suggestion to implement bitfield like

This commit is contained in:
protolambda 2019-06-27 15:40:40 +02:00
parent f57387cc83
commit a5154da1ff
No known key found for this signature in database
GPG Key ID: EC89FDBB2B4C7623
2 changed files with 27 additions and 10 deletions

View File

@ -1,7 +1,7 @@
from ..merkle_minimal import merkleize_chunks
from ..hash_function import hash
from .ssz_typing import (
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, boolean, Container, List, Bytes,
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bitfield, boolean, Container, List, Bytes,
Bitlist, Bitvector, uint,
)
@ -128,6 +128,8 @@ def item_length(typ: SSZType) -> int:
def chunk_count(typ: SSZType) -> int:
if isinstance(typ, BasicType):
return 1
elif issubclass(typ, Bitfield):
return (typ.length + 7) // 8 // 32
elif issubclass(typ, Elements):
return (typ.length * item_length(typ.elem_type) + 31) // 32
elif issubclass(typ, Container):

View File

@ -281,10 +281,6 @@ class ElementsType(ParamsMeta):
length: int
class BitElementsType(ElementsType):
elem_type = boolean
class Elements(ParamsBase, metaclass=ElementsType):
pass
@ -346,11 +342,16 @@ class BaseList(list, Elements):
return self[len(self) - 1]
class BaseBitfield(BaseList, metaclass=BitElementsType):
elem_type = bool
class BitElementsType(ElementsType):
elem_type: SSZType = boolean
length: int
class Bitlist(BaseBitfield):
class Bitfield(BaseList, metaclass=BitElementsType):
pass
class Bitlist(Bitfield):
@classmethod
def is_fixed_size(cls):
return False
@ -360,7 +361,20 @@ class Bitlist(BaseBitfield):
return cls()
class Bitvector(BaseBitfield):
class Bitvector(Bitfield):
@classmethod
def extract_args(cls, *args):
if len(args) == 0:
return cls.default()
else:
return super().extract_args(*args)
@classmethod
def value_check(cls, value):
# check length limit strictly
return len(value) == cls.length and super().value_check(value)
@classmethod
def is_fixed_size(cls):
return True
@ -369,6 +383,7 @@ class Bitvector(BaseBitfield):
def default(cls):
return cls(0 for _ in range(cls.length))
class List(BaseList):
@classmethod