multi threading support for circom prover

This commit is contained in:
munna0908 2025-02-20 17:48:48 +05:30
parent 16dce0fc43
commit 65b8b46cb0
No known key found for this signature in database
GPG Key ID: 2FFCD637E937D3E6
6 changed files with 214 additions and 46 deletions

View File

@ -289,8 +289,9 @@ proc new*(
store = NetworkStore.new(engine, repoStore)
prover =
if config.prover:
let backend =
config.initializeBackend().expect("Unable to create prover backend.")
let backend = config.initializeBackend(taskpool = taskpool).expect(
"Unable to create prover backend."
)
some Prover.new(store, backend, config.numProofSamples)
else:
none Prover

View File

@ -2,6 +2,7 @@ import os
import strutils
import pkg/chronos
import pkg/chronicles
import pkg/taskpools
import pkg/questionable
import pkg/confutils/defs
import pkg/stew/io2
@ -11,7 +12,9 @@ import ../../conf
import ./backends
import ./backendutils
proc initializeFromConfig(config: CodexConf, utils: BackendUtils): ?!AnyBackend =
proc initializeFromConfig(
config: CodexConf, utils: BackendUtils, taskpool: Taskpool
): ?!AnyBackend =
if not fileAccessible($config.circomR1cs, {AccessFlags.Read}) or
not endsWith($config.circomR1cs, ".r1cs"):
return failure("Circom R1CS file not accessible")
@ -27,7 +30,7 @@ proc initializeFromConfig(config: CodexConf, utils: BackendUtils): ?!AnyBackend
trace "Initialized prover backend from cli config"
success(
utils.initializeCircomBackend(
$config.circomR1cs, $config.circomWasm, $config.circomZkey
$config.circomR1cs, $config.circomWasm, $config.circomZkey, taskpool
)
)
@ -41,14 +44,14 @@ proc zkeyFilePath(config: CodexConf): string =
config.circuitDir / "proof_main.zkey"
proc initializeFromCircuitDirFiles(
config: CodexConf, utils: BackendUtils
config: CodexConf, utils: BackendUtils, taskpool: Taskpool
): ?!AnyBackend {.gcsafe.} =
if fileExists(config.r1csFilePath) and fileExists(config.wasmFilePath) and
fileExists(config.zkeyFilePath):
trace "Initialized prover backend from local files"
return success(
utils.initializeCircomBackend(
config.r1csFilePath, config.wasmFilePath, config.zkeyFilePath
config.r1csFilePath, config.wasmFilePath, config.zkeyFilePath, taskpool
)
)
@ -68,11 +71,11 @@ proc suggestDownloadTool(config: CodexConf) =
instructions
proc initializeBackend*(
config: CodexConf, utils: BackendUtils = BackendUtils()
config: CodexConf, utils: BackendUtils = BackendUtils(), taskpool: Taskpool
): ?!AnyBackend =
without backend =? initializeFromConfig(config, utils), cliErr:
without backend =? initializeFromConfig(config, utils, taskpool), cliErr:
info "Could not initialize prover backend from CLI options...", msg = cliErr.msg
without backend =? initializeFromCircuitDirFiles(config, utils), localErr:
without backend =? initializeFromCircuitDirFiles(config, utils, taskpool), localErr:
info "Could not initialize prover backend from circuit dir files...",
msg = localErr.msg
suggestDownloadTool(config)

View File

@ -9,9 +9,11 @@
{.push raises: [].}
import std/sugar
import std/[sugar, atomics]
import pkg/chronos
import pkg/taskpools
import pkg/chronos/threadsync
import pkg/questionable/results
import pkg/circomcompat
@ -22,6 +24,7 @@ import ../../../contracts
import ./converters
export circomcompat, converters
export taskpools
type
CircomCompat* = object
@ -35,9 +38,24 @@ type
zkeyPath: string # path to the zkey file
backendCfg: ptr CircomBn254Cfg
vkp*: ptr CircomKey
taskpool: Taskpool
NormalizedProofInputs*[H] {.borrow: `.`.} = distinct ProofInputs[H]
ProveTask = object
circom: ptr CircomCompat
ctx: ptr CircomCompatCtx
proof: ptr Proof
success: Atomic[bool]
signal: ThreadSignalPtr
VerifyTask = object
proof: ptr CircomProof
vkp: ptr CircomKey
inputs: ptr CircomInputs
success: ptr Atomic[bool]
signal: ThreadSignalPtr
func normalizeInput*[H](
self: CircomCompat, input: ProofInputs[H]
): NormalizedProofInputs[H] =
@ -79,7 +97,28 @@ proc release*(self: CircomCompat) =
if not isNil(self.vkp):
self.vkp.unsafeAddr.release_key()
proc prove[H](self: CircomCompat, input: NormalizedProofInputs[H]): ?!CircomProof =
proc circomProveTask(task: ptr ProveTask) {.gcsafe.} =
defer:
discard task[].signal.fireSync()
var proofPtr: ptr Proof = nil
try:
if (
let res = task.circom.backendCfg.prove_circuit(task.ctx, proofPtr.addr)
res != ERR_OK
) or proofPtr == nil:
task.success.store(false)
return
copyProof(task.proof, proofPtr[])
task.success.store(true)
finally:
if proofPtr != nil:
proofPtr.addr.release_proof()
proc asyncProve*[H](
self: CircomCompat, input: NormalizedProofInputs[H], proof: ptr Proof
): Future[?!void] {.async.} =
doAssert input.samples.len == self.numSamples, "Number of samples does not match"
doAssert input.slotProof.len <= self.datasetDepth,
@ -143,13 +182,11 @@ proc prove[H](self: CircomCompat, input: NormalizedProofInputs[H]): ?!CircomProo
for s in input.samples:
var
merklePaths = s.merklePaths.mapIt(it.toBytes)
merklePaths = s.merklePaths.mapIt(@(it.toBytes)).concat
data = s.cellData.mapIt(@(it.toBytes)).concat
if ctx.push_input_u256_array(
"merklePaths".cstring,
merklePaths[0].addr,
uint (merklePaths[0].len * merklePaths.len),
"merklePaths".cstring, merklePaths[0].addr, uint (merklePaths.len)
) != ERR_OK:
return failure("Failed to push merkle paths")
@ -157,44 +194,127 @@ proc prove[H](self: CircomCompat, input: NormalizedProofInputs[H]): ?!CircomProo
ERR_OK:
return failure("Failed to push cell data")
var proofPtr: ptr Proof = nil
without threadPtr =? ThreadSignalPtr.new():
return failure("Unable to create thread signal")
let proof =
defer:
threadPtr.close().expect("closing once works")
var task = ProveTask(circom: addr self, ctx: ctx, proof: proof, signal: threadPtr)
let taskPtr = addr task
doAssert task.circom.taskpool.numThreads > 1,
"Must have at least one separate thread or signal will never be fired"
task.circom.taskpool.spawn circomProveTask(taskPtr)
let threadFut = threadPtr.wait()
try:
await threadFut.join()
except CatchableError as exc:
try:
if (let res = self.backendCfg.prove_circuit(ctx, proofPtr.addr); res != ERR_OK) or
proofPtr == nil:
return failure("Failed to prove - err code: " & $res)
proofPtr[]
await threadFut
except AsyncError as asyncExc:
return failure(asyncExc.msg)
finally:
if proofPtr != nil:
proofPtr.addr.release_proof()
if exc of CancelledError:
raise (ref CancelledError) exc
else:
return failure(exc.msg)
success proof
if not task.success.load():
return failure("Failed to prove circuit")
proc prove*[H](self: CircomCompat, input: ProofInputs[H]): ?!CircomProof =
self.prove(self.normalizeInput(input))
success()
proc prove*[H](
self: CircomCompat, input: ProofInputs[H]
): Future[?!CircomProof] {.async, raises: [CancelledError].} =
var proof = newProof()
defer:
destroyProof(proof)
try:
if error =? (await self.asyncProve(self.normalizeInput(input), proof)).errorOption:
return failure(error)
return success(deepCopy(proof)[])
except CancelledError as exc:
raise exc
proc circomVerifyTask(task: ptr VerifyTask) {.gcsafe.} =
defer:
task[].inputs[].releaseCircomInputs()
discard task[].signal.fireSync()
let res = verify_circuit(task[].proof, task[].inputs, task[].vkp)
if res == ERR_OK:
task[].success[].store(true)
elif res == ERR_FAILED_TO_VERIFY_PROOF:
task[].success[].store(false)
else:
task[].success[].store(false)
error "Failed to verify proof", errorCode = res
proc asyncVerify*[H](
self: CircomCompat,
proof: CircomProof,
inputs: ProofInputs[H],
success: ptr Atomic[bool],
): Future[?!void] {.async.} =
var proofPtr = unsafeAddr proof
var inputs = inputs.toCircomInputs()
without threadPtr =? ThreadSignalPtr.new():
return failure("Unable to create thread signal")
defer:
threadPtr.close().expect("closing once works")
var task = VerifyTask(
proof: proofPtr,
vkp: self.vkp,
inputs: addr inputs,
success: success,
signal: threadPtr,
)
let taskPtr = addr task
doAssert self.taskpool.numThreads > 1,
"Must have at least one separate thread or signal will never be fired"
self.taskpool.spawn circomVerifyTask(taskPtr)
let threadFut = threadPtr.wait()
try:
await threadFut.join()
except CatchableError as exc:
try:
await threadFut
except AsyncError as asyncExc:
return failure(asyncExc.msg)
finally:
if exc of CancelledError:
raise (ref CancelledError) exc
else:
return failure(exc.msg)
success()
proc verify*[H](
self: CircomCompat, proof: CircomProof, inputs: ProofInputs[H]
): ?!bool =
): Future[?!bool] {.async, raises: [CancelledError].} =
## Verify a proof using a ctx
##
var
proofPtr = unsafeAddr proof
inputs = inputs.toCircomInputs()
var res = newVerifyResult()
defer:
destroyVerifyResult(res)
try:
let res = verify_circuit(proofPtr, inputs.addr, self.vkp)
if res == ERR_OK:
success true
elif res == ERR_FAILED_TO_VERIFY_PROOF:
success false
else:
failure("Failed to verify proof - err code: " & $res)
finally:
inputs.releaseCircomInputs()
if error =? (await self.asyncVerify(proof, inputs, res)).errorOption:
return failure(error)
return success(res[].load())
except CancelledError as exc:
raise exc
proc init*(
_: type CircomCompat,
@ -206,6 +326,7 @@ proc init*(
blkDepth = DefaultBlockDepth,
cellElms = DefaultCellElms,
numSamples = DefaultSamplesNum,
taskpool: Taskpool,
): CircomCompat =
## Create a new ctx
##
@ -237,4 +358,5 @@ proc init*(
numSamples: numSamples,
backendCfg: cfg,
vkp: vkpPtr,
taskpool: taskpool,
)

View File

@ -10,6 +10,7 @@
{.push raises: [].}
import pkg/circomcompat
import std/atomics
import ../../../contracts
import ../../types
@ -22,6 +23,9 @@ type
CircomProof* = Proof
CircomKey* = VerifyingKey
CircomInputs* = Inputs
VerifyResult* = Atomic[bool]
export VerifyResult
proc toCircomInputs*(inputs: ProofInputs[Poseidon2Hash]): CircomInputs =
var
@ -52,3 +56,32 @@ func toG2*(g: CircomG2): G2Point =
func toGroth16Proof*(proof: CircomProof): Groth16Proof =
Groth16Proof(a: proof.a.toG1, b: proof.b.toG2, c: proof.c.toG1)
proc newProof*(): ptr Proof =
result = cast[ptr Proof](allocShared0(sizeof(Proof)))
proc newVerifyResult*(): ptr VerifyResult =
result = cast[ptr VerifyResult](allocShared0(sizeof(VerifyResult)))
proc destroyVerifyResult*(result: ptr VerifyResult) =
if result != nil:
deallocShared(result)
proc destroyProof*(proof: ptr Proof) =
if proof != nil:
deallocShared(proof)
proc copyInto*(dest: var G1, src: G1) =
copyMem(addr dest.x[0], addr src.x[0], 32)
copyMem(addr dest.y[0], addr src.y[0], 32)
proc copyInto*(dest: var G2, src: G2) =
for i in 0 .. 1:
copyMem(addr dest.x[i][0], addr src.x[i][0], 32)
copyMem(addr dest.y[i][0], addr src.y[i][0], 32)
proc copyProof*(dest: ptr Proof, src: Proof) =
if not isNil(dest):
copyInto(dest.a, src.a)
copyInto(dest.b, src.b)
copyInto(dest.c, src.c)

View File

@ -1,8 +1,13 @@
import ./backends
import pkg/taskpools
type BackendUtils* = ref object of RootObj
method initializeCircomBackend*(
self: BackendUtils, r1csFile: string, wasmFile: string, zKeyFile: string
self: BackendUtils,
r1csFile: string,
wasmFile: string,
zKeyFile: string,
taskpool: Taskpool,
): AnyBackend {.base, gcsafe.} =
CircomCompat.init(r1csFile, wasmFile, zKeyFile)
CircomCompat.init(r1csFile, wasmFile, zKeyFile, taskpool = taskpool)

View File

@ -72,7 +72,7 @@ proc prove*(
return failure(err)
# prove slot
without proof =? self.backend.prove(proofInput), err:
without proof =? await self.backend.prove(proofInput), err:
error "Unable to prove slot", err = err.msg
return failure(err)
@ -83,7 +83,11 @@ proc verify*(
): Future[?!bool] {.async.} =
## Prove a statement using backend.
## Returns a future that resolves to a proof.
self.backend.verify(proof, inputs)
without res =? (await self.backend.verify(proof, inputs)), err:
error "Unable to verify proof", err = err.msg
return failure(err)
return success(res)
proc new*(
_: type Prover, store: BlockStore, backend: AnyBackend, nSamples: int