lots of bugfixes

This commit is contained in:
protolambda 2019-05-27 23:40:05 +02:00
parent 0f79ed709b
commit d023d2d20f
No known key found for this signature in database
GPG Key ID: EC89FDBB2B4C7623
5 changed files with 81 additions and 31 deletions

View File

@ -4,6 +4,8 @@ from eth2spec.utils.ssz.ssz_typing import *
def encode(value, typ, include_hash_tree_roots=False): def encode(value, typ, include_hash_tree_roots=False):
if is_uint_type(typ): if is_uint_type(typ):
if hasattr(typ, '__supertype__'):
typ = typ.__supertype__
# Larger uints are boxed and the class declares their byte length # Larger uints are boxed and the class declares their byte length
if issubclass(typ, uint) and typ.byte_len > 8: if issubclass(typ, uint) and typ.byte_len > 8:
return str(value) return str(value)
@ -14,7 +16,7 @@ def encode(value, typ, include_hash_tree_roots=False):
elif is_list_type(typ) or is_vector_type(typ): elif is_list_type(typ) or is_vector_type(typ):
elem_typ = read_elem_type(typ) elem_typ = read_elem_type(typ)
return [encode(element, elem_typ, include_hash_tree_roots) for element in value] return [encode(element, elem_typ, include_hash_tree_roots) for element in value]
elif issubclass(typ, bytes): # both bytes and BytesN elif isinstance(typ, type) and issubclass(typ, bytes): # both bytes and BytesN
return '0x' + value.hex() return '0x' + value.hex()
elif is_container_type(typ): elif is_container_type(typ):
ret = {} ret = {}

View File

@ -115,7 +115,7 @@ def get_random_bytes_list(rng: Random, length: int) -> bytes:
return bytes(rng.getrandbits(8) for _ in range(length)) return bytes(rng.getrandbits(8) for _ in range(length))
def get_random_basic_value(rng: Random, typ: str) -> Any: def get_random_basic_value(rng: Random, typ) -> Any:
if is_bool_type(typ): if is_bool_type(typ):
return rng.choice((True, False)) return rng.choice((True, False))
if is_uint_type(typ): if is_uint_type(typ):
@ -126,7 +126,7 @@ def get_random_basic_value(rng: Random, typ: str) -> Any:
raise ValueError("Not a basic type") raise ValueError("Not a basic type")
def get_min_basic_value(typ: str) -> Any: def get_min_basic_value(typ) -> Any:
if is_bool_type(typ): if is_bool_type(typ):
return False return False
if is_uint_type(typ): if is_uint_type(typ):
@ -137,7 +137,7 @@ def get_min_basic_value(typ: str) -> Any:
raise ValueError("Not a basic type") raise ValueError("Not a basic type")
def get_max_basic_value(typ: str) -> Any: def get_max_basic_value(typ) -> Any:
if is_bool_type(typ): if is_bool_type(typ):
return True return True
if is_uint_type(typ): if is_uint_type(typ):

View File

@ -34,10 +34,13 @@ def get_merkle_proof(tree, item_index):
def next_power_of_two(v: int) -> int: def next_power_of_two(v: int) -> int:
""" """
Get the next power of 2. (for 64 bit range ints) Get the next power of 2. (for 64 bit range ints).
0 is a special case, to have non-empty defaults.
Examples: Examples:
0 -> 0, 1 -> 1, 2 -> 2, 3 -> 4, 32 -> 32, 33 -> 64 0 -> 1, 1 -> 1, 2 -> 2, 3 -> 4, 32 -> 32, 33 -> 64
""" """
if v == 0:
return 1
# effectively fill the bitstring (1 less, do not want to with ones, then increment for next power of 2. # effectively fill the bitstring (1 less, do not want to with ones, then increment for next power of 2.
v -= 1 v -= 1
v |= v >> (1 << 0) v |= v >> (1 << 0)

View File

@ -112,7 +112,7 @@ def hash_tree_root(obj, typ):
leaf_root = merkleize_chunks(leaves) leaf_root = merkleize_chunks(leaves)
return mix_in_length(leaf_root, len(obj)) if is_list_type(typ) else leaf_root return mix_in_length(leaf_root, len(obj)) if is_list_type(typ) else leaf_root
elif is_container_type(typ): elif is_container_type(typ):
leaves = [hash_tree_root(elem, subtyp) for elem, subtyp in obj.get_fields()] leaves = [hash_tree_root(field_value, field_typ) for field_value, field_typ in obj.get_typed_values()]
return merkleize_chunks(chunkify(b''.join(leaves))) return merkleize_chunks(chunkify(b''.join(leaves)))
else: else:
raise Exception("Type not supported: obj {} type {}".format(obj, typ)) raise Exception("Type not supported: obj {} type {}".format(obj, typ))
@ -121,6 +121,7 @@ def hash_tree_root(obj, typ):
@infer_input_type @infer_input_type
def signing_root(obj, typ): def signing_root(obj, typ):
assert is_container_type(typ) assert is_container_type(typ)
leaves = [hash_tree_root(elem, subtyp) for elem, subtyp in obj.get_fields()[:-1]] # ignore last field
leaves = [hash_tree_root(field_value, field_typ) for field_value, field_typ in obj.get_typed_values()[:-1]]
return merkleize_chunks(chunkify(b''.join(leaves))) return merkleize_chunks(chunkify(b''.join(leaves)))

View File

@ -1,6 +1,6 @@
from typing import List, Iterable, Type, NewType
from typing import Union
from inspect import isclass from inspect import isclass
from typing import List, Iterable, TypeVar, Type, NewType
from typing import Union
# SSZ integers # SSZ integers
@ -64,17 +64,25 @@ class uint256(uint):
def is_uint_type(typ): def is_uint_type(typ):
# All integers are uint in the scope of the spec here. # All integers are uint in the scope of the spec here.
# Since we default to uint64. Bounds can be checked elsewhere. # Since we default to uint64. Bounds can be checked elsewhere.
return issubclass(typ, int)
# However, some are wrapped in a NewType
if hasattr(typ, '__supertype__'):
# get the type that the NewType is wrapping
typ = typ.__supertype__
return isinstance(typ, type) and issubclass(typ, int)
def uint_byte_size(typ): def uint_byte_size(typ):
if issubclass(typ, uint): if hasattr(typ, '__supertype__'):
return typ.byte_len typ = typ.__supertype__
elif issubclass(typ, int): if isinstance(typ, type):
# Default to uint64 if issubclass(typ, uint):
return 8 return typ.byte_len
else: elif issubclass(typ, int):
raise TypeError("Type %s is not an uint (or int-default uint64) type" % typ) # Default to uint64
return 8
raise TypeError("Type %s is not an uint (or int-default uint64) type" % typ)
# SSZ Container base class # SSZ Container base class
@ -86,7 +94,7 @@ class Container(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
cls = self.__class__ cls = self.__class__
for f, t in cls.get_fields().items(): for f, t in cls.get_fields():
if f not in kwargs: if f not in kwargs:
setattr(self, f, get_zero_value(t)) setattr(self, f, get_zero_value(t))
else: else:
@ -117,7 +125,10 @@ class Container(object):
@classmethod @classmethod
def get_fields(cls): def get_fields(cls):
return dict(cls.__annotations__).items() return list(dict(cls.__annotations__).items())
def get_typed_values(self):
return list(zip(self.get_field_values(), self.get_field_types()))
@classmethod @classmethod
def get_field_names(cls): def get_field_names(cls):
@ -134,6 +145,9 @@ class Container(object):
def _is_vector_instance_of(a, b): def _is_vector_instance_of(a, b):
# Other must not be a BytesN
if issubclass(b, bytes):
return False
if not hasattr(b, 'elem_type') or not hasattr(b, 'length'): if not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
# Vector (b) is not an instance of Vector[X, Y] (a) # Vector (b) is not an instance of Vector[X, Y] (a)
return False return False
@ -146,6 +160,9 @@ def _is_vector_instance_of(a, b):
def _is_equal_vector_type(a, b): def _is_equal_vector_type(a, b):
# Other must not be a BytesN
if issubclass(b, bytes):
return False
if not hasattr(a, 'elem_type') or not hasattr(a, 'length'): if not hasattr(a, 'elem_type') or not hasattr(a, 'length'):
if not hasattr(b, 'elem_type') or not hasattr(b, 'length'): if not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
# Vector == Vector # Vector == Vector
@ -237,6 +254,9 @@ class Vector(metaclass=VectorMeta):
def _is_bytes_n_instance_of(a, b): def _is_bytes_n_instance_of(a, b):
# Other has to be a Bytes derivative class to be a BytesN
if not issubclass(b, bytes):
return False
if not hasattr(b, 'length'): if not hasattr(b, 'length'):
# BytesN (b) is not an instance of BytesN[X] (a) # BytesN (b) is not an instance of BytesN[X] (a)
return False return False
@ -249,6 +269,9 @@ def _is_bytes_n_instance_of(a, b):
def _is_equal_bytes_n_type(a, b): def _is_equal_bytes_n_type(a, b):
# Other has to be a Bytes derivative class to be a BytesN
if not issubclass(b, bytes):
return False
if not hasattr(a, 'length'): if not hasattr(a, 'length'):
if not hasattr(b, 'length'): if not hasattr(b, 'length'):
# BytesN == BytesN # BytesN == BytesN
@ -267,7 +290,7 @@ class BytesNMeta(type):
out = type.__new__(cls, class_name, parents, attrs) out = type.__new__(cls, class_name, parents, attrs)
if 'length' in attrs: if 'length' in attrs:
setattr(out, 'length', attrs['length']) setattr(out, 'length', attrs['length'])
out._name = 'Vector' out._name = 'BytesN'
out.elem_type = byte out.elem_type = byte
return out return out
@ -318,7 +341,7 @@ class BytesN(bytes, metaclass=BytesNMeta):
else: else:
bytesval = b'\x00' * cls.length bytesval = b'\x00' * cls.length
if len(bytesval) != cls.length: if len(bytesval) != cls.length:
raise TypeError("bytesN[%d] cannot be initialized with value of %d bytes" % (cls.length, len(bytesval))) raise TypeError("BytesN[%d] cannot be initialized with value of %d bytes" % (cls.length, len(bytesval)))
return super().__new__(cls, bytesval) return super().__new__(cls, bytesval)
def serialize(self): def serialize(self):
@ -334,7 +357,7 @@ class BytesN(bytes, metaclass=BytesNMeta):
# ----------------------------- # -----------------------------
def get_zero_value(typ): def get_zero_value(typ):
if is_uint(typ): if is_uint_type(typ):
return 0 return 0
if issubclass(typ, bool): if issubclass(typ, bool):
return False return False
@ -354,7 +377,7 @@ def get_zero_value(typ):
# ----------------------------- # -----------------------------
def infer_type(obj): def infer_type(obj):
if is_uint(obj.__class__): if is_uint_type(obj.__class__):
return obj.__class__ return obj.__class__
elif isinstance(obj, int): elif isinstance(obj, int):
return uint64 return uint64
@ -370,39 +393,50 @@ def infer_input_type(fn):
""" """
Decorator to run infer_type on the obj if typ argument is None Decorator to run infer_type on the obj if typ argument is None
""" """
def infer_helper(obj, typ=None): def infer_helper(obj, typ=None):
if typ is None: if typ is None:
typ = infer_type(obj) typ = infer_type(obj)
return fn(obj, typ) return fn(obj, typ)
return infer_helper return infer_helper
def is_bool_type(typ): def is_bool_type(typ):
return issubclass(typ, bool) if hasattr(typ, '__supertype__'):
typ = typ.__supertype__
return isinstance(typ, type) and issubclass(typ, bool)
def is_list_type(typ): def is_list_type(typ):
""" """
Checks if the given type is a list. Checks if the given type is a list.
""" """
return (hasattr(typ, '_name') and typ._name == 'List') return hasattr(typ, '_name') and typ._name == 'List'
def is_bytes_type(typ): def is_bytes_type(typ):
# Do not accept subclasses of bytes here, to avoid confusion with BytesN # Do not accept subclasses of bytes here, to avoid confusion with BytesN
return typ == bytes return typ == bytes
def is_list_kind(typ): def is_list_kind(typ):
""" """
Checks if the given type is a kind of list. Can be bytes. Checks if the given type is a kind of list. Can be bytes.
""" """
return is_list_type(typ) or is_bytes_type(typ) return is_list_type(typ) or is_bytes_type(typ)
def is_vector_type(typ): def is_vector_type(typ):
""" """
Checks if the given type is a vector. Checks if the given type is a vector.
""" """
return issubclass(typ, Vector) return isinstance(typ, type) and issubclass(typ, Vector)
def is_bytesn_type(typ): def is_bytesn_type(typ):
return issubclass(typ, BytesN) return isinstance(typ, type) and issubclass(typ, BytesN)
def is_vector_kind(typ): def is_vector_kind(typ):
""" """
@ -410,23 +444,33 @@ def is_vector_kind(typ):
""" """
return is_vector_type(typ) or is_bytesn_type(typ) return is_vector_type(typ) or is_bytesn_type(typ)
def is_container_type(typ): def is_container_type(typ):
return issubclass(typ, Container) return isinstance(typ, type) and issubclass(typ, Container)
T = TypeVar('T')
L = TypeVar('L')
def read_list_elem_type(list_typ: Type[List[T]]) -> T: def read_list_elem_type(list_typ: Type[List[T]]) -> T:
if list_typ.__args__ is None or len(list_typ.__args__) != 1: if list_typ.__args__ is None or len(list_typ.__args__) != 1:
raise TypeError("Supplied list-type is invalid, no element type found.") raise TypeError("Supplied list-type is invalid, no element type found.")
return list_typ.__args__[0] return list_typ.__args__[0]
def read_vector_elem_type(vector_typ: Type[Vector[T, L]]) -> T: def read_vector_elem_type(vector_typ: Type[Vector[T, L]]) -> T:
return vector_typ.elem_type return vector_typ.elem_type
def read_elem_type(typ): def read_elem_type(typ):
if typ == bytes: if typ == bytes:
return byte return byte
elif is_list_type(typ): elif is_list_type(typ):
return read_list_elem_typ(typ) return read_list_elem_type(typ)
elif is_vector_type(typ): elif is_vector_type(typ):
return read_vector_elem_typ(typ) return read_vector_elem_type(typ)
elif issubclass(typ, bytes):
return byte
else: else:
raise TypeError("Unexpected type: {}".format(typ)) raise TypeError("Unexpected type: {}".format(typ))