bugfixes and typing improvements
This commit is contained in:
parent
08e6f32f38
commit
8bd2e878ef
|
@ -44,7 +44,7 @@ def next_power_of_two(v: int) -> int:
|
|||
return 1 << (v - 1).bit_length()
|
||||
|
||||
|
||||
def merkleize_chunks(chunks, pad_to: int = None):
|
||||
def merkleize_chunks(chunks, pad_to: int = 1):
|
||||
count = len(chunks)
|
||||
depth = max(count - 1, 0).bit_length()
|
||||
max_depth = max(depth, (pad_to - 1).bit_length())
|
||||
|
@ -55,7 +55,7 @@ def merkleize_chunks(chunks, pad_to: int = None):
|
|||
while True:
|
||||
if i & (1 << j) == 0:
|
||||
if i == count and j < depth:
|
||||
h = hash(h + zerohashes[j])
|
||||
h = hash(h + zerohashes[j]) # keep going if we are complementing the void to the next power of 2
|
||||
else:
|
||||
break
|
||||
else:
|
||||
|
@ -63,11 +63,15 @@ def merkleize_chunks(chunks, pad_to: int = None):
|
|||
j += 1
|
||||
tmp[j] = h
|
||||
|
||||
# merge in leaf by leaf.
|
||||
for i in range(count):
|
||||
merge(chunks[i], i)
|
||||
|
||||
# complement with 0 if empty, or if not the right power of 2
|
||||
if 1 << depth != count:
|
||||
merge(zerohashes[0], count)
|
||||
|
||||
# the next power of two may be smaller than the ultimate virtual size, complement with zero-hashes at each depth.
|
||||
for j in range(depth, max_depth):
|
||||
tmp[j + 1] = hash(tmp[j] + zerohashes[j])
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from ..merkle_minimal import merkleize_chunks
|
||||
from ..hash_function import hash
|
||||
from .ssz_typing import (
|
||||
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Vector, Bytes, BytesN, uint
|
||||
SSZValue, SSZType, BasicValue, BasicType, Series, ElementsType, Elements, Bit, Container, List, Vector, Bytes, BytesN, uint
|
||||
)
|
||||
|
||||
# SSZ Serialization
|
||||
|
@ -47,8 +47,8 @@ def serialize(obj: SSZValue):
|
|||
|
||||
def encode_series(values: Series):
|
||||
# bytes and bytesN are already in the right format.
|
||||
if isinstance(values, bytes):
|
||||
return values
|
||||
if isinstance(values, (Bytes, BytesN)):
|
||||
return values.items
|
||||
|
||||
# Recursively serialize
|
||||
parts = [(v.type().is_fixed_size(), serialize(v)) for v in values]
|
||||
|
@ -84,8 +84,8 @@ def encode_series(values: Series):
|
|||
|
||||
|
||||
def pack(values: Series):
|
||||
if isinstance(values, bytes):
|
||||
return values
|
||||
if isinstance(values, (Bytes, BytesN)):
|
||||
return values.items
|
||||
return b''.join([serialize_basic(value) for value in values])
|
||||
|
||||
|
||||
|
@ -101,8 +101,8 @@ def mix_in_length(root, length):
|
|||
|
||||
def is_bottom_layer_kind(typ: SSZType):
|
||||
return (
|
||||
issubclass(typ, BasicType) or
|
||||
(issubclass(typ, Elements) and issubclass(typ.elem_type, BasicType))
|
||||
isinstance(typ, BasicType) or
|
||||
(issubclass(typ, Elements) and isinstance(typ.elem_type, BasicType))
|
||||
)
|
||||
|
||||
|
||||
|
@ -114,7 +114,7 @@ def item_length(typ: SSZType) -> int:
|
|||
|
||||
|
||||
def chunk_count(typ: SSZType) -> int:
|
||||
if issubclass(typ, BasicType):
|
||||
if isinstance(typ, BasicType):
|
||||
return 1
|
||||
elif issubclass(typ, Elements):
|
||||
return (typ.length * item_length(typ.elem_type) + 31) // 32
|
||||
|
@ -133,7 +133,7 @@ def hash_tree_root(obj: SSZValue):
|
|||
elif isinstance(obj, BasicValue):
|
||||
leaves = chunkify(serialize_basic(obj))
|
||||
else:
|
||||
raise Exception(f"Type not supported: {obj.type()}")
|
||||
raise Exception(f"Type not supported: {type(obj)}")
|
||||
|
||||
if isinstance(obj, (List, Bytes)):
|
||||
return mix_in_length(merkleize_chunks(leaves, pad_to=chunk_count(obj.type())), len(obj))
|
||||
|
|
|
@ -92,6 +92,22 @@ class uint256(uint):
|
|||
byte_len = 32
|
||||
|
||||
|
||||
def coerce_type_maybe(v, typ: SSZType):
|
||||
v_typ = type(v)
|
||||
# shortcut if it's already the type we are looking for
|
||||
if v_typ == typ:
|
||||
return v
|
||||
elif isinstance(v, int):
|
||||
return typ(v)
|
||||
elif isinstance(v, (list, tuple)):
|
||||
return typ(*v)
|
||||
elif isinstance(v, GeneratorType):
|
||||
return typ(v)
|
||||
else:
|
||||
# just return as-is, Value-checkers will take care of it not being coerced.
|
||||
return v
|
||||
|
||||
|
||||
class Series(SSZValue):
|
||||
|
||||
def __iter__(self) -> Iterator[SSZValue]:
|
||||
|
@ -108,7 +124,11 @@ class Container(Series, metaclass=SSZType):
|
|||
if f not in kwargs:
|
||||
setattr(self, f, t.default())
|
||||
else:
|
||||
setattr(self, f, kwargs[f])
|
||||
value = coerce_type_maybe(kwargs[f], t)
|
||||
if not isinstance(value, t):
|
||||
raise ValueCheckError(f"Bad input for class {self.__class__}:"
|
||||
f" field: {f} type: {t} value: {value} value type: {type(value)}")
|
||||
setattr(self, f, value)
|
||||
|
||||
def serialize(self):
|
||||
from .ssz_impl import serialize
|
||||
|
@ -141,23 +161,22 @@ class Container(Series, metaclass=SSZType):
|
|||
def __hash__(self):
|
||||
return hash(self.hash_tree_root())
|
||||
|
||||
@classmethod
|
||||
def get_fields_dict(cls) -> Dict[str, SSZType]:
|
||||
return dict(cls.__annotations__)
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls) -> Tuple[Tuple[str, SSZType], ...]:
|
||||
if not hasattr(cls, '__annotations__'): # no container fields
|
||||
return ()
|
||||
return tuple((f, t) for f, t in cls.__annotations__.items())
|
||||
|
||||
def get_typed_values(self):
|
||||
return tuple(zip(self.get_field_values(), self.get_field_types()))
|
||||
|
||||
@classmethod
|
||||
def get_field_names(cls) -> Tuple[str]:
|
||||
def get_field_names(cls) -> Tuple[str, ...]:
|
||||
if not hasattr(cls, '__annotations__'): # no container fields
|
||||
return ()
|
||||
return tuple(cls.__annotations__.keys())
|
||||
|
||||
@classmethod
|
||||
def get_field_types(cls) -> Tuple[SSZType, ...]:
|
||||
if not hasattr(cls, '__annotations__'): # no container fields
|
||||
return ()
|
||||
# values of annotations are the types corresponding to the fields, not instance values.
|
||||
return tuple(cls.__annotations__.values())
|
||||
|
||||
|
@ -233,12 +252,12 @@ class ParamsMeta(SSZType):
|
|||
return True
|
||||
|
||||
|
||||
class Elements(ParamsMeta):
|
||||
class ElementsType(ParamsMeta):
|
||||
elem_type: SSZType
|
||||
length: int
|
||||
|
||||
|
||||
class ElementsBase(ParamsBase, metaclass=Elements):
|
||||
class Elements(ParamsBase, metaclass=ElementsType):
|
||||
|
||||
def __init__(self, *args):
|
||||
items = self.extract_args(*args)
|
||||
|
@ -256,6 +275,7 @@ class ElementsBase(ParamsBase, metaclass=Elements):
|
|||
x = list(args)
|
||||
if len(x) == 1 and isinstance(x[0], GeneratorType):
|
||||
x = list(x[0])
|
||||
x = [coerce_type_maybe(v, cls.elem_type) for v in x]
|
||||
return x
|
||||
|
||||
def __str__(self):
|
||||
|
@ -281,7 +301,7 @@ class ElementsBase(ParamsBase, metaclass=Elements):
|
|||
return self.items == other.items
|
||||
|
||||
|
||||
class List(ElementsBase):
|
||||
class List(Elements):
|
||||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
|
@ -292,7 +312,7 @@ class List(ElementsBase):
|
|||
return False
|
||||
|
||||
|
||||
class Vector(ElementsBase):
|
||||
class Vector(Elements):
|
||||
|
||||
@classmethod
|
||||
def value_check(cls, value):
|
||||
|
@ -307,23 +327,26 @@ class Vector(ElementsBase):
|
|||
return cls.elem_type.is_fixed_size()
|
||||
|
||||
|
||||
class BytesMeta(Elements):
|
||||
class BytesType(ElementsType):
|
||||
elem_type: SSZType = byte
|
||||
length: int
|
||||
|
||||
|
||||
class BytesLike(ElementsBase, metaclass=BytesMeta):
|
||||
class BytesLike(Elements, metaclass=BytesType):
|
||||
|
||||
@classmethod
|
||||
def extract_args(cls, args):
|
||||
if isinstance(args, bytes):
|
||||
return args
|
||||
elif isinstance(args, BytesLike):
|
||||
return args.items
|
||||
elif isinstance(args, GeneratorType):
|
||||
return bytes(args)
|
||||
def extract_args(cls, *args):
|
||||
x = list(args)
|
||||
if len(x) == 1 and isinstance(x[0], (GeneratorType, bytes, BytesLike)):
|
||||
x = x[0]
|
||||
if isinstance(x, bytes):
|
||||
return x
|
||||
elif isinstance(x, BytesLike):
|
||||
return x.items
|
||||
elif isinstance(x, GeneratorType):
|
||||
return bytes(x)
|
||||
else:
|
||||
return bytes(args)
|
||||
return bytes(x)
|
||||
|
||||
@classmethod
|
||||
def value_check(cls, value):
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
from .ssz_impl import serialize, serialize_basic, encode_series, signing_root, hash_tree_root
|
||||
from .ssz_typing import (
|
||||
SSZValue, SSZType, BasicValue, BasicType, Series, ElementsType, Bit, Container, List, Vector, Bytes, BytesN,
|
||||
uint, uint8, uint16, uint32, uint64, uint128, uint256, byte
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class EmptyTestStruct(Container):
|
||||
pass
|
||||
|
||||
|
||||
class SingleFieldTestStruct(Container):
|
||||
A: byte
|
||||
|
||||
|
||||
class SmallTestStruct(Container):
|
||||
A: uint16
|
||||
B: uint16
|
||||
|
||||
|
||||
class FixedTestStruct(Container):
|
||||
A: uint8
|
||||
B: uint64
|
||||
C: uint32
|
||||
|
||||
|
||||
class VarTestStruct(Container):
|
||||
A: uint16
|
||||
B: List[uint16, 1024]
|
||||
C: uint8
|
||||
|
||||
|
||||
class ComplexTestStruct(Container):
|
||||
A: uint16
|
||||
B: List[uint16, 128]
|
||||
C: uint8
|
||||
D: Bytes[256]
|
||||
E: VarTestStruct
|
||||
F: Vector[FixedTestStruct, 4]
|
||||
G: Vector[VarTestStruct, 2]
|
||||
|
||||
|
||||
sig_test_data = [0 for i in range(96)]
|
||||
for k, v in {0: 1, 32: 2, 64: 3, 95: 0xff}.items():
|
||||
sig_test_data[k] = v
|
||||
|
||||
test_data = [
|
||||
("bool F", Bit(False), "00"),
|
||||
("bool T", Bit(True), "01"),
|
||||
("uint8 00", uint8(0x00), "00"),
|
||||
("uint8 01", uint8(0x01), "01"),
|
||||
("uint8 ab", uint8(0xab), "ab"),
|
||||
("uint16 0000", uint16(0x0000), "0000"),
|
||||
("uint16 abcd", uint16(0xabcd), "cdab"),
|
||||
("uint32 00000000", uint32(0x00000000), "00000000"),
|
||||
("uint32 01234567", uint32(0x01234567), "67452301"),
|
||||
("small (4567, 0123)", SmallTestStruct(A=0x4567, B=0x0123), "67452301"),
|
||||
("small [4567, 0123]::2", Vector[uint16, 2](uint16(0x4567), uint16(0x0123)), "67452301"),
|
||||
("uint32 01234567", uint32(0x01234567), "67452301"),
|
||||
("uint64 0000000000000000", uint64(0x00000000), "0000000000000000"),
|
||||
("uint64 0123456789abcdef", uint64(0x0123456789abcdef), "efcdab8967452301"),
|
||||
("sig", BytesN[96](*sig_test_data),
|
||||
"0100000000000000000000000000000000000000000000000000000000000000"
|
||||
"0200000000000000000000000000000000000000000000000000000000000000"
|
||||
"03000000000000000000000000000000000000000000000000000000000000ff"),
|
||||
("emptyTestStruct", EmptyTestStruct(), ""),
|
||||
("singleFieldTestStruct", SingleFieldTestStruct(A=0xab), "ab"),
|
||||
("fixedTestStruct", FixedTestStruct(A=0xab, B=0xaabbccdd00112233, C=0x12345678), "ab33221100ddccbbaa78563412"),
|
||||
("varTestStruct nil", VarTestStruct(A=0xabcd, C=0xff), "cdab07000000ff"),
|
||||
("varTestStruct empty", VarTestStruct(A=0xabcd, B=List[uint16, 1024](), C=0xff), "cdab07000000ff"),
|
||||
("varTestStruct some", VarTestStruct(A=0xabcd, B=List[uint16, 1024](1, 2, 3), C=0xff),
|
||||
"cdab07000000ff010002000300"),
|
||||
("complexTestStruct",
|
||||
ComplexTestStruct(
|
||||
A=0xaabb,
|
||||
B=List[uint16, 128](0x1122, 0x3344),
|
||||
C=0xff,
|
||||
D=Bytes[256](b"foobar"),
|
||||
E=VarTestStruct(A=0xabcd, B=List[uint16, 1024](1, 2, 3), C=0xff),
|
||||
F=Vector[FixedTestStruct, 4](
|
||||
FixedTestStruct(A=0xcc, B=0x4242424242424242, C=0x13371337),
|
||||
FixedTestStruct(A=0xdd, B=0x3333333333333333, C=0xabcdabcd),
|
||||
FixedTestStruct(A=0xee, B=0x4444444444444444, C=0x00112233),
|
||||
FixedTestStruct(A=0xff, B=0x5555555555555555, C=0x44556677)),
|
||||
G=Vector[VarTestStruct, 2](
|
||||
VarTestStruct(A=0xabcd, B=List[uint16, 1024](1, 2, 3), C=0xff),
|
||||
VarTestStruct(A=0xabcd, B=List[uint16, 1024](1, 2, 3), C=0xff)),
|
||||
),
|
||||
"bbaa"
|
||||
"47000000" # offset of B, []uint16
|
||||
"ff"
|
||||
"4b000000" # offset of foobar
|
||||
"51000000" # offset of E
|
||||
"cc424242424242424237133713"
|
||||
"dd3333333333333333cdabcdab"
|
||||
"ee444444444444444433221100"
|
||||
"ff555555555555555577665544"
|
||||
"5e000000" # pointer to G
|
||||
"22114433" # contents of B
|
||||
"666f6f626172" # foobar
|
||||
"cdab07000000ff010002000300" # contents of E
|
||||
"08000000" "15000000" # [start G]: local offsets of [2]varTestStruct
|
||||
"cdab07000000ff010002000300"
|
||||
"cdab07000000ff010002000300",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name, value, serialized", test_data)
|
||||
def test_serialize(name, value, serialized):
|
||||
assert serialize(value) == bytes.fromhex(serialized)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name, value, _", test_data)
|
||||
def test_hash_tree_root(name, value, _):
|
||||
hash_tree_root(value)
|
|
@ -1,5 +1,5 @@
|
|||
from .ssz_typing import (
|
||||
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Vector, Bytes, BytesN,
|
||||
SSZValue, SSZType, BasicValue, BasicType, Series, ElementsType, Elements, Bit, Container, List, Vector, Bytes, BytesN,
|
||||
uint, uint8, uint16, uint32, uint64, uint128, uint256
|
||||
)
|
||||
|
||||
|
@ -23,7 +23,8 @@ def test_subclasses():
|
|||
assert not isinstance(c, BasicType)
|
||||
|
||||
for c in [List, Vector, Bytes, BytesN]:
|
||||
assert isinstance(c, Elements)
|
||||
assert issubclass(c, Elements)
|
||||
assert isinstance(c, ElementsType)
|
||||
|
||||
|
||||
def test_basic_instances():
|
||||
|
@ -109,7 +110,8 @@ def test_list():
|
|||
assert issubclass(typ, List)
|
||||
assert issubclass(typ, SSZValue)
|
||||
assert issubclass(typ, Series)
|
||||
assert isinstance(typ, Elements)
|
||||
assert issubclass(typ, Elements)
|
||||
assert isinstance(typ, ElementsType)
|
||||
|
||||
assert not typ.is_fixed_size()
|
||||
|
||||
|
@ -128,4 +130,5 @@ def test_list():
|
|||
assert isinstance(v, typ)
|
||||
assert isinstance(v, SSZValue)
|
||||
assert isinstance(v, Series)
|
||||
assert isinstance(v.type(), Elements)
|
||||
assert issubclass(v.type(), Elements)
|
||||
assert isinstance(v.type(), ElementsType)
|
||||
|
|
Loading…
Reference in New Issue