PR feedback, fix type hinting, add missing `Container.get_field_names()` method
This commit is contained in:
parent
bb0b5b09cc
commit
663d43d07f
|
@ -55,7 +55,6 @@ from eth2spec.utils.ssz.ssz_impl import (
|
|||
)
|
||||
from eth2spec.utils.ssz.ssz_typing import (
|
||||
BasicValue, Elements, BaseList, SSZType,
|
||||
SSZVariableName,
|
||||
Container, List, Vector, Bytes, BytesN, Bitlist, Bitvector, Bits,
|
||||
Bytes1, Bytes4, Bytes8, Bytes32, Bytes48, Bytes96,
|
||||
uint64, bit, boolean,
|
||||
|
@ -68,6 +67,9 @@ from eth2spec.utils.bls import (
|
|||
)
|
||||
|
||||
from eth2spec.utils.hash_function import hash
|
||||
|
||||
|
||||
SSZVariableName = str
|
||||
'''
|
||||
SUNDRY_CONSTANTS_FUNCTIONS = '''
|
||||
def ceillog2(x: uint64) -> int:
|
||||
|
|
|
@ -38,8 +38,6 @@ def get_next_power_of_two(x: int) -> int:
|
|||
"""
|
||||
if x <= 2:
|
||||
return x
|
||||
elif x % 2 == 0:
|
||||
return 2 * get_next_power_of_two(x // 2)
|
||||
else:
|
||||
return 2 * get_next_power_of_two((x + 1) // 2)
|
||||
```
|
||||
|
@ -49,7 +47,10 @@ def get_previous_power_of_two(x: int) -> int:
|
|||
"""
|
||||
Get the previous power of 2 >= the input.
|
||||
"""
|
||||
return x if x <= 2 else 2 * get_previous_power_of_two(x // 2)
|
||||
if x <= 2:
|
||||
return x
|
||||
else:
|
||||
return 2 * get_previous_power_of_two(x // 2)
|
||||
```
|
||||
|
||||
## Generalized Merkle tree index
|
||||
|
@ -91,7 +92,7 @@ y_data_root len(y)
|
|||
.......
|
||||
```
|
||||
|
||||
We can now define a concept of a "path", a way of describing a function that takes as input an SSZ object and outputs some specific (possibly deeply nested) member. For example, `foo -> foo.x` is a path, as are `foo -> len(foo.y)` and `foo -> foo.y[5].w`. We'll describe paths as lists, which can have two representations. In "human-readable form", they are `["x"]`, `["y", "__len__"]` and `["y", 5, "w"]` respectively. In "encoded form", they are lists of `uint64` values, in these cases (assuming the fields of `foo` in order are `x` then `y`, and `w` is the first field of `y[i]`) `[0]`, `[1, 2**64-1]`, `[1, 5, 0]`.
|
||||
We can now define a concept of a "path", a way of describing a function that takes as input an SSZ object and outputs some specific (possibly deeply nested) member. For example, `foo -> foo.x` is a path, as are `foo -> len(foo.y)` and `foo -> foo.y[5].w`. We'll describe paths as lists, which can have two representations. In "human-readable form", they are `["x"]`, `["y", "__len__"]` and `["y", 5, "w"]` respectively. In "encoded form", they are lists of `uint64` values, in these cases (assuming the fields of `foo` in order are `x` then `y`, and `w` is the first field of `y[i]`) `[0]`, `[1, 2**64-1]`, `[1, 5, 0]`. We define `SSZVariableName` as the member variable name string, i.e., a path is presented as a sequence of integers and `SSZVariableName`.
|
||||
|
||||
```python
|
||||
def item_length(typ: SSZType) -> int:
|
||||
|
@ -149,7 +150,7 @@ def get_item_position(typ: SSZType, index_or_variable_name: Union[int, SSZVariab
|
|||
start = index * item_length(typ.elem_type)
|
||||
return start // 32, start % 32, start % 32 + item_length(typ.elem_type)
|
||||
elif issubclass(typ, Container):
|
||||
variable_name = int(index_or_variable_name)
|
||||
variable_name = index_or_variable_name
|
||||
return typ.get_field_names().index(variable_name), 0, item_length(get_elem_type(typ, variable_name))
|
||||
else:
|
||||
raise Exception("Only lists/vectors/containers supported")
|
||||
|
@ -161,11 +162,15 @@ def get_generalized_index(typ: SSZType, path: Sequence[Union[int, SSZVariableNam
|
|||
Converts a path (eg. `[7, "foo", 3]` for `x[7].foo[3]`, `[12, "bar", "__len__"]` for
|
||||
`len(x[12].bar)`) into the generalized index representing its position in the Merkle tree.
|
||||
"""
|
||||
root: Optional[GeneralizedIndex] = GeneralizedIndex(1)
|
||||
root = GeneralizedIndex(1)
|
||||
for p in path:
|
||||
assert not issubclass(typ, BasicValue) # If we descend to a basic type, the path cannot continue further
|
||||
if p == '__len__':
|
||||
typ, root = uint64, root * 2 + 1 if issubclass(typ, (List, Bytes)) else None
|
||||
typ = uint64
|
||||
if issubclass(typ, (List, Bytes)):
|
||||
root = GeneralizedIndex(root * 2 + 1)
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
pos, _, _ = get_item_position(typ, p)
|
||||
base_index = (GeneralizedIndex(2) if issubclass(typ, (List, Bytes)) else GeneralizedIndex(1))
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
from typing import Dict, Iterator, NewType
|
||||
from typing import Dict, Iterator, Iterable
|
||||
import copy
|
||||
from types import GeneratorType
|
||||
|
||||
|
||||
SSZVariableName = NewType('SSZVariableName', str)
|
||||
|
||||
|
||||
class DefaultingTypeMeta(type):
|
||||
def default(cls):
|
||||
raise Exception("Not implemented")
|
||||
|
@ -198,6 +195,12 @@ class Container(Series, metaclass=SSZType):
|
|||
return {}
|
||||
return dict(cls.__annotations__)
|
||||
|
||||
@classmethod
|
||||
def get_field_names(cls) -> Iterable[SSZType]:
|
||||
if not hasattr(cls, '__annotations__'): # no container fields
|
||||
return ()
|
||||
return list(cls.__annotations__.keys())
|
||||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
return cls(**{f: t.default() for f, t in cls.get_fields().items()})
|
||||
|
|
Loading…
Reference in New Issue