Simplified SSZ impl

This commit is contained in:
protolambda 2019-06-20 19:51:38 +02:00
parent 29dbe1b880
commit 5ddfe34f0c
No known key found for this signature in database
GPG Key ID: EC89FDBB2B4C7623
2 changed files with 102 additions and 318 deletions

View File

@ -1,9 +1,8 @@
from ..merkle_minimal import merkleize_chunks, hash
from eth2spec.utils.ssz.ssz_typing import (
from ..merkle_minimal import merkleize_chunks, hash, ZERO_BYTES32
from .ssz_typing import (
is_uint_type, is_bool_type, is_container_type,
is_list_kind, is_vector_kind,
read_vector_elem_type, read_elem_type,
uint_byte_size,
read_elem_type,
infer_input_type,
get_zero_value,
)
@ -20,7 +19,7 @@ def is_basic_type(typ):
def serialize_basic(value, typ):
if is_uint_type(typ):
return value.to_bytes(uint_byte_size(typ), 'little')
return value.to_bytes(typ.byte_len, 'little')
elif is_bool_type(typ):
if value:
return b'\x01'
@ -140,6 +139,8 @@ def get_typed_values(obj, typ=None):
else:
raise Exception("Invalid type")
def item_length(typ):
return 1 if typ == bool else typ.byte_len if is_uint_type(typ) else 32
@infer_input_type
def hash_tree_root(obj, typ=None):
@ -150,6 +151,8 @@ def hash_tree_root(obj, typ=None):
fields = get_typed_values(obj, typ=typ)
leaves = [hash_tree_root(field_value, typ=field_typ) for field_value, field_typ in fields]
if is_list_kind(typ):
full_chunk_length = (item_length(read_elem_type(typ)) * typ.length + 31) // 32
leaves += [ZERO_BYTES32] * (full_chunk_length - len(obj))
return mix_in_length(merkleize_chunks(leaves), len(obj))
else:
return merkleize_chunks(leaves)

View File

@ -6,74 +6,38 @@ from typing_inspect import get_origin
# SSZ integers
# -----------------------------
class uint(int):
byte_len = 0
def __new__(cls, value, *args, **kwargs):
if value < 0:
raise ValueError("unsigned types must not be negative")
if value.byte_len and value.bit_length() > value.byte_len:
raise ValueError("value out of bounds for uint{}".format(value.byte_len))
return super().__new__(cls, value)
class uint8(uint):
byte_len = 1
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 8:
raise ValueError("value out of bounds for uint8")
return super().__new__(cls, value)
# Alias for uint8
byte = NewType('byte', uint8)
class uint16(uint):
byte_len = 2
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 16:
raise ValueError("value out of bounds for uint16")
return super().__new__(cls, value)
class uint32(uint):
byte_len = 4
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 32:
raise ValueError("value out of bounds for uint16")
return super().__new__(cls, value)
class uint64(uint):
byte_len = 8
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 64:
raise ValueError("value out of bounds for uint64")
return super().__new__(cls, value)
class uint128(uint):
byte_len = 16
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 128:
raise ValueError("value out of bounds for uint128")
return super().__new__(cls, value)
class uint256(uint):
byte_len = 32
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 256:
raise ValueError("value out of bounds for uint256")
return super().__new__(cls, value)
def is_uint_type(typ):
# All integers are uint in the scope of the spec here.
# Since we default to uint64. Bounds can be checked elsewhere.
@ -84,21 +48,6 @@ def is_uint_type(typ):
return isinstance(typ, type) and issubclass(typ, int) and not issubclass(typ, bool)
def uint_byte_size(typ):
if hasattr(typ, '__supertype__'):
typ = typ.__supertype__
if isinstance(typ, type):
if issubclass(typ, uint):
return typ.byte_len
elif issubclass(typ, int):
# Default to uint64
return 8
else:
raise TypeError("Type %s is not an uint (or int-default uint64) type" % typ)
# SSZ Container base class
# -----------------------------
@ -166,45 +115,19 @@ class Container(object):
return list(cls.__annotations__.values())
# SSZ vector
# -----------------------------
def _is_vector_instance_of(a, b):
# Other must not be a BytesN
if issubclass(b, bytes):
return False
elif not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
# Vector (b) is not an instance of Vector[X, Y] (a)
return False
elif not hasattr(a, 'elem_type') or not hasattr(a, 'length'):
# Vector[X, Y] (b) is an instance of Vector (a)
return True
def get_zero_value(typ):
if typ == int:
return 0
else:
# Vector[X, Y] (a) is an instance of Vector[X, Y] (b)
return a.elem_type == b.elem_type and a.length == b.length
return typ.default()
def _is_equal_vector_type(a, b):
# Other must not be a BytesN
if issubclass(b, bytes):
return False
elif not hasattr(a, 'elem_type') or not hasattr(a, 'length'):
if not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
# Vector == Vector
return True
else:
# Vector != Vector[X, Y]
return False
elif not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
# Vector[X, Y] != Vector
return False
def type_check(typ, value):
if typ == int or typ == uint64:
return isinstance(value, int)
else:
# Vector[X, Y] == Vector[X, Y]
return a.elem_type == b.elem_type and a.length == b.length
return typ.value_check(value)
class VectorMeta(type):
class AbstractListMeta(type):
def __new__(cls, class_name, parents, attrs):
out = type.__new__(cls, class_name, parents, attrs)
if 'elem_type' in attrs and 'length' in attrs:
@ -214,239 +137,115 @@ class VectorMeta(type):
def __getitem__(self, params):
if not isinstance(params, tuple) or len(params) != 2:
raise Exception("Vector must be instantiated with two args: elem type and length")
o = self.__class__(self.__name__, (Vector,), {'elem_type': params[0], 'length': params[1]})
o._name = 'Vector'
raise Exception("List must be instantiated with two args: elem type and length")
o = self.__class__(self.__name__, (self,), {'elem_type': params[0], 'length': params[1]})
o._name = 'AbstractList'
return o
def __subclasscheck__(self, sub):
return _is_vector_instance_of(self, sub)
def __instancecheck__(self, obj):
if obj.__class__.__name__ != self.__name__:
return False
if hasattr(self, 'elem_type') and obj.__class__.elem_type != self.elem_type:
return False
if hasattr(self, 'length') and obj.__class__.length != self.length:
return False
return True
def __instancecheck__(self, other):
return _is_vector_instance_of(self, other.__class__)
class ValueCheckError(Exception):
pass
def __eq__(self, other):
return _is_equal_vector_type(self, other)
class AbstractList(metaclass=AbstractListMeta):
def __init__(self, *args):
items = self.extract_args(args)
if not self.value_check(items):
raise ValueCheckError("Bad input for class {}: {}".format(self.__class__, items))
self.items = items
def value_check(self, value):
for v in value:
if not type_check(self.__class__.elem_type, v):
return False
return True
def __ne__(self, other):
return not _is_equal_vector_type(self, other)
def extract_args(self, args):
return list(args) if len(args) > 0 else self.default()
def __hash__(self):
return hash(self.__class__)
def default(self):
raise Exception("Not implemented")
def __getitem__(self, i):
return self.items[i]
class Vector(metaclass=VectorMeta):
def __init__(self, *args: Iterable):
cls = self.__class__
if not hasattr(cls, 'elem_type'):
raise TypeError("Type Vector without elem_type data cannot be instantiated")
elif not hasattr(cls, 'length'):
raise TypeError("Type Vector without length data cannot be instantiated")
if len(args) != cls.length:
if len(args) == 0:
args = [get_zero_value(cls.elem_type) for _ in range(cls.length)]
else:
raise TypeError("Typed vector with length %d cannot hold %d items" % (cls.length, len(args)))
self.items = list(args)
# cannot check non-type objects, or parametrized types
if isinstance(cls.elem_type, type) and not hasattr(cls.elem_type, '__args__'):
for i, item in enumerate(self.items):
if not issubclass(cls.elem_type, type(item)):
raise TypeError("Typed vector cannot hold differently typed value"
" at index %d. Got type: %s, expected type: %s" % (i, type(item), cls.elem_type))
def serialize(self):
from .ssz_impl import serialize
return serialize(self, self.__class__)
def hash_tree_root(self):
from .ssz_impl import hash_tree_root
return hash_tree_root(self, self.__class__)
def __repr__(self):
return repr({'length': self.__class__.length, 'items': self.items})
def __getitem__(self, key):
return self.items[key]
def __setitem__(self, key, value):
self.items[key] = value
def __iter__(self):
return iter(self.items)
def __setitem__(self, k, v):
self.items[k] = v
def __len__(self):
return len(self.items)
def __eq__(self, other):
return self.hash_tree_root() == other.hash_tree_root()
def __repr__(self):
return repr(self.items)
# SSZ BytesN
# -----------------------------
def _is_bytes_n_instance_of(a, b):
# Other has to be a Bytes derivative class to be a BytesN
if not issubclass(b, bytes):
return False
elif not hasattr(b, 'length'):
# BytesN (b) is not an instance of BytesN[X] (a)
return False
elif not hasattr(a, 'length'):
# BytesN[X] (b) is an instance of BytesN (a)
return True
else:
# BytesN[X] (a) is an instance of BytesN[X] (b)
return a.length == b.length
def _is_equal_bytes_n_type(a, b):
# Other has to be a Bytes derivative class to be a BytesN
if not issubclass(b, bytes):
return False
elif not hasattr(a, 'length'):
if not hasattr(b, 'length'):
# BytesN == BytesN
return True
else:
# BytesN != BytesN[X]
return False
elif not hasattr(b, 'length'):
# BytesN[X] != BytesN
return False
else:
# BytesN[X] == BytesN[X]
return a.length == b.length
class BytesNMeta(type):
def __new__(cls, class_name, parents, attrs):
out = type.__new__(cls, class_name, parents, attrs)
if 'length' in attrs:
setattr(out, 'length', attrs['length'])
out._name = 'BytesN'
out.elem_type = byte
return out
def __getitem__(self, n):
return self.__class__(self.__name__, (BytesN,), {'length': n})
def __subclasscheck__(self, sub):
return _is_bytes_n_instance_of(self, sub)
def __instancecheck__(self, other):
return _is_bytes_n_instance_of(self, other.__class__)
def __iter__(self):
return iter(self.items)
def __eq__(self, other):
return _is_equal_bytes_n_type(self, other)
return self.items == other.items
def __ne__(self, other):
return not _is_equal_bytes_n_type(self, other)
class List(AbstractList, metaclass=AbstractListMeta):
def value_check(self, value):
return len(value) <= self.__class__.length and super().value_check(value)
def __hash__(self):
return hash(self.__class__)
def parse_bytes(val):
if val is None:
return None
elif isinstance(val, str):
# TODO: import from eth-utils instead, and do: hexstr_if_str(to_bytes, val)
return None
elif isinstance(val, bytes):
return val
elif isinstance(val, int):
return bytes([val])
elif isinstance(val, (list, GeneratorType)):
return bytes(val)
else:
return None
class BytesN(bytes, metaclass=BytesNMeta):
def __new__(cls, *args):
if not hasattr(cls, 'length'):
return
bytesval = None
if len(args) == 1:
val: Union[bytes, int, str] = args[0]
bytesval = parse_bytes(val)
elif len(args) > 1:
# TODO: each int is 1 byte, check size, create bytesval
bytesval = bytes(args)
if bytesval is None:
if cls.length == 0:
bytesval = b''
else:
bytesval = b'\x00' * cls.length
if len(bytesval) != cls.length:
raise TypeError("BytesN[%d] cannot be initialized with value of %d bytes" % (cls.length, len(bytesval)))
return super().__new__(cls, bytesval)
def serialize(self):
from .ssz_impl import serialize
return serialize(self, self.__class__)
def hash_tree_root(self):
from .ssz_impl import hash_tree_root
return hash_tree_root(self, self.__class__)
class Bytes4(BytesN):
length = 4
class Bytes32(BytesN):
length = 32
class Bytes48(BytesN):
length = 48
class Bytes96(BytesN):
length = 96
# SSZ Defaults
# -----------------------------
def get_zero_value(typ):
if is_uint_type(typ):
return uint64(0)
elif is_list_type(typ):
def default(self):
return []
elif is_bool_type(typ):
return False
elif is_vector_type(typ):
return typ()
elif is_bytesn_type(typ):
return typ()
elif is_bytes_type(typ):
class Vector(AbstractList, metaclass=AbstractListMeta):
def value_check(self, value):
return len(value) == self.__class__.length and super().value_check(value)
def default(self):
return [get_zero_value(self.__class__.elem_type) for _ in range(self.__class__.length)]
class BytesMeta(AbstractListMeta):
def __getitem__(self, params):
if not isinstance(params, int):
raise Exception("Bytes must be instantiated with one arg: length")
o = self.__class__(self.__name__, (self,), {'length': params})
o._name = 'Bytes'
return o
def single_item_extractor(cls, args):
assert len(args) < 2
return args[0] if len(args) > 0 else cls.default()
class Bytes(AbstractList, metaclass=BytesMeta):
def value_check(self, value):
return len(value) <= self.__class__.length and isinstance(value, bytes)
extract_args = single_item_extractor
def default(self):
return b''
elif is_container_type(typ):
return typ(**{f: get_zero_value(t) for f, t in typ.get_fields()})
else:
raise Exception("Type not supported: {}".format(typ))
class BytesN(AbstractList, metaclass=BytesMeta):
def value_check(self, value):
return len(value) == self.__class__.length and isinstance(value, bytes)
extract_args = single_item_extractor
def default(self):
return b'\x00' * self.__class__.length
# Type helpers
# -----------------------------
def infer_type(obj):
if is_uint_type(obj.__class__):
return obj.__class__
elif isinstance(obj, int):
return uint64
elif isinstance(obj, list):
return List[infer_type(obj[0])]
elif isinstance(obj, (Vector, Container, bool, BytesN, bytes)):
elif isinstance(obj, (List, Vector, Container, bool, BytesN, Bytes)):
return obj.__class__
else:
raise Exception("Unknown type for {}".format(obj))
@ -476,15 +275,14 @@ def is_list_type(typ):
"""
Check if the given type is a list.
"""
return get_origin(typ) is List or get_origin(typ) is list
return isinstance(typ, type) and issubclass(typ, List)
def is_bytes_type(typ):
"""
Check if the given type is a ``bytes``.
"""
# Do not accept subclasses of bytes here, to avoid confusion with BytesN
return typ == bytes
return isinstance(typ, type) and issubclass(typ, Bytes)
def is_bytesn_type(typ):
@ -526,22 +324,5 @@ T = TypeVar('T')
L = TypeVar('L')
def read_list_elem_type(list_typ: Type[List[T]]) -> T:
if list_typ.__args__ is None or len(list_typ.__args__) != 1:
raise TypeError("Supplied list-type is invalid, no element type found.")
return list_typ.__args__[0]
def read_vector_elem_type(vector_typ: Type[Vector[T, L]]) -> T:
return vector_typ.elem_type
def read_elem_type(typ):
if typ == bytes or (isinstance(typ, type) and issubclass(typ, bytes)): # bytes or bytesN
return byte
elif is_list_type(typ):
return read_list_elem_type(typ)
elif is_vector_type(typ):
return read_vector_elem_type(typ)
else:
raise TypeError("Unexpected type: {}".format(typ))
return typ.elem_type