From 341e1ebeec714c79f3af2ec144d161eb5236f9ed Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 26 Oct 2022 10:58:38 +0200 Subject: [PATCH] Working --- plonky2/src/recursion/cyclic_recursion.rs | 116 ++++++++-------------- 1 file changed, 40 insertions(+), 76 deletions(-) diff --git a/plonky2/src/recursion/cyclic_recursion.rs b/plonky2/src/recursion/cyclic_recursion.rs index 10fe9dd0..2b3693e1 100644 --- a/plonky2/src/recursion/cyclic_recursion.rs +++ b/plonky2/src/recursion/cyclic_recursion.rs @@ -10,7 +10,7 @@ use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::{ - CircuitData, CommonCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData, + CommonCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData, }; use crate::plonk::config::Hasher; use crate::plonk::config::{AlgebraicHasher, GenericConfig}; @@ -43,7 +43,6 @@ pub struct CyclicPublicInputs< > { pub circuit_digest: HashOut, pub constants_sigmas_cap: MerkleCap, - pub base_case: bool, } impl, C: GenericConfig, const D: usize> @@ -53,30 +52,23 @@ impl, C: GenericConfig, const D: usize> where C::Hasher: AlgebraicHasher, { - // The structure of the public inputs is `[...,circuit_digest, constants_sigmas_cap, base_case]`. + // The structure of the public inputs is `[..., circuit_digest, constants_sigmas_cap]`. let cap_len = common_data.config.fri_config.num_cap_elements(); let len = slice.len(); - ensure!(len >= 4 + 4 * cap_len + 1, "Not enough public inputs"); - let base_case = slice[len - 1]; - ensure!( - base_case.is_one() || base_case.is_zero(), - "Base case flag {:?} is not binary", - base_case - ); + ensure!(len >= 4 + 4 * cap_len, "Not enough public inputs"); let constants_sigmas_cap = MerkleCap( (0..cap_len) .map(|i| HashOut { - elements: std::array::from_fn(|j| slice[len - 1 - 4 * (cap_len - i) + j]), + elements: std::array::from_fn(|j| slice[len - 4 * (cap_len - i) + j]), }) .collect(), ); let circuit_digest = - HashOut::from_partial(&slice[len - 5 - 4 * cap_len..len - 1 - 4 * cap_len]); + HashOut::from_partial(&slice[len - 4 - 4 * cap_len..len - 4 * cap_len]); Ok(Self { circuit_digest, constants_sigmas_cap, - base_case: base_case.is_one(), }) } } @@ -84,7 +76,6 @@ impl, C: GenericConfig, const D: usize> pub struct CyclicPublicInputsTarget { pub circuit_digest: HashOutTarget, pub constants_sigmas_cap: MerkleCapTarget, - pub base_case: BoolTarget, } impl CyclicPublicInputsTarget { @@ -94,34 +85,31 @@ impl CyclicPublicInputsTarget { ) -> Result { let cap_len = common_data.config.fri_config.num_cap_elements(); let len = slice.len(); - ensure!(len >= 4 + 4 * cap_len + 1, "Not enough public inputs"); - let base_case = BoolTarget::new_unsafe(slice[len - 1]); + ensure!(len >= 4 + 4 * cap_len, "Not enough public inputs"); let constants_sigmas_cap = MerkleCapTarget( (0..cap_len) .map(|i| HashOutTarget { - elements: std::array::from_fn(|j| slice[len - 1 - 4 * (cap_len - i) + j]), + elements: std::array::from_fn(|j| slice[len - 4 * (cap_len - i) + j]), }) .collect(), ); let circuit_digest = HashOutTarget { - elements: std::array::from_fn(|i| slice[len - 5 - 4 * cap_len + i]), + elements: std::array::from_fn(|i| slice[len - 4 - 4 * cap_len + i]), }; Ok(Self { circuit_digest, constants_sigmas_cap, - base_case, }) } } impl, const D: usize> CircuitBuilder { pub fn cyclic_recursion>( - mut self, + &mut self, previous_virtual_public_inputs: &[Target], - previous_base_case: Target, - mut common_data: CommonCircuitData, - ) -> Result<(CircuitData, CyclicRecursionTarget)> + common_data: &mut CommonCircuitData, + ) -> Result> where C::Hasher: AlgebraicHasher, [(); C::Hasher::HASH_SIZE]:, @@ -147,22 +135,15 @@ impl, const D: usize> CircuitBuilder { }; // Flag set to true for the base case of the cycle where we verify a dummy proof to bootstrap the cycle. Set to false otherwise. - // Unsafe is ok since `base_case` is a public input and its booleaness should be checked in the verifier. - let base_case = self.add_virtual_bool_target_unsafe(); - self.register_public_input(base_case.target); + let base_case = self.add_virtual_bool_target_safe(); common_data.num_public_inputs = self.num_public_inputs(); - let proof = self.add_virtual_proof_with_pis::(&common_data); - let dummy_proof = self.add_virtual_proof_with_pis::(&common_data); + let proof = self.add_virtual_proof_with_pis::(common_data); + let dummy_proof = self.add_virtual_proof_with_pis::(common_data); let pis = - CyclicPublicInputsTarget::from_slice::(&proof.public_inputs, &common_data)?; - // Check that the previous base case flag was boolean. - self.assert_bool(pis.base_case); - // Check that we cannot go from a non-base case to a base case by checking `previous_base_case - base_case \in {0,1}`. - let decrease = BoolTarget::new_unsafe(self.sub(pis.base_case.target, base_case.target)); - self.assert_bool(decrease); + CyclicPublicInputsTarget::from_slice::(&proof.public_inputs, common_data)?; // Connect previous verifier data to current one. This guarantees that every proof in the cycle uses the same verifier data. self.connect_hashes(pis.circuit_digest, verifier_data.circuit_digest); for (h0, h1) in pis @@ -174,7 +155,6 @@ impl, const D: usize> CircuitBuilder { self.connect_hashes(*h0, *h1); } - self.connect(previous_base_case, pis.base_case.target); for (x, y) in previous_virtual_public_inputs .iter() .zip(&proof.public_inputs) @@ -189,7 +169,7 @@ impl, const D: usize> CircuitBuilder { &dummy_verifier_data, &proof, &verifier_data, - &common_data, + common_data, ); // Make sure we have enough gates to match `common_data`. @@ -201,24 +181,13 @@ impl, const D: usize> CircuitBuilder { self.add_gate_to_gate_set(g.clone()); } - let data = self.build::(); - ensure!( - data.common == common_data, - "Common data does not match. Final circuit has common data {:?} instead of {:?}.", - data.common, - common_data - ); - - Ok(( - data, - CyclicRecursionTarget { - proof, - verifier_data, - dummy_proof, - dummy_verifier_data, - base_case, - }, - )) + Ok(CyclicRecursionTarget { + proof, + verifier_data, + dummy_proof, + dummy_verifier_data, + base_case, + }) } } @@ -256,8 +225,6 @@ where let mut proof = dummy_proof.clone(); proof.public_inputs[0..public_inputs.len()].copy_from_slice(public_inputs); let pis_len = proof.public_inputs.len(); - // A base case must be following another base case. - proof.public_inputs[pis_len - 1] = F::ONE; // The circuit checks that the verifier data is the same throughout the cycle, so // we set the verifier data to the "real" verifier data even though it's unused in the base case. let num_cap = cyclic_recursion_data @@ -265,7 +232,7 @@ where .config .fri_config .num_cap_elements(); - let s = pis_len - 5 - 4 * num_cap; + let s = pis_len - 4 - 4 * num_cap; proof.public_inputs[s..s + 4] .copy_from_slice(&cyclic_recursion_data.verifier_data.circuit_digest.elements); for i in 0..num_cap { @@ -305,10 +272,8 @@ where C::Hasher: AlgebraicHasher, { let pis = CyclicPublicInputs::::from_slice(&proof.public_inputs, common_data)?; - if !pis.base_case { - ensure!(verifier_data.constants_sigmas_cap == pis.constants_sigmas_cap); - ensure!(verifier_data.circuit_digest == pis.circuit_digest); - } + ensure!(verifier_data.constants_sigmas_cap == pis.constants_sigmas_cap); + ensure!(verifier_data.circuit_digest == pis.circuit_digest); Ok(()) } @@ -325,7 +290,6 @@ mod tests { use crate::hash::hash_types::RichField; use crate::hash::hashing::hash_n_to_hash_no_pad; use crate::hash::poseidon::{PoseidonHash, PoseidonPermutation}; - use crate::iop::target::BoolTarget; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData, VerifierCircuitTarget}; @@ -386,24 +350,15 @@ mod tests { builder.register_public_inputs(&initial_hash.elements); // Hash from the previous proof. let old_hash = builder.add_virtual_hash(); - // Flag set to true if the last proof was a base case. - let old_base_case = builder.add_virtual_target(); // The input hash is either the previous hash or the initial hash depending on whether // the last proof was a base case. - let input_hash = builder.select_hash( - BoolTarget::new_unsafe(old_base_case), - initial_hash, - old_hash, - ); + let input_hash = builder.add_virtual_hash(); let h = builder.hash_n_to_hash_no_pad::(input_hash.elements.to_vec()); builder.register_public_inputs(&h.elements); // Previous counter. let old_counter = builder.add_virtual_target(); let one = builder.one(); - let old_not_base_case = builder.sub(one, old_base_case); - // New counter is the previous counter +1 if the previous proof wasn't a base case. - let new_counter = builder.add(old_counter, old_not_base_case); - builder.register_public_input(new_counter); + let new_counter = builder.add_virtual_public_input(); let old_pis = [ initial_hash.elements.as_slice(), old_hash.elements.as_slice(), @@ -411,11 +366,19 @@ mod tests { ] .concat(); - let common_data = common_data_for_recursion::(); + let mut common_data = common_data_for_recursion::(); // Add cyclic recursion gadget. - let (cyclic_circuit_data, cyclic_data_target) = - builder.cyclic_recursion::(&old_pis, old_base_case, common_data)?; + let cyclic_data_target = builder.cyclic_recursion::(&old_pis, &mut common_data)?; + let input_hash_bis = + builder.select_hash(cyclic_data_target.base_case, initial_hash, old_hash); + builder.connect_hashes(input_hash, input_hash_bis); + let not_base_case = builder.sub(one, cyclic_data_target.base_case.target); + // New counter is the previous counter +1 if the previous proof wasn't a base case. + let new_counter_bis = builder.add(old_counter, not_base_case); + builder.connect(new_counter, new_counter_bis); + + let cyclic_circuit_data = builder.build::(); let cyclic_recursion_data = CyclicRecursionData { proof: &None, // Base case: We don't have a proof to put here yet. @@ -482,6 +445,7 @@ mod tests { let initial_hash = &proof.public_inputs[..4]; let hash = &proof.public_inputs[4..8]; let counter = proof.public_inputs[8]; + dbg!(counter); let mut h: [F; 4] = initial_hash.try_into().unwrap(); assert_eq!( hash,