more improvements, and implement new space-efficient merkleization with padding support
This commit is contained in:
parent
d1ecfd510e
commit
b6cf809d9b
|
@ -44,11 +44,31 @@ def next_power_of_two(v: int) -> int:
|
|||
return 1 << (v - 1).bit_length()
|
||||
|
||||
|
||||
def merkleize_chunks(chunks):
|
||||
tree = chunks[::]
|
||||
margin = next_power_of_two(len(chunks)) - len(chunks)
|
||||
tree.extend([ZERO_BYTES32] * margin)
|
||||
tree = [ZERO_BYTES32] * len(tree) + tree
|
||||
for i in range(len(tree) // 2 - 1, 0, -1):
|
||||
tree[i] = hash(tree[i * 2] + tree[i * 2 + 1])
|
||||
return tree[1]
|
||||
def merkleize_chunks(chunks, pad_to: int = None):
|
||||
count = len(chunks)
|
||||
depth = max(count - 1, 0).bit_length()
|
||||
max_depth = max(depth, (pad_to - 1).bit_length())
|
||||
tmp = [None for _ in range(max_depth + 1)]
|
||||
|
||||
def merge(h, i):
|
||||
j = 0
|
||||
while True:
|
||||
if i & (1 << j) == 0:
|
||||
if i == count and j < depth:
|
||||
h = hash(h + zerohashes[j])
|
||||
else:
|
||||
break
|
||||
else:
|
||||
h = hash(tmp[j] + h)
|
||||
j += 1
|
||||
tmp[j] = h
|
||||
|
||||
for i in range(count):
|
||||
merge(chunks[i], i)
|
||||
|
||||
merge(zerohashes[0], count)
|
||||
|
||||
for j in range(depth, max_depth):
|
||||
tmp[j + 1] = hash(tmp[j] + zerohashes[j])
|
||||
|
||||
return tmp[max_depth]
|
||||
|
|
|
@ -1,11 +1,7 @@
|
|||
from ..merkle_minimal import merkleize_chunks, ZERO_BYTES32
|
||||
from .hash_function import hash
|
||||
from ..hash_function import hash
|
||||
from .ssz_typing import (
|
||||
is_uint_type, is_bool_type, is_container_type,
|
||||
is_list_kind, is_vector_kind,
|
||||
read_elem_type,
|
||||
infer_input_type,
|
||||
get_zero_value,
|
||||
get_zero_value, Container, List, Vector, Bytes, BytesN, uint
|
||||
)
|
||||
|
||||
# SSZ Serialization
|
||||
|
@ -15,13 +11,13 @@ BYTES_PER_LENGTH_OFFSET = 4
|
|||
|
||||
|
||||
def is_basic_type(typ):
|
||||
return is_uint_type(typ) or is_bool_type(typ)
|
||||
return issubclass(typ, (bool, uint))
|
||||
|
||||
|
||||
def serialize_basic(value, typ):
|
||||
if is_uint_type(typ):
|
||||
if issubclass(typ, uint):
|
||||
return value.to_bytes(typ.byte_len, 'little')
|
||||
elif is_bool_type(typ):
|
||||
elif issubclass(typ, bool):
|
||||
if value:
|
||||
return b'\x01'
|
||||
else:
|
||||
|
@ -31,22 +27,34 @@ def serialize_basic(value, typ):
|
|||
|
||||
|
||||
def deserialize_basic(value, typ):
|
||||
if is_uint_type(typ):
|
||||
if issubclass(typ, uint):
|
||||
return typ(int.from_bytes(value, 'little'))
|
||||
elif is_bool_type(typ):
|
||||
elif issubclass(typ, bool):
|
||||
assert value in (b'\x00', b'\x01')
|
||||
return True if value == b'\x01' else False
|
||||
else:
|
||||
raise Exception("Type not supported: {}".format(typ))
|
||||
|
||||
|
||||
def is_list_kind(typ):
|
||||
return issubclass(typ, (List, Bytes))
|
||||
|
||||
|
||||
def is_vector_kind(typ):
|
||||
return issubclass(typ, (Vector, BytesN))
|
||||
|
||||
|
||||
def is_container_type(typ):
|
||||
return issubclass(typ, Container)
|
||||
|
||||
|
||||
def is_fixed_size(typ):
|
||||
if is_basic_type(typ):
|
||||
return True
|
||||
elif is_list_kind(typ):
|
||||
return False
|
||||
elif is_vector_kind(typ):
|
||||
return is_fixed_size(read_vector_elem_type(typ))
|
||||
return is_fixed_size(typ.elem_type)
|
||||
elif is_container_type(typ):
|
||||
return all(is_fixed_size(t) for t in typ.get_field_types())
|
||||
else:
|
||||
|
@ -57,12 +65,11 @@ def is_empty(obj):
|
|||
return get_zero_value(type(obj)) == obj
|
||||
|
||||
|
||||
@infer_input_type
|
||||
def serialize(obj, typ=None):
|
||||
def serialize(obj, typ):
|
||||
if is_basic_type(typ):
|
||||
return serialize_basic(obj, typ)
|
||||
elif is_list_kind(typ) or is_vector_kind(typ):
|
||||
return encode_series(obj, [read_elem_type(typ)] * len(obj))
|
||||
return encode_series(obj, [typ.elem_type] * len(obj))
|
||||
elif is_container_type(typ):
|
||||
return encode_series(obj.get_field_values(), typ.get_field_types())
|
||||
else:
|
||||
|
@ -126,40 +133,41 @@ def mix_in_length(root, length):
|
|||
def is_bottom_layer_kind(typ):
|
||||
return (
|
||||
is_basic_type(typ) or
|
||||
(is_list_kind(typ) or is_vector_kind(typ)) and is_basic_type(read_elem_type(typ))
|
||||
(is_list_kind(typ) or is_vector_kind(typ)) and is_basic_type(typ.elem_type)
|
||||
)
|
||||
|
||||
|
||||
@infer_input_type
|
||||
def get_typed_values(obj, typ=None):
|
||||
def get_typed_values(obj, typ):
|
||||
if is_container_type(typ):
|
||||
return obj.get_typed_values()
|
||||
elif is_list_kind(typ) or is_vector_kind(typ):
|
||||
elem_type = read_elem_type(typ)
|
||||
return list(zip(obj, [elem_type] * len(obj)))
|
||||
return list(zip(obj, [typ.elem_type] * len(obj)))
|
||||
else:
|
||||
raise Exception("Invalid type")
|
||||
|
||||
def item_length(typ):
|
||||
return 1 if typ == bool else typ.byte_len if is_uint_type(typ) else 32
|
||||
|
||||
@infer_input_type
|
||||
def hash_tree_root(obj, typ=None):
|
||||
def item_length(typ):
|
||||
if typ == bool:
|
||||
return 1
|
||||
elif issubclass(typ, uint):
|
||||
return typ.byte_len
|
||||
else:
|
||||
return 32
|
||||
|
||||
|
||||
def hash_tree_root(obj, typ):
|
||||
if is_bottom_layer_kind(typ):
|
||||
data = serialize_basic(obj, typ) if is_basic_type(typ) else pack(obj, read_elem_type(typ))
|
||||
data = serialize_basic(obj, typ) if is_basic_type(typ) else pack(obj, typ.elem_type)
|
||||
leaves = chunkify(data)
|
||||
else:
|
||||
fields = get_typed_values(obj, typ=typ)
|
||||
leaves = [hash_tree_root(field_value, typ=field_typ) for field_value, field_typ in fields]
|
||||
if is_list_kind(typ):
|
||||
full_chunk_length = (item_length(read_elem_type(typ)) * typ.length + 31) // 32
|
||||
leaves += [ZERO_BYTES32] * (full_chunk_length - len(obj))
|
||||
return mix_in_length(merkleize_chunks(leaves), len(obj))
|
||||
return mix_in_length(merkleize_chunks(leaves, pad_to=typ.length), len(obj))
|
||||
else:
|
||||
return merkleize_chunks(leaves)
|
||||
|
||||
|
||||
@infer_input_type
|
||||
def signing_root(obj, typ):
|
||||
assert is_container_type(typ)
|
||||
# ignore last field
|
||||
|
|
|
@ -1,10 +1,26 @@
|
|||
from typing import NewType, Union
|
||||
from types import GeneratorType
|
||||
|
||||
|
||||
class ValueCheckError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DefaultingTypeMeta(type):
|
||||
def default(cls):
|
||||
raise Exception("Not implemented")
|
||||
|
||||
# Every type is subclassed and has a default() method, except bool.
|
||||
TypeWithDefault = Union[DefaultingTypeMeta, bool]
|
||||
|
||||
|
||||
def get_zero_value(typ: TypeWithDefault):
|
||||
if issubclass(typ, bool):
|
||||
return False
|
||||
else:
|
||||
return typ.default()
|
||||
|
||||
|
||||
# SSZ integers
|
||||
# -----------------------------
|
||||
|
||||
|
@ -63,7 +79,7 @@ class Container(object):
|
|||
cls = self.__class__
|
||||
for f, t in cls.get_fields():
|
||||
if f not in kwargs:
|
||||
setattr(self, f, t.default())
|
||||
setattr(self, f, get_zero_value(t))
|
||||
else:
|
||||
setattr(self, f, kwargs[f])
|
||||
|
||||
|
@ -120,7 +136,7 @@ class Container(object):
|
|||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
return cls(**{f: t.default() for f, t in cls.get_fields()})
|
||||
return cls(**{f: get_zero_value(t) for f, t in cls.get_fields()})
|
||||
|
||||
|
||||
class ParamsBase:
|
||||
|
@ -174,12 +190,8 @@ class ParamsMeta(DefaultingTypeMeta):
|
|||
return True
|
||||
|
||||
|
||||
class ValueCheckError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AbstractListMeta(ParamsMeta):
|
||||
elem_type: DefaultingTypeMeta
|
||||
elem_type: TypeWithDefault
|
||||
length: int
|
||||
|
||||
|
||||
|
@ -227,8 +239,6 @@ class AbstractList(ParamsBase, metaclass=AbstractListMeta):
|
|||
|
||||
|
||||
class List(AbstractList):
|
||||
def value_check(self, value):
|
||||
return len(value) <= self.__class__.length and super().value_check(value)
|
||||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
|
@ -241,11 +251,11 @@ class Vector(AbstractList, metaclass=AbstractListMeta):
|
|||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
return [cls.elem_type.default() for _ in range(cls.length)]
|
||||
return [get_zero_value(cls.elem_type) for _ in range(cls.length)]
|
||||
|
||||
|
||||
class BytesMeta(AbstractListMeta):
|
||||
elem_type: DefaultingTypeMeta = byte
|
||||
elem_type: TypeWithDefault = byte
|
||||
length: int
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue