diff --git a/README.md b/README.md index 8a97122..e16c8c1 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ at your choice. - [x] find and fix the _second_ totally surreal bug - [ ] clean up the code -- [ ] make it compatible with the latest constantine and also Nim 2.0.x +- [x] make it compatible with the latest constantine and also Nim 2.0.x - [x] make it a nimble package - [ ] compare `.r1cs` to the "coeffs" section of `.zkey` - [x] generate fake circuit-specific setup ourselves @@ -29,5 +29,5 @@ at your choice. - [x] multithreading support (MSM, and possibly also FFT) - [ ] add Groth16 notes - [ ] document the `snarkjs` circuit-specific setup `H` points convention +- [x] precalculate stuff for "partial" proofs - [ ] make it work for different curves - diff --git a/cli/cli_main.nim b/cli/cli_main.nim index b126503..e1dc300 100644 --- a/cli/cli_main.nim +++ b/cli/cli_main.nim @@ -1,13 +1,8 @@ -import sugar import std/strutils -import std/sequtils -import std/os import std/cpuinfo import std/parseopt import std/times -import std/options -# import strformat import taskpools @@ -21,6 +16,8 @@ import groth16/zkey_types import groth16/fake_setup import groth16/misc +import testing + #------------------------------------------------------------------------------- proc printHelp() = @@ -37,6 +34,7 @@ proc printHelp() = echo " -y, --verify : verify a proof" echo " -u, --setup : perform (fake) trusted setup" echo " -n, --nomask : don't use random masking for full ZK" + echo " -s, --sanity : sanity test the partial prover" echo " -z, --zkey = : the `.zkey` file" echo " -w, --wtns = : the `.wtns` file" echo " -r, --r1cs = : the `.r1cs` file" @@ -46,33 +44,35 @@ proc printHelp() = #------------------------------------------------------------------------------- type Config = object - zkey_file: string - r1cs_file: string - wtns_file: string - output_file: string - io_file: string - verbose: bool - debug: bool - measure_time: bool - do_prove: bool - do_verify: bool - do_setup: bool - no_masking: bool - nthreads: int + zkey_file: string + r1cs_file: string + wtns_file: string + output_file: string + io_file: string + verbose: bool + debug: bool + measure_time: bool + do_prove: bool + do_verify: bool + do_setup: bool + no_masking: bool + partial_sanity: bool + nthreads: int const dummyConfig = - Config( zkey_file: "" - , r1cs_file: "" - , wtns_file: "" - , output_file: "" - , io_file: "" - , verbose: false - , measure_time: false - , do_prove: false - , do_verify: false - , do_setup: false - , no_masking: false - , nthreads: 0 + Config( zkey_file: "" + , r1cs_file: "" + , wtns_file: "" + , output_file: "" + , io_file: "" + , verbose: false + , measure_time: false + , do_prove: false + , do_verify: false + , do_setup: false + , no_masking: false + , partial_sanity: false + , nthreads: 0 ) proc printConfig(cfg: Config) = @@ -114,6 +114,7 @@ proc parseCliOptions(): Config = of "y", "verify" : cfg.do_verify = true of "u", "setup" : cfg.do_setup = true of "n", "nomask" : cfg.no_masking = true + of "s", "sanity" : cfg.partial_sanity = true of "o", "output" : cfg.output_file = value of "r", "r1cs" : cfg.r1cs_file = value of "z", "zkey" : cfg.zkey_file = value @@ -136,28 +137,6 @@ proc parseCliOptions(): Config = return cfg -#------------------------------------------------------------------------------- - -#[ -proc testProveAndVerify*( zkey_fname, wtns_fname: string): (VKey,Proof) = - - echo("parsing witness & zkey files...") - let witness = parseWitness( wtns_fname) - let zkey = parseZKey( zkey_fname) - - echo("generating proof...") - let start = cpuTime() - let proof = generateProof( zkey, witness ) - let elapsed = cpuTime() - start - echo("proving took ",seconds(elapsed)) - - echo("verifying the proof...") - let vkey = extractVKey( zkey) - let ok = verifyProof( vkey, proof ) - echo("verification succeeded = ",ok) - - return (vkey,proof) -]# #------------------------------------------------------------------------------- @@ -198,6 +177,16 @@ proc cliMain(cfg: Config) = printGrothHeader(zkey.header) # debugPrintCoeffs(zkey.coeffs) + if cfg.partial_sanity: + if (cfg.wtns_file=="") or (cfg.zkey_file=="" and cfg.do_setup==false): + echo("cannot prove: missing witness and/or zkey file!") + quit() + else: + var pool = Taskpool.new(cfg.nthreads) + let print_timings = cfg.measure_time and cfg.verbose + sanityCheckPartialProofs(zkey,wtns,pool,print_timings) + echo("sanity testing partial proofs...") + if cfg.do_prove: if (cfg.wtns_file=="") or (cfg.zkey_file=="" and cfg.do_setup==false): echo("cannot prove: missing witness and/or zkey file!") diff --git a/cli/testing.nim b/cli/testing.nim new file mode 100644 index 0000000..bb30de2 --- /dev/null +++ b/cli/testing.nim @@ -0,0 +1,102 @@ + +import std/strutils +import std/times +import std/options +import std/random +import std/syncio + +import taskpools + +import constantine/named/properties_fields + +# import groth16/bn128 +import groth16/zkey_types +import groth16/files/witness +import groth16/misc +import groth16/files/export_json + +import groth16/partial/types +import groth16/partial/precalc +import groth16/partial/finish + +import groth16/prover +import groth16/prover/shared +import groth16/verifier + +#------------------------------------------------------------------------------- + +#[ +proc testProveAndVerify*( zkey_fname, wtns_fname: string): (VKey,Proof) = + + echo("parsing witness & zkey files...") + let witness = parseWitness( wtns_fname) + let zkey = parseZKey( zkey_fname) + + echo("generating proof...") + let start = cpuTime() + let proof = generateProof( zkey, witness ) + let elapsed = cpuTime() - start + echo("proving took ",seconds(elapsed)) + + echo("verifying the proof...") + let vkey = extractVKey( zkey) + let ok = verifyProof( vkey, proof ) + echo("verification succeeded = ",ok) + + return (vkey,proof) +]# + +#------------------------------------------------------------------------------- + +proc sanityCheckPartialProofs*( zkey: ZKey, wtns: Witness, pool: Taskpool, printTimings: bool) = + + let witness = wtns.values + let M = witness.len + + var partial_mask: seq[bool] = newSeq[bool]( M ) + var partial_witness: seq[Option[Fr[BN254_Snarks]]] = newSeq[Option[Fr[BN254_Snarks]]]( M ) + + # generate randomized partial witness + partial_mask[0] = true + partial_witness[0] = some(witness[0]) + var count = 0 + for i in 1..= startIdx: + compact_zs[k] = witness[i] + compact_pointsC1[k] = pts.pointsC1[ i - startIdx ] + k += 1 + + start = cpuTime() + var abc : ABC + withMeasureTime(printTimings,"building 'ABC'"): + abc = buildABC( zkey, witness ) + + start = cpuTime() + var qs : seq[Fr[BN254_Snarks]] + withMeasureTime(printTimings,"computing the quotient (FFTs)"): + case zkey.header.flavour + + # the points H are [delta^-1 * tau^i * Z(tau)] + of JensGroth: + let polyQ = computeQuotientPointwise( abc, pool ) + qs = polyQ.coeffs + + # the points H are `[delta^-1 * L_{2i+1}(tau)]_1` + # where L_i are Lagrange basis polynomials on the double-sized domain + of Snarkjs: + qs = computeSnarkjsScalarCoeffs( abc, pool ) + + # masking coeffs + let r = mask.r + let s = mask.s + + assert( witness.len == pts.pointsA1.len ) + assert( witness.len == pts.pointsB1.len ) + assert( witness.len == pts.pointsB2.len ) + assert( hdr.domainSize == qs.len ) + assert( hdr.domainSize == pts.pointsH1.len ) + assert( nvars - npubs - 1 == pts.pointsC1.len ) + + var pi_a : G1 = partialProof.partial_pi_a + withMeasureTime(printTimings,"computing pi_A (G1 MSM)"): + pi_a += r ** spec.delta1 + pi_a += msmMultiThreadedG1( compact_witness , compact_pointsA1, pool ) + + var rho : G1 = partialProof.partial_rho + withMeasureTime(printTimings,"computing rho (G1 MSM)"): + rho += s ** spec.delta1 + rho += msmMultiThreadedG1( compact_witness , compact_pointsB1, pool ) + + var pi_b : G2 = partialProof.partial_pi_b + withMeasureTime(printTimings,"computing pi_B (G2 MSM)"): + pi_b += s ** spec.delta2 + pi_b += msmMultiThreadedG2( compact_witness , compact_pointsB2, pool ) + + var pi_c : G1 = partialProof.partial_pi_c + withMeasureTime(printTimings,"computing pi_C (2x G1 MSM)"): + pi_c += s ** pi_a + pi_c += r ** rho + pi_c += negFr(r*s) ** spec.delta1 + pi_c += msmMultiThreadedG1( qs , pts.pointsH1, pool ) + pi_c += msmMultiThreadedG1( compact_zs , compact_pointsC1, pool ) + + return Proof( curve:"bn128", publicIO:pubIO, pi_a:pi_a, pi_b:pi_b, pi_c:pi_c ) + +#------------------------------------------------------------------------------- + +proc finishPartialProofWithTrivialMask*( zkey: ZKey, wtns: Witness, partial: PartialProof, pool: Taskpool, printTimings: bool ): Proof = + let mask = Mask( r: zeroFr , s: zeroFr ) + return finishPartialProofWithMask( zkey, wtns, partial, mask, pool, printTimings ) + +proc finishPartialProof*( zkey: ZKey, wtns: Witness, partial: PartialProof, pool: Taskpool, printTimings = false ): Proof = + let mask = randomMask() + return finishPartialProofWithMask( zkey, wtns, partial, mask, pool, printTimings ) + +#------------------------------------------------------------------------------- diff --git a/groth16/partial/precalc.nim b/groth16/partial/precalc.nim new file mode 100644 index 0000000..b0f8d68 --- /dev/null +++ b/groth16/partial/precalc.nim @@ -0,0 +1,152 @@ + +# +# "a partial prover": precalculate stuff based on a partial witness +# +# this is useful when a significant portion of the witness does not changes +# between proofs - then such precalculation can result in a potentially big speedup. +# +# an example use is RLN proofs, where the circuit is dominated by the Merkle +#vinclusion proof check, which is expected to change relatively rarely. +# + +{.push raises:[].} + +import std/times +import std/options +import system + +import taskpools + +import groth16/bn128 +import groth16/zkey_types +import groth16/misc + +import groth16/partial/types + +#import groth16/math/domain +#import groth16/math/poly +#import groth16/prover/shared + +#------------------------------------------------------------------------------- +# the prover +# + +proc generatePartialProof*( zkey: ZKey, pwtns: PartialWitness, pool: Taskpool, printTimings: bool): PartialProof = + + when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)): + {.fatal: "Compile with arc/orc!".} + + # assert( zkey.header.curve == wtns.curve ) + var start : float = 0 + + let partial_witness = pwtns.values + + let hdr : GrothHeader = zkey.header + let spec : SpecPoints = zkey.specPoints + let pts : ProverPoints = zkey.pPoints + + let nvars = hdr.nvars + let npubs = hdr.npubs + + assert( nvars == partial_witness.len , "wrong witness length" ) + + assert( partial_witness.len == pts.pointsA1.len ) + assert( partial_witness.len == pts.pointsB1.len ) + assert( partial_witness.len == pts.pointsB2.len ) + + # note: in "zs" we have to ignore public inputs (plus 1 for the special first entry) + var partial_mask : seq[bool] = newSeq[bool](nvars) + for i in 0..= startIdx: + compact_zs[k] = wtns_value + compact_pointsC1[k] = pts.pointsC1[ i - startIdx ] + k += 1 + +#[ + start = cpuTime() + var abc : ABC + withMeasureTime(printTimings,"building 'ABC'"): + abc = buildABC( zkey, witness ) + + start = cpuTime() + var qs : seq[Fr[BN254_Snarks]] + withMeasureTime(printTimings,"computing the quotient (FFTs)"): + case zkey.header.flavour + + # the points H are [delta^-1 * tau^i * Z(tau)] + of JensGroth: + let polyQ = computeQuotientPointwise( abc, pool ) + qs = polyQ.coeffs + + # the points H are `[delta^-1 * L_{2i+1}(tau)]_1` + # where L_i are Lagrange basis polynomials on the double-sized domain + of Snarkjs: + qs = computeSnarkjsScalarCoeffs( abc, pool ) + +]# + + assert( hdr.domainSize == pts.pointsH1.len ) + assert( nvars - npubs - 1 == pts.pointsC1.len ) + + var pi_a : G1 + withMeasureTime(printTimings,"computing partial pi_A (G1 MSM)"): + pi_a = spec.alpha1 + pi_a += msmMultiThreadedG1( compact_witness , compact_pointsA1 , pool ) + + var rho : G1 + withMeasureTime(printTimings,"computing partial rho (G1 MSM)"): + rho = spec.beta1 + rho += msmMultiThreadedG1( compact_witness , compact_pointsB1 , pool ) + + var pi_b : G2 + withMeasureTime(printTimings,"computing partial pi_B (G2 MSM)"): + pi_b = spec.beta2 + pi_b += msmMultiThreadedG2( compact_witness , compact_pointsB2 , pool ) + + var pi_c : G1 + withMeasureTime(printTimings,"partial computing pi_C (G1 MSM)"): + pi_c = msmMultiThreadedG1( compact_zs , compact_pointsC1 , pool ) + + return PartialProof( + partial_mask : partial_mask + , partial_pi_a : pi_a + , partial_rho : rho + , partial_pi_b : pi_b + , partial_pi_c : pi_c + ) + +#------------------------------------------------------------------------------- + diff --git a/groth16/partial/types.nim b/groth16/partial/types.nim new file mode 100644 index 0000000..2c7b24c --- /dev/null +++ b/groth16/partial/types.nim @@ -0,0 +1,29 @@ + +{.push raises:[].} + +import std/options + +import constantine/named/properties_fields + +import groth16/bn128 +import groth16/prover/types + +export types + +#------------------------------------------------------------------------------- + +type + + F* = Fr[BN254_Snarks] + + PartialWitness* = object + values* : seq[Option[F]] + + PartialProof* = object + partial_mask* : seq[bool] + partial_pi_a* : G1 # = [alpha]_1 + sum z_j*[A_j]_1 + partial_rho* : G1 # = [beta]_1 + sum z_j*[B_j]_1 + partial_pi_b* : G2 # = [beta]_2 + sum z_j*[B_j]_2 + partial_pi_c* : G1 # = sum z_j*[K_j]_1 + +#------------------------------------------------------------------------------- diff --git a/groth16/prover.nim b/groth16/prover.nim index 8f1bb89..ef914fa 100644 --- a/groth16/prover.nim +++ b/groth16/prover.nim @@ -7,332 +7,9 @@ # See # -{.push raises:[].} +import groth16/prover/types +import groth16/prover/groth16 -#[ -import sugar -import constantine/math/config/curves -import constantine/math/io/io_fields -import constantine/math/io/io_bigints -import ./zkey -]# +export types +export groth16 -import std/os -import std/times -import std/cpuinfo -import system -import taskpools -import constantine/math/arithmetic -import constantine/named/properties_fields -import constantine/math/extension_fields/towers - -#import constantine/math/io/io_extfields except Fp12 - -import groth16/bn128 -import groth16/math/domain -import groth16/math/poly -import groth16/zkey_types -import groth16/files/witness -import groth16/misc - -#------------------------------------------------------------------------------- - -type - Proof* = object - publicIO* : seq[Fr[BN254_Snarks]] - pi_a* : G1 - pi_b* : G2 - pi_c* : G1 - curve* : string - -#------------------------------------------------------------------------------- -# Az, Bz, Cz column vectors -# - -type - ABC = object - valuesAz : seq[Fr[BN254_Snarks]] - valuesBz : seq[Fr[BN254_Snarks]] - valuesCz : seq[Fr[BN254_Snarks]] - -# computes the vectors A*z, B*z, C*z where z is the witness -func buildABC( zkey: ZKey, witness: seq[Fr[BN254_Snarks]] ): ABC = - let hdr: GrothHeader = zkey.header - let domSize = hdr.domainSize - - var valuesAz = newSeq[Fr[BN254_Snarks]](domSize) - var valuesBz = newSeq[Fr[BN254_Snarks]](domSize) - - for entry in zkey.coeffs: - case entry.matrix - of MatrixA: valuesAz[entry.row] += entry.coeff * witness[entry.col] - of MatrixB: valuesBz[entry.row] += entry.coeff * witness[entry.col] - else: raise newException(AssertionDefect, "fatal error") - - var valuesCz = newSeq[Fr[BN254_Snarks]](domSize) - for i in 0..= 1) - var ys = newSeq[Fr[BN254_Snarks]](n) - ys[0] = xs[0] - if n >= 1: ys[1] = eta * xs[1] - var spow = eta - for i in 2.. -# -proc computeSnarkjsScalarCoeffs( abc: ABC, pool: TaskPool ): seq[Fr[BN254_Snarks]] = - let n = abc.valuesAz.len - assert( abc.valuesBz.len == n ) - assert( abc.valuesCz.len == n ) - let D = createDomain(n) - let eta = createDomain(2*n).domainGen - - var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]] - - var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 ) - var taskB1 = pool.spawn shiftEvalDomainTask( abc.valuesBz, D, eta, addr outputB1 ) - var taskC1 = pool.spawn shiftEvalDomainTask( abc.valuesCz, D, eta, addr outputC1 ) - - discard sync taskA1 - discard sync taskB1 - discard sync taskC1 - - let A1 = outputA1.extract() - let B1 = outputB1.extract() - let C1 = outputC1.extract() - - var ys : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( n ) - for j in 0.. +# + +{.push raises:[].} + +import std/times +import system +import taskpools +import constantine/math/arithmetic +import constantine/named/properties_fields + +#import constantine/math/io/io_extfields except Fp12 + +import groth16/bn128 +#import groth16/math/domain +import groth16/math/poly +import groth16/zkey_types +import groth16/files/witness +import groth16/misc + +import groth16/prover/types +import groth16/prover/shared + +#------------------------------------------------------------------------------- +# the prover +# + +proc generateProofWithMask*( zkey: ZKey, wtns: Witness, mask: Mask, pool: Taskpool, printTimings: bool): Proof = + + when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)): + {.fatal: "Compile with arc/orc!".} + + # if (zkey.header.curve != wtns.curve): + # echo( "zkey.header.curve = " & ($zkey.header.curve) ) + # echo( "wtns.curve = " & ($wtns.curve ) ) + + assert( zkey.header.curve == wtns.curve ) + var start : float = 0 + + let witness = wtns.values + + let hdr : GrothHeader = zkey.header + let spec : SpecPoints = zkey.specPoints + let pts : ProverPoints = zkey.pPoints + + let nvars = hdr.nvars + let npubs = hdr.npubs + + assert( nvars == witness.len , "wrong witness length" ) + + # remark: with the special variable "1" we actuall have (npub+1) public IO variables + var pubIO = newSeq[Fr[BN254_Snarks]]( npubs + 1) + for i in 0..npubs: pubIO[i] = witness[i] + + start = cpuTime() + var abc : ABC + withMeasureTime(printTimings,"building 'ABC'"): + abc = buildABC( zkey, witness ) + + start = cpuTime() + var qs : seq[Fr[BN254_Snarks]] + withMeasureTime(printTimings,"computing the quotient (FFTs)"): + case zkey.header.flavour + + # the points H are [delta^-1 * tau^i * Z(tau)] + of JensGroth: + let polyQ = computeQuotientPointwise( abc, pool ) + qs = polyQ.coeffs + + # the points H are `[delta^-1 * L_{2i+1}(tau)]_1` + # where L_i are Lagrange basis polynomials on the double-sized domain + of Snarkjs: + qs = computeSnarkjsScalarCoeffs( abc, pool ) + + var zs = newSeq[Fr[BN254_Snarks]]( nvars - npubs - 1 ) + for j in npubs+1.. +# + +{.push raises:[].} + +import system +import taskpools +import constantine/math/arithmetic +import constantine/named/properties_fields + +import groth16/bn128 +import groth16/math/domain +import groth16/math/poly +import groth16/zkey_types +#import groth16/misc + +import groth16/prover/types + +#------------------------------------------------------------------------------- + +proc randomMask*(): Mask = + # masking coeffs + let r = randFr() + let s = randFr() + let mask = Mask(r: r, s: s) + return mask + +#------------------------------------------------------------------------------- + +# computes the vectors A*z, B*z, C*z where z is the witness +func buildABC*( zkey: ZKey, witness: seq[Fr[BN254_Snarks]] ): ABC = + let hdr: GrothHeader = zkey.header + let domSize = hdr.domainSize + + var valuesAz = newSeq[Fr[BN254_Snarks]](domSize) + var valuesBz = newSeq[Fr[BN254_Snarks]](domSize) + + for entry in zkey.coeffs: + case entry.matrix + of MatrixA: valuesAz[entry.row] += entry.coeff * witness[entry.col] + of MatrixB: valuesBz[entry.row] += entry.coeff * witness[entry.col] + else: raise newException(AssertionDefect, "fatal error") + + var valuesCz = newSeq[Fr[BN254_Snarks]](domSize) + for i in 0..= 1) + var ys = newSeq[Fr[BN254_Snarks]](n) + ys[0] = xs[0] + if n >= 1: ys[1] = eta * xs[1] + var spow = eta + for i in 2.. +# +proc computeSnarkjsScalarCoeffs*( abc: ABC, pool: TaskPool ): seq[Fr[BN254_Snarks]] = + let n = abc.valuesAz.len + assert( abc.valuesBz.len == n ) + assert( abc.valuesCz.len == n ) + let D = createDomain(n) + let eta = createDomain(2*n).domainGen + + var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]] + + var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 ) + var taskB1 = pool.spawn shiftEvalDomainTask( abc.valuesBz, D, eta, addr outputB1 ) + var taskC1 = pool.spawn shiftEvalDomainTask( abc.valuesCz, D, eta, addr outputC1 ) + + discard sync taskA1 + discard sync taskB1 + discard sync taskC1 + + let A1 = outputA1.extract() + let B1 = outputB1.extract() + let C1 = outputC1.extract() + + var ys : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( n ) + for j in 0..