research/trie_research/new_bintrie.py

289 lines
10 KiB
Python
Raw Normal View History

2017-08-01 08:42:08 -04:00
from bin_utils import encode_bin_path, decode_bin_path, common_prefix_length, encode_bin, decode_bin
from ethereum.utils import sha3, encode_hex
class EphemDB():
def __init__(self):
self.kv = {}
def get(self, k):
return self.kv.get(k, None)
def put(self, k, v):
self.kv[k] = v
2017-08-03 03:20:03 -04:00
def delete(self, k):
del self.kv[k]
2017-08-01 08:42:08 -04:00
KV_TYPE = 0
BRANCH_TYPE = 1
2017-08-03 00:23:05 -04:00
LEAF_TYPE = 2
2017-08-01 08:42:08 -04:00
b1 = bytes([1])
b0 = bytes([0])
2017-08-02 22:46:03 -04:00
# Input: a serialized node
2017-08-01 08:42:08 -04:00
def parse_node(node):
2017-08-03 00:23:05 -04:00
if node[0] == BRANCH_TYPE:
# Output: left child, right child, node type
return node[1:33], node[33:], BRANCH_TYPE
elif node[0] == KV_TYPE:
# Output: keypath: child, node type
return decode_bin_path(node[1:-32]), node[-32:], KV_TYPE
elif node[0] == LEAF_TYPE:
# Output: None, value, node type
return None, node[1:], LEAF_TYPE
2017-08-01 08:42:08 -04:00
else:
2017-08-03 00:23:05 -04:00
raise Exception("Bad node")
2017-08-01 08:42:08 -04:00
2017-08-02 22:46:03 -04:00
# Serializes a key/value node
2017-08-01 08:42:08 -04:00
def encode_kv_node(keypath, node):
assert keypath
assert len(node) == 32
2017-08-03 00:23:05 -04:00
o = bytes([KV_TYPE]) + encode_bin_path(keypath) + node
2017-08-01 08:42:08 -04:00
return o
2017-08-02 22:46:03 -04:00
# Serializes a branch node (ie. a node with 2 children)
2017-08-01 08:42:08 -04:00
def encode_branch_node(left, right):
assert len(left) == len(right) == 32
2017-08-03 00:23:05 -04:00
return bytes([BRANCH_TYPE]) + left + right
# Serializes a leaf node
def encode_leaf_node(value):
return bytes([LEAF_TYPE]) + value
2017-08-01 08:42:08 -04:00
2017-08-02 22:46:03 -04:00
# Saves a value into the database and returns its hash
2017-08-01 08:42:08 -04:00
def hash_and_save(db, node):
h = sha3(node)
db.put(h, node)
return h
2017-08-02 22:46:03 -04:00
# Fetches the value with a given keypath from the given node
2017-08-01 08:42:08 -04:00
def _get(db, node, keypath):
L, R, nodetype = parse_node(db.get(node))
2017-08-02 22:46:03 -04:00
# Key-value node descend
2017-08-03 00:23:05 -04:00
if nodetype == LEAF_TYPE:
return R
elif nodetype == KV_TYPE:
2017-08-01 08:42:08 -04:00
if keypath[:len(L)] == L:
return _get(db, R, keypath[len(L):])
else:
return None
2017-08-02 22:46:03 -04:00
# Branch node descend
2017-08-01 08:42:08 -04:00
elif nodetype == BRANCH_TYPE:
if keypath[:1] == b0:
return _get(db, L, keypath[1:])
else:
return _get(db, R, keypath[1:])
2017-08-02 22:46:03 -04:00
# Updates the value at the given keypath from the given node
2017-08-01 08:42:08 -04:00
def _update(db, node, keypath, val):
2017-08-03 00:23:05 -04:00
# Empty trie
if not node:
2017-08-01 08:49:08 -04:00
if val:
2017-08-03 00:23:05 -04:00
return hash_and_save(db, encode_kv_node(keypath, hash_and_save(db, encode_leaf_node(val))))
2017-08-01 08:49:08 -04:00
else:
return b''
2017-08-01 08:42:08 -04:00
L, R, nodetype = parse_node(db.get(node))
2017-08-03 00:23:05 -04:00
# Node is a leaf node
if nodetype == LEAF_TYPE:
return hash_and_save(db, encode_leaf_node(val)) if val else b''
2017-08-02 22:46:03 -04:00
# node is a key-value node
2017-08-03 00:23:05 -04:00
elif nodetype == KV_TYPE:
2017-08-02 22:46:03 -04:00
# Keypath prefixes match
2017-08-01 08:42:08 -04:00
if keypath[:len(L)] == L:
2017-08-02 22:46:03 -04:00
# Recurse into child
2017-08-01 08:42:08 -04:00
o = _update(db, R, keypath[len(L):], val)
2017-08-03 00:23:05 -04:00
# If child is empty
if not o:
return b''
#print(db.get(o))
2017-08-01 08:42:08 -04:00
subL, subR, subnodetype = parse_node(db.get(o))
2017-08-02 22:46:03 -04:00
# If the child is a key-value node, compress together the keypaths
# into one node
2017-08-01 08:42:08 -04:00
if subnodetype == KV_TYPE:
return hash_and_save(db, encode_kv_node(L + subL, subR))
else:
return hash_and_save(db, encode_kv_node(L, o)) if o else b''
2017-08-02 22:46:03 -04:00
# Keypath prefixes don't match. Here we will be converting a key-value node
# of the form (k, CHILD) into a structure of one of the following forms:
# i. (k[:-1], (NEWCHILD, CHILD))
# ii. (k[:-1], ((k2, NEWCHILD), CHILD))
# iii. (k1, ((k2, CHILD), NEWCHILD))
# iv. (k1, ((k2, CHILD), (k2', NEWCHILD))
# v. (CHILD, NEWCHILD)
# vi. ((k[1:], CHILD), (k', NEWCHILD))
# vii. ((k[1:], CHILD), NEWCHILD)
# viii (CHILD, (k[1:], NEWCHILD))
2017-08-01 08:42:08 -04:00
else:
cf = common_prefix_length(L, keypath[:len(L)])
2017-08-02 22:46:03 -04:00
# valnode: the child node that has the new value we are adding
# Case 1: keypath prefixes almost match, so we are in case (i), (ii), (v), (vi)
2017-08-01 08:42:08 -04:00
if len(keypath) == cf + 1:
valnode = val
2017-08-02 22:46:03 -04:00
# Case 2: keypath prefixes mismatch in the middle, so we need to break
# the keypath in half. We are in case (iii), (iv), (vii), (viii)
2017-08-01 08:42:08 -04:00
else:
2017-08-03 00:23:05 -04:00
valnode = hash_and_save(db, encode_kv_node(keypath[cf+1:], hash_and_save(db, encode_leaf_node(val))))
2017-08-02 22:46:03 -04:00
# oldnode: the child node the has the old child value
# Case 1: (i), (iii), (v), (vi)
2017-08-01 08:42:08 -04:00
if len(L) == cf + 1:
oldnode = R
2017-08-02 22:46:03 -04:00
# (ii), (iv), (vi), (viii)
2017-08-01 08:42:08 -04:00
else:
oldnode = hash_and_save(db, encode_kv_node(L[cf+1:], R))
2017-08-02 22:46:03 -04:00
# Create the new branch node (because the key paths diverge, there has to
# be some "first bit" at which they diverge, so there must be a branch
# node somewhere)
2017-08-01 08:42:08 -04:00
if keypath[cf:cf+1] == b1:
newsub = hash_and_save(db, encode_branch_node(oldnode, valnode))
else:
newsub = hash_and_save(db, encode_branch_node(valnode, oldnode))
2017-08-02 22:46:03 -04:00
# Case 1: keypath prefixes match in the first bit, so we still need
# a kv node at the top
# (i) (ii) (iii) (iv)
2017-08-01 08:42:08 -04:00
if cf:
return hash_and_save(db, encode_kv_node(L[:cf], newsub))
2017-08-02 22:46:03 -04:00
# Case 2: keypath prefixes diverge in the first bit, so we replace the
# kv node with a branch node
# (v) (vi) (vii) (viii)
2017-08-01 08:42:08 -04:00
else:
return newsub
2017-08-02 22:46:03 -04:00
# node is a branch node
2017-08-01 08:42:08 -04:00
elif nodetype == BRANCH_TYPE:
newL, newR = L, R
2017-08-02 22:46:03 -04:00
# Which child node to update? Depends on first bit in keypath
2017-08-01 08:42:08 -04:00
if keypath[:1] == b0:
newL = _update(db, L, keypath[1:], val)
else:
newR = _update(db, R, keypath[1:], val)
2017-08-02 22:46:03 -04:00
# Compress branch node into kv node
2017-08-01 08:42:08 -04:00
if not newL or not newR:
subL, subR, subnodetype = parse_node(db.get(newL or newR))
first_bit = b1 if newR else b0
2017-08-02 22:46:03 -04:00
# Compress (k1, (k2, NODE)) -> (k1 + k2, NODE)
2017-08-01 08:42:08 -04:00
if subnodetype == KV_TYPE:
return hash_and_save(db, encode_kv_node(first_bit + subL, subR))
2017-08-02 22:46:03 -04:00
# kv node pointing to a branch node
2017-08-01 08:42:08 -04:00
elif subnodetype == BRANCH_TYPE:
return hash_and_save(db, encode_kv_node(first_bit, newL or newR))
else:
return hash_and_save(db, encode_branch_node(newL, newR))
2017-08-02 22:46:03 -04:00
raise Exception("How did I get here?")
2017-08-01 08:42:08 -04:00
2017-08-02 22:46:03 -04:00
# Prints a tree, and checks that all invariants check out
2017-08-01 08:42:08 -04:00
def print_and_check_invariants(db, node, prefix=b''):
if node == b'' and prefix == b'':
return {}
L, R, nodetype = parse_node(db.get(node))
2017-08-03 00:23:05 -04:00
if nodetype == LEAF_TYPE:
# All keys must be 160 bits
assert len(prefix) == 160
return {prefix: R}
elif nodetype == KV_TYPE:
2017-08-02 22:46:03 -04:00
# (k1, (k2, node)) two nested key values nodes not allowed
2017-08-01 08:42:08 -04:00
assert 0 < len(L) <= 160 - len(prefix)
if len(L) + len(prefix) < 160:
subL, subR, subnodetype = parse_node(db.get(R))
assert subnodetype != KV_TYPE
2017-08-02 22:46:03 -04:00
# Childre of a key node cannot be empty
assert subR != sha3(b'')
2017-08-01 08:42:08 -04:00
return print_and_check_invariants(db, R, prefix + L)
else:
2017-08-02 22:46:03 -04:00
# Children of a branch node cannot be empty
assert L != sha3(b'') and R != sha3(b'')
2017-08-01 08:42:08 -04:00
o = {}
o.update(print_and_check_invariants(db, L, prefix + b0))
o.update(print_and_check_invariants(db, R, prefix + b1))
return o
2017-08-02 22:46:03 -04:00
# Pretty-print all nodes in a tree (for debugging purposes)
2017-08-01 08:42:08 -04:00
def print_nodes(db, node, prefix=b''):
if node == b'':
print('empty node')
return
L, R, nodetype = parse_node(db.get(node))
2017-08-03 00:23:05 -04:00
if nodetype == LEAF_TYPE:
print('value node', encode_hex(node[:4]), R)
elif nodetype == KV_TYPE:
2017-08-01 08:42:08 -04:00
print(('kv node:', encode_hex(node[:4]), ''.join(['1' if x == 1 else '0' for x in L]), encode_hex(R[:4])))
print_nodes(db, R, prefix + L)
else:
print(('branch node:', encode_hex(node[:4]), encode_hex(L[:4]), encode_hex(R[:4])))
print_nodes(db, L, prefix + b0)
print_nodes(db, R, prefix + b1)
2017-08-02 22:46:03 -04:00
# Get a Merkle proof
2017-08-01 08:42:08 -04:00
def _get_branch(db, node, keypath):
if not keypath:
2017-08-02 04:36:44 -04:00
return [db.get(node)]
2017-08-01 08:42:08 -04:00
L, R, nodetype = parse_node(db.get(node))
if nodetype == KV_TYPE:
path = encode_bin_path(L)
if keypath[:len(L)] == L:
2017-08-02 04:36:44 -04:00
return [b'\x01'+path] + _get_branch(db, R, keypath[len(L):])
2017-08-01 08:42:08 -04:00
else:
2017-08-02 04:36:44 -04:00
return [b'\x01'+path, db.get(R)]
2017-08-01 08:42:08 -04:00
elif nodetype == BRANCH_TYPE:
if keypath[:1] == b0:
2017-08-02 04:36:44 -04:00
return [b'\x02'+R] + _get_branch(db, L, keypath[1:])
2017-08-01 08:42:08 -04:00
else:
2017-08-02 04:36:44 -04:00
return [b'\x03'+L] + _get_branch(db, R, keypath[1:])
2017-08-02 22:46:03 -04:00
# Verify a Merkle proof
2017-08-02 04:36:44 -04:00
def _verify_branch(branch, root, keypath, value):
nodes = [branch[-1]]
_keypath = b''
for data in branch[-2::-1]:
marker, node = data[0], data[1:]
# it's a keypath
if marker == 1:
node = decode_bin_path(node)
_keypath = node + _keypath
nodes.insert(0, encode_kv_node(node, sha3(nodes[0])))
# it's a right-side branch
elif marker == 2:
_keypath = b0 + _keypath
nodes.insert(0, encode_branch_node(sha3(nodes[0]), node))
# it's a left-side branch
elif marker == 3:
_keypath = b1 + _keypath
nodes.insert(0, encode_branch_node(node, sha3(nodes[0])))
else:
raise Exception("Foo")
L, R, nodetype = parse_node(nodes[0])
if value:
assert _keypath == keypath
assert sha3(nodes[0]) == root
db = EphemDB()
db.kv = {sha3(node): node for node in nodes}
assert _get(db, root, keypath) == value
return True
2017-08-01 08:42:08 -04:00
2017-08-02 22:46:03 -04:00
# Trie wrapper class
2017-08-01 08:42:08 -04:00
class Trie():
def __init__(self, db, root):
self.db = db
self.root = root
assert isinstance(self.root, bytes)
def get(self, key):
assert len(key) == 20
return _get(self.db, self.root, encode_bin(key))
def get_branch(self, key):
2017-08-02 04:36:44 -04:00
o = _get_branch(self.db, self.root, encode_bin(key))
assert _verify_branch(o, self.root, encode_bin(key), self.get(key))
return o
2017-08-01 08:42:08 -04:00
def update(self, key, value):
assert len(key) == 20
self.root = _update(self.db, self.root, encode_bin(key), value)
def to_dict(self, hexify=False):
o = print_and_check_invariants(self.db, self.root)
encoder = lambda x: encode_hex(x) if hexify else x
return {encoder(decode_bin(k)): v for k, v in o.items()}
def print_nodes(self):
print_nodes(self.db, self.root)