mirror of
https://github.com/logos-storage/nim-groth16.git
synced 2026-05-18 16:49:30 +00:00
135 lines
4.9 KiB
Nim
135 lines
4.9 KiB
Nim
{.used.}
|
|
|
|
# Multi-threading correctness tests.
|
|
#
|
|
# Two complementary checks:
|
|
#
|
|
# 1. Trivial-mask determinism (r=s=0): proof is a pure deterministic function
|
|
# of (zkey, witness), so sweeping the thread count must produce
|
|
# byte-identical proof points. Catches races that produce *different but
|
|
# still valid* proofs across configurations.
|
|
#
|
|
# 2. Random-mask end-to-end verify: proves with random masking (the
|
|
# production code path) under varied (gc-mode, thread-count) and asserts
|
|
# every resulting proof verifies. Random masks change the MSM coefficient
|
|
# inputs, which exercises the data-dependent (non-constant-time) parts of
|
|
# the MSM where coefficient-magnitude-driven races have historically
|
|
# hidden — invisible under trivial-mask testing.
|
|
|
|
import std/unittest
|
|
import std/sequtils
|
|
|
|
import taskpools
|
|
|
|
import groth16/prover
|
|
import groth16/prover/groth16 as proverImpl
|
|
import groth16/verifier
|
|
import groth16/fake_setup
|
|
import groth16/zkey_types
|
|
import groth16/files/witness
|
|
import groth16/files/r1cs
|
|
import groth16/bn128/fields
|
|
|
|
#-------------------------------------------------------------------------------
|
|
# Same simple multiplication circuit testProver.nim uses: 7*11*13 + 1022 = 2023.
|
|
# Small but exercises the full prover path (4 MSMs + quotient computation).
|
|
|
|
const myWitnessCfg =
|
|
WitnessConfig( nWires: 8
|
|
, nPubOut: 1
|
|
, nPubIn: 1
|
|
, nPrivIn: 3
|
|
, nLabels: 0
|
|
)
|
|
|
|
const myEq1 : Constraint = ( @[] , @[] , @[ (1,minusOneFr) , (2,oneFr) , (7,oneFr) ] )
|
|
const myEq2 : Constraint = ( @[ (3,oneFr) ] , @[ (4,oneFr) ] , @[ (6,oneFr) ] )
|
|
const myEq3 : Constraint = ( @[ (5,oneFr) ] , @[ (6,oneFr) ] , @[ (7,oneFr) ] )
|
|
|
|
const myConstraints : seq[Constraint] = @[ myEq1, myEq2, myEq3 ]
|
|
|
|
const myR1CS =
|
|
R1CS( r: primeR
|
|
, cfg: myWitnessCfg
|
|
, nConstr: myConstraints.len
|
|
, constraints: myConstraints
|
|
, wireToLabel: @[]
|
|
)
|
|
|
|
let myWitnessValues = map( @[ 1, 2023, 1022, 7, 11, 13, 7*11, 7*11*13 ] , intToFr )
|
|
|
|
let myWitness =
|
|
Witness( curve: "bn128"
|
|
, r: primeR
|
|
, nvars: 8
|
|
, values: myWitnessValues
|
|
)
|
|
|
|
const ThreadCounts = [1, 2, 4, 8]
|
|
|
|
#-------------------------------------------------------------------------------
|
|
|
|
proc proveWithThreads(zkey: ZKey, witness: Witness, nThreads: int): Proof =
|
|
var pool = Taskpool.new(numThreads = nThreads)
|
|
result = generateProofWithTrivialMask( zkey, witness, pool, printTimings = false )
|
|
pool.shutdown()
|
|
|
|
proc verifyWith(zkey: ZKey, proof: Proof): bool =
|
|
let vkey = extractVKey(zkey)
|
|
return verifyProof(vkey, proof)
|
|
|
|
#-------------------------------------------------------------------------------
|
|
|
|
suite "multithreading":
|
|
|
|
test "repeated proofs on the same pool match (no per-call state leak)":
|
|
# Reusing one pool across many proofs must not change the output: rules
|
|
# out residual state in worker-local buffers between invocations.
|
|
let zkey = createFakeCircuitSetup( myR1cs, flavour=Snarkjs )
|
|
var pool = Taskpool.new(numThreads = 4)
|
|
defer: pool.shutdown()
|
|
let first = generateProofWithTrivialMask(zkey, myWitness, pool, false)
|
|
for _ in 0 ..< 4:
|
|
let again = generateProofWithTrivialMask(zkey, myWitness, pool, false)
|
|
check isEqualProof(first, again)
|
|
|
|
test "trivial-mask proof is deterministic across thread counts (JensGroth)":
|
|
let zkey = createFakeCircuitSetup( myR1cs, flavour=JensGroth )
|
|
let reference = proveWithThreads(zkey, myWitness, ThreadCounts[0])
|
|
check verifyWith(zkey, reference)
|
|
for j in ThreadCounts[1..^1]:
|
|
let proof = proveWithThreads(zkey, myWitness, j)
|
|
check isEqualProof(reference, proof)
|
|
check verifyWith(zkey, proof)
|
|
|
|
test "trivial-mask proof is deterministic across thread counts (Snarkjs)":
|
|
let zkey = createFakeCircuitSetup( myR1cs, flavour=Snarkjs )
|
|
let reference = proveWithThreads(zkey, myWitness, ThreadCounts[0])
|
|
check verifyWith(zkey, reference)
|
|
for j in ThreadCounts[1..^1]:
|
|
let proof = proveWithThreads(zkey, myWitness, j)
|
|
check isEqualProof(reference, proof)
|
|
check verifyWith(zkey, proof)
|
|
|
|
test "random-mask proofs verify across thread counts (Snarkjs)":
|
|
let zkey = createFakeCircuitSetup( myR1cs, flavour=Snarkjs )
|
|
let vkey = extractVKey(zkey)
|
|
for j in ThreadCounts:
|
|
var pool = Taskpool.new(numThreads = j)
|
|
defer: pool.shutdown()
|
|
for _ in 0 ..< 100:
|
|
let proof = generateProof(zkey, myWitness, pool, false)
|
|
check verifyProof(vkey, proof)
|
|
|
|
test "random-mask proofs verify across thread counts (JensGroth)":
|
|
let zkey = createFakeCircuitSetup( myR1cs, flavour=JensGroth )
|
|
let vkey = extractVKey(zkey)
|
|
for j in ThreadCounts:
|
|
var pool = Taskpool.new(numThreads = j)
|
|
defer: pool.shutdown()
|
|
for _ in 0 ..< 100:
|
|
let proof = generateProof(zkey, myWitness, pool, false)
|
|
check verifyProof(vkey, proof)
|
|
|
|
|