diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..d915edc --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,21 @@ +name: CI + +on: [push, pull_request] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + nim: [2.2.4] + steps: + - uses: actions/checkout@v2 + - uses: iffy/install-nim@v5 + with: + version: ${{ matrix.nim }} + - name: Install dependencies + run: nimble install -d -y + - name: Test + run: nimble test -y + - name: Build binary + run: nimble build -y diff --git a/cli/cli_main.nim b/cli/cli_main.nim index b6a5a0e..b126503 100644 --- a/cli/cli_main.nim +++ b/cli/cli_main.nim @@ -9,6 +9,8 @@ import std/times import std/options # import strformat +import taskpools + import groth16/prover import groth16/verifier import groth16/files/witness @@ -203,11 +205,12 @@ proc cliMain(cfg: Config) = else: echo("generating proof...") let print_timings = cfg.measure_time and cfg.verbose + var pool = Taskpool.new(cfg.nthreads) withMeasureTime(cfg.measure_time,"proving"): if cfg.no_masking: - proof = generateProofWithTrivialMask(cfg.nthreads, print_timings, zkey, wtns) + proof = generateProofWithTrivialMask(zkey, wtns, pool, print_timings) else: - proof = generateProof(cfg.nthreads, print_timings, zkey, wtns) + proof = generateProof(zkey, wtns, pool, print_timings) if not (cfg.output_file == ""): echo("exporting the proof to " & quoted(cfg.output_file)) diff --git a/groth16.nimble b/groth16.nimble index a6566ea..168d302 100644 --- a/groth16.nimble +++ b/groth16.nimble @@ -8,4 +8,4 @@ binDir = "build" namedBin = {"cli/cli_main": "nim-groth16"}.toTable() requires "https://github.com/status-im/nim-taskpools" -requires "https://github.com/mratsim/constantine#5f7ba18f2ed351260015397c9eae079a6decaee1" \ No newline at end of file +requires "https://github.com/mratsim/constantine >= 0.2.0" \ No newline at end of file diff --git a/groth16/bn128/curves.nim b/groth16/bn128/curves.nim index 42611b2..f8bb9f1 100644 --- a/groth16/bn128/curves.nim +++ b/groth16/bn128/curves.nim @@ -13,13 +13,12 @@ #import constantine/platforms/abstractions #import constantine/math/isogenies/frobenius -import constantine/math/arithmetic except Fp, Fr -import constantine/math/io/io_fields except Fp, Fr import constantine/math/io/io_bigints -import constantine/math/config/curves +import constantine/math/arithmetic +import constantine/math/io/io_fields -import constantine/math/config/type_ff as tff except Fp, Fr -import constantine/math/extension_fields/towers as ext except Fp, Fp2, Fp12, Fr +import constantine/named/properties_fields as tff +import constantine/math/extension_fields/towers as ext import constantine/math/elliptic/ec_shortweierstrass_affine as aff import constantine/math/elliptic/ec_shortweierstrass_projective as prj @@ -30,19 +29,19 @@ import groth16/bn128/fields #------------------------------------------------------------------------------- -type G1* = aff.ECP_ShortW_Aff[Fp , aff.G1] -type G2* = aff.ECP_ShortW_Aff[Fp2, aff.G2] +type G1* = aff.EC_ShortW_Aff[Fp[BN254_Snarks] , aff.G1] +type G2* = aff.EC_ShortW_Aff[Fp2[BN254_Snarks], aff.G2] -type ProjG1* = prj.ECP_ShortW_Prj[Fp , prj.G1] -type ProjG2* = prj.ECP_ShortW_Prj[Fp2, prj.G2] +type ProjG1* = prj.EC_ShortW_Prj[Fp[BN254_Snarks] , prj.G1] +type ProjG2* = prj.EC_ShortW_Prj[Fp2[BN254_Snarks], prj.G2] #------------------------------------------------------------------------------- -func unsafeMkG1* ( X, Y: Fp ) : G1 = - return aff.ECP_ShortW_Aff[Fp, aff.G1](x: X, y: Y) +func unsafeMkG1* ( X, Y: Fp[BN254_Snarks] ) : G1 = + return aff.EC_ShortW_Aff[Fp[BN254_Snarks], aff.G1](x: X, y: Y) -func unsafeMkG2* ( X, Y: Fp2 ) : G2 = - return aff.ECP_ShortW_Aff[Fp2, aff.G2](x: X, y: Y) +func unsafeMkG2* ( X, Y: Fp2[BN254_Snarks] ) : G2 = + return aff.EC_ShortW_Aff[Fp2[BN254_Snarks], aff.G2](x: X, y: Y) #------------------------------------------------------------------------------- @@ -51,15 +50,15 @@ const infG2* : G2 = unsafeMkG2( zeroFp2 , zeroFp2 ) #------------------------------------------------------------------------------- -func checkCurveEqG1*( x, y: Fp ) : bool = +func checkCurveEqG1*( x, y: Fp[BN254_Snarks] ) : bool = if bool(isZero(x)) and bool(isZero(y)): # the point at infinity is on the curve by definition return true else: - var x2 : Fp = squareFp(x) - var y2 : Fp = squareFp(y) - var x3 : Fp = x2 * x - var eq : Fp + var x2 = squareFp(x) + var y2 = squareFp(y) + var x3 = x2 * x + var eq : Fp[BN254_Snarks] eq = x3 eq += intToFp(3) eq -= y2 @@ -72,19 +71,19 @@ func checkCurveEqG1*( x, y: Fp ) : bool = # B = b1 + bu*u # b1 = 19485874751759354771024239261021720505790618469301721065564631296452457478373 # b2 = 266929791119991161246907387137283842545076965332900288569378510910307636690 -const twistCoeffB_1 : Fp = fromHex(Fp, "0x2b149d40ceb8aaae81be18991be06ac3b5b4c5e559dbefa33267e6dc24a138e5") -const twistCoeffB_u : Fp = fromHex(Fp, "0x009713b03af0fed4cd2cafadeed8fdf4a74fa084e52d1852e4a2bd0685c315d2") -const twistCoeffB : Fp2 = mkFp2( twistCoeffB_1 , twistCoeffB_u ) +const twistCoeffB_1 = fromHex(Fp[BN254_Snarks], "0x2b149d40ceb8aaae81be18991be06ac3b5b4c5e559dbefa33267e6dc24a138e5") +const twistCoeffB_u = fromHex(Fp[BN254_Snarks], "0x009713b03af0fed4cd2cafadeed8fdf4a74fa084e52d1852e4a2bd0685c315d2") +const twistCoeffB = mkFp2( twistCoeffB_1 , twistCoeffB_u ) -func checkCurveEqG2*( x, y: Fp2 ) : bool = +func checkCurveEqG2*( x, y: Fp2[BN254_Snarks] ) : bool = if isZeroFp2(x) and isZeroFp2(y): # the point at infinity is on the curve by definition return true else: - var x2 : Fp2 = squareFp2(x) - var y2 : Fp2 = squareFp2(y) - var x3 : Fp2 = x2 * x; - var eq : Fp2 + var x2 = squareFp2(x) + var y2 = squareFp2(y) + var x3 = x2 * x + var eq : Fp2[BN254_Snarks] eq = x3 eq += twistCoeffB eq -= y2 @@ -92,14 +91,14 @@ func checkCurveEqG2*( x, y: Fp2 ) : bool = #------------------------------------------------------------------------------- -func mkG1*( x, y: Fp ) : G1 = +func mkG1*( x, y: Fp[BN254_Snarks] ) : G1 = if isZeroFp(x) and isZeroFp(y): return infG1 else: assert( checkCurveEqG1(x,y) , "mkG1: not a G1 curve point" ) return unsafeMkG1(x,y) -func mkG2*( x, y: Fp2 ) : G2 = +func mkG2*( x, y: Fp2[BN254_Snarks] ) : G2 = if isZeroFp2(x) and isZeroFp2(y): return infG2 else: @@ -109,16 +108,16 @@ func mkG2*( x, y: Fp2 ) : G2 = #------------------------------------------------------------------------------- # group generators -const gen1_x : Fp = fromHex(Fp, "0x01") -const gen1_y : Fp = fromHex(Fp, "0x02") +const gen1_x = fromHex(Fp[BN254_Snarks], "0x01") +const gen1_y = fromHex(Fp[BN254_Snarks], "0x02") -const gen2_xi : Fp = fromHex(Fp, "0x1adcd0ed10df9cb87040f46655e3808f98aa68a570acf5b0bde23fab1f149701") -const gen2_xu : Fp = fromHex(Fp, "0x09e847e9f05a6082c3cd2a1d0a3a82e6fbfbe620f7f31269fa15d21c1c13b23b") -const gen2_yi : Fp = fromHex(Fp, "0x056c01168a5319461f7ca7aa19d4fcfd1c7cdf52dbfc4cbee6f915250b7f6fc8") -const gen2_yu : Fp = fromHex(Fp, "0x0efe500a2d02dd77f5f401329f30895df553b878fc3c0dadaaa86456a623235c") +const gen2_xi = fromHex(Fp[BN254_Snarks], "0x1adcd0ed10df9cb87040f46655e3808f98aa68a570acf5b0bde23fab1f149701") +const gen2_xu = fromHex(Fp[BN254_Snarks], "0x09e847e9f05a6082c3cd2a1d0a3a82e6fbfbe620f7f31269fa15d21c1c13b23b") +const gen2_yi = fromHex(Fp[BN254_Snarks], "0x056c01168a5319461f7ca7aa19d4fcfd1c7cdf52dbfc4cbee6f915250b7f6fc8") +const gen2_yu = fromHex(Fp[BN254_Snarks], "0x0efe500a2d02dd77f5f401329f30895df553b878fc3c0dadaaa86456a623235c") -const gen2_x : Fp2 = mkFp2( gen2_xi, gen2_xu ) -const gen2_y : Fp2 = mkFp2( gen2_yi, gen2_yu ) +const gen2_x = mkFp2( gen2_xi, gen2_xu ) +const gen2_y = mkFp2( gen2_yi, gen2_yu ) const gen1* : G1 = unsafeMkG1( gen1_x, gen1_y ) const gen2* : G2 = unsafeMkG2( gen2_x, gen2_y ) @@ -215,9 +214,9 @@ func `**`*( coeff: BigInt , point: G2 ) : G2 = #------------------------------------------------------------------------------- -func pairing* (p: G1, q: G2) : Fp12 = - var t : Fp12 - ate.pairing_bn[BN254Snarks]( t, p, q ) +func pairing* (p: G1, q: G2) : Fp12[BN254_Snarks] = + var t : Fp12[BN254_Snarks] + ate.pairing_bn( t, p, q ) return t #------------------------------------------------------------------------------- @@ -225,7 +224,8 @@ func pairing* (p: G1, q: G2) : Fp12 = proc sanityCheckGroupGen*() = echo( "gen1 on the curve = ", checkCurveEqG1(gen1.x,gen1.y) ) echo( "gen2 on the curve = ", checkCurveEqG2(gen2.x,gen2.y) ) - echo( "order of gen1 is R = ", (not bool(isInf(gen1))) and bool(isInf(primeR ** gen1)) ) - echo( "order of gen2 is R = ", (not bool(isInf(gen2))) and bool(isInf(primeR ** gen2)) ) + # TODO: fix compilation error with Constantine 0.2.0: + # echo( "order of gen1 is R = ", (not bool(isNeutral(gen1))) and bool(isNeutral(primeR ** gen1)) ) + # echo( "order of gen2 is R = ", (not bool(isNeutral(gen2))) and bool(isNeutral(primeR ** gen2)) ) #------------------------------------------------------------------------------- diff --git a/groth16/bn128/debug.nim b/groth16/bn128/debug.nim index ab7403c..7795294 100644 --- a/groth16/bn128/debug.nim +++ b/groth16/bn128/debug.nim @@ -9,23 +9,26 @@ # equation: y^2 = x^3 + 3 # +import constantine/named/properties_fields +import constantine/math/extension_fields/towers + import groth16/bn128/fields import groth16/bn128/curves import groth16/bn128/io #------------------------------------------------------------------------------- -proc debugPrintFp*(prefix: string, x: Fp) = +proc debugPrintFp*(prefix: string, x: Fp[BN254_Snarks]) = echo(prefix & toDecimalFp(x)) -proc debugPrintFp2*(prefix: string, z: Fp2) = +proc debugPrintFp2*(prefix: string, z: Fp2[BN254_Snarks]) = echo(prefix & " 1 ~> " & toDecimalFp(z.coords[0])) echo(prefix & " u ~> " & toDecimalFp(z.coords[1])) -proc debugPrintFr*(prefix: string, x: Fr) = +proc debugPrintFr*(prefix: string, x: Fr[BN254_Snarks]) = echo(prefix & toDecimalFr(x)) -proc debugPrintFrSeq*(msg: string, xs: seq[Fr]) = +proc debugPrintFrSeq*(msg: string, xs: seq[Fr[BN254_Snarks]]) = echo "---------------------" echo msg for x in xs: diff --git a/groth16/bn128/fields.nim b/groth16/bn128/fields.nim index 5ffa8a1..2f52c50 100644 --- a/groth16/bn128/fields.nim +++ b/groth16/bn128/fields.nim @@ -14,22 +14,16 @@ import std/sequtils import constantine/math/arithmetic import constantine/math/io/io_fields import constantine/math/io/io_bigints -import constantine/math/config/curves -import constantine/math/config/type_ff as tff +import constantine/named/properties_fields as tff import constantine/math/extension_fields/towers as ext #------------------------------------------------------------------------------- type B* = BigInt[256] -type Fr* = tff.Fr[BN254Snarks] -type Fp* = tff.Fp[BN254Snarks] -type Fp2* = ext.QuadraticExt[Fp] -type Fp12* = ext.Fp12[BN254Snarks] - -func mkFp2* (i: Fp, u: Fp) : Fp2 = - let c : array[2, Fp] = [i,u] - return ext.QuadraticExt[Fp]( coords: c ) +func mkFp2* (i: Fp[BN254_Snarks], u: Fp[BN254_Snarks]) : Fp2[BN254_Snarks] = + let c : array[2, Fp[BN254_Snarks]] = [i,u] + return ext.QuadraticExt[Fp[BN254_Snarks]]( coords: c ) #------------------------------------------------------------------------------- @@ -38,16 +32,16 @@ const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b97 #------------------------------------------------------------------------------- -const zeroFp* : Fp = fromHex( Fp, "0x00" ) -const zeroFr* : Fr = fromHex( Fr, "0x00" ) -const oneFp* : Fp = fromHex( Fp, "0x01" ) -const oneFr* : Fr = fromHex( Fr, "0x01" ) +const zeroFp* = fromHex( Fp[BN254_Snarks], "0x00" ) +const zeroFr* = fromHex( Fr[BN254_Snarks], "0x00" ) +const oneFp* = fromHex( Fp[BN254_Snarks], "0x01" ) +const oneFr* = fromHex( Fr[BN254_Snarks], "0x01" ) -const zeroFp2* : Fp2 = mkFp2( zeroFp, zeroFp ) -const oneFp2* : Fp2 = mkFp2( oneFp , zeroFp ) +const zeroFp2* = mkFp2( zeroFp, zeroFp ) +const oneFp2* = mkFp2( oneFp , zeroFp ) -const minusOneFp* : Fp = fromHex( Fp, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd46" ) -const minusOneFr* : Fr = fromHex( Fr, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000000" ) +const minusOneFp* = fromHex( Fp[BN254_Snarks], "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd46" ) +const minusOneFr* = fromHex( Fr[BN254_Snarks], "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000000" ) #------------------------------------------------------------------------------- @@ -56,13 +50,13 @@ func intToB*(a: uint): B = y.setUint(a) return y -func intToFp*(a: int): Fp = - var y : Fp +func intToFp*(a: int): Fp[BN254_Snarks] = + var y : Fp[BN254_Snarks] y.fromInt(a) return y -func intToFr*(a: int): Fr = - var y : Fr +func intToFr*(a: int): Fr[BN254_Snarks] = + var y : Fr[BN254_Snarks] y.fromInt(a) return y @@ -154,27 +148,27 @@ func smallPowFr*(base: Fr, expo: int): Fr = #------------------------------------------------------------------------------- -func deltaFr*(i, j: int) : Fr = +func deltaFr*[T](i, j: int) : Fr[T] = return (if (i == j): oneFr else: zeroFr) #------------------------------------------------------------------------------- # Montgomery batch inversion -func batchInverseFr*( xs: seq[Fr] ) : seq[Fr] = +func batchInverseFr*( xs: seq[Fr[BN254_Snarks]] ) : seq[Fr[BN254_Snarks]] = let n = xs.len assert(n>0) - var us : seq[Fr] = newSeq[Fr](n+1) + var us : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]](n+1) var a = xs[0] us[0] = oneFr us[1] = a for i in 1.. primeP_254 - k65536 ): return "-" & toDecimalFp(negFp(a)) else: return toDecimalFp(a) -func signedToDecimalFr*(a : Fr): string = +func signedToDecimalFr*(a : Fr[BN254_Snarks]): string = if bool( a.toBig() > primeR_254 - k65536 ): return "-" & toDecimalFr(negFr(a)) else: @@ -58,38 +58,38 @@ func signedToDecimalFr*(a : Fr): string = # # R=2^256; this computes 2^256 mod Fp -func calcFpMontR*() : Fp = - var x : Fp = intToFp(2) +func calcFpMontR*() : Fp[BN254_Snarks] = + var x : Fp[BN254_Snarks] = intToFp(2) for i in 1..8: square(x) return x # R=2^256; this computes the inverse of (2^256 mod Fp) -func calcFpInvMontR*() : Fp = - var x : Fp = calcFpMontR() +func calcFpInvMontR*() : Fp[BN254_Snarks] = + var x : Fp[BN254_Snarks] = calcFpMontR() inv(x) return x # R=2^256; this computes 2^256 mod Fr -func calcFrMontR*() : Fr = - var x : Fr = intToFr(2) +func calcFrMontR*() : Fr[BN254_Snarks] = + var x : Fr[BN254_Snarks] = intToFr(2) for i in 1..8: square(x) return x # R=2^256; this computes the inverse of (2^256 mod Fp) -func calcFrInvMontR*() : Fr = - var x : Fr = calcFrMontR() +func calcFrInvMontR*() : Fr[BN254_Snarks] = + var x : Fr[BN254_Snarks] = calcFrMontR() inv(x) return x # apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?) -const fpMontR* : Fp = fromHex( Fp, "0x0e0a77c19a07df2f666ea36f7879462c0a78eb28f5c70b3dd35d438dc58f0d9d" ) -const fpInvMontR* : Fp = fromHex( Fp, "0x2e67157159e5c639cf63e9cfb74492d9eb2022850278edf8ed84884a014afa37" ) +const fpMontR* = fromHex( Fp[BN254_Snarks], "0x0e0a77c19a07df2f666ea36f7879462c0a78eb28f5c70b3dd35d438dc58f0d9d" ) +const fpInvMontR* = fromHex( Fp[BN254_Snarks], "0x2e67157159e5c639cf63e9cfb74492d9eb2022850278edf8ed84884a014afa37" ) # apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?) -const frMontR* : Fr = fromHex( Fr, "0x0e0a77c19a07df2f666ea36f7879462e36fc76959f60cd29ac96341c4ffffffb" ) -const frInvMontR* : Fr = fromHex( Fr, "0x15ebf95182c5551cc8260de4aeb85d5d090ef5a9e111ec87dc5ba0056db1194e" ) +const frMontR* = fromHex( Fr[BN254_Snarks], "0x0e0a77c19a07df2f666ea36f7879462e36fc76959f60cd29ac96341c4ffffffb" ) +const frInvMontR* = fromHex( Fr[BN254_Snarks], "0x15ebf95182c5551cc8260de4aeb85d5d090ef5a9e111ec87dc5ba0056db1194e" ) proc checkMontgomeryConstants*() = assert( bool( fpMontR == calcFpMontR() ) ) @@ -103,18 +103,18 @@ proc checkMontgomeryConstants*() = # the binary file `.zkey` used by the `circom` ecosystem uses little-endian # Montgomery representation. So when we unmarshal with Constantine, it will # give the wrong result. Calling this function on the result fixes that. -func fromMontgomeryFp*(x : Fp) : Fp = - var y : Fp = x; +func fromMontgomeryFp*(x : Fp[BN254_Snarks]) : Fp[BN254_Snarks] = + var y : Fp[BN254_Snarks] = x; y *= fpInvMontR return y -func fromMontgomeryFr*(x : Fr) : Fr = - var y : Fr = x; +func fromMontgomeryFr*(x : Fr[BN254_Snarks]) : Fr[BN254_Snarks] = + var y = x; y *= frInvMontR return y -func toMontgomeryFr*(x : Fr) : Fr = - var y : Fr = x; +func toMontgomeryFr*(x : Fr[BN254_Snarks]) : Fr[BN254_Snarks] = + var y = x; y *= frMontR return y @@ -123,47 +123,47 @@ func toMontgomeryFr*(x : Fr) : Fr = # Note: in the `.zkey` coefficients, e apparently DOUBLE Montgomery encoding is used ?!? # -func unmarshalFpMont* ( bs: array[32,byte] ) : Fp = +func unmarshalFpMont* ( bs: array[32,byte] ) : Fp[BN254_Snarks] = var big : BigInt[254] unmarshal( big, bs, littleEndian ); - var x : Fp + var x : Fp[BN254_Snarks] x.fromBig( big ) return fromMontgomeryFp(x) # WTF Jordi, go home you are drunk -func unmarshalFrWTF* ( bs: array[32,byte] ) : Fr = +func unmarshalFrWTF* ( bs: array[32,byte] ) : Fr[BN254_Snarks] = var big : BigInt[254] unmarshal( big, bs, littleEndian ); - var x : Fr + var x : Fr[BN254_Snarks] x.fromBig( big ) return fromMontgomeryFr(fromMontgomeryFr(x)) -func unmarshalFrStd* ( bs: array[32,byte] ) : Fr = +func unmarshalFrStd* ( bs: array[32,byte] ) : Fr[BN254_Snarks] = var big : BigInt[254] unmarshal( big, bs, littleEndian ); - var x : Fr + var x : Fr[BN254_Snarks] x.fromBig( big ) return x -func unmarshalFrMont* ( bs: array[32,byte] ) : Fr = +func unmarshalFrMont* ( bs: array[32,byte] ) : Fr[BN254_Snarks] = var big : BigInt[254] unmarshal( big, bs, littleEndian ); - var x : Fr + var x : Fr[BN254_Snarks] x.fromBig( big ) return fromMontgomeryFr(x) #------------------------------------------------------------------------------- -func unmarshalFpMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fp] = - var vals : seq[Fp] = newSeq[Fp]( len ) +func unmarshalFpMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fp[BN254_Snarks]] = + var vals : seq[Fp[BN254_Snarks]] = newSeq[Fp[BN254_Snarks]]( len ) var bytes : array[32,byte] for i in 0..1: (nthreads*task_multiplier) else: 1 - var pool = Taskpool.new(num_threads = nthreads) var pending : seq[FlowVar[mycurves.G1]] = newSeq[FlowVar[mycurves.G1]](ntasks) var a : int = 0 @@ -118,22 +116,18 @@ proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G1] for k in 0..1: (nthreads*task_multiplier) else: 1 - var pool = Taskpool.new(num_threads = nthreads) var pending : seq[FlowVar[mycurves.G2]] = newSeq[FlowVar[mycurves.G2]](ntasks) var a : int = 0 @@ -152,19 +146,16 @@ proc msmMultiThreadedG2*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G2] for k in 0.. bool ) = + , filt: proc(_: int): bool {.raises: [IOError, OSError].} + ) {.raises: [IOError, OSError].} = let sectId = int( stream.readUint32() ) let sectLen = int( stream.readUint64() ) @@ -77,7 +78,8 @@ proc parseContainer*[T] ( expectedMagic: string , fname: string , user: var T , callback: SectionCallback[T] - , filt: (int) -> bool ) = + , filt: proc(_: int): bool {.raises: [IOError, OSError].} + ) {.raises: [IOError, OSError].} = let stream = newFileStream(fname, mode = fmRead) defer: stream.close() diff --git a/groth16/files/export_json.nim b/groth16/files/export_json.nim index 53b2828..a4d991f 100644 --- a/groth16/files/export_json.nim +++ b/groth16/files/export_json.nim @@ -3,19 +3,20 @@ # export proof and public input in `circom`-compatible JSON files # -import constantine/math/arithmetic except Fp, Fr -#import constantine/math/io/io_fields except Fp, Fr +import constantine/math/arithmetic +import constantine/named/properties_fields +import constantine/math/extension_fields/towers import groth16/bn128 from groth16/prover import Proof #------------------------------------------------------------------------------- -func toQuotedDecimalFp(x: Fp): string = +func toQuotedDecimalFp(x: Fp[BN254_Snarks]): string = let s : string = toDecimalFp(x) return ("\"" & s & "\"") -func toQuotedDecimalFr(x: Fr): string = +func toQuotedDecimalFr(x: Fr[BN254_Snarks]): string = let s : string = toDecimalFr(x) return ("\"" & s & "\"") diff --git a/groth16/files/r1cs.nim b/groth16/files/r1cs.nim index cb65712..ad3d4fa 100644 --- a/groth16/files/r1cs.nim +++ b/groth16/files/r1cs.nim @@ -51,8 +51,9 @@ import std/streams -import constantine/math/arithmetic except Fp, Fr +import constantine/math/arithmetic import constantine/math/io/io_bigints +import constantine/named/properties_fields import groth16/bn128 import groth16/files/container @@ -67,8 +68,8 @@ type nPubIn* : int # number of public inputs nPrivIn* : int # number of private inputs nLabels* : int # number of labels - - Term* = tuple[ wireIdx: int, value: Fr ] + + Term* = tuple[ wireIdx: int, value: Fr[BN254_Snarks] ] LinComb* = seq[Term] Constraint* = tuple[ A: LinComb, B: LinComb, C: LinComb ] diff --git a/groth16/files/witness.nim b/groth16/files/witness.nim index f849eb3..381e85c 100644 --- a/groth16/files/witness.nim +++ b/groth16/files/witness.nim @@ -16,8 +16,9 @@ import std/streams -import constantine/math/arithmetic except Fp, Fr +import constantine/math/arithmetic import constantine/math/io/io_bigints +import constantine/named/properties_fields import groth16/bn128 import groth16/files/container @@ -29,7 +30,7 @@ type curve* : string r* : BigInt[256] nvars* : int - values* : seq[Fr] + values* : seq[Fr[BN254_Snarks]] #------------------------------------------------------------------------------- diff --git a/groth16/files/zkey.nim b/groth16/files/zkey.nim index 73ce1e5..8ee4129 100644 --- a/groth16/files/zkey.nim +++ b/groth16/files/zkey.nim @@ -94,7 +94,7 @@ import std/streams -import constantine/math/arithmetic except Fp, Fr +import constantine/math/arithmetic #import constantine/math/io/io_bigints import groth16/bn128 diff --git a/groth16/math/domain.nim b/groth16/math/domain.nim index 2185814..160b418 100644 --- a/groth16/math/domain.nim +++ b/groth16/math/domain.nim @@ -3,9 +3,10 @@ # power-of-two sized multiplicative FFT domains in the scalar field # -import constantine/math/arithmetic except Fp,Fr -import constantine/math/io/io_fields except Fp,Fr +import constantine/math/arithmetic +import constantine/math/io/io_fields #import constantine/math/io/io_bigints +import constantine/named/properties_fields import groth16/bn128 import groth16/misc @@ -14,27 +15,27 @@ import groth16/misc type Domain* = object - domainSize* : int # `N = 2^n` - logDomainSize* : int # `n = log2(N)` - domainGen* : Fr # `g` - invDomainGen* : Fr # `g^-1` - invDomainSize* : Fr # `1/n` + domainSize* : int # `N = 2^n` + logDomainSize* : int # `n = log2(N)` + domainGen* : Fr[BN254_Snarks] # `g` + invDomainGen* : Fr[BN254_Snarks] # `g^-1` + invDomainSize* : Fr[BN254_Snarks] # `1/n` #------------------------------------------------------------------------------- # the generator of the multiplicative subgroup with size `2^28` -const gen28 : Fr = fromHex( Fr, "0x2a3c09f0a58a7e8500e0a7eb8ef62abc402d111e41112ed49bd61b6e725b19f0" ) +const gen28 = fromHex( Fr[BN254_Snarks], "0x2a3c09f0a58a7e8500e0a7eb8ef62abc402d111e41112ed49bd61b6e725b19f0" ) func createDomain*(size: int): Domain = let log2 = ceilingLog2(size) assert( (1 shl log2) == size , "domain must have a power-of-two size" ) let expo : uint = 1'u shl (28 - log2) - let gen : Fr = smallPowFr(gen28, expo) + let gen = smallPowFr(gen28, expo) let halfSize = size div 2 - let a : Fr = smallPowFr(gen, size ) - let b : Fr = smallPowFr(gen, halfSize) + let a = smallPowFr(gen, size ) + let b = smallPowFr(gen, halfSize) assert( bool(a == oneFr) , "domain generator sanity check /A" ) assert( not bool(b == oneFr) , "domain generator sanity check /B" ) @@ -47,9 +48,9 @@ func createDomain*(size: int): Domain = #------------------------------------------------------------------------------- -func enumerateDomain*(D: Domain): seq[Fr] = - var xs : seq[Fr] = newSeq[Fr](D.domainSize) - var g : Fr = oneFr +func enumerateDomain*(D: Domain): seq[Fr[BN254_Snarks]] = + var xs = newSeq[Fr[BN254_Snarks]](D.domainSize) + var g = oneFr for i in 0..= m: for i in 0.. 0: y = cs[0] for i in 1..= m: for i in 0..= m: for i in 0..=1 ) - var cs : seq[Fr] = newSeq[Fr]( N+1 ) + var cs = newSeq[Fr[BN254_Snarks]]( N+1 ) cs[0] = negFr(b) cs[N] = a return Poly(coeffs: cs) @@ -186,9 +187,9 @@ type func polyQuotRemByVanishing*(P: Poly, N: int): QuotRem[Poly] = assert( N>=1 ) let deg : int = polyDegree(P) - let src : seq[Fr] = P.coeffs - var quot : seq[Fr] = newSeq[Fr]( max(1, deg - N + 1) ) - var rem : seq[Fr] = newSeq[Fr]( N ) + let src = P.coeffs + var quot = newSeq[Fr[BN254_Snarks]]( max(1, deg - N + 1) ) + var rem = newSeq[Fr[BN254_Snarks]]( N ) if deg < N: rem = src @@ -222,14 +223,14 @@ func polyDivideByVanishing*(P: Poly, N: int): Poly = # Lagrange basis polynomials func lagrangePoly*(D: Domain, k: int): Poly = let N = D.domainSize - let omMinusK : Fr = smallPowFr( D.invDomainGen , k ) - let invN : Fr = invFr(intToFr(N)) + let omMinusK = smallPowFr( D.invDomainGen , k ) + let invN = invFr(intToFr(N)) - var cs : seq[Fr] = newSeq[Fr]( N ) + var cs = newSeq[Fr[BN254_Snarks]]( N ) if k == 0: for i in 0.. # +{.push raises:[].} + #[ import sugar import constantine/math/config/curves @@ -20,10 +22,11 @@ import std/times import std/cpuinfo import system import taskpools +import constantine/math/arithmetic +import constantine/named/properties_fields +import constantine/math/extension_fields/towers -import constantine/math/arithmetic except Fp, Fr #import constantine/math/io/io_extfields except Fp12 -#import constantine/math/extension_fields/towers except Fp2, Fp12 import groth16/bn128 import groth16/math/domain @@ -36,7 +39,7 @@ import groth16/misc type Proof* = object - publicIO* : seq[Fr] + publicIO* : seq[Fr[BN254_Snarks]] pi_a* : G1 pi_b* : G2 pi_c* : G1 @@ -48,17 +51,17 @@ type type ABC = object - valuesAz : seq[Fr] - valuesBz : seq[Fr] - valuesCz : seq[Fr] + valuesAz : seq[Fr[BN254_Snarks]] + valuesBz : seq[Fr[BN254_Snarks]] + valuesCz : seq[Fr[BN254_Snarks]] # computes the vectors A*z, B*z, C*z where z is the witness -func buildABC( zkey: ZKey, witness: seq[Fr] ): ABC = +func buildABC( zkey: ZKey, witness: seq[Fr[BN254_Snarks]] ): ABC = let hdr: GrothHeader = zkey.header let domSize = hdr.domainSize - var valuesAz : seq[Fr] = newSeq[Fr](domSize) - var valuesBz : seq[Fr] = newSeq[Fr](domSize) + var valuesAz = newSeq[Fr[BN254_Snarks]](domSize) + var valuesBz = newSeq[Fr[BN254_Snarks]](domSize) for entry in zkey.coeffs: case entry.matrix @@ -66,7 +69,7 @@ func buildABC( zkey: ZKey, witness: seq[Fr] ): ABC = of MatrixB: valuesBz[entry.row] += entry.coeff * witness[entry.col] else: raise newException(AssertionDefect, "fatal error") - var valuesCz : seq[Fr] = newSeq[Fr](domSize) + var valuesCz = newSeq[Fr[BN254_Snarks]](domSize) for i in 0..= 1) - var ys : seq[Fr] = newSeq[Fr](n) + var ys = newSeq[Fr[BN254_Snarks]](n) ys[0] = xs[0] if n >= 1: ys[1] = eta * xs[1] - var spow : Fr = eta + var spow = eta for i in 2.. # -proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC): seq[Fr] = +proc computeSnarkjsScalarCoeffs( abc: ABC, pool: TaskPool ): seq[Fr[BN254_Snarks]] = let n = abc.valuesAz.len assert( abc.valuesBz.len == n ) assert( abc.valuesCz.len == n ) let D = createDomain(n) let eta = createDomain(2*n).domainGen - var pool = Taskpool.new(num_threads = nthreads) + var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]] - var A1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesAz, D, eta ) - var B1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesBz, D, eta ) - var C1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesCz, D, eta ) + var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 ) + var taskB1 = pool.spawn shiftEvalDomainTask( abc.valuesBz, D, eta, addr outputB1 ) + var taskC1 = pool.spawn shiftEvalDomainTask( abc.valuesCz, D, eta, addr outputC1 ) - let A1 = sync A1fv - let B1 = sync B1fv - let C1 = sync C1fv + discard sync taskA1 + discard sync taskB1 + discard sync taskC1 - var ys : seq[Fr] = newSeq[Fr]( n ) + let A1 = outputA1.extract() + let B1 = outputB1.extract() + let C1 = outputC1.extract() + + var ys : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( n ) for j in 0.. - let rhs1 : Fp12 = vkey.spec.alphaBeta # < alpha , beta > - let rhs2 : Fp12 = pairing( prf.pi_c , vkey.spec.delta2 ) # < pi_c , delta > - let rhs3 : Fp12 = pairing( pubG1 , vkey.spec.gamma2 ) # < sum... , gamma > + let lhs = pairing( negG1(prf.pi_a) , prf.pi_b ) # < -pi_a , pi_b > + let rhs1 = vkey.spec.alphaBeta # < alpha , beta > + let rhs2 = pairing( prf.pi_c , vkey.spec.delta2 ) # < pi_c , delta > + let rhs3 = pairing( pubG1 , vkey.spec.gamma2 ) # < sum... , gamma > - var eq : Fp12 + var eq : Fp12[BN254_Snarks] eq = lhs eq *= rhs1 eq *= rhs2 diff --git a/groth16/zkey_types.nim b/groth16/zkey_types.nim index 16b3969..1ce1d0b 100644 --- a/groth16/zkey_types.nim +++ b/groth16/zkey_types.nim @@ -1,5 +1,7 @@ -import constantine/math/arithmetic except Fp, Fr +import constantine/math/arithmetic +import constantine/named/properties_fields +import constantine/math/extension_fields/towers import groth16/bn128 @@ -28,7 +30,7 @@ type gamma2* : G2 # = gamma * g2 delta1* : G1 # = delta * g1 delta2* : G2 # = delta * g2 - alphaBeta* : Fp12 # = + alphaBeta* : Fp12[BN254_Snarks] # = VerifierPoints* = object pointsIC* : seq[G1] # the points `delta^-1 * ( beta*A_j(tau) + alpha*B_j(tau) + C_j(tau) ) * g1` (for j <= npub) @@ -49,7 +51,7 @@ type matrix* : MatrixSel row* : int col* : int - coeff* : Fr + coeff* : Fr[BN254_Snarks] ZKey* = object # sectionMask* : uint32 diff --git a/tests/groth16/testProver.nim b/tests/groth16/testProver.nim index cdf8a5b..3bb3eda 100644 --- a/tests/groth16/testProver.nim +++ b/tests/groth16/testProver.nim @@ -2,6 +2,8 @@ import std/unittest import std/sequtils +import taskpools + import groth16/prover import groth16/verifier import groth16/fake_setup @@ -44,7 +46,7 @@ const myR1CS = ) # the equation we want prove is `7*11*13 + 1022 == 2023` -let myWitnessValues : seq[Fr] = map( @[ 1, 2023, 1022, 7, 11, 13, 7*11, 7*11*13 ] , intToFr ) +let myWitnessValues = map( @[ 1, 2023, 1022, 7, 11, 13, 7*11, 7*11*13 ] , intToFr ) # wire indices: ^^^^^^^ 0 1 2 3 4 5 6 7 let myWitness = @@ -57,9 +59,11 @@ let myWitness = #------------------------------------------------------------------------------- proc testProof(zkey: ZKey, witness: Witness): bool = - let proof = generateProof( zkey, witness ) + var pool = Taskpool.new() + let proof = generateProof( zkey, witness, pool ) let vkey = extractVKey( zkey) let ok = verifyProof( vkey, proof ) + pool.shutdown() return ok suite "prover":