mirror of
https://github.com/logos-storage/nim-groth16.git
synced 2026-01-03 22:23:08 +00:00
feat!: pass in task pool instead of instantiating it ourselves
This commit is contained in:
parent
b3d5adf106
commit
8dd99f1964
@ -9,6 +9,8 @@ import std/times
|
||||
import std/options
|
||||
# import strformat
|
||||
|
||||
import taskpools
|
||||
|
||||
import groth16/prover
|
||||
import groth16/verifier
|
||||
import groth16/files/witness
|
||||
@ -203,11 +205,12 @@ proc cliMain(cfg: Config) =
|
||||
else:
|
||||
echo("generating proof...")
|
||||
let print_timings = cfg.measure_time and cfg.verbose
|
||||
var pool = Taskpool.new(cfg.nthreads)
|
||||
withMeasureTime(cfg.measure_time,"proving"):
|
||||
if cfg.no_masking:
|
||||
proof = generateProofWithTrivialMask(cfg.nthreads, print_timings, zkey, wtns)
|
||||
proof = generateProofWithTrivialMask(zkey, wtns, pool, print_timings)
|
||||
else:
|
||||
proof = generateProof(cfg.nthreads, print_timings, zkey, wtns)
|
||||
proof = generateProof(zkey, wtns, pool, print_timings)
|
||||
|
||||
if not (cfg.output_file == ""):
|
||||
echo("exporting the proof to " & quoted(cfg.output_file))
|
||||
|
||||
@ -85,7 +85,7 @@ func msmConstantineG2*( coeffs: openArray[Fr[BN254_Snarks]] , points: openArray[
|
||||
|
||||
const task_multiplier : int = 1
|
||||
|
||||
proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr[BN254_Snarks]] , points: seq[G1] ): G1 =
|
||||
proc msmMultiThreadedG1*( coeffs: seq[Fr[BN254_Snarks]] , points: seq[G1], pool: Taskpool ): G1 =
|
||||
|
||||
# for N <= 255 , we use 1 thread
|
||||
# for N == 256 , we use 2 threads
|
||||
@ -94,11 +94,10 @@ proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr[BN254_Snarks]] , po
|
||||
|
||||
let N = coeffs.len
|
||||
assert( N == points.len, "incompatible sequence lengths" )
|
||||
let nthreads_target = if (nthreads_hint<=0): countProcessors() else: min( nthreads_hint, 256 )
|
||||
let nthreads_target = min( pool.numThreads, 256 )
|
||||
let nthreads = max( 1 , min( N div 128 , nthreads_target ) )
|
||||
let ntasks = if nthreads>1: (nthreads*task_multiplier) else: 1
|
||||
|
||||
var pool = Taskpool.new(num_threads = nthreads)
|
||||
var pending : seq[FlowVar[mycurves.G1]] = newSeq[FlowVar[mycurves.G1]](ntasks)
|
||||
|
||||
var a : int = 0
|
||||
@ -117,22 +116,18 @@ proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr[BN254_Snarks]] , po
|
||||
for k in 0..<ntasks:
|
||||
res += sync pending[k]
|
||||
|
||||
pool.syncAll()
|
||||
pool.shutdown()
|
||||
|
||||
return res
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
proc msmMultiThreadedG2*( nthreads_hint: int, coeffs: seq[Fr[BN254_Snarks]] , points: seq[G2] ): G2 =
|
||||
proc msmMultiThreadedG2*( coeffs: seq[Fr[BN254_Snarks]] , points: seq[G2], pool: Taskpool ): G2 =
|
||||
|
||||
let N = coeffs.len
|
||||
assert( N == points.len, "incompatible sequence lengths" )
|
||||
let nthreads_target = if (nthreads_hint<=0): countProcessors() else: min( nthreads_hint, 256 )
|
||||
let nthreads_target = min( pool.numThreads, 256 )
|
||||
let nthreads = max( 1 , min( N div 128 , nthreads_target ) )
|
||||
let ntasks = if nthreads>1: (nthreads*task_multiplier) else: 1
|
||||
|
||||
var pool = Taskpool.new(num_threads = nthreads)
|
||||
var pending : seq[FlowVar[mycurves.G2]] = newSeq[FlowVar[mycurves.G2]](ntasks)
|
||||
|
||||
var a : int = 0
|
||||
@ -151,9 +146,6 @@ proc msmMultiThreadedG2*( nthreads_hint: int, coeffs: seq[Fr[BN254_Snarks]] , po
|
||||
for k in 0..<ntasks:
|
||||
res += sync pending[k]
|
||||
|
||||
pool.syncAll()
|
||||
pool.shutdown()
|
||||
|
||||
return res
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
@ -130,7 +130,7 @@ func shiftEvalDomainTask(
|
||||
# computes the quotient polynomial Q = (A*B - C) / Z
|
||||
# by computing the values on a shifted domain, and interpolating the result
|
||||
# remark: Q has degree `n-2`, so it's enough to use a domain of size n
|
||||
proc computeQuotientPointwise( nthreads: int, abc: ABC ): Poly =
|
||||
proc computeQuotientPointwise( abc: ABC, pool: TaskPool ): Poly =
|
||||
let n = abc.valuesAz.len
|
||||
assert( abc.valuesBz.len == n )
|
||||
assert( abc.valuesCz.len == n )
|
||||
@ -142,8 +142,6 @@ proc computeQuotientPointwise( nthreads: int, abc: ABC ): Poly =
|
||||
let eta = createDomain(2*n).domainGen
|
||||
let invZ1 = invFr( smallPowFr(eta,n) - oneFr )
|
||||
|
||||
var pool = Taskpool.new(num_threads = nthreads)
|
||||
|
||||
var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]]
|
||||
|
||||
var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 )
|
||||
@ -163,9 +161,6 @@ proc computeQuotientPointwise( nthreads: int, abc: ABC ): Poly =
|
||||
let Q1 = polyInverseNTT( ys, D )
|
||||
let cs = multiplyByPowers( Q1.coeffs, invFr(eta) )
|
||||
|
||||
pool.syncAll()
|
||||
pool.shutdown()
|
||||
|
||||
return Poly(coeffs: cs)
|
||||
|
||||
#---------------------------------------
|
||||
@ -176,15 +171,13 @@ proc computeQuotientPointwise( nthreads: int, abc: ABC ): Poly =
|
||||
# (shifted) Lagrange bases.
|
||||
# see <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
|
||||
#
|
||||
proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC): seq[Fr[BN254_Snarks]] =
|
||||
proc computeSnarkjsScalarCoeffs( abc: ABC, pool: TaskPool ): seq[Fr[BN254_Snarks]] =
|
||||
let n = abc.valuesAz.len
|
||||
assert( abc.valuesBz.len == n )
|
||||
assert( abc.valuesCz.len == n )
|
||||
let D = createDomain(n)
|
||||
let eta = createDomain(2*n).domainGen
|
||||
|
||||
var pool = Taskpool.new(num_threads = nthreads)
|
||||
|
||||
var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]]
|
||||
|
||||
var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 )
|
||||
@ -202,9 +195,6 @@ proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC): seq[Fr[BN254_Snarks]]
|
||||
var ys : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( n )
|
||||
for j in 0..<n: ys[j] = ( A1[j] * B1[j] - C1[j] )
|
||||
|
||||
pool.syncAll()
|
||||
pool.shutdown()
|
||||
|
||||
return ys
|
||||
|
||||
#[
|
||||
@ -239,7 +229,7 @@ type
|
||||
r*: Fr[BN254_Snarks] # masking coefficients
|
||||
s*: Fr[BN254_Snarks] # for zero knowledge
|
||||
|
||||
proc generateProofWithMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns: Witness, mask: Mask ): Proof =
|
||||
proc generateProofWithMask*( zkey: ZKey, wtns: Witness, mask: Mask, pool: Taskpool, printTimings: bool): Proof =
|
||||
|
||||
when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)):
|
||||
{.fatal: "Compile with arc/orc!".}
|
||||
@ -278,13 +268,13 @@ proc generateProofWithMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns
|
||||
|
||||
# the points H are [delta^-1 * tau^i * Z(tau)]
|
||||
of JensGroth:
|
||||
let polyQ = computeQuotientPointwise( nthreads, abc )
|
||||
let polyQ = computeQuotientPointwise( abc, pool )
|
||||
qs = polyQ.coeffs
|
||||
|
||||
# the points H are `[delta^-1 * L_{2i+1}(tau)]_1`
|
||||
# where L_i are Lagrange basis polynomials on the double-sized domain
|
||||
of Snarkjs:
|
||||
qs = computeSnarkjsScalarCoeffs( nthreads, abc )
|
||||
qs = computeSnarkjsScalarCoeffs( abc, pool )
|
||||
|
||||
var zs = newSeq[Fr[BN254_Snarks]]( nvars - npubs - 1 )
|
||||
for j in npubs+1..<nvars:
|
||||
@ -306,41 +296,41 @@ proc generateProofWithMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns
|
||||
withMeasureTime(printTimings,"computing pi_A (G1 MSM)"):
|
||||
pi_a = spec.alpha1
|
||||
pi_a += r ** spec.delta1
|
||||
pi_a += msmMultiThreadedG1( nthreads , witness , pts.pointsA1 )
|
||||
pi_a += msmMultiThreadedG1( witness , pts.pointsA1, pool )
|
||||
|
||||
var rho : G1
|
||||
withMeasureTime(printTimings,"computing rho (G1 MSM)"):
|
||||
rho = spec.beta1
|
||||
rho += s ** spec.delta1
|
||||
rho += msmMultiThreadedG1( nthreads , witness , pts.pointsB1 )
|
||||
rho += msmMultiThreadedG1( witness , pts.pointsB1, pool )
|
||||
|
||||
var pi_b : G2
|
||||
withMeasureTime(printTimings,"computing pi_B (G2 MSM)"):
|
||||
pi_b = spec.beta2
|
||||
pi_b += s ** spec.delta2
|
||||
pi_b += msmMultiThreadedG2( nthreads , witness , pts.pointsB2 )
|
||||
pi_b += msmMultiThreadedG2( witness , pts.pointsB2, pool )
|
||||
|
||||
var pi_c : G1
|
||||
withMeasureTime(printTimings,"computing pi_C (2x G1 MSM)"):
|
||||
pi_c = s ** pi_a
|
||||
pi_c += r ** rho
|
||||
pi_c += negFr(r*s) ** spec.delta1
|
||||
pi_c += msmMultiThreadedG1( nthreads, qs , pts.pointsH1 )
|
||||
pi_c += msmMultiThreadedG1( nthreads, zs , pts.pointsC1 )
|
||||
pi_c += msmMultiThreadedG1( qs , pts.pointsH1, pool )
|
||||
pi_c += msmMultiThreadedG1( zs , pts.pointsC1, pool )
|
||||
|
||||
return Proof( curve:"bn128", publicIO:pubIO, pi_a:pi_a, pi_b:pi_b, pi_c:pi_c )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc generateProofWithTrivialMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns: Witness ): Proof =
|
||||
proc generateProofWithTrivialMask*( zkey: ZKey, wtns: Witness, pool: Taskpool, printTimings: bool ): Proof =
|
||||
let mask = Mask( r: zeroFr , s: zeroFr )
|
||||
return generateProofWithMask( nthreads, printTimings, zkey, wtns, mask )
|
||||
return generateProofWithMask( zkey, wtns, mask, pool, printTimings )
|
||||
|
||||
proc generateProof*( nthreads: int, printTimings: bool, zkey: ZKey, wtns: Witness ): Proof =
|
||||
proc generateProof*( zkey: ZKey, wtns: Witness, pool: Taskpool, printTimings = false ): Proof =
|
||||
|
||||
# masking coeffs
|
||||
let r = randFr()
|
||||
let s = randFr()
|
||||
let mask = Mask(r: r, s: s)
|
||||
|
||||
return generateProofWithMask( nthreads, printTimings, zkey, wtns, mask )
|
||||
return generateProofWithMask( zkey, wtns, mask, pool, printTimings )
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
import std/unittest
|
||||
import std/sequtils
|
||||
|
||||
import taskpools
|
||||
|
||||
import groth16/prover
|
||||
import groth16/verifier
|
||||
import groth16/fake_setup
|
||||
@ -57,9 +59,11 @@ let myWitness =
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc testProof(zkey: ZKey, witness: Witness): bool =
|
||||
let proof = generateProof( 8, false, zkey, witness )
|
||||
var pool = Taskpool.new()
|
||||
let proof = generateProof( zkey, witness, pool )
|
||||
let vkey = extractVKey( zkey)
|
||||
let ok = verifyProof( vkey, proof )
|
||||
pool.shutdown()
|
||||
return ok
|
||||
|
||||
suite "prover":
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user