diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..d915edc --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,21 @@ +name: CI + +on: [push, pull_request] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + nim: [2.2.4] + steps: + - uses: actions/checkout@v2 + - uses: iffy/install-nim@v5 + with: + version: ${{ matrix.nim }} + - name: Install dependencies + run: nimble install -d -y + - name: Test + run: nimble test -y + - name: Build binary + run: nimble build -y diff --git a/groth16.nimble b/groth16.nimble index a6566ea..168d302 100644 --- a/groth16.nimble +++ b/groth16.nimble @@ -8,4 +8,4 @@ binDir = "build" namedBin = {"cli/cli_main": "nim-groth16"}.toTable() requires "https://github.com/status-im/nim-taskpools" -requires "https://github.com/mratsim/constantine#5f7ba18f2ed351260015397c9eae079a6decaee1" \ No newline at end of file +requires "https://github.com/mratsim/constantine >= 0.2.0" \ No newline at end of file diff --git a/groth16/bn128/curves.nim b/groth16/bn128/curves.nim index 42611b2..f8bb9f1 100644 --- a/groth16/bn128/curves.nim +++ b/groth16/bn128/curves.nim @@ -13,13 +13,12 @@ #import constantine/platforms/abstractions #import constantine/math/isogenies/frobenius -import constantine/math/arithmetic except Fp, Fr -import constantine/math/io/io_fields except Fp, Fr import constantine/math/io/io_bigints -import constantine/math/config/curves +import constantine/math/arithmetic +import constantine/math/io/io_fields -import constantine/math/config/type_ff as tff except Fp, Fr -import constantine/math/extension_fields/towers as ext except Fp, Fp2, Fp12, Fr +import constantine/named/properties_fields as tff +import constantine/math/extension_fields/towers as ext import constantine/math/elliptic/ec_shortweierstrass_affine as aff import constantine/math/elliptic/ec_shortweierstrass_projective as prj @@ -30,19 +29,19 @@ import groth16/bn128/fields #------------------------------------------------------------------------------- -type G1* = aff.ECP_ShortW_Aff[Fp , aff.G1] -type G2* = aff.ECP_ShortW_Aff[Fp2, aff.G2] +type G1* = aff.EC_ShortW_Aff[Fp[BN254_Snarks] , aff.G1] +type G2* = aff.EC_ShortW_Aff[Fp2[BN254_Snarks], aff.G2] -type ProjG1* = prj.ECP_ShortW_Prj[Fp , prj.G1] -type ProjG2* = prj.ECP_ShortW_Prj[Fp2, prj.G2] +type ProjG1* = prj.EC_ShortW_Prj[Fp[BN254_Snarks] , prj.G1] +type ProjG2* = prj.EC_ShortW_Prj[Fp2[BN254_Snarks], prj.G2] #------------------------------------------------------------------------------- -func unsafeMkG1* ( X, Y: Fp ) : G1 = - return aff.ECP_ShortW_Aff[Fp, aff.G1](x: X, y: Y) +func unsafeMkG1* ( X, Y: Fp[BN254_Snarks] ) : G1 = + return aff.EC_ShortW_Aff[Fp[BN254_Snarks], aff.G1](x: X, y: Y) -func unsafeMkG2* ( X, Y: Fp2 ) : G2 = - return aff.ECP_ShortW_Aff[Fp2, aff.G2](x: X, y: Y) +func unsafeMkG2* ( X, Y: Fp2[BN254_Snarks] ) : G2 = + return aff.EC_ShortW_Aff[Fp2[BN254_Snarks], aff.G2](x: X, y: Y) #------------------------------------------------------------------------------- @@ -51,15 +50,15 @@ const infG2* : G2 = unsafeMkG2( zeroFp2 , zeroFp2 ) #------------------------------------------------------------------------------- -func checkCurveEqG1*( x, y: Fp ) : bool = +func checkCurveEqG1*( x, y: Fp[BN254_Snarks] ) : bool = if bool(isZero(x)) and bool(isZero(y)): # the point at infinity is on the curve by definition return true else: - var x2 : Fp = squareFp(x) - var y2 : Fp = squareFp(y) - var x3 : Fp = x2 * x - var eq : Fp + var x2 = squareFp(x) + var y2 = squareFp(y) + var x3 = x2 * x + var eq : Fp[BN254_Snarks] eq = x3 eq += intToFp(3) eq -= y2 @@ -72,19 +71,19 @@ func checkCurveEqG1*( x, y: Fp ) : bool = # B = b1 + bu*u # b1 = 19485874751759354771024239261021720505790618469301721065564631296452457478373 # b2 = 266929791119991161246907387137283842545076965332900288569378510910307636690 -const twistCoeffB_1 : Fp = fromHex(Fp, "0x2b149d40ceb8aaae81be18991be06ac3b5b4c5e559dbefa33267e6dc24a138e5") -const twistCoeffB_u : Fp = fromHex(Fp, "0x009713b03af0fed4cd2cafadeed8fdf4a74fa084e52d1852e4a2bd0685c315d2") -const twistCoeffB : Fp2 = mkFp2( twistCoeffB_1 , twistCoeffB_u ) +const twistCoeffB_1 = fromHex(Fp[BN254_Snarks], "0x2b149d40ceb8aaae81be18991be06ac3b5b4c5e559dbefa33267e6dc24a138e5") +const twistCoeffB_u = fromHex(Fp[BN254_Snarks], "0x009713b03af0fed4cd2cafadeed8fdf4a74fa084e52d1852e4a2bd0685c315d2") +const twistCoeffB = mkFp2( twistCoeffB_1 , twistCoeffB_u ) -func checkCurveEqG2*( x, y: Fp2 ) : bool = +func checkCurveEqG2*( x, y: Fp2[BN254_Snarks] ) : bool = if isZeroFp2(x) and isZeroFp2(y): # the point at infinity is on the curve by definition return true else: - var x2 : Fp2 = squareFp2(x) - var y2 : Fp2 = squareFp2(y) - var x3 : Fp2 = x2 * x; - var eq : Fp2 + var x2 = squareFp2(x) + var y2 = squareFp2(y) + var x3 = x2 * x + var eq : Fp2[BN254_Snarks] eq = x3 eq += twistCoeffB eq -= y2 @@ -92,14 +91,14 @@ func checkCurveEqG2*( x, y: Fp2 ) : bool = #------------------------------------------------------------------------------- -func mkG1*( x, y: Fp ) : G1 = +func mkG1*( x, y: Fp[BN254_Snarks] ) : G1 = if isZeroFp(x) and isZeroFp(y): return infG1 else: assert( checkCurveEqG1(x,y) , "mkG1: not a G1 curve point" ) return unsafeMkG1(x,y) -func mkG2*( x, y: Fp2 ) : G2 = +func mkG2*( x, y: Fp2[BN254_Snarks] ) : G2 = if isZeroFp2(x) and isZeroFp2(y): return infG2 else: @@ -109,16 +108,16 @@ func mkG2*( x, y: Fp2 ) : G2 = #------------------------------------------------------------------------------- # group generators -const gen1_x : Fp = fromHex(Fp, "0x01") -const gen1_y : Fp = fromHex(Fp, "0x02") +const gen1_x = fromHex(Fp[BN254_Snarks], "0x01") +const gen1_y = fromHex(Fp[BN254_Snarks], "0x02") -const gen2_xi : Fp = fromHex(Fp, "0x1adcd0ed10df9cb87040f46655e3808f98aa68a570acf5b0bde23fab1f149701") -const gen2_xu : Fp = fromHex(Fp, "0x09e847e9f05a6082c3cd2a1d0a3a82e6fbfbe620f7f31269fa15d21c1c13b23b") -const gen2_yi : Fp = fromHex(Fp, "0x056c01168a5319461f7ca7aa19d4fcfd1c7cdf52dbfc4cbee6f915250b7f6fc8") -const gen2_yu : Fp = fromHex(Fp, "0x0efe500a2d02dd77f5f401329f30895df553b878fc3c0dadaaa86456a623235c") +const gen2_xi = fromHex(Fp[BN254_Snarks], "0x1adcd0ed10df9cb87040f46655e3808f98aa68a570acf5b0bde23fab1f149701") +const gen2_xu = fromHex(Fp[BN254_Snarks], "0x09e847e9f05a6082c3cd2a1d0a3a82e6fbfbe620f7f31269fa15d21c1c13b23b") +const gen2_yi = fromHex(Fp[BN254_Snarks], "0x056c01168a5319461f7ca7aa19d4fcfd1c7cdf52dbfc4cbee6f915250b7f6fc8") +const gen2_yu = fromHex(Fp[BN254_Snarks], "0x0efe500a2d02dd77f5f401329f30895df553b878fc3c0dadaaa86456a623235c") -const gen2_x : Fp2 = mkFp2( gen2_xi, gen2_xu ) -const gen2_y : Fp2 = mkFp2( gen2_yi, gen2_yu ) +const gen2_x = mkFp2( gen2_xi, gen2_xu ) +const gen2_y = mkFp2( gen2_yi, gen2_yu ) const gen1* : G1 = unsafeMkG1( gen1_x, gen1_y ) const gen2* : G2 = unsafeMkG2( gen2_x, gen2_y ) @@ -215,9 +214,9 @@ func `**`*( coeff: BigInt , point: G2 ) : G2 = #------------------------------------------------------------------------------- -func pairing* (p: G1, q: G2) : Fp12 = - var t : Fp12 - ate.pairing_bn[BN254Snarks]( t, p, q ) +func pairing* (p: G1, q: G2) : Fp12[BN254_Snarks] = + var t : Fp12[BN254_Snarks] + ate.pairing_bn( t, p, q ) return t #------------------------------------------------------------------------------- @@ -225,7 +224,8 @@ func pairing* (p: G1, q: G2) : Fp12 = proc sanityCheckGroupGen*() = echo( "gen1 on the curve = ", checkCurveEqG1(gen1.x,gen1.y) ) echo( "gen2 on the curve = ", checkCurveEqG2(gen2.x,gen2.y) ) - echo( "order of gen1 is R = ", (not bool(isInf(gen1))) and bool(isInf(primeR ** gen1)) ) - echo( "order of gen2 is R = ", (not bool(isInf(gen2))) and bool(isInf(primeR ** gen2)) ) + # TODO: fix compilation error with Constantine 0.2.0: + # echo( "order of gen1 is R = ", (not bool(isNeutral(gen1))) and bool(isNeutral(primeR ** gen1)) ) + # echo( "order of gen2 is R = ", (not bool(isNeutral(gen2))) and bool(isNeutral(primeR ** gen2)) ) #------------------------------------------------------------------------------- diff --git a/groth16/bn128/debug.nim b/groth16/bn128/debug.nim index ab7403c..7795294 100644 --- a/groth16/bn128/debug.nim +++ b/groth16/bn128/debug.nim @@ -9,23 +9,26 @@ # equation: y^2 = x^3 + 3 # +import constantine/named/properties_fields +import constantine/math/extension_fields/towers + import groth16/bn128/fields import groth16/bn128/curves import groth16/bn128/io #------------------------------------------------------------------------------- -proc debugPrintFp*(prefix: string, x: Fp) = +proc debugPrintFp*(prefix: string, x: Fp[BN254_Snarks]) = echo(prefix & toDecimalFp(x)) -proc debugPrintFp2*(prefix: string, z: Fp2) = +proc debugPrintFp2*(prefix: string, z: Fp2[BN254_Snarks]) = echo(prefix & " 1 ~> " & toDecimalFp(z.coords[0])) echo(prefix & " u ~> " & toDecimalFp(z.coords[1])) -proc debugPrintFr*(prefix: string, x: Fr) = +proc debugPrintFr*(prefix: string, x: Fr[BN254_Snarks]) = echo(prefix & toDecimalFr(x)) -proc debugPrintFrSeq*(msg: string, xs: seq[Fr]) = +proc debugPrintFrSeq*(msg: string, xs: seq[Fr[BN254_Snarks]]) = echo "---------------------" echo msg for x in xs: diff --git a/groth16/bn128/fields.nim b/groth16/bn128/fields.nim index 5ffa8a1..2f52c50 100644 --- a/groth16/bn128/fields.nim +++ b/groth16/bn128/fields.nim @@ -14,22 +14,16 @@ import std/sequtils import constantine/math/arithmetic import constantine/math/io/io_fields import constantine/math/io/io_bigints -import constantine/math/config/curves -import constantine/math/config/type_ff as tff +import constantine/named/properties_fields as tff import constantine/math/extension_fields/towers as ext #------------------------------------------------------------------------------- type B* = BigInt[256] -type Fr* = tff.Fr[BN254Snarks] -type Fp* = tff.Fp[BN254Snarks] -type Fp2* = ext.QuadraticExt[Fp] -type Fp12* = ext.Fp12[BN254Snarks] - -func mkFp2* (i: Fp, u: Fp) : Fp2 = - let c : array[2, Fp] = [i,u] - return ext.QuadraticExt[Fp]( coords: c ) +func mkFp2* (i: Fp[BN254_Snarks], u: Fp[BN254_Snarks]) : Fp2[BN254_Snarks] = + let c : array[2, Fp[BN254_Snarks]] = [i,u] + return ext.QuadraticExt[Fp[BN254_Snarks]]( coords: c ) #------------------------------------------------------------------------------- @@ -38,16 +32,16 @@ const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b97 #------------------------------------------------------------------------------- -const zeroFp* : Fp = fromHex( Fp, "0x00" ) -const zeroFr* : Fr = fromHex( Fr, "0x00" ) -const oneFp* : Fp = fromHex( Fp, "0x01" ) -const oneFr* : Fr = fromHex( Fr, "0x01" ) +const zeroFp* = fromHex( Fp[BN254_Snarks], "0x00" ) +const zeroFr* = fromHex( Fr[BN254_Snarks], "0x00" ) +const oneFp* = fromHex( Fp[BN254_Snarks], "0x01" ) +const oneFr* = fromHex( Fr[BN254_Snarks], "0x01" ) -const zeroFp2* : Fp2 = mkFp2( zeroFp, zeroFp ) -const oneFp2* : Fp2 = mkFp2( oneFp , zeroFp ) +const zeroFp2* = mkFp2( zeroFp, zeroFp ) +const oneFp2* = mkFp2( oneFp , zeroFp ) -const minusOneFp* : Fp = fromHex( Fp, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd46" ) -const minusOneFr* : Fr = fromHex( Fr, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000000" ) +const minusOneFp* = fromHex( Fp[BN254_Snarks], "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd46" ) +const minusOneFr* = fromHex( Fr[BN254_Snarks], "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000000" ) #------------------------------------------------------------------------------- @@ -56,13 +50,13 @@ func intToB*(a: uint): B = y.setUint(a) return y -func intToFp*(a: int): Fp = - var y : Fp +func intToFp*(a: int): Fp[BN254_Snarks] = + var y : Fp[BN254_Snarks] y.fromInt(a) return y -func intToFr*(a: int): Fr = - var y : Fr +func intToFr*(a: int): Fr[BN254_Snarks] = + var y : Fr[BN254_Snarks] y.fromInt(a) return y @@ -154,27 +148,27 @@ func smallPowFr*(base: Fr, expo: int): Fr = #------------------------------------------------------------------------------- -func deltaFr*(i, j: int) : Fr = +func deltaFr*[T](i, j: int) : Fr[T] = return (if (i == j): oneFr else: zeroFr) #------------------------------------------------------------------------------- # Montgomery batch inversion -func batchInverseFr*( xs: seq[Fr] ) : seq[Fr] = +func batchInverseFr*( xs: seq[Fr[BN254_Snarks]] ) : seq[Fr[BN254_Snarks]] = let n = xs.len assert(n>0) - var us : seq[Fr] = newSeq[Fr](n+1) + var us : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]](n+1) var a = xs[0] us[0] = oneFr us[1] = a for i in 1.. primeP_254 - k65536 ): return "-" & toDecimalFp(negFp(a)) else: return toDecimalFp(a) -func signedToDecimalFr*(a : Fr): string = +func signedToDecimalFr*(a : Fr[BN254_Snarks]): string = if bool( a.toBig() > primeR_254 - k65536 ): return "-" & toDecimalFr(negFr(a)) else: @@ -58,38 +58,38 @@ func signedToDecimalFr*(a : Fr): string = # # R=2^256; this computes 2^256 mod Fp -func calcFpMontR*() : Fp = - var x : Fp = intToFp(2) +func calcFpMontR*() : Fp[BN254_Snarks] = + var x : Fp[BN254_Snarks] = intToFp(2) for i in 1..8: square(x) return x # R=2^256; this computes the inverse of (2^256 mod Fp) -func calcFpInvMontR*() : Fp = - var x : Fp = calcFpMontR() +func calcFpInvMontR*() : Fp[BN254_Snarks] = + var x : Fp[BN254_Snarks] = calcFpMontR() inv(x) return x # R=2^256; this computes 2^256 mod Fr -func calcFrMontR*() : Fr = - var x : Fr = intToFr(2) +func calcFrMontR*() : Fr[BN254_Snarks] = + var x : Fr[BN254_Snarks] = intToFr(2) for i in 1..8: square(x) return x # R=2^256; this computes the inverse of (2^256 mod Fp) -func calcFrInvMontR*() : Fr = - var x : Fr = calcFrMontR() +func calcFrInvMontR*() : Fr[BN254_Snarks] = + var x : Fr[BN254_Snarks] = calcFrMontR() inv(x) return x # apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?) -const fpMontR* : Fp = fromHex( Fp, "0x0e0a77c19a07df2f666ea36f7879462c0a78eb28f5c70b3dd35d438dc58f0d9d" ) -const fpInvMontR* : Fp = fromHex( Fp, "0x2e67157159e5c639cf63e9cfb74492d9eb2022850278edf8ed84884a014afa37" ) +const fpMontR* = fromHex( Fp[BN254_Snarks], "0x0e0a77c19a07df2f666ea36f7879462c0a78eb28f5c70b3dd35d438dc58f0d9d" ) +const fpInvMontR* = fromHex( Fp[BN254_Snarks], "0x2e67157159e5c639cf63e9cfb74492d9eb2022850278edf8ed84884a014afa37" ) # apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?) -const frMontR* : Fr = fromHex( Fr, "0x0e0a77c19a07df2f666ea36f7879462e36fc76959f60cd29ac96341c4ffffffb" ) -const frInvMontR* : Fr = fromHex( Fr, "0x15ebf95182c5551cc8260de4aeb85d5d090ef5a9e111ec87dc5ba0056db1194e" ) +const frMontR* = fromHex( Fr[BN254_Snarks], "0x0e0a77c19a07df2f666ea36f7879462e36fc76959f60cd29ac96341c4ffffffb" ) +const frInvMontR* = fromHex( Fr[BN254_Snarks], "0x15ebf95182c5551cc8260de4aeb85d5d090ef5a9e111ec87dc5ba0056db1194e" ) proc checkMontgomeryConstants*() = assert( bool( fpMontR == calcFpMontR() ) ) @@ -103,18 +103,18 @@ proc checkMontgomeryConstants*() = # the binary file `.zkey` used by the `circom` ecosystem uses little-endian # Montgomery representation. So when we unmarshal with Constantine, it will # give the wrong result. Calling this function on the result fixes that. -func fromMontgomeryFp*(x : Fp) : Fp = - var y : Fp = x; +func fromMontgomeryFp*(x : Fp[BN254_Snarks]) : Fp[BN254_Snarks] = + var y : Fp[BN254_Snarks] = x; y *= fpInvMontR return y -func fromMontgomeryFr*(x : Fr) : Fr = - var y : Fr = x; +func fromMontgomeryFr*(x : Fr[BN254_Snarks]) : Fr[BN254_Snarks] = + var y = x; y *= frInvMontR return y -func toMontgomeryFr*(x : Fr) : Fr = - var y : Fr = x; +func toMontgomeryFr*(x : Fr[BN254_Snarks]) : Fr[BN254_Snarks] = + var y = x; y *= frMontR return y @@ -123,47 +123,47 @@ func toMontgomeryFr*(x : Fr) : Fr = # Note: in the `.zkey` coefficients, e apparently DOUBLE Montgomery encoding is used ?!? # -func unmarshalFpMont* ( bs: array[32,byte] ) : Fp = +func unmarshalFpMont* ( bs: array[32,byte] ) : Fp[BN254_Snarks] = var big : BigInt[254] unmarshal( big, bs, littleEndian ); - var x : Fp + var x : Fp[BN254_Snarks] x.fromBig( big ) return fromMontgomeryFp(x) # WTF Jordi, go home you are drunk -func unmarshalFrWTF* ( bs: array[32,byte] ) : Fr = +func unmarshalFrWTF* ( bs: array[32,byte] ) : Fr[BN254_Snarks] = var big : BigInt[254] unmarshal( big, bs, littleEndian ); - var x : Fr + var x : Fr[BN254_Snarks] x.fromBig( big ) return fromMontgomeryFr(fromMontgomeryFr(x)) -func unmarshalFrStd* ( bs: array[32,byte] ) : Fr = +func unmarshalFrStd* ( bs: array[32,byte] ) : Fr[BN254_Snarks] = var big : BigInt[254] unmarshal( big, bs, littleEndian ); - var x : Fr + var x : Fr[BN254_Snarks] x.fromBig( big ) return x -func unmarshalFrMont* ( bs: array[32,byte] ) : Fr = +func unmarshalFrMont* ( bs: array[32,byte] ) : Fr[BN254_Snarks] = var big : BigInt[254] unmarshal( big, bs, littleEndian ); - var x : Fr + var x : Fr[BN254_Snarks] x.fromBig( big ) return fromMontgomeryFr(x) #------------------------------------------------------------------------------- -func unmarshalFpMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fp] = - var vals : seq[Fp] = newSeq[Fp]( len ) +func unmarshalFpMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fp[BN254_Snarks]] = + var vals : seq[Fp[BN254_Snarks]] = newSeq[Fp[BN254_Snarks]]( len ) var bytes : array[32,byte] for i in 0..= m: for i in 0.. 0: y = cs[0] for i in 1..= m: for i in 0..= m: for i in 0..=1 ) - var cs : seq[Fr] = newSeq[Fr]( N+1 ) + var cs = newSeq[Fr[BN254_Snarks]]( N+1 ) cs[0] = negFr(b) cs[N] = a return Poly(coeffs: cs) @@ -186,9 +187,9 @@ type func polyQuotRemByVanishing*(P: Poly, N: int): QuotRem[Poly] = assert( N>=1 ) let deg : int = polyDegree(P) - let src : seq[Fr] = P.coeffs - var quot : seq[Fr] = newSeq[Fr]( max(1, deg - N + 1) ) - var rem : seq[Fr] = newSeq[Fr]( N ) + let src = P.coeffs + var quot = newSeq[Fr[BN254_Snarks]]( max(1, deg - N + 1) ) + var rem = newSeq[Fr[BN254_Snarks]]( N ) if deg < N: rem = src @@ -222,14 +223,14 @@ func polyDivideByVanishing*(P: Poly, N: int): Poly = # Lagrange basis polynomials func lagrangePoly*(D: Domain, k: int): Poly = let N = D.domainSize - let omMinusK : Fr = smallPowFr( D.invDomainGen , k ) - let invN : Fr = invFr(intToFr(N)) + let omMinusK = smallPowFr( D.invDomainGen , k ) + let invN = invFr(intToFr(N)) - var cs : seq[Fr] = newSeq[Fr]( N ) + var cs = newSeq[Fr[BN254_Snarks]]( N ) if k == 0: for i in 0..= 1) - var ys : seq[Fr] = newSeq[Fr](n) + var ys = newSeq[Fr[BN254_Snarks]](n) ys[0] = xs[0] if n >= 1: ys[1] = eta * xs[1] - var spow : Fr = eta + var spow = eta for i in 2.. # -proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC): seq[Fr] = +proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC): seq[Fr[BN254_Snarks]] = let n = abc.valuesAz.len assert( abc.valuesBz.len == n ) assert( abc.valuesCz.len == n ) @@ -164,15 +185,21 @@ proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC): seq[Fr] = var pool = Taskpool.new(num_threads = nthreads) - var A1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesAz, D, eta ) - var B1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesBz, D, eta ) - var C1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesCz, D, eta ) + var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]] - let A1 = sync A1fv - let B1 = sync B1fv - let C1 = sync C1fv + var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 ) + var taskB1 = pool.spawn shiftEvalDomainTask( abc.valuesBz, D, eta, addr outputB1 ) + var taskC1 = pool.spawn shiftEvalDomainTask( abc.valuesCz, D, eta, addr outputC1 ) - var ys : seq[Fr] = newSeq[Fr]( n ) + discard sync taskA1 + discard sync taskB1 + discard sync taskC1 + + let A1 = outputA1.extract() + let B1 = outputB1.extract() + let C1 = outputC1.extract() + + var ys : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( n ) for j in 0.. - let rhs1 : Fp12 = vkey.spec.alphaBeta # < alpha , beta > - let rhs2 : Fp12 = pairing( prf.pi_c , vkey.spec.delta2 ) # < pi_c , delta > - let rhs3 : Fp12 = pairing( pubG1 , vkey.spec.gamma2 ) # < sum... , gamma > + let lhs = pairing( negG1(prf.pi_a) , prf.pi_b ) # < -pi_a , pi_b > + let rhs1 = vkey.spec.alphaBeta # < alpha , beta > + let rhs2 = pairing( prf.pi_c , vkey.spec.delta2 ) # < pi_c , delta > + let rhs3 = pairing( pubG1 , vkey.spec.gamma2 ) # < sum... , gamma > - var eq : Fp12 + var eq : Fp12[BN254_Snarks] eq = lhs eq *= rhs1 eq *= rhs2 diff --git a/groth16/zkey_types.nim b/groth16/zkey_types.nim index 16b3969..1ce1d0b 100644 --- a/groth16/zkey_types.nim +++ b/groth16/zkey_types.nim @@ -1,5 +1,7 @@ -import constantine/math/arithmetic except Fp, Fr +import constantine/math/arithmetic +import constantine/named/properties_fields +import constantine/math/extension_fields/towers import groth16/bn128 @@ -28,7 +30,7 @@ type gamma2* : G2 # = gamma * g2 delta1* : G1 # = delta * g1 delta2* : G2 # = delta * g2 - alphaBeta* : Fp12 # = + alphaBeta* : Fp12[BN254_Snarks] # = VerifierPoints* = object pointsIC* : seq[G1] # the points `delta^-1 * ( beta*A_j(tau) + alpha*B_j(tau) + C_j(tau) ) * g1` (for j <= npub) @@ -49,7 +51,7 @@ type matrix* : MatrixSel row* : int col* : int - coeff* : Fr + coeff* : Fr[BN254_Snarks] ZKey* = object # sectionMask* : uint32 diff --git a/tests/groth16/testProver.nim b/tests/groth16/testProver.nim index cdf8a5b..04daf2f 100644 --- a/tests/groth16/testProver.nim +++ b/tests/groth16/testProver.nim @@ -44,7 +44,7 @@ const myR1CS = ) # the equation we want prove is `7*11*13 + 1022 == 2023` -let myWitnessValues : seq[Fr] = map( @[ 1, 2023, 1022, 7, 11, 13, 7*11, 7*11*13 ] , intToFr ) +let myWitnessValues = map( @[ 1, 2023, 1022, 7, 11, 13, 7*11, 7*11*13 ] , intToFr ) # wire indices: ^^^^^^^ 0 1 2 3 4 5 6 7 let myWitness = @@ -57,7 +57,7 @@ let myWitness = #------------------------------------------------------------------------------- proc testProof(zkey: ZKey, witness: Witness): bool = - let proof = generateProof( zkey, witness ) + let proof = generateProof( 8, false, zkey, witness ) let vkey = extractVKey( zkey) let ok = verifyProof( vkey, proof ) return ok