diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py index 4b64c9162..144201d83 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py @@ -1,7 +1,7 @@ from ..merkle_minimal import merkleize_chunks from ..hash_function import hash from .ssz_typing import ( - SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Bytes, BytesN, uint, + SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Bytes, uint, ) # SSZ Serialization @@ -46,9 +46,8 @@ def serialize(obj: SSZValue): def encode_series(values: Series): - # bytes and bytesN are already in the right format. - if isinstance(values, (Bytes, BytesN)): - return values.items + if isinstance(values, bytes): # Bytes and BytesN are already like serialized output + return values # Recursively serialize parts = [(v.type().is_fixed_size(), serialize(v)) for v in values] @@ -84,8 +83,8 @@ def encode_series(values: Series): def pack(values: Series): - if isinstance(values, (Bytes, BytesN)): - return values.items + if isinstance(values, bytes): # Bytes and BytesN are already packed + return values return b''.join([serialize_basic(value) for value in values]) diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py index 381dadf9e..341df880a 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py @@ -188,10 +188,10 @@ class Container(Series, metaclass=SSZType): class ParamsBase(Series): - _bare = True + _has_params = False def __new__(cls, *args, **kwargs): - if cls._bare: + if not cls._has_params: raise Exception("cannot init bare type without params") return super().__new__(cls, **kwargs) @@ -200,13 +200,13 @@ class ParamsMeta(SSZType): def __new__(cls, class_name, parents, attrs): out = type.__new__(cls, class_name, parents, attrs) - for k, v in attrs.items(): - setattr(out, k, v) + if hasattr(out, "_has_params") and getattr(out, "_has_params"): + for k, v in attrs.items(): + setattr(out, k, v) return out def __getitem__(self, params): o = self.__class__(self.__name__, (self,), self.attr_from_params(params)) - o._bare = False return o def __str__(self): @@ -218,7 +218,7 @@ class ParamsMeta(SSZType): def attr_from_params(self, p): # single key params are valid too. Wrap them in a tuple. params = p if isinstance(p, tuple) else (p,) - res = {} + res = {'_has_params': True} i = 0 for (name, typ) in self.__annotations__.items(): if hasattr(self.__class__, name): @@ -262,13 +262,17 @@ class ElementsType(ParamsMeta): class Elements(ParamsBase, metaclass=ElementsType): + pass + + +class BaseList(list, Elements): def __init__(self, *args): items = self.extract_args(*args) if not self.value_check(items): raise ValueError(f"Bad input for class {self.__class__}: {items}") - self.items = items + super().__init__(items) @classmethod def value_check(cls, value): @@ -284,39 +288,32 @@ class Elements(ParamsBase, metaclass=ElementsType): def __str__(self): cls = self.__class__ - return f"{cls.__name__}[{cls.elem_type.__name__}, {cls.length}]({', '.join(str(v) for v in self.items)})" + return f"{cls.__name__}[{cls.elem_type.__name__}, {cls.length}]({', '.join(str(v) for v in self)})" def __getitem__(self, i) -> SSZValue: - return self.items[i] + if i < 0: + raise IndexError(f"cannot get item in type {self.__class__} at negative index {i}") + if i > len(self): + raise IndexError(f"cannot get item in type {self.__class__}" + f" at out of bounds index {i}") + return super().__getitem__(i) def __setitem__(self, k, v): if k < 0: raise IndexError(f"cannot set item in type {self.__class__} at negative index {k} (to {v})") - if k > len(self.items): + if k > len(self): raise IndexError(f"cannot set item in type {self.__class__}" - f" at out of bounds index {k} (to {v}, bound: {len(self.items)})") - self.items[k] = coerce_type_maybe(v, self.__class__.elem_type, strict=True) + f" at out of bounds index {k} (to {v}, bound: {len(self)})") + super().__setitem__(k, coerce_type_maybe(v, self.__class__.elem_type, strict=True)) def append(self, v): - self.items.append(coerce_type_maybe(v, self.__class__.elem_type, strict=True)) - - def pop(self): - if len(self.items) == 0: - raise IndexError("Pop from empty list") - else: - return self.items.pop() - - def __len__(self): - return len(self.items) - - def __repr__(self): - return repr(self.items) + super().append(coerce_type_maybe(v, self.__class__.elem_type, strict=True)) def __iter__(self) -> Iterator[SSZValue]: - return iter(self.items) + return super().__iter__() -class List(Elements): +class List(BaseList): @classmethod def default(cls): @@ -327,7 +324,7 @@ class List(Elements): return False -class Vector(Elements): +class Vector(BaseList): @classmethod def value_check(cls, value): @@ -342,27 +339,35 @@ class Vector(Elements): def is_fixed_size(cls): return cls.elem_type.is_fixed_size() + def append(self, v): + raise Exception("cannot modify vector length") + + def pop(self, *args): + raise Exception("cannot modify vector length") + class BytesType(ElementsType): elem_type: SSZType = byte length: int -class BytesLike(Elements, metaclass=BytesType): +class BaseBytes(bytes, Elements, metaclass=BytesType): + + def __new__(cls, *args) -> "BaseBytes": + extracted_val = cls.extract_args(*args) + if not cls.value_check(extracted_val): + raise ValueError(f"Bad input for class {cls}: {extracted_val}") + return super().__new__(cls, extracted_val) @classmethod def extract_args(cls, *args): - x = list(args) - if len(x) == 1 and isinstance(x[0], (GeneratorType, bytes, BytesLike)): + x = args + if len(x) == 1 and isinstance(x[0], (GeneratorType, bytes)): x = x[0] - if isinstance(x, bytes): + if isinstance(x, bytes): # Includes BytesLike return x - elif isinstance(x, BytesLike): - return x.items - elif isinstance(x, GeneratorType): - return bytes(x) else: - return bytes(x) + return bytes(x) # E.g. GeneratorType put into bytes. @classmethod def value_check(cls, value): @@ -374,7 +379,7 @@ class BytesLike(Elements, metaclass=BytesType): return f"{cls.__name__}[{cls.length}]: {self.hex()}" -class Bytes(BytesLike): +class Bytes(BaseBytes): @classmethod def default(cls): @@ -385,7 +390,14 @@ class Bytes(BytesLike): return False -class BytesN(BytesLike): +class BytesN(BaseBytes): + + @classmethod + def extract_args(cls, *args): + if len(args) == 0: + return cls.default() + else: + return super().extract_args(*args) @classmethod def default(cls): diff --git a/test_libs/pyspec/eth2spec/utils/ssz/test_ssz_typing.py b/test_libs/pyspec/eth2spec/utils/ssz/test_ssz_typing.py index daa923aa7..895a074a9 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/test_ssz_typing.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/test_ssz_typing.py @@ -211,3 +211,5 @@ def test_bytesn_subclass(): assert issubclass(Hash, Bytes32) assert not issubclass(Bytes48, Bytes32) + + assert len(Bytes32() + Bytes48()) == 80