more improvements, and implement new space-efficient merkleization with padding support

This commit is contained in:
protolambda 2019-06-20 19:54:59 +02:00
parent d1ecfd510e
commit b6cf809d9b
No known key found for this signature in database
GPG Key ID: EC89FDBB2B4C7623
3 changed files with 86 additions and 48 deletions

View File

@ -44,11 +44,31 @@ def next_power_of_two(v: int) -> int:
return 1 << (v - 1).bit_length()
def merkleize_chunks(chunks):
tree = chunks[::]
margin = next_power_of_two(len(chunks)) - len(chunks)
tree.extend([ZERO_BYTES32] * margin)
tree = [ZERO_BYTES32] * len(tree) + tree
for i in range(len(tree) // 2 - 1, 0, -1):
tree[i] = hash(tree[i * 2] + tree[i * 2 + 1])
return tree[1]
def merkleize_chunks(chunks, pad_to: int = None):
count = len(chunks)
depth = max(count - 1, 0).bit_length()
max_depth = max(depth, (pad_to - 1).bit_length())
tmp = [None for _ in range(max_depth + 1)]
def merge(h, i):
j = 0
while True:
if i & (1 << j) == 0:
if i == count and j < depth:
h = hash(h + zerohashes[j])
else:
break
else:
h = hash(tmp[j] + h)
j += 1
tmp[j] = h
for i in range(count):
merge(chunks[i], i)
merge(zerohashes[0], count)
for j in range(depth, max_depth):
tmp[j + 1] = hash(tmp[j] + zerohashes[j])
return tmp[max_depth]

View File

@ -1,11 +1,7 @@
from ..merkle_minimal import merkleize_chunks, ZERO_BYTES32
from .hash_function import hash
from ..hash_function import hash
from .ssz_typing import (
is_uint_type, is_bool_type, is_container_type,
is_list_kind, is_vector_kind,
read_elem_type,
infer_input_type,
get_zero_value,
get_zero_value, Container, List, Vector, Bytes, BytesN, uint
)
# SSZ Serialization
@ -15,13 +11,13 @@ BYTES_PER_LENGTH_OFFSET = 4
def is_basic_type(typ):
return is_uint_type(typ) or is_bool_type(typ)
return issubclass(typ, (bool, uint))
def serialize_basic(value, typ):
if is_uint_type(typ):
if issubclass(typ, uint):
return value.to_bytes(typ.byte_len, 'little')
elif is_bool_type(typ):
elif issubclass(typ, bool):
if value:
return b'\x01'
else:
@ -31,22 +27,34 @@ def serialize_basic(value, typ):
def deserialize_basic(value, typ):
if is_uint_type(typ):
if issubclass(typ, uint):
return typ(int.from_bytes(value, 'little'))
elif is_bool_type(typ):
elif issubclass(typ, bool):
assert value in (b'\x00', b'\x01')
return True if value == b'\x01' else False
else:
raise Exception("Type not supported: {}".format(typ))
def is_list_kind(typ):
return issubclass(typ, (List, Bytes))
def is_vector_kind(typ):
return issubclass(typ, (Vector, BytesN))
def is_container_type(typ):
return issubclass(typ, Container)
def is_fixed_size(typ):
if is_basic_type(typ):
return True
elif is_list_kind(typ):
return False
elif is_vector_kind(typ):
return is_fixed_size(read_vector_elem_type(typ))
return is_fixed_size(typ.elem_type)
elif is_container_type(typ):
return all(is_fixed_size(t) for t in typ.get_field_types())
else:
@ -57,12 +65,11 @@ def is_empty(obj):
return get_zero_value(type(obj)) == obj
@infer_input_type
def serialize(obj, typ=None):
def serialize(obj, typ):
if is_basic_type(typ):
return serialize_basic(obj, typ)
elif is_list_kind(typ) or is_vector_kind(typ):
return encode_series(obj, [read_elem_type(typ)] * len(obj))
return encode_series(obj, [typ.elem_type] * len(obj))
elif is_container_type(typ):
return encode_series(obj.get_field_values(), typ.get_field_types())
else:
@ -126,40 +133,41 @@ def mix_in_length(root, length):
def is_bottom_layer_kind(typ):
return (
is_basic_type(typ) or
(is_list_kind(typ) or is_vector_kind(typ)) and is_basic_type(read_elem_type(typ))
(is_list_kind(typ) or is_vector_kind(typ)) and is_basic_type(typ.elem_type)
)
@infer_input_type
def get_typed_values(obj, typ=None):
def get_typed_values(obj, typ):
if is_container_type(typ):
return obj.get_typed_values()
elif is_list_kind(typ) or is_vector_kind(typ):
elem_type = read_elem_type(typ)
return list(zip(obj, [elem_type] * len(obj)))
return list(zip(obj, [typ.elem_type] * len(obj)))
else:
raise Exception("Invalid type")
def item_length(typ):
return 1 if typ == bool else typ.byte_len if is_uint_type(typ) else 32
@infer_input_type
def hash_tree_root(obj, typ=None):
def item_length(typ):
if typ == bool:
return 1
elif issubclass(typ, uint):
return typ.byte_len
else:
return 32
def hash_tree_root(obj, typ):
if is_bottom_layer_kind(typ):
data = serialize_basic(obj, typ) if is_basic_type(typ) else pack(obj, read_elem_type(typ))
data = serialize_basic(obj, typ) if is_basic_type(typ) else pack(obj, typ.elem_type)
leaves = chunkify(data)
else:
fields = get_typed_values(obj, typ=typ)
leaves = [hash_tree_root(field_value, typ=field_typ) for field_value, field_typ in fields]
if is_list_kind(typ):
full_chunk_length = (item_length(read_elem_type(typ)) * typ.length + 31) // 32
leaves += [ZERO_BYTES32] * (full_chunk_length - len(obj))
return mix_in_length(merkleize_chunks(leaves), len(obj))
return mix_in_length(merkleize_chunks(leaves, pad_to=typ.length), len(obj))
else:
return merkleize_chunks(leaves)
@infer_input_type
def signing_root(obj, typ):
assert is_container_type(typ)
# ignore last field

View File

@ -1,10 +1,26 @@
from typing import NewType, Union
from types import GeneratorType
class ValueCheckError(Exception):
pass
class DefaultingTypeMeta(type):
def default(cls):
raise Exception("Not implemented")
# Every type is subclassed and has a default() method, except bool.
TypeWithDefault = Union[DefaultingTypeMeta, bool]
def get_zero_value(typ: TypeWithDefault):
if issubclass(typ, bool):
return False
else:
return typ.default()
# SSZ integers
# -----------------------------
@ -63,7 +79,7 @@ class Container(object):
cls = self.__class__
for f, t in cls.get_fields():
if f not in kwargs:
setattr(self, f, t.default())
setattr(self, f, get_zero_value(t))
else:
setattr(self, f, kwargs[f])
@ -120,7 +136,7 @@ class Container(object):
@classmethod
def default(cls):
return cls(**{f: t.default() for f, t in cls.get_fields()})
return cls(**{f: get_zero_value(t) for f, t in cls.get_fields()})
class ParamsBase:
@ -174,12 +190,8 @@ class ParamsMeta(DefaultingTypeMeta):
return True
class ValueCheckError(Exception):
pass
class AbstractListMeta(ParamsMeta):
elem_type: DefaultingTypeMeta
elem_type: TypeWithDefault
length: int
@ -227,8 +239,6 @@ class AbstractList(ParamsBase, metaclass=AbstractListMeta):
class List(AbstractList):
def value_check(self, value):
return len(value) <= self.__class__.length and super().value_check(value)
@classmethod
def default(cls):
@ -241,11 +251,11 @@ class Vector(AbstractList, metaclass=AbstractListMeta):
@classmethod
def default(cls):
return [cls.elem_type.default() for _ in range(cls.length)]
return [get_zero_value(cls.elem_type) for _ in range(cls.length)]
class BytesMeta(AbstractListMeta):
elem_type: DefaultingTypeMeta = byte
elem_type: TypeWithDefault = byte
length: int