ssz typing now subclasses list/bytes, much easier to work with than wrapped list/bytes functionality
This commit is contained in:
parent
4e747fb887
commit
977856b06f
|
@ -1,7 +1,7 @@
|
|||
from ..merkle_minimal import merkleize_chunks
|
||||
from ..hash_function import hash
|
||||
from .ssz_typing import (
|
||||
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Bytes, BytesN, uint,
|
||||
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Bytes, uint,
|
||||
)
|
||||
|
||||
# SSZ Serialization
|
||||
|
@ -46,9 +46,8 @@ def serialize(obj: SSZValue):
|
|||
|
||||
|
||||
def encode_series(values: Series):
|
||||
# bytes and bytesN are already in the right format.
|
||||
if isinstance(values, (Bytes, BytesN)):
|
||||
return values.items
|
||||
if isinstance(values, bytes): # Bytes and BytesN are already like serialized output
|
||||
return values
|
||||
|
||||
# Recursively serialize
|
||||
parts = [(v.type().is_fixed_size(), serialize(v)) for v in values]
|
||||
|
@ -84,8 +83,8 @@ def encode_series(values: Series):
|
|||
|
||||
|
||||
def pack(values: Series):
|
||||
if isinstance(values, (Bytes, BytesN)):
|
||||
return values.items
|
||||
if isinstance(values, bytes): # Bytes and BytesN are already packed
|
||||
return values
|
||||
return b''.join([serialize_basic(value) for value in values])
|
||||
|
||||
|
||||
|
|
|
@ -188,10 +188,10 @@ class Container(Series, metaclass=SSZType):
|
|||
|
||||
|
||||
class ParamsBase(Series):
|
||||
_bare = True
|
||||
_has_params = False
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._bare:
|
||||
if not cls._has_params:
|
||||
raise Exception("cannot init bare type without params")
|
||||
return super().__new__(cls, **kwargs)
|
||||
|
||||
|
@ -200,13 +200,13 @@ class ParamsMeta(SSZType):
|
|||
|
||||
def __new__(cls, class_name, parents, attrs):
|
||||
out = type.__new__(cls, class_name, parents, attrs)
|
||||
for k, v in attrs.items():
|
||||
setattr(out, k, v)
|
||||
if hasattr(out, "_has_params") and getattr(out, "_has_params"):
|
||||
for k, v in attrs.items():
|
||||
setattr(out, k, v)
|
||||
return out
|
||||
|
||||
def __getitem__(self, params):
|
||||
o = self.__class__(self.__name__, (self,), self.attr_from_params(params))
|
||||
o._bare = False
|
||||
return o
|
||||
|
||||
def __str__(self):
|
||||
|
@ -218,7 +218,7 @@ class ParamsMeta(SSZType):
|
|||
def attr_from_params(self, p):
|
||||
# single key params are valid too. Wrap them in a tuple.
|
||||
params = p if isinstance(p, tuple) else (p,)
|
||||
res = {}
|
||||
res = {'_has_params': True}
|
||||
i = 0
|
||||
for (name, typ) in self.__annotations__.items():
|
||||
if hasattr(self.__class__, name):
|
||||
|
@ -262,13 +262,17 @@ class ElementsType(ParamsMeta):
|
|||
|
||||
|
||||
class Elements(ParamsBase, metaclass=ElementsType):
|
||||
pass
|
||||
|
||||
|
||||
class BaseList(list, Elements):
|
||||
|
||||
def __init__(self, *args):
|
||||
items = self.extract_args(*args)
|
||||
|
||||
if not self.value_check(items):
|
||||
raise ValueError(f"Bad input for class {self.__class__}: {items}")
|
||||
self.items = items
|
||||
super().__init__(items)
|
||||
|
||||
@classmethod
|
||||
def value_check(cls, value):
|
||||
|
@ -284,39 +288,32 @@ class Elements(ParamsBase, metaclass=ElementsType):
|
|||
|
||||
def __str__(self):
|
||||
cls = self.__class__
|
||||
return f"{cls.__name__}[{cls.elem_type.__name__}, {cls.length}]({', '.join(str(v) for v in self.items)})"
|
||||
return f"{cls.__name__}[{cls.elem_type.__name__}, {cls.length}]({', '.join(str(v) for v in self)})"
|
||||
|
||||
def __getitem__(self, i) -> SSZValue:
|
||||
return self.items[i]
|
||||
if i < 0:
|
||||
raise IndexError(f"cannot get item in type {self.__class__} at negative index {i}")
|
||||
if i > len(self):
|
||||
raise IndexError(f"cannot get item in type {self.__class__}"
|
||||
f" at out of bounds index {i}")
|
||||
return super().__getitem__(i)
|
||||
|
||||
def __setitem__(self, k, v):
|
||||
if k < 0:
|
||||
raise IndexError(f"cannot set item in type {self.__class__} at negative index {k} (to {v})")
|
||||
if k > len(self.items):
|
||||
if k > len(self):
|
||||
raise IndexError(f"cannot set item in type {self.__class__}"
|
||||
f" at out of bounds index {k} (to {v}, bound: {len(self.items)})")
|
||||
self.items[k] = coerce_type_maybe(v, self.__class__.elem_type, strict=True)
|
||||
f" at out of bounds index {k} (to {v}, bound: {len(self)})")
|
||||
super().__setitem__(k, coerce_type_maybe(v, self.__class__.elem_type, strict=True))
|
||||
|
||||
def append(self, v):
|
||||
self.items.append(coerce_type_maybe(v, self.__class__.elem_type, strict=True))
|
||||
|
||||
def pop(self):
|
||||
if len(self.items) == 0:
|
||||
raise IndexError("Pop from empty list")
|
||||
else:
|
||||
return self.items.pop()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.items)
|
||||
super().append(coerce_type_maybe(v, self.__class__.elem_type, strict=True))
|
||||
|
||||
def __iter__(self) -> Iterator[SSZValue]:
|
||||
return iter(self.items)
|
||||
return super().__iter__()
|
||||
|
||||
|
||||
class List(Elements):
|
||||
class List(BaseList):
|
||||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
|
@ -327,7 +324,7 @@ class List(Elements):
|
|||
return False
|
||||
|
||||
|
||||
class Vector(Elements):
|
||||
class Vector(BaseList):
|
||||
|
||||
@classmethod
|
||||
def value_check(cls, value):
|
||||
|
@ -342,27 +339,35 @@ class Vector(Elements):
|
|||
def is_fixed_size(cls):
|
||||
return cls.elem_type.is_fixed_size()
|
||||
|
||||
def append(self, v):
|
||||
raise Exception("cannot modify vector length")
|
||||
|
||||
def pop(self, *args):
|
||||
raise Exception("cannot modify vector length")
|
||||
|
||||
|
||||
class BytesType(ElementsType):
|
||||
elem_type: SSZType = byte
|
||||
length: int
|
||||
|
||||
|
||||
class BytesLike(Elements, metaclass=BytesType):
|
||||
class BaseBytes(bytes, Elements, metaclass=BytesType):
|
||||
|
||||
def __new__(cls, *args) -> "BaseBytes":
|
||||
extracted_val = cls.extract_args(*args)
|
||||
if not cls.value_check(extracted_val):
|
||||
raise ValueError(f"Bad input for class {cls}: {extracted_val}")
|
||||
return super().__new__(cls, extracted_val)
|
||||
|
||||
@classmethod
|
||||
def extract_args(cls, *args):
|
||||
x = list(args)
|
||||
if len(x) == 1 and isinstance(x[0], (GeneratorType, bytes, BytesLike)):
|
||||
x = args
|
||||
if len(x) == 1 and isinstance(x[0], (GeneratorType, bytes)):
|
||||
x = x[0]
|
||||
if isinstance(x, bytes):
|
||||
if isinstance(x, bytes): # Includes BytesLike
|
||||
return x
|
||||
elif isinstance(x, BytesLike):
|
||||
return x.items
|
||||
elif isinstance(x, GeneratorType):
|
||||
return bytes(x)
|
||||
else:
|
||||
return bytes(x)
|
||||
return bytes(x) # E.g. GeneratorType put into bytes.
|
||||
|
||||
@classmethod
|
||||
def value_check(cls, value):
|
||||
|
@ -374,7 +379,7 @@ class BytesLike(Elements, metaclass=BytesType):
|
|||
return f"{cls.__name__}[{cls.length}]: {self.hex()}"
|
||||
|
||||
|
||||
class Bytes(BytesLike):
|
||||
class Bytes(BaseBytes):
|
||||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
|
@ -385,7 +390,14 @@ class Bytes(BytesLike):
|
|||
return False
|
||||
|
||||
|
||||
class BytesN(BytesLike):
|
||||
class BytesN(BaseBytes):
|
||||
|
||||
@classmethod
|
||||
def extract_args(cls, *args):
|
||||
if len(args) == 0:
|
||||
return cls.default()
|
||||
else:
|
||||
return super().extract_args(*args)
|
||||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
|
|
|
@ -211,3 +211,5 @@ def test_bytesn_subclass():
|
|||
assert issubclass(Hash, Bytes32)
|
||||
|
||||
assert not issubclass(Bytes48, Bytes32)
|
||||
|
||||
assert len(Bytes32() + Bytes48()) == 80
|
||||
|
|
Loading…
Reference in New Issue