From 2047e7e47695cc7e7409a0b45bd4ae82c20e1e4e Mon Sep 17 00:00:00 2001 From: Balazs Komuves Date: Tue, 14 Nov 2023 12:39:54 +0100 Subject: [PATCH 1/6] started some refactoring. However Nim import "mechanism" *does not help*... --- bn128.nim | 790 +---------------------------------------------- bn128/curves.nim | 231 ++++++++++++++ bn128/debug.nim | 45 +++ bn128/fields.nim | 186 +++++++++++ bn128/io.nim | 253 +++++++++++++++ bn128/msm.nim | 156 ++++++++++ bn128/rnd.nim | 76 +++++ 7 files changed, 959 insertions(+), 778 deletions(-) create mode 100644 bn128/curves.nim create mode 100644 bn128/debug.nim create mode 100644 bn128/fields.nim create mode 100644 bn128/io.nim create mode 100644 bn128/msm.nim create mode 100644 bn128/rnd.nim diff --git a/bn128.nim b/bn128.nim index b5c9bc0..d36a5fe 100644 --- a/bn128.nim +++ b/bn128.nim @@ -9,789 +9,23 @@ # equation: y^2 = x^3 + 3 # -import sugar - -import std/bitops -import std/strutils -import std/sequtils -import std/streams -import std/random - -import constantine/platforms/abstractions -import constantine/math/isogenies/frobenius - -import constantine/math/arithmetic -import constantine/math/io/io_fields -import constantine/math/io/io_bigints -import constantine/math/config/curves -import constantine/math/config/type_ff as tff - -import constantine/math/extension_fields/towers as ext -import constantine/math/elliptic/ec_shortweierstrass_affine as aff -import constantine/math/elliptic/ec_shortweierstrass_projective as prj -import constantine/math/pairings/pairings_bn as ate -import constantine/math/elliptic/ec_scalar_mul as scl -import constantine/math/elliptic/ec_multi_scalar_mul as msm - #------------------------------------------------------------------------------- -type B* = BigInt[256] -type Fr* = tff.Fr[BN254Snarks] -type Fp* = tff.Fp[BN254Snarks] - -type Fp2* = ext.QuadraticExt[Fp] -type Fp12* = ext.Fp12[BN254Snarks] - -type G1* = aff.ECP_ShortW_Aff[Fp , aff.G1] -type G2* = aff.ECP_ShortW_Aff[Fp2, aff.G2] - -type ProjG1* = prj.ECP_ShortW_Prj[Fp , prj.G1] -type ProjG2* = prj.ECP_ShortW_Prj[Fp2, prj.G2] - -func mkFp2* (i: Fp, u: Fp) : Fp2 = - let c : array[2, Fp] = [i,u] - return ext.QuadraticExt[Fp]( coords: c ) - -func unsafeMkG1* ( X, Y: Fp ) : G1 = - return aff.ECP_ShortW_Aff[Fp, aff.G1](x: X, y: Y) - -func unsafeMkG2* ( X, Y: Fp2 ) : G2 = - return aff.ECP_ShortW_Aff[Fp2, aff.G2](x: X, y: Y) - -#------------------------------------------------------------------------------- - -func pairing* (p: G1, q: G2) : Fp12 = - var t : Fp12 - pairing_bn( t, p, q ) - return t - -#------------------------------------------------------------------------------- - -const primeP* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian ) -const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian ) - -const primeP_254 : BigInt[254] = fromHex( BigInt[254], "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian ) -const primeR_254 : BigInt[254] = fromHex( BigInt[254], "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian ) - -#------------------------------------------------------------------------------- - -const zeroFp* : Fp = fromHex( Fp, "0x00" ) -const zeroFr* : Fr = fromHex( Fr, "0x00" ) -const oneFp* : Fp = fromHex( Fp, "0x01" ) -const oneFr* : Fr = fromHex( Fr, "0x01" ) - -const zeroFp2* : Fp2 = mkFp2( zeroFp, zeroFp ) -const oneFp2* : Fp2 = mkFp2( oneFp , zeroFp ) - -const infG1* : G1 = unsafeMkG1( zeroFp , zeroFp ) -const infG2* : G2 = unsafeMkG2( zeroFp2 , zeroFp2 ) - -#------------------------------------------------------------------------------- - -func intToB*(a: uint): B = - var y : B - y.setUint(a) - return y - -func intToFp*(a: int): Fp = - var y : Fp - y.fromInt(a) - return y - -func intToFr*(a: int): Fr = - var y : Fr - y.fromInt(a) - return y - -#------------------------------------------------------------------------------- - -func isZeroFp*(x: Fp): bool = bool(isZero(x)) -func isZeroFr*(x: Fr): bool = bool(isZero(x)) - -func isEqualFp*(x, y: Fp): bool = bool(x == y) -func isEqualFr*(x, y: Fr): bool = bool(x == y) - -func `===`*(x, y: Fp): bool = isEqualFp(x,y) -func `===`*(x, y: Fr): bool = isEqualFr(x,y) +import ./bn128/fields +import ./bn128/curves +import ./bn128/msm +import ./bn128/io +import ./bn128/rnd +import ./bn128/debug #------------------- -func isEqualFpSeq*(xs, ys: seq[Fp]): bool = - let n = xs.len - assert( n == ys.len ) - var b = true - for i in 0.. 0): - if bitand(e,1) > 0: a *= s - e = (e shr 1) - square(s) - return a - -func smallPowFr*(base: Fr, expo: int): Fr = - if expo >= 0: - return smallPowFr( base, uint(expo) ) - else: - return smallPowFr( invFr(base) , uint(-expo) ) - -#------------------------------------------------------------------------------- - -func deltaFr*(i, j: int) : Fr = - return (if (i == j): oneFr else: zeroFr) - -#------------------------------------------------------------------------------- - -func toDecimalBig*[n](a : BigInt[n]): string = - var s : string = toDecimal(a) - s = s.strip( leading=true, trailing=false, chars={'0'} ) - if s.len == 0: s="0" - return s - -func toDecimalFp*(a : Fp): string = - var s : string = toDecimal(a) - s = s.strip( leading=true, trailing=false, chars={'0'} ) - if s.len == 0: s="0" - return s - -func toDecimalFr*(a : Fr): string = - var s : string = toDecimal(a) - s = s.strip( leading=true, trailing=false, chars={'0'} ) - if s.len == 0: s="0" - return s - -#--------------------------------------- - -const k65536 : BigInt[254] = fromHex( BigInt[254], "0x10000", bigEndian ) - -func signedToDecimalFp*(a : Fp): string = - if bool( a.toBig() > primeP_254 - k65536 ): - return "-" & toDecimalFp(negFp(a)) - else: - return toDecimalFp(a) - -func signedToDecimalFr*(a : Fr): string = - if bool( a.toBig() > primeR_254 - k65536 ): - return "-" & toDecimalFr(negFr(a)) - else: - return toDecimalFr(a) - -#------------------------------------------------------------------------------- - -proc debugPrintFp*(prefix: string, x: Fp) = - echo(prefix & toDecimalFp(x)) - -proc debugPrintFp2*(prefix: string, z: Fp2) = - echo(prefix & " 1 ~> " & toDecimalFp(z.coords[0])) - echo(prefix & " u ~> " & toDecimalFp(z.coords[1])) - -proc debugPrintFr*(prefix: string, x: Fr) = - echo(prefix & toDecimalFr(x)) - -proc debugPrintFrSeq*(msg: string, xs: seq[Fr]) = - echo "---------------------" - echo msg - for x in xs: - debugPrintFr( " " , x ) - -proc debugPrintG1*(msg: string, pt: G1) = - echo(msg & ":") - debugPrintFp( " x = ", pt.x ) - debugPrintFp( " y = ", pt.y ) - -proc debugPrintG2*(msg: string, pt: G2) = - echo(msg & ":") - debugPrintFp2( " x = ", pt.x ) - debugPrintFp2( " y = ", pt.y ) - -#------------------------------------------------------------------------------- - -# Montgomery batch inversion -func batchInverse*( xs: seq[Fr] ) : seq[Fr] = - let n = xs.len - assert(n>0) - var us : seq[Fr] = newSeq[Fr](n+1) - var a = xs[0] - us[0] = oneFr - us[1] = a - for i in 1.. " & toDecimalFp(z.coords[0])) + echo(prefix & " u ~> " & toDecimalFp(z.coords[1])) + +proc debugPrintFr*(prefix: string, x: Fr) = + echo(prefix & toDecimalFr(x)) + +proc debugPrintFrSeq*(msg: string, xs: seq[Fr]) = + echo "---------------------" + echo msg + for x in xs: + debugPrintFr( " " , x ) + +proc debugPrintG1*(msg: string, pt: G1) = + echo(msg & ":") + debugPrintFp( " x = ", pt.x ) + debugPrintFp( " y = ", pt.y ) + +proc debugPrintG2*(msg: string, pt: G2) = + echo(msg & ":") + debugPrintFp2( " x = ", pt.x ) + debugPrintFp2( " y = ", pt.y ) + +#------------------------------------------------------------------------------- + diff --git a/bn128/fields.nim b/bn128/fields.nim new file mode 100644 index 0000000..84d4805 --- /dev/null +++ b/bn128/fields.nim @@ -0,0 +1,186 @@ + +# +# the prime fields Fp and Fr with sizes +# +# p = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +# r = 21888242871839275222246405745257275088548364400416034343698204186575808495617 +# + +import sugar + +import std/bitops +import std/sequtils + +import constantine/math/arithmetic +import constantine/math/io/io_fields +import constantine/math/io/io_bigints +import constantine/math/config/curves +import constantine/math/config/type_ff as tff +import constantine/math/extension_fields/towers as ext + +#------------------------------------------------------------------------------- + +type B* = BigInt[256] +type Fr* = tff.Fr[BN254Snarks] +type Fp* = tff.Fp[BN254Snarks] + +type Fp2* = ext.QuadraticExt[Fp] +type Fp12* = ext.Fp12[BN254Snarks] + +func mkFp2* (i: Fp, u: Fp) : Fp2 = + let c : array[2, Fp] = [i,u] + return ext.QuadraticExt[Fp]( coords: c ) + +#------------------------------------------------------------------------------- + +const primeP* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian ) +const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian ) + +#------------------------------------------------------------------------------- + +const zeroFp* : Fp = fromHex( Fp, "0x00" ) +const zeroFr* : Fr = fromHex( Fr, "0x00" ) +const oneFp* : Fp = fromHex( Fp, "0x01" ) +const oneFr* : Fr = fromHex( Fr, "0x01" ) + +const zeroFp2* : Fp2 = mkFp2( zeroFp, zeroFp ) +const oneFp2* : Fp2 = mkFp2( oneFp , zeroFp ) + +#------------------------------------------------------------------------------- + +func intToB*(a: uint): B = + var y : B + y.setUint(a) + return y + +func intToFp*(a: int): Fp = + var y : Fp + y.fromInt(a) + return y + +func intToFr*(a: int): Fr = + var y : Fr + y.fromInt(a) + return y + +#------------------------------------------------------------------------------- + +func isZeroFp* (x: Fp ): bool = bool(isZero(x)) +func isZeroFp2*(x: Fp2): bool = bool(isZero(x)) +func isZeroFr* (x: Fr ): bool = bool(isZero(x)) + +func isEqualFp* (x, y: Fp ): bool = bool(x == y) +func isEqualFp2*(x, y: Fp2): bool = bool(x == y) +func isEqualFr* (x, y: Fr ): bool = bool(x == y) + +func `===`*(x, y: Fp ): bool = isEqualFp(x,y) +func `===`*(x, y: Fp2): bool = isEqualFp2(x,y) +func `===`*(x, y: Fr ): bool = isEqualFr(x,y) + +#------------------- + +func isEqualFpSeq*(xs, ys: seq[Fp]): bool = + let n = xs.len + assert( n == ys.len ) + var b = true + for i in 0.. 0): + if bitand(e,1) > 0: a *= s + e = (e shr 1) + square(s) + return a + +func smallPowFr*(base: Fr, expo: int): Fr = + if expo >= 0: + return smallPowFr( base, uint(expo) ) + else: + return smallPowFr( invFr(base) , uint(-expo) ) + +#------------------------------------------------------------------------------- + +func deltaFr*(i, j: int) : Fr = + return (if (i == j): oneFr else: zeroFr) + +#------------------------------------------------------------------------------- + +# Montgomery batch inversion +func batchInverseFr*( xs: seq[Fr] ) : seq[Fr] = + let n = xs.len + assert(n>0) + var us : seq[Fr] = newSeq[Fr](n+1) + var a = xs[0] + us[0] = oneFr + us[1] = a + for i in 1.. primeP_254 - k65536 ): + return "-" & toDecimalFp(negFp(a)) + else: + return toDecimalFp(a) + +func signedToDecimalFr*(a : Fr): string = + if bool( a.toBig() > primeR_254 - k65536 ): + return "-" & toDecimalFr(negFr(a)) + else: + return toDecimalFr(a) + +#------------------------------------------------------------------------------- +# Dealing with Montgomery representation +# + +# R=2^256; this computes 2^256 mod Fp +func calcFpMontR*() : Fp = + var x : Fp = intToFp(2) + for i in 1..8: + square(x) + return x + +# R=2^256; this computes the inverse of (2^256 mod Fp) +func calcFpInvMontR*() : Fp = + var x : Fp = calcFpMontR() + inv(x) + return x + +# R=2^256; this computes 2^256 mod Fr +func calcFrMontR*() : Fr = + var x : Fr = intToFr(2) + for i in 1..8: + square(x) + return x + +# R=2^256; this computes the inverse of (2^256 mod Fp) +func calcFrInvMontR*() : Fr = + var x : Fr = calcFrMontR() + inv(x) + return x + +# apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?) +const fpMontR* : Fp = fromHex( Fp, "0x0e0a77c19a07df2f666ea36f7879462c0a78eb28f5c70b3dd35d438dc58f0d9d" ) +const fpInvMontR* : Fp = fromHex( Fp, "0x2e67157159e5c639cf63e9cfb74492d9eb2022850278edf8ed84884a014afa37" ) + +# apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?) +const frMontR* : Fr = fromHex( Fr, "0x0e0a77c19a07df2f666ea36f7879462e36fc76959f60cd29ac96341c4ffffffb" ) +const frInvMontR* : Fr = fromHex( Fr, "0x15ebf95182c5551cc8260de4aeb85d5d090ef5a9e111ec87dc5ba0056db1194e" ) + +proc checkMontgomeryConstants*() = + assert( bool( fpMontR == calcFpMontR() ) ) + assert( bool( frMontR == calcFrMontR() ) ) + assert( bool( fpInvMontR == calcFpInvMontR() ) ) + assert( bool( frInvMontR == calcFrInvMontR() ) ) + echo("OK") + +#--------------------------------------- + +# the binary file `.zkey` used by the `circom` ecosystem uses little-endian +# Montgomery representation. So when we unmarshal with Constantine, it will +# give the wrong result. Calling this function on the result fixes that. +func fromMontgomeryFp*(x : Fp) : Fp = + var y : Fp = x; + y *= fpInvMontR + return y + +func fromMontgomeryFr*(x : Fr) : Fr = + var y : Fr = x; + y *= frInvMontR + return y + +func toMontgomeryFr*(x : Fr) : Fr = + var y : Fr = x; + y *= frMontR + return y + +#------------------------------------------------------------------------------- +# Unmarshalling field elements +# Note: in the `.zkey` coefficients, e apparently DOUBLE Montgomery encoding is used ?!? +# + +func unmarshalFpMont* ( bs: array[32,byte] ) : Fp = + var big : BigInt[254] + unmarshal( big, bs, littleEndian ); + var x : Fp + x.fromBig( big ) + return fromMontgomeryFp(x) + +# WTF Jordi, go home you are drunk +func unmarshalFrWTF* ( bs: array[32,byte] ) : Fr = + var big : BigInt[254] + unmarshal( big, bs, littleEndian ); + var x : Fr + x.fromBig( big ) + return fromMontgomeryFr(fromMontgomeryFr(x)) + +func unmarshalFrStd* ( bs: array[32,byte] ) : Fr = + var big : BigInt[254] + unmarshal( big, bs, littleEndian ); + var x : Fr + x.fromBig( big ) + return x + +func unmarshalFrMont* ( bs: array[32,byte] ) : Fr = + var big : BigInt[254] + unmarshal( big, bs, littleEndian ); + var x : Fr + x.fromBig( big ) + return fromMontgomeryFr(x) + +#------------------------------------------------------------------------------- + +func unmarshalFpMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fp] = + var vals : seq[Fp] = newSeq[Fp]( len ) + var bytes : array[32,byte] + for i in 0..= 1024, we use 8 threads + let nthreads = max( 1 , min( N div 128 , nthreadsTarget ) ) + + let m = N div nthreads + + var threads : seq[Thread[InputTuple]] = newSeq[Thread[InputTuple]]( nthreads ) + var results : seq[G1] = newSeq[G1]( nthreads ) + + proc myThreadFunc( inp: InputTuple ) {.thread.} = + results[inp.idx] = msmConstantineG1( inp.coeffs, inp.points ) + + for i in 0.. Date: Tue, 14 Nov 2023 12:40:13 +0100 Subject: [PATCH 2/6] measure proof time --- test_proof.nim | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test_proof.nim b/test_proof.nim index 66f021e..6e2f477 100644 --- a/test_proof.nim +++ b/test_proof.nim @@ -1,4 +1,7 @@ + +import std/[times, os] + import ./groth16 import ./witness import ./r1cs @@ -18,7 +21,11 @@ proc testProveAndVerify*( zkey_fname, wtns_fname: string): Proof = echo("generating proof...") let vkey = extractVKey( zkey) + + let start = cpuTime() let proof = generateProof( zkey, witness ) + let elapsed = cpuTime() - start + echo("proving took ",elapsed) echo("verifying the proof...") let ok = verifyProof( vkey, proof ) @@ -42,7 +49,11 @@ proc testFakeSetupAndVerify*( r1cs_fname, wtns_fname: string, flavour=Snarkjs): echo("generating proof...") let vkey = extractVKey( zkey) + + let start = cpuTime() let proof = generateProof( zkey, witness ) + let elapsed = cpuTime() - start + echo("proving took ",elapsed) echo("verifying the proof...") let ok = verifyProof( vkey, proof ) From 26f708f9086b4d904c0a0fae887e32dcb1804056 Mon Sep 17 00:00:00 2001 From: Balazs Komuves Date: Fri, 17 Nov 2023 09:43:43 +0100 Subject: [PATCH 3/6] make it a nimble package --- README.md | 9 +- example/example.nim | 11 - groth16.nim | 246 +----------------- groth16.nimble | 9 + bn128.nim => groth16/bn128.nim | 12 +- {bn128 => groth16/bn128}/curves.nim | 2 +- {bn128 => groth16/bn128}/debug.nim | 6 +- {bn128 => groth16/bn128}/fields.nim | 3 + {bn128 => groth16/bn128}/io.nim | 4 +- {bn128 => groth16/bn128}/msm.nim | 6 +- {bn128 => groth16/bn128}/rnd.nim | 2 +- {example => groth16/example}/.gitignore | 0 groth16/example/example.nim | 18 ++ {example => groth16/example}/product.circom | 0 {example => groth16/example}/prove.sh | 0 fake_setup.nim => groth16/fake_setup.nim | 94 ++++--- container.nim => groth16/files/container.nim | 0 .../files/export_json.nim | 4 +- r1cs.nim => groth16/files/r1cs.nim | 4 +- witness.nim => groth16/files/witness.nim | 4 +- zkey.nim => groth16/files/zkey.nim | 8 +- domain.nim => groth16/math/domain.nim | 4 +- ntt.nim => groth16/math/ntt.nim | 4 +- poly.nim => groth16/math/poly.nim | 12 +- misc.nim => groth16/misc.nim | 0 groth16/prover.nim | 216 +++++++++++++++ test_proof.nim => groth16/test_proof.nim | 20 +- groth16/verifier.nim | 54 ++++ zkey_types.nim => groth16/zkey_types.nim | 8 +- tests/groth16/testProver.nim | 75 ++++++ tests/nim.cfg | 1 + tests/test.nim | 3 + 32 files changed, 490 insertions(+), 349 deletions(-) delete mode 100644 example/example.nim create mode 100644 groth16.nimble rename bn128.nim => groth16/bn128.nim (77%) rename {bn128 => groth16/bn128}/curves.nim (99%) rename {bn128 => groth16/bn128}/debug.nim (93%) rename {bn128 => groth16/bn128}/fields.nim (96%) rename {bn128 => groth16/bn128}/io.nim (99%) rename {bn128 => groth16/bn128}/msm.nim (97%) rename {bn128 => groth16/bn128}/rnd.nim (98%) rename {example => groth16/example}/.gitignore (100%) create mode 100644 groth16/example/example.nim rename {example => groth16/example}/product.circom (100%) rename {example => groth16/example}/prove.sh (100%) rename fake_setup.nim => groth16/fake_setup.nim (77%) rename container.nim => groth16/files/container.nim (100%) rename export_json.nim => groth16/files/export_json.nim (98%) rename r1cs.nim => groth16/files/r1cs.nim (99%) rename witness.nim => groth16/files/witness.nim (97%) rename zkey.nim => groth16/files/zkey.nim (98%) rename domain.nim => groth16/math/domain.nim (97%) rename ntt.nim => groth16/math/ntt.nim (99%) rename poly.nim => groth16/math/poly.nim (98%) rename misc.nim => groth16/misc.nim (100%) create mode 100644 groth16/prover.nim rename test_proof.nim => groth16/test_proof.nim (84%) create mode 100644 groth16/verifier.nim rename zkey_types.nim => groth16/zkey_types.nim (95%) create mode 100644 tests/groth16/testProver.nim create mode 100644 tests/nim.cfg create mode 100644 tests/test.nim diff --git a/README.md b/README.md index e8990c9..e433d34 100644 --- a/README.md +++ b/README.md @@ -19,12 +19,13 @@ at your choice. ### TODO -- [ ] make it a nimble package -- [ ] refactor `bn128.nim` into smaller files -- [ ] proper MSM implementation (I couldn't make constantine's one to work) +- [ ] clean up the code +- [x] make it a nimble package +- [/] refactor `bn128.nim` into smaller files +- [/] proper MSM implementation (at first I couldn't make constantine's one to work) - [x] compare `.r1cs` to the "coeffs" section of `.zkey` - [x] generate fake circuit-specific setup ourselves -- [ ] multithreaded support (MSM, and possibly also FFT) +- [ ] multithreading support (MSM, and possibly also FFT) - [ ] add Groth16 notes - [ ] document the `snarkjs` circuit-specific setup `H` points convention - [ ] make it work for different curves diff --git a/example/example.nim b/example/example.nim deleted file mode 100644 index d62933c..0000000 --- a/example/example.nim +++ /dev/null @@ -1,11 +0,0 @@ - -import ../test_proof -import ../export_json - -let zkey_fname : string = "./build/product.zkey" -let wtns_fname : string = "./build/product.wtns" -let proof = testProveAndVerify( zkey_fname, wtns_fname) - -exportPublicIO( "./build/nim_public.json" , proof ) -exportProof( "./build/nim_proof.json" , proof ) - diff --git a/groth16.nim b/groth16.nim index c0e072e..80c02c0 100644 --- a/groth16.nim +++ b/groth16.nim @@ -1,242 +1,6 @@ -# -# Groth16 prover -# -# WARNING! -# the points H in `.zkey` are *NOT* what normal people would think they are -# See -# - -#[ -import sugar -import constantine/math/config/curves -import constantine/math/io/io_fields -import constantine/math/io/io_bigints -import ./zkey -]# - -import constantine/math/arithmetic except Fp, Fr -import constantine/math/io/io_extfields except Fp12 -import constantine/math/extension_fields/towers except Fp2, Fp12 - -import ./bn128 -import ./domain -import ./poly -import ./zkey_types -import ./witness - -#------------------------------------------------------------------------------- - -type - Proof* = object - publicIO* : seq[Fr] - pi_a* : G1 - pi_b* : G2 - pi_c* : G1 - curve : string - -#------------------------------------------------------------------------------- -# the verifier -# - -proc verifyProof* (vkey: VKey, prf: Proof): bool = - - assert( prf.curve == "bn128" ) - - assert( isOnCurveG1(prf.pi_a) , "pi_a is not in G1" ) - assert( isOnCurveG2(prf.pi_b) , "pi_b is not in G2" ) - assert( isOnCurveG1(prf.pi_c) , "pi_c is not in G1" ) - - var pubG1 : G1 = msmG1( prf.publicIO , vkey.vpoints.pointsIC ) - - let lhs : Fp12 = pairing( negG1(prf.pi_a) , prf.pi_b ) # < -pi_a , pi_b > - let rhs1 : Fp12 = vkey.spec.alphaBeta # < alpha , beta > - let rhs2 : Fp12 = pairing( prf.pi_c , vkey.spec.delta2 ) # < pi_c , delta > - let rhs3 : Fp12 = pairing( pubG1 , vkey.spec.gamma2 ) # < sum... , gamma > - - var eq : Fp12 - eq = lhs - eq *= rhs1 - eq *= rhs2 - eq *= rhs3 - - return bool(isOne(eq)) - -#------------------------------------------------------------------------------- -# A, B, C column vectors -# - -type - ABC = object - valuesA : seq[Fr] - valuesB : seq[Fr] - valuesC : seq[Fr] - -func buildABC( zkey: ZKey, witness: seq[Fr] ): ABC = - let hdr: GrothHeader = zkey.header - let domSize = hdr.domainSize - - var valuesA : seq[Fr] = newSeq[Fr](domSize) - var valuesB : seq[Fr] = newSeq[Fr](domSize) - for entry in zkey.coeffs: - case entry.matrix - of MatrixA: valuesA[entry.row] += entry.coeff * witness[entry.col] - of MatrixB: valuesB[entry.row] += entry.coeff * witness[entry.col] - else: raise newException(AssertionDefect, "fatal error") - - var valuesC : seq[Fr] = newSeq[Fr](domSize) - for i in 0..= 1) - var ys : seq[Fr] = newSeq[Fr](n) - ys[0] = xs[0] - if n >= 1: ys[1] = eta * xs[1] - var spow : Fr = eta - for i in 2.. -# -func computeSnarkjsScalarCoeffs( abc: ABC ): seq[Fr] = - let n = abc.valuesA.len - let D = createDomain(n) - let eta = createDomain(2*n).domainGen - let A1 = shiftEvalDomain( abc.valuesA, D, eta ) - let B1 = shiftEvalDomain( abc.valuesB, D, eta ) - let C1 = shiftEvalDomain( abc.valuesC, D, eta ) - var ys : seq[Fr] = newSeq[Fr]( n ) - for j in 0.. +# + +#[ +import sugar +import constantine/math/config/curves +import constantine/math/io/io_fields +import constantine/math/io/io_bigints +import ./zkey +]# + +import constantine/math/arithmetic except Fp, Fr +#import constantine/math/io/io_extfields except Fp12 +#import constantine/math/extension_fields/towers except Fp2, Fp12 + +import groth16/bn128 +import groth16/math/domain +import groth16/math/poly +import groth16/zkey_types +import groth16/files/witness + +#------------------------------------------------------------------------------- + +type + Proof* = object + publicIO* : seq[Fr] + pi_a* : G1 + pi_b* : G2 + pi_c* : G1 + curve* : string + +#------------------------------------------------------------------------------- +# A, B, C column vectors +# + +type + ABC = object + valuesA : seq[Fr] + valuesB : seq[Fr] + valuesC : seq[Fr] + +func buildABC( zkey: ZKey, witness: seq[Fr] ): ABC = + let hdr: GrothHeader = zkey.header + let domSize = hdr.domainSize + + var valuesA : seq[Fr] = newSeq[Fr](domSize) + var valuesB : seq[Fr] = newSeq[Fr](domSize) + for entry in zkey.coeffs: + case entry.matrix + of MatrixA: valuesA[entry.row] += entry.coeff * witness[entry.col] + of MatrixB: valuesB[entry.row] += entry.coeff * witness[entry.col] + else: raise newException(AssertionDefect, "fatal error") + + var valuesC : seq[Fr] = newSeq[Fr](domSize) + for i in 0..= 1) + var ys : seq[Fr] = newSeq[Fr](n) + ys[0] = xs[0] + if n >= 1: ys[1] = eta * xs[1] + var spow : Fr = eta + for i in 2.. +# +func computeSnarkjsScalarCoeffs( abc: ABC ): seq[Fr] = + let n = abc.valuesA.len + let D = createDomain(n) + let eta = createDomain(2*n).domainGen + let A1 = shiftEvalDomain( abc.valuesA, D, eta ) + let B1 = shiftEvalDomain( abc.valuesB, D, eta ) + let C1 = shiftEvalDomain( abc.valuesC, D, eta ) + var ys : seq[Fr] = newSeq[Fr]( n ) + for j in 0.. + let rhs1 : Fp12 = vkey.spec.alphaBeta # < alpha , beta > + let rhs2 : Fp12 = pairing( prf.pi_c , vkey.spec.delta2 ) # < pi_c , delta > + let rhs3 : Fp12 = pairing( pubG1 , vkey.spec.gamma2 ) # < sum... , gamma > + + var eq : Fp12 + eq = lhs + eq *= rhs1 + eq *= rhs2 + eq *= rhs3 + + return bool(isOne(eq)) + +#------------------------------------------------------------------------------- diff --git a/zkey_types.nim b/groth16/zkey_types.nim similarity index 95% rename from zkey_types.nim rename to groth16/zkey_types.nim index f1fc815..ca51f8a 100644 --- a/zkey_types.nim +++ b/groth16/zkey_types.nim @@ -1,7 +1,7 @@ import constantine/math/arithmetic except Fp, Fr -import ./bn128 +import groth16/bn128 #------------------------------------------------------------------------------- @@ -80,14 +80,14 @@ func matrixSelToString(sel: MatrixSel): string = of MatrixB: return "B" of MatrixC: return "C" -proc printCoeff(cf: Coeff) = +proc debugPrintCoeff(cf: Coeff) = echo( "matrix=", matrixSelToString(cf.matrix) , " | i=", cf.row , " | j=", cf.col , " | val=", signedToDecimalFr(cf.coeff) ) -proc printCoeffs*(cfs: seq[Coeff]) = - for cf in cfs: printCoeff(cf) +proc debugPrintCoeffs*(cfs: seq[Coeff]) = + for cf in cfs: debugPrintCoeff(cf) #------------------------------------------------------------------------------- diff --git a/tests/groth16/testProver.nim b/tests/groth16/testProver.nim new file mode 100644 index 0000000..ba1e8e9 --- /dev/null +++ b/tests/groth16/testProver.nim @@ -0,0 +1,75 @@ + +import std/unittest +import std/sequtils + +import groth16/prover +import groth16/verifier +import groth16/fake_setup +import groth16/zkey_types +import groth16/files/witness +import groth16/files/r1cs +import groth16/bn128/fields + +#------------------------------------------------------------------------------- +# simple hand-crafted arithmetic circuit +# + +const myWitnessCfg = + WitnessConfig( nWires: 7 + , nPubOut: 1 # public output = input + a*b*c = 1022 + 7*11*13 = 2023 + , nPubIn: 1 # public input = 1022 + , nPrivIn: 3 # private inputs: 7, 11, 13 + , nLabels: 0 + ) + +# 2023 == 1022 + 7*3*11 +const myEq1 : Constraint = ( @[] , @[] , @[ (0,minusOneFr) , (1,oneFr) , (6,oneFr) ] ) + +# 7*11 == 77 +const myEq2 : Constraint = ( @[ (2,oneFr) ] , @[ (3,oneFr) ] , @[ (5,oneFr) ] ) + +# 77*13 == 1001 +const myEq3 : Constraint = ( @[ (4,oneFr) ] , @[ (5,oneFr) ] , @[ (6,oneFr) ] ) + +const myConstraints : seq[Constraint] = @[ myEq1, myEq2, myEq3 ] + +const myLabels : seq[int] = @[] + +const myR1CS = + R1CS( r: primeR + , cfg: myWitnessCfg + , nConstr: myConstraints.len + , constraints: myConstraints + , wireToLabel: myLabels + ) + +# the equation we want prove is `7*11*13 + 1022 == 2023` +let myWitnessValues : seq[Fr] = map( @[ 2023, 1022, 7, 11, 13, 7*11, 7*11*13 ] , intToFr ) +# wire indices: ^^^^^^^ 0 1 2 3 4 5 6 + +let myWitness = + Witness( curve: "bn128" + , r: primeR + , nvars: 7 + , values: myWitnessValues + ) + +#------------------------------------------------------------------------------- + +proc testProof(zkey: ZKey, witness: Witness): bool = + let proof = generateProof( zkey, witness ) + let vkey = extractVKey( zkey) + let ok = verifyProof( vkey, proof ) + return ok + +suite "prover": + + test "prove & verify simple multiplication circuit, `JensGroth` flavour": + let zkey = createFakeCircuitSetup( myR1cs, flavour=JensGroth ) + check testProof( zkey, myWitness ) + + test "prove & verify simple multiplication circuit, `Snarkjs` flavour": + let zkey = createFakeCircuitSetup( myR1cs, flavour=Snarkjs ) + check testProof( zkey, myWitness ) + +#------------------------------------------------------------------------------- diff --git a/tests/nim.cfg b/tests/nim.cfg new file mode 100644 index 0000000..0f840a1 --- /dev/null +++ b/tests/nim.cfg @@ -0,0 +1 @@ +--path:".." diff --git a/tests/test.nim b/tests/test.nim new file mode 100644 index 0000000..f2bf786 --- /dev/null +++ b/tests/test.nim @@ -0,0 +1,3 @@ + +import ./groth16/testProver + From 502e031e9523ab7359d7bfcf522d14cca5afca9d Mon Sep 17 00:00:00 2001 From: Balazs Komuves Date: Tue, 21 Nov 2023 11:31:27 +0100 Subject: [PATCH 4/6] small comments / fix typos --- groth16/files/export_json.nim | 2 ++ groth16/files/r1cs.nim | 2 +- groth16/files/witness.nim | 4 +--- groth16/files/zkey.nim | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/groth16/files/export_json.nim b/groth16/files/export_json.nim index ebb09a3..53b2828 100644 --- a/groth16/files/export_json.nim +++ b/groth16/files/export_json.nim @@ -33,6 +33,8 @@ proc exportPublicIO*( fpath: string, prf: Proof ) = let f = open(fpath, fmWrite) defer: f.close() + # note: we start from 1 because the 0th element is the constant 1 "variable", + # which is automatically added by the tools for i in 1.. Date: Wed, 22 Nov 2023 15:55:18 +0100 Subject: [PATCH 5/6] export SageMath verification script (bn128 curve only) --- groth16/files/export_sage.nim | 152 ++++++++++++++++++++++++++++++++++ groth16/test_proof.nim | 8 +- 2 files changed, 156 insertions(+), 4 deletions(-) create mode 100644 groth16/files/export_sage.nim diff --git a/groth16/files/export_sage.nim b/groth16/files/export_sage.nim new file mode 100644 index 0000000..5a0da72 --- /dev/null +++ b/groth16/files/export_sage.nim @@ -0,0 +1,152 @@ + +# +# export proof, public input and verifier as a SageMath script +# + +import std/strutils +import std/streams + +import constantine/math/arithmetic except Fp, Fr + +import groth16/bn128 +import groth16/zkey_types +from groth16/prover import Proof + +#------------------------------------------------------------------------------- + +func toSpaces(str: string): string = spaces(str.len) + +func sageFp(prefix: string, x: Fp): string = prefix & "Fp(" & toDecimalFp(x) & ")" +func sageFr(prefix: string, x: Fr): string = prefix & "Fr(" & toDecimalFr(x) & ")" + +func sageFp2(prefix: string, z: Fp2): string = + sageFp( prefix & "mkFp2(" , z.coords[0]) & ",\n" & + sageFp( toSpaces(prefix) & " " , z.coords[1]) & ")" + +func sageG1(prefix: string, p: G1): string = + sageFp( prefix & "E(" , p.x) & ",\n" & + sageFp( toSpaces(prefix) & " " , p.y) & ")" + +func sageG2(prefix: string, p: G2): string = + sageFp2( prefix & "E2(" , p.x) & ",\n" & + sageFp2( toSpaces(prefix) & " " , p.y) & ")" + +#------------------------------------------------------------------------------- + +proc exportVKey(h: Stream, vkey: VKey ) = + let spec = vkey.spec + h.writeLine("alpha1 = \\") ; h.writeLine(sageG1(" ", spec.alpha1)) + h.writeLine("beta2 = \\") ; h.writeLine(sageG2(" ", spec.beta2 )) + h.writeLine("gamma2 = \\") ; h.writeLine(sageG2(" ", spec.gamma2)) + h.writeLine("delta2 = \\") ; h.writeLine(sageG2(" ", spec.delta2)) + + let pts = vkey.vpoints.pointsIC + h.writeLine("pointsIC = \\") + for i in 0.. = Fp[]" + , "Fp2. = Fp.extension(x^2+1)" + , "def mkFp2(a,b):" + , " return ( a + u*b )" + , "R. = Fp2[]" + , "Fp12. = Fp2.extension(x^6 - (9+u))" + , "E12 = E.base_extend(Fp12)" + , "" + , "# twisted curve" + , "B_twist = Fp2(19485874751759354771024239261021720505790618469301721065564631296452457478373 + 266929791119991161246907387137283842545076965332900288569378510910307636690*u )" + , "E2 = EllipticCurve(Fp2,[0,B_twist])" + , "size_E2 = E2.cardinality();" + , "cofactor_E2 = size_E2 / r;" + , "print(\"|E2| = \", size_E2 );" + , "print(\"h(E2) = \", cofactor_E2 );" + , "" + , "# map from E2 to E12" + , "def Psi(pt):" + , " pt.normalize_coordinates()" + , " x = pt[0]" + , " y = pt[1]" + , " return E12( Fp12(w^2 * x) , Fp12(w^3 * y) )" + , "" + , "def pairing(P,Q):" + , " return E12(P).ate_pairing( Psi(Q), n=r, k=12, t=bn_t, q=p^12 )" + , "" + ] + +const sage_bn128 : string = join(sage_bn128_lines, sep="\n") + +#------------------------------------------------------------------------------- + +const verify_lines : seq[string] = + @[ "pubG1 = pointsIC[0]" + , "for i in [1..len(pubIO)-1]:" + , " pubG1 = pubG1 + pubIO[i]*pointsIC[i]" + , "" + , "lhs = pairing( -piA , piB )" + , "rhs1 = pairing( alpha1 , beta2 )" + , "rhs2 = pairing( piC , delta2 )" + , "rhs3 = pairing( pubG1 , gamma2 )" + , "eq = lhs * rhs1 * rhs2 * rhs3" + , "print(\"verification suceeded =\\n\",eq == 1)" + ] + +const verify_script : string = join(verify_lines, sep="\n") + +#------------------------------------------------------------------------------- + +proc exportSage*(fpath: string, vkey: VKey, prf: Proof) = + + let h = openFileStream(fpath, fmWrite) + defer: h.close() + + h.writeLine(sage_bn128) + h.exportVKey(vkey); + h.exportProof(prf); + h.writeLine(verify_script) + +#------------------------------------------------------------------------------- + diff --git a/groth16/test_proof.nim b/groth16/test_proof.nim index 2a30a4e..cf3d603 100644 --- a/groth16/test_proof.nim +++ b/groth16/test_proof.nim @@ -12,7 +12,7 @@ import groth16/fake_setup #------------------------------------------------------------------------------- -proc testProveAndVerify*( zkey_fname, wtns_fname: string): Proof = +proc testProveAndVerify*( zkey_fname, wtns_fname: string): (VKey,Proof) = echo("parsing witness & zkey files...") let witness = parseWitness( wtns_fname) @@ -29,11 +29,11 @@ proc testProveAndVerify*( zkey_fname, wtns_fname: string): Proof = let ok = verifyProof( vkey, proof ) echo("verification succeeded = ",ok) - return proof + return (vkey,proof) #------------------------------------------------------------------------------- -proc testFakeSetupAndVerify*( r1cs_fname, wtns_fname: string, flavour=Snarkjs): Proof = +proc testFakeSetupAndVerify*( r1cs_fname, wtns_fname: string, flavour=Snarkjs): (VKey,Proof) = echo("trusted setup flavour = ",flavour) echo("parsing witness & r1cs files...") @@ -57,4 +57,4 @@ proc testFakeSetupAndVerify*( r1cs_fname, wtns_fname: string, flavour=Snarkjs): let ok = verifyProof( vkey, proof ) echo("verification succeeded = ",ok) - return proof + return (vkey,proof) From b433cbf7adca85d11d8213d7c44ed76d5ff90edc Mon Sep 17 00:00:00 2001 From: Balazs Komuves Date: Thu, 23 Nov 2023 13:25:26 +0100 Subject: [PATCH 6/6] speed up fake setup generation --- groth16/fake_setup.nim | 29 ++++++++++++++++++++++++++++- groth16/test_proof.nim | 12 +++++++++--- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/groth16/fake_setup.nim b/groth16/fake_setup.nim index f4a19ec..f39f67a 100644 --- a/groth16/fake_setup.nim +++ b/groth16/fake_setup.nim @@ -32,7 +32,7 @@ proc randomToxicWaste(): ToxicWaste = let b = randFr() let c = randFr() let d = randFr() - let t = randFr() + let t = randFr() # intToFr(106) return ToxicWaste( alpha: a , beta: b @@ -129,6 +129,16 @@ func matricesToCoeffs*(matrices: Matrices): seq[Coeff] = #------------------------------------------------------------------------------- +func dotProdFr(xs, ys: seq[Fr]): Fr = + let n = xs.len + assert( n == ys.len, "dotProdFr: incompatible vector lengths" ) + var s : Fr = zeroFr + for i in 0..