From d790dc3162b828f87abc3ccf3eb181616d5047e9 Mon Sep 17 00:00:00 2001 From: Balazs Komuves Date: Thu, 29 Feb 2024 20:35:16 +0100 Subject: [PATCH] do the three FFT/IFFT pairs in parallel --- cli/cli_main.nim | 4 ++++ groth16/bn128/msm.nim | 5 ---- groth16/prover.nim | 54 +++++++++++++++++++++++++++++++++---------- 3 files changed, 46 insertions(+), 17 deletions(-) diff --git a/cli/cli_main.nim b/cli/cli_main.nim index 2287a40..b6a5a0e 100644 --- a/cli/cli_main.nim +++ b/cli/cli_main.nim @@ -3,6 +3,7 @@ import sugar import std/strutils import std/sequtils import std/os +import std/cpuinfo import std/parseopt import std/times import std/options @@ -128,6 +129,9 @@ proc parseCliOptions(): Config = printHelp() quit() + if cfg.nthreads <= 0: + cfg.nthreads = countProcessors() + return cfg #------------------------------------------------------------------------------- diff --git a/groth16/bn128/msm.nim b/groth16/bn128/msm.nim index 135e971..d98b949 100644 --- a/groth16/bn128/msm.nim +++ b/groth16/bn128/msm.nim @@ -99,8 +99,6 @@ proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G1] let nthreads = max( 1 , min( N div 128 , nthreads_target ) ) let ntasks = if nthreads>1: (nthreads*task_multiplier) else: 1 - # echo("msm with #threads = " & $nthreads) - var pool = Taskpool.new(num_threads = nthreads) var pending : seq[FlowVar[mycurves.G1]] = newSeq[FlowVar[mycurves.G1]](ntasks) @@ -142,12 +140,9 @@ proc msmMultiThreadedG2*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G2] let nthreads = max( 1 , min( N div 128 , nthreads_target ) ) let ntasks = if nthreads>1: (nthreads*task_multiplier) else: 1 - # echo("G2 msm with #threads = " & $nthreads) - var pool = Taskpool.new(num_threads = nthreads) var pending : seq[FlowVar[mycurves.G2]] = newSeq[FlowVar[mycurves.G2]](ntasks) - # nim is just batshit crazy... GC_ref(coeffs) GC_ref(points) diff --git a/groth16/prover.nim b/groth16/prover.nim index 3e7f393..0f226fb 100644 --- a/groth16/prover.nim +++ b/groth16/prover.nim @@ -17,6 +17,9 @@ import ./zkey import std/os import std/times +import std/cpuinfo +import system +import taskpools import constantine/math/arithmetic except Fp, Fr #import constantine/math/io/io_extfields except Fp12 @@ -112,7 +115,7 @@ func shiftEvalDomain( values: seq[Fr], D: Domain, eta: Fr ): seq[Fr] = # computes the quotient polynomial Q = (A*B - C) / Z # by computing the values on a shifted domain, and interpolating the result # remark: Q has degree `n-2`, so it's enough to use a domain of size n -func computeQuotientPointwise( abc: ABC ): Poly = +proc computeQuotientPointwise( nthreads: int, abc: ABC ): Poly = let n = abc.valuesAz.len assert( abc.valuesBz.len == n ) assert( abc.valuesCz.len == n ) @@ -124,15 +127,27 @@ func computeQuotientPointwise( abc: ABC ): Poly = let eta = createDomain(2*n).domainGen let invZ1 = invFr( smallPowFr(eta,n) - oneFr ) - let A1 = shiftEvalDomain( abc.valuesAz, D, eta ) - let B1 = shiftEvalDomain( abc.valuesBz, D, eta ) - let C1 = shiftEvalDomain( abc.valuesCz, D, eta ) + var pool = Taskpool.new(num_threads = nthreads) + GCref(abc.valuesAz) + GCref(abc.valuesBz) + GCref(abc.valuesCz) + var A1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesAz, D, eta ) + var B1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesBz, D, eta ) + var C1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesCz, D, eta ) + pool.syncAll() + GCunref(abc.valuesAz) + GCunref(abc.valuesBz) + GCunref(abc.valuesCz) + let A1 = sync A1fv + let B1 = sync B1fv + let C1 = sync C1fv var ys : seq[Fr] = newSeq[Fr]( n ) for j in 0.. # -func computeSnarkjsScalarCoeffs( abc: ABC ): seq[Fr] = +proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC ): seq[Fr] = let n = abc.valuesAz.len assert( abc.valuesBz.len == n ) assert( abc.valuesCz.len == n ) let D = createDomain(n) let eta = createDomain(2*n).domainGen - let A1 = shiftEvalDomain( abc.valuesAz, D, eta ) - let B1 = shiftEvalDomain( abc.valuesBz, D, eta ) - let C1 = shiftEvalDomain( abc.valuesCz, D, eta ) + + var pool = Taskpool.new(num_threads = nthreads) + GCref(abc.valuesAz) + GCref(abc.valuesBz) + GCref(abc.valuesCz) + var A1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesAz, D, eta ) + var B1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesBz, D, eta ) + var C1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesCz, D, eta ) + pool.syncAll() + GCunref(abc.valuesAz) + GCunref(abc.valuesBz) + GCunref(abc.valuesCz) + let A1 = sync A1fv + let B1 = sync B1fv + let C1 = sync C1fv + var ys : seq[Fr] = newSeq[Fr]( n ) - for j in 0..