From 1e6133a804dbfcb6102987c06e22f64a5e3ae4da Mon Sep 17 00:00:00 2001 From: Jaremy Creechley Date: Mon, 3 Jun 2024 18:57:14 +0100 Subject: [PATCH] thread local dupes --- codex/slots/proofs/backends/asynccircoms.nim | 30 +++++++++++++------- codex/slots/proofs/backends/circomcompat.nim | 10 ++++++- tests/codex/slots/testprover.nim | 2 +- 3 files changed, 29 insertions(+), 13 deletions(-) diff --git a/codex/slots/proofs/backends/asynccircoms.nim b/codex/slots/proofs/backends/asynccircoms.nim index b63fad5a..38b70c4b 100644 --- a/codex/slots/proofs/backends/asynccircoms.nim +++ b/codex/slots/proofs/backends/asynccircoms.nim @@ -28,15 +28,20 @@ type proof: CircomProof inputs: ProofInputs[H] -proc proveTask[H](args: ptr ProverArgs[H], results: SignalQueuePtr[?!CircomProof]) = - let circom = args[].circom - let data = args[].data +var + localCircom {.threadvar.}: Option[CircomCompat] - let proof = circom.prove(data) +proc proveTask[H](args: ptr ProverArgs[H], results: SignalQueuePtr[?!CircomProof]) = + + if localCircom.isNone: + localCircom = some args.circom.duplicate() + + var data = args[].data + let proof = localCircom.get().prove(data) echo "PROVE TASK: proof: ", proof - let verified = circom.verify(proof.get(), data) + let verified = localCircom.get().verify(proof.get(), data) echo "PROVE TASK: verify: ", verified if (let sent = results.send(proof); sent.isErr()): @@ -50,7 +55,7 @@ proc prove*[H]( without queue =? newSignalQueue[?!CircomProof](maxItems = 1), qerr: return failure(qerr) - var args = (ref ProverArgs[H])(circom: self.circom.duplicate(), data: input) + var args = (ref ProverArgs[H])(circom: self.circom, data: input) GC_ref(args) proc spawnTask() = @@ -71,10 +76,13 @@ proc prove*[H]( success(proof) proc verifyTask[H](args: ptr VerifierArgs[H], results: SignalQueuePtr[?!bool]) = - let circom = args.circom - let proof = args.proof - let inputs = args.inputs - let verified = circom.verify(proof, inputs) + + if localCircom.isNone: + localCircom = some args.circom.duplicate() + + var proof = args[].proof + var inputs = args[].inputs + let verified = localCircom.get().verify(proof, inputs) if (let sent = results.send(verified); sent.isErr()): error "Error sending verification results", msg = sent.error().msg @@ -87,7 +95,7 @@ proc verify*[H]( without queue =? newSignalQueue[?!bool](maxItems = 1), qerr: return failure(qerr) - var args = (ref VerifierArgs[H])(circom: self.circom.duplicate(), proof: proof, inputs: inputs) + var args = (ref VerifierArgs[H])(circom: self.circom, proof: proof, inputs: inputs) GC_ref(args) proc spawnTask() = diff --git a/codex/slots/proofs/backends/circomcompat.nim b/codex/slots/proofs/backends/circomcompat.nim index 517e829d..f0efc01f 100644 --- a/codex/slots/proofs/backends/circomcompat.nim +++ b/codex/slots/proofs/backends/circomcompat.nim @@ -245,4 +245,12 @@ proc duplicate*( if cfg != nil: cfg.addr.releaseCfg() raiseAssert("failed to initialize circom compat config") - CircomCompat(params: self.params, backendCfg: cfg, vkp: self.vkp) + + var + vkpPtr: ptr VerifyingKey = nil + + if cfg.getVerifyingKey(vkpPtr.addr) != ERR_OK or vkpPtr == nil: + if vkpPtr != nil: vkpPtr.addr.releaseKey() + raiseAssert("Failed to get verifying key") + + CircomCompat(params: self.params, backendCfg: cfg, vkp: vkpPtr) diff --git a/tests/codex/slots/testprover.nim b/tests/codex/slots/testprover.nim index 932d84e8..139a3863 100644 --- a/tests/codex/slots/testprover.nim +++ b/tests/codex/slots/testprover.nim @@ -101,7 +101,7 @@ suite "Test Prover": r1cs = "tests/circuits/fixtures/proof_main.r1cs" wasm = "tests/circuits/fixtures/proof_main.wasm" - taskpool = Taskpool.new(num_threads = 6) + taskpool = Taskpool.new(num_threads = 5) params = CircomCompatParams.init(r1cs, wasm) circomBackend = AsyncCircomCompat.init(params, taskpool) prover = Prover.new(store, circomBackend, samples)