diff --git a/cli/cli_main.nim b/cli/cli_main.nim index b6a5a0e..b126503 100644 --- a/cli/cli_main.nim +++ b/cli/cli_main.nim @@ -9,6 +9,8 @@ import std/times import std/options # import strformat +import taskpools + import groth16/prover import groth16/verifier import groth16/files/witness @@ -203,11 +205,12 @@ proc cliMain(cfg: Config) = else: echo("generating proof...") let print_timings = cfg.measure_time and cfg.verbose + var pool = Taskpool.new(cfg.nthreads) withMeasureTime(cfg.measure_time,"proving"): if cfg.no_masking: - proof = generateProofWithTrivialMask(cfg.nthreads, print_timings, zkey, wtns) + proof = generateProofWithTrivialMask(zkey, wtns, pool, print_timings) else: - proof = generateProof(cfg.nthreads, print_timings, zkey, wtns) + proof = generateProof(zkey, wtns, pool, print_timings) if not (cfg.output_file == ""): echo("exporting the proof to " & quoted(cfg.output_file)) diff --git a/groth16/bn128/msm.nim b/groth16/bn128/msm.nim index f68fb9b..b4b4005 100644 --- a/groth16/bn128/msm.nim +++ b/groth16/bn128/msm.nim @@ -85,7 +85,7 @@ func msmConstantineG2*( coeffs: openArray[Fr[BN254_Snarks]] , points: openArray[ const task_multiplier : int = 1 -proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr[BN254_Snarks]] , points: seq[G1] ): G1 = +proc msmMultiThreadedG1*( coeffs: seq[Fr[BN254_Snarks]] , points: seq[G1], pool: Taskpool ): G1 = # for N <= 255 , we use 1 thread # for N == 256 , we use 2 threads @@ -94,11 +94,10 @@ proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr[BN254_Snarks]] , po let N = coeffs.len assert( N == points.len, "incompatible sequence lengths" ) - let nthreads_target = if (nthreads_hint<=0): countProcessors() else: min( nthreads_hint, 256 ) + let nthreads_target = min( pool.numThreads, 256 ) let nthreads = max( 1 , min( N div 128 , nthreads_target ) ) let ntasks = if nthreads>1: (nthreads*task_multiplier) else: 1 - var pool = Taskpool.new(num_threads = nthreads) var pending : seq[FlowVar[mycurves.G1]] = newSeq[FlowVar[mycurves.G1]](ntasks) var a : int = 0 @@ -117,22 +116,18 @@ proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr[BN254_Snarks]] , po for k in 0..1: (nthreads*task_multiplier) else: 1 - var pool = Taskpool.new(num_threads = nthreads) var pending : seq[FlowVar[mycurves.G2]] = newSeq[FlowVar[mycurves.G2]](ntasks) var a : int = 0 @@ -151,9 +146,6 @@ proc msmMultiThreadedG2*( nthreads_hint: int, coeffs: seq[Fr[BN254_Snarks]] , po for k in 0.. # -proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC): seq[Fr[BN254_Snarks]] = +proc computeSnarkjsScalarCoeffs( abc: ABC, pool: TaskPool ): seq[Fr[BN254_Snarks]] = let n = abc.valuesAz.len assert( abc.valuesBz.len == n ) assert( abc.valuesCz.len == n ) let D = createDomain(n) let eta = createDomain(2*n).domainGen - var pool = Taskpool.new(num_threads = nthreads) - var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]] var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 ) @@ -202,9 +195,6 @@ proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC): seq[Fr[BN254_Snarks]] var ys : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( n ) for j in 0..