// Implementation of "safe" merkle tree // consistent with the one in codex: // https://github.com/codex-storage/nim-codex/blob/master/codex/merkletree/merkletree.nim use std::marker::PhantomData; use anyhow::{ensure, Result}; use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::hash::hash_types::{HashOut, RichField}; use plonky2::hash::poseidon::PoseidonHash; use plonky2::plonk::config::Hasher; use std::ops::Shr; use plonky2_field::extension::Extendable; use plonky2_field::types::Field; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; use crate::circuits::keyed_compress::key_compress; use crate::circuits::params::HF; // Constants for the keys used in compression pub const KEY_NONE: u64 = 0x0; pub const KEY_BOTTOM_LAYER: u64 = 0x1; pub const KEY_ODD: u64 = 0x2; pub const KEY_ODD_AND_BOTTOM_LAYER: u64 = 0x3; /// Merkle tree struct, containing the layers, compression function, and zero hash. #[derive(Clone)] pub struct MerkleTree< F: RichField + Extendable + Poseidon2, const D: usize, > { pub layers: Vec>>, pub zero: HashOut, } impl< F: RichField + Extendable + Poseidon2, const D: usize, > MerkleTree { /// Constructs a new Merkle tree from the given leaves. pub fn new( leaves: &[HashOut], zero: HashOut, ) -> Result { let layers = merkle_tree_worker::(leaves, zero, true)?; Ok(Self { layers, zero, }) } /// Returns the depth of the Merkle tree. pub fn depth(&self) -> usize { self.layers.len() - 1 } /// Returns the number of leaves in the Merkle tree. pub fn leaves_count(&self) -> usize { self.layers[0].len() } /// Returns the root hash of the Merkle tree. pub fn root(&self) -> Result> { let last_layer = self.layers.last().ok_or_else(|| anyhow::anyhow!("Empty tree"))?; ensure!(last_layer.len() == 1, "Invalid Merkle tree"); Ok(last_layer[0]) } /// Generates a Merkle proof for a given leaf index. pub fn get_proof(&self, index: usize) -> Result> { let depth = self.depth(); let nleaves = self.leaves_count(); ensure!(index < nleaves, "Index out of bounds"); let mut path = Vec::with_capacity(depth); let mut k = index; let mut m = nleaves; for i in 0..depth { let j = k ^ 1; let sibling = if j < m { self.layers[i][j] } else { self.zero }; path.push(sibling); k = k >> 1; m = (m + 1) >> 1; } Ok(MerkleProof { index, path, nleaves, zero: self.zero, }) } } /// Build the Merkle tree layers. fn merkle_tree_worker< F: RichField + Extendable + Poseidon2, const D: usize, >( xs: &[HashOut], zero: HashOut, is_bottom_layer: bool, ) -> Result>>> { let m = xs.len(); if !is_bottom_layer && m == 1 { return Ok(vec![xs.to_vec()]); } let halfn = m / 2; let n = 2 * halfn; let is_odd = n != m; let mut ys = Vec::with_capacity(halfn + if is_odd { 1 } else { 0 }); for i in 0..halfn { let key = if is_bottom_layer { KEY_BOTTOM_LAYER } else { KEY_NONE }; let h = key_compress::(xs[2 * i], xs[2 * i + 1], key); ys.push(h); } if is_odd { let key = if is_bottom_layer { KEY_ODD_AND_BOTTOM_LAYER } else { KEY_ODD }; let h = key_compress::(xs[n], zero, key); ys.push(h); } let mut layers = vec![xs.to_vec()]; let mut upper_layers = merkle_tree_worker::(&ys, zero, false)?; layers.append(&mut upper_layers); Ok(layers) } /// Merkle proof struct, containing the index, path, and other necessary data. #[derive(Clone)] pub struct MerkleProof< F: RichField + Extendable + Poseidon2, const D: usize, > { pub index: usize, // Index of the leaf pub path: Vec>, // Sibling hashes from the leaf to the root pub nleaves: usize, // Total number of leaves pub zero: HashOut, } impl< F: RichField + Extendable + Poseidon2, const D: usize, > MerkleProof { /// Reconstructs the root hash from the proof and the given leaf. pub fn reconstruct_root(&self, leaf: HashOut) -> Result> { let mut m = self.nleaves; let mut j = self.index; let mut h = leaf; let mut bottom_flag = KEY_BOTTOM_LAYER; for p in &self.path { let odd_index = (j & 1) != 0; if odd_index { // The index of the child is odd h = key_compress::(*p, h, bottom_flag); } else { if j == m - 1 { // Single child -> so odd node h = key_compress::(h, *p, bottom_flag + 2); } else { // Even node h = key_compress::(h, *p, bottom_flag); } } bottom_flag = KEY_NONE; j = j.shr(1); m = (m + 1).shr(1); } Ok(h) } /// reconstruct the root using path_bits and last_bits in similar way as the circuit /// this is used for testing - sanity check pub fn reconstruct_root2(leaf: HashOut, path_bits: Vec, last_bits:Vec, path: Vec>, mask_bits:Vec, depth: usize) -> Result> { let is_last = compute_is_last(path_bits.clone(),last_bits); let mut h = vec![]; h.push(leaf); let mut i = 0; for p in &path { let bottom = if(i==0){ KEY_BOTTOM_LAYER }else{ KEY_NONE }; let odd = (is_last[i] as usize) * (1-(path_bits[i] as usize)); let key = bottom + (2 * (odd as u64)); let odd_index = path_bits[i]; if odd_index { h.push(key_compress::(*p, h[i], key)); } else { h.push(key_compress::(h[i], *p, key)); } i += 1; } let mut reconstructed_root = HashOut::::ZERO; for k in 0..depth{ let diff = (mask_bits[k] as u64) - (mask_bits[k+1] as u64); let mul_res: Vec = h[k+1].elements.iter().map(|e| e.mul(F::from_canonical_u64(diff))).collect(); reconstructed_root = HashOut::::from_vec( mul_res.iter().zip(reconstructed_root.elements).map(|(e1,e2)| e1.add(e2)).collect() ); } Ok(reconstructed_root) } /// Verifies the proof against a given root and leaf. pub fn verify(&self, leaf: HashOut, root: HashOut) -> Result { let reconstructed_root = self.reconstruct_root(leaf)?; Ok(reconstructed_root == root) } } ///helper function to compute is_last fn compute_is_last(path_bits: Vec, last_bits: Vec) -> Vec { let max_depth = path_bits.len(); // Initialize isLast vector let mut is_last = vec![false; max_depth + 1]; is_last[max_depth] = true; // Set isLast[max_depth] to 1 (true) // Iterate over eq and isLast in reverse order for i in (0..max_depth).rev() { let eq_out = path_bits[i] == last_bits[i]; // eq[i].out is_last[i] = is_last[i + 1] && eq_out; // isLast[i] = isLast[i+1] * eq[i].out } is_last } #[cfg(test)] mod tests { use super::*; use plonky2::field::types::Field; use crate::circuits::keyed_compress::key_compress; // types used in all tests type F = GoldilocksField; const D: usize = 2; type H = PoseidonHash; fn compress( x: HashOut, y: HashOut, key: u64, ) -> HashOut { key_compress::(x,y,key) } fn make_tree( data: &[F], zero: HashOut, ) -> Result> { // Hash the data to obtain leaf hashes let leaves: Vec> = data .iter() .map(|&element| { // Hash each field element to get the leaf hash H::hash_no_pad(&[element]) }) .collect(); MerkleTree::::new(&leaves, zero) } #[test] fn single_proof_test() -> Result<()> { let data = (1u64..=8) .map(|i| F::from_canonical_u64(i)) .collect::>(); // Hash the data to obtain leaf hashes let leaves: Vec> = data .iter() .map(|&element| { // Hash each field element to get the leaf hash H::hash_no_pad(&[element]) }) .collect(); let zero = HashOut { elements: [F::ZERO; 4], }; // Build the Merkle tree let tree = MerkleTree::::new(&leaves, zero)?; // Get the root let root = tree.root()?; // Get a proof for the first leaf let proof = tree.get_proof(0)?; // Verify the proof let is_valid = proof.verify(leaves[0], root)?; assert!(is_valid, "Merkle proof verification failed"); Ok(()) } #[test] fn test_correctness_even_bottom_layer() -> Result<()> { // Data for the test (field elements) let data = (1u64..=8) .map(|i| F::from_canonical_u64(i)) .collect::>(); // Hash the data to get leaf hashes let leaf_hashes: Vec> = data .iter() .map(|&element| H::hash_no_pad(&[element])) .collect(); // zero hash let zero = HashOut { elements: [F::ZERO; 4], }; let expected_root = compress( compress( compress( leaf_hashes[0], leaf_hashes[1], KEY_BOTTOM_LAYER, ), compress( leaf_hashes[2], leaf_hashes[3], KEY_BOTTOM_LAYER, ), KEY_NONE, ), compress( compress( leaf_hashes[4], leaf_hashes[5], KEY_BOTTOM_LAYER, ), compress( leaf_hashes[6], leaf_hashes[7], KEY_BOTTOM_LAYER, ), KEY_NONE, ), KEY_NONE, ); // Build the tree let tree = make_tree(&data, zero)?; // Get the computed root let computed_root = tree.root()?; // Check that the computed root matches the expected root assert_eq!(computed_root, expected_root); Ok(()) } #[test] fn test_correctness_odd_bottom_layer() -> Result<()> { // Data for the test (field elements) let data = (1u64..=7) .map(|i| F::from_canonical_u64(i)) .collect::>(); // Hash the data to get leaf hashes let leaf_hashes: Vec> = data .iter() .map(|&element| H::hash_no_pad(&[element])) .collect(); // zero hash let zero = HashOut { elements: [F::ZERO; 4], }; let expected_root = compress( compress( compress( leaf_hashes[0], leaf_hashes[1], KEY_BOTTOM_LAYER, ), compress( leaf_hashes[2], leaf_hashes[3], KEY_BOTTOM_LAYER, ), KEY_NONE, ), compress( compress( leaf_hashes[4], leaf_hashes[5], KEY_BOTTOM_LAYER, ), compress( leaf_hashes[6], zero, KEY_ODD_AND_BOTTOM_LAYER, ), KEY_NONE, ), KEY_NONE, ); // Build the tree let tree = make_tree(&data, zero)?; // Get the computed root let computed_root = tree.root()?; // Check that the computed root matches the expected root assert_eq!(computed_root, expected_root); Ok(()) } #[test] fn test_correctness_even_bottom_odd_upper_layers() -> Result<()> { // Data for the test (field elements) let data = (1u64..=10) .map(|i| F::from_canonical_u64(i)) .collect::>(); // Hash the data to get leaf hashes let leaf_hashes: Vec> = data .iter() .map(|&element| H::hash_no_pad(&[element])) .collect(); // zero hash let zero = HashOut { elements: [F::ZERO; 4], }; let expected_root = compress( compress( compress( compress( leaf_hashes[0], leaf_hashes[1], KEY_BOTTOM_LAYER, ), compress( leaf_hashes[2], leaf_hashes[3], KEY_BOTTOM_LAYER, ), KEY_NONE, ), compress( compress( leaf_hashes[4], leaf_hashes[5], KEY_BOTTOM_LAYER, ), compress( leaf_hashes[6], leaf_hashes[7], KEY_BOTTOM_LAYER, ), KEY_NONE, ), KEY_NONE, ), compress( compress( compress( leaf_hashes[8], leaf_hashes[9], KEY_BOTTOM_LAYER, ), zero, KEY_ODD, ), zero, KEY_ODD, ), KEY_NONE, ); // Build the tree let tree = make_tree(&data, zero)?; // Get the computed root let computed_root = tree.root()?; // Check that the computed root matches the expected root assert_eq!(computed_root, expected_root); Ok(()) } #[test] fn test_proofs() -> Result<()> { // Data for the test (field elements) let data = (1u64..=10) .map(|i| F::from_canonical_u64(i)) .collect::>(); // Hash the data to get leaf hashes let leaf_hashes: Vec> = data .iter() .map(|&element| H::hash_no_pad(&[element])) .collect(); // zero hash let zero = HashOut { elements: [F::ZERO; 4], }; // Build the tree let tree = MerkleTree::::new(&leaf_hashes, zero)?; // Get the root let expected_root = tree.root()?; // Verify proofs for all leaves for (i, &leaf_hash) in leaf_hashes.iter().enumerate() { let proof = tree.get_proof(i)?; let is_valid = proof.verify(leaf_hash, expected_root)?; assert!(is_valid, "Proof verification failed for leaf {}", i); } Ok(()) } }