From b6cf809d9b686f090dc4e19a68fbf311cd247831 Mon Sep 17 00:00:00 2001 From: protolambda Date: Thu, 20 Jun 2019 19:54:59 +0200 Subject: [PATCH] more improvements, and implement new space-efficient merkleization with padding support --- .../pyspec/eth2spec/utils/merkle_minimal.py | 36 +++++++--- .../pyspec/eth2spec/utils/ssz/ssz_impl.py | 66 +++++++++++-------- .../pyspec/eth2spec/utils/ssz/ssz_typing.py | 32 +++++---- 3 files changed, 86 insertions(+), 48 deletions(-) diff --git a/test_libs/pyspec/eth2spec/utils/merkle_minimal.py b/test_libs/pyspec/eth2spec/utils/merkle_minimal.py index c508f0df2..ebfb4faf6 100644 --- a/test_libs/pyspec/eth2spec/utils/merkle_minimal.py +++ b/test_libs/pyspec/eth2spec/utils/merkle_minimal.py @@ -44,11 +44,31 @@ def next_power_of_two(v: int) -> int: return 1 << (v - 1).bit_length() -def merkleize_chunks(chunks): - tree = chunks[::] - margin = next_power_of_two(len(chunks)) - len(chunks) - tree.extend([ZERO_BYTES32] * margin) - tree = [ZERO_BYTES32] * len(tree) + tree - for i in range(len(tree) // 2 - 1, 0, -1): - tree[i] = hash(tree[i * 2] + tree[i * 2 + 1]) - return tree[1] +def merkleize_chunks(chunks, pad_to: int = None): + count = len(chunks) + depth = max(count - 1, 0).bit_length() + max_depth = max(depth, (pad_to - 1).bit_length()) + tmp = [None for _ in range(max_depth + 1)] + + def merge(h, i): + j = 0 + while True: + if i & (1 << j) == 0: + if i == count and j < depth: + h = hash(h + zerohashes[j]) + else: + break + else: + h = hash(tmp[j] + h) + j += 1 + tmp[j] = h + + for i in range(count): + merge(chunks[i], i) + + merge(zerohashes[0], count) + + for j in range(depth, max_depth): + tmp[j + 1] = hash(tmp[j] + zerohashes[j]) + + return tmp[max_depth] diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py index b08a3d4e2..1a556bc7d 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py @@ -1,11 +1,7 @@ from ..merkle_minimal import merkleize_chunks, ZERO_BYTES32 -from .hash_function import hash +from ..hash_function import hash from .ssz_typing import ( - is_uint_type, is_bool_type, is_container_type, - is_list_kind, is_vector_kind, - read_elem_type, - infer_input_type, - get_zero_value, + get_zero_value, Container, List, Vector, Bytes, BytesN, uint ) # SSZ Serialization @@ -15,13 +11,13 @@ BYTES_PER_LENGTH_OFFSET = 4 def is_basic_type(typ): - return is_uint_type(typ) or is_bool_type(typ) + return issubclass(typ, (bool, uint)) def serialize_basic(value, typ): - if is_uint_type(typ): + if issubclass(typ, uint): return value.to_bytes(typ.byte_len, 'little') - elif is_bool_type(typ): + elif issubclass(typ, bool): if value: return b'\x01' else: @@ -31,22 +27,34 @@ def serialize_basic(value, typ): def deserialize_basic(value, typ): - if is_uint_type(typ): + if issubclass(typ, uint): return typ(int.from_bytes(value, 'little')) - elif is_bool_type(typ): + elif issubclass(typ, bool): assert value in (b'\x00', b'\x01') return True if value == b'\x01' else False else: raise Exception("Type not supported: {}".format(typ)) +def is_list_kind(typ): + return issubclass(typ, (List, Bytes)) + + +def is_vector_kind(typ): + return issubclass(typ, (Vector, BytesN)) + + +def is_container_type(typ): + return issubclass(typ, Container) + + def is_fixed_size(typ): if is_basic_type(typ): return True elif is_list_kind(typ): return False elif is_vector_kind(typ): - return is_fixed_size(read_vector_elem_type(typ)) + return is_fixed_size(typ.elem_type) elif is_container_type(typ): return all(is_fixed_size(t) for t in typ.get_field_types()) else: @@ -57,12 +65,11 @@ def is_empty(obj): return get_zero_value(type(obj)) == obj -@infer_input_type -def serialize(obj, typ=None): +def serialize(obj, typ): if is_basic_type(typ): return serialize_basic(obj, typ) elif is_list_kind(typ) or is_vector_kind(typ): - return encode_series(obj, [read_elem_type(typ)] * len(obj)) + return encode_series(obj, [typ.elem_type] * len(obj)) elif is_container_type(typ): return encode_series(obj.get_field_values(), typ.get_field_types()) else: @@ -126,40 +133,41 @@ def mix_in_length(root, length): def is_bottom_layer_kind(typ): return ( is_basic_type(typ) or - (is_list_kind(typ) or is_vector_kind(typ)) and is_basic_type(read_elem_type(typ)) + (is_list_kind(typ) or is_vector_kind(typ)) and is_basic_type(typ.elem_type) ) -@infer_input_type -def get_typed_values(obj, typ=None): +def get_typed_values(obj, typ): if is_container_type(typ): return obj.get_typed_values() elif is_list_kind(typ) or is_vector_kind(typ): - elem_type = read_elem_type(typ) - return list(zip(obj, [elem_type] * len(obj))) + return list(zip(obj, [typ.elem_type] * len(obj))) else: raise Exception("Invalid type") -def item_length(typ): - return 1 if typ == bool else typ.byte_len if is_uint_type(typ) else 32 -@infer_input_type -def hash_tree_root(obj, typ=None): +def item_length(typ): + if typ == bool: + return 1 + elif issubclass(typ, uint): + return typ.byte_len + else: + return 32 + + +def hash_tree_root(obj, typ): if is_bottom_layer_kind(typ): - data = serialize_basic(obj, typ) if is_basic_type(typ) else pack(obj, read_elem_type(typ)) + data = serialize_basic(obj, typ) if is_basic_type(typ) else pack(obj, typ.elem_type) leaves = chunkify(data) else: fields = get_typed_values(obj, typ=typ) leaves = [hash_tree_root(field_value, typ=field_typ) for field_value, field_typ in fields] if is_list_kind(typ): - full_chunk_length = (item_length(read_elem_type(typ)) * typ.length + 31) // 32 - leaves += [ZERO_BYTES32] * (full_chunk_length - len(obj)) - return mix_in_length(merkleize_chunks(leaves), len(obj)) + return mix_in_length(merkleize_chunks(leaves, pad_to=typ.length), len(obj)) else: return merkleize_chunks(leaves) -@infer_input_type def signing_root(obj, typ): assert is_container_type(typ) # ignore last field diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py index 9aafb5294..30f71f87d 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py @@ -1,10 +1,26 @@ +from typing import NewType, Union from types import GeneratorType +class ValueCheckError(Exception): + pass + + class DefaultingTypeMeta(type): def default(cls): raise Exception("Not implemented") +# Every type is subclassed and has a default() method, except bool. +TypeWithDefault = Union[DefaultingTypeMeta, bool] + + +def get_zero_value(typ: TypeWithDefault): + if issubclass(typ, bool): + return False + else: + return typ.default() + + # SSZ integers # ----------------------------- @@ -63,7 +79,7 @@ class Container(object): cls = self.__class__ for f, t in cls.get_fields(): if f not in kwargs: - setattr(self, f, t.default()) + setattr(self, f, get_zero_value(t)) else: setattr(self, f, kwargs[f]) @@ -120,7 +136,7 @@ class Container(object): @classmethod def default(cls): - return cls(**{f: t.default() for f, t in cls.get_fields()}) + return cls(**{f: get_zero_value(t) for f, t in cls.get_fields()}) class ParamsBase: @@ -174,12 +190,8 @@ class ParamsMeta(DefaultingTypeMeta): return True -class ValueCheckError(Exception): - pass - - class AbstractListMeta(ParamsMeta): - elem_type: DefaultingTypeMeta + elem_type: TypeWithDefault length: int @@ -227,8 +239,6 @@ class AbstractList(ParamsBase, metaclass=AbstractListMeta): class List(AbstractList): - def value_check(self, value): - return len(value) <= self.__class__.length and super().value_check(value) @classmethod def default(cls): @@ -241,11 +251,11 @@ class Vector(AbstractList, metaclass=AbstractListMeta): @classmethod def default(cls): - return [cls.elem_type.default() for _ in range(cls.length)] + return [get_zero_value(cls.elem_type) for _ in range(cls.length)] class BytesMeta(AbstractListMeta): - elem_type: DefaultingTypeMeta = byte + elem_type: TypeWithDefault = byte length: int