diff --git a/beacon_chain/merkle_minimal.nim b/beacon_chain/merkle_minimal.nim index e74f0b299..2762ce803 100644 --- a/beacon_chain/merkle_minimal.nim +++ b/beacon_chain/merkle_minimal.nim @@ -13,9 +13,10 @@ {.push raises: [Defect].} import - sequtils, strutils, macros, bitops, + sequtils, macros, + stew/endians2, # Specs - ../../beacon_chain/spec/[beaconstate, datatypes, digest, helpers], + ../../beacon_chain/spec/[datatypes, digest], ../../beacon_chain/ssz/merkleization # TODO All tests need to be moved to the test suite. @@ -33,9 +34,14 @@ type SparseMerkleTree*[Depth: static int] = object # There is an extra "depth" layer to store leaf nodes # This stores leaves at depth = 0 # and the root hash at the last depth - nnznodes: array[Depth+1, seq[Eth2Digest]] # nodes that leads to non-zero leaves + nnznodes*: array[Depth+1, seq[Eth2Digest]] # nodes that leads to non-zero leaves -func merkleTreeFromLeaves( +type + MerkleTreeFragment* = object + depth: int + elements: seq[Eth2Digest] + +func merkleTreeFromLeaves*( values: openarray[Eth2Digest], Depth: static[int] = DEPOSIT_CONTRACT_TREE_DEPTH ): SparseMerkleTree[Depth] = @@ -62,9 +68,9 @@ func merkleTreeFromLeaves( h.update zeroHashes[depth-1] result.nnznodes[depth].add nodeHash -func getMerkleProof[Depth: static int](tree: SparseMerkleTree[Depth], - index: int, - depositMode = false): array[Depth, Eth2Digest] = +func getMerkleProof*[Depth: static int](tree: SparseMerkleTree[Depth], + index: int, + depositMode = false): array[Depth, Eth2Digest] = # Descend down the tree according to the bit representation # of the index: # - 0 --> go left @@ -91,113 +97,12 @@ func getMerkleProof[Depth: static int](tree: SparseMerkleTree[Depth], depthLen = (depthLen + 1) div 2 func attachMerkleProofs*(deposits: var openarray[Deposit]) = - let - deposit_data_roots = mapIt(deposits, it.data.hash_tree_root) - merkle_tree = merkleTreeFromLeaves(deposit_data_roots) - var - deposit_data_sums: seq[Eth2Digest] - for prefix_root in hash_tree_roots_prefix( - deposit_data_roots, 1'i64 shl DEPOSIT_CONTRACT_TREE_DEPTH): - deposit_data_sums.add prefix_root + let depositsRoots = mapIt(deposits, hash_tree_root(it.data)) - for val_idx in 0 ..< deposits.len: - deposits[val_idx].proof[0..31] = merkle_tree.getMerkleProof(val_idx, true) - deposits[val_idx].proof[32].data[0..7] = uint_to_bytes8((val_idx + 1).uint64) + const depositContractLimit = Limit(1'u64 shl (DEPOSIT_CONTRACT_TREE_DEPTH - 1'u64)) + var incrementalMerkleProofs = createMerkleizer(depositContractLimit) + + for i in 0 ..< depositsRoots.len: + incrementalMerkleProofs.addChunkAndGenMerkleProof(depositsRoots[i], deposits[i].proof) + deposits[i].proof[32].data[0..7] = toBytesLE uint64(i + 1) - doAssert is_valid_merkle_branch( - deposit_data_roots[val_idx], deposits[val_idx].proof, - DEPOSIT_CONTRACT_TREE_DEPTH + 1, val_idx.uint64, - deposit_data_sums[val_idx]) - -proc testMerkleMinimal*(): bool = - proc toDigest[N: static int](x: array[N, byte]): Eth2Digest = - result.data[0 .. N-1] = x - - let a = [byte 0x01, 0x02, 0x03].toDigest - let b = [byte 0x04, 0x05, 0x06].toDigest - let c = [byte 0x07, 0x08, 0x09].toDigest - - block: # SSZ Sanity checks vs Python impl - block: # 3 leaves - let leaves = List[Eth2Digest, 3](@[a, b, c]) - let root = hash_tree_root(leaves) - doAssert $root == "9ff412e827b7c9d40fc7df2725021fd579ab762581d1ff5c270316682868456e".toUpperAscii - - block: # 2^3 leaves - let leaves = List[Eth2Digest, int64(1 shl 3)](@[a, b, c]) - let root = hash_tree_root(leaves) - doAssert $root == "5248085b588fab1dd1e03f3cd62201602b12e6560665935964f46e805977e8c5".toUpperAscii - - block: # 2^10 leaves - let leaves = List[Eth2Digest, int64(1 shl 10)](@[a, b, c]) - let root = hash_tree_root(leaves) - doAssert $root == "9fb7d518368dc14e8cc588fb3fd2749beef9f493fef70ae34af5721543c67173".toUpperAscii - - block: # Round-trips - # TODO: there is an issue (also in EF specs?) - # using hash_tree_root([a, b, c]) - # doesn't give the same hash as - # - hash_tree_root(@[a, b, c]) - # - sszList(@[a, b, c], int64(nleaves)) - # which both have the same hash. - # - # hash_tree_root([a, b, c]) gives the same hash as - # the last hash of merkleTreeFromLeaves - # - # Running tests with hash_tree_root([a, b, c]) - # works for depth 2 (3 or 4 leaves) - - macro roundTrips(): untyped = - result = newStmtList() - - # compile-time unrolled test - for nleaves in [3, 4, 5, 7, 8, 1 shl 10, 1 shl 32]: - let depth = fastLog2(nleaves-1) + 1 - - result.add quote do: - block: - let tree = merkleTreeFromLeaves([a, b, c], Depth = `depth`) - #echo "Tree: ", tree - - doAssert tree.nnznodes[`depth`].len == 1 - let root = tree.nnznodes[`depth`][0] - #echo "Root: ", root - - block: # proof for a - let index = 0 - - doAssert is_valid_merkle_branch( - a, get_merkle_proof(tree, index = index), - depth = `depth`, - index = index.uint64, - root = root - ), "Failed (depth: " & $`depth` & - ", nleaves: " & $`nleaves` & ')' - - block: # proof for b - let index = 1 - - doAssert is_valid_merkle_branch( - b, get_merkle_proof(tree, index = index), - depth = `depth`, - index = index.uint64, - root = root - ), "Failed (depth: " & $`depth` & - ", nleaves: " & $`nleaves` & ')' - - block: # proof for c - let index = 2 - - doAssert is_valid_merkle_branch( - c, get_merkle_proof(tree, index = index), - depth = `depth`, - index = index.uint64, - root = root - ), "Failed (depth: " & $`depth` & - ", nleaves: " & $`nleaves` & ')' - - roundTrips() - true - -when isMainModule: - discard testMerkleMinimal() diff --git a/beacon_chain/ssz/merkleization.nim b/beacon_chain/ssz/merkleization.nim index 136f624a3..91ffb3a15 100644 --- a/beacon_chain/ssz/merkleization.nim +++ b/beacon_chain/ssz/merkleization.nim @@ -29,11 +29,14 @@ const bitsPerChunk = bytesPerChunk * 8 type - SszChunksMerkleizer = object + SszChunksMerkleizer* = object combinedChunks: ptr UncheckedArray[Eth2Digest] totalChunks: uint64 topIndex: int +template chunks*(m: SszChunksMerkleizer): openarray[Eth2Digest] = + m.combinedChunks.toOpenArray(0, m.topIndex) + func digest(a, b: openArray[byte]): Eth2Digest = result = withEth2Hash: trs "DIGESTING ARRAYS ", toHex(a), " ", toHex(b) @@ -74,19 +77,10 @@ func computeZeroHashes: array[sizeof(Limit) * 8, Eth2Digest] = const zeroHashes* = computeZeroHashes() -func addChunk(merkleizer: var SszChunksMerkleizer, data: openarray[byte]) = +func addChunk*(merkleizer: var SszChunksMerkleizer, data: openarray[byte]) = doAssert data.len > 0 and data.len <= bytesPerChunk - if not getBitLE(merkleizer.totalChunks, 0): - let paddingBytes = bytesPerChunk - data.len - - merkleizer.combinedChunks[0].data[0.. 0 and merkleizer.topIndex > 0 + + let proofHeight = merkleizer.topIndex + 1 + result = newSeq[Eth2Digest](chunks.len * proofHeight) + + if chunks.len == 1: + merkleizer.addChunkAndGenMerkleProof(chunks[0], result) + return + + let newTotalChunks = merkleizer.totalChunks + chunks.len.uint64 + + var + # A perfect binary tree will take either `chunks.len * 2` values if the + # number of elements in the base layer is odd and `chunks.len * 2 - 1` + # otherwise. Each row may also need a single extra element at most if + # it must be combined with the existing values in the Merkleizer: + merkleTree = newSeqOfCap[Eth2Digest](chunks.len + merkleizer.topIndex) + inRowIdx = merkleizer.totalChunks + postUpdateInRowIdx = newTotalChunks + zeroMixed = false + + template writeResult(chunkIdx, level: int, chunk: Eth2Digest) = + result[chunkIdx * proofHeight + level] = chunk + + # We'll start by generating the first row of the merkle tree. + var currPairEnd = if inRowIdx.isOdd: + # an odd chunk number means that we must combine the + # hash with the existing pending sibling hash in the + # merkleizer. + writeResult(0, 0, merkleizer.combinedChunks[0]) + merkleTree.add mergeBranches(merkleizer.combinedChunks[0], chunks[0]) + + # TODO: can we immediately write this out? + merkleizer.completeStartedChunk(merkleTree[^1], 1) + 2 + else: + 1 + + if postUpdateInRowIdx.isOdd: + merkleizer.combinedChunks[0] = chunks[^1] + + while currPairEnd < chunks.len: + writeResult(currPairEnd - 1, 0, chunks[currPairEnd]) + writeResult(currPairEnd, 0, chunks[currPairEnd - 1]) + merkleTree.add mergeBranches(chunks[currPairEnd - 1], + chunks[currPairEnd]) + currPairEnd += 2 + + if currPairEnd - 1 < chunks.len: + zeroMixed = true + writeResult(currPairEnd - 1, 0, zeroHashes[0]) + merkleTree.add mergeBranches(chunks[currPairEnd - 1], + zeroHashes[0]) + var + level = 0 + baseChunksPerElement = 1 + treeRowStart = 0 + rowLen = merkleTree.len + + template writeProofs(rowChunkIdx: int, hash: Eth2Digest) = + let + startAbsIdx = (inRowIdx.int + rowChunkIdx) * baseChunksPerElement + endAbsIdx = startAbsIdx + baseChunksPerElement + startResIdx = max(startAbsIdx - merkleizer.totalChunks.int, 0) + endResIdx = min(endAbsIdx - merkleizer.totalChunks.int, chunks.len) + + for resultPos in startResIdx ..< endResIdx: + writeResult(resultPos, level, hash) + + if rowLen > 1: + while level < merkleizer.topIndex: + inc level + baseChunksPerElement *= 2 + inRowIdx = inRowIdx div 2 + postUpdateInRowIdx = postUpdateInRowIdx div 2 + + var currPairEnd = if inRowIdx.isOdd: + # an odd chunk number means that we must combine the + # hash with the existing pending sibling hash in the + # merkleizer. + writeProofs(0, merkleizer.combinedChunks[level]) + merkleTree.add mergeBranches(merkleizer.combinedChunks[level], + merkleTree[treeRowStart]) + + # TODO: can we immediately write this out? + merkleizer.completeStartedChunk(merkleTree[^1], level + 1) + 2 + else: + 1 + + if postUpdateInRowIdx.isOdd: + merkleizer.combinedChunks[level] = merkleTree[treeRowStart + rowLen - + ord(zeroMixed) - 1] + while currPairEnd < rowLen: + writeProofs(currPairEnd - 1, merkleTree[treeRowStart + currPairEnd]) + writeProofs(currPairEnd, merkleTree[treeRowStart + currPairEnd - 1]) + merkleTree.add mergeBranches(merkleTree[treeRowStart + currPairEnd - 1], + merkleTree[treeRowStart + currPairEnd]) + currPairEnd += 2 + + if currPairEnd - 1 < rowLen: + zeroMixed = true + writeProofs(currPairEnd - 1, zeroHashes[level]) + merkleTree.add mergeBranches(merkleTree[treeRowStart + currPairEnd - 1], + zeroHashes[level]) + + treeRowStart += rowLen + rowLen = merkleTree.len - treeRowStart + + if rowLen == 1: + break + + doAssert rowLen == 1 + + if (inRowIdx and 2) != 0: + merkleizer.completeStartedChunk( + mergeBranches(merkleizer.combinedChunks[level + 1], merkleTree[^1]), + level + 2) + + if (not zeroMixed) and (postUpdateInRowIdx and 2) != 0: + merkleizer.combinedChunks[level + 1] = merkleTree[^1] + + while level < merkleizer.topIndex: + inc level + baseChunksPerElement *= 2 + inRowIdx = inRowIdx div 2 + + let hash = if getBitLE(merkleizer.totalChunks, level): + merkleizer.combinedChunks[level] + else: + zeroHashes[level] + + writeProofs(0, hash) + + merkleizer.totalChunks = newTotalChunks + +func binaryTreeHeight*(totalElements: Limit): int = + bitWidth nextPow2(uint64 totalElements) + +type + SszHeapMerkleizer[limit: static[Limit]] = object + chunks: array[binaryTreeHeight limit, Eth2Digest] + m: SszChunksMerkleizer + +proc init*(S: type SszHeapMerkleizer): S = + result.m.combinedChunks = cast[ptr UncheckedArray[Eth2Digest]](addr result.chunks) + result.m.topIndex = result.limit - 1 + result.m.totalChunks = 0 + +template createMerkleizer*(totalElements: static Limit): SszChunksMerkleizer = trs "CREATING A MERKLEIZER FOR ", totalElements - const treeHeight = bitWidth nextPow2(uint64 totalElements) + const treeHeight = binaryTreeHeight totalElements var combinedChunks {.noInit.}: array[treeHeight, Eth2Digest] SszChunksMerkleizer( @@ -112,7 +307,7 @@ template createMerkleizer(totalElements: static Limit): SszChunksMerkleizer = topIndex: treeHeight - 1, totalChunks: 0) -func getFinalHash(merkleizer: var SszChunksMerkleizer): Eth2Digest = +func getFinalHash*(merkleizer: SszChunksMerkleizer): Eth2Digest = if merkleizer.totalChunks == 0: return zeroHashes[merkleizer.topIndex] diff --git a/tests/test_ssz_merkleization.nim b/tests/test_ssz_merkleization.nim new file mode 100644 index 000000000..3c1676abe --- /dev/null +++ b/tests/test_ssz_merkleization.nim @@ -0,0 +1,198 @@ +import + std/[strutils, sequtils, macros, bitops], + stew/[bitops2, endians2], + ../beacon_chain/spec/[beaconstate, datatypes, digest, helpers], + ../beacon_chain/[ssz, merkle_minimal], + mocking/mock_deposits + +proc testMerkleMinimal*(): bool = + proc toDigest[N: static int](x: array[N, byte]): Eth2Digest = + result.data[0 .. N-1] = x + + let a = [byte 0x01, 0x02, 0x03].toDigest + let b = [byte 0x04, 0x05, 0x06].toDigest + let c = [byte 0x07, 0x08, 0x09].toDigest + + block: # SSZ Sanity checks vs Python impl + block: # 3 leaves + let leaves = List[Eth2Digest, 3](@[a, b, c]) + let root = hash_tree_root(leaves) + doAssert $root == "9ff412e827b7c9d40fc7df2725021fd579ab762581d1ff5c270316682868456e".toUpperAscii + + block: # 2^3 leaves + let leaves = List[Eth2Digest, int64(1 shl 3)](@[a, b, c]) + let root = hash_tree_root(leaves) + doAssert $root == "5248085b588fab1dd1e03f3cd62201602b12e6560665935964f46e805977e8c5".toUpperAscii + + block: # 2^10 leaves + let leaves = List[Eth2Digest, int64(1 shl 10)](@[a, b, c]) + let root = hash_tree_root(leaves) + doAssert $root == "9fb7d518368dc14e8cc588fb3fd2749beef9f493fef70ae34af5721543c67173".toUpperAscii + + block: # Round-trips + # TODO: there is an issue (also in EF specs?) + # using hash_tree_root([a, b, c]) + # doesn't give the same hash as + # - hash_tree_root(@[a, b, c]) + # - sszList(@[a, b, c], int64(nleaves)) + # which both have the same hash. + # + # hash_tree_root([a, b, c]) gives the same hash as + # the last hash of merkleTreeFromLeaves + # + # Running tests with hash_tree_root([a, b, c]) + # works for depth 2 (3 or 4 leaves) + + macro roundTrips(): untyped = + result = newStmtList() + + # compile-time unrolled test + for nleaves in [3, 4, 5, 7, 8, 1 shl 10, 1 shl 32]: + let depth = fastLog2(nleaves-1) + 1 + + result.add quote do: + block: + let tree = merkleTreeFromLeaves([a, b, c], Depth = `depth`) + #echo "Tree: ", tree + + doAssert tree.nnznodes[`depth`].len == 1 + let root = tree.nnznodes[`depth`][0] + #echo "Root: ", root + + block: # proof for a + let index = 0 + + doAssert is_valid_merkle_branch( + a, tree.getMerkleProof(index = index), + depth = `depth`, + index = index.uint64, + root = root + ), "Failed (depth: " & $`depth` & + ", nleaves: " & $`nleaves` & ')' + + block: # proof for b + let index = 1 + + doAssert is_valid_merkle_branch( + b, tree.getMerkleProof(index = index), + depth = `depth`, + index = index.uint64, + root = root + ), "Failed (depth: " & $`depth` & + ", nleaves: " & $`nleaves` & ')' + + block: # proof for c + let index = 2 + + doAssert is_valid_merkle_branch( + c, tree.getMerkleProof(index = index), + depth = `depth`, + index = index.uint64, + root = root + ), "Failed (depth: " & $`depth` & + ", nleaves: " & $`nleaves` & ')' + + roundTrips() + true + +doAssert testMerkleMinimal() + +let + digests = mapIt(1..65, eth2digest toBytesLE(uint64 it)) + +proc compareTreeVsMerkleizer(hashes: openarray[Eth2Digest], limit: static Limit) = + const treeHeight = binaryTreeHeight(limit) + let tree = merkleTreeFromLeaves(hashes, treeHeight) + + var merkleizer = createMerkleizer(limit) + for hash in hashes: + merkleizer.addChunk hash.data + + doAssert merkleizer.getFinalHash() == tree.nnznodes[treeHeight - 1][0] + +proc testMultiProofsGeneration(preludeRecords: int, + totalProofs: int, + followUpRecords: int, + limit: static Limit) = + var + m1 = createMerkleizer(limit) + m2 = createMerkleizer(limit) + + var preludeHashes = newSeq[Eth2Digest]() + for i in 0 ..< preludeRecords: + let hash = eth2digest toBytesLE(uint64(100000000 + i)) + m1.addChunk hash.data + m2.addChunk hash.data + preludeHashes.add hash + + var proofsHashes = newSeq[Eth2Digest]() + for i in 0 ..< totalProofs: + let hash = eth2digest toBytesLE(uint64(200000000 + i)) + m1.addChunk hash.data + proofsHashes.add hash + + var proofs = addChunksAndGenMerkleProofs(m2, proofsHashes) + + const treeHeight = binaryTreeHeight(limit) + let merkleTree = merkleTreeFromLeaves(preludeHashes & proofsHashes, + treeHeight) + + doAssert m1.getFinalHash == merkleTree.nnznodes[treeHeight - 1][0] + doAssert m1.getFinalHash == m2.getFinalHash + + for i in 0 ..< totalProofs: + let + referenceProof = merkle_tree.getMerkleProof(preludeRecords + i, false) + startPos = i * treeHeight + endPos = startPos + treeHeight - 1 + + doAssert referenceProof == proofs.toOpenArray(startPos, endPos) + + for i in 0 ..< followUpRecords: + let hash = eth2digest toBytesLE(uint64(300000000 + i)) + m1.addChunk hash.data + m2.addChunk hash.data + + doAssert m1.getFinalHash == m2.getFinalHash + +for prelude in [0, 1, 2, 5, 6, 12, 13, 16]: + for proofs in [1, 2, 4, 17, 64]: + for followUpHashes in [0, 1, 2, 5, 7, 8, 15, 48]: + testMultiProofsGeneration(prelude, proofs, followUpHashes, 128) + testMultiProofsGeneration(prelude, proofs, followUpHashes, 5000) + +func attachMerkleProofsReferenceImpl(deposits: var openarray[Deposit]) = + let + deposit_data_roots = mapIt(deposits, it.data.hash_tree_root) + merkle_tree = merkleTreeFromLeaves(deposit_data_roots) + var + deposit_data_sums: seq[Eth2Digest] + for prefix_root in hash_tree_roots_prefix( + deposit_data_roots, 1'i64 shl DEPOSIT_CONTRACT_TREE_DEPTH): + deposit_data_sums.add prefix_root + + for val_idx in 0 ..< deposits.len: + deposits[val_idx].proof[0..31] = merkle_tree.getMerkleProof(val_idx, true) + deposits[val_idx].proof[32].data[0..7] = uint_to_bytes8((val_idx + 1).uint64) + + doAssert is_valid_merkle_branch( + deposit_data_roots[val_idx], deposits[val_idx].proof, + DEPOSIT_CONTRACT_TREE_DEPTH + 1, val_idx.uint64, + deposit_data_sums[val_idx]) + +proc testMerkleizer = + for i in 0 ..< digests.len: + compareTreeVsMerkleizer(digests.toOpenArray(0, i), 128) + compareTreeVsMerkleizer(digests.toOpenArray(0, i), 5000) + + var deposits = mockGenesisBalancedDeposits(65, 100000) + var depositsCopy = deposits + + attachMerkleProofsReferenceImpl(deposits) + attachMerkleProofs(depositsCopy) + + for i in 0 ..< deposits.len: + doAssert deposits[i].proof == depositsCopy[i].proof + +testMerkleizer() +