diff --git a/trie_research/new_bintrie.py b/trie_research/new_bintrie.py index b3d6fcc..d0cc676 100644 --- a/trie_research/new_bintrie.py +++ b/trie_research/new_bintrie.py @@ -17,12 +17,15 @@ BRANCH_TYPE = 1 b1 = bytes([1]) b0 = bytes([0]) +# Input: a serialized node +# Output: keypath OR left child, child OR right child, node typr def parse_node(node): if len(node) == 64: return node[:32], node[32:], BRANCH_TYPE else: return decode_bin_path(node[:-32]), node[-32:], KV_TYPE +# Serializes a key/value node def encode_kv_node(keypath, node): assert keypath assert len(node) == 32 @@ -30,107 +33,154 @@ def encode_kv_node(keypath, node): assert len(o) < 64 return o +# Serializes a branch node (ie. a node with 2 children) def encode_branch_node(left, right): assert len(left) == len(right) == 32 return left + right +# Saves a value into the database and returns its hash def hash_and_save(db, node): h = sha3(node) db.put(h, node) return h +# Fetches the value with a given keypath from the given node def _get(db, node, keypath): if not keypath: return db.get(node) L, R, nodetype = parse_node(db.get(node)) + # Key-value node descend if nodetype == KV_TYPE: if keypath[:len(L)] == L: return _get(db, R, keypath[len(L):]) else: return None + # Branch node descend elif nodetype == BRANCH_TYPE: if keypath[:1] == b0: return _get(db, L, keypath[1:]) else: return _get(db, R, keypath[1:]) +# Updates the value at the given keypath from the given node def _update(db, node, keypath, val): + # Base case if not keypath: if val: return hash_and_save(db, val) else: return b'' + # Empty trie if not node: return hash_and_save(db, encode_kv_node(keypath, hash_and_save(db, val))) L, R, nodetype = parse_node(db.get(node)) + # node is a key-value node if nodetype == KV_TYPE: + # Keypath prefixes match if keypath[:len(L)] == L: + # Recurse into child o = _update(db, R, keypath[len(L):], val) - assert o is not None + # We are at the end, child is a value node if len(L) == len(keypath): return hash_and_save(db, encode_kv_node(L, o)) if o else b'' subL, subR, subnodetype = parse_node(db.get(o)) + # If the child is a key-value node, compress together the keypaths + # into one node if subnodetype == KV_TYPE: return hash_and_save(db, encode_kv_node(L + subL, subR)) else: return hash_and_save(db, encode_kv_node(L, o)) if o else b'' + # Keypath prefixes don't match. Here we will be converting a key-value node + # of the form (k, CHILD) into a structure of one of the following forms: + # i. (k[:-1], (NEWCHILD, CHILD)) + # ii. (k[:-1], ((k2, NEWCHILD), CHILD)) + # iii. (k1, ((k2, CHILD), NEWCHILD)) + # iv. (k1, ((k2, CHILD), (k2', NEWCHILD)) + # v. (CHILD, NEWCHILD) + # vi. ((k[1:], CHILD), (k', NEWCHILD)) + # vii. ((k[1:], CHILD), NEWCHILD) + # viii (CHILD, (k[1:], NEWCHILD)) else: cf = common_prefix_length(L, keypath[:len(L)]) + # valnode: the child node that has the new value we are adding + # Case 1: keypath prefixes almost match, so we are in case (i), (ii), (v), (vi) if len(keypath) == cf + 1: valnode = val + # Case 2: keypath prefixes mismatch in the middle, so we need to break + # the keypath in half. We are in case (iii), (iv), (vii), (viii) else: valnode = hash_and_save(db, encode_kv_node(keypath[cf+1:], hash_and_save(db, val))) + # oldnode: the child node the has the old child value + # Case 1: (i), (iii), (v), (vi) if len(L) == cf + 1: oldnode = R + # (ii), (iv), (vi), (viii) else: oldnode = hash_and_save(db, encode_kv_node(L[cf+1:], R)) + # Create the new branch node (because the key paths diverge, there has to + # be some "first bit" at which they diverge, so there must be a branch + # node somewhere) if keypath[cf:cf+1] == b1: newsub = hash_and_save(db, encode_branch_node(oldnode, valnode)) else: newsub = hash_and_save(db, encode_branch_node(valnode, oldnode)) + # Case 1: keypath prefixes match in the first bit, so we still need + # a kv node at the top + # (i) (ii) (iii) (iv) if cf: return hash_and_save(db, encode_kv_node(L[:cf], newsub)) + # Case 2: keypath prefixes diverge in the first bit, so we replace the + # kv node with a branch node + # (v) (vi) (vii) (viii) else: return newsub + # node is a branch node elif nodetype == BRANCH_TYPE: newL, newR = L, R + # Which child node to update? Depends on first bit in keypath if keypath[:1] == b0: newL = _update(db, L, keypath[1:], val) else: newR = _update(db, R, keypath[1:], val) + # Compress branch node into kv node if not newL or not newR: subL, subR, subnodetype = parse_node(db.get(newL or newR)) first_bit = b1 if newR else b0 + # Compress (k1, (k2, NODE)) -> (k1 + k2, NODE) if subnodetype == KV_TYPE: return hash_and_save(db, encode_kv_node(first_bit + subL, subR)) + # kv node pointing to a branch node elif subnodetype == BRANCH_TYPE: return hash_and_save(db, encode_kv_node(first_bit, newL or newR)) - raise Exception("cow") else: return hash_and_save(db, encode_branch_node(newL, newR)) - raise Exception("cow") + raise Exception("How did I get here?") +# Prints a tree, and checks that all invariants check out def print_and_check_invariants(db, node, prefix=b''): - #print('pci', node, prefix) if len(prefix) == 160: return {prefix: db.get(node)} if node == b'' and prefix == b'': return {} L, R, nodetype = parse_node(db.get(node)) - #print('lrn', L, R, nodetype) if nodetype == KV_TYPE: + # (k1, (k2, node)) two nested key values nodes not allowed assert 0 < len(L) <= 160 - len(prefix) if len(L) + len(prefix) < 160: subL, subR, subnodetype = parse_node(db.get(R)) assert subnodetype != KV_TYPE + # Childre of a key node cannot be empty + assert subR != sha3(b'') return print_and_check_invariants(db, R, prefix + L) else: - assert L and R + # Children of a branch node cannot be empty + assert L != sha3(b'') and R != sha3(b'') o = {} o.update(print_and_check_invariants(db, L, prefix + b0)) o.update(print_and_check_invariants(db, R, prefix + b1)) return o +# Pretty-print all nodes in a tree (for debugging purposes) def print_nodes(db, node, prefix=b''): if len(prefix) == 160: print('value node', encode_hex(node[:4]), db.get(node)) @@ -147,6 +197,7 @@ def print_nodes(db, node, prefix=b''): print_nodes(db, L, prefix + b0) print_nodes(db, R, prefix + b1) +# Get a Merkle proof def _get_branch(db, node, keypath): if not keypath: return [db.get(node)] @@ -163,6 +214,7 @@ def _get_branch(db, node, keypath): else: return [b'\x03'+L] + _get_branch(db, R, keypath[1:]) +# Verify a Merkle proof def _verify_branch(branch, root, keypath, value): nodes = [branch[-1]] _keypath = b'' @@ -192,6 +244,7 @@ def _verify_branch(branch, root, keypath, value): assert _get(db, root, keypath) == value return True +# Trie wrapper class class Trie(): def __init__(self, db, root): self.db = db