From 6a1f7785a01ef4a205741ccc12c6e758a5462136 Mon Sep 17 00:00:00 2001 From: cheatfate Date: Thu, 4 Apr 2019 12:34:23 +0300 Subject: [PATCH] Custom tracking mechanism. 1. Add simple tracking of Datagram and Stream transports. 2. Fix leaks in tests. 3. Add leaks tests to Datagram and Stream transport tests. --- chronos/asyncloop.nim | 21 ++++- chronos/transports/datagram.nim | 67 +++++++++++++- chronos/transports/stream.nim | 154 +++++++++++++++++++++++++++----- tests/testdatagram.nim | 2 + tests/teststream.nim | 6 ++ 5 files changed, 223 insertions(+), 27 deletions(-) diff --git a/chronos/asyncloop.nim b/chronos/asyncloop.nim index 6befa34..59947b1 100644 --- a/chronos/asyncloop.nim +++ b/chronos/asyncloop.nim @@ -182,9 +182,15 @@ type finishAt*: Moment function*: AsyncCallback + TrackerBase* = ref object of RootRef + id*: string + dump*: proc(): string {.gcsafe.} + isLeaked*: proc(): bool {.gcsafe.} + PDispatcherBase = ref object of RootRef timers*: HeapQueue[TimerCallback] callbacks*: Deque[AsyncCallback] + trackers*: Table[string, TrackerBase] proc `<`(a, b: TimerCallback): bool = result = a.finishAt < b.finishAt @@ -305,6 +311,7 @@ when defined(windows) or defined(nimdoc): result.handles = initSet[AsyncFD]() result.timers.newHeapQueue() result.callbacks = initDeque[AsyncCallback](64) + result.trackers = initTable[string, TrackerBase]() var gDisp{.threadvar.}: PDispatcher ## Global dispatcher @@ -479,6 +486,7 @@ else: result.timers.newHeapQueue() result.callbacks = initDeque[AsyncCallback](64) result.keys = newSeq[ReadyKey](64) + result.trackers = initTable[string, TrackerBase]() var gDisp{.threadvar.}: PDispatcher ## Global dispatcher @@ -804,7 +812,7 @@ include asyncmacro2 proc callSoon(cbproc: CallbackFunc, data: pointer = nil) = ## Schedule `cbproc` to be called as soon as possible. ## The callback is called when control returns to the event loop. - doAssert cbproc != nil + doAssert(not isNil(cbproc)) let acb = AsyncCallback(function: cbproc, udata: data) getGlobalDispatcher().callbacks.addLast(acb) @@ -820,5 +828,16 @@ proc waitFor*[T](fut: Future[T]): T = fut.read +proc addTracker*[T](id: string, tracker: T) = + ## Add new ``tracker`` object to current thread dispatcher with identifier + ## ``id``. + let loop = getGlobalDispatcher() + loop.trackers[id] = tracker + +proc getTracker*(id: string): TrackerBase = + ## Get ``tracker`` from current thread dispatcher using identifier ``id``. + let loop = getGlobalDispatcher() + result = loop.trackers.getOrDefault(id, nil) + # Global API and callSoon() initialization. initAPI() diff --git a/chronos/transports/datagram.nim b/chronos/transports/datagram.nim index d0f2223..641ca8b 100644 --- a/chronos/transports/datagram.nim +++ b/chronos/transports/datagram.nim @@ -54,6 +54,13 @@ type rwsabuf: TWSABuf # Reader WSABUF structure wwsabuf: TWSABuf # Writer WSABUF structure + DgramTransportTracker* = ref object of TrackerBase + opened*: int64 + closed*: int64 + +const + DgramTransportTrackerName = "datagram.transport" + template setReadError(t, e: untyped) = (t).state.incl(ReadError) (t).error = getTransportOsError(e) @@ -62,6 +69,38 @@ template setWriterWSABuffer(t, v: untyped) = (t).wwsabuf.buf = cast[cstring](v.buf) (t).wwsabuf.len = cast[int32](v.buflen) +proc setupDgramTransportTracker(): DgramTransportTracker {.gcsafe.} + +proc getDgramTransportTracker(): DgramTransportTracker {.inline.} = + result = cast[DgramTransportTracker](getTracker(DgramTransportTrackerName)) + if isNil(result): + result = setupDgramTransportTracker() + +proc dumpTransportTracking(): string {.gcsafe.} = + var tracker = getDgramTransportTracker() + result = "Opened transports: " & $tracker.opened & "\n" & + "Closed transports: " & $tracker.closed + +proc leakTransport(): bool {.gcsafe.} = + var tracker = getDgramTransportTracker() + result = tracker.opened != tracker.closed + +proc trackDgram(t: DatagramTransport) {.inline.} = + var tracker = getDgramTransportTracker() + inc(tracker.opened) + +proc untrackDgram(t: DatagramTransport) {.inline.} = + var tracker = getDgramTransportTracker() + inc(tracker.closed) + +proc setupDgramTransportTracker(): DgramTransportTracker {.gcsafe.} = + result = new DgramTransportTracker + result.opened = 0 + result.closed = 0 + result.dump = dumpTransportTracking + result.isLeaked = leakTransport + addTracker(DgramTransportTrackerName, result) + when defined(windows): const IOC_VENDOR = DWORD(0x18000000) @@ -144,6 +183,8 @@ when defined(windows): # CancelIO() interrupt or closeSocket() call. transp.state.incl(ReadPaused) if ReadClosed in transp.state: + # Stop tracking transport + untrackDgram(transp) # If `ReadClosed` present, then close(transport) was called. transp.future.complete() GC_unref(transp) @@ -188,7 +229,10 @@ when defined(windows): # WSARecvFrom session. if ReadClosed in transp.state: if not transp.future.finished: + # Stop tracking transport + untrackDgram(transp) transp.future.complete() + GC_unref(transp) break proc resumeRead(transp: DatagramTransport) {.inline.} = @@ -299,6 +343,8 @@ when defined(windows): result.rwsabuf = TWSABuf(buf: cast[cstring](addr result.buffer[0]), len: int32(len(result.buffer))) GC_ref(result) + # Start tracking transport + trackDgram(result) if NoAutoRead notin flags: result.resumeRead() else: @@ -465,6 +511,8 @@ else: result.state = {WritePaused} result.future = newFuture[void]("datagram.transport") GC_ref(result) + # Start tracking transport + trackDgram(result) if NoAutoRead notin flags: result.resumeRead() else: @@ -472,14 +520,25 @@ else: proc close*(transp: DatagramTransport) = ## Closes and frees resources of transport ``transp``. + proc continuation(udata: pointer) = + if not transp.future.finished: + # Stop tracking transport + untrackDgram(transp) + transp.future.complete() + GC_unref(transp) + when defined(windows): if {ReadClosed, WriteClosed} * transp.state == {}: transp.state.incl({WriteClosed, ReadClosed}) - closeSocket(transp.fd) + if ReadPaused in transp.state: + # If readDatagramLoop() is not running we need to finish in + # continuation step. + closeSocket(transp.fd, continuation) + else: + # If readDatagramLoop() is running, it will be properly finished inside + # of readDatagramLoop(). + closeSocket(transp.fd) else: - proc continuation(udata: pointer) = - transp.future.complete() - GC_unref(transp) if {ReadClosed, WriteClosed} * transp.state == {}: transp.state.incl({WriteClosed, ReadClosed}) closeSocket(transp.fd, continuation) diff --git a/chronos/transports/stream.nim b/chronos/transports/stream.nim index 4cd288d..6b32032 100644 --- a/chronos/transports/stream.nim +++ b/chronos/transports/stream.nim @@ -53,6 +53,18 @@ type # Please use this flag only if you are making both client and server in # the same thread. + StreamTransportTracker* = ref object of TrackerBase + opened*: int64 + closed*: int64 + + StreamServerTracker* = ref object of TrackerBase + opened*: int64 + closed*: int64 + +const + StreamTransportTrackerName = "stream.transport" + StreamServerTrackerName = "stream.server" + when defined(windows): const SO_UPDATE_CONNECT_CONTEXT = 0x7010 @@ -171,6 +183,69 @@ template shiftVectorFile(v, o: untyped) = (v).buf = cast[pointer](cast[uint]((v).buf) - cast[uint](o)) (v).offset += cast[uint]((o)) +proc setupStreamTransportTracker(): StreamTransportTracker {.gcsafe.} +proc setupStreamServerTracker(): StreamServerTracker {.gcsafe.} + +proc getStreamTransportTracker(): StreamTransportTracker {.inline.} = + result = cast[StreamTransportTracker](getTracker(StreamTransportTrackerName)) + if isNil(result): + result = setupStreamTransportTracker() + +proc getStreamServerTracker(): StreamServerTracker {.inline.} = + result = cast[StreamServerTracker](getTracker(StreamServerTrackerName)) + if isNil(result): + result = setupStreamServerTracker() + +proc dumpTransportTracking(): string {.gcsafe.} = + var tracker = getStreamTransportTracker() + result = "Opened transports: " & $tracker.opened & "\n" & + "Closed transports: " & $tracker.closed + +proc dumpServerTracking(): string {.gcsafe.} = + var tracker = getStreamServerTracker() + result = "Opened servers: " & $tracker.opened & "\n" & + "Closed servers: " & $tracker.closed + +proc leakTransport(): bool {.gcsafe.} = + var tracker = getStreamTransportTracker() + result = tracker.opened != tracker.closed + +proc leakServer(): bool {.gcsafe.} = + var tracker = getStreamServerTracker() + result = tracker.opened != tracker.closed + +proc trackStream(t: StreamTransport) {.inline.} = + var tracker = getStreamTransportTracker() + inc(tracker.opened) + +proc untrackStream(t: StreamTransport) {.inline.} = + var tracker = getStreamTransportTracker() + inc(tracker.closed) + +proc trackServer(s: StreamServer) {.inline.} = + var tracker = getStreamServerTracker() + inc(tracker.opened) + +proc untrackServer(s: StreamServer) {.inline.} = + var tracker = getStreamServerTracker() + inc(tracker.closed) + +proc setupStreamTransportTracker(): StreamTransportTracker {.gcsafe.} = + result = new StreamTransportTracker + result.opened = 0 + result.closed = 0 + result.dump = dumpTransportTracking + result.isLeaked = leakTransport + addTracker(StreamTransportTrackerName, result) + +proc setupStreamServerTracker(): StreamServerTracker {.gcsafe.} = + result = new StreamServerTracker + result.opened = 0 + result.closed = 0 + result.dump = dumpServerTracking + result.isLeaked = leakServer + addTracker(StreamServerTrackerName, result) + when defined(windows): template zeroOvelappedOffset(t: untyped) = @@ -361,14 +436,6 @@ when defined(windows): ERROR_BROKEN_PIPE, ERROR_NETNAME_DELETED}: # CancelIO() interrupt or closeSocket() call. transp.state.incl(ReadPaused) - if ReadClosed in transp.state: - if not isNil(transp.reader): - if not transp.reader.finished: - transp.reader.complete() - transp.reader = nil - # If `ReadClosed` present, then close(transport) was called. - transp.future.complete() - GC_unref(transp) elif transp.kind == TransportKind.Socket and (int(err) in {ERROR_NETNAME_DELETED, WSAECONNABORTED}): transp.state.incl({ReadEof, ReadPaused}) @@ -377,10 +444,19 @@ when defined(windows): transp.state.incl({ReadEof, ReadPaused}) else: transp.setReadError(err) + if not isNil(transp.reader): if not transp.reader.finished: transp.reader.complete() transp.reader = nil + + if ReadClosed in transp.state: + # Stop tracking transport + untrackStream(transp) + # If `ReadClosed` present, then close(transport) was called. + transp.future.complete() + GC_unref(transp) + if ReadPaused in transp.state: # Transport buffer is full, so we will not continue on reading. break @@ -553,9 +629,11 @@ when defined(windows): sock.closeSocket() retFuture.fail(getTransportOsError(err)) else: - retFuture.complete(newStreamSocketTransport(povl.data.fd, - bufferSize, - child)) + let transp = newStreamSocketTransport(povl.data.fd, bufferSize, + child) + # Start tracking transport + trackStream(transp) + retFuture.complete(transp) else: sock.closeSocket() retFuture.fail(getTransportOsError(ovl.data.errCode)) @@ -594,8 +672,11 @@ when defined(windows): retFuture.fail(getTransportOsError(err)) else: register(AsyncFD(pipeHandle)) - retFuture.complete(newStreamPipeTransport(AsyncFD(pipeHandle), - bufferSize, child)) + let transp = newStreamPipeTransport(AsyncFD(pipeHandle), + bufferSize, child) + # Start tracking transport + trackStream(transp) + retFuture.complete(transp) pipeContinuation(nil) return retFuture @@ -621,10 +702,15 @@ when defined(windows): else: ntransp = newStreamPipeTransport(server.sock, server.bufferSize, nil, flags) + # Start tracking transport + trackStream(ntransp) asyncCheck server.function(server, ntransp) elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED: # CancelIO() interrupt or close call. if server.status == ServerStatus.Closed: + # Stop tracking server + untrackServer(server) + # Completing server's Future server.loopFuture.complete() if not isNil(server.udata) and GCUserData in server.flags: GC_unref(cast[ref int](server.udata)) @@ -674,9 +760,12 @@ when defined(windows): # connectNamedPipe session. if server.status == ServerStatus.Closed: if not server.loopFuture.finished: + # Stop tracking server + untrackServer(server) server.loopFuture.complete() if not isNil(server.udata) and GCUserData in server.flags: GC_unref(cast[ref int](server.udata)) + GC_unref(server) proc acceptLoop(udata: pointer) {.gcsafe, nimcall.} = @@ -705,11 +794,15 @@ when defined(windows): else: ntransp = newStreamSocketTransport(server.asock, server.bufferSize, nil) + # Start tracking transport + trackStream(ntransp) asyncCheck server.function(server, ntransp) elif int32(ovl.data.errCode) == ERROR_OPERATION_ABORTED: # CancelIO() interrupt or close. if server.status == ServerStatus.Closed: + # Stop tracking server + untrackServer(server) server.loopFuture.complete() if not isNil(server.udata) and GCUserData in server.flags: GC_unref(cast[ref int](server.udata)) @@ -753,6 +846,8 @@ when defined(windows): # AcceptEx session. if server.status == ServerStatus.Closed: if not server.loopFuture.finished: + # Stop tracking server + untrackServer(server) server.loopFuture.complete() if not isNil(server.udata) and GCUserData in server.flags: GC_unref(cast[ref int](server.udata)) @@ -930,13 +1025,19 @@ else: closeSocket(fd) retFuture.fail(getTransportOsError(OSErrorCode(err))) return - retFuture.complete(newStreamSocketTransport(fd, bufferSize, child)) + let transp = newStreamSocketTransport(fd, bufferSize, child) + # Start tracking transport + trackStream(transp) + retFuture.complete(transp) while true: var res = posix.connect(SocketHandle(sock), cast[ptr SockAddr](addr saddr), slen) if res == 0: - retFuture.complete(newStreamSocketTransport(sock, bufferSize, child)) + let transp = newStreamSocketTransport(sock, bufferSize, child) + # Start tracking transport + trackStream(transp) + retFuture.complete(transp) break else: let err = osLastError() @@ -962,13 +1063,15 @@ else: if int(res) > 0: let sock = wrapAsyncSocket(res) if sock != asyncInvalidSocket: + var ntransp: StreamTransport if not isNil(server.init): - var transp = server.init(server, sock) - asyncCheck server.function(server, - newStreamSocketTransport(sock, server.bufferSize, transp)) + let transp = server.init(server, sock) + ntransp = newStreamSocketTransport(sock, server.bufferSize, transp) else: - asyncCheck server.function(server, - newStreamSocketTransport(sock, server.bufferSize, nil)) + ntransp = newStreamSocketTransport(sock, server.bufferSize, nil) + # Start tracking transport + trackStream(ntransp) + asyncCheck server.function(server, ntransp) break else: let err = osLastError() @@ -1023,6 +1126,8 @@ proc close*(server: StreamServer) = ## sure all resources got released please use ``await server.join()``. when not defined(windows): proc continuation(udata: pointer) = + # Stop tracking server + untrackServer(server) server.loopFuture.complete() if not isNil(server.udata) and GCUserData in server.flags: GC_unref(cast[ref int](server.udata)) @@ -1197,6 +1302,8 @@ proc createStreamServer*(host: TransportAddress, result.domain = host.getDomain() result.apending = false + # Start tracking server + trackServer(result) GC_ref(result) proc createStreamServer*[T](host: TransportAddress, @@ -1562,8 +1669,11 @@ proc close*(transp: StreamTransport) = ## Please note that release of resources is not completed immediately, to be ## sure all resources got released please use ``await transp.join()``. proc continuation(udata: pointer) = - transp.future.complete() - GC_unref(transp) + if not transp.future.finished: + transp.future.complete() + # Stop tracking stream + untrackStream(transp) + GC_unref(transp) if {ReadClosed, WriteClosed} * transp.state == {}: transp.state.incl({WriteClosed, ReadClosed}) diff --git a/tests/testdatagram.nim b/tests/testdatagram.nim index dac73a7..0b31464 100644 --- a/tests/testdatagram.nim +++ b/tests/testdatagram.nim @@ -503,3 +503,5 @@ suite "Datagram Transport test suite": check waitFor(testConnReset()) == true test "Broadcast test": check waitFor(testBroadcast()) == 1 + test "Transports leak test": + check getTracker("datagram.transport").isLeaked() == false diff --git a/tests/teststream.nim b/tests/teststream.nim index aa9c7ac..9ebbfd1 100644 --- a/tests/teststream.nim +++ b/tests/teststream.nim @@ -631,6 +631,7 @@ suite "Stream Transport test suite": server.stop() server.close() await server.join() + await transp.join() result = subres proc testConnectionRefused(address: TransportAddress): Future[bool] {.async.} = @@ -733,3 +734,8 @@ suite "Stream Transport test suite": check waitFor(testConnectionRefused(address)) == true test prefixes[i] & m16: check waitFor(test16(addresses[i])) == 1 + + test "Servers leak test": + check getTracker("stream.server").isLeaked() == false + test "Transports leak test": + check getTracker("stream.transport").isLeaked() == false