diff --git a/poseidon2/merkle.nim b/poseidon2/merkle.nim index f8b98f4..48c8e91 100644 --- a/poseidon2/merkle.nim +++ b/poseidon2/merkle.nim @@ -1,4 +1,3 @@ -import std/sequtils import constantine/math/arithmetic import constantine/math/io/io_fields import ./types @@ -10,39 +9,57 @@ const KeyBottomLayer = F.fromHex("0x1") const KeyOdd = F.fromHex("0x2") const KeyOddAndBottomLayer = F.fromhex("0x3") -func merkleRoot(xs: openArray[F], isBottomLayer: static bool) : F = - let a = low(xs) - let b = high(xs) - let m = b-a+1 +type Merkle* = object + todo: seq[F] # nodes that haven't been combined yet + width: int # width of the current subtree + leafs: int # amount of leafs processed - when isBottomLayer: - assert m > 0, "merkle root of empty sequence is not defined" +func init*(_: type Merkle): Merkle = + Merkle(width: 2) - when not isBottomLayer: - if m==1: - return xs[a] - - let halfn : int = m div 2 - let n : int = 2*halfn - let isOdd : bool = (n != m) - - var ys : seq[F] - if not isOdd: - ys = newSeq[F](halfn) +func compress(merkle: var Merkle, odd: static bool) = + when odd: + let a = merkle.todo.pop() + let b = zero + let key = if merkle.width == 2: KeyOddAndBottomLayer else: KeyOdd + merkle.todo.add(compress(a, b, key = key)) + merkle.leafs += merkle.width div 2 # zero node represents this many leafs else: - ys = newSeq[F](halfn+1) + let b = merkle.todo.pop() + let a = merkle.todo.pop() + let key = if merkle.width == 2: KeyBottomLayer else: KeyNone + merkle.todo.add(compress(a, b, key = key)) + merkle.width *= 2 - for i in 0.. 0, "merkle root of empty sequence is not defined" -func merkleRoot*(xs: openArray[F]) : F = - merkleRoot(xs, isBottomLayer = true) + if merkle.leafs == 1: + merkle.compress(odd = true) -func merkleRoot*(bytes: openArray[byte]): F = - merkleRoot(toSeq bytes.elements(F)) + while merkle.todo.len > 1: + if merkle.leafs mod merkle.width == 0: + merkle.compress(odd = false) + else: + merkle.compress(odd = true) + + return merkle.todo[0] + +func digest*(_: type Merkle, elements: openArray[F]): F = + var merkle = Merkle.init() + for element in elements: + merkle.update(element) + return merkle.finish() + +func digest*(_: type Merkle, bytes: openArray[byte]): F = + var merkle = Merkle.init() + for element in bytes.elements(F): + merkle.update(element) + return merkle.finish() diff --git a/tests/poseidon2/testMerkle.nim b/tests/poseidon2/testMerkle.nim index bddc6ad..84bfd66 100644 --- a/tests/poseidon2/testMerkle.nim +++ b/tests/poseidon2/testMerkle.nim @@ -25,7 +25,7 @@ suite "merkle root": for i in 1..n: xs.add( toF(i) ) - let root = merkleRoot(xs) + let root = Merkle.digest(xs) check root.toHex(littleEndian) == "0x593e01f200cb1aee4e75fe2a9206abc3abd2a1216ab75f1061965e97371e8623" test "merkle root of even elements": @@ -34,7 +34,7 @@ suite "merkle root": compress(1.toF, 2.toF, key = isBottomLayer.toF), compress(3.toF, 4.toF, key = isBottomLayer.toF), ) - check bool(merkleRoot(elements) == expected) + check bool(Merkle.digest(elements) == expected) test "merkle root of odd elements": let elements = toSeq(1..3).mapIt(toF(it)) @@ -42,40 +42,40 @@ suite "merkle root": compress(1.toF, 2.toF, key = isBottomLayer.toF), compress(3.toF, 0.toF, key = (isBottomLayer + isOddNode).toF) ) - check bool(merkleRoot(elements) == expected) + check bool(Merkle.digest(elements) == expected) test "data ending with 0 differs from padded data": let a = toSeq(1..3).mapIt(it.toF) let b = a & @[0.toF] - check not bool(merkleRoot(a) == merkleRoot(b)) + check not bool(Merkle.digest(a) == Merkle.digest(b)) test "merkle root of single element does not equal the element": - check not bool(merkleRoot([1.toF]) == 1.toF) + check not bool(Merkle.digest([1.toF]) == 1.toF) test "merkle root differs from merkle root of merkle root": let a = 1.toF let b = 2.toF - check not bool(merkleRoot([a, b]) == merkleRoot([merkleRoot([a, b])])) + check not bool(Merkle.digest([a, b]) == Merkle.digest([Merkle.digest([a, b])])) test "merkle root of bytes": let bytes = toSeq 1'u8..80'u8 - let root = merkleRoot(bytes) + let root = Merkle.digest(bytes) check root.toHex(littleEndian) == "0x40989b63104f39e3331767883381085bcfc46e2202679123371f1ffe53521b16" test "merkle root of bytes converted to bytes": let bytes = toSeq 1'u8..80'u8 - let rootAsBytes = merkleRoot(bytes).toBytes() + let rootAsBytes = Merkle.digest(bytes).toBytes() check rootAsBytes.toHex == "0x40989b63104f39e3331767883381085bcfc46e2202679123371f1ffe53521b16" test "merkle root of empty sequence of elements": let empty = seq[F].default expect Exception: - discard merkleRoot(empty) + discard Merkle.digest(empty) test "merkle root of empty sequency of bytes": # merkle root of empty sequence of bytes is uniquely defined through padding let empty = seq[byte].default - check merkleRoot(empty).toBytes.toHex == "0xcc8da1d157900e611b89e258d95450e707f4f9eec169422d7c26aba54f803c08" + check Merkle.digest(empty).toBytes.toHex == "0xcc8da1d157900e611b89e258d95450e707f4f9eec169422d7c26aba54f803c08" suite "merkle root test vectors": @@ -126,7 +126,7 @@ suite "merkle root test vectors": for n in 1..40: let input = collect(newSeq, (for i in 1..n: i.toF)) - let root = merkleRoot(input) + let root = Merkle.digest(input) check root.toDecimal == expected[n-1] test "byte sequences": @@ -217,5 +217,5 @@ suite "merkle root test vectors": for n in 0..80: let input = collect(newSeq, (for i in 1..n: byte(i))) - let root = merkleRoot(input) + let root = Merkle.digest(input) check root.toDecimal == expected[n]