diff --git a/asyncdispatch2/transports/common.nim b/asyncdispatch2/transports/common.nim index 32f996f..8f4bd9d 100644 --- a/asyncdispatch2/transports/common.nim +++ b/asyncdispatch2/transports/common.nim @@ -18,7 +18,7 @@ const type ServerFlags* = enum ## Server's flags - ReuseAddr, ReusePort, NoAutoRead + ReuseAddr, ReusePort, NoAutoRead, GCUserData TransportAddress* = object ## Transport network address diff --git a/asyncdispatch2/transports/stream.nim b/asyncdispatch2/transports/stream.nim index 37380d0..3860be6 100644 --- a/asyncdispatch2/transports/stream.nim +++ b/asyncdispatch2/transports/stream.nim @@ -55,12 +55,10 @@ type todo2: int StreamCallback* = proc(server: StreamServer, - client: StreamTransport, - udata: pointer): Future[void] {.gcsafe.} + client: StreamTransport): Future[void] {.gcsafe.} ## New remote client connection callback ## ``server`` - StreamServer object. ## ``client`` - accepted client transport. - ## ``udata`` - user-defined pointer passed at ``createStreamServer()`` call. StreamServer* = ref object of SocketServer function*: StreamCallback @@ -440,8 +438,7 @@ when defined(windows): raise newException(TransportOsError, osErrorMsg(osLastError())) else: discard server.function(server, - newStreamSocketTransport(server.asock, server.bufferSize), - server.udata) + newStreamSocketTransport(server.asock, server.bufferSize)) elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED: # CancelIO() interrupt server.asock.closeAsyncSocket() @@ -714,6 +711,8 @@ proc close*(server: SocketServer) = closeAsyncSocket(server.sock) server.status = Closed server.loopFuture.complete() + if not isNil(server.udata) and GCUserData in server.flags: + GC_unref(cast[ref int](server.udata)) GC_unref(server) proc createStreamServer*(host: TransportAddress, @@ -788,6 +787,21 @@ proc createStreamServer*(host: TransportAddress, GC_ref(result) result.resumeAccept() +proc createStreamServer*[T](host: TransportAddress, + cbproc: StreamCallback, + flags: set[ServerFlags] = {}, + sock: AsyncFD = asyncInvalidSocket, + backlog: int = 100, + bufferSize: int = DefaultStreamBufferSize, + udata: ref T): StreamServer = + var fflags = flags + {GCUserData} + GC_ref(udata) + result = createStreamServer(host, cbproc, flags, sock, backlog, bufferSize, + cast[pointer](udata)) + +proc getUserData*[T](server: StreamServer): ref T {.inline.} = + result = cast[ref T](server.udata) + proc write*(transp: StreamTransport, pbytes: pointer, nbytes: int): Future[int] = ## Write data from buffer ``pbytes`` with size ``nbytes`` using transport diff --git a/tests/teststream.nim b/tests/teststream.nim index a1198be..5018cb3 100644 --- a/tests/teststream.nim +++ b/tests/teststream.nim @@ -24,8 +24,7 @@ const FilesCount = 50 FilesTestName = "tests/teststream.nim" -proc serveClient1(server: StreamServer, - transp: StreamTransport, udata: pointer) {.async.} = +proc serveClient1(server: StreamServer, transp: StreamTransport) {.async.} = while not transp.atEof(): var data = await transp.readLine() if len(data) == 0: @@ -39,8 +38,7 @@ proc serveClient1(server: StreamServer, doAssert(res == len(ans)) transp.close() -proc serveClient2(server: StreamServer, - transp: StreamTransport, udata: pointer) {.async.} = +proc serveClient2(server: StreamServer, transp: StreamTransport) {.async.} = var buffer: array[20, char] var check = "REQUEST" while not transp.atEof(): @@ -63,8 +61,7 @@ proc serveClient2(server: StreamServer, doAssert(res == MessageSize) transp.close() -proc serveClient3(server: StreamServer, - transp: StreamTransport, udata: pointer) {.async.} = +proc serveClient3(server: StreamServer, transp: StreamTransport) {.async.} = var buffer: array[20, char] var check = "REQUEST" var suffixStr = "SUFFIX" @@ -90,8 +87,7 @@ proc serveClient3(server: StreamServer, dec(counter) transp.close() -proc serveClient4(server: StreamServer, - transp: StreamTransport, udata: pointer) {.async.} = +proc serveClient4(server: StreamServer, transp: StreamTransport) {.async.} = var pathname = await transp.readLine() var size = await transp.readLine() var sizeNum = parseInt(size) @@ -106,8 +102,7 @@ proc serveClient4(server: StreamServer, doAssert(res == len(answer)) transp.close() -proc serveClient5(server: StreamServer, - transp: StreamTransport, udata: pointer) {.async.} = +proc serveClient5(server: StreamServer, transp: StreamTransport) {.async.} = var data = await transp.read() doAssert(len(data) == len(ConstantMessage) * MessagesCount) transp.close() @@ -115,14 +110,13 @@ proc serveClient5(server: StreamServer, for i in 0..