mirror of
https://github.com/logos-storage/nim-groth16.git
synced 2026-01-04 06:33:12 +00:00
add "fake" trusted setup for testing purposes; implement both the paper version and snarkjs version of the prover
This commit is contained in:
parent
0544ef5b9e
commit
5ce7926e92
10
README.md
10
README.md
@ -6,24 +6,24 @@ This is Groth16 prover implementation in Nim, using the
|
||||
[`constantine`](https://github.com/mratsim/constantine)
|
||||
library as an arithmetic / curve backend.
|
||||
|
||||
The implementation is compatible with the `circom` ecosystem.
|
||||
The implementation is compatible with the `circom` + `snarkjs` ecosystem.
|
||||
|
||||
At the moment only the `BN254` (aka. `alt-bn128`) curve is supported.
|
||||
|
||||
### License
|
||||
|
||||
Licensed and distributed under either of
|
||||
Licensed and distributed under either of the
|
||||
[MIT license](http://opensource.org/licenses/MIT) or
|
||||
[Apache License, v2.0](http://www.apache.org/licenses/LICENSE-2.0),
|
||||
at your option.
|
||||
at your choice.
|
||||
|
||||
### TODO
|
||||
|
||||
- [ ] make it a nimble package
|
||||
- [ ] refactor `bn128.nim` into smaller files
|
||||
- [ ] proper MSM implementation (I couldn't make constantine's one to work)
|
||||
- [ ] compare `.r1cs` to the "coeffs" section of `.zkey`
|
||||
- [ ] generate fake circuit-specific setup ourselves
|
||||
- [x] compare `.r1cs` to the "coeffs" section of `.zkey`
|
||||
- [x] generate fake circuit-specific setup ourselves
|
||||
- [ ] multithreaded support (MSM, and possibly also FFT)
|
||||
- [ ] add Groth16 notes
|
||||
- [ ] document the `snarkjs` circuit-specific setup `H` points convention
|
||||
|
||||
108
bn128.nim
108
bn128.nim
@ -67,6 +67,9 @@ func pairing* (p: G1, q: G2) : Fp12 =
|
||||
const primeP* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian )
|
||||
const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian )
|
||||
|
||||
const primeP_254 : BigInt[254] = fromHex( BigInt[254], "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian )
|
||||
const primeR_254 : BigInt[254] = fromHex( BigInt[254], "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
const zeroFp* : Fp = fromHex( Fp, "0x00" )
|
||||
@ -82,6 +85,11 @@ const infG2* : G2 = unsafeMkG2( zeroFp2 , zeroFp2 )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func intToB*(a: uint): B =
|
||||
var y : B
|
||||
y.setUint(a)
|
||||
return y
|
||||
|
||||
func intToFp*(a: int): Fp =
|
||||
var y : Fp
|
||||
y.fromInt(a)
|
||||
@ -97,14 +105,46 @@ func intToFr*(a: int): Fr =
|
||||
func isZeroFp*(x: Fp): bool = bool(isZero(x))
|
||||
func isZeroFr*(x: Fr): bool = bool(isZero(x))
|
||||
|
||||
func isEquaLFp*(x, y: Fp): bool = bool(x == y)
|
||||
func isEquaLFr*(x, y: Fr): bool = bool(x == y)
|
||||
func isEqualFp*(x, y: Fp): bool = bool(x == y)
|
||||
func isEqualFr*(x, y: Fr): bool = bool(x == y)
|
||||
|
||||
func `===`*(x, y: Fp): bool = isEqualFp(x,y)
|
||||
func `===`*(x, y: Fr): bool = isEqualFr(x,y)
|
||||
|
||||
#-------------------
|
||||
|
||||
func isEqualFpSeq*(xs, ys: seq[Fp]): bool =
|
||||
let n = xs.len
|
||||
assert( n == ys.len )
|
||||
var b = true
|
||||
for i in 0..<n:
|
||||
if not bool(xs[i] == ys[i]):
|
||||
b = false
|
||||
break
|
||||
return b
|
||||
|
||||
func isEqualFrSeq*(xs, ys: seq[Fr]): bool =
|
||||
let n = xs.len
|
||||
assert( n == ys.len )
|
||||
var b = true
|
||||
for i in 0..<n:
|
||||
if not bool(xs[i] == ys[i]):
|
||||
b = false
|
||||
break
|
||||
return b
|
||||
|
||||
func `===`*(xs, ys: seq[Fp]): bool = isEqualFpSeq(xs,ys)
|
||||
func `===`*(xs, ys: seq[Fr]): bool = isEqualFrSeq(xs,ys)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
#func `+`*(x, y: B ): B = ( var z : B = x ; z += y ; return z )
|
||||
func `+`*[n](x, y: BigInt[n] ): BigInt[n] = ( var z : BigInt[n] = x ; z += y ; return z )
|
||||
func `+`*(x, y: Fp): Fp = ( var z : Fp = x ; z += y ; return z )
|
||||
func `+`*(x, y: Fr): Fr = ( var z : Fr = x ; z += y ; return z )
|
||||
|
||||
#func `-`*(x, y: B ): B = ( var z : B = x ; z -= y ; return z )
|
||||
func `-`*[n](x, y: BigInt[n] ): BigInt[n] = ( var z : BigInt[n] = x ; z -= y ; return z )
|
||||
func `-`*(x, y: Fp): Fp = ( var z : Fp = x ; z -= y ; return z )
|
||||
func `-`*(x, y: Fr): Fr = ( var z : Fr = x ; z -= y ; return z )
|
||||
|
||||
@ -139,6 +179,11 @@ func smallPowFr*(base: Fr, expo: int): Fr =
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func deltaFr*(i, j: int) : Fr =
|
||||
return (if (i == j): oneFr else: zeroFr)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func toDecimalBig*[n](a : BigInt[n]): string =
|
||||
var s : string = toDecimal(a)
|
||||
s = s.strip( leading=true, trailing=false, chars={'0'} )
|
||||
@ -157,6 +202,22 @@ func toDecimalFr*(a : Fr): string =
|
||||
if s.len == 0: s="0"
|
||||
return s
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
const k65536 : BigInt[254] = fromHex( BigInt[254], "0x10000", bigEndian )
|
||||
|
||||
func signedToDecimalFp*(a : Fp): string =
|
||||
if bool( a.toBig() > primeP_254 - k65536 ):
|
||||
return "-" & toDecimalFp(negFp(a))
|
||||
else:
|
||||
return toDecimalFp(a)
|
||||
|
||||
func signedToDecimalFr*(a : Fr): string =
|
||||
if bool( a.toBig() > primeR_254 - k65536 ):
|
||||
return "-" & toDecimalFr(negFr(a))
|
||||
else:
|
||||
return toDecimalFr(a)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc debugPrintFp*(prefix: string, x: Fp) =
|
||||
@ -333,6 +394,23 @@ func mkG2( x, y: Fp2 ) : G2 =
|
||||
assert( checkCurveEqG2(x,y) , "mkG2: not a G2 curve point" )
|
||||
return unsafeMkG2(x,y)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# group generators
|
||||
|
||||
const gen1_x : Fp = fromHex(Fp, "0x01")
|
||||
const gen1_y : Fp = fromHex(Fp, "0x02")
|
||||
|
||||
const gen2_xi : Fp = fromHex(Fp, "0x1adcd0ed10df9cb87040f46655e3808f98aa68a570acf5b0bde23fab1f149701")
|
||||
const gen2_xu : Fp = fromHex(Fp, "0x09e847e9f05a6082c3cd2a1d0a3a82e6fbfbe620f7f31269fa15d21c1c13b23b")
|
||||
const gen2_yi : Fp = fromHex(Fp, "0x056c01168a5319461f7ca7aa19d4fcfd1c7cdf52dbfc4cbee6f915250b7f6fc8")
|
||||
const gen2_yu : Fp = fromHex(Fp, "0x0efe500a2d02dd77f5f401329f30895df553b878fc3c0dadaaa86456a623235c")
|
||||
|
||||
const gen2_x : Fp2 = mkFp2( gen2_xi, gen2_xu )
|
||||
const gen2_y : Fp2 = mkFp2( gen2_yi, gen2_yu )
|
||||
|
||||
const gen1* : G1 = unsafeMkG1( gen1_x, gen1_y )
|
||||
const gen2* : G2 = unsafeMkG2( gen2_x, gen2_y )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func isOnCurveG1* ( p: G1 ) : bool =
|
||||
@ -625,6 +703,24 @@ func `**`*( coeff: Fr , point: G2 ) : G2 =
|
||||
prj.affine( r, q )
|
||||
return r
|
||||
|
||||
#-------------------
|
||||
|
||||
func `**`*( coeff: BigInt , point: G1 ) : G1 =
|
||||
var q : ProjG1
|
||||
prj.fromAffine( q , point )
|
||||
scl.scalarMul( q , coeff )
|
||||
var r : G1
|
||||
prj.affine( r, q )
|
||||
return r
|
||||
|
||||
func `**`*( coeff: BigInt , point: G2 ) : G2 =
|
||||
var q : ProjG2
|
||||
prj.fromAffine( q , point )
|
||||
scl.scalarMul( q , coeff )
|
||||
var r : G2
|
||||
prj.affine( r, q )
|
||||
return r
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func msmNaiveG1( coeffs: seq[Fr] , points: seq[G1] ): G1 =
|
||||
@ -675,3 +771,11 @@ func msmG2*( coeffs: seq[Fr] , points: seq[G2] ): G2 =
|
||||
return msmNaiveG2( coeffs, points )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc sanityCheckGroupGen*() =
|
||||
echo( "gen1 on the curve = ", checkCurveEqG1(gen1.x,gen1.y) )
|
||||
echo( "gen2 on the curve = ", checkCurveEqG2(gen2.x,gen2.y) )
|
||||
echo( "order of gen1 is R = ", (not bool(isInf(gen1))) and bool(isInf(primeR ** gen1)) )
|
||||
echo( "order of gen2 is R = ", (not bool(isInf(gen2))) and bool(isInf(primeR ** gen2)) )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
16
domain.nim
16
domain.nim
@ -14,9 +14,11 @@ import ./misc
|
||||
|
||||
type
|
||||
Domain* = object
|
||||
domainSize* : int
|
||||
logDomainSize* : int
|
||||
domainGen* : Fr
|
||||
domainSize* : int # `N = 2^n`
|
||||
logDomainSize* : int # `n = log2(N)`
|
||||
domainGen* : Fr # `g`
|
||||
invDomainGen* : Fr # `g^-1`
|
||||
invDomainSize* : Fr # `1/n`
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
@ -36,7 +38,12 @@ func createDomain*(size: int): Domain =
|
||||
assert( bool(a == oneFr) , "domain generator sanity check /A" )
|
||||
assert( not bool(b == oneFr) , "domain generator sanity check /B" )
|
||||
|
||||
return Domain( domainSize:size, logDomainSize:log2, domainGen:gen )
|
||||
return Domain( domainSize: size
|
||||
, logDomainSize: log2
|
||||
, domainGen: gen
|
||||
, invDomainGen: invFr(gen)
|
||||
, invDomainSize: invFr(intToFr(size))
|
||||
)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
@ -49,3 +56,4 @@ func enumerateDomain*(D: Domain): seq[Fr] =
|
||||
return xs
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
234
fake_setup.nim
Normal file
234
fake_setup.nim
Normal file
@ -0,0 +1,234 @@
|
||||
|
||||
#
|
||||
# create "fake" circuit-specific trusted setup for testing purposes
|
||||
#
|
||||
# by fake here I mean that no actual ceremoney is done, we just generate
|
||||
# some random toxic waste
|
||||
#
|
||||
|
||||
import sugar
|
||||
import std/sequtils
|
||||
|
||||
import constantine/math/arithmetic except Fp, Fr
|
||||
|
||||
import bn128
|
||||
import domain
|
||||
import poly
|
||||
import zkey_types
|
||||
import r1cs
|
||||
import misc
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
type
|
||||
ToxicWaste = object
|
||||
alpha: Fr
|
||||
beta: Fr
|
||||
gamma: Fr
|
||||
delta: Fr
|
||||
tau: Fr
|
||||
|
||||
proc randomToxicWaste(): ToxicWaste =
|
||||
let a = randFr()
|
||||
let b = randFr()
|
||||
let c = randFr()
|
||||
let d = randFr()
|
||||
let t = randFr()
|
||||
return ToxicWaste( alpha: a
|
||||
, beta: b
|
||||
, gamma: c
|
||||
, delta: d
|
||||
, tau: t )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func r1csToCoeffs*(r1cs: R1CS): seq[Coeff] =
|
||||
var coeffs : seq[Coeff]
|
||||
let n = r1cs.constraints.len
|
||||
let p = r1cs.cfg.nPubIn + r1cs.cfg.nPubOut
|
||||
for i in 0..<n:
|
||||
let ct = r1cs.constraints[i]
|
||||
for term in ct.A:
|
||||
let c = Coeff(matrix:MatrixA, row:i, col:term.wireIdx, coeff:term.value)
|
||||
coeffs.add(c)
|
||||
for term in ct.B:
|
||||
let c = Coeff(matrix:MatrixB, row:i, col:term.wireIdx, coeff:term.value)
|
||||
coeffs.add(c)
|
||||
|
||||
# Snarkjs adds some dummy coefficients to the matrix "A", for the public I/O
|
||||
# Let's emulate that here
|
||||
for i in n..n+p:
|
||||
let c = Coeff(matrix:MatrixA, row:i, col:i-n, coeff:oneFr)
|
||||
coeffs.add(c)
|
||||
|
||||
return coeffs
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
type Column*[T] = seq[T]
|
||||
|
||||
type Matrix*[T] = seq[Column[T]]
|
||||
|
||||
type
|
||||
Matrices* = object
|
||||
A* : Matrix[Fr]
|
||||
B* : Matrix[Fr]
|
||||
C* : Matrix[Fr]
|
||||
|
||||
func r1csToMatrices*(r1cs: R1CS): Matrices =
|
||||
let n = r1cs.constraints.len
|
||||
let m = r1cs.cfg.nWires
|
||||
let p = r1cs.cfg.nPubIn + r1cs.cfg.nPubOut
|
||||
|
||||
let logDomSize = ceilingLog2(n+p+1)
|
||||
let domSize = 1 shl logDomSize
|
||||
|
||||
var matA, matB, matC: Matrix[Fr]
|
||||
for i in 0..<m:
|
||||
var colA = newSeq[Fr](domSize)
|
||||
var colB = newSeq[Fr](domSize)
|
||||
var colC = newSeq[Fr](domSize)
|
||||
matA.add( colA )
|
||||
matB.add( colB )
|
||||
matC.add( colC )
|
||||
|
||||
for i in 0..<n:
|
||||
let ct = r1cs.constraints[i]
|
||||
for term in ct.A: matA[term.wireIdx][i] += term.value
|
||||
for term in ct.B: matB[term.wireIdx][i] += term.value
|
||||
for term in ct.C: matC[term.wireIdx][i] += term.value
|
||||
|
||||
# Snarkjs adds some dummy coefficients to the matrix "A", for the public I/O
|
||||
# Let's emulate that here
|
||||
for i in n..n+p:
|
||||
matA[i-n][i] += oneFr
|
||||
|
||||
return Matrices(A:matA, B:matB, C:matC)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func matricesToCoeffs*(matrices: Matrices): seq[Coeff] =
|
||||
let n = matrices.A[0].len
|
||||
let m = matrices.A.len
|
||||
|
||||
var coeffs : seq[Coeff]
|
||||
for i in 0..<n:
|
||||
for j in 0..<m:
|
||||
|
||||
let a = matrices.A[j][i]
|
||||
if not bool(isZero(a)):
|
||||
let x = Coeff(matrix:MatrixA, row:i, col:j, coeff:a)
|
||||
coeffs.add(x)
|
||||
|
||||
let b = matrices.B[j][i]
|
||||
if not bool(isZero(b)):
|
||||
let x = Coeff(matrix:MatrixB, row:i, col:j, coeff:b)
|
||||
coeffs.add(x)
|
||||
|
||||
return coeffs
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func fakeCircuitSetup*(r1cs: R1CS, toxic: ToxicWaste, flavour=Snarkjs): ZKey =
|
||||
|
||||
let neqs = r1cs.constraints.len
|
||||
let npub = r1cs.cfg.nPubIn + r1cs.cfg.nPubOut
|
||||
let logDomSize = ceilingLog2(neqs+npub+1)
|
||||
let domSize = 1 shl logDomSize
|
||||
|
||||
let nvars = r1cs.cfg.nWires
|
||||
let npubs = r1cs.cfg.nPubOut + r1cs.cfg.nPubIn
|
||||
|
||||
# echo("nvars = ",nvars)
|
||||
# echo("npub = ",npubs)
|
||||
# echo("neqs = ",neqs)
|
||||
# echo("domain = ",domSize)
|
||||
|
||||
let header = GrothHeader( curve: "bn128"
|
||||
, flavour: flavour
|
||||
, p: primeP
|
||||
, r: primeR
|
||||
, nvars: nvars
|
||||
, npubs: npubs
|
||||
, domainSize: domSize
|
||||
, logDomainSize: logDomSize
|
||||
)
|
||||
|
||||
let spec = SpecPoints( alpha1 : toxic.alpha ** gen1
|
||||
, beta1 : toxic.beta ** gen1
|
||||
, beta2 : toxic.beta ** gen2
|
||||
, gamma2 : toxic.gamma ** gen2
|
||||
, delta1 : toxic.delta ** gen1
|
||||
, delta2 : toxic.delta ** gen2
|
||||
, alphaBeta : pairing( toxic.alpha ** gen1 , toxic.beta ** gen2 )
|
||||
)
|
||||
|
||||
let matrices = r1csToMatrices(r1cs)
|
||||
let coeffs = r1csToCoeffs( r1cs )
|
||||
# let coeffs = matricesToCoeffs(matrices)
|
||||
|
||||
let D : Domain = createDomain(domSize)
|
||||
|
||||
let polyAs : seq[Poly] = collect( newSeq , (for col in matrices.A: polyInverseNTT(col, D) ))
|
||||
let polyBs : seq[Poly] = collect( newSeq , (for col in matrices.B: polyInverseNTT(col, D) ))
|
||||
let polyCs : seq[Poly] = collect( newSeq , (for col in matrices.C: polyInverseNTT(col, D) ))
|
||||
|
||||
let pointsA : seq[G1] = collect( newSeq , (for p in polyAs: polyEvalAt(p, toxic.tau) ** gen1) )
|
||||
let pointsB1 : seq[G1] = collect( newSeq , (for p in polyBs: polyEvalAt(p, toxic.tau) ** gen1) )
|
||||
let pointsB2 : seq[G2] = collect( newSeq , (for p in polyBs: polyEvalAt(p, toxic.tau) ** gen2) )
|
||||
let pointsC : seq[G1] = collect( newSeq , (for p in polyCs: polyEvalAt(p, toxic.tau) ** gen1) )
|
||||
|
||||
let gammaInv : Fr = invFr(toxic.gamma)
|
||||
let deltaInv : Fr = invFr(toxic.delta)
|
||||
|
||||
let pointsL : seq[G1] = collect( newSeq , (for j in 0..npub:
|
||||
gammaInv ** ( toxic.beta ** pointsA[j] + toxic.alpha ** pointsB1[j] + pointsC[j] ) ))
|
||||
|
||||
let pointsK : seq[G1] = collect( newSeq , (for j in npub+1..nvars-1:
|
||||
deltaInv ** ( toxic.beta ** pointsA[j] + toxic.alpha ** pointsB1[j] + pointsC[j] ) ))
|
||||
|
||||
let polyZ = vanishingPoly(D)
|
||||
let ztauG1 = polyEvalAt(polyZ, toxic.tau) ** gen1
|
||||
|
||||
var pointsH : seq[G1]
|
||||
case flavour
|
||||
|
||||
# in the original paper, these are the curve points
|
||||
# [ delta^-1 * tau^i * Z(tau) ]
|
||||
of JensGroth:
|
||||
pointsH = collect( newSeq , (for i in 0..<domSize:
|
||||
(deltaInv * smallPowFr(toxic.tau,i)) ** ztauG1 ))
|
||||
|
||||
# in the Snarkjs implementation, these are the curve points
|
||||
# [ delta^-1 * L_{2i+1} (tau) ]
|
||||
# where L_k are the Lagrange polynomials on the refined domain
|
||||
of Snarkjs:
|
||||
let D2 : Domain = createDomain(2*domSize)
|
||||
let eta : Fr = D2.domainGen
|
||||
|
||||
pointsH = collect( newSeq , (for i in 0..<domSize:
|
||||
(deltaInv * evalLagrangePolyAt(D2, 2*i+1, toxic.tau)) ** gen1 ))
|
||||
|
||||
let vPoints = VerifierPoints( pointsIC: pointsL )
|
||||
|
||||
let pPoints = ProverPoints( pointsA1: pointsA
|
||||
, pointsB1: pointsB1
|
||||
, pointsB2: pointsB2
|
||||
, pointsC1: pointsK
|
||||
, pointsH1: pointsH
|
||||
)
|
||||
|
||||
return ZKey( header: header
|
||||
, specPoints: spec
|
||||
, vPoints: vPoints
|
||||
, pPoints: pPoints
|
||||
, coeffs: coeffs
|
||||
)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc createFakeCircuitSetup*(r1cs: R1CS, flavour=Snarkjs): ZKey =
|
||||
let toxic = randomToxicWaste()
|
||||
return fakeCircuitSetup(r1cs, toxic, flavour=flavour)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
25
groth16.nim
25
groth16.nim
@ -139,15 +139,15 @@ func computeQuotientPointwise( abc: ABC ): Poly =
|
||||
|
||||
# (eta*omega^j)^n - 1 = eta^n - 1
|
||||
# 1 / [ (eta*omega^j)^n - 1] = 1/(eta^n - 1)
|
||||
let eta = createDomain(2*n).domainGen
|
||||
let Inv1 = invFr( smallPowFr(eta,n) - oneFr )
|
||||
let eta = createDomain(2*n).domainGen
|
||||
let invZ1 = invFr( smallPowFr(eta,n) - oneFr )
|
||||
|
||||
let A1 = shiftEvalDomain( abc.valuesA, D, eta )
|
||||
let B1 = shiftEvalDomain( abc.valuesB, D, eta )
|
||||
let C1 = shiftEvalDomain( abc.valuesC, D, eta )
|
||||
|
||||
var ys : seq[Fr] = newSeq[Fr]( n )
|
||||
for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] ) * Inv1
|
||||
for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] ) * invZ1
|
||||
let Q1 = polyInverseNTT( ys, D )
|
||||
|
||||
let cs = multiplyByPowers( Q1.coeffs, invFr(eta) )
|
||||
@ -175,7 +175,7 @@ func computeSnarkjsScalarCoeffs( abc: ABC ): seq[Fr] =
|
||||
# the prover
|
||||
#
|
||||
|
||||
proc generateProof* ( zkey: ZKey, wtns: Witness ): Proof =
|
||||
proc generateProof*( zkey: ZKey, wtns: Witness ): Proof =
|
||||
assert( zkey.header.curve == wtns.curve )
|
||||
|
||||
let witness = wtns.values
|
||||
@ -194,10 +194,18 @@ proc generateProof* ( zkey: ZKey, wtns: Witness ): Proof =
|
||||
|
||||
var abc : ABC = buildABC( zkey, witness )
|
||||
|
||||
# let polyQ1 = computeQuotientNaive( abc )
|
||||
# let polyQ2 = computeQuotientPointwise( abc )
|
||||
var qs : seq[Fr]
|
||||
case zkey.header.flavour
|
||||
|
||||
let qs = computeSnarkjsScalarCoeffs( abc )
|
||||
# the points H are [delta^-1 * tau^i * Z(tau)]
|
||||
of JensGroth:
|
||||
let polyQ = computeQuotientPointwise( abc )
|
||||
qs = polyQ.coeffs
|
||||
|
||||
# the points H are [delta^-1 * L_i(tau*eta) / Z(omega^i*eta)]
|
||||
# where eta^2 = omega and L_i are Lagrange basis polynomials
|
||||
of Snarkjs:
|
||||
qs = computeSnarkjsScalarCoeffs( abc )
|
||||
|
||||
var zs : seq[Fr] = newSeq[Fr]( nvars - npubs - 1 )
|
||||
for j in npubs+1..<nvars:
|
||||
@ -207,9 +215,6 @@ proc generateProof* ( zkey: ZKey, wtns: Witness ): Proof =
|
||||
let r : Fr = randFr()
|
||||
let s : Fr = randFr()
|
||||
|
||||
# let r : Fr = intToFr(3)
|
||||
# let s : Fr = intToFr(4)
|
||||
|
||||
var pi_a : G1
|
||||
pi_a = spec.alpha1
|
||||
pi_a += r ** spec.delta1
|
||||
|
||||
5
misc.nim
5
misc.nim
@ -5,6 +5,11 @@
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func delta*(i, j: int) : int =
|
||||
return (if (i == j): 1 else: 0)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func floorLog2* (x : int) : int =
|
||||
var k = -1
|
||||
var y = x
|
||||
|
||||
3
ntt.nim
3
ntt.nim
@ -70,6 +70,7 @@ func forwardNTT*(src: seq[Fr], D: Domain): seq[Fr] =
|
||||
return tgt
|
||||
|
||||
# pads the input with zeros to get a pwoer of two size
|
||||
# TODO: optimize the FFT so that it doesn't do the multiplications with zeros
|
||||
func extendAndForwardNTT*(src: seq[Fr], D: Domain): seq[Fr] =
|
||||
let n = src.len
|
||||
let N = D.domainSize
|
||||
@ -97,6 +98,7 @@ func inverseNTT_worker( m: int
|
||||
of 0:
|
||||
tgt[tgtOfs] = src[srcOfs]
|
||||
|
||||
# TODO: faster division by 2
|
||||
of 1:
|
||||
tgt[tgtOfs ] = ( src[srcOfs] + src[srcOfs+1] ) * oneHalfFr
|
||||
tgt[tgtOfs+tgtStride] = ( src[srcOfs] - src[srcOfs+1] ) * oneHalfFr
|
||||
@ -107,6 +109,7 @@ func inverseNTT_worker( m: int
|
||||
let ginv : Fr = invFr(gen)
|
||||
var gpow : Fr = oneHalfFr
|
||||
|
||||
# TODO: precalculate the gpow vector for repeated iNTT-s ?
|
||||
for j in 0..<halfN:
|
||||
buf[bufOfs+j ] = ( src[srcOfs+j] + src[srcOfs+j+halfN] ) * oneHalfFr
|
||||
buf[bufOfs+j+halfN] = ( src[srcOfs+j] - src[srcOfs+j+halfN] ) * gpow
|
||||
|
||||
81
poly.nim
81
poly.nim
@ -5,8 +5,6 @@
|
||||
# constantine's implementation is "somewhat lacking", so we have to
|
||||
# implement these ourselves...
|
||||
#
|
||||
# TODO: more efficient implementations (right now I just want something working)
|
||||
#
|
||||
|
||||
import std/sequtils
|
||||
import std/sugar
|
||||
@ -143,6 +141,7 @@ func polyMulFFT*(P, Q: Poly): Poly =
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# WARNING: this is using the naive implementation!
|
||||
func polyMul*(P, Q : Poly): Poly =
|
||||
# return polyMulFFT(P, Q)
|
||||
return polyMulNaive(P, Q)
|
||||
@ -172,6 +171,9 @@ func generalizedVanishingPoly*(N: int, a: Fr, b: Fr): Poly =
|
||||
func vanishingPoly*(N: int): Poly =
|
||||
return generalizedVanishingPoly(N, oneFr, oneFr)
|
||||
|
||||
func vanishingPoly*(D: Domain): Poly =
|
||||
return vanishingPoly(D.domainSize)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
type
|
||||
@ -217,6 +219,38 @@ func polyDivideByVanishing*(P: Poly, N: int): Poly =
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# Lagrange basis polynomials
|
||||
func lagrangePoly(D: Domain, k: int): Poly =
|
||||
let N = D.domainSize
|
||||
let omMinusK : Fr = smallPowFr( D.invDomainGen , k )
|
||||
let invN : Fr = invFr(intToFr(N))
|
||||
|
||||
var cs : seq[Fr] = newSeq[Fr]( N )
|
||||
if k == 0:
|
||||
for i in 0..<N: cs[i] = invN
|
||||
else:
|
||||
var s : Fr = invN
|
||||
for i in 0..<N:
|
||||
cs[i] = s
|
||||
s *= omMinusK
|
||||
|
||||
return Poly(coeffs: cs)
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
# evaluate a Lagrange basis polynomial at a given point `zeta` (outside the domain)
|
||||
func evalLagrangePolyAt*(D: Domain, k: int, zeta: Fr): Fr =
|
||||
let omegaK = smallPowFr(D.domainGen, k)
|
||||
let denom = (zeta - omegaK)
|
||||
if bool(isZero(denom)):
|
||||
# we are inside the domain
|
||||
raise newException(AssertionDefect, "point should be outside the domain")
|
||||
else:
|
||||
# we are outside the domain
|
||||
return omegaK * (smallPowFr(zeta, D.domainSize) - oneFr) * D.invDomainSize * invFr(denom)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# evaluates a polynomial on an FFT domain
|
||||
func polyForwardNTT*(P: Poly, D: Domain): seq[Fr] =
|
||||
let n = P.coeffs.len
|
||||
@ -244,6 +278,8 @@ proc sanityCheckOneHalf*() =
|
||||
echo(toDecimalFr(invTwo * two))
|
||||
echo(toHex(invTwo))
|
||||
|
||||
#-------------------
|
||||
|
||||
proc sanityCheckVanishing*() =
|
||||
var js : seq[int] = toSeq(101..112)
|
||||
let cs : seq[Fr] = map( js, intToFr )
|
||||
@ -266,6 +302,8 @@ proc sanityCheckVanishing*() =
|
||||
debugPrintFrSeq("zs", S.coeffs)
|
||||
echo( polyIsEqual(P,S) )
|
||||
|
||||
#-------------------
|
||||
|
||||
proc sanityCheckNTT*() =
|
||||
var js : seq[int] = toSeq(101..108)
|
||||
let cs : seq[Fr] = map( js, intToFr )
|
||||
@ -280,6 +318,8 @@ proc sanityCheckNTT*() =
|
||||
debugPrintFrSeq("zs", zs)
|
||||
debugPrintFrSeq("us", Q.coeffs)
|
||||
|
||||
#-------------------
|
||||
|
||||
proc sanityCheckMulFFT*() =
|
||||
var js : seq[int] = toSeq(101..110)
|
||||
let cs : seq[Fr] = map( js, intToFr )
|
||||
@ -297,6 +337,43 @@ proc sanityCheckMulFFT*() =
|
||||
|
||||
echo( "multiply test = ", polyIsEqual(R1,R2) )
|
||||
|
||||
#-------------------
|
||||
|
||||
proc sanityCheckLagrangeBases*() =
|
||||
let n = 8
|
||||
let D = createDomain(n)
|
||||
|
||||
let L : seq[Poly] = collect( newSeq, (for k in 0..<n: lagrangePoly(D,k) ))
|
||||
|
||||
let xs = enumerateDomain(D)
|
||||
let ys0 : seq[Fr] = collect( newSeq, (for x in xs: polyEvalAt(L[0],x) ))
|
||||
let ys1 : seq[Fr] = collect( newSeq, (for x in xs: polyEvalAt(L[1],x) ))
|
||||
let ys5 : seq[Fr] = collect( newSeq, (for x in xs: polyEvalAt(L[5],x) ))
|
||||
let zs0 : seq[Fr] = collect( newSeq, (for i in 0..<n: deltaFr(0,i) ))
|
||||
let zs1 : seq[Fr] = collect( newSeq, (for i in 0..<n: deltaFr(1,i) ))
|
||||
let zs5 : seq[Fr] = collect( newSeq, (for i in 0..<n: deltaFr(5,i) ))
|
||||
|
||||
echo("==============")
|
||||
for i in 0..<n: echo("i = ",i, " | y[i] = ",toDecimalFr(ys0[i]), " | z[i] = ",toDecimalFr(zs0[i]))
|
||||
echo("--------------")
|
||||
for i in 0..<n: echo("i = ",i, " | y[i] = ",toDecimalFr(ys1[i]), " | z[i] = ",toDecimalFr(zs1[i]))
|
||||
echo("--------------")
|
||||
for i in 0..<n: echo("i = ",i, " | y[i] = ",toDecimalFr(ys5[i]), " | z[i] = ",toDecimalFr(zs5[i]))
|
||||
|
||||
let zeta = intToFr(123457)
|
||||
let us : seq[Fr] = collect( newSeq, (for i in 0..<n: polyEvalAt(L[i],zeta)) )
|
||||
let vs : seq[Fr] = collect( newSeq, (for i in 0..<n: evalLagrangePolyAt(D,i,zeta)) )
|
||||
|
||||
echo("==============")
|
||||
for i in 0..<n: echo("i = ",i, " | u[i] = ",toDecimalFr(us[i]), " | v[i] = ",toDecimalFr(vs[i]))
|
||||
|
||||
let prefix = "Lagrange basis sanity check = "
|
||||
if ( ys0===zs0 and ys1===zs1 and ys5===zs5 and
|
||||
us===vs ):
|
||||
echo( prefix & "OK")
|
||||
else:
|
||||
echo( prefix & "FAILED")
|
||||
|
||||
]#
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
|
||||
import ./groth16
|
||||
import ./witness
|
||||
import ./r1cs
|
||||
import ./zkey
|
||||
import ./zkey_types
|
||||
import ./fake_setup
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
@ -12,14 +14,38 @@ proc testProveAndVerify*( zkey_fname, wtns_fname: string): Proof =
|
||||
let witness = parseWitness( wtns_fname)
|
||||
let zkey = parseZKey( zkey_fname)
|
||||
|
||||
# printCoeffs(zkey.coeffs)
|
||||
|
||||
echo("generating proof...")
|
||||
let vkey = extractVKey( zkey)
|
||||
let proof = generateProof( zkey, witness )
|
||||
|
||||
echo("verifying the proof...")
|
||||
let ok = verifyProof( vkey, proof)
|
||||
let ok = verifyProof( vkey, proof )
|
||||
echo("verification succeeded = ",ok)
|
||||
|
||||
return proof
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc testFakeSetupAndVerify*( r1cs_fname, wtns_fname: string, flavour=Snarkjs): Proof =
|
||||
echo("trusted setup flavour = ",flavour)
|
||||
|
||||
echo("parsing witness & r1cs files...")
|
||||
let witness = parseWitness( wtns_fname)
|
||||
let r1cs = parseR1CS( r1cs_fname)
|
||||
|
||||
echo("performing fake trusted setup...")
|
||||
let zkey = createFakeCircuitSetup( r1cs, flavour=flavour )
|
||||
|
||||
# printCoeffs(zkey.coeffs)
|
||||
|
||||
echo("generating proof...")
|
||||
let vkey = extractVKey( zkey)
|
||||
let proof = generateProof( zkey, witness )
|
||||
|
||||
echo("verifying the proof...")
|
||||
let ok = verifyProof( vkey, proof )
|
||||
echo("verification succeeded = ",ok)
|
||||
|
||||
return proof
|
||||
|
||||
@ -11,9 +11,9 @@
|
||||
#
|
||||
# nvars = 1 + pub + secret = 1 + npubout + npubin + nprivin + nsecret
|
||||
#
|
||||
# WARNING! unlike the `.r1cs` and `.zkey` files, which encode field elements
|
||||
# in Montgomery representation, the `.wtns` file encode field elements in
|
||||
# the standard representation!!
|
||||
# NOTE: Unlike the `.zkey` files, which encode field elements in the
|
||||
# Montgomery representation, the `.wtns` file encode field elements in
|
||||
# the standard representation!
|
||||
#
|
||||
|
||||
import std/streams
|
||||
|
||||
14
zkey.nim
14
zkey.nim
@ -7,7 +7,9 @@
|
||||
# ===========
|
||||
#
|
||||
# standard iden3 binary container format.
|
||||
# field elements are in Montgomery representation
|
||||
# field elements are in Montgomery representation, except for the coefficients
|
||||
# which for some reason are double Montgomery encoded... (and unlike the
|
||||
# `.wtns` and `.r1cs` files which use the standard representation)
|
||||
#
|
||||
# sections:
|
||||
#
|
||||
@ -52,7 +54,7 @@
|
||||
# compute (C*witness)[i] = (A*witness)[i] * (B*witness)[i]
|
||||
# These 3 column vectors is all we need in the proof generation.
|
||||
#
|
||||
# WARNING! It appears that the values here are *doubly Montgomery encoded* (?!)
|
||||
# WARNING! It appears that the values here are *doubly Montgomery encoded* (??!)
|
||||
#
|
||||
# 5: PointsA
|
||||
# ----------
|
||||
@ -76,8 +78,10 @@
|
||||
#
|
||||
# 9: PointsH
|
||||
# ----------
|
||||
# what normally should be the curve points [delta^-1 * tau^i * Z(tau)]
|
||||
# HOWEVER, they are NOT! (??)
|
||||
# what normally should be the curve points `[ delta^-1 * tau^i * Z(tau) ]_1`
|
||||
# HOWEVER, in the snarkjs implementation, they are different; namely
|
||||
# `[ delta^-1 * L_{2i+1} (tau) ]_1` where L_k are Lagrange polynomials
|
||||
# on the refined (double sized) domain
|
||||
# See <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
|
||||
# length = 2 * n8p * domSize = domSize G1 points
|
||||
#
|
||||
@ -122,6 +126,8 @@ proc parseSection2_GrothHeader( stream: Stream, user: var ZKey, sectionLen: int
|
||||
header.p = p
|
||||
header.r = r
|
||||
|
||||
header.flavour = Snarkjs
|
||||
|
||||
assert( n8p == 32 , "expecting 256 bit primes")
|
||||
assert( n8r == 32 , "expecting 256 bit primes")
|
||||
|
||||
|
||||
@ -7,33 +7,38 @@ import ./bn128
|
||||
|
||||
type
|
||||
|
||||
Flavour* = enum
|
||||
JensGroth # the version described in the original Groth16 paper
|
||||
Snarkjs # the version implemented by Snarkjs
|
||||
|
||||
GrothHeader* = object
|
||||
curve* : string
|
||||
p* : BigInt[256]
|
||||
r* : BigInt[256]
|
||||
nvars* : int
|
||||
npubs* : int
|
||||
domainSize* : int
|
||||
curve* : string # name of the curve, eg. "bn128"
|
||||
flavour* : Flavour # which variation of the trusted setup
|
||||
p* : BigInt[256] # size of the base field
|
||||
r* : BigInt[256] # size of the scalar field
|
||||
nvars* : int # number of witness variables (including the constant 1)
|
||||
npubs* : int # number of public input/outputs (excluding the constant 1)
|
||||
domainSize* : int # size of the domain (should be power of two)
|
||||
logDomainSize* : int
|
||||
|
||||
SpecPoints* = object
|
||||
alpha1* : G1
|
||||
beta1* : G1
|
||||
beta2* : G2
|
||||
gamma2* : G2
|
||||
delta1* : G1
|
||||
delta2* : G2
|
||||
alphaBeta* : Fp12 # = <alpha1,beta2>
|
||||
alpha1* : G1 # = alpha * g1
|
||||
beta1* : G1 # = beta * g1
|
||||
beta2* : G2 # = beta * g2
|
||||
gamma2* : G2 # = gamma * g2
|
||||
delta1* : G1 # = delta * g1
|
||||
delta2* : G2 # = delta * g2
|
||||
alphaBeta* : Fp12 # = <alpha1 , beta2>
|
||||
|
||||
VerifierPoints* = object
|
||||
pointsIC* : seq[G1]
|
||||
pointsIC* : seq[G1] # the points `delta^-1 * ( beta*A_j(tau) + alpha*B_j(tau) + C_j(tau) ) * g1` (for j <= npub)
|
||||
|
||||
ProverPoints* = object
|
||||
pointsA1* : seq[G1]
|
||||
pointsB1* : seq[G1]
|
||||
pointsB2* : seq[G2]
|
||||
pointsC1* : seq[G1]
|
||||
pointsH1* : seq[G1]
|
||||
pointsA1* : seq[G1] # the points `A_j(tau) * g1`
|
||||
pointsB1* : seq[G1] # the points `B_j(tau) * g1`
|
||||
pointsB2* : seq[G2] # the points `B_j(tau) * g2`
|
||||
pointsC1* : seq[G1] # the points `delta^-1 * ( beta*A_j(tau) + alpha*B_j(tau) + C_j(tau) ) * g1` (for j > npub)
|
||||
pointsH1* : seq[G1] # meaning depends on `flavour`
|
||||
|
||||
MatrixSel* = enum
|
||||
MatrixA
|
||||
@ -68,3 +73,21 @@ func extractVKey*(zkey: Zkey): VKey =
|
||||
return VKey(curve:curve, spec:spec, vpoints:vpts)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func matrixSelToString(sel: MatrixSel): string =
|
||||
case sel
|
||||
of MatrixA: return "A"
|
||||
of MatrixB: return "B"
|
||||
of MatrixC: return "C"
|
||||
|
||||
proc printCoeff(cf: Coeff) =
|
||||
echo( "matrix=", matrixSelToString(cf.matrix)
|
||||
, " | i=", cf.row
|
||||
, " | j=", cf.col
|
||||
, " | val=", signedToDecimalFr(cf.coeff)
|
||||
)
|
||||
|
||||
proc printCoeffs*(cfs: seq[Coeff]) =
|
||||
for cf in cfs: printCoeff(cf)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user