mostly fixes to make it compile with latest constantine and codex

This commit is contained in:
Dmitriy Ryajov 2025-05-28 19:36:01 -06:00
parent 5616a1c52f
commit 434170541e
No known key found for this signature in database
GPG Key ID: DA8C680CE7C657A4
28 changed files with 689 additions and 609 deletions

View File

@ -1,5 +1,5 @@
import sugar
import std/sugar
import std/strutils
import std/sequtils
import std/os
@ -36,10 +36,10 @@ proc printHelp() =
echo " -u, --setup : perform (fake) trusted setup"
echo " -n, --nomask : don't use random masking for full ZK"
echo " -z, --zkey = <circuit.zkey> : the `.zkey` file"
echo " -w, --wtns = <circuit.wtns> : the `.wtns` file"
echo " -r, --r1cs = <circuit.r1cs> : the `.r1cs` file"
echo " -w, --wtns = <circuit.wtns> : the `.wtns` file"
echo " -r, --r1cs = <circuit.r1cs> : the `.r1cs` file"
echo " -o, --output = <proof.json> : the proof file"
echo " -i, --io = <public.json> : the public input/output file"
echo " -i, --io = <public.json> : the public input/output file"
#-------------------------------------------------------------------------------
@ -58,7 +58,7 @@ type Config = object
no_masking: bool
nthreads: int
const dummyConfig =
const dummyConfig =
Config( zkey_file: ""
, r1cs_file: ""
, wtns_file: ""
@ -123,13 +123,13 @@ proc parseCliOptions(): Config =
quit()
of cmdEnd:
discard
discard
if swCtr==0 and argCtr==0:
printHelp()
quit()
if cfg.nthreads <= 0:
if cfg.nthreads <= 0:
cfg.nthreads = countProcessors()
return cfg
@ -137,7 +137,7 @@ proc parseCliOptions(): Config =
#-------------------------------------------------------------------------------
#[
proc testProveAndVerify*( zkey_fname, wtns_fname: string): (VKey,Proof) =
proc testProveAndVerify*( zkey_fname, wtns_fname: string): (VKey,Proof) =
echo("parsing witness & zkey files...")
let witness = parseWitness( wtns_fname)
@ -170,20 +170,20 @@ proc cliMain(cfg: Config) =
echo("\nparsing witness file " & quoted(cfg.wtns_file))
withMeasureTime(cfg.measure_time,"parsing the witness"):
wtns = parseWitness(cfg.wtns_file)
if not (cfg.zkey_file == ""):
echo("\nparsing zkey file " & quoted(cfg.zkey_file))
withMeasureTime(cfg.measure_time,"parsing the zkey"):
zkey = parseZKey(cfg.zkey_file)
if not (cfg.r1cs_file == ""):
echo("\nparsing r1cs file " & quoted(cfg.r1cs_file))
withMeasureTime(cfg.measure_time,"parsing the r1cs"):
r1cs = parseR1CS(cfg.r1cs_file)
if cfg.do_setup:
if not (cfg.zkey_file == ""):
echo("\nwe are doing a fake trusted setup, don't specify the zkey file!")
echo("\nwe are doing a fake trusted setup, don't specify the zkey file!")
quit()
if (cfg.r1cs_file == ""):
echo("\nerror: r1cs file is required for the fake setup!")
@ -191,24 +191,24 @@ proc cliMain(cfg: Config) =
echo("\nperforming fake trusted setup...")
withMeasureTime(cfg.measure_time,"fake setup"):
zkey = createFakeCircuitSetup( r1cs, flavour=Snarkjs )
if cfg.debug:
printGrothHeader(zkey.header)
# debugPrintCoeffs(zkey.coeffs)
if cfg.do_prove:
if (cfg.wtns_file=="") or (cfg.zkey_file=="" and cfg.do_setup==false):
echo("cannot prove: missing witness and/or zkey file!")
echo("cannot prove: missing witness and/or zkey file!")
quit()
else:
echo("generating proof...")
let print_timings = cfg.measure_time and cfg.verbose
withMeasureTime(cfg.measure_time,"proving"):
if cfg.no_masking:
proof = generateProofWithTrivialMask(cfg.nthreads, print_timings, zkey, wtns)
proof = generateProofWithTrivialMask(zkey, wtns, cfg.nthreads, print_timings)
else:
proof = generateProof(cfg.nthreads, print_timings, zkey, wtns)
proof = generateProof(zkey, wtns, cfg.nthreads, print_timings)
if not (cfg.output_file == ""):
echo("exporting the proof to " & quoted(cfg.output_file))
exportProof( cfg.output_file, proof )
@ -218,7 +218,7 @@ proc cliMain(cfg: Config) =
if cfg.do_verify:
if (cfg.zkey_file == "" and cfg.do_setup==false):
echo("cannot verify: missing vkey (well, zkey)")
echo("cannot verify: missing vkey (well, zkey)")
quit()
else:
let vkey = extractVKey( zkey)
@ -227,7 +227,7 @@ proc cliMain(cfg: Config) =
withMeasureTime(cfg.measure_time,"verifying"):
ok = verifyProof( vkey, proof )
echo("verification succeeded = ",ok)
echo("")
#-------------------------------------------------------------------------------

View File

@ -1,3 +1 @@
--path:".."
--threads:on
--mm:arc

View File

@ -1,12 +1,16 @@
import groth16/bn128
import groth16/files/zkey
import groth16/zkey_types
import groth16/files/r1cs
import groth16/files/witness
import groth16/prover
import groth16/verifier
export r1cs
export bn128
export zkey
export zkey_types
export witness
export prover
export verifier

View File

@ -8,4 +8,4 @@ binDir = "build"
namedBin = {"cli/cli_main": "nim-groth16"}.toTable()
requires "https://github.com/status-im/nim-taskpools"
requires "https://github.com/mratsim/constantine#5f7ba18f2ed351260015397c9eae079a6decaee1"
requires "constantine >= 0.2.0"

View File

@ -13,36 +13,36 @@
#import constantine/platforms/abstractions
#import constantine/math/isogenies/frobenius
import constantine/math/arithmetic except Fp, Fr
import constantine/math/io/io_fields except Fp, Fr
import constantine/math/io/io_bigints
import constantine/math/config/curves
import pkg/constantine/math/arithmetic
import pkg/constantine/math/io/io_fields
import pkg/constantine/math/io/io_bigints
# import pkg/constantine/math/config/curves
import constantine/math/config/type_ff as tff except Fp, Fr
import constantine/math/extension_fields/towers as ext except Fp, Fp2, Fp12, Fr
import pkg/constantine/named/properties_fields as tff
import pkg/constantine/math/extension_fields/towers as ext
import constantine/math/elliptic/ec_shortweierstrass_affine as aff
import constantine/math/elliptic/ec_shortweierstrass_projective as prj
import constantine/math/pairings/pairings_bn as ate
import constantine/math/elliptic/ec_scalar_mul_vartime as scl
import pkg/constantine/math/elliptic/ec_shortweierstrass_affine as aff
import pkg/constantine/math/elliptic/ec_shortweierstrass_projective as prj
import pkg/constantine/math/pairings/pairings_bn as ate
import pkg/constantine/math/elliptic/ec_scalar_mul_vartime as scl
import groth16/bn128/fields
#-------------------------------------------------------------------------------
type G1* = aff.ECP_ShortW_Aff[Fp , aff.G1]
type G2* = aff.ECP_ShortW_Aff[Fp2, aff.G2]
type G1* = aff.EC_ShortW_Aff[Fp[BN254_Snarks] , aff.G1]
type G2* = aff.EC_ShortW_Aff[Fp2[BN254_Snarks], aff.G2]
type ProjG1* = prj.ECP_ShortW_Prj[Fp , prj.G1]
type ProjG2* = prj.ECP_ShortW_Prj[Fp2, prj.G2]
type ProjG1* = prj.EC_ShortW_Prj[Fp[BN254_Snarks] , prj.G1]
type ProjG2* = prj.EC_ShortW_Prj[Fp2[BN254_Snarks], prj.G2]
#-------------------------------------------------------------------------------
func unsafeMkG1* ( X, Y: Fp ) : G1 =
return aff.ECP_ShortW_Aff[Fp, aff.G1](x: X, y: Y)
func unsafeMkG1* ( X, Y: Fp[BN254_Snarks] ) : G1 =
return aff.EC_ShortW_Aff[Fp[BN254_Snarks], aff.G1](x: X, y: Y)
func unsafeMkG2* ( X, Y: Fp2 ) : G2 =
return aff.ECP_ShortW_Aff[Fp2, aff.G2](x: X, y: Y)
func unsafeMkG2* ( X, Y: Fp2[BN254_Snarks] ) : G2 =
return aff.EC_ShortW_Aff[Fp2[BN254_Snarks], aff.G2](x: X, y: Y)
#-------------------------------------------------------------------------------
@ -51,15 +51,15 @@ const infG2* : G2 = unsafeMkG2( zeroFp2 , zeroFp2 )
#-------------------------------------------------------------------------------
func checkCurveEqG1*( x, y: Fp ) : bool =
func checkCurveEqG1*( x, y: Fp[BN254_Snarks] ) : bool =
if bool(isZero(x)) and bool(isZero(y)):
# the point at infinity is on the curve by definition
return true
else:
var x2 : Fp = squareFp(x)
var y2 : Fp = squareFp(y)
var x3 : Fp = x2 * x
var eq : Fp
var x2 = squareFp(x)
var y2 = squareFp(y)
var x3 = x2 * x
var eq : Fp[BN254_Snarks]
eq = x3
eq += intToFp(3)
eq -= y2
@ -72,19 +72,19 @@ func checkCurveEqG1*( x, y: Fp ) : bool =
# B = b1 + bu*u
# b1 = 19485874751759354771024239261021720505790618469301721065564631296452457478373
# b2 = 266929791119991161246907387137283842545076965332900288569378510910307636690
const twistCoeffB_1 : Fp = fromHex(Fp, "0x2b149d40ceb8aaae81be18991be06ac3b5b4c5e559dbefa33267e6dc24a138e5")
const twistCoeffB_u : Fp = fromHex(Fp, "0x009713b03af0fed4cd2cafadeed8fdf4a74fa084e52d1852e4a2bd0685c315d2")
const twistCoeffB : Fp2 = mkFp2( twistCoeffB_1 , twistCoeffB_u )
const twistCoeffB_1 = fromHex(Fp[BN254_Snarks], "0x2b149d40ceb8aaae81be18991be06ac3b5b4c5e559dbefa33267e6dc24a138e5")
const twistCoeffB_u = fromHex(Fp[BN254_Snarks], "0x009713b03af0fed4cd2cafadeed8fdf4a74fa084e52d1852e4a2bd0685c315d2")
const twistCoeffB = mkFp2( twistCoeffB_1 , twistCoeffB_u )
func checkCurveEqG2*( x, y: Fp2 ) : bool =
func checkCurveEqG2*( x, y: Fp2[BN254_Snarks] ) : bool =
if isZeroFp2(x) and isZeroFp2(y):
# the point at infinity is on the curve by definition
return true
else:
var x2 : Fp2 = squareFp2(x)
var y2 : Fp2 = squareFp2(y)
var x3 : Fp2 = x2 * x;
var eq : Fp2
var x2 = squareFp2(x)
var y2 = squareFp2(y)
var x3 = x2 * x
var eq : Fp2[BN254_Snarks]
eq = x3
eq += twistCoeffB
eq -= y2
@ -92,14 +92,14 @@ func checkCurveEqG2*( x, y: Fp2 ) : bool =
#-------------------------------------------------------------------------------
func mkG1*( x, y: Fp ) : G1 =
func mkG1*( x, y: Fp[BN254_Snarks] ) : G1 =
if isZeroFp(x) and isZeroFp(y):
return infG1
else:
assert( checkCurveEqG1(x,y) , "mkG1: not a G1 curve point" )
return unsafeMkG1(x,y)
func mkG2*( x, y: Fp2 ) : G2 =
func mkG2*( x, y: Fp2[BN254_Snarks] ) : G2 =
if isZeroFp2(x) and isZeroFp2(y):
return infG2
else:
@ -109,16 +109,16 @@ func mkG2*( x, y: Fp2 ) : G2 =
#-------------------------------------------------------------------------------
# group generators
const gen1_x : Fp = fromHex(Fp, "0x01")
const gen1_y : Fp = fromHex(Fp, "0x02")
const gen1_x = fromHex(Fp[BN254_Snarks], "0x01")
const gen1_y = fromHex(Fp[BN254_Snarks], "0x02")
const gen2_xi : Fp = fromHex(Fp, "0x1adcd0ed10df9cb87040f46655e3808f98aa68a570acf5b0bde23fab1f149701")
const gen2_xu : Fp = fromHex(Fp, "0x09e847e9f05a6082c3cd2a1d0a3a82e6fbfbe620f7f31269fa15d21c1c13b23b")
const gen2_yi : Fp = fromHex(Fp, "0x056c01168a5319461f7ca7aa19d4fcfd1c7cdf52dbfc4cbee6f915250b7f6fc8")
const gen2_yu : Fp = fromHex(Fp, "0x0efe500a2d02dd77f5f401329f30895df553b878fc3c0dadaaa86456a623235c")
const gen2_xi = fromHex(Fp[BN254_Snarks], "0x1adcd0ed10df9cb87040f46655e3808f98aa68a570acf5b0bde23fab1f149701")
const gen2_xu = fromHex(Fp[BN254_Snarks], "0x09e847e9f05a6082c3cd2a1d0a3a82e6fbfbe620f7f31269fa15d21c1c13b23b")
const gen2_yi = fromHex(Fp[BN254_Snarks], "0x056c01168a5319461f7ca7aa19d4fcfd1c7cdf52dbfc4cbee6f915250b7f6fc8")
const gen2_yu = fromHex(Fp[BN254_Snarks], "0x0efe500a2d02dd77f5f401329f30895df553b878fc3c0dadaaa86456a623235c")
const gen2_x : Fp2 = mkFp2( gen2_xi, gen2_xu )
const gen2_y : Fp2 = mkFp2( gen2_yi, gen2_yu )
const gen2_x = mkFp2( gen2_xi, gen2_xu )
const gen2_y = mkFp2( gen2_yi, gen2_yu )
const gen1* : G1 = unsafeMkG1( gen1_x, gen1_y )
const gen2* : G2 = unsafeMkG2( gen2_x, gen2_y )
@ -215,9 +215,9 @@ func `**`*( coeff: BigInt , point: G2 ) : G2 =
#-------------------------------------------------------------------------------
func pairing* (p: G1, q: G2) : Fp12 =
var t : Fp12
ate.pairing_bn[BN254Snarks]( t, p, q )
func pairing* (p: G1, q: G2) : Fp12[BN254_Snarks] =
var t : Fp12[BN254_Snarks]
ate.pairing_bn( t, p, q )
return t
#-------------------------------------------------------------------------------
@ -225,7 +225,7 @@ func pairing* (p: G1, q: G2) : Fp12 =
proc sanityCheckGroupGen*() =
echo( "gen1 on the curve = ", checkCurveEqG1(gen1.x,gen1.y) )
echo( "gen2 on the curve = ", checkCurveEqG2(gen2.x,gen2.y) )
echo( "order of gen1 is R = ", (not bool(isInf(gen1))) and bool(isInf(primeR ** gen1)) )
echo( "order of gen2 is R = ", (not bool(isInf(gen2))) and bool(isInf(primeR ** gen2)) )
# echo( "order of gen1 is R = ", (not bool(isInf(gen1))) and bool(isInf(primeR ** gen1)) )
# echo( "order of gen2 is R = ", (not bool(isInf(gen2))) and bool(isInf(primeR ** gen2)) )
#-------------------------------------------------------------------------------

View File

@ -9,23 +9,27 @@
# equation: y^2 = x^3 + 3
#
import pkg/constantine/math/io/io_fields
import pkg/constantine/named/properties_fields
import pkg/constantine/math/extension_fields/towers
import groth16/bn128/fields
import groth16/bn128/curves
import groth16/bn128/io
#-------------------------------------------------------------------------------
proc debugPrintFp*(prefix: string, x: Fp) =
proc debugPrintFp*(prefix: string, x: Fp[BN254_Snarks]) =
echo(prefix & toDecimalFp(x))
proc debugPrintFp2*(prefix: string, z: Fp2) =
proc debugPrintFp2*(prefix: string, z: Fp2[BN254_Snarks]) =
echo(prefix & " 1 ~> " & toDecimalFp(z.coords[0]))
echo(prefix & " u ~> " & toDecimalFp(z.coords[1]))
proc debugPrintFr*(prefix: string, x: Fr) =
proc debugPrintFr*(prefix: string, x: Fr[BN254_Snarks]) =
echo(prefix & toDecimalFr(x))
proc debugPrintFrSeq*(msg: string, xs: seq[Fr]) =
proc debugPrintFrSeq*(msg: string, xs: seq[Fr[BN254_Snarks]]) =
echo "---------------------"
echo msg
for x in xs:

View File

@ -1,4 +1,3 @@
#
# the prime fields Fp and Fr with sizes
#
@ -6,30 +5,31 @@
# r = 21888242871839275222246405745257275088548364400416034343698204186575808495617
#
import sugar
import std/sugar
import std/bitops
import std/sequtils
import constantine/math/arithmetic
import constantine/math/io/io_fields
import constantine/math/io/io_bigints
import constantine/math/config/curves
import constantine/math/config/type_ff as tff
import constantine/math/extension_fields/towers as ext
import pkg/constantine/math/arithmetic
import pkg/constantine/math/io/io_fields
# import pkg/constantine/math/io/io_extfields
import pkg/constantine/math/io/io_bigints
# import pkg/constantine/math/config/curves
import pkg/constantine/named/properties_fields
import pkg/constantine/math/extension_fields/towers as ext
#-------------------------------------------------------------------------------
type B* = BigInt[256]
type Fr* = tff.Fr[BN254Snarks]
type Fp* = tff.Fp[BN254Snarks]
# type Fr* = tff.Fr[BN254_Snarks]
# type Fp* = tff.Fp[BN254_Snarks]
type Fp2* = ext.QuadraticExt[Fp]
type Fp12* = ext.Fp12[BN254Snarks]
# type Fp2* = ext.QuadraticExt[Fp]
# type Fp12* = ext.Fp12[BN254_Snarks]
func mkFp2* (i: Fp, u: Fp) : Fp2 =
let c : array[2, Fp] = [i,u]
return ext.QuadraticExt[Fp]( coords: c )
func mkFp2* (i: Fp[BN254_Snarks], u: Fp[BN254_Snarks]) : Fp2[BN254_Snarks] =
let c : array[2, Fp[BN254_Snarks]] = [i,u]
return ext.QuadraticExt[Fp[BN254_Snarks]]( coords: c )
#-------------------------------------------------------------------------------
@ -38,16 +38,16 @@ const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b97
#-------------------------------------------------------------------------------
const zeroFp* : Fp = fromHex( Fp, "0x00" )
const zeroFr* : Fr = fromHex( Fr, "0x00" )
const oneFp* : Fp = fromHex( Fp, "0x01" )
const oneFr* : Fr = fromHex( Fr, "0x01" )
const zeroFp* = fromHex( Fp[BN254_Snarks], "0x00" )
const zeroFr* = fromHex( Fr[BN254_Snarks], "0x00" )
const oneFp* = fromHex( Fp[BN254_Snarks], "0x01" )
const oneFr* = fromHex( Fr[BN254_Snarks], "0x01" )
const zeroFp2* : Fp2 = mkFp2( zeroFp, zeroFp )
const oneFp2* : Fp2 = mkFp2( oneFp , zeroFp )
const zeroFp2* = mkFp2( zeroFp, zeroFp )
const oneFp2* = mkFp2( oneFp , zeroFp )
const minusOneFp* : Fp = fromHex( Fp, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd46" )
const minusOneFr* : Fr = fromHex( Fr, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000000" )
const minusOneFp* = fromHex( Fp[BN254_Snarks], "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd46" )
const minusOneFr* = fromHex( Fr[BN254_Snarks], "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000000" )
#-------------------------------------------------------------------------------
@ -56,13 +56,13 @@ func intToB*(a: uint): B =
y.setUint(a)
return y
func intToFp*(a: int): Fp =
var y : Fp
func intToFp*(a: int): Fp[BN254_Snarks] =
var y : Fp[BN254_Snarks]
y.fromInt(a)
return y
func intToFr*(a: int): Fr =
var y : Fr
func intToFr*(a: int): Fr[BN254_Snarks] =
var y : Fr[BN254_Snarks]
y.fromInt(a)
return y
@ -154,27 +154,27 @@ func smallPowFr*(base: Fr, expo: int): Fr =
#-------------------------------------------------------------------------------
func deltaFr*(i, j: int) : Fr =
func deltaFr*[T](i, j: int) : Fr[T] =
return (if (i == j): oneFr else: zeroFr)
#-------------------------------------------------------------------------------
# Montgomery batch inversion
func batchInverseFr*( xs: seq[Fr] ) : seq[Fr] =
func batchInverseFr*( xs: seq[Fr[BN254_Snarks]] ) : seq[Fr[BN254_Snarks]] =
let n = xs.len
assert(n>0)
var us : seq[Fr] = newSeq[Fr](n+1)
var us : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]](n+1)
var a = xs[0]
us[0] = oneFr
us[1] = a
for i in 1..<n: ( a *= xs[i] ; us[i+1] = a )
var vs : seq[Fr] = newSeq[Fr](n)
var vs : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]](n)
vs[n-1] = invFr( us[n] )
for i in countdown(n-2,0): vs[i] = vs[i+1] * xs[i+1]
return collect( newSeq, (for i in 0..<n: us[i]*vs[i] ) )
proc sanityCheckBatchInverseFr*() =
let xs : seq[Fr] = map( toSeq(101..137) , intToFr )
let xs = map( toSeq(101..137) , intToFr )
let ys = batchInverseFr( xs )
let zs = collect( newSeq, (for x in xs: invFr(x)) )
let n = xs.len
@ -183,7 +183,7 @@ proc sanityCheckBatchInverseFr*() =
if not bool(ys[i] == zs[i]):
echo "batch inverse test FAILED!"
return
echo "batch iverse test OK."
echo "batch inverse test OK."
#-------------------------------------------------------------------------------

View File

@ -3,11 +3,12 @@
import std/strutils
import std/streams
import constantine/math/arithmetic except Fp, Fp2, Fr
import constantine/math/io/io_fields except Fp, Fp2, Fp
import constantine/math/io/io_bigints
import constantine/math/config/curves
import constantine/math/config/type_ff as tff except Fp, Fp2, Fr
import pkg/constantine/math/arithmetic
import pkg/constantine/math/io/io_fields
import pkg/constantine/math/io/io_bigints
# import pkg/constantine/math/config/curves
import pkg/constantine/named/properties_fields
import pkg/constantine/math/extension_fields/towers
import groth16/bn128/fields
import groth16/bn128/curves
@ -25,13 +26,13 @@ func toDecimalBig*[n](a : BigInt[n]): string =
if s.len == 0: s="0"
return s
func toDecimalFp*(a : Fp): string =
func toDecimalFp*(a : Fp[BN254_Snarks]): string =
var s : string = toDecimal(a)
s = s.strip( leading=true, trailing=false, chars={'0'} )
if s.len == 0: s="0"
return s
func toDecimalFr*(a : Fr): string =
func toDecimalFr*(a : Fr[BN254_Snarks]): string =
var s : string = toDecimal(a)
s = s.strip( leading=true, trailing=false, chars={'0'} )
if s.len == 0: s="0"
@ -41,13 +42,13 @@ func toDecimalFr*(a : Fr): string =
const k65536 : BigInt[254] = fromHex( BigInt[254], "0x10000", bigEndian )
func signedToDecimalFp*(a : Fp): string =
func signedToDecimalFp*(a : Fp[BN254_Snarks]): string =
if bool( a.toBig() > primeP_254 - k65536 ):
return "-" & toDecimalFp(negFp(a))
else:
return toDecimalFp(a)
func signedToDecimalFr*(a : Fr): string =
func signedToDecimalFr*(a : Fr[BN254_Snarks]): string =
if bool( a.toBig() > primeR_254 - k65536 ):
return "-" & toDecimalFr(negFr(a))
else:
@ -58,38 +59,38 @@ func signedToDecimalFr*(a : Fr): string =
#
# R=2^256; this computes 2^256 mod Fp
func calcFpMontR*() : Fp =
var x : Fp = intToFp(2)
func calcFpMontR*() : Fp[BN254_Snarks] =
var x : Fp[BN254_Snarks] = intToFp(2)
for i in 1..8:
square(x)
return x
# R=2^256; this computes the inverse of (2^256 mod Fp)
func calcFpInvMontR*() : Fp =
var x : Fp = calcFpMontR()
func calcFpInvMontR*() : Fp[BN254_Snarks] =
var x : Fp[BN254_Snarks] = calcFpMontR()
inv(x)
return x
# R=2^256; this computes 2^256 mod Fr
func calcFrMontR*() : Fr =
var x : Fr = intToFr(2)
func calcFrMontR*() : Fr[BN254_Snarks] =
var x : Fr[BN254_Snarks] = intToFr(2)
for i in 1..8:
square(x)
return x
# R=2^256; this computes the inverse of (2^256 mod Fp)
func calcFrInvMontR*() : Fr =
var x : Fr = calcFrMontR()
func calcFrInvMontR*() : Fr[BN254_Snarks] =
var x : Fr[BN254_Snarks] = calcFrMontR()
inv(x)
return x
# apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?)
const fpMontR* : Fp = fromHex( Fp, "0x0e0a77c19a07df2f666ea36f7879462c0a78eb28f5c70b3dd35d438dc58f0d9d" )
const fpInvMontR* : Fp = fromHex( Fp, "0x2e67157159e5c639cf63e9cfb74492d9eb2022850278edf8ed84884a014afa37" )
const fpMontR* = fromHex( Fp[BN254_Snarks], "0x0e0a77c19a07df2f666ea36f7879462c0a78eb28f5c70b3dd35d438dc58f0d9d" )
const fpInvMontR* = fromHex( Fp[BN254_Snarks], "0x2e67157159e5c639cf63e9cfb74492d9eb2022850278edf8ed84884a014afa37" )
# apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?)
const frMontR* : Fr = fromHex( Fr, "0x0e0a77c19a07df2f666ea36f7879462e36fc76959f60cd29ac96341c4ffffffb" )
const frInvMontR* : Fr = fromHex( Fr, "0x15ebf95182c5551cc8260de4aeb85d5d090ef5a9e111ec87dc5ba0056db1194e" )
const frMontR* = fromHex( Fr[BN254_Snarks], "0x0e0a77c19a07df2f666ea36f7879462e36fc76959f60cd29ac96341c4ffffffb" )
const frInvMontR* = fromHex( Fr[BN254_Snarks], "0x15ebf95182c5551cc8260de4aeb85d5d090ef5a9e111ec87dc5ba0056db1194e" )
proc checkMontgomeryConstants*() =
assert( bool( fpMontR == calcFpMontR() ) )
@ -100,21 +101,21 @@ proc checkMontgomeryConstants*() =
#---------------------------------------
# the binary file `.zkey` used by the `circom` ecosystem uses little-endian
# Montgomery representation. So when we unmarshal with Constantine, it will
# the binary file `.zkey` used by the `circom` ecosystem uses little-endian
# Montgomery representation. So when we unmarshal with Constantine, it will
# give the wrong result. Calling this function on the result fixes that.
func fromMontgomeryFp*(x : Fp) : Fp =
var y : Fp = x;
func fromMontgomeryFp*(x : Fp[BN254_Snarks]) : Fp[BN254_Snarks] =
var y : Fp[BN254_Snarks] = x;
y *= fpInvMontR
return y
func fromMontgomeryFr*(x : Fr) : Fr =
var y : Fr = x;
func fromMontgomeryFr*(x : Fr[BN254_Snarks]) : Fr[BN254_Snarks] =
var y = x;
y *= frInvMontR
return y
func toMontgomeryFr*(x : Fr) : Fr =
var y : Fr = x;
func toMontgomeryFr*(x : Fr[BN254_Snarks]) : Fr[BN254_Snarks] =
var y = x;
y *= frMontR
return y
@ -123,47 +124,47 @@ func toMontgomeryFr*(x : Fr) : Fr =
# Note: in the `.zkey` coefficients, e apparently DOUBLE Montgomery encoding is used ?!?
#
func unmarshalFpMont* ( bs: array[32,byte] ) : Fp =
func unmarshalFpMont* ( bs: array[32,byte] ) : Fp[BN254_Snarks] =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fp
var x : Fp[BN254_Snarks]
x.fromBig( big )
return fromMontgomeryFp(x)
# WTF Jordi, go home you are drunk
func unmarshalFrWTF* ( bs: array[32,byte] ) : Fr =
func unmarshalFrWTF* ( bs: array[32,byte] ) : Fr[BN254_Snarks] =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fr
var x : Fr[BN254_Snarks]
x.fromBig( big )
return fromMontgomeryFr(fromMontgomeryFr(x))
func unmarshalFrStd* ( bs: array[32,byte] ) : Fr =
func unmarshalFrStd* ( bs: array[32,byte] ) : Fr[BN254_Snarks] =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fr
var x : Fr[BN254_Snarks]
x.fromBig( big )
return x
func unmarshalFrMont* ( bs: array[32,byte] ) : Fr =
func unmarshalFrMont* ( bs: array[32,byte] ) : Fr[BN254_Snarks] =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fr
var x : Fr[BN254_Snarks]
x.fromBig( big )
return fromMontgomeryFr(x)
#-------------------------------------------------------------------------------
func unmarshalFpMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fp] =
var vals : seq[Fp] = newSeq[Fp]( len )
func unmarshalFpMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fp[BN254_Snarks]] =
var vals : seq[Fp[BN254_Snarks]] = newSeq[Fp[BN254_Snarks]]( len )
var bytes : array[32,byte]
for i in 0..<len:
copyMem( addr(bytes) , unsafeAddr(bs[32*i]) , 32 )
vals[i] = unmarshalFpMont( bytes )
return vals
func unmarshalFrMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fr] =
var vals : seq[Fr] = newSeq[Fr]( len )
func unmarshalFrMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fr[BN254_Snarks]] =
var vals : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( len )
var bytes : array[32,byte]
for i in 0..<len:
copyMem( addr(bytes) , unsafeAddr(bs[32*i]) , 32 )
@ -172,7 +173,7 @@ func unmarshalFrMontSeq* ( len: int, bs: openArray[byte] ) : seq[Fr] =
#-------------------------------------------------------------------------------
proc loadValueFrWTF*( stream: Stream ) : Fr =
proc loadValueFrWTF*( stream: Stream ) : Fr[BN254_Snarks] =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
# for i in 0..<32: stdout.write(" " & toHex(bytes[i]))
@ -180,45 +181,45 @@ proc loadValueFrWTF*( stream: Stream ) : Fr =
assert( n == 32 )
return unmarshalFrWTF(bytes)
proc loadValueFrStd*( stream: Stream ) : Fr =
proc loadValueFrStd*( stream: Stream ) : Fr[BN254_Snarks] =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
assert( n == 32 )
return unmarshalFrStd(bytes)
proc loadValueFrMont*( stream: Stream ) : Fr =
proc loadValueFrMont*( stream: Stream ) : Fr[BN254_Snarks] =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
assert( n == 32 )
return unmarshalFrMont(bytes)
proc loadValueFpMont*( stream: Stream ) : Fp =
proc loadValueFpMont*( stream: Stream ) : Fp[BN254_Snarks] =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
assert( n == 32 )
return unmarshalFpMont(bytes)
proc loadValueFp2Mont*( stream: Stream ) : Fp2 =
proc loadValueFp2Mont*( stream: Stream ) : Fp2[BN254_Snarks] =
let i = loadValueFpMont( stream )
let u = loadValueFpMont( stream )
return mkFp2(i,u)
#---------------------------------------
proc loadValuesFrStd*( len: int, stream: Stream ) : seq[Fr] =
var values : seq[Fr]
proc loadValuesFrStd*( len: int, stream: Stream ) : seq[Fr[BN254_Snarks]] =
var values: seq[Fr[BN254_Snarks]]
for i in 1..len:
values.add( loadValueFrStd(stream) )
return values
proc loadValuesFpMont*( len: int, stream: Stream ) : seq[Fp] =
var values : seq[Fp]
proc loadValuesFpMont*( len: int, stream: Stream ) : seq[Fp[BN254_Snarks]] =
var values : seq[Fp[BN254_Snarks]]
for i in 1..len:
values.add( loadValueFpMont(stream) )
return values
proc loadValuesFrMont*( len: int, stream: Stream ) : seq[Fr] =
var values : seq[Fr]
proc loadValuesFrMont*( len: int, stream: Stream ) : seq[Fr[BN254_Snarks]] =
var values: seq[Fr[BN254_Snarks]]
for i in 1..len:
values.add( loadValueFrMont(stream) )
return values

View File

@ -1,38 +1,38 @@
#
# Multi-Scalar Multiplication (MSM)
#
#
import system
import std/cpuinfo
import taskpools
import std/times
import pkg/taskpools
# import constantine/curves_primitives except Fp, Fp2, Fr
import constantine/platforms/abstractions except Subgroup
import constantine/math/isogenies/frobenius except Subgroup
import constantine/math/arithmetic except Fp, Fp2, Fr
import constantine/math/io/io_fields except Fp, Fp2, Fr
import constantine/math/io/io_bigints
import constantine/math/config/curves except G1, G2, Subgroup
import constantine/math/config/type_ff except Fp, Fr, Subgroup
import pkg/constantine/platforms/abstractions except Subgroup
import pkg/constantine/math/endomorphisms/frobenius except Subgroup
import constantine/math/extension_fields/towers as ext except Fp, Fp2, Fp12, Fr
import constantine/math/elliptic/ec_shortweierstrass_affine as aff except Subgroup
import constantine/math/elliptic/ec_shortweierstrass_projective as prj except Subgroup
import constantine/math/elliptic/ec_scalar_mul_vartime as scl except Subgroup
import constantine/math/elliptic/ec_multi_scalar_mul as msm except Subgroup
import pkg/constantine/named/algebras
import pkg/constantine/math/arithmetic
import pkg/constantine/math/io/io_fields
import pkg/constantine/math/io/io_bigints
import pkg/constantine/math/extension_fields/towers as ext
import pkg/constantine/math/elliptic/ec_shortweierstrass_affine as aff except Subgroup
import pkg/constantine/math/elliptic/ec_shortweierstrass_projective as prj except Subgroup
import pkg/constantine/math/elliptic/ec_scalar_mul_vartime as scl except Subgroup
import pkg/constantine/math/elliptic/ec_multi_scalar_mul as msm except Subgroup
import groth16/bn128/fields
import groth16/bn128/curves as mycurves
import groth16/misc # TEMP DEBUGGING
import std/times
#-------------------------------------------------------------------------------
proc msmConstantineG1*( coeffs: openArray[Fr] , points: openArray[G1] ): G1 =
proc msmConstantineG1*( coeffs: openArray[Fr[BN254_Snarks]] , points: openArray[G1] ): G1 =
# let start = cpuTime()
@ -60,7 +60,7 @@ proc msmConstantineG1*( coeffs: openArray[Fr] , points: openArray[G1] ): G1 =
#---------------------------------------
func msmConstantineG2*( coeffs: openArray[Fr] , points: openArray[G2] ): G2 =
func msmConstantineG2*( coeffs: openArray[Fr[BN254_Snarks]] , points: openArray[G2] ): G2 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
@ -86,21 +86,20 @@ func msmConstantineG2*( coeffs: openArray[Fr] , points: openArray[G2] ): G2 =
const task_multiplier : int = 1
proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G1] ): G1 =
proc msmMultiThreadedG1*( coeffs: seq[Fr[BN254_Snarks]] , points: seq[G1], pool: Taskpool ): G1 =
# for N <= 255 , we use 1 thread
# for N == 256 , we use 2 threads
# for N == 512 , we use 4 threads
# for N >= 1024, we use 8+ threads
# for N == 512 , we use 4 threads
# for N >= 1024, we use 8+ threads
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
let nthreads_target = if (nthreads_hint<=0): countProcessors() else: min( nthreads_hint, 256 )
let nthreads_target = min( pool.numThreads, 256 )
let nthreads = max( 1 , min( N div 128 , nthreads_target ) )
let ntasks = if nthreads>1: (nthreads*task_multiplier) else: 1
let ntasks = if nthreads>1: ( nthreads * task_multiplier ) else: 1
var pool = Taskpool.new(num_threads = nthreads)
var pending : seq[FlowVar[mycurves.G1]] = newSeq[FlowVar[mycurves.G1]](ntasks)
var pending : seq[Flowvar[mycurves.G1]] = newSeq[Flowvar[mycurves.G1]](ntasks)
var a : int = 0
var b : int
@ -118,23 +117,19 @@ proc msmMultiThreadedG1*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G1]
for k in 0..<ntasks:
res += sync pending[k]
pool.syncAll()
pool.shutdown()
return res
#---------------------------------------
proc msmMultiThreadedG2*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G2] ): G2 =
proc msmMultiThreadedG2*( coeffs: seq[Fr[BN254_Snarks]] , points: seq[G2], pool: Taskpool ): G2 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
let nthreads_target = if (nthreads_hint<=0): countProcessors() else: min( nthreads_hint, 256 )
let nthreads_target = min( pool.numThreads, 256 )
let nthreads = max( 1 , min( N div 128 , nthreads_target ) )
let ntasks = if nthreads>1: (nthreads*task_multiplier) else: 1
var pool = Taskpool.new(num_threads = nthreads)
var pending : seq[FlowVar[mycurves.G2]] = newSeq[FlowVar[mycurves.G2]](ntasks)
var pending : seq[Flowvar[mycurves.G2]] = newSeq[Flowvar[mycurves.G2]](ntasks)
var a : int = 0
var b : int
@ -152,19 +147,16 @@ proc msmMultiThreadedG2*( nthreads_hint: int, coeffs: seq[Fr] , points: seq[G2]
for k in 0..<ntasks:
res += sync pending[k]
pool.syncAll()
pool.shutdown()
return res
#-------------------------------------------------------------------------------
func msmNaiveG1*( coeffs: seq[Fr] , points: seq[G1] ): G1 =
func msmNaiveG1*( coeffs: seq[Fr[BN254_Snarks]] , points: seq[G1] ): G1 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
var s : ProjG1
s.setInf()
s.setNeutral()
for i in 0..<N:
var t : ProjG1
@ -179,12 +171,12 @@ func msmNaiveG1*( coeffs: seq[Fr] , points: seq[G1] ): G1 =
#---------------------------------------
func msmNaiveG2*( coeffs: seq[Fr] , points: seq[G2] ): G2 =
func msmNaiveG2*( coeffs: seq[Fr[BN254_Snarks]] , points: seq[G2] ): G2 =
let N = coeffs.len
assert( N == points.len, "incompatible sequence lengths" )
var s : ProjG2
s.setInf()
s.setNeutral()
for i in 0..<N:
var t : ProjG2
@ -199,8 +191,8 @@ func msmNaiveG2*( coeffs: seq[Fr] , points: seq[G2] ): G2 =
#-------------------------------------------------------------------------------
proc msmG1*( coeffs: seq[Fr] , points: seq[G1] ): G1 = msmConstantineG1(coeffs, points)
proc msmG2*( coeffs: seq[Fr] , points: seq[G2] ): G2 = msmConstantineG2(coeffs, points)
proc msmG1*( coeffs: seq[Fr[BN254_Snarks]] , points: seq[G1] ): G1 = msmConstantineG1(coeffs, points)
proc msmG2*( coeffs: seq[Fr[BN254_Snarks]] , points: seq[G2] ): G2 = msmConstantineG2(coeffs, points)
#-------------------------------------------------------------------------------

View File

@ -3,9 +3,10 @@ import std/random
# import constantine/platforms/abstractions
import constantine/math/arithmetic except Fp, Fp2, Fr
import constantine/math/io/io_fields except Fp, Fp2, Fr
import constantine/math/io/io_bigints
import pkg/constantine/math/arithmetic
import pkg/constantine/math/io/io_fields
import pkg/constantine/math/io/io_bigints
import pkg/constantine/named/properties_fields
import groth16/bn128/fields
@ -59,9 +60,9 @@ proc randBig*[bits: static int](): BigInt[bits] =
return d
proc randFr*(): Fr =
proc randFr*(): Fr[BN254_Snarks] =
let b : BigInt[254] = randBig[254]()
var y : Fr
var y : Fr[BN254_Snarks]
y.fromBig( b )
return y

View File

@ -4,15 +4,15 @@ import groth16/files/export_json
#-------------------------------------------------------------------------------
proc exampleProveAndVerify() =
proc exampleProveAndVerify() =
let zkey_fname : string = "./build/product.zkey"
let wtns_fname : string = "./build/product.wtns"
let proof = testProveAndVerify( zkey_fname, wtns_fname)
let (_, proof) = testProveAndVerify( zkey_fname, wtns_fname)
exportPublicIO( "./build/nim_public.json" , proof )
exportProof( "./build/nim_proof.json" , proof )
#-------------------------------------------------------------------------------
when isMainModule:
exampleProveAndVerify()
exampleProveAndVerify()

View File

@ -37,4 +37,4 @@ template Main(n) {
//------------------------------------------------------------------------------
component main {public [plus]} = Main(3);
component main {public [plus]} = Main(3);

View File

@ -2,14 +2,16 @@
#
# create "fake" circuit-specific trusted setup for testing purposes
#
# by fake here I mean that no actual ceremoney is done, we just generate
# by fake here I mean that no actual ceremoney is done, we just generate
# some random toxic waste
#
#
import sugar
import std/tables
import constantine/math/arithmetic except Fp, Fr
import pkg/constantine/math/arithmetic
import pkg/constantine/named/properties_fields
import pkg/constantine/math/extension_fields/towers
import groth16/bn128
import groth16/math/domain
@ -20,21 +22,21 @@ import groth16/misc
#-------------------------------------------------------------------------------
type
type
ToxicWaste* = object
alpha*: Fr
beta*: Fr
gamma*: Fr
delta*: Fr
tau*: Fr
alpha*: Fr[BN254_Snarks]
beta*: Fr[BN254_Snarks]
gamma*: Fr[BN254_Snarks]
delta*: Fr[BN254_Snarks]
tau*: Fr[BN254_Snarks]
proc randomToxicWaste*(): ToxicWaste =
proc randomToxicWaste*(): ToxicWaste =
let a = randFr()
let b = randFr()
let c = randFr()
let d = randFr()
let t = randFr() # intToFr(106)
return
let t = randFr() # intToFr(106)
return
ToxicWaste( alpha: a
, beta: b
, gamma: c
@ -43,7 +45,7 @@ proc randomToxicWaste*(): ToxicWaste =
#-------------------------------------------------------------------------------
func r1csToCoeffs*(r1cs: R1CS): seq[Coeff] =
func r1csToCoeffs*(r1cs: R1CS): seq[Coeff] =
var coeffs : seq[Coeff]
let n = r1cs.constraints.len
let p = r1cs.cfg.nPubIn + r1cs.cfg.nPubOut
@ -71,11 +73,11 @@ type DenseColumn*[T] = seq[T]
type DenseMatrix*[T] = seq[DenseColumn[T]]
type
type
DenseMatrices* = object
A* : DenseMatrix[Fr]
B* : DenseMatrix[Fr]
C* : DenseMatrix[Fr]
A* : DenseMatrix[Fr[BN254_Snarks]]
B* : DenseMatrix[Fr[BN254_Snarks]]
C* : DenseMatrix[Fr[BN254_Snarks]]
#[
@ -111,7 +113,7 @@ func r1csToDenseMatrices*(r1cs: R1CS): DenseMatrices =
#-------------------------------------------------------------------------------
func denseMatricesToCoeffs*(matrices: DenseMatrices): seq[Coeff] =
func denseMatricesToCoeffs*(matrices: DenseMatrices): seq[Coeff] =
let n = matrices.A[0].len
let m = matrices.A.len
@ -137,24 +139,24 @@ func denseMatricesToCoeffs*(matrices: DenseMatrices): seq[Coeff] =
type SparseColumn*[T] = Table[int,T]
proc columnInsertWithAddFr( col: var SparseColumn[Fr] , i: int, y: Fr ) =
proc columnInsertWithAddFr( col: var SparseColumn[Fr[BN254_Snarks]] , i: int, y: Fr[BN254_Snarks] ) =
var x = getOrDefault( col, i, zeroFr )
x += y
col[i] = x
proc sparseDenseDotProdFr( U: SparseColumn[Fr], V: DenseColumn[Fr] ): Fr =
var acc : Fr = zeroFr
proc sparseDenseDotProdFr( U: SparseColumn[Fr[BN254_Snarks]], V: DenseColumn[Fr[BN254_Snarks]] ): Fr[BN254_Snarks] =
var acc : Fr[BN254_Snarks] = zeroFr
for i,x in U.pairs:
acc += x * V[i]
return acc
type SparseMatrix*[T] = seq[SparseColumn[T]]
type
type
SparseMatrices* = object
A* : SparseMatrix[Fr]
B* : SparseMatrix[Fr]
C* : SparseMatrix[Fr]
A* : SparseMatrix[Fr[BN254_Snarks]]
B* : SparseMatrix[Fr[BN254_Snarks]]
C* : SparseMatrix[Fr[BN254_Snarks]]
func r1csToSparseMatrices*(r1cs: R1CS): SparseMatrices =
let n = r1cs.constraints.len
@ -164,11 +166,11 @@ func r1csToSparseMatrices*(r1cs: R1CS): SparseMatrices =
let logDomSize = ceilingLog2(n+p+1)
let domSize = 1 shl logDomSize
var matA, matB, matC: SparseMatrix[Fr]
var matA, matB, matC: SparseMatrix[Fr[BN254_Snarks]]
for i in 0..<m:
var colA : SparseColumn[Fr] = initTable[int,Fr]()
var colB : SparseColumn[Fr] = initTable[int,Fr]()
var colC : SparseColumn[Fr] = initTable[int,Fr]()
var colA : SparseColumn[Fr[BN254_Snarks]] = initTable[int,Fr[BN254_Snarks]]()
var colB : SparseColumn[Fr[BN254_Snarks]] = initTable[int,Fr[BN254_Snarks]]()
var colC : SparseColumn[Fr[BN254_Snarks]] = initTable[int,Fr[BN254_Snarks]]()
matA.add( colA )
matB.add( colB )
matC.add( colC )
@ -188,21 +190,21 @@ func r1csToSparseMatrices*(r1cs: R1CS): SparseMatrices =
#-------------------------------------------------------------------------------
func dotProdFr(xs, ys: seq[Fr]): Fr =
func dotProdFr(xs, ys: seq[Fr[BN254_Snarks]]): Fr[BN254_Snarks] =
let n = xs.len
assert( n == ys.len, "dotProdFr: incompatible vector lengths" )
var s : Fr = zeroFr
var s : Fr[BN254_Snarks] = zeroFr
for i in 0..<n:
s += xs[i] * ys[i]
return s
#-------------------------------------------------------------------------------
func fakeCircuitSetup*(r1cs: R1CS, toxic: ToxicWaste, flavour=Snarkjs): ZKey =
func fakeCircuitSetup*(r1cs: R1CS, toxic: ToxicWaste, flavour=Snarkjs): ZKey =
let neqs = r1cs.constraints.len
let npub = r1cs.cfg.nPubIn + r1cs.cfg.nPubOut
let logDomSize = ceilingLog2(neqs+npub+1)
let logDomSize = ceilingLog2(neqs+npub+1)
let domSize = 1 shl logDomSize
let nvars = r1cs.cfg.nWires
@ -213,7 +215,7 @@ func fakeCircuitSetup*(r1cs: R1CS, toxic: ToxicWaste, flavour=Snarkjs): ZKey =
# echo("neqs = ",neqs)
# echo("domain = ",domSize)
let header =
let header =
GrothHeader( curve: "bn128"
, flavour: flavour
, p: primeP
@ -224,18 +226,18 @@ func fakeCircuitSetup*(r1cs: R1CS, toxic: ToxicWaste, flavour=Snarkjs): ZKey =
, logDomainSize: logDomSize
)
let spec =
let spec =
SpecPoints( alpha1 : toxic.alpha ** gen1
, beta1 : toxic.beta ** gen1
, beta2 : toxic.beta ** gen2
, gamma2 : toxic.gamma ** gen2
, delta1 : toxic.delta ** gen1
, delta2 : toxic.delta ** gen2
, alphaBeta : pairing( toxic.alpha ** gen1 , toxic.beta ** gen2 )
, alphaBeta : pairing( toxic.alpha ** gen1 , toxic.beta ** gen2 )
)
let matrices = r1csToSparseMatrices(r1cs)
let D : Domain = createDomain(domSize)
let D : Domain = createDomain(domSize)
#[
# this approach is extremely inefficient
@ -252,8 +254,8 @@ func fakeCircuitSetup*(r1cs: R1CS, toxic: ToxicWaste, flavour=Snarkjs): ZKey =
# the Lagrange polynomials L_k(x) evaluated at x=tau
# we can then simply take the dot product of these with the column vectors to compute the points A,B1,B2,C
let lagrangeTaus : seq[Fr] = collect( newSeq, (for k in 0..<domSize: evalLagrangePolyAt(D, k, toxic.tau) ))
let lagrangeTaus : seq[Fr[BN254_Snarks]] = collect( newSeq, (for k in 0..<domSize: evalLagrangePolyAt(D, k, toxic.tau) ))
#[
# dense matrices use way too much memory
let columnTausA : seq[Fr] = collect( newSeq, (for col in matrices.A: dotProdFr(col,lagrangeTaus) ))
@ -261,36 +263,36 @@ func fakeCircuitSetup*(r1cs: R1CS, toxic: ToxicWaste, flavour=Snarkjs): ZKey =
let columnTausC : seq[Fr] = collect( newSeq, (for col in matrices.C: dotProdFr(col,lagrangeTaus) ))
]#
let columnTausA : seq[Fr] = collect( newSeq, (for col in matrices.A: sparseDenseDotProdFr(col,lagrangeTaus) ))
let columnTausB : seq[Fr] = collect( newSeq, (for col in matrices.B: sparseDenseDotProdFr(col,lagrangeTaus) ))
let columnTausC : seq[Fr] = collect( newSeq, (for col in matrices.C: sparseDenseDotProdFr(col,lagrangeTaus) ))
let columnTausA : seq[Fr[BN254_Snarks]] = collect( newSeq, (for col in matrices.A: sparseDenseDotProdFr(col,lagrangeTaus) ))
let columnTausB : seq[Fr[BN254_Snarks]] = collect( newSeq, (for col in matrices.B: sparseDenseDotProdFr(col,lagrangeTaus) ))
let columnTausC : seq[Fr[BN254_Snarks]] = collect( newSeq, (for col in matrices.C: sparseDenseDotProdFr(col,lagrangeTaus) ))
let pointsA : seq[G1] = collect( newSeq , (for y in columnTausA: (y ** gen1) ))
let pointsB1 : seq[G1] = collect( newSeq , (for y in columnTausB: (y ** gen1) ))
let pointsB2 : seq[G2] = collect( newSeq , (for y in columnTausB: (y ** gen2) ))
let pointsC : seq[G1] = collect( newSeq , (for y in columnTausC: (y ** gen1) ))
let gammaInv : Fr = invFr(toxic.gamma)
let deltaInv : Fr = invFr(toxic.delta)
let gammaInv : Fr[BN254_Snarks] = invFr(toxic.gamma)
let deltaInv : Fr[BN254_Snarks] = invFr(toxic.delta)
let pointsL : seq[G1] = collect( newSeq , (for j in 0..npub:
let pointsL : seq[G1] = collect( newSeq , (for j in 0..npub:
gammaInv ** ( toxic.beta ** pointsA[j] + toxic.alpha ** pointsB1[j] + pointsC[j] ) ))
let pointsK : seq[G1] = collect( newSeq , (for j in npub+1..nvars-1:
let pointsK : seq[G1] = collect( newSeq , (for j in npub+1..nvars-1:
deltaInv ** ( toxic.beta ** pointsA[j] + toxic.alpha ** pointsB1[j] + pointsC[j] ) ))
let polyZ = vanishingPoly(D)
let ztauG1 = polyEvalAt(polyZ, toxic.tau) ** gen1
var pointsH : seq[G1]
case flavour
case flavour
#---------------------------------------------------------------------------
# in the original paper, these are the curve points
# [ delta^-1 * tau^i * Z(tau) ]
# [ delta^-1 * tau^i * Z(tau) ]
#
of JensGroth:
pointsH = collect( newSeq , (for i in 0..<domSize:
pointsH = collect( newSeq , (for i in 0..<domSize:
(deltaInv * smallPowFr(toxic.tau,i)) ** ztauG1 ))
#---------------------------------------------------------------------------
@ -300,34 +302,34 @@ func fakeCircuitSetup*(r1cs: R1CS, toxic: ToxicWaste, flavour=Snarkjs): ZKey =
#
of Snarkjs:
let D2 : Domain = createDomain(2*domSize)
pointsH = collect( newSeq , (for i in 0..<domSize:
pointsH = collect( newSeq , (for i in 0..<domSize:
(deltaInv * evalLagrangePolyAt(D2, 2*i+1, toxic.tau)) ** gen1 ))
#---------------------------------------------------------------------------
let vPoints = VerifierPoints( pointsIC: pointsL )
let pPoints =
let pPoints =
ProverPoints( pointsA1: pointsA
, pointsB1: pointsB1
, pointsB2: pointsB2
, pointsB1: pointsB1
, pointsB2: pointsB2
, pointsC1: pointsK
, pointsH1: pointsH
, pointsH1: pointsH
)
let coeffs = r1csToCoeffs( r1cs )
return
return
ZKey( header: header
, specPoints: spec
, vPoints: vPoints
, pPoints: pPoints
, coeffs: coeffs
)
#-------------------------------------------------------------------------------
proc createFakeCircuitSetup*(r1cs: R1CS, flavour=Snarkjs): ZKey =
proc createFakeCircuitSetup*(r1cs: R1CS, flavour=Snarkjs): ZKey =
let toxic = randomToxicWaste()
return fakeCircuitSetup(r1cs, toxic, flavour=flavour)

View File

@ -14,7 +14,7 @@
#
# for each section:
# -----------------
# section id : word32
# section id : word32
# section size : word64
# section data : <section_size> number of bytes
#
@ -23,32 +23,33 @@
import std/streams
import sugar
import std/sugar
import constantine/math/arithmetic except Fp, Fr
import constantine/math/io/io_bigints
import pkg/constantine/math/arithmetic except Fp, Fr
import pkg/constantine/math/io/io_bigints
#-------------------------------------------------------------------------------
type
type
SectionCallback*[T] = proc (stream: Stream, sectId: int, sectLen: int, user: var T) {.closure.}
Filt* = ((int) {.raises: [], gcsafe.} -> bool)
#-------------------------------------------------------------------------------
func magicWord(magic: string): uint32 =
func magicWord(magic: string): uint32 =
assert( magic.len == 4, "magicWord: expecting a string of 4 characters" )
var w : uint32 = 0
var w : uint32 = 0
for i in 0..3:
let a = uint32(ord(magic[i]))
let a = uint32(ord(magic[i]))
w += a shl (8*i)
return w
#-------------------------------------------------------------------------------
proc parsePrimeField*( stream: Stream ) : (int, BigInt[256]) =
proc parsePrimeField*( stream: Stream ) : (int, BigInt[256]) =
let n8p = int( stream.readUint32() )
assert( n8p <= 32 , "at most 256 bit primes are allowed" )
var p_bytes : array[32, uint8]
var p_bytes : array[32, uint8]
discard stream.readData( addr(p_bytes), n8p )
var p : BigInt[256]
unmarshal(p, p_bytes, littleEndian);
@ -60,8 +61,8 @@ proc readSection[T] ( expectedMagic: string
, expectedVersion: int
, stream: Stream
, user: var T
, callback: SectionCallback[T]
, filt: (int) -> bool ) =
, callback: SectionCallback[T]
, filt: Filt ) {.raises: [IOError, OSError].} =
let sectId = int( stream.readUint32() )
let sectLen = int( stream.readUint64() )
@ -76,8 +77,8 @@ proc parseContainer*[T] ( expectedMagic: string
, expectedVersion: int
, fname: string
, user: var T
, callback: SectionCallback[T]
, filt: (int) -> bool ) =
, callback: SectionCallback[T]
, filt: Filt ) {.raises: [IOError, OSError].} =
let stream = newFileStream(fname, mode = fmRead)
defer: stream.close()

View File

@ -3,37 +3,41 @@
# export proof and public input in `circom`-compatible JSON files
#
import constantine/math/arithmetic except Fp, Fr
#import constantine/math/io/io_fields except Fp, Fr
import pkg/constantine/math/arithmetic
import pkg/constantine/math/io/io_fields
# import pkg/constantine/math/config/curves
import pkg/constantine/named/properties_fields
import pkg/constantine/math/extension_fields/towers
import groth16/bn128
from groth16/prover import Proof
#-------------------------------------------------------------------------------
func toQuotedDecimalFp(x: Fp): string =
func toQuotedDecimalFp(x: Fp[BN254_Snarks]): string =
let s : string = toDecimalFp(x)
return ("\"" & s & "\"")
func toQuotedDecimalFr(x: Fr): string =
func toQuotedDecimalFr(x: Fr[BN254_Snarks]): string =
let s : string = toDecimalFr(x)
return ("\"" & s & "\"")
#-------------------------------------------------------------------------------
# exports the public input/output into as a JSON file
proc exportPublicIO*( fpath: string, prf: Proof ) =
proc exportPublicIO*( fpath: string, prf: Proof ) =
# debugPrintFrSeq("public IO",prf.publicIO)
debugPrintFrSeq("public IO",prf.publicIO)
let n : int = prf.publicIO.len
let n : int = prf.publicIO.len
assert( n > 0 )
assert( bool(prf.publicIO[0] == oneFr) )
let f = open(fpath, fmWrite)
defer: f.close()
# note: we start from 1 because the 0th element is the constant 1 "variable",
# note: we start from 1 because the 0th element is the constant 1 "variable",
# which is automatically added by the tools
for i in 1..<n:
let str : string = toQuotedDecimalFr( prf.publicIO[i] )
@ -67,7 +71,7 @@ proc writeG2( f: File, p: G2 ) =
#-------------------------------------------------------------------------------
# exports the proof into as a JSON file
proc exportProof*( fpath: string, prf: Proof ) =
proc exportProof*( fpath: string, prf: Proof ) =
let f = open(fpath, fmWrite)
defer: f.close()
@ -84,7 +88,7 @@ proc exportProof*( fpath: string, prf: Proof ) =
#[
#import std/sequtils
func getFakeProof*() : Proof =
func getFakeProof*() : Proof =
let pub : seq[Fr] = map( [1,101,102,103,117,119] , intToFr )
let p = unsafeMkG1( intToFp(666) , intToFp(777) )
let r = unsafeMkG1( intToFp(888) , intToFp(999) )
@ -93,7 +97,7 @@ func getFakeProof*() : Proof =
let q = unsafeMkG2( x , y )
return Proof( publicIO:pub, pi_a:p, pi_b:q, pi_c:r )
proc exportFakeProof*() =
proc exportFakeProof*() =
let prf = getFakeProof()
exportPublicIO( "fake_pub.json" , prf )
exportProof( "fake_prf.json" , prf )

View File

@ -6,7 +6,7 @@
import std/strutils
import std/streams
import constantine/math/arithmetic except Fp, Fr
import pkg/constantine/math/arithmetic except Fp, Fr
import groth16/bn128
import groth16/zkey_types
@ -19,52 +19,52 @@ func toSpaces(str: string): string = spaces(str.len)
func sageFp*(prefix: string, x: Fp): string = prefix & "Fp(" & toDecimalFp(x) & ")"
func sageFr*(prefix: string, x: Fr): string = prefix & "Fr(" & toDecimalFr(x) & ")"
func sageFp2*(prefix: string, z: Fp2): string =
sageFp( prefix & "mkFp2(" , z.coords[0]) & ",\n" &
func sageFp2*(prefix: string, z: Fp2): string =
sageFp( prefix & "mkFp2(" , z.coords[0]) & ",\n" &
sageFp( toSpaces(prefix) & " " , z.coords[1]) & ")"
func sageG1*(prefix: string, p: G1): string =
sageFp( prefix & "E(" , p.x) & ",\n" &
func sageG1*(prefix: string, p: G1): string =
sageFp( prefix & "E(" , p.x) & ",\n" &
sageFp( toSpaces(prefix) & " " , p.y) & ")"
func sageG2*(prefix: string, p: G2): string =
sageFp2( prefix & "E2(" , p.x) & ",\n" &
sageFp2( prefix & "E2(" , p.x) & ",\n" &
sageFp2( toSpaces(prefix) & " " , p.y) & ")"
#-------------------------------------------------------------------------------
proc exportVKey*(h: Stream, vkey: VKey ) =
proc exportVKey*(h: Stream, vkey: VKey ) =
let spec = vkey.spec
h.writeLine("alpha1 = \\") ; h.writeLine(sageG1(" ", spec.alpha1))
h.writeLine("beta2 = \\") ; h.writeLine(sageG2(" ", spec.beta2 ))
h.writeLine("gamma2 = \\") ; h.writeLine(sageG2(" ", spec.gamma2))
h.writeLine("delta2 = \\") ; h.writeLine(sageG2(" ", spec.delta2))
let pts = vkey.vpoints.pointsIC
let pts = vkey.vpoints.pointsIC
h.writeLine("pointsIC = \\")
for i in 0..<pts.len:
let prefix = if (i==0): " [ " else: " "
let postfix = if (i<pts.len-1): "," else: " ]"
let postfix = if (i<pts.len-1): "," else: " ]"
h.writeLine( sageG1(prefix, pts[i]) & postfix )
#---------------------------------------
proc exportProof*(h: Stream, prf: Proof ) =
proc exportProof*(h: Stream, prf: Proof ) =
h.writeLine("piA = \\") ; h.writeLine(sageG1(" ", prf.pi_a ))
h.writeLine("piB = \\") ; h.writeLine(sageG2(" ", prf.pi_b ))
h.writeLine("piC = \\") ; h.writeLine(sageG1(" ", prf.pi_c ))
# note: the first element is just the constant 1
let coeffs = prf.publicIO
h.writeLine("pubIO = \\")
for i in 0..<coeffs.len:
let prefix = if (i==0): " [ " else: " "
let postfix = if (i<coeffs.len-1): "," else: " ]"
let postfix = if (i<coeffs.len-1): "," else: " ]"
h.writeLine( prefix & toDecimalFr(coeffs[i]) & postfix )
#-------------------------------------------------------------------------------
const sage_bn128_lines : seq[string] =
const sage_bn128_lines : seq[string] =
@[ "# BN128 elliptic curve"
, "p = 21888242871839275222246405745257275088696311157297823662689037894645226208583"
, "r = 21888242871839275222246405745257275088548364400416034343698204186575808495617"
@ -121,7 +121,7 @@ const sage_bn128* : string = join(sage_bn128_lines, sep="\n")
#-------------------------------------------------------------------------------
const verify_lines : seq[string] =
const verify_lines : seq[string] =
@[ "pubG1 = pointsIC[0]"
, "for i in [1..len(pubIO)-1]:"
, " pubG1 = pubG1 + pubIO[i]*pointsIC[i]"
@ -138,7 +138,7 @@ const verify_script : string = join(verify_lines, sep="\n")
#-------------------------------------------------------------------------------
proc exportSage*(fpath: string, vkey: VKey, prf: Proof) =
proc exportSage*(fpath: string, vkey: VKey, prf: Proof) =
let h = openFileStream(fpath, fmWrite)
defer: h.close()

View File

@ -4,7 +4,7 @@
#
# file format
# ===========
#
#
# standard iden3 binary container format.
# field elements are in standard representation
#
@ -33,7 +33,7 @@
# where a term looks like this:
# idx : word32 = which witness variable
# coeff : Fr = the coefficient
#
#
# 3: Wire-to-label mapping
# ------------------------
# <an array of `nWires` many 64 bit words>
@ -49,31 +49,36 @@
# ...
#
import std/streams
{.push raises: [IOError, OSError].}
import constantine/math/arithmetic except Fp, Fr
import constantine/math/io/io_bigints
import std/streams
import std/sugar
import pkg/constantine/math/arithmetic except Fp, Fr
import pkg/constantine/math/io/io_bigints
import pkg/constantine/named/properties_fields
import pkg/constantine/math/extension_fields/towers
import groth16/bn128
import groth16/files/container
#-------------------------------------------------------------------------------
type
type
WitnessConfig* = object
nWires* : int # total number of wires (or witness variables), including the constant 1 "variable"
nPubOut* : int # number of public outputs
nPubIn* : int # number of public inputs
nPrivIn* : int # number of private inputs
nLabels* : int # number of labels
Term* = tuple[ wireIdx: int, value: Fr ]
Term* = tuple[ wireIdx: int, value: Fr[BN254_Snarks] ]
LinComb* = seq[Term]
Constraint* = tuple[ A: LinComb, B: LinComb, C: LinComb ]
R1CS* = object
r* : BigInt[256]
r* : BigInt[256]
cfg* : WitnessConfig
nConstr* : int
constraints* : seq[Constraint]
@ -83,7 +88,7 @@ type
proc parseSection1_header( stream: Stream, user: var R1CS, sectionLen: int ) =
# echo "\nparsing r1cs header"
let (n8r, r) = parsePrimeField( stream ) # size of the scalar field
user.r = r;
@ -110,19 +115,19 @@ proc parseSection1_header( stream: Stream, user: var R1CS, sectionLen: int ) =
#-------------------------------------------------------------------------------
proc loadTerm( stream: Stream ): Term =
proc loadTerm( stream: Stream ): Term =
let idx = int( stream.readUint32() )
let coeff = loadValueFrStd( stream )
return (wireIdx:idx, value:coeff)
proc loadLinComb( stream: Stream ): LinComb =
proc loadLinComb( stream: Stream ): LinComb =
let nterms = int( stream.readUint32() )
var terms : seq[Term]
for i in 1..nterms:
terms.add( loadTerm(stream) )
return terms
proc loadConstraint( stream: Stream ): Constraint =
proc loadConstraint( stream: Stream ): Constraint =
let a = loadLinComb( stream )
let b = loadLinComb( stream )
let c = loadLinComb( stream )
@ -160,17 +165,17 @@ proc r1csCallback( stream: Stream
, sectId: int
, sectLen: int
, user: var R1CS
) =
) =
case sectId
of 1: parseSection1_header( stream, user, sectLen )
of 2: parseSection2_constraints( stream, user, sectLen )
of 3: parseSection3_wireToLabel( stream, user, sectLen )
else: discard
proc parseR1CS* (fname: string): R1CS =
proc parseR1CS* (fname: string): R1CS =
var r1cs : R1CS
parseContainer( "r1cs", 1, fname, r1cs, r1csCallback, proc (id: int): bool = id == 1 )
parseContainer( "r1cs", 1, fname, r1cs, r1csCallback, proc (id: int): bool = id != 1 )
parseContainer( "r1cs", 1, fname, r1cs, r1csCallback, (id) {.raises: [], gcsafe.} => id == 1 )
parseContainer( "r1cs", 1, fname, r1cs, r1csCallback, (id) {.raises: [], gcsafe.} => id != 1 )
return r1cs
#-------------------------------------------------------------------------------

View File

@ -16,26 +16,29 @@
import std/streams
import constantine/math/arithmetic except Fp, Fr
import constantine/math/io/io_bigints
import pkg/constantine/math/arithmetic
import pkg/constantine/math/io/io_fields
import pkg/constantine/math/io/io_bigints
import pkg/constantine/named/properties_fields
import pkg/constantine/math/extension_fields/towers
import groth16/bn128
import groth16/files/container
#-------------------------------------------------------------------------------
type
type
Witness* = object
curve* : string
r* : BigInt[256]
r* : BigInt[256]
nvars* : int
values* : seq[Fr]
values* : seq[Fr[BN254_Snarks]]
#-------------------------------------------------------------------------------
proc parseSection1_header( stream: Stream, user: var Witness, sectionLen: int ) =
# echo "\nparsing witness header"
let (n8r, r) = parsePrimeField( stream ) # size of the scalar field
user.r = r;
@ -61,14 +64,14 @@ proc parseSection2_witness( stream: Stream, user: var Witness, sectionLen: int )
#-------------------------------------------------------------------------------
proc wtnsCallback(stream: Stream, sectId: int, sectLen: int, user: var Witness) =
proc wtnsCallback(stream: Stream, sectId: int, sectLen: int, user: var Witness) =
#echo(sectId)
case sectId
of 1: parseSection1_header( stream, user, sectLen )
of 2: parseSection2_witness( stream, user, sectLen )
else: discard
proc parseWitness* (fname: string): Witness =
proc parseWitness* (fname: string): Witness =
var wtns : Witness
parseContainer( "wtns", 2, fname, wtns, wtnsCallback, proc (id: int): bool = id == 1 )
parseContainer( "wtns", 2, fname, wtns, wtnsCallback, proc (id: int): bool = id != 1 )

View File

@ -5,12 +5,12 @@
#
# file format
# ===========
#
#
# standard iden3 binary container format.
# field elements are in Montgomery representation, except for the coefficients
# which for some reason are double Montgomery encoded... (and unlike the
# which for some reason are double Montgomery encoded... (and unlike the
# `.wtns` and `.r1cs` files which use the standard representation)
#
#
# sections:
#
# 1: Header
@ -81,7 +81,7 @@
# what normally should be the curve points `[ delta^-1 * tau^i * Z(tau) ]_1`
# HOWEVER, in the snarkjs implementation, they are different; namely
# `[ delta^-1 * L_{2i+1} (tau) ]_1` where L_k are Lagrange polynomials
# on the refined (double sized) domain
# on the refined (double sized) domain
# See <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
# length = 2 * n8p * domSize = domSize G1 points
#
@ -94,18 +94,20 @@
import std/streams
import constantine/math/arithmetic except Fp, Fr
import pkg/constantine/math/arithmetic except Fp, Fr
#import constantine/math/io/io_bigints
import groth16/bn128
import groth16/zkey_types
import groth16/files/container
import groth16/misc
export zkey_types
#-------------------------------------------------------------------------------
proc parseSection1_proverType ( stream: Stream, user: var Zkey, sectionLen: int ) =
assert( sectionLen == 4 , "unexpected section length" )
proc parseSection1_proverType ( stream: Stream, user: var ZKey, sectionLen: int ) =
assert( sectionLen == 4 , "unexpected section length" )
let proverType = stream.readUint32
assert( proverType == 1 , "expecting `.zkey` file for a Groth16 prover")
@ -128,8 +130,8 @@ proc parseSection2_GrothHeader( stream: Stream, user: var ZKey, sectionLen: int
header.flavour = Snarkjs
assert( n8p == 32 , "expecting 256 bit primes")
assert( n8r == 32 , "expecting 256 bit primes")
assert( n8p == 32 , "expecting 256 bit primes")
assert( n8r == 32 , "expecting 256 bit primes")
assert( bool(p == primeP) , "expecting the alt-bn128 curve" )
assert( bool(r == primeR) , "expecting the alt-bn128 curve" )
@ -171,7 +173,7 @@ proc parseSection4_Coeffs( stream: Stream, user: var ZKey, sectionLen: int ) =
assert( sectionLen == 4 + ncoeffs*(32+12) , "unexpected section length" )
let nrows = user.header.domainSize
let ncols = user.header.nvars
var coeffs : seq[Coeff]
for i in 1..ncoeffs:
let m = int( stream.readUint32() ) # which matrix
@ -188,7 +190,7 @@ proc parseSection4_Coeffs( stream: Stream, user: var ZKey, sectionLen: int ) =
let cf = loadValueFrWTF( stream ) # Jordi, WTF is this encoding ?!?!?!!111
let entry = Coeff( matrix:sel, row:r, col:c, coeff:cf )
coeffs.add( entry )
user.coeffs = coeffs
#-------------------------------------------------------------------------------
@ -225,7 +227,7 @@ proc parseSection9_PointsH1( stream: Stream, user: var ZKey, sectionLen: int ) =
#-------------------------------------------------------------------------------
proc zkeyCallback(stream: Stream, sectId: int, sectLen: int, user: var ZKey) =
proc zkeyCallback(stream: Stream, sectId: int, sectLen: int, user: var ZKey) =
case sectId
of 1: parseSection1_proverType( stream, user, sectLen )
of 2: parseSection2_GrothHeader( stream, user, sectLen )
@ -238,7 +240,7 @@ proc zkeyCallback(stream: Stream, sectId: int, sectLen: int, user: var ZKey) =
of 9: parseSection9_PointsH1( stream, user, sectLen )
else: discard
proc parseZKey* (fname: string): ZKey =
proc parseZKey* (fname: string): ZKey =
var zkey : ZKey
parseContainer( "zkey", 1, fname, zkey, zkeyCallback, proc (id: int): bool = id == 1 )
parseContainer( "zkey", 1, fname, zkey, zkeyCallback, proc (id: int): bool = id == 2 )

View File

@ -3,38 +3,39 @@
# power-of-two sized multiplicative FFT domains in the scalar field
#
import constantine/math/arithmetic except Fp,Fr
import constantine/math/io/io_fields except Fp,Fr
#import constantine/math/io/io_bigints
import pkg/constantine/math/arithmetic
import pkg/constantine/math/io/io_fields
import pkg/constantine/named/properties_fields
import pkg/constantine/math/extension_fields/towers
import groth16/bn128
import groth16/misc
#-------------------------------------------------------------------------------
type
type
Domain* = object
domainSize* : int # `N = 2^n`
logDomainSize* : int # `n = log2(N)`
domainGen* : Fr # `g`
invDomainGen* : Fr # `g^-1`
invDomainSize* : Fr # `1/n`
domainSize* : int # `N = 2^n`
logDomainSize* : int # `n = log2(N)`
domainGen* : Fr[BN254_Snarks] # `g`
invDomainGen* : Fr[BN254_Snarks] # `g^-1`
invDomainSize* : Fr[BN254_Snarks] # `1/n`
#-------------------------------------------------------------------------------
# the generator of the multiplicative subgroup with size `2^28`
const gen28 : Fr = fromHex( Fr, "0x2a3c09f0a58a7e8500e0a7eb8ef62abc402d111e41112ed49bd61b6e725b19f0" )
const gen28 = fromHex( Fr[BN254_Snarks], "0x2a3c09f0a58a7e8500e0a7eb8ef62abc402d111e41112ed49bd61b6e725b19f0" )
func createDomain*(size: int): Domain =
func createDomain*(size: int): Domain =
let log2 = ceilingLog2(size)
assert( (1 shl log2) == size , "domain must have a power-of-two size" )
let expo : uint = 1'u shl (28 - log2)
let gen : Fr = smallPowFr(gen28, expo)
let expo = 1'u shl (28 - log2)
let gen = smallPowFr(gen28, expo)
let halfSize = size div 2
let a : Fr = smallPowFr(gen, size )
let b : Fr = smallPowFr(gen, halfSize)
let a = smallPowFr(gen, size )
let b = smallPowFr(gen, halfSize)
assert( bool(a == oneFr) , "domain generator sanity check /A" )
assert( not bool(b == oneFr) , "domain generator sanity check /B" )
@ -42,14 +43,14 @@ func createDomain*(size: int): Domain =
, logDomainSize: log2
, domainGen: gen
, invDomainGen: invFr(gen)
, invDomainSize: invFr(intToFr(size))
, invDomainSize: invFr(intToFr(size))
)
#-------------------------------------------------------------------------------
func enumerateDomain*(D: Domain): seq[Fr] =
var xs : seq[Fr] = newSeq[Fr](D.domainSize)
var g : Fr = oneFr
func enumerateDomain*(D: Domain): seq[Fr[BN254_Snarks]] =
var xs = newSeq[Fr[BN254_Snarks]](D.domainSize)
var g = oneFr
for i in 0..<D.domainSize:
xs[i] = g
g *= D.domainGen

View File

@ -1,13 +1,15 @@
#
# Number-theoretic transform
# Number-theoretic transform
# (that is, FFT for polynomials over finite fields)
#
#-------------------------------------------------------------------------------
import constantine/math/arithmetic except Fp,Fr
import constantine/math/io/io_fields
import pkg/constantine/math/arithmetic
import pkg/constantine/math/io/io_fields
import pkg/constantine/named/properties_fields
import pkg/constantine/math/extension_fields/towers
import groth16/bn128
import groth16/math/domain
@ -16,22 +18,22 @@ import groth16/math/domain
func forwardNTT_worker( m: int
, srcStride: int
, gpows: seq[Fr]
, src: seq[Fr] , srcOfs: int
, buf: var seq[Fr] , bufOfs: int
, tgt: var seq[Fr] , tgtOfs: int ) =
case m
, gpows: seq[Fr[BN254_Snarks]]
, src: seq[Fr[BN254_Snarks]] , srcOfs: int
, buf: var seq[Fr[BN254_Snarks]] , bufOfs: int
, tgt: var seq[Fr[BN254_Snarks]] , tgtOfs: int ) =
case m
of 0:
of 0:
tgt[tgtOfs] = src[srcOfs]
of 1:
tgt[tgtOfs ] = src[srcOfs] + src[srcOfs+srcStride]
tgt[tgtOfs+1] = src[srcOfs] - src[srcOfs+srcStride]
else:
let N : int = 1 shl m
let halfN : int = 1 shl (m-1)
let N : int = 1 shl m
let halFN : int = 1 shl (m-1)
forwardNTT_worker( m-1
, srcStride shl 1
, gpows
@ -43,28 +45,28 @@ func forwardNTT_worker( m: int
, gpows
, src , srcOfs + srcStride
, buf , bufOfs + N
, buf , bufOfs + halfN )
for j in 0..<halfN:
let y : Fr = gpows[j*srcStride] * buf[bufOfs+j+halfN]
, buf , bufOfs + halFN )
for j in 0..<halFN:
let y = gpows[j*srcStride] * buf[bufOfs+j+halFN]
tgt[tgtOfs+j ] = buf[bufOfs+j] + y
tgt[tgtOfs+j+halfN] = buf[bufOfs+j] - y
tgt[tgtOfs+j+halFN] = buf[bufOfs+j] - y
#---------------------------------------
# forward number-theoretical transform (corresponds to polynomial evaluation)
func forwardNTT*(src: seq[Fr], D: Domain): seq[Fr] =
func forwardNTT*(src: seq[Fr[BN254_Snarks]], D: Domain): seq[Fr[BN254_Snarks]] =
assert( D.domainSize == (1 shl D.logDomainSize) , "domain must have a power-of-two size" )
assert( D.domainSize == src.len , "input must have the same size as the domain" )
var buf : seq[Fr] = newSeq[Fr]( 2 * D.domainSize )
var tgt : seq[Fr] = newSeq[Fr]( D.domainSize )
var buf = newSeq[Fr[BN254_Snarks] ]( 2 * D.domainSize )
var tgt = newSeq[Fr[BN254_Snarks]]( D.domainSize )
# precalc powers of gen
let N = D.domainSize
let halFN = N div 2
var gpows : seq[Fr] = newSeq[Fr]( halfN )
var x : Fr = oneFr
let gen : Fr = D.domainGen
for i in 0..<halfN:
var gpows = newSeq[Fr[BN254_Snarks]]( halFN )
var x = oneFr
let gen = D.domainGen
for i in 0..<halFN:
gpows[i] = x
x *= gen
@ -77,47 +79,47 @@ func forwardNTT*(src: seq[Fr], D: Domain): seq[Fr] =
return tgt
# pads the input with zeros to get a pwoer of two size
# TODO: optimize the FFT so that it doesn't do the multiplications with zeros
func extendAndForwardNTT*(src: seq[Fr], D: Domain): seq[Fr] =
# TODO: optimize the FFT so that it doesn't do the multiplications with zeros
func extendAndForwardNTT*(src: seq[Fr[BN254_Snarks]], D: Domain): seq[Fr[BN254_Snarks]] =
let n = src.len
let N = D.domainSize
let N = D.domainSize
assert( n <= N )
if n == N:
return forwardNTT(src, D)
else:
var padded : seq[Fr] = newSeq[Fr]( N )
var padded = newSeq[Fr[BN254_Snarks]]( N )
for i in 0..<n: padded[i] = src[i]
# for i in n..<N: padded[i] = zeroFr
# for i in n..<N: padded[i] = zeroFr
return forwardNTT(padded, D)
#-------------------------------------------------------------------------------
const oneHalfFr* : Fr = fromHex(Fr, "0x183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f0fac9f8000001")
const oneHalfFr* = fromHex(Fr[BN254_Snarks], "0x183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f0fac9f8000001")
func inverseNTT_worker( m: int
, tgtStride: int
, gpows: seq[Fr]
, src: seq[Fr] , srcOfs: int
, buf: var seq[Fr] , bufOfs: int
, tgt: var seq[Fr] , tgtOfs: int ) =
case m
, gpows: seq[Fr[BN254_Snarks]]
, src: seq[Fr[BN254_Snarks]] , srcOfs: int
, buf: var seq[Fr[BN254_Snarks]] , bufOfs: int
, tgt: var seq[Fr[BN254_Snarks]] , tgtOfs: int ) =
case m
of 0:
of 0:
tgt[tgtOfs] = src[srcOfs]
of 1:
tgt[tgtOfs ] = ( src[srcOfs] + src[srcOfs+1] )
tgt[tgtOfs+tgtStride] = ( src[srcOfs] - src[srcOfs+1] )
tgt[tgtOfs ] = ( src[srcOfs] + src[srcOfs+1] )
tgt[tgtOfs+tgtStride] = ( src[srcOfs] - src[srcOfs+1] )
div2( tgt[tgtOfs ] )
div2( tgt[tgtOfs+tgtStride] )
else:
let N : int = 1 shl m
let halfN : int = 1 shl (m-1)
let N : int = 1 shl m
let halFN : int = 1 shl (m-1)
for j in 0..<halfN:
buf[bufOfs+j ] = ( src[srcOfs+j] + src[srcOfs+j+halfN] )
buf[bufOfs+j+halfN] = ( src[srcOfs+j] - src[srcOfs+j+halfN] ) * gpows[ j*tgtStride ]
for j in 0..<halFN:
buf[bufOfs+j ] = ( src[srcOfs+j] + src[srcOfs+j+halFN] )
buf[bufOfs+j+halFN] = ( src[srcOfs+j] - src[srcOfs+j+halFN] ) * gpows[ j*tgtStride ]
div2( buf[bufOfs+j ] )
inverseNTT_worker( m-1
@ -129,26 +131,26 @@ func inverseNTT_worker( m: int
inverseNTT_worker( m-1
, tgtStride shl 1
, gpows
, buf , bufOfs + halfN
, buf , bufOfs + halFN
, buf , bufOfs + N
, tgt , tgtOfs + tgtStride )
#---------------------------------------
# inverse number-theoretical transform (corresponds to polynomial interpolation)
func inverseNTT*(src: seq[Fr], D: Domain): seq[Fr] =
func inverseNTT*(src: seq[Fr[BN254_Snarks]], D: Domain): seq[Fr[BN254_Snarks]] =
assert( D.domainSize == (1 shl D.logDomainSize) , "domain must have a power-of-two size" )
assert( D.domainSize == src.len , "input must have the same size as the domain" )
var buf : seq[Fr] = newSeq[Fr]( 2 * D.domainSize )
var tgt : seq[Fr] = newSeq[Fr]( D.domainSize )
var buf = newSeq[Fr[BN254_Snarks]]( 2 * D.domainSize )
var tgt = newSeq[Fr[BN254_Snarks]]( D.domainSize )
# precalc 1/2 times powers of gen^-1
let N = D.domainSize
let halFN = N div 2
var gpows : seq[Fr] = newSeq[Fr]( halfN )
var x : Fr = oneHalfFr
let ginv : Fr = invFr( D.domainGen )
for i in 0..<halfN:
var gpows = newSeq[Fr[BN254_Snarks]]( halFN )
var x = oneHalfFr
let ginv = invFr( D.domainGen )
for i in 0..<halFN:
gpows[i] = x
x *= ginv

View File

@ -9,7 +9,12 @@
import std/sequtils
import std/sugar
import constantine/math/arithmetic except Fp,Fr
import pkg/constantine/math/arithmetic
import pkg/constantine/math/io/io_fields
import pkg/constantine/math/io/io_bigints
import pkg/constantine/named/properties_fields
import pkg/constantine/math/extension_fields/towers
#import constantine/math/io/io_fields
import groth16/bn128
@ -19,9 +24,9 @@ import groth16/misc
#-------------------------------------------------------------------------------
type
type
Poly* = object
coeffs* : seq[Fr]
coeffs* : seq[Fr[BN254_Snarks]]
#-------------------------------------------------------------------------------
@ -31,7 +36,7 @@ func polyDegree*(P: Poly) : int =
while isZeroFr(xs[d]) and (d >= 0): d -= 1
return d
func polyIsZero*(P: Poly) : bool =
func polyIsZero*(P: Poly) : bool =
let xs = P.coeffs ; let n = xs.len
var b = true
for i in 0..<n:
@ -40,9 +45,9 @@ func polyIsZero*(P: Poly) : bool =
break
return b
func polyIsEqual*(P, Q: Poly) : bool =
let xs : seq[Fr] = P.coeffs ; let n = xs.len
let ys : seq[Fr] = Q.coeffs ; let m = ys.len
func polyIsEqual*(P, Q: Poly) : bool =
let xs = P.coeffs ; let n = xs.len
let ys = Q.coeffs ; let m = ys.len
var b = true
if n >= m:
for i in 0..<m: ( if not isEqualFr(xs[i], ys[i]): ( b = false ; break ) )
@ -54,10 +59,10 @@ func polyIsEqual*(P, Q: Poly) : bool =
#-------------------------------------------------------------------------------
func polyEvalAt*(P: Poly, x0: Fr): Fr =
func polyEvalAt*(P: Poly, x0: Fr[BN254_Snarks]): Fr[BN254_Snarks] =
let cs = P.coeffs ; let n = cs.len
var y : Fr = zeroFr
var r : Fr = oneFr
var y : Fr[BN254_Snarks] = zeroFr
var r : Fr[BN254_Snarks] = oneFr
if n > 0: y = cs[0]
for i in 1..<n:
r *= x0
@ -67,13 +72,13 @@ func polyEvalAt*(P: Poly, x0: Fr): Fr =
#-------------------------------------------------------------------------------
func polyNeg*(P: Poly) : Poly =
let zs : seq[Fr] = map( P.coeffs , negFr )
let zs = map( P.coeffs , negFr )
return Poly(coeffs: zs)
func polyAdd*(P, Q: Poly) : Poly =
let xs = P.coeffs ; let n = xs.len
let ys = Q.coeffs ; let m = ys.len
var zs : seq[Fr] = newSeq[Fr](max(n,m))
var zs = newSeq[Fr[BN254_Snarks]](max(n,m))
if n >= m:
for i in 0..<m: zs[i] = ( xs[i] + ys[i] )
for i in m..<n: zs[i] = ( xs[i] )
@ -85,7 +90,7 @@ func polyAdd*(P, Q: Poly) : Poly =
func polySub*(P, Q: Poly) : Poly =
let xs = P.coeffs ; let n = xs.len
let ys = Q.coeffs ; let m = ys.len
var zs : seq[Fr] = newSeq[Fr](max(n,m))
var zs = newSeq[Fr[BN254_Snarks]](max(n,m))
if n >= m:
for i in 0..<m: zs[i] = ( xs[i] - ys[i] )
for i in m..<n: zs[i] = ( xs[i] )
@ -97,7 +102,7 @@ func polySub*(P, Q: Poly) : Poly =
#-------------------------------------------------------------------------------
func polyScale*(s: Fr, P: Poly): Poly =
let zs : seq[Fr] = map( P.coeffs , proc (x: Fr): Fr = s*x )
let zs = map( P.coeffs , proc (x: Fr[BN254_Snarks]): Fr[BN254_Snarks] = s*x )
return Poly(coeffs: zs)
#-------------------------------------------------------------------------------
@ -106,13 +111,13 @@ func polyMulNaive*(P, Q : Poly): Poly =
let xs = P.coeffs ; let n1 = xs.len
let ys = Q.coeffs ; let n2 = ys.len
let N = n1 + n2 - 1
var zs : seq[Fr] = newSeq[Fr](N)
var zs = newSeq[Fr[BN254_Snarks]](N)
for k in 0..<N:
# 0 <= i <= min(k , n1-1)
# 0 <= j <= min(k , n2-1)
# k = i + j
# 0 >= i = k - j >= k - min(k , n2-1)
# 0 >= j = k - i >= k - min(k , n1-1)
# 0 >= j = k - i >= k - min(k , n1-1)
let A : int = max( 0 , k - min(k , n2-1) )
let B : int = min( k , n1-1 )
zs[k] = zeroFr
@ -124,7 +129,7 @@ func polyMulNaive*(P, Q : Poly): Poly =
#-------------------------------------------------------------------------------
# multiply two polynomials using FFT
func polyMulFFT*(P, Q: Poly): Poly =
func polyMulFFT*(P, Q: Poly): Poly =
let n1 = P.coeffs.len
let n2 = Q.coeffs.len
@ -132,10 +137,10 @@ func polyMulFFT*(P, Q: Poly): Poly =
let N : int = (1 shl log2)
let D : Domain = createDomain( N )
let us : seq[Fr] = extendAndForwardNTT( P.coeffs, D )
let vs : seq[Fr] = extendAndForwardNTT( Q.coeffs, D )
let zs : seq[Fr] = collect( newSeq, (for i in 0..<N: us[i]*vs[i] ))
let ws : seq[Fr] = inverseNTT( zs, D )
let us = extendAndForwardNTT( P.coeffs, D )
let vs = extendAndForwardNTT( Q.coeffs, D )
let zs = collect( newSeq, (for i in 0..<N: us[i]*vs[i] ))
let ws = inverseNTT( zs, D )
return Poly(coeffs: ws)
@ -143,8 +148,8 @@ func polyMulFFT*(P, Q: Poly): Poly =
# WARNING: this is using the naive implementation!
func polyMul*(P, Q : Poly): Poly =
# return polyMulFFT(P, Q)
return polyMulNaive(P, Q)
# return polyMulFFT(P, Q)
return polyMulNaive(P, Q)
#-------------------------------------------------------------------------------
@ -160,39 +165,39 @@ func `*`*(P: Poly, s: Fr ): Poly = return polyScale(s, P)
#-------------------------------------------------------------------------------
# the generalized vanishing polynomial `(a*x^N - b)`
func generalizedVanishingPoly*(N: int, a: Fr, b: Fr): Poly =
func generalizedVanishingPoly*(N: int, a: Fr[BN254_Snarks], b: Fr[BN254_Snarks]): Poly =
assert( N>=1 )
var cs : seq[Fr] = newSeq[Fr]( N+1 )
var cs = newSeq[Fr[BN254_Snarks]]( N+1 )
cs[0] = negFr(b)
cs[N] = a
return Poly(coeffs: cs)
# the vanishing polynomial `(x^N - 1)`
func vanishingPoly*(N: int): Poly =
func vanishingPoly*(N: int): Poly =
return generalizedVanishingPoly(N, oneFr, oneFr)
func vanishingPoly*(D: Domain): Poly =
func vanishingPoly*(D: Domain): Poly =
return vanishingPoly(D.domainSize)
#-------------------------------------------------------------------------------
type
QuotRem*[T] = object
quot* : T
rem* : T
quot* : T
rem* : T
# divide by the vanishing polynomial `(x^N - 1)`
# returns the quotient and remainder
func polyQuotRemByVanishing*(P: Poly, N: int): QuotRem[Poly] =
func polyQuotRemByVanishing*(P: Poly, N: int): QuotRem[Poly] =
assert( N>=1 )
let deg : int = polyDegree(P)
let src : seq[Fr] = P.coeffs
var quot : seq[Fr] = newSeq[Fr]( max(1, deg - N + 1) )
var rem : seq[Fr] = newSeq[Fr]( N )
let src = P.coeffs
var quot = newSeq[Fr[BN254_Snarks]]( max(1, deg - N + 1) )
var rem = newSeq[Fr[BN254_Snarks]]( N )
if deg < N:
rem = src
else:
# compute quotient
@ -212,7 +217,7 @@ func polyQuotRemByVanishing*(P: Poly, N: int): QuotRem[Poly] =
return QuotRem[Poly]( quot:Poly(coeffs:quot), rem:Poly(coeffs:rem) )
# divide by the vanishing polynomial `(x^N - 1)`
func polyDivideByVanishing*(P: Poly, N: int): Poly =
func polyDivideByVanishing*(P: Poly, N: int): Poly =
let qr = polyQuotRemByVanishing(P, N)
assert( polyIsZero(qr.rem) )
return qr.quot
@ -222,15 +227,15 @@ func polyDivideByVanishing*(P: Poly, N: int): Poly =
# Lagrange basis polynomials
func lagrangePoly*(D: Domain, k: int): Poly =
let N = D.domainSize
let omMinusK : Fr = smallPowFr( D.invDomainGen , k )
let invN : Fr = invFr(intToFr(N))
let omMinusK = smallPowFr( D.invDomainGen , k )
let invN = invFr(intToFr(N))
var cs : seq[Fr] = newSeq[Fr]( N )
var cs = newSeq[Fr[BN254_Snarks]]( N )
if k == 0:
for i in 0..<N: cs[i] = invN
else:
var s : Fr = invN
for i in 0..<N:
var s = invN
for i in 0..<N:
cs[i] = s
s *= omMinusK
@ -239,7 +244,7 @@ func lagrangePoly*(D: Domain, k: int): Poly =
#---------------------------------------
# evaluate a Lagrange basis polynomial at a given point `zeta` (outside the domain)
func evalLagrangePolyAt*(D: Domain, k: int, zeta: Fr): Fr =
func evalLagrangePolyAt*(D: Domain, k: int, zeta: Fr[BN254_Snarks]): Fr[BN254_Snarks] =
let omegaK = smallPowFr(D.domainGen, k)
let denom = (zeta - omegaK)
if bool(isZero(denom)):
@ -252,16 +257,16 @@ func evalLagrangePolyAt*(D: Domain, k: int, zeta: Fr): Fr =
#-------------------------------------------------------------------------------
# evaluates a polynomial on an FFT domain
func polyForwardNTT*(P: Poly, D: Domain): seq[Fr] =
func polyForwardNTT*(P: Poly, D: Domain): seq[Fr[BN254_Snarks]] =
let n = P.coeffs.len
assert( n <= D.domainSize , "the domain must be as least as big as the polynomial" )
let src : seq[Fr] = P.coeffs
let src = P.coeffs
return forwardNTT(src, D)
#---------------------------------------
# interpolates a polynomial on an FFT domain
func polyInverseNTT*(ys: seq[Fr], D: Domain): Poly =
func polyInverseNTT*(ys: seq[Fr[BN254_Snarks]], D: Domain): Poly =
let n = ys.len
assert( n == D.domainSize , "the domain must be same size as the input" )
let tgt = inverseNTT(ys, D)
@ -280,7 +285,7 @@ proc sanityCheckOneHalf*() =
#-------------------
proc sanityCheckVanishing*() =
proc sanityCheckVanishing*() =
var js : seq[int] = toSeq(101..112)
let cs : seq[Fr] = map( js, intToFr )
let P : Poly = Poly( coeffs:cs )
@ -304,13 +309,13 @@ proc sanityCheckVanishing*() =
#-------------------
proc sanityCheckNTT*() =
proc sanityCheckNTT*() =
var js : seq[int] = toSeq(101..108)
let cs : seq[Fr] = map( js, intToFr )
let P : Poly = Poly( coeffs:cs )
let D : Domain = createDomain(8)
let xs : seq[Fr] = D.enumerateDomain()
let ys : seq[Fr] = collect( newSeq, (for x in xs: polyEvalAt(P,x)) )
let ys : seq[Fr] = collect( newSeq, (for x in xs: polyEvalAt(P,x)) )
let zs : seq[Fr] = polyForwardNTT(P ,D)
let Q : Poly = polyInverseNTT(zs,D)
debugPrintFrSeq("xs", xs)
@ -320,7 +325,7 @@ proc sanityCheckNTT*() =
#-------------------
proc sanityCheckMulFFT*() =
proc sanityCheckMulFFT*() =
var js : seq[int] = toSeq(101..110)
let cs : seq[Fr] = map( js, intToFr )
let P : Poly = Poly( coeffs:cs )
@ -339,19 +344,19 @@ proc sanityCheckMulFFT*() =
#-------------------
proc sanityCheckLagrangeBases*() =
proc sanityCheckLagrangeBases*() =
let n = 8
let D = createDomain(n)
let L : seq[Poly] = collect( newSeq, (for k in 0..<n: lagrangePoly(D,k) ))
let xs = enumerateDomain(D)
let ys0 : seq[Fr] = collect( newSeq, (for x in xs: polyEvalAt(L[0],x) ))
let ys1 : seq[Fr] = collect( newSeq, (for x in xs: polyEvalAt(L[1],x) ))
let ys5 : seq[Fr] = collect( newSeq, (for x in xs: polyEvalAt(L[5],x) ))
let zs0 : seq[Fr] = collect( newSeq, (for i in 0..<n: deltaFr(0,i) ))
let zs1 : seq[Fr] = collect( newSeq, (for i in 0..<n: deltaFr(1,i) ))
let zs5 : seq[Fr] = collect( newSeq, (for i in 0..<n: deltaFr(5,i) ))
let ys0 : seq[Fr] = collect( newSeq, (for x in xs: polyEvalAt(L[0],x) ))
let ys1 : seq[Fr] = collect( newSeq, (for x in xs: polyEvalAt(L[1],x) ))
let ys5 : seq[Fr] = collect( newSeq, (for x in xs: polyEvalAt(L[5],x) ))
let zs0 : seq[Fr] = collect( newSeq, (for i in 0..<n: deltaFr(0,i) ))
let zs1 : seq[Fr] = collect( newSeq, (for i in 0..<n: deltaFr(1,i) ))
let zs5 : seq[Fr] = collect( newSeq, (for i in 0..<n: deltaFr(5,i) ))
echo("==============")
for i in 0..<n: echo("i = ",i, " | y[i] = ",toDecimalFr(ys0[i]), " | z[i] = ",toDecimalFr(zs0[i]))
@ -359,16 +364,16 @@ proc sanityCheckLagrangeBases*() =
for i in 0..<n: echo("i = ",i, " | y[i] = ",toDecimalFr(ys1[i]), " | z[i] = ",toDecimalFr(zs1[i]))
echo("--------------")
for i in 0..<n: echo("i = ",i, " | y[i] = ",toDecimalFr(ys5[i]), " | z[i] = ",toDecimalFr(zs5[i]))
let zeta = intToFr(123457)
let us : seq[Fr] = collect( newSeq, (for i in 0..<n: polyEvalAt(L[i],zeta)) )
let vs : seq[Fr] = collect( newSeq, (for i in 0..<n: evalLagrangePolyAt(D,i,zeta)) )
let us : seq[Fr] = collect( newSeq, (for i in 0..<n: polyEvalAt(L[i],zeta)) )
let vs : seq[Fr] = collect( newSeq, (for i in 0..<n: evalLagrangePolyAt(D,i,zeta)) )
echo("==============")
for i in 0..<n: echo("i = ",i, " | u[i] = ",toDecimalFr(us[i]), " | v[i] = ",toDecimalFr(vs[i]))
let prefix = "Lagrange basis sanity check = "
if ( ys0===zs0 and ys1===zs1 and ys5===zs5 and
if ( ys0===zs0 and ys1===zs1 and ys5===zs5 and
us===vs ):
echo( prefix & "OK")
else:

View File

@ -3,8 +3,8 @@
# miscellaneous routines
#
import strformat
import times, os, strutils
import std/strformat
import std/[times, os, strutils]
#-------------------------------------------------------------------------------
@ -32,7 +32,7 @@ func delta*(i, j: int) : int =
#-------------------------------------------------------------------------------
func floorLog2* (x : int) : int =
func floorLog2* (x : int) : int =
var k = -1
var y = x
while (y > 0):
@ -40,7 +40,7 @@ func floorLog2* (x : int) : int =
y = y shr 1
return k
func ceilingLog2* (x : int) : int =
func ceilingLog2* (x : int) : int =
if (x==0):
return -1
else:

View File

@ -2,28 +2,28 @@
#
# Groth16 prover
#
# WARNING!
# WARNING!
# the points H in `.zkey` are *NOT* what normal people would think they are
# See <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
#
#[
import sugar
import constantine/math/config/curves
import constantine/math/io/io_fields
import constantine/math/io/io_bigints
import ./zkey
]#
{.push raises: [].}
import std/os
import std/times
import std/cpuinfo
import system
import taskpools
import std/isolation
import pkg/taskpools
import pkg/constantine/math/arithmetic
import pkg/constantine/math/io/io_fields
import pkg/constantine/math/io/io_bigints
import pkg/constantine/named/properties_fields
import pkg/constantine/math/extension_fields/towers
import constantine/math/arithmetic except Fp, Fr
#import constantine/math/io/io_extfields except Fp12
#import constantine/math/extension_fields/towers except Fp2, Fp12
#import constantine/math/extension_fields/towers except Fp2, Fp12
import groth16/bn128
import groth16/math/domain
@ -36,7 +36,7 @@ import groth16/misc
type
Proof* = object
publicIO* : seq[Fr]
publicIO* : seq[Fr[BN254_Snarks]]
pi_a* : G1
pi_b* : G2
pi_c* : G1
@ -44,29 +44,35 @@ type
#-------------------------------------------------------------------------------
# Az, Bz, Cz column vectors
#
#
type
ABC = object
valuesAz : seq[Fr]
valuesBz : seq[Fr]
valuesCz : seq[Fr]
valuesAz : seq[Fr[BN254_Snarks]]
valuesBz : seq[Fr[BN254_Snarks]]
valuesCz : seq[Fr[BN254_Snarks]]
ShiftEvalDomainTask* = object
values : seq[Fr[BN254_Snarks]]
D : Domain
eta : Fr[BN254_Snarks]
res : Isolated[seq[Fr[BN254_Snarks]]]
# computes the vectors A*z, B*z, C*z where z is the witness
func buildABC( zkey: ZKey, witness: seq[Fr] ): ABC =
func buildABC( zkey: ZKey, witness: seq[Fr[BN254_Snarks]] ): ABC =
let hdr: GrothHeader = zkey.header
let domSize = hdr.domainSize
var valuesAz : seq[Fr] = newSeq[Fr](domSize)
var valuesBz : seq[Fr] = newSeq[Fr](domSize)
var valuesAz = newSeq[Fr[BN254_Snarks]](domSize)
var valuesBz = newSeq[Fr[BN254_Snarks]](domSize)
for entry in zkey.coeffs:
case entry.matrix
case entry.matrix
of MatrixA: valuesAz[entry.row] += entry.coeff * witness[entry.col]
of MatrixB: valuesBz[entry.row] += entry.coeff * witness[entry.col]
else: raise newException(AssertionDefect, "fatal error")
var valuesCz : seq[Fr] = newSeq[Fr](domSize)
var valuesCz = newSeq[Fr[BN254_Snarks]](domSize)
for i in 0..<domSize:
valuesCz[i] = valuesAz[i] * valuesBz[i]
@ -93,90 +99,132 @@ func computeQuotientNaive( abc: ABC ): Poly=
#---------------------------------------
# returns [ eta^i * xs[i] | i<-[0..n-1] ]
func multiplyByPowers( xs: seq[Fr], eta: Fr ): seq[Fr] =
func multiplyByPowers( xs: seq[Fr[BN254_Snarks]], eta: Fr[BN254_Snarks] ): seq[Fr[BN254_Snarks]] =
let n = xs.len
assert(n >= 1)
var ys : seq[Fr] = newSeq[Fr](n)
var ys = newSeq[Fr[BN254_Snarks]](n)
ys[0] = xs[0]
if n >= 1: ys[1] = eta * xs[1]
var spow : Fr = eta
for i in 2..<n:
var spow = eta
for i in 2..<n:
spow *= eta
ys[i] = spow * xs[i]
return ys
# interpolates a polynomial, shift the variable by `eta`, and compute the shifted values
func shiftEvalDomain( values: seq[Fr], D: Domain, eta: Fr ): seq[Fr] =
func shiftEvalDomain(
values: seq[Fr[BN254_Snarks]],
D: Domain,
eta: Fr[BN254_Snarks] ): seq[Fr[BN254_Snarks]] =
let poly : Poly = polyInverseNTT( values , D )
let cs : seq[Fr] = poly.coeffs
var ds : seq[Fr] = multiplyByPowers( cs, eta )
let cs : seq[Fr[BN254_Snarks]] = poly.coeffs
var ds : seq[Fr[BN254_Snarks]] = multiplyByPowers( cs, eta )
return polyForwardNTT( Poly(coeffs:ds), D )
func shiftEvalDomainTask( task: ptr ShiftEvalDomainTask ): bool =
let D = task[].D
let eta = task[].eta
let values = task[].values
var res = isolate(shiftEvalDomain( values, D, eta ))
task[].res = move res
return true
# computes the quotient polynomial Q = (A*B - C) / Z
# by computing the values on a shifted domain, and interpolating the result
# remark: Q has degree `n-2`, so it's enough to use a domain of size n
proc computeQuotientPointwise( nthreads: int, abc: ABC ): Poly =
proc computeQuotientPointwise( abc: ABC, pool: Taskpool ): Poly =
let n = abc.valuesAz.len
assert( abc.valuesBz.len == n )
assert( abc.valuesCz.len == n )
let D = createDomain(n)
# (eta*omega^j)^n - 1 = eta^n - 1
# (eta*omega^j)^n - 1 = eta^n - 1
# 1 / [ (eta*omega^j)^n - 1] = 1/(eta^n - 1)
let eta = createDomain(2*n).domainGen
let invZ1 = invFr( smallPowFr(eta,n) - oneFr )
var pool = Taskpool.new(num_threads = nthreads)
var a1fvTask = ShiftEvalDomainTask(
values: abc.valuesAz,
D: D,
eta: eta,
)
var A1fv = pool.spawn shiftEvalDomainTask( addr a1fvTask )
var A1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesAz, D, eta )
var B1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesBz, D, eta )
var C1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesCz, D, eta )
var b1fvTask = ShiftEvalDomainTask(
values: abc.valuesBz,
D: D,
eta: eta,
)
var B1fv = pool.spawn shiftEvalDomainTask( addr b1fvTask )
let A1 = sync A1fv
let B1 = sync B1fv
let C1 = sync C1fv
var c1fvTask = ShiftEvalDomainTask(
values: abc.valuesCz,
D: D,
eta: eta,
)
var C1fv = pool.spawn shiftEvalDomainTask( addr c1fvTask )
var ys : seq[Fr] = newSeq[Fr]( n )
discard sync A1fv
let A1 = a1fvTask.res.extract
discard sync B1fv
let B1 = b1fvTask.res.extract
discard sync C1fv
let C1 = c1fvTask.res.extract
var ys = newSeq[Fr[BN254_Snarks]]( n )
for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] ) * invZ1
let Q1 = polyInverseNTT( ys, D )
let cs = multiplyByPowers( Q1.coeffs, invFr(eta) )
pool.syncAll()
pool.shutdown()
return Poly(coeffs: cs)
#---------------------------------------
# Snarkjs does something different, not actually computing the quotient poly
# they can get away with this, because during the trusted setup, they
# replace the points encoding the values `delta^-1 * tau^i * Z(tau)` by
# replace the points encoding the values `delta^-1 * tau^i * Z(tau)` by
# (shifted) Lagrange bases.
# see <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
#
proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC): seq[Fr] =
proc computeSnarkjsScalarCoeffs( abc: ABC, pool: Taskpool ): seq[Fr[BN254_Snarks]] =
let n = abc.valuesAz.len
assert( abc.valuesBz.len == n )
assert( abc.valuesCz.len == n )
let D = createDomain(n)
let eta = createDomain(2*n).domainGen
var pool = Taskpool.new(num_threads = nthreads)
var a1fvTask = ShiftEvalDomainTask(
values: abc.valuesAz,
D: D,
eta: eta,
)
var A1fv = pool.spawn shiftEvalDomainTask( addr a1fvTask )
var A1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesAz, D, eta )
var B1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesBz, D, eta )
var C1fv : FlowVar[seq[Fr]] = pool.spawn shiftEvalDomain( abc.valuesCz, D, eta )
var b1fvTask = ShiftEvalDomainTask(
values: abc.valuesBz,
D: D,
eta: eta,
)
var B1fv = pool.spawn shiftEvalDomainTask( addr b1fvTask )
let A1 = sync A1fv
let B1 = sync B1fv
let C1 = sync C1fv
var c1fvTask = ShiftEvalDomainTask(
values: abc.valuesCz,
D: D,
eta: eta,
)
var C1fv = pool.spawn shiftEvalDomainTask( addr c1fvTask )
var ys : seq[Fr] = newSeq[Fr]( n )
for j in 0..<n: ys[j] = ( A1[j] * B1[j] - C1[j] )
discard sync A1fv
let A1 = a1fvTask.res.extract
discard sync B1fv
let B1 = b1fvTask.res.extract
discard sync C1fv
let C1 = c1fvTask.res.extract
pool.syncAll()
pool.shutdown()
var ys = newSeq[Fr[BN254_Snarks]]( n )
for j in 0..<n: ys[j] = ( A1[j] * B1[j] - C1[j] )
return ys
@ -192,7 +240,7 @@ proc computeSnarkjsScalarCoeffs_st( abc: ABC ): seq[Fr] =
let B1 : seq[Fr] = shiftEvalDomain( abc.valuesBz, D, eta )
let C1 : seq[Fr] = shiftEvalDomain( abc.valuesCz, D, eta )
var ys : seq[Fr] = newSeq[Fr]( n )
for j in 0..<n: ys[j] = ( A1[j] * B1[j] - C1[j] )
for j in 0..<n: ys[j] = ( A1[j] * B1[j] - C1[j] )
return ys
proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC ): seq[Fr] =
@ -209,13 +257,13 @@ proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC ): seq[Fr] =
type
Mask* = object
r*: Fr # masking coefficients
s*: Fr # for zero knowledge
r*: Fr[BN254_Snarks] # masking coefficients
s*: Fr[BN254_Snarks] # for zero knowledge
proc generateProofWithMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns: Witness, mask: Mask ): Proof =
proc generateProofWithMask*( zkey: ZKey, wtns: Witness, mask: Mask, pool: Taskpool, printTimings: bool ): Proof =
when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)):
{.fatal: "Compile with arc/orc!".}
# when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)):
# {.fatal: "Compile with arc/orc!".}
# if (zkey.header.curve != wtns.curve):
# echo( "zkey.header.curve = " & ($zkey.header.curve) )
@ -228,7 +276,7 @@ proc generateProofWithMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns
let hdr : GrothHeader = zkey.header
let spec : SpecPoints = zkey.specPoints
let pts : ProverPoints = zkey.pPoints
let pts : ProverPoints = zkey.pPoints
let nvars = hdr.nvars
let npubs = hdr.npubs
@ -236,30 +284,30 @@ proc generateProofWithMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns
assert( nvars == witness.len , "wrong witness length" )
# remark: with the special variable "1" we actuall have (npub+1) public IO variables
var pubIO : seq[Fr] = newSeq[Fr]( npubs + 1)
for i in 0..npubs: pubIO[i] = witness[i]
var pubIO = newSeq[Fr[BN254_Snarks]]( npubs + 1)
for i in 0..npubs: pubIO[i] = witness[i]
start = cpuTime()
var abc : ABC
var abc : ABC
withMeasureTime(printTimings,"building 'ABC'"):
abc = buildABC( zkey, witness )
start = cpuTime()
var qs : seq[Fr]
var qs : seq[Fr[BN254_Snarks]]
withMeasureTime(printTimings,"computing the quotient (FFTs)"):
case zkey.header.flavour
# the points H are [delta^-1 * tau^i * Z(tau)]
of JensGroth:
let polyQ = computeQuotientPointwise( nthreads, abc )
let polyQ = computeQuotientPointwise( abc, pool )
qs = polyQ.coeffs
# the points H are `[delta^-1 * L_{2i+1}(tau)]_1`
# where L_i are Lagrange basis polynomials on the double-sized domain
of Snarkjs:
qs = computeSnarkjsScalarCoeffs( nthreads, abc )
qs = computeSnarkjsScalarCoeffs( abc, pool )
var zs : seq[Fr] = newSeq[Fr]( nvars - npubs - 1 )
var zs = newSeq[Fr[BN254_Snarks]]( nvars - npubs - 1 )
for j in npubs+1..<nvars:
zs[j-npubs-1] = witness[j]
@ -275,45 +323,45 @@ proc generateProofWithMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns
assert( nvars - npubs - 1 == zs.len )
assert( nvars - npubs - 1 == pts.pointsC1.len )
var pi_a : G1
var pi_a : G1
withMeasureTime(printTimings,"computing pi_A (G1 MSM)"):
pi_a = spec.alpha1
pi_a += r ** spec.delta1
pi_a += msmMultiThreadedG1( nthreads , witness , pts.pointsA1 )
pi_a += msmMultiThreadedG1( witness , pts.pointsA1, pool )
var rho : G1
var rho : G1
withMeasureTime(printTimings,"computing rho (G1 MSM)"):
rho = spec.beta1
rho += s ** spec.delta1
rho += msmMultiThreadedG1( nthreads , witness , pts.pointsB1 )
rho += msmMultiThreadedG1( witness , pts.pointsB1, pool )
var pi_b : G2
withMeasureTime(printTimings,"computing pi_B (G2 MSM)"):
pi_b = spec.beta2
pi_b += s ** spec.delta2
pi_b += msmMultiThreadedG2( nthreads , witness , pts.pointsB2 )
pi_b += msmMultiThreadedG2( witness , pts.pointsB2, pool )
var pi_c : G1
withMeasureTime(printTimings,"computing pi_C (2x G1 MSM)"):
pi_c = s ** pi_a
pi_c += r ** rho
pi_c += negFr(r*s) ** spec.delta1
pi_c += msmMultiThreadedG1( nthreads, qs , pts.pointsH1 )
pi_c += msmMultiThreadedG1( nthreads, zs , pts.pointsC1 )
pi_c += msmMultiThreadedG1( qs , pts.pointsH1, pool )
pi_c += msmMultiThreadedG1( zs , pts.pointsC1, pool )
return Proof( curve:"bn128", publicIO:pubIO, pi_a:pi_a, pi_b:pi_b, pi_c:pi_c )
#-------------------------------------------------------------------------------
proc generateProofWithTrivialMask*( nthreads: int, printTimings: bool, zkey: ZKey, wtns: Witness ): Proof =
proc generateProofWithTrivialMask*( zkey: ZKey, wtns: Witness, pool: Taskpool, printTimings: bool ): Proof =
let mask = Mask( r: zeroFr , s: zeroFr )
return generateProofWithMask( nthreads, printTimings, zkey, wtns, mask )
return generateProofWithMask( zkey, wtns, mask, pool, printTimings )
proc generateProof*( nthreads: int, printTimings: bool, zkey: ZKey, wtns: Witness ): Proof =
proc generateProof*( zkey: ZKey, wtns: Witness, pool: Taskpool, printTimings: bool = false ): Proof {.raises: [].} =
# masking coeffs
let r : Fr = randFr()
let s : Fr = randFr()
let r = randFr()
let s = randFr()
let mask = Mask(r: r, s: s)
return generateProofWithMask( nthreads, printTimings, zkey, wtns, mask )
return generateProofWithMask( zkey, wtns, mask, pool, printTimings )

View File

@ -1,7 +1,7 @@
import std/[times,os]
import strformat
import std/strformat
import groth16/prover
import groth16/verifier
@ -15,7 +15,7 @@ func seconds(x: float): string = fmt"{x:.4f} seconds"
#-------------------------------------------------------------------------------
proc testProveAndVerify*( zkey_fname, wtns_fname: string): (VKey,Proof) =
proc testProveAndVerify*( zkey_fname, wtns_fname: string): (VKey,Proof) =
echo("parsing witness & zkey files...")
let witness = parseWitness( wtns_fname)
@ -23,7 +23,7 @@ proc testProveAndVerify*( zkey_fname, wtns_fname: string): (VKey,Proof) =
echo("generating proof...")
let start = cpuTime()
let proof = generateProof( zkey, witness )
let proof = generateProof( zkey, witness, Taskpool.new() )
let elapsed = cpuTime() - start
echo("proving took ",seconds(elapsed))
@ -36,7 +36,7 @@ proc testProveAndVerify*( zkey_fname, wtns_fname: string): (VKey,Proof) =
#-------------------------------------------------------------------------------
proc testFakeSetupAndVerify*( r1cs_fname, wtns_fname: string, flavour=Snarkjs): (VKey,Proof) =
proc testFakeSetupAndVerify*( r1cs_fname, wtns_fname: string, flavour=Snarkjs): (VKey,Proof) =
echo("trusted setup flavour = ",flavour)
echo("parsing witness & r1cs files...")
@ -55,7 +55,7 @@ proc testFakeSetupAndVerify*( r1cs_fname, wtns_fname: string, flavour=Snarkjs):
let vkey = extractVKey( zkey)
let start = cpuTime()
let proof = generateProof( zkey, witness )
let proof = generateProof( zkey, witness, Taskpool.new() )
let elapsed = cpuTime() - start
echo("proving took ",seconds(elapsed))

View File

@ -2,22 +2,25 @@
#
# Groth16 prover
#
# WARNING!
# WARNING!
# the points H in `.zkey` are *NOT* what normal people would think they are
# See <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
#
#[
import sugar
import constantine/math/config/curves
import constantine/math/config/curves
import constantine/math/io/io_fields
import constantine/math/io/io_bigints
import ./zkey
]#
# import constantine/math/arithmetic except Fp, Fr
import constantine/math/io/io_extfields except Fp12
import constantine/math/extension_fields/towers except Fp2, Fp12
import pkg/constantine/math/io/io_extfields
import pkg/constantine/math/extension_fields/towers
import pkg/constantine/math/arithmetic
import pkg/constantine/named/properties_fields
import pkg/constantine/math/extension_fields/towers
import groth16/bn128
import groth16/zkey_types
@ -36,15 +39,15 @@ proc verifyProof* (vkey: VKey, prf: Proof): bool =
assert( isOnCurveG2(prf.pi_b) , "pi_b is not in G2" )
assert( isOnCurveG1(prf.pi_c) , "pi_c is not in G1" )
var pubG1 : G1 = msmG1( prf.publicIO , vkey.vpoints.pointsIC )
var pubG1 : G1 = msmG1( prf.publicIO , vkey.vpoints.pointsIC )
let lhs : Fp12 = pairing( negG1(prf.pi_a) , prf.pi_b ) # < -pi_a , pi_b >
let rhs1 : Fp12 = vkey.spec.alphaBeta # < alpha , beta >
let rhs2 : Fp12 = pairing( prf.pi_c , vkey.spec.delta2 ) # < pi_c , delta >
let rhs3 : Fp12 = pairing( pubG1 , vkey.spec.gamma2 ) # < sum... , gamma >
let lhs = pairing( negG1(prf.pi_a) , prf.pi_b ) # < -pi_a , pi_b >
let rhs1 = vkey.spec.alphaBeta # < alpha , beta >
let rhs2 = pairing( prf.pi_c , vkey.spec.delta2 ) # < pi_c , delta >
let rhs3 = pairing( pubG1 , vkey.spec.gamma2 ) # < sum... , gamma >
var eq : Fp12
eq = lhs
var eq : Fp12[BN254_Snarks]
eq = lhs
eq *= rhs1
eq *= rhs2
eq *= rhs3

View File

@ -1,12 +1,16 @@
import constantine/math/arithmetic except Fp, Fr
import pkg/constantine/math/arithmetic
import pkg/constantine/math/io/io_fields
import pkg/constantine/math/io/io_bigints
import pkg/constantine/named/properties_fields
import pkg/constantine/math/extension_fields/towers
import groth16/bn128
#-------------------------------------------------------------------------------
type
type
Flavour* = enum
JensGroth # the version described in the original Groth16 paper
Snarkjs # the version implemented by Snarkjs
@ -27,8 +31,8 @@ type
beta2* : G2 # = beta * g2
gamma2* : G2 # = gamma * g2
delta1* : G1 # = delta * g1
delta2* : G2 # = delta * g2
alphaBeta* : Fp12 # = <alpha1 , beta2>
delta2* : G2 # = delta * g2
alphaBeta* : Fp12[BN254_Snarks] # = <alpha1 , beta2>
VerifierPoints* = object
pointsIC* : seq[G1] # the points `delta^-1 * ( beta*A_j(tau) + alpha*B_j(tau) + C_j(tau) ) * g1` (for j <= npub)
@ -49,7 +53,7 @@ type
matrix* : MatrixSel
row* : int
col* : int
coeff* : Fr
coeff* : Fr[BN254_Snarks]
ZKey* = object
# sectionMask* : uint32
@ -59,14 +63,14 @@ type
pPoints* : ProverPoints
coeffs* : seq[Coeff]
VKey* = object
VKey* = object
curve* : string
spec* : SpecPoints
vpoints* : VerifierPoints
#-------------------------------------------------------------------------------
func extractVKey*(zkey: Zkey): VKey =
func extractVKey*(zkey: ZKey): VKey =
let curve = zkey.header.curve
let spec = zkey.specPoints
let vpts = zkey.vPoints
@ -74,32 +78,32 @@ func extractVKey*(zkey: Zkey): VKey =
#-------------------------------------------------------------------------------
proc printGrothHeader*(hdr: GrothHeader) =
echo("curve = " & ($hdr.curve ) )
echo("flavour = " & ($hdr.flavour ) )
echo("|Fp| = " & (toDecimalBig(hdr.p)) )
echo("|Fr| = " & (toDecimalBig(hdr.r)) )
echo("nvars = " & ($hdr.nvars ) )
echo("npubs = " & ($hdr.npubs ) )
echo("domainSize = " & ($hdr.domainSize ) )
echo("logDomainSize= " & ($hdr.logDomainSize) )
proc printGrothHeader*(hdr: GrothHeader) =
echo("curve = " & ($hdr.curve ) )
echo("flavour = " & ($hdr.flavour ) )
echo("|Fp| = " & (toDecimalBig(hdr.p)) )
echo("|Fr| = " & (toDecimalBig(hdr.r)) )
echo("nvars = " & ($hdr.nvars ) )
echo("npubs = " & ($hdr.npubs ) )
echo("domainSize = " & ($hdr.domainSize ) )
echo("logDomainSize= " & ($hdr.logDomainSize) )
#-------------------------------------------------------------------------------
func matrixSelToString(sel: MatrixSel): string =
case sel
func matrixSelToString(sel: MatrixSel): string =
case sel
of MatrixA: return "A"
of MatrixB: return "B"
of MatrixC: return "C"
proc debugPrintCoeff(cf: Coeff) =
proc debugPrintCoeff(cf: Coeff) =
echo( "matrix=", matrixSelToString(cf.matrix)
, " | i=", cf.row
, " | j=", cf.col
, " | val=", signedToDecimalFr(cf.coeff)
)
proc debugPrintCoeffs*(cfs: seq[Coeff]) =
proc debugPrintCoeffs*(cfs: seq[Coeff]) =
for cf in cfs: debugPrintCoeff(cf)
#-------------------------------------------------------------------------------

View File

@ -19,7 +19,7 @@ const myWitnessCfg =
, nPubOut: 1 # public output = input + a*b*c = 1022 + 7*11*13 = 2023
, nPubIn: 1 # public input = 1022
, nPrivIn: 3 # private inputs: 7, 11, 13
, nLabels: 0
, nLabels: 0
)
# 2023 == 1022 + 7*3*11
@ -44,10 +44,10 @@ const myR1CS =
)
# the equation we want prove is `7*11*13 + 1022 == 2023`
let myWitnessValues : seq[Fr] = map( @[ 1, 2023, 1022, 7, 11, 13, 7*11, 7*11*13 ] , intToFr )
let myWitnessValues = map( @[ 1, 2023, 1022, 7, 11, 13, 7*11, 7*11*13 ] , intToFr )
# wire indices: ^^^^^^^ 0 1 2 3 4 5 6 7
let myWitness =
let myWitness =
Witness( curve: "bn128"
, r: primeR
, nvars: 8
@ -56,8 +56,8 @@ let myWitness =
#-------------------------------------------------------------------------------
proc testProof(zkey: ZKey, witness: Witness): bool =
let proof = generateProof( zkey, witness )
proc testProof(zkey: ZKey, witness: Witness): bool =
let proof = generateProof( zkey, witness, Taskpool.new() )
let vkey = extractVKey( zkey)
let ok = verifyProof( vkey, proof )
return ok
@ -65,11 +65,11 @@ proc testProof(zkey: ZKey, witness: Witness): bool =
suite "prover":
test "prove & verify simple multiplication circuit, `JensGroth` flavour":
let zkey = createFakeCircuitSetup( myR1cs, flavour=JensGroth )
let zkey = createFakeCircuitSetup( myR1CS, flavour=JensGroth )
check testProof( zkey, myWitness )
test "prove & verify simple multiplication circuit, `Snarkjs` flavour":
let zkey = createFakeCircuitSetup( myR1cs, flavour=Snarkjs )
let zkey = createFakeCircuitSetup( myR1CS, flavour=Snarkjs )
check testProof( zkey, myWitness )
#-------------------------------------------------------------------------------