diff --git a/README.md b/README.md index 86cde3d..e8990c9 100644 --- a/README.md +++ b/README.md @@ -6,24 +6,24 @@ This is Groth16 prover implementation in Nim, using the [`constantine`](https://github.com/mratsim/constantine) library as an arithmetic / curve backend. -The implementation is compatible with the `circom` ecosystem. +The implementation is compatible with the `circom` + `snarkjs` ecosystem. At the moment only the `BN254` (aka. `alt-bn128`) curve is supported. ### License -Licensed and distributed under either of +Licensed and distributed under either of the [MIT license](http://opensource.org/licenses/MIT) or [Apache License, v2.0](http://www.apache.org/licenses/LICENSE-2.0), -at your option. +at your choice. ### TODO - [ ] make it a nimble package - [ ] refactor `bn128.nim` into smaller files - [ ] proper MSM implementation (I couldn't make constantine's one to work) -- [ ] compare `.r1cs` to the "coeffs" section of `.zkey` -- [ ] generate fake circuit-specific setup ourselves +- [x] compare `.r1cs` to the "coeffs" section of `.zkey` +- [x] generate fake circuit-specific setup ourselves - [ ] multithreaded support (MSM, and possibly also FFT) - [ ] add Groth16 notes - [ ] document the `snarkjs` circuit-specific setup `H` points convention diff --git a/bn128.nim b/bn128.nim index 45decef..4830311 100644 --- a/bn128.nim +++ b/bn128.nim @@ -67,6 +67,9 @@ func pairing* (p: G1, q: G2) : Fp12 = const primeP* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian ) const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian ) +const primeP_254 : BigInt[254] = fromHex( BigInt[254], "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian ) +const primeR_254 : BigInt[254] = fromHex( BigInt[254], "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian ) + #------------------------------------------------------------------------------- const zeroFp* : Fp = fromHex( Fp, "0x00" ) @@ -82,6 +85,11 @@ const infG2* : G2 = unsafeMkG2( zeroFp2 , zeroFp2 ) #------------------------------------------------------------------------------- +func intToB*(a: uint): B = + var y : B + y.setUint(a) + return y + func intToFp*(a: int): Fp = var y : Fp y.fromInt(a) @@ -97,14 +105,46 @@ func intToFr*(a: int): Fr = func isZeroFp*(x: Fp): bool = bool(isZero(x)) func isZeroFr*(x: Fr): bool = bool(isZero(x)) -func isEquaLFp*(x, y: Fp): bool = bool(x == y) -func isEquaLFr*(x, y: Fr): bool = bool(x == y) +func isEqualFp*(x, y: Fp): bool = bool(x == y) +func isEqualFr*(x, y: Fr): bool = bool(x == y) + +func `===`*(x, y: Fp): bool = isEqualFp(x,y) +func `===`*(x, y: Fr): bool = isEqualFr(x,y) + +#------------------- + +func isEqualFpSeq*(xs, ys: seq[Fp]): bool = + let n = xs.len + assert( n == ys.len ) + var b = true + for i in 0.. primeP_254 - k65536 ): + return "-" & toDecimalFp(negFp(a)) + else: + return toDecimalFp(a) + +func signedToDecimalFr*(a : Fr): string = + if bool( a.toBig() > primeR_254 - k65536 ): + return "-" & toDecimalFr(negFr(a)) + else: + return toDecimalFr(a) + #------------------------------------------------------------------------------- proc debugPrintFp*(prefix: string, x: Fp) = @@ -333,6 +394,23 @@ func mkG2( x, y: Fp2 ) : G2 = assert( checkCurveEqG2(x,y) , "mkG2: not a G2 curve point" ) return unsafeMkG2(x,y) +#------------------------------------------------------------------------------- +# group generators + +const gen1_x : Fp = fromHex(Fp, "0x01") +const gen1_y : Fp = fromHex(Fp, "0x02") + +const gen2_xi : Fp = fromHex(Fp, "0x1adcd0ed10df9cb87040f46655e3808f98aa68a570acf5b0bde23fab1f149701") +const gen2_xu : Fp = fromHex(Fp, "0x09e847e9f05a6082c3cd2a1d0a3a82e6fbfbe620f7f31269fa15d21c1c13b23b") +const gen2_yi : Fp = fromHex(Fp, "0x056c01168a5319461f7ca7aa19d4fcfd1c7cdf52dbfc4cbee6f915250b7f6fc8") +const gen2_yu : Fp = fromHex(Fp, "0x0efe500a2d02dd77f5f401329f30895df553b878fc3c0dadaaa86456a623235c") + +const gen2_x : Fp2 = mkFp2( gen2_xi, gen2_xu ) +const gen2_y : Fp2 = mkFp2( gen2_yi, gen2_yu ) + +const gen1* : G1 = unsafeMkG1( gen1_x, gen1_y ) +const gen2* : G2 = unsafeMkG2( gen2_x, gen2_y ) + #------------------------------------------------------------------------------- func isOnCurveG1* ( p: G1 ) : bool = @@ -625,6 +703,24 @@ func `**`*( coeff: Fr , point: G2 ) : G2 = prj.affine( r, q ) return r +#------------------- + +func `**`*( coeff: BigInt , point: G1 ) : G1 = + var q : ProjG1 + prj.fromAffine( q , point ) + scl.scalarMul( q , coeff ) + var r : G1 + prj.affine( r, q ) + return r + +func `**`*( coeff: BigInt , point: G2 ) : G2 = + var q : ProjG2 + prj.fromAffine( q , point ) + scl.scalarMul( q , coeff ) + var r : G2 + prj.affine( r, q ) + return r + #------------------------------------------------------------------------------- func msmNaiveG1( coeffs: seq[Fr] , points: seq[G1] ): G1 = @@ -675,3 +771,11 @@ func msmG2*( coeffs: seq[Fr] , points: seq[G2] ): G2 = return msmNaiveG2( coeffs, points ) #------------------------------------------------------------------------------- + +proc sanityCheckGroupGen*() = + echo( "gen1 on the curve = ", checkCurveEqG1(gen1.x,gen1.y) ) + echo( "gen2 on the curve = ", checkCurveEqG2(gen2.x,gen2.y) ) + echo( "order of gen1 is R = ", (not bool(isInf(gen1))) and bool(isInf(primeR ** gen1)) ) + echo( "order of gen2 is R = ", (not bool(isInf(gen2))) and bool(isInf(primeR ** gen2)) ) + +#------------------------------------------------------------------------------- diff --git a/domain.nim b/domain.nim index 340953e..334c2ce 100644 --- a/domain.nim +++ b/domain.nim @@ -14,9 +14,11 @@ import ./misc type Domain* = object - domainSize* : int - logDomainSize* : int - domainGen* : Fr + domainSize* : int # `N = 2^n` + logDomainSize* : int # `n = log2(N)` + domainGen* : Fr # `g` + invDomainGen* : Fr # `g^-1` + invDomainSize* : Fr # `1/n` #------------------------------------------------------------------------------- @@ -36,7 +38,12 @@ func createDomain*(size: int): Domain = assert( bool(a == oneFr) , "domain generator sanity check /A" ) assert( not bool(b == oneFr) , "domain generator sanity check /B" ) - return Domain( domainSize:size, logDomainSize:log2, domainGen:gen ) + return Domain( domainSize: size + , logDomainSize: log2 + , domainGen: gen + , invDomainGen: invFr(gen) + , invDomainSize: invFr(intToFr(size)) + ) #------------------------------------------------------------------------------- @@ -49,3 +56,4 @@ func enumerateDomain*(D: Domain): seq[Fr] = return xs #------------------------------------------------------------------------------- + diff --git a/fake_setup.nim b/fake_setup.nim new file mode 100644 index 0000000..00bbae4 --- /dev/null +++ b/fake_setup.nim @@ -0,0 +1,234 @@ + +# +# create "fake" circuit-specific trusted setup for testing purposes +# +# by fake here I mean that no actual ceremoney is done, we just generate +# some random toxic waste +# + +import sugar +import std/sequtils + +import constantine/math/arithmetic except Fp, Fr + +import bn128 +import domain +import poly +import zkey_types +import r1cs +import misc + +#------------------------------------------------------------------------------- + +type + ToxicWaste = object + alpha: Fr + beta: Fr + gamma: Fr + delta: Fr + tau: Fr + +proc randomToxicWaste(): ToxicWaste = + let a = randFr() + let b = randFr() + let c = randFr() + let d = randFr() + let t = randFr() + return ToxicWaste( alpha: a + , beta: b + , gamma: c + , delta: d + , tau: t ) + +#------------------------------------------------------------------------------- + +func r1csToCoeffs*(r1cs: R1CS): seq[Coeff] = + var coeffs : seq[Coeff] + let n = r1cs.constraints.len + let p = r1cs.cfg.nPubIn + r1cs.cfg.nPubOut + for i in 0.. + alpha1* : G1 # = alpha * g1 + beta1* : G1 # = beta * g1 + beta2* : G2 # = beta * g2 + gamma2* : G2 # = gamma * g2 + delta1* : G1 # = delta * g1 + delta2* : G2 # = delta * g2 + alphaBeta* : Fp12 # = VerifierPoints* = object - pointsIC* : seq[G1] + pointsIC* : seq[G1] # the points `delta^-1 * ( beta*A_j(tau) + alpha*B_j(tau) + C_j(tau) ) * g1` (for j <= npub) ProverPoints* = object - pointsA1* : seq[G1] - pointsB1* : seq[G1] - pointsB2* : seq[G2] - pointsC1* : seq[G1] - pointsH1* : seq[G1] + pointsA1* : seq[G1] # the points `A_j(tau) * g1` + pointsB1* : seq[G1] # the points `B_j(tau) * g1` + pointsB2* : seq[G2] # the points `B_j(tau) * g2` + pointsC1* : seq[G1] # the points `delta^-1 * ( beta*A_j(tau) + alpha*B_j(tau) + C_j(tau) ) * g1` (for j > npub) + pointsH1* : seq[G1] # meaning depends on `flavour` MatrixSel* = enum MatrixA @@ -68,3 +73,21 @@ func extractVKey*(zkey: Zkey): VKey = return VKey(curve:curve, spec:spec, vpoints:vpts) #------------------------------------------------------------------------------- + +func matrixSelToString(sel: MatrixSel): string = + case sel + of MatrixA: return "A" + of MatrixB: return "B" + of MatrixC: return "C" + +proc printCoeff(cf: Coeff) = + echo( "matrix=", matrixSelToString(cf.matrix) + , " | i=", cf.row + , " | j=", cf.col + , " | val=", signedToDecimalFr(cf.coeff) + ) + +proc printCoeffs*(cfs: seq[Coeff]) = + for cf in cfs: printCoeff(cf) + +#-------------------------------------------------------------------------------