Merge pull request #1363 from ethereum/executable_merkle_proofs

Executable Merkle proofs
This commit is contained in:
Diederik Loerakker 2019-08-23 14:50:12 +02:00 committed by GitHub
commit 5d2f34f882
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 312 additions and 85 deletions

View File

@ -89,7 +89,7 @@ $(PY_SPEC_PHASE_0_TARGETS): $(PY_SPEC_PHASE_0_DEPS)
python3 $(SCRIPT_DIR)/build_spec.py -p0 $(SPEC_DIR)/core/0_beacon-chain.md $(SPEC_DIR)/core/0_fork-choice.md $(SPEC_DIR)/validator/0_beacon-chain-validator.md $@
$(PY_SPEC_DIR)/eth2spec/phase1/spec.py: $(PY_SPEC_PHASE_1_DEPS)
python3 $(SCRIPT_DIR)/build_spec.py -p1 $(SPEC_DIR)/core/0_beacon-chain.md $(SPEC_DIR)/core/1_custody-game.md $(SPEC_DIR)/core/1_shard-data-chains.md $(SPEC_DIR)/core/0_fork-choice.md $@
python3 $(SCRIPT_DIR)/build_spec.py -p1 $(SPEC_DIR)/core/0_beacon-chain.md $(SPEC_DIR)/core/0_fork-choice.md $(SPEC_DIR)/core/1_custody-game.md $(SPEC_DIR)/core/1_shard-data-chains.md $(SPEC_DIR)/light_client/merkle_proofs.md $@
CURRENT_DIR = ${CURDIR}

View File

@ -37,7 +37,10 @@ from eth2spec.utils.bls import (
from eth2spec.utils.hash_function import hash
'''
PHASE1_IMPORTS = '''from typing import (
Any, Dict, Optional, Set, Sequence, MutableSequence, Tuple, Union,
Any, Dict, Optional, Set, Sequence, MutableSequence, NewType, Tuple, Union,
)
from math import (
log2,
)
from dataclasses import (
@ -51,8 +54,10 @@ from eth2spec.utils.ssz.ssz_impl import (
is_zero,
)
from eth2spec.utils.ssz.ssz_typing import (
uint64, bit, boolean, Container, List, Vector, Bytes, BytesN,
Bytes1, Bytes4, Bytes8, Bytes32, Bytes48, Bytes96, Bitlist, Bitvector,
BasicValue, Elements, BaseBytes, BaseList, SSZType,
Container, List, Vector, Bytes, BytesN, Bitlist, Bitvector, Bits,
Bytes1, Bytes4, Bytes8, Bytes32, Bytes48, Bytes96,
uint64, bit, boolean,
)
from eth2spec.utils.bls import (
bls_aggregate_pubkeys,
@ -62,6 +67,10 @@ from eth2spec.utils.bls import (
)
from eth2spec.utils.hash_function import hash
SSZVariableName = str
GeneralizedIndex = NewType('GeneralizedIndex', int)
'''
SUNDRY_CONSTANTS_FUNCTIONS = '''
def ceillog2(x: uint64) -> int:
@ -281,17 +290,23 @@ def build_phase0_spec(phase0_sourcefile: str, fork_choice_sourcefile: str,
def build_phase1_spec(phase0_sourcefile: str,
fork_choice_sourcefile: str,
phase1_custody_sourcefile: str,
phase1_shard_sourcefile: str,
fork_choice_sourcefile: str,
merkle_proofs_sourcefile: str,
outfile: str=None) -> Optional[str]:
phase0_spec = get_spec(phase0_sourcefile)
remove_for_phase1(phase0_spec[0])
phase1_custody = get_spec(phase1_custody_sourcefile)
phase1_shard_data = get_spec(phase1_shard_sourcefile)
fork_choice_spec = get_spec(fork_choice_sourcefile)
spec_objects = phase0_spec
for value in [phase1_custody, phase1_shard_data, fork_choice_spec]:
all_sourcefiles = (
phase0_sourcefile,
fork_choice_sourcefile,
phase1_custody_sourcefile,
phase1_shard_sourcefile,
merkle_proofs_sourcefile,
)
all_spescs = [get_spec(spec) for spec in all_sourcefiles]
for spec in all_spescs:
remove_for_phase1(spec[0])
spec_objects = all_spescs[0]
for value in all_spescs[1:]:
spec_objects = combine_spec_objects(spec_objects, value)
spec = objects_to_spec(*spec_objects, PHASE1_IMPORTS)
if outfile is not None:
@ -304,17 +319,18 @@ if __name__ == '__main__':
description = '''
Build the specs from the md docs.
If building phase 0:
1st argument is input spec.md
2nd argument is input fork_choice.md
3rd argument is input validator_guide.md
1st argument is input /core/0_beacon-chain.md
2nd argument is input /core/0_fork-choice.md
3rd argument is input /core/0_beacon-chain-validator.md
4th argument is output spec.py
If building phase 1:
1st argument is input spec_phase0.md
2nd argument is input spec_phase1_custody.md
3rd argument is input spec_phase1_shard_data.md
4th argument is input fork_choice.md
5th argument is output spec.py
1st argument is input /core/0_beacon-chain.md
2nd argument is input /core/0_fork-choice.md
3rd argument is input /core/1_custody-game.md
4th argument is input /core/1_shard-data-chains.md
5th argument is input /light_client/merkle_proofs.md
6th argument is output spec.py
'''
parser = ArgumentParser(description=description)
parser.add_argument("-p", "--phase", dest="phase", type=int, default=0, help="Build for phase #")
@ -327,10 +343,15 @@ If building phase 1:
else:
print(" Phase 0 requires spec, forkchoice, and v-guide inputs as well as an output file.")
elif args.phase == 1:
if len(args.files) == 5:
if len(args.files) == 6:
build_phase1_spec(*args.files)
else:
print(" Phase 1 requires 4 input files as well as an output file: "
+ "(phase0.md and phase1.md, phase1.md, fork_choice.md, output.py)")
print(
" Phase 1 requires input files as well as an output file:\n"
"\t core/phase_0: (0_beacon-chain.md, 0_fork-choice.md)\n"
"\t core/phase_1: (1_custody-game.md, 1_shard-data-chains.md)\n"
"\t light_client: (merkle_proofs.md)\n"
"\t and output.py"
)
else:
print("Invalid phase: {0}".format(args.phase))

View File

@ -7,6 +7,7 @@
- [Merkle proof formats](#merkle-proof-formats)
- [Table of contents](#table-of-contents)
- [Helper functions](#helper-functions)
- [Generalized Merkle tree index](#generalized-merkle-tree-index)
- [SSZ object to index](#ssz-object-to-index)
- [Helpers for generalized indices](#helpers-for-generalized-indices)
@ -20,6 +21,30 @@
<!-- /TOC -->
## Helper functions
```python
def get_next_power_of_two(x: int) -> int:
"""
Get next power of 2 >= the input.
"""
if x <= 2:
return x
else:
return 2 * get_next_power_of_two((x + 1) // 2)
```
```python
def get_previous_power_of_two(x: int) -> int:
"""
Get the previous power of 2 >= the input.
"""
if x <= 2:
return x
else:
return 2 * get_previous_power_of_two(x // 2)
```
## Generalized Merkle tree index
In a binary Merkle tree, we define a "generalized index" of a node as `2**depth + index`. Visually, this looks as follows:
@ -34,14 +59,16 @@ In a binary Merkle tree, we define a "generalized index" of a node as `2**depth
Note that the generalized index has the convenient property that the two children of node `k` are `2k` and `2k+1`, and also that it equals the position of a node in the linear representation of the Merkle tree that's computed by this function:
```python
def merkle_tree(leaves: List[Bytes32]) -> List[Bytes32]:
padded_length = next_power_of_2(len(leaves))
o = [ZERO_HASH] * padded_length + leaves + [ZERO_HASH] * (padded_length - len(leaves))
def merkle_tree(leaves: Sequence[Hash]) -> Sequence[Hash]:
padded_length = get_next_power_of_two(len(leaves))
o = [Hash()] * padded_length + list(leaves) + [Hash()] * (padded_length - len(leaves))
for i in range(len(leaves) - 1, 0, -1):
o[i] = hash(o[i * 2] + o[i * 2 + 1])
return o
```
We define a custom type `GeneralizedIndex` as a Python integer type in this document. It can be represented as a Bitvector/Bitlist object as well.
We will define Merkle proofs in terms of generalized indices.
## SSZ object to index
@ -59,30 +86,33 @@ y_data_root len(y)
.......
```
We can now define a concept of a "path", a way of describing a function that takes as input an SSZ object and outputs some specific (possibly deeply nested) member. For example, `foo -> foo.x` is a path, as are `foo -> len(foo.y)` and `foo -> foo.y[5].w`. We'll describe paths as lists, which can have two representations. In "human-readable form", they are `["x"]`, `["y", "__len__"]` and `["y", 5, "w"]` respectively. In "encoded form", they are lists of `uint64` values, in these cases (assuming the fields of `foo` in order are `x` then `y`, and `w` is the first field of `y[i]`) `[0]`, `[1, 2**64-1]`, `[1, 5, 0]`.
We can now define a concept of a "path", a way of describing a function that takes as input an SSZ object and outputs some specific (possibly deeply nested) member. For example, `foo -> foo.x` is a path, as are `foo -> len(foo.y)` and `foo -> foo.y[5].w`. We'll describe paths as lists, which can have two representations. In "human-readable form", they are `["x"]`, `["y", "__len__"]` and `["y", 5, "w"]` respectively. In "encoded form", they are lists of `uint64` values, in these cases (assuming the fields of `foo` in order are `x` then `y`, and `w` is the first field of `y[i]`) `[0]`, `[1, 2**64-1]`, `[1, 5, 0]`. We define `SSZVariableName` as the member variable name string, i.e., a path is presented as a sequence of integers and `SSZVariableName`.
```python
def item_length(typ: SSZType) -> int:
"""
Returns the number of bytes in a basic type, or 32 (a full hash) for compound types.
Return the number of bytes in a basic type, or 32 (a full hash) for compound types.
"""
if issubclass(typ, BasicValue):
return typ.byte_len
else:
return 32
```
def get_elem_type(typ: ComplexType, index: Union[int, str]) -> Type:
```python
def get_elem_type(typ: Union[BaseBytes, BaseList, Container],
index_or_variable_name: Union[int, SSZVariableName]) -> SSZType:
"""
Returns the type of the element of an object of the given type with the given index
Return the type of the element of an object of the given type with the given index
or member variable name (eg. `7` for `x[7]`, `"foo"` for `x.foo`)
"""
return typ.get_fields()[index] if issubclass(typ, Container) else typ.elem_type
return typ.get_fields()[index_or_variable_name] if issubclass(typ, Container) else typ.elem_type
```
```python
def chunk_count(typ: SSZType) -> int:
"""
Returns the number of hashes needed to represent the top-level elements in the given type
Return the number of hashes needed to represent the top-level elements in the given type
(eg. `x.foo` or `x[7]` but not `x[7].bar` or `x.foo.baz`). In all cases except lists/vectors
of basic types, this is simply the number of top-level elements, as each element gets one
hash. For lists/vectors of basic types, it is often fewer because multiple basic elements
@ -99,36 +129,47 @@ def chunk_count(typ: SSZType) -> int:
return len(typ.get_fields())
else:
raise Exception(f"Type not supported: {typ}")
```
def get_item_position(typ: SSZType, index: Union[int, str]) -> Tuple[int, int, int]:
```python
def get_item_position(typ: SSZType, index_or_variable_name: Union[int, SSZVariableName]) -> Tuple[int, int, int]:
"""
Returns three variables: (i) the index of the chunk in which the given element of the item is
represented, (ii) the starting byte position within the chunk, (iii) the ending byte position within the chunk. For example for
a 6-item list of uint64 values, index=2 will return (0, 16, 24), index=5 will return (1, 8, 16)
Return three variables:
(i) the index of the chunk in which the given element of the item is represented;
(ii) the starting byte position within the chunk;
(iii) the ending byte position within the chunk.
For example: for a 6-item list of uint64 values, index=2 will return (0, 16, 24), index=5 will return (1, 8, 16)
"""
if issubclass(typ, Elements):
index = int(index_or_variable_name)
start = index * item_length(typ.elem_type)
return start // 32, start % 32, start % 32 + item_length(typ.elem_type)
elif issubclass(typ, Container):
return typ.get_field_names().index(index), 0, item_length(get_elem_type(typ, index))
variable_name = index_or_variable_name
return typ.get_field_names().index(variable_name), 0, item_length(get_elem_type(typ, variable_name))
else:
raise Exception("Only lists/vectors/containers supported")
```
def get_generalized_index(typ: Type, path: List[Union[int, str]]) -> GeneralizedIndex:
```python
def get_generalized_index(typ: SSZType, path: Sequence[Union[int, SSZVariableName]]) -> Optional[GeneralizedIndex]:
"""
Converts a path (eg. `[7, "foo", 3]` for `x[7].foo[3]`, `[12, "bar", "__len__"]` for
`len(x[12].bar)`) into the generalized index representing its position in the Merkle tree.
"""
root = 1
root = GeneralizedIndex(1)
for p in path:
assert not issubclass(typ, BasicValue) # If we descend to a basic type, the path cannot continue further
if p == '__len__':
typ, root = uint64, root * 2 + 1 if issubclass(typ, (List, Bytes)) else None
typ = uint64
if issubclass(typ, (List, Bytes)):
root = GeneralizedIndex(root * 2 + 1)
else:
return None
else:
pos, _, _ = get_item_position(typ, p)
root = root * (2 if issubclass(typ, (List, Bytes)) else 1) * next_power_of_two(chunk_count(typ)) + pos
base_index = (GeneralizedIndex(2) if issubclass(typ, (List, Bytes)) else GeneralizedIndex(1))
root = GeneralizedIndex(root * base_index * get_next_power_of_two(chunk_count(typ)) + pos)
typ = get_elem_type(typ, p)
return root
```
@ -140,14 +181,14 @@ _Usage note: functions outside this section should manipulate generalized indice
#### `concat_generalized_indices`
```python
def concat_generalized_indices(*indices: Sequence[GeneralizedIndex]) -> GeneralizedIndex:
def concat_generalized_indices(indices: Sequence[GeneralizedIndex]) -> GeneralizedIndex:
"""
Given generalized indices i1 for A -> B, i2 for B -> C .... i_n for Y -> Z, returns
the generalized index for A -> Z.
"""
o = GeneralizedIndex(1)
for i in indices:
o = o * get_previous_power_of_2(i) + (i - get_previous_power_of_2(i))
o = GeneralizedIndex(o * get_previous_power_of_two(i) + (i - get_previous_power_of_two(i)))
return o
```
@ -156,9 +197,9 @@ def concat_generalized_indices(*indices: Sequence[GeneralizedIndex]) -> Generali
```python
def get_generalized_index_length(index: GeneralizedIndex) -> int:
"""
Returns the length of a path represented by a generalized index.
Return the length of a path represented by a generalized index.
"""
return log2(index)
return int(log2(index))
```
#### `get_generalized_index_bit`
@ -166,7 +207,7 @@ def get_generalized_index_length(index: GeneralizedIndex) -> int:
```python
def get_generalized_index_bit(index: GeneralizedIndex, position: int) -> bool:
"""
Returns the given bit of a generalized index.
Return the given bit of a generalized index.
"""
return (index & (1 << position)) > 0
```
@ -175,21 +216,21 @@ def get_generalized_index_bit(index: GeneralizedIndex, position: int) -> bool:
```python
def generalized_index_sibling(index: GeneralizedIndex) -> GeneralizedIndex:
return index ^ 1
return GeneralizedIndex(index ^ 1)
```
#### `generalized_index_child`
```python
def generalized_index_child(index: GeneralizedIndex, right_side: bool) -> GeneralizedIndex:
return index * 2 + right_side
return GeneralizedIndex(index * 2 + right_side)
```
#### `generalized_index_parent`
```python
def generalized_index_parent(index: GeneralizedIndex) -> GeneralizedIndex:
return index // 2
return GeneralizedIndex(index // 2)
```
## Merkle multiproofs
@ -208,7 +249,7 @@ x x . . . . x *
First, we provide a method for computing the generalized indices of the auxiliary tree nodes that a proof of a given set of generalized indices will require:
```python
def get_branch_indices(tree_index: GeneralizedIndex) -> List[GeneralizedIndex]:
def get_branch_indices(tree_index: GeneralizedIndex) -> Sequence[GeneralizedIndex]:
"""
Get the generalized indices of the sister chunks along the path from the chunk with the
given tree index to the root.
@ -217,21 +258,26 @@ def get_branch_indices(tree_index: GeneralizedIndex) -> List[GeneralizedIndex]:
while o[-1] > 1:
o.append(generalized_index_sibling(generalized_index_parent(o[-1])))
return o[:-1]
```
def get_helper_indices(indices: List[GeneralizedIndex]) -> List[GeneralizedIndex]:
```python
def get_helper_indices(indices: Sequence[GeneralizedIndex]) -> Sequence[GeneralizedIndex]:
"""
Get the generalized indices of all "extra" chunks in the tree needed to prove the chunks with the given
generalized indices. Note that the decreasing order is chosen deliberately to ensure equivalence to the
order of hashes in a regular single-item Merkle proof in the single-item case.
"""
all_indices = set()
all_indices: Set[GeneralizedIndex] = set()
for index in indices:
all_indices = all_indices.union(set(get_branch_indices(index) + [index]))
all_indices = all_indices.union(set(list(get_branch_indices(index)) + [index]))
return sorted([
x for x in all_indices if not
(generalized_index_child(x, 0) in all_indices and generalized_index_child(x, 1) in all_indices) and not
(x in indices)
x for x in all_indices if (
not (
generalized_index_child(x, False) in all_indices and
generalized_index_child(x, True) in all_indices
) and not (x in indices)
)
], reverse=True)
```
@ -251,23 +297,29 @@ def verify_merkle_proof(leaf: Hash, proof: Sequence[Hash], index: GeneralizedInd
Now for multi-item proofs:
```python
def verify_merkle_multiproof(leaves: Sequence[Hash], proof: Sequence[Hash], indices: Sequence[GeneralizedIndex], root: Hash) -> bool:
def verify_merkle_multiproof(leaves: Sequence[Hash],
proof: Sequence[Hash],
indices: Sequence[GeneralizedIndex],
root: Hash) -> bool:
assert len(leaves) == len(indices)
helper_indices = get_helper_indices(indices)
assert len(proof) == len(helper_indices)
objects = {
**{index:node for index, node in zip(indices, leaves)},
**{index:node for index, node in zip(helper_indices, proof)}
**{index: node for index, node in zip(indices, leaves)},
**{index: node for index, node in zip(helper_indices, proof)}
}
keys = sorted(objects.keys(), reverse=True)
pos = 0
while pos < len(keys):
k = keys[pos]
if k in objects and k ^ 1 in objects and k // 2 not in objects:
objects[k // 2] = hash(objects[(k | 1) ^ 1] + objects[k | 1])
keys.append(k // 2)
objects[GeneralizedIndex(k // 2)] = hash(
objects[GeneralizedIndex((k | 1) ^ 1)] +
objects[GeneralizedIndex(k | 1)]
)
keys.append(GeneralizedIndex(k // 2))
pos += 1
return objects[1] == root
return objects[GeneralizedIndex(1)] == root
```
Note that the single-item proof is a special case of a multi-item proof; a valid single-item proof verifies correctly when put into the multi-item verification function (making the natural trivial changes to input arguments, `index -> [index]` and `leaf -> [leaf]`).

View File

@ -0,0 +1,148 @@
import re
from eth_utils import (
to_tuple,
)
from eth2spec.test.context import (
spec_state_test,
with_all_phases_except,
)
from eth2spec.utils.ssz.ssz_typing import (
Bytes32,
Container,
List,
uint64,
)
class Foo(Container):
x: uint64
y: List[Bytes32, 2]
# Tree
# root
# / \
# x y_root
# / \
# y_data_root len(y)
# / \
# / \ / \
#
# Generalized indices
# 1
# / \
# 2 (x) 3 (y_root)
# / \
# 6 7
# / \
# 12 13
@to_tuple
def ssz_object_to_path(start, end):
is_len = False
len_findall = re.findall(r"(?<=len\().*(?=\))", end)
if len_findall:
is_len = True
end = len_findall[0]
route = ''
if end.startswith(start):
route = end[len(start):]
segments = route.split('.')
for word in segments:
index_match = re.match(r"(\w+)\[(\d+)]", word)
if index_match:
yield from index_match.groups()
elif len(word):
yield word
if is_len:
yield '__len__'
to_path_test_cases = [
('foo', 'foo.x', ('x',)),
('foo', 'foo.x[100].y', ('x', '100', 'y')),
('foo', 'foo.x[100].y[1].z[2]', ('x', '100', 'y', '1', 'z', '2')),
('foo', 'len(foo.x[100].y[1].z[2])', ('x', '100', 'y', '1', 'z', '2', '__len__')),
]
def test_to_path():
for test_case in to_path_test_cases:
start, end, expected = test_case
assert ssz_object_to_path(start, end) == expected
generalized_index_cases = [
(Foo, ('x',), 2),
(Foo, ('y',), 3),
(Foo, ('y', 0), 12),
(Foo, ('y', 1), 13),
(Foo, ('y', '__len__'), None),
]
@with_all_phases_except(['phase0'])
@spec_state_test
def test_get_generalized_index(spec, state):
for typ, path, generalized_index in generalized_index_cases:
assert spec.get_generalized_index(
typ=typ,
path=path,
) == generalized_index
yield 'typ', typ
yield 'path', path
yield 'generalized_index', generalized_index
@with_all_phases_except(['phase0'])
@spec_state_test
def test_verify_merkle_proof(spec, state):
h = spec.hash
a = b'\x11' * 32
b = b'\x22' * 32
c = b'\x33' * 32
d = b'\x44' * 32
root = h(h(a + b) + h(c + d))
leaf = a
generalized_index = 4
proof = [b, h(c + d)]
is_valid = spec.verify_merkle_proof(
leaf=leaf,
proof=proof,
index=generalized_index,
root=root,
)
assert is_valid
yield 'proof', proof
yield 'is_valid', is_valid
@with_all_phases_except(['phase0'])
@spec_state_test
def test_verify_merkle_multiproof(spec, state):
h = spec.hash
a = b'\x11' * 32
b = b'\x22' * 32
c = b'\x33' * 32
d = b'\x44' * 32
root = h(h(a + b) + h(c + d))
leaves = [a, d]
generalized_indices = [4, 7]
proof = [c, b] # helper_indices = [6, 5]
is_valid = spec.verify_merkle_multiproof(
leaves=leaves,
proof=proof,
indices=generalized_indices,
root=root,
)
assert is_valid
yield 'proof', proof
yield 'is_valid', is_valid

View File

@ -1,4 +1,4 @@
from typing import Dict, Iterator
from typing import Dict, Iterator, Iterable
import copy
from types import GeneratorType
@ -195,6 +195,12 @@ class Container(Series, metaclass=SSZType):
return {}
return dict(cls.__annotations__)
@classmethod
def get_field_names(cls) -> Iterable[SSZType]:
if not hasattr(cls, '__annotations__'): # no container fields
return ()
return list(cls.__annotations__.keys())
@classmethod
def default(cls):
return cls(**{f: t.default() for f, t in cls.get_fields().items()})