implement the Nim witness generator (some operations are missing, because... constantine)

This commit is contained in:
Balazs Komuves 2025-03-16 20:09:36 +01:00
parent b832b20442
commit 4651602cb6
No known key found for this signature in database
GPG Key ID: F63B7AEF18435562
9 changed files with 375 additions and 36 deletions

View File

@ -31,8 +31,9 @@ Nim witness generator (to be used with [`nim-groth16`](https://github.com/codex-
- [x] parsing the graph file
- [x] parsing json input
- [ ] generating the witness
- [ ] exporting the witness
- [x] generating the witness
- [x] exporting the witness
- [ ] support the complete set of operations
- [ ] proper error handling
### Testing & correctness

View File

@ -39,6 +39,11 @@ proc writeHeader(s: Stream, witnessLen: int) =
#-------------------------------------------------------------------------------
proc exportFeltSequence*(filepath: string, values: seq[F]) =
var stream = newFileStream(filepath, fmWrite)
for i in 0..<values.len:
stream.writeFelt( values[i] )
proc exportWitness*(filepath: string, witness: seq[F]) =
var stream = newFileStream(filepath, fmWrite)
stream.writeHeader(witness.len)

View File

@ -10,19 +10,30 @@ type
B* = BigInt[254]
F* = Fr[BN254Snarks]
const zeroF* : F = fromHex( Fr[BN254Snarks], "0x00" )
const oneF* : F = fromHex( Fr[BN254Snarks], "0x01" )
const zeroB* : B = fromHex( BigInt[254], "0x00" )
const oneB* : B = fromHex( BigInt[254], "0x01" )
func isZeroF* (x: F ) : bool = bool(isZero(x))
func isEqualF* (x, y: F ) : bool = bool(x == y)
func `===`* (x, y: F ) : bool = isEqualF(x,y)
const zeroF* : F = fromHex( Fr[BN254Snarks], "0x00" )
const oneF* : F = fromHex( Fr[BN254Snarks], "0x01" )
const fieldMask* : B = fromHex( BigInt[254] , "0x3fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", bigEndian )
const fieldPrime* : B = fromHex( BigInt[254] , "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian )
const halfPrimPlus1* : B = fromHex( BigInt[254] , "0x183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f0fac9f8000001", bigEndian )
func isZeroB* (x: B ) : bool = bool(isZero(x))
func isZeroF* (x: F ) : bool = bool(isZero(x))
func isNonZeroF*(x: F ) : bool = not isZeroF(x)
func isEqualF* (x, y: F ) : bool = bool(x == y)
func `===`* (x, y: F ) : bool = isEqualF(x,y)
const fieldMask* : B = fromHex( BigInt[254] , "0x3fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", bigEndian )
const fieldPrime* : B = fromHex( BigInt[254] , "0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", bigEndian )
const halfPrimePlus1* : B = fromHex( BigInt[254] , "0x183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f0fac9f8000001", bigEndian )
const numberOfBitsAsBigInt* : B = fromHex( BigInt[254] , "0xfe", bigEndian )
#-------------------------------------------------------------------------------
func uintToB*(a: uint): B =
var y : B
y.fromUint(a)
return y
func intToF*(a: int): F =
var y : F
y.fromInt(a)
@ -37,7 +48,7 @@ func boolToF*(b: bool): F =
return (if b: oneF else: zeroF)
func fToBool*(x: F): bool =
return not (isZeroF x)
return isNonZeroF(x)
func bigToF*(big: B): F =
var x : F
@ -58,13 +69,17 @@ proc decimalToF*(s: string): F =
# let ok = y.fromDecimal(s) # wtf nim
# return y
func bigToDecimal*(x: B): string =
return toDecimal(x)
func fToDecimal*(x: F): string =
return toDecimal(x)
#-------------------------------------------------------------------------------
func negB* (y: B ): B = ( var z : B = zeroB ; z -= y ; return z )
func negF* (y: F ): F = ( var z : F = zeroF ; z -= y ; return z )
func invF* (y: F ): F = ( var z : F = y ; inv(z) ; return z )
func invF* (y: F ): F = ( var z : F = y ; if isNonZeroF(y): z.inv() ; return z )
func `+`*[n](x, y: BigInt[n] ): BigInt[n] = ( var z : BigInt[n] = x ; z += y ; return z )
func `-`*[n](x, y: BigInt[n] ): BigInt[n] = ( var z : BigInt[n] = x ; z -= y ; return z )

View File

@ -1,9 +1,22 @@
import std/sugar
import std/tables
import pkg/constantine/math/io/io_bigints
import ./field
#-------------------------------------------------------------------------------
type
Inputs* = Table[string, seq[F]]
UnoOp* = enum
Neg,
Id
Id,
Lnot,
Bnot
DuoOp* = enum
Mul,
@ -28,14 +41,16 @@ type
Bxor
TresOp* = enum
TresCond
TernCond
BigUInt* = distinct seq[uint8]
BigUInt* = object
bytes*: seq[byte]
InputNode*[T] = object
idx*: T
InputNode* = object
idx*: uint32
ConstantNode* = distinct BigUInt
ConstantNode* = object
bigVal*: BigUInt
UnoOpNode*[T] = object
op*: UnoOp
@ -56,7 +71,7 @@ type
Node*[T] = object
case kind*: NodeKind
of Input: inp*: InputNode[T]
of Input: inp*: InputNode
of Const: kst*: ConstantNode
of Uno: uno*: UnoOpNode[T]
of Duo: duo*: DuoOpNode[T]
@ -66,7 +81,8 @@ type
offset*: uint32
length*: uint32
WitnessMapping* = distinct seq[uint32]
WitnessMapping* = object
mapping*: seq[uint32]
CircuitInputs* = seq[(string, SignalDescription)]
@ -78,3 +94,51 @@ type
nodes*: seq[Node[uint32]]
meta*: GraphMetaData
#-------------------------------------------------------------------------------
func unwrapBigUInt*(x: BigUInt): seq[byte] = x.bytes
func bigFromBigUInt*(big: BigUInt): B =
let bytes = unwrapBigUInt(big)
var buf: seq[byte] = newSeq[byte](32)
for i, x in bytes.pairs():
buf[i] = x
var output : B
unmarshal(output, buf, littleEndian)
return output
func fromBigUInt*(big: BigUInt): F =
return bigToF(bigFromBigUInt(big))
#-------------------------------------------------------------------------------
proc fmapUno[S,T]( fun: (S) -> T , node: UnoOpNode[S]): UnoOpNode[T] =
UnoOpNode[T]( op: node.op, arg1: fun(node.arg1) )
proc fmapDuo[S,T]( fun: (S) -> T , node: DuoOpNode[S]): DuoOpNode[T] =
DuoOpNode[T]( op: node.op, arg1: fun(node.arg1), arg2: fun(node.arg2) )
proc fmapTres[S,T]( fun: (S) -> T , node: TresOpNode[S]): TresOpNode[T] =
TresOpNode[T]( op: node.op, arg1: fun(node.arg1), arg2: fun(node.arg2), arg3: fun(node.arg3) )
proc fmap* [S,T]( fun: (S) -> T , node: Node[S]): Node[T] =
case node.kind:
of Input: Node[T](kind: Input , inp: node.inp )
of Const: Node[T](kind: Const , kst: node.kst )
of Uno: Node[T](kind: Uno , uno: fmapUno( fun, node.uno ) )
of Duo: Node[T](kind: Duo , duo: fmapDuo( fun, node.duo ) )
of Tres: Node[T](kind: Tres , tres: fmapTres(fun, node.tres) )
#-------------------------------------------------------------------------------
proc showNodeUint32*( node: Node[uint32] ): string =
case node.kind:
of Input: "Input idx=" & ($node.inp.idx)
of Const: "Const kst=" & bigToDecimal(bigFromBigUInt(node.kst.bigVal))
of Uno: "Uno op=" & ($node.uno.op ) & " | arg1=" & ($node.uno.arg1 )
of Duo: "Duo op=" & ($node.duo.op ) & " | arg1=" & ($node.duo.arg1 ) & " | arg2=" & ($node.duo.arg2 )
of Tres: "Tres op=" & ($node.tres.op) & " | arg1=" & ($node.tres.arg1) & " | arg2=" & ($node.tres.arg2) & " | arg3=" & ($node.tres.arg3)
proc printNodeUint32*( node: Node[uint32] ) = echo showNodeUint32(node)
#-------------------------------------------------------------------------------

View File

@ -4,12 +4,10 @@ import std/json
import std/tables
import ./field
import ./graph
#-------------------------------------------------------------------------------
type
Inputs* = Table[string, seq[F]]
proc printInputs*(inputs: Inputs) =
for key, list in pairs(inputs):
echo key

View File

@ -7,13 +7,13 @@ import ./graph
#-------------------------------------------------------------------------------
proc parseVarUint64(buf: openArray[byte], p: var int): uint64 =
let x = buf[p]
let x : uint8 = uint8(buf[p])
p += 1
if x < 128:
return uint64(x)
else:
let y = buf.parseVarUint64(p)
return uint64(x - 128) + (y shl 7)
return uint64(bitand(x, 0x7f)) + (y shl 7)
proc parseVarUint32(buf: openArray[byte], p: var int): uint32 =
return uint32( parseVarUint64(buf,p) )
@ -64,7 +64,7 @@ proc parseGenericNode(buf: openArray[byte]): seq[uint32] =
proc parseInputNode(buf: openArray[byte]): Node[uint32] =
# echo "InputNode"
let values = parseGenericNode(buf)
let node: InputNode[uint32] = InputNode[uint32](idx: values[1])
let node: InputNode = InputNode(idx: values[1])
return Node[uint32](kind: Input, inp: node)
proc parseConstantNode(buf: openArray[byte]): Node[uint32] =
@ -84,7 +84,7 @@ proc parseConstantNode(buf: openArray[byte]): Node[uint32] =
for i in 0..<l2: bytes[i] = buf[p+i]
# echo leBytesToHex(bytes)
let node: ConstantNode = ConstantNode(BigUInt(bytes))
let node: ConstantNode = ConstantNode(bigVal: BigUInt(bytes: bytes))
return Node[uint32](kind: Const, kst: node)
proc parseUnoOpNode(buf: openArray[byte]): Node[uint32] =
@ -205,7 +205,7 @@ proc parseMeta(buf: openArray[byte]): GraphMetaData =
let entry = buf.parseCircuitInput(p)
entries.add(entry)
return GraphMetaData(witnessMapping: WitnessMapping(mapping), inputSignals: entries)
return GraphMetaData(witnessMapping: WitnessMapping(mapping: mapping), inputSignals: entries)
#-------------------------------------------------------------------------------

View File

@ -0,0 +1,184 @@
import std/bitops
import std/tables
import pkg/constantine/math/arithmetic
import pkg/constantine/math/io/io_bigints
import ./graph
import ./field
#-------------------------------------------------------------------------------
func bigIntBitwiseComplement(x: B): B =
var bytes1 : seq[byte] = newSeq[byte](32)
var bytes2 : seq[byte] = newSeq[byte](32)
marshal(bytes1, x, littleEndian)
for i in 0..<32:
bytes2[i] = bitxor( bytes1[i] , 0xff )
var output : B
unmarshal(output, bytes2, littleEndian)
return output
func bigIntBitwiseAnd(x, y: B): B =
var bytes1 : seq[byte] = newSeq[byte](32)
var bytes2 : seq[byte] = newSeq[byte](32)
var bytes3 : seq[byte] = newSeq[byte](32)
marshal(bytes1, x, littleEndian)
marshal(bytes2, y, littleEndian)
for i in 0..<32:
bytes3[i] = bitand( bytes1[i] , bytes2[i] )
var output : B
unmarshal(output, bytes3, littleEndian)
return output
func bigIntBitwiseOr(x, y: B): B =
var bytes1 : seq[byte] = newSeq[byte](32)
var bytes2 : seq[byte] = newSeq[byte](32)
var bytes3 : seq[byte] = newSeq[byte](32)
marshal(bytes1, x, littleEndian)
marshal(bytes2, y, littleEndian)
for i in 0..<32:
bytes3[i] = bitor( bytes1[i] , bytes2[i] )
var output : B
unmarshal(output, bytes3, littleEndian)
return output
func bigIntBitwiseXor(x, y: B): B =
var bytes1 : seq[byte] = newSeq[byte](32)
var bytes2 : seq[byte] = newSeq[byte](32)
var bytes3 : seq[byte] = newSeq[byte](32)
marshal(bytes1, x, littleEndian)
marshal(bytes2, y, littleEndian)
for i in 0..<32:
bytes3[i] = bitxor( bytes1[i] , bytes2[i] )
var output : B
unmarshal(output, bytes3, littleEndian)
return output
#-------------------------------------------------------------------------------
func applyFieldMask(big : B) : F =
return bigToF( bigIntBitwiseAnd( fieldMask, big ) )
func fieldComplement(x: F): F =
let big1 = fToBig(x)
let comp = bigIntBitwiseComplement( big1 )
return applyFieldMask(comp)
#-------------------------------------------------------------------------------
func fieldNegateB(x : B): B =
if bool(isZero(x)):
return x
else:
return fieldPrime - x
func smallShiftRightB(x: B, k: int): B =
if (k == 0):
return x
elif (k < 64):
var y : B = x
y.shiftRight(k)
return y
else:
# more constantine limitations...
var y : B = x
y.shiftRight(63)
return smallShiftRightB(y, k-63)
func shiftLeftF*( x: F, kbig: B ) : F
func shiftRightF*( x: F, kbig: B ) : F
func shiftLeftF*( x: F, kbig: B ) : F =
if (isZeroB(kbig)):
return x
elif bool(kbig >= halfPrimePlus1):
return shiftRightF( x , fieldNegateB(kbig) )
elif bool(kbig > numberOfBitsAsBigInt):
return zeroF
else:
let k = int(kbig.limbs[0])
var y = fToBig(x)
for i in 0..<k: # constantine has `shiftRight` but no `shiftLeft`, WTF seriously
let _ = y.double()
return applyFieldMask( y )
func shiftRightF*( x: F, kbig: B ) : F =
if (isZeroB(kbig)):
return x # WTF constantine ?!?!?!
if bool(kbig >= halfPrimePlus1):
return shiftLeftF( x , fieldNegateB(kbig) )
elif bool(kbig > numberOfBitsAsBigInt):
return zeroF
else:
let k = int(kbig.limbs[0])
var y : B = fToBig(x)
return bigToF( smallShiftRightB( y , k ) )
#[
proc shiftSanityCheck*() =
let x: F = intToF(12345678903)
let k: B = uintToB(8)
let nk: B = fieldPrime - k
echo fToDecimal( shiftLeftF( x,k) )
echo fToDecimal( shiftRightF(x,k) )
echo fToDecimal( shiftLeftF( x,nk) )
echo fToDecimal( shiftRightF(x,nk) )
let x2: F = decimalToF("21051029818893485635560069555360071249585393429228441201546820650188605022495") # intToF(12345678903)
let k2: B = uintToB(0)
echo fToDecimal( shiftLeftF( x2,k2) )
echo fToDecimal( shiftRightF(x2,k2) )
let x3: F = decimalToF("21051029818893485635560069555360071249585393429228441201546820650188605022495") # intToF(12345678903)
let k3: B = uintToB(100)
echo fToDecimal( shiftRightF(x3,k3) )
]#
#-------------------------------------------------------------------------------
func evalUnoOpNode(op: UnoOp, x: F): F =
case op:
of Neg: return negF(x)
of Id: return x
of LNot: return boolToF( not (fToBool x) )
of Bnot: return fieldComplement(x)
func evalDuoOpNode(op: DuoOp, x: F, y: F): F =
case op:
of Mul: return x * y
of Div: return if isZeroF(y): zeroF else: x / y
of Add: return x + y
of Sub: return x - y
of Pow: assert( false, "Pow: not yet implemented" )
of Idiv: assert( false, "Idiv: not yet implemented" ) # return bigToF( fToBig(x) div fToBig(y) )
of Mod: assert( false, "Mod: not yet implemented" ) # return bigToF( fToBig(x) mod fToBig(y) )
of Eq: return boolToF( x === y )
of Neq: return boolToF( not (x === y) )
of Lt: return boolToF( bool( fToBig(x) < fToBig(y) ) )
of Gt: return boolToF( bool( fToBig(x) > fToBig(y) ) )
of Leq: return boolToF( bool( fToBig(x) <= fToBig(y) ) )
of Geq: return boolToF( bool( fToBig(x) >= fToBig(y) ) )
of Land: return boolToF( fToBool(x) and fToBool(y) )
of Lor: return boolToF( fToBool(x) or fToBool(y) )
of Shl: return shiftLeftF( x , fToBig(y) )
of Shr: return shiftRightF( x , fToBig(y) )
of Bor: return bigToF( bigIntBitwiseOr( fToBig(x) , fToBig(y) ) )
of Band: return bigToF( bigIntBitwiseAnd( fToBig(x) , fToBig(y) ) )
of Bxor: return bigToF( bigIntBitwiseXor( fToBig(x) , fToBig(y) ) )
func evalTresOpNode(op: TresOp, x: F, y: F, z: F): F =
case op:
of TernCond:
return (if fToBool(x): y else: z)
#-------------------------------------------------------------------------------
func evalNode*( inputs: Table[int,F] , node: Node[F] ): F =
case node.kind:
of Input: return inputs[int(node.inp.idx)]
of Const: return fromBigUInt(node.kst.bigVal)
of Uno: return evalUnoOpNode( node.uno.op , node.uno.arg1 )
of Duo: return evalDuoOpNode( node.duo.op , node.duo.arg1 , node.duo.arg2 )
of Tres: return evalTresOpNode(node.tres.op, node.tres.arg1, node.tres.arg2, node.tres.arg3 )
#-------------------------------------------------------------------------------

View File

@ -0,0 +1,70 @@
import std/tables
import std/strformat
import ./field
import ./graph
import ./semantics
#-------------------------------------------------------------------------------
proc expandInputs*(circuitInputs: seq[(string, SignalDescription)] , inputs: Inputs): Table[int, F] =
var table: Table[int, F]
table[0] = oneF
for (key, desc) in circuitInputs:
let k: int = int(desc.length)
let o: int = int(desc.offset)
assert( inputs.hasKey(key) , "input signal `" & key & "` not present" )
let list: seq[F] = inputs[key]
assert( list.len == k , "input signal `" & key & "` has unexpected size" )
for i in 0..<k:
table[o + i] = list[i]
# echo "input value " & (fToDecimal(list[i])) & " at offset " & ($(o+i))
return table
# note: this contains temporary values which are not present in the actual witness
proc generateFullComputation*(graph: Graph, inputs: Inputs): seq[F] =
let sequence : seq[Node[uint32]] = graph.nodes
let graphMeta : GraphMetaData = graph.meta
let circuitInputs : seq[(string, SignalDescription)] = graphMeta.inputSignals
let inpTable = expandInputs(circuitInputs, inputs)
var output: seq[F] = newSeq[F]( sequence.len )
for (i, node_orig) in sequence.pairs():
let node: Node[F] = fmap[uint32,F]( proc (idx: uint32): F = output[int(idx)] , node_orig )
output[i] = evalNode( inpTable , node )
# echo "index = " & ($i) & " -> " & (showNodeUint32(node_orig))
return output
#[
let node_orig = sequence[i]
let o = 32*i
let hexo = fmt"{o=:x}"
if (i == 11838): # o == 0x4fcc:
echo " i = " & ($i) & " | ofs = " & hexo & " | node = " & ($node_orig)
case node_orig.kind:
of Duo:
echo node_orig.duo.arg1
echo node_orig.duo.arg2
echo fToDecimal(output[int(node_orig.duo.arg1)])
echo fToDecimal(output[int(node_orig.duo.arg2)])
else:
discard
echo "result = " & fToDecimal(output[i])
echo " "
]#
proc generateWitness*(graph: Graph, inputs: Inputs): seq[F] =
let mapping: seq[uint32] = graph.meta.witnessMapping.mapping
let pre_witness = generateFullComputation(graph, inputs)
var output: seq[F] = newSeq[F](mapping.len)
for (j, idx) in mapping.pairs():
output[j] = pre_witness[int(idx)]
# echo " - " & ($j) & " -> " & fToDecimal(output[j]) & " | from " & ($idx)
return output
#-------------------------------------------------------------------------------

View File

@ -1,10 +1,7 @@
import std/sequtils
import circom_witnessgen/field
import circom_witnessgen/graph
import circom_witnessgen/load
import circom_witnessgen/input_json
import circom_witnessgen/witness
import circom_witnessgen/export_wtns
#-------------------------------------------------------------------------------
@ -12,19 +9,24 @@ import circom_witnessgen/export_wtns
const graph_file: string = "../tmp/graph4.bin"
const input_file: string = "../tmp/input4.json"
const wtns_file: string = "../tmp/nim4.wtns"
const comp_file: string = "../tmp/nim4_full.bin"
#-------------------------------------------------------------------------------
when isMainModule:
echo "\nloading in " & input_file
echo "loading in " & input_file
let inp = loadInputJSON(input_file)
# printInputs(inp)
echo "\nloading in " & graph_file
echo "loading in " & graph_file
let gr = loadGraph(graph_file)
# echo $gr
let us: seq[int] = @[1,2,3,4,5,6,7]
let wtns: seq[F] = us.map(intToF);
# echo "generating full computation"
# let comp = generateFullComputation( gr, inp )
# exportFeltSequence(comp_file, comp)
echo "generating witness"
let wtns = generateWitness( gr, inp )
exportWitness(wtns_file, wtns)