diff --git a/README.md b/README.md index e8990c9..e433d34 100644 --- a/README.md +++ b/README.md @@ -19,12 +19,13 @@ at your choice. ### TODO -- [ ] make it a nimble package -- [ ] refactor `bn128.nim` into smaller files -- [ ] proper MSM implementation (I couldn't make constantine's one to work) +- [ ] clean up the code +- [x] make it a nimble package +- [/] refactor `bn128.nim` into smaller files +- [/] proper MSM implementation (at first I couldn't make constantine's one to work) - [x] compare `.r1cs` to the "coeffs" section of `.zkey` - [x] generate fake circuit-specific setup ourselves -- [ ] multithreaded support (MSM, and possibly also FFT) +- [ ] multithreading support (MSM, and possibly also FFT) - [ ] add Groth16 notes - [ ] document the `snarkjs` circuit-specific setup `H` points convention - [ ] make it work for different curves diff --git a/bn128.nim b/bn128.nim deleted file mode 100644 index b5c9bc0..0000000 --- a/bn128.nim +++ /dev/null @@ -1,797 +0,0 @@ -# -# the `alt-bn128` elliptic curve -# -# See for example -# -# p = 21888242871839275222246405745257275088696311157297823662689037894645226208583 -# r = 21888242871839275222246405745257275088548364400416034343698204186575808495617 -# -# equation: y^2 = x^3 + 3 -# - -import sugar - -import std/bitops -import std/strutils -import std/sequtils -import std/streams -import std/random - -import constantine/platforms/abstractions -import constantine/math/isogenies/frobenius - -import constantine/math/arithmetic -import constantine/math/io/io_fields -import constantine/math/io/io_bigints -import constantine/math/config/curves -import constantine/math/config/type_ff as tff - -import constantine/math/extension_fields/towers as ext -import constantine/math/elliptic/ec_shortweierstrass_affine as aff -import constantine/math/elliptic/ec_shortweierstrass_projective as prj -import constantine/math/pairings/pairings_bn as ate -import constantine/math/elliptic/ec_scalar_mul as scl -import constantine/math/elliptic/ec_multi_scalar_mul as msm - -#------------------------------------------------------------------------------- - -type B* = BigInt[256] -type Fr* = tff.Fr[BN254Snarks] -type Fp* = tff.Fp[BN254Snarks] - -type Fp2* = ext.QuadraticExt[Fp] -type Fp12* = ext.Fp12[BN254Snarks] - -type G1* = aff.ECP_ShortW_Aff[Fp , aff.G1] -type G2* = aff.ECP_ShortW_Aff[Fp2, aff.G2] - -type ProjG1* = prj.ECP_ShortW_Prj[Fp , prj.G1] -type ProjG2* = prj.ECP_ShortW_Prj[Fp2, prj.G2] - -func mkFp2* (i: Fp, u: Fp) : Fp2 = - let c : array[2, Fp] = [i,u] - return ext.QuadraticExt[Fp]( coords: c ) - -func unsafeMkG1* ( X, Y: Fp ) : G1 = - return aff.ECP_ShortW_Aff[Fp, aff.G1](x: X, y: Y) - -func unsafeMkG2* ( X, Y: Fp2 ) : G2 = - return aff.ECP_ShortW_Aff[Fp2, aff.G2](x: X, y: Y) - -#------------------------------------------------------------------------------- - -func pairing* (p: G1, q: G2) : Fp12 = - var t : Fp12 - pairing_bn( t, p, q ) - return t - -#------------------------------------------------------------------------------- - -const primeP* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian ) -const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian ) - -const primeP_254 : BigInt[254] = fromHex( BigInt[254], "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian ) -const primeR_254 : BigInt[254] = fromHex( BigInt[254], "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian ) - -#------------------------------------------------------------------------------- - -const zeroFp* : Fp = fromHex( Fp, "0x00" ) -const zeroFr* : Fr = fromHex( Fr, "0x00" ) -const oneFp* : Fp = fromHex( Fp, "0x01" ) -const oneFr* : Fr = fromHex( Fr, "0x01" ) - -const zeroFp2* : Fp2 = mkFp2( zeroFp, zeroFp ) -const oneFp2* : Fp2 = mkFp2( oneFp , zeroFp ) - -const infG1* : G1 = unsafeMkG1( zeroFp , zeroFp ) -const infG2* : G2 = unsafeMkG2( zeroFp2 , zeroFp2 ) - -#------------------------------------------------------------------------------- - -func intToB*(a: uint): B = - var y : B - y.setUint(a) - return y - -func intToFp*(a: int): Fp = - var y : Fp - y.fromInt(a) - return y - -func intToFr*(a: int): Fr = - var y : Fr - y.fromInt(a) - return y - -#------------------------------------------------------------------------------- - -func isZeroFp*(x: Fp): bool = bool(isZero(x)) -func isZeroFr*(x: Fr): bool = bool(isZero(x)) - -func isEqualFp*(x, y: Fp): bool = bool(x == y) -func isEqualFr*(x, y: Fr): bool = bool(x == y) - -func `===`*(x, y: Fp): bool = isEqualFp(x,y) -func `===`*(x, y: Fr): bool = isEqualFr(x,y) - -#------------------- - -func isEqualFpSeq*(xs, ys: seq[Fp]): bool = - let n = xs.len - assert( n == ys.len ) - var b = true - for i in 0.. 0): - if bitand(e,1) > 0: a *= s - e = (e shr 1) - square(s) - return a - -func smallPowFr*(base: Fr, expo: int): Fr = - if expo >= 0: - return smallPowFr( base, uint(expo) ) - else: - return smallPowFr( invFr(base) , uint(-expo) ) - -#------------------------------------------------------------------------------- - -func deltaFr*(i, j: int) : Fr = - return (if (i == j): oneFr else: zeroFr) - -#------------------------------------------------------------------------------- - -func toDecimalBig*[n](a : BigInt[n]): string = - var s : string = toDecimal(a) - s = s.strip( leading=true, trailing=false, chars={'0'} ) - if s.len == 0: s="0" - return s - -func toDecimalFp*(a : Fp): string = - var s : string = toDecimal(a) - s = s.strip( leading=true, trailing=false, chars={'0'} ) - if s.len == 0: s="0" - return s - -func toDecimalFr*(a : Fr): string = - var s : string = toDecimal(a) - s = s.strip( leading=true, trailing=false, chars={'0'} ) - if s.len == 0: s="0" - return s - -#--------------------------------------- - -const k65536 : BigInt[254] = fromHex( BigInt[254], "0x10000", bigEndian ) - -func signedToDecimalFp*(a : Fp): string = - if bool( a.toBig() > primeP_254 - k65536 ): - return "-" & toDecimalFp(negFp(a)) - else: - return toDecimalFp(a) - -func signedToDecimalFr*(a : Fr): string = - if bool( a.toBig() > primeR_254 - k65536 ): - return "-" & toDecimalFr(negFr(a)) - else: - return toDecimalFr(a) - -#------------------------------------------------------------------------------- - -proc debugPrintFp*(prefix: string, x: Fp) = - echo(prefix & toDecimalFp(x)) - -proc debugPrintFp2*(prefix: string, z: Fp2) = - echo(prefix & " 1 ~> " & toDecimalFp(z.coords[0])) - echo(prefix & " u ~> " & toDecimalFp(z.coords[1])) - -proc debugPrintFr*(prefix: string, x: Fr) = - echo(prefix & toDecimalFr(x)) - -proc debugPrintFrSeq*(msg: string, xs: seq[Fr]) = - echo "---------------------" - echo msg - for x in xs: - debugPrintFr( " " , x ) - -proc debugPrintG1*(msg: string, pt: G1) = - echo(msg & ":") - debugPrintFp( " x = ", pt.x ) - debugPrintFp( " y = ", pt.y ) - -proc debugPrintG2*(msg: string, pt: G2) = - echo(msg & ":") - debugPrintFp2( " x = ", pt.x ) - debugPrintFp2( " y = ", pt.y ) - -#------------------------------------------------------------------------------- - -# Montgomery batch inversion -func batchInverse*( xs: seq[Fr] ) : seq[Fr] = - let n = xs.len - assert(n>0) - var us : seq[Fr] = newSeq[Fr](n+1) - var a = xs[0] - us[0] = oneFr - us[1] = a - for i in 1.. -# - -#[ -import sugar -import constantine/math/config/curves -import constantine/math/io/io_fields -import constantine/math/io/io_bigints -import ./zkey -]# - -import constantine/math/arithmetic except Fp, Fr -import constantine/math/io/io_extfields except Fp12 -import constantine/math/extension_fields/towers except Fp2, Fp12 - -import ./bn128 -import ./domain -import ./poly -import ./zkey_types -import ./witness - -#------------------------------------------------------------------------------- - -type - Proof* = object - publicIO* : seq[Fr] - pi_a* : G1 - pi_b* : G2 - pi_c* : G1 - curve : string - -#------------------------------------------------------------------------------- -# the verifier -# - -proc verifyProof* (vkey: VKey, prf: Proof): bool = - - assert( prf.curve == "bn128" ) - - assert( isOnCurveG1(prf.pi_a) , "pi_a is not in G1" ) - assert( isOnCurveG2(prf.pi_b) , "pi_b is not in G2" ) - assert( isOnCurveG1(prf.pi_c) , "pi_c is not in G1" ) - - var pubG1 : G1 = msmG1( prf.publicIO , vkey.vpoints.pointsIC ) - - let lhs : Fp12 = pairing( negG1(prf.pi_a) , prf.pi_b ) # < -pi_a , pi_b > - let rhs1 : Fp12 = vkey.spec.alphaBeta # < alpha , beta > - let rhs2 : Fp12 = pairing( prf.pi_c , vkey.spec.delta2 ) # < pi_c , delta > - let rhs3 : Fp12 = pairing( pubG1 , vkey.spec.gamma2 ) # < sum... , gamma > - - var eq : Fp12 - eq = lhs - eq *= rhs1 - eq *= rhs2 - eq *= rhs3 - - return bool(isOne(eq)) - -#------------------------------------------------------------------------------- -# A, B, C column vectors -# - -type - ABC = object - valuesA : seq[Fr] - valuesB : seq[Fr] - valuesC : seq[Fr] - -func buildABC( zkey: ZKey, witness: seq[Fr] ): ABC = - let hdr: GrothHeader = zkey.header - let domSize = hdr.domainSize - - var valuesA : seq[Fr] = newSeq[Fr](domSize) - var valuesB : seq[Fr] = newSeq[Fr](domSize) - for entry in zkey.coeffs: - case entry.matrix - of MatrixA: valuesA[entry.row] += entry.coeff * witness[entry.col] - of MatrixB: valuesB[entry.row] += entry.coeff * witness[entry.col] - else: raise newException(AssertionDefect, "fatal error") - - var valuesC : seq[Fr] = newSeq[Fr](domSize) - for i in 0..= 1) - var ys : seq[Fr] = newSeq[Fr](n) - ys[0] = xs[0] - if n >= 1: ys[1] = eta * xs[1] - var spow : Fr = eta - for i in 2.. -# -func computeSnarkjsScalarCoeffs( abc: ABC ): seq[Fr] = - let n = abc.valuesA.len - let D = createDomain(n) - let eta = createDomain(2*n).domainGen - let A1 = shiftEvalDomain( abc.valuesA, D, eta ) - let B1 = shiftEvalDomain( abc.valuesB, D, eta ) - let C1 = shiftEvalDomain( abc.valuesC, D, eta ) - var ys : seq[Fr] = newSeq[Fr]( n ) - for j in 0.. +# +# p = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +# r = 21888242871839275222246405745257275088548364400416034343698204186575808495617 +# +# equation: y^2 = x^3 + 3 +# + +#------------------------------------------------------------------------------- + +import groth16/bn128/fields +import groth16/bn128/curves +import groth16/bn128/msm +import groth16/bn128/io +import groth16/bn128/rnd +import groth16/bn128/debug + +#------------------- + +export fields +export curves +export msm +export io +export rnd +export debug + +#------------------------------------------------------------------------------- + diff --git a/groth16/bn128/curves.nim b/groth16/bn128/curves.nim new file mode 100644 index 0000000..77ba811 --- /dev/null +++ b/groth16/bn128/curves.nim @@ -0,0 +1,231 @@ +# +# the `alt-bn128` elliptic curve +# +# See for example +# +# p = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +# r = 21888242871839275222246405745257275088548364400416034343698204186575808495617 +# +# equation: y^2 = x^3 + 3 +# + + +#import constantine/platforms/abstractions +#import constantine/math/isogenies/frobenius + +import constantine/math/arithmetic except Fp, Fr +import constantine/math/io/io_fields except Fp, Fr +import constantine/math/io/io_bigints +import constantine/math/config/curves + +import constantine/math/config/type_ff as tff except Fp, Fr +import constantine/math/extension_fields/towers as ext except Fp, Fp2, Fp12, Fr + +import constantine/math/elliptic/ec_shortweierstrass_affine as aff +import constantine/math/elliptic/ec_shortweierstrass_projective as prj +import constantine/math/pairings/pairings_bn as ate +import constantine/math/elliptic/ec_scalar_mul as scl + +import groth16/bn128/fields + +#------------------------------------------------------------------------------- + +type G1* = aff.ECP_ShortW_Aff[Fp , aff.G1] +type G2* = aff.ECP_ShortW_Aff[Fp2, aff.G2] + +type ProjG1* = prj.ECP_ShortW_Prj[Fp , prj.G1] +type ProjG2* = prj.ECP_ShortW_Prj[Fp2, prj.G2] + +#------------------------------------------------------------------------------- + +func unsafeMkG1* ( X, Y: Fp ) : G1 = + return aff.ECP_ShortW_Aff[Fp, aff.G1](x: X, y: Y) + +func unsafeMkG2* ( X, Y: Fp2 ) : G2 = + return aff.ECP_ShortW_Aff[Fp2, aff.G2](x: X, y: Y) + +#------------------------------------------------------------------------------- + +const infG1* : G1 = unsafeMkG1( zeroFp , zeroFp ) +const infG2* : G2 = unsafeMkG2( zeroFp2 , zeroFp2 ) + +#------------------------------------------------------------------------------- + +func checkCurveEqG1*( x, y: Fp ) : bool = + if bool(isZero(x)) and bool(isZero(y)): + # the point at infinity is on the curve by definition + return true + else: + var x2 : Fp = squareFp(x) + var y2 : Fp = squareFp(y) + var x3 : Fp = x2 * x + var eq : Fp + eq = x3 + eq += intToFp(3) + eq -= y2 + # echo("eq = ",toDecimalFp(eq)) + return (bool(isZero(eq))) + +#--------------------------------------- + +# y^2 = x^3 + B +# B = b1 + bu*u +# b1 = 19485874751759354771024239261021720505790618469301721065564631296452457478373 +# b2 = 266929791119991161246907387137283842545076965332900288569378510910307636690 +const twistCoeffB_1 : Fp = fromHex(Fp, "0x2b149d40ceb8aaae81be18991be06ac3b5b4c5e559dbefa33267e6dc24a138e5") +const twistCoeffB_u : Fp = fromHex(Fp, "0x009713b03af0fed4cd2cafadeed8fdf4a74fa084e52d1852e4a2bd0685c315d2") +const twistCoeffB : Fp2 = mkFp2( twistCoeffB_1 , twistCoeffB_u ) + +func checkCurveEqG2*( x, y: Fp2 ) : bool = + if isZeroFp2(x) and isZeroFp2(y): + # the point at infinity is on the curve by definition + return true + else: + var x2 : Fp2 = squareFp2(x) + var y2 : Fp2 = squareFp2(y) + var x3 : Fp2 = x2 * x; + var eq : Fp2 + eq = x3 + eq += twistCoeffB + eq -= y2 + return isZeroFp2(eq) + +#------------------------------------------------------------------------------- + +func mkG1*( x, y: Fp ) : G1 = + if isZeroFp(x) and isZeroFp(y): + return infG1 + else: + assert( checkCurveEqG1(x,y) , "mkG1: not a G1 curve point" ) + return unsafeMkG1(x,y) + +func mkG2*( x, y: Fp2 ) : G2 = + if isZeroFp2(x) and isZeroFp2(y): + return infG2 + else: + assert( checkCurveEqG2(x,y) , "mkG2: not a G2 curve point" ) + return unsafeMkG2(x,y) + +#------------------------------------------------------------------------------- +# group generators + +const gen1_x : Fp = fromHex(Fp, "0x01") +const gen1_y : Fp = fromHex(Fp, "0x02") + +const gen2_xi : Fp = fromHex(Fp, "0x1adcd0ed10df9cb87040f46655e3808f98aa68a570acf5b0bde23fab1f149701") +const gen2_xu : Fp = fromHex(Fp, "0x09e847e9f05a6082c3cd2a1d0a3a82e6fbfbe620f7f31269fa15d21c1c13b23b") +const gen2_yi : Fp = fromHex(Fp, "0x056c01168a5319461f7ca7aa19d4fcfd1c7cdf52dbfc4cbee6f915250b7f6fc8") +const gen2_yu : Fp = fromHex(Fp, "0x0efe500a2d02dd77f5f401329f30895df553b878fc3c0dadaaa86456a623235c") + +const gen2_x : Fp2 = mkFp2( gen2_xi, gen2_xu ) +const gen2_y : Fp2 = mkFp2( gen2_yi, gen2_yu ) + +const gen1* : G1 = unsafeMkG1( gen1_x, gen1_y ) +const gen2* : G2 = unsafeMkG2( gen2_x, gen2_y ) + +#------------------------------------------------------------------------------- + +func isOnCurveG1* ( p: G1 ) : bool = + return checkCurveEqG1( p.x, p.y ) + +func isOnCurveG2* ( p: G2 ) : bool = + return checkCurveEqG2( p.x, p.y ) + +#=============================================================================== + +func addG1*(p,q: G1): G1 = + var r, x, y : ProjG1 + prj.fromAffine(x, p) + prj.fromAffine(y, q) + prj.sum(r, x, y) + var s : G1 + prj.affine(s, r) + return s + +#--------------------------------------- + +func addG2*(p,q: G2): G2 = + var r, x, y : ProjG2 + prj.fromAffine(x, p) + prj.fromAffine(y, q) + prj.sum(r, x, y) + var s : G2 + prj.affine(s, r) + return s + +func negG1*(p: G1): G1 = + var r : G1 = p + neg(r) + return r + +func negG2*(p: G2): G2 = + var r : G2 = p + neg(r) + return r + +#--------------------------------------- + +func `+`*(p,q: G1): G1 = addG1(p,q) +func `+`*(p,q: G2): G2 = addG2(p,q) + +func `+=`*(p: var G1, q: G1) = p = addG1(p,q) +func `+=`*(p: var G2, q: G2) = p = addG2(p,q) + +func `-=`*(p: var G1, q: G1) = p = addG1(p,negG1(q)) +func `-=`*(p: var G2, q: G2) = p = addG2(p,negG2(q)) + +#------------------------------------------------------------------------------- +# +# (affine) scalar multiplication +# + +func `**`*( coeff: Fr , point: G1 ) : G1 = + var q : ProjG1 + prj.fromAffine( q , point ) + scl.scalarMul( q , coeff.toBig() ) + var r : G1 + prj.affine( r, q ) + return r + +func `**`*( coeff: Fr , point: G2 ) : G2 = + var q : ProjG2 + prj.fromAffine( q , point ) + scl.scalarMul( q , coeff.toBig() ) + var r : G2 + prj.affine( r, q ) + return r + +#------------------- + +func `**`*( coeff: BigInt , point: G1 ) : G1 = + var q : ProjG1 + prj.fromAffine( q , point ) + scl.scalarMul( q , coeff ) + var r : G1 + prj.affine( r, q ) + return r + +func `**`*( coeff: BigInt , point: G2 ) : G2 = + var q : ProjG2 + prj.fromAffine( q , point ) + scl.scalarMul( q , coeff ) + var r : G2 + prj.affine( r, q ) + return r + +#------------------------------------------------------------------------------- + +func pairing* (p: G1, q: G2) : Fp12 = + var t : Fp12 + ate.pairing_bn[BN254Snarks]( t, p, q ) + return t + +#------------------------------------------------------------------------------- + +proc sanityCheckGroupGen*() = + echo( "gen1 on the curve = ", checkCurveEqG1(gen1.x,gen1.y) ) + echo( "gen2 on the curve = ", checkCurveEqG2(gen2.x,gen2.y) ) + echo( "order of gen1 is R = ", (not bool(isInf(gen1))) and bool(isInf(primeR ** gen1)) ) + echo( "order of gen2 is R = ", (not bool(isInf(gen2))) and bool(isInf(primeR ** gen2)) ) + +#------------------------------------------------------------------------------- diff --git a/groth16/bn128/debug.nim b/groth16/bn128/debug.nim new file mode 100644 index 0000000..ab7403c --- /dev/null +++ b/groth16/bn128/debug.nim @@ -0,0 +1,45 @@ +# +# the `alt-bn128` elliptic curve +# +# See for example +# +# p = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +# r = 21888242871839275222246405745257275088548364400416034343698204186575808495617 +# +# equation: y^2 = x^3 + 3 +# + +import groth16/bn128/fields +import groth16/bn128/curves +import groth16/bn128/io + +#------------------------------------------------------------------------------- + +proc debugPrintFp*(prefix: string, x: Fp) = + echo(prefix & toDecimalFp(x)) + +proc debugPrintFp2*(prefix: string, z: Fp2) = + echo(prefix & " 1 ~> " & toDecimalFp(z.coords[0])) + echo(prefix & " u ~> " & toDecimalFp(z.coords[1])) + +proc debugPrintFr*(prefix: string, x: Fr) = + echo(prefix & toDecimalFr(x)) + +proc debugPrintFrSeq*(msg: string, xs: seq[Fr]) = + echo "---------------------" + echo msg + for x in xs: + debugPrintFr( " " , x ) + +proc debugPrintG1*(msg: string, pt: G1) = + echo(msg & ":") + debugPrintFp( " x = ", pt.x ) + debugPrintFp( " y = ", pt.y ) + +proc debugPrintG2*(msg: string, pt: G2) = + echo(msg & ":") + debugPrintFp2( " x = ", pt.x ) + debugPrintFp2( " y = ", pt.y ) + +#------------------------------------------------------------------------------- + diff --git a/groth16/bn128/fields.nim b/groth16/bn128/fields.nim new file mode 100644 index 0000000..5ffa8a1 --- /dev/null +++ b/groth16/bn128/fields.nim @@ -0,0 +1,189 @@ + +# +# the prime fields Fp and Fr with sizes +# +# p = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +# r = 21888242871839275222246405745257275088548364400416034343698204186575808495617 +# + +import sugar + +import std/bitops +import std/sequtils + +import constantine/math/arithmetic +import constantine/math/io/io_fields +import constantine/math/io/io_bigints +import constantine/math/config/curves +import constantine/math/config/type_ff as tff +import constantine/math/extension_fields/towers as ext + +#------------------------------------------------------------------------------- + +type B* = BigInt[256] +type Fr* = tff.Fr[BN254Snarks] +type Fp* = tff.Fp[BN254Snarks] + +type Fp2* = ext.QuadraticExt[Fp] +type Fp12* = ext.Fp12[BN254Snarks] + +func mkFp2* (i: Fp, u: Fp) : Fp2 = + let c : array[2, Fp] = [i,u] + return ext.QuadraticExt[Fp]( coords: c ) + +#------------------------------------------------------------------------------- + +const primeP* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian ) +const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian ) + +#------------------------------------------------------------------------------- + +const zeroFp* : Fp = fromHex( Fp, "0x00" ) +const zeroFr* : Fr = fromHex( Fr, "0x00" ) +const oneFp* : Fp = fromHex( Fp, "0x01" ) +const oneFr* : Fr = fromHex( Fr, "0x01" ) + +const zeroFp2* : Fp2 = mkFp2( zeroFp, zeroFp ) +const oneFp2* : Fp2 = mkFp2( oneFp , zeroFp ) + +const minusOneFp* : Fp = fromHex( Fp, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd46" ) +const minusOneFr* : Fr = fromHex( Fr, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000000" ) + +#------------------------------------------------------------------------------- + +func intToB*(a: uint): B = + var y : B + y.setUint(a) + return y + +func intToFp*(a: int): Fp = + var y : Fp + y.fromInt(a) + return y + +func intToFr*(a: int): Fr = + var y : Fr + y.fromInt(a) + return y + +#------------------------------------------------------------------------------- + +func isZeroFp* (x: Fp ): bool = bool(isZero(x)) +func isZeroFp2*(x: Fp2): bool = bool(isZero(x)) +func isZeroFr* (x: Fr ): bool = bool(isZero(x)) + +func isEqualFp* (x, y: Fp ): bool = bool(x == y) +func isEqualFp2*(x, y: Fp2): bool = bool(x == y) +func isEqualFr* (x, y: Fr ): bool = bool(x == y) + +func `===`*(x, y: Fp ): bool = isEqualFp(x,y) +func `===`*(x, y: Fp2): bool = isEqualFp2(x,y) +func `===`*(x, y: Fr ): bool = isEqualFr(x,y) + +#------------------- + +func isEqualFpSeq*(xs, ys: seq[Fp]): bool = + let n = xs.len + assert( n == ys.len ) + var b = true + for i in 0.. 0): + if bitand(e,1) > 0: a *= s + e = (e shr 1) + square(s) + return a + +func smallPowFr*(base: Fr, expo: int): Fr = + if expo >= 0: + return smallPowFr( base, uint(expo) ) + else: + return smallPowFr( invFr(base) , uint(-expo) ) + +#------------------------------------------------------------------------------- + +func deltaFr*(i, j: int) : Fr = + return (if (i == j): oneFr else: zeroFr) + +#------------------------------------------------------------------------------- + +# Montgomery batch inversion +func batchInverseFr*( xs: seq[Fr] ) : seq[Fr] = + let n = xs.len + assert(n>0) + var us : seq[Fr] = newSeq[Fr](n+1) + var a = xs[0] + us[0] = oneFr + us[1] = a + for i in 1.. primeP_254 - k65536 ): + return "-" & toDecimalFp(negFp(a)) + else: + return toDecimalFp(a) + +func signedToDecimalFr*(a : Fr): string = + if bool( a.toBig() > primeR_254 - k65536 ): + return "-" & toDecimalFr(negFr(a)) + else: + return toDecimalFr(a) + +#------------------------------------------------------------------------------- +# Dealing with Montgomery representation +# + +# R=2^256; this computes 2^256 mod Fp +func calcFpMontR*() : Fp = + var x : Fp = intToFp(2) + for i in 1..8: + square(x) + return x + +# R=2^256; this computes the inverse of (2^256 mod Fp) +func calcFpInvMontR*() : Fp = + var x : Fp = calcFpMontR() + inv(x) + return x + +# R=2^256; this computes 2^256 mod Fr +func calcFrMontR*() : Fr = + var x : Fr = intToFr(2) + for i in 1..8: + square(x) + return x + +# R=2^256; this computes the inverse of (2^256 mod Fp) +func calcFrInvMontR*() : Fr = + var x : Fr = calcFrMontR() + inv(x) + return x + +# apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?) +const fpMontR* : Fp = fromHex( Fp, "0x0e0a77c19a07df2f666ea36f7879462c0a78eb28f5c70b3dd35d438dc58f0d9d" ) +const fpInvMontR* : Fp = fromHex( Fp, "0x2e67157159e5c639cf63e9cfb74492d9eb2022850278edf8ed84884a014afa37" ) + +# apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?) +const frMontR* : Fr = fromHex( Fr, "0x0e0a77c19a07df2f666ea36f7879462e36fc76959f60cd29ac96341c4ffffffb" ) +const frInvMontR* : Fr = fromHex( Fr, "0x15ebf95182c5551cc8260de4aeb85d5d090ef5a9e111ec87dc5ba0056db1194e" ) + +proc checkMontgomeryConstants*() = + assert( bool( fpMontR == calcFpMontR() ) ) + assert( bool( frMontR == calcFrMontR() ) ) + assert( bool( fpInvMontR == calcFpInvMontR() ) ) + assert( bool( frInvMontR == calcFrInvMontR() ) ) + echo("OK") + +#--------------------------------------- + +# the binary file `.zkey` used by the `circom` ecosystem uses little-endian +# Montgomery representation. So when we unmarshal with Constantine, it will +# give the wrong result. Calling this function on the result fixes that. +func fromMontgomeryFp*(x : Fp) : Fp = + var y : Fp = x; + y *= fpInvMontR + return y + +func fromMontgomeryFr*(x : Fr) : Fr = + var y : Fr = x; + y *= frInvMontR + return y + +func toMontgomeryFr*(x : Fr) : Fr = + var y : Fr = x; + y *= frMontR + return y + +#------------------------------------------------------------------------------- +# Unmarshalling field elements +# Note: in the `.zkey` coefficients, e apparently DOUBLE Montgomery encoding is used ?!? +# + +func unmarshalFpMont* ( bs: array[32,byte] ) : Fp = + var big : BigInt[254] + unmarshal( big, bs, littleEndian ); + var x : Fp + x.fromBig( big ) + return fromMontgomeryFp(x) + +# WTF Jordi, go home you are drunk +func unmarshalFrWTF* ( bs: array[32,byte] ) : Fr = + var big : BigInt[254] + unmarshal( big, bs, littleEndian ); + var x : Fr + x.fromBig( big ) + return fromMontgomeryFr(fromMontgomeryFr(x)) + +func unmarshalFrStd* ( bs: array[32,byte] ) : Fr = + var big : BigInt[254] + unmarshal( big, bs, littleEndian ); + var x : Fr + x.fromBig( big ) + return x + +func unmarshalFrMont* ( bs: array[32,byte] ) : Fr = + var big : BigInt[254] + unmarshal( big, bs, littleEndian ); + var x : Fr + x.fromBig( big ) + return fromMontgomeryFr(x) + +#------------------------------------------------------------------------------- + +func unmarshalFpMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fp] = + var vals : seq[Fp] = newSeq[Fp]( len ) + var bytes : array[32,byte] + for i in 0..= 1024, we use 8 threads + let nthreads = max( 1 , min( N div 128 , nthreadsTarget ) ) + + let m = N div nthreads + + var threads : seq[Thread[InputTuple]] = newSeq[Thread[InputTuple]]( nthreads ) + var results : seq[G1] = newSeq[G1]( nthreads ) + + proc myThreadFunc( inp: InputTuple ) {.thread.} = + results[inp.idx] = msmConstantineG1( inp.coeffs, inp.points ) + + for i in 0.. = Fp[]" + , "Fp2. = Fp.extension(x^2+1)" + , "def mkFp2(a,b):" + , " return ( a + u*b )" + , "R. = Fp2[]" + , "Fp12. = Fp2.extension(x^6 - (9+u))" + , "E12 = E.base_extend(Fp12)" + , "" + , "# twisted curve" + , "B_twist = Fp2(19485874751759354771024239261021720505790618469301721065564631296452457478373 + 266929791119991161246907387137283842545076965332900288569378510910307636690*u )" + , "E2 = EllipticCurve(Fp2,[0,B_twist])" + , "size_E2 = E2.cardinality();" + , "cofactor_E2 = size_E2 / r;" + , "print(\"|E2| = \", size_E2 );" + , "print(\"h(E2) = \", cofactor_E2 );" + , "" + , "# map from E2 to E12" + , "def Psi(pt):" + , " pt.normalize_coordinates()" + , " x = pt[0]" + , " y = pt[1]" + , " return E12( Fp12(w^2 * x) , Fp12(w^3 * y) )" + , "" + , "def pairing(P,Q):" + , " return E12(P).ate_pairing( Psi(Q), n=r, k=12, t=bn_t, q=p^12 )" + , "" + ] + +const sage_bn128 : string = join(sage_bn128_lines, sep="\n") + +#------------------------------------------------------------------------------- + +const verify_lines : seq[string] = + @[ "pubG1 = pointsIC[0]" + , "for i in [1..len(pubIO)-1]:" + , " pubG1 = pubG1 + pubIO[i]*pointsIC[i]" + , "" + , "lhs = pairing( -piA , piB )" + , "rhs1 = pairing( alpha1 , beta2 )" + , "rhs2 = pairing( piC , delta2 )" + , "rhs3 = pairing( pubG1 , gamma2 )" + , "eq = lhs * rhs1 * rhs2 * rhs3" + , "print(\"verification suceeded =\\n\",eq == 1)" + ] + +const verify_script : string = join(verify_lines, sep="\n") + +#------------------------------------------------------------------------------- + +proc exportSage*(fpath: string, vkey: VKey, prf: Proof) = + + let h = openFileStream(fpath, fmWrite) + defer: h.close() + + h.writeLine(sage_bn128) + h.exportVKey(vkey); + h.exportProof(prf); + h.writeLine(verify_script) + +#------------------------------------------------------------------------------- + diff --git a/r1cs.nim b/groth16/files/r1cs.nim similarity index 98% rename from r1cs.nim rename to groth16/files/r1cs.nim index e07fe99..cb65712 100644 --- a/r1cs.nim +++ b/groth16/files/r1cs.nim @@ -43,7 +43,7 @@ # ... # ... # -# 4: Custom gates application +# 5: Custom gates application # --------------------------- # ... # ... @@ -54,8 +54,8 @@ import std/streams import constantine/math/arithmetic except Fp, Fr import constantine/math/io/io_bigints -import ./bn128 -import ./container +import groth16/bn128 +import groth16/files/container #------------------------------------------------------------------------------- diff --git a/witness.nim b/groth16/files/witness.nim similarity index 91% rename from witness.nim rename to groth16/files/witness.nim index 9c45b4e..f849eb3 100644 --- a/witness.nim +++ b/groth16/files/witness.nim @@ -11,9 +11,7 @@ # # nvars = 1 + pub + secret = 1 + npubout + npubin + nprivin + nsecret # -# NOTE: Unlike the `.zkey` files, which encode field elements in the -# Montgomery representation, the `.wtns` file encode field elements in -# the standard representation! +# Field elements are encoded in the standard representation. # import std/streams @@ -21,8 +19,8 @@ import std/streams import constantine/math/arithmetic except Fp, Fr import constantine/math/io/io_bigints -import ./bn128 -import ./container +import groth16/bn128 +import groth16/files/container #------------------------------------------------------------------------------- diff --git a/zkey.nim b/groth16/files/zkey.nim similarity index 98% rename from zkey.nim rename to groth16/files/zkey.nim index abc3f97..73ce1e5 100644 --- a/zkey.nim +++ b/groth16/files/zkey.nim @@ -31,7 +31,7 @@ # beta2 : G2 = [beta]_2 # gamma2 : G2 = [gamma]_2 # delta1 : G1 = [delta]_1 -# delta2 : G2 = [delta_2] +# delta2 : G2 = [delta]_2 # # 3: IC # ----- @@ -97,10 +97,10 @@ import std/streams import constantine/math/arithmetic except Fp, Fr #import constantine/math/io/io_bigints -import ./bn128 -import ./zkey_types -import ./container -import ./misc +import groth16/bn128 +import groth16/zkey_types +import groth16/files/container +import groth16/misc #------------------------------------------------------------------------------- diff --git a/domain.nim b/groth16/math/domain.nim similarity index 97% rename from domain.nim rename to groth16/math/domain.nim index 334c2ce..2185814 100644 --- a/domain.nim +++ b/groth16/math/domain.nim @@ -7,8 +7,8 @@ import constantine/math/arithmetic except Fp,Fr import constantine/math/io/io_fields except Fp,Fr #import constantine/math/io/io_bigints -import ./bn128 -import ./misc +import groth16/bn128 +import groth16/misc #------------------------------------------------------------------------------- diff --git a/ntt.nim b/groth16/math/ntt.nim similarity index 99% rename from ntt.nim rename to groth16/math/ntt.nim index 8a84276..0173a1b 100644 --- a/ntt.nim +++ b/groth16/math/ntt.nim @@ -9,8 +9,8 @@ import constantine/math/arithmetic except Fp,Fr import constantine/math/io/io_fields -import bn128 -import domain +import groth16/bn128 +import groth16/math/domain #------------------------------------------------------------------------------- diff --git a/poly.nim b/groth16/math/poly.nim similarity index 98% rename from poly.nim rename to groth16/math/poly.nim index 44d599b..f9e82d3 100644 --- a/poly.nim +++ b/groth16/math/poly.nim @@ -10,12 +10,12 @@ import std/sequtils import std/sugar import constantine/math/arithmetic except Fp,Fr -import constantine/math/io/io_fields +#import constantine/math/io/io_fields -import bn128 -import domain -import ntt -import misc +import groth16/bn128 +import groth16/math/domain +import groth16/math/ntt +import groth16/misc #------------------------------------------------------------------------------- @@ -220,7 +220,7 @@ func polyDivideByVanishing*(P: Poly, N: int): Poly = #------------------------------------------------------------------------------- # Lagrange basis polynomials -func lagrangePoly(D: Domain, k: int): Poly = +func lagrangePoly*(D: Domain, k: int): Poly = let N = D.domainSize let omMinusK : Fr = smallPowFr( D.invDomainGen , k ) let invN : Fr = invFr(intToFr(N)) diff --git a/misc.nim b/groth16/misc.nim similarity index 100% rename from misc.nim rename to groth16/misc.nim diff --git a/groth16/prover.nim b/groth16/prover.nim new file mode 100644 index 0000000..8eb6be5 --- /dev/null +++ b/groth16/prover.nim @@ -0,0 +1,216 @@ + +# +# Groth16 prover +# +# WARNING! +# the points H in `.zkey` are *NOT* what normal people would think they are +# See +# + +#[ +import sugar +import constantine/math/config/curves +import constantine/math/io/io_fields +import constantine/math/io/io_bigints +import ./zkey +]# + +import constantine/math/arithmetic except Fp, Fr +#import constantine/math/io/io_extfields except Fp12 +#import constantine/math/extension_fields/towers except Fp2, Fp12 + +import groth16/bn128 +import groth16/math/domain +import groth16/math/poly +import groth16/zkey_types +import groth16/files/witness + +#------------------------------------------------------------------------------- + +type + Proof* = object + publicIO* : seq[Fr] + pi_a* : G1 + pi_b* : G2 + pi_c* : G1 + curve* : string + +#------------------------------------------------------------------------------- +# A, B, C column vectors +# + +type + ABC = object + valuesA : seq[Fr] + valuesB : seq[Fr] + valuesC : seq[Fr] + +func buildABC( zkey: ZKey, witness: seq[Fr] ): ABC = + let hdr: GrothHeader = zkey.header + let domSize = hdr.domainSize + + var valuesA : seq[Fr] = newSeq[Fr](domSize) + var valuesB : seq[Fr] = newSeq[Fr](domSize) + for entry in zkey.coeffs: + case entry.matrix + of MatrixA: valuesA[entry.row] += entry.coeff * witness[entry.col] + of MatrixB: valuesB[entry.row] += entry.coeff * witness[entry.col] + else: raise newException(AssertionDefect, "fatal error") + + var valuesC : seq[Fr] = newSeq[Fr](domSize) + for i in 0..= 1) + var ys : seq[Fr] = newSeq[Fr](n) + ys[0] = xs[0] + if n >= 1: ys[1] = eta * xs[1] + var spow : Fr = eta + for i in 2.. +# +func computeSnarkjsScalarCoeffs( abc: ABC ): seq[Fr] = + let n = abc.valuesA.len + let D = createDomain(n) + let eta = createDomain(2*n).domainGen + let A1 = shiftEvalDomain( abc.valuesA, D, eta ) + let B1 = shiftEvalDomain( abc.valuesB, D, eta ) + let C1 = shiftEvalDomain( abc.valuesC, D, eta ) + var ys : seq[Fr] = newSeq[Fr]( n ) + for j in 0.. + let rhs1 : Fp12 = vkey.spec.alphaBeta # < alpha , beta > + let rhs2 : Fp12 = pairing( prf.pi_c , vkey.spec.delta2 ) # < pi_c , delta > + let rhs3 : Fp12 = pairing( pubG1 , vkey.spec.gamma2 ) # < sum... , gamma > + + var eq : Fp12 + eq = lhs + eq *= rhs1 + eq *= rhs2 + eq *= rhs3 + + return bool(isOne(eq)) + +#------------------------------------------------------------------------------- diff --git a/zkey_types.nim b/groth16/zkey_types.nim similarity index 95% rename from zkey_types.nim rename to groth16/zkey_types.nim index f1fc815..ca51f8a 100644 --- a/zkey_types.nim +++ b/groth16/zkey_types.nim @@ -1,7 +1,7 @@ import constantine/math/arithmetic except Fp, Fr -import ./bn128 +import groth16/bn128 #------------------------------------------------------------------------------- @@ -80,14 +80,14 @@ func matrixSelToString(sel: MatrixSel): string = of MatrixB: return "B" of MatrixC: return "C" -proc printCoeff(cf: Coeff) = +proc debugPrintCoeff(cf: Coeff) = echo( "matrix=", matrixSelToString(cf.matrix) , " | i=", cf.row , " | j=", cf.col , " | val=", signedToDecimalFr(cf.coeff) ) -proc printCoeffs*(cfs: seq[Coeff]) = - for cf in cfs: printCoeff(cf) +proc debugPrintCoeffs*(cfs: seq[Coeff]) = + for cf in cfs: debugPrintCoeff(cf) #------------------------------------------------------------------------------- diff --git a/tests/groth16/testProver.nim b/tests/groth16/testProver.nim new file mode 100644 index 0000000..ba1e8e9 --- /dev/null +++ b/tests/groth16/testProver.nim @@ -0,0 +1,75 @@ + +import std/unittest +import std/sequtils + +import groth16/prover +import groth16/verifier +import groth16/fake_setup +import groth16/zkey_types +import groth16/files/witness +import groth16/files/r1cs +import groth16/bn128/fields + +#------------------------------------------------------------------------------- +# simple hand-crafted arithmetic circuit +# + +const myWitnessCfg = + WitnessConfig( nWires: 7 + , nPubOut: 1 # public output = input + a*b*c = 1022 + 7*11*13 = 2023 + , nPubIn: 1 # public input = 1022 + , nPrivIn: 3 # private inputs: 7, 11, 13 + , nLabels: 0 + ) + +# 2023 == 1022 + 7*3*11 +const myEq1 : Constraint = ( @[] , @[] , @[ (0,minusOneFr) , (1,oneFr) , (6,oneFr) ] ) + +# 7*11 == 77 +const myEq2 : Constraint = ( @[ (2,oneFr) ] , @[ (3,oneFr) ] , @[ (5,oneFr) ] ) + +# 77*13 == 1001 +const myEq3 : Constraint = ( @[ (4,oneFr) ] , @[ (5,oneFr) ] , @[ (6,oneFr) ] ) + +const myConstraints : seq[Constraint] = @[ myEq1, myEq2, myEq3 ] + +const myLabels : seq[int] = @[] + +const myR1CS = + R1CS( r: primeR + , cfg: myWitnessCfg + , nConstr: myConstraints.len + , constraints: myConstraints + , wireToLabel: myLabels + ) + +# the equation we want prove is `7*11*13 + 1022 == 2023` +let myWitnessValues : seq[Fr] = map( @[ 2023, 1022, 7, 11, 13, 7*11, 7*11*13 ] , intToFr ) +# wire indices: ^^^^^^^ 0 1 2 3 4 5 6 + +let myWitness = + Witness( curve: "bn128" + , r: primeR + , nvars: 7 + , values: myWitnessValues + ) + +#------------------------------------------------------------------------------- + +proc testProof(zkey: ZKey, witness: Witness): bool = + let proof = generateProof( zkey, witness ) + let vkey = extractVKey( zkey) + let ok = verifyProof( vkey, proof ) + return ok + +suite "prover": + + test "prove & verify simple multiplication circuit, `JensGroth` flavour": + let zkey = createFakeCircuitSetup( myR1cs, flavour=JensGroth ) + check testProof( zkey, myWitness ) + + test "prove & verify simple multiplication circuit, `Snarkjs` flavour": + let zkey = createFakeCircuitSetup( myR1cs, flavour=Snarkjs ) + check testProof( zkey, myWitness ) + +#------------------------------------------------------------------------------- diff --git a/tests/nim.cfg b/tests/nim.cfg new file mode 100644 index 0000000..0f840a1 --- /dev/null +++ b/tests/nim.cfg @@ -0,0 +1 @@ +--path:".." diff --git a/tests/test.nim b/tests/test.nim new file mode 100644 index 0000000..f2bf786 --- /dev/null +++ b/tests/test.nim @@ -0,0 +1,3 @@ + +import ./groth16/testProver +