improve impl, box less used integer types, use uint64 as default

This commit is contained in:
protolambda 2019-05-27 18:01:46 +02:00
parent e14f789779
commit d5eab257d0
No known key found for this signature in database
GPG Key ID: EC89FDBB2B4C7623
2 changed files with 116 additions and 61 deletions

View File

@ -6,18 +6,21 @@ from .ssz_typing import *
BYTES_PER_LENGTH_OFFSET = 4
def is_basic_type(typ):
return is_uint(typ) or typ == bool
def serialize_basic(value, typ):
if is_uint(typ):
return value.to_bytes(typ.byte_len, 'little')
return value.to_bytes(uint_byte_size(typ), 'little')
if issubclass(typ, bool):
if value:
return b'\x01'
else:
return b'\x00'
def is_fixed_size(typ):
if is_basic_type(typ):
return True
@ -30,9 +33,9 @@ def is_fixed_size(typ):
else:
raise Exception("Type not supported: {}".format(typ))
def serialize(obj, typ=None):
if typ is None:
typ = infer_type(obj)
@infer_input_type
def serialize(obj, typ):
if is_basic_type(typ):
return serialize_basic(obj, typ)
elif is_list_type(typ) or is_vector_type(typ):
@ -42,6 +45,7 @@ def serialize(obj, typ=None):
else:
raise Exception("Type not supported: {}".format(typ))
def encode_series(values, types):
# bytes and bytesN are already in the right format.
if isinstance(values, bytes):
@ -75,24 +79,28 @@ def encode_series(values, types):
# Return the concatenation of the fixed-size parts (offsets interleaved) with the variable-size parts
return b''.join(fixed_parts + variable_parts)
# SSZ Hash-tree-root
# -----------------------------
def pack(values, subtype):
if isinstance(values, bytes):
return values
return b''.join([serialize_basic(value, subtype) for value in values])
def chunkify(bytez):
bytez += b'\x00' * (-len(bytez) % 32)
return [bytez[i:i + 32] for i in range(0, len(bytez), 32)]
def mix_in_length(root, length):
return hash(root + length.to_bytes(32, 'little'))
def hash_tree_root(obj, typ=None):
if typ is None:
typ = infer_type(obj)
@infer_input_type
def hash_tree_root(obj, typ):
if is_basic_type(typ):
return merkleize_chunks(chunkify(serialize_basic(obj, typ)))
elif is_list_type(typ) or is_vector_type(typ):
@ -104,31 +112,15 @@ def hash_tree_root(obj, typ=None):
leaf_root = merkleize_chunks(leaves)
return mix_in_length(leaf_root, len(obj)) if is_list_type(typ) else leaf_root
elif is_container_typ(typ):
leaves = [hash_tree_root(elem, subtyp) for elem, subtyp in obj.get_fields().items()]
leaves = [hash_tree_root(elem, subtyp) for elem, subtyp in obj.get_fields()]
return merkleize_chunks(chunkify(b''.join(leaves)))
else:
raise Exception("Type not supported: obj {} type {}".format(obj, typ))
def signing_root(value, typ):
if typ is None:
typ = infer_type(obj)
@infer_input_type
def signing_root(obj, typ):
assert is_container_typ(typ)
leaves = [hash_tree_root(elem, subtyp) for elem, subtyp in obj.get_fields().items()]
leaves = [hash_tree_root(elem, subtyp) for elem, subtyp in obj.get_fields()[:-1]]
return merkleize_chunks(chunkify(b''.join(leaves)))
# Implementation notes:
# - Container,Vector/BytesN.hash_tree_root/serialize functions are for ease, implementation here
# - uint types have a 'byte_len' attribute
# - uint types are not classes. They use NewType(), for performance.
# This forces us to check type equivalence by exact reference.
# There's no class. The type data comes from an annotation/argument from the context of the value.
# - Vector is not valid to create instances with. Give it a elem-type and length: Vector[FooBar, 123]
# - *The class of* a Vector instance has a `elem_type` (type, may not be a class, see uint) and `length` (int)
# - BytesN is not valid to create instances with. Give it a length: BytesN[123]
# - *The class of* a BytesN instance has a `length` (int)
# Where possible, it is preferable to create helpers that just act on the type, and don't unnecessarily use a value
# E.g. is_basic_type(). This way, we can use them in type-only contexts and have no duplicate logic.
# For every class-instance, you can get the type with my_object.__class__
# For uints, and other NewType related, you have to rely on type information. It cannot be retrieved from the value.
# Note: we may just want to box integers instead. And then we can do bounds checking too. But it is SLOW and MEMORY INTENSIVE.
#

View File

@ -1,30 +1,82 @@
from typing import List, Iterable, TypeVar, Type, NewType
from typing import List, Iterable, Type, NewType
from typing import Union
from inspect import isclass
T = TypeVar('T')
L = TypeVar('L')
# SSZ integer types, with 0 computational overhead (NewType)
# SSZ integers
# -----------------------------
uint8 = NewType('uint8', int)
uint8.byte_len = 1
uint16 = NewType('uint16', int)
uint16.byte_len = 2
uint32 = NewType('uint32', int)
uint32.byte_len = 4
uint64 = NewType('uint64', int)
uint64.byte_len = 8
uint128 = NewType('uint128', int)
uint128.byte_len = 16
uint256 = NewType('uint256', int)
uint256.byte_len = 32
class uint(int):
byte_len = 0
def __new__(cls, value, *args, **kwargs):
if value < 0:
raise ValueError("unsigned types must not be negative")
return super().__new__(cls, value)
class uint8(uint):
byte_len = 1
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 8:
raise ValueError("value out of bounds for uint8")
return super().__new__(cls, value)
# Alias for uint8
byte = NewType('byte', uint8)
class uint16(uint):
byte_len = 2
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 16:
raise ValueError("value out of bounds for uint16")
return super().__new__(cls, value)
class uint32(uint):
byte_len = 4
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 32:
raise ValueError("value out of bounds for uint16")
return super().__new__(cls, value)
# We simply default to uint64. But do give it a name, for readability
uint64 = NewType('uint64', int)
class uint128(uint):
byte_len = 16
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 128:
raise ValueError("value out of bounds for uint128")
return super().__new__(cls, value)
class uint256(uint):
byte_len = 32
def __new__(cls, value, *args, **kwargs):
if value.bit_length() > 256:
raise ValueError("value out of bounds for uint256")
return super().__new__(cls, value)
def is_uint(typ):
# All integers are uint in the scope of the spec here.
# Since we default to uint64. Bounds can be checked elsewhere.
return issubclass(typ, int)
def uint_byte_size(typ):
if issubclass(typ, uint):
return typ.byte_len
elif issubclass(typ, int):
# Default to uint64
return 8
else:
raise TypeError("Type %s is not an uint (or int-default uint64) type" % typ)
# SSZ Container base class
# -----------------------------
@ -57,9 +109,13 @@ class Container(object):
return [getattr(self, field) for field in cls.get_field_names()]
@classmethod
def get_fields(cls):
def get_fields_dict(cls):
return dict(cls.__annotations__)
@classmethod
def get_fields(cls):
return dict(cls.__annotations__).items()
@classmethod
def get_field_names(cls):
return list(cls.__annotations__.keys())
@ -70,14 +126,6 @@ class Container(object):
return list(cls.__annotations__.values())
def is_uint(typ):
# Note: only the type reference exists,
# but it really resolves to 'int' during run-time for zero computational/memory overhead.
# Hence, we check equality to the type references (which are really just 'NewType' instances),
# and don't use any sub-classing like we normally would.
return typ == uint8 or typ == uint16 or typ == uint32 or typ == uint64 \
or typ == uint128 or typ == uint256 or typ == byte
# SSZ vector
# -----------------------------
@ -138,7 +186,7 @@ class VectorMeta(type):
class Vector(metaclass=VectorMeta):
def __init__(self, *args: Iterable[T]):
def __init__(self, *args: Iterable):
cls = self.__class__
if not hasattr(cls, 'elem_type'):
@ -275,6 +323,7 @@ class BytesN(bytes, metaclass=BytesNMeta):
from .ssz_impl import hash_tree_root
return hash_tree_root(self, self.__class__)
# SSZ Defaults
# -----------------------------
@ -292,7 +341,8 @@ def get_zero_value(typ):
if issubclass(typ, bytes):
return b''
if issubclass(typ, Container):
return typ(**{f: get_zero_value(t) for f, t in typ.get_fields().items()}),
return typ(**{f: get_zero_value(t) for f, t in typ.get_fields()}),
# Type helpers
# -----------------------------
@ -309,17 +359,30 @@ def infer_type(obj):
else:
raise Exception("Unknown type for {}".format(obj))
def infer_input_type(fn):
"""
Decorator to run infer_type on the obj if typ argument is None
"""
def infer_helper(obj, typ=None):
if typ is None:
typ = infer_type(obj)
return fn(obj, typ)
return infer_helper
def is_list_type(typ):
return (hasattr(typ, '_name') and typ._name == 'List') or typ == bytes
def is_vector_type(typ):
return hasattr(typ, '_name') and typ._name == 'Vector'
return issubclass(typ, Vector)
def is_container_typ(typ):
return hasattr(typ, 'get_fields')
return issubclass(typ, Container)
def read_list_elem_typ(list_typ: Type[List[T]]) -> T:
assert list_typ.__args__ is not None
if list_typ.__args__ is None or len(list_typ.__args__) != 1:
raise TypeError("Supplied list-type is invalid, no element type found.")
return list_typ.__args__[0]
def read_vector_elem_typ(vector_typ: Type[Vector[T, L]]) -> T:
@ -333,4 +396,4 @@ def read_elem_typ(typ):
elif is_vector_type(typ):
return read_vector_elem_typ(typ)
else:
raise Exception("Unexpected type: {}".format(typ))
raise TypeError("Unexpected type: {}".format(typ))