From 1f85bd8d5f63577623d4fa97ed69faed79d65c01 Mon Sep 17 00:00:00 2001 From: M Alghazwi Date: Fri, 17 Jan 2025 10:06:09 +0100 Subject: [PATCH] add hybrid tests and refactor --- proof-input/src/recursion/hybrid.rs | 124 ++++++++++++++++++++++++++++ proof-input/src/recursion/mod.rs | 1 + proof-input/src/recursion/tree2.rs | 27 +++--- 3 files changed, 139 insertions(+), 13 deletions(-) create mode 100644 proof-input/src/recursion/hybrid.rs diff --git a/proof-input/src/recursion/hybrid.rs b/proof-input/src/recursion/hybrid.rs new file mode 100644 index 0000000..ff23090 --- /dev/null +++ b/proof-input/src/recursion/hybrid.rs @@ -0,0 +1,124 @@ +// some tests for approach 2 of the tree recursion + +#[cfg(test)] +mod tests { + use std::time::Instant; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, Hasher}; + use plonky2::plonk::proof::{ProofWithPublicInputs}; + use codex_plonky2_circuits::circuits::sample_cells::SampleCircuit; + use crate::params::{F, D, C, HF}; + use codex_plonky2_circuits::recursion::circuits::sampling_inner_circuit::SamplingRecursion; + use codex_plonky2_circuits::recursion::circuits::inner_circuit::InnerCircuit; + use codex_plonky2_circuits::recursion::circuits::leaf_circuit::{LeafCircuit}; + // use plonky2_poseidon2::poseidon2_hash::poseidon2::{Poseidon2, Poseidon2Hash}; + use crate::gen_input::gen_testing_circuit_input; + use crate::params::Params; + use codex_plonky2_circuits::recursion::hybrid::tree_circuit::HybridTreeRecursion; + + + #[test] + fn test_hybrid_recursion() -> anyhow::Result<()> { + const N: usize = 2; // binary tree + const M: usize = 4; // number of proofs in leaves + const K: usize = 8; + + let config = CircuitConfig::standard_recursion_config(); + let mut sampling_builder = CircuitBuilder::::new(config); + + //------------ sampling inner circuit ---------------------- + // Circuit that does the sampling - default input + let mut params = Params::default(); + let one_circ_input = gen_testing_circuit_input::(¶ms.input_params); + let samp_circ = SampleCircuit::::new(params.circuit_params); + let inner_tar = samp_circ.sample_slot_circuit_with_public_input(&mut sampling_builder)?; + // get generate a sampling proof + let mut pw = PartialWitness::::new(); + samp_circ.sample_slot_assign_witness(&mut pw,&inner_tar,&one_circ_input)?; + let inner_data = sampling_builder.build::(); + println!("sampling circuit degree bits = {:?}", inner_data.common.degree_bits()); + let inner_proof = inner_data.prove(pw)?; + + // ------------------- leaf -------------------- + // leaf circuit that verifies the sampling proof + let inner_circ = SamplingRecursion::::new(Params::default().circuit_params); + let leaf_circuit = LeafCircuit::::new(inner_circ); + + // ------------- tree circuit ------------------ + + let mut tree = HybridTreeRecursion::::new(leaf_circuit); + + // prepare input + let input_proofs: Vec> = (0..K) + .map(|_| { + inner_proof.clone() + }) + .collect::>(); + + // prove tree + + let s = Instant::now(); + let (tree_root_proof, verifier_data) = tree.prove_tree::(&input_proofs, inner_data.verifier_data())?; + println!("prove = {:?}", s.elapsed()); + println!("num of pi = {}", tree_root_proof.public_inputs.len()); + let s = Instant::now(); + assert!( + verifier_data.verify(tree_root_proof.clone()).is_ok(), + "proof verification failed" + ); + + assert_eq!( + tree_root_proof.public_inputs[0..4].to_vec(), + get_expected_tree_root_pi_hash::(input_proofs), + "Public input of tree_root_proof does not match the expected root hash" + ); + println!("verify = {:?}", s.elapsed()); + + Ok(()) + } + + // ------------ Public Input Verification ------------ + /// Recompute the expected root public input hash outside the circuit + fn get_expected_tree_root_pi_hash(input_proofs: Vec>) -> Vec{ + // Compute the leaf hashes + + let mut current_hashes = vec![]; + for chunk in input_proofs.chunks(M){ + let chunk_f: Vec = chunk.iter() + .flat_map(|p| p.public_inputs.iter().cloned()) + .collect(); + + let hash = HF::hash_no_pad(&chunk_f); + current_hashes.push(hash); + } + + // compute parent hashes until one root hash remains + while current_hashes.len() > 1 { + let mut next_level_hashes = Vec::new(); + + for chunk in current_hashes.chunks(N) { + // Ensure each chunk has exactly N elements + assert!( + chunk.len() == N, + "Number of proofs is not divisible by N" + ); + + // collect field elements + let chunk_f: Vec = chunk.iter() + .flat_map(|h| h.elements.iter().cloned()) + .collect(); + + // Compute Poseidon2 hash of the concatenated chunk + let hash = HF::hash_no_pad(&chunk_f); + next_level_hashes.push(hash); + } + + current_hashes = next_level_hashes; + } + + //the expected root hash + current_hashes[0].elements.to_vec() + } +} \ No newline at end of file diff --git a/proof-input/src/recursion/mod.rs b/proof-input/src/recursion/mod.rs index bc53c20..1c2b066 100644 --- a/proof-input/src/recursion/mod.rs +++ b/proof-input/src/recursion/mod.rs @@ -3,3 +3,4 @@ pub mod simple_tree; pub mod cyclic_recursion; pub mod tree1; pub mod tree2; +mod hybrid; diff --git a/proof-input/src/recursion/tree2.rs b/proof-input/src/recursion/tree2.rs index 8bf683a..1c2dc0b 100644 --- a/proof-input/src/recursion/tree2.rs +++ b/proof-input/src/recursion/tree2.rs @@ -13,7 +13,7 @@ mod tests { use crate::params::{F, D, C, HF}; use codex_plonky2_circuits::recursion::circuits::sampling_inner_circuit::SamplingRecursion; use codex_plonky2_circuits::recursion::circuits::inner_circuit::InnerCircuit; - use codex_plonky2_circuits::recursion::tree2::leaf_circuit::{LeafCircuit, LeafInput}; + use codex_plonky2_circuits::recursion::circuits::leaf_circuit::{LeafCircuit, LeafInput}; // use plonky2_poseidon2::poseidon2_hash::poseidon2::{Poseidon2, Poseidon2Hash}; use crate::gen_input::gen_testing_circuit_input; use crate::params::Params; @@ -40,15 +40,15 @@ mod tests { let inner_d = builder.build::(); let inner_prf = inner_d.prove(pw)?; - let leaf_in = LeafInput{ - inner_proof:inner_prf, + let leaf_in = LeafInput::{ + inner_proof:[inner_prf; M], verifier_data: inner_d.verifier_data(), }; let config2 = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config2); let inner_circ = SamplingRecursion::::new(Params::default().circuit_params); - let leaf_circuit = LeafCircuit::::new(inner_circ); + let leaf_circuit = LeafCircuit::::new(inner_circ); let s = Instant::now(); let leaf_tar = leaf_circuit.build::(&mut builder)?; @@ -73,6 +73,7 @@ mod tests { #[test] fn test_node_circuit_approach2() -> anyhow::Result<()> { + const M: usize = 1; const N: usize = 2; // binary tree let config = CircuitConfig::standard_recursion_config(); @@ -93,10 +94,10 @@ mod tests { // ------------------- leaf -------------------- // leaf circuit that verifies the sampling proof let inner_circ = SamplingRecursion::::new(Params::default().circuit_params); - let leaf_circuit = LeafCircuit::::new(inner_circ); + let leaf_circuit = LeafCircuit::::new(inner_circ); - let leaf_in = LeafInput{ - inner_proof, + let leaf_in = LeafInput::{ + inner_proof:[inner_proof; M], verifier_data: inner_data.verifier_data(), }; let config = CircuitConfig::standard_recursion_config(); @@ -126,7 +127,7 @@ mod tests { // node circuit that verifies leafs or itself // build let s = Instant::now(); - let mut node = NodeCircuit::::build_circuit::<_,HF>(leaf_circuit)?; + let mut node = NodeCircuit::::build_circuit::<_,HF,M>(leaf_circuit)?; println!("build = {:?}", s.elapsed()); println!("leaf circuit size = {:?}", node.node_data.node_circuit_data.common.degree_bits()); @@ -204,6 +205,7 @@ mod tests { #[test] fn test_tree_recursion_approach2() -> anyhow::Result<()> { + const M: usize = 1; const N: usize = 2; // binary tree const K: usize = 4; // number of leaves/slots sampled - should be power of 2 @@ -226,10 +228,10 @@ mod tests { // ------------------- leaf -------------------- // leaf circuit that verifies the sampling proof let inner_circ = SamplingRecursion::::new(Params::default().circuit_params); - let leaf_circuit = LeafCircuit::::new(inner_circ); + let leaf_circuit = LeafCircuit::::new(inner_circ); - let leaf_in = LeafInput{ - inner_proof, + let leaf_in = LeafInput::{ + inner_proof:[inner_proof; M], verifier_data: inner_data.verifier_data(), }; let config = CircuitConfig::standard_recursion_config(); @@ -259,7 +261,7 @@ mod tests { // node circuit that verifies leafs or itself // build let s = Instant::now(); - let mut tree = TreeRecursion::::build::<_,HF>(leaf_circuit)?; + let mut tree = TreeRecursion::::build::<_,HF, M>(leaf_circuit)?; println!("build = {:?}", s.elapsed()); println!("node circuit degree bits = {:?}", tree.node.node_data.node_circuit_data.common.degree_bits()); @@ -295,7 +297,6 @@ mod tests { fn get_expected_tree_root_pi_hash(leaf_proofs: Vec>) -> Vec{ // Step 1: Extract relevant public inputs from each leaf proof - // Assuming the first public input is the hash used for tree hashing let mut current_hashes: Vec> = leaf_proofs .iter() .map(|p|HashOut::from_vec(p.public_inputs.clone())) // Adjust index if different