diff --git a/chronos.nimble b/chronos.nimble index 7b951af..af25a26 100644 --- a/chronos.nimble +++ b/chronos.nimble @@ -1,5 +1,5 @@ packageName = "chronos" -version = "2.2.6" +version = "2.2.7" author = "Status Research & Development GmbH" description = "Chronos" license = "Apache License 2.0 or MIT" diff --git a/chronos/asyncfutures2.nim b/chronos/asyncfutures2.nim index d37c64a..97ac1f7 100644 --- a/chronos/asyncfutures2.nim +++ b/chronos/asyncfutures2.nim @@ -26,13 +26,19 @@ type deleted*: bool # ZAH: This can probably be stored with a cheaper representation - # until the moment it needs to be printed to the screen (e.g. seq[StackTraceEntry]) + # until the moment it needs to be printed to the screen + # (e.g. seq[StackTraceEntry]) StackTrace = string + FutureState* {.pure.} = enum + Pending, Finished, Cancelled, Failed + FutureBase* = ref object of RootObj ## Untyped future. location: array[2, ptr SrcLoc] callbacks: Deque[AsyncCallback] - finished: bool + cancelcb*: CallbackFunc + child*: FutureBase + state*: FutureState error*: ref Exception ## Stored exception errorStackTrace*: StackTrace stackTrace: StackTrace ## For debugging purposes only. @@ -58,6 +64,8 @@ type FutureError* = object of Exception cause*: FutureBase + CancelledError* = object of FutureError + var currentID* {.threadvar.}: int currentID = 0 @@ -79,7 +87,7 @@ proc callSoon*(c: CallbackFunc, u: pointer = nil) = template setupFutureBase(loc: ptr SrcLoc) = new(result) - result.finished = false + result.state = FutureState.Pending result.stackTrace = getStackTrace() result.id = currentID result.location[LocCreateIndex] = loc @@ -135,13 +143,30 @@ template newFutureVar*[T](fromProc: static[string] = ""): auto = proc clean*[T](future: FutureVar[T]) = ## Resets the ``finished`` status of ``future``. - Future[T](future).finished = false + Future[T](future).state = FutureState.Pending Future[T](future).error = nil +proc finished*(future: FutureBase | FutureVar): bool {.inline.} = + ## Determines whether ``future`` has completed. + ## + ## ``True`` may indicate an error or a value. Use ``failed`` to distinguish. + when future is FutureVar: + result = (FutureBase(future).state != FutureState.Pending) + else: + result = (future.state != FutureState.Pending) + +proc cancelled*(future: FutureBase): bool {.inline.} = + ## Determines whether ``future`` has cancelled. + result = (future.state == FutureState.Cancelled) + +proc failed*(future: FutureBase): bool {.inline.} = + ## Determines whether ``future`` completed with an error. + result = (future.state == FutureState.Failed) + proc checkFinished[T](future: Future[T], loc: ptr SrcLoc) = ## Checks whether `future` is finished. If it is then raises a ## ``FutureError``. - if future.finished: + if future.finished(): var msg = "" msg.add("An attempt was made to complete a Future more than once. ") msg.add("Details:") @@ -170,7 +195,7 @@ proc call(callbacks: var Deque[AsyncCallback]) = var count = len(callbacks) while count > 0: var item = callbacks.popFirst() - if not item.deleted: + if not(item.deleted): callSoon(item.function, item.udata) dec(count) @@ -185,44 +210,48 @@ proc remove(callbacks: var Deque[AsyncCallback], item: AsyncCallback) = p.deleted = true proc complete[T](future: Future[T], val: T, loc: ptr SrcLoc) = - checkFinished(future, loc) - doAssert(isNil(future.error)) - future.value = val - future.finished = true - future.callbacks.call() + if not(future.cancelled()): + checkFinished(future, loc) + doAssert(isNil(future.error)) + future.value = val + future.state = FutureState.Finished + future.callbacks.call() template complete*[T](future: Future[T], val: T) = ## Completes ``future`` with value ``val``. complete(future, val, getSrcLocation()) proc complete(future: Future[void], loc: ptr SrcLoc) = - ## Completes a void ``future``. - checkFinished(future, loc) - doAssert(isNil(future.error)) - future.finished = true - future.callbacks.call() + if not(future.cancelled()): + checkFinished(future, loc) + doAssert(isNil(future.error)) + future.state = FutureState.Finished + future.callbacks.call() template complete*(future: Future[void]) = + ## Completes a void ``future``. complete(future, getSrcLocation()) proc complete[T](future: FutureVar[T], loc: ptr SrcLoc) = - template fut: untyped = Future[T](future) - checkFinished(fut, loc) - doAssert(isNil(fut.error)) - fut.finished = true - fut.callbacks.call() + if not(future.cancelled()): + template fut: untyped = Future[T](future) + checkFinished(fut, loc) + doAssert(isNil(fut.error)) + fut.state = FutureState.Finished + fut.callbacks.call() template complete*[T](futvar: FutureVar[T]) = ## Completes a ``FutureVar``. complete(futvar, getSrcLocation()) proc complete[T](futvar: FutureVar[T], val: T, loc: ptr SrcLoc) = - template fut: untyped = Future[T](futvar) - checkFinished(fut, loc) - doAssert(isNil(fut.error)) - fut.finished = true - fut.value = val - fut.callbacks.call() + if not(futvar.cancelled()): + template fut: untyped = Future[T](futvar) + checkFinished(fut, loc) + doAssert(isNil(fut.error)) + fut.state = FutureState.Finished + fut.value = val + fut.callbacks.call() template complete*[T](futvar: FutureVar[T], val: T) = ## Completes a ``FutureVar`` with value ``val``. @@ -231,17 +260,38 @@ template complete*[T](futvar: FutureVar[T], val: T) = complete(futvar, val, getSrcLocation()) proc fail[T](future: Future[T], error: ref Exception, loc: ptr SrcLoc) = - checkFinished(future, loc) - future.finished = true - future.error = error - future.errorStackTrace = - if getStackTrace(error) == "": getStackTrace() else: getStackTrace(error) - future.callbacks.call() + if not(future.cancelled()): + checkFinished(future, loc) + future.state = FutureState.Failed + future.error = error + future.errorStackTrace = + if getStackTrace(error) == "": getStackTrace() else: getStackTrace(error) + future.callbacks.call() template fail*[T](future: Future[T], error: ref Exception) = ## Completes ``future`` with ``error``. fail(future, error, getSrcLocation()) +proc cancel[T](future: Future[T], loc: ptr SrcLoc) = + if future.finished(): + checkFinished(future, loc) + else: + var first = FutureBase(future) + var last = first + while (not isNil(last.child)) and (not(last.child.finished())): + last = last.child + if last == first: + checkFinished(future, loc) + last.state = FutureState.Cancelled + last.error = newException(CancelledError, "") + if not(isNil(last.cancelcb)): + last.cancelcb(cast[pointer](last)) + last.callbacks.call() + +template cancel*[T](future: Future[T]) = + ## Cancel ``future``. + cancel(future, getSrcLocation()) + proc clearCallbacks(future: FutureBase) = var count = len(future.callbacks) while count > 0: @@ -253,8 +303,7 @@ proc addCallback*(future: FutureBase, cb: CallbackFunc, udata: pointer = nil) = ## ## If future has already completed then ``cb`` will be called immediately. doAssert(not isNil(cb)) - if future.finished: - # ZAH: it seems that the Future needs to know its associated Dispatcher + if future.finished(): callSoon(cb, udata) else: let acb = AsyncCallback(function: cb, udata: udata) @@ -292,6 +341,12 @@ proc `callback=`*[T](future: Future[T], cb: CallbackFunc) = ## If future has already completed then ``cb`` will be called immediately. `callback=`(future, cb, cast[pointer](future)) +proc `cancelCallback=`*[T](future: Future[T], cb: CallbackFunc) = + ## Sets the callback procedure to be called when the future is cancelled. + ## + ## This callback will be called immediately as ``future.cancel()`` invoked. + future.cancelcb = cb + proc getHint(entry: StackTraceEntry): string = ## We try to provide some hints about stack trace entries that the user ## may not be familiar with, in particular calls inside the stdlib. @@ -371,8 +426,8 @@ proc read*[T](future: Future[T] | FutureVar[T]): T = {.push hint[ConvFromXtoItselfNotNeeded]: off.} let fut = Future[T](future) {.pop.} - if fut.finished: - if not isNil(fut.error): + if fut.finished(): + if not(isNil(fut.error)): injectStacktrace(fut) raise fut.error when T isnot void: @@ -386,8 +441,10 @@ proc readError*[T](future: Future[T]): ref Exception = ## ## An ``ValueError`` exception will be thrown if no exception exists ## in the specified Future. - if not isNil(future.error): return future.error + if not(isNil(future.error)): + return future.error else: + # TODO: Make a custom exception type for this? raise newException(ValueError, "No error in future.") proc mget*[T](future: FutureVar[T]): var T = @@ -397,19 +454,6 @@ proc mget*[T](future: FutureVar[T]): var T = ## Future has not been finished. result = Future[T](future).value -proc finished*(future: FutureBase | FutureVar): bool = - ## Determines whether ``future`` has completed. - ## - ## ``True`` may indicate an error or a value. Use ``failed`` to distinguish. - when future is FutureVar: - result = (FutureBase(future)).finished - else: - result = future.finished - -proc failed*(future: FutureBase): bool = - ## Determines whether ``future`` completed with an error. - return (not isNil(future.error)) - proc asyncCheck*[T](future: Future[T]) = ## Sets a callback on ``future`` which raises an exception if the future ## finished with an error. @@ -417,7 +461,7 @@ proc asyncCheck*[T](future: Future[T]) = ## This should be used instead of ``discard`` to discard void futures. doAssert(not isNil(future), "Future is nil") proc cb(data: pointer) = - if future.failed: + if future.failed() or future.cancelled(): injectStacktrace(future) raise future.error future.callback = cb @@ -425,50 +469,68 @@ proc asyncCheck*[T](future: Future[T]) = proc asyncDiscard*[T](future: Future[T]) = discard ## This is async workaround for discard ``Future[T]``. -# ZAH: The return type here could be a Future[(T, Y)] -proc `and`*[T, Y](fut1: Future[T], fut2: Future[Y]): Future[void] = +proc `and`*[T, Y](fut1: Future[T], fut2: Future[Y]): Future[void] {. + deprecated: "Use allFutures[T](varargs[Future[T]])".} = ## Returns a future which will complete once both ``fut1`` and ``fut2`` ## complete. - # ZAH: The Rust implementation of futures is making the case that the - # `and` combinator can be implemented in a more efficient way without - # resorting to closures and callbacks. I haven't thought this through - # completely yet, but here is their write-up: - # http://aturon.github.io/2016/09/07/futures-design/ - # - # We should investigate this further, before settling on the final design. - # The same reasoning applies to `or` and `all`. + ## + ## If cancelled, ``fut1`` and ``fut2`` futures WILL NOT BE cancelled. var retFuture = newFuture[void]("chronos.`and`") proc cb(data: pointer) = - if not retFuture.finished: - if (fut1.failed or fut1.finished) and (fut2.failed or fut2.finished): + if not(retFuture.finished()): + if fut1.finished() and fut2.finished(): if cast[pointer](fut1) == data: - if fut1.failed: retFuture.fail(fut1.error) - elif fut2.finished: retFuture.complete() + if fut1.failed(): + retFuture.fail(fut1.error) + else: + retFuture.complete() else: - if fut2.failed: retFuture.fail(fut2.error) - elif fut1.finished: retFuture.complete() + if fut2.failed(): + retFuture.fail(fut2.error) + else: + retFuture.complete() fut1.callback = cb fut2.callback = cb + + proc cancel(udata: pointer) {.gcsafe.} = + # On cancel we remove all our callbacks only. + if not(retFuture.finished()): + fut1.removeCallback(cb) + fut2.removeCallback(cb) + + retFuture.cancelCallback = cancel return retFuture -proc `or`*[T, Y](fut1: Future[T], fut2: Future[Y]): Future[void] = +proc `or`*[T, Y](fut1: Future[T], fut2: Future[Y]): Future[void] {. + deprecated: "Use one[T](varargs[Future[T]])".} = ## Returns a future which will complete once either ``fut1`` or ``fut2`` ## complete. + ## + ## If cancelled, ``fut1`` and ``fut2`` futures WILL NOT BE cancelled. var retFuture = newFuture[void]("chronos.`or`") - proc cb(data: pointer) {.gcsafe.} = - if not retFuture.finished: - var fut = cast[FutureBase](data) - if cast[pointer](fut1) == data: + proc cb(udata: pointer) {.gcsafe.} = + if not(retFuture.finished()): + var fut = cast[FutureBase](udata) + if cast[pointer](fut1) == udata: fut2.removeCallback(cb) else: fut1.removeCallback(cb) - if fut.failed: retFuture.fail(fut.error) + if fut.failed(): retFuture.fail(fut.error) else: retFuture.complete() fut1.callback = cb fut2.callback = cb + + proc cancel(udata: pointer) {.gcsafe.} = + # On cancel we remove all our callbacks only. + if not(retFuture.finished()): + fut1.removeCallback(cb) + fut2.removeCallback(cb) + + retFuture.cancelCallback = cancel return retFuture -proc all*[T](futs: varargs[Future[T]]): auto = +proc all*[T](futs: varargs[Future[T]]): auto {. + deprecated: "Use allFutures(varargs[Future[T]])".} = ## Returns a future which will complete once all futures in ``futs`` complete. ## If the argument is empty, the returned future completes immediately. ## @@ -480,6 +542,11 @@ proc all*[T](futs: varargs[Future[T]]): auto = ## ## Note, that if one of the futures in ``futs`` will fail, result of ``all()`` ## will also be failed with error from failed future. + ## + ## TODO: This procedure has bug on handling cancelled futures from ``futs``. + ## So if future from ``futs`` list become cancelled, what must be returned? + ## You can't cancel result ``retFuture`` because in such way infinite + ## recursion will happen. let totalFutures = len(futs) var completedFutures = 0 @@ -488,17 +555,19 @@ proc all*[T](futs: varargs[Future[T]]): auto = when T is void: var retFuture = newFuture[void]("chronos.all(void)") - for fut in nfuts: - fut.addCallback proc (data: pointer) = + proc cb(udata: pointer) {.gcsafe.} = + if not(retFuture.finished()): inc(completedFutures) - if not retFuture.finished: - if completedFutures == totalFutures: - for nfut in nfuts: - if nfut.failed: - retFuture.fail(nfut.error) - break - if not retFuture.failed: - retFuture.complete() + if completedFutures == totalFutures: + for nfut in nfuts: + if nfut.failed(): + retFuture.fail(nfut.error) + break + if not(retFuture.failed()): + retFuture.complete() + + for fut in nfuts: + fut.addCallback(cb) if len(nfuts) == 0: retFuture.complete() @@ -507,26 +576,30 @@ proc all*[T](futs: varargs[Future[T]]): auto = else: var retFuture = newFuture[seq[T]]("chronos.all(T)") var retValues = newSeq[T](totalFutures) - for fut in nfuts: - fut.addCallback proc (data: pointer) = + + proc cb(udata: pointer) {.gcsafe.} = + if not(retFuture.finished()): inc(completedFutures) - if not retFuture.finished: - if completedFutures == totalFutures: - for k, nfut in nfuts: - if nfut.failed: - retFuture.fail(nfut.error) - break - else: - retValues[k] = nfut.read() - if not retFuture.failed: - retFuture.complete(retValues) + if completedFutures == totalFutures: + for k, nfut in nfuts: + if nfut.failed(): + retFuture.fail(nfut.error) + break + else: + retValues[k] = nfut.read() + if not(retFuture.failed()): + retFuture.complete(retValues) + + for fut in nfuts: + fut.addCallback(cb) if len(nfuts) == 0: retFuture.complete(retValues) return retFuture -proc oneIndex*[T](futs: varargs[Future[T]]): Future[int] = +proc oneIndex*[T](futs: varargs[Future[T]]): Future[int] {. + deprecated: "Use one[T](varargs[Future[T]])".} = ## Returns a future which will complete once one of the futures in ``futs`` ## complete. ## @@ -537,10 +610,10 @@ proc oneIndex*[T](futs: varargs[Future[T]]): Future[int] = var nfuts = @futs var retFuture = newFuture[int]("chronos.oneIndex(T)") - proc cb(data: pointer) {.gcsafe.} = + proc cb(udata: pointer) {.gcsafe.} = var res = -1 - if not retFuture.finished: - var rfut = cast[FutureBase](data) + if not(retFuture.finished()): + var rfut = cast[FutureBase](udata) for i in 0.. yield future - result.add newNimNode(nnkYieldStmt, fromNode).add(futureVarNode) - # -> future.read - valueReceiver = newDotExpr(futureVarNode, newIdentNode("read")) - result.add rootReceiver + if isawait: + # -> yield future + result.add newNimNode(nnkYieldStmt, fromNode).add(futureVarNode) + # -> future.read + valueReceiver = newDotExpr(futureVarNode, newIdentNode("read")) + result.add rootReceiver + else: + # -> yield future + result.add newNimNode(nnkYieldStmt, fromNode).add(futureVarNode) + valueReceiver = futureVarNode + result.add rootReceiver template createVar(result: var NimNode, futSymName: string, asyncProc: NimNode, - valueReceiver, rootReceiver: untyped, - fromNode: NimNode) = + valueReceiver, rootReceiver, retFutSym: untyped, + fromNode: NimNode, isawait: bool) = result = newNimNode(nnkStmtList, fromNode) var futSym = genSym(nskVar, "future") result.add newVarStmt(futSym, asyncProc) # -> var future = y - useVar(result, futSym, valueReceiver, rootReceiver, fromNode) + # retFuture.child = future + result.add newAssignment( + newDotExpr( + newCall(newIdentNode("FutureBase"), copyNimNode(retFutSym)), + newIdentNode("child") + ), + newCall(newIdentNode("FutureBase"), copyNimNode(futSym)) + ) + useVar(result, futSym, valueReceiver, rootReceiver, fromNode, isawait) proc createFutureVarCompletions(futureVarIdents: seq[NimNode], fromNode: NimNode): NimNode {.compileTime.} = @@ -134,7 +150,8 @@ proc processBody(node, retFutureSym: NimNode, result.add newNimNode(nnkReturnStmt, node).add(newNilLit()) return # Don't process the children of this return stmt of nnkCommand, nnkCall: - if node[0].kind == nnkIdent and node[0].eqIdent("await"): + if node[0].kind == nnkIdent and + (node[0].eqIdent("await") or node[0].eqIdent("awaitne")): case node[1].kind of nnkIdent, nnkInfix, nnkDotExpr, nnkCall, nnkCommand: # await x @@ -143,40 +160,48 @@ proc processBody(node, retFutureSym: NimNode, # await foo p, x var futureValue: NimNode result.createVar("future" & $node[1][0].toStrLit, node[1], futureValue, - futureValue, node) + futureValue, retFutureSym, node, + node[0].eqIdent("await")) else: error("Invalid node kind in 'await', got: " & $node[1].kind) elif node.len > 1 and node[1].kind == nnkCommand and - node[1][0].kind == nnkIdent and node[1][0].eqIdent("await"): + node[1][0].kind == nnkIdent and + (node[1][0].eqIdent("await") or node[1][0].eqIdent("awaitne")): # foo await x var newCommand = node result.createVar("future" & $node[0].toStrLit, node[1][1], newCommand[1], - newCommand, node) + newCommand, retFutureSym, node, + node[1][0].eqIdent("await")) of nnkVarSection, nnkLetSection: case node[0][2].kind of nnkCommand: - if node[0][2][0].kind == nnkIdent and node[0][2][0].eqIdent("await"): + if node[0][2][0].kind == nnkIdent and + (node[0][2][0].eqIdent("await") or node[0][2][0].eqIdent("awaitne")): # var x = await y var newVarSection = node # TODO: Should this use copyNimNode? result.createVar("future" & node[0][0].strVal, node[0][2][1], - newVarSection[0][2], newVarSection, node) + newVarSection[0][2], newVarSection, retFutureSym, node, + node[0][2][0].eqIdent("await")) else: discard of nnkAsgn: case node[1].kind of nnkCommand: - if node[1][0].eqIdent("await"): + if node[1][0].eqIdent("await") or node[1][0].eqIdent("awaitne"): # x = await y var newAsgn = node - result.createVar("future" & $node[0].toStrLit, node[1][1], newAsgn[1], newAsgn, node) + result.createVar("future" & $node[0].toStrLit, node[1][1], newAsgn[1], + newAsgn, retFutureSym, node, + node[1][0].eqIdent("await")) else: discard of nnkDiscardStmt: # discard await x if node[0].kind == nnkCommand and node[0][0].kind == nnkIdent and - node[0][0].eqIdent("await"): + (node[0][0].eqIdent("await") or node[0][0].eqIdent("awaitne")): var newDiscard = node result.createVar("futureDiscard_" & $toStrLit(node[0][1]), node[0][1], - newDiscard[0], newDiscard, node) + newDiscard[0], newDiscard, retFutureSym, node, + node[0][0].eqIdent("await")) else: discard for i in 0 ..< result.len: @@ -336,103 +361,3 @@ macro async*(prc: untyped): untyped = result = asyncSingleProc(prc) when defined(nimDumpAsync): echo repr result - - -# Multisync -proc emptyNoop[T](x: T): T = - # The ``await``s are replaced by a call to this for simplicity. - when T isnot void: - return x - -proc stripAwait(node: NimNode): NimNode = - ## Strips out all ``await`` commands from a procedure body, replaces them - ## with ``emptyNoop`` for simplicity. - result = node - - let emptyNoopSym = bindSym("emptyNoop") - - case node.kind - of nnkCommand, nnkCall: - if node[0].kind == nnkIdent and node[0].eqIdent("await"): - node[0] = emptyNoopSym - elif node.len > 1 and node[1].kind == nnkCommand and - node[1][0].kind == nnkIdent and node[1][0].eqIdent("await"): - # foo await x - node[1][0] = emptyNoopSym - of nnkVarSection, nnkLetSection: - case node[0][2].kind - of nnkCommand: - if node[0][2][0].kind == nnkIdent and node[0][2][0].eqIdent("await"): - # var x = await y - node[0][2][0] = emptyNoopSym - else: discard - of nnkAsgn: - case node[1].kind - of nnkCommand: - if node[1][0].eqIdent("await"): - # x = await y - node[1][0] = emptyNoopSym - else: discard - of nnkDiscardStmt: - # discard await x - if node[0].kind == nnkCommand and node[0][0].kind == nnkIdent and - node[0][0].eqIdent("await"): - node[0][0] = emptyNoopSym - else: discard - - for i in 0 ..< result.len: - result[i] = stripAwait(result[i]) - -proc splitParamType(paramType: NimNode, async: bool): NimNode = - result = paramType - if paramType.kind == nnkInfix and paramType[0].strVal in ["|", "or"]: - let firstAsync = "async" in paramType[1].strVal.normalize - let secondAsync = "async" in paramType[2].strVal.normalize - - if firstAsync: - result = paramType[if async: 1 else: 2] - elif secondAsync: - result = paramType[if async: 2 else: 1] - -proc stripReturnType(returnType: NimNode): NimNode = - # Strip out the 'Future' from 'Future[T]'. - result = returnType - if returnType.kind == nnkBracketExpr: - let fut = repr(returnType[0]) - verifyReturnType(fut) - result = returnType[1] - -proc splitProc(prc: NimNode): (NimNode, NimNode) = - ## Takes a procedure definition which takes a generic union of arguments, - ## for example: proc (socket: Socket | AsyncSocket). - ## It transforms them so that ``proc (socket: Socket)`` and - ## ``proc (socket: AsyncSocket)`` are returned. - - result[0] = prc.copyNimTree() - # Retrieve the `T` inside `Future[T]`. - let returnType = stripReturnType(result[0][3][0]) - result[0][3][0] = splitParamType(returnType, async=false) - for i in 1 ..< result[0][3].len: - # Sync proc (0) -> FormalParams (3) -> IdentDefs, the parameter (i) -> - # parameter type (1). - result[0][3][i][1] = splitParamType(result[0][3][i][1], async=false) - result[0][6] = stripAwait(result[0][6]) - - result[1] = prc.copyNimTree() - if result[1][3][0].kind == nnkBracketExpr: - result[1][3][0][1] = splitParamType(result[1][3][0][1], async=true) - for i in 1 ..< result[1][3].len: - # Async proc (1) -> FormalParams (3) -> IdentDefs, the parameter (i) -> - # parameter type (1). - result[1][3][i][1] = splitParamType(result[1][3][i][1], async=true) - -macro multisync*(prc: untyped): untyped = - ## Macro which processes async procedures into both asynchronous and - ## synchronous procedures. - ## - ## The generated async procedures use the ``async`` macro, whereas the - ## generated synchronous procedures simply strip off the ``await`` calls. - let (sync, asyncPrc) = splitProc(prc) - result = newStmtList() - result.add(asyncSingleProc(asyncPrc)) - result.add(sync) diff --git a/chronos/asyncsync.nim b/chronos/asyncsync.nim index 9382a7f..dc77fa5 100644 --- a/chronos/asyncsync.nim +++ b/chronos/asyncsync.nim @@ -80,12 +80,12 @@ proc acquire*(lock: AsyncLock) {.async.} = ## ## This procedure blocks until the lock ``lock`` is unlocked, then sets it ## to locked and returns. - if not lock.locked: + if not(lock.locked): lock.locked = true else: var w = newFuture[void]("AsyncLock.acquire") lock.waiters.addLast(w) - yield w + await w lock.locked = true proc own*(lock: AsyncLock) = @@ -112,7 +112,7 @@ proc release*(lock: AsyncLock) = lock.locked = false while len(lock.waiters) > 0: w = lock.waiters.popFirst() - if not w.finished: + if not(w.finished()): w.complete() break else: @@ -143,18 +143,18 @@ proc wait*(event: AsyncEvent) {.async.} = else: var w = newFuture[void]("AsyncEvent.wait") event.waiters.addLast(w) - yield w + await w proc fire*(event: AsyncEvent) = ## Set the internal flag of ``event`` to `true`. All tasks waiting for it ## to become `true` are awakened. Task that call `wait()` once the flag is ## `true` will not block at all. var w: Future[void] - if not event.flag: + if not(event.flag): event.flag = true while len(event.waiters) > 0: w = event.waiters.popFirst() - if not w.finished: + if not(w.finished()): w.complete() proc clear*(event: AsyncEvent) = @@ -203,7 +203,8 @@ proc addFirstNoWait*[T](aq: AsyncQueue[T], item: T) = aq.queue.addFirst(item) while len(aq.getters) > 0: w = aq.getters.popFirst() - if not w.finished: w.complete() + if not(w.finished()): + w.complete() proc addLastNoWait*[T](aq: AsyncQueue[T], item: T) = ## Put an item ``item`` at the end of the queue ``aq`` immediately. @@ -215,7 +216,8 @@ proc addLastNoWait*[T](aq: AsyncQueue[T], item: T) = aq.queue.addLast(item) while len(aq.getters) > 0: w = aq.getters.popFirst() - if not w.finished: w.complete() + if not(w.finished()): + w.complete() proc popFirstNoWait*[T](aq: AsyncQueue[T]): T = ## Get an item from the beginning of the queue ``aq`` immediately. @@ -227,7 +229,8 @@ proc popFirstNoWait*[T](aq: AsyncQueue[T]): T = result = aq.queue.popFirst() while len(aq.putters) > 0: w = aq.putters.popFirst() - if not w.finished: w.complete() + if not(w.finished()): + w.complete() proc popLastNoWait*[T](aq: AsyncQueue[T]): T = ## Get an item from the end of the queue ``aq`` immediately. @@ -239,7 +242,8 @@ proc popLastNoWait*[T](aq: AsyncQueue[T]): T = result = aq.queue.popLast() while len(aq.putters) > 0: w = aq.putters.popFirst() - if not w.finished: w.complete() + if not(w.finished()): + w.complete() proc addFirst*[T](aq: AsyncQueue[T], item: T) {.async.} = ## Put an ``item`` to the beginning of the queue ``aq``. If the queue is full, @@ -247,7 +251,7 @@ proc addFirst*[T](aq: AsyncQueue[T], item: T) {.async.} = while aq.full(): var putter = newFuture[void]("AsyncQueue.addFirst") aq.putters.addLast(putter) - yield putter + await putter aq.addFirstNoWait(item) proc addLast*[T](aq: AsyncQueue[T], item: T) {.async.} = @@ -256,7 +260,7 @@ proc addLast*[T](aq: AsyncQueue[T], item: T) {.async.} = while aq.full(): var putter = newFuture[void]("AsyncQueue.addLast") aq.putters.addLast(putter) - yield putter + await putter aq.addLastNoWait(item) proc popFirst*[T](aq: AsyncQueue[T]): Future[T] {.async.} = @@ -265,7 +269,7 @@ proc popFirst*[T](aq: AsyncQueue[T]): Future[T] {.async.} = while aq.empty(): var getter = newFuture[void]("AsyncQueue.popFirst") aq.getters.addLast(getter) - yield getter + await getter result = aq.popFirstNoWait() proc popLast*[T](aq: AsyncQueue[T]): Future[T] {.async.} = @@ -274,7 +278,7 @@ proc popLast*[T](aq: AsyncQueue[T]): Future[T] {.async.} = while aq.empty(): var getter = newFuture[void]("AsyncQueue.popLast") aq.getters.addLast(getter) - yield getter + await getter result = aq.popLastNoWait() proc putNoWait*[T](aq: AsyncQueue[T], item: T) {.inline.} = diff --git a/chronos/streams/asyncstream.nim b/chronos/streams/asyncstream.nim index af99a29..3479ef1 100644 --- a/chronos/streams/asyncstream.nim +++ b/chronos/streams/asyncstream.nim @@ -567,11 +567,12 @@ proc write*(wstream: AsyncStreamWriter, pbytes: pointer, raise newAsyncStreamIncorrectError("Zero length message") if isNil(wstream.wsource): - var resFut = write(wstream.tsource, pbytes, nbytes) - yield resFut - if resFut.failed: - raise newAsyncStreamWriteError(resFut.error) - if resFut.read() != nbytes: + var res: int + try: + res = await write(wstream.tsource, pbytes, nbytes) + except: + raise newAsyncStreamWriteError(getCurrentException()) + if res != nbytes: raise newAsyncStreamIncompleteError() else: if isNil(wstream.writerLoop): @@ -582,8 +583,9 @@ proc write*(wstream: AsyncStreamWriter, pbytes: pointer, item.size = nbytes item.future = newFuture[void]("async.stream.write(pointer)") await wstream.queue.put(item) - yield item.future - if item.future.failed: + try: + await item.future + except: raise newAsyncStreamWriteError(item.future.error) proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte], @@ -604,11 +606,12 @@ proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte], raise newAsyncStreamIncorrectError("Zero length message") if isNil(wstream.wsource): - var resFut = write(wstream.tsource, sbytes, msglen) - yield resFut - if resFut.failed: - raise newAsyncStreamWriteError(resFut.error) - if resFut.read() != length: + var res: int + try: + res = await write(wstream.tsource, sbytes, msglen) + except: + raise newAsyncStreamWriteError(getCurrentException()) + if res != length: raise newAsyncStreamIncompleteError() else: if isNil(wstream.writerLoop): @@ -622,8 +625,9 @@ proc write*(wstream: AsyncStreamWriter, sbytes: seq[byte], item.size = length item.future = newFuture[void]("async.stream.write(seq)") await wstream.queue.put(item) - yield item.future - if item.future.failed: + try: + await item.future + except: raise newAsyncStreamWriteError(item.future.error) proc write*(wstream: AsyncStreamWriter, sbytes: string, @@ -643,11 +647,12 @@ proc write*(wstream: AsyncStreamWriter, sbytes: string, raise newAsyncStreamIncorrectError("Zero length message") if isNil(wstream.wsource): - var resFut = write(wstream.tsource, sbytes, msglen) - yield resFut - if resFut.failed: - raise newAsyncStreamWriteError(resFut.error) - if resFut.read() != length: + var res: int + try: + res = await write(wstream.tsource, sbytes, msglen) + except: + raise newAsyncStreamWriteError(getCurrentException()) + if res != length: raise newAsyncStreamIncompleteError() else: if isNil(wstream.writerLoop): @@ -661,8 +666,9 @@ proc write*(wstream: AsyncStreamWriter, sbytes: string, item.size = length item.future = newFuture[void]("async.stream.write(string)") await wstream.queue.put(item) - yield item.future - if item.future.failed: + try: + await item.future + except: raise newAsyncStreamWriteError(item.future.error) proc finish*(wstream: AsyncStreamWriter) {.async.} = @@ -678,8 +684,9 @@ proc finish*(wstream: AsyncStreamWriter) {.async.} = item.size = 0 item.future = newFuture[void]("async.stream.finish") await wstream.queue.put(item) - yield item.future - if item.future.failed: + try: + await item.future + except: raise newAsyncStreamWriteError(item.future.error) proc join*(rw: AsyncStreamRW): Future[void] = @@ -689,11 +696,19 @@ proc join*(rw: AsyncStreamRW): Future[void] = var retFuture = newFuture[void]("async.stream.reader.join") else: var retFuture = newFuture[void]("async.stream.writer.join") - proc continuation(udata: pointer) = retFuture.complete() - if not rw.future.finished: - rw.future.addCallback(continuation) + + proc continuation(udata: pointer) {.gcsafe.} = + retFuture.complete() + + proc cancel(udata: pointer) {.gcsafe.} = + rw.future.removeCallback(continuation, cast[pointer](retFuture)) + + if not(rw.future.finished()): + rw.future.addCallback(continuation, cast[pointer](retFuture)) + rw.future.cancelCallback = cancel else: retFuture.complete() + return retFuture proc close*(rw: AsyncStreamRW) = diff --git a/chronos/streams/chunkstream.nim b/chronos/streams/chunkstream.nim index fc81469..750872d 100644 --- a/chronos/streams/chunkstream.nim +++ b/chronos/streams/chunkstream.nim @@ -32,7 +32,7 @@ proc oneOf*[A, B](fut1: Future[A], fut2: Future[B]): Future[void] = ## error, so you need to check `fut1` and `fut2` for error. var retFuture = newFuture[void]("chunked.oneOf()") proc cb(data: pointer) {.gcsafe.} = - if not retFuture.finished: + if not(retFuture.finished()): if cast[pointer](fut1) == data: fut2.removeCallback(cb) elif cast[pointer](fut2) == data: @@ -95,11 +95,11 @@ proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} = var ruFut1 = rstream.rsource.readUntil(addr buffer[0], 1024, CRLF) await oneOf(ruFut1, exitFut) - if exitFut.finished: + if exitFut.finished(): rstream.state = AsyncStreamState.Stopped break - if ruFut1.failed: + if ruFut1.failed(): rstream.error = ruFut1.error rstream.state = AsyncStreamState.Error break @@ -118,18 +118,20 @@ proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} = toRead) await oneOf(reFut2, exitFut) - if exitFut.finished: + if exitFut.finished(): rstream.state = AsyncStreamState.Stopped break - if reFut2.failed: + if reFut2.failed(): rstream.error = reFut2.error rstream.state = AsyncStreamState.Error break rstream.buffer.update(toRead) - await rstream.buffer.transfer() or exitFut - if exitFut.finished: + + await oneOf(rstream.buffer.transfer(), exitFut) + + if exitFut.finished(): rstream.state = AsyncStreamState.Stopped break @@ -141,11 +143,11 @@ proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} = # Reading chunk trailing CRLF var reFut3 = rstream.rsource.readExactly(addr buffer[0], 2) await oneOf(reFut3, exitFut) - if exitFut.finished: + if exitFut.finished(): rstream.state = AsyncStreamState.Stopped break - if reFut3.failed: + if reFut3.failed(): rstream.error = reFut3.error rstream.state = AsyncStreamState.Error break @@ -159,18 +161,20 @@ proc chunkedReadLoop(stream: AsyncStreamReader) {.async.} = var ruFut4 = rstream.rsource.readUntil(addr buffer[0], len(buffer), CRLF) await oneOf(ruFut4, exitFut) - if exitFut.finished: + if exitFut.finished(): rstream.state = AsyncStreamState.Stopped break - if ruFut4.failed: + if ruFut4.failed(): rstream.error = ruFut4.error rstream.state = AsyncStreamState.Error break rstream.state = AsyncStreamState.Finished - await rstream.buffer.transfer() or exitFut - if exitFut.finished: + + await oneOf(rstream.buffer.transfer(), exitFut) + + if exitFut.finished(): rstream.state = AsyncStreamState.Stopped break @@ -190,7 +194,7 @@ proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} = # Getting new item from stream's queue. var getFut = wstream.queue.get() await oneOf(getFut, exitFut) - if exitFut.finished: + if exitFut.finished(): wstream.state = AsyncStreamState.Stopped break var item = getFut.read() @@ -202,11 +206,11 @@ proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} = wFut1 = wstream.wsource.write(addr buffer[0], length) await oneOf(wFut1, exitFut) - if exitFut.finished: + if exitFut.finished(): wstream.state = AsyncStreamState.Stopped break - if wFut1.failed: + if wFut1.failed(): item.future.fail(wFut1.error) continue @@ -219,22 +223,22 @@ proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} = wFut2 = wstream.wsource.write(addr item.data3[0], item.size) await oneOf(wFut2, exitFut) - if exitFut.finished: + if exitFut.finished(): wstream.state = AsyncStreamState.Stopped break - if wFut2.failed: + if wFut2.failed(): item.future.fail(wFut2.error) continue # Writing chunk footer CRLF. var wFut3 = wstream.wsource.write(CRLF) await oneOf(wFut3, exitFut) - if exitFut.finished: + if exitFut.finished(): wstream.state = AsyncStreamState.Stopped break - if wFut3.failed: + if wFut3.failed(): item.future.fail(wFut3.error) continue @@ -246,11 +250,11 @@ proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} = # Write finish chunk `0`. wFut1 = wstream.wsource.write(addr buffer[0], length) await oneOf(wFut1, exitFut) - if exitFut.finished: + if exitFut.finished(): wstream.state = AsyncStreamState.Stopped break - if wFut1.failed: + if wFut1.failed(): item.future.fail(wFut1.error) # We break here, because this is last chunk break @@ -259,11 +263,11 @@ proc chunkedWriteLoop(stream: AsyncStreamWriter) {.async.} = wFut2 = wstream.wsource.write(CRLF) await oneOf(wFut2, exitFut) - if exitFut.finished: + if exitFut.finished(): wstream.state = AsyncStreamState.Stopped break - if wFut2.failed: + if wFut2.failed(): item.future.fail(wFut2.error) # We break here, because this is last chunk break diff --git a/chronos/transports/datagram.nim b/chronos/transports/datagram.nim index 18f073d..30cd049 100644 --- a/chronos/transports/datagram.nim +++ b/chronos/transports/datagram.nim @@ -87,7 +87,7 @@ proc dumpTransportTracking(): string {.gcsafe.} = proc leakTransport(): bool {.gcsafe.} = var tracker = getDgramTransportTracker() - result = tracker.opened != tracker.closed + result = (tracker.opened != tracker.closed) proc trackDgram(t: DatagramTransport) {.inline.} = var tracker = getDgramTransportTracker() @@ -123,14 +123,17 @@ when defined(windows): let err = transp.wovl.data.errCode let vector = transp.queue.popFirst() if err == OSErrorCode(-1): - vector.writer.complete() + if not(vector.writer.finished()): + vector.writer.complete() elif int(err) == ERROR_OPERATION_ABORTED: # CancelIO() interrupt transp.state.incl(WritePaused) - vector.writer.complete() + if not(vector.writer.finished()): + vector.writer.complete() else: transp.state.incl({WritePaused, WriteError}) - vector.writer.fail(getTransportOsError(err)) + if not(vector.writer.finished()): + vector.writer.fail(getTransportOsError(err)) else: ## Initiation transp.state.incl(WritePending) @@ -153,13 +156,15 @@ when defined(windows): # CancelIO() interrupt transp.state.excl(WritePending) transp.state.incl(WritePaused) - vector.writer.complete() + if not(vector.writer.finished()): + vector.writer.complete() elif int(err) == ERROR_IO_PENDING: transp.queue.addFirst(vector) else: transp.state.excl(WritePending) transp.state.incl({WritePaused, WriteError}) - vector.writer.fail(getTransportOsError(err)) + if not(vector.writer.finished()): + vector.writer.fail(getTransportOsError(err)) else: transp.queue.addFirst(vector) break @@ -188,7 +193,7 @@ when defined(windows): elif int(err) == ERROR_OPERATION_ABORTED: # CancelIO() interrupt or closeSocket() call. transp.state.incl(ReadPaused) - if ReadClosed in transp.state: + if ReadClosed in transp.state and not(transp.future.finished()): # Stop tracking transport untrackDgram(transp) # If `ReadClosed` present, then close(transport) was called. @@ -233,12 +238,11 @@ when defined(windows): else: # Transport closure happens in callback, and we not started new # WSARecvFrom session. - if ReadClosed in transp.state: - if not transp.future.finished: - # Stop tracking transport - untrackDgram(transp) - transp.future.complete() - GC_unref(transp) + if ReadClosed in transp.state and not(transp.future.finished()): + # Stop tracking transport + untrackDgram(transp) + transp.future.complete() + GC_unref(transp) break proc resumeRead(transp: DatagramTransport) {.inline.} = @@ -424,13 +428,15 @@ else: elif vector.kind == WithoutAddress: res = posix.send(fd, vector.buf, vector.buflen, MSG_NOSIGNAL) if res >= 0: - vector.writer.complete() + if not(vector.writer.finished()): + vector.writer.complete() else: let err = osLastError() if int(err) == EINTR: continue else: - vector.writer.fail(getTransportOsError(err)) + if not(vector.writer.finished()): + vector.writer.fail(getTransportOsError(err)) break else: transp.state.incl(WritePaused) @@ -550,7 +556,7 @@ else: proc close*(transp: DatagramTransport) = ## Closes and frees resources of transport ``transp``. proc continuation(udata: pointer) = - if not transp.future.finished: + if not(transp.future.finished()): # Stop tracking transport untrackDgram(transp) transp.future.complete() @@ -657,11 +663,19 @@ proc newDatagramTransport6*[T](cbproc: DatagramCallback, proc join*(transp: DatagramTransport): Future[void] = ## Wait until the transport ``transp`` will be closed. var retFuture = newFuture[void]("datagram.transport.join") - proc continuation(udata: pointer) = retFuture.complete() - if not transp.future.finished: - transp.future.addCallback(continuation) + + proc continuation(udata: pointer) {.gcsafe.} = + retFuture.complete() + + proc cancel(udata: pointer) {.gcsafe.} = + transp.future.removeCallback(continuation, cast[pointer](retFuture)) + + if not(transp.future.finished()): + transp.future.addCallback(continuation, cast[pointer](retFuture)) + retFuture.cancelCallback = cancel else: retFuture.complete() + return retFuture proc closeWait*(transp: DatagramTransport): Future[void] = @@ -696,7 +710,7 @@ proc send*(transp: DatagramTransport, msg: string, msglen = -1): Future[void] = retFuture.gcholder = msg let length = if msglen <= 0: len(msg) else: msglen let vector = GramVector(kind: WithoutAddress, buf: addr retFuture.gcholder[0], - buflen: len(msg), + buflen: length, writer: cast[Future[void]](retFuture)) transp.queue.addLast(vector) if WritePaused in transp.state: diff --git a/chronos/transports/stream.nim b/chronos/transports/stream.nim index ab49dcb..6915504 100644 --- a/chronos/transports/stream.nim +++ b/chronos/transports/stream.nim @@ -249,7 +249,8 @@ proc completePendingWriteQueue(queue: var Deque[StreamVector], v: int) {.inline.} = while len(queue) > 0: var vector = queue.popFirst() - vector.writer.complete(v) + if not(vector.writer.finished()): + vector.writer.complete(v) when defined(windows): @@ -295,7 +296,8 @@ when defined(windows): bytesCount = transp.wovl.data.bytesCount var vector = transp.queue.popFirst() if bytesCount == 0: - vector.writer.complete(0) + if not(vector.writer.finished()): + vector.writer.complete(0) else: if transp.kind == TransportKind.Socket: if vector.kind == VectorKind.DataBuffer: @@ -303,25 +305,29 @@ when defined(windows): vector.shiftVectorBuffer(bytesCount) transp.queue.addFirst(vector) else: - vector.writer.complete(transp.wwsabuf.len) + if not(vector.writer.finished()): + vector.writer.complete(transp.wwsabuf.len) else: if uint(bytesCount) < getFileSize(vector): vector.shiftVectorFile(bytesCount) transp.queue.addFirst(vector) else: - vector.writer.complete(int(getFileSize(vector))) + if not(vector.writer.finished()): + vector.writer.complete(int(getFileSize(vector))) elif transp.kind == TransportKind.Pipe: if vector.kind == VectorKind.DataBuffer: if bytesCount < transp.wwsabuf.len: vector.shiftVectorBuffer(bytesCount) transp.queue.addFirst(vector) else: - vector.writer.complete(transp.wwsabuf.len) + if not(vector.writer.finished()): + vector.writer.complete(transp.wwsabuf.len) elif int(err) == ERROR_OPERATION_ABORTED: # CancelIO() interrupt transp.state.incl(WritePaused) let v = transp.queue.popFirst() - v.writer.complete(0) + if not(v.writer.finished()): + v.writer.complete(0) break else: let v = transp.queue.popFirst() @@ -329,12 +335,14 @@ when defined(windows): # Soft error happens which indicates that remote peer got # disconnected, complete all pending writes in queue with 0. transp.state.incl(WriteEof) - v.writer.complete(0) + if not(v.writer.finished()): + v.writer.complete(0) completePendingWriteQueue(transp.queue, 0) break else: transp.state.incl(WriteError) - v.writer.fail(getTransportOsError(err)) + if not(v.writer.finished()): + v.writer.fail(getTransportOsError(err)) else: ## Initiation transp.state.incl(WritePending) @@ -353,7 +361,8 @@ when defined(windows): # CancelIO() interrupt transp.state.excl(WritePending) transp.state.incl(WritePaused) - vector.writer.complete(0) + if not(vector.writer.finished()): + vector.writer.complete(0) elif int(err) == ERROR_IO_PENDING: transp.queue.addFirst(vector) else: @@ -362,12 +371,14 @@ when defined(windows): # Soft error happens which indicates that remote peer got # disconnected, complete all pending writes in queue with 0. transp.state.incl({WritePaused, WriteEof}) - vector.writer.complete(0) + if not(vector.writer.finished()): + vector.writer.complete(0) completePendingWriteQueue(transp.queue, 0) break else: transp.state.incl({WritePaused, WriteError}) - vector.writer.fail(getTransportOsError(err)) + if not(vector.writer.finished()): + vector.writer.fail(getTransportOsError(err)) else: transp.queue.addFirst(vector) else: @@ -390,7 +401,8 @@ when defined(windows): # CancelIO() interrupt transp.state.excl(WritePending) transp.state.incl(WritePaused) - vector.writer.complete(0) + if not(vector.writer.finished()): + vector.writer.complete(0) elif int(err) == ERROR_IO_PENDING: transp.queue.addFirst(vector) else: @@ -399,12 +411,14 @@ when defined(windows): # Soft error happens which indicates that remote peer got # disconnected, complete all pending writes in queue with 0. transp.state.incl({WritePaused, WriteEof}) - vector.writer.complete(0) + if not(vector.writer.finished()): + vector.writer.complete(0) completePendingWriteQueue(transp.queue, 0) break else: transp.state.incl({WritePaused, WriteError}) - vector.writer.fail(getTransportOsError(err)) + if not(vector.writer.finished()): + vector.writer.fail(getTransportOsError(err)) else: transp.queue.addFirst(vector) elif transp.kind == TransportKind.Pipe: @@ -422,26 +436,30 @@ when defined(windows): # CancelIO() interrupt transp.state.excl(WritePending) transp.state.incl(WritePaused) - vector.writer.complete(0) + if not(vector.writer.finished()): + vector.writer.complete(0) elif int(err) == ERROR_IO_PENDING: transp.queue.addFirst(vector) elif int(err) == ERROR_NO_DATA: # The pipe is being closed. transp.state.excl(WritePending) transp.state.incl(WritePaused) - vector.writer.complete(0) + if not(vector.writer.finished()): + vector.writer.complete(0) else: transp.state.excl(WritePending) if isConnResetError(err): # Soft error happens which indicates that remote peer got # disconnected, complete all pending writes in queue with 0. transp.state.incl({WritePaused, WriteEof}) - vector.writer.complete(0) + if not(vector.writer.finished()): + vector.writer.complete(0) completePendingWriteQueue(transp.queue, 0) break else: transp.state.incl({WritePaused, WriteError}) - vector.writer.fail(getTransportOsError(err)) + if not(vector.writer.finished()): + vector.writer.fail(getTransportOsError(err)) else: transp.queue.addFirst(vector) break @@ -483,16 +501,16 @@ when defined(windows): else: transp.setReadError(err) - if not isNil(transp.reader): - if not transp.reader.finished: - transp.reader.complete() - transp.reader = nil + if not(isNil(transp.reader)) and not(transp.reader.finished()): + transp.reader.complete() + transp.reader = nil if ReadClosed in transp.state: # Stop tracking transport untrackStream(transp) # If `ReadClosed` present, then close(transport) was called. - transp.future.complete() + if not(transp.future.finished()): + transp.future.complete() GC_unref(transp) if ReadPaused in transp.state: @@ -521,14 +539,14 @@ when defined(windows): elif int32(err) in {WSAECONNRESET, WSAENETRESET, WSAECONNABORTED}: transp.state.excl(ReadPending) transp.state.incl({ReadEof, ReadPaused}) - if not isNil(transp.reader): + if not(isNil(transp.reader)) and not(transp.reader.finished()): transp.reader.complete() transp.reader = nil elif int32(err) != ERROR_IO_PENDING: transp.state.excl(ReadPending) transp.state.incl(ReadPaused) transp.setReadError(err) - if not isNil(transp.reader): + if not(isNil(transp.reader)) and not(transp.reader.finished()): transp.reader.complete() transp.reader = nil elif transp.kind == TransportKind.Pipe: @@ -547,25 +565,25 @@ when defined(windows): elif int32(err) in {ERROR_BROKEN_PIPE, ERROR_PIPE_NOT_CONNECTED}: transp.state.excl(ReadPending) transp.state.incl({ReadEof, ReadPaused}) - if not isNil(transp.reader): + if not(isNil(transp.reader)) and not(transp.reader.finished()): transp.reader.complete() transp.reader = nil elif int32(err) != ERROR_IO_PENDING: transp.state.excl(ReadPending) transp.state.incl(ReadPaused) transp.setReadError(err) - if not isNil(transp.reader): + if not(isNil(transp.reader)) and not(transp.reader.finished()): transp.reader.complete() transp.reader = nil else: transp.state.incl(ReadPaused) - if not isNil(transp.reader): + if not(isNil(transp.reader)) and not(transp.reader.finished()): transp.reader.complete() transp.reader = nil # Transport close happens in callback, and we not started new # WSARecvFrom session. if ReadClosed in transp.state: - if not transp.future.finished: + if not(transp.future.finished()): transp.future.complete() ## Finish Loop break @@ -648,7 +666,8 @@ when defined(windows): proto = Protocol.IPPROTO_TCP sock = createAsyncSocket(address.getDomain(), SockType.SOCK_STREAM, proto) if sock == asyncInvalidSocket: - result.fail(getTransportOsError(osLastError())) + retFuture.fail(getTransportOsError(osLastError())) + return retFuture if not bindToDomain(sock, address.getDomain()): let err = wsaGetLastError() @@ -656,9 +675,9 @@ when defined(windows): retFuture.fail(getTransportOsError(err)) return retFuture - proc socketContinuation(udata: pointer) = + proc socketContinuation(udata: pointer) {.gcsafe.} = var ovl = cast[RefCustomOverlapped](udata) - if not retFuture.finished: + if not(retFuture.finished()): if ovl.data.errCode == OSErrorCode(-1): if setsockopt(SocketHandle(sock), cint(SOL_SOCKET), cint(SO_UPDATE_CONNECT_CONTEXT), nil, @@ -677,6 +696,11 @@ when defined(windows): retFuture.fail(getTransportOsError(ovl.data.errCode)) GC_unref(ovl) + proc cancel(udata: pointer) {.gcsafe.} = + sock.closeSocket() + + retFuture.cancelCallback = cancel + povl = RefCustomOverlapped() GC_ref(povl) povl.data = CompletionData(fd: sock, cb: socketContinuation) @@ -695,26 +719,29 @@ when defined(windows): elif address.family == AddressFamily.Unix: ## Unix domain socket emulation with Windows Named Pipes. + var pipeHandle = INVALID_HANDLE_VALUE proc pipeContinuation(udata: pointer) {.gcsafe.} = - var pipeSuffix = $cast[cstring](unsafeAddr address.address_un[0]) - var pipeName = newWideCString(r"\\.\pipe\" & pipeSuffix[1 .. ^1]) - var pipeHandle = createFileW(pipeName, GENERIC_READ or GENERIC_WRITE, - FILE_SHARE_READ or FILE_SHARE_WRITE, - nil, OPEN_EXISTING, - FILE_FLAG_OVERLAPPED, Handle(0)) - if pipeHandle == INVALID_HANDLE_VALUE: - let err = osLastError() - if int32(err) == ERROR_PIPE_BUSY: - addTimer(Moment.fromNow(50.milliseconds), pipeContinuation, nil) + # Continue only if `retFuture` is not cancelled. + if not(retFuture.finished()): + var pipeSuffix = $cast[cstring](unsafeAddr address.address_un[0]) + var pipeName = newWideCString(r"\\.\pipe\" & pipeSuffix[1 .. ^1]) + pipeHandle = createFileW(pipeName, GENERIC_READ or GENERIC_WRITE, + FILE_SHARE_READ or FILE_SHARE_WRITE, + nil, OPEN_EXISTING, + FILE_FLAG_OVERLAPPED, Handle(0)) + if pipeHandle == INVALID_HANDLE_VALUE: + let err = osLastError() + if int32(err) == ERROR_PIPE_BUSY: + addTimer(Moment.fromNow(50.milliseconds), pipeContinuation, nil) + else: + retFuture.fail(getTransportOsError(err)) else: - retFuture.fail(getTransportOsError(err)) - else: - register(AsyncFD(pipeHandle)) - let transp = newStreamPipeTransport(AsyncFD(pipeHandle), - bufferSize, child) - # Start tracking transport - trackStream(transp) - retFuture.complete(transp) + register(AsyncFD(pipeHandle)) + let transp = newStreamPipeTransport(AsyncFD(pipeHandle), + bufferSize, child) + # Start tracking transport + trackStream(transp) + retFuture.complete(transp) pipeContinuation(nil) return retFuture @@ -748,7 +775,8 @@ when defined(windows): # Stop tracking server untrackServer(server) # Completing server's Future - server.loopFuture.complete() + if not(server.loopFuture.finished()): + server.loopFuture.complete() if not isNil(server.udata) and GCUserData in server.flags: GC_unref(cast[ref int](server.udata)) GC_unref(server) @@ -796,7 +824,7 @@ when defined(windows): # Server close happens in callback, and we are not started new # connectNamedPipe session. if server.status in {ServerStatus.Closed, ServerStatus.Stopped}: - if not server.loopFuture.finished: + if not(server.loopFuture.finished()): # Stop tracking server untrackServer(server) server.loopFuture.complete() @@ -839,7 +867,7 @@ when defined(windows): # CancelIO() interrupt or close. if server.status in {ServerStatus.Closed, ServerStatus.Stopped}: # Stop tracking server - if not server.loopFuture.finished: + if not(server.loopFuture.finished()): untrackServer(server) server.loopFuture.complete() if not isNil(server.udata) and GCUserData in server.flags: @@ -883,7 +911,7 @@ when defined(windows): # Server close happens in callback, and we are not started new # AcceptEx session. if server.status in {ServerStatus.Closed, ServerStatus.Stopped}: - if not server.loopFuture.finished: + if not(server.loopFuture.finished()): # Stop tracking server untrackServer(server) server.loopFuture.complete() @@ -937,7 +965,8 @@ else: let res = posix.send(fd, vector.buf, vector.buflen, MSG_NOSIGNAL) if res >= 0: if vector.buflen - res == 0: - vector.writer.complete(vector.buflen) + if not(vector.writer.finished()): + vector.writer.complete(vector.buflen) else: vector.shiftVectorBuffer(res) transp.queue.addFirst(vector) @@ -950,11 +979,13 @@ else: # Soft error happens which indicates that remote peer got # disconnected, complete all pending writes in queue with 0. transp.state.incl({WriteEof, WritePaused}) - vector.writer.complete(0) + if not(vector.writer.finished()): + vector.writer.complete(0) completePendingWriteQueue(transp.queue, 0) transp.fd.removeWriter() else: - vector.writer.fail(getTransportOsError(err)) + if not(vector.writer.finished()): + vector.writer.fail(getTransportOsError(err)) else: var nbytes = cast[int](vector.buf) let res = sendfile(int(fd), cast[int](vector.buflen), @@ -963,7 +994,8 @@ else: if res >= 0: if cast[int](vector.buf) - nbytes == 0: vector.size += nbytes - vector.writer.complete(vector.size) + if not(vector.writer.finished()): + vector.writer.complete(vector.size) else: vector.size += nbytes vector.shiftVectorFile(nbytes) @@ -977,11 +1009,13 @@ else: # Soft error happens which indicates that remote peer got # disconnected, complete all pending writes in queue with 0. transp.state.incl({WriteEof, WritePaused}) - vector.writer.complete(0) + if not(vector.writer.finished()): + vector.writer.complete(0) completePendingWriteQueue(transp.queue, 0) transp.fd.removeWriter() else: - vector.writer.fail(getTransportOsError(err)) + if not(vector.writer.finished()): + vector.writer.fail(getTransportOsError(err)) break else: transp.state.incl(WritePaused) @@ -998,10 +1032,9 @@ else: if ReadClosed in transp.state: transp.state.incl({ReadPaused}) - if not isNil(transp.reader): - if not transp.reader.finished: - transp.reader.complete() - transp.reader = nil + if not(isNil(transp.reader)) and not(transp.reader.finished()): + transp.reader.complete() + transp.reader = nil else: while true: var res = posix.recv(fd, addr transp.buffer[transp.offset], @@ -1025,7 +1058,7 @@ else: if transp.offset == len(transp.buffer): transp.state.incl(ReadPaused) cdata.fd.removeReader() - if not isNil(transp.reader): + if not(isNil(transp.reader)) and not(transp.reader.finished()): transp.reader.complete() transp.reader = nil break @@ -1070,8 +1103,8 @@ else: retFuture.fail(getTransportOsError(osLastError())) return retFuture - proc continuation(udata: pointer) = - if not retFuture.finished: + proc continuation(udata: pointer) {.gcsafe.} = + if not(retFuture.finished()): var data = cast[ptr CompletionData](udata) var err = 0 let fd = data.fd @@ -1089,6 +1122,9 @@ else: trackStream(transp) retFuture.complete(transp) + proc cancel(udata: pointer) {.gcsafe.} = + closeSocket(sock) + while true: var res = posix.connect(SocketHandle(sock), cast[ptr SockAddr](addr saddr), slen) @@ -1100,10 +1136,15 @@ else: break else: let err = osLastError() - if int(err) == EINTR: - continue - elif int(err) == EINPROGRESS: + # If connect() is interrupted by a signal that is caught while blocked + # waiting to establish a connection, connect() shall fail and set + # connect() to [EINTR], but the connection request shall not be aborted, + # and the connection shall be established asynchronously. + # + # http://www.madore.org/~david/computers/connect-intr.html + if int(err) == EINPROGRESS or int(err) == EINTR: sock.addWriter(continuation) + retFuture.cancelCallback = cancel break else: sock.closeSocket() @@ -1171,9 +1212,16 @@ proc stop*(server: StreamServer) = proc join*(server: StreamServer): Future[void] = ## Waits until ``server`` is not closed. var retFuture = newFuture[void]("stream.transport.server.join") - proc continuation(udata: pointer) = retFuture.complete() - if not server.loopFuture.finished: - server.loopFuture.addCallback(continuation) + + proc continuation(udata: pointer) {.gcsafe.} = + retFuture.complete() + + proc cancel(udata: pointer) {.gcsafe.} = + server.loopFuture.removeCallback(continuation, cast[pointer](retFuture)) + + if not(server.loopFuture.finished()): + server.loopFuture.addCallback(continuation, cast[pointer](retFuture)) + retFuture.cancelCallback = cancel else: retFuture.complete() return retFuture @@ -1183,14 +1231,15 @@ proc close*(server: StreamServer) = ## ## Please note that release of resources is not completed immediately, to be ## sure all resources got released please use ``await server.join()``. - proc continuation(udata: pointer) = + proc continuation(udata: pointer) {.gcsafe.} = # Stop tracking server - if not server.loopFuture.finished: + if not(server.loopFuture.finished()): untrackServer(server) server.loopFuture.complete() if not isNil(server.udata) and GCUserData in server.flags: GC_unref(cast[ref int](server.udata)) GC_unref(server) + if server.status == ServerStatus.Stopped: server.status = ServerStatus.Closed when defined(windows): @@ -1723,9 +1772,16 @@ proc consume*(transp: StreamTransport, n = -1): Future[int] {.async.} = proc join*(transp: StreamTransport): Future[void] = ## Wait until ``transp`` will not be closed. var retFuture = newFuture[void]("stream.transport.join") - proc continuation(udata: pointer) = retFuture.complete() - if not transp.future.finished: - transp.future.addCallback(continuation) + + proc continuation(udata: pointer) {.gcsafe.} = + retFuture.complete() + + proc cancel(udata: pointer) {.gcsafe.} = + transp.future.removeCallback(continuation, cast[pointer](retFuture)) + + if not(transp.future.finished()): + transp.future.addCallback(continuation, cast[pointer](retFuture)) + retFuture.cancelCallback = cancel else: retFuture.complete() return retFuture @@ -1735,8 +1791,8 @@ proc close*(transp: StreamTransport) = ## ## Please note that release of resources is not completed immediately, to be ## sure all resources got released please use ``await transp.join()``. - proc continuation(udata: pointer) = - if not transp.future.finished: + proc continuation(udata: pointer) {.gcsafe.} = + if not(transp.future.finished()): transp.future.complete() # Stop tracking stream untrackStream(transp) diff --git a/tests/testall.nim b/tests/testall.nim index 710afee..cc3bb9e 100644 --- a/tests/testall.nim +++ b/tests/testall.nim @@ -5,6 +5,6 @@ # Licensed under either of # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) -import testsync, testsoon, testtime, testfut, testsignal, testaddress, - testdatagram, teststream, testserver, testbugs, testnet, +import testmacro, testsync, testsoon, testtime, testfut, testsignal, + testaddress, testdatagram, teststream, testserver, testbugs, testnet, testasyncstream diff --git a/tests/testfut.nim b/tests/testfut.nim index 50ebc65..ad0fa3e 100644 --- a/tests/testfut.nim +++ b/tests/testfut.nim @@ -139,7 +139,79 @@ suite "Future[T] behavior test suite": proc test5(): int = result = waitFor(testFuture4()) - proc testAllVarargs(): int = + proc testAsyncDiscard(): int = + var completedFutures = 0 + + proc client1() {.async.} = + await sleepAsync(100.milliseconds) + inc(completedFutures) + + proc client2() {.async.} = + await sleepAsync(200.milliseconds) + inc(completedFutures) + + proc client3() {.async.} = + await sleepAsync(300.milliseconds) + inc(completedFutures) + + proc client4() {.async.} = + await sleepAsync(400.milliseconds) + inc(completedFutures) + + proc client5() {.async.} = + await sleepAsync(500.milliseconds) + inc(completedFutures) + + proc client1f() {.async.} = + await sleepAsync(100.milliseconds) + inc(completedFutures) + if true: + raise newException(ValueError, "") + + proc client2f() {.async.} = + await sleepAsync(200.milliseconds) + inc(completedFutures) + if true: + raise newException(ValueError, "") + + proc client3f() {.async.} = + await sleepAsync(300.milliseconds) + inc(completedFutures) + if true: + raise newException(ValueError, "") + + proc client4f() {.async.} = + await sleepAsync(400.milliseconds) + inc(completedFutures) + if true: + raise newException(ValueError, "") + + proc client5f() {.async.} = + await sleepAsync(500.milliseconds) + inc(completedFutures) + if true: + raise newException(ValueError, "") + + asyncDiscard client1() + asyncDiscard client1f() + asyncDiscard client2() + asyncDiscard client2f() + asyncDiscard client3() + asyncDiscard client3f() + asyncDiscard client4() + asyncDiscard client4f() + asyncDiscard client5() + asyncDiscard client5f() + + waitFor(sleepAsync(2000.milliseconds)) + result = completedFutures + + proc testAllFuturesZero(): bool = + var tseq = newSeq[Future[int]]() + var fut = allFutures(tseq) + result = fut.finished + + proc testAllFuturesVarargs(): int = var completedFutures = 0 proc vlient1() {.async.} = @@ -247,43 +319,34 @@ suite "Future[T] behavior test suite": if true: raise newException(ValueError, "") - waitFor(all(vlient1(), vlient2(), vlient3(), vlient4(), vlient5())) + waitFor(allFutures(vlient1(), vlient2(), vlient3(), vlient4(), vlient5())) # 5 completed futures = 5 result += completedFutures + completedFutures = 0 - try: - waitFor(all(vlient1(), vlient1f(), - vlient2(), vlient2f(), - vlient3(), vlient3f(), - vlient4(), vlient4f(), - vlient5(), vlient5f())) - result -= 10000 - except: - discard + waitFor(allFutures(vlient1(), vlient1f(), + vlient2(), vlient2f(), + vlient3(), vlient3f(), + vlient4(), vlient4f(), + vlient5(), vlient5f())) # 10 completed futures = 10 result += completedFutures completedFutures = 0 - var res = waitFor(all(client1(), client2(), client3(), client4(), client5())) - for item in res: - result += item - # 5 completed futures + 5 values = 10 + waitFor(allFutures(client1(), client2(), client3(), client4(), client5())) + # 5 completed futures result += completedFutures completedFutures = 0 - try: - var res = waitFor(all(client1(), client1f(), - client2(), client2f(), - client3(), client3f(), - client4(), client4f(), - client5(), client5f())) - result -= 10000 - except: - discard + waitFor(allFutures(client1(), client1f(), + client2(), client2f(), + client3(), client3f(), + client4(), client4f(), + client5(), client5f())) # 10 completed futures = 10 result += completedFutures - proc testAllSeq(): int = + proc testAllFuturesSeq(): int = var completedFutures = 0 var vfutures = newSeq[Future[void]]() var nfutures = newSeq[Future[int]]() @@ -401,7 +464,7 @@ suite "Future[T] behavior test suite": vfutures.add(vlient4()) vfutures.add(vlient5()) - waitFor(all(vfutures)) + waitFor(allFutures(vfutures)) # 5 * 10 completed futures = 50 result += completedFutures @@ -419,11 +482,7 @@ suite "Future[T] behavior test suite": vfutures.add(vlient5()) vfutures.add(vlient5f()) - try: - waitFor(all(vfutures)) - result -= 10000 - except: - discard + waitFor(allFutures(vfutures)) # 10 * 10 completed futures = 100 result += completedFutures @@ -436,10 +495,8 @@ suite "Future[T] behavior test suite": nfutures.add(client4()) nfutures.add(client5()) - var res = waitFor(all(nfutures)) - for i in 0..