highly experimental typing
This commit is contained in:
parent
108410d862
commit
4ebdceaf12
|
@ -1,7 +1,7 @@
|
|||
from ..merkle_minimal import merkleize_chunks, ZERO_BYTES32
|
||||
from ..merkle_minimal import merkleize_chunks
|
||||
from ..hash_function import hash
|
||||
from .ssz_typing import (
|
||||
get_zero_value, Container, List, Vector, Bytes, BytesN, uint
|
||||
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Vector, Bytes, BytesN, uint
|
||||
)
|
||||
|
||||
# SSZ Serialization
|
||||
|
@ -10,79 +10,48 @@ from .ssz_typing import (
|
|||
BYTES_PER_LENGTH_OFFSET = 4
|
||||
|
||||
|
||||
def is_basic_type(typ):
|
||||
return issubclass(typ, (bool, uint))
|
||||
|
||||
|
||||
def serialize_basic(value, typ):
|
||||
if issubclass(typ, uint):
|
||||
return value.to_bytes(typ.byte_len, 'little')
|
||||
elif issubclass(typ, bool):
|
||||
def serialize_basic(value: SSZValue):
|
||||
if isinstance(value, uint):
|
||||
return value.to_bytes(value.type().byte_len, 'little')
|
||||
elif isinstance(value, Bit):
|
||||
if value:
|
||||
return b'\x01'
|
||||
else:
|
||||
return b'\x00'
|
||||
else:
|
||||
raise Exception("Type not supported: {}".format(typ))
|
||||
raise Exception(f"Type not supported: {type(value)}")
|
||||
|
||||
|
||||
def deserialize_basic(value, typ):
|
||||
def deserialize_basic(value, typ: BasicType):
|
||||
if issubclass(typ, uint):
|
||||
return typ(int.from_bytes(value, 'little'))
|
||||
elif issubclass(typ, bool):
|
||||
elif issubclass(typ, Bit):
|
||||
assert value in (b'\x00', b'\x01')
|
||||
return True if value == b'\x01' else False
|
||||
return Bit(value == b'\x01')
|
||||
else:
|
||||
raise Exception("Type not supported: {}".format(typ))
|
||||
raise Exception(f"Type not supported: {typ}")
|
||||
|
||||
|
||||
def is_list_kind(typ):
|
||||
return issubclass(typ, (List, Bytes))
|
||||
def is_empty(obj: SSZValue):
|
||||
return type(obj).default() == obj
|
||||
|
||||
|
||||
def is_vector_kind(typ):
|
||||
return issubclass(typ, (Vector, BytesN))
|
||||
|
||||
|
||||
def is_container_type(typ):
|
||||
return issubclass(typ, Container)
|
||||
|
||||
|
||||
def is_fixed_size(typ):
|
||||
if is_basic_type(typ):
|
||||
return True
|
||||
elif is_list_kind(typ):
|
||||
return False
|
||||
elif is_vector_kind(typ):
|
||||
return is_fixed_size(typ.elem_type)
|
||||
elif is_container_type(typ):
|
||||
return all(is_fixed_size(t) for t in typ.get_field_types())
|
||||
def serialize(obj: SSZValue):
|
||||
if isinstance(obj, BasicValue):
|
||||
return serialize_basic(obj)
|
||||
elif isinstance(obj, Series):
|
||||
return encode_series(obj)
|
||||
else:
|
||||
raise Exception("Type not supported: {}".format(typ))
|
||||
raise Exception(f"Type not supported: {type(obj)}")
|
||||
|
||||
|
||||
def is_empty(obj):
|
||||
return get_zero_value(type(obj)) == obj
|
||||
|
||||
|
||||
def serialize(obj, typ):
|
||||
if is_basic_type(typ):
|
||||
return serialize_basic(obj, typ)
|
||||
elif is_list_kind(typ) or is_vector_kind(typ):
|
||||
return encode_series(obj, [typ.elem_type] * len(obj))
|
||||
elif is_container_type(typ):
|
||||
return encode_series(obj.get_field_values(), typ.get_field_types())
|
||||
else:
|
||||
raise Exception("Type not supported: {}".format(typ))
|
||||
|
||||
|
||||
def encode_series(values, types):
|
||||
def encode_series(values: Series):
|
||||
# bytes and bytesN are already in the right format.
|
||||
if isinstance(values, bytes):
|
||||
return values
|
||||
|
||||
# Recursively serialize
|
||||
parts = [(is_fixed_size(types[i]), serialize(values[i], typ=types[i])) for i in range(len(values))]
|
||||
parts = [(v.type().is_fixed_size(), serialize(v)) for v in values]
|
||||
|
||||
# Compute and check lengths
|
||||
fixed_lengths = [len(serialized) if constant_size else BYTES_PER_LENGTH_OFFSET
|
||||
|
@ -114,10 +83,10 @@ def encode_series(values, types):
|
|||
# -----------------------------
|
||||
|
||||
|
||||
def pack(values, subtype):
|
||||
def pack(values: Series):
|
||||
if isinstance(values, bytes):
|
||||
return values
|
||||
return b''.join([serialize_basic(value, subtype) for value in values])
|
||||
return b''.join([serialize_basic(value) for value in values])
|
||||
|
||||
|
||||
def chunkify(bytez):
|
||||
|
@ -130,52 +99,49 @@ def mix_in_length(root, length):
|
|||
return hash(root + length.to_bytes(32, 'little'))
|
||||
|
||||
|
||||
def is_bottom_layer_kind(typ):
|
||||
def is_bottom_layer_kind(typ: SSZType):
|
||||
return (
|
||||
is_basic_type(typ) or
|
||||
(is_list_kind(typ) or is_vector_kind(typ)) and is_basic_type(typ.elem_type)
|
||||
issubclass(typ, BasicType) or
|
||||
(issubclass(typ, Elements) and issubclass(typ.elem_type, BasicType))
|
||||
)
|
||||
|
||||
|
||||
def get_typed_values(obj, typ):
|
||||
if is_container_type(typ):
|
||||
return obj.get_typed_values()
|
||||
elif is_list_kind(typ) or is_vector_kind(typ):
|
||||
return list(zip(obj, [typ.elem_type] * len(obj)))
|
||||
else:
|
||||
raise Exception("Invalid type")
|
||||
|
||||
|
||||
def item_length(typ):
|
||||
if typ == bool:
|
||||
return 1
|
||||
elif issubclass(typ, uint):
|
||||
def item_length(typ: SSZType) -> int:
|
||||
if issubclass(typ, BasicType):
|
||||
return typ.byte_len
|
||||
else:
|
||||
return 32
|
||||
def chunk_count(typ):
|
||||
if is_basic_type(typ):
|
||||
return 1
|
||||
elif is_list_kind(typ) or is_vector_kind(typ):
|
||||
return (typ.length * item_length(typ.elem_type) + 31) // 32
|
||||
else:
|
||||
return len(typ.get_fields())
|
||||
|
||||
def hash_tree_root(obj, typ):
|
||||
if is_bottom_layer_kind(typ):
|
||||
data = serialize_basic(obj, typ) if is_basic_type(typ) else pack(obj, typ.elem_type)
|
||||
leaves = chunkify(data)
|
||||
|
||||
def chunk_count(typ: SSZType) -> int:
|
||||
if issubclass(typ, BasicType):
|
||||
return 1
|
||||
elif issubclass(typ, Elements):
|
||||
return (typ.length * item_length(typ.elem_type) + 31) // 32
|
||||
elif issubclass(typ, Container):
|
||||
return len(typ.get_fields())
|
||||
else:
|
||||
fields = get_typed_values(obj, typ=typ)
|
||||
leaves = [hash_tree_root(field_value, typ=field_typ) for field_value, field_typ in fields]
|
||||
if is_list_kind(typ):
|
||||
return mix_in_length(merkleize_chunks(leaves, pad_to=chunk_count(typ)), len(obj))
|
||||
raise Exception(f"Type not supported: {typ}")
|
||||
|
||||
|
||||
def hash_tree_root(obj: SSZValue):
|
||||
if isinstance(obj, Series):
|
||||
if is_bottom_layer_kind(obj.type()):
|
||||
leaves = chunkify(pack(obj))
|
||||
else:
|
||||
leaves = [hash_tree_root(value) for value in obj]
|
||||
elif isinstance(obj, BasicValue):
|
||||
leaves = chunkify(serialize_basic(obj))
|
||||
else:
|
||||
raise Exception(f"Type not supported: {obj.type()}")
|
||||
|
||||
if isinstance(obj, (List, Bytes)):
|
||||
return mix_in_length(merkleize_chunks(leaves, pad_to=chunk_count(obj.type())), len(obj))
|
||||
else:
|
||||
return merkleize_chunks(leaves)
|
||||
|
||||
|
||||
def signing_root(obj, typ):
|
||||
assert is_container_type(typ)
|
||||
def signing_root(obj: Container):
|
||||
# ignore last field
|
||||
leaves = [hash_tree_root(field_value, typ=field_typ) for field_value, field_typ in obj.get_typed_values()[:-1]]
|
||||
leaves = [hash_tree_root(field) for field in obj[:-1]]
|
||||
return merkleize_chunks(chunkify(b''.join(leaves)))
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import NewType, Union
|
||||
from typing import Tuple, Dict, Iterator
|
||||
from types import GeneratorType
|
||||
|
||||
|
||||
|
@ -10,24 +10,42 @@ class DefaultingTypeMeta(type):
|
|||
def default(cls):
|
||||
raise Exception("Not implemented")
|
||||
|
||||
# Every type is subclassed and has a default() method, except bool.
|
||||
TypeWithDefault = Union[DefaultingTypeMeta, bool]
|
||||
|
||||
class SSZType(DefaultingTypeMeta):
|
||||
|
||||
def is_fixed_size(cls):
|
||||
raise Exception("Not implemented")
|
||||
|
||||
|
||||
def get_zero_value(typ: TypeWithDefault):
|
||||
if issubclass(typ, bool):
|
||||
return False
|
||||
else:
|
||||
return typ.default()
|
||||
class SSZValue(object, metaclass=SSZType):
|
||||
|
||||
def type(self):
|
||||
return self.__class__
|
||||
|
||||
|
||||
# SSZ integers
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class uint(int, metaclass=DefaultingTypeMeta):
|
||||
class BasicType(SSZType):
|
||||
byte_len = 0
|
||||
|
||||
def is_fixed_size(cls):
|
||||
return True
|
||||
|
||||
|
||||
class BasicValue(int, SSZValue, metaclass=BasicType):
|
||||
pass
|
||||
|
||||
|
||||
class Bit(BasicValue): # can't subclass bool.
|
||||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
return cls(False)
|
||||
|
||||
def __bool__(self):
|
||||
return self > 0
|
||||
|
||||
|
||||
class uint(BasicValue, metaclass=BasicType):
|
||||
|
||||
def __new__(cls, value, *args, **kwargs):
|
||||
if value < 0:
|
||||
raise ValueError("unsigned types must not be negative")
|
||||
|
@ -69,36 +87,39 @@ class uint256(uint):
|
|||
byte_len = 32
|
||||
|
||||
|
||||
# SSZ Container base class
|
||||
# -----------------------------
|
||||
class Series(SSZValue):
|
||||
|
||||
def __iter__(self) -> Iterator[SSZValue]:
|
||||
raise Exception("Not implemented")
|
||||
|
||||
|
||||
# Note: importing ssz functionality locally, to avoid import loop
|
||||
|
||||
class Container(object, metaclass=DefaultingTypeMeta):
|
||||
class Container(Series, metaclass=SSZType):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
cls = self.__class__
|
||||
for f, t in cls.get_fields():
|
||||
if f not in kwargs:
|
||||
setattr(self, f, get_zero_value(t))
|
||||
setattr(self, f, t.default())
|
||||
else:
|
||||
setattr(self, f, kwargs[f])
|
||||
|
||||
def serialize(self):
|
||||
from .ssz_impl import serialize
|
||||
return serialize(self, self.__class__)
|
||||
return serialize(self)
|
||||
|
||||
def hash_tree_root(self):
|
||||
from .ssz_impl import hash_tree_root
|
||||
return hash_tree_root(self, self.__class__)
|
||||
return hash_tree_root(self)
|
||||
|
||||
def signing_root(self):
|
||||
from .ssz_impl import signing_root
|
||||
return signing_root(self, self.__class__)
|
||||
return signing_root(self)
|
||||
|
||||
def get_field_values(self):
|
||||
def get_field_values(self) -> Tuple[SSZValue, ...]:
|
||||
cls = self.__class__
|
||||
return [getattr(self, field) for field in cls.get_field_names()]
|
||||
return tuple(getattr(self, field) for field in cls.get_field_names())
|
||||
|
||||
def __repr__(self):
|
||||
return repr({field: getattr(self, field) for field in self.get_field_names()})
|
||||
|
@ -116,31 +137,38 @@ class Container(object, metaclass=DefaultingTypeMeta):
|
|||
return hash(self.hash_tree_root())
|
||||
|
||||
@classmethod
|
||||
def get_fields_dict(cls):
|
||||
def get_fields_dict(cls) -> Dict[str, SSZType]:
|
||||
return dict(cls.__annotations__)
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
return list(dict(cls.__annotations__).items())
|
||||
def get_fields(cls) -> Tuple[Tuple[str, SSZType], ...]:
|
||||
return tuple((f, SSZType(t)) for f, t in dict(cls.__annotations__).items())
|
||||
|
||||
def get_typed_values(self):
|
||||
return list(zip(self.get_field_values(), self.get_field_types()))
|
||||
return tuple(zip(self.get_field_values(), self.get_field_types()))
|
||||
|
||||
@classmethod
|
||||
def get_field_names(cls):
|
||||
return list(cls.__annotations__.keys())
|
||||
def get_field_names(cls) -> Tuple[str]:
|
||||
return tuple(cls.__annotations__.keys())
|
||||
|
||||
@classmethod
|
||||
def get_field_types(cls):
|
||||
def get_field_types(cls) -> Tuple[SSZType, ...]:
|
||||
# values of annotations are the types corresponding to the fields, not instance values.
|
||||
return list(cls.__annotations__.values())
|
||||
return tuple(cls.__annotations__.values())
|
||||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
return cls(**{f: get_zero_value(t) for f, t in cls.get_fields()})
|
||||
return cls(**{f: t.default() for f, t in cls.get_fields()})
|
||||
|
||||
@classmethod
|
||||
def is_fixed_size(cls):
|
||||
return all(t.is_fixed_size() for t in cls.get_field_types())
|
||||
|
||||
def __iter__(self) -> Iterator[SSZValue]:
|
||||
return iter(self.get_field_values())
|
||||
|
||||
|
||||
class ParamsBase:
|
||||
class ParamsBase(Series):
|
||||
_bare = True
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
|
@ -149,7 +177,7 @@ class ParamsBase:
|
|||
return super().__new__(cls, **kwargs)
|
||||
|
||||
|
||||
class ParamsMeta(DefaultingTypeMeta):
|
||||
class ParamsMeta(SSZType):
|
||||
|
||||
def __new__(cls, class_name, parents, attrs):
|
||||
out = type.__new__(cls, class_name, parents, attrs)
|
||||
|
@ -168,14 +196,14 @@ class ParamsMeta(DefaultingTypeMeta):
|
|||
res = {}
|
||||
i = 0
|
||||
for (name, typ) in self.__annotations__.items():
|
||||
param = params[i]
|
||||
if hasattr(self.__class__, name):
|
||||
res[name] = getattr(self.__class__, name)
|
||||
else:
|
||||
if typ == TypeWithDefault:
|
||||
if not (isinstance(param, bool) or isinstance(param, DefaultingTypeMeta)):
|
||||
raise TypeError("expected param {} as {} to have a type default".format(param, name, typ))
|
||||
elif not isinstance(param, typ):
|
||||
if i >= len(params):
|
||||
i += 1
|
||||
continue
|
||||
param = params[i]
|
||||
if not isinstance(param, typ):
|
||||
raise TypeError(
|
||||
"cannot create parametrized class with param {} as {} of type {}".format(param, name, typ))
|
||||
res[name] = param
|
||||
|
@ -194,12 +222,12 @@ class ParamsMeta(DefaultingTypeMeta):
|
|||
return True
|
||||
|
||||
|
||||
class AbstractListMeta(ParamsMeta):
|
||||
elem_type: TypeWithDefault
|
||||
class Elements(ParamsMeta):
|
||||
elem_type: SSZType
|
||||
length: int
|
||||
|
||||
|
||||
class AbstractList(ParamsBase, metaclass=AbstractListMeta):
|
||||
class ElementsBase(ParamsBase, metaclass=Elements):
|
||||
|
||||
def __init__(self, *args):
|
||||
items = self.extract_args(*args)
|
||||
|
@ -223,7 +251,7 @@ class AbstractList(ParamsBase, metaclass=AbstractListMeta):
|
|||
cls = self.__class__
|
||||
return f"{cls.__name__}[{cls.elem_type.__name__}, {cls.length}]({', '.join(str(v) for v in self.items)})"
|
||||
|
||||
def __getitem__(self, i):
|
||||
def __getitem__(self, i) -> SSZValue:
|
||||
return self.items[i]
|
||||
|
||||
def __setitem__(self, k, v):
|
||||
|
@ -235,21 +263,25 @@ class AbstractList(ParamsBase, metaclass=AbstractListMeta):
|
|||
def __repr__(self):
|
||||
return repr(self.items)
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator[SSZValue]:
|
||||
return iter(self.items)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.items == other.items
|
||||
|
||||
|
||||
class List(AbstractList):
|
||||
class List(ElementsBase):
|
||||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def is_fixed_size(cls):
|
||||
return False
|
||||
|
||||
class Vector(AbstractList, metaclass=AbstractListMeta):
|
||||
|
||||
class Vector(ElementsBase):
|
||||
|
||||
@classmethod
|
||||
def value_check(cls, value):
|
||||
|
@ -257,15 +289,19 @@ class Vector(AbstractList, metaclass=AbstractListMeta):
|
|||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
return [get_zero_value(cls.elem_type) for _ in range(cls.length)]
|
||||
return [cls.elem_type.default() for _ in range(cls.length)]
|
||||
|
||||
@classmethod
|
||||
def is_fixed_size(cls):
|
||||
return cls.elem_type.is_fixed_size()
|
||||
|
||||
|
||||
class BytesMeta(AbstractListMeta):
|
||||
elem_type: TypeWithDefault = byte
|
||||
class BytesMeta(Elements):
|
||||
elem_type: SSZType = byte
|
||||
length: int
|
||||
|
||||
|
||||
class BytesLike(AbstractList, metaclass=BytesMeta):
|
||||
class BytesLike(ElementsBase, metaclass=BytesMeta):
|
||||
|
||||
@classmethod
|
||||
def extract_args(cls, args):
|
||||
|
@ -293,6 +329,10 @@ class Bytes(BytesLike):
|
|||
def default(cls):
|
||||
return b''
|
||||
|
||||
@classmethod
|
||||
def is_fixed_size(cls):
|
||||
return False
|
||||
|
||||
|
||||
class BytesN(BytesLike):
|
||||
|
||||
|
@ -304,3 +344,6 @@ class BytesN(BytesLike):
|
|||
def value_check(cls, value):
|
||||
return len(value) == cls.length and super().value_check(value)
|
||||
|
||||
@classmethod
|
||||
def is_fixed_size(cls):
|
||||
return True
|
||||
|
|
Loading…
Reference in New Issue