From 1291cf0e62ddd829b7be14cccc301af82e033391 Mon Sep 17 00:00:00 2001 From: Dmitriy Ryajov Date: Wed, 28 May 2025 19:31:02 -0600 Subject: [PATCH] refactor: update function signatures to include raises annotations for better error handling --- nim/circom_witnessgen/graph.nim | 8 ++--- nim/circom_witnessgen/semantics.nim | 46 +++++++++++++++-------------- nim/circom_witnessgen/witness.nim | 8 +++-- 3 files changed, 33 insertions(+), 29 deletions(-) diff --git a/nim/circom_witnessgen/graph.nim b/nim/circom_witnessgen/graph.nim index 3719c0c..911bf52 100644 --- a/nim/circom_witnessgen/graph.nim +++ b/nim/circom_witnessgen/graph.nim @@ -112,16 +112,16 @@ func fromBigUInt*(big: BigUInt): F = #------------------------------------------------------------------------------- -proc fmapUno[S,T]( fun: ((S) {.gcsafe.} -> T) , node: UnoOpNode[S]): UnoOpNode[T] = +proc fmapUno[S,T]( fun: ((S) {.raises: [], gcsafe.} -> T) , node: UnoOpNode[S]): UnoOpNode[T] = UnoOpNode[T]( op: node.op, arg1: fun(node.arg1) ) -proc fmapDuo[S,T]( fun: ((S) {.gcsafe.} -> T) , node: DuoOpNode[S]): DuoOpNode[T] = +proc fmapDuo[S,T]( fun: ((S) {.raises: [], gcsafe.} -> T) , node: DuoOpNode[S]): DuoOpNode[T] = DuoOpNode[T]( op: node.op, arg1: fun(node.arg1), arg2: fun(node.arg2) ) -proc fmapTres[S,T]( fun: ((S) {.gcsafe.} -> T) , node: TresOpNode[S]): TresOpNode[T] = +proc fmapTres[S,T]( fun: ((S) {.raises: [], gcsafe.} -> T) , node: TresOpNode[S]): TresOpNode[T] = TresOpNode[T]( op: node.op, arg1: fun(node.arg1), arg2: fun(node.arg2), arg3: fun(node.arg3) ) -proc fmap* [S,T]( fun: ((S) {.gcsafe.} -> T) , node: Node[S]): Node[T] = +proc fmap* [S,T]( fun: ((S) {.raises: [], gcsafe.} -> T) , node: Node[S]): Node[T] = case node.kind: of Input: Node[T](kind: Input , inp: node.inp ) of Const: Node[T](kind: Const , kst: node.kst ) diff --git a/nim/circom_witnessgen/semantics.nim b/nim/circom_witnessgen/semantics.nim index fc27c50..e9b0f5d 100644 --- a/nim/circom_witnessgen/semantics.nim +++ b/nim/circom_witnessgen/semantics.nim @@ -1,4 +1,6 @@ +{.push raises: [].} + import std/bitops import std/tables @@ -14,7 +16,7 @@ func bigIntBitwiseComplement(x: B): B = var bytes1 : seq[byte] = newSeq[byte](32) var bytes2 : seq[byte] = newSeq[byte](32) marshal(bytes1, x, littleEndian) - for i in 0..<32: + for i in 0..<32: bytes2[i] = bitxor( bytes1[i] , 0xff ) var output : B unmarshal(output, bytes2, littleEndian) @@ -26,7 +28,7 @@ func bigIntBitwiseAnd(x, y: B): B = var bytes3 : seq[byte] = newSeq[byte](32) marshal(bytes1, x, littleEndian) marshal(bytes2, y, littleEndian) - for i in 0..<32: + for i in 0..<32: bytes3[i] = bitand( bytes1[i] , bytes2[i] ) var output : B unmarshal(output, bytes3, littleEndian) @@ -38,7 +40,7 @@ func bigIntBitwiseOr(x, y: B): B = var bytes3 : seq[byte] = newSeq[byte](32) marshal(bytes1, x, littleEndian) marshal(bytes2, y, littleEndian) - for i in 0..<32: + for i in 0..<32: bytes3[i] = bitor( bytes1[i] , bytes2[i] ) var output : B unmarshal(output, bytes3, littleEndian) @@ -50,7 +52,7 @@ func bigIntBitwiseXor(x, y: B): B = var bytes3 : seq[byte] = newSeq[byte](32) marshal(bytes1, x, littleEndian) marshal(bytes2, y, littleEndian) - for i in 0..<32: + for i in 0..<32: bytes3[i] = bitxor( bytes1[i] , bytes2[i] ) var output : B unmarshal(output, bytes3, littleEndian) @@ -61,20 +63,20 @@ func bigIntBitwiseXor(x, y: B): B = func applyFieldMask(big : B) : F = return bigToF( bigIntBitwiseAnd( fieldMask, big ) ) -func fieldComplement(x: F): F = +func fieldComplement(x: F): F = let big1 = fToBig(x) - let comp = bigIntBitwiseComplement( big1 ) + let comp = bigIntBitwiseComplement( big1 ) return applyFieldMask(comp) #------------------------------------------------------------------------------- -func fieldNegateB(x : B): B = +func fieldNegateB(x : B): B = if bool(isZero(x)): return x else: return fieldPrime - x -func smallShiftRightB(x: B, k: int): B = +func smallShiftRightB(x: B, k: int): B = if (k == 0): return x elif (k < 64): @@ -88,9 +90,9 @@ func smallShiftRightB(x: B, k: int): B = return smallShiftRightB(y, k-63) func shiftLeftF*( x: F, kbig: B ) : F -func shiftRightF*( x: F, kbig: B ) : F - -func shiftLeftF*( x: F, kbig: B ) : F = +func shiftRightF*( x: F, kbig: B ) : F + +func shiftLeftF*( x: F, kbig: B ) : F {.raises: [].} = if (isZeroB(kbig)): return x elif bool(kbig >= halfPrimePlus1): @@ -104,7 +106,7 @@ func shiftLeftF*( x: F, kbig: B ) : F = let _ = y.double() return applyFieldMask( y ) -func shiftRightF*( x: F, kbig: B ) : F = +func shiftRightF*( x: F, kbig: B ) : F = if (isZeroB(kbig)): return x # WTF constantine ?!?!?! if bool(kbig >= halfPrimePlus1): @@ -117,7 +119,7 @@ func shiftRightF*( x: F, kbig: B ) : F = return bigToF( smallShiftRightB( y , k ) ) #[ -proc shiftSanityCheck*() = +proc shiftSanityCheck*() = let x: F = intToF(12345678903) let k: B = uintToB(8) let nk: B = fieldPrime - k @@ -136,14 +138,14 @@ proc shiftSanityCheck*() = #------------------------------------------------------------------------------- -func evalUnoOpNode(op: UnoOp, x: F): F = +func evalUnoOpNode(op: UnoOp, x: F): F = case op: of Neg: return negF(x) of Id: return x of LNot: return boolToF( not (fToBool x) ) of Bnot: return fieldComplement(x) -func evalDuoOpNode(op: DuoOp, x: F, y: F): F = +func evalDuoOpNode(op: DuoOp, x: F, y: F): F = case op: of Mul: return x * y of Div: return if isZeroF(y): zeroF else: x / y @@ -166,19 +168,19 @@ func evalDuoOpNode(op: DuoOp, x: F, y: F): F = of Band: return bigToF( bigIntBitwiseAnd( fToBig(x) , fToBig(y) ) ) of Bxor: return bigToF( bigIntBitwiseXor( fToBig(x) , fToBig(y) ) ) -func evalTresOpNode(op: TresOp, x: F, y: F, z: F): F = +func evalTresOpNode(op: TresOp, x: F, y: F, z: F): F = case op: of TernCond: return (if fToBool(x): y else: z) #------------------------------------------------------------------------------- -func evalNode*( inputs: Table[int,F] , node: Node[F] ): F = +func evalNode*( inputs: Table[int,F] , node: Node[F] ): F {.raises: KeyError.} = case node.kind: - of Input: return inputs[int(node.inp.idx)] - of Const: return fromBigUInt(node.kst.bigVal) - of Uno: return evalUnoOpNode( node.uno.op , node.uno.arg1 ) - of Duo: return evalDuoOpNode( node.duo.op , node.duo.arg1 , node.duo.arg2 ) - of Tres: return evalTresOpNode(node.tres.op, node.tres.arg1, node.tres.arg2, node.tres.arg3 ) + of Input: inputs[int(node.inp.idx)] + of Const: fromBigUInt(node.kst.bigVal) + of Uno: evalUnoOpNode( node.uno.op , node.uno.arg1 ) + of Duo: evalDuoOpNode( node.duo.op , node.duo.arg1 , node.duo.arg2 ) + of Tres: evalTresOpNode(node.tres.op, node.tres.arg1, node.tres.arg2, node.tres.arg3 ) #------------------------------------------------------------------------------- diff --git a/nim/circom_witnessgen/witness.nim b/nim/circom_witnessgen/witness.nim index f22422a..48ed807 100644 --- a/nim/circom_witnessgen/witness.nim +++ b/nim/circom_witnessgen/witness.nim @@ -1,4 +1,6 @@ +{.push raises: [].} + import std/tables import std/strformat @@ -15,7 +17,7 @@ proc expandInputs*(circuitInputs: seq[(string, SignalDescription)] , inputs: Inp let k: int = int(desc.length) let o: int = int(desc.offset) assert( inputs.hasKey(key) , "input signal `" & key & "` not present" ) - let list: seq[F] = inputs[key] + let list: seq[F] = try: inputs[key] except KeyError as exc: raiseAssert(exc.msg) assert( list.len == k , "input signal `" & key & "` has unexpected size" ) for i in 0..