diff --git a/chronos/internal/asyncfutures.nim b/chronos/internal/asyncfutures.nim index 6c8f2bd..7a716c6 100644 --- a/chronos/internal/asyncfutures.nim +++ b/chronos/internal/asyncfutures.nim @@ -989,31 +989,87 @@ template cancel*(future: FutureBase) {. ## Cancel ``future``. cancelSoon(future, nil, nil, getSrcLocation()) -proc cancelAndWait*(future: FutureBase, loc: ptr SrcLoc): Future[void] {. - async: (raw: true, raises: []).} = +proc cancelAndWait( + loc: ptr SrcLoc, + futs: varargs[FutureBase] +): Future[void] {.async: (raw: true, raises: []).} = + let + retFuture = + Future[void].Raising([]).init( + "chronos.cancelAndWait(varargs[FutureBase])", + {FutureFlag.OwnCancelSchedule}) + var count = 0 + + proc continuation(udata: pointer) {.gcsafe.} = + dec(count) + if count == 0: + retFuture.complete() + + retFuture.cancelCallback = nil + + for futn in futs: + if not(futn.finished()): + inc(count) + cancelSoon(futn, continuation, cast[pointer](futn), loc) + + if count == 0: + retFuture.complete() + + retFuture + +proc cancelAndWait( + loc: ptr SrcLoc, + futs: openArray[SomeFuture] +): Future[void] {.async: (raw: true, raises: []).} = + cancelAndWait(loc, futs.mapIt(FutureBase(it))) + +template cancelAndWait*(future: FutureBase): Future[void].Raising([]) = ## Perform cancellation ``future`` return Future which will be completed when ## ``future`` become finished (completed with value, failed or cancelled). ## ## NOTE: Compared to the `tryCancel()` call, this procedure call guarantees ## that ``future``will be finished (completed with value, failed or cancelled) ## as quickly as possible. - let retFuture = newFuture[void]("chronos.cancelAndWait(FutureBase)", - {FutureFlag.OwnCancelSchedule}) + cancelAndWait(getSrcLocation(), future) - proc continuation(udata: pointer) {.gcsafe.} = - retFuture.complete() +template cancelAndWait*(future: SomeFuture): Future[void].Raising([]) = + ## Perform cancellation ``future`` return Future which will be completed when + ## ``future`` become finished (completed with value, failed or cancelled). + ## + ## NOTE: Compared to the `tryCancel()` call, this procedure call guarantees + ## that ``future``will be finished (completed with value, failed or cancelled) + ## as quickly as possible. + cancelAndWait(getSrcLocation(), FutureBase(future)) - if future.finished(): - retFuture.complete() - else: - retFuture.cancelCallback = nil - cancelSoon(future, continuation, cast[pointer](retFuture), loc) +template cancelAndWait*(futs: varargs[FutureBase]): Future[void].Raising([]) = + ## Perform cancellation of all the ``futs``. Returns Future which will be + ## completed when all the ``futs`` become finished (completed with value, + ## failed or cancelled). + ## + ## NOTE: Compared to the `tryCancel()` call, this procedure call guarantees + ## that all the ``futs``will be finished (completed with value, failed or + ## cancelled) as quickly as possible. + ## + ## NOTE: It is safe to pass finished futures in ``futs`` (completed with + ## value, failed or cancelled). + ## + ## NOTE: If ``futs`` is an empty array, procedure returns completed Future. + cancelAndWait(getSrcLocation(), futs) - retFuture - -template cancelAndWait*(future: FutureBase): Future[void].Raising([]) = - ## Cancel ``future``. - cancelAndWait(future, getSrcLocation()) +template cancelAndWait*(futs: openArray[SomeFuture]): Future[void].Raising([]) = + ## Perform cancellation of all the ``futs``. Returns Future which will be + ## completed when all the ``futs`` become finished (completed with value, + ## failed or cancelled). + ## + ## NOTE: Compared to the `tryCancel()` call, this procedure call guarantees + ## that all the ``futs``will be finished (completed with value, failed or + ## cancelled) as quickly as possible. + ## + ## NOTE: It is safe to pass finished futures in ``futs`` (completed with + ## value, failed or cancelled). + ## + ## NOTE: If ``futs`` is an empty array, procedure returns completed Future. + cancelAndWait(getSrcLocation(), futs) proc noCancel*[F: SomeFuture](future: F): auto = # async: (raw: true, raises: asyncraiseOf(future) - CancelledError ## Prevent cancellation requests from propagating to ``future`` while diff --git a/tests/testfut.nim b/tests/testfut.nim index 8d9fa58..c04b54d 100644 --- a/tests/testfut.nim +++ b/tests/testfut.nim @@ -2352,6 +2352,226 @@ suite "Future[T] behavior test suite": future1.cancelled() == true future2.cancelled() == true + asyncTest "cancelAndWait(varargs) should be able to cancel test": + proc test01() {.async.} = + await noCancel sleepAsync(100.milliseconds) + await noCancel sleepAsync(100.milliseconds) + await sleepAsync(100.milliseconds) + + proc test02() {.async.} = + await noCancel sleepAsync(100.milliseconds) + await sleepAsync(100.milliseconds) + await noCancel sleepAsync(100.milliseconds) + + proc test03() {.async.} = + await sleepAsync(100.milliseconds) + await noCancel sleepAsync(100.milliseconds) + await noCancel sleepAsync(100.milliseconds) + + proc test04() {.async.} = + while true: + await noCancel sleepAsync(50.milliseconds) + await sleepAsync(0.milliseconds) + + proc test05() {.async.} = + while true: + await sleepAsync(0.milliseconds) + await noCancel sleepAsync(50.milliseconds) + + proc test11() {.async: (raises: [CancelledError]).} = + await noCancel sleepAsync(100.milliseconds) + await noCancel sleepAsync(100.milliseconds) + await sleepAsync(100.milliseconds) + + proc test12() {.async: (raises: [CancelledError]).} = + await noCancel sleepAsync(100.milliseconds) + await sleepAsync(100.milliseconds) + await noCancel sleepAsync(100.milliseconds) + + proc test13() {.async: (raises: [CancelledError]).} = + await sleepAsync(100.milliseconds) + await noCancel sleepAsync(100.milliseconds) + await noCancel sleepAsync(100.milliseconds) + + proc test14() {.async: (raises: [CancelledError]).} = + while true: + await noCancel sleepAsync(50.milliseconds) + await sleepAsync(0.milliseconds) + + proc test15() {.async: (raises: [CancelledError]).} = + while true: + await sleepAsync(0.milliseconds) + await noCancel sleepAsync(50.milliseconds) + + template runTest(N1, N2, N3: untyped) = + let + future01 = `test N2 N1`() + future02 = `test N2 N1`() + future03 = `test N2 N1`() + future04 = `test N2 N1`() + future05 = `test N2 N1`() + future06 = `test N2 N1`() + future07 = `test N2 N1`() + future08 = `test N2 N1`() + future09 = `test N2 N1`() + future10 = `test N2 N1`() + future11 = `test N2 N1`() + future12 = `test N2 N1`() + + await allFutures( + cancelAndWait(future01, future02), + cancelAndWait(FutureBase(future03), FutureBase(future04)), + cancelAndWait([future05, future06]), + cancelAndWait([FutureBase(future07), FutureBase(future08)]), + cancelAndWait(@[future09, future10]), + cancelAndWait(@[FutureBase(future11), FutureBase(future12)]) + ) + + let + future21 = `test N2 N1`() + future22 = `test N2 N1`() + future23 = `test N2 N1`() + future24 = `test N2 N1`() + future25 = `test N2 N1`() + future26 = `test N2 N1`() + future27 = `test N2 N1`() + future28 = `test N2 N1`() + future29 = `test N2 N1`() + future30 = `test N2 N1`() + future31 = `test N2 N1`() + future32 = `test N2 N1`() + + await sleepAsync(`N3`) + + await allFutures( + cancelAndWait(future21, future22), + cancelAndWait(FutureBase(future23), FutureBase(future24)), + cancelAndWait([future25, future26]), + cancelAndWait([FutureBase(future27), FutureBase(future28)]), + cancelAndWait(@[future29, future30]), + cancelAndWait(@[FutureBase(future31), FutureBase(future32)]) + ) + + check: + future01.state == FutureState.Cancelled + future02.state == FutureState.Cancelled + future03.state == FutureState.Cancelled + future04.state == FutureState.Cancelled + future05.state == FutureState.Cancelled + future06.state == FutureState.Cancelled + future07.state == FutureState.Cancelled + future08.state == FutureState.Cancelled + future09.state == FutureState.Cancelled + future10.state == FutureState.Cancelled + future11.state == FutureState.Cancelled + future12.state == FutureState.Cancelled + future21.state == FutureState.Cancelled + future22.state == FutureState.Cancelled + future23.state == FutureState.Cancelled + future24.state == FutureState.Cancelled + future25.state == FutureState.Cancelled + future26.state == FutureState.Cancelled + future27.state == FutureState.Cancelled + future28.state == FutureState.Cancelled + future29.state == FutureState.Cancelled + future30.state == FutureState.Cancelled + future31.state == FutureState.Cancelled + future32.state == FutureState.Cancelled + + runTest(1, 0, 10.milliseconds) + runTest(1, 1, 10.milliseconds) + runTest(2, 0, 10.milliseconds) + runTest(2, 1, 10.milliseconds) + runTest(3, 0, 10.milliseconds) + runTest(3, 1, 10.milliseconds) + runTest(4, 0, 333.milliseconds) + runTest(4, 1, 333.milliseconds) + runTest(5, 0, 333.milliseconds) + runTest(5, 1, 333.milliseconds) + + asyncTest "cancelAndWait([]) on empty set returns completed Future test": + var + a0: array[0, Future[void]] + a1: array[0, Future[void].Raising([CancelledError])] + a2: seq[Future[void].Raising([CancelledError])] + a3: seq[Future[void]] + + let + future1 = cancelAndWait() + future2 = cancelAndWait(a0) + future3 = cancelAndWait(a1) + future4 = cancelAndWait(a2) + future5 = cancelAndWait(a3) + + check: + future1.finished() == true + future2.finished() == true + future3.finished() == true + future4.finished() == true + future5.finished() == true + + asyncTest "cancelAndWait([]) should ignore finished futures test": + let + future0 = + Future[void].Raising([]).init("future0", {OwnCancelSchedule}) + future1 = + Future[void].Raising([CancelledError]).init("future1") + future2 = + Future[void].Raising([CancelledError, ValueError]).init("future2") + future3 = + Future[string].Raising([]).init("future3", {OwnCancelSchedule}) + future4 = + Future[string].Raising([CancelledError]).init("future4") + future5 = + Future[string].Raising([CancelledError, ValueError]).init("future5") + future6 = + newFuture[void]("future6") + future7 = + newFuture[void]("future7") + future8 = + newFuture[void]("future8") + future9 = + newFuture[string]("future9") + future10 = + newFuture[string]("future10") + future11 = + newFuture[string]("future11") + + future0.complete() + check future1.tryCancel() == true + future2.fail(newException(ValueError, "Test Error")) + future3.complete("test") + check future4.tryCancel() == true + future5.fail(newException(ValueError, "Test Error")) + future6.complete() + check future7.tryCancel() == true + future8.fail(newException(ValueError, "Test Error")) + future9.complete("test") + check future10.tryCancel() == true + future11.fail(newException(ValueError, "Test Error")) + + check: + cancelAndWait(future0, future1, future2).finished() == true + cancelAndWait(future3, future4, future5).finished() == true + cancelAndWait(future6, future7, future8).finished() == true + cancelAndWait(future9, future10, future11).finished() == true + cancelAndWait(future0, future1, future2, + future3, future4, future5, + future5, future7, future8, + future9, future10, future11).finished() == true + + cancelAndWait([future0, future1, future2]).finished() == true + cancelAndWait([future3, future4]).finished() == true + cancelAndWait([future5]).finished() == true + cancelAndWait([future6, future7, future8]).finished() == true + cancelAndWait([future9, future10, future11]).finished() == true + + cancelAndWait(@[future0, future1, future2]).finished() == true + cancelAndWait(@[future3, future4]).finished() == true + cancelAndWait(@[future5]).finished() == true + cancelAndWait(@[future6, future7, future8]).finished() == true + cancelAndWait(@[future9, future10, future11]).finished() == true + asyncTest "join() test": proc joinFoo0(future: FutureBase) {.async.} = await join(future)