diff --git a/codex/slots/proofs/backends/asynccircoms.nim b/codex/slots/proofs/backends/asynccircoms.nim index fc54b7b0..613a250e 100644 --- a/codex/slots/proofs/backends/asynccircoms.nim +++ b/codex/slots/proofs/backends/asynccircoms.nim @@ -11,38 +11,36 @@ import ../../../utils/asyncthreads import ./circomcompat -type - AsyncCircomCompat* = object - params*: CircomCompatParams - tp*: Taskpool - - # Args objects are missing seq[seq[byte]] field, to avoid unnecessary data copy - ProveTaskArgs* = object - signal: ThreadSignalPtr - params: CircomCompatParams +type AsyncCircomCompat* = object + params*: CircomCompatParams + tp*: Taskpool var circomBackend {.threadvar.}: Option[CircomCompat] proc proveTask[H]( - args: ProveTaskArgs, data: ProofInputs[H] -): Result[CircomProof, string] = - + params: CircomCompatParams, + data: ProofInputs[H], + results: SignalQueuePtr[Result[CircomProof, string]], +) = try: if circomBackend.isNone: - circomBackend = some CircomCompat.init(args.params) + circomBackend = some CircomCompat.init(params) else: - assert circomBackend.get().params == args.params + assert circomBackend.get().params == params - let res = circomBackend.get().prove(data) - if res.isOk: - return ok(res.get()) + let proof = circomBackend.get().prove(data) + var val: Result[CircomProof, string] + if proof.isOk(): + val.ok(proof.get()) else: - return err(res.error().msg) + val.err(proof.error().msg) + + if (let sent = results.send(val); sent.isErr()): + error "Error sending proof results", msg = sent.error().msg except CatchableError as exception: - return err(exception.msg) - finally: - if err =? args.signal.fireSync().mapFailure.errorOption(): - error "Error firing signal in proveTask ", msg = err.msg + var err = Result[CircomProof, string].err(exception.msg) + if (let res = results.send(err); res.isErr()): + error "Error sending proof results", msg = res.error().msg proc prove*[H]( self: AsyncCircomCompat, input: ProofInputs[H] @@ -50,27 +48,24 @@ proc prove*[H]( ## Generates proof using circom-compat asynchronously ## - without signal =? ThreadSignalPtr.new().mapFailure, err: + without queue =? newSignalQueue[Result[CircomProof, string]](), err: return failure(err) - defer: - if err =? signal.close().mapFailure.errorOption(): - error "Error closing signal", msg = $err.msg - let args = ProveTaskArgs(signal: signal, params: self.params) - proc spawnTask(): Flowvar[Result[CircomProof, string]] = - self.tp.spawn proveTask(args, input) - let flowvar = spawnTask() + proc spawnTask() = + self.tp.spawn proveTask(self.params, input, queue) - without taskRes =? await awaitThreadResult(signal, flowvar), err: + spawnTask() + + without taskRes =? await queue.recvAsync(), err: return failure(err) + if (let res = queue.release(); res.isErr): + return failure "Error releasing proof queue " & res.error().msg + without proof =? taskRes.mapFailure, err: - let res: ?!CircomProof = failure(err) - return res - - let pf: CircomProof = proof - success(pf) + return failure(err) + success(proof) proc verify*[H]( self: AsyncCircomCompat, proof: CircomProof, inputs: ProofInputs[H] diff --git a/codex/utils/asyncthreads.nim b/codex/utils/asyncthreads.nim index 341f57f1..67fc2277 100644 --- a/codex/utils/asyncthreads.nim +++ b/codex/utils/asyncthreads.nim @@ -1,4 +1,4 @@ - +import std/options import pkg/taskpools import pkg/taskpools/flowvars import pkg/chronos @@ -24,3 +24,58 @@ proc awaitThreadResult*[T](signal: ThreadSignalPtr, handle: Flowvar[T]): Future[ await sleepAsync(CompletionRetryDelay) return failure("Task signaled finish but didn't return any result within " & $CompletionRetryDelay) + +type + SignalQueue[T] = object + signal: ThreadSignalPtr + chan*: Channel[T] + + SignalQueuePtr*[T] = ptr SignalQueue[T] + +proc release*[T](queue: SignalQueuePtr[T]): ?!void = + ## Call to properly dispose of a SignalQueue. + queue[].chan.close() + if err =? queue[].signal.close().mapFailure.errorOption(): + return failure(err.msg) + deallocShared(queue) + +proc newSignalQueue*[T]( + maxItems: int = 0 +): ?!SignalQueuePtr[T] = + ## Create a signal queue compatible with Chronos async. + result = success cast[ptr SignalQueue[T]](allocShared0(sizeof(SignalQueue[T]))) + without signal =? ThreadSignalPtr.new().mapFailure, err: + return failure(err) + result[].signal = signal + result[].chan.open(maxItems) + +proc send*[T](queue: SignalQueuePtr[T], msg: T): ?!void {.raises: [].} = + ## Sends a message to a thread. `msg` is copied. + ## Note: may be blocking. + ## + try: + queue[].chan.send(msg) + except Exception as exc: + return failure(exc.msg) + + let res = queue[].signal.fireSync() + if res.isErr(): + return failure(res.error()) + result = ok() + +proc recv*[T](queue: SignalQueue[T]): ?!T = + ## Receive item from queue, blocking. + try: + ok(queue.chan[].recv()) + except Exception as exc: + failure(exc.msg) + +proc recvAsync*[T](queue: SignalQueuePtr[T]): Future[?!T] {.async.} = + ## Async compatible receive from queue. Pauses async execution until + ## an item is received from the queue + await wait(queue.signal) + let res = queue.chan.tryRecv() + if not res.dataAvailable: + return failure("unable to retrieve expected queue value") + else: + return success(res.msg)