release callback memory early (#130)

* release callback memory early

this fixes a memory leak where a deleted callback may keep references
alive until the future is finished.

In particular, when using helpers like `or` which try to remove
themselves from the callback list when a dependent future is completed,
create a reference chain between all futures in the expression - in the
pathological case where one of the futures is completes only rarely (for
example a timeout or a cancellation task), the buildup will be
significant.

* Removing unnecessary asserts, and place comments instead.
This commit is contained in:
Jacek Sieka 2020-09-15 09:55:43 +02:00 committed by GitHub
parent 2134980744
commit 1ffd1cd3dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 64 deletions

View File

@ -8,7 +8,7 @@
# Apache License, version 2.0, (LICENSE-APACHEv2) # Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT) # MIT license (LICENSE-MIT)
import os, tables, strutils, heapqueue, options, deques, cstrutils import std/[os, tables, strutils, heapqueue, options, deques, cstrutils, sequtils]
import srcloc import srcloc
export srcloc export srcloc
@ -25,7 +25,7 @@ type
FutureBase* = ref object of RootObj ## Untyped future. FutureBase* = ref object of RootObj ## Untyped future.
location*: array[2, ptr SrcLoc] location*: array[2, ptr SrcLoc]
callbacks: Deque[AsyncCallback] callbacks: seq[AsyncCallback]
cancelcb*: CallbackFunc cancelcb*: CallbackFunc
child*: FutureBase child*: FutureBase
state*: FutureState state*: FutureState
@ -142,6 +142,7 @@ template newFutureVar*[T](fromProc: static[string] = ""): auto =
proc clean*[T](future: FutureVar[T]) = proc clean*[T](future: FutureVar[T]) =
## Resets the ``finished`` status of ``future``. ## Resets the ``finished`` status of ``future``.
Future[T](future).state = FutureState.Pending Future[T](future).state = FutureState.Pending
Future[T](future).value = default(T)
Future[T](future).error = nil Future[T](future).error = nil
proc finished*(future: FutureBase | FutureVar): bool {.inline.} = proc finished*(future: FutureBase | FutureVar): bool {.inline.} =
@ -201,33 +202,27 @@ proc checkFinished(future: FutureBase, loc: ptr SrcLoc) =
else: else:
future.location[LocCompleteIndex] = loc future.location[LocCompleteIndex] = loc
proc call(callbacks: var Deque[AsyncCallback]) = proc finish(fut: FutureBase, state: FutureState) =
var count = len(callbacks) # We do not perform any checks here, because:
while count > 0: # 1. `finish()` is a private procedure and `state` is under our control.
var item = callbacks.popFirst() # 2. `fut.state` is checked by `checkFinished()`.
if not(item.deleted): fut.state = state
fut.cancelcb = nil # release cancellation callback memory
for item in fut.callbacks.mitems():
if not(isNil(item.function)):
callSoon(item.function, item.udata) callSoon(item.function, item.udata)
dec(count) item = default(AsyncCallback) # release memory as early as possible
fut.callbacks = default(seq[AsyncCallback]) # release seq as well
proc add(callbacks: var Deque[AsyncCallback], item: AsyncCallback) = when defined(chronosFutureTracking):
if len(callbacks) == 0: scheduleDestructor(fut)
callbacks = initDeque[AsyncCallback]()
callbacks.addLast(item)
proc remove(callbacks: var Deque[AsyncCallback], item: AsyncCallback) =
for p in callbacks.mitems():
if p.function == item.function and p.udata == item.udata:
p.deleted = true
proc complete[T](future: Future[T], val: T, loc: ptr SrcLoc) = proc complete[T](future: Future[T], val: T, loc: ptr SrcLoc) =
if not(future.cancelled()): if not(future.cancelled()):
checkFinished(FutureBase(future), loc) checkFinished(FutureBase(future), loc)
doAssert(isNil(future.error)) doAssert(isNil(future.error))
future.value = val future.value = val
future.state = FutureState.Finished future.finish(FutureState.Finished)
future.callbacks.call()
when defined(chronosFutureTracking):
scheduleDestructor(FutureBase(future))
template complete*[T](future: Future[T], val: T) = template complete*[T](future: Future[T], val: T) =
## Completes ``future`` with value ``val``. ## Completes ``future`` with value ``val``.
@ -237,10 +232,7 @@ proc complete(future: Future[void], loc: ptr SrcLoc) =
if not(future.cancelled()): if not(future.cancelled()):
checkFinished(FutureBase(future), loc) checkFinished(FutureBase(future), loc)
doAssert(isNil(future.error)) doAssert(isNil(future.error))
future.state = FutureState.Finished future.finish(FutureState.Finished)
future.callbacks.call()
when defined(chronosFutureTracking):
scheduleDestructor(FutureBase(future))
template complete*(future: Future[void]) = template complete*(future: Future[void]) =
## Completes a void ``future``. ## Completes a void ``future``.
@ -251,10 +243,7 @@ proc complete[T](future: FutureVar[T], loc: ptr SrcLoc) =
template fut: untyped = Future[T](future) template fut: untyped = Future[T](future)
checkFinished(FutureBase(fut), loc) checkFinished(FutureBase(fut), loc)
doAssert(isNil(fut.error)) doAssert(isNil(fut.error))
fut.state = FutureState.Finished fut.finish(FutureState.Finished)
fut.callbacks.call()
when defined(chronosFutureTracking):
scheduleDestructor(FutureBase(future))
template complete*[T](futvar: FutureVar[T]) = template complete*[T](futvar: FutureVar[T]) =
## Completes a ``FutureVar``. ## Completes a ``FutureVar``.
@ -265,11 +254,8 @@ proc complete[T](futvar: FutureVar[T], val: T, loc: ptr SrcLoc) =
template fut: untyped = Future[T](futvar) template fut: untyped = Future[T](futvar)
checkFinished(FutureBase(fut), loc) checkFinished(FutureBase(fut), loc)
doAssert(isNil(fut.error)) doAssert(isNil(fut.error))
fut.state = FutureState.Finished
fut.value = val fut.value = val
fut.callbacks.call() fut.finish(FutureState.Finished)
when defined(chronosFutureTracking):
scheduleDestructor(FutureBase(fut))
template complete*[T](futvar: FutureVar[T], val: T) = template complete*[T](futvar: FutureVar[T], val: T) =
## Completes a ``FutureVar`` with value ``val``. ## Completes a ``FutureVar`` with value ``val``.
@ -280,16 +266,13 @@ template complete*[T](futvar: FutureVar[T], val: T) =
proc fail[T](future: Future[T], error: ref Exception, loc: ptr SrcLoc) = proc fail[T](future: Future[T], error: ref Exception, loc: ptr SrcLoc) =
if not(future.cancelled()): if not(future.cancelled()):
checkFinished(FutureBase(future), loc) checkFinished(FutureBase(future), loc)
future.state = FutureState.Failed
future.error = error future.error = error
when defined(chronosStackTrace): when defined(chronosStackTrace):
future.errorStackTrace = if getStackTrace(error) == "": future.errorStackTrace = if getStackTrace(error) == "":
getStackTrace() getStackTrace()
else: else:
getStackTrace(error) getStackTrace(error)
future.callbacks.call() future.finish(FutureState.Failed)
when defined(chronosFutureTracking):
scheduleDestructor(FutureBase(future))
template fail*[T](future: Future[T], error: ref Exception) = template fail*[T](future: Future[T], error: ref Exception) =
## Completes ``future`` with ``error``. ## Completes ``future`` with ``error``.
@ -301,13 +284,10 @@ template newCancelledError(): ref CancelledError =
proc cancelAndSchedule(future: FutureBase, loc: ptr SrcLoc) = proc cancelAndSchedule(future: FutureBase, loc: ptr SrcLoc) =
if not(future.finished()): if not(future.finished()):
checkFinished(future, loc) checkFinished(future, loc)
future.state = FutureState.Cancelled
future.error = newCancelledError() future.error = newCancelledError()
when defined(chronosStackTrace): when defined(chronosStackTrace):
future.errorStackTrace = getStackTrace() future.errorStackTrace = getStackTrace()
future.callbacks.call() future.finish(FutureState.Cancelled)
when defined(chronosFutureTracking):
scheduleDestructor(future)
template cancelAndSchedule*[T](future: Future[T]) = template cancelAndSchedule*[T](future: Future[T]) =
cancelAndSchedule(FutureBase(future), getSrcLocation()) cancelAndSchedule(FutureBase(future), getSrcLocation())
@ -324,6 +304,7 @@ proc cancel(future: FutureBase, loc: ptr SrcLoc) =
else: else:
if not(isNil(future.cancelcb)): if not(isNil(future.cancelcb)):
future.cancelcb(cast[pointer](future)) future.cancelcb(cast[pointer](future))
future.cancelcb = nil
cancelAndSchedule(future, getSrcLocation()) cancelAndSchedule(future, getSrcLocation())
template cancel*[T](future: Future[T]) = template cancel*[T](future: Future[T]) =
@ -331,7 +312,7 @@ template cancel*[T](future: Future[T]) =
cancel(FutureBase(future), getSrcLocation()) cancel(FutureBase(future), getSrcLocation())
proc clearCallbacks(future: FutureBase) = proc clearCallbacks(future: FutureBase) =
future.callbacks.clear() future.callbacks = default(seq[AsyncCallback])
proc addCallback*(future: FutureBase, cb: CallbackFunc, udata: pointer = nil) = proc addCallback*(future: FutureBase, cb: CallbackFunc, udata: pointer = nil) =
## Adds the callbacks proc to be called when the future completes. ## Adds the callbacks proc to be called when the future completes.
@ -352,9 +333,13 @@ proc addCallback*[T](future: Future[T], cb: CallbackFunc) =
proc removeCallback*(future: FutureBase, cb: CallbackFunc, proc removeCallback*(future: FutureBase, cb: CallbackFunc,
udata: pointer = nil) = udata: pointer = nil) =
## Remove future from list of callbacks - this operation may be slow if there
## are many registered callbacks!
doAssert(not isNil(cb)) doAssert(not isNil(cb))
let acb = AsyncCallback(function: cb, udata: udata) # Make sure to release memory associated with callback, or reference chains
future.callbacks.remove acb # may be created!
future.callbacks.keepItIf:
it.function != cb or it.udata != udata
proc removeCallback*[T](future: Future[T], cb: CallbackFunc) = proc removeCallback*[T](future: Future[T], cb: CallbackFunc) =
future.removeCallback(cb, cast[pointer](future)) future.removeCallback(cb, cast[pointer](future))

View File

@ -183,7 +183,6 @@ type
AsyncCallback* = object AsyncCallback* = object
function*: CallbackFunc function*: CallbackFunc
udata*: pointer udata*: pointer
deleted*: bool
AsyncError* = object of CatchableError AsyncError* = object of CatchableError
## Generic async exception ## Generic async exception
@ -193,7 +192,6 @@ type
TimerCallback* = ref object TimerCallback* = ref object
finishAt*: Moment finishAt*: Moment
function*: AsyncCallback function*: AsyncCallback
deleted*: bool
TrackerBase* = ref object of RootRef TrackerBase* = ref object of RootRef
id*: string id*: string
@ -231,7 +229,7 @@ func getAsyncTimestamp*(a: Duration): auto {.inline.} =
template processTimersGetTimeout(loop, timeout: untyped) = template processTimersGetTimeout(loop, timeout: untyped) =
var lastFinish = curTime var lastFinish = curTime
while loop.timers.len > 0: while loop.timers.len > 0:
if loop.timers[0].deleted: if loop.timers[0].function.function.isNil:
discard loop.timers.pop() discard loop.timers.pop()
continue continue
@ -256,7 +254,7 @@ template processTimersGetTimeout(loop, timeout: untyped) =
template processTimers(loop: untyped) = template processTimers(loop: untyped) =
var curTime = Moment.now() var curTime = Moment.now()
while loop.timers.len > 0: while loop.timers.len > 0:
if loop.timers[0].deleted: if loop.timers[0].function.function.isNil:
discard loop.timers.pop() discard loop.timers.pop()
continue continue
@ -581,7 +579,7 @@ elif unixPlatform:
var newEvents: set[Event] var newEvents: set[Event]
withData(loop.selector, int(fd), adata) do: withData(loop.selector, int(fd), adata) do:
# We need to clear `reader` data, because `selectors` don't do it # We need to clear `reader` data, because `selectors` don't do it
adata.reader.function = nil adata.reader = default(AsyncCallback)
# adata.rdata = CompletionData() # adata.rdata = CompletionData()
if not(isNil(adata.writer.function)): if not(isNil(adata.writer.function)):
newEvents.incl(Event.Write) newEvents.incl(Event.Write)
@ -611,7 +609,7 @@ elif unixPlatform:
var newEvents: set[Event] var newEvents: set[Event]
withData(loop.selector, int(fd), adata) do: withData(loop.selector, int(fd), adata) do:
# We need to clear `writer` data, because `selectors` don't do it # We need to clear `writer` data, because `selectors` don't do it
adata.writer.function = nil adata.writer = default(AsyncCallback)
# adata.wdata = CompletionData() # adata.wdata = CompletionData()
if not(isNil(adata.reader.function)): if not(isNil(adata.reader.function)):
newEvents.incl(Event.Read) newEvents.incl(Event.Read)
@ -638,16 +636,16 @@ elif unixPlatform:
withData(loop.selector, int(fd), adata) do: withData(loop.selector, int(fd), adata) do:
# We are scheduling reader and writer callbacks to be called # We are scheduling reader and writer callbacks to be called
# explicitly, so they can get an error and continue work. # explicitly, so they can get an error and continue work.
if not(isNil(adata.reader.function)): # Callbacks marked as deleted so we don't need to get REAL notifications
if not adata.reader.deleted:
loop.callbacks.addLast(adata.reader)
if not(isNil(adata.writer.function)):
if not adata.writer.deleted:
loop.callbacks.addLast(adata.writer)
# Mark callbacks as deleted, we don't need to get REAL notifications
# from system queue for this reader and writer. # from system queue for this reader and writer.
adata.reader.deleted = true
adata.writer.deleted = true if not(isNil(adata.reader.function)):
loop.callbacks.addLast(adata.reader)
adata.reader = default(AsyncCallback)
if not(isNil(adata.writer.function)):
loop.callbacks.addLast(adata.writer)
adata.writer = default(AsyncCallback)
# We can't unregister file descriptor from system queue here, because # We can't unregister file descriptor from system queue here, because
# in such case processing queue will stuck on poll() call, because there # in such case processing queue will stuck on poll() call, because there
@ -707,20 +705,20 @@ elif unixPlatform:
withData(loop.selector, fd, adata) do: withData(loop.selector, fd, adata) do:
if Event.Read in events or events == {Event.Error}: if Event.Read in events or events == {Event.Error}:
if not adata.reader.deleted: if not isNil(adata.reader.function):
loop.callbacks.addLast(adata.reader) loop.callbacks.addLast(adata.reader)
if Event.Write in events or events == {Event.Error}: if Event.Write in events or events == {Event.Error}:
if not adata.writer.deleted: if not isNil(adata.writer.function):
loop.callbacks.addLast(adata.writer) loop.callbacks.addLast(adata.writer)
if Event.User in events: if Event.User in events:
if not adata.reader.deleted: if not isNil(adata.reader.function):
loop.callbacks.addLast(adata.reader) loop.callbacks.addLast(adata.reader)
when ioselSupportedPlatform: when ioselSupportedPlatform:
if customSet * events != {}: if customSet * events != {}:
if not adata.reader.deleted: if not isNil(adata.reader.function):
loop.callbacks.addLast(adata.reader) loop.callbacks.addLast(adata.reader)
# Moving expired timers to `loop.callbacks`. # Moving expired timers to `loop.callbacks`.
@ -744,7 +742,7 @@ proc setTimer*(at: Moment, cb: CallbackFunc,
loop.timers.push(result) loop.timers.push(result)
proc clearTimer*(timer: TimerCallback) {.inline.} = proc clearTimer*(timer: TimerCallback) {.inline.} =
timer.deleted = true timer.function = default(AsyncCallback)
proc addTimer*(at: Moment, cb: CallbackFunc, udata: pointer = nil) {. proc addTimer*(at: Moment, cb: CallbackFunc, udata: pointer = nil) {.
inline, deprecated: "Use setTimer/clearTimer instead".} = inline, deprecated: "Use setTimer/clearTimer instead".} =