bugfixes and typing improvements

This commit is contained in:
protolambda 2019-06-17 01:39:39 +02:00
parent 08e6f32f38
commit 8bd2e878ef
5 changed files with 187 additions and 39 deletions

View File

@ -44,7 +44,7 @@ def next_power_of_two(v: int) -> int:
return 1 << (v - 1).bit_length()
def merkleize_chunks(chunks, pad_to: int = None):
def merkleize_chunks(chunks, pad_to: int = 1):
count = len(chunks)
depth = max(count - 1, 0).bit_length()
max_depth = max(depth, (pad_to - 1).bit_length())
@ -55,7 +55,7 @@ def merkleize_chunks(chunks, pad_to: int = None):
while True:
if i & (1 << j) == 0:
if i == count and j < depth:
h = hash(h + zerohashes[j])
h = hash(h + zerohashes[j]) # keep going if we are complementing the void to the next power of 2
else:
break
else:
@ -63,11 +63,15 @@ def merkleize_chunks(chunks, pad_to: int = None):
j += 1
tmp[j] = h
# merge in leaf by leaf.
for i in range(count):
merge(chunks[i], i)
merge(zerohashes[0], count)
# complement with 0 if empty, or if not the right power of 2
if 1 << depth != count:
merge(zerohashes[0], count)
# the next power of two may be smaller than the ultimate virtual size, complement with zero-hashes at each depth.
for j in range(depth, max_depth):
tmp[j + 1] = hash(tmp[j] + zerohashes[j])

View File

@ -1,7 +1,7 @@
from ..merkle_minimal import merkleize_chunks
from ..hash_function import hash
from .ssz_typing import (
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Vector, Bytes, BytesN, uint
SSZValue, SSZType, BasicValue, BasicType, Series, ElementsType, Elements, Bit, Container, List, Vector, Bytes, BytesN, uint
)
# SSZ Serialization
@ -47,8 +47,8 @@ def serialize(obj: SSZValue):
def encode_series(values: Series):
# bytes and bytesN are already in the right format.
if isinstance(values, bytes):
return values
if isinstance(values, (Bytes, BytesN)):
return values.items
# Recursively serialize
parts = [(v.type().is_fixed_size(), serialize(v)) for v in values]
@ -84,8 +84,8 @@ def encode_series(values: Series):
def pack(values: Series):
if isinstance(values, bytes):
return values
if isinstance(values, (Bytes, BytesN)):
return values.items
return b''.join([serialize_basic(value) for value in values])
@ -101,8 +101,8 @@ def mix_in_length(root, length):
def is_bottom_layer_kind(typ: SSZType):
return (
issubclass(typ, BasicType) or
(issubclass(typ, Elements) and issubclass(typ.elem_type, BasicType))
isinstance(typ, BasicType) or
(issubclass(typ, Elements) and isinstance(typ.elem_type, BasicType))
)
@ -114,7 +114,7 @@ def item_length(typ: SSZType) -> int:
def chunk_count(typ: SSZType) -> int:
if issubclass(typ, BasicType):
if isinstance(typ, BasicType):
return 1
elif issubclass(typ, Elements):
return (typ.length * item_length(typ.elem_type) + 31) // 32
@ -133,7 +133,7 @@ def hash_tree_root(obj: SSZValue):
elif isinstance(obj, BasicValue):
leaves = chunkify(serialize_basic(obj))
else:
raise Exception(f"Type not supported: {obj.type()}")
raise Exception(f"Type not supported: {type(obj)}")
if isinstance(obj, (List, Bytes)):
return mix_in_length(merkleize_chunks(leaves, pad_to=chunk_count(obj.type())), len(obj))

View File

@ -92,6 +92,22 @@ class uint256(uint):
byte_len = 32
def coerce_type_maybe(v, typ: SSZType):
v_typ = type(v)
# shortcut if it's already the type we are looking for
if v_typ == typ:
return v
elif isinstance(v, int):
return typ(v)
elif isinstance(v, (list, tuple)):
return typ(*v)
elif isinstance(v, GeneratorType):
return typ(v)
else:
# just return as-is, Value-checkers will take care of it not being coerced.
return v
class Series(SSZValue):
def __iter__(self) -> Iterator[SSZValue]:
@ -108,7 +124,11 @@ class Container(Series, metaclass=SSZType):
if f not in kwargs:
setattr(self, f, t.default())
else:
setattr(self, f, kwargs[f])
value = coerce_type_maybe(kwargs[f], t)
if not isinstance(value, t):
raise ValueCheckError(f"Bad input for class {self.__class__}:"
f" field: {f} type: {t} value: {value} value type: {type(value)}")
setattr(self, f, value)
def serialize(self):
from .ssz_impl import serialize
@ -141,23 +161,22 @@ class Container(Series, metaclass=SSZType):
def __hash__(self):
return hash(self.hash_tree_root())
@classmethod
def get_fields_dict(cls) -> Dict[str, SSZType]:
return dict(cls.__annotations__)
@classmethod
def get_fields(cls) -> Tuple[Tuple[str, SSZType], ...]:
if not hasattr(cls, '__annotations__'): # no container fields
return ()
return tuple((f, t) for f, t in cls.__annotations__.items())
def get_typed_values(self):
return tuple(zip(self.get_field_values(), self.get_field_types()))
@classmethod
def get_field_names(cls) -> Tuple[str]:
def get_field_names(cls) -> Tuple[str, ...]:
if not hasattr(cls, '__annotations__'): # no container fields
return ()
return tuple(cls.__annotations__.keys())
@classmethod
def get_field_types(cls) -> Tuple[SSZType, ...]:
if not hasattr(cls, '__annotations__'): # no container fields
return ()
# values of annotations are the types corresponding to the fields, not instance values.
return tuple(cls.__annotations__.values())
@ -233,12 +252,12 @@ class ParamsMeta(SSZType):
return True
class Elements(ParamsMeta):
class ElementsType(ParamsMeta):
elem_type: SSZType
length: int
class ElementsBase(ParamsBase, metaclass=Elements):
class Elements(ParamsBase, metaclass=ElementsType):
def __init__(self, *args):
items = self.extract_args(*args)
@ -256,6 +275,7 @@ class ElementsBase(ParamsBase, metaclass=Elements):
x = list(args)
if len(x) == 1 and isinstance(x[0], GeneratorType):
x = list(x[0])
x = [coerce_type_maybe(v, cls.elem_type) for v in x]
return x
def __str__(self):
@ -281,7 +301,7 @@ class ElementsBase(ParamsBase, metaclass=Elements):
return self.items == other.items
class List(ElementsBase):
class List(Elements):
@classmethod
def default(cls):
@ -292,7 +312,7 @@ class List(ElementsBase):
return False
class Vector(ElementsBase):
class Vector(Elements):
@classmethod
def value_check(cls, value):
@ -307,23 +327,26 @@ class Vector(ElementsBase):
return cls.elem_type.is_fixed_size()
class BytesMeta(Elements):
class BytesType(ElementsType):
elem_type: SSZType = byte
length: int
class BytesLike(ElementsBase, metaclass=BytesMeta):
class BytesLike(Elements, metaclass=BytesType):
@classmethod
def extract_args(cls, args):
if isinstance(args, bytes):
return args
elif isinstance(args, BytesLike):
return args.items
elif isinstance(args, GeneratorType):
return bytes(args)
def extract_args(cls, *args):
x = list(args)
if len(x) == 1 and isinstance(x[0], (GeneratorType, bytes, BytesLike)):
x = x[0]
if isinstance(x, bytes):
return x
elif isinstance(x, BytesLike):
return x.items
elif isinstance(x, GeneratorType):
return bytes(x)
else:
return bytes(args)
return bytes(x)
@classmethod
def value_check(cls, value):

View File

@ -0,0 +1,118 @@
from .ssz_impl import serialize, serialize_basic, encode_series, signing_root, hash_tree_root
from .ssz_typing import (
SSZValue, SSZType, BasicValue, BasicType, Series, ElementsType, Bit, Container, List, Vector, Bytes, BytesN,
uint, uint8, uint16, uint32, uint64, uint128, uint256, byte
)
import pytest
class EmptyTestStruct(Container):
pass
class SingleFieldTestStruct(Container):
A: byte
class SmallTestStruct(Container):
A: uint16
B: uint16
class FixedTestStruct(Container):
A: uint8
B: uint64
C: uint32
class VarTestStruct(Container):
A: uint16
B: List[uint16, 1024]
C: uint8
class ComplexTestStruct(Container):
A: uint16
B: List[uint16, 128]
C: uint8
D: Bytes[256]
E: VarTestStruct
F: Vector[FixedTestStruct, 4]
G: Vector[VarTestStruct, 2]
sig_test_data = [0 for i in range(96)]
for k, v in {0: 1, 32: 2, 64: 3, 95: 0xff}.items():
sig_test_data[k] = v
test_data = [
("bool F", Bit(False), "00"),
("bool T", Bit(True), "01"),
("uint8 00", uint8(0x00), "00"),
("uint8 01", uint8(0x01), "01"),
("uint8 ab", uint8(0xab), "ab"),
("uint16 0000", uint16(0x0000), "0000"),
("uint16 abcd", uint16(0xabcd), "cdab"),
("uint32 00000000", uint32(0x00000000), "00000000"),
("uint32 01234567", uint32(0x01234567), "67452301"),
("small (4567, 0123)", SmallTestStruct(A=0x4567, B=0x0123), "67452301"),
("small [4567, 0123]::2", Vector[uint16, 2](uint16(0x4567), uint16(0x0123)), "67452301"),
("uint32 01234567", uint32(0x01234567), "67452301"),
("uint64 0000000000000000", uint64(0x00000000), "0000000000000000"),
("uint64 0123456789abcdef", uint64(0x0123456789abcdef), "efcdab8967452301"),
("sig", BytesN[96](*sig_test_data),
"0100000000000000000000000000000000000000000000000000000000000000"
"0200000000000000000000000000000000000000000000000000000000000000"
"03000000000000000000000000000000000000000000000000000000000000ff"),
("emptyTestStruct", EmptyTestStruct(), ""),
("singleFieldTestStruct", SingleFieldTestStruct(A=0xab), "ab"),
("fixedTestStruct", FixedTestStruct(A=0xab, B=0xaabbccdd00112233, C=0x12345678), "ab33221100ddccbbaa78563412"),
("varTestStruct nil", VarTestStruct(A=0xabcd, C=0xff), "cdab07000000ff"),
("varTestStruct empty", VarTestStruct(A=0xabcd, B=List[uint16, 1024](), C=0xff), "cdab07000000ff"),
("varTestStruct some", VarTestStruct(A=0xabcd, B=List[uint16, 1024](1, 2, 3), C=0xff),
"cdab07000000ff010002000300"),
("complexTestStruct",
ComplexTestStruct(
A=0xaabb,
B=List[uint16, 128](0x1122, 0x3344),
C=0xff,
D=Bytes[256](b"foobar"),
E=VarTestStruct(A=0xabcd, B=List[uint16, 1024](1, 2, 3), C=0xff),
F=Vector[FixedTestStruct, 4](
FixedTestStruct(A=0xcc, B=0x4242424242424242, C=0x13371337),
FixedTestStruct(A=0xdd, B=0x3333333333333333, C=0xabcdabcd),
FixedTestStruct(A=0xee, B=0x4444444444444444, C=0x00112233),
FixedTestStruct(A=0xff, B=0x5555555555555555, C=0x44556677)),
G=Vector[VarTestStruct, 2](
VarTestStruct(A=0xabcd, B=List[uint16, 1024](1, 2, 3), C=0xff),
VarTestStruct(A=0xabcd, B=List[uint16, 1024](1, 2, 3), C=0xff)),
),
"bbaa"
"47000000" # offset of B, []uint16
"ff"
"4b000000" # offset of foobar
"51000000" # offset of E
"cc424242424242424237133713"
"dd3333333333333333cdabcdab"
"ee444444444444444433221100"
"ff555555555555555577665544"
"5e000000" # pointer to G
"22114433" # contents of B
"666f6f626172" # foobar
"cdab07000000ff010002000300" # contents of E
"08000000" "15000000" # [start G]: local offsets of [2]varTestStruct
"cdab07000000ff010002000300"
"cdab07000000ff010002000300",
)
]
@pytest.mark.parametrize("name, value, serialized", test_data)
def test_serialize(name, value, serialized):
assert serialize(value) == bytes.fromhex(serialized)
@pytest.mark.parametrize("name, value, _", test_data)
def test_hash_tree_root(name, value, _):
hash_tree_root(value)

View File

@ -1,5 +1,5 @@
from .ssz_typing import (
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bit, Container, List, Vector, Bytes, BytesN,
SSZValue, SSZType, BasicValue, BasicType, Series, ElementsType, Elements, Bit, Container, List, Vector, Bytes, BytesN,
uint, uint8, uint16, uint32, uint64, uint128, uint256
)
@ -23,7 +23,8 @@ def test_subclasses():
assert not isinstance(c, BasicType)
for c in [List, Vector, Bytes, BytesN]:
assert isinstance(c, Elements)
assert issubclass(c, Elements)
assert isinstance(c, ElementsType)
def test_basic_instances():
@ -109,7 +110,8 @@ def test_list():
assert issubclass(typ, List)
assert issubclass(typ, SSZValue)
assert issubclass(typ, Series)
assert isinstance(typ, Elements)
assert issubclass(typ, Elements)
assert isinstance(typ, ElementsType)
assert not typ.is_fixed_size()
@ -128,4 +130,5 @@ def test_list():
assert isinstance(v, typ)
assert isinstance(v, SSZValue)
assert isinstance(v, Series)
assert isinstance(v.type(), Elements)
assert issubclass(v.type(), Elements)
assert isinstance(v.type(), ElementsType)