diff --git a/proof-input/src/gen_input.rs b/proof-input/src/gen_input.rs index c237d24..c2f5b71 100644 --- a/proof-input/src/gen_input.rs +++ b/proof-input/src/gen_input.rs @@ -12,7 +12,7 @@ use plonky2::iop::witness::PartialWitness; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData}; use plonky2::plonk::proof::ProofWithPublicInputs; -use crate::sponge::{hash_bytes_no_padding, hash_n_with_padding}; +use crate::sponge::hash_bytes_no_padding; use crate::params::{C, D, F}; /// generates circuit input (SampleCircuitInput) from fake data for testing @@ -388,16 +388,16 @@ impl< } } -// build the sampling circuit -// returns the proof and verifier ci +/// build the sampling circuit +/// returns the proof and circuit data pub fn build_circuit(n_samples: usize, slot_index: usize) -> anyhow::Result<(CircuitData, PartialWitness)>{ let (data, pw, _) = build_circuit_with_targets(n_samples, slot_index).unwrap(); Ok((data, pw)) } -// build the sampling circuit , -// returns the proof and verifier ci and targets +/// build the sampling circuit , +/// returns the proof, circuit data, and targets pub fn build_circuit_with_targets(n_samples: usize, slot_index: usize) -> anyhow::Result<(CircuitData, PartialWitness, SampleTargets)>{ // get input let mut params = TestParams::default(); @@ -428,7 +428,7 @@ pub fn build_circuit_with_targets(n_samples: usize, slot_index: usize) -> anyhow Ok((data, pw, targets)) } -// prove the circuit +/// prove the circuit pub fn prove_circuit(data: &CircuitData, pw: &PartialWitness) -> anyhow::Result>{ // Prove the circuit with the assigned witness let proof_with_pis = data.prove(pw.clone())?; @@ -436,22 +436,26 @@ pub fn prove_circuit(data: &CircuitData, pw: &PartialWitness) -> any Ok(proof_with_pis) } +/// returns exactly M default circuit input +pub fn get_m_default_circ_input() -> [SampleCircuitInput; M]{ + let params = TestParams::default(); + let one_circ_input = gen_testing_circuit_input::(¶ms); + let circ_input: [SampleCircuitInput; M] = (0..M) + .map(|_| one_circ_input.clone()) + .collect::>() + .try_into().unwrap(); + circ_input +} + #[cfg(test)] mod tests { use std::time::Instant; use super::*; - use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData}; - use plonky2::plonk::config::GenericConfig; + use plonky2::plonk::circuit_data::CircuitConfig; use plonky2::iop::witness::PartialWitness; use plonky2::plonk::circuit_builder::CircuitBuilder; use codex_plonky2_circuits::circuits::params::CircuitParams; - use codex_plonky2_circuits::recursion::simple_recursion::{aggregate_sampling_proofs, aggregate_sampling_proofs_tree, aggregate_sampling_proofs_tree2}; - use codex_plonky2_circuits::circuits::sample_cells::{SampleCircuit, SampleTargets}; - use codex_plonky2_circuits::recursion::params::RecursionTreeParams; - use plonky2::plonk::proof::ProofWithPublicInputs; - use plonky2_poseidon2::serialization::{DefaultGateSerializer, DefaultGeneratorSerializer}; - use crate::json::write_bytes_to_file; - use codex_plonky2_circuits::recursion::simple_recursion2::{SimpleRecursionCircuit, SimpleRecursionInput}; + use codex_plonky2_circuits::circuits::sample_cells::SampleCircuit; // use crate::params::{C, D, F}; // Test sample cells (non-circuit) diff --git a/proof-input/src/lib.rs b/proof-input/src/lib.rs index ebd00a1..66bdef2 100644 --- a/proof-input/src/lib.rs +++ b/proof-input/src/lib.rs @@ -5,6 +5,3 @@ pub mod utils; pub mod json; pub mod tests; mod sponge; -pub mod cyclic_recursion; -pub mod tree_recursion; -pub mod simple_recursion; diff --git a/proof-input/src/cyclic_recursion.rs b/proof-input/src/tests/cyclic_recursion.rs similarity index 81% rename from proof-input/src/cyclic_recursion.rs rename to proof-input/src/tests/cyclic_recursion.rs index a184483..b535738 100644 --- a/proof-input/src/cyclic_recursion.rs +++ b/proof-input/src/tests/cyclic_recursion.rs @@ -4,22 +4,18 @@ mod tests { use std::time::Instant; use anyhow::Result; - use plonky2::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget, RichField}; + use plonky2::hash::hash_types::HashOut; use plonky2::hash::hashing::hash_n_to_hash_no_pad; - use plonky2::hash::poseidon::{PoseidonHash, PoseidonPermutation}; - use plonky2::iop::witness::{PartialWitness, PartitionWitness, WitnessWrite}; + use plonky2::hash::poseidon::PoseidonPermutation; use plonky2::plonk::circuit_builder::CircuitBuilder; - use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, CommonCircuitData, VerifierCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData}; - use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher, PoseidonGoldilocksConfig}; - use plonky2::recursion::cyclic_recursion::check_cyclic_proof_verifier_data; - use plonky2::recursion::dummy_circuit::cyclic_base_proof; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::GenericConfig; use codex_plonky2_circuits::recursion::params::{F, D, C, Plonky2Proof}; use codex_plonky2_circuits::recursion::sampling_inner_circuit::SamplingRecursion; - use codex_plonky2_circuits::recursion::inner_circuit::InnerCircuit; use plonky2_poseidon2::poseidon2_hash::poseidon2::{Poseidon2, Poseidon2Hash}; use crate::gen_input::gen_testing_circuit_input; use crate::params::TestParams; - use codex_plonky2_circuits::recursion::cyclic_recursion::{common_data_for_recursion, CyclicCircuit}; + use codex_plonky2_circuits::recursion::cyclic_recursion::CyclicCircuit; /// Uses cyclic recursion to sample the dataset @@ -33,7 +29,6 @@ mod tests { let mut builder = CircuitBuilder::::new(config); let one = builder.one(); - // Circuit that does the sampling let inner_sampling_circuit = SamplingRecursion::default(); let mut params = TestParams::default(); params.n_samples = 10; @@ -87,7 +82,6 @@ mod tests { let mut builder = CircuitBuilder::::new(config); let one = builder.one(); - // Circuit that does the sampling let inner_sampling_circuit = SamplingRecursion::default(); let mut params = TestParams::default(); params.n_samples = 10; diff --git a/proof-input/src/tests/mod.rs b/proof-input/src/tests/mod.rs index 00913e7..f1c4b59 100644 --- a/proof-input/src/tests/mod.rs +++ b/proof-input/src/tests/mod.rs @@ -1,2 +1,6 @@ pub mod merkle_circuit; pub mod merkle; +pub mod simple_recursion; +pub mod cyclic_recursion; +pub mod tree_recursion1; +pub mod tree_recursion2; diff --git a/proof-input/src/simple_recursion.rs b/proof-input/src/tests/simple_recursion.rs similarity index 71% rename from proof-input/src/simple_recursion.rs rename to proof-input/src/tests/simple_recursion.rs index b3d60e3..e0fc346 100644 --- a/proof-input/src/simple_recursion.rs +++ b/proof-input/src/tests/simple_recursion.rs @@ -1,4 +1,4 @@ -// tests for simple recursion +// tests for simple recursion approaches use std::time::Instant; use plonky2::hash::hash_types::HashOut; @@ -6,17 +6,17 @@ use plonky2::iop::witness::PartialWitness; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData}; use plonky2_field::types::Field; -use codex_plonky2_circuits::recursion::params::RecursionTreeParams; -use codex_plonky2_circuits::recursion::simple_recursion::{aggregate_sampling_proofs, aggregate_sampling_proofs_tree}; -use codex_plonky2_circuits::recursion::simple_recursion2::{SimpleRecursionCircuit, SimpleRecursionInput}; +use codex_plonky2_circuits::recursion::sampling_inner_circuit::SamplingRecursion; +use codex_plonky2_circuits::recursion::simple_recursion::{aggregate_sampling_proofs,SimpleRecursionCircuit, SimpleRecursionInput}; +use codex_plonky2_circuits::recursion::simple_tree_recursion::aggregate_sampling_proofs_tree; use plonky2_poseidon2::serialization::{DefaultGateSerializer, DefaultGeneratorSerializer}; use crate::gen_input::{build_circuit, prove_circuit}; use crate::json::write_bytes_to_file; -use crate::params::{C, F, D}; +use crate::params::{C, D, F}; -// Test recursion +// Test simple recursion #[test] -fn test_recursion() -> anyhow::Result<()> { +fn test_simple_recursion() -> anyhow::Result<()> { // number of samples in each proof let n_samples = 10; // number of inner proofs: @@ -66,7 +66,7 @@ fn test_recursion() -> anyhow::Result<()> { // Test simple tree recursion #[test] -fn test_tree_recursion() -> anyhow::Result<()> { +fn test_simple_tree_recursion() -> anyhow::Result<()> { // number of samples in each proof let n_samples = 10; // number of inner proofs: @@ -84,30 +84,32 @@ fn test_tree_recursion() -> anyhow::Result<()> { let data = data.unwrap(); println!("inner circuit size = {:?}", data.common.degree_bits()); - let gate_serializer = DefaultGateSerializer; - let generator_serializer =DefaultGeneratorSerializer::::default(); - let data_bytes = data.to_bytes(&gate_serializer, &generator_serializer).unwrap(); - println!("inner proof circuit data size = {} bytes", data_bytes.len()); - let file_path = "inner_circ_data.bin"; - // Write data to the file - write_bytes_to_file(data_bytes, file_path).unwrap(); - println!("Data written to {}", file_path); + // serialization + // let gate_serializer = DefaultGateSerializer; + // let generator_serializer =DefaultGeneratorSerializer::::default(); + // let data_bytes = data.to_bytes(&gate_serializer, &generator_serializer).unwrap(); + // println!("inner proof circuit data size = {} bytes", data_bytes.len()); + // let file_path = "inner_circ_data.bin"; + // // Write data to the file + // write_bytes_to_file(data_bytes, file_path).unwrap(); + // println!("Data written to {}", file_path); let start_time = Instant::now(); let (proof, vd_agg) = aggregate_sampling_proofs_tree(&proofs_with_pi, data)?; println!("prove_time = {:?}", start_time.elapsed()); println!("num of public inputs = {}", proof.public_inputs.len()); println!("agg pub input = {:?}", proof.public_inputs); - println!("outer circuit size = {:?}", vd_agg.common.degree_bits()); - // let gate_serializer = DefaultGateSerializer; - // let generator_serializer =DefaultGeneratorSerializer::::default(); - let outer_data_bytes = vd_agg.to_bytes(&gate_serializer, &generator_serializer).unwrap(); - println!("outer proof circuit data size = {} bytes", outer_data_bytes.len()); - let file_path = "outer_circ_data.bin"; - // Write data to the file - write_bytes_to_file(outer_data_bytes, file_path).unwrap(); - println!("Data written to {}", file_path); + + // serialization + // // let gate_serializer = DefaultGateSerializer; + // // let generator_serializer =DefaultGeneratorSerializer::::default(); + // let outer_data_bytes = vd_agg.to_bytes(&gate_serializer, &generator_serializer).unwrap(); + // println!("outer proof circuit data size = {} bytes", outer_data_bytes.len()); + // let file_path = "outer_circ_data.bin"; + // // Write data to the file + // write_bytes_to_file(outer_data_bytes, file_path).unwrap(); + // println!("Data written to {}", file_path); // Verify the proof let verifier_data = vd_agg.verifier_data(); @@ -119,13 +121,13 @@ fn test_tree_recursion() -> anyhow::Result<()> { Ok(()) } -// test another approach of the tree recursion +// test another approach of the simple recursion #[test] -pub fn test_tree_recursion2()-> anyhow::Result<()>{ +pub fn test_simple_recursion_approach2()-> anyhow::Result<()>{ // number of samples in each proof - let n_samples = 10; + let n_samples = 5; // number of inner proofs: - let n_inner = 4; + const n_inner: usize = 4; let mut data: Option> = None; // get proofs @@ -138,9 +140,10 @@ pub fn test_tree_recursion2()-> anyhow::Result<()>{ } let data = data.unwrap(); - let rt_params = RecursionTreeParams::new(n_inner); - - let rec_circuit = SimpleRecursionCircuit::new(rt_params, data.verifier_data()); + // careful here, the sampling recursion is the default so proofs should be for circuit + // with default params + let sampling_inner_circ = SamplingRecursion::default(); + let rec_circuit = SimpleRecursionCircuit::<_,n_inner>::new(sampling_inner_circ); // Create the circuit let config = CircuitConfig::standard_recursion_config(); @@ -148,7 +151,7 @@ pub fn test_tree_recursion2()-> anyhow::Result<()>{ // Create a PartialWitness let mut pw = PartialWitness::new(); - let targets = rec_circuit.build_circuit(&mut builder); + let targets = rec_circuit.build_circuit(&mut builder)?; let start = Instant::now(); let agg_data = builder.build::(); diff --git a/proof-input/src/tests/tree_recursion1.rs b/proof-input/src/tests/tree_recursion1.rs new file mode 100644 index 0000000..e302e7f --- /dev/null +++ b/proof-input/src/tests/tree_recursion1.rs @@ -0,0 +1,185 @@ +// some tests for approach 1 of the tree recursion + +#[cfg(test)] +mod tests { + use std::time::Instant; + use anyhow::{anyhow, Result}; + use plonky2::hash::poseidon::PoseidonHash; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, CommonCircuitData, VerifierCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData}; + use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher, PoseidonGoldilocksConfig}; + use plonky2_field::types::Field; + use codex_plonky2_circuits::circuits::sample_cells::{SampleCircuit, SampleCircuitInput}; + use codex_plonky2_circuits::recursion::params::{F, D, C, Plonky2Proof}; + use codex_plonky2_circuits::recursion::sampling_inner_circuit::SamplingRecursion; + use codex_plonky2_circuits::recursion::inner_circuit::InnerCircuit; + use plonky2_poseidon2::poseidon2_hash::poseidon2::{Poseidon2, Poseidon2Hash}; + use crate::gen_input::get_m_default_circ_input; + use codex_plonky2_circuits::recursion::tree_recursion::{NodeCircuit, TreeRecursion}; + + /// Uses node recursion to sample the dataset + #[test] + fn test_node_recursion() -> Result<()> { + // const D: usize = 2; + // type C = PoseidonGoldilocksConfig; + // type F = >::F; + const M: usize = 1; + const N: usize = 2; + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let one = builder.one(); + + let inner_sampling_circuit = SamplingRecursion::default(); + + let mut node = NodeCircuit::<_,M,N>::new(inner_sampling_circuit); + let mut tree_circ = TreeRecursion::new(node); + let circ_input = get_m_default_circ_input::(); + + let s = Instant::now(); + tree_circ.build()?; + println!("build = {:?}", s.elapsed()); + let s = Instant::now(); + let proof = tree_circ.prove(&circ_input,None, true)?; + println!("prove = {:?}", s.elapsed()); + println!("num of pi = {}", proof.public_inputs.len()); + println!("pub input: {:?}", proof.public_inputs); + let s = Instant::now(); + assert!( + tree_circ.verify_proof(proof).is_ok(), + "proof verification failed" + ); + println!("verify = {:?}", s.elapsed()); + + Ok(()) + } + + /// Uses node recursion to sample the dataset + #[test] + fn test_tree_recursion_approach1() -> Result<()> { + // const D: usize = 2; + // type C = PoseidonGoldilocksConfig; + // type F = >::F; + const M: usize = 1; + const N: usize = 2; + + const DEPTH: usize = 3; + const TOTAL_INPUT: usize = (N.pow(DEPTH as u32) - 1) / (N - 1); + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + let inner_sampling_circuit = SamplingRecursion::default(); + + let mut node = NodeCircuit::<_,M,N>::new(inner_sampling_circuit); + let mut tree_circ = TreeRecursion::new(node); + + let all_circ_input = get_m_default_circ_input::().to_vec(); + + let s = Instant::now(); + tree_circ.build()?; + println!("build = {:?}", s.elapsed()); + let s = Instant::now(); + let proof = tree_circ.prove_tree(all_circ_input.clone(),DEPTH)?; + println!("prove = {:?}", s.elapsed()); + println!("num of pi = {}", proof.public_inputs.len()); + println!("pub input: {:?}", proof.public_inputs); + + // Extract the final public input hash from the proof + let final_proof_hash = &proof.public_inputs[0..4]; + + // Recompute the expected final public input hash (outside the circuit) + let expected_hash = compute_expected_pub_input_hash::( + &all_circ_input, + DEPTH, + M, + N + )?; + + // Check that the final hash in the proof matches the expected hash + assert_eq!(final_proof_hash, expected_hash.as_slice(), "Public input hash mismatch"); + + let s = Instant::now(); + assert!( + tree_circ.verify_proof(proof).is_ok(), + "proof verification failed" + ); + println!("verify = {:?}", s.elapsed()); + + Ok(()) + } + + /// Recursively compute the final public input hash for a single node in the recursion tree. + /// This is the same logic from `NodeCircuit::build_circuit` + /// TODO: optimize this + fn compute_node_hash>>( + all_circ_inputs: &[I::Input], + depth: usize, + current_depth: usize, + node_idx: usize, + M: usize, + N: usize, + ) -> [F; 4] { + // Calculate the index in all_circ_inputs for this node's M inputs. + // Total inputs per layer: sum_{k=0}^{current_depth-1} M*N^k = M * ((N^current_depth - 1)/(N-1)) + let offset_for_layer = ((N.pow(current_depth as u32) - 1) / (N - 1)) * M; + let node_start = offset_for_layer + node_idx * M; + let node_inputs = &all_circ_inputs[node_start..node_start + M]; + + // Compute the outer public input hash: + // public inputs are [slot_index, dataset_root.elements, entropy.elements]. + let mut outer_pi_hashes = vec![]; + for inp in node_inputs { + let mut pi_vec = vec![inp.slot_index]; + pi_vec.extend_from_slice(&inp.dataset_root.elements); + pi_vec.extend_from_slice(&inp.entropy.elements); + let hash_res = PoseidonHash::hash_no_pad(&pi_vec); + outer_pi_hashes.extend_from_slice(&hash_res.elements); + } + + // hash all these M hashes into one + let outer_pi_hash = PoseidonHash::hash_no_pad(&outer_pi_hashes); + + let is_leaf = current_depth == depth - 1; + + // Compute the inner proof hash (or zero hash if leaf) + let inner_pi_hash_or_zero = if is_leaf { + // condition = false at leaf, so inner proofs = zero hash + [F::ZERO; 4] + } else { + // condition = true at non-leaf node -> recursively compute child hashes + let next_depth = current_depth + 1; + let child_start = node_idx * N; + + let mut inner_pub_input_hashes = vec![]; + for i in child_start..child_start + N { + let child_hash = compute_node_hash::(all_circ_inputs, depth, next_depth, i, M, N); + inner_pub_input_hashes.extend_from_slice(&child_hash); + } + + let inner_pub_input_hash = PoseidonHash::hash_no_pad(&inner_pub_input_hashes); + inner_pub_input_hash.elements + }; + + // Combine outer_pi_hash and inner_pi_hash_or_zero + let mut final_input = vec![]; + final_input.extend_from_slice(&outer_pi_hash.elements); + final_input.extend_from_slice(&inner_pi_hash_or_zero); + + let final_hash = PoseidonHash::hash_no_pad(&final_input); + final_hash.elements + } + + /// Compute the expected public input hash for the entire recursion tree. + /// This function calls `compute_node_hash` starting from the root (layer 0, node 0). + pub fn compute_expected_pub_input_hash>>( + all_circ_inputs: &[I::Input], + depth: usize, + M: usize, + N: usize, + ) -> Result> { + // The root node is at layer = 0 and node_idx = 0 + let final_hash = compute_node_hash::(all_circ_inputs, depth, 0, 0, M, N); + Ok(final_hash.to_vec()) + } +} \ No newline at end of file diff --git a/proof-input/src/tests/tree_recursion2.rs b/proof-input/src/tests/tree_recursion2.rs new file mode 100644 index 0000000..d6c4d0b --- /dev/null +++ b/proof-input/src/tests/tree_recursion2.rs @@ -0,0 +1,307 @@ +// some tests for approach 2 of the tree recursion + +#[cfg(test)] +mod tests { + use std::time::Instant; + use anyhow::{anyhow, Result}; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, CommonCircuitData, VerifierCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData}; + use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher, PoseidonGoldilocksConfig}; + use plonky2::plonk::proof::ProofWithPublicInputs; + use codex_plonky2_circuits::circuits::params::CircuitParams; + use codex_plonky2_circuits::circuits::sample_cells::{SampleCircuit, SampleCircuitInput}; + use codex_plonky2_circuits::recursion::params::{F, D, C, Plonky2Proof}; + use codex_plonky2_circuits::recursion::sampling_inner_circuit::SamplingRecursion; + use codex_plonky2_circuits::recursion::inner_circuit::InnerCircuit; + use codex_plonky2_circuits::recursion::leaf_circuit::{LeafCircuit, LeafInput}; + use plonky2_poseidon2::poseidon2_hash::poseidon2::{Poseidon2, Poseidon2Hash}; + use crate::gen_input::gen_testing_circuit_input; + use crate::params::TestParams; + use codex_plonky2_circuits::recursion::tree_recursion2::{NodeCircuit as nodeC, TreeRecursion as TR}; + use codex_plonky2_circuits::recursion::utils::{get_dummy_leaf_proof, get_dummy_node_proof}; + use crate::gen_input::get_m_default_circ_input; + + /// Uses node recursion to sample the dataset + #[test] + fn test_leaf_circuit() -> Result<()> { + // const D: usize = 2; + // type C = PoseidonGoldilocksConfig; + // type F = >::F; + const M: usize = 1; + const N: usize = 2; + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let params = TestParams::default(); + + let one_circ_input = gen_testing_circuit_input::(¶ms); + let samp_circ = SampleCircuit::::new(CircuitParams::default()); + let inner_tar = samp_circ.sample_slot_circuit_with_public_input(&mut builder); + let mut pw = PartialWitness::::new(); + samp_circ.sample_slot_assign_witness(&mut pw,&inner_tar,&one_circ_input); + let inner_d = builder.build::(); + let inner_prf = inner_d.prove(pw)?; + + let leaf_in = LeafInput{ + inner_proof:inner_prf, + verifier_data: inner_d.verifier_data(), + }; + let config2 = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config2); + + let inner_circ = SamplingRecursion::default(); + let leaf_circuit = LeafCircuit::new(inner_circ); + + let s = Instant::now(); + let leaf_tar = leaf_circuit.build(&mut builder)?; + let circ_data = builder.build::(); + println!("build = {:?}", s.elapsed()); + let s = Instant::now(); + // let proof = tree_circ.prove(&[leaf_in],None, true)?; + let mut pw = PartialWitness::::new(); + leaf_circuit.assign_targets(&mut pw, &leaf_tar, &leaf_in)?; + let proof = circ_data.prove(pw)?; + println!("prove = {:?}", s.elapsed()); + println!("num of pi = {}", proof.public_inputs.len()); + println!("pub input: {:?}", proof.public_inputs); + let s = Instant::now(); + assert!( + circ_data.verify(proof).is_ok(), + "proof verification failed" + ); + println!("verify = {:?}", s.elapsed()); + + Ok(()) + } + + #[test] + fn test_node_circuit_approach2() -> Result<()> { + // use predefined: C, D, F c + const N: usize = 2; // binary tree + + let config = CircuitConfig::standard_recursion_config(); + let mut sampling_builder = CircuitBuilder::::new(config); + + //------------ sampling inner circuit ---------------------- + // Circuit that does the sampling - default input + let mut params = TestParams::default(); + // params.n_samples = 10; + let one_circ_input = gen_testing_circuit_input::(¶ms); + let samp_circ = SampleCircuit::::new(CircuitParams::default()); + let inner_tar = samp_circ.sample_slot_circuit_with_public_input(&mut sampling_builder); + // get generate a sampling proof + let mut pw = PartialWitness::::new(); + samp_circ.sample_slot_assign_witness(&mut pw,&inner_tar,&one_circ_input); + let inner_data = sampling_builder.build::(); + let inner_proof = inner_data.prove(pw)?; + + // ------------------- leaf -------------------- + // leaf circuit that verifies the sampling proof + let inner_circ = SamplingRecursion::default(); + let leaf_circuit = LeafCircuit::new(inner_circ); + + let leaf_in = LeafInput{ + inner_proof, + verifier_data: inner_data.verifier_data(), + }; + let config = CircuitConfig::standard_recursion_config(); + let mut leaf_builder = CircuitBuilder::::new(config); + // build + let s = Instant::now(); + let leaf_targets = leaf_circuit.build(&mut leaf_builder)?; + let leaf_circ_data = leaf_builder.build::(); + println!("build = {:?}", s.elapsed()); + // prove + let s = Instant::now(); + let mut pw = PartialWitness::::new(); + leaf_circuit.assign_targets(&mut pw, &leaf_targets, &leaf_in)?; + let leaf_proof = leaf_circ_data.prove(pw)?; + println!("prove = {:?}", s.elapsed()); + println!("num of pi = {}", leaf_proof.public_inputs.len()); + println!("pub input: {:?}", leaf_proof.public_inputs); + // verify + let s = Instant::now(); + assert!( + leaf_circ_data.verify(leaf_proof.clone()).is_ok(), + "proof verification failed" + ); + println!("verify = {:?}", s.elapsed()); + + // ------------- Node circuit ------------------ + // node circuit that verifies leafs or itself + // build + let s = Instant::now(); + let mut node = nodeC::build_circuit()?; + println!("build = {:?}", s.elapsed()); + + // prove leaf + let s = Instant::now(); + let mut pw = PartialWitness::::new(); + let leaf_proofs: [ProofWithPublicInputs; N] = (0..N) + .map(|_| { + leaf_proof.clone() + }) + .collect::>() + .try_into() + .map_err(|_| anyhow!("Expected exactly M inner circuits"))?; + let dummy_node_proof = get_dummy_node_proof( + &node.node_data.inner_node_common_data, + &node.node_data.node_circuit_data.verifier_only, + ); + let dummy_node_proofs: [ProofWithPublicInputs; N] = (0..N) + .map(|_| { + dummy_node_proof.clone() + }) + .collect::>() + .try_into() + .map_err(|_| anyhow!("Expected exactly M inner circuits"))?; + nodeC::::assign_targets( + node.node_targets.clone(), //targets + Some(leaf_proofs), // leaf proofs + Some(dummy_node_proofs), // node proofs (dummy here) + &node.node_data.leaf_circuit_data.verifier_only, // leaf verifier data + &mut pw, // partial witness + true // is leaf + )?; + let node_proof = node.node_data.node_circuit_data.prove(pw)?; + println!("prove = {:?}", s.elapsed()); + println!("num of pi = {}", node_proof.public_inputs.len()); + println!("pub input: {:?}", node_proof.public_inputs); + let s = Instant::now(); + assert!( + node.node_data.node_circuit_data.verify(node_proof.clone()).is_ok(), + "proof verification failed" + ); + println!("verify = {:?}", s.elapsed()); + + // prove node + let s = Instant::now(); + let mut pw = PartialWitness::::new(); + let node_proofs: [ProofWithPublicInputs; N] = (0..N) + .map(|_| { + node_proof.clone() + }) + .collect::>() + .try_into() + .map_err(|_| anyhow!("Expected exactly M inner circuits"))?; + let dummy_leaf_proof = get_dummy_leaf_proof( + &node.node_data.leaf_circuit_data.common + ); + let dummy_leaf_proofs: [ProofWithPublicInputs; N] = (0..N) + .map(|_| { + dummy_leaf_proof.clone() + }) + .collect::>() + .try_into() + .map_err(|_| anyhow!("Expected exactly M inner circuits"))?; + nodeC::::assign_targets( + node.node_targets.clone(), //targets + Some(dummy_leaf_proofs), // leaf proofs + Some(node_proofs), // node proofs (dummy here) + &node.node_data.leaf_circuit_data.verifier_only, // leaf verifier data + &mut pw, // partial witness + false // is leaf + )?; + let node_proof = node.node_data.node_circuit_data.prove(pw)?; + // let node_proof = node_d.prove(pw)?; + println!("prove = {:?}", s.elapsed()); + println!("num of pi = {}", node_proof.public_inputs.len()); + println!("pub input: {:?}", node_proof.public_inputs); + let s = Instant::now(); + assert!( + node.node_data.node_circuit_data.verify(node_proof.clone()).is_ok(), + "proof verification failed" + ); + println!("verify = {:?}", s.elapsed()); + + Ok(()) + } + + #[test] + fn test_tree_recursion_approach2() -> Result<()> { + // use predefined: C, D, F c + const N: usize = 2; // binary tree + const K: usize = 4; // number of leaves/slots sampled - should be power of 2 + + let config = CircuitConfig::standard_recursion_config(); + let mut sampling_builder = CircuitBuilder::::new(config); + + //------------ sampling inner circuit ---------------------- + // Circuit that does the sampling - default input + let mut params = TestParams::default(); + params.n_samples = 10; + let one_circ_input = gen_testing_circuit_input::(¶ms); + let samp_circ = SampleCircuit::::new(CircuitParams::default()); + let inner_tar = samp_circ.sample_slot_circuit_with_public_input(&mut sampling_builder); + // get generate a sampling proof + let mut pw = PartialWitness::::new(); + samp_circ.sample_slot_assign_witness(&mut pw,&inner_tar,&one_circ_input); + let inner_data = sampling_builder.build::(); + println!("sampling circuit degree bits = {:?}", inner_data.common.degree_bits()); + let inner_proof = inner_data.prove(pw)?; + + // ------------------- leaf -------------------- + // leaf circuit that verifies the sampling proof + let inner_circ = SamplingRecursion::default(); + let leaf_circuit = LeafCircuit::new(inner_circ); + + let leaf_in = LeafInput{ + inner_proof, + verifier_data: inner_data.verifier_data(), + }; + let config = CircuitConfig::standard_recursion_config(); + let mut leaf_builder = CircuitBuilder::::new(config); + // build + let s = Instant::now(); + let leaf_targets = leaf_circuit.build(&mut leaf_builder)?; + let leaf_circ_data = leaf_builder.build::(); + println!("build = {:?}", s.elapsed()); + println!("leaf circuit degree bits = {:?}", leaf_circ_data.common.degree_bits()); + // prove + let s = Instant::now(); + let mut pw = PartialWitness::::new(); + leaf_circuit.assign_targets(&mut pw, &leaf_targets, &leaf_in)?; + let leaf_proof = leaf_circ_data.prove(pw)?; + println!("prove = {:?}", s.elapsed()); + println!("num of pi = {}", leaf_proof.public_inputs.len()); + println!("pub input: {:?}", leaf_proof.public_inputs); + // verify + let s = Instant::now(); + assert!( + leaf_circ_data.verify(leaf_proof.clone()).is_ok(), + "proof verification failed" + ); + println!("verify = {:?}", s.elapsed()); + + // ------------- tree circuit ------------------ + // node circuit that verifies leafs or itself + // build + let s = Instant::now(); + let mut tree = TR::::build()?; + println!("build = {:?}", s.elapsed()); + println!("tree circuit degree bits = {:?}", tree.node.node_data.node_circuit_data.common.degree_bits()); + + // prove leaf + let s = Instant::now(); + // let mut pw = PartialWitness::::new(); + let leaf_proofs: Vec> = (0..K) + .map(|_| { + leaf_proof.clone() + }) + .collect::>(); + + let tree_root_proof = tree.prove_tree(leaf_proofs)?; + println!("prove = {:?}", s.elapsed()); + println!("num of pi = {}", tree_root_proof.public_inputs.len()); + println!("pub input: {:?}", tree_root_proof.public_inputs); + let s = Instant::now(); + assert!( + tree.verify_proof(tree_root_proof.clone()).is_ok(), + "proof verification failed" + ); + println!("verify = {:?}", s.elapsed()); + + Ok(()) + } +} \ No newline at end of file diff --git a/proof-input/src/tree_recursion.rs b/proof-input/src/tree_recursion.rs deleted file mode 100644 index 96395aa..0000000 --- a/proof-input/src/tree_recursion.rs +++ /dev/null @@ -1,68 +0,0 @@ -// some tests for cyclic recursion - -#[cfg(test)] -mod tests { - use std::time::Instant; - use anyhow::{anyhow, Result}; - use plonky2::plonk::circuit_builder::CircuitBuilder; - use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, CommonCircuitData, VerifierCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData}; - use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, GenericHashOut, Hasher, PoseidonGoldilocksConfig}; - use codex_plonky2_circuits::circuits::sample_cells::SampleCircuitInput; - use codex_plonky2_circuits::recursion::params::{F, D, C, Plonky2Proof}; - use codex_plonky2_circuits::recursion::sampling_inner_circuit::SamplingRecursion; - use codex_plonky2_circuits::recursion::inner_circuit::InnerCircuit; - use plonky2_poseidon2::poseidon2_hash::poseidon2::{Poseidon2, Poseidon2Hash}; - use crate::gen_input::gen_testing_circuit_input; - use crate::params::TestParams; - use codex_plonky2_circuits::recursion::tree_recursion::{NodeCircuit, TreeRecursion}; - - fn get_m_default_circ_input() -> [SampleCircuitInput; M]{ - let mut params = TestParams::default(); - params.n_samples = 10; - let one_circ_input = gen_testing_circuit_input::(¶ms); - let circ_input: [SampleCircuitInput; M] = (0..M) - .map(|_| one_circ_input.clone()) - .collect::>() - .try_into().unwrap(); - circ_input - } - - /// Uses node recursion to sample the dataset - #[test] - fn test_node_recursion() -> Result<()> { - // const D: usize = 2; - // type C = PoseidonGoldilocksConfig; - // type F = >::F; - const M: usize = 1; - const N: usize = 2; - - let config = CircuitConfig::standard_recursion_config(); - let mut builder = CircuitBuilder::::new(config); - let one = builder.one(); - - // Circuit that does the sampling - let inner_sampling_circuit = SamplingRecursion::default(); - - let mut cyclic_circ = NodeCircuit::<_,M,N>::new(inner_sampling_circuit); - let mut tree_circ = TreeRecursion::new(cyclic_circ); - let circ_input = get_m_default_circ_input::(); - - let s = Instant::now(); - tree_circ.build()?; - println!("build = {:?}", s.elapsed()); - let s = Instant::now(); - let proof = tree_circ.prove(&circ_input,None, true)?; - println!("prove = {:?}", s.elapsed()); - println!("num of pi = {}", proof.public_inputs.len()); - println!("pub input: {:?}", proof.public_inputs); - let s = Instant::now(); - assert!( - tree_circ.verify_proof(proof).is_ok(), - "proof verification failed" - ); - println!("verify = {:?}", s.elapsed()); - - Ok(()) - } - -} \ No newline at end of file diff --git a/workflow/Cargo.toml b/workflow/Cargo.toml index e0b05f4..b7aebc1 100644 --- a/workflow/Cargo.toml +++ b/workflow/Cargo.toml @@ -52,4 +52,8 @@ harness = false [[bench]] name = "tree_recursion" +harness = false + +[[bench]] +name = "simple_tree_recursion" harness = false \ No newline at end of file diff --git a/workflow/benches/simple_tree_recursion.rs b/workflow/benches/simple_tree_recursion.rs new file mode 100644 index 0000000..e76bd24 --- /dev/null +++ b/workflow/benches/simple_tree_recursion.rs @@ -0,0 +1,58 @@ +use criterion::{Criterion, criterion_group, criterion_main}; +use plonky2::plonk::circuit_data::VerifierCircuitData; +use plonky2::plonk::config::GenericConfig; +use plonky2::plonk::proof::ProofWithPublicInputs; +use codex_plonky2_circuits::recursion::simple_tree_recursion::aggregate_sampling_proofs_tree2; +use proof_input::params::{C, D, F}; +use proof_input::gen_input::{build_circuit, prove_circuit}; + +/// Benchmark for building, proving, and verifying the Plonky2 recursion circuit. +fn bench_tree_recursion(c: &mut Criterion) { + // num of inner proofs + let num_of_inner_proofs = 4; + // number of samples in each proof + let n_samples = 10; + + let (data, pw) = build_circuit(n_samples, 3).unwrap(); + + // get proofs + let mut proofs_with_pi = vec![]; + for i in 0..num_of_inner_proofs{ + proofs_with_pi.push(prove_circuit(&data, &pw).unwrap()); + } + let vd = data.verifier_data(); + + let mut group = c.benchmark_group("bench simple tree recursion"); + let mut agg_proof_with_pis: Option> = None; + let mut agg_vd: Option> = None; + + // Benchmark the Circuit Building Phase + group.bench_function("build & prove Circuit", |b| { + b.iter(|| { + let (agg_p, agg_d) = aggregate_sampling_proofs_tree2(&proofs_with_pi, vd.clone()).unwrap(); + agg_proof_with_pis = Some(agg_p); + agg_vd = Some(agg_d); + }) + }); + + let proof = agg_proof_with_pis.unwrap(); + println!("Proof size: {} bytes", proof.to_bytes().len()); + + // Benchmark the Verifying Phase + let loc_vd = agg_vd.unwrap(); + group.bench_function("Verify Proof", |b| { + b.iter(|| { + loc_vd.clone().verify(proof.clone()).expect("Failed to verify proof"); + }) + }); + + group.finish(); +} + +/// Criterion benchmark group +criterion_group!{ + name = recursion; + config = Criterion::default().sample_size(10); + targets = bench_tree_recursion +} +criterion_main!(recursion); diff --git a/workflow/benches/tree_recursion.rs b/workflow/benches/tree_recursion.rs index fcb1b9b..677a26c 100644 --- a/workflow/benches/tree_recursion.rs +++ b/workflow/benches/tree_recursion.rs @@ -1,49 +1,80 @@ -use anyhow::Result; -use criterion::{criterion_group, criterion_main, Criterion}; -use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, VerifierCircuitData}; +use criterion::{Criterion, criterion_group, criterion_main}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::{CircuitConfig, VerifierCircuitData}; use plonky2::plonk::config::GenericConfig; -use codex_plonky2_circuits::recursion::simple_recursion::{aggregate_sampling_proofs_tree, aggregate_sampling_proofs_tree2}; use plonky2::plonk::proof::ProofWithPublicInputs; -use proof_input::params::{D, C, F, Params, TestParams}; -use proof_input::gen_input::{build_circuit, prove_circuit}; +use codex_plonky2_circuits::circuits::sample_cells::SampleCircuitInput; +use codex_plonky2_circuits::recursion::sampling_inner_circuit::SamplingRecursion; +use codex_plonky2_circuits::recursion::tree_recursion::{NodeCircuit, TreeRecursion}; +use proof_input::params::{C, D, F, TestParams}; +use proof_input::gen_input::gen_testing_circuit_input; -/// Benchmark for building, proving, and verifying the Plonky2 recursion circuit. +fn get_m_default_circ_input() -> [SampleCircuitInput; M]{ + let mut params = TestParams::default(); + params.n_samples = 10; + let one_circ_input = gen_testing_circuit_input::(¶ms); + let circ_input: [SampleCircuitInput; M] = (0..M) + .map(|_| one_circ_input.clone()) + .collect::>() + .try_into().unwrap(); + circ_input +} + +/// Benchmark for building, proving, and verifying the Plonky2 tree recursion circuit. fn bench_tree_recursion(c: &mut Criterion) { - // num of inner proofs - let num_of_inner_proofs = 4; - // number of samples in each proof - let n_samples = 10; - let (data, pw) = build_circuit(n_samples, 3).unwrap(); + const M: usize = 1; + const N: usize = 2; - // get proofs - let mut proofs_with_pi = vec![]; - for i in 0..num_of_inner_proofs{ - proofs_with_pi.push(prove_circuit(&data, &pw).unwrap()); - } - let vd = data.verifier_data(); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + // Circuit that does the sampling + let inner_sampling_circuit = SamplingRecursion::default(); + + let mut cyclic_circ = NodeCircuit::<_,M,N>::new(inner_sampling_circuit); + let mut tree_circ = TreeRecursion::new(cyclic_circ); + let circ_input = get_m_default_circ_input::(); + + // // get proofs + // let mut proofs_with_pi = vec![]; + // for i in 0..num_of_inner_proofs{ + // proofs_with_pi.push(prove_circuit(&data, &pw).unwrap()); + // } + // let vd = data.verifier_data(); let mut group = c.benchmark_group("bench simple tree recursion"); let mut agg_proof_with_pis: Option> = None; let mut agg_vd: Option> = None; // Benchmark the Circuit Building Phase - group.bench_function("build & prove Circuit", |b| { + group.bench_function("build", |b| { b.iter(|| { - let (agg_p, agg_d) = aggregate_sampling_proofs_tree2(&proofs_with_pi, vd.clone()).unwrap(); - agg_proof_with_pis = Some(agg_p); - agg_vd = Some(agg_d); + let mut cyclic_circ = NodeCircuit::<_,M,N>::new(inner_sampling_circuit.clone()); + let mut tree_circ = TreeRecursion::new(cyclic_circ); + tree_circ.build(); }) }); - let proof = agg_proof_with_pis.unwrap(); - println!("Proof size: {} bytes", proof.to_bytes().len()); + // let proof = agg_proof_with_pis.unwrap(); + // println!("Proof size: {} bytes", proof.to_bytes().len()); + + let mut proof: Option> = None; + + // Benchmark the Circuit prove Phase + group.bench_function("prove", |b| { + b.iter(|| { + proof = Some(tree_circ.prove(&circ_input,None, true).unwrap()); + }) + }); + + // let proof = tree_circ.prove(&circ_input,None, true)?; // Benchmark the Verifying Phase let loc_vd = agg_vd.unwrap(); group.bench_function("Verify Proof", |b| { b.iter(|| { - loc_vd.clone().verify(proof.clone()).expect("Failed to verify proof"); + tree_circ.verify_proof(proof.unwrap()).expect("Failed to verify proof"); }) }); @@ -53,7 +84,7 @@ fn bench_tree_recursion(c: &mut Criterion) { /// Criterion benchmark group criterion_group!{ name = recursion; - config = Criterion::default().sample_size(10); + config = Criterion::default().sample_size(3); targets = bench_tree_recursion } criterion_main!(recursion);