diff --git a/plonky2/src/hash/hashing.rs b/plonky2/src/hash/hashing.rs index 0867eaa8..997a6b12 100644 --- a/plonky2/src/hash/hashing.rs +++ b/plonky2/src/hash/hashing.rs @@ -69,7 +69,7 @@ impl, const D: usize> CircuitBuilder { } // Squeeze until we have the desired number of outputs. - let mut outputs = Vec::new(); + let mut outputs = Vec::with_capacity(num_outputs); loop { for i in 0..SPONGE_RATE { outputs.push(state[i]); diff --git a/plonky2/src/hash/merkle_tree.rs b/plonky2/src/hash/merkle_tree.rs index 8f191366..e10b5019 100644 --- a/plonky2/src/hash/merkle_tree.rs +++ b/plonky2/src/hash/merkle_tree.rs @@ -1,3 +1,7 @@ +use std::mem::MaybeUninit; +use std::slice; + +use plonky2_util::log2_strict; use rayon::prelude::*; use serde::{Deserialize, Serialize}; @@ -27,36 +31,103 @@ pub struct MerkleTree> { /// The data in the leaves of the Merkle tree. pub leaves: Vec>, - /// The layers of hashes in the tree. The first layer is the one at the bottom. - pub layers: Vec>, + /// The digests in the tree. Consists of `cap.len()` sub-trees, each corresponding to one + /// element in `cap`. Each subtree is contiguous and located at + /// `digests[digests.len() / cap.len() * i..digests.len() / cap.len() * (i + 1)]`. + /// Within each subtree, siblings are stored next to each other. The layout is, + /// left_child_subtree || left_child_digest || right_child_digest || right_child_subtree, where + /// left_child_digest and right_child_digest are H::Hash and left_child_subtree and + /// right_child_subtree recurse. Observe that the digest of a node is stored by its _parent_. + /// Consequently, the digests of the roots are not stored here (they can be found in `cap`). + pub digests: Vec, /// The Merkle cap. pub cap: MerkleCap, } +fn capacity_up_to_mut(v: &mut Vec, len: usize) -> &mut [MaybeUninit] { + assert!(v.capacity() >= len); + let v_ptr = v.as_mut_ptr().cast::>(); + unsafe { + // SAFETY: `v_ptr` is a valid pointer to a buffer of length at least `len`. Upon return, the + // lifetime will be bound to that of `v`. The underlying memory will not be deallocated as + // we hold the sole mutable reference to `v`. The contents of the slice may be + // uninitialized, but the `MaybeUninit` makes it safe. + slice::from_raw_parts_mut(v_ptr, len) + } +} + +fn fill_subtree>( + digests_buf: &mut [MaybeUninit], + leaves: &[Vec], +) -> H::Hash { + assert_eq!(leaves.len(), digests_buf.len() / 2 + 1); + if digests_buf.is_empty() { + H::hash(&leaves[0], false) + } else { + // Layout is: left recursive output || left child digest + // || right child digest || right recursive output. + // Split `digests_buf` into the two recursive outputs (slices) and two child digests + // (references). + let (left_digests_buf, right_digests_buf) = digests_buf.split_at_mut(digests_buf.len() / 2); + let (left_digest_mem, left_digests_buf) = left_digests_buf.split_last_mut().unwrap(); + let (right_digest_mem, right_digests_buf) = right_digests_buf.split_first_mut().unwrap(); + // Split `leaves` between both children. + let (left_leaves, right_leaves) = leaves.split_at(leaves.len() / 2); + let (left_digest, right_digest) = rayon::join( + || fill_subtree::(left_digests_buf, left_leaves), + || fill_subtree::(right_digests_buf, right_leaves), + ); + left_digest_mem.write(left_digest); + right_digest_mem.write(right_digest); + H::two_to_one(left_digest, right_digest) + } +} + +fn fill_digests_buf>( + digests_buf: &mut [MaybeUninit], + cap_buf: &mut [MaybeUninit], + leaves: &[Vec], + cap_height: usize, +) { + let subtree_digests_len = digests_buf.len() >> cap_height; + let subtree_leaves_len = leaves.len() >> cap_height; + let digests_chunks = digests_buf.par_chunks_exact_mut(subtree_digests_len); + let leaves_chunks = leaves.par_chunks_exact(subtree_leaves_len); + assert_eq!(digests_chunks.len(), cap_buf.len()); + assert_eq!(digests_chunks.len(), leaves_chunks.len()); + digests_chunks.zip(cap_buf).zip(leaves_chunks).for_each( + |((subtree_digests, subtree_cap), subtree_leaves)| { + // We have `1 << cap_height` sub-trees, one for each entry in `cap`. They are totally + // independent, so we schedule one task for each. `digests_buf` and `leaves` are split + // into `1 << cap_height` slices, one for each sub-tree. + subtree_cap.write(fill_subtree::(subtree_digests, subtree_leaves)); + }, + ); +} + impl> MerkleTree { pub fn new(leaves: Vec>, cap_height: usize) -> Self { - let mut current_layer = leaves - .par_iter() - .map(|l| H::hash(l, false)) - .collect::>(); + let num_digests = 2 * (leaves.len() - (1 << cap_height)); + let mut digests = Vec::with_capacity(num_digests); - let mut layers = vec![]; - let cap = loop { - if current_layer.len() == 1 << cap_height { - break current_layer; - } - let next_layer = current_layer - .par_chunks(2) - .map(|chunk| H::two_to_one(chunk[0], chunk[1])) - .collect::>(); - layers.push(current_layer); - current_layer = next_layer; - }; + let len_cap = 1 << cap_height; + let mut cap = Vec::with_capacity(len_cap); + + let digests_buf = capacity_up_to_mut(&mut digests, num_digests); + let cap_buf = capacity_up_to_mut(&mut cap, len_cap); + fill_digests_buf::(digests_buf, cap_buf, &leaves[..], cap_height); + + unsafe { + // SAFETY: `fill_digests_buf` and `cap` initialized the spare capacity up to + // `num_digests` and `len_cap`, resp. + digests.set_len(num_digests); + cap.set_len(len_cap); + } Self { leaves, - layers, + digests, cap: MerkleCap(cap), } } @@ -67,17 +138,40 @@ impl> MerkleTree { /// Create a Merkle proof from a leaf index. pub fn prove(&self, leaf_index: usize) -> MerkleProof { - MerkleProof { - siblings: self - .layers - .iter() - .scan(leaf_index, |acc, layer| { - let index = *acc ^ 1; - *acc >>= 1; - Some(layer[index]) - }) - .collect(), - } + let cap_height = log2_strict(self.cap.len()); + let num_layers = log2_strict(self.leaves.len()) - cap_height; + debug_assert_eq!(leaf_index >> (cap_height + num_layers), 0); + + let digest_tree = { + let tree_index = leaf_index >> num_layers; + let tree_len = self.digests.len() >> cap_height; + &self.digests[tree_len * tree_index..tree_len * (tree_index + 1)] + }; + + // Mask out high bits to get the index within the sub-tree. + let mut pair_index = leaf_index & ((1 << num_layers) - 1); + let siblings = (0..num_layers) + .into_iter() + .map(|i| { + let parity = pair_index & 1; + pair_index >>= 1; + + // The layers' data is interleaved as follows: + // [layer 0, layer 1, layer 0, layer 2, layer 0, layer 1, layer 0, layer 3, ...]. + // Each of the above is a pair of siblings. + // `pair_index` is the index of the pair within layer `i`. + // The index of that the pair within `digests` is + // `pair_index * 2 ** (i + 1) + (2 ** i - 1)`. + let siblings_index = (pair_index << (i + 1)) + (1 << i) - 1; + // We have an index for the _pair_, but we want the index of the _sibling_. + // Double the pair index to get the index of the left sibling. Conditionally add `1` + // if we are to retrieve the right sibling. + let sibling_index = 2 * siblings_index + (1 - parity); + digest_tree[sibling_index] + }) + .collect(); + + MerkleProof { siblings } } }