Updated type checkers for generalized index functions.
This commit is contained in:
parent
446ad3c392
commit
55f5f106f1
|
@ -17,12 +17,6 @@
|
||||||
|
|
||||||
<!-- /TOC -->
|
<!-- /TOC -->
|
||||||
|
|
||||||
## Constants
|
|
||||||
|
|
||||||
| Name | Value |
|
|
||||||
| - | - |
|
|
||||||
| `LENGTH_FLAG` | `2**64 - 1` |
|
|
||||||
|
|
||||||
## Generalized Merkle tree index
|
## Generalized Merkle tree index
|
||||||
|
|
||||||
In a binary Merkle tree, we define a "generalized index" of a node as `2**depth + index`. Visually, this looks as follows:
|
In a binary Merkle tree, we define a "generalized index" of a node as `2**depth + index`. Visually, this looks as follows:
|
||||||
|
@ -38,7 +32,8 @@ Note that the generalized index has the convenient property that the two childre
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def merkle_tree(leaves: List[Bytes32]) -> List[Bytes32]:
|
def merkle_tree(leaves: List[Bytes32]) -> List[Bytes32]:
|
||||||
o = [0] * len(leaves) + leaves
|
padded_length = next_power_of_2(len(leaves))
|
||||||
|
o = [ZERO_HASH] * padded_length + leaves + [ZERO_HASH] * (padded_length - len(leaves))
|
||||||
for i in range(len(leaves) - 1, 0, -1):
|
for i in range(len(leaves) - 1, 0, -1):
|
||||||
o[i] = hash(o[i * 2] + o[i * 2 + 1])
|
o[i] = hash(o[i * 2] + o[i * 2 + 1])
|
||||||
return o
|
return o
|
||||||
|
@ -64,27 +59,24 @@ y_data_root len(y)
|
||||||
We can now define a concept of a "path", a way of describing a function that takes as input an SSZ object and outputs some specific (possibly deeply nested) member. For example, `foo -> foo.x` is a path, as are `foo -> len(foo.y)` and `foo -> foo.y[5].w`. We'll describe paths as lists, which can have two representations. In "human-readable form", they are `["x"]`, `["y", "__len__"]` and `["y", 5, "w"]` respectively. In "encoded form", they are lists of `uint64` values, in these cases (assuming the fields of `foo` in order are `x` then `y`, and `w` is the first field of `y[i]`) `[0]`, `[1, 2**64-1]`, `[1, 5, 0]`.
|
We can now define a concept of a "path", a way of describing a function that takes as input an SSZ object and outputs some specific (possibly deeply nested) member. For example, `foo -> foo.x` is a path, as are `foo -> len(foo.y)` and `foo -> foo.y[5].w`. We'll describe paths as lists, which can have two representations. In "human-readable form", they are `["x"]`, `["y", "__len__"]` and `["y", 5, "w"]` respectively. In "encoded form", they are lists of `uint64` values, in these cases (assuming the fields of `foo` in order are `x` then `y`, and `w` is the first field of `y[i]`) `[0]`, `[1, 2**64-1]`, `[1, 5, 0]`.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def item_length(typ: Type) -> int:
|
def item_length(typ: SSZType) -> int:
|
||||||
"""
|
"""
|
||||||
Returns the number of bytes in a basic type, or 32 (a full hash) for compound types.
|
Returns the number of bytes in a basic type, or 32 (a full hash) for compound types.
|
||||||
"""
|
"""
|
||||||
if typ == bool:
|
if issubclass(typ, BasicValue):
|
||||||
return 1
|
|
||||||
elif issubclass(typ, uint):
|
|
||||||
return typ.byte_len
|
return typ.byte_len
|
||||||
else:
|
else:
|
||||||
return 32
|
return 32
|
||||||
|
|
||||||
|
|
||||||
def get_elem_type(typ: Type, index: int) -> Type:
|
def get_elem_type(typ: ComplexType, index: int) -> Type:
|
||||||
"""
|
"""
|
||||||
Returns the type of the element of an object of the given type with the given index
|
Returns the type of the element of an object of the given type with the given index
|
||||||
or member variable name (eg. `7` for `x[7]`, `"foo"` for `x.foo`)
|
or member variable name (eg. `7` for `x[7]`, `"foo"` for `x.foo`)
|
||||||
"""
|
"""
|
||||||
return typ.get_fields_dict()[index] if is_container_type(typ) else typ.elem_type
|
return typ.get_fields()[key] if issubclass(typ, Container) else typ.elem_type
|
||||||
|
|
||||||
|
def chunk_count(typ: SSZType) -> int:
|
||||||
def get_chunk_count(typ: Type) -> int:
|
|
||||||
"""
|
"""
|
||||||
Returns the number of hashes needed to represent the top-level elements in the given type
|
Returns the number of hashes needed to represent the top-level elements in the given type
|
||||||
(eg. `x.foo` or `x[7]` but not `x[7].bar` or `x.foo.baz`). In all cases except lists/vectors
|
(eg. `x.foo` or `x[7]` but not `x[7].bar` or `x.foo.baz`). In all cases except lists/vectors
|
||||||
|
@ -92,24 +84,28 @@ def get_chunk_count(typ: Type) -> int:
|
||||||
hash. For lists/vectors of basic types, it is often fewer because multiple basic elements
|
hash. For lists/vectors of basic types, it is often fewer because multiple basic elements
|
||||||
can be packed into one 32-byte chunk.
|
can be packed into one 32-byte chunk.
|
||||||
"""
|
"""
|
||||||
if is_basic_type(typ):
|
if issubclass(typ, BasicValue):
|
||||||
return 1
|
return 1
|
||||||
elif issubclass(typ, (List, Vector, Bytes, BytesN)):
|
elif issubclass(typ, Bits):
|
||||||
|
return (typ.length + 255) // 256
|
||||||
|
elif issubclass(typ, Elements):
|
||||||
return (typ.length * item_length(typ.elem_type) + 31) // 32
|
return (typ.length * item_length(typ.elem_type) + 31) // 32
|
||||||
else:
|
elif issubclass(typ, Container):
|
||||||
return len(typ.get_fields())
|
return len(typ.get_fields())
|
||||||
|
else:
|
||||||
|
raise Exception(f"Type not supported: {typ}")
|
||||||
|
|
||||||
|
|
||||||
def get_item_position(typ: Type, index: Union[int, str]) -> Tuple[int, int, int]:
|
def get_item_position(typ: SSZType, index: Union[int, str]) -> Tuple[int, int, int]:
|
||||||
"""
|
"""
|
||||||
Returns three variables: (i) the index of the chunk in which the given element of the item is
|
Returns three variables: (i) the index of the chunk in which the given element of the item is
|
||||||
represented, (ii) the starting byte position, (iii) the ending byte position. For example for
|
represented, (ii) the starting byte position, (iii) the ending byte position. For example for
|
||||||
a 6-item list of uint64 values, index=2 will return (0, 16, 24), index=5 will return (1, 8, 16)
|
a 6-item list of uint64 values, index=2 will return (0, 16, 24), index=5 will return (1, 8, 16)
|
||||||
"""
|
"""
|
||||||
if issubclass(typ, (List, Vector, Bytes, BytesN)):
|
if issubclass(typ, Elements):
|
||||||
start = index * item_length(typ.elem_type)
|
start = index * item_length(typ.elem_type)
|
||||||
return start // 32, start % 32, start % 32 + item_length(typ.elem_type)
|
return start // 32, start % 32, start % 32 + item_length(typ.elem_type)
|
||||||
elif is_container_type(typ):
|
elif issubclass(typ, Container):
|
||||||
return typ.get_field_names().index(index), 0, item_length(get_elem_type(typ, index))
|
return typ.get_field_names().index(index), 0, item_length(get_elem_type(typ, index))
|
||||||
else:
|
else:
|
||||||
raise Exception("Only lists/vectors/containers supported")
|
raise Exception("Only lists/vectors/containers supported")
|
||||||
|
@ -122,12 +118,12 @@ def get_generalized_index(typ: Type, path: List[Union[int, str]]) -> Generalized
|
||||||
"""
|
"""
|
||||||
root = 1
|
root = 1
|
||||||
for p in path:
|
for p in path:
|
||||||
assert not is_basic_type(typ) # If we descend to a basic type, the path cannot continue further
|
assert not issubclass(typ, BasicValue) # If we descend to a basic type, the path cannot continue further
|
||||||
if p == '__len__':
|
if p == '__len__':
|
||||||
typ, root = uint256, root * 2 + 1 if issubclass(typ, (List, Bytes)) else None
|
typ, root = uint256, root * 2 + 1 if issubclass(typ, (List, Bytes)) else None
|
||||||
else:
|
else:
|
||||||
pos, _, _ = get_item_position(typ, p)
|
pos, _, _ = get_item_position(typ, p)
|
||||||
root = root * (2 if issubclass(typ, (List, Bytes)) else 1) * next_power_of_two(get_chunk_count(typ)) + pos
|
root = root * (2 if issubclass(typ, (List, Bytes)) else 1) * next_power_of_two(chunk_count(typ)) + pos
|
||||||
typ = get_elem_type(typ, p)
|
typ = get_elem_type(typ, p)
|
||||||
return root
|
return root
|
||||||
```
|
```
|
||||||
|
@ -197,7 +193,7 @@ def get_branch_indices(tree_index: int) -> List[int]:
|
||||||
def get_expanded_indices(indices: List[int]) -> List[int]:
|
def get_expanded_indices(indices: List[int]) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Get the generalized indices of all chunks in the tree needed to prove the chunks with the given
|
Get the generalized indices of all chunks in the tree needed to prove the chunks with the given
|
||||||
generalized indices.
|
generalized indices, including the leaves.
|
||||||
"""
|
"""
|
||||||
branches = set()
|
branches = set()
|
||||||
for index in indices:
|
for index in indices:
|
||||||
|
|
Loading…
Reference in New Issue