Fix ssz_partials.py linter errors

This commit is contained in:
Hsiao-Wei Wang 2019-06-03 17:14:29 +08:00
parent 8e8ef2d0a4
commit a5576059f8
No known key found for this signature in database
GPG Key ID: 95B070122902DEA4

View File

@ -1,17 +1,44 @@
from ..merkle_minimal import hash, next_power_of_two
from .ssz_typing import *
from .ssz_impl import *
from .ssz_typing import (
Container,
infer_input_type,
is_bool_type,
is_bytes_type,
is_bytesn_type,
is_container_type,
is_list_kind,
is_list_type,
is_uint_type,
is_vector_kind,
is_vector_type,
read_elem_type,
uint_byte_size,
)
from .ssz_impl import (
chunkify,
deserialize_basic,
get_typed_values,
is_basic_type,
is_bottom_layer_kind,
pack,
serialize_basic,
)
ZERO_CHUNK = b'\x00' * 32
def last_power_of_two(x):
return next_power_of_two(x+1) // 2
return next_power_of_two(x + 1) // 2
def concat_generalized_indices(x, y):
return x * last_power_of_two(y) + y - last_power_of_two(y)
def rebase(objs, new_root):
return {concat_generalized_indices(new_root, k): v for k,v in objs.items()}
return {concat_generalized_indices(new_root, k): v for k, v in objs.items()}
def constrict_generalized_index(x, q):
depth = last_power_of_two(x // q)
@ -20,32 +47,36 @@ def constrict_generalized_index(x, q):
return None
return o
def unrebase(objs, q):
o = {}
for k,v in objs.items():
for k, v in objs.items():
new_k = constrict_generalized_index(k, q)
if new_k is not None:
o[new_k] = v
return o
def filler(starting_position, chunk_count):
at, skip, end = chunk_count, 1, next_power_of_two(chunk_count)
value = ZERO_CHUNK
o = {}
while at < end:
while at % (skip*2) == 0:
while at % (skip * 2) == 0:
skip *= 2
value = hash(value + value)
o[(starting_position + at) // skip] = value
at += skip
return o
def merkle_tree_of_chunks(chunks, root):
starting_index = root * next_power_of_two(len(chunks))
o = {starting_index+i: chunk for i,chunk in enumerate(chunks)}
o = {starting_index + i: chunk for i, chunk in enumerate(chunks)}
o = {**o, **filler(starting_index, len(chunks))}
return o
@infer_input_type
def ssz_leaves(obj, typ=None, root=1):
if is_list_kind(typ):
@ -57,33 +88,36 @@ def ssz_leaves(obj, typ=None, root=1):
if is_bottom_layer_kind(typ):
data = serialize_basic(obj, typ) if is_basic_type(typ) else pack(obj, read_elem_type(typ))
q = {**o, **merkle_tree_of_chunks(chunkify(data), base)}
#print(obj, root, typ, base, list(q.keys()))
# print(obj, root, typ, base, list(q.keys()))
return(q)
else:
fields = get_typed_values(obj, typ=typ)
sub_base = base * next_power_of_two(len(fields))
for i, (elem, elem_type) in enumerate(fields):
o = {**o, **ssz_leaves(elem, typ=elem_type, root=sub_base+i)}
o = {**o, **ssz_leaves(elem, typ=elem_type, root=sub_base + i)}
q = {**o, **filler(sub_base, len(fields))}
#print(obj, root, typ, base, list(q.keys()))
# print(obj, root, typ, base, list(q.keys()))
return(q)
def fill(objects):
objects = {k:v for k,v in objects.items()}
objects = {k: v for k, v in objects.items()}
keys = sorted(objects.keys())[::-1]
pos = 0
while pos < len(keys):
k = keys[pos]
if k in objects and k^1 in objects and k//2 not in objects:
objects[k//2] = hash(objects[k&-2] + objects[k|1])
if k in objects and k ^ 1 in objects and k // 2 not in objects:
objects[k // 2] = hash(objects[k & - 2] + objects[k | 1])
keys.append(k // 2)
pos += 1
return objects
@infer_input_type
def ssz_full(obj, typ=None):
return fill(ssz_leaves(obj, typ=typ))
def get_basic_type_size(typ):
if is_uint_type(typ):
return uint_byte_size(typ)
@ -92,6 +126,7 @@ def get_basic_type_size(typ):
else:
raise Exception("Type not basic: {}".format(typ))
def get_bottom_layer_element_position(typ, base, length, index):
"""
Returns the generalized index and the byte range of the index'th value
@ -104,7 +139,8 @@ def get_bottom_layer_element_position(typ, base, length, index):
chunk_count = (1 if is_basic_type(typ) else length) * elem_size // 32
generalized_index = base * next_power_of_two(chunk_count) + chunk_index
start = elem_size * index % 32
return generalized_index, start, start+elem_size
return generalized_index, start, start + elem_size
@infer_input_type
def get_generalized_indices(obj, path, typ=None, root=1):
@ -132,14 +168,17 @@ def get_generalized_indices(obj, path, typ=None, root=1):
root=base * next_power_of_two(field_count) + index
)
def get_branch_indices(tree_index):
o = [tree_index, tree_index ^ 1]
while o[-1] > 1:
o.append((o[-1] // 2) ^ 1)
return o[:-1]
def remove_redundant_indices(obj):
return {k:v for k,v in obj.items() if not (k*2 in obj and k*2+1 in obj)}
return {k: v for k, v in obj.items() if not (k * 2 in obj and k * 2 + 1 in obj)}
def merge(*args):
o = {}
@ -147,14 +186,19 @@ def merge(*args):
o = {**o, **arg}
return fill(o)
@infer_input_type
def get_nodes_along_path(obj, path, typ=None):
indices = get_generalized_indices(obj, path, typ=typ)
return remove_redundant_indices(merge(*({i:obj.objects[i] for i in get_branch_indices(index)} for index in indices)))
return remove_redundant_indices(merge(
*({i: obj.objects[i] for i in get_branch_indices(index)} for index in indices)
))
class OutOfRangeException(Exception):
pass
class SSZPartial():
def __init__(self, typ, objects, root=1):
assert not is_basic_type(typ)
@ -204,9 +248,9 @@ class SSZPartial():
def __len__(self):
if is_list_kind(self.typ):
if self.root*2+1 not in self.objects:
raise OutOfRangeException("Do not have required data: {}".format(self.root*2+1))
return int.from_bytes(self.objects[self.root*2+1], 'little')
if self.root * 2 + 1 not in self.objects:
raise OutOfRangeException("Do not have required data: {}".format(self.root * 2 + 1))
return int.from_bytes(self.objects[self.root * 2 + 1], 'little')
elif is_vector_kind(self.typ):
return self.typ.length
elif is_container_type(self.typ):
@ -222,7 +266,9 @@ class SSZPartial():
elif is_vector_kind(self.typ):
return self.typ(*(self[i] for i in range(len(self))))
elif is_container_type(self.typ):
full_value = lambda x: x.full_value() if hasattr(x, 'full_value') else x
def full_value(x):
return x.full_value() if hasattr(x, 'full_value') else x
return self.typ(**{field: full_value(self.getter(field)) for field in self.typ.get_field_names()})
elif is_basic_type(self.typ):
return self.getter(0)
@ -235,8 +281,8 @@ class SSZPartial():
pos = 0
while pos < len(keys):
k = keys[pos]
if k in o and k^1 in o and k//2 not in o:
o[k//2] = hash(o[k&-2] + o[k|1])
if k in o and k ^ 1 in o and k // 2 not in o:
o[k // 2] = hash(o[k & - 2] + o[k | 1])
keys.append(k // 2)
pos += 1
return o[self.root]
@ -244,13 +290,16 @@ class SSZPartial():
def __str__(self):
return str(self.full_value())
def ssz_partial(typ, objects, root=1):
ssz_type = (
Container if is_container_type(typ) else
typ if (is_vector_type(typ) or is_bytesn_type(typ)) else object
)
class Partial(SSZPartial, ssz_type):
pass
if is_container_type(typ):
Partial.__annotations__ = typ.__annotations__
o = Partial(typ, objects, root=root)