From 50c0c5f12351ed0fe92781922c5c54bddb49726d Mon Sep 17 00:00:00 2001 From: KonradStaniec Date: Mon, 9 Aug 2021 12:17:21 +0200 Subject: [PATCH] Add helpers to generate merkle proofs (#381) --- eth/ssz/merkle_tree.nim | 111 ++++++++++++++++++++++++++++++++++ eth/ssz/merkleization.nim | 34 ++++++++++- tests/ssz/all_tests.nim | 3 +- tests/ssz/test_proofs.nim | 123 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 268 insertions(+), 3 deletions(-) create mode 100644 eth/ssz/merkle_tree.nim create mode 100644 tests/ssz/test_proofs.nim diff --git a/eth/ssz/merkle_tree.nim b/eth/ssz/merkle_tree.nim new file mode 100644 index 0000000..21977aa --- /dev/null +++ b/eth/ssz/merkle_tree.nim @@ -0,0 +1,111 @@ +{.push raises: [Defect].} + +import + math, sequtils, ssz_serialization, options, algorithm, + nimcrypto/hash, + ../common/eth_types, ./types, ./merkleization + +const maxTreeDepth: uint64 = 32 +const empty: seq[Digest] = @[] + +type + MerkleNodeType = enum + LeafType, + NodeType, + ZeroType + + MerkleNode = ref object + case kind: MerkleNodeType + of LeafType: + digest: Digest + of NodeType: + innerDigest: Digest + left: MerkleNode + right: MerkleNode + of ZeroType: + depth: uint64 + +func zeroNodes(): seq[MerkleNode] = + var nodes = newSeq[MerkleNode]() + for i in 0..maxTreeDepth: + nodes.add(MerkleNode(kind: ZeroType, depth: i)) + return nodes + +let zNodes = zeroNodes() + +# This look like something that should be in standard lib. +func splitAt[T](s: openArray[T], idx: uint64): (seq[T], seq[T]) = + var lSeq = newSeq[T]() + var rSeq = newSeq[T]() + for i, e in s: + if (uint64(i) < idx): + lSeq.add(e) + else: + rSeq.add(e) + (lSeq, rSeq) + +func splitLeaves(l: openArray[Digest], cap: uint64): (seq[Digest], seq[Digest]) = + if (uint64(len(l)) <= cap): + (l.toSeq(), empty) + else: + splitAt(l, cap) + +proc getSubTrees(node: MerkleNode): Option[(MerkleNode, MerkleNode)] = + case node.kind + of LeafType: + return none[(MerkleNode, MerkleNode)]() + of NodeType: + return some((node.left, node.right)) + of ZeroType: + if node.depth == 0: + return none[(MerkleNode, MerkleNode)]() + else: + return some((zNodes[node.depth - 1], zNodes[node.depth - 1])) + +func hash*(node: MerkleNode): Digest = + case node.kind + of LeafType: + node.digest + of NodeType: + node.innerDigest + of ZeroType: + zeroHashes[node.depth] + +func getCapacityAtDepth(depth: uint64): uint64 = + uint64 math.pow(2'f64, float64 depth) + +func createTree*(leaves: openArray[Digest], depth: uint64): MerkleNode = + if len(leaves) == 0: + return MerkleNode(kind: ZeroType, depth: depth) + elif depth == 0: + return MerkleNode(kind: LeafType, digest: leaves[0]) + else: + let nexLevelDepth = depth - 1 + let subCap = getCapacityAtDepth(nexLevelDepth) + let (left, right) = splitLeaves(leaves, subCap) + let leftTree = createTree(left, nexLevelDepth) + let rightTree = createTree(right, nexLevelDepth) + let finalHash = mergeBranches(leftTree.hash(), rightTree.hash()) + return MerkleNode(kind: NodeType, innerDigest: finalHash, left: leftTree, right: rightTree) + +proc genProof*(tree: MerkleNode, idx: uint64, treeDepth: uint64): seq[Digest] = + var proof = newSeq[Digest]() + var currNode = tree + var currDepth = treeDepth + while currDepth > 0: + let ithBit = (idx shr (currDepth - 1)) and 1 + # should be safe to call unsafeGet() as leaves are on lowest level, and depth is + # always larger than 0 + let (left, right) = getSubTrees(currNode).unsafeGet() + if ithBit == 1: + proof.add(left.hash()) + currNode = right + else: + proof.add(right.hash()) + currNode = left + currDepth = currDepth - 1 + + proof.reverse() + proof + +# TODO add method to add leaf to the exisiting tree diff --git a/eth/ssz/merkleization.nim b/eth/ssz/merkleization.nim index 172d4d7..b553499 100644 --- a/eth/ssz/merkleization.nim +++ b/eth/ssz/merkleization.nim @@ -12,7 +12,7 @@ {.push raises: [Defect].} import - math, + math, sequtils, stew/[bitops2, endians2, ptrops], stew/ranges/ptr_arith, nimcrypto/[hash, sha2], serialization/testing/tracing, @@ -70,6 +70,10 @@ template computeDigest*(body: untyped): Digest = body finish(h) +func digest(a: openArray[byte]): Digest = + result = computeDigest: + h.update(a) + func digest(a, b: openArray[byte]): Digest = result = computeDigest: trs "DIGESTING ARRAYS ", toHex(a), " ", toHex(b) @@ -99,7 +103,7 @@ template mergeBranches(existing: Digest, newData: array[32, byte]): Digest = trs "MERGING BRANCHES ARRAY" digest(existing.data, newData) -template mergeBranches(a, b: Digest): Digest = +template mergeBranches*(a, b: Digest): Digest = trs "MERGING BRANCHES DIGEST" digest(a.data, b.data) @@ -636,3 +640,29 @@ func isValidProof*(leaf: Digest, proof: openArray[Digest], value == root else: false + +proc slice[T](x: openArray[T]): seq[T] = x.toSeq() + +# Helper functions to get proof for any element of a list +proc getProofForAllListElements*(list: List): seq[Digest] = + type T = type(list) + type E = ElemType(T) + # basic types have different chunking rules + static: + doAssert (E is not BasicType) + var digests: seq[Digest] = @[] + for e in list: + let root = hash_tree_root(e) + digests.add(root) + var merk = createMerkleizer(list.maxLen) + merk.addChunksAndGenMerkleProofs(digests) + +proc getProofWithIdx*(list: List, allProofs: seq[Digest], idx: int): seq[Digest] = + let treeHeight = binaryTreeHeight(list.maxLen) + let startPos = idx * treeHeight + let endPos = startPos + treeHeight - 2 + slice(allProofs.toOpenArray(startPos, endPos)) + +proc generateAndGetProofWithIdx*(list: List, idx: int): seq[Digest] = + let allProofs = getProofForAllListElements(list) + getProofWithIdx(list, allProofs, idx) diff --git a/tests/ssz/all_tests.nim b/tests/ssz/all_tests.nim index c6d8857..0872cb6 100644 --- a/tests/ssz/all_tests.nim +++ b/tests/ssz/all_tests.nim @@ -1,3 +1,4 @@ import ./test_merkleization, - ./test_verification + ./test_verification, + ./test_proofs diff --git a/tests/ssz/test_proofs.nim b/tests/ssz/test_proofs.nim new file mode 100644 index 0000000..6328b31 --- /dev/null +++ b/tests/ssz/test_proofs.nim @@ -0,0 +1,123 @@ +{.used.} + +import + sequtils, unittest, math, + nimcrypto/[hash, sha2], + stew/endians2, + ../eth/ssz/merkleization, + ../eth/ssz/ssz_serialization, + ../eth/ssz/merkle_tree + +template toSszType(x: auto): auto = + x + +proc h(a: openArray[byte]): Digest = + var h: sha256 + h.init() + h.update(a) + h.finish() + +type TestObject = object + digest: array[32, byte] + num: uint64 + +proc genObject(num: uint64): TestObject = + let numAsHash = h(num.toBytesLE()) + TestObject(digest: numAsHash.data, num: num) + +proc genNObjects(n: int): seq[TestObject] = + var objs = newSeq[TestObject]() + for i in 1..n: + let obj = genObject(uint64 i) + objs.add(obj) + objs + +proc getGenIndex(idx: int, depth: uint64): uint64 = + uint64 (math.pow(2'f64, float64 depth) + float64 idx) + +# Normal hash_tree_root add list length to final hash calculation. Proofs by default +# are generated without it. If necessary length of the list can be added manually +# at the end of the proof but here we are just hashing list with no mixin. +proc getListRootNoMixin(list: List): Digest = + var merk = createMerkleizer(list.maxLen) + for e in list: + let hash = hash_tree_root(e) + merk.addChunk(hash.data) + merk.getFinalHash() + +type TestCase = object + numOfElements: int + limit: int + +const TestCases = ( + TestCase(numOfElements: 0, limit: 2), + TestCase(numOfElements: 1, limit: 2), + TestCase(numOfElements: 2, limit: 2), + + TestCase(numOfElements: 0, limit: 4), + TestCase(numOfElements: 1, limit: 4), + TestCase(numOfElements: 2, limit: 4), + TestCase(numOfElements: 3, limit: 4), + TestCase(numOfElements: 4, limit: 4), + + TestCase(numOfElements: 0, limit: 8), + TestCase(numOfElements: 1, limit: 8), + TestCase(numOfElements: 2, limit: 8), + TestCase(numOfElements: 3, limit: 8), + TestCase(numOfElements: 4, limit: 8), + TestCase(numOfElements: 5, limit: 8), + TestCase(numOfElements: 6, limit: 8), + TestCase(numOfElements: 7, limit: 8), + TestCase(numOfElements: 8, limit: 8), + + TestCase(numOfElements: 0, limit: 16), + TestCase(numOfElements: 1, limit: 16), + TestCase(numOfElements: 2, limit: 16), + TestCase(numOfElements: 3, limit: 16), + TestCase(numOfElements: 4, limit: 16), + TestCase(numOfElements: 5, limit: 16), + TestCase(numOfElements: 6, limit: 16), + TestCase(numOfElements: 7, limit: 16), + TestCase(numOfElements: 16, limit: 16), + + TestCase(numOfElements: 32, limit: 32), + + TestCase(numOfElements: 64, limit: 64) +) + +suite "Merkle Proof generation": + test "generation of proof for various tree sizes": + for testCase in TestCases.fields: + let testObjects = genNObjects(testCase.numOfElements) + let treeDepth = uint64 binaryTreeHeight(testCase.limit) - 1 + + # Create List and and genereate root by using merkelizer + let list = List.init(testObjects, testCase.limit) + let listRoot = getListRootNoMixin(list) + + # Create sparse merkle tree from list elements and generate root + let listDigests = map(testObjects, proc(x: TestObject): Digest = hash_tree_root(x)) + let tree = createTree(listDigests, treeDepth) + let treeHash = tree.hash() + + # Assert that by using both methods we get same hash + check listRoot == treeHash + + for i, e in list: + # generate proof by using merkelizer + let merkleizerProof = generateAndGetProofWithIdx(list, i) + # generate proof by sparse merkle tree + let sparseTreeProof = genProof(tree, uint64 i, treeDepth) + + let leafHash = hash_tree_root(e) + let genIndex = getGenIndex(i, treeDepth) + + # both proof are valid. If both are valid that means that both proof are + # effectivly the same + let isValidProof = isValidProof(leafHash , merkleizerProof, genIndex, listRoot) + let isValidProof1 = isValidProof(leafHash , sparseTreeProof, genIndex, listRoot) + + check isValidProof + check isValidProof1 + +