Added comments to code
This commit is contained in:
parent
a829adb72b
commit
20b86c2fb9
|
@ -17,12 +17,15 @@ BRANCH_TYPE = 1
|
||||||
b1 = bytes([1])
|
b1 = bytes([1])
|
||||||
b0 = bytes([0])
|
b0 = bytes([0])
|
||||||
|
|
||||||
|
# Input: a serialized node
|
||||||
|
# Output: keypath OR left child, child OR right child, node typr
|
||||||
def parse_node(node):
|
def parse_node(node):
|
||||||
if len(node) == 64:
|
if len(node) == 64:
|
||||||
return node[:32], node[32:], BRANCH_TYPE
|
return node[:32], node[32:], BRANCH_TYPE
|
||||||
else:
|
else:
|
||||||
return decode_bin_path(node[:-32]), node[-32:], KV_TYPE
|
return decode_bin_path(node[:-32]), node[-32:], KV_TYPE
|
||||||
|
|
||||||
|
# Serializes a key/value node
|
||||||
def encode_kv_node(keypath, node):
|
def encode_kv_node(keypath, node):
|
||||||
assert keypath
|
assert keypath
|
||||||
assert len(node) == 32
|
assert len(node) == 32
|
||||||
|
@ -30,107 +33,154 @@ def encode_kv_node(keypath, node):
|
||||||
assert len(o) < 64
|
assert len(o) < 64
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
# Serializes a branch node (ie. a node with 2 children)
|
||||||
def encode_branch_node(left, right):
|
def encode_branch_node(left, right):
|
||||||
assert len(left) == len(right) == 32
|
assert len(left) == len(right) == 32
|
||||||
return left + right
|
return left + right
|
||||||
|
|
||||||
|
# Saves a value into the database and returns its hash
|
||||||
def hash_and_save(db, node):
|
def hash_and_save(db, node):
|
||||||
h = sha3(node)
|
h = sha3(node)
|
||||||
db.put(h, node)
|
db.put(h, node)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
# Fetches the value with a given keypath from the given node
|
||||||
def _get(db, node, keypath):
|
def _get(db, node, keypath):
|
||||||
if not keypath:
|
if not keypath:
|
||||||
return db.get(node)
|
return db.get(node)
|
||||||
L, R, nodetype = parse_node(db.get(node))
|
L, R, nodetype = parse_node(db.get(node))
|
||||||
|
# Key-value node descend
|
||||||
if nodetype == KV_TYPE:
|
if nodetype == KV_TYPE:
|
||||||
if keypath[:len(L)] == L:
|
if keypath[:len(L)] == L:
|
||||||
return _get(db, R, keypath[len(L):])
|
return _get(db, R, keypath[len(L):])
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
# Branch node descend
|
||||||
elif nodetype == BRANCH_TYPE:
|
elif nodetype == BRANCH_TYPE:
|
||||||
if keypath[:1] == b0:
|
if keypath[:1] == b0:
|
||||||
return _get(db, L, keypath[1:])
|
return _get(db, L, keypath[1:])
|
||||||
else:
|
else:
|
||||||
return _get(db, R, keypath[1:])
|
return _get(db, R, keypath[1:])
|
||||||
|
|
||||||
|
# Updates the value at the given keypath from the given node
|
||||||
def _update(db, node, keypath, val):
|
def _update(db, node, keypath, val):
|
||||||
|
# Base case
|
||||||
if not keypath:
|
if not keypath:
|
||||||
if val:
|
if val:
|
||||||
return hash_and_save(db, val)
|
return hash_and_save(db, val)
|
||||||
else:
|
else:
|
||||||
return b''
|
return b''
|
||||||
|
# Empty trie
|
||||||
if not node:
|
if not node:
|
||||||
return hash_and_save(db, encode_kv_node(keypath, hash_and_save(db, val)))
|
return hash_and_save(db, encode_kv_node(keypath, hash_and_save(db, val)))
|
||||||
L, R, nodetype = parse_node(db.get(node))
|
L, R, nodetype = parse_node(db.get(node))
|
||||||
|
# node is a key-value node
|
||||||
if nodetype == KV_TYPE:
|
if nodetype == KV_TYPE:
|
||||||
|
# Keypath prefixes match
|
||||||
if keypath[:len(L)] == L:
|
if keypath[:len(L)] == L:
|
||||||
|
# Recurse into child
|
||||||
o = _update(db, R, keypath[len(L):], val)
|
o = _update(db, R, keypath[len(L):], val)
|
||||||
assert o is not None
|
# We are at the end, child is a value node
|
||||||
if len(L) == len(keypath):
|
if len(L) == len(keypath):
|
||||||
return hash_and_save(db, encode_kv_node(L, o)) if o else b''
|
return hash_and_save(db, encode_kv_node(L, o)) if o else b''
|
||||||
subL, subR, subnodetype = parse_node(db.get(o))
|
subL, subR, subnodetype = parse_node(db.get(o))
|
||||||
|
# If the child is a key-value node, compress together the keypaths
|
||||||
|
# into one node
|
||||||
if subnodetype == KV_TYPE:
|
if subnodetype == KV_TYPE:
|
||||||
return hash_and_save(db, encode_kv_node(L + subL, subR))
|
return hash_and_save(db, encode_kv_node(L + subL, subR))
|
||||||
else:
|
else:
|
||||||
return hash_and_save(db, encode_kv_node(L, o)) if o else b''
|
return hash_and_save(db, encode_kv_node(L, o)) if o else b''
|
||||||
|
# Keypath prefixes don't match. Here we will be converting a key-value node
|
||||||
|
# of the form (k, CHILD) into a structure of one of the following forms:
|
||||||
|
# i. (k[:-1], (NEWCHILD, CHILD))
|
||||||
|
# ii. (k[:-1], ((k2, NEWCHILD), CHILD))
|
||||||
|
# iii. (k1, ((k2, CHILD), NEWCHILD))
|
||||||
|
# iv. (k1, ((k2, CHILD), (k2', NEWCHILD))
|
||||||
|
# v. (CHILD, NEWCHILD)
|
||||||
|
# vi. ((k[1:], CHILD), (k', NEWCHILD))
|
||||||
|
# vii. ((k[1:], CHILD), NEWCHILD)
|
||||||
|
# viii (CHILD, (k[1:], NEWCHILD))
|
||||||
else:
|
else:
|
||||||
cf = common_prefix_length(L, keypath[:len(L)])
|
cf = common_prefix_length(L, keypath[:len(L)])
|
||||||
|
# valnode: the child node that has the new value we are adding
|
||||||
|
# Case 1: keypath prefixes almost match, so we are in case (i), (ii), (v), (vi)
|
||||||
if len(keypath) == cf + 1:
|
if len(keypath) == cf + 1:
|
||||||
valnode = val
|
valnode = val
|
||||||
|
# Case 2: keypath prefixes mismatch in the middle, so we need to break
|
||||||
|
# the keypath in half. We are in case (iii), (iv), (vii), (viii)
|
||||||
else:
|
else:
|
||||||
valnode = hash_and_save(db, encode_kv_node(keypath[cf+1:], hash_and_save(db, val)))
|
valnode = hash_and_save(db, encode_kv_node(keypath[cf+1:], hash_and_save(db, val)))
|
||||||
|
# oldnode: the child node the has the old child value
|
||||||
|
# Case 1: (i), (iii), (v), (vi)
|
||||||
if len(L) == cf + 1:
|
if len(L) == cf + 1:
|
||||||
oldnode = R
|
oldnode = R
|
||||||
|
# (ii), (iv), (vi), (viii)
|
||||||
else:
|
else:
|
||||||
oldnode = hash_and_save(db, encode_kv_node(L[cf+1:], R))
|
oldnode = hash_and_save(db, encode_kv_node(L[cf+1:], R))
|
||||||
|
# Create the new branch node (because the key paths diverge, there has to
|
||||||
|
# be some "first bit" at which they diverge, so there must be a branch
|
||||||
|
# node somewhere)
|
||||||
if keypath[cf:cf+1] == b1:
|
if keypath[cf:cf+1] == b1:
|
||||||
newsub = hash_and_save(db, encode_branch_node(oldnode, valnode))
|
newsub = hash_and_save(db, encode_branch_node(oldnode, valnode))
|
||||||
else:
|
else:
|
||||||
newsub = hash_and_save(db, encode_branch_node(valnode, oldnode))
|
newsub = hash_and_save(db, encode_branch_node(valnode, oldnode))
|
||||||
|
# Case 1: keypath prefixes match in the first bit, so we still need
|
||||||
|
# a kv node at the top
|
||||||
|
# (i) (ii) (iii) (iv)
|
||||||
if cf:
|
if cf:
|
||||||
return hash_and_save(db, encode_kv_node(L[:cf], newsub))
|
return hash_and_save(db, encode_kv_node(L[:cf], newsub))
|
||||||
|
# Case 2: keypath prefixes diverge in the first bit, so we replace the
|
||||||
|
# kv node with a branch node
|
||||||
|
# (v) (vi) (vii) (viii)
|
||||||
else:
|
else:
|
||||||
return newsub
|
return newsub
|
||||||
|
# node is a branch node
|
||||||
elif nodetype == BRANCH_TYPE:
|
elif nodetype == BRANCH_TYPE:
|
||||||
newL, newR = L, R
|
newL, newR = L, R
|
||||||
|
# Which child node to update? Depends on first bit in keypath
|
||||||
if keypath[:1] == b0:
|
if keypath[:1] == b0:
|
||||||
newL = _update(db, L, keypath[1:], val)
|
newL = _update(db, L, keypath[1:], val)
|
||||||
else:
|
else:
|
||||||
newR = _update(db, R, keypath[1:], val)
|
newR = _update(db, R, keypath[1:], val)
|
||||||
|
# Compress branch node into kv node
|
||||||
if not newL or not newR:
|
if not newL or not newR:
|
||||||
subL, subR, subnodetype = parse_node(db.get(newL or newR))
|
subL, subR, subnodetype = parse_node(db.get(newL or newR))
|
||||||
first_bit = b1 if newR else b0
|
first_bit = b1 if newR else b0
|
||||||
|
# Compress (k1, (k2, NODE)) -> (k1 + k2, NODE)
|
||||||
if subnodetype == KV_TYPE:
|
if subnodetype == KV_TYPE:
|
||||||
return hash_and_save(db, encode_kv_node(first_bit + subL, subR))
|
return hash_and_save(db, encode_kv_node(first_bit + subL, subR))
|
||||||
|
# kv node pointing to a branch node
|
||||||
elif subnodetype == BRANCH_TYPE:
|
elif subnodetype == BRANCH_TYPE:
|
||||||
return hash_and_save(db, encode_kv_node(first_bit, newL or newR))
|
return hash_and_save(db, encode_kv_node(first_bit, newL or newR))
|
||||||
raise Exception("cow")
|
|
||||||
else:
|
else:
|
||||||
return hash_and_save(db, encode_branch_node(newL, newR))
|
return hash_and_save(db, encode_branch_node(newL, newR))
|
||||||
raise Exception("cow")
|
raise Exception("How did I get here?")
|
||||||
|
|
||||||
|
# Prints a tree, and checks that all invariants check out
|
||||||
def print_and_check_invariants(db, node, prefix=b''):
|
def print_and_check_invariants(db, node, prefix=b''):
|
||||||
#print('pci', node, prefix)
|
|
||||||
if len(prefix) == 160:
|
if len(prefix) == 160:
|
||||||
return {prefix: db.get(node)}
|
return {prefix: db.get(node)}
|
||||||
if node == b'' and prefix == b'':
|
if node == b'' and prefix == b'':
|
||||||
return {}
|
return {}
|
||||||
L, R, nodetype = parse_node(db.get(node))
|
L, R, nodetype = parse_node(db.get(node))
|
||||||
#print('lrn', L, R, nodetype)
|
|
||||||
if nodetype == KV_TYPE:
|
if nodetype == KV_TYPE:
|
||||||
|
# (k1, (k2, node)) two nested key values nodes not allowed
|
||||||
assert 0 < len(L) <= 160 - len(prefix)
|
assert 0 < len(L) <= 160 - len(prefix)
|
||||||
if len(L) + len(prefix) < 160:
|
if len(L) + len(prefix) < 160:
|
||||||
subL, subR, subnodetype = parse_node(db.get(R))
|
subL, subR, subnodetype = parse_node(db.get(R))
|
||||||
assert subnodetype != KV_TYPE
|
assert subnodetype != KV_TYPE
|
||||||
|
# Childre of a key node cannot be empty
|
||||||
|
assert subR != sha3(b'')
|
||||||
return print_and_check_invariants(db, R, prefix + L)
|
return print_and_check_invariants(db, R, prefix + L)
|
||||||
else:
|
else:
|
||||||
assert L and R
|
# Children of a branch node cannot be empty
|
||||||
|
assert L != sha3(b'') and R != sha3(b'')
|
||||||
o = {}
|
o = {}
|
||||||
o.update(print_and_check_invariants(db, L, prefix + b0))
|
o.update(print_and_check_invariants(db, L, prefix + b0))
|
||||||
o.update(print_and_check_invariants(db, R, prefix + b1))
|
o.update(print_and_check_invariants(db, R, prefix + b1))
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
# Pretty-print all nodes in a tree (for debugging purposes)
|
||||||
def print_nodes(db, node, prefix=b''):
|
def print_nodes(db, node, prefix=b''):
|
||||||
if len(prefix) == 160:
|
if len(prefix) == 160:
|
||||||
print('value node', encode_hex(node[:4]), db.get(node))
|
print('value node', encode_hex(node[:4]), db.get(node))
|
||||||
|
@ -147,6 +197,7 @@ def print_nodes(db, node, prefix=b''):
|
||||||
print_nodes(db, L, prefix + b0)
|
print_nodes(db, L, prefix + b0)
|
||||||
print_nodes(db, R, prefix + b1)
|
print_nodes(db, R, prefix + b1)
|
||||||
|
|
||||||
|
# Get a Merkle proof
|
||||||
def _get_branch(db, node, keypath):
|
def _get_branch(db, node, keypath):
|
||||||
if not keypath:
|
if not keypath:
|
||||||
return [db.get(node)]
|
return [db.get(node)]
|
||||||
|
@ -163,6 +214,7 @@ def _get_branch(db, node, keypath):
|
||||||
else:
|
else:
|
||||||
return [b'\x03'+L] + _get_branch(db, R, keypath[1:])
|
return [b'\x03'+L] + _get_branch(db, R, keypath[1:])
|
||||||
|
|
||||||
|
# Verify a Merkle proof
|
||||||
def _verify_branch(branch, root, keypath, value):
|
def _verify_branch(branch, root, keypath, value):
|
||||||
nodes = [branch[-1]]
|
nodes = [branch[-1]]
|
||||||
_keypath = b''
|
_keypath = b''
|
||||||
|
@ -192,6 +244,7 @@ def _verify_branch(branch, root, keypath, value):
|
||||||
assert _get(db, root, keypath) == value
|
assert _get(db, root, keypath) == value
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# Trie wrapper class
|
||||||
class Trie():
|
class Trie():
|
||||||
def __init__(self, db, root):
|
def __init__(self, db, root):
|
||||||
self.db = db
|
self.db = db
|
||||||
|
|
Loading…
Reference in New Issue