update yaml encoder/decoder and obj randomizer for typed ssz usage

This commit is contained in:
protolambda 2019-05-27 22:19:18 +02:00
parent c68944bd53
commit 0f79ed709b
No known key found for this signature in database
GPG Key ID: EC89FDBB2B4C7623
3 changed files with 68 additions and 71 deletions

View File

@ -3,22 +3,22 @@ from eth2spec.utils.ssz.ssz_typing import *
def decode(data, typ): def decode(data, typ):
if is_uint(typ): if is_uint_type(typ):
return data return data
elif is_bool_type(typ): elif is_bool_type(typ):
assert data in (True, False) assert data in (True, False)
return data return data
elif issubclass(typ, list): elif is_list_type(typ):
elem_typ = read_list_elem_typ(typ) elem_typ = read_list_elem_type(typ)
return [decode(element, elem_typ) for element in data] return [decode(element, elem_typ) for element in data]
elif issubclass(typ, Vector): elif is_vector_type(typ):
elem_typ = read_vector_elem_typ(typ) elem_typ = read_vector_elem_type(typ)
return Vector(decode(element, elem_typ) for element in data) return Vector(decode(element, elem_typ) for element in data)
elif issubclass(typ, bytes): elif is_bytes_type(typ):
return bytes.fromhex(data[2:]) return bytes.fromhex(data[2:])
elif issubclass(typ, BytesN): elif is_bytesn_type(typ):
return BytesN(bytes.fromhex(data[2:])) return BytesN(bytes.fromhex(data[2:]))
elif is_container_typ(typ): elif is_container_type(typ):
temp = {} temp = {}
for field, subtype in typ.get_fields(): for field, subtype in typ.get_fields():
temp[field] = decode(data[field], subtype) temp[field] = decode(data[field], subtype)

View File

@ -3,19 +3,20 @@ from eth2spec.utils.ssz.ssz_typing import *
def encode(value, typ, include_hash_tree_roots=False): def encode(value, typ, include_hash_tree_roots=False):
if is_uint(typ): if is_uint_type(typ):
# Larger uints are boxed and the class declares their byte length
if issubclass(typ, uint) and typ.byte_len > 8: if issubclass(typ, uint) and typ.byte_len > 8:
return str(value) return str(value)
return value return value
elif is_bool_type(typ): elif is_bool_type(typ):
assert value in (True, False) assert value in (True, False)
return value return value
elif issubclass(typ, list) or issubclass(typ, Vector): elif is_list_type(typ) or is_vector_type(typ):
elem_typ = read_elem_typ(typ) elem_typ = read_elem_type(typ)
return [encode(element, elem_typ, include_hash_tree_roots) for element in value] return [encode(element, elem_typ, include_hash_tree_roots) for element in value]
elif issubclass(typ, bytes): elif issubclass(typ, bytes): # both bytes and BytesN
return '0x' + value.hex() return '0x' + value.hex()
elif is_container_typ(typ): elif is_container_type(typ):
ret = {} ret = {}
for field, subtype in typ.get_fields(): for field, subtype in typ.get_fields():
field_value = getattr(value, field) field_value = getattr(value, field)

View File

@ -2,10 +2,11 @@ from random import Random
from typing import Any from typing import Any
from enum import Enum from enum import Enum
from eth2spec.utils.ssz.ssz_typing import *
from eth2spec.utils.ssz.ssz_impl import is_basic_type
UINT_SIZES = [8, 16, 32, 64, 128, 256] # in bytes
UINT_SIZES = [1, 2, 4, 8, 16, 32]
basic_types = ["uint%d" % v for v in UINT_SIZES] + ['bool', 'byte']
random_mode_names = ["random", "zero", "max", "nil", "one", "lengthy"] random_mode_names = ["random", "zero", "max", "nil", "one", "lengthy"]
@ -49,60 +50,61 @@ def get_random_ssz_object(rng: Random,
""" """
if chaos: if chaos:
mode = rng.choice(list(RandomizationMode)) mode = rng.choice(list(RandomizationMode))
if isinstance(typ, str): # Bytes array
# Bytes array if is_bytes_type(typ):
if typ == 'bytes': if mode == RandomizationMode.mode_nil_count:
if mode == RandomizationMode.mode_nil_count: return b''
return b'' if mode == RandomizationMode.mode_max_count:
if mode == RandomizationMode.mode_max_count: return get_random_bytes_list(rng, max_bytes_length)
return get_random_bytes_list(rng, max_bytes_length) if mode == RandomizationMode.mode_one_count:
if mode == RandomizationMode.mode_one_count: return get_random_bytes_list(rng, 1)
return get_random_bytes_list(rng, 1) if mode == RandomizationMode.mode_zero:
if mode == RandomizationMode.mode_zero: return b'\x00'
return b'\x00' if mode == RandomizationMode.mode_max:
if mode == RandomizationMode.mode_max: return b'\xff'
return b'\xff' return get_random_bytes_list(rng, rng.randint(0, max_bytes_length))
return get_random_bytes_list(rng, rng.randint(0, max_bytes_length)) elif is_bytesn_type(typ):
elif typ[:5] == 'bytes' and len(typ) > 5: length = typ.length
length = int(typ[5:]) # Sanity, don't generate absurdly big random values
# Sanity, don't generate absurdly big random values # If a client is aiming to performance-test, they should create a benchmark suite.
# If a client is aiming to performance-test, they should create a benchmark suite. assert length <= max_bytes_length
assert length <= max_bytes_length if mode == RandomizationMode.mode_zero:
if mode == RandomizationMode.mode_zero: return b'\x00' * length
return b'\x00' * length if mode == RandomizationMode.mode_max:
if mode == RandomizationMode.mode_max: return b'\xff' * length
return b'\xff' * length return get_random_bytes_list(rng, length)
return get_random_bytes_list(rng, length) # Basic types
# Basic types elif is_basic_type(typ):
else: if mode == RandomizationMode.mode_zero:
if mode == RandomizationMode.mode_zero: return get_min_basic_value(typ)
return get_min_basic_value(typ) if mode == RandomizationMode.mode_max:
if mode == RandomizationMode.mode_max: return get_max_basic_value(typ)
return get_max_basic_value(typ) return get_random_basic_value(rng, typ)
return get_random_basic_value(rng, typ)
# Vector: # Vector:
elif isinstance(typ, list) and len(typ) == 2: elif is_vector_type(typ):
elem_typ = read_vector_elem_type(typ)
return [ return [
get_random_ssz_object(rng, typ[0], max_bytes_length, max_list_length, mode, chaos) get_random_ssz_object(rng, elem_typ, max_bytes_length, max_list_length, mode, chaos)
for _ in range(typ[1]) for _ in range(typ.length)
] ]
# List: # List:
elif isinstance(typ, list) and len(typ) == 1: elif is_list_type(typ):
elem_typ = read_list_elem_type(typ)
length = rng.randint(0, max_list_length) length = rng.randint(0, max_list_length)
if mode == RandomizationMode.mode_one_count: if mode == RandomizationMode.mode_one_count:
length = 1 length = 1
if mode == RandomizationMode.mode_max_count: if mode == RandomizationMode.mode_max_count:
length = max_list_length length = max_list_length
return [ return [
get_random_ssz_object(rng, typ[0], max_bytes_length, max_list_length, mode, chaos) get_random_ssz_object(rng, elem_typ, max_bytes_length, max_list_length, mode, chaos)
for _ in range(length) for _ in range(length)
] ]
# Container: # Container:
elif hasattr(typ, 'fields'): elif is_container_type(typ):
return typ(**{ return typ(**{
field: field:
get_random_ssz_object(rng, subtype, max_bytes_length, max_list_length, mode, chaos) get_random_ssz_object(rng, subtype, max_bytes_length, max_list_length, mode, chaos)
for field, subtype in typ.fields.items() for field, subtype in typ.get_fields()
}) })
else: else:
print(typ) print(typ)
@ -114,39 +116,33 @@ def get_random_bytes_list(rng: Random, length: int) -> bytes:
def get_random_basic_value(rng: Random, typ: str) -> Any: def get_random_basic_value(rng: Random, typ: str) -> Any:
if typ == 'bool': if is_bool_type(typ):
return rng.choice((True, False)) return rng.choice((True, False))
if typ[:4] == 'uint': if is_uint_type(typ):
size = int(typ[4:]) size = uint_byte_size(typ)
assert size in UINT_SIZES assert size in UINT_SIZES
return rng.randint(0, 2**size - 1) return rng.randint(0, 256**size - 1)
if typ == 'byte':
return rng.randint(0, 8)
else: else:
raise ValueError("Not a basic type") raise ValueError("Not a basic type")
def get_min_basic_value(typ: str) -> Any: def get_min_basic_value(typ: str) -> Any:
if typ == 'bool': if is_bool_type(typ):
return False return False
if typ[:4] == 'uint': if is_uint_type(typ):
size = int(typ[4:]) size = uint_byte_size(typ)
assert size in UINT_SIZES assert size in UINT_SIZES
return 0 return 0
if typ == 'byte':
return 0x00
else: else:
raise ValueError("Not a basic type") raise ValueError("Not a basic type")
def get_max_basic_value(typ: str) -> Any: def get_max_basic_value(typ: str) -> Any:
if typ == 'bool': if is_bool_type(typ):
return True return True
if typ[:4] == 'uint': if is_uint_type(typ):
size = int(typ[4:]) size = uint_byte_size(typ)
assert size in UINT_SIZES assert size in UINT_SIZES
return 2**size - 1 return 256**size - 1
if typ == 'byte':
return 0xff
else: else:
raise ValueError("Not a basic type") raise ValueError("Not a basic type")