diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py index 7298fb3ca..f0ee944bd 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py @@ -1,7 +1,7 @@ from ..merkle_minimal import merkleize_chunks from ..hash_function import hash from .ssz_typing import ( - SSZValue, SSZType, BasicValue, BasicType, Series, Elements, boolean, Container, List, Bytes, + SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bitfield, boolean, Container, List, Bytes, Bitlist, Bitvector, uint, ) @@ -128,6 +128,8 @@ def item_length(typ: SSZType) -> int: def chunk_count(typ: SSZType) -> int: if isinstance(typ, BasicType): return 1 + elif issubclass(typ, Bitfield): + return (typ.length + 7) // 8 // 32 elif issubclass(typ, Elements): return (typ.length * item_length(typ.elem_type) + 31) // 32 elif issubclass(typ, Container): diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py index 6ce2b1538..53ab42743 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py @@ -281,10 +281,6 @@ class ElementsType(ParamsMeta): length: int -class BitElementsType(ElementsType): - elem_type = boolean - - class Elements(ParamsBase, metaclass=ElementsType): pass @@ -346,11 +342,16 @@ class BaseList(list, Elements): return self[len(self) - 1] -class BaseBitfield(BaseList, metaclass=BitElementsType): - elem_type = bool +class BitElementsType(ElementsType): + elem_type: SSZType = boolean + length: int -class Bitlist(BaseBitfield): +class Bitfield(BaseList, metaclass=BitElementsType): + pass + + +class Bitlist(Bitfield): @classmethod def is_fixed_size(cls): return False @@ -360,15 +361,29 @@ class Bitlist(BaseBitfield): return cls() -class Bitvector(BaseBitfield): +class Bitvector(Bitfield): + + @classmethod + def extract_args(cls, *args): + if len(args) == 0: + return cls.default() + else: + return super().extract_args(*args) + + @classmethod + def value_check(cls, value): + # check length limit strictly + return len(value) == cls.length and super().value_check(value) + @classmethod def is_fixed_size(cls): return True - + @classmethod def default(cls): return cls(0 for _ in range(cls.length)) + class List(BaseList): @classmethod