mirror of
https://github.com/logos-storage/nim-groth16.git
synced 2026-02-04 05:43:09 +00:00
initial support for "partial proofs" (precalculating based on a partial witness)
This commit is contained in:
parent
f7a4549d86
commit
3ddf4948f9
@ -21,7 +21,7 @@ at your choice.
|
||||
|
||||
- [x] find and fix the _second_ totally surreal bug
|
||||
- [ ] clean up the code
|
||||
- [ ] make it compatible with the latest constantine and also Nim 2.0.x
|
||||
- [x] make it compatible with the latest constantine and also Nim 2.0.x
|
||||
- [x] make it a nimble package
|
||||
- [ ] compare `.r1cs` to the "coeffs" section of `.zkey`
|
||||
- [x] generate fake circuit-specific setup ourselves
|
||||
@ -29,5 +29,5 @@ at your choice.
|
||||
- [x] multithreading support (MSM, and possibly also FFT)
|
||||
- [ ] add Groth16 notes
|
||||
- [ ] document the `snarkjs` circuit-specific setup `H` points convention
|
||||
- [x] precalculate stuff for "partial" proofs
|
||||
- [ ] make it work for different curves
|
||||
|
||||
|
||||
@ -1,13 +1,8 @@
|
||||
|
||||
import sugar
|
||||
import std/strutils
|
||||
import std/sequtils
|
||||
import std/os
|
||||
import std/cpuinfo
|
||||
import std/parseopt
|
||||
import std/times
|
||||
import std/options
|
||||
# import strformat
|
||||
|
||||
import taskpools
|
||||
|
||||
@ -21,6 +16,8 @@ import groth16/zkey_types
|
||||
import groth16/fake_setup
|
||||
import groth16/misc
|
||||
|
||||
import testing
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc printHelp() =
|
||||
@ -37,6 +34,7 @@ proc printHelp() =
|
||||
echo " -y, --verify : verify a proof"
|
||||
echo " -u, --setup : perform (fake) trusted setup"
|
||||
echo " -n, --nomask : don't use random masking for full ZK"
|
||||
echo " -s, --sanity : sanity test the partial prover"
|
||||
echo " -z, --zkey = <circuit.zkey> : the `.zkey` file"
|
||||
echo " -w, --wtns = <circuit.wtns> : the `.wtns` file"
|
||||
echo " -r, --r1cs = <circuit.r1cs> : the `.r1cs` file"
|
||||
@ -46,33 +44,35 @@ proc printHelp() =
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
type Config = object
|
||||
zkey_file: string
|
||||
r1cs_file: string
|
||||
wtns_file: string
|
||||
output_file: string
|
||||
io_file: string
|
||||
verbose: bool
|
||||
debug: bool
|
||||
measure_time: bool
|
||||
do_prove: bool
|
||||
do_verify: bool
|
||||
do_setup: bool
|
||||
no_masking: bool
|
||||
nthreads: int
|
||||
zkey_file: string
|
||||
r1cs_file: string
|
||||
wtns_file: string
|
||||
output_file: string
|
||||
io_file: string
|
||||
verbose: bool
|
||||
debug: bool
|
||||
measure_time: bool
|
||||
do_prove: bool
|
||||
do_verify: bool
|
||||
do_setup: bool
|
||||
no_masking: bool
|
||||
partial_sanity: bool
|
||||
nthreads: int
|
||||
|
||||
const dummyConfig =
|
||||
Config( zkey_file: ""
|
||||
, r1cs_file: ""
|
||||
, wtns_file: ""
|
||||
, output_file: ""
|
||||
, io_file: ""
|
||||
, verbose: false
|
||||
, measure_time: false
|
||||
, do_prove: false
|
||||
, do_verify: false
|
||||
, do_setup: false
|
||||
, no_masking: false
|
||||
, nthreads: 0
|
||||
Config( zkey_file: ""
|
||||
, r1cs_file: ""
|
||||
, wtns_file: ""
|
||||
, output_file: ""
|
||||
, io_file: ""
|
||||
, verbose: false
|
||||
, measure_time: false
|
||||
, do_prove: false
|
||||
, do_verify: false
|
||||
, do_setup: false
|
||||
, no_masking: false
|
||||
, partial_sanity: false
|
||||
, nthreads: 0
|
||||
)
|
||||
|
||||
proc printConfig(cfg: Config) =
|
||||
@ -114,6 +114,7 @@ proc parseCliOptions(): Config =
|
||||
of "y", "verify" : cfg.do_verify = true
|
||||
of "u", "setup" : cfg.do_setup = true
|
||||
of "n", "nomask" : cfg.no_masking = true
|
||||
of "s", "sanity" : cfg.partial_sanity = true
|
||||
of "o", "output" : cfg.output_file = value
|
||||
of "r", "r1cs" : cfg.r1cs_file = value
|
||||
of "z", "zkey" : cfg.zkey_file = value
|
||||
@ -136,28 +137,6 @@ proc parseCliOptions(): Config =
|
||||
|
||||
return cfg
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
#[
|
||||
proc testProveAndVerify*( zkey_fname, wtns_fname: string): (VKey,Proof) =
|
||||
|
||||
echo("parsing witness & zkey files...")
|
||||
let witness = parseWitness( wtns_fname)
|
||||
let zkey = parseZKey( zkey_fname)
|
||||
|
||||
echo("generating proof...")
|
||||
let start = cpuTime()
|
||||
let proof = generateProof( zkey, witness )
|
||||
let elapsed = cpuTime() - start
|
||||
echo("proving took ",seconds(elapsed))
|
||||
|
||||
echo("verifying the proof...")
|
||||
let vkey = extractVKey( zkey)
|
||||
let ok = verifyProof( vkey, proof )
|
||||
echo("verification succeeded = ",ok)
|
||||
|
||||
return (vkey,proof)
|
||||
]#
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
@ -198,6 +177,16 @@ proc cliMain(cfg: Config) =
|
||||
printGrothHeader(zkey.header)
|
||||
# debugPrintCoeffs(zkey.coeffs)
|
||||
|
||||
if cfg.partial_sanity:
|
||||
if (cfg.wtns_file=="") or (cfg.zkey_file=="" and cfg.do_setup==false):
|
||||
echo("cannot prove: missing witness and/or zkey file!")
|
||||
quit()
|
||||
else:
|
||||
var pool = Taskpool.new(cfg.nthreads)
|
||||
let print_timings = cfg.measure_time and cfg.verbose
|
||||
sanityCheckPartialProofs(zkey,wtns,pool,print_timings)
|
||||
echo("sanity testing partial proofs...")
|
||||
|
||||
if cfg.do_prove:
|
||||
if (cfg.wtns_file=="") or (cfg.zkey_file=="" and cfg.do_setup==false):
|
||||
echo("cannot prove: missing witness and/or zkey file!")
|
||||
|
||||
102
cli/testing.nim
Normal file
102
cli/testing.nim
Normal file
@ -0,0 +1,102 @@
|
||||
|
||||
import std/strutils
|
||||
import std/times
|
||||
import std/options
|
||||
import std/random
|
||||
import std/syncio
|
||||
|
||||
import taskpools
|
||||
|
||||
import constantine/named/properties_fields
|
||||
|
||||
# import groth16/bn128
|
||||
import groth16/zkey_types
|
||||
import groth16/files/witness
|
||||
import groth16/misc
|
||||
import groth16/files/export_json
|
||||
|
||||
import groth16/partial/types
|
||||
import groth16/partial/precalc
|
||||
import groth16/partial/finish
|
||||
|
||||
import groth16/prover
|
||||
import groth16/prover/shared
|
||||
import groth16/verifier
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
#[
|
||||
proc testProveAndVerify*( zkey_fname, wtns_fname: string): (VKey,Proof) =
|
||||
|
||||
echo("parsing witness & zkey files...")
|
||||
let witness = parseWitness( wtns_fname)
|
||||
let zkey = parseZKey( zkey_fname)
|
||||
|
||||
echo("generating proof...")
|
||||
let start = cpuTime()
|
||||
let proof = generateProof( zkey, witness )
|
||||
let elapsed = cpuTime() - start
|
||||
echo("proving took ",seconds(elapsed))
|
||||
|
||||
echo("verifying the proof...")
|
||||
let vkey = extractVKey( zkey)
|
||||
let ok = verifyProof( vkey, proof )
|
||||
echo("verification succeeded = ",ok)
|
||||
|
||||
return (vkey,proof)
|
||||
]#
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc sanityCheckPartialProofs*( zkey: ZKey, wtns: Witness, pool: Taskpool, printTimings: bool) =
|
||||
|
||||
let witness = wtns.values
|
||||
let M = witness.len
|
||||
|
||||
var partial_mask: seq[bool] = newSeq[bool]( M )
|
||||
var partial_witness: seq[Option[Fr[BN254_Snarks]]] = newSeq[Option[Fr[BN254_Snarks]]]( M )
|
||||
|
||||
# generate randomized partial witness
|
||||
partial_mask[0] = true
|
||||
partial_witness[0] = some(witness[0])
|
||||
var count = 0
|
||||
for i in 1..<M:
|
||||
let b : bool = rand(bool)
|
||||
partial_mask[i] = b
|
||||
if b:
|
||||
partial_witness[i] = some(witness[i])
|
||||
count += 1
|
||||
else:
|
||||
partial_witness[i] = none(Fr[BN254_Snarks])
|
||||
|
||||
echo "\nrandomized a partial witness of size " & $(count) & " out of " & $(M)
|
||||
let partial_wtns = PartialWitness(values: partial_witness )
|
||||
|
||||
let mask = randomMask()
|
||||
|
||||
var fullProof : Proof
|
||||
withMeasureTime(true,"\ngenerating the full proof"):
|
||||
fullProof = generateProofWithMask( zkey, wtns, mask, pool, printTimings )
|
||||
writeProof(stdout,fullProof)
|
||||
|
||||
let vkey = extractVKey(zkey)
|
||||
echo "verifying the full proof succeeds = " & $verifyProof(vkey, fullProof)
|
||||
|
||||
var partialProof : PartialProof
|
||||
withMeasureTime(true,"\ngenerating the partial proof"):
|
||||
partialProof = generatePartialProof( zkey, partial_wtns, pool, printTimings )
|
||||
|
||||
var finishedProof : Proof
|
||||
withMeasureTime(true,"\nfinishing the partial proof"):
|
||||
finishedProof = finishPartialProofWithMask( zkey, wtns, partialProof, mask, pool, printTimings )
|
||||
writeProof(stdout,finishedProof)
|
||||
|
||||
echo "verifying the finished proof succeeds = " & $verifyProof(vkey, finishedProof)
|
||||
|
||||
if (not isEqualProof(fullProof, finishedProof)):
|
||||
echo "PROBLEM! the two proofs DIFFER!!!"
|
||||
else:
|
||||
echo "OK. the two proofs agree"
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
@ -37,6 +37,14 @@ type ProjG2* = prj.EC_ShortW_Prj[Fp2[BN254_Snarks], prj.G2]
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func isEqualG1* (x, y: G1 ): bool = bool(x == y)
|
||||
func isEqualG2* (x, y: G2 ): bool = bool(x == y)
|
||||
|
||||
func `===`*(x, y: G1 ): bool = isEqualG1(x,y)
|
||||
func `===`*(x, y: G2 ): bool = isEqualG2(x,y)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
func unsafeMkG1* ( X, Y: Fp[BN254_Snarks] ) : G1 =
|
||||
return aff.EC_ShortW_Aff[Fp[BN254_Snarks], aff.G1](x: X, y: Y)
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@
|
||||
import constantine/named/properties_fields
|
||||
import constantine/math/extension_fields/towers
|
||||
|
||||
import groth16/bn128/fields
|
||||
#import groth16/bn128/fields
|
||||
import groth16/bn128/curves
|
||||
import groth16/bn128/io
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ import constantine/math/extension_fields/towers as ext
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
type B* = BigInt[256]
|
||||
type B* = BigInt[256]
|
||||
|
||||
func mkFp2* (i: Fp[BN254_Snarks], u: Fp[BN254_Snarks]) : Fp2[BN254_Snarks] =
|
||||
let c : array[2, Fp[BN254_Snarks]] = [i,u]
|
||||
|
||||
@ -4,12 +4,9 @@
|
||||
#
|
||||
|
||||
import system
|
||||
import std/cpuinfo
|
||||
import taskpools
|
||||
|
||||
# import constantine/curves_primitives except Fp, Fp2, Fr
|
||||
|
||||
import constantine/platforms/abstractions except Subgroup
|
||||
import constantine/platforms/abstractions except Subgroup
|
||||
import constantine/math/endomorphisms/frobenius except Subgroup
|
||||
|
||||
import constantine/math/io/io_bigints
|
||||
@ -23,11 +20,12 @@ import constantine/math/elliptic/ec_shortweierstrass_projective as prj except Su
|
||||
import constantine/math/elliptic/ec_scalar_mul_vartime as scl except Subgroup
|
||||
import constantine/math/elliptic/ec_multi_scalar_mul as msm except Subgroup
|
||||
|
||||
import groth16/bn128/fields
|
||||
#import groth16/bn128/fields
|
||||
import groth16/bn128/curves as mycurves
|
||||
|
||||
import groth16/misc # TEMP DEBUGGING
|
||||
import std/times
|
||||
#import groth16/misc # TEMP DEBUGGING
|
||||
#import std/cpuinfo
|
||||
#import std/times
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
@ -23,8 +23,6 @@
|
||||
|
||||
import std/streams
|
||||
|
||||
import sugar
|
||||
|
||||
import constantine/math/arithmetic
|
||||
import constantine/math/io/io_bigints
|
||||
|
||||
|
||||
@ -67,11 +67,7 @@ proc writeG2( f: File, p: G2 ) =
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# exports the proof into as a JSON file
|
||||
proc exportProof*( fpath: string, prf: Proof ) =
|
||||
|
||||
let f = open(fpath, fmWrite)
|
||||
defer: f.close()
|
||||
proc writeProof*( f: File, prf: Proof) =
|
||||
|
||||
f.writeLine("{ \"protocol\": \"groth16\"")
|
||||
f.writeLine(", \"curve\": \"bn128\"" )
|
||||
@ -80,6 +76,13 @@ proc exportProof*( fpath: string, prf: Proof ) =
|
||||
f.writeLine(", \"pi_c\":" ) ; writeG1( f, prf.pi_c )
|
||||
f.writeLine("}")
|
||||
|
||||
# exports the proof into as a JSON file
|
||||
proc exportProof*( fpath: string, prf: Proof ) =
|
||||
|
||||
let f = open(fpath, fmWrite)
|
||||
defer: f.close()
|
||||
writeProof(f, prf)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
#[
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
#
|
||||
|
||||
import strformat
|
||||
import times, os, strutils
|
||||
import times, strutils
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
|
||||
158
groth16/partial/finish.nim
Normal file
158
groth16/partial/finish.nim
Normal file
@ -0,0 +1,158 @@
|
||||
|
||||
#
|
||||
# Finish a "partial proof"
|
||||
#
|
||||
|
||||
{.push raises:[].}
|
||||
|
||||
import std/times
|
||||
import system
|
||||
import taskpools
|
||||
import constantine/math/arithmetic
|
||||
import constantine/named/properties_fields
|
||||
|
||||
import groth16/bn128
|
||||
#import groth16/math/domain
|
||||
import groth16/math/poly
|
||||
import groth16/zkey_types
|
||||
import groth16/files/witness
|
||||
import groth16/misc
|
||||
|
||||
import groth16/partial/types
|
||||
import groth16/prover/types
|
||||
import groth16/prover/shared
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# the finishing prover
|
||||
#
|
||||
|
||||
proc finishPartialProofWithMask*( zkey: ZKey, wtns: Witness, partialProof: PartialProof, mask: Mask, pool: Taskpool, printTimings: bool): Proof =
|
||||
|
||||
when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)):
|
||||
{.fatal: "Compile with arc/orc!".}
|
||||
|
||||
# if (zkey.header.curve != wtns.curve):
|
||||
# echo( "zkey.header.curve = " & ($zkey.header.curve) )
|
||||
# echo( "wtns.curve = " & ($wtns.curve ) )
|
||||
|
||||
assert( zkey.header.curve == wtns.curve )
|
||||
var start : float = 0
|
||||
|
||||
let witness = wtns.values
|
||||
|
||||
let hdr : GrothHeader = zkey.header
|
||||
let spec : SpecPoints = zkey.specPoints
|
||||
let pts : ProverPoints = zkey.pPoints
|
||||
|
||||
let nvars = hdr.nvars
|
||||
let npubs = hdr.npubs
|
||||
|
||||
assert( nvars == witness.len , "wrong witness length" )
|
||||
|
||||
let partial_mask = partialProof.partial_mask
|
||||
|
||||
# remark: with the special variable "1" we actuall have (npub+1) public IO variables
|
||||
var pubIO = newSeq[Fr[BN254_Snarks]]( npubs + 1)
|
||||
for i in 0..npubs: pubIO[i] = witness[i]
|
||||
|
||||
# note: we have to ignore public inputs (plus 1 for the special first entry)
|
||||
let startIdx : int = npubs + 1
|
||||
var count: int = 0 # count the number of NEW witness elements
|
||||
var zs_cnt: int = 0 # count the number of NEW witness elements, which are NOT public IO
|
||||
for i in 0..<startidx:
|
||||
if (not partial_mask[i]):
|
||||
count += 1
|
||||
for i in startIdx..<nvars:
|
||||
if (not partial_mask[i]):
|
||||
count += 1
|
||||
zs_cnt += 1
|
||||
|
||||
# compactify the stuff so we can call normal vector MSM
|
||||
var compact_witness: seq[F] = newSeq[F](count)
|
||||
var compact_pointsA1: seq[G1] = newSeq[G1](count)
|
||||
var compact_pointsB1: seq[G1] = newSeq[G1](count)
|
||||
var compact_pointsB2: seq[G2] = newSeq[G2](count)
|
||||
var compact_zs: seq[F] = newSeq[F](zs_cnt)
|
||||
var compact_pointsC1: seq[G1] = newSeq[G1](zs_cnt)
|
||||
block:
|
||||
var j: int = 0;
|
||||
var k: int = 0;
|
||||
for i in 0..<nvars:
|
||||
if (not partial_mask[i]):
|
||||
compact_witness[j ] = witness[i]
|
||||
compact_pointsA1[j] = pts.pointsA1[i]
|
||||
compact_pointsB1[j] = pts.pointsB1[i]
|
||||
compact_pointsB2[j] = pts.pointsB2[i]
|
||||
j += 1
|
||||
if i >= startIdx:
|
||||
compact_zs[k] = witness[i]
|
||||
compact_pointsC1[k] = pts.pointsC1[ i - startIdx ]
|
||||
k += 1
|
||||
|
||||
start = cpuTime()
|
||||
var abc : ABC
|
||||
withMeasureTime(printTimings,"building 'ABC'"):
|
||||
abc = buildABC( zkey, witness )
|
||||
|
||||
start = cpuTime()
|
||||
var qs : seq[Fr[BN254_Snarks]]
|
||||
withMeasureTime(printTimings,"computing the quotient (FFTs)"):
|
||||
case zkey.header.flavour
|
||||
|
||||
# the points H are [delta^-1 * tau^i * Z(tau)]
|
||||
of JensGroth:
|
||||
let polyQ = computeQuotientPointwise( abc, pool )
|
||||
qs = polyQ.coeffs
|
||||
|
||||
# the points H are `[delta^-1 * L_{2i+1}(tau)]_1`
|
||||
# where L_i are Lagrange basis polynomials on the double-sized domain
|
||||
of Snarkjs:
|
||||
qs = computeSnarkjsScalarCoeffs( abc, pool )
|
||||
|
||||
# masking coeffs
|
||||
let r = mask.r
|
||||
let s = mask.s
|
||||
|
||||
assert( witness.len == pts.pointsA1.len )
|
||||
assert( witness.len == pts.pointsB1.len )
|
||||
assert( witness.len == pts.pointsB2.len )
|
||||
assert( hdr.domainSize == qs.len )
|
||||
assert( hdr.domainSize == pts.pointsH1.len )
|
||||
assert( nvars - npubs - 1 == pts.pointsC1.len )
|
||||
|
||||
var pi_a : G1 = partialProof.partial_pi_a
|
||||
withMeasureTime(printTimings,"computing pi_A (G1 MSM)"):
|
||||
pi_a += r ** spec.delta1
|
||||
pi_a += msmMultiThreadedG1( compact_witness , compact_pointsA1, pool )
|
||||
|
||||
var rho : G1 = partialProof.partial_rho
|
||||
withMeasureTime(printTimings,"computing rho (G1 MSM)"):
|
||||
rho += s ** spec.delta1
|
||||
rho += msmMultiThreadedG1( compact_witness , compact_pointsB1, pool )
|
||||
|
||||
var pi_b : G2 = partialProof.partial_pi_b
|
||||
withMeasureTime(printTimings,"computing pi_B (G2 MSM)"):
|
||||
pi_b += s ** spec.delta2
|
||||
pi_b += msmMultiThreadedG2( compact_witness , compact_pointsB2, pool )
|
||||
|
||||
var pi_c : G1 = partialProof.partial_pi_c
|
||||
withMeasureTime(printTimings,"computing pi_C (2x G1 MSM)"):
|
||||
pi_c += s ** pi_a
|
||||
pi_c += r ** rho
|
||||
pi_c += negFr(r*s) ** spec.delta1
|
||||
pi_c += msmMultiThreadedG1( qs , pts.pointsH1, pool )
|
||||
pi_c += msmMultiThreadedG1( compact_zs , compact_pointsC1, pool )
|
||||
|
||||
return Proof( curve:"bn128", publicIO:pubIO, pi_a:pi_a, pi_b:pi_b, pi_c:pi_c )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc finishPartialProofWithTrivialMask*( zkey: ZKey, wtns: Witness, partial: PartialProof, pool: Taskpool, printTimings: bool ): Proof =
|
||||
let mask = Mask( r: zeroFr , s: zeroFr )
|
||||
return finishPartialProofWithMask( zkey, wtns, partial, mask, pool, printTimings )
|
||||
|
||||
proc finishPartialProof*( zkey: ZKey, wtns: Witness, partial: PartialProof, pool: Taskpool, printTimings = false ): Proof =
|
||||
let mask = randomMask()
|
||||
return finishPartialProofWithMask( zkey, wtns, partial, mask, pool, printTimings )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
152
groth16/partial/precalc.nim
Normal file
152
groth16/partial/precalc.nim
Normal file
@ -0,0 +1,152 @@
|
||||
|
||||
#
|
||||
# "a partial prover": precalculate stuff based on a partial witness
|
||||
#
|
||||
# this is useful when a significant portion of the witness does not changes
|
||||
# between proofs - then such precalculation can result in a potentially big speedup.
|
||||
#
|
||||
# an example use is RLN proofs, where the circuit is dominated by the Merkle
|
||||
#vinclusion proof check, which is expected to change relatively rarely.
|
||||
#
|
||||
|
||||
{.push raises:[].}
|
||||
|
||||
import std/times
|
||||
import std/options
|
||||
import system
|
||||
|
||||
import taskpools
|
||||
|
||||
import groth16/bn128
|
||||
import groth16/zkey_types
|
||||
import groth16/misc
|
||||
|
||||
import groth16/partial/types
|
||||
|
||||
#import groth16/math/domain
|
||||
#import groth16/math/poly
|
||||
#import groth16/prover/shared
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# the prover
|
||||
#
|
||||
|
||||
proc generatePartialProof*( zkey: ZKey, pwtns: PartialWitness, pool: Taskpool, printTimings: bool): PartialProof =
|
||||
|
||||
when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)):
|
||||
{.fatal: "Compile with arc/orc!".}
|
||||
|
||||
# assert( zkey.header.curve == wtns.curve )
|
||||
var start : float = 0
|
||||
|
||||
let partial_witness = pwtns.values
|
||||
|
||||
let hdr : GrothHeader = zkey.header
|
||||
let spec : SpecPoints = zkey.specPoints
|
||||
let pts : ProverPoints = zkey.pPoints
|
||||
|
||||
let nvars = hdr.nvars
|
||||
let npubs = hdr.npubs
|
||||
|
||||
assert( nvars == partial_witness.len , "wrong witness length" )
|
||||
|
||||
assert( partial_witness.len == pts.pointsA1.len )
|
||||
assert( partial_witness.len == pts.pointsB1.len )
|
||||
assert( partial_witness.len == pts.pointsB2.len )
|
||||
|
||||
# note: in "zs" we have to ignore public inputs (plus 1 for the special first entry)
|
||||
var partial_mask : seq[bool] = newSeq[bool](nvars)
|
||||
for i in 0..<nvars:
|
||||
partial_mask[i] = false
|
||||
|
||||
let startIdx: int = npubs + 1 # in "zs" we have to skip the public IO
|
||||
var count: int = 0 # count the number of existing witness elements
|
||||
var zs_cnt: int = 0 # count the number of existing witness elements, which are NOT public IO
|
||||
for i in 0..<startidx:
|
||||
if isSome(partial_witness[i]):
|
||||
partial_mask[i] = true
|
||||
count += 1
|
||||
for i in startIdx..<nvars:
|
||||
if isSome(partial_witness[i]):
|
||||
partial_mask[i] = true
|
||||
count += 1
|
||||
zs_cnt += 1
|
||||
|
||||
# compactify the stuff so we can call normal vector MSM
|
||||
var compact_witness: seq[F] = newSeq[F](count)
|
||||
var compact_pointsA1: seq[G1] = newSeq[G1](count)
|
||||
var compact_pointsB1: seq[G1] = newSeq[G1](count)
|
||||
var compact_pointsB2: seq[G2] = newSeq[G2](count)
|
||||
var compact_zs: seq[F] = newSeq[F](zs_cnt)
|
||||
var compact_pointsC1: seq[G1] = newSeq[G1](zs_cnt)
|
||||
block:
|
||||
var j: int = 0;
|
||||
var k: int = 0;
|
||||
for i in 0..<nvars:
|
||||
if isSome(partial_witness[i]):
|
||||
let wtns_value = partial_witness[i].unsafeGet()
|
||||
compact_witness[j] = wtns_value
|
||||
compact_pointsA1[j] = pts.pointsA1[ i ]
|
||||
compact_pointsB1[j] = pts.pointsB1[ i ]
|
||||
compact_pointsB2[j] = pts.pointsB2[ i ]
|
||||
j += 1
|
||||
if i >= startIdx:
|
||||
compact_zs[k] = wtns_value
|
||||
compact_pointsC1[k] = pts.pointsC1[ i - startIdx ]
|
||||
k += 1
|
||||
|
||||
#[
|
||||
start = cpuTime()
|
||||
var abc : ABC
|
||||
withMeasureTime(printTimings,"building 'ABC'"):
|
||||
abc = buildABC( zkey, witness )
|
||||
|
||||
start = cpuTime()
|
||||
var qs : seq[Fr[BN254_Snarks]]
|
||||
withMeasureTime(printTimings,"computing the quotient (FFTs)"):
|
||||
case zkey.header.flavour
|
||||
|
||||
# the points H are [delta^-1 * tau^i * Z(tau)]
|
||||
of JensGroth:
|
||||
let polyQ = computeQuotientPointwise( abc, pool )
|
||||
qs = polyQ.coeffs
|
||||
|
||||
# the points H are `[delta^-1 * L_{2i+1}(tau)]_1`
|
||||
# where L_i are Lagrange basis polynomials on the double-sized domain
|
||||
of Snarkjs:
|
||||
qs = computeSnarkjsScalarCoeffs( abc, pool )
|
||||
|
||||
]#
|
||||
|
||||
assert( hdr.domainSize == pts.pointsH1.len )
|
||||
assert( nvars - npubs - 1 == pts.pointsC1.len )
|
||||
|
||||
var pi_a : G1
|
||||
withMeasureTime(printTimings,"computing partial pi_A (G1 MSM)"):
|
||||
pi_a = spec.alpha1
|
||||
pi_a += msmMultiThreadedG1( compact_witness , compact_pointsA1 , pool )
|
||||
|
||||
var rho : G1
|
||||
withMeasureTime(printTimings,"computing partial rho (G1 MSM)"):
|
||||
rho = spec.beta1
|
||||
rho += msmMultiThreadedG1( compact_witness , compact_pointsB1 , pool )
|
||||
|
||||
var pi_b : G2
|
||||
withMeasureTime(printTimings,"computing partial pi_B (G2 MSM)"):
|
||||
pi_b = spec.beta2
|
||||
pi_b += msmMultiThreadedG2( compact_witness , compact_pointsB2 , pool )
|
||||
|
||||
var pi_c : G1
|
||||
withMeasureTime(printTimings,"partial computing pi_C (G1 MSM)"):
|
||||
pi_c = msmMultiThreadedG1( compact_zs , compact_pointsC1 , pool )
|
||||
|
||||
return PartialProof(
|
||||
partial_mask : partial_mask
|
||||
, partial_pi_a : pi_a
|
||||
, partial_rho : rho
|
||||
, partial_pi_b : pi_b
|
||||
, partial_pi_c : pi_c
|
||||
)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
29
groth16/partial/types.nim
Normal file
29
groth16/partial/types.nim
Normal file
@ -0,0 +1,29 @@
|
||||
|
||||
{.push raises:[].}
|
||||
|
||||
import std/options
|
||||
|
||||
import constantine/named/properties_fields
|
||||
|
||||
import groth16/bn128
|
||||
import groth16/prover/types
|
||||
|
||||
export types
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
type
|
||||
|
||||
F* = Fr[BN254_Snarks]
|
||||
|
||||
PartialWitness* = object
|
||||
values* : seq[Option[F]]
|
||||
|
||||
PartialProof* = object
|
||||
partial_mask* : seq[bool]
|
||||
partial_pi_a* : G1 # = [alpha]_1 + sum z_j*[A_j]_1
|
||||
partial_rho* : G1 # = [beta]_1 + sum z_j*[B_j]_1
|
||||
partial_pi_b* : G2 # = [beta]_2 + sum z_j*[B_j]_2
|
||||
partial_pi_c* : G1 # = sum z_j*[K_j]_1
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
@ -7,332 +7,9 @@
|
||||
# See <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
|
||||
#
|
||||
|
||||
{.push raises:[].}
|
||||
import groth16/prover/types
|
||||
import groth16/prover/groth16
|
||||
|
||||
#[
|
||||
import sugar
|
||||
import constantine/math/config/curves
|
||||
import constantine/math/io/io_fields
|
||||
import constantine/math/io/io_bigints
|
||||
import ./zkey
|
||||
]#
|
||||
export types
|
||||
export groth16
|
||||
|
||||
import std/os
|
||||
import std/times
|
||||
import std/cpuinfo
|
||||
import system
|
||||
import taskpools
|
||||
import constantine/math/arithmetic
|
||||
import constantine/named/properties_fields
|
||||
import constantine/math/extension_fields/towers
|
||||
|
||||
#import constantine/math/io/io_extfields except Fp12
|
||||
|
||||
import groth16/bn128
|
||||
import groth16/math/domain
|
||||
import groth16/math/poly
|
||||
import groth16/zkey_types
|
||||
import groth16/files/witness
|
||||
import groth16/misc
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
type
|
||||
Proof* = object
|
||||
publicIO* : seq[Fr[BN254_Snarks]]
|
||||
pi_a* : G1
|
||||
pi_b* : G2
|
||||
pi_c* : G1
|
||||
curve* : string
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Az, Bz, Cz column vectors
|
||||
#
|
||||
|
||||
type
|
||||
ABC = object
|
||||
valuesAz : seq[Fr[BN254_Snarks]]
|
||||
valuesBz : seq[Fr[BN254_Snarks]]
|
||||
valuesCz : seq[Fr[BN254_Snarks]]
|
||||
|
||||
# computes the vectors A*z, B*z, C*z where z is the witness
|
||||
func buildABC( zkey: ZKey, witness: seq[Fr[BN254_Snarks]] ): ABC =
|
||||
let hdr: GrothHeader = zkey.header
|
||||
let domSize = hdr.domainSize
|
||||
|
||||
var valuesAz = newSeq[Fr[BN254_Snarks]](domSize)
|
||||
var valuesBz = newSeq[Fr[BN254_Snarks]](domSize)
|
||||
|
||||
for entry in zkey.coeffs:
|
||||
case entry.matrix
|
||||
of MatrixA: valuesAz[entry.row] += entry.coeff * witness[entry.col]
|
||||
of MatrixB: valuesBz[entry.row] += entry.coeff * witness[entry.col]
|
||||
else: raise newException(AssertionDefect, "fatal error")
|
||||
|
||||
var valuesCz = newSeq[Fr[BN254_Snarks]](domSize)
|
||||
for i in 0..<domSize:
|
||||
valuesCz[i] = valuesAz[i] * valuesBz[i]
|
||||
|
||||
return ABC( valuesAz:valuesAz, valuesBz:valuesBz, valuesCz:valuesCz )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# quotient poly
|
||||
#
|
||||
|
||||
# interpolates A,B,C, and computes the quotient polynomial Q = (A*B - C) / Z
|
||||
func computeQuotientNaive( abc: ABC ): Poly=
|
||||
let n = abc.valuesAz.len
|
||||
assert( abc.valuesBz.len == n )
|
||||
assert( abc.valuesCz.len == n )
|
||||
let D = createDomain(n)
|
||||
let polyA : Poly = polyInverseNTT( abc.valuesAz , D )
|
||||
let polyB : Poly = polyInverseNTT( abc.valuesBz , D )
|
||||
let polyC : Poly = polyInverseNTT( abc.valuesCz , D )
|
||||
let polyBig = polyMulFFT( polyA , polyB ) - polyC
|
||||
var polyQ = polyDivideByVanishing(polyBig, D.domainSize)
|
||||
polyQ.coeffs.add( zeroFr ) # make it a power of two
|
||||
return polyQ
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
# returns [ eta^i * xs[i] | i<-[0..n-1] ]
|
||||
func multiplyByPowers( xs: seq[Fr[BN254_Snarks]], eta: Fr[BN254_Snarks] ): seq[Fr[BN254_Snarks]] =
|
||||
let n = xs.len
|
||||
assert(n >= 1)
|
||||
var ys = newSeq[Fr[BN254_Snarks]](n)
|
||||
ys[0] = xs[0]
|
||||
if n >= 1: ys[1] = eta * xs[1]
|
||||
var spow = eta
|
||||
for i in 2..<n:
|
||||
spow *= eta
|
||||
ys[i] = spow * xs[i]
|
||||
return ys
|
||||
|
||||
# interpolates a polynomial, shift the variable by `eta`, and compute the shifted values
|
||||
func shiftEvalDomain(
|
||||
values: seq[Fr[BN254_Snarks]],
|
||||
D: Domain,
|
||||
eta: Fr[BN254_Snarks] ): seq[Fr[BN254_Snarks]] =
|
||||
let poly : Poly = polyInverseNTT( values , D )
|
||||
let cs : seq[Fr[BN254_Snarks]] = poly.coeffs
|
||||
var ds : seq[Fr[BN254_Snarks]] = multiplyByPowers( cs, eta )
|
||||
return polyForwardNTT( Poly(coeffs:ds), D )
|
||||
|
||||
# Wraps shiftEvalDomain such that it can be called by Taskpool.spawn. The result
|
||||
# is written to the output parameter. Has an unused return type because
|
||||
# Taskpool.spawn cannot handle a void return type.
|
||||
func shiftEvalDomainTask(
|
||||
values: seq[Fr[BN254_Snarks]],
|
||||
D: Domain,
|
||||
eta: Fr[BN254_Snarks],
|
||||
output: ptr Isolated[seq[Fr[BN254_Snarks]]]): bool =
|
||||
|
||||
output[] = isolate shiftEvalDomain(values, D, eta)
|
||||
|
||||
# computes the quotient polynomial Q = (A*B - C) / Z
|
||||
# by computing the values on a shifted domain, and interpolating the result
|
||||
# remark: Q has degree `n-2`, so it's enough to use a domain of size n
|
||||
proc computeQuotientPointwise( abc: ABC, pool: TaskPool ): Poly =
|
||||
let n = abc.valuesAz.len
|
||||
assert( abc.valuesBz.len == n )
|
||||
assert( abc.valuesCz.len == n )
|
||||
|
||||
let D = createDomain(n)
|
||||
|
||||
# (eta*omega^j)^n - 1 = eta^n - 1
|
||||
# 1 / [ (eta*omega^j)^n - 1] = 1/(eta^n - 1)
|
||||
let eta = createDomain(2*n).domainGen
|
||||
let invZ1 = invFr( smallPowFr(eta,n) - oneFr )
|
||||
|
||||
var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]]
|
||||
|
||||
var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 )
|
||||
var taskB1 = pool.spawn shiftEvalDomainTask( abc.valuesBz, D, eta, addr outputB1 )
|
||||
var taskC1 = pool.spawn shiftEvalDomainTask( abc.valuesCz, D, eta, addr outputC1 )
|
||||
|
||||
discard sync taskA1
|
||||
discard sync taskB1
|
||||
discard sync taskC1
|
||||
|
||||
let A1 = outputA1.extract()
|
||||
let B1 = outputB1.extract()
|
||||
let C1 = outputC1.extract()
|
||||
|
||||
var ys : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( n )
|
||||
for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] ) * invZ1
|
||||
let Q1 = polyInverseNTT( ys, D )
|
||||
let cs = multiplyByPowers( Q1.coeffs, invFr(eta) )
|
||||
|
||||
return Poly(coeffs: cs)
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
# Snarkjs does something different, not actually computing the quotient poly
|
||||
# they can get away with this, because during the trusted setup, they
|
||||
# replace the points encoding the values `delta^-1 * tau^i * Z(tau)` by
|
||||
# (shifted) Lagrange bases.
|
||||
# see <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
|
||||
#
|
||||
proc computeSnarkjsScalarCoeffs( abc: ABC, pool: TaskPool ): seq[Fr[BN254_Snarks]] =
|
||||
let n = abc.valuesAz.len
|
||||
assert( abc.valuesBz.len == n )
|
||||
assert( abc.valuesCz.len == n )
|
||||
let D = createDomain(n)
|
||||
let eta = createDomain(2*n).domainGen
|
||||
|
||||
var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]]
|
||||
|
||||
var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 )
|
||||
var taskB1 = pool.spawn shiftEvalDomainTask( abc.valuesBz, D, eta, addr outputB1 )
|
||||
var taskC1 = pool.spawn shiftEvalDomainTask( abc.valuesCz, D, eta, addr outputC1 )
|
||||
|
||||
discard sync taskA1
|
||||
discard sync taskB1
|
||||
discard sync taskC1
|
||||
|
||||
let A1 = outputA1.extract()
|
||||
let B1 = outputB1.extract()
|
||||
let C1 = outputC1.extract()
|
||||
|
||||
var ys : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( n )
|
||||
for j in 0..<n: ys[j] = ( A1[j] * B1[j] - C1[j] )
|
||||
|
||||
return ys
|
||||
|
||||
#[
|
||||
|
||||
proc computeSnarkjsScalarCoeffs_st( abc: ABC ): seq[Fr] =
|
||||
let n = abc.valuesAz.len
|
||||
assert( abc.valuesBz.len == n )
|
||||
assert( abc.valuesCz.len == n )
|
||||
let D = createDomain(n)
|
||||
let eta = createDomain(2*n).domainGen
|
||||
let A1 : seq[Fr] = shiftEvalDomain( abc.valuesAz, D, eta )
|
||||
let B1 : seq[Fr] = shiftEvalDomain( abc.valuesBz, D, eta )
|
||||
let C1 : seq[Fr] = shiftEvalDomain( abc.valuesCz, D, eta )
|
||||
var ys : seq[Fr] = newSeq[Fr]( n )
|
||||
for j in 0..<n: ys[j] = ( A1[j] * B1[j] - C1[j] )
|
||||
return ys
|
||||
|
||||
proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC ): seq[Fr] =
|
||||
if nthreads <= 1:
|
||||
computeSnarkjsScalarCoeffs_st( abc )
|
||||
else:
|
||||
computeSnarkjsScalarCoeffs_mt( nthreads, abc )
|
||||
|
||||
]#
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# the prover
|
||||
#
|
||||
|
||||
type
|
||||
Mask* = object
|
||||
r*: Fr[BN254_Snarks] # masking coefficients
|
||||
s*: Fr[BN254_Snarks] # for zero knowledge
|
||||
|
||||
proc generateProofWithMask*( zkey: ZKey, wtns: Witness, mask: Mask, pool: Taskpool, printTimings: bool): Proof =
|
||||
|
||||
when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)):
|
||||
{.fatal: "Compile with arc/orc!".}
|
||||
|
||||
# if (zkey.header.curve != wtns.curve):
|
||||
# echo( "zkey.header.curve = " & ($zkey.header.curve) )
|
||||
# echo( "wtns.curve = " & ($wtns.curve ) )
|
||||
|
||||
assert( zkey.header.curve == wtns.curve )
|
||||
var start : float = 0
|
||||
|
||||
let witness = wtns.values
|
||||
|
||||
let hdr : GrothHeader = zkey.header
|
||||
let spec : SpecPoints = zkey.specPoints
|
||||
let pts : ProverPoints = zkey.pPoints
|
||||
|
||||
let nvars = hdr.nvars
|
||||
let npubs = hdr.npubs
|
||||
|
||||
assert( nvars == witness.len , "wrong witness length" )
|
||||
|
||||
# remark: with the special variable "1" we actuall have (npub+1) public IO variables
|
||||
var pubIO = newSeq[Fr[BN254_Snarks]]( npubs + 1)
|
||||
for i in 0..npubs: pubIO[i] = witness[i]
|
||||
|
||||
start = cpuTime()
|
||||
var abc : ABC
|
||||
withMeasureTime(printTimings,"building 'ABC'"):
|
||||
abc = buildABC( zkey, witness )
|
||||
|
||||
start = cpuTime()
|
||||
var qs : seq[Fr[BN254_Snarks]]
|
||||
withMeasureTime(printTimings,"computing the quotient (FFTs)"):
|
||||
case zkey.header.flavour
|
||||
|
||||
# the points H are [delta^-1 * tau^i * Z(tau)]
|
||||
of JensGroth:
|
||||
let polyQ = computeQuotientPointwise( abc, pool )
|
||||
qs = polyQ.coeffs
|
||||
|
||||
# the points H are `[delta^-1 * L_{2i+1}(tau)]_1`
|
||||
# where L_i are Lagrange basis polynomials on the double-sized domain
|
||||
of Snarkjs:
|
||||
qs = computeSnarkjsScalarCoeffs( abc, pool )
|
||||
|
||||
var zs = newSeq[Fr[BN254_Snarks]]( nvars - npubs - 1 )
|
||||
for j in npubs+1..<nvars:
|
||||
zs[j-npubs-1] = witness[j]
|
||||
|
||||
# masking coeffs
|
||||
let r = mask.r
|
||||
let s = mask.s
|
||||
|
||||
assert( witness.len == pts.pointsA1.len )
|
||||
assert( witness.len == pts.pointsB1.len )
|
||||
assert( witness.len == pts.pointsB2.len )
|
||||
assert( hdr.domainSize == qs.len )
|
||||
assert( hdr.domainSize == pts.pointsH1.len )
|
||||
assert( nvars - npubs - 1 == zs.len )
|
||||
assert( nvars - npubs - 1 == pts.pointsC1.len )
|
||||
|
||||
var pi_a : G1
|
||||
withMeasureTime(printTimings,"computing pi_A (G1 MSM)"):
|
||||
pi_a = spec.alpha1
|
||||
pi_a += r ** spec.delta1
|
||||
pi_a += msmMultiThreadedG1( witness , pts.pointsA1, pool )
|
||||
|
||||
var rho : G1
|
||||
withMeasureTime(printTimings,"computing rho (G1 MSM)"):
|
||||
rho = spec.beta1
|
||||
rho += s ** spec.delta1
|
||||
rho += msmMultiThreadedG1( witness , pts.pointsB1, pool )
|
||||
|
||||
var pi_b : G2
|
||||
withMeasureTime(printTimings,"computing pi_B (G2 MSM)"):
|
||||
pi_b = spec.beta2
|
||||
pi_b += s ** spec.delta2
|
||||
pi_b += msmMultiThreadedG2( witness , pts.pointsB2, pool )
|
||||
|
||||
var pi_c : G1
|
||||
withMeasureTime(printTimings,"computing pi_C (2x G1 MSM)"):
|
||||
pi_c = s ** pi_a
|
||||
pi_c += r ** rho
|
||||
pi_c += negFr(r*s) ** spec.delta1
|
||||
pi_c += msmMultiThreadedG1( qs , pts.pointsH1, pool )
|
||||
pi_c += msmMultiThreadedG1( zs , pts.pointsC1, pool )
|
||||
|
||||
return Proof( curve:"bn128", publicIO:pubIO, pi_a:pi_a, pi_b:pi_b, pi_c:pi_c )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc generateProofWithTrivialMask*( zkey: ZKey, wtns: Witness, pool: Taskpool, printTimings: bool ): Proof =
|
||||
let mask = Mask( r: zeroFr , s: zeroFr )
|
||||
return generateProofWithMask( zkey, wtns, mask, pool, printTimings )
|
||||
|
||||
proc generateProof*( zkey: ZKey, wtns: Witness, pool: Taskpool, printTimings = false ): Proof =
|
||||
|
||||
# masking coeffs
|
||||
let r = randFr()
|
||||
let s = randFr()
|
||||
let mask = Mask(r: r, s: s)
|
||||
|
||||
return generateProofWithMask( zkey, wtns, mask, pool, printTimings )
|
||||
|
||||
135
groth16/prover/groth16.nim
Normal file
135
groth16/prover/groth16.nim
Normal file
@ -0,0 +1,135 @@
|
||||
|
||||
#
|
||||
# Groth16 prover
|
||||
#
|
||||
# WARNING!
|
||||
# the points H in `.zkey` are *NOT* what normal people would think they are
|
||||
# See <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
|
||||
#
|
||||
|
||||
{.push raises:[].}
|
||||
|
||||
import std/times
|
||||
import system
|
||||
import taskpools
|
||||
import constantine/math/arithmetic
|
||||
import constantine/named/properties_fields
|
||||
|
||||
#import constantine/math/io/io_extfields except Fp12
|
||||
|
||||
import groth16/bn128
|
||||
#import groth16/math/domain
|
||||
import groth16/math/poly
|
||||
import groth16/zkey_types
|
||||
import groth16/files/witness
|
||||
import groth16/misc
|
||||
|
||||
import groth16/prover/types
|
||||
import groth16/prover/shared
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# the prover
|
||||
#
|
||||
|
||||
proc generateProofWithMask*( zkey: ZKey, wtns: Witness, mask: Mask, pool: Taskpool, printTimings: bool): Proof =
|
||||
|
||||
when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)):
|
||||
{.fatal: "Compile with arc/orc!".}
|
||||
|
||||
# if (zkey.header.curve != wtns.curve):
|
||||
# echo( "zkey.header.curve = " & ($zkey.header.curve) )
|
||||
# echo( "wtns.curve = " & ($wtns.curve ) )
|
||||
|
||||
assert( zkey.header.curve == wtns.curve )
|
||||
var start : float = 0
|
||||
|
||||
let witness = wtns.values
|
||||
|
||||
let hdr : GrothHeader = zkey.header
|
||||
let spec : SpecPoints = zkey.specPoints
|
||||
let pts : ProverPoints = zkey.pPoints
|
||||
|
||||
let nvars = hdr.nvars
|
||||
let npubs = hdr.npubs
|
||||
|
||||
assert( nvars == witness.len , "wrong witness length" )
|
||||
|
||||
# remark: with the special variable "1" we actuall have (npub+1) public IO variables
|
||||
var pubIO = newSeq[Fr[BN254_Snarks]]( npubs + 1)
|
||||
for i in 0..npubs: pubIO[i] = witness[i]
|
||||
|
||||
start = cpuTime()
|
||||
var abc : ABC
|
||||
withMeasureTime(printTimings,"building 'ABC'"):
|
||||
abc = buildABC( zkey, witness )
|
||||
|
||||
start = cpuTime()
|
||||
var qs : seq[Fr[BN254_Snarks]]
|
||||
withMeasureTime(printTimings,"computing the quotient (FFTs)"):
|
||||
case zkey.header.flavour
|
||||
|
||||
# the points H are [delta^-1 * tau^i * Z(tau)]
|
||||
of JensGroth:
|
||||
let polyQ = computeQuotientPointwise( abc, pool )
|
||||
qs = polyQ.coeffs
|
||||
|
||||
# the points H are `[delta^-1 * L_{2i+1}(tau)]_1`
|
||||
# where L_i are Lagrange basis polynomials on the double-sized domain
|
||||
of Snarkjs:
|
||||
qs = computeSnarkjsScalarCoeffs( abc, pool )
|
||||
|
||||
var zs = newSeq[Fr[BN254_Snarks]]( nvars - npubs - 1 )
|
||||
for j in npubs+1..<nvars:
|
||||
zs[j-npubs-1] = witness[j]
|
||||
|
||||
# masking coeffs
|
||||
let r = mask.r
|
||||
let s = mask.s
|
||||
|
||||
assert( witness.len == pts.pointsA1.len )
|
||||
assert( witness.len == pts.pointsB1.len )
|
||||
assert( witness.len == pts.pointsB2.len )
|
||||
assert( hdr.domainSize == qs.len )
|
||||
assert( hdr.domainSize == pts.pointsH1.len )
|
||||
assert( nvars - npubs - 1 == zs.len )
|
||||
assert( nvars - npubs - 1 == pts.pointsC1.len )
|
||||
|
||||
var pi_a : G1
|
||||
withMeasureTime(printTimings,"computing pi_A (G1 MSM)"):
|
||||
pi_a = spec.alpha1
|
||||
pi_a += r ** spec.delta1
|
||||
pi_a += msmMultiThreadedG1( witness , pts.pointsA1, pool )
|
||||
|
||||
var rho : G1
|
||||
withMeasureTime(printTimings,"computing rho (G1 MSM)"):
|
||||
rho = spec.beta1
|
||||
rho += s ** spec.delta1
|
||||
rho += msmMultiThreadedG1( witness , pts.pointsB1, pool )
|
||||
|
||||
var pi_b : G2
|
||||
withMeasureTime(printTimings,"computing pi_B (G2 MSM)"):
|
||||
pi_b = spec.beta2
|
||||
pi_b += s ** spec.delta2
|
||||
pi_b += msmMultiThreadedG2( witness , pts.pointsB2, pool )
|
||||
|
||||
var pi_c : G1
|
||||
withMeasureTime(printTimings,"computing pi_C (2x G1 MSM)"):
|
||||
pi_c = s ** pi_a
|
||||
pi_c += r ** rho
|
||||
pi_c += negFr(r*s) ** spec.delta1
|
||||
pi_c += msmMultiThreadedG1( qs , pts.pointsH1, pool )
|
||||
pi_c += msmMultiThreadedG1( zs , pts.pointsC1, pool )
|
||||
|
||||
return Proof( curve:"bn128", publicIO:pubIO, pi_a:pi_a, pi_b:pi_b, pi_c:pi_c )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc generateProofWithTrivialMask*( zkey: ZKey, wtns: Witness, pool: Taskpool, printTimings: bool ): Proof =
|
||||
let mask = Mask( r: zeroFr , s: zeroFr )
|
||||
return generateProofWithMask( zkey, wtns, mask, pool, printTimings )
|
||||
|
||||
proc generateProof*( zkey: ZKey, wtns: Witness, pool: Taskpool, printTimings = false ): Proof =
|
||||
let mask = randomMask()
|
||||
return generateProofWithMask( zkey, wtns, mask, pool, printTimings )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
180
groth16/prover/shared.nim
Normal file
180
groth16/prover/shared.nim
Normal file
@ -0,0 +1,180 @@
|
||||
|
||||
#
|
||||
# Groth16 prover
|
||||
#
|
||||
# WARNING!
|
||||
# the points H in `.zkey` are *NOT* what normal people would think they are
|
||||
# See <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
|
||||
#
|
||||
|
||||
{.push raises:[].}
|
||||
|
||||
import system
|
||||
import taskpools
|
||||
import constantine/math/arithmetic
|
||||
import constantine/named/properties_fields
|
||||
|
||||
import groth16/bn128
|
||||
import groth16/math/domain
|
||||
import groth16/math/poly
|
||||
import groth16/zkey_types
|
||||
#import groth16/misc
|
||||
|
||||
import groth16/prover/types
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
proc randomMask*(): Mask =
|
||||
# masking coeffs
|
||||
let r = randFr()
|
||||
let s = randFr()
|
||||
let mask = Mask(r: r, s: s)
|
||||
return mask
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# computes the vectors A*z, B*z, C*z where z is the witness
|
||||
func buildABC*( zkey: ZKey, witness: seq[Fr[BN254_Snarks]] ): ABC =
|
||||
let hdr: GrothHeader = zkey.header
|
||||
let domSize = hdr.domainSize
|
||||
|
||||
var valuesAz = newSeq[Fr[BN254_Snarks]](domSize)
|
||||
var valuesBz = newSeq[Fr[BN254_Snarks]](domSize)
|
||||
|
||||
for entry in zkey.coeffs:
|
||||
case entry.matrix
|
||||
of MatrixA: valuesAz[entry.row] += entry.coeff * witness[entry.col]
|
||||
of MatrixB: valuesBz[entry.row] += entry.coeff * witness[entry.col]
|
||||
else: raise newException(AssertionDefect, "fatal error")
|
||||
|
||||
var valuesCz = newSeq[Fr[BN254_Snarks]](domSize)
|
||||
for i in 0..<domSize:
|
||||
valuesCz[i] = valuesAz[i] * valuesBz[i]
|
||||
|
||||
return ABC( valuesAz:valuesAz, valuesBz:valuesBz, valuesCz:valuesCz )
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# quotient poly
|
||||
#
|
||||
|
||||
# interpolates A,B,C, and computes the quotient polynomial Q = (A*B - C) / Z
|
||||
func computeQuotientNaive*( abc: ABC ): Poly=
|
||||
let n = abc.valuesAz.len
|
||||
assert( abc.valuesBz.len == n )
|
||||
assert( abc.valuesCz.len == n )
|
||||
let D = createDomain(n)
|
||||
let polyA : Poly = polyInverseNTT( abc.valuesAz , D )
|
||||
let polyB : Poly = polyInverseNTT( abc.valuesBz , D )
|
||||
let polyC : Poly = polyInverseNTT( abc.valuesCz , D )
|
||||
let polyBig = polyMulFFT( polyA , polyB ) - polyC
|
||||
var polyQ = polyDivideByVanishing(polyBig, D.domainSize)
|
||||
polyQ.coeffs.add( zeroFr ) # make it a power of two
|
||||
return polyQ
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
# returns [ eta^i * xs[i] | i<-[0..n-1] ]
|
||||
func multiplyByPowers*( xs: seq[Fr[BN254_Snarks]], eta: Fr[BN254_Snarks] ): seq[Fr[BN254_Snarks]] =
|
||||
let n = xs.len
|
||||
assert(n >= 1)
|
||||
var ys = newSeq[Fr[BN254_Snarks]](n)
|
||||
ys[0] = xs[0]
|
||||
if n >= 1: ys[1] = eta * xs[1]
|
||||
var spow = eta
|
||||
for i in 2..<n:
|
||||
spow *= eta
|
||||
ys[i] = spow * xs[i]
|
||||
return ys
|
||||
|
||||
# interpolates a polynomial, shift the variable by `eta`, and compute the shifted values
|
||||
func shiftEvalDomain*(
|
||||
values: seq[Fr[BN254_Snarks]],
|
||||
D: Domain,
|
||||
eta: Fr[BN254_Snarks] ): seq[Fr[BN254_Snarks]] =
|
||||
let poly : Poly = polyInverseNTT( values , D )
|
||||
let cs : seq[Fr[BN254_Snarks]] = poly.coeffs
|
||||
var ds : seq[Fr[BN254_Snarks]] = multiplyByPowers( cs, eta )
|
||||
return polyForwardNTT( Poly(coeffs:ds), D )
|
||||
|
||||
# Wraps shiftEvalDomain such that it can be called by Taskpool.spawn. The result
|
||||
# is written to the output parameter. Has an unused return type because
|
||||
# Taskpool.spawn cannot handle a void return type.
|
||||
func shiftEvalDomainTask*(
|
||||
values: seq[Fr[BN254_Snarks]],
|
||||
D: Domain,
|
||||
eta: Fr[BN254_Snarks],
|
||||
output: ptr Isolated[seq[Fr[BN254_Snarks]]]): bool =
|
||||
|
||||
output[] = isolate shiftEvalDomain(values, D, eta)
|
||||
|
||||
# computes the quotient polynomial Q = (A*B - C) / Z
|
||||
# by computing the values on a shifted domain, and interpolating the result
|
||||
# remark: Q has degree `n-2`, so it's enough to use a domain of size n
|
||||
proc computeQuotientPointwise*( abc: ABC, pool: TaskPool ): Poly =
|
||||
let n = abc.valuesAz.len
|
||||
assert( abc.valuesBz.len == n )
|
||||
assert( abc.valuesCz.len == n )
|
||||
|
||||
let D = createDomain(n)
|
||||
|
||||
# (eta*omega^j)^n - 1 = eta^n - 1
|
||||
# 1 / [ (eta*omega^j)^n - 1] = 1/(eta^n - 1)
|
||||
let eta = createDomain(2*n).domainGen
|
||||
let invZ1 = invFr( smallPowFr(eta,n) - oneFr )
|
||||
|
||||
var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]]
|
||||
|
||||
var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 )
|
||||
var taskB1 = pool.spawn shiftEvalDomainTask( abc.valuesBz, D, eta, addr outputB1 )
|
||||
var taskC1 = pool.spawn shiftEvalDomainTask( abc.valuesCz, D, eta, addr outputC1 )
|
||||
|
||||
discard sync taskA1
|
||||
discard sync taskB1
|
||||
discard sync taskC1
|
||||
|
||||
let A1 = outputA1.extract()
|
||||
let B1 = outputB1.extract()
|
||||
let C1 = outputC1.extract()
|
||||
|
||||
var ys : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( n )
|
||||
for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] ) * invZ1
|
||||
let Q1 = polyInverseNTT( ys, D )
|
||||
let cs = multiplyByPowers( Q1.coeffs, invFr(eta) )
|
||||
|
||||
return Poly(coeffs: cs)
|
||||
|
||||
#---------------------------------------
|
||||
|
||||
# Snarkjs does something different, not actually computing the quotient poly
|
||||
# they can get away with this, because during the trusted setup, they
|
||||
# replace the points encoding the values `delta^-1 * tau^i * Z(tau)` by
|
||||
# (shifted) Lagrange bases.
|
||||
# see <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
|
||||
#
|
||||
proc computeSnarkjsScalarCoeffs*( abc: ABC, pool: TaskPool ): seq[Fr[BN254_Snarks]] =
|
||||
let n = abc.valuesAz.len
|
||||
assert( abc.valuesBz.len == n )
|
||||
assert( abc.valuesCz.len == n )
|
||||
let D = createDomain(n)
|
||||
let eta = createDomain(2*n).domainGen
|
||||
|
||||
var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]]
|
||||
|
||||
var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 )
|
||||
var taskB1 = pool.spawn shiftEvalDomainTask( abc.valuesBz, D, eta, addr outputB1 )
|
||||
var taskC1 = pool.spawn shiftEvalDomainTask( abc.valuesCz, D, eta, addr outputC1 )
|
||||
|
||||
discard sync taskA1
|
||||
discard sync taskB1
|
||||
discard sync taskC1
|
||||
|
||||
let A1 = outputA1.extract()
|
||||
let B1 = outputB1.extract()
|
||||
let C1 = outputC1.extract()
|
||||
|
||||
var ys : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( n )
|
||||
for j in 0..<n: ys[j] = ( A1[j] * B1[j] - C1[j] )
|
||||
|
||||
return ys
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
41
groth16/prover/types.nim
Normal file
41
groth16/prover/types.nim
Normal file
@ -0,0 +1,41 @@
|
||||
|
||||
{.push raises:[].}
|
||||
|
||||
import constantine/named/properties_fields
|
||||
|
||||
import groth16/bn128
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
type
|
||||
Mask* = object
|
||||
r*: Fr[BN254_Snarks] # masking coefficients
|
||||
s*: Fr[BN254_Snarks] # for zero knowledge
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# a Groth16 proof
|
||||
|
||||
type
|
||||
Proof* = object
|
||||
publicIO* : seq[Fr[BN254_Snarks]]
|
||||
pi_a* : G1
|
||||
pi_b* : G2
|
||||
pi_c* : G1
|
||||
curve* : string
|
||||
|
||||
func isEqualProof*(prf1, prf2: Proof): bool =
|
||||
return (prf1.pi_a === prf2.pi_a) and
|
||||
(prf1.pi_b === prf2.pi_b) and
|
||||
(prf1.pi_c === prf2.pi_c)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Az, Bz, Cz column vectors
|
||||
#
|
||||
|
||||
type
|
||||
ABC* = object
|
||||
valuesAz* : seq[Fr[BN254_Snarks]]
|
||||
valuesBz* : seq[Fr[BN254_Snarks]]
|
||||
valuesCz* : seq[Fr[BN254_Snarks]]
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
Loading…
x
Reference in New Issue
Block a user