add "fake" trusted setup for testing purposes; implement both the paper version and snarkjs version of the prover

This commit is contained in:
Balazs Komuves 2023-11-13 19:40:15 +01:00
parent 0544ef5b9e
commit 5ce7926e92
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
12 changed files with 542 additions and 51 deletions

View File

@ -6,24 +6,24 @@ This is Groth16 prover implementation in Nim, using the
[`constantine`](https://github.com/mratsim/constantine) [`constantine`](https://github.com/mratsim/constantine)
library as an arithmetic / curve backend. library as an arithmetic / curve backend.
The implementation is compatible with the `circom` ecosystem. The implementation is compatible with the `circom` + `snarkjs` ecosystem.
At the moment only the `BN254` (aka. `alt-bn128`) curve is supported. At the moment only the `BN254` (aka. `alt-bn128`) curve is supported.
### License ### License
Licensed and distributed under either of Licensed and distributed under either of the
[MIT license](http://opensource.org/licenses/MIT) or [MIT license](http://opensource.org/licenses/MIT) or
[Apache License, v2.0](http://www.apache.org/licenses/LICENSE-2.0), [Apache License, v2.0](http://www.apache.org/licenses/LICENSE-2.0),
at your option. at your choice.
### TODO ### TODO
- [ ] make it a nimble package - [ ] make it a nimble package
- [ ] refactor `bn128.nim` into smaller files - [ ] refactor `bn128.nim` into smaller files
- [ ] proper MSM implementation (I couldn't make constantine's one to work) - [ ] proper MSM implementation (I couldn't make constantine's one to work)
- [ ] compare `.r1cs` to the "coeffs" section of `.zkey` - [x] compare `.r1cs` to the "coeffs" section of `.zkey`
- [ ] generate fake circuit-specific setup ourselves - [x] generate fake circuit-specific setup ourselves
- [ ] multithreaded support (MSM, and possibly also FFT) - [ ] multithreaded support (MSM, and possibly also FFT)
- [ ] add Groth16 notes - [ ] add Groth16 notes
- [ ] document the `snarkjs` circuit-specific setup `H` points convention - [ ] document the `snarkjs` circuit-specific setup `H` points convention

108
bn128.nim
View File

@ -67,6 +67,9 @@ func pairing* (p: G1, q: G2) : Fp12 =
const primeP* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian ) const primeP* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian )
const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian ) const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian )
const primeP_254 : BigInt[254] = fromHex( BigInt[254], "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian )
const primeR_254 : BigInt[254] = fromHex( BigInt[254], "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian )
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
const zeroFp* : Fp = fromHex( Fp, "0x00" ) const zeroFp* : Fp = fromHex( Fp, "0x00" )
@ -82,6 +85,11 @@ const infG2* : G2 = unsafeMkG2( zeroFp2 , zeroFp2 )
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
func intToB*(a: uint): B =
var y : B
y.setUint(a)
return y
func intToFp*(a: int): Fp = func intToFp*(a: int): Fp =
var y : Fp var y : Fp
y.fromInt(a) y.fromInt(a)
@ -97,14 +105,46 @@ func intToFr*(a: int): Fr =
func isZeroFp*(x: Fp): bool = bool(isZero(x)) func isZeroFp*(x: Fp): bool = bool(isZero(x))
func isZeroFr*(x: Fr): bool = bool(isZero(x)) func isZeroFr*(x: Fr): bool = bool(isZero(x))
func isEquaLFp*(x, y: Fp): bool = bool(x == y) func isEqualFp*(x, y: Fp): bool = bool(x == y)
func isEquaLFr*(x, y: Fr): bool = bool(x == y) func isEqualFr*(x, y: Fr): bool = bool(x == y)
func `===`*(x, y: Fp): bool = isEqualFp(x,y)
func `===`*(x, y: Fr): bool = isEqualFr(x,y)
#-------------------
func isEqualFpSeq*(xs, ys: seq[Fp]): bool =
let n = xs.len
assert( n == ys.len )
var b = true
for i in 0..<n:
if not bool(xs[i] == ys[i]):
b = false
break
return b
func isEqualFrSeq*(xs, ys: seq[Fr]): bool =
let n = xs.len
assert( n == ys.len )
var b = true
for i in 0..<n:
if not bool(xs[i] == ys[i]):
b = false
break
return b
func `===`*(xs, ys: seq[Fp]): bool = isEqualFpSeq(xs,ys)
func `===`*(xs, ys: seq[Fr]): bool = isEqualFrSeq(xs,ys)
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
#func `+`*(x, y: B ): B = ( var z : B = x ; z += y ; return z )
func `+`*[n](x, y: BigInt[n] ): BigInt[n] = ( var z : BigInt[n] = x ; z += y ; return z )
func `+`*(x, y: Fp): Fp = ( var z : Fp = x ; z += y ; return z ) func `+`*(x, y: Fp): Fp = ( var z : Fp = x ; z += y ; return z )
func `+`*(x, y: Fr): Fr = ( var z : Fr = x ; z += y ; return z ) func `+`*(x, y: Fr): Fr = ( var z : Fr = x ; z += y ; return z )
#func `-`*(x, y: B ): B = ( var z : B = x ; z -= y ; return z )
func `-`*[n](x, y: BigInt[n] ): BigInt[n] = ( var z : BigInt[n] = x ; z -= y ; return z )
func `-`*(x, y: Fp): Fp = ( var z : Fp = x ; z -= y ; return z ) func `-`*(x, y: Fp): Fp = ( var z : Fp = x ; z -= y ; return z )
func `-`*(x, y: Fr): Fr = ( var z : Fr = x ; z -= y ; return z ) func `-`*(x, y: Fr): Fr = ( var z : Fr = x ; z -= y ; return z )
@ -139,6 +179,11 @@ func smallPowFr*(base: Fr, expo: int): Fr =
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
func deltaFr*(i, j: int) : Fr =
return (if (i == j): oneFr else: zeroFr)
#-------------------------------------------------------------------------------
func toDecimalBig*[n](a : BigInt[n]): string = func toDecimalBig*[n](a : BigInt[n]): string =
var s : string = toDecimal(a) var s : string = toDecimal(a)
s = s.strip( leading=true, trailing=false, chars={'0'} ) s = s.strip( leading=true, trailing=false, chars={'0'} )
@ -157,6 +202,22 @@ func toDecimalFr*(a : Fr): string =
if s.len == 0: s="0" if s.len == 0: s="0"
return s return s
#---------------------------------------
const k65536 : BigInt[254] = fromHex( BigInt[254], "0x10000", bigEndian )
func signedToDecimalFp*(a : Fp): string =
if bool( a.toBig() > primeP_254 - k65536 ):
return "-" & toDecimalFp(negFp(a))
else:
return toDecimalFp(a)
func signedToDecimalFr*(a : Fr): string =
if bool( a.toBig() > primeR_254 - k65536 ):
return "-" & toDecimalFr(negFr(a))
else:
return toDecimalFr(a)
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
proc debugPrintFp*(prefix: string, x: Fp) = proc debugPrintFp*(prefix: string, x: Fp) =
@ -333,6 +394,23 @@ func mkG2( x, y: Fp2 ) : G2 =
assert( checkCurveEqG2(x,y) , "mkG2: not a G2 curve point" ) assert( checkCurveEqG2(x,y) , "mkG2: not a G2 curve point" )
return unsafeMkG2(x,y) return unsafeMkG2(x,y)
#-------------------------------------------------------------------------------
# group generators
const gen1_x : Fp = fromHex(Fp, "0x01")
const gen1_y : Fp = fromHex(Fp, "0x02")
const gen2_xi : Fp = fromHex(Fp, "0x1adcd0ed10df9cb87040f46655e3808f98aa68a570acf5b0bde23fab1f149701")
const gen2_xu : Fp = fromHex(Fp, "0x09e847e9f05a6082c3cd2a1d0a3a82e6fbfbe620f7f31269fa15d21c1c13b23b")
const gen2_yi : Fp = fromHex(Fp, "0x056c01168a5319461f7ca7aa19d4fcfd1c7cdf52dbfc4cbee6f915250b7f6fc8")
const gen2_yu : Fp = fromHex(Fp, "0x0efe500a2d02dd77f5f401329f30895df553b878fc3c0dadaaa86456a623235c")
const gen2_x : Fp2 = mkFp2( gen2_xi, gen2_xu )
const gen2_y : Fp2 = mkFp2( gen2_yi, gen2_yu )
const gen1* : G1 = unsafeMkG1( gen1_x, gen1_y )
const gen2* : G2 = unsafeMkG2( gen2_x, gen2_y )
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
func isOnCurveG1* ( p: G1 ) : bool = func isOnCurveG1* ( p: G1 ) : bool =
@ -625,6 +703,24 @@ func `**`*( coeff: Fr , point: G2 ) : G2 =
prj.affine( r, q ) prj.affine( r, q )
return r return r
#-------------------
func `**`*( coeff: BigInt , point: G1 ) : G1 =
var q : ProjG1
prj.fromAffine( q , point )
scl.scalarMul( q , coeff )
var r : G1
prj.affine( r, q )
return r
func `**`*( coeff: BigInt , point: G2 ) : G2 =
var q : ProjG2
prj.fromAffine( q , point )
scl.scalarMul( q , coeff )
var r : G2
prj.affine( r, q )
return r
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
func msmNaiveG1( coeffs: seq[Fr] , points: seq[G1] ): G1 = func msmNaiveG1( coeffs: seq[Fr] , points: seq[G1] ): G1 =
@ -675,3 +771,11 @@ func msmG2*( coeffs: seq[Fr] , points: seq[G2] ): G2 =
return msmNaiveG2( coeffs, points ) return msmNaiveG2( coeffs, points )
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
proc sanityCheckGroupGen*() =
echo( "gen1 on the curve = ", checkCurveEqG1(gen1.x,gen1.y) )
echo( "gen2 on the curve = ", checkCurveEqG2(gen2.x,gen2.y) )
echo( "order of gen1 is R = ", (not bool(isInf(gen1))) and bool(isInf(primeR ** gen1)) )
echo( "order of gen2 is R = ", (not bool(isInf(gen2))) and bool(isInf(primeR ** gen2)) )
#-------------------------------------------------------------------------------

View File

@ -14,9 +14,11 @@ import ./misc
type type
Domain* = object Domain* = object
domainSize* : int domainSize* : int # `N = 2^n`
logDomainSize* : int logDomainSize* : int # `n = log2(N)`
domainGen* : Fr domainGen* : Fr # `g`
invDomainGen* : Fr # `g^-1`
invDomainSize* : Fr # `1/n`
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
@ -36,7 +38,12 @@ func createDomain*(size: int): Domain =
assert( bool(a == oneFr) , "domain generator sanity check /A" ) assert( bool(a == oneFr) , "domain generator sanity check /A" )
assert( not bool(b == oneFr) , "domain generator sanity check /B" ) assert( not bool(b == oneFr) , "domain generator sanity check /B" )
return Domain( domainSize:size, logDomainSize:log2, domainGen:gen ) return Domain( domainSize: size
, logDomainSize: log2
, domainGen: gen
, invDomainGen: invFr(gen)
, invDomainSize: invFr(intToFr(size))
)
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
@ -49,3 +56,4 @@ func enumerateDomain*(D: Domain): seq[Fr] =
return xs return xs
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------

234
fake_setup.nim Normal file
View File

@ -0,0 +1,234 @@
#
# create "fake" circuit-specific trusted setup for testing purposes
#
# by fake here I mean that no actual ceremoney is done, we just generate
# some random toxic waste
#
import sugar
import std/sequtils
import constantine/math/arithmetic except Fp, Fr
import bn128
import domain
import poly
import zkey_types
import r1cs
import misc
#-------------------------------------------------------------------------------
type
ToxicWaste = object
alpha: Fr
beta: Fr
gamma: Fr
delta: Fr
tau: Fr
proc randomToxicWaste(): ToxicWaste =
let a = randFr()
let b = randFr()
let c = randFr()
let d = randFr()
let t = randFr()
return ToxicWaste( alpha: a
, beta: b
, gamma: c
, delta: d
, tau: t )
#-------------------------------------------------------------------------------
func r1csToCoeffs*(r1cs: R1CS): seq[Coeff] =
var coeffs : seq[Coeff]
let n = r1cs.constraints.len
let p = r1cs.cfg.nPubIn + r1cs.cfg.nPubOut
for i in 0..<n:
let ct = r1cs.constraints[i]
for term in ct.A:
let c = Coeff(matrix:MatrixA, row:i, col:term.wireIdx, coeff:term.value)
coeffs.add(c)
for term in ct.B:
let c = Coeff(matrix:MatrixB, row:i, col:term.wireIdx, coeff:term.value)
coeffs.add(c)
# Snarkjs adds some dummy coefficients to the matrix "A", for the public I/O
# Let's emulate that here
for i in n..n+p:
let c = Coeff(matrix:MatrixA, row:i, col:i-n, coeff:oneFr)
coeffs.add(c)
return coeffs
#-------------------------------------------------------------------------------
type Column*[T] = seq[T]
type Matrix*[T] = seq[Column[T]]
type
Matrices* = object
A* : Matrix[Fr]
B* : Matrix[Fr]
C* : Matrix[Fr]
func r1csToMatrices*(r1cs: R1CS): Matrices =
let n = r1cs.constraints.len
let m = r1cs.cfg.nWires
let p = r1cs.cfg.nPubIn + r1cs.cfg.nPubOut
let logDomSize = ceilingLog2(n+p+1)
let domSize = 1 shl logDomSize
var matA, matB, matC: Matrix[Fr]
for i in 0..<m:
var colA = newSeq[Fr](domSize)
var colB = newSeq[Fr](domSize)
var colC = newSeq[Fr](domSize)
matA.add( colA )
matB.add( colB )
matC.add( colC )
for i in 0..<n:
let ct = r1cs.constraints[i]
for term in ct.A: matA[term.wireIdx][i] += term.value
for term in ct.B: matB[term.wireIdx][i] += term.value
for term in ct.C: matC[term.wireIdx][i] += term.value
# Snarkjs adds some dummy coefficients to the matrix "A", for the public I/O
# Let's emulate that here
for i in n..n+p:
matA[i-n][i] += oneFr
return Matrices(A:matA, B:matB, C:matC)
#-------------------------------------------------------------------------------
func matricesToCoeffs*(matrices: Matrices): seq[Coeff] =
let n = matrices.A[0].len
let m = matrices.A.len
var coeffs : seq[Coeff]
for i in 0..<n:
for j in 0..<m:
let a = matrices.A[j][i]
if not bool(isZero(a)):
let x = Coeff(matrix:MatrixA, row:i, col:j, coeff:a)
coeffs.add(x)
let b = matrices.B[j][i]
if not bool(isZero(b)):
let x = Coeff(matrix:MatrixB, row:i, col:j, coeff:b)
coeffs.add(x)
return coeffs
#-------------------------------------------------------------------------------
func fakeCircuitSetup*(r1cs: R1CS, toxic: ToxicWaste, flavour=Snarkjs): ZKey =
let neqs = r1cs.constraints.len
let npub = r1cs.cfg.nPubIn + r1cs.cfg.nPubOut
let logDomSize = ceilingLog2(neqs+npub+1)
let domSize = 1 shl logDomSize
let nvars = r1cs.cfg.nWires
let npubs = r1cs.cfg.nPubOut + r1cs.cfg.nPubIn
# echo("nvars = ",nvars)
# echo("npub = ",npubs)
# echo("neqs = ",neqs)
# echo("domain = ",domSize)
let header = GrothHeader( curve: "bn128"
, flavour: flavour
, p: primeP
, r: primeR
, nvars: nvars
, npubs: npubs
, domainSize: domSize
, logDomainSize: logDomSize
)
let spec = SpecPoints( alpha1 : toxic.alpha ** gen1
, beta1 : toxic.beta ** gen1
, beta2 : toxic.beta ** gen2
, gamma2 : toxic.gamma ** gen2
, delta1 : toxic.delta ** gen1
, delta2 : toxic.delta ** gen2
, alphaBeta : pairing( toxic.alpha ** gen1 , toxic.beta ** gen2 )
)
let matrices = r1csToMatrices(r1cs)
let coeffs = r1csToCoeffs( r1cs )
# let coeffs = matricesToCoeffs(matrices)
let D : Domain = createDomain(domSize)
let polyAs : seq[Poly] = collect( newSeq , (for col in matrices.A: polyInverseNTT(col, D) ))
let polyBs : seq[Poly] = collect( newSeq , (for col in matrices.B: polyInverseNTT(col, D) ))
let polyCs : seq[Poly] = collect( newSeq , (for col in matrices.C: polyInverseNTT(col, D) ))
let pointsA : seq[G1] = collect( newSeq , (for p in polyAs: polyEvalAt(p, toxic.tau) ** gen1) )
let pointsB1 : seq[G1] = collect( newSeq , (for p in polyBs: polyEvalAt(p, toxic.tau) ** gen1) )
let pointsB2 : seq[G2] = collect( newSeq , (for p in polyBs: polyEvalAt(p, toxic.tau) ** gen2) )
let pointsC : seq[G1] = collect( newSeq , (for p in polyCs: polyEvalAt(p, toxic.tau) ** gen1) )
let gammaInv : Fr = invFr(toxic.gamma)
let deltaInv : Fr = invFr(toxic.delta)
let pointsL : seq[G1] = collect( newSeq , (for j in 0..npub:
gammaInv ** ( toxic.beta ** pointsA[j] + toxic.alpha ** pointsB1[j] + pointsC[j] ) ))
let pointsK : seq[G1] = collect( newSeq , (for j in npub+1..nvars-1:
deltaInv ** ( toxic.beta ** pointsA[j] + toxic.alpha ** pointsB1[j] + pointsC[j] ) ))
let polyZ = vanishingPoly(D)
let ztauG1 = polyEvalAt(polyZ, toxic.tau) ** gen1
var pointsH : seq[G1]
case flavour
# in the original paper, these are the curve points
# [ delta^-1 * tau^i * Z(tau) ]
of JensGroth:
pointsH = collect( newSeq , (for i in 0..<domSize:
(deltaInv * smallPowFr(toxic.tau,i)) ** ztauG1 ))
# in the Snarkjs implementation, these are the curve points
# [ delta^-1 * L_{2i+1} (tau) ]
# where L_k are the Lagrange polynomials on the refined domain
of Snarkjs:
let D2 : Domain = createDomain(2*domSize)
let eta : Fr = D2.domainGen
pointsH = collect( newSeq , (for i in 0..<domSize:
(deltaInv * evalLagrangePolyAt(D2, 2*i+1, toxic.tau)) ** gen1 ))
let vPoints = VerifierPoints( pointsIC: pointsL )
let pPoints = ProverPoints( pointsA1: pointsA
, pointsB1: pointsB1
, pointsB2: pointsB2
, pointsC1: pointsK
, pointsH1: pointsH
)
return ZKey( header: header
, specPoints: spec
, vPoints: vPoints
, pPoints: pPoints
, coeffs: coeffs
)
#-------------------------------------------------------------------------------
proc createFakeCircuitSetup*(r1cs: R1CS, flavour=Snarkjs): ZKey =
let toxic = randomToxicWaste()
return fakeCircuitSetup(r1cs, toxic, flavour=flavour)
#-------------------------------------------------------------------------------

View File

@ -139,15 +139,15 @@ func computeQuotientPointwise( abc: ABC ): Poly =
# (eta*omega^j)^n - 1 = eta^n - 1 # (eta*omega^j)^n - 1 = eta^n - 1
# 1 / [ (eta*omega^j)^n - 1] = 1/(eta^n - 1) # 1 / [ (eta*omega^j)^n - 1] = 1/(eta^n - 1)
let eta = createDomain(2*n).domainGen let eta = createDomain(2*n).domainGen
let Inv1 = invFr( smallPowFr(eta,n) - oneFr ) let invZ1 = invFr( smallPowFr(eta,n) - oneFr )
let A1 = shiftEvalDomain( abc.valuesA, D, eta ) let A1 = shiftEvalDomain( abc.valuesA, D, eta )
let B1 = shiftEvalDomain( abc.valuesB, D, eta ) let B1 = shiftEvalDomain( abc.valuesB, D, eta )
let C1 = shiftEvalDomain( abc.valuesC, D, eta ) let C1 = shiftEvalDomain( abc.valuesC, D, eta )
var ys : seq[Fr] = newSeq[Fr]( n ) var ys : seq[Fr] = newSeq[Fr]( n )
for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] ) * Inv1 for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] ) * invZ1
let Q1 = polyInverseNTT( ys, D ) let Q1 = polyInverseNTT( ys, D )
let cs = multiplyByPowers( Q1.coeffs, invFr(eta) ) let cs = multiplyByPowers( Q1.coeffs, invFr(eta) )
@ -175,7 +175,7 @@ func computeSnarkjsScalarCoeffs( abc: ABC ): seq[Fr] =
# the prover # the prover
# #
proc generateProof* ( zkey: ZKey, wtns: Witness ): Proof = proc generateProof*( zkey: ZKey, wtns: Witness ): Proof =
assert( zkey.header.curve == wtns.curve ) assert( zkey.header.curve == wtns.curve )
let witness = wtns.values let witness = wtns.values
@ -194,10 +194,18 @@ proc generateProof* ( zkey: ZKey, wtns: Witness ): Proof =
var abc : ABC = buildABC( zkey, witness ) var abc : ABC = buildABC( zkey, witness )
# let polyQ1 = computeQuotientNaive( abc ) var qs : seq[Fr]
# let polyQ2 = computeQuotientPointwise( abc ) case zkey.header.flavour
let qs = computeSnarkjsScalarCoeffs( abc ) # the points H are [delta^-1 * tau^i * Z(tau)]
of JensGroth:
let polyQ = computeQuotientPointwise( abc )
qs = polyQ.coeffs
# the points H are [delta^-1 * L_i(tau*eta) / Z(omega^i*eta)]
# where eta^2 = omega and L_i are Lagrange basis polynomials
of Snarkjs:
qs = computeSnarkjsScalarCoeffs( abc )
var zs : seq[Fr] = newSeq[Fr]( nvars - npubs - 1 ) var zs : seq[Fr] = newSeq[Fr]( nvars - npubs - 1 )
for j in npubs+1..<nvars: for j in npubs+1..<nvars:
@ -207,9 +215,6 @@ proc generateProof* ( zkey: ZKey, wtns: Witness ): Proof =
let r : Fr = randFr() let r : Fr = randFr()
let s : Fr = randFr() let s : Fr = randFr()
# let r : Fr = intToFr(3)
# let s : Fr = intToFr(4)
var pi_a : G1 var pi_a : G1
pi_a = spec.alpha1 pi_a = spec.alpha1
pi_a += r ** spec.delta1 pi_a += r ** spec.delta1

View File

@ -5,6 +5,11 @@
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
func delta*(i, j: int) : int =
return (if (i == j): 1 else: 0)
#-------------------------------------------------------------------------------
func floorLog2* (x : int) : int = func floorLog2* (x : int) : int =
var k = -1 var k = -1
var y = x var y = x

View File

@ -70,6 +70,7 @@ func forwardNTT*(src: seq[Fr], D: Domain): seq[Fr] =
return tgt return tgt
# pads the input with zeros to get a pwoer of two size # pads the input with zeros to get a pwoer of two size
# TODO: optimize the FFT so that it doesn't do the multiplications with zeros
func extendAndForwardNTT*(src: seq[Fr], D: Domain): seq[Fr] = func extendAndForwardNTT*(src: seq[Fr], D: Domain): seq[Fr] =
let n = src.len let n = src.len
let N = D.domainSize let N = D.domainSize
@ -97,6 +98,7 @@ func inverseNTT_worker( m: int
of 0: of 0:
tgt[tgtOfs] = src[srcOfs] tgt[tgtOfs] = src[srcOfs]
# TODO: faster division by 2
of 1: of 1:
tgt[tgtOfs ] = ( src[srcOfs] + src[srcOfs+1] ) * oneHalfFr tgt[tgtOfs ] = ( src[srcOfs] + src[srcOfs+1] ) * oneHalfFr
tgt[tgtOfs+tgtStride] = ( src[srcOfs] - src[srcOfs+1] ) * oneHalfFr tgt[tgtOfs+tgtStride] = ( src[srcOfs] - src[srcOfs+1] ) * oneHalfFr
@ -107,6 +109,7 @@ func inverseNTT_worker( m: int
let ginv : Fr = invFr(gen) let ginv : Fr = invFr(gen)
var gpow : Fr = oneHalfFr var gpow : Fr = oneHalfFr
# TODO: precalculate the gpow vector for repeated iNTT-s ?
for j in 0..<halfN: for j in 0..<halfN:
buf[bufOfs+j ] = ( src[srcOfs+j] + src[srcOfs+j+halfN] ) * oneHalfFr buf[bufOfs+j ] = ( src[srcOfs+j] + src[srcOfs+j+halfN] ) * oneHalfFr
buf[bufOfs+j+halfN] = ( src[srcOfs+j] - src[srcOfs+j+halfN] ) * gpow buf[bufOfs+j+halfN] = ( src[srcOfs+j] - src[srcOfs+j+halfN] ) * gpow

View File

@ -5,8 +5,6 @@
# constantine's implementation is "somewhat lacking", so we have to # constantine's implementation is "somewhat lacking", so we have to
# implement these ourselves... # implement these ourselves...
# #
# TODO: more efficient implementations (right now I just want something working)
#
import std/sequtils import std/sequtils
import std/sugar import std/sugar
@ -143,6 +141,7 @@ func polyMulFFT*(P, Q: Poly): Poly =
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
# WARNING: this is using the naive implementation!
func polyMul*(P, Q : Poly): Poly = func polyMul*(P, Q : Poly): Poly =
# return polyMulFFT(P, Q) # return polyMulFFT(P, Q)
return polyMulNaive(P, Q) return polyMulNaive(P, Q)
@ -172,6 +171,9 @@ func generalizedVanishingPoly*(N: int, a: Fr, b: Fr): Poly =
func vanishingPoly*(N: int): Poly = func vanishingPoly*(N: int): Poly =
return generalizedVanishingPoly(N, oneFr, oneFr) return generalizedVanishingPoly(N, oneFr, oneFr)
func vanishingPoly*(D: Domain): Poly =
return vanishingPoly(D.domainSize)
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
type type
@ -217,6 +219,38 @@ func polyDivideByVanishing*(P: Poly, N: int): Poly =
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
# Lagrange basis polynomials
func lagrangePoly(D: Domain, k: int): Poly =
let N = D.domainSize
let omMinusK : Fr = smallPowFr( D.invDomainGen , k )
let invN : Fr = invFr(intToFr(N))
var cs : seq[Fr] = newSeq[Fr]( N )
if k == 0:
for i in 0..<N: cs[i] = invN
else:
var s : Fr = invN
for i in 0..<N:
cs[i] = s
s *= omMinusK
return Poly(coeffs: cs)
#---------------------------------------
# evaluate a Lagrange basis polynomial at a given point `zeta` (outside the domain)
func evalLagrangePolyAt*(D: Domain, k: int, zeta: Fr): Fr =
let omegaK = smallPowFr(D.domainGen, k)
let denom = (zeta - omegaK)
if bool(isZero(denom)):
# we are inside the domain
raise newException(AssertionDefect, "point should be outside the domain")
else:
# we are outside the domain
return omegaK * (smallPowFr(zeta, D.domainSize) - oneFr) * D.invDomainSize * invFr(denom)
#-------------------------------------------------------------------------------
# evaluates a polynomial on an FFT domain # evaluates a polynomial on an FFT domain
func polyForwardNTT*(P: Poly, D: Domain): seq[Fr] = func polyForwardNTT*(P: Poly, D: Domain): seq[Fr] =
let n = P.coeffs.len let n = P.coeffs.len
@ -244,6 +278,8 @@ proc sanityCheckOneHalf*() =
echo(toDecimalFr(invTwo * two)) echo(toDecimalFr(invTwo * two))
echo(toHex(invTwo)) echo(toHex(invTwo))
#-------------------
proc sanityCheckVanishing*() = proc sanityCheckVanishing*() =
var js : seq[int] = toSeq(101..112) var js : seq[int] = toSeq(101..112)
let cs : seq[Fr] = map( js, intToFr ) let cs : seq[Fr] = map( js, intToFr )
@ -266,6 +302,8 @@ proc sanityCheckVanishing*() =
debugPrintFrSeq("zs", S.coeffs) debugPrintFrSeq("zs", S.coeffs)
echo( polyIsEqual(P,S) ) echo( polyIsEqual(P,S) )
#-------------------
proc sanityCheckNTT*() = proc sanityCheckNTT*() =
var js : seq[int] = toSeq(101..108) var js : seq[int] = toSeq(101..108)
let cs : seq[Fr] = map( js, intToFr ) let cs : seq[Fr] = map( js, intToFr )
@ -280,6 +318,8 @@ proc sanityCheckNTT*() =
debugPrintFrSeq("zs", zs) debugPrintFrSeq("zs", zs)
debugPrintFrSeq("us", Q.coeffs) debugPrintFrSeq("us", Q.coeffs)
#-------------------
proc sanityCheckMulFFT*() = proc sanityCheckMulFFT*() =
var js : seq[int] = toSeq(101..110) var js : seq[int] = toSeq(101..110)
let cs : seq[Fr] = map( js, intToFr ) let cs : seq[Fr] = map( js, intToFr )
@ -297,6 +337,43 @@ proc sanityCheckMulFFT*() =
echo( "multiply test = ", polyIsEqual(R1,R2) ) echo( "multiply test = ", polyIsEqual(R1,R2) )
#-------------------
proc sanityCheckLagrangeBases*() =
let n = 8
let D = createDomain(n)
let L : seq[Poly] = collect( newSeq, (for k in 0..<n: lagrangePoly(D,k) ))
let xs = enumerateDomain(D)
let ys0 : seq[Fr] = collect( newSeq, (for x in xs: polyEvalAt(L[0],x) ))
let ys1 : seq[Fr] = collect( newSeq, (for x in xs: polyEvalAt(L[1],x) ))
let ys5 : seq[Fr] = collect( newSeq, (for x in xs: polyEvalAt(L[5],x) ))
let zs0 : seq[Fr] = collect( newSeq, (for i in 0..<n: deltaFr(0,i) ))
let zs1 : seq[Fr] = collect( newSeq, (for i in 0..<n: deltaFr(1,i) ))
let zs5 : seq[Fr] = collect( newSeq, (for i in 0..<n: deltaFr(5,i) ))
echo("==============")
for i in 0..<n: echo("i = ",i, " | y[i] = ",toDecimalFr(ys0[i]), " | z[i] = ",toDecimalFr(zs0[i]))
echo("--------------")
for i in 0..<n: echo("i = ",i, " | y[i] = ",toDecimalFr(ys1[i]), " | z[i] = ",toDecimalFr(zs1[i]))
echo("--------------")
for i in 0..<n: echo("i = ",i, " | y[i] = ",toDecimalFr(ys5[i]), " | z[i] = ",toDecimalFr(zs5[i]))
let zeta = intToFr(123457)
let us : seq[Fr] = collect( newSeq, (for i in 0..<n: polyEvalAt(L[i],zeta)) )
let vs : seq[Fr] = collect( newSeq, (for i in 0..<n: evalLagrangePolyAt(D,i,zeta)) )
echo("==============")
for i in 0..<n: echo("i = ",i, " | u[i] = ",toDecimalFr(us[i]), " | v[i] = ",toDecimalFr(vs[i]))
let prefix = "Lagrange basis sanity check = "
if ( ys0===zs0 and ys1===zs1 and ys5===zs5 and
us===vs ):
echo( prefix & "OK")
else:
echo( prefix & "FAILED")
]# ]#
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------

View File

@ -1,8 +1,10 @@
import ./groth16 import ./groth16
import ./witness import ./witness
import ./r1cs
import ./zkey import ./zkey
import ./zkey_types import ./zkey_types
import ./fake_setup
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
@ -12,14 +14,38 @@ proc testProveAndVerify*( zkey_fname, wtns_fname: string): Proof =
let witness = parseWitness( wtns_fname) let witness = parseWitness( wtns_fname)
let zkey = parseZKey( zkey_fname) let zkey = parseZKey( zkey_fname)
# printCoeffs(zkey.coeffs)
echo("generating proof...") echo("generating proof...")
let vkey = extractVKey( zkey) let vkey = extractVKey( zkey)
let proof = generateProof( zkey, witness ) let proof = generateProof( zkey, witness )
echo("verifying the proof...") echo("verifying the proof...")
let ok = verifyProof( vkey, proof) let ok = verifyProof( vkey, proof )
echo("verification succeeded = ",ok) echo("verification succeeded = ",ok)
return proof return proof
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
proc testFakeSetupAndVerify*( r1cs_fname, wtns_fname: string, flavour=Snarkjs): Proof =
echo("trusted setup flavour = ",flavour)
echo("parsing witness & r1cs files...")
let witness = parseWitness( wtns_fname)
let r1cs = parseR1CS( r1cs_fname)
echo("performing fake trusted setup...")
let zkey = createFakeCircuitSetup( r1cs, flavour=flavour )
# printCoeffs(zkey.coeffs)
echo("generating proof...")
let vkey = extractVKey( zkey)
let proof = generateProof( zkey, witness )
echo("verifying the proof...")
let ok = verifyProof( vkey, proof )
echo("verification succeeded = ",ok)
return proof

View File

@ -11,9 +11,9 @@
# #
# nvars = 1 + pub + secret = 1 + npubout + npubin + nprivin + nsecret # nvars = 1 + pub + secret = 1 + npubout + npubin + nprivin + nsecret
# #
# WARNING! unlike the `.r1cs` and `.zkey` files, which encode field elements # NOTE: Unlike the `.zkey` files, which encode field elements in the
# in Montgomery representation, the `.wtns` file encode field elements in # Montgomery representation, the `.wtns` file encode field elements in
# the standard representation!! # the standard representation!
# #
import std/streams import std/streams

View File

@ -7,8 +7,10 @@
# =========== # ===========
# #
# standard iden3 binary container format. # standard iden3 binary container format.
# field elements are in Montgomery representation # field elements are in Montgomery representation, except for the coefficients
# # which for some reason are double Montgomery encoded... (and unlike the
# `.wtns` and `.r1cs` files which use the standard representation)
#
# sections: # sections:
# #
# 1: Header # 1: Header
@ -52,7 +54,7 @@
# compute (C*witness)[i] = (A*witness)[i] * (B*witness)[i] # compute (C*witness)[i] = (A*witness)[i] * (B*witness)[i]
# These 3 column vectors is all we need in the proof generation. # These 3 column vectors is all we need in the proof generation.
# #
# WARNING! It appears that the values here are *doubly Montgomery encoded* (?!) # WARNING! It appears that the values here are *doubly Montgomery encoded* (??!)
# #
# 5: PointsA # 5: PointsA
# ---------- # ----------
@ -76,8 +78,10 @@
# #
# 9: PointsH # 9: PointsH
# ---------- # ----------
# what normally should be the curve points [delta^-1 * tau^i * Z(tau)] # what normally should be the curve points `[ delta^-1 * tau^i * Z(tau) ]_1`
# HOWEVER, they are NOT! (??) # HOWEVER, in the snarkjs implementation, they are different; namely
# `[ delta^-1 * L_{2i+1} (tau) ]_1` where L_k are Lagrange polynomials
# on the refined (double sized) domain
# See <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs> # See <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
# length = 2 * n8p * domSize = domSize G1 points # length = 2 * n8p * domSize = domSize G1 points
# #
@ -122,6 +126,8 @@ proc parseSection2_GrothHeader( stream: Stream, user: var ZKey, sectionLen: int
header.p = p header.p = p
header.r = r header.r = r
header.flavour = Snarkjs
assert( n8p == 32 , "expecting 256 bit primes") assert( n8p == 32 , "expecting 256 bit primes")
assert( n8r == 32 , "expecting 256 bit primes") assert( n8r == 32 , "expecting 256 bit primes")

View File

@ -6,34 +6,39 @@ import ./bn128
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
type type
Flavour* = enum
JensGroth # the version described in the original Groth16 paper
Snarkjs # the version implemented by Snarkjs
GrothHeader* = object GrothHeader* = object
curve* : string curve* : string # name of the curve, eg. "bn128"
p* : BigInt[256] flavour* : Flavour # which variation of the trusted setup
r* : BigInt[256] p* : BigInt[256] # size of the base field
nvars* : int r* : BigInt[256] # size of the scalar field
npubs* : int nvars* : int # number of witness variables (including the constant 1)
domainSize* : int npubs* : int # number of public input/outputs (excluding the constant 1)
domainSize* : int # size of the domain (should be power of two)
logDomainSize* : int logDomainSize* : int
SpecPoints* = object SpecPoints* = object
alpha1* : G1 alpha1* : G1 # = alpha * g1
beta1* : G1 beta1* : G1 # = beta * g1
beta2* : G2 beta2* : G2 # = beta * g2
gamma2* : G2 gamma2* : G2 # = gamma * g2
delta1* : G1 delta1* : G1 # = delta * g1
delta2* : G2 delta2* : G2 # = delta * g2
alphaBeta* : Fp12 # = <alpha1,beta2> alphaBeta* : Fp12 # = <alpha1 , beta2>
VerifierPoints* = object VerifierPoints* = object
pointsIC* : seq[G1] pointsIC* : seq[G1] # the points `delta^-1 * ( beta*A_j(tau) + alpha*B_j(tau) + C_j(tau) ) * g1` (for j <= npub)
ProverPoints* = object ProverPoints* = object
pointsA1* : seq[G1] pointsA1* : seq[G1] # the points `A_j(tau) * g1`
pointsB1* : seq[G1] pointsB1* : seq[G1] # the points `B_j(tau) * g1`
pointsB2* : seq[G2] pointsB2* : seq[G2] # the points `B_j(tau) * g2`
pointsC1* : seq[G1] pointsC1* : seq[G1] # the points `delta^-1 * ( beta*A_j(tau) + alpha*B_j(tau) + C_j(tau) ) * g1` (for j > npub)
pointsH1* : seq[G1] pointsH1* : seq[G1] # meaning depends on `flavour`
MatrixSel* = enum MatrixSel* = enum
MatrixA MatrixA
@ -68,3 +73,21 @@ func extractVKey*(zkey: Zkey): VKey =
return VKey(curve:curve, spec:spec, vpoints:vpts) return VKey(curve:curve, spec:spec, vpoints:vpts)
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
func matrixSelToString(sel: MatrixSel): string =
case sel
of MatrixA: return "A"
of MatrixB: return "B"
of MatrixC: return "C"
proc printCoeff(cf: Coeff) =
echo( "matrix=", matrixSelToString(cf.matrix)
, " | i=", cf.row
, " | j=", cf.col
, " | val=", signedToDecimalFr(cf.coeff)
)
proc printCoeffs*(cfs: seq[Coeff]) =
for cf in cfs: printCoeff(cf)
#-------------------------------------------------------------------------------