Fix deprecated utility code, avoid wrong helper function name, add tests

This commit is contained in:
protolambda 2020-06-17 22:34:43 +02:00
parent 7117d2e75a
commit 3f765f55ca
No known key found for this signature in database
GPG Key ID: EC89FDBB2B4C7623
2 changed files with 75 additions and 14 deletions

View File

@ -26,25 +26,34 @@
## Helper functions ## Helper functions
```python ```python
def get_next_power_of_two(x: int) -> int: def get_power_of_two_ceil(x: int) -> int:
""" """
Get next power of 2 >= the input. Get the power of 2 for given input, or the closest higher power of 2 if the input is not a power of 2.
Commonly used for "how many nodes do I need for a bottom tree layer fitting x elements?"
Example: 0->1, 1->1, 2->2, 3->4, 4->4, 5->8, 6->8, 7->8, 8->8, 9->16.
""" """
if x <= 2: if x <= 1:
return x return 1
elif x == 2:
return 2
else: else:
return 2 * get_next_power_of_two((x + 1) // 2) return 2 * get_power_of_two_ceil((x + 1) // 2)
``` ```
```python ```python
def get_previous_power_of_two(x: int) -> int: def get_power_of_two_floor(x: int) -> int:
""" """
Get the previous power of 2 >= the input. Get the power of 2 for given input, or the closest lower power of 2 if the input is not a power of 2.
The zero case is a placeholder and not used for math with generalized indices.
Commonly used for "what power of two makes up the root bit of the generalized index?"
Example: 0->1, 1->1, 2->2, 3->2, 4->4, 5->4, 6->4, 7->4, 8->8, 9->8
""" """
if x <= 2: if x <= 1:
return 1
if x == 2:
return x return x
else: else:
return 2 * get_previous_power_of_two(x // 2) return 2 * get_power_of_two_floor(x // 2)
``` ```
## Generalized Merkle tree index ## Generalized Merkle tree index
@ -62,9 +71,14 @@ Note that the generalized index has the convenient property that the two childre
```python ```python
def merkle_tree(leaves: Sequence[Bytes32]) -> Sequence[Bytes32]: def merkle_tree(leaves: Sequence[Bytes32]) -> Sequence[Bytes32]:
padded_length = get_next_power_of_two(len(leaves)) """
o = [Bytes32()] * padded_length + list(leaves) + [Bytes32()] * (padded_length - len(leaves)) Return an array representing the tree nodes by generalized index:
for i in range(padded_length - 1, 0, -1): [0, 1, 2, 3, 4, 5, 6, 7], where each layer is a power of 2. The 0 index is ignored. The 1 index is the root.
The result will be twice the size as the padded bottom layer for the input leaves.
"""
bottom_length = get_power_of_two_ceil(len(leaves))
o = [Bytes32()] * bottom_length + list(leaves) + [Bytes32()] * (bottom_length - len(leaves))
for i in range(bottom_length - 1, 0, -1):
o[i] = hash(o[i * 2] + o[i * 2 + 1]) o[i] = hash(o[i * 2] + o[i * 2 + 1])
return o return o
``` ```
@ -169,7 +183,7 @@ def get_generalized_index(typ: SSZType, path: Sequence[Union[int, SSZVariableNam
else: else:
pos, _, _ = get_item_position(typ, p) pos, _, _ = get_item_position(typ, p)
base_index = (GeneralizedIndex(2) if issubclass(typ, (List, ByteList)) else GeneralizedIndex(1)) base_index = (GeneralizedIndex(2) if issubclass(typ, (List, ByteList)) else GeneralizedIndex(1))
root = GeneralizedIndex(root * base_index * get_next_power_of_two(chunk_count(typ)) + pos) root = GeneralizedIndex(root * base_index * get_power_of_two_ceil(chunk_count(typ)) + pos)
typ = get_elem_type(typ, p) typ = get_elem_type(typ, p)
return root return root
``` ```
@ -188,7 +202,7 @@ def concat_generalized_indices(*indices: GeneralizedIndex) -> GeneralizedIndex:
""" """
o = GeneralizedIndex(1) o = GeneralizedIndex(1)
for i in indices: for i in indices:
o = GeneralizedIndex(o * get_previous_power_of_two(i) + (i - get_previous_power_of_two(i))) o = GeneralizedIndex(o * get_power_of_two_floor(i) + (i - get_power_of_two_floor(i)))
return o return o
``` ```

View File

@ -0,0 +1,47 @@
import pytest
# Note: these functions are extract from merkle-proofs.md (deprecated),
# the tests are temporary to show correctness while the document is still there.
def get_power_of_two_ceil(x: int) -> int:
if x <= 1:
return 1
elif x == 2:
return 2
else:
return 2 * get_power_of_two_ceil((x + 1) // 2)
def get_power_of_two_floor(x: int) -> int:
if x <= 1:
return 1
if x == 2:
return x
else:
return 2 * get_power_of_two_floor(x // 2)
power_of_two_ceil_cases = [
(0, 1), (1, 1), (2, 2), (3, 4), (4, 4), (5, 8), (6, 8), (7, 8), (8, 8), (9, 16),
]
power_of_two_floor_cases = [
(0, 1), (1, 1), (2, 2), (3, 2), (4, 4), (5, 4), (6, 4), (7, 4), (8, 8), (9, 8),
]
@pytest.mark.parametrize(
'value,expected',
power_of_two_ceil_cases,
)
def test_get_power_of_two_ceil(value, expected):
assert get_power_of_two_ceil(value) == expected
@pytest.mark.parametrize(
'value,expected',
power_of_two_floor_cases,
)
def test_get_power_of_two_floor(value, expected):
assert get_power_of_two_floor(value) == expected