ssz typing now subclasses list/bytes, much easier to work with than wrapped list/bytes functionality

This commit is contained in:
protolambda 2019-06-20 20:30:42 +02:00
parent 4e747fb887
commit 977856b06f
No known key found for this signature in database
GPG Key ID: EC89FDBB2B4C7623
3 changed files with 58 additions and 45 deletions

View File

@ -1,7 +1,7 @@
from ..merkle_minimal import merkleize_chunks
from ..hash_function import hash
from .ssz_typing import (
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Bytes, BytesN, uint,
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Bytes, uint,
)
# SSZ Serialization
@ -46,9 +46,8 @@ def serialize(obj: SSZValue):
def encode_series(values: Series):
# bytes and bytesN are already in the right format.
if isinstance(values, (Bytes, BytesN)):
return values.items
if isinstance(values, bytes): # Bytes and BytesN are already like serialized output
return values
# Recursively serialize
parts = [(v.type().is_fixed_size(), serialize(v)) for v in values]
@ -84,8 +83,8 @@ def encode_series(values: Series):
def pack(values: Series):
if isinstance(values, (Bytes, BytesN)):
return values.items
if isinstance(values, bytes): # Bytes and BytesN are already packed
return values
return b''.join([serialize_basic(value) for value in values])

View File

@ -188,10 +188,10 @@ class Container(Series, metaclass=SSZType):
class ParamsBase(Series):
_bare = True
_has_params = False
def __new__(cls, *args, **kwargs):
if cls._bare:
if not cls._has_params:
raise Exception("cannot init bare type without params")
return super().__new__(cls, **kwargs)
@ -200,13 +200,13 @@ class ParamsMeta(SSZType):
def __new__(cls, class_name, parents, attrs):
out = type.__new__(cls, class_name, parents, attrs)
if hasattr(out, "_has_params") and getattr(out, "_has_params"):
for k, v in attrs.items():
setattr(out, k, v)
return out
def __getitem__(self, params):
o = self.__class__(self.__name__, (self,), self.attr_from_params(params))
o._bare = False
return o
def __str__(self):
@ -218,7 +218,7 @@ class ParamsMeta(SSZType):
def attr_from_params(self, p):
# single key params are valid too. Wrap them in a tuple.
params = p if isinstance(p, tuple) else (p,)
res = {}
res = {'_has_params': True}
i = 0
for (name, typ) in self.__annotations__.items():
if hasattr(self.__class__, name):
@ -262,13 +262,17 @@ class ElementsType(ParamsMeta):
class Elements(ParamsBase, metaclass=ElementsType):
pass
class BaseList(list, Elements):
def __init__(self, *args):
items = self.extract_args(*args)
if not self.value_check(items):
raise ValueError(f"Bad input for class {self.__class__}: {items}")
self.items = items
super().__init__(items)
@classmethod
def value_check(cls, value):
@ -284,39 +288,32 @@ class Elements(ParamsBase, metaclass=ElementsType):
def __str__(self):
cls = self.__class__
return f"{cls.__name__}[{cls.elem_type.__name__}, {cls.length}]({', '.join(str(v) for v in self.items)})"
return f"{cls.__name__}[{cls.elem_type.__name__}, {cls.length}]({', '.join(str(v) for v in self)})"
def __getitem__(self, i) -> SSZValue:
return self.items[i]
if i < 0:
raise IndexError(f"cannot get item in type {self.__class__} at negative index {i}")
if i > len(self):
raise IndexError(f"cannot get item in type {self.__class__}"
f" at out of bounds index {i}")
return super().__getitem__(i)
def __setitem__(self, k, v):
if k < 0:
raise IndexError(f"cannot set item in type {self.__class__} at negative index {k} (to {v})")
if k > len(self.items):
if k > len(self):
raise IndexError(f"cannot set item in type {self.__class__}"
f" at out of bounds index {k} (to {v}, bound: {len(self.items)})")
self.items[k] = coerce_type_maybe(v, self.__class__.elem_type, strict=True)
f" at out of bounds index {k} (to {v}, bound: {len(self)})")
super().__setitem__(k, coerce_type_maybe(v, self.__class__.elem_type, strict=True))
def append(self, v):
self.items.append(coerce_type_maybe(v, self.__class__.elem_type, strict=True))
def pop(self):
if len(self.items) == 0:
raise IndexError("Pop from empty list")
else:
return self.items.pop()
def __len__(self):
return len(self.items)
def __repr__(self):
return repr(self.items)
super().append(coerce_type_maybe(v, self.__class__.elem_type, strict=True))
def __iter__(self) -> Iterator[SSZValue]:
return iter(self.items)
return super().__iter__()
class List(Elements):
class List(BaseList):
@classmethod
def default(cls):
@ -327,7 +324,7 @@ class List(Elements):
return False
class Vector(Elements):
class Vector(BaseList):
@classmethod
def value_check(cls, value):
@ -342,27 +339,35 @@ class Vector(Elements):
def is_fixed_size(cls):
return cls.elem_type.is_fixed_size()
def append(self, v):
raise Exception("cannot modify vector length")
def pop(self, *args):
raise Exception("cannot modify vector length")
class BytesType(ElementsType):
elem_type: SSZType = byte
length: int
class BytesLike(Elements, metaclass=BytesType):
class BaseBytes(bytes, Elements, metaclass=BytesType):
def __new__(cls, *args) -> "BaseBytes":
extracted_val = cls.extract_args(*args)
if not cls.value_check(extracted_val):
raise ValueError(f"Bad input for class {cls}: {extracted_val}")
return super().__new__(cls, extracted_val)
@classmethod
def extract_args(cls, *args):
x = list(args)
if len(x) == 1 and isinstance(x[0], (GeneratorType, bytes, BytesLike)):
x = args
if len(x) == 1 and isinstance(x[0], (GeneratorType, bytes)):
x = x[0]
if isinstance(x, bytes):
if isinstance(x, bytes): # Includes BytesLike
return x
elif isinstance(x, BytesLike):
return x.items
elif isinstance(x, GeneratorType):
return bytes(x)
else:
return bytes(x)
return bytes(x) # E.g. GeneratorType put into bytes.
@classmethod
def value_check(cls, value):
@ -374,7 +379,7 @@ class BytesLike(Elements, metaclass=BytesType):
return f"{cls.__name__}[{cls.length}]: {self.hex()}"
class Bytes(BytesLike):
class Bytes(BaseBytes):
@classmethod
def default(cls):
@ -385,7 +390,14 @@ class Bytes(BytesLike):
return False
class BytesN(BytesLike):
class BytesN(BaseBytes):
@classmethod
def extract_args(cls, *args):
if len(args) == 0:
return cls.default()
else:
return super().extract_args(*args)
@classmethod
def default(cls):

View File

@ -211,3 +211,5 @@ def test_bytesn_subclass():
assert issubclass(Hash, Bytes32)
assert not issubclass(Bytes48, Bytes32)
assert len(Bytes32() + Bytes48()) == 80