diff --git a/ssz/merkle-proofs.md b/ssz/merkle-proofs.md index 44b85fdda..2f32e43eb 100644 --- a/ssz/merkle-proofs.md +++ b/ssz/merkle-proofs.md @@ -26,25 +26,34 @@ ## Helper functions ```python -def get_next_power_of_two(x: int) -> int: +def get_power_of_two_ceil(x: int) -> int: """ - Get next power of 2 >= the input. + Get the power of 2 for given input, or the closest higher power of 2 if the input is not a power of 2. + Commonly used for "how many nodes do I need for a bottom tree layer fitting x elements?" + Example: 0->1, 1->1, 2->2, 3->4, 4->4, 5->8, 6->8, 7->8, 8->8, 9->16. """ - if x <= 2: - return x + if x <= 1: + return 1 + elif x == 2: + return 2 else: - return 2 * get_next_power_of_two((x + 1) // 2) + return 2 * get_power_of_two_ceil((x + 1) // 2) ``` ```python -def get_previous_power_of_two(x: int) -> int: +def get_power_of_two_floor(x: int) -> int: """ - Get the previous power of 2 >= the input. + Get the power of 2 for given input, or the closest lower power of 2 if the input is not a power of 2. + The zero case is a placeholder and not used for math with generalized indices. + Commonly used for "what power of two makes up the root bit of the generalized index?" + Example: 0->1, 1->1, 2->2, 3->2, 4->4, 5->4, 6->4, 7->4, 8->8, 9->8 """ - if x <= 2: + if x <= 1: + return 1 + if x == 2: return x else: - return 2 * get_previous_power_of_two(x // 2) + return 2 * get_power_of_two_floor(x // 2) ``` ## Generalized Merkle tree index @@ -62,9 +71,14 @@ Note that the generalized index has the convenient property that the two childre ```python def merkle_tree(leaves: Sequence[Bytes32]) -> Sequence[Bytes32]: - padded_length = get_next_power_of_two(len(leaves)) - o = [Bytes32()] * padded_length + list(leaves) + [Bytes32()] * (padded_length - len(leaves)) - for i in range(padded_length - 1, 0, -1): + """ + Return an array representing the tree nodes by generalized index: + [0, 1, 2, 3, 4, 5, 6, 7], where each layer is a power of 2. The 0 index is ignored. The 1 index is the root. + The result will be twice the size as the padded bottom layer for the input leaves. + """ + bottom_length = get_power_of_two_ceil(len(leaves)) + o = [Bytes32()] * bottom_length + list(leaves) + [Bytes32()] * (bottom_length - len(leaves)) + for i in range(bottom_length - 1, 0, -1): o[i] = hash(o[i * 2] + o[i * 2 + 1]) return o ``` @@ -169,7 +183,7 @@ def get_generalized_index(typ: SSZType, path: Sequence[Union[int, SSZVariableNam else: pos, _, _ = get_item_position(typ, p) base_index = (GeneralizedIndex(2) if issubclass(typ, (List, ByteList)) else GeneralizedIndex(1)) - root = GeneralizedIndex(root * base_index * get_next_power_of_two(chunk_count(typ)) + pos) + root = GeneralizedIndex(root * base_index * get_power_of_two_ceil(chunk_count(typ)) + pos) typ = get_elem_type(typ, p) return root ``` @@ -188,7 +202,7 @@ def concat_generalized_indices(*indices: GeneralizedIndex) -> GeneralizedIndex: """ o = GeneralizedIndex(1) for i in indices: - o = GeneralizedIndex(o * get_previous_power_of_two(i) + (i - get_previous_power_of_two(i))) + o = GeneralizedIndex(o * get_power_of_two_floor(i) + (i - get_power_of_two_floor(i))) return o ``` diff --git a/tests/core/pyspec/eth2spec/utils/test_merkle_proof_util.py b/tests/core/pyspec/eth2spec/utils/test_merkle_proof_util.py new file mode 100644 index 000000000..e1d59fa8c --- /dev/null +++ b/tests/core/pyspec/eth2spec/utils/test_merkle_proof_util.py @@ -0,0 +1,47 @@ +import pytest + + +# Note: these functions are extract from merkle-proofs.md (deprecated), +# the tests are temporary to show correctness while the document is still there. + +def get_power_of_two_ceil(x: int) -> int: + if x <= 1: + return 1 + elif x == 2: + return 2 + else: + return 2 * get_power_of_two_ceil((x + 1) // 2) + + +def get_power_of_two_floor(x: int) -> int: + if x <= 1: + return 1 + if x == 2: + return x + else: + return 2 * get_power_of_two_floor(x // 2) + + +power_of_two_ceil_cases = [ + (0, 1), (1, 1), (2, 2), (3, 4), (4, 4), (5, 8), (6, 8), (7, 8), (8, 8), (9, 16), +] + +power_of_two_floor_cases = [ + (0, 1), (1, 1), (2, 2), (3, 2), (4, 4), (5, 4), (6, 4), (7, 4), (8, 8), (9, 8), +] + + +@pytest.mark.parametrize( + 'value,expected', + power_of_two_ceil_cases, +) +def test_get_power_of_two_ceil(value, expected): + assert get_power_of_two_ceil(value) == expected + + +@pytest.mark.parametrize( + 'value,expected', + power_of_two_floor_cases, +) +def test_get_power_of_two_floor(value, expected): + assert get_power_of_two_floor(value) == expected