Merge pull request #1292 from ethereum/correct-merkle

Correct merkleization
This commit is contained in:
Danny Ryan 2019-07-14 17:19:39 -06:00 committed by GitHub
commit 07a0e7b7dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 93 additions and 61 deletions

View File

@ -25,8 +25,6 @@
- [Vectors, containers, lists, unions](#vectors-containers-lists-unions) - [Vectors, containers, lists, unions](#vectors-containers-lists-unions)
- [Deserialization](#deserialization) - [Deserialization](#deserialization)
- [Merkleization](#merkleization) - [Merkleization](#merkleization)
- [`Bitvector[N]`](#bitvectorn-1)
- [`Bitlist[N]`](#bitlistn-1)
- [Self-signed containers](#self-signed-containers) - [Self-signed containers](#self-signed-containers)
- [Implementations](#implementations) - [Implementations](#implementations)
@ -177,38 +175,37 @@ Note that deserialization requires hardening against invalid inputs. A non-exhau
We first define helper functions: We first define helper functions:
* `size_of(B)`, where `B` is a basic type: the length, in bytes, of the serialized form of the basic type.
* `chunk_count(type)`: calculate the amount of leafs for merkleization of the type.
* all basic types: `1`
* `Bitlist[N]` and `Bitvector[N]`: `(N + 255) // 256` (dividing by chunk size, rounding up)
* `List[B, N]` and `Vector[B, N]`, where `B` is a basic type: `(N * size_of(B) + 31) // 32` (dividing by chunk size, rounding up)
* `List[C, N]` and `Vector[C, N]`, where `C` is a composite type: `N`
* containers: `len(fields)`
* `bitfield_bytes(bits)`: return the bits of the bitlist or bitvector, packed in bytes, aligned to the start. Exclusive length-delimiting bit for bitlists.
* `pack`: Given ordered objects of the same basic type, serialize them, pack them into `BYTES_PER_CHUNK`-byte chunks, right-pad the last chunk with zero bytes, and return the chunks. * `pack`: Given ordered objects of the same basic type, serialize them, pack them into `BYTES_PER_CHUNK`-byte chunks, right-pad the last chunk with zero bytes, and return the chunks.
* `next_pow_of_two(i)`: get the next power of 2 of `i`, if not already a power of 2, with 0 mapping to 1. Examples: `0->1, 1->1, 2->2, 3->4, 4->4, 6->8, 9->16` * `next_pow_of_two(i)`: get the next power of 2 of `i`, if not already a power of 2, with 0 mapping to 1. Examples: `0->1, 1->1, 2->2, 3->4, 4->4, 6->8, 9->16`
* `merkleize(data, pad_for=1)`: Given ordered `BYTES_PER_CHUNK`-byte chunks, if necessary append zero chunks so that the number of chunks is a power of two, Merkleize the chunks, and return the root. * `merkleize(chunks, limit=None)`: Given ordered `BYTES_PER_CHUNK`-byte chunks, merkleize the chunks, and return the root:
* The merkleization depends on the effective input, which can be padded: if `pad_for=L`, then pad the `data` with zeroed chunks to `next_pow_of_two(L)` (virtually for memory efficiency). * The merkleization depends on the effective input, which can be padded/limited:
- if no limit: pad the `chunks` with zeroed chunks to `next_pow_of_two(len(chunks))` (virtually for memory efficiency).
- if `limit > len(chunks)`, pad the `chunks` with zeroed chunks to `next_pow_of_two(limit)` (virtually for memory efficiency).
- if `limit < len(chunks)`: do not merkleize, input exceeds limit. Raise an error instead.
* Then, merkleize the chunks (empty input is padded to 1 zero chunk): * Then, merkleize the chunks (empty input is padded to 1 zero chunk):
- If `1` chunk: A single chunk is simply that chunk, i.e. the identity when the number of chunks is one. - If `1` chunk: the root is the chunk itself.
- If `> 1` chunks: pad to `next_pow_of_two(len(chunks))`, merkleize as binary tree. - If `> 1` chunks: merkleize as binary tree.
* `mix_in_length`: Given a Merkle root `root` and a length `length` (`"uint256"` little-endian serialization) return `hash(root + length)`. * `mix_in_length`: Given a Merkle root `root` and a length `length` (`"uint256"` little-endian serialization) return `hash(root + length)`.
* `mix_in_type`: Given a Merkle root `root` and a type_index `type_index` (`"uint256"` little-endian serialization) return `hash(root + type_index)`. * `mix_in_type`: Given a Merkle root `root` and a type_index `type_index` (`"uint256"` little-endian serialization) return `hash(root + type_index)`.
We now define Merkleization `hash_tree_root(value)` of an object `value` recursively: We now define Merkleization `hash_tree_root(value)` of an object `value` recursively:
* `merkleize(pack(value))` if `value` is a basic object or a vector of basic objects. * `merkleize(pack(value))` if `value` is a basic object or a vector of basic objects.
* `mix_in_length(merkleize(pack(value), pad_for=(N * elem_size / BYTES_PER_CHUNK)), len(value))` if `value` is a list of basic objects. * `merkleize(bitfield_bytes(value), limit=chunk_count(type))` if `value` is a bitvector.
* `mix_in_length(merkleize(pack(value), limit=chunk_count(type)), len(value))` if `value` is a list of basic objects.
* `mix_in_length(merkleize(bitfield_bytes(value), limit=chunk_count(type)), len(value))` if `value` is a bitlist.
* `merkleize([hash_tree_root(element) for element in value])` if `value` is a vector of composite objects or a container. * `merkleize([hash_tree_root(element) for element in value])` if `value` is a vector of composite objects or a container.
* `mix_in_length(merkleize([hash_tree_root(element) for element in value], pad_for=N), len(value))` if `value` is a list of composite objects. * `mix_in_length(merkleize([hash_tree_root(element) for element in value], limit=chunk_count(type)), len(value))` if `value` is a list of composite objects.
* `mix_in_type(merkleize(value.value), value.type_index)` if `value` is of union type. * `mix_in_type(merkleize(value.value), value.type_index)` if `value` is of union type.
### `Bitvector[N]`
```python
as_integer = sum([value[i] << i for i in range(len(value))])
return merkleize(pack(as_integer.to_bytes((N + 7) // 8, "little")))
```
### `Bitlist[N]`
```python
as_integer = sum([value[i] << i for i in range(len(value))])
return mix_in_length(merkleize(pack(as_integer.to_bytes((N + 7) // 8, "little"))), len(value))
```
## Self-signed containers ## Self-signed containers
Let `value` be a self-signed container object. The convention is that the signature (e.g. a `"bytes96"` BLS12-381 signature) be the last field of `value`. Further, the signed message for `value` is `signing_root(value) = hash_tree_root(truncate_last(value))` where `truncate_last` truncates the last element of `value`. Let `value` be a self-signed container object. The convention is that the signature (e.g. a `"bytes96"` BLS12-381 signature) be the last field of `value`. Further, the signed message for `value` is `signing_root(value) = hash_tree_root(truncate_last(value))` where `truncate_last` truncates the last element of `value`.

View File

@ -1,4 +1,4 @@
from .hash_function import hash from eth2spec.utils.hash_function import hash
from math import log2 from math import log2
@ -21,6 +21,8 @@ def calc_merkle_tree_from_leaves(values, layer_count=32):
def get_merkle_root(values, pad_to=1): def get_merkle_root(values, pad_to=1):
if pad_to == 0:
return zerohashes[0]
layer_count = int(log2(pad_to)) layer_count = int(log2(pad_to))
if len(values) == 0: if len(values) == 0:
return zerohashes[layer_count] return zerohashes[layer_count]
@ -35,10 +37,21 @@ def get_merkle_proof(tree, item_index):
return proof return proof
def merkleize_chunks(chunks, pad_to: int=1): def merkleize_chunks(chunks, limit=None):
# If no limit is defined, we are just merkleizing chunks (e.g. SSZ container).
if limit is None:
limit = len(chunks)
count = len(chunks) count = len(chunks)
# See if the input is within expected size.
# If not, a list-limit is set incorrectly, or a value is unexpectedly large.
assert count <= limit
if limit == 0:
return zerohashes[0]
depth = max(count - 1, 0).bit_length() depth = max(count - 1, 0).bit_length()
max_depth = max(depth, (pad_to - 1).bit_length()) max_depth = (limit - 1).bit_length()
tmp = [None for _ in range(max_depth + 1)] tmp = [None for _ in range(max_depth + 1)]
def merge(h, i): def merge(h, i):

View File

@ -126,6 +126,7 @@ def item_length(typ: SSZType) -> int:
def chunk_count(typ: SSZType) -> int: def chunk_count(typ: SSZType) -> int:
# note that for lists, .length *on the type* describes the list limit.
if isinstance(typ, BasicType): if isinstance(typ, BasicType):
return 1 return 1
elif issubclass(typ, Bits): elif issubclass(typ, Bits):
@ -150,7 +151,7 @@ def hash_tree_root(obj: SSZValue):
raise Exception(f"Type not supported: {type(obj)}") raise Exception(f"Type not supported: {type(obj)}")
if isinstance(obj, (List, Bytes, Bitlist)): if isinstance(obj, (List, Bytes, Bitlist)):
return mix_in_length(merkleize_chunks(leaves, pad_to=chunk_count(obj.type())), len(obj)) return mix_in_length(merkleize_chunks(leaves, limit=chunk_count(obj.type())), len(obj))
else: else:
return merkleize_chunks(leaves) return merkleize_chunks(leaves)

View File

@ -8,7 +8,8 @@ def h(a: bytes, b: bytes) -> bytes:
def e(v: int) -> bytes: def e(v: int) -> bytes:
return v.to_bytes(length=32, byteorder='little') # prefix with 0xfff... to make it non-zero
return b'\xff' * 28 + v.to_bytes(length=4, byteorder='little')
def z(i: int) -> bytes: def z(i: int) -> bytes:
@ -16,44 +17,64 @@ def z(i: int) -> bytes:
cases = [ cases = [
(0, 0, 1, z(0)), # limit 0: always zero hash
(0, 1, 1, e(0)), (0, 0, z(0)),
(1, 0, 2, h(z(0), z(0))), (1, 0, None), # cut-off due to limit
(1, 1, 2, h(e(0), z(0))), (2, 0, None), # cut-off due to limit
(1, 2, 2, h(e(0), e(1))), # limit 1: padded to 1 element if not already. Returned (like identity func)
(2, 0, 4, h(h(z(0), z(0)), z(1))), (0, 1, z(0)),
(2, 1, 4, h(h(e(0), z(0)), z(1))), (1, 1, e(0)),
(2, 2, 4, h(h(e(0), e(1)), z(1))), (2, 1, None), # cut-off due to limit
(2, 3, 4, h(h(e(0), e(1)), h(e(2), z(0)))), (1, 1, e(0)),
(2, 4, 4, h(h(e(0), e(1)), h(e(2), e(3)))), (0, 2, h(z(0), z(0))),
(3, 0, 8, h(h(h(z(0), z(0)), z(1)), z(2))), (1, 2, h(e(0), z(0))),
(3, 1, 8, h(h(h(e(0), z(0)), z(1)), z(2))), (2, 2, h(e(0), e(1))),
(3, 2, 8, h(h(h(e(0), e(1)), z(1)), z(2))), (3, 2, None), # cut-off due to limit
(3, 3, 8, h(h(h(e(0), e(1)), h(e(2), z(0))), z(2))), (16, 2, None), # bigger cut-off due to limit
(3, 4, 8, h(h(h(e(0), e(1)), h(e(2), e(3))), z(2))), (0, 4, h(h(z(0), z(0)), z(1))),
(3, 5, 8, h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), z(0)), z(1)))), (1, 4, h(h(e(0), z(0)), z(1))),
(3, 6, 8, h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(z(0), z(0))))), (2, 4, h(h(e(0), e(1)), z(1))),
(3, 7, 8, h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(e(6), z(0))))), (3, 4, h(h(e(0), e(1)), h(e(2), z(0)))),
(3, 8, 8, h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(e(6), e(7))))), (4, 4, h(h(e(0), e(1)), h(e(2), e(3)))),
(4, 0, 16, h(h(h(h(z(0), z(0)), z(1)), z(2)), z(3))), (5, 4, None), # cut-off due to limit
(4, 1, 16, h(h(h(h(e(0), z(0)), z(1)), z(2)), z(3))), (0, 8, h(h(h(z(0), z(0)), z(1)), z(2))),
(4, 2, 16, h(h(h(h(e(0), e(1)), z(1)), z(2)), z(3))), (1, 8, h(h(h(e(0), z(0)), z(1)), z(2))),
(4, 3, 16, h(h(h(h(e(0), e(1)), h(e(2), z(0))), z(2)), z(3))), (2, 8, h(h(h(e(0), e(1)), z(1)), z(2))),
(4, 4, 16, h(h(h(h(e(0), e(1)), h(e(2), e(3))), z(2)), z(3))), (3, 8, h(h(h(e(0), e(1)), h(e(2), z(0))), z(2))),
(4, 5, 16, h(h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), z(0)), z(1))), z(3))), (4, 8, h(h(h(e(0), e(1)), h(e(2), e(3))), z(2))),
(4, 6, 16, h(h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(z(0), z(0)))), z(3))), (5, 8, h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), z(0)), z(1)))),
(4, 7, 16, h(h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(e(6), z(0)))), z(3))), (6, 8, h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(z(0), z(0))))),
(4, 8, 16, h(h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(e(6), e(7)))), z(3))), (7, 8, h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(e(6), z(0))))),
(4, 9, 16, (8, 8, h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(e(6), e(7))))),
h(h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(e(6), e(7)))), h(h(h(e(8), z(0)), z(1)), z(2)))), (9, 8, None), # cut-off due to limit
(0, 16, h(h(h(h(z(0), z(0)), z(1)), z(2)), z(3))),
(1, 16, h(h(h(h(e(0), z(0)), z(1)), z(2)), z(3))),
(2, 16, h(h(h(h(e(0), e(1)), z(1)), z(2)), z(3))),
(3, 16, h(h(h(h(e(0), e(1)), h(e(2), z(0))), z(2)), z(3))),
(4, 16, h(h(h(h(e(0), e(1)), h(e(2), e(3))), z(2)), z(3))),
(5, 16, h(h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), z(0)), z(1))), z(3))),
(6, 16, h(h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(z(0), z(0)))), z(3))),
(7, 16, h(h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(e(6), z(0)))), z(3))),
(8, 16, h(h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(e(6), e(7)))), z(3))),
(9, 16, h(h(h(h(e(0), e(1)), h(e(2), e(3))), h(h(e(4), e(5)), h(e(6), e(7)))), h(h(h(e(8), z(0)), z(1)), z(2)))),
] ]
@pytest.mark.parametrize( @pytest.mark.parametrize(
'depth,count,pow2,value', 'count,limit,value',
cases, cases,
) )
def test_merkleize_chunks_and_get_merkle_root(depth, count, pow2, value): def test_merkleize_chunks_and_get_merkle_root(count, limit, value):
chunks = [e(i) for i in range(count)] chunks = [e(i) for i in range(count)]
assert merkleize_chunks(chunks, pad_to=pow2) == value if value is None:
assert get_merkle_root(chunks, pad_to=pow2) == value bad = False
try:
merkleize_chunks(chunks, limit=limit)
bad = True
except AssertionError:
pass
if bad:
assert False, "expected merkleization to be invalid"
else:
assert merkleize_chunks(chunks, limit=limit) == value
assert get_merkle_root(chunks, pad_to=limit) == value