add hybrid recursion

This commit is contained in:
M Alghazwi 2025-01-17 10:05:04 +01:00
parent eaf7e65c3d
commit 00ee710354
No known key found for this signature in database
GPG Key ID: 646E567CAD7DB607
9 changed files with 294 additions and 26 deletions

View File

@ -9,13 +9,15 @@ use plonky2_field::extension::Extendable;
use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2; use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2;
use crate::recursion::circuits::inner_circuit::InnerCircuit; use crate::recursion::circuits::inner_circuit::InnerCircuit;
use crate::{error::CircuitError,Result}; use crate::{error::CircuitError,Result};
use crate::circuits::utils::vec_to_array;
/// recursion leaf circuit for the recursion tree circuit /// recursion leaf circuit for the recursion tree circuit
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct LeafCircuit< pub struct LeafCircuit<
F: RichField + Extendable<D> + Poseidon2, F: RichField + Extendable<D> + Poseidon2,
const D: usize, const D: usize,
I: InnerCircuit<F, D> I: InnerCircuit<F, D>,
const M: usize,
> { > {
pub inner_circ: I, pub inner_circ: I,
phantom_data: PhantomData<F> phantom_data: PhantomData<F>
@ -24,8 +26,9 @@ pub struct LeafCircuit<
impl< impl<
F: RichField + Extendable<D> + Poseidon2, F: RichField + Extendable<D> + Poseidon2,
const D: usize, const D: usize,
I: InnerCircuit<F, D> I: InnerCircuit<F, D>,
> LeafCircuit<F,D,I> { const M: usize,
> LeafCircuit<F,D,I, M> {
pub fn new(inner_circ: I) -> Self { pub fn new(inner_circ: I) -> Self {
Self{ Self{
inner_circ, inner_circ,
@ -36,8 +39,9 @@ impl<
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct LeafTargets < pub struct LeafTargets <
const D: usize, const D: usize,
const M: usize
>{ >{
pub inner_proof: ProofWithPublicInputsTarget<D>, pub inner_proof: [ProofWithPublicInputsTarget<D>; M],
pub verifier_data: VerifierCircuitTarget, pub verifier_data: VerifierCircuitTarget,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -45,8 +49,9 @@ pub struct LeafInput<
F: RichField + Extendable<D> + Poseidon2, F: RichField + Extendable<D> + Poseidon2,
const D: usize, const D: usize,
C: GenericConfig<D, F = F>, C: GenericConfig<D, F = F>,
const M: usize,
>{ >{
pub inner_proof: ProofWithPublicInputs<F, C, D>, pub inner_proof: [ProofWithPublicInputs<F, C, D>; M],
pub verifier_data: VerifierCircuitData<F, C, D> pub verifier_data: VerifierCircuitData<F, C, D>
} }
@ -54,13 +59,14 @@ impl<
F: RichField + Extendable<D> + Poseidon2, F: RichField + Extendable<D> + Poseidon2,
const D: usize, const D: usize,
I: InnerCircuit<F, D>, I: InnerCircuit<F, D>,
> LeafCircuit<F,D,I>{ const M: usize,
> LeafCircuit<F,D,I, M>{
/// build the leaf circuit /// build the leaf circuit
pub fn build< pub fn build<
C: GenericConfig<D, F = F>, C: GenericConfig<D, F = F>,
H: AlgebraicHasher<F>, H: AlgebraicHasher<F>,
>(&self, builder: &mut CircuitBuilder<F, D>) -> Result<LeafTargets<D>> >(&self, builder: &mut CircuitBuilder<F, D>) -> Result<LeafTargets<D,M>>
where where
<C as GenericConfig<D>>::Hasher: AlgebraicHasher<F> <C as GenericConfig<D>>::Hasher: AlgebraicHasher<F>
{ {
@ -68,23 +74,32 @@ impl<
let common = self.inner_circ.get_common_data()?; let common = self.inner_circ.get_common_data()?;
// the proof virtual targets - only one for now // the proof virtual targets - only one for now
// TODO: make it M proofs let mut vir_proofs = vec![];
let vir_proof = builder.add_virtual_proof_with_pis(&common); let mut pub_input = vec![];
for _i in 0..M {
let vir_proof = builder.add_virtual_proof_with_pis(&common);
let inner_pub_input = vir_proof.public_inputs.clone();
vir_proofs.push(vir_proof);
pub_input.extend_from_slice(&inner_pub_input);
}
// hash the public input & make it public // hash the public input & make it public
let inner_pub_input = vir_proof.public_inputs.clone(); let hash_inner_pub_input = builder.hash_n_to_hash_no_pad::<H>(pub_input);
let hash_inner_pub_input = builder.hash_n_to_hash_no_pad::<H>(inner_pub_input);
builder.register_public_inputs(&hash_inner_pub_input.elements); builder.register_public_inputs(&hash_inner_pub_input.elements);
// virtual target for the verifier data // virtual target for the verifier data
let inner_verifier_data = builder.add_virtual_verifier_data(common.config.fri_config.cap_height); let inner_verifier_data = builder.add_virtual_verifier_data(common.config.fri_config.cap_height);
// verify the proofs in-circuit (only one now) // verify the proofs in-circuit (only one now)
builder.verify_proof::<C>(&vir_proof.clone(),&inner_verifier_data,&common); for i in 0..M {
builder.verify_proof::<C>(&vir_proofs[i], &inner_verifier_data, &common);
}
let proofs = vec_to_array::<M, ProofWithPublicInputsTarget<D>>(vir_proofs)?;
// return targets // return targets
let t = LeafTargets { let t = LeafTargets {
inner_proof: vir_proof, inner_proof: proofs,
verifier_data: inner_verifier_data, verifier_data: inner_verifier_data,
}; };
Ok(t) Ok(t)
@ -95,15 +110,17 @@ impl<
pub fn assign_targets< pub fn assign_targets<
C: GenericConfig<D, F = F>, C: GenericConfig<D, F = F>,
H: AlgebraicHasher<F>, H: AlgebraicHasher<F>,
>(&self, pw: &mut PartialWitness<F>, targets: &LeafTargets<D>, input: &LeafInput<F, D, C>) -> Result<()> >(&self, pw: &mut PartialWitness<F>, targets: &LeafTargets<D,M>, input: &LeafInput<F, D, C, M>) -> Result<()>
where where
<C as GenericConfig<D>>::Hasher: AlgebraicHasher<F> <C as GenericConfig<D>>::Hasher: AlgebraicHasher<F>
{ {
// assign the proof // assign the proofs
pw.set_proof_with_pis_target(&targets.inner_proof,&input.inner_proof) for i in 0..M {
.map_err(|e| { pw.set_proof_with_pis_target(&targets.inner_proof[i], &input.inner_proof[i])
CircuitError::ProofTargetAssignmentError("inner-proof".to_string(), e.to_string()) .map_err(|e| {
})?; CircuitError::ProofTargetAssignmentError("inner-proof".to_string(), e.to_string())
})?;
}
// assign the verifier data // assign the verifier data
pw.set_verifier_data_target(&targets.verifier_data, &input.verifier_data.verifier_only) pw.set_verifier_data_target(&targets.verifier_data, &input.verifier_data.verifier_only)

View File

@ -1,2 +1,3 @@
pub mod inner_circuit; pub mod inner_circuit;
pub mod sampling_inner_circuit; pub mod sampling_inner_circuit;
pub mod leaf_circuit;

View File

@ -0,0 +1,2 @@
pub mod node_circuit;
pub mod tree_circuit;

View File

@ -0,0 +1,104 @@
use std::marker::PhantomData;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::witness::{PartialWitness, WitnessWrite};
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::{CommonCircuitData, VerifierCircuitData, VerifierCircuitTarget};
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig};
use plonky2::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget};
use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2;
use plonky2_field::extension::Extendable;
use crate::circuits::utils::{vec_to_array};
use crate::{error::CircuitError, Result};
/// Node circuit struct
/// contains necessary data
/// N: number of proofs verified in-circuit (so num of child nodes)
pub struct NodeCircuit<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
C: GenericConfig<D, F = F>,
const N: usize,
>{
phantom_data: PhantomData<(F,C)>
}
/// Node circuit targets
/// assumes that all proofs use the same verifier data
#[derive(Clone, Debug)]
pub struct NodeCircuitTargets<
const D: usize,
const N: usize,
>{
pub proof_targets: [ProofWithPublicInputsTarget<D>; N],
pub verifier_data_target: VerifierCircuitTarget,
}
impl<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
C: GenericConfig<D, F = F> + 'static,
const N: usize,
> NodeCircuit<F, D, C, N>
where
<C as GenericConfig<D>>::Hasher: AlgebraicHasher<F>
{
/// builds the node circuit
pub fn build_circuit<
H: AlgebraicHasher<F>,
>(
builder: &mut CircuitBuilder<F, D>,
common_data: &CommonCircuitData<F, D>,
) -> Result<(NodeCircuitTargets<D, N>)>{
// the proof virtual targets
let mut proof_targets = vec![];
let mut inner_pub_input = vec![];
for _i in 0..N {
let vir_proof = builder.add_virtual_proof_with_pis(common_data);
// collect the public input
inner_pub_input.extend_from_slice(&vir_proof.public_inputs);
// collect the proof targets
proof_targets.push(vir_proof);
}
// hash the public input & make it public
let hash_inner_pub_input = builder.hash_n_to_hash_no_pad::<H>(inner_pub_input);
builder.register_public_inputs(&hash_inner_pub_input.elements);
// virtual target for the verifier data
let inner_verifier_data = builder.add_virtual_verifier_data(common_data.config.fri_config.cap_height);
// verify the proofs in-circuit
for i in 0..N {
builder.verify_proof::<C>(&proof_targets[i],&inner_verifier_data,&common_data);
}
let proof_target_array = vec_to_array::<N,ProofWithPublicInputsTarget<D>>(proof_targets)?;
Ok(NodeCircuitTargets{
proof_targets: proof_target_array,
verifier_data_target: inner_verifier_data,
})
}
/// assigns the targets for the Node circuit
pub fn assign_targets(
node_targets: NodeCircuitTargets<D, N>,
proofs_with_pi: &[ProofWithPublicInputs<F, C, D>; N],
verifier_data: &VerifierCircuitData<F, C, D>,
pw: &mut PartialWitness<F>,
) -> Result<()>{
for i in 0..N{
pw.set_proof_with_pis_target(&node_targets.proof_targets[i],&proofs_with_pi[i])
.map_err(|e| {
CircuitError::ProofTargetAssignmentError(format!("proof {}", i), e.to_string())
})?;
}
// assign the verifier data
pw.set_verifier_data_target(&node_targets.verifier_data_target, &verifier_data.verifier_only)
.map_err(|e| {
CircuitError::VerifierDataTargetAssignmentError(e.to_string())
})?;
Ok(())
}
}

View File

@ -0,0 +1,142 @@
use plonky2::hash::hash_types::RichField;
use plonky2::iop::witness::PartialWitness;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, VerifierCircuitData};
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig};
use plonky2::plonk::proof::ProofWithPublicInputs;
use plonky2_poseidon2::poseidon2_hash::poseidon2::Poseidon2;
use crate::recursion::circuits::inner_circuit::InnerCircuit;
use plonky2_field::extension::Extendable;
use crate::{error::CircuitError, Result};
use crate::circuits::utils::vec_to_array;
use crate::recursion::circuits::leaf_circuit::{LeafCircuit, LeafInput};
use crate::recursion::hybrid::node_circuit::NodeCircuit;
/// Hybrid tree recursion - combines simple and tree recursion
/// - N: number of leaf proofs to verify in the node circuit
/// - M: number of inner proofs to verify in the leaf circuit
pub struct HybridTreeRecursion<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
I: InnerCircuit<F, D>,
const N: usize,
const M: usize,
> {
pub leaf: LeafCircuit<F, D, I, M>,
}
impl<
F: RichField + Extendable<D> + Poseidon2,
const D: usize,
I: InnerCircuit<F, D>,
const N: usize,
const M: usize,
> HybridTreeRecursion<F, D, I, N, M>
{
pub fn new(
leaf: LeafCircuit<F, D, I, M>
) -> Self {
Self{
leaf,
}
}
pub fn prove_tree<
C: GenericConfig<D, F = F> + 'static,
H: AlgebraicHasher<F>,
>(
&mut self,
proofs_with_pi: &[ProofWithPublicInputs<F, C, D>],
inner_verifier_data: VerifierCircuitData<F, C, D>,
) -> Result<(ProofWithPublicInputs<F, C, D>, VerifierCircuitData<F, C, D>)> where
<C as GenericConfig<D>>::Hasher: AlgebraicHasher<F>
{
// process leaves
let (leaf_proofs, leaf_data) = self.get_leaf_proofs::<C,H>(
proofs_with_pi,
inner_verifier_data
)?;
// process nodes
let (root_proof, last_verifier_data) = self.prove::<C,H>(&leaf_proofs,leaf_data.verifier_data())?;
Ok((root_proof, last_verifier_data))
}
fn get_leaf_proofs<
C: GenericConfig<D, F = F> + 'static,
H: AlgebraicHasher<F>,
>(
&mut self,
proofs_with_pi: &[ProofWithPublicInputs<F, C, D>],
inner_verifier_data: VerifierCircuitData<F, C, D>,
) -> Result<(Vec<ProofWithPublicInputs<F, C, D>>, CircuitData<F, C, D>)> where
<C as GenericConfig<D>>::Hasher: AlgebraicHasher<F>{
// builder with standard recursion config
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, D>::new(config);
let leaf_targets = self.leaf.build::<C,H>(&mut builder)?;
let leaf_data = builder.build::<C>();
let mut leaf_proofs = vec![];
for chunk in proofs_with_pi.chunks(M){
let mut pw = PartialWitness::<F>::new();
let chunk_arr = vec_to_array::<M,ProofWithPublicInputs<F, C, D>>(chunk.to_vec())?;
let leaf_in = LeafInput{
inner_proof: chunk_arr,
verifier_data: inner_verifier_data.clone(),
};
self.leaf.assign_targets::<C,H>(&mut pw,&leaf_targets,&leaf_in)?;
let proof = leaf_data.prove(pw).unwrap();
leaf_proofs.push(proof);
}
Ok((leaf_proofs, leaf_data))
}
/// generates a proof - only one node
/// takes N proofs
fn prove<
C: GenericConfig<D, F = F> + 'static,
H: AlgebraicHasher<F>,
>(
&mut self,
proofs_with_pi: &[ProofWithPublicInputs<F, C, D>],
verifier_data: VerifierCircuitData<F, C, D>,
) -> Result<(ProofWithPublicInputs<F, C, D>, VerifierCircuitData<F, C, D>)> where
<C as GenericConfig<D>>::Hasher: AlgebraicHasher<F>
{
if proofs_with_pi.len() == 1 {
return Ok((proofs_with_pi[0].clone(), verifier_data));
}
let mut new_proofs = vec![];
let node_config = CircuitConfig::standard_recursion_config();
let mut node_builder = CircuitBuilder::<F, D>::new(node_config);
let node_targets = NodeCircuit::<F,D,C,N>::build_circuit::<H>(&mut node_builder, &verifier_data.common)?;
let node_data = node_builder.build::<C>();
for chunk in proofs_with_pi.chunks(N) {
let chunk_arr = vec_to_array::<N,ProofWithPublicInputs<F, C, D>>(chunk.to_vec())?;
let mut inner_pw = PartialWitness::new();
NodeCircuit::<F,D,C,N>::assign_targets(node_targets.clone(),&chunk_arr,&verifier_data, &mut inner_pw)?;
let proof = node_data.prove(inner_pw)
.map_err(|e| CircuitError::ProofGenerationError(e.to_string()))?;
new_proofs.push(proof);
}
self.prove::<C,H>(&new_proofs, node_data.verifier_data())
}
}

View File

@ -3,3 +3,4 @@ pub mod circuits;
pub mod simple; pub mod simple;
pub mod tree1; pub mod tree1;
pub mod tree2; pub mod tree2;
pub mod hybrid;

View File

@ -1,4 +1,3 @@
pub mod leaf_circuit;
pub mod dummy_gen; pub mod dummy_gen;
pub mod node_circuit; pub mod node_circuit;
pub mod tree_circuit; pub mod tree_circuit;

View File

@ -12,7 +12,7 @@ use crate::recursion::circuits::inner_circuit::InnerCircuit;
use plonky2_field::extension::Extendable; use plonky2_field::extension::Extendable;
use crate::circuits::utils::{select_hash, select_vec, vec_to_array}; use crate::circuits::utils::{select_hash, select_vec, vec_to_array};
use crate::{error::CircuitError, Result}; use crate::{error::CircuitError, Result};
use crate::recursion::tree2::leaf_circuit::LeafCircuit; use crate::recursion::circuits::leaf_circuit::LeafCircuit;
/// Node circuit struct /// Node circuit struct
/// contains necessary data /// contains necessary data
@ -67,9 +67,10 @@ impl<
/// TODO: make generic recursion config /// TODO: make generic recursion config
pub fn build_circuit< pub fn build_circuit<
I: InnerCircuit<F, D>, I: InnerCircuit<F, D>,
H: AlgebraicHasher<F> H: AlgebraicHasher<F>,
const M: usize,
>( >(
leaf_circuit:LeafCircuit<F, D, I> leaf_circuit: LeafCircuit<F, D, I, M>
) -> Result<NodeCircuit<F, D, C, N>>{ ) -> Result<NodeCircuit<F, D, C, N>>{
// builder with standard recursion config // builder with standard recursion config

View File

@ -9,7 +9,7 @@ use plonky2_field::extension::Extendable;
use crate::recursion::tree2::dummy_gen::DummyProofGen; use crate::recursion::tree2::dummy_gen::DummyProofGen;
use crate::{error::CircuitError, Result}; use crate::{error::CircuitError, Result};
use crate::circuits::utils::vec_to_array; use crate::circuits::utils::vec_to_array;
use crate::recursion::tree2::leaf_circuit::LeafCircuit; use crate::recursion::circuits::leaf_circuit::LeafCircuit;
use crate::recursion::tree2::node_circuit::NodeCircuit; use crate::recursion::tree2::node_circuit::NodeCircuit;
/// the tree recursion struct simplifies the process /// the tree recursion struct simplifies the process
@ -40,12 +40,13 @@ impl<
pub fn build< pub fn build<
I: InnerCircuit<F, D>, I: InnerCircuit<F, D>,
H: AlgebraicHasher<F>, H: AlgebraicHasher<F>,
const M: usize,
>( >(
leaf_circuit: LeafCircuit<F, D, I> leaf_circuit: LeafCircuit<F, D, I, M>
) -> Result<Self>{ ) -> Result<Self>{
Ok( Ok(
Self{ Self{
node: NodeCircuit::<F, D, C, N>::build_circuit::<I,H>(leaf_circuit)?, node: NodeCircuit::<F, D, C, N>::build_circuit::<I,H, M>(leaf_circuit)?,
} }
) )
} }