bugfixes and typing improvements
This commit is contained in:
parent
08e6f32f38
commit
8bd2e878ef
|
@ -44,7 +44,7 @@ def next_power_of_two(v: int) -> int:
|
||||||
return 1 << (v - 1).bit_length()
|
return 1 << (v - 1).bit_length()
|
||||||
|
|
||||||
|
|
||||||
def merkleize_chunks(chunks, pad_to: int = None):
|
def merkleize_chunks(chunks, pad_to: int = 1):
|
||||||
count = len(chunks)
|
count = len(chunks)
|
||||||
depth = max(count - 1, 0).bit_length()
|
depth = max(count - 1, 0).bit_length()
|
||||||
max_depth = max(depth, (pad_to - 1).bit_length())
|
max_depth = max(depth, (pad_to - 1).bit_length())
|
||||||
|
@ -55,7 +55,7 @@ def merkleize_chunks(chunks, pad_to: int = None):
|
||||||
while True:
|
while True:
|
||||||
if i & (1 << j) == 0:
|
if i & (1 << j) == 0:
|
||||||
if i == count and j < depth:
|
if i == count and j < depth:
|
||||||
h = hash(h + zerohashes[j])
|
h = hash(h + zerohashes[j]) # keep going if we are complementing the void to the next power of 2
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
@ -63,11 +63,15 @@ def merkleize_chunks(chunks, pad_to: int = None):
|
||||||
j += 1
|
j += 1
|
||||||
tmp[j] = h
|
tmp[j] = h
|
||||||
|
|
||||||
|
# merge in leaf by leaf.
|
||||||
for i in range(count):
|
for i in range(count):
|
||||||
merge(chunks[i], i)
|
merge(chunks[i], i)
|
||||||
|
|
||||||
merge(zerohashes[0], count)
|
# complement with 0 if empty, or if not the right power of 2
|
||||||
|
if 1 << depth != count:
|
||||||
|
merge(zerohashes[0], count)
|
||||||
|
|
||||||
|
# the next power of two may be smaller than the ultimate virtual size, complement with zero-hashes at each depth.
|
||||||
for j in range(depth, max_depth):
|
for j in range(depth, max_depth):
|
||||||
tmp[j + 1] = hash(tmp[j] + zerohashes[j])
|
tmp[j + 1] = hash(tmp[j] + zerohashes[j])
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from ..merkle_minimal import merkleize_chunks
|
from ..merkle_minimal import merkleize_chunks
|
||||||
from ..hash_function import hash
|
from ..hash_function import hash
|
||||||
from .ssz_typing import (
|
from .ssz_typing import (
|
||||||
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Vector, Bytes, BytesN, uint
|
SSZValue, SSZType, BasicValue, BasicType, Series, ElementsType, Elements, Bit, Container, List, Vector, Bytes, BytesN, uint
|
||||||
)
|
)
|
||||||
|
|
||||||
# SSZ Serialization
|
# SSZ Serialization
|
||||||
|
@ -47,8 +47,8 @@ def serialize(obj: SSZValue):
|
||||||
|
|
||||||
def encode_series(values: Series):
|
def encode_series(values: Series):
|
||||||
# bytes and bytesN are already in the right format.
|
# bytes and bytesN are already in the right format.
|
||||||
if isinstance(values, bytes):
|
if isinstance(values, (Bytes, BytesN)):
|
||||||
return values
|
return values.items
|
||||||
|
|
||||||
# Recursively serialize
|
# Recursively serialize
|
||||||
parts = [(v.type().is_fixed_size(), serialize(v)) for v in values]
|
parts = [(v.type().is_fixed_size(), serialize(v)) for v in values]
|
||||||
|
@ -84,8 +84,8 @@ def encode_series(values: Series):
|
||||||
|
|
||||||
|
|
||||||
def pack(values: Series):
|
def pack(values: Series):
|
||||||
if isinstance(values, bytes):
|
if isinstance(values, (Bytes, BytesN)):
|
||||||
return values
|
return values.items
|
||||||
return b''.join([serialize_basic(value) for value in values])
|
return b''.join([serialize_basic(value) for value in values])
|
||||||
|
|
||||||
|
|
||||||
|
@ -101,8 +101,8 @@ def mix_in_length(root, length):
|
||||||
|
|
||||||
def is_bottom_layer_kind(typ: SSZType):
|
def is_bottom_layer_kind(typ: SSZType):
|
||||||
return (
|
return (
|
||||||
issubclass(typ, BasicType) or
|
isinstance(typ, BasicType) or
|
||||||
(issubclass(typ, Elements) and issubclass(typ.elem_type, BasicType))
|
(issubclass(typ, Elements) and isinstance(typ.elem_type, BasicType))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -114,7 +114,7 @@ def item_length(typ: SSZType) -> int:
|
||||||
|
|
||||||
|
|
||||||
def chunk_count(typ: SSZType) -> int:
|
def chunk_count(typ: SSZType) -> int:
|
||||||
if issubclass(typ, BasicType):
|
if isinstance(typ, BasicType):
|
||||||
return 1
|
return 1
|
||||||
elif issubclass(typ, Elements):
|
elif issubclass(typ, Elements):
|
||||||
return (typ.length * item_length(typ.elem_type) + 31) // 32
|
return (typ.length * item_length(typ.elem_type) + 31) // 32
|
||||||
|
@ -133,7 +133,7 @@ def hash_tree_root(obj: SSZValue):
|
||||||
elif isinstance(obj, BasicValue):
|
elif isinstance(obj, BasicValue):
|
||||||
leaves = chunkify(serialize_basic(obj))
|
leaves = chunkify(serialize_basic(obj))
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Type not supported: {obj.type()}")
|
raise Exception(f"Type not supported: {type(obj)}")
|
||||||
|
|
||||||
if isinstance(obj, (List, Bytes)):
|
if isinstance(obj, (List, Bytes)):
|
||||||
return mix_in_length(merkleize_chunks(leaves, pad_to=chunk_count(obj.type())), len(obj))
|
return mix_in_length(merkleize_chunks(leaves, pad_to=chunk_count(obj.type())), len(obj))
|
||||||
|
|
|
@ -92,6 +92,22 @@ class uint256(uint):
|
||||||
byte_len = 32
|
byte_len = 32
|
||||||
|
|
||||||
|
|
||||||
|
def coerce_type_maybe(v, typ: SSZType):
|
||||||
|
v_typ = type(v)
|
||||||
|
# shortcut if it's already the type we are looking for
|
||||||
|
if v_typ == typ:
|
||||||
|
return v
|
||||||
|
elif isinstance(v, int):
|
||||||
|
return typ(v)
|
||||||
|
elif isinstance(v, (list, tuple)):
|
||||||
|
return typ(*v)
|
||||||
|
elif isinstance(v, GeneratorType):
|
||||||
|
return typ(v)
|
||||||
|
else:
|
||||||
|
# just return as-is, Value-checkers will take care of it not being coerced.
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class Series(SSZValue):
|
class Series(SSZValue):
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[SSZValue]:
|
def __iter__(self) -> Iterator[SSZValue]:
|
||||||
|
@ -108,7 +124,11 @@ class Container(Series, metaclass=SSZType):
|
||||||
if f not in kwargs:
|
if f not in kwargs:
|
||||||
setattr(self, f, t.default())
|
setattr(self, f, t.default())
|
||||||
else:
|
else:
|
||||||
setattr(self, f, kwargs[f])
|
value = coerce_type_maybe(kwargs[f], t)
|
||||||
|
if not isinstance(value, t):
|
||||||
|
raise ValueCheckError(f"Bad input for class {self.__class__}:"
|
||||||
|
f" field: {f} type: {t} value: {value} value type: {type(value)}")
|
||||||
|
setattr(self, f, value)
|
||||||
|
|
||||||
def serialize(self):
|
def serialize(self):
|
||||||
from .ssz_impl import serialize
|
from .ssz_impl import serialize
|
||||||
|
@ -141,23 +161,22 @@ class Container(Series, metaclass=SSZType):
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(self.hash_tree_root())
|
return hash(self.hash_tree_root())
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_fields_dict(cls) -> Dict[str, SSZType]:
|
|
||||||
return dict(cls.__annotations__)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_fields(cls) -> Tuple[Tuple[str, SSZType], ...]:
|
def get_fields(cls) -> Tuple[Tuple[str, SSZType], ...]:
|
||||||
|
if not hasattr(cls, '__annotations__'): # no container fields
|
||||||
|
return ()
|
||||||
return tuple((f, t) for f, t in cls.__annotations__.items())
|
return tuple((f, t) for f, t in cls.__annotations__.items())
|
||||||
|
|
||||||
def get_typed_values(self):
|
|
||||||
return tuple(zip(self.get_field_values(), self.get_field_types()))
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_field_names(cls) -> Tuple[str]:
|
def get_field_names(cls) -> Tuple[str, ...]:
|
||||||
|
if not hasattr(cls, '__annotations__'): # no container fields
|
||||||
|
return ()
|
||||||
return tuple(cls.__annotations__.keys())
|
return tuple(cls.__annotations__.keys())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_field_types(cls) -> Tuple[SSZType, ...]:
|
def get_field_types(cls) -> Tuple[SSZType, ...]:
|
||||||
|
if not hasattr(cls, '__annotations__'): # no container fields
|
||||||
|
return ()
|
||||||
# values of annotations are the types corresponding to the fields, not instance values.
|
# values of annotations are the types corresponding to the fields, not instance values.
|
||||||
return tuple(cls.__annotations__.values())
|
return tuple(cls.__annotations__.values())
|
||||||
|
|
||||||
|
@ -233,12 +252,12 @@ class ParamsMeta(SSZType):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class Elements(ParamsMeta):
|
class ElementsType(ParamsMeta):
|
||||||
elem_type: SSZType
|
elem_type: SSZType
|
||||||
length: int
|
length: int
|
||||||
|
|
||||||
|
|
||||||
class ElementsBase(ParamsBase, metaclass=Elements):
|
class Elements(ParamsBase, metaclass=ElementsType):
|
||||||
|
|
||||||
def __init__(self, *args):
|
def __init__(self, *args):
|
||||||
items = self.extract_args(*args)
|
items = self.extract_args(*args)
|
||||||
|
@ -256,6 +275,7 @@ class ElementsBase(ParamsBase, metaclass=Elements):
|
||||||
x = list(args)
|
x = list(args)
|
||||||
if len(x) == 1 and isinstance(x[0], GeneratorType):
|
if len(x) == 1 and isinstance(x[0], GeneratorType):
|
||||||
x = list(x[0])
|
x = list(x[0])
|
||||||
|
x = [coerce_type_maybe(v, cls.elem_type) for v in x]
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
@ -281,7 +301,7 @@ class ElementsBase(ParamsBase, metaclass=Elements):
|
||||||
return self.items == other.items
|
return self.items == other.items
|
||||||
|
|
||||||
|
|
||||||
class List(ElementsBase):
|
class List(Elements):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default(cls):
|
def default(cls):
|
||||||
|
@ -292,7 +312,7 @@ class List(ElementsBase):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class Vector(ElementsBase):
|
class Vector(Elements):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_check(cls, value):
|
def value_check(cls, value):
|
||||||
|
@ -307,23 +327,26 @@ class Vector(ElementsBase):
|
||||||
return cls.elem_type.is_fixed_size()
|
return cls.elem_type.is_fixed_size()
|
||||||
|
|
||||||
|
|
||||||
class BytesMeta(Elements):
|
class BytesType(ElementsType):
|
||||||
elem_type: SSZType = byte
|
elem_type: SSZType = byte
|
||||||
length: int
|
length: int
|
||||||
|
|
||||||
|
|
||||||
class BytesLike(ElementsBase, metaclass=BytesMeta):
|
class BytesLike(Elements, metaclass=BytesType):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extract_args(cls, args):
|
def extract_args(cls, *args):
|
||||||
if isinstance(args, bytes):
|
x = list(args)
|
||||||
return args
|
if len(x) == 1 and isinstance(x[0], (GeneratorType, bytes, BytesLike)):
|
||||||
elif isinstance(args, BytesLike):
|
x = x[0]
|
||||||
return args.items
|
if isinstance(x, bytes):
|
||||||
elif isinstance(args, GeneratorType):
|
return x
|
||||||
return bytes(args)
|
elif isinstance(x, BytesLike):
|
||||||
|
return x.items
|
||||||
|
elif isinstance(x, GeneratorType):
|
||||||
|
return bytes(x)
|
||||||
else:
|
else:
|
||||||
return bytes(args)
|
return bytes(x)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_check(cls, value):
|
def value_check(cls, value):
|
||||||
|
|
|
@ -0,0 +1,118 @@
|
||||||
|
from .ssz_impl import serialize, serialize_basic, encode_series, signing_root, hash_tree_root
|
||||||
|
from .ssz_typing import (
|
||||||
|
SSZValue, SSZType, BasicValue, BasicType, Series, ElementsType, Bit, Container, List, Vector, Bytes, BytesN,
|
||||||
|
uint, uint8, uint16, uint32, uint64, uint128, uint256, byte
|
||||||
|
)
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyTestStruct(Container):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SingleFieldTestStruct(Container):
|
||||||
|
A: byte
|
||||||
|
|
||||||
|
|
||||||
|
class SmallTestStruct(Container):
|
||||||
|
A: uint16
|
||||||
|
B: uint16
|
||||||
|
|
||||||
|
|
||||||
|
class FixedTestStruct(Container):
|
||||||
|
A: uint8
|
||||||
|
B: uint64
|
||||||
|
C: uint32
|
||||||
|
|
||||||
|
|
||||||
|
class VarTestStruct(Container):
|
||||||
|
A: uint16
|
||||||
|
B: List[uint16, 1024]
|
||||||
|
C: uint8
|
||||||
|
|
||||||
|
|
||||||
|
class ComplexTestStruct(Container):
|
||||||
|
A: uint16
|
||||||
|
B: List[uint16, 128]
|
||||||
|
C: uint8
|
||||||
|
D: Bytes[256]
|
||||||
|
E: VarTestStruct
|
||||||
|
F: Vector[FixedTestStruct, 4]
|
||||||
|
G: Vector[VarTestStruct, 2]
|
||||||
|
|
||||||
|
|
||||||
|
sig_test_data = [0 for i in range(96)]
|
||||||
|
for k, v in {0: 1, 32: 2, 64: 3, 95: 0xff}.items():
|
||||||
|
sig_test_data[k] = v
|
||||||
|
|
||||||
|
test_data = [
|
||||||
|
("bool F", Bit(False), "00"),
|
||||||
|
("bool T", Bit(True), "01"),
|
||||||
|
("uint8 00", uint8(0x00), "00"),
|
||||||
|
("uint8 01", uint8(0x01), "01"),
|
||||||
|
("uint8 ab", uint8(0xab), "ab"),
|
||||||
|
("uint16 0000", uint16(0x0000), "0000"),
|
||||||
|
("uint16 abcd", uint16(0xabcd), "cdab"),
|
||||||
|
("uint32 00000000", uint32(0x00000000), "00000000"),
|
||||||
|
("uint32 01234567", uint32(0x01234567), "67452301"),
|
||||||
|
("small (4567, 0123)", SmallTestStruct(A=0x4567, B=0x0123), "67452301"),
|
||||||
|
("small [4567, 0123]::2", Vector[uint16, 2](uint16(0x4567), uint16(0x0123)), "67452301"),
|
||||||
|
("uint32 01234567", uint32(0x01234567), "67452301"),
|
||||||
|
("uint64 0000000000000000", uint64(0x00000000), "0000000000000000"),
|
||||||
|
("uint64 0123456789abcdef", uint64(0x0123456789abcdef), "efcdab8967452301"),
|
||||||
|
("sig", BytesN[96](*sig_test_data),
|
||||||
|
"0100000000000000000000000000000000000000000000000000000000000000"
|
||||||
|
"0200000000000000000000000000000000000000000000000000000000000000"
|
||||||
|
"03000000000000000000000000000000000000000000000000000000000000ff"),
|
||||||
|
("emptyTestStruct", EmptyTestStruct(), ""),
|
||||||
|
("singleFieldTestStruct", SingleFieldTestStruct(A=0xab), "ab"),
|
||||||
|
("fixedTestStruct", FixedTestStruct(A=0xab, B=0xaabbccdd00112233, C=0x12345678), "ab33221100ddccbbaa78563412"),
|
||||||
|
("varTestStruct nil", VarTestStruct(A=0xabcd, C=0xff), "cdab07000000ff"),
|
||||||
|
("varTestStruct empty", VarTestStruct(A=0xabcd, B=List[uint16, 1024](), C=0xff), "cdab07000000ff"),
|
||||||
|
("varTestStruct some", VarTestStruct(A=0xabcd, B=List[uint16, 1024](1, 2, 3), C=0xff),
|
||||||
|
"cdab07000000ff010002000300"),
|
||||||
|
("complexTestStruct",
|
||||||
|
ComplexTestStruct(
|
||||||
|
A=0xaabb,
|
||||||
|
B=List[uint16, 128](0x1122, 0x3344),
|
||||||
|
C=0xff,
|
||||||
|
D=Bytes[256](b"foobar"),
|
||||||
|
E=VarTestStruct(A=0xabcd, B=List[uint16, 1024](1, 2, 3), C=0xff),
|
||||||
|
F=Vector[FixedTestStruct, 4](
|
||||||
|
FixedTestStruct(A=0xcc, B=0x4242424242424242, C=0x13371337),
|
||||||
|
FixedTestStruct(A=0xdd, B=0x3333333333333333, C=0xabcdabcd),
|
||||||
|
FixedTestStruct(A=0xee, B=0x4444444444444444, C=0x00112233),
|
||||||
|
FixedTestStruct(A=0xff, B=0x5555555555555555, C=0x44556677)),
|
||||||
|
G=Vector[VarTestStruct, 2](
|
||||||
|
VarTestStruct(A=0xabcd, B=List[uint16, 1024](1, 2, 3), C=0xff),
|
||||||
|
VarTestStruct(A=0xabcd, B=List[uint16, 1024](1, 2, 3), C=0xff)),
|
||||||
|
),
|
||||||
|
"bbaa"
|
||||||
|
"47000000" # offset of B, []uint16
|
||||||
|
"ff"
|
||||||
|
"4b000000" # offset of foobar
|
||||||
|
"51000000" # offset of E
|
||||||
|
"cc424242424242424237133713"
|
||||||
|
"dd3333333333333333cdabcdab"
|
||||||
|
"ee444444444444444433221100"
|
||||||
|
"ff555555555555555577665544"
|
||||||
|
"5e000000" # pointer to G
|
||||||
|
"22114433" # contents of B
|
||||||
|
"666f6f626172" # foobar
|
||||||
|
"cdab07000000ff010002000300" # contents of E
|
||||||
|
"08000000" "15000000" # [start G]: local offsets of [2]varTestStruct
|
||||||
|
"cdab07000000ff010002000300"
|
||||||
|
"cdab07000000ff010002000300",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("name, value, serialized", test_data)
|
||||||
|
def test_serialize(name, value, serialized):
|
||||||
|
assert serialize(value) == bytes.fromhex(serialized)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("name, value, _", test_data)
|
||||||
|
def test_hash_tree_root(name, value, _):
|
||||||
|
hash_tree_root(value)
|
|
@ -1,5 +1,5 @@
|
||||||
from .ssz_typing import (
|
from .ssz_typing import (
|
||||||
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Vector, Bytes, BytesN,
|
SSZValue, SSZType, BasicValue, BasicType, Series, ElementsType, Elements, Bit, Container, List, Vector, Bytes, BytesN,
|
||||||
uint, uint8, uint16, uint32, uint64, uint128, uint256
|
uint, uint8, uint16, uint32, uint64, uint128, uint256
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,7 +23,8 @@ def test_subclasses():
|
||||||
assert not isinstance(c, BasicType)
|
assert not isinstance(c, BasicType)
|
||||||
|
|
||||||
for c in [List, Vector, Bytes, BytesN]:
|
for c in [List, Vector, Bytes, BytesN]:
|
||||||
assert isinstance(c, Elements)
|
assert issubclass(c, Elements)
|
||||||
|
assert isinstance(c, ElementsType)
|
||||||
|
|
||||||
|
|
||||||
def test_basic_instances():
|
def test_basic_instances():
|
||||||
|
@ -109,7 +110,8 @@ def test_list():
|
||||||
assert issubclass(typ, List)
|
assert issubclass(typ, List)
|
||||||
assert issubclass(typ, SSZValue)
|
assert issubclass(typ, SSZValue)
|
||||||
assert issubclass(typ, Series)
|
assert issubclass(typ, Series)
|
||||||
assert isinstance(typ, Elements)
|
assert issubclass(typ, Elements)
|
||||||
|
assert isinstance(typ, ElementsType)
|
||||||
|
|
||||||
assert not typ.is_fixed_size()
|
assert not typ.is_fixed_size()
|
||||||
|
|
||||||
|
@ -128,4 +130,5 @@ def test_list():
|
||||||
assert isinstance(v, typ)
|
assert isinstance(v, typ)
|
||||||
assert isinstance(v, SSZValue)
|
assert isinstance(v, SSZValue)
|
||||||
assert isinstance(v, Series)
|
assert isinstance(v, Series)
|
||||||
assert isinstance(v.type(), Elements)
|
assert issubclass(v.type(), Elements)
|
||||||
|
assert isinstance(v.type(), ElementsType)
|
||||||
|
|
Loading…
Reference in New Issue