suggestion to implement bitfield like

This commit is contained in:
protolambda 2019-06-27 15:40:40 +02:00
parent f57387cc83
commit a5154da1ff
No known key found for this signature in database
GPG Key ID: EC89FDBB2B4C7623
2 changed files with 27 additions and 10 deletions

View File

@ -1,7 +1,7 @@
from ..merkle_minimal import merkleize_chunks from ..merkle_minimal import merkleize_chunks
from ..hash_function import hash from ..hash_function import hash
from .ssz_typing import ( from .ssz_typing import (
SSZValue, SSZType, BasicValue, BasicType, Series, Elements, boolean, Container, List, Bytes, SSZValue, SSZType, BasicValue, BasicType, Series, Elements, Bitfield, boolean, Container, List, Bytes,
Bitlist, Bitvector, uint, Bitlist, Bitvector, uint,
) )
@ -128,6 +128,8 @@ def item_length(typ: SSZType) -> int:
def chunk_count(typ: SSZType) -> int: def chunk_count(typ: SSZType) -> int:
if isinstance(typ, BasicType): if isinstance(typ, BasicType):
return 1 return 1
elif issubclass(typ, Bitfield):
return (typ.length + 7) // 8 // 32
elif issubclass(typ, Elements): elif issubclass(typ, Elements):
return (typ.length * item_length(typ.elem_type) + 31) // 32 return (typ.length * item_length(typ.elem_type) + 31) // 32
elif issubclass(typ, Container): elif issubclass(typ, Container):

View File

@ -281,10 +281,6 @@ class ElementsType(ParamsMeta):
length: int length: int
class BitElementsType(ElementsType):
elem_type = boolean
class Elements(ParamsBase, metaclass=ElementsType): class Elements(ParamsBase, metaclass=ElementsType):
pass pass
@ -346,11 +342,16 @@ class BaseList(list, Elements):
return self[len(self) - 1] return self[len(self) - 1]
class BaseBitfield(BaseList, metaclass=BitElementsType): class BitElementsType(ElementsType):
elem_type = bool elem_type: SSZType = boolean
length: int
class Bitlist(BaseBitfield): class Bitfield(BaseList, metaclass=BitElementsType):
pass
class Bitlist(Bitfield):
@classmethod @classmethod
def is_fixed_size(cls): def is_fixed_size(cls):
return False return False
@ -360,7 +361,20 @@ class Bitlist(BaseBitfield):
return cls() return cls()
class Bitvector(BaseBitfield): class Bitvector(Bitfield):
@classmethod
def extract_args(cls, *args):
if len(args) == 0:
return cls.default()
else:
return super().extract_args(*args)
@classmethod
def value_check(cls, value):
# check length limit strictly
return len(value) == cls.length and super().value_check(value)
@classmethod @classmethod
def is_fixed_size(cls): def is_fixed_size(cls):
return True return True
@ -369,6 +383,7 @@ class Bitvector(BaseBitfield):
def default(cls): def default(cls):
return cls(0 for _ in range(cls.length)) return cls(0 for _ in range(cls.length))
class List(BaseList): class List(BaseList):
@classmethod @classmethod