perf: flatten merkle tree

A classic encoding of a merkle tree is to store the layers consecutively
in memory breadth-first. This encoding has several advantages:

* Good performance for accessing successive nodes, such as when
constructing the tree or serializing it
* Significantly lower memory usage - avoids the per-node allocation
overhead which otherwise more than doubles the memory usage for
"regular" 32-byte hashes
* Less memory management - a single memory allocation can reserve memory
for the whole tree meaning that there are fewer allocations to keep
track of
* Simplified buffer lifetimes - with all memory allocated up-front,
there's no need for cross-thread memory management or transfers

While we're here, we can clean up a few other things in the
implementation:

* Move async implementation to `merkletree` so that it doesn't have to
be repeated
* Factor tree construction into preparation and computation - the latter
is the part offloaded onto a different thread
* Simplify task posting - `threadpools` already creates a "task" from
the worker function call
* Deprecate several high-overhead accessors that presumably are only
needed in tests
This commit is contained in:
Jacek Sieka 2025-12-17 13:30:37 +01:00
parent 99a78a41ee
commit db8f81cd63
No known key found for this signature in database
GPG Key ID: A1B09461ABB656B8
4 changed files with 293 additions and 220 deletions

View File

@ -21,12 +21,10 @@ import pkg/chronos/threadsync
import ../../utils
import ../../rng
import ../../errors
import ../../blocktype
import ../../codextypes
from ../../utils/digest import digestBytes
import ../../utils/uniqueptr
import ../merkletree
export merkletree
@ -45,8 +43,6 @@ type
ByteTree* = MerkleTree[ByteHash, ByteTreeKey]
ByteProof* = MerkleProof[ByteHash, ByteTreeKey]
CodexTreeTask* = MerkleTask[ByteHash, ByteTreeKey]
CodexTree* = ref object of ByteTree
mcodec*: MultiCodec
@ -119,9 +115,7 @@ func compress*(x, y: openArray[byte], key: ByteTreeKey, codec: MultiCodec): ?!By
let digest = ?MultiHash.digest(codec, input).mapFailure
success digest.digestBytes
func init*(
_: type CodexTree, mcodec: MultiCodec = Sha256HashCodec, leaves: openArray[ByteHash]
): ?!CodexTree =
func initTree(mcodec: MultiCodec, leaves: openArray[ByteHash]): ?!CodexTree =
if leaves.len == 0:
return failure "Empty leaves"
@ -134,70 +128,25 @@ func init*(
if digestSize != leaves[0].len:
return failure "Invalid hash length"
var self = CodexTree(mcodec: mcodec, compress: compressor, zero: Zero)
self.layers = ?merkleTreeWorker(self[], leaves, isBottomLayer = true)
var self = CodexTree(mcodec: mcodec)
?self.prepare(compressor, Zero, leaves)
success self
func init*(
_: type CodexTree, mcodec: MultiCodec = Sha256HashCodec, leaves: openArray[ByteHash]
): ?!CodexTree =
let tree = ?initTree(mcodec, leaves)
?tree.compute()
success tree
proc init*(
_: type CodexTree,
tp: Taskpool,
mcodec: MultiCodec = Sha256HashCodec,
leaves: seq[ByteHash],
): Future[?!CodexTree] {.async: (raises: []).} =
if leaves.len == 0:
return failure "Empty leaves"
let
compressor = proc(x, y: seq[byte], key: ByteTreeKey): ?!ByteHash {.noSideEffect.} =
compress(x, y, key, mcodec)
digestSize = ?mcodec.digestSize.mapFailure
Zero: ByteHash = newSeq[byte](digestSize)
if digestSize != leaves[0].len:
return failure "Invalid hash length"
without signal =? ThreadSignalPtr.new():
return failure("Unable to create thread signal")
defer:
signal.close().expect("closing once works")
var tree = CodexTree(mcodec: mcodec, compress: compressor, zero: Zero)
var
nodesPerLevel = nodesPerLevel(leaves.len)
hashes = nodesPerLevel.foldl(a + b, 0)
layers = newSeq[byte](hashes * digestSize)
var task = CodexTreeTask(
tree: addr tree[],
leaves: SharedBuf.view(leaves),
layers: SharedBuf.view(layers),
signal: signal,
)
doAssert tp.numThreads > 1,
"Must have at least one separate thread or signal will never be fired"
tp.spawn merkleTreeWorker(addr task)
# To support cancellation, we'd have to ensure the task we posted to taskpools
# exits early - since we're not doing that, block cancellation attempts
try:
await noCancel signal.wait()
except AsyncError as exc:
# Since we initialized the signal, the OS or chronos is misbehaving. In any
# case, it would mean the task is still running which would cause a memory
# a memory violation if we let it run - panic instead
raiseAssert "Could not wait for signal, was it initialized? " & exc.msg
if not task.success.load():
return failure("merkle tree task failed")
tree.layers = unpack[ByteHash](task.layers, leaves.len, digestSize)
): Future[?!CodexTree] {.async: (raises: [CancelledError]).} =
let tree = ?initTree(mcodec, leaves)
?await tree.compute(tp)
success tree
func init*(_: type CodexTree, leaves: openArray[MultiHash]): ?!CodexTree =
@ -262,15 +211,8 @@ proc fromNodes*(
if digestSize != nodes[0].len:
return failure "Invalid hash length"
var
self = CodexTree(compress: compressor, zero: Zero, mcodec: mcodec)
layer = nleaves
pos = 0
while pos < nodes.len:
self.layers.add(nodes[pos ..< (pos + layer)])
pos += layer
layer = divUp(layer, 2)
var self = CodexTree(mcodec: mcodec)
?self.fromNodes(compressor, Zero, nodes, nleaves)
let
index = Rng.instance.rand(nleaves - 1)

View File

@ -9,11 +9,12 @@
{.push raises: [].}
import std/[bitops, atomics]
import std/[bitops, atomics, sequtils]
import stew/assign2
import pkg/questionable/results
import pkg/taskpools
import pkg/chronos
import pkg/chronos/threadsync
import ../errors
@ -21,13 +22,43 @@ import ../utils/sharedbuf
export sharedbuf
template nodeData(
data: openArray[byte], offsets: openArray[int], nodeSize, i, j: int
): openArray[byte] =
## Bytes of the j'th entry of the i'th level in the tree, starting with the
## leaves (at level 0).
let start = (offsets[i] + j) * nodeSize
data.toOpenArray(start, start + nodeSize - 1)
type
# TODO hash functions don't fail - removing the ?! from this function would
# significantly simplify the flow below
CompressFn*[H, K] = proc(x, y: H, key: K): ?!H {.noSideEffect, raises: [].}
CompressData[H, K] = object
fn: CompressFn[H, K]
nodeSize: int
zero: H
MerkleTreeObj*[H, K] = object of RootObj
layers*: seq[seq[H]]
compress*: CompressFn[H, K]
zero*: H
store*: seq[byte]
## Flattened merkle tree where hashes are assumed to be trivial bytes and
## uniform in size.
##
## Each layer of the tree is stored serially starting with the leaves and
## ending with the root.
##
## Beacuse the tree might not be balanced, `layerOffsets` contains the
## index of the starting point of each level, for easy lookup.
layerOffsets*: seq[int]
## Starting point of each level in the tree, starting from the leaves -
## multiplied by the entry size, this is the offset in the payload where
## the entries of that level start
##
## For example, a tree with 4 leaves will have [0, 4, 6] stored here.
##
## See nodesPerLevel function, from whic this sequence is derived
compress*: CompressData[H, K]
MerkleTree*[H, K] = ref MerkleTreeObj[H, K]
@ -39,24 +70,35 @@ type
zero*: H # zero value
MerkleTask*[H, K] = object
tree*: ptr MerkleTreeObj[H, K]
leaves*: SharedBuf[H]
store*: SharedBuf[byte]
layerOffsets: SharedBuf[int]
compress*: ptr CompressData[H, K]
signal*: ThreadSignalPtr
layers*: SharedBuf[byte]
success*: Atomic[bool]
func levels*[H, K](self: MerkleTree[H, K]): int =
return self.layerOffsets.len
func depth*[H, K](self: MerkleTree[H, K]): int =
return self.layers.len - 1
return self.levels() - 1
func nodesInLayer(offsets: openArray[int], layer: int): int =
if layer == offsets.high:
1
else:
offsets[layer + 1] - offsets[layer]
func nodesInLayer(self: MerkleTree | MerkleTreeObj, layer: int): int =
self.layerOffsets.nodesInLayer(layer)
func leavesCount*[H, K](self: MerkleTree[H, K]): int =
return self.layers[0].len
return self.nodesInLayer(0)
func levels*[H, K](self: MerkleTree[H, K]): int =
return self.layers.len
func nodesPerLevel*(nleaves: int): seq[int] =
## Given a number of leaves, the number of nodes at each depth (from the
## bottom/leaves to the root)
func nodesPerLevel(nleaves: int): seq[int] =
## Given a number of leaves, return a seq with the number of nodes at each
## layer of the tree (from the bottom/leaves to the root)
##
## Ie For a tree of 4 leaves, return `[4, 2, 1]`
if nleaves <= 0:
return @[]
elif nleaves == 1:
@ -73,24 +115,60 @@ func nodesPerLevel*(nleaves: int): seq[int] =
nodes
func leaves*[H, K](self: MerkleTree[H, K]): seq[H] =
return self.layers[0]
func layerOffsets(nleaves: int): seq[int] =
## Given a number of leaves, return a seq of the starting offsets of each
## layer in the node store that results from flattening the binary tree
##
## Ie For a tree of 4 leaves, return `[0, 4, 6]`
let nodes = nodesPerLevel(nleaves)
var tot = 0
let offsets = nodes.mapIt:
let cur = tot
tot += it
cur
offsets
iterator layers*[H, K](self: MerkleTree[H, K]): seq[H] =
for layer in self.layers:
yield layer
template nodeData(self: MerkleTreeObj, i, j: int): openArray[byte] =
## Bytes of the j'th node of the i'th level in the tree, starting with the
## leaves (at level 0).
self.store.nodeData(self.layerOffsets, self.compress.nodeSize, i, j)
func layer*[H, K](
self: MerkleTree[H, K], layer: int
): seq[H] {.deprecated: "Expensive".} =
var nodes = newSeq[H](self.nodesInLayer(layer))
for i, h in nodes.mpairs:
assign(h, self[].nodeData(layer, i))
return nodes
func leaves*[H, K](self: MerkleTree[H, K]): seq[H] {.deprecated: "Expensive".} =
self.layer(0)
iterator layers*[H, K](self: MerkleTree[H, K]): seq[H] {.deprecated: "Expensive".} =
for i in 0 ..< self.layerOffsets.len:
yield self.layer(i)
proc layers*[H, K](self: MerkleTree[H, K]): seq[seq[H]] {.deprecated: "Expensive".} =
for l in self.layers():
result.add l
iterator nodes*[H, K](self: MerkleTree[H, K]): H =
for layer in self.layers:
for node in layer:
## Iterate over the nodes of each layer starting with the leaves
var node: H
for i in 0 ..< self.layerOffsets.len:
let nodesInLayer = self.nodesInLayer(i)
for j in 0 ..< nodesInLayer:
assign(node, self[].nodeData(i, j))
yield node
func root*[H, K](self: MerkleTree[H, K]): ?!H =
let last = self.layers[^1]
if last.len != 1:
mixin assign
if self.layerOffsets.len == 0:
return failure "invalid tree"
return success last[0]
var h: H
assign(h, self[].nodeData(self.layerOffsets.high(), 0))
return success h
func getProof*[H, K](
self: MerkleTree[H, K], index: int, proof: MerkleProof[H, K]
@ -106,18 +184,19 @@ func getProof*[H, K](
var m = nleaves
for i in 0 ..< depth:
let j = k xor 1
path[i] =
if (j < m):
self.layers[i][j]
else:
self.zero
if (j < m):
assign(path[i], self[].nodeData(i, j))
else:
path[i] = self.compress.zero
k = k shr 1
m = (m + 1) shr 1
proof.index = index
proof.path = path
proof.nleaves = nleaves
proof.compress = self.compress
proof.compress = self.compress.fn
success()
@ -156,68 +235,170 @@ func reconstructRoot*[H, K](proof: MerkleProof[H, K], leaf: H): ?!H =
func verify*[H, K](proof: MerkleProof[H, K], leaf: H, root: H): ?!bool =
success bool(root == ?proof.reconstructRoot(leaf))
func merkleTreeWorker*[H, K](
self: MerkleTreeObj[H, K], xs: openArray[H], isBottomLayer: static bool
): ?!seq[seq[H]] =
let a = low(xs)
let b = high(xs)
let m = b - a + 1
func fromNodes*[H, K](
self: MerkleTree[H, K],
compressor: CompressFn,
zero: H,
nodes: openArray[H],
nleaves: int,
): ?!void =
mixin assign
if nodes.len < 2: # At least leaf and root
return failure "Not enough nodes"
if nleaves == 0:
return failure "No leaves"
self.compress = CompressData[H, K](fn: compressor, nodeSize: nodes[0].len, zero: zero)
self.layerOffsets = layerOffsets(nleaves)
if self.layerOffsets[^1] + 1 != nodes.len:
return failure "bad node count"
self.store = newSeqUninit[byte](nodes.len * self.compress.nodeSize)
for i in 0 ..< nodes.len:
assign(
self[].store.toOpenArray(
i * self.compress.nodeSize, (i + 1) * self.compress.nodeSize - 1
),
nodes[i],
)
success()
func merkleTreeWorker[H, K](
store: var openArray[byte],
offsets: openArray[int],
compress: CompressData[H, K],
layer: int,
isBottomLayer: static bool,
): ?!void =
## Worker used to compute the merkle tree from the leaves that are assumed to
## already be stored at the beginning of the `store`, as done by `prepare`.
# Throughout, we use `assign` to convert from H to bytes and back, assuming
# this assignment can be done somewhat efficiently (ie memcpy) - because
# the code must work with multihash where len(H) is can differ, we cannot
# simply use a fixed-size array here.
mixin assign
template nodeData(i, j: int): openArray[byte] =
# Pick out the bytes of node j in layer i
store.nodeData(offsets, compress.nodeSize, i, j)
let m = offsets.nodesInLayer(layer)
when not isBottomLayer:
if m == 1:
return success @[@xs]
return success()
let halfn: int = m div 2
let n: int = 2 * halfn
let isOdd: bool = (n != m)
var ys: seq[H]
if not isOdd:
ys = newSeq[H](halfn)
else:
ys = newSeq[H](halfn + 1)
# Because the compression function we work with works with H and not bytes,
# we need to extract H from the raw data - a little abstraction tax that
# ensures that properties like alignment of H are respected.
var a, b, tmp: H
for i in 0 ..< halfn:
const key = when isBottomLayer: K.KeyBottomLayer else: K.KeyNone
ys[i] = ?self.compress(xs[a + 2 * i], xs[a + 2 * i + 1], key = key)
assign(a, nodeData(layer, i * 2))
assign(b, nodeData(layer, i * 2 + 1))
tmp = ?compress.fn(a, b, key = key)
assign(nodeData(layer + 1, i), tmp)
if isOdd:
const key = when isBottomLayer: K.KeyOddAndBottomLayer else: K.KeyOdd
ys[halfn] = ?self.compress(xs[n], self.zero, key = key)
success @[@xs] & ?self.merkleTreeWorker(ys, isBottomLayer = false)
assign(a, nodeData(layer, n))
proc pack*[H](tgt: SharedBuf[byte], v: seq[seq[H]]) =
# Pack the given nested sequences into a flat buffer
var pos = 0
for layer in v:
for h in layer:
assign(tgt.toOpenArray(pos, pos + h.len - 1), h)
pos += h.len
tmp = ?compress.fn(a, compress.zero, key = key)
proc unpack*[H](src: SharedBuf[byte], nleaves, digestSize: int): seq[seq[H]] =
# Given a flat buffer and the number of leaves, unpack the merkle tree from
# its flat storage
var
nodesPerLevel = nodesPerLevel(nleaves)
res = newSeq[seq[H]](nodesPerLevel.len)
pos = 0
for i, layer in res.mpairs:
layer = newSeq[H](nodesPerLevel[i])
for j, h in layer.mpairs:
assign(h, src.toOpenArray(pos, pos + digestSize - 1))
pos += digestSize
res
assign(nodeData(layer + 1, halfn), tmp)
proc merkleTreeWorker*[H, K](task: ptr MerkleTask[H, K]) {.gcsafe.} =
merkleTreeWorker(store, offsets, compress, layer + 1, false)
proc merkleTreeWorker[H, K](task: ptr MerkleTask[H, K]) {.gcsafe.} =
defer:
discard task[].signal.fireSync()
let res = merkleTreeWorker(
task[].tree[], task[].leaves.toOpenArray(), isBottomLayer = true
).valueOr:
task[].success.store(false)
return
task[].store.toOpenArray(),
task[].layerOffsets.toOpenArray(),
task[].compress[],
0,
isBottomLayer = true,
)
task.layers.pack(res)
task[].success.store(res.isOk())
task[].success.store(true)
func prepare*[H, K](
self: MerkleTree[H, K], compressor: CompressFn, zero: H, leaves: openArray[H]
): ?!void =
## Prepare the instance for computing the merkle tree of the given leaves using
## the given compression function. After preparation, `compute` should be
## called to perform the actual computation. `leaves` will be copied into the
## tree so they can be freed after the call.
if leaves.len == 0:
return failure "No leaves"
self.compress =
CompressData[H, K](fn: compressor, nodeSize: leaves[0].len, zero: zero)
self.layerOffsets = layerOffsets(leaves.len)
self.store = newSeqUninit[byte]((self.layerOffsets[^1] + 1) * self.compress.nodeSize)
for j in 0 ..< leaves.len:
assign(self[].nodeData(0, j), leaves[j])
return success()
proc compute*[H, K](self: MerkleTree[H, K]): ?!void =
merkleTreeWorker(
self.store, self.layerOffsets, self.compress, 0, isBottomLayer = true
)
proc compute*[H, K](
self: MerkleTree[H, K], tp: Taskpool
): Future[?!void] {.async: (raises: []).} =
if tp.numThreads == 1:
# With a single thread, there's no point creating a separate task
return self.compute()
# TODO this signal would benefit from reuse across computations
without signal =? ThreadSignalPtr.new():
return failure("Unable to create thread signal")
defer:
signal.close().expect("closing once works")
var task = MerkleTask[H, K](
store: SharedBuf.view(self.store),
layerOffsets: SharedBuf.view(self.layerOffsets),
compress: addr self.compress,
signal: signal,
)
tp.spawn merkleTreeWorker(addr task)
# To support cancellation, we'd have to ensure the task we posted to taskpools
# exits early - since we're not doing that, block cancellation attempts
try:
await noCancel signal.wait()
except AsyncError as exc:
# Since we initialized the signal, the OS or chronos is misbehaving. In any
# case, it would mean the task is still running which would cause a memory
# a memory violation if we let it run - panic instead
raiseAssert "Could not wait for signal, was it initialized? " & exc.msg
if not task.success.load():
return failure("merkle tree task failed")
return success()

View File

@ -46,7 +46,16 @@ type
Poseidon2Tree* = MerkleTree[Poseidon2Hash, PoseidonKeysEnum]
Poseidon2Proof* = MerkleProof[Poseidon2Hash, PoseidonKeysEnum]
Poseidon2TreeTask* = MerkleTask[Poseidon2Hash, PoseidonKeysEnum]
proc len*(v: Poseidon2Hash): int =
sizeof(v)
proc assign*(v: var openArray[byte], h: Poseidon2Hash) =
doAssert v.len == sizeof(h)
copyMem(addr v[0], addr h, sizeof(h))
proc assign*(h: var Poseidon2Hash, v: openArray[byte]) =
doAssert v.len == sizeof(h)
copyMem(addr h, addr v[0], sizeof(h))
proc `$`*(self: Poseidon2Tree): string =
let root = if self.root.isOk: self.root.get.toHex else: "none"
@ -67,7 +76,7 @@ converter toKey*(key: PoseidonKeysEnum): Poseidon2Hash =
of KeyOdd: KeyOddF
of KeyOddAndBottomLayer: KeyOddAndBottomLayerF
func init*(_: type Poseidon2Tree, leaves: openArray[Poseidon2Hash]): ?!Poseidon2Tree =
proc initTree(leaves: openArray[Poseidon2Hash]): ?!Poseidon2Tree =
if leaves.len == 0:
return failure "Empty leaves"
@ -76,72 +85,24 @@ func init*(_: type Poseidon2Tree, leaves: openArray[Poseidon2Hash]): ?!Poseidon2
): ?!Poseidon2Hash {.noSideEffect.} =
success compress(x, y, key.toKey)
var self = Poseidon2Tree(compress: compressor, zero: Poseidon2Zero)
self.layers = ?merkleTreeWorker(self[], leaves, isBottomLayer = true)
var self = Poseidon2Tree()
?self.prepare(compressor, Poseidon2Zero, leaves)
success self
proc len*(v: Poseidon2Hash): int =
sizeof(v)
func init*(_: type Poseidon2Tree, leaves: openArray[Poseidon2Hash]): ?!Poseidon2Tree =
let self = ?initTree(leaves)
?self.compute()
proc assign*(v: var openArray[byte], h: Poseidon2Hash) =
doAssert v.len == sizeof(h)
copyMem(addr v[0], addr h, sizeof(h))
proc assign*(h: var Poseidon2Hash, v: openArray[byte]) =
doAssert v.len == sizeof(h)
copyMem(addr h, addr v[0], sizeof(h))
success self
proc init*(
_: type Poseidon2Tree, tp: Taskpool, leaves: seq[Poseidon2Hash]
): Future[?!Poseidon2Tree] {.async: (raises: []).} =
if leaves.len == 0:
return failure "Empty leaves"
): Future[?!Poseidon2Tree] {.async: (raises: [CancelledError]).} =
let self = ?initTree(leaves)
let compressor = proc(
x, y: Poseidon2Hash, key: PoseidonKeysEnum
): ?!Poseidon2Hash {.noSideEffect.} =
success compress(x, y, key.toKey)
?await self.compute(tp)
without signal =? ThreadSignalPtr.new():
return failure("Unable to create thread signal")
defer:
signal.close().expect("closing once works")
var tree = Poseidon2Tree(compress: compressor, zero: Poseidon2Zero)
var
nodesPerLevel = nodesPerLevel(leaves.len)
hashes = nodesPerLevel.foldl(a + b, 0)
layers = newSeq[byte](hashes * sizeof(Poseidon2Hash))
var task = Poseidon2TreeTask(
tree: addr tree[],
leaves: SharedBuf.view(leaves),
layers: SharedBuf.view(layers),
signal: signal,
)
doAssert tp.numThreads > 1,
"Must have at least one separate thread or signal will never be fired"
tp.spawn merkleTreeWorker(addr task)
# To support cancellation, we'd have to ensure the task we posted to taskpools
# exits early - since we're not doing that, block cancellation attempts
try:
await noCancel signal.wait()
except AsyncError as exc:
# Since we initialized the signal, the OS or chronos is misbehaving. In any
# case, it would mean the task is still running which would cause a memory
# a memory violation if we let it run - panic instead
raiseAssert "Could not wait for signal, was it initialized? " & exc.msg
if not task.success.load():
return failure("merkle tree task failed")
tree.layers = unpack[Poseidon2Hash](task.layers, leaves.len, sizeof(Poseidon2Hash))
success tree
success self
func init*(_: type Poseidon2Tree, leaves: openArray[array[31, byte]]): ?!Poseidon2Tree =
Poseidon2Tree.init(leaves.mapIt(Poseidon2Hash.fromBytes(it)))
@ -154,23 +115,13 @@ proc init*(
proc fromNodes*(
_: type Poseidon2Tree, nodes: openArray[Poseidon2Hash], nleaves: int
): ?!Poseidon2Tree =
if nodes.len == 0:
return failure "Empty nodes"
let compressor = proc(
x, y: Poseidon2Hash, key: PoseidonKeysEnum
): ?!Poseidon2Hash {.noSideEffect.} =
success compress(x, y, key.toKey)
var
self = Poseidon2Tree(compress: compressor, zero: zero)
layer = nleaves
pos = 0
while pos < nodes.len:
self.layers.add(nodes[pos ..< (pos + layer)])
pos += layer
layer = divUp(layer, 2)
let self = Poseidon2Tree()
?self.fromNodes(compressor, Poseidon2Zero, nodes, nleaves)
let
index = Rng.instance.rand(nleaves - 1)

View File

@ -72,7 +72,6 @@ suite "Test Poseidon2Tree":
tree.leaves == expectedLeaves
test "Build poseidon2 tree from byte leaves asynchronously":
echo "Build poseidon2 tree from byte leaves asynchronously"
var tp = Taskpool.new()
defer:
tp.shutdown()