diff --git a/codex/merkletree/codexmerkletree/coders.nim b/codex/merkletree/codexmerkletree/coders.nim index 72ffcd08..8a292d3d 100644 --- a/codex/merkletree/codexmerkletree/coders.nim +++ b/codex/merkletree/codexmerkletree/coders.nim @@ -28,13 +28,12 @@ const MaxMerkleProofSize = 1.MiBs.uint proc encode*(self: CodexMerkleTree): seq[byte] = var pb = initProtoBuffer(maxSize = MaxMerkleTreeSize) pb.write(1, self.mcodec.uint64) - pb.write(2, self.digestSize.uint64) - pb.write(3, self.leavesCount.uint64) - var nodesPb = initProtoBuffer(maxSize = MaxMerkleTreeSize) + pb.write(2, self.leavesCount.uint64) for node in self.nodes: + var nodesPb = initProtoBuffer(maxSize = MaxMerkleTreeSize) nodesPb.write(1, node) - nodesPb.finish() - pb.write(4, nodesPb) + nodesPb.finish() + pb.write(3, nodesPb) pb.finish pb.buffer @@ -42,11 +41,9 @@ proc encode*(self: CodexMerkleTree): seq[byte] = proc decode*(_: type CodexMerkleTree, data: seq[byte]): ?!CodexMerkleTree = var pb = initProtoBuffer(data, maxSize = MaxMerkleTreeSize) var mcodecCode: uint64 - var digestSize: uint64 var leavesCount: uint64 discard ? pb.getField(1, mcodecCode).mapFailure - discard ? pb.getField(2, digestSize).mapFailure - discard ? pb.getField(3, leavesCount).mapFailure + discard ? pb.getField(2, leavesCount).mapFailure let mcodec = MultiCodec.codec(mcodecCode.int) if mcodec == InvalidMultiCodec: @@ -56,42 +53,42 @@ proc decode*(_: type CodexMerkleTree, data: seq[byte]): ?!CodexMerkleTree = nodesBuff: seq[seq[byte]] nodes: seq[ByteHash] - if ? pb.getRepeatedField(4, nodesBuff).mapFailure: + if ? pb.getRepeatedField(3, nodesBuff).mapFailure: for nodeBuff in nodesBuff: var node: ByteHash - let nodePb = initProtoBuffer(nodeBuff) - discard ? nodePb.getField(1, node).mapFailure + discard ? initProtoBuffer(nodeBuff).getField(1, node).mapFailure nodes.add node - let tree = ? CodexMerkleTree.fromNodes(mcodec, digestSize, leavesCount, nodesBuffer) - success(tree) + CodexMerkleTree.fromNodes(mcodec, nodes, leavesCount.int) proc encode*(self: CodexMerkleProof): seq[byte] = var pb = initProtoBuffer(maxSize = MaxMerkleProofSize) pb.write(1, self.mcodec.uint64) - pb.write(2, self.digestSize.uint64) - pb.write(3, self.index.uint64) - var nodesPb = initProtoBuffer(maxSize = MaxMerkleTreeSize) + pb.write(2, self.index.uint64) + pb.write(3, self.nleaves.uint64) + for node in self.path: + var nodesPb = initProtoBuffer(maxSize = MaxMerkleTreeSize) nodesPb.write(1, node) - nodesPb.finish() - pb.write(4, nodesPb) + nodesPb.finish() + pb.write(4, nodesPb) + pb.finish pb.buffer proc decode*(_: type CodexMerkleProof, data: seq[byte]): ?!CodexMerkleProof = var pb = initProtoBuffer(data, maxSize = MaxMerkleProofSize) var mcodecCode: uint64 - var digestSize: uint64 var index: uint64 + var nleaves: uint64 discard ? pb.getField(1, mcodecCode).mapFailure let mcodec = MultiCodec.codec(mcodecCode.int) if mcodec == InvalidMultiCodec: return failure("Invalid MultiCodec code " & $mcodecCode) - discard ? pb.getField(2, digestSize).mapFailure - discard ? pb.getField(3, index).mapFailure + discard ? pb.getField(2, index).mapFailure + discard ? pb.getField(3, nleaves).mapFailure var nodesBuff: seq[seq[byte]] @@ -104,7 +101,4 @@ proc decode*(_: type CodexMerkleProof, data: seq[byte]): ?!CodexMerkleProof = discard ? nodePb.getField(1, node).mapFailure nodes.add node - let - proof = ? CodexMerkleProof.init(mcodec, index.int, nodes) - - success(proof) + CodexMerkleProof.init(mcodec, index.int, nleaves.int, nodes) diff --git a/codex/merkletree/codexmerkletree/codexmerkletree.nim b/codex/merkletree/codexmerkletree/codexmerkletree.nim index ae3e98d4..f51aec84 100644 --- a/codex/merkletree/codexmerkletree/codexmerkletree.nim +++ b/codex/merkletree/codexmerkletree/codexmerkletree.nim @@ -24,6 +24,8 @@ import pkg/questionable/results import pkg/libp2p/[cid, multicodec, multihash] import pkg/stew/byteutils +import ../../utils +import ../../rng import ../../errors import ../../blocktype @@ -86,23 +88,29 @@ func getProof*(self: CodexMerkleTree, index: int): ?!CodexMerkleProof = success proof -func verify*(self: CodexMerkleProof, root: MultiHash): ?!void = +func verify*(self: CodexMerkleProof, leaf: MultiHash, root: MultiHash): ?!void = ## Verify hash ## let - bytes = root.bytes + rootBytes = root.bytes + leafBytes = leaf.bytes - if self.mcodec != root.mcodec: + if self.mcodec != root.mcodec or + self.mcodec != leaf.mcodec: return failure "Hash codec mismatch" - if bytes.len != root.size: + if rootBytes.len != root.size and + leafBytes.len != leaf.size: return failure "Invalid hash length" - ? self.verify(bytes) + ? self.verify(leafBytes, rootBytes) success() +func verify*(self: CodexMerkleProof, leaf: Cid, root: Cid): ?!void = + self.verify(? leaf.mhash.mapFailure, ? leaf.mhash.mapFailure) + proc rootCid*( self: CodexMerkleTree, version = CIDv1, @@ -133,6 +141,17 @@ func getLeafCid*( dataCodec, ? MultiHash.init(self.mcodec, self.root).mapFailure).mapFailure +proc `==`*(a, b: CodexMerkleTree): bool = + (a.mcodec == b.mcodec) and + (a.leavesCount == b.leavesCount) and + (a.levels == b.levels) + +proc `==`*(a, b: CodexMerkleProof): bool = + (a.mcodec == b.mcodec) and + (a.nleaves == b.nleaves) and + (a.path == b.path) and + (a.index == b.index) + func compress*( x, y: openArray[byte], key: ByteTreeKey, @@ -192,10 +211,10 @@ func init*( CodexMerkleTree.init(mcodec, leaves) -func fromNodes*( +proc fromNodes*( _: type CodexMerkleTree, mcodec: MultiCodec, - nodes: openArray[seq[ByteHash]], + nodes: openArray[ByteHash], nleaves: int): ?!CodexMerkleTree = if nodes.len == 0: @@ -203,33 +222,35 @@ func fromNodes*( let mhash = ? mcodec.getMhash() - Zero = newSeq[ByteHash](mhash.size) - compressor = proc(x, y: openArray[byte], key: ByteTreeKey): ?!ByteHash {.noSideEffect.} = + Zero = newSeq[byte](mhash.size) + compressor = proc(x, y: seq[byte], key: ByteTreeKey): ?!ByteHash {.noSideEffect.} = compress(x, y, key, mhash) if mhash.size != nodes[0].len: return failure "Invalid hash length" - let - self = CodexMerkleTree(compress: compressor, zero: Zero, mhash: mhash) - var + self = CodexMerkleTree(compress: compressor, zero: Zero, mhash: mhash) layer = nleaves pos = 0 - while layer > 0: - self.layers.add( nodes[pos..