diff --git a/codex/slots/proofs/backends/circomcompat.nim b/codex/slots/proofs/backends/circomcompat.nim index b9f8e84c..a38cf346 100644 --- a/codex/slots/proofs/backends/circomcompat.nim +++ b/codex/slots/proofs/backends/circomcompat.nim @@ -9,7 +9,7 @@ {.push raises: [].} -import std/[sugar, atomics] +import std/[sugar, atomics, locks] import pkg/chronos import pkg/taskpools @@ -39,6 +39,7 @@ type backendCfg: ptr CircomBn254Cfg vkp*: ptr CircomKey taskpool: Taskpool + lock: ptr Lock NormalizedProofInputs*[H] {.borrow: `.`.} = distinct ProofInputs[H] @@ -97,24 +98,29 @@ proc release*(self: CircomCompat) = if not isNil(self.vkp): self.vkp.unsafeAddr.release_key() + if not isNil(self.lock): + deinitLock(self.lock[]) # Cleanup the lock + dealloc(self.lock) # Free the memory + proc circomProveTask(task: ptr ProveTask) {.gcsafe.} = - defer: - discard task[].signal.fireSync() + withLock task[].circom.lock[]: + defer: + discard task[].signal.fireSync() - var proofPtr: ptr Proof = nil - try: - if ( - let res = task.circom.backendCfg.prove_circuit(task.ctx, proofPtr.addr) - res != ERR_OK - ) or proofPtr == nil: - task.success.store(false) - return + var proofPtr: ptr Proof = nil + try: + if ( + let res = task.circom.backendCfg.prove_circuit(task.ctx, proofPtr.addr) + res != ERR_OK + ) or proofPtr == nil: + task.success.store(false) + return - copyProof(task.proof, proofPtr[]) - task.success.store(true) - finally: - if proofPtr != nil: - proofPtr.addr.release_proof() + copyProof(task.proof, proofPtr[]) + task.success.store(true) + finally: + if proofPtr != nil: + proofPtr.addr.release_proof() proc asyncProve*[H]( self: CircomCompat, input: NormalizedProofInputs[H], proof: ptr Proof @@ -328,9 +334,11 @@ proc init*( numSamples = DefaultSamplesNum, taskpool: Taskpool, ): CircomCompat = - ## Create a new ctx - ## + # Allocate and initialize the lock + var lockPtr = create(Lock) # Allocate memory for the lock + initLock(lockPtr[]) # Initialize the lock + ## Create a new ctx var cfg: ptr CircomBn254Cfg var zkey = if zkeyPath.len > 0: zkeyPath.cstring else: nil @@ -359,4 +367,5 @@ proc init*( backendCfg: cfg, vkp: vkpPtr, taskpool: taskpool, + lock: lockPtr, )