diff --git a/poseidon2.nim b/poseidon2.nim index 05db865..f777295 100644 --- a/poseidon2.nim +++ b/poseidon2.nim @@ -98,4 +98,4 @@ func merkleRoot*(xs: openArray[F]) : F = return merkleRoot(ys) func merkleRoot*(bytes: openArray[byte]): F = - merkleRoot(seq[F].unmarshal(bytes, littleEndian)) + merkleRoot(seq[F].unmarshal(bytes)) diff --git a/poseidon2/io.nim b/poseidon2/io.nim index 2d2daa6..06827d1 100644 --- a/poseidon2/io.nim +++ b/poseidon2/io.nim @@ -2,32 +2,20 @@ import ./types import constantine/math/arithmetic import constantine/math/io/io_bigints -func padRight(source: openArray[byte], endian: static Endianness): array[32, byte] = - assert source.len <= 31 - when endian == littleEndian: - copyMem(addr result[0], unsafeAddr source[0], source.len) - when endian == bigEndian: - copyMem(addr result[1], unsafeAddr source[0], source.len) - -func unmarshal*( - _: type F, - bytes: openArray[byte], - endian: static Endianness): F = +func unmarshal*(_: type F, bytes: openArray[byte]): F = assert bytes.len <= 31 - let padded = bytes.padRight(endian) - let bigint = B.unmarshal(padded, endian) + var padded: array[32, byte] + copyMem(addr padded[0], unsafeAddr bytes[0], bytes.len) + let bigint = B.unmarshal(padded, littleEndian) return F.fromBig(bigint) -func unmarshal*( - _: type seq[F], - bytes: openArray[byte], - endian: static Endianness): seq[F] = +func unmarshal*(_: type seq[F], bytes: openArray[byte]): seq[F] = const chunkLen = 31 var elements: seq[F] var chunkStart = 0 while chunkStart < bytes.len: let chunkEnd = min(chunkStart + 31, bytes.len) - let element = F.unmarshal(bytes.toOpenArray(chunkStart, chunkEnd - 1), endian) + let element = F.unmarshal(bytes.toOpenArray(chunkStart, chunkEnd - 1)) elements.add(element) chunkStart += chunkLen return elements diff --git a/tests/poseidon2/testIo.nim b/tests/poseidon2/testIo.nim index a38bf30..98c5a80 100644 --- a/tests/poseidon2/testIo.nim +++ b/tests/poseidon2/testIo.nim @@ -7,49 +7,29 @@ import poseidon2/io suite "unmarshalling": - test "converts big endian bytes into field elements": - let bytes = toSeq 1'u8..31'u8 - let paddedTo32 = @[0x00'u8] & bytes # most significant byte is not used - let expected = F.fromBig(B.unmarshal(paddedTo32, bigEndian)) - let unmarshalled = F.unmarshal(bytes, bigEndian) - check bool(unmarshalled == expected) - test "converts little endian bytes into field elements": let bytes = toSeq 1'u8..31'u8 - let paddedTo32 = bytes & @[0x00'u8] # most significant byte is not used + let paddedTo32 = bytes & @[0'u8] # most significant byte is not used let expected = F.fromBig(B.unmarshal(paddedTo32, littleEndian)) - let unmarshalled = F.unmarshal(bytes, littleEndian) - check bool(unmarshalled == expected) - - test "pads big endian bytes to the right with 0's": - let bytes = @[0x12'u8, 0x34, 0x56] - let paddedTo31 = bytes & 0x00'u8.repeat(31 - bytes.len) - let paddedTo32 = @[0x00'u8] & paddedTo31 # most significant byte is not used - let expected = F.fromBig(B.unmarshal(paddedTo32, bigEndian)) - let unmarshalled = F.unmarshal(bytes, bigEndian) + let unmarshalled = F.unmarshal(bytes) check bool(unmarshalled == expected) test "pads little endian bytes to the right with 0's": let bytes = @[0x56'u8, 0x34, 0x12] - let paddedTo31 = bytes & 0x00'u8.repeat(31 - bytes.len) - let paddedTo32 = paddedTo31 & @[0x00'u8] # most significant byte is not used + let paddedTo32 = bytes & 0'u8.repeat(32 - bytes.len) let expected = F.fromBig(B.unmarshal(paddedTo32, littleEndian)) - let unmarshalled = F.unmarshal(bytes, littleEndian) + let unmarshalled = F.unmarshal(bytes) check bool(unmarshalled == expected) test "converts every 31 bytes into a field element": - template checkConversion(endian) = - let bytes = toSeq 1'u8..80'u8 - let padded = bytes & 0'u8.repeat(93 - bytes.len) - let expected1 = F.fromBig(B.unmarshal(padded[ 0..<31], endian)) - let expected2 = F.fromBig(B.unmarshal(padded[31..<62], endian)) - let expected3 = F.fromBig(B.unmarshal(padded[62..<93], endian)) - let elements = seq[F].unmarshal(bytes, endian) - check elements.len == 3 - check bool(elements[0] == expected1) - check bool(elements[1] == expected2) - check bool(elements[2] == expected3) - - checkConversion(littleEndian) - checkConversion(bigEndian) + let bytes = toSeq 1'u8..80'u8 + let padded = bytes & 0'u8.repeat(93 - bytes.len) + let expected1 = F.fromBig(B.unmarshal(padded[ 0..<31] & @[0'u8], littleEndian)) + let expected2 = F.fromBig(B.unmarshal(padded[31..<62] & @[0'u8], littleEndian)) + let expected3 = F.fromBig(B.unmarshal(padded[62..<93] & @[0'u8], littleEndian)) + let elements = seq[F].unmarshal(bytes) + check elements.len == 3 + check bool(elements[0] == expected1) + check bool(elements[1] == expected2) + check bool(elements[2] == expected3)