Merkle trees can use codecs different than sha256

This commit is contained in:
Tomasz Bekas 2023-08-23 18:38:58 +02:00
parent 0668d54c2e
commit 61f059e04f
No known key found for this signature in database
GPG Key ID: 4854E04C98824959
2 changed files with 190 additions and 147 deletions

View File

@ -9,23 +9,30 @@
import std/math import std/math
import std/bitops import std/bitops
import std/strutils import std/sequtils
import std/sugar
import pkg/questionable/results import pkg/questionable/results
import pkg/nimcrypto/sha2 import pkg/nimcrypto/sha2
import pkg/libp2p/[multicodec, multihash, vbuffer]
const digestSize = sha256.sizeDigest import ../errors
type type
MerkleHash* = array[digestSize, byte]
MerkleTree* = object MerkleTree* = object
mcodec: MultiCodec
digestSize: Natural
leavesCount: Natural leavesCount: Natural
nodes: seq[MerkleHash] nodesBuffer: seq[byte]
MerkleProof* = object MerkleProof* = object
mcodec: MultiCodec
digestSize: Natural
index: Natural index: Natural
path: seq[MerkleHash] nodesBuffer: seq[byte]
MerkleTreeBuilder* = object MerkleTreeBuilder* = object
buffer: seq[MerkleHash] mcodec: MultiCodec
digestSize: Natural
buffer: seq[byte]
########################################################### ###########################################################
# Helper functions # Helper functions
@ -37,41 +44,41 @@ func computeTreeHeight(leavesCount: int): int =
else: else:
fastLog2(leavesCount) + 2 fastLog2(leavesCount) + 2
func computeLevels(leavesCount: int): seq[tuple[offset: int, width: int]] = func computeLevels(leavesCount: int): seq[tuple[offset: int, width: int, index: int]] =
let height = computeTreeHeight(leavesCount) let height = computeTreeHeight(leavesCount)
result = newSeq[tuple[offset: int, width: int]](height) var levels = newSeq[tuple[offset: int, width: int, index: int]](height)
result[0].offset = 0 levels[0].offset = 0
result[0].width = leavesCount levels[0].width = leavesCount
levels[0].index = 0
for i in 1..<height: for i in 1..<height:
result[i].offset = result[i - 1].offset + result[i - 1].width levels[i].offset = levels[i - 1].offset + levels[i - 1].width
result[i].width = (result[i - 1].width + 1) div 2 levels[i].width = (levels[i - 1].width + 1) div 2
levels[i].index = i
levels
proc digestFn(data: openArray[byte], output: var MerkleHash): void = proc digestFn(mcodec: MultiCodec, output: pointer, data: openArray[byte]): ?!void =
var digest = sha256.digest(data) var mhash = ? MultiHash.digest($mcodec, data).mapFailure
copyMem(addr output, addr digest.data[0], digestSize) copyMem(output, addr mhash.data.buffer[mhash.dpos], mhash.size)
success()
###########################################################
# MerkleHash
###########################################################
var zeroHash: MerkleHash
proc `$`*(self: MerkleHash): string =
result = newStringOfCap(self.len)
for i in 0..<self.len:
result.add(toHex(self[i]))
########################################################### ###########################################################
# MerkleTreeBuilder # MerkleTreeBuilder
########################################################### ###########################################################
proc addDataBlock*(self: var MerkleTreeBuilder, dataBlock: openArray[byte]): void = proc init*(
T: type MerkleTreeBuilder,
mcodec: MultiCodec
): ?!MerkleTreeBuilder =
let mhash = ? MultiHash.digest($mcodec, "".toBytes).mapFailure
success(MerkleTreeBuilder(mcodec: mcodec, digestSize: mhash.size, buffer: newSeq[byte]()))
proc addDataBlock*(self: var MerkleTreeBuilder, dataBlock: openArray[byte]): ?!void =
## Hashes the data block and adds the result of hashing to a buffer ## Hashes the data block and adds the result of hashing to a buffer
## ##
let oldLen = self.buffer.len let oldLen = self.buffer.len
self.buffer.setLen(oldLen + 1) self.buffer.setLen(oldLen + self.digestSize)
digestFn(dataBlock, self.buffer[oldLen]) digestFn(self.mcodec, addr self.buffer[oldLen], dataBlock)
proc build*(self: MerkleTreeBuilder): ?!MerkleTree = proc build*(self: MerkleTreeBuilder): ?!MerkleTree =
## Builds a tree from previously added data blocks ## Builds a tree from previously added data blocks
@ -79,31 +86,36 @@ proc build*(self: MerkleTreeBuilder): ?!MerkleTree =
## Tree built from data blocks A, B and C is ## Tree built from data blocks A, B and C is
## H5=H(H3 & H4) ## H5=H(H3 & H4)
## / \ ## / \
## H3=H(H0 & H1) H4=H(H2 & HZ) ## H3=H(H0 & H1) H4=H(H2 & 0x00)
## / \ / ## / \ /
## H0=H(A) H1=H(B) H2=H(C) ## H0=H(A) H1=H(B) H2=H(C)
## | | | ## | | |
## A B C ## A B C
## ##
## where HZ=0x0b
##
## Memory layout is [H0, H1, H2, H3, H4, H5] ## Memory layout is [H0, H1, H2, H3, H4, H5]
## ##
let leavesCount = self.buffer.len let
mcodec = self.mcodec
digestSize = self.digestSize
leavesCount = self.buffer.len div self.digestSize
if leavesCount == 0: if leavesCount == 0:
return failure("At least one data block is required") return failure("At least one data block is required")
let levels = computeLevels(leavesCount) let levels = computeLevels(leavesCount)
let totalSize = levels[^1].offset + 1 let totalNodes = levels[^1].offset + 1
var tree = MerkleTree(leavesCount: leavesCount, nodes: newSeq[MerkleHash](totalSize)) var tree = MerkleTree(mcodec: mcodec, digestSize: digestSize, leavesCount: leavesCount, nodesBuffer: newSeq[byte](totalNodes * digestSize))
# copy leaves # copy leaves
copyMem(addr tree.nodes[0], unsafeAddr self.buffer[0], leavesCount * digestSize) copyMem(addr tree.nodesBuffer[0], unsafeAddr self.buffer[0], leavesCount * digestSize)
# calculate intermediate nodes # calculate intermediate nodes
var concatBuf: array[2 * digestSize, byte] var zero = newSeq[byte](self.digestSize)
var one = newSeq[byte](self.digestSize)
one[^1] = 0x01
var concatBuf = newSeq[byte](2 * digestSize)
var prevLevel = levels[0] var prevLevel = levels[0]
for level in levels[1..^1]: for level in levels[1..^1]:
for i in 0..<level.width: for i in 0..<level.width:
@ -111,14 +123,16 @@ proc build*(self: MerkleTreeBuilder): ?!MerkleTree =
let leftChildIndex = prevLevel.offset + 2 * i let leftChildIndex = prevLevel.offset + 2 * i
let rightChildIndex = leftChildIndex + 1 let rightChildIndex = leftChildIndex + 1
copyMem(addr concatBuf[0], addr tree.nodes[leftChildIndex], digestSize) copyMem(addr concatBuf[0], addr tree.nodesBuffer[leftChildIndex * digestSize], digestSize)
var dummyValue = if prevLevel.index == 0: zero else: one
if rightChildIndex < prevLevel.offset + prevLevel.width: if rightChildIndex < prevLevel.offset + prevLevel.width:
copyMem(addr concatBuf[digestSize], addr tree.nodes[rightChildIndex], digestSize) copyMem(addr concatBuf[digestSize], addr tree.nodesBuffer[rightChildIndex * digestSize], digestSize)
else: else:
copyMem(addr concatBuf[digestSize], addr zeroHash, digestSize) copyMem(addr concatBuf[digestSize], addr dummyValue[0], digestSize)
digestFn(concatBuf, tree.nodes[parentIndex]) ? digestFn(mcodec, addr tree.nodesBuffer[parentIndex * digestSize], concatBuf)
prevLevel = level prevLevel = level
return success(tree) return success(tree)
@ -127,17 +141,25 @@ proc build*(self: MerkleTreeBuilder): ?!MerkleTree =
# MerkleTree # MerkleTree
########################################################### ###########################################################
proc root*(self: MerkleTree): MerkleHash = proc nodeBufferToMultiHash(self: (MerkleTree | MerkleProof), index: int): MultiHash =
self.nodes[^1] var buf = newSeq[byte](self.digestSize)
copyMem(addr buf[0], unsafeAddr self.nodesBuffer[index * self.digestSize], self.digestSize)
without mhash =? MultiHash.init($self.mcodec, buf).mapFailure, error:
raise error
mhash
proc len*(self: MerkleTree): Natural = proc len*(self: (MerkleTree | MerkleProof)): Natural =
self.nodes.len self.nodesBuffer.len div self.digestSize
proc leaves*(self: MerkleTree): seq[MerkleHash] = proc nodes*(self: (MerkleTree | MerkleProof)): seq[MultiHash] =
self.nodes[0..<self.leavesCount] toSeq(0..<self.len).map(i => self.nodeBufferToMultiHash(i))
proc nodes*(self: MerkleTree): seq[MerkleHash] = proc root*(self: MerkleTree): MultiHash =
self.nodes let rootIndex = self.len - 1
self.nodeBufferToMultiHash(rootIndex)
proc leaves*(self: MerkleTree): seq[MultiHash] =
toSeq(0..<self.leavesCount).map(i => self.nodeBufferToMultiHash(i))
proc height*(self: MerkleTree): Natural = proc height*(self: MerkleTree): Natural =
computeTreeHeight(self.leavesCount) computeTreeHeight(self.leavesCount)
@ -157,32 +179,39 @@ proc getProof*(self: MerkleTree, index: Natural): ?!MerkleProof =
## Proofs of inclusion (index and path) are ## Proofs of inclusion (index and path) are
## - 0,[H1, H4] for data block A ## - 0,[H1, H4] for data block A
## - 1,[H0, H4] for data block B ## - 1,[H0, H4] for data block B
## - 2,[HZ, H3] for data block C ## - 2,[0x00, H3] for data block C
##
## where HZ=0x0b
## ##
if index >= self.leavesCount: if index >= self.leavesCount:
return failure("Index " & $index & " out of range [0.." & $self.leaves.high & "]" ) return failure("Index " & $index & " out of range [0.." & $(self.leavesCount - 1) & "]" )
var zero = newSeq[byte](self.digestSize)
var one = newSeq[byte](self.digestSize)
one[^1] = 0x01
let levels = computeLevels(self.leavesCount) let levels = computeLevels(self.leavesCount)
var path = newSeq[MerkleHash](levels.len - 1) var proofNodesBuffer = newSeq[byte]((levels.len - 1) * self.digestSize)
for levelIndex, level in levels[0..^2]: for level in levels[0..^2]:
let i = index div (1 shl levelIndex) let i = index div (1 shl level.index)
let siblingIndex = if i mod 2 == 0: let siblingIndex = if i mod 2 == 0:
level.offset + i + 1 level.offset + i + 1
else: else:
level.offset + i - 1 level.offset + i - 1
if siblingIndex < level.offset + level.width: var dummyValue = if level.index == 0: zero else: one
path[levelIndex] = self.nodes[siblingIndex]
else:
path[levelIndex] = zeroHash
success(MerkleProof(index: index, path: path)) if siblingIndex < level.offset + level.width:
copyMem(addr proofNodesBuffer[level.index * self.digestSize], unsafeAddr self.nodesBuffer[siblingIndex * self.digestSize], self.digestSize)
else:
copyMem(addr proofNodesBuffer[level.index * self.digestSize], addr dummyValue[0], self.digestSize)
# path[levelIndex] = zeroHash
success(MerkleProof(mcodec: self.mcodec, digestSize: self.digestSize, index: index, nodesBuffer: proofNodesBuffer))
proc `$`*(self: MerkleTree): string = proc `$`*(self: MerkleTree): string =
result &= "leavesCount: " & $self.leavesCount "mcodec:" & $self.mcodec &
result &= "\nnodes: " & $self.nodes "\nleavesCount: " & $self.leavesCount &
"\nnodes: " & $self.nodes
########################################################### ###########################################################
# MerkleProof # MerkleProof
@ -191,19 +220,28 @@ proc `$`*(self: MerkleTree): string =
proc index*(self: MerkleProof): Natural = proc index*(self: MerkleProof): Natural =
self.index self.index
proc path*(self: MerkleProof): seq[MerkleHash] =
self.path
proc `$`*(self: MerkleProof): string = proc `$`*(self: MerkleProof): string =
result &= "index: " & $self.index "mcodec:" & $self.mcodec &
result &= "\npath: " & $self.path "\nindex: " & $self.index &
"\nnodes: " & $self.nodes
func `==`*(a, b: MerkleProof): bool = func `==`*(a, b: MerkleProof): bool =
(a.index == b.index) and (a.path == b.path) (a.index == b.index) and (a.mcodec == b.mcodec) and (a.digestSize == b.digestSize) == (a.nodesBuffer == b.nodesBuffer)
proc init*( proc init*(
T: type MerkleProof, T: type MerkleProof,
index: Natural, index: Natural,
path: seq[MerkleHash] nodes: seq[MultiHash]
): MerkleProof = ): ?!MerkleProof =
MerkleProof(index: index, path: path) if nodes.len == 0:
return failure("At least one node is required")
let
mcodec = nodes[0].mcodec
digestSize = nodes[0].size
var nodesBuffer = newSeq[byte](nodes.len * digestSize)
for nodeIndex, node in nodes:
copyMem(addr nodesBuffer[nodeIndex * digestSize], unsafeAddr node.data.buffer[node.dpos], digestSize)
success(MerkleProof(mcodec: mcodec, digestSize: digestSize, index: index, nodesBuffer: nodesBuffer))

View File

@ -23,28 +23,35 @@ checksuite "merkletree":
"8901234567890123456789012345678901234567".toBytes, "8901234567890123456789012345678901234567".toBytes,
"9012345678901234567890123456789012345678".toBytes, "9012345678901234567890123456789012345678".toBytes,
] ]
var zeroHash: MerkleHash
var expectedLeaves: array[data.len, MerkleHash]
var builder: MerkleTreeBuilder
proc combine(a, b: MerkleHash): MerkleHash = const sha256 = multiCodec("sha2-256")
var buf = newSeq[byte](a.len + b.len) const sha512 = multiCodec("sha2-512")
for i in 0..<a.len:
buf[i] = a[i] proc combine(a, b: MultiHash, codec: MultiCodec = sha256): MultiHash =
for i in 0..<b.len: var buf = newSeq[byte](a.size + b.size)
buf[i + a.len] = b[i] copyMem(addr buf[0], unsafeAddr a.data.buffer[a.dpos], a.size)
var digest = sha256.digest(buf) copyMem(addr buf[a.size], unsafeAddr b.data.buffer[b.dpos], b.size)
return digest.data return MultiHash.digest($codec, buf).tryGet()
var zeroHash: MultiHash
var oneHash: MultiHash
var expectedLeaves: array[data.len, MultiHash]
var builder: MerkleTreeBuilder
setup: setup:
for i in 0..<data.len: for i in 0..<data.len:
var digest = sha256.digest(data[i]) expectedLeaves[i] = MultiHash.digest($sha256, data[i]).tryGet()
expectedLeaves[i] = digest.data
builder = MerkleTreeBuilder() builder = MerkleTreeBuilder.init(sha256).tryGet()
var zero: array[32, byte]
var one: array[32, byte]
one[^1] = 0x01
zeroHash = MultiHash.init($sha256, zero).tryGet()
oneHash = MultiHash.init($sha256, one).tryGet()
test "tree with one leaf has expected root": test "tree with one leaf has expected structure":
builder.addDataBlock(data[0]) builder.addDataBlock(data[0]).tryGet()
let tree = builder.build().tryGet() let tree = builder.build().tryGet()
@ -53,9 +60,9 @@ checksuite "merkletree":
tree.root == expectedLeaves[0] tree.root == expectedLeaves[0]
tree.len == 1 tree.len == 1
test "tree with two leaves has expected root": test "tree with two leaves has expected structure":
builder.addDataBlock(data[0]) builder.addDataBlock(data[0]).tryGet()
builder.addDataBlock(data[1]) builder.addDataBlock(data[1]).tryGet()
let tree = builder.build().tryGet() let tree = builder.build().tryGet()
@ -66,10 +73,10 @@ checksuite "merkletree":
tree.len == 3 tree.len == 3
tree.root == expectedRoot tree.root == expectedRoot
test "tree with three leaves has expected root": test "tree with three leaves has expected structure":
builder.addDataBlock(data[0]) builder.addDataBlock(data[0]).tryGet()
builder.addDataBlock(data[1]) builder.addDataBlock(data[1]).tryGet()
builder.addDataBlock(data[2]) builder.addDataBlock(data[2]).tryGet()
let tree = builder.build().tryGet() let tree = builder.build().tryGet()
@ -84,17 +91,16 @@ checksuite "merkletree":
tree.len == 6 tree.len == 6
tree.root == expectedRoot tree.root == expectedRoot
test "tree with ten leaves has expected root": test "tree with nine leaves has expected structure":
builder.addDataBlock(data[0]) builder.addDataBlock(data[0]).tryGet()
builder.addDataBlock(data[1]) builder.addDataBlock(data[1]).tryGet()
builder.addDataBlock(data[2]) builder.addDataBlock(data[2]).tryGet()
builder.addDataBlock(data[3]) builder.addDataBlock(data[3]).tryGet()
builder.addDataBlock(data[4]) builder.addDataBlock(data[4]).tryGet()
builder.addDataBlock(data[5]) builder.addDataBlock(data[5]).tryGet()
builder.addDataBlock(data[6]) builder.addDataBlock(data[6]).tryGet()
builder.addDataBlock(data[7]) builder.addDataBlock(data[7]).tryGet()
builder.addDataBlock(data[8]) builder.addDataBlock(data[8]).tryGet()
builder.addDataBlock(data[9])
let tree = builder.build().tryGet() let tree = builder.build().tryGet()
@ -112,27 +118,27 @@ checksuite "merkletree":
), ),
combine( combine(
combine( combine(
combine(expectedLeaves[8], expectedLeaves[9]), combine(expectedLeaves[8], zeroHash),
zeroHash oneHash
), ),
zeroHash oneHash
) )
) )
check: check:
tree.leaves == expectedLeaves[0..9] tree.leaves == expectedLeaves[0..8]
tree.len == 21 tree.len == 20
tree.root == expectedRoot tree.root == expectedRoot
test "tree with two leaves provides expected proofs": test "tree with two leaves provides expected proofs":
builder.addDataBlock(data[0]) builder.addDataBlock(data[0]).tryGet()
builder.addDataBlock(data[1]) builder.addDataBlock(data[1]).tryGet()
let tree = builder.build().tryGet() let tree = builder.build().tryGet()
let expectedProofs = [ let expectedProofs = [
MerkleProof.init(0, @[expectedLeaves[1]]), MerkleProof.init(0, @[expectedLeaves[1]]).tryGet(),
MerkleProof.init(1, @[expectedLeaves[0]]), MerkleProof.init(1, @[expectedLeaves[0]]).tryGet(),
] ]
check: check:
@ -140,16 +146,16 @@ checksuite "merkletree":
tree.getProof(1).tryGet() == expectedProofs[1] tree.getProof(1).tryGet() == expectedProofs[1]
test "tree with three leaves provides expected proofs": test "tree with three leaves provides expected proofs":
builder.addDataBlock(data[0]) builder.addDataBlock(data[0]).tryGet()
builder.addDataBlock(data[1]) builder.addDataBlock(data[1]).tryGet()
builder.addDataBlock(data[2]) builder.addDataBlock(data[2]).tryGet()
let tree = builder.build().tryGet() let tree = builder.build().tryGet()
let expectedProofs = [ let expectedProofs = [
MerkleProof.init(0, @[expectedLeaves[1], combine(expectedLeaves[2], zeroHash)]), MerkleProof.init(0, @[expectedLeaves[1], combine(expectedLeaves[2], zeroHash)]).tryGet(),
MerkleProof.init(1, @[expectedLeaves[0], combine(expectedLeaves[2], zeroHash)]), MerkleProof.init(1, @[expectedLeaves[0], combine(expectedLeaves[2], zeroHash)]).tryGet(),
MerkleProof.init(2, @[zeroHash, combine(expectedLeaves[0], expectedLeaves[1])]), MerkleProof.init(2, @[zeroHash, combine(expectedLeaves[0], expectedLeaves[1])]).tryGet(),
] ]
check: check:
@ -157,17 +163,16 @@ checksuite "merkletree":
tree.getProof(1).tryGet() == expectedProofs[1] tree.getProof(1).tryGet() == expectedProofs[1]
tree.getProof(2).tryGet() == expectedProofs[2] tree.getProof(2).tryGet() == expectedProofs[2]
test "tree with ten leaves provides expected proofs": test "tree with nine leaves provides expected proofs":
builder.addDataBlock(data[0]) builder.addDataBlock(data[0]).tryGet()
builder.addDataBlock(data[1]) builder.addDataBlock(data[1]).tryGet()
builder.addDataBlock(data[2]) builder.addDataBlock(data[2]).tryGet()
builder.addDataBlock(data[3]) builder.addDataBlock(data[3]).tryGet()
builder.addDataBlock(data[4]) builder.addDataBlock(data[4]).tryGet()
builder.addDataBlock(data[5]) builder.addDataBlock(data[5]).tryGet()
builder.addDataBlock(data[6]) builder.addDataBlock(data[6]).tryGet()
builder.addDataBlock(data[7]) builder.addDataBlock(data[7]).tryGet()
builder.addDataBlock(data[8]) builder.addDataBlock(data[8]).tryGet()
builder.addDataBlock(data[9])
let tree = builder.build().tryGet() let tree = builder.build().tryGet()
@ -182,17 +187,17 @@ checksuite "merkletree":
), ),
combine( combine(
combine( combine(
combine(expectedLeaves[8], expectedLeaves[9]), combine(expectedLeaves[8], zeroHash),
zeroHash oneHash
), ),
zeroHash oneHash
) )
]), ]).tryGet(),
9: 8:
MerkleProof.init(9, @[ MerkleProof.init(8, @[
expectedLeaves[8],
zeroHash,
zeroHash, zeroHash,
oneHash,
oneHash,
combine( combine(
combine( combine(
combine(expectedLeaves[0], expectedLeaves[1]), combine(expectedLeaves[0], expectedLeaves[1]),
@ -203,17 +208,17 @@ checksuite "merkletree":
combine(expectedLeaves[6], expectedLeaves[7]) combine(expectedLeaves[6], expectedLeaves[7])
) )
) )
]), ]).tryGet(),
}.newTable }.newTable
check: check:
tree.getProof(4).tryGet() == expectedProofs[4] tree.getProof(4).tryGet() == expectedProofs[4]
tree.getProof(9).tryGet() == expectedProofs[9] tree.getProof(8).tryGet() == expectedProofs[8]
test "getProof fails for index out of bounds": test "getProof fails for index out of bounds":
builder.addDataBlock(data[0]) builder.addDataBlock(data[0]).tryGet()
builder.addDataBlock(data[1]) builder.addDataBlock(data[1]).tryGet()
builder.addDataBlock(data[2]) builder.addDataBlock(data[2]).tryGet()
let tree = builder.build().tryGet() let tree = builder.build().tryGet()