mirror of
https://github.com/logos-storage/nim-groth16.git
synced 2026-01-05 07:03:09 +00:00
proving and verifying _finally_ works
This commit is contained in:
parent
e893d37b43
commit
f094de8df3
1
.gitignore
vendored
1
.gitignore
vendored
@ -2,4 +2,5 @@
|
||||
_bck*
|
||||
tmp
|
||||
main
|
||||
main.nim
|
||||
*.json
|
||||
|
||||
19
README.md
19
README.md
@ -1,11 +1,24 @@
|
||||
|
||||
Groth16 prover in Nim
|
||||
---------------------
|
||||
Groth16 prover written in Nim
|
||||
-----------------------------
|
||||
|
||||
This is Groth16 prover implementation in Nim, using the
|
||||
[`constantine`](https://github.com/mratsim/constantine)
|
||||
library as an arithmetic / curve backend.
|
||||
|
||||
The implementation should be compatible with the `circom` ecosystem.
|
||||
The implementation is compatible with the `circom` ecosystem.
|
||||
|
||||
At the moment only the `BN254` (aka. `alt-bn128`) curve is supported.
|
||||
|
||||
|
||||
### TODO
|
||||
|
||||
- [ ] make it a nimble package
|
||||
- [ ] proper MSM implementation (I couldn't make constantine's one to work)
|
||||
- [ ] proper polynomial implemention (constantine's one is essentially missing)
|
||||
- [ ] compare `.r1cs` to the "coeffs" section of `.zkey`
|
||||
- [ ] make it work for different curves
|
||||
- [ ] multithreaded support (MSM, and possibly also FFT)
|
||||
- [ ] add Groth16 notes
|
||||
- [ ] document the `snarkjs` circuit-specific setup `H` points convention
|
||||
|
||||
|
||||
172
bn128.nim
172
bn128.nim
@ -1,6 +1,19 @@
|
||||
#
|
||||
# the `alt-bn128` elliptic curve
|
||||
#
|
||||
# See for example <https://hackmd.io/@jpw/bn254>
|
||||
#
|
||||
# p = 21888242871839275222246405745257275088696311157297823662689037894645226208583
|
||||
# r = 21888242871839275222246405745257275088548364400416034343698204186575808495617
|
||||
#
|
||||
# equation: y^2 = x^3 + 3
|
||||
#
|
||||
|
||||
import sugar
|
||||
|
||||
import std/bitops
|
||||
import std/strutils
|
||||
import std/sequtils
|
||||
import std/streams
|
||||
import std/random
|
||||
|
||||
@ -14,8 +27,8 @@ import constantine/math/extension_fields/towers as ext
|
||||
import constantine/math/elliptic/ec_shortweierstrass_affine as aff
|
||||
import constantine/math/elliptic/ec_shortweierstrass_projective as prj
|
||||
import constantine/math/pairings/pairings_bn as ate
|
||||
import constantine/math/elliptic/ec_multi_scalar_mul as msm
|
||||
import constantine/math/elliptic/ec_scalar_mul as scl
|
||||
# import constantine/math/elliptic/ec_multi_scalar_mul as msm
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
@ -118,6 +131,12 @@ func smallPowFr*(base: Fr, expo: uint): Fr =
|
||||
square(s)
|
||||
return a
|
||||
|
||||
func smallPowFr*(base: Fr, expo: int): Fr =
|
||||
if expo >= 0:
|
||||
return smallPowFr( base, uint(expo) )
|
||||
else:
|
||||
return smallPowFr( invFr(base) , uint(-expo) )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func toDecimalBig*[n](a : BigInt[n]): string =
|
||||
@ -140,11 +159,59 @@ func toDecimalFr*(a : Fr): string =
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc debugPrintSeqFr*(msg: string, xs: seq[Fr]) =
|
||||
proc debugPrintFp*(prefix: string, x: Fp) =
|
||||
echo(prefix & toDecimalFp(x))
|
||||
|
||||
proc debugPrintFp2*(prefix: string, z: Fp2) =
|
||||
echo(prefix & " 1 ~> " & toDecimalFp(z.coords[0]))
|
||||
echo(prefix & " u ~> " & toDecimalFp(z.coords[1]))
|
||||
|
||||
proc debugPrintFr*(prefix: string, x: Fr) =
|
||||
echo(prefix & toDecimalFr(x))
|
||||
|
||||
proc debugPrintFrSeq*(msg: string, xs: seq[Fr]) =
|
||||
echo "---------------------"
|
||||
echo msg
|
||||
for x in xs:
|
||||
echo(" " & toDecimalFr(x))
|
||||
debugPrintFr( " " , x )
|
||||
|
||||
proc debugPrintG1*(msg: string, pt: G1) =
|
||||
echo(msg & ":")
|
||||
debugPrintFp( " x = ", pt.x )
|
||||
debugPrintFp( " y = ", pt.y )
|
||||
|
||||
proc debugPrintG2*(msg: string, pt: G2) =
|
||||
echo(msg & ":")
|
||||
debugPrintFp2( " x = ", pt.x )
|
||||
debugPrintFp2( " y = ", pt.y )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# Montgomery batch inversion
|
||||
func batchInverse*( xs: seq[Fr] ) : seq[Fr] =
|
||||
let n = xs.len
|
||||
assert(n>0)
|
||||
var us : seq[Fr] = newSeq[Fr](n+1)
|
||||
var a = xs[0]
|
||||
us[0] = oneFr
|
||||
us[1] = a
|
||||
for i in 1..<n: ( a *= xs[i] ; us[i+1] = a )
|
||||
var vs : seq[Fr] = newSeq[Fr](n)
|
||||
vs[n-1] = invFr( us[n] )
|
||||
for i in countdown(n-2,0): vs[i] = vs[i+1] * xs[i+1]
|
||||
return collect( newSeq, (for i in 0..<n: us[i]*vs[i] ) )
|
||||
|
||||
proc sanityCheckBatchInverse*() =
|
||||
let xs : seq[Fr] = map( toSeq(101..137) , intToFr )
|
||||
let ys = batchInverse( xs )
|
||||
let zs = collect( newSeq, (for x in xs: invFr(x)) )
|
||||
let n = xs.len
|
||||
# for i in 0..<n: echo(i," | batch = ",toDecimalFr(ys[i])," | ref = ",toDecimalFr(zs[i]) )
|
||||
for i in 0..<n:
|
||||
if not bool(ys[i] == zs[i]):
|
||||
echo "batch inverse test FAILED!"
|
||||
return
|
||||
echo "batch iverse test OK."
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# random values
|
||||
@ -321,9 +388,10 @@ proc checkMontgomeryConstants*() =
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
# the binary files used by the `circom` ecosystem always use little-endian
|
||||
# Montgomery representation. So when we unmarshal with Constantine, it will
|
||||
# give the wrong result. Calling this function on the result fixes that.
|
||||
# the binary files used by the `circom` ecosystem (EXCEPT the witness file!)
|
||||
# always use little-endian Montgomery representation. So when we unmarshal
|
||||
# with Constantine, it will give the wrong result. Calling this function on
|
||||
# the result fixes that.
|
||||
func fromMontgomeryFp*(x : Fp) : Fp =
|
||||
var y : Fp = x;
|
||||
y *= fpInvMontR
|
||||
@ -334,19 +402,41 @@ func fromMontgomeryFr*(x : Fr) : Fr =
|
||||
y *= frInvMontR
|
||||
return y
|
||||
|
||||
func toMontgomeryFr*(x : Fr) : Fr =
|
||||
var y : Fr = x;
|
||||
y *= frMontR
|
||||
return y
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Unmarshalling field elements
|
||||
# (note: circom binary files use little-endian Montgomery representation)
|
||||
# Except, in witness files, where the standard representation is used
|
||||
# And, EXCEPT in the zkey coefficients, where apparently DOUBLE Montgomery encoding is used ???
|
||||
#
|
||||
|
||||
func unmarshalFp* ( bs: array[32,byte] ) : Fp =
|
||||
# WTF Jordi, go home you are drunk
|
||||
func unmarshalFrWTF* ( bs: array[32,byte] ) : Fr =
|
||||
var big : BigInt[254]
|
||||
unmarshal( big, bs, littleEndian );
|
||||
var x : Fr
|
||||
x.fromBig( big )
|
||||
return fromMontgomeryFr(fromMontgomeryFr(x))
|
||||
|
||||
func unmarshalFrStd* ( bs: array[32,byte] ) : Fr =
|
||||
var big : BigInt[254]
|
||||
unmarshal( big, bs, littleEndian );
|
||||
var x : Fr
|
||||
x.fromBig( big )
|
||||
return x
|
||||
|
||||
func unmarshalFpMont* ( bs: array[32,byte] ) : Fp =
|
||||
var big : BigInt[254]
|
||||
unmarshal( big, bs, littleEndian );
|
||||
var x : Fp
|
||||
x.fromBig( big )
|
||||
return fromMontgomeryFp(x)
|
||||
|
||||
func unmarshalFr* ( bs: array[32,byte] ) : Fr =
|
||||
func unmarshalFrMont* ( bs: array[32,byte] ) : Fr =
|
||||
var big : BigInt[254]
|
||||
unmarshal( big, bs, littleEndian );
|
||||
var x : Fr
|
||||
@ -355,65 +445,85 @@ func unmarshalFr* ( bs: array[32,byte] ) : Fr =
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func unmarshalFpSeq* ( len: int, bs: openArray[byte] ) : seq[Fp] =
|
||||
func unmarshalFpMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fp] =
|
||||
var vals : seq[Fp] = newSeq[Fp]( len )
|
||||
var bytes : array[32,byte]
|
||||
for i in 0..<len:
|
||||
copyMem( addr(bytes) , unsafeAddr(bs[32*i]) , 32 )
|
||||
vals[i] = unmarshalFp( bytes )
|
||||
vals[i] = unmarshalFpMont( bytes )
|
||||
return vals
|
||||
|
||||
func unmarshalFrSeq* ( len: int, bs: openArray[byte] ) : seq[Fr] =
|
||||
func unmarshalFrMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fr] =
|
||||
var vals : seq[Fr] = newSeq[Fr]( len )
|
||||
var bytes : array[32,byte]
|
||||
for i in 0..<len:
|
||||
copyMem( addr(bytes) , unsafeAddr(bs[32*i]) , 32 )
|
||||
vals[i] = unmarshalFr( bytes )
|
||||
vals[i] = unmarshalFrMont( bytes )
|
||||
return vals
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc loadValueFp*( stream: Stream ) : Fp =
|
||||
proc loadValueFrWTF*( stream: Stream ) : Fr =
|
||||
var bytes : array[32,byte]
|
||||
let n = stream.readData( addr(bytes), 32 )
|
||||
# for i in 0..<32: stdout.write(" " & toHex(bytes[i]))
|
||||
# echo("")
|
||||
assert( n == 32 )
|
||||
return unmarshalFrWTF(bytes)
|
||||
|
||||
proc loadValueFrStd*( stream: Stream ) : Fr =
|
||||
var bytes : array[32,byte]
|
||||
let n = stream.readData( addr(bytes), 32 )
|
||||
assert( n == 32 )
|
||||
return unmarshalFp(bytes)
|
||||
return unmarshalFrStd(bytes)
|
||||
|
||||
proc loadValueFr*( stream: Stream ) : Fr =
|
||||
proc loadValueFrMont*( stream: Stream ) : Fr =
|
||||
var bytes : array[32,byte]
|
||||
let n = stream.readData( addr(bytes), 32 )
|
||||
assert( n == 32 )
|
||||
return unmarshalFr(bytes)
|
||||
return unmarshalFrMont(bytes)
|
||||
|
||||
proc loadValueFp2*( stream: Stream ) : Fp2 =
|
||||
let i = loadValueFp( stream )
|
||||
let u = loadValueFp( stream )
|
||||
proc loadValueFpMont*( stream: Stream ) : Fp =
|
||||
var bytes : array[32,byte]
|
||||
let n = stream.readData( addr(bytes), 32 )
|
||||
assert( n == 32 )
|
||||
return unmarshalFpMont(bytes)
|
||||
|
||||
proc loadValueFp2Mont*( stream: Stream ) : Fp2 =
|
||||
let i = loadValueFpMont( stream )
|
||||
let u = loadValueFpMont( stream )
|
||||
return mkFp2(i,u)
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
proc loadValuesFp*( len: int, stream: Stream ) : seq[Fp] =
|
||||
var values : seq[Fp]
|
||||
for i in 1..len:
|
||||
values.add( loadValueFp(stream) )
|
||||
return values
|
||||
|
||||
proc loadValuesFr*( len: int, stream: Stream ) : seq[Fr] =
|
||||
proc loadValuesFrStd*( len: int, stream: Stream ) : seq[Fr] =
|
||||
var values : seq[Fr]
|
||||
for i in 1..len:
|
||||
values.add( loadValueFr(stream) )
|
||||
values.add( loadValueFrStd(stream) )
|
||||
return values
|
||||
|
||||
proc loadValuesFpMont*( len: int, stream: Stream ) : seq[Fp] =
|
||||
var values : seq[Fp]
|
||||
for i in 1..len:
|
||||
values.add( loadValueFpMont(stream) )
|
||||
return values
|
||||
|
||||
proc loadValuesFrMont*( len: int, stream: Stream ) : seq[Fr] =
|
||||
var values : seq[Fr]
|
||||
for i in 1..len:
|
||||
values.add( loadValueFrMont(stream) )
|
||||
return values
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc loadPointG1*( stream: Stream ) : G1 =
|
||||
let x = loadValueFp( stream )
|
||||
let y = loadValueFp( stream )
|
||||
let x = loadValueFpMont( stream )
|
||||
let y = loadValueFpMont( stream )
|
||||
return mkG1(x,y)
|
||||
|
||||
proc loadPointG2*( stream: Stream ) : G2 =
|
||||
let x = loadValueFp2( stream )
|
||||
let y = loadValueFp2( stream )
|
||||
let x = loadValueFp2Mont( stream )
|
||||
let y = loadValueFp2Mont( stream )
|
||||
return mkG2(x,y)
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
@ -67,8 +67,6 @@ proc readSection[T] ( expectedMagic: string
|
||||
let sectLen = int( stream.readUint64() )
|
||||
let oldpos = stream.getPosition()
|
||||
if filt(sectId):
|
||||
# echo("section id = ",sectId )
|
||||
# echo("section len = ",sectLen)
|
||||
callback(stream, sectId, sectLen, user)
|
||||
stream.setPosition(oldpos + sectLen)
|
||||
|
||||
@ -89,7 +87,7 @@ proc parseContainer*[T] ( expectedMagic: string
|
||||
let version = stream.readUint32()
|
||||
assert( version == uint32(expectedVersion) , "not a version " & ($expectedVersion) & " `" & expectedMagic & "` file" )
|
||||
let nsections = stream.readUint32()
|
||||
echo("number of sections = ",nsections)
|
||||
# echo("number of sections = ",nsections)
|
||||
|
||||
for i in 1..nsections:
|
||||
readSection(expectedMagic, expectedVersion, stream, user, callback, filt)
|
||||
|
||||
@ -3,9 +3,9 @@
|
||||
# power-of-two sized multiplicative FFT domains in the scalar field
|
||||
#
|
||||
|
||||
import constantine/math/io/io_bigints
|
||||
import constantine/math/arithmetic except Fp,Fr
|
||||
import constantine/math/io/io_fields except Fp,Fr
|
||||
#import constantine/math/io/io_bigints
|
||||
|
||||
import ./bn128
|
||||
import ./misc
|
||||
@ -31,8 +31,8 @@ func createDomain*(size: int): Domain =
|
||||
let gen : Fr = smallPowFr(gen28, expo)
|
||||
|
||||
let halfSize = size div 2
|
||||
let a : Fr = smallPowFr(gen, uint(size ))
|
||||
let b : Fr = smallPowFr(gen, uint(halfSize))
|
||||
let a : Fr = smallPowFr(gen, size )
|
||||
let b : Fr = smallPowFr(gen, halfSize)
|
||||
assert( bool(a == oneFr) , "domain generator sanity check /A" )
|
||||
assert( not bool(b == oneFr) , "domain generator sanity check /B" )
|
||||
|
||||
|
||||
95
export_json.nim
Normal file
95
export_json.nim
Normal file
@ -0,0 +1,95 @@
|
||||
|
||||
#
|
||||
# export proof and public input in `circom`-compatible JSON files
|
||||
#
|
||||
|
||||
import std/sequtils
|
||||
|
||||
import constantine/math/arithmetic except Fp, Fr
|
||||
#import constantine/math/io/io_fields except Fp, Fr
|
||||
|
||||
import bn128
|
||||
from ./groth16 import Proof
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func toQuotedDecimalFp(x: Fp): string =
|
||||
let s : string = toDecimalFp(x)
|
||||
return ("\"" & s & "\"")
|
||||
|
||||
func toQuotedDecimalFr(x: Fr): string =
|
||||
let s : string = toDecimalFr(x)
|
||||
return ("\"" & s & "\"")
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# exports the public input/output into as a JSON file
|
||||
proc exportPublicIO*( fpath: string, prf: Proof ) =
|
||||
|
||||
# debugPrintFrSeq("public IO",prf.publicIO)
|
||||
|
||||
let n : int = prf.publicIO.len
|
||||
assert( n > 0 )
|
||||
assert( bool(prf.publicIO[0] == oneFr) )
|
||||
|
||||
let f = open(fpath, fmWrite)
|
||||
defer: f.close()
|
||||
|
||||
for i in 1..<n:
|
||||
let str : string = toQuotedDecimalFr( prf.publicIO[i] )
|
||||
if i==1:
|
||||
f.writeLine("[ " & str)
|
||||
else:
|
||||
f.writeLine(", " & str)
|
||||
f.writeLine("] ")
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc writeFp2( f: File, c: char, z: Fp2 ) =
|
||||
let prefix = " " & c & " "
|
||||
let indent = " "
|
||||
f.writeLine( prefix & "[ " & toQuotedDecimalFp( z.coords[0] ) )
|
||||
f.writeLine( indent & ", " & toQuotedDecimalFp( z.coords[1] ) )
|
||||
f.writeLine( indent & "]")
|
||||
|
||||
proc writeG1( f: File, p: G1 ) =
|
||||
f.writeLine(" [ " & toQuotedDecimalFp( p.x ) )
|
||||
f.writeLine(" , " & toQuotedDecimalFp( p.y ) )
|
||||
f.writeLine(" , " & toQuotedDecimalFp( oneFp ) )
|
||||
f.writeLine(" ]")
|
||||
|
||||
proc writeG2( f: File, p: G2 ) =
|
||||
writeFp2( f , '[' , p.x )
|
||||
writeFp2( f , ',' , p.y )
|
||||
writeFp2( f , ',' , oneFp2 )
|
||||
f.writeLine(" ]")
|
||||
|
||||
# exports the proof into as a JSON file
|
||||
proc exportProof*( fpath: string, prf: Proof ) =
|
||||
|
||||
let f = open(fpath, fmWrite)
|
||||
defer: f.close()
|
||||
|
||||
f.writeLine("{ \"protocol\": \"groth16\"")
|
||||
f.writeLine(", \"curve\": \"bn128\"" )
|
||||
f.writeLine(", \"pi_a\":" ) ; writeG1( f, prf.pi_a )
|
||||
f.writeLine(", \"pi_b\":" ) ; writeG2( f, prf.pi_b )
|
||||
f.writeLine(", \"pi_c\":" ) ; writeG1( f, prf.pi_c )
|
||||
f.writeLine("}")
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func getFakeProof*() : Proof =
|
||||
let pub : seq[Fr] = map( [1,101,102,103,117,119] , intToFr )
|
||||
let p = unsafeMkG1( intToFp(666) , intToFp(777) )
|
||||
let r = unsafeMkG1( intToFp(888) , intToFp(999) )
|
||||
let x = mkFp2( intToFp(22) , intToFp(33) )
|
||||
let y = mkFp2( intToFp(44) , intToFp(55) )
|
||||
let q = unsafeMkG2( x , y )
|
||||
return Proof( publicIO:pub, pi_a:p, pi_b:q, pi_c:r )
|
||||
|
||||
proc exportFakeProof*() =
|
||||
let prf = getFakeProof()
|
||||
exportPublicIO( "fake_pub.json" , prf )
|
||||
exportProof( "fake_prf.json" , prf )
|
||||
|
||||
236
groth16.nim
Normal file
236
groth16.nim
Normal file
@ -0,0 +1,236 @@
|
||||
|
||||
#
|
||||
# Groth16 prover
|
||||
#
|
||||
# WARNING! the points H are *NOT* what normal people would think they are
|
||||
# See <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
|
||||
#
|
||||
|
||||
#[]
|
||||
import sugar
|
||||
import constantine/math/config/curves
|
||||
import constantine/math/io/io_fields
|
||||
import constantine/math/io/io_bigints
|
||||
import ./zkey
|
||||
]#
|
||||
|
||||
import constantine/math/arithmetic except Fp, Fr
|
||||
import constantine/math/io/io_extfields except Fp12
|
||||
import constantine/math/extension_fields/towers except Fp2, Fp12
|
||||
|
||||
import ./bn128
|
||||
import ./domain
|
||||
import ./poly
|
||||
import ./zkey_types
|
||||
import ./witness
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
type
|
||||
Proof* = object
|
||||
publicIO* : seq[Fr]
|
||||
pi_a* : G1
|
||||
pi_b* : G2
|
||||
pi_c* : G1
|
||||
curve : string
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# the verifier
|
||||
#
|
||||
|
||||
proc verifyProof* (vkey: VKey, prf: Proof): bool =
|
||||
|
||||
assert( prf.curve == "bn128" )
|
||||
|
||||
assert( isOnCurveG1(prf.pi_a) , "pi_a is not in G1" )
|
||||
assert( isOnCurveG2(prf.pi_b) , "pi_b is not in G2" )
|
||||
assert( isOnCurveG1(prf.pi_c) , "pi_c is not in G1" )
|
||||
|
||||
var pubG1 : G1 = msmG1( prf.publicIO , vkey.vpoints.pointsIC )
|
||||
|
||||
let lhs : Fp12 = pairing( negG1(prf.pi_a) , prf.pi_b ) # < -pi_a , pi_b >
|
||||
let rhs1 : Fp12 = vkey.spec.alphaBeta # < alpha , beta >
|
||||
let rhs2 : Fp12 = pairing( prf.pi_c , vkey.spec.delta2 ) # < pi_c , delta >
|
||||
let rhs3 : Fp12 = pairing( pubG1 , vkey.spec.gamma2 ) # < sum... , gamma >
|
||||
|
||||
var eq : Fp12
|
||||
eq = lhs
|
||||
eq *= rhs1
|
||||
eq *= rhs2
|
||||
eq *= rhs3
|
||||
|
||||
return bool(isOne(eq))
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# A, B, C column vectors
|
||||
#
|
||||
|
||||
type
|
||||
ABC = object
|
||||
valuesA : seq[Fr]
|
||||
valuesB : seq[Fr]
|
||||
valuesC : seq[Fr]
|
||||
|
||||
func buildABC( zkey: ZKey, witness: seq[Fr] ): ABC =
|
||||
let hdr: GrothHeader = zkey.header
|
||||
let domSize = hdr.domainSize
|
||||
|
||||
var valuesA : seq[Fr] = newSeq[Fr](domSize)
|
||||
var valuesB : seq[Fr] = newSeq[Fr](domSize)
|
||||
for entry in zkey.coeffs:
|
||||
case entry.matrix
|
||||
of MatrixA: valuesA[entry.row] += entry.coeff * witness[entry.col]
|
||||
of MatrixB: valuesB[entry.row] += entry.coeff * witness[entry.col]
|
||||
else: raise newException(AssertionDefect, "fatal error")
|
||||
|
||||
var valuesC : seq[Fr] = newSeq[Fr](domSize)
|
||||
for i in 0..<hdr.nvars:
|
||||
valuesC[i] = valuesA[i] * valuesB[i]
|
||||
|
||||
return ABC( valuesA:valuesA, valuesB:valuesB, valuesC:valuesC )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# quotient poly
|
||||
#
|
||||
|
||||
# interpolates A,B,C, and computes the quotient polynomial Q = (A*B - C) / Z
|
||||
func computeQuotientNaive( abc: ABC ): Poly=
|
||||
let n = abc.valuesA.len
|
||||
assert( abc.valuesB.len == n )
|
||||
assert( abc.valuesC.len == n )
|
||||
let D = createDomain(n)
|
||||
let polyA : Poly = polyInverseNTT( abc.valuesA , D )
|
||||
let polyB : Poly = polyInverseNTT( abc.valuesB , D )
|
||||
let polyC : Poly = polyInverseNTT( abc.valuesC , D )
|
||||
let polyBig = polyMulFFT( polyA , polyB ) - polyC
|
||||
var polyQ = polyDivideByVanishing(polyBig, D.domainSize)
|
||||
polyQ.coeffs.add( zeroFr ) # make it a power of two
|
||||
return polyQ
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
# returns [ eta^i * xs[i] | i<-[0..n-1] ]
|
||||
func multiplyByPowers( xs: seq[Fr], eta: Fr ): seq[Fr] =
|
||||
let n = xs.len
|
||||
assert(n >= 1)
|
||||
var ys : seq[Fr] = newSeq[Fr](n)
|
||||
ys[0] = xs[0]
|
||||
if n >= 1: ys[1] = eta * xs[1]
|
||||
var spow : Fr = eta
|
||||
for i in 2..<n:
|
||||
spow *= eta
|
||||
ys[i] = spow * xs[i]
|
||||
return ys
|
||||
|
||||
# interpolates a polynomial, shift the variable by `eta`, and compute the shifted values
|
||||
func shiftEvalDomain( values: seq[Fr], D: Domain, eta: Fr ): seq[Fr] =
|
||||
let poly : Poly = polyInverseNTT( values , D )
|
||||
let cs : seq[Fr] = poly.coeffs
|
||||
var ds : seq[Fr] = multiplyByPowers( cs, eta )
|
||||
return polyForwardNTT( Poly(coeffs:ds), D )
|
||||
|
||||
# computes the quotient polynomial Q = (A*B - C) / Z
|
||||
# by computing the values on a shifted domain, and interpolating the result
|
||||
# remark: Q has degree `n-2`, so it's enough to use a domain of size n
|
||||
func computeQuotientPointwise( abc: ABC ): Poly =
|
||||
let n = abc.valuesA.len
|
||||
let D = createDomain(n)
|
||||
|
||||
# (eta*omega^j)^n - 1 = eta^n - 1
|
||||
# 1 / [ (eta*omega^j)^n - 1] = 1/(eta^n - 1)
|
||||
let eta = createDomain(2*n).domainGen
|
||||
let Inv1 = invFr( smallPowFr(eta,n) - oneFr )
|
||||
|
||||
let A1 = shiftEvalDomain( abc.valuesA, D, eta )
|
||||
let B1 = shiftEvalDomain( abc.valuesB, D, eta )
|
||||
let C1 = shiftEvalDomain( abc.valuesC, D, eta )
|
||||
|
||||
var ys : seq[Fr] = newSeq[Fr]( n )
|
||||
for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] ) * Inv1
|
||||
let Q1 = polyInverseNTT( ys, D )
|
||||
|
||||
let cs = multiplyByPowers( Q1.coeffs, invFr(eta) )
|
||||
return Poly(coeffs: cs)
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
# Snarkjs does something different, not actually computing the quotient poly
|
||||
# they can get away with this, because during the trusted setup, they
|
||||
# transform the H points into (shifted??) Lagrange bases (?)
|
||||
# see <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
|
||||
#
|
||||
func computeSnarkjsScalarCoeffs( abc: ABC ): seq[Fr] =
|
||||
let n = abc.valuesA.len
|
||||
let D = createDomain(n)
|
||||
let eta = createDomain(2*n).domainGen
|
||||
let A1 = shiftEvalDomain( abc.valuesA, D, eta )
|
||||
let B1 = shiftEvalDomain( abc.valuesB, D, eta )
|
||||
let C1 = shiftEvalDomain( abc.valuesC, D, eta )
|
||||
var ys : seq[Fr] = newSeq[Fr]( n )
|
||||
for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] )
|
||||
return ys
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# the prover
|
||||
#
|
||||
|
||||
proc generateProof* ( zkey: ZKey, wtns: Witness ): Proof =
|
||||
assert( zkey.header.curve == wtns.curve )
|
||||
|
||||
let witness = wtns.values
|
||||
|
||||
let hdr : GrothHeader = zkey.header
|
||||
let spec : SpecPoints = zkey.specPoints
|
||||
let pts : ProverPoints = zkey.pPoints
|
||||
|
||||
let nvars = hdr.nvars
|
||||
let npubs = hdr.npubs
|
||||
|
||||
assert( nvars == witness.len , "wrong witness length" )
|
||||
|
||||
var pubIO : seq[Fr] = newSeq[Fr]( npubs + 1)
|
||||
for i in 0..npubs: pubIO[i] = witness[i]
|
||||
|
||||
var abc : ABC = buildABC( zkey, witness )
|
||||
|
||||
# let polyQ1 = computeQuotientNaive( abc )
|
||||
# let polyQ2 = computeQuotientPointwise( abc )
|
||||
|
||||
let qs = computeSnarkjsScalarCoeffs( abc )
|
||||
|
||||
var zs : seq[Fr] = newSeq[Fr]( nvars - npubs - 1 )
|
||||
for j in npubs+1..<nvars:
|
||||
zs[j-npubs-1] = witness[j]
|
||||
|
||||
# masking coeffs
|
||||
let r : Fr = randFr()
|
||||
let s : Fr = randFr()
|
||||
|
||||
# let r : Fr = intToFr(3)
|
||||
# let s : Fr = intToFr(4)
|
||||
|
||||
var pi_a : G1
|
||||
pi_a = spec.alpha1
|
||||
pi_a += r ** spec.delta1
|
||||
pi_a += msmG1( witness , pts.pointsA1 )
|
||||
|
||||
var rho : G1
|
||||
rho = spec.beta1
|
||||
rho += s ** spec.delta1
|
||||
rho += msmG1( witness , pts.pointsB1 )
|
||||
|
||||
var pi_b : G2
|
||||
pi_b = spec.beta2
|
||||
pi_b += s ** spec.delta2
|
||||
pi_b += msmG2( witness , pts.pointsB2 )
|
||||
|
||||
var pi_c : G1
|
||||
pi_c = s ** pi_a
|
||||
pi_c += r ** rho
|
||||
pi_c += negFr(r*s) ** spec.delta1
|
||||
pi_c += msmG1( qs , pts.pointsH1 )
|
||||
pi_c += msmG1( zs , pts.pointsC1 )
|
||||
|
||||
return Proof( curve:"bn128", publicIO:pubIO, pi_a:pi_a, pi_b:pi_b, pi_c:pi_c )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
19
main.nim
19
main.nim
@ -1,19 +0,0 @@
|
||||
|
||||
import ./r1cs
|
||||
import ./zkey
|
||||
import ./witness
|
||||
import ./bn128
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc testMain() =
|
||||
# checkMontgomeryConstants()
|
||||
let r1cs_fname : string = "/Users/bkomuves/zk/examples/circom/toy/build/toymain.r1cs"
|
||||
let zkey_fname : string = "/Users/bkomuves/zk/examples/circom/toy/build/toymain.zkey"
|
||||
let wtns_fname : string = "/Users/bkomuves/zk/examples/circom/toy/build/toymain_witness.wtns"
|
||||
parseWitness( wtns_fname)
|
||||
parseR1CS( r1cs_fname)
|
||||
parseZKey( zkey_fname)
|
||||
|
||||
when isMainModule:
|
||||
testMain()
|
||||
32
misc.nim
32
misc.nim
@ -1,4 +1,7 @@
|
||||
|
||||
#
|
||||
# miscellaneous routines
|
||||
#
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
@ -16,13 +19,26 @@ func ceilingLog2* (x : int) : int =
|
||||
else:
|
||||
return (floorLog2(x-1) + 1)
|
||||
|
||||
#
|
||||
# import std/math
|
||||
#
|
||||
# proc sanityCheckLog2* () =
|
||||
# for i in 0..18:
|
||||
# let x = float64(i)
|
||||
# echo( i," | ",floorLog2(i),"=",floor(log2(x))," | ",ceilingLog2(i),"=",ceil(log2(x)) )
|
||||
#
|
||||
#-------------------
|
||||
|
||||
#[
|
||||
import std/math
|
||||
|
||||
proc sanityCheckLog2* () =
|
||||
for i in 0..18:
|
||||
let x = float64(i)
|
||||
echo( i," | ",floorLog2(i),"=",floor(log2(x))," | ",ceilingLog2(i),"=",ceil(log2(x)) )
|
||||
]#
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
#[
|
||||
func rotateSeq[T](xs: seq[T], ofs: int): seq[T] =
|
||||
let n = xs.len
|
||||
var ys : seq[T]
|
||||
for i in (0..<n):
|
||||
ys.add( xs[ (i+n+ofs) mod n ] )
|
||||
return ys
|
||||
]#
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
18
ntt.nim
18
ntt.nim
@ -1,9 +1,10 @@
|
||||
|
||||
#
|
||||
# Number-theoretic transform
|
||||
# Number-theoretic transform
|
||||
# (that is, FFT for polynomials over finite fields)
|
||||
#
|
||||
|
||||
import std/sequtils
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
import constantine/math/arithmetic except Fp,Fr
|
||||
import constantine/math/io/io_fields
|
||||
@ -68,6 +69,19 @@ func forwardNTT*(src: seq[Fr], D: Domain): seq[Fr] =
|
||||
, tgt , 0 )
|
||||
return tgt
|
||||
|
||||
# pads the input with zeros to get a pwoer of two size
|
||||
func extendAndForwardNTT*(src: seq[Fr], D: Domain): seq[Fr] =
|
||||
let n = src.len
|
||||
let N = D.domainSize
|
||||
assert( n <= N )
|
||||
if n == N:
|
||||
return forwardNTT(src, D)
|
||||
else:
|
||||
var padded : seq[Fr] = newSeq[Fr]( N )
|
||||
for i in 0..<n: padded[i] = src[i]
|
||||
# for i in n..<N: padded[i] = zeroFr
|
||||
return forwardNTT(padded, D)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
const oneHalfFr* : Fr = fromHex(Fr, "0x183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f0fac9f8000001")
|
||||
|
||||
109
poly.nim
109
poly.nim
@ -17,6 +17,7 @@ import constantine/math/io/io_fields
|
||||
import bn128
|
||||
import domain
|
||||
import ntt
|
||||
import misc
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
@ -74,25 +75,25 @@ func polyNeg*(P: Poly) : Poly =
|
||||
func polyAdd*(P, Q: Poly) : Poly =
|
||||
let xs = P.coeffs ; let n = xs.len
|
||||
let ys = Q.coeffs ; let m = ys.len
|
||||
var zs : seq[Fr]
|
||||
var zs : seq[Fr] = newSeq[Fr](max(n,m))
|
||||
if n >= m:
|
||||
for i in 0..<m: zs.add( xs[i] + ys[i] )
|
||||
for i in m..<n: zs.add( xs[i] )
|
||||
for i in 0..<m: zs[i] = ( xs[i] + ys[i] )
|
||||
for i in m..<n: zs[i] = ( xs[i] )
|
||||
else:
|
||||
for i in 0..<n: zs.add( xs[i] + ys[i] )
|
||||
for i in n..<m: zs.add( ys[i] )
|
||||
for i in 0..<n: zs[i] = ( xs[i] + ys[i] )
|
||||
for i in n..<m: zs[i] = ( ys[i] )
|
||||
return Poly(coeffs: zs)
|
||||
|
||||
func polySub*(P, Q: Poly) : Poly =
|
||||
let xs = P.coeffs ; let n = xs.len
|
||||
let ys = Q.coeffs ; let m = ys.len
|
||||
var zs : seq[Fr]
|
||||
var zs : seq[Fr] = newSeq[Fr](max(n,m))
|
||||
if n >= m:
|
||||
for i in 0..<m: zs.add( xs[i] - ys[i] )
|
||||
for i in m..<n: zs.add( xs[i] )
|
||||
for i in 0..<m: zs[i] = ( xs[i] - ys[i] )
|
||||
for i in m..<n: zs[i] = ( xs[i] )
|
||||
else:
|
||||
for i in 0..<n: zs.add( xs[i] + ys[i] )
|
||||
for i in n..<m: zs.add( zeroFr - ys[i] )
|
||||
for i in 0..<n: zs[i] = ( xs[i] + ys[i] )
|
||||
for i in n..<m: zs[i] = ( negFr( ys[i] ))
|
||||
return Poly(coeffs: zs)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
@ -122,7 +123,28 @@ func polyMulNaive*(P, Q : Poly): Poly =
|
||||
zs[k] += xs[i] * ys[j]
|
||||
return Poly(coeffs: zs)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# multiply two polynomials using FFT
|
||||
func polyMulFFT*(P, Q: Poly): Poly =
|
||||
let n1 = P.coeffs.len
|
||||
let n2 = Q.coeffs.len
|
||||
|
||||
let log2 : int = max( ceilingLog2(n1) , ceilingLog2(n2) ) + 1
|
||||
let N : int = (1 shl log2)
|
||||
let D : Domain = createDomain( N )
|
||||
|
||||
let us : seq[Fr] = extendAndForwardNTT( P.coeffs, D )
|
||||
let vs : seq[Fr] = extendAndForwardNTT( Q.coeffs, D )
|
||||
let zs : seq[Fr] = collect( newSeq, (for i in 0..<N: us[i]*vs[i] ))
|
||||
let ws : seq[Fr] = inverseNTT( zs, D )
|
||||
|
||||
return Poly(coeffs: ws)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func polyMul*(P, Q : Poly): Poly =
|
||||
# return polyMulFFT(P, Q)
|
||||
return polyMulNaive(P, Q)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
@ -138,14 +160,20 @@ func `*`*(P: Poly, s: Fr ): Poly = return polyScale(s, P)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# the vanishing polynomial `(x^N - 1)`
|
||||
func vanishingPoly*(N: int): Poly =
|
||||
# the generalized vanishing polynomial `(a*x^N - b)`
|
||||
func generalizedVanishingPoly*(N: int, a: Fr, b: Fr): Poly =
|
||||
assert( N>=1 )
|
||||
var cs : seq[Fr] = newSeq[Fr]( N+1 )
|
||||
cs[0] = negFr( oneFr )
|
||||
cs[N] = oneFr
|
||||
cs[0] = negFr(b)
|
||||
cs[N] = a
|
||||
return Poly(coeffs: cs)
|
||||
|
||||
# the vanishing polynomial `(x^N - 1)`
|
||||
func vanishingPoly*(N: int): Poly =
|
||||
return generalizedVanishingPoly(N, oneFr, oneFr)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
type
|
||||
QuotRem*[T] = object
|
||||
quot* : T
|
||||
@ -164,13 +192,15 @@ func polyQuotRemByVanishing*(P: Poly, N: int): QuotRem[Poly] =
|
||||
rem = src
|
||||
|
||||
else:
|
||||
# compute quot
|
||||
|
||||
# compute quotient
|
||||
for j in countdown(deg-N, 0):
|
||||
if j+N <= deg-N:
|
||||
quot[j] = src[j+N] + quot[j+N]
|
||||
else:
|
||||
quot[j] = src[j+N]
|
||||
# compute rem
|
||||
|
||||
# compute remainder
|
||||
for j in 0..<N:
|
||||
if j <= deg-N:
|
||||
rem[j] = src[j] + quot[j]
|
||||
@ -191,14 +221,8 @@ func polyDivideByVanishing*(P: Poly, N: int): Poly =
|
||||
func polyForwardNTT*(P: Poly, D: Domain): seq[Fr] =
|
||||
let n = P.coeffs.len
|
||||
assert( n <= D.domainSize , "the domain must be as least as big as the polynomial" )
|
||||
|
||||
if n == D.domainSize:
|
||||
let src : seq[Fr] = P.coeffs
|
||||
return forwardNTT(src, D)
|
||||
else:
|
||||
var src : seq[Fr] = P.coeffs
|
||||
for i in n..<D.domainSize: src.add( zeroFr )
|
||||
return forwardNTT(src, D)
|
||||
let src : seq[Fr] = P.coeffs
|
||||
return forwardNTT(src, D)
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
@ -211,6 +235,8 @@ func polyInverseNTT*(ys: seq[Fr], D: Domain): Poly =
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
#[
|
||||
|
||||
proc sanityCheckOneHalf*() =
|
||||
let two = oneFr + oneFr
|
||||
let invTwo = oneHalfFr
|
||||
@ -224,20 +250,20 @@ proc sanityCheckVanishing*() =
|
||||
let P : Poly = Poly( coeffs:cs )
|
||||
|
||||
echo("degree of P = ",polyDegree(P))
|
||||
debugPrintSeqFr("xs", P.coeffs)
|
||||
debugPrintFrSeq("xs", P.coeffs)
|
||||
|
||||
let n : int = 5
|
||||
let QR = polyQuotRemByVanishing(P, n)
|
||||
let Q = QR.quot
|
||||
let R = QR.rem
|
||||
|
||||
debugPrintSeqFr("Q", Q.coeffs)
|
||||
debugPrintSeqFr("R", R.coeffs)
|
||||
debugPrintFrSeq("Q", Q.coeffs)
|
||||
debugPrintFrSeq("R", R.coeffs)
|
||||
|
||||
let Z : Poly = vanishingPoly(n)
|
||||
let S : Poly = Q * Z + R
|
||||
|
||||
debugPrintSeqFr("zs", S.coeffs)
|
||||
debugPrintFrSeq("zs", S.coeffs)
|
||||
echo( polyIsEqual(P,S) )
|
||||
|
||||
proc sanityCheckNTT*() =
|
||||
@ -249,9 +275,28 @@ proc sanityCheckNTT*() =
|
||||
let ys : seq[Fr] = collect( newSeq, (for x in xs: polyEvalAt(P,x)) )
|
||||
let zs : seq[Fr] = polyForwardNTT(P ,D)
|
||||
let Q : Poly = polyInverseNTT(zs,D)
|
||||
debugPrintSeqFr("xs", xs)
|
||||
debugPrintSeqFr("ys", ys)
|
||||
debugPrintSeqFr("zs", zs)
|
||||
debugPrintSeqFr("us", Q.coeffs)
|
||||
debugPrintFrSeq("xs", xs)
|
||||
debugPrintFrSeq("ys", ys)
|
||||
debugPrintFrSeq("zs", zs)
|
||||
debugPrintFrSeq("us", Q.coeffs)
|
||||
|
||||
proc sanityCheckMulFFT*() =
|
||||
var js : seq[int] = toSeq(101..110)
|
||||
let cs : seq[Fr] = map( js, intToFr )
|
||||
let P : Poly = Poly( coeffs:cs )
|
||||
|
||||
var ks : seq[int] = toSeq(1001..1020)
|
||||
let ds : seq[Fr] = map( ks, intToFr )
|
||||
let Q : Poly = Poly( coeffs:ds )
|
||||
|
||||
let R1 : Poly = polyMulNaive( P , Q )
|
||||
let R2 : Poly = polyMulFFT( P , Q )
|
||||
|
||||
# debugPrintFrSeq("naive coeffs", R1.coeffs)
|
||||
# debugPrintFrSeq("fft coeffs", R2.coeffs)
|
||||
|
||||
echo( "multiply test = ", polyIsEqual(R1,R2) )
|
||||
|
||||
]#
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
20
r1cs.nim
20
r1cs.nim
@ -82,12 +82,13 @@ type
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc parseSection1_header( stream: Stream, user: var R1CS, sectionLen: int ) =
|
||||
echo "\nparsing r1cs header"
|
||||
# echo "\nparsing r1cs header"
|
||||
|
||||
let (n8r, r) = parsePrimeField( stream ) # size of the scalar field
|
||||
echo("r = ",toDecimalBig(r))
|
||||
user.r = r;
|
||||
|
||||
# echo("r = ",toDecimalBig(r))
|
||||
|
||||
assert( sectionLen == 4 + n8r + 16 + 8 + 4, "unexpected section length")
|
||||
|
||||
assert( bool(r == primeR) , "expecting the alt-bn128 curve" )
|
||||
@ -104,14 +105,14 @@ proc parseSection1_header( stream: Stream, user: var R1CS, sectionLen: int ) =
|
||||
let nConstr = int( stream.readUint32() )
|
||||
user.nConstr = nConstr
|
||||
|
||||
echo("witness config = ",cfg)
|
||||
echo("nConstr = ",nConstr)
|
||||
# echo("witness config = ",cfg)
|
||||
# echo("nConstr = ",nConstr)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc loadTerm( stream: Stream ): Term =
|
||||
let idx = int( stream.readUint32() )
|
||||
let coeff = loadValueFr( stream )
|
||||
let coeff = loadValueFrMont( stream )
|
||||
return (wireIdx:idx, value:coeff)
|
||||
|
||||
proc loadLinComb( stream: Stream ): LinComb =
|
||||
@ -139,9 +140,9 @@ proc parseSection2_constraints( stream: Stream, user: var R1CS, sectionLen: int
|
||||
ncoeffsB += abc.B.len
|
||||
ncoeffsC += abc.C.len
|
||||
user.constraints = constr
|
||||
echo( "number of nonzero coefficients in matrix A = ", ncoeffsA )
|
||||
echo( "number of nonzero coefficients in matrix B = ", ncoeffsB )
|
||||
echo( "number of nonzero coefficients in matrix C = ", ncoeffsC )
|
||||
# echo( "number of nonzero coefficients in matrix A = ", ncoeffsA )
|
||||
# echo( "number of nonzero coefficients in matrix B = ", ncoeffsB )
|
||||
# echo( "number of nonzero coefficients in matrix C = ", ncoeffsC )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
@ -166,9 +167,10 @@ proc r1csCallback( stream: Stream
|
||||
of 3: parseSection3_wireToLabel( stream, user, sectLen )
|
||||
else: discard
|
||||
|
||||
proc parseR1CS* (fname: string) =
|
||||
proc parseR1CS* (fname: string): R1CS =
|
||||
var r1cs : R1CS
|
||||
parseContainer( "r1cs", 1, fname, r1cs, r1csCallback, proc (id: int): bool = id == 1 )
|
||||
parseContainer( "r1cs", 1, fname, r1cs, r1csCallback, proc (id: int): bool = id != 1 )
|
||||
return r1cs
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
28
test_proof.nim
Normal file
28
test_proof.nim
Normal file
@ -0,0 +1,28 @@
|
||||
|
||||
import ./groth16
|
||||
import ./export_json
|
||||
import ./witness
|
||||
import ./zkey
|
||||
import ./zkey_types
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc testProveAndVerify*( zkey_fname, wtns_fname: string) =
|
||||
|
||||
echo("parsing witness & zkey files...")
|
||||
let witness = parseWitness( wtns_fname)
|
||||
let zkey = parseZKey( zkey_fname)
|
||||
|
||||
echo("generating proof...")
|
||||
let vkey = extractVKey( zkey)
|
||||
let proof = generateProof( zkey, witness )
|
||||
|
||||
echo("exporting proof...")
|
||||
exportPublicIO( "my_pub.json" , proof )
|
||||
exportProof( "my_prf.json" , proof )
|
||||
|
||||
echo("verifying the proof...")
|
||||
let ok = verifyProof( vkey, proof)
|
||||
echo("verification succeeded = ",ok)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
25
witness.nim
25
witness.nim
@ -11,6 +11,10 @@
|
||||
#
|
||||
# nvars = 1 + pub + secret = 1 + npubout + npubin + nprivin + nsecret
|
||||
#
|
||||
# WARNING! unlike the `.r1cs` and `.zkey` files, which encode field elements
|
||||
# in Montgomery representation, the `.wtns` file encode field elements in
|
||||
# the standard representation!!
|
||||
#
|
||||
|
||||
import std/streams
|
||||
|
||||
@ -24,35 +28,38 @@ import ./container
|
||||
|
||||
type
|
||||
Witness* = object
|
||||
r : BigInt[256]
|
||||
nvars : int
|
||||
values : seq[Fr]
|
||||
curve* : string
|
||||
r* : BigInt[256]
|
||||
nvars* : int
|
||||
values* : seq[Fr]
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc parseSection1_header( stream: Stream, user: var Witness, sectionLen: int ) =
|
||||
echo "\nparsing witness header"
|
||||
# echo "\nparsing witness header"
|
||||
|
||||
let (n8r, r) = parsePrimeField( stream ) # size of the scalar field
|
||||
user.r = r;
|
||||
echo("r = ",toDecimalBig(r))
|
||||
|
||||
# echo("r = ",toDecimalBig(r))
|
||||
|
||||
assert( sectionLen == 4 + n8r + 4 , "unexpected section length")
|
||||
|
||||
assert( n8r == 32 , "expecting 256 bit prime" )
|
||||
assert( bool(r == primeR) , "expecting the alt-bn128 curve" )
|
||||
user.curve = "bn128"
|
||||
|
||||
let nvars = int( stream.readUint32() )
|
||||
user.nvars = nvars;
|
||||
|
||||
echo("nvars = ",nvars)
|
||||
# echo("nvars = ",nvars)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc parseSection2_witness( stream: Stream, user: var Witness, sectionLen: int ) =
|
||||
|
||||
assert( sectionLen == 32 * user.nvars )
|
||||
user.values = loadValuesFr( user.nvars, stream )
|
||||
user.values = loadValuesFrStd( user.nvars, stream )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
@ -63,11 +70,11 @@ proc wtnsCallback(stream: Stream, sectId: int, sectLen: int, user: var Witness)
|
||||
of 2: parseSection2_witness( stream, user, sectLen )
|
||||
else: discard
|
||||
|
||||
proc parseWitness* (fname: string) =
|
||||
proc parseWitness* (fname: string): Witness =
|
||||
var wtns : Witness
|
||||
parseContainer( "wtns", 2, fname, wtns, wtnsCallback, proc (id: int): bool = id == 1 )
|
||||
parseContainer( "wtns", 2, fname, wtns, wtnsCallback, proc (id: int): bool = id != 1 )
|
||||
|
||||
return wtns
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
85
zkey.nim
85
zkey.nim
@ -52,6 +52,8 @@
|
||||
# compute (C*witness)[i] = (A*witness)[i] * (B*witness)[i]
|
||||
# These 3 column vectors is all we need in the proof generation.
|
||||
#
|
||||
# WARNING! It appears that the values here are *doubly Montgomery encoded* (?!)
|
||||
#
|
||||
# 5: PointsA
|
||||
# ----------
|
||||
# the curve points [A_j(tau)]_1 in G1
|
||||
@ -74,7 +76,9 @@
|
||||
#
|
||||
# 9: PointsH
|
||||
# ----------
|
||||
# the curve points [delta^-1 * tau^i * Z(tau)]
|
||||
# what normally should be the curve points [delta^-1 * tau^i * Z(tau)]
|
||||
# HOWEVER, they are NOT! (??)
|
||||
# See <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
|
||||
# length = 2 * n8p * domSize = domSize G1 points
|
||||
#
|
||||
# 10: Contributions
|
||||
@ -82,64 +86,20 @@
|
||||
# ??? (but not required for proving, only for checking that the `.zkey` file is valid)
|
||||
#
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
import std/streams
|
||||
|
||||
import constantine/math/arithmetic except Fp, Fr
|
||||
import constantine/math/io/io_bigints
|
||||
|
||||
#import constantine/math/io/io_bigints
|
||||
|
||||
import ./bn128
|
||||
import ./zkey_types
|
||||
import ./container
|
||||
import ./misc
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
type
|
||||
|
||||
GrothHeader* = object
|
||||
p : BigInt[256]
|
||||
r : BigInt[256]
|
||||
nvars : int
|
||||
npubs : int
|
||||
domainSize : int
|
||||
logDomainSize : int
|
||||
|
||||
SpecPoints* = object
|
||||
alpha1 : G1
|
||||
beta1 : G1
|
||||
beta2 : G2
|
||||
gamma2 : G2
|
||||
delta1 : G1
|
||||
delta2 : G2
|
||||
|
||||
VerifierPoints* = object
|
||||
pointsIC : seq[G1]
|
||||
|
||||
ProverPoints* = object
|
||||
pointsA1 : seq[G1]
|
||||
pointsB1 : seq[G1]
|
||||
pointsB2 : seq[G2]
|
||||
pointsC1 : seq[G1]
|
||||
pointsH1 : seq[G1]
|
||||
|
||||
MatrixSel* = enum
|
||||
MatrixA
|
||||
MatrixB
|
||||
MatrixC
|
||||
|
||||
Coeff* = object
|
||||
matrix : MatrixSel
|
||||
row : int
|
||||
col : int
|
||||
coeff : Fr
|
||||
|
||||
ZKey* = object
|
||||
sectionMask : uint32
|
||||
header : GrothHeader
|
||||
specPoints : SpecPoints
|
||||
vPoints : VerifierPoints
|
||||
pPoints : ProverPoints
|
||||
coeffs : seq[Coeff]
|
||||
|
||||
proc parseSection1_proverType ( stream: Stream, user: var Zkey, sectionLen: int ) =
|
||||
assert( sectionLen == 4 , "unexpected section length" )
|
||||
let proverType = stream.readUint32
|
||||
@ -148,13 +108,13 @@ proc parseSection1_proverType ( stream: Stream, user: var Zkey, sectionLen: int
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc parseSection2_GrothHeader( stream: Stream, user: var ZKey, sectionLen: int ) =
|
||||
echo "\nparsing the Groth16 zkey header"
|
||||
# echo "\nparsing the Groth16 zkey header"
|
||||
|
||||
let (n8p, p) = parsePrimeField( stream ) # size of the base field
|
||||
let (n8r, r) = parsePrimeField( stream ) # size of the scalar field
|
||||
|
||||
echo("p = ",toDecimalBig(p))
|
||||
echo("r = ",toDecimalBig(r))
|
||||
# echo("p = ",toDecimalBig(p))
|
||||
# echo("r = ",toDecimalBig(r))
|
||||
|
||||
assert( sectionLen == 2*4 + n8p + n8r + 3*4 + 3*64 + 3*128 , "unexpected section length" )
|
||||
|
||||
@ -167,6 +127,7 @@ proc parseSection2_GrothHeader( stream: Stream, user: var ZKey, sectionLen: int
|
||||
|
||||
assert( bool(p == primeP) , "expecting the alt-bn128 curve" )
|
||||
assert( bool(r == primeR) , "expecting the alt-bn128 curve" )
|
||||
header.curve = "bn128"
|
||||
|
||||
let nvars = int( stream.readUint32() )
|
||||
let npubs = int( stream.readUint32() )
|
||||
@ -175,9 +136,9 @@ proc parseSection2_GrothHeader( stream: Stream, user: var ZKey, sectionLen: int
|
||||
|
||||
assert( (1 shl log2siz) == domsiz , "domain size should be a power of two" )
|
||||
|
||||
echo("nvars = ",nvars)
|
||||
echo("npubs = ",npubs)
|
||||
echo("domsiz = ",domsiz)
|
||||
# echo("nvars = ",nvars)
|
||||
# echo("npubs = ",npubs)
|
||||
# echo("domsiz = ",domsiz)
|
||||
|
||||
header.nvars = nvars
|
||||
header.npubs = npubs
|
||||
@ -194,6 +155,7 @@ proc parseSection2_GrothHeader( stream: Stream, user: var ZKey, sectionLen: int
|
||||
spec.gamma2 = loadPointG2( stream )
|
||||
spec.delta1 = loadPointG1( stream )
|
||||
spec.delta2 = loadPointG2( stream )
|
||||
spec.alphaBeta = pairing( spec.alpha1, spec.beta2 )
|
||||
user.specPoints = spec
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
@ -206,9 +168,9 @@ proc parseSection4_Coeffs( stream: Stream, user: var ZKey, sectionLen: int ) =
|
||||
|
||||
var coeffs : seq[Coeff]
|
||||
for i in 1..ncoeffs:
|
||||
let m = int( stream.readUint32() )
|
||||
let r = int( stream.readUint32() )
|
||||
let c = int( stream.readUint32() )
|
||||
let m = int( stream.readUint32() ) # which matrix
|
||||
let r = int( stream.readUint32() ) # row (equation index)
|
||||
let c = int( stream.readUint32() ) # column (witness index)
|
||||
assert( m >= 0 and m <= 2 , "invalid matrix selector" )
|
||||
let sel : MatrixSel = case m
|
||||
of 0: MatrixA
|
||||
@ -217,7 +179,7 @@ proc parseSection4_Coeffs( stream: Stream, user: var ZKey, sectionLen: int ) =
|
||||
else: raise newException(AssertionDefect, "fatal error")
|
||||
assert( r >= 0 and r < nrows, "row index out of range" )
|
||||
assert( c >= 0 and c < ncols, "column index out of range" )
|
||||
let cf = loadValueFr( stream )
|
||||
let cf = loadValueFrWTF( stream ) # Jordi, WTF is this encoding ?!?!?!!111
|
||||
let entry = Coeff( matrix:sel, row:r, col:c, coeff:cf )
|
||||
coeffs.add( entry )
|
||||
|
||||
@ -270,10 +232,11 @@ proc zkeyCallback(stream: Stream, sectId: int, sectLen: int, user: var ZKey) =
|
||||
of 9: parseSection9_PointsH1( stream, user, sectLen )
|
||||
else: discard
|
||||
|
||||
proc parseZKey* (fname: string) =
|
||||
proc parseZKey* (fname: string): ZKey =
|
||||
var zkey : ZKey
|
||||
parseContainer( "zkey", 1, fname, zkey, zkeyCallback, proc (id: int): bool = id == 1 )
|
||||
parseContainer( "zkey", 1, fname, zkey, zkeyCallback, proc (id: int): bool = id == 2 )
|
||||
parseContainer( "zkey", 1, fname, zkey, zkeyCallback, proc (id: int): bool = id >= 3 )
|
||||
return zkey
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
70
zkey_types.nim
Normal file
70
zkey_types.nim
Normal file
@ -0,0 +1,70 @@
|
||||
|
||||
import constantine/math/arithmetic except Fp, Fr
|
||||
|
||||
import ./bn128
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
type
|
||||
|
||||
GrothHeader* = object
|
||||
curve* : string
|
||||
p* : BigInt[256]
|
||||
r* : BigInt[256]
|
||||
nvars* : int
|
||||
npubs* : int
|
||||
domainSize* : int
|
||||
logDomainSize* : int
|
||||
|
||||
SpecPoints* = object
|
||||
alpha1* : G1
|
||||
beta1* : G1
|
||||
beta2* : G2
|
||||
gamma2* : G2
|
||||
delta1* : G1
|
||||
delta2* : G2
|
||||
alphaBeta* : Fp12 # = <alpha1,beta2>
|
||||
|
||||
VerifierPoints* = object
|
||||
pointsIC* : seq[G1]
|
||||
|
||||
ProverPoints* = object
|
||||
pointsA1* : seq[G1]
|
||||
pointsB1* : seq[G1]
|
||||
pointsB2* : seq[G2]
|
||||
pointsC1* : seq[G1]
|
||||
pointsH1* : seq[G1]
|
||||
|
||||
MatrixSel* = enum
|
||||
MatrixA
|
||||
MatrixB
|
||||
MatrixC
|
||||
|
||||
Coeff* = object
|
||||
matrix* : MatrixSel
|
||||
row* : int
|
||||
col* : int
|
||||
coeff* : Fr
|
||||
|
||||
ZKey* = object
|
||||
# sectionMask* : uint32
|
||||
header* : GrothHeader
|
||||
specPoints* : SpecPoints
|
||||
vPoints* : VerifierPoints
|
||||
pPoints* : ProverPoints
|
||||
coeffs* : seq[Coeff]
|
||||
|
||||
VKey* = object
|
||||
curve* : string
|
||||
spec* : SpecPoints
|
||||
vpoints* : VerifierPoints
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func extractVKey*(zkey: Zkey): VKey =
|
||||
let curve = zkey.header.curve
|
||||
let spec = zkey.specPoints
|
||||
let vpts = zkey.vPoints
|
||||
return VKey(curve:curve, spec:spec, vpoints:vpts)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
Loading…
x
Reference in New Issue
Block a user