diff --git a/poseidon2/compress.nim b/poseidon2/compress.nim index 85c5d8c..5fc9dd0 100644 --- a/poseidon2/compress.nim +++ b/poseidon2/compress.nim @@ -6,5 +6,5 @@ func compress*(a, b : F, key = zero) : F = var x = a var y = b var z = key - permInplace(x, y, z) + permInPlace(x, y, z) return x diff --git a/poseidon2/io.nim b/poseidon2/io.nim index 2c8dd6b..68d7157 100644 --- a/poseidon2/io.nim +++ b/poseidon2/io.nim @@ -1,8 +1,9 @@ -import ./types +import std/options import constantine/math/arithmetic import constantine/math/io/io_bigints import constantine/math/io/io_fields import constantine/math/config/curves +import ./types export curves @@ -14,6 +15,13 @@ func fromBytes*(_: type F, bytes: array[31, byte]): F = ## canonical little-endian big integer. F.fromOpenArray(bytes) +func fromBytes*(_: type F, bytes: array[32, byte]): Option[F] = + ## Converts bytes into a field element. The byte array is interpreted as a + ## canonical little-endian big integer. + let big = B.unmarshal(bytes, littleEndian) + if bool(big < F.fieldMod()): + return some(F.fromBig(big)) + func toBytes*(element: F): array[32, byte] = ## Converts a field element into its canonical representation in little-endian ## byte order. Uses at most 254 bits, the remaining most-significant bits are diff --git a/poseidon2/permutation.nim b/poseidon2/permutation.nim index 6911f00..6e41f15 100644 --- a/poseidon2/permutation.nim +++ b/poseidon2/permutation.nim @@ -2,7 +2,7 @@ import ./types import ./roundfun # the Poseidon2 permutation (mutable, in-place version) -proc permInplace*(x, y, z : var F) = +proc permInPlace*(x, y, z : var F) = linearLayer(x, y, z) for j in 0..3: externalRound(j, x, y, z) @@ -14,5 +14,5 @@ proc permInplace*(x, y, z : var F) = # the Poseidon2 permutation func perm*(xyz: S) : S = var (x,y,z) = xyz - permInplace(x, y, z) + permInPlace(x, y, z) return (x,y,z) diff --git a/poseidon2/types.nim b/poseidon2/types.nim index 5cc931c..2c9fa52 100644 --- a/poseidon2/types.nim +++ b/poseidon2/types.nim @@ -38,3 +38,6 @@ func arrayFromHex*[N]( for i in low(inp)..high(inp): tmp[i] = hexToF(inp[i], endian) return tmp + +func `==`*(a, b: F): bool = + bool(arithmetic.`==`(a, b)) diff --git a/tests/poseidon2/testIo.nim b/tests/poseidon2/testIo.nim index 23b1ce9..a8189b1 100644 --- a/tests/poseidon2/testIo.nim +++ b/tests/poseidon2/testIo.nim @@ -1,5 +1,6 @@ import std/unittest import std/sequtils +import std/options import constantine/math/io/io_bigints import constantine/math/io/io_fields import constantine/math/arithmetic @@ -11,6 +12,9 @@ suite "conversion to/from bytes": func toArray(bytes: openArray[byte]): array[31, byte] = result[0..