diff --git a/README.md b/README.md index 66e2a0f..3d23911 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,9 @@ Nim witness generator (to be used with [`nim-groth16`](https://github.com/codex- - [x] parsing the graph file - [x] parsing json input -- [ ] generating the witness -- [ ] exporting the witness +- [x] generating the witness +- [x] exporting the witness +- [ ] support the complete set of operations - [ ] proper error handling ### Testing & correctness diff --git a/nim/circom_witnessgen/export_wtns.nim b/nim/circom_witnessgen/export_wtns.nim index ebe7512..42fc3f3 100644 --- a/nim/circom_witnessgen/export_wtns.nim +++ b/nim/circom_witnessgen/export_wtns.nim @@ -39,6 +39,11 @@ proc writeHeader(s: Stream, witnessLen: int) = #------------------------------------------------------------------------------- +proc exportFeltSequence*(filepath: string, values: seq[F]) = + var stream = newFileStream(filepath, fmWrite) + for i in 0.. T , node: UnoOpNode[S]): UnoOpNode[T] = + UnoOpNode[T]( op: node.op, arg1: fun(node.arg1) ) + +proc fmapDuo[S,T]( fun: (S) -> T , node: DuoOpNode[S]): DuoOpNode[T] = + DuoOpNode[T]( op: node.op, arg1: fun(node.arg1), arg2: fun(node.arg2) ) + +proc fmapTres[S,T]( fun: (S) -> T , node: TresOpNode[S]): TresOpNode[T] = + TresOpNode[T]( op: node.op, arg1: fun(node.arg1), arg2: fun(node.arg2), arg3: fun(node.arg3) ) + +proc fmap* [S,T]( fun: (S) -> T , node: Node[S]): Node[T] = + case node.kind: + of Input: Node[T](kind: Input , inp: node.inp ) + of Const: Node[T](kind: Const , kst: node.kst ) + of Uno: Node[T](kind: Uno , uno: fmapUno( fun, node.uno ) ) + of Duo: Node[T](kind: Duo , duo: fmapDuo( fun, node.duo ) ) + of Tres: Node[T](kind: Tres , tres: fmapTres(fun, node.tres) ) + +#------------------------------------------------------------------------------- + +proc showNodeUint32*( node: Node[uint32] ): string = + case node.kind: + of Input: "Input idx=" & ($node.inp.idx) + of Const: "Const kst=" & bigToDecimal(bigFromBigUInt(node.kst.bigVal)) + of Uno: "Uno op=" & ($node.uno.op ) & " | arg1=" & ($node.uno.arg1 ) + of Duo: "Duo op=" & ($node.duo.op ) & " | arg1=" & ($node.duo.arg1 ) & " | arg2=" & ($node.duo.arg2 ) + of Tres: "Tres op=" & ($node.tres.op) & " | arg1=" & ($node.tres.arg1) & " | arg2=" & ($node.tres.arg2) & " | arg3=" & ($node.tres.arg3) + +proc printNodeUint32*( node: Node[uint32] ) = echo showNodeUint32(node) + +#------------------------------------------------------------------------------- diff --git a/nim/circom_witnessgen/input_json.nim b/nim/circom_witnessgen/input_json.nim index 97763cf..753b291 100644 --- a/nim/circom_witnessgen/input_json.nim +++ b/nim/circom_witnessgen/input_json.nim @@ -4,12 +4,10 @@ import std/json import std/tables import ./field +import ./graph #------------------------------------------------------------------------------- -type - Inputs* = Table[string, seq[F]] - proc printInputs*(inputs: Inputs) = for key, list in pairs(inputs): echo key diff --git a/nim/circom_witnessgen/load.nim b/nim/circom_witnessgen/load.nim index 8bb532d..6be2935 100644 --- a/nim/circom_witnessgen/load.nim +++ b/nim/circom_witnessgen/load.nim @@ -7,13 +7,13 @@ import ./graph #------------------------------------------------------------------------------- proc parseVarUint64(buf: openArray[byte], p: var int): uint64 = - let x = buf[p] + let x : uint8 = uint8(buf[p]) p += 1 if x < 128: return uint64(x) else: let y = buf.parseVarUint64(p) - return uint64(x - 128) + (y shl 7) + return uint64(bitand(x, 0x7f)) + (y shl 7) proc parseVarUint32(buf: openArray[byte], p: var int): uint32 = return uint32( parseVarUint64(buf,p) ) @@ -64,7 +64,7 @@ proc parseGenericNode(buf: openArray[byte]): seq[uint32] = proc parseInputNode(buf: openArray[byte]): Node[uint32] = # echo "InputNode" let values = parseGenericNode(buf) - let node: InputNode[uint32] = InputNode[uint32](idx: values[1]) + let node: InputNode = InputNode(idx: values[1]) return Node[uint32](kind: Input, inp: node) proc parseConstantNode(buf: openArray[byte]): Node[uint32] = @@ -84,7 +84,7 @@ proc parseConstantNode(buf: openArray[byte]): Node[uint32] = for i in 0..= halfPrimePlus1): + return shiftRightF( x , fieldNegateB(kbig) ) + elif bool(kbig > numberOfBitsAsBigInt): + return zeroF + else: + let k = int(kbig.limbs[0]) + var y = fToBig(x) + for i in 0..= halfPrimePlus1): + return shiftLeftF( x , fieldNegateB(kbig) ) + elif bool(kbig > numberOfBitsAsBigInt): + return zeroF + else: + let k = int(kbig.limbs[0]) + var y : B = fToBig(x) + return bigToF( smallShiftRightB( y , k ) ) + +#[ +proc shiftSanityCheck*() = + let x: F = intToF(12345678903) + let k: B = uintToB(8) + let nk: B = fieldPrime - k + echo fToDecimal( shiftLeftF( x,k) ) + echo fToDecimal( shiftRightF(x,k) ) + echo fToDecimal( shiftLeftF( x,nk) ) + echo fToDecimal( shiftRightF(x,nk) ) + let x2: F = decimalToF("21051029818893485635560069555360071249585393429228441201546820650188605022495") # intToF(12345678903) + let k2: B = uintToB(0) + echo fToDecimal( shiftLeftF( x2,k2) ) + echo fToDecimal( shiftRightF(x2,k2) ) + let x3: F = decimalToF("21051029818893485635560069555360071249585393429228441201546820650188605022495") # intToF(12345678903) + let k3: B = uintToB(100) + echo fToDecimal( shiftRightF(x3,k3) ) +]# + +#------------------------------------------------------------------------------- + +func evalUnoOpNode(op: UnoOp, x: F): F = + case op: + of Neg: return negF(x) + of Id: return x + of LNot: return boolToF( not (fToBool x) ) + of Bnot: return fieldComplement(x) + +func evalDuoOpNode(op: DuoOp, x: F, y: F): F = + case op: + of Mul: return x * y + of Div: return if isZeroF(y): zeroF else: x / y + of Add: return x + y + of Sub: return x - y + of Pow: assert( false, "Pow: not yet implemented" ) + of Idiv: assert( false, "Idiv: not yet implemented" ) # return bigToF( fToBig(x) div fToBig(y) ) + of Mod: assert( false, "Mod: not yet implemented" ) # return bigToF( fToBig(x) mod fToBig(y) ) + of Eq: return boolToF( x === y ) + of Neq: return boolToF( not (x === y) ) + of Lt: return boolToF( bool( fToBig(x) < fToBig(y) ) ) + of Gt: return boolToF( bool( fToBig(x) > fToBig(y) ) ) + of Leq: return boolToF( bool( fToBig(x) <= fToBig(y) ) ) + of Geq: return boolToF( bool( fToBig(x) >= fToBig(y) ) ) + of Land: return boolToF( fToBool(x) and fToBool(y) ) + of Lor: return boolToF( fToBool(x) or fToBool(y) ) + of Shl: return shiftLeftF( x , fToBig(y) ) + of Shr: return shiftRightF( x , fToBig(y) ) + of Bor: return bigToF( bigIntBitwiseOr( fToBig(x) , fToBig(y) ) ) + of Band: return bigToF( bigIntBitwiseAnd( fToBig(x) , fToBig(y) ) ) + of Bxor: return bigToF( bigIntBitwiseXor( fToBig(x) , fToBig(y) ) ) + +func evalTresOpNode(op: TresOp, x: F, y: F, z: F): F = + case op: + of TernCond: + return (if fToBool(x): y else: z) + +#------------------------------------------------------------------------------- + +func evalNode*( inputs: Table[int,F] , node: Node[F] ): F = + case node.kind: + of Input: return inputs[int(node.inp.idx)] + of Const: return fromBigUInt(node.kst.bigVal) + of Uno: return evalUnoOpNode( node.uno.op , node.uno.arg1 ) + of Duo: return evalDuoOpNode( node.duo.op , node.duo.arg1 , node.duo.arg2 ) + of Tres: return evalTresOpNode(node.tres.op, node.tres.arg1, node.tres.arg2, node.tres.arg3 ) + +#------------------------------------------------------------------------------- diff --git a/nim/circom_witnessgen/witness.nim b/nim/circom_witnessgen/witness.nim new file mode 100644 index 0000000..82b2509 --- /dev/null +++ b/nim/circom_witnessgen/witness.nim @@ -0,0 +1,70 @@ + +import std/tables +import std/strformat + +import ./field +import ./graph +import ./semantics + +#------------------------------------------------------------------------------- + +proc expandInputs*(circuitInputs: seq[(string, SignalDescription)] , inputs: Inputs): Table[int, F] = + var table: Table[int, F] + table[0] = oneF + for (key, desc) in circuitInputs: + let k: int = int(desc.length) + let o: int = int(desc.offset) + assert( inputs.hasKey(key) , "input signal `" & key & "` not present" ) + let list: seq[F] = inputs[key] + assert( list.len == k , "input signal `" & key & "` has unexpected size" ) + for i in 0.. " & (showNodeUint32(node_orig)) + + return output + +#[ + let node_orig = sequence[i] + let o = 32*i + let hexo = fmt"{o=:x}" + if (i == 11838): # o == 0x4fcc: + echo " i = " & ($i) & " | ofs = " & hexo & " | node = " & ($node_orig) + case node_orig.kind: + of Duo: + echo node_orig.duo.arg1 + echo node_orig.duo.arg2 + echo fToDecimal(output[int(node_orig.duo.arg1)]) + echo fToDecimal(output[int(node_orig.duo.arg2)]) + else: + discard + echo "result = " & fToDecimal(output[i]) + echo " " +]# + +proc generateWitness*(graph: Graph, inputs: Inputs): seq[F] = + let mapping: seq[uint32] = graph.meta.witnessMapping.mapping + let pre_witness = generateFullComputation(graph, inputs) + var output: seq[F] = newSeq[F](mapping.len) + for (j, idx) in mapping.pairs(): + output[j] = pre_witness[int(idx)] + # echo " - " & ($j) & " -> " & fToDecimal(output[j]) & " | from " & ($idx) + return output + + +#------------------------------------------------------------------------------- diff --git a/nim/main.nim b/nim/main.nim index d96546d..9fc92a6 100644 --- a/nim/main.nim +++ b/nim/main.nim @@ -1,10 +1,7 @@ -import std/sequtils - -import circom_witnessgen/field -import circom_witnessgen/graph import circom_witnessgen/load import circom_witnessgen/input_json +import circom_witnessgen/witness import circom_witnessgen/export_wtns #------------------------------------------------------------------------------- @@ -12,19 +9,24 @@ import circom_witnessgen/export_wtns const graph_file: string = "../tmp/graph4.bin" const input_file: string = "../tmp/input4.json" const wtns_file: string = "../tmp/nim4.wtns" +const comp_file: string = "../tmp/nim4_full.bin" #------------------------------------------------------------------------------- when isMainModule: - echo "\nloading in " & input_file + echo "loading in " & input_file let inp = loadInputJSON(input_file) # printInputs(inp) - echo "\nloading in " & graph_file + echo "loading in " & graph_file let gr = loadGraph(graph_file) # echo $gr - let us: seq[int] = @[1,2,3,4,5,6,7] - let wtns: seq[F] = us.map(intToF); + # echo "generating full computation" + # let comp = generateFullComputation( gr, inp ) + # exportFeltSequence(comp_file, comp) + + echo "generating witness" + let wtns = generateWitness( gr, inp ) exportWitness(wtns_file, wtns) \ No newline at end of file