PR feedback, fix type hinting, add missing `Container.get_field_names()` method

This commit is contained in:
Hsiao-Wei Wang 2019-08-20 18:55:30 +08:00
parent bb0b5b09cc
commit 663d43d07f
No known key found for this signature in database
GPG Key ID: 95B070122902DEA4
3 changed files with 22 additions and 12 deletions

View File

@ -55,7 +55,6 @@ from eth2spec.utils.ssz.ssz_impl import (
) )
from eth2spec.utils.ssz.ssz_typing import ( from eth2spec.utils.ssz.ssz_typing import (
BasicValue, Elements, BaseList, SSZType, BasicValue, Elements, BaseList, SSZType,
SSZVariableName,
Container, List, Vector, Bytes, BytesN, Bitlist, Bitvector, Bits, Container, List, Vector, Bytes, BytesN, Bitlist, Bitvector, Bits,
Bytes1, Bytes4, Bytes8, Bytes32, Bytes48, Bytes96, Bytes1, Bytes4, Bytes8, Bytes32, Bytes48, Bytes96,
uint64, bit, boolean, uint64, bit, boolean,
@ -68,6 +67,9 @@ from eth2spec.utils.bls import (
) )
from eth2spec.utils.hash_function import hash from eth2spec.utils.hash_function import hash
SSZVariableName = str
''' '''
SUNDRY_CONSTANTS_FUNCTIONS = ''' SUNDRY_CONSTANTS_FUNCTIONS = '''
def ceillog2(x: uint64) -> int: def ceillog2(x: uint64) -> int:

View File

@ -38,8 +38,6 @@ def get_next_power_of_two(x: int) -> int:
""" """
if x <= 2: if x <= 2:
return x return x
elif x % 2 == 0:
return 2 * get_next_power_of_two(x // 2)
else: else:
return 2 * get_next_power_of_two((x + 1) // 2) return 2 * get_next_power_of_two((x + 1) // 2)
``` ```
@ -49,7 +47,10 @@ def get_previous_power_of_two(x: int) -> int:
""" """
Get the previous power of 2 >= the input. Get the previous power of 2 >= the input.
""" """
return x if x <= 2 else 2 * get_previous_power_of_two(x // 2) if x <= 2:
return x
else:
return 2 * get_previous_power_of_two(x // 2)
``` ```
## Generalized Merkle tree index ## Generalized Merkle tree index
@ -91,7 +92,7 @@ y_data_root len(y)
....... .......
``` ```
We can now define a concept of a "path", a way of describing a function that takes as input an SSZ object and outputs some specific (possibly deeply nested) member. For example, `foo -> foo.x` is a path, as are `foo -> len(foo.y)` and `foo -> foo.y[5].w`. We'll describe paths as lists, which can have two representations. In "human-readable form", they are `["x"]`, `["y", "__len__"]` and `["y", 5, "w"]` respectively. In "encoded form", they are lists of `uint64` values, in these cases (assuming the fields of `foo` in order are `x` then `y`, and `w` is the first field of `y[i]`) `[0]`, `[1, 2**64-1]`, `[1, 5, 0]`. We can now define a concept of a "path", a way of describing a function that takes as input an SSZ object and outputs some specific (possibly deeply nested) member. For example, `foo -> foo.x` is a path, as are `foo -> len(foo.y)` and `foo -> foo.y[5].w`. We'll describe paths as lists, which can have two representations. In "human-readable form", they are `["x"]`, `["y", "__len__"]` and `["y", 5, "w"]` respectively. In "encoded form", they are lists of `uint64` values, in these cases (assuming the fields of `foo` in order are `x` then `y`, and `w` is the first field of `y[i]`) `[0]`, `[1, 2**64-1]`, `[1, 5, 0]`. We define `SSZVariableName` as the member variable name string, i.e., a path is presented as a sequence of integers and `SSZVariableName`.
```python ```python
def item_length(typ: SSZType) -> int: def item_length(typ: SSZType) -> int:
@ -149,7 +150,7 @@ def get_item_position(typ: SSZType, index_or_variable_name: Union[int, SSZVariab
start = index * item_length(typ.elem_type) start = index * item_length(typ.elem_type)
return start // 32, start % 32, start % 32 + item_length(typ.elem_type) return start // 32, start % 32, start % 32 + item_length(typ.elem_type)
elif issubclass(typ, Container): elif issubclass(typ, Container):
variable_name = int(index_or_variable_name) variable_name = index_or_variable_name
return typ.get_field_names().index(variable_name), 0, item_length(get_elem_type(typ, variable_name)) return typ.get_field_names().index(variable_name), 0, item_length(get_elem_type(typ, variable_name))
else: else:
raise Exception("Only lists/vectors/containers supported") raise Exception("Only lists/vectors/containers supported")
@ -161,11 +162,15 @@ def get_generalized_index(typ: SSZType, path: Sequence[Union[int, SSZVariableNam
Converts a path (eg. `[7, "foo", 3]` for `x[7].foo[3]`, `[12, "bar", "__len__"]` for Converts a path (eg. `[7, "foo", 3]` for `x[7].foo[3]`, `[12, "bar", "__len__"]` for
`len(x[12].bar)`) into the generalized index representing its position in the Merkle tree. `len(x[12].bar)`) into the generalized index representing its position in the Merkle tree.
""" """
root: Optional[GeneralizedIndex] = GeneralizedIndex(1) root = GeneralizedIndex(1)
for p in path: for p in path:
assert not issubclass(typ, BasicValue) # If we descend to a basic type, the path cannot continue further assert not issubclass(typ, BasicValue) # If we descend to a basic type, the path cannot continue further
if p == '__len__': if p == '__len__':
typ, root = uint64, root * 2 + 1 if issubclass(typ, (List, Bytes)) else None typ = uint64
if issubclass(typ, (List, Bytes)):
root = GeneralizedIndex(root * 2 + 1)
else:
return None
else: else:
pos, _, _ = get_item_position(typ, p) pos, _, _ = get_item_position(typ, p)
base_index = (GeneralizedIndex(2) if issubclass(typ, (List, Bytes)) else GeneralizedIndex(1)) base_index = (GeneralizedIndex(2) if issubclass(typ, (List, Bytes)) else GeneralizedIndex(1))

View File

@ -1,11 +1,8 @@
from typing import Dict, Iterator, NewType from typing import Dict, Iterator, Iterable
import copy import copy
from types import GeneratorType from types import GeneratorType
SSZVariableName = NewType('SSZVariableName', str)
class DefaultingTypeMeta(type): class DefaultingTypeMeta(type):
def default(cls): def default(cls):
raise Exception("Not implemented") raise Exception("Not implemented")
@ -198,6 +195,12 @@ class Container(Series, metaclass=SSZType):
return {} return {}
return dict(cls.__annotations__) return dict(cls.__annotations__)
@classmethod
def get_field_names(cls) -> Iterable[SSZType]:
if not hasattr(cls, '__annotations__'): # no container fields
return ()
return list(cls.__annotations__.keys())
@classmethod @classmethod
def default(cls): def default(cls):
return cls(**{f: t.default() for f, t in cls.get_fields().items()}) return cls(**{f: t.default() for f, t in cls.get_fields().items()})