diff --git a/trie_research/new_bintrie.py b/trie_research/new_bintrie.py new file mode 100644 index 0000000..df5a0d2 --- /dev/null +++ b/trie_research/new_bintrie.py @@ -0,0 +1,186 @@ +from bin_utils import encode_bin_path, decode_bin_path, common_prefix_length, encode_bin, decode_bin +from ethereum.utils import sha3, encode_hex + +class EphemDB(): + def __init__(self): + self.kv = {} + + def get(self, k): + return self.kv.get(k, None) + + def put(self, k, v): + self.kv[k] = v + +KV_TYPE = 0 +BRANCH_TYPE = 1 + +b1 = bytes([1]) +b0 = bytes([0]) + +def parse_node(node): + if len(node) == 64: + return node[:32], node[32:], BRANCH_TYPE + else: + return decode_bin_path(node[:-32]), node[-32:], KV_TYPE + +def encode_kv_node(keypath, node): + assert keypath + assert len(node) == 32 + o = encode_bin_path(keypath) + node + assert len(o) < 64 + return o + +def encode_branch_node(left, right): + assert len(left) == len(right) == 32 + return left + right + +def hash_and_save(db, node): + h = sha3(node) + db.put(h, node) + return h + +def _get(db, node, keypath): + if not keypath: + return db.get(node) + L, R, nodetype = parse_node(db.get(node)) + if nodetype == KV_TYPE: + if keypath[:len(L)] == L: + return _get(db, R, keypath[len(L):]) + else: + return None + elif nodetype == BRANCH_TYPE: + if keypath[:1] == b0: + return _get(db, L, keypath[1:]) + else: + return _get(db, R, keypath[1:]) + +def _update(db, node, keypath, val): + if not keypath: + return hash_and_save(db, val) + if not node: + return hash_and_save(db, encode_kv_node(keypath, hash_and_save(db, val))) + L, R, nodetype = parse_node(db.get(node)) + if nodetype == KV_TYPE: + if keypath[:len(L)] == L: + o = _update(db, R, keypath[len(L):], val) + assert o is not None + if len(L) == len(keypath): + return hash_and_save(db, encode_kv_node(L, o)) if o else b'' + subL, subR, subnodetype = parse_node(db.get(o)) + if subnodetype == KV_TYPE: + return hash_and_save(db, encode_kv_node(L + subL, subR)) + else: + return hash_and_save(db, encode_kv_node(L, o)) if o else b'' + else: + cf = common_prefix_length(L, keypath[:len(L)]) + if len(keypath) == cf + 1: + valnode = val + else: + valnode = hash_and_save(db, encode_kv_node(keypath[cf+1:], hash_and_save(db, val))) + if len(L) == cf + 1: + oldnode = R + else: + oldnode = hash_and_save(db, encode_kv_node(L[cf+1:], R)) + if keypath[cf:cf+1] == b1: + newsub = hash_and_save(db, encode_branch_node(oldnode, valnode)) + else: + newsub = hash_and_save(db, encode_branch_node(valnode, oldnode)) + if cf: + return hash_and_save(db, encode_kv_node(L[:cf], newsub)) + else: + return newsub + elif nodetype == BRANCH_TYPE: + newL, newR = L, R + if keypath[:1] == b0: + newL = _update(db, L, keypath[1:], val) + else: + newR = _update(db, R, keypath[1:], val) + if not newL or not newR: + subL, subR, subnodetype = parse_node(db.get(newL or newR)) + first_bit = b1 if newR else b0 + if subnodetype == KV_TYPE: + return hash_and_save(db, encode_kv_node(first_bit + subL, subR)) + elif subnodetype == BRANCH_TYPE: + return hash_and_save(db, encode_kv_node(first_bit, newL or newR)) + raise Exception("cow") + else: + return hash_and_save(db, encode_branch_node(newL, newR)) + raise Exception("cow") + +def print_and_check_invariants(db, node, prefix=b''): + #print('pci', node, prefix) + if len(prefix) == 160: + return {prefix: db.get(node)} + if node == b'' and prefix == b'': + return {} + L, R, nodetype = parse_node(db.get(node)) + #print('lrn', L, R, nodetype) + if nodetype == KV_TYPE: + assert 0 < len(L) <= 160 - len(prefix) + if len(L) + len(prefix) < 160: + subL, subR, subnodetype = parse_node(db.get(R)) + assert subnodetype != KV_TYPE + return print_and_check_invariants(db, R, prefix + L) + else: + assert L and R + o = {} + o.update(print_and_check_invariants(db, L, prefix + b0)) + o.update(print_and_check_invariants(db, R, prefix + b1)) + return o + +def print_nodes(db, node, prefix=b''): + if len(prefix) == 160: + print('value node', encode_hex(node[:4]), db.get(node)) + return + if node == b'': + print('empty node') + return + L, R, nodetype = parse_node(db.get(node)) + if nodetype == KV_TYPE: + print(('kv node:', encode_hex(node[:4]), ''.join(['1' if x == 1 else '0' for x in L]), encode_hex(R[:4]))) + print_nodes(db, R, prefix + L) + else: + print(('branch node:', encode_hex(node[:4]), encode_hex(L[:4]), encode_hex(R[:4]))) + print_nodes(db, L, prefix + b0) + print_nodes(db, R, prefix + b1) + +def _get_branch(db, node, keypath): + if not keypath: + return [node] + L, R, nodetype = parse_node(db.get(node)) + if nodetype == KV_TYPE: + path = encode_bin_path(L) + if keypath[:len(L)] == L: + return [path] + _get_branch(db, R, keypath[len(L):]) + else: + return [path] + elif nodetype == BRANCH_TYPE: + if keypath[:1] == b0: + return [R] + _get_branch(db, L, keypath[1:]) + else: + return [L] + _get_branch(db, R, keypath[1:]) + +class Trie(): + def __init__(self, db, root): + self.db = db + self.root = root + assert isinstance(self.root, bytes) + + def get(self, key): + assert len(key) == 20 + return _get(self.db, self.root, encode_bin(key)) + + def get_branch(self, key): + return _get_branch(self.db, self.root, encode_bin(key)) + + def update(self, key, value): + assert len(key) == 20 + self.root = _update(self.db, self.root, encode_bin(key), value) + + def to_dict(self, hexify=False): + o = print_and_check_invariants(self.db, self.root) + encoder = lambda x: encode_hex(x) if hexify else x + return {encoder(decode_bin(k)): v for k, v in o.items()} + + def print_nodes(self): + print_nodes(self.db, self.root) diff --git a/trie_research/new_bintrie_tests.py b/trie_research/new_bintrie_tests.py new file mode 100644 index 0000000..dc250c3 --- /dev/null +++ b/trie_research/new_bintrie_tests.py @@ -0,0 +1,32 @@ +from new_bintrie import Trie, EphemDB, encode_bin, encode_bin_path, decode_bin_path +from ethereum.utils import sha3, encode_hex +import random + +def shuffle_in_place(x): + y = x[::] + random.shuffle(y) + return y + +kvpairs = [(sha3(str(i))[12:], str(i) * 5) for i in range(1000)] + + +for path in ([], [1,0,1], [0,0,1,0], [1,0,0,1,0], [1,0,0,1,0,0,1,0], [1,0] * 8): + assert decode_bin_path(encode_bin_path(bytes(path))) == bytes(path) + +r1 = None + +for i in range(3): + t = Trie(EphemDB(), b'') + for k, v in shuffle_in_place(kvpairs): + #print(t.to_dict()) + t.update(k, v) + assert t.get(k) == v + #t.print_nodes() + assert r1 is None or t.root == r1 + r1 = t.root + t.update(kvpairs[0][0], kvpairs[0][1]) + assert t.root == r1 + print(t.get_branch(kvpairs[0][0])) + print(encode_hex(t.root)) + +t.to_dict()