diff --git a/asyncdispatch2/transports/common.nim b/asyncdispatch2/transports/common.nim index df2a3ea..3f685b9 100644 --- a/asyncdispatch2/transports/common.nim +++ b/asyncdispatch2/transports/common.nim @@ -281,6 +281,11 @@ template getError*(t: untyped): ref Exception = (t).error = nil err +proc raiseTransportOsError*(err: OSErrorCode) = + ## Raises transport specific OS error. + var msg = "(" & $int(err) & ") " & osErrorMsg(err) + raise newException(TransportOsError, msg) + when defined(windows): import winlean diff --git a/asyncdispatch2/transports/stream.nim b/asyncdispatch2/transports/stream.nim index 30ab3fd..6641149 100644 --- a/asyncdispatch2/transports/stream.nim +++ b/asyncdispatch2/transports/stream.nim @@ -33,25 +33,56 @@ type Pipe, # Pipe transport File # File transport +when defined(windows): + const SO_UPDATE_CONNECT_CONTEXT = 0x7010 + + type + StreamTransport* = ref object of RootRef + fd*: AsyncFD # File descriptor + state: set[TransportState] # Current Transport state + reader: Future[void] # Current reader Future + buffer: seq[byte] # Reading buffer + offset: int # Reading buffer offset + error: ref Exception # Current error + queue: Deque[StreamVector] # Writer queue + future: Future[void] # Stream life future + # Windows specific part + rwsabuf: TWSABuf # Reader WSABUF + wwsabuf: TWSABuf # Writer WSABUF + rovl: CustomOverlapped # Reader OVERLAPPED structure + wovl: CustomOverlapped # Writer OVERLAPPED structure + roffset: int # Pending reading offset + case kind*: TransportKind + of TransportKind.Socket: + domain: Domain # Socket transport domain (IPv4/IPv6) + local: TransportAddress # Local address + remote: TransportAddress # Remote address + of TransportKind.Pipe: + todo1: int + of TransportKind.File: + todo2: int +else: + type + StreamTransport* = ref object of RootRef + fd*: AsyncFD # File descriptor + state: set[TransportState] # Current Transport state + reader: Future[void] # Current reader Future + buffer: seq[byte] # Reading buffer + offset: int # Reading buffer offset + error: ref Exception # Current error + queue: Deque[StreamVector] # Writer queue + future: Future[void] # Stream life future + case kind*: TransportKind + of TransportKind.Socket: + domain: Domain # Socket transport domain (IPv4/IPv6) + local: TransportAddress # Local address + remote: TransportAddress # Remote address + of TransportKind.Pipe: + todo1: int + of TransportKind.File: + todo2: int + type - StreamTransport* = ref object of RootRef - fd*: AsyncFD # File descriptor - state: set[TransportState] # Current Transport state - reader: Future[void] # Current reader Future - buffer: seq[byte] # Reading buffer - offset: int # Reading buffer offset - error: ref Exception # Current error - queue: Deque[StreamVector] # Writer queue - future: Future[void] # Stream life future - case kind*: TransportKind - of TransportKind.Socket: - domain: Domain # Socket transport domain (IPv4/IPv6) - local: TransportAddress # Local address - remote: TransportAddress # Remote address - of TransportKind.Pipe: - todo1: int - of TransportKind.File: - todo2: int StreamCallback* = proc(server: StreamServer, client: StreamTransport): Future[void] {.gcsafe.} @@ -59,8 +90,17 @@ type ## ``server`` - StreamServer object. ## ``client`` - accepted client transport. + TransportInitCallback* = proc(server: StreamServer, + fd: AsyncFD): StreamTransport {.gcsafe.} + ## Custom transport initialization procedure, which can allocated inherited + ## StreamTransport object. + StreamServer* = ref object of SocketServer - function*: StreamCallback + ## StreamServer object + function*: StreamCallback # callback which will be called after new + # client accepted + init*: TransportInitCallback # callback which will be called before + # transport for new client proc remoteAddress*(transp: StreamTransport): TransportAddress = ## Returns ``transp`` remote socket address. @@ -116,16 +156,6 @@ template shiftVectorFile(v, o: untyped) = (v).offset += cast[uint]((o)) when defined(windows): - import winlean - type - WindowsStreamTransport = ref object of StreamTransport - rwsabuf: TWSABuf # Reader WSABUF - wwsabuf: TWSABuf # Writer WSABUF - rovl: CustomOverlapped # Reader OVERLAPPED structure - wovl: CustomOverlapped # Writer OVERLAPPED structure - roffset: int # Pending reading offset - - const SO_UPDATE_CONNECT_CONTEXT = 0x7010 template zeroOvelappedOffset(t: untyped) = (t).offset = 0 @@ -157,7 +187,7 @@ when defined(windows): proc writeStreamLoop(udata: pointer) {.gcsafe, nimcall.} = var bytesCount: int32 var ovl = cast[PtrCustomOverlapped](udata) - var transp = cast[WindowsStreamTransport](ovl.data.udata) + var transp = cast[StreamTransport](ovl.data.udata) while len(transp.queue) > 0: if WritePending in transp.state: @@ -258,7 +288,7 @@ when defined(windows): proc readStreamLoop(udata: pointer) {.gcsafe, nimcall.} = var ovl = cast[PtrCustomOverlapped](udata) - var transp = cast[WindowsStreamTransport](ovl.data.udata) + var transp = cast[StreamTransport](ovl.data.udata) while true: if ReadPending in transp.state: @@ -324,8 +354,13 @@ when defined(windows): ## Finish Loop break - proc newStreamSocketTransport(sock: AsyncFD, bufsize: int): StreamTransport = - var transp = WindowsStreamTransport(kind: TransportKind.Socket) + proc newStreamSocketTransport(sock: AsyncFD, bufsize: int, + child: StreamTransport): StreamTransport = + var transp: StreamTransport + if not isNil(child): + transp = child + else: + transp = StreamTransport(kind: TransportKind.Socket) transp.fd = sock transp.rovl.data = CompletionData(fd: sock, cb: readStreamLoop, udata: cast[pointer](transp)) @@ -335,12 +370,8 @@ when defined(windows): transp.state = {ReadPaused, WritePaused} transp.queue = initDeque[StreamVector]() transp.future = newFuture[void]("stream.socket.transport") - # ZAH: If these objects are going to be manually managed, why do we bother - # with using the GC at all? It's better to rely on a destructor. If someone - # wants to share a Transport reference, they can still create a GC-managed - # wrapping object. GC_ref(transp) - result = cast[StreamTransport](transp) + result = transp proc bindToDomain(handle: AsyncFD, domain: Domain): bool = result = true @@ -358,7 +389,8 @@ when defined(windows): result = false proc connect*(address: TransportAddress, - bufferSize = DefaultStreamBufferSize): Future[StreamTransport] = + bufferSize = DefaultStreamBufferSize, + child: StreamTransport = nil): Future[StreamTransport] = ## Open new connection to remote peer with address ``address`` and create ## new transport object ``StreamTransport`` for established connection. ## ``bufferSize`` is size of internal buffer for transport. @@ -392,7 +424,8 @@ when defined(windows): osErrorMsg(osLastError()))) else: retFuture.complete(newStreamSocketTransport(povl.data.fd, - bufferSize)) + bufferSize, + child)) else: sock.closeAsyncSocket() retFuture.fail(newException(TransportOsError, @@ -433,24 +466,38 @@ when defined(windows): addr server.sock, SockLen(sizeof(SocketHandle))) != 0'i32: server.asock.closeAsyncSocket() - raise newException(TransportOsError, osErrorMsg(osLastError())) + raiseTransportOsError(osLastError()) else: - discard server.function(server, - newStreamSocketTransport(server.asock, server.bufferSize)) + if not isNil(server.init): + var transp = server.init(server, server.asock) + discard server.function( + server, + newStreamSocketTransport(server.asock, server.bufferSize, + transp) + ) + else: + discard server.function( + server, + newStreamSocketTransport(server.asock, server.bufferSize, nil) + ) elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED: # CancelIO() interrupt server.asock.closeAsyncSocket() break else: server.asock.closeAsyncSocket() - raiseOsError(osLastError()) + raiseTransportOsError(osLastError()) else: ## Initiation + if server.status in {ServerStatus.Stopped, ServerStatus.Closed}: + ## Server was already stopped/closed exiting + break + server.apending = true server.asock = createAsyncSocket(server.domain, SockType.SOCK_STREAM, Protocol.IPPROTO_TCP) if server.asock == asyncInvalidSocket: - raiseOsError(osLastError()) + raiseTransportOsError(osLastError()) var dwBytesReceived = DWORD(0) let dwReceiveDataLength = DWORD(0) @@ -471,18 +518,16 @@ when defined(windows): elif int32(err) == ERROR_IO_PENDING: discard else: - raiseOsError(osLastError()) + raiseTransportOsError(err) break proc resumeRead(transp: StreamTransport) {.inline.} = - var wtransp = cast[WindowsStreamTransport](transp) - wtransp.state.excl(ReadPaused) - readStreamLoop(cast[pointer](addr wtransp.rovl)) + transp.state.excl(ReadPaused) + readStreamLoop(cast[pointer](addr transp.rovl)) proc resumeWrite(transp: StreamTransport) {.inline.} = - var wtransp = cast[WindowsStreamTransport](transp) - wtransp.state.excl(WritePaused) - writeStreamLoop(cast[pointer](addr wtransp.wovl)) + transp.state.excl(WritePaused) + writeStreamLoop(cast[pointer](addr transp.wovl)) proc pauseAccept(server: StreamServer) {.inline.} = if server.apending: @@ -492,10 +537,6 @@ when defined(windows): if not server.apending: acceptLoop(cast[pointer](addr server.aovl)) else: - import posix - - type - UnixStreamTransport* = ref object of StreamTransport template getVectorBuffer(v: untyped): pointer = cast[pointer](cast[uint]((v).buf) + uint((v).boffset)) @@ -514,7 +555,7 @@ else: if not isNil(cdata) and (int(cdata.fd) == 0 or isNil(cdata.udata)): # Transport was closed earlier, exiting return - var transp = cast[UnixStreamTransport](cdata.udata) + var transp = cast[StreamTransport](cdata.udata) let fd = SocketHandle(cdata.fd) if len(transp.queue) > 0: var vector = transp.queue.popFirst() @@ -562,7 +603,7 @@ else: if not isNil(cdata) and (int(cdata.fd) == 0 or isNil(cdata.udata)): # Transport was closed earlier, exiting return - var transp = cast[UnixStreamTransport](cdata.udata) + var transp = cast[StreamTransport](cdata.udata) let fd = SocketHandle(cdata.fd) while true: var res = posix.recv(fd, addr transp.buffer[transp.offset], @@ -589,18 +630,25 @@ else: transp.finishReader() break - proc newStreamSocketTransport(sock: AsyncFD, bufsize: int): StreamTransport = - var transp = UnixStreamTransport(kind: TransportKind.Socket) + proc newStreamSocketTransport(sock: AsyncFD, bufsize: int, + child: StreamTransport): StreamTransport = + var transp: StreamTransport + if not isNil(child): + transp = child + else: + transp = StreamTransport(kind: TransportKind.Socket) + transp.fd = sock transp.buffer = newSeq[byte](bufsize) transp.state = {ReadPaused, WritePaused} transp.queue = initDeque[StreamVector]() transp.future = newFuture[void]("socket.stream.transport") GC_ref(transp) - result = cast[StreamTransport](transp) + result = transp proc connect*(address: TransportAddress, - bufferSize = DefaultStreamBufferSize): Future[StreamTransport] = + bufferSize = DefaultStreamBufferSize, + child: StreamTransport = nil): Future[StreamTransport] = ## Open new connection to remote peer with address ``address`` and create ## new transport object ``StreamTransport`` for established connection. ## ``bufferSize`` - size of internal buffer for transport. @@ -613,7 +661,7 @@ else: sock = createAsyncSocket(address.address.getDomain(), SockType.SOCK_STREAM, Protocol.IPPROTO_TCP) if sock == asyncInvalidSocket: - retFuture.fail(newException(OSError, osErrorMsg(osLastError()))) + retFuture.fail(newException(TransportOsError, osErrorMsg(osLastError()))) return retFuture proc continuation(udata: pointer) = @@ -622,20 +670,22 @@ else: let fd = data.fd if not fd.getSocketError(err): fd.closeAsyncSocket() - retFuture.fail(newException(OSError, osErrorMsg(osLastError()))) + retFuture.fail(newException(TransportOsError, + osErrorMsg(osLastError()))) return if err != 0: fd.closeAsyncSocket() - retFuture.fail(newException(OSError, osErrorMsg(OSErrorCode(err)))) + retFuture.fail(newException(TransportOsError, + osErrorMsg(OSErrorCode(err)))) return fd.removeWriter() - retFuture.complete(newStreamSocketTransport(fd, bufferSize)) + retFuture.complete(newStreamSocketTransport(fd, bufferSize, child)) while true: var res = posix.connect(SocketHandle(sock), cast[ptr SockAddr](addr saddr), slen) if res == 0: - retFuture.complete(newStreamSocketTransport(sock, bufferSize)) + retFuture.complete(newStreamSocketTransport(sock, bufferSize, child)) break else: let err = osLastError() @@ -646,11 +696,11 @@ else: break else: sock.closeAsyncSocket() - retFuture.fail(newException(OSError, osErrorMsg(err))) + retFuture.fail(newException(TransportOsError, osErrorMsg(err))) break return retFuture - proc serverCallback(udata: pointer) = + proc acceptLoop(udata: pointer) = var saddr: Sockaddr_storage slen: SockLen @@ -661,8 +711,13 @@ else: if int(res) > 0: let sock = wrapAsyncSocket(res) if sock != asyncInvalidSocket: - discard server.function(server, - newStreamSocketTransport(sock, server.bufferSize)) + if not isNil(server.init): + var transp = server.init(server, sock) + discard server.function(server, + newStreamSocketTransport(sock, server.bufferSize, transp)) + else: + discard server.function(server, + newStreamSocketTransport(sock, server.bufferSize, nil)) break else: let err = osLastError() @@ -670,10 +725,10 @@ else: continue else: ## Critical unrecoverable error - raiseOsError(err) + raiseTransportOsError(err) proc resumeAccept(server: StreamServer) = - addReader(server.sock, serverCallback, cast[pointer](server)) + addReader(server.sock, acceptLoop, cast[pointer](server)) proc pauseAccept(server: StreamServer) = removeReader(server.sock) @@ -709,7 +764,7 @@ proc close*(server: StreamServer) = ## Release ``server`` resources. if server.status == ServerStatus.Stopped: closeAsyncSocket(server.sock) - server.status = Closed + server.status = ServerStatus.Closed server.loopFuture.complete() if not isNil(server.udata) and GCUserData in server.flags: GC_unref(cast[ref int](server.udata)) @@ -721,6 +776,8 @@ proc createStreamServer*(host: TransportAddress, sock: AsyncFD = asyncInvalidSocket, backlog: int = 100, bufferSize: int = DefaultStreamBufferSize, + child: StreamServer = nil, + init: TransportInitCallback = nil, udata: pointer = nil): StreamServer = ## Create new TCP stream server. ## @@ -732,6 +789,8 @@ proc createStreamServer*(host: TransportAddress, ## ``backlog`` - number of outstanding connections in the socket's listen ## queue. ## ``bufferSize`` - size of internal buffer for transport. + ## ``child`` - existing object ``StreamServer``object to initialize, can be + ## used to initalize ``StreamServer`` inherited objects. ## ``udata`` - user-defined pointer. var saddr: Sockaddr_storage @@ -742,10 +801,10 @@ proc createStreamServer*(host: TransportAddress, SockType.SOCK_STREAM, Protocol.IPPROTO_TCP) if serverSocket == asyncInvalidSocket: - raiseOsError(osLastError()) + raiseTransportOsError(osLastError()) else: if not setSocketBlocking(SocketHandle(sock), false): - raiseOsError(osLastError()) + raiseTransportOsError(osLastError()) register(sock) serverSocket = sock @@ -754,7 +813,7 @@ proc createStreamServer*(host: TransportAddress, let err = osLastError() if sock == asyncInvalidSocket: closeAsyncSocket(serverSocket) - raiseOsError(err) + raiseTransportOsError(err) toSockAddr(host.address, host.port, saddr, slen) if bindAddr(SocketHandle(serverSocket), cast[ptr SockAddr](addr saddr), @@ -762,17 +821,22 @@ proc createStreamServer*(host: TransportAddress, let err = osLastError() if sock == asyncInvalidSocket: closeAsyncSocket(serverSocket) - raiseOsError(err) + raiseTransportOsError(err) if nativesockets.listen(SocketHandle(serverSocket), cint(backlog)) != 0: let err = osLastError() if sock == asyncInvalidSocket: closeAsyncSocket(serverSocket) - raiseOsError(err) + raiseTransportOsError(err) + + if not isNil(child): + result = child + else: + result = StreamServer() - result = StreamServer() result.sock = serverSocket result.function = cbproc + result.init = init result.bufferSize = bufferSize result.status = Starting result.loopFuture = newFuture[void]("stream.server") @@ -790,10 +854,11 @@ proc createStreamServer*(host: TransportAddress, proc createStreamServer*[T](host: TransportAddress, cbproc: StreamCallback, flags: set[ServerFlags] = {}, + udata: ref T, sock: AsyncFD = asyncInvalidSocket, backlog: int = 100, bufferSize: int = DefaultStreamBufferSize, - udata: ref T): StreamServer = + child: StreamServer = nil): StreamServer = var fflags = flags + {GCUserData} GC_ref(udata) result = createStreamServer(host, cbproc, flags, sock, backlog, bufferSize, diff --git a/tests/testserver.nim b/tests/testserver.nim index 79cf814..37c4896 100644 --- a/tests/testserver.nim +++ b/tests/testserver.nim @@ -9,10 +9,34 @@ import strutils, unittest import ../asyncdispatch2 +type + CustomServer = ref object of StreamServer + test1: string + test2: string + + CustomTransport = ref object of StreamTransport + test: string + proc serveStreamClient(server: StreamServer, transp: StreamTransport) {.async.} = discard +proc serveCustomStreamClient(server: StreamServer, + transp: StreamTransport) {.async.} = + var cserver = cast[CustomServer](server) + var ctransp = cast[CustomTransport](transp) + cserver.test1 = "CONNECTION" + cserver.test2 = ctransp.test + transp.close() + server.stop() + server.close() + +proc customServerTransport(server: StreamServer, + fd: AsyncFD): StreamTransport = + var transp = CustomTransport() + transp.test = "CUSTOM" + result = cast[StreamTransport](transp) + proc serveDatagramClient(transp: DatagramTransport, pbytes: pointer, nbytes: int, raddr: TransportAddress, @@ -21,11 +45,11 @@ proc serveDatagramClient(transp: DatagramTransport, proc test1(): bool = var ta = initTAddress("127.0.0.1:31354") - var server1 = createStreamServer(ta, serveStreamClient, {ReuseAddr}) + var server1 = createStreamServer(ta, serveStreamClient, {}) server1.start() server1.stop() server1.close() - var server2 = createStreamServer(ta, serveStreamClient, {ReuseAddr}) + var server2 = createStreamServer(ta, serveStreamClient, {}) server2.start() server2.stop() server2.close() @@ -33,19 +57,41 @@ proc test1(): bool = proc test2(): bool = var ta = initTAddress("127.0.0.1:31354") - var server1 = createDatagramServer(ta, serveDatagramClient, {ReuseAddr}) + var server1 = createDatagramServer(ta, serveDatagramClient, {}) server1.start() server1.stop() server1.close() - var server2 = createDatagramServer(ta, serveDatagramClient, {ReuseAddr}) + var server2 = createDatagramServer(ta, serveDatagramClient, {}) server2.start() server2.stop() server2.close() result = true +proc client(server: CustomServer, ta: TransportAddress) {.async.} = + var transp = CustomTransport() + transp.test = "CLIENT" + server.start() + var ptransp = await connect(ta, child = transp) + var etransp = cast[CustomTransport](ptransp) + doAssert(etransp.test == "CLIENT") + transp.close() + await server.join() + +proc test3(): bool = + var server = CustomServer() + server.test1 = "TEST" + var ta = initTAddress("127.0.0.1:31354") + var pserver = createStreamServer(ta, serveCustomStreamClient, {}, + child = cast[StreamServer](server), + init = customServerTransport) + waitFor client(server, ta) + result = (server.test1 == "CONNECTION") and (server.test2 == "CUSTOM") + when isMainModule: suite "Server's test suite": test "Stream Server start/stop test": check test1() == true + test "Stream Server inherited object test": + check test3() == true test "Datagram Server start/stop test": check test2() == true