Made leaf nodes more explicit

This commit is contained in:
Vitalik Buterin 2017-08-03 00:23:05 -04:00
parent 20b86c2fb9
commit b82db5d868
2 changed files with 43 additions and 30 deletions

View File

@ -13,30 +13,40 @@ class EphemDB():
KV_TYPE = 0
BRANCH_TYPE = 1
LEAF_TYPE = 2
b1 = bytes([1])
b0 = bytes([0])
# Input: a serialized node
# Output: keypath OR left child, child OR right child, node typr
def parse_node(node):
if len(node) == 64:
return node[:32], node[32:], BRANCH_TYPE
if node[0] == BRANCH_TYPE:
# Output: left child, right child, node type
return node[1:33], node[33:], BRANCH_TYPE
elif node[0] == KV_TYPE:
# Output: keypath: child, node type
return decode_bin_path(node[1:-32]), node[-32:], KV_TYPE
elif node[0] == LEAF_TYPE:
# Output: None, value, node type
return None, node[1:], LEAF_TYPE
else:
return decode_bin_path(node[:-32]), node[-32:], KV_TYPE
raise Exception("Bad node")
# Serializes a key/value node
def encode_kv_node(keypath, node):
assert keypath
assert len(node) == 32
o = encode_bin_path(keypath) + node
assert len(o) < 64
o = bytes([KV_TYPE]) + encode_bin_path(keypath) + node
return o
# Serializes a branch node (ie. a node with 2 children)
def encode_branch_node(left, right):
assert len(left) == len(right) == 32
return left + right
return bytes([BRANCH_TYPE]) + left + right
# Serializes a leaf node
def encode_leaf_node(value):
return bytes([LEAF_TYPE]) + value
# Saves a value into the database and returns its hash
def hash_and_save(db, node):
@ -46,11 +56,11 @@ def hash_and_save(db, node):
# Fetches the value with a given keypath from the given node
def _get(db, node, keypath):
if not keypath:
return db.get(node)
L, R, nodetype = parse_node(db.get(node))
# Key-value node descend
if nodetype == KV_TYPE:
if nodetype == LEAF_TYPE:
return R
elif nodetype == KV_TYPE:
if keypath[:len(L)] == L:
return _get(db, R, keypath[len(L):])
else:
@ -64,25 +74,26 @@ def _get(db, node, keypath):
# Updates the value at the given keypath from the given node
def _update(db, node, keypath, val):
# Base case
if not keypath:
if val:
return hash_and_save(db, val)
else:
return b''
# Empty trie
if not node:
return hash_and_save(db, encode_kv_node(keypath, hash_and_save(db, val)))
if val:
return hash_and_save(db, encode_kv_node(keypath, hash_and_save(db, encode_leaf_node(val))))
else:
return b''
L, R, nodetype = parse_node(db.get(node))
# Node is a leaf node
if nodetype == LEAF_TYPE:
return hash_and_save(db, encode_leaf_node(val)) if val else b''
# node is a key-value node
if nodetype == KV_TYPE:
elif nodetype == KV_TYPE:
# Keypath prefixes match
if keypath[:len(L)] == L:
# Recurse into child
o = _update(db, R, keypath[len(L):], val)
# We are at the end, child is a value node
if len(L) == len(keypath):
return hash_and_save(db, encode_kv_node(L, o)) if o else b''
# If child is empty
if not o:
return b''
#print(db.get(o))
subL, subR, subnodetype = parse_node(db.get(o))
# If the child is a key-value node, compress together the keypaths
# into one node
@ -109,7 +120,7 @@ def _update(db, node, keypath, val):
# Case 2: keypath prefixes mismatch in the middle, so we need to break
# the keypath in half. We are in case (iii), (iv), (vii), (viii)
else:
valnode = hash_and_save(db, encode_kv_node(keypath[cf+1:], hash_and_save(db, val)))
valnode = hash_and_save(db, encode_kv_node(keypath[cf+1:], hash_and_save(db, encode_leaf_node(val))))
# oldnode: the child node the has the old child value
# Case 1: (i), (iii), (v), (vi)
if len(L) == cf + 1:
@ -158,12 +169,14 @@ def _update(db, node, keypath, val):
# Prints a tree, and checks that all invariants check out
def print_and_check_invariants(db, node, prefix=b''):
if len(prefix) == 160:
return {prefix: db.get(node)}
if node == b'' and prefix == b'':
return {}
L, R, nodetype = parse_node(db.get(node))
if nodetype == KV_TYPE:
if nodetype == LEAF_TYPE:
# All keys must be 160 bits
assert len(prefix) == 160
return {prefix: R}
elif nodetype == KV_TYPE:
# (k1, (k2, node)) two nested key values nodes not allowed
assert 0 < len(L) <= 160 - len(prefix)
if len(L) + len(prefix) < 160:
@ -182,14 +195,13 @@ def print_and_check_invariants(db, node, prefix=b''):
# Pretty-print all nodes in a tree (for debugging purposes)
def print_nodes(db, node, prefix=b''):
if len(prefix) == 160:
print('value node', encode_hex(node[:4]), db.get(node))
return
if node == b'':
print('empty node')
return
L, R, nodetype = parse_node(db.get(node))
if nodetype == KV_TYPE:
if nodetype == LEAF_TYPE:
print('value node', encode_hex(node[:4]), R)
elif nodetype == KV_TYPE:
print(('kv node:', encode_hex(node[:4]), ''.join(['1' if x == 1 else '0' for x in L]), encode_hex(R[:4])))
print_nodes(db, R, prefix + L)
else:

View File

@ -8,7 +8,7 @@ def shuffle_in_place(x):
random.shuffle(y)
return y
kvpairs = [(sha3(str(i))[12:], str(i) * 5) for i in range(2000)]
kvpairs = [(sha3(str(i))[12:], str(i).encode('utf-8') * 5) for i in range(2000)]
for path in ([], [1,0,1], [0,0,1,0], [1,0,0,1,0], [1,0,0,1,0,0,1,0], [1,0] * 8):
@ -37,5 +37,6 @@ for _ in range(3):
t.update(k, b'')
if not random.randrange(100):
t.to_dict()
#t.print_nodes()
assert t.root == b''