Updates to SSZ partials
This commit is contained in:
parent
fd04f4129a
commit
2605dfba08
|
@ -64,43 +64,71 @@ y_data_root len(y)
|
|||
We can now define a concept of a "path", a way of describing a function that takes as input an SSZ object and outputs some specific (possibly deeply nested) member. For example, `foo -> foo.x` is a path, as are `foo -> len(foo.y)` and `foo -> foo.y[5].w`. We'll describe paths as lists, which can have two representations. In "human-readable form", they are `["x"]`, `["y", "__len__"]` and `["y", 5, "w"]` respectively. In "encoded form", they are lists of `uint64` values, in these cases (assuming the fields of `foo` in order are `x` then `y`, and `w` is the first field of `y[i]`) `[0]`, `[1, 2**64-1]`, `[1, 5, 0]`.
|
||||
|
||||
```python
|
||||
def path_to_encoded_form(obj: Any, path: List[Union[str, int]]) -> List[int]:
|
||||
if len(path) == 0:
|
||||
return []
|
||||
elif isinstance(path[0], "__len__"):
|
||||
assert len(path) == 1
|
||||
return [LENGTH_FLAG]
|
||||
elif isinstance(path[0], str) and hasattr(obj, "fields"):
|
||||
return [list(obj.fields.keys()).index(path[0])] + path_to_encoded_form(getattr(obj, path[0]), path[1:])
|
||||
elif isinstance(obj, (Vector, List)):
|
||||
return [path[0]] + path_to_encoded_form(obj[path[0]], path[1:])
|
||||
def item_length(typ: Type) -> int:
|
||||
"""
|
||||
Returns the number of bytes in a basic type, or 32 (a full hash) for compound types.
|
||||
"""
|
||||
if typ == bool:
|
||||
return 1
|
||||
elif issubclass(typ, uint):
|
||||
return typ.byte_len
|
||||
else:
|
||||
raise Exception("Unknown type / path")
|
||||
```
|
||||
|
||||
We can now define a function `get_generalized_indices(object: Any, path: List[int], root: int=1) -> List[int]` that converts an object and a path to a set of generalized indices (note that for constant-sized objects, there is only one generalized index and it only depends on the path, but for dynamically sized objects the indices may depend on the object itself too). For dynamically-sized objects, the set of indices will have more than one member because of the need to access an array's length to determine the correct generalized index for some array access.
|
||||
|
||||
```python
|
||||
def get_generalized_indices(obj: Any, path: List[int], root: int=1) -> List[int]:
|
||||
if len(path) == 0:
|
||||
return [root]
|
||||
elif isinstance(obj, Vector):
|
||||
items_per_chunk = (32 // len(serialize(x))) if isinstance(x, int) else 1
|
||||
new_root = root * next_power_of_2(len(obj) // items_per_chunk) + path[0] // items_per_chunk
|
||||
return get_generalized_indices(obj[path[0]], path[1:], new_root)
|
||||
elif isinstance(obj, List) and path[0] == LENGTH_FLAG:
|
||||
return [root * 2 + 1]
|
||||
elif isinstance(obj, List) and isinstance(path[0], int):
|
||||
assert path[0] < len(obj)
|
||||
items_per_chunk = (32 // len(serialize(x))) if isinstance(x, int) else 1
|
||||
new_root = root * 2 * next_power_of_2(len(obj) // items_per_chunk) + path[0] // items_per_chunk
|
||||
return [root *2 + 1] + get_generalized_indices(obj[path[0]], path[1:], new_root)
|
||||
elif hasattr(obj, "fields"):
|
||||
field = list(fields.keys())[path[0]]
|
||||
new_root = root * next_power_of_2(len(fields)) + path[0]
|
||||
return get_generalized_indices(getattr(obj, field), path[1:], new_root)
|
||||
return 32
|
||||
|
||||
|
||||
def get_elem_type(typ: Type, index: int) -> Type:
|
||||
"""
|
||||
Returns the type of the element of an object of the given type with the given index
|
||||
or member variable name (eg. `7` for `x[7]`, `"foo"` for `x.foo`)
|
||||
"""
|
||||
return typ.get_fields_dict()[index] if is_container_type(typ) else typ.elem_type
|
||||
|
||||
|
||||
def get_chunk_count(typ: Type) -> int:
|
||||
"""
|
||||
Returns the number of hashes needed to represent the top-level elements in the given type
|
||||
(eg. `x.foo` or `x[7]` but not `x[7].bar` or `x.foo.baz`). In all cases except lists/vectors
|
||||
of basic types, this is simply the number of top-level elements, as each element gets one
|
||||
hash. For lists/vectors of basic types, it is often fewer because multiple basic elements
|
||||
can be packed into one 32-byte chunk.
|
||||
"""
|
||||
if is_basic_type(typ):
|
||||
return 1
|
||||
elif is_list_kind(typ) or is_vector_kind(typ):
|
||||
return (typ.length * item_length(typ.elem_type) + 31) // 32
|
||||
else:
|
||||
raise Exception("Unknown type / path")
|
||||
return len(typ.get_fields())
|
||||
|
||||
|
||||
def get_item_position(typ: Type, index: Union[int, str]) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Returns three variables: (i) the index of the chunk in which the given element of the item is
|
||||
represented, (ii) the starting byte position, (iii) the ending byte position. For example for
|
||||
a 6-item list of uint64 values, index=2 will return (0, 16, 24), index=5 will return (1, 8, 16)
|
||||
"""
|
||||
if is_list_kind(typ) or is_vector_kind(typ):
|
||||
start = index * item_length(typ.elem_type)
|
||||
return start // 32, start % 32, start % 32 + item_length(typ.elem_type)
|
||||
elif is_container_type(typ):
|
||||
return typ.get_field_names().index(index), 0, item_length(get_elem_type(typ, index))
|
||||
else:
|
||||
raise Exception("Only lists/vectors/containers supported")
|
||||
|
||||
|
||||
def get_generalized_index(typ: Type, path: List[Union[int, str]]) -> int:
|
||||
"""
|
||||
Converts a path (eg. `[7, "foo", 3]` for `x[7].foo[3]`, `[12, "bar", "__len__"]` for
|
||||
`len(x[12].bar)`) into the generalized index representing its position in the Merkle tree.
|
||||
"""
|
||||
for p in path:
|
||||
assert not is_basic_type(typ) # If we descend to a basic type, the path cannot continue further
|
||||
if p == '__len__':
|
||||
typ, root = uint256, root * 2 + 1 if is_list_kind(typ) else None
|
||||
else:
|
||||
pos, _, _ = get_item_position(typ, p)
|
||||
root = root * (2 if is_list_kind(typ) else 1) * next_power_of_two(get_chunk_count(typ)) + pos
|
||||
typ = get_elem_type(typ, p)
|
||||
return root
|
||||
```
|
||||
|
||||
## Merkle multiproofs
|
||||
|
@ -116,72 +144,98 @@ x x . . . . x *
|
|||
|
||||
. are unused nodes, * are used nodes, x are the values we are trying to prove. Notice how despite being a multiproof for 3 values, it requires only 3 auxiliary nodes, only one node more than would be required to prove a single value. Normally the efficiency gains are not quite that extreme, but the savings relative to individual Merkle proofs are still significant. As a rule of thumb, a multiproof for k nodes at the same level of an n-node tree has size `k * (n/k + log(n/k))`.
|
||||
|
||||
Here is code for creating and verifying a multiproof. First, a method for computing the generalized indices of the auxiliary tree nodes that a proof of a given set of generalized indices will require:
|
||||
First, we provide a method for computing the generalized indices of the auxiliary tree nodes that a proof of a given set of generalized indices will require:
|
||||
|
||||
```python
|
||||
def get_proof_indices(tree_indices: List[int]) -> List[int]:
|
||||
# Get all indices touched by the proof
|
||||
maximal_indices = set()
|
||||
for i in tree_indices:
|
||||
x = i
|
||||
while x > 1:
|
||||
maximal_indices.add(x ^ 1)
|
||||
x //= 2
|
||||
maximal_indices = tree_indices + sorted(list(maximal_indices))[::-1]
|
||||
# Get indices that cannot be recalculated from earlier indices
|
||||
redundant_indices = set()
|
||||
proof = []
|
||||
for index in maximal_indices:
|
||||
if index not in redundant_indices:
|
||||
proof.append(index)
|
||||
while index > 1:
|
||||
redundant_indices.add(index)
|
||||
if (index ^ 1) not in redundant_indices:
|
||||
break
|
||||
index //= 2
|
||||
return [i for i in proof if i not in tree_indices]
|
||||
```
|
||||
def get_branch_indices(tree_index: int) -> List[int]:
|
||||
"""
|
||||
Get the generalized indices of the sister chunks along the path from the chunk with the
|
||||
given tree index to the root.
|
||||
"""
|
||||
o = [tree_index ^ 1]
|
||||
while o[-1] > 1:
|
||||
o.append((o[-1] // 2) ^ 1)
|
||||
return o[:-1]
|
||||
|
||||
def get_expanded_indices(indices: List[int]) -> List[int]:
|
||||
"""
|
||||
Get the generalized indices of all chunks in the tree needed to prove the chunks with the given
|
||||
generalized indices.
|
||||
"""
|
||||
branches = set()
|
||||
for index in indices:
|
||||
branches = branches.union(set(get_branch_indices(index) + [index]))
|
||||
return sorted(list([x for x in branches if x*2 not in branches or x*2+1 not in branches]))[::-1]
|
||||
```
|
||||
|
||||
Generating a proof is simply a matter of taking the node of the SSZ hash tree with the union of the given generalized indices for each index given by `get_proof_indices`, and outputting the list of nodes in the same order.
|
||||
Generating a proof that covers paths `p1 ... pn` is simply a matter of taking the chunks in the SSZ hash tree with generalized indices `get_expanded_indices([p1 ... pn])`.
|
||||
|
||||
Here is the verification function:
|
||||
We now provide the bulk of the proving machinery, a function that takes a `{generalized_index: chunk}` map and fills in chunks that can be inferred (inferring the parent by hashing its two children):
|
||||
|
||||
```python
|
||||
def verify_multi_proof(root: Bytes32, indices: List[int], leaves: List[Bytes32], proof: List[Bytes32]) -> bool:
|
||||
tree = {}
|
||||
for index, leaf in zip(indices, leaves):
|
||||
tree[index] = leaf
|
||||
for index, proof_item in zip(get_proof_indices(indices), proof):
|
||||
tree[index] = proof_item
|
||||
index_queue = sorted(tree.keys())[:-1]
|
||||
i = 0
|
||||
while i < len(index_queue):
|
||||
index = index_queue[i]
|
||||
if index >= 2 and index ^ 1 in tree:
|
||||
tree[index // 2] = hash(tree[index - index % 2] + tree[index - index % 2 + 1])
|
||||
index_queue.append(index // 2)
|
||||
i += 1
|
||||
return (indices == []) or (1 in tree and tree[1] == root)
|
||||
def fill(objects: Dict[int, Bytes32]) -> Dict[int, Bytes32]:
|
||||
"""
|
||||
Fills in chunks that can be inferred from other chunks. For a set of chunks that constitutes
|
||||
a valid proof, this includes the root (generalized index 1).
|
||||
"""
|
||||
objects = {k: v for k, v in objects.items()}
|
||||
keys = sorted(objects.keys())[::-1]
|
||||
pos = 0
|
||||
while pos < len(keys):
|
||||
k = keys[pos]
|
||||
if k in objects and k ^ 1 in objects and k // 2 not in objects:
|
||||
objects[k // 2] = hash(objects[k & - 2] + objects[k | 1])
|
||||
keys.append(k // 2)
|
||||
pos += 1
|
||||
# Completeness and consistency check
|
||||
assert 1 in objects
|
||||
for k in objects:
|
||||
if k > 1:
|
||||
assert objects[k // 2] == hash(objects[k & -2] + objects[k | 1])
|
||||
return objects
|
||||
```
|
||||
|
||||
## MerklePartial
|
||||
|
||||
We define:
|
||||
We define a container that encodes an SSZ partial, and provide the methods for converting it into a `{generalized_index: chunk}` map, for which we provide a method to extract individual values. To determine the hash tree root of an object represented by an SSZ partial, simply check `decode_ssz_partial(partial)[1]`.
|
||||
|
||||
### `SSZMerklePartial`
|
||||
|
||||
|
||||
```python
|
||||
{
|
||||
"root": "bytes32",
|
||||
"indices": ["uint64"],
|
||||
"values": ["bytes32"],
|
||||
"proof": ["bytes32"]
|
||||
}
|
||||
class SSZMerklePartial(Container):
|
||||
indices: List[uint64, 2**32]
|
||||
chunks: List[Bytes32, 2**32]
|
||||
```
|
||||
|
||||
### Proofs for execution
|
||||
### `decode_ssz_partial`
|
||||
|
||||
We define `MerklePartial(f, arg1, arg2..., focus=0)` as being a `SSZMerklePartial` object wrapping a Merkle multiproof of the set of nodes in the hash tree of the SSZ object `arg[focus]` that is needed to authenticate the parts of the object needed to compute `f(arg1, arg2...)`.
|
||||
```python
|
||||
def decode_ssz_partial(encoded: SSZMerklePartial) -> Dict[int, Bytes32]:
|
||||
"""
|
||||
Decodes an encoded SSZ partial into a generalized index -> chunk map, and verify hash consistency.
|
||||
"""
|
||||
full_indices = get_expanded_indices(encoded.indices)
|
||||
return fill({k:v for k,v in zip(full_indices, encoded.chunks)})
|
||||
```
|
||||
|
||||
Ideally, any function which accepts an SSZ object should also be able to accept a `SSZMerklePartial` object as a substitute.
|
||||
### `extract_value_at_path`
|
||||
|
||||
```python
|
||||
def extract_value_at_path(chunks: Dict[int, Bytes32], typ: Type, path: List[Union[int, str]]) -> Any:
|
||||
"""
|
||||
Provides the value of the element in the object represented by the given encoded SSZ partial at
|
||||
the given path. Returns a KeyError if that path is not covered by this SSZ partial.
|
||||
"""
|
||||
root = 1
|
||||
for p in path:
|
||||
if p == '__len__':
|
||||
return deserialize_basic(chunks[root * 2 + 1][:8], uint64)
|
||||
if is_list_kind(typ):
|
||||
assert 0 <= p < deserialize_basic(chunks[root * 2 + 1][:8], uint64)
|
||||
pos, start, end = get_item_position(typ, p)
|
||||
root = root * (2 if is_list_kind(typ) else 1) * next_power_of_two(get_chunk_count(typ)) + pos
|
||||
typ = get_elem_type(typ, p)
|
||||
return deserialize_basic(chunks[root][start: end], typ)
|
||||
```
|
||||
|
||||
Here [link TBD] is a python implementation of SSZ partials that represents them as a class that can be read and written to just like the underlying objects, so you can eg. perform state transitions on SSZ partials and compute the resulting root
|
||||
|
|
Loading…
Reference in New Issue