Added comments to code

This commit is contained in:
Vitalik Buterin 2017-08-02 22:46:03 -04:00
parent a829adb72b
commit 20b86c2fb9
1 changed files with 59 additions and 6 deletions

View File

@ -17,12 +17,15 @@ BRANCH_TYPE = 1
b1 = bytes([1]) b1 = bytes([1])
b0 = bytes([0]) b0 = bytes([0])
# Input: a serialized node
# Output: keypath OR left child, child OR right child, node typr
def parse_node(node): def parse_node(node):
if len(node) == 64: if len(node) == 64:
return node[:32], node[32:], BRANCH_TYPE return node[:32], node[32:], BRANCH_TYPE
else: else:
return decode_bin_path(node[:-32]), node[-32:], KV_TYPE return decode_bin_path(node[:-32]), node[-32:], KV_TYPE
# Serializes a key/value node
def encode_kv_node(keypath, node): def encode_kv_node(keypath, node):
assert keypath assert keypath
assert len(node) == 32 assert len(node) == 32
@ -30,107 +33,154 @@ def encode_kv_node(keypath, node):
assert len(o) < 64 assert len(o) < 64
return o return o
# Serializes a branch node (ie. a node with 2 children)
def encode_branch_node(left, right): def encode_branch_node(left, right):
assert len(left) == len(right) == 32 assert len(left) == len(right) == 32
return left + right return left + right
# Saves a value into the database and returns its hash
def hash_and_save(db, node): def hash_and_save(db, node):
h = sha3(node) h = sha3(node)
db.put(h, node) db.put(h, node)
return h return h
# Fetches the value with a given keypath from the given node
def _get(db, node, keypath): def _get(db, node, keypath):
if not keypath: if not keypath:
return db.get(node) return db.get(node)
L, R, nodetype = parse_node(db.get(node)) L, R, nodetype = parse_node(db.get(node))
# Key-value node descend
if nodetype == KV_TYPE: if nodetype == KV_TYPE:
if keypath[:len(L)] == L: if keypath[:len(L)] == L:
return _get(db, R, keypath[len(L):]) return _get(db, R, keypath[len(L):])
else: else:
return None return None
# Branch node descend
elif nodetype == BRANCH_TYPE: elif nodetype == BRANCH_TYPE:
if keypath[:1] == b0: if keypath[:1] == b0:
return _get(db, L, keypath[1:]) return _get(db, L, keypath[1:])
else: else:
return _get(db, R, keypath[1:]) return _get(db, R, keypath[1:])
# Updates the value at the given keypath from the given node
def _update(db, node, keypath, val): def _update(db, node, keypath, val):
# Base case
if not keypath: if not keypath:
if val: if val:
return hash_and_save(db, val) return hash_and_save(db, val)
else: else:
return b'' return b''
# Empty trie
if not node: if not node:
return hash_and_save(db, encode_kv_node(keypath, hash_and_save(db, val))) return hash_and_save(db, encode_kv_node(keypath, hash_and_save(db, val)))
L, R, nodetype = parse_node(db.get(node)) L, R, nodetype = parse_node(db.get(node))
# node is a key-value node
if nodetype == KV_TYPE: if nodetype == KV_TYPE:
# Keypath prefixes match
if keypath[:len(L)] == L: if keypath[:len(L)] == L:
# Recurse into child
o = _update(db, R, keypath[len(L):], val) o = _update(db, R, keypath[len(L):], val)
assert o is not None # We are at the end, child is a value node
if len(L) == len(keypath): if len(L) == len(keypath):
return hash_and_save(db, encode_kv_node(L, o)) if o else b'' return hash_and_save(db, encode_kv_node(L, o)) if o else b''
subL, subR, subnodetype = parse_node(db.get(o)) subL, subR, subnodetype = parse_node(db.get(o))
# If the child is a key-value node, compress together the keypaths
# into one node
if subnodetype == KV_TYPE: if subnodetype == KV_TYPE:
return hash_and_save(db, encode_kv_node(L + subL, subR)) return hash_and_save(db, encode_kv_node(L + subL, subR))
else: else:
return hash_and_save(db, encode_kv_node(L, o)) if o else b'' return hash_and_save(db, encode_kv_node(L, o)) if o else b''
# Keypath prefixes don't match. Here we will be converting a key-value node
# of the form (k, CHILD) into a structure of one of the following forms:
# i. (k[:-1], (NEWCHILD, CHILD))
# ii. (k[:-1], ((k2, NEWCHILD), CHILD))
# iii. (k1, ((k2, CHILD), NEWCHILD))
# iv. (k1, ((k2, CHILD), (k2', NEWCHILD))
# v. (CHILD, NEWCHILD)
# vi. ((k[1:], CHILD), (k', NEWCHILD))
# vii. ((k[1:], CHILD), NEWCHILD)
# viii (CHILD, (k[1:], NEWCHILD))
else: else:
cf = common_prefix_length(L, keypath[:len(L)]) cf = common_prefix_length(L, keypath[:len(L)])
# valnode: the child node that has the new value we are adding
# Case 1: keypath prefixes almost match, so we are in case (i), (ii), (v), (vi)
if len(keypath) == cf + 1: if len(keypath) == cf + 1:
valnode = val valnode = val
# Case 2: keypath prefixes mismatch in the middle, so we need to break
# the keypath in half. We are in case (iii), (iv), (vii), (viii)
else: else:
valnode = hash_and_save(db, encode_kv_node(keypath[cf+1:], hash_and_save(db, val))) valnode = hash_and_save(db, encode_kv_node(keypath[cf+1:], hash_and_save(db, val)))
# oldnode: the child node the has the old child value
# Case 1: (i), (iii), (v), (vi)
if len(L) == cf + 1: if len(L) == cf + 1:
oldnode = R oldnode = R
# (ii), (iv), (vi), (viii)
else: else:
oldnode = hash_and_save(db, encode_kv_node(L[cf+1:], R)) oldnode = hash_and_save(db, encode_kv_node(L[cf+1:], R))
# Create the new branch node (because the key paths diverge, there has to
# be some "first bit" at which they diverge, so there must be a branch
# node somewhere)
if keypath[cf:cf+1] == b1: if keypath[cf:cf+1] == b1:
newsub = hash_and_save(db, encode_branch_node(oldnode, valnode)) newsub = hash_and_save(db, encode_branch_node(oldnode, valnode))
else: else:
newsub = hash_and_save(db, encode_branch_node(valnode, oldnode)) newsub = hash_and_save(db, encode_branch_node(valnode, oldnode))
# Case 1: keypath prefixes match in the first bit, so we still need
# a kv node at the top
# (i) (ii) (iii) (iv)
if cf: if cf:
return hash_and_save(db, encode_kv_node(L[:cf], newsub)) return hash_and_save(db, encode_kv_node(L[:cf], newsub))
# Case 2: keypath prefixes diverge in the first bit, so we replace the
# kv node with a branch node
# (v) (vi) (vii) (viii)
else: else:
return newsub return newsub
# node is a branch node
elif nodetype == BRANCH_TYPE: elif nodetype == BRANCH_TYPE:
newL, newR = L, R newL, newR = L, R
# Which child node to update? Depends on first bit in keypath
if keypath[:1] == b0: if keypath[:1] == b0:
newL = _update(db, L, keypath[1:], val) newL = _update(db, L, keypath[1:], val)
else: else:
newR = _update(db, R, keypath[1:], val) newR = _update(db, R, keypath[1:], val)
# Compress branch node into kv node
if not newL or not newR: if not newL or not newR:
subL, subR, subnodetype = parse_node(db.get(newL or newR)) subL, subR, subnodetype = parse_node(db.get(newL or newR))
first_bit = b1 if newR else b0 first_bit = b1 if newR else b0
# Compress (k1, (k2, NODE)) -> (k1 + k2, NODE)
if subnodetype == KV_TYPE: if subnodetype == KV_TYPE:
return hash_and_save(db, encode_kv_node(first_bit + subL, subR)) return hash_and_save(db, encode_kv_node(first_bit + subL, subR))
# kv node pointing to a branch node
elif subnodetype == BRANCH_TYPE: elif subnodetype == BRANCH_TYPE:
return hash_and_save(db, encode_kv_node(first_bit, newL or newR)) return hash_and_save(db, encode_kv_node(first_bit, newL or newR))
raise Exception("cow")
else: else:
return hash_and_save(db, encode_branch_node(newL, newR)) return hash_and_save(db, encode_branch_node(newL, newR))
raise Exception("cow") raise Exception("How did I get here?")
# Prints a tree, and checks that all invariants check out
def print_and_check_invariants(db, node, prefix=b''): def print_and_check_invariants(db, node, prefix=b''):
#print('pci', node, prefix)
if len(prefix) == 160: if len(prefix) == 160:
return {prefix: db.get(node)} return {prefix: db.get(node)}
if node == b'' and prefix == b'': if node == b'' and prefix == b'':
return {} return {}
L, R, nodetype = parse_node(db.get(node)) L, R, nodetype = parse_node(db.get(node))
#print('lrn', L, R, nodetype)
if nodetype == KV_TYPE: if nodetype == KV_TYPE:
# (k1, (k2, node)) two nested key values nodes not allowed
assert 0 < len(L) <= 160 - len(prefix) assert 0 < len(L) <= 160 - len(prefix)
if len(L) + len(prefix) < 160: if len(L) + len(prefix) < 160:
subL, subR, subnodetype = parse_node(db.get(R)) subL, subR, subnodetype = parse_node(db.get(R))
assert subnodetype != KV_TYPE assert subnodetype != KV_TYPE
# Childre of a key node cannot be empty
assert subR != sha3(b'')
return print_and_check_invariants(db, R, prefix + L) return print_and_check_invariants(db, R, prefix + L)
else: else:
assert L and R # Children of a branch node cannot be empty
assert L != sha3(b'') and R != sha3(b'')
o = {} o = {}
o.update(print_and_check_invariants(db, L, prefix + b0)) o.update(print_and_check_invariants(db, L, prefix + b0))
o.update(print_and_check_invariants(db, R, prefix + b1)) o.update(print_and_check_invariants(db, R, prefix + b1))
return o return o
# Pretty-print all nodes in a tree (for debugging purposes)
def print_nodes(db, node, prefix=b''): def print_nodes(db, node, prefix=b''):
if len(prefix) == 160: if len(prefix) == 160:
print('value node', encode_hex(node[:4]), db.get(node)) print('value node', encode_hex(node[:4]), db.get(node))
@ -147,6 +197,7 @@ def print_nodes(db, node, prefix=b''):
print_nodes(db, L, prefix + b0) print_nodes(db, L, prefix + b0)
print_nodes(db, R, prefix + b1) print_nodes(db, R, prefix + b1)
# Get a Merkle proof
def _get_branch(db, node, keypath): def _get_branch(db, node, keypath):
if not keypath: if not keypath:
return [db.get(node)] return [db.get(node)]
@ -163,6 +214,7 @@ def _get_branch(db, node, keypath):
else: else:
return [b'\x03'+L] + _get_branch(db, R, keypath[1:]) return [b'\x03'+L] + _get_branch(db, R, keypath[1:])
# Verify a Merkle proof
def _verify_branch(branch, root, keypath, value): def _verify_branch(branch, root, keypath, value):
nodes = [branch[-1]] nodes = [branch[-1]]
_keypath = b'' _keypath = b''
@ -192,6 +244,7 @@ def _verify_branch(branch, root, keypath, value):
assert _get(db, root, keypath) == value assert _get(db, root, keypath) == value
return True return True
# Trie wrapper class
class Trie(): class Trie():
def __init__(self, db, root): def __init__(self, db, root):
self.db = db self.db = db