diff --git a/codex-plonky2-circuits/README.md b/codex-plonky2-circuits/README.md index 2e405f1..86da0bd 100644 --- a/codex-plonky2-circuits/README.md +++ b/codex-plonky2-circuits/README.md @@ -5,15 +5,9 @@ This crate is an implementation of the [codex storage proofs circuits](https://g ## Code organization -- [`capped_tree`](./src/merkle_tree/capped_tree.rs) is an adapted implementation of Merkle tree based on the original plonky2 merkle tree implementation. - -- [`capped_tree_circuit`](./src/circuits/capped_tree_circuit.rs) is the circuit implementation for regular merkle tree implementation (non-safe version) based on the above merkle tree implementation. - - [`merkle_safe`](./src/merkle_tree/merkle_safe.rs) is the implementation of "safe" merkle tree used in codex, consistent with the one [here](https://github.com/codex-storage/nim-codex/blob/master/codex/merkletree/merkletree.nim). -- [`safe_tree_circuit`](./src/circuits/safe_tree_circuit.rs) is the Plonky2 Circuit implementation of "safe" merkle tree above. - -- [`prove_single_cell`](./src/circuits/prove_single_cell.rs) is the Plonky2 Circuit implementation for proving a single cell in slot merkle tree. +- [`merkle_circuit`](./src/circuits/merkle_circuit) is the Plonky2 Circuit implementation of "safe" merkle tree above. - [`sample_cells`](./src/circuits/sample_cells.rs) is the Plonky2 Circuit implementation for sampling cells in dataset merkle tree. diff --git a/codex-plonky2-circuits/benches/prove_cells.rs b/codex-plonky2-circuits/benches/prove_cells.rs index 3a561e9..20b2914 100644 --- a/codex-plonky2-circuits/benches/prove_cells.rs +++ b/codex-plonky2-circuits/benches/prove_cells.rs @@ -4,7 +4,7 @@ use std::time::{Duration, Instant}; use codex_plonky2_circuits::{ merkle_tree::merkle_safe::MerkleProof, - circuits::safe_tree_circuit::MerkleTreeCircuit, + circuits::merkle_circuit::MerkleTreeCircuit, }; use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData}; use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher, PoseidonGoldilocksConfig}; diff --git a/codex-plonky2-circuits/benches/safe_circuit.rs b/codex-plonky2-circuits/benches/safe_circuit.rs index 8ae9b00..bc3afcf 100644 --- a/codex-plonky2-circuits/benches/safe_circuit.rs +++ b/codex-plonky2-circuits/benches/safe_circuit.rs @@ -1,7 +1,7 @@ use criterion::{criterion_group, criterion_main, Criterion}; use anyhow::Result; -use codex_plonky2_circuits::{merkle_tree::merkle_safe::MerkleTree, circuits::safe_tree_circuit::MerkleTreeCircuit}; +use codex_plonky2_circuits::{merkle_tree::merkle_safe::MerkleTree, circuits::merkle_circuit::MerkleTreeCircuit}; use plonky2::field::types::Field; use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData}; use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher, PoseidonGoldilocksConfig}; diff --git a/codex-plonky2-circuits/src/circuits/capped_tree_circuit.rs b/codex-plonky2-circuits/src/circuits/capped_tree_circuit.rs deleted file mode 100644 index 39773eb..0000000 --- a/codex-plonky2-circuits/src/circuits/capped_tree_circuit.rs +++ /dev/null @@ -1,446 +0,0 @@ -// circuit for regular merkle tree implementation (non-safe version) -// the circuit uses caps in similar way as in Plonky2 Merkle tree implementation -// NOTE: this might be deleted at later time, since we don't use it for codex - -use anyhow::Result; -use plonky2::field::extension::Extendable; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::target::{BoolTarget, Target}; -use plonky2::iop::witness::{PartialWitness, WitnessWrite, Witness}; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, VerifierCircuitData}; -use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher, GenericHashOut}; -use plonky2::plonk::proof::ProofWithPublicInputs; -use std::marker::PhantomData; -use itertools::Itertools; - -use crate::merkle_tree::capped_tree::MerkleTree; -use plonky2::hash::poseidon::PoseidonHash; - -use plonky2::hash::hash_types::{HashOutTarget, MerkleCapTarget, NUM_HASH_OUT_ELTS}; -use crate::merkle_tree::capped_tree::MerkleProofTarget; -use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; - -use plonky2::plonk::config::PoseidonGoldilocksConfig; -use plonky2::plonk::proof::Proof; - -use plonky2::hash::hashing::PlonkyPermutation; -use plonky2::plonk::circuit_data::VerifierCircuitTarget; - -// size of leaf data (in number of field elements) -pub const LEAF_LEN: usize = 4; - -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct MerkleTreeTargets< - F: RichField + Extendable + Poseidon2, - C: GenericConfig, - const D: usize, - H: Hasher + AlgebraicHasher, -> { - pub proof_target: MerkleProofTarget, - pub cap_target: MerkleCapTarget, - pub leaf: Vec, - pub leaf_index_target: Target, - _phantom: PhantomData<(C,H)>, -} - -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct MerkleTreeCircuit< - F: RichField + Extendable + Poseidon2, - C: GenericConfig, - const D: usize, - H: Hasher + AlgebraicHasher, -> { - pub tree: MerkleTree, - pub _phantom: PhantomData, -} - -impl< - F: RichField + Extendable + Poseidon2, - C: GenericConfig, - const D: usize, - H: Hasher + AlgebraicHasher, -> MerkleTreeCircuit{ - - pub fn tree_height(&self) -> usize { - self.tree.leaves.len().trailing_zeros() as usize - } - - // build the circuit and returns the circuit data - pub fn build_circuit(&mut self, builder: &mut CircuitBuilder::) -> MerkleTreeTargets{ - - let proof_t = MerkleProofTarget { - siblings: builder.add_virtual_hashes(self.tree_height()-self.tree.cap.height()), - }; - - let cap_t = builder.add_virtual_cap(self.tree.cap.height()); - - let leaf_index_t = builder.add_virtual_target(); - - let leaf_index_bits = builder.split_le(leaf_index_t, self.tree_height()); - - // NOTE: takes the length from const LEAF_LEN and assume all lengths are the same - let leaf_t: [Target; LEAF_LEN] = builder.add_virtual_targets(LEAF_LEN).try_into().unwrap(); - - let zero = builder.zero(); - // let mut mt = MT(self.tree.clone()); - self.verify_merkle_proof_to_cap_circuit( - builder, leaf_t.to_vec(), &leaf_index_bits, &cap_t, &proof_t, - ); - - MerkleTreeTargets{ - proof_target: proof_t, - cap_target: cap_t, - leaf: leaf_t.to_vec(), - leaf_index_target: leaf_index_t, - _phantom: Default::default(), - } - } - - pub fn fill_targets( - &self, - pw: &mut PartialWitness, - // leaf_data: Vec, - leaf_index: usize, - targets: MerkleTreeTargets, - ) { - let proof = self.tree.prove(leaf_index); - - for i in 0..proof.siblings.len() { - pw.set_hash_target(targets.proof_target.siblings[i], proof.siblings[i]); - } - - // set cap target manually - // pw.set_cap_target(&cap_t, &tree.cap); - for (ht, h) in targets.cap_target.0.iter().zip(&self.tree.cap.0) { - pw.set_hash_target(*ht, *h); - } - - pw.set_target( - targets.leaf_index_target, - F::from_canonical_usize(leaf_index), - ); - - for j in 0..targets.leaf.len() { - pw.set_target(targets.leaf[j], self.tree.leaves[leaf_index][j]); - } - - } - - pub fn prove( - &self, - data: CircuitData, - pw: PartialWitness - ) -> Result> { - let proof = data.prove(pw); - return proof - } - - // function to automate build and prove, useful for quick testing - pub fn build_and_prove( - &mut self, - // builder: &mut CircuitBuilder::, - config: CircuitConfig, - // pw: &mut PartialWitness, - leaf_index: usize, - // data: CircuitData, - ) -> Result<(CircuitData,ProofWithPublicInputs)> { - let mut builder = CircuitBuilder::::new(config); - let mut pw = PartialWitness::new(); - // merkle proof - let merkle_proof = self.tree.prove(leaf_index); - let proof_t = MerkleProofTarget { - siblings: builder.add_virtual_hashes(merkle_proof.siblings.len()), - }; - - for i in 0..merkle_proof.siblings.len() { - pw.set_hash_target(proof_t.siblings[i], merkle_proof.siblings[i]); - } - - // merkle cap target - let cap_t = builder.add_virtual_cap(self.tree.cap.height()); - // set cap target manually - // pw.set_cap_target(&cap_t, &tree.cap); - for (ht, h) in cap_t.0.iter().zip(&self.tree.cap.0) { - pw.set_hash_target(*ht, *h); - } - - // leaf index target - let leaf_index_t = builder.constant(F::from_canonical_usize(leaf_index)); - let leaf_index_bits = builder.split_le(leaf_index_t, self.tree_height()); - - // leaf targets - // NOTE: takes the length from const LEAF_LEN and assume all lengths are the same - // let leaf_t = builder.add_virtual_targets(LEAF_LEN); - let leaf_t = builder.add_virtual_targets(self.tree.leaves[leaf_index].len()); - for j in 0..leaf_t.len() { - pw.set_target(leaf_t[j], self.tree.leaves[leaf_index][j]); - } - - // let mut mt = MT(self.tree.clone()); - self.verify_merkle_proof_to_cap_circuit( - &mut builder, leaf_t.to_vec(), &leaf_index_bits, &cap_t, &proof_t, - ); - let data = builder.build::(); - let proof = data.prove(pw).unwrap(); - - Ok((data, proof)) - } - - pub fn verify( - &self, - verifier_data: &VerifierCircuitData, - public_inputs: Vec, - proof: Proof - ) -> Result<()> { - verifier_data.verify(ProofWithPublicInputs { - proof, - public_inputs, - }) - } -} - -impl + Poseidon2, const D: usize, C: GenericConfig, H: Hasher + AlgebraicHasher,> MerkleTreeCircuit { - - pub fn verify_merkle_proof_circuit( - &mut self, - builder: &mut CircuitBuilder, - leaf_data: Vec, - leaf_index_bits: &[BoolTarget], - merkle_root: HashOutTarget, - proof: &MerkleProofTarget, - ) { - let merkle_cap = MerkleCapTarget(vec![merkle_root]); - self.verify_merkle_proof_to_cap_circuit(builder, leaf_data, leaf_index_bits, &merkle_cap, proof); - } - - pub fn verify_merkle_proof_to_cap_circuit( - &mut self, - builder: &mut CircuitBuilder, - leaf_data: Vec, - leaf_index_bits: &[BoolTarget], - merkle_cap: &MerkleCapTarget, - proof: &MerkleProofTarget, - ) { - let cap_index = builder.le_sum(leaf_index_bits[proof.siblings.len()..].iter().copied()); - self.verify_merkle_proof_to_cap_with_cap_index_circuit( - builder, - leaf_data, - leaf_index_bits, - cap_index, - merkle_cap, - proof, - ); - } - - pub fn verify_merkle_proof_to_cap_with_cap_index_circuit( - &mut self, - builder: &mut CircuitBuilder, - leaf_data: Vec, - leaf_index_bits: &[BoolTarget], - cap_index: Target, - merkle_cap: &MerkleCapTarget, - proof: &MerkleProofTarget, - ) { - debug_assert!(H::AlgebraicPermutation::RATE >= NUM_HASH_OUT_ELTS); - - let zero = builder.zero(); - let mut state: HashOutTarget = builder.hash_or_noop::(leaf_data); - debug_assert_eq!(state.elements.len(), NUM_HASH_OUT_ELTS); - - for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) { - debug_assert_eq!(sibling.elements.len(), NUM_HASH_OUT_ELTS); - - let mut perm_inputs = H::AlgebraicPermutation::default(); - perm_inputs.set_from_slice(&state.elements, 0); - perm_inputs.set_from_slice(&sibling.elements, NUM_HASH_OUT_ELTS); - // Ensure the rest of the state, if any, is zero: - perm_inputs.set_from_iter(core::iter::repeat(zero), 2 * NUM_HASH_OUT_ELTS); - // let perm_outs = builder.permute_swapped::(perm_inputs, bit); - let perm_outs = H::permute_swapped(perm_inputs, bit, builder); - let hash_outs = perm_outs.squeeze()[0..NUM_HASH_OUT_ELTS] - .try_into() - .unwrap(); - state = HashOutTarget { - elements: hash_outs, - }; - } - - for i in 0..NUM_HASH_OUT_ELTS { - let result = builder.random_access( - cap_index, - merkle_cap.0.iter().map(|h| h.elements[i]).collect(), - ); - builder.connect(result, state.elements[i]); - } - } - - pub fn verify_batch_merkle_proof_to_cap_with_cap_index_circuit( - &mut self, - builder: &mut CircuitBuilder, - leaf_data: &[Vec], - leaf_heights: &[usize], - leaf_index_bits: &[BoolTarget], - cap_index: Target, - merkle_cap: &MerkleCapTarget, - proof: &MerkleProofTarget, - ) { - debug_assert!(H::AlgebraicPermutation::RATE >= NUM_HASH_OUT_ELTS); - - let zero = builder.zero(); - let mut state: HashOutTarget = builder.hash_or_noop::(leaf_data[0].clone()); - debug_assert_eq!(state.elements.len(), NUM_HASH_OUT_ELTS); - - let mut current_height = leaf_heights[0]; - let mut leaf_data_index = 1; - for (&bit, &sibling) in leaf_index_bits.iter().zip(&proof.siblings) { - debug_assert_eq!(sibling.elements.len(), NUM_HASH_OUT_ELTS); - - let mut perm_inputs = H::AlgebraicPermutation::default(); - perm_inputs.set_from_slice(&state.elements, 0); - perm_inputs.set_from_slice(&sibling.elements, NUM_HASH_OUT_ELTS); - // Ensure the rest of the state, if any, is zero: - perm_inputs.set_from_iter(core::iter::repeat(zero), 2 * NUM_HASH_OUT_ELTS); - // let perm_outs = builder.permute_swapped::(perm_inputs, bit); - let perm_outs = H::permute_swapped(perm_inputs, bit, builder); - let hash_outs = perm_outs.squeeze()[0..NUM_HASH_OUT_ELTS] - .try_into() - .unwrap(); - state = HashOutTarget { - elements: hash_outs, - }; - current_height -= 1; - - if leaf_data_index < leaf_heights.len() - && current_height == leaf_heights[leaf_data_index] - { - let mut new_leaves = state.elements.to_vec(); - new_leaves.extend_from_slice(&leaf_data[leaf_data_index]); - state = builder.hash_or_noop::(new_leaves); - - leaf_data_index += 1; - } - } - - for i in 0..NUM_HASH_OUT_ELTS { - let result = builder.random_access( - cap_index, - merkle_cap.0.iter().map(|h| h.elements[i]).collect(), - ); - builder.connect(result, state.elements[i]); - } - } - - pub fn connect_hashes(&mut self, builder: &mut CircuitBuilder, x: HashOutTarget, y: HashOutTarget) { - for i in 0..NUM_HASH_OUT_ELTS { - builder.connect(x.elements[i], y.elements[i]); - } - } - - pub fn connect_merkle_caps(&mut self, builder: &mut CircuitBuilder, x: &MerkleCapTarget, y: &MerkleCapTarget) { - for (h0, h1) in x.0.iter().zip_eq(&y.0) { - self.connect_hashes(builder, *h0, *h1); - } - } - - pub fn connect_verifier_data(&mut self, builder: &mut CircuitBuilder, x: &VerifierCircuitTarget, y: &VerifierCircuitTarget) { - self.connect_merkle_caps(builder, &x.constants_sigmas_cap, &y.constants_sigmas_cap); - self.connect_hashes(builder, x.circuit_digest, y.circuit_digest); - } -} - -#[cfg(test)] -pub mod tests { - use std::time::Instant; - use rand::rngs::OsRng; - use rand::Rng; - - use super::*; - use plonky2::field::types::Field; - use crate::merkle_tree::capped_tree::MerkleTree; - use plonky2::iop::witness::PartialWitness; - use plonky2::plonk::circuit_data::CircuitConfig; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - pub fn random_data(n: usize, k: usize) -> Vec> { - (0..n).map(|_| F::rand_vec(k)).collect() - } - - #[test] - fn test_merkle_circuit() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - type H = PoseidonHash; - - // create Merkle tree - let log_n = 8; - let n = 1 << log_n; - let cap_height = 1; - let leaves = random_data::(n, LEAF_LEN); - let tree = MerkleTree::>::Hasher>::new(leaves, cap_height); - - // ---- prover zone ---- - // Build and prove - let start_build = Instant::now(); - let mut mt_circuit = MerkleTreeCircuit::{ tree: tree.clone(), _phantom: Default::default() }; - let leaf_index: usize = OsRng.gen_range(0..n); - let config = CircuitConfig::standard_recursion_config(); - let (data, proof_with_pub_input) = mt_circuit.build_and_prove(config,leaf_index).unwrap(); - println!("build and prove time is: {:?}", start_build.elapsed()); - - let vd = data.verifier_data(); - let pub_input = proof_with_pub_input.public_inputs; - let proof = proof_with_pub_input.proof; - - // ---- verifier zone ---- - let start_verifier = Instant::now(); - assert!(mt_circuit.verify(&vd,pub_input,proof).is_ok()); - println!("verify time is: {:?}", start_verifier.elapsed()); - - Ok(()) - } - - #[test] - fn mod_test_merkle_circuit() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - // create Merkle tree - let log_n = 8; - let n = 1 << log_n; - let cap_height = 0; - let leaves = random_data::(n, LEAF_LEN); - let tree = MerkleTree::>::Hasher>::new(leaves, cap_height); - - // Build circuit - let start_build = Instant::now(); - let mut mt_circuit = MerkleTreeCircuit{ tree: tree.clone(), _phantom: Default::default() }; - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - let targets = mt_circuit.build_circuit(&mut builder); - let data = builder.build::(); - let vd = data.verifier_data(); - println!("build time is: {:?}", start_build.elapsed()); - - // Prover Zone - let start_prover = Instant::now(); - let mut pw = PartialWitness::new(); - let leaf_index: usize = OsRng.gen_range(0..n); - let proof = tree.prove(leaf_index); - mt_circuit.fill_targets(&mut pw, leaf_index, targets); - let proof_with_pub_input = mt_circuit.prove(data,pw).unwrap(); - let pub_input = proof_with_pub_input.public_inputs; - let proof = proof_with_pub_input.proof; - println!("prove time is: {:?}", start_prover.elapsed()); - - // Verifier zone - let start_verifier = Instant::now(); - assert!(mt_circuit.verify(&vd,pub_input,proof).is_ok()); - println!("verify time is: {:?}", start_verifier.elapsed()); - - Ok(()) - } -} \ No newline at end of file diff --git a/codex-plonky2-circuits/src/circuits/keyed_compress.rs b/codex-plonky2-circuits/src/circuits/keyed_compress.rs index 2cd6c9e..6f2dfc2 100644 --- a/codex-plonky2-circuits/src/circuits/keyed_compress.rs +++ b/codex-plonky2-circuits/src/circuits/keyed_compress.rs @@ -8,7 +8,10 @@ use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; /// Compression function which takes two 256 bit inputs (HashOut) and u64 key (which is converted to field element in the function) /// and returns a 256 bit output (HashOut). -pub fn key_compress >(x: HashOut, y: HashOut, key: u64) -> HashOut { +pub fn key_compress< + F: RichField, + H:Hasher +>(x: HashOut, y: HashOut, key: u64) -> HashOut { debug_assert_eq!(x.elements.len(), NUM_HASH_OUT_ELTS); debug_assert_eq!(y.elements.len(), NUM_HASH_OUT_ELTS); diff --git a/codex-plonky2-circuits/src/circuits/merkle_circuit.rs b/codex-plonky2-circuits/src/circuits/merkle_circuit.rs new file mode 100644 index 0000000..a20269f --- /dev/null +++ b/codex-plonky2-circuits/src/circuits/merkle_circuit.rs @@ -0,0 +1,176 @@ +// Plonky2 Circuit implementation of "safe" merkle tree +// consistent with the one in codex: +// https://github.com/codex-storage/codex-storage-proofs-circuits/blob/master/circuit/codex/merkle.circom + +use anyhow::Result; +use plonky2::field::extension::Extendable; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::{HashOutTarget, RichField, NUM_HASH_OUT_ELTS}; +use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher, PoseidonGoldilocksConfig}; +use std::marker::PhantomData; +use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; +use crate::circuits::keyed_compress::key_compress_circuit; +use crate::circuits::params::HF; +use crate::circuits::utils::{add_assign_hash_out_target, assign_bool_targets, assign_hash_out_targets, mul_hash_out_target, usize_to_bits_le_padded}; +use crate::merkle_tree::merkle_safe::{KEY_NONE,KEY_BOTTOM_LAYER}; + +/// Merkle tree targets representing the input to the circuit +#[derive(Clone)] +pub struct MerkleTreeTargets{ + pub leaf: HashOutTarget, + pub path_bits: Vec, + pub last_bits: Vec, + pub mask_bits: Vec, + pub merkle_path: MerkleProofTarget, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct MerkleProofTarget { + /// The Merkle digest of each sibling subtree, staying from the bottommost layer. + pub path: Vec, +} + +/// Merkle tree circuit contains the functions for +/// building, proving and verifying the circuit. +#[derive(Clone)] +pub struct MerkleTreeCircuit< + F: RichField + Extendable + Poseidon2, + const D: usize, +> { + pub phantom_data: PhantomData, +} + +impl< + F: RichField + Extendable + Poseidon2, + const D: usize, +> MerkleTreeCircuit { + + pub fn new() -> Self{ + Self{ + phantom_data: Default::default(), + } + } + + /// takes the params from the targets struct + /// outputs the reconstructed merkle root + pub fn reconstruct_merkle_root_circuit( + builder: &mut CircuitBuilder, + targets: &mut MerkleTreeTargets, + max_depth: usize, + ) -> HashOutTarget { + let mut state: HashOutTarget = targets.leaf; + let zero = builder.zero(); + let one = builder.one(); + let two = builder.two(); + debug_assert_eq!(targets.path_bits.len(), targets.merkle_path.path.len()); + + // compute is_last + let mut is_last = vec![BoolTarget::new_unsafe(zero); max_depth + 1]; + is_last[max_depth] = BoolTarget::new_unsafe(one); // set isLast[max_depth] to 1 (true) + for i in (0..max_depth).rev() { + let eq_out = builder.is_equal(targets.path_bits[i].target , targets.last_bits[i].target); + is_last[i] = builder.and( is_last[i + 1] , eq_out); + } + + let mut i: usize = 0; + for (&bit, &sibling) in targets.path_bits.iter().zip(&targets.merkle_path.path) { + debug_assert_eq!(sibling.elements.len(), NUM_HASH_OUT_ELTS); + + let bottom = if i == 0 { + builder.constant(F::from_canonical_u64(KEY_BOTTOM_LAYER)) + } else { + builder.constant(F::from_canonical_u64(KEY_NONE)) + }; + + // compute: odd = isLast[i] * (1-pathBits[i]); + // compute: key = bottom + 2*odd + let mut odd = builder.sub(one, targets.path_bits[i].target); + odd = builder.mul(is_last[i].target, odd); + odd = builder.mul(two, odd); + let key = builder.add(bottom,odd); + + // select left and right based on path_bit + let mut left = vec![]; + let mut right = vec![]; + for i in 0..NUM_HASH_OUT_ELTS { + left.push( builder.select(bit, sibling.elements[i], state.elements[i])); + right.push( builder.select(bit, state.elements[i], sibling.elements[i])); + } + + state = key_compress_circuit::(builder,left,right,key); + + i += 1; + } + + return state; + } + + /// takes the params from the targets struct + /// outputs the reconstructed merkle root + /// this one uses the mask bits + pub fn reconstruct_merkle_root_circuit_with_mask( + builder: &mut CircuitBuilder, + targets: &mut MerkleTreeTargets, + max_depth: usize, + ) -> HashOutTarget { + let mut state: Vec = Vec::with_capacity(max_depth+1); + state.push(targets.leaf); + let zero = builder.zero(); + let one = builder.one(); + let two = builder.two(); + debug_assert_eq!(targets.path_bits.len(), targets.merkle_path.path.len()); + + // compute is_last + let mut is_last = vec![BoolTarget::new_unsafe(zero); max_depth + 1]; + is_last[max_depth] = BoolTarget::new_unsafe(one); // set isLast[max_depth] to 1 (true) + for i in (0..max_depth).rev() { + let eq_out = builder.is_equal(targets.path_bits[i].target , targets.last_bits[i].target); + is_last[i] = builder.and( is_last[i + 1] , eq_out); + } + + let mut i: usize = 0; + for (&bit, &sibling) in targets.path_bits.iter().zip(&targets.merkle_path.path) { + debug_assert_eq!(sibling.elements.len(), NUM_HASH_OUT_ELTS); + + let bottom = if i == 0 { + builder.constant(F::from_canonical_u64(KEY_BOTTOM_LAYER)) + } else { + builder.constant(F::from_canonical_u64(KEY_NONE)) + }; + + // compute: odd = isLast[i] * (1-pathBits[i]); + // compute: key = bottom + 2*odd + let mut odd = builder.sub(one, targets.path_bits[i].target); + odd = builder.mul(is_last[i].target, odd); + odd = builder.mul(two, odd); + let key = builder.add(bottom,odd); + + // select left and right based on path_bit + let mut left = vec![]; + let mut right = vec![]; + for j in 0..NUM_HASH_OUT_ELTS { + left.push( builder.select(bit, sibling.elements[j], state[i].elements[j])); + right.push( builder.select(bit, state[i].elements[j], sibling.elements[j])); + } + + state.push(key_compress_circuit::(builder,left,right,key)); + + i += 1; + } + + // another way to do this is to use builder.select + // but that might be less efficient & more constraints + let mut reconstructed_root = HashOutTarget::from_vec([builder.zero();4].to_vec()); + for k in 0..max_depth { + let diff = builder.sub(targets.mask_bits[k].target, targets.mask_bits[k+1].target); + let mul_result = mul_hash_out_target(builder,&diff,&mut state[k+1]); + add_assign_hash_out_target(builder,&mut reconstructed_root, &mul_result); + } + + // reconstructed_root + state[max_depth] + + } +} diff --git a/codex-plonky2-circuits/src/circuits/mod.rs b/codex-plonky2-circuits/src/circuits/mod.rs index 523eb76..93cf79c 100644 --- a/codex-plonky2-circuits/src/circuits/mod.rs +++ b/codex-plonky2-circuits/src/circuits/mod.rs @@ -1,6 +1,6 @@ -pub mod capped_tree_circuit; -pub mod safe_tree_circuit; -pub mod prove_single_cell; +// pub mod capped_tree_circuit; +pub mod merkle_circuit; +// pub mod prove_single_cell; pub mod sample_cells; pub mod utils; pub mod params; diff --git a/codex-plonky2-circuits/src/circuits/params.rs b/codex-plonky2-circuits/src/circuits/params.rs index 305d87e..528f171 100644 --- a/codex-plonky2-circuits/src/circuits/params.rs +++ b/codex-plonky2-circuits/src/circuits/params.rs @@ -1,25 +1,10 @@ // global params for the circuits -// only change params here use plonky2::hash::poseidon::PoseidonHash; - -// constants and types used throughout the circuit -pub const N_FIELD_ELEMS_PER_CELL: usize = 256; -pub const BOT_DEPTH: usize = 5; // block depth - depth of the block merkle tree -pub const MAX_DEPTH: usize = 16; // depth of big tree (slot tree depth + block tree depth) -pub const N_CELLS_IN_BLOCKS: usize = 1< and causing a lot of headache // will look into this later. -pub type HF = PoseidonHash; \ No newline at end of file +pub type HF = PoseidonHash; + diff --git a/codex-plonky2-circuits/src/circuits/prove_single_cell.rs b/codex-plonky2-circuits/src/circuits/prove_single_cell.rs deleted file mode 100644 index bed9629..0000000 --- a/codex-plonky2-circuits/src/circuits/prove_single_cell.rs +++ /dev/null @@ -1,475 +0,0 @@ -// prove single cell -// consistent with: -// https://github.com/codex-storage/codex-storage-proofs-circuits/blob/master/circuit/codex/single_cell.circom -// circuit consists of: -// - reconstruct the block merkle root -// - use merkle root as leaf and reconstruct slot root -// - check equality with given slot root - -use anyhow::Result; -use plonky2::field::extension::Extendable; -use plonky2::hash::hash_types::{HashOut, RichField}; -use plonky2::iop::target::{BoolTarget, Target}; -use plonky2::iop::witness::{PartialWitness, WitnessWrite, Witness}; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::circuit_data::CircuitConfig; -use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher, GenericHashOut}; -use std::marker::PhantomData; -use itertools::Itertools; - -use crate::merkle_tree::merkle_safe::MerkleTree; -use plonky2::hash::poseidon::PoseidonHash; - -use plonky2::hash::hash_types::{HashOutTarget, NUM_HASH_OUT_ELTS}; -use crate::merkle_tree::merkle_safe::{MerkleProof, MerkleProofTarget}; -use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; - -use plonky2::plonk::config::PoseidonGoldilocksConfig; - -use plonky2::hash::hashing::PlonkyPermutation; -use crate::circuits::safe_tree_circuit::{MerkleTreeCircuit, MerkleTreeTargets}; -use crate::circuits::utils::usize_to_bits_le_padded; -use crate::circuits::params::{MAX_DEPTH, BOT_DEPTH, N_FIELD_ELEMS_PER_CELL, N_CELLS_IN_BLOCKS, N_BLOCKS, N_CELLS, HF}; - -// ------ Slot Tree -------- - -#[derive(Clone)] -pub struct SlotTreeCircuit< - F: RichField + Extendable + Poseidon2, - const D: usize, -> { - pub tree: MerkleTreeCircuit, // slot tree - pub block_trees: Vec>, // vec of block trees - pub cell_data: Vec>, // cell data as field elements -} - -impl< - F: RichField + Extendable + Poseidon2, - const D: usize, -> Default for SlotTreeCircuit{ - /// slot tree with fake data, for testing only - fn default() -> Self { - // generate fake cell data - let mut cell_data = (0..N_CELLS) - .map(|i|{ - (0..N_FIELD_ELEMS_PER_CELL) - .map(|j| F::from_canonical_u64((j+i) as u64)) - .collect::>() - }) - .collect::>(); - // hash it - let leaves: Vec> = cell_data - .iter() - .map(|element| { - HF::hash_no_pad(&element) - }) - .collect(); - // zero hash - let zero = HashOut { - elements: [F::ZERO; 4], - }; - // create block tree - let block_trees = (0..N_BLOCKS) - .map(|i| { - let start = i * N_CELLS_IN_BLOCKS; - let end = (i + 1) * N_CELLS_IN_BLOCKS; - let b_tree = Self::get_block_tree(&leaves[start..end].to_vec()); // use helper function - MerkleTreeCircuit::{ tree:b_tree} - }) - .collect::>(); - // get the roots or block trees - let block_roots = block_trees.iter() - .map(|t| { - t.tree.root().unwrap() - }) - .collect::>(); - // create slot tree - let slot_tree = MerkleTree::::new(&block_roots, zero).unwrap(); - - Self{ - tree: MerkleTreeCircuit::{ tree:slot_tree}, - block_trees, - cell_data, - } - } -} - -impl< - F: RichField + Extendable + Poseidon2, - const D: usize, -> SlotTreeCircuit { - - /// Slot tree with fake data, for testing only - pub fn new_for_testing() -> Self { - // Generate fake cell data for one block - let cell_data_block = (0..N_CELLS_IN_BLOCKS) - .map(|i| { - (0..N_FIELD_ELEMS_PER_CELL) - .map(|j| F::from_canonical_u64((j + i) as u64)) - .collect::>() - }) - .collect::>(); - - // Hash the cell data block to create leaves for one block - let leaves_block: Vec> = cell_data_block - .iter() - .map(|element| { - HF::hash_no_pad(&element) - }) - .collect(); - - // Zero hash - let zero = HashOut { - elements: [F::ZERO; 4], - }; - - // Create a block tree from the leaves of one block - let b_tree = Self::get_block_tree(&leaves_block); - - // Create a block tree circuit - let block_tree_circuit = MerkleTreeCircuit:: { - tree: b_tree, - // _phantom: Default::default(), - }; - - // Now replicate this block tree for all N_BLOCKS blocks - let block_trees = vec![block_tree_circuit.clone(); N_BLOCKS]; - - // Get the roots of block trees - let block_roots = block_trees - .iter() - .map(|t| t.tree.root().unwrap()) - .collect::>(); - - // Create the slot tree from block roots - let slot_tree = MerkleTree::::new(&block_roots, zero).unwrap(); - - // Create the full cell data and cell hash by repeating the block data - let cell_data = vec![cell_data_block.clone(); N_BLOCKS].concat(); - - // Return the constructed Self - Self { - tree: MerkleTreeCircuit:: { - tree: slot_tree, - }, - block_trees, - cell_data, - } - } - - - /// same as default but with supplied cell data - pub fn new(cell_data: Vec>) -> Self{ - let leaves: Vec> = cell_data - .iter() - .map(|element| { - HF::hash_no_pad(element) - }) - .collect(); - let zero = HashOut { - elements: [F::ZERO; 4], - }; - let block_trees = (0..N_BLOCKS as usize) - .map(|i| { - let start = i * N_CELLS_IN_BLOCKS; - let end = (i + 1) * N_CELLS_IN_BLOCKS; - let b_tree = Self::get_block_tree(&leaves[start..end].to_vec()); - MerkleTreeCircuit::{ tree:b_tree} - }) - .collect::>(); - let block_roots = block_trees.iter() - .map(|t| { - t.tree.root().unwrap() - }) - .collect::>(); - let slot_tree = MerkleTree::::new(&block_roots, zero).unwrap(); - Self{ - tree: MerkleTreeCircuit::{ tree:slot_tree}, - block_trees, - cell_data, - } - } - - /// generates a proof for given leaf index - /// the path in the proof is a combined block and slot path to make up the full path - pub fn get_proof(&self, index: usize) -> MerkleProof { - let block_index = index/ N_CELLS_IN_BLOCKS; - let leaf_index = index % N_CELLS_IN_BLOCKS; - let block_proof = self.block_trees[block_index].tree.get_proof(leaf_index).unwrap(); - let slot_proof = self.tree.tree.get_proof(block_index).unwrap(); - - // Combine the paths from the block and slot proofs - let mut combined_path = block_proof.path.clone(); - combined_path.extend(slot_proof.path.clone()); - - MerkleProof:: { - index: index, - path: combined_path, - nleaves: self.cell_data.len(), - zero: block_proof.zero.clone(), - } - - } - - /// verify the given proof for slot tree, checks equality with given root - pub fn verify_cell_proof(&self, proof: MerkleProof, root: HashOut) -> Result{ - let mut block_path_bits = usize_to_bits_le_padded(proof.index, MAX_DEPTH); - let last_index = N_CELLS - 1; - let mut block_last_bits = usize_to_bits_le_padded(last_index, MAX_DEPTH); - - let split_point = BOT_DEPTH; - - let slot_last_bits = block_last_bits.split_off(split_point); - let slot_path_bits = block_path_bits.split_off(split_point); - - let leaf_hash = HF::hash_no_pad(&self.cell_data[proof.index]); - - let mut block_path = proof.path; - let slot_path = block_path.split_off(split_point); - - let block_res = MerkleProof::::reconstruct_root2(leaf_hash,block_path_bits.clone(),block_last_bits.clone(),block_path); - let reconstructed_root = MerkleProof::::reconstruct_root2(block_res.unwrap(),slot_path_bits,slot_last_bits,slot_path); - - Ok(reconstructed_root.unwrap() == root) - } - - fn get_block_tree(leaves: &Vec>) -> MerkleTree { - let zero = HashOut { - elements: [F::ZERO; 4], - }; - // Build the Merkle tree - let block_tree = MerkleTree::::new(leaves, zero).unwrap(); - block_tree - } -} - -//------- single cell struct ------ -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct SingleCellTargets{ - pub expected_slot_root_target: HashOutTarget, - pub proof_target: MerkleProofTarget, - pub leaf_target: Vec, - pub path_bits: Vec, - pub last_bits: Vec, -} - -//------- circuit impl -------- - -impl< - F: RichField + Extendable + Poseidon2, - const D: usize, -> SlotTreeCircuit { - - pub fn prove_single_cell( - builder: &mut CircuitBuilder:: - ) -> SingleCellTargets { - - // Retrieve tree depth - let depth = MAX_DEPTH; - - // Create virtual targets - let mut leaf = (0..N_FIELD_ELEMS_PER_CELL).map(|_| builder.add_virtual_target()).collect::>(); - - let mut hash_inputs:Vec= Vec::new(); - hash_inputs.extend_from_slice(&leaf); - let leaf_hash = builder.hash_n_to_hash_no_pad::(hash_inputs); - - // path bits (binary decomposition of leaf_index) - let mut block_path_bits = (0..BOT_DEPTH).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); - let mut slot_path_bits = (0..(depth - BOT_DEPTH)).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); - - // last bits (binary decomposition of last_index = nleaves - 1) - let block_last_bits = (0..BOT_DEPTH).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); - let slot_last_bits = (0..(depth-BOT_DEPTH)).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); - - // mask bits (binary decomposition of last_index = nleaves - 1) - let block_mask_bits = (0..BOT_DEPTH+1).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); - let slot_mask_bits = (0..(depth-BOT_DEPTH)+1).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); - - - // Merkle path (sibling hashes from leaf to root) - let mut block_merkle_path = MerkleProofTarget { - path: (0..BOT_DEPTH).map(|_| builder.add_virtual_hash()).collect(), - }; - let mut slot_merkle_path = MerkleProofTarget { - path: (0..(depth - BOT_DEPTH)).map(|_| builder.add_virtual_hash()).collect(), - }; - - // expected Merkle root - let slot_expected_root = builder.add_virtual_hash(); - - let mut block_targets = MerkleTreeTargets { - leaf: leaf_hash, - path_bits:block_path_bits, - last_bits: block_last_bits, - mask_bits: block_mask_bits, - merkle_path: block_merkle_path, - }; - - // reconstruct block root - let block_root = MerkleTreeCircuit::::reconstruct_merkle_root_circuit(builder, &mut block_targets); - - // create MerkleTreeTargets struct - let mut slot_targets = MerkleTreeTargets { - leaf: block_root, - path_bits:slot_path_bits, - last_bits:slot_last_bits, - mask_bits:slot_mask_bits, - merkle_path:slot_merkle_path, - }; - - // reconstruct slot root with block root as leaf - let slot_root = MerkleTreeCircuit::::reconstruct_merkle_root_circuit(builder, &mut slot_targets); - - // check equality with expected root - for i in 0..NUM_HASH_OUT_ELTS { - builder.connect(slot_expected_root.elements[i], slot_root.elements[i]); - } - - let mut proof_target = MerkleProofTarget{ - path: block_targets.merkle_path.path, - }; - proof_target.path.extend_from_slice(&slot_targets.merkle_path.path); - - let mut path_bits = block_targets.path_bits; - path_bits.extend_from_slice(&slot_targets.path_bits); - - let mut last_bits = block_targets.last_bits; - last_bits.extend_from_slice(&slot_targets.last_bits); - - let mut cell_targets = SingleCellTargets { - expected_slot_root_target: slot_expected_root, - proof_target, - leaf_target: leaf, - path_bits, - last_bits, - }; - - // Return MerkleTreeTargets - cell_targets - } - - /// assign the witness values in the circuit targets - /// this takes leaf_index, leaf, and proof (generated from slot_tree) - /// and fills all required circuit targets(circuit inputs) - pub fn single_cell_assign_witness( - &self, - pw: &mut PartialWitness, - targets: &mut SingleCellTargets, - leaf_index: usize, - leaf: &Vec, - proof: MerkleProof, - )-> Result<()> { - - // Assign the leaf to the leaf target - for i in 0..targets.leaf_target.len(){ - pw.set_target(targets.leaf_target[i], leaf[i]); - } - - // Convert `leaf_index` to binary bits and assign as path_bits - let path_bits = usize_to_bits_le_padded(leaf_index, MAX_DEPTH); - for (i, bit) in path_bits.iter().enumerate() { - pw.set_bool_target(targets.path_bits[i], *bit); - } - - // get `last_index` (nleaves - 1) in binary bits and assign - let last_index = N_CELLS - 1; - let last_bits = usize_to_bits_le_padded(last_index, MAX_DEPTH); - for (i, bit) in last_bits.iter().enumerate() { - pw.set_bool_target(targets.last_bits[i], *bit); - } - - // assign the Merkle path (sibling hashes) to the targets - for (i, sibling_hash) in proof.path.iter().enumerate() { - // This is a bit hacky because it should be HashOutTarget, but it is H:Hash - // pw.set_hash_target(targets.merkle_path.path[i],sibling_hash); - // TODO: fix this HashOutTarget later - let sibling_hash_out = sibling_hash.to_vec(); - for j in 0..sibling_hash_out.len() { - pw.set_target(targets.proof_target.path[i].elements[j], sibling_hash_out[j]); - } - } - - // assign the expected Merkle root to the target - let expected_root = self.tree.tree.root()?; - // TODO: fix this HashOutTarget later same issue as above - let expected_root_hash_out = expected_root.to_vec(); - for j in 0..expected_root_hash_out.len() { - pw.set_target(targets.expected_slot_root_target.elements[j], expected_root_hash_out[j]); - } - - Ok(()) - } - - fn hash_leaf(builder: &mut CircuitBuilder, leaf: &mut Vec){ - builder.hash_n_to_hash_no_pad::(leaf.to_owned()); - } -} - -#[cfg(test)] -mod tests { - use std::time::Instant; - use super::*; - use plonky2::plonk::circuit_data::CircuitConfig; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use plonky2::iop::witness::PartialWitness; - - //types for tests - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - type H = PoseidonHash; - - #[test] - fn test_prove_single_cell(){ - let slot_t = SlotTreeCircuit::::default(); - let index = 8; - let proof = slot_t.get_proof(index); - let res = slot_t.verify_cell_proof(proof,slot_t.tree.tree.root().unwrap()).unwrap(); - assert_eq!(res, true); - } - - #[test] - fn test_cell_build_circuit() -> Result<()> { - - let slot_t = SlotTreeCircuit::::default(); - - // select leaf index to prove - let leaf_index: usize = 8; - - let proof = slot_t.get_proof(leaf_index); - // get the expected Merkle root - let expected_root = slot_t.tree.tree.root().unwrap(); - let res = slot_t.verify_cell_proof(proof.clone(),expected_root).unwrap(); - assert_eq!(res, true); - - // create the circuit - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - - let mut targets = SlotTreeCircuit::::prove_single_cell(&mut builder); - - // create a PartialWitness and assign - let mut pw = PartialWitness::new(); - slot_t.single_cell_assign_witness(&mut pw, &mut targets, leaf_index, &slot_t.cell_data[leaf_index], proof)?; - - // build the circuit - let data = builder.build::(); - println!("circuit size = {:?}", data.common.degree_bits()); - - // Prove the circuit with the assigned witness - let start_time = Instant::now(); - let proof_with_pis = data.prove(pw)?; - println!("prove_time = {:?}", start_time.elapsed()); - - // verify the proof - let verifier_data = data.verifier_data(); - assert!( - verifier_data.verify(proof_with_pis).is_ok(), - "Merkle proof verification failed" - ); - - Ok(()) - } -} - diff --git a/codex-plonky2-circuits/src/circuits/safe_tree_circuit.rs b/codex-plonky2-circuits/src/circuits/safe_tree_circuit.rs deleted file mode 100644 index 1040cc4..0000000 --- a/codex-plonky2-circuits/src/circuits/safe_tree_circuit.rs +++ /dev/null @@ -1,439 +0,0 @@ -// Plonky2 Circuit implementation of "safe" merkle tree -// consistent with the one in codex: -// https://github.com/codex-storage/codex-storage-proofs-circuits/blob/master/circuit/codex/merkle.circom - -use anyhow::Result; -use plonky2::field::extension::Extendable; -use plonky2::field::goldilocks_field::GoldilocksField; -use plonky2::field::types::Field; -use plonky2::hash::hash_types::{HashOut, HashOutTarget, RichField, NUM_HASH_OUT_ELTS}; -use plonky2::hash::hashing::PlonkyPermutation; -use plonky2::hash::poseidon::PoseidonHash; -use plonky2::iop::target::{BoolTarget, Target}; -use plonky2::iop::witness::{PartialWitness, Witness, WitnessWrite}; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, VerifierCircuitData}; -use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher, PoseidonGoldilocksConfig}; -use plonky2::plonk::proof::{Proof, ProofWithPublicInputs}; -use std::marker::PhantomData; -use std::os::macos::raw::stat; -use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; -use serde::Serialize; -use crate::circuits::keyed_compress::key_compress_circuit; -use crate::circuits::params::{HF, MAX_DEPTH}; -use crate::circuits::utils::usize_to_bits_le_padded; - -use crate::merkle_tree::merkle_safe::{MerkleTree, MerkleProofTarget}; -use crate::merkle_tree::merkle_safe::{KEY_NONE,KEY_BOTTOM_LAYER}; - - -/// Merkle tree targets representing the input to the circuit -// note: this omits the mask bits since in plonky2 we can -// uses the Plonk's permutation argument to check that two elements are equal. -// TODO: double check the need for mask -#[derive(Clone)] -pub struct MerkleTreeTargets{ - pub leaf: HashOutTarget, - pub path_bits: Vec, - pub last_bits: Vec, - pub mask_bits: Vec, - pub merkle_path: MerkleProofTarget, -} - -/// Merkle tree circuit contains the tree and functions for -/// building, proving and verifying the circuit. -#[derive(Clone)] -pub struct MerkleTreeCircuit< - F: RichField + Extendable + Poseidon2, - const D: usize, -> { - pub tree: MerkleTree, -} - -impl< - F: RichField + Extendable + Poseidon2, - const D: usize, -> MerkleTreeCircuit { - - /// defines the computations inside the circuit and returns the targets used - pub fn build_circuit( - &mut self, - builder: &mut CircuitBuilder:: - ) -> (MerkleTreeTargets, HashOutTarget) { - // Retrieve tree depth - let depth = self.tree.depth(); - - // Create virtual targets - let leaf = builder.add_virtual_hash(); - - // path bits (binary decomposition of leaf_index) - let path_bits = (0..MAX_DEPTH).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); - - // last bits (binary decomposition of last_index = nleaves - 1) - let last_bits = (0..MAX_DEPTH).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); - - // last bits (binary decomposition of last_index = nleaves - 1) - let mask_bits = (0..MAX_DEPTH+1).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); - - // Merkle path (sibling hashes from leaf to root) - let merkle_path = MerkleProofTarget { - path: (0..MAX_DEPTH).map(|_| builder.add_virtual_hash()).collect(), - }; - - // create MerkleTreeTargets struct - let mut targets = MerkleTreeTargets{ - leaf, - path_bits, - last_bits, - mask_bits, - merkle_path, - }; - - // Add Merkle proof verification constraints to the circuit - let expected_root_target = Self::reconstruct_merkle_root_circuit_with_mask(builder, &mut targets); - - // Return MerkleTreeTargets - (targets, expected_root_target) - } - - /// assign the witness values in the circuit targets - /// this takes leaf_index and fills all required circuit targets(inputs) - pub fn assign_witness( - &mut self, - pw: &mut PartialWitness, - targets: &mut MerkleTreeTargets, - leaf_index: usize, - )-> Result<()> { - // Get the total number of leaves and tree depth - let nleaves = self.tree.leaves_count(); - let depth = self.tree.depth(); - - // get the Merkle proof for the specified leaf index - let proof = self.tree.get_proof(leaf_index)?; - - // get the leaf hash from the Merkle tree - let leaf_hash = self.tree.layers[0][leaf_index].clone(); - - // Assign the leaf hash to the leaf target - pw.set_hash_target(targets.leaf, leaf_hash); - - // Convert `leaf_index` to binary bits and assign as path_bits - let path_bits = usize_to_bits_le_padded(leaf_index, MAX_DEPTH); - for (i, bit) in path_bits.iter().enumerate() { - pw.set_bool_target(targets.path_bits[i], *bit); - } - - // get `last_index` (nleaves - 1) in binary bits and assign - let last_index = nleaves - 1; - let last_bits = usize_to_bits_le_padded(last_index, MAX_DEPTH); - for (i, bit) in last_bits.iter().enumerate() { - pw.set_bool_target(targets.last_bits[i], *bit); - } - - // get mask bits - let mask_bits = usize_to_bits_le_padded(last_index, MAX_DEPTH+1); - for (i, bit) in mask_bits.iter().enumerate() { - pw.set_bool_target(targets.mask_bits[i], *bit); - } - - // assign the Merkle path (sibling hashes) to the targets - for i in 0..MAX_DEPTH { - if i>=proof.path.len() { - for j in 0..NUM_HASH_OUT_ELTS { - pw.set_target(targets.merkle_path.path[i].elements[j], F::ZERO); - } - continue - } - // This is a bit hacky because it should be HashOutTarget, but it is H:Hash - // pw.set_hash_target(targets.merkle_path.path[i],sibling_hash); - // TODO: fix this HashOutTarget later - let sibling_hash = proof.path[i]; - let sibling_hash_out = sibling_hash.to_vec(); - for j in 0..sibling_hash_out.len() { - pw.set_target(targets.merkle_path.path[i].elements[j], sibling_hash_out[j]); - } - } - - Ok(()) - } - - /// takes the params from the targets struct - /// outputs the reconstructed merkle root - pub fn reconstruct_merkle_root_circuit( - builder: &mut CircuitBuilder, - targets: &mut MerkleTreeTargets, - ) -> HashOutTarget { - let max_depth = targets.path_bits.len(); - let mut state: HashOutTarget = targets.leaf; - let zero = builder.zero(); - let one = builder.one(); - let two = builder.two(); - debug_assert_eq!(targets.path_bits.len(), targets.merkle_path.path.len()); - - // compute is_last - let mut is_last = vec![BoolTarget::new_unsafe(zero); max_depth + 1]; - is_last[max_depth] = BoolTarget::new_unsafe(one); // set isLast[max_depth] to 1 (true) - for i in (0..max_depth).rev() { - let eq_out = builder.is_equal(targets.path_bits[i].target , targets.last_bits[i].target); - is_last[i] = builder.and( is_last[i + 1] , eq_out); - } - - let mut i: usize = 0; - for (&bit, &sibling) in targets.path_bits.iter().zip(&targets.merkle_path.path) { - debug_assert_eq!(sibling.elements.len(), NUM_HASH_OUT_ELTS); - - let bottom = if i == 0 { - builder.constant(F::from_canonical_u64(KEY_BOTTOM_LAYER)) - } else { - builder.constant(F::from_canonical_u64(KEY_NONE)) - }; - - // compute: odd = isLast[i] * (1-pathBits[i]); - // compute: key = bottom + 2*odd - let mut odd = builder.sub(one, targets.path_bits[i].target); - odd = builder.mul(is_last[i].target, odd); - odd = builder.mul(two, odd); - let key = builder.add(bottom,odd); - - // select left and right based on path_bit - let mut left = vec![]; - let mut right = vec![]; - for i in 0..NUM_HASH_OUT_ELTS { - left.push( builder.select(bit, sibling.elements[i], state.elements[i])); - right.push( builder.select(bit, state.elements[i], sibling.elements[i])); - } - - state = key_compress_circuit::(builder,left,right,key); - - i += 1; - } - - return state; - } - - /// takes the params from the targets struct - /// outputs the reconstructed merkle root - /// this one uses the mask bits - pub fn reconstruct_merkle_root_circuit_with_mask( - builder: &mut CircuitBuilder, - targets: &mut MerkleTreeTargets, - ) -> HashOutTarget { - let max_depth = targets.path_bits.len(); - let mut state: Vec = Vec::with_capacity(max_depth+1); - state.push(targets.leaf); - let zero = builder.zero(); - let one = builder.one(); - let two = builder.two(); - debug_assert_eq!(targets.path_bits.len(), targets.merkle_path.path.len()); - - // compute is_last - let mut is_last = vec![BoolTarget::new_unsafe(zero); max_depth + 1]; - is_last[max_depth] = BoolTarget::new_unsafe(one); // set isLast[max_depth] to 1 (true) - for i in (0..max_depth).rev() { - let eq_out = builder.is_equal(targets.path_bits[i].target , targets.last_bits[i].target); - is_last[i] = builder.and( is_last[i + 1] , eq_out); - } - - let mut i: usize = 0; - for (&bit, &sibling) in targets.path_bits.iter().zip(&targets.merkle_path.path) { - debug_assert_eq!(sibling.elements.len(), NUM_HASH_OUT_ELTS); - - let bottom = if i == 0 { - builder.constant(F::from_canonical_u64(KEY_BOTTOM_LAYER)) - } else { - builder.constant(F::from_canonical_u64(KEY_NONE)) - }; - - // compute: odd = isLast[i] * (1-pathBits[i]); - // compute: key = bottom + 2*odd - let mut odd = builder.sub(one, targets.path_bits[i].target); - odd = builder.mul(is_last[i].target, odd); - odd = builder.mul(two, odd); - let key = builder.add(bottom,odd); - - // select left and right based on path_bit - let mut left = vec![]; - let mut right = vec![]; - for j in 0..NUM_HASH_OUT_ELTS { - left.push( builder.select(bit, sibling.elements[j], state[i].elements[j])); - right.push( builder.select(bit, state[i].elements[j], sibling.elements[j])); - } - - state.push(key_compress_circuit::(builder,left,right,key)); - - i += 1; - } - - // println!("mask = {}, last = {}", targets.mask_bits.len(), targets.last_bits.len(), ); - - // another way to do this is to use builder.select - // but that might be less efficient & more constraints - let mut reconstructed_root = HashOutTarget::from_vec([builder.zero();4].to_vec()); - for k in 0..MAX_DEPTH { - let diff = builder.sub(targets.mask_bits[k].target, targets.mask_bits[k+1].target); - let mul_result = Self::mul_hash_out_target(builder,&diff,&mut state[k+1]); - Self::add_assign_hash_out_target(builder,&mut reconstructed_root, &mul_result); - } - - reconstructed_root - } - - /// helper fn to multiply a HashOutTarget by a Target - pub(crate) fn mul_hash_out_target(builder: &mut CircuitBuilder, t: &Target, hash_target: &mut HashOutTarget) -> HashOutTarget { - let mut mul_elements = vec![]; - for i in 0..NUM_HASH_OUT_ELTS { - mul_elements.push(builder.mul(hash_target.elements[i], *t)); - } - HashOutTarget::from_vec(mul_elements) - } - - /// helper fn to add AND assign a HashOutTarget (hot) to a mutable HashOutTarget (mut_hot) - pub(crate) fn add_assign_hash_out_target(builder: &mut CircuitBuilder, mut_hot: &mut HashOutTarget, hot: &HashOutTarget) { - for i in 0..NUM_HASH_OUT_ELTS { - mut_hot.elements[i] = (builder.add(mut_hot.elements[i], hot.elements[i])); - } - } -} - -// NOTE: for now these tests don't check the reconstructed root is equal to expected_root -// will be fixed later, but for that test check the prove_single_cell tests -#[cfg(test)] -mod tests { - use super::*; - use plonky2::field::types::Field; - use plonky2::plonk::circuit_data::CircuitConfig; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use plonky2::iop::witness::PartialWitness; - - #[test] - fn test_build_circuit() -> Result<()> { - // circuit params - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - type H = PoseidonHash; - - // Generate random leaf data - let nleaves = 10; // Number of leaves - let data = (0..nleaves) - .map(|i| GoldilocksField::from_canonical_u64(i)) - .collect::>(); - // Hash the data to obtain leaf hashes - let leaves: Vec> = data - .iter() - .map(|&element| { - // Hash each field element to get the leaf hash - PoseidonHash::hash_no_pad(&[element]) - }) - .collect(); - - //initialize the Merkle tree - let zero_hash = HashOut { - elements: [GoldilocksField::ZERO; 4], - }; - let tree = MerkleTree::::new(&leaves, zero_hash)?; - - // select leaf index to prove - let leaf_index: usize = 8; - - // get the Merkle proof for the selected leaf - let proof = tree.get_proof(leaf_index)?; - // sanity check: - let check = proof.verify(tree.layers[0][leaf_index],tree.root().unwrap()).unwrap(); - assert_eq!(check, true); - - // get the expected Merkle root - let expected_root = tree.root()?; - - // create the circuit - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - let mut circuit_instance = MerkleTreeCircuit:: { - tree: tree.clone(), - // _phantom: PhantomData, - }; - let (mut targets, expected_root_target) = circuit_instance.build_circuit(&mut builder); - - // create a PartialWitness and assign - let mut pw = PartialWitness::new(); - circuit_instance.assign_witness(&mut pw, &mut targets, leaf_index)?; - - // build the circuit - let data = builder.build::(); - - // Prove the circuit with the assigned witness - let proof_with_pis = data.prove(pw)?; - - // verify the proof - let verifier_data = data.verifier_data(); - assert!( - verifier_data.verify(proof_with_pis).is_ok(), - "Merkle proof verification failed" - ); - - Ok(()) - } - - // same as test above but for all leaves - #[test] - fn test_verify_all_leaves() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - type H = PoseidonHash; - - let nleaves = 10; // Number of leaves - let data = (0..nleaves) - .map(|i| GoldilocksField::from_canonical_u64(i as u64)) - .collect::>(); - // Hash the data to obtain leaf hashes - let leaves: Vec> = data - .iter() - .map(|&element| { - // Hash each field element to get the leaf hash - PoseidonHash::hash_no_pad(&[element]) - }) - .collect(); - - let zero_hash = HashOut { - elements: [GoldilocksField::ZERO; 4], - }; - let tree = MerkleTree::::new(&leaves, zero_hash)?; - - let expected_root = tree.root()?; - - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - let mut circuit_instance = MerkleTreeCircuit:: { - tree: tree.clone(), - }; - let (mut targets, expected_root_target) = circuit_instance.build_circuit(&mut builder); - - let data = builder.build::(); - - for leaf_index in 0..nleaves { - let proof = tree.get_proof(leaf_index)?; - let check = proof.verify(tree.layers[0][leaf_index], expected_root)?; - assert!( - check, - "Merkle proof verification failed for leaf index {}", - leaf_index - ); - - let mut pw = PartialWitness::new(); - - circuit_instance.assign_witness(&mut pw, &mut targets, leaf_index)?; - - let proof_with_pis = data.prove(pw)?; - - let verifier_data = data.verifier_data(); - assert!( - verifier_data.verify(proof_with_pis).is_ok(), - "Merkle proof verification failed in circuit for leaf index {}", - leaf_index - ); - } - - Ok(()) - } -} \ No newline at end of file diff --git a/codex-plonky2-circuits/src/circuits/sample_cells.rs b/codex-plonky2-circuits/src/circuits/sample_cells.rs index 95cca06..9adc341 100644 --- a/codex-plonky2-circuits/src/circuits/sample_cells.rs +++ b/codex-plonky2-circuits/src/circuits/sample_cells.rs @@ -11,25 +11,14 @@ use plonky2::hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS, RichF use plonky2::iop::target::{BoolTarget, Target}; use plonky2::iop::witness::{PartialWitness, WitnessWrite, Witness}; use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::circuit_data::CircuitConfig; use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher, GenericHashOut}; use std::marker::PhantomData; -use itertools::Itertools; - -use crate::merkle_tree::merkle_safe::MerkleTree; -use plonky2::hash::poseidon::PoseidonHash; - -use crate::merkle_tree::merkle_safe::{MerkleProof, MerkleProofTarget}; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; - -use plonky2::plonk::config::PoseidonGoldilocksConfig; - use plonky2::hash::hashing::PlonkyPermutation; -use crate::circuits::prove_single_cell::{SingleCellTargets, SlotTreeCircuit}; -use crate::circuits::params::{BOT_DEPTH, DATASET_DEPTH, HF, MAX_DEPTH, N_FIELD_ELEMS_PER_CELL, N_SAMPLES, TESTING_SLOT_INDEX}; +use crate::circuits::params::HF; -use crate::circuits::safe_tree_circuit::{MerkleTreeCircuit, MerkleTreeTargets}; -use crate::circuits::utils::{bits_le_padded_to_usize, calculate_cell_index_bits}; +use crate::circuits::merkle_circuit::{MerkleTreeCircuit, MerkleTreeTargets, MerkleProofTarget}; +use crate::circuits::utils::{assign_hash_out_targets, bits_le_padded_to_usize, calculate_cell_index_bits}; // ------ Dataset Tree -------- ///dataset tree containing all slot trees @@ -38,175 +27,81 @@ pub struct DatasetTreeCircuit< F: RichField + Extendable + Poseidon2, const D: usize, > { - pub tree: MerkleTreeCircuit, // dataset tree - pub slot_trees: Vec>, // vec of slot trees -} - -/// Dataset Merkle proof struct, containing the dataset proof and N_SAMPLES proofs. -#[derive(Clone)] -pub struct DatasetMerkleProof { - pub slot_index: usize, - pub entropy: usize, - pub dataset_proof: MerkleProof, // proof for dataset level tree - pub slot_proofs: Vec>, // proofs for sampled slot, contains N_SAMPLES proofs + params: CircuitParams, + phantom_data: PhantomData, } impl< F: RichField + Extendable + Poseidon2, const D: usize, -> Default for DatasetTreeCircuit { - /// dataset tree with fake data, for testing only - fn default() -> Self { - let mut slot_trees = vec![]; - let n_slots = 1<::default()); - } - // get the roots or slot trees - let slot_roots = slot_trees.iter() - .map(|t| { - t.tree.tree.root().unwrap() - }) - .collect::>(); - // zero hash - let zero = HashOut { - elements: [F::ZERO; 4], - }; - let dataset_tree = MerkleTree::::new(&slot_roots, zero).unwrap(); +> DatasetTreeCircuit { + pub fn new(params: CircuitParams) -> Self{ Self{ - tree: MerkleTreeCircuit::{ tree:dataset_tree}, - slot_trees, + params, + phantom_data: Default::default(), } } } - -impl< - F: RichField + Extendable + Poseidon2, - const D: usize, -> DatasetTreeCircuit { - /// dataset tree with fake data, for testing only - /// create data for only the TESTING_SLOT_INDEX in params file - pub fn new_for_testing() -> Self { - let mut slot_trees = vec![]; - let n_slots = 1<{ - tree: MerkleTreeCircuit { - tree: MerkleTree::::new(&[zero.clone()], zero.clone()).unwrap(), - }, - block_trees: vec![], - cell_data: vec![], - }; - for i in 0..n_slots { - if(i == TESTING_SLOT_INDEX) { - slot_trees.push(SlotTreeCircuit::::new_for_testing()); - }else{ - slot_trees.push(zero_slot.clone()); - } - - } - // get the roots or slot trees - let slot_roots = slot_trees.iter() - .map(|t| { - t.tree.tree.root().unwrap() - }) - .collect::>(); - let dataset_tree = MerkleTree::::new(&slot_roots, zero).unwrap(); - Self{ - tree: MerkleTreeCircuit::{ tree:dataset_tree}, - slot_trees, - } - } - - /// same as default but with supplied slot trees - pub fn new(slot_trees: Vec>) -> Self{ - // get the roots or slot trees - let slot_roots = slot_trees.iter() - .map(|t| { - t.tree.tree.root().unwrap() - }) - .collect::>(); - // zero hash - let zero = HashOut { - elements: [F::ZERO; 4], - }; - let dataset_tree = MerkleTree::::new(&slot_roots, zero).unwrap(); - Self{ - tree: MerkleTreeCircuit::{ tree:dataset_tree}, - slot_trees, - } - } - - /// generates a dataset level proof for given slot index - /// just a regular merkle tree proof - pub fn get_proof(&self, index: usize) -> MerkleProof { - let dataset_proof = self.tree.tree.get_proof(index).unwrap(); - dataset_proof - } - - /// generates a proof for given slot index - /// also takes entropy so it can use it sample the slot - pub fn sample_slot(&self, index: usize, entropy: usize) -> DatasetMerkleProof { - let dataset_proof = self.tree.tree.get_proof(index).unwrap(); - let slot = &self.slot_trees[index]; - let slot_root = slot.tree.tree.root().unwrap(); - let mut slot_proofs = vec![]; - // get the index for cell from H(slot_root|counter|entropy) - for i in 0..N_SAMPLES { - let cell_index_bits = calculate_cell_index_bits(entropy, slot_root, i); - let cell_index = bits_le_padded_to_usize(&cell_index_bits); - slot_proofs.push(slot.get_proof(cell_index)); - } - - DatasetMerkleProof{ - slot_index: index, - entropy, - dataset_proof, - slot_proofs, - } - } - - // verify the sampling - non-circuit version - pub fn verify_sampling(&self, proof: DatasetMerkleProof) -> Result{ - let slot = &self.slot_trees[proof.slot_index]; - let slot_root = slot.tree.tree.root().unwrap(); - // check dataset level proof - let d_res = proof.dataset_proof.verify(slot_root,self.tree.tree.root().unwrap()); - if(d_res.unwrap() == false){ - return Ok(false); - } - // sanity check - assert_eq!(N_SAMPLES, proof.slot_proofs.len()); - // get the index for cell from H(slot_root|counter|entropy) - for i in 0..N_SAMPLES { - let cell_index_bits = calculate_cell_index_bits(proof.entropy, slot_root, i); - let cell_index = bits_le_padded_to_usize(&cell_index_bits); - //check the cell_index is the same as one in the proof - assert_eq!(cell_index, proof.slot_proofs[i].index); - let s_res = slot.verify_cell_proof(proof.slot_proofs[i].clone(),slot_root); - if(s_res.unwrap() == false){ - return Ok(false); - } - } - Ok(true) - } +// params used for the circuits +// should be defined prior to building the circuit +#[derive(Clone)] +pub struct CircuitParams{ + pub max_depth: usize, + pub max_log2_n_slots: usize, + pub block_tree_depth: usize, + pub n_field_elems_per_cell: usize, + pub n_samples: usize, } #[derive(Clone)] -pub struct DatasetTargets{ - pub dataset_proof: MerkleProofTarget, // proof that slot_root in dataset tree +pub struct SampleTargets { + + pub entropy: HashOutTarget, pub dataset_root: HashOutTarget, + pub slot_index: Target, + + pub slot_root: HashOutTarget, + pub n_cells_per_slot: Target, + pub n_slots_per_dataset: Target, + + pub slot_proof: MerkleProofTarget, // proof that slot_root in dataset tree pub cell_data: Vec>, - pub entropy: HashOutTarget, - pub slot_index: Target, - pub slot_root: HashOutTarget, - pub slot_proofs: Vec, + pub merkle_paths: Vec, +} +#[derive(Clone)] +pub struct SampleCircuitInput< + F: RichField + Extendable + Poseidon2, + const D: usize, +>{ + pub entropy: Vec, + pub dataset_root: HashOut, + pub slot_index: F, + + pub slot_root: HashOut, + pub n_cells_per_slot: F, + pub n_slots_per_dataset: F, + + pub slot_proof: Vec>, // proof that slot_root in dataset tree + + pub cell_data: Vec>, + pub merkle_paths: Vec>>, + +} + +#[derive(Clone)] +pub struct MerklePath< + F: RichField + Extendable + Poseidon2, + const D: usize, +> { + path: Vec> +} + +#[derive(Clone)] +pub struct CellTarget { + pub data: Vec } //------- circuit impl -------- @@ -218,9 +113,17 @@ impl< // in-circuit sampling // TODO: make it more modular pub fn sample_slot_circuit( - &mut self, + &self, builder: &mut CircuitBuilder::, - )-> DatasetTargets { + )-> SampleTargets { + // circuit params + let CircuitParams { + max_depth, + max_log2_n_slots, + block_tree_depth, + n_field_elems_per_cell, + n_samples, + } = self.params; // constants let zero = builder.zero(); @@ -229,27 +132,24 @@ impl< // ***** prove slot root is in dataset tree ********* - // Retrieve dataset tree depth - let d_depth = DATASET_DEPTH; - // Create virtual target for slot root and index let slot_root = builder.add_virtual_hash(); let slot_index = builder.add_virtual_target(); // dataset path bits (binary decomposition of leaf_index) - let d_path_bits = builder.split_le(slot_index,d_depth); + let d_path_bits = builder.split_le(slot_index,max_log2_n_slots); + + // create virtual target for n_slots_per_dataset + let n_slots_per_dataset = builder.add_virtual_target(); // dataset last bits (binary decomposition of last_index = nleaves - 1) - let depth_target = builder.constant(F::from_canonical_u64(d_depth as u64)); - let mut d_last_index = builder.exp(two,depth_target,d_depth); - d_last_index = builder.sub(d_last_index, one); - let d_last_bits = builder.split_le(d_last_index,d_depth); - - let d_mask_bits = builder.split_le(d_last_index,d_depth+1); + let dataset_last_index = builder.sub(n_slots_per_dataset, one); + let d_last_bits = builder.split_le(dataset_last_index,max_log2_n_slots); + let d_mask_bits = builder.split_le(dataset_last_index,max_log2_n_slots+1); // dataset Merkle path (sibling hashes from leaf to root) let d_merkle_path = MerkleProofTarget { - path: (0..d_depth).map(|_| builder.add_virtual_hash()).collect(), + path: (0..max_log2_n_slots).map(|_| builder.add_virtual_hash()).collect(), }; // create MerkleTreeTargets struct @@ -263,7 +163,7 @@ impl< // dataset reconstructed root let d_reconstructed_root = - MerkleTreeCircuit::::reconstruct_merkle_root_circuit(builder, &mut d_targets); + MerkleTreeCircuit::::reconstruct_merkle_root_circuit_with_mask(builder, &mut d_targets, max_log2_n_slots); // expected Merkle root let d_expected_root = builder.add_virtual_hash(); @@ -279,24 +179,23 @@ impl< let mut slot_sample_proofs = vec![]; let entropy_target = builder.add_virtual_hash(); - //TODO: this can probably be optimized by supplying nCellsPerSlot as input to the circuit - let b_depth_target = builder.constant(F::from_canonical_u64(BOT_DEPTH as u64)); - let mut b_last_index = builder.exp(two,b_depth_target,BOT_DEPTH); - b_last_index = builder.sub(b_last_index, one); - let b_last_bits = builder.split_le(b_last_index,BOT_DEPTH); + // virtual target for n_cells_per_slot + let n_cells_per_slot = builder.add_virtual_target(); - let b_mask_bits = builder.split_le(b_last_index,BOT_DEPTH+1); + let slot_last_index = builder.sub(n_cells_per_slot, one); + let mut b_last_bits = builder.split_le(slot_last_index,max_depth); + let mut b_mask_bits = builder.split_le(slot_last_index,max_depth); - let s_depth_target = builder.constant(F::from_canonical_u64(MAX_DEPTH as u64)); - let mut s_last_index = builder.exp(two,s_depth_target,MAX_DEPTH); - s_last_index = builder.sub(s_last_index, one); - let s_last_bits = builder.split_le(s_last_index,MAX_DEPTH); - let s_mask_bits = builder.split_le(b_last_index,BOT_DEPTH+1); + let mut s_last_bits = b_last_bits.split_off(block_tree_depth); + let mut s_mask_bits = b_mask_bits.split_off(block_tree_depth); - for i in 0..N_SAMPLES{ + b_mask_bits.push(BoolTarget::new_unsafe(zero.clone())); + s_mask_bits.push(BoolTarget::new_unsafe(zero.clone())); + + for i in 0..n_samples{ // cell data targets - let mut data_i = (0..N_FIELD_ELEMS_PER_CELL).map(|_| builder.add_virtual_target()).collect::>(); + let mut data_i = (0..n_field_elems_per_cell).map(|_| builder.add_virtual_target()).collect::>(); let mut hash_inputs:Vec= Vec::new(); hash_inputs.extend_from_slice(&data_i); @@ -312,15 +211,15 @@ impl< } } // paths - let mut b_path_bits = Self::calculate_cell_index_bits(builder, &entropy_target, &d_targets.leaf, &ctr); - let mut s_path_bits = b_path_bits.split_off(BOT_DEPTH); + let mut b_path_bits = self.calculate_cell_index_bits(builder, &entropy_target, &d_targets.leaf, &ctr); + let mut s_path_bits = b_path_bits.split_off(block_tree_depth); let mut b_merkle_path = MerkleProofTarget { - path: (0..BOT_DEPTH).map(|_| builder.add_virtual_hash()).collect(), + path: (0..block_tree_depth).map(|_| builder.add_virtual_hash()).collect(), }; let mut s_merkle_path = MerkleProofTarget { - path: (0..(MAX_DEPTH - BOT_DEPTH)).map(|_| builder.add_virtual_hash()).collect(), + path: (0..(max_depth - block_tree_depth)).map(|_| builder.add_virtual_hash()).collect(), }; let mut block_targets = MerkleTreeTargets { @@ -332,7 +231,7 @@ impl< }; // reconstruct block root - let b_root = MerkleTreeCircuit::::reconstruct_merkle_root_circuit_with_mask(builder, &mut block_targets); + let b_root = MerkleTreeCircuit::::reconstruct_merkle_root_circuit_with_mask(builder, &mut block_targets, block_tree_depth); let mut slot_targets = MerkleTreeTargets { leaf: b_root, @@ -343,7 +242,7 @@ impl< }; // reconstruct slot root with block root as leaf - let slot_reconstructed_root = MerkleTreeCircuit::::reconstruct_merkle_root_circuit_with_mask(builder, &mut slot_targets); + let slot_reconstructed_root = MerkleTreeCircuit::::reconstruct_merkle_root_circuit_with_mask(builder, &mut slot_targets, max_depth-block_tree_depth); // check equality with expected root for i in 0..NUM_HASH_OUT_ELTS { @@ -361,204 +260,84 @@ impl< } - DatasetTargets{ - dataset_proof: d_targets.merkle_path, - dataset_root: d_expected_root, - cell_data: data_targets, + SampleTargets { entropy: entropy_target, + dataset_root: d_expected_root, slot_index, slot_root: d_targets.leaf, - slot_proofs: slot_sample_proofs, + n_cells_per_slot, + n_slots_per_dataset, + slot_proof: d_targets.merkle_path, + cell_data: data_targets, + merkle_paths: slot_sample_proofs, } } - pub fn calculate_cell_index_bits(builder: &mut CircuitBuilder::, entropy: &HashOutTarget, slot_root: &HashOutTarget, ctr: &HashOutTarget) -> Vec { + pub fn calculate_cell_index_bits(&self, builder: &mut CircuitBuilder::, entropy: &HashOutTarget, slot_root: &HashOutTarget, ctr: &HashOutTarget) -> Vec { let mut hash_inputs:Vec= Vec::new(); hash_inputs.extend_from_slice(&entropy.elements); hash_inputs.extend_from_slice(&slot_root.elements); hash_inputs.extend_from_slice(&ctr.elements); let hash_out = builder.hash_n_to_hash_no_pad::(hash_inputs); - let cell_index_bits = builder.low_bits(hash_out.elements[0], MAX_DEPTH, 64); + let cell_index_bits = builder.low_bits(hash_out.elements[0], self.params.max_depth, 64); cell_index_bits } pub fn sample_slot_assign_witness( - &mut self, + &self, pw: &mut PartialWitness, - targets: &mut DatasetTargets, - slot_index:usize, - entropy:usize, + targets: &mut SampleTargets, + witnesses: SampleCircuitInput, ){ - // dataset proof - let d_proof = self.tree.tree.get_proof(slot_index).unwrap(); + // circuit params + let CircuitParams { + max_depth, + max_log2_n_slots, + block_tree_depth, + n_field_elems_per_cell, + n_samples, + } = self.params; + + // assign n_cells_per_slot + pw.set_target(targets.n_cells_per_slot, witnesses.n_cells_per_slot); + + // assign n_slots_per_dataset + pw.set_target(targets.n_slots_per_dataset, witnesses.n_slots_per_dataset); // assign dataset proof - for (i, sibling_hash) in d_proof.path.iter().enumerate() { - // TODO: fix this HashOutTarget later - let sibling_hash_out = sibling_hash.to_vec(); - for j in 0..sibling_hash_out.len() { - pw.set_target(targets.dataset_proof.path[i].elements[j], sibling_hash_out[j]); - } + for (i, sibling_hash) in witnesses.slot_proof.iter().enumerate() { + pw.set_hash_target(targets.slot_proof.path[i], *sibling_hash); } // assign slot index - pw.set_target(targets.slot_index, F::from_canonical_u64(slot_index as u64)); + pw.set_target(targets.slot_index, witnesses.slot_index); // assign the expected Merkle root of dataset to the target - let expected_root = self.tree.tree.root().unwrap(); - let expected_root_hash_out = expected_root.to_vec(); - for j in 0..expected_root_hash_out.len() { - pw.set_target(targets.dataset_root.elements[j], expected_root_hash_out[j]); - } + pw.set_hash_target(targets.dataset_root, witnesses.dataset_root); - // the sampled slot - let slot = &self.slot_trees[slot_index]; - let slot_root = slot.tree.tree.root().unwrap(); - pw.set_hash_target(targets.slot_root, slot_root); + // assign the sampled slot + pw.set_hash_target(targets.slot_root, witnesses.slot_root); // assign entropy - for (i, element) in targets.entropy.elements.iter().enumerate() { - if(i==0) { - pw.set_target(*element, F::from_canonical_u64(entropy as u64)); - }else { - pw.set_target(*element, F::from_canonical_u64(0)); - } - } - // pw.set_target(targets.entropy, F::from_canonical_u64(entropy as u64)); + assign_hash_out_targets(pw, &targets.entropy.elements, &witnesses.entropy); // do the sample N times - for i in 0..N_SAMPLES { - let cell_index_bits = calculate_cell_index_bits(entropy,slot_root,i+1); + for i in 0..n_samples { + let cell_index_bits = calculate_cell_index_bits(&witnesses.entropy,witnesses.slot_root,i+1,max_depth); let cell_index = bits_le_padded_to_usize(&cell_index_bits); // assign cell data - let leaf = &slot.cell_data[cell_index]; - for j in 0..leaf.len(){ + let leaf = witnesses.cell_data[i].clone(); + for j in 0..n_field_elems_per_cell{ pw.set_target(targets.cell_data[i][j], leaf[j]); } + // assign proof for that cell - let cell_proof = slot.get_proof(cell_index); - for (k, sibling_hash) in cell_proof.path.iter().enumerate() { - let sibling_hash_out = sibling_hash.to_vec(); - for j in 0..sibling_hash_out.len() { - pw.set_target(targets.slot_proofs[i].path[k].elements[j], sibling_hash_out[j]); - } + let cell_proof = witnesses.merkle_paths[i].clone(); + for k in 0..max_depth { + pw.set_hash_target(targets.merkle_paths[i].path[k], cell_proof[k]) } } } } - -#[cfg(test)] -mod tests { - use std::time::Instant; - use super::*; - use plonky2::plonk::circuit_data::CircuitConfig; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use plonky2::iop::witness::PartialWitness; - - //types for tests - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - type H = PoseidonHash; - - #[test] - fn test_sample_cells() { - let dataset_t = DatasetTreeCircuit::::default(); - let slot_index = 2; - let entropy = 123; - let proof = dataset_t.sample_slot(slot_index,entropy); - let res = dataset_t.verify_sampling(proof).unwrap(); - assert_eq!(res, true); - } - - // sample cells with full set of fake data - // this test takes too long, see next test - #[test] - fn test_sample_cells_circuit() -> Result<()> { - - let mut dataset_t = DatasetTreeCircuit::::default(); - - let slot_index = 2; - let entropy = 123; - - // sanity check - let proof = dataset_t.sample_slot(slot_index,entropy); - let slot_root = dataset_t.slot_trees[slot_index].tree.tree.root().unwrap(); - let res = dataset_t.verify_sampling(proof).unwrap(); - assert_eq!(res, true); - - // create the circuit - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - - let mut targets = dataset_t.sample_slot_circuit(&mut builder); - - // create a PartialWitness and assign - let mut pw = PartialWitness::new(); - dataset_t.sample_slot_assign_witness(&mut pw, &mut targets,slot_index,entropy); - - // build the circuit - let data = builder.build::(); - println!("circuit size = {:?}", data.common.degree_bits()); - - // Prove the circuit with the assigned witness - let start_time = Instant::now(); - let proof_with_pis = data.prove(pw)?; - println!("prove_time = {:?}", start_time.elapsed()); - - // verify the proof - let verifier_data = data.verifier_data(); - assert!( - verifier_data.verify(proof_with_pis).is_ok(), - "Merkle proof verification failed" - ); - - Ok(()) - } - - // same as above but with fake data for the specific slot to be sampled - #[test] - fn test_sample_cells_circuit_from_selected_slot() -> Result<()> { - - let mut dataset_t = DatasetTreeCircuit::::new_for_testing(); - - let slot_index = TESTING_SLOT_INDEX; - let entropy = 123; - - // sanity check - let proof = dataset_t.sample_slot(slot_index,entropy); - let slot_root = dataset_t.slot_trees[slot_index].tree.tree.root().unwrap(); - let res = dataset_t.verify_sampling(proof).unwrap(); - assert_eq!(res, true); - - // create the circuit - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - - let mut targets = dataset_t.sample_slot_circuit(&mut builder); - - // create a PartialWitness and assign - let mut pw = PartialWitness::new(); - dataset_t.sample_slot_assign_witness(&mut pw, &mut targets,slot_index,entropy); - - // build the circuit - let data = builder.build::(); - println!("circuit size = {:?}", data.common.degree_bits()); - - // Prove the circuit with the assigned witness - let start_time = Instant::now(); - let proof_with_pis = data.prove(pw)?; - println!("prove_time = {:?}", start_time.elapsed()); - - // verify the proof - let verifier_data = data.verifier_data(); - assert!( - verifier_data.verify(proof_with_pis).is_ok(), - "Merkle proof verification failed" - ); - - Ok(()) - } -} \ No newline at end of file diff --git a/codex-plonky2-circuits/src/circuits/utils.rs b/codex-plonky2-circuits/src/circuits/utils.rs index 6721c98..efe613c 100644 --- a/codex-plonky2-circuits/src/circuits/utils.rs +++ b/codex-plonky2-circuits/src/circuits/utils.rs @@ -1,12 +1,14 @@ -use plonky2::hash::hash_types::{HashOut, NUM_HASH_OUT_ELTS, RichField}; -use plonky2::iop::witness::PartialWitness; +use plonky2::hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS, RichField}; +use plonky2::iop::witness::{PartialWitness, WitnessWrite}; use plonky2::plonk::circuit_data::{CircuitData, VerifierCircuitData}; -use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher}; +use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; use plonky2::plonk::proof::{Proof, ProofWithPublicInputs}; use plonky2_field::extension::Extendable; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; -use crate::circuits::params::{HF, MAX_DEPTH}; +use crate::circuits::params::HF; use anyhow::Result; +use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::plonk::circuit_builder::CircuitBuilder; // --------- helper functions --------- @@ -23,30 +25,28 @@ pub(crate) fn usize_to_bits_le_padded(index: usize, bit_length: usize) -> Vec(entropy: usize, slot_root: HashOut, ctr: usize) -> Vec { - let entropy_field = F::from_canonical_u64(entropy as u64); - let mut entropy_as_digest = HashOut::::ZERO; - entropy_as_digest.elements[0] = entropy_field; +pub(crate) fn calculate_cell_index_bits(entropy: &Vec, slot_root: HashOut, ctr: usize, depth: usize) -> Vec { let ctr_field = F::from_canonical_u64(ctr as u64); let mut ctr_as_digest = HashOut::::ZERO; ctr_as_digest.elements[0] = ctr_field; let mut hash_inputs = Vec::new(); - hash_inputs.extend_from_slice(&entropy_as_digest.elements); + hash_inputs.extend_from_slice(&entropy); hash_inputs.extend_from_slice(&slot_root.elements); hash_inputs.extend_from_slice(&ctr_as_digest.elements); let hash_output = HF::hash_no_pad(&hash_inputs); let cell_index_bytes = hash_output.elements[0].to_canonical_u64(); - // let p_bits = take_n_bits_from_bytes(&p_bytes, MAX_DEPTH); - let cell_index_bits = usize_to_bits_le_padded(cell_index_bytes as usize, MAX_DEPTH); + let cell_index_bits = usize_to_bits_le_padded(cell_index_bytes as usize, depth); cell_index_bits } + pub(crate) fn take_n_bits_from_bytes(bytes: &[u8], n: usize) -> Vec { bytes.iter() .flat_map(|byte| (0..8u8).map(move |i| (byte >> i) & 1 == 1)) .take(n) .collect() } + /// Converts a vector of bits (LSB first) into an index (usize). pub(crate) fn bits_le_padded_to_usize(bits: &[bool]) -> usize { bits.iter().enumerate().fold(0usize, |acc, (i, &bit)| { @@ -87,4 +87,54 @@ pub fn verify< proof, public_inputs, }) +} + +/// assign a vec of bool values to a vec of BoolTargets +pub(crate) fn assign_bool_targets< + F: RichField + Extendable + Poseidon2, + const D: usize, +>( + pw: &mut PartialWitness, + bool_targets: &Vec, + bools: Vec, +){ + for (i, bit) in bools.iter().enumerate() { + pw.set_bool_target(bool_targets[i], *bit); + } +} + +/// assign a vec of field elems to hash out target elements +pub(crate) fn assign_hash_out_targets< + F: RichField + Extendable + Poseidon2, + const D: usize, +>( + pw: &mut PartialWitness, + hash_out_elements_targets: &[Target], + hash_out_elements: &[F], +){ + for j in 0..NUM_HASH_OUT_ELTS { + pw.set_target(hash_out_elements_targets[j], hash_out_elements[j]); + } +} + +/// helper fn to multiply a HashOutTarget by a Target +pub(crate) fn mul_hash_out_target< + F: RichField + Extendable + Poseidon2, + const D: usize, +>(builder: &mut CircuitBuilder, t: &Target, hash_target: &mut HashOutTarget) -> HashOutTarget { + let mut mul_elements = vec![]; + for i in 0..NUM_HASH_OUT_ELTS { + mul_elements.push(builder.mul(hash_target.elements[i], *t)); + } + HashOutTarget::from_vec(mul_elements) +} + +/// helper fn to add AND assign a HashOutTarget (hot) to a mutable HashOutTarget (mut_hot) +pub(crate) fn add_assign_hash_out_target< + F: RichField + Extendable + Poseidon2, + const D: usize, +>(builder: &mut CircuitBuilder, mut_hot: &mut HashOutTarget, hot: &HashOutTarget) { + for i in 0..NUM_HASH_OUT_ELTS { + mut_hot.elements[i] = (builder.add(mut_hot.elements[i], hot.elements[i])); + } } \ No newline at end of file diff --git a/codex-plonky2-circuits/src/lib.rs b/codex-plonky2-circuits/src/lib.rs index 1034758..d52d3ad 100644 --- a/codex-plonky2-circuits/src/lib.rs +++ b/codex-plonky2-circuits/src/lib.rs @@ -1,2 +1,4 @@ pub mod circuits; -pub mod merkle_tree; \ No newline at end of file +pub mod merkle_tree; +pub mod proof_input; +pub mod tests; \ No newline at end of file diff --git a/codex-plonky2-circuits/src/merkle_tree/capped_tree.rs b/codex-plonky2-circuits/src/merkle_tree/capped_tree.rs deleted file mode 100644 index 7985c78..0000000 --- a/codex-plonky2-circuits/src/merkle_tree/capped_tree.rs +++ /dev/null @@ -1,378 +0,0 @@ -// An adapted implementation of Merkle tree -// based on the original plonky2 merkle tree implementation - -use core::mem::MaybeUninit; -use core::slice; -use anyhow::{ensure, Result}; -use plonky2_maybe_rayon::*; -use serde::{Deserialize, Serialize}; - -use plonky2::hash::hash_types::{HashOutTarget, RichField}; -use plonky2::plonk::config::{GenericHashOut, Hasher}; -use plonky2::util::log2_strict; - -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -#[serde(bound = "")] -pub struct MerkleCap>(pub Vec); - -impl> Default for MerkleCap { - fn default() -> Self { - Self(Vec::new()) - } -} - -impl> MerkleCap { - pub fn len(&self) -> usize { - self.0.len() - } - - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - pub fn height(&self) -> usize { - log2_strict(self.len()) - } - - pub fn flatten(&self) -> Vec { - self.0.iter().flat_map(|&h| h.to_vec()).collect() - } -} - -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct MerkleTree> { - pub leaves: Vec>, - - pub digests: Vec, - - pub cap: MerkleCap, -} - -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -#[serde(bound = "")] -pub struct MerkleProof> { - /// The Merkle digest of each sibling subtree, staying from the bottommost layer. - pub siblings: Vec, -} - -impl> MerkleProof { - pub fn len(&self) -> usize { - self.siblings.len() - } - - pub fn is_empty(&self) -> bool { - self.len() == 0 - } -} - -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct MerkleProofTarget { - /// The Merkle digest of each sibling subtree, staying from the bottommost layer. - pub siblings: Vec, -} - -impl> Default for MerkleTree { - fn default() -> Self { - Self { - leaves: Vec::new(), - digests: Vec::new(), - cap: MerkleCap::default(), - } - } -} - -pub(crate) fn capacity_up_to_mut(v: &mut Vec, len: usize) -> &mut [MaybeUninit] { - assert!(v.capacity() >= len); - let v_ptr = v.as_mut_ptr().cast::>(); - unsafe { - slice::from_raw_parts_mut(v_ptr, len) - } -} - -pub(crate) fn fill_subtree>( - digests_buf: &mut [MaybeUninit], - leaves: &[Vec], -) -> H::Hash { - assert_eq!(leaves.len(), digests_buf.len() / 2 + 1); - if digests_buf.is_empty() { - H::hash_or_noop(&leaves[0]) - } else { - let (left_digests_buf, right_digests_buf) = digests_buf.split_at_mut(digests_buf.len() / 2); - let (left_digest_mem, left_digests_buf) = left_digests_buf.split_last_mut().unwrap(); - let (right_digest_mem, right_digests_buf) = right_digests_buf.split_first_mut().unwrap(); - - let (left_leaves, right_leaves) = leaves.split_at(leaves.len() / 2); - - let (left_digest, right_digest) = plonky2_maybe_rayon::join( - || fill_subtree::(left_digests_buf, left_leaves), - || fill_subtree::(right_digests_buf, right_leaves), - ); - - left_digest_mem.write(left_digest); - right_digest_mem.write(right_digest); - H::two_to_one(left_digest, right_digest) - } -} - -pub(crate) fn fill_digests_buf>( - digests_buf: &mut [MaybeUninit], - cap_buf: &mut [MaybeUninit], - leaves: &[Vec], - cap_height: usize, -) { - - if digests_buf.is_empty() { - debug_assert_eq!(cap_buf.len(), leaves.len()); - cap_buf - .par_iter_mut() - .zip(leaves) - .for_each(|(cap_buf, leaf)| { - cap_buf.write(H::hash_or_noop(leaf)); - }); - return; - } - - let subtree_digests_len = digests_buf.len() >> cap_height; - let subtree_leaves_len = leaves.len() >> cap_height; - let digests_chunks = digests_buf.par_chunks_exact_mut(subtree_digests_len); - let leaves_chunks = leaves.par_chunks_exact(subtree_leaves_len); - assert_eq!(digests_chunks.len(), cap_buf.len()); - assert_eq!(digests_chunks.len(), leaves_chunks.len()); - digests_chunks.zip(cap_buf).zip(leaves_chunks).for_each( - |((subtree_digests, subtree_cap), subtree_leaves)| { - - subtree_cap.write(fill_subtree::(subtree_digests, subtree_leaves)); - }, - ); -} - -pub(crate) fn merkle_tree_prove>( - leaf_index: usize, - leaves_len: usize, - cap_height: usize, - digests: &[H::Hash], -) -> Vec { - let num_layers = log2_strict(leaves_len) - cap_height; - debug_assert_eq!(leaf_index >> (cap_height + num_layers), 0); - - let digest_len = 2 * (leaves_len - (1 << cap_height)); - assert_eq!(digest_len, digests.len()); - - let digest_tree: &[H::Hash] = { - let tree_index = leaf_index >> num_layers; - let tree_len = digest_len >> cap_height; - &digests[tree_len * tree_index..tree_len * (tree_index + 1)] - }; - - // Mask out high bits to get the index within the sub-tree. - let mut pair_index = leaf_index & ((1 << num_layers) - 1); - (0..num_layers) - .map(|i| { - let parity = pair_index & 1; - pair_index >>= 1; - - // The layers' data is interleaved as follows: - // [layer 0, layer 1, layer 0, layer 2, layer 0, layer 1, layer 0, layer 3, ...]. - // Each of the above is a pair of siblings. - // `pair_index` is the index of the pair within layer `i`. - // The index of that the pair within `digests` is - // `pair_index * 2 ** (i + 1) + (2 ** i - 1)`. - let siblings_index = (pair_index << (i + 1)) + (1 << i) - 1; - // We have an index for the _pair_, but we want the index of the _sibling_. - // Double the pair index to get the index of the left sibling. Conditionally add `1` - // if we are to retrieve the right sibling. - let sibling_index = 2 * siblings_index + (1 - parity); - digest_tree[sibling_index] - }) - .collect() -} - -impl> MerkleTree { - pub fn new(leaves: Vec>, cap_height: usize) -> Self { - let log2_leaves_len = log2_strict(leaves.len()); - assert!( - cap_height <= log2_leaves_len, - "cap_height={} should be at most log2(leaves.len())={}", - cap_height, - log2_leaves_len - ); - - let num_digests = 2 * (leaves.len() - (1 << cap_height)); - let mut digests = Vec::with_capacity(num_digests); - - let len_cap = 1 << cap_height; - let mut cap = Vec::with_capacity(len_cap); - - let digests_buf = capacity_up_to_mut(&mut digests, num_digests); - let cap_buf = capacity_up_to_mut(&mut cap, len_cap); - fill_digests_buf::(digests_buf, cap_buf, &leaves[..], cap_height); - - unsafe { - // SAFETY: `fill_digests_buf` and `cap` initialized the spare capacity up to - // `num_digests` and `len_cap`, resp. - digests.set_len(num_digests); - cap.set_len(len_cap); - } - - Self { - leaves, - digests, - cap: MerkleCap(cap), - } - } - - pub fn get(&self, i: usize) -> &[F] { - &self.leaves[i] - } - - // Create a Merkle proof from a leaf index. - pub fn prove(&self, leaf_index: usize) -> MerkleProof { - let cap_height = log2_strict(self.cap.len()); - let siblings = - merkle_tree_prove::(leaf_index, self.leaves.len(), cap_height, &self.digests); - - MerkleProof { siblings } - } -} - -/// Verifies that the given leaf data is present at the given index in the Merkle tree with the -/// given root. -pub fn verify_merkle_proof>( - leaf_data: Vec, - leaf_index: usize, - merkle_root: H::Hash, - proof: &MerkleProof, -) -> Result<()> { - let merkle_cap = MerkleCap(vec![merkle_root]); - verify_merkle_proof_to_cap(leaf_data, leaf_index, &merkle_cap, proof) -} - -/// Verifies that the given leaf data is present at the given index in the Merkle tree with the -/// given cap. -pub fn verify_merkle_proof_to_cap>( - leaf_data: Vec, - leaf_index: usize, - merkle_cap: &MerkleCap, - proof: &MerkleProof, -) -> Result<()> { - verify_batch_merkle_proof_to_cap( - &[leaf_data.clone()], - &[proof.siblings.len()], - leaf_index, - merkle_cap, - proof, - ) -} - -/// Verifies that the given leaf data is present at the given index in the Field Merkle tree with the -/// given cap. -pub fn verify_batch_merkle_proof_to_cap>( - leaf_data: &[Vec], - leaf_heights: &[usize], - mut leaf_index: usize, - merkle_cap: &MerkleCap, - proof: &MerkleProof, -) -> Result<()> { - assert_eq!(leaf_data.len(), leaf_heights.len()); - let mut current_digest = H::hash_or_noop(&leaf_data[0]); - let mut current_height = leaf_heights[0]; - let mut leaf_data_index = 1; - for &sibling_digest in &proof.siblings { - let bit = leaf_index & 1; - leaf_index >>= 1; - current_digest = if bit == 1 { - H::two_to_one(sibling_digest, current_digest) - } else { - H::two_to_one(current_digest, sibling_digest) - }; - current_height -= 1; - - if leaf_data_index < leaf_heights.len() && current_height == leaf_heights[leaf_data_index] { - let mut new_leaves = current_digest.to_vec(); - new_leaves.extend_from_slice(&leaf_data[leaf_data_index]); - current_digest = H::hash_or_noop(&new_leaves); - leaf_data_index += 1; - } - } - assert_eq!(leaf_data_index, leaf_data.len()); - ensure!( - current_digest == merkle_cap.0[leaf_index], - "Invalid Merkle proof." - ); - - Ok(()) -} - -#[cfg(test)] -pub(crate) mod tests { - use anyhow::Result; - - use super::*; - use plonky2::field::extension::Extendable; - use crate::merkle_tree::capped_tree::verify_merkle_proof_to_cap; - use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - pub(crate) fn random_data(n: usize, k: usize) -> Vec> { - (0..n).map(|_| F::rand_vec(k)).collect() - } - - fn verify_all_leaves< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( - leaves: Vec>, - cap_height: usize, - ) -> Result<()> { - let tree = MerkleTree::::new(leaves.clone(), cap_height); - for (i, leaf) in leaves.into_iter().enumerate() { - let proof = tree.prove(i); - verify_merkle_proof_to_cap(leaf, i, &tree.cap, &proof)?; - } - Ok(()) - } - - #[test] - #[should_panic] - fn test_cap_height_too_big() { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let log_n = 8; - let cap_height = log_n + 1; // Should panic if `cap_height > len_n`. - - let leaves = random_data::(1 << log_n, 7); - let _ = MerkleTree::>::Hasher>::new(leaves, cap_height); - } - - #[test] - fn test_cap_height_eq_log2_len() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let log_n = 8; - let n = 1 << log_n; - let leaves = random_data::(n, 7); - - verify_all_leaves::(leaves, log_n)?; - - Ok(()) - } - - #[test] - fn test_merkle_trees() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let log_n = 8; - let n = 1 << log_n; - let leaves = random_data::(n, 7); - - verify_all_leaves::(leaves, 1)?; - - Ok(()) - } -} diff --git a/codex-plonky2-circuits/src/merkle_tree/merkle_safe.rs b/codex-plonky2-circuits/src/merkle_tree/merkle_safe.rs index b23fa20..b9db4f0 100644 --- a/codex-plonky2-circuits/src/merkle_tree/merkle_safe.rs +++ b/codex-plonky2-circuits/src/merkle_tree/merkle_safe.rs @@ -5,7 +5,7 @@ use std::marker::PhantomData; use anyhow::{ensure, Result}; use plonky2::field::goldilocks_field::GoldilocksField; -use plonky2::hash::hash_types::{HashOut, HashOutTarget, RichField}; +use plonky2::hash::hash_types::{HashOut, RichField}; use plonky2::hash::poseidon::PoseidonHash; use plonky2::plonk::config::Hasher; use std::ops::Shr; @@ -137,12 +137,6 @@ pub struct MerkleProof { pub zero: HashOut, } -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct MerkleProofTarget { - /// The Merkle digest of each sibling subtree, staying from the bottommost layer. - pub path: Vec, -} - impl MerkleProof { /// Reconstructs the root hash from the proof and the given leaf. pub fn reconstruct_root(&self, leaf: HashOut) -> Result> { diff --git a/codex-plonky2-circuits/src/merkle_tree/mod.rs b/codex-plonky2-circuits/src/merkle_tree/mod.rs index e977a83..53eca8e 100644 --- a/codex-plonky2-circuits/src/merkle_tree/mod.rs +++ b/codex-plonky2-circuits/src/merkle_tree/mod.rs @@ -1,2 +1 @@ -pub mod capped_tree; pub mod merkle_safe; \ No newline at end of file diff --git a/codex-plonky2-circuits/src/proof_input/gen_input.rs b/codex-plonky2-circuits/src/proof_input/gen_input.rs new file mode 100644 index 0000000..c3b2685 --- /dev/null +++ b/codex-plonky2-circuits/src/proof_input/gen_input.rs @@ -0,0 +1,440 @@ +use anyhow::Result; +use plonky2::hash::hash_types::{HashOut, RichField}; +use plonky2::plonk::config::{GenericConfig, Hasher}; +use plonky2_field::extension::Extendable; +use plonky2_field::types::Field; +use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; +use crate::circuits::params::HF; +use crate::proof_input::test_params::{BOT_DEPTH, DATASET_DEPTH, MAX_DEPTH, N_BLOCKS, N_CELLS, N_CELLS_IN_BLOCKS, N_FIELD_ELEMS_PER_CELL, N_SAMPLES, TESTING_SLOT_INDEX}; +use crate::circuits::utils::{bits_le_padded_to_usize, calculate_cell_index_bits, usize_to_bits_le_padded}; +use crate::merkle_tree::merkle_safe::{MerkleProof, MerkleTree}; + +#[derive(Clone)] +pub struct SlotTree< + F: RichField + Extendable + Poseidon2, + const D: usize, +> { + pub tree: MerkleTree, // slot tree + pub block_trees: Vec>, // vec of block trees + pub cell_data: Vec>, // cell data as field elements +} + +#[derive(Clone)] +pub struct Cell< + F: RichField + Extendable + Poseidon2, + const D: usize, +> { + pub data: Vec, // cell data as field elements +} + +impl< + F: RichField + Extendable + Poseidon2, + const D: usize, +> Default for Cell { + /// default cell with random data + fn default() -> Self { + let data = (0..N_FIELD_ELEMS_PER_CELL) + .map(|j| F::rand()) + .collect::>(); + Self{ + data, + } + } +} + +impl< + F: RichField + Extendable + Poseidon2, + const D: usize, +> Default for SlotTree { + /// slot tree with fake data, for testing only + fn default() -> Self { + // generate fake cell data + let mut cell_data = (0..N_CELLS) + .map(|i|{ + Cell::::default() + }) + .collect::>(); + Self::new(cell_data) + } +} + +impl< + F: RichField + Extendable + Poseidon2, + const D: usize, +> SlotTree { + /// Slot tree with fake data, for testing only + pub fn new_for_testing(cells: Vec>) -> Self { + // Hash the cell data block to create leaves for one block + let leaves_block: Vec> = cells + .iter() + .map(|element| { + HF::hash_no_pad(&element.data) + }) + .collect(); + + // Zero hash + let zero = HashOut { + elements: [F::ZERO; 4], + }; + + // Create a block tree from the leaves of one block + let b_tree = Self::get_block_tree(&leaves_block); + + // Now replicate this block tree for all N_BLOCKS blocks + let block_trees = vec![b_tree; N_BLOCKS]; + + // Get the roots of block trees + let block_roots = block_trees + .iter() + .map(|t| t.root().unwrap()) + .collect::>(); + + // Create the slot tree from block roots + let slot_tree = MerkleTree::::new(&block_roots, zero).unwrap(); + + // Create the full cell data and cell hash by repeating the block data + let cell_data = vec![cells.clone(); N_BLOCKS].concat(); + + // Return the constructed Self + Self { + tree: slot_tree, + block_trees, + cell_data, + } + } + /// same as default but with supplied cell data + pub fn new(cells: Vec>) -> Self { + let leaves: Vec> = cells + .iter() + .map(|element| { + HF::hash_no_pad(&element.data) + }) + .collect(); + let zero = HashOut { + elements: [F::ZERO; 4], + }; + let block_trees = (0..N_BLOCKS as usize) + .map(|i| { + let start = i * N_CELLS_IN_BLOCKS; + let end = (i + 1) * N_CELLS_IN_BLOCKS; + Self::get_block_tree(&leaves[start..end].to_vec()) + // MerkleTree:: { tree: b_tree } + }) + .collect::>(); + let block_roots = block_trees.iter() + .map(|t| { + t.root().unwrap() + }) + .collect::>(); + let slot_tree = MerkleTree::::new(&block_roots, zero).unwrap(); + Self { + tree: slot_tree, + block_trees, + cell_data: cells, + } + } + + /// generates a proof for given leaf index + /// the path in the proof is a combined block and slot path to make up the full path + pub fn get_proof(&self, index: usize) -> MerkleProof { + let block_index = index / N_CELLS_IN_BLOCKS; + let leaf_index = index % N_CELLS_IN_BLOCKS; + let block_proof = self.block_trees[block_index].get_proof(leaf_index).unwrap(); + let slot_proof = self.tree.get_proof(block_index).unwrap(); + + // Combine the paths from the block and slot proofs + let mut combined_path = block_proof.path.clone(); + combined_path.extend(slot_proof.path.clone()); + + MerkleProof:: { + index: index, + path: combined_path, + nleaves: self.cell_data.len(), + zero: block_proof.zero.clone(), + } + } + + /// verify the given proof for slot tree, checks equality with given root + pub fn verify_cell_proof(&self, proof: MerkleProof, root: HashOut) -> anyhow::Result { + let mut block_path_bits = usize_to_bits_le_padded(proof.index, MAX_DEPTH); + let last_index = N_CELLS - 1; + let mut block_last_bits = usize_to_bits_le_padded(last_index, MAX_DEPTH); + + let split_point = BOT_DEPTH; + + let slot_last_bits = block_last_bits.split_off(split_point); + let slot_path_bits = block_path_bits.split_off(split_point); + + let leaf_hash = HF::hash_no_pad(&self.cell_data[proof.index].data); + + let mut block_path = proof.path; + let slot_path = block_path.split_off(split_point); + + let block_res = MerkleProof::::reconstruct_root2(leaf_hash, block_path_bits.clone(), block_last_bits.clone(), block_path); + let reconstructed_root = MerkleProof::::reconstruct_root2(block_res.unwrap(), slot_path_bits, slot_last_bits, slot_path); + + Ok(reconstructed_root.unwrap() == root) + } + + fn get_block_tree(leaves: &Vec>) -> MerkleTree { + let zero = HashOut { + elements: [F::ZERO; 4], + }; + // Build the Merkle tree + let block_tree = MerkleTree::::new(leaves, zero).unwrap(); + block_tree + } +} + +// ------ Dataset Tree -------- +///dataset tree containing all slot trees +#[derive(Clone)] +pub struct DatasetTree< + F: RichField + Extendable + Poseidon2, + const D: usize, +> { + pub tree: MerkleTree, // dataset tree + pub slot_trees: Vec>, // vec of slot trees +} + +/// Dataset Merkle proof struct, containing the dataset proof and N_SAMPLES proofs. +#[derive(Clone)] +pub struct DatasetProof { + pub slot_index: F, + pub entropy: HashOut, + pub dataset_proof: MerkleProof, // proof for dataset level tree + pub slot_proofs: Vec>, // proofs for sampled slot, contains N_SAMPLES proofs + pub cell_data: Vec>, +} + +impl< + F: RichField + Extendable + Poseidon2, + const D: usize, +> Default for DatasetTree { + /// dataset tree with fake data, for testing only + fn default() -> Self { + let mut slot_trees = vec![]; + let n_slots = 1 << DATASET_DEPTH; + for i in 0..n_slots { + slot_trees.push(SlotTree::::default()); + } + Self::new(slot_trees) + } +} + +impl< + F: RichField + Extendable + Poseidon2, + const D: usize, +> DatasetTree { + /// dataset tree with fake data, for testing only + /// create data for only the TESTING_SLOT_INDEX in params file + pub fn new_for_testing() -> Self { + let mut slot_trees = vec![]; + let n_slots = 1 << DATASET_DEPTH; + // zero hash + let zero = HashOut { + elements: [F::ZERO; 4], + }; + let zero_slot = SlotTree:: { + tree: MerkleTree::::new(&[zero.clone()], zero.clone()).unwrap(), + block_trees: vec![], + cell_data: vec![], + }; + for i in 0..n_slots { + if (i == TESTING_SLOT_INDEX) { + slot_trees.push(SlotTree::::default()); + } else { + slot_trees.push(zero_slot.clone()); + } + } + // get the roots or slot trees + let slot_roots = slot_trees.iter() + .map(|t| { + t.tree.root().unwrap() + }) + .collect::>(); + let dataset_tree = MerkleTree::::new(&slot_roots, zero).unwrap(); + Self { + tree: dataset_tree, + slot_trees, + } + } + + /// same as default but with supplied slot trees + pub fn new(slot_trees: Vec>) -> Self { + // get the roots or slot trees + let slot_roots = slot_trees.iter() + .map(|t| { + t.tree.root().unwrap() + }) + .collect::>(); + // zero hash + let zero = HashOut { + elements: [F::ZERO; 4], + }; + let dataset_tree = MerkleTree::::new(&slot_roots, zero).unwrap(); + Self { + tree: dataset_tree, + slot_trees, + } + } + + /// generates a dataset level proof for given slot index + /// just a regular merkle tree proof + pub fn get_proof(&self, index: usize) -> MerkleProof { + let dataset_proof = self.tree.get_proof(index).unwrap(); + dataset_proof + } + + /// generates a proof for given slot index + /// also takes entropy so it can use it sample the slot + pub fn sample_slot(&self, index: usize, entropy: usize) -> DatasetProof { + let dataset_proof = self.tree.get_proof(index).unwrap(); + let slot = &self.slot_trees[index]; + let slot_root = slot.tree.root().unwrap(); + let mut slot_proofs = vec![]; + let mut cell_data = vec![]; + let entropy_field = F::from_canonical_u64(entropy as u64); + let mut entropy_as_digest = HashOut::::ZERO; + entropy_as_digest.elements[0] = entropy_field; + // get the index for cell from H(slot_root|counter|entropy) + for i in 0..N_SAMPLES { + let cell_index_bits = calculate_cell_index_bits(&entropy_as_digest.elements.to_vec(), slot_root, i+1, MAX_DEPTH); + let cell_index = bits_le_padded_to_usize(&cell_index_bits); + let s_proof = slot.get_proof(cell_index); + slot_proofs.push(s_proof); + cell_data.push(slot.cell_data[cell_index].data.clone()); + } + + DatasetProof { + slot_index: F::from_canonical_u64(index as u64), + entropy: entropy_as_digest, + dataset_proof, + slot_proofs, + cell_data, + } + } + + // verify the sampling - non-circuit version + pub fn verify_sampling(&self, proof: DatasetProof) -> bool { + let slot = &self.slot_trees[proof.slot_index.to_canonical_u64() as usize]; + let slot_root = slot.tree.root().unwrap(); + // check dataset level proof + let d_res = proof.dataset_proof.verify(slot_root, self.tree.root().unwrap()); + if (d_res.unwrap() == false) { + return false; + } + // sanity check + assert_eq!(N_SAMPLES, proof.slot_proofs.len()); + // get the index for cell from H(slot_root|counter|entropy) + for i in 0..N_SAMPLES { + // let entropy_field = F::from_canonical_u64(proof.entropy as u64); + // let mut entropy_as_digest = HashOut::::ZERO; + // entropy_as_digest.elements[0] = entropy_field; + let cell_index_bits = calculate_cell_index_bits(&proof.entropy.elements.to_vec(), slot_root, i+1, MAX_DEPTH); + let cell_index = bits_le_padded_to_usize(&cell_index_bits); + //check the cell_index is the same as one in the proof + assert_eq!(cell_index, proof.slot_proofs[i].index); + let s_res = slot.verify_cell_proof(proof.slot_proofs[i].clone(), slot_root); + if (s_res.unwrap() == false) { + return false; + } + } + true + } +} + +#[cfg(test)] +mod tests { + use std::time::Instant; + use super::*; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::GenericConfig; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use crate::circuits::sample_cells::{CircuitParams, DatasetTreeCircuit, SampleCircuitInput}; + use crate::proof_input::test_params::{D, C, F, H, N_SLOTS}; + + #[test] + fn test_sample_cells() { + let dataset_t = DatasetTree::::new_for_testing(); + let slot_index = 2; + let entropy = 2; + let proof = dataset_t.sample_slot(slot_index,entropy); + let res = dataset_t.verify_sampling(proof); + assert_eq!(res, true); + } + + #[test] + fn test_sample_cells_circuit_from_selected_slot() -> anyhow::Result<()> { + + let mut dataset_t = DatasetTree::::new_for_testing(); + + let slot_index = TESTING_SLOT_INDEX; + let entropy = 123; + + // sanity check + let proof = dataset_t.sample_slot(slot_index,entropy); + let slot_root = dataset_t.slot_trees[slot_index].tree.root().unwrap(); + let res = dataset_t.verify_sampling(proof.clone()); + assert_eq!(res, true); + + // create the circuit + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + let circuit_params = CircuitParams{ + max_depth: MAX_DEPTH, + max_log2_n_slots: DATASET_DEPTH, + block_tree_depth: BOT_DEPTH, + n_field_elems_per_cell: N_FIELD_ELEMS_PER_CELL, + n_samples: N_SAMPLES, + }; + let circ = DatasetTreeCircuit::new(circuit_params); + let mut targets = circ.sample_slot_circuit(&mut builder); + + // create a PartialWitness and assign + let mut pw = PartialWitness::new(); + + let mut slot_paths = vec![]; + for i in 0..N_SAMPLES{ + let path = proof.slot_proofs[i].path.clone(); + slot_paths.push(path); + //TODO: need to be padded + } + + let witness = SampleCircuitInput::{ + entropy: proof.entropy.elements.clone().to_vec(), + dataset_root: dataset_t.tree.root().unwrap(), + slot_index: proof.slot_index.clone(), + slot_root, + n_cells_per_slot: F::from_canonical_u64((2_u32.pow(MAX_DEPTH as u32)) as u64), + n_slots_per_dataset: F::from_canonical_u64((2_u32.pow(DATASET_DEPTH as u32)) as u64), + slot_proof: proof.dataset_proof.path.clone(), + cell_data: proof.cell_data.clone(), + merkle_paths: slot_paths, + }; + + println!("dataset ={:?}",dataset_t.slot_trees[0].tree.layers); + + circ.sample_slot_assign_witness(&mut pw, &mut targets,witness); + + // build the circuit + let data = builder.build::(); + println!("circuit size = {:?}", data.common.degree_bits()); + + // Prove the circuit with the assigned witness + let start_time = Instant::now(); + let proof_with_pis = data.prove(pw)?; + println!("prove_time = {:?}", start_time.elapsed()); + + // verify the proof + let verifier_data = data.verifier_data(); + assert!( + verifier_data.verify(proof_with_pis).is_ok(), + "Merkle proof verification failed" + ); + + Ok(()) + } +} \ No newline at end of file diff --git a/codex-plonky2-circuits/src/proof_input/mod.rs b/codex-plonky2-circuits/src/proof_input/mod.rs new file mode 100644 index 0000000..558a61b --- /dev/null +++ b/codex-plonky2-circuits/src/proof_input/mod.rs @@ -0,0 +1,3 @@ +pub mod gen_input; +pub mod test_params; +pub mod utils; \ No newline at end of file diff --git a/codex-plonky2-circuits/src/proof_input/test_params.rs b/codex-plonky2-circuits/src/proof_input/test_params.rs new file mode 100644 index 0000000..1820fdf --- /dev/null +++ b/codex-plonky2-circuits/src/proof_input/test_params.rs @@ -0,0 +1,39 @@ +// config for generating input for proof circuit + +use plonky2::hash::poseidon::PoseidonHash; +use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + +// fake input params + +// types +pub const D: usize = 2; +pub type C = PoseidonGoldilocksConfig; +pub type F = >::F; // this is the goldilocks field +pub type H = PoseidonHash; + + +// hardcoded params for generating proof input +pub const MAX_DEPTH: usize = 8; // depth of big tree (slot tree depth, includes block tree depth) +pub const MAX_SLOTS: usize = 256; // maximum number of slots +pub const CELL_SIZE: usize = 2048; // cell size in bytes +pub const BLOCK_SIZE: usize = 65536; // block size in bytes +pub const N_SAMPLES: usize = 5; // number of samples to prove + +pub const ENTROPY: usize = 1234567; // external randomness +pub const SEED: usize = 12345; // seed for creating fake data TODO: not used now + +pub const N_SLOTS: usize = 8; // number of slots in the dataset +pub const TESTING_SLOT_INDEX: usize = 2; //the index of the slot to be sampled +pub const N_CELLS: usize = 512; // number of cells in each slot + +// computed constants +pub const GOLDILOCKS_F_SIZE: usize = 64; +pub const N_FIELD_ELEMS_PER_CELL: usize = CELL_SIZE * 8 / GOLDILOCKS_F_SIZE; +pub const BOT_DEPTH: usize = (BLOCK_SIZE/CELL_SIZE).ilog2() as usize; // block tree depth + +pub const N_CELLS_IN_BLOCKS: usize = 1<< BOT_DEPTH; //2^BOT_DEPTH +pub const N_BLOCKS: usize = 1<<(MAX_DEPTH - BOT_DEPTH); // 2^(MAX_DEPTH - BOT_DEPTH) + +pub const DATASET_DEPTH: usize = MAX_SLOTS.ilog2() as usize; + +// TODO: load params \ No newline at end of file diff --git a/codex-plonky2-circuits/src/proof_input/utils.rs b/codex-plonky2-circuits/src/proof_input/utils.rs new file mode 100644 index 0000000..e69de29 diff --git a/codex-plonky2-circuits/src/tests/merkle_circuit.rs b/codex-plonky2-circuits/src/tests/merkle_circuit.rs new file mode 100644 index 0000000..8190296 --- /dev/null +++ b/codex-plonky2-circuits/src/tests/merkle_circuit.rs @@ -0,0 +1,308 @@ +use anyhow::Result; +use plonky2::field::extension::Extendable; +use plonky2::field::goldilocks_field::GoldilocksField; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::{HashOut, HashOutTarget, RichField, NUM_HASH_OUT_ELTS}; +use plonky2::hash::hashing::PlonkyPermutation; +use plonky2::hash::poseidon::PoseidonHash; +use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::iop::witness::{PartialWitness, Witness, WitnessWrite}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, VerifierCircuitData}; +use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher, PoseidonGoldilocksConfig}; +use plonky2::plonk::proof::{Proof, ProofWithPublicInputs}; +use std::marker::PhantomData; +use std::os::macos::raw::stat; +use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; +use serde::Serialize; +use crate::circuits::keyed_compress::key_compress_circuit; +use crate::circuits::params::HF; +use crate::circuits::merkle_circuit::{MerkleProofTarget, MerkleTreeCircuit, MerkleTreeTargets}; +use crate::circuits::utils::{add_assign_hash_out_target, assign_bool_targets, assign_hash_out_targets, mul_hash_out_target, usize_to_bits_le_padded}; + +use crate::merkle_tree::merkle_safe::MerkleTree; +use crate::merkle_tree::merkle_safe::{KEY_NONE,KEY_BOTTOM_LAYER}; + +/// the input to the merkle tree circuit +#[derive(Clone)] +pub struct MerkleTreeCircuitInput< + F: RichField + Extendable + Poseidon2, + const D: usize, +>{ + pub leaf: HashOut, + pub path_bits: Vec, + pub last_bits: Vec, + pub mask_bits: Vec, + pub merkle_path: Vec>, +} + +/// defines the computations inside the circuit and returns the targets used +/// NOTE: this is not used in the sampling circuit, see reconstruct_merkle_root_circuit_with_mask +pub fn build_circuit< + F: RichField + Extendable + Poseidon2, + const D: usize, +>( + builder: &mut CircuitBuilder::, + depth: usize, +) -> (MerkleTreeTargets, HashOutTarget) { + + // Create virtual targets + let leaf = builder.add_virtual_hash(); + + // path bits (binary decomposition of leaf_index) + let path_bits = (0..depth).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); + + // last bits (binary decomposition of last_index = nleaves - 1) + let last_bits = (0..depth).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); + + // last bits (binary decomposition of last_index = nleaves - 1) + let mask_bits = (0..depth+1).map(|_| builder.add_virtual_bool_target_safe()).collect::>(); + + // Merkle path (sibling hashes from leaf to root) + let merkle_path = MerkleProofTarget { + path: (0..depth).map(|_| builder.add_virtual_hash()).collect(), + }; + + // create MerkleTreeTargets struct + let mut targets = MerkleTreeTargets{ + leaf, + path_bits, + last_bits, + mask_bits, + merkle_path, + }; + + // Add Merkle proof verification constraints to the circuit + let reconstructed_root_target = MerkleTreeCircuit::reconstruct_merkle_root_circuit_with_mask(builder, &mut targets, depth); + + // Return MerkleTreeTargets + (targets, reconstructed_root_target) +} + +/// assign the witness values in the circuit targets +/// this takes MerkleTreeCircuitInput and fills all required circuit targets +pub fn assign_witness< + F: RichField + Extendable + Poseidon2, + const D: usize, +>( + pw: &mut PartialWitness, + targets: &mut MerkleTreeTargets, + witnesses: MerkleTreeCircuitInput +)-> Result<()> { + // Assign the leaf hash to the leaf target + pw.set_hash_target(targets.leaf, witnesses.leaf); + + // Assign path bits + assign_bool_targets(pw, &targets.path_bits, witnesses.path_bits); + + // Assign last bits + assign_bool_targets(pw, &targets.last_bits, witnesses.last_bits); + + // Assign mask bits + assign_bool_targets(pw, &targets.mask_bits, witnesses.mask_bits); + + // assign the Merkle path (sibling hashes) to the targets + for i in 0..targets.merkle_path.path.len() { + if i>=witnesses.merkle_path.len() { // pad with zeros + assign_hash_out_targets(pw, &targets.merkle_path.path[i].elements, &[F::ZERO; NUM_HASH_OUT_ELTS]); + continue + } + assign_hash_out_targets(pw, &targets.merkle_path.path[i].elements, &witnesses.merkle_path[i].elements) + } + Ok(()) +} + + +#[cfg(test)] +mod tests { + use std::time::Instant; + use plonky2::hash::hash_types::HashOut; + use plonky2::hash::poseidon::PoseidonHash; + use super::*; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2_field::goldilocks_field::GoldilocksField; + use crate::circuits::merkle_circuit::{MerkleTreeCircuit, }; + use crate::circuits::sample_cells::{CircuitParams, DatasetTreeCircuit, SampleCircuitInput}; + use crate::circuits::utils::usize_to_bits_le_padded; + use crate::merkle_tree::merkle_safe::MerkleTree; + use crate::proof_input::test_params::{D, C, F, H, N_SLOTS, MAX_DEPTH}; + + // NOTE: for now these tests don't check the reconstructed root is equal to expected_root +// will be fixed later, but for that test check the prove_single_cell tests + #[test] + fn test_build_circuit() -> anyhow::Result<()> { + // circuit params + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type H = PoseidonHash; + + // Generate random leaf data + let nleaves = 16; // Number of leaves + let max_depth = 4; + let data = (0..nleaves) + .map(|i| GoldilocksField::from_canonical_u64(i)) + .collect::>(); + // Hash the data to obtain leaf hashes + let leaves: Vec> = data + .iter() + .map(|&element| { + // Hash each field element to get the leaf hash + PoseidonHash::hash_no_pad(&[element]) + }) + .collect(); + + //initialize the Merkle tree + let zero_hash = HashOut { + elements: [GoldilocksField::ZERO; 4], + }; + let tree = MerkleTree::::new(&leaves, zero_hash)?; + + // select leaf index to prove + let leaf_index: usize = 8; + + // get the Merkle proof for the selected leaf + let proof = tree.get_proof(leaf_index)?; + // sanity check: + let check = proof.verify(tree.layers[0][leaf_index],tree.root().unwrap()).unwrap(); + assert_eq!(check, true); + + // get the expected Merkle root + let expected_root = tree.root()?; + + // create the circuit + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let (mut targets, reconstructed_root_target) = build_circuit(&mut builder, max_depth); + + // expected Merkle root + let expected_root = builder.add_virtual_hash(); + + // check equality with expected root + for i in 0..NUM_HASH_OUT_ELTS { + builder.connect(expected_root.elements[i], reconstructed_root_target.elements[i]); + } + + let path_bits = usize_to_bits_le_padded(leaf_index, max_depth); + let last_index = (nleaves - 1) as usize; + let last_bits = usize_to_bits_le_padded(last_index, max_depth); + let mask_bits = usize_to_bits_le_padded(last_index, max_depth+1); + + // circuit input + let circuit_input = MerkleTreeCircuitInput::{ + leaf: tree.layers[0][leaf_index], + path_bits, + last_bits, + mask_bits, + merkle_path: proof.path, + }; + + // create a PartialWitness and assign + let mut pw = PartialWitness::new(); + assign_witness(&mut pw, &mut targets, circuit_input)?; + pw.set_hash_target(expected_root, tree.root().unwrap()); + + // build the circuit + let data = builder.build::(); + + // Prove the circuit with the assigned witness + let proof_with_pis = data.prove(pw)?; + + // verify the proof + let verifier_data = data.verifier_data(); + assert!( + verifier_data.verify(proof_with_pis).is_ok(), + "Merkle proof verification failed" + ); + + Ok(()) + } + + // same as test above but for all leaves + #[test] + fn test_verify_all_leaves() -> anyhow::Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type H = PoseidonHash; + + let nleaves = 16; // Number of leaves + let max_depth = 4; + let data = (0..nleaves) + .map(|i| GoldilocksField::from_canonical_u64(i as u64)) + .collect::>(); + // Hash the data to obtain leaf hashes + let leaves: Vec> = data + .iter() + .map(|&element| { + // Hash each field element to get the leaf hash + PoseidonHash::hash_no_pad(&[element]) + }) + .collect(); + + let zero_hash = HashOut { + elements: [GoldilocksField::ZERO; 4], + }; + let tree = MerkleTree::::new(&leaves, zero_hash)?; + + let expected_root = tree.root()?; + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let (mut targets, reconstructed_root_target) = build_circuit(&mut builder, max_depth); + + // expected Merkle root + let expected_root_target = builder.add_virtual_hash(); + + // check equality with expected root + for i in 0..NUM_HASH_OUT_ELTS { + builder.connect(expected_root_target.elements[i], reconstructed_root_target.elements[i]); + } + + let data = builder.build::(); + + for leaf_index in 0..nleaves { + let proof = tree.get_proof(leaf_index)?; + let check = proof.verify(tree.layers[0][leaf_index], expected_root)?; + assert!( + check, + "Merkle proof verification failed for leaf index {}", + leaf_index + ); + + let mut pw = PartialWitness::new(); + + let path_bits = usize_to_bits_le_padded(leaf_index, max_depth); + let last_index = (nleaves - 1) as usize; + let last_bits = usize_to_bits_le_padded(last_index, max_depth); + let mask_bits = usize_to_bits_le_padded(last_index, max_depth+1); + + // circuit input + let circuit_input = MerkleTreeCircuitInput::{ + leaf: tree.layers[0][leaf_index], + path_bits, + last_bits, + mask_bits, + merkle_path: proof.path, + }; + + assign_witness(&mut pw, &mut targets, circuit_input)?; + pw.set_hash_target(expected_root_target, expected_root); + + let proof_with_pis = data.prove(pw)?; + + let verifier_data = data.verifier_data(); + assert!( + verifier_data.verify(proof_with_pis).is_ok(), + "Merkle proof verification failed in circuit for leaf index {}", + leaf_index + ); + } + + Ok(()) + } + +} \ No newline at end of file diff --git a/codex-plonky2-circuits/src/tests/mod.rs b/codex-plonky2-circuits/src/tests/mod.rs new file mode 100644 index 0000000..57ffd28 --- /dev/null +++ b/codex-plonky2-circuits/src/tests/mod.rs @@ -0,0 +1 @@ +pub mod merkle_circuit;