diff --git a/chronos/transports/stream.nim b/chronos/transports/stream.nim index e752a39..81923ff 100644 --- a/chronos/transports/stream.nim +++ b/chronos/transports/stream.nim @@ -180,7 +180,7 @@ template setReadError(t, e: untyped) = (t).error = getTransportOsError(e) template checkPending(t: untyped) = - if not isNil((t).reader): + if not(isNil((t).reader)): raise newException(TransportError, "Read operation already pending!") template shiftBuffer(t, c: untyped) = @@ -282,7 +282,7 @@ proc clean(server: StreamServer) {.inline.} = if not(server.loopFuture.finished()): untrackServer(server) server.loopFuture.complete() - if not isNil(server.udata) and GCUserData in server.flags: + if not(isNil(server.udata)) and (GCUserData in server.flags): GC_unref(cast[ref int](server.udata)) GC_unref(server) @@ -634,7 +634,7 @@ when defined(windows): proc newStreamSocketTransport(sock: AsyncFD, bufsize: int, child: StreamTransport): StreamTransport = var transp: StreamTransport - if not isNil(child): + if not(isNil(child)): transp = child else: transp = StreamTransport(kind: TransportKind.Socket) @@ -654,7 +654,7 @@ when defined(windows): child: StreamTransport, flags: set[TransportFlags] = {}): StreamTransport = var transp: StreamTransport - if not isNil(child): + if not(isNil(child)): transp = child else: transp = StreamTransport(kind: TransportKind.Pipe) @@ -718,7 +718,7 @@ when defined(windows): retFuture.fail(getTransportOsError(osLastError())) return retFuture - if not bindToDomain(sock, raddress.getDomain()): + if not(bindToDomain(sock, raddress.getDomain())): let err = wsaGetLastError() sock.closeSocket() retFuture.fail(getTransportOsError(err)) @@ -751,12 +751,12 @@ when defined(windows): povl = RefCustomOverlapped() GC_ref(povl) povl.data = CompletionData(fd: sock, cb: socketContinuation) - var res = loop.connectEx(SocketHandle(sock), + let res = loop.connectEx(SocketHandle(sock), cast[ptr SockAddr](addr saddr), DWORD(slen), nil, 0, nil, cast[POVERLAPPED](povl)) # We will not process immediate completion, to avoid undefined behavior. - if not res: + if not(res): let err = osLastError() if int32(err) != ERROR_IO_PENDING: GC_unref(povl) @@ -839,7 +839,7 @@ when defined(windows): var flags = {WinServerPipe} if NoPipeFlash in server.flags: flags.incl(WinNoPipeFlash) - if not isNil(server.init): + if not(isNil(server.init)): var transp = server.init(server, server.sock) ntransp = newStreamPipeTransport(server.sock, server.bufferSize, transp, flags) @@ -928,7 +928,7 @@ when defined(windows): raiseAssert osErrorMsg(err) else: var ntransp: StreamTransport - if not isNil(server.init): + if not(isNil(server.init)): let transp = server.init(server, server.asock) ntransp = newStreamSocketTransport(server.asock, server.bufferSize, @@ -982,7 +982,7 @@ when defined(windows): dwReceiveDataLength, dwLocalAddressLength, dwRemoteAddressLength, addr dwBytesReceived, cast[POVERLAPPED](addr server.aovl)) - if not res: + if not(res): let err = osLastError() if int32(err) == ERROR_OPERATION_ABORTED: server.apending = false @@ -1014,7 +1014,7 @@ when defined(windows): discard cancelIO(Handle(server.sock)) proc resumeAccept(server: StreamServer) {.inline.} = - if not server.apending: + if not(server.apending): server.aovl.data.cb(addr server.aovl) proc accept*(server: StreamServer): Future[StreamTransport] = @@ -1052,7 +1052,7 @@ when defined(windows): retFuture.fail(getTransportOsError(err)) else: var ntransp: StreamTransport - if not isNil(server.init): + if not(isNil(server.init)): let transp = server.init(server, server.asock) ntransp = newStreamSocketTransport(server.asock, server.bufferSize, @@ -1073,6 +1073,8 @@ when defined(windows): retFuture.fail(getTransportOsError(ovl.data.errCode)) proc cancellationSocket(udata: pointer) {.gcsafe.} = + if server.apending: + server.apending = false server.asock.closeSocket() proc continuationPipe(udata: pointer) {.gcsafe.} = @@ -1091,7 +1093,7 @@ when defined(windows): var flags = {WinServerPipe} if NoPipeFlash in server.flags: flags.incl(WinNoPipeFlash) - if not isNil(server.init): + if not(isNil(server.init)): var transp = server.init(server, server.sock) ntransp = newStreamPipeTransport(server.sock, server.bufferSize, transp, flags) @@ -1126,6 +1128,8 @@ when defined(windows): retFuture.fail(getTransportOsError(ovl.data.errCode)) proc cancellationPipe(udata: pointer) {.gcsafe.} = + if server.apending: + server.apending = false server.sock.closeHandle() if server.local.family in {AddressFamily.IPv4, AddressFamily.IPv6}: @@ -1160,7 +1164,7 @@ when defined(windows): dwReceiveDataLength, dwLocalAddressLength, dwRemoteAddressLength, addr dwBytesReceived, cast[POVERLAPPED](addr server.aovl)) - if not res: + if not(res): let err = osLastError() if int32(err) == ERROR_OPERATION_ABORTED: server.apending = false @@ -1494,7 +1498,7 @@ else: proc newStreamSocketTransport(sock: AsyncFD, bufsize: int, child: StreamTransport): StreamTransport = var transp: StreamTransport - if not isNil(child): + if not(isNil(child)): transp = child else: transp = StreamTransport(kind: TransportKind.Socket) @@ -1510,7 +1514,7 @@ else: proc newStreamPipeTransport(fd: AsyncFD, bufsize: int, child: StreamTransport): StreamTransport = var transp: StreamTransport - if not isNil(child): + if not(isNil(child)): transp = child else: transp = StreamTransport(kind: TransportKind.Pipe) @@ -1569,7 +1573,7 @@ else: retFuture.fail(exc) return - if not fd.getSocketError(err): + if not(fd.getSocketError(err)): closeSocket(fd) retFuture.fail(getTransportOsError(osLastError())) return @@ -1636,7 +1640,7 @@ else: raiseAsDefect exc, "wrapAsyncSocket" if sock != asyncInvalidSocket: var ntransp: StreamTransport - if not isNil(server.init): + if not(isNil(server.init)): let transp = server.init(server, sock) ntransp = newStreamSocketTransport(sock, server.bufferSize, transp) else: @@ -1719,7 +1723,7 @@ else: if sock != asyncInvalidSocket: var ntransp: StreamTransport - if not isNil(server.init): + if not(isNil(server.init)): let transp = server.init(server, sock) ntransp = newStreamSocketTransport(sock, server.bufferSize, transp) @@ -1773,7 +1777,8 @@ else: proc start*(server: StreamServer) {. raises: [Defect, IOSelectorsException, ValueError].} = ## Starts ``server``. - doAssert(not(isNil(server.function))) + doAssert(not(isNil(server.function)), + "You should not start server, if you have not set processing callback!") if server.status == ServerStatus.Starting: server.resumeAccept() server.status = ServerStatus.Running @@ -1782,7 +1787,8 @@ proc stop*(server: StreamServer) {. raises: [Defect, IOSelectorsException, ValueError].} = ## Stops ``server``. if server.status == ServerStatus.Running: - server.pauseAccept() + if not(isNil(server.function)): + server.pauseAccept() server.status = ServerStatus.Stopped elif server.status == ServerStatus.Starting: server.status = ServerStatus.Stopped @@ -1814,27 +1820,19 @@ proc close*(server: StreamServer) = if not(server.loopFuture.finished()): server.clean() - let r1 = (server.status == ServerStatus.Stopped) and - not(isNil(server.function)) - let r2 = (server.status == ServerStatus.Starting) and isNil(server.function) - - if r1 or r2: + if server.status in {ServerStatus.Starting, ServerStatus.Stopped}: server.status = ServerStatus.Closed when defined(windows): if server.local.family in {AddressFamily.IPv4, AddressFamily.IPv6}: - if not server.apending: - server.sock.closeSocket(continuation) - else: + if server.apending: server.asock.closeSocket() - server.sock.closeSocket() + server.apending = false + server.sock.closeSocket(continuation) elif server.local.family in {AddressFamily.Unix}: if NoPipeFlash notin server.flags: discard flushFileBuffers(Handle(server.sock)) discard disconnectNamedPipe(Handle(server.sock)) - if not server.apending: - server.sock.closeHandle(continuation) - else: - server.sock.closeHandle() + server.sock.closeHandle(continuation) else: server.sock.closeSocket(continuation) @@ -1883,21 +1881,21 @@ proc createStreamServer*(host: TransportAddress, if serverSocket == asyncInvalidSocket: raiseTransportOsError(osLastError()) else: - if not setSocketBlocking(SocketHandle(sock), false): + if not(setSocketBlocking(SocketHandle(sock), false)): raiseTransportOsError(osLastError()) register(sock) serverSocket = sock # SO_REUSEADDR is not useful for Unix domain sockets. if ServerFlags.ReuseAddr in flags: - if not setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1): + if not(setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1)): let err = osLastError() if sock == asyncInvalidSocket: serverSocket.closeSocket() raiseTransportOsError(err) # TCP flags are not useful for Unix domain sockets. if ServerFlags.TcpNoDelay in flags: - if not setSockOpt(serverSocket, handles.IPPROTO_TCP, - handles.TCP_NODELAY, 1): + if not(setSockOpt(serverSocket, handles.IPPROTO_TCP, + handles.TCP_NODELAY, 1)): let err = osLastError() if sock == asyncInvalidSocket: serverSocket.closeSocket() @@ -1940,7 +1938,7 @@ proc createStreamServer*(host: TransportAddress, if serverSocket == asyncInvalidSocket: raiseTransportOsError(osLastError()) else: - if not setSocketBlocking(SocketHandle(sock), false): + if not(setSocketBlocking(SocketHandle(sock), false)): raiseTransportOsError(osLastError()) register(sock) @@ -1949,21 +1947,21 @@ proc createStreamServer*(host: TransportAddress, if host.family in {AddressFamily.IPv4, AddressFamily.IPv6}: # SO_REUSEADDR and SO_REUSEPORT are not useful for Unix domain sockets. if ServerFlags.ReuseAddr in flags: - if not setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1): + if not(setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1)): let err = osLastError() if sock == asyncInvalidSocket: serverSocket.closeSocket() raiseTransportOsError(err) if ServerFlags.ReusePort in flags: - if not setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEPORT, 1): + if not(setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEPORT, 1)): let err = osLastError() if sock == asyncInvalidSocket: serverSocket.closeSocket() raiseTransportOsError(err) # TCP flags are not useful for Unix domain sockets. if ServerFlags.TcpNoDelay in flags: - if not setSockOpt(serverSocket, handles.IPPROTO_TCP, - handles.TCP_NODELAY, 1): + if not(setSockOpt(serverSocket, handles.IPPROTO_TCP, + handles.TCP_NODELAY, 1)): let err = osLastError() if sock == asyncInvalidSocket: serverSocket.closeSocket() @@ -1997,7 +1995,7 @@ proc createStreamServer*(host: TransportAddress, serverSocket.closeSocket() raiseTransportOsError(err) - if not isNil(child): + if not(isNil(child)): result = child else: result = StreamServer() @@ -2099,7 +2097,7 @@ proc write*(transp: StreamTransport, msg: string, msglen = -1): Future[int] = var retFuture = newFutureStr[int]("stream.transport.write(string)") transp.checkClosed(retFuture) transp.checkWriteEof(retFuture) - if not isLiteral(msg): + if not(isLiteral(msg)): shallowCopy(retFuture.gcholder, msg) else: retFuture.gcholder = msg @@ -2117,7 +2115,7 @@ proc write*[T](transp: StreamTransport, msg: seq[T], msglen = -1): Future[int] = var retFuture = newFutureSeq[int, T]("stream.transport.write(seq)") transp.checkClosed(retFuture) transp.checkWriteEof(retFuture) - if not isLiteral(msg): + if not(isLiteral(msg)): shallowCopy(retFuture.gcholder, msg) else: retFuture.gcholder = msg diff --git a/tests/teststream.nim b/tests/teststream.nim index 00ffb37..607362d 100644 --- a/tests/teststream.nim +++ b/tests/teststream.nim @@ -1201,6 +1201,44 @@ suite "Stream Transport test suite": await acceptFut return res + proc testAcceptRace(address: TransportAddress): Future[bool] {.async.} = + proc test1(address: TransportAddress) {.async.} = + let server = createStreamServer(address, flags = {ReuseAddr}) + let acceptFut = server.accept() + server.close() + await allFutures(acceptFut.cancelAndWait(), server.join()) + + proc test2(address: TransportAddress) {.async.} = + let server = createStreamServer(address, flags = {ReuseAddr}) + let acceptFut = server.accept() + await acceptFut.cancelAndWait() + server.close() + await server.join() + + proc test3(address: TransportAddress) {.async.} = + let server = createStreamServer(address, flags = {ReuseAddr}) + let acceptFut = server.accept() + server.stop() + server.close() + await allFutures(acceptFut.cancelAndWait(), server.join()) + + proc test4(address: TransportAddress) {.async.} = + let server = createStreamServer(address, flags = {ReuseAddr}) + let acceptFut = server.accept() + await acceptFut.cancelAndWait() + server.stop() + server.close() + await server.join() + + try: + await test1(address).wait(5.seconds) + await test2(address).wait(5.seconds) + await test3(address).wait(5.seconds) + await test4(address).wait(5.seconds) + return true + except AsyncTimeoutError: + return false + markFD = getCurrentFD() for i in 0..