lots of bugfixes

This commit is contained in:
protolambda 2019-05-27 23:40:05 +02:00
parent 0f79ed709b
commit d023d2d20f
No known key found for this signature in database
GPG Key ID: EC89FDBB2B4C7623
5 changed files with 81 additions and 31 deletions

View File

@ -4,6 +4,8 @@ from eth2spec.utils.ssz.ssz_typing import *
def encode(value, typ, include_hash_tree_roots=False):
if is_uint_type(typ):
if hasattr(typ, '__supertype__'):
typ = typ.__supertype__
# Larger uints are boxed and the class declares their byte length
if issubclass(typ, uint) and typ.byte_len > 8:
return str(value)
@ -14,7 +16,7 @@ def encode(value, typ, include_hash_tree_roots=False):
elif is_list_type(typ) or is_vector_type(typ):
elem_typ = read_elem_type(typ)
return [encode(element, elem_typ, include_hash_tree_roots) for element in value]
elif issubclass(typ, bytes): # both bytes and BytesN
elif isinstance(typ, type) and issubclass(typ, bytes): # both bytes and BytesN
return '0x' + value.hex()
elif is_container_type(typ):
ret = {}

View File

@ -115,7 +115,7 @@ def get_random_bytes_list(rng: Random, length: int) -> bytes:
return bytes(rng.getrandbits(8) for _ in range(length))
def get_random_basic_value(rng: Random, typ: str) -> Any:
def get_random_basic_value(rng: Random, typ) -> Any:
if is_bool_type(typ):
return rng.choice((True, False))
if is_uint_type(typ):
@ -126,7 +126,7 @@ def get_random_basic_value(rng: Random, typ: str) -> Any:
raise ValueError("Not a basic type")
def get_min_basic_value(typ: str) -> Any:
def get_min_basic_value(typ) -> Any:
if is_bool_type(typ):
return False
if is_uint_type(typ):
@ -137,7 +137,7 @@ def get_min_basic_value(typ: str) -> Any:
raise ValueError("Not a basic type")
def get_max_basic_value(typ: str) -> Any:
def get_max_basic_value(typ) -> Any:
if is_bool_type(typ):
return True
if is_uint_type(typ):

View File

@ -34,10 +34,13 @@ def get_merkle_proof(tree, item_index):
def next_power_of_two(v: int) -> int:
"""
Get the next power of 2. (for 64 bit range ints)
Get the next power of 2. (for 64 bit range ints).
0 is a special case, to have non-empty defaults.
Examples:
0 -> 0, 1 -> 1, 2 -> 2, 3 -> 4, 32 -> 32, 33 -> 64
0 -> 1, 1 -> 1, 2 -> 2, 3 -> 4, 32 -> 32, 33 -> 64
"""
if v == 0:
return 1
# effectively fill the bitstring (1 less, do not want to with ones, then increment for next power of 2.
v -= 1
v |= v >> (1 << 0)

View File

@ -112,7 +112,7 @@ def hash_tree_root(obj, typ):
leaf_root = merkleize_chunks(leaves)
return mix_in_length(leaf_root, len(obj)) if is_list_type(typ) else leaf_root
elif is_container_type(typ):
leaves = [hash_tree_root(elem, subtyp) for elem, subtyp in obj.get_fields()]
leaves = [hash_tree_root(field_value, field_typ) for field_value, field_typ in obj.get_typed_values()]
return merkleize_chunks(chunkify(b''.join(leaves)))
else:
raise Exception("Type not supported: obj {} type {}".format(obj, typ))
@ -121,6 +121,7 @@ def hash_tree_root(obj, typ):
@infer_input_type
def signing_root(obj, typ):
assert is_container_type(typ)
leaves = [hash_tree_root(elem, subtyp) for elem, subtyp in obj.get_fields()[:-1]]
# ignore last field
leaves = [hash_tree_root(field_value, field_typ) for field_value, field_typ in obj.get_typed_values()[:-1]]
return merkleize_chunks(chunkify(b''.join(leaves)))

View File

@ -1,6 +1,6 @@
from typing import List, Iterable, Type, NewType
from typing import Union
from inspect import isclass
from typing import List, Iterable, TypeVar, Type, NewType
from typing import Union
# SSZ integers
@ -64,17 +64,25 @@ class uint256(uint):
def is_uint_type(typ):
# All integers are uint in the scope of the spec here.
# Since we default to uint64. Bounds can be checked elsewhere.
return issubclass(typ, int)
# However, some are wrapped in a NewType
if hasattr(typ, '__supertype__'):
# get the type that the NewType is wrapping
typ = typ.__supertype__
return isinstance(typ, type) and issubclass(typ, int)
def uint_byte_size(typ):
if issubclass(typ, uint):
return typ.byte_len
elif issubclass(typ, int):
# Default to uint64
return 8
else:
raise TypeError("Type %s is not an uint (or int-default uint64) type" % typ)
if hasattr(typ, '__supertype__'):
typ = typ.__supertype__
if isinstance(typ, type):
if issubclass(typ, uint):
return typ.byte_len
elif issubclass(typ, int):
# Default to uint64
return 8
raise TypeError("Type %s is not an uint (or int-default uint64) type" % typ)
# SSZ Container base class
@ -86,7 +94,7 @@ class Container(object):
def __init__(self, **kwargs):
cls = self.__class__
for f, t in cls.get_fields().items():
for f, t in cls.get_fields():
if f not in kwargs:
setattr(self, f, get_zero_value(t))
else:
@ -117,7 +125,10 @@ class Container(object):
@classmethod
def get_fields(cls):
return dict(cls.__annotations__).items()
return list(dict(cls.__annotations__).items())
def get_typed_values(self):
return list(zip(self.get_field_values(), self.get_field_types()))
@classmethod
def get_field_names(cls):
@ -134,6 +145,9 @@ class Container(object):
def _is_vector_instance_of(a, b):
# Other must not be a BytesN
if issubclass(b, bytes):
return False
if not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
# Vector (b) is not an instance of Vector[X, Y] (a)
return False
@ -146,6 +160,9 @@ def _is_vector_instance_of(a, b):
def _is_equal_vector_type(a, b):
# Other must not be a BytesN
if issubclass(b, bytes):
return False
if not hasattr(a, 'elem_type') or not hasattr(a, 'length'):
if not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
# Vector == Vector
@ -237,6 +254,9 @@ class Vector(metaclass=VectorMeta):
def _is_bytes_n_instance_of(a, b):
# Other has to be a Bytes derivative class to be a BytesN
if not issubclass(b, bytes):
return False
if not hasattr(b, 'length'):
# BytesN (b) is not an instance of BytesN[X] (a)
return False
@ -249,6 +269,9 @@ def _is_bytes_n_instance_of(a, b):
def _is_equal_bytes_n_type(a, b):
# Other has to be a Bytes derivative class to be a BytesN
if not issubclass(b, bytes):
return False
if not hasattr(a, 'length'):
if not hasattr(b, 'length'):
# BytesN == BytesN
@ -267,7 +290,7 @@ class BytesNMeta(type):
out = type.__new__(cls, class_name, parents, attrs)
if 'length' in attrs:
setattr(out, 'length', attrs['length'])
out._name = 'Vector'
out._name = 'BytesN'
out.elem_type = byte
return out
@ -318,7 +341,7 @@ class BytesN(bytes, metaclass=BytesNMeta):
else:
bytesval = b'\x00' * cls.length
if len(bytesval) != cls.length:
raise TypeError("bytesN[%d] cannot be initialized with value of %d bytes" % (cls.length, len(bytesval)))
raise TypeError("BytesN[%d] cannot be initialized with value of %d bytes" % (cls.length, len(bytesval)))
return super().__new__(cls, bytesval)
def serialize(self):
@ -334,7 +357,7 @@ class BytesN(bytes, metaclass=BytesNMeta):
# -----------------------------
def get_zero_value(typ):
if is_uint(typ):
if is_uint_type(typ):
return 0
if issubclass(typ, bool):
return False
@ -354,7 +377,7 @@ def get_zero_value(typ):
# -----------------------------
def infer_type(obj):
if is_uint(obj.__class__):
if is_uint_type(obj.__class__):
return obj.__class__
elif isinstance(obj, int):
return uint64
@ -370,39 +393,50 @@ def infer_input_type(fn):
"""
Decorator to run infer_type on the obj if typ argument is None
"""
def infer_helper(obj, typ=None):
if typ is None:
typ = infer_type(obj)
return fn(obj, typ)
return infer_helper
def is_bool_type(typ):
return issubclass(typ, bool)
if hasattr(typ, '__supertype__'):
typ = typ.__supertype__
return isinstance(typ, type) and issubclass(typ, bool)
def is_list_type(typ):
"""
Checks if the given type is a list.
"""
return (hasattr(typ, '_name') and typ._name == 'List')
return hasattr(typ, '_name') and typ._name == 'List'
def is_bytes_type(typ):
# Do not accept subclasses of bytes here, to avoid confusion with BytesN
return typ == bytes
def is_list_kind(typ):
"""
Checks if the given type is a kind of list. Can be bytes.
"""
return is_list_type(typ) or is_bytes_type(typ)
def is_vector_type(typ):
"""
Checks if the given type is a vector.
"""
return issubclass(typ, Vector)
return isinstance(typ, type) and issubclass(typ, Vector)
def is_bytesn_type(typ):
return issubclass(typ, BytesN)
return isinstance(typ, type) and issubclass(typ, BytesN)
def is_vector_kind(typ):
"""
@ -410,23 +444,33 @@ def is_vector_kind(typ):
"""
return is_vector_type(typ) or is_bytesn_type(typ)
def is_container_type(typ):
return issubclass(typ, Container)
return isinstance(typ, type) and issubclass(typ, Container)
T = TypeVar('T')
L = TypeVar('L')
def read_list_elem_type(list_typ: Type[List[T]]) -> T:
if list_typ.__args__ is None or len(list_typ.__args__) != 1:
raise TypeError("Supplied list-type is invalid, no element type found.")
return list_typ.__args__[0]
def read_vector_elem_type(vector_typ: Type[Vector[T, L]]) -> T:
return vector_typ.elem_type
def read_elem_type(typ):
if typ == bytes:
return byte
elif is_list_type(typ):
return read_list_elem_typ(typ)
return read_list_elem_type(typ)
elif is_vector_type(typ):
return read_vector_elem_typ(typ)
return read_vector_elem_type(typ)
elif issubclass(typ, bytes):
return byte
else:
raise TypeError("Unexpected type: {}".format(typ))