initial support for "partial proofs" (precalculating based on a partial witness)

This commit is contained in:
Balazs Komuves 2026-01-15 22:41:56 +01:00
parent f7a4549d86
commit 3ddf4948f9
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
17 changed files with 868 additions and 398 deletions

View File

@ -21,7 +21,7 @@ at your choice.
- [x] find and fix the _second_ totally surreal bug
- [ ] clean up the code
- [ ] make it compatible with the latest constantine and also Nim 2.0.x
- [x] make it compatible with the latest constantine and also Nim 2.0.x
- [x] make it a nimble package
- [ ] compare `.r1cs` to the "coeffs" section of `.zkey`
- [x] generate fake circuit-specific setup ourselves
@ -29,5 +29,5 @@ at your choice.
- [x] multithreading support (MSM, and possibly also FFT)
- [ ] add Groth16 notes
- [ ] document the `snarkjs` circuit-specific setup `H` points convention
- [x] precalculate stuff for "partial" proofs
- [ ] make it work for different curves

View File

@ -1,13 +1,8 @@
import sugar
import std/strutils
import std/sequtils
import std/os
import std/cpuinfo
import std/parseopt
import std/times
import std/options
# import strformat
import taskpools
@ -21,6 +16,8 @@ import groth16/zkey_types
import groth16/fake_setup
import groth16/misc
import testing
#-------------------------------------------------------------------------------
proc printHelp() =
@ -37,6 +34,7 @@ proc printHelp() =
echo " -y, --verify : verify a proof"
echo " -u, --setup : perform (fake) trusted setup"
echo " -n, --nomask : don't use random masking for full ZK"
echo " -s, --sanity : sanity test the partial prover"
echo " -z, --zkey = <circuit.zkey> : the `.zkey` file"
echo " -w, --wtns = <circuit.wtns> : the `.wtns` file"
echo " -r, --r1cs = <circuit.r1cs> : the `.r1cs` file"
@ -46,33 +44,35 @@ proc printHelp() =
#-------------------------------------------------------------------------------
type Config = object
zkey_file: string
r1cs_file: string
wtns_file: string
output_file: string
io_file: string
verbose: bool
debug: bool
measure_time: bool
do_prove: bool
do_verify: bool
do_setup: bool
no_masking: bool
nthreads: int
zkey_file: string
r1cs_file: string
wtns_file: string
output_file: string
io_file: string
verbose: bool
debug: bool
measure_time: bool
do_prove: bool
do_verify: bool
do_setup: bool
no_masking: bool
partial_sanity: bool
nthreads: int
const dummyConfig =
Config( zkey_file: ""
, r1cs_file: ""
, wtns_file: ""
, output_file: ""
, io_file: ""
, verbose: false
, measure_time: false
, do_prove: false
, do_verify: false
, do_setup: false
, no_masking: false
, nthreads: 0
Config( zkey_file: ""
, r1cs_file: ""
, wtns_file: ""
, output_file: ""
, io_file: ""
, verbose: false
, measure_time: false
, do_prove: false
, do_verify: false
, do_setup: false
, no_masking: false
, partial_sanity: false
, nthreads: 0
)
proc printConfig(cfg: Config) =
@ -114,6 +114,7 @@ proc parseCliOptions(): Config =
of "y", "verify" : cfg.do_verify = true
of "u", "setup" : cfg.do_setup = true
of "n", "nomask" : cfg.no_masking = true
of "s", "sanity" : cfg.partial_sanity = true
of "o", "output" : cfg.output_file = value
of "r", "r1cs" : cfg.r1cs_file = value
of "z", "zkey" : cfg.zkey_file = value
@ -136,28 +137,6 @@ proc parseCliOptions(): Config =
return cfg
#-------------------------------------------------------------------------------
#[
proc testProveAndVerify*( zkey_fname, wtns_fname: string): (VKey,Proof) =
echo("parsing witness & zkey files...")
let witness = parseWitness( wtns_fname)
let zkey = parseZKey( zkey_fname)
echo("generating proof...")
let start = cpuTime()
let proof = generateProof( zkey, witness )
let elapsed = cpuTime() - start
echo("proving took ",seconds(elapsed))
echo("verifying the proof...")
let vkey = extractVKey( zkey)
let ok = verifyProof( vkey, proof )
echo("verification succeeded = ",ok)
return (vkey,proof)
]#
#-------------------------------------------------------------------------------
@ -198,6 +177,16 @@ proc cliMain(cfg: Config) =
printGrothHeader(zkey.header)
# debugPrintCoeffs(zkey.coeffs)
if cfg.partial_sanity:
if (cfg.wtns_file=="") or (cfg.zkey_file=="" and cfg.do_setup==false):
echo("cannot prove: missing witness and/or zkey file!")
quit()
else:
var pool = Taskpool.new(cfg.nthreads)
let print_timings = cfg.measure_time and cfg.verbose
sanityCheckPartialProofs(zkey,wtns,pool,print_timings)
echo("sanity testing partial proofs...")
if cfg.do_prove:
if (cfg.wtns_file=="") or (cfg.zkey_file=="" and cfg.do_setup==false):
echo("cannot prove: missing witness and/or zkey file!")

102
cli/testing.nim Normal file
View File

@ -0,0 +1,102 @@
import std/strutils
import std/times
import std/options
import std/random
import std/syncio
import taskpools
import constantine/named/properties_fields
# import groth16/bn128
import groth16/zkey_types
import groth16/files/witness
import groth16/misc
import groth16/files/export_json
import groth16/partial/types
import groth16/partial/precalc
import groth16/partial/finish
import groth16/prover
import groth16/prover/shared
import groth16/verifier
#-------------------------------------------------------------------------------
#[
proc testProveAndVerify*( zkey_fname, wtns_fname: string): (VKey,Proof) =
echo("parsing witness & zkey files...")
let witness = parseWitness( wtns_fname)
let zkey = parseZKey( zkey_fname)
echo("generating proof...")
let start = cpuTime()
let proof = generateProof( zkey, witness )
let elapsed = cpuTime() - start
echo("proving took ",seconds(elapsed))
echo("verifying the proof...")
let vkey = extractVKey( zkey)
let ok = verifyProof( vkey, proof )
echo("verification succeeded = ",ok)
return (vkey,proof)
]#
#-------------------------------------------------------------------------------
proc sanityCheckPartialProofs*( zkey: ZKey, wtns: Witness, pool: Taskpool, printTimings: bool) =
let witness = wtns.values
let M = witness.len
var partial_mask: seq[bool] = newSeq[bool]( M )
var partial_witness: seq[Option[Fr[BN254_Snarks]]] = newSeq[Option[Fr[BN254_Snarks]]]( M )
# generate randomized partial witness
partial_mask[0] = true
partial_witness[0] = some(witness[0])
var count = 0
for i in 1..<M:
let b : bool = rand(bool)
partial_mask[i] = b
if b:
partial_witness[i] = some(witness[i])
count += 1
else:
partial_witness[i] = none(Fr[BN254_Snarks])
echo "\nrandomized a partial witness of size " & $(count) & " out of " & $(M)
let partial_wtns = PartialWitness(values: partial_witness )
let mask = randomMask()
var fullProof : Proof
withMeasureTime(true,"\ngenerating the full proof"):
fullProof = generateProofWithMask( zkey, wtns, mask, pool, printTimings )
writeProof(stdout,fullProof)
let vkey = extractVKey(zkey)
echo "verifying the full proof succeeds = " & $verifyProof(vkey, fullProof)
var partialProof : PartialProof
withMeasureTime(true,"\ngenerating the partial proof"):
partialProof = generatePartialProof( zkey, partial_wtns, pool, printTimings )
var finishedProof : Proof
withMeasureTime(true,"\nfinishing the partial proof"):
finishedProof = finishPartialProofWithMask( zkey, wtns, partialProof, mask, pool, printTimings )
writeProof(stdout,finishedProof)
echo "verifying the finished proof succeeds = " & $verifyProof(vkey, finishedProof)
if (not isEqualProof(fullProof, finishedProof)):
echo "PROBLEM! the two proofs DIFFER!!!"
else:
echo "OK. the two proofs agree"
#-------------------------------------------------------------------------------

View File

@ -37,6 +37,14 @@ type ProjG2* = prj.EC_ShortW_Prj[Fp2[BN254_Snarks], prj.G2]
#-------------------------------------------------------------------------------
func isEqualG1* (x, y: G1 ): bool = bool(x == y)
func isEqualG2* (x, y: G2 ): bool = bool(x == y)
func `===`*(x, y: G1 ): bool = isEqualG1(x,y)
func `===`*(x, y: G2 ): bool = isEqualG2(x,y)
#-------------------------------------------------------------------------------
func unsafeMkG1* ( X, Y: Fp[BN254_Snarks] ) : G1 =
return aff.EC_ShortW_Aff[Fp[BN254_Snarks], aff.G1](x: X, y: Y)

View File

@ -12,7 +12,7 @@
import constantine/named/properties_fields
import constantine/math/extension_fields/towers
import groth16/bn128/fields
#import groth16/bn128/fields
import groth16/bn128/curves
import groth16/bn128/io

View File

@ -19,7 +19,7 @@ import constantine/math/extension_fields/towers as ext
#-------------------------------------------------------------------------------
type B* = BigInt[256]
type B* = BigInt[256]
func mkFp2* (i: Fp[BN254_Snarks], u: Fp[BN254_Snarks]) : Fp2[BN254_Snarks] =
let c : array[2, Fp[BN254_Snarks]] = [i,u]

View File

@ -4,12 +4,9 @@
#
import system
import std/cpuinfo
import taskpools
# import constantine/curves_primitives except Fp, Fp2, Fr
import constantine/platforms/abstractions except Subgroup
import constantine/platforms/abstractions except Subgroup
import constantine/math/endomorphisms/frobenius except Subgroup
import constantine/math/io/io_bigints
@ -23,11 +20,12 @@ import constantine/math/elliptic/ec_shortweierstrass_projective as prj except Su
import constantine/math/elliptic/ec_scalar_mul_vartime as scl except Subgroup
import constantine/math/elliptic/ec_multi_scalar_mul as msm except Subgroup
import groth16/bn128/fields
#import groth16/bn128/fields
import groth16/bn128/curves as mycurves
import groth16/misc # TEMP DEBUGGING
import std/times
#import groth16/misc # TEMP DEBUGGING
#import std/cpuinfo
#import std/times
#-------------------------------------------------------------------------------

View File

@ -23,8 +23,6 @@
import std/streams
import sugar
import constantine/math/arithmetic
import constantine/math/io/io_bigints

View File

@ -67,11 +67,7 @@ proc writeG2( f: File, p: G2 ) =
#-------------------------------------------------------------------------------
# exports the proof into as a JSON file
proc exportProof*( fpath: string, prf: Proof ) =
let f = open(fpath, fmWrite)
defer: f.close()
proc writeProof*( f: File, prf: Proof) =
f.writeLine("{ \"protocol\": \"groth16\"")
f.writeLine(", \"curve\": \"bn128\"" )
@ -80,6 +76,13 @@ proc exportProof*( fpath: string, prf: Proof ) =
f.writeLine(", \"pi_c\":" ) ; writeG1( f, prf.pi_c )
f.writeLine("}")
# exports the proof into as a JSON file
proc exportProof*( fpath: string, prf: Proof ) =
let f = open(fpath, fmWrite)
defer: f.close()
writeProof(f, prf)
#-------------------------------------------------------------------------------
#[

View File

@ -4,7 +4,7 @@
#
import strformat
import times, os, strutils
import times, strutils
#-------------------------------------------------------------------------------

158
groth16/partial/finish.nim Normal file
View File

@ -0,0 +1,158 @@
#
# Finish a "partial proof"
#
{.push raises:[].}
import std/times
import system
import taskpools
import constantine/math/arithmetic
import constantine/named/properties_fields
import groth16/bn128
#import groth16/math/domain
import groth16/math/poly
import groth16/zkey_types
import groth16/files/witness
import groth16/misc
import groth16/partial/types
import groth16/prover/types
import groth16/prover/shared
#-------------------------------------------------------------------------------
# the finishing prover
#
proc finishPartialProofWithMask*( zkey: ZKey, wtns: Witness, partialProof: PartialProof, mask: Mask, pool: Taskpool, printTimings: bool): Proof =
when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)):
{.fatal: "Compile with arc/orc!".}
# if (zkey.header.curve != wtns.curve):
# echo( "zkey.header.curve = " & ($zkey.header.curve) )
# echo( "wtns.curve = " & ($wtns.curve ) )
assert( zkey.header.curve == wtns.curve )
var start : float = 0
let witness = wtns.values
let hdr : GrothHeader = zkey.header
let spec : SpecPoints = zkey.specPoints
let pts : ProverPoints = zkey.pPoints
let nvars = hdr.nvars
let npubs = hdr.npubs
assert( nvars == witness.len , "wrong witness length" )
let partial_mask = partialProof.partial_mask
# remark: with the special variable "1" we actuall have (npub+1) public IO variables
var pubIO = newSeq[Fr[BN254_Snarks]]( npubs + 1)
for i in 0..npubs: pubIO[i] = witness[i]
# note: we have to ignore public inputs (plus 1 for the special first entry)
let startIdx : int = npubs + 1
var count: int = 0 # count the number of NEW witness elements
var zs_cnt: int = 0 # count the number of NEW witness elements, which are NOT public IO
for i in 0..<startidx:
if (not partial_mask[i]):
count += 1
for i in startIdx..<nvars:
if (not partial_mask[i]):
count += 1
zs_cnt += 1
# compactify the stuff so we can call normal vector MSM
var compact_witness: seq[F] = newSeq[F](count)
var compact_pointsA1: seq[G1] = newSeq[G1](count)
var compact_pointsB1: seq[G1] = newSeq[G1](count)
var compact_pointsB2: seq[G2] = newSeq[G2](count)
var compact_zs: seq[F] = newSeq[F](zs_cnt)
var compact_pointsC1: seq[G1] = newSeq[G1](zs_cnt)
block:
var j: int = 0;
var k: int = 0;
for i in 0..<nvars:
if (not partial_mask[i]):
compact_witness[j ] = witness[i]
compact_pointsA1[j] = pts.pointsA1[i]
compact_pointsB1[j] = pts.pointsB1[i]
compact_pointsB2[j] = pts.pointsB2[i]
j += 1
if i >= startIdx:
compact_zs[k] = witness[i]
compact_pointsC1[k] = pts.pointsC1[ i - startIdx ]
k += 1
start = cpuTime()
var abc : ABC
withMeasureTime(printTimings,"building 'ABC'"):
abc = buildABC( zkey, witness )
start = cpuTime()
var qs : seq[Fr[BN254_Snarks]]
withMeasureTime(printTimings,"computing the quotient (FFTs)"):
case zkey.header.flavour
# the points H are [delta^-1 * tau^i * Z(tau)]
of JensGroth:
let polyQ = computeQuotientPointwise( abc, pool )
qs = polyQ.coeffs
# the points H are `[delta^-1 * L_{2i+1}(tau)]_1`
# where L_i are Lagrange basis polynomials on the double-sized domain
of Snarkjs:
qs = computeSnarkjsScalarCoeffs( abc, pool )
# masking coeffs
let r = mask.r
let s = mask.s
assert( witness.len == pts.pointsA1.len )
assert( witness.len == pts.pointsB1.len )
assert( witness.len == pts.pointsB2.len )
assert( hdr.domainSize == qs.len )
assert( hdr.domainSize == pts.pointsH1.len )
assert( nvars - npubs - 1 == pts.pointsC1.len )
var pi_a : G1 = partialProof.partial_pi_a
withMeasureTime(printTimings,"computing pi_A (G1 MSM)"):
pi_a += r ** spec.delta1
pi_a += msmMultiThreadedG1( compact_witness , compact_pointsA1, pool )
var rho : G1 = partialProof.partial_rho
withMeasureTime(printTimings,"computing rho (G1 MSM)"):
rho += s ** spec.delta1
rho += msmMultiThreadedG1( compact_witness , compact_pointsB1, pool )
var pi_b : G2 = partialProof.partial_pi_b
withMeasureTime(printTimings,"computing pi_B (G2 MSM)"):
pi_b += s ** spec.delta2
pi_b += msmMultiThreadedG2( compact_witness , compact_pointsB2, pool )
var pi_c : G1 = partialProof.partial_pi_c
withMeasureTime(printTimings,"computing pi_C (2x G1 MSM)"):
pi_c += s ** pi_a
pi_c += r ** rho
pi_c += negFr(r*s) ** spec.delta1
pi_c += msmMultiThreadedG1( qs , pts.pointsH1, pool )
pi_c += msmMultiThreadedG1( compact_zs , compact_pointsC1, pool )
return Proof( curve:"bn128", publicIO:pubIO, pi_a:pi_a, pi_b:pi_b, pi_c:pi_c )
#-------------------------------------------------------------------------------
proc finishPartialProofWithTrivialMask*( zkey: ZKey, wtns: Witness, partial: PartialProof, pool: Taskpool, printTimings: bool ): Proof =
let mask = Mask( r: zeroFr , s: zeroFr )
return finishPartialProofWithMask( zkey, wtns, partial, mask, pool, printTimings )
proc finishPartialProof*( zkey: ZKey, wtns: Witness, partial: PartialProof, pool: Taskpool, printTimings = false ): Proof =
let mask = randomMask()
return finishPartialProofWithMask( zkey, wtns, partial, mask, pool, printTimings )
#-------------------------------------------------------------------------------

152
groth16/partial/precalc.nim Normal file
View File

@ -0,0 +1,152 @@
#
# "a partial prover": precalculate stuff based on a partial witness
#
# this is useful when a significant portion of the witness does not changes
# between proofs - then such precalculation can result in a potentially big speedup.
#
# an example use is RLN proofs, where the circuit is dominated by the Merkle
#vinclusion proof check, which is expected to change relatively rarely.
#
{.push raises:[].}
import std/times
import std/options
import system
import taskpools
import groth16/bn128
import groth16/zkey_types
import groth16/misc
import groth16/partial/types
#import groth16/math/domain
#import groth16/math/poly
#import groth16/prover/shared
#-------------------------------------------------------------------------------
# the prover
#
proc generatePartialProof*( zkey: ZKey, pwtns: PartialWitness, pool: Taskpool, printTimings: bool): PartialProof =
when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)):
{.fatal: "Compile with arc/orc!".}
# assert( zkey.header.curve == wtns.curve )
var start : float = 0
let partial_witness = pwtns.values
let hdr : GrothHeader = zkey.header
let spec : SpecPoints = zkey.specPoints
let pts : ProverPoints = zkey.pPoints
let nvars = hdr.nvars
let npubs = hdr.npubs
assert( nvars == partial_witness.len , "wrong witness length" )
assert( partial_witness.len == pts.pointsA1.len )
assert( partial_witness.len == pts.pointsB1.len )
assert( partial_witness.len == pts.pointsB2.len )
# note: in "zs" we have to ignore public inputs (plus 1 for the special first entry)
var partial_mask : seq[bool] = newSeq[bool](nvars)
for i in 0..<nvars:
partial_mask[i] = false
let startIdx: int = npubs + 1 # in "zs" we have to skip the public IO
var count: int = 0 # count the number of existing witness elements
var zs_cnt: int = 0 # count the number of existing witness elements, which are NOT public IO
for i in 0..<startidx:
if isSome(partial_witness[i]):
partial_mask[i] = true
count += 1
for i in startIdx..<nvars:
if isSome(partial_witness[i]):
partial_mask[i] = true
count += 1
zs_cnt += 1
# compactify the stuff so we can call normal vector MSM
var compact_witness: seq[F] = newSeq[F](count)
var compact_pointsA1: seq[G1] = newSeq[G1](count)
var compact_pointsB1: seq[G1] = newSeq[G1](count)
var compact_pointsB2: seq[G2] = newSeq[G2](count)
var compact_zs: seq[F] = newSeq[F](zs_cnt)
var compact_pointsC1: seq[G1] = newSeq[G1](zs_cnt)
block:
var j: int = 0;
var k: int = 0;
for i in 0..<nvars:
if isSome(partial_witness[i]):
let wtns_value = partial_witness[i].unsafeGet()
compact_witness[j] = wtns_value
compact_pointsA1[j] = pts.pointsA1[ i ]
compact_pointsB1[j] = pts.pointsB1[ i ]
compact_pointsB2[j] = pts.pointsB2[ i ]
j += 1
if i >= startIdx:
compact_zs[k] = wtns_value
compact_pointsC1[k] = pts.pointsC1[ i - startIdx ]
k += 1
#[
start = cpuTime()
var abc : ABC
withMeasureTime(printTimings,"building 'ABC'"):
abc = buildABC( zkey, witness )
start = cpuTime()
var qs : seq[Fr[BN254_Snarks]]
withMeasureTime(printTimings,"computing the quotient (FFTs)"):
case zkey.header.flavour
# the points H are [delta^-1 * tau^i * Z(tau)]
of JensGroth:
let polyQ = computeQuotientPointwise( abc, pool )
qs = polyQ.coeffs
# the points H are `[delta^-1 * L_{2i+1}(tau)]_1`
# where L_i are Lagrange basis polynomials on the double-sized domain
of Snarkjs:
qs = computeSnarkjsScalarCoeffs( abc, pool )
]#
assert( hdr.domainSize == pts.pointsH1.len )
assert( nvars - npubs - 1 == pts.pointsC1.len )
var pi_a : G1
withMeasureTime(printTimings,"computing partial pi_A (G1 MSM)"):
pi_a = spec.alpha1
pi_a += msmMultiThreadedG1( compact_witness , compact_pointsA1 , pool )
var rho : G1
withMeasureTime(printTimings,"computing partial rho (G1 MSM)"):
rho = spec.beta1
rho += msmMultiThreadedG1( compact_witness , compact_pointsB1 , pool )
var pi_b : G2
withMeasureTime(printTimings,"computing partial pi_B (G2 MSM)"):
pi_b = spec.beta2
pi_b += msmMultiThreadedG2( compact_witness , compact_pointsB2 , pool )
var pi_c : G1
withMeasureTime(printTimings,"partial computing pi_C (G1 MSM)"):
pi_c = msmMultiThreadedG1( compact_zs , compact_pointsC1 , pool )
return PartialProof(
partial_mask : partial_mask
, partial_pi_a : pi_a
, partial_rho : rho
, partial_pi_b : pi_b
, partial_pi_c : pi_c
)
#-------------------------------------------------------------------------------

29
groth16/partial/types.nim Normal file
View File

@ -0,0 +1,29 @@
{.push raises:[].}
import std/options
import constantine/named/properties_fields
import groth16/bn128
import groth16/prover/types
export types
#-------------------------------------------------------------------------------
type
F* = Fr[BN254_Snarks]
PartialWitness* = object
values* : seq[Option[F]]
PartialProof* = object
partial_mask* : seq[bool]
partial_pi_a* : G1 # = [alpha]_1 + sum z_j*[A_j]_1
partial_rho* : G1 # = [beta]_1 + sum z_j*[B_j]_1
partial_pi_b* : G2 # = [beta]_2 + sum z_j*[B_j]_2
partial_pi_c* : G1 # = sum z_j*[K_j]_1
#-------------------------------------------------------------------------------

View File

@ -7,332 +7,9 @@
# See <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
#
{.push raises:[].}
import groth16/prover/types
import groth16/prover/groth16
#[
import sugar
import constantine/math/config/curves
import constantine/math/io/io_fields
import constantine/math/io/io_bigints
import ./zkey
]#
export types
export groth16
import std/os
import std/times
import std/cpuinfo
import system
import taskpools
import constantine/math/arithmetic
import constantine/named/properties_fields
import constantine/math/extension_fields/towers
#import constantine/math/io/io_extfields except Fp12
import groth16/bn128
import groth16/math/domain
import groth16/math/poly
import groth16/zkey_types
import groth16/files/witness
import groth16/misc
#-------------------------------------------------------------------------------
type
Proof* = object
publicIO* : seq[Fr[BN254_Snarks]]
pi_a* : G1
pi_b* : G2
pi_c* : G1
curve* : string
#-------------------------------------------------------------------------------
# Az, Bz, Cz column vectors
#
type
ABC = object
valuesAz : seq[Fr[BN254_Snarks]]
valuesBz : seq[Fr[BN254_Snarks]]
valuesCz : seq[Fr[BN254_Snarks]]
# computes the vectors A*z, B*z, C*z where z is the witness
func buildABC( zkey: ZKey, witness: seq[Fr[BN254_Snarks]] ): ABC =
let hdr: GrothHeader = zkey.header
let domSize = hdr.domainSize
var valuesAz = newSeq[Fr[BN254_Snarks]](domSize)
var valuesBz = newSeq[Fr[BN254_Snarks]](domSize)
for entry in zkey.coeffs:
case entry.matrix
of MatrixA: valuesAz[entry.row] += entry.coeff * witness[entry.col]
of MatrixB: valuesBz[entry.row] += entry.coeff * witness[entry.col]
else: raise newException(AssertionDefect, "fatal error")
var valuesCz = newSeq[Fr[BN254_Snarks]](domSize)
for i in 0..<domSize:
valuesCz[i] = valuesAz[i] * valuesBz[i]
return ABC( valuesAz:valuesAz, valuesBz:valuesBz, valuesCz:valuesCz )
#-------------------------------------------------------------------------------
# quotient poly
#
# interpolates A,B,C, and computes the quotient polynomial Q = (A*B - C) / Z
func computeQuotientNaive( abc: ABC ): Poly=
let n = abc.valuesAz.len
assert( abc.valuesBz.len == n )
assert( abc.valuesCz.len == n )
let D = createDomain(n)
let polyA : Poly = polyInverseNTT( abc.valuesAz , D )
let polyB : Poly = polyInverseNTT( abc.valuesBz , D )
let polyC : Poly = polyInverseNTT( abc.valuesCz , D )
let polyBig = polyMulFFT( polyA , polyB ) - polyC
var polyQ = polyDivideByVanishing(polyBig, D.domainSize)
polyQ.coeffs.add( zeroFr ) # make it a power of two
return polyQ
#---------------------------------------
# returns [ eta^i * xs[i] | i<-[0..n-1] ]
func multiplyByPowers( xs: seq[Fr[BN254_Snarks]], eta: Fr[BN254_Snarks] ): seq[Fr[BN254_Snarks]] =
let n = xs.len
assert(n >= 1)
var ys = newSeq[Fr[BN254_Snarks]](n)
ys[0] = xs[0]
if n >= 1: ys[1] = eta * xs[1]
var spow = eta
for i in 2..<n:
spow *= eta
ys[i] = spow * xs[i]
return ys
# interpolates a polynomial, shift the variable by `eta`, and compute the shifted values
func shiftEvalDomain(
values: seq[Fr[BN254_Snarks]],
D: Domain,
eta: Fr[BN254_Snarks] ): seq[Fr[BN254_Snarks]] =
let poly : Poly = polyInverseNTT( values , D )
let cs : seq[Fr[BN254_Snarks]] = poly.coeffs
var ds : seq[Fr[BN254_Snarks]] = multiplyByPowers( cs, eta )
return polyForwardNTT( Poly(coeffs:ds), D )
# Wraps shiftEvalDomain such that it can be called by Taskpool.spawn. The result
# is written to the output parameter. Has an unused return type because
# Taskpool.spawn cannot handle a void return type.
func shiftEvalDomainTask(
values: seq[Fr[BN254_Snarks]],
D: Domain,
eta: Fr[BN254_Snarks],
output: ptr Isolated[seq[Fr[BN254_Snarks]]]): bool =
output[] = isolate shiftEvalDomain(values, D, eta)
# computes the quotient polynomial Q = (A*B - C) / Z
# by computing the values on a shifted domain, and interpolating the result
# remark: Q has degree `n-2`, so it's enough to use a domain of size n
proc computeQuotientPointwise( abc: ABC, pool: TaskPool ): Poly =
let n = abc.valuesAz.len
assert( abc.valuesBz.len == n )
assert( abc.valuesCz.len == n )
let D = createDomain(n)
# (eta*omega^j)^n - 1 = eta^n - 1
# 1 / [ (eta*omega^j)^n - 1] = 1/(eta^n - 1)
let eta = createDomain(2*n).domainGen
let invZ1 = invFr( smallPowFr(eta,n) - oneFr )
var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]]
var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 )
var taskB1 = pool.spawn shiftEvalDomainTask( abc.valuesBz, D, eta, addr outputB1 )
var taskC1 = pool.spawn shiftEvalDomainTask( abc.valuesCz, D, eta, addr outputC1 )
discard sync taskA1
discard sync taskB1
discard sync taskC1
let A1 = outputA1.extract()
let B1 = outputB1.extract()
let C1 = outputC1.extract()
var ys : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( n )
for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] ) * invZ1
let Q1 = polyInverseNTT( ys, D )
let cs = multiplyByPowers( Q1.coeffs, invFr(eta) )
return Poly(coeffs: cs)
#---------------------------------------
# Snarkjs does something different, not actually computing the quotient poly
# they can get away with this, because during the trusted setup, they
# replace the points encoding the values `delta^-1 * tau^i * Z(tau)` by
# (shifted) Lagrange bases.
# see <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
#
proc computeSnarkjsScalarCoeffs( abc: ABC, pool: TaskPool ): seq[Fr[BN254_Snarks]] =
let n = abc.valuesAz.len
assert( abc.valuesBz.len == n )
assert( abc.valuesCz.len == n )
let D = createDomain(n)
let eta = createDomain(2*n).domainGen
var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]]
var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 )
var taskB1 = pool.spawn shiftEvalDomainTask( abc.valuesBz, D, eta, addr outputB1 )
var taskC1 = pool.spawn shiftEvalDomainTask( abc.valuesCz, D, eta, addr outputC1 )
discard sync taskA1
discard sync taskB1
discard sync taskC1
let A1 = outputA1.extract()
let B1 = outputB1.extract()
let C1 = outputC1.extract()
var ys : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( n )
for j in 0..<n: ys[j] = ( A1[j] * B1[j] - C1[j] )
return ys
#[
proc computeSnarkjsScalarCoeffs_st( abc: ABC ): seq[Fr] =
let n = abc.valuesAz.len
assert( abc.valuesBz.len == n )
assert( abc.valuesCz.len == n )
let D = createDomain(n)
let eta = createDomain(2*n).domainGen
let A1 : seq[Fr] = shiftEvalDomain( abc.valuesAz, D, eta )
let B1 : seq[Fr] = shiftEvalDomain( abc.valuesBz, D, eta )
let C1 : seq[Fr] = shiftEvalDomain( abc.valuesCz, D, eta )
var ys : seq[Fr] = newSeq[Fr]( n )
for j in 0..<n: ys[j] = ( A1[j] * B1[j] - C1[j] )
return ys
proc computeSnarkjsScalarCoeffs( nthreads: int, abc: ABC ): seq[Fr] =
if nthreads <= 1:
computeSnarkjsScalarCoeffs_st( abc )
else:
computeSnarkjsScalarCoeffs_mt( nthreads, abc )
]#
#-------------------------------------------------------------------------------
# the prover
#
type
Mask* = object
r*: Fr[BN254_Snarks] # masking coefficients
s*: Fr[BN254_Snarks] # for zero knowledge
proc generateProofWithMask*( zkey: ZKey, wtns: Witness, mask: Mask, pool: Taskpool, printTimings: bool): Proof =
when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)):
{.fatal: "Compile with arc/orc!".}
# if (zkey.header.curve != wtns.curve):
# echo( "zkey.header.curve = " & ($zkey.header.curve) )
# echo( "wtns.curve = " & ($wtns.curve ) )
assert( zkey.header.curve == wtns.curve )
var start : float = 0
let witness = wtns.values
let hdr : GrothHeader = zkey.header
let spec : SpecPoints = zkey.specPoints
let pts : ProverPoints = zkey.pPoints
let nvars = hdr.nvars
let npubs = hdr.npubs
assert( nvars == witness.len , "wrong witness length" )
# remark: with the special variable "1" we actuall have (npub+1) public IO variables
var pubIO = newSeq[Fr[BN254_Snarks]]( npubs + 1)
for i in 0..npubs: pubIO[i] = witness[i]
start = cpuTime()
var abc : ABC
withMeasureTime(printTimings,"building 'ABC'"):
abc = buildABC( zkey, witness )
start = cpuTime()
var qs : seq[Fr[BN254_Snarks]]
withMeasureTime(printTimings,"computing the quotient (FFTs)"):
case zkey.header.flavour
# the points H are [delta^-1 * tau^i * Z(tau)]
of JensGroth:
let polyQ = computeQuotientPointwise( abc, pool )
qs = polyQ.coeffs
# the points H are `[delta^-1 * L_{2i+1}(tau)]_1`
# where L_i are Lagrange basis polynomials on the double-sized domain
of Snarkjs:
qs = computeSnarkjsScalarCoeffs( abc, pool )
var zs = newSeq[Fr[BN254_Snarks]]( nvars - npubs - 1 )
for j in npubs+1..<nvars:
zs[j-npubs-1] = witness[j]
# masking coeffs
let r = mask.r
let s = mask.s
assert( witness.len == pts.pointsA1.len )
assert( witness.len == pts.pointsB1.len )
assert( witness.len == pts.pointsB2.len )
assert( hdr.domainSize == qs.len )
assert( hdr.domainSize == pts.pointsH1.len )
assert( nvars - npubs - 1 == zs.len )
assert( nvars - npubs - 1 == pts.pointsC1.len )
var pi_a : G1
withMeasureTime(printTimings,"computing pi_A (G1 MSM)"):
pi_a = spec.alpha1
pi_a += r ** spec.delta1
pi_a += msmMultiThreadedG1( witness , pts.pointsA1, pool )
var rho : G1
withMeasureTime(printTimings,"computing rho (G1 MSM)"):
rho = spec.beta1
rho += s ** spec.delta1
rho += msmMultiThreadedG1( witness , pts.pointsB1, pool )
var pi_b : G2
withMeasureTime(printTimings,"computing pi_B (G2 MSM)"):
pi_b = spec.beta2
pi_b += s ** spec.delta2
pi_b += msmMultiThreadedG2( witness , pts.pointsB2, pool )
var pi_c : G1
withMeasureTime(printTimings,"computing pi_C (2x G1 MSM)"):
pi_c = s ** pi_a
pi_c += r ** rho
pi_c += negFr(r*s) ** spec.delta1
pi_c += msmMultiThreadedG1( qs , pts.pointsH1, pool )
pi_c += msmMultiThreadedG1( zs , pts.pointsC1, pool )
return Proof( curve:"bn128", publicIO:pubIO, pi_a:pi_a, pi_b:pi_b, pi_c:pi_c )
#-------------------------------------------------------------------------------
proc generateProofWithTrivialMask*( zkey: ZKey, wtns: Witness, pool: Taskpool, printTimings: bool ): Proof =
let mask = Mask( r: zeroFr , s: zeroFr )
return generateProofWithMask( zkey, wtns, mask, pool, printTimings )
proc generateProof*( zkey: ZKey, wtns: Witness, pool: Taskpool, printTimings = false ): Proof =
# masking coeffs
let r = randFr()
let s = randFr()
let mask = Mask(r: r, s: s)
return generateProofWithMask( zkey, wtns, mask, pool, printTimings )

135
groth16/prover/groth16.nim Normal file
View File

@ -0,0 +1,135 @@
#
# Groth16 prover
#
# WARNING!
# the points H in `.zkey` are *NOT* what normal people would think they are
# See <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
#
{.push raises:[].}
import std/times
import system
import taskpools
import constantine/math/arithmetic
import constantine/named/properties_fields
#import constantine/math/io/io_extfields except Fp12
import groth16/bn128
#import groth16/math/domain
import groth16/math/poly
import groth16/zkey_types
import groth16/files/witness
import groth16/misc
import groth16/prover/types
import groth16/prover/shared
#-------------------------------------------------------------------------------
# the prover
#
proc generateProofWithMask*( zkey: ZKey, wtns: Witness, mask: Mask, pool: Taskpool, printTimings: bool): Proof =
when not (defined(gcArc) or defined(gcOrc) or defined(gcAtomicArc)):
{.fatal: "Compile with arc/orc!".}
# if (zkey.header.curve != wtns.curve):
# echo( "zkey.header.curve = " & ($zkey.header.curve) )
# echo( "wtns.curve = " & ($wtns.curve ) )
assert( zkey.header.curve == wtns.curve )
var start : float = 0
let witness = wtns.values
let hdr : GrothHeader = zkey.header
let spec : SpecPoints = zkey.specPoints
let pts : ProverPoints = zkey.pPoints
let nvars = hdr.nvars
let npubs = hdr.npubs
assert( nvars == witness.len , "wrong witness length" )
# remark: with the special variable "1" we actuall have (npub+1) public IO variables
var pubIO = newSeq[Fr[BN254_Snarks]]( npubs + 1)
for i in 0..npubs: pubIO[i] = witness[i]
start = cpuTime()
var abc : ABC
withMeasureTime(printTimings,"building 'ABC'"):
abc = buildABC( zkey, witness )
start = cpuTime()
var qs : seq[Fr[BN254_Snarks]]
withMeasureTime(printTimings,"computing the quotient (FFTs)"):
case zkey.header.flavour
# the points H are [delta^-1 * tau^i * Z(tau)]
of JensGroth:
let polyQ = computeQuotientPointwise( abc, pool )
qs = polyQ.coeffs
# the points H are `[delta^-1 * L_{2i+1}(tau)]_1`
# where L_i are Lagrange basis polynomials on the double-sized domain
of Snarkjs:
qs = computeSnarkjsScalarCoeffs( abc, pool )
var zs = newSeq[Fr[BN254_Snarks]]( nvars - npubs - 1 )
for j in npubs+1..<nvars:
zs[j-npubs-1] = witness[j]
# masking coeffs
let r = mask.r
let s = mask.s
assert( witness.len == pts.pointsA1.len )
assert( witness.len == pts.pointsB1.len )
assert( witness.len == pts.pointsB2.len )
assert( hdr.domainSize == qs.len )
assert( hdr.domainSize == pts.pointsH1.len )
assert( nvars - npubs - 1 == zs.len )
assert( nvars - npubs - 1 == pts.pointsC1.len )
var pi_a : G1
withMeasureTime(printTimings,"computing pi_A (G1 MSM)"):
pi_a = spec.alpha1
pi_a += r ** spec.delta1
pi_a += msmMultiThreadedG1( witness , pts.pointsA1, pool )
var rho : G1
withMeasureTime(printTimings,"computing rho (G1 MSM)"):
rho = spec.beta1
rho += s ** spec.delta1
rho += msmMultiThreadedG1( witness , pts.pointsB1, pool )
var pi_b : G2
withMeasureTime(printTimings,"computing pi_B (G2 MSM)"):
pi_b = spec.beta2
pi_b += s ** spec.delta2
pi_b += msmMultiThreadedG2( witness , pts.pointsB2, pool )
var pi_c : G1
withMeasureTime(printTimings,"computing pi_C (2x G1 MSM)"):
pi_c = s ** pi_a
pi_c += r ** rho
pi_c += negFr(r*s) ** spec.delta1
pi_c += msmMultiThreadedG1( qs , pts.pointsH1, pool )
pi_c += msmMultiThreadedG1( zs , pts.pointsC1, pool )
return Proof( curve:"bn128", publicIO:pubIO, pi_a:pi_a, pi_b:pi_b, pi_c:pi_c )
#-------------------------------------------------------------------------------
proc generateProofWithTrivialMask*( zkey: ZKey, wtns: Witness, pool: Taskpool, printTimings: bool ): Proof =
let mask = Mask( r: zeroFr , s: zeroFr )
return generateProofWithMask( zkey, wtns, mask, pool, printTimings )
proc generateProof*( zkey: ZKey, wtns: Witness, pool: Taskpool, printTimings = false ): Proof =
let mask = randomMask()
return generateProofWithMask( zkey, wtns, mask, pool, printTimings )
#-------------------------------------------------------------------------------

180
groth16/prover/shared.nim Normal file
View File

@ -0,0 +1,180 @@
#
# Groth16 prover
#
# WARNING!
# the points H in `.zkey` are *NOT* what normal people would think they are
# See <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
#
{.push raises:[].}
import system
import taskpools
import constantine/math/arithmetic
import constantine/named/properties_fields
import groth16/bn128
import groth16/math/domain
import groth16/math/poly
import groth16/zkey_types
#import groth16/misc
import groth16/prover/types
#-------------------------------------------------------------------------------
proc randomMask*(): Mask =
# masking coeffs
let r = randFr()
let s = randFr()
let mask = Mask(r: r, s: s)
return mask
#-------------------------------------------------------------------------------
# computes the vectors A*z, B*z, C*z where z is the witness
func buildABC*( zkey: ZKey, witness: seq[Fr[BN254_Snarks]] ): ABC =
let hdr: GrothHeader = zkey.header
let domSize = hdr.domainSize
var valuesAz = newSeq[Fr[BN254_Snarks]](domSize)
var valuesBz = newSeq[Fr[BN254_Snarks]](domSize)
for entry in zkey.coeffs:
case entry.matrix
of MatrixA: valuesAz[entry.row] += entry.coeff * witness[entry.col]
of MatrixB: valuesBz[entry.row] += entry.coeff * witness[entry.col]
else: raise newException(AssertionDefect, "fatal error")
var valuesCz = newSeq[Fr[BN254_Snarks]](domSize)
for i in 0..<domSize:
valuesCz[i] = valuesAz[i] * valuesBz[i]
return ABC( valuesAz:valuesAz, valuesBz:valuesBz, valuesCz:valuesCz )
#-------------------------------------------------------------------------------
# quotient poly
#
# interpolates A,B,C, and computes the quotient polynomial Q = (A*B - C) / Z
func computeQuotientNaive*( abc: ABC ): Poly=
let n = abc.valuesAz.len
assert( abc.valuesBz.len == n )
assert( abc.valuesCz.len == n )
let D = createDomain(n)
let polyA : Poly = polyInverseNTT( abc.valuesAz , D )
let polyB : Poly = polyInverseNTT( abc.valuesBz , D )
let polyC : Poly = polyInverseNTT( abc.valuesCz , D )
let polyBig = polyMulFFT( polyA , polyB ) - polyC
var polyQ = polyDivideByVanishing(polyBig, D.domainSize)
polyQ.coeffs.add( zeroFr ) # make it a power of two
return polyQ
#---------------------------------------
# returns [ eta^i * xs[i] | i<-[0..n-1] ]
func multiplyByPowers*( xs: seq[Fr[BN254_Snarks]], eta: Fr[BN254_Snarks] ): seq[Fr[BN254_Snarks]] =
let n = xs.len
assert(n >= 1)
var ys = newSeq[Fr[BN254_Snarks]](n)
ys[0] = xs[0]
if n >= 1: ys[1] = eta * xs[1]
var spow = eta
for i in 2..<n:
spow *= eta
ys[i] = spow * xs[i]
return ys
# interpolates a polynomial, shift the variable by `eta`, and compute the shifted values
func shiftEvalDomain*(
values: seq[Fr[BN254_Snarks]],
D: Domain,
eta: Fr[BN254_Snarks] ): seq[Fr[BN254_Snarks]] =
let poly : Poly = polyInverseNTT( values , D )
let cs : seq[Fr[BN254_Snarks]] = poly.coeffs
var ds : seq[Fr[BN254_Snarks]] = multiplyByPowers( cs, eta )
return polyForwardNTT( Poly(coeffs:ds), D )
# Wraps shiftEvalDomain such that it can be called by Taskpool.spawn. The result
# is written to the output parameter. Has an unused return type because
# Taskpool.spawn cannot handle a void return type.
func shiftEvalDomainTask*(
values: seq[Fr[BN254_Snarks]],
D: Domain,
eta: Fr[BN254_Snarks],
output: ptr Isolated[seq[Fr[BN254_Snarks]]]): bool =
output[] = isolate shiftEvalDomain(values, D, eta)
# computes the quotient polynomial Q = (A*B - C) / Z
# by computing the values on a shifted domain, and interpolating the result
# remark: Q has degree `n-2`, so it's enough to use a domain of size n
proc computeQuotientPointwise*( abc: ABC, pool: TaskPool ): Poly =
let n = abc.valuesAz.len
assert( abc.valuesBz.len == n )
assert( abc.valuesCz.len == n )
let D = createDomain(n)
# (eta*omega^j)^n - 1 = eta^n - 1
# 1 / [ (eta*omega^j)^n - 1] = 1/(eta^n - 1)
let eta = createDomain(2*n).domainGen
let invZ1 = invFr( smallPowFr(eta,n) - oneFr )
var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]]
var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 )
var taskB1 = pool.spawn shiftEvalDomainTask( abc.valuesBz, D, eta, addr outputB1 )
var taskC1 = pool.spawn shiftEvalDomainTask( abc.valuesCz, D, eta, addr outputC1 )
discard sync taskA1
discard sync taskB1
discard sync taskC1
let A1 = outputA1.extract()
let B1 = outputB1.extract()
let C1 = outputC1.extract()
var ys : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( n )
for j in 0..<n: ys[j] = ( A1[j]*B1[j] - C1[j] ) * invZ1
let Q1 = polyInverseNTT( ys, D )
let cs = multiplyByPowers( Q1.coeffs, invFr(eta) )
return Poly(coeffs: cs)
#---------------------------------------
# Snarkjs does something different, not actually computing the quotient poly
# they can get away with this, because during the trusted setup, they
# replace the points encoding the values `delta^-1 * tau^i * Z(tau)` by
# (shifted) Lagrange bases.
# see <https://geometry.xyz/notebook/the-hidden-little-secret-in-snarkjs>
#
proc computeSnarkjsScalarCoeffs*( abc: ABC, pool: TaskPool ): seq[Fr[BN254_Snarks]] =
let n = abc.valuesAz.len
assert( abc.valuesBz.len == n )
assert( abc.valuesCz.len == n )
let D = createDomain(n)
let eta = createDomain(2*n).domainGen
var outputA1, outputB1, outputC1: Isolated[seq[Fr[BN254_Snarks]]]
var taskA1 = pool.spawn shiftEvalDomainTask( abc.valuesAz, D, eta, addr outputA1 )
var taskB1 = pool.spawn shiftEvalDomainTask( abc.valuesBz, D, eta, addr outputB1 )
var taskC1 = pool.spawn shiftEvalDomainTask( abc.valuesCz, D, eta, addr outputC1 )
discard sync taskA1
discard sync taskB1
discard sync taskC1
let A1 = outputA1.extract()
let B1 = outputB1.extract()
let C1 = outputC1.extract()
var ys : seq[Fr[BN254_Snarks]] = newSeq[Fr[BN254_Snarks]]( n )
for j in 0..<n: ys[j] = ( A1[j] * B1[j] - C1[j] )
return ys
#-------------------------------------------------------------------------------

41
groth16/prover/types.nim Normal file
View File

@ -0,0 +1,41 @@
{.push raises:[].}
import constantine/named/properties_fields
import groth16/bn128
#-------------------------------------------------------------------------------
type
Mask* = object
r*: Fr[BN254_Snarks] # masking coefficients
s*: Fr[BN254_Snarks] # for zero knowledge
#-------------------------------------------------------------------------------
# a Groth16 proof
type
Proof* = object
publicIO* : seq[Fr[BN254_Snarks]]
pi_a* : G1
pi_b* : G2
pi_c* : G1
curve* : string
func isEqualProof*(prf1, prf2: Proof): bool =
return (prf1.pi_a === prf2.pi_a) and
(prf1.pi_b === prf2.pi_b) and
(prf1.pi_c === prf2.pi_c)
#-------------------------------------------------------------------------------
# Az, Bz, Cz column vectors
#
type
ABC* = object
valuesAz* : seq[Fr[BN254_Snarks]]
valuesBz* : seq[Fr[BN254_Snarks]]
valuesCz* : seq[Fr[BN254_Snarks]]
#-------------------------------------------------------------------------------