diff --git a/asyncdispatch2/transports/common.nim b/asyncdispatch2/transports/common.nim index e5872fb..0b11f5a 100644 --- a/asyncdispatch2/transports/common.nim +++ b/asyncdispatch2/transports/common.nim @@ -36,20 +36,41 @@ type Starting, # Server created Stopped, # Server stopped Running, # Server running - Paused # Server paused + Closed # Server closed - SocketServer* = ref object of RootRef - ## Socket server object - sock*: AsyncFD # Socket - local*: TransportAddress # Address - actEvent*: AsyncEvent # Activation event - action*: ServerCommand # Activation command - status*: ServerStatus # Current server status - udata*: pointer # User-defined pointer - flags*: set[ServerFlags] # Flags - bufferSize*: int # Size of internal transports' buffer - loopFuture*: Future[void] # Server's main Future +when defined(windows): + type + SocketServer* = ref object of RootRef + ## Socket server object + sock*: AsyncFD # Socket + local*: TransportAddress # Address + # actEvent*: AsyncEvent # Activation event + # action*: ServerCommand # Activation command + status*: ServerStatus # Current server status + udata*: pointer # User-defined pointer + flags*: set[ServerFlags] # Flags + bufferSize*: int # Size of internal transports' buffer + loopFuture*: Future[void] # Server's main Future + domain*: Domain # Current server domain (IPv4 or IPv6) + apending*: bool + asock*: AsyncFD # Current AcceptEx() socket + abuffer*: array[128, byte] # Windows AcceptEx() buffer + aovl*: CustomOverlapped # AcceptEx OVERLAPPED structure +else: + type + SocketServer* = ref object of RootRef + ## Socket server object + sock*: AsyncFD # Socket + local*: TransportAddress # Address + # actEvent*: AsyncEvent # Activation event + # action*: ServerCommand # Activation command + status*: ServerStatus # Current server status + udata*: pointer # User-defined pointer + flags*: set[ServerFlags] # Flags + bufferSize*: int # Size of internal transports' buffer + loopFuture*: Future[void] # Server's main Future +type TransportError* = object of Exception ## Transport's specific exception TransportOsError* = object of TransportError diff --git a/asyncdispatch2/transports/datagram.nim b/asyncdispatch2/transports/datagram.nim index d955b1d..af4b67d 100644 --- a/asyncdispatch2/transports/datagram.nim +++ b/asyncdispatch2/transports/datagram.nim @@ -190,8 +190,8 @@ when defined(windows): transp.state.excl(ReadPending) transp.state.incl(ReadPaused) elif int(err) == WSAECONNRESET: - transp.state.excl(ReadPending) - continue + transp.state = {ReadPaused, ReadEof} + break elif int(err) == ERROR_IO_PENDING: discard else: @@ -580,27 +580,12 @@ proc createDatagramServer*(host: TransportAddress, proc start*(server: DatagramServer) = ## Starts ``server``. - if server.status in {ServerStatus.Starting, ServerStatus.Paused}: + if server.status == ServerStatus.Starting: server.transport.resumeRead() + server.status = ServerStatus.Running proc stop*(server: DatagramServer) = ## Stops ``server``. - if server.status in {ServerStatus.Paused, ServerStatus.Running}: - when defined(windows): - if server.status == ServerStatus.Running: - if {WritePending, ReadPending} * server.transport.state != {}: - ## CancelIO will stop both reading and writing. - discard cancelIo(Handle(server.transport.fd)) - else: - if server.status == ServerStatus.Running: - if WritePaused notin server.transport.state: - server.transport.fd.removeWriter() - if ReadPaused notin server.transport.state: - server.transport.fd.removeReader() - server.status = ServerStatus.Stopped - -proc pause*(server: DatagramServer) = - ## Pause ``server``. if server.status == ServerStatus.Running: when defined(windows): if {WritePending, ReadPending} * server.transport.state != {}: @@ -611,7 +596,7 @@ proc pause*(server: DatagramServer) = server.transport.fd.removeWriter() if ReadPaused notin server.transport.state: server.transport.fd.removeReader() - server.status = ServerStatus.Paused + server.status = ServerStatus.Stopped proc join*(server: DatagramServer) {.async.} = ## Waits until ``server`` is not stopped. @@ -621,5 +606,6 @@ proc join*(server: DatagramServer) {.async.} = proc close*(server: DatagramServer) = ## Release ``server`` resources. if server.status == ServerStatus.Stopped: + server.status = ServerStatus.Closed server.transport.close() GC_unref(server) diff --git a/asyncdispatch2/transports/stream.nim b/asyncdispatch2/transports/stream.nim index 8973f14..a894c54 100644 --- a/asyncdispatch2/transports/stream.nim +++ b/asyncdispatch2/transports/stream.nim @@ -7,9 +7,9 @@ # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) +import net, nativesockets, os, deques, strutils import ../asyncloop, ../asyncsync, ../handles, ../sendfile import common -import net, nativesockets, os, deques, strutils when defined(windows): import winlean @@ -108,7 +108,6 @@ template setWriteError(t, e: untyped) = template finishReader(t: untyped) = var reader = (t).reader - (t).reader = nil reader.complete() template checkPending(t: untyped) = @@ -141,12 +140,6 @@ when defined(windows): wovl: CustomOverlapped # Writer OVERLAPPED structure roffset: int # Pending reading offset - WindowsStreamServer* = ref object of RootRef - server: SocketServer # Server object - domain: Domain # Current server domain (IPv4 or IPv6) - abuffer: array[128, byte] # Windows AcceptEx() buffer - aovl: CustomOverlapped # AcceptEx OVERLAPPED structure - const SO_UPDATE_CONNECT_CONTEXT = 0x7010 template finishWriter(t: untyped) = @@ -330,7 +323,13 @@ when defined(windows): # CancelIO() interrupt transp.state.excl(ReadPending) transp.state.incl(ReadPaused) + elif int32(err) in {WSAECONNRESET, WSAENETRESET}: + if not isNil(transp.reader): + transp.state = {ReadEof, ReadPaused} + transp.finishReader() elif int32(err) != ERROR_IO_PENDING: + transp.state.excl(ReadPending) + transp.state.incl(ReadPaused) transp.setReadError(err) if not isNil(transp.reader): transp.finishReader() @@ -426,99 +425,65 @@ when defined(windows): retFuture.fail(newException(OSError, osErrorMsg(err))) return retFuture - proc acceptAddr(server: WindowsStreamServer): Future[AsyncFD] = - var retFuture = newFuture[AsyncFD]("transport.acceptAddr") - let loop = getGlobalDispatcher() - var sock = createAsyncSocket(server.domain, SockType.SOCK_STREAM, - Protocol.IPPROTO_TCP) - if sock == asyncInvalidSocket: - retFuture.fail(newException(OSError, osErrorMsg(osLastError()))) + proc acceptLoop(udata: pointer) {.gcsafe, nimcall.} = + var ovl = cast[PtrCustomOverlapped](udata) + var server = cast[StreamServer](ovl.data.udata) + var loop = getGlobalDispatcher() - var dwBytesReceived = DWORD(0) - let dwReceiveDataLength = DWORD(0) - let dwLocalAddressLength = DWORD(sizeof(Sockaddr_in6) + 16) - let dwRemoteAddressLength = DWORD(sizeof(Sockaddr_in6) + 16) - - proc continuation(udata: pointer) = - var ovl = cast[PtrCustomOverlapped](udata) - if not retFuture.finished: - if server.server.status in {Stopped, Paused}: - sock.closeAsyncSocket() - retFuture.complete(asyncInvalidSocket) + while true: + if server.apending: + ## Continuation + server.apending = false + if server.status == ServerStatus.Stopped: + server.asock.closeAsyncSocket() else: if ovl.data.errCode == OSErrorCode(-1): - if setsockopt(SocketHandle(sock), cint(SOL_SOCKET), + if setsockopt(SocketHandle(server.asock), cint(SOL_SOCKET), cint(SO_UPDATE_ACCEPT_CONTEXT), - addr server.server.sock, + addr server.sock, SockLen(sizeof(SocketHandle))) != 0'i32: - sock.closeAsyncSocket() - retFuture.fail(newException(OSError, osErrorMsg(osLastError()))) + server.asock.closeAsyncSocket() + raiseOsError(osLastError()) else: - retFuture.complete(sock) + discard server.function(server, + newStreamSocketTransport(server.asock, server.bufferSize), + server.udata) elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED: # CancelIO() interrupt - sock.closeAsyncSocket() - retFuture.complete(asyncInvalidSocket) + server.asock.closeAsyncSocket() + break else: - sock.closeAsyncSocket() - retFuture.fail(newException(OSError, osErrorMsg(ovl.data.errCode))) + server.asock.closeAsyncSocket() + raiseOsError(osLastError()) + else: + ## Initiation + server.apending = true + server.asock = createAsyncSocket(server.domain, SockType.SOCK_STREAM, + Protocol.IPPROTO_TCP) + if server.asock == asyncInvalidSocket: + raiseOsError(osLastError()) - server.aovl.data.fd = server.server.sock - server.aovl.data.cb = continuation + var dwBytesReceived = DWORD(0) + let dwReceiveDataLength = DWORD(0) + let dwLocalAddressLength = DWORD(sizeof(Sockaddr_in6) + 16) + let dwRemoteAddressLength = DWORD(sizeof(Sockaddr_in6) + 16) - let res = loop.acceptEx(SocketHandle(server.server.sock), - SocketHandle(sock), addr server.abuffer[0], - dwReceiveDataLength, dwLocalAddressLength, - dwRemoteAddressLength, addr dwBytesReceived, - cast[POVERLAPPED](addr server.aovl)) - - if not res: - let err = osLastError() - if int32(err) != ERROR_IO_PENDING: - retFuture.fail(newException(OSError, osErrorMsg(err))) - return retFuture - - proc serverLoop(server: StreamServer): Future[void] {.async.} = - ## TODO: This procedure must be reviewed, when cancellation support - ## will be added - var wserver = new WindowsStreamServer - wserver.server = server - wserver.domain = server.local.address.getDomain() - await server.actEvent.wait() - server.actEvent.clear() - var acceptFut: Future[AsyncFD] - if server.action == ServerCommand.Start: - server.status = Running - var eventFut = server.actEvent.wait() - while true: - if server.status in {Paused}: - await eventFut - else: - acceptFut = acceptAddr(wserver) - await eventFut or acceptFut - if eventFut.finished: - server.actEvent.clear() - eventFut = server.actEvent.wait() - if server.action == ServerCommand.Start: - if server.status in {Stopped, Paused}: - server.status = Running - elif server.action == ServerCommand.Stop: - if server.status in {Running}: - server.status = Stopped - break - elif server.status in {Paused}: - server.status = Stopped - break - elif server.action == ServerCommand.Pause: - if server.status in {Running}: - server.status = Paused - if acceptFut.finished: - if not acceptFut.failed: - var sock = acceptFut.read() - if sock != asyncInvalidSocket: - discard server.function(server, - newStreamSocketTransport(sock, server.bufferSize), - server.udata) + let res = loop.acceptEx(SocketHandle(server.sock), + SocketHandle(server.asock), + addr server.abuffer[0], + dwReceiveDataLength, dwLocalAddressLength, + dwRemoteAddressLength, addr dwBytesReceived, + cast[POVERLAPPED](addr server.aovl)) + if not res: + let err = osLastError() + if int32(err) == ERROR_OPERATION_ABORTED: + server.apending = false + break + elif int32(err) == ERROR_IO_PENDING: + discard + else: + raiseOsError(osLastError()) + break proc resumeRead(transp: StreamTransport) {.inline.} = var wtransp = cast[WindowsStreamTransport](transp) @@ -530,6 +495,13 @@ when defined(windows): wtransp.state.excl(WritePaused) writeStreamLoop(cast[pointer](addr wtransp.wovl)) + proc pauseAccept(server: SocketServer) {.inline.} = + if server.apending: + discard cancelIO(Handle(server.sock)) + + proc resumeAccept(server: SocketServer) {.inline.} = + if not server.apending: + acceptLoop(cast[pointer](addr server.aovl)) else: import posix @@ -713,27 +685,11 @@ else: ## Critical unrecoverable error raiseOsError(err) - proc serverLoop(server: SocketServer): Future[void] {.async.} = - while true: - await server.actEvent.wait() - server.actEvent.clear() - if server.action == ServerCommand.Start: - if server.status in {Stopped, Paused, Starting}: - addReader(server.sock, serverCallback, - cast[pointer](server)) - server.status = Running - elif server.action == ServerCommand.Stop: - if server.status in {Running}: - removeReader(server.sock) - server.status = Stopped - break - elif server.status in {Paused}: - server.status = Stopped - break - elif server.action == ServerCommand.Pause: - if server.status in {Running}: - removeReader(server.sock) - server.status = Paused + proc resumeAccept(server: SocketServer) = + addReader(server.sock, serverCallback, cast[pointer](server)) + + proc pauseAccept(server: SocketServer) = + removeReader(server.sock) proc resumeRead(transp: StreamTransport) {.inline.} = transp.state.excl(ReadPaused) @@ -745,29 +701,28 @@ else: proc start*(server: SocketServer) = ## Starts ``server``. - server.action = Start - server.actEvent.fire() + if server.status == ServerStatus.Starting: + server.resumeAccept() + server.status = ServerStatus.Running proc stop*(server: SocketServer) = ## Stops ``server``. - server.action = Stop - server.actEvent.fire() - -proc pause*(server: SocketServer) = - ## Pause ``server``. - when defined(windows): - discard cancelIo(Handle(server.sock)) - server.action = Pause - server.actEvent.fire() + if server.status == ServerStatus.Running: + server.pauseAccept() + server.status = ServerStatus.Stopped proc join*(server: SocketServer) {.async.} = - ## Waits until ``server`` is not stopped. + ## Waits until ``server`` is not closed. if not server.loopFuture.finished: await server.loopFuture proc close*(server: SocketServer) = ## Release ``server`` resources. - GC_unref(server) + if server.status == ServerStatus.Stopped: + closeAsyncSocket(server.sock) + server.status = Closed + server.loopFuture.complete() + GC_unref(server) proc createStreamServer*(host: TransportAddress, cbproc: StreamCallback, @@ -777,7 +732,7 @@ proc createStreamServer*(host: TransportAddress, bufferSize: int = DefaultStreamBufferSize, udata: pointer = nil): StreamServer = ## Create new TCP stream server. - ## + ## ## ``host`` - address to which server will be bound. ## ``flags`` - flags to apply to server socket. ## ``cbproc`` - callback function which will be called, when new client @@ -829,11 +784,17 @@ proc createStreamServer*(host: TransportAddress, result.function = cbproc result.bufferSize = bufferSize result.status = Starting - result.actEvent = newAsyncEvent() + result.loopFuture = newFuture[void]("stream.server") result.udata = udata result.local = host + + when defined(windows): + result.aovl.data = CompletionData(fd: serverSocket, cb: acceptLoop, + udata: cast[pointer](result)) + result.domain = host.address.getDomain() + result.apending = false GC_ref(result) - result.loopFuture = serverLoop(result) + result.resumeAccept() proc write*(transp: StreamTransport, pbytes: pointer, nbytes: int): Future[int] {.async.} = @@ -855,7 +816,7 @@ proc writeFile*(transp: StreamTransport, handle: int, offset: uint = 0, size: int = 0): Future[void] {.async.} = ## Write data from file descriptor ``handle`` to transport ``transp``. - ## + ## ## You can specify starting ``offset`` in opened file and number of bytes ## to transfer from file to transport via ``size``. if transp.kind != TransportKind.Socket: @@ -876,7 +837,7 @@ proc readExactly*(transp: StreamTransport, pbytes: pointer, nbytes: int) {.async.} = ## Read exactly ``nbytes`` bytes from transport ``transp`` and store it to ## ``pbytes``. - ## + ## ## If EOF is received and ``nbytes`` is not yet readed, the procedure ## will raise ``TransportIncompleteError``. checkClosed(transp) @@ -912,7 +873,7 @@ proc readExactly*(transp: StreamTransport, pbytes: pointer, proc readOnce*(transp: StreamTransport, pbytes: pointer, nbytes: int): Future[int] {.async.} = ## Perform one read operation on transport ``transp``. - ## + ## ## If internal buffer is not empty, ``nbytes`` bytes will be transferred from ## internal buffer, otherwise it will wait until some bytes will be received. checkClosed(transp) @@ -928,6 +889,8 @@ proc readOnce*(transp: StreamTransport, pbytes: pointer, if ReadPaused in transp.state: transp.resumeRead() await transp.reader + + # we are no longer need data transp.reader = nil else: if transp.offset > nbytes: @@ -942,16 +905,16 @@ proc readOnce*(transp: StreamTransport, pbytes: pointer, proc readUntil*(transp: StreamTransport, pbytes: pointer, nbytes: int, sep: seq[byte]): Future[int] {.async.} = ## Read data from the transport ``transp`` until separator ``sep`` is found. - ## + ## ## On success, the data and separator will be removed from the internal ## buffer (consumed). Returned data will NOT include the separator at the end. - ## + ## ## If EOF is received, and `sep` was not found, procedure will raise ## ``TransportIncompleteError``. - ## + ## ## If ``nbytes`` bytes has been received and `sep` was not found, procedure ## will raise ``TransportLimitError``. - ## + ## ## Procedure returns actual number of bytes read. checkClosed(transp) checkPending(transp) @@ -1006,13 +969,13 @@ proc readLine*(transp: StreamTransport, limit = 0, sep = "\r\n"): Future[string] {.async.} = ## Read one line from transport ``transp``, where "line" is a sequence of ## bytes ending with ``sep`` (default is "\r\n"). - ## + ## ## If EOF is received, and ``sep`` was not found, the method will return the ## partial read bytes. - ## + ## ## If the EOF was received and the internal buffer is empty, return an ## empty string. - ## + ## ## If ``limit`` more then 0, then read is limited to ``limit`` bytes. checkClosed(transp) checkPending(transp) @@ -1061,7 +1024,7 @@ proc readLine*(transp: StreamTransport, limit = 0, proc read*(transp: StreamTransport, n = -1): Future[seq[byte]] {.async.} = ## Read all bytes (n == -1) or `n` bytes from transport ``transp``. - ## + ## ## This procedure allocates buffer seq[byte] and return it as result. checkClosed(transp) checkPending(transp) @@ -1071,8 +1034,7 @@ proc read*(transp: StreamTransport, n = -1): Future[seq[byte]] {.async.} = while true: if (ReadError in transp.state): raise transp.getError() - # ZAH: Shouldn't this be {ReadEof, ReadClosed} * transp.state != {} - if (ReadEof in transp.state) or (ReadClosed in transp.state): + if {ReadEof, ReadClosed} * transp.state != {}: break if transp.offset > 0: @@ -1120,7 +1082,7 @@ proc join*(transp: StreamTransport) {.async.} = proc close*(transp: StreamTransport) = ## Closes and frees resources of transport ``transp``. - if ReadClosed notin transp.state and WriteClosed notin transp.state: + if {ReadClosed, WriteClosed} * transp.state == {}: when defined(windows): discard cancelIo(Handle(transp.fd)) closeAsyncSocket(transp.fd) diff --git a/tests/testserver.nim b/tests/testserver.nim index aa52055..f4a43c0 100644 --- a/tests/testserver.nim +++ b/tests/testserver.nim @@ -7,54 +7,45 @@ # MIT license (LICENSE-MIT) import strutils, unittest -import ../asyncdispatch2, ../asyncdispatch2/timer +import ../asyncdispatch2 -const TimeoutPeriod = 2000 +proc serveStreamClient(server: StreamServer, + transp: StreamTransport, udata: pointer) {.async.} = + discard -proc serveClient1(server: StreamServer, - transp: StreamTransport, udata: pointer) {.async.} = - var data = await transp.readLine() - if len(data) == 0: - doAssert(transp.atEof()) - if data == "PAUSE": - server.pause() - data = "DONE\r\n" - var res = await transp.write(cast[pointer](addr data[0]), len(data)) - doAssert(res == len(data)) - await sleepAsync(TimeoutPeriod) - server.start() - elif data == "CHECK": - data = "CONFIRM\r\n" - var res = await transp.write(cast[pointer](addr data[0]), len(data)) - doAssert(res == len(data)) - transp.close() +proc serveDatagramClient(transp: DatagramTransport, + pbytes: pointer, nbytes: int, + raddr: TransportAddress, + udata: pointer): Future[void] {.async.} = + discard -proc swarmWorker1(address: TransportAddress): Future[int] {.async.} = - var transp1 = await connect(address) - var data = "PAUSE\r\n" - var res = await transp1.write(cast[pointer](addr data[0]), len(data)) - doAssert(res == len(data)) - var answer = await transp1.readLine() - doAssert(answer == "DONE") - var st = fastEpochTime() - var transp2 = await connect(address) - data = "CHECK\r\n" - res = await transp2.write(cast[pointer](addr data[0]), len(data)) - doAssert(res == len(data)) - var confirm = await transp2.readLine() - doAssert(confirm == "CONFIRM") - var et = fastEpochTime() - result = int(et - st) - -proc test1(): Future[int] {.async.} = +proc test1(): bool = var ta = initTAddress("127.0.0.1:31354") - var server = createStreamServer(ta, serveClient1, {ReuseAddr}) - server.start() - result = await swarmWorker1(ta) - server.stop() - server.close() + var server1 = createStreamServer(ta, serveStreamClient, {ReuseAddr}) + server1.start() + server1.stop() + server1.close() + var server2 = createStreamServer(ta, serveStreamClient, {ReuseAddr}) + server2.start() + server2.stop() + server2.close() + result = true + +proc test2(): bool = + var ta = initTAddress("127.0.0.1:31354") + var server1 = createDatagramServer(ta, serveDatagramClient, {ReuseAddr}) + server1.start() + server1.stop() + server1.close() + var server2 = createDatagramServer(ta, serveDatagramClient, {ReuseAddr}) + server2.start() + server2.stop() + server2.close() + result = true when isMainModule: suite "Server's test suite": - test "Server pause/resume test": - check waitFor(test1()) >= TimeoutPeriod + test "Stream Server start/stop test": + check test1() == true + test "Datagram Server start/stop test": + check test2() == true