circom files parsing (r1cs, wtns and zkey files)

This commit is contained in:
Balazs Komuves 2023-11-09 15:20:49 +01:00
parent e6eb074c0b
commit ba04191b72
7 changed files with 955 additions and 0 deletions

284
bn128.nim Normal file
View File

@ -0,0 +1,284 @@
import std/strutils
import std/streams
import constantine/math/arithmetic
import constantine/math/io/io_fields
import constantine/math/io/io_bigints
import constantine/math/config/curves
import constantine/math/config/type_ff as tff
import constantine/math/extension_fields/towers as ext
import constantine/math/elliptic/ec_shortweierstrass_affine as ell
#-------------------------------------------------------------------------------
type B* = BigInt[256]
type Fr* = tff.Fr[BN254Snarks]
type Fp* = tff.Fp[BN254Snarks]
type Fp2* = ext.QuadraticExt[Fp]
type G1* = ell.ECP_ShortW_Aff[Fp , ell.G1]
type G2* = ell.ECP_ShortW_Aff[Fp2, ell.G2]
func mkFp2(i: Fp, u: Fp) : Fp2 =
let c : array[2, Fp] = [i,u]
return ext.QuadraticExt[Fp]( coords: c )
func unsafeMkG1( X, Y: Fp ) : G1 =
return ell.ECP_ShortW_Aff[Fp, ell.G1](x: X, y: Y)
func unsafeMkG2( X, Y: Fp2 ) : G2 =
return ell.ECP_ShortW_Aff[Fp2, ell.G2](x: X, y: Y)
#-------------------------------------------------------------------------------
const primeP* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian )
const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian )
#-------------------------------------------------------------------------------
const zeroFp* : Fp = fromHex( Fp, "0x00" )
const zeroFr* : Fr = fromHex( Fr, "0x00" )
const oneFp* : Fp = fromHex( Fp, "0x01" )
const oneFr* : Fr = fromHex( Fr, "0x01" )
const zeroFp2* : Fp2 = mkFp2( zeroFp, zeroFp )
const infG1* : G1 = unsafeMkG1( zeroFp , zeroFp )
const infG2* : G2 = unsafeMkG2( zeroFp2 , zeroFp2 )
#-------------------------------------------------------------------------------
func intToFp*(a: int) : Fp =
var y : Fp
y.fromInt(a)
return y
func intToFr*(a: int) : Fr =
var y : Fr
y.fromInt(a)
return y
#-------------------------------------------------------------------------------
func toDecimalBig*[n](a : BigInt[n]): string =
var s : string = toDecimal(a)
s = s.strip( leading=true, trailing=false, chars={'0'} )
if s.len == 0: s="0"
return s
func toDecimalFp*(a : Fp): string =
var s : string = toDecimal(a)
s = s.strip( leading=true, trailing=false, chars={'0'} )
if s.len == 0: s="0"
return s
func toDecimalFr*(a : Fr): string =
var s : string = toDecimal(a)
s = s.strip( leading=true, trailing=false, chars={'0'} )
if s.len == 0: s="0"
return s
#-------------------------------------------------------------------------------
func checkCurveEqG1*( x, y: Fp ) : bool =
var x2 : Fp = x ; square(x2);
var y2 : Fp = y ; square(y2);
var x3 : Fp = x2 ; x3 *= x;
var eq : Fp
eq = x3
eq += intToFp(3)
eq -= y2
# echo("eq = ",toDecimalFp(eq))
return (bool(isZero(eq)))
# y^2 = x^3 + B
# B = b1 + bu*u
# b1 = 19485874751759354771024239261021720505790618469301721065564631296452457478373
# b2 = 266929791119991161246907387137283842545076965332900288569378510910307636690
const twistCoeffB_1 : Fp = fromHex(Fp, "0x2b149d40ceb8aaae81be18991be06ac3b5b4c5e559dbefa33267e6dc24a138e5")
const twistCoeffB_u : Fp = fromHex(Fp, "0x009713b03af0fed4cd2cafadeed8fdf4a74fa084e52d1852e4a2bd0685c315d2")
const twistCoeffB : Fp2 = mkFp2( twistCoeffB_1 , twistCoeffB_u )
func checkCurveEqG2*( x, y: Fp2 ) : bool =
var x2 : Fp2 = x ; square(x2);
var y2 : Fp2 = y ; square(y2);
var x3 : Fp2 = x2 ; x3 *= x;
var eq : Fp2
eq = x3
eq += twistCoeffB
eq -= y2
return (bool(isZero(eq)))
#-------------------------------------------------------------------------------
func mkG1( x, y: Fp ) : G1 =
if bool(isZero(x)) and bool(isZero(y)):
return infG1
else:
assert( checkCurveEqG1(x,y) , "mkG1: not a G1 curve point" )
return unsafeMkG1(x,y)
func mkG2( x, y: Fp2 ) : G2 =
if bool(isZero(x)) and bool(isZero(y)):
return infG2
else:
assert( checkCurveEqG2(x,y) , "mkG2: not a G2 curve point" )
return unsafeMkG2(x,y)
#-------------------------------------------------------------------------------
# Dealing with Montgomery representation
#
# R=2^256; this computes 2^256 mod Fp
func calcFpMontR*() : Fp =
var x : Fp = intToFp(2)
for i in 1..8:
square(x)
return x
# R=2^256; this computes the inverse of (2^256 mod Fp)
func calcFpInvMontR*() : Fp =
var x : Fp = calcFpMontR()
inv(x)
return x
# R=2^256; this computes 2^256 mod Fr
func calcFrMontR*() : Fr =
var x : Fr = intToFr(2)
for i in 1..8:
square(x)
return x
# R=2^256; this computes the inverse of (2^256 mod Fp)
func calcFrInvMontR*() : Fr =
var x : Fr = calcFrMontR()
inv(x)
return x
# apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?)
const fpMontR* : Fp = fromHex( Fp, "0x0e0a77c19a07df2f666ea36f7879462c0a78eb28f5c70b3dd35d438dc58f0d9d" )
const fpInvMontR* : Fp = fromHex( Fp, "0x2e67157159e5c639cf63e9cfb74492d9eb2022850278edf8ed84884a014afa37" )
# apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?)
const frMontR* : Fr = fromHex( Fr, "0x0e0a77c19a07df2f666ea36f7879462e36fc76959f60cd29ac96341c4ffffffb" )
const frInvMontR* : Fr = fromHex( Fr, "0x15ebf95182c5551cc8260de4aeb85d5d090ef5a9e111ec87dc5ba0056db1194e" )
proc checkMontgomeryConstants*() =
assert( bool( fpMontR == calcFpMontR() ) )
assert( bool( frMontR == calcFrMontR() ) )
assert( bool( fpInvMontR == calcFpInvMontR() ) )
assert( bool( frInvMontR == calcFrInvMontR() ) )
echo("OK")
#---------------------------------------
# the binary files used by the `circom` ecosystem always use little-endian
# Montgomery representation. So when we unmarshal with Constantine, it will
# give the wrong result. Calling this function on the result fixes that.
func fromMontgomeryFp*(x : Fp) : Fp =
var y : Fp = x;
y *= fpInvMontR
return y
func fromMontgomeryFr*(x : Fr) : Fr =
var y : Fr = x;
y *= frInvMontR
return y
#-------------------------------------------------------------------------------
# Unmarshalling field elements
# (note: circom binary files use little-endian Montgomery representation)
#
func unmarshalFp* ( bs: array[32,byte] ) : Fp =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fp
x.fromBig( big )
return fromMontgomeryFp(x)
func unmarshalFr* ( bs: array[32,byte] ) : Fr =
var big : BigInt[254]
unmarshal( big, bs, littleEndian );
var x : Fr
x.fromBig( big )
return fromMontgomeryFr(x)
#-------------------------------------------------------------------------------
func unmarshalFpSeq* ( len: int, bs: openArray[byte] ) : seq[Fp] =
var vals : seq[Fp] = newSeq[Fp]( len )
var bytes : array[32,byte]
for i in 0..<len:
copyMem( addr(bytes) , unsafeAddr(bs[32*i]) , 32 )
vals[i] = unmarshalFp( bytes )
return vals
func unmarshalFrSeq* ( len: int, bs: openArray[byte] ) : seq[Fr] =
var vals : seq[Fr] = newSeq[Fr]( len )
var bytes : array[32,byte]
for i in 0..<len:
copyMem( addr(bytes) , unsafeAddr(bs[32*i]) , 32 )
vals[i] = unmarshalFr( bytes )
return vals
#-------------------------------------------------------------------------------
proc loadValueFp*( stream: Stream ) : Fp =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
assert( n == 32 )
return unmarshalFp(bytes)
proc loadValueFr*( stream: Stream ) : Fr =
var bytes : array[32,byte]
let n = stream.readData( addr(bytes), 32 )
assert( n == 32 )
return unmarshalFr(bytes)
proc loadValueFp2*( stream: Stream ) : Fp2 =
let i = loadValueFp( stream )
let u = loadValueFp( stream )
return mkFp2(i,u)
#---------------------------------------
proc loadValuesFp*( len: int, stream: Stream ) : seq[Fp] =
var values : seq[Fp]
for i in 1..len:
values.add( loadValueFp(stream) )
return values
proc loadValuesFr*( len: int, stream: Stream ) : seq[Fr] =
var values : seq[Fr]
for i in 1..len:
values.add( loadValueFr(stream) )
return values
#-------------------------------------------------------------------------------
proc loadPointG1*( stream: Stream ) : G1 =
let x = loadValueFp( stream )
let y = loadValueFp( stream )
return mkG1(x,y)
proc loadPointG2*( stream: Stream ) : G2 =
let x = loadValueFp2( stream )
let y = loadValueFp2( stream )
return mkG2(x,y)
#---------------------------------------
proc loadPointsG1*( len: int, stream: Stream ) : seq[G1] =
var points : seq[G1]
for i in 1..len:
points.add( loadPointG1(stream) )
return points
proc loadPointsG2*( len: int, stream: Stream ) : seq[G2] =
var points : seq[G2]
for i in 1..len:
points.add( loadPointG2(stream) )
return points
#-------------------------------------------------------------------------------

98
container.nim Normal file
View File

@ -0,0 +1,98 @@
#
# the container format used by `circom` / `snarkjs`
# see <https://github.com/iden3/binfileutils>
#
# format:
# =======
#
# global header:
# --------------
# magic : word32
# version : word32
# number of sections : word32
#
# for each section:
# -----------------
# section id : word32
# section size : word64
# section data : <section_size> number of bytes
#
#-------------------------------------------------------------------------------
import std/streams
import sugar
import constantine/math/arithmetic except Fp, Fr
import constantine/math/io/io_bigints
#-------------------------------------------------------------------------------
type
SectionCallback*[T] = proc (stream: Stream, sectId: int, sectLen: int, user: var T) {.closure.}
#-------------------------------------------------------------------------------
func magicWord(magic: string): uint32 =
assert( magic.len == 4, "magicWord: expecting a string of 4 characters" )
var w : uint32 = 0
for i in 0..3:
let a = uint32(ord(magic[i]))
w += a shl (8*i)
return w
#-------------------------------------------------------------------------------
proc parsePrimeField*( stream: Stream ) : (int, BigInt[256]) =
let n8p = int( stream.readUint32() )
assert( n8p <= 32 , "at most 256 bit primes are allowed" )
var p_bytes : array[32, uint8]
discard stream.readData( addr(p_bytes), n8p )
var p : BigInt[256]
unmarshal(p, p_bytes, littleEndian);
return (n8p, p)
#-------------------------------------------------------------------------------
proc readSection[T] ( expectedMagic: string
, expectedVersion: int
, stream: Stream
, user: var T
, callback: SectionCallback[T]
, filt: (int) -> bool ) =
let sectId = int( stream.readUint32() )
let sectLen = int( stream.readUint64() )
let oldpos = stream.getPosition()
if filt(sectId):
# echo("section id = ",sectId )
# echo("section len = ",sectLen)
callback(stream, sectId, sectLen, user)
stream.setPosition(oldpos + sectLen)
#-------------------------------------------------------------------------------
proc parseContainer*[T] ( expectedMagic: string
, expectedVersion: int
, fname: string
, user: var T
, callback: SectionCallback[T]
, filt: (int) -> bool ) =
let stream = newFileStream(fname, mode = fmRead)
defer: stream.close()
let magic = stream.readUint32()
assert( magic == magicWord(expectedMagic) , "not a `" & expectedMagic & "` file" )
let version = stream.readUint32()
assert( version == uint32(expectedVersion) , "not a version " & ($expectedVersion) & " `" & expectedMagic & "` file" )
let nsections = stream.readUint32()
echo("number of sections = ",nsections)
for i in 1..nsections:
readSection(expectedMagic, expectedVersion, stream, user, callback, filt)
#-------------------------------------------------------------------------------

19
main.nim Normal file
View File

@ -0,0 +1,19 @@
import ./r1cs
import ./zkey
import ./witness
import ./bn128
#-------------------------------------------------------------------------------
proc testMain() =
# checkMontgomeryConstants()
let r1cs_fname : string = "/Users/bkomuves/zk/examples/circom/toy/build/toymain.r1cs"
let zkey_fname : string = "/Users/bkomuves/zk/examples/circom/toy/build/toymain.zkey"
let wtns_fname : string = "/Users/bkomuves/zk/examples/circom/toy/build/toymain_witness.wtns"
parseWitness( wtns_fname)
parseR1CS( r1cs_fname)
parseZKey( zkey_fname)
when isMainModule:
testMain()

28
misc.nim Normal file
View File

@ -0,0 +1,28 @@
#-------------------------------------------------------------------------------
func floorLog2* (x : int) : int =
var k = -1
var y = x
while (y > 0):
k += 1
y = y shr 1
return k
func ceilingLog2* (x : int) : int =
if (x==0):
return -1
else:
return (floorLog2(x-1) + 1)
#
# import std/math
#
# proc sanityCheckLog2* () =
# for i in 0..18:
# let x = float64(i)
# echo( i," | ",floorLog2(i),"=",floor(log2(x))," | ",ceilingLog2(i),"=",ceil(log2(x)) )
#
#-------------------------------------------------------------------------------

174
r1cs.nim Normal file
View File

@ -0,0 +1,174 @@
#
# parsing the `.r1cs` file computed by `circom` witness code genereators
#
# file format
# ===========
#
# standard iden3 binary container format.
# field elements are in Montgomery representation
#
# sections:
#
# 1: Header
# ---------
# n8r : word32 = how many bytes are a field element in Fr
# r : n8r bytes = the size of the prime field Fr (the scalar field)
# nWires : word32 = number of wires (or witness variables)
# nPubOut : word32 = number of public outputs
# nPubIn : word32 = number of public inputs
# nPrivIn : word32 = number of private inputs
# nLabels : word64 = number of labels (variable names in the circom source code)
#
# 2: Constraints
# --------------
# nConstr : word32 = number of constraints
# then an array of constraints:
# A : LinComb
# B : LinComb
# C : LinComb
# meaning `A*B=C`, where LinComb looks like this:
# nTerms : word32 = number of terms
# <an array of terms>
# where a term looks like this:
# idx : word32 = which witness variable
# coeff : Fr = the coefficient
#
# 3: Wire-to-label mapping
# ------------------------
# <an array of `nWires` many 64 bit words>
#
# 4: Custom gates list
# --------------------
# ...
# ...
#
# 4: Custom gates application
# ---------------------------
# ...
# ...
#
import std/streams
import constantine/math/arithmetic except Fp, Fr
import constantine/math/io/io_bigints
import ./bn128
import ./container
#-------------------------------------------------------------------------------
type
WitnessConfig* = object
nWires : int # total number of wires (or witness variables), including the constant 1 "variable"
nPubOut : int # number of public outputs
nPubIn : int # number of public inputs
nPrivIn : int # number of private inputs
nLabels : int # number of labels
Term* = tuple[ wireIdx: int, value: Fr ]
LinComb* = seq[Term]
Constraint* = tuple[ A: LinComb, B: LinComb, C: LinComb ]
R1CS* = object
r : BigInt[256]
cfg : WitnessConfig
nConstr : int
constraints : seq[Constraint]
wireToLabel : seq[int]
#-------------------------------------------------------------------------------
proc parseSection1_header( stream: Stream, user: var R1CS, sectionLen: int ) =
echo "\nparsing r1cs header"
let (n8r, r) = parsePrimeField( stream ) # size of the scalar field
echo("r = ",toDecimalBig(r))
user.r = r;
assert( sectionLen == 4 + n8r + 16 + 8 + 4, "unexpected section length")
assert( bool(r == primeR) , "expecting the alt-bn128 curve" )
var cfg : WitnessConfig
cfg.nWires = int( stream.readUint32() )
cfg.nPubOut = int( stream.readUint32() )
cfg.nPubIn = int( stream.readUint32() )
cfg.nPrivIn = int( stream.readUint32() )
cfg.nLabels = int( stream.readUint64() )
user.cfg = cfg
let nConstr = int( stream.readUint32() )
user.nConstr = nConstr
echo("witness config = ",cfg)
echo("nConstr = ",nConstr)
#-------------------------------------------------------------------------------
proc loadTerm( stream: Stream ): Term =
let idx = int( stream.readUint32() )
let coeff = loadValueFr( stream )
return (wireIdx:idx, value:coeff)
proc loadLinComb( stream: Stream ): LinComb =
let nterms = int( stream.readUint32() )
var terms : seq[Term]
for i in 1..nterms:
terms.add( loadTerm(stream) )
return terms
proc loadConstraint( stream: Stream ): Constraint =
let a = loadLinComb( stream )
let b = loadLinComb( stream )
let c = loadLinComb( stream )
return (A:a, B:b, C:c)
#-------------------------------------------------------------------------------
proc parseSection2_constraints( stream: Stream, user: var R1CS, sectionLen: int ) =
var constr: seq[Constraint]
var ncoeffsA, ncoeffsB, ncoeffsC: int
for i in 1..(user.nConstr):
let abc = loadConstraint(stream)
constr.add( abc )
ncoeffsA += abc.A.len
ncoeffsB += abc.B.len
ncoeffsC += abc.C.len
user.constraints = constr
echo( "number of nonzero coefficients in matrix A = ", ncoeffsA )
echo( "number of nonzero coefficients in matrix B = ", ncoeffsB )
echo( "number of nonzero coefficients in matrix C = ", ncoeffsC )
#-------------------------------------------------------------------------------
proc parseSection3_wireToLabel( stream: Stream, user: var R1CS, sectionLen: int ) =
assert( sectionLen == 8 * user.cfg.nWires, "unexpected section length")
var labels: seq[int]
for i in 1..(user.cfg.nWires):
let label = int( stream.readUint64() )
labels.add( label )
user.wireToLabel = labels
#-------------------------------------------------------------------------------
proc r1csCallback( stream: Stream
, sectId: int
, sectLen: int
, user: var R1CS
) =
case sectId
of 1: parseSection1_header( stream, user, sectLen )
of 2: parseSection2_constraints( stream, user, sectLen )
of 3: parseSection3_wireToLabel( stream, user, sectLen )
else: discard
proc parseR1CS* (fname: string) =
var r1cs : R1CS
parseContainer( "r1cs", 1, fname, r1cs, r1csCallback, proc (id: int): bool = id == 1 )
parseContainer( "r1cs", 1, fname, r1cs, r1csCallback, proc (id: int): bool = id != 1 )
#-------------------------------------------------------------------------------

73
witness.nim Normal file
View File

@ -0,0 +1,73 @@
#
# parsing the `.wtns` file computed by `circom` witness code genereators
#
# Note: the witness values are a flat array of size `nvars`, organized
# in the following order:
#
# [ 1 | public output | public input | private input | secret witness ]
#
# so we have
#
# nvars = 1 + pub + secret = 1 + npubout + npubin + nprivin + nsecret
#
import std/streams
import constantine/math/arithmetic except Fp, Fr
import constantine/math/io/io_bigints
import ./bn128
import ./container
#-------------------------------------------------------------------------------
type
Witness* = object
r : BigInt[256]
nvars : int
values : seq[Fr]
#-------------------------------------------------------------------------------
proc parseSection1_header( stream: Stream, user: var Witness, sectionLen: int ) =
echo "\nparsing witness header"
let (n8r, r) = parsePrimeField( stream ) # size of the scalar field
user.r = r;
echo("r = ",toDecimalBig(r))
assert( sectionLen == 4 + n8r + 4 , "unexpected section length")
assert( n8r == 32 , "expecting 256 bit prime" )
assert( bool(r == primeR) , "expecting the alt-bn128 curve" )
let nvars = int( stream.readUint32() )
user.nvars = nvars;
echo("nvars = ",nvars)
#-------------------------------------------------------------------------------
proc parseSection2_witness( stream: Stream, user: var Witness, sectionLen: int ) =
assert( sectionLen == 32 * user.nvars )
user.values = loadValuesFr( user.nvars, stream )
#-------------------------------------------------------------------------------
proc wtnsCallback(stream: Stream, sectId: int, sectLen: int, user: var Witness) =
#echo(sectId)
case sectId
of 1: parseSection1_header( stream, user, sectLen )
of 2: parseSection2_witness( stream, user, sectLen )
else: discard
proc parseWitness* (fname: string) =
var wtns : Witness
parseContainer( "wtns", 2, fname, wtns, wtnsCallback, proc (id: int): bool = id == 1 )
parseContainer( "wtns", 2, fname, wtns, wtnsCallback, proc (id: int): bool = id != 1 )
#-------------------------------------------------------------------------------

279
zkey.nim Normal file
View File

@ -0,0 +1,279 @@
#
# parsing the `.zkey` file format used by the `circom` ecosystem.
# this contains the prover and verifier keys.
#
# file format
# ===========
#
# standard iden3 binary container format.
# field elements are in Montgomery representation
#
# sections:
#
# 1: Header
# ---------
# prover_type : word32 (Groth16 = 0x0001)
#
# 2: Groth16-specific header
# --------------------------
# n8p : word32 = how many bytes are a field element in Fp
# p : n8p bytes = the size of the prime field Fp (the base field)
# n8r : word32 = how many bytes are a field element in Fr
# r : n8p bytes = the size of the prime field Fr (the scalar field)
# nvars : word32 = number of witness variables
# npub : word32 = number of public variables (public input/output)
# domSize : word32 = domain size (power of two)
# alpha1 : G1 = [alpha]_1
# beta1 : G1 = [beta]_1
# beta2 : G2 = [beta]_2
# gamma2 : G2 = [gamma]_2
# delta1 : G1 = [delta]_1
# delta2 : G2 = [delta_2]
#
# 3: IC
# -----
# the curve points (corresponding to public input) required by the verifier
# length = 2 * n8p * (npub + 1) = (npub+1) G1 points
#
# 4: Coeffs
# ---------
# ncoeffs : words32 = number of entries
# The nonzero coefficients in the A,B R1CS matrices (that is, sparse representation)
# Remark: since we now that (A*witness).(B*witness) = C.witness
# (12+n8r) bytes per entry:
# m : word32 = which matrix (0=A, 1=B)
# c : word32 = which row, from 0..domSize-1
# s : word32 = which column, from 0..nvars-1
# value : Fr (n8r bytes)
#
# for each such entry, we add `value * witness[c]` to the `i`-th element of
# the corresponding column vector (meaning `A*witness` and `B*witness), then
# compute (C*witness)[i] = (A*witness)[i] * (B*witness)[i]
# These 3 column vectors is all we need in the proof generation.
#
# 5: PointsA
# ----------
# the curve points [A_j(tau)]_1 in G1
# length = 2 * n8p * nvars = nvars G1 points
#
# 6: PointsB1
# -----------
# the curve points [B_j(tau)]_1 in G1
# length = 2 * n8p * nvars = nvars G1 points
#
# 7: PointsB2
# -----------
# the curve points [B_j(tau)]_2 in G2
# length = 4 * n8p * nvars = nvars G2 points
#
# 8: PointsC
# ----------
# the curve points [ delta^-1 * ( beta*A_j(tau) + alpha*B_j(tau) + C_j(tau) ) ]_1 in G1
# length = 2 * n8p * (nvars - npub - 1) = (nvars-npub-1) G1 points
#
# 9: PointsH
# ----------
# the curve points [delta^-1 * tau^i * Z(tau)]
# length = 2 * n8p * domSize = domSize G1 points
#
# 10: Contributions
# -----
# ??? (but not required for proving, only for checking that the `.zkey` file is valid)
#
import std/streams
import constantine/math/arithmetic except Fp, Fr
import constantine/math/io/io_bigints
import ./bn128
import ./container
import ./misc
#-------------------------------------------------------------------------------
type
GrothHeader* = object
p : BigInt[256]
r : BigInt[256]
nvars : int
npubs : int
domainSize : int
logDomainSize : int
SpecPoints* = object
alpha1 : G1
beta1 : G1
beta2 : G2
gamma2 : G2
delta1 : G1
delta2 : G2
VerifierPoints* = object
pointsIC : seq[G1]
ProverPoints* = object
pointsA1 : seq[G1]
pointsB1 : seq[G1]
pointsB2 : seq[G2]
pointsC1 : seq[G1]
pointsH1 : seq[G1]
MatrixSel* = enum
MatrixA
MatrixB
MatrixC
Coeff* = object
matrix : MatrixSel
row : int
col : int
coeff : Fr
ZKey* = object
sectionMask : uint32
header : GrothHeader
specPoints : SpecPoints
vPoints : VerifierPoints
pPoints : ProverPoints
coeffs : seq[Coeff]
proc parseSection1_proverType ( stream: Stream, user: var Zkey, sectionLen: int ) =
assert( sectionLen == 4 , "unexpected section length" )
let proverType = stream.readUint32
assert( proverType == 1 , "expecting `.zkey` file for a Groth16 prover")
#-------------------------------------------------------------------------------
proc parseSection2_GrothHeader( stream: Stream, user: var ZKey, sectionLen: int ) =
echo "\nparsing the Groth16 zkey header"
let (n8p, p) = parsePrimeField( stream ) # size of the base field
let (n8r, r) = parsePrimeField( stream ) # size of the scalar field
echo("p = ",toDecimalBig(p))
echo("r = ",toDecimalBig(r))
assert( sectionLen == 2*4 + n8p + n8r + 3*4 + 3*64 + 3*128 , "unexpected section length" )
var header : GrothHeader
header.p = p
header.r = r
assert( n8p == 32 , "expecting 256 bit primes")
assert( n8r == 32 , "expecting 256 bit primes")
assert( bool(p == primeP) , "expecting the alt-bn128 curve" )
assert( bool(r == primeR) , "expecting the alt-bn128 curve" )
let nvars = int( stream.readUint32() )
let npubs = int( stream.readUint32() )
let domsiz = int( stream.readUint32() )
let log2siz = ceilingLog2(domsiz)
assert( (1 shl log2siz) == domsiz , "domain size should be a power of two" )
echo("nvars = ",nvars)
echo("npubs = ",npubs)
echo("domsiz = ",domsiz)
header.nvars = nvars
header.npubs = npubs
header.domainSize = domsiz
header.logDomainSize = log2siz
user.header = header
# 3 group elements in G1, 3 in G2
var spec : SpecPoints
spec.alpha1 = loadPointG1( stream )
spec.beta1 = loadPointG1( stream )
spec.beta2 = loadPointG2( stream )
spec.gamma2 = loadPointG2( stream )
spec.delta1 = loadPointG1( stream )
spec.delta2 = loadPointG2( stream )
user.specPoints = spec
#-------------------------------------------------------------------------------
proc parseSection4_Coeffs( stream: Stream, user: var ZKey, sectionLen: int ) =
let ncoeffs = int( stream.readUint32() )
assert( sectionLen == 4 + ncoeffs*(32+12) , "unexpected section length" )
let nrows = user.header.domainSize
let ncols = user.header.nvars
var coeffs : seq[Coeff]
for i in 1..ncoeffs:
let m = int( stream.readUint32() )
let r = int( stream.readUint32() )
let c = int( stream.readUint32() )
assert( m >= 0 and m <= 2 , "invalid matrix selector" )
let sel : MatrixSel = case m
of 0: MatrixA
of 1: MatrixB
of 2: MatrixC
else: raise newException(AssertionDefect, "fatal error")
assert( r >= 0 and r < nrows, "row index out of range" )
assert( c >= 0 and c < ncols, "column index out of range" )
let cf = loadValueFr( stream )
let entry = Coeff( matrix:sel, row:r, col:c, coeff:cf )
coeffs.add( entry )
user.coeffs = coeffs
#-------------------------------------------------------------------------------
proc parseSection3_PointsIC( stream: Stream, user: var ZKey, sectionLen: int ) =
let npoints = user.header.npubs + 1
assert( sectionLen == 64*npoints , "unexpected section length" )
user.vPoints.pointsIC = loadPointsG1( npoints, stream )
proc parseSection5_PointsA1( stream: Stream, user: var ZKey, sectionLen: int ) =
let npoints = user.header.nvars
assert( sectionLen == 64*npoints , "unexpected section length" )
user.pPoints.pointsA1 = loadPointsG1( npoints, stream )
proc parseSection6_PointsB1( stream: Stream, user: var ZKey, sectionLen: int ) =
let npoints = user.header.nvars
assert( sectionLen == 64*npoints , "unexpected section length" )
user.pPoints.pointsB1 = loadPointsG1( npoints, stream )
proc parseSection7_PointsB2( stream: Stream, user: var ZKey, sectionLen: int ) =
let npoints = user.header.nvars
assert( sectionLen == 128*npoints , "unexpected section length" )
user.pPoints.pointsB2 = loadPointsG2( npoints, stream )
proc parseSection8_PointsC1( stream: Stream, user: var ZKey, sectionLen: int ) =
let npoints = user.header.nvars - user.header.npubs - 1
assert( sectionLen == 64*npoints , "unexpected section length" )
user.pPoints.pointsC1 = loadPointsG1( npoints, stream )
proc parseSection9_PointsH1( stream: Stream, user: var ZKey, sectionLen: int ) =
let npoints = user.header.domainSize
assert( sectionLen == 64*npoints , "unexpected section length" )
user.pPoints.pointsH1 = loadPointsG1( npoints, stream )
#-------------------------------------------------------------------------------
proc zkeyCallback(stream: Stream, sectId: int, sectLen: int, user: var ZKey) =
case sectId
of 1: parseSection1_proverType( stream, user, sectLen )
of 2: parseSection2_GrothHeader( stream, user, sectLen )
of 3: parseSection3_PointsIC( stream, user, sectLen )
of 4: parseSection4_Coeffs( stream, user, sectLen )
of 5: parseSection5_PointsA1( stream, user, sectLen )
of 6: parseSection6_PointsB1( stream, user, sectLen )
of 7: parseSection7_PointsB2( stream, user, sectLen )
of 8: parseSection8_PointsC1( stream, user, sectLen )
of 9: parseSection9_PointsH1( stream, user, sectLen )
else: discard
proc parseZKey* (fname: string) =
var zkey : ZKey
parseContainer( "zkey", 1, fname, zkey, zkeyCallback, proc (id: int): bool = id == 1 )
parseContainer( "zkey", 1, fname, zkey, zkeyCallback, proc (id: int): bool = id == 2 )
parseContainer( "zkey", 1, fname, zkey, zkeyCallback, proc (id: int): bool = id >= 3 )
#-------------------------------------------------------------------------------