mirror of
https://github.com/status-im/eth2.0-specs.git
synced 2025-01-12 19:54:34 +00:00
Simplified SSZ impl
This commit is contained in:
parent
29dbe1b880
commit
5ddfe34f0c
@ -1,9 +1,8 @@
|
||||
from ..merkle_minimal import merkleize_chunks, hash
|
||||
from eth2spec.utils.ssz.ssz_typing import (
|
||||
from ..merkle_minimal import merkleize_chunks, hash, ZERO_BYTES32
|
||||
from .ssz_typing import (
|
||||
is_uint_type, is_bool_type, is_container_type,
|
||||
is_list_kind, is_vector_kind,
|
||||
read_vector_elem_type, read_elem_type,
|
||||
uint_byte_size,
|
||||
read_elem_type,
|
||||
infer_input_type,
|
||||
get_zero_value,
|
||||
)
|
||||
@ -20,7 +19,7 @@ def is_basic_type(typ):
|
||||
|
||||
def serialize_basic(value, typ):
|
||||
if is_uint_type(typ):
|
||||
return value.to_bytes(uint_byte_size(typ), 'little')
|
||||
return value.to_bytes(typ.byte_len, 'little')
|
||||
elif is_bool_type(typ):
|
||||
if value:
|
||||
return b'\x01'
|
||||
@ -140,6 +139,8 @@ def get_typed_values(obj, typ=None):
|
||||
else:
|
||||
raise Exception("Invalid type")
|
||||
|
||||
def item_length(typ):
|
||||
return 1 if typ == bool else typ.byte_len if is_uint_type(typ) else 32
|
||||
|
||||
@infer_input_type
|
||||
def hash_tree_root(obj, typ=None):
|
||||
@ -150,6 +151,8 @@ def hash_tree_root(obj, typ=None):
|
||||
fields = get_typed_values(obj, typ=typ)
|
||||
leaves = [hash_tree_root(field_value, typ=field_typ) for field_value, field_typ in fields]
|
||||
if is_list_kind(typ):
|
||||
full_chunk_length = (item_length(read_elem_type(typ)) * typ.length + 31) // 32
|
||||
leaves += [ZERO_BYTES32] * (full_chunk_length - len(obj))
|
||||
return mix_in_length(merkleize_chunks(leaves), len(obj))
|
||||
else:
|
||||
return merkleize_chunks(leaves)
|
||||
|
@ -6,74 +6,38 @@ from typing_inspect import get_origin
|
||||
# SSZ integers
|
||||
# -----------------------------
|
||||
|
||||
|
||||
class uint(int):
|
||||
byte_len = 0
|
||||
|
||||
def __new__(cls, value, *args, **kwargs):
|
||||
if value < 0:
|
||||
raise ValueError("unsigned types must not be negative")
|
||||
if value.byte_len and value.bit_length() > value.byte_len:
|
||||
raise ValueError("value out of bounds for uint{}".format(value.byte_len))
|
||||
return super().__new__(cls, value)
|
||||
|
||||
|
||||
class uint8(uint):
|
||||
byte_len = 1
|
||||
|
||||
def __new__(cls, value, *args, **kwargs):
|
||||
if value.bit_length() > 8:
|
||||
raise ValueError("value out of bounds for uint8")
|
||||
return super().__new__(cls, value)
|
||||
|
||||
|
||||
# Alias for uint8
|
||||
byte = NewType('byte', uint8)
|
||||
|
||||
|
||||
class uint16(uint):
|
||||
byte_len = 2
|
||||
|
||||
def __new__(cls, value, *args, **kwargs):
|
||||
if value.bit_length() > 16:
|
||||
raise ValueError("value out of bounds for uint16")
|
||||
return super().__new__(cls, value)
|
||||
|
||||
|
||||
class uint32(uint):
|
||||
byte_len = 4
|
||||
|
||||
def __new__(cls, value, *args, **kwargs):
|
||||
if value.bit_length() > 32:
|
||||
raise ValueError("value out of bounds for uint16")
|
||||
return super().__new__(cls, value)
|
||||
|
||||
|
||||
class uint64(uint):
|
||||
byte_len = 8
|
||||
|
||||
def __new__(cls, value, *args, **kwargs):
|
||||
if value.bit_length() > 64:
|
||||
raise ValueError("value out of bounds for uint64")
|
||||
return super().__new__(cls, value)
|
||||
|
||||
|
||||
class uint128(uint):
|
||||
byte_len = 16
|
||||
|
||||
def __new__(cls, value, *args, **kwargs):
|
||||
if value.bit_length() > 128:
|
||||
raise ValueError("value out of bounds for uint128")
|
||||
return super().__new__(cls, value)
|
||||
|
||||
|
||||
class uint256(uint):
|
||||
byte_len = 32
|
||||
|
||||
def __new__(cls, value, *args, **kwargs):
|
||||
if value.bit_length() > 256:
|
||||
raise ValueError("value out of bounds for uint256")
|
||||
return super().__new__(cls, value)
|
||||
|
||||
|
||||
def is_uint_type(typ):
|
||||
# All integers are uint in the scope of the spec here.
|
||||
# Since we default to uint64. Bounds can be checked elsewhere.
|
||||
@ -84,21 +48,6 @@ def is_uint_type(typ):
|
||||
|
||||
return isinstance(typ, type) and issubclass(typ, int) and not issubclass(typ, bool)
|
||||
|
||||
|
||||
def uint_byte_size(typ):
|
||||
if hasattr(typ, '__supertype__'):
|
||||
typ = typ.__supertype__
|
||||
|
||||
if isinstance(typ, type):
|
||||
if issubclass(typ, uint):
|
||||
return typ.byte_len
|
||||
elif issubclass(typ, int):
|
||||
# Default to uint64
|
||||
return 8
|
||||
else:
|
||||
raise TypeError("Type %s is not an uint (or int-default uint64) type" % typ)
|
||||
|
||||
|
||||
# SSZ Container base class
|
||||
# -----------------------------
|
||||
|
||||
@ -166,45 +115,19 @@ class Container(object):
|
||||
return list(cls.__annotations__.values())
|
||||
|
||||
|
||||
# SSZ vector
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def _is_vector_instance_of(a, b):
|
||||
# Other must not be a BytesN
|
||||
if issubclass(b, bytes):
|
||||
return False
|
||||
elif not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
|
||||
# Vector (b) is not an instance of Vector[X, Y] (a)
|
||||
return False
|
||||
elif not hasattr(a, 'elem_type') or not hasattr(a, 'length'):
|
||||
# Vector[X, Y] (b) is an instance of Vector (a)
|
||||
return True
|
||||
def get_zero_value(typ):
|
||||
if typ == int:
|
||||
return 0
|
||||
else:
|
||||
# Vector[X, Y] (a) is an instance of Vector[X, Y] (b)
|
||||
return a.elem_type == b.elem_type and a.length == b.length
|
||||
return typ.default()
|
||||
|
||||
|
||||
def _is_equal_vector_type(a, b):
|
||||
# Other must not be a BytesN
|
||||
if issubclass(b, bytes):
|
||||
return False
|
||||
elif not hasattr(a, 'elem_type') or not hasattr(a, 'length'):
|
||||
if not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
|
||||
# Vector == Vector
|
||||
return True
|
||||
else:
|
||||
# Vector != Vector[X, Y]
|
||||
return False
|
||||
elif not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
|
||||
# Vector[X, Y] != Vector
|
||||
return False
|
||||
def type_check(typ, value):
|
||||
if typ == int or typ == uint64:
|
||||
return isinstance(value, int)
|
||||
else:
|
||||
# Vector[X, Y] == Vector[X, Y]
|
||||
return a.elem_type == b.elem_type and a.length == b.length
|
||||
return typ.value_check(value)
|
||||
|
||||
|
||||
class VectorMeta(type):
|
||||
class AbstractListMeta(type):
|
||||
def __new__(cls, class_name, parents, attrs):
|
||||
out = type.__new__(cls, class_name, parents, attrs)
|
||||
if 'elem_type' in attrs and 'length' in attrs:
|
||||
@ -214,239 +137,115 @@ class VectorMeta(type):
|
||||
|
||||
def __getitem__(self, params):
|
||||
if not isinstance(params, tuple) or len(params) != 2:
|
||||
raise Exception("Vector must be instantiated with two args: elem type and length")
|
||||
o = self.__class__(self.__name__, (Vector,), {'elem_type': params[0], 'length': params[1]})
|
||||
o._name = 'Vector'
|
||||
raise Exception("List must be instantiated with two args: elem type and length")
|
||||
o = self.__class__(self.__name__, (self,), {'elem_type': params[0], 'length': params[1]})
|
||||
o._name = 'AbstractList'
|
||||
return o
|
||||
|
||||
def __subclasscheck__(self, sub):
|
||||
return _is_vector_instance_of(self, sub)
|
||||
def __instancecheck__(self, obj):
|
||||
if obj.__class__.__name__ != self.__name__:
|
||||
return False
|
||||
if hasattr(self, 'elem_type') and obj.__class__.elem_type != self.elem_type:
|
||||
return False
|
||||
if hasattr(self, 'length') and obj.__class__.length != self.length:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __instancecheck__(self, other):
|
||||
return _is_vector_instance_of(self, other.__class__)
|
||||
class ValueCheckError(Exception):
|
||||
pass
|
||||
|
||||
def __eq__(self, other):
|
||||
return _is_equal_vector_type(self, other)
|
||||
class AbstractList(metaclass=AbstractListMeta):
|
||||
def __init__(self, *args):
|
||||
items = self.extract_args(args)
|
||||
|
||||
if not self.value_check(items):
|
||||
raise ValueCheckError("Bad input for class {}: {}".format(self.__class__, items))
|
||||
self.items = items
|
||||
|
||||
def value_check(self, value):
|
||||
for v in value:
|
||||
if not type_check(self.__class__.elem_type, v):
|
||||
return False
|
||||
return True
|
||||
|
||||
def __ne__(self, other):
|
||||
return not _is_equal_vector_type(self, other)
|
||||
def extract_args(self, args):
|
||||
return list(args) if len(args) > 0 else self.default()
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.__class__)
|
||||
def default(self):
|
||||
raise Exception("Not implemented")
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.items[i]
|
||||
|
||||
class Vector(metaclass=VectorMeta):
|
||||
|
||||
def __init__(self, *args: Iterable):
|
||||
cls = self.__class__
|
||||
if not hasattr(cls, 'elem_type'):
|
||||
raise TypeError("Type Vector without elem_type data cannot be instantiated")
|
||||
elif not hasattr(cls, 'length'):
|
||||
raise TypeError("Type Vector without length data cannot be instantiated")
|
||||
|
||||
if len(args) != cls.length:
|
||||
if len(args) == 0:
|
||||
args = [get_zero_value(cls.elem_type) for _ in range(cls.length)]
|
||||
else:
|
||||
raise TypeError("Typed vector with length %d cannot hold %d items" % (cls.length, len(args)))
|
||||
|
||||
self.items = list(args)
|
||||
|
||||
# cannot check non-type objects, or parametrized types
|
||||
if isinstance(cls.elem_type, type) and not hasattr(cls.elem_type, '__args__'):
|
||||
for i, item in enumerate(self.items):
|
||||
if not issubclass(cls.elem_type, type(item)):
|
||||
raise TypeError("Typed vector cannot hold differently typed value"
|
||||
" at index %d. Got type: %s, expected type: %s" % (i, type(item), cls.elem_type))
|
||||
|
||||
def serialize(self):
|
||||
from .ssz_impl import serialize
|
||||
return serialize(self, self.__class__)
|
||||
|
||||
def hash_tree_root(self):
|
||||
from .ssz_impl import hash_tree_root
|
||||
return hash_tree_root(self, self.__class__)
|
||||
|
||||
def __repr__(self):
|
||||
return repr({'length': self.__class__.length, 'items': self.items})
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.items[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.items[key] = value
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.items)
|
||||
def __setitem__(self, k, v):
|
||||
self.items[k] = v
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.hash_tree_root() == other.hash_tree_root()
|
||||
def __repr__(self):
|
||||
return repr(self.items)
|
||||
|
||||
|
||||
# SSZ BytesN
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def _is_bytes_n_instance_of(a, b):
|
||||
# Other has to be a Bytes derivative class to be a BytesN
|
||||
if not issubclass(b, bytes):
|
||||
return False
|
||||
elif not hasattr(b, 'length'):
|
||||
# BytesN (b) is not an instance of BytesN[X] (a)
|
||||
return False
|
||||
elif not hasattr(a, 'length'):
|
||||
# BytesN[X] (b) is an instance of BytesN (a)
|
||||
return True
|
||||
else:
|
||||
# BytesN[X] (a) is an instance of BytesN[X] (b)
|
||||
return a.length == b.length
|
||||
|
||||
|
||||
def _is_equal_bytes_n_type(a, b):
|
||||
# Other has to be a Bytes derivative class to be a BytesN
|
||||
if not issubclass(b, bytes):
|
||||
return False
|
||||
elif not hasattr(a, 'length'):
|
||||
if not hasattr(b, 'length'):
|
||||
# BytesN == BytesN
|
||||
return True
|
||||
else:
|
||||
# BytesN != BytesN[X]
|
||||
return False
|
||||
elif not hasattr(b, 'length'):
|
||||
# BytesN[X] != BytesN
|
||||
return False
|
||||
else:
|
||||
# BytesN[X] == BytesN[X]
|
||||
return a.length == b.length
|
||||
|
||||
|
||||
class BytesNMeta(type):
|
||||
def __new__(cls, class_name, parents, attrs):
|
||||
out = type.__new__(cls, class_name, parents, attrs)
|
||||
if 'length' in attrs:
|
||||
setattr(out, 'length', attrs['length'])
|
||||
out._name = 'BytesN'
|
||||
out.elem_type = byte
|
||||
return out
|
||||
|
||||
def __getitem__(self, n):
|
||||
return self.__class__(self.__name__, (BytesN,), {'length': n})
|
||||
|
||||
def __subclasscheck__(self, sub):
|
||||
return _is_bytes_n_instance_of(self, sub)
|
||||
|
||||
def __instancecheck__(self, other):
|
||||
return _is_bytes_n_instance_of(self, other.__class__)
|
||||
def __iter__(self):
|
||||
return iter(self.items)
|
||||
|
||||
def __eq__(self, other):
|
||||
return _is_equal_bytes_n_type(self, other)
|
||||
return self.items == other.items
|
||||
|
||||
def __ne__(self, other):
|
||||
return not _is_equal_bytes_n_type(self, other)
|
||||
class List(AbstractList, metaclass=AbstractListMeta):
|
||||
def value_check(self, value):
|
||||
return len(value) <= self.__class__.length and super().value_check(value)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.__class__)
|
||||
|
||||
|
||||
def parse_bytes(val):
|
||||
if val is None:
|
||||
return None
|
||||
elif isinstance(val, str):
|
||||
# TODO: import from eth-utils instead, and do: hexstr_if_str(to_bytes, val)
|
||||
return None
|
||||
elif isinstance(val, bytes):
|
||||
return val
|
||||
elif isinstance(val, int):
|
||||
return bytes([val])
|
||||
elif isinstance(val, (list, GeneratorType)):
|
||||
return bytes(val)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class BytesN(bytes, metaclass=BytesNMeta):
|
||||
def __new__(cls, *args):
|
||||
if not hasattr(cls, 'length'):
|
||||
return
|
||||
bytesval = None
|
||||
if len(args) == 1:
|
||||
val: Union[bytes, int, str] = args[0]
|
||||
bytesval = parse_bytes(val)
|
||||
elif len(args) > 1:
|
||||
# TODO: each int is 1 byte, check size, create bytesval
|
||||
bytesval = bytes(args)
|
||||
|
||||
if bytesval is None:
|
||||
if cls.length == 0:
|
||||
bytesval = b''
|
||||
else:
|
||||
bytesval = b'\x00' * cls.length
|
||||
if len(bytesval) != cls.length:
|
||||
raise TypeError("BytesN[%d] cannot be initialized with value of %d bytes" % (cls.length, len(bytesval)))
|
||||
return super().__new__(cls, bytesval)
|
||||
|
||||
def serialize(self):
|
||||
from .ssz_impl import serialize
|
||||
return serialize(self, self.__class__)
|
||||
|
||||
def hash_tree_root(self):
|
||||
from .ssz_impl import hash_tree_root
|
||||
return hash_tree_root(self, self.__class__)
|
||||
|
||||
|
||||
class Bytes4(BytesN):
|
||||
length = 4
|
||||
|
||||
|
||||
class Bytes32(BytesN):
|
||||
length = 32
|
||||
|
||||
|
||||
class Bytes48(BytesN):
|
||||
length = 48
|
||||
|
||||
|
||||
class Bytes96(BytesN):
|
||||
length = 96
|
||||
|
||||
|
||||
# SSZ Defaults
|
||||
# -----------------------------
|
||||
def get_zero_value(typ):
|
||||
if is_uint_type(typ):
|
||||
return uint64(0)
|
||||
elif is_list_type(typ):
|
||||
def default(self):
|
||||
return []
|
||||
elif is_bool_type(typ):
|
||||
return False
|
||||
elif is_vector_type(typ):
|
||||
return typ()
|
||||
elif is_bytesn_type(typ):
|
||||
return typ()
|
||||
elif is_bytes_type(typ):
|
||||
|
||||
class Vector(AbstractList, metaclass=AbstractListMeta):
|
||||
def value_check(self, value):
|
||||
return len(value) == self.__class__.length and super().value_check(value)
|
||||
|
||||
def default(self):
|
||||
return [get_zero_value(self.__class__.elem_type) for _ in range(self.__class__.length)]
|
||||
|
||||
class BytesMeta(AbstractListMeta):
|
||||
def __getitem__(self, params):
|
||||
if not isinstance(params, int):
|
||||
raise Exception("Bytes must be instantiated with one arg: length")
|
||||
o = self.__class__(self.__name__, (self,), {'length': params})
|
||||
o._name = 'Bytes'
|
||||
return o
|
||||
|
||||
def single_item_extractor(cls, args):
|
||||
assert len(args) < 2
|
||||
return args[0] if len(args) > 0 else cls.default()
|
||||
|
||||
class Bytes(AbstractList, metaclass=BytesMeta):
|
||||
def value_check(self, value):
|
||||
return len(value) <= self.__class__.length and isinstance(value, bytes)
|
||||
|
||||
extract_args = single_item_extractor
|
||||
|
||||
def default(self):
|
||||
return b''
|
||||
elif is_container_type(typ):
|
||||
return typ(**{f: get_zero_value(t) for f, t in typ.get_fields()})
|
||||
else:
|
||||
raise Exception("Type not supported: {}".format(typ))
|
||||
|
||||
class BytesN(AbstractList, metaclass=BytesMeta):
|
||||
def value_check(self, value):
|
||||
return len(value) == self.__class__.length and isinstance(value, bytes)
|
||||
|
||||
extract_args = single_item_extractor
|
||||
|
||||
def default(self):
|
||||
return b'\x00' * self.__class__.length
|
||||
|
||||
|
||||
# Type helpers
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def infer_type(obj):
|
||||
if is_uint_type(obj.__class__):
|
||||
return obj.__class__
|
||||
elif isinstance(obj, int):
|
||||
return uint64
|
||||
elif isinstance(obj, list):
|
||||
return List[infer_type(obj[0])]
|
||||
elif isinstance(obj, (Vector, Container, bool, BytesN, bytes)):
|
||||
elif isinstance(obj, (List, Vector, Container, bool, BytesN, Bytes)):
|
||||
return obj.__class__
|
||||
else:
|
||||
raise Exception("Unknown type for {}".format(obj))
|
||||
@ -476,15 +275,14 @@ def is_list_type(typ):
|
||||
"""
|
||||
Check if the given type is a list.
|
||||
"""
|
||||
return get_origin(typ) is List or get_origin(typ) is list
|
||||
return isinstance(typ, type) and issubclass(typ, List)
|
||||
|
||||
|
||||
def is_bytes_type(typ):
|
||||
"""
|
||||
Check if the given type is a ``bytes``.
|
||||
"""
|
||||
# Do not accept subclasses of bytes here, to avoid confusion with BytesN
|
||||
return typ == bytes
|
||||
return isinstance(typ, type) and issubclass(typ, Bytes)
|
||||
|
||||
|
||||
def is_bytesn_type(typ):
|
||||
@ -526,22 +324,5 @@ T = TypeVar('T')
|
||||
L = TypeVar('L')
|
||||
|
||||
|
||||
def read_list_elem_type(list_typ: Type[List[T]]) -> T:
|
||||
if list_typ.__args__ is None or len(list_typ.__args__) != 1:
|
||||
raise TypeError("Supplied list-type is invalid, no element type found.")
|
||||
return list_typ.__args__[0]
|
||||
|
||||
|
||||
def read_vector_elem_type(vector_typ: Type[Vector[T, L]]) -> T:
|
||||
return vector_typ.elem_type
|
||||
|
||||
|
||||
def read_elem_type(typ):
|
||||
if typ == bytes or (isinstance(typ, type) and issubclass(typ, bytes)): # bytes or bytesN
|
||||
return byte
|
||||
elif is_list_type(typ):
|
||||
return read_list_elem_type(typ)
|
||||
elif is_vector_type(typ):
|
||||
return read_vector_elem_type(typ)
|
||||
else:
|
||||
raise TypeError("Unexpected type: {}".format(typ))
|
||||
return typ.elem_type
|
||||
|
Loading…
x
Reference in New Issue
Block a user