more improvements, and implement new space-efficient merkleization with padding support
This commit is contained in:
parent
d1ecfd510e
commit
b6cf809d9b
|
@ -44,11 +44,31 @@ def next_power_of_two(v: int) -> int:
|
||||||
return 1 << (v - 1).bit_length()
|
return 1 << (v - 1).bit_length()
|
||||||
|
|
||||||
|
|
||||||
def merkleize_chunks(chunks):
|
def merkleize_chunks(chunks, pad_to: int = None):
|
||||||
tree = chunks[::]
|
count = len(chunks)
|
||||||
margin = next_power_of_two(len(chunks)) - len(chunks)
|
depth = max(count - 1, 0).bit_length()
|
||||||
tree.extend([ZERO_BYTES32] * margin)
|
max_depth = max(depth, (pad_to - 1).bit_length())
|
||||||
tree = [ZERO_BYTES32] * len(tree) + tree
|
tmp = [None for _ in range(max_depth + 1)]
|
||||||
for i in range(len(tree) // 2 - 1, 0, -1):
|
|
||||||
tree[i] = hash(tree[i * 2] + tree[i * 2 + 1])
|
def merge(h, i):
|
||||||
return tree[1]
|
j = 0
|
||||||
|
while True:
|
||||||
|
if i & (1 << j) == 0:
|
||||||
|
if i == count and j < depth:
|
||||||
|
h = hash(h + zerohashes[j])
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
h = hash(tmp[j] + h)
|
||||||
|
j += 1
|
||||||
|
tmp[j] = h
|
||||||
|
|
||||||
|
for i in range(count):
|
||||||
|
merge(chunks[i], i)
|
||||||
|
|
||||||
|
merge(zerohashes[0], count)
|
||||||
|
|
||||||
|
for j in range(depth, max_depth):
|
||||||
|
tmp[j + 1] = hash(tmp[j] + zerohashes[j])
|
||||||
|
|
||||||
|
return tmp[max_depth]
|
||||||
|
|
|
@ -1,11 +1,7 @@
|
||||||
from ..merkle_minimal import merkleize_chunks, ZERO_BYTES32
|
from ..merkle_minimal import merkleize_chunks, ZERO_BYTES32
|
||||||
from .hash_function import hash
|
from ..hash_function import hash
|
||||||
from .ssz_typing import (
|
from .ssz_typing import (
|
||||||
is_uint_type, is_bool_type, is_container_type,
|
get_zero_value, Container, List, Vector, Bytes, BytesN, uint
|
||||||
is_list_kind, is_vector_kind,
|
|
||||||
read_elem_type,
|
|
||||||
infer_input_type,
|
|
||||||
get_zero_value,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# SSZ Serialization
|
# SSZ Serialization
|
||||||
|
@ -15,13 +11,13 @@ BYTES_PER_LENGTH_OFFSET = 4
|
||||||
|
|
||||||
|
|
||||||
def is_basic_type(typ):
|
def is_basic_type(typ):
|
||||||
return is_uint_type(typ) or is_bool_type(typ)
|
return issubclass(typ, (bool, uint))
|
||||||
|
|
||||||
|
|
||||||
def serialize_basic(value, typ):
|
def serialize_basic(value, typ):
|
||||||
if is_uint_type(typ):
|
if issubclass(typ, uint):
|
||||||
return value.to_bytes(typ.byte_len, 'little')
|
return value.to_bytes(typ.byte_len, 'little')
|
||||||
elif is_bool_type(typ):
|
elif issubclass(typ, bool):
|
||||||
if value:
|
if value:
|
||||||
return b'\x01'
|
return b'\x01'
|
||||||
else:
|
else:
|
||||||
|
@ -31,22 +27,34 @@ def serialize_basic(value, typ):
|
||||||
|
|
||||||
|
|
||||||
def deserialize_basic(value, typ):
|
def deserialize_basic(value, typ):
|
||||||
if is_uint_type(typ):
|
if issubclass(typ, uint):
|
||||||
return typ(int.from_bytes(value, 'little'))
|
return typ(int.from_bytes(value, 'little'))
|
||||||
elif is_bool_type(typ):
|
elif issubclass(typ, bool):
|
||||||
assert value in (b'\x00', b'\x01')
|
assert value in (b'\x00', b'\x01')
|
||||||
return True if value == b'\x01' else False
|
return True if value == b'\x01' else False
|
||||||
else:
|
else:
|
||||||
raise Exception("Type not supported: {}".format(typ))
|
raise Exception("Type not supported: {}".format(typ))
|
||||||
|
|
||||||
|
|
||||||
|
def is_list_kind(typ):
|
||||||
|
return issubclass(typ, (List, Bytes))
|
||||||
|
|
||||||
|
|
||||||
|
def is_vector_kind(typ):
|
||||||
|
return issubclass(typ, (Vector, BytesN))
|
||||||
|
|
||||||
|
|
||||||
|
def is_container_type(typ):
|
||||||
|
return issubclass(typ, Container)
|
||||||
|
|
||||||
|
|
||||||
def is_fixed_size(typ):
|
def is_fixed_size(typ):
|
||||||
if is_basic_type(typ):
|
if is_basic_type(typ):
|
||||||
return True
|
return True
|
||||||
elif is_list_kind(typ):
|
elif is_list_kind(typ):
|
||||||
return False
|
return False
|
||||||
elif is_vector_kind(typ):
|
elif is_vector_kind(typ):
|
||||||
return is_fixed_size(read_vector_elem_type(typ))
|
return is_fixed_size(typ.elem_type)
|
||||||
elif is_container_type(typ):
|
elif is_container_type(typ):
|
||||||
return all(is_fixed_size(t) for t in typ.get_field_types())
|
return all(is_fixed_size(t) for t in typ.get_field_types())
|
||||||
else:
|
else:
|
||||||
|
@ -57,12 +65,11 @@ def is_empty(obj):
|
||||||
return get_zero_value(type(obj)) == obj
|
return get_zero_value(type(obj)) == obj
|
||||||
|
|
||||||
|
|
||||||
@infer_input_type
|
def serialize(obj, typ):
|
||||||
def serialize(obj, typ=None):
|
|
||||||
if is_basic_type(typ):
|
if is_basic_type(typ):
|
||||||
return serialize_basic(obj, typ)
|
return serialize_basic(obj, typ)
|
||||||
elif is_list_kind(typ) or is_vector_kind(typ):
|
elif is_list_kind(typ) or is_vector_kind(typ):
|
||||||
return encode_series(obj, [read_elem_type(typ)] * len(obj))
|
return encode_series(obj, [typ.elem_type] * len(obj))
|
||||||
elif is_container_type(typ):
|
elif is_container_type(typ):
|
||||||
return encode_series(obj.get_field_values(), typ.get_field_types())
|
return encode_series(obj.get_field_values(), typ.get_field_types())
|
||||||
else:
|
else:
|
||||||
|
@ -126,40 +133,41 @@ def mix_in_length(root, length):
|
||||||
def is_bottom_layer_kind(typ):
|
def is_bottom_layer_kind(typ):
|
||||||
return (
|
return (
|
||||||
is_basic_type(typ) or
|
is_basic_type(typ) or
|
||||||
(is_list_kind(typ) or is_vector_kind(typ)) and is_basic_type(read_elem_type(typ))
|
(is_list_kind(typ) or is_vector_kind(typ)) and is_basic_type(typ.elem_type)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@infer_input_type
|
def get_typed_values(obj, typ):
|
||||||
def get_typed_values(obj, typ=None):
|
|
||||||
if is_container_type(typ):
|
if is_container_type(typ):
|
||||||
return obj.get_typed_values()
|
return obj.get_typed_values()
|
||||||
elif is_list_kind(typ) or is_vector_kind(typ):
|
elif is_list_kind(typ) or is_vector_kind(typ):
|
||||||
elem_type = read_elem_type(typ)
|
return list(zip(obj, [typ.elem_type] * len(obj)))
|
||||||
return list(zip(obj, [elem_type] * len(obj)))
|
|
||||||
else:
|
else:
|
||||||
raise Exception("Invalid type")
|
raise Exception("Invalid type")
|
||||||
|
|
||||||
def item_length(typ):
|
|
||||||
return 1 if typ == bool else typ.byte_len if is_uint_type(typ) else 32
|
|
||||||
|
|
||||||
@infer_input_type
|
def item_length(typ):
|
||||||
def hash_tree_root(obj, typ=None):
|
if typ == bool:
|
||||||
|
return 1
|
||||||
|
elif issubclass(typ, uint):
|
||||||
|
return typ.byte_len
|
||||||
|
else:
|
||||||
|
return 32
|
||||||
|
|
||||||
|
|
||||||
|
def hash_tree_root(obj, typ):
|
||||||
if is_bottom_layer_kind(typ):
|
if is_bottom_layer_kind(typ):
|
||||||
data = serialize_basic(obj, typ) if is_basic_type(typ) else pack(obj, read_elem_type(typ))
|
data = serialize_basic(obj, typ) if is_basic_type(typ) else pack(obj, typ.elem_type)
|
||||||
leaves = chunkify(data)
|
leaves = chunkify(data)
|
||||||
else:
|
else:
|
||||||
fields = get_typed_values(obj, typ=typ)
|
fields = get_typed_values(obj, typ=typ)
|
||||||
leaves = [hash_tree_root(field_value, typ=field_typ) for field_value, field_typ in fields]
|
leaves = [hash_tree_root(field_value, typ=field_typ) for field_value, field_typ in fields]
|
||||||
if is_list_kind(typ):
|
if is_list_kind(typ):
|
||||||
full_chunk_length = (item_length(read_elem_type(typ)) * typ.length + 31) // 32
|
return mix_in_length(merkleize_chunks(leaves, pad_to=typ.length), len(obj))
|
||||||
leaves += [ZERO_BYTES32] * (full_chunk_length - len(obj))
|
|
||||||
return mix_in_length(merkleize_chunks(leaves), len(obj))
|
|
||||||
else:
|
else:
|
||||||
return merkleize_chunks(leaves)
|
return merkleize_chunks(leaves)
|
||||||
|
|
||||||
|
|
||||||
@infer_input_type
|
|
||||||
def signing_root(obj, typ):
|
def signing_root(obj, typ):
|
||||||
assert is_container_type(typ)
|
assert is_container_type(typ)
|
||||||
# ignore last field
|
# ignore last field
|
||||||
|
|
|
@ -1,10 +1,26 @@
|
||||||
|
from typing import NewType, Union
|
||||||
from types import GeneratorType
|
from types import GeneratorType
|
||||||
|
|
||||||
|
|
||||||
|
class ValueCheckError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DefaultingTypeMeta(type):
|
class DefaultingTypeMeta(type):
|
||||||
def default(cls):
|
def default(cls):
|
||||||
raise Exception("Not implemented")
|
raise Exception("Not implemented")
|
||||||
|
|
||||||
|
# Every type is subclassed and has a default() method, except bool.
|
||||||
|
TypeWithDefault = Union[DefaultingTypeMeta, bool]
|
||||||
|
|
||||||
|
|
||||||
|
def get_zero_value(typ: TypeWithDefault):
|
||||||
|
if issubclass(typ, bool):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return typ.default()
|
||||||
|
|
||||||
|
|
||||||
# SSZ integers
|
# SSZ integers
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
|
|
||||||
|
@ -63,7 +79,7 @@ class Container(object):
|
||||||
cls = self.__class__
|
cls = self.__class__
|
||||||
for f, t in cls.get_fields():
|
for f, t in cls.get_fields():
|
||||||
if f not in kwargs:
|
if f not in kwargs:
|
||||||
setattr(self, f, t.default())
|
setattr(self, f, get_zero_value(t))
|
||||||
else:
|
else:
|
||||||
setattr(self, f, kwargs[f])
|
setattr(self, f, kwargs[f])
|
||||||
|
|
||||||
|
@ -120,7 +136,7 @@ class Container(object):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default(cls):
|
def default(cls):
|
||||||
return cls(**{f: t.default() for f, t in cls.get_fields()})
|
return cls(**{f: get_zero_value(t) for f, t in cls.get_fields()})
|
||||||
|
|
||||||
|
|
||||||
class ParamsBase:
|
class ParamsBase:
|
||||||
|
@ -174,12 +190,8 @@ class ParamsMeta(DefaultingTypeMeta):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class ValueCheckError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractListMeta(ParamsMeta):
|
class AbstractListMeta(ParamsMeta):
|
||||||
elem_type: DefaultingTypeMeta
|
elem_type: TypeWithDefault
|
||||||
length: int
|
length: int
|
||||||
|
|
||||||
|
|
||||||
|
@ -227,8 +239,6 @@ class AbstractList(ParamsBase, metaclass=AbstractListMeta):
|
||||||
|
|
||||||
|
|
||||||
class List(AbstractList):
|
class List(AbstractList):
|
||||||
def value_check(self, value):
|
|
||||||
return len(value) <= self.__class__.length and super().value_check(value)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default(cls):
|
def default(cls):
|
||||||
|
@ -241,11 +251,11 @@ class Vector(AbstractList, metaclass=AbstractListMeta):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default(cls):
|
def default(cls):
|
||||||
return [cls.elem_type.default() for _ in range(cls.length)]
|
return [get_zero_value(cls.elem_type) for _ in range(cls.length)]
|
||||||
|
|
||||||
|
|
||||||
class BytesMeta(AbstractListMeta):
|
class BytesMeta(AbstractListMeta):
|
||||||
elem_type: DefaultingTypeMeta = byte
|
elem_type: TypeWithDefault = byte
|
||||||
length: int
|
length: int
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue