diff --git a/asyncdispatch2.nimble b/asyncdispatch2.nimble index 46d95440..edaf5cb5 100644 --- a/asyncdispatch2.nimble +++ b/asyncdispatch2.nimble @@ -1,5 +1,5 @@ packageName = "asyncdispatch2" -version = "2.0.8" +version = "2.0.9" author = "Status Research & Development GmbH" description = "Asyncdispatch2" license = "Apache License 2.0 or MIT" diff --git a/asyncdispatch2/asyncloop.nim b/asyncdispatch2/asyncloop.nim index 4838143d..fe35be7d 100644 --- a/asyncdispatch2/asyncloop.nim +++ b/asyncdispatch2/asyncloop.nim @@ -287,12 +287,14 @@ when defined(windows) or defined(nimdoc): var gDisp{.threadvar.}: PDispatcher ## Global dispatcher proc setGlobalDispatcher*(disp: PDispatcher) = + ## Set current thread's dispatcher instance to ``disp``. if not gDisp.isNil: assert gDisp.callbacks.len == 0 gDisp = disp initCallSoonProc() proc getGlobalDispatcher*(): PDispatcher = + ## Returns current thread's dispatcher instance. if gDisp.isNil: setGlobalDispatcher(newDispatcher()) result = gDisp @@ -303,14 +305,15 @@ when defined(windows) or defined(nimdoc): return disp.ioPort proc register*(fd: AsyncFD) = - ## Registers ``fd`` with the dispatcher. - let p = getGlobalDispatcher() - if createIoCompletionPort(fd.Handle, p.ioPort, + ## Register file descriptor ``fd`` in thread's dispatcher. + let loop = getGlobalDispatcher() + if createIoCompletionPort(fd.Handle, loop.ioPort, cast[CompletionKey](fd), 1) == 0: raiseOSError(osLastError()) - p.handles.incl(fd) + loop.handles.incl(fd) proc poll*() = + ## Perform single asynchronous step. let loop = getGlobalDispatcher() var curTime = fastEpochTime() var curTimeout = DWORD(0) @@ -397,16 +400,21 @@ when defined(windows) or defined(nimdoc): loop.transmitFile = cast[WSAPROC_TRANSMITFILE](funcPointer) close(sock) - proc closeSocket*(socket: AsyncFD) = + proc closeSocket*(socket: AsyncFD, aftercb: CallbackFunc = nil) = ## Closes a socket and ensures that it is unregistered. + let loop = getGlobalDispatcher() socket.SocketHandle.close() - getGlobalDispatcher().handles.excl(socket) + loop.handles.excl(socket) + if not isNil(aftercb): + var acb = AsyncCallback(function: aftercb) + loop.callbacks.addLast(acb) proc unregister*(fd: AsyncFD) = ## Unregisters ``fd``. getGlobalDispatcher().handles.excl(fd) proc contains*(disp: PDispatcher, fd: AsyncFD): bool = + ## Returns ``true`` if ``fd`` is registered in thread's dispatcher. return fd in disp.handles else: @@ -435,6 +443,7 @@ else: proc `==`*(x, y: AsyncFD): bool {.borrow.} proc newDispatcher*(): PDispatcher = + ## Create new dispatcher. new result result.selector = newSelector[SelectorData]() result.timers.newHeapQueue() @@ -444,40 +453,44 @@ else: var gDisp{.threadvar.}: PDispatcher ## Global dispatcher proc setGlobalDispatcher*(disp: PDispatcher) = + ## Set current thread's dispatcher instance to ``disp``. if not gDisp.isNil: assert gDisp.callbacks.len == 0 gDisp = disp initCallSoonProc() proc getGlobalDispatcher*(): PDispatcher = + ## Returns current thread's dispatcher instance. if gDisp.isNil: setGlobalDispatcher(newDispatcher()) result = gDisp proc getIoHandler*(disp: PDispatcher): Selector[SelectorData] = + ## Returns system specific OS queue. return disp.selector proc register*(fd: AsyncFD) = - ## Register file descriptor ``fd`` in selector. + ## Register file descriptor ``fd`` in thread's dispatcher. + let loop = getGlobalDispatcher() var data: SelectorData data.rdata.fd = fd data.wdata.fd = fd - let loop = getGlobalDispatcher() loop.selector.registerHandle(int(fd), {}, data) proc unregister*(fd: AsyncFD) = - ## Unregister file descriptor ``fd`` from selector. + ## Unregister file descriptor ``fd`` from thread's dispatcher. getGlobalDispatcher().selector.unregister(int(fd)) proc contains*(disp: PDispatcher, fd: AsyncFd): bool {.inline.} = + ## Returns ``true`` if ``fd`` is registered in thread's dispatcher. result = int(fd) in disp.selector proc addReader*(fd: AsyncFD, cb: CallbackFunc, udata: pointer = nil) = ## Start watching the file descriptor ``fd`` for read availability and then ## call the callback ``cb`` with specified argument ``udata``. - let p = getGlobalDispatcher() + let loop = getGlobalDispatcher() var newEvents = {Event.Read} - withData(p.selector, int(fd), adata) do: + withData(loop.selector, int(fd), adata) do: let acb = AsyncCallback(function: cb, udata: addr adata.rdata) adata.reader = acb adata.rdata = CompletionData(fd: fd, udata: udata) @@ -485,27 +498,27 @@ else: if not isNil(adata.writer.function): newEvents.incl(Event.Write) do: raise newException(ValueError, "File descriptor not registered.") - p.selector.updateHandle(int(fd), newEvents) + loop.selector.updateHandle(int(fd), newEvents) proc removeReader*(fd: AsyncFD) = ## Stop watching the file descriptor ``fd`` for read availability. - let p = getGlobalDispatcher() + let loop = getGlobalDispatcher() var newEvents: set[Event] - withData(p.selector, int(fd), adata) do: + withData(loop.selector, int(fd), adata) do: # We need to clear `reader` data, because `selectors` don't do it - adata.reader = AsyncCallback() - adata.rdata = CompletionData() + adata.reader.function = nil + # adata.rdata = CompletionData() if not isNil(adata.writer.function): newEvents.incl(Event.Write) do: raise newException(ValueError, "File descriptor not registered.") - p.selector.updateHandle(int(fd), newEvents) + loop.selector.updateHandle(int(fd), newEvents) proc addWriter*(fd: AsyncFD, cb: CallbackFunc, udata: pointer = nil) = ## Start watching the file descriptor ``fd`` for write availability and then ## call the callback ``cb`` with specified argument ``udata``. - let p = getGlobalDispatcher() + let loop = getGlobalDispatcher() var newEvents = {Event.Write} - withData(p.selector, int(fd), adata) do: + withData(loop.selector, int(fd), adata) do: let acb = AsyncCallback(function: cb, udata: addr adata.wdata) adata.writer = acb adata.wdata = CompletionData(fd: fd, udata: udata) @@ -513,20 +526,44 @@ else: if not isNil(adata.reader.function): newEvents.incl(Event.Read) do: raise newException(ValueError, "File descriptor not registered.") - p.selector.updateHandle(int(fd), newEvents) + loop.selector.updateHandle(int(fd), newEvents) proc removeWriter*(fd: AsyncFD) = ## Stop watching the file descriptor ``fd`` for write availability. - let p = getGlobalDispatcher() + let loop = getGlobalDispatcher() var newEvents: set[Event] - withData(p.selector, int(fd), adata) do: + withData(loop.selector, int(fd), adata) do: # We need to clear `writer` data, because `selectors` don't do it - adata.writer = AsyncCallback() - adata.wdata = CompletionData() + adata.writer.function = nil + # adata.wdata = CompletionData() if not isNil(adata.reader.function): newEvents.incl(Event.Read) do: raise newException(ValueError, "File descriptor not registered.") - p.selector.updateHandle(int(fd), newEvents) + loop.selector.updateHandle(int(fd), newEvents) + + proc closeSocket*(fd: AsyncFD, aftercb: CallbackFunc = nil) = + ## Close asynchronous socket. + ## + ## Please note, that socket is not closed immediately. To avoid bugs with + ## closing socket, while operation pending, socket will be closed as + ## soon as all pending operations will be notified. + ## You can execute ``aftercb`` before actual socket close operation. + let loop = getGlobalDispatcher() + + proc continuation(udata: pointer) = + aftercb(nil) + unregister(fd) + close(SocketHandle(fd)) + + withData(loop.selector, int(fd), adata) do: + if not isNil(adata.reader.function): + loop.callbacks.addLast(adata.reader) + if not isNil(adata.writer.function): + loop.callbacks.addLast(adata.writer) + + if not isNil(aftercb): + var acb = AsyncCallback(function: continuation) + loop.callbacks.addLast(acb) when ioselSupportedPlatform: proc addSignal*(signal: int, cb: CallbackFunc, @@ -535,10 +572,10 @@ else: ## callback ``cb`` with specified argument ``udata``. Returns signal ## identifier code, which can be used to remove signal callback ## via ``removeSignal``. - let p = getGlobalDispatcher() + let loop = getGlobalDispatcher() var data: SelectorData - result = p.selector.registerSignal(signal, data) - withData(p.selector, result, adata) do: + result = loop.selector.registerSignal(signal, data) + withData(loop.selector, result, adata) do: adata.reader = AsyncCallback(function: cb, udata: addr adata.rdata) adata.rdata.fd = AsyncFD(result) adata.rdata.udata = udata @@ -547,8 +584,8 @@ else: proc removeSignal*(sigfd: int) = ## Remove watching signal ``signal``. - let p = getGlobalDispatcher() - p.selector.unregister(sigfd) + let loop = getGlobalDispatcher() + loop.selector.unregister(sigfd) proc poll*() = ## Perform single asynchronous step. @@ -569,21 +606,18 @@ else: let fd = loop.keys[i].fd let events = loop.keys[i].events - if Event.Read in events or events == {Event.Error}: - withData(loop.selector, fd, adata) do: + withData(loop.selector, fd, adata) do: + if Event.Read in events or events == {Event.Error}: loop.callbacks.addLast(adata.reader) - if Event.Write in events or events == {Event.Error}: - withData(loop.selector, fd, adata) do: + if Event.Write in events or events == {Event.Error}: loop.callbacks.addLast(adata.writer) - if Event.User in events: - withData(loop.selector, fd, adata) do: + if Event.User in events: loop.callbacks.addLast(adata.reader) - when ioselSupportedPlatform: - if customSet * events != {}: - withData(loop.selector, fd, adata) do: + when ioselSupportedPlatform: + if customSet * events != {}: loop.callbacks.addLast(adata.reader) # Moving expired timers to `loop.callbacks`. @@ -618,10 +652,6 @@ proc removeTimer*(at: uint64, cb: CallbackFunc, udata: pointer = nil) = if index != -1: loop.timers.del(index) -# proc completeProxy*[T](data: pointer) = -# var future = cast[Future[T]](data) -# future.complete() - proc sleepAsync*(ms: int): Future[void] = ## Suspends the execution of the current async procedure for the next ## ``ms`` milliseconds. @@ -656,7 +686,7 @@ proc withTimeout*[T](fut: Future[T], timeout: int): Future[bool] = proc wait*[T](fut: Future[T], timeout = -1): Future[T] = ## Returns a future which will complete once future ``fut`` completes ## or if timeout of ``timeout`` milliseconds has been expired. - ## + ## ## If ``timeout`` is ``-1``, then statement ``await wait(fut)`` is ## equal to ``await fut``. var retFuture = newFuture[T]("asyncdispatch.wait") diff --git a/asyncdispatch2/handles.nim b/asyncdispatch2/handles.nim index d90beb5d..19f0649b 100644 --- a/asyncdispatch2/handles.nim +++ b/asyncdispatch2/handles.nim @@ -90,8 +90,3 @@ proc wrapAsyncSocket*(sock: SocketHandle): AsyncFD = return asyncInvalidSocket result = AsyncFD(sock) register(result) - -proc closeAsyncSocket*(s: AsyncFD) {.inline.} = - ## Closes asynchronous socket handle ``s``. - unregister(s) - close(SocketHandle(s)) diff --git a/asyncdispatch2/transports/datagram.nim b/asyncdispatch2/transports/datagram.nim index be0c0fb1..cc7c0947 100644 --- a/asyncdispatch2/transports/datagram.nim +++ b/asyncdispatch2/transports/datagram.nim @@ -228,7 +228,7 @@ when defined(windows): if not setSockOpt(localSock, SOL_SOCKET, SO_REUSEADDR, 1): let err = osLastError() if sock == asyncInvalidSocket: - closeAsyncSocket(localSock) + closeSocket(localSock) raiseTransportOsError(err) ## Fix for Q263823. @@ -247,7 +247,7 @@ when defined(windows): slen) != 0: let err = osLastError() if sock == asyncInvalidSocket: - closeAsyncSocket(localSock) + closeSocket(localSock) raiseTransportOsError(err) result.local = local else: @@ -263,7 +263,7 @@ when defined(windows): slen) != 0: let err = osLastError() if sock == asyncInvalidSocket: - closeAsyncSocket(localSock) + closeSocket(localSock) raiseTransportOsError(err) if remote.port != Port(0): @@ -274,7 +274,7 @@ when defined(windows): slen) != 0: let err = osLastError() if sock == asyncInvalidSocket: - closeAsyncSocket(localSock) + closeSocket(localSock) raiseTransportOsError(err) result.remote = remote @@ -297,30 +297,34 @@ when defined(windows): else: result.state.incl(ReadPaused) - proc close*(transp: DatagramTransport) = - ## Closes and frees resources of transport ``transp``. - if ReadClosed notin transp.state and WriteClosed notin transp.state: - # discard cancelIo(Handle(transp.fd)) - closeAsyncSocket(transp.fd) - transp.state.incl(WriteClosed) - transp.state.incl(ReadClosed) - transp.future.complete() - if not isNil(transp.udata) and GCUserData in transp.flags: - GC_unref(cast[ref int](transp.udata)) - GC_unref(transp) + # proc close*(transp: DatagramTransport) = + # ## Closes and frees resources of transport ``transp``. + # if ReadClosed notin transp.state and WriteClosed notin transp.state: + # # discard cancelIo(Handle(transp.fd)) + # closeSocket(transp.fd) + # transp.state.incl(WriteClosed) + # transp.state.incl(ReadClosed) + # transp.future.complete() + # if not isNil(transp.udata) and GCUserData in transp.flags: + # GC_unref(cast[ref int](transp.udata)) + # GC_unref(transp) else: # Linux/BSD/MacOS part proc readDatagramLoop(udata: pointer) = var raddr: TransportAddress + doAssert(not isNil(udata)) var cdata = cast[ptr CompletionData](udata) - if not isNil(cdata) and (int(cdata.fd) == 0 or isNil(cdata.udata)): - # Transport was closed earlier, exiting - return var transp = cast[DatagramTransport](cdata.udata) let fd = SocketHandle(cdata.fd) - if not isNil(transp): + if int(fd) == 0: + ## This situation can be happen, when there events present + ## after transport was closed. + return + if ReadClosed in transp.state: + transp.state.incl({ReadPaused}) + else: while true: transp.ralen = SockLen(sizeof(Sockaddr_storage)) var res = posix.recvfrom(fd, addr transp.buffer[0], @@ -343,13 +347,17 @@ else: proc writeDatagramLoop(udata: pointer) = var res: int + doAssert(not isNil(udata)) var cdata = cast[ptr CompletionData](udata) - if not isNil(cdata) and (int(cdata.fd) == 0 or isNil(cdata.udata)): - # Transport was closed earlier, exiting - return var transp = cast[DatagramTransport](cdata.udata) let fd = SocketHandle(cdata.fd) - if not isNil(transp): + if int(fd) == 0: + ## This situation can be happen, when there events present + ## after transport was closed. + return + if WriteClosed in transp.state: + transp.state.incl({WritePaused}) + else: if len(transp.queue) > 0: var vector = transp.queue.popFirst() while true: @@ -420,7 +428,7 @@ else: if not setSockOpt(localSock, SOL_SOCKET, SO_REUSEADDR, 1): let err = osLastError() if sock == asyncInvalidSocket: - closeAsyncSocket(localSock) + closeSocket(localSock) raiseTransportOsError(err) if local.port != Port(0): @@ -431,7 +439,7 @@ else: slen) != 0: let err = osLastError() if sock == asyncInvalidSocket: - closeAsyncSocket(localSock) + closeSocket(localSock) raiseTransportOsError(err) result.local = local @@ -443,7 +451,7 @@ else: slen) != 0: let err = osLastError() if sock == asyncInvalidSocket: - closeAsyncSocket(localSock) + closeSocket(localSock) raiseTransportOsError(err) result.remote = remote @@ -461,13 +469,22 @@ else: else: result.state.incl(ReadPaused) - proc close*(transp: DatagramTransport) = - ## Closes and frees resources of transport ``transp``. +proc close*(transp: DatagramTransport) = + ## Closes and frees resources of transport ``transp``. + when defined(windows): if {ReadClosed, WriteClosed} * transp.state == {}: - closeAsyncSocket(transp.fd) + discard cancelIo(Handle(transp.fd)) + closeSocket(transp.fd) transp.state.incl({WriteClosed, ReadClosed}) transp.future.complete() GC_unref(transp) + else: + proc continuation(udata: pointer) = + transp.future.complete() + GC_unref(transp) + if {ReadClosed, WriteClosed} * transp.state == {}: + transp.state.incl({WriteClosed, ReadClosed}) + closeSocket(transp.fd, continuation) proc newDatagramTransport*(cbproc: DatagramCallback, remote: TransportAddress = AnyAddress, @@ -543,10 +560,19 @@ proc newDatagramTransport6*[T](cbproc: DatagramCallback, fflags, cast[pointer](udata), child, bufSize) -proc join*(transp: DatagramTransport) {.async.} = +proc join*(transp: DatagramTransport): Future[void] = ## Wait until the transport ``transp`` will be closed. + var retFuture = newFuture[void]("datagramtransport.join") + proc continuation(udata: pointer) = retFuture.complete() if not transp.future.finished: - await transp.future + transp.future.addCallback(continuation) + else: + retFuture.complete() + return retFuture + +proc closeWait*(transp: DatagramTransport): Future[void] = + ## Close transport ``transp`` and release all resources. + result = transp.join() proc send*(transp: DatagramTransport, pbytes: pointer, nbytes: int): Future[void] = diff --git a/asyncdispatch2/transports/stream.nim b/asyncdispatch2/transports/stream.nim index 24f4d8c4..f4cc9d2b 100644 --- a/asyncdispatch2/transports/stream.nim +++ b/asyncdispatch2/transports/stream.nim @@ -291,6 +291,11 @@ when defined(windows): ## Continuation transp.state.excl(ReadPending) if ReadClosed in transp.state: + transp.state.incl({ReadPaused}) + if not isNil(transp.reader): + if not transp.reader.finished: + transp.reader.complete() + transp.reader = nil break let err = transp.rovl.data.errCode if err == OSErrorCode(-1): @@ -353,6 +358,11 @@ when defined(windows): if not isNil(transp.reader): transp.reader.complete() transp.reader = nil + else: + transp.state.incl(ReadPaused) + if not isNil(transp.reader): + transp.reader.complete() + transp.reader = nil ## Finish Loop break @@ -411,7 +421,7 @@ when defined(windows): result.fail(newException(TransportOsError, osErrorMsg(osLastError()))) if not bindToDomain(sock, address.address.getDomain()): - sock.closeAsyncSocket() + sock.closeSocket() result.fail(newException(TransportOsError, osErrorMsg(osLastError()))) proc continuation(udata: pointer) = @@ -421,7 +431,7 @@ when defined(windows): if setsockopt(SocketHandle(sock), cint(SOL_SOCKET), cint(SO_UPDATE_CONNECT_CONTEXT), nil, SockLen(0)) != 0'i32: - sock.closeAsyncSocket() + sock.closeSocket() retFuture.fail(newException(TransportOsError, osErrorMsg(osLastError()))) else: @@ -429,7 +439,7 @@ when defined(windows): bufferSize, child)) else: - sock.closeAsyncSocket() + sock.closeSocket() retFuture.fail(newException(TransportOsError, osErrorMsg(ovl.data.errCode))) GC_unref(ovl) @@ -446,7 +456,7 @@ when defined(windows): let err = osLastError() if int32(err) != ERROR_IO_PENDING: GC_unref(povl) - sock.closeAsyncSocket() + sock.closeSocket() retFuture.fail(newException(TransportOsError, osErrorMsg(err))) return retFuture @@ -460,14 +470,14 @@ when defined(windows): ## Continuation server.apending = false if server.status == ServerStatus.Stopped: - server.asock.closeAsyncSocket() + server.asock.closeSocket() else: if ovl.data.errCode == OSErrorCode(-1): if setsockopt(SocketHandle(server.asock), cint(SOL_SOCKET), cint(SO_UPDATE_ACCEPT_CONTEXT), addr server.sock, SockLen(sizeof(SocketHandle))) != 0'i32: - server.asock.closeAsyncSocket() + server.asock.closeSocket() raiseTransportOsError(osLastError()) else: if not isNil(server.init): @@ -482,10 +492,10 @@ when defined(windows): asyncCheck server.function(server, ntransp) elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED: # CancelIO() interrupt - server.asock.closeAsyncSocket() + server.asock.closeSocket() break else: - server.asock.closeAsyncSocket() + server.asock.closeSocket() raiseTransportOsError(osLastError()) else: ## Initiation @@ -553,11 +563,14 @@ else: proc writeStreamLoop(udata: pointer) {.gcsafe.} = var cdata = cast[ptr CompletionData](udata) - if not isNil(cdata) and (int(cdata.fd) == 0 or isNil(cdata.udata)): - # Transport was closed earlier, exiting - return var transp = cast[StreamTransport](cdata.udata) let fd = SocketHandle(cdata.fd) + + if int(fd) == 0: + ## This situation can be happen, when there events present + ## after transport was closed. + return + if len(transp.queue) > 0: var vector = transp.queue.popFirst() while true: @@ -601,36 +614,46 @@ else: proc readStreamLoop(udata: pointer) {.gcsafe.} = var cdata = cast[ptr CompletionData](udata) - if not isNil(cdata) and (int(cdata.fd) == 0 or isNil(cdata.udata)): - # Transport was closed earlier, exiting - return var transp = cast[StreamTransport](cdata.udata) let fd = SocketHandle(cdata.fd) - while true: - var res = posix.recv(fd, addr transp.buffer[transp.offset], - len(transp.buffer) - transp.offset, cint(0)) - if res < 0: - let err = osLastError() - if int(err) == EINTR: - continue - elif int(err) in {ECONNRESET}: + if int(fd) == 0: + ## This situation can be happen, when there events present + ## after transport was closed. + return + + if ReadClosed in transp.state: + transp.state.incl({ReadPaused}) + if not isNil(transp.reader): + if not transp.reader.finished: + transp.reader.complete() + transp.reader = nil + else: + while true: + var res = posix.recv(fd, addr transp.buffer[transp.offset], + len(transp.buffer) - transp.offset, cint(0)) + if res < 0: + let err = osLastError() + if int(err) == EINTR: + continue + elif int(err) in {ECONNRESET}: + transp.state.incl({ReadEof, ReadPaused}) + cdata.fd.removeReader() + else: + transp.state.incl(ReadPaused) + transp.setReadError(err) + cdata.fd.removeReader() + elif res == 0: transp.state.incl({ReadEof, ReadPaused}) cdata.fd.removeReader() else: - transp.setReadError(err) - cdata.fd.removeReader() - elif res == 0: - transp.state.incl({ReadEof, ReadPaused}) - cdata.fd.removeReader() - else: - transp.offset += res - if transp.offset == len(transp.buffer): - transp.state.incl(ReadPaused) - cdata.fd.removeReader() - if not isNil(transp.reader): - transp.reader.complete() - transp.reader = nil - break + transp.offset += res + if transp.offset == len(transp.buffer): + transp.state.incl(ReadPaused) + cdata.fd.removeReader() + if not isNil(transp.reader): + transp.reader.complete() + transp.reader = nil + break proc newStreamSocketTransport(sock: AsyncFD, bufsize: int, child: StreamTransport): StreamTransport = @@ -670,17 +693,17 @@ else: var data = cast[ptr CompletionData](udata) var err = 0 let fd = data.fd + fd.removeWriter() if not fd.getSocketError(err): - fd.closeAsyncSocket() + closeSocket(fd) retFuture.fail(newException(TransportOsError, osErrorMsg(osLastError()))) return if err != 0: - fd.closeAsyncSocket() + closeSocket(fd) retFuture.fail(newException(TransportOsError, osErrorMsg(OSErrorCode(err)))) return - fd.removeWriter() retFuture.complete(newStreamSocketTransport(fd, bufferSize, child)) while true: @@ -697,7 +720,7 @@ else: sock.addWriter(continuation) break else: - sock.closeAsyncSocket() + sock.closeSocket() retFuture.fail(newException(TransportOsError, osErrorMsg(err))) break return retFuture @@ -757,20 +780,33 @@ proc stop*(server: StreamServer) = elif server.status == ServerStatus.Starting: server.status = ServerStatus.Stopped -proc join*(server: StreamServer) {.async.} = +proc join*(server: StreamServer): Future[void] = ## Waits until ``server`` is not closed. + var retFuture = newFuture[void]("streamserver.join") + proc continuation(udata: pointer) = retFuture.complete() if not server.loopFuture.finished: - await server.loopFuture + server.loopFuture.addCallback(continuation) + else: + retFuture.complete() + return retFuture proc close*(server: StreamServer) = ## Release ``server`` resources. - if server.status == ServerStatus.Stopped: - closeAsyncSocket(server.sock) - server.status = ServerStatus.Closed + ## + ## Please note that release of resources is not completed immediately, to be + ## sure all resources got released please use ``await server.join()``. + proc continuation(udata: pointer) = server.loopFuture.complete() if not isNil(server.udata) and GCUserData in server.flags: GC_unref(cast[ref int](server.udata)) GC_unref(server) + if server.status == ServerStatus.Stopped: + server.status = ServerStatus.Closed + server.sock.closeSocket(continuation) + +proc closeWait*(server: StreamServer): Future[void] = + ## Close server ``server`` and release all resources. + result = server.join() proc createStreamServer*(host: TransportAddress, cbproc: StreamCallback, @@ -814,7 +850,7 @@ proc createStreamServer*(host: TransportAddress, if not setSockOpt(serverSocket, SOL_SOCKET, SO_REUSEADDR, 1): let err = osLastError() if sock == asyncInvalidSocket: - closeAsyncSocket(serverSocket) + serverSocket.closeSocket() raiseTransportOsError(err) toSockAddr(host.address, host.port, saddr, slen) @@ -822,13 +858,13 @@ proc createStreamServer*(host: TransportAddress, slen) != 0: let err = osLastError() if sock == asyncInvalidSocket: - closeAsyncSocket(serverSocket) + serverSocket.closeSocket() raiseTransportOsError(err) if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0: let err = osLastError() if sock == asyncInvalidSocket: - closeAsyncSocket(serverSocket) + serverSocket.closeSocket() raiseTransportOsError(err) if not isNil(child): @@ -1193,21 +1229,35 @@ proc consume*(transp: StreamTransport, n = -1): Future[int] {.async.} = transp.resumeRead() await fut -proc join*(transp: StreamTransport) {.async.} = +proc join*(transp: StreamTransport): Future[void] = ## Wait until ``transp`` will not be closed. + var retFuture = newFuture[void]("streamtransport.join") + proc continuation(udata: pointer) = retFuture.complete() if not transp.future.finished: - await transp.future + transp.future.addCallback(continuation) + else: + retFuture.complete() + return retFuture proc close*(transp: StreamTransport) = ## Closes and frees resources of transport ``transp``. - if {ReadClosed, WriteClosed} * transp.state == {}: - when defined(windows): - discard cancelIo(Handle(transp.fd)) - closeAsyncSocket(transp.fd) - transp.state.incl({WriteClosed, ReadClosed}) + ## + ## Please note that release of resources is not completed immediately, to be + ## sure all resources got released please use ``await transp.join()``. + proc continuation(udata: pointer) = transp.future.complete() GC_unref(transp) + if {ReadClosed, WriteClosed} * transp.state == {}: + transp.state.incl({WriteClosed, ReadClosed}) + when defined(windows): + discard cancelIo(Handle(transp.fd)) + closeSocket(transp.fd, continuation) + +proc closeWait*(transp: StreamTransport): Future[void] = + ## Close and frees resources of transport ``transp``. + result = transp.join() + proc closed*(transp: StreamTransport): bool {.inline.} = ## Returns ``true`` if transport in closed state. result = ({ReadClosed, WriteClosed} * transp.state != {}) diff --git a/tests/testdatagram.nim b/tests/testdatagram.nim index 487718a7..d10c9202 100644 --- a/tests/testdatagram.nim +++ b/tests/testdatagram.nim @@ -139,7 +139,7 @@ proc client5(transp: DatagramTransport, if counterPtr[] == MessagesCount: transp.close() else: - var ta = initTAddress("127.0.0.1:33337") + var ta = initTAddress("127.0.0.1:33341") var req = "REQUEST" & $counterPtr[] await transp.sendTo(ta, addr req[0], len(req)) else: @@ -272,7 +272,7 @@ proc client10(transp: DatagramTransport, if counterPtr[] == TestsCount: transp.close() else: - var ta = initTAddress("127.0.0.1:33336") + var ta = initTAddress("127.0.0.1:33338") var req = "REQUEST" & $counterPtr[] var reqseq = newSeq[byte](len(req)) copyMem(addr reqseq[0], addr req[0], len(req)) @@ -370,7 +370,7 @@ proc testStringSend(): Future[int] {.async.} = proc testSeqSendTo(): Future[int] {.async.} = ## sendTo(string) test - var ta = initTAddress("127.0.0.1:33336") + var ta = initTAddress("127.0.0.1:33338") var counter = 0 var dgram1 = newDatagramTransport(client9, udata = addr counter, local = ta) var dgram2 = newDatagramTransport(client10, udata = addr counter) @@ -385,7 +385,7 @@ proc testSeqSendTo(): Future[int] {.async.} = proc testSeqSend(): Future[int] {.async.} = ## send(string) test - var ta = initTAddress("127.0.0.1:33337") + var ta = initTAddress("127.0.0.1:33339") var counter = 0 var dgram1 = newDatagramTransport(client9, udata = addr counter, local = ta) var dgram2 = newDatagramTransport(client11, udata = addr counter, remote = ta) @@ -412,7 +412,11 @@ proc waitAll(futs: seq[Future[void]]): Future[void] = return retFuture proc test3(bounded: bool): Future[int] {.async.} = - var ta = initTAddress("127.0.0.1:33337") + var ta: TransportAddress + if bounded: + ta = initTAddress("127.0.0.1:33340") + else: + ta = initTAddress("127.0.0.1:33341") var counter = 0 var dgram1 = newDatagramTransport(client1, udata = addr counter, local = ta) var clients = newSeq[Future[void]](ClientsCount) diff --git a/tests/testserver.nim b/tests/testserver.nim index 74cb1d1e..0e058ffb 100644 --- a/tests/testserver.nim +++ b/tests/testserver.nim @@ -35,6 +35,7 @@ proc serveCustomStreamClient(server: StreamServer, var answer = "ANSWER\r\n" discard await transp.write(answer) transp.close() + await transp.join() proc serveUdataStreamClient(server: StreamServer, transp: StreamTransport) {.async.} = @@ -43,6 +44,7 @@ proc serveUdataStreamClient(server: StreamServer, var msg = line & udata.test & "\r\n" discard await transp.write(msg) transp.close() + await transp.join() proc customServerTransport(server: StreamServer, fd: AsyncFD): StreamTransport = @@ -56,10 +58,12 @@ proc test1(): bool = server1.start() server1.stop() server1.close() + waitFor server1.join() var server2 = createStreamServer(ta, serveStreamClient, {ReuseAddr}) server2.start() server2.stop() server2.close() + waitFor server2.join() result = true proc client1(server: CustomServer, ta: TransportAddress) {.async.} = @@ -75,6 +79,7 @@ proc client1(server: CustomServer, ta: TransportAddress) {.async.} = transp.close() server.stop() server.close() + await server.join() proc client2(server: StreamServer, ta: TransportAddress): Future[bool] {.async.} = @@ -87,6 +92,7 @@ proc client2(server: StreamServer, transp.close() server.stop() server.close() + await server.join() proc test3(): bool = var server = CustomServer() diff --git a/tests/teststream.nim b/tests/teststream.nim index caa0ce52..8155ec3f 100644 --- a/tests/teststream.nim +++ b/tests/teststream.nim @@ -24,7 +24,7 @@ when sizeof(int) == 8: ClientsCount = 100 MessagesCount = 100 MessageSize = 20 - FilesCount = 50 + FilesCount = 100 elif sizeof(int) == 4: const BigMessageCount = 200 @@ -46,6 +46,7 @@ proc serveClient1(server: StreamServer, transp: StreamTransport) {.async.} = var res = await transp.write(cast[pointer](addr ans[0]), len(ans)) doAssert(res == len(ans)) transp.close() + await transp.join() proc serveClient2(server: StreamServer, transp: StreamTransport) {.async.} = var buffer: array[20, char] @@ -69,6 +70,7 @@ proc serveClient2(server: StreamServer, transp: StreamTransport) {.async.} = var res = await transp.write(cast[pointer](addr buffer[0]), MessageSize) doAssert(res == MessageSize) transp.close() + await transp.join() proc serveClient3(server: StreamServer, transp: StreamTransport) {.async.} = var buffer: array[20, char] @@ -95,6 +97,7 @@ proc serveClient3(server: StreamServer, transp: StreamTransport) {.async.} = doAssert(res == len(ans)) dec(counter) transp.close() + await transp.join() proc serveClient4(server: StreamServer, transp: StreamTransport) {.async.} = var pathname = await transp.readLine() @@ -110,6 +113,7 @@ proc serveClient4(server: StreamServer, transp: StreamTransport) {.async.} = var res = await transp.write(cast[pointer](addr answer[0]), len(answer)) doAssert(res == len(answer)) transp.close() + await transp.join() proc serveClient5(server: StreamServer, transp: StreamTransport) {.async.} = var data = await transp.read() @@ -124,6 +128,7 @@ proc serveClient5(server: StreamServer, transp: StreamTransport) {.async.} = if counter[] == 0: server.stop() server.close() + await server.join() proc serveClient6(server: StreamServer, transp: StreamTransport) {.async.} = var expect = ConstantMessage @@ -138,6 +143,7 @@ proc serveClient6(server: StreamServer, transp: StreamTransport) {.async.} = if counter[] == 0: server.stop() server.close() + await server.join() proc serveClient7(server: StreamServer, transp: StreamTransport) {.async.} = var answer = "DONE\r\n" @@ -150,6 +156,7 @@ proc serveClient7(server: StreamServer, transp: StreamTransport) {.async.} = var res = await transp.write(answer) doAssert(res == len(answer)) transp.close() + await transp.join() proc serveClient8(server: StreamServer, transp: StreamTransport) {.async.} = var answer = "DONE\r\n" @@ -171,6 +178,7 @@ proc serveClient8(server: StreamServer, transp: StreamTransport) {.async.} = transp.close() server.stop() server.close() + await server.join() proc swarmWorker1(address: TransportAddress): Future[int] {.async.} = var transp = await connect(address) @@ -185,6 +193,7 @@ proc swarmWorker1(address: TransportAddress): Future[int] {.async.} = doAssert(num == i) inc(result) transp.close() + await transp.join() proc swarmWorker2(address: TransportAddress): Future[int] {.async.} = var transp = await connect(address) @@ -208,6 +217,7 @@ proc swarmWorker2(address: TransportAddress): Future[int] {.async.} = doAssert(num == i) inc(result) transp.close() + await transp.join() proc swarmWorker3(address: TransportAddress): Future[int] {.async.} = var transp = await connect(address) @@ -235,6 +245,7 @@ proc swarmWorker3(address: TransportAddress): Future[int] {.async.} = doAssert(num == i) inc(result) transp.close() + await transp.join() proc swarmWorker4(address: TransportAddress): Future[int] {.async.} = var transp = await connect(address) @@ -249,9 +260,9 @@ proc swarmWorker4(address: TransportAddress): Future[int] {.async.} = handle = int(getFileHandle(fhandle)) doAssert(handle > 0) name = name & "\r\n" - ssize = $size & "\r\n" var res = await transp.write(cast[pointer](addr name[0]), len(name)) doAssert(res == len(name)) + ssize = $size & "\r\n" res = await transp.write(cast[pointer](addr ssize[0]), len(ssize)) doAssert(res == len(ssize)) var checksize = await transp.writeFile(handle, 0'u, size) @@ -261,6 +272,7 @@ proc swarmWorker4(address: TransportAddress): Future[int] {.async.} = doAssert(ans == "OK") result = 1 transp.close() + await transp.join() proc swarmWorker5(address: TransportAddress): Future[int] {.async.} = var transp = await connect(address) @@ -269,6 +281,7 @@ proc swarmWorker5(address: TransportAddress): Future[int] {.async.} = var res = await transp.write(data) result = MessagesCount transp.close() + await transp.join() proc swarmWorker6(address: TransportAddress): Future[int] {.async.} = var transp = await connect(address) @@ -279,6 +292,7 @@ proc swarmWorker6(address: TransportAddress): Future[int] {.async.} = var res = await transp.write(seqdata) result = MessagesCount transp.close() + await transp.join() proc swarmWorker7(address: TransportAddress): Future[int] {.async.} = var transp = await connect(address) @@ -292,6 +306,7 @@ proc swarmWorker7(address: TransportAddress): Future[int] {.async.} = doAssert(line == "DONE") result = 1 transp.close() + await transp.join() proc swarmWorker8(address: TransportAddress): Future[int] {.async.} = var transp = await connect(address) @@ -305,6 +320,7 @@ proc swarmWorker8(address: TransportAddress): Future[int] {.async.} = doAssert(line == "DONE") result = 1 transp.close() + await transp.join() proc waitAll[T](futs: seq[Future[T]]): Future[void] = var counter = len(futs) @@ -390,6 +406,7 @@ proc test1(): Future[int] {.async.} = result = await swarmManager1(ta) server.stop() server.close() + await server.join() proc test2(): Future[int] {.async.} = var ta = initTAddress("127.0.0.1:31345") @@ -399,6 +416,7 @@ proc test2(): Future[int] {.async.} = result = await swarmManager2(ta) server.stop() server.close() + await server.join() proc test3(): Future[int] {.async.} = var ta = initTAddress("127.0.0.1:31346") @@ -408,6 +426,7 @@ proc test3(): Future[int] {.async.} = result = await swarmManager3(ta) server.stop() server.close() + await server.join() proc test4(): Future[int] {.async.} = var ta = initTAddress("127.0.0.1:31347") @@ -416,6 +435,7 @@ proc test4(): Future[int] {.async.} = result = await swarmManager4(ta) server.stop() server.close() + await server.join() proc test5(): Future[int] {.async.} = var ta = initTAddress("127.0.0.1:31348") @@ -424,7 +444,6 @@ proc test5(): Future[int] {.async.} = udata = cast[pointer](addr counter)) server.start() result = await swarmManager5(ta) - await server.join() proc test6(): Future[int] {.async.} = var ta = initTAddress("127.0.0.1:31349") @@ -433,7 +452,6 @@ proc test6(): Future[int] {.async.} = udata = cast[pointer](addr counter)) server.start() result = await swarmManager6(ta) - await server.join() proc test7(): Future[int] {.async.} = var ta = initTAddress("127.0.0.1:31350") @@ -442,6 +460,7 @@ proc test7(): Future[int] {.async.} = result = await swarmWorker7(ta) server.stop() server.close() + await server.join() proc test8(): Future[int] {.async.} = var ta = initTAddress("127.0.0.1:31350") @@ -450,6 +469,168 @@ proc test8(): Future[int] {.async.} = result = await swarmWorker8(ta) server.stop() server.close() + await server.join() + +proc serveClient9(server: StreamServer, transp: StreamTransport) {.async.} = + var expect = "" + for i in 0..