typing improvements
This commit is contained in:
parent
8919f628cb
commit
d1ecfd510e
|
@ -1,52 +1,56 @@
|
||||||
from types import GeneratorType
|
from types import GeneratorType
|
||||||
from typing import List, Iterable, TypeVar, Type, NewType
|
|
||||||
from typing import Union
|
|
||||||
from typing_inspect import get_origin
|
class DefaultingTypeMeta(type):
|
||||||
|
def default(cls):
|
||||||
|
raise Exception("Not implemented")
|
||||||
|
|
||||||
# SSZ integers
|
# SSZ integers
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
|
|
||||||
class uint(int):
|
|
||||||
|
class uint(int, metaclass=DefaultingTypeMeta):
|
||||||
byte_len = 0
|
byte_len = 0
|
||||||
|
|
||||||
def __new__(cls, value, *args, **kwargs):
|
def __new__(cls, value, *args, **kwargs):
|
||||||
if value < 0:
|
if value < 0:
|
||||||
raise ValueError("unsigned types must not be negative")
|
raise ValueError("unsigned types must not be negative")
|
||||||
if value.byte_len and value.bit_length() > value.byte_len:
|
if cls.byte_len and (value.bit_length() >> 3) > cls.byte_len:
|
||||||
raise ValueError("value out of bounds for uint{}".format(value.byte_len))
|
raise ValueError("value out of bounds for uint{}".format(cls.byte_len))
|
||||||
return super().__new__(cls, value)
|
return super().__new__(cls, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default(cls):
|
||||||
|
return cls(0)
|
||||||
|
|
||||||
|
|
||||||
class uint8(uint):
|
class uint8(uint):
|
||||||
byte_len = 1
|
byte_len = 1
|
||||||
|
|
||||||
|
|
||||||
# Alias for uint8
|
# Alias for uint8
|
||||||
byte = NewType('byte', uint8)
|
byte = NewType('byte', uint8)
|
||||||
|
|
||||||
|
|
||||||
class uint16(uint):
|
class uint16(uint):
|
||||||
byte_len = 2
|
byte_len = 2
|
||||||
|
|
||||||
|
|
||||||
class uint32(uint):
|
class uint32(uint):
|
||||||
byte_len = 4
|
byte_len = 4
|
||||||
|
|
||||||
|
|
||||||
class uint64(uint):
|
class uint64(uint):
|
||||||
byte_len = 8
|
byte_len = 8
|
||||||
|
|
||||||
|
|
||||||
class uint128(uint):
|
class uint128(uint):
|
||||||
byte_len = 16
|
byte_len = 16
|
||||||
|
|
||||||
|
|
||||||
class uint256(uint):
|
class uint256(uint):
|
||||||
byte_len = 32
|
byte_len = 32
|
||||||
|
|
||||||
def is_uint_type(typ):
|
|
||||||
# All integers are uint in the scope of the spec here.
|
|
||||||
# Since we default to uint64. Bounds can be checked elsewhere.
|
|
||||||
# However, some are wrapped in a NewType
|
|
||||||
if hasattr(typ, '__supertype__'):
|
|
||||||
# get the type that the NewType is wrapping
|
|
||||||
typ = typ.__supertype__
|
|
||||||
|
|
||||||
return isinstance(typ, type) and issubclass(typ, int) and not issubclass(typ, bool)
|
|
||||||
|
|
||||||
# SSZ Container base class
|
# SSZ Container base class
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
|
@ -59,7 +63,7 @@ class Container(object):
|
||||||
cls = self.__class__
|
cls = self.__class__
|
||||||
for f, t in cls.get_fields():
|
for f, t in cls.get_fields():
|
||||||
if f not in kwargs:
|
if f not in kwargs:
|
||||||
setattr(self, f, get_zero_value(t))
|
setattr(self, f, t.default())
|
||||||
else:
|
else:
|
||||||
setattr(self, f, kwargs[f])
|
setattr(self, f, kwargs[f])
|
||||||
|
|
||||||
|
@ -83,9 +87,9 @@ class Container(object):
|
||||||
return repr({field: getattr(self, field) for field in self.get_field_names()})
|
return repr({field: getattr(self, field) for field in self.get_field_names()})
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
output = []
|
output = [f'{self.__class__.__name__}']
|
||||||
for field in self.get_field_names():
|
for field in self.get_field_names():
|
||||||
output.append(f'{field}: {getattr(self, field)}')
|
output.append(f' {field}: {getattr(self, field)}')
|
||||||
return "\n".join(output)
|
return "\n".join(output)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
|
@ -114,67 +118,94 @@ class Container(object):
|
||||||
# values of annotations are the types corresponding to the fields, not instance values.
|
# values of annotations are the types corresponding to the fields, not instance values.
|
||||||
return list(cls.__annotations__.values())
|
return list(cls.__annotations__.values())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default(cls):
|
||||||
|
return cls(**{f: t.default() for f, t in cls.get_fields()})
|
||||||
|
|
||||||
def get_zero_value(typ):
|
|
||||||
if typ == int:
|
|
||||||
return 0
|
|
||||||
elif is_container_type(typ):
|
|
||||||
return typ(**{f: get_zero_value(t) for f, t in typ.get_fields()})
|
|
||||||
else:
|
|
||||||
return typ.default()
|
|
||||||
|
|
||||||
def type_check(typ, value):
|
class ParamsBase:
|
||||||
if typ == int or typ == uint64:
|
_bare = True
|
||||||
return isinstance(value, int)
|
|
||||||
else:
|
def __new__(cls, *args, **kwargs):
|
||||||
return typ.value_check(value)
|
if cls._bare:
|
||||||
|
raise Exception("cannot init bare type without params")
|
||||||
|
return super().__new__(cls, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class ParamsMeta(DefaultingTypeMeta):
|
||||||
|
|
||||||
class AbstractListMeta(type):
|
|
||||||
def __new__(cls, class_name, parents, attrs):
|
def __new__(cls, class_name, parents, attrs):
|
||||||
out = type.__new__(cls, class_name, parents, attrs)
|
out = type.__new__(cls, class_name, parents, attrs)
|
||||||
if 'elem_type' in attrs and 'length' in attrs:
|
for k, v in attrs.items():
|
||||||
setattr(out, 'elem_type', attrs['elem_type'])
|
setattr(out, k, v)
|
||||||
setattr(out, 'length', attrs['length'])
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def __getitem__(self, params):
|
def __getitem__(self, params):
|
||||||
if not isinstance(params, tuple) or len(params) != 2:
|
o = self.__class__(self.__name__, (self,), self.attr_from_params(params))
|
||||||
raise Exception("List must be instantiated with two args: elem type and length")
|
o._bare = False
|
||||||
o = self.__class__(self.__name__, (self,), {'elem_type': params[0], 'length': params[1]})
|
|
||||||
o._name = 'AbstractList'
|
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def attr_from_params(self, p):
|
||||||
|
# single key params are valid too. Wrap them in a tuple.
|
||||||
|
params = p if isinstance(p, tuple) else (p,)
|
||||||
|
res = {}
|
||||||
|
i = 0
|
||||||
|
for (name, typ) in self.__annotations__.items():
|
||||||
|
param = params[i]
|
||||||
|
if hasattr(self.__class__, name):
|
||||||
|
res[name] = getattr(self.__class__, name)
|
||||||
|
else:
|
||||||
|
if not isinstance(param, typ):
|
||||||
|
raise TypeError(
|
||||||
|
"cannot create parametrized class with param {} as {} of type {}".format(param, name, typ))
|
||||||
|
res[name] = param
|
||||||
|
i += 1
|
||||||
|
if len(params) != i:
|
||||||
|
raise TypeError("provided parameters {} mismatch required parameter count {}".format(params, i))
|
||||||
|
return res
|
||||||
|
|
||||||
def __instancecheck__(self, obj):
|
def __instancecheck__(self, obj):
|
||||||
if obj.__class__.__name__ != self.__name__:
|
if obj.__class__.__name__ != self.__name__:
|
||||||
return False
|
return False
|
||||||
if hasattr(self, 'elem_type') and obj.__class__.elem_type != self.elem_type:
|
for name, typ in self.__annotations__:
|
||||||
return False
|
if hasattr(self, name) and hasattr(obj.__class__, name) \
|
||||||
if hasattr(self, 'length') and obj.__class__.length != self.length:
|
and getattr(obj.__class__, name) != getattr(self, name):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class ValueCheckError(Exception):
|
class ValueCheckError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class AbstractList(metaclass=AbstractListMeta):
|
|
||||||
|
class AbstractListMeta(ParamsMeta):
|
||||||
|
elem_type: DefaultingTypeMeta
|
||||||
|
length: int
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractList(ParamsBase, metaclass=AbstractListMeta):
|
||||||
|
|
||||||
def __init__(self, *args):
|
def __init__(self, *args):
|
||||||
items = self.extract_args(args)
|
items = self.extract_args(*args)
|
||||||
|
|
||||||
if not self.value_check(items):
|
if not self.value_check(items):
|
||||||
raise ValueCheckError("Bad input for class {}: {}".format(self.__class__, items))
|
raise ValueCheckError("Bad input for class {}: {}".format(self.__class__, items))
|
||||||
self.items = items
|
self.items = items
|
||||||
|
|
||||||
def value_check(self, value):
|
@classmethod
|
||||||
for v in value:
|
def value_check(cls, value):
|
||||||
if not type_check(self.__class__.elem_type, v):
|
return all(isinstance(v, cls.elem_type) for v in value)
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def extract_args(self, args):
|
@classmethod
|
||||||
return list(args) if len(args) > 0 else self.default()
|
def extract_args(cls, *args):
|
||||||
|
x = list(args)
|
||||||
|
if len(x) == 1 and isinstance(x[0], GeneratorType):
|
||||||
|
x = list(x[0])
|
||||||
|
return x if len(x) > 0 else cls.default()
|
||||||
|
|
||||||
def default(self):
|
def __str__(self):
|
||||||
raise Exception("Not implemented")
|
cls = self.__class__
|
||||||
|
return f"{cls.__name__}[{cls.elem_type.__name__}, {cls.length}]({', '.join(str(v) for v in self.items)})"
|
||||||
|
|
||||||
def __getitem__(self, i):
|
def __getitem__(self, i):
|
||||||
return self.items[i]
|
return self.items[i]
|
||||||
|
@ -194,137 +225,65 @@ class AbstractList(metaclass=AbstractListMeta):
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return self.items == other.items
|
return self.items == other.items
|
||||||
|
|
||||||
class List(AbstractList, metaclass=AbstractListMeta):
|
|
||||||
|
class List(AbstractList):
|
||||||
def value_check(self, value):
|
def value_check(self, value):
|
||||||
return len(value) <= self.__class__.length and super().value_check(value)
|
return len(value) <= self.__class__.length and super().value_check(value)
|
||||||
|
|
||||||
def default(self):
|
@classmethod
|
||||||
return []
|
def default(cls):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
|
||||||
class Vector(AbstractList, metaclass=AbstractListMeta):
|
class Vector(AbstractList, metaclass=AbstractListMeta):
|
||||||
def value_check(self, value):
|
def value_check(self, value):
|
||||||
return len(value) == self.__class__.length and super().value_check(value)
|
return len(value) == self.__class__.length and super().value_check(value)
|
||||||
|
|
||||||
def default(self):
|
@classmethod
|
||||||
return [get_zero_value(self.__class__.elem_type) for _ in range(self.__class__.length)]
|
def default(cls):
|
||||||
|
return [cls.elem_type.default() for _ in range(cls.length)]
|
||||||
|
|
||||||
|
|
||||||
class BytesMeta(AbstractListMeta):
|
class BytesMeta(AbstractListMeta):
|
||||||
def __getitem__(self, params):
|
elem_type: DefaultingTypeMeta = byte
|
||||||
if not isinstance(params, int):
|
length: int
|
||||||
raise Exception("Bytes must be instantiated with one arg: length")
|
|
||||||
o = self.__class__(self.__name__, (self,), {'length': params})
|
|
||||||
o._name = 'Bytes'
|
|
||||||
return o
|
|
||||||
|
|
||||||
def single_item_extractor(cls, args):
|
|
||||||
assert len(args) < 2
|
|
||||||
return args[0] if len(args) > 0 else cls.default()
|
|
||||||
|
|
||||||
class Bytes(AbstractList, metaclass=BytesMeta):
|
class BytesLike(AbstractList, metaclass=BytesMeta):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def extract_args(cls, args):
|
||||||
|
if isinstance(args, bytes):
|
||||||
|
return args
|
||||||
|
elif isinstance(args, BytesLike):
|
||||||
|
return args.items
|
||||||
|
elif isinstance(args, GeneratorType):
|
||||||
|
return bytes(args)
|
||||||
|
else:
|
||||||
|
return bytes(args)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def value_check(cls, value):
|
||||||
|
return len(value) == cls.length and isinstance(value, bytes)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
cls = self.__class__
|
||||||
|
return f"{cls.__name__}[{cls.length}]: {self.items.hex()}"
|
||||||
|
|
||||||
|
|
||||||
|
class Bytes(BytesLike):
|
||||||
|
|
||||||
def value_check(self, value):
|
def value_check(self, value):
|
||||||
return len(value) <= self.__class__.length and isinstance(value, bytes)
|
return len(value) <= self.__class__.length and isinstance(value, bytes)
|
||||||
|
|
||||||
extract_args = single_item_extractor
|
@classmethod
|
||||||
|
def default(cls):
|
||||||
def default(self):
|
|
||||||
return b''
|
return b''
|
||||||
|
|
||||||
class BytesN(AbstractList, metaclass=BytesMeta):
|
|
||||||
def value_check(self, value):
|
|
||||||
return len(value) == self.__class__.length and isinstance(value, bytes)
|
|
||||||
|
|
||||||
extract_args = single_item_extractor
|
class BytesN(BytesLike):
|
||||||
|
|
||||||
def default(self):
|
@classmethod
|
||||||
return b'\x00' * self.__class__.length
|
def default(cls):
|
||||||
|
return b'\x00' * cls.length
|
||||||
|
|
||||||
|
|
||||||
# Type helpers
|
|
||||||
# -----------------------------
|
|
||||||
|
|
||||||
def infer_type(obj):
|
|
||||||
if is_uint_type(obj.__class__):
|
|
||||||
return obj.__class__
|
|
||||||
elif isinstance(obj, int):
|
|
||||||
return uint64
|
|
||||||
elif isinstance(obj, (List, Vector, Container, bool, BytesN, Bytes)):
|
|
||||||
return obj.__class__
|
|
||||||
else:
|
|
||||||
raise Exception("Unknown type for {}".format(obj))
|
|
||||||
|
|
||||||
|
|
||||||
def infer_input_type(fn):
|
|
||||||
"""
|
|
||||||
Decorator to run infer_type on the obj if typ argument is None
|
|
||||||
"""
|
|
||||||
def infer_helper(obj, typ=None, **kwargs):
|
|
||||||
if typ is None:
|
|
||||||
typ = infer_type(obj)
|
|
||||||
return fn(obj, typ=typ, **kwargs)
|
|
||||||
return infer_helper
|
|
||||||
|
|
||||||
|
|
||||||
def is_bool_type(typ):
|
|
||||||
"""
|
|
||||||
Check if the given type is a bool.
|
|
||||||
"""
|
|
||||||
if hasattr(typ, '__supertype__'):
|
|
||||||
typ = typ.__supertype__
|
|
||||||
return isinstance(typ, type) and issubclass(typ, bool)
|
|
||||||
|
|
||||||
|
|
||||||
def is_list_type(typ):
|
|
||||||
"""
|
|
||||||
Check if the given type is a list.
|
|
||||||
"""
|
|
||||||
return isinstance(typ, type) and issubclass(typ, List)
|
|
||||||
|
|
||||||
|
|
||||||
def is_bytes_type(typ):
|
|
||||||
"""
|
|
||||||
Check if the given type is a ``bytes``.
|
|
||||||
"""
|
|
||||||
return isinstance(typ, type) and issubclass(typ, Bytes)
|
|
||||||
|
|
||||||
|
|
||||||
def is_bytesn_type(typ):
|
|
||||||
"""
|
|
||||||
Check if the given type is a BytesN.
|
|
||||||
"""
|
|
||||||
return isinstance(typ, type) and issubclass(typ, BytesN)
|
|
||||||
|
|
||||||
|
|
||||||
def is_list_kind(typ):
|
|
||||||
"""
|
|
||||||
Check if the given type is a kind of list. Can be bytes.
|
|
||||||
"""
|
|
||||||
return is_list_type(typ) or is_bytes_type(typ)
|
|
||||||
|
|
||||||
|
|
||||||
def is_vector_type(typ):
|
|
||||||
"""
|
|
||||||
Check if the given type is a vector.
|
|
||||||
"""
|
|
||||||
return isinstance(typ, type) and issubclass(typ, Vector)
|
|
||||||
|
|
||||||
|
|
||||||
def is_vector_kind(typ):
|
|
||||||
"""
|
|
||||||
Check if the given type is a kind of vector. Can be BytesN.
|
|
||||||
"""
|
|
||||||
return is_vector_type(typ) or is_bytesn_type(typ)
|
|
||||||
|
|
||||||
|
|
||||||
def is_container_type(typ):
|
|
||||||
"""
|
|
||||||
Check if the given type is a container.
|
|
||||||
"""
|
|
||||||
return isinstance(typ, type) and issubclass(typ, Container)
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar('T')
|
|
||||||
L = TypeVar('L')
|
|
||||||
|
|
||||||
|
|
||||||
def read_elem_type(typ):
|
|
||||||
return typ.elem_type
|
|
||||||
|
|
Loading…
Reference in New Issue