Merge pull request #3 from codex-storage/refactor

Refactor & nimblify
This commit is contained in:
Balazs Komuves 2023-11-23 14:57:07 +01:00 committed by GitHub
commit fbe637e8d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 1658 additions and 1140 deletions

View File

@ -19,12 +19,13 @@ at your choice.
### TODO
- [ ] make it a nimble package
- [ ] refactor `bn128.nim` into smaller files
- [ ] proper MSM implementation (I couldn't make constantine's one to work)
- [ ] clean up the code
- [x] make it a nimble package
- [/] refactor `bn128.nim` into smaller files
- [/] proper MSM implementation (at first I couldn't make constantine's one to work)
- [x] compare `.r1cs` to the "coeffs" section of `.zkey`
- [x] generate fake circuit-specific setup ourselves
- [ ] multithreaded support (MSM, and possibly also FFT)
- [ ] multithreading support (MSM, and possibly also FFT)
- [ ] add Groth16 notes
- [ ] document the `snarkjs` circuit-specific setup `H` points convention
- [ ] make it work for different curves

797
bn128.nim
View File

@ -1,797 +0,0 @@
#
# the `alt-bn128` elliptic curve
#
# See for example <https://hackmd.io/@jpw/bn254>
#
# p = 21888242871839275222246405745257275088696311157297823662689037894645226208583
# r = 21888242871839275222246405745257275088548364400416034343698204186575808495617
#
# equation: y^2 = x^3 + 3
#
import sugar
import std/bitops
import std/strutils
import std/sequtils
import std/streams
import std/random
import constantine/platforms/abstractions
import constantine/math/isogenies/frobenius
import constantine/math/arithmetic
import constantine/math/io/io_fields
import constantine/math/io/io_bigints
import constantine/math/config/curves
import constantine/math/config/type_ff as tff
import constantine/math/extension_fields/towers as ext
import constantine/math/elliptic/ec_shortweierstrass_affine as aff
import constantine/math/elliptic/ec_shortweierstrass_projective as prj
import constantine/math/pairings/pairings_bn as ate
import constantine/math/elliptic/ec_scalar_mul as scl
import constantine/math/elliptic/ec_multi_scalar_mul as msm
#-------------------------------------------------------------------------------
type B* = BigInt[256]
type Fr* = tff.Fr[BN254Snarks]
type Fp* = tff.Fp[BN254Snarks]
type Fp2* = ext.QuadraticExt[Fp]
type Fp12* = ext.Fp12[BN254Snarks]
type G1* = aff.ECP_ShortW_Aff[Fp , aff.G1]
type G2* = aff.ECP_ShortW_Aff[Fp2, aff.G2]
type ProjG1* = prj.ECP_ShortW_Prj[Fp , prj.G1]
type ProjG2* = prj.ECP_ShortW_Prj[Fp2, prj.G2]
func mkFp2* (i: Fp, u: Fp) : Fp2 =
let c : array[2, Fp] = [i,u]
return ext.QuadraticExt[Fp]( coords: c )
func unsafeMkG1* ( X, Y: Fp ) : G1 =
return aff.ECP_ShortW_Aff[Fp, aff.G1](x: X, y: Y)
func unsafeMkG2* ( X, Y: Fp2 ) : G2 =
return aff.ECP_ShortW_Aff[Fp2, aff.G2](x: X, y: Y)
#-------------------------------------------------------------------------------
func pairing* (p: G1, q: G2) : Fp12 =
var t : Fp12
pairing_bn( t, p, q )
return t
#-------------------------------------------------------------------------------
const primeP* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian )
const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian )
const primeP_254 : BigInt[254] = fromHex( BigInt[254], "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian )
const primeR_254 : BigInt[254] = fromHex( BigInt[254], "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian )
#-------------------------------------------------------------------------------
const zeroFp* : Fp = fromHex( Fp, "0x00" )
const zeroFr* : Fr = fromHex( Fr, "0x00" )
const oneFp* : Fp = fromHex( Fp, "0x01" )
const oneFr* : Fr = fromHex( Fr, "0x01" )
const zeroFp2* : Fp2 = mkFp2( zeroFp, zeroFp )
const oneFp2* : Fp2 = mkFp2( oneFp , zeroFp )
const infG1* : G1 = unsafeMkG1( zeroFp , zeroFp )
const infG2* : G2 = unsafeMkG2( zeroFp2 , zeroFp2 )
#-------------------------------------------------------------------------------
func intToB*(a: uint): B =
var y : B
y.setUint(a)
return y
func intToFp*(a: int): Fp =
var y : Fp
y.fromInt(a)
return y
func intToFr*(a: int): Fr =
var y : Fr
y.fromInt(a)
return y
#-------------------------------------------------------------------------------
func isZeroFp*(x: Fp): bool = bool(isZero(x))
func isZeroFr*(x: Fr): bool = bool(isZero(x))
func isEqualFp*(x, y: Fp): bool = bool(x == y)
func isEqualFr*(x, y: Fr): bool = bool(x == y)
func `===`*(x, y: Fp): bool = isEqualFp(x,y)
func `===`*(x, y: Fr): bool = isEqualFr(x,y)
#-------------------
func isEqualFpSeq*(xs, ys: seq[Fp]): bool =
let n = xs.len
assert( n == ys.len )
var b = true
for i in 0..<n:
if not bool(xs[i] == ys[i]):
b = false
break
return b
func isEqualFrSeq*(xs, ys: seq[Fr]): bool =
let n = xs.len
assert( n == ys.len )
var b = true
for i in 0..<n:
if not bool(xs[i] == ys[i]):
b = false
break
return b
func `===`*(xs, ys: seq[Fp]): bool = isEqualFpSeq(xs,ys)
func `===`*(xs, ys: seq[Fr]): bool = isEqualFrSeq(xs,ys)
#-------------------------------------------------------------------------------
#func `+`*(x, y: B ): B = ( var z : B = x ; z += y ; return z )
func `+`*[n](x, y: BigInt[n] ): BigInt[n] = ( var z : BigInt[n] = x ; z += y ; return z )
func `+`*(x, y: Fp): Fp = ( var z : Fp = x ; z += y ; return z )
func `+`*(x, y: Fr): Fr = ( var z : Fr = x ; z += y ; return z )
#func `-`*(x, y: B ): B = ( var z : B = x ; z -= y ; return z )
func `-`*[n](x, y: BigInt[n] ): BigInt[n] = ( var z : BigInt[n] = x ; z -= y ; return z )
func `-`*(x, y: Fp): Fp = ( var z : Fp = x ; z -= y ; return z )
func `-`*(x, y: Fr): Fr = ( var z : Fr = x ; z -= y ; return z )
func `*`*(x, y: Fp): Fp = ( var z : Fp = x ; z *= y ; return z )
func `*`*(x, y: Fr): Fr = ( var z : Fr = x ; z *= y ; return z )
func negFp*(y: Fp): Fp = ( var z : Fp = zeroFp ; z -= y ; return z )
func negFr*(y: Fr): Fr = ( var z : Fr = zeroFr ; z -= y ; return z )
func invFp*(y: Fp): Fp = ( var z : Fp = y ; inv(z) ; return z )
func invFr*(y: Fr): Fr = ( var z : Fr = y ; inv(z) ; return z )
# template/generic instantiation of `pow_vartime` from here
# /Users/bkomuves/.nimble/pkgs/constantine-0.0.1/constantine/math/arithmetic/finite_fields.nim(389, 7) template/generic instantiation of `fieldMod` from here
# /Users/bkomuves/.nimble/pkgs/constantine-0.0.1/constantine/math/config/curves_prop_field_derived.nim(67, 5) Error: undeclared identifier: 'getCurveOrder'
# ...
func smallPowFr*(base: Fr, expo: uint): Fr =
var a : Fr = oneFr
var s : Fr = base
var e : uint = expo
while (e > 0):
if bitand(e,1) > 0: a *= s
e = (e shr 1)
square(s)
return a
func smallPowFr*(base: Fr, expo: int): Fr =
if expo >= 0:
return smallPowFr( base, uint(expo) )
else:
return smallPowFr( invFr(base) , uint(-expo) )
#-------------------------------------------------------------------------------
func deltaFr*(i, j: int) : Fr =
return (if (i == j): oneFr else: zeroFr)
#-------------------------------------------------------------------------------
func toDecimalBig*[n](a : BigInt[n]): string =
var s : string = toDecimal(a)
s = s.strip( leading=true, trailing=false, chars={'0'} )
if s.len == 0: s="0"
return s
func toDecimalFp*(a : Fp): string =
var s : string = toDecimal(a)
s = s.strip( leading=true, trailing=false, chars={'0'} )
if s.len == 0: s="0"
return s
func toDecimalFr*(a : Fr): string =
var s : string = toDecimal(a)
s = s.strip( leading=true, trailing=false, chars={'0'} )
if s.len == 0: s="0"
return s
#---------------------------------------
const k65536 : BigInt[254] = fromHex( BigInt[254], "0x10000", bigEndian )
func signedToDecimalFp*(a : Fp): string =
if bool( a.toBig() > primeP_254 - k65536 ):
return "-" & toDecimalFp(negFp(a))
else:
return toDecimalFp(a)
func signedToDecimalFr*(a : Fr): string =
if bool( a.toBig() > primeR_254 - k65536 ):
return "-" & toDecimalFr(negFr(a))
else:
return toDecimalFr(a)
#-------------------------------------------------------------------------------
proc debugPrintFp*(prefix: string, x: Fp) =
echo(prefix & toDecimalFp(x))
proc debugPrintFp2*(prefix: string, z: Fp2) =
echo(prefix & " 1 ~> " & toDecimalFp(z.coords[0]))
echo(prefix & " u ~> " & toDecimalFp(z.coords[1]))
proc debugPrintFr*(prefix: string, x: Fr) =
echo(prefix & toDecimalFr(x))
proc debugPrintFrSeq*(msg: string, xs: seq[Fr]) =
echo "---------------------"
echo msg
for x in xs:
debugPrintFr( " " , x )
proc debugPrintG1*(msg: string, pt: G1) =
echo(msg & ":")
debugPrintFp( " x = ", pt.x )
debugPrintFp( " y = ", pt.y )
proc debugPrintG2*(msg: string, pt: G2) =
echo(msg & ":")
debugPrintFp2( " x = ", pt.x )
debugPrintFp2( " y = ", pt.y )
#-------------------------------------------------------------------------------
# Montgomery batch inversion
func batchInverse*( xs: seq[Fr] ) : seq[Fr] =
let n = xs.len
assert(n>0)
var us : seq[Fr] = newSeq[Fr](n+1)
var a = xs[0]
us[0] = oneFr
us[1] = a
for i in 1..<n: ( a *= xs[i] ; us[i+1] = a )
var vs : seq[Fr] = newSeq[Fr](n)
vs[n-1] = invFr( us[n] )
for i in countdown(n-2,0): vs[i] = vs[i+1] * xs[i+1]
return collect( newSeq, (for i in 0..<n: us[i]*vs[i] ) )
proc sanityCheckBatchInverse*() =
let xs : seq[Fr] = map( toSeq(101..137) , intToFr )
let ys = batchInverse( xs )
let zs = collect( newSeq, (for x in xs: invFr(x)) )
let n = xs.len
# for i in 0..<n: echo(i," | batch = ",toDecimalFr(ys[i])," | ref = ",toDecimalFr(zs[i]) )
for i in 0..<n:
if not bool(ys[i] == zs[i]):
echo "batch inverse test FAILED!"
return
echo "batch iverse test OK."
#-------------------------------------------------------------------------------
# random values
var randomInitialized : bool = false
var randomState : Rand = initRand( 12345 )
proc rndUint64() : uint64 =
return randomState.next()
proc initializeRandomIfNecessary() =
if not randomInitialized:
randomState = initRand()
randomInitialized = true
#----------------------------| 01234567890abcdf01234567890abcdf01234567890abcdf01234567890abcdf
const m64 : B = fromHex( B, "0x0000000000000000000000000000000000000000000000010000000000000000", bigEndian )
const m128 : B = fromHex( B, "0x0000000000000000000000000000000100000000000000000000000000000000", bigEndian )
const m192 : B = fromHex( B, "0x0000000000000001000000000000000000000000000000000000000000000000", bigEndian )
#----------------------------| 01234567890abcdf01234567890abcdf01234567890abcdf01234567890abcdf
proc randBig*[bits: static int](): BigInt[bits] =
initializeRandomIfNecessary()
let a0 : uint64 = rndUint64()
let a1 : uint64 = rndUint64()
let a2 : uint64 = rndUint64()
let a3 : uint64 = rndUint64()
# echo((a0,a1,a2,a3))
var b0 : BigInt[bits] ; b0.fromUint(a0)
var b1 : BigInt[bits] ; b1.fromUint(a1)
var b2 : BigInt[bits] ; b2.fromUint(a2)
var b3 : BigInt[bits] ; b3.fromUint(a3)
# constantine doesn't appear to have left shift....
var c1,c2,c3 : BigInt[bits]
prod( c1 , b1 , m64 )
prod( c2 , b2 , m128 )
prod( c3 , b3 , m192 )
var d : BigInt[bits]
d = b0
d += c1
d += c2
d += c3
return d
proc randFr*(): Fr =
let b : BigInt[254] = randBig[254]()
var y : Fr
y.fromBig( b )
return y
proc testRandom*() =
for i in 1..20:
let x = randFr()
echo(x.toHex())
echo("-------------------")
echo(primeR.toHex())
#-------------------------------------------------------------------------------
func checkCurveEqG1*( x, y: Fp ) : bool =
if bool(isZero(x)) and bool(isZero(y)):
# the point at infinity is on the curve by definition
return true
else:
var x2 : Fp = x ; square(x2);
var y2 : Fp = y ; square(y2);
var x3 : Fp = x2 ; x3 *= x;
var eq : Fp
eq = x3
eq += intToFp(3)
eq -= y2
# echo("eq = ",toDecimalFp(eq))
return (bool(isZero(eq)))
#---------------------------------------
# y^2 = x^3 + B
# B = b1 + bu*u
# b1 = 19485874751759354771024239261021720505790618469301721065564631296452457478373
# b2 = 266929791119991161246907387137283842545076965332900288569378510910307636690
const twistCoeffB_1 : Fp = fromHex(Fp, "0x2b149d40ceb8aaae81be18991be06ac3b5b4c5e559dbefa33267e6dc24a138e5")
const twistCoeffB_u : Fp = fromHex(Fp, "0x009713b03af0fed4cd2cafadeed8fdf4a74fa084e52d1852e4a2bd0685c315d2")
const twistCoeffB : Fp2 = mkFp2( twistCoeffB_1 , twistCoeffB_u )
func checkCurveEqG2*( x, y: Fp2 ) : bool =
if bool(isZero(x)) and bool(isZero(y)):
# the point at infinity is on the curve by definition
return true
else:
var x2 : Fp2 = x ; square(x2);
var y2 : Fp2 = y ; square(y2);
var x3 : Fp2 = x2 ; x3 *= x;
var eq : Fp2
eq = x3
eq += twistCoeffB
eq -= y2
return (bool(isZero(eq)))
#-------------------------------------------------------------------------------
func mkG1( x, y: Fp ) : G1 =
if bool(isZero(x)) and bool(isZero(y)):
return infG1
else:
assert( checkCurveEqG1(x,y) , "mkG1: not a G1 curve point" )
return unsafeMkG1(x,y)
func mkG2( x, y: Fp2 ) : G2 =
if bool(isZero(x)) and bool(isZero(y)):
return infG2
else:
assert( checkCurveEqG2(x,y) , "mkG2: not a G2 curve point" )
return unsafeMkG2(x,y)
#-------------------------------------------------------------------------------
# group generators
const gen1_x : Fp = fromHex(Fp, "0x01")
const gen1_y : Fp = fromHex(Fp, "0x02")
const gen2_xi : Fp = fromHex(Fp, "0x1adcd0ed10df9cb87040f46655e3808f98aa68a570acf5b0bde23fab1f149701")
const gen2_xu : Fp = fromHex(Fp, "0x09e847e9f05a6082c3cd2a1d0a3a82e6fbfbe620f7f31269fa15d21c1c13b23b")
const gen2_yi : Fp = fromHex(Fp, "0x056c01168a5319461f7ca7aa19d4fcfd1c7cdf52dbfc4cbee6f915250b7f6fc8")
const gen2_yu : Fp = fromHex(Fp, "0x0efe500a2d02dd77f5f401329f30895df553b878fc3c0dadaaa86456a623235c")
const gen2_x : Fp2 = mkFp2( gen2_xi, gen2_xu )
const gen2_y : Fp2 = mkFp2( gen2_yi, gen2_yu )
const gen1* : G1 = unsafeMkG1( gen1_x, gen1_y )
const gen2* : G2 = unsafeMkG2( gen2_x, gen2_y )
#-------------------------------------------------------------------------------
func isOnCurveG1* ( p: G1 ) : bool =
return checkCurveEqG1( p.x, p.y )
func isOnCurveG2* ( p: G2 ) : bool =
return checkCurveEqG2( p.x, p.y )
#-------------------------------------------------------------------------------
# Dealing with Montgomery representation
#
# R=2^256; this computes 2^256 mod Fp
func calcFpMontR*() : Fp =
var x : Fp = intToFp(2)
for i in 1..8:
square(x)
return x
# R=2^256; this computes the inverse of (2^256 mod Fp)
func calcFpInvMontR*() : Fp =
var x : Fp = calcFpMontR()
inv(x)
return x
# R=2^256; this computes 2^256 mod Fr
func calcFrMontR*() : Fr =
var x : Fr = intToFr(2)
for i in 1..8:
square(x)
return x
# R=2^256; this computes the inverse of (2^256 mod Fp)
func calcFrInvMontR*() : Fr =
var x : Fr = calcFrMontR()
inv(x)
return x
# apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?)
const fpMontR* : Fp = fromHex( Fp, "0x0e0a77c19a07df2f666ea36f7879462c0a78eb28f5c70b3dd35d438dc58f0d9d" )
const fpInvMontR* : Fp = fromHex( Fp, "0x2e67157159e5c639cf63e9cfb74492d9eb2022850278edf8ed84884a014afa37" )
# apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?)
const frMontR* : Fr = fromHex( Fr, "0x0e0a77c19a07df2f666ea36f7879462e36fc76959f60cd29ac96341c4ffffffb" )
const frInvMontR* : Fr = fromHex( Fr, "0x15ebf95182c5551cc8260de4aeb85d5d090ef5a9e111ec87dc5ba0056db1194e" )
proc checkMontgomeryConstants*() =
assert( bool( fpMontR == calcFpMontR() ) )
assert( bool( frMontR == calcFrMontR() ) )
assert( bool( fpInvMontR == calcFpInvMontR() ) )
assert( bool( frInvMontR == calcFrInvMontR() ) )
echo("OK")
#---------------------------------------
# the binary files used by the `circom` ecosystem (EXCEPT the witness file!)
# always use little-endian Montgomery representation. So when we unmarshal
# with Constantine, it will give the wrong result. Calling this function on
# the result fixes that.
func fromMontgomeryFp*(x : Fp) : Fp =
var y : Fp = x;
y *= fpInvMontR
return y
func fromMontgomeryFr*(x : Fr) : Fr =
var y : Fr = x;
y *= frInvMontR
return y
func toMontgomeryFr*(x : Fr) : Fr =
var y : Fr = x;
y *= frMontR
return y
#-------------------------------------------------------------------------------
# Unmarshalling field elements
# (note: circom binary files use little-endian Montgomery representation)
# Except, in witness files, where the standard representation is used
# And, EXCEPT in the zkey coefficients, where apparently DOUBLE Montgomery encoding is used ???
#
# WTF Jordi, go home you are drunk
func unmarshalFrWTF* ( bs: array[32,byte] ) : Fr =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fr
x.fromBig( big )
return fromMontgomeryFr(fromMontgomeryFr(x))
func unmarshalFrStd* ( bs: array[32,byte] ) : Fr =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fr
x.fromBig( big )
return x
func unmarshalFpMont* ( bs: array[32,byte] ) : Fp =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fp
x.fromBig( big )
return fromMontgomeryFp(x)
func unmarshalFrMont* ( bs: array[32,byte] ) : Fr =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fr
x.fromBig( big )
return fromMontgomeryFr(x)
#-------------------------------------------------------------------------------
func unmarshalFpMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fp] =
var vals : seq[Fp] = newSeq[Fp]( len )
var bytes : array[32,byte]
for i in 0..<len:
copyMem( addr(bytes) , unsafeAddr(bs[32*i]) , 32 )
vals[i] = unmarshalFpMont( bytes )
return vals
func unmarshalFrMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fr] =
var vals : seq[Fr] = newSeq[Fr]( len )
var bytes : array[32,byte]
for i in 0..<len:
copyMem( addr(bytes) , unsafeAddr(bs[32*i]) , 32 )
vals[i] = unmarshalFrMont( bytes )
return vals
#-------------------------------------------------------------------------------
proc loadValueFrWTF*( stream: Stream ) : Fr =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
# for i in 0..<32: stdout.write(" " & toHex(bytes[i]))
# echo("")
assert( n == 32 )
return unmarshalFrWTF(bytes)
proc loadValueFrStd*( stream: Stream ) : Fr =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
assert( n == 32 )
return unmarshalFrStd(bytes)
proc loadValueFrMont*( stream: Stream ) : Fr =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
assert( n == 32 )
return unmarshalFrMont(bytes)
proc loadValueFpMont*( stream: Stream ) : Fp =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
assert( n == 32 )
return unmarshalFpMont(bytes)
proc loadValueFp2Mont*( stream: Stream ) : Fp2 =
let i = loadValueFpMont( stream )
let u = loadValueFpMont( stream )
return mkFp2(i,u)
#---------------------------------------
proc loadValuesFrStd*( len: int, stream: Stream ) : seq[Fr] =
var values : seq[Fr]
for i in 1..len:
values.add( loadValueFrStd(stream) )
return values
proc loadValuesFpMont*( len: int, stream: Stream ) : seq[Fp] =
var values : seq[Fp]
for i in 1..len:
values.add( loadValueFpMont(stream) )
return values
proc loadValuesFrMont*( len: int, stream: Stream ) : seq[Fr] =
var values : seq[Fr]
for i in 1..len:
values.add( loadValueFrMont(stream) )
return values
#-------------------------------------------------------------------------------
proc loadPointG1*( stream: Stream ) : G1 =
let x = loadValueFpMont( stream )
let y = loadValueFpMont( stream )
return mkG1(x,y)
proc loadPointG2*( stream: Stream ) : G2 =
let x = loadValueFp2Mont( stream )
let y = loadValueFp2Mont( stream )
return mkG2(x,y)
#---------------------------------------
proc loadPointsG1*( len: int, stream: Stream ) : seq[G1] =
var points : seq[G1]
for i in 1..len:
points.add( loadPointG1(stream) )
return points
proc loadPointsG2*( len: int, stream: Stream ) : seq[G2] =
var points : seq[G2]
for i in 1..len:
points.add( loadPointG2(stream) )
return points
#===============================================================================
func addG1*(p,q: G1): G1 =
var r, x, y : ProjG1
prj.fromAffine(x, p)
prj.fromAffine(y, q)
prj.sum(r, x, y)
var s : G1
prj.affine(s, r)
return s
func addG2*(p,q: G2): G2 =
var r, x, y : ProjG2
prj.fromAffine(x, p)
prj.fromAffine(y, q)
prj.sum(r, x, y)
var s : G2
prj.affine(s, r)
return s
func negG1*(p: G1): G1 =
var r : G1 = p
neg(r)
return r
func negG2*(p: G2): G2 =
var r : G2 = p
neg(r)
return r
func `+`*(p,q: G1): G1 = addG1(p,q)
func `+`*(p,q: G2): G2 = addG2(p,q)
func `+=`*(p: var G1, q: G1) = p = addG1(p,q)
func `+=`*(p: var G2, q: G2) = p = addG2(p,q)
func `-=`*(p: var G1, q: G1) = p = addG1(p,negG1(q))
func `-=`*(p: var G2, q: G2) = p = addG2(p,negG2(q))
#-------------------------------------------------------------------------------
func msmG1*( coeffs: openArray[Fr] , points: openArray[G1] ): G1 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
# var arr1 = toOpenArray(coeffs, 0, N-1)
# var arr2 = toOpenArray(points, 0, N-1)
var bigcfs : seq[BigInt[254]]
for x in coeffs:
bigcfs.add( x.toBig() )
var r : ProjG1
# [Fp,aff.G1]
msm.multiScalarMul_vartime( r,
toOpenArray(bigcfs, 0, N-1),
toOpenArray(points, 0, N-1) )
var rAff: G1
prj.affine(rAff, r)
return rAff
func msmG2*( coeffs: openArray[Fr] , points: openArray[G2] ): G2 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
var bigcfs : seq[BigInt[254]]
for x in coeffs:
bigcfs.add( x.toBig() )
var r : ProjG2
# [Fp,aff.G1]
msm.multiScalarMul_vartime( r,
toOpenArray(bigcfs, 0, N-1),
toOpenArray(points, 0, N-1) )
var rAff: G2
prj.affine(rAff, r)
return rAff
#-------------------------------------------------------------------------------
#
# (affine) scalar multiplication
#
func `**`*( coeff: Fr , point: G1 ) : G1 =
var q : ProjG1
prj.fromAffine( q , point )
scl.scalarMul( q , coeff.toBig() )
var r : G1
prj.affine( r, q )
return r
func `**`*( coeff: Fr , point: G2 ) : G2 =
var q : ProjG2
prj.fromAffine( q , point )
scl.scalarMul( q , coeff.toBig() )
var r : G2
prj.affine( r, q )
return r
#-------------------
func `**`*( coeff: BigInt , point: G1 ) : G1 =
var q : ProjG1
prj.fromAffine( q , point )
scl.scalarMul( q , coeff )
var r : G1
prj.affine( r, q )
return r
func `**`*( coeff: BigInt , point: G2 ) : G2 =
var q : ProjG2
prj.fromAffine( q , point )
scl.scalarMul( q , coeff )
var r : G2
prj.affine( r, q )
return r
#-------------------------------------------------------------------------------
func msmNaiveG1( coeffs: seq[Fr] , points: seq[G1] ): G1 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
var s : ProjG1
s.setInf()
for i in 0..<N:
var t : ProjG1
prj.fromAffine( t, points[i] )
scl.scalarMul( t , coeffs[i].toBig() )
s += t
var r : G1
prj.affine( r, s )
return r
#---------------------------------------
func msmNaiveG2( coeffs: seq[Fr] , points: seq[G2] ): G2 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
var s : ProjG2
s.setInf()
for i in 0..<N:
var t : ProjG2
prj.fromAffine( t, points[i] )
scl.scalarMul( t , coeffs[i].toBig() )
s += t
var r : G2
prj.affine( r, s)
return r
#-------------------------------------------------------------------------------
proc sanityCheckGroupGen*() =
echo( "gen1 on the curve = ", checkCurveEqG1(gen1.x,gen1.y) )
echo( "gen2 on the curve = ", checkCurveEqG2(gen2.x,gen2.y) )
echo( "order of gen1 is R = ", (not bool(isInf(gen1))) and bool(isInf(primeR ** gen1)) )
echo( "order of gen2 is R = ", (not bool(isInf(gen2))) and bool(isInf(primeR ** gen2)) )
#-------------------------------------------------------------------------------

View File

@ -1,11 +0,0 @@
import ../test_proof
import ../export_json
let zkey_fname : string = "./build/product.zkey"
let wtns_fname : string = "./build/product.wtns"
let proof = testProveAndVerify( zkey_fname, wtns_fname)
exportPublicIO( "./build/nim_public.json" , proof )
exportProof( "./build/nim_proof.json" , proof )

View File

@ -1,242 +1,6 @@
#
# Groth16 prover
#
# WARNING!
# the points H in `.zkey` are *NOT* what normal people would think they are
# See <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
#
#[
import sugar
import constantine/math/config/curves
import constantine/math/io/io_fields
import constantine/math/io/io_bigints
import ./zkey
]#
import constantine/math/arithmetic except Fp, Fr
import constantine/math/io/io_extfields except Fp12
import constantine/math/extension_fields/towers except Fp2, Fp12
import ./bn128
import ./domain
import ./poly
import ./zkey_types
import ./witness
#-------------------------------------------------------------------------------
type
Proof* = object
publicIO* : seq[Fr]
pi_a* : G1
pi_b* : G2
pi_c* : G1
curve : string
#-------------------------------------------------------------------------------
# the verifier
#
proc verifyProof* (vkey: VKey, prf: Proof): bool =
assert( prf.curve == "bn128" )
assert( isOnCurveG1(prf.pi_a) , "pi_a is not in G1" )
assert( isOnCurveG2(prf.pi_b) , "pi_b is not in G2" )
assert( isOnCurveG1(prf.pi_c) , "pi_c is not in G1" )
var pubG1 : G1 = msmG1( prf.publicIO , vkey.vpoints.pointsIC )
let lhs : Fp12 = pairing( negG1(prf.pi_a) , prf.pi_b ) # < -pi_a , pi_b >
let rhs1 : Fp12 = vkey.spec.alphaBeta # < alpha , beta >
let rhs2 : Fp12 = pairing( prf.pi_c , vkey.spec.delta2 ) # < pi_c , delta >
let rhs3 : Fp12 = pairing( pubG1 , vkey.spec.gamma2 ) # < sum... , gamma >
var eq : Fp12
eq = lhs
eq *= rhs1
eq *= rhs2
eq *= rhs3
return bool(isOne(eq))
#-------------------------------------------------------------------------------
# A, B, C column vectors
#
type
ABC = object
valuesA : seq[Fr]
valuesB : seq[Fr]
valuesC : seq[Fr]
func buildABC( zkey: ZKey, witness: seq[Fr] ): ABC =
let hdr: GrothHeader = zkey.header
let domSize = hdr.domainSize
var valuesA : seq[Fr] = newSeq[Fr](domSize)
var valuesB : seq[Fr] = newSeq[Fr](domSize)
for entry in zkey.coeffs:
case entry.matrix
of MatrixA: valuesA[entry.row] += entry.coeff * witness[entry.col]
of MatrixB: valuesB[entry.row] += entry.coeff * witness[entry.col]
else: raise newException(AssertionDefect, "fatal error")
var valuesC : seq[Fr] = newSeq[Fr](domSize)
for i in 0..<domSize:
valuesC[i] = valuesA[i] * valuesB[i]
return ABC( valuesA:valuesA, valuesB:valuesB, valuesC:valuesC )
#-------------------------------------------------------------------------------
# quotient poly
#
# interpolates A,B,C, and computes the quotient polynomial Q = (A*B - C) / Z
func computeQuotientNaive( abc: ABC ): Poly=
let n = abc.valuesA.len
assert( abc.valuesB.len == n )
assert( abc.valuesC.len == n )
let D = createDomain(n)
let polyA : Poly = polyInverseNTT( abc.valuesA , D )
let polyB : Poly = polyInverseNTT( abc.valuesB , D )
let polyC : Poly = polyInverseNTT( abc.valuesC , D )
let polyBig = polyMulFFT( polyA , polyB ) - polyC
var polyQ = polyDivideByVanishing(polyBig, D.domainSize)
polyQ.coeffs.add( zeroFr ) # make it a power of two
return polyQ
#---------------------------------------
# returns [ eta^i * xs[i] | i<-[0..n-1] ]
func multiplyByPowers( xs: seq[Fr], eta: Fr ): seq[Fr] =
let n = xs.len
assert(n >= 1)
var ys : seq[Fr] = newSeq[Fr](n)
ys[0] = xs[0]
if n >= 1: ys[1] = eta * xs[1]
var spow : Fr = eta
for i in 2..<n:
spow *= eta
ys[i] = spow * xs[i]
return ys
# interpolates a polynomial, shift the variable by `eta`, and compute the shifted values
func shiftEvalDomain( values: seq[Fr], D: Domain, eta: Fr ): seq[Fr] =
let poly : Poly = polyInverseNTT( values , D )
let cs : seq[Fr] = poly.coeffs
var ds : seq[Fr] = multiplyByPowers( cs, eta )
return polyForwardNTT( Poly(coeffs:ds), D )
# computes the quotient polynomial Q = (A*B - C) / Z
# by computing the values on a shifted domain, and interpolating the result
# remark: Q has degree `n-2`, so it's enough to use a domain of size n
func computeQuotientPointwise( abc: ABC ): Poly =
let n = abc.valuesA.len
let D = createDomain(n)
# (eta*omega^j)^n - 1 = eta^n - 1
# 1 / [ (eta*omega^j)^n - 1] = 1/(eta^n - 1)
let eta = createDomain(2*n).domainGen
let invZ1 = invFr( smallPowFr(eta,n) - oneFr )
let A1 = shiftEvalDomain( abc.valuesA, D, eta )
let B1 = shiftEvalDomain( abc.valuesB, D, eta )
let C1 = shiftEvalDomain( abc.valuesC, D, eta )
var ys : seq[Fr] = newSeq[Fr]( n )
for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] ) * invZ1
let Q1 = polyInverseNTT( ys, D )
let cs = multiplyByPowers( Q1.coeffs, invFr(eta) )
return Poly(coeffs: cs)
#---------------------------------------
# Snarkjs does something different, not actually computing the quotient poly
# they can get away with this, because during the trusted setup, they
# transform the H points into (shifted??) Lagrange bases (?)
# see <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
#
func computeSnarkjsScalarCoeffs( abc: ABC ): seq[Fr] =
let n = abc.valuesA.len
let D = createDomain(n)
let eta = createDomain(2*n).domainGen
let A1 = shiftEvalDomain( abc.valuesA, D, eta )
let B1 = shiftEvalDomain( abc.valuesB, D, eta )
let C1 = shiftEvalDomain( abc.valuesC, D, eta )
var ys : seq[Fr] = newSeq[Fr]( n )
for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] )
return ys
#-------------------------------------------------------------------------------
# the prover
#
proc generateProof*( zkey: ZKey, wtns: Witness ): Proof =
assert( zkey.header.curve == wtns.curve )
let witness = wtns.values
let hdr : GrothHeader = zkey.header
let spec : SpecPoints = zkey.specPoints
let pts : ProverPoints = zkey.pPoints
let nvars = hdr.nvars
let npubs = hdr.npubs
assert( nvars == witness.len , "wrong witness length" )
var pubIO : seq[Fr] = newSeq[Fr]( npubs + 1)
for i in 0..npubs: pubIO[i] = witness[i]
var abc : ABC = buildABC( zkey, witness )
var qs : seq[Fr]
case zkey.header.flavour
# the points H are [delta^-1 * tau^i * Z(tau)]
of JensGroth:
let polyQ = computeQuotientPointwise( abc )
qs = polyQ.coeffs
# the points H are [delta^-1 * L_i(tau*eta) / Z(omega^i*eta)]
# where eta^2 = omega and L_i are Lagrange basis polynomials
of Snarkjs:
qs = computeSnarkjsScalarCoeffs( abc )
var zs : seq[Fr] = newSeq[Fr]( nvars - npubs - 1 )
for j in npubs+1..<nvars:
zs[j-npubs-1] = witness[j]
# masking coeffs
let r : Fr = randFr()
let s : Fr = randFr()
var pi_a : G1
pi_a = spec.alpha1
pi_a += r ** spec.delta1
pi_a += msmG1( witness , pts.pointsA1 )
var rho : G1
rho = spec.beta1
rho += s ** spec.delta1
rho += msmG1( witness , pts.pointsB1 )
var pi_b : G2
pi_b = spec.beta2
pi_b += s ** spec.delta2
pi_b += msmG2( witness , pts.pointsB2 )
var pi_c : G1
pi_c = s ** pi_a
pi_c += r ** rho
pi_c += negFr(r*s) ** spec.delta1
pi_c += msmG1( qs , pts.pointsH1 )
pi_c += msmG1( zs , pts.pointsC1 )
return Proof( curve:"bn128", publicIO:pubIO, pi_a:pi_a, pi_b:pi_b, pi_c:pi_c )
#-------------------------------------------------------------------------------
import groth16/bn128/types
import groth16/files/zkey
import groth16/files/witness
import groth16/prover
import groth16/verifier

9
groth16.nimble Normal file
View File

@ -0,0 +1,9 @@
version = "0.0.1"
author = "Balazs Komuves"
description = "Groth16 proof system"
license = "MIT OR Apache-2.0"
skipDirs = @["groth16/example"]
binDir = "build"
requires "https://github.com/mratsim/constantine"

31
groth16/bn128.nim Normal file
View File

@ -0,0 +1,31 @@
#
# the `alt-bn128` elliptic curve
#
# See for example <https://hackmd.io/@jpw/bn254>
#
# p = 21888242871839275222246405745257275088696311157297823662689037894645226208583
# r = 21888242871839275222246405745257275088548364400416034343698204186575808495617
#
# equation: y^2 = x^3 + 3
#
#-------------------------------------------------------------------------------
import groth16/bn128/fields
import groth16/bn128/curves
import groth16/bn128/msm
import groth16/bn128/io
import groth16/bn128/rnd
import groth16/bn128/debug
#-------------------
export fields
export curves
export msm
export io
export rnd
export debug
#-------------------------------------------------------------------------------

231
groth16/bn128/curves.nim Normal file
View File

@ -0,0 +1,231 @@
#
# the `alt-bn128` elliptic curve
#
# See for example <https://hackmd.io/@jpw/bn254>
#
# p = 21888242871839275222246405745257275088696311157297823662689037894645226208583
# r = 21888242871839275222246405745257275088548364400416034343698204186575808495617
#
# equation: y^2 = x^3 + 3
#
#import constantine/platforms/abstractions
#import constantine/math/isogenies/frobenius
import constantine/math/arithmetic except Fp, Fr
import constantine/math/io/io_fields except Fp, Fr
import constantine/math/io/io_bigints
import constantine/math/config/curves
import constantine/math/config/type_ff as tff except Fp, Fr
import constantine/math/extension_fields/towers as ext except Fp, Fp2, Fp12, Fr
import constantine/math/elliptic/ec_shortweierstrass_affine as aff
import constantine/math/elliptic/ec_shortweierstrass_projective as prj
import constantine/math/pairings/pairings_bn as ate
import constantine/math/elliptic/ec_scalar_mul as scl
import groth16/bn128/fields
#-------------------------------------------------------------------------------
type G1* = aff.ECP_ShortW_Aff[Fp , aff.G1]
type G2* = aff.ECP_ShortW_Aff[Fp2, aff.G2]
type ProjG1* = prj.ECP_ShortW_Prj[Fp , prj.G1]
type ProjG2* = prj.ECP_ShortW_Prj[Fp2, prj.G2]
#-------------------------------------------------------------------------------
func unsafeMkG1* ( X, Y: Fp ) : G1 =
return aff.ECP_ShortW_Aff[Fp, aff.G1](x: X, y: Y)
func unsafeMkG2* ( X, Y: Fp2 ) : G2 =
return aff.ECP_ShortW_Aff[Fp2, aff.G2](x: X, y: Y)
#-------------------------------------------------------------------------------
const infG1* : G1 = unsafeMkG1( zeroFp , zeroFp )
const infG2* : G2 = unsafeMkG2( zeroFp2 , zeroFp2 )
#-------------------------------------------------------------------------------
func checkCurveEqG1*( x, y: Fp ) : bool =
if bool(isZero(x)) and bool(isZero(y)):
# the point at infinity is on the curve by definition
return true
else:
var x2 : Fp = squareFp(x)
var y2 : Fp = squareFp(y)
var x3 : Fp = x2 * x
var eq : Fp
eq = x3
eq += intToFp(3)
eq -= y2
# echo("eq = ",toDecimalFp(eq))
return (bool(isZero(eq)))
#---------------------------------------
# y^2 = x^3 + B
# B = b1 + bu*u
# b1 = 19485874751759354771024239261021720505790618469301721065564631296452457478373
# b2 = 266929791119991161246907387137283842545076965332900288569378510910307636690
const twistCoeffB_1 : Fp = fromHex(Fp, "0x2b149d40ceb8aaae81be18991be06ac3b5b4c5e559dbefa33267e6dc24a138e5")
const twistCoeffB_u : Fp = fromHex(Fp, "0x009713b03af0fed4cd2cafadeed8fdf4a74fa084e52d1852e4a2bd0685c315d2")
const twistCoeffB : Fp2 = mkFp2( twistCoeffB_1 , twistCoeffB_u )
func checkCurveEqG2*( x, y: Fp2 ) : bool =
if isZeroFp2(x) and isZeroFp2(y):
# the point at infinity is on the curve by definition
return true
else:
var x2 : Fp2 = squareFp2(x)
var y2 : Fp2 = squareFp2(y)
var x3 : Fp2 = x2 * x;
var eq : Fp2
eq = x3
eq += twistCoeffB
eq -= y2
return isZeroFp2(eq)
#-------------------------------------------------------------------------------
func mkG1*( x, y: Fp ) : G1 =
if isZeroFp(x) and isZeroFp(y):
return infG1
else:
assert( checkCurveEqG1(x,y) , "mkG1: not a G1 curve point" )
return unsafeMkG1(x,y)
func mkG2*( x, y: Fp2 ) : G2 =
if isZeroFp2(x) and isZeroFp2(y):
return infG2
else:
assert( checkCurveEqG2(x,y) , "mkG2: not a G2 curve point" )
return unsafeMkG2(x,y)
#-------------------------------------------------------------------------------
# group generators
const gen1_x : Fp = fromHex(Fp, "0x01")
const gen1_y : Fp = fromHex(Fp, "0x02")
const gen2_xi : Fp = fromHex(Fp, "0x1adcd0ed10df9cb87040f46655e3808f98aa68a570acf5b0bde23fab1f149701")
const gen2_xu : Fp = fromHex(Fp, "0x09e847e9f05a6082c3cd2a1d0a3a82e6fbfbe620f7f31269fa15d21c1c13b23b")
const gen2_yi : Fp = fromHex(Fp, "0x056c01168a5319461f7ca7aa19d4fcfd1c7cdf52dbfc4cbee6f915250b7f6fc8")
const gen2_yu : Fp = fromHex(Fp, "0x0efe500a2d02dd77f5f401329f30895df553b878fc3c0dadaaa86456a623235c")
const gen2_x : Fp2 = mkFp2( gen2_xi, gen2_xu )
const gen2_y : Fp2 = mkFp2( gen2_yi, gen2_yu )
const gen1* : G1 = unsafeMkG1( gen1_x, gen1_y )
const gen2* : G2 = unsafeMkG2( gen2_x, gen2_y )
#-------------------------------------------------------------------------------
func isOnCurveG1* ( p: G1 ) : bool =
return checkCurveEqG1( p.x, p.y )
func isOnCurveG2* ( p: G2 ) : bool =
return checkCurveEqG2( p.x, p.y )
#===============================================================================
func addG1*(p,q: G1): G1 =
var r, x, y : ProjG1
prj.fromAffine(x, p)
prj.fromAffine(y, q)
prj.sum(r, x, y)
var s : G1
prj.affine(s, r)
return s
#---------------------------------------
func addG2*(p,q: G2): G2 =
var r, x, y : ProjG2
prj.fromAffine(x, p)
prj.fromAffine(y, q)
prj.sum(r, x, y)
var s : G2
prj.affine(s, r)
return s
func negG1*(p: G1): G1 =
var r : G1 = p
neg(r)
return r
func negG2*(p: G2): G2 =
var r : G2 = p
neg(r)
return r
#---------------------------------------
func `+`*(p,q: G1): G1 = addG1(p,q)
func `+`*(p,q: G2): G2 = addG2(p,q)
func `+=`*(p: var G1, q: G1) = p = addG1(p,q)
func `+=`*(p: var G2, q: G2) = p = addG2(p,q)
func `-=`*(p: var G1, q: G1) = p = addG1(p,negG1(q))
func `-=`*(p: var G2, q: G2) = p = addG2(p,negG2(q))
#-------------------------------------------------------------------------------
#
# (affine) scalar multiplication
#
func `**`*( coeff: Fr , point: G1 ) : G1 =
var q : ProjG1
prj.fromAffine( q , point )
scl.scalarMul( q , coeff.toBig() )
var r : G1
prj.affine( r, q )
return r
func `**`*( coeff: Fr , point: G2 ) : G2 =
var q : ProjG2
prj.fromAffine( q , point )
scl.scalarMul( q , coeff.toBig() )
var r : G2
prj.affine( r, q )
return r
#-------------------
func `**`*( coeff: BigInt , point: G1 ) : G1 =
var q : ProjG1
prj.fromAffine( q , point )
scl.scalarMul( q , coeff )
var r : G1
prj.affine( r, q )
return r
func `**`*( coeff: BigInt , point: G2 ) : G2 =
var q : ProjG2
prj.fromAffine( q , point )
scl.scalarMul( q , coeff )
var r : G2
prj.affine( r, q )
return r
#-------------------------------------------------------------------------------
func pairing* (p: G1, q: G2) : Fp12 =
var t : Fp12
ate.pairing_bn[BN254Snarks]( t, p, q )
return t
#-------------------------------------------------------------------------------
proc sanityCheckGroupGen*() =
echo( "gen1 on the curve = ", checkCurveEqG1(gen1.x,gen1.y) )
echo( "gen2 on the curve = ", checkCurveEqG2(gen2.x,gen2.y) )
echo( "order of gen1 is R = ", (not bool(isInf(gen1))) and bool(isInf(primeR ** gen1)) )
echo( "order of gen2 is R = ", (not bool(isInf(gen2))) and bool(isInf(primeR ** gen2)) )
#-------------------------------------------------------------------------------

45
groth16/bn128/debug.nim Normal file
View File

@ -0,0 +1,45 @@
#
# the `alt-bn128` elliptic curve
#
# See for example <https://hackmd.io/@jpw/bn254>
#
# p = 21888242871839275222246405745257275088696311157297823662689037894645226208583
# r = 21888242871839275222246405745257275088548364400416034343698204186575808495617
#
# equation: y^2 = x^3 + 3
#
import groth16/bn128/fields
import groth16/bn128/curves
import groth16/bn128/io
#-------------------------------------------------------------------------------
proc debugPrintFp*(prefix: string, x: Fp) =
echo(prefix & toDecimalFp(x))
proc debugPrintFp2*(prefix: string, z: Fp2) =
echo(prefix & " 1 ~> " & toDecimalFp(z.coords[0]))
echo(prefix & " u ~> " & toDecimalFp(z.coords[1]))
proc debugPrintFr*(prefix: string, x: Fr) =
echo(prefix & toDecimalFr(x))
proc debugPrintFrSeq*(msg: string, xs: seq[Fr]) =
echo "---------------------"
echo msg
for x in xs:
debugPrintFr( " " , x )
proc debugPrintG1*(msg: string, pt: G1) =
echo(msg & ":")
debugPrintFp( " x = ", pt.x )
debugPrintFp( " y = ", pt.y )
proc debugPrintG2*(msg: string, pt: G2) =
echo(msg & ":")
debugPrintFp2( " x = ", pt.x )
debugPrintFp2( " y = ", pt.y )
#-------------------------------------------------------------------------------

189
groth16/bn128/fields.nim Normal file
View File

@ -0,0 +1,189 @@
#
# the prime fields Fp and Fr with sizes
#
# p = 21888242871839275222246405745257275088696311157297823662689037894645226208583
# r = 21888242871839275222246405745257275088548364400416034343698204186575808495617
#
import sugar
import std/bitops
import std/sequtils
import constantine/math/arithmetic
import constantine/math/io/io_fields
import constantine/math/io/io_bigints
import constantine/math/config/curves
import constantine/math/config/type_ff as tff
import constantine/math/extension_fields/towers as ext
#-------------------------------------------------------------------------------
type B* = BigInt[256]
type Fr* = tff.Fr[BN254Snarks]
type Fp* = tff.Fp[BN254Snarks]
type Fp2* = ext.QuadraticExt[Fp]
type Fp12* = ext.Fp12[BN254Snarks]
func mkFp2* (i: Fp, u: Fp) : Fp2 =
let c : array[2, Fp] = [i,u]
return ext.QuadraticExt[Fp]( coords: c )
#-------------------------------------------------------------------------------
const primeP* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian )
const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian )
#-------------------------------------------------------------------------------
const zeroFp* : Fp = fromHex( Fp, "0x00" )
const zeroFr* : Fr = fromHex( Fr, "0x00" )
const oneFp* : Fp = fromHex( Fp, "0x01" )
const oneFr* : Fr = fromHex( Fr, "0x01" )
const zeroFp2* : Fp2 = mkFp2( zeroFp, zeroFp )
const oneFp2* : Fp2 = mkFp2( oneFp , zeroFp )
const minusOneFp* : Fp = fromHex( Fp, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd46" )
const minusOneFr* : Fr = fromHex( Fr, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000000" )
#-------------------------------------------------------------------------------
func intToB*(a: uint): B =
var y : B
y.setUint(a)
return y
func intToFp*(a: int): Fp =
var y : Fp
y.fromInt(a)
return y
func intToFr*(a: int): Fr =
var y : Fr
y.fromInt(a)
return y
#-------------------------------------------------------------------------------
func isZeroFp* (x: Fp ): bool = bool(isZero(x))
func isZeroFp2*(x: Fp2): bool = bool(isZero(x))
func isZeroFr* (x: Fr ): bool = bool(isZero(x))
func isEqualFp* (x, y: Fp ): bool = bool(x == y)
func isEqualFp2*(x, y: Fp2): bool = bool(x == y)
func isEqualFr* (x, y: Fr ): bool = bool(x == y)
func `===`*(x, y: Fp ): bool = isEqualFp(x,y)
func `===`*(x, y: Fp2): bool = isEqualFp2(x,y)
func `===`*(x, y: Fr ): bool = isEqualFr(x,y)
#-------------------
func isEqualFpSeq*(xs, ys: seq[Fp]): bool =
let n = xs.len
assert( n == ys.len )
var b = true
for i in 0..<n:
if not bool(xs[i] == ys[i]):
b = false
break
return b
func isEqualFrSeq*(xs, ys: seq[Fr]): bool =
let n = xs.len
assert( n == ys.len )
var b = true
for i in 0..<n:
if not bool(xs[i] == ys[i]):
b = false
break
return b
func `===`*(xs, ys: seq[Fp]): bool = isEqualFpSeq(xs,ys)
func `===`*(xs, ys: seq[Fr]): bool = isEqualFrSeq(xs,ys)
#-------------------------------------------------------------------------------
func `+`*[n](x, y: BigInt[n] ): BigInt[n] = ( var z : BigInt[n] = x ; z += y ; return z )
func `+`*(x, y: Fp ): Fp = ( var z : Fp = x ; z += y ; return z )
func `+`*(x, y: Fp2): Fp2 = ( var z : Fp2 = x ; z += y ; return z )
func `+`*(x, y: Fr ): Fr = ( var z : Fr = x ; z += y ; return z )
func `-`*[n](x, y: BigInt[n] ): BigInt[n] = ( var z : BigInt[n] = x ; z -= y ; return z )
func `-`*(x, y: Fp ): Fp = ( var z : Fp = x ; z -= y ; return z )
func `-`*(x, y: Fp2): Fp2 = ( var z : Fp2 = x ; z -= y ; return z )
func `-`*(x, y: Fr ): Fr = ( var z : Fr = x ; z -= y ; return z )
func `*`*(x, y: Fp ): Fp = ( var z : Fp = x ; z *= y ; return z )
func `*`*(x, y: Fp2): Fp2 = ( var z : Fp2 = x ; z *= y ; return z )
func `*`*(x, y: Fr ): Fr = ( var z : Fr = x ; z *= y ; return z )
func negFp* (y: Fp ): Fp = ( var z : Fp = zeroFp ; z -= y ; return z )
func negFp2*(y: Fp2): Fp2 = ( var z : Fp2 = zeroFp2 ; z -= y ; return z )
func negFr* (y: Fr ): Fr = ( var z : Fr = zeroFr ; z -= y ; return z )
func invFp*(y: Fp): Fp = ( var z : Fp = y ; inv(z) ; return z )
func invFr*(y: Fr): Fr = ( var z : Fr = y ; inv(z) ; return z )
func squareFp* (y: Fp): Fp = ( var z : Fp = y ; square(z) ; return z )
func squareFp2*(y: Fp2): Fp2 = ( var z : Fp2 = y ; square(z) ; return z )
func squareFr* (y: Fr): Fr = ( var z : Fr = y ; square(z) ; return z )
# template/generic instantiation of `pow_vartime` from here
# /Users/bkomuves/.nimble/pkgs/constantine-0.0.1/constantine/math/arithmetic/finite_fields.nim(389, 7) template/generic instantiation of `fieldMod` from here
# /Users/bkomuves/.nimble/pkgs/constantine-0.0.1/constantine/math/config/curves_prop_field_derived.nim(67, 5) Error: undeclared identifier: 'getCurveOrder'
# ...
func smallPowFr*(base: Fr, expo: uint): Fr =
var a : Fr = oneFr
var s : Fr = base
var e : uint = expo
while (e > 0):
if bitand(e,1) > 0: a *= s
e = (e shr 1)
square(s)
return a
func smallPowFr*(base: Fr, expo: int): Fr =
if expo >= 0:
return smallPowFr( base, uint(expo) )
else:
return smallPowFr( invFr(base) , uint(-expo) )
#-------------------------------------------------------------------------------
func deltaFr*(i, j: int) : Fr =
return (if (i == j): oneFr else: zeroFr)
#-------------------------------------------------------------------------------
# Montgomery batch inversion
func batchInverseFr*( xs: seq[Fr] ) : seq[Fr] =
let n = xs.len
assert(n>0)
var us : seq[Fr] = newSeq[Fr](n+1)
var a = xs[0]
us[0] = oneFr
us[1] = a
for i in 1..<n: ( a *= xs[i] ; us[i+1] = a )
var vs : seq[Fr] = newSeq[Fr](n)
vs[n-1] = invFr( us[n] )
for i in countdown(n-2,0): vs[i] = vs[i+1] * xs[i+1]
return collect( newSeq, (for i in 0..<n: us[i]*vs[i] ) )
proc sanityCheckBatchInverseFr*() =
let xs : seq[Fr] = map( toSeq(101..137) , intToFr )
let ys = batchInverseFr( xs )
let zs = collect( newSeq, (for x in xs: invFr(x)) )
let n = xs.len
# for i in 0..<n: echo(i," | batch = ",toDecimalFr(ys[i])," | ref = ",toDecimalFr(zs[i]) )
for i in 0..<n:
if not bool(ys[i] == zs[i]):
echo "batch inverse test FAILED!"
return
echo "batch iverse test OK."
#-------------------------------------------------------------------------------

253
groth16/bn128/io.nim Normal file
View File

@ -0,0 +1,253 @@
import std/strutils
import std/streams
import constantine/math/arithmetic except Fp, Fp2, Fr
import constantine/math/io/io_fields except Fp, Fp2, Fp
import constantine/math/io/io_bigints
import constantine/math/config/curves
import constantine/math/config/type_ff as tff except Fp, Fp2, Fr
import groth16/bn128/fields
import groth16/bn128/curves
#-------------------------------------------------------------------------------
const primeP_254 : BigInt[254] = fromHex( BigInt[254], "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian )
const primeR_254 : BigInt[254] = fromHex( BigInt[254], "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian )
#-------------------------------------------------------------------------------
func toDecimalBig*[n](a : BigInt[n]): string =
var s : string = toDecimal(a)
s = s.strip( leading=true, trailing=false, chars={'0'} )
if s.len == 0: s="0"
return s
func toDecimalFp*(a : Fp): string =
var s : string = toDecimal(a)
s = s.strip( leading=true, trailing=false, chars={'0'} )
if s.len == 0: s="0"
return s
func toDecimalFr*(a : Fr): string =
var s : string = toDecimal(a)
s = s.strip( leading=true, trailing=false, chars={'0'} )
if s.len == 0: s="0"
return s
#---------------------------------------
const k65536 : BigInt[254] = fromHex( BigInt[254], "0x10000", bigEndian )
func signedToDecimalFp*(a : Fp): string =
if bool( a.toBig() > primeP_254 - k65536 ):
return "-" & toDecimalFp(negFp(a))
else:
return toDecimalFp(a)
func signedToDecimalFr*(a : Fr): string =
if bool( a.toBig() > primeR_254 - k65536 ):
return "-" & toDecimalFr(negFr(a))
else:
return toDecimalFr(a)
#-------------------------------------------------------------------------------
# Dealing with Montgomery representation
#
# R=2^256; this computes 2^256 mod Fp
func calcFpMontR*() : Fp =
var x : Fp = intToFp(2)
for i in 1..8:
square(x)
return x
# R=2^256; this computes the inverse of (2^256 mod Fp)
func calcFpInvMontR*() : Fp =
var x : Fp = calcFpMontR()
inv(x)
return x
# R=2^256; this computes 2^256 mod Fr
func calcFrMontR*() : Fr =
var x : Fr = intToFr(2)
for i in 1..8:
square(x)
return x
# R=2^256; this computes the inverse of (2^256 mod Fp)
func calcFrInvMontR*() : Fr =
var x : Fr = calcFrMontR()
inv(x)
return x
# apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?)
const fpMontR* : Fp = fromHex( Fp, "0x0e0a77c19a07df2f666ea36f7879462c0a78eb28f5c70b3dd35d438dc58f0d9d" )
const fpInvMontR* : Fp = fromHex( Fp, "0x2e67157159e5c639cf63e9cfb74492d9eb2022850278edf8ed84884a014afa37" )
# apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?)
const frMontR* : Fr = fromHex( Fr, "0x0e0a77c19a07df2f666ea36f7879462e36fc76959f60cd29ac96341c4ffffffb" )
const frInvMontR* : Fr = fromHex( Fr, "0x15ebf95182c5551cc8260de4aeb85d5d090ef5a9e111ec87dc5ba0056db1194e" )
proc checkMontgomeryConstants*() =
assert( bool( fpMontR == calcFpMontR() ) )
assert( bool( frMontR == calcFrMontR() ) )
assert( bool( fpInvMontR == calcFpInvMontR() ) )
assert( bool( frInvMontR == calcFrInvMontR() ) )
echo("OK")
#---------------------------------------
# the binary file `.zkey` used by the `circom` ecosystem uses little-endian
# Montgomery representation. So when we unmarshal with Constantine, it will
# give the wrong result. Calling this function on the result fixes that.
func fromMontgomeryFp*(x : Fp) : Fp =
var y : Fp = x;
y *= fpInvMontR
return y
func fromMontgomeryFr*(x : Fr) : Fr =
var y : Fr = x;
y *= frInvMontR
return y
func toMontgomeryFr*(x : Fr) : Fr =
var y : Fr = x;
y *= frMontR
return y
#-------------------------------------------------------------------------------
# Unmarshalling field elements
# Note: in the `.zkey` coefficients, e apparently DOUBLE Montgomery encoding is used ?!?
#
func unmarshalFpMont* ( bs: array[32,byte] ) : Fp =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fp
x.fromBig( big )
return fromMontgomeryFp(x)
# WTF Jordi, go home you are drunk
func unmarshalFrWTF* ( bs: array[32,byte] ) : Fr =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fr
x.fromBig( big )
return fromMontgomeryFr(fromMontgomeryFr(x))
func unmarshalFrStd* ( bs: array[32,byte] ) : Fr =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fr
x.fromBig( big )
return x
func unmarshalFrMont* ( bs: array[32,byte] ) : Fr =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fr
x.fromBig( big )
return fromMontgomeryFr(x)
#-------------------------------------------------------------------------------
func unmarshalFpMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fp] =
var vals : seq[Fp] = newSeq[Fp]( len )
var bytes : array[32,byte]
for i in 0..<len:
copyMem( addr(bytes) , unsafeAddr(bs[32*i]) , 32 )
vals[i] = unmarshalFpMont( bytes )
return vals
func unmarshalFrMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fr] =
var vals : seq[Fr] = newSeq[Fr]( len )
var bytes : array[32,byte]
for i in 0..<len:
copyMem( addr(bytes) , unsafeAddr(bs[32*i]) , 32 )
vals[i] = unmarshalFrMont( bytes )
return vals
#-------------------------------------------------------------------------------
proc loadValueFrWTF*( stream: Stream ) : Fr =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
# for i in 0..<32: stdout.write(" " & toHex(bytes[i]))
# echo("")
assert( n == 32 )
return unmarshalFrWTF(bytes)
proc loadValueFrStd*( stream: Stream ) : Fr =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
assert( n == 32 )
return unmarshalFrStd(bytes)
proc loadValueFrMont*( stream: Stream ) : Fr =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
assert( n == 32 )
return unmarshalFrMont(bytes)
proc loadValueFpMont*( stream: Stream ) : Fp =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
assert( n == 32 )
return unmarshalFpMont(bytes)
proc loadValueFp2Mont*( stream: Stream ) : Fp2 =
let i = loadValueFpMont( stream )
let u = loadValueFpMont( stream )
return mkFp2(i,u)
#---------------------------------------
proc loadValuesFrStd*( len: int, stream: Stream ) : seq[Fr] =
var values : seq[Fr]
for i in 1..len:
values.add( loadValueFrStd(stream) )
return values
proc loadValuesFpMont*( len: int, stream: Stream ) : seq[Fp] =
var values : seq[Fp]
for i in 1..len:
values.add( loadValueFpMont(stream) )
return values
proc loadValuesFrMont*( len: int, stream: Stream ) : seq[Fr] =
var values : seq[Fr]
for i in 1..len:
values.add( loadValueFrMont(stream) )
return values
#-------------------------------------------------------------------------------
proc loadPointG1*( stream: Stream ) : G1 =
let x = loadValueFpMont( stream )
let y = loadValueFpMont( stream )
return mkG1(x,y)
proc loadPointG2*( stream: Stream ) : G2 =
let x = loadValueFp2Mont( stream )
let y = loadValueFp2Mont( stream )
return mkG2(x,y)
#---------------------------------------
proc loadPointsG1*( len: int, stream: Stream ) : seq[G1] =
var points : seq[G1]
for i in 1..len:
points.add( loadPointG1(stream) )
return points
proc loadPointsG2*( len: int, stream: Stream ) : seq[G2] =
var points : seq[G2]
for i in 1..len:
points.add( loadPointG2(stream) )
return points
#-------------------------------------------------------------------------------

158
groth16/bn128/msm.nim Normal file
View File

@ -0,0 +1,158 @@
#
# Multi-Scalar Multiplication (MSM)
#
import system
# import constantine/curves_primitives except Fp, Fp2, Fr
import constantine/platforms/abstractions except Subgroup
import constantine/math/isogenies/frobenius except Subgroup
import constantine/math/arithmetic except Fp, Fp2, Fr
import constantine/math/io/io_fields except Fp, Fp2, Fr
import constantine/math/io/io_bigints
import constantine/math/config/curves except G1, G2, Subgroup
import constantine/math/config/type_ff except Fp, Fr, Subgroup
import constantine/math/extension_fields/towers as ext except Fp, Fp2, Fp12, Fr
import constantine/math/elliptic/ec_shortweierstrass_affine as aff except Subgroup
import constantine/math/elliptic/ec_shortweierstrass_projective as prj except Subgroup
import constantine/math/elliptic/ec_scalar_mul as scl except Subgroup
import constantine/math/elliptic/ec_multi_scalar_mul as msm except Subgroup
import groth16/bn128/fields
import groth16/bn128/curves
#-------------------------------------------------------------------------------
func msmConstantineG1*( coeffs: openArray[Fr] , points: openArray[G1] ): G1 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
var bigcfs : seq[BigInt[254]]
for x in coeffs:
bigcfs.add( x.toBig() )
var r : ProjG1
# [Fp,aff.G1]
msm.multiScalarMul_vartime( r,
toOpenArray(bigcfs, 0, N-1),
toOpenArray(points, 0, N-1) )
var rAff: G1
prj.affine(rAff, r)
return rAff
#---------------------------------------
func msmConstantineG2*( coeffs: openArray[Fr] , points: openArray[G2] ): G2 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
var bigcfs : seq[BigInt[254]]
for x in coeffs:
bigcfs.add( x.toBig() )
var r : ProjG2
# [Fp,aff.G1]
msm.multiScalarMul_vartime( r,
toOpenArray(bigcfs, 0, N-1),
toOpenArray(points, 0, N-1) )
var rAff: G2
prj.affine(rAff, r)
return rAff
#-------------------------------------------------------------------------------
#[
type InputTuple = tuple[idx:int, coeffs: openArray[Fr] , points: openArray[G1]]
func msmMultiThreadedG1*( coeffs: openArray[Fr] , points: openArray[G1] ): G1 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
let nthreadsTarget = 8
# for N <= 255 , we use 1 thread
# for N == 256 , we use 2 threads
# for N == 512 , we use 4 threads
# for N >= 1024, we use 8 threads
let nthreads = max( 1 , min( N div 128 , nthreadsTarget ) )
let m = N div nthreads
var threads : seq[Thread[InputTuple]] = newSeq[Thread[InputTuple]]( nthreads )
var results : seq[G1] = newSeq[G1]( nthreads )
proc myThreadFunc( inp: InputTuple ) {.thread.} =
results[inp.idx] = msmConstantineG1( inp.coeffs, inp.points )
for i in 0..<nthreads:
let a = i*m
let b = if (i == nthreads-1): N else: (i+1)*m
createThread(threads[i], myThreadFunc, (i, coeffs[a..<b], points[a..<b]))
joinThreads(threads)
var r : G1 = infG1
for i in 0..<nthreads: r += results[i]
return r
]#
#-------------------------------------------------------------------------------
func msmNaiveG1*( coeffs: seq[Fr] , points: seq[G1] ): G1 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
var s : ProjG1
s.setInf()
for i in 0..<N:
var t : ProjG1
prj.fromAffine( t, points[i] )
scl.scalarMul( t , coeffs[i].toBig() )
s += t
var r : G1
prj.affine( r, s )
return r
#---------------------------------------
func msmNaiveG2*( coeffs: seq[Fr] , points: seq[G2] ): G2 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
var s : ProjG2
s.setInf()
for i in 0..<N:
var t : ProjG2
prj.fromAffine( t, points[i] )
scl.scalarMul( t , coeffs[i].toBig() )
s += t
var r : G2
prj.affine( r, s)
return r
#-------------------------------------------------------------------------------
func msmG1*( coeffs: seq[Fr] , points: seq[G1] ): G1 = msmConstantineG1(coeffs, points)
func msmG2*( coeffs: seq[Fr] , points: seq[G2] ): G2 = msmConstantineG2(coeffs, points)
#-------------------------------------------------------------------------------

76
groth16/bn128/rnd.nim Normal file
View File

@ -0,0 +1,76 @@
import std/random
# import constantine/platforms/abstractions
import constantine/math/arithmetic except Fp, Fp2, Fr
import constantine/math/io/io_fields except Fp, Fp2, Fr
import constantine/math/io/io_bigints
import groth16/bn128/fields
#-------------------------------------------------------------------------------
# random values
var randomInitialized : bool = false
var randomState : Rand = initRand( 12345 )
proc rndUint64() : uint64 =
return randomState.next()
proc initializeRandomIfNecessary() =
if not randomInitialized:
randomState = initRand()
randomInitialized = true
#----------------------------| 01234567890abcdf01234567890abcdf01234567890abcdf01234567890abcdf
const m64 : B = fromHex( B, "0x0000000000000000000000000000000000000000000000010000000000000000", bigEndian )
const m128 : B = fromHex( B, "0x0000000000000000000000000000000100000000000000000000000000000000", bigEndian )
const m192 : B = fromHex( B, "0x0000000000000001000000000000000000000000000000000000000000000000", bigEndian )
#----------------------------| 01234567890abcdf01234567890abcdf01234567890abcdf01234567890abcdf
proc randBig*[bits: static int](): BigInt[bits] =
initializeRandomIfNecessary()
let a0 : uint64 = rndUint64()
let a1 : uint64 = rndUint64()
let a2 : uint64 = rndUint64()
let a3 : uint64 = rndUint64()
# echo((a0,a1,a2,a3))
var b0 : BigInt[bits] ; b0.fromUint(a0)
var b1 : BigInt[bits] ; b1.fromUint(a1)
var b2 : BigInt[bits] ; b2.fromUint(a2)
var b3 : BigInt[bits] ; b3.fromUint(a3)
# constantine doesn't appear to have left shift....
var c1,c2,c3 : BigInt[bits]
prod( c1 , b1 , m64 )
prod( c2 , b2 , m128 )
prod( c3 , b3 , m192 )
var d : BigInt[bits]
d = b0
d += c1
d += c2
d += c3
return d
proc randFr*(): Fr =
let b : BigInt[254] = randBig[254]()
var y : Fr
y.fromBig( b )
return y
proc testRandom*() =
for i in 1..20:
let x = randFr()
echo(x.toHex())
echo("-------------------")
echo(primeR.toHex())
#-------------------------------------------------------------------------------

View File

@ -0,0 +1,18 @@
import groth16/test_proof
import groth16/files/export_json
#-------------------------------------------------------------------------------
proc exampleProveAndVerify() =
let zkey_fname : string = "./build/product.zkey"
let wtns_fname : string = "./build/product.wtns"
let proof = testProveAndVerify( zkey_fname, wtns_fname)
exportPublicIO( "./build/nim_public.json" , proof )
exportProof( "./build/nim_proof.json" , proof )
#-------------------------------------------------------------------------------
when isMainModule:
exampleProveAndVerify()

View File

@ -7,16 +7,15 @@
#
import sugar
import std/sequtils
import constantine/math/arithmetic except Fp, Fr
import bn128
import domain
import poly
import zkey_types
import r1cs
import misc
import groth16/bn128
import groth16/math/domain
import groth16/math/poly
import groth16/zkey_types
import groth16/files/r1cs
import groth16/misc
#-------------------------------------------------------------------------------
@ -33,12 +32,13 @@ proc randomToxicWaste(): ToxicWaste =
let b = randFr()
let c = randFr()
let d = randFr()
let t = randFr()
return ToxicWaste( alpha: a
, beta: b
, gamma: c
, delta: d
, tau: t )
let t = randFr() # intToFr(106)
return
ToxicWaste( alpha: a
, beta: b
, gamma: c
, delta: d
, tau: t )
#-------------------------------------------------------------------------------
@ -129,6 +129,16 @@ func matricesToCoeffs*(matrices: Matrices): seq[Coeff] =
#-------------------------------------------------------------------------------
func dotProdFr(xs, ys: seq[Fr]): Fr =
let n = xs.len
assert( n == ys.len, "dotProdFr: incompatible vector lengths" )
var s : Fr = zeroFr
for i in 0..<n:
s += xs[i] * ys[i]
return s
#-------------------------------------------------------------------------------
func fakeCircuitSetup*(r1cs: R1CS, toxic: ToxicWaste, flavour=Snarkjs): ZKey =
let neqs = r1cs.constraints.len
@ -144,24 +154,26 @@ func fakeCircuitSetup*(r1cs: R1CS, toxic: ToxicWaste, flavour=Snarkjs): ZKey =
# echo("neqs = ",neqs)
# echo("domain = ",domSize)
let header = GrothHeader( curve: "bn128"
, flavour: flavour
, p: primeP
, r: primeR
, nvars: nvars
, npubs: npubs
, domainSize: domSize
, logDomainSize: logDomSize
)
let header =
GrothHeader( curve: "bn128"
, flavour: flavour
, p: primeP
, r: primeR
, nvars: nvars
, npubs: npubs
, domainSize: domSize
, logDomainSize: logDomSize
)
let spec = SpecPoints( alpha1 : toxic.alpha ** gen1
, beta1 : toxic.beta ** gen1
, beta2 : toxic.beta ** gen2
, gamma2 : toxic.gamma ** gen2
, delta1 : toxic.delta ** gen1
, delta2 : toxic.delta ** gen2
, alphaBeta : pairing( toxic.alpha ** gen1 , toxic.beta ** gen2 )
)
let spec =
SpecPoints( alpha1 : toxic.alpha ** gen1
, beta1 : toxic.beta ** gen1
, beta2 : toxic.beta ** gen2
, gamma2 : toxic.gamma ** gen2
, delta1 : toxic.delta ** gen1
, delta2 : toxic.delta ** gen2
, alphaBeta : pairing( toxic.alpha ** gen1 , toxic.beta ** gen2 )
)
let matrices = r1csToMatrices(r1cs)
let coeffs = r1csToCoeffs( r1cs )
@ -169,6 +181,9 @@ func fakeCircuitSetup*(r1cs: R1CS, toxic: ToxicWaste, flavour=Snarkjs): ZKey =
let D : Domain = createDomain(domSize)
#[
# this approach is very inefficient
let polyAs : seq[Poly] = collect( newSeq , (for col in matrices.A: polyInverseNTT(col, D) ))
let polyBs : seq[Poly] = collect( newSeq , (for col in matrices.B: polyInverseNTT(col, D) ))
let polyCs : seq[Poly] = collect( newSeq , (for col in matrices.C: polyInverseNTT(col, D) ))
@ -177,6 +192,20 @@ func fakeCircuitSetup*(r1cs: R1CS, toxic: ToxicWaste, flavour=Snarkjs): ZKey =
let pointsB1 : seq[G1] = collect( newSeq , (for p in polyBs: polyEvalAt(p, toxic.tau) ** gen1) )
let pointsB2 : seq[G2] = collect( newSeq , (for p in polyBs: polyEvalAt(p, toxic.tau) ** gen2) )
let pointsC : seq[G1] = collect( newSeq , (for p in polyCs: polyEvalAt(p, toxic.tau) ** gen1) )
]#
# the Lagrange polynomials L_k(x) evaluated at x=tau
# we can then simply take the dot product of these with the column vectors to compute the points A,B1,B2,C
let lagrangeTaus : seq[Fr] = collect( newSeq, (for k in 0..<domSize: evalLagrangePolyAt(D, k, toxic.tau) ))
let columnTausA : seq[Fr] = collect( newSeq, (for col in matrices.A: dotProdFr(col,lagrangeTaus) ))
let columnTausB : seq[Fr] = collect( newSeq, (for col in matrices.B: dotProdFr(col,lagrangeTaus) ))
let columnTausC : seq[Fr] = collect( newSeq, (for col in matrices.C: dotProdFr(col,lagrangeTaus) ))
let pointsA : seq[G1] = collect( newSeq , (for y in columnTausA: (y ** gen1) ))
let pointsB1 : seq[G1] = collect( newSeq , (for y in columnTausB: (y ** gen1) ))
let pointsB2 : seq[G2] = collect( newSeq , (for y in columnTausB: (y ** gen2) ))
let pointsC : seq[G1] = collect( newSeq , (for y in columnTausC: (y ** gen1) ))
let gammaInv : Fr = invFr(toxic.gamma)
let deltaInv : Fr = invFr(toxic.delta)
@ -193,37 +222,43 @@ func fakeCircuitSetup*(r1cs: R1CS, toxic: ToxicWaste, flavour=Snarkjs): ZKey =
var pointsH : seq[G1]
case flavour
#---------------------------------------------------------------------------
# in the original paper, these are the curve points
# [ delta^-1 * tau^i * Z(tau) ]
#
of JensGroth:
pointsH = collect( newSeq , (for i in 0..<domSize:
(deltaInv * smallPowFr(toxic.tau,i)) ** ztauG1 ))
#---------------------------------------------------------------------------
# in the Snarkjs implementation, these are the curve points
# [ delta^-1 * L_{2i+1} (tau) ]
# where L_k are the Lagrange polynomials on the refined domain
#
of Snarkjs:
let D2 : Domain = createDomain(2*domSize)
let eta : Fr = D2.domainGen
pointsH = collect( newSeq , (for i in 0..<domSize:
(deltaInv * evalLagrangePolyAt(D2, 2*i+1, toxic.tau)) ** gen1 ))
#---------------------------------------------------------------------------
let vPoints = VerifierPoints( pointsIC: pointsL )
let pPoints = ProverPoints( pointsA1: pointsA
, pointsB1: pointsB1
, pointsB2: pointsB2
, pointsC1: pointsK
, pointsH1: pointsH
)
let pPoints =
ProverPoints( pointsA1: pointsA
, pointsB1: pointsB1
, pointsB2: pointsB2
, pointsC1: pointsK
, pointsH1: pointsH
)
return ZKey( header: header
, specPoints: spec
, vPoints: vPoints
, pPoints: pPoints
, coeffs: coeffs
)
return
ZKey( header: header
, specPoints: spec
, vPoints: vPoints
, pPoints: pPoints
, coeffs: coeffs
)
#-------------------------------------------------------------------------------

View File

@ -6,8 +6,8 @@
import constantine/math/arithmetic except Fp, Fr
#import constantine/math/io/io_fields except Fp, Fr
import bn128
from ./groth16 import Proof
import groth16/bn128
from groth16/prover import Proof
#-------------------------------------------------------------------------------
@ -33,6 +33,8 @@ proc exportPublicIO*( fpath: string, prf: Proof ) =
let f = open(fpath, fmWrite)
defer: f.close()
# note: we start from 1 because the 0th element is the constant 1 "variable",
# which is automatically added by the tools
for i in 1..<n:
let str : string = toQuotedDecimalFr( prf.publicIO[i] )
if i==1:

View File

@ -0,0 +1,152 @@
#
# export proof, public input and verifier as a SageMath script
#
import std/strutils
import std/streams
import constantine/math/arithmetic except Fp, Fr
import groth16/bn128
import groth16/zkey_types
from groth16/prover import Proof
#-------------------------------------------------------------------------------
func toSpaces(str: string): string = spaces(str.len)
func sageFp(prefix: string, x: Fp): string = prefix & "Fp(" & toDecimalFp(x) & ")"
func sageFr(prefix: string, x: Fr): string = prefix & "Fr(" & toDecimalFr(x) & ")"
func sageFp2(prefix: string, z: Fp2): string =
sageFp( prefix & "mkFp2(" , z.coords[0]) & ",\n" &
sageFp( toSpaces(prefix) & " " , z.coords[1]) & ")"
func sageG1(prefix: string, p: G1): string =
sageFp( prefix & "E(" , p.x) & ",\n" &
sageFp( toSpaces(prefix) & " " , p.y) & ")"
func sageG2(prefix: string, p: G2): string =
sageFp2( prefix & "E2(" , p.x) & ",\n" &
sageFp2( toSpaces(prefix) & " " , p.y) & ")"
#-------------------------------------------------------------------------------
proc exportVKey(h: Stream, vkey: VKey ) =
let spec = vkey.spec
h.writeLine("alpha1 = \\") ; h.writeLine(sageG1(" ", spec.alpha1))
h.writeLine("beta2 = \\") ; h.writeLine(sageG2(" ", spec.beta2 ))
h.writeLine("gamma2 = \\") ; h.writeLine(sageG2(" ", spec.gamma2))
h.writeLine("delta2 = \\") ; h.writeLine(sageG2(" ", spec.delta2))
let pts = vkey.vpoints.pointsIC
h.writeLine("pointsIC = \\")
for i in 0..<pts.len:
let prefix = if (i==0): " [ " else: " "
let postfix = if (i<pts.len-1): "," else: " ]"
h.writeLine( sageG1(prefix, pts[i]) & postfix )
#---------------------------------------
proc exportProof*(h: Stream, prf: Proof ) =
h.writeLine("piA = \\") ; h.writeLine(sageG1(" ", prf.pi_a ))
h.writeLine("piB = \\") ; h.writeLine(sageG2(" ", prf.pi_b ))
h.writeLine("piC = \\") ; h.writeLine(sageG1(" ", prf.pi_c ))
# note: the first element is just the constant 1
let coeffs = prf.publicIO
h.writeLine("pubIO = \\")
for i in 0..<coeffs.len:
let prefix = if (i==0): " [ " else: " "
let postfix = if (i<coeffs.len-1): "," else: " ]"
h.writeLine( prefix & toDecimalFr(coeffs[i]) & postfix )
#-------------------------------------------------------------------------------
const sage_bn128_lines : seq[string] =
@[ "# BN128 elliptic curve"
, "p = 21888242871839275222246405745257275088696311157297823662689037894645226208583"
, "r = 21888242871839275222246405745257275088548364400416034343698204186575808495617"
, "h = 1"
, "Fp = GF(p)"
, "Fr = GF(r)"
, "A = Fp(0)"
, "B = Fp(3)"
, "E = EllipticCurve(Fp,[A,B])"
, "gx = Fp(1)"
, "gy = Fp(2)"
, "gen = E(gx,gy) # subgroup generator"
, "print(\"scalar field check: \", gen.additive_order() == r )"
, "print(\"cofactor check: \", E.cardinality() == r*h )"
, ""
, "# r and trace of Frobenius from the BN parameter x"
, "x = 4965661367192848881"
, "bn_r=36*x^4+36*x^3+18*x^2+6*x+1"
, "bn_t=6*x^2+1"
, "print(\"BN r = \",bn_r)"
, "print(\"BN t = \",bn_t)"
, "print(\"test p+1 === t (mod r) : \", mod(p+1-bn_t,r) )"
, ""
, "# extension tower"
, "R.<x> = Fp[]"
, "Fp2.<u> = Fp.extension(x^2+1)"
, "def mkFp2(a,b):"
, " return ( a + u*b )"
, "R.<x> = Fp2[]"
, "Fp12.<w> = Fp2.extension(x^6 - (9+u))"
, "E12 = E.base_extend(Fp12)"
, ""
, "# twisted curve"
, "B_twist = Fp2(19485874751759354771024239261021720505790618469301721065564631296452457478373 + 266929791119991161246907387137283842545076965332900288569378510910307636690*u )"
, "E2 = EllipticCurve(Fp2,[0,B_twist])"
, "size_E2 = E2.cardinality();"
, "cofactor_E2 = size_E2 / r;"
, "print(\"|E2| = \", size_E2 );"
, "print(\"h(E2) = \", cofactor_E2 );"
, ""
, "# map from E2 to E12"
, "def Psi(pt):"
, " pt.normalize_coordinates()"
, " x = pt[0]"
, " y = pt[1]"
, " return E12( Fp12(w^2 * x) , Fp12(w^3 * y) )"
, ""
, "def pairing(P,Q):"
, " return E12(P).ate_pairing( Psi(Q), n=r, k=12, t=bn_t, q=p^12 )"
, ""
]
const sage_bn128 : string = join(sage_bn128_lines, sep="\n")
#-------------------------------------------------------------------------------
const verify_lines : seq[string] =
@[ "pubG1 = pointsIC[0]"
, "for i in [1..len(pubIO)-1]:"
, " pubG1 = pubG1 + pubIO[i]*pointsIC[i]"
, ""
, "lhs = pairing( -piA , piB )"
, "rhs1 = pairing( alpha1 , beta2 )"
, "rhs2 = pairing( piC , delta2 )"
, "rhs3 = pairing( pubG1 , gamma2 )"
, "eq = lhs * rhs1 * rhs2 * rhs3"
, "print(\"verification suceeded =\\n\",eq == 1)"
]
const verify_script : string = join(verify_lines, sep="\n")
#-------------------------------------------------------------------------------
proc exportSage*(fpath: string, vkey: VKey, prf: Proof) =
let h = openFileStream(fpath, fmWrite)
defer: h.close()
h.writeLine(sage_bn128)
h.exportVKey(vkey);
h.exportProof(prf);
h.writeLine(verify_script)
#-------------------------------------------------------------------------------

View File

@ -43,7 +43,7 @@
# ...
# ...
#
# 4: Custom gates application
# 5: Custom gates application
# ---------------------------
# ...
# ...
@ -54,8 +54,8 @@ import std/streams
import constantine/math/arithmetic except Fp, Fr
import constantine/math/io/io_bigints
import ./bn128
import ./container
import groth16/bn128
import groth16/files/container
#-------------------------------------------------------------------------------

View File

@ -11,9 +11,7 @@
#
# nvars = 1 + pub + secret = 1 + npubout + npubin + nprivin + nsecret
#
# NOTE: Unlike the `.zkey` files, which encode field elements in the
# Montgomery representation, the `.wtns` file encode field elements in
# the standard representation!
# Field elements are encoded in the standard representation.
#
import std/streams
@ -21,8 +19,8 @@ import std/streams
import constantine/math/arithmetic except Fp, Fr
import constantine/math/io/io_bigints
import ./bn128
import ./container
import groth16/bn128
import groth16/files/container
#-------------------------------------------------------------------------------

View File

@ -31,7 +31,7 @@
# beta2 : G2 = [beta]_2
# gamma2 : G2 = [gamma]_2
# delta1 : G1 = [delta]_1
# delta2 : G2 = [delta_2]
# delta2 : G2 = [delta]_2
#
# 3: IC
# -----
@ -97,10 +97,10 @@ import std/streams
import constantine/math/arithmetic except Fp, Fr
#import constantine/math/io/io_bigints
import ./bn128
import ./zkey_types
import ./container
import ./misc
import groth16/bn128
import groth16/zkey_types
import groth16/files/container
import groth16/misc
#-------------------------------------------------------------------------------

View File

@ -7,8 +7,8 @@ import constantine/math/arithmetic except Fp,Fr
import constantine/math/io/io_fields except Fp,Fr
#import constantine/math/io/io_bigints
import ./bn128
import ./misc
import groth16/bn128
import groth16/misc
#-------------------------------------------------------------------------------

View File

@ -9,8 +9,8 @@
import constantine/math/arithmetic except Fp,Fr
import constantine/math/io/io_fields
import bn128
import domain
import groth16/bn128
import groth16/math/domain
#-------------------------------------------------------------------------------

View File

@ -10,12 +10,12 @@ import std/sequtils
import std/sugar
import constantine/math/arithmetic except Fp,Fr
import constantine/math/io/io_fields
#import constantine/math/io/io_fields
import bn128
import domain
import ntt
import misc
import groth16/bn128
import groth16/math/domain
import groth16/math/ntt
import groth16/misc
#-------------------------------------------------------------------------------
@ -220,7 +220,7 @@ func polyDivideByVanishing*(P: Poly, N: int): Poly =
#-------------------------------------------------------------------------------
# Lagrange basis polynomials
func lagrangePoly(D: Domain, k: int): Poly =
func lagrangePoly*(D: Domain, k: int): Poly =
let N = D.domainSize
let omMinusK : Fr = smallPowFr( D.invDomainGen , k )
let invN : Fr = invFr(intToFr(N))

216
groth16/prover.nim Normal file
View File

@ -0,0 +1,216 @@
#
# Groth16 prover
#
# WARNING!
# the points H in `.zkey` are *NOT* what normal people would think they are
# See <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
#
#[
import sugar
import constantine/math/config/curves
import constantine/math/io/io_fields
import constantine/math/io/io_bigints
import ./zkey
]#
import constantine/math/arithmetic except Fp, Fr
#import constantine/math/io/io_extfields except Fp12
#import constantine/math/extension_fields/towers except Fp2, Fp12
import groth16/bn128
import groth16/math/domain
import groth16/math/poly
import groth16/zkey_types
import groth16/files/witness
#-------------------------------------------------------------------------------
type
Proof* = object
publicIO* : seq[Fr]
pi_a* : G1
pi_b* : G2
pi_c* : G1
curve* : string
#-------------------------------------------------------------------------------
# A, B, C column vectors
#
type
ABC = object
valuesA : seq[Fr]
valuesB : seq[Fr]
valuesC : seq[Fr]
func buildABC( zkey: ZKey, witness: seq[Fr] ): ABC =
let hdr: GrothHeader = zkey.header
let domSize = hdr.domainSize
var valuesA : seq[Fr] = newSeq[Fr](domSize)
var valuesB : seq[Fr] = newSeq[Fr](domSize)
for entry in zkey.coeffs:
case entry.matrix
of MatrixA: valuesA[entry.row] += entry.coeff * witness[entry.col]
of MatrixB: valuesB[entry.row] += entry.coeff * witness[entry.col]
else: raise newException(AssertionDefect, "fatal error")
var valuesC : seq[Fr] = newSeq[Fr](domSize)
for i in 0..<domSize:
valuesC[i] = valuesA[i] * valuesB[i]
return ABC( valuesA:valuesA, valuesB:valuesB, valuesC:valuesC )
#-------------------------------------------------------------------------------
# quotient poly
#
# interpolates A,B,C, and computes the quotient polynomial Q = (A*B - C) / Z
func computeQuotientNaive( abc: ABC ): Poly=
let n = abc.valuesA.len
assert( abc.valuesB.len == n )
assert( abc.valuesC.len == n )
let D = createDomain(n)
let polyA : Poly = polyInverseNTT( abc.valuesA , D )
let polyB : Poly = polyInverseNTT( abc.valuesB , D )
let polyC : Poly = polyInverseNTT( abc.valuesC , D )
let polyBig = polyMulFFT( polyA , polyB ) - polyC
var polyQ = polyDivideByVanishing(polyBig, D.domainSize)
polyQ.coeffs.add( zeroFr ) # make it a power of two
return polyQ
#---------------------------------------
# returns [ eta^i * xs[i] | i<-[0..n-1] ]
func multiplyByPowers( xs: seq[Fr], eta: Fr ): seq[Fr] =
let n = xs.len
assert(n >= 1)
var ys : seq[Fr] = newSeq[Fr](n)
ys[0] = xs[0]
if n >= 1: ys[1] = eta * xs[1]
var spow : Fr = eta
for i in 2..<n:
spow *= eta
ys[i] = spow * xs[i]
return ys
# interpolates a polynomial, shift the variable by `eta`, and compute the shifted values
func shiftEvalDomain( values: seq[Fr], D: Domain, eta: Fr ): seq[Fr] =
let poly : Poly = polyInverseNTT( values , D )
let cs : seq[Fr] = poly.coeffs
var ds : seq[Fr] = multiplyByPowers( cs, eta )
return polyForwardNTT( Poly(coeffs:ds), D )
# computes the quotient polynomial Q = (A*B - C) / Z
# by computing the values on a shifted domain, and interpolating the result
# remark: Q has degree `n-2`, so it's enough to use a domain of size n
func computeQuotientPointwise( abc: ABC ): Poly =
let n = abc.valuesA.len
let D = createDomain(n)
# (eta*omega^j)^n - 1 = eta^n - 1
# 1 / [ (eta*omega^j)^n - 1] = 1/(eta^n - 1)
let eta = createDomain(2*n).domainGen
let invZ1 = invFr( smallPowFr(eta,n) - oneFr )
let A1 = shiftEvalDomain( abc.valuesA, D, eta )
let B1 = shiftEvalDomain( abc.valuesB, D, eta )
let C1 = shiftEvalDomain( abc.valuesC, D, eta )
var ys : seq[Fr] = newSeq[Fr]( n )
for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] ) * invZ1
let Q1 = polyInverseNTT( ys, D )
let cs = multiplyByPowers( Q1.coeffs, invFr(eta) )
return Poly(coeffs: cs)
#---------------------------------------
# Snarkjs does something different, not actually computing the quotient poly
# they can get away with this, because during the trusted setup, they
# replace the points encoding the values `delta^-1 * tau^i * Z(tau)` by
# (shifted) Lagrange bases.
# see <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
#
func computeSnarkjsScalarCoeffs( abc: ABC ): seq[Fr] =
let n = abc.valuesA.len
let D = createDomain(n)
let eta = createDomain(2*n).domainGen
let A1 = shiftEvalDomain( abc.valuesA, D, eta )
let B1 = shiftEvalDomain( abc.valuesB, D, eta )
let C1 = shiftEvalDomain( abc.valuesC, D, eta )
var ys : seq[Fr] = newSeq[Fr]( n )
for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] )
return ys
#-------------------------------------------------------------------------------
# the prover
#
proc generateProof*( zkey: ZKey, wtns: Witness ): Proof =
assert( zkey.header.curve == wtns.curve )
let witness = wtns.values
let hdr : GrothHeader = zkey.header
let spec : SpecPoints = zkey.specPoints
let pts : ProverPoints = zkey.pPoints
let nvars = hdr.nvars
let npubs = hdr.npubs
assert( nvars == witness.len , "wrong witness length" )
var pubIO : seq[Fr] = newSeq[Fr]( npubs + 1)
for i in 0..npubs: pubIO[i] = witness[i]
var abc : ABC = buildABC( zkey, witness )
var qs : seq[Fr]
case zkey.header.flavour
# the points H are [delta^-1 * tau^i * Z(tau)]
of JensGroth:
let polyQ = computeQuotientPointwise( abc )
qs = polyQ.coeffs
# the points H are `[delta^-1 * L_{2i+1}(tau)]_1`
# where L_i are Lagrange basis polynomials on the double-sized domain
of Snarkjs:
qs = computeSnarkjsScalarCoeffs( abc )
var zs : seq[Fr] = newSeq[Fr]( nvars - npubs - 1 )
for j in npubs+1..<nvars:
zs[j-npubs-1] = witness[j]
# masking coeffs
let r : Fr = randFr()
let s : Fr = randFr()
var pi_a : G1
pi_a = spec.alpha1
pi_a += r ** spec.delta1
pi_a += msmG1( witness , pts.pointsA1 )
var rho : G1
rho = spec.beta1
rho += s ** spec.delta1
rho += msmG1( witness , pts.pointsB1 )
var pi_b : G2
pi_b = spec.beta2
pi_b += s ** spec.delta2
pi_b += msmG2( witness , pts.pointsB2 )
var pi_c : G1
pi_c = s ** pi_a
pi_c += r ** rho
pi_c += negFr(r*s) ** spec.delta1
pi_c += msmG1( qs , pts.pointsH1 )
pi_c += msmG1( zs , pts.pointsC1 )
return Proof( curve:"bn128", publicIO:pubIO, pi_a:pi_a, pi_b:pi_b, pi_c:pi_c )
#-------------------------------------------------------------------------------

View File

@ -1,34 +1,42 @@
import ./groth16
import ./witness
import ./r1cs
import ./zkey
import ./zkey_types
import ./fake_setup
import std/[times,os]
import strformat
import groth16/prover
import groth16/verifier
import groth16/files/witness
import groth16/files/r1cs
import groth16/files/zkey
import groth16/zkey_types
import groth16/fake_setup
func seconds(x: float): string = fmt"{x:.4f} seconds"
#-------------------------------------------------------------------------------
proc testProveAndVerify*( zkey_fname, wtns_fname: string): Proof =
proc testProveAndVerify*( zkey_fname, wtns_fname: string): (VKey,Proof) =
echo("parsing witness & zkey files...")
let witness = parseWitness( wtns_fname)
let zkey = parseZKey( zkey_fname)
# printCoeffs(zkey.coeffs)
echo("generating proof...")
let vkey = extractVKey( zkey)
let start = cpuTime()
let proof = generateProof( zkey, witness )
let elapsed = cpuTime() - start
echo("proving took ",seconds(elapsed))
echo("verifying the proof...")
let ok = verifyProof( vkey, proof )
let vkey = extractVKey( zkey)
let ok = verifyProof( vkey, proof )
echo("verification succeeded = ",ok)
return proof
return (vkey,proof)
#-------------------------------------------------------------------------------
proc testFakeSetupAndVerify*( r1cs_fname, wtns_fname: string, flavour=Snarkjs): Proof =
proc testFakeSetupAndVerify*( r1cs_fname, wtns_fname: string, flavour=Snarkjs): (VKey,Proof) =
echo("trusted setup flavour = ",flavour)
echo("parsing witness & r1cs files...")
@ -36,16 +44,23 @@ proc testFakeSetupAndVerify*( r1cs_fname, wtns_fname: string, flavour=Snarkjs):
let r1cs = parseR1CS( r1cs_fname)
echo("performing fake trusted setup...")
let start1 = cpuTime()
let zkey = createFakeCircuitSetup( r1cs, flavour=flavour )
let elapsed1 = cpuTime() - start1
echo("fake setup took ",seconds(elapsed1))
# printCoeffs(zkey.coeffs)
echo("generating proof...")
let vkey = extractVKey( zkey)
let start = cpuTime()
let proof = generateProof( zkey, witness )
let elapsed = cpuTime() - start
echo("proving took ",seconds(elapsed))
echo("verifying the proof...")
let ok = verifyProof( vkey, proof )
echo("verification succeeded = ",ok)
return proof
return (vkey,proof)

54
groth16/verifier.nim Normal file
View File

@ -0,0 +1,54 @@
#
# Groth16 prover
#
# WARNING!
# the points H in `.zkey` are *NOT* what normal people would think they are
# See <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
#
#[
import sugar
import constantine/math/config/curves
import constantine/math/io/io_fields
import constantine/math/io/io_bigints
import ./zkey
]#
# import constantine/math/arithmetic except Fp, Fr
import constantine/math/io/io_extfields except Fp12
import constantine/math/extension_fields/towers except Fp2, Fp12
import groth16/bn128
import groth16/zkey_types
from groth16/prover import Proof
#-------------------------------------------------------------------------------
# the verifier
#
proc verifyProof* (vkey: VKey, prf: Proof): bool =
assert( prf.curve == "bn128" )
assert( isOnCurveG1(prf.pi_a) , "pi_a is not in G1" )
assert( isOnCurveG2(prf.pi_b) , "pi_b is not in G2" )
assert( isOnCurveG1(prf.pi_c) , "pi_c is not in G1" )
var pubG1 : G1 = msmG1( prf.publicIO , vkey.vpoints.pointsIC )
let lhs : Fp12 = pairing( negG1(prf.pi_a) , prf.pi_b ) # < -pi_a , pi_b >
let rhs1 : Fp12 = vkey.spec.alphaBeta # < alpha , beta >
let rhs2 : Fp12 = pairing( prf.pi_c , vkey.spec.delta2 ) # < pi_c , delta >
let rhs3 : Fp12 = pairing( pubG1 , vkey.spec.gamma2 ) # < sum... , gamma >
var eq : Fp12
eq = lhs
eq *= rhs1
eq *= rhs2
eq *= rhs3
return bool(isOne(eq))
#-------------------------------------------------------------------------------

View File

@ -1,7 +1,7 @@
import constantine/math/arithmetic except Fp, Fr
import ./bn128
import groth16/bn128
#-------------------------------------------------------------------------------
@ -80,14 +80,14 @@ func matrixSelToString(sel: MatrixSel): string =
of MatrixB: return "B"
of MatrixC: return "C"
proc printCoeff(cf: Coeff) =
proc debugPrintCoeff(cf: Coeff) =
echo( "matrix=", matrixSelToString(cf.matrix)
, " | i=", cf.row
, " | j=", cf.col
, " | val=", signedToDecimalFr(cf.coeff)
)
proc printCoeffs*(cfs: seq[Coeff]) =
for cf in cfs: printCoeff(cf)
proc debugPrintCoeffs*(cfs: seq[Coeff]) =
for cf in cfs: debugPrintCoeff(cf)
#-------------------------------------------------------------------------------

View File

@ -0,0 +1,75 @@
import std/unittest
import std/sequtils
import groth16/prover
import groth16/verifier
import groth16/fake_setup
import groth16/zkey_types
import groth16/files/witness
import groth16/files/r1cs
import groth16/bn128/fields
#-------------------------------------------------------------------------------
# simple hand-crafted arithmetic circuit
#
const myWitnessCfg =
WitnessConfig( nWires: 7
, nPubOut: 1 # public output = input + a*b*c = 1022 + 7*11*13 = 2023
, nPubIn: 1 # public input = 1022
, nPrivIn: 3 # private inputs: 7, 11, 13
, nLabels: 0
)
# 2023 == 1022 + 7*3*11
const myEq1 : Constraint = ( @[] , @[] , @[ (0,minusOneFr) , (1,oneFr) , (6,oneFr) ] )
# 7*11 == 77
const myEq2 : Constraint = ( @[ (2,oneFr) ] , @[ (3,oneFr) ] , @[ (5,oneFr) ] )
# 77*13 == 1001
const myEq3 : Constraint = ( @[ (4,oneFr) ] , @[ (5,oneFr) ] , @[ (6,oneFr) ] )
const myConstraints : seq[Constraint] = @[ myEq1, myEq2, myEq3 ]
const myLabels : seq[int] = @[]
const myR1CS =
R1CS( r: primeR
, cfg: myWitnessCfg
, nConstr: myConstraints.len
, constraints: myConstraints
, wireToLabel: myLabels
)
# the equation we want prove is `7*11*13 + 1022 == 2023`
let myWitnessValues : seq[Fr] = map( @[ 2023, 1022, 7, 11, 13, 7*11, 7*11*13 ] , intToFr )
# wire indices: ^^^^^^^ 0 1 2 3 4 5 6
let myWitness =
Witness( curve: "bn128"
, r: primeR
, nvars: 7
, values: myWitnessValues
)
#-------------------------------------------------------------------------------
proc testProof(zkey: ZKey, witness: Witness): bool =
let proof = generateProof( zkey, witness )
let vkey = extractVKey( zkey)
let ok = verifyProof( vkey, proof )
return ok
suite "prover":
test "prove & verify simple multiplication circuit, `JensGroth` flavour":
let zkey = createFakeCircuitSetup( myR1cs, flavour=JensGroth )
check testProof( zkey, myWitness )
test "prove & verify simple multiplication circuit, `Snarkjs` flavour":
let zkey = createFakeCircuitSetup( myR1cs, flavour=Snarkjs )
check testProof( zkey, myWitness )
#-------------------------------------------------------------------------------

1
tests/nim.cfg Normal file
View File

@ -0,0 +1 @@
--path:".."

3
tests/test.nim Normal file
View File

@ -0,0 +1,3 @@
import ./groth16/testProver