more improvements, and implement new space-efficient merkleization with padding support

This commit is contained in:
protolambda 2019-06-20 19:54:59 +02:00
parent d1ecfd510e
commit b6cf809d9b
No known key found for this signature in database
GPG Key ID: EC89FDBB2B4C7623
3 changed files with 86 additions and 48 deletions

View File

@ -44,11 +44,31 @@ def next_power_of_two(v: int) -> int:
return 1 << (v - 1).bit_length() return 1 << (v - 1).bit_length()
def merkleize_chunks(chunks): def merkleize_chunks(chunks, pad_to: int = None):
tree = chunks[::] count = len(chunks)
margin = next_power_of_two(len(chunks)) - len(chunks) depth = max(count - 1, 0).bit_length()
tree.extend([ZERO_BYTES32] * margin) max_depth = max(depth, (pad_to - 1).bit_length())
tree = [ZERO_BYTES32] * len(tree) + tree tmp = [None for _ in range(max_depth + 1)]
for i in range(len(tree) // 2 - 1, 0, -1):
tree[i] = hash(tree[i * 2] + tree[i * 2 + 1]) def merge(h, i):
return tree[1] j = 0
while True:
if i & (1 << j) == 0:
if i == count and j < depth:
h = hash(h + zerohashes[j])
else:
break
else:
h = hash(tmp[j] + h)
j += 1
tmp[j] = h
for i in range(count):
merge(chunks[i], i)
merge(zerohashes[0], count)
for j in range(depth, max_depth):
tmp[j + 1] = hash(tmp[j] + zerohashes[j])
return tmp[max_depth]

View File

@ -1,11 +1,7 @@
from ..merkle_minimal import merkleize_chunks, ZERO_BYTES32 from ..merkle_minimal import merkleize_chunks, ZERO_BYTES32
from .hash_function import hash from ..hash_function import hash
from .ssz_typing import ( from .ssz_typing import (
is_uint_type, is_bool_type, is_container_type, get_zero_value, Container, List, Vector, Bytes, BytesN, uint
is_list_kind, is_vector_kind,
read_elem_type,
infer_input_type,
get_zero_value,
) )
# SSZ Serialization # SSZ Serialization
@ -15,13 +11,13 @@ BYTES_PER_LENGTH_OFFSET = 4
def is_basic_type(typ): def is_basic_type(typ):
return is_uint_type(typ) or is_bool_type(typ) return issubclass(typ, (bool, uint))
def serialize_basic(value, typ): def serialize_basic(value, typ):
if is_uint_type(typ): if issubclass(typ, uint):
return value.to_bytes(typ.byte_len, 'little') return value.to_bytes(typ.byte_len, 'little')
elif is_bool_type(typ): elif issubclass(typ, bool):
if value: if value:
return b'\x01' return b'\x01'
else: else:
@ -31,22 +27,34 @@ def serialize_basic(value, typ):
def deserialize_basic(value, typ): def deserialize_basic(value, typ):
if is_uint_type(typ): if issubclass(typ, uint):
return typ(int.from_bytes(value, 'little')) return typ(int.from_bytes(value, 'little'))
elif is_bool_type(typ): elif issubclass(typ, bool):
assert value in (b'\x00', b'\x01') assert value in (b'\x00', b'\x01')
return True if value == b'\x01' else False return True if value == b'\x01' else False
else: else:
raise Exception("Type not supported: {}".format(typ)) raise Exception("Type not supported: {}".format(typ))
def is_list_kind(typ):
return issubclass(typ, (List, Bytes))
def is_vector_kind(typ):
return issubclass(typ, (Vector, BytesN))
def is_container_type(typ):
return issubclass(typ, Container)
def is_fixed_size(typ): def is_fixed_size(typ):
if is_basic_type(typ): if is_basic_type(typ):
return True return True
elif is_list_kind(typ): elif is_list_kind(typ):
return False return False
elif is_vector_kind(typ): elif is_vector_kind(typ):
return is_fixed_size(read_vector_elem_type(typ)) return is_fixed_size(typ.elem_type)
elif is_container_type(typ): elif is_container_type(typ):
return all(is_fixed_size(t) for t in typ.get_field_types()) return all(is_fixed_size(t) for t in typ.get_field_types())
else: else:
@ -57,12 +65,11 @@ def is_empty(obj):
return get_zero_value(type(obj)) == obj return get_zero_value(type(obj)) == obj
@infer_input_type def serialize(obj, typ):
def serialize(obj, typ=None):
if is_basic_type(typ): if is_basic_type(typ):
return serialize_basic(obj, typ) return serialize_basic(obj, typ)
elif is_list_kind(typ) or is_vector_kind(typ): elif is_list_kind(typ) or is_vector_kind(typ):
return encode_series(obj, [read_elem_type(typ)] * len(obj)) return encode_series(obj, [typ.elem_type] * len(obj))
elif is_container_type(typ): elif is_container_type(typ):
return encode_series(obj.get_field_values(), typ.get_field_types()) return encode_series(obj.get_field_values(), typ.get_field_types())
else: else:
@ -126,40 +133,41 @@ def mix_in_length(root, length):
def is_bottom_layer_kind(typ): def is_bottom_layer_kind(typ):
return ( return (
is_basic_type(typ) or is_basic_type(typ) or
(is_list_kind(typ) or is_vector_kind(typ)) and is_basic_type(read_elem_type(typ)) (is_list_kind(typ) or is_vector_kind(typ)) and is_basic_type(typ.elem_type)
) )
@infer_input_type def get_typed_values(obj, typ):
def get_typed_values(obj, typ=None):
if is_container_type(typ): if is_container_type(typ):
return obj.get_typed_values() return obj.get_typed_values()
elif is_list_kind(typ) or is_vector_kind(typ): elif is_list_kind(typ) or is_vector_kind(typ):
elem_type = read_elem_type(typ) return list(zip(obj, [typ.elem_type] * len(obj)))
return list(zip(obj, [elem_type] * len(obj)))
else: else:
raise Exception("Invalid type") raise Exception("Invalid type")
def item_length(typ):
return 1 if typ == bool else typ.byte_len if is_uint_type(typ) else 32
@infer_input_type def item_length(typ):
def hash_tree_root(obj, typ=None): if typ == bool:
return 1
elif issubclass(typ, uint):
return typ.byte_len
else:
return 32
def hash_tree_root(obj, typ):
if is_bottom_layer_kind(typ): if is_bottom_layer_kind(typ):
data = serialize_basic(obj, typ) if is_basic_type(typ) else pack(obj, read_elem_type(typ)) data = serialize_basic(obj, typ) if is_basic_type(typ) else pack(obj, typ.elem_type)
leaves = chunkify(data) leaves = chunkify(data)
else: else:
fields = get_typed_values(obj, typ=typ) fields = get_typed_values(obj, typ=typ)
leaves = [hash_tree_root(field_value, typ=field_typ) for field_value, field_typ in fields] leaves = [hash_tree_root(field_value, typ=field_typ) for field_value, field_typ in fields]
if is_list_kind(typ): if is_list_kind(typ):
full_chunk_length = (item_length(read_elem_type(typ)) * typ.length + 31) // 32 return mix_in_length(merkleize_chunks(leaves, pad_to=typ.length), len(obj))
leaves += [ZERO_BYTES32] * (full_chunk_length - len(obj))
return mix_in_length(merkleize_chunks(leaves), len(obj))
else: else:
return merkleize_chunks(leaves) return merkleize_chunks(leaves)
@infer_input_type
def signing_root(obj, typ): def signing_root(obj, typ):
assert is_container_type(typ) assert is_container_type(typ)
# ignore last field # ignore last field

View File

@ -1,10 +1,26 @@
from typing import NewType, Union
from types import GeneratorType from types import GeneratorType
class ValueCheckError(Exception):
pass
class DefaultingTypeMeta(type): class DefaultingTypeMeta(type):
def default(cls): def default(cls):
raise Exception("Not implemented") raise Exception("Not implemented")
# Every type is subclassed and has a default() method, except bool.
TypeWithDefault = Union[DefaultingTypeMeta, bool]
def get_zero_value(typ: TypeWithDefault):
if issubclass(typ, bool):
return False
else:
return typ.default()
# SSZ integers # SSZ integers
# ----------------------------- # -----------------------------
@ -63,7 +79,7 @@ class Container(object):
cls = self.__class__ cls = self.__class__
for f, t in cls.get_fields(): for f, t in cls.get_fields():
if f not in kwargs: if f not in kwargs:
setattr(self, f, t.default()) setattr(self, f, get_zero_value(t))
else: else:
setattr(self, f, kwargs[f]) setattr(self, f, kwargs[f])
@ -120,7 +136,7 @@ class Container(object):
@classmethod @classmethod
def default(cls): def default(cls):
return cls(**{f: t.default() for f, t in cls.get_fields()}) return cls(**{f: get_zero_value(t) for f, t in cls.get_fields()})
class ParamsBase: class ParamsBase:
@ -174,12 +190,8 @@ class ParamsMeta(DefaultingTypeMeta):
return True return True
class ValueCheckError(Exception):
pass
class AbstractListMeta(ParamsMeta): class AbstractListMeta(ParamsMeta):
elem_type: DefaultingTypeMeta elem_type: TypeWithDefault
length: int length: int
@ -227,8 +239,6 @@ class AbstractList(ParamsBase, metaclass=AbstractListMeta):
class List(AbstractList): class List(AbstractList):
def value_check(self, value):
return len(value) <= self.__class__.length and super().value_check(value)
@classmethod @classmethod
def default(cls): def default(cls):
@ -241,11 +251,11 @@ class Vector(AbstractList, metaclass=AbstractListMeta):
@classmethod @classmethod
def default(cls): def default(cls):
return [cls.elem_type.default() for _ in range(cls.length)] return [get_zero_value(cls.elem_type) for _ in range(cls.length)]
class BytesMeta(AbstractListMeta): class BytesMeta(AbstractListMeta):
elem_type: DefaultingTypeMeta = byte elem_type: TypeWithDefault = byte
length: int length: int