From e326630e7bcaff0494410df28e9d80736cd28975 Mon Sep 17 00:00:00 2001 From: M Alghazwi Date: Tue, 8 Oct 2024 14:21:12 +0200 Subject: [PATCH] add safe MT --- .../src/circuits/merkle_tree_circuit.rs | 8 +- .../src/merkle_tree/capped_tree.rs | 378 +++++++++++++ .../src/merkle_tree/merkle_safe.rs | 514 ++++++++++++++++++ codex-plonky2-circuits/src/merkle_tree/mod.rs | 380 +------------ 4 files changed, 898 insertions(+), 382 deletions(-) create mode 100644 codex-plonky2-circuits/src/merkle_tree/capped_tree.rs create mode 100644 codex-plonky2-circuits/src/merkle_tree/merkle_safe.rs diff --git a/codex-plonky2-circuits/src/circuits/merkle_tree_circuit.rs b/codex-plonky2-circuits/src/circuits/merkle_tree_circuit.rs index f1a0da9..4f3bce2 100644 --- a/codex-plonky2-circuits/src/circuits/merkle_tree_circuit.rs +++ b/codex-plonky2-circuits/src/circuits/merkle_tree_circuit.rs @@ -11,11 +11,11 @@ use plonky2::plonk::proof::ProofWithPublicInputs; use std::marker::PhantomData; use itertools::Itertools; -use crate::merkle_tree::MerkleTree; +use crate::merkle_tree::capped_tree::MerkleTree; use plonky2::hash::poseidon::PoseidonHash; use plonky2::hash::hash_types::{HashOutTarget, MerkleCapTarget, NUM_HASH_OUT_ELTS}; -use crate::merkle_tree::{MerkleProof, MerkleProofTarget}; +use crate::merkle_tree::capped_tree::{MerkleProof, MerkleProofTarget}; use plonky2_poseidon2::poseidon2_hash::poseidon2::{Poseidon2, Poseidon2Hash}; use plonky2::field::goldilocks_field::GoldilocksField; @@ -24,7 +24,7 @@ use plonky2::plonk::proof::Proof; use plonky2::hash::hashing::PlonkyPermutation; use plonky2::plonk::circuit_data::VerifierCircuitTarget; -use crate::merkle_tree::MerkleCap; +use crate::merkle_tree::capped_tree::MerkleCap; // size of leaf data (in number of field elements) pub const LEAF_LEN: usize = 4; @@ -359,7 +359,7 @@ pub mod tests { use super::*; use plonky2::field::types::Field; - use crate::merkle_tree::MerkleTree; + use crate::merkle_tree::capped_tree::MerkleTree; use plonky2::iop::witness::{PartialWitness, WitnessWrite}; use plonky2::plonk::circuit_data::CircuitConfig; use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; diff --git a/codex-plonky2-circuits/src/merkle_tree/capped_tree.rs b/codex-plonky2-circuits/src/merkle_tree/capped_tree.rs new file mode 100644 index 0000000..7985c78 --- /dev/null +++ b/codex-plonky2-circuits/src/merkle_tree/capped_tree.rs @@ -0,0 +1,378 @@ +// An adapted implementation of Merkle tree +// based on the original plonky2 merkle tree implementation + +use core::mem::MaybeUninit; +use core::slice; +use anyhow::{ensure, Result}; +use plonky2_maybe_rayon::*; +use serde::{Deserialize, Serialize}; + +use plonky2::hash::hash_types::{HashOutTarget, RichField}; +use plonky2::plonk::config::{GenericHashOut, Hasher}; +use plonky2::util::log2_strict; + +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +#[serde(bound = "")] +pub struct MerkleCap>(pub Vec); + +impl> Default for MerkleCap { + fn default() -> Self { + Self(Vec::new()) + } +} + +impl> MerkleCap { + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn height(&self) -> usize { + log2_strict(self.len()) + } + + pub fn flatten(&self) -> Vec { + self.0.iter().flat_map(|&h| h.to_vec()).collect() + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct MerkleTree> { + pub leaves: Vec>, + + pub digests: Vec, + + pub cap: MerkleCap, +} + +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +#[serde(bound = "")] +pub struct MerkleProof> { + /// The Merkle digest of each sibling subtree, staying from the bottommost layer. + pub siblings: Vec, +} + +impl> MerkleProof { + pub fn len(&self) -> usize { + self.siblings.len() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct MerkleProofTarget { + /// The Merkle digest of each sibling subtree, staying from the bottommost layer. + pub siblings: Vec, +} + +impl> Default for MerkleTree { + fn default() -> Self { + Self { + leaves: Vec::new(), + digests: Vec::new(), + cap: MerkleCap::default(), + } + } +} + +pub(crate) fn capacity_up_to_mut(v: &mut Vec, len: usize) -> &mut [MaybeUninit] { + assert!(v.capacity() >= len); + let v_ptr = v.as_mut_ptr().cast::>(); + unsafe { + slice::from_raw_parts_mut(v_ptr, len) + } +} + +pub(crate) fn fill_subtree>( + digests_buf: &mut [MaybeUninit], + leaves: &[Vec], +) -> H::Hash { + assert_eq!(leaves.len(), digests_buf.len() / 2 + 1); + if digests_buf.is_empty() { + H::hash_or_noop(&leaves[0]) + } else { + let (left_digests_buf, right_digests_buf) = digests_buf.split_at_mut(digests_buf.len() / 2); + let (left_digest_mem, left_digests_buf) = left_digests_buf.split_last_mut().unwrap(); + let (right_digest_mem, right_digests_buf) = right_digests_buf.split_first_mut().unwrap(); + + let (left_leaves, right_leaves) = leaves.split_at(leaves.len() / 2); + + let (left_digest, right_digest) = plonky2_maybe_rayon::join( + || fill_subtree::(left_digests_buf, left_leaves), + || fill_subtree::(right_digests_buf, right_leaves), + ); + + left_digest_mem.write(left_digest); + right_digest_mem.write(right_digest); + H::two_to_one(left_digest, right_digest) + } +} + +pub(crate) fn fill_digests_buf>( + digests_buf: &mut [MaybeUninit], + cap_buf: &mut [MaybeUninit], + leaves: &[Vec], + cap_height: usize, +) { + + if digests_buf.is_empty() { + debug_assert_eq!(cap_buf.len(), leaves.len()); + cap_buf + .par_iter_mut() + .zip(leaves) + .for_each(|(cap_buf, leaf)| { + cap_buf.write(H::hash_or_noop(leaf)); + }); + return; + } + + let subtree_digests_len = digests_buf.len() >> cap_height; + let subtree_leaves_len = leaves.len() >> cap_height; + let digests_chunks = digests_buf.par_chunks_exact_mut(subtree_digests_len); + let leaves_chunks = leaves.par_chunks_exact(subtree_leaves_len); + assert_eq!(digests_chunks.len(), cap_buf.len()); + assert_eq!(digests_chunks.len(), leaves_chunks.len()); + digests_chunks.zip(cap_buf).zip(leaves_chunks).for_each( + |((subtree_digests, subtree_cap), subtree_leaves)| { + + subtree_cap.write(fill_subtree::(subtree_digests, subtree_leaves)); + }, + ); +} + +pub(crate) fn merkle_tree_prove>( + leaf_index: usize, + leaves_len: usize, + cap_height: usize, + digests: &[H::Hash], +) -> Vec { + let num_layers = log2_strict(leaves_len) - cap_height; + debug_assert_eq!(leaf_index >> (cap_height + num_layers), 0); + + let digest_len = 2 * (leaves_len - (1 << cap_height)); + assert_eq!(digest_len, digests.len()); + + let digest_tree: &[H::Hash] = { + let tree_index = leaf_index >> num_layers; + let tree_len = digest_len >> cap_height; + &digests[tree_len * tree_index..tree_len * (tree_index + 1)] + }; + + // Mask out high bits to get the index within the sub-tree. + let mut pair_index = leaf_index & ((1 << num_layers) - 1); + (0..num_layers) + .map(|i| { + let parity = pair_index & 1; + pair_index >>= 1; + + // The layers' data is interleaved as follows: + // [layer 0, layer 1, layer 0, layer 2, layer 0, layer 1, layer 0, layer 3, ...]. + // Each of the above is a pair of siblings. + // `pair_index` is the index of the pair within layer `i`. + // The index of that the pair within `digests` is + // `pair_index * 2 ** (i + 1) + (2 ** i - 1)`. + let siblings_index = (pair_index << (i + 1)) + (1 << i) - 1; + // We have an index for the _pair_, but we want the index of the _sibling_. + // Double the pair index to get the index of the left sibling. Conditionally add `1` + // if we are to retrieve the right sibling. + let sibling_index = 2 * siblings_index + (1 - parity); + digest_tree[sibling_index] + }) + .collect() +} + +impl> MerkleTree { + pub fn new(leaves: Vec>, cap_height: usize) -> Self { + let log2_leaves_len = log2_strict(leaves.len()); + assert!( + cap_height <= log2_leaves_len, + "cap_height={} should be at most log2(leaves.len())={}", + cap_height, + log2_leaves_len + ); + + let num_digests = 2 * (leaves.len() - (1 << cap_height)); + let mut digests = Vec::with_capacity(num_digests); + + let len_cap = 1 << cap_height; + let mut cap = Vec::with_capacity(len_cap); + + let digests_buf = capacity_up_to_mut(&mut digests, num_digests); + let cap_buf = capacity_up_to_mut(&mut cap, len_cap); + fill_digests_buf::(digests_buf, cap_buf, &leaves[..], cap_height); + + unsafe { + // SAFETY: `fill_digests_buf` and `cap` initialized the spare capacity up to + // `num_digests` and `len_cap`, resp. + digests.set_len(num_digests); + cap.set_len(len_cap); + } + + Self { + leaves, + digests, + cap: MerkleCap(cap), + } + } + + pub fn get(&self, i: usize) -> &[F] { + &self.leaves[i] + } + + // Create a Merkle proof from a leaf index. + pub fn prove(&self, leaf_index: usize) -> MerkleProof { + let cap_height = log2_strict(self.cap.len()); + let siblings = + merkle_tree_prove::(leaf_index, self.leaves.len(), cap_height, &self.digests); + + MerkleProof { siblings } + } +} + +/// Verifies that the given leaf data is present at the given index in the Merkle tree with the +/// given root. +pub fn verify_merkle_proof>( + leaf_data: Vec, + leaf_index: usize, + merkle_root: H::Hash, + proof: &MerkleProof, +) -> Result<()> { + let merkle_cap = MerkleCap(vec![merkle_root]); + verify_merkle_proof_to_cap(leaf_data, leaf_index, &merkle_cap, proof) +} + +/// Verifies that the given leaf data is present at the given index in the Merkle tree with the +/// given cap. +pub fn verify_merkle_proof_to_cap>( + leaf_data: Vec, + leaf_index: usize, + merkle_cap: &MerkleCap, + proof: &MerkleProof, +) -> Result<()> { + verify_batch_merkle_proof_to_cap( + &[leaf_data.clone()], + &[proof.siblings.len()], + leaf_index, + merkle_cap, + proof, + ) +} + +/// Verifies that the given leaf data is present at the given index in the Field Merkle tree with the +/// given cap. +pub fn verify_batch_merkle_proof_to_cap>( + leaf_data: &[Vec], + leaf_heights: &[usize], + mut leaf_index: usize, + merkle_cap: &MerkleCap, + proof: &MerkleProof, +) -> Result<()> { + assert_eq!(leaf_data.len(), leaf_heights.len()); + let mut current_digest = H::hash_or_noop(&leaf_data[0]); + let mut current_height = leaf_heights[0]; + let mut leaf_data_index = 1; + for &sibling_digest in &proof.siblings { + let bit = leaf_index & 1; + leaf_index >>= 1; + current_digest = if bit == 1 { + H::two_to_one(sibling_digest, current_digest) + } else { + H::two_to_one(current_digest, sibling_digest) + }; + current_height -= 1; + + if leaf_data_index < leaf_heights.len() && current_height == leaf_heights[leaf_data_index] { + let mut new_leaves = current_digest.to_vec(); + new_leaves.extend_from_slice(&leaf_data[leaf_data_index]); + current_digest = H::hash_or_noop(&new_leaves); + leaf_data_index += 1; + } + } + assert_eq!(leaf_data_index, leaf_data.len()); + ensure!( + current_digest == merkle_cap.0[leaf_index], + "Invalid Merkle proof." + ); + + Ok(()) +} + +#[cfg(test)] +pub(crate) mod tests { + use anyhow::Result; + + use super::*; + use plonky2::field::extension::Extendable; + use crate::merkle_tree::capped_tree::verify_merkle_proof_to_cap; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + pub(crate) fn random_data(n: usize, k: usize) -> Vec> { + (0..n).map(|_| F::rand_vec(k)).collect() + } + + fn verify_all_leaves< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, + >( + leaves: Vec>, + cap_height: usize, + ) -> Result<()> { + let tree = MerkleTree::::new(leaves.clone(), cap_height); + for (i, leaf) in leaves.into_iter().enumerate() { + let proof = tree.prove(i); + verify_merkle_proof_to_cap(leaf, i, &tree.cap, &proof)?; + } + Ok(()) + } + + #[test] + #[should_panic] + fn test_cap_height_too_big() { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let log_n = 8; + let cap_height = log_n + 1; // Should panic if `cap_height > len_n`. + + let leaves = random_data::(1 << log_n, 7); + let _ = MerkleTree::>::Hasher>::new(leaves, cap_height); + } + + #[test] + fn test_cap_height_eq_log2_len() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let log_n = 8; + let n = 1 << log_n; + let leaves = random_data::(n, 7); + + verify_all_leaves::(leaves, log_n)?; + + Ok(()) + } + + #[test] + fn test_merkle_trees() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let log_n = 8; + let n = 1 << log_n; + let leaves = random_data::(n, 7); + + verify_all_leaves::(leaves, 1)?; + + Ok(()) + } +} diff --git a/codex-plonky2-circuits/src/merkle_tree/merkle_safe.rs b/codex-plonky2-circuits/src/merkle_tree/merkle_safe.rs new file mode 100644 index 0000000..3aa995e --- /dev/null +++ b/codex-plonky2-circuits/src/merkle_tree/merkle_safe.rs @@ -0,0 +1,514 @@ +// Implementation of "safe" merkle tree +// consistent with the one in codex: +// https://github.com/codex-storage/nim-codex/blob/master/codex/merkletree/merkletree.nim + +use anyhow::{ensure, Result}; +use plonky2::field::goldilocks_field::GoldilocksField; +use plonky2::hash::hash_types::{HashOut, HashOutTarget, RichField}; +use plonky2::hash::poseidon::PoseidonHash; +use plonky2::plonk::config::Hasher; +use std::ops::Shr; +use plonky2_field::types::Field; + +// Constants for the keys used in compression +const KEY_NONE: u64 = 0x0; +const KEY_BOTTOM_LAYER: u64 = 0x1; +const KEY_ODD: u64 = 0x2; +const KEY_ODD_AND_BOTTOM_LAYER: u64 = 0x3; + +/// Trait for a hash function that supports keyed compression. +pub trait KeyedHasher: Hasher { + fn compress(x: Self::Hash, y: Self::Hash, key: u64) -> Self::Hash; +} + +impl KeyedHasher for PoseidonHash { + fn compress(x: Self::Hash, y: Self::Hash, key: u64) -> Self::Hash { + let key_field = GoldilocksField::from_canonical_u64(key); + let mut inputs = Vec::new(); + inputs.extend_from_slice(&x.elements); + inputs.extend_from_slice(&y.elements); + inputs.push(key_field); + PoseidonHash::hash_no_pad(&inputs) // TODO: double-check this function + } +} + +/// Merkle tree struct, containing the layers, compression function, and zero hash. +#[derive(Clone)] +pub struct MerkleTree> { + layers: Vec>, + compress: fn(H::Hash, H::Hash, u64) -> H::Hash, + zero: H::Hash, +} + +impl> MerkleTree { + /// Constructs a new Merkle tree from the given leaves. + pub fn new( + leaves: &[H::Hash], + zero: H::Hash, + compress: fn(H::Hash, H::Hash, u64) -> H::Hash, + ) -> Result { + let layers = merkle_tree_worker::(leaves, zero, compress, true)?; + Ok(Self { + layers, + compress, + zero, + }) + } + + /// Returns the depth of the Merkle tree. + pub fn depth(&self) -> usize { + self.layers.len() - 1 + } + + /// Returns the number of leaves in the Merkle tree. + pub fn leaves_count(&self) -> usize { + self.layers[0].len() + } + + /// Returns the root hash of the Merkle tree. + pub fn root(&self) -> Result { + let last_layer = self.layers.last().ok_or_else(|| anyhow::anyhow!("Empty tree"))?; + ensure!(last_layer.len() == 1, "Invalid Merkle tree"); + Ok(last_layer[0]) + } + + /// Generates a Merkle proof for a given leaf index. + pub fn get_proof(&self, index: usize) -> Result> { + let depth = self.depth(); + let nleaves = self.leaves_count(); + + ensure!(index < nleaves, "Index out of bounds"); + + let mut path = Vec::with_capacity(depth); + let mut k = index; + let mut m = nleaves; + + for i in 0..depth { + let j = k ^ 1; + let sibling = if j < m { + self.layers[i][j] + } else { + self.zero + }; + path.push(sibling); + k = k >> 1; + m = (m + 1) >> 1; + } + + Ok(MerkleProof { + index, + path, + nleaves, + compress: self.compress, + zero: self.zero, + }) + } +} + +/// Build the Merkle tree layers. +fn merkle_tree_worker>( + xs: &[H::Hash], + zero: H::Hash, + compress: fn(H::Hash, H::Hash, u64) -> H::Hash, + is_bottom_layer: bool, +) -> Result>> { + let m = xs.len(); + if !is_bottom_layer && m == 1 { + return Ok(vec![xs.to_vec()]); + } + + let halfn = m / 2; + let n = 2 * halfn; + let is_odd = n != m; + + let mut ys = Vec::with_capacity(halfn + if is_odd { 1 } else { 0 }); + + for i in 0..halfn { + let key = if is_bottom_layer { KEY_BOTTOM_LAYER } else { KEY_NONE }; + let h = compress(xs[2 * i], xs[2 * i + 1], key); + ys.push(h); + } + + if is_odd { + let key = if is_bottom_layer { + KEY_ODD_AND_BOTTOM_LAYER + } else { + KEY_ODD + }; + let h = compress(xs[n], zero, key); + ys.push(h); + } + + let mut layers = vec![xs.to_vec()]; + let mut upper_layers = merkle_tree_worker::(&ys, zero, compress, false)?; + layers.append(&mut upper_layers); + + Ok(layers) +} + +/// Merkle proof struct, containing the index, path, and other necessary data. +#[derive(Clone)] +pub struct MerkleProof> { + pub index: usize, // Index of the leaf + pub path: Vec, // Sibling hashes from the leaf to the root + pub nleaves: usize, // Total number of leaves + pub compress: fn(H::Hash, H::Hash, u64) -> H::Hash, // compression function - TODO: make it generic instead + pub zero: H::Hash, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct MerkleProofTarget { + /// The Merkle digest of each sibling subtree, staying from the bottommost layer. + pub path: Vec, +} + +impl> MerkleProof { + /// Reconstructs the root hash from the proof and the given leaf. + pub fn reconstruct_root(&self, leaf: H::Hash) -> Result { + let mut m = self.nleaves; + let mut j = self.index; + let mut h = leaf; + let mut bottom_flag = KEY_BOTTOM_LAYER; + + for p in &self.path { + let odd_index = (j & 1) != 0; + if odd_index { + // The index of the child is odd + h = (self.compress)(*p, h, bottom_flag); + } else { + if j == m - 1 { + // Single child -> so odd node + h = (self.compress)(h, *p, bottom_flag + 2); + } else { + // Even node + h = (self.compress)(h, *p, bottom_flag); + } + } + bottom_flag = KEY_NONE; + j = j.shr(1); + m = (m + 1).shr(1); + } + + Ok(h) + } + + /// Verifies the proof against a given root and leaf. + pub fn verify(&self, leaf: H::Hash, root: H::Hash) -> Result { + let reconstructed_root = self.reconstruct_root(leaf)?; + Ok(reconstructed_root == root) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use plonky2::field::types::Field; + + // Constants for the keys used in compression + // const KEY_NONE: u64 = 0x0; + // const KEY_BOTTOM_LAYER: u64 = 0x1; + // const KEY_ODD: u64 = 0x2; + // const KEY_ODD_AND_BOTTOM_LAYER: u64 = 0x3; + + fn compress( + x: HashOut, + y: HashOut, + key: u64, + ) -> HashOut { + let key_field = GoldilocksField::from_canonical_u64(key); + let mut inputs = Vec::new(); + inputs.extend_from_slice(&x.elements); + inputs.extend_from_slice(&y.elements); + inputs.push(key_field); + PoseidonHash::hash_no_pad(&inputs) + } + + fn make_tree( + data: &[GoldilocksField], + zero: HashOut, + ) -> Result> { + let compress_fn = PoseidonHash::compress; + + // Hash the data to obtain leaf hashes + let leaves: Vec> = data + .iter() + .map(|&element| { + // Hash each field element to get the leaf hash + PoseidonHash::hash_no_pad(&[element]) + }) + .collect(); + + MerkleTree::::new(&leaves, zero, compress_fn) + } + + #[test] + fn single_proof_test() -> Result<()> { + let data = (1u64..=8) + .map(|i| GoldilocksField::from_canonical_u64(i)) + .collect::>(); + + // Hash the data to obtain leaf hashes + let leaves: Vec> = data + .iter() + .map(|&element| { + // Hash each field element to get the leaf hash + PoseidonHash::hash_no_pad(&[element]) + }) + .collect(); + + let zero = HashOut { + elements: [GoldilocksField::ZERO; 4], + }; + + let compress_fn = PoseidonHash::compress; + + // Build the Merkle tree + let tree = MerkleTree::::new(&leaves, zero, compress_fn)?; + + // Get the root + let root = tree.root()?; + + // Get a proof for the first leaf + let proof = tree.get_proof(0)?; + + // Verify the proof + let is_valid = proof.verify(leaves[0], root)?; + assert!(is_valid, "Merkle proof verification failed"); + + Ok(()) + } + + #[test] + fn test_correctness_even_bottom_layer() -> Result<()> { + // Data for the test (field elements) + let data = (1u64..=8) + .map(|i| GoldilocksField::from_canonical_u64(i)) + .collect::>(); + + // Hash the data to get leaf hashes + let leaf_hashes: Vec> = data + .iter() + .map(|&element| PoseidonHash::hash_no_pad(&[element])) + .collect(); + + // zero hash + let zero = HashOut { + elements: [GoldilocksField::ZERO; 4], + }; + + let expected_root = + compress( + compress( + compress( + leaf_hashes[0], + leaf_hashes[1], + KEY_BOTTOM_LAYER, + ), + compress( + leaf_hashes[2], + leaf_hashes[3], + KEY_BOTTOM_LAYER, + ), + KEY_NONE, + ), + compress( + compress( + leaf_hashes[4], + leaf_hashes[5], + KEY_BOTTOM_LAYER, + ), + compress( + leaf_hashes[6], + leaf_hashes[7], + KEY_BOTTOM_LAYER, + ), + KEY_NONE, + ), + KEY_NONE, + ); + + // Build the tree + let tree = make_tree(&data, zero)?; + + // Get the computed root + let computed_root = tree.root()?; + + // Check that the computed root matches the expected root + assert_eq!(computed_root, expected_root); + + Ok(()) + } + + #[test] + fn test_correctness_odd_bottom_layer() -> Result<()> { + // Data for the test (field elements) + let data = (1u64..=7) + .map(|i| GoldilocksField::from_canonical_u64(i)) + .collect::>(); + + // Hash the data to get leaf hashes + let leaf_hashes: Vec> = data + .iter() + .map(|&element| PoseidonHash::hash_no_pad(&[element])) + .collect(); + + // zero hash + let zero = HashOut { + elements: [GoldilocksField::ZERO; 4], + }; + + let expected_root = + compress( + compress( + compress( + leaf_hashes[0], + leaf_hashes[1], + KEY_BOTTOM_LAYER, + ), + compress( + leaf_hashes[2], + leaf_hashes[3], + KEY_BOTTOM_LAYER, + ), + KEY_NONE, + ), + compress( + compress( + leaf_hashes[4], + leaf_hashes[5], + KEY_BOTTOM_LAYER, + ), + compress( + leaf_hashes[6], + zero, + KEY_ODD_AND_BOTTOM_LAYER, + ), + KEY_NONE, + ), + KEY_NONE, + ); + + // Build the tree + let tree = make_tree(&data, zero)?; + + // Get the computed root + let computed_root = tree.root()?; + + // Check that the computed root matches the expected root + assert_eq!(computed_root, expected_root); + + Ok(()) + } + + #[test] + fn test_correctness_even_bottom_odd_upper_layers() -> Result<()> { + // Data for the test (field elements) + let data = (1u64..=10) + .map(|i| GoldilocksField::from_canonical_u64(i)) + .collect::>(); + + // Hash the data to get leaf hashes + let leaf_hashes: Vec> = data + .iter() + .map(|&element| PoseidonHash::hash_no_pad(&[element])) + .collect(); + + // zero hash + let zero = HashOut { + elements: [GoldilocksField::ZERO; 4], + }; + + let expected_root = compress( + compress( + compress( + compress( + leaf_hashes[0], + leaf_hashes[1], + KEY_BOTTOM_LAYER, + ), + compress( + leaf_hashes[2], + leaf_hashes[3], + KEY_BOTTOM_LAYER, + ), + KEY_NONE, + ), + compress( + compress( + leaf_hashes[4], + leaf_hashes[5], + KEY_BOTTOM_LAYER, + ), + compress( + leaf_hashes[6], + leaf_hashes[7], + KEY_BOTTOM_LAYER, + ), + KEY_NONE, + ), + KEY_NONE, + ), + compress( + compress( + compress( + leaf_hashes[8], + leaf_hashes[9], + KEY_BOTTOM_LAYER, + ), + zero, + KEY_ODD, + ), + zero, + KEY_ODD, + ), + KEY_NONE, + ); + + // Build the tree + let tree = make_tree(&data, zero)?; + + // Get the computed root + let computed_root = tree.root()?; + + // Check that the computed root matches the expected root + assert_eq!(computed_root, expected_root); + + Ok(()) + } + + #[test] + fn test_proofs() -> Result<()> { + // Data for the test (field elements) + let data = (1u64..=10) + .map(|i| GoldilocksField::from_canonical_u64(i)) + .collect::>(); + + // Hash the data to get leaf hashes + let leaf_hashes: Vec> = data + .iter() + .map(|&element| PoseidonHash::hash_no_pad(&[element])) + .collect(); + + // zero hash + let zero = HashOut { + elements: [GoldilocksField::ZERO; 4], + }; + + let compress_fn = PoseidonHash::compress; + + // Build the tree + let tree = MerkleTree::::new(&leaf_hashes, zero, compress_fn)?; + + // Get the root + let expected_root = tree.root()?; + + // Verify proofs for all leaves + for (i, &leaf_hash) in leaf_hashes.iter().enumerate() { + let proof = tree.get_proof(i)?; + let is_valid = proof.verify(leaf_hash, expected_root)?; + assert!(is_valid, "Proof verification failed for leaf {}", i); + } + + Ok(()) + } +} \ No newline at end of file diff --git a/codex-plonky2-circuits/src/merkle_tree/mod.rs b/codex-plonky2-circuits/src/merkle_tree/mod.rs index 2c8be2b..e977a83 100644 --- a/codex-plonky2-circuits/src/merkle_tree/mod.rs +++ b/codex-plonky2-circuits/src/merkle_tree/mod.rs @@ -1,378 +1,2 @@ -// An adapted implementation of Merkle tree -// based on the original plonky2 merkle tree implementation - -use core::mem::MaybeUninit; -use core::slice; -use anyhow::{ensure, Result}; -use plonky2_maybe_rayon::*; -use serde::{Deserialize, Serialize}; - -use plonky2::hash::hash_types::{HashOutTarget, RichField}; -use plonky2::plonk::config::{GenericHashOut, Hasher}; -use plonky2::util::log2_strict; - -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -#[serde(bound = "")] -pub struct MerkleCap>(pub Vec); - -impl> Default for MerkleCap { - fn default() -> Self { - Self(Vec::new()) - } -} - -impl> MerkleCap { - pub fn len(&self) -> usize { - self.0.len() - } - - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - pub fn height(&self) -> usize { - log2_strict(self.len()) - } - - pub fn flatten(&self) -> Vec { - self.0.iter().flat_map(|&h| h.to_vec()).collect() - } -} - -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct MerkleTree> { - pub leaves: Vec>, - - pub digests: Vec, - - pub cap: MerkleCap, -} - -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -#[serde(bound = "")] -pub struct MerkleProof> { - /// The Merkle digest of each sibling subtree, staying from the bottommost layer. - pub siblings: Vec, -} - -impl> MerkleProof { - pub fn len(&self) -> usize { - self.siblings.len() - } - - pub fn is_empty(&self) -> bool { - self.len() == 0 - } -} - -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct MerkleProofTarget { - /// The Merkle digest of each sibling subtree, staying from the bottommost layer. - pub siblings: Vec, -} - -impl> Default for MerkleTree { - fn default() -> Self { - Self { - leaves: Vec::new(), - digests: Vec::new(), - cap: MerkleCap::default(), - } - } -} - -pub(crate) fn capacity_up_to_mut(v: &mut Vec, len: usize) -> &mut [MaybeUninit] { - assert!(v.capacity() >= len); - let v_ptr = v.as_mut_ptr().cast::>(); - unsafe { - slice::from_raw_parts_mut(v_ptr, len) - } -} - -pub(crate) fn fill_subtree>( - digests_buf: &mut [MaybeUninit], - leaves: &[Vec], -) -> H::Hash { - assert_eq!(leaves.len(), digests_buf.len() / 2 + 1); - if digests_buf.is_empty() { - H::hash_or_noop(&leaves[0]) - } else { - let (left_digests_buf, right_digests_buf) = digests_buf.split_at_mut(digests_buf.len() / 2); - let (left_digest_mem, left_digests_buf) = left_digests_buf.split_last_mut().unwrap(); - let (right_digest_mem, right_digests_buf) = right_digests_buf.split_first_mut().unwrap(); - - let (left_leaves, right_leaves) = leaves.split_at(leaves.len() / 2); - - let (left_digest, right_digest) = plonky2_maybe_rayon::join( - || fill_subtree::(left_digests_buf, left_leaves), - || fill_subtree::(right_digests_buf, right_leaves), - ); - - left_digest_mem.write(left_digest); - right_digest_mem.write(right_digest); - H::two_to_one(left_digest, right_digest) - } -} - -pub(crate) fn fill_digests_buf>( - digests_buf: &mut [MaybeUninit], - cap_buf: &mut [MaybeUninit], - leaves: &[Vec], - cap_height: usize, -) { - - if digests_buf.is_empty() { - debug_assert_eq!(cap_buf.len(), leaves.len()); - cap_buf - .par_iter_mut() - .zip(leaves) - .for_each(|(cap_buf, leaf)| { - cap_buf.write(H::hash_or_noop(leaf)); - }); - return; - } - - let subtree_digests_len = digests_buf.len() >> cap_height; - let subtree_leaves_len = leaves.len() >> cap_height; - let digests_chunks = digests_buf.par_chunks_exact_mut(subtree_digests_len); - let leaves_chunks = leaves.par_chunks_exact(subtree_leaves_len); - assert_eq!(digests_chunks.len(), cap_buf.len()); - assert_eq!(digests_chunks.len(), leaves_chunks.len()); - digests_chunks.zip(cap_buf).zip(leaves_chunks).for_each( - |((subtree_digests, subtree_cap), subtree_leaves)| { - - subtree_cap.write(fill_subtree::(subtree_digests, subtree_leaves)); - }, - ); -} - -pub(crate) fn merkle_tree_prove>( - leaf_index: usize, - leaves_len: usize, - cap_height: usize, - digests: &[H::Hash], -) -> Vec { - let num_layers = log2_strict(leaves_len) - cap_height; - debug_assert_eq!(leaf_index >> (cap_height + num_layers), 0); - - let digest_len = 2 * (leaves_len - (1 << cap_height)); - assert_eq!(digest_len, digests.len()); - - let digest_tree: &[H::Hash] = { - let tree_index = leaf_index >> num_layers; - let tree_len = digest_len >> cap_height; - &digests[tree_len * tree_index..tree_len * (tree_index + 1)] - }; - - // Mask out high bits to get the index within the sub-tree. - let mut pair_index = leaf_index & ((1 << num_layers) - 1); - (0..num_layers) - .map(|i| { - let parity = pair_index & 1; - pair_index >>= 1; - - // The layers' data is interleaved as follows: - // [layer 0, layer 1, layer 0, layer 2, layer 0, layer 1, layer 0, layer 3, ...]. - // Each of the above is a pair of siblings. - // `pair_index` is the index of the pair within layer `i`. - // The index of that the pair within `digests` is - // `pair_index * 2 ** (i + 1) + (2 ** i - 1)`. - let siblings_index = (pair_index << (i + 1)) + (1 << i) - 1; - // We have an index for the _pair_, but we want the index of the _sibling_. - // Double the pair index to get the index of the left sibling. Conditionally add `1` - // if we are to retrieve the right sibling. - let sibling_index = 2 * siblings_index + (1 - parity); - digest_tree[sibling_index] - }) - .collect() -} - -impl> MerkleTree { - pub fn new(leaves: Vec>, cap_height: usize) -> Self { - let log2_leaves_len = log2_strict(leaves.len()); - assert!( - cap_height <= log2_leaves_len, - "cap_height={} should be at most log2(leaves.len())={}", - cap_height, - log2_leaves_len - ); - - let num_digests = 2 * (leaves.len() - (1 << cap_height)); - let mut digests = Vec::with_capacity(num_digests); - - let len_cap = 1 << cap_height; - let mut cap = Vec::with_capacity(len_cap); - - let digests_buf = capacity_up_to_mut(&mut digests, num_digests); - let cap_buf = capacity_up_to_mut(&mut cap, len_cap); - fill_digests_buf::(digests_buf, cap_buf, &leaves[..], cap_height); - - unsafe { - // SAFETY: `fill_digests_buf` and `cap` initialized the spare capacity up to - // `num_digests` and `len_cap`, resp. - digests.set_len(num_digests); - cap.set_len(len_cap); - } - - Self { - leaves, - digests, - cap: MerkleCap(cap), - } - } - - pub fn get(&self, i: usize) -> &[F] { - &self.leaves[i] - } - - // Create a Merkle proof from a leaf index. - pub fn prove(&self, leaf_index: usize) -> MerkleProof { - let cap_height = log2_strict(self.cap.len()); - let siblings = - merkle_tree_prove::(leaf_index, self.leaves.len(), cap_height, &self.digests); - - MerkleProof { siblings } - } -} - -/// Verifies that the given leaf data is present at the given index in the Merkle tree with the -/// given root. -pub fn verify_merkle_proof>( - leaf_data: Vec, - leaf_index: usize, - merkle_root: H::Hash, - proof: &MerkleProof, -) -> Result<()> { - let merkle_cap = MerkleCap(vec![merkle_root]); - verify_merkle_proof_to_cap(leaf_data, leaf_index, &merkle_cap, proof) -} - -/// Verifies that the given leaf data is present at the given index in the Merkle tree with the -/// given cap. -pub fn verify_merkle_proof_to_cap>( - leaf_data: Vec, - leaf_index: usize, - merkle_cap: &MerkleCap, - proof: &MerkleProof, -) -> Result<()> { - verify_batch_merkle_proof_to_cap( - &[leaf_data.clone()], - &[proof.siblings.len()], - leaf_index, - merkle_cap, - proof, - ) -} - -/// Verifies that the given leaf data is present at the given index in the Field Merkle tree with the -/// given cap. -pub fn verify_batch_merkle_proof_to_cap>( - leaf_data: &[Vec], - leaf_heights: &[usize], - mut leaf_index: usize, - merkle_cap: &MerkleCap, - proof: &MerkleProof, -) -> Result<()> { - assert_eq!(leaf_data.len(), leaf_heights.len()); - let mut current_digest = H::hash_or_noop(&leaf_data[0]); - let mut current_height = leaf_heights[0]; - let mut leaf_data_index = 1; - for &sibling_digest in &proof.siblings { - let bit = leaf_index & 1; - leaf_index >>= 1; - current_digest = if bit == 1 { - H::two_to_one(sibling_digest, current_digest) - } else { - H::two_to_one(current_digest, sibling_digest) - }; - current_height -= 1; - - if leaf_data_index < leaf_heights.len() && current_height == leaf_heights[leaf_data_index] { - let mut new_leaves = current_digest.to_vec(); - new_leaves.extend_from_slice(&leaf_data[leaf_data_index]); - current_digest = H::hash_or_noop(&new_leaves); - leaf_data_index += 1; - } - } - assert_eq!(leaf_data_index, leaf_data.len()); - ensure!( - current_digest == merkle_cap.0[leaf_index], - "Invalid Merkle proof." - ); - - Ok(()) -} - -#[cfg(test)] -pub(crate) mod tests { - use anyhow::Result; - - use super::*; - use plonky2::field::extension::Extendable; - use crate::merkle_tree::verify_merkle_proof_to_cap; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - pub(crate) fn random_data(n: usize, k: usize) -> Vec> { - (0..n).map(|_| F::rand_vec(k)).collect() - } - - fn verify_all_leaves< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( - leaves: Vec>, - cap_height: usize, - ) -> Result<()> { - let tree = MerkleTree::::new(leaves.clone(), cap_height); - for (i, leaf) in leaves.into_iter().enumerate() { - let proof = tree.prove(i); - verify_merkle_proof_to_cap(leaf, i, &tree.cap, &proof)?; - } - Ok(()) - } - - #[test] - #[should_panic] - fn test_cap_height_too_big() { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let log_n = 8; - let cap_height = log_n + 1; // Should panic if `cap_height > len_n`. - - let leaves = random_data::(1 << log_n, 7); - let _ = MerkleTree::>::Hasher>::new(leaves, cap_height); - } - - #[test] - fn test_cap_height_eq_log2_len() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let log_n = 8; - let n = 1 << log_n; - let leaves = random_data::(n, 7); - - verify_all_leaves::(leaves, log_n)?; - - Ok(()) - } - - #[test] - fn test_merkle_trees() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let log_n = 8; - let n = 1 << log_n; - let leaves = random_data::(n, 7); - - verify_all_leaves::(leaves, 1)?; - - Ok(()) - } -} +pub mod capped_tree; +pub mod merkle_safe; \ No newline at end of file