From a5364958990c197a16eb0708976f9e6229f821ba Mon Sep 17 00:00:00 2001 From: Vitalik Buterin Date: Fri, 23 Mar 2018 09:31:07 -0400 Subject: [PATCH] Bring back patricia tree --- trie_research/bintrie1/bin_utils.py | 56 +++ trie_research/bintrie1/compact_branches.py | 46 +++ trie_research/bintrie1/compress_witness.py | 84 +++++ trie_research/bintrie1/new_bintrie.py | 322 ++++++++++++++++++ .../bintrie1/new_bintrie_aggregate.py | 117 +++++++ .../bintrie1/new_bintrie_aggregate_tests.py | 52 +++ trie_research/bintrie1/new_bintrie_tests.py | 65 ++++ 7 files changed, 742 insertions(+) create mode 100644 trie_research/bintrie1/bin_utils.py create mode 100644 trie_research/bintrie1/compact_branches.py create mode 100644 trie_research/bintrie1/compress_witness.py create mode 100644 trie_research/bintrie1/new_bintrie.py create mode 100644 trie_research/bintrie1/new_bintrie_aggregate.py create mode 100644 trie_research/bintrie1/new_bintrie_aggregate_tests.py create mode 100644 trie_research/bintrie1/new_bintrie_tests.py diff --git a/trie_research/bintrie1/bin_utils.py b/trie_research/bintrie1/bin_utils.py new file mode 100644 index 0000000..ce8820a --- /dev/null +++ b/trie_research/bintrie1/bin_utils.py @@ -0,0 +1,56 @@ +from ethereum.utils import safe_ord as ord + +# 0100000101010111010000110100100101001001 -> ASCII +def decode_bin(x): + o = bytearray(len(x) // 8) + for i in range(0, len(x), 8): + v = 0 + for c in x[i:i+8]: + v = v * 2 + c + o[i//8] = v + return bytes(o) + + +# ASCII -> 0100000101010111010000110100100101001001 +def encode_bin(x): + o = b'' + for c in x: + c = ord(c) + p = bytearray(8) + for i in range(8): + p[7-i] = c % 2 + c //= 2 + o += p + return o + +two_bits = [bytes([0,0]), bytes([0,1]), + bytes([1,0]), bytes([1,1])] +prefix00 = bytes([0,0]) +prefix100000 = bytes([1,0,0,0,0,0]) + + +# Encodes a sequence of 0s and 1s into tightly packed bytes +def encode_bin_path(b): + b2 = bytes((4 - len(b)) % 4) + b + prefix = two_bits[len(b) % 4] + if len(b2) % 8 == 4: + return decode_bin(prefix00 + prefix + b2) + else: + return decode_bin(prefix100000 + prefix + b2) + + +# Decodes bytes into a sequence of 0s and 1s +def decode_bin_path(p): + p = encode_bin(p) + if p[0] == 1: + p = p[4:] + assert p[0:2] == prefix00 + L = two_bits.index(p[2:4]) + return p[4+((4 - L) % 4):] + +def common_prefix_length(a, b): + o = 0 + while o < len(a) and o < len(b) and a[o] == b[o]: + o += 1 + return o + diff --git a/trie_research/bintrie1/compact_branches.py b/trie_research/bintrie1/compact_branches.py new file mode 100644 index 0000000..e8bee44 --- /dev/null +++ b/trie_research/bintrie1/compact_branches.py @@ -0,0 +1,46 @@ +# Get a Merkle proof +def _get_branch(db, node, keypath): + if not keypath: + return [db.get(node)] + L, R, nodetype = parse_node(db.get(node)) + if nodetype == KV_TYPE: + path = encode_bin_path(L) + if keypath[:len(L)] == L: + return [b'\x01'+path] + _get_branch(db, R, keypath[len(L):]) + else: + return [b'\x01'+path, db.get(R)] + elif nodetype == BRANCH_TYPE: + if keypath[:1] == b0: + return [b'\x02'+R] + _get_branch(db, L, keypath[1:]) + else: + return [b'\x03'+L] + _get_branch(db, R, keypath[1:]) + +# Verify a Merkle proof +def _verify_branch(branch, root, keypath, value): + nodes = [branch[-1]] + _keypath = b'' + for data in branch[-2::-1]: + marker, node = data[0], data[1:] + # it's a keypath + if marker == 1: + node = decode_bin_path(node) + _keypath = node + _keypath + nodes.insert(0, encode_kv_node(node, sha3(nodes[0]))) + # it's a right-side branch + elif marker == 2: + _keypath = b0 + _keypath + nodes.insert(0, encode_branch_node(sha3(nodes[0]), node)) + # it's a left-side branch + elif marker == 3: + _keypath = b1 + _keypath + nodes.insert(0, encode_branch_node(node, sha3(nodes[0]))) + else: + raise Exception("Foo") + L, R, nodetype = parse_node(nodes[0]) + if value: + assert _keypath == keypath + assert sha3(nodes[0]) == root + db = EphemDB() + db.kv = {sha3(node): node for node in nodes} + assert _get(db, root, keypath) == value + return True diff --git a/trie_research/bintrie1/compress_witness.py b/trie_research/bintrie1/compress_witness.py new file mode 100644 index 0000000..74e59f6 --- /dev/null +++ b/trie_research/bintrie1/compress_witness.py @@ -0,0 +1,84 @@ +from ethereum.utils import sha3, encode_hex +from new_bintrie import parse_node, KV_TYPE, BRANCH_TYPE, LEAF_TYPE, encode_bin_path, encode_kv_node, encode_branch_node, decode_bin_path + +KV_COMPRESS_TYPE = 128 +BRANCH_LEFT_TYPE = 129 +BRANCH_RIGHT_TYPE = 130 + +def compress(witness): + parentmap = {} + leaves = [] + for w in witness: + L, R, nodetype = parse_node(w) + if nodetype == LEAF_TYPE: + leaves.append(w) + elif nodetype == KV_TYPE: + parentmap[R] = w + elif nodetype == BRANCH_TYPE: + parentmap[L] = w + parentmap[R] = w + used = {} + proof = [] + for node in leaves: + proof.append(node) + used[node] = True + h = sha3(node) + while h in parentmap: + node = parentmap[h] + L, R, nodetype = parse_node(node) + if nodetype == KV_TYPE: + proof.append(bytes([KV_COMPRESS_TYPE]) + encode_bin_path(L)) + elif nodetype == BRANCH_TYPE and L == h: + proof.append(bytes([BRANCH_LEFT_TYPE]) + R) + elif nodetype == BRANCH_TYPE and R == h: + proof.append(bytes([BRANCH_RIGHT_TYPE]) + L) + else: + raise Exception("something is wrong") + h = sha3(node) + if h in used: + proof.pop() + break + used[h] = True + assert len(used) == len(proof) + return proof + +# Input: a serialized node +def parse_proof_node(node): + if node[0] == BRANCH_LEFT_TYPE: + # Output: right child, node type + return node[1:33], BRANCH_LEFT_TYPE + elif node[0] == BRANCH_RIGHT_TYPE: + # Output: left child, node type + return node[1:33], BRANCH_RIGHT_TYPE + elif node[0] == KV_COMPRESS_TYPE: + # Output: keypath: child, node type + return decode_bin_path(node[1:]), KV_COMPRESS_TYPE + elif node[0] == LEAF_TYPE: + # Output: None, value, node type + return node[1:], LEAF_TYPE + else: + raise Exception("Bad node") + +def expand(proof): + witness = [] + lasthash = None + for p in proof: + sub, nodetype = parse_proof_node(p) + if nodetype == LEAF_TYPE: + witness.append(p) + lasthash = sha3(p) + elif nodetype == KV_COMPRESS_TYPE: + fullnode = encode_kv_node(sub, lasthash) + witness.append(fullnode) + lasthash = sha3(fullnode) + elif nodetype == BRANCH_LEFT_TYPE: + fullnode = encode_branch_node(lasthash, sub) + witness.append(fullnode) + lasthash = sha3(fullnode) + elif nodetype == BRANCH_RIGHT_TYPE: + fullnode = encode_branch_node(sub, lasthash) + witness.append(fullnode) + lasthash = sha3(fullnode) + else: + raise Exception("Bad node") + return witness diff --git a/trie_research/bintrie1/new_bintrie.py b/trie_research/bintrie1/new_bintrie.py new file mode 100644 index 0000000..a438cad --- /dev/null +++ b/trie_research/bintrie1/new_bintrie.py @@ -0,0 +1,322 @@ +from bin_utils import encode_bin_path, decode_bin_path, common_prefix_length, encode_bin, decode_bin +from ethereum.utils import sha3, encode_hex + +class EphemDB(): + def __init__(self, kv=None): + self.kv = kv or {} + + def get(self, k): + return self.kv.get(k, None) + + def put(self, k, v): + self.kv[k] = v + + def delete(self, k): + del self.kv[k] + +KV_TYPE = 0 +BRANCH_TYPE = 1 +LEAF_TYPE = 2 + +b1 = bytes([1]) +b0 = bytes([0]) + +# Input: a serialized node +def parse_node(node): + if node[0] == BRANCH_TYPE: + # Output: left child, right child, node type + return node[1:33], node[33:], BRANCH_TYPE + elif node[0] == KV_TYPE: + # Output: keypath: child, node type + return decode_bin_path(node[1:-32]), node[-32:], KV_TYPE + elif node[0] == LEAF_TYPE: + # Output: None, value, node type + return None, node[1:], LEAF_TYPE + else: + raise Exception("Bad node") + +# Serializes a key/value node +def encode_kv_node(keypath, node): + assert keypath + assert len(node) == 32 + o = bytes([KV_TYPE]) + encode_bin_path(keypath) + node + return o + +# Serializes a branch node (ie. a node with 2 children) +def encode_branch_node(left, right): + assert len(left) == len(right) == 32 + return bytes([BRANCH_TYPE]) + left + right + +# Serializes a leaf node +def encode_leaf_node(value): + return bytes([LEAF_TYPE]) + value + +# Saves a value into the database and returns its hash +def hash_and_save(db, node): + h = sha3(node) + db.put(h, node) + return h + +# Fetches the value with a given keypath from the given node +def _get(db, node, keypath): + L, R, nodetype = parse_node(db.get(node)) + # Key-value node descend + if nodetype == LEAF_TYPE: + return R + elif nodetype == KV_TYPE: + # Keypath too short + if not keypath: + return None + if keypath[:len(L)] == L: + return _get(db, R, keypath[len(L):]) + else: + return None + # Branch node descend + elif nodetype == BRANCH_TYPE: + # Keypath too short + if not keypath: + return None + if keypath[:1] == b0: + return _get(db, L, keypath[1:]) + else: + return _get(db, R, keypath[1:]) + +# Updates the value at the given keypath from the given node +def _update(db, node, keypath, val): + # Empty trie + if not node: + if val: + return hash_and_save(db, encode_kv_node(keypath, hash_and_save(db, encode_leaf_node(val)))) + else: + return b'' + L, R, nodetype = parse_node(db.get(node)) + # Node is a leaf node + if nodetype == LEAF_TYPE: + # Keypath must match, there should be no remaining keypath + if keypath: + raise Exception("Existing kv pair is being effaced because it's key is the prefix of the new key") + return hash_and_save(db, encode_leaf_node(val)) if val else b'' + # node is a key-value node + elif nodetype == KV_TYPE: + # Keypath too short + if not keypath: + return node + # Keypath prefixes match + if keypath[:len(L)] == L: + # Recurse into child + o = _update(db, R, keypath[len(L):], val) + # If child is empty + if not o: + return b'' + #print(db.get(o)) + subL, subR, subnodetype = parse_node(db.get(o)) + # If the child is a key-value node, compress together the keypaths + # into one node + if subnodetype == KV_TYPE: + return hash_and_save(db, encode_kv_node(L + subL, subR)) + else: + return hash_and_save(db, encode_kv_node(L, o)) if o else b'' + # Keypath prefixes don't match. Here we will be converting a key-value node + # of the form (k, CHILD) into a structure of one of the following forms: + # i. (k[:-1], (NEWCHILD, CHILD)) + # ii. (k[:-1], ((k2, NEWCHILD), CHILD)) + # iii. (k1, ((k2, CHILD), NEWCHILD)) + # iv. (k1, ((k2, CHILD), (k2', NEWCHILD)) + # v. (CHILD, NEWCHILD) + # vi. ((k[1:], CHILD), (k', NEWCHILD)) + # vii. ((k[1:], CHILD), NEWCHILD) + # viii (CHILD, (k[1:], NEWCHILD)) + else: + cf = common_prefix_length(L, keypath[:len(L)]) + # New key-value pair can not contain empty value + if not val: + return node + # valnode: the child node that has the new value we are adding + # Case 1: keypath prefixes almost match, so we are in case (i), (ii), (v), (vi) + if len(keypath) == cf + 1: + valnode = hash_and_save(db, encode_leaf_node(val)) + # Case 2: keypath prefixes mismatch in the middle, so we need to break + # the keypath in half. We are in case (iii), (iv), (vii), (viii) + else: + valnode = hash_and_save(db, encode_kv_node(keypath[cf+1:], hash_and_save(db, encode_leaf_node(val)))) + # oldnode: the child node the has the old child value + # Case 1: (i), (iii), (v), (vi) + if len(L) == cf + 1: + oldnode = R + # (ii), (iv), (vi), (viii) + else: + oldnode = hash_and_save(db, encode_kv_node(L[cf+1:], R)) + # Create the new branch node (because the key paths diverge, there has to + # be some "first bit" at which they diverge, so there must be a branch + # node somewhere) + if keypath[cf:cf+1] == b1: + newsub = hash_and_save(db, encode_branch_node(oldnode, valnode)) + else: + newsub = hash_and_save(db, encode_branch_node(valnode, oldnode)) + # Case 1: keypath prefixes match in the first bit, so we still need + # a kv node at the top + # (i) (ii) (iii) (iv) + if cf: + return hash_and_save(db, encode_kv_node(L[:cf], newsub)) + # Case 2: keypath prefixes diverge in the first bit, so we replace the + # kv node with a branch node + # (v) (vi) (vii) (viii) + else: + return newsub + # node is a branch node + elif nodetype == BRANCH_TYPE: + # Keypath too short + if not keypath: + return node + newL, newR = L, R + # Which child node to update? Depends on first bit in keypath + if keypath[:1] == b0: + newL = _update(db, L, keypath[1:], val) + else: + newR = _update(db, R, keypath[1:], val) + # Compress branch node into kv node + if not newL or not newR: + subL, subR, subnodetype = parse_node(db.get(newL or newR)) + first_bit = b1 if newR else b0 + # Compress (k1, (k2, NODE)) -> (k1 + k2, NODE) + if subnodetype == KV_TYPE: + return hash_and_save(db, encode_kv_node(first_bit + subL, subR)) + # kv node pointing to a branch node + elif subnodetype == BRANCH_TYPE or subnodetype == LEAF_TYPE: + return hash_and_save(db, encode_kv_node(first_bit, newL or newR)) + else: + return hash_and_save(db, encode_branch_node(newL, newR)) + raise Exception("How did I get here?") + +# Prints a tree, and checks that all invariants check out +def print_and_check_invariants(db, node, prefix=b''): + if node == b'' and prefix == b'': + return {} + L, R, nodetype = parse_node(db.get(node)) + if nodetype == LEAF_TYPE: + # All keys must be 160 bits + assert len(prefix) == 160 + return {prefix: R} + elif nodetype == KV_TYPE: + # (k1, (k2, node)) two nested key values nodes not allowed + assert 0 < len(L) <= 160 - len(prefix) + if len(L) + len(prefix) < 160: + subL, subR, subnodetype = parse_node(db.get(R)) + assert subnodetype != KV_TYPE + # Childre of a key node cannot be empty + assert subR != sha3(b'') + return print_and_check_invariants(db, R, prefix + L) + else: + # Children of a branch node cannot be empty + assert L != sha3(b'') and R != sha3(b'') + o = {} + o.update(print_and_check_invariants(db, L, prefix + b0)) + o.update(print_and_check_invariants(db, R, prefix + b1)) + return o + +# Pretty-print all nodes in a tree (for debugging purposes) +def print_nodes(db, node, prefix=b''): + if node == b'': + print('empty node') + return + L, R, nodetype = parse_node(db.get(node)) + if nodetype == LEAF_TYPE: + print('value node', encode_hex(node[:4]), R) + elif nodetype == KV_TYPE: + print(('kv node:', encode_hex(node[:4]), ''.join(['1' if x == 1 else '0' for x in L]), encode_hex(R[:4]))) + print_nodes(db, R, prefix + L) + else: + print(('branch node:', encode_hex(node[:4]), encode_hex(L[:4]), encode_hex(R[:4]))) + print_nodes(db, L, prefix + b0) + print_nodes(db, R, prefix + b1) + +# Get a long-format Merkle branch +def _get_long_format_branch(db, node, keypath): + if not keypath: + return [db.get(node)] + L, R, nodetype = parse_node(db.get(node)) + if nodetype == KV_TYPE: + path = encode_bin_path(L) + if keypath[:len(L)] == L: + return [db.get(node)] + _get_long_format_branch(db, R, keypath[len(L):]) + else: + return [db.get(node)] + elif nodetype == BRANCH_TYPE: + if keypath[:1] == b0: + return [db.get(node)] + _get_long_format_branch(db, L, keypath[1:]) + else: + return [db.get(node)] + _get_long_format_branch(db, R, keypath[1:]) + +def _verify_long_format_branch(branch, root, keypath, value): + db = EphemDB() + db.kv = {sha3(node): node for node in branch} + assert _get(db, root, keypath) == value + return True + +# Get full subtrie +def _get_subtrie(db, node): + dbnode = db.get(node) + L, R, nodetype = parse_node(dbnode) + if nodetype == KV_TYPE: + return [dbnode] + _get_subtrie(db, R) + elif nodetype == BRANCH_TYPE: + return [dbnode] + _get_subtrie(db, L) + _get_subtrie(db, R) + elif nodetype == LEAF_TYPE: + return [dbnode] + +# Get witness for prefix +def _get_prefix_witness(db, node, keypath): + dbnode = db.get(node) + if not keypath: + return _get_subtrie(db, node) + L, R, nodetype = parse_node(dbnode) + if nodetype == KV_TYPE: + path = encode_bin_path(L) + if len(keypath) < len(L) and L[:len(keypath)] == keypath: + return [dbnode] + _get_subtrie(db, R) + if keypath[:len(L)] == L: + return [dbnode] + _get_prefix_witness(db, R, keypath[len(L):]) + else: + return [dbnode] + elif nodetype == BRANCH_TYPE: + if keypath[:1] == b0: + return [dbnode] + _get_prefix_witness(db, L, keypath[1:]) + else: + return [dbnode] + _get_prefix_witness(db, R, keypath[1:]) + + +# Trie wrapper class +class Trie(): + def __init__(self, db, root): + self.db = db + self.root = root + assert isinstance(self.root, bytes) + + def get(self, key): + #assert len(key) == 20 + return _get(self.db, self.root, encode_bin(key)) + + #def get_branch(self, key): + # o = _get_branch(self.db, self.root, encode_bin(key)) + # assert _verify_branch(o, self.root, encode_bin(key), self.get(key)) + # return o + + def get_long_format_branch(self, key): + o = _get_long_format_branch(self.db, self.root, encode_bin(key)) + assert _verify_long_format_branch(o, self.root, encode_bin(key), self.get(key)) + return o + + def get_prefix_witness(self, key): + return _get_prefix_witness(self.db, self.root, encode_bin(key)) + + def update(self, key, value): + #assert len(key) == 20 + self.root = _update(self.db, self.root, encode_bin(key), value) + + def to_dict(self, hexify=False): + o = print_and_check_invariants(self.db, self.root) + encoder = lambda x: encode_hex(x) if hexify else x + return {encoder(decode_bin(k)): v for k, v in o.items()} + + def print_nodes(self): + print_nodes(self.db, self.root) diff --git a/trie_research/bintrie1/new_bintrie_aggregate.py b/trie_research/bintrie1/new_bintrie_aggregate.py new file mode 100644 index 0000000..5d98941 --- /dev/null +++ b/trie_research/bintrie1/new_bintrie_aggregate.py @@ -0,0 +1,117 @@ +from new_bintrie import b0, b1, KV_TYPE, BRANCH_TYPE, LEAF_TYPE, parse_node, encode_kv_node, encode_branch_node, encode_leaf_node +from bin_utils import encode_bin_path, decode_bin_path, common_prefix_length, encode_bin, decode_bin +from ethereum.utils import sha3 as _sha3, encode_hex + +sha3_cache = {} + +def sha3(x): + if x not in sha3_cache: + sha3_cache[x] = _sha3(x) + return sha3_cache[x] + +def quick_encode(nodes): + o = b'' + for node in nodes: + o += bytes([len(node) // 65536, len(node) // 256, len(node)]) + node + return o + +def quick_decode(nodedata): + o = [] + pos = 0 + while pos < len(nodedata): + L = nodedata[pos] * 65536 + nodedata[pos+1] * 256 + nodedata[pos+2] + o.append(nodedata[pos+3: pos+3+L]) + pos += 3+L + return o + + +class WrapperDB(): + def __init__(self, parent_db): + self.parent_db = parent_db + self.substores = {} + self.node_to_substore = {} + self.new_nodes = {} + self.parent_db_reads = 0 + self.parent_db_writes = 0 + self.printing_mode = False + + # Loads a substore (RLP-encoded list of closeby trie nodes) from the DB + def fetch_substore(self, key): + substore_values = self.parent_db.get(key) + assert substore_values is not None + children = quick_decode(substore_values) + self.parent_db_reads += 1 + self.substores[key] = {sha3(n): n for n in children} + self.node_to_substore.update({sha3(n): key for n in children}) + assert key in self.node_to_substore and key in self.substores + + def get(self, k): + if k in self.new_nodes: + return self.new_nodes[k] + if k not in self.node_to_substore: + self.fetch_substore(k) + o = self.substores[self.node_to_substore[k]][k] + assert sha3(o) == k + return o + + def put(self, k, v): + if k not in self.new_nodes and k not in self.node_to_substore: + self.new_nodes[k] = v + + # Given a key, returns a collection of candidate nodes to form + # a substore, as well as the children of that substore + def get_substore_candidate_and_children(self, key, depth=5): + if depth == 0: + return [], [key] + elif self.parent_db.get(key) is not None: + return [], [key] + else: + node = self.get(key) + L, R, nodetype = parse_node(node) + if nodetype == BRANCH_TYPE: + Ln, Lc = self.get_substore_candidate_and_children(L, depth-1) + Rn, Rc = self.get_substore_candidate_and_children(R, depth-1) + return [node] + Ln + Rn, Lc + Rc + elif nodetype == KV_TYPE: + Rn, Rc = self.get_substore_candidate_and_children(R, depth-1) + return [node] + Rn, Rc + elif nodetype == LEAF_TYPE: + return [node], [] + + # Commits to the parent DB + def commit(self): + processed = {} + assert_exists = {} + for k, v in self.new_nodes.items(): + if k in processed: + continue + nodes, children = self.get_substore_candidate_and_children(k) + if not nodes: + continue + assert k == sha3(nodes[0]) + for c in children: + assert_exists[c] = True + if c not in self.substores: + self.fetch_substore(c) + cvalues = list(self.substores[c].values()) + if len(quick_encode(cvalues + nodes)) < 3072: + del self.substores[c] + nodes.extend(cvalues) + self.parent_db.put(k, quick_encode(nodes)) + self.parent_db_writes += 1 + self.substores[k] = {} + for n in nodes: + h = sha3(n) + self.substores[k][h] = n + self.node_to_substore[h] = k + processed[h] = k + for c in assert_exists: + assert self.parent_db.get(c) is not None + print('reads', self.parent_db_reads, 'writes', self.parent_db_writes) + self.parent_db_reads = self.parent_db_writes = 0 + self.new_nodes = {} + + def clear_cache(self): + assert len(self.new_nodes) == 0 + self.substores = {} + self.node_to_substore = {} diff --git a/trie_research/bintrie1/new_bintrie_aggregate_tests.py b/trie_research/bintrie1/new_bintrie_aggregate_tests.py new file mode 100644 index 0000000..f248e30 --- /dev/null +++ b/trie_research/bintrie1/new_bintrie_aggregate_tests.py @@ -0,0 +1,52 @@ +from new_bintrie import Trie, EphemDB, encode_bin, encode_bin_path, decode_bin_path +from new_bintrie_aggregate import WrapperDB +from ethereum.utils import sha3, encode_hex +import random +import rlp + +def shuffle_in_place(x): + y = x[::] + random.shuffle(y) + return y + +kvpairs = [(sha3(str(i))[12:], str(i).encode('utf-8') * 5) for i in range(2000)] + + +for path in ([], [1,0,1], [0,0,1,0], [1,0,0,1,0], [1,0,0,1,0,0,1,0], [1,0] * 8): + assert decode_bin_path(encode_bin_path(bytes(path))) == bytes(path) + +r1 = None + +t = Trie(WrapperDB(EphemDB()), b'') +for i, (k, v) in enumerate(shuffle_in_place(kvpairs)): + #print(t.to_dict()) + t.update(k, v) + assert t.get(k) == v + if not i % 50: + t.db.commit() + assert t.db.parent_db.get(t.root) is not None + if not i % 250: + t.to_dict() + print("Length of branch at %d nodes: %d" % (i, len(rlp.encode(t.get_branch(k))))) +assert r1 is None or t.root == r1 +r1 = t.root +t.update(kvpairs[0][0], kvpairs[0][1]) +assert t.root == r1 +print(t.get_branch(kvpairs[0][0])) +print(t.get_branch(kvpairs[0][0][::-1])) +print(encode_hex(t.root)) +t.db.commit() +assert t.db.parent_db.get(t.root) is not None +t.db.clear_cache() +t.db.printing_mode = True +for k, v in shuffle_in_place(kvpairs): + t.get(k) + t.db.clear_cache() +print('Average DB reads: %.3f' % (t.db.parent_db_reads / len(kvpairs))) +for k, v in shuffle_in_place(kvpairs): + t.update(k, b'') + if not random.randrange(100): + t.to_dict() + t.db.commit() +#t.print_nodes() +assert t.root == b'' diff --git a/trie_research/bintrie1/new_bintrie_tests.py b/trie_research/bintrie1/new_bintrie_tests.py new file mode 100644 index 0000000..11cc018 --- /dev/null +++ b/trie_research/bintrie1/new_bintrie_tests.py @@ -0,0 +1,65 @@ +from new_bintrie import Trie, EphemDB, encode_bin, encode_bin_path, decode_bin_path +from ethereum.utils import sha3, encode_hex +from compress_witness import compress, expand +import random +import rlp + +def shuffle_in_place(x): + y = x[::] + random.shuffle(y) + return y + +kvpairs = [(sha3(str(i))[12:], str(i).encode('utf-8') * 5) for i in range(2000)] + + +for path in ([], [1,0,1], [0,0,1,0], [1,0,0,1,0], [1,0,0,1,0,0,1,0], [1,0] * 8): + assert decode_bin_path(encode_bin_path(bytes(path))) == bytes(path) + +r1 = None + +for _ in range(3): + t = Trie(EphemDB(), b'') + for i, (k, v) in enumerate(shuffle_in_place(kvpairs)): + #print(t.to_dict()) + t.update(k, v) + assert t.get(k) == v + if not i % 50: + if not i % 250: + t.to_dict() + b = t.get_long_format_branch(k) + print("Length of long-format branch at %d nodes: %d" % (i, len(rlp.encode(b)))) + c = compress(b) + b2 = expand(c) + print("Length of compressed witness: %d" % len(rlp.encode(c))) + assert sorted(b2) == sorted(b), "Witness compression fails" + print('Added 1000 values, doing checks') + assert r1 is None or t.root == r1 + r1 = t.root + t.update(kvpairs[0][0], kvpairs[0][1]) + assert t.root == r1 + print(encode_hex(t.root)) + print('Checking that single-key witnesses are the same as branches') + for k, v in sorted(kvpairs): + assert t.get_prefix_witness(k) == t.get_long_format_branch(k) + print('Checking byte-wide witnesses') + for _ in range(16): + byte = random.randrange(256) + witness = t.get_prefix_witness(bytearray([byte])) + c = compress(witness) + w2 = expand(c) + assert sorted(w2) == sorted(witness), "Witness compression fails" + print('Witness compression for prefix witnesses: %d original %d compressed' % + (len(rlp.encode(witness)), len(rlp.encode(c)))) + subtrie = Trie(EphemDB({sha3(x): x for x in witness}), t.root) + print('auditing byte', byte, 'with', len([k for k,v in kvpairs if k[0] == byte]), 'keys') + for k, v in sorted(kvpairs): + if k[0] == byte: + assert subtrie.get(k) == v + assert subtrie.get(bytearray([byte] + [0] * 19)) == None + assert subtrie.get(bytearray([byte] + [255] * 19)) == None + for k, v in shuffle_in_place(kvpairs): + t.update(k, b'') + if not random.randrange(100): + t.to_dict() + #t.print_nodes() + assert t.root == b''