Merge f408b78ab042684a2a7a0c0703cb09913c56055e into 49e801803f31544c1896995190c82da5434e80df

This commit is contained in:
munna0908 2025-12-17 11:11:32 -06:00 committed by GitHub
commit 76f9e22280
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 331 additions and 65 deletions

View File

@ -347,8 +347,9 @@ proc new*(
store = NetworkStore.new(engine, repoStore)
prover =
if config.prover:
let backend =
config.initializeBackend().expect("Unable to create prover backend.")
let backend = config.initializeBackend(taskpool = taskpool).expect(
"Unable to create prover backend."
)
some Prover.new(store, backend, config.numProofSamples)
else:
none Prover

View File

@ -2,6 +2,7 @@ import os
import strutils
import pkg/chronos
import pkg/chronicles
import pkg/taskpools
import pkg/questionable
import pkg/confutils/defs
import pkg/stew/io2
@ -11,7 +12,9 @@ import ../../conf
import ./backends
import ./backendutils
proc initializeFromConfig(config: CodexConf, utils: BackendUtils): ?!AnyBackend =
proc initializeFromConfig(
config: CodexConf, utils: BackendUtils, taskpool: Taskpool
): ?!AnyBackend =
if not fileAccessible($config.circomR1cs, {AccessFlags.Read}) or
not endsWith($config.circomR1cs, ".r1cs"):
return failure("Circom R1CS file not accessible")
@ -27,7 +30,7 @@ proc initializeFromConfig(config: CodexConf, utils: BackendUtils): ?!AnyBackend
trace "Initialized prover backend from cli config"
success(
utils.initializeCircomBackend(
$config.circomR1cs, $config.circomWasm, $config.circomZkey
$config.circomR1cs, $config.circomWasm, $config.circomZkey, taskpool
)
)
@ -41,14 +44,14 @@ proc zkeyFilePath(config: CodexConf): string =
config.circuitDir / "proof_main.zkey"
proc initializeFromCircuitDirFiles(
config: CodexConf, utils: BackendUtils
config: CodexConf, utils: BackendUtils, taskpool: Taskpool
): ?!AnyBackend {.gcsafe.} =
if fileExists(config.r1csFilePath) and fileExists(config.wasmFilePath) and
fileExists(config.zkeyFilePath):
trace "Initialized prover backend from local files"
return success(
utils.initializeCircomBackend(
config.r1csFilePath, config.wasmFilePath, config.zkeyFilePath
config.r1csFilePath, config.wasmFilePath, config.zkeyFilePath, taskpool
)
)
@ -68,11 +71,11 @@ proc suggestDownloadTool(config: CodexConf) =
instructions
proc initializeBackend*(
config: CodexConf, utils: BackendUtils = BackendUtils()
config: CodexConf, utils: BackendUtils = BackendUtils(), taskpool: Taskpool
): ?!AnyBackend =
without backend =? initializeFromConfig(config, utils), cliErr:
without backend =? initializeFromConfig(config, utils, taskpool), cliErr:
info "Could not initialize prover backend from CLI options...", msg = cliErr.msg
without backend =? initializeFromCircuitDirFiles(config, utils), localErr:
without backend =? initializeFromCircuitDirFiles(config, utils, taskpool), localErr:
info "Could not initialize prover backend from circuit dir files...",
msg = localErr.msg
suggestDownloadTool(config)

View File

@ -9,9 +9,11 @@
{.push raises: [].}
import std/sugar
import std/[sugar, atomics, locks]
import pkg/chronos
import pkg/taskpools
import pkg/chronos/threadsync
import pkg/questionable/results
import pkg/circomcompat
@ -22,6 +24,7 @@ import ../../../contracts
import ./converters
export circomcompat, converters
export taskpools
type
CircomCompat* = object
@ -35,9 +38,25 @@ type
zkeyPath: string # path to the zkey file
backendCfg: ptr CircomBn254Cfg
vkp*: ptr CircomKey
taskpool: Taskpool
lock: ptr Lock
NormalizedProofInputs*[H] {.borrow: `.`.} = distinct ProofInputs[H]
ProveTask = object
circom: ptr CircomCompat
ctx: ptr CircomCompatCtx
proof: ptr Proof
success: Atomic[bool]
signal: ThreadSignalPtr
VerifyTask = object
proof: ptr CircomProof
vkp: ptr CircomKey
inputs: ptr CircomInputs
success: VerifyResult
signal: ThreadSignalPtr
func normalizeInput*[H](
self: CircomCompat, input: ProofInputs[H]
): NormalizedProofInputs[H] =
@ -79,7 +98,33 @@ proc release*(self: CircomCompat) =
if not isNil(self.vkp):
self.vkp.unsafeAddr.release_key()
proc prove[H](self: CircomCompat, input: NormalizedProofInputs[H]): ?!CircomProof =
if not isNil(self.lock):
deinitLock(self.lock[]) # Cleanup the lock
dealloc(self.lock) # Free the memory
proc circomProveTask(task: ptr ProveTask) {.gcsafe.} =
withLock task[].circom.lock[]:
defer:
discard task[].signal.fireSync()
var proofPtr: ptr Proof = nil
try:
if (
let res = task.circom.backendCfg.prove_circuit(task.ctx, proofPtr.addr)
res != ERR_OK
) or proofPtr == nil:
task.success.store(false)
return
copyProof(task.proof, proofPtr[])
task.success.store(true)
finally:
if proofPtr != nil:
proofPtr.addr.release_proof()
proc asyncProve*[H](
self: CircomCompat, input: NormalizedProofInputs[H], proof: ptr Proof
): Future[?!void] {.async.} =
doAssert input.samples.len == self.numSamples, "Number of samples does not match"
doAssert input.slotProof.len <= self.datasetDepth,
@ -143,13 +188,11 @@ proc prove[H](self: CircomCompat, input: NormalizedProofInputs[H]): ?!CircomProo
for s in input.samples:
var
merklePaths = s.merklePaths.mapIt(it.toBytes)
merklePaths = s.merklePaths.mapIt(@(it.toBytes)).concat
data = s.cellData.mapIt(@(it.toBytes)).concat
if ctx.push_input_u256_array(
"merklePaths".cstring,
merklePaths[0].addr,
uint (merklePaths[0].len * merklePaths.len),
"merklePaths".cstring, merklePaths[0].addr, uint (merklePaths.len)
) != ERR_OK:
return failure("Failed to push merkle paths")
@ -157,44 +200,118 @@ proc prove[H](self: CircomCompat, input: NormalizedProofInputs[H]): ?!CircomProo
ERR_OK:
return failure("Failed to push cell data")
var proofPtr: ptr Proof = nil
without threadPtr =? ThreadSignalPtr.new():
return failure("Unable to create thread signal")
let proof =
try:
if (let res = self.backendCfg.prove_circuit(ctx, proofPtr.addr); res != ERR_OK) or
proofPtr == nil:
return failure("Failed to prove - err code: " & $res)
defer:
threadPtr.close().expect("closing once works")
proofPtr[]
finally:
if proofPtr != nil:
proofPtr.addr.release_proof()
var task = ProveTask(circom: addr self, ctx: ctx, proof: proof, signal: threadPtr)
success proof
let taskPtr = addr task
proc prove*[H](self: CircomCompat, input: ProofInputs[H]): ?!CircomProof =
self.prove(self.normalizeInput(input))
doAssert task.circom.taskpool.numThreads > 1,
"Must have at least one separate thread or signal will never be fired"
task.circom.taskpool.spawn circomProveTask(taskPtr)
let threadFut = threadPtr.wait()
if joinErr =? catch(await threadFut.join()).errorOption:
if err =? catch(await noCancel threadFut).errorOption:
return failure(err)
if joinErr of CancelledError:
raise joinErr
else:
return failure(joinErr)
if not task.success.load():
return failure("Failed to prove circuit")
success()
proc prove*[H](
self: CircomCompat, input: ProofInputs[H]
): Future[?!CircomProof] {.async, raises: [CancelledError].} =
var proof = ProofPtr.new()
defer:
destroyProof(proof)
try:
if error =? (await self.asyncProve(self.normalizeInput(input), proof)).errorOption:
return failure(error)
return success(deepCopy(proof)[])
except CancelledError as exc:
raise exc
proc circomVerifyTask(task: ptr VerifyTask) {.gcsafe.} =
defer:
task[].inputs[].releaseCircomInputs()
discard task[].signal.fireSync()
let res = verify_circuit(task[].proof, task[].inputs, task[].vkp)
if res == ERR_OK:
task[].success[] = true
elif res == ERR_FAILED_TO_VERIFY_PROOF:
task[].success[] = false
else:
task[].success[] = false
error "Failed to verify proof", errorCode = res
proc asyncVerify*[H](
self: CircomCompat,
proof: CircomProof,
inputs: ProofInputs[H],
success: VerifyResult,
): Future[?!void] {.async.} =
var proofPtr = unsafeAddr proof
var inputs = inputs.toCircomInputs()
without threadPtr =? ThreadSignalPtr.new():
return failure("Unable to create thread signal")
defer:
threadPtr.close().expect("closing once works")
var task = VerifyTask(
proof: proofPtr,
vkp: self.vkp,
inputs: addr inputs,
success: success,
signal: threadPtr,
)
let taskPtr = addr task
doAssert self.taskpool.numThreads > 1,
"Must have at least one separate thread or signal will never be fired"
self.taskpool.spawn circomVerifyTask(taskPtr)
let threadFut = threadPtr.wait()
if joinErr =? catch(await threadFut.join()).errorOption:
if err =? catch(await noCancel threadFut).errorOption:
return failure(err)
if joinErr of CancelledError:
raise joinErr
else:
return failure(joinErr)
success()
proc verify*[H](
self: CircomCompat, proof: CircomProof, inputs: ProofInputs[H]
): ?!bool =
): Future[?!bool] {.async, raises: [CancelledError].} =
## Verify a proof using a ctx
##
var
proofPtr = unsafeAddr proof
inputs = inputs.toCircomInputs()
var res = VerifyResult.new()
defer:
destroyVerifyResult(res)
try:
let res = verify_circuit(proofPtr, inputs.addr, self.vkp)
if res == ERR_OK:
success true
elif res == ERR_FAILED_TO_VERIFY_PROOF:
success false
else:
failure("Failed to verify proof - err code: " & $res)
finally:
inputs.releaseCircomInputs()
if error =? (await self.asyncVerify(proof, inputs, res)).errorOption:
return failure(error)
return success(res[])
except CancelledError as exc:
raise exc
proc init*(
_: type CircomCompat,
@ -206,10 +323,13 @@ proc init*(
blkDepth = DefaultBlockDepth,
cellElms = DefaultCellElms,
numSamples = DefaultSamplesNum,
taskpool: Taskpool,
): CircomCompat =
## Create a new ctx
##
# Allocate and initialize the lock
var lockPtr = create(Lock) # Allocate memory for the lock
initLock(lockPtr[]) # Initialize the lock
## Create a new ctx
var cfg: ptr CircomBn254Cfg
var zkey = if zkeyPath.len > 0: zkeyPath.cstring else: nil
@ -237,4 +357,6 @@ proc init*(
numSamples: numSamples,
backendCfg: cfg,
vkp: vkpPtr,
taskpool: taskpool,
lock: lockPtr,
)

View File

@ -10,6 +10,7 @@
{.push raises: [].}
import pkg/circomcompat
import std/atomics
import ../../../contracts
import ../../types
@ -22,6 +23,14 @@ type
CircomProof* = Proof
CircomKey* = VerifyingKey
CircomInputs* = Inputs
VerifyResult* = ptr bool
ProofPtr* = ptr Proof
proc new*(_: type ProofPtr): ProofPtr =
cast[ptr Proof](allocShared0(sizeof(Proof)))
proc new*(_: type VerifyResult): VerifyResult =
cast[ptr bool](allocShared0(sizeof(bool)))
proc toCircomInputs*(inputs: ProofInputs[Poseidon2Hash]): CircomInputs =
var
@ -52,3 +61,26 @@ func toG2*(g: CircomG2): G2Point =
func toGroth16Proof*(proof: CircomProof): Groth16Proof =
Groth16Proof(a: proof.a.toG1, b: proof.b.toG2, c: proof.c.toG1)
proc destroyVerifyResult*(result: VerifyResult) =
if result != nil:
deallocShared(result)
proc destroyProof*(proof: ProofPtr) =
if proof != nil:
deallocShared(proof)
proc copyInto*(dest: var G1, src: G1) =
copyMem(addr dest.x[0], addr src.x[0], 32)
copyMem(addr dest.y[0], addr src.y[0], 32)
proc copyInto*(dest: var G2, src: G2) =
for i in 0 .. 1:
copyMem(addr dest.x[i][0], addr src.x[i][0], 32)
copyMem(addr dest.y[i][0], addr src.y[i][0], 32)
proc copyProof*(dest: ptr Proof, src: Proof) =
if not isNil(dest):
copyInto(dest.a, src.a)
copyInto(dest.b, src.b)
copyInto(dest.c, src.c)

View File

@ -1,8 +1,13 @@
import ./backends
import pkg/taskpools
type BackendUtils* = ref object of RootObj
method initializeCircomBackend*(
self: BackendUtils, r1csFile: string, wasmFile: string, zKeyFile: string
self: BackendUtils,
r1csFile: string,
wasmFile: string,
zKeyFile: string,
taskpool: Taskpool,
): AnyBackend {.base, gcsafe.} =
CircomCompat.init(r1csFile, wasmFile, zKeyFile)
CircomCompat.init(r1csFile, wasmFile, zKeyFile, taskpool = taskpool)

View File

@ -74,7 +74,7 @@ proc prove*(
return failure(err)
# prove slot
without proof =? self.backend.prove(proofInput), err:
without proof =? await self.backend.prove(proofInput), err:
error "Unable to prove slot", err = err.msg
return failure(err)
@ -85,7 +85,11 @@ proc verify*(
): Future[?!bool] {.async: (raises: [CancelledError]).} =
## Prove a statement using backend.
## Returns a future that resolves to a proof.
self.backend.verify(proof, inputs)
without res =? (await self.backend.verify(proof, inputs)), err:
error "Unable to verify proof", err = err.msg
return failure(err)
return success(res)
proc new*(
_: type Prover, store: BlockStore, backend: AnyBackend, nSamples: int

View File

@ -26,29 +26,36 @@ suite "Test Circom Compat Backend - control inputs":
var
circom: CircomCompat
proofInputs: ProofInputs[Poseidon2Hash]
taskpool: Taskpool
setup:
let
inputData = readFile("tests/circuits/fixtures/input.json")
inputJson = !JsonNode.parse(inputData)
taskpool = Taskpool.new()
proofInputs = Poseidon2Hash.jsonToProofInput(inputJson)
circom = CircomCompat.init(r1cs, wasm, zkey)
circom = CircomCompat.init(r1cs, wasm, zkey, taskpool = taskpool)
teardown:
circom.release() # this comes from the rust FFI
taskpool.shutdown()
test "Should verify with correct inputs":
let proof = circom.prove(proofInputs).tryGet
let proof = (await circom.prove(proofInputs)).tryGet
check circom.verify(proof, proofInputs).tryGet
var resp = (await circom.verify(proof, proofInputs)).tryGet
check resp
test "Should not verify with incorrect inputs":
proofInputs.slotIndex = 1 # change slot index
let proof = circom.prove(proofInputs).tryGet
let proof = (await circom.prove(proofInputs)).tryGet
check circom.verify(proof, proofInputs).tryGet == false
var resp = (await circom.verify(proof, proofInputs)).tryGet
check resp == false
suite "Test Circom Compat Backend":
let
@ -77,6 +84,7 @@ suite "Test Circom Compat Backend":
challenge: array[32, byte]
builder: Poseidon2Builder
sampler: Poseidon2Sampler
taskpool: Taskpool
setup:
let
@ -91,8 +99,9 @@ suite "Test Circom Compat Backend":
builder = Poseidon2Builder.new(store, verifiable).tryGet
sampler = Poseidon2Sampler.new(slotId, store, builder).tryGet
taskpool = Taskpool.new()
circom = CircomCompat.init(r1cs, wasm, zkey)
circom = CircomCompat.init(r1cs, wasm, zkey, taskpool = taskpool)
challenge = 1234567.toF.toBytes.toArray32
proofInputs = (await sampler.getProofInput(challenge, samples)).tryGet
@ -101,15 +110,20 @@ suite "Test Circom Compat Backend":
circom.release() # this comes from the rust FFI
await repoTmp.destroyDb()
await metaTmp.destroyDb()
taskpool.shutdown()
test "Should verify with correct input":
var proof = circom.prove(proofInputs).tryGet
var proof = (await circom.prove(proofInputs)).tryGet
check circom.verify(proof, proofInputs).tryGet
var resp = (await circom.verify(proof, proofInputs)).tryGet
check resp == true
test "Should not verify with incorrect input":
proofInputs.slotIndex = 1 # change slot index
let proof = circom.prove(proofInputs).tryGet
let proof = (await circom.prove(proofInputs)).tryGet
check circom.verify(proof, proofInputs).tryGet == false
var resp = (await circom.verify(proof, proofInputs)).tryGet
check resp == false

View File

@ -4,6 +4,7 @@ import ../../asynctest
import pkg/chronos
import pkg/confutils/defs
import pkg/codex/conf
import pkg/taskpools
import pkg/codex/slots/proofs/backends
import pkg/codex/slots/proofs/backendfactory
import pkg/codex/slots/proofs/backendutils
@ -18,7 +19,11 @@ type BackendUtilsMock = ref object of BackendUtils
argZKeyFile: string
method initializeCircomBackend*(
self: BackendUtilsMock, r1csFile: string, wasmFile: string, zKeyFile: string
self: BackendUtilsMock,
r1csFile: string,
wasmFile: string,
zKeyFile: string,
taskpool: Taskpool,
): AnyBackend =
self.argR1csFile = r1csFile
self.argWasmFile = wasmFile
@ -52,7 +57,7 @@ suite "Test BackendFactory":
circomWasm: InputFile("tests/circuits/fixtures/proof_main.wasm"),
circomZkey: InputFile("tests/circuits/fixtures/proof_main.zkey"),
)
backend = config.initializeBackend(utilsMock).tryGet
backend = config.initializeBackend(utilsMock, taskpool = nil).tryGet
check:
backend.vkp != nil
@ -73,7 +78,7 @@ suite "Test BackendFactory":
# will be picked up as local files:
circuitDir: OutDir("tests/circuits/fixtures"),
)
backend = config.initializeBackend(utilsMock).tryGet
backend = config.initializeBackend(utilsMock, taskpool = nil).tryGet
check:
backend.vkp != nil
@ -91,7 +96,7 @@ suite "Test BackendFactory":
marketplaceAddress: EthAddress.example.some,
circuitDir: OutDir(circuitDir),
)
backendResult = config.initializeBackend(utilsMock)
backendResult = config.initializeBackend(utilsMock, taskpool = nil)
check:
backendResult.isErr

View File

@ -29,6 +29,8 @@ suite "Test Prover":
var
store: BlockStore
prover: Prover
backend: AnyBackend
taskpool: Taskpool
setup:
let
@ -44,7 +46,8 @@ suite "Test Prover":
circomZkey: InputFile("tests/circuits/fixtures/proof_main.zkey"),
numProofSamples: samples,
)
backend = config.initializeBackend().tryGet()
taskpool = Taskpool.new()
backend = config.initializeBackend(taskpool = taskpool).tryGet()
store = RepoStore.new(repoDs, metaDs)
prover = Prover.new(store, backend, config.numProofSamples)
@ -52,6 +55,7 @@ suite "Test Prover":
teardown:
await repoTmp.destroyDb()
await metaTmp.destroyDb()
taskpool.shutdown()
test "Should sample and prove a slot":
let (_, _, verifiable) = await createVerifiableManifest(
@ -86,3 +90,79 @@ suite "Test Prover":
check:
(await prover.verify(proof, inputs)).tryGet == true
test "Should concurrently prove/verify":
const iterations = 5
var proveTasks = newSeq[Future[?!(AnyProofInputs, AnyProof)]]()
var verifyTasks = newSeq[Future[?!bool]]()
for i in 0 ..< iterations:
# create multiple prove tasks
let (_, _, verifiable) = await createVerifiableManifest(
store,
8, # number of blocks in the original dataset (before EC)
5, # ecK
3, # ecM
blockSize,
cellSize,
)
proveTasks.add(prover.prove(1, verifiable, challenge))
let proveResults = await allFinished(proveTasks)
#
for i in 0 ..< proveResults.len:
var (inputs, proofs) = proveTasks[i].read().tryGet()
verifyTasks.add(prover.verify(proofs, inputs))
let verifyResults = await allFinished(verifyTasks)
for i in 0 ..< verifyResults.len:
check:
verifyResults[i].read().tryGet() == true
test "Should complete prove/verify task when cancelled":
let (_, _, verifiable) = await createVerifiableManifest(
store,
8, # number of blocks in the original dataset (before EC)
5, # ecK
3, # ecM
blockSize,
cellSize,
)
let (inputs, proof) = (await prover.prove(1, verifiable, challenge)).tryGet
var cancelledProof = ProofPtr.new()
defer:
destroyProof(cancelledProof)
# call asyncProve and cancel the task
let proveFut = backend.asyncProve(backend.normalizeInput(inputs), cancelledProof)
proveFut.cancel()
try:
discard await proveFut
except CatchableError as exc:
check exc of CancelledError
finally:
# validate the cancelledProof
check:
(await prover.verify(cancelledProof[], inputs)).tryGet == true
var verifyRes = VerifyResult.new()
defer:
destroyVerifyResult(verifyRes)
# call asyncVerify and cancel the task
let verifyFut = backend.asyncVerify(proof, inputs, verifyRes)
verifyFut.cancel()
try:
discard await verifyFut
except CatchableError as exc:
check exc of CancelledError
finally:
# validate the verifyResponse
check verifyRes[] == true

@ -1 +1 @@
Subproject commit d3fb903945c3895f28a2e50685745e0a9762ece5
Subproject commit 88197fe6ec929559b37e1443785ece650d2e9255