diff --git a/plonky2/src/recursion/conditional_recursive_verifier.rs b/plonky2/src/recursion/conditional_recursive_verifier.rs index b4da993e..510bb438 100644 --- a/plonky2/src/recursion/conditional_recursive_verifier.rs +++ b/plonky2/src/recursion/conditional_recursive_verifier.rs @@ -182,7 +182,7 @@ impl, const D: usize> CircuitBuilder { .collect() } - fn select_hash( + pub(crate) fn select_hash( &mut self, b: BoolTarget, h0: HashOutTarget, diff --git a/plonky2/src/recursion/cyclic_recursion.rs b/plonky2/src/recursion/cyclic_recursion.rs index 6260a584..e968e5e7 100644 --- a/plonky2/src/recursion/cyclic_recursion.rs +++ b/plonky2/src/recursion/cyclic_recursion.rs @@ -118,12 +118,19 @@ impl CyclicPublicInputsTarget { impl, const D: usize> CircuitBuilder { pub fn cyclic_recursion>( mut self, + previous_virtual_public_inputs: &[Target], + previous_base_case: Target, mut common_data: CommonCircuitData, ) -> Result<(CircuitData, CyclicRecursionTarget)> where C::Hasher: AlgebraicHasher, [(); C::Hasher::HASH_SIZE]:, { + ensure!( + previous_virtual_public_inputs.len() == self.num_public_inputs(), + "Incorrect number of public inputs." + ); + let verifier_data = VerifierCircuitTarget { constants_sigmas_cap: self.add_virtual_cap(self.config.fri_config.cap_height), circuit_digest: self.add_virtual_hash(), @@ -168,6 +175,14 @@ impl, const D: usize> CircuitBuilder { self.connect_hashes(*h0, *h1); } + self.connect(previous_base_case, pis.base_case.target); + for (x, y) in previous_virtual_public_inputs + .iter() + .zip(&proof.public_inputs) + { + self.connect(*x, *y); + } + // Verify the dummy proof if `base_case` is set to true, otherwise verify the "real" proof. self.conditionally_verify_proof( base_case, @@ -212,6 +227,8 @@ pub fn set_cyclic_recursion_data_target< pw: &mut PartialWitness, cyclic_recursion_data_target: &CyclicRecursionTarget, cyclic_recursion_data: &CyclicRecursionData, + // Public inputs to set in the base case to seed some initial data. + public_inputs: &[F], ) -> Result<()> where C::Hasher: AlgebraicHasher, @@ -233,6 +250,7 @@ where let (dummy_proof, dummy_data) = dummy_proof(cyclic_recursion_data.common_data)?; pw.set_bool_target(cyclic_recursion_data_target.base_case, true); let mut proof = dummy_proof.clone(); + proof.public_inputs[0..public_inputs.len()].copy_from_slice(public_inputs); let pis_len = proof.public_inputs.len(); // A base case must be following another base case. proof.public_inputs[pis_len - 1] = F::ONE; @@ -293,13 +311,17 @@ where #[cfg(test)] mod tests { + use anyhow::Result; use plonky2_field::extension::Extendable; + use plonky2_field::types::PrimeField64; use crate::field::types::Field; use crate::hash::hash_types::RichField; - use crate::hash::poseidon::PoseidonHash; - use crate::iop::witness::{PartialWitness, Witness}; + use crate::hash::hashing::hash_n_to_hash_no_pad; + use crate::hash::poseidon::{PoseidonHash, PoseidonPermutation}; + use crate::iop::target::BoolTarget; + use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData, VerifierCircuitTarget}; use crate::plonk::config::{AlgebraicHasher, GenericConfig, Hasher, PoseidonGoldilocksConfig}; @@ -351,23 +373,54 @@ mod tests { let mut pw = PartialWitness::new(); let mut builder = CircuitBuilder::::new(config); - // Build realistic circuit - let t = builder.add_virtual_target(); - pw.set_target(t, F::rand()); - let t_inv = builder.inverse(t); - let h = builder.hash_n_to_hash_no_pad::(vec![t_inv]); + // Circuit that computes a repeated hash. + let initial_hash = builder.add_virtual_hash(); + builder.register_public_inputs(&initial_hash.elements); + // Hash from the previous proof. + let old_hash = builder.add_virtual_hash(); + // Flag set to true if the last proof was a base case. + let old_base_case = builder.add_virtual_target(); + // The input hash is either the previous hash or the initial hash depending on whether + // the last proof was a base case. + let input_hash = builder.select_hash( + BoolTarget::new_unsafe(old_base_case), + initial_hash, + old_hash, + ); + let h = builder.hash_n_to_hash_no_pad::(input_hash.elements.to_vec()); builder.register_public_inputs(&h.elements); + // Previous counter. + let old_counter = builder.add_virtual_target(); + let one = builder.one(); + let old_not_base_case = builder.sub(one, old_base_case); + // New counter is the previous counter +1 if the previous proof wasn't a base case. + let new_counter = builder.add(old_counter, old_not_base_case); + builder.register_public_input(new_counter); + let old_pis = [ + initial_hash.elements.as_slice(), + old_hash.elements.as_slice(), + [old_counter].as_slice(), + ] + .concat(); let common_data = common_data_for_recursion::(); // Add cyclic recursion gadget. - let (cyclic_circuit_data, cyclic_data_target) = builder.cyclic_recursion(common_data)?; + let (cyclic_circuit_data, cyclic_data_target) = + builder.cyclic_recursion(&old_pis, old_base_case, common_data)?; + let cyclic_recursion_data = CyclicRecursionData { proof: &None, // Base case: We don't have a proof to put here yet. verifier_data: &cyclic_circuit_data.verifier_only, common_data: &cyclic_circuit_data.common, }; - set_cyclic_recursion_data_target(&mut pw, &cyclic_data_target, &cyclic_recursion_data)?; + let initial_hash = [F::ZERO, F::ONE, F::TWO, F::from_canonical_usize(3)]; + set_cyclic_recursion_data_target( + &mut pw, + &cyclic_data_target, + &cyclic_recursion_data, + &initial_hash, + )?; let proof = cyclic_circuit_data.prove(pw)?; check_cyclic_proof_verifier_data( &proof, @@ -378,13 +431,17 @@ mod tests { // 1st recursive layer. let mut pw = PartialWitness::new(); - pw.set_target(t, F::rand()); let cyclic_recursion_data = CyclicRecursionData { proof: &Some(proof), // Input previous proof. verifier_data: &cyclic_circuit_data.verifier_only, common_data: &cyclic_circuit_data.common, }; - set_cyclic_recursion_data_target(&mut pw, &cyclic_data_target, &cyclic_recursion_data)?; + set_cyclic_recursion_data_target( + &mut pw, + &cyclic_data_target, + &cyclic_recursion_data, + &[], + )?; let proof = cyclic_circuit_data.prove(pw)?; check_cyclic_proof_verifier_data( &proof, @@ -395,19 +452,39 @@ mod tests { // 2nd recursive layer. let mut pw = PartialWitness::new(); - pw.set_target(t, F::rand()); let cyclic_recursion_data = CyclicRecursionData { proof: &Some(proof), // Input previous proof. verifier_data: &cyclic_circuit_data.verifier_only, common_data: &cyclic_circuit_data.common, }; - set_cyclic_recursion_data_target(&mut pw, &cyclic_data_target, &cyclic_recursion_data)?; + set_cyclic_recursion_data_target( + &mut pw, + &cyclic_data_target, + &cyclic_recursion_data, + &[], + )?; let proof = cyclic_circuit_data.prove(pw)?; check_cyclic_proof_verifier_data( &proof, cyclic_recursion_data.verifier_data, cyclic_recursion_data.common_data, )?; + + // Verify that the proof correctly computes a repeated hash. + let initial_hash = &proof.public_inputs[..4]; + let hash = &proof.public_inputs[4..8]; + let counter = proof.public_inputs[8]; + let mut h: [F; 4] = initial_hash.try_into().unwrap(); + assert_eq!( + hash, + std::iter::repeat_with(|| { + h = hash_n_to_hash_no_pad::(&h).elements; + h + }) + .nth(counter.to_canonical_u64() as usize) + .unwrap() + ); + cyclic_circuit_data.verify(proof) } }