mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-05 15:23:06 +00:00
Merkle tree optimizations (#433)
* Merkle tree optimizations * Replace spawn with parallel iterators Co-authored-by: Daniel Lubarov <daniel@lubarov.com> * Missing imports Co-authored-by: Daniel Lubarov <daniel@lubarov.com>
This commit is contained in:
parent
c0ac79e2e1
commit
ab48cca1f3
@ -69,7 +69,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
|
||||
}
|
||||
|
||||
// Squeeze until we have the desired number of outputs.
|
||||
let mut outputs = Vec::new();
|
||||
let mut outputs = Vec::with_capacity(num_outputs);
|
||||
loop {
|
||||
for i in 0..SPONGE_RATE {
|
||||
outputs.push(state[i]);
|
||||
|
||||
@ -1,3 +1,7 @@
|
||||
use std::mem::MaybeUninit;
|
||||
use std::slice;
|
||||
|
||||
use plonky2_util::log2_strict;
|
||||
use rayon::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@ -27,36 +31,103 @@ pub struct MerkleTree<F: RichField, H: Hasher<F>> {
|
||||
/// The data in the leaves of the Merkle tree.
|
||||
pub leaves: Vec<Vec<F>>,
|
||||
|
||||
/// The layers of hashes in the tree. The first layer is the one at the bottom.
|
||||
pub layers: Vec<Vec<H::Hash>>,
|
||||
/// The digests in the tree. Consists of `cap.len()` sub-trees, each corresponding to one
|
||||
/// element in `cap`. Each subtree is contiguous and located at
|
||||
/// `digests[digests.len() / cap.len() * i..digests.len() / cap.len() * (i + 1)]`.
|
||||
/// Within each subtree, siblings are stored next to each other. The layout is,
|
||||
/// left_child_subtree || left_child_digest || right_child_digest || right_child_subtree, where
|
||||
/// left_child_digest and right_child_digest are H::Hash and left_child_subtree and
|
||||
/// right_child_subtree recurse. Observe that the digest of a node is stored by its _parent_.
|
||||
/// Consequently, the digests of the roots are not stored here (they can be found in `cap`).
|
||||
pub digests: Vec<H::Hash>,
|
||||
|
||||
/// The Merkle cap.
|
||||
pub cap: MerkleCap<F, H>,
|
||||
}
|
||||
|
||||
fn capacity_up_to_mut<T>(v: &mut Vec<T>, len: usize) -> &mut [MaybeUninit<T>] {
|
||||
assert!(v.capacity() >= len);
|
||||
let v_ptr = v.as_mut_ptr().cast::<MaybeUninit<T>>();
|
||||
unsafe {
|
||||
// SAFETY: `v_ptr` is a valid pointer to a buffer of length at least `len`. Upon return, the
|
||||
// lifetime will be bound to that of `v`. The underlying memory will not be deallocated as
|
||||
// we hold the sole mutable reference to `v`. The contents of the slice may be
|
||||
// uninitialized, but the `MaybeUninit` makes it safe.
|
||||
slice::from_raw_parts_mut(v_ptr, len)
|
||||
}
|
||||
}
|
||||
|
||||
fn fill_subtree<F: RichField, H: Hasher<F>>(
|
||||
digests_buf: &mut [MaybeUninit<H::Hash>],
|
||||
leaves: &[Vec<F>],
|
||||
) -> H::Hash {
|
||||
assert_eq!(leaves.len(), digests_buf.len() / 2 + 1);
|
||||
if digests_buf.is_empty() {
|
||||
H::hash(&leaves[0], false)
|
||||
} else {
|
||||
// Layout is: left recursive output || left child digest
|
||||
// || right child digest || right recursive output.
|
||||
// Split `digests_buf` into the two recursive outputs (slices) and two child digests
|
||||
// (references).
|
||||
let (left_digests_buf, right_digests_buf) = digests_buf.split_at_mut(digests_buf.len() / 2);
|
||||
let (left_digest_mem, left_digests_buf) = left_digests_buf.split_last_mut().unwrap();
|
||||
let (right_digest_mem, right_digests_buf) = right_digests_buf.split_first_mut().unwrap();
|
||||
// Split `leaves` between both children.
|
||||
let (left_leaves, right_leaves) = leaves.split_at(leaves.len() / 2);
|
||||
let (left_digest, right_digest) = rayon::join(
|
||||
|| fill_subtree::<F, H>(left_digests_buf, left_leaves),
|
||||
|| fill_subtree::<F, H>(right_digests_buf, right_leaves),
|
||||
);
|
||||
left_digest_mem.write(left_digest);
|
||||
right_digest_mem.write(right_digest);
|
||||
H::two_to_one(left_digest, right_digest)
|
||||
}
|
||||
}
|
||||
|
||||
fn fill_digests_buf<F: RichField, H: Hasher<F>>(
|
||||
digests_buf: &mut [MaybeUninit<H::Hash>],
|
||||
cap_buf: &mut [MaybeUninit<H::Hash>],
|
||||
leaves: &[Vec<F>],
|
||||
cap_height: usize,
|
||||
) {
|
||||
let subtree_digests_len = digests_buf.len() >> cap_height;
|
||||
let subtree_leaves_len = leaves.len() >> cap_height;
|
||||
let digests_chunks = digests_buf.par_chunks_exact_mut(subtree_digests_len);
|
||||
let leaves_chunks = leaves.par_chunks_exact(subtree_leaves_len);
|
||||
assert_eq!(digests_chunks.len(), cap_buf.len());
|
||||
assert_eq!(digests_chunks.len(), leaves_chunks.len());
|
||||
digests_chunks.zip(cap_buf).zip(leaves_chunks).for_each(
|
||||
|((subtree_digests, subtree_cap), subtree_leaves)| {
|
||||
// We have `1 << cap_height` sub-trees, one for each entry in `cap`. They are totally
|
||||
// independent, so we schedule one task for each. `digests_buf` and `leaves` are split
|
||||
// into `1 << cap_height` slices, one for each sub-tree.
|
||||
subtree_cap.write(fill_subtree::<F, H>(subtree_digests, subtree_leaves));
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
impl<F: RichField, H: Hasher<F>> MerkleTree<F, H> {
|
||||
pub fn new(leaves: Vec<Vec<F>>, cap_height: usize) -> Self {
|
||||
let mut current_layer = leaves
|
||||
.par_iter()
|
||||
.map(|l| H::hash(l, false))
|
||||
.collect::<Vec<_>>();
|
||||
let num_digests = 2 * (leaves.len() - (1 << cap_height));
|
||||
let mut digests = Vec::with_capacity(num_digests);
|
||||
|
||||
let mut layers = vec![];
|
||||
let cap = loop {
|
||||
if current_layer.len() == 1 << cap_height {
|
||||
break current_layer;
|
||||
}
|
||||
let next_layer = current_layer
|
||||
.par_chunks(2)
|
||||
.map(|chunk| H::two_to_one(chunk[0], chunk[1]))
|
||||
.collect::<Vec<_>>();
|
||||
layers.push(current_layer);
|
||||
current_layer = next_layer;
|
||||
};
|
||||
let len_cap = 1 << cap_height;
|
||||
let mut cap = Vec::with_capacity(len_cap);
|
||||
|
||||
let digests_buf = capacity_up_to_mut(&mut digests, num_digests);
|
||||
let cap_buf = capacity_up_to_mut(&mut cap, len_cap);
|
||||
fill_digests_buf::<F, H>(digests_buf, cap_buf, &leaves[..], cap_height);
|
||||
|
||||
unsafe {
|
||||
// SAFETY: `fill_digests_buf` and `cap` initialized the spare capacity up to
|
||||
// `num_digests` and `len_cap`, resp.
|
||||
digests.set_len(num_digests);
|
||||
cap.set_len(len_cap);
|
||||
}
|
||||
|
||||
Self {
|
||||
leaves,
|
||||
layers,
|
||||
digests,
|
||||
cap: MerkleCap(cap),
|
||||
}
|
||||
}
|
||||
@ -67,17 +138,40 @@ impl<F: RichField, H: Hasher<F>> MerkleTree<F, H> {
|
||||
|
||||
/// Create a Merkle proof from a leaf index.
|
||||
pub fn prove(&self, leaf_index: usize) -> MerkleProof<F, H> {
|
||||
MerkleProof {
|
||||
siblings: self
|
||||
.layers
|
||||
.iter()
|
||||
.scan(leaf_index, |acc, layer| {
|
||||
let index = *acc ^ 1;
|
||||
*acc >>= 1;
|
||||
Some(layer[index])
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
let cap_height = log2_strict(self.cap.len());
|
||||
let num_layers = log2_strict(self.leaves.len()) - cap_height;
|
||||
debug_assert_eq!(leaf_index >> (cap_height + num_layers), 0);
|
||||
|
||||
let digest_tree = {
|
||||
let tree_index = leaf_index >> num_layers;
|
||||
let tree_len = self.digests.len() >> cap_height;
|
||||
&self.digests[tree_len * tree_index..tree_len * (tree_index + 1)]
|
||||
};
|
||||
|
||||
// Mask out high bits to get the index within the sub-tree.
|
||||
let mut pair_index = leaf_index & ((1 << num_layers) - 1);
|
||||
let siblings = (0..num_layers)
|
||||
.into_iter()
|
||||
.map(|i| {
|
||||
let parity = pair_index & 1;
|
||||
pair_index >>= 1;
|
||||
|
||||
// The layers' data is interleaved as follows:
|
||||
// [layer 0, layer 1, layer 0, layer 2, layer 0, layer 1, layer 0, layer 3, ...].
|
||||
// Each of the above is a pair of siblings.
|
||||
// `pair_index` is the index of the pair within layer `i`.
|
||||
// The index of that the pair within `digests` is
|
||||
// `pair_index * 2 ** (i + 1) + (2 ** i - 1)`.
|
||||
let siblings_index = (pair_index << (i + 1)) + (1 << i) - 1;
|
||||
// We have an index for the _pair_, but we want the index of the _sibling_.
|
||||
// Double the pair index to get the index of the left sibling. Conditionally add `1`
|
||||
// if we are to retrieve the right sibling.
|
||||
let sibling_index = 2 * siblings_index + (1 - parity);
|
||||
digest_tree[sibling_index]
|
||||
})
|
||||
.collect();
|
||||
|
||||
MerkleProof { siblings }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user