proving and verifying _finally_ works

This commit is contained in:
Balazs Komuves 2023-11-11 13:35:13 +01:00
parent e893d37b43
commit f094de8df3
16 changed files with 759 additions and 180 deletions

1
.gitignore vendored
View File

@ -2,4 +2,5 @@
_bck*
tmp
main
main.nim
*.json

View File

@ -1,11 +1,24 @@
Groth16 prover in Nim
---------------------
Groth16 prover written in Nim
-----------------------------
This is Groth16 prover implementation in Nim, using the
[`constantine`](https://github.com/mratsim/constantine)
library as an arithmetic / curve backend.
The implementation should be compatible with the `circom` ecosystem.
The implementation is compatible with the `circom` ecosystem.
At the moment only the `BN254` (aka. `alt-bn128`) curve is supported.
### TODO
- [ ] make it a nimble package
- [ ] proper MSM implementation (I couldn't make constantine's one to work)
- [ ] proper polynomial implemention (constantine's one is essentially missing)
- [ ] compare `.r1cs` to the "coeffs" section of `.zkey`
- [ ] make it work for different curves
- [ ] multithreaded support (MSM, and possibly also FFT)
- [ ] add Groth16 notes
- [ ] document the `snarkjs` circuit-specific setup `H` points convention

172
bn128.nim
View File

@ -1,6 +1,19 @@
#
# the `alt-bn128` elliptic curve
#
# See for example <https://hackmd.io/@jpw/bn254>
#
# p = 21888242871839275222246405745257275088696311157297823662689037894645226208583
# r = 21888242871839275222246405745257275088548364400416034343698204186575808495617
#
# equation: y^2 = x^3 + 3
#
import sugar
import std/bitops
import std/strutils
import std/sequtils
import std/streams
import std/random
@ -14,8 +27,8 @@ import constantine/math/extension_fields/towers as ext
import constantine/math/elliptic/ec_shortweierstrass_affine as aff
import constantine/math/elliptic/ec_shortweierstrass_projective as prj
import constantine/math/pairings/pairings_bn as ate
import constantine/math/elliptic/ec_multi_scalar_mul as msm
import constantine/math/elliptic/ec_scalar_mul as scl
# import constantine/math/elliptic/ec_multi_scalar_mul as msm
#-------------------------------------------------------------------------------
@ -118,6 +131,12 @@ func smallPowFr*(base: Fr, expo: uint): Fr =
square(s)
return a
func smallPowFr*(base: Fr, expo: int): Fr =
if expo >= 0:
return smallPowFr( base, uint(expo) )
else:
return smallPowFr( invFr(base) , uint(-expo) )
#-------------------------------------------------------------------------------
func toDecimalBig*[n](a : BigInt[n]): string =
@ -140,11 +159,59 @@ func toDecimalFr*(a : Fr): string =
#-------------------------------------------------------------------------------
proc debugPrintSeqFr*(msg: string, xs: seq[Fr]) =
proc debugPrintFp*(prefix: string, x: Fp) =
echo(prefix & toDecimalFp(x))
proc debugPrintFp2*(prefix: string, z: Fp2) =
echo(prefix & " 1 ~> " & toDecimalFp(z.coords[0]))
echo(prefix & " u ~> " & toDecimalFp(z.coords[1]))
proc debugPrintFr*(prefix: string, x: Fr) =
echo(prefix & toDecimalFr(x))
proc debugPrintFrSeq*(msg: string, xs: seq[Fr]) =
echo "---------------------"
echo msg
for x in xs:
echo(" " & toDecimalFr(x))
debugPrintFr( " " , x )
proc debugPrintG1*(msg: string, pt: G1) =
echo(msg & ":")
debugPrintFp( " x = ", pt.x )
debugPrintFp( " y = ", pt.y )
proc debugPrintG2*(msg: string, pt: G2) =
echo(msg & ":")
debugPrintFp2( " x = ", pt.x )
debugPrintFp2( " y = ", pt.y )
#-------------------------------------------------------------------------------
# Montgomery batch inversion
func batchInverse*( xs: seq[Fr] ) : seq[Fr] =
let n = xs.len
assert(n>0)
var us : seq[Fr] = newSeq[Fr](n+1)
var a = xs[0]
us[0] = oneFr
us[1] = a
for i in 1..<n: ( a *= xs[i] ; us[i+1] = a )
var vs : seq[Fr] = newSeq[Fr](n)
vs[n-1] = invFr( us[n] )
for i in countdown(n-2,0): vs[i] = vs[i+1] * xs[i+1]
return collect( newSeq, (for i in 0..<n: us[i]*vs[i] ) )
proc sanityCheckBatchInverse*() =
let xs : seq[Fr] = map( toSeq(101..137) , intToFr )
let ys = batchInverse( xs )
let zs = collect( newSeq, (for x in xs: invFr(x)) )
let n = xs.len
# for i in 0..<n: echo(i," | batch = ",toDecimalFr(ys[i])," | ref = ",toDecimalFr(zs[i]) )
for i in 0..<n:
if not bool(ys[i] == zs[i]):
echo "batch inverse test FAILED!"
return
echo "batch iverse test OK."
#-------------------------------------------------------------------------------
# random values
@ -321,9 +388,10 @@ proc checkMontgomeryConstants*() =
#---------------------------------------
# the binary files used by the `circom` ecosystem always use little-endian
# Montgomery representation. So when we unmarshal with Constantine, it will
# give the wrong result. Calling this function on the result fixes that.
# the binary files used by the `circom` ecosystem (EXCEPT the witness file!)
# always use little-endian Montgomery representation. So when we unmarshal
# with Constantine, it will give the wrong result. Calling this function on
# the result fixes that.
func fromMontgomeryFp*(x : Fp) : Fp =
var y : Fp = x;
y *= fpInvMontR
@ -334,19 +402,41 @@ func fromMontgomeryFr*(x : Fr) : Fr =
y *= frInvMontR
return y
func toMontgomeryFr*(x : Fr) : Fr =
var y : Fr = x;
y *= frMontR
return y
#-------------------------------------------------------------------------------
# Unmarshalling field elements
# (note: circom binary files use little-endian Montgomery representation)
# Except, in witness files, where the standard representation is used
# And, EXCEPT in the zkey coefficients, where apparently DOUBLE Montgomery encoding is used ???
#
func unmarshalFp* ( bs: array[32,byte] ) : Fp =
# WTF Jordi, go home you are drunk
func unmarshalFrWTF* ( bs: array[32,byte] ) : Fr =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fr
x.fromBig( big )
return fromMontgomeryFr(fromMontgomeryFr(x))
func unmarshalFrStd* ( bs: array[32,byte] ) : Fr =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fr
x.fromBig( big )
return x
func unmarshalFpMont* ( bs: array[32,byte] ) : Fp =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fp
x.fromBig( big )
return fromMontgomeryFp(x)
func unmarshalFr* ( bs: array[32,byte] ) : Fr =
func unmarshalFrMont* ( bs: array[32,byte] ) : Fr =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fr
@ -355,65 +445,85 @@ func unmarshalFr* ( bs: array[32,byte] ) : Fr =
#-------------------------------------------------------------------------------
func unmarshalFpSeq* ( len: int, bs: openArray[byte] ) : seq[Fp] =
func unmarshalFpMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fp] =
var vals : seq[Fp] = newSeq[Fp]( len )
var bytes : array[32,byte]
for i in 0..<len:
copyMem( addr(bytes) , unsafeAddr(bs[32*i]) , 32 )
vals[i] = unmarshalFp( bytes )
vals[i] = unmarshalFpMont( bytes )
return vals
func unmarshalFrSeq* ( len: int, bs: openArray[byte] ) : seq[Fr] =
func unmarshalFrMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fr] =
var vals : seq[Fr] = newSeq[Fr]( len )
var bytes : array[32,byte]
for i in 0..<len:
copyMem( addr(bytes) , unsafeAddr(bs[32*i]) , 32 )
vals[i] = unmarshalFr( bytes )
vals[i] = unmarshalFrMont( bytes )
return vals
#-------------------------------------------------------------------------------
proc loadValueFp*( stream: Stream ) : Fp =
proc loadValueFrWTF*( stream: Stream ) : Fr =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
# for i in 0..<32: stdout.write(" " & toHex(bytes[i]))
# echo("")
assert( n == 32 )
return unmarshalFrWTF(bytes)
proc loadValueFrStd*( stream: Stream ) : Fr =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
assert( n == 32 )
return unmarshalFp(bytes)
return unmarshalFrStd(bytes)
proc loadValueFr*( stream: Stream ) : Fr =
proc loadValueFrMont*( stream: Stream ) : Fr =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
assert( n == 32 )
return unmarshalFr(bytes)
return unmarshalFrMont(bytes)
proc loadValueFp2*( stream: Stream ) : Fp2 =
let i = loadValueFp( stream )
let u = loadValueFp( stream )
proc loadValueFpMont*( stream: Stream ) : Fp =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
assert( n == 32 )
return unmarshalFpMont(bytes)
proc loadValueFp2Mont*( stream: Stream ) : Fp2 =
let i = loadValueFpMont( stream )
let u = loadValueFpMont( stream )
return mkFp2(i,u)
#---------------------------------------
proc loadValuesFp*( len: int, stream: Stream ) : seq[Fp] =
var values : seq[Fp]
for i in 1..len:
values.add( loadValueFp(stream) )
return values
proc loadValuesFr*( len: int, stream: Stream ) : seq[Fr] =
proc loadValuesFrStd*( len: int, stream: Stream ) : seq[Fr] =
var values : seq[Fr]
for i in 1..len:
values.add( loadValueFr(stream) )
values.add( loadValueFrStd(stream) )
return values
proc loadValuesFpMont*( len: int, stream: Stream ) : seq[Fp] =
var values : seq[Fp]
for i in 1..len:
values.add( loadValueFpMont(stream) )
return values
proc loadValuesFrMont*( len: int, stream: Stream ) : seq[Fr] =
var values : seq[Fr]
for i in 1..len:
values.add( loadValueFrMont(stream) )
return values
#-------------------------------------------------------------------------------
proc loadPointG1*( stream: Stream ) : G1 =
let x = loadValueFp( stream )
let y = loadValueFp( stream )
let x = loadValueFpMont( stream )
let y = loadValueFpMont( stream )
return mkG1(x,y)
proc loadPointG2*( stream: Stream ) : G2 =
let x = loadValueFp2( stream )
let y = loadValueFp2( stream )
let x = loadValueFp2Mont( stream )
let y = loadValueFp2Mont( stream )
return mkG2(x,y)
#---------------------------------------

View File

@ -67,8 +67,6 @@ proc readSection[T] ( expectedMagic: string
let sectLen = int( stream.readUint64() )
let oldpos = stream.getPosition()
if filt(sectId):
# echo("section id = ",sectId )
# echo("section len = ",sectLen)
callback(stream, sectId, sectLen, user)
stream.setPosition(oldpos + sectLen)
@ -89,7 +87,7 @@ proc parseContainer*[T] ( expectedMagic: string
let version = stream.readUint32()
assert( version == uint32(expectedVersion) , "not a version " & ($expectedVersion) & " `" & expectedMagic & "` file" )
let nsections = stream.readUint32()
echo("number of sections = ",nsections)
# echo("number of sections = ",nsections)
for i in 1..nsections:
readSection(expectedMagic, expectedVersion, stream, user, callback, filt)

View File

@ -3,9 +3,9 @@
# power-of-two sized multiplicative FFT domains in the scalar field
#
import constantine/math/io/io_bigints
import constantine/math/arithmetic except Fp,Fr
import constantine/math/io/io_fields except Fp,Fr
#import constantine/math/io/io_bigints
import ./bn128
import ./misc
@ -31,8 +31,8 @@ func createDomain*(size: int): Domain =
let gen : Fr = smallPowFr(gen28, expo)
let halfSize = size div 2
let a : Fr = smallPowFr(gen, uint(size ))
let b : Fr = smallPowFr(gen, uint(halfSize))
let a : Fr = smallPowFr(gen, size )
let b : Fr = smallPowFr(gen, halfSize)
assert( bool(a == oneFr) , "domain generator sanity check /A" )
assert( not bool(b == oneFr) , "domain generator sanity check /B" )

95
export_json.nim Normal file
View File

@ -0,0 +1,95 @@
#
# export proof and public input in `circom`-compatible JSON files
#
import std/sequtils
import constantine/math/arithmetic except Fp, Fr
#import constantine/math/io/io_fields except Fp, Fr
import bn128
from ./groth16 import Proof
#-------------------------------------------------------------------------------
func toQuotedDecimalFp(x: Fp): string =
let s : string = toDecimalFp(x)
return ("\"" & s & "\"")
func toQuotedDecimalFr(x: Fr): string =
let s : string = toDecimalFr(x)
return ("\"" & s & "\"")
#-------------------------------------------------------------------------------
# exports the public input/output into as a JSON file
proc exportPublicIO*( fpath: string, prf: Proof ) =
# debugPrintFrSeq("public IO",prf.publicIO)
let n : int = prf.publicIO.len
assert( n > 0 )
assert( bool(prf.publicIO[0] == oneFr) )
let f = open(fpath, fmWrite)
defer: f.close()
for i in 1..<n:
let str : string = toQuotedDecimalFr( prf.publicIO[i] )
if i==1:
f.writeLine("[ " & str)
else:
f.writeLine(", " & str)
f.writeLine("] ")
#-------------------------------------------------------------------------------
proc writeFp2( f: File, c: char, z: Fp2 ) =
let prefix = " " & c & " "
let indent = " "
f.writeLine( prefix & "[ " & toQuotedDecimalFp( z.coords[0] ) )
f.writeLine( indent & ", " & toQuotedDecimalFp( z.coords[1] ) )
f.writeLine( indent & "]")
proc writeG1( f: File, p: G1 ) =
f.writeLine(" [ " & toQuotedDecimalFp( p.x ) )
f.writeLine(" , " & toQuotedDecimalFp( p.y ) )
f.writeLine(" , " & toQuotedDecimalFp( oneFp ) )
f.writeLine(" ]")
proc writeG2( f: File, p: G2 ) =
writeFp2( f , '[' , p.x )
writeFp2( f , ',' , p.y )
writeFp2( f , ',' , oneFp2 )
f.writeLine(" ]")
# exports the proof into as a JSON file
proc exportProof*( fpath: string, prf: Proof ) =
let f = open(fpath, fmWrite)
defer: f.close()
f.writeLine("{ \"protocol\": \"groth16\"")
f.writeLine(", \"curve\": \"bn128\"" )
f.writeLine(", \"pi_a\":" ) ; writeG1( f, prf.pi_a )
f.writeLine(", \"pi_b\":" ) ; writeG2( f, prf.pi_b )
f.writeLine(", \"pi_c\":" ) ; writeG1( f, prf.pi_c )
f.writeLine("}")
#-------------------------------------------------------------------------------
func getFakeProof*() : Proof =
let pub : seq[Fr] = map( [1,101,102,103,117,119] , intToFr )
let p = unsafeMkG1( intToFp(666) , intToFp(777) )
let r = unsafeMkG1( intToFp(888) , intToFp(999) )
let x = mkFp2( intToFp(22) , intToFp(33) )
let y = mkFp2( intToFp(44) , intToFp(55) )
let q = unsafeMkG2( x , y )
return Proof( publicIO:pub, pi_a:p, pi_b:q, pi_c:r )
proc exportFakeProof*() =
let prf = getFakeProof()
exportPublicIO( "fake_pub.json" , prf )
exportProof( "fake_prf.json" , prf )

236
groth16.nim Normal file
View File

@ -0,0 +1,236 @@
#
# Groth16 prover
#
# WARNING! the points H are *NOT* what normal people would think they are
# See <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
#
#[]
import sugar
import constantine/math/config/curves
import constantine/math/io/io_fields
import constantine/math/io/io_bigints
import ./zkey
]#
import constantine/math/arithmetic except Fp, Fr
import constantine/math/io/io_extfields except Fp12
import constantine/math/extension_fields/towers except Fp2, Fp12
import ./bn128
import ./domain
import ./poly
import ./zkey_types
import ./witness
#-------------------------------------------------------------------------------
type
Proof* = object
publicIO* : seq[Fr]
pi_a* : G1
pi_b* : G2
pi_c* : G1
curve : string
#-------------------------------------------------------------------------------
# the verifier
#
proc verifyProof* (vkey: VKey, prf: Proof): bool =
assert( prf.curve == "bn128" )
assert( isOnCurveG1(prf.pi_a) , "pi_a is not in G1" )
assert( isOnCurveG2(prf.pi_b) , "pi_b is not in G2" )
assert( isOnCurveG1(prf.pi_c) , "pi_c is not in G1" )
var pubG1 : G1 = msmG1( prf.publicIO , vkey.vpoints.pointsIC )
let lhs : Fp12 = pairing( negG1(prf.pi_a) , prf.pi_b ) # < -pi_a , pi_b >
let rhs1 : Fp12 = vkey.spec.alphaBeta # < alpha , beta >
let rhs2 : Fp12 = pairing( prf.pi_c , vkey.spec.delta2 ) # < pi_c , delta >
let rhs3 : Fp12 = pairing( pubG1 , vkey.spec.gamma2 ) # < sum... , gamma >
var eq : Fp12
eq = lhs
eq *= rhs1
eq *= rhs2
eq *= rhs3
return bool(isOne(eq))
#-------------------------------------------------------------------------------
# A, B, C column vectors
#
type
ABC = object
valuesA : seq[Fr]
valuesB : seq[Fr]
valuesC : seq[Fr]
func buildABC( zkey: ZKey, witness: seq[Fr] ): ABC =
let hdr: GrothHeader = zkey.header
let domSize = hdr.domainSize
var valuesA : seq[Fr] = newSeq[Fr](domSize)
var valuesB : seq[Fr] = newSeq[Fr](domSize)
for entry in zkey.coeffs:
case entry.matrix
of MatrixA: valuesA[entry.row] += entry.coeff * witness[entry.col]
of MatrixB: valuesB[entry.row] += entry.coeff * witness[entry.col]
else: raise newException(AssertionDefect, "fatal error")
var valuesC : seq[Fr] = newSeq[Fr](domSize)
for i in 0..<hdr.nvars:
valuesC[i] = valuesA[i] * valuesB[i]
return ABC( valuesA:valuesA, valuesB:valuesB, valuesC:valuesC )
#-------------------------------------------------------------------------------
# quotient poly
#
# interpolates A,B,C, and computes the quotient polynomial Q = (A*B - C) / Z
func computeQuotientNaive( abc: ABC ): Poly=
let n = abc.valuesA.len
assert( abc.valuesB.len == n )
assert( abc.valuesC.len == n )
let D = createDomain(n)
let polyA : Poly = polyInverseNTT( abc.valuesA , D )
let polyB : Poly = polyInverseNTT( abc.valuesB , D )
let polyC : Poly = polyInverseNTT( abc.valuesC , D )
let polyBig = polyMulFFT( polyA , polyB ) - polyC
var polyQ = polyDivideByVanishing(polyBig, D.domainSize)
polyQ.coeffs.add( zeroFr ) # make it a power of two
return polyQ
#---------------------------------------
# returns [ eta^i * xs[i] | i<-[0..n-1] ]
func multiplyByPowers( xs: seq[Fr], eta: Fr ): seq[Fr] =
let n = xs.len
assert(n >= 1)
var ys : seq[Fr] = newSeq[Fr](n)
ys[0] = xs[0]
if n >= 1: ys[1] = eta * xs[1]
var spow : Fr = eta
for i in 2..<n:
spow *= eta
ys[i] = spow * xs[i]
return ys
# interpolates a polynomial, shift the variable by `eta`, and compute the shifted values
func shiftEvalDomain( values: seq[Fr], D: Domain, eta: Fr ): seq[Fr] =
let poly : Poly = polyInverseNTT( values , D )
let cs : seq[Fr] = poly.coeffs
var ds : seq[Fr] = multiplyByPowers( cs, eta )
return polyForwardNTT( Poly(coeffs:ds), D )
# computes the quotient polynomial Q = (A*B - C) / Z
# by computing the values on a shifted domain, and interpolating the result
# remark: Q has degree `n-2`, so it's enough to use a domain of size n
func computeQuotientPointwise( abc: ABC ): Poly =
let n = abc.valuesA.len
let D = createDomain(n)
# (eta*omega^j)^n - 1 = eta^n - 1
# 1 / [ (eta*omega^j)^n - 1] = 1/(eta^n - 1)
let eta = createDomain(2*n).domainGen
let Inv1 = invFr( smallPowFr(eta,n) - oneFr )
let A1 = shiftEvalDomain( abc.valuesA, D, eta )
let B1 = shiftEvalDomain( abc.valuesB, D, eta )
let C1 = shiftEvalDomain( abc.valuesC, D, eta )
var ys : seq[Fr] = newSeq[Fr]( n )
for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] ) * Inv1
let Q1 = polyInverseNTT( ys, D )
let cs = multiplyByPowers( Q1.coeffs, invFr(eta) )
return Poly(coeffs: cs)
#---------------------------------------
# Snarkjs does something different, not actually computing the quotient poly
# they can get away with this, because during the trusted setup, they
# transform the H points into (shifted??) Lagrange bases (?)
# see <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
#
func computeSnarkjsScalarCoeffs( abc: ABC ): seq[Fr] =
let n = abc.valuesA.len
let D = createDomain(n)
let eta = createDomain(2*n).domainGen
let A1 = shiftEvalDomain( abc.valuesA, D, eta )
let B1 = shiftEvalDomain( abc.valuesB, D, eta )
let C1 = shiftEvalDomain( abc.valuesC, D, eta )
var ys : seq[Fr] = newSeq[Fr]( n )
for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] )
return ys
#-------------------------------------------------------------------------------
# the prover
#
proc generateProof* ( zkey: ZKey, wtns: Witness ): Proof =
assert( zkey.header.curve == wtns.curve )
let witness = wtns.values
let hdr : GrothHeader = zkey.header
let spec : SpecPoints = zkey.specPoints
let pts : ProverPoints = zkey.pPoints
let nvars = hdr.nvars
let npubs = hdr.npubs
assert( nvars == witness.len , "wrong witness length" )
var pubIO : seq[Fr] = newSeq[Fr]( npubs + 1)
for i in 0..npubs: pubIO[i] = witness[i]
var abc : ABC = buildABC( zkey, witness )
# let polyQ1 = computeQuotientNaive( abc )
# let polyQ2 = computeQuotientPointwise( abc )
let qs = computeSnarkjsScalarCoeffs( abc )
var zs : seq[Fr] = newSeq[Fr]( nvars - npubs - 1 )
for j in npubs+1..<nvars:
zs[j-npubs-1] = witness[j]
# masking coeffs
let r : Fr = randFr()
let s : Fr = randFr()
# let r : Fr = intToFr(3)
# let s : Fr = intToFr(4)
var pi_a : G1
pi_a = spec.alpha1
pi_a += r ** spec.delta1
pi_a += msmG1( witness , pts.pointsA1 )
var rho : G1
rho = spec.beta1
rho += s ** spec.delta1
rho += msmG1( witness , pts.pointsB1 )
var pi_b : G2
pi_b = spec.beta2
pi_b += s ** spec.delta2
pi_b += msmG2( witness , pts.pointsB2 )
var pi_c : G1
pi_c = s ** pi_a
pi_c += r ** rho
pi_c += negFr(r*s) ** spec.delta1
pi_c += msmG1( qs , pts.pointsH1 )
pi_c += msmG1( zs , pts.pointsC1 )
return Proof( curve:"bn128", publicIO:pubIO, pi_a:pi_a, pi_b:pi_b, pi_c:pi_c )
#-------------------------------------------------------------------------------

View File

@ -1,19 +0,0 @@
import ./r1cs
import ./zkey
import ./witness
import ./bn128
#-------------------------------------------------------------------------------
proc testMain() =
# checkMontgomeryConstants()
let r1cs_fname : string = "/Users/bkomuves/zk/examples/circom/toy/build/toymain.r1cs"
let zkey_fname : string = "/Users/bkomuves/zk/examples/circom/toy/build/toymain.zkey"
let wtns_fname : string = "/Users/bkomuves/zk/examples/circom/toy/build/toymain_witness.wtns"
parseWitness( wtns_fname)
parseR1CS( r1cs_fname)
parseZKey( zkey_fname)
when isMainModule:
testMain()

View File

@ -1,4 +1,7 @@
#
# miscellaneous routines
#
#-------------------------------------------------------------------------------
@ -16,13 +19,26 @@ func ceilingLog2* (x : int) : int =
else:
return (floorLog2(x-1) + 1)
#
# import std/math
#
# proc sanityCheckLog2* () =
# for i in 0..18:
# let x = float64(i)
# echo( i," | ",floorLog2(i),"=",floor(log2(x))," | ",ceilingLog2(i),"=",ceil(log2(x)) )
#
#-------------------
#[
import std/math
proc sanityCheckLog2* () =
for i in 0..18:
let x = float64(i)
echo( i," | ",floorLog2(i),"=",floor(log2(x))," | ",ceilingLog2(i),"=",ceil(log2(x)) )
]#
#-------------------------------------------------------------------------------
#[
func rotateSeq[T](xs: seq[T], ofs: int): seq[T] =
let n = xs.len
var ys : seq[T]
for i in (0..<n):
ys.add( xs[ (i+n+ofs) mod n ] )
return ys
]#
#-------------------------------------------------------------------------------

18
ntt.nim
View File

@ -1,9 +1,10 @@
#
# Number-theoretic transform
# Number-theoretic transform
# (that is, FFT for polynomials over finite fields)
#
import std/sequtils
#-------------------------------------------------------------------------------
import constantine/math/arithmetic except Fp,Fr
import constantine/math/io/io_fields
@ -68,6 +69,19 @@ func forwardNTT*(src: seq[Fr], D: Domain): seq[Fr] =
, tgt , 0 )
return tgt
# pads the input with zeros to get a pwoer of two size
func extendAndForwardNTT*(src: seq[Fr], D: Domain): seq[Fr] =
let n = src.len
let N = D.domainSize
assert( n <= N )
if n == N:
return forwardNTT(src, D)
else:
var padded : seq[Fr] = newSeq[Fr]( N )
for i in 0..<n: padded[i] = src[i]
# for i in n..<N: padded[i] = zeroFr
return forwardNTT(padded, D)
#-------------------------------------------------------------------------------
const oneHalfFr* : Fr = fromHex(Fr, "0x183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f0fac9f8000001")

109
poly.nim
View File

@ -17,6 +17,7 @@ import constantine/math/io/io_fields
import bn128
import domain
import ntt
import misc
#-------------------------------------------------------------------------------
@ -74,25 +75,25 @@ func polyNeg*(P: Poly) : Poly =
func polyAdd*(P, Q: Poly) : Poly =
let xs = P.coeffs ; let n = xs.len
let ys = Q.coeffs ; let m = ys.len
var zs : seq[Fr]
var zs : seq[Fr] = newSeq[Fr](max(n,m))
if n >= m:
for i in 0..<m: zs.add( xs[i] + ys[i] )
for i in m..<n: zs.add( xs[i] )
for i in 0..<m: zs[i] = ( xs[i] + ys[i] )
for i in m..<n: zs[i] = ( xs[i] )
else:
for i in 0..<n: zs.add( xs[i] + ys[i] )
for i in n..<m: zs.add( ys[i] )
for i in 0..<n: zs[i] = ( xs[i] + ys[i] )
for i in n..<m: zs[i] = ( ys[i] )
return Poly(coeffs: zs)
func polySub*(P, Q: Poly) : Poly =
let xs = P.coeffs ; let n = xs.len
let ys = Q.coeffs ; let m = ys.len
var zs : seq[Fr]
var zs : seq[Fr] = newSeq[Fr](max(n,m))
if n >= m:
for i in 0..<m: zs.add( xs[i] - ys[i] )
for i in m..<n: zs.add( xs[i] )
for i in 0..<m: zs[i] = ( xs[i] - ys[i] )
for i in m..<n: zs[i] = ( xs[i] )
else:
for i in 0..<n: zs.add( xs[i] + ys[i] )
for i in n..<m: zs.add( zeroFr - ys[i] )
for i in 0..<n: zs[i] = ( xs[i] + ys[i] )
for i in n..<m: zs[i] = ( negFr( ys[i] ))
return Poly(coeffs: zs)
#-------------------------------------------------------------------------------
@ -122,7 +123,28 @@ func polyMulNaive*(P, Q : Poly): Poly =
zs[k] += xs[i] * ys[j]
return Poly(coeffs: zs)
#-------------------------------------------------------------------------------
# multiply two polynomials using FFT
func polyMulFFT*(P, Q: Poly): Poly =
let n1 = P.coeffs.len
let n2 = Q.coeffs.len
let log2 : int = max( ceilingLog2(n1) , ceilingLog2(n2) ) + 1
let N : int = (1 shl log2)
let D : Domain = createDomain( N )
let us : seq[Fr] = extendAndForwardNTT( P.coeffs, D )
let vs : seq[Fr] = extendAndForwardNTT( Q.coeffs, D )
let zs : seq[Fr] = collect( newSeq, (for i in 0..<N: us[i]*vs[i] ))
let ws : seq[Fr] = inverseNTT( zs, D )
return Poly(coeffs: ws)
#-------------------------------------------------------------------------------
func polyMul*(P, Q : Poly): Poly =
# return polyMulFFT(P, Q)
return polyMulNaive(P, Q)
#-------------------------------------------------------------------------------
@ -138,14 +160,20 @@ func `*`*(P: Poly, s: Fr ): Poly = return polyScale(s, P)
#-------------------------------------------------------------------------------
# the vanishing polynomial `(x^N - 1)`
func vanishingPoly*(N: int): Poly =
# the generalized vanishing polynomial `(a*x^N - b)`
func generalizedVanishingPoly*(N: int, a: Fr, b: Fr): Poly =
assert( N>=1 )
var cs : seq[Fr] = newSeq[Fr]( N+1 )
cs[0] = negFr( oneFr )
cs[N] = oneFr
cs[0] = negFr(b)
cs[N] = a
return Poly(coeffs: cs)
# the vanishing polynomial `(x^N - 1)`
func vanishingPoly*(N: int): Poly =
return generalizedVanishingPoly(N, oneFr, oneFr)
#-------------------------------------------------------------------------------
type
QuotRem*[T] = object
quot* : T
@ -164,13 +192,15 @@ func polyQuotRemByVanishing*(P: Poly, N: int): QuotRem[Poly] =
rem = src
else:
# compute quot
# compute quotient
for j in countdown(deg-N, 0):
if j+N <= deg-N:
quot[j] = src[j+N] + quot[j+N]
else:
quot[j] = src[j+N]
# compute rem
# compute remainder
for j in 0..<N:
if j <= deg-N:
rem[j] = src[j] + quot[j]
@ -191,14 +221,8 @@ func polyDivideByVanishing*(P: Poly, N: int): Poly =
func polyForwardNTT*(P: Poly, D: Domain): seq[Fr] =
let n = P.coeffs.len
assert( n <= D.domainSize , "the domain must be as least as big as the polynomial" )
if n == D.domainSize:
let src : seq[Fr] = P.coeffs
return forwardNTT(src, D)
else:
var src : seq[Fr] = P.coeffs
for i in n..<D.domainSize: src.add( zeroFr )
return forwardNTT(src, D)
let src : seq[Fr] = P.coeffs
return forwardNTT(src, D)
#---------------------------------------
@ -211,6 +235,8 @@ func polyInverseNTT*(ys: seq[Fr], D: Domain): Poly =
#-------------------------------------------------------------------------------
#[
proc sanityCheckOneHalf*() =
let two = oneFr + oneFr
let invTwo = oneHalfFr
@ -224,20 +250,20 @@ proc sanityCheckVanishing*() =
let P : Poly = Poly( coeffs:cs )
echo("degree of P = ",polyDegree(P))
debugPrintSeqFr("xs", P.coeffs)
debugPrintFrSeq("xs", P.coeffs)
let n : int = 5
let QR = polyQuotRemByVanishing(P, n)
let Q = QR.quot
let R = QR.rem
debugPrintSeqFr("Q", Q.coeffs)
debugPrintSeqFr("R", R.coeffs)
debugPrintFrSeq("Q", Q.coeffs)
debugPrintFrSeq("R", R.coeffs)
let Z : Poly = vanishingPoly(n)
let S : Poly = Q * Z + R
debugPrintSeqFr("zs", S.coeffs)
debugPrintFrSeq("zs", S.coeffs)
echo( polyIsEqual(P,S) )
proc sanityCheckNTT*() =
@ -249,9 +275,28 @@ proc sanityCheckNTT*() =
let ys : seq[Fr] = collect( newSeq, (for x in xs: polyEvalAt(P,x)) )
let zs : seq[Fr] = polyForwardNTT(P ,D)
let Q : Poly = polyInverseNTT(zs,D)
debugPrintSeqFr("xs", xs)
debugPrintSeqFr("ys", ys)
debugPrintSeqFr("zs", zs)
debugPrintSeqFr("us", Q.coeffs)
debugPrintFrSeq("xs", xs)
debugPrintFrSeq("ys", ys)
debugPrintFrSeq("zs", zs)
debugPrintFrSeq("us", Q.coeffs)
proc sanityCheckMulFFT*() =
var js : seq[int] = toSeq(101..110)
let cs : seq[Fr] = map( js, intToFr )
let P : Poly = Poly( coeffs:cs )
var ks : seq[int] = toSeq(1001..1020)
let ds : seq[Fr] = map( ks, intToFr )
let Q : Poly = Poly( coeffs:ds )
let R1 : Poly = polyMulNaive( P , Q )
let R2 : Poly = polyMulFFT( P , Q )
# debugPrintFrSeq("naive coeffs", R1.coeffs)
# debugPrintFrSeq("fft coeffs", R2.coeffs)
echo( "multiply test = ", polyIsEqual(R1,R2) )
]#
#-------------------------------------------------------------------------------

View File

@ -82,12 +82,13 @@ type
#-------------------------------------------------------------------------------
proc parseSection1_header( stream: Stream, user: var R1CS, sectionLen: int ) =
echo "\nparsing r1cs header"
# echo "\nparsing r1cs header"
let (n8r, r) = parsePrimeField( stream ) # size of the scalar field
echo("r = ",toDecimalBig(r))
user.r = r;
# echo("r = ",toDecimalBig(r))
assert( sectionLen == 4 + n8r + 16 + 8 + 4, "unexpected section length")
assert( bool(r == primeR) , "expecting the alt-bn128 curve" )
@ -104,14 +105,14 @@ proc parseSection1_header( stream: Stream, user: var R1CS, sectionLen: int ) =
let nConstr = int( stream.readUint32() )
user.nConstr = nConstr
echo("witness config = ",cfg)
echo("nConstr = ",nConstr)
# echo("witness config = ",cfg)
# echo("nConstr = ",nConstr)
#-------------------------------------------------------------------------------
proc loadTerm( stream: Stream ): Term =
let idx = int( stream.readUint32() )
let coeff = loadValueFr( stream )
let coeff = loadValueFrMont( stream )
return (wireIdx:idx, value:coeff)
proc loadLinComb( stream: Stream ): LinComb =
@ -139,9 +140,9 @@ proc parseSection2_constraints( stream: Stream, user: var R1CS, sectionLen: int
ncoeffsB += abc.B.len
ncoeffsC += abc.C.len
user.constraints = constr
echo( "number of nonzero coefficients in matrix A = ", ncoeffsA )
echo( "number of nonzero coefficients in matrix B = ", ncoeffsB )
echo( "number of nonzero coefficients in matrix C = ", ncoeffsC )
# echo( "number of nonzero coefficients in matrix A = ", ncoeffsA )
# echo( "number of nonzero coefficients in matrix B = ", ncoeffsB )
# echo( "number of nonzero coefficients in matrix C = ", ncoeffsC )
#-------------------------------------------------------------------------------
@ -166,9 +167,10 @@ proc r1csCallback( stream: Stream
of 3: parseSection3_wireToLabel( stream, user, sectLen )
else: discard
proc parseR1CS* (fname: string) =
proc parseR1CS* (fname: string): R1CS =
var r1cs : R1CS
parseContainer( "r1cs", 1, fname, r1cs, r1csCallback, proc (id: int): bool = id == 1 )
parseContainer( "r1cs", 1, fname, r1cs, r1csCallback, proc (id: int): bool = id != 1 )
return r1cs
#-------------------------------------------------------------------------------

28
test_proof.nim Normal file
View File

@ -0,0 +1,28 @@
import ./groth16
import ./export_json
import ./witness
import ./zkey
import ./zkey_types
#-------------------------------------------------------------------------------
proc testProveAndVerify*( zkey_fname, wtns_fname: string) =
echo("parsing witness & zkey files...")
let witness = parseWitness( wtns_fname)
let zkey = parseZKey( zkey_fname)
echo("generating proof...")
let vkey = extractVKey( zkey)
let proof = generateProof( zkey, witness )
echo("exporting proof...")
exportPublicIO( "my_pub.json" , proof )
exportProof( "my_prf.json" , proof )
echo("verifying the proof...")
let ok = verifyProof( vkey, proof)
echo("verification succeeded = ",ok)
#-------------------------------------------------------------------------------

View File

@ -11,6 +11,10 @@
#
# nvars = 1 + pub + secret = 1 + npubout + npubin + nprivin + nsecret
#
# WARNING! unlike the `.r1cs` and `.zkey` files, which encode field elements
# in Montgomery representation, the `.wtns` file encode field elements in
# the standard representation!!
#
import std/streams
@ -24,35 +28,38 @@ import ./container
type
Witness* = object
r : BigInt[256]
nvars : int
values : seq[Fr]
curve* : string
r* : BigInt[256]
nvars* : int
values* : seq[Fr]
#-------------------------------------------------------------------------------
proc parseSection1_header( stream: Stream, user: var Witness, sectionLen: int ) =
echo "\nparsing witness header"
# echo "\nparsing witness header"
let (n8r, r) = parsePrimeField( stream ) # size of the scalar field
user.r = r;
echo("r = ",toDecimalBig(r))
# echo("r = ",toDecimalBig(r))
assert( sectionLen == 4 + n8r + 4 , "unexpected section length")
assert( n8r == 32 , "expecting 256 bit prime" )
assert( bool(r == primeR) , "expecting the alt-bn128 curve" )
user.curve = "bn128"
let nvars = int( stream.readUint32() )
user.nvars = nvars;
echo("nvars = ",nvars)
# echo("nvars = ",nvars)
#-------------------------------------------------------------------------------
proc parseSection2_witness( stream: Stream, user: var Witness, sectionLen: int ) =
assert( sectionLen == 32 * user.nvars )
user.values = loadValuesFr( user.nvars, stream )
user.values = loadValuesFrStd( user.nvars, stream )
#-------------------------------------------------------------------------------
@ -63,11 +70,11 @@ proc wtnsCallback(stream: Stream, sectId: int, sectLen: int, user: var Witness)
of 2: parseSection2_witness( stream, user, sectLen )
else: discard
proc parseWitness* (fname: string) =
proc parseWitness* (fname: string): Witness =
var wtns : Witness
parseContainer( "wtns", 2, fname, wtns, wtnsCallback, proc (id: int): bool = id == 1 )
parseContainer( "wtns", 2, fname, wtns, wtnsCallback, proc (id: int): bool = id != 1 )
return wtns
#-------------------------------------------------------------------------------

View File

@ -52,6 +52,8 @@
# compute (C*witness)[i] = (A*witness)[i] * (B*witness)[i]
# These 3 column vectors is all we need in the proof generation.
#
# WARNING! It appears that the values here are *doubly Montgomery encoded* (?!)
#
# 5: PointsA
# ----------
# the curve points [A_j(tau)]_1 in G1
@ -74,7 +76,9 @@
#
# 9: PointsH
# ----------
# the curve points [delta^-1 * tau^i * Z(tau)]
# what normally should be the curve points [delta^-1 * tau^i * Z(tau)]
# HOWEVER, they are NOT! (??)
# See <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
# length = 2 * n8p * domSize = domSize G1 points
#
# 10: Contributions
@ -82,64 +86,20 @@
# ??? (but not required for proving, only for checking that the `.zkey` file is valid)
#
#-------------------------------------------------------------------------------
import std/streams
import constantine/math/arithmetic except Fp, Fr
import constantine/math/io/io_bigints
#import constantine/math/io/io_bigints
import ./bn128
import ./zkey_types
import ./container
import ./misc
#-------------------------------------------------------------------------------
type
GrothHeader* = object
p : BigInt[256]
r : BigInt[256]
nvars : int
npubs : int
domainSize : int
logDomainSize : int
SpecPoints* = object
alpha1 : G1
beta1 : G1
beta2 : G2
gamma2 : G2
delta1 : G1
delta2 : G2
VerifierPoints* = object
pointsIC : seq[G1]
ProverPoints* = object
pointsA1 : seq[G1]
pointsB1 : seq[G1]
pointsB2 : seq[G2]
pointsC1 : seq[G1]
pointsH1 : seq[G1]
MatrixSel* = enum
MatrixA
MatrixB
MatrixC
Coeff* = object
matrix : MatrixSel
row : int
col : int
coeff : Fr
ZKey* = object
sectionMask : uint32
header : GrothHeader
specPoints : SpecPoints
vPoints : VerifierPoints
pPoints : ProverPoints
coeffs : seq[Coeff]
proc parseSection1_proverType ( stream: Stream, user: var Zkey, sectionLen: int ) =
assert( sectionLen == 4 , "unexpected section length" )
let proverType = stream.readUint32
@ -148,13 +108,13 @@ proc parseSection1_proverType ( stream: Stream, user: var Zkey, sectionLen: int
#-------------------------------------------------------------------------------
proc parseSection2_GrothHeader( stream: Stream, user: var ZKey, sectionLen: int ) =
echo "\nparsing the Groth16 zkey header"
# echo "\nparsing the Groth16 zkey header"
let (n8p, p) = parsePrimeField( stream ) # size of the base field
let (n8r, r) = parsePrimeField( stream ) # size of the scalar field
echo("p = ",toDecimalBig(p))
echo("r = ",toDecimalBig(r))
# echo("p = ",toDecimalBig(p))
# echo("r = ",toDecimalBig(r))
assert( sectionLen == 2*4 + n8p + n8r + 3*4 + 3*64 + 3*128 , "unexpected section length" )
@ -167,6 +127,7 @@ proc parseSection2_GrothHeader( stream: Stream, user: var ZKey, sectionLen: int
assert( bool(p == primeP) , "expecting the alt-bn128 curve" )
assert( bool(r == primeR) , "expecting the alt-bn128 curve" )
header.curve = "bn128"
let nvars = int( stream.readUint32() )
let npubs = int( stream.readUint32() )
@ -175,9 +136,9 @@ proc parseSection2_GrothHeader( stream: Stream, user: var ZKey, sectionLen: int
assert( (1 shl log2siz) == domsiz , "domain size should be a power of two" )
echo("nvars = ",nvars)
echo("npubs = ",npubs)
echo("domsiz = ",domsiz)
# echo("nvars = ",nvars)
# echo("npubs = ",npubs)
# echo("domsiz = ",domsiz)
header.nvars = nvars
header.npubs = npubs
@ -194,6 +155,7 @@ proc parseSection2_GrothHeader( stream: Stream, user: var ZKey, sectionLen: int
spec.gamma2 = loadPointG2( stream )
spec.delta1 = loadPointG1( stream )
spec.delta2 = loadPointG2( stream )
spec.alphaBeta = pairing( spec.alpha1, spec.beta2 )
user.specPoints = spec
#-------------------------------------------------------------------------------
@ -206,9 +168,9 @@ proc parseSection4_Coeffs( stream: Stream, user: var ZKey, sectionLen: int ) =
var coeffs : seq[Coeff]
for i in 1..ncoeffs:
let m = int( stream.readUint32() )
let r = int( stream.readUint32() )
let c = int( stream.readUint32() )
let m = int( stream.readUint32() ) # which matrix
let r = int( stream.readUint32() ) # row (equation index)
let c = int( stream.readUint32() ) # column (witness index)
assert( m >= 0 and m <= 2 , "invalid matrix selector" )
let sel : MatrixSel = case m
of 0: MatrixA
@ -217,7 +179,7 @@ proc parseSection4_Coeffs( stream: Stream, user: var ZKey, sectionLen: int ) =
else: raise newException(AssertionDefect, "fatal error")
assert( r >= 0 and r < nrows, "row index out of range" )
assert( c >= 0 and c < ncols, "column index out of range" )
let cf = loadValueFr( stream )
let cf = loadValueFrWTF( stream ) # Jordi, WTF is this encoding ?!?!?!!111
let entry = Coeff( matrix:sel, row:r, col:c, coeff:cf )
coeffs.add( entry )
@ -270,10 +232,11 @@ proc zkeyCallback(stream: Stream, sectId: int, sectLen: int, user: var ZKey) =
of 9: parseSection9_PointsH1( stream, user, sectLen )
else: discard
proc parseZKey* (fname: string) =
proc parseZKey* (fname: string): ZKey =
var zkey : ZKey
parseContainer( "zkey", 1, fname, zkey, zkeyCallback, proc (id: int): bool = id == 1 )
parseContainer( "zkey", 1, fname, zkey, zkeyCallback, proc (id: int): bool = id == 2 )
parseContainer( "zkey", 1, fname, zkey, zkeyCallback, proc (id: int): bool = id >= 3 )
return zkey
#-------------------------------------------------------------------------------

70
zkey_types.nim Normal file
View File

@ -0,0 +1,70 @@
import constantine/math/arithmetic except Fp, Fr
import ./bn128
#-------------------------------------------------------------------------------
type
GrothHeader* = object
curve* : string
p* : BigInt[256]
r* : BigInt[256]
nvars* : int
npubs* : int
domainSize* : int
logDomainSize* : int
SpecPoints* = object
alpha1* : G1
beta1* : G1
beta2* : G2
gamma2* : G2
delta1* : G1
delta2* : G2
alphaBeta* : Fp12 # = <alpha1,beta2>
VerifierPoints* = object
pointsIC* : seq[G1]
ProverPoints* = object
pointsA1* : seq[G1]
pointsB1* : seq[G1]
pointsB2* : seq[G2]
pointsC1* : seq[G1]
pointsH1* : seq[G1]
MatrixSel* = enum
MatrixA
MatrixB
MatrixC
Coeff* = object
matrix* : MatrixSel
row* : int
col* : int
coeff* : Fr
ZKey* = object
# sectionMask* : uint32
header* : GrothHeader
specPoints* : SpecPoints
vPoints* : VerifierPoints
pPoints* : ProverPoints
coeffs* : seq[Coeff]
VKey* = object
curve* : string
spec* : SpecPoints
vpoints* : VerifierPoints
#-------------------------------------------------------------------------------
func extractVKey*(zkey: Zkey): VKey =
let curve = zkey.header.curve
let spec = zkey.specPoints
let vpts = zkey.vPoints
return VKey(curve:curve, spec:spec, vpoints:vpts)
#-------------------------------------------------------------------------------