diff --git a/.gitignore b/.gitignore index aeb4c66..110e37e 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ _bck* tmp main +main.nim *.json diff --git a/README.md b/README.md index 87f9ca2..ce614b5 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,24 @@ -Groth16 prover in Nim ---------------------- +Groth16 prover written in Nim +----------------------------- This is Groth16 prover implementation in Nim, using the [`constantine`](https://github.com/mratsim/constantine) library as an arithmetic / curve backend. -The implementation should be compatible with the `circom` ecosystem. +The implementation is compatible with the `circom` ecosystem. At the moment only the `BN254` (aka. `alt-bn128`) curve is supported. + + +### TODO + +- [ ] make it a nimble package +- [ ] proper MSM implementation (I couldn't make constantine's one to work) +- [ ] proper polynomial implemention (constantine's one is essentially missing) +- [ ] compare `.r1cs` to the "coeffs" section of `.zkey` +- [ ] make it work for different curves +- [ ] multithreaded support (MSM, and possibly also FFT) +- [ ] add Groth16 notes +- [ ] document the `snarkjs` circuit-specific setup `H` points convention + diff --git a/bn128.nim b/bn128.nim index be43806..96c9d54 100644 --- a/bn128.nim +++ b/bn128.nim @@ -1,6 +1,19 @@ +# +# the `alt-bn128` elliptic curve +# +# See for example +# +# p = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +# r = 21888242871839275222246405745257275088548364400416034343698204186575808495617 +# +# equation: y^2 = x^3 + 3 +# + +import sugar import std/bitops import std/strutils +import std/sequtils import std/streams import std/random @@ -14,8 +27,8 @@ import constantine/math/extension_fields/towers as ext import constantine/math/elliptic/ec_shortweierstrass_affine as aff import constantine/math/elliptic/ec_shortweierstrass_projective as prj import constantine/math/pairings/pairings_bn as ate -import constantine/math/elliptic/ec_multi_scalar_mul as msm import constantine/math/elliptic/ec_scalar_mul as scl +# import constantine/math/elliptic/ec_multi_scalar_mul as msm #------------------------------------------------------------------------------- @@ -118,6 +131,12 @@ func smallPowFr*(base: Fr, expo: uint): Fr = square(s) return a +func smallPowFr*(base: Fr, expo: int): Fr = + if expo >= 0: + return smallPowFr( base, uint(expo) ) + else: + return smallPowFr( invFr(base) , uint(-expo) ) + #------------------------------------------------------------------------------- func toDecimalBig*[n](a : BigInt[n]): string = @@ -140,11 +159,59 @@ func toDecimalFr*(a : Fr): string = #------------------------------------------------------------------------------- -proc debugPrintSeqFr*(msg: string, xs: seq[Fr]) = +proc debugPrintFp*(prefix: string, x: Fp) = + echo(prefix & toDecimalFp(x)) + +proc debugPrintFp2*(prefix: string, z: Fp2) = + echo(prefix & " 1 ~> " & toDecimalFp(z.coords[0])) + echo(prefix & " u ~> " & toDecimalFp(z.coords[1])) + +proc debugPrintFr*(prefix: string, x: Fr) = + echo(prefix & toDecimalFr(x)) + +proc debugPrintFrSeq*(msg: string, xs: seq[Fr]) = echo "---------------------" echo msg for x in xs: - echo(" " & toDecimalFr(x)) + debugPrintFr( " " , x ) + +proc debugPrintG1*(msg: string, pt: G1) = + echo(msg & ":") + debugPrintFp( " x = ", pt.x ) + debugPrintFp( " y = ", pt.y ) + +proc debugPrintG2*(msg: string, pt: G2) = + echo(msg & ":") + debugPrintFp2( " x = ", pt.x ) + debugPrintFp2( " y = ", pt.y ) + +#------------------------------------------------------------------------------- + +# Montgomery batch inversion +func batchInverse*( xs: seq[Fr] ) : seq[Fr] = + let n = xs.len + assert(n>0) + var us : seq[Fr] = newSeq[Fr](n+1) + var a = xs[0] + us[0] = oneFr + us[1] = a + for i in 1.. 0 ) + assert( bool(prf.publicIO[0] == oneFr) ) + + let f = open(fpath, fmWrite) + defer: f.close() + + for i in 1.. +# + +#[] +import sugar +import constantine/math/config/curves +import constantine/math/io/io_fields +import constantine/math/io/io_bigints +import ./zkey +]# + +import constantine/math/arithmetic except Fp, Fr +import constantine/math/io/io_extfields except Fp12 +import constantine/math/extension_fields/towers except Fp2, Fp12 + +import ./bn128 +import ./domain +import ./poly +import ./zkey_types +import ./witness + +#------------------------------------------------------------------------------- + +type + Proof* = object + publicIO* : seq[Fr] + pi_a* : G1 + pi_b* : G2 + pi_c* : G1 + curve : string + +#------------------------------------------------------------------------------- +# the verifier +# + +proc verifyProof* (vkey: VKey, prf: Proof): bool = + + assert( prf.curve == "bn128" ) + + assert( isOnCurveG1(prf.pi_a) , "pi_a is not in G1" ) + assert( isOnCurveG2(prf.pi_b) , "pi_b is not in G2" ) + assert( isOnCurveG1(prf.pi_c) , "pi_c is not in G1" ) + + var pubG1 : G1 = msmG1( prf.publicIO , vkey.vpoints.pointsIC ) + + let lhs : Fp12 = pairing( negG1(prf.pi_a) , prf.pi_b ) # < -pi_a , pi_b > + let rhs1 : Fp12 = vkey.spec.alphaBeta # < alpha , beta > + let rhs2 : Fp12 = pairing( prf.pi_c , vkey.spec.delta2 ) # < pi_c , delta > + let rhs3 : Fp12 = pairing( pubG1 , vkey.spec.gamma2 ) # < sum... , gamma > + + var eq : Fp12 + eq = lhs + eq *= rhs1 + eq *= rhs2 + eq *= rhs3 + + return bool(isOne(eq)) + +#------------------------------------------------------------------------------- +# A, B, C column vectors +# + +type + ABC = object + valuesA : seq[Fr] + valuesB : seq[Fr] + valuesC : seq[Fr] + +func buildABC( zkey: ZKey, witness: seq[Fr] ): ABC = + let hdr: GrothHeader = zkey.header + let domSize = hdr.domainSize + + var valuesA : seq[Fr] = newSeq[Fr](domSize) + var valuesB : seq[Fr] = newSeq[Fr](domSize) + for entry in zkey.coeffs: + case entry.matrix + of MatrixA: valuesA[entry.row] += entry.coeff * witness[entry.col] + of MatrixB: valuesB[entry.row] += entry.coeff * witness[entry.col] + else: raise newException(AssertionDefect, "fatal error") + + var valuesC : seq[Fr] = newSeq[Fr](domSize) + for i in 0..= 1) + var ys : seq[Fr] = newSeq[Fr](n) + ys[0] = xs[0] + if n >= 1: ys[1] = eta * xs[1] + var spow : Fr = eta + for i in 2.. +# +func computeSnarkjsScalarCoeffs( abc: ABC ): seq[Fr] = + let n = abc.valuesA.len + let D = createDomain(n) + let eta = createDomain(2*n).domainGen + let A1 = shiftEvalDomain( abc.valuesA, D, eta ) + let B1 = shiftEvalDomain( abc.valuesB, D, eta ) + let C1 = shiftEvalDomain( abc.valuesC, D, eta ) + var ys : seq[Fr] = newSeq[Fr]( n ) + for j in 0..= m: - for i in 0..= m: - for i in 0..=1 ) var cs : seq[Fr] = newSeq[Fr]( N+1 ) - cs[0] = negFr( oneFr ) - cs[N] = oneFr + cs[0] = negFr(b) + cs[N] = a return Poly(coeffs: cs) +# the vanishing polynomial `(x^N - 1)` +func vanishingPoly*(N: int): Poly = + return generalizedVanishingPoly(N, oneFr, oneFr) + +#------------------------------------------------------------------------------- + type QuotRem*[T] = object quot* : T @@ -164,13 +192,15 @@ func polyQuotRemByVanishing*(P: Poly, N: int): QuotRem[Poly] = rem = src else: - # compute quot + + # compute quotient for j in countdown(deg-N, 0): if j+N <= deg-N: quot[j] = src[j+N] + quot[j+N] else: quot[j] = src[j+N] - # compute rem + + # compute remainder for j in 0.. # length = 2 * n8p * domSize = domSize G1 points # # 10: Contributions @@ -82,64 +86,20 @@ # ??? (but not required for proving, only for checking that the `.zkey` file is valid) # +#------------------------------------------------------------------------------- + import std/streams import constantine/math/arithmetic except Fp, Fr -import constantine/math/io/io_bigints - +#import constantine/math/io/io_bigints + import ./bn128 +import ./zkey_types import ./container import ./misc #------------------------------------------------------------------------------- -type - - GrothHeader* = object - p : BigInt[256] - r : BigInt[256] - nvars : int - npubs : int - domainSize : int - logDomainSize : int - - SpecPoints* = object - alpha1 : G1 - beta1 : G1 - beta2 : G2 - gamma2 : G2 - delta1 : G1 - delta2 : G2 - - VerifierPoints* = object - pointsIC : seq[G1] - - ProverPoints* = object - pointsA1 : seq[G1] - pointsB1 : seq[G1] - pointsB2 : seq[G2] - pointsC1 : seq[G1] - pointsH1 : seq[G1] - - MatrixSel* = enum - MatrixA - MatrixB - MatrixC - - Coeff* = object - matrix : MatrixSel - row : int - col : int - coeff : Fr - - ZKey* = object - sectionMask : uint32 - header : GrothHeader - specPoints : SpecPoints - vPoints : VerifierPoints - pPoints : ProverPoints - coeffs : seq[Coeff] - proc parseSection1_proverType ( stream: Stream, user: var Zkey, sectionLen: int ) = assert( sectionLen == 4 , "unexpected section length" ) let proverType = stream.readUint32 @@ -148,13 +108,13 @@ proc parseSection1_proverType ( stream: Stream, user: var Zkey, sectionLen: int #------------------------------------------------------------------------------- proc parseSection2_GrothHeader( stream: Stream, user: var ZKey, sectionLen: int ) = - echo "\nparsing the Groth16 zkey header" + # echo "\nparsing the Groth16 zkey header" let (n8p, p) = parsePrimeField( stream ) # size of the base field let (n8r, r) = parsePrimeField( stream ) # size of the scalar field - echo("p = ",toDecimalBig(p)) - echo("r = ",toDecimalBig(r)) + # echo("p = ",toDecimalBig(p)) + # echo("r = ",toDecimalBig(r)) assert( sectionLen == 2*4 + n8p + n8r + 3*4 + 3*64 + 3*128 , "unexpected section length" ) @@ -167,6 +127,7 @@ proc parseSection2_GrothHeader( stream: Stream, user: var ZKey, sectionLen: int assert( bool(p == primeP) , "expecting the alt-bn128 curve" ) assert( bool(r == primeR) , "expecting the alt-bn128 curve" ) + header.curve = "bn128" let nvars = int( stream.readUint32() ) let npubs = int( stream.readUint32() ) @@ -175,9 +136,9 @@ proc parseSection2_GrothHeader( stream: Stream, user: var ZKey, sectionLen: int assert( (1 shl log2siz) == domsiz , "domain size should be a power of two" ) - echo("nvars = ",nvars) - echo("npubs = ",npubs) - echo("domsiz = ",domsiz) + # echo("nvars = ",nvars) + # echo("npubs = ",npubs) + # echo("domsiz = ",domsiz) header.nvars = nvars header.npubs = npubs @@ -194,6 +155,7 @@ proc parseSection2_GrothHeader( stream: Stream, user: var ZKey, sectionLen: int spec.gamma2 = loadPointG2( stream ) spec.delta1 = loadPointG1( stream ) spec.delta2 = loadPointG2( stream ) + spec.alphaBeta = pairing( spec.alpha1, spec.beta2 ) user.specPoints = spec #------------------------------------------------------------------------------- @@ -206,9 +168,9 @@ proc parseSection4_Coeffs( stream: Stream, user: var ZKey, sectionLen: int ) = var coeffs : seq[Coeff] for i in 1..ncoeffs: - let m = int( stream.readUint32() ) - let r = int( stream.readUint32() ) - let c = int( stream.readUint32() ) + let m = int( stream.readUint32() ) # which matrix + let r = int( stream.readUint32() ) # row (equation index) + let c = int( stream.readUint32() ) # column (witness index) assert( m >= 0 and m <= 2 , "invalid matrix selector" ) let sel : MatrixSel = case m of 0: MatrixA @@ -217,7 +179,7 @@ proc parseSection4_Coeffs( stream: Stream, user: var ZKey, sectionLen: int ) = else: raise newException(AssertionDefect, "fatal error") assert( r >= 0 and r < nrows, "row index out of range" ) assert( c >= 0 and c < ncols, "column index out of range" ) - let cf = loadValueFr( stream ) + let cf = loadValueFrWTF( stream ) # Jordi, WTF is this encoding ?!?!?!!111 let entry = Coeff( matrix:sel, row:r, col:c, coeff:cf ) coeffs.add( entry ) @@ -270,10 +232,11 @@ proc zkeyCallback(stream: Stream, sectId: int, sectLen: int, user: var ZKey) = of 9: parseSection9_PointsH1( stream, user, sectLen ) else: discard -proc parseZKey* (fname: string) = +proc parseZKey* (fname: string): ZKey = var zkey : ZKey parseContainer( "zkey", 1, fname, zkey, zkeyCallback, proc (id: int): bool = id == 1 ) parseContainer( "zkey", 1, fname, zkey, zkeyCallback, proc (id: int): bool = id == 2 ) parseContainer( "zkey", 1, fname, zkey, zkeyCallback, proc (id: int): bool = id >= 3 ) + return zkey #------------------------------------------------------------------------------- diff --git a/zkey_types.nim b/zkey_types.nim new file mode 100644 index 0000000..5dd48cf --- /dev/null +++ b/zkey_types.nim @@ -0,0 +1,70 @@ + +import constantine/math/arithmetic except Fp, Fr + +import ./bn128 + +#------------------------------------------------------------------------------- + +type + + GrothHeader* = object + curve* : string + p* : BigInt[256] + r* : BigInt[256] + nvars* : int + npubs* : int + domainSize* : int + logDomainSize* : int + + SpecPoints* = object + alpha1* : G1 + beta1* : G1 + beta2* : G2 + gamma2* : G2 + delta1* : G1 + delta2* : G2 + alphaBeta* : Fp12 # = + + VerifierPoints* = object + pointsIC* : seq[G1] + + ProverPoints* = object + pointsA1* : seq[G1] + pointsB1* : seq[G1] + pointsB2* : seq[G2] + pointsC1* : seq[G1] + pointsH1* : seq[G1] + + MatrixSel* = enum + MatrixA + MatrixB + MatrixC + + Coeff* = object + matrix* : MatrixSel + row* : int + col* : int + coeff* : Fr + + ZKey* = object + # sectionMask* : uint32 + header* : GrothHeader + specPoints* : SpecPoints + vPoints* : VerifierPoints + pPoints* : ProverPoints + coeffs* : seq[Coeff] + + VKey* = object + curve* : string + spec* : SpecPoints + vpoints* : VerifierPoints + +#------------------------------------------------------------------------------- + +func extractVKey*(zkey: Zkey): VKey = + let curve = zkey.header.curve + let spec = zkey.specPoints + let vpts = zkey.vPoints + return VKey(curve:curve, spec:spec, vpoints:vpts) + +#-------------------------------------------------------------------------------