From 5ddfe34f0c3363589b402e88abde2c0aa1b05225 Mon Sep 17 00:00:00 2001 From: protolambda Date: Thu, 20 Jun 2019 19:51:38 +0200 Subject: [PATCH] Simplified SSZ impl --- .../pyspec/eth2spec/utils/ssz/ssz_impl.py | 13 +- .../pyspec/eth2spec/utils/ssz/ssz_typing.py | 407 ++++-------------- 2 files changed, 102 insertions(+), 318 deletions(-) diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py index b3c877d48..c88cfed1f 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py @@ -1,9 +1,8 @@ -from ..merkle_minimal import merkleize_chunks, hash -from eth2spec.utils.ssz.ssz_typing import ( +from ..merkle_minimal import merkleize_chunks, hash, ZERO_BYTES32 +from .ssz_typing import ( is_uint_type, is_bool_type, is_container_type, is_list_kind, is_vector_kind, - read_vector_elem_type, read_elem_type, - uint_byte_size, + read_elem_type, infer_input_type, get_zero_value, ) @@ -20,7 +19,7 @@ def is_basic_type(typ): def serialize_basic(value, typ): if is_uint_type(typ): - return value.to_bytes(uint_byte_size(typ), 'little') + return value.to_bytes(typ.byte_len, 'little') elif is_bool_type(typ): if value: return b'\x01' @@ -140,6 +139,8 @@ def get_typed_values(obj, typ=None): else: raise Exception("Invalid type") +def item_length(typ): + return 1 if typ == bool else typ.byte_len if is_uint_type(typ) else 32 @infer_input_type def hash_tree_root(obj, typ=None): @@ -150,6 +151,8 @@ def hash_tree_root(obj, typ=None): fields = get_typed_values(obj, typ=typ) leaves = [hash_tree_root(field_value, typ=field_typ) for field_value, field_typ in fields] if is_list_kind(typ): + full_chunk_length = (item_length(read_elem_type(typ)) * typ.length + 31) // 32 + leaves += [ZERO_BYTES32] * (full_chunk_length - len(obj)) return mix_in_length(merkleize_chunks(leaves), len(obj)) else: return merkleize_chunks(leaves) diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py index f870336e8..dbc3f9523 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py @@ -6,74 +6,38 @@ from typing_inspect import get_origin # SSZ integers # ----------------------------- - class uint(int): byte_len = 0 def __new__(cls, value, *args, **kwargs): if value < 0: raise ValueError("unsigned types must not be negative") + if value.byte_len and value.bit_length() > value.byte_len: + raise ValueError("value out of bounds for uint{}".format(value.byte_len)) return super().__new__(cls, value) class uint8(uint): byte_len = 1 - def __new__(cls, value, *args, **kwargs): - if value.bit_length() > 8: - raise ValueError("value out of bounds for uint8") - return super().__new__(cls, value) - - # Alias for uint8 byte = NewType('byte', uint8) - class uint16(uint): byte_len = 2 - def __new__(cls, value, *args, **kwargs): - if value.bit_length() > 16: - raise ValueError("value out of bounds for uint16") - return super().__new__(cls, value) - - class uint32(uint): byte_len = 4 - def __new__(cls, value, *args, **kwargs): - if value.bit_length() > 32: - raise ValueError("value out of bounds for uint16") - return super().__new__(cls, value) - - class uint64(uint): byte_len = 8 - def __new__(cls, value, *args, **kwargs): - if value.bit_length() > 64: - raise ValueError("value out of bounds for uint64") - return super().__new__(cls, value) - - class uint128(uint): byte_len = 16 - def __new__(cls, value, *args, **kwargs): - if value.bit_length() > 128: - raise ValueError("value out of bounds for uint128") - return super().__new__(cls, value) - - class uint256(uint): byte_len = 32 - def __new__(cls, value, *args, **kwargs): - if value.bit_length() > 256: - raise ValueError("value out of bounds for uint256") - return super().__new__(cls, value) - - def is_uint_type(typ): # All integers are uint in the scope of the spec here. # Since we default to uint64. Bounds can be checked elsewhere. @@ -84,21 +48,6 @@ def is_uint_type(typ): return isinstance(typ, type) and issubclass(typ, int) and not issubclass(typ, bool) - -def uint_byte_size(typ): - if hasattr(typ, '__supertype__'): - typ = typ.__supertype__ - - if isinstance(typ, type): - if issubclass(typ, uint): - return typ.byte_len - elif issubclass(typ, int): - # Default to uint64 - return 8 - else: - raise TypeError("Type %s is not an uint (or int-default uint64) type" % typ) - - # SSZ Container base class # ----------------------------- @@ -166,45 +115,19 @@ class Container(object): return list(cls.__annotations__.values()) -# SSZ vector -# ----------------------------- - - -def _is_vector_instance_of(a, b): - # Other must not be a BytesN - if issubclass(b, bytes): - return False - elif not hasattr(b, 'elem_type') or not hasattr(b, 'length'): - # Vector (b) is not an instance of Vector[X, Y] (a) - return False - elif not hasattr(a, 'elem_type') or not hasattr(a, 'length'): - # Vector[X, Y] (b) is an instance of Vector (a) - return True +def get_zero_value(typ): + if typ == int: + return 0 else: - # Vector[X, Y] (a) is an instance of Vector[X, Y] (b) - return a.elem_type == b.elem_type and a.length == b.length + return typ.default() - -def _is_equal_vector_type(a, b): - # Other must not be a BytesN - if issubclass(b, bytes): - return False - elif not hasattr(a, 'elem_type') or not hasattr(a, 'length'): - if not hasattr(b, 'elem_type') or not hasattr(b, 'length'): - # Vector == Vector - return True - else: - # Vector != Vector[X, Y] - return False - elif not hasattr(b, 'elem_type') or not hasattr(b, 'length'): - # Vector[X, Y] != Vector - return False +def type_check(typ, value): + if typ == int or typ == uint64: + return isinstance(value, int) else: - # Vector[X, Y] == Vector[X, Y] - return a.elem_type == b.elem_type and a.length == b.length + return typ.value_check(value) - -class VectorMeta(type): +class AbstractListMeta(type): def __new__(cls, class_name, parents, attrs): out = type.__new__(cls, class_name, parents, attrs) if 'elem_type' in attrs and 'length' in attrs: @@ -214,239 +137,115 @@ class VectorMeta(type): def __getitem__(self, params): if not isinstance(params, tuple) or len(params) != 2: - raise Exception("Vector must be instantiated with two args: elem type and length") - o = self.__class__(self.__name__, (Vector,), {'elem_type': params[0], 'length': params[1]}) - o._name = 'Vector' + raise Exception("List must be instantiated with two args: elem type and length") + o = self.__class__(self.__name__, (self,), {'elem_type': params[0], 'length': params[1]}) + o._name = 'AbstractList' return o - def __subclasscheck__(self, sub): - return _is_vector_instance_of(self, sub) + def __instancecheck__(self, obj): + if obj.__class__.__name__ != self.__name__: + return False + if hasattr(self, 'elem_type') and obj.__class__.elem_type != self.elem_type: + return False + if hasattr(self, 'length') and obj.__class__.length != self.length: + return False + return True - def __instancecheck__(self, other): - return _is_vector_instance_of(self, other.__class__) +class ValueCheckError(Exception): + pass - def __eq__(self, other): - return _is_equal_vector_type(self, other) +class AbstractList(metaclass=AbstractListMeta): + def __init__(self, *args): + items = self.extract_args(args) + + if not self.value_check(items): + raise ValueCheckError("Bad input for class {}: {}".format(self.__class__, items)) + self.items = items + + def value_check(self, value): + for v in value: + if not type_check(self.__class__.elem_type, v): + return False + return True - def __ne__(self, other): - return not _is_equal_vector_type(self, other) + def extract_args(self, args): + return list(args) if len(args) > 0 else self.default() - def __hash__(self): - return hash(self.__class__) + def default(self): + raise Exception("Not implemented") + def __getitem__(self, i): + return self.items[i] -class Vector(metaclass=VectorMeta): - - def __init__(self, *args: Iterable): - cls = self.__class__ - if not hasattr(cls, 'elem_type'): - raise TypeError("Type Vector without elem_type data cannot be instantiated") - elif not hasattr(cls, 'length'): - raise TypeError("Type Vector without length data cannot be instantiated") - - if len(args) != cls.length: - if len(args) == 0: - args = [get_zero_value(cls.elem_type) for _ in range(cls.length)] - else: - raise TypeError("Typed vector with length %d cannot hold %d items" % (cls.length, len(args))) - - self.items = list(args) - - # cannot check non-type objects, or parametrized types - if isinstance(cls.elem_type, type) and not hasattr(cls.elem_type, '__args__'): - for i, item in enumerate(self.items): - if not issubclass(cls.elem_type, type(item)): - raise TypeError("Typed vector cannot hold differently typed value" - " at index %d. Got type: %s, expected type: %s" % (i, type(item), cls.elem_type)) - - def serialize(self): - from .ssz_impl import serialize - return serialize(self, self.__class__) - - def hash_tree_root(self): - from .ssz_impl import hash_tree_root - return hash_tree_root(self, self.__class__) - - def __repr__(self): - return repr({'length': self.__class__.length, 'items': self.items}) - - def __getitem__(self, key): - return self.items[key] - - def __setitem__(self, key, value): - self.items[key] = value - - def __iter__(self): - return iter(self.items) + def __setitem__(self, k, v): + self.items[k] = v def __len__(self): return len(self.items) - def __eq__(self, other): - return self.hash_tree_root() == other.hash_tree_root() + def __repr__(self): + return repr(self.items) - -# SSZ BytesN -# ----------------------------- - - -def _is_bytes_n_instance_of(a, b): - # Other has to be a Bytes derivative class to be a BytesN - if not issubclass(b, bytes): - return False - elif not hasattr(b, 'length'): - # BytesN (b) is not an instance of BytesN[X] (a) - return False - elif not hasattr(a, 'length'): - # BytesN[X] (b) is an instance of BytesN (a) - return True - else: - # BytesN[X] (a) is an instance of BytesN[X] (b) - return a.length == b.length - - -def _is_equal_bytes_n_type(a, b): - # Other has to be a Bytes derivative class to be a BytesN - if not issubclass(b, bytes): - return False - elif not hasattr(a, 'length'): - if not hasattr(b, 'length'): - # BytesN == BytesN - return True - else: - # BytesN != BytesN[X] - return False - elif not hasattr(b, 'length'): - # BytesN[X] != BytesN - return False - else: - # BytesN[X] == BytesN[X] - return a.length == b.length - - -class BytesNMeta(type): - def __new__(cls, class_name, parents, attrs): - out = type.__new__(cls, class_name, parents, attrs) - if 'length' in attrs: - setattr(out, 'length', attrs['length']) - out._name = 'BytesN' - out.elem_type = byte - return out - - def __getitem__(self, n): - return self.__class__(self.__name__, (BytesN,), {'length': n}) - - def __subclasscheck__(self, sub): - return _is_bytes_n_instance_of(self, sub) - - def __instancecheck__(self, other): - return _is_bytes_n_instance_of(self, other.__class__) + def __iter__(self): + return iter(self.items) def __eq__(self, other): - return _is_equal_bytes_n_type(self, other) + return self.items == other.items - def __ne__(self, other): - return not _is_equal_bytes_n_type(self, other) +class List(AbstractList, metaclass=AbstractListMeta): + def value_check(self, value): + return len(value) <= self.__class__.length and super().value_check(value) - def __hash__(self): - return hash(self.__class__) - - -def parse_bytes(val): - if val is None: - return None - elif isinstance(val, str): - # TODO: import from eth-utils instead, and do: hexstr_if_str(to_bytes, val) - return None - elif isinstance(val, bytes): - return val - elif isinstance(val, int): - return bytes([val]) - elif isinstance(val, (list, GeneratorType)): - return bytes(val) - else: - return None - - -class BytesN(bytes, metaclass=BytesNMeta): - def __new__(cls, *args): - if not hasattr(cls, 'length'): - return - bytesval = None - if len(args) == 1: - val: Union[bytes, int, str] = args[0] - bytesval = parse_bytes(val) - elif len(args) > 1: - # TODO: each int is 1 byte, check size, create bytesval - bytesval = bytes(args) - - if bytesval is None: - if cls.length == 0: - bytesval = b'' - else: - bytesval = b'\x00' * cls.length - if len(bytesval) != cls.length: - raise TypeError("BytesN[%d] cannot be initialized with value of %d bytes" % (cls.length, len(bytesval))) - return super().__new__(cls, bytesval) - - def serialize(self): - from .ssz_impl import serialize - return serialize(self, self.__class__) - - def hash_tree_root(self): - from .ssz_impl import hash_tree_root - return hash_tree_root(self, self.__class__) - - -class Bytes4(BytesN): - length = 4 - - -class Bytes32(BytesN): - length = 32 - - -class Bytes48(BytesN): - length = 48 - - -class Bytes96(BytesN): - length = 96 - - -# SSZ Defaults -# ----------------------------- -def get_zero_value(typ): - if is_uint_type(typ): - return uint64(0) - elif is_list_type(typ): + def default(self): return [] - elif is_bool_type(typ): - return False - elif is_vector_type(typ): - return typ() - elif is_bytesn_type(typ): - return typ() - elif is_bytes_type(typ): + +class Vector(AbstractList, metaclass=AbstractListMeta): + def value_check(self, value): + return len(value) == self.__class__.length and super().value_check(value) + + def default(self): + return [get_zero_value(self.__class__.elem_type) for _ in range(self.__class__.length)] + +class BytesMeta(AbstractListMeta): + def __getitem__(self, params): + if not isinstance(params, int): + raise Exception("Bytes must be instantiated with one arg: length") + o = self.__class__(self.__name__, (self,), {'length': params}) + o._name = 'Bytes' + return o + +def single_item_extractor(cls, args): + assert len(args) < 2 + return args[0] if len(args) > 0 else cls.default() + +class Bytes(AbstractList, metaclass=BytesMeta): + def value_check(self, value): + return len(value) <= self.__class__.length and isinstance(value, bytes) + + extract_args = single_item_extractor + + def default(self): return b'' - elif is_container_type(typ): - return typ(**{f: get_zero_value(t) for f, t in typ.get_fields()}) - else: - raise Exception("Type not supported: {}".format(typ)) + +class BytesN(AbstractList, metaclass=BytesMeta): + def value_check(self, value): + return len(value) == self.__class__.length and isinstance(value, bytes) + + extract_args = single_item_extractor + + def default(self): + return b'\x00' * self.__class__.length # Type helpers # ----------------------------- - def infer_type(obj): if is_uint_type(obj.__class__): return obj.__class__ elif isinstance(obj, int): return uint64 - elif isinstance(obj, list): - return List[infer_type(obj[0])] - elif isinstance(obj, (Vector, Container, bool, BytesN, bytes)): + elif isinstance(obj, (List, Vector, Container, bool, BytesN, Bytes)): return obj.__class__ else: raise Exception("Unknown type for {}".format(obj)) @@ -476,15 +275,14 @@ def is_list_type(typ): """ Check if the given type is a list. """ - return get_origin(typ) is List or get_origin(typ) is list + return isinstance(typ, type) and issubclass(typ, List) def is_bytes_type(typ): """ Check if the given type is a ``bytes``. """ - # Do not accept subclasses of bytes here, to avoid confusion with BytesN - return typ == bytes + return isinstance(typ, type) and issubclass(typ, Bytes) def is_bytesn_type(typ): @@ -526,22 +324,5 @@ T = TypeVar('T') L = TypeVar('L') -def read_list_elem_type(list_typ: Type[List[T]]) -> T: - if list_typ.__args__ is None or len(list_typ.__args__) != 1: - raise TypeError("Supplied list-type is invalid, no element type found.") - return list_typ.__args__[0] - - -def read_vector_elem_type(vector_typ: Type[Vector[T, L]]) -> T: - return vector_typ.elem_type - - def read_elem_type(typ): - if typ == bytes or (isinstance(typ, type) and issubclass(typ, bytes)): # bytes or bytesN - return byte - elif is_list_type(typ): - return read_list_elem_type(typ) - elif is_vector_type(typ): - return read_vector_elem_type(typ) - else: - raise TypeError("Unexpected type: {}".format(typ)) + return typ.elem_type