Merge pull request #1024 from ethereum/sos_ssz_py

implement SOS serialization in minimal_ssz.py
This commit is contained in:
Danny Ryan 2019-05-03 15:43:07 -06:00 committed by GitHub
commit c011feb3c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 31 additions and 24 deletions

View File

@ -3,7 +3,7 @@ from typing import Any
from .hash_function import hash from .hash_function import hash
BYTES_PER_CHUNK = 32 BYTES_PER_CHUNK = 32
BYTES_PER_LENGTH_PREFIX = 4 BYTES_PER_LENGTH_OFFSET = 4
ZERO_CHUNK = b'\x00' * BYTES_PER_CHUNK ZERO_CHUNK = b'\x00' * BYTES_PER_CHUNK
@ -111,19 +111,34 @@ def coerce_to_bytes(x):
raise Exception("Expecting bytes") raise Exception("Expecting bytes")
def encode_bytes(value): def encode_series(values, types):
serialized_bytes = coerce_to_bytes(value) # Recursively serialize
assert len(serialized_bytes) < 2 ** (8 * BYTES_PER_LENGTH_PREFIX) parts = [(is_constant_sized(types[i]), serialize_value(values[i], types[i])) for i in range(len(values))]
serialized_length = len(serialized_bytes).to_bytes(BYTES_PER_LENGTH_PREFIX, 'little')
return serialized_length + serialized_bytes
# Compute and check lengths
fixed_lengths = [len(serialized) if constant_size else BYTES_PER_LENGTH_OFFSET
for (constant_size, serialized) in parts]
variable_lengths = [len(serialized) if not constant_size else 0
for (constant_size, serialized) in parts]
def encode_variable_size_container(values, types): # Check if integer is not out of bounds (Python)
return encode_bytes(encode_fixed_size_container(values, types)) assert sum(fixed_lengths + variable_lengths) < 2 ** (BYTES_PER_LENGTH_OFFSET * 8)
# Interleave offsets of variable-size parts with fixed-size parts.
# Avoid quadratic complexity in calculation of offsets.
offset = sum(fixed_lengths)
variable_parts = []
fixed_parts = []
for (constant_size, serialized) in parts:
if constant_size:
fixed_parts.append(serialized)
else:
fixed_parts.append(offset.to_bytes(BYTES_PER_LENGTH_OFFSET, 'little'))
variable_parts.append(serialized)
offset += len(serialized)
def encode_fixed_size_container(values, types): # Return the concatenation of the fixed-size parts (offsets interleaved) with the variable-size parts
return b''.join([serialize_value(v, typ) for (v, typ) in zip(values, types)]) return b"".join(fixed_parts + variable_parts)
def serialize_value(value, typ=None): def serialize_value(value, typ=None):
@ -142,18 +157,13 @@ def serialize_value(value, typ=None):
elif isinstance(typ, list) and len(typ) == 2: elif isinstance(typ, list) and len(typ) == 2:
# (regardless of element type, sanity-check if the length reported in the vector type matches the value length) # (regardless of element type, sanity-check if the length reported in the vector type matches the value length)
assert len(value) == typ[1] assert len(value) == typ[1]
# If value is fixed-size (i.e. element type is fixed-size): return encode_series(value, [typ[0]] * len(value))
if is_constant_sized(typ):
return encode_fixed_size_container(value, [typ[0]] * len(value))
# If value is variable-size (i.e. element type is variable-size)
else:
return encode_variable_size_container(value, [typ[0]] * len(value))
# "bytes" (variable size)
elif isinstance(typ, str) and typ == 'bytes':
return encode_bytes(value)
# List # List
elif isinstance(typ, list) and len(typ) == 1: elif isinstance(typ, list) and len(typ) == 1:
return encode_variable_size_container(value, [typ[0]] * len(value)) return encode_series(value, [typ[0]] * len(value))
# "bytes" (variable size)
elif isinstance(typ, str) and typ == 'bytes':
return coerce_to_bytes(value)
# "bytesN" (fixed size) # "bytesN" (fixed size)
elif isinstance(typ, str) and len(typ) > 5 and typ[:5] == 'bytes': elif isinstance(typ, str) and len(typ) > 5 and typ[:5] == 'bytes':
assert len(value) == int(typ[5:]), (value, int(typ[5:])) assert len(value) == int(typ[5:]), (value, int(typ[5:]))
@ -162,10 +172,7 @@ def serialize_value(value, typ=None):
elif hasattr(typ, 'fields'): elif hasattr(typ, 'fields'):
values = [getattr(value, field) for field in typ.fields.keys()] values = [getattr(value, field) for field in typ.fields.keys()]
types = list(typ.fields.values()) types = list(typ.fields.values())
if is_constant_sized(typ): return encode_series(values, types)
return encode_fixed_size_container(values, types)
else:
return encode_variable_size_container(values, types)
else: else:
print(value, typ) print(value, typ)
raise Exception("Type not recognized") raise Exception("Type not recognized")