From 02b8da986bd68cc14a5925a76e96ee7384f56e67 Mon Sep 17 00:00:00 2001 From: Eugene Kabanov Date: Wed, 24 Jun 2020 11:21:52 +0300 Subject: [PATCH] Add accept() call (#103) * Add accept() call and tests. * Fix rare fd leaks on Windows. * Fix compilation warnings. * Add fd leak test. * Bump version to 2.4.0. --- chronos.nimble | 2 +- chronos/asyncloop.nim | 2 +- chronos/handles.nim | 30 ++- chronos/transports/common.nim | 14 +- chronos/transports/stream.nim | 348 +++++++++++++++++++++++++++++----- tests/teststream.nim | 125 +++++++++++- 6 files changed, 466 insertions(+), 55 deletions(-) diff --git a/chronos.nimble b/chronos.nimble index f67945e..7c34fca 100644 --- a/chronos.nimble +++ b/chronos.nimble @@ -1,5 +1,5 @@ packageName = "chronos" -version = "2.3.9" +version = "2.4.0" author = "Status Research & Development GmbH" description = "Chronos" license = "Apache License 2.0 or MIT" diff --git a/chronos/asyncloop.nim b/chronos/asyncloop.nim index 2474b9d..96dfee8 100644 --- a/chronos/asyncloop.nim +++ b/chronos/asyncloop.nim @@ -465,7 +465,7 @@ when defined(windows) or defined(nimdoc): ## Closes a (pipe/file) handle and ensures that it is unregistered. let loop = getGlobalDispatcher() loop.handles.excl(fd) - doAssert closeHandle(Handle(fd)) == 1 + discard closeHandle(Handle(fd)) if not isNil(aftercb): var acb = AsyncCallback(function: aftercb) loop.callbacks.addLast(acb) diff --git a/chronos/handles.nim b/chronos/handles.nim index a05c4e3..eef66cb 100644 --- a/chronos/handles.nim +++ b/chronos/handles.nim @@ -26,7 +26,7 @@ when defined(windows): proc connectNamedPipe(hNamedPipe: Handle, lpOverlapped: pointer): WINBOOL {.importc: "ConnectNamedPipe", stdcall, dynlib: "kernel32".} else: - import posix + import os, posix const asyncInvalidSocket* = AsyncFD(posix.INVALID_SOCKET) TCP_NODELAY* = 1 @@ -117,6 +117,34 @@ proc wrapAsyncSocket*(sock: SocketHandle): AsyncFD = result = AsyncFD(sock) register(result) +proc getMaxOpenFiles*(): int = + ## Returns maximum file descriptor number that can be opened by this process. + ## + ## Note: On Windows its impossible to obtain such number, so getMaxOpenFiles() + ## will return constant value of 16384. You can get more information on this + ## link https://docs.microsoft.com/en-us/archive/blogs/markrussinovich/pushing-the-limits-of-windows-handles + when defined(windows): + result = 16384 + else: + var limits: RLimit + if getrlimit(posix.RLIMIT_NOFILE, limits) != 0: + raiseOSError(osLastError()) + result = int(limits.rlim_cur) + +proc setMaxOpenFiles*(count: int) = + ## Set maximum file descriptor number that can be opened by this process. + ## + ## Note: On Windows its impossible to set this value, so it just a nop call. + when defined(windows): + discard + else: + var limits: RLimit + if getrlimit(posix.RLIMIT_NOFILE, limits) != 0: + raiseOSError(osLastError()) + limits.rlim_cur = count + if setrlimit(posix.RLIMIT_NOFILE, limits) != 0: + raiseOSError(osLastError()) + proc createAsyncPipe*(): tuple[read: AsyncFD, write: AsyncFD] = ## Create new asynchronouse pipe. ## Returns tuple of read pipe handle and write pipe handle``asyncInvalidPipe`` on error. diff --git a/chronos/transports/common.nim b/chronos/transports/common.nim index 8787e49..0d87d91 100644 --- a/chronos/transports/common.nim +++ b/chronos/transports/common.nim @@ -69,6 +69,7 @@ when defined(windows): domain*: Domain # Current server domain (IPv4 or IPv6) apending*: bool asock*: AsyncFD # Current AcceptEx() socket + errorCode*: OSErrorCode # Current error code abuffer*: array[128, byte] # Windows AcceptEx() buffer aovl*: CustomOverlapped # AcceptEx OVERLAPPED structure else: @@ -82,6 +83,7 @@ else: flags*: set[ServerFlags] # Flags bufferSize*: int # Size of internal transports' buffer loopFuture*: Future[void] # Server's main Future + errorCode*: OSErrorCode # Current error code type TransportError* = object of AsyncError @@ -100,6 +102,8 @@ type ## Transport's capability not supported exception TransportUseClosedError* = object of TransportError ## Usage after transport close exception + TransportTooManyError* = object of TransportError + ## Too many open file descriptors exception TransportState* = enum ## Transport's state @@ -470,7 +474,8 @@ template checkClosed*(t: untyped) = template checkClosed*(t: untyped, future: untyped) = if (ReadClosed in (t).state) or (WriteClosed in (t).state): - future.fail(newException(TransportUseClosedError, "Transport is already closed!")) + future.fail(newException(TransportUseClosedError, + "Transport is already closed!")) return future template checkWriteEof*(t: untyped, future: untyped) = @@ -484,6 +489,12 @@ template getError*(t: untyped): ref Exception = (t).error = nil err +template getServerUseClosedError*(): ref TransportUseClosedError = + newException(TransportUseClosedError, "Server is already closed!") + +template getTransportTooManyError*(): ref TransportTooManyError = + newException(TransportTooManyError, "Too many open transports!") + template getTransportOsError*(err: OSErrorCode): ref TransportOsError = var msg = "(" & $int(err) & ") " & osErrorMsg(err) var tre = newException(TransportOsError, msg) @@ -526,6 +537,7 @@ when defined(windows): ERROR_PIPE_NOT_CONNECTED* = 233 ERROR_NO_DATA* = 232 ERROR_CONNECTION_ABORTED* = 1236 + ERROR_TOO_MANY_OPEN_FILES* = 4 proc cancelIo*(hFile: HANDLE): WINBOOL {.stdcall, dynlib: "kernel32", importc: "CancelIo".} diff --git a/chronos/transports/stream.nim b/chronos/transports/stream.nim index bc13c8d..e5052f6 100644 --- a/chronos/transports/stream.nim +++ b/chronos/transports/stream.nim @@ -273,6 +273,20 @@ proc failPendingWriteQueue(queue: var Deque[StreamVector], if not(vector.writer.finished()): vector.writer.fail(error) +proc clean(server: StreamServer) {.inline.} = + if not(server.loopFuture.finished()): + untrackServer(server) + server.loopFuture.complete() + if not isNil(server.udata) and GCUserData in server.flags: + GC_unref(cast[ref int](server.udata)) + GC_unref(server) + +proc clean(transp: StreamTransport) {.inline.} = + if not(transp.future.finished()): + untrackStream(transp) + transp.future.complete() + GC_unref(transp) + when defined(windows): template zeroOvelappedOffset(t: untyped) = @@ -539,11 +553,7 @@ when defined(windows): if ReadClosed in transp.state: # Stop tracking transport - untrackStream(transp) - # If `ReadClosed` present, then close(transport) was called. - if not(transp.future.finished()): - transp.future.complete() - GC_unref(transp) + transp.clean() if ReadPaused in transp.state: # Transport buffer is full, so we will not continue on reading. @@ -771,6 +781,26 @@ when defined(windows): return retFuture + proc createAcceptPipe(server: StreamServer) = + let pipeSuffix = $cast[cstring](addr server.local.address_un) + let pipeName = newWideCString(r"\\.\pipe\" & pipeSuffix[1 .. ^1]) + var openMode = PIPE_ACCESS_DUPLEX or FILE_FLAG_OVERLAPPED + if FirstPipe notin server.flags: + openMode = openMode or FILE_FLAG_FIRST_PIPE_INSTANCE + server.flags.incl(FirstPipe) + let pipeMode = int32(PIPE_TYPE_BYTE or PIPE_READMODE_BYTE or PIPE_WAIT) + let pipeHandle = createNamedPipe(pipeName, openMode, pipeMode, + PIPE_UNLIMITED_INSTANCES, + DWORD(server.bufferSize), + DWORD(server.bufferSize), + DWORD(0), nil) + if pipeHandle != INVALID_HANDLE_VALUE: + server.sock = AsyncFD(pipeHandle) + register(server.sock) + else: + server.sock = asyncInvalidPipe + server.errorCode = osLastError() + proc acceptPipeLoop(udata: pointer) {.gcsafe, nimcall.} = var ovl = cast[PtrCustomOverlapped](udata) var server = cast[StreamServer](ovl.data.udata) @@ -797,14 +827,7 @@ when defined(windows): elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED: # CancelIO() interrupt or close call. if server.status in {ServerStatus.Closed, ServerStatus.Stopped}: - # Stop tracking server - untrackServer(server) - # Completing server's Future - if not(server.loopFuture.finished()): - server.loopFuture.complete() - if not isNil(server.udata) and GCUserData in server.flags: - GC_unref(cast[ref int](server.udata)) - GC_unref(server) + server.clean() break else: doAssert disconnectNamedPipe(Handle(server.sock)) == 1 @@ -850,13 +873,7 @@ when defined(windows): # connectNamedPipe session. if server.status in {ServerStatus.Closed, ServerStatus.Stopped}: if not(server.loopFuture.finished()): - # Stop tracking server - untrackServer(server) - server.loopFuture.complete() - if not isNil(server.udata) and GCUserData in server.flags: - GC_unref(cast[ref int](server.udata)) - - GC_unref(server) + server.clean() proc acceptLoop(udata: pointer) {.gcsafe, nimcall.} = var ovl = cast[PtrCustomOverlapped](udata) @@ -890,14 +907,11 @@ when defined(windows): elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED: # CancelIO() interrupt or close. + server.asock.closeSocket() if server.status in {ServerStatus.Closed, ServerStatus.Stopped}: # Stop tracking server if not(server.loopFuture.finished()): - untrackServer(server) - server.loopFuture.complete() - if not isNil(server.udata) and GCUserData in server.flags: - GC_unref(cast[ref int](server.udata)) - GC_unref(server) + server.clean() break else: server.asock.closeSocket() @@ -937,12 +951,7 @@ when defined(windows): # AcceptEx session. if server.status in {ServerStatus.Closed, ServerStatus.Stopped}: if not(server.loopFuture.finished()): - # Stop tracking server - untrackServer(server) - server.loopFuture.complete() - if not isNil(server.udata) and GCUserData in server.flags: - GC_unref(cast[ref int](server.udata)) - GC_unref(server) + server.clean() proc resumeRead(transp: StreamTransport) {.inline.} = if ReadPaused in transp.state: @@ -962,6 +971,166 @@ when defined(windows): if not server.apending: server.aovl.data.cb(addr server.aovl) + proc accept*(server: StreamServer): Future[StreamTransport] = + var retFuture = newFuture[StreamTransport]("stream.server.accept") + + doAssert(server.status != ServerStatus.Running, + "You could not use accept() if server was already started") + + if server.status == ServerStatus.Closed: + retFuture.fail(getServerUseClosedError()) + return retFuture + + proc continuationSocket(udata: pointer) {.gcsafe.} = + var ovl = cast[PtrCustomOverlapped](udata) + var server = cast[StreamServer](ovl.data.udata) + + server.apending = false + if ovl.data.errCode == OSErrorCode(-1): + if setsockopt(SocketHandle(server.asock), cint(SOL_SOCKET), + cint(SO_UPDATE_ACCEPT_CONTEXT), addr server.sock, + SockLen(sizeof(SocketHandle))) != 0'i32: + let err = OSErrorCode(wsaGetLastError()) + server.asock.closeSocket() + retFuture.fail(getTransportOsError(err)) + else: + var ntransp: StreamTransport + if not isNil(server.init): + let transp = server.init(server, server.asock) + ntransp = newStreamSocketTransport(server.asock, + server.bufferSize, + transp) + else: + ntransp = newStreamSocketTransport(server.asock, + server.bufferSize, nil) + # Start tracking transport + trackStream(ntransp) + retFuture.complete(ntransp) + elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED: + # CancelIO() interrupt or close. + server.asock.closeSocket() + retFuture.fail(getServerUseClosedError()) + server.clean() + else: + server.asock.closeSocket() + retFuture.fail(getTransportOsError(ovl.data.errCode)) + + proc cancellationSocket(udata: pointer) {.gcsafe.} = + server.asock.closeSocket() + + proc continuationPipe(udata: pointer) {.gcsafe.} = + var ovl = cast[PtrCustomOverlapped](udata) + var server = cast[StreamServer](ovl.data.udata) + + server.apending = false + if ovl.data.errCode == OSErrorCode(-1): + var ntransp: StreamTransport + var flags = {WinServerPipe} + if NoPipeFlash in server.flags: + flags.incl(WinNoPipeFlash) + if not isNil(server.init): + var transp = server.init(server, server.sock) + ntransp = newStreamPipeTransport(server.sock, server.bufferSize, + transp, flags) + else: + ntransp = newStreamPipeTransport(server.sock, server.bufferSize, + nil, flags) + # Start tracking transport + trackStream(ntransp) + server.createAcceptPipe() + retFuture.complete(ntransp) + + elif int32(ovl.data.errCode) in {ERROR_OPERATION_ABORTED, + ERROR_PIPE_NOT_CONNECTED}: + # CancelIO() interrupt or close call. + retFuture.fail(getServerUseClosedError()) + server.clean() + else: + let sock = server.sock + server.createAcceptPipe() + closeHandle(sock) + retFuture.fail(getTransportOsError(ovl.data.errCode)) + + proc cancellationPipe(udata: pointer) {.gcsafe.} = + server.sock.closeHandle() + + if server.local.family in {AddressFamily.IPv4, AddressFamily.IPv6}: + # TCP Sockets part + var loop = getGlobalDispatcher() + server.asock = createAsyncSocket(server.domain, SockType.SOCK_STREAM, + Protocol.IPPROTO_TCP) + if server.asock == asyncInvalidSocket: + let err = osLastError() + if int32(err) == ERROR_TOO_MANY_OPEN_FILES: + retFuture.fail(getTransportTooManyError()) + else: + retFuture.fail(getTransportOsError(err)) + return retFuture + + var dwBytesReceived = DWORD(0) + let dwReceiveDataLength = DWORD(0) + let dwLocalAddressLength = DWORD(sizeof(Sockaddr_in6) + 16) + let dwRemoteAddressLength = DWORD(sizeof(Sockaddr_in6) + 16) + + server.aovl.data = CompletionData(fd: server.sock, + cb: continuationSocket, + udata: cast[pointer](server)) + server.apending = true + let res = loop.acceptEx(SocketHandle(server.sock), + SocketHandle(server.asock), + addr server.abuffer[0], + dwReceiveDataLength, dwLocalAddressLength, + dwRemoteAddressLength, addr dwBytesReceived, + cast[POVERLAPPED](addr server.aovl)) + if not res: + let err = osLastError() + if int32(err) == ERROR_OPERATION_ABORTED: + server.apending = false + retFuture.fail(getServerUseClosedError()) + return retFuture + elif int32(err) == ERROR_IO_PENDING: + discard + else: + server.apending = false + retFuture.fail(getTransportOsError(err)) + return retFuture + + retFuture.cancelCallback = cancellationSocket + + elif server.local.family in {AddressFamily.Unix}: + # Unix domain sockets emulation via Windows Named pipes part. + server.apending = true + if server.sock == asyncInvalidPipe: + let err = server.errorCode + if int32(err) == ERROR_TOO_MANY_OPEN_FILES: + retFuture.fail(getTransportTooManyError()) + else: + retFuture.fail(getTransportOsError(err)) + return retFuture + + server.aovl.data = CompletionData(fd: server.sock, + cb: continuationPipe, + udata: cast[pointer](server)) + server.apending = true + let res = connectNamedPipe(HANDLE(server.sock), + cast[POVERLAPPED](addr server.aovl)) + if res == 0: + let err = osLastError() + if int32(err) == ERROR_OPERATION_ABORTED: + server.apending = false + retFuture.fail(getServerUseClosedError()) + return retFuture + elif int32(err) in {ERROR_IO_PENDING, ERROR_PIPE_CONNECTED}: + discard + else: + server.apending = false + retFuture.fail(getTransportOsError(err)) + return retFuture + + retFuture.cancelCallback = cancellationPipe + + return retFuture + else: import ../sendfile @@ -1227,7 +1396,11 @@ else: sock = createAsyncSocket(address.getDomain(), SockType.SOCK_STREAM, proto) if sock == asyncInvalidSocket: - retFuture.fail(getTransportOsError(osLastError())) + let err = osLastError() + if int(err) == EMFILE: + retFuture.fail(getTransportTooManyError()) + else: + retFuture.fail(getTransportOsError(err)) return retFuture proc continuation(udata: pointer) {.gcsafe.} = @@ -1324,8 +1497,63 @@ else: transp.state.excl(WritePaused) addWriter(transp.fd, writeStreamLoop, cast[pointer](transp)) + proc accept*(server: StreamServer): Future[StreamTransport] = + var retFuture = newFuture[StreamTransport]("stream.server.accept") + + doAssert(server.status != ServerStatus.Running, + "You could not use accept() if server was started with start()") + if server.status == ServerStatus.Closed: + retFuture.fail(getServerUseClosedError()) + return retFuture + + proc continuation(udata: pointer) {.gcsafe.} = + var + saddr: Sockaddr_storage + slen: SockLen + while true: + let res = posix.accept(SocketHandle(server.sock), + cast[ptr SockAddr](addr saddr), addr slen) + if int(res) > 0: + let sock = wrapAsyncSocket(res) + if sock != asyncInvalidSocket: + var ntransp: StreamTransport + if not isNil(server.init): + let transp = server.init(server, sock) + ntransp = newStreamSocketTransport(sock, server.bufferSize, + transp) + else: + ntransp = newStreamSocketTransport(sock, server.bufferSize, nil) + # Start tracking transport + trackStream(ntransp) + retFuture.complete(ntransp) + else: + retFuture.fail(getTransportOsError(osLastError())) + else: + let err = osLastError() + if int(err) == EINTR: + continue + elif int(err) == EAGAIN: + # This error appears only when server get closed, while accept() + # call pending. + retFuture.fail(getServerUseClosedError()) + elif int(err) == EMFILE: + retFuture.fail(getTransportTooManyError()) + else: + retFuture.fail(getTransportOsError(err)) + break + + removeReader(server.sock) + + proc cancellation(udata: pointer) {.gcsafe.} = + removeReader(server.sock) + + addReader(server.sock, continuation, nil) + retFuture.cancelCallback = cancellation + return retFuture + proc start*(server: StreamServer) = ## Starts ``server``. + doAssert(not(isNil(server.function))) if server.status == ServerStatus.Starting: server.resumeAccept() server.status = ServerStatus.Running @@ -1363,24 +1591,25 @@ proc close*(server: StreamServer) = proc continuation(udata: pointer) {.gcsafe.} = # Stop tracking server if not(server.loopFuture.finished()): - untrackServer(server) - server.loopFuture.complete() - if not isNil(server.udata) and GCUserData in server.flags: - GC_unref(cast[ref int](server.udata)) - GC_unref(server) + server.clean() - if server.status == ServerStatus.Stopped: + let r1 = (server.status == ServerStatus.Stopped) and + not(isNil(server.function)) + let r2 = (server.status == ServerStatus.Starting) and isNil(server.function) + + if r1 or r2: server.status = ServerStatus.Closed when defined(windows): if server.local.family in {AddressFamily.IPv4, AddressFamily.IPv6}: if not server.apending: server.sock.closeSocket(continuation) else: + server.asock.closeSocket() server.sock.closeSocket() elif server.local.family in {AddressFamily.Unix}: if NoPipeFlash notin server.flags: discard flushFileBuffers(Handle(server.sock)) - doAssert disconnectNamedPipe(Handle(server.sock)) == 1 + discard disconnectNamedPipe(Handle(server.sock)) if not server.apending: server.sock.closeHandle(continuation) else: @@ -1563,8 +1792,13 @@ proc createStreamServer*(host: TransportAddress, elif host.family == AddressFamily.Unix: cb = acceptPipeLoop - result.aovl.data = CompletionData(fd: serverSocket, cb: cb, - udata: cast[pointer](result)) + if not(isNil(cbproc)): + result.aovl.data = CompletionData(fd: serverSocket, cb: cb, + udata: cast[pointer](result)) + else: + if host.family == AddressFamily.Unix: + result.createAcceptPipe() + result.domain = host.getDomain() result.apending = false @@ -1572,6 +1806,17 @@ proc createStreamServer*(host: TransportAddress, trackServer(result) GC_ref(result) +proc createStreamServer*(host: TransportAddress, + flags: set[ServerFlags] = {}, + sock: AsyncFD = asyncInvalidSocket, + backlog: int = 100, + bufferSize: int = DefaultStreamBufferSize, + child: StreamServer = nil, + init: TransportInitCallback = nil, + udata: pointer = nil): StreamServer = + result = createStreamServer(host, nil, flags, sock, backlog, bufferSize, + child, init, cast[pointer](udata)) + proc createStreamServer*[T](host: TransportAddress, cbproc: StreamCallback, flags: set[ServerFlags] = {}, @@ -1586,6 +1831,19 @@ proc createStreamServer*[T](host: TransportAddress, result = createStreamServer(host, cbproc, fflags, sock, backlog, bufferSize, child, init, cast[pointer](udata)) +proc createStreamServer*[T](host: TransportAddress, + flags: set[ServerFlags] = {}, + udata: ref T, + sock: AsyncFD = asyncInvalidSocket, + backlog: int = 100, + bufferSize: int = DefaultStreamBufferSize, + child: StreamServer = nil, + init: TransportInitCallback = nil): StreamServer = + var fflags = flags + {GCUserData} + GC_ref(udata) + result = createStreamServer(host, nil, fflags, sock, backlog, bufferSize, + child, init, cast[pointer](udata)) + proc getUserData*[T](server: StreamServer): T {.inline.} = ## Obtain user data stored in ``server`` object. result = cast[T](server.udata) @@ -1916,11 +2174,7 @@ proc close*(transp: StreamTransport) = ## Please note that release of resources is not completed immediately, to be ## sure all resources got released please use ``await transp.join()``. proc continuation(udata: pointer) {.gcsafe.} = - if not(transp.future.finished()): - transp.future.complete() - # Stop tracking stream - untrackStream(transp) - GC_unref(transp) + transp.clean() if {ReadClosed, WriteClosed} * transp.state == {}: transp.state.incl({WriteClosed, ReadClosed}) @@ -1929,7 +2183,7 @@ proc close*(transp: StreamTransport) = if WinServerPipe in transp.flags: if WinNoPipeFlash notin transp.flags: discard flushFileBuffers(Handle(transp.fd)) - doAssert disconnectNamedPipe(Handle(transp.fd)) == 1 + discard disconnectNamedPipe(Handle(transp.fd)) else: if WinNoPipeFlash notin transp.flags: discard flushFileBuffers(Handle(transp.fd)) diff --git a/tests/teststream.nim b/tests/teststream.nim index b32c0b3..4f55805 100644 --- a/tests/teststream.nim +++ b/tests/teststream.nim @@ -25,6 +25,7 @@ suite "Stream Transport test suite": MessagesCount = 10 MessageSize = 20 FilesCount = 10 + TestsCount = 100 m1 = "readLine() multiple clients with messages (" & $ClientsCount & " clients x " & $MessagesCount & " messages)" @@ -48,16 +49,26 @@ suite "Stream Transport test suite": m17 = "0.0.0.0/::0 (INADDR_ANY) test" when defined(windows): - var addresses = [ + let addresses = [ initTAddress("127.0.0.1:33335"), initTAddress(r"/LOCAL\testpipe") ] else: - var addresses = [ + let addresses = [ initTAddress("127.0.0.1:33335"), initTAddress(r"/tmp/testpipe") ] - var prefixes = ["[IP] ", "[UNIX] "] + + let prefixes = ["[IP] ", "[UNIX] "] + + var markFD: int + + proc getCurrentFD(): int = + let local = initTAddress("127.0.0.1:33334") + let sock = createAsyncSocket(local.getDomain(), SockType.SOCK_DGRAM, + Protocol.IPPROTO_UDP) + closeSocket(sock) + return int(sock) proc createBigMessage(size: int): seq[byte] = var message = "MESSAGE" @@ -1057,6 +1068,92 @@ suite "Stream Transport test suite": await server.join() result = c7 + proc testAccept(address: TransportAddress): Future[bool] {.async.} = + var server = createStreamServer(address, flags = {ReuseAddr}) + var connected = 0 + var accepted = 0 + + proc acceptTask(server: StreamServer) {.async.} = + for i in 0 ..< TestsCount: + let transp = await server.accept() + await transp.closeWait() + inc(accepted) + + var acceptFut = acceptTask(server) + var transp: StreamTransport + + try: + for i in 0 ..< TestsCount: + transp = await connect(address) + await sleepAsync(10.milliseconds) + await transp.closeWait() + inc(connected) + if await withTimeout(acceptFut, 5.seconds): + if acceptFut.finished() and not(acceptFut.failed()): + result = (connected == TestsCount) and (connected == accepted) + finally: + await server.closeWait() + if not(isNil(transp)): + await transp.closeWait() + + proc testAcceptClose(address: TransportAddress): Future[bool] {.async.} = + var server = createStreamServer(address, flags = {ReuseAddr}) + + proc acceptTask(server: StreamServer) {.async.} = + let transp = await server.accept() + await transp.closeWait() + + var acceptFut = acceptTask(server) + await server.closeWait() + + if await withTimeout(acceptFut, 5.seconds): + if acceptFut.finished() and acceptFut.failed(): + if acceptFut.readError() of TransportUseClosedError: + result = true + else: + result = false + + when not(defined(windows)): + proc testAcceptTooMany(address: TransportAddress): Future[bool] {.async.} = + let maxFiles = getMaxOpenFiles() + var server = createStreamServer(address, flags = {ReuseAddr}) + let isock = int(server.sock) + let newMaxFiles = isock + 4 + setMaxOpenFiles(newMaxFiles) + + proc acceptTask(server: StreamServer): Future[bool] {.async.} = + var transports = newSeq[StreamTransport]() + try: + for i in 0 ..< 3: + let transp = await server.accept() + transports.add(transp) + except TransportTooManyError: + var pending = newSeq[Future[void]]() + for item in transports: + pending.add(closeWait(item)) + await allFutures(pending) + return true + + var acceptFut = acceptTask(server) + + try: + for i in 0 ..< 3: + try: + let transp = await connect(address) + await sleepAsync(10.milliseconds) + await transp.closeWait() + except TransportTooManyError: + break + if await withTimeout(acceptFut, 5.seconds): + if acceptFut.finished() and not(acceptFut.failed()): + if acceptFut.read() == true: + result = true + finally: + await server.closeWait() + setMaxOpenFiles(maxFiles) + + markFD = getCurrentFD() + for i in 0..