diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py index bcdef3988..679574891 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py @@ -1,7 +1,7 @@ -from ..merkle_minimal import merkleize_chunks, ZERO_BYTES32 +from ..merkle_minimal import merkleize_chunks from ..hash_function import hash from .ssz_typing import ( - get_zero_value, Container, List, Vector, Bytes, BytesN, uint + SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Vector, Bytes, BytesN, uint ) # SSZ Serialization @@ -10,79 +10,48 @@ from .ssz_typing import ( BYTES_PER_LENGTH_OFFSET = 4 -def is_basic_type(typ): - return issubclass(typ, (bool, uint)) - - -def serialize_basic(value, typ): - if issubclass(typ, uint): - return value.to_bytes(typ.byte_len, 'little') - elif issubclass(typ, bool): +def serialize_basic(value: SSZValue): + if isinstance(value, uint): + return value.to_bytes(value.type().byte_len, 'little') + elif isinstance(value, Bit): if value: return b'\x01' else: return b'\x00' else: - raise Exception("Type not supported: {}".format(typ)) + raise Exception(f"Type not supported: {type(value)}") -def deserialize_basic(value, typ): +def deserialize_basic(value, typ: BasicType): if issubclass(typ, uint): return typ(int.from_bytes(value, 'little')) - elif issubclass(typ, bool): + elif issubclass(typ, Bit): assert value in (b'\x00', b'\x01') - return True if value == b'\x01' else False + return Bit(value == b'\x01') else: - raise Exception("Type not supported: {}".format(typ)) + raise Exception(f"Type not supported: {typ}") -def is_list_kind(typ): - return issubclass(typ, (List, Bytes)) +def is_empty(obj: SSZValue): + return type(obj).default() == obj -def is_vector_kind(typ): - return issubclass(typ, (Vector, BytesN)) - - -def is_container_type(typ): - return issubclass(typ, Container) - - -def is_fixed_size(typ): - if is_basic_type(typ): - return True - elif is_list_kind(typ): - return False - elif is_vector_kind(typ): - return is_fixed_size(typ.elem_type) - elif is_container_type(typ): - return all(is_fixed_size(t) for t in typ.get_field_types()) +def serialize(obj: SSZValue): + if isinstance(obj, BasicValue): + return serialize_basic(obj) + elif isinstance(obj, Series): + return encode_series(obj) else: - raise Exception("Type not supported: {}".format(typ)) + raise Exception(f"Type not supported: {type(obj)}") -def is_empty(obj): - return get_zero_value(type(obj)) == obj - - -def serialize(obj, typ): - if is_basic_type(typ): - return serialize_basic(obj, typ) - elif is_list_kind(typ) or is_vector_kind(typ): - return encode_series(obj, [typ.elem_type] * len(obj)) - elif is_container_type(typ): - return encode_series(obj.get_field_values(), typ.get_field_types()) - else: - raise Exception("Type not supported: {}".format(typ)) - - -def encode_series(values, types): +def encode_series(values: Series): # bytes and bytesN are already in the right format. if isinstance(values, bytes): return values # Recursively serialize - parts = [(is_fixed_size(types[i]), serialize(values[i], typ=types[i])) for i in range(len(values))] + parts = [(v.type().is_fixed_size(), serialize(v)) for v in values] # Compute and check lengths fixed_lengths = [len(serialized) if constant_size else BYTES_PER_LENGTH_OFFSET @@ -114,10 +83,10 @@ def encode_series(values, types): # ----------------------------- -def pack(values, subtype): +def pack(values: Series): if isinstance(values, bytes): return values - return b''.join([serialize_basic(value, subtype) for value in values]) + return b''.join([serialize_basic(value) for value in values]) def chunkify(bytez): @@ -130,52 +99,49 @@ def mix_in_length(root, length): return hash(root + length.to_bytes(32, 'little')) -def is_bottom_layer_kind(typ): +def is_bottom_layer_kind(typ: SSZType): return ( - is_basic_type(typ) or - (is_list_kind(typ) or is_vector_kind(typ)) and is_basic_type(typ.elem_type) + issubclass(typ, BasicType) or + (issubclass(typ, Elements) and issubclass(typ.elem_type, BasicType)) ) -def get_typed_values(obj, typ): - if is_container_type(typ): - return obj.get_typed_values() - elif is_list_kind(typ) or is_vector_kind(typ): - return list(zip(obj, [typ.elem_type] * len(obj))) - else: - raise Exception("Invalid type") - - -def item_length(typ): - if typ == bool: - return 1 - elif issubclass(typ, uint): +def item_length(typ: SSZType) -> int: + if issubclass(typ, BasicType): return typ.byte_len else: return 32 -def chunk_count(typ): - if is_basic_type(typ): - return 1 - elif is_list_kind(typ) or is_vector_kind(typ): - return (typ.length * item_length(typ.elem_type) + 31) // 32 - else: - return len(typ.get_fields()) -def hash_tree_root(obj, typ): - if is_bottom_layer_kind(typ): - data = serialize_basic(obj, typ) if is_basic_type(typ) else pack(obj, typ.elem_type) - leaves = chunkify(data) + +def chunk_count(typ: SSZType) -> int: + if issubclass(typ, BasicType): + return 1 + elif issubclass(typ, Elements): + return (typ.length * item_length(typ.elem_type) + 31) // 32 + elif issubclass(typ, Container): + return len(typ.get_fields()) else: - fields = get_typed_values(obj, typ=typ) - leaves = [hash_tree_root(field_value, typ=field_typ) for field_value, field_typ in fields] - if is_list_kind(typ): - return mix_in_length(merkleize_chunks(leaves, pad_to=chunk_count(typ)), len(obj)) + raise Exception(f"Type not supported: {typ}") + + +def hash_tree_root(obj: SSZValue): + if isinstance(obj, Series): + if is_bottom_layer_kind(obj.type()): + leaves = chunkify(pack(obj)) + else: + leaves = [hash_tree_root(value) for value in obj] + elif isinstance(obj, BasicValue): + leaves = chunkify(serialize_basic(obj)) + else: + raise Exception(f"Type not supported: {obj.type()}") + + if isinstance(obj, (List, Bytes)): + return mix_in_length(merkleize_chunks(leaves, pad_to=chunk_count(obj.type())), len(obj)) else: return merkleize_chunks(leaves) -def signing_root(obj, typ): - assert is_container_type(typ) +def signing_root(obj: Container): # ignore last field - leaves = [hash_tree_root(field_value, typ=field_typ) for field_value, field_typ in obj.get_typed_values()[:-1]] + leaves = [hash_tree_root(field) for field in obj[:-1]] return merkleize_chunks(chunkify(b''.join(leaves))) diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py index 40901ad97..b79789f27 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py @@ -1,4 +1,4 @@ -from typing import NewType, Union +from typing import Tuple, Dict, Iterator from types import GeneratorType @@ -10,24 +10,42 @@ class DefaultingTypeMeta(type): def default(cls): raise Exception("Not implemented") -# Every type is subclassed and has a default() method, except bool. -TypeWithDefault = Union[DefaultingTypeMeta, bool] + +class SSZType(DefaultingTypeMeta): + + def is_fixed_size(cls): + raise Exception("Not implemented") -def get_zero_value(typ: TypeWithDefault): - if issubclass(typ, bool): - return False - else: - return typ.default() +class SSZValue(object, metaclass=SSZType): + + def type(self): + return self.__class__ -# SSZ integers -# ----------------------------- - - -class uint(int, metaclass=DefaultingTypeMeta): +class BasicType(SSZType): byte_len = 0 + def is_fixed_size(cls): + return True + + +class BasicValue(int, SSZValue, metaclass=BasicType): + pass + + +class Bit(BasicValue): # can't subclass bool. + + @classmethod + def default(cls): + return cls(False) + + def __bool__(self): + return self > 0 + + +class uint(BasicValue, metaclass=BasicType): + def __new__(cls, value, *args, **kwargs): if value < 0: raise ValueError("unsigned types must not be negative") @@ -69,36 +87,39 @@ class uint256(uint): byte_len = 32 -# SSZ Container base class -# ----------------------------- +class Series(SSZValue): + + def __iter__(self) -> Iterator[SSZValue]: + raise Exception("Not implemented") + # Note: importing ssz functionality locally, to avoid import loop -class Container(object, metaclass=DefaultingTypeMeta): +class Container(Series, metaclass=SSZType): def __init__(self, **kwargs): cls = self.__class__ for f, t in cls.get_fields(): if f not in kwargs: - setattr(self, f, get_zero_value(t)) + setattr(self, f, t.default()) else: setattr(self, f, kwargs[f]) def serialize(self): from .ssz_impl import serialize - return serialize(self, self.__class__) + return serialize(self) def hash_tree_root(self): from .ssz_impl import hash_tree_root - return hash_tree_root(self, self.__class__) + return hash_tree_root(self) def signing_root(self): from .ssz_impl import signing_root - return signing_root(self, self.__class__) + return signing_root(self) - def get_field_values(self): + def get_field_values(self) -> Tuple[SSZValue, ...]: cls = self.__class__ - return [getattr(self, field) for field in cls.get_field_names()] + return tuple(getattr(self, field) for field in cls.get_field_names()) def __repr__(self): return repr({field: getattr(self, field) for field in self.get_field_names()}) @@ -116,31 +137,38 @@ class Container(object, metaclass=DefaultingTypeMeta): return hash(self.hash_tree_root()) @classmethod - def get_fields_dict(cls): + def get_fields_dict(cls) -> Dict[str, SSZType]: return dict(cls.__annotations__) @classmethod - def get_fields(cls): - return list(dict(cls.__annotations__).items()) + def get_fields(cls) -> Tuple[Tuple[str, SSZType], ...]: + return tuple((f, SSZType(t)) for f, t in dict(cls.__annotations__).items()) def get_typed_values(self): - return list(zip(self.get_field_values(), self.get_field_types())) + return tuple(zip(self.get_field_values(), self.get_field_types())) @classmethod - def get_field_names(cls): - return list(cls.__annotations__.keys()) + def get_field_names(cls) -> Tuple[str]: + return tuple(cls.__annotations__.keys()) @classmethod - def get_field_types(cls): + def get_field_types(cls) -> Tuple[SSZType, ...]: # values of annotations are the types corresponding to the fields, not instance values. - return list(cls.__annotations__.values()) + return tuple(cls.__annotations__.values()) @classmethod def default(cls): - return cls(**{f: get_zero_value(t) for f, t in cls.get_fields()}) + return cls(**{f: t.default() for f, t in cls.get_fields()}) + + @classmethod + def is_fixed_size(cls): + return all(t.is_fixed_size() for t in cls.get_field_types()) + + def __iter__(self) -> Iterator[SSZValue]: + return iter(self.get_field_values()) -class ParamsBase: +class ParamsBase(Series): _bare = True def __new__(cls, *args, **kwargs): @@ -149,7 +177,7 @@ class ParamsBase: return super().__new__(cls, **kwargs) -class ParamsMeta(DefaultingTypeMeta): +class ParamsMeta(SSZType): def __new__(cls, class_name, parents, attrs): out = type.__new__(cls, class_name, parents, attrs) @@ -168,14 +196,14 @@ class ParamsMeta(DefaultingTypeMeta): res = {} i = 0 for (name, typ) in self.__annotations__.items(): - param = params[i] if hasattr(self.__class__, name): res[name] = getattr(self.__class__, name) else: - if typ == TypeWithDefault: - if not (isinstance(param, bool) or isinstance(param, DefaultingTypeMeta)): - raise TypeError("expected param {} as {} to have a type default".format(param, name, typ)) - elif not isinstance(param, typ): + if i >= len(params): + i += 1 + continue + param = params[i] + if not isinstance(param, typ): raise TypeError( "cannot create parametrized class with param {} as {} of type {}".format(param, name, typ)) res[name] = param @@ -194,12 +222,12 @@ class ParamsMeta(DefaultingTypeMeta): return True -class AbstractListMeta(ParamsMeta): - elem_type: TypeWithDefault +class Elements(ParamsMeta): + elem_type: SSZType length: int -class AbstractList(ParamsBase, metaclass=AbstractListMeta): +class ElementsBase(ParamsBase, metaclass=Elements): def __init__(self, *args): items = self.extract_args(*args) @@ -223,7 +251,7 @@ class AbstractList(ParamsBase, metaclass=AbstractListMeta): cls = self.__class__ return f"{cls.__name__}[{cls.elem_type.__name__}, {cls.length}]({', '.join(str(v) for v in self.items)})" - def __getitem__(self, i): + def __getitem__(self, i) -> SSZValue: return self.items[i] def __setitem__(self, k, v): @@ -235,21 +263,25 @@ class AbstractList(ParamsBase, metaclass=AbstractListMeta): def __repr__(self): return repr(self.items) - def __iter__(self): + def __iter__(self) -> Iterator[SSZValue]: return iter(self.items) def __eq__(self, other): return self.items == other.items -class List(AbstractList): +class List(ElementsBase): @classmethod def default(cls): return cls() + @classmethod + def is_fixed_size(cls): + return False -class Vector(AbstractList, metaclass=AbstractListMeta): + +class Vector(ElementsBase): @classmethod def value_check(cls, value): @@ -257,15 +289,19 @@ class Vector(AbstractList, metaclass=AbstractListMeta): @classmethod def default(cls): - return [get_zero_value(cls.elem_type) for _ in range(cls.length)] + return [cls.elem_type.default() for _ in range(cls.length)] + + @classmethod + def is_fixed_size(cls): + return cls.elem_type.is_fixed_size() -class BytesMeta(AbstractListMeta): - elem_type: TypeWithDefault = byte +class BytesMeta(Elements): + elem_type: SSZType = byte length: int -class BytesLike(AbstractList, metaclass=BytesMeta): +class BytesLike(ElementsBase, metaclass=BytesMeta): @classmethod def extract_args(cls, args): @@ -293,6 +329,10 @@ class Bytes(BytesLike): def default(cls): return b'' + @classmethod + def is_fixed_size(cls): + return False + class BytesN(BytesLike): @@ -304,3 +344,6 @@ class BytesN(BytesLike): def value_check(cls, value): return len(value) == cls.length and super().value_check(value) + @classmethod + def is_fixed_size(cls): + return True