highly experimental typing

This commit is contained in:
protolambda 2019-06-20 19:57:50 +02:00
parent 108410d862
commit 4ebdceaf12
No known key found for this signature in database
GPG Key ID: EC89FDBB2B4C7623
2 changed files with 146 additions and 137 deletions

View File

@ -1,7 +1,7 @@
from ..merkle_minimal import merkleize_chunks, ZERO_BYTES32 from ..merkle_minimal import merkleize_chunks
from ..hash_function import hash from ..hash_function import hash
from .ssz_typing import ( from .ssz_typing import (
get_zero_value, Container, List, Vector, Bytes, BytesN, uint SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Vector, Bytes, BytesN, uint
) )
# SSZ Serialization # SSZ Serialization
@ -10,79 +10,48 @@ from .ssz_typing import (
BYTES_PER_LENGTH_OFFSET = 4 BYTES_PER_LENGTH_OFFSET = 4
def is_basic_type(typ): def serialize_basic(value: SSZValue):
return issubclass(typ, (bool, uint)) if isinstance(value, uint):
return value.to_bytes(value.type().byte_len, 'little')
elif isinstance(value, Bit):
def serialize_basic(value, typ):
if issubclass(typ, uint):
return value.to_bytes(typ.byte_len, 'little')
elif issubclass(typ, bool):
if value: if value:
return b'\x01' return b'\x01'
else: else:
return b'\x00' return b'\x00'
else: else:
raise Exception("Type not supported: {}".format(typ)) raise Exception(f"Type not supported: {type(value)}")
def deserialize_basic(value, typ): def deserialize_basic(value, typ: BasicType):
if issubclass(typ, uint): if issubclass(typ, uint):
return typ(int.from_bytes(value, 'little')) return typ(int.from_bytes(value, 'little'))
elif issubclass(typ, bool): elif issubclass(typ, Bit):
assert value in (b'\x00', b'\x01') assert value in (b'\x00', b'\x01')
return True if value == b'\x01' else False return Bit(value == b'\x01')
else: else:
raise Exception("Type not supported: {}".format(typ)) raise Exception(f"Type not supported: {typ}")
def is_list_kind(typ): def is_empty(obj: SSZValue):
return issubclass(typ, (List, Bytes)) return type(obj).default() == obj
def is_vector_kind(typ): def serialize(obj: SSZValue):
return issubclass(typ, (Vector, BytesN)) if isinstance(obj, BasicValue):
return serialize_basic(obj)
elif isinstance(obj, Series):
def is_container_type(typ): return encode_series(obj)
return issubclass(typ, Container)
def is_fixed_size(typ):
if is_basic_type(typ):
return True
elif is_list_kind(typ):
return False
elif is_vector_kind(typ):
return is_fixed_size(typ.elem_type)
elif is_container_type(typ):
return all(is_fixed_size(t) for t in typ.get_field_types())
else: else:
raise Exception("Type not supported: {}".format(typ)) raise Exception(f"Type not supported: {type(obj)}")
def is_empty(obj): def encode_series(values: Series):
return get_zero_value(type(obj)) == obj
def serialize(obj, typ):
if is_basic_type(typ):
return serialize_basic(obj, typ)
elif is_list_kind(typ) or is_vector_kind(typ):
return encode_series(obj, [typ.elem_type] * len(obj))
elif is_container_type(typ):
return encode_series(obj.get_field_values(), typ.get_field_types())
else:
raise Exception("Type not supported: {}".format(typ))
def encode_series(values, types):
# bytes and bytesN are already in the right format. # bytes and bytesN are already in the right format.
if isinstance(values, bytes): if isinstance(values, bytes):
return values return values
# Recursively serialize # Recursively serialize
parts = [(is_fixed_size(types[i]), serialize(values[i], typ=types[i])) for i in range(len(values))] parts = [(v.type().is_fixed_size(), serialize(v)) for v in values]
# Compute and check lengths # Compute and check lengths
fixed_lengths = [len(serialized) if constant_size else BYTES_PER_LENGTH_OFFSET fixed_lengths = [len(serialized) if constant_size else BYTES_PER_LENGTH_OFFSET
@ -114,10 +83,10 @@ def encode_series(values, types):
# ----------------------------- # -----------------------------
def pack(values, subtype): def pack(values: Series):
if isinstance(values, bytes): if isinstance(values, bytes):
return values return values
return b''.join([serialize_basic(value, subtype) for value in values]) return b''.join([serialize_basic(value) for value in values])
def chunkify(bytez): def chunkify(bytez):
@ -130,52 +99,49 @@ def mix_in_length(root, length):
return hash(root + length.to_bytes(32, 'little')) return hash(root + length.to_bytes(32, 'little'))
def is_bottom_layer_kind(typ): def is_bottom_layer_kind(typ: SSZType):
return ( return (
is_basic_type(typ) or issubclass(typ, BasicType) or
(is_list_kind(typ) or is_vector_kind(typ)) and is_basic_type(typ.elem_type) (issubclass(typ, Elements) and issubclass(typ.elem_type, BasicType))
) )
def get_typed_values(obj, typ): def item_length(typ: SSZType) -> int:
if is_container_type(typ): if issubclass(typ, BasicType):
return obj.get_typed_values()
elif is_list_kind(typ) or is_vector_kind(typ):
return list(zip(obj, [typ.elem_type] * len(obj)))
else:
raise Exception("Invalid type")
def item_length(typ):
if typ == bool:
return 1
elif issubclass(typ, uint):
return typ.byte_len return typ.byte_len
else: else:
return 32 return 32
def chunk_count(typ):
if is_basic_type(typ):
return 1
elif is_list_kind(typ) or is_vector_kind(typ):
return (typ.length * item_length(typ.elem_type) + 31) // 32
else:
return len(typ.get_fields())
def hash_tree_root(obj, typ):
if is_bottom_layer_kind(typ): def chunk_count(typ: SSZType) -> int:
data = serialize_basic(obj, typ) if is_basic_type(typ) else pack(obj, typ.elem_type) if issubclass(typ, BasicType):
leaves = chunkify(data) return 1
elif issubclass(typ, Elements):
return (typ.length * item_length(typ.elem_type) + 31) // 32
elif issubclass(typ, Container):
return len(typ.get_fields())
else: else:
fields = get_typed_values(obj, typ=typ) raise Exception(f"Type not supported: {typ}")
leaves = [hash_tree_root(field_value, typ=field_typ) for field_value, field_typ in fields]
if is_list_kind(typ):
return mix_in_length(merkleize_chunks(leaves, pad_to=chunk_count(typ)), len(obj)) def hash_tree_root(obj: SSZValue):
if isinstance(obj, Series):
if is_bottom_layer_kind(obj.type()):
leaves = chunkify(pack(obj))
else:
leaves = [hash_tree_root(value) for value in obj]
elif isinstance(obj, BasicValue):
leaves = chunkify(serialize_basic(obj))
else:
raise Exception(f"Type not supported: {obj.type()}")
if isinstance(obj, (List, Bytes)):
return mix_in_length(merkleize_chunks(leaves, pad_to=chunk_count(obj.type())), len(obj))
else: else:
return merkleize_chunks(leaves) return merkleize_chunks(leaves)
def signing_root(obj, typ): def signing_root(obj: Container):
assert is_container_type(typ)
# ignore last field # ignore last field
leaves = [hash_tree_root(field_value, typ=field_typ) for field_value, field_typ in obj.get_typed_values()[:-1]] leaves = [hash_tree_root(field) for field in obj[:-1]]
return merkleize_chunks(chunkify(b''.join(leaves))) return merkleize_chunks(chunkify(b''.join(leaves)))

View File

@ -1,4 +1,4 @@
from typing import NewType, Union from typing import Tuple, Dict, Iterator
from types import GeneratorType from types import GeneratorType
@ -10,24 +10,42 @@ class DefaultingTypeMeta(type):
def default(cls): def default(cls):
raise Exception("Not implemented") raise Exception("Not implemented")
# Every type is subclassed and has a default() method, except bool.
TypeWithDefault = Union[DefaultingTypeMeta, bool] class SSZType(DefaultingTypeMeta):
def is_fixed_size(cls):
raise Exception("Not implemented")
def get_zero_value(typ: TypeWithDefault): class SSZValue(object, metaclass=SSZType):
if issubclass(typ, bool):
return False def type(self):
else: return self.__class__
return typ.default()
# SSZ integers class BasicType(SSZType):
# -----------------------------
class uint(int, metaclass=DefaultingTypeMeta):
byte_len = 0 byte_len = 0
def is_fixed_size(cls):
return True
class BasicValue(int, SSZValue, metaclass=BasicType):
pass
class Bit(BasicValue): # can't subclass bool.
@classmethod
def default(cls):
return cls(False)
def __bool__(self):
return self > 0
class uint(BasicValue, metaclass=BasicType):
def __new__(cls, value, *args, **kwargs): def __new__(cls, value, *args, **kwargs):
if value < 0: if value < 0:
raise ValueError("unsigned types must not be negative") raise ValueError("unsigned types must not be negative")
@ -69,36 +87,39 @@ class uint256(uint):
byte_len = 32 byte_len = 32
# SSZ Container base class class Series(SSZValue):
# -----------------------------
def __iter__(self) -> Iterator[SSZValue]:
raise Exception("Not implemented")
# Note: importing ssz functionality locally, to avoid import loop # Note: importing ssz functionality locally, to avoid import loop
class Container(object, metaclass=DefaultingTypeMeta): class Container(Series, metaclass=SSZType):
def __init__(self, **kwargs): def __init__(self, **kwargs):
cls = self.__class__ cls = self.__class__
for f, t in cls.get_fields(): for f, t in cls.get_fields():
if f not in kwargs: if f not in kwargs:
setattr(self, f, get_zero_value(t)) setattr(self, f, t.default())
else: else:
setattr(self, f, kwargs[f]) setattr(self, f, kwargs[f])
def serialize(self): def serialize(self):
from .ssz_impl import serialize from .ssz_impl import serialize
return serialize(self, self.__class__) return serialize(self)
def hash_tree_root(self): def hash_tree_root(self):
from .ssz_impl import hash_tree_root from .ssz_impl import hash_tree_root
return hash_tree_root(self, self.__class__) return hash_tree_root(self)
def signing_root(self): def signing_root(self):
from .ssz_impl import signing_root from .ssz_impl import signing_root
return signing_root(self, self.__class__) return signing_root(self)
def get_field_values(self): def get_field_values(self) -> Tuple[SSZValue, ...]:
cls = self.__class__ cls = self.__class__
return [getattr(self, field) for field in cls.get_field_names()] return tuple(getattr(self, field) for field in cls.get_field_names())
def __repr__(self): def __repr__(self):
return repr({field: getattr(self, field) for field in self.get_field_names()}) return repr({field: getattr(self, field) for field in self.get_field_names()})
@ -116,31 +137,38 @@ class Container(object, metaclass=DefaultingTypeMeta):
return hash(self.hash_tree_root()) return hash(self.hash_tree_root())
@classmethod @classmethod
def get_fields_dict(cls): def get_fields_dict(cls) -> Dict[str, SSZType]:
return dict(cls.__annotations__) return dict(cls.__annotations__)
@classmethod @classmethod
def get_fields(cls): def get_fields(cls) -> Tuple[Tuple[str, SSZType], ...]:
return list(dict(cls.__annotations__).items()) return tuple((f, SSZType(t)) for f, t in dict(cls.__annotations__).items())
def get_typed_values(self): def get_typed_values(self):
return list(zip(self.get_field_values(), self.get_field_types())) return tuple(zip(self.get_field_values(), self.get_field_types()))
@classmethod @classmethod
def get_field_names(cls): def get_field_names(cls) -> Tuple[str]:
return list(cls.__annotations__.keys()) return tuple(cls.__annotations__.keys())
@classmethod @classmethod
def get_field_types(cls): def get_field_types(cls) -> Tuple[SSZType, ...]:
# values of annotations are the types corresponding to the fields, not instance values. # values of annotations are the types corresponding to the fields, not instance values.
return list(cls.__annotations__.values()) return tuple(cls.__annotations__.values())
@classmethod @classmethod
def default(cls): def default(cls):
return cls(**{f: get_zero_value(t) for f, t in cls.get_fields()}) return cls(**{f: t.default() for f, t in cls.get_fields()})
@classmethod
def is_fixed_size(cls):
return all(t.is_fixed_size() for t in cls.get_field_types())
def __iter__(self) -> Iterator[SSZValue]:
return iter(self.get_field_values())
class ParamsBase: class ParamsBase(Series):
_bare = True _bare = True
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
@ -149,7 +177,7 @@ class ParamsBase:
return super().__new__(cls, **kwargs) return super().__new__(cls, **kwargs)
class ParamsMeta(DefaultingTypeMeta): class ParamsMeta(SSZType):
def __new__(cls, class_name, parents, attrs): def __new__(cls, class_name, parents, attrs):
out = type.__new__(cls, class_name, parents, attrs) out = type.__new__(cls, class_name, parents, attrs)
@ -168,14 +196,14 @@ class ParamsMeta(DefaultingTypeMeta):
res = {} res = {}
i = 0 i = 0
for (name, typ) in self.__annotations__.items(): for (name, typ) in self.__annotations__.items():
param = params[i]
if hasattr(self.__class__, name): if hasattr(self.__class__, name):
res[name] = getattr(self.__class__, name) res[name] = getattr(self.__class__, name)
else: else:
if typ == TypeWithDefault: if i >= len(params):
if not (isinstance(param, bool) or isinstance(param, DefaultingTypeMeta)): i += 1
raise TypeError("expected param {} as {} to have a type default".format(param, name, typ)) continue
elif not isinstance(param, typ): param = params[i]
if not isinstance(param, typ):
raise TypeError( raise TypeError(
"cannot create parametrized class with param {} as {} of type {}".format(param, name, typ)) "cannot create parametrized class with param {} as {} of type {}".format(param, name, typ))
res[name] = param res[name] = param
@ -194,12 +222,12 @@ class ParamsMeta(DefaultingTypeMeta):
return True return True
class AbstractListMeta(ParamsMeta): class Elements(ParamsMeta):
elem_type: TypeWithDefault elem_type: SSZType
length: int length: int
class AbstractList(ParamsBase, metaclass=AbstractListMeta): class ElementsBase(ParamsBase, metaclass=Elements):
def __init__(self, *args): def __init__(self, *args):
items = self.extract_args(*args) items = self.extract_args(*args)
@ -223,7 +251,7 @@ class AbstractList(ParamsBase, metaclass=AbstractListMeta):
cls = self.__class__ cls = self.__class__
return f"{cls.__name__}[{cls.elem_type.__name__}, {cls.length}]({', '.join(str(v) for v in self.items)})" return f"{cls.__name__}[{cls.elem_type.__name__}, {cls.length}]({', '.join(str(v) for v in self.items)})"
def __getitem__(self, i): def __getitem__(self, i) -> SSZValue:
return self.items[i] return self.items[i]
def __setitem__(self, k, v): def __setitem__(self, k, v):
@ -235,21 +263,25 @@ class AbstractList(ParamsBase, metaclass=AbstractListMeta):
def __repr__(self): def __repr__(self):
return repr(self.items) return repr(self.items)
def __iter__(self): def __iter__(self) -> Iterator[SSZValue]:
return iter(self.items) return iter(self.items)
def __eq__(self, other): def __eq__(self, other):
return self.items == other.items return self.items == other.items
class List(AbstractList): class List(ElementsBase):
@classmethod @classmethod
def default(cls): def default(cls):
return cls() return cls()
@classmethod
def is_fixed_size(cls):
return False
class Vector(AbstractList, metaclass=AbstractListMeta):
class Vector(ElementsBase):
@classmethod @classmethod
def value_check(cls, value): def value_check(cls, value):
@ -257,15 +289,19 @@ class Vector(AbstractList, metaclass=AbstractListMeta):
@classmethod @classmethod
def default(cls): def default(cls):
return [get_zero_value(cls.elem_type) for _ in range(cls.length)] return [cls.elem_type.default() for _ in range(cls.length)]
@classmethod
def is_fixed_size(cls):
return cls.elem_type.is_fixed_size()
class BytesMeta(AbstractListMeta): class BytesMeta(Elements):
elem_type: TypeWithDefault = byte elem_type: SSZType = byte
length: int length: int
class BytesLike(AbstractList, metaclass=BytesMeta): class BytesLike(ElementsBase, metaclass=BytesMeta):
@classmethod @classmethod
def extract_args(cls, args): def extract_args(cls, args):
@ -293,6 +329,10 @@ class Bytes(BytesLike):
def default(cls): def default(cls):
return b'' return b''
@classmethod
def is_fixed_size(cls):
return False
class BytesN(BytesLike): class BytesN(BytesLike):
@ -304,3 +344,6 @@ class BytesN(BytesLike):
def value_check(cls, value): def value_check(cls, value):
return len(value) == cls.length and super().value_check(value) return len(value) == cls.length and super().value_check(value)
@classmethod
def is_fixed_size(cls):
return True