diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py index de54bbf05..9aafb5294 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py @@ -1,52 +1,56 @@ from types import GeneratorType -from typing import List, Iterable, TypeVar, Type, NewType -from typing import Union -from typing_inspect import get_origin + + +class DefaultingTypeMeta(type): + def default(cls): + raise Exception("Not implemented") # SSZ integers # ----------------------------- -class uint(int): + +class uint(int, metaclass=DefaultingTypeMeta): byte_len = 0 def __new__(cls, value, *args, **kwargs): if value < 0: raise ValueError("unsigned types must not be negative") - if value.byte_len and value.bit_length() > value.byte_len: - raise ValueError("value out of bounds for uint{}".format(value.byte_len)) + if cls.byte_len and (value.bit_length() >> 3) > cls.byte_len: + raise ValueError("value out of bounds for uint{}".format(cls.byte_len)) return super().__new__(cls, value) + @classmethod + def default(cls): + return cls(0) + class uint8(uint): byte_len = 1 + # Alias for uint8 byte = NewType('byte', uint8) + class uint16(uint): byte_len = 2 + class uint32(uint): byte_len = 4 + class uint64(uint): byte_len = 8 + class uint128(uint): byte_len = 16 + class uint256(uint): byte_len = 32 -def is_uint_type(typ): - # All integers are uint in the scope of the spec here. - # Since we default to uint64. Bounds can be checked elsewhere. - # However, some are wrapped in a NewType - if hasattr(typ, '__supertype__'): - # get the type that the NewType is wrapping - typ = typ.__supertype__ - - return isinstance(typ, type) and issubclass(typ, int) and not issubclass(typ, bool) # SSZ Container base class # ----------------------------- @@ -59,7 +63,7 @@ class Container(object): cls = self.__class__ for f, t in cls.get_fields(): if f not in kwargs: - setattr(self, f, get_zero_value(t)) + setattr(self, f, t.default()) else: setattr(self, f, kwargs[f]) @@ -83,9 +87,9 @@ class Container(object): return repr({field: getattr(self, field) for field in self.get_field_names()}) def __str__(self): - output = [] + output = [f'{self.__class__.__name__}'] for field in self.get_field_names(): - output.append(f'{field}: {getattr(self, field)}') + output.append(f' {field}: {getattr(self, field)}') return "\n".join(output) def __eq__(self, other): @@ -114,67 +118,94 @@ class Container(object): # values of annotations are the types corresponding to the fields, not instance values. return list(cls.__annotations__.values()) + @classmethod + def default(cls): + return cls(**{f: t.default() for f, t in cls.get_fields()}) -def get_zero_value(typ): - if typ == int: - return 0 - elif is_container_type(typ): - return typ(**{f: get_zero_value(t) for f, t in typ.get_fields()}) - else: - return typ.default() -def type_check(typ, value): - if typ == int or typ == uint64: - return isinstance(value, int) - else: - return typ.value_check(value) +class ParamsBase: + _bare = True + + def __new__(cls, *args, **kwargs): + if cls._bare: + raise Exception("cannot init bare type without params") + return super().__new__(cls, **kwargs) + + +class ParamsMeta(DefaultingTypeMeta): -class AbstractListMeta(type): def __new__(cls, class_name, parents, attrs): out = type.__new__(cls, class_name, parents, attrs) - if 'elem_type' in attrs and 'length' in attrs: - setattr(out, 'elem_type', attrs['elem_type']) - setattr(out, 'length', attrs['length']) + for k, v in attrs.items(): + setattr(out, k, v) return out def __getitem__(self, params): - if not isinstance(params, tuple) or len(params) != 2: - raise Exception("List must be instantiated with two args: elem type and length") - o = self.__class__(self.__name__, (self,), {'elem_type': params[0], 'length': params[1]}) - o._name = 'AbstractList' + o = self.__class__(self.__name__, (self,), self.attr_from_params(params)) + o._bare = False return o + def attr_from_params(self, p): + # single key params are valid too. Wrap them in a tuple. + params = p if isinstance(p, tuple) else (p,) + res = {} + i = 0 + for (name, typ) in self.__annotations__.items(): + param = params[i] + if hasattr(self.__class__, name): + res[name] = getattr(self.__class__, name) + else: + if not isinstance(param, typ): + raise TypeError( + "cannot create parametrized class with param {} as {} of type {}".format(param, name, typ)) + res[name] = param + i += 1 + if len(params) != i: + raise TypeError("provided parameters {} mismatch required parameter count {}".format(params, i)) + return res + def __instancecheck__(self, obj): if obj.__class__.__name__ != self.__name__: return False - if hasattr(self, 'elem_type') and obj.__class__.elem_type != self.elem_type: - return False - if hasattr(self, 'length') and obj.__class__.length != self.length: - return False + for name, typ in self.__annotations__: + if hasattr(self, name) and hasattr(obj.__class__, name) \ + and getattr(obj.__class__, name) != getattr(self, name): + return False return True + class ValueCheckError(Exception): pass -class AbstractList(metaclass=AbstractListMeta): + +class AbstractListMeta(ParamsMeta): + elem_type: DefaultingTypeMeta + length: int + + +class AbstractList(ParamsBase, metaclass=AbstractListMeta): + def __init__(self, *args): - items = self.extract_args(args) - + items = self.extract_args(*args) + if not self.value_check(items): raise ValueCheckError("Bad input for class {}: {}".format(self.__class__, items)) self.items = items - - def value_check(self, value): - for v in value: - if not type_check(self.__class__.elem_type, v): - return False - return True - def extract_args(self, args): - return list(args) if len(args) > 0 else self.default() + @classmethod + def value_check(cls, value): + return all(isinstance(v, cls.elem_type) for v in value) - def default(self): - raise Exception("Not implemented") + @classmethod + def extract_args(cls, *args): + x = list(args) + if len(x) == 1 and isinstance(x[0], GeneratorType): + x = list(x[0]) + return x if len(x) > 0 else cls.default() + + def __str__(self): + cls = self.__class__ + return f"{cls.__name__}[{cls.elem_type.__name__}, {cls.length}]({', '.join(str(v) for v in self.items)})" def __getitem__(self, i): return self.items[i] @@ -194,137 +225,65 @@ class AbstractList(metaclass=AbstractListMeta): def __eq__(self, other): return self.items == other.items -class List(AbstractList, metaclass=AbstractListMeta): + +class List(AbstractList): def value_check(self, value): return len(value) <= self.__class__.length and super().value_check(value) - def default(self): - return [] + @classmethod + def default(cls): + return cls() + class Vector(AbstractList, metaclass=AbstractListMeta): def value_check(self, value): return len(value) == self.__class__.length and super().value_check(value) - def default(self): - return [get_zero_value(self.__class__.elem_type) for _ in range(self.__class__.length)] + @classmethod + def default(cls): + return [cls.elem_type.default() for _ in range(cls.length)] + class BytesMeta(AbstractListMeta): - def __getitem__(self, params): - if not isinstance(params, int): - raise Exception("Bytes must be instantiated with one arg: length") - o = self.__class__(self.__name__, (self,), {'length': params}) - o._name = 'Bytes' - return o + elem_type: DefaultingTypeMeta = byte + length: int -def single_item_extractor(cls, args): - assert len(args) < 2 - return args[0] if len(args) > 0 else cls.default() -class Bytes(AbstractList, metaclass=BytesMeta): +class BytesLike(AbstractList, metaclass=BytesMeta): + + @classmethod + def extract_args(cls, args): + if isinstance(args, bytes): + return args + elif isinstance(args, BytesLike): + return args.items + elif isinstance(args, GeneratorType): + return bytes(args) + else: + return bytes(args) + + @classmethod + def value_check(cls, value): + return len(value) == cls.length and isinstance(value, bytes) + + def __str__(self): + cls = self.__class__ + return f"{cls.__name__}[{cls.length}]: {self.items.hex()}" + + +class Bytes(BytesLike): + def value_check(self, value): return len(value) <= self.__class__.length and isinstance(value, bytes) - extract_args = single_item_extractor - - def default(self): + @classmethod + def default(cls): return b'' -class BytesN(AbstractList, metaclass=BytesMeta): - def value_check(self, value): - return len(value) == self.__class__.length and isinstance(value, bytes) - extract_args = single_item_extractor +class BytesN(BytesLike): - def default(self): - return b'\x00' * self.__class__.length + @classmethod + def default(cls): + return b'\x00' * cls.length - -# Type helpers -# ----------------------------- - -def infer_type(obj): - if is_uint_type(obj.__class__): - return obj.__class__ - elif isinstance(obj, int): - return uint64 - elif isinstance(obj, (List, Vector, Container, bool, BytesN, Bytes)): - return obj.__class__ - else: - raise Exception("Unknown type for {}".format(obj)) - - -def infer_input_type(fn): - """ - Decorator to run infer_type on the obj if typ argument is None - """ - def infer_helper(obj, typ=None, **kwargs): - if typ is None: - typ = infer_type(obj) - return fn(obj, typ=typ, **kwargs) - return infer_helper - - -def is_bool_type(typ): - """ - Check if the given type is a bool. - """ - if hasattr(typ, '__supertype__'): - typ = typ.__supertype__ - return isinstance(typ, type) and issubclass(typ, bool) - - -def is_list_type(typ): - """ - Check if the given type is a list. - """ - return isinstance(typ, type) and issubclass(typ, List) - - -def is_bytes_type(typ): - """ - Check if the given type is a ``bytes``. - """ - return isinstance(typ, type) and issubclass(typ, Bytes) - - -def is_bytesn_type(typ): - """ - Check if the given type is a BytesN. - """ - return isinstance(typ, type) and issubclass(typ, BytesN) - - -def is_list_kind(typ): - """ - Check if the given type is a kind of list. Can be bytes. - """ - return is_list_type(typ) or is_bytes_type(typ) - - -def is_vector_type(typ): - """ - Check if the given type is a vector. - """ - return isinstance(typ, type) and issubclass(typ, Vector) - - -def is_vector_kind(typ): - """ - Check if the given type is a kind of vector. Can be BytesN. - """ - return is_vector_type(typ) or is_bytesn_type(typ) - - -def is_container_type(typ): - """ - Check if the given type is a container. - """ - return isinstance(typ, type) and issubclass(typ, Container) - - -T = TypeVar('T') -L = TypeVar('L') - - -def read_elem_type(typ): - return typ.elem_type