From b0de8d352f6236c9fa2244fed871546fabb016d1 Mon Sep 17 00:00:00 2001 From: Vitalik Buterin Date: Sat, 2 Dec 2017 21:08:35 -0500 Subject: [PATCH] Added witness prefix fucnctions --- trie_research/bintrie.py | 289 ----------------------------- trie_research/bintrie_sample.py | 13 -- trie_research/compact_branches.py | 46 +++++ trie_research/new_bintrie.py | 102 +++++----- trie_research/new_bintrie_tests.py | 20 +- 5 files changed, 107 insertions(+), 363 deletions(-) delete mode 100644 trie_research/bintrie.py delete mode 100644 trie_research/bintrie_sample.py create mode 100644 trie_research/compact_branches.py diff --git a/trie_research/bintrie.py b/trie_research/bintrie.py deleted file mode 100644 index 97e474e..0000000 --- a/trie_research/bintrie.py +++ /dev/null @@ -1,289 +0,0 @@ -# All nodes are of the form [path1, child1, path2, child2] -# or - -from ethereum import utils -from ethereum.db import EphemDB, ListeningDB -import rlp, sys -import copy - -hashfunc = utils.sha3 - -HASHLEN = 32 - - -# 0100000101010111010000110100100101001001 -> ASCII -def decode_bin(x): - return ''.join([chr(int(x[i:i+8], 2)) for i in range(0, len(x), 8)]) - - -# ASCII -> 0100000101010111010000110100100101001001 -def encode_bin(x): - o = '' - for c in x: - c = ord(c) - p = '' - for i in range(8): - p = str(c % 2) + p - c /= 2 - o += p - return o - - -# Encodes a binary list [0,1,0,1,1,0] of any length into bytes -def encode_bin_path(li): - if li == []: - return '' - b = ''.join([str(x) for x in li]) - b2 = '0' * ((4 - len(b)) % 4) + b - prefix = ['00', '01', '10', '11'][len(b) % 4] - if len(b2) % 8 == 4: - return decode_bin('00' + prefix + b2) - else: - return decode_bin('100000' + prefix + b2) - - -# Decodes bytes into a binary list -def decode_bin_path(p): - if p == '': - return [] - p = encode_bin(p) - if p[0] == '1': - p = p[4:] - assert p[0:2] == '00' - L = ['00', '01', '10', '11'].index(p[2:4]) - p = p[4+((4 - L) % 4):] - return [(1 if x == '1' else 0) for x in p] - - -# Get a node from a database if needed -def dbget(node, db): - if len(node) == HASHLEN: - return rlp.decode(db.get(node)) - return node - - -# Place a node into a database if needed -def dbput(node, db): - r = rlp.encode(node) - if len(r) == HASHLEN or len(r) > HASHLEN * 2: - h = hashfunc(r) - db.put(h, r) - return h - return node - - -# Get a value from a tree -def get(node, db, key): - node = dbget(node, db) - if key == []: - return node[0] - elif len(node) == 1 or len(node) == 0: - return '' - else: - sub = dbget(node[key[0]], db) - if len(sub) == 2: - subpath, subnode = sub - else: - subpath, subnode = '', sub[0] - subpath = decode_bin_path(subpath) - if key[1:len(subpath)+1] != subpath: - return '' - return get(subnode, db, key[len(subpath)+1:]) - - -# Get length of shared prefix of inputs -def get_shared_length(l1, l2): - i = 0 - while i < len(l1) and i < len(l2) and l1[i] == l2[i]: - i += 1 - return i - - -# Replace ['', v] with [v] and compact nodes into hashes -# if needed -def contract_node(n, db): - if len(n[0]) == 2 and n[0][0] == '': - n[0] = [n[0][1]] - if len(n[1]) == 2 and n[1][0] == '': - n[1] = [n[1][1]] - if len(n[0]) != 32: - n[0] = dbput(n[0], db) - if len(n[1]) != 32: - n[1] = dbput(n[1], db) - return dbput(n, db) - - -# Update a trie -def update(node, db, key, val): - node = dbget(node, db) - # Unfortunately this particular design does not allow - # a node to have one child, so at the root for empty - # tries we need to add two dummy children - if node == '': - node = [dbput([encode_bin_path([]), ''], db), - dbput([encode_bin_path([1]), ''], db)] - if key == []: - node = [val] - elif len(node) == 1: - raise Exception("DB must be prefix-free") - else: - assert len(node) == 2, node - sub = dbget(node[key[0]], db) - if len(sub) == 2: - _subpath, subnode = sub - else: - _subpath, subnode = '', sub[0] - subpath = decode_bin_path(_subpath) - sl = get_shared_length(subpath, key[1:]) - if sl == len(subpath): - node[key[0]] = [_subpath, update(subnode, db, key[sl+1:], val)] - else: - subpath_next = subpath[sl] - n = [0, 0] - n[subpath_next] = [encode_bin_path(subpath[sl+1:]), subnode] - n[(1 - subpath_next)] = [encode_bin_path(key[sl+2:]), [val]] - n = contract_node(n, db) - node[key[0]] = dbput([encode_bin_path(subpath[:sl]), n], db) - return contract_node(node, db) - - -# Compression algorithm specialized for merkle proof databases -# The idea is similar to standard compression algorithms, where -# you replace an instance of a repeat with a pointer to the repeat, -# except that here you replace an instance of a hash of a value -# with the pointer of a value. This is useful since merkle branches -# usually include nodes which contain hashes of each other -magic = '\xff\x39' - - -def compress_db(db): - out = [] - values = db.kv.values() - keys = [hashfunc(x) for x in values] - assert len(keys) < 65300 - for v in values: - o = '' - pos = 0 - while pos < len(v): - done = False - if v[pos:pos+2] == magic: - o += magic + magic - done = True - pos += 2 - for i, k in enumerate(keys): - if v[pos:].startswith(k): - o += magic + chr(i // 256) + chr(i % 256) - done = True - pos += len(k) - break - if not done: - o += v[pos] - pos += 1 - out.append(o) - return rlp.encode(out) - - -def decompress_db(ins): - ins = rlp.decode(ins) - vals = [None] * len(ins) - - def decipher(i): - if vals[i] is None: - v = ins[i] - o = '' - pos = 0 - while pos < len(v): - if v[pos:pos+2] == magic: - if v[pos+2:pos+4] == magic: - o += magic - else: - ind = ord(v[pos+2]) * 256 + ord(v[pos+3]) - o += hashfunc(decipher(ind)) - pos += 4 - else: - o += v[pos] - pos += 1 - vals[i] = o - return vals[i] - - for i in range(len(ins)): - decipher(i) - - o = EphemDB() - for v in vals: - o.put(hashfunc(v), v) - return o - - -# Convert a merkle branch directly into RLP (ie. remove -# the hashing indirection). As it turns out, this is a -# really compact way to represent a branch -def compress_branch(db, root): - o = dbget(copy.copy(root), db) - - def evaluate_node(x): - for i in range(len(x)): - if len(x[i]) == HASHLEN and x[i] in db.kv: - x[i] = evaluate_node(dbget(x[i], db)) - elif isinstance(x, list): - x[i] = evaluate_node(x[i]) - return x - - o2 = rlp.encode(evaluate_node(o)) - return o2 - - -def decompress_branch(branch): - branch = rlp.decode(branch) - db = EphemDB() - - def evaluate_node(x): - if isinstance(x, list): - x = [evaluate_node(n) for n in x] - x = dbput(x, db) - return x - evaluate_node(branch) - return db - - -# Test with n nodes and k branch picks -def test(n, m=100): - assert m <= n - db = EphemDB() - x = '' - for i in range(n): - k = hashfunc(str(i)) - v = hashfunc('v'+str(i)) - x = update(x, db, [int(a) for a in encode_bin(rlp.encode(k))], v) - print(x) - print(sum([len(val) for key, val in db.db.items()])) - l1 = ListeningDB(db) - o = 0 - p = 0 - q = 0 - ecks = x - for i in range(m): - x = copy.deepcopy(ecks) - k = hashfunc(str(i)) - v = hashfunc('v'+str(i)) - l2 = ListeningDB(l1) - v2 = get(x, l2, [int(a) for a in encode_bin(rlp.encode(k))]) - assert v == v2 - o += sum([len(val) for key, val in l2.kv.items()]) - cdb = compress_db(l2) - p += len(cdb) - assert decompress_db(cdb).kv == l2.kv - cbr = compress_branch(l2, x) - q += len(cbr) - dbranch = decompress_branch(cbr) - assert v == get(x, dbranch, [int(a) for a in encode_bin(rlp.encode(k))]) - # for k in l2.kv: - # assert k in dbranch.kv - o = { - 'total_db_size': sum([len(val) for key, val in l1.kv.items()]), - 'avg_proof_size': sum([len(val) for key, val in l1.kv.items()]), - 'avg_compressed_proof_size': (p // min(n, m)), - 'avg_branch_size': (q // min(n, m)), - 'compressed_db_size': len(compress_db(l1)) - } - return o diff --git a/trie_research/bintrie_sample.py b/trie_research/bintrie_sample.py deleted file mode 100644 index 4be7f0a..0000000 --- a/trie_research/bintrie_sample.py +++ /dev/null @@ -1,13 +0,0 @@ -import bintrie - -datapoints = [1, 3, 10, 31, 100, 316, 1000, 3162] -o = [] - -for i in range(len(datapoints)): - p = [] - for j in range(i+1): - print 'Running with: %d %d' % (datapoints[i], datapoints[j]) - p.append(bintrie.test(datapoints[i], datapoints[j])['compressed_db_size']) - o.append(p) - -print o diff --git a/trie_research/compact_branches.py b/trie_research/compact_branches.py new file mode 100644 index 0000000..e8bee44 --- /dev/null +++ b/trie_research/compact_branches.py @@ -0,0 +1,46 @@ +# Get a Merkle proof +def _get_branch(db, node, keypath): + if not keypath: + return [db.get(node)] + L, R, nodetype = parse_node(db.get(node)) + if nodetype == KV_TYPE: + path = encode_bin_path(L) + if keypath[:len(L)] == L: + return [b'\x01'+path] + _get_branch(db, R, keypath[len(L):]) + else: + return [b'\x01'+path, db.get(R)] + elif nodetype == BRANCH_TYPE: + if keypath[:1] == b0: + return [b'\x02'+R] + _get_branch(db, L, keypath[1:]) + else: + return [b'\x03'+L] + _get_branch(db, R, keypath[1:]) + +# Verify a Merkle proof +def _verify_branch(branch, root, keypath, value): + nodes = [branch[-1]] + _keypath = b'' + for data in branch[-2::-1]: + marker, node = data[0], data[1:] + # it's a keypath + if marker == 1: + node = decode_bin_path(node) + _keypath = node + _keypath + nodes.insert(0, encode_kv_node(node, sha3(nodes[0]))) + # it's a right-side branch + elif marker == 2: + _keypath = b0 + _keypath + nodes.insert(0, encode_branch_node(sha3(nodes[0]), node)) + # it's a left-side branch + elif marker == 3: + _keypath = b1 + _keypath + nodes.insert(0, encode_branch_node(node, sha3(nodes[0]))) + else: + raise Exception("Foo") + L, R, nodetype = parse_node(nodes[0]) + if value: + assert _keypath == keypath + assert sha3(nodes[0]) == root + db = EphemDB() + db.kv = {sha3(node): node for node in nodes} + assert _get(db, root, keypath) == value + return True diff --git a/trie_research/new_bintrie.py b/trie_research/new_bintrie.py index ba01897..827b97f 100644 --- a/trie_research/new_bintrie.py +++ b/trie_research/new_bintrie.py @@ -2,8 +2,8 @@ from bin_utils import encode_bin_path, decode_bin_path, common_prefix_length, en from ethereum.utils import sha3, encode_hex class EphemDB(): - def __init__(self): - self.kv = {} + def __init__(self, kv=None): + self.kv = kv or {} def get(self, k): return self.kv.get(k, None) @@ -212,23 +212,6 @@ def print_nodes(db, node, prefix=b''): print_nodes(db, L, prefix + b0) print_nodes(db, R, prefix + b1) -# Get a Merkle proof -def _get_branch(db, node, keypath): - if not keypath: - return [db.get(node)] - L, R, nodetype = parse_node(db.get(node)) - if nodetype == KV_TYPE: - path = encode_bin_path(L) - if keypath[:len(L)] == L: - return [b'\x01'+path] + _get_branch(db, R, keypath[len(L):]) - else: - return [b'\x01'+path, db.get(R)] - elif nodetype == BRANCH_TYPE: - if keypath[:1] == b0: - return [b'\x02'+R] + _get_branch(db, L, keypath[1:]) - else: - return [b'\x03'+L] + _get_branch(db, R, keypath[1:]) - # Get a long-format Merkle branch def _get_long_format_branch(db, node, keypath): if not keypath: @@ -237,14 +220,14 @@ def _get_long_format_branch(db, node, keypath): if nodetype == KV_TYPE: path = encode_bin_path(L) if keypath[:len(L)] == L: - return [db.get(node)] + _get_branch(db, R, keypath[len(L):]) + return [db.get(node)] + _get_long_format_branch(db, R, keypath[len(L):]) else: - return [db.get(node), db.get(R)] + return [db.get(node)] elif nodetype == BRANCH_TYPE: if keypath[:1] == b0: - return [db.get(node)] + _get_branch(db, L, keypath[1:]) + return [db.get(node)] + _get_long_format_branch(db, L, keypath[1:]) else: - return [db.get(node)] + _get_branch(db, R, keypath[1:]) + return [db.get(node)] + _get_long_format_branch(db, R, keypath[1:]) def _verify_long_format_branch(branch, root, keypath, value): db = EphemDB() @@ -252,35 +235,37 @@ def _verify_long_format_branch(branch, root, keypath, value): assert _get(db, root, keypath) == value return True -# Verify a Merkle proof -def _verify_branch(branch, root, keypath, value): - nodes = [branch[-1]] - _keypath = b'' - for data in branch[-2::-1]: - marker, node = data[0], data[1:] - # it's a keypath - if marker == 1: - node = decode_bin_path(node) - _keypath = node + _keypath - nodes.insert(0, encode_kv_node(node, sha3(nodes[0]))) - # it's a right-side branch - elif marker == 2: - _keypath = b0 + _keypath - nodes.insert(0, encode_branch_node(sha3(nodes[0]), node)) - # it's a left-side branch - elif marker == 3: - _keypath = b1 + _keypath - nodes.insert(0, encode_branch_node(node, sha3(nodes[0]))) +# Get full subtrie +def _get_subtrie(db, node): + dbnode = db.get(node) + L, R, nodetype = parse_node(dbnode) + if nodetype == KV_TYPE: + return [dbnode] + _get_subtrie(db, R) + elif nodetype == BRANCH_TYPE: + return [dbnode] + _get_subtrie(db, L) + _get_subtrie(db, R) + elif nodetype == LEAF_TYPE: + return [dbnode] + +# Get witness for prefix +def _get_prefix_witness(db, node, keypath): + dbnode = db.get(node) + if not keypath: + return _get_subtrie(db, node) + L, R, nodetype = parse_node(dbnode) + if nodetype == KV_TYPE: + path = encode_bin_path(L) + if len(keypath) < len(L) and L[:len(keypath)] == keypath: + return [dbnode] + _get_subtrie(db, R) + if keypath[:len(L)] == L: + return [dbnode] + _get_prefix_witness(db, R, keypath[len(L):]) else: - raise Exception("Foo") - L, R, nodetype = parse_node(nodes[0]) - if value: - assert _keypath == keypath - assert sha3(nodes[0]) == root - db = EphemDB() - db.kv = {sha3(node): node for node in nodes} - assert _get(db, root, keypath) == value - return True + return [dbnode] + elif nodetype == BRANCH_TYPE: + if keypath[:1] == b0: + return [dbnode] + _get_prefix_witness(db, L, keypath[1:]) + else: + return [dbnode] + _get_prefix_witness(db, R, keypath[1:]) + # Trie wrapper class class Trie(): @@ -290,21 +275,24 @@ class Trie(): assert isinstance(self.root, bytes) def get(self, key): - assert len(key) == 20 + #assert len(key) == 20 return _get(self.db, self.root, encode_bin(key)) - def get_branch(self, key): - o = _get_branch(self.db, self.root, encode_bin(key)) - assert _verify_branch(o, self.root, encode_bin(key), self.get(key)) - return o + #def get_branch(self, key): + # o = _get_branch(self.db, self.root, encode_bin(key)) + # assert _verify_branch(o, self.root, encode_bin(key), self.get(key)) + # return o def get_long_format_branch(self, key): o = _get_long_format_branch(self.db, self.root, encode_bin(key)) assert _verify_long_format_branch(o, self.root, encode_bin(key), self.get(key)) return o + def get_prefix_witness(self, key): + return _get_prefix_witness(self.db, self.root, encode_bin(key)) + def update(self, key, value): - assert len(key) == 20 + #assert len(key) == 20 self.root = _update(self.db, self.root, encode_bin(key), value) def to_dict(self, hexify=False): diff --git a/trie_research/new_bintrie_tests.py b/trie_research/new_bintrie_tests.py index 570887c..be71d4a 100644 --- a/trie_research/new_bintrie_tests.py +++ b/trie_research/new_bintrie_tests.py @@ -25,18 +25,30 @@ for _ in range(3): if not i % 50: if not i % 250: t.to_dict() - print("Length of branch at %d nodes: %d" % (i, len(rlp.encode(t.get_branch(k))))) + print("Length of long-format branch at %d nodes: %d" % (i, len(rlp.encode(t.get_long_format_branch(k))))) + print('Added 1000 values, doing checks') assert r1 is None or t.root == r1 r1 = t.root t.update(kvpairs[0][0], kvpairs[0][1]) assert t.root == r1 - print(t.get_branch(kvpairs[0][0])) - print(t.get_branch(kvpairs[0][0][::-1])) print(encode_hex(t.root)) + print('Checking that single-key witnesses are the same as branches') + for k, v in sorted(kvpairs): + assert t.get_prefix_witness(k) == t.get_long_format_branch(k) + print('Checking byte-wide witnesses') + for _ in range(16): + byte = random.randrange(256) + witness = t.get_prefix_witness(bytearray([byte])) + subtrie = Trie(EphemDB({sha3(x): x for x in witness}), t.root) + print('auditing byte', byte, 'with', len([k for k,v in kvpairs if k[0] == byte]), 'keys') + for k, v in sorted(kvpairs): + if k[0] == byte: + assert subtrie.get(k) == v + assert subtrie.get(bytearray([byte] + [0] * 19)) == None + assert subtrie.get(bytearray([byte] + [255] * 19)) == None for k, v in shuffle_in_place(kvpairs): t.update(k, b'') if not random.randrange(100): t.to_dict() #t.print_nodes() assert t.root == b'' -