Fix deprecated utility code, avoid wrong helper function name, add tests
This commit is contained in:
parent
7117d2e75a
commit
3f765f55ca
|
@ -26,25 +26,34 @@
|
||||||
## Helper functions
|
## Helper functions
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def get_next_power_of_two(x: int) -> int:
|
def get_power_of_two_ceil(x: int) -> int:
|
||||||
"""
|
"""
|
||||||
Get next power of 2 >= the input.
|
Get the power of 2 for given input, or the closest higher power of 2 if the input is not a power of 2.
|
||||||
|
Commonly used for "how many nodes do I need for a bottom tree layer fitting x elements?"
|
||||||
|
Example: 0->1, 1->1, 2->2, 3->4, 4->4, 5->8, 6->8, 7->8, 8->8, 9->16.
|
||||||
"""
|
"""
|
||||||
if x <= 2:
|
if x <= 1:
|
||||||
return x
|
return 1
|
||||||
|
elif x == 2:
|
||||||
|
return 2
|
||||||
else:
|
else:
|
||||||
return 2 * get_next_power_of_two((x + 1) // 2)
|
return 2 * get_power_of_two_ceil((x + 1) // 2)
|
||||||
```
|
```
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def get_previous_power_of_two(x: int) -> int:
|
def get_power_of_two_floor(x: int) -> int:
|
||||||
"""
|
"""
|
||||||
Get the previous power of 2 >= the input.
|
Get the power of 2 for given input, or the closest lower power of 2 if the input is not a power of 2.
|
||||||
|
The zero case is a placeholder and not used for math with generalized indices.
|
||||||
|
Commonly used for "what power of two makes up the root bit of the generalized index?"
|
||||||
|
Example: 0->1, 1->1, 2->2, 3->2, 4->4, 5->4, 6->4, 7->4, 8->8, 9->8
|
||||||
"""
|
"""
|
||||||
if x <= 2:
|
if x <= 1:
|
||||||
|
return 1
|
||||||
|
if x == 2:
|
||||||
return x
|
return x
|
||||||
else:
|
else:
|
||||||
return 2 * get_previous_power_of_two(x // 2)
|
return 2 * get_power_of_two_floor(x // 2)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Generalized Merkle tree index
|
## Generalized Merkle tree index
|
||||||
|
@ -62,9 +71,14 @@ Note that the generalized index has the convenient property that the two childre
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def merkle_tree(leaves: Sequence[Bytes32]) -> Sequence[Bytes32]:
|
def merkle_tree(leaves: Sequence[Bytes32]) -> Sequence[Bytes32]:
|
||||||
padded_length = get_next_power_of_two(len(leaves))
|
"""
|
||||||
o = [Bytes32()] * padded_length + list(leaves) + [Bytes32()] * (padded_length - len(leaves))
|
Return an array representing the tree nodes by generalized index:
|
||||||
for i in range(padded_length - 1, 0, -1):
|
[0, 1, 2, 3, 4, 5, 6, 7], where each layer is a power of 2. The 0 index is ignored. The 1 index is the root.
|
||||||
|
The result will be twice the size as the padded bottom layer for the input leaves.
|
||||||
|
"""
|
||||||
|
bottom_length = get_power_of_two_ceil(len(leaves))
|
||||||
|
o = [Bytes32()] * bottom_length + list(leaves) + [Bytes32()] * (bottom_length - len(leaves))
|
||||||
|
for i in range(bottom_length - 1, 0, -1):
|
||||||
o[i] = hash(o[i * 2] + o[i * 2 + 1])
|
o[i] = hash(o[i * 2] + o[i * 2 + 1])
|
||||||
return o
|
return o
|
||||||
```
|
```
|
||||||
|
@ -169,7 +183,7 @@ def get_generalized_index(typ: SSZType, path: Sequence[Union[int, SSZVariableNam
|
||||||
else:
|
else:
|
||||||
pos, _, _ = get_item_position(typ, p)
|
pos, _, _ = get_item_position(typ, p)
|
||||||
base_index = (GeneralizedIndex(2) if issubclass(typ, (List, ByteList)) else GeneralizedIndex(1))
|
base_index = (GeneralizedIndex(2) if issubclass(typ, (List, ByteList)) else GeneralizedIndex(1))
|
||||||
root = GeneralizedIndex(root * base_index * get_next_power_of_two(chunk_count(typ)) + pos)
|
root = GeneralizedIndex(root * base_index * get_power_of_two_ceil(chunk_count(typ)) + pos)
|
||||||
typ = get_elem_type(typ, p)
|
typ = get_elem_type(typ, p)
|
||||||
return root
|
return root
|
||||||
```
|
```
|
||||||
|
@ -188,7 +202,7 @@ def concat_generalized_indices(*indices: GeneralizedIndex) -> GeneralizedIndex:
|
||||||
"""
|
"""
|
||||||
o = GeneralizedIndex(1)
|
o = GeneralizedIndex(1)
|
||||||
for i in indices:
|
for i in indices:
|
||||||
o = GeneralizedIndex(o * get_previous_power_of_two(i) + (i - get_previous_power_of_two(i)))
|
o = GeneralizedIndex(o * get_power_of_two_floor(i) + (i - get_power_of_two_floor(i)))
|
||||||
return o
|
return o
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# Note: these functions are extract from merkle-proofs.md (deprecated),
|
||||||
|
# the tests are temporary to show correctness while the document is still there.
|
||||||
|
|
||||||
|
def get_power_of_two_ceil(x: int) -> int:
|
||||||
|
if x <= 1:
|
||||||
|
return 1
|
||||||
|
elif x == 2:
|
||||||
|
return 2
|
||||||
|
else:
|
||||||
|
return 2 * get_power_of_two_ceil((x + 1) // 2)
|
||||||
|
|
||||||
|
|
||||||
|
def get_power_of_two_floor(x: int) -> int:
|
||||||
|
if x <= 1:
|
||||||
|
return 1
|
||||||
|
if x == 2:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return 2 * get_power_of_two_floor(x // 2)
|
||||||
|
|
||||||
|
|
||||||
|
power_of_two_ceil_cases = [
|
||||||
|
(0, 1), (1, 1), (2, 2), (3, 4), (4, 4), (5, 8), (6, 8), (7, 8), (8, 8), (9, 16),
|
||||||
|
]
|
||||||
|
|
||||||
|
power_of_two_floor_cases = [
|
||||||
|
(0, 1), (1, 1), (2, 2), (3, 2), (4, 4), (5, 4), (6, 4), (7, 4), (8, 8), (9, 8),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'value,expected',
|
||||||
|
power_of_two_ceil_cases,
|
||||||
|
)
|
||||||
|
def test_get_power_of_two_ceil(value, expected):
|
||||||
|
assert get_power_of_two_ceil(value) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'value,expected',
|
||||||
|
power_of_two_floor_cases,
|
||||||
|
)
|
||||||
|
def test_get_power_of_two_floor(value, expected):
|
||||||
|
assert get_power_of_two_floor(value) == expected
|
Loading…
Reference in New Issue