From 67b281d5b0738852d9bac4031d48412b02c131d2 Mon Sep 17 00:00:00 2001 From: Vitalik Buterin Date: Fri, 23 Mar 2018 23:24:35 -0400 Subject: [PATCH] Added witness compression and decompression and length tests --- trie_research/bintrie1/new_bintrie_tests.py | 23 ++++++++++++-------- trie_research/bintrie2/new_bintrie.py | 24 ++++++++++++++++++++- trie_research/bintrie2/new_bintrie_test.py | 21 ++++++++++++------ 3 files changed, 52 insertions(+), 16 deletions(-) diff --git a/trie_research/bintrie1/new_bintrie_tests.py b/trie_research/bintrie1/new_bintrie_tests.py index 11cc018..19e6829 100644 --- a/trie_research/bintrie1/new_bintrie_tests.py +++ b/trie_research/bintrie1/new_bintrie_tests.py @@ -19,19 +19,24 @@ r1 = None for _ in range(3): t = Trie(EphemDB(), b'') + total_long_length, total_short_length = 0, 0 for i, (k, v) in enumerate(shuffle_in_place(kvpairs)): #print(t.to_dict()) t.update(k, v) assert t.get(k) == v - if not i % 50: - if not i % 250: - t.to_dict() - b = t.get_long_format_branch(k) - print("Length of long-format branch at %d nodes: %d" % (i, len(rlp.encode(b)))) - c = compress(b) - b2 = expand(c) - print("Length of compressed witness: %d" % len(rlp.encode(c))) - assert sorted(b2) == sorted(b), "Witness compression fails" + if not i % 250: + t.to_dict() + b = t.get_long_format_branch(k) + c = compress(b) + b2 = expand(c) + total_long_length += len(rlp.encode(b)) + total_short_length += len(rlp.encode(c)) + assert sorted(b2) == sorted(b), "Witness compression fails" + if i % 50 == 49: + print("Avg length of long-format branch at %d nodes: %d" % (i-24, total_long_length // 50)) + print("Avg length of compressed witness: %d" % (total_short_length // 50)) + total_long_length = 0 + total_short_length = 0 print('Added 1000 values, doing checks') assert r1 is None or t.root == r1 r1 = t.root diff --git a/trie_research/bintrie2/new_bintrie.py b/trie_research/bintrie2/new_bintrie.py index 2bd354c..1244bad 100644 --- a/trie_research/bintrie2/new_bintrie.py +++ b/trie_research/bintrie2/new_bintrie.py @@ -14,7 +14,7 @@ class EphemDB(): del self.kv[k] zerohashes = [b'\x00' * 32] -for i in range(256): +for i in range(255): zerohashes.insert(0, sha3(zerohashes[0] + zerohashes[0])) def new_tree(db): @@ -101,3 +101,25 @@ def verify_proof(proof, root, key, value): path >>= 1 v = newv return root == v + +def compress_proof(proof): + bits = bytearray(32) + oproof = b'' + for i, p in enumerate(proof): + if p == zerohashes[i]: + bits[i // 8] ^= 1 << i % 8 + else: + oproof += p + return bytes(bits) + oproof + +def decompress_proof(oproof): + proof = [] + bits = bytearray(oproof[:32]) + pos = 32 + for i in range(256): + if bits[i // 8] & (1 << (i % 8)): + proof.append(zerohashes[i]) + else: + proof.append(oproof[pos: pos + 32]) + pos += 32 + return proof diff --git a/trie_research/bintrie2/new_bintrie_test.py b/trie_research/bintrie2/new_bintrie_test.py index f573182..4e0553d 100644 --- a/trie_research/bintrie2/new_bintrie_test.py +++ b/trie_research/bintrie2/new_bintrie_test.py @@ -1,18 +1,27 @@ -from new_bintrie import EphemDB, new_tree, get, update, make_merkle_proof, verify_proof +from new_bintrie import EphemDB, new_tree, get, update, make_merkle_proof, verify_proof, compress_proof, decompress_proof import random from ethereum.utils import sha3 +KEYS = 500 + db = EphemDB() t = new_tree(db) -for i in range(500): +for i in range(KEYS): t = update(db, t, sha3(str(i)), sha3(str(i**3))) -for i in range(500): +print('%d elements added' % KEYS) +for i in range(KEYS): assert get(db, t, sha3(str(i))) == sha3(str(i**3)) -for i in range(501, 1000): +print('Get requests for present elements successful') +for i in range(KEYS + 1, KEYS * 2): assert get(db, t, sha3(str(i))) == b'\x00' * 32 +print('Get requests for absent elements successful') -for i in range(1000): +TL = 0 +for i in range(KEYS * 2): key = sha3(str(i)) - value = sha3(str(i ** 3)) if i < 500 else b'\x00' * 32 + value = sha3(str(i ** 3)) if i < KEYS else b'\x00' * 32 proof = make_merkle_proof(db, t, key) assert verify_proof(proof, t, key, value) + assert decompress_proof(compress_proof(proof)) == proof + TL += len(compress_proof(proof)) +print('Average total length at %d keys: %d, %d including key' % (KEYS, TL // KEYS // 2, TL // KEYS // 2 + 32))