Fix cancellation race when low level futures are already completed, while cancellation process is pending. (#107)

Added test.
This commit is contained in:
Eugene Kabanov 2020-07-03 15:03:59 +03:00 committed by GitHub
parent 528688d01e
commit 16ed169f25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 66 additions and 43 deletions

View File

@ -26,7 +26,7 @@ type
Pending, Finished, Cancelled, Failed
FutureBase* = ref object of RootObj ## Untyped future.
location: array[2, ptr SrcLoc]
location*: array[2, ptr SrcLoc]
callbacks: Deque[AsyncCallback]
cancelcb*: CallbackFunc
child*: FutureBase
@ -34,7 +34,8 @@ type
error*: ref Exception ## Stored exception
errorStackTrace*: StackTrace
stackTrace: StackTrace ## For debugging purposes only.
id: int
mustCancel*: bool
id*: int
# ZAH: we have discussed some possible optimizations where
# the future can be stored within the caller's stack frame.
@ -71,11 +72,6 @@ template setupFutureBase(loc: ptr SrcLoc) =
result.location[LocCreateIndex] = loc
currentID.inc()
## ZAH: As far as I undestand `fromProc` is just a debugging helper.
## It would be more efficient if it's represented as a simple statically
## known `char *` in the final program (so it needs to be a `cstring` in Nim).
## The public API can be defined as a template expecting a `static[string]`
## and converting this immediately to a `cstring`.
proc newFuture[T](loc: ptr SrcLoc): Future[T] =
setupFutureBase(loc)
@ -141,7 +137,7 @@ proc failed*(future: FutureBase): bool {.inline.} =
## Determines whether ``future`` completed with an error.
result = (future.state == FutureState.Failed)
proc checkFinished[T](future: Future[T], loc: ptr SrcLoc) =
proc checkFinished(future: FutureBase, loc: ptr SrcLoc) =
## Checks whether `future` is finished. If it is then raises a
## ``FutureDefect``.
if future.finished():
@ -157,9 +153,6 @@ proc checkFinished[T](future: Future[T], loc: ptr SrcLoc) =
msg.add("\n " & $loc)
msg.add("\n Stack trace to moment of creation:")
msg.add("\n" & indent(future.stackTrace.strip(), 4))
when T is string:
msg.add("\n Contents (string): ")
msg.add("\n" & indent(future.value.repr, 4))
msg.add("\n Stack trace to moment of secondary completion:")
msg.add("\n" & indent(getStackTrace().strip(), 4))
msg.add("\n\n")
@ -189,7 +182,7 @@ proc remove(callbacks: var Deque[AsyncCallback], item: AsyncCallback) =
proc complete[T](future: Future[T], val: T, loc: ptr SrcLoc) =
if not(future.cancelled()):
checkFinished(future, loc)
checkFinished(FutureBase(future), loc)
doAssert(isNil(future.error))
future.value = val
future.state = FutureState.Finished
@ -201,7 +194,7 @@ template complete*[T](future: Future[T], val: T) =
proc complete(future: Future[void], loc: ptr SrcLoc) =
if not(future.cancelled()):
checkFinished(future, loc)
checkFinished(FutureBase(future), loc)
doAssert(isNil(future.error))
future.state = FutureState.Finished
future.callbacks.call()
@ -213,7 +206,7 @@ template complete*(future: Future[void]) =
proc complete[T](future: FutureVar[T], loc: ptr SrcLoc) =
if not(future.cancelled()):
template fut: untyped = Future[T](future)
checkFinished(fut, loc)
checkFinished(FutureBase(fut), loc)
doAssert(isNil(fut.error))
fut.state = FutureState.Finished
fut.callbacks.call()
@ -225,7 +218,7 @@ template complete*[T](futvar: FutureVar[T]) =
proc complete[T](futvar: FutureVar[T], val: T, loc: ptr SrcLoc) =
if not(futvar.cancelled()):
template fut: untyped = Future[T](futvar)
checkFinished(fut, loc)
checkFinished(FutureBase(fut), loc)
doAssert(isNil(fut.error))
fut.state = FutureState.Finished
fut.value = val
@ -239,7 +232,7 @@ template complete*[T](futvar: FutureVar[T], val: T) =
proc fail[T](future: Future[T], error: ref Exception, loc: ptr SrcLoc) =
if not(future.cancelled()):
checkFinished(future, loc)
checkFinished(FutureBase(future), loc)
future.state = FutureState.Failed
future.error = error
future.errorStackTrace =
@ -250,35 +243,33 @@ template fail*[T](future: Future[T], error: ref Exception) =
## Completes ``future`` with ``error``.
fail(future, error, getSrcLocation())
proc cancel[T](future: Future[T], loc: ptr SrcLoc) =
if future.finished():
proc cancelAndSchedule(future: FutureBase, loc: ptr SrcLoc) =
if not(future.finished()):
checkFinished(future, loc)
else:
var first = FutureBase(future)
var last = first
while not(isNil(last.child)) and not(last.child.cancelled()):
last = last.child
if last == first:
checkFinished(future, loc)
let isPending = (last.state == FutureState.Pending)
last.state = FutureState.Cancelled
last.error = newException(CancelledError, "")
if not(isNil(last.cancelcb)):
last.cancelcb(cast[pointer](last))
if isPending:
# If Future's state was `Finished` or `Failed` callbacks are already
# scheduled.
last.callbacks.call()
future.state = FutureState.Cancelled
future.error = newException(CancelledError, "")
future.errorStackTrace = getStackTrace()
future.callbacks.call()
template cancelAndSchedule*[T](future: Future[T]) =
cancelAndSchedule(FutureBase(future), getSrcLocation())
proc cancel(future: FutureBase, loc: ptr SrcLoc) =
if not(future.finished()):
if not(isNil(future.child)):
cancel(future.child, getSrcLocation())
future.mustCancel = true
else:
if not(isNil(future.cancelcb)):
future.cancelcb(cast[pointer](future))
cancelAndSchedule(future, getSrcLocation())
template cancel*[T](future: Future[T]) =
## Cancel ``future``.
cancel(future, getSrcLocation())
cancel(FutureBase(future), getSrcLocation())
proc clearCallbacks(future: FutureBase) =
var count = len(future.callbacks)
while count > 0:
discard future.callbacks.popFirst()
dec(count)
future.callbacks.clear()
proc addCallback*(future: FutureBase, cb: CallbackFunc, udata: pointer = nil) =
## Adds the callbacks proc to be called when the future completes.

View File

@ -49,7 +49,7 @@ template createCb(retFutureSym, iteratorNameSym,
{.gcsafe.}:
next.addCallback(identName)
except CancelledError:
retFutureSym.cancel()
retFutureSym.cancelAndSchedule()
except CatchableError as exc:
futureVarCompletions
@ -273,8 +273,10 @@ template await*[T](f: Future[T]): auto =
chronosInternalTmpFuture = f
chronosInternalRetFuture.child = chronosInternalTmpFuture
yield chronosInternalTmpFuture
chronosInternalTmpFuture.internalCheckComplete()
chronosInternalRetFuture.child = nil
if chronosInternalRetFuture.mustCancel:
raise newException(CancelledError, "")
chronosInternalTmpFuture.internalCheckComplete()
cast[type(f)](chronosInternalTmpFuture).internalRead()
else:
unsupported "await is only available within {.async.}"
@ -287,6 +289,8 @@ template awaitne*[T](f: Future[T]): Future[T] =
chronosInternalRetFuture.child = chronosInternalTmpFuture
yield chronosInternalTmpFuture
chronosInternalRetFuture.child = nil
if chronosInternalRetFuture.mustCancel:
raise newException(CancelledError, "")
cast[type(f)](chronosInternalTmpFuture)
else:
unsupported "awaitne is only available within {.async.}"

View File

@ -936,12 +936,38 @@ suite "Future[T] behavior test suite":
neverFlag1 and neverFlag2 and neverFlag3 and
waitProc1 and waitProc2
proc testCancellationRaceAsync(): Future[bool] {.async.} =
var someFut = newFuture[void]()
proc raceProc(): Future[void] {.async.} =
await someFut
var raceFut1 = raceProc()
someFut.complete()
await cancelAndWait(raceFut1)
someFut = newFuture[void]()
var raceFut2 = raceProc()
someFut.fail(newException(ValueError, ""))
await cancelAndWait(raceFut2)
someFut = newFuture[void]()
var raceFut3 = raceProc()
someFut.cancel()
await cancelAndWait(raceFut3)
result = (raceFut1.state == FutureState.Cancelled) and
(raceFut2.state == FutureState.Cancelled) and
(raceFut3.state == FutureState.Cancelled)
proc testWait(): bool =
result = waitFor(testWaitAsync())
waitFor(testWaitAsync())
proc testWithTimeout(): bool =
result = waitFor(testWithTimeoutAsync())
waitFor(testWithTimeoutAsync())
proc testCancellationRace(): bool =
waitFor(testCancellationRaceAsync())
test "Async undefined behavior (#7758) test":
check test1() == true
@ -994,3 +1020,5 @@ suite "Future[T] behavior test suite":
check testWait() == true
test "Cancellation withTimeout() test":
check testWithTimeout() == true
test "Cancellation race test":
check testCancellationRace() == true