Merkle tree optimizations (#433)

* Merkle tree optimizations

* Replace spawn with parallel iterators

Co-authored-by: Daniel Lubarov <daniel@lubarov.com>

* Missing imports

Co-authored-by: Daniel Lubarov <daniel@lubarov.com>
This commit is contained in:
Jakub Nabaglo 2022-01-26 11:54:39 -08:00 committed by GitHub
parent c0ac79e2e1
commit ab48cca1f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 125 additions and 31 deletions

View File

@ -69,7 +69,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
// Squeeze until we have the desired number of outputs.
let mut outputs = Vec::new();
let mut outputs = Vec::with_capacity(num_outputs);
loop {
for i in 0..SPONGE_RATE {
outputs.push(state[i]);

View File

@ -1,3 +1,7 @@
use std::mem::MaybeUninit;
use std::slice;
use plonky2_util::log2_strict;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
@ -27,36 +31,103 @@ pub struct MerkleTree<F: RichField, H: Hasher<F>> {
/// The data in the leaves of the Merkle tree.
pub leaves: Vec<Vec<F>>,
/// The layers of hashes in the tree. The first layer is the one at the bottom.
pub layers: Vec<Vec<H::Hash>>,
/// The digests in the tree. Consists of `cap.len()` sub-trees, each corresponding to one
/// element in `cap`. Each subtree is contiguous and located at
/// `digests[digests.len() / cap.len() * i..digests.len() / cap.len() * (i + 1)]`.
/// Within each subtree, siblings are stored next to each other. The layout is,
/// left_child_subtree || left_child_digest || right_child_digest || right_child_subtree, where
/// left_child_digest and right_child_digest are H::Hash and left_child_subtree and
/// right_child_subtree recurse. Observe that the digest of a node is stored by its _parent_.
/// Consequently, the digests of the roots are not stored here (they can be found in `cap`).
pub digests: Vec<H::Hash>,
/// The Merkle cap.
pub cap: MerkleCap<F, H>,
}
fn capacity_up_to_mut<T>(v: &mut Vec<T>, len: usize) -> &mut [MaybeUninit<T>] {
assert!(v.capacity() >= len);
let v_ptr = v.as_mut_ptr().cast::<MaybeUninit<T>>();
unsafe {
// SAFETY: `v_ptr` is a valid pointer to a buffer of length at least `len`. Upon return, the
// lifetime will be bound to that of `v`. The underlying memory will not be deallocated as
// we hold the sole mutable reference to `v`. The contents of the slice may be
// uninitialized, but the `MaybeUninit` makes it safe.
slice::from_raw_parts_mut(v_ptr, len)
}
}
fn fill_subtree<F: RichField, H: Hasher<F>>(
digests_buf: &mut [MaybeUninit<H::Hash>],
leaves: &[Vec<F>],
) -> H::Hash {
assert_eq!(leaves.len(), digests_buf.len() / 2 + 1);
if digests_buf.is_empty() {
H::hash(&leaves[0], false)
} else {
// Layout is: left recursive output || left child digest
// || right child digest || right recursive output.
// Split `digests_buf` into the two recursive outputs (slices) and two child digests
// (references).
let (left_digests_buf, right_digests_buf) = digests_buf.split_at_mut(digests_buf.len() / 2);
let (left_digest_mem, left_digests_buf) = left_digests_buf.split_last_mut().unwrap();
let (right_digest_mem, right_digests_buf) = right_digests_buf.split_first_mut().unwrap();
// Split `leaves` between both children.
let (left_leaves, right_leaves) = leaves.split_at(leaves.len() / 2);
let (left_digest, right_digest) = rayon::join(
|| fill_subtree::<F, H>(left_digests_buf, left_leaves),
|| fill_subtree::<F, H>(right_digests_buf, right_leaves),
);
left_digest_mem.write(left_digest);
right_digest_mem.write(right_digest);
H::two_to_one(left_digest, right_digest)
}
}
fn fill_digests_buf<F: RichField, H: Hasher<F>>(
digests_buf: &mut [MaybeUninit<H::Hash>],
cap_buf: &mut [MaybeUninit<H::Hash>],
leaves: &[Vec<F>],
cap_height: usize,
) {
let subtree_digests_len = digests_buf.len() >> cap_height;
let subtree_leaves_len = leaves.len() >> cap_height;
let digests_chunks = digests_buf.par_chunks_exact_mut(subtree_digests_len);
let leaves_chunks = leaves.par_chunks_exact(subtree_leaves_len);
assert_eq!(digests_chunks.len(), cap_buf.len());
assert_eq!(digests_chunks.len(), leaves_chunks.len());
digests_chunks.zip(cap_buf).zip(leaves_chunks).for_each(
|((subtree_digests, subtree_cap), subtree_leaves)| {
// We have `1 << cap_height` sub-trees, one for each entry in `cap`. They are totally
// independent, so we schedule one task for each. `digests_buf` and `leaves` are split
// into `1 << cap_height` slices, one for each sub-tree.
subtree_cap.write(fill_subtree::<F, H>(subtree_digests, subtree_leaves));
},
);
}
impl<F: RichField, H: Hasher<F>> MerkleTree<F, H> {
pub fn new(leaves: Vec<Vec<F>>, cap_height: usize) -> Self {
let mut current_layer = leaves
.par_iter()
.map(|l| H::hash(l, false))
.collect::<Vec<_>>();
let num_digests = 2 * (leaves.len() - (1 << cap_height));
let mut digests = Vec::with_capacity(num_digests);
let mut layers = vec![];
let cap = loop {
if current_layer.len() == 1 << cap_height {
break current_layer;
}
let next_layer = current_layer
.par_chunks(2)
.map(|chunk| H::two_to_one(chunk[0], chunk[1]))
.collect::<Vec<_>>();
layers.push(current_layer);
current_layer = next_layer;
};
let len_cap = 1 << cap_height;
let mut cap = Vec::with_capacity(len_cap);
let digests_buf = capacity_up_to_mut(&mut digests, num_digests);
let cap_buf = capacity_up_to_mut(&mut cap, len_cap);
fill_digests_buf::<F, H>(digests_buf, cap_buf, &leaves[..], cap_height);
unsafe {
// SAFETY: `fill_digests_buf` and `cap` initialized the spare capacity up to
// `num_digests` and `len_cap`, resp.
digests.set_len(num_digests);
cap.set_len(len_cap);
}
Self {
leaves,
layers,
digests,
cap: MerkleCap(cap),
}
}
@ -67,17 +138,40 @@ impl<F: RichField, H: Hasher<F>> MerkleTree<F, H> {
/// Create a Merkle proof from a leaf index.
pub fn prove(&self, leaf_index: usize) -> MerkleProof<F, H> {
MerkleProof {
siblings: self
.layers
.iter()
.scan(leaf_index, |acc, layer| {
let index = *acc ^ 1;
*acc >>= 1;
Some(layer[index])
})
.collect(),
}
let cap_height = log2_strict(self.cap.len());
let num_layers = log2_strict(self.leaves.len()) - cap_height;
debug_assert_eq!(leaf_index >> (cap_height + num_layers), 0);
let digest_tree = {
let tree_index = leaf_index >> num_layers;
let tree_len = self.digests.len() >> cap_height;
&self.digests[tree_len * tree_index..tree_len * (tree_index + 1)]
};
// Mask out high bits to get the index within the sub-tree.
let mut pair_index = leaf_index & ((1 << num_layers) - 1);
let siblings = (0..num_layers)
.into_iter()
.map(|i| {
let parity = pair_index & 1;
pair_index >>= 1;
// The layers' data is interleaved as follows:
// [layer 0, layer 1, layer 0, layer 2, layer 0, layer 1, layer 0, layer 3, ...].
// Each of the above is a pair of siblings.
// `pair_index` is the index of the pair within layer `i`.
// The index of that the pair within `digests` is
// `pair_index * 2 ** (i + 1) + (2 ** i - 1)`.
let siblings_index = (pair_index << (i + 1)) + (1 << i) - 1;
// We have an index for the _pair_, but we want the index of the _sibling_.
// Double the pair index to get the index of the left sibling. Conditionally add `1`
// if we are to retrieve the right sibling.
let sibling_index = 2 * siblings_index + (1 - parity);
digest_tree[sibling_index]
})
.collect();
MerkleProof { siblings }
}
}