diff --git a/cli/cli_main.nim b/cli/cli_main.nim index b6a5a0e..08a4abc 100644 --- a/cli/cli_main.nim +++ b/cli/cli_main.nim @@ -1,5 +1,5 @@ -import sugar +import std/sugar import std/strutils import std/sequtils import std/os @@ -36,10 +36,10 @@ proc printHelp() = echo " -u, --setup : perform (fake) trusted setup" echo " -n, --nomask : don't use random masking for full ZK" echo " -z, --zkey = : the `.zkey` file" - echo " -w, --wtns = : the `.wtns` file" - echo " -r, --r1cs = : the `.r1cs` file" + echo " -w, --wtns = : the `.wtns` file" + echo " -r, --r1cs = : the `.r1cs` file" echo " -o, --output = : the proof file" - echo " -i, --io = : the public input/output file" + echo " -i, --io = : the public input/output file" #------------------------------------------------------------------------------- @@ -58,7 +58,7 @@ type Config = object no_masking: bool nthreads: int -const dummyConfig = +const dummyConfig = Config( zkey_file: "" , r1cs_file: "" , wtns_file: "" @@ -123,13 +123,13 @@ proc parseCliOptions(): Config = quit() of cmdEnd: - discard + discard if swCtr==0 and argCtr==0: printHelp() quit() - if cfg.nthreads <= 0: + if cfg.nthreads <= 0: cfg.nthreads = countProcessors() return cfg @@ -137,7 +137,7 @@ proc parseCliOptions(): Config = #------------------------------------------------------------------------------- #[ -proc testProveAndVerify*( zkey_fname, wtns_fname: string): (VKey,Proof) = +proc testProveAndVerify*( zkey_fname, wtns_fname: string): (VKey,Proof) = echo("parsing witness & zkey files...") let witness = parseWitness( wtns_fname) @@ -170,20 +170,20 @@ proc cliMain(cfg: Config) = echo("\nparsing witness file " & quoted(cfg.wtns_file)) withMeasureTime(cfg.measure_time,"parsing the witness"): wtns = parseWitness(cfg.wtns_file) - + if not (cfg.zkey_file == ""): echo("\nparsing zkey file " & quoted(cfg.zkey_file)) withMeasureTime(cfg.measure_time,"parsing the zkey"): zkey = parseZKey(cfg.zkey_file) - + if not (cfg.r1cs_file == ""): echo("\nparsing r1cs file " & quoted(cfg.r1cs_file)) withMeasureTime(cfg.measure_time,"parsing the r1cs"): r1cs = parseR1CS(cfg.r1cs_file) - + if cfg.do_setup: if not (cfg.zkey_file == ""): - echo("\nwe are doing a fake trusted setup, don't specify the zkey file!") + echo("\nwe are doing a fake trusted setup, don't specify the zkey file!") quit() if (cfg.r1cs_file == ""): echo("\nerror: r1cs file is required for the fake setup!") @@ -191,24 +191,24 @@ proc cliMain(cfg: Config) = echo("\nperforming fake trusted setup...") withMeasureTime(cfg.measure_time,"fake setup"): zkey = createFakeCircuitSetup( r1cs, flavour=Snarkjs ) - + if cfg.debug: printGrothHeader(zkey.header) # debugPrintCoeffs(zkey.coeffs) if cfg.do_prove: if (cfg.wtns_file=="") or (cfg.zkey_file=="" and cfg.do_setup==false): - echo("cannot prove: missing witness and/or zkey file!") + echo("cannot prove: missing witness and/or zkey file!") quit() else: echo("generating proof...") let print_timings = cfg.measure_time and cfg.verbose withMeasureTime(cfg.measure_time,"proving"): if cfg.no_masking: - proof = generateProofWithTrivialMask(cfg.nthreads, print_timings, zkey, wtns) + proof = generateProofWithTrivialMask(zkey, wtns, cfg.nthreads, print_timings) else: - proof = generateProof(cfg.nthreads, print_timings, zkey, wtns) - + proof = generateProof(zkey, wtns, cfg.nthreads, print_timings) + if not (cfg.output_file == ""): echo("exporting the proof to " & quoted(cfg.output_file)) exportProof( cfg.output_file, proof ) @@ -218,7 +218,7 @@ proc cliMain(cfg: Config) = if cfg.do_verify: if (cfg.zkey_file == "" and cfg.do_setup==false): - echo("cannot verify: missing vkey (well, zkey)") + echo("cannot verify: missing vkey (well, zkey)") quit() else: let vkey = extractVKey( zkey) @@ -227,7 +227,7 @@ proc cliMain(cfg: Config) = withMeasureTime(cfg.measure_time,"verifying"): ok = verifyProof( vkey, proof ) echo("verification succeeded = ",ok) - + echo("") #------------------------------------------------------------------------------- diff --git a/cli/nim.cfg b/cli/nim.cfg index abe4065..0f840a1 100644 --- a/cli/nim.cfg +++ b/cli/nim.cfg @@ -1,3 +1 @@ --path:".." ---threads:on ---mm:arc \ No newline at end of file diff --git a/groth16.nim b/groth16.nim index 658c6f0..7697246 100644 --- a/groth16.nim +++ b/groth16.nim @@ -1,12 +1,16 @@ import groth16/bn128 import groth16/files/zkey +import groth16/zkey_types +import groth16/files/r1cs import groth16/files/witness import groth16/prover import groth16/verifier +export r1cs export bn128 export zkey +export zkey_types export witness export prover export verifier diff --git a/groth16.nimble b/groth16.nimble index a6566ea..041efda 100644 --- a/groth16.nimble +++ b/groth16.nimble @@ -8,4 +8,4 @@ binDir = "build" namedBin = {"cli/cli_main": "nim-groth16"}.toTable() requires "https://github.com/status-im/nim-taskpools" -requires "https://github.com/mratsim/constantine#5f7ba18f2ed351260015397c9eae079a6decaee1" \ No newline at end of file +requires "constantine >= 0.2.0" diff --git a/groth16/bn128/curves.nim b/groth16/bn128/curves.nim index 42611b2..57d1607 100644 --- a/groth16/bn128/curves.nim +++ b/groth16/bn128/curves.nim @@ -13,36 +13,36 @@ #import constantine/platforms/abstractions #import constantine/math/isogenies/frobenius -import constantine/math/arithmetic except Fp, Fr -import constantine/math/io/io_fields except Fp, Fr -import constantine/math/io/io_bigints -import constantine/math/config/curves +import pkg/constantine/math/arithmetic +import pkg/constantine/math/io/io_fields +import pkg/constantine/math/io/io_bigints +# import pkg/constantine/math/config/curves -import constantine/math/config/type_ff as tff except Fp, Fr -import constantine/math/extension_fields/towers as ext except Fp, Fp2, Fp12, Fr +import pkg/constantine/named/properties_fields as tff +import pkg/constantine/math/extension_fields/towers as ext -import constantine/math/elliptic/ec_shortweierstrass_affine as aff -import constantine/math/elliptic/ec_shortweierstrass_projective as prj -import constantine/math/pairings/pairings_bn as ate -import constantine/math/elliptic/ec_scalar_mul_vartime as scl +import pkg/constantine/math/elliptic/ec_shortweierstrass_affine as aff +import pkg/constantine/math/elliptic/ec_shortweierstrass_projective as prj +import pkg/constantine/math/pairings/pairings_bn as ate +import pkg/constantine/math/elliptic/ec_scalar_mul_vartime as scl import groth16/bn128/fields #------------------------------------------------------------------------------- -type G1* = aff.ECP_ShortW_Aff[Fp , aff.G1] -type G2* = aff.ECP_ShortW_Aff[Fp2, aff.G2] +type G1* = aff.EC_ShortW_Aff[Fp[BN254_Snarks] , aff.G1] +type G2* = aff.EC_ShortW_Aff[Fp2[BN254_Snarks], aff.G2] -type ProjG1* = prj.ECP_ShortW_Prj[Fp , prj.G1] -type ProjG2* = prj.ECP_ShortW_Prj[Fp2, prj.G2] +type ProjG1* = prj.EC_ShortW_Prj[Fp[BN254_Snarks] , prj.G1] +type ProjG2* = prj.EC_ShortW_Prj[Fp2[BN254_Snarks], prj.G2] #------------------------------------------------------------------------------- -func unsafeMkG1* ( X, Y: Fp ) : G1 = - return aff.ECP_ShortW_Aff[Fp, aff.G1](x: X, y: Y) +func unsafeMkG1* ( X, Y: Fp[BN254_Snarks] ) : G1 = + return aff.EC_ShortW_Aff[Fp[BN254_Snarks], aff.G1](x: X, y: Y) -func unsafeMkG2* ( X, Y: Fp2 ) : G2 = - return aff.ECP_ShortW_Aff[Fp2, aff.G2](x: X, y: Y) +func unsafeMkG2* ( X, Y: Fp2[BN254_Snarks] ) : G2 = + return aff.EC_ShortW_Aff[Fp2[BN254_Snarks], aff.G2](x: X, y: Y) #------------------------------------------------------------------------------- @@ -51,15 +51,15 @@ const infG2* : G2 = unsafeMkG2( zeroFp2 , zeroFp2 ) #------------------------------------------------------------------------------- -func checkCurveEqG1*( x, y: Fp ) : bool = +func checkCurveEqG1*( x, y: Fp[BN254_Snarks] ) : bool = if bool(isZero(x)) and bool(isZero(y)): # the point at infinity is on the curve by definition return true else: - var x2 : Fp = squareFp(x) - var y2 : Fp = squareFp(y) - var x3 : Fp = x2 * x - var eq : Fp + var x2 = squareFp(x) + var y2 = squareFp(y) + var x3 = x2 * x + var eq : Fp[BN254_Snarks] eq = x3 eq += intToFp(3) eq -= y2 @@ -72,19 +72,19 @@ func checkCurveEqG1*( x, y: Fp ) : bool = # B = b1 + bu*u # b1 = 19485874751759354771024239261021720505790618469301721065564631296452457478373 # b2 = 266929791119991161246907387137283842545076965332900288569378510910307636690 -const twistCoeffB_1 : Fp = fromHex(Fp, "0x2b149d40ceb8aaae81be18991be06ac3b5b4c5e559dbefa33267e6dc24a138e5") -const twistCoeffB_u : Fp = fromHex(Fp, "0x009713b03af0fed4cd2cafadeed8fdf4a74fa084e52d1852e4a2bd0685c315d2") -const twistCoeffB : Fp2 = mkFp2( twistCoeffB_1 , twistCoeffB_u ) +const twistCoeffB_1 = fromHex(Fp[BN254_Snarks], "0x2b149d40ceb8aaae81be18991be06ac3b5b4c5e559dbefa33267e6dc24a138e5") +const twistCoeffB_u = fromHex(Fp[BN254_Snarks], "0x009713b03af0fed4cd2cafadeed8fdf4a74fa084e52d1852e4a2bd0685c315d2") +const twistCoeffB = mkFp2( twistCoeffB_1 , twistCoeffB_u ) -func checkCurveEqG2*( x, y: Fp2 ) : bool = +func checkCurveEqG2*( x, y: Fp2[BN254_Snarks] ) : bool = if isZeroFp2(x) and isZeroFp2(y): # the point at infinity is on the curve by definition return true else: - var x2 : Fp2 = squareFp2(x) - var y2 : Fp2 = squareFp2(y) - var x3 : Fp2 = x2 * x; - var eq : Fp2 + var x2 = squareFp2(x) + var y2 = squareFp2(y) + var x3 = x2 * x + var eq : Fp2[BN254_Snarks] eq = x3 eq += twistCoeffB eq -= y2 @@ -92,14 +92,14 @@ func checkCurveEqG2*( x, y: Fp2 ) : bool = #------------------------------------------------------------------------------- -func mkG1*( x, y: Fp ) : G1 = +func mkG1*( x, y: Fp[BN254_Snarks] ) : G1 = if isZeroFp(x) and isZeroFp(y): return infG1 else: assert( checkCurveEqG1(x,y) , "mkG1: not a G1 curve point" ) return unsafeMkG1(x,y) -func mkG2*( x, y: Fp2 ) : G2 = +func mkG2*( x, y: Fp2[BN254_Snarks] ) : G2 = if isZeroFp2(x) and isZeroFp2(y): return infG2 else: @@ -109,16 +109,16 @@ func mkG2*( x, y: Fp2 ) : G2 = #------------------------------------------------------------------------------- # group generators -const gen1_x : Fp = fromHex(Fp, "0x01") -const gen1_y : Fp = fromHex(Fp, "0x02") +const gen1_x = fromHex(Fp[BN254_Snarks], "0x01") +const gen1_y = fromHex(Fp[BN254_Snarks], "0x02") -const gen2_xi : Fp = fromHex(Fp, "0x1adcd0ed10df9cb87040f46655e3808f98aa68a570acf5b0bde23fab1f149701") -const gen2_xu : Fp = fromHex(Fp, "0x09e847e9f05a6082c3cd2a1d0a3a82e6fbfbe620f7f31269fa15d21c1c13b23b") -const gen2_yi : Fp = fromHex(Fp, "0x056c01168a5319461f7ca7aa19d4fcfd1c7cdf52dbfc4cbee6f915250b7f6fc8") -const gen2_yu : Fp = fromHex(Fp, "0x0efe500a2d02dd77f5f401329f30895df553b878fc3c0dadaaa86456a623235c") +const gen2_xi = fromHex(Fp[BN254_Snarks], "0x1adcd0ed10df9cb87040f46655e3808f98aa68a570acf5b0bde23fab1f149701") +const gen2_xu = fromHex(Fp[BN254_Snarks], "0x09e847e9f05a6082c3cd2a1d0a3a82e6fbfbe620f7f31269fa15d21c1c13b23b") +const gen2_yi = fromHex(Fp[BN254_Snarks], "0x056c01168a5319461f7ca7aa19d4fcfd1c7cdf52dbfc4cbee6f915250b7f6fc8") +const gen2_yu = fromHex(Fp[BN254_Snarks], "0x0efe500a2d02dd77f5f401329f30895df553b878fc3c0dadaaa86456a623235c") -const gen2_x : Fp2 = mkFp2( gen2_xi, gen2_xu ) -const gen2_y : Fp2 = mkFp2( gen2_yi, gen2_yu ) +const gen2_x = mkFp2( gen2_xi, gen2_xu ) +const gen2_y = mkFp2( gen2_yi, gen2_yu ) const gen1* : G1 = unsafeMkG1( gen1_x, gen1_y ) const gen2* : G2 = unsafeMkG2( gen2_x, gen2_y ) @@ -215,9 +215,9 @@ func `**`*( coeff: BigInt , point: G2 ) : G2 = #------------------------------------------------------------------------------- -func pairing* (p: G1, q: G2) : Fp12 = - var t : Fp12 - ate.pairing_bn[BN254Snarks]( t, p, q ) +func pairing* (p: G1, q: G2) : Fp12[BN254_Snarks] = + var t : Fp12[BN254_Snarks] + ate.pairing_bn( t, p, q ) return t #------------------------------------------------------------------------------- @@ -225,7 +225,7 @@ func pairing* (p: G1, q: G2) : Fp12 = proc sanityCheckGroupGen*() = echo( "gen1 on the curve = ", checkCurveEqG1(gen1.x,gen1.y) ) echo( "gen2 on the curve = ", checkCurveEqG2(gen2.x,gen2.y) ) - echo( "order of gen1 is R = ", (not bool(isInf(gen1))) and bool(isInf(primeR ** gen1)) ) - echo( "order of gen2 is R = ", (not bool(isInf(gen2))) and bool(isInf(primeR ** gen2)) ) + # echo( "order of gen1 is R = ", (not bool(isInf(gen1))) and bool(isInf(primeR ** gen1)) ) + # echo( "order of gen2 is R = ", (not bool(isInf(gen2))) and bool(isInf(primeR ** gen2)) ) #------------------------------------------------------------------------------- diff --git a/groth16/bn128/debug.nim b/groth16/bn128/debug.nim index ab7403c..9492a81 100644 --- a/groth16/bn128/debug.nim +++ b/groth16/bn128/debug.nim @@ -9,23 +9,27 @@ # equation: y^2 = x^3 + 3 # +import pkg/constantine/math/io/io_fields +import pkg/constantine/named/properties_fields +import pkg/constantine/math/extension_fields/towers + import groth16/bn128/fields import groth16/bn128/curves import groth16/bn128/io #------------------------------------------------------------------------------- -proc debugPrintFp*(prefix: string, x: Fp) = +proc debugPrintFp*(prefix: string, x: Fp[BN254_Snarks]) = echo(prefix & toDecimalFp(x)) -proc debugPrintFp2*(prefix: string, z: Fp2) = +proc debugPrintFp2*(prefix: string, z: Fp2[BN254_Snarks]) = echo(prefix & " 1 ~> " & toDecimalFp(z.coords[0])) echo(prefix & " u ~> " & toDecimalFp(z.coords[1])) -proc debugPrintFr*(prefix: string, x: Fr) = +proc debugPrintFr*(prefix: string, x: Fr[BN254_Snarks]) = echo(prefix & toDecimalFr(x)) -proc debugPrintFrSeq*(msg: string, xs: seq[Fr]) = +proc debugPrintFrSeq*(msg: string, xs: seq[Fr[BN254_Snarks]]) = echo "---------------------" echo msg for x in xs: diff --git a/groth16/bn128/fields.nim b/groth16/bn128/fields.nim index 5ffa8a1..4a027ef 100644 --- a/groth16/bn128/fields.nim +++ b/groth16/bn128/fields.nim @@ -1,4 +1,3 @@ - # # the prime fields Fp and Fr with sizes # @@ -6,30 +5,31 @@ # r = 21888242871839275222246405745257275088548364400416034343698204186575808495617 # -import sugar +import std/sugar import std/bitops import std/sequtils -import constantine/math/arithmetic -import constantine/math/io/io_fields -import constantine/math/io/io_bigints -import constantine/math/config/curves -import constantine/math/config/type_ff as tff -import constantine/math/extension_fields/towers as ext +import pkg/constantine/math/arithmetic +import pkg/constantine/math/io/io_fields +# import pkg/constantine/math/io/io_extfields +import pkg/constantine/math/io/io_bigints +# import pkg/constantine/math/config/curves +import pkg/constantine/named/properties_fields +import pkg/constantine/math/extension_fields/towers as ext #------------------------------------------------------------------------------- type B* = BigInt[256] -type Fr* = tff.Fr[BN254Snarks] -type Fp* = tff.Fp[BN254Snarks] +# type Fr* = tff.Fr[BN254_Snarks] +# type Fp* = tff.Fp[BN254_Snarks] -type Fp2* = ext.QuadraticExt[Fp] -type Fp12* = ext.Fp12[BN254Snarks] +# type Fp2* = ext.QuadraticExt[Fp] +# type Fp12* = ext.Fp12[BN254_Snarks] -func mkFp2* (i: Fp, u: Fp) : Fp2 = - let c : array[2, Fp] = [i,u] - return ext.QuadraticExt[Fp]( coords: c ) +func mkFp2* (i: Fp[BN254_Snarks], u: Fp[BN254_Snarks]) : Fp2[BN254_Snarks] = + let c : array[2, Fp[BN254_Snarks]] = [i,u] + return ext.QuadraticExt[Fp[BN254_Snarks]]( coords: c ) #------------------------------------------------------------------------------- @@ -38,16 +38,16 @@ const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b97 #------------------------------------------------------------------------------- -const zeroFp* : Fp = fromHex( Fp, "0x00" ) -const zeroFr* : Fr = fromHex( Fr, "0x00" ) -const oneFp* : Fp = fromHex( Fp, "0x01" ) -const oneFr* : Fr = fromHex( Fr, "0x01" ) +const zeroFp* = fromHex( Fp[BN254_Snarks], "0x00" ) +const zeroFr* = fromHex( Fr[BN254_Snarks], "0x00" ) +const oneFp* = fromHex( Fp[BN254_Snarks], "0x01" ) +const oneFr* = fromHex( Fr[BN254_Snarks], "0x01" ) -const zeroFp2* : Fp2 = mkFp2( zeroFp, zeroFp ) -const oneFp2* : Fp2 = mkFp2( oneFp , zeroFp ) +const zeroFp2* = mkFp2( zeroFp, zeroFp ) +const oneFp2* = mkFp2( oneFp , zeroFp ) -const minusOneFp* : Fp = fromHex( Fp, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd46" ) -const minusOneFr* : Fr = fromHex( Fr, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000000" ) +const minusOneFp* = fromHex( Fp[BN254_Snarks], "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd46" ) +const minusOneFr* = fromHex( Fr[BN254_Snarks], "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000000" ) #------------------------------------------------------------------------------- @@ -56,13 +56,13 @@ func intToB*(a: uint): B = y.setUint(a) return y -func intToFp*(a: int): Fp = - var y : Fp +func intToFp*(a: int): Fp[BN254_Snarks] = + var y : Fp[BN254_Snarks] y.fromInt(a) return y -func intToFr*(a: int): Fr = - var y : Fr +func intToFr*(a: int): Fr[BN254_Snarks] = + var y : Fr[BN254_Snarks] y.fromInt(a) return y @@ -154,27 +154,27 @@ func smallPowFr*(base: Fr, expo: int): Fr = #------------------------------------------------------------------------------- -func deltaFr*(i, j: int) : Fr = +func deltaFr*[T](i, j: int) : Fr[T] = return (if (i == j): oneFr else: zeroFr) #------------------------------------------------------------------------------- # Montgomery batch inversion -func batchInverseFr*( xs: seq[Fr] ) : seq[Fr] = +func batchInverseFr*( xs: seq[Fr[BN254_Snarks]] ) : seq[Fr[BN254_Snarks]] = let n = xs.len assert(n>0) - var us : seq[Fr] = newSeq[Fr](n+1) + var us : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]](n+1) var a = xs[0] us[0] = oneFr us[1] = a for i in 1.. primeP_254 - k65536 ): return "-" & toDecimalFp(negFp(a)) else: return toDecimalFp(a) -func signedToDecimalFr*(a : Fr): string = +func signedToDecimalFr*(a : Fr[BN254_Snarks]): string = if bool( a.toBig() > primeR_254 - k65536 ): return "-" & toDecimalFr(negFr(a)) else: @@ -58,38 +59,38 @@ func signedToDecimalFr*(a : Fr): string = # # R=2^256; this computes 2^256 mod Fp -func calcFpMontR*() : Fp = - var x : Fp = intToFp(2) +func calcFpMontR*() : Fp[BN254_Snarks] = + var x : Fp[BN254_Snarks] = intToFp(2) for i in 1..8: square(x) return x # R=2^256; this computes the inverse of (2^256 mod Fp) -func calcFpInvMontR*() : Fp = - var x : Fp = calcFpMontR() +func calcFpInvMontR*() : Fp[BN254_Snarks] = + var x : Fp[BN254_Snarks] = calcFpMontR() inv(x) return x # R=2^256; this computes 2^256 mod Fr -func calcFrMontR*() : Fr = - var x : Fr = intToFr(2) +func calcFrMontR*() : Fr[BN254_Snarks] = + var x : Fr[BN254_Snarks] = intToFr(2) for i in 1..8: square(x) return x # R=2^256; this computes the inverse of (2^256 mod Fp) -func calcFrInvMontR*() : Fr = - var x : Fr = calcFrMontR() +func calcFrInvMontR*() : Fr[BN254_Snarks] = + var x : Fr[BN254_Snarks] = calcFrMontR() inv(x) return x # apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?) -const fpMontR* : Fp = fromHex( Fp, "0x0e0a77c19a07df2f666ea36f7879462c0a78eb28f5c70b3dd35d438dc58f0d9d" ) -const fpInvMontR* : Fp = fromHex( Fp, "0x2e67157159e5c639cf63e9cfb74492d9eb2022850278edf8ed84884a014afa37" ) +const fpMontR* = fromHex( Fp[BN254_Snarks], "0x0e0a77c19a07df2f666ea36f7879462c0a78eb28f5c70b3dd35d438dc58f0d9d" ) +const fpInvMontR* = fromHex( Fp[BN254_Snarks], "0x2e67157159e5c639cf63e9cfb74492d9eb2022850278edf8ed84884a014afa37" ) # apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?) -const frMontR* : Fr = fromHex( Fr, "0x0e0a77c19a07df2f666ea36f7879462e36fc76959f60cd29ac96341c4ffffffb" ) -const frInvMontR* : Fr = fromHex( Fr, "0x15ebf95182c5551cc8260de4aeb85d5d090ef5a9e111ec87dc5ba0056db1194e" ) +const frMontR* = fromHex( Fr[BN254_Snarks], "0x0e0a77c19a07df2f666ea36f7879462e36fc76959f60cd29ac96341c4ffffffb" ) +const frInvMontR* = fromHex( Fr[BN254_Snarks], "0x15ebf95182c5551cc8260de4aeb85d5d090ef5a9e111ec87dc5ba0056db1194e" ) proc checkMontgomeryConstants*() = assert( bool( fpMontR == calcFpMontR() ) ) @@ -100,21 +101,21 @@ proc checkMontgomeryConstants*() = #--------------------------------------- -# the binary file `.zkey` used by the `circom` ecosystem uses little-endian -# Montgomery representation. So when we unmarshal with Constantine, it will +# the binary file `.zkey` used by the `circom` ecosystem uses little-endian +# Montgomery representation. So when we unmarshal with Constantine, it will # give the wrong result. Calling this function on the result fixes that. -func fromMontgomeryFp*(x : Fp) : Fp = - var y : Fp = x; +func fromMontgomeryFp*(x : Fp[BN254_Snarks]) : Fp[BN254_Snarks] = + var y : Fp[BN254_Snarks] = x; y *= fpInvMontR return y -func fromMontgomeryFr*(x : Fr) : Fr = - var y : Fr = x; +func fromMontgomeryFr*(x : Fr[BN254_Snarks]) : Fr[BN254_Snarks] = + var y = x; y *= frInvMontR return y -func toMontgomeryFr*(x : Fr) : Fr = - var y : Fr = x; +func toMontgomeryFr*(x : Fr[BN254_Snarks]) : Fr[BN254_Snarks] = + var y = x; y *= frMontR return y @@ -123,47 +124,47 @@ func toMontgomeryFr*(x : Fr) : Fr = # Note: in the `.zkey` coefficients, e apparently DOUBLE Montgomery encoding is used ?!? # -func unmarshalFpMont* ( bs: array[32,byte] ) : Fp = +func unmarshalFpMont* ( bs: array[32,byte] ) : Fp[BN254_Snarks] = var big : BigInt[254] unmarshal( big, bs, littleEndian ); - var x : Fp + var x : Fp[BN254_Snarks] x.fromBig( big ) return fromMontgomeryFp(x) # WTF Jordi, go home you are drunk -func unmarshalFrWTF* ( bs: array[32,byte] ) : Fr = +func unmarshalFrWTF* ( bs: array[32,byte] ) : Fr[BN254_Snarks] = var big : BigInt[254] unmarshal( big, bs, littleEndian ); - var x : Fr + var x : Fr[BN254_Snarks] x.fromBig( big ) return fromMontgomeryFr(fromMontgomeryFr(x)) -func unmarshalFrStd* ( bs: array[32,byte] ) : Fr = +func unmarshalFrStd* ( bs: array[32,byte] ) : Fr[BN254_Snarks] = var big : BigInt[254] unmarshal( big, bs, littleEndian ); - var x : Fr + var x : Fr[BN254_Snarks] x.fromBig( big ) return x -func unmarshalFrMont* ( bs: array[32,byte] ) : Fr = +func unmarshalFrMont* ( bs: array[32,byte] ) : Fr[BN254_Snarks] = var big : BigInt[254] unmarshal( big, bs, littleEndian ); - var x : Fr + var x : Fr[BN254_Snarks] x.fromBig( big ) return fromMontgomeryFr(x) #------------------------------------------------------------------------------- -func unmarshalFpMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fp] = - var vals : seq[Fp] = newSeq[Fp]( len ) +func unmarshalFpMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fp[BN254_Snarks]] = + var vals : seq[Fp[BN254_Snarks]] = newSeq[Fp[BN254_Snarks]]( len ) var bytes : array[32,byte] for i in 0..= 1024, we use 8+ threads + # for N == 512 , we use 4 threads + # for N >= 1024, we use 8+ threads let N = coeffs.len assert( N == points.len, "incompatible sequence lengths" ) - let nthreads_target = if (nthreads_hint<=0): countProcessors() else: min( nthreads_hint, 256 ) + let nthreads_target = min( pool.numThreads, 256 ) let nthreads = max( 1 , min( N div 128 , nthreads_target ) ) - let ntasks = if nthreads>1: (nthreads*task_multiplier) else: 1 + let ntasks = if nthreads>1: ( nthreads * task_multiplier ) else: 1 - var pool = Taskpool.new(num_threads = nthreads) - var pending : seq[FlowVar[mycurves.G1]] = newSeq[FlowVar[mycurves.G1]](ntasks) + var pending : seq[Flowvar[mycurves.G1]] = newSeq[Flowvar[mycurves.G1]](ntasks) var a : int = 0 var b : int @@ -118,23 +117,19 @@ proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G1] for k in 0..1: (nthreads*task_multiplier) else: 1 - var pool = Taskpool.new(num_threads = nthreads) - var pending : seq[FlowVar[mycurves.G2]] = newSeq[FlowVar[mycurves.G2]](ntasks) + var pending : seq[Flowvar[mycurves.G2]] = newSeq[Flowvar[mycurves.G2]](ntasks) var a : int = 0 var b : int @@ -152,19 +147,16 @@ proc msmMultiThreadedG2*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G2] for k in 0.. number of bytes # @@ -23,32 +23,33 @@ import std/streams -import sugar +import std/sugar -import constantine/math/arithmetic except Fp, Fr -import constantine/math/io/io_bigints +import pkg/constantine/math/arithmetic except Fp, Fr +import pkg/constantine/math/io/io_bigints #------------------------------------------------------------------------------- -type +type SectionCallback*[T] = proc (stream: Stream, sectId: int, sectLen: int, user: var T) {.closure.} + Filt* = ((int) {.raises: [], gcsafe.} -> bool) #------------------------------------------------------------------------------- -func magicWord(magic: string): uint32 = +func magicWord(magic: string): uint32 = assert( magic.len == 4, "magicWord: expecting a string of 4 characters" ) - var w : uint32 = 0 + var w : uint32 = 0 for i in 0..3: - let a = uint32(ord(magic[i])) + let a = uint32(ord(magic[i])) w += a shl (8*i) return w #------------------------------------------------------------------------------- -proc parsePrimeField*( stream: Stream ) : (int, BigInt[256]) = +proc parsePrimeField*( stream: Stream ) : (int, BigInt[256]) = let n8p = int( stream.readUint32() ) assert( n8p <= 32 , "at most 256 bit primes are allowed" ) - var p_bytes : array[32, uint8] + var p_bytes : array[32, uint8] discard stream.readData( addr(p_bytes), n8p ) var p : BigInt[256] unmarshal(p, p_bytes, littleEndian); @@ -60,8 +61,8 @@ proc readSection[T] ( expectedMagic: string , expectedVersion: int , stream: Stream , user: var T - , callback: SectionCallback[T] - , filt: (int) -> bool ) = + , callback: SectionCallback[T] + , filt: Filt ) {.raises: [IOError, OSError].} = let sectId = int( stream.readUint32() ) let sectLen = int( stream.readUint64() ) @@ -76,8 +77,8 @@ proc parseContainer*[T] ( expectedMagic: string , expectedVersion: int , fname: string , user: var T - , callback: SectionCallback[T] - , filt: (int) -> bool ) = + , callback: SectionCallback[T] + , filt: Filt ) {.raises: [IOError, OSError].} = let stream = newFileStream(fname, mode = fmRead) defer: stream.close() diff --git a/groth16/files/export_json.nim b/groth16/files/export_json.nim index 53b2828..498009c 100644 --- a/groth16/files/export_json.nim +++ b/groth16/files/export_json.nim @@ -3,37 +3,41 @@ # export proof and public input in `circom`-compatible JSON files # -import constantine/math/arithmetic except Fp, Fr -#import constantine/math/io/io_fields except Fp, Fr +import pkg/constantine/math/arithmetic +import pkg/constantine/math/io/io_fields +# import pkg/constantine/math/config/curves + +import pkg/constantine/named/properties_fields +import pkg/constantine/math/extension_fields/towers import groth16/bn128 from groth16/prover import Proof #------------------------------------------------------------------------------- -func toQuotedDecimalFp(x: Fp): string = +func toQuotedDecimalFp(x: Fp[BN254_Snarks]): string = let s : string = toDecimalFp(x) return ("\"" & s & "\"") -func toQuotedDecimalFr(x: Fr): string = +func toQuotedDecimalFr(x: Fr[BN254_Snarks]): string = let s : string = toDecimalFr(x) return ("\"" & s & "\"") #------------------------------------------------------------------------------- # exports the public input/output into as a JSON file -proc exportPublicIO*( fpath: string, prf: Proof ) = +proc exportPublicIO*( fpath: string, prf: Proof ) = - # debugPrintFrSeq("public IO",prf.publicIO) + debugPrintFrSeq("public IO",prf.publicIO) - let n : int = prf.publicIO.len + let n : int = prf.publicIO.len assert( n > 0 ) assert( bool(prf.publicIO[0] == oneFr) ) let f = open(fpath, fmWrite) defer: f.close() - # note: we start from 1 because the 0th element is the constant 1 "variable", + # note: we start from 1 because the 0th element is the constant 1 "variable", # which is automatically added by the tools for i in 1.. @@ -49,31 +49,36 @@ # ... # -import std/streams +{.push raises: [IOError, OSError].} -import constantine/math/arithmetic except Fp, Fr -import constantine/math/io/io_bigints +import std/streams +import std/sugar + +import pkg/constantine/math/arithmetic except Fp, Fr +import pkg/constantine/math/io/io_bigints +import pkg/constantine/named/properties_fields +import pkg/constantine/math/extension_fields/towers import groth16/bn128 import groth16/files/container #------------------------------------------------------------------------------- -type - +type + WitnessConfig* = object nWires* : int # total number of wires (or witness variables), including the constant 1 "variable" nPubOut* : int # number of public outputs nPubIn* : int # number of public inputs nPrivIn* : int # number of private inputs nLabels* : int # number of labels - - Term* = tuple[ wireIdx: int, value: Fr ] + + Term* = tuple[ wireIdx: int, value: Fr[BN254_Snarks] ] LinComb* = seq[Term] Constraint* = tuple[ A: LinComb, B: LinComb, C: LinComb ] R1CS* = object - r* : BigInt[256] + r* : BigInt[256] cfg* : WitnessConfig nConstr* : int constraints* : seq[Constraint] @@ -83,7 +88,7 @@ type proc parseSection1_header( stream: Stream, user: var R1CS, sectionLen: int ) = # echo "\nparsing r1cs header" - + let (n8r, r) = parsePrimeField( stream ) # size of the scalar field user.r = r; @@ -110,19 +115,19 @@ proc parseSection1_header( stream: Stream, user: var R1CS, sectionLen: int ) = #------------------------------------------------------------------------------- -proc loadTerm( stream: Stream ): Term = +proc loadTerm( stream: Stream ): Term = let idx = int( stream.readUint32() ) let coeff = loadValueFrStd( stream ) return (wireIdx:idx, value:coeff) -proc loadLinComb( stream: Stream ): LinComb = +proc loadLinComb( stream: Stream ): LinComb = let nterms = int( stream.readUint32() ) var terms : seq[Term] for i in 1..nterms: terms.add( loadTerm(stream) ) return terms -proc loadConstraint( stream: Stream ): Constraint = +proc loadConstraint( stream: Stream ): Constraint = let a = loadLinComb( stream ) let b = loadLinComb( stream ) let c = loadLinComb( stream ) @@ -160,17 +165,17 @@ proc r1csCallback( stream: Stream , sectId: int , sectLen: int , user: var R1CS - ) = + ) = case sectId of 1: parseSection1_header( stream, user, sectLen ) of 2: parseSection2_constraints( stream, user, sectLen ) of 3: parseSection3_wireToLabel( stream, user, sectLen ) else: discard -proc parseR1CS* (fname: string): R1CS = +proc parseR1CS* (fname: string): R1CS = var r1cs : R1CS - parseContainer( "r1cs", 1, fname, r1cs, r1csCallback, proc (id: int): bool = id == 1 ) - parseContainer( "r1cs", 1, fname, r1cs, r1csCallback, proc (id: int): bool = id != 1 ) + parseContainer( "r1cs", 1, fname, r1cs, r1csCallback, (id) {.raises: [], gcsafe.} => id == 1 ) + parseContainer( "r1cs", 1, fname, r1cs, r1csCallback, (id) {.raises: [], gcsafe.} => id != 1 ) return r1cs #------------------------------------------------------------------------------- diff --git a/groth16/files/witness.nim b/groth16/files/witness.nim index f849eb3..7f74b5e 100644 --- a/groth16/files/witness.nim +++ b/groth16/files/witness.nim @@ -16,26 +16,29 @@ import std/streams -import constantine/math/arithmetic except Fp, Fr -import constantine/math/io/io_bigints +import pkg/constantine/math/arithmetic +import pkg/constantine/math/io/io_fields +import pkg/constantine/math/io/io_bigints +import pkg/constantine/named/properties_fields +import pkg/constantine/math/extension_fields/towers import groth16/bn128 import groth16/files/container #------------------------------------------------------------------------------- -type +type Witness* = object curve* : string - r* : BigInt[256] + r* : BigInt[256] nvars* : int - values* : seq[Fr] + values* : seq[Fr[BN254_Snarks]] #------------------------------------------------------------------------------- proc parseSection1_header( stream: Stream, user: var Witness, sectionLen: int ) = # echo "\nparsing witness header" - + let (n8r, r) = parsePrimeField( stream ) # size of the scalar field user.r = r; @@ -61,14 +64,14 @@ proc parseSection2_witness( stream: Stream, user: var Witness, sectionLen: int ) #------------------------------------------------------------------------------- -proc wtnsCallback(stream: Stream, sectId: int, sectLen: int, user: var Witness) = +proc wtnsCallback(stream: Stream, sectId: int, sectLen: int, user: var Witness) = #echo(sectId) case sectId of 1: parseSection1_header( stream, user, sectLen ) of 2: parseSection2_witness( stream, user, sectLen ) else: discard -proc parseWitness* (fname: string): Witness = +proc parseWitness* (fname: string): Witness = var wtns : Witness parseContainer( "wtns", 2, fname, wtns, wtnsCallback, proc (id: int): bool = id == 1 ) parseContainer( "wtns", 2, fname, wtns, wtnsCallback, proc (id: int): bool = id != 1 ) diff --git a/groth16/files/zkey.nim b/groth16/files/zkey.nim index 73ce1e5..af75b07 100644 --- a/groth16/files/zkey.nim +++ b/groth16/files/zkey.nim @@ -5,12 +5,12 @@ # # file format # =========== -# +# # standard iden3 binary container format. # field elements are in Montgomery representation, except for the coefficients -# which for some reason are double Montgomery encoded... (and unlike the +# which for some reason are double Montgomery encoded... (and unlike the # `.wtns` and `.r1cs` files which use the standard representation) -# +# # sections: # # 1: Header @@ -81,7 +81,7 @@ # what normally should be the curve points `[ delta^-1 * tau^i * Z(tau) ]_1` # HOWEVER, in the snarkjs implementation, they are different; namely # `[ delta^-1 * L_{2i+1} (tau) ]_1` where L_k are Lagrange polynomials -# on the refined (double sized) domain +# on the refined (double sized) domain # See # length = 2 * n8p * domSize = domSize G1 points # @@ -94,18 +94,20 @@ import std/streams -import constantine/math/arithmetic except Fp, Fr +import pkg/constantine/math/arithmetic except Fp, Fr #import constantine/math/io/io_bigints - + import groth16/bn128 import groth16/zkey_types import groth16/files/container import groth16/misc +export zkey_types + #------------------------------------------------------------------------------- -proc parseSection1_proverType ( stream: Stream, user: var Zkey, sectionLen: int ) = - assert( sectionLen == 4 , "unexpected section length" ) +proc parseSection1_proverType ( stream: Stream, user: var ZKey, sectionLen: int ) = + assert( sectionLen == 4 , "unexpected section length" ) let proverType = stream.readUint32 assert( proverType == 1 , "expecting `.zkey` file for a Groth16 prover") @@ -128,8 +130,8 @@ proc parseSection2_GrothHeader( stream: Stream, user: var ZKey, sectionLen: int header.flavour = Snarkjs - assert( n8p == 32 , "expecting 256 bit primes") - assert( n8r == 32 , "expecting 256 bit primes") + assert( n8p == 32 , "expecting 256 bit primes") + assert( n8r == 32 , "expecting 256 bit primes") assert( bool(p == primeP) , "expecting the alt-bn128 curve" ) assert( bool(r == primeR) , "expecting the alt-bn128 curve" ) @@ -171,7 +173,7 @@ proc parseSection4_Coeffs( stream: Stream, user: var ZKey, sectionLen: int ) = assert( sectionLen == 4 + ncoeffs*(32+12) , "unexpected section length" ) let nrows = user.header.domainSize let ncols = user.header.nvars - + var coeffs : seq[Coeff] for i in 1..ncoeffs: let m = int( stream.readUint32() ) # which matrix @@ -188,7 +190,7 @@ proc parseSection4_Coeffs( stream: Stream, user: var ZKey, sectionLen: int ) = let cf = loadValueFrWTF( stream ) # Jordi, WTF is this encoding ?!?!?!!111 let entry = Coeff( matrix:sel, row:r, col:c, coeff:cf ) coeffs.add( entry ) - + user.coeffs = coeffs #------------------------------------------------------------------------------- @@ -225,7 +227,7 @@ proc parseSection9_PointsH1( stream: Stream, user: var ZKey, sectionLen: int ) = #------------------------------------------------------------------------------- -proc zkeyCallback(stream: Stream, sectId: int, sectLen: int, user: var ZKey) = +proc zkeyCallback(stream: Stream, sectId: int, sectLen: int, user: var ZKey) = case sectId of 1: parseSection1_proverType( stream, user, sectLen ) of 2: parseSection2_GrothHeader( stream, user, sectLen ) @@ -238,7 +240,7 @@ proc zkeyCallback(stream: Stream, sectId: int, sectLen: int, user: var ZKey) = of 9: parseSection9_PointsH1( stream, user, sectLen ) else: discard -proc parseZKey* (fname: string): ZKey = +proc parseZKey* (fname: string): ZKey = var zkey : ZKey parseContainer( "zkey", 1, fname, zkey, zkeyCallback, proc (id: int): bool = id == 1 ) parseContainer( "zkey", 1, fname, zkey, zkeyCallback, proc (id: int): bool = id == 2 ) diff --git a/groth16/math/domain.nim b/groth16/math/domain.nim index 2185814..aef7b06 100644 --- a/groth16/math/domain.nim +++ b/groth16/math/domain.nim @@ -3,38 +3,39 @@ # power-of-two sized multiplicative FFT domains in the scalar field # -import constantine/math/arithmetic except Fp,Fr -import constantine/math/io/io_fields except Fp,Fr -#import constantine/math/io/io_bigints +import pkg/constantine/math/arithmetic +import pkg/constantine/math/io/io_fields +import pkg/constantine/named/properties_fields +import pkg/constantine/math/extension_fields/towers import groth16/bn128 import groth16/misc #------------------------------------------------------------------------------- -type +type Domain* = object - domainSize* : int # `N = 2^n` - logDomainSize* : int # `n = log2(N)` - domainGen* : Fr # `g` - invDomainGen* : Fr # `g^-1` - invDomainSize* : Fr # `1/n` + domainSize* : int # `N = 2^n` + logDomainSize* : int # `n = log2(N)` + domainGen* : Fr[BN254_Snarks] # `g` + invDomainGen* : Fr[BN254_Snarks] # `g^-1` + invDomainSize* : Fr[BN254_Snarks] # `1/n` #------------------------------------------------------------------------------- # the generator of the multiplicative subgroup with size `2^28` -const gen28 : Fr = fromHex( Fr, "0x2a3c09f0a58a7e8500e0a7eb8ef62abc402d111e41112ed49bd61b6e725b19f0" ) +const gen28 = fromHex( Fr[BN254_Snarks], "0x2a3c09f0a58a7e8500e0a7eb8ef62abc402d111e41112ed49bd61b6e725b19f0" ) -func createDomain*(size: int): Domain = +func createDomain*(size: int): Domain = let log2 = ceilingLog2(size) assert( (1 shl log2) == size , "domain must have a power-of-two size" ) - let expo : uint = 1'u shl (28 - log2) - let gen : Fr = smallPowFr(gen28, expo) + let expo = 1'u shl (28 - log2) + let gen = smallPowFr(gen28, expo) let halfSize = size div 2 - let a : Fr = smallPowFr(gen, size ) - let b : Fr = smallPowFr(gen, halfSize) + let a = smallPowFr(gen, size ) + let b = smallPowFr(gen, halfSize) assert( bool(a == oneFr) , "domain generator sanity check /A" ) assert( not bool(b == oneFr) , "domain generator sanity check /B" ) @@ -42,14 +43,14 @@ func createDomain*(size: int): Domain = , logDomainSize: log2 , domainGen: gen , invDomainGen: invFr(gen) - , invDomainSize: invFr(intToFr(size)) + , invDomainSize: invFr(intToFr(size)) ) #------------------------------------------------------------------------------- -func enumerateDomain*(D: Domain): seq[Fr] = - var xs : seq[Fr] = newSeq[Fr](D.domainSize) - var g : Fr = oneFr +func enumerateDomain*(D: Domain): seq[Fr[BN254_Snarks]] = + var xs = newSeq[Fr[BN254_Snarks]](D.domainSize) + var g = oneFr for i in 0..= 0): d -= 1 return d -func polyIsZero*(P: Poly) : bool = +func polyIsZero*(P: Poly) : bool = let xs = P.coeffs ; let n = xs.len var b = true for i in 0..= m: for i in 0.. 0: y = cs[0] for i in 1..= m: for i in 0..= m: for i in 0..= i = k - j >= k - min(k , n2-1) - # 0 >= j = k - i >= k - min(k , n1-1) + # 0 >= j = k - i >= k - min(k , n1-1) let A : int = max( 0 , k - min(k , n2-1) ) let B : int = min( k , n1-1 ) zs[k] = zeroFr @@ -124,7 +129,7 @@ func polyMulNaive*(P, Q : Poly): Poly = #------------------------------------------------------------------------------- # multiply two polynomials using FFT -func polyMulFFT*(P, Q: Poly): Poly = +func polyMulFFT*(P, Q: Poly): Poly = let n1 = P.coeffs.len let n2 = Q.coeffs.len @@ -132,10 +137,10 @@ func polyMulFFT*(P, Q: Poly): Poly = let N : int = (1 shl log2) let D : Domain = createDomain( N ) - let us : seq[Fr] = extendAndForwardNTT( P.coeffs, D ) - let vs : seq[Fr] = extendAndForwardNTT( Q.coeffs, D ) - let zs : seq[Fr] = collect( newSeq, (for i in 0..=1 ) - var cs : seq[Fr] = newSeq[Fr]( N+1 ) + var cs = newSeq[Fr[BN254_Snarks]]( N+1 ) cs[0] = negFr(b) cs[N] = a return Poly(coeffs: cs) # the vanishing polynomial `(x^N - 1)` -func vanishingPoly*(N: int): Poly = +func vanishingPoly*(N: int): Poly = return generalizedVanishingPoly(N, oneFr, oneFr) -func vanishingPoly*(D: Domain): Poly = +func vanishingPoly*(D: Domain): Poly = return vanishingPoly(D.domainSize) #------------------------------------------------------------------------------- type QuotRem*[T] = object - quot* : T - rem* : T + quot* : T + rem* : T # divide by the vanishing polynomial `(x^N - 1)` # returns the quotient and remainder -func polyQuotRemByVanishing*(P: Poly, N: int): QuotRem[Poly] = +func polyQuotRemByVanishing*(P: Poly, N: int): QuotRem[Poly] = assert( N>=1 ) let deg : int = polyDegree(P) - let src : seq[Fr] = P.coeffs - var quot : seq[Fr] = newSeq[Fr]( max(1, deg - N + 1) ) - var rem : seq[Fr] = newSeq[Fr]( N ) + let src = P.coeffs + var quot = newSeq[Fr[BN254_Snarks]]( max(1, deg - N + 1) ) + var rem = newSeq[Fr[BN254_Snarks]]( N ) if deg < N: rem = src - + else: # compute quotient @@ -212,7 +217,7 @@ func polyQuotRemByVanishing*(P: Poly, N: int): QuotRem[Poly] = return QuotRem[Poly]( quot:Poly(coeffs:quot), rem:Poly(coeffs:rem) ) # divide by the vanishing polynomial `(x^N - 1)` -func polyDivideByVanishing*(P: Poly, N: int): Poly = +func polyDivideByVanishing*(P: Poly, N: int): Poly = let qr = polyQuotRemByVanishing(P, N) assert( polyIsZero(qr.rem) ) return qr.quot @@ -222,15 +227,15 @@ func polyDivideByVanishing*(P: Poly, N: int): Poly = # Lagrange basis polynomials func lagrangePoly*(D: Domain, k: int): Poly = let N = D.domainSize - let omMinusK : Fr = smallPowFr( D.invDomainGen , k ) - let invN : Fr = invFr(intToFr(N)) + let omMinusK = smallPowFr( D.invDomainGen , k ) + let invN = invFr(intToFr(N)) - var cs : seq[Fr] = newSeq[Fr]( N ) + var cs = newSeq[Fr[BN254_Snarks]]( N ) if k == 0: for i in 0.. 0): @@ -40,7 +40,7 @@ func floorLog2* (x : int) : int = y = y shr 1 return k -func ceilingLog2* (x : int) : int = +func ceilingLog2* (x : int) : int = if (x==0): return -1 else: diff --git a/groth16/prover.nim b/groth16/prover.nim index 70fdcd1..66de332 100644 --- a/groth16/prover.nim +++ b/groth16/prover.nim @@ -2,28 +2,28 @@ # # Groth16 prover # -# WARNING! +# WARNING! # the points H in `.zkey` are *NOT* what normal people would think they are # See # -#[ -import sugar -import constantine/math/config/curves -import constantine/math/io/io_fields -import constantine/math/io/io_bigints -import ./zkey -]# +{.push raises: [].} import std/os import std/times import std/cpuinfo -import system -import taskpools +import std/isolation + +import pkg/taskpools + +import pkg/constantine/math/arithmetic +import pkg/constantine/math/io/io_fields +import pkg/constantine/math/io/io_bigints +import pkg/constantine/named/properties_fields +import pkg/constantine/math/extension_fields/towers -import constantine/math/arithmetic except Fp, Fr #import constantine/math/io/io_extfields except Fp12 -#import constantine/math/extension_fields/towers except Fp2, Fp12 +#import constantine/math/extension_fields/towers except Fp2, Fp12 import groth16/bn128 import groth16/math/domain @@ -36,7 +36,7 @@ import groth16/misc type Proof* = object - publicIO* : seq[Fr] + publicIO* : seq[Fr[BN254_Snarks]] pi_a* : G1 pi_b* : G2 pi_c* : G1 @@ -44,29 +44,35 @@ type #------------------------------------------------------------------------------- # Az, Bz, Cz column vectors -# +# type ABC = object - valuesAz : seq[Fr] - valuesBz : seq[Fr] - valuesCz : seq[Fr] + valuesAz : seq[Fr[BN254_Snarks]] + valuesBz : seq[Fr[BN254_Snarks]] + valuesCz : seq[Fr[BN254_Snarks]] + + ShiftEvalDomainTask* = object + values : seq[Fr[BN254_Snarks]] + D : Domain + eta : Fr[BN254_Snarks] + res : Isolated[seq[Fr[BN254_Snarks]]] # computes the vectors A*z, B*z, C*z where z is the witness -func buildABC( zkey: ZKey, witness: seq[Fr] ): ABC = +func buildABC( zkey: ZKey, witness: seq[Fr[BN254_Snarks]] ): ABC = let hdr: GrothHeader = zkey.header let domSize = hdr.domainSize - var valuesAz : seq[Fr] = newSeq[Fr](domSize) - var valuesBz : seq[Fr] = newSeq[Fr](domSize) + var valuesAz = newSeq[Fr[BN254_Snarks]](domSize) + var valuesBz = newSeq[Fr[BN254_Snarks]](domSize) for entry in zkey.coeffs: - case entry.matrix + case entry.matrix of MatrixA: valuesAz[entry.row] += entry.coeff * witness[entry.col] of MatrixB: valuesBz[entry.row] += entry.coeff * witness[entry.col] else: raise newException(AssertionDefect, "fatal error") - var valuesCz : seq[Fr] = newSeq[Fr](domSize) + var valuesCz = newSeq[Fr[BN254_Snarks]](domSize) for i in 0..= 1) - var ys : seq[Fr] = newSeq[Fr](n) + var ys = newSeq[Fr[BN254_Snarks]](n) ys[0] = xs[0] if n >= 1: ys[1] = eta * xs[1] - var spow : Fr = eta - for i in 2.. # -proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC): seq[Fr] = +proc computeSnarkjsScalarCoeffs( abc: ABC, pool: Taskpool ): seq[Fr[BN254_Snarks]] = let n = abc.valuesAz.len assert( abc.valuesBz.len == n ) assert( abc.valuesCz.len == n ) let D = createDomain(n) let eta = createDomain(2*n).domainGen - var pool = Taskpool.new(num_threads = nthreads) + var a1fvTask = ShiftEvalDomainTask( + values: abc.valuesAz, + D: D, + eta: eta, + ) + var A1fv = pool.spawn shiftEvalDomainTask( addr a1fvTask ) - var A1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesAz, D, eta ) - var B1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesBz, D, eta ) - var C1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesCz, D, eta ) + var b1fvTask = ShiftEvalDomainTask( + values: abc.valuesBz, + D: D, + eta: eta, + ) + var B1fv = pool.spawn shiftEvalDomainTask( addr b1fvTask ) - let A1 = sync A1fv - let B1 = sync B1fv - let C1 = sync C1fv + var c1fvTask = ShiftEvalDomainTask( + values: abc.valuesCz, + D: D, + eta: eta, + ) + var C1fv = pool.spawn shiftEvalDomainTask( addr c1fvTask ) - var ys : seq[Fr] = newSeq[Fr]( n ) - for j in 0.. # #[ import sugar -import constantine/math/config/curves +import constantine/math/config/curves import constantine/math/io/io_fields import constantine/math/io/io_bigints import ./zkey ]# # import constantine/math/arithmetic except Fp, Fr -import constantine/math/io/io_extfields except Fp12 -import constantine/math/extension_fields/towers except Fp2, Fp12 +import pkg/constantine/math/io/io_extfields +import pkg/constantine/math/extension_fields/towers +import pkg/constantine/math/arithmetic +import pkg/constantine/named/properties_fields +import pkg/constantine/math/extension_fields/towers import groth16/bn128 import groth16/zkey_types @@ -36,15 +39,15 @@ proc verifyProof* (vkey: VKey, prf: Proof): bool = assert( isOnCurveG2(prf.pi_b) , "pi_b is not in G2" ) assert( isOnCurveG1(prf.pi_c) , "pi_c is not in G1" ) - var pubG1 : G1 = msmG1( prf.publicIO , vkey.vpoints.pointsIC ) + var pubG1 : G1 = msmG1( prf.publicIO , vkey.vpoints.pointsIC ) - let lhs : Fp12 = pairing( negG1(prf.pi_a) , prf.pi_b ) # < -pi_a , pi_b > - let rhs1 : Fp12 = vkey.spec.alphaBeta # < alpha , beta > - let rhs2 : Fp12 = pairing( prf.pi_c , vkey.spec.delta2 ) # < pi_c , delta > - let rhs3 : Fp12 = pairing( pubG1 , vkey.spec.gamma2 ) # < sum... , gamma > + let lhs = pairing( negG1(prf.pi_a) , prf.pi_b ) # < -pi_a , pi_b > + let rhs1 = vkey.spec.alphaBeta # < alpha , beta > + let rhs2 = pairing( prf.pi_c , vkey.spec.delta2 ) # < pi_c , delta > + let rhs3 = pairing( pubG1 , vkey.spec.gamma2 ) # < sum... , gamma > - var eq : Fp12 - eq = lhs + var eq : Fp12[BN254_Snarks] + eq = lhs eq *= rhs1 eq *= rhs2 eq *= rhs3 diff --git a/groth16/zkey_types.nim b/groth16/zkey_types.nim index 16b3969..a66ed71 100644 --- a/groth16/zkey_types.nim +++ b/groth16/zkey_types.nim @@ -1,12 +1,16 @@ -import constantine/math/arithmetic except Fp, Fr +import pkg/constantine/math/arithmetic +import pkg/constantine/math/io/io_fields +import pkg/constantine/math/io/io_bigints +import pkg/constantine/named/properties_fields +import pkg/constantine/math/extension_fields/towers import groth16/bn128 #------------------------------------------------------------------------------- -type - +type + Flavour* = enum JensGroth # the version described in the original Groth16 paper Snarkjs # the version implemented by Snarkjs @@ -27,8 +31,8 @@ type beta2* : G2 # = beta * g2 gamma2* : G2 # = gamma * g2 delta1* : G1 # = delta * g1 - delta2* : G2 # = delta * g2 - alphaBeta* : Fp12 # = + delta2* : G2 # = delta * g2 + alphaBeta* : Fp12[BN254_Snarks] # = VerifierPoints* = object pointsIC* : seq[G1] # the points `delta^-1 * ( beta*A_j(tau) + alpha*B_j(tau) + C_j(tau) ) * g1` (for j <= npub) @@ -49,7 +53,7 @@ type matrix* : MatrixSel row* : int col* : int - coeff* : Fr + coeff* : Fr[BN254_Snarks] ZKey* = object # sectionMask* : uint32 @@ -59,14 +63,14 @@ type pPoints* : ProverPoints coeffs* : seq[Coeff] - VKey* = object + VKey* = object curve* : string spec* : SpecPoints vpoints* : VerifierPoints #------------------------------------------------------------------------------- -func extractVKey*(zkey: Zkey): VKey = +func extractVKey*(zkey: ZKey): VKey = let curve = zkey.header.curve let spec = zkey.specPoints let vpts = zkey.vPoints @@ -74,32 +78,32 @@ func extractVKey*(zkey: Zkey): VKey = #------------------------------------------------------------------------------- -proc printGrothHeader*(hdr: GrothHeader) = - echo("curve = " & ($hdr.curve ) ) - echo("flavour = " & ($hdr.flavour ) ) - echo("|Fp| = " & (toDecimalBig(hdr.p)) ) - echo("|Fr| = " & (toDecimalBig(hdr.r)) ) - echo("nvars = " & ($hdr.nvars ) ) - echo("npubs = " & ($hdr.npubs ) ) - echo("domainSize = " & ($hdr.domainSize ) ) - echo("logDomainSize= " & ($hdr.logDomainSize) ) +proc printGrothHeader*(hdr: GrothHeader) = + echo("curve = " & ($hdr.curve ) ) + echo("flavour = " & ($hdr.flavour ) ) + echo("|Fp| = " & (toDecimalBig(hdr.p)) ) + echo("|Fr| = " & (toDecimalBig(hdr.r)) ) + echo("nvars = " & ($hdr.nvars ) ) + echo("npubs = " & ($hdr.npubs ) ) + echo("domainSize = " & ($hdr.domainSize ) ) + echo("logDomainSize= " & ($hdr.logDomainSize) ) #------------------------------------------------------------------------------- -func matrixSelToString(sel: MatrixSel): string = - case sel +func matrixSelToString(sel: MatrixSel): string = + case sel of MatrixA: return "A" of MatrixB: return "B" of MatrixC: return "C" -proc debugPrintCoeff(cf: Coeff) = +proc debugPrintCoeff(cf: Coeff) = echo( "matrix=", matrixSelToString(cf.matrix) , " | i=", cf.row , " | j=", cf.col , " | val=", signedToDecimalFr(cf.coeff) ) -proc debugPrintCoeffs*(cfs: seq[Coeff]) = +proc debugPrintCoeffs*(cfs: seq[Coeff]) = for cf in cfs: debugPrintCoeff(cf) #------------------------------------------------------------------------------- diff --git a/tests/groth16/testProver.nim b/tests/groth16/testProver.nim index cdf8a5b..51bcd0b 100644 --- a/tests/groth16/testProver.nim +++ b/tests/groth16/testProver.nim @@ -19,7 +19,7 @@ const myWitnessCfg = , nPubOut: 1 # public output = input + a*b*c = 1022 + 7*11*13 = 2023 , nPubIn: 1 # public input = 1022 , nPrivIn: 3 # private inputs: 7, 11, 13 - , nLabels: 0 + , nLabels: 0 ) # 2023 == 1022 + 7*3*11 @@ -44,10 +44,10 @@ const myR1CS = ) # the equation we want prove is `7*11*13 + 1022 == 2023` -let myWitnessValues : seq[Fr] = map( @[ 1, 2023, 1022, 7, 11, 13, 7*11, 7*11*13 ] , intToFr ) +let myWitnessValues = map( @[ 1, 2023, 1022, 7, 11, 13, 7*11, 7*11*13 ] , intToFr ) # wire indices: ^^^^^^^ 0 1 2 3 4 5 6 7 -let myWitness = +let myWitness = Witness( curve: "bn128" , r: primeR , nvars: 8 @@ -56,8 +56,8 @@ let myWitness = #------------------------------------------------------------------------------- -proc testProof(zkey: ZKey, witness: Witness): bool = - let proof = generateProof( zkey, witness ) +proc testProof(zkey: ZKey, witness: Witness): bool = + let proof = generateProof( zkey, witness, Taskpool.new() ) let vkey = extractVKey( zkey) let ok = verifyProof( vkey, proof ) return ok @@ -65,11 +65,11 @@ proc testProof(zkey: ZKey, witness: Witness): bool = suite "prover": test "prove & verify simple multiplication circuit, `JensGroth` flavour": - let zkey = createFakeCircuitSetup( myR1cs, flavour=JensGroth ) + let zkey = createFakeCircuitSetup( myR1CS, flavour=JensGroth ) check testProof( zkey, myWitness ) test "prove & verify simple multiplication circuit, `Snarkjs` flavour": - let zkey = createFakeCircuitSetup( myR1cs, flavour=Snarkjs ) + let zkey = createFakeCircuitSetup( myR1CS, flavour=Snarkjs ) check testProof( zkey, myWitness ) #-------------------------------------------------------------------------------