typing improvements

This commit is contained in:
protolambda 2019-06-20 19:53:32 +02:00
parent 8919f628cb
commit d1ecfd510e
No known key found for this signature in database
GPG Key ID: EC89FDBB2B4C7623
1 changed files with 128 additions and 169 deletions

View File

@ -1,52 +1,56 @@
from types import GeneratorType from types import GeneratorType
from typing import List, Iterable, TypeVar, Type, NewType
from typing import Union
from typing_inspect import get_origin class DefaultingTypeMeta(type):
def default(cls):
raise Exception("Not implemented")
# SSZ integers # SSZ integers
# ----------------------------- # -----------------------------
class uint(int):
class uint(int, metaclass=DefaultingTypeMeta):
byte_len = 0 byte_len = 0
def __new__(cls, value, *args, **kwargs): def __new__(cls, value, *args, **kwargs):
if value < 0: if value < 0:
raise ValueError("unsigned types must not be negative") raise ValueError("unsigned types must not be negative")
if value.byte_len and value.bit_length() > value.byte_len: if cls.byte_len and (value.bit_length() >> 3) > cls.byte_len:
raise ValueError("value out of bounds for uint{}".format(value.byte_len)) raise ValueError("value out of bounds for uint{}".format(cls.byte_len))
return super().__new__(cls, value) return super().__new__(cls, value)
@classmethod
def default(cls):
return cls(0)
class uint8(uint): class uint8(uint):
byte_len = 1 byte_len = 1
# Alias for uint8 # Alias for uint8
byte = NewType('byte', uint8) byte = NewType('byte', uint8)
class uint16(uint): class uint16(uint):
byte_len = 2 byte_len = 2
class uint32(uint): class uint32(uint):
byte_len = 4 byte_len = 4
class uint64(uint): class uint64(uint):
byte_len = 8 byte_len = 8
class uint128(uint): class uint128(uint):
byte_len = 16 byte_len = 16
class uint256(uint): class uint256(uint):
byte_len = 32 byte_len = 32
def is_uint_type(typ):
# All integers are uint in the scope of the spec here.
# Since we default to uint64. Bounds can be checked elsewhere.
# However, some are wrapped in a NewType
if hasattr(typ, '__supertype__'):
# get the type that the NewType is wrapping
typ = typ.__supertype__
return isinstance(typ, type) and issubclass(typ, int) and not issubclass(typ, bool)
# SSZ Container base class # SSZ Container base class
# ----------------------------- # -----------------------------
@ -59,7 +63,7 @@ class Container(object):
cls = self.__class__ cls = self.__class__
for f, t in cls.get_fields(): for f, t in cls.get_fields():
if f not in kwargs: if f not in kwargs:
setattr(self, f, get_zero_value(t)) setattr(self, f, t.default())
else: else:
setattr(self, f, kwargs[f]) setattr(self, f, kwargs[f])
@ -83,9 +87,9 @@ class Container(object):
return repr({field: getattr(self, field) for field in self.get_field_names()}) return repr({field: getattr(self, field) for field in self.get_field_names()})
def __str__(self): def __str__(self):
output = [] output = [f'{self.__class__.__name__}']
for field in self.get_field_names(): for field in self.get_field_names():
output.append(f'{field}: {getattr(self, field)}') output.append(f' {field}: {getattr(self, field)}')
return "\n".join(output) return "\n".join(output)
def __eq__(self, other): def __eq__(self, other):
@ -114,67 +118,94 @@ class Container(object):
# values of annotations are the types corresponding to the fields, not instance values. # values of annotations are the types corresponding to the fields, not instance values.
return list(cls.__annotations__.values()) return list(cls.__annotations__.values())
@classmethod
def default(cls):
return cls(**{f: t.default() for f, t in cls.get_fields()})
def get_zero_value(typ):
if typ == int:
return 0
elif is_container_type(typ):
return typ(**{f: get_zero_value(t) for f, t in typ.get_fields()})
else:
return typ.default()
def type_check(typ, value): class ParamsBase:
if typ == int or typ == uint64: _bare = True
return isinstance(value, int)
else: def __new__(cls, *args, **kwargs):
return typ.value_check(value) if cls._bare:
raise Exception("cannot init bare type without params")
return super().__new__(cls, **kwargs)
class ParamsMeta(DefaultingTypeMeta):
class AbstractListMeta(type):
def __new__(cls, class_name, parents, attrs): def __new__(cls, class_name, parents, attrs):
out = type.__new__(cls, class_name, parents, attrs) out = type.__new__(cls, class_name, parents, attrs)
if 'elem_type' in attrs and 'length' in attrs: for k, v in attrs.items():
setattr(out, 'elem_type', attrs['elem_type']) setattr(out, k, v)
setattr(out, 'length', attrs['length'])
return out return out
def __getitem__(self, params): def __getitem__(self, params):
if not isinstance(params, tuple) or len(params) != 2: o = self.__class__(self.__name__, (self,), self.attr_from_params(params))
raise Exception("List must be instantiated with two args: elem type and length") o._bare = False
o = self.__class__(self.__name__, (self,), {'elem_type': params[0], 'length': params[1]})
o._name = 'AbstractList'
return o return o
def attr_from_params(self, p):
# single key params are valid too. Wrap them in a tuple.
params = p if isinstance(p, tuple) else (p,)
res = {}
i = 0
for (name, typ) in self.__annotations__.items():
param = params[i]
if hasattr(self.__class__, name):
res[name] = getattr(self.__class__, name)
else:
if not isinstance(param, typ):
raise TypeError(
"cannot create parametrized class with param {} as {} of type {}".format(param, name, typ))
res[name] = param
i += 1
if len(params) != i:
raise TypeError("provided parameters {} mismatch required parameter count {}".format(params, i))
return res
def __instancecheck__(self, obj): def __instancecheck__(self, obj):
if obj.__class__.__name__ != self.__name__: if obj.__class__.__name__ != self.__name__:
return False return False
if hasattr(self, 'elem_type') and obj.__class__.elem_type != self.elem_type: for name, typ in self.__annotations__:
return False if hasattr(self, name) and hasattr(obj.__class__, name) \
if hasattr(self, 'length') and obj.__class__.length != self.length: and getattr(obj.__class__, name) != getattr(self, name):
return False return False
return True return True
class ValueCheckError(Exception): class ValueCheckError(Exception):
pass pass
class AbstractList(metaclass=AbstractListMeta):
class AbstractListMeta(ParamsMeta):
elem_type: DefaultingTypeMeta
length: int
class AbstractList(ParamsBase, metaclass=AbstractListMeta):
def __init__(self, *args): def __init__(self, *args):
items = self.extract_args(args) items = self.extract_args(*args)
if not self.value_check(items): if not self.value_check(items):
raise ValueCheckError("Bad input for class {}: {}".format(self.__class__, items)) raise ValueCheckError("Bad input for class {}: {}".format(self.__class__, items))
self.items = items self.items = items
def value_check(self, value): @classmethod
for v in value: def value_check(cls, value):
if not type_check(self.__class__.elem_type, v): return all(isinstance(v, cls.elem_type) for v in value)
return False
return True
def extract_args(self, args): @classmethod
return list(args) if len(args) > 0 else self.default() def extract_args(cls, *args):
x = list(args)
if len(x) == 1 and isinstance(x[0], GeneratorType):
x = list(x[0])
return x if len(x) > 0 else cls.default()
def default(self): def __str__(self):
raise Exception("Not implemented") cls = self.__class__
return f"{cls.__name__}[{cls.elem_type.__name__}, {cls.length}]({', '.join(str(v) for v in self.items)})"
def __getitem__(self, i): def __getitem__(self, i):
return self.items[i] return self.items[i]
@ -194,137 +225,65 @@ class AbstractList(metaclass=AbstractListMeta):
def __eq__(self, other): def __eq__(self, other):
return self.items == other.items return self.items == other.items
class List(AbstractList, metaclass=AbstractListMeta):
class List(AbstractList):
def value_check(self, value): def value_check(self, value):
return len(value) <= self.__class__.length and super().value_check(value) return len(value) <= self.__class__.length and super().value_check(value)
def default(self): @classmethod
return [] def default(cls):
return cls()
class Vector(AbstractList, metaclass=AbstractListMeta): class Vector(AbstractList, metaclass=AbstractListMeta):
def value_check(self, value): def value_check(self, value):
return len(value) == self.__class__.length and super().value_check(value) return len(value) == self.__class__.length and super().value_check(value)
def default(self): @classmethod
return [get_zero_value(self.__class__.elem_type) for _ in range(self.__class__.length)] def default(cls):
return [cls.elem_type.default() for _ in range(cls.length)]
class BytesMeta(AbstractListMeta): class BytesMeta(AbstractListMeta):
def __getitem__(self, params): elem_type: DefaultingTypeMeta = byte
if not isinstance(params, int): length: int
raise Exception("Bytes must be instantiated with one arg: length")
o = self.__class__(self.__name__, (self,), {'length': params})
o._name = 'Bytes'
return o
def single_item_extractor(cls, args):
assert len(args) < 2
return args[0] if len(args) > 0 else cls.default()
class Bytes(AbstractList, metaclass=BytesMeta): class BytesLike(AbstractList, metaclass=BytesMeta):
@classmethod
def extract_args(cls, args):
if isinstance(args, bytes):
return args
elif isinstance(args, BytesLike):
return args.items
elif isinstance(args, GeneratorType):
return bytes(args)
else:
return bytes(args)
@classmethod
def value_check(cls, value):
return len(value) == cls.length and isinstance(value, bytes)
def __str__(self):
cls = self.__class__
return f"{cls.__name__}[{cls.length}]: {self.items.hex()}"
class Bytes(BytesLike):
def value_check(self, value): def value_check(self, value):
return len(value) <= self.__class__.length and isinstance(value, bytes) return len(value) <= self.__class__.length and isinstance(value, bytes)
extract_args = single_item_extractor @classmethod
def default(cls):
def default(self):
return b'' return b''
class BytesN(AbstractList, metaclass=BytesMeta):
def value_check(self, value):
return len(value) == self.__class__.length and isinstance(value, bytes)
extract_args = single_item_extractor class BytesN(BytesLike):
def default(self): @classmethod
return b'\x00' * self.__class__.length def default(cls):
return b'\x00' * cls.length
# Type helpers
# -----------------------------
def infer_type(obj):
if is_uint_type(obj.__class__):
return obj.__class__
elif isinstance(obj, int):
return uint64
elif isinstance(obj, (List, Vector, Container, bool, BytesN, Bytes)):
return obj.__class__
else:
raise Exception("Unknown type for {}".format(obj))
def infer_input_type(fn):
"""
Decorator to run infer_type on the obj if typ argument is None
"""
def infer_helper(obj, typ=None, **kwargs):
if typ is None:
typ = infer_type(obj)
return fn(obj, typ=typ, **kwargs)
return infer_helper
def is_bool_type(typ):
"""
Check if the given type is a bool.
"""
if hasattr(typ, '__supertype__'):
typ = typ.__supertype__
return isinstance(typ, type) and issubclass(typ, bool)
def is_list_type(typ):
"""
Check if the given type is a list.
"""
return isinstance(typ, type) and issubclass(typ, List)
def is_bytes_type(typ):
"""
Check if the given type is a ``bytes``.
"""
return isinstance(typ, type) and issubclass(typ, Bytes)
def is_bytesn_type(typ):
"""
Check if the given type is a BytesN.
"""
return isinstance(typ, type) and issubclass(typ, BytesN)
def is_list_kind(typ):
"""
Check if the given type is a kind of list. Can be bytes.
"""
return is_list_type(typ) or is_bytes_type(typ)
def is_vector_type(typ):
"""
Check if the given type is a vector.
"""
return isinstance(typ, type) and issubclass(typ, Vector)
def is_vector_kind(typ):
"""
Check if the given type is a kind of vector. Can be BytesN.
"""
return is_vector_type(typ) or is_bytesn_type(typ)
def is_container_type(typ):
"""
Check if the given type is a container.
"""
return isinstance(typ, type) and issubclass(typ, Container)
T = TypeVar('T')
L = TypeVar('L')
def read_elem_type(typ):
return typ.elem_type