Updated type checkers for generalized index functions.

This commit is contained in:
vbuterin 2019-08-01 10:56:31 -04:00 committed by GitHub
parent 446ad3c392
commit 55f5f106f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 20 additions and 24 deletions

View File

@ -17,12 +17,6 @@
<!-- /TOC --> <!-- /TOC -->
## Constants
| Name | Value |
| - | - |
| `LENGTH_FLAG` | `2**64 - 1` |
## Generalized Merkle tree index ## Generalized Merkle tree index
In a binary Merkle tree, we define a "generalized index" of a node as `2**depth + index`. Visually, this looks as follows: In a binary Merkle tree, we define a "generalized index" of a node as `2**depth + index`. Visually, this looks as follows:
@ -38,7 +32,8 @@ Note that the generalized index has the convenient property that the two childre
```python ```python
def merkle_tree(leaves: List[Bytes32]) -> List[Bytes32]: def merkle_tree(leaves: List[Bytes32]) -> List[Bytes32]:
o = [0] * len(leaves) + leaves padded_length = next_power_of_2(len(leaves))
o = [ZERO_HASH] * padded_length + leaves + [ZERO_HASH] * (padded_length - len(leaves))
for i in range(len(leaves) - 1, 0, -1): for i in range(len(leaves) - 1, 0, -1):
o[i] = hash(o[i * 2] + o[i * 2 + 1]) o[i] = hash(o[i * 2] + o[i * 2 + 1])
return o return o
@ -64,27 +59,24 @@ y_data_root len(y)
We can now define a concept of a "path", a way of describing a function that takes as input an SSZ object and outputs some specific (possibly deeply nested) member. For example, `foo -> foo.x` is a path, as are `foo -> len(foo.y)` and `foo -> foo.y[5].w`. We'll describe paths as lists, which can have two representations. In "human-readable form", they are `["x"]`, `["y", "__len__"]` and `["y", 5, "w"]` respectively. In "encoded form", they are lists of `uint64` values, in these cases (assuming the fields of `foo` in order are `x` then `y`, and `w` is the first field of `y[i]`) `[0]`, `[1, 2**64-1]`, `[1, 5, 0]`. We can now define a concept of a "path", a way of describing a function that takes as input an SSZ object and outputs some specific (possibly deeply nested) member. For example, `foo -> foo.x` is a path, as are `foo -> len(foo.y)` and `foo -> foo.y[5].w`. We'll describe paths as lists, which can have two representations. In "human-readable form", they are `["x"]`, `["y", "__len__"]` and `["y", 5, "w"]` respectively. In "encoded form", they are lists of `uint64` values, in these cases (assuming the fields of `foo` in order are `x` then `y`, and `w` is the first field of `y[i]`) `[0]`, `[1, 2**64-1]`, `[1, 5, 0]`.
```python ```python
def item_length(typ: Type) -> int: def item_length(typ: SSZType) -> int:
""" """
Returns the number of bytes in a basic type, or 32 (a full hash) for compound types. Returns the number of bytes in a basic type, or 32 (a full hash) for compound types.
""" """
if typ == bool: if issubclass(typ, BasicValue):
return 1
elif issubclass(typ, uint):
return typ.byte_len return typ.byte_len
else: else:
return 32 return 32
def get_elem_type(typ: Type, index: int) -> Type: def get_elem_type(typ: ComplexType, index: int) -> Type:
""" """
Returns the type of the element of an object of the given type with the given index Returns the type of the element of an object of the given type with the given index
or member variable name (eg. `7` for `x[7]`, `"foo"` for `x.foo`) or member variable name (eg. `7` for `x[7]`, `"foo"` for `x.foo`)
""" """
return typ.get_fields_dict()[index] if is_container_type(typ) else typ.elem_type return typ.get_fields()[key] if issubclass(typ, Container) else typ.elem_type
def chunk_count(typ: SSZType) -> int:
def get_chunk_count(typ: Type) -> int:
""" """
Returns the number of hashes needed to represent the top-level elements in the given type Returns the number of hashes needed to represent the top-level elements in the given type
(eg. `x.foo` or `x[7]` but not `x[7].bar` or `x.foo.baz`). In all cases except lists/vectors (eg. `x.foo` or `x[7]` but not `x[7].bar` or `x.foo.baz`). In all cases except lists/vectors
@ -92,24 +84,28 @@ def get_chunk_count(typ: Type) -> int:
hash. For lists/vectors of basic types, it is often fewer because multiple basic elements hash. For lists/vectors of basic types, it is often fewer because multiple basic elements
can be packed into one 32-byte chunk. can be packed into one 32-byte chunk.
""" """
if is_basic_type(typ): if issubclass(typ, BasicValue):
return 1 return 1
elif issubclass(typ, (List, Vector, Bytes, BytesN)): elif issubclass(typ, Bits):
return (typ.length + 255) // 256
elif issubclass(typ, Elements):
return (typ.length * item_length(typ.elem_type) + 31) // 32 return (typ.length * item_length(typ.elem_type) + 31) // 32
else: elif issubclass(typ, Container):
return len(typ.get_fields()) return len(typ.get_fields())
else:
raise Exception(f"Type not supported: {typ}")
def get_item_position(typ: Type, index: Union[int, str]) -> Tuple[int, int, int]: def get_item_position(typ: SSZType, index: Union[int, str]) -> Tuple[int, int, int]:
""" """
Returns three variables: (i) the index of the chunk in which the given element of the item is Returns three variables: (i) the index of the chunk in which the given element of the item is
represented, (ii) the starting byte position, (iii) the ending byte position. For example for represented, (ii) the starting byte position, (iii) the ending byte position. For example for
a 6-item list of uint64 values, index=2 will return (0, 16, 24), index=5 will return (1, 8, 16) a 6-item list of uint64 values, index=2 will return (0, 16, 24), index=5 will return (1, 8, 16)
""" """
if issubclass(typ, (List, Vector, Bytes, BytesN)): if issubclass(typ, Elements):
start = index * item_length(typ.elem_type) start = index * item_length(typ.elem_type)
return start // 32, start % 32, start % 32 + item_length(typ.elem_type) return start // 32, start % 32, start % 32 + item_length(typ.elem_type)
elif is_container_type(typ): elif issubclass(typ, Container):
return typ.get_field_names().index(index), 0, item_length(get_elem_type(typ, index)) return typ.get_field_names().index(index), 0, item_length(get_elem_type(typ, index))
else: else:
raise Exception("Only lists/vectors/containers supported") raise Exception("Only lists/vectors/containers supported")
@ -122,12 +118,12 @@ def get_generalized_index(typ: Type, path: List[Union[int, str]]) -> Generalized
""" """
root = 1 root = 1
for p in path: for p in path:
assert not is_basic_type(typ) # If we descend to a basic type, the path cannot continue further assert not issubclass(typ, BasicValue) # If we descend to a basic type, the path cannot continue further
if p == '__len__': if p == '__len__':
typ, root = uint256, root * 2 + 1 if issubclass(typ, (List, Bytes)) else None typ, root = uint256, root * 2 + 1 if issubclass(typ, (List, Bytes)) else None
else: else:
pos, _, _ = get_item_position(typ, p) pos, _, _ = get_item_position(typ, p)
root = root * (2 if issubclass(typ, (List, Bytes)) else 1) * next_power_of_two(get_chunk_count(typ)) + pos root = root * (2 if issubclass(typ, (List, Bytes)) else 1) * next_power_of_two(chunk_count(typ)) + pos
typ = get_elem_type(typ, p) typ = get_elem_type(typ, p)
return root return root
``` ```
@ -197,7 +193,7 @@ def get_branch_indices(tree_index: int) -> List[int]:
def get_expanded_indices(indices: List[int]) -> List[int]: def get_expanded_indices(indices: List[int]) -> List[int]:
""" """
Get the generalized indices of all chunks in the tree needed to prove the chunks with the given Get the generalized indices of all chunks in the tree needed to prove the chunks with the given
generalized indices. generalized indices, including the leaves.
""" """
branches = set() branches = set()
for index in indices: for index in indices: