Refine merkle tree construction

This commit is contained in:
Tomasz Bekas 2023-08-15 13:37:14 +02:00
parent e8601274b9
commit 7a45fe8592
No known key found for this signature in database
GPG Key ID: 4854E04C98824959
2 changed files with 310 additions and 177 deletions

View File

@ -1,5 +1,5 @@
## Nim-Codex ## Nim-Codex
## Copyright (c) 2022 Status Research & Development GmbH ## Copyright (c) 2023 Status Research & Development GmbH
## Licensed under either of ## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) ## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT)) ## * MIT license ([LICENSE-MIT](LICENSE-MIT))
@ -7,41 +7,29 @@
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
import std/sequtils
import std/math import std/math
import std/bitops import std/bitops
import std/sugar import std/strutils
import pkg/libp2p
import pkg/stew/byteutils
import pkg/questionable
import pkg/questionable/results import pkg/questionable/results
import pkg/nimcrypto/sha2
const digestSize = sha256.sizeDigest
type type
MerkleHash* = MultiHash MerkleHash* = array[digestSize, byte]
MerkleTree* = object MerkleTree* = object
leavesCount: int leavesCount: Natural
nodes: seq[MerkleHash] nodes: seq[MerkleHash]
MerkleProof* = object MerkleProof* = object
index: int index: Natural
path: seq[MerkleHash] path: seq[MerkleHash]
MerkleTreeBuilder* = object
buffer: seq[MerkleHash]
# Tree constructed from leaves H0..H2 is ###########################################################
# H5=H(H3 & H4) # Helper functions
# / \ ###########################################################
# H3=H(H0 & H1) H4=H(H2 & H2)
# / \ /
# H0=H(A) H1=H(B) H2=H(C)
# | | |
# A B C
#
# Memory layout is [H0, H1, H2, H3, H4, H5]
#
# Proofs of inclusion are
# - [H1, H4] for A
# - [H0, H4] for B
# - [H2, H3] for C
func computeTreeHeight(leavesCount: int): int = func computeTreeHeight(leavesCount: int): int =
if isPowerOfTwo(leavesCount): if isPowerOfTwo(leavesCount):
@ -49,48 +37,100 @@ func computeTreeHeight(leavesCount: int): int =
else: else:
fastLog2(leavesCount) + 2 fastLog2(leavesCount) + 2
func getLowHigh(leavesCount, level: int): (int, int) = func computeLevels(leavesCount: int): seq[tuple[offset: int, width: int]] =
var width = leavesCount
var low = 0
for _ in 0..<level:
low += width
width = (width + 1) div 2
(low, low + width - 1)
func getLowHigh(self: MerkleTree, level: int): (int, int) =
getLowHigh(self.leavesCount, level)
func computeTotalSize(leavesCount: int): int =
let height = computeTreeHeight(leavesCount) let height = computeTreeHeight(leavesCount)
getLowHigh(leavesCount, height - 1)[1] + 1 result = newSeq[tuple[offset: int, width: int]](height)
proc getWidth(self: MerkleTree, level: int): int = result[0].offset = 0
let (low, high) = self.getLowHigh(level) result[0].width = leavesCount
high - low + 1 for i in 1..<height:
result[i].offset = result[i - 1].offset + result[i - 1].width
result[i].width = (result[i - 1].width + 1) div 2
func getChildren(self: MerkleTree, level, i: int): (MerkleHash, MerkleHash) = proc digestFn(data: openArray[byte], output: var MerkleHash): void =
let (low, high) = self.getLowHigh(level - 1) var digest = sha256.digest(data)
let leftIdx = low + 2 * i copyMem(addr output, addr digest.data[0], digestSize)
let rightIdx = min(leftIdx + 1, high)
(self.nodes[leftIdx], self.nodes[rightIdx]) ###########################################################
# MerkleHash
###########################################################
func getSibling(self: MerkleTree, level, i: int): MerkleHash = var zeroHash: MerkleHash
let (low, high) = self.getLowHigh(level)
if i mod 2 == 0: proc `$`*(self: MerkleHash): string =
self.nodes[min(low + i + 1, high)] result = newStringOfCap(self.len)
for i in 0..<self.len:
result.add(toHex(self[i]))
###########################################################
# MerkleTreeBuilder
###########################################################
proc addDataBlock*(self: var MerkleTreeBuilder, dataBlock: openArray[byte]): void =
## Hashes the data block and adds the result of hashing to a buffer
##
let oldLen = self.buffer.len
self.buffer.setLen(oldLen + 1)
digestFn(dataBlock, self.buffer[oldLen])
proc build*(self: MerkleTreeBuilder): ?!MerkleTree =
## Builds a tree from previously added data blocks
##
## Tree built from data blocks A, B and C is
## H5=H(H3 & H4)
## / \
## H3=H(H0 & H1) H4=H(H2 & HZ)
## / \ /
## H0=H(A) H1=H(B) H2=H(C)
## | | |
## A B C
##
## where HZ=H(0x0b)
##
## Memory layout is [H0, H1, H2, H3, H4, H5]
##
let leavesCount = self.buffer.len
if leavesCount == 0:
return failure("At least one data block is required")
let levels = computeLevels(leavesCount)
let totalSize = levels[^1].offset + 1
var tree = MerkleTree(leavesCount: leavesCount, nodes: newSeq[MerkleHash](totalSize))
# copy leaves
copyMem(addr tree.nodes[0], unsafeAddr self.buffer[0], leavesCount * digestSize)
# calculate intermediate nodes
var concatBuf: array[2 * digestSize, byte]
var prevLevel = levels[0]
for level in levels[1..^1]:
for i in 0..<level.width:
let parentIndex = level.offset + i
let leftChildIndex = prevLevel.offset + 2 * i
let rightChildIndex = leftChildIndex + 1
copyMem(addr concatBuf[0], addr tree.nodes[leftChildIndex], digestSize)
if rightChildIndex < prevLevel.offset + prevLevel.width:
copyMem(addr concatBuf[digestSize], addr tree.nodes[rightChildIndex], digestSize)
else: else:
self.nodes[low + i - 1] copyMem(addr concatBuf[digestSize], addr zeroHash, digestSize)
proc setNode(self: var MerkleTree, level, i: int, value: MerkleHash): void = digestFn(concatBuf, tree.nodes[parentIndex])
let (low, _) = self.getLowHigh(level) prevLevel = level
self.nodes[low + i] = value
return success(tree)
###########################################################
# MerkleTree
###########################################################
proc root*(self: MerkleTree): MerkleHash = proc root*(self: MerkleTree): MerkleHash =
self.nodes[^1] self.nodes[^1]
proc len*(self: MerkleTree): int = proc len*(self: MerkleTree): Natural =
self.nodes.len self.nodes.len
proc leaves*(self: MerkleTree): seq[MerkleHash] = proc leaves*(self: MerkleTree): seq[MerkleHash] =
@ -99,76 +139,56 @@ proc leaves*(self: MerkleTree): seq[MerkleHash] =
proc nodes*(self: MerkleTree): seq[MerkleHash] = proc nodes*(self: MerkleTree): seq[MerkleHash] =
self.nodes self.nodes
proc height*(self: MerkleTree): int = proc height*(self: MerkleTree): Natural =
computeTreeHeight(self.leavesCount) computeTreeHeight(self.leavesCount)
proc getProof*(self: MerkleTree, index: Natural): ?!MerkleProof =
## Extracts proof from a tree for a given index
##
## Given a tree built from data blocks A, B and C
## H5
## / \
## H3 H4
## / \ /
## H0 H1 H2
## | | |
## A B C
##
## Proofs of inclusion (index and path) are
## - 0,[H1, H4] for data block A
## - 1,[H0, H4] for data block B
## - 2,[HZ, H3] for data block C
##
## where HZ=H(0x0b)
##
if index >= self.leavesCount:
return failure("Index " & $index & " out of range [0.." & $self.leaves.high & "]" )
let levels = computeLevels(self.leavesCount)
var path = newSeq[MerkleHash](levels.len - 1)
for levelIndex, level in levels[0..^2]:
let i = index div (1 shl levelIndex)
let siblingIndex = if i mod 2 == 0:
level.offset + i + 1
else:
level.offset + i - 1
if siblingIndex < level.offset + level.width:
path[levelIndex] = self.nodes[siblingIndex]
else:
path[levelIndex] = zeroHash
success(MerkleProof(index: index, path: path))
proc `$`*(self: MerkleTree): string = proc `$`*(self: MerkleTree): string =
result &= "leavesCount: " & $self.leavesCount result &= "leavesCount: " & $self.leavesCount
result &= "\nnodes: " & $self.nodes result &= "\nnodes: " & $self.nodes
proc getProof*(self: MerkleTree, index: int): ?!MerkleProof = ###########################################################
if index >= self.leavesCount or index < 0: # MerkleProof
return failure("Index " & $index & " out of range [0.." & $self.leaves.high & "]" ) ###########################################################
var path = newSeq[MerkleHash](self.height - 1) proc index*(self: MerkleProof): Natural =
for level in 0..<path.len:
let i = index div (1 shl level)
path[level] = self.getSibling(level, i)
success(MerkleProof(index: index, path: path))
proc initTreeFromLeaves(leaves: openArray[MerkleHash]): ?!MerkleTree =
without mcodec =? leaves.?[0].?mcodec and
digestSize =? leaves.?[0].?size:
return failure("At least one leaf is required")
if not leaves.allIt(it.mcodec == mcodec):
return failure("All leaves must use the same codec")
let totalSize = computeTotalSize(leaves.len)
var tree = MerkleTree(leavesCount: leaves.len, nodes: newSeq[MerkleHash](totalSize))
var buf = newSeq[byte](digestSize * 2)
proc combine(l, r: MerkleHash): ?!MerkleHash =
copyMem(addr buf[0], unsafeAddr l.data.buffer[0], digestSize)
copyMem(addr buf[digestSize], unsafeAddr r.data.buffer[0], digestSize)
MultiHash.digest($mcodec, buf).mapErr(
c => newException(CatchableError, "Error calculating hash using codec " & $mcodec & ": " & $c)
)
# copy leaves
for i in 0..<tree.getWidth(0):
tree.setNode(0, i, leaves[i])
# calculate intermediate nodes
for level in 1..<tree.height:
for i in 0..<tree.getWidth(level):
let (left, right) = tree.getChildren(level, i)
without mhash =? combine(left, right), error:
return failure(error)
tree.setNode(level, i, mhash)
success(tree)
func init*(
T: type MerkleTree,
root: MerkleHash,
leavesCount: int
): MerkleTree =
let totalSize = computeTotalSize(leavesCount)
var nodes = newSeq[MerkleHash](totalSize)
nodes[^1] = root
MerkleTree(nodes: nodes, leavesCount: leavesCount)
proc init*(
T: type MerkleTree,
leaves: openArray[MerkleHash]
): ?!MerkleTree =
initTreeFromLeaves(leaves)
proc index*(self: MerkleProof): int =
self.index self.index
proc path*(self: MerkleProof): seq[MerkleHash] = proc path*(self: MerkleProof): seq[MerkleHash] =
@ -183,7 +203,7 @@ func `==`*(a, b: MerkleProof): bool =
proc init*( proc init*(
T: type MerkleProof, T: type MerkleProof,
index: int, index: Natural,
path: seq[MerkleHash] path: seq[MerkleHash]
): MerkleProof = ): MerkleProof =
MerkleProof(index: index, path: path) MerkleProof(index: index, path: path)

View File

@ -1,73 +1,138 @@
import std/unittest import std/unittest
import std/bitops
import std/random
import std/sequtils import std/sequtils
import pkg/libp2p import std/tables
import pkg/questionable/results
import pkg/stew/byteutils
import pkg/nimcrypto/sha2
import codex/merkletree/merkletree import codex/merkletree/merkletree
import ../helpers import ../helpers
import pkg/questionable/results
checksuite "merkletree": checksuite "merkletree":
const sha256 = multiCodec("sha2-256") const data =
const sha512 = multiCodec("sha2-512") [
"0123456789012345678901234567890123456789".toBytes,
"1234567890123456789012345678901234567890".toBytes,
"2345678901234567890123456789012345678901".toBytes,
"3456789012345678901234567890123456789012".toBytes,
"4567890123456789012345678901234567890123".toBytes,
"5678901234567890123456789012345678901234".toBytes,
"6789012345678901234567890123456789012345".toBytes,
"7890123456789012345678901234567890123456".toBytes,
"8901234567890123456789012345678901234567".toBytes,
"9012345678901234567890123456789012345678".toBytes,
]
var zeroHash: MerkleHash
var expectedLeaves: array[data.len, MerkleHash]
var builder: MerkleTreeBuilder
proc randomHash(codec: MultiCodec = sha256): MerkleHash = proc combine(a, b: MerkleHash): MerkleHash =
var data: array[0..31, byte] var buf = newSeq[byte](a.len + b.len)
for i in 0..31: for i in 0..<a.len:
data[i] = rand(uint8) buf[i] = a[i]
return MultiHash.digest($codec, data).tryGet() for i in 0..<b.len:
buf[i + a.len] = b[i]
proc combine(a, b: MerkleHash, codec: MultiCodec = sha256): MerkleHash = var digest = sha256.digest(buf)
var buf = newSeq[byte](a.size + b.size) return digest.data
for i in 0..<a.size:
buf[i] = a.data.buffer[i]
for i in 0..<b.size:
buf[i + a.size] = b.data.buffer[i]
return MultiHash.digest($codec, buf).tryGet()
var
leaves: array[0..10, MerkleHash]
setup: setup:
for i in 0..leaves.high: for i in 0..<data.len:
leaves[i] = randomHash() var digest = sha256.digest(data[i])
expectedLeaves[i] = digest.data
builder = MerkleTreeBuilder()
test "tree with one leaf has expected root": test "tree with one leaf has expected root":
let tree = MerkleTree.init(leaves[0..0]).tryGet() builder.addDataBlock(data[0])
let tree = builder.build().tryGet()
check: check:
tree.leaves == leaves[0..0] tree.leaves == expectedLeaves[0..0]
tree.root == leaves[0] tree.root == expectedLeaves[0]
tree.len == 1 tree.len == 1
test "tree with two leaves has expected root": test "tree with two leaves has expected root":
let builder.addDataBlock(data[0])
expectedRoot = combine(leaves[0], leaves[1]) builder.addDataBlock(data[1])
let tree = MerkleTree.init(leaves[0..1]).tryGet() let tree = builder.build().tryGet()
let expectedRoot = combine(expectedLeaves[0], expectedLeaves[1])
check: check:
tree.leaves == leaves[0..1] tree.leaves == expectedLeaves[0..1]
tree.len == 3 tree.len == 3
tree.root == expectedRoot tree.root == expectedRoot
test "tree with three leaves has expected root": test "tree with three leaves has expected root":
let builder.addDataBlock(data[0])
expectedRoot = combine(combine(leaves[0], leaves[1]), combine(leaves[2], leaves[2])) builder.addDataBlock(data[1])
builder.addDataBlock(data[2])
let tree = MerkleTree.init(leaves[0..2]).tryGet() let tree = builder.build().tryGet()
let
expectedRoot = combine(
combine(expectedLeaves[0], expectedLeaves[1]),
combine(expectedLeaves[2], zeroHash)
)
check: check:
tree.leaves == leaves[0..2] tree.leaves == expectedLeaves[0..2]
tree.len == 6 tree.len == 6
tree.root == expectedRoot tree.root == expectedRoot
test "tree with ten leaves has expected root":
builder.addDataBlock(data[0])
builder.addDataBlock(data[1])
builder.addDataBlock(data[2])
builder.addDataBlock(data[3])
builder.addDataBlock(data[4])
builder.addDataBlock(data[5])
builder.addDataBlock(data[6])
builder.addDataBlock(data[7])
builder.addDataBlock(data[8])
builder.addDataBlock(data[9])
let tree = builder.build().tryGet()
let
expectedRoot = combine(
combine(
combine(
combine(expectedLeaves[0], expectedLeaves[1]),
combine(expectedLeaves[2], expectedLeaves[3]),
),
combine(
combine(expectedLeaves[4], expectedLeaves[5]),
combine(expectedLeaves[6], expectedLeaves[7])
)
),
combine(
combine(
combine(expectedLeaves[8], expectedLeaves[9]),
zeroHash
),
zeroHash
)
)
check:
tree.leaves == expectedLeaves[0..9]
tree.len == 21
tree.root == expectedRoot
test "tree with two leaves provides expected proofs": test "tree with two leaves provides expected proofs":
let tree = MerkleTree.init(leaves[0..1]).tryGet() builder.addDataBlock(data[0])
builder.addDataBlock(data[1])
let tree = builder.build().tryGet()
let expectedProofs = [ let expectedProofs = [
MerkleProof.init(0, @[leaves[1]]), MerkleProof.init(0, @[expectedLeaves[1]]),
MerkleProof.init(1, @[leaves[0]]), MerkleProof.init(1, @[expectedLeaves[0]]),
] ]
check: check:
@ -75,12 +140,16 @@ checksuite "merkletree":
tree.getProof(1).tryGet() == expectedProofs[1] tree.getProof(1).tryGet() == expectedProofs[1]
test "tree with three leaves provides expected proofs": test "tree with three leaves provides expected proofs":
let tree = MerkleTree.init(leaves[0..2]).tryGet() builder.addDataBlock(data[0])
builder.addDataBlock(data[1])
builder.addDataBlock(data[2])
let tree = builder.build().tryGet()
let expectedProofs = [ let expectedProofs = [
MerkleProof.init(0, @[leaves[1], combine(leaves[2], leaves[2])]), MerkleProof.init(0, @[expectedLeaves[1], combine(expectedLeaves[2], zeroHash)]),
MerkleProof.init(1, @[leaves[0], combine(leaves[2], leaves[2])]), MerkleProof.init(1, @[expectedLeaves[0], combine(expectedLeaves[2], zeroHash)]),
MerkleProof.init(2, @[leaves[2], combine(leaves[0], leaves[1])]), MerkleProof.init(2, @[zeroHash, combine(expectedLeaves[0], expectedLeaves[1])]),
] ]
check: check:
@ -88,21 +157,65 @@ checksuite "merkletree":
tree.getProof(1).tryGet() == expectedProofs[1] tree.getProof(1).tryGet() == expectedProofs[1]
tree.getProof(2).tryGet() == expectedProofs[2] tree.getProof(2).tryGet() == expectedProofs[2]
test "tree with ten leaves provides expected proofs":
builder.addDataBlock(data[0])
builder.addDataBlock(data[1])
builder.addDataBlock(data[2])
builder.addDataBlock(data[3])
builder.addDataBlock(data[4])
builder.addDataBlock(data[5])
builder.addDataBlock(data[6])
builder.addDataBlock(data[7])
builder.addDataBlock(data[8])
builder.addDataBlock(data[9])
let tree = builder.build().tryGet()
let expectedProofs = {
4:
MerkleProof.init(4, @[
expectedLeaves[5],
combine(expectedLeaves[6], expectedLeaves[7]),
combine(
combine(expectedLeaves[0], expectedLeaves[1]),
combine(expectedLeaves[2], expectedLeaves[3]),
),
combine(
combine(
combine(expectedLeaves[8], expectedLeaves[9]),
zeroHash
),
zeroHash
)
]),
9:
MerkleProof.init(9, @[
expectedLeaves[8],
zeroHash,
zeroHash,
combine(
combine(
combine(expectedLeaves[0], expectedLeaves[1]),
combine(expectedLeaves[2], expectedLeaves[3]),
),
combine(
combine(expectedLeaves[4], expectedLeaves[5]),
combine(expectedLeaves[6], expectedLeaves[7])
)
)
]),
}.newTable
check:
tree.getProof(4).tryGet() == expectedProofs[4]
tree.getProof(9).tryGet() == expectedProofs[9]
test "getProof fails for index out of bounds": test "getProof fails for index out of bounds":
let tree = MerkleTree.init(leaves[0..3]).tryGet() builder.addDataBlock(data[0])
builder.addDataBlock(data[1])
builder.addDataBlock(data[2])
let tree = builder.build().tryGet()
check: check:
isErr(tree.getProof(-1))
isErr(tree.getProof(4)) isErr(tree.getProof(4))
test "can create MerkleTree directly from root hash":
let tree = MerkleTree.init(leaves[0], 1)
check:
tree.root == leaves[0]
test "cannot create MerkleTree from leaves with different codec":
let res = MerkleTree.init(@[randomHash(sha256), randomHash(sha512)])
check:
isErr(res)