visual-identity/plasma_cash/utils/merkle/sparse_merkle_tree.py
2018-07-31 11:24:58 -04:00

118 lines
4.2 KiB
Python

from collections import OrderedDict
from eth_utils.crypto import keccak
class SparseMerkleTree(object):
def __init__(self, depth=64, leaves={}):
self.depth = depth
if len(leaves) > 2 ** depth:
raise self.TreeSizeExceededException(
'tree with depth {} cannot have {} leaves'.format(
depth, len(leaves)
)
)
# Sort the transaction dict by index.
self.leaves = OrderedDict(sorted(leaves.items(), key=lambda t: t[0]))
self.default_nodes = self.create_default_nodes(self.depth)
if leaves:
self.tree = self.create_tree(
self.leaves, self.depth, self.default_nodes
)
self.root = self.tree[-1][0]
else:
self.tree = []
self.root = self.default_nodes[self.depth]
def create_default_nodes(self, depth):
# Default nodes are the nodes whose children are both empty nodes at
# each level.
default_hash = keccak(b'\x00' * 32)
default_nodes = [default_hash]
for level in range(1, depth + 1):
prev_default = default_nodes[level - 1]
default_nodes.append(keccak(prev_default * 2))
return default_nodes
def create_tree(self, ordered_leaves, depth, default_nodes):
tree = [ordered_leaves]
tree_level = ordered_leaves
for level in range(depth):
next_level = {}
for index, value in tree_level.items():
if index % 2 == 0:
co_index = index + 1
if co_index in tree_level:
next_level[index // 2] = keccak(
value + tree_level[co_index]
)
else:
next_level[index // 2] = keccak(
value + default_nodes[level]
)
else:
# If the node is a right node, check if its left sibling is
# a default node.
co_index = index - 1
if co_index not in tree_level:
next_level[index // 2] = keccak(
default_nodes[level] + value
)
tree_level = next_level
tree.append(tree_level)
return tree
def create_merkle_proof(self, uid):
# Generate a merkle proof for a leaf with provided index.
# First `depth/8` bytes of the proof are necessary for checking if
# we are at a default-node
index = uid
proof = b''
proofbits = 0
# Edge case of tree being empty
if len(self.tree) == 0:
return b'\x00\x00\x00\x00\x00\x00\x00\x00'
for level in range(self.depth):
sibling_index = index + 1 if index % 2 == 0 else index - 1
index = index // 2
if sibling_index in self.tree[level]:
proof += self.tree[level][sibling_index]
proofbits += 2 ** level
proof_bytes = proofbits.to_bytes(8, byteorder='big')
return proof_bytes + proof
def verify(self, uid, proof):
''' Checks if the proof for the leaf at `uid` is valid'''
# assert (len(proof) -8 % 32) == 0
assert len(proof) <= 2056
proofbits = int.from_bytes((proof[0:8]), byteorder='big')
index = uid
p = 8
if index in self.leaves:
computed_hash = self.leaves[index]
# in case the tx is not included, computed_hash is the default leaf
else:
computed_hash = self.default_nodes[-1]
for d in range(self.depth):
if proofbits % 2 == 0:
proof_element = self.default_nodes[d]
else:
proof_element = proof[p : p + 32]
p += 32
if index % 2 == 0:
computed_hash = keccak(computed_hash + proof_element)
else:
computed_hash = keccak(proof_element + computed_hash)
proofbits = proofbits // 2
index = index // 2
return computed_hash == self.root
class TreeSizeExceededException(Exception):
"""there are too many leaves for the tree to build"""