Merge fc56b91393a65b931658dc5f014e93b72aed26e4 into f876bd6dba00b2accf3ad1dcb308bda3c3cb3690

This commit is contained in:
munna0908 2026-05-07 10:23:47 +00:00 committed by GitHub
commit 00e8636ca1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 307 additions and 51 deletions

View File

@ -1,3 +1,3 @@
--path:".."
--threads:on
--mm:arc
--mm:refc

View File

@ -1,3 +1,4 @@
version = "0.1.1"
author = "Balazs Komuves"
description = "Groth16 proof system"

View File

@ -22,6 +22,7 @@ import constantine/math/elliptic/ec_multi_scalar_mul as msm except Su
#import groth16/bn128/fields
import groth16/bn128/curves as mycurves
import groth16/sharedbuf
#import groth16/misc # TEMP DEBUGGING
#import std/cpuinfo
@ -79,6 +80,33 @@ func msmConstantineG2*( coeffs: openArray[Fr[BN254_Snarks]] , points: openArray[
return rAff
#-------------------------------------------------------------------------------
# spawnable wrappers: take SharedBuf views, delegate to the core.
# These are what `pool.spawn` calls, so they carry {.gcsafe, raises: [].}.
#
# Local aliases `AffG1`/`AffG2` are required because taskpools' `spawn` macro
# does `getImpl().replaceSymsByIdents()`, which strips qualifications. With
# `SharedBuf[mycurves.G1]` the bare ident `G1` then re-resolves to the enum
# value `aff.G1` (of type `Subgroup`), not the type alias. Renaming dodges
# the collision.
type
AffG1 = mycurves.G1
AffG2 = mycurves.G2
FrBN = Fr[BN254_Snarks]
func msmConstantineG1Range( coeffs: SharedBuf[FrBN] ,
points: SharedBuf[AffG1] ): AffG1
{.gcsafe, raises: [].} =
msmConstantineG1( toOpenArray(coeffs.payload, 0, coeffs.len - 1),
toOpenArray(points.payload, 0, points.len - 1) )
func msmConstantineG2Range( coeffs: SharedBuf[FrBN] ,
points: SharedBuf[AffG2] ): AffG2
{.gcsafe, raises: [].} =
msmConstantineG2( toOpenArray(coeffs.payload, 0, coeffs.len - 1),
toOpenArray(points.payload, 0, points.len - 1) )
#-------------------------------------------------------------------------------
const task_multiplier : int = 1
@ -105,9 +133,9 @@ proc msmMultiThreadedG1*( coeffs: seq[Fr[BN254_Snarks]] , points: seq[G1], pool:
b = (N*(k+1)) div ntasks
else:
b = N
let cs = coeffs[a..<b]
let ps = points[a..<b]
pending[k] = pool.spawn msmConstantineG1( cs, ps );
let cs = SharedBuf.view(toOpenArray(coeffs, a, b - 1))
let ps = SharedBuf.view(toOpenArray(points, a, b - 1))
pending[k] = pool.spawn msmConstantineG1Range(cs, ps)
a = b
var res : G1 = infG1
@ -135,9 +163,9 @@ proc msmMultiThreadedG2*( coeffs: seq[Fr[BN254_Snarks]] , points: seq[G2], pool:
b = (N*(k+1)) div ntasks
else:
b = N
let cs = coeffs[a..<b]
let ps = points[a..<b]
pending[k] = pool.spawn msmConstantineG2( cs, ps );
let cs = SharedBuf.view(toOpenArray(coeffs, a, b - 1))
let ps = SharedBuf.view(toOpenArray(points, a, b - 1))
pending[k] = pool.spawn msmConstantineG2Range(cs, ps)
a = b
var res : G2 = infG2

View File

@ -4,11 +4,11 @@ import groth16/files/export_json
#-------------------------------------------------------------------------------
proc exampleProveAndVerify() =
proc exampleProveAndVerify() =
let zkey_fname : string = "./build/product.zkey"
let wtns_fname : string = "./build/product.wtns"
let proof = testProveAndVerify( zkey_fname, wtns_fname)
let (_, proof) = testProveAndVerify( zkey_fname, wtns_fname)
exportPublicIO( "./build/nim_public.json" , proof )
exportProof( "./build/nim_proof.json" , proof )

3
groth16/example/nim.cfg Normal file
View File

@ -0,0 +1,3 @@
--path:"../.."
--threads:on
--mm:arc

View File

@ -28,9 +28,6 @@ import groth16/prover/shared
proc finishPartialProofWithMask*( zkey: ZKey, wtns: Witness, partialProof: PartialProof, mask: Mask, pool: Taskpool, printTimings: bool): Proof =
when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)):
{.fatal: "Compile with arc/orc!".}
# if (zkey.header.curve != wtns.curve):
# echo( "zkey.header.curve = " & ($zkey.header.curve) )
# echo( "wtns.curve = " & ($wtns.curve ) )

View File

@ -33,9 +33,6 @@ import groth16/partial/types
proc generatePartialProof*( zkey: ZKey, pwtns: PartialWitness, pool: Taskpool, printTimings: bool): PartialProof =
when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)):
{.fatal: "Compile with arc/orc!".}
# assert( zkey.header.curve == wtns.curve )
let partial_witness = pwtns.values

View File

@ -33,9 +33,6 @@ import groth16/prover/shared
proc generateProofWithMask*( zkey: ZKey, wtns: Witness, mask: Mask, pool: Taskpool, printTimings: bool): Proof =
when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)):
{.fatal: "Compile with arc/orc!".}
# if (zkey.header.curve != wtns.curve):
# echo( "zkey.header.curve = " & ($zkey.header.curve) )
# echo( "wtns.curve = " & ($wtns.curve ) )

View File

@ -18,10 +18,19 @@ import groth16/bn128
import groth16/math/domain
import groth16/math/poly
import groth16/zkey_types
import groth16/sharedbuf
#import groth16/misc
import groth16/prover/types
# `FrBN` shadows `Fr[BN254_Snarks]` to dodge a taskpools `spawn` macro issue:
# its `getImpl().replaceSymsByIdents()` doesn't roundtrip qualified generic
# instantiations cleanly, so the bare ident `Fr[BN254_Snarks]` re-resolves
# wrong at the wrapper-proc reconstruction step. Using a plain alias gives
# the macro an unambiguous symbol to work with. Same workaround as
# `AffG1`/`AffG2` in `groth16/bn128/msm.nim`.
type FrBN = Fr[BN254_Snarks]
#-------------------------------------------------------------------------------
proc randomMask*(): Mask =
@ -96,16 +105,40 @@ func shiftEvalDomain*(
var ds : seq[Fr[BN254_Snarks]] = multiplyByPowers( cs, eta )
return polyForwardNTT( Poly(coeffs:ds), D )
# Wraps shiftEvalDomain such that it can be called by Taskpool.spawn. The result
# is written to the output parameter. Has an unused return type because
# Taskpool.spawn cannot handle a void return type.
func shiftEvalDomainTask*(
values: seq[Fr[BN254_Snarks]],
D: Domain,
eta: Fr[BN254_Snarks],
output: ptr Isolated[seq[Fr[BN254_Snarks]]]): bool =
# Spawnable wrapper for shiftEvalDomain. Crosses the spawn boundary using
# SharedBuf views — no GC type travels across, so this works under any
# GC mode (refc / arc / orc).
#
# Contract: caller owns both `values` (read) and `output` (written),
# both of length n. Caller must keep them alive until the FlowVar is sync'd.
proc shiftEvalDomainTask(
values: SharedBuf[FrBN],
D: Domain,
eta: FrBN,
output: SharedBuf[FrBN]): bool
{.gcsafe, raises: [].} =
output[] = isolate shiftEvalDomain(values, D, eta)
let n = values.len
# TODO(perf): two N-sized boundary copies (input → `local`, `result` → output)
# are an artifact of `shiftEvalDomain` / `polyInverseNTT` / `polyForwardNTT`
# taking and returning `seq[Fr]`. To eliminate both:
# 1. Add `inverseNTTInto`/`forwardNTTInto` variants in `groth16/math/ntt.nim`
# that take `openArray[Fr]` for input and `var openArray[Fr]` for output
# (the inner workers already operate on stride/offset triples — change is
# mechanical).
# 2. Rewrite this proc to call the `…Into` variants directly on
# `toOpenArray(values)` and `toOpenArray(output)`.
# Copy input view → worker-local seq for the existing seq-flavoured core.
var local = newSeq[FrBN](n)
for i in 0 ..< n: local[i] = values.payload[i]
let res = shiftEvalDomain(local, D, eta)
# Copy result → caller's output buffer through the SharedBuf payload pointer.
for i in 0 ..< n: output.payload[i] = res[i]
return true
# computes the quotient polynomial Q = (A*B - C) / Z
# by computing the values on a shifted domain, and interpolating the result
@ -116,28 +149,37 @@ proc computeQuotientPointwise*( abc: ABC, pool: TaskPool ): Poly =
assert( abc.valuesCz.len == n )
let D = createDomain(n)
# (eta*omega^j)^n - 1 = eta^n - 1
# (eta*omega^j)^n - 1 = eta^n - 1
# 1 / [ (eta*omega^j)^n - 1] = 1/(eta^n - 1)
let eta = createDomain(2*n).domainGen
let invZ1 = invFr( smallPowFr(eta,n) - oneFr )
var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]]
# Pre-allocate caller-owned output buffers; workers write into these
# through bare pointers, so nothing GC-managed crosses the spawn.
var outA = newSeq[Fr[BN254_Snarks]](n)
var outB = newSeq[Fr[BN254_Snarks]](n)
var outC = newSeq[Fr[BN254_Snarks]](n)
var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 )
var taskB1 = pool.spawn shiftEvalDomainTask( abc.valuesBz, D, eta, addr outputB1 )
var taskC1 = pool.spawn shiftEvalDomainTask( abc.valuesCz, D, eta, addr outputC1 )
let taskA1 = pool.spawn shiftEvalDomainTask(
SharedBuf.view(toOpenArray(abc.valuesAz, 0, n - 1)),
D, eta,
SharedBuf.view(toOpenArray(outA, 0, n - 1)))
let taskB1 = pool.spawn shiftEvalDomainTask(
SharedBuf.view(toOpenArray(abc.valuesBz, 0, n - 1)),
D, eta,
SharedBuf.view(toOpenArray(outB, 0, n - 1)))
let taskC1 = pool.spawn shiftEvalDomainTask(
SharedBuf.view(toOpenArray(abc.valuesCz, 0, n - 1)),
D, eta,
SharedBuf.view(toOpenArray(outC, 0, n - 1)))
discard sync taskA1
discard sync taskB1
discard sync taskC1
let A1 = outputA1.extract()
let B1 = outputB1.extract()
let C1 = outputC1.extract()
var ys : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( n )
for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] ) * invZ1
for j in 0..<n: ys[j] = ( outA[j]*outB[j] - outC[j] ) * invZ1
let Q1 = polyInverseNTT( ys, D )
let cs = multiplyByPowers( Q1.coeffs, invFr(eta) )
@ -147,7 +189,7 @@ proc computeQuotientPointwise*( abc: ABC, pool: TaskPool ): Poly =
# Snarkjs does something different, not actually computing the quotient poly
# they can get away with this, because during the trusted setup, they
# replace the points encoding the values `delta^-1 * tau^i * Z(tau)` by
# replace the points encoding the values `delta^-1 * tau^i * Z(tau)` by
# (shifted) Lagrange bases.
# see <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
#
@ -158,22 +200,29 @@ proc computeSnarkjsScalarCoeffs*( abc: ABC, pool: TaskPool ): seq[Fr[BN254_Snark
let D = createDomain(n)
let eta = createDomain(2*n).domainGen
var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]]
var outA = newSeq[Fr[BN254_Snarks]](n)
var outB = newSeq[Fr[BN254_Snarks]](n)
var outC = newSeq[Fr[BN254_Snarks]](n)
var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 )
var taskB1 = pool.spawn shiftEvalDomainTask( abc.valuesBz, D, eta, addr outputB1 )
var taskC1 = pool.spawn shiftEvalDomainTask( abc.valuesCz, D, eta, addr outputC1 )
let taskA1 = pool.spawn shiftEvalDomainTask(
SharedBuf.view(toOpenArray(abc.valuesAz, 0, n - 1)),
D, eta,
SharedBuf.view(toOpenArray(outA, 0, n - 1)))
let taskB1 = pool.spawn shiftEvalDomainTask(
SharedBuf.view(toOpenArray(abc.valuesBz, 0, n - 1)),
D, eta,
SharedBuf.view(toOpenArray(outB, 0, n - 1)))
let taskC1 = pool.spawn shiftEvalDomainTask(
SharedBuf.view(toOpenArray(abc.valuesCz, 0, n - 1)),
D, eta,
SharedBuf.view(toOpenArray(outC, 0, n - 1)))
discard sync taskA1
discard sync taskB1
discard sync taskC1
let A1 = outputA1.extract()
let B1 = outputB1.extract()
let C1 = outputC1.extract()
var ys : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( n )
for j in 0..<n: ys[j] = ( A1[j] * B1[j] - C1[j] )
for j in 0..<n: ys[j] = ( outA[j] * outB[j] - outC[j] )
return ys

43
groth16/sharedbuf.nim Normal file
View File

@ -0,0 +1,43 @@
#
# SharedBuf[T] — non-owning (ptr, len) view over a contiguous buffer.
#
# Purpose: cross a `Taskpool.spawn` boundary safely under any GC mode
# (refc / arc / orc) without sending a `seq[T]` through the spawn closure.
#
# Contract:
# - The pointed-to memory is owned by the caller. The caller MUST keep
# the underlying storage alive until every worker holding the view
# has finished (i.e. all FlowVars from `pool.spawn` have been `sync`ed).
# - Element type `T` must have no GC fields (workers read/write payload
# bytes through a raw pointer; the seq's GC machinery is never touched).
#
{.push raises: [], gcsafe.}
template makeUncheckedArray[T](p: ptr T): ptr UncheckedArray[T] =
cast[ptr UncheckedArray[T]](p)
type SharedBuf*[T] = object
payload*: ptr UncheckedArray[T]
len*: int
func view*[T](_: type SharedBuf, v: openArray[T]): SharedBuf[T] =
if v.len > 0:
SharedBuf[T](payload: makeUncheckedArray(addr v[0]), len: v.len)
else:
default(SharedBuf[T])
template checkIdx(v: SharedBuf, i: int) =
doAssert i >= 0 and i < v.len
func `[]`*[T](v: SharedBuf[T], i: int): var T =
v.checkIdx(i)
v.payload[i]
template toOpenArray*[T](v: SharedBuf[T]): var openArray[T] =
v.payload.toOpenArray(0, v.len - 1)
template toOpenArray*[T](v: SharedBuf[T], s, e: int): var openArray[T] =
v.toOpenArray().toOpenArray(s, e)
{.pop.}

View File

@ -3,6 +3,8 @@
import std/[times,os]
import strformat
import taskpools
import groth16/prover
import groth16/verifier
import groth16/files/witness
@ -23,7 +25,8 @@ proc testProveAndVerify*( zkey_fname, wtns_fname: string): (VKey,Proof) =
echo("generating proof...")
let start = cpuTime()
let proof = generateProof( zkey, witness )
var pool = Taskpool.new()
let proof = generateProof( zkey, witness ,pool)
let elapsed = cpuTime() - start
echo("proving took ",seconds(elapsed))
@ -55,7 +58,8 @@ proc testFakeSetupAndVerify*( r1cs_fname, wtns_fname: string, flavour=Snarkjs):
let vkey = extractVKey( zkey)
let start = cpuTime()
let proof = generateProof( zkey, witness )
var pool = Taskpool.new()
let proof = generateProof( zkey, witness ,pool)
let elapsed = cpuTime() - start
echo("proving took ",seconds(elapsed))

View File

@ -0,0 +1,134 @@
{.used.}
# Multi-threading correctness tests.
#
# Two complementary checks:
#
# 1. Trivial-mask determinism (r=s=0): proof is a pure deterministic function
# of (zkey, witness), so sweeping the thread count must produce
# byte-identical proof points. Catches races that produce *different but
# still valid* proofs across configurations.
#
# 2. Random-mask end-to-end verify: proves with random masking (the
# production code path) under varied (gc-mode, thread-count) and asserts
# every resulting proof verifies. Random masks change the MSM coefficient
# inputs, which exercises the data-dependent (non-constant-time) parts of
# the MSM where coefficient-magnitude-driven races have historically
# hidden — invisible under trivial-mask testing.
import std/unittest
import std/sequtils
import taskpools
import groth16/prover
import groth16/prover/groth16 as proverImpl
import groth16/verifier
import groth16/fake_setup
import groth16/zkey_types
import groth16/files/witness
import groth16/files/r1cs
import groth16/bn128/fields
#-------------------------------------------------------------------------------
# Same simple multiplication circuit testProver.nim uses: 7*11*13 + 1022 = 2023.
# Small but exercises the full prover path (4 MSMs + quotient computation).
const myWitnessCfg =
WitnessConfig( nWires: 8
, nPubOut: 1
, nPubIn: 1
, nPrivIn: 3
, nLabels: 0
)
const myEq1 : Constraint = ( @[] , @[] , @[ (1,minusOneFr) , (2,oneFr) , (7,oneFr) ] )
const myEq2 : Constraint = ( @[ (3,oneFr) ] , @[ (4,oneFr) ] , @[ (6,oneFr) ] )
const myEq3 : Constraint = ( @[ (5,oneFr) ] , @[ (6,oneFr) ] , @[ (7,oneFr) ] )
const myConstraints : seq[Constraint] = @[ myEq1, myEq2, myEq3 ]
const myR1CS =
R1CS( r: primeR
, cfg: myWitnessCfg
, nConstr: myConstraints.len
, constraints: myConstraints
, wireToLabel: @[]
)
let myWitnessValues = map( @[ 1, 2023, 1022, 7, 11, 13, 7*11, 7*11*13 ] , intToFr )
let myWitness =
Witness( curve: "bn128"
, r: primeR
, nvars: 8
, values: myWitnessValues
)
const ThreadCounts = [1, 2, 4, 8]
#-------------------------------------------------------------------------------
proc proveWithThreads(zkey: ZKey, witness: Witness, nThreads: int): Proof =
var pool = Taskpool.new(numThreads = nThreads)
result = generateProofWithTrivialMask( zkey, witness, pool, printTimings = false )
pool.shutdown()
proc verifyWith(zkey: ZKey, proof: Proof): bool =
let vkey = extractVKey(zkey)
return verifyProof(vkey, proof)
#-------------------------------------------------------------------------------
suite "multithreading":
test "repeated proofs on the same pool match (no per-call state leak)":
# Reusing one pool across many proofs must not change the output: rules
# out residual state in worker-local buffers between invocations.
let zkey = createFakeCircuitSetup( myR1cs, flavour=Snarkjs )
var pool = Taskpool.new(numThreads = 4)
defer: pool.shutdown()
let first = generateProofWithTrivialMask(zkey, myWitness, pool, false)
for _ in 0 ..< 4:
let again = generateProofWithTrivialMask(zkey, myWitness, pool, false)
check isEqualProof(first, again)
test "trivial-mask proof is deterministic across thread counts (JensGroth)":
let zkey = createFakeCircuitSetup( myR1cs, flavour=JensGroth )
let reference = proveWithThreads(zkey, myWitness, ThreadCounts[0])
check verifyWith(zkey, reference)
for j in ThreadCounts[1..^1]:
let proof = proveWithThreads(zkey, myWitness, j)
check isEqualProof(reference, proof)
check verifyWith(zkey, proof)
test "trivial-mask proof is deterministic across thread counts (Snarkjs)":
let zkey = createFakeCircuitSetup( myR1cs, flavour=Snarkjs )
let reference = proveWithThreads(zkey, myWitness, ThreadCounts[0])
check verifyWith(zkey, reference)
for j in ThreadCounts[1..^1]:
let proof = proveWithThreads(zkey, myWitness, j)
check isEqualProof(reference, proof)
check verifyWith(zkey, proof)
test "random-mask proofs verify across thread counts (Snarkjs)":
let zkey = createFakeCircuitSetup( myR1cs, flavour=Snarkjs )
let vkey = extractVKey(zkey)
for j in ThreadCounts:
var pool = Taskpool.new(numThreads = j)
defer: pool.shutdown()
for _ in 0 ..< 100:
let proof = generateProof(zkey, myWitness, pool, false)
check verifyProof(vkey, proof)
test "random-mask proofs verify across thread counts (JensGroth)":
let zkey = createFakeCircuitSetup( myR1cs, flavour=JensGroth )
let vkey = extractVKey(zkey)
for j in ThreadCounts:
var pool = Taskpool.new(numThreads = j)
defer: pool.shutdown()
for _ in 0 ..< 100:
let proof = generateProof(zkey, myWitness, pool, false)
check verifyProof(vkey, proof)

View File

@ -1 +1,3 @@
--path:".."
--threads:on
--mm:refc

View File

@ -2,4 +2,5 @@
import ./groth16/testPtCompression
import ./groth16/testCurve
import ./groth16/testProver
import ./groth16/testMultithreading