Merge pull request #1907 from ethereum/fix_deprecated_merkle_util
Fix deprecated utility code, avoid wrong helper function name, add tests
This commit is contained in:
commit
e06bbd14f7
|
@ -26,25 +26,34 @@
|
|||
## Helper functions
|
||||
|
||||
```python
|
||||
def get_next_power_of_two(x: int) -> int:
|
||||
def get_power_of_two_ceil(x: int) -> int:
|
||||
"""
|
||||
Get next power of 2 >= the input.
|
||||
Get the power of 2 for given input, or the closest higher power of 2 if the input is not a power of 2.
|
||||
Commonly used for "how many nodes do I need for a bottom tree layer fitting x elements?"
|
||||
Example: 0->1, 1->1, 2->2, 3->4, 4->4, 5->8, 6->8, 7->8, 8->8, 9->16.
|
||||
"""
|
||||
if x <= 2:
|
||||
return x
|
||||
if x <= 1:
|
||||
return 1
|
||||
elif x == 2:
|
||||
return 2
|
||||
else:
|
||||
return 2 * get_next_power_of_two((x + 1) // 2)
|
||||
return 2 * get_power_of_two_ceil((x + 1) // 2)
|
||||
```
|
||||
|
||||
```python
|
||||
def get_previous_power_of_two(x: int) -> int:
|
||||
def get_power_of_two_floor(x: int) -> int:
|
||||
"""
|
||||
Get the previous power of 2 >= the input.
|
||||
Get the power of 2 for given input, or the closest lower power of 2 if the input is not a power of 2.
|
||||
The zero case is a placeholder and not used for math with generalized indices.
|
||||
Commonly used for "what power of two makes up the root bit of the generalized index?"
|
||||
Example: 0->1, 1->1, 2->2, 3->2, 4->4, 5->4, 6->4, 7->4, 8->8, 9->8
|
||||
"""
|
||||
if x <= 2:
|
||||
if x <= 1:
|
||||
return 1
|
||||
if x == 2:
|
||||
return x
|
||||
else:
|
||||
return 2 * get_previous_power_of_two(x // 2)
|
||||
return 2 * get_power_of_two_floor(x // 2)
|
||||
```
|
||||
|
||||
## Generalized Merkle tree index
|
||||
|
@ -62,9 +71,14 @@ Note that the generalized index has the convenient property that the two childre
|
|||
|
||||
```python
|
||||
def merkle_tree(leaves: Sequence[Bytes32]) -> Sequence[Bytes32]:
|
||||
padded_length = get_next_power_of_two(len(leaves))
|
||||
o = [Bytes32()] * padded_length + list(leaves) + [Bytes32()] * (padded_length - len(leaves))
|
||||
for i in range(padded_length - 1, 0, -1):
|
||||
"""
|
||||
Return an array representing the tree nodes by generalized index:
|
||||
[0, 1, 2, 3, 4, 5, 6, 7], where each layer is a power of 2. The 0 index is ignored. The 1 index is the root.
|
||||
The result will be twice the size as the padded bottom layer for the input leaves.
|
||||
"""
|
||||
bottom_length = get_power_of_two_ceil(len(leaves))
|
||||
o = [Bytes32()] * bottom_length + list(leaves) + [Bytes32()] * (bottom_length - len(leaves))
|
||||
for i in range(bottom_length - 1, 0, -1):
|
||||
o[i] = hash(o[i * 2] + o[i * 2 + 1])
|
||||
return o
|
||||
```
|
||||
|
@ -169,7 +183,7 @@ def get_generalized_index(typ: SSZType, path: Sequence[Union[int, SSZVariableNam
|
|||
else:
|
||||
pos, _, _ = get_item_position(typ, p)
|
||||
base_index = (GeneralizedIndex(2) if issubclass(typ, (List, ByteList)) else GeneralizedIndex(1))
|
||||
root = GeneralizedIndex(root * base_index * get_next_power_of_two(chunk_count(typ)) + pos)
|
||||
root = GeneralizedIndex(root * base_index * get_power_of_two_ceil(chunk_count(typ)) + pos)
|
||||
typ = get_elem_type(typ, p)
|
||||
return root
|
||||
```
|
||||
|
@ -188,7 +202,7 @@ def concat_generalized_indices(*indices: GeneralizedIndex) -> GeneralizedIndex:
|
|||
"""
|
||||
o = GeneralizedIndex(1)
|
||||
for i in indices:
|
||||
o = GeneralizedIndex(o * get_previous_power_of_two(i) + (i - get_previous_power_of_two(i)))
|
||||
o = GeneralizedIndex(o * get_power_of_two_floor(i) + (i - get_power_of_two_floor(i)))
|
||||
return o
|
||||
```
|
||||
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
import pytest
|
||||
|
||||
|
||||
# Note: these functions are extract from merkle-proofs.md (deprecated),
|
||||
# the tests are temporary to show correctness while the document is still there.
|
||||
|
||||
def get_power_of_two_ceil(x: int) -> int:
|
||||
if x <= 1:
|
||||
return 1
|
||||
elif x == 2:
|
||||
return 2
|
||||
else:
|
||||
return 2 * get_power_of_two_ceil((x + 1) // 2)
|
||||
|
||||
|
||||
def get_power_of_two_floor(x: int) -> int:
|
||||
if x <= 1:
|
||||
return 1
|
||||
if x == 2:
|
||||
return x
|
||||
else:
|
||||
return 2 * get_power_of_two_floor(x // 2)
|
||||
|
||||
|
||||
power_of_two_ceil_cases = [
|
||||
(0, 1), (1, 1), (2, 2), (3, 4), (4, 4), (5, 8), (6, 8), (7, 8), (8, 8), (9, 16),
|
||||
]
|
||||
|
||||
power_of_two_floor_cases = [
|
||||
(0, 1), (1, 1), (2, 2), (3, 2), (4, 4), (5, 4), (6, 4), (7, 4), (8, 8), (9, 8),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'value,expected',
|
||||
power_of_two_ceil_cases,
|
||||
)
|
||||
def test_get_power_of_two_ceil(value, expected):
|
||||
assert get_power_of_two_ceil(value) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'value,expected',
|
||||
power_of_two_floor_cases,
|
||||
)
|
||||
def test_get_power_of_two_floor(value, expected):
|
||||
assert get_power_of_two_floor(value) == expected
|
Loading…
Reference in New Issue