From 16ed169f251a3443e9500ba10fe06f368849c5c8 Mon Sep 17 00:00:00 2001 From: Eugene Kabanov Date: Fri, 3 Jul 2020 15:03:59 +0300 Subject: [PATCH] Fix cancellation race when low level futures are already completed, while cancellation process is pending. (#107) Added test. --- chronos/asyncfutures2.nim | 69 +++++++++++++++++---------------------- chronos/asyncmacro2.nim | 8 +++-- tests/testfut.nim | 32 ++++++++++++++++-- 3 files changed, 66 insertions(+), 43 deletions(-) diff --git a/chronos/asyncfutures2.nim b/chronos/asyncfutures2.nim index 28597b0..ce003ec 100644 --- a/chronos/asyncfutures2.nim +++ b/chronos/asyncfutures2.nim @@ -26,7 +26,7 @@ type Pending, Finished, Cancelled, Failed FutureBase* = ref object of RootObj ## Untyped future. - location: array[2, ptr SrcLoc] + location*: array[2, ptr SrcLoc] callbacks: Deque[AsyncCallback] cancelcb*: CallbackFunc child*: FutureBase @@ -34,7 +34,8 @@ type error*: ref Exception ## Stored exception errorStackTrace*: StackTrace stackTrace: StackTrace ## For debugging purposes only. - id: int + mustCancel*: bool + id*: int # ZAH: we have discussed some possible optimizations where # the future can be stored within the caller's stack frame. @@ -71,11 +72,6 @@ template setupFutureBase(loc: ptr SrcLoc) = result.location[LocCreateIndex] = loc currentID.inc() -## ZAH: As far as I undestand `fromProc` is just a debugging helper. -## It would be more efficient if it's represented as a simple statically -## known `char *` in the final program (so it needs to be a `cstring` in Nim). -## The public API can be defined as a template expecting a `static[string]` -## and converting this immediately to a `cstring`. proc newFuture[T](loc: ptr SrcLoc): Future[T] = setupFutureBase(loc) @@ -141,7 +137,7 @@ proc failed*(future: FutureBase): bool {.inline.} = ## Determines whether ``future`` completed with an error. result = (future.state == FutureState.Failed) -proc checkFinished[T](future: Future[T], loc: ptr SrcLoc) = +proc checkFinished(future: FutureBase, loc: ptr SrcLoc) = ## Checks whether `future` is finished. If it is then raises a ## ``FutureDefect``. if future.finished(): @@ -157,9 +153,6 @@ proc checkFinished[T](future: Future[T], loc: ptr SrcLoc) = msg.add("\n " & $loc) msg.add("\n Stack trace to moment of creation:") msg.add("\n" & indent(future.stackTrace.strip(), 4)) - when T is string: - msg.add("\n Contents (string): ") - msg.add("\n" & indent(future.value.repr, 4)) msg.add("\n Stack trace to moment of secondary completion:") msg.add("\n" & indent(getStackTrace().strip(), 4)) msg.add("\n\n") @@ -189,7 +182,7 @@ proc remove(callbacks: var Deque[AsyncCallback], item: AsyncCallback) = proc complete[T](future: Future[T], val: T, loc: ptr SrcLoc) = if not(future.cancelled()): - checkFinished(future, loc) + checkFinished(FutureBase(future), loc) doAssert(isNil(future.error)) future.value = val future.state = FutureState.Finished @@ -201,7 +194,7 @@ template complete*[T](future: Future[T], val: T) = proc complete(future: Future[void], loc: ptr SrcLoc) = if not(future.cancelled()): - checkFinished(future, loc) + checkFinished(FutureBase(future), loc) doAssert(isNil(future.error)) future.state = FutureState.Finished future.callbacks.call() @@ -213,7 +206,7 @@ template complete*(future: Future[void]) = proc complete[T](future: FutureVar[T], loc: ptr SrcLoc) = if not(future.cancelled()): template fut: untyped = Future[T](future) - checkFinished(fut, loc) + checkFinished(FutureBase(fut), loc) doAssert(isNil(fut.error)) fut.state = FutureState.Finished fut.callbacks.call() @@ -225,7 +218,7 @@ template complete*[T](futvar: FutureVar[T]) = proc complete[T](futvar: FutureVar[T], val: T, loc: ptr SrcLoc) = if not(futvar.cancelled()): template fut: untyped = Future[T](futvar) - checkFinished(fut, loc) + checkFinished(FutureBase(fut), loc) doAssert(isNil(fut.error)) fut.state = FutureState.Finished fut.value = val @@ -239,7 +232,7 @@ template complete*[T](futvar: FutureVar[T], val: T) = proc fail[T](future: Future[T], error: ref Exception, loc: ptr SrcLoc) = if not(future.cancelled()): - checkFinished(future, loc) + checkFinished(FutureBase(future), loc) future.state = FutureState.Failed future.error = error future.errorStackTrace = @@ -250,35 +243,33 @@ template fail*[T](future: Future[T], error: ref Exception) = ## Completes ``future`` with ``error``. fail(future, error, getSrcLocation()) -proc cancel[T](future: Future[T], loc: ptr SrcLoc) = - if future.finished(): +proc cancelAndSchedule(future: FutureBase, loc: ptr SrcLoc) = + if not(future.finished()): checkFinished(future, loc) - else: - var first = FutureBase(future) - var last = first - while not(isNil(last.child)) and not(last.child.cancelled()): - last = last.child - if last == first: - checkFinished(future, loc) - let isPending = (last.state == FutureState.Pending) - last.state = FutureState.Cancelled - last.error = newException(CancelledError, "") - if not(isNil(last.cancelcb)): - last.cancelcb(cast[pointer](last)) - if isPending: - # If Future's state was `Finished` or `Failed` callbacks are already - # scheduled. - last.callbacks.call() + future.state = FutureState.Cancelled + future.error = newException(CancelledError, "") + future.errorStackTrace = getStackTrace() + future.callbacks.call() + +template cancelAndSchedule*[T](future: Future[T]) = + cancelAndSchedule(FutureBase(future), getSrcLocation()) + +proc cancel(future: FutureBase, loc: ptr SrcLoc) = + if not(future.finished()): + if not(isNil(future.child)): + cancel(future.child, getSrcLocation()) + future.mustCancel = true + else: + if not(isNil(future.cancelcb)): + future.cancelcb(cast[pointer](future)) + cancelAndSchedule(future, getSrcLocation()) template cancel*[T](future: Future[T]) = ## Cancel ``future``. - cancel(future, getSrcLocation()) + cancel(FutureBase(future), getSrcLocation()) proc clearCallbacks(future: FutureBase) = - var count = len(future.callbacks) - while count > 0: - discard future.callbacks.popFirst() - dec(count) + future.callbacks.clear() proc addCallback*(future: FutureBase, cb: CallbackFunc, udata: pointer = nil) = ## Adds the callbacks proc to be called when the future completes. diff --git a/chronos/asyncmacro2.nim b/chronos/asyncmacro2.nim index 8daf6a0..b596ebb 100644 --- a/chronos/asyncmacro2.nim +++ b/chronos/asyncmacro2.nim @@ -49,7 +49,7 @@ template createCb(retFutureSym, iteratorNameSym, {.gcsafe.}: next.addCallback(identName) except CancelledError: - retFutureSym.cancel() + retFutureSym.cancelAndSchedule() except CatchableError as exc: futureVarCompletions @@ -273,8 +273,10 @@ template await*[T](f: Future[T]): auto = chronosInternalTmpFuture = f chronosInternalRetFuture.child = chronosInternalTmpFuture yield chronosInternalTmpFuture - chronosInternalTmpFuture.internalCheckComplete() chronosInternalRetFuture.child = nil + if chronosInternalRetFuture.mustCancel: + raise newException(CancelledError, "") + chronosInternalTmpFuture.internalCheckComplete() cast[type(f)](chronosInternalTmpFuture).internalRead() else: unsupported "await is only available within {.async.}" @@ -287,6 +289,8 @@ template awaitne*[T](f: Future[T]): Future[T] = chronosInternalRetFuture.child = chronosInternalTmpFuture yield chronosInternalTmpFuture chronosInternalRetFuture.child = nil + if chronosInternalRetFuture.mustCancel: + raise newException(CancelledError, "") cast[type(f)](chronosInternalTmpFuture) else: unsupported "awaitne is only available within {.async.}" diff --git a/tests/testfut.nim b/tests/testfut.nim index a82b9ce..d0e4bc6 100644 --- a/tests/testfut.nim +++ b/tests/testfut.nim @@ -936,12 +936,38 @@ suite "Future[T] behavior test suite": neverFlag1 and neverFlag2 and neverFlag3 and waitProc1 and waitProc2 + proc testCancellationRaceAsync(): Future[bool] {.async.} = + var someFut = newFuture[void]() + + proc raceProc(): Future[void] {.async.} = + await someFut + + var raceFut1 = raceProc() + someFut.complete() + await cancelAndWait(raceFut1) + + someFut = newFuture[void]() + var raceFut2 = raceProc() + someFut.fail(newException(ValueError, "")) + await cancelAndWait(raceFut2) + + someFut = newFuture[void]() + var raceFut3 = raceProc() + someFut.cancel() + await cancelAndWait(raceFut3) + + result = (raceFut1.state == FutureState.Cancelled) and + (raceFut2.state == FutureState.Cancelled) and + (raceFut3.state == FutureState.Cancelled) proc testWait(): bool = - result = waitFor(testWaitAsync()) + waitFor(testWaitAsync()) proc testWithTimeout(): bool = - result = waitFor(testWithTimeoutAsync()) + waitFor(testWithTimeoutAsync()) + + proc testCancellationRace(): bool = + waitFor(testCancellationRaceAsync()) test "Async undefined behavior (#7758) test": check test1() == true @@ -994,3 +1020,5 @@ suite "Future[T] behavior test suite": check testWait() == true test "Cancellation withTimeout() test": check testWithTimeout() == true + test "Cancellation race test": + check testCancellationRace() == true