feat!: pass in task pool instead of instantiating it ourselves

This commit is contained in:
Mark Spanbroek 2025-07-02 16:15:19 +02:00
parent b3d5adf106
commit 8dd99f1964
4 changed files with 28 additions and 39 deletions

View File

@ -9,6 +9,8 @@ import std/times
import std/options
# import strformat
import taskpools
import groth16/prover
import groth16/verifier
import groth16/files/witness
@ -203,11 +205,12 @@ proc cliMain(cfg: Config) =
else:
echo("generating proof...")
let print_timings = cfg.measure_time and cfg.verbose
var pool = Taskpool.new(cfg.nthreads)
withMeasureTime(cfg.measure_time,"proving"):
if cfg.no_masking:
proof = generateProofWithTrivialMask(cfg.nthreads, print_timings, zkey, wtns)
proof = generateProofWithTrivialMask(zkey, wtns, pool, print_timings)
else:
proof = generateProof(cfg.nthreads, print_timings, zkey, wtns)
proof = generateProof(zkey, wtns, pool, print_timings)
if not (cfg.output_file == ""):
echo("exporting the proof to " & quoted(cfg.output_file))

View File

@ -85,7 +85,7 @@ func msmConstantineG2*( coeffs: openArray[Fr[BN254_Snarks]] , points: openArray[
const task_multiplier : int = 1
proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr[BN254_Snarks]] , points: seq[G1] ): G1 =
proc msmMultiThreadedG1*( coeffs: seq[Fr[BN254_Snarks]] , points: seq[G1], pool: Taskpool ): G1 =
# for N <= 255 , we use 1 thread
# for N == 256 , we use 2 threads
@ -94,11 +94,10 @@ proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr[BN254_Snarks]] , po
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
let nthreads_target = if (nthreads_hint<=0): countProcessors() else: min( nthreads_hint, 256 )
let nthreads_target = min( pool.numThreads, 256 )
let nthreads = max( 1 , min( N div 128 , nthreads_target ) )
let ntasks = if nthreads>1: (nthreads*task_multiplier) else: 1
var pool = Taskpool.new(num_threads = nthreads)
var pending : seq[FlowVar[mycurves.G1]] = newSeq[FlowVar[mycurves.G1]](ntasks)
var a : int = 0
@ -117,22 +116,18 @@ proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr[BN254_Snarks]] , po
for k in 0..<ntasks:
res += sync pending[k]
pool.syncAll()
pool.shutdown()
return res
#---------------------------------------
proc msmMultiThreadedG2*( nthreads_hint: int, coeffs: seq[Fr[BN254_Snarks]] , points: seq[G2] ): G2 =
proc msmMultiThreadedG2*( coeffs: seq[Fr[BN254_Snarks]] , points: seq[G2], pool: Taskpool ): G2 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
let nthreads_target = if (nthreads_hint<=0): countProcessors() else: min( nthreads_hint, 256 )
let nthreads_target = min( pool.numThreads, 256 )
let nthreads = max( 1 , min( N div 128 , nthreads_target ) )
let ntasks = if nthreads>1: (nthreads*task_multiplier) else: 1
var pool = Taskpool.new(num_threads = nthreads)
var pending : seq[FlowVar[mycurves.G2]] = newSeq[FlowVar[mycurves.G2]](ntasks)
var a : int = 0
@ -151,9 +146,6 @@ proc msmMultiThreadedG2*( nthreads_hint: int, coeffs: seq[Fr[BN254_Snarks]] , po
for k in 0..<ntasks:
res += sync pending[k]
pool.syncAll()
pool.shutdown()
return res
#-------------------------------------------------------------------------------

View File

@ -130,7 +130,7 @@ func shiftEvalDomainTask(
# computes the quotient polynomial Q = (A*B - C) / Z
# by computing the values on a shifted domain, and interpolating the result
# remark: Q has degree `n-2`, so it's enough to use a domain of size n
proc computeQuotientPointwise( nthreads: int, abc: ABC ): Poly =
proc computeQuotientPointwise( abc: ABC, pool: TaskPool ): Poly =
let n = abc.valuesAz.len
assert( abc.valuesBz.len == n )
assert( abc.valuesCz.len == n )
@ -142,8 +142,6 @@ proc computeQuotientPointwise( nthreads: int, abc: ABC ): Poly =
let eta = createDomain(2*n).domainGen
let invZ1 = invFr( smallPowFr(eta,n) - oneFr )
var pool = Taskpool.new(num_threads = nthreads)
var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]]
var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 )
@ -163,9 +161,6 @@ proc computeQuotientPointwise( nthreads: int, abc: ABC ): Poly =
let Q1 = polyInverseNTT( ys, D )
let cs = multiplyByPowers( Q1.coeffs, invFr(eta) )
pool.syncAll()
pool.shutdown()
return Poly(coeffs: cs)
#---------------------------------------
@ -176,15 +171,13 @@ proc computeQuotientPointwise( nthreads: int, abc: ABC ): Poly =
# (shifted) Lagrange bases.
# see <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
#
proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC): seq[Fr[BN254_Snarks]] =
proc computeSnarkjsScalarCoeffs( abc: ABC, pool: TaskPool ): seq[Fr[BN254_Snarks]] =
let n = abc.valuesAz.len
assert( abc.valuesBz.len == n )
assert( abc.valuesCz.len == n )
let D = createDomain(n)
let eta = createDomain(2*n).domainGen
var pool = Taskpool.new(num_threads = nthreads)
var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]]
var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 )
@ -202,9 +195,6 @@ proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC): seq[Fr[BN254_Snarks]]
var ys : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( n )
for j in 0..<n: ys[j] = ( A1[j] * B1[j] - C1[j] )
pool.syncAll()
pool.shutdown()
return ys
#[
@ -239,7 +229,7 @@ type
r*: Fr[BN254_Snarks] # masking coefficients
s*: Fr[BN254_Snarks] # for zero knowledge
proc generateProofWithMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns: Witness, mask: Mask ): Proof =
proc generateProofWithMask*( zkey: ZKey, wtns: Witness, mask: Mask, pool: Taskpool, printTimings: bool): Proof =
when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)):
{.fatal: "Compile with arc/orc!".}
@ -278,13 +268,13 @@ proc generateProofWithMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns
# the points H are [delta^-1 * tau^i * Z(tau)]
of JensGroth:
let polyQ = computeQuotientPointwise( nthreads, abc )
let polyQ = computeQuotientPointwise( abc, pool )
qs = polyQ.coeffs
# the points H are `[delta^-1 * L_{2i+1}(tau)]_1`
# where L_i are Lagrange basis polynomials on the double-sized domain
of Snarkjs:
qs = computeSnarkjsScalarCoeffs( nthreads, abc )
qs = computeSnarkjsScalarCoeffs( abc, pool )
var zs = newSeq[Fr[BN254_Snarks]]( nvars - npubs - 1 )
for j in npubs+1..<nvars:
@ -306,41 +296,41 @@ proc generateProofWithMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns
withMeasureTime(printTimings,"computing pi_A (G1 MSM)"):
pi_a = spec.alpha1
pi_a += r ** spec.delta1
pi_a += msmMultiThreadedG1( nthreads , witness , pts.pointsA1 )
pi_a += msmMultiThreadedG1( witness , pts.pointsA1, pool )
var rho : G1
withMeasureTime(printTimings,"computing rho (G1 MSM)"):
rho = spec.beta1
rho += s ** spec.delta1
rho += msmMultiThreadedG1( nthreads , witness , pts.pointsB1 )
rho += msmMultiThreadedG1( witness , pts.pointsB1, pool )
var pi_b : G2
withMeasureTime(printTimings,"computing pi_B (G2 MSM)"):
pi_b = spec.beta2
pi_b += s ** spec.delta2
pi_b += msmMultiThreadedG2( nthreads , witness , pts.pointsB2 )
pi_b += msmMultiThreadedG2( witness , pts.pointsB2, pool )
var pi_c : G1
withMeasureTime(printTimings,"computing pi_C (2x G1 MSM)"):
pi_c = s ** pi_a
pi_c += r ** rho
pi_c += negFr(r*s) ** spec.delta1
pi_c += msmMultiThreadedG1( nthreads, qs , pts.pointsH1 )
pi_c += msmMultiThreadedG1( nthreads, zs , pts.pointsC1 )
pi_c += msmMultiThreadedG1( qs , pts.pointsH1, pool )
pi_c += msmMultiThreadedG1( zs , pts.pointsC1, pool )
return Proof( curve:"bn128", publicIO:pubIO, pi_a:pi_a, pi_b:pi_b, pi_c:pi_c )
#-------------------------------------------------------------------------------
proc generateProofWithTrivialMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns: Witness ): Proof =
proc generateProofWithTrivialMask*( zkey: ZKey, wtns: Witness, pool: Taskpool, printTimings: bool ): Proof =
let mask = Mask( r: zeroFr , s: zeroFr )
return generateProofWithMask( nthreads, printTimings, zkey, wtns, mask )
return generateProofWithMask( zkey, wtns, mask, pool, printTimings )
proc generateProof*( nthreads: int, printTimings: bool, zkey: ZKey, wtns: Witness ): Proof =
proc generateProof*( zkey: ZKey, wtns: Witness, pool: Taskpool, printTimings = false ): Proof =
# masking coeffs
let r = randFr()
let s = randFr()
let mask = Mask(r: r, s: s)
return generateProofWithMask( nthreads, printTimings, zkey, wtns, mask )
return generateProofWithMask( zkey, wtns, mask, pool, printTimings )

View File

@ -2,6 +2,8 @@
import std/unittest
import std/sequtils
import taskpools
import groth16/prover
import groth16/verifier
import groth16/fake_setup
@ -57,9 +59,11 @@ let myWitness =
#-------------------------------------------------------------------------------
proc testProof(zkey: ZKey, witness: Witness): bool =
let proof = generateProof( 8, false, zkey, witness )
var pool = Taskpool.new()
let proof = generateProof( zkey, witness, pool )
let vkey = extractVKey( zkey)
let ok = verifyProof( vkey, proof )
pool.shutdown()
return ok
suite "prover":