diff --git a/cli/cli_main.nim b/cli/cli_main.nim index 2f69cf8..2287a40 100644 --- a/cli/cli_main.nim +++ b/cli/cli_main.nim @@ -28,6 +28,7 @@ proc printHelp() = echo " -h, --help : print this help" echo " -v, --verbose : verbose output" echo " -d, --debug : debug output" + echo " -j, --nthreads : number of CPU threads" echo " -t, --time : print time measurements" echo " -p, --prove : create a proof" echo " -y, --verify : verify a proof" @@ -54,6 +55,7 @@ type Config = object do_verify: bool do_setup: bool no_masking: bool + nthreads: int const dummyConfig = Config( zkey_file: "" @@ -67,6 +69,7 @@ const dummyConfig = , do_verify: false , do_setup: false , no_masking: false + , nthreads: 0 ) proc printConfig(cfg: Config) = @@ -102,6 +105,7 @@ proc parseCliOptions(): Config = of "h", "help" : printHelp() of "v", "verbose" : cfg.verbose = true of "d", "debug" : cfg.debug = true + of "j", "nthreads" : cfg.nthreads = parseInt(value) of "t", "time" : cfg.measure_time = true of "p", "prove" : cfg.do_prove = true of "y", "verify" : cfg.do_verify = true @@ -160,25 +164,19 @@ proc cliMain(cfg: Config) = if not (cfg.wtns_file == ""): echo("\nparsing witness file " & quoted(cfg.wtns_file)) - let start = cpuTime() - wtns = parseWitness(cfg.wtns_file) - let elapsed = cpuTime() - start - if cfg.measure_time: echo("parsing the witness took ",seconds(elapsed)) - + withMeasureTime(cfg.measure_time,"parsing the witness"): + wtns = parseWitness(cfg.wtns_file) + if not (cfg.zkey_file == ""): echo("\nparsing zkey file " & quoted(cfg.zkey_file)) - let start = cpuTime() - zkey = parseZKey(cfg.zkey_file) - let elapsed = cpuTime() - start - if cfg.measure_time: echo("parsing the zkey took ",seconds(elapsed)) - + withMeasureTime(cfg.measure_time,"parsing the zkey"): + zkey = parseZKey(cfg.zkey_file) + if not (cfg.r1cs_file == ""): echo("\nparsing r1cs file " & quoted(cfg.r1cs_file)) - let start = cpuTime() - r1cs = parseR1CS(cfg.r1cs_file) - let elapsed = cpuTime() - start - if cfg.measure_time: echo("parsing the r1cs took ",seconds(elapsed)) - + withMeasureTime(cfg.measure_time,"parsing the r1cs"): + r1cs = parseR1CS(cfg.r1cs_file) + if cfg.do_setup: if not (cfg.zkey_file == ""): echo("\nwe are doing a fake trusted setup, don't specify the zkey file!") @@ -187,11 +185,9 @@ proc cliMain(cfg: Config) = echo("\nerror: r1cs file is required for the fake setup!") quit() echo("\nperforming fake trusted setup...") - let start = cpuTime() - zkey = createFakeCircuitSetup( r1cs, flavour=Snarkjs ) - let elapsed = cpuTime() - start - if cfg.measure_time: echo("fake setup took ",seconds(elapsed)) - + withMeasureTime(cfg.measure_time,"fake setup"): + zkey = createFakeCircuitSetup( r1cs, flavour=Snarkjs ) + if cfg.debug: printGrothHeader(zkey.header) # debugPrintCoeffs(zkey.coeffs) @@ -202,14 +198,13 @@ proc cliMain(cfg: Config) = quit() else: echo("generating proof...") - let start = cpuTime() let print_timings = cfg.measure_time and cfg.verbose - if cfg.no_masking: - proof = generateProofWithTrivialMask(print_timings, zkey, wtns) - else: - proof = generateProof(print_timings, zkey, wtns) - let elapsed = cpuTime() - start - if cfg.measure_time: echo("proving took ",seconds(elapsed)) + withMeasureTime(cfg.measure_time,"proving"): + if cfg.no_masking: + proof = generateProofWithTrivialMask(cfg.nthreads, print_timings, zkey, wtns) + else: + proof = generateProof(cfg.nthreads, print_timings, zkey, wtns) + if not (cfg.output_file == ""): echo("exporting the proof to " & quoted(cfg.output_file)) exportProof( cfg.output_file, proof ) @@ -224,12 +219,11 @@ proc cliMain(cfg: Config) = else: let vkey = extractVKey( zkey) echo("\nverifying the proof...") - let start = cpuTime() - let ok = verifyProof( vkey, proof ) - let elapsed = cpuTime() - start - echo("verification succeeded = ",ok) - if cfg.measure_time: echo("verifying took ",seconds(elapsed)) - + var ok : bool + withMeasureTime(cfg.measure_time,"verifying"): + ok = verifyProof( vkey, proof ) + echo("verification succeeded = ",ok) + echo("") #------------------------------------------------------------------------------- diff --git a/cli/nim.cfg b/cli/nim.cfg index 0f840a1..b20433f 100644 --- a/cli/nim.cfg +++ b/cli/nim.cfg @@ -1 +1,2 @@ --path:".." +--threads:on \ No newline at end of file diff --git a/groth16.nimble b/groth16.nimble index c21656d..a6566ea 100644 --- a/groth16.nimble +++ b/groth16.nimble @@ -7,4 +7,5 @@ skipDirs = @["groth16/example"] binDir = "build" namedBin = {"cli/cli_main": "nim-groth16"}.toTable() +requires "https://github.com/status-im/nim-taskpools" requires "https://github.com/mratsim/constantine#5f7ba18f2ed351260015397c9eae079a6decaee1" \ No newline at end of file diff --git a/groth16/bn128/msm.nim b/groth16/bn128/msm.nim index fa32cbf..135e971 100644 --- a/groth16/bn128/msm.nim +++ b/groth16/bn128/msm.nim @@ -4,6 +4,8 @@ # import system +import std/cpuinfo +import taskpools # import constantine/curves_primitives except Fp, Fp2, Fr @@ -23,11 +25,16 @@ import constantine/math/elliptic/ec_scalar_mul_vartime as scl except Su import constantine/math/elliptic/ec_multi_scalar_mul as msm except Subgroup import groth16/bn128/fields -import groth16/bn128/curves +import groth16/bn128/curves as mycurves + +import groth16/misc # TEMP DEBUGGING +import std/times #------------------------------------------------------------------------------- -func msmConstantineG1*( coeffs: openArray[Fr] , points: openArray[G1] ): G1 = +proc msmConstantineG1*( coeffs: openArray[Fr] , points: openArray[G1] ): G1 = + + # let start = cpuTime() let N = coeffs.len assert( N == points.len, "incompatible sequence lengths" ) @@ -46,6 +53,9 @@ func msmConstantineG1*( coeffs: openArray[Fr] , points: openArray[G1] ): G1 = var rAff: G1 prj.affine(rAff, r) + # let elapsed = cpuTime() - start + # echo("computing an MSM of size " & ($N) & " took " & seconds(elapsed)) + return rAff #--------------------------------------- @@ -74,41 +84,96 @@ func msmConstantineG2*( coeffs: openArray[Fr] , points: openArray[G2] ): G2 = #------------------------------------------------------------------------------- -#[ -type InputTuple = tuple[idx:int, coeffs: openArray[Fr] , points: openArray[G1]] +const task_multiplier : int = 1 -func msmMultiThreadedG1*( coeffs: openArray[Fr] , points: openArray[G1] ): G1 = - let N = coeffs.len - assert( N == points.len, "incompatible sequence lengths" ) - - let nthreadsTarget = 8 +proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G1] ): G1 = # for N <= 255 , we use 1 thread # for N == 256 , we use 2 threads # for N == 512 , we use 4 threads - # for N >= 1024, we use 8 threads - let nthreads = max( 1 , min( N div 128 , nthreadsTarget ) ) + # for N >= 1024, we use 8+ threads - let m = N div nthreads + let N = coeffs.len + assert( N == points.len, "incompatible sequence lengths" ) + let nthreads_target = if (nthreads_hint<=0): countProcessors() else: min( nthreads_hint, 256 ) + let nthreads = max( 1 , min( N div 128 , nthreads_target ) ) + let ntasks = if nthreads>1: (nthreads*task_multiplier) else: 1 - var threads : seq[Thread[InputTuple]] = newSeq[Thread[InputTuple]]( nthreads ) - var results : seq[G1] = newSeq[G1]( nthreads ) + # echo("msm with #threads = " & $nthreads) - proc myThreadFunc( inp: InputTuple ) {.thread.} = - results[inp.idx] = msmConstantineG1( inp.coeffs, inp.points ) + var pool = Taskpool.new(num_threads = nthreads) + var pending : seq[FlowVar[mycurves.G1]] = newSeq[FlowVar[mycurves.G1]](ntasks) - for i in 0..1: (nthreads*task_multiplier) else: 1 + + # echo("G2 msm with #threads = " & $nthreads) + + var pool = Taskpool.new(num_threads = nthreads) + var pending : seq[FlowVar[mycurves.G2]] = newSeq[FlowVar[mycurves.G2]](ntasks) + + # nim is just batshit crazy... + GC_ref(coeffs) + GC_ref(points) + + var a : int = 0 + var b : int + for k in 0..