diff --git a/codex/codex.nim b/codex/codex.nim index 63f56c6b..5ac23af0 100644 --- a/codex/codex.nim +++ b/codex/codex.nim @@ -347,8 +347,9 @@ proc new*( store = NetworkStore.new(engine, repoStore) prover = if config.prover: - let backend = - config.initializeBackend().expect("Unable to create prover backend.") + let backend = config.initializeBackend(taskpool = taskpool).expect( + "Unable to create prover backend." + ) some Prover.new(store, backend, config.numProofSamples) else: none Prover diff --git a/codex/slots/proofs/backendfactory.nim b/codex/slots/proofs/backendfactory.nim index 7aba27d8..34706191 100644 --- a/codex/slots/proofs/backendfactory.nim +++ b/codex/slots/proofs/backendfactory.nim @@ -2,6 +2,7 @@ import os import strutils import pkg/chronos import pkg/chronicles +import pkg/taskpools import pkg/questionable import pkg/confutils/defs import pkg/stew/io2 @@ -11,7 +12,9 @@ import ../../conf import ./backends import ./backendutils -proc initializeFromConfig(config: CodexConf, utils: BackendUtils): ?!AnyBackend = +proc initializeFromConfig( + config: CodexConf, utils: BackendUtils, taskpool: Taskpool +): ?!AnyBackend = if not fileAccessible($config.circomR1cs, {AccessFlags.Read}) or not endsWith($config.circomR1cs, ".r1cs"): return failure("Circom R1CS file not accessible") @@ -27,7 +30,7 @@ proc initializeFromConfig(config: CodexConf, utils: BackendUtils): ?!AnyBackend trace "Initialized prover backend from cli config" success( utils.initializeCircomBackend( - $config.circomR1cs, $config.circomWasm, $config.circomZkey + $config.circomR1cs, $config.circomWasm, $config.circomZkey, taskpool ) ) @@ -41,14 +44,14 @@ proc zkeyFilePath(config: CodexConf): string = config.circuitDir / "proof_main.zkey" proc initializeFromCircuitDirFiles( - config: CodexConf, utils: BackendUtils + config: CodexConf, utils: BackendUtils, taskpool: Taskpool ): ?!AnyBackend {.gcsafe.} = if fileExists(config.r1csFilePath) and fileExists(config.wasmFilePath) and fileExists(config.zkeyFilePath): trace "Initialized prover backend from local files" return success( utils.initializeCircomBackend( - config.r1csFilePath, config.wasmFilePath, config.zkeyFilePath + config.r1csFilePath, config.wasmFilePath, config.zkeyFilePath, taskpool ) ) @@ -68,11 +71,11 @@ proc suggestDownloadTool(config: CodexConf) = instructions proc initializeBackend*( - config: CodexConf, utils: BackendUtils = BackendUtils() + config: CodexConf, utils: BackendUtils = BackendUtils(), taskpool: Taskpool ): ?!AnyBackend = - without backend =? initializeFromConfig(config, utils), cliErr: + without backend =? initializeFromConfig(config, utils, taskpool), cliErr: info "Could not initialize prover backend from CLI options...", msg = cliErr.msg - without backend =? initializeFromCircuitDirFiles(config, utils), localErr: + without backend =? initializeFromCircuitDirFiles(config, utils, taskpool), localErr: info "Could not initialize prover backend from circuit dir files...", msg = localErr.msg suggestDownloadTool(config) diff --git a/codex/slots/proofs/backends/circomcompat.nim b/codex/slots/proofs/backends/circomcompat.nim index 1d2e3e19..6c9583fe 100644 --- a/codex/slots/proofs/backends/circomcompat.nim +++ b/codex/slots/proofs/backends/circomcompat.nim @@ -9,9 +9,11 @@ {.push raises: [].} -import std/sugar +import std/[sugar, atomics, locks] import pkg/chronos +import pkg/taskpools +import pkg/chronos/threadsync import pkg/questionable/results import pkg/circomcompat @@ -22,6 +24,7 @@ import ../../../contracts import ./converters export circomcompat, converters +export taskpools type CircomCompat* = object @@ -35,9 +38,25 @@ type zkeyPath: string # path to the zkey file backendCfg: ptr CircomBn254Cfg vkp*: ptr CircomKey + taskpool: Taskpool + lock: ptr Lock NormalizedProofInputs*[H] {.borrow: `.`.} = distinct ProofInputs[H] + ProveTask = object + circom: ptr CircomCompat + ctx: ptr CircomCompatCtx + proof: ptr Proof + success: Atomic[bool] + signal: ThreadSignalPtr + + VerifyTask = object + proof: ptr CircomProof + vkp: ptr CircomKey + inputs: ptr CircomInputs + success: VerifyResult + signal: ThreadSignalPtr + func normalizeInput*[H]( self: CircomCompat, input: ProofInputs[H] ): NormalizedProofInputs[H] = @@ -79,7 +98,33 @@ proc release*(self: CircomCompat) = if not isNil(self.vkp): self.vkp.unsafeAddr.release_key() -proc prove[H](self: CircomCompat, input: NormalizedProofInputs[H]): ?!CircomProof = + if not isNil(self.lock): + deinitLock(self.lock[]) # Cleanup the lock + dealloc(self.lock) # Free the memory + +proc circomProveTask(task: ptr ProveTask) {.gcsafe.} = + withLock task[].circom.lock[]: + defer: + discard task[].signal.fireSync() + + var proofPtr: ptr Proof = nil + try: + if ( + let res = task.circom.backendCfg.prove_circuit(task.ctx, proofPtr.addr) + res != ERR_OK + ) or proofPtr == nil: + task.success.store(false) + return + + copyProof(task.proof, proofPtr[]) + task.success.store(true) + finally: + if proofPtr != nil: + proofPtr.addr.release_proof() + +proc asyncProve*[H]( + self: CircomCompat, input: NormalizedProofInputs[H], proof: ptr Proof +): Future[?!void] {.async.} = doAssert input.samples.len == self.numSamples, "Number of samples does not match" doAssert input.slotProof.len <= self.datasetDepth, @@ -143,13 +188,11 @@ proc prove[H](self: CircomCompat, input: NormalizedProofInputs[H]): ?!CircomProo for s in input.samples: var - merklePaths = s.merklePaths.mapIt(it.toBytes) + merklePaths = s.merklePaths.mapIt(@(it.toBytes)).concat data = s.cellData.mapIt(@(it.toBytes)).concat if ctx.push_input_u256_array( - "merklePaths".cstring, - merklePaths[0].addr, - uint (merklePaths[0].len * merklePaths.len), + "merklePaths".cstring, merklePaths[0].addr, uint (merklePaths.len) ) != ERR_OK: return failure("Failed to push merkle paths") @@ -157,44 +200,118 @@ proc prove[H](self: CircomCompat, input: NormalizedProofInputs[H]): ?!CircomProo ERR_OK: return failure("Failed to push cell data") - var proofPtr: ptr Proof = nil + without threadPtr =? ThreadSignalPtr.new(): + return failure("Unable to create thread signal") - let proof = - try: - if (let res = self.backendCfg.prove_circuit(ctx, proofPtr.addr); res != ERR_OK) or - proofPtr == nil: - return failure("Failed to prove - err code: " & $res) + defer: + threadPtr.close().expect("closing once works") - proofPtr[] - finally: - if proofPtr != nil: - proofPtr.addr.release_proof() + var task = ProveTask(circom: addr self, ctx: ctx, proof: proof, signal: threadPtr) - success proof + let taskPtr = addr task -proc prove*[H](self: CircomCompat, input: ProofInputs[H]): ?!CircomProof = - self.prove(self.normalizeInput(input)) + doAssert task.circom.taskpool.numThreads > 1, + "Must have at least one separate thread or signal will never be fired" + task.circom.taskpool.spawn circomProveTask(taskPtr) + let threadFut = threadPtr.wait() + + if joinErr =? catch(await threadFut.join()).errorOption: + if err =? catch(await noCancel threadFut).errorOption: + return failure(err) + if joinErr of CancelledError: + raise joinErr + else: + return failure(joinErr) + + if not task.success.load(): + return failure("Failed to prove circuit") + + success() + +proc prove*[H]( + self: CircomCompat, input: ProofInputs[H] +): Future[?!CircomProof] {.async, raises: [CancelledError].} = + var proof = ProofPtr.new() + defer: + destroyProof(proof) + + try: + if error =? (await self.asyncProve(self.normalizeInput(input), proof)).errorOption: + return failure(error) + return success(deepCopy(proof)[]) + except CancelledError as exc: + raise exc + +proc circomVerifyTask(task: ptr VerifyTask) {.gcsafe.} = + defer: + task[].inputs[].releaseCircomInputs() + discard task[].signal.fireSync() + + let res = verify_circuit(task[].proof, task[].inputs, task[].vkp) + if res == ERR_OK: + task[].success[] = true + elif res == ERR_FAILED_TO_VERIFY_PROOF: + task[].success[] = false + else: + task[].success[] = false + error "Failed to verify proof", errorCode = res + +proc asyncVerify*[H]( + self: CircomCompat, + proof: CircomProof, + inputs: ProofInputs[H], + success: VerifyResult, +): Future[?!void] {.async.} = + var proofPtr = unsafeAddr proof + var inputs = inputs.toCircomInputs() + + without threadPtr =? ThreadSignalPtr.new(): + return failure("Unable to create thread signal") + + defer: + threadPtr.close().expect("closing once works") + + var task = VerifyTask( + proof: proofPtr, + vkp: self.vkp, + inputs: addr inputs, + success: success, + signal: threadPtr, + ) + + let taskPtr = addr task + + doAssert self.taskpool.numThreads > 1, + "Must have at least one separate thread or signal will never be fired" + + self.taskpool.spawn circomVerifyTask(taskPtr) + + let threadFut = threadPtr.wait() + + if joinErr =? catch(await threadFut.join()).errorOption: + if err =? catch(await noCancel threadFut).errorOption: + return failure(err) + if joinErr of CancelledError: + raise joinErr + else: + return failure(joinErr) + + success() proc verify*[H]( self: CircomCompat, proof: CircomProof, inputs: ProofInputs[H] -): ?!bool = +): Future[?!bool] {.async, raises: [CancelledError].} = ## Verify a proof using a ctx ## - - var - proofPtr = unsafeAddr proof - inputs = inputs.toCircomInputs() - + var res = VerifyResult.new() + defer: + destroyVerifyResult(res) try: - let res = verify_circuit(proofPtr, inputs.addr, self.vkp) - if res == ERR_OK: - success true - elif res == ERR_FAILED_TO_VERIFY_PROOF: - success false - else: - failure("Failed to verify proof - err code: " & $res) - finally: - inputs.releaseCircomInputs() + if error =? (await self.asyncVerify(proof, inputs, res)).errorOption: + return failure(error) + return success(res[]) + except CancelledError as exc: + raise exc proc init*( _: type CircomCompat, @@ -206,10 +323,13 @@ proc init*( blkDepth = DefaultBlockDepth, cellElms = DefaultCellElms, numSamples = DefaultSamplesNum, + taskpool: Taskpool, ): CircomCompat = - ## Create a new ctx - ## + # Allocate and initialize the lock + var lockPtr = create(Lock) # Allocate memory for the lock + initLock(lockPtr[]) # Initialize the lock + ## Create a new ctx var cfg: ptr CircomBn254Cfg var zkey = if zkeyPath.len > 0: zkeyPath.cstring else: nil @@ -237,4 +357,6 @@ proc init*( numSamples: numSamples, backendCfg: cfg, vkp: vkpPtr, + taskpool: taskpool, + lock: lockPtr, ) diff --git a/codex/slots/proofs/backends/converters.nim b/codex/slots/proofs/backends/converters.nim index ee771477..0e74b000 100644 --- a/codex/slots/proofs/backends/converters.nim +++ b/codex/slots/proofs/backends/converters.nim @@ -10,6 +10,7 @@ {.push raises: [].} import pkg/circomcompat +import std/atomics import ../../../contracts import ../../types @@ -22,6 +23,14 @@ type CircomProof* = Proof CircomKey* = VerifyingKey CircomInputs* = Inputs + VerifyResult* = ptr bool + ProofPtr* = ptr Proof + +proc new*(_: type ProofPtr): ProofPtr = + cast[ptr Proof](allocShared0(sizeof(Proof))) + +proc new*(_: type VerifyResult): VerifyResult = + cast[ptr bool](allocShared0(sizeof(bool))) proc toCircomInputs*(inputs: ProofInputs[Poseidon2Hash]): CircomInputs = var @@ -52,3 +61,26 @@ func toG2*(g: CircomG2): G2Point = func toGroth16Proof*(proof: CircomProof): Groth16Proof = Groth16Proof(a: proof.a.toG1, b: proof.b.toG2, c: proof.c.toG1) + +proc destroyVerifyResult*(result: VerifyResult) = + if result != nil: + deallocShared(result) + +proc destroyProof*(proof: ProofPtr) = + if proof != nil: + deallocShared(proof) + +proc copyInto*(dest: var G1, src: G1) = + copyMem(addr dest.x[0], addr src.x[0], 32) + copyMem(addr dest.y[0], addr src.y[0], 32) + +proc copyInto*(dest: var G2, src: G2) = + for i in 0 .. 1: + copyMem(addr dest.x[i][0], addr src.x[i][0], 32) + copyMem(addr dest.y[i][0], addr src.y[i][0], 32) + +proc copyProof*(dest: ptr Proof, src: Proof) = + if not isNil(dest): + copyInto(dest.a, src.a) + copyInto(dest.b, src.b) + copyInto(dest.c, src.c) diff --git a/codex/slots/proofs/backendutils.nim b/codex/slots/proofs/backendutils.nim index 0e334ace..b6995896 100644 --- a/codex/slots/proofs/backendutils.nim +++ b/codex/slots/proofs/backendutils.nim @@ -1,8 +1,13 @@ import ./backends +import pkg/taskpools type BackendUtils* = ref object of RootObj method initializeCircomBackend*( - self: BackendUtils, r1csFile: string, wasmFile: string, zKeyFile: string + self: BackendUtils, + r1csFile: string, + wasmFile: string, + zKeyFile: string, + taskpool: Taskpool, ): AnyBackend {.base, gcsafe.} = - CircomCompat.init(r1csFile, wasmFile, zKeyFile) + CircomCompat.init(r1csFile, wasmFile, zKeyFile, taskpool = taskpool) diff --git a/codex/slots/proofs/prover.nim b/codex/slots/proofs/prover.nim index 1afcd068..7ce5b08a 100644 --- a/codex/slots/proofs/prover.nim +++ b/codex/slots/proofs/prover.nim @@ -74,7 +74,7 @@ proc prove*( return failure(err) # prove slot - without proof =? self.backend.prove(proofInput), err: + without proof =? await self.backend.prove(proofInput), err: error "Unable to prove slot", err = err.msg return failure(err) @@ -85,7 +85,11 @@ proc verify*( ): Future[?!bool] {.async: (raises: [CancelledError]).} = ## Prove a statement using backend. ## Returns a future that resolves to a proof. - self.backend.verify(proof, inputs) + without res =? (await self.backend.verify(proof, inputs)), err: + error "Unable to verify proof", err = err.msg + return failure(err) + + return success(res) proc new*( _: type Prover, store: BlockStore, backend: AnyBackend, nSamples: int diff --git a/tests/codex/slots/backends/testcircomcompat.nim b/tests/codex/slots/backends/testcircomcompat.nim index b61d4f18..a9e1fdf3 100644 --- a/tests/codex/slots/backends/testcircomcompat.nim +++ b/tests/codex/slots/backends/testcircomcompat.nim @@ -26,29 +26,36 @@ suite "Test Circom Compat Backend - control inputs": var circom: CircomCompat proofInputs: ProofInputs[Poseidon2Hash] + taskpool: Taskpool setup: let inputData = readFile("tests/circuits/fixtures/input.json") inputJson = !JsonNode.parse(inputData) + taskpool = Taskpool.new() proofInputs = Poseidon2Hash.jsonToProofInput(inputJson) - circom = CircomCompat.init(r1cs, wasm, zkey) + circom = CircomCompat.init(r1cs, wasm, zkey, taskpool = taskpool) teardown: circom.release() # this comes from the rust FFI + taskpool.shutdown() test "Should verify with correct inputs": - let proof = circom.prove(proofInputs).tryGet + let proof = (await circom.prove(proofInputs)).tryGet - check circom.verify(proof, proofInputs).tryGet + var resp = (await circom.verify(proof, proofInputs)).tryGet + + check resp test "Should not verify with incorrect inputs": proofInputs.slotIndex = 1 # change slot index - let proof = circom.prove(proofInputs).tryGet + let proof = (await circom.prove(proofInputs)).tryGet - check circom.verify(proof, proofInputs).tryGet == false + var resp = (await circom.verify(proof, proofInputs)).tryGet + + check resp == false suite "Test Circom Compat Backend": let @@ -77,6 +84,7 @@ suite "Test Circom Compat Backend": challenge: array[32, byte] builder: Poseidon2Builder sampler: Poseidon2Sampler + taskpool: Taskpool setup: let @@ -91,8 +99,9 @@ suite "Test Circom Compat Backend": builder = Poseidon2Builder.new(store, verifiable).tryGet sampler = Poseidon2Sampler.new(slotId, store, builder).tryGet + taskpool = Taskpool.new() - circom = CircomCompat.init(r1cs, wasm, zkey) + circom = CircomCompat.init(r1cs, wasm, zkey, taskpool = taskpool) challenge = 1234567.toF.toBytes.toArray32 proofInputs = (await sampler.getProofInput(challenge, samples)).tryGet @@ -101,15 +110,20 @@ suite "Test Circom Compat Backend": circom.release() # this comes from the rust FFI await repoTmp.destroyDb() await metaTmp.destroyDb() + taskpool.shutdown() test "Should verify with correct input": - var proof = circom.prove(proofInputs).tryGet + var proof = (await circom.prove(proofInputs)).tryGet - check circom.verify(proof, proofInputs).tryGet + var resp = (await circom.verify(proof, proofInputs)).tryGet + + check resp == true test "Should not verify with incorrect input": proofInputs.slotIndex = 1 # change slot index - let proof = circom.prove(proofInputs).tryGet + let proof = (await circom.prove(proofInputs)).tryGet - check circom.verify(proof, proofInputs).tryGet == false + var resp = (await circom.verify(proof, proofInputs)).tryGet + + check resp == false diff --git a/tests/codex/slots/testbackendfactory.nim b/tests/codex/slots/testbackendfactory.nim index a24bc41a..fa17b5f3 100644 --- a/tests/codex/slots/testbackendfactory.nim +++ b/tests/codex/slots/testbackendfactory.nim @@ -4,6 +4,7 @@ import ../../asynctest import pkg/chronos import pkg/confutils/defs import pkg/codex/conf +import pkg/taskpools import pkg/codex/slots/proofs/backends import pkg/codex/slots/proofs/backendfactory import pkg/codex/slots/proofs/backendutils @@ -18,7 +19,11 @@ type BackendUtilsMock = ref object of BackendUtils argZKeyFile: string method initializeCircomBackend*( - self: BackendUtilsMock, r1csFile: string, wasmFile: string, zKeyFile: string + self: BackendUtilsMock, + r1csFile: string, + wasmFile: string, + zKeyFile: string, + taskpool: Taskpool, ): AnyBackend = self.argR1csFile = r1csFile self.argWasmFile = wasmFile @@ -52,7 +57,7 @@ suite "Test BackendFactory": circomWasm: InputFile("tests/circuits/fixtures/proof_main.wasm"), circomZkey: InputFile("tests/circuits/fixtures/proof_main.zkey"), ) - backend = config.initializeBackend(utilsMock).tryGet + backend = config.initializeBackend(utilsMock, taskpool = nil).tryGet check: backend.vkp != nil @@ -73,7 +78,7 @@ suite "Test BackendFactory": # will be picked up as local files: circuitDir: OutDir("tests/circuits/fixtures"), ) - backend = config.initializeBackend(utilsMock).tryGet + backend = config.initializeBackend(utilsMock, taskpool = nil).tryGet check: backend.vkp != nil @@ -91,7 +96,7 @@ suite "Test BackendFactory": marketplaceAddress: EthAddress.example.some, circuitDir: OutDir(circuitDir), ) - backendResult = config.initializeBackend(utilsMock) + backendResult = config.initializeBackend(utilsMock, taskpool = nil) check: backendResult.isErr diff --git a/tests/codex/slots/testprover.nim b/tests/codex/slots/testprover.nim index c567db55..7986cb97 100644 --- a/tests/codex/slots/testprover.nim +++ b/tests/codex/slots/testprover.nim @@ -29,6 +29,8 @@ suite "Test Prover": var store: BlockStore prover: Prover + backend: AnyBackend + taskpool: Taskpool setup: let @@ -44,7 +46,8 @@ suite "Test Prover": circomZkey: InputFile("tests/circuits/fixtures/proof_main.zkey"), numProofSamples: samples, ) - backend = config.initializeBackend().tryGet() + taskpool = Taskpool.new() + backend = config.initializeBackend(taskpool = taskpool).tryGet() store = RepoStore.new(repoDs, metaDs) prover = Prover.new(store, backend, config.numProofSamples) @@ -52,6 +55,7 @@ suite "Test Prover": teardown: await repoTmp.destroyDb() await metaTmp.destroyDb() + taskpool.shutdown() test "Should sample and prove a slot": let (_, _, verifiable) = await createVerifiableManifest( @@ -86,3 +90,79 @@ suite "Test Prover": check: (await prover.verify(proof, inputs)).tryGet == true + + test "Should concurrently prove/verify": + const iterations = 5 + + var proveTasks = newSeq[Future[?!(AnyProofInputs, AnyProof)]]() + var verifyTasks = newSeq[Future[?!bool]]() + + for i in 0 ..< iterations: + # create multiple prove tasks + let (_, _, verifiable) = await createVerifiableManifest( + store, + 8, # number of blocks in the original dataset (before EC) + 5, # ecK + 3, # ecM + blockSize, + cellSize, + ) + + proveTasks.add(prover.prove(1, verifiable, challenge)) + + let proveResults = await allFinished(proveTasks) + # + for i in 0 ..< proveResults.len: + var (inputs, proofs) = proveTasks[i].read().tryGet() + verifyTasks.add(prover.verify(proofs, inputs)) + + let verifyResults = await allFinished(verifyTasks) + + for i in 0 ..< verifyResults.len: + check: + verifyResults[i].read().tryGet() == true + + test "Should complete prove/verify task when cancelled": + let (_, _, verifiable) = await createVerifiableManifest( + store, + 8, # number of blocks in the original dataset (before EC) + 5, # ecK + 3, # ecM + blockSize, + cellSize, + ) + + let (inputs, proof) = (await prover.prove(1, verifiable, challenge)).tryGet + + var cancelledProof = ProofPtr.new() + defer: + destroyProof(cancelledProof) + + # call asyncProve and cancel the task + let proveFut = backend.asyncProve(backend.normalizeInput(inputs), cancelledProof) + proveFut.cancel() + + try: + discard await proveFut + except CatchableError as exc: + check exc of CancelledError + finally: + # validate the cancelledProof + check: + (await prover.verify(cancelledProof[], inputs)).tryGet == true + + var verifyRes = VerifyResult.new() + defer: + destroyVerifyResult(verifyRes) + + # call asyncVerify and cancel the task + let verifyFut = backend.asyncVerify(proof, inputs, verifyRes) + verifyFut.cancel() + + try: + discard await verifyFut + except CatchableError as exc: + check exc of CancelledError + finally: + # validate the verifyResponse + check verifyRes[] == true diff --git a/vendor/nim-circom-compat b/vendor/nim-circom-compat index d3fb9039..88197fe6 160000 --- a/vendor/nim-circom-compat +++ b/vendor/nim-circom-compat @@ -1 +1 @@ -Subproject commit d3fb903945c3895f28a2e50685745e0a9762ece5 +Subproject commit 88197fe6ec929559b37e1443785ece650d2e9255