Added witness prefix fucnctions
This commit is contained in:
parent
82c7393dcf
commit
b0de8d352f
|
@ -1,289 +0,0 @@
|
|||
# All nodes are of the form [path1, child1, path2, child2]
|
||||
# or <value>
|
||||
|
||||
from ethereum import utils
|
||||
from ethereum.db import EphemDB, ListeningDB
|
||||
import rlp, sys
|
||||
import copy
|
||||
|
||||
hashfunc = utils.sha3
|
||||
|
||||
HASHLEN = 32
|
||||
|
||||
|
||||
# 0100000101010111010000110100100101001001 -> ASCII
|
||||
def decode_bin(x):
|
||||
return ''.join([chr(int(x[i:i+8], 2)) for i in range(0, len(x), 8)])
|
||||
|
||||
|
||||
# ASCII -> 0100000101010111010000110100100101001001
|
||||
def encode_bin(x):
|
||||
o = ''
|
||||
for c in x:
|
||||
c = ord(c)
|
||||
p = ''
|
||||
for i in range(8):
|
||||
p = str(c % 2) + p
|
||||
c /= 2
|
||||
o += p
|
||||
return o
|
||||
|
||||
|
||||
# Encodes a binary list [0,1,0,1,1,0] of any length into bytes
|
||||
def encode_bin_path(li):
|
||||
if li == []:
|
||||
return ''
|
||||
b = ''.join([str(x) for x in li])
|
||||
b2 = '0' * ((4 - len(b)) % 4) + b
|
||||
prefix = ['00', '01', '10', '11'][len(b) % 4]
|
||||
if len(b2) % 8 == 4:
|
||||
return decode_bin('00' + prefix + b2)
|
||||
else:
|
||||
return decode_bin('100000' + prefix + b2)
|
||||
|
||||
|
||||
# Decodes bytes into a binary list
|
||||
def decode_bin_path(p):
|
||||
if p == '':
|
||||
return []
|
||||
p = encode_bin(p)
|
||||
if p[0] == '1':
|
||||
p = p[4:]
|
||||
assert p[0:2] == '00'
|
||||
L = ['00', '01', '10', '11'].index(p[2:4])
|
||||
p = p[4+((4 - L) % 4):]
|
||||
return [(1 if x == '1' else 0) for x in p]
|
||||
|
||||
|
||||
# Get a node from a database if needed
|
||||
def dbget(node, db):
|
||||
if len(node) == HASHLEN:
|
||||
return rlp.decode(db.get(node))
|
||||
return node
|
||||
|
||||
|
||||
# Place a node into a database if needed
|
||||
def dbput(node, db):
|
||||
r = rlp.encode(node)
|
||||
if len(r) == HASHLEN or len(r) > HASHLEN * 2:
|
||||
h = hashfunc(r)
|
||||
db.put(h, r)
|
||||
return h
|
||||
return node
|
||||
|
||||
|
||||
# Get a value from a tree
|
||||
def get(node, db, key):
|
||||
node = dbget(node, db)
|
||||
if key == []:
|
||||
return node[0]
|
||||
elif len(node) == 1 or len(node) == 0:
|
||||
return ''
|
||||
else:
|
||||
sub = dbget(node[key[0]], db)
|
||||
if len(sub) == 2:
|
||||
subpath, subnode = sub
|
||||
else:
|
||||
subpath, subnode = '', sub[0]
|
||||
subpath = decode_bin_path(subpath)
|
||||
if key[1:len(subpath)+1] != subpath:
|
||||
return ''
|
||||
return get(subnode, db, key[len(subpath)+1:])
|
||||
|
||||
|
||||
# Get length of shared prefix of inputs
|
||||
def get_shared_length(l1, l2):
|
||||
i = 0
|
||||
while i < len(l1) and i < len(l2) and l1[i] == l2[i]:
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
# Replace ['', v] with [v] and compact nodes into hashes
|
||||
# if needed
|
||||
def contract_node(n, db):
|
||||
if len(n[0]) == 2 and n[0][0] == '':
|
||||
n[0] = [n[0][1]]
|
||||
if len(n[1]) == 2 and n[1][0] == '':
|
||||
n[1] = [n[1][1]]
|
||||
if len(n[0]) != 32:
|
||||
n[0] = dbput(n[0], db)
|
||||
if len(n[1]) != 32:
|
||||
n[1] = dbput(n[1], db)
|
||||
return dbput(n, db)
|
||||
|
||||
|
||||
# Update a trie
|
||||
def update(node, db, key, val):
|
||||
node = dbget(node, db)
|
||||
# Unfortunately this particular design does not allow
|
||||
# a node to have one child, so at the root for empty
|
||||
# tries we need to add two dummy children
|
||||
if node == '':
|
||||
node = [dbput([encode_bin_path([]), ''], db),
|
||||
dbput([encode_bin_path([1]), ''], db)]
|
||||
if key == []:
|
||||
node = [val]
|
||||
elif len(node) == 1:
|
||||
raise Exception("DB must be prefix-free")
|
||||
else:
|
||||
assert len(node) == 2, node
|
||||
sub = dbget(node[key[0]], db)
|
||||
if len(sub) == 2:
|
||||
_subpath, subnode = sub
|
||||
else:
|
||||
_subpath, subnode = '', sub[0]
|
||||
subpath = decode_bin_path(_subpath)
|
||||
sl = get_shared_length(subpath, key[1:])
|
||||
if sl == len(subpath):
|
||||
node[key[0]] = [_subpath, update(subnode, db, key[sl+1:], val)]
|
||||
else:
|
||||
subpath_next = subpath[sl]
|
||||
n = [0, 0]
|
||||
n[subpath_next] = [encode_bin_path(subpath[sl+1:]), subnode]
|
||||
n[(1 - subpath_next)] = [encode_bin_path(key[sl+2:]), [val]]
|
||||
n = contract_node(n, db)
|
||||
node[key[0]] = dbput([encode_bin_path(subpath[:sl]), n], db)
|
||||
return contract_node(node, db)
|
||||
|
||||
|
||||
# Compression algorithm specialized for merkle proof databases
|
||||
# The idea is similar to standard compression algorithms, where
|
||||
# you replace an instance of a repeat with a pointer to the repeat,
|
||||
# except that here you replace an instance of a hash of a value
|
||||
# with the pointer of a value. This is useful since merkle branches
|
||||
# usually include nodes which contain hashes of each other
|
||||
magic = '\xff\x39'
|
||||
|
||||
|
||||
def compress_db(db):
|
||||
out = []
|
||||
values = db.kv.values()
|
||||
keys = [hashfunc(x) for x in values]
|
||||
assert len(keys) < 65300
|
||||
for v in values:
|
||||
o = ''
|
||||
pos = 0
|
||||
while pos < len(v):
|
||||
done = False
|
||||
if v[pos:pos+2] == magic:
|
||||
o += magic + magic
|
||||
done = True
|
||||
pos += 2
|
||||
for i, k in enumerate(keys):
|
||||
if v[pos:].startswith(k):
|
||||
o += magic + chr(i // 256) + chr(i % 256)
|
||||
done = True
|
||||
pos += len(k)
|
||||
break
|
||||
if not done:
|
||||
o += v[pos]
|
||||
pos += 1
|
||||
out.append(o)
|
||||
return rlp.encode(out)
|
||||
|
||||
|
||||
def decompress_db(ins):
|
||||
ins = rlp.decode(ins)
|
||||
vals = [None] * len(ins)
|
||||
|
||||
def decipher(i):
|
||||
if vals[i] is None:
|
||||
v = ins[i]
|
||||
o = ''
|
||||
pos = 0
|
||||
while pos < len(v):
|
||||
if v[pos:pos+2] == magic:
|
||||
if v[pos+2:pos+4] == magic:
|
||||
o += magic
|
||||
else:
|
||||
ind = ord(v[pos+2]) * 256 + ord(v[pos+3])
|
||||
o += hashfunc(decipher(ind))
|
||||
pos += 4
|
||||
else:
|
||||
o += v[pos]
|
||||
pos += 1
|
||||
vals[i] = o
|
||||
return vals[i]
|
||||
|
||||
for i in range(len(ins)):
|
||||
decipher(i)
|
||||
|
||||
o = EphemDB()
|
||||
for v in vals:
|
||||
o.put(hashfunc(v), v)
|
||||
return o
|
||||
|
||||
|
||||
# Convert a merkle branch directly into RLP (ie. remove
|
||||
# the hashing indirection). As it turns out, this is a
|
||||
# really compact way to represent a branch
|
||||
def compress_branch(db, root):
|
||||
o = dbget(copy.copy(root), db)
|
||||
|
||||
def evaluate_node(x):
|
||||
for i in range(len(x)):
|
||||
if len(x[i]) == HASHLEN and x[i] in db.kv:
|
||||
x[i] = evaluate_node(dbget(x[i], db))
|
||||
elif isinstance(x, list):
|
||||
x[i] = evaluate_node(x[i])
|
||||
return x
|
||||
|
||||
o2 = rlp.encode(evaluate_node(o))
|
||||
return o2
|
||||
|
||||
|
||||
def decompress_branch(branch):
|
||||
branch = rlp.decode(branch)
|
||||
db = EphemDB()
|
||||
|
||||
def evaluate_node(x):
|
||||
if isinstance(x, list):
|
||||
x = [evaluate_node(n) for n in x]
|
||||
x = dbput(x, db)
|
||||
return x
|
||||
evaluate_node(branch)
|
||||
return db
|
||||
|
||||
|
||||
# Test with n nodes and k branch picks
|
||||
def test(n, m=100):
|
||||
assert m <= n
|
||||
db = EphemDB()
|
||||
x = ''
|
||||
for i in range(n):
|
||||
k = hashfunc(str(i))
|
||||
v = hashfunc('v'+str(i))
|
||||
x = update(x, db, [int(a) for a in encode_bin(rlp.encode(k))], v)
|
||||
print(x)
|
||||
print(sum([len(val) for key, val in db.db.items()]))
|
||||
l1 = ListeningDB(db)
|
||||
o = 0
|
||||
p = 0
|
||||
q = 0
|
||||
ecks = x
|
||||
for i in range(m):
|
||||
x = copy.deepcopy(ecks)
|
||||
k = hashfunc(str(i))
|
||||
v = hashfunc('v'+str(i))
|
||||
l2 = ListeningDB(l1)
|
||||
v2 = get(x, l2, [int(a) for a in encode_bin(rlp.encode(k))])
|
||||
assert v == v2
|
||||
o += sum([len(val) for key, val in l2.kv.items()])
|
||||
cdb = compress_db(l2)
|
||||
p += len(cdb)
|
||||
assert decompress_db(cdb).kv == l2.kv
|
||||
cbr = compress_branch(l2, x)
|
||||
q += len(cbr)
|
||||
dbranch = decompress_branch(cbr)
|
||||
assert v == get(x, dbranch, [int(a) for a in encode_bin(rlp.encode(k))])
|
||||
# for k in l2.kv:
|
||||
# assert k in dbranch.kv
|
||||
o = {
|
||||
'total_db_size': sum([len(val) for key, val in l1.kv.items()]),
|
||||
'avg_proof_size': sum([len(val) for key, val in l1.kv.items()]),
|
||||
'avg_compressed_proof_size': (p // min(n, m)),
|
||||
'avg_branch_size': (q // min(n, m)),
|
||||
'compressed_db_size': len(compress_db(l1))
|
||||
}
|
||||
return o
|
|
@ -1,13 +0,0 @@
|
|||
import bintrie
|
||||
|
||||
datapoints = [1, 3, 10, 31, 100, 316, 1000, 3162]
|
||||
o = []
|
||||
|
||||
for i in range(len(datapoints)):
|
||||
p = []
|
||||
for j in range(i+1):
|
||||
print 'Running with: %d %d' % (datapoints[i], datapoints[j])
|
||||
p.append(bintrie.test(datapoints[i], datapoints[j])['compressed_db_size'])
|
||||
o.append(p)
|
||||
|
||||
print o
|
|
@ -0,0 +1,46 @@
|
|||
# Get a Merkle proof
|
||||
def _get_branch(db, node, keypath):
|
||||
if not keypath:
|
||||
return [db.get(node)]
|
||||
L, R, nodetype = parse_node(db.get(node))
|
||||
if nodetype == KV_TYPE:
|
||||
path = encode_bin_path(L)
|
||||
if keypath[:len(L)] == L:
|
||||
return [b'\x01'+path] + _get_branch(db, R, keypath[len(L):])
|
||||
else:
|
||||
return [b'\x01'+path, db.get(R)]
|
||||
elif nodetype == BRANCH_TYPE:
|
||||
if keypath[:1] == b0:
|
||||
return [b'\x02'+R] + _get_branch(db, L, keypath[1:])
|
||||
else:
|
||||
return [b'\x03'+L] + _get_branch(db, R, keypath[1:])
|
||||
|
||||
# Verify a Merkle proof
|
||||
def _verify_branch(branch, root, keypath, value):
|
||||
nodes = [branch[-1]]
|
||||
_keypath = b''
|
||||
for data in branch[-2::-1]:
|
||||
marker, node = data[0], data[1:]
|
||||
# it's a keypath
|
||||
if marker == 1:
|
||||
node = decode_bin_path(node)
|
||||
_keypath = node + _keypath
|
||||
nodes.insert(0, encode_kv_node(node, sha3(nodes[0])))
|
||||
# it's a right-side branch
|
||||
elif marker == 2:
|
||||
_keypath = b0 + _keypath
|
||||
nodes.insert(0, encode_branch_node(sha3(nodes[0]), node))
|
||||
# it's a left-side branch
|
||||
elif marker == 3:
|
||||
_keypath = b1 + _keypath
|
||||
nodes.insert(0, encode_branch_node(node, sha3(nodes[0])))
|
||||
else:
|
||||
raise Exception("Foo")
|
||||
L, R, nodetype = parse_node(nodes[0])
|
||||
if value:
|
||||
assert _keypath == keypath
|
||||
assert sha3(nodes[0]) == root
|
||||
db = EphemDB()
|
||||
db.kv = {sha3(node): node for node in nodes}
|
||||
assert _get(db, root, keypath) == value
|
||||
return True
|
|
@ -2,8 +2,8 @@ from bin_utils import encode_bin_path, decode_bin_path, common_prefix_length, en
|
|||
from ethereum.utils import sha3, encode_hex
|
||||
|
||||
class EphemDB():
|
||||
def __init__(self):
|
||||
self.kv = {}
|
||||
def __init__(self, kv=None):
|
||||
self.kv = kv or {}
|
||||
|
||||
def get(self, k):
|
||||
return self.kv.get(k, None)
|
||||
|
@ -212,23 +212,6 @@ def print_nodes(db, node, prefix=b''):
|
|||
print_nodes(db, L, prefix + b0)
|
||||
print_nodes(db, R, prefix + b1)
|
||||
|
||||
# Get a Merkle proof
|
||||
def _get_branch(db, node, keypath):
|
||||
if not keypath:
|
||||
return [db.get(node)]
|
||||
L, R, nodetype = parse_node(db.get(node))
|
||||
if nodetype == KV_TYPE:
|
||||
path = encode_bin_path(L)
|
||||
if keypath[:len(L)] == L:
|
||||
return [b'\x01'+path] + _get_branch(db, R, keypath[len(L):])
|
||||
else:
|
||||
return [b'\x01'+path, db.get(R)]
|
||||
elif nodetype == BRANCH_TYPE:
|
||||
if keypath[:1] == b0:
|
||||
return [b'\x02'+R] + _get_branch(db, L, keypath[1:])
|
||||
else:
|
||||
return [b'\x03'+L] + _get_branch(db, R, keypath[1:])
|
||||
|
||||
# Get a long-format Merkle branch
|
||||
def _get_long_format_branch(db, node, keypath):
|
||||
if not keypath:
|
||||
|
@ -237,14 +220,14 @@ def _get_long_format_branch(db, node, keypath):
|
|||
if nodetype == KV_TYPE:
|
||||
path = encode_bin_path(L)
|
||||
if keypath[:len(L)] == L:
|
||||
return [db.get(node)] + _get_branch(db, R, keypath[len(L):])
|
||||
return [db.get(node)] + _get_long_format_branch(db, R, keypath[len(L):])
|
||||
else:
|
||||
return [db.get(node), db.get(R)]
|
||||
return [db.get(node)]
|
||||
elif nodetype == BRANCH_TYPE:
|
||||
if keypath[:1] == b0:
|
||||
return [db.get(node)] + _get_branch(db, L, keypath[1:])
|
||||
return [db.get(node)] + _get_long_format_branch(db, L, keypath[1:])
|
||||
else:
|
||||
return [db.get(node)] + _get_branch(db, R, keypath[1:])
|
||||
return [db.get(node)] + _get_long_format_branch(db, R, keypath[1:])
|
||||
|
||||
def _verify_long_format_branch(branch, root, keypath, value):
|
||||
db = EphemDB()
|
||||
|
@ -252,35 +235,37 @@ def _verify_long_format_branch(branch, root, keypath, value):
|
|||
assert _get(db, root, keypath) == value
|
||||
return True
|
||||
|
||||
# Verify a Merkle proof
|
||||
def _verify_branch(branch, root, keypath, value):
|
||||
nodes = [branch[-1]]
|
||||
_keypath = b''
|
||||
for data in branch[-2::-1]:
|
||||
marker, node = data[0], data[1:]
|
||||
# it's a keypath
|
||||
if marker == 1:
|
||||
node = decode_bin_path(node)
|
||||
_keypath = node + _keypath
|
||||
nodes.insert(0, encode_kv_node(node, sha3(nodes[0])))
|
||||
# it's a right-side branch
|
||||
elif marker == 2:
|
||||
_keypath = b0 + _keypath
|
||||
nodes.insert(0, encode_branch_node(sha3(nodes[0]), node))
|
||||
# it's a left-side branch
|
||||
elif marker == 3:
|
||||
_keypath = b1 + _keypath
|
||||
nodes.insert(0, encode_branch_node(node, sha3(nodes[0])))
|
||||
# Get full subtrie
|
||||
def _get_subtrie(db, node):
|
||||
dbnode = db.get(node)
|
||||
L, R, nodetype = parse_node(dbnode)
|
||||
if nodetype == KV_TYPE:
|
||||
return [dbnode] + _get_subtrie(db, R)
|
||||
elif nodetype == BRANCH_TYPE:
|
||||
return [dbnode] + _get_subtrie(db, L) + _get_subtrie(db, R)
|
||||
elif nodetype == LEAF_TYPE:
|
||||
return [dbnode]
|
||||
|
||||
# Get witness for prefix
|
||||
def _get_prefix_witness(db, node, keypath):
|
||||
dbnode = db.get(node)
|
||||
if not keypath:
|
||||
return _get_subtrie(db, node)
|
||||
L, R, nodetype = parse_node(dbnode)
|
||||
if nodetype == KV_TYPE:
|
||||
path = encode_bin_path(L)
|
||||
if len(keypath) < len(L) and L[:len(keypath)] == keypath:
|
||||
return [dbnode] + _get_subtrie(db, R)
|
||||
if keypath[:len(L)] == L:
|
||||
return [dbnode] + _get_prefix_witness(db, R, keypath[len(L):])
|
||||
else:
|
||||
raise Exception("Foo")
|
||||
L, R, nodetype = parse_node(nodes[0])
|
||||
if value:
|
||||
assert _keypath == keypath
|
||||
assert sha3(nodes[0]) == root
|
||||
db = EphemDB()
|
||||
db.kv = {sha3(node): node for node in nodes}
|
||||
assert _get(db, root, keypath) == value
|
||||
return True
|
||||
return [dbnode]
|
||||
elif nodetype == BRANCH_TYPE:
|
||||
if keypath[:1] == b0:
|
||||
return [dbnode] + _get_prefix_witness(db, L, keypath[1:])
|
||||
else:
|
||||
return [dbnode] + _get_prefix_witness(db, R, keypath[1:])
|
||||
|
||||
|
||||
# Trie wrapper class
|
||||
class Trie():
|
||||
|
@ -290,21 +275,24 @@ class Trie():
|
|||
assert isinstance(self.root, bytes)
|
||||
|
||||
def get(self, key):
|
||||
assert len(key) == 20
|
||||
#assert len(key) == 20
|
||||
return _get(self.db, self.root, encode_bin(key))
|
||||
|
||||
def get_branch(self, key):
|
||||
o = _get_branch(self.db, self.root, encode_bin(key))
|
||||
assert _verify_branch(o, self.root, encode_bin(key), self.get(key))
|
||||
return o
|
||||
#def get_branch(self, key):
|
||||
# o = _get_branch(self.db, self.root, encode_bin(key))
|
||||
# assert _verify_branch(o, self.root, encode_bin(key), self.get(key))
|
||||
# return o
|
||||
|
||||
def get_long_format_branch(self, key):
|
||||
o = _get_long_format_branch(self.db, self.root, encode_bin(key))
|
||||
assert _verify_long_format_branch(o, self.root, encode_bin(key), self.get(key))
|
||||
return o
|
||||
|
||||
def get_prefix_witness(self, key):
|
||||
return _get_prefix_witness(self.db, self.root, encode_bin(key))
|
||||
|
||||
def update(self, key, value):
|
||||
assert len(key) == 20
|
||||
#assert len(key) == 20
|
||||
self.root = _update(self.db, self.root, encode_bin(key), value)
|
||||
|
||||
def to_dict(self, hexify=False):
|
||||
|
|
|
@ -25,18 +25,30 @@ for _ in range(3):
|
|||
if not i % 50:
|
||||
if not i % 250:
|
||||
t.to_dict()
|
||||
print("Length of branch at %d nodes: %d" % (i, len(rlp.encode(t.get_branch(k)))))
|
||||
print("Length of long-format branch at %d nodes: %d" % (i, len(rlp.encode(t.get_long_format_branch(k)))))
|
||||
print('Added 1000 values, doing checks')
|
||||
assert r1 is None or t.root == r1
|
||||
r1 = t.root
|
||||
t.update(kvpairs[0][0], kvpairs[0][1])
|
||||
assert t.root == r1
|
||||
print(t.get_branch(kvpairs[0][0]))
|
||||
print(t.get_branch(kvpairs[0][0][::-1]))
|
||||
print(encode_hex(t.root))
|
||||
print('Checking that single-key witnesses are the same as branches')
|
||||
for k, v in sorted(kvpairs):
|
||||
assert t.get_prefix_witness(k) == t.get_long_format_branch(k)
|
||||
print('Checking byte-wide witnesses')
|
||||
for _ in range(16):
|
||||
byte = random.randrange(256)
|
||||
witness = t.get_prefix_witness(bytearray([byte]))
|
||||
subtrie = Trie(EphemDB({sha3(x): x for x in witness}), t.root)
|
||||
print('auditing byte', byte, 'with', len([k for k,v in kvpairs if k[0] == byte]), 'keys')
|
||||
for k, v in sorted(kvpairs):
|
||||
if k[0] == byte:
|
||||
assert subtrie.get(k) == v
|
||||
assert subtrie.get(bytearray([byte] + [0] * 19)) == None
|
||||
assert subtrie.get(bytearray([byte] + [255] * 19)) == None
|
||||
for k, v in shuffle_in_place(kvpairs):
|
||||
t.update(k, b'')
|
||||
if not random.randrange(100):
|
||||
t.to_dict()
|
||||
#t.print_nodes()
|
||||
assert t.root == b''
|
||||
|
||||
|
|
Loading…
Reference in New Issue