From 00ee7103544c48ddc7d8b1fa946b7d9dda13f75a Mon Sep 17 00:00:00 2001 From: M Alghazwi Date: Fri, 17 Jan 2025 10:05:04 +0100 Subject: [PATCH] add hybrid recursion --- .../{tree2 => circuits}/leaf_circuit.rs | 55 ++++--- .../src/recursion/circuits/mod.rs | 1 + .../src/recursion/hybrid/mod.rs | 2 + .../src/recursion/hybrid/node_circuit.rs | 104 +++++++++++++ .../src/recursion/hybrid/tree_circuit.rs | 142 ++++++++++++++++++ codex-plonky2-circuits/src/recursion/mod.rs | 1 + .../src/recursion/tree2/mod.rs | 1 - .../src/recursion/tree2/node_circuit.rs | 7 +- .../src/recursion/tree2/tree_circuit.rs | 7 +- 9 files changed, 294 insertions(+), 26 deletions(-) rename codex-plonky2-circuits/src/recursion/{tree2 => circuits}/leaf_circuit.rs (72%) create mode 100644 codex-plonky2-circuits/src/recursion/hybrid/mod.rs create mode 100644 codex-plonky2-circuits/src/recursion/hybrid/node_circuit.rs create mode 100644 codex-plonky2-circuits/src/recursion/hybrid/tree_circuit.rs diff --git a/codex-plonky2-circuits/src/recursion/tree2/leaf_circuit.rs b/codex-plonky2-circuits/src/recursion/circuits/leaf_circuit.rs similarity index 72% rename from codex-plonky2-circuits/src/recursion/tree2/leaf_circuit.rs rename to codex-plonky2-circuits/src/recursion/circuits/leaf_circuit.rs index 1a396c9..17bd06d 100644 --- a/codex-plonky2-circuits/src/recursion/tree2/leaf_circuit.rs +++ b/codex-plonky2-circuits/src/recursion/circuits/leaf_circuit.rs @@ -9,13 +9,15 @@ use plonky2_field::extension::Extendable; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; use crate::recursion::circuits::inner_circuit::InnerCircuit; use crate::{error::CircuitError,Result}; +use crate::circuits::utils::vec_to_array; /// recursion leaf circuit for the recursion tree circuit #[derive(Clone, Debug)] pub struct LeafCircuit< F: RichField + Extendable + Poseidon2, const D: usize, - I: InnerCircuit + I: InnerCircuit, + const M: usize, > { pub inner_circ: I, phantom_data: PhantomData @@ -24,8 +26,9 @@ pub struct LeafCircuit< impl< F: RichField + Extendable + Poseidon2, const D: usize, - I: InnerCircuit -> LeafCircuit { + I: InnerCircuit, + const M: usize, +> LeafCircuit { pub fn new(inner_circ: I) -> Self { Self{ inner_circ, @@ -36,8 +39,9 @@ impl< #[derive(Clone, Debug)] pub struct LeafTargets < const D: usize, + const M: usize >{ - pub inner_proof: ProofWithPublicInputsTarget, + pub inner_proof: [ProofWithPublicInputsTarget; M], pub verifier_data: VerifierCircuitTarget, } #[derive(Clone, Debug)] @@ -45,8 +49,9 @@ pub struct LeafInput< F: RichField + Extendable + Poseidon2, const D: usize, C: GenericConfig, + const M: usize, >{ - pub inner_proof: ProofWithPublicInputs, + pub inner_proof: [ProofWithPublicInputs; M], pub verifier_data: VerifierCircuitData } @@ -54,13 +59,14 @@ impl< F: RichField + Extendable + Poseidon2, const D: usize, I: InnerCircuit, -> LeafCircuit{ + const M: usize, +> LeafCircuit{ /// build the leaf circuit pub fn build< C: GenericConfig, H: AlgebraicHasher, - >(&self, builder: &mut CircuitBuilder) -> Result> + >(&self, builder: &mut CircuitBuilder) -> Result> where >::Hasher: AlgebraicHasher { @@ -68,23 +74,32 @@ impl< let common = self.inner_circ.get_common_data()?; // the proof virtual targets - only one for now - // TODO: make it M proofs - let vir_proof = builder.add_virtual_proof_with_pis(&common); + let mut vir_proofs = vec![]; + let mut pub_input = vec![]; + for _i in 0..M { + let vir_proof = builder.add_virtual_proof_with_pis(&common); + let inner_pub_input = vir_proof.public_inputs.clone(); + vir_proofs.push(vir_proof); + pub_input.extend_from_slice(&inner_pub_input); + } // hash the public input & make it public - let inner_pub_input = vir_proof.public_inputs.clone(); - let hash_inner_pub_input = builder.hash_n_to_hash_no_pad::(inner_pub_input); + let hash_inner_pub_input = builder.hash_n_to_hash_no_pad::(pub_input); builder.register_public_inputs(&hash_inner_pub_input.elements); // virtual target for the verifier data let inner_verifier_data = builder.add_virtual_verifier_data(common.config.fri_config.cap_height); // verify the proofs in-circuit (only one now) - builder.verify_proof::(&vir_proof.clone(),&inner_verifier_data,&common); + for i in 0..M { + builder.verify_proof::(&vir_proofs[i], &inner_verifier_data, &common); + } + + let proofs = vec_to_array::>(vir_proofs)?; // return targets let t = LeafTargets { - inner_proof: vir_proof, + inner_proof: proofs, verifier_data: inner_verifier_data, }; Ok(t) @@ -95,15 +110,17 @@ impl< pub fn assign_targets< C: GenericConfig, H: AlgebraicHasher, - >(&self, pw: &mut PartialWitness, targets: &LeafTargets, input: &LeafInput) -> Result<()> + >(&self, pw: &mut PartialWitness, targets: &LeafTargets, input: &LeafInput) -> Result<()> where >::Hasher: AlgebraicHasher { - // assign the proof - pw.set_proof_with_pis_target(&targets.inner_proof,&input.inner_proof) - .map_err(|e| { - CircuitError::ProofTargetAssignmentError("inner-proof".to_string(), e.to_string()) - })?; + // assign the proofs + for i in 0..M { + pw.set_proof_with_pis_target(&targets.inner_proof[i], &input.inner_proof[i]) + .map_err(|e| { + CircuitError::ProofTargetAssignmentError("inner-proof".to_string(), e.to_string()) + })?; + } // assign the verifier data pw.set_verifier_data_target(&targets.verifier_data, &input.verifier_data.verifier_only) diff --git a/codex-plonky2-circuits/src/recursion/circuits/mod.rs b/codex-plonky2-circuits/src/recursion/circuits/mod.rs index d803ff2..9c7c554 100644 --- a/codex-plonky2-circuits/src/recursion/circuits/mod.rs +++ b/codex-plonky2-circuits/src/recursion/circuits/mod.rs @@ -1,2 +1,3 @@ pub mod inner_circuit; pub mod sampling_inner_circuit; +pub mod leaf_circuit; diff --git a/codex-plonky2-circuits/src/recursion/hybrid/mod.rs b/codex-plonky2-circuits/src/recursion/hybrid/mod.rs new file mode 100644 index 0000000..26bc7e1 --- /dev/null +++ b/codex-plonky2-circuits/src/recursion/hybrid/mod.rs @@ -0,0 +1,2 @@ +pub mod node_circuit; +pub mod tree_circuit; \ No newline at end of file diff --git a/codex-plonky2-circuits/src/recursion/hybrid/node_circuit.rs b/codex-plonky2-circuits/src/recursion/hybrid/node_circuit.rs new file mode 100644 index 0000000..0d41668 --- /dev/null +++ b/codex-plonky2-circuits/src/recursion/hybrid/node_circuit.rs @@ -0,0 +1,104 @@ +use std::marker::PhantomData; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::witness::{PartialWitness, WitnessWrite}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::{CommonCircuitData, VerifierCircuitData, VerifierCircuitTarget}; +use plonky2::plonk::config::{AlgebraicHasher, GenericConfig}; +use plonky2::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}; +use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; +use plonky2_field::extension::Extendable; +use crate::circuits::utils::{vec_to_array}; +use crate::{error::CircuitError, Result}; + +/// Node circuit struct +/// contains necessary data +/// N: number of proofs verified in-circuit (so num of child nodes) +pub struct NodeCircuit< + F: RichField + Extendable + Poseidon2, + const D: usize, + C: GenericConfig, + const N: usize, +>{ + phantom_data: PhantomData<(F,C)> +} + +/// Node circuit targets +/// assumes that all proofs use the same verifier data +#[derive(Clone, Debug)] +pub struct NodeCircuitTargets< + const D: usize, + const N: usize, +>{ + pub proof_targets: [ProofWithPublicInputsTarget; N], + pub verifier_data_target: VerifierCircuitTarget, +} + +impl< + F: RichField + Extendable + Poseidon2, + const D: usize, + C: GenericConfig + 'static, + const N: usize, +> NodeCircuit + where + >::Hasher: AlgebraicHasher +{ + + /// builds the node circuit + pub fn build_circuit< + H: AlgebraicHasher, + >( + builder: &mut CircuitBuilder, + common_data: &CommonCircuitData, + ) -> Result<(NodeCircuitTargets)>{ + + // the proof virtual targets + let mut proof_targets = vec![]; + let mut inner_pub_input = vec![]; + for _i in 0..N { + let vir_proof = builder.add_virtual_proof_with_pis(common_data); + // collect the public input + inner_pub_input.extend_from_slice(&vir_proof.public_inputs); + // collect the proof targets + proof_targets.push(vir_proof); + } + // hash the public input & make it public + let hash_inner_pub_input = builder.hash_n_to_hash_no_pad::(inner_pub_input); + builder.register_public_inputs(&hash_inner_pub_input.elements); + + // virtual target for the verifier data + let inner_verifier_data = builder.add_virtual_verifier_data(common_data.config.fri_config.cap_height); + + // verify the proofs in-circuit + for i in 0..N { + builder.verify_proof::(&proof_targets[i],&inner_verifier_data,&common_data); + } + let proof_target_array = vec_to_array::>(proof_targets)?; + + Ok(NodeCircuitTargets{ + proof_targets: proof_target_array, + verifier_data_target: inner_verifier_data, + }) + } + + /// assigns the targets for the Node circuit + pub fn assign_targets( + node_targets: NodeCircuitTargets, + proofs_with_pi: &[ProofWithPublicInputs; N], + verifier_data: &VerifierCircuitData, + pw: &mut PartialWitness, + ) -> Result<()>{ + for i in 0..N{ + pw.set_proof_with_pis_target(&node_targets.proof_targets[i],&proofs_with_pi[i]) + .map_err(|e| { + CircuitError::ProofTargetAssignmentError(format!("proof {}", i), e.to_string()) + })?; + } + // assign the verifier data + pw.set_verifier_data_target(&node_targets.verifier_data_target, &verifier_data.verifier_only) + .map_err(|e| { + CircuitError::VerifierDataTargetAssignmentError(e.to_string()) + })?; + + Ok(()) + } +} \ No newline at end of file diff --git a/codex-plonky2-circuits/src/recursion/hybrid/tree_circuit.rs b/codex-plonky2-circuits/src/recursion/hybrid/tree_circuit.rs new file mode 100644 index 0000000..3756481 --- /dev/null +++ b/codex-plonky2-circuits/src/recursion/hybrid/tree_circuit.rs @@ -0,0 +1,142 @@ +use plonky2::hash::hash_types::RichField; +use plonky2::iop::witness::PartialWitness; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, VerifierCircuitData}; +use plonky2::plonk::config::{AlgebraicHasher, GenericConfig}; +use plonky2::plonk::proof::ProofWithPublicInputs; +use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; +use crate::recursion::circuits::inner_circuit::InnerCircuit; +use plonky2_field::extension::Extendable; +use crate::{error::CircuitError, Result}; +use crate::circuits::utils::vec_to_array; +use crate::recursion::circuits::leaf_circuit::{LeafCircuit, LeafInput}; +use crate::recursion::hybrid::node_circuit::NodeCircuit; + +/// Hybrid tree recursion - combines simple and tree recursion +/// - N: number of leaf proofs to verify in the node circuit +/// - M: number of inner proofs to verify in the leaf circuit +pub struct HybridTreeRecursion< + F: RichField + Extendable + Poseidon2, + const D: usize, + I: InnerCircuit, + const N: usize, + const M: usize, +> { + pub leaf: LeafCircuit, +} + +impl< + F: RichField + Extendable + Poseidon2, + const D: usize, + I: InnerCircuit, + const N: usize, + const M: usize, +> HybridTreeRecursion +{ + + pub fn new( + leaf: LeafCircuit + ) -> Self { + Self{ + leaf, + } + } + + pub fn prove_tree< + C: GenericConfig + 'static, + H: AlgebraicHasher, + >( + &mut self, + proofs_with_pi: &[ProofWithPublicInputs], + inner_verifier_data: VerifierCircuitData, + ) -> Result<(ProofWithPublicInputs, VerifierCircuitData)> where + >::Hasher: AlgebraicHasher + { + // process leaves + let (leaf_proofs, leaf_data) = self.get_leaf_proofs::( + proofs_with_pi, + inner_verifier_data + )?; + + // process nodes + let (root_proof, last_verifier_data) = self.prove::(&leaf_proofs,leaf_data.verifier_data())?; + + Ok((root_proof, last_verifier_data)) + } + + + fn get_leaf_proofs< + C: GenericConfig + 'static, + H: AlgebraicHasher, + >( + &mut self, + proofs_with_pi: &[ProofWithPublicInputs], + inner_verifier_data: VerifierCircuitData, + ) -> Result<(Vec>, CircuitData)> where + >::Hasher: AlgebraicHasher{ + // builder with standard recursion config + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + let leaf_targets = self.leaf.build::(&mut builder)?; + let leaf_data = builder.build::(); + + let mut leaf_proofs = vec![]; + + for chunk in proofs_with_pi.chunks(M){ + let mut pw = PartialWitness::::new(); + let chunk_arr = vec_to_array::>(chunk.to_vec())?; + let leaf_in = LeafInput{ + inner_proof: chunk_arr, + verifier_data: inner_verifier_data.clone(), + }; + self.leaf.assign_targets::(&mut pw,&leaf_targets,&leaf_in)?; + let proof = leaf_data.prove(pw).unwrap(); + leaf_proofs.push(proof); + } + + Ok((leaf_proofs, leaf_data)) + } + + /// generates a proof - only one node + /// takes N proofs + fn prove< + C: GenericConfig + 'static, + H: AlgebraicHasher, + >( + &mut self, + proofs_with_pi: &[ProofWithPublicInputs], + verifier_data: VerifierCircuitData, + ) -> Result<(ProofWithPublicInputs, VerifierCircuitData)> where + >::Hasher: AlgebraicHasher + { + + if proofs_with_pi.len() == 1 { + return Ok((proofs_with_pi[0].clone(), verifier_data)); + } + + let mut new_proofs = vec![]; + + let node_config = CircuitConfig::standard_recursion_config(); + let mut node_builder = CircuitBuilder::::new(node_config); + let node_targets = NodeCircuit::::build_circuit::(&mut node_builder, &verifier_data.common)?; + let node_data = node_builder.build::(); + + for chunk in proofs_with_pi.chunks(N) { + + let chunk_arr = vec_to_array::>(chunk.to_vec())?; + + let mut inner_pw = PartialWitness::new(); + + NodeCircuit::::assign_targets(node_targets.clone(),&chunk_arr,&verifier_data, &mut inner_pw)?; + + let proof = node_data.prove(inner_pw) + .map_err(|e| CircuitError::ProofGenerationError(e.to_string()))?; + new_proofs.push(proof); + } + + self.prove::(&new_proofs, node_data.verifier_data()) + } + +} + diff --git a/codex-plonky2-circuits/src/recursion/mod.rs b/codex-plonky2-circuits/src/recursion/mod.rs index d924c51..482bdbf 100644 --- a/codex-plonky2-circuits/src/recursion/mod.rs +++ b/codex-plonky2-circuits/src/recursion/mod.rs @@ -3,3 +3,4 @@ pub mod circuits; pub mod simple; pub mod tree1; pub mod tree2; +pub mod hybrid; diff --git a/codex-plonky2-circuits/src/recursion/tree2/mod.rs b/codex-plonky2-circuits/src/recursion/tree2/mod.rs index 827e1e4..571a038 100644 --- a/codex-plonky2-circuits/src/recursion/tree2/mod.rs +++ b/codex-plonky2-circuits/src/recursion/tree2/mod.rs @@ -1,4 +1,3 @@ -pub mod leaf_circuit; pub mod dummy_gen; pub mod node_circuit; pub mod tree_circuit; diff --git a/codex-plonky2-circuits/src/recursion/tree2/node_circuit.rs b/codex-plonky2-circuits/src/recursion/tree2/node_circuit.rs index 3148476..fd48c0c 100644 --- a/codex-plonky2-circuits/src/recursion/tree2/node_circuit.rs +++ b/codex-plonky2-circuits/src/recursion/tree2/node_circuit.rs @@ -12,7 +12,7 @@ use crate::recursion::circuits::inner_circuit::InnerCircuit; use plonky2_field::extension::Extendable; use crate::circuits::utils::{select_hash, select_vec, vec_to_array}; use crate::{error::CircuitError, Result}; -use crate::recursion::tree2::leaf_circuit::LeafCircuit; +use crate::recursion::circuits::leaf_circuit::LeafCircuit; /// Node circuit struct /// contains necessary data @@ -67,9 +67,10 @@ impl< /// TODO: make generic recursion config pub fn build_circuit< I: InnerCircuit, - H: AlgebraicHasher + H: AlgebraicHasher, + const M: usize, >( - leaf_circuit:LeafCircuit + leaf_circuit: LeafCircuit ) -> Result>{ // builder with standard recursion config diff --git a/codex-plonky2-circuits/src/recursion/tree2/tree_circuit.rs b/codex-plonky2-circuits/src/recursion/tree2/tree_circuit.rs index bde3de6..538e6b0 100644 --- a/codex-plonky2-circuits/src/recursion/tree2/tree_circuit.rs +++ b/codex-plonky2-circuits/src/recursion/tree2/tree_circuit.rs @@ -9,7 +9,7 @@ use plonky2_field::extension::Extendable; use crate::recursion::tree2::dummy_gen::DummyProofGen; use crate::{error::CircuitError, Result}; use crate::circuits::utils::vec_to_array; -use crate::recursion::tree2::leaf_circuit::LeafCircuit; +use crate::recursion::circuits::leaf_circuit::LeafCircuit; use crate::recursion::tree2::node_circuit::NodeCircuit; /// the tree recursion struct simplifies the process @@ -40,12 +40,13 @@ impl< pub fn build< I: InnerCircuit, H: AlgebraicHasher, + const M: usize, >( - leaf_circuit: LeafCircuit + leaf_circuit: LeafCircuit ) -> Result{ Ok( Self{ - node: NodeCircuit::::build_circuit::(leaf_circuit)?, + node: NodeCircuit::::build_circuit::(leaf_circuit)?, } ) }