From c483cb1c233301d3f9e6ee3d694612aedab99bcb Mon Sep 17 00:00:00 2001 From: Vitalik Buterin Date: Tue, 23 Jan 2018 18:01:35 -0500 Subject: [PATCH] Added witness compression --- trie_research/compress_witness.py | 84 ++++++++++++++++++++++++++++++ trie_research/new_bintrie_tests.py | 13 ++++- 2 files changed, 96 insertions(+), 1 deletion(-) create mode 100644 trie_research/compress_witness.py diff --git a/trie_research/compress_witness.py b/trie_research/compress_witness.py new file mode 100644 index 0000000..74e59f6 --- /dev/null +++ b/trie_research/compress_witness.py @@ -0,0 +1,84 @@ +from ethereum.utils import sha3, encode_hex +from new_bintrie import parse_node, KV_TYPE, BRANCH_TYPE, LEAF_TYPE, encode_bin_path, encode_kv_node, encode_branch_node, decode_bin_path + +KV_COMPRESS_TYPE = 128 +BRANCH_LEFT_TYPE = 129 +BRANCH_RIGHT_TYPE = 130 + +def compress(witness): + parentmap = {} + leaves = [] + for w in witness: + L, R, nodetype = parse_node(w) + if nodetype == LEAF_TYPE: + leaves.append(w) + elif nodetype == KV_TYPE: + parentmap[R] = w + elif nodetype == BRANCH_TYPE: + parentmap[L] = w + parentmap[R] = w + used = {} + proof = [] + for node in leaves: + proof.append(node) + used[node] = True + h = sha3(node) + while h in parentmap: + node = parentmap[h] + L, R, nodetype = parse_node(node) + if nodetype == KV_TYPE: + proof.append(bytes([KV_COMPRESS_TYPE]) + encode_bin_path(L)) + elif nodetype == BRANCH_TYPE and L == h: + proof.append(bytes([BRANCH_LEFT_TYPE]) + R) + elif nodetype == BRANCH_TYPE and R == h: + proof.append(bytes([BRANCH_RIGHT_TYPE]) + L) + else: + raise Exception("something is wrong") + h = sha3(node) + if h in used: + proof.pop() + break + used[h] = True + assert len(used) == len(proof) + return proof + +# Input: a serialized node +def parse_proof_node(node): + if node[0] == BRANCH_LEFT_TYPE: + # Output: right child, node type + return node[1:33], BRANCH_LEFT_TYPE + elif node[0] == BRANCH_RIGHT_TYPE: + # Output: left child, node type + return node[1:33], BRANCH_RIGHT_TYPE + elif node[0] == KV_COMPRESS_TYPE: + # Output: keypath: child, node type + return decode_bin_path(node[1:]), KV_COMPRESS_TYPE + elif node[0] == LEAF_TYPE: + # Output: None, value, node type + return node[1:], LEAF_TYPE + else: + raise Exception("Bad node") + +def expand(proof): + witness = [] + lasthash = None + for p in proof: + sub, nodetype = parse_proof_node(p) + if nodetype == LEAF_TYPE: + witness.append(p) + lasthash = sha3(p) + elif nodetype == KV_COMPRESS_TYPE: + fullnode = encode_kv_node(sub, lasthash) + witness.append(fullnode) + lasthash = sha3(fullnode) + elif nodetype == BRANCH_LEFT_TYPE: + fullnode = encode_branch_node(lasthash, sub) + witness.append(fullnode) + lasthash = sha3(fullnode) + elif nodetype == BRANCH_RIGHT_TYPE: + fullnode = encode_branch_node(sub, lasthash) + witness.append(fullnode) + lasthash = sha3(fullnode) + else: + raise Exception("Bad node") + return witness diff --git a/trie_research/new_bintrie_tests.py b/trie_research/new_bintrie_tests.py index be71d4a..11cc018 100644 --- a/trie_research/new_bintrie_tests.py +++ b/trie_research/new_bintrie_tests.py @@ -1,5 +1,6 @@ from new_bintrie import Trie, EphemDB, encode_bin, encode_bin_path, decode_bin_path from ethereum.utils import sha3, encode_hex +from compress_witness import compress, expand import random import rlp @@ -25,7 +26,12 @@ for _ in range(3): if not i % 50: if not i % 250: t.to_dict() - print("Length of long-format branch at %d nodes: %d" % (i, len(rlp.encode(t.get_long_format_branch(k))))) + b = t.get_long_format_branch(k) + print("Length of long-format branch at %d nodes: %d" % (i, len(rlp.encode(b)))) + c = compress(b) + b2 = expand(c) + print("Length of compressed witness: %d" % len(rlp.encode(c))) + assert sorted(b2) == sorted(b), "Witness compression fails" print('Added 1000 values, doing checks') assert r1 is None or t.root == r1 r1 = t.root @@ -39,6 +45,11 @@ for _ in range(3): for _ in range(16): byte = random.randrange(256) witness = t.get_prefix_witness(bytearray([byte])) + c = compress(witness) + w2 = expand(c) + assert sorted(w2) == sorted(witness), "Witness compression fails" + print('Witness compression for prefix witnesses: %d original %d compressed' % + (len(rlp.encode(witness)), len(rlp.encode(c)))) subtrie = Trie(EphemDB({sha3(x): x for x in witness}), t.root) print('auditing byte', byte, 'with', len([k for k,v in kvpairs if k[0] == byte]), 'keys') for k, v in sorted(kvpairs):