ssz typing now subclasses list/bytes, much easier to work with than wrapped list/bytes functionality
This commit is contained in:
parent
4e747fb887
commit
977856b06f
|
@ -1,7 +1,7 @@
|
||||||
from ..merkle_minimal import merkleize_chunks
|
from ..merkle_minimal import merkleize_chunks
|
||||||
from ..hash_function import hash
|
from ..hash_function import hash
|
||||||
from .ssz_typing import (
|
from .ssz_typing import (
|
||||||
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Bytes, BytesN, uint,
|
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Bytes, uint,
|
||||||
)
|
)
|
||||||
|
|
||||||
# SSZ Serialization
|
# SSZ Serialization
|
||||||
|
@ -46,9 +46,8 @@ def serialize(obj: SSZValue):
|
||||||
|
|
||||||
|
|
||||||
def encode_series(values: Series):
|
def encode_series(values: Series):
|
||||||
# bytes and bytesN are already in the right format.
|
if isinstance(values, bytes): # Bytes and BytesN are already like serialized output
|
||||||
if isinstance(values, (Bytes, BytesN)):
|
return values
|
||||||
return values.items
|
|
||||||
|
|
||||||
# Recursively serialize
|
# Recursively serialize
|
||||||
parts = [(v.type().is_fixed_size(), serialize(v)) for v in values]
|
parts = [(v.type().is_fixed_size(), serialize(v)) for v in values]
|
||||||
|
@ -84,8 +83,8 @@ def encode_series(values: Series):
|
||||||
|
|
||||||
|
|
||||||
def pack(values: Series):
|
def pack(values: Series):
|
||||||
if isinstance(values, (Bytes, BytesN)):
|
if isinstance(values, bytes): # Bytes and BytesN are already packed
|
||||||
return values.items
|
return values
|
||||||
return b''.join([serialize_basic(value) for value in values])
|
return b''.join([serialize_basic(value) for value in values])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -188,10 +188,10 @@ class Container(Series, metaclass=SSZType):
|
||||||
|
|
||||||
|
|
||||||
class ParamsBase(Series):
|
class ParamsBase(Series):
|
||||||
_bare = True
|
_has_params = False
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
if cls._bare:
|
if not cls._has_params:
|
||||||
raise Exception("cannot init bare type without params")
|
raise Exception("cannot init bare type without params")
|
||||||
return super().__new__(cls, **kwargs)
|
return super().__new__(cls, **kwargs)
|
||||||
|
|
||||||
|
@ -200,13 +200,13 @@ class ParamsMeta(SSZType):
|
||||||
|
|
||||||
def __new__(cls, class_name, parents, attrs):
|
def __new__(cls, class_name, parents, attrs):
|
||||||
out = type.__new__(cls, class_name, parents, attrs)
|
out = type.__new__(cls, class_name, parents, attrs)
|
||||||
|
if hasattr(out, "_has_params") and getattr(out, "_has_params"):
|
||||||
for k, v in attrs.items():
|
for k, v in attrs.items():
|
||||||
setattr(out, k, v)
|
setattr(out, k, v)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def __getitem__(self, params):
|
def __getitem__(self, params):
|
||||||
o = self.__class__(self.__name__, (self,), self.attr_from_params(params))
|
o = self.__class__(self.__name__, (self,), self.attr_from_params(params))
|
||||||
o._bare = False
|
|
||||||
return o
|
return o
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
@ -218,7 +218,7 @@ class ParamsMeta(SSZType):
|
||||||
def attr_from_params(self, p):
|
def attr_from_params(self, p):
|
||||||
# single key params are valid too. Wrap them in a tuple.
|
# single key params are valid too. Wrap them in a tuple.
|
||||||
params = p if isinstance(p, tuple) else (p,)
|
params = p if isinstance(p, tuple) else (p,)
|
||||||
res = {}
|
res = {'_has_params': True}
|
||||||
i = 0
|
i = 0
|
||||||
for (name, typ) in self.__annotations__.items():
|
for (name, typ) in self.__annotations__.items():
|
||||||
if hasattr(self.__class__, name):
|
if hasattr(self.__class__, name):
|
||||||
|
@ -262,13 +262,17 @@ class ElementsType(ParamsMeta):
|
||||||
|
|
||||||
|
|
||||||
class Elements(ParamsBase, metaclass=ElementsType):
|
class Elements(ParamsBase, metaclass=ElementsType):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseList(list, Elements):
|
||||||
|
|
||||||
def __init__(self, *args):
|
def __init__(self, *args):
|
||||||
items = self.extract_args(*args)
|
items = self.extract_args(*args)
|
||||||
|
|
||||||
if not self.value_check(items):
|
if not self.value_check(items):
|
||||||
raise ValueError(f"Bad input for class {self.__class__}: {items}")
|
raise ValueError(f"Bad input for class {self.__class__}: {items}")
|
||||||
self.items = items
|
super().__init__(items)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_check(cls, value):
|
def value_check(cls, value):
|
||||||
|
@ -284,39 +288,32 @@ class Elements(ParamsBase, metaclass=ElementsType):
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
cls = self.__class__
|
cls = self.__class__
|
||||||
return f"{cls.__name__}[{cls.elem_type.__name__}, {cls.length}]({', '.join(str(v) for v in self.items)})"
|
return f"{cls.__name__}[{cls.elem_type.__name__}, {cls.length}]({', '.join(str(v) for v in self)})"
|
||||||
|
|
||||||
def __getitem__(self, i) -> SSZValue:
|
def __getitem__(self, i) -> SSZValue:
|
||||||
return self.items[i]
|
if i < 0:
|
||||||
|
raise IndexError(f"cannot get item in type {self.__class__} at negative index {i}")
|
||||||
|
if i > len(self):
|
||||||
|
raise IndexError(f"cannot get item in type {self.__class__}"
|
||||||
|
f" at out of bounds index {i}")
|
||||||
|
return super().__getitem__(i)
|
||||||
|
|
||||||
def __setitem__(self, k, v):
|
def __setitem__(self, k, v):
|
||||||
if k < 0:
|
if k < 0:
|
||||||
raise IndexError(f"cannot set item in type {self.__class__} at negative index {k} (to {v})")
|
raise IndexError(f"cannot set item in type {self.__class__} at negative index {k} (to {v})")
|
||||||
if k > len(self.items):
|
if k > len(self):
|
||||||
raise IndexError(f"cannot set item in type {self.__class__}"
|
raise IndexError(f"cannot set item in type {self.__class__}"
|
||||||
f" at out of bounds index {k} (to {v}, bound: {len(self.items)})")
|
f" at out of bounds index {k} (to {v}, bound: {len(self)})")
|
||||||
self.items[k] = coerce_type_maybe(v, self.__class__.elem_type, strict=True)
|
super().__setitem__(k, coerce_type_maybe(v, self.__class__.elem_type, strict=True))
|
||||||
|
|
||||||
def append(self, v):
|
def append(self, v):
|
||||||
self.items.append(coerce_type_maybe(v, self.__class__.elem_type, strict=True))
|
super().append(coerce_type_maybe(v, self.__class__.elem_type, strict=True))
|
||||||
|
|
||||||
def pop(self):
|
|
||||||
if len(self.items) == 0:
|
|
||||||
raise IndexError("Pop from empty list")
|
|
||||||
else:
|
|
||||||
return self.items.pop()
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.items)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return repr(self.items)
|
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[SSZValue]:
|
def __iter__(self) -> Iterator[SSZValue]:
|
||||||
return iter(self.items)
|
return super().__iter__()
|
||||||
|
|
||||||
|
|
||||||
class List(Elements):
|
class List(BaseList):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default(cls):
|
def default(cls):
|
||||||
|
@ -327,7 +324,7 @@ class List(Elements):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class Vector(Elements):
|
class Vector(BaseList):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_check(cls, value):
|
def value_check(cls, value):
|
||||||
|
@ -342,27 +339,35 @@ class Vector(Elements):
|
||||||
def is_fixed_size(cls):
|
def is_fixed_size(cls):
|
||||||
return cls.elem_type.is_fixed_size()
|
return cls.elem_type.is_fixed_size()
|
||||||
|
|
||||||
|
def append(self, v):
|
||||||
|
raise Exception("cannot modify vector length")
|
||||||
|
|
||||||
|
def pop(self, *args):
|
||||||
|
raise Exception("cannot modify vector length")
|
||||||
|
|
||||||
|
|
||||||
class BytesType(ElementsType):
|
class BytesType(ElementsType):
|
||||||
elem_type: SSZType = byte
|
elem_type: SSZType = byte
|
||||||
length: int
|
length: int
|
||||||
|
|
||||||
|
|
||||||
class BytesLike(Elements, metaclass=BytesType):
|
class BaseBytes(bytes, Elements, metaclass=BytesType):
|
||||||
|
|
||||||
|
def __new__(cls, *args) -> "BaseBytes":
|
||||||
|
extracted_val = cls.extract_args(*args)
|
||||||
|
if not cls.value_check(extracted_val):
|
||||||
|
raise ValueError(f"Bad input for class {cls}: {extracted_val}")
|
||||||
|
return super().__new__(cls, extracted_val)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extract_args(cls, *args):
|
def extract_args(cls, *args):
|
||||||
x = list(args)
|
x = args
|
||||||
if len(x) == 1 and isinstance(x[0], (GeneratorType, bytes, BytesLike)):
|
if len(x) == 1 and isinstance(x[0], (GeneratorType, bytes)):
|
||||||
x = x[0]
|
x = x[0]
|
||||||
if isinstance(x, bytes):
|
if isinstance(x, bytes): # Includes BytesLike
|
||||||
return x
|
return x
|
||||||
elif isinstance(x, BytesLike):
|
|
||||||
return x.items
|
|
||||||
elif isinstance(x, GeneratorType):
|
|
||||||
return bytes(x)
|
|
||||||
else:
|
else:
|
||||||
return bytes(x)
|
return bytes(x) # E.g. GeneratorType put into bytes.
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_check(cls, value):
|
def value_check(cls, value):
|
||||||
|
@ -374,7 +379,7 @@ class BytesLike(Elements, metaclass=BytesType):
|
||||||
return f"{cls.__name__}[{cls.length}]: {self.hex()}"
|
return f"{cls.__name__}[{cls.length}]: {self.hex()}"
|
||||||
|
|
||||||
|
|
||||||
class Bytes(BytesLike):
|
class Bytes(BaseBytes):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default(cls):
|
def default(cls):
|
||||||
|
@ -385,7 +390,14 @@ class Bytes(BytesLike):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class BytesN(BytesLike):
|
class BytesN(BaseBytes):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def extract_args(cls, *args):
|
||||||
|
if len(args) == 0:
|
||||||
|
return cls.default()
|
||||||
|
else:
|
||||||
|
return super().extract_args(*args)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default(cls):
|
def default(cls):
|
||||||
|
|
|
@ -211,3 +211,5 @@ def test_bytesn_subclass():
|
||||||
assert issubclass(Hash, Bytes32)
|
assert issubclass(Hash, Bytes32)
|
||||||
|
|
||||||
assert not issubclass(Bytes48, Bytes32)
|
assert not issubclass(Bytes48, Bytes32)
|
||||||
|
|
||||||
|
assert len(Bytes32() + Bytes48()) == 80
|
||||||
|
|
Loading…
Reference in New Issue