diff --git a/poseidon2.nim b/poseidon2.nim index d68ec1c..7595af3 100644 --- a/poseidon2.nim +++ b/poseidon2.nim @@ -1,3 +1,4 @@ +import std/sequtils import constantine/math/arithmetic import poseidon2/types @@ -7,7 +8,7 @@ import poseidon2/sponge export sponge export toBytes - +export elements #------------------------------------------------------------------------------- @@ -46,6 +47,6 @@ func merkleRoot*(xs: openArray[F]) : F = return merkleRoot(ys) func merkleRoot*(bytes: openArray[byte]): F = - merkleRoot(seq[F].fromBytes(bytes)) + merkleRoot(toSeq bytes.elements(F)) #------------------------------------------------------------------------------- diff --git a/poseidon2/io.nim b/poseidon2/io.nim index 3ab70ae..721ccba 100644 --- a/poseidon2/io.nim +++ b/poseidon2/io.nim @@ -13,18 +13,16 @@ func fromBytes*(_: type F, bytes: openArray[byte]): F = let bigint = B.unmarshal(padded, littleEndian) return F.fromBig(bigint) -func fromBytes*(_: type seq[F], bytes: openArray[byte]): seq[F] = +iterator elements*(bytes: openArray[byte], _: type F): F = ## Converts bytes into field elements. The byte array is converted 31 bytes at ## a time with the `F.fromBytes()` function. const chunkLen = 31 - var elements: seq[F] var chunkStart = 0 while chunkStart < bytes.len: let chunkEnd = min(chunkStart + 31, bytes.len) let element = F.fromBytes(bytes.toOpenArray(chunkStart, chunkEnd - 1)) - elements.add(element) + yield element chunkStart += chunkLen - return elements func toBytes*(element: F): array[32, byte] = ## Converts a field element into its canonical representation in little-endian diff --git a/tests/poseidon2/testIo.nim b/tests/poseidon2/testIo.nim index ebcca69..b903f2a 100644 --- a/tests/poseidon2/testIo.nim +++ b/tests/poseidon2/testIo.nim @@ -27,7 +27,7 @@ suite "conversion to/from bytes": let expected1 = F.fromBig(B.unmarshal(padded[ 0..<31] & @[0'u8], littleEndian)) let expected2 = F.fromBig(B.unmarshal(padded[31..<62] & @[0'u8], littleEndian)) let expected3 = F.fromBig(B.unmarshal(padded[62..<93] & @[0'u8], littleEndian)) - let elements = seq[F].fromBytes(bytes) + let elements = toSeq bytes.elements(F) check elements.len == 3 check bool(elements[0] == expected1) check bool(elements[1] == expected2)