Add helpers to generate merkle proofs (#381)

This commit is contained in:
KonradStaniec 2021-08-09 12:17:21 +02:00 committed by GitHub
parent 9bc4fa366a
commit 50c0c5f123
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 268 additions and 3 deletions

111
eth/ssz/merkle_tree.nim Normal file
View File

@ -0,0 +1,111 @@
{.push raises: [Defect].}
import
math, sequtils, ssz_serialization, options, algorithm,
nimcrypto/hash,
../common/eth_types, ./types, ./merkleization
const maxTreeDepth: uint64 = 32
const empty: seq[Digest] = @[]
type
MerkleNodeType = enum
LeafType,
NodeType,
ZeroType
MerkleNode = ref object
case kind: MerkleNodeType
of LeafType:
digest: Digest
of NodeType:
innerDigest: Digest
left: MerkleNode
right: MerkleNode
of ZeroType:
depth: uint64
func zeroNodes(): seq[MerkleNode] =
var nodes = newSeq[MerkleNode]()
for i in 0..maxTreeDepth:
nodes.add(MerkleNode(kind: ZeroType, depth: i))
return nodes
let zNodes = zeroNodes()
# This look like something that should be in standard lib.
func splitAt[T](s: openArray[T], idx: uint64): (seq[T], seq[T]) =
var lSeq = newSeq[T]()
var rSeq = newSeq[T]()
for i, e in s:
if (uint64(i) < idx):
lSeq.add(e)
else:
rSeq.add(e)
(lSeq, rSeq)
func splitLeaves(l: openArray[Digest], cap: uint64): (seq[Digest], seq[Digest]) =
if (uint64(len(l)) <= cap):
(l.toSeq(), empty)
else:
splitAt(l, cap)
proc getSubTrees(node: MerkleNode): Option[(MerkleNode, MerkleNode)] =
case node.kind
of LeafType:
return none[(MerkleNode, MerkleNode)]()
of NodeType:
return some((node.left, node.right))
of ZeroType:
if node.depth == 0:
return none[(MerkleNode, MerkleNode)]()
else:
return some((zNodes[node.depth - 1], zNodes[node.depth - 1]))
func hash*(node: MerkleNode): Digest =
case node.kind
of LeafType:
node.digest
of NodeType:
node.innerDigest
of ZeroType:
zeroHashes[node.depth]
func getCapacityAtDepth(depth: uint64): uint64 =
uint64 math.pow(2'f64, float64 depth)
func createTree*(leaves: openArray[Digest], depth: uint64): MerkleNode =
if len(leaves) == 0:
return MerkleNode(kind: ZeroType, depth: depth)
elif depth == 0:
return MerkleNode(kind: LeafType, digest: leaves[0])
else:
let nexLevelDepth = depth - 1
let subCap = getCapacityAtDepth(nexLevelDepth)
let (left, right) = splitLeaves(leaves, subCap)
let leftTree = createTree(left, nexLevelDepth)
let rightTree = createTree(right, nexLevelDepth)
let finalHash = mergeBranches(leftTree.hash(), rightTree.hash())
return MerkleNode(kind: NodeType, innerDigest: finalHash, left: leftTree, right: rightTree)
proc genProof*(tree: MerkleNode, idx: uint64, treeDepth: uint64): seq[Digest] =
var proof = newSeq[Digest]()
var currNode = tree
var currDepth = treeDepth
while currDepth > 0:
let ithBit = (idx shr (currDepth - 1)) and 1
# should be safe to call unsafeGet() as leaves are on lowest level, and depth is
# always larger than 0
let (left, right) = getSubTrees(currNode).unsafeGet()
if ithBit == 1:
proof.add(left.hash())
currNode = right
else:
proof.add(right.hash())
currNode = left
currDepth = currDepth - 1
proof.reverse()
proof
# TODO add method to add leaf to the exisiting tree

View File

@ -12,7 +12,7 @@
{.push raises: [Defect].}
import
math,
math, sequtils,
stew/[bitops2, endians2, ptrops],
stew/ranges/ptr_arith, nimcrypto/[hash, sha2],
serialization/testing/tracing,
@ -70,6 +70,10 @@ template computeDigest*(body: untyped): Digest =
body
finish(h)
func digest(a: openArray[byte]): Digest =
result = computeDigest:
h.update(a)
func digest(a, b: openArray[byte]): Digest =
result = computeDigest:
trs "DIGESTING ARRAYS ", toHex(a), " ", toHex(b)
@ -99,7 +103,7 @@ template mergeBranches(existing: Digest, newData: array[32, byte]): Digest =
trs "MERGING BRANCHES ARRAY"
digest(existing.data, newData)
template mergeBranches(a, b: Digest): Digest =
template mergeBranches*(a, b: Digest): Digest =
trs "MERGING BRANCHES DIGEST"
digest(a.data, b.data)
@ -636,3 +640,29 @@ func isValidProof*(leaf: Digest, proof: openArray[Digest],
value == root
else:
false
proc slice[T](x: openArray[T]): seq[T] = x.toSeq()
# Helper functions to get proof for any element of a list
proc getProofForAllListElements*(list: List): seq[Digest] =
type T = type(list)
type E = ElemType(T)
# basic types have different chunking rules
static:
doAssert (E is not BasicType)
var digests: seq[Digest] = @[]
for e in list:
let root = hash_tree_root(e)
digests.add(root)
var merk = createMerkleizer(list.maxLen)
merk.addChunksAndGenMerkleProofs(digests)
proc getProofWithIdx*(list: List, allProofs: seq[Digest], idx: int): seq[Digest] =
let treeHeight = binaryTreeHeight(list.maxLen)
let startPos = idx * treeHeight
let endPos = startPos + treeHeight - 2
slice(allProofs.toOpenArray(startPos, endPos))
proc generateAndGetProofWithIdx*(list: List, idx: int): seq[Digest] =
let allProofs = getProofForAllListElements(list)
getProofWithIdx(list, allProofs, idx)

View File

@ -1,3 +1,4 @@
import
./test_merkleization,
./test_verification
./test_verification,
./test_proofs

123
tests/ssz/test_proofs.nim Normal file
View File

@ -0,0 +1,123 @@
{.used.}
import
sequtils, unittest, math,
nimcrypto/[hash, sha2],
stew/endians2,
../eth/ssz/merkleization,
../eth/ssz/ssz_serialization,
../eth/ssz/merkle_tree
template toSszType(x: auto): auto =
x
proc h(a: openArray[byte]): Digest =
var h: sha256
h.init()
h.update(a)
h.finish()
type TestObject = object
digest: array[32, byte]
num: uint64
proc genObject(num: uint64): TestObject =
let numAsHash = h(num.toBytesLE())
TestObject(digest: numAsHash.data, num: num)
proc genNObjects(n: int): seq[TestObject] =
var objs = newSeq[TestObject]()
for i in 1..n:
let obj = genObject(uint64 i)
objs.add(obj)
objs
proc getGenIndex(idx: int, depth: uint64): uint64 =
uint64 (math.pow(2'f64, float64 depth) + float64 idx)
# Normal hash_tree_root add list length to final hash calculation. Proofs by default
# are generated without it. If necessary length of the list can be added manually
# at the end of the proof but here we are just hashing list with no mixin.
proc getListRootNoMixin(list: List): Digest =
var merk = createMerkleizer(list.maxLen)
for e in list:
let hash = hash_tree_root(e)
merk.addChunk(hash.data)
merk.getFinalHash()
type TestCase = object
numOfElements: int
limit: int
const TestCases = (
TestCase(numOfElements: 0, limit: 2),
TestCase(numOfElements: 1, limit: 2),
TestCase(numOfElements: 2, limit: 2),
TestCase(numOfElements: 0, limit: 4),
TestCase(numOfElements: 1, limit: 4),
TestCase(numOfElements: 2, limit: 4),
TestCase(numOfElements: 3, limit: 4),
TestCase(numOfElements: 4, limit: 4),
TestCase(numOfElements: 0, limit: 8),
TestCase(numOfElements: 1, limit: 8),
TestCase(numOfElements: 2, limit: 8),
TestCase(numOfElements: 3, limit: 8),
TestCase(numOfElements: 4, limit: 8),
TestCase(numOfElements: 5, limit: 8),
TestCase(numOfElements: 6, limit: 8),
TestCase(numOfElements: 7, limit: 8),
TestCase(numOfElements: 8, limit: 8),
TestCase(numOfElements: 0, limit: 16),
TestCase(numOfElements: 1, limit: 16),
TestCase(numOfElements: 2, limit: 16),
TestCase(numOfElements: 3, limit: 16),
TestCase(numOfElements: 4, limit: 16),
TestCase(numOfElements: 5, limit: 16),
TestCase(numOfElements: 6, limit: 16),
TestCase(numOfElements: 7, limit: 16),
TestCase(numOfElements: 16, limit: 16),
TestCase(numOfElements: 32, limit: 32),
TestCase(numOfElements: 64, limit: 64)
)
suite "Merkle Proof generation":
test "generation of proof for various tree sizes":
for testCase in TestCases.fields:
let testObjects = genNObjects(testCase.numOfElements)
let treeDepth = uint64 binaryTreeHeight(testCase.limit) - 1
# Create List and and genereate root by using merkelizer
let list = List.init(testObjects, testCase.limit)
let listRoot = getListRootNoMixin(list)
# Create sparse merkle tree from list elements and generate root
let listDigests = map(testObjects, proc(x: TestObject): Digest = hash_tree_root(x))
let tree = createTree(listDigests, treeDepth)
let treeHash = tree.hash()
# Assert that by using both methods we get same hash
check listRoot == treeHash
for i, e in list:
# generate proof by using merkelizer
let merkleizerProof = generateAndGetProofWithIdx(list, i)
# generate proof by sparse merkle tree
let sparseTreeProof = genProof(tree, uint64 i, treeDepth)
let leafHash = hash_tree_root(e)
let genIndex = getGenIndex(i, treeDepth)
# both proof are valid. If both are valid that means that both proof are
# effectivly the same
let isValidProof = isValidProof(leafHash , merkleizerProof, genIndex, listRoot)
let isValidProof1 = isValidProof(leafHash , sparseTreeProof, genIndex, listRoot)
check isValidProof
check isValidProof1