fix: address cross-thread safety issues

For minimal correctness, we must ensure that buffers that cross the
thread boundary are allocated and deallocated within the same thread and
that there is no reference counting going on during the computation.

To get there with minimal changes:

* Preallocate a buffer for the outcome of the merkle tree computation
* Pass pointers instead of `ref` types between threads
* Avoid relying on isolation - this is an ORC-only feature
* Add `SharedBuf` as a simple "view" type that allows working with a set
of values while at the same time avoiding allocations and refcounts -
the view checks for out-of-bounds acccess much like a seq, but the user
is responsible for managing lifetime (which in this case is simple since
all that needs to happen is for the task to complete)
* In order not to upset the code too much, use a simple linear packer
for the hashes that simply copies the values back and forth
* Block cancellation and panic if the thread signalling mechanism fails
- cancelling the task itself would require inserting cancellation points
in the computation

The worker task relies on a nuance, namely that calling a closure
procedure does not count as a reference-counting event - while this
works, it can be brittle in "general" code since it's easy to make copy
of the closure itself by accident - the refactoring necessary for
addressing this point is beyond the scope of this change however.
This commit is contained in:
Jacek Sieka 2025-12-15 19:59:11 +01:00
parent 9024246349
commit 99a78a41ee
No known key found for this signature in database
GPG Key ID: A1B09461ABB656B8
6 changed files with 168 additions and 37 deletions

View File

@ -136,7 +136,8 @@ func init*(
var self = CodexTree(mcodec: mcodec, compress: compressor, zero: Zero)
self.layers = ?merkleTreeWorker(self, leaves, isBottomLayer = true)
self.layers = ?merkleTreeWorker(self[], leaves, isBottomLayer = true)
success self
proc init*(
@ -144,7 +145,7 @@ proc init*(
tp: Taskpool,
mcodec: MultiCodec = Sha256HashCodec,
leaves: seq[ByteHash],
): Future[?!CodexTree] {.async: (raises: [CancelledError]).} =
): Future[?!CodexTree] {.async: (raises: []).} =
if leaves.len == 0:
return failure "Empty leaves"
@ -165,25 +166,37 @@ proc init*(
var tree = CodexTree(mcodec: mcodec, compress: compressor, zero: Zero)
var task =
CodexTreeTask(tree: cast[ptr ByteTree](addr tree), leaves: leaves, signal: signal)
var
nodesPerLevel = nodesPerLevel(leaves.len)
hashes = nodesPerLevel.foldl(a + b, 0)
layers = newSeq[byte](hashes * digestSize)
var task = CodexTreeTask(
tree: addr tree[],
leaves: SharedBuf.view(leaves),
layers: SharedBuf.view(layers),
signal: signal,
)
doAssert tp.numThreads > 1,
"Must have at least one separate thread or signal will never be fired"
tp.spawn merkleTreeWorker(addr task)
let threadFut = signal.wait()
if err =? catch(await threadFut.join()).errorOption:
?catch(await noCancel threadFut)
if err of CancelledError:
raise (ref CancelledError) err
# To support cancellation, we'd have to ensure the task we posted to taskpools
# exits early - since we're not doing that, block cancellation attempts
try:
await noCancel signal.wait()
except AsyncError as exc:
# Since we initialized the signal, the OS or chronos is misbehaving. In any
# case, it would mean the task is still running which would cause a memory
# a memory violation if we let it run - panic instead
raiseAssert "Could not wait for signal, was it initialized? " & exc.msg
if not task.success.load():
return failure("merkle tree task failed")
tree.layers = extractValue(task.layers)
tree.layers = unpack[ByteHash](task.layers, leaves.len, digestSize)
success tree

View File

@ -10,22 +10,27 @@
{.push raises: [].}
import std/[bitops, atomics]
import stew/assign2
import pkg/questionable/results
import pkg/taskpools
import pkg/chronos/threadsync
import ../errors
import ../utils/uniqueptr
import ../utils/sharedbuf
export sharedbuf
type
CompressFn*[H, K] = proc(x, y: H, key: K): ?!H {.noSideEffect, raises: [].}
MerkleTree*[H, K] = ref object of RootObj
MerkleTreeObj*[H, K] = object of RootObj
layers*: seq[seq[H]]
compress*: CompressFn[H, K]
zero*: H
MerkleTree*[H, K] = ref MerkleTreeObj[H, K]
MerkleProof*[H, K] = ref object of RootObj
index*: int # linear index of the leaf, starting from 0
path*: seq[H] # order: from the bottom to the top
@ -34,10 +39,10 @@ type
zero*: H # zero value
MerkleTask*[H, K] = object
tree*: ptr MerkleTree[H, K]
leaves*: seq[H]
tree*: ptr MerkleTreeObj[H, K]
leaves*: SharedBuf[H]
signal*: ThreadSignalPtr
layers*: UniquePtr[seq[seq[H]]]
layers*: SharedBuf[byte]
success*: Atomic[bool]
func depth*[H, K](self: MerkleTree[H, K]): int =
@ -49,6 +54,25 @@ func leavesCount*[H, K](self: MerkleTree[H, K]): int =
func levels*[H, K](self: MerkleTree[H, K]): int =
return self.layers.len
func nodesPerLevel*(nleaves: int): seq[int] =
## Given a number of leaves, the number of nodes at each depth (from the
## bottom/leaves to the root)
if nleaves <= 0:
return @[]
elif nleaves == 1:
return @[1, 1] # leaf and root
var nodes: seq[int] = @[]
var m = nleaves
while true:
nodes.add(m)
if m == 1:
break
# Next layer size is ceil(m/2)
m = (m + 1) shr 1
nodes
func leaves*[H, K](self: MerkleTree[H, K]): seq[H] =
return self.layers[0]
@ -133,7 +157,7 @@ func verify*[H, K](proof: MerkleProof[H, K], leaf: H, root: H): ?!bool =
success bool(root == ?proof.reconstructRoot(leaf))
func merkleTreeWorker*[H, K](
self: MerkleTree[H, K], xs: openArray[H], isBottomLayer: static bool
self: MerkleTreeObj[H, K], xs: openArray[H], isBottomLayer: static bool
): ?!seq[seq[H]] =
let a = low(xs)
let b = high(xs)
@ -162,21 +186,38 @@ func merkleTreeWorker*[H, K](
success @[@xs] & ?self.merkleTreeWorker(ys, isBottomLayer = false)
proc pack*[H](tgt: SharedBuf[byte], v: seq[seq[H]]) =
# Pack the given nested sequences into a flat buffer
var pos = 0
for layer in v:
for h in layer:
assign(tgt.toOpenArray(pos, pos + h.len - 1), h)
pos += h.len
proc unpack*[H](src: SharedBuf[byte], nleaves, digestSize: int): seq[seq[H]] =
# Given a flat buffer and the number of leaves, unpack the merkle tree from
# its flat storage
var
nodesPerLevel = nodesPerLevel(nleaves)
res = newSeq[seq[H]](nodesPerLevel.len)
pos = 0
for i, layer in res.mpairs:
layer = newSeq[H](nodesPerLevel[i])
for j, h in layer.mpairs:
assign(h, src.toOpenArray(pos, pos + digestSize - 1))
pos += digestSize
res
proc merkleTreeWorker*[H, K](task: ptr MerkleTask[H, K]) {.gcsafe.} =
defer:
discard task[].signal.fireSync()
let res = merkleTreeWorker(task[].tree[], task[].leaves, isBottomLayer = true)
if res.isErr:
let res = merkleTreeWorker(
task[].tree[], task[].leaves.toOpenArray(), isBottomLayer = true
).valueOr:
task[].success.store(false)
return
var layers = res.get()
var newOuterSeq = newSeq[seq[H]](layers.len)
for i in 0 ..< layers.len:
var isoInner = isolate(layers[i])
newOuterSeq[i] = extract(isoInner)
task.layers.pack(res)
task[].layers = newUniquePtr(newOuterSeq)
task[].success.store(true)

View File

@ -19,7 +19,6 @@ import pkg/constantine/platforms/abstractions
import pkg/questionable/results
import ../utils
import ../utils/uniqueptr
import ../rng
import ./merkletree
@ -79,12 +78,22 @@ func init*(_: type Poseidon2Tree, leaves: openArray[Poseidon2Hash]): ?!Poseidon2
var self = Poseidon2Tree(compress: compressor, zero: Poseidon2Zero)
self.layers = ?merkleTreeWorker(self, leaves, isBottomLayer = true)
self.layers = ?merkleTreeWorker(self[], leaves, isBottomLayer = true)
success self
proc len*(v: Poseidon2Hash): int =
sizeof(v)
proc assign*(v: var openArray[byte], h: Poseidon2Hash) =
doAssert v.len == sizeof(h)
copyMem(addr v[0], addr h, sizeof(h))
proc assign*(h: var Poseidon2Hash, v: openArray[byte]) =
doAssert v.len == sizeof(h)
copyMem(addr h, addr v[0], sizeof(h))
proc init*(
_: type Poseidon2Tree, tp: Taskpool, leaves: seq[Poseidon2Hash]
): Future[?!Poseidon2Tree] {.async: (raises: [CancelledError]).} =
): Future[?!Poseidon2Tree] {.async: (raises: []).} =
if leaves.len == 0:
return failure "Empty leaves"
@ -100,8 +109,16 @@ proc init*(
signal.close().expect("closing once works")
var tree = Poseidon2Tree(compress: compressor, zero: Poseidon2Zero)
var
nodesPerLevel = nodesPerLevel(leaves.len)
hashes = nodesPerLevel.foldl(a + b, 0)
layers = newSeq[byte](hashes * sizeof(Poseidon2Hash))
var task = Poseidon2TreeTask(
tree: cast[ptr Poseidon2Tree](addr tree), leaves: leaves, signal: signal
tree: addr tree[],
leaves: SharedBuf.view(leaves),
layers: SharedBuf.view(layers),
signal: signal,
)
doAssert tp.numThreads > 1,
@ -109,17 +126,20 @@ proc init*(
tp.spawn merkleTreeWorker(addr task)
let threadFut = signal.wait()
if err =? catch(await threadFut.join()).errorOption:
?catch(await noCancel threadFut)
if err of CancelledError:
raise (ref CancelledError) err
# To support cancellation, we'd have to ensure the task we posted to taskpools
# exits early - since we're not doing that, block cancellation attempts
try:
await noCancel signal.wait()
except AsyncError as exc:
# Since we initialized the signal, the OS or chronos is misbehaving. In any
# case, it would mean the task is still running which would cause a memory
# a memory violation if we let it run - panic instead
raiseAssert "Could not wait for signal, was it initialized? " & exc.msg
if not task.success.load():
return failure("merkle tree task failed")
tree.layers = extractValue(task.layers)
tree.layers = unpack[Poseidon2Hash](task.layers, leaves.len, sizeof(Poseidon2Hash))
success tree

24
codex/utils/sharedbuf.nim Normal file
View File

@ -0,0 +1,24 @@
import stew/ptrops
type SharedBuf*[T] = object
payload*: ptr UncheckedArray[T]
len*: int
proc view*[T](_: type SharedBuf, v: openArray[T]): SharedBuf[T] =
if v.len > 0:
SharedBuf[T](payload: makeUncheckedArray(addr v[0]), len: v.len)
else:
default(SharedBuf[T])
template checkIdx(v: SharedBuf, i: int) =
doAssert i > 0 and i <= v.len
proc `[]`*[T](v: SharedBuf[T], i: int): var T =
v.checkIdx(i)
v.payload[i]
template toOpenArray*[T](v: SharedBuf[T]): var openArray[T] =
v.payload.toOpenArray(0, v.len - 1)
template toOpenArray*[T](v: SharedBuf[T], s, e: int): var openArray[T] =
v.toOpenArray().toOpenArray(s, e)

View File

@ -12,6 +12,13 @@ proc testGenericTree*[H, K, U](
let data = @data
suite "Correctness tests - " & name:
test "Should build correct tree for single leaf":
let expectedRoot = compress(data[0], zero, K.KeyOddAndBottomLayer)
let tree = makeTree(data[0 .. 0])
check:
tree.root.tryGet == expectedRoot
test "Should build correct tree for even bottom layer":
let expectedRoot = compress(
compress(

View File

@ -96,6 +96,32 @@ suite "Test CodexTree":
tree.get().leaves == expectedLeaves.mapIt(it.mhash.tryGet.digestBytes)
tree.get().mcodec == sha256
test "Should build tree the same tree sync and async":
var tp = Taskpool.new(numThreads = 2)
defer:
tp.shutdown()
let expectedLeaves = data.mapIt(
Cid.init(CidVersion.CIDv1, BlockCodec, MultiHash.digest($sha256, it).tryGet).tryGet
)
let
atree = (await CodexTree.init(tp, leaves = expectedLeaves))
stree = CodexTree.init(leaves = expectedLeaves)
check:
toSeq(atree.get().nodes) == toSeq(stree.get().nodes)
atree.get().root == stree.get().root
# Single-leaf trees have their root separately computed
let
atree1 = (await CodexTree.init(tp, leaves = expectedLeaves[0..0]))
stree1 = CodexTree.init(leaves = expectedLeaves[0..0])
check:
toSeq(atree.get().nodes) == toSeq(stree.get().nodes)
atree.get().root == stree.get().root
test "Should build from raw digestbytes (should not hash leaves)":
let tree = CodexTree.init(sha256, leaves = data).tryGet