ssz typing now subclasses list/bytes, much easier to work with than wrapped list/bytes functionality

This commit is contained in:
protolambda 2019-06-20 20:30:42 +02:00
parent 4e747fb887
commit 977856b06f
No known key found for this signature in database
GPG Key ID: EC89FDBB2B4C7623
3 changed files with 58 additions and 45 deletions

View File

@ -1,7 +1,7 @@
from ..merkle_minimal import merkleize_chunks from ..merkle_minimal import merkleize_chunks
from ..hash_function import hash from ..hash_function import hash
from .ssz_typing import ( from .ssz_typing import (
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Bytes, BytesN, uint, SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Bytes, uint,
) )
# SSZ Serialization # SSZ Serialization
@ -46,9 +46,8 @@ def serialize(obj: SSZValue):
def encode_series(values: Series): def encode_series(values: Series):
# bytes and bytesN are already in the right format. if isinstance(values, bytes): # Bytes and BytesN are already like serialized output
if isinstance(values, (Bytes, BytesN)): return values
return values.items
# Recursively serialize # Recursively serialize
parts = [(v.type().is_fixed_size(), serialize(v)) for v in values] parts = [(v.type().is_fixed_size(), serialize(v)) for v in values]
@ -84,8 +83,8 @@ def encode_series(values: Series):
def pack(values: Series): def pack(values: Series):
if isinstance(values, (Bytes, BytesN)): if isinstance(values, bytes): # Bytes and BytesN are already packed
return values.items return values
return b''.join([serialize_basic(value) for value in values]) return b''.join([serialize_basic(value) for value in values])

View File

@ -188,10 +188,10 @@ class Container(Series, metaclass=SSZType):
class ParamsBase(Series): class ParamsBase(Series):
_bare = True _has_params = False
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if cls._bare: if not cls._has_params:
raise Exception("cannot init bare type without params") raise Exception("cannot init bare type without params")
return super().__new__(cls, **kwargs) return super().__new__(cls, **kwargs)
@ -200,13 +200,13 @@ class ParamsMeta(SSZType):
def __new__(cls, class_name, parents, attrs): def __new__(cls, class_name, parents, attrs):
out = type.__new__(cls, class_name, parents, attrs) out = type.__new__(cls, class_name, parents, attrs)
for k, v in attrs.items(): if hasattr(out, "_has_params") and getattr(out, "_has_params"):
setattr(out, k, v) for k, v in attrs.items():
setattr(out, k, v)
return out return out
def __getitem__(self, params): def __getitem__(self, params):
o = self.__class__(self.__name__, (self,), self.attr_from_params(params)) o = self.__class__(self.__name__, (self,), self.attr_from_params(params))
o._bare = False
return o return o
def __str__(self): def __str__(self):
@ -218,7 +218,7 @@ class ParamsMeta(SSZType):
def attr_from_params(self, p): def attr_from_params(self, p):
# single key params are valid too. Wrap them in a tuple. # single key params are valid too. Wrap them in a tuple.
params = p if isinstance(p, tuple) else (p,) params = p if isinstance(p, tuple) else (p,)
res = {} res = {'_has_params': True}
i = 0 i = 0
for (name, typ) in self.__annotations__.items(): for (name, typ) in self.__annotations__.items():
if hasattr(self.__class__, name): if hasattr(self.__class__, name):
@ -262,13 +262,17 @@ class ElementsType(ParamsMeta):
class Elements(ParamsBase, metaclass=ElementsType): class Elements(ParamsBase, metaclass=ElementsType):
pass
class BaseList(list, Elements):
def __init__(self, *args): def __init__(self, *args):
items = self.extract_args(*args) items = self.extract_args(*args)
if not self.value_check(items): if not self.value_check(items):
raise ValueError(f"Bad input for class {self.__class__}: {items}") raise ValueError(f"Bad input for class {self.__class__}: {items}")
self.items = items super().__init__(items)
@classmethod @classmethod
def value_check(cls, value): def value_check(cls, value):
@ -284,39 +288,32 @@ class Elements(ParamsBase, metaclass=ElementsType):
def __str__(self): def __str__(self):
cls = self.__class__ cls = self.__class__
return f"{cls.__name__}[{cls.elem_type.__name__}, {cls.length}]({', '.join(str(v) for v in self.items)})" return f"{cls.__name__}[{cls.elem_type.__name__}, {cls.length}]({', '.join(str(v) for v in self)})"
def __getitem__(self, i) -> SSZValue: def __getitem__(self, i) -> SSZValue:
return self.items[i] if i < 0:
raise IndexError(f"cannot get item in type {self.__class__} at negative index {i}")
if i > len(self):
raise IndexError(f"cannot get item in type {self.__class__}"
f" at out of bounds index {i}")
return super().__getitem__(i)
def __setitem__(self, k, v): def __setitem__(self, k, v):
if k < 0: if k < 0:
raise IndexError(f"cannot set item in type {self.__class__} at negative index {k} (to {v})") raise IndexError(f"cannot set item in type {self.__class__} at negative index {k} (to {v})")
if k > len(self.items): if k > len(self):
raise IndexError(f"cannot set item in type {self.__class__}" raise IndexError(f"cannot set item in type {self.__class__}"
f" at out of bounds index {k} (to {v}, bound: {len(self.items)})") f" at out of bounds index {k} (to {v}, bound: {len(self)})")
self.items[k] = coerce_type_maybe(v, self.__class__.elem_type, strict=True) super().__setitem__(k, coerce_type_maybe(v, self.__class__.elem_type, strict=True))
def append(self, v): def append(self, v):
self.items.append(coerce_type_maybe(v, self.__class__.elem_type, strict=True)) super().append(coerce_type_maybe(v, self.__class__.elem_type, strict=True))
def pop(self):
if len(self.items) == 0:
raise IndexError("Pop from empty list")
else:
return self.items.pop()
def __len__(self):
return len(self.items)
def __repr__(self):
return repr(self.items)
def __iter__(self) -> Iterator[SSZValue]: def __iter__(self) -> Iterator[SSZValue]:
return iter(self.items) return super().__iter__()
class List(Elements): class List(BaseList):
@classmethod @classmethod
def default(cls): def default(cls):
@ -327,7 +324,7 @@ class List(Elements):
return False return False
class Vector(Elements): class Vector(BaseList):
@classmethod @classmethod
def value_check(cls, value): def value_check(cls, value):
@ -342,27 +339,35 @@ class Vector(Elements):
def is_fixed_size(cls): def is_fixed_size(cls):
return cls.elem_type.is_fixed_size() return cls.elem_type.is_fixed_size()
def append(self, v):
raise Exception("cannot modify vector length")
def pop(self, *args):
raise Exception("cannot modify vector length")
class BytesType(ElementsType): class BytesType(ElementsType):
elem_type: SSZType = byte elem_type: SSZType = byte
length: int length: int
class BytesLike(Elements, metaclass=BytesType): class BaseBytes(bytes, Elements, metaclass=BytesType):
def __new__(cls, *args) -> "BaseBytes":
extracted_val = cls.extract_args(*args)
if not cls.value_check(extracted_val):
raise ValueError(f"Bad input for class {cls}: {extracted_val}")
return super().__new__(cls, extracted_val)
@classmethod @classmethod
def extract_args(cls, *args): def extract_args(cls, *args):
x = list(args) x = args
if len(x) == 1 and isinstance(x[0], (GeneratorType, bytes, BytesLike)): if len(x) == 1 and isinstance(x[0], (GeneratorType, bytes)):
x = x[0] x = x[0]
if isinstance(x, bytes): if isinstance(x, bytes): # Includes BytesLike
return x return x
elif isinstance(x, BytesLike):
return x.items
elif isinstance(x, GeneratorType):
return bytes(x)
else: else:
return bytes(x) return bytes(x) # E.g. GeneratorType put into bytes.
@classmethod @classmethod
def value_check(cls, value): def value_check(cls, value):
@ -374,7 +379,7 @@ class BytesLike(Elements, metaclass=BytesType):
return f"{cls.__name__}[{cls.length}]: {self.hex()}" return f"{cls.__name__}[{cls.length}]: {self.hex()}"
class Bytes(BytesLike): class Bytes(BaseBytes):
@classmethod @classmethod
def default(cls): def default(cls):
@ -385,7 +390,14 @@ class Bytes(BytesLike):
return False return False
class BytesN(BytesLike): class BytesN(BaseBytes):
@classmethod
def extract_args(cls, *args):
if len(args) == 0:
return cls.default()
else:
return super().extract_args(*args)
@classmethod @classmethod
def default(cls): def default(cls):

View File

@ -211,3 +211,5 @@ def test_bytesn_subclass():
assert issubclass(Hash, Bytes32) assert issubclass(Hash, Bytes32)
assert not issubclass(Bytes48, Bytes32) assert not issubclass(Bytes48, Bytes32)
assert len(Bytes32() + Bytes48()) == 80