Refine merkle tree construction

This commit is contained in:
Tomasz Bekas 2023-08-15 13:37:14 +02:00
parent e8601274b9
commit 7a45fe8592
No known key found for this signature in database
GPG Key ID: 4854E04C98824959
2 changed files with 310 additions and 177 deletions

View File

@ -1,5 +1,5 @@
## Nim-Codex
## Copyright (c) 2022 Status Research & Development GmbH
## Copyright (c) 2023 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
@ -7,41 +7,29 @@
## This file may not be copied, modified, or distributed except according to
## those terms.
import std/sequtils
import std/math
import std/bitops
import std/sugar
import std/strutils
import pkg/libp2p
import pkg/stew/byteutils
import pkg/questionable
import pkg/questionable/results
import pkg/nimcrypto/sha2
const digestSize = sha256.sizeDigest
type
MerkleHash* = MultiHash
MerkleHash* = array[digestSize, byte]
MerkleTree* = object
leavesCount: int
leavesCount: Natural
nodes: seq[MerkleHash]
MerkleProof* = object
index: int
index: Natural
path: seq[MerkleHash]
MerkleTreeBuilder* = object
buffer: seq[MerkleHash]
# Tree constructed from leaves H0..H2 is
# H5=H(H3 & H4)
# / \
# H3=H(H0 & H1) H4=H(H2 & H2)
# / \ /
# H0=H(A) H1=H(B) H2=H(C)
# | | |
# A B C
#
# Memory layout is [H0, H1, H2, H3, H4, H5]
#
# Proofs of inclusion are
# - [H1, H4] for A
# - [H0, H4] for B
# - [H2, H3] for C
###########################################################
# Helper functions
###########################################################
func computeTreeHeight(leavesCount: int): int =
if isPowerOfTwo(leavesCount):
@ -49,48 +37,100 @@ func computeTreeHeight(leavesCount: int): int =
else:
fastLog2(leavesCount) + 2
func getLowHigh(leavesCount, level: int): (int, int) =
var width = leavesCount
var low = 0
for _ in 0..<level:
low += width
width = (width + 1) div 2
(low, low + width - 1)
func getLowHigh(self: MerkleTree, level: int): (int, int) =
getLowHigh(self.leavesCount, level)
func computeTotalSize(leavesCount: int): int =
func computeLevels(leavesCount: int): seq[tuple[offset: int, width: int]] =
let height = computeTreeHeight(leavesCount)
getLowHigh(leavesCount, height - 1)[1] + 1
result = newSeq[tuple[offset: int, width: int]](height)
proc getWidth(self: MerkleTree, level: int): int =
let (low, high) = self.getLowHigh(level)
high - low + 1
result[0].offset = 0
result[0].width = leavesCount
for i in 1..<height:
result[i].offset = result[i - 1].offset + result[i - 1].width
result[i].width = (result[i - 1].width + 1) div 2
func getChildren(self: MerkleTree, level, i: int): (MerkleHash, MerkleHash) =
let (low, high) = self.getLowHigh(level - 1)
let leftIdx = low + 2 * i
let rightIdx = min(leftIdx + 1, high)
proc digestFn(data: openArray[byte], output: var MerkleHash): void =
var digest = sha256.digest(data)
copyMem(addr output, addr digest.data[0], digestSize)
(self.nodes[leftIdx], self.nodes[rightIdx])
###########################################################
# MerkleHash
###########################################################
func getSibling(self: MerkleTree, level, i: int): MerkleHash =
let (low, high) = self.getLowHigh(level)
if i mod 2 == 0:
self.nodes[min(low + i + 1, high)]
else:
self.nodes[low + i - 1]
var zeroHash: MerkleHash
proc setNode(self: var MerkleTree, level, i: int, value: MerkleHash): void =
let (low, _) = self.getLowHigh(level)
self.nodes[low + i] = value
proc `$`*(self: MerkleHash): string =
result = newStringOfCap(self.len)
for i in 0..<self.len:
result.add(toHex(self[i]))
###########################################################
# MerkleTreeBuilder
###########################################################
proc addDataBlock*(self: var MerkleTreeBuilder, dataBlock: openArray[byte]): void =
## Hashes the data block and adds the result of hashing to a buffer
##
let oldLen = self.buffer.len
self.buffer.setLen(oldLen + 1)
digestFn(dataBlock, self.buffer[oldLen])
proc build*(self: MerkleTreeBuilder): ?!MerkleTree =
## Builds a tree from previously added data blocks
##
## Tree built from data blocks A, B and C is
## H5=H(H3 & H4)
## / \
## H3=H(H0 & H1) H4=H(H2 & HZ)
## / \ /
## H0=H(A) H1=H(B) H2=H(C)
## | | |
## A B C
##
## where HZ=H(0x0b)
##
## Memory layout is [H0, H1, H2, H3, H4, H5]
##
let leavesCount = self.buffer.len
if leavesCount == 0:
return failure("At least one data block is required")
let levels = computeLevels(leavesCount)
let totalSize = levels[^1].offset + 1
var tree = MerkleTree(leavesCount: leavesCount, nodes: newSeq[MerkleHash](totalSize))
# copy leaves
copyMem(addr tree.nodes[0], unsafeAddr self.buffer[0], leavesCount * digestSize)
# calculate intermediate nodes
var concatBuf: array[2 * digestSize, byte]
var prevLevel = levels[0]
for level in levels[1..^1]:
for i in 0..<level.width:
let parentIndex = level.offset + i
let leftChildIndex = prevLevel.offset + 2 * i
let rightChildIndex = leftChildIndex + 1
copyMem(addr concatBuf[0], addr tree.nodes[leftChildIndex], digestSize)
if rightChildIndex < prevLevel.offset + prevLevel.width:
copyMem(addr concatBuf[digestSize], addr tree.nodes[rightChildIndex], digestSize)
else:
copyMem(addr concatBuf[digestSize], addr zeroHash, digestSize)
digestFn(concatBuf, tree.nodes[parentIndex])
prevLevel = level
return success(tree)
###########################################################
# MerkleTree
###########################################################
proc root*(self: MerkleTree): MerkleHash =
self.nodes[^1]
proc len*(self: MerkleTree): int =
proc len*(self: MerkleTree): Natural =
self.nodes.len
proc leaves*(self: MerkleTree): seq[MerkleHash] =
@ -99,76 +139,56 @@ proc leaves*(self: MerkleTree): seq[MerkleHash] =
proc nodes*(self: MerkleTree): seq[MerkleHash] =
self.nodes
proc height*(self: MerkleTree): int =
proc height*(self: MerkleTree): Natural =
computeTreeHeight(self.leavesCount)
proc getProof*(self: MerkleTree, index: Natural): ?!MerkleProof =
## Extracts proof from a tree for a given index
##
## Given a tree built from data blocks A, B and C
## H5
## / \
## H3 H4
## / \ /
## H0 H1 H2
## | | |
## A B C
##
## Proofs of inclusion (index and path) are
## - 0,[H1, H4] for data block A
## - 1,[H0, H4] for data block B
## - 2,[HZ, H3] for data block C
##
## where HZ=H(0x0b)
##
if index >= self.leavesCount:
return failure("Index " & $index & " out of range [0.." & $self.leaves.high & "]" )
let levels = computeLevels(self.leavesCount)
var path = newSeq[MerkleHash](levels.len - 1)
for levelIndex, level in levels[0..^2]:
let i = index div (1 shl levelIndex)
let siblingIndex = if i mod 2 == 0:
level.offset + i + 1
else:
level.offset + i - 1
if siblingIndex < level.offset + level.width:
path[levelIndex] = self.nodes[siblingIndex]
else:
path[levelIndex] = zeroHash
success(MerkleProof(index: index, path: path))
proc `$`*(self: MerkleTree): string =
result &= "leavesCount: " & $self.leavesCount
result &= "\nnodes: " & $self.nodes
proc getProof*(self: MerkleTree, index: int): ?!MerkleProof =
if index >= self.leavesCount or index < 0:
return failure("Index " & $index & " out of range [0.." & $self.leaves.high & "]" )
###########################################################
# MerkleProof
###########################################################
var path = newSeq[MerkleHash](self.height - 1)
for level in 0..<path.len:
let i = index div (1 shl level)
path[level] = self.getSibling(level, i)
success(MerkleProof(index: index, path: path))
proc initTreeFromLeaves(leaves: openArray[MerkleHash]): ?!MerkleTree =
without mcodec =? leaves.?[0].?mcodec and
digestSize =? leaves.?[0].?size:
return failure("At least one leaf is required")
if not leaves.allIt(it.mcodec == mcodec):
return failure("All leaves must use the same codec")
let totalSize = computeTotalSize(leaves.len)
var tree = MerkleTree(leavesCount: leaves.len, nodes: newSeq[MerkleHash](totalSize))
var buf = newSeq[byte](digestSize * 2)
proc combine(l, r: MerkleHash): ?!MerkleHash =
copyMem(addr buf[0], unsafeAddr l.data.buffer[0], digestSize)
copyMem(addr buf[digestSize], unsafeAddr r.data.buffer[0], digestSize)
MultiHash.digest($mcodec, buf).mapErr(
c => newException(CatchableError, "Error calculating hash using codec " & $mcodec & ": " & $c)
)
# copy leaves
for i in 0..<tree.getWidth(0):
tree.setNode(0, i, leaves[i])
# calculate intermediate nodes
for level in 1..<tree.height:
for i in 0..<tree.getWidth(level):
let (left, right) = tree.getChildren(level, i)
without mhash =? combine(left, right), error:
return failure(error)
tree.setNode(level, i, mhash)
success(tree)
func init*(
T: type MerkleTree,
root: MerkleHash,
leavesCount: int
): MerkleTree =
let totalSize = computeTotalSize(leavesCount)
var nodes = newSeq[MerkleHash](totalSize)
nodes[^1] = root
MerkleTree(nodes: nodes, leavesCount: leavesCount)
proc init*(
T: type MerkleTree,
leaves: openArray[MerkleHash]
): ?!MerkleTree =
initTreeFromLeaves(leaves)
proc index*(self: MerkleProof): int =
proc index*(self: MerkleProof): Natural =
self.index
proc path*(self: MerkleProof): seq[MerkleHash] =
@ -183,7 +203,7 @@ func `==`*(a, b: MerkleProof): bool =
proc init*(
T: type MerkleProof,
index: int,
index: Natural,
path: seq[MerkleHash]
): MerkleProof =
MerkleProof(index: index, path: path)

View File

@ -1,73 +1,138 @@
import std/unittest
import std/bitops
import std/random
import std/sequtils
import pkg/libp2p
import std/tables
import pkg/questionable/results
import pkg/stew/byteutils
import pkg/nimcrypto/sha2
import codex/merkletree/merkletree
import ../helpers
import pkg/questionable/results
checksuite "merkletree":
const sha256 = multiCodec("sha2-256")
const sha512 = multiCodec("sha2-512")
const data =
[
"0123456789012345678901234567890123456789".toBytes,
"1234567890123456789012345678901234567890".toBytes,
"2345678901234567890123456789012345678901".toBytes,
"3456789012345678901234567890123456789012".toBytes,
"4567890123456789012345678901234567890123".toBytes,
"5678901234567890123456789012345678901234".toBytes,
"6789012345678901234567890123456789012345".toBytes,
"7890123456789012345678901234567890123456".toBytes,
"8901234567890123456789012345678901234567".toBytes,
"9012345678901234567890123456789012345678".toBytes,
]
var zeroHash: MerkleHash
var expectedLeaves: array[data.len, MerkleHash]
var builder: MerkleTreeBuilder
proc randomHash(codec: MultiCodec = sha256): MerkleHash =
var data: array[0..31, byte]
for i in 0..31:
data[i] = rand(uint8)
return MultiHash.digest($codec, data).tryGet()
proc combine(a, b: MerkleHash, codec: MultiCodec = sha256): MerkleHash =
var buf = newSeq[byte](a.size + b.size)
for i in 0..<a.size:
buf[i] = a.data.buffer[i]
for i in 0..<b.size:
buf[i + a.size] = b.data.buffer[i]
return MultiHash.digest($codec, buf).tryGet()
var
leaves: array[0..10, MerkleHash]
proc combine(a, b: MerkleHash): MerkleHash =
var buf = newSeq[byte](a.len + b.len)
for i in 0..<a.len:
buf[i] = a[i]
for i in 0..<b.len:
buf[i + a.len] = b[i]
var digest = sha256.digest(buf)
return digest.data
setup:
for i in 0..leaves.high:
leaves[i] = randomHash()
for i in 0..<data.len:
var digest = sha256.digest(data[i])
expectedLeaves[i] = digest.data
builder = MerkleTreeBuilder()
test "tree with one leaf has expected root":
let tree = MerkleTree.init(leaves[0..0]).tryGet()
builder.addDataBlock(data[0])
let tree = builder.build().tryGet()
check:
tree.leaves == leaves[0..0]
tree.root == leaves[0]
tree.leaves == expectedLeaves[0..0]
tree.root == expectedLeaves[0]
tree.len == 1
test "tree with two leaves has expected root":
let
expectedRoot = combine(leaves[0], leaves[1])
builder.addDataBlock(data[0])
builder.addDataBlock(data[1])
let tree = MerkleTree.init(leaves[0..1]).tryGet()
let tree = builder.build().tryGet()
let expectedRoot = combine(expectedLeaves[0], expectedLeaves[1])
check:
tree.leaves == leaves[0..1]
tree.leaves == expectedLeaves[0..1]
tree.len == 3
tree.root == expectedRoot
test "tree with three leaves has expected root":
let
expectedRoot = combine(combine(leaves[0], leaves[1]), combine(leaves[2], leaves[2]))
builder.addDataBlock(data[0])
builder.addDataBlock(data[1])
builder.addDataBlock(data[2])
let tree = MerkleTree.init(leaves[0..2]).tryGet()
let tree = builder.build().tryGet()
let
expectedRoot = combine(
combine(expectedLeaves[0], expectedLeaves[1]),
combine(expectedLeaves[2], zeroHash)
)
check:
tree.leaves == leaves[0..2]
tree.leaves == expectedLeaves[0..2]
tree.len == 6
tree.root == expectedRoot
test "tree with ten leaves has expected root":
builder.addDataBlock(data[0])
builder.addDataBlock(data[1])
builder.addDataBlock(data[2])
builder.addDataBlock(data[3])
builder.addDataBlock(data[4])
builder.addDataBlock(data[5])
builder.addDataBlock(data[6])
builder.addDataBlock(data[7])
builder.addDataBlock(data[8])
builder.addDataBlock(data[9])
let tree = builder.build().tryGet()
let
expectedRoot = combine(
combine(
combine(
combine(expectedLeaves[0], expectedLeaves[1]),
combine(expectedLeaves[2], expectedLeaves[3]),
),
combine(
combine(expectedLeaves[4], expectedLeaves[5]),
combine(expectedLeaves[6], expectedLeaves[7])
)
),
combine(
combine(
combine(expectedLeaves[8], expectedLeaves[9]),
zeroHash
),
zeroHash
)
)
check:
tree.leaves == expectedLeaves[0..9]
tree.len == 21
tree.root == expectedRoot
test "tree with two leaves provides expected proofs":
let tree = MerkleTree.init(leaves[0..1]).tryGet()
builder.addDataBlock(data[0])
builder.addDataBlock(data[1])
let tree = builder.build().tryGet()
let expectedProofs = [
MerkleProof.init(0, @[leaves[1]]),
MerkleProof.init(1, @[leaves[0]]),
MerkleProof.init(0, @[expectedLeaves[1]]),
MerkleProof.init(1, @[expectedLeaves[0]]),
]
check:
@ -75,12 +140,16 @@ checksuite "merkletree":
tree.getProof(1).tryGet() == expectedProofs[1]
test "tree with three leaves provides expected proofs":
let tree = MerkleTree.init(leaves[0..2]).tryGet()
builder.addDataBlock(data[0])
builder.addDataBlock(data[1])
builder.addDataBlock(data[2])
let tree = builder.build().tryGet()
let expectedProofs = [
MerkleProof.init(0, @[leaves[1], combine(leaves[2], leaves[2])]),
MerkleProof.init(1, @[leaves[0], combine(leaves[2], leaves[2])]),
MerkleProof.init(2, @[leaves[2], combine(leaves[0], leaves[1])]),
MerkleProof.init(0, @[expectedLeaves[1], combine(expectedLeaves[2], zeroHash)]),
MerkleProof.init(1, @[expectedLeaves[0], combine(expectedLeaves[2], zeroHash)]),
MerkleProof.init(2, @[zeroHash, combine(expectedLeaves[0], expectedLeaves[1])]),
]
check:
@ -88,21 +157,65 @@ checksuite "merkletree":
tree.getProof(1).tryGet() == expectedProofs[1]
tree.getProof(2).tryGet() == expectedProofs[2]
test "tree with ten leaves provides expected proofs":
builder.addDataBlock(data[0])
builder.addDataBlock(data[1])
builder.addDataBlock(data[2])
builder.addDataBlock(data[3])
builder.addDataBlock(data[4])
builder.addDataBlock(data[5])
builder.addDataBlock(data[6])
builder.addDataBlock(data[7])
builder.addDataBlock(data[8])
builder.addDataBlock(data[9])
let tree = builder.build().tryGet()
let expectedProofs = {
4:
MerkleProof.init(4, @[
expectedLeaves[5],
combine(expectedLeaves[6], expectedLeaves[7]),
combine(
combine(expectedLeaves[0], expectedLeaves[1]),
combine(expectedLeaves[2], expectedLeaves[3]),
),
combine(
combine(
combine(expectedLeaves[8], expectedLeaves[9]),
zeroHash
),
zeroHash
)
]),
9:
MerkleProof.init(9, @[
expectedLeaves[8],
zeroHash,
zeroHash,
combine(
combine(
combine(expectedLeaves[0], expectedLeaves[1]),
combine(expectedLeaves[2], expectedLeaves[3]),
),
combine(
combine(expectedLeaves[4], expectedLeaves[5]),
combine(expectedLeaves[6], expectedLeaves[7])
)
)
]),
}.newTable
check:
tree.getProof(4).tryGet() == expectedProofs[4]
tree.getProof(9).tryGet() == expectedProofs[9]
test "getProof fails for index out of bounds":
let tree = MerkleTree.init(leaves[0..3]).tryGet()
builder.addDataBlock(data[0])
builder.addDataBlock(data[1])
builder.addDataBlock(data[2])
let tree = builder.build().tryGet()
check:
isErr(tree.getProof(-1))
isErr(tree.getProof(4))
test "can create MerkleTree directly from root hash":
let tree = MerkleTree.init(leaves[0], 1)
check:
tree.root == leaves[0]
test "cannot create MerkleTree from leaves with different codec":
let res = MerkleTree.init(@[randomHash(sha256), randomHash(sha512)])
check:
isErr(res)