SSZ impl fixes (#960)

* fix serialization mixup of array types, fix variable size vector serialization, document, structure and de-deduplicate code
* more cleanup + minor fixes in earlier improvements
* Make type-inference stable on empty lists/vectors
* Add get_zero_value
This commit is contained in:
Diederik Loerakker 2019-04-18 22:54:17 +10:00 committed by Justin
parent d4ce0d20a0
commit d8a4a48ed8
1 changed files with 143 additions and 29 deletions

View File

@ -1,5 +1,6 @@
from .hash_function import hash from .hash_function import hash
from typing import Any
BYTES_PER_CHUNK = 32 BYTES_PER_CHUNK = 32
BYTES_PER_LENGTH_PREFIX = 4 BYTES_PER_LENGTH_PREFIX = 4
@ -9,9 +10,10 @@ ZERO_CHUNK = b'\x00' * BYTES_PER_CHUNK
def SSZType(fields): def SSZType(fields):
class SSZObject(): class SSZObject():
def __init__(self, **kwargs): def __init__(self, **kwargs):
for f in fields: for f, t in fields.items():
if f not in kwargs: if f not in kwargs:
raise Exception("Missing constructor argument: %s" % f) setattr(self, f, get_zero_value(t))
else:
setattr(self, f, kwargs[f]) setattr(self, f, kwargs[f])
def __eq__(self, other): def __eq__(self, other):
@ -58,18 +60,40 @@ class Vector():
def is_basic(typ): def is_basic(typ):
return isinstance(typ, str) and (typ[:4] in ('uint', 'bool') or typ == 'byte') # if not a string, it is a complex, and cannot be basic
if not isinstance(typ, str):
return False
# "uintN": N-bit unsigned integer (where N in [8, 16, 32, 64, 128, 256])
elif typ[:4] == 'uint' and typ[4:] in ['8', '16', '32', '64', '128', '256']:
return True
# "bool": True or False
elif typ == 'bool':
return True
# alias: "byte" -> "uint8"
elif typ == 'byte':
return True
# default
else:
return False
def is_constant_sized(typ): def is_constant_sized(typ):
# basic objects are fixed size by definition
if is_basic(typ): if is_basic(typ):
return True return True
# dynamic size array type, "list": [elem_type].
# Not constant size by definition.
elif isinstance(typ, list) and len(typ) == 1: elif isinstance(typ, list) and len(typ) == 1:
return is_constant_sized(typ[0])
elif isinstance(typ, list) and len(typ) == 2:
return False return False
# fixed size array type, "vector": [elem_type, length]
# Constant size, but only if the elements are.
elif isinstance(typ, list) and len(typ) == 2:
return is_constant_sized(typ[0])
# bytes array (fixed or dynamic size)
elif isinstance(typ, str) and typ[:5] == 'bytes': elif isinstance(typ, str) and typ[:5] == 'bytes':
return len(typ) > 5 # if no length suffix, it has a dynamic size
return typ != 'bytes'
# containers are only constant-size if all of the fields are constant size.
elif hasattr(typ, 'fields'): elif hasattr(typ, 'fields'):
for subtype in typ.fields.values(): for subtype in typ.fields.values():
if not is_constant_sized(subtype): if not is_constant_sized(subtype):
@ -90,40 +114,98 @@ def coerce_to_bytes(x):
raise Exception("Expecting bytes") raise Exception("Expecting bytes")
def encode_bytes(value):
serialized_bytes = coerce_to_bytes(value)
assert len(serialized_bytes) < 2 ** (8 * BYTES_PER_LENGTH_PREFIX)
serialized_length = len(serialized_bytes).to_bytes(BYTES_PER_LENGTH_PREFIX, 'little')
return serialized_length + serialized_bytes
def encode_variable_size_container(values, types):
return encode_bytes(encode_fixed_size_container(values, types))
def encode_fixed_size_container(values, types):
return b''.join([serialize_value(v, typ) for (v, typ) in zip(values, types)])
def serialize_value(value, typ=None): def serialize_value(value, typ=None):
if typ is None: if typ is None:
typ = infer_type(value) typ = infer_type(value)
# "uintN"
if isinstance(typ, str) and typ[:4] == 'uint': if isinstance(typ, str) and typ[:4] == 'uint':
length = int(typ[4:]) length = int(typ[4:])
assert length in (8, 16, 32, 64, 128, 256) assert length in (8, 16, 32, 64, 128, 256)
return value.to_bytes(length // 8, 'little') return value.to_bytes(length // 8, 'little')
elif typ == 'bool': # "bool"
elif isinstance(typ, str) and typ == 'bool':
assert value in (True, False) assert value in (True, False)
return b'\x01' if value is True else b'\x00' return b'\x01' if value is True else b'\x00'
elif (isinstance(typ, list) and len(typ) == 1) or typ == 'bytes': # Vector
serialized_bytes = coerce_to_bytes(value) if typ == 'bytes' else b''.join([serialize_value(element, typ[0]) for element in value])
assert len(serialized_bytes) < 2**(8 * BYTES_PER_LENGTH_PREFIX)
serialized_length = len(serialized_bytes).to_bytes(BYTES_PER_LENGTH_PREFIX, 'little')
return serialized_length + serialized_bytes
elif isinstance(typ, list) and len(typ) == 2: elif isinstance(typ, list) and len(typ) == 2:
# (regardless of element type, sanity-check if the length reported in the vector type matches the value length)
assert len(value) == typ[1] assert len(value) == typ[1]
return b''.join([serialize_value(element, typ[0]) for element in value]) # If value is fixed-size (i.e. element type is fixed-size):
if is_constant_sized(typ):
return encode_fixed_size_container(value, [typ[0]] * len(value))
# If value is variable-size (i.e. element type is variable-size)
else:
return encode_variable_size_container(value, [typ[0]] * len(value))
# "bytes" (variable size)
elif isinstance(typ, str) and typ == 'bytes':
return encode_bytes(value)
# List
elif isinstance(typ, list) and len(typ) == 1:
return encode_variable_size_container(value, [typ[0]] * len(value))
# "bytesN" (fixed size)
elif isinstance(typ, str) and len(typ) > 5 and typ[:5] == 'bytes': elif isinstance(typ, str) and len(typ) > 5 and typ[:5] == 'bytes':
assert len(value) == int(typ[5:]), (value, int(typ[5:])) assert len(value) == int(typ[5:]), (value, int(typ[5:]))
return coerce_to_bytes(value) return coerce_to_bytes(value)
# containers
elif hasattr(typ, 'fields'): elif hasattr(typ, 'fields'):
serialized_bytes = b''.join([serialize_value(getattr(value, field), subtype) for field, subtype in typ.fields.items()]) values = [getattr(value, field) for field in typ.fields.keys()]
types = list(typ.fields.values())
if is_constant_sized(typ): if is_constant_sized(typ):
return serialized_bytes return encode_fixed_size_container(values, types)
else: else:
assert len(serialized_bytes) < 2**(8 * BYTES_PER_LENGTH_PREFIX) return encode_variable_size_container(values, types)
serialized_length = len(serialized_bytes).to_bytes(BYTES_PER_LENGTH_PREFIX, 'little')
return serialized_length + serialized_bytes
else: else:
print(value, typ) print(value, typ)
raise Exception("Type not recognized") raise Exception("Type not recognized")
def get_zero_value(typ: Any) -> Any:
if isinstance(typ, str):
# Bytes array
if typ == 'bytes':
return b''
# bytesN
elif typ[:5] == 'bytes' and len(typ) > 5:
length = int(typ[5:])
return b'\x00' * length
# Basic types
elif typ == 'bool':
return False
elif typ[:4] == 'uint':
return 0
elif typ == 'byte':
return 0x00
else:
raise ValueError("Type not recognized")
# Vector:
elif isinstance(typ, list) and len(typ) == 2:
return [get_zero_value(typ[0]) for _ in range(typ[1])]
# List:
elif isinstance(typ, list) and len(typ) == 1:
return []
# Container:
elif hasattr(typ, 'fields'):
return typ(**{field: get_zero_value(subtype) for field, subtype in typ.fields.items()})
else:
print(typ)
raise Exception("Type not recognized")
def chunkify(bytez): def chunkify(bytez):
bytez += b'\x00' * (-len(bytez) % BYTES_PER_CHUNK) bytez += b'\x00' * (-len(bytez) % BYTES_PER_CHUNK)
return [bytez[i:i + 32] for i in range(0, len(bytez), 32)] return [bytez[i:i + 32] for i in range(0, len(bytez), 32)]
@ -152,12 +234,27 @@ def mix_in_length(root, length):
def infer_type(value): def infer_type(value):
"""
Note: defaults to uint64 for integer type inference due to lack of information.
Other integer sizes are still supported, see spec.
:param value: The value to infer a SSZ type for.
:return: The SSZ type.
"""
if hasattr(value.__class__, 'fields'): if hasattr(value.__class__, 'fields'):
return value.__class__ return value.__class__
elif isinstance(value, Vector): elif isinstance(value, Vector):
return [infer_type(value[0]) if len(value) > 0 else 'uint64', len(value)] if len(value) > 0:
return [infer_type(value[0]), len(value)]
else:
# Element type does not matter too much,
# assumed to be a basic type for size-encoding purposes, vector is empty.
return ['uint64']
elif isinstance(value, list): elif isinstance(value, list):
return [infer_type(value[0])] if len(value) > 0 else ['uint64'] if len(value) > 0:
return [infer_type(value[0])]
else:
# Element type does not matter, list-content size will be encoded regardless, list is empty.
return ['uint64']
elif isinstance(value, (bytes, str)): elif isinstance(value, (bytes, str)):
return 'bytes' return 'bytes'
elif isinstance(value, int): elif isinstance(value, int):
@ -169,24 +266,41 @@ def infer_type(value):
def hash_tree_root(value, typ=None): def hash_tree_root(value, typ=None):
if typ is None: if typ is None:
typ = infer_type(value) typ = infer_type(value)
# -------------------------------------
# merkleize(pack(value))
# basic object: merkleize packed version (merkleization pads it to 32 bytes if it is not already)
if is_basic(typ): if is_basic(typ):
return merkleize(pack([value], typ)) return merkleize(pack([value], typ))
elif isinstance(typ, list) and len(typ) == 1 and is_basic(typ[0]): # or a vector of basic objects
return mix_in_length(merkleize(pack(value, typ[0])), len(value))
elif isinstance(typ, list) and len(typ) == 1 and not is_basic(typ[0]):
return mix_in_length(merkleize([hash_tree_root(element, typ[0]) for element in value]), len(value))
elif isinstance(typ, list) and len(typ) == 2 and is_basic(typ[0]): elif isinstance(typ, list) and len(typ) == 2 and is_basic(typ[0]):
assert len(value) == typ[1] assert len(value) == typ[1]
return merkleize(pack(value, typ[0])) return merkleize(pack(value, typ[0]))
# -------------------------------------
# mix_in_length(merkleize(pack(value)), len(value))
# if value is a list of basic objects
elif isinstance(typ, list) and len(typ) == 1 and is_basic(typ[0]):
return mix_in_length(merkleize(pack(value, typ[0])), len(value))
# (needs some extra work for non-fixed-sized bytes array)
elif typ == 'bytes': elif typ == 'bytes':
return mix_in_length(merkleize(chunkify(coerce_to_bytes(value))), len(value)) return mix_in_length(merkleize(chunkify(coerce_to_bytes(value))), len(value))
# -------------------------------------
# merkleize([hash_tree_root(element) for element in value])
# if value is a vector of composite objects
elif isinstance(typ, list) and len(typ) == 2 and not is_basic(typ[0]):
return merkleize([hash_tree_root(element, typ[0]) for element in value])
# (needs some extra work for fixed-sized bytes array)
elif isinstance(typ, str) and typ[:5] == 'bytes' and len(typ) > 5: elif isinstance(typ, str) and typ[:5] == 'bytes' and len(typ) > 5:
assert len(value) == int(typ[5:]) assert len(value) == int(typ[5:])
return merkleize(chunkify(coerce_to_bytes(value))) return merkleize(chunkify(coerce_to_bytes(value)))
elif isinstance(typ, list) and len(typ) == 2 and not is_basic(typ[0]): # or a container
return merkleize([hash_tree_root(element, typ[0]) for element in value])
elif hasattr(typ, 'fields'): elif hasattr(typ, 'fields'):
return merkleize([hash_tree_root(getattr(value, field), subtype) for field, subtype in typ.fields.items()]) return merkleize([hash_tree_root(getattr(value, field), subtype) for field, subtype in typ.fields.items()])
# -------------------------------------
# mix_in_length(merkleize([hash_tree_root(element) for element in value]), len(value))
# if value is a list of composite objects
elif isinstance(typ, list) and len(typ) == 1 and not is_basic(typ[0]):
return mix_in_length(merkleize([hash_tree_root(element, typ[0]) for element in value]), len(value))
# -------------------------------------
else: else:
raise Exception("Type not recognized") raise Exception("Type not recognized")