diff --git a/ecdsa/src/gadgets/glv.rs b/ecdsa/src/gadgets/glv.rs index 4302023e..539b5de3 100644 --- a/ecdsa/src/gadgets/glv.rs +++ b/ecdsa/src/gadgets/glv.rs @@ -55,8 +55,8 @@ impl, const D: usize> CircuitBuilderGlv ) { let k1 = self.add_virtual_nonnative_target_sized::(4); let k2 = self.add_virtual_nonnative_target_sized::(4); - let k1_neg = self.add_virtual_bool_target(); - let k2_neg = self.add_virtual_bool_target(); + let k1_neg = self.add_virtual_bool_target_unsafe(); + let k2_neg = self.add_virtual_bool_target_unsafe(); self.add_simple_generator(GLVDecompositionGenerator:: { k: k.clone(), diff --git a/ecdsa/src/gadgets/nonnative.rs b/ecdsa/src/gadgets/nonnative.rs index c6ff4753..29520bed 100644 --- a/ecdsa/src/gadgets/nonnative.rs +++ b/ecdsa/src/gadgets/nonnative.rs @@ -183,7 +183,7 @@ impl, const D: usize> CircuitBuilderNonNative b: &NonNativeTarget, ) -> NonNativeTarget { let sum = self.add_virtual_nonnative_target::(); - let overflow = self.add_virtual_bool_target(); + let overflow = self.add_virtual_bool_target_unsafe(); self.add_simple_generator(NonNativeAdditionGenerator:: { a: a.clone(), @@ -282,7 +282,7 @@ impl, const D: usize> CircuitBuilderNonNative b: &NonNativeTarget, ) -> NonNativeTarget { let diff = self.add_virtual_nonnative_target::(); - let overflow = self.add_virtual_bool_target(); + let overflow = self.add_virtual_bool_target_unsafe(); self.add_simple_generator(NonNativeSubtractionGenerator:: { a: a.clone(), diff --git a/plonky2/src/gadgets/arithmetic.rs b/plonky2/src/gadgets/arithmetic.rs index f4722df4..33facd74 100644 --- a/plonky2/src/gadgets/arithmetic.rs +++ b/plonky2/src/gadgets/arithmetic.rs @@ -345,7 +345,7 @@ impl, const D: usize> CircuitBuilder { pub fn is_equal(&mut self, x: Target, y: Target) -> BoolTarget { let zero = self.zero(); - let equal = self.add_virtual_bool_target(); + let equal = self.add_virtual_bool_target_unsafe(); let not_equal = self.not(equal); let inv = self.add_virtual_target(); self.add_simple_generator(EqualityGenerator { x, y, equal, inv }); diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index d5a748e5..24e83d01 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -202,8 +202,7 @@ impl, const D: usize> CircuitBuilder { PolynomialCoeffsExtTarget(coeffs) } - // TODO: Unsafe - pub fn add_virtual_bool_target(&mut self) -> BoolTarget { + pub fn add_virtual_bool_target_unsafe(&mut self) -> BoolTarget { BoolTarget::new_unsafe(self.add_virtual_target()) } diff --git a/plonky2/src/recursion/cyclic_recursion.rs b/plonky2/src/recursion/cyclic_recursion.rs index 0a3308e6..abf7f7a2 100644 --- a/plonky2/src/recursion/cyclic_recursion.rs +++ b/plonky2/src/recursion/cyclic_recursion.rs @@ -1,9 +1,12 @@ -use anyhow::Result; +use anyhow::{ensure, Result}; +use itertools::Itertools; use plonky2_field::extension::Extendable; +use plonky2_field::types::Field; use crate::gates::noop::NoopGate; -use crate::hash::hash_types::RichField; -use crate::iop::target::BoolTarget; +use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget, RichField}; +use crate::hash::merkle_tree::MerkleCap; +use crate::iop::target::{BoolTarget, Target}; use crate::iop::witness::{PartialWitness, Witness}; use crate::plonk::circuit_builder::CircuitBuilder; use crate::plonk::circuit_data::{ @@ -33,6 +36,85 @@ pub struct CyclicRecursionTarget { pub base_case: BoolTarget, } +pub struct CyclicPublicInputs< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +> { + pub circuit_digest: HashOut, + pub constants_sigmas_cap: MerkleCap, + pub base_case: bool, +} + +impl, C: GenericConfig, const D: usize> + CyclicPublicInputs +{ + fn from_slice(slice: &[F], common_data: &CommonCircuitData) -> Result + where + C::Hasher: AlgebraicHasher, + { + // The structure of the public inputs is `[...,circuit_digest, constants_sigmas_cap, base_case]`. + let cap_len = common_data.config.fri_config.num_cap_elements(); + let len = slice.len(); + ensure!(len >= 4 + 4 * cap_len + 1, "Not enough public inputs"); + let base_case = slice[len - 1]; + ensure!( + base_case.is_one() || base_case.is_zero(), + "Base case flag {:?} is not binary", + base_case + ); + let constants_sigmas_cap = MerkleCap( + (0..cap_len) + .map(|i| HashOut { + elements: std::array::from_fn(|j| slice[len - 1 - 4 * (cap_len - i) + j]), + }) + .collect(), + ); + let circuit_digest = + HashOut::from_partial(&slice[len - 5 - 4 * cap_len..len - 1 - 4 * cap_len]); + + Ok(Self { + circuit_digest, + constants_sigmas_cap, + base_case: base_case.is_one(), + }) + } +} + +pub struct CyclicPublicInputsTarget { + pub circuit_digest: HashOutTarget, + pub constants_sigmas_cap: MerkleCapTarget, + pub base_case: Target, +} + +impl CyclicPublicInputsTarget { + fn from_slice, C: GenericConfig, const D: usize>( + slice: &[Target], + common_data: &CommonCircuitData, + ) -> Result { + let cap_len = common_data.config.fri_config.num_cap_elements(); + let len = slice.len(); + ensure!(len >= 4 + 4 * cap_len + 1, "Not enough public inputs"); + let base_case = slice[len - 1]; + let constants_sigmas_cap = MerkleCapTarget( + (0..cap_len) + .map(|i| HashOutTarget { + elements: std::array::from_fn(|j| slice[len - 1 - 4 * (cap_len - i) + j]), + }) + .collect(), + ); + let circuit_digest = HashOutTarget { + elements: std::array::from_fn(|i| slice[len - 5 - 4 * cap_len + i]), + }; + + Ok(Self { + circuit_digest, + constants_sigmas_cap, + base_case, + }) + } +} + impl, const D: usize> CircuitBuilder { pub fn cyclic_recursion>( mut self, @@ -46,24 +128,39 @@ impl, const D: usize> CircuitBuilder { constants_sigmas_cap: self.add_virtual_cap(self.config.fri_config.cap_height), circuit_digest: self.add_virtual_hash(), }; + // The verifier data are public inputs. self.register_public_inputs(&verifier_data.circuit_digest.elements); for i in 0..self.config.fri_config.num_cap_elements() { self.register_public_inputs(&verifier_data.constants_sigmas_cap.0[i].elements); } + let dummy_verifier_data = VerifierCircuitTarget { constants_sigmas_cap: self.add_virtual_cap(self.config.fri_config.cap_height), circuit_digest: self.add_virtual_hash(), }; - let base_case = self.add_virtual_bool_target(); + // Unsafe is ok since `base_case` is a public input and its booleaness should be checked in the verifier. + let base_case = self.add_virtual_bool_target_unsafe(); self.register_public_input(base_case.target); common_data.num_public_inputs = self.num_public_inputs(); + // The `conditionally_verify_proof` gadget below takes 2^12 gates, so `degree_bits` cannot be smaller than 13. common_data.degree_bits = common_data.degree_bits.max(13); common_data.fri_params.degree_bits = common_data.fri_params.degree_bits.max(13); let proof = self.add_virtual_proof_with_pis(&common_data); let dummy_proof = self.add_virtual_proof_with_pis(&common_data); + let pis = CyclicPublicInputsTarget::from_slice(&proof.public_inputs, &common_data)?; + self.connect_hashes(pis.circuit_digest, verifier_data.circuit_digest); + for (h0, h1) in pis + .constants_sigmas_cap + .0 + .iter() + .zip_eq(&verifier_data.constants_sigmas_cap.0) + { + self.connect_hashes(*h0, *h1); + } + self.conditionally_verify_proof( base_case, &dummy_proof, @@ -107,7 +204,6 @@ pub fn set_cyclic_recursion_data_target< cyclic_recursion_data: &CyclicRecursionData, ) -> Result<()> where - F: RichField + Extendable, C::Hasher: AlgebraicHasher, [(); C::Hasher::HASH_SIZE]:, { @@ -124,12 +220,29 @@ where cyclic_recursion_data.verifier_data, ); } else { + dbg!("hi"); let (dummy_proof, dummy_data) = dummy_proof(cyclic_recursion_data.common_data)?; pw.set_bool_target(cyclic_recursion_data_target.base_case, true); - pw.set_proof_with_pis_target(&cyclic_recursion_data_target.proof, &dummy_proof); + let mut dummy_proof_real_vd = dummy_proof.clone(); + let pis_len = dummy_proof_real_vd.public_inputs.len(); + let num_cap = cyclic_recursion_data + .common_data + .config + .fri_config + .num_cap_elements(); + let s = pis_len - 5 - 4 * num_cap; + dummy_proof_real_vd.public_inputs[s..s + 4] + .copy_from_slice(&cyclic_recursion_data.verifier_data.circuit_digest.elements); + for i in 0..num_cap { + dummy_proof_real_vd.public_inputs[s + 4 * (1 + i)..s + 4 * (2 + i)].copy_from_slice( + &cyclic_recursion_data.verifier_data.constants_sigmas_cap.0[i].elements, + ); + } + pw.set_proof_with_pis_target(&cyclic_recursion_data_target.proof, &dummy_proof_real_vd); + dbg!(cyclic_recursion_data.verifier_data.circuit_digest); pw.set_verifier_data_target( &cyclic_recursion_data_target.verifier_data, - &dummy_data.verifier_only, + cyclic_recursion_data.verifier_data, ); pw.set_proof_with_pis_target(&cyclic_recursion_data_target.dummy_proof, &dummy_proof); pw.set_verifier_data_target( @@ -141,6 +254,29 @@ where Ok(()) } +pub fn check_cyclic_proof_verifier_data< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + proof: &ProofWithPublicInputs, + verifier_data: &VerifierOnlyCircuitData, + common_data: &CommonCircuitData, +) -> Result<()> +where + C::Hasher: AlgebraicHasher, +{ + let pis = CyclicPublicInputs::from_slice(&proof.public_inputs, common_data)?; + dbg!(pis.circuit_digest); + dbg!(verifier_data.circuit_digest); + if !pis.base_case { + ensure!(verifier_data.constants_sigmas_cap == pis.constants_sigmas_cap); + ensure!(verifier_data.circuit_digest == pis.circuit_digest); + } + + Ok(()) +} + #[cfg(test)] mod tests { use anyhow::Result; @@ -154,7 +290,7 @@ mod tests { use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData, VerifierCircuitTarget}; use crate::plonk::config::{AlgebraicHasher, GenericConfig, Hasher, PoseidonGoldilocksConfig}; use crate::recursion::cyclic_recursion::{ - set_cyclic_recursion_data_target, CyclicRecursionData, + check_cyclic_proof_verifier_data, set_cyclic_recursion_data_target, CyclicRecursionData, }; fn common_data_for_recursion< @@ -183,7 +319,7 @@ mod tests { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); let config = CircuitConfig::standard_recursion_config(); - let mut pw = PartialWitness::::new(); + let pw = PartialWitness::::new(); let mut builder = CircuitBuilder::::new(config); let proof = builder.add_virtual_proof_with_pis(&data.common); let verifier_data = VerifierCircuitTarget { @@ -220,8 +356,47 @@ mod tests { common_data: &cyclic_circuit_data.common, }; set_cyclic_recursion_data_target(&mut pw, &cyclic_data_target, &cyclic_recursion_data)?; + dbg!("yo"); let proof = cyclic_circuit_data.prove(pw)?; - cyclic_circuit_data.verify(proof); + check_cyclic_proof_verifier_data( + &proof, + &cyclic_recursion_data.verifier_data, + cyclic_recursion_data.common_data, + )?; + cyclic_circuit_data.verify(proof.clone())?; + + let mut pw = PartialWitness::new(); + pw.set_target(t, F::rand()); + let cyclic_recursion_data = CyclicRecursionData { + proof: &Some(proof), + verifier_data: &cyclic_circuit_data.verifier_only, + common_data: &cyclic_circuit_data.common, + }; + set_cyclic_recursion_data_target(&mut pw, &cyclic_data_target, &cyclic_recursion_data)?; + dbg!("yo"); + let proof = cyclic_circuit_data.prove(pw)?; + check_cyclic_proof_verifier_data( + &proof, + &cyclic_recursion_data.verifier_data, + cyclic_recursion_data.common_data, + )?; + cyclic_circuit_data.verify(proof.clone())?; + + let mut pw = PartialWitness::new(); + pw.set_target(t, F::rand()); + let cyclic_recursion_data = CyclicRecursionData { + proof: &Some(proof), + verifier_data: &cyclic_circuit_data.verifier_only, + common_data: &cyclic_circuit_data.common, + }; + set_cyclic_recursion_data_target(&mut pw, &cyclic_data_target, &cyclic_recursion_data)?; + let proof = cyclic_circuit_data.prove(pw)?; + check_cyclic_proof_verifier_data( + &proof, + &cyclic_recursion_data.verifier_data, + cyclic_recursion_data.common_data, + )?; + cyclic_circuit_data.verify(proof.clone())?; Ok(()) }