diff --git a/evm/src/cross_table_lookup.rs b/evm/src/cross_table_lookup.rs index 4930321a..2d55c6db 100644 --- a/evm/src/cross_table_lookup.rs +++ b/evm/src/cross_table_lookup.rs @@ -8,7 +8,6 @@ use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; -use plonky2::iop::challenger::Challenger; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::iop::target::Target; use plonky2::plonk::circuit_builder::CircuitBuilder; @@ -17,10 +16,8 @@ use plonky2::plonk::config::GenericConfig; use crate::all_stark::{Table, NUM_TABLES}; use crate::config::StarkConfig; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::permutation::{ - get_grand_product_challenge_set, GrandProductChallenge, GrandProductChallengeSet, -}; -use crate::proof::{StarkProof, StarkProofTarget}; +use crate::permutation::{GrandProductChallenge, GrandProductChallengeSet}; +use crate::proof::{StarkProofTarget, StarkProofWithMetadata}; use crate::stark::Stark; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; @@ -163,6 +160,7 @@ pub struct CrossTableLookup { pub(crate) looking_tables: Vec>, pub(crate) looked_table: TableWithColumns, /// Default value if filters are not used. + // TODO: Remove? Ended up not using it. default: Option>, } @@ -234,13 +232,11 @@ impl CtlData { } } -pub fn cross_table_lookup_data, const D: usize>( - config: &StarkConfig, +pub(crate) fn cross_table_lookup_data, const D: usize>( trace_poly_values: &[Vec>; NUM_TABLES], cross_table_lookups: &[CrossTableLookup], - challenger: &mut Challenger, + ctl_challenges: &GrandProductChallengeSet, ) -> [CtlData; NUM_TABLES] { - let challenges = get_grand_product_challenge_set(challenger, config.num_challenges); let mut ctl_data_per_table = [0; NUM_TABLES].map(|_| CtlData::default()); for CrossTableLookup { looking_tables, @@ -249,7 +245,7 @@ pub fn cross_table_lookup_data, const D } in cross_table_lookups { log::debug!("Processing CTL for {:?}", looked_table.table); - for &challenge in &challenges.challenges { + for &challenge in &ctl_challenges.challenges { let zs_looking = looking_tables.iter().map(|table| { partial_products( &trace_poly_values[table.table as usize], @@ -358,7 +354,7 @@ impl<'a, F: RichField + Extendable, const D: usize> CtlCheckVars<'a, F, F::Extension, F::Extension, D> { pub(crate) fn from_proofs>( - proofs: &[StarkProof; NUM_TABLES], + proofs: &[StarkProofWithMetadata; NUM_TABLES], cross_table_lookups: &'a [CrossTableLookup], ctl_challenges: &'a GrandProductChallengeSet, num_permutation_zs: &[usize; NUM_TABLES], @@ -367,7 +363,7 @@ impl<'a, F: RichField + Extendable, const D: usize> .iter() .zip(num_permutation_zs) .map(|(p, &num_perms)| { - let openings = &p.openings; + let openings = &p.proof.openings; let ctl_zs = openings.permutation_ctl_zs.iter().skip(num_perms); let ctl_zs_next = openings.permutation_ctl_zs_next.iter().skip(num_perms); ctl_zs.zip(ctl_zs_next) @@ -582,7 +578,7 @@ pub(crate) fn verify_cross_table_lookups< C: GenericConfig, const D: usize, >( - cross_table_lookups: Vec>, + cross_table_lookups: &[CrossTableLookup], ctl_zs_lasts: [Vec; NUM_TABLES], degrees_bits: [usize; NUM_TABLES], challenges: GrandProductChallengeSet, @@ -597,7 +593,7 @@ pub(crate) fn verify_cross_table_lookups< default, .. }, - ) in cross_table_lookups.into_iter().enumerate() + ) in cross_table_lookups.iter().enumerate() { for _ in 0..config.num_challenges { let looking_degrees_sum = looking_tables @@ -635,47 +631,23 @@ pub(crate) fn verify_cross_table_lookups_circuit< builder: &mut CircuitBuilder, cross_table_lookups: Vec>, ctl_zs_lasts: [Vec; NUM_TABLES], - degrees_bits: [usize; NUM_TABLES], - challenges: GrandProductChallengeSet, inner_config: &StarkConfig, ) { let mut ctl_zs_openings = ctl_zs_lasts.iter().map(|v| v.iter()).collect::>(); - for ( - i, - CrossTableLookup { - looking_tables, - looked_table, - default, - .. - }, - ) in cross_table_lookups.into_iter().enumerate() + for CrossTableLookup { + looking_tables, + looked_table, + .. + } in cross_table_lookups.into_iter() { for _ in 0..inner_config.num_challenges { - let looking_degrees_sum = looking_tables - .iter() - .map(|table| 1 << degrees_bits[table.table as usize]) - .sum::(); - let looked_degree = 1 << degrees_bits[looked_table.table as usize]; let looking_zs_prod = builder.mul_many( looking_tables .iter() .map(|table| *ctl_zs_openings[table.table as usize].next().unwrap()), ); let looked_z = *ctl_zs_openings[looked_table.table as usize].next().unwrap(); - let challenge = challenges.challenges[i % inner_config.num_challenges]; - if let Some(default) = default.as_ref() { - let default = default - .iter() - .map(|&x| builder.constant(x)) - .collect::>(); - let combined_default = challenge.combine_base_circuit(builder, &default); - - let pad = builder.exp_u64(combined_default, looking_degrees_sum - looked_degree); - let padded_looked_z = builder.mul(looked_z, pad); - builder.connect(looking_zs_prod, padded_looked_z); - } else { - builder.connect(looking_zs_prod, looked_z); - } + builder.connect(looking_zs_prod, looked_z); } } debug_assert!(ctl_zs_openings.iter_mut().all(|iter| iter.next().is_none())); diff --git a/evm/src/fixed_recursive_verifier.rs b/evm/src/fixed_recursive_verifier.rs new file mode 100644 index 00000000..8fa19b3b --- /dev/null +++ b/evm/src/fixed_recursive_verifier.rs @@ -0,0 +1,410 @@ +use std::collections::BTreeMap; +use std::ops::Range; + +use itertools::Itertools; +use plonky2::field::extension::Extendable; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::hashing::SPONGE_WIDTH; +use plonky2::iop::challenger::RecursiveChallenger; +use plonky2::iop::target::Target; +use plonky2::iop::witness::{PartialWitness, WitnessWrite}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData}; +use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; +use plonky2::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}; +use plonky2::util::timing::TimingTree; + +use crate::all_stark::{all_cross_table_lookups, AllStark, Table, NUM_TABLES}; +use crate::config::StarkConfig; +use crate::cpu::cpu_stark::CpuStark; +use crate::cross_table_lookup::{verify_cross_table_lookups_circuit, CrossTableLookup}; +use crate::generation::GenerationInputs; +use crate::keccak::keccak_stark::KeccakStark; +use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; +use crate::logic::LogicStark; +use crate::memory::memory_stark::MemoryStark; +use crate::permutation::{get_grand_product_challenge_set_target, GrandProductChallengeSet}; +use crate::proof::StarkProofWithMetadata; +use crate::prover::prove; +use crate::recursive_verifier::{ + add_common_recursion_gates, recursive_stark_circuit, PlonkWrapperCircuit, PublicInputs, + StarkWrapperCircuit, +}; +use crate::stark::Stark; + +/// The recursion threshold. We end a chain of recursive proofs once we reach this size. +const THRESHOLD_DEGREE_BITS: usize = 13; + +/// Contains all recursive circuits used in the system. For each STARK and each initial +/// `degree_bits`, this contains a chain of recursive circuits for shrinking that STARK from +/// `degree_bits` to a constant `THRESHOLD_DEGREE_BITS`. It also contains a special root circuit +/// for combining each STARK's shrunk wrapper proof into a single proof. +pub struct AllRecursiveCircuits +where + F: RichField + Extendable, + C: GenericConfig, +{ + /// The root circuit, which aggregates the (shrunk) per-table recursive proofs. + pub root: RootCircuitData, + /// Holds chains of circuits for each table and for each initial `degree_bits`. + by_table: [RecursiveCircuitsForTable; NUM_TABLES], +} + +/// Data for the special root circuit, which is used to combine each STARK's shrunk wrapper proof +/// into a single proof. +pub struct RootCircuitData +where + F: RichField + Extendable, + C: GenericConfig, +{ + pub circuit: CircuitData, + proof_with_pis: [ProofWithPublicInputsTarget; NUM_TABLES], + /// For each table, various inner circuits may be used depending on the initial table size. + /// This target holds the index of the circuit (within `final_circuits()`) that was used. + index_verifier_data: [Target; NUM_TABLES], +} + +impl AllRecursiveCircuits +where + F: RichField + Extendable, + C: GenericConfig, + C::Hasher: AlgebraicHasher, + [(); C::Hasher::HASH_SIZE]:, + [(); CpuStark::::COLUMNS]:, + [(); KeccakStark::::COLUMNS]:, + [(); KeccakSpongeStark::::COLUMNS]:, + [(); LogicStark::::COLUMNS]:, + [(); MemoryStark::::COLUMNS]:, +{ + /// Preprocess all recursive circuits used by the system. + pub fn new( + all_stark: &AllStark, + degree_bits_range: Range, + stark_config: &StarkConfig, + ) -> Self { + let cpu = RecursiveCircuitsForTable::new( + Table::Cpu, + &all_stark.cpu_stark, + degree_bits_range.clone(), + &all_stark.cross_table_lookups, + stark_config, + ); + let keccak = RecursiveCircuitsForTable::new( + Table::Keccak, + &all_stark.keccak_stark, + degree_bits_range.clone(), + &all_stark.cross_table_lookups, + stark_config, + ); + let keccak_sponge = RecursiveCircuitsForTable::new( + Table::KeccakSponge, + &all_stark.keccak_sponge_stark, + degree_bits_range.clone(), + &all_stark.cross_table_lookups, + stark_config, + ); + let logic = RecursiveCircuitsForTable::new( + Table::Logic, + &all_stark.logic_stark, + degree_bits_range.clone(), + &all_stark.cross_table_lookups, + stark_config, + ); + let memory = RecursiveCircuitsForTable::new( + Table::Memory, + &all_stark.memory_stark, + degree_bits_range, + &all_stark.cross_table_lookups, + stark_config, + ); + + let by_table = [cpu, keccak, keccak_sponge, logic, memory]; + let root = Self::create_root_circuit(&by_table, stark_config); + Self { root, by_table } + } + + fn create_root_circuit( + by_table: &[RecursiveCircuitsForTable; NUM_TABLES], + stark_config: &StarkConfig, + ) -> RootCircuitData { + let inner_common_data: [_; NUM_TABLES] = + std::array::from_fn(|i| &by_table[i].final_circuits()[0].common); + + let mut builder = CircuitBuilder::new(CircuitConfig::standard_recursion_config()); + let recursive_proofs = + std::array::from_fn(|i| builder.add_virtual_proof_with_pis::(inner_common_data[i])); + let pis: [_; NUM_TABLES] = std::array::from_fn(|i| { + PublicInputs::from_vec(&recursive_proofs[i].public_inputs, stark_config) + }); + let index_verifier_data = std::array::from_fn(|_i| builder.add_virtual_target()); + + let mut challenger = RecursiveChallenger::::new(&mut builder); + for pi in &pis { + for h in &pi.trace_cap { + challenger.observe_elements(h); + } + } + let ctl_challenges = get_grand_product_challenge_set_target( + &mut builder, + &mut challenger, + stark_config.num_challenges, + ); + // Check that the correct CTL challenges are used in every proof. + for pi in &pis { + for i in 0..stark_config.num_challenges { + builder.connect( + ctl_challenges.challenges[i].beta, + pi.ctl_challenges.challenges[i].beta, + ); + builder.connect( + ctl_challenges.challenges[i].gamma, + pi.ctl_challenges.challenges[i].gamma, + ); + } + } + + let state = challenger.compact(&mut builder); + for k in 0..SPONGE_WIDTH { + builder.connect(state[k], pis[0].challenger_state_before[k]); + } + // Check that the challenger state is consistent between proofs. + for i in 1..NUM_TABLES { + for k in 0..SPONGE_WIDTH { + builder.connect( + pis[i].challenger_state_before[k], + pis[i - 1].challenger_state_after[k], + ); + } + } + + // Verify the CTL checks. + verify_cross_table_lookups_circuit::( + &mut builder, + all_cross_table_lookups(), + pis.map(|p| p.ctl_zs_last), + stark_config, + ); + + for (i, table_circuits) in by_table.iter().enumerate() { + let final_circuits = table_circuits.final_circuits(); + for final_circuit in &final_circuits { + assert_eq!( + &final_circuit.common, inner_common_data[i], + "common_data mismatch" + ); + } + let mut possible_vks = final_circuits + .into_iter() + .map(|c| builder.constant_verifier_data(&c.verifier_only)) + .collect_vec(); + // random_access_verifier_data expects a vector whose length is a power of two. + // To satisfy this, we will just add some duplicates of the first VK. + while !possible_vks.len().is_power_of_two() { + possible_vks.push(possible_vks[0].clone()); + } + let inner_verifier_data = + builder.random_access_verifier_data(index_verifier_data[i], possible_vks); + + builder.verify_proof::( + &recursive_proofs[i], + &inner_verifier_data, + inner_common_data[i], + ); + } + + RootCircuitData { + circuit: builder.build(), + proof_with_pis: recursive_proofs, + index_verifier_data, + } + } + + /// Create a proof for each STARK, then combine them, eventually culminating in a root proof. + pub fn prove_root( + &self, + all_stark: &AllStark, + config: &StarkConfig, + generation_inputs: GenerationInputs, + timing: &mut TimingTree, + ) -> anyhow::Result> { + let all_proof = prove::(all_stark, config, generation_inputs, timing)?; + let mut root_inputs = PartialWitness::new(); + for table in 0..NUM_TABLES { + let stark_proof = &all_proof.stark_proofs[table]; + let original_degree_bits = stark_proof.proof.recover_degree_bits(config); + let table_circuits = &self.by_table[table]; + let shrunk_proof = table_circuits.by_stark_size[&original_degree_bits] + .shrink(stark_proof, &all_proof.ctl_challenges)?; + let index_verifier_data = table_circuits + .by_stark_size + .keys() + .position(|&size| size == original_degree_bits) + .unwrap(); + root_inputs.set_target( + self.root.index_verifier_data[table], + F::from_canonical_usize(index_verifier_data), + ); + root_inputs.set_proof_with_pis_target(&self.root.proof_with_pis[table], &shrunk_proof); + } + self.root.circuit.prove(root_inputs) + } +} + +struct RecursiveCircuitsForTable +where + F: RichField + Extendable, + C: GenericConfig, +{ + /// A map from `log_2(height)` to a chain of shrinking recursion circuits starting at that + /// height. + by_stark_size: BTreeMap>, +} + +impl RecursiveCircuitsForTable +where + F: RichField + Extendable, + C: GenericConfig, + C::Hasher: AlgebraicHasher, + [(); C::Hasher::HASH_SIZE]:, +{ + fn new>( + table: Table, + stark: &S, + degree_bits_range: Range, + all_ctls: &[CrossTableLookup], + stark_config: &StarkConfig, + ) -> Self + where + [(); S::COLUMNS]:, + { + let by_stark_size = degree_bits_range + .map(|degree_bits| { + ( + degree_bits, + RecursiveCircuitsForTableSize::new::( + table, + stark, + degree_bits, + all_ctls, + stark_config, + ), + ) + }) + .collect(); + Self { by_stark_size } + } + + /// For each initial `degree_bits`, get the final circuit at the end of that shrinking chain. + /// Each of these final circuits should have degree `THRESHOLD_DEGREE_BITS`. + fn final_circuits(&self) -> Vec<&CircuitData> { + self.by_stark_size + .values() + .map(|chain| { + chain + .shrinking_wrappers + .last() + .map(|wrapper| &wrapper.circuit) + .unwrap_or(&chain.initial_wrapper.circuit) + }) + .collect() + } +} + +/// A chain of shrinking wrapper circuits, ending with a final circuit with `degree_bits` +/// `THRESHOLD_DEGREE_BITS`. +struct RecursiveCircuitsForTableSize +where + F: RichField + Extendable, + C: GenericConfig, +{ + initial_wrapper: StarkWrapperCircuit, + shrinking_wrappers: Vec>, +} + +impl RecursiveCircuitsForTableSize +where + F: RichField + Extendable, + C: GenericConfig, + C::Hasher: AlgebraicHasher, + [(); C::Hasher::HASH_SIZE]:, +{ + fn new>( + table: Table, + stark: &S, + degree_bits: usize, + all_ctls: &[CrossTableLookup], + stark_config: &StarkConfig, + ) -> Self + where + [(); S::COLUMNS]:, + { + let initial_wrapper = recursive_stark_circuit( + table, + stark, + degree_bits, + all_ctls, + stark_config, + &shrinking_config(), + THRESHOLD_DEGREE_BITS, + ); + let mut shrinking_wrappers = vec![]; + + // Shrinking recursion loop. + loop { + let last = shrinking_wrappers + .last() + .map(|wrapper: &PlonkWrapperCircuit| &wrapper.circuit) + .unwrap_or(&initial_wrapper.circuit); + let last_degree_bits = last.common.degree_bits(); + assert!(last_degree_bits >= THRESHOLD_DEGREE_BITS); + if last_degree_bits == THRESHOLD_DEGREE_BITS { + break; + } + + let mut builder = CircuitBuilder::new(shrinking_config()); + let proof_with_pis_target = builder.add_virtual_proof_with_pis::(&last.common); + let last_vk = builder.constant_verifier_data(&last.verifier_only); + builder.verify_proof::(&proof_with_pis_target, &last_vk, &last.common); + builder.register_public_inputs(&proof_with_pis_target.public_inputs); // carry PIs forward + add_common_recursion_gates(&mut builder); + let circuit = builder.build(); + + assert!( + circuit.common.degree_bits() < last_degree_bits, + "Couldn't shrink to expected recursion threshold of 2^{}; stalled at 2^{}", + THRESHOLD_DEGREE_BITS, + circuit.common.degree_bits() + ); + shrinking_wrappers.push(PlonkWrapperCircuit { + circuit, + proof_with_pis_target, + }); + } + + Self { + initial_wrapper, + shrinking_wrappers, + } + } + + fn shrink( + &self, + stark_proof_with_metadata: &StarkProofWithMetadata, + ctl_challenges: &GrandProductChallengeSet, + ) -> anyhow::Result> { + let mut proof = self + .initial_wrapper + .prove(stark_proof_with_metadata, ctl_challenges)?; + for wrapper_circuit in &self.shrinking_wrappers { + proof = wrapper_circuit.prove(&proof)?; + } + Ok(proof) + } +} + +fn shrinking_config() -> CircuitConfig { + CircuitConfig { + num_routed_wires: 40, + ..CircuitConfig::standard_recursion_config() + } +} diff --git a/evm/src/get_challenges.rs b/evm/src/get_challenges.rs index ede7c466..f368b7c2 100644 --- a/evm/src/get_challenges.rs +++ b/evm/src/get_challenges.rs @@ -23,7 +23,7 @@ impl, C: GenericConfig, const D: usize> A let mut challenger = Challenger::::new(); for proof in &self.stark_proofs { - challenger.observe_cap(&proof.trace_cap); + challenger.observe_cap(&proof.proof.trace_cap); } // TODO: Observe public values. @@ -37,7 +37,7 @@ impl, C: GenericConfig, const D: usize> A AllProofChallenges { stark_challenges: std::array::from_fn(|i| { challenger.compact(); - self.stark_proofs[i].get_challenges( + self.stark_proofs[i].proof.get_challenges( &mut challenger, num_permutation_zs[i] > 0, num_permutation_batch_sizes[i], @@ -57,7 +57,7 @@ impl, C: GenericConfig, const D: usize> A let mut challenger = Challenger::::new(); for proof in &self.stark_proofs { - challenger.observe_cap(&proof.trace_cap); + challenger.observe_cap(&proof.proof.trace_cap); } // TODO: Observe public values. @@ -70,7 +70,7 @@ impl, C: GenericConfig, const D: usize> A let mut challenger_states = vec![challenger.compact()]; for i in 0..NUM_TABLES { - self.stark_proofs[i].get_challenges( + self.stark_proofs[i].proof.get_challenges( &mut challenger, num_permutation_zs[i] > 0, num_permutation_batch_sizes[i], diff --git a/evm/src/lib.rs b/evm/src/lib.rs index 4c368491..6ca956c4 100644 --- a/evm/src/lib.rs +++ b/evm/src/lib.rs @@ -12,6 +12,7 @@ pub mod config; pub mod constraint_consumer; pub mod cpu; pub mod cross_table_lookup; +pub mod fixed_recursive_verifier; pub mod generation; mod get_challenges; pub mod keccak; diff --git a/evm/src/permutation.rs b/evm/src/permutation.rs index b081c309..4f42a4aa 100644 --- a/evm/src/permutation.rs +++ b/evm/src/permutation.rs @@ -15,9 +15,7 @@ use plonky2::iop::ext_target::ExtensionTarget; use plonky2::iop::target::Target; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; -use plonky2::plonk::plonk_common::{ - reduce_with_powers, reduce_with_powers_circuit, reduce_with_powers_ext_circuit, -}; +use plonky2::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_circuit}; use plonky2::util::reducing::{ReducingFactor, ReducingFactorTarget}; use crate::config::StarkConfig; @@ -82,15 +80,6 @@ impl GrandProductChallenge { let gamma = builder.convert_to_ext(self.gamma); builder.add_extension(reduced, gamma) } - - pub(crate) fn combine_base_circuit, const D: usize>( - &self, - builder: &mut CircuitBuilder, - terms: &[Target], - ) -> Target { - let reduced = reduce_with_powers_circuit(builder, terms, self.beta); - builder.add(reduced, self.gamma) - } } /// Like `PermutationChallenge`, but with `num_challenges` copies to boost soundness. diff --git a/evm/src/proof.rs b/evm/src/proof.rs index 4cd03a65..46b88ca2 100644 --- a/evm/src/proof.rs +++ b/evm/src/proof.rs @@ -19,15 +19,17 @@ use crate::all_stark::NUM_TABLES; use crate::config::StarkConfig; use crate::permutation::GrandProductChallengeSet; +/// A STARK proof for each table, plus some metadata used to create recursive wrapper proofs. #[derive(Debug, Clone)] pub struct AllProof, C: GenericConfig, const D: usize> { - pub stark_proofs: [StarkProof; NUM_TABLES], + pub stark_proofs: [StarkProofWithMetadata; NUM_TABLES], + pub(crate) ctl_challenges: GrandProductChallengeSet, pub public_values: PublicValues, } impl, C: GenericConfig, const D: usize> AllProof { pub fn degree_bits(&self, config: &StarkConfig) -> [usize; NUM_TABLES] { - std::array::from_fn(|i| self.stark_proofs[i].recover_degree_bits(config)) + std::array::from_fn(|i| self.stark_proofs[i].proof.recover_degree_bits(config)) } } @@ -113,6 +115,18 @@ pub struct StarkProof, C: GenericConfig, pub opening_proof: FriProof, } +/// A `StarkProof` along with some metadata about the initial Fiat-Shamir state, which is used when +/// creating a recursive wrapper proof around a STARK proof. +#[derive(Debug, Clone)] +pub struct StarkProofWithMetadata +where + F: RichField + Extendable, + C: GenericConfig, +{ + pub(crate) init_challenger_state: [F; SPONGE_WIDTH], + pub(crate) proof: StarkProof, +} + impl, C: GenericConfig, const D: usize> StarkProof { /// Recover the length of the trace from a STARK proof and a STARK config. pub fn recover_degree_bits(&self, config: &StarkConfig) -> usize { diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 55b57437..a777e1ad 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -29,10 +29,10 @@ use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; use crate::logic::LogicStark; use crate::memory::memory_stark::MemoryStark; use crate::permutation::{ - compute_permutation_z_polys, get_n_grand_product_challenge_sets, GrandProductChallengeSet, - PermutationCheckVars, + compute_permutation_z_polys, get_grand_product_challenge_set, + get_n_grand_product_challenge_sets, GrandProductChallengeSet, PermutationCheckVars, }; -use crate::proof::{AllProof, PublicValues, StarkOpeningSet, StarkProof}; +use crate::proof::{AllProof, PublicValues, StarkOpeningSet, StarkProof, StarkProofWithMetadata}; use crate::stark::Stark; use crate::vanishing_poly::eval_vanishing_poly; use crate::vars::StarkEvaluationVars; @@ -117,14 +117,14 @@ where challenger.observe_cap(cap); } + let ctl_challenges = get_grand_product_challenge_set(&mut challenger, config.num_challenges); let ctl_data_per_table = timed!( timing, "compute CTL data", cross_table_lookup_data::( - config, &trace_poly_values, &all_stark.cross_table_lookups, - &mut challenger, + &ctl_challenges, ) ); @@ -144,6 +144,7 @@ where Ok(AllProof { stark_proofs, + ctl_challenges, public_values, }) } @@ -156,7 +157,7 @@ fn prove_with_commitments( ctl_data_per_table: [CtlData; NUM_TABLES], challenger: &mut Challenger, timing: &mut TimingTree, -) -> Result<[StarkProof; NUM_TABLES]> +) -> Result<[StarkProofWithMetadata; NUM_TABLES]> where F: RichField + Extendable, C: GenericConfig, @@ -250,7 +251,7 @@ pub(crate) fn prove_single_table( ctl_data: &CtlData, challenger: &mut Challenger, timing: &mut TimingTree, -) -> Result> +) -> Result> where F: RichField + Extendable, C: GenericConfig, @@ -268,7 +269,7 @@ where "FRI total reduction arity is too large.", ); - challenger.compact(); + let init_challenger_state = challenger.compact(); // Permutation arguments. let permutation_challenges = stark.uses_permutation_args().then(|| { @@ -411,12 +412,16 @@ where ) ); - Ok(StarkProof { + let proof = StarkProof { trace_cap: trace_commitment.merkle_tree.cap.clone(), permutation_ctl_zs_cap, quotient_polys_cap, openings, opening_proof, + }; + Ok(StarkProofWithMetadata { + init_challenger_state, + proof, }) } diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index 0f713e32..59a37602 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -5,18 +5,24 @@ use itertools::Itertools; use plonky2::field::extension::Extendable; use plonky2::field::types::Field; use plonky2::fri::witness_util::set_fri_proof_target; +use plonky2::gates::exponentiation::ExponentiationGate; +use plonky2::gates::gate::GateRef; +use plonky2::gates::noop::NoopGate; use plonky2::hash::hash_types::{HashOut, RichField}; use plonky2::hash::hashing::SPONGE_WIDTH; use plonky2::iop::challenger::{Challenger, RecursiveChallenger}; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::iop::target::Target; -use plonky2::iop::witness::Witness; +use plonky2::iop::witness::{PartialWitness, Witness, WitnessWrite}; use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::circuit_data::{CircuitConfig, VerifierCircuitData, VerifierCircuitTarget}; +use plonky2::plonk::circuit_data::{ + CircuitConfig, CircuitData, VerifierCircuitData, VerifierCircuitTarget, +}; use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; use plonky2::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}; use plonky2::util::reducing::ReducingFactorTarget; use plonky2::with_context; +use plonky2_util::log2_ceil; use crate::all_stark::{all_cross_table_lookups, AllStark, Table, NUM_TABLES}; use crate::config::StarkConfig; @@ -36,8 +42,8 @@ use crate::permutation::{ }; use crate::proof::{ AllProof, AllProofTarget, BlockMetadata, BlockMetadataTarget, PublicValues, PublicValuesTarget, - StarkOpeningSetTarget, StarkProof, StarkProofChallengesTarget, StarkProofTarget, TrieRoots, - TrieRootsTarget, + StarkOpeningSetTarget, StarkProof, StarkProofChallengesTarget, StarkProofTarget, + StarkProofWithMetadata, TrieRoots, TrieRootsTarget, }; use crate::stark::Stark; use crate::util::{h160_limbs, h256_limbs}; @@ -58,12 +64,12 @@ pub struct RecursiveAllProofTargetWithData { pub verifier_data: [VerifierCircuitTarget; NUM_TABLES], } -struct PublicInputs { - trace_cap: Vec>, - ctl_zs_last: Vec, - ctl_challenges: GrandProductChallengeSet, - challenger_state_before: [T; SPONGE_WIDTH], - challenger_state_after: [T; SPONGE_WIDTH], +pub(crate) struct PublicInputs { + pub(crate) trace_cap: Vec>, + pub(crate) ctl_zs_last: Vec, + pub(crate) ctl_challenges: GrandProductChallengeSet, + pub(crate) challenger_state_before: [T; SPONGE_WIDTH], + pub(crate) challenger_state_after: [T; SPONGE_WIDTH], } /// Similar to the unstable `Iterator::next_chunk`. Could be replaced with that when it's stable. @@ -76,9 +82,10 @@ fn next_chunk(iter: &mut impl Iterator) -> [ } impl PublicInputs { - fn from_vec(v: &[T], config: &StarkConfig) -> Self { + pub(crate) fn from_vec(v: &[T], config: &StarkConfig) -> Self { + log::info!("from_vec {}", v.len()); let mut iter = v.iter().copied(); - let trace_cap = (0..1 << config.fri_config.cap_height) + let trace_cap = (0..config.fri_config.num_cap_elements()) .map(|_| next_chunk::<_, 4>(&mut iter).to_vec()) .collect(); let ctl_challenges = GrandProductChallengeSet { @@ -91,7 +98,8 @@ impl PublicInputs { }; let challenger_state_before = next_chunk(&mut iter); let challenger_state_after = next_chunk(&mut iter); - let ctl_zs_last = iter.collect(); + let ctl_zs_last: Vec<_> = iter.collect(); + log::info!("from_vec num Zs: {}", ctl_zs_last.len()); // TODO Self { trace_cap, @@ -143,7 +151,7 @@ impl, C: GenericConfig, const D: usize> // Verify the CTL checks. let degrees_bits = std::array::from_fn(|i| verifier_data[i].common.degree_bits()); verify_cross_table_lookups::( - cross_table_lookups, + &cross_table_lookups, pis.map(|p| p.ctl_zs_last), degrees_bits, ctl_challenges, @@ -158,6 +166,7 @@ impl, C: GenericConfig, const D: usize> } /// Recursively verify every recursive proof. + // TODO: Remove? Replaced by fixed_recursive_verifier. pub fn verify_circuit( builder: &mut CircuitBuilder, recursive_all_proof_target: RecursiveAllProofTargetWithData, @@ -215,13 +224,10 @@ impl, C: GenericConfig, const D: usize> } // Verify the CTL checks. - let degrees_bits = std::array::from_fn(|i| verifier_data[i].common.degree_bits()); verify_cross_table_lookups_circuit::( builder, all_cross_table_lookups(), pis.map(|p| p.ctl_zs_last), - degrees_bits, - ctl_challenges, inner_config, ); for (i, (recursive_proof, verifier_data_target)) in recursive_proofs @@ -238,33 +244,113 @@ impl, C: GenericConfig, const D: usize> } } -/// Returns the verifier data for the recursive Stark circuit. -fn verifier_data_recursive_stark_proof< +/// Represents a circuit which recursively verifies a STARK proof. +pub(crate) struct StarkWrapperCircuit +where + F: RichField + Extendable, + C: GenericConfig, +{ + pub(crate) circuit: CircuitData, + pub(crate) stark_proof_target: StarkProofTarget, + pub(crate) ctl_challenges_target: GrandProductChallengeSet, + pub(crate) init_challenger_state_target: [Target; SPONGE_WIDTH], + pub(crate) zero_target: Target, +} + +impl StarkWrapperCircuit +where + F: RichField + Extendable, + C: GenericConfig, + C::Hasher: AlgebraicHasher, +{ + pub(crate) fn prove( + &self, + proof_with_metadata: &StarkProofWithMetadata, + ctl_challenges: &GrandProductChallengeSet, + ) -> Result> { + let mut inputs = PartialWitness::new(); + + set_stark_proof_target( + &mut inputs, + &self.stark_proof_target, + &proof_with_metadata.proof, + self.zero_target, + ); + + for (challenge_target, challenge) in self + .ctl_challenges_target + .challenges + .iter() + .zip(&ctl_challenges.challenges) + { + inputs.set_target(challenge_target.beta, challenge.beta); + inputs.set_target(challenge_target.gamma, challenge.gamma); + } + + inputs.set_target_arr( + self.init_challenger_state_target, + proof_with_metadata.init_challenger_state, + ); + + self.circuit.prove(inputs) + } +} + +/// Represents a circuit which recursively verifies a PLONK proof. +pub(crate) struct PlonkWrapperCircuit +where + F: RichField + Extendable, + C: GenericConfig, +{ + pub(crate) circuit: CircuitData, + pub(crate) proof_with_pis_target: ProofWithPublicInputsTarget, +} + +impl PlonkWrapperCircuit +where + F: RichField + Extendable, + C: GenericConfig, + C::Hasher: AlgebraicHasher, +{ + pub(crate) fn prove( + &self, + proof: &ProofWithPublicInputs, + ) -> Result> { + let mut inputs = PartialWitness::new(); + inputs.set_proof_with_pis_target(&self.proof_with_pis_target, proof); + self.circuit.prove(inputs) + } +} + +/// Returns the recursive Stark circuit. +pub(crate) fn recursive_stark_circuit< F: RichField + Extendable, C: GenericConfig, S: Stark, const D: usize, >( table: Table, - stark: S, + stark: &S, degree_bits: usize, cross_table_lookups: &[CrossTableLookup], inner_config: &StarkConfig, circuit_config: &CircuitConfig, -) -> VerifierCircuitData + min_degree_bits: usize, +) -> StarkWrapperCircuit where [(); S::COLUMNS]:, [(); C::Hasher::HASH_SIZE]:, C::Hasher: AlgebraicHasher, { let mut builder = CircuitBuilder::::new(circuit_config.clone()); + let zero_target = builder.zero(); let num_permutation_zs = stark.num_permutation_batches(inner_config); let num_permutation_batch_size = stark.permutation_batch_size(); let num_ctl_zs = CrossTableLookup::num_ctl_zs(cross_table_lookups, table, inner_config.num_challenges); let proof_target = - add_virtual_stark_proof(&mut builder, &stark, inner_config, degree_bits, num_ctl_zs); + add_virtual_stark_proof(&mut builder, stark, inner_config, degree_bits, num_ctl_zs); builder.register_public_inputs( &proof_target .trace_cap @@ -291,8 +377,9 @@ where num_permutation_zs, ); - let challenger_state = std::array::from_fn(|_| builder.add_virtual_public_input()); - let mut challenger = RecursiveChallenger::::from_state(challenger_state); + let init_challenger_state_target = std::array::from_fn(|_| builder.add_virtual_public_input()); + let mut challenger = + RecursiveChallenger::::from_state(init_challenger_state_target); let challenges = proof_target.get_challenges::( &mut builder, &mut challenger, @@ -307,14 +394,39 @@ where verify_stark_proof_with_challenges_circuit::( &mut builder, - &stark, + stark, &proof_target, &challenges, &ctl_vars, inner_config, ); - builder.build_verifier::() + add_common_recursion_gates(&mut builder); + + // Pad to the minimum degree. + while log2_ceil(builder.num_gates()) < min_degree_bits { + builder.add_gate(NoopGate, vec![]); + } + + let circuit = builder.build::(); + StarkWrapperCircuit { + circuit, + stark_proof_target: proof_target, + ctl_challenges_target, + init_challenger_state_target, + zero_target, + } +} + +/// Add gates that are sometimes used by recursive circuits, even if it's not actually used by this +/// particular recursive circuit. This is done for uniformity. We sometimes want all recursion +/// circuits to have the same gate set, so that we can do 1-of-n conditional recursion efficiently. +pub(crate) fn add_common_recursion_gates, const D: usize>( + builder: &mut CircuitBuilder, +) { + builder.add_gate_to_gate_set(GateRef::new(ExponentiationGate::new_from_config( + &builder.config, + ))); } /// Returns the recursive Stark circuit verifier data for every Stark in `AllStark`. @@ -338,46 +450,61 @@ where C::Hasher: AlgebraicHasher, { [ - verifier_data_recursive_stark_proof( + recursive_stark_circuit( Table::Cpu, - all_stark.cpu_stark, + &all_stark.cpu_stark, degree_bits[Table::Cpu as usize], &all_stark.cross_table_lookups, inner_config, circuit_config, - ), - verifier_data_recursive_stark_proof( + 0, + ) + .circuit + .verifier_data(), + recursive_stark_circuit( Table::Keccak, - all_stark.keccak_stark, + &all_stark.keccak_stark, degree_bits[Table::Keccak as usize], &all_stark.cross_table_lookups, inner_config, circuit_config, - ), - verifier_data_recursive_stark_proof( + 0, + ) + .circuit + .verifier_data(), + recursive_stark_circuit( Table::KeccakSponge, - all_stark.keccak_sponge_stark, + &all_stark.keccak_sponge_stark, degree_bits[Table::KeccakSponge as usize], &all_stark.cross_table_lookups, inner_config, circuit_config, - ), - verifier_data_recursive_stark_proof( + 0, + ) + .circuit + .verifier_data(), + recursive_stark_circuit( Table::Logic, - all_stark.logic_stark, + &all_stark.logic_stark, degree_bits[Table::Logic as usize], &all_stark.cross_table_lookups, inner_config, circuit_config, - ), - verifier_data_recursive_stark_proof( + 0, + ) + .circuit + .verifier_data(), + recursive_stark_circuit( Table::Memory, - all_stark.memory_stark, + &all_stark.memory_stark, degree_bits[Table::Memory as usize], &all_stark.cross_table_lookups, inner_config, circuit_config, - ), + 0, + ) + .circuit + .verifier_data(), ] } @@ -580,7 +707,7 @@ where VerifierCircuitTarget { constants_sigmas_cap: builder .constant_merkle_cap(&verifier_data.verifier_only.constants_sigmas_cap), - circuit_digest: builder.add_virtual_hash(), + circuit_digest: builder.constant_hash(verifier_data.verifier_only.circuit_digest), } }); RecursiveAllProofTargetWithData { @@ -714,7 +841,7 @@ pub fn set_all_proof_target, W, const D: usize>( .iter() .zip_eq(&all_proof.stark_proofs) { - set_stark_proof_target(witness, pt, p, zero); + set_stark_proof_target(witness, pt, &p.proof, zero); } set_public_value_targets( witness, @@ -994,7 +1121,7 @@ pub(crate) mod tests { recursively_verify_stark_proof( Table::Cpu, all_stark.cpu_stark, - &all_proof.stark_proofs[Table::Cpu as usize], + &all_proof.stark_proofs[Table::Cpu as usize].proof, &all_stark.cross_table_lookups, &ctl_challenges, states[0], @@ -1005,7 +1132,7 @@ pub(crate) mod tests { recursively_verify_stark_proof( Table::Keccak, all_stark.keccak_stark, - &all_proof.stark_proofs[Table::Keccak as usize], + &all_proof.stark_proofs[Table::Keccak as usize].proof, &all_stark.cross_table_lookups, &ctl_challenges, states[1], @@ -1016,7 +1143,7 @@ pub(crate) mod tests { recursively_verify_stark_proof( Table::KeccakSponge, all_stark.keccak_sponge_stark, - &all_proof.stark_proofs[Table::KeccakSponge as usize], + &all_proof.stark_proofs[Table::KeccakSponge as usize].proof, &all_stark.cross_table_lookups, &ctl_challenges, states[2], @@ -1027,7 +1154,7 @@ pub(crate) mod tests { recursively_verify_stark_proof( Table::Logic, all_stark.logic_stark, - &all_proof.stark_proofs[Table::Logic as usize], + &all_proof.stark_proofs[Table::Logic as usize].proof, &all_stark.cross_table_lookups, &ctl_challenges, states[3], @@ -1038,7 +1165,7 @@ pub(crate) mod tests { recursively_verify_stark_proof( Table::Memory, all_stark.memory_stark, - &all_proof.stark_proofs[Table::Memory as usize], + &all_proof.stark_proofs[Table::Memory as usize].proof, &all_stark.cross_table_lookups, &ctl_challenges, states[4], diff --git a/evm/src/verifier.rs b/evm/src/verifier.rs index a0329d04..cb75ad33 100644 --- a/evm/src/verifier.rs +++ b/evm/src/verifier.rs @@ -26,7 +26,7 @@ use crate::vanishing_poly::eval_vanishing_poly; use crate::vars::StarkEvaluationVars; pub fn verify_proof, C: GenericConfig, const D: usize>( - all_stark: AllStark, + all_stark: &AllStark, all_proof: AllProof, config: &StarkConfig, ) -> Result<()> @@ -41,7 +41,7 @@ where let AllProofChallenges { stark_challenges, ctl_challenges, - } = all_proof.get_challenges(&all_stark, config); + } = all_proof.get_challenges(all_stark, config); let nums_permutation_zs = all_stark.nums_permutation_zs(config); @@ -56,52 +56,52 @@ where let ctl_vars_per_table = CtlCheckVars::from_proofs( &all_proof.stark_proofs, - &cross_table_lookups, + cross_table_lookups, &ctl_challenges, &nums_permutation_zs, ); verify_stark_proof_with_challenges( cpu_stark, - &all_proof.stark_proofs[Table::Cpu as usize], + &all_proof.stark_proofs[Table::Cpu as usize].proof, &stark_challenges[Table::Cpu as usize], &ctl_vars_per_table[Table::Cpu as usize], config, )?; verify_stark_proof_with_challenges( keccak_stark, - &all_proof.stark_proofs[Table::Keccak as usize], + &all_proof.stark_proofs[Table::Keccak as usize].proof, &stark_challenges[Table::Keccak as usize], &ctl_vars_per_table[Table::Keccak as usize], config, )?; verify_stark_proof_with_challenges( keccak_sponge_stark, - &all_proof.stark_proofs[Table::KeccakSponge as usize], + &all_proof.stark_proofs[Table::KeccakSponge as usize].proof, &stark_challenges[Table::KeccakSponge as usize], &ctl_vars_per_table[Table::KeccakSponge as usize], config, )?; verify_stark_proof_with_challenges( memory_stark, - &all_proof.stark_proofs[Table::Memory as usize], + &all_proof.stark_proofs[Table::Memory as usize].proof, &stark_challenges[Table::Memory as usize], &ctl_vars_per_table[Table::Memory as usize], config, )?; verify_stark_proof_with_challenges( logic_stark, - &all_proof.stark_proofs[Table::Logic as usize], + &all_proof.stark_proofs[Table::Logic as usize].proof, &stark_challenges[Table::Logic as usize], &ctl_vars_per_table[Table::Logic as usize], config, )?; let degrees_bits = - std::array::from_fn(|i| all_proof.stark_proofs[i].recover_degree_bits(config)); + std::array::from_fn(|i| all_proof.stark_proofs[i].proof.recover_degree_bits(config)); verify_cross_table_lookups::( cross_table_lookups, - all_proof.stark_proofs.map(|p| p.openings.ctl_zs_last), + all_proof.stark_proofs.map(|p| p.proof.openings.ctl_zs_last), degrees_bits, ctl_challenges, config, @@ -114,7 +114,7 @@ pub(crate) fn verify_stark_proof_with_challenges< S: Stark, const D: usize, >( - stark: S, + stark: &S, proof: &StarkProof, challenges: &StarkProofChallenges, ctl_vars: &[CtlCheckVars], @@ -125,7 +125,7 @@ where [(); C::Hasher::HASH_SIZE]:, { log::debug!("Checking proof: {}", type_name::()); - validate_proof_shape(&stark, proof, config, ctl_vars.len())?; + validate_proof_shape(stark, proof, config, ctl_vars.len())?; let StarkOpeningSet { local_values, next_values, @@ -160,7 +160,7 @@ where permutation_challenge_sets: challenges.permutation_challenge_sets.clone().unwrap(), }); eval_vanishing_poly::( - &stark, + stark, config, vars, permutation_data, diff --git a/evm/tests/empty_txn_list.rs b/evm/tests/empty_txn_list.rs index aa7b60b9..1e564758 100644 --- a/evm/tests/empty_txn_list.rs +++ b/evm/tests/empty_txn_list.rs @@ -8,6 +8,7 @@ use plonky2::plonk::config::PoseidonGoldilocksConfig; use plonky2::util::timing::TimingTree; use plonky2_evm::all_stark::AllStark; use plonky2_evm::config::StarkConfig; +use plonky2_evm::fixed_recursive_verifier::AllRecursiveCircuits; use plonky2_evm::generation::{GenerationInputs, TrieInputs}; use plonky2_evm::proof::BlockMetadata; use plonky2_evm::prover::prove; @@ -49,7 +50,10 @@ fn test_empty_txn_list() -> anyhow::Result<()> { }; let mut timing = TimingTree::new("prove", log::Level::Debug); - let proof = prove::(&all_stark, &config, inputs, &mut timing)?; + // TODO: This is redundant; prove_root below calls this prove method internally. + // Just keeping it for now because the root proof returned by prove_root doesn't contain public + // values yet, and we want those for the assertions below. + let proof = prove::(&all_stark, &config, inputs.clone(), &mut timing)?; timing.filter(Duration::from_millis(100)).print(); assert_eq!( @@ -77,7 +81,11 @@ fn test_empty_txn_list() -> anyhow::Result<()> { receipts_trie_root ); - verify_proof(all_stark, proof, &config) + verify_proof(&all_stark, proof, &config)?; + + let all_circuits = AllRecursiveCircuits::::new(&all_stark, 9..19, &config); + let root_proof = all_circuits.prove_root(&all_stark, &config, inputs, &mut timing)?; + all_circuits.root.circuit.verify(root_proof) } fn init_logger() { diff --git a/evm/tests/transfer_to_new_addr.rs b/evm/tests/transfer_to_new_addr.rs index e4fe8eb4..4a5e69e3 100644 --- a/evm/tests/transfer_to_new_addr.rs +++ b/evm/tests/transfer_to_new_addr.rs @@ -104,7 +104,7 @@ fn test_simple_transfer() -> anyhow::Result<()> { expected_state_trie_after.calc_hash() ); - verify_proof(all_stark, proof, &config) + verify_proof(&all_stark, proof, &config) } fn eth_to_wei(eth: U256) -> U256 { diff --git a/plonky2/examples/bench_recursion.rs b/plonky2/examples/bench_recursion.rs index 0ad12e76..bf6b0e6b 100644 --- a/plonky2/examples/bench_recursion.rs +++ b/plonky2/examples/bench_recursion.rs @@ -14,9 +14,7 @@ use plonky2::gates::noop::NoopGate; use plonky2::hash::hash_types::RichField; use plonky2::iop::witness::{PartialWitness, WitnessWrite}; use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::circuit_data::{ - CircuitConfig, CommonCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData, -}; +use plonky2::plonk::circuit_data::{CircuitConfig, CommonCircuitData, VerifierOnlyCircuitData}; use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, PoseidonGoldilocksConfig}; use plonky2::plonk::proof::{CompressedProofWithPublicInputs, ProofWithPublicInputs}; use plonky2::plonk::prover::prove; @@ -107,10 +105,7 @@ where let mut builder = CircuitBuilder::::new(config.clone()); let pt = builder.add_virtual_proof_with_pis::(inner_cd); - let inner_data = VerifierCircuitTarget { - constants_sigmas_cap: builder.add_virtual_cap(inner_cd.config.fri_config.cap_height), - circuit_digest: builder.add_virtual_hash(), - }; + let inner_data = builder.add_virtual_verifier_data(inner_cd.config.fri_config.cap_height); builder.verify_proof::(&pt, &inner_data, inner_cd); builder.print_gate_counts(0); diff --git a/plonky2/src/gadgets/random_access.rs b/plonky2/src/gadgets/random_access.rs index d3a3ff1b..73e3de8c 100644 --- a/plonky2/src/gadgets/random_access.rs +++ b/plonky2/src/gadgets/random_access.rs @@ -2,15 +2,15 @@ use alloc::vec::Vec; use crate::field::extension::Extendable; use crate::gates::random_access::RandomAccessGate; -use crate::hash::hash_types::RichField; +use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget, RichField}; use crate::iop::ext_target::ExtensionTarget; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::VerifierCircuitTarget; use crate::util::log2_strict; impl, const D: usize> CircuitBuilder { - /// Checks that a `Target` matches a vector at a non-deterministic index. - /// Note: `access_index` is not range-checked. + /// Checks that a `Target` matches a vector at a particular index. pub fn random_access(&mut self, access_index: Target, v: Vec) -> Target { let vec_size = v.len(); let bits = log2_strict(vec_size); @@ -38,18 +38,64 @@ impl, const D: usize> CircuitBuilder { claimed_element } - /// Checks that an `ExtensionTarget` matches a vector at a non-deterministic index. - /// Note: `access_index` is not range-checked. + /// Like `random_access`, but with `ExtensionTarget`s rather than simple `Target`s. pub fn random_access_extension( &mut self, access_index: Target, v: Vec>, ) -> ExtensionTarget { - let v: Vec<_> = (0..D) + let selected: Vec<_> = (0..D) .map(|i| self.random_access(access_index, v.iter().map(|et| et.0[i]).collect())) .collect(); - ExtensionTarget(v.try_into().unwrap()) + ExtensionTarget(selected.try_into().unwrap()) + } + + /// Like `random_access`, but with `HashOutTarget`s rather than simple `Target`s. + pub fn random_access_hash( + &mut self, + access_index: Target, + v: Vec, + ) -> HashOutTarget { + let selected = std::array::from_fn(|i| { + self.random_access( + access_index, + v.iter().map(|hash| hash.elements[i]).collect(), + ) + }); + selected.into() + } + + /// Like `random_access`, but with `MerkleCapTarget`s rather than simple `Target`s. + pub fn random_access_merkle_cap( + &mut self, + access_index: Target, + v: Vec, + ) -> MerkleCapTarget { + let cap_size = v[0].0.len(); + assert!(v.iter().all(|cap| cap.0.len() == cap_size)); + + let selected = (0..cap_size) + .map(|i| self.random_access_hash(access_index, v.iter().map(|cap| cap.0[i]).collect())) + .collect(); + MerkleCapTarget(selected) + } + + /// Like `random_access`, but with `VerifierCircuitTarget`s rather than simple `Target`s. + pub fn random_access_verifier_data( + &mut self, + access_index: Target, + v: Vec, + ) -> VerifierCircuitTarget { + let constants_sigmas_caps = v.iter().map(|vk| vk.constants_sigmas_cap.clone()).collect(); + let circuit_digests = v.iter().map(|vk| vk.circuit_digest).collect(); + let constants_sigmas_cap = + self.random_access_merkle_cap(access_index, constants_sigmas_caps); + let circuit_digest = self.random_access_hash(access_index, circuit_digests); + VerifierCircuitTarget { + constants_sigmas_cap, + circuit_digest, + } } } diff --git a/plonky2/src/hash/merkle_proofs.rs b/plonky2/src/hash/merkle_proofs.rs index cd696e55..3726d847 100644 --- a/plonky2/src/hash/merkle_proofs.rs +++ b/plonky2/src/hash/merkle_proofs.rs @@ -11,6 +11,7 @@ use crate::hash::hashing::SPONGE_WIDTH; use crate::hash::merkle_tree::MerkleCap; use crate::iop::target::{BoolTarget, Target}; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::VerifierCircuitTarget; use crate::plonk::config::{AlgebraicHasher, Hasher}; #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] @@ -152,6 +153,11 @@ impl, const D: usize> CircuitBuilder { self.connect_hashes(*h0, *h1); } } + + pub fn connect_verifier_data(&mut self, x: &VerifierCircuitTarget, y: &VerifierCircuitTarget) { + self.connect_merkle_caps(&x.constants_sigmas_cap, &y.constants_sigmas_cap); + self.connect_hashes(x.circuit_digest, y.circuit_digest); + } } #[cfg(test)] diff --git a/plonky2/src/hash/merkle_tree.rs b/plonky2/src/hash/merkle_tree.rs index 86871701..7a9cd1f3 100644 --- a/plonky2/src/hash/merkle_tree.rs +++ b/plonky2/src/hash/merkle_tree.rs @@ -14,6 +14,7 @@ use crate::util::log2_strict; /// It can be used in place of the root to verify Merkle paths, which are `h` elements shorter. #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] #[serde(bound = "")] +// TODO: Change H to GenericHashOut, since this only cares about the hash, not the hasher. pub struct MerkleCap>(pub Vec); impl> MerkleCap { diff --git a/plonky2/src/iop/witness.rs b/plonky2/src/iop/witness.rs index e9ade3d2..14213d0d 100644 --- a/plonky2/src/iop/witness.rs +++ b/plonky2/src/iop/witness.rs @@ -14,7 +14,7 @@ use crate::iop::ext_target::ExtensionTarget; use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; use crate::plonk::circuit_data::{VerifierCircuitTarget, VerifierOnlyCircuitData}; -use crate::plonk::config::{AlgebraicHasher, GenericConfig}; +use crate::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; use crate::plonk::proof::{Proof, ProofTarget, ProofWithPublicInputs, ProofWithPublicInputsTarget}; pub trait WitnessWrite { @@ -224,6 +224,19 @@ pub trait Witness: WitnessWrite { } } + fn get_merkle_cap_target>(&self, cap_target: MerkleCapTarget) -> MerkleCap + where + F: RichField, + H: AlgebraicHasher, + { + let cap = cap_target + .0 + .iter() + .map(|hash_target| self.get_hash_target(*hash_target)) + .collect(); + MerkleCap(cap) + } + fn get_wire(&self, wire: Wire) -> F { self.get_target(Target::Wire(wire)) } diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index 9017a059..21305f4d 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -245,6 +245,13 @@ impl, const D: usize> CircuitBuilder { t } + pub fn add_virtual_verifier_data(&mut self, cap_height: usize) -> VerifierCircuitTarget { + VerifierCircuitTarget { + constants_sigmas_cap: self.add_virtual_cap(cap_height), + circuit_digest: self.add_virtual_hash(), + } + } + /// Add a virtual verifier data, register it as a public input and set it to `self.verifier_data_public_input`. /// WARNING: Do not register any public input after calling this! TODO: relax this pub fn add_verifier_data_public_inputs(&mut self) -> VerifierCircuitTarget { @@ -253,10 +260,7 @@ impl, const D: usize> CircuitBuilder { "add_verifier_data_public_inputs only needs to be called once" ); - let verifier_data = VerifierCircuitTarget { - constants_sigmas_cap: self.add_virtual_cap(self.config.fri_config.cap_height), - circuit_digest: self.add_virtual_hash(), - }; + let verifier_data = self.add_virtual_verifier_data(self.config.fri_config.cap_height); // The verifier data are public inputs. self.register_public_inputs(&verifier_data.circuit_digest.elements); for i in 0..self.config.fri_config.num_cap_elements() { @@ -784,13 +788,13 @@ impl, const D: usize> CircuitBuilder { self.add_simple_generator(const_gen); } - info!( + debug!( "Degree before blinding & padding: {}", self.gate_instances.len() ); self.blind_and_pad(); let degree = self.gate_instances.len(); - info!("Degree after blinding & padding: {}", degree); + debug!("Degree after blinding & padding: {}", degree); let degree_bits = log2_strict(degree); let fri_params = self.fri_params(degree_bits); assert!( diff --git a/plonky2/src/plonk/circuit_data.rs b/plonky2/src/plonk/circuit_data.rs index 6df986fa..c57e206d 100644 --- a/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/src/plonk/circuit_data.rs @@ -149,15 +149,15 @@ impl, C: GenericConfig, const D: usize> proof.decompress(&self.verifier_only.circuit_digest, &self.common) } - pub fn verifier_data(self) -> VerifierCircuitData { + pub fn verifier_data(&self) -> VerifierCircuitData { let CircuitData { verifier_only, common, .. } = self; VerifierCircuitData { - verifier_only, - common, + verifier_only: verifier_only.clone(), + common: common.clone(), } } @@ -258,7 +258,7 @@ pub struct ProverOnlyCircuitData< } /// Circuit data required by the verifier, but not the prover. -#[derive(Debug, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq)] pub struct VerifierOnlyCircuitData, const D: usize> { /// A commitment to each constant polynomial and each permutation polynomial. pub constants_sigmas_cap: MerkleCap, @@ -289,7 +289,7 @@ pub struct CommonCircuitData, const D: usize> { /// The number of constant wires. pub(crate) num_constants: usize, - pub(crate) num_public_inputs: usize, + pub num_public_inputs: usize, /// The `{k_i}` valued used in `S_ID_i` in Plonk's permutation argument. pub(crate) k_is: Vec, diff --git a/plonky2/src/recursion/conditional_recursive_verifier.rs b/plonky2/src/recursion/conditional_recursive_verifier.rs index ace47cab..2596e1f0 100644 --- a/plonky2/src/recursion/conditional_recursive_verifier.rs +++ b/plonky2/src/recursion/conditional_recursive_verifier.rs @@ -378,15 +378,11 @@ mod tests { pw.set_proof_with_pis_target(&pt, &proof); let dummy_pt = builder.add_virtual_proof_with_pis::(&data.common); pw.set_proof_with_pis_target::(&dummy_pt, &dummy_proof); - let inner_data = VerifierCircuitTarget { - constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height), - circuit_digest: builder.add_virtual_hash(), - }; + let inner_data = + builder.add_virtual_verifier_data(data.common.config.fri_config.cap_height); pw.set_verifier_data_target(&inner_data, &data.verifier_only); - let dummy_inner_data = VerifierCircuitTarget { - constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height), - circuit_digest: builder.add_virtual_hash(), - }; + let dummy_inner_data = + builder.add_virtual_verifier_data(data.common.config.fri_config.cap_height); pw.set_verifier_data_target(&dummy_inner_data, &dummy_data.verifier_only); let b = builder.constant_bool(F::rand().0 % 2 == 0); builder.conditionally_verify_proof::( diff --git a/plonky2/src/recursion/cyclic_recursion.rs b/plonky2/src/recursion/cyclic_recursion.rs index a12c31d4..656ff3b9 100644 --- a/plonky2/src/recursion/cyclic_recursion.rs +++ b/plonky2/src/recursion/cyclic_recursion.rs @@ -188,7 +188,7 @@ mod tests { use crate::hash::poseidon::{PoseidonHash, PoseidonPermutation}; use crate::iop::witness::{PartialWitness, WitnessWrite}; use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData, VerifierCircuitTarget}; + use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData}; use crate::plonk::config::{AlgebraicHasher, GenericConfig, PoseidonGoldilocksConfig}; use crate::recursion::cyclic_recursion::check_cyclic_proof_verifier_data; use crate::recursion::dummy_circuit::cyclic_base_proof; @@ -208,20 +208,16 @@ mod tests { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); let proof = builder.add_virtual_proof_with_pis::(&data.common); - let verifier_data = VerifierCircuitTarget { - constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height), - circuit_digest: builder.add_virtual_hash(), - }; + let verifier_data = + builder.add_virtual_verifier_data(data.common.config.fri_config.cap_height); builder.verify_proof::(&proof, &verifier_data, &data.common); let data = builder.build::(); let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); let proof = builder.add_virtual_proof_with_pis::(&data.common); - let verifier_data = VerifierCircuitTarget { - constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height), - circuit_digest: builder.add_virtual_hash(), - }; + let verifier_data = + builder.add_virtual_verifier_data(data.common.config.fri_config.cap_height); builder.verify_proof::(&proof, &verifier_data, &data.common); while builder.num_gates() < 1 << 12 { builder.add_gate(NoopGate, vec![]); diff --git a/plonky2/src/recursion/dummy_circuit.rs b/plonky2/src/recursion/dummy_circuit.rs index 34b20c71..38f51aea 100644 --- a/plonky2/src/recursion/dummy_circuit.rs +++ b/plonky2/src/recursion/dummy_circuit.rs @@ -113,11 +113,8 @@ impl, const D: usize> CircuitBuilder { let dummy_circuit = dummy_circuit::(common_data); let dummy_proof_with_pis = dummy_proof(&dummy_circuit, HashMap::new())?; let dummy_proof_with_pis_target = self.add_virtual_proof_with_pis::(common_data); - - let dummy_verifier_data_target = VerifierCircuitTarget { - constants_sigmas_cap: self.add_virtual_cap(self.config.fri_config.cap_height), - circuit_digest: self.add_virtual_hash(), - }; + let dummy_verifier_data_target = + self.add_virtual_verifier_data(self.config.fri_config.cap_height); self.add_simple_generator(DummyProofGenerator { proof_with_pis_target: dummy_proof_with_pis_target.clone(), diff --git a/plonky2/src/recursion/recursive_verifier.rs b/plonky2/src/recursion/recursive_verifier.rs index 15943a87..9aafb1f5 100644 --- a/plonky2/src/recursion/recursive_verifier.rs +++ b/plonky2/src/recursion/recursive_verifier.rs @@ -366,10 +366,7 @@ mod tests { let pt = builder.add_virtual_proof_with_pis::(&inner_cd); pw.set_proof_with_pis_target(&pt, &inner_proof); - let inner_data = VerifierCircuitTarget { - constants_sigmas_cap: builder.add_virtual_cap(inner_cd.config.fri_config.cap_height), - circuit_digest: builder.add_virtual_hash(), - }; + let inner_data = builder.add_virtual_verifier_data(inner_cd.config.fri_config.cap_height); pw.set_cap_target( &inner_data.constants_sigmas_cap, &inner_vd.constants_sigmas_cap,