diff --git a/bn128.nim b/bn128.nim new file mode 100644 index 0000000..8465578 --- /dev/null +++ b/bn128.nim @@ -0,0 +1,284 @@ + +import std/strutils +import std/streams + +import constantine/math/arithmetic +import constantine/math/io/io_fields +import constantine/math/io/io_bigints +import constantine/math/config/curves +import constantine/math/config/type_ff as tff + +import constantine/math/extension_fields/towers as ext +import constantine/math/elliptic/ec_shortweierstrass_affine as ell + +#------------------------------------------------------------------------------- + +type B* = BigInt[256] +type Fr* = tff.Fr[BN254Snarks] +type Fp* = tff.Fp[BN254Snarks] +type Fp2* = ext.QuadraticExt[Fp] +type G1* = ell.ECP_ShortW_Aff[Fp , ell.G1] +type G2* = ell.ECP_ShortW_Aff[Fp2, ell.G2] + +func mkFp2(i: Fp, u: Fp) : Fp2 = + let c : array[2, Fp] = [i,u] + return ext.QuadraticExt[Fp]( coords: c ) + +func unsafeMkG1( X, Y: Fp ) : G1 = + return ell.ECP_ShortW_Aff[Fp, ell.G1](x: X, y: Y) + +func unsafeMkG2( X, Y: Fp2 ) : G2 = + return ell.ECP_ShortW_Aff[Fp2, ell.G2](x: X, y: Y) + +#------------------------------------------------------------------------------- + +const primeP* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", bigEndian ) +const primeR* : B = fromHex( B, "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian ) + +#------------------------------------------------------------------------------- + +const zeroFp* : Fp = fromHex( Fp, "0x00" ) +const zeroFr* : Fr = fromHex( Fr, "0x00" ) +const oneFp* : Fp = fromHex( Fp, "0x01" ) +const oneFr* : Fr = fromHex( Fr, "0x01" ) + +const zeroFp2* : Fp2 = mkFp2( zeroFp, zeroFp ) +const infG1* : G1 = unsafeMkG1( zeroFp , zeroFp ) +const infG2* : G2 = unsafeMkG2( zeroFp2 , zeroFp2 ) + +#------------------------------------------------------------------------------- + +func intToFp*(a: int) : Fp = + var y : Fp + y.fromInt(a) + return y + +func intToFr*(a: int) : Fr = + var y : Fr + y.fromInt(a) + return y + +#------------------------------------------------------------------------------- + +func toDecimalBig*[n](a : BigInt[n]): string = + var s : string = toDecimal(a) + s = s.strip( leading=true, trailing=false, chars={'0'} ) + if s.len == 0: s="0" + return s + +func toDecimalFp*(a : Fp): string = + var s : string = toDecimal(a) + s = s.strip( leading=true, trailing=false, chars={'0'} ) + if s.len == 0: s="0" + return s + +func toDecimalFr*(a : Fr): string = + var s : string = toDecimal(a) + s = s.strip( leading=true, trailing=false, chars={'0'} ) + if s.len == 0: s="0" + return s + +#------------------------------------------------------------------------------- + +func checkCurveEqG1*( x, y: Fp ) : bool = + var x2 : Fp = x ; square(x2); + var y2 : Fp = y ; square(y2); + var x3 : Fp = x2 ; x3 *= x; + var eq : Fp + eq = x3 + eq += intToFp(3) + eq -= y2 + # echo("eq = ",toDecimalFp(eq)) + return (bool(isZero(eq))) + +# y^2 = x^3 + B +# B = b1 + bu*u +# b1 = 19485874751759354771024239261021720505790618469301721065564631296452457478373 +# b2 = 266929791119991161246907387137283842545076965332900288569378510910307636690 +const twistCoeffB_1 : Fp = fromHex(Fp, "0x2b149d40ceb8aaae81be18991be06ac3b5b4c5e559dbefa33267e6dc24a138e5") +const twistCoeffB_u : Fp = fromHex(Fp, "0x009713b03af0fed4cd2cafadeed8fdf4a74fa084e52d1852e4a2bd0685c315d2") +const twistCoeffB : Fp2 = mkFp2( twistCoeffB_1 , twistCoeffB_u ) + +func checkCurveEqG2*( x, y: Fp2 ) : bool = + var x2 : Fp2 = x ; square(x2); + var y2 : Fp2 = y ; square(y2); + var x3 : Fp2 = x2 ; x3 *= x; + var eq : Fp2 + eq = x3 + eq += twistCoeffB + eq -= y2 + return (bool(isZero(eq))) + +#------------------------------------------------------------------------------- + +func mkG1( x, y: Fp ) : G1 = + if bool(isZero(x)) and bool(isZero(y)): + return infG1 + else: + assert( checkCurveEqG1(x,y) , "mkG1: not a G1 curve point" ) + return unsafeMkG1(x,y) + +func mkG2( x, y: Fp2 ) : G2 = + if bool(isZero(x)) and bool(isZero(y)): + return infG2 + else: + assert( checkCurveEqG2(x,y) , "mkG2: not a G2 curve point" ) + return unsafeMkG2(x,y) + +#------------------------------------------------------------------------------- +# Dealing with Montgomery representation +# + +# R=2^256; this computes 2^256 mod Fp +func calcFpMontR*() : Fp = + var x : Fp = intToFp(2) + for i in 1..8: + square(x) + return x + +# R=2^256; this computes the inverse of (2^256 mod Fp) +func calcFpInvMontR*() : Fp = + var x : Fp = calcFpMontR() + inv(x) + return x + +# R=2^256; this computes 2^256 mod Fr +func calcFrMontR*() : Fr = + var x : Fr = intToFr(2) + for i in 1..8: + square(x) + return x + +# R=2^256; this computes the inverse of (2^256 mod Fp) +func calcFrInvMontR*() : Fr = + var x : Fr = calcFrMontR() + inv(x) + return x + +# apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?) +const fpMontR* : Fp = fromHex( Fp, "0x0e0a77c19a07df2f666ea36f7879462c0a78eb28f5c70b3dd35d438dc58f0d9d" ) +const fpInvMontR* : Fp = fromHex( Fp, "0x2e67157159e5c639cf63e9cfb74492d9eb2022850278edf8ed84884a014afa37" ) + +# apparently we cannot compute these in compile time for some reason or other... (maybe because `intToFp()`?) +const frMontR* : Fr = fromHex( Fr, "0x0e0a77c19a07df2f666ea36f7879462e36fc76959f60cd29ac96341c4ffffffb" ) +const frInvMontR* : Fr = fromHex( Fr, "0x15ebf95182c5551cc8260de4aeb85d5d090ef5a9e111ec87dc5ba0056db1194e" ) + +proc checkMontgomeryConstants*() = + assert( bool( fpMontR == calcFpMontR() ) ) + assert( bool( frMontR == calcFrMontR() ) ) + assert( bool( fpInvMontR == calcFpInvMontR() ) ) + assert( bool( frInvMontR == calcFrInvMontR() ) ) + echo("OK") + +#--------------------------------------- + +# the binary files used by the `circom` ecosystem always use little-endian +# Montgomery representation. So when we unmarshal with Constantine, it will +# give the wrong result. Calling this function on the result fixes that. +func fromMontgomeryFp*(x : Fp) : Fp = + var y : Fp = x; + y *= fpInvMontR + return y + +func fromMontgomeryFr*(x : Fr) : Fr = + var y : Fr = x; + y *= frInvMontR + return y + +#------------------------------------------------------------------------------- +# Unmarshalling field elements +# (note: circom binary files use little-endian Montgomery representation) +# + +func unmarshalFp* ( bs: array[32,byte] ) : Fp = + var big : BigInt[254] + unmarshal( big, bs, littleEndian ); + var x : Fp + x.fromBig( big ) + return fromMontgomeryFp(x) + +func unmarshalFr* ( bs: array[32,byte] ) : Fr = + var big : BigInt[254] + unmarshal( big, bs, littleEndian ); + var x : Fr + x.fromBig( big ) + return fromMontgomeryFr(x) + +#------------------------------------------------------------------------------- + +func unmarshalFpSeq* ( len: int, bs: openArray[byte] ) : seq[Fp] = + var vals : seq[Fp] = newSeq[Fp]( len ) + var bytes : array[32,byte] + for i in 0.. +# +# format: +# ======= +# +# global header: +# -------------- +# magic : word32 +# version : word32 +# number of sections : word32 +# +# for each section: +# ----------------- +# section id : word32 +# section size : word64 +# section data : number of bytes +# + +#------------------------------------------------------------------------------- + +import std/streams + +import sugar + +import constantine/math/arithmetic except Fp, Fr +import constantine/math/io/io_bigints + +#------------------------------------------------------------------------------- + +type + SectionCallback*[T] = proc (stream: Stream, sectId: int, sectLen: int, user: var T) {.closure.} + +#------------------------------------------------------------------------------- + +func magicWord(magic: string): uint32 = + assert( magic.len == 4, "magicWord: expecting a string of 4 characters" ) + var w : uint32 = 0 + for i in 0..3: + let a = uint32(ord(magic[i])) + w += a shl (8*i) + return w + +#------------------------------------------------------------------------------- + +proc parsePrimeField*( stream: Stream ) : (int, BigInt[256]) = + let n8p = int( stream.readUint32() ) + assert( n8p <= 32 , "at most 256 bit primes are allowed" ) + var p_bytes : array[32, uint8] + discard stream.readData( addr(p_bytes), n8p ) + var p : BigInt[256] + unmarshal(p, p_bytes, littleEndian); + return (n8p, p) + +#------------------------------------------------------------------------------- + +proc readSection[T] ( expectedMagic: string + , expectedVersion: int + , stream: Stream + , user: var T + , callback: SectionCallback[T] + , filt: (int) -> bool ) = + + let sectId = int( stream.readUint32() ) + let sectLen = int( stream.readUint64() ) + let oldpos = stream.getPosition() + if filt(sectId): + # echo("section id = ",sectId ) + # echo("section len = ",sectLen) + callback(stream, sectId, sectLen, user) + stream.setPosition(oldpos + sectLen) + +#------------------------------------------------------------------------------- + +proc parseContainer*[T] ( expectedMagic: string + , expectedVersion: int + , fname: string + , user: var T + , callback: SectionCallback[T] + , filt: (int) -> bool ) = + + let stream = newFileStream(fname, mode = fmRead) + defer: stream.close() + + let magic = stream.readUint32() + assert( magic == magicWord(expectedMagic) , "not a `" & expectedMagic & "` file" ) + let version = stream.readUint32() + assert( version == uint32(expectedVersion) , "not a version " & ($expectedVersion) & " `" & expectedMagic & "` file" ) + let nsections = stream.readUint32() + echo("number of sections = ",nsections) + + for i in 1..nsections: + readSection(expectedMagic, expectedVersion, stream, user, callback, filt) + +#------------------------------------------------------------------------------- + diff --git a/main.nim b/main.nim new file mode 100644 index 0000000..51ed776 --- /dev/null +++ b/main.nim @@ -0,0 +1,19 @@ + +import ./r1cs +import ./zkey +import ./witness +import ./bn128 + +#------------------------------------------------------------------------------- + +proc testMain() = + # checkMontgomeryConstants() + let r1cs_fname : string = "/Users/bkomuves/zk/examples/circom/toy/build/toymain.r1cs" + let zkey_fname : string = "/Users/bkomuves/zk/examples/circom/toy/build/toymain.zkey" + let wtns_fname : string = "/Users/bkomuves/zk/examples/circom/toy/build/toymain_witness.wtns" + parseWitness( wtns_fname) + parseR1CS( r1cs_fname) + parseZKey( zkey_fname) + +when isMainModule: + testMain() \ No newline at end of file diff --git a/misc.nim b/misc.nim new file mode 100644 index 0000000..f2e4964 --- /dev/null +++ b/misc.nim @@ -0,0 +1,28 @@ + + +#------------------------------------------------------------------------------- + +func floorLog2* (x : int) : int = + var k = -1 + var y = x + while (y > 0): + k += 1 + y = y shr 1 + return k + +func ceilingLog2* (x : int) : int = + if (x==0): + return -1 + else: + return (floorLog2(x-1) + 1) + +# +# import std/math +# +# proc sanityCheckLog2* () = +# for i in 0..18: +# let x = float64(i) +# echo( i," | ",floorLog2(i),"=",floor(log2(x))," | ",ceilingLog2(i),"=",ceil(log2(x)) ) +# + +#------------------------------------------------------------------------------- diff --git a/r1cs.nim b/r1cs.nim new file mode 100644 index 0000000..1bbda57 --- /dev/null +++ b/r1cs.nim @@ -0,0 +1,174 @@ + +# +# parsing the `.r1cs` file computed by `circom` witness code genereators +# +# file format +# =========== +# +# standard iden3 binary container format. +# field elements are in Montgomery representation +# +# sections: +# +# 1: Header +# --------- +# n8r : word32 = how many bytes are a field element in Fr +# r : n8r bytes = the size of the prime field Fr (the scalar field) +# nWires : word32 = number of wires (or witness variables) +# nPubOut : word32 = number of public outputs +# nPubIn : word32 = number of public inputs +# nPrivIn : word32 = number of private inputs +# nLabels : word64 = number of labels (variable names in the circom source code) +# +# 2: Constraints +# -------------- +# nConstr : word32 = number of constraints +# then an array of constraints: +# A : LinComb +# B : LinComb +# C : LinComb +# meaning `A*B=C`, where LinComb looks like this: +# nTerms : word32 = number of terms +# +# where a term looks like this: +# idx : word32 = which witness variable +# coeff : Fr = the coefficient +# +# 3: Wire-to-label mapping +# ------------------------ +# +# +# 4: Custom gates list +# -------------------- +# ... +# ... +# +# 4: Custom gates application +# --------------------------- +# ... +# ... +# + +import std/streams + +import constantine/math/arithmetic except Fp, Fr +import constantine/math/io/io_bigints + +import ./bn128 +import ./container + +#------------------------------------------------------------------------------- + +type + + WitnessConfig* = object + nWires : int # total number of wires (or witness variables), including the constant 1 "variable" + nPubOut : int # number of public outputs + nPubIn : int # number of public inputs + nPrivIn : int # number of private inputs + nLabels : int # number of labels + + Term* = tuple[ wireIdx: int, value: Fr ] + LinComb* = seq[Term] + Constraint* = tuple[ A: LinComb, B: LinComb, C: LinComb ] + + R1CS* = object + r : BigInt[256] + cfg : WitnessConfig + nConstr : int + constraints : seq[Constraint] + wireToLabel : seq[int] + +#------------------------------------------------------------------------------- + +proc parseSection1_header( stream: Stream, user: var R1CS, sectionLen: int ) = + echo "\nparsing r1cs header" + + let (n8r, r) = parsePrimeField( stream ) # size of the scalar field + echo("r = ",toDecimalBig(r)) + user.r = r; + + assert( sectionLen == 4 + n8r + 16 + 8 + 4, "unexpected section length") + + assert( bool(r == primeR) , "expecting the alt-bn128 curve" ) + + var cfg : WitnessConfig + + cfg.nWires = int( stream.readUint32() ) + cfg.nPubOut = int( stream.readUint32() ) + cfg.nPubIn = int( stream.readUint32() ) + cfg.nPrivIn = int( stream.readUint32() ) + cfg.nLabels = int( stream.readUint64() ) + user.cfg = cfg + + let nConstr = int( stream.readUint32() ) + user.nConstr = nConstr + + echo("witness config = ",cfg) + echo("nConstr = ",nConstr) + +#------------------------------------------------------------------------------- + +proc loadTerm( stream: Stream ): Term = + let idx = int( stream.readUint32() ) + let coeff = loadValueFr( stream ) + return (wireIdx:idx, value:coeff) + +proc loadLinComb( stream: Stream ): LinComb = + let nterms = int( stream.readUint32() ) + var terms : seq[Term] + for i in 1..nterms: + terms.add( loadTerm(stream) ) + return terms + +proc loadConstraint( stream: Stream ): Constraint = + let a = loadLinComb( stream ) + let b = loadLinComb( stream ) + let c = loadLinComb( stream ) + return (A:a, B:b, C:c) + +#------------------------------------------------------------------------------- + +proc parseSection2_constraints( stream: Stream, user: var R1CS, sectionLen: int ) = + var constr: seq[Constraint] + var ncoeffsA, ncoeffsB, ncoeffsC: int + for i in 1..(user.nConstr): + let abc = loadConstraint(stream) + constr.add( abc ) + ncoeffsA += abc.A.len + ncoeffsB += abc.B.len + ncoeffsC += abc.C.len + user.constraints = constr + echo( "number of nonzero coefficients in matrix A = ", ncoeffsA ) + echo( "number of nonzero coefficients in matrix B = ", ncoeffsB ) + echo( "number of nonzero coefficients in matrix C = ", ncoeffsC ) + +#------------------------------------------------------------------------------- + +proc parseSection3_wireToLabel( stream: Stream, user: var R1CS, sectionLen: int ) = + assert( sectionLen == 8 * user.cfg.nWires, "unexpected section length") + var labels: seq[int] + for i in 1..(user.cfg.nWires): + let label = int( stream.readUint64() ) + labels.add( label ) + user.wireToLabel = labels + +#------------------------------------------------------------------------------- + +proc r1csCallback( stream: Stream + , sectId: int + , sectLen: int + , user: var R1CS + ) = + case sectId + of 1: parseSection1_header( stream, user, sectLen ) + of 2: parseSection2_constraints( stream, user, sectLen ) + of 3: parseSection3_wireToLabel( stream, user, sectLen ) + else: discard + +proc parseR1CS* (fname: string) = + var r1cs : R1CS + parseContainer( "r1cs", 1, fname, r1cs, r1csCallback, proc (id: int): bool = id == 1 ) + parseContainer( "r1cs", 1, fname, r1cs, r1csCallback, proc (id: int): bool = id != 1 ) + +#------------------------------------------------------------------------------- diff --git a/witness.nim b/witness.nim new file mode 100644 index 0000000..8b1cec0 --- /dev/null +++ b/witness.nim @@ -0,0 +1,73 @@ + +# +# parsing the `.wtns` file computed by `circom` witness code genereators +# +# Note: the witness values are a flat array of size `nvars`, organized +# in the following order: +# +# [ 1 | public output | public input | private input | secret witness ] +# +# so we have +# +# nvars = 1 + pub + secret = 1 + npubout + npubin + nprivin + nsecret +# + +import std/streams + +import constantine/math/arithmetic except Fp, Fr +import constantine/math/io/io_bigints + +import ./bn128 +import ./container + +#------------------------------------------------------------------------------- + +type + Witness* = object + r : BigInt[256] + nvars : int + values : seq[Fr] + +#------------------------------------------------------------------------------- + +proc parseSection1_header( stream: Stream, user: var Witness, sectionLen: int ) = + echo "\nparsing witness header" + + let (n8r, r) = parsePrimeField( stream ) # size of the scalar field + user.r = r; + echo("r = ",toDecimalBig(r)) + + assert( sectionLen == 4 + n8r + 4 , "unexpected section length") + + assert( n8r == 32 , "expecting 256 bit prime" ) + assert( bool(r == primeR) , "expecting the alt-bn128 curve" ) + + let nvars = int( stream.readUint32() ) + user.nvars = nvars; + + echo("nvars = ",nvars) + +#------------------------------------------------------------------------------- + +proc parseSection2_witness( stream: Stream, user: var Witness, sectionLen: int ) = + + assert( sectionLen == 32 * user.nvars ) + user.values = loadValuesFr( user.nvars, stream ) + +#------------------------------------------------------------------------------- + +proc wtnsCallback(stream: Stream, sectId: int, sectLen: int, user: var Witness) = + #echo(sectId) + case sectId + of 1: parseSection1_header( stream, user, sectLen ) + of 2: parseSection2_witness( stream, user, sectLen ) + else: discard + +proc parseWitness* (fname: string) = + var wtns : Witness + parseContainer( "wtns", 2, fname, wtns, wtnsCallback, proc (id: int): bool = id == 1 ) + parseContainer( "wtns", 2, fname, wtns, wtnsCallback, proc (id: int): bool = id != 1 ) + + +#------------------------------------------------------------------------------- + diff --git a/zkey.nim b/zkey.nim new file mode 100644 index 0000000..2b85a1e --- /dev/null +++ b/zkey.nim @@ -0,0 +1,279 @@ + +# +# parsing the `.zkey` file format used by the `circom` ecosystem. +# this contains the prover and verifier keys. +# +# file format +# =========== +# +# standard iden3 binary container format. +# field elements are in Montgomery representation +# +# sections: +# +# 1: Header +# --------- +# prover_type : word32 (Groth16 = 0x0001) +# +# 2: Groth16-specific header +# -------------------------- +# n8p : word32 = how many bytes are a field element in Fp +# p : n8p bytes = the size of the prime field Fp (the base field) +# n8r : word32 = how many bytes are a field element in Fr +# r : n8p bytes = the size of the prime field Fr (the scalar field) +# nvars : word32 = number of witness variables +# npub : word32 = number of public variables (public input/output) +# domSize : word32 = domain size (power of two) +# alpha1 : G1 = [alpha]_1 +# beta1 : G1 = [beta]_1 +# beta2 : G2 = [beta]_2 +# gamma2 : G2 = [gamma]_2 +# delta1 : G1 = [delta]_1 +# delta2 : G2 = [delta_2] +# +# 3: IC +# ----- +# the curve points (corresponding to public input) required by the verifier +# length = 2 * n8p * (npub + 1) = (npub+1) G1 points +# +# 4: Coeffs +# --------- +# ncoeffs : words32 = number of entries +# The nonzero coefficients in the A,B R1CS matrices (that is, sparse representation) +# Remark: since we now that (A*witness).(B*witness) = C.witness +# (12+n8r) bytes per entry: +# m : word32 = which matrix (0=A, 1=B) +# c : word32 = which row, from 0..domSize-1 +# s : word32 = which column, from 0..nvars-1 +# value : Fr (n8r bytes) +# +# for each such entry, we add `value * witness[c]` to the `i`-th element of +# the corresponding column vector (meaning `A*witness` and `B*witness), then +# compute (C*witness)[i] = (A*witness)[i] * (B*witness)[i] +# These 3 column vectors is all we need in the proof generation. +# +# 5: PointsA +# ---------- +# the curve points [A_j(tau)]_1 in G1 +# length = 2 * n8p * nvars = nvars G1 points +# +# 6: PointsB1 +# ----------- +# the curve points [B_j(tau)]_1 in G1 +# length = 2 * n8p * nvars = nvars G1 points +# +# 7: PointsB2 +# ----------- +# the curve points [B_j(tau)]_2 in G2 +# length = 4 * n8p * nvars = nvars G2 points +# +# 8: PointsC +# ---------- +# the curve points [ delta^-1 * ( beta*A_j(tau) + alpha*B_j(tau) + C_j(tau) ) ]_1 in G1 +# length = 2 * n8p * (nvars - npub - 1) = (nvars-npub-1) G1 points +# +# 9: PointsH +# ---------- +# the curve points [delta^-1 * tau^i * Z(tau)] +# length = 2 * n8p * domSize = domSize G1 points +# +# 10: Contributions +# ----- +# ??? (but not required for proving, only for checking that the `.zkey` file is valid) +# + +import std/streams + +import constantine/math/arithmetic except Fp, Fr +import constantine/math/io/io_bigints + +import ./bn128 +import ./container +import ./misc + +#------------------------------------------------------------------------------- + +type + + GrothHeader* = object + p : BigInt[256] + r : BigInt[256] + nvars : int + npubs : int + domainSize : int + logDomainSize : int + + SpecPoints* = object + alpha1 : G1 + beta1 : G1 + beta2 : G2 + gamma2 : G2 + delta1 : G1 + delta2 : G2 + + VerifierPoints* = object + pointsIC : seq[G1] + + ProverPoints* = object + pointsA1 : seq[G1] + pointsB1 : seq[G1] + pointsB2 : seq[G2] + pointsC1 : seq[G1] + pointsH1 : seq[G1] + + MatrixSel* = enum + MatrixA + MatrixB + MatrixC + + Coeff* = object + matrix : MatrixSel + row : int + col : int + coeff : Fr + + ZKey* = object + sectionMask : uint32 + header : GrothHeader + specPoints : SpecPoints + vPoints : VerifierPoints + pPoints : ProverPoints + coeffs : seq[Coeff] + +proc parseSection1_proverType ( stream: Stream, user: var Zkey, sectionLen: int ) = + assert( sectionLen == 4 , "unexpected section length" ) + let proverType = stream.readUint32 + assert( proverType == 1 , "expecting `.zkey` file for a Groth16 prover") + +#------------------------------------------------------------------------------- + +proc parseSection2_GrothHeader( stream: Stream, user: var ZKey, sectionLen: int ) = + echo "\nparsing the Groth16 zkey header" + + let (n8p, p) = parsePrimeField( stream ) # size of the base field + let (n8r, r) = parsePrimeField( stream ) # size of the scalar field + + echo("p = ",toDecimalBig(p)) + echo("r = ",toDecimalBig(r)) + + assert( sectionLen == 2*4 + n8p + n8r + 3*4 + 3*64 + 3*128 , "unexpected section length" ) + + var header : GrothHeader + header.p = p + header.r = r + + assert( n8p == 32 , "expecting 256 bit primes") + assert( n8r == 32 , "expecting 256 bit primes") + + assert( bool(p == primeP) , "expecting the alt-bn128 curve" ) + assert( bool(r == primeR) , "expecting the alt-bn128 curve" ) + + let nvars = int( stream.readUint32() ) + let npubs = int( stream.readUint32() ) + let domsiz = int( stream.readUint32() ) + let log2siz = ceilingLog2(domsiz) + + assert( (1 shl log2siz) == domsiz , "domain size should be a power of two" ) + + echo("nvars = ",nvars) + echo("npubs = ",npubs) + echo("domsiz = ",domsiz) + + header.nvars = nvars + header.npubs = npubs + header.domainSize = domsiz + header.logDomainSize = log2siz + + user.header = header + + # 3 group elements in G1, 3 in G2 + var spec : SpecPoints + spec.alpha1 = loadPointG1( stream ) + spec.beta1 = loadPointG1( stream ) + spec.beta2 = loadPointG2( stream ) + spec.gamma2 = loadPointG2( stream ) + spec.delta1 = loadPointG1( stream ) + spec.delta2 = loadPointG2( stream ) + user.specPoints = spec + +#------------------------------------------------------------------------------- + +proc parseSection4_Coeffs( stream: Stream, user: var ZKey, sectionLen: int ) = + let ncoeffs = int( stream.readUint32() ) + assert( sectionLen == 4 + ncoeffs*(32+12) , "unexpected section length" ) + let nrows = user.header.domainSize + let ncols = user.header.nvars + + var coeffs : seq[Coeff] + for i in 1..ncoeffs: + let m = int( stream.readUint32() ) + let r = int( stream.readUint32() ) + let c = int( stream.readUint32() ) + assert( m >= 0 and m <= 2 , "invalid matrix selector" ) + let sel : MatrixSel = case m + of 0: MatrixA + of 1: MatrixB + of 2: MatrixC + else: raise newException(AssertionDefect, "fatal error") + assert( r >= 0 and r < nrows, "row index out of range" ) + assert( c >= 0 and c < ncols, "column index out of range" ) + let cf = loadValueFr( stream ) + let entry = Coeff( matrix:sel, row:r, col:c, coeff:cf ) + coeffs.add( entry ) + + user.coeffs = coeffs + +#------------------------------------------------------------------------------- + +proc parseSection3_PointsIC( stream: Stream, user: var ZKey, sectionLen: int ) = + let npoints = user.header.npubs + 1 + assert( sectionLen == 64*npoints , "unexpected section length" ) + user.vPoints.pointsIC = loadPointsG1( npoints, stream ) + +proc parseSection5_PointsA1( stream: Stream, user: var ZKey, sectionLen: int ) = + let npoints = user.header.nvars + assert( sectionLen == 64*npoints , "unexpected section length" ) + user.pPoints.pointsA1 = loadPointsG1( npoints, stream ) + +proc parseSection6_PointsB1( stream: Stream, user: var ZKey, sectionLen: int ) = + let npoints = user.header.nvars + assert( sectionLen == 64*npoints , "unexpected section length" ) + user.pPoints.pointsB1 = loadPointsG1( npoints, stream ) + +proc parseSection7_PointsB2( stream: Stream, user: var ZKey, sectionLen: int ) = + let npoints = user.header.nvars + assert( sectionLen == 128*npoints , "unexpected section length" ) + user.pPoints.pointsB2 = loadPointsG2( npoints, stream ) + +proc parseSection8_PointsC1( stream: Stream, user: var ZKey, sectionLen: int ) = + let npoints = user.header.nvars - user.header.npubs - 1 + assert( sectionLen == 64*npoints , "unexpected section length" ) + user.pPoints.pointsC1 = loadPointsG1( npoints, stream ) + +proc parseSection9_PointsH1( stream: Stream, user: var ZKey, sectionLen: int ) = + let npoints = user.header.domainSize + assert( sectionLen == 64*npoints , "unexpected section length" ) + user.pPoints.pointsH1 = loadPointsG1( npoints, stream ) + +#------------------------------------------------------------------------------- + +proc zkeyCallback(stream: Stream, sectId: int, sectLen: int, user: var ZKey) = + case sectId + of 1: parseSection1_proverType( stream, user, sectLen ) + of 2: parseSection2_GrothHeader( stream, user, sectLen ) + of 3: parseSection3_PointsIC( stream, user, sectLen ) + of 4: parseSection4_Coeffs( stream, user, sectLen ) + of 5: parseSection5_PointsA1( stream, user, sectLen ) + of 6: parseSection6_PointsB1( stream, user, sectLen ) + of 7: parseSection7_PointsB2( stream, user, sectLen ) + of 8: parseSection8_PointsC1( stream, user, sectLen ) + of 9: parseSection9_PointsH1( stream, user, sectLen ) + else: discard + +proc parseZKey* (fname: string) = + var zkey : ZKey + parseContainer( "zkey", 1, fname, zkey, zkeyCallback, proc (id: int): bool = id == 1 ) + parseContainer( "zkey", 1, fname, zkey, zkeyCallback, proc (id: int): bool = id == 2 ) + parseContainer( "zkey", 1, fname, zkey, zkeyCallback, proc (id: int): bool = id >= 3 ) + +#-------------------------------------------------------------------------------