typing improvements
This commit is contained in:
parent
8919f628cb
commit
d1ecfd510e
|
@ -1,52 +1,56 @@
|
|||
from types import GeneratorType
|
||||
from typing import List, Iterable, TypeVar, Type, NewType
|
||||
from typing import Union
|
||||
from typing_inspect import get_origin
|
||||
|
||||
|
||||
class DefaultingTypeMeta(type):
|
||||
def default(cls):
|
||||
raise Exception("Not implemented")
|
||||
|
||||
# SSZ integers
|
||||
# -----------------------------
|
||||
|
||||
class uint(int):
|
||||
|
||||
class uint(int, metaclass=DefaultingTypeMeta):
|
||||
byte_len = 0
|
||||
|
||||
def __new__(cls, value, *args, **kwargs):
|
||||
if value < 0:
|
||||
raise ValueError("unsigned types must not be negative")
|
||||
if value.byte_len and value.bit_length() > value.byte_len:
|
||||
raise ValueError("value out of bounds for uint{}".format(value.byte_len))
|
||||
if cls.byte_len and (value.bit_length() >> 3) > cls.byte_len:
|
||||
raise ValueError("value out of bounds for uint{}".format(cls.byte_len))
|
||||
return super().__new__(cls, value)
|
||||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
return cls(0)
|
||||
|
||||
|
||||
class uint8(uint):
|
||||
byte_len = 1
|
||||
|
||||
|
||||
# Alias for uint8
|
||||
byte = NewType('byte', uint8)
|
||||
|
||||
|
||||
class uint16(uint):
|
||||
byte_len = 2
|
||||
|
||||
|
||||
class uint32(uint):
|
||||
byte_len = 4
|
||||
|
||||
|
||||
class uint64(uint):
|
||||
byte_len = 8
|
||||
|
||||
|
||||
class uint128(uint):
|
||||
byte_len = 16
|
||||
|
||||
|
||||
class uint256(uint):
|
||||
byte_len = 32
|
||||
|
||||
def is_uint_type(typ):
|
||||
# All integers are uint in the scope of the spec here.
|
||||
# Since we default to uint64. Bounds can be checked elsewhere.
|
||||
# However, some are wrapped in a NewType
|
||||
if hasattr(typ, '__supertype__'):
|
||||
# get the type that the NewType is wrapping
|
||||
typ = typ.__supertype__
|
||||
|
||||
return isinstance(typ, type) and issubclass(typ, int) and not issubclass(typ, bool)
|
||||
|
||||
# SSZ Container base class
|
||||
# -----------------------------
|
||||
|
@ -59,7 +63,7 @@ class Container(object):
|
|||
cls = self.__class__
|
||||
for f, t in cls.get_fields():
|
||||
if f not in kwargs:
|
||||
setattr(self, f, get_zero_value(t))
|
||||
setattr(self, f, t.default())
|
||||
else:
|
||||
setattr(self, f, kwargs[f])
|
||||
|
||||
|
@ -83,9 +87,9 @@ class Container(object):
|
|||
return repr({field: getattr(self, field) for field in self.get_field_names()})
|
||||
|
||||
def __str__(self):
|
||||
output = []
|
||||
output = [f'{self.__class__.__name__}']
|
||||
for field in self.get_field_names():
|
||||
output.append(f'{field}: {getattr(self, field)}')
|
||||
output.append(f' {field}: {getattr(self, field)}')
|
||||
return "\n".join(output)
|
||||
|
||||
def __eq__(self, other):
|
||||
|
@ -114,67 +118,94 @@ class Container(object):
|
|||
# values of annotations are the types corresponding to the fields, not instance values.
|
||||
return list(cls.__annotations__.values())
|
||||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
return cls(**{f: t.default() for f, t in cls.get_fields()})
|
||||
|
||||
def get_zero_value(typ):
|
||||
if typ == int:
|
||||
return 0
|
||||
elif is_container_type(typ):
|
||||
return typ(**{f: get_zero_value(t) for f, t in typ.get_fields()})
|
||||
else:
|
||||
return typ.default()
|
||||
|
||||
def type_check(typ, value):
|
||||
if typ == int or typ == uint64:
|
||||
return isinstance(value, int)
|
||||
else:
|
||||
return typ.value_check(value)
|
||||
class ParamsBase:
|
||||
_bare = True
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._bare:
|
||||
raise Exception("cannot init bare type without params")
|
||||
return super().__new__(cls, **kwargs)
|
||||
|
||||
|
||||
class ParamsMeta(DefaultingTypeMeta):
|
||||
|
||||
class AbstractListMeta(type):
|
||||
def __new__(cls, class_name, parents, attrs):
|
||||
out = type.__new__(cls, class_name, parents, attrs)
|
||||
if 'elem_type' in attrs and 'length' in attrs:
|
||||
setattr(out, 'elem_type', attrs['elem_type'])
|
||||
setattr(out, 'length', attrs['length'])
|
||||
for k, v in attrs.items():
|
||||
setattr(out, k, v)
|
||||
return out
|
||||
|
||||
def __getitem__(self, params):
|
||||
if not isinstance(params, tuple) or len(params) != 2:
|
||||
raise Exception("List must be instantiated with two args: elem type and length")
|
||||
o = self.__class__(self.__name__, (self,), {'elem_type': params[0], 'length': params[1]})
|
||||
o._name = 'AbstractList'
|
||||
o = self.__class__(self.__name__, (self,), self.attr_from_params(params))
|
||||
o._bare = False
|
||||
return o
|
||||
|
||||
def attr_from_params(self, p):
|
||||
# single key params are valid too. Wrap them in a tuple.
|
||||
params = p if isinstance(p, tuple) else (p,)
|
||||
res = {}
|
||||
i = 0
|
||||
for (name, typ) in self.__annotations__.items():
|
||||
param = params[i]
|
||||
if hasattr(self.__class__, name):
|
||||
res[name] = getattr(self.__class__, name)
|
||||
else:
|
||||
if not isinstance(param, typ):
|
||||
raise TypeError(
|
||||
"cannot create parametrized class with param {} as {} of type {}".format(param, name, typ))
|
||||
res[name] = param
|
||||
i += 1
|
||||
if len(params) != i:
|
||||
raise TypeError("provided parameters {} mismatch required parameter count {}".format(params, i))
|
||||
return res
|
||||
|
||||
def __instancecheck__(self, obj):
|
||||
if obj.__class__.__name__ != self.__name__:
|
||||
return False
|
||||
if hasattr(self, 'elem_type') and obj.__class__.elem_type != self.elem_type:
|
||||
return False
|
||||
if hasattr(self, 'length') and obj.__class__.length != self.length:
|
||||
return False
|
||||
for name, typ in self.__annotations__:
|
||||
if hasattr(self, name) and hasattr(obj.__class__, name) \
|
||||
and getattr(obj.__class__, name) != getattr(self, name):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class ValueCheckError(Exception):
|
||||
pass
|
||||
|
||||
class AbstractList(metaclass=AbstractListMeta):
|
||||
|
||||
class AbstractListMeta(ParamsMeta):
|
||||
elem_type: DefaultingTypeMeta
|
||||
length: int
|
||||
|
||||
|
||||
class AbstractList(ParamsBase, metaclass=AbstractListMeta):
|
||||
|
||||
def __init__(self, *args):
|
||||
items = self.extract_args(args)
|
||||
|
||||
items = self.extract_args(*args)
|
||||
|
||||
if not self.value_check(items):
|
||||
raise ValueCheckError("Bad input for class {}: {}".format(self.__class__, items))
|
||||
self.items = items
|
||||
|
||||
def value_check(self, value):
|
||||
for v in value:
|
||||
if not type_check(self.__class__.elem_type, v):
|
||||
return False
|
||||
return True
|
||||
|
||||
def extract_args(self, args):
|
||||
return list(args) if len(args) > 0 else self.default()
|
||||
@classmethod
|
||||
def value_check(cls, value):
|
||||
return all(isinstance(v, cls.elem_type) for v in value)
|
||||
|
||||
def default(self):
|
||||
raise Exception("Not implemented")
|
||||
@classmethod
|
||||
def extract_args(cls, *args):
|
||||
x = list(args)
|
||||
if len(x) == 1 and isinstance(x[0], GeneratorType):
|
||||
x = list(x[0])
|
||||
return x if len(x) > 0 else cls.default()
|
||||
|
||||
def __str__(self):
|
||||
cls = self.__class__
|
||||
return f"{cls.__name__}[{cls.elem_type.__name__}, {cls.length}]({', '.join(str(v) for v in self.items)})"
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.items[i]
|
||||
|
@ -194,137 +225,65 @@ class AbstractList(metaclass=AbstractListMeta):
|
|||
def __eq__(self, other):
|
||||
return self.items == other.items
|
||||
|
||||
class List(AbstractList, metaclass=AbstractListMeta):
|
||||
|
||||
class List(AbstractList):
|
||||
def value_check(self, value):
|
||||
return len(value) <= self.__class__.length and super().value_check(value)
|
||||
|
||||
def default(self):
|
||||
return []
|
||||
@classmethod
|
||||
def default(cls):
|
||||
return cls()
|
||||
|
||||
|
||||
class Vector(AbstractList, metaclass=AbstractListMeta):
|
||||
def value_check(self, value):
|
||||
return len(value) == self.__class__.length and super().value_check(value)
|
||||
|
||||
def default(self):
|
||||
return [get_zero_value(self.__class__.elem_type) for _ in range(self.__class__.length)]
|
||||
@classmethod
|
||||
def default(cls):
|
||||
return [cls.elem_type.default() for _ in range(cls.length)]
|
||||
|
||||
|
||||
class BytesMeta(AbstractListMeta):
|
||||
def __getitem__(self, params):
|
||||
if not isinstance(params, int):
|
||||
raise Exception("Bytes must be instantiated with one arg: length")
|
||||
o = self.__class__(self.__name__, (self,), {'length': params})
|
||||
o._name = 'Bytes'
|
||||
return o
|
||||
elem_type: DefaultingTypeMeta = byte
|
||||
length: int
|
||||
|
||||
def single_item_extractor(cls, args):
|
||||
assert len(args) < 2
|
||||
return args[0] if len(args) > 0 else cls.default()
|
||||
|
||||
class Bytes(AbstractList, metaclass=BytesMeta):
|
||||
class BytesLike(AbstractList, metaclass=BytesMeta):
|
||||
|
||||
@classmethod
|
||||
def extract_args(cls, args):
|
||||
if isinstance(args, bytes):
|
||||
return args
|
||||
elif isinstance(args, BytesLike):
|
||||
return args.items
|
||||
elif isinstance(args, GeneratorType):
|
||||
return bytes(args)
|
||||
else:
|
||||
return bytes(args)
|
||||
|
||||
@classmethod
|
||||
def value_check(cls, value):
|
||||
return len(value) == cls.length and isinstance(value, bytes)
|
||||
|
||||
def __str__(self):
|
||||
cls = self.__class__
|
||||
return f"{cls.__name__}[{cls.length}]: {self.items.hex()}"
|
||||
|
||||
|
||||
class Bytes(BytesLike):
|
||||
|
||||
def value_check(self, value):
|
||||
return len(value) <= self.__class__.length and isinstance(value, bytes)
|
||||
|
||||
extract_args = single_item_extractor
|
||||
|
||||
def default(self):
|
||||
@classmethod
|
||||
def default(cls):
|
||||
return b''
|
||||
|
||||
class BytesN(AbstractList, metaclass=BytesMeta):
|
||||
def value_check(self, value):
|
||||
return len(value) == self.__class__.length and isinstance(value, bytes)
|
||||
|
||||
extract_args = single_item_extractor
|
||||
class BytesN(BytesLike):
|
||||
|
||||
def default(self):
|
||||
return b'\x00' * self.__class__.length
|
||||
@classmethod
|
||||
def default(cls):
|
||||
return b'\x00' * cls.length
|
||||
|
||||
|
||||
# Type helpers
|
||||
# -----------------------------
|
||||
|
||||
def infer_type(obj):
|
||||
if is_uint_type(obj.__class__):
|
||||
return obj.__class__
|
||||
elif isinstance(obj, int):
|
||||
return uint64
|
||||
elif isinstance(obj, (List, Vector, Container, bool, BytesN, Bytes)):
|
||||
return obj.__class__
|
||||
else:
|
||||
raise Exception("Unknown type for {}".format(obj))
|
||||
|
||||
|
||||
def infer_input_type(fn):
|
||||
"""
|
||||
Decorator to run infer_type on the obj if typ argument is None
|
||||
"""
|
||||
def infer_helper(obj, typ=None, **kwargs):
|
||||
if typ is None:
|
||||
typ = infer_type(obj)
|
||||
return fn(obj, typ=typ, **kwargs)
|
||||
return infer_helper
|
||||
|
||||
|
||||
def is_bool_type(typ):
|
||||
"""
|
||||
Check if the given type is a bool.
|
||||
"""
|
||||
if hasattr(typ, '__supertype__'):
|
||||
typ = typ.__supertype__
|
||||
return isinstance(typ, type) and issubclass(typ, bool)
|
||||
|
||||
|
||||
def is_list_type(typ):
|
||||
"""
|
||||
Check if the given type is a list.
|
||||
"""
|
||||
return isinstance(typ, type) and issubclass(typ, List)
|
||||
|
||||
|
||||
def is_bytes_type(typ):
|
||||
"""
|
||||
Check if the given type is a ``bytes``.
|
||||
"""
|
||||
return isinstance(typ, type) and issubclass(typ, Bytes)
|
||||
|
||||
|
||||
def is_bytesn_type(typ):
|
||||
"""
|
||||
Check if the given type is a BytesN.
|
||||
"""
|
||||
return isinstance(typ, type) and issubclass(typ, BytesN)
|
||||
|
||||
|
||||
def is_list_kind(typ):
|
||||
"""
|
||||
Check if the given type is a kind of list. Can be bytes.
|
||||
"""
|
||||
return is_list_type(typ) or is_bytes_type(typ)
|
||||
|
||||
|
||||
def is_vector_type(typ):
|
||||
"""
|
||||
Check if the given type is a vector.
|
||||
"""
|
||||
return isinstance(typ, type) and issubclass(typ, Vector)
|
||||
|
||||
|
||||
def is_vector_kind(typ):
|
||||
"""
|
||||
Check if the given type is a kind of vector. Can be BytesN.
|
||||
"""
|
||||
return is_vector_type(typ) or is_bytesn_type(typ)
|
||||
|
||||
|
||||
def is_container_type(typ):
|
||||
"""
|
||||
Check if the given type is a container.
|
||||
"""
|
||||
return isinstance(typ, type) and issubclass(typ, Container)
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
L = TypeVar('L')
|
||||
|
||||
|
||||
def read_elem_type(typ):
|
||||
return typ.elem_type
|
||||
|
|
Loading…
Reference in New Issue