Added new binary trie design
This commit is contained in:
parent
8c6f40ebab
commit
210cae3fed
|
@ -0,0 +1,186 @@
|
||||||
|
from bin_utils import encode_bin_path, decode_bin_path, common_prefix_length, encode_bin, decode_bin
|
||||||
|
from ethereum.utils import sha3, encode_hex
|
||||||
|
|
||||||
|
class EphemDB():
|
||||||
|
def __init__(self):
|
||||||
|
self.kv = {}
|
||||||
|
|
||||||
|
def get(self, k):
|
||||||
|
return self.kv.get(k, None)
|
||||||
|
|
||||||
|
def put(self, k, v):
|
||||||
|
self.kv[k] = v
|
||||||
|
|
||||||
|
KV_TYPE = 0
|
||||||
|
BRANCH_TYPE = 1
|
||||||
|
|
||||||
|
b1 = bytes([1])
|
||||||
|
b0 = bytes([0])
|
||||||
|
|
||||||
|
def parse_node(node):
|
||||||
|
if len(node) == 64:
|
||||||
|
return node[:32], node[32:], BRANCH_TYPE
|
||||||
|
else:
|
||||||
|
return decode_bin_path(node[:-32]), node[-32:], KV_TYPE
|
||||||
|
|
||||||
|
def encode_kv_node(keypath, node):
|
||||||
|
assert keypath
|
||||||
|
assert len(node) == 32
|
||||||
|
o = encode_bin_path(keypath) + node
|
||||||
|
assert len(o) < 64
|
||||||
|
return o
|
||||||
|
|
||||||
|
def encode_branch_node(left, right):
|
||||||
|
assert len(left) == len(right) == 32
|
||||||
|
return left + right
|
||||||
|
|
||||||
|
def hash_and_save(db, node):
|
||||||
|
h = sha3(node)
|
||||||
|
db.put(h, node)
|
||||||
|
return h
|
||||||
|
|
||||||
|
def _get(db, node, keypath):
|
||||||
|
if not keypath:
|
||||||
|
return db.get(node)
|
||||||
|
L, R, nodetype = parse_node(db.get(node))
|
||||||
|
if nodetype == KV_TYPE:
|
||||||
|
if keypath[:len(L)] == L:
|
||||||
|
return _get(db, R, keypath[len(L):])
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
elif nodetype == BRANCH_TYPE:
|
||||||
|
if keypath[:1] == b0:
|
||||||
|
return _get(db, L, keypath[1:])
|
||||||
|
else:
|
||||||
|
return _get(db, R, keypath[1:])
|
||||||
|
|
||||||
|
def _update(db, node, keypath, val):
|
||||||
|
if not keypath:
|
||||||
|
return hash_and_save(db, val)
|
||||||
|
if not node:
|
||||||
|
return hash_and_save(db, encode_kv_node(keypath, hash_and_save(db, val)))
|
||||||
|
L, R, nodetype = parse_node(db.get(node))
|
||||||
|
if nodetype == KV_TYPE:
|
||||||
|
if keypath[:len(L)] == L:
|
||||||
|
o = _update(db, R, keypath[len(L):], val)
|
||||||
|
assert o is not None
|
||||||
|
if len(L) == len(keypath):
|
||||||
|
return hash_and_save(db, encode_kv_node(L, o)) if o else b''
|
||||||
|
subL, subR, subnodetype = parse_node(db.get(o))
|
||||||
|
if subnodetype == KV_TYPE:
|
||||||
|
return hash_and_save(db, encode_kv_node(L + subL, subR))
|
||||||
|
else:
|
||||||
|
return hash_and_save(db, encode_kv_node(L, o)) if o else b''
|
||||||
|
else:
|
||||||
|
cf = common_prefix_length(L, keypath[:len(L)])
|
||||||
|
if len(keypath) == cf + 1:
|
||||||
|
valnode = val
|
||||||
|
else:
|
||||||
|
valnode = hash_and_save(db, encode_kv_node(keypath[cf+1:], hash_and_save(db, val)))
|
||||||
|
if len(L) == cf + 1:
|
||||||
|
oldnode = R
|
||||||
|
else:
|
||||||
|
oldnode = hash_and_save(db, encode_kv_node(L[cf+1:], R))
|
||||||
|
if keypath[cf:cf+1] == b1:
|
||||||
|
newsub = hash_and_save(db, encode_branch_node(oldnode, valnode))
|
||||||
|
else:
|
||||||
|
newsub = hash_and_save(db, encode_branch_node(valnode, oldnode))
|
||||||
|
if cf:
|
||||||
|
return hash_and_save(db, encode_kv_node(L[:cf], newsub))
|
||||||
|
else:
|
||||||
|
return newsub
|
||||||
|
elif nodetype == BRANCH_TYPE:
|
||||||
|
newL, newR = L, R
|
||||||
|
if keypath[:1] == b0:
|
||||||
|
newL = _update(db, L, keypath[1:], val)
|
||||||
|
else:
|
||||||
|
newR = _update(db, R, keypath[1:], val)
|
||||||
|
if not newL or not newR:
|
||||||
|
subL, subR, subnodetype = parse_node(db.get(newL or newR))
|
||||||
|
first_bit = b1 if newR else b0
|
||||||
|
if subnodetype == KV_TYPE:
|
||||||
|
return hash_and_save(db, encode_kv_node(first_bit + subL, subR))
|
||||||
|
elif subnodetype == BRANCH_TYPE:
|
||||||
|
return hash_and_save(db, encode_kv_node(first_bit, newL or newR))
|
||||||
|
raise Exception("cow")
|
||||||
|
else:
|
||||||
|
return hash_and_save(db, encode_branch_node(newL, newR))
|
||||||
|
raise Exception("cow")
|
||||||
|
|
||||||
|
def print_and_check_invariants(db, node, prefix=b''):
|
||||||
|
#print('pci', node, prefix)
|
||||||
|
if len(prefix) == 160:
|
||||||
|
return {prefix: db.get(node)}
|
||||||
|
if node == b'' and prefix == b'':
|
||||||
|
return {}
|
||||||
|
L, R, nodetype = parse_node(db.get(node))
|
||||||
|
#print('lrn', L, R, nodetype)
|
||||||
|
if nodetype == KV_TYPE:
|
||||||
|
assert 0 < len(L) <= 160 - len(prefix)
|
||||||
|
if len(L) + len(prefix) < 160:
|
||||||
|
subL, subR, subnodetype = parse_node(db.get(R))
|
||||||
|
assert subnodetype != KV_TYPE
|
||||||
|
return print_and_check_invariants(db, R, prefix + L)
|
||||||
|
else:
|
||||||
|
assert L and R
|
||||||
|
o = {}
|
||||||
|
o.update(print_and_check_invariants(db, L, prefix + b0))
|
||||||
|
o.update(print_and_check_invariants(db, R, prefix + b1))
|
||||||
|
return o
|
||||||
|
|
||||||
|
def print_nodes(db, node, prefix=b''):
|
||||||
|
if len(prefix) == 160:
|
||||||
|
print('value node', encode_hex(node[:4]), db.get(node))
|
||||||
|
return
|
||||||
|
if node == b'':
|
||||||
|
print('empty node')
|
||||||
|
return
|
||||||
|
L, R, nodetype = parse_node(db.get(node))
|
||||||
|
if nodetype == KV_TYPE:
|
||||||
|
print(('kv node:', encode_hex(node[:4]), ''.join(['1' if x == 1 else '0' for x in L]), encode_hex(R[:4])))
|
||||||
|
print_nodes(db, R, prefix + L)
|
||||||
|
else:
|
||||||
|
print(('branch node:', encode_hex(node[:4]), encode_hex(L[:4]), encode_hex(R[:4])))
|
||||||
|
print_nodes(db, L, prefix + b0)
|
||||||
|
print_nodes(db, R, prefix + b1)
|
||||||
|
|
||||||
|
def _get_branch(db, node, keypath):
|
||||||
|
if not keypath:
|
||||||
|
return [node]
|
||||||
|
L, R, nodetype = parse_node(db.get(node))
|
||||||
|
if nodetype == KV_TYPE:
|
||||||
|
path = encode_bin_path(L)
|
||||||
|
if keypath[:len(L)] == L:
|
||||||
|
return [path] + _get_branch(db, R, keypath[len(L):])
|
||||||
|
else:
|
||||||
|
return [path]
|
||||||
|
elif nodetype == BRANCH_TYPE:
|
||||||
|
if keypath[:1] == b0:
|
||||||
|
return [R] + _get_branch(db, L, keypath[1:])
|
||||||
|
else:
|
||||||
|
return [L] + _get_branch(db, R, keypath[1:])
|
||||||
|
|
||||||
|
class Trie():
|
||||||
|
def __init__(self, db, root):
|
||||||
|
self.db = db
|
||||||
|
self.root = root
|
||||||
|
assert isinstance(self.root, bytes)
|
||||||
|
|
||||||
|
def get(self, key):
|
||||||
|
assert len(key) == 20
|
||||||
|
return _get(self.db, self.root, encode_bin(key))
|
||||||
|
|
||||||
|
def get_branch(self, key):
|
||||||
|
return _get_branch(self.db, self.root, encode_bin(key))
|
||||||
|
|
||||||
|
def update(self, key, value):
|
||||||
|
assert len(key) == 20
|
||||||
|
self.root = _update(self.db, self.root, encode_bin(key), value)
|
||||||
|
|
||||||
|
def to_dict(self, hexify=False):
|
||||||
|
o = print_and_check_invariants(self.db, self.root)
|
||||||
|
encoder = lambda x: encode_hex(x) if hexify else x
|
||||||
|
return {encoder(decode_bin(k)): v for k, v in o.items()}
|
||||||
|
|
||||||
|
def print_nodes(self):
|
||||||
|
print_nodes(self.db, self.root)
|
|
@ -0,0 +1,32 @@
|
||||||
|
from new_bintrie import Trie, EphemDB, encode_bin, encode_bin_path, decode_bin_path
|
||||||
|
from ethereum.utils import sha3, encode_hex
|
||||||
|
import random
|
||||||
|
|
||||||
|
def shuffle_in_place(x):
|
||||||
|
y = x[::]
|
||||||
|
random.shuffle(y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
kvpairs = [(sha3(str(i))[12:], str(i) * 5) for i in range(1000)]
|
||||||
|
|
||||||
|
|
||||||
|
for path in ([], [1,0,1], [0,0,1,0], [1,0,0,1,0], [1,0,0,1,0,0,1,0], [1,0] * 8):
|
||||||
|
assert decode_bin_path(encode_bin_path(bytes(path))) == bytes(path)
|
||||||
|
|
||||||
|
r1 = None
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
t = Trie(EphemDB(), b'')
|
||||||
|
for k, v in shuffle_in_place(kvpairs):
|
||||||
|
#print(t.to_dict())
|
||||||
|
t.update(k, v)
|
||||||
|
assert t.get(k) == v
|
||||||
|
#t.print_nodes()
|
||||||
|
assert r1 is None or t.root == r1
|
||||||
|
r1 = t.root
|
||||||
|
t.update(kvpairs[0][0], kvpairs[0][1])
|
||||||
|
assert t.root == r1
|
||||||
|
print(t.get_branch(kvpairs[0][0]))
|
||||||
|
print(encode_hex(t.root))
|
||||||
|
|
||||||
|
t.to_dict()
|
Loading…
Reference in New Issue