diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index 5ec4041b..40e67ba7 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -114,7 +114,7 @@ fn ctl_keccak() -> CrossTableLookup { keccak_stark::ctl_data(), Some(keccak_stark::ctl_filter()), ); - CrossTableLookup::new(vec![keccak_sponge_looking], keccak_looked, None) + CrossTableLookup::new(vec![keccak_sponge_looking], keccak_looked) } fn ctl_keccak_sponge() -> CrossTableLookup { @@ -128,7 +128,7 @@ fn ctl_keccak_sponge() -> CrossTableLookup { keccak_sponge_stark::ctl_looked_data(), Some(keccak_sponge_stark::ctl_looked_filter()), ); - CrossTableLookup::new(vec![cpu_looking], keccak_sponge_looked, None) + CrossTableLookup::new(vec![cpu_looking], keccak_sponge_looked) } fn ctl_logic() -> CrossTableLookup { @@ -148,7 +148,7 @@ fn ctl_logic() -> CrossTableLookup { } let logic_looked = TableWithColumns::new(Table::Logic, logic::ctl_data(), Some(logic::ctl_filter())); - CrossTableLookup::new(all_lookers, logic_looked, None) + CrossTableLookup::new(all_lookers, logic_looked) } fn ctl_memory() -> CrossTableLookup { @@ -180,5 +180,5 @@ fn ctl_memory() -> CrossTableLookup { memory_stark::ctl_data(), Some(memory_stark::ctl_filter()), ); - CrossTableLookup::new(all_lookers, memory_looked, None) + CrossTableLookup::new(all_lookers, memory_looked) } diff --git a/evm/src/cross_table_lookup.rs b/evm/src/cross_table_lookup.rs index 4930321a..1b184a50 100644 --- a/evm/src/cross_table_lookup.rs +++ b/evm/src/cross_table_lookup.rs @@ -8,7 +8,6 @@ use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; -use plonky2::iop::challenger::Challenger; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::iop::target::Target; use plonky2::plonk::circuit_builder::CircuitBuilder; @@ -17,10 +16,8 @@ use plonky2::plonk::config::GenericConfig; use crate::all_stark::{Table, NUM_TABLES}; use crate::config::StarkConfig; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::permutation::{ - get_grand_product_challenge_set, GrandProductChallenge, GrandProductChallengeSet, -}; -use crate::proof::{StarkProof, StarkProofTarget}; +use crate::permutation::{GrandProductChallenge, GrandProductChallengeSet}; +use crate::proof::{StarkProofTarget, StarkProofWithMetadata}; use crate::stark::Stark; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; @@ -162,33 +159,19 @@ impl TableWithColumns { pub struct CrossTableLookup { pub(crate) looking_tables: Vec>, pub(crate) looked_table: TableWithColumns, - /// Default value if filters are not used. - default: Option>, } impl CrossTableLookup { pub fn new( looking_tables: Vec>, looked_table: TableWithColumns, - default: Option>, ) -> Self { assert!(looking_tables .iter() .all(|twc| twc.columns.len() == looked_table.columns.len())); - assert!( - looking_tables - .iter() - .all(|twc| twc.filter_column.is_none() == default.is_some()) - && default.is_some() == looked_table.filter_column.is_none(), - "Default values should be provided iff there are no filter columns." - ); - if let Some(default) = &default { - assert_eq!(default.len(), looked_table.columns.len()); - } Self { looking_tables, looked_table, - default, } } @@ -234,22 +217,19 @@ impl CtlData { } } -pub fn cross_table_lookup_data, const D: usize>( - config: &StarkConfig, +pub(crate) fn cross_table_lookup_data, const D: usize>( trace_poly_values: &[Vec>; NUM_TABLES], cross_table_lookups: &[CrossTableLookup], - challenger: &mut Challenger, + ctl_challenges: &GrandProductChallengeSet, ) -> [CtlData; NUM_TABLES] { - let challenges = get_grand_product_challenge_set(challenger, config.num_challenges); let mut ctl_data_per_table = [0; NUM_TABLES].map(|_| CtlData::default()); for CrossTableLookup { looking_tables, looked_table, - default, } in cross_table_lookups { log::debug!("Processing CTL for {:?}", looked_table.table); - for &challenge in &challenges.challenges { + for &challenge in &ctl_challenges.challenges { let zs_looking = looking_tables.iter().map(|table| { partial_products( &trace_poly_values[table.table as usize], @@ -271,21 +251,6 @@ pub fn cross_table_lookup_data, const D .map(|z| *z.values.last().unwrap()) .product::(), *z_looked.values.last().unwrap() - * default - .as_ref() - .map(|default| { - challenge.combine(default).exp_u64( - looking_tables - .iter() - .map(|table| { - trace_poly_values[table.table as usize][0].len() as u64 - }) - .sum::() - - trace_poly_values[looked_table.table as usize][0].len() - as u64, - ) - }) - .unwrap_or(F::ONE) ); for (table, z) in looking_tables.iter().zip(zs_looking) { @@ -358,7 +323,7 @@ impl<'a, F: RichField + Extendable, const D: usize> CtlCheckVars<'a, F, F::Extension, F::Extension, D> { pub(crate) fn from_proofs>( - proofs: &[StarkProof; NUM_TABLES], + proofs: &[StarkProofWithMetadata; NUM_TABLES], cross_table_lookups: &'a [CrossTableLookup], ctl_challenges: &'a GrandProductChallengeSet, num_permutation_zs: &[usize; NUM_TABLES], @@ -367,7 +332,7 @@ impl<'a, F: RichField + Extendable, const D: usize> .iter() .zip(num_permutation_zs) .map(|(p, &num_perms)| { - let openings = &p.openings; + let openings = &p.proof.openings; let ctl_zs = openings.permutation_ctl_zs.iter().skip(num_perms); let ctl_zs_next = openings.permutation_ctl_zs_next.iter().skip(num_perms); ctl_zs.zip(ctl_zs_next) @@ -378,7 +343,6 @@ impl<'a, F: RichField + Extendable, const D: usize> for CrossTableLookup { looking_tables, looked_table, - .. } in cross_table_lookups { for &challenges in &ctl_challenges.challenges { @@ -481,7 +445,6 @@ impl<'a, F: Field, const D: usize> CtlCheckVarsTarget<'a, F, D> { for CrossTableLookup { looking_tables, looked_table, - .. } in cross_table_lookups { for &challenges in &ctl_challenges.challenges { @@ -582,44 +545,27 @@ pub(crate) fn verify_cross_table_lookups< C: GenericConfig, const D: usize, >( - cross_table_lookups: Vec>, + cross_table_lookups: &[CrossTableLookup], ctl_zs_lasts: [Vec; NUM_TABLES], - degrees_bits: [usize; NUM_TABLES], - challenges: GrandProductChallengeSet, config: &StarkConfig, ) -> Result<()> { let mut ctl_zs_openings = ctl_zs_lasts.iter().map(|v| v.iter()).collect::>(); - for ( - i, - CrossTableLookup { - looking_tables, - looked_table, - default, - .. - }, - ) in cross_table_lookups.into_iter().enumerate() + for CrossTableLookup { + looking_tables, + looked_table, + } in cross_table_lookups.iter() { for _ in 0..config.num_challenges { - let looking_degrees_sum = looking_tables - .iter() - .map(|table| 1 << degrees_bits[table.table as usize]) - .sum::(); - let looked_degree = 1 << degrees_bits[looked_table.table as usize]; let looking_zs_prod = looking_tables .iter() .map(|table| *ctl_zs_openings[table.table as usize].next().unwrap()) .product::(); let looked_z = *ctl_zs_openings[looked_table.table as usize].next().unwrap(); - let challenge = challenges.challenges[i % config.num_challenges]; - if let Some(default) = default.as_ref() { - let combined_default = challenge.combine(default.iter()); - ensure!( - looking_zs_prod - == looked_z * combined_default.exp_u64(looking_degrees_sum - looked_degree), - "Cross-table lookup verification failed." - ); - } + ensure!( + looking_zs_prod == looked_z, + "Cross-table lookup verification failed." + ); } } debug_assert!(ctl_zs_openings.iter_mut().all(|iter| iter.next().is_none())); @@ -635,47 +581,22 @@ pub(crate) fn verify_cross_table_lookups_circuit< builder: &mut CircuitBuilder, cross_table_lookups: Vec>, ctl_zs_lasts: [Vec; NUM_TABLES], - degrees_bits: [usize; NUM_TABLES], - challenges: GrandProductChallengeSet, inner_config: &StarkConfig, ) { let mut ctl_zs_openings = ctl_zs_lasts.iter().map(|v| v.iter()).collect::>(); - for ( - i, - CrossTableLookup { - looking_tables, - looked_table, - default, - .. - }, - ) in cross_table_lookups.into_iter().enumerate() + for CrossTableLookup { + looking_tables, + looked_table, + } in cross_table_lookups.into_iter() { for _ in 0..inner_config.num_challenges { - let looking_degrees_sum = looking_tables - .iter() - .map(|table| 1 << degrees_bits[table.table as usize]) - .sum::(); - let looked_degree = 1 << degrees_bits[looked_table.table as usize]; let looking_zs_prod = builder.mul_many( looking_tables .iter() .map(|table| *ctl_zs_openings[table.table as usize].next().unwrap()), ); let looked_z = *ctl_zs_openings[looked_table.table as usize].next().unwrap(); - let challenge = challenges.challenges[i % inner_config.num_challenges]; - if let Some(default) = default.as_ref() { - let default = default - .iter() - .map(|&x| builder.constant(x)) - .collect::>(); - let combined_default = challenge.combine_base_circuit(builder, &default); - - let pad = builder.exp_u64(combined_default, looking_degrees_sum - looked_degree); - let padded_looked_z = builder.mul(looked_z, pad); - builder.connect(looking_zs_prod, padded_looked_z); - } else { - builder.connect(looking_zs_prod, looked_z); - } + builder.connect(looking_zs_prod, looked_z); } } debug_assert!(ctl_zs_openings.iter_mut().all(|iter| iter.next().is_none())); @@ -712,7 +633,6 @@ pub(crate) mod testutils { let CrossTableLookup { looking_tables, looked_table, - default, } = ctl; // Maps `m` with `(table, i) in m[row]` iff the `i`-th row of `table` is equal to `row` and @@ -726,45 +646,11 @@ pub(crate) mod testutils { process_table(trace_poly_values, looked_table, &mut looked_multiset); let empty = &vec![]; - // Check that every row in the looking tables appears in the looked table the same number of times - // with some special logic for the default row. + // Check that every row in the looking tables appears in the looked table the same number of times. for (row, looking_locations) in &looking_multiset { let looked_locations = looked_multiset.get(row).unwrap_or(empty); - if let Some(default) = default { - if row == default { - continue; - } - } check_locations(looking_locations, looked_locations, ctl_index, row); } - let extra_default_count = default.as_ref().map(|d| { - let looking_default_locations = looking_multiset.get(d).unwrap_or(empty); - let looked_default_locations = looked_multiset.get(d).unwrap_or(empty); - looking_default_locations - .len() - .checked_sub(looked_default_locations.len()) - .unwrap_or_else(|| { - // If underflow, panic. There should be more default rows in the looking side. - check_locations( - looking_default_locations, - looked_default_locations, - ctl_index, - d, - ); - unreachable!() - }) - }); - // Check that the number of extra default rows is correct. - if let Some(count) = extra_default_count { - assert_eq!( - count, - looking_tables - .iter() - .map(|table| trace_poly_values[table.table as usize][0].len()) - .sum::() - - trace_poly_values[looked_table.table as usize][0].len() - ); - } // Check that every row in the looked tables appears in the looked table the same number of times. for (row, looked_locations) in &looked_multiset { let looking_locations = looking_multiset.get(row).unwrap_or(empty); diff --git a/evm/src/fixed_recursive_verifier.rs b/evm/src/fixed_recursive_verifier.rs new file mode 100644 index 00000000..5a3a5013 --- /dev/null +++ b/evm/src/fixed_recursive_verifier.rs @@ -0,0 +1,540 @@ +use std::collections::BTreeMap; +use std::ops::Range; + +use itertools::Itertools; +use plonky2::field::extension::Extendable; +use plonky2::gates::noop::NoopGate; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::hashing::SPONGE_WIDTH; +use plonky2::iop::challenger::RecursiveChallenger; +use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::iop::witness::{PartialWitness, WitnessWrite}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, VerifierCircuitTarget}; +use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; +use plonky2::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}; +use plonky2::recursion::cyclic_recursion::check_cyclic_proof_verifier_data; +use plonky2::util::timing::TimingTree; +use plonky2_util::log2_ceil; + +use crate::all_stark::{all_cross_table_lookups, AllStark, Table, NUM_TABLES}; +use crate::config::StarkConfig; +use crate::cpu::cpu_stark::CpuStark; +use crate::cross_table_lookup::{verify_cross_table_lookups_circuit, CrossTableLookup}; +use crate::generation::GenerationInputs; +use crate::keccak::keccak_stark::KeccakStark; +use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; +use crate::logic::LogicStark; +use crate::memory::memory_stark::MemoryStark; +use crate::permutation::{get_grand_product_challenge_set_target, GrandProductChallengeSet}; +use crate::proof::StarkProofWithMetadata; +use crate::prover::prove; +use crate::recursive_verifier::{ + add_common_recursion_gates, recursive_stark_circuit, PlonkWrapperCircuit, PublicInputs, + StarkWrapperCircuit, +}; +use crate::stark::Stark; + +/// The recursion threshold. We end a chain of recursive proofs once we reach this size. +const THRESHOLD_DEGREE_BITS: usize = 13; + +/// Contains all recursive circuits used in the system. For each STARK and each initial +/// `degree_bits`, this contains a chain of recursive circuits for shrinking that STARK from +/// `degree_bits` to a constant `THRESHOLD_DEGREE_BITS`. It also contains a special root circuit +/// for combining each STARK's shrunk wrapper proof into a single proof. +pub struct AllRecursiveCircuits +where + F: RichField + Extendable, + C: GenericConfig, +{ + /// The EVM root circuit, which aggregates the (shrunk) per-table recursive proofs. + pub root: RootCircuitData, + pub aggregation: AggregationCircuitData, + /// Holds chains of circuits for each table and for each initial `degree_bits`. + by_table: [RecursiveCircuitsForTable; NUM_TABLES], +} + +/// Data for the EVM root circuit, which is used to combine each STARK's shrunk wrapper proof +/// into a single proof. +pub struct RootCircuitData +where + F: RichField + Extendable, + C: GenericConfig, +{ + circuit: CircuitData, + proof_with_pis: [ProofWithPublicInputsTarget; NUM_TABLES], + /// For each table, various inner circuits may be used depending on the initial table size. + /// This target holds the index of the circuit (within `final_circuits()`) that was used. + index_verifier_data: [Target; NUM_TABLES], + /// Public inputs used for cyclic verification. These aren't actually used for EVM root + /// proofs; the circuit has them just to match the structure of aggregation proofs. + cyclic_vk: VerifierCircuitTarget, +} + +/// Data for the aggregation circuit, which is used to compress two proofs into one. Each inner +/// proof can be either an EVM root proof or another aggregation proof. +pub struct AggregationCircuitData +where + F: RichField + Extendable, + C: GenericConfig, +{ + circuit: CircuitData, + lhs: AggregationChildTarget, + rhs: AggregationChildTarget, + cyclic_vk: VerifierCircuitTarget, +} + +pub struct AggregationChildTarget { + is_agg: BoolTarget, + agg_proof: ProofWithPublicInputsTarget, + evm_proof: ProofWithPublicInputsTarget, +} + +impl AllRecursiveCircuits +where + F: RichField + Extendable, + C: GenericConfig, + C::Hasher: AlgebraicHasher, + [(); C::Hasher::HASH_SIZE]:, + [(); CpuStark::::COLUMNS]:, + [(); KeccakStark::::COLUMNS]:, + [(); KeccakSpongeStark::::COLUMNS]:, + [(); LogicStark::::COLUMNS]:, + [(); MemoryStark::::COLUMNS]:, +{ + /// Preprocess all recursive circuits used by the system. + pub fn new( + all_stark: &AllStark, + degree_bits_range: Range, + stark_config: &StarkConfig, + ) -> Self { + let cpu = RecursiveCircuitsForTable::new( + Table::Cpu, + &all_stark.cpu_stark, + degree_bits_range.clone(), + &all_stark.cross_table_lookups, + stark_config, + ); + let keccak = RecursiveCircuitsForTable::new( + Table::Keccak, + &all_stark.keccak_stark, + degree_bits_range.clone(), + &all_stark.cross_table_lookups, + stark_config, + ); + let keccak_sponge = RecursiveCircuitsForTable::new( + Table::KeccakSponge, + &all_stark.keccak_sponge_stark, + degree_bits_range.clone(), + &all_stark.cross_table_lookups, + stark_config, + ); + let logic = RecursiveCircuitsForTable::new( + Table::Logic, + &all_stark.logic_stark, + degree_bits_range.clone(), + &all_stark.cross_table_lookups, + stark_config, + ); + let memory = RecursiveCircuitsForTable::new( + Table::Memory, + &all_stark.memory_stark, + degree_bits_range, + &all_stark.cross_table_lookups, + stark_config, + ); + + let by_table = [cpu, keccak, keccak_sponge, logic, memory]; + let root = Self::create_root_circuit(&by_table, stark_config); + let aggregation = Self::create_aggregation_circuit(&root); + Self { + root, + aggregation, + by_table, + } + } + + fn create_root_circuit( + by_table: &[RecursiveCircuitsForTable; NUM_TABLES], + stark_config: &StarkConfig, + ) -> RootCircuitData { + let inner_common_data: [_; NUM_TABLES] = + std::array::from_fn(|i| &by_table[i].final_circuits()[0].common); + + let mut builder = CircuitBuilder::new(CircuitConfig::standard_recursion_config()); + let recursive_proofs = + std::array::from_fn(|i| builder.add_virtual_proof_with_pis::(inner_common_data[i])); + let pis: [_; NUM_TABLES] = std::array::from_fn(|i| { + PublicInputs::from_vec(&recursive_proofs[i].public_inputs, stark_config) + }); + let index_verifier_data = std::array::from_fn(|_i| builder.add_virtual_target()); + + let mut challenger = RecursiveChallenger::::new(&mut builder); + for pi in &pis { + for h in &pi.trace_cap { + challenger.observe_elements(h); + } + } + let ctl_challenges = get_grand_product_challenge_set_target( + &mut builder, + &mut challenger, + stark_config.num_challenges, + ); + // Check that the correct CTL challenges are used in every proof. + for pi in &pis { + for i in 0..stark_config.num_challenges { + builder.connect( + ctl_challenges.challenges[i].beta, + pi.ctl_challenges.challenges[i].beta, + ); + builder.connect( + ctl_challenges.challenges[i].gamma, + pi.ctl_challenges.challenges[i].gamma, + ); + } + } + + let state = challenger.compact(&mut builder); + for k in 0..SPONGE_WIDTH { + builder.connect(state[k], pis[0].challenger_state_before[k]); + } + // Check that the challenger state is consistent between proofs. + for i in 1..NUM_TABLES { + for k in 0..SPONGE_WIDTH { + builder.connect( + pis[i].challenger_state_before[k], + pis[i - 1].challenger_state_after[k], + ); + } + } + + // Verify the CTL checks. + verify_cross_table_lookups_circuit::( + &mut builder, + all_cross_table_lookups(), + pis.map(|p| p.ctl_zs_last), + stark_config, + ); + + for (i, table_circuits) in by_table.iter().enumerate() { + let final_circuits = table_circuits.final_circuits(); + for final_circuit in &final_circuits { + assert_eq!( + &final_circuit.common, inner_common_data[i], + "common_data mismatch" + ); + } + let mut possible_vks = final_circuits + .into_iter() + .map(|c| builder.constant_verifier_data(&c.verifier_only)) + .collect_vec(); + // random_access_verifier_data expects a vector whose length is a power of two. + // To satisfy this, we will just add some duplicates of the first VK. + while !possible_vks.len().is_power_of_two() { + possible_vks.push(possible_vks[0].clone()); + } + let inner_verifier_data = + builder.random_access_verifier_data(index_verifier_data[i], possible_vks); + + builder.verify_proof::( + &recursive_proofs[i], + &inner_verifier_data, + inner_common_data[i], + ); + } + + // We want EVM root proofs to have the exact same structure as aggregation proofs, so we add + // public inputs for cyclic verification, even though they'll be ignored. + let cyclic_vk = builder.add_verifier_data_public_inputs(); + + RootCircuitData { + circuit: builder.build(), + proof_with_pis: recursive_proofs, + index_verifier_data, + cyclic_vk, + } + } + + fn create_aggregation_circuit( + root: &RootCircuitData, + ) -> AggregationCircuitData { + let mut builder = CircuitBuilder::::new(root.circuit.common.config.clone()); + let cyclic_vk = builder.add_verifier_data_public_inputs(); + let lhs = Self::add_agg_child(&mut builder, root); + let rhs = Self::add_agg_child(&mut builder, root); + + // Pad to match the root circuit's degree. + while log2_ceil(builder.num_gates()) < root.circuit.common.degree_bits() { + builder.add_gate(NoopGate, vec![]); + } + + let circuit = builder.build::(); + AggregationCircuitData { + circuit, + lhs, + rhs, + cyclic_vk, + } + } + + fn add_agg_child( + builder: &mut CircuitBuilder, + root: &RootCircuitData, + ) -> AggregationChildTarget { + let common = &root.circuit.common; + let root_vk = builder.constant_verifier_data(&root.circuit.verifier_only); + let is_agg = builder.add_virtual_bool_target_safe(); + let agg_proof = builder.add_virtual_proof_with_pis::(common); + let evm_proof = builder.add_virtual_proof_with_pis::(common); + builder + .conditionally_verify_cyclic_proof::( + is_agg, &agg_proof, &evm_proof, &root_vk, common, + ) + .expect("Failed to build cyclic recursion circuit"); + AggregationChildTarget { + is_agg, + agg_proof, + evm_proof, + } + } + + /// Create a proof for each STARK, then combine them, eventually culminating in a root proof. + pub fn prove_root( + &self, + all_stark: &AllStark, + config: &StarkConfig, + generation_inputs: GenerationInputs, + timing: &mut TimingTree, + ) -> anyhow::Result> { + let all_proof = prove::(all_stark, config, generation_inputs, timing)?; + let mut root_inputs = PartialWitness::new(); + + for table in 0..NUM_TABLES { + let stark_proof = &all_proof.stark_proofs[table]; + let original_degree_bits = stark_proof.proof.recover_degree_bits(config); + let table_circuits = &self.by_table[table]; + let shrunk_proof = table_circuits.by_stark_size[&original_degree_bits] + .shrink(stark_proof, &all_proof.ctl_challenges)?; + let index_verifier_data = table_circuits + .by_stark_size + .keys() + .position(|&size| size == original_degree_bits) + .unwrap(); + root_inputs.set_target( + self.root.index_verifier_data[table], + F::from_canonical_usize(index_verifier_data), + ); + root_inputs.set_proof_with_pis_target(&self.root.proof_with_pis[table], &shrunk_proof); + } + + root_inputs.set_verifier_data_target( + &self.root.cyclic_vk, + &self.aggregation.circuit.verifier_only, + ); + + self.root.circuit.prove(root_inputs) + } + + pub fn verify_root(&self, agg_proof: ProofWithPublicInputs) -> anyhow::Result<()> { + self.root.circuit.verify(agg_proof) + } + + pub fn prove_aggregation( + &self, + lhs_is_agg: bool, + lhs_proof: &ProofWithPublicInputs, + rhs_is_agg: bool, + rhs_proof: &ProofWithPublicInputs, + ) -> anyhow::Result> { + let mut agg_inputs = PartialWitness::new(); + + agg_inputs.set_bool_target(self.aggregation.lhs.is_agg, lhs_is_agg); + agg_inputs.set_proof_with_pis_target(&self.aggregation.lhs.agg_proof, lhs_proof); + agg_inputs.set_proof_with_pis_target(&self.aggregation.lhs.evm_proof, lhs_proof); + + agg_inputs.set_bool_target(self.aggregation.rhs.is_agg, rhs_is_agg); + agg_inputs.set_proof_with_pis_target(&self.aggregation.rhs.agg_proof, rhs_proof); + agg_inputs.set_proof_with_pis_target(&self.aggregation.rhs.evm_proof, rhs_proof); + + agg_inputs.set_verifier_data_target( + &self.aggregation.cyclic_vk, + &self.aggregation.circuit.verifier_only, + ); + + self.aggregation.circuit.prove(agg_inputs) + } + + pub fn verify_aggregation( + &self, + agg_proof: &ProofWithPublicInputs, + ) -> anyhow::Result<()> { + self.aggregation.circuit.verify(agg_proof.clone())?; + check_cyclic_proof_verifier_data( + agg_proof, + &self.aggregation.circuit.verifier_only, + &self.aggregation.circuit.common, + ) + } +} + +struct RecursiveCircuitsForTable +where + F: RichField + Extendable, + C: GenericConfig, +{ + /// A map from `log_2(height)` to a chain of shrinking recursion circuits starting at that + /// height. + by_stark_size: BTreeMap>, +} + +impl RecursiveCircuitsForTable +where + F: RichField + Extendable, + C: GenericConfig, + C::Hasher: AlgebraicHasher, + [(); C::Hasher::HASH_SIZE]:, +{ + fn new>( + table: Table, + stark: &S, + degree_bits_range: Range, + all_ctls: &[CrossTableLookup], + stark_config: &StarkConfig, + ) -> Self + where + [(); S::COLUMNS]:, + { + let by_stark_size = degree_bits_range + .map(|degree_bits| { + ( + degree_bits, + RecursiveCircuitsForTableSize::new::( + table, + stark, + degree_bits, + all_ctls, + stark_config, + ), + ) + }) + .collect(); + Self { by_stark_size } + } + + /// For each initial `degree_bits`, get the final circuit at the end of that shrinking chain. + /// Each of these final circuits should have degree `THRESHOLD_DEGREE_BITS`. + fn final_circuits(&self) -> Vec<&CircuitData> { + self.by_stark_size + .values() + .map(|chain| { + chain + .shrinking_wrappers + .last() + .map(|wrapper| &wrapper.circuit) + .unwrap_or(&chain.initial_wrapper.circuit) + }) + .collect() + } +} + +/// A chain of shrinking wrapper circuits, ending with a final circuit with `degree_bits` +/// `THRESHOLD_DEGREE_BITS`. +struct RecursiveCircuitsForTableSize +where + F: RichField + Extendable, + C: GenericConfig, +{ + initial_wrapper: StarkWrapperCircuit, + shrinking_wrappers: Vec>, +} + +impl RecursiveCircuitsForTableSize +where + F: RichField + Extendable, + C: GenericConfig, + C::Hasher: AlgebraicHasher, + [(); C::Hasher::HASH_SIZE]:, +{ + fn new>( + table: Table, + stark: &S, + degree_bits: usize, + all_ctls: &[CrossTableLookup], + stark_config: &StarkConfig, + ) -> Self + where + [(); S::COLUMNS]:, + { + let initial_wrapper = recursive_stark_circuit( + table, + stark, + degree_bits, + all_ctls, + stark_config, + &shrinking_config(), + THRESHOLD_DEGREE_BITS, + ); + let mut shrinking_wrappers = vec![]; + + // Shrinking recursion loop. + loop { + let last = shrinking_wrappers + .last() + .map(|wrapper: &PlonkWrapperCircuit| &wrapper.circuit) + .unwrap_or(&initial_wrapper.circuit); + let last_degree_bits = last.common.degree_bits(); + assert!(last_degree_bits >= THRESHOLD_DEGREE_BITS); + if last_degree_bits == THRESHOLD_DEGREE_BITS { + break; + } + + let mut builder = CircuitBuilder::new(shrinking_config()); + let proof_with_pis_target = builder.add_virtual_proof_with_pis::(&last.common); + let last_vk = builder.constant_verifier_data(&last.verifier_only); + builder.verify_proof::(&proof_with_pis_target, &last_vk, &last.common); + builder.register_public_inputs(&proof_with_pis_target.public_inputs); // carry PIs forward + add_common_recursion_gates(&mut builder); + let circuit = builder.build(); + + assert!( + circuit.common.degree_bits() < last_degree_bits, + "Couldn't shrink to expected recursion threshold of 2^{}; stalled at 2^{}", + THRESHOLD_DEGREE_BITS, + circuit.common.degree_bits() + ); + shrinking_wrappers.push(PlonkWrapperCircuit { + circuit, + proof_with_pis_target, + }); + } + + Self { + initial_wrapper, + shrinking_wrappers, + } + } + + fn shrink( + &self, + stark_proof_with_metadata: &StarkProofWithMetadata, + ctl_challenges: &GrandProductChallengeSet, + ) -> anyhow::Result> { + let mut proof = self + .initial_wrapper + .prove(stark_proof_with_metadata, ctl_challenges)?; + for wrapper_circuit in &self.shrinking_wrappers { + proof = wrapper_circuit.prove(&proof)?; + } + Ok(proof) + } +} + +/// Our usual recursion threshold is 2^12 gates, but for these shrinking circuits, we use a few more +/// gates for a constant inner VK and for public inputs. This pushes us over the threshold to 2^13. +/// As long as we're at 2^13 gates, we might as well use a narrower witness. +fn shrinking_config() -> CircuitConfig { + CircuitConfig { + num_routed_wires: 40, + ..CircuitConfig::standard_recursion_config() + } +} diff --git a/evm/src/get_challenges.rs b/evm/src/get_challenges.rs index ede7c466..f368b7c2 100644 --- a/evm/src/get_challenges.rs +++ b/evm/src/get_challenges.rs @@ -23,7 +23,7 @@ impl, C: GenericConfig, const D: usize> A let mut challenger = Challenger::::new(); for proof in &self.stark_proofs { - challenger.observe_cap(&proof.trace_cap); + challenger.observe_cap(&proof.proof.trace_cap); } // TODO: Observe public values. @@ -37,7 +37,7 @@ impl, C: GenericConfig, const D: usize> A AllProofChallenges { stark_challenges: std::array::from_fn(|i| { challenger.compact(); - self.stark_proofs[i].get_challenges( + self.stark_proofs[i].proof.get_challenges( &mut challenger, num_permutation_zs[i] > 0, num_permutation_batch_sizes[i], @@ -57,7 +57,7 @@ impl, C: GenericConfig, const D: usize> A let mut challenger = Challenger::::new(); for proof in &self.stark_proofs { - challenger.observe_cap(&proof.trace_cap); + challenger.observe_cap(&proof.proof.trace_cap); } // TODO: Observe public values. @@ -70,7 +70,7 @@ impl, C: GenericConfig, const D: usize> A let mut challenger_states = vec![challenger.compact()]; for i in 0..NUM_TABLES { - self.stark_proofs[i].get_challenges( + self.stark_proofs[i].proof.get_challenges( &mut challenger, num_permutation_zs[i] > 0, num_permutation_batch_sizes[i], diff --git a/evm/src/keccak/keccak_stark.rs b/evm/src/keccak/keccak_stark.rs index df842a41..c8fe8086 100644 --- a/evm/src/keccak/keccak_stark.rs +++ b/evm/src/keccak/keccak_stark.rs @@ -74,11 +74,8 @@ impl, const D: usize> KeccakStark { rows } - fn generate_trace_rows_for_perm( - &self, - input: [u64; NUM_INPUTS], - ) -> [[F; NUM_COLUMNS]; NUM_ROUNDS] { - let mut rows = [[F::ZERO; NUM_COLUMNS]; NUM_ROUNDS]; + fn generate_trace_rows_for_perm(&self, input: [u64; NUM_INPUTS]) -> Vec<[F; NUM_COLUMNS]> { + let mut rows = vec![[F::ZERO; NUM_COLUMNS]; NUM_ROUNDS]; // Populate the preimage for each row. for round in 0..24 { diff --git a/evm/src/lib.rs b/evm/src/lib.rs index 7ac3e1e7..bd9ba261 100644 --- a/evm/src/lib.rs +++ b/evm/src/lib.rs @@ -13,6 +13,7 @@ pub mod config; pub mod constraint_consumer; pub mod cpu; pub mod cross_table_lookup; +pub mod fixed_recursive_verifier; pub mod generation; mod get_challenges; pub mod keccak; diff --git a/evm/src/permutation.rs b/evm/src/permutation.rs index b081c309..4f42a4aa 100644 --- a/evm/src/permutation.rs +++ b/evm/src/permutation.rs @@ -15,9 +15,7 @@ use plonky2::iop::ext_target::ExtensionTarget; use plonky2::iop::target::Target; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; -use plonky2::plonk::plonk_common::{ - reduce_with_powers, reduce_with_powers_circuit, reduce_with_powers_ext_circuit, -}; +use plonky2::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_circuit}; use plonky2::util::reducing::{ReducingFactor, ReducingFactorTarget}; use crate::config::StarkConfig; @@ -82,15 +80,6 @@ impl GrandProductChallenge { let gamma = builder.convert_to_ext(self.gamma); builder.add_extension(reduced, gamma) } - - pub(crate) fn combine_base_circuit, const D: usize>( - &self, - builder: &mut CircuitBuilder, - terms: &[Target], - ) -> Target { - let reduced = reduce_with_powers_circuit(builder, terms, self.beta); - builder.add(reduced, self.gamma) - } } /// Like `PermutationChallenge`, but with `num_challenges` copies to boost soundness. diff --git a/evm/src/proof.rs b/evm/src/proof.rs index 4cd03a65..7025697f 100644 --- a/evm/src/proof.rs +++ b/evm/src/proof.rs @@ -19,15 +19,17 @@ use crate::all_stark::NUM_TABLES; use crate::config::StarkConfig; use crate::permutation::GrandProductChallengeSet; +/// A STARK proof for each table, plus some metadata used to create recursive wrapper proofs. #[derive(Debug, Clone)] pub struct AllProof, C: GenericConfig, const D: usize> { - pub stark_proofs: [StarkProof; NUM_TABLES], + pub stark_proofs: [StarkProofWithMetadata; NUM_TABLES], + pub(crate) ctl_challenges: GrandProductChallengeSet, pub public_values: PublicValues, } impl, C: GenericConfig, const D: usize> AllProof { pub fn degree_bits(&self, config: &StarkConfig) -> [usize; NUM_TABLES] { - std::array::from_fn(|i| self.stark_proofs[i].recover_degree_bits(config)) + std::array::from_fn(|i| self.stark_proofs[i].proof.recover_degree_bits(config)) } } @@ -44,11 +46,6 @@ pub(crate) struct AllChallengerState, const D: usiz pub ctl_challenges: GrandProductChallengeSet, } -pub struct AllProofTarget { - pub stark_proofs: [StarkProofTarget; NUM_TABLES], - pub public_values: PublicValuesTarget, -} - /// Memory values which are public. #[derive(Debug, Clone, Default)] pub struct PublicValues { @@ -113,6 +110,18 @@ pub struct StarkProof, C: GenericConfig, pub opening_proof: FriProof, } +/// A `StarkProof` along with some metadata about the initial Fiat-Shamir state, which is used when +/// creating a recursive wrapper proof around a STARK proof. +#[derive(Debug, Clone)] +pub struct StarkProofWithMetadata +where + F: RichField + Extendable, + C: GenericConfig, +{ + pub(crate) init_challenger_state: [F; SPONGE_WIDTH], + pub(crate) proof: StarkProof, +} + impl, C: GenericConfig, const D: usize> StarkProof { /// Recover the length of the trace from a STARK proof and a STARK config. pub fn recover_degree_bits(&self, config: &StarkConfig) -> usize { diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 55b57437..a777e1ad 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -29,10 +29,10 @@ use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; use crate::logic::LogicStark; use crate::memory::memory_stark::MemoryStark; use crate::permutation::{ - compute_permutation_z_polys, get_n_grand_product_challenge_sets, GrandProductChallengeSet, - PermutationCheckVars, + compute_permutation_z_polys, get_grand_product_challenge_set, + get_n_grand_product_challenge_sets, GrandProductChallengeSet, PermutationCheckVars, }; -use crate::proof::{AllProof, PublicValues, StarkOpeningSet, StarkProof}; +use crate::proof::{AllProof, PublicValues, StarkOpeningSet, StarkProof, StarkProofWithMetadata}; use crate::stark::Stark; use crate::vanishing_poly::eval_vanishing_poly; use crate::vars::StarkEvaluationVars; @@ -117,14 +117,14 @@ where challenger.observe_cap(cap); } + let ctl_challenges = get_grand_product_challenge_set(&mut challenger, config.num_challenges); let ctl_data_per_table = timed!( timing, "compute CTL data", cross_table_lookup_data::( - config, &trace_poly_values, &all_stark.cross_table_lookups, - &mut challenger, + &ctl_challenges, ) ); @@ -144,6 +144,7 @@ where Ok(AllProof { stark_proofs, + ctl_challenges, public_values, }) } @@ -156,7 +157,7 @@ fn prove_with_commitments( ctl_data_per_table: [CtlData; NUM_TABLES], challenger: &mut Challenger, timing: &mut TimingTree, -) -> Result<[StarkProof; NUM_TABLES]> +) -> Result<[StarkProofWithMetadata; NUM_TABLES]> where F: RichField + Extendable, C: GenericConfig, @@ -250,7 +251,7 @@ pub(crate) fn prove_single_table( ctl_data: &CtlData, challenger: &mut Challenger, timing: &mut TimingTree, -) -> Result> +) -> Result> where F: RichField + Extendable, C: GenericConfig, @@ -268,7 +269,7 @@ where "FRI total reduction arity is too large.", ); - challenger.compact(); + let init_challenger_state = challenger.compact(); // Permutation arguments. let permutation_challenges = stark.uses_permutation_args().then(|| { @@ -411,12 +412,16 @@ where ) ); - Ok(StarkProof { + let proof = StarkProof { trace_cap: trace_commitment.merkle_tree.cap.clone(), permutation_ctl_zs_cap, quotient_polys_cap, openings, opening_proof, + }; + Ok(StarkProofWithMetadata { + init_challenger_state, + proof, }) } diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index 0f713e32..4de74921 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -5,38 +5,34 @@ use itertools::Itertools; use plonky2::field::extension::Extendable; use plonky2::field::types::Field; use plonky2::fri::witness_util::set_fri_proof_target; -use plonky2::hash::hash_types::{HashOut, RichField}; +use plonky2::gates::exponentiation::ExponentiationGate; +use plonky2::gates::gate::GateRef; +use plonky2::gates::noop::NoopGate; +use plonky2::hash::hash_types::RichField; use plonky2::hash::hashing::SPONGE_WIDTH; use plonky2::iop::challenger::{Challenger, RecursiveChallenger}; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::iop::target::Target; -use plonky2::iop::witness::Witness; +use plonky2::iop::witness::{PartialWitness, Witness, WitnessWrite}; use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::circuit_data::{CircuitConfig, VerifierCircuitData, VerifierCircuitTarget}; +use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData, VerifierCircuitData}; use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; use plonky2::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}; use plonky2::util::reducing::ReducingFactorTarget; use plonky2::with_context; +use plonky2_util::log2_ceil; -use crate::all_stark::{all_cross_table_lookups, AllStark, Table, NUM_TABLES}; +use crate::all_stark::{Table, NUM_TABLES}; use crate::config::StarkConfig; use crate::constraint_consumer::RecursiveConstraintConsumer; -use crate::cpu::cpu_stark::CpuStark; -use crate::cross_table_lookup::{ - verify_cross_table_lookups, verify_cross_table_lookups_circuit, CrossTableLookup, - CtlCheckVarsTarget, -}; -use crate::keccak::keccak_stark::KeccakStark; -use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; -use crate::logic::LogicStark; -use crate::memory::memory_stark::MemoryStark; +use crate::cross_table_lookup::{verify_cross_table_lookups, CrossTableLookup, CtlCheckVarsTarget}; use crate::permutation::{ - get_grand_product_challenge_set, get_grand_product_challenge_set_target, GrandProductChallenge, - GrandProductChallengeSet, PermutationCheckDataTarget, + get_grand_product_challenge_set, GrandProductChallenge, GrandProductChallengeSet, + PermutationCheckDataTarget, }; use crate::proof::{ - AllProof, AllProofTarget, BlockMetadata, BlockMetadataTarget, PublicValues, PublicValuesTarget, - StarkOpeningSetTarget, StarkProof, StarkProofChallengesTarget, StarkProofTarget, TrieRoots, + BlockMetadata, BlockMetadataTarget, PublicValues, PublicValuesTarget, StarkOpeningSetTarget, + StarkProof, StarkProofChallengesTarget, StarkProofTarget, StarkProofWithMetadata, TrieRoots, TrieRootsTarget, }; use crate::stark::Stark; @@ -53,17 +49,12 @@ pub struct RecursiveAllProof< pub recursive_proofs: [ProofWithPublicInputs; NUM_TABLES], } -pub struct RecursiveAllProofTargetWithData { - pub recursive_proofs: [ProofWithPublicInputsTarget; NUM_TABLES], - pub verifier_data: [VerifierCircuitTarget; NUM_TABLES], -} - -struct PublicInputs { - trace_cap: Vec>, - ctl_zs_last: Vec, - ctl_challenges: GrandProductChallengeSet, - challenger_state_before: [T; SPONGE_WIDTH], - challenger_state_after: [T; SPONGE_WIDTH], +pub(crate) struct PublicInputs { + pub(crate) trace_cap: Vec>, + pub(crate) ctl_zs_last: Vec, + pub(crate) ctl_challenges: GrandProductChallengeSet, + pub(crate) challenger_state_before: [T; SPONGE_WIDTH], + pub(crate) challenger_state_after: [T; SPONGE_WIDTH], } /// Similar to the unstable `Iterator::next_chunk`. Could be replaced with that when it's stable. @@ -76,9 +67,9 @@ fn next_chunk(iter: &mut impl Iterator) -> [ } impl PublicInputs { - fn from_vec(v: &[T], config: &StarkConfig) -> Self { + pub(crate) fn from_vec(v: &[T], config: &StarkConfig) -> Self { let mut iter = v.iter().copied(); - let trace_cap = (0..1 << config.fri_config.cap_height) + let trace_cap = (0..config.fri_config.num_cap_elements()) .map(|_| next_chunk::<_, 4>(&mut iter).to_vec()) .collect(); let ctl_challenges = GrandProductChallengeSet { @@ -91,7 +82,7 @@ impl PublicInputs { }; let challenger_state_before = next_chunk(&mut iter); let challenger_state_after = next_chunk(&mut iter); - let ctl_zs_last = iter.collect(); + let ctl_zs_last: Vec<_> = iter.collect(); Self { trace_cap, @@ -141,12 +132,9 @@ impl, C: GenericConfig, const D: usize> } // Verify the CTL checks. - let degrees_bits = std::array::from_fn(|i| verifier_data[i].common.degree_bits()); verify_cross_table_lookups::( - cross_table_lookups, + &cross_table_lookups, pis.map(|p| p.ctl_zs_last), - degrees_bits, - ctl_challenges, inner_config, )?; @@ -156,115 +144,115 @@ impl, C: GenericConfig, const D: usize> } Ok(()) } +} - /// Recursively verify every recursive proof. - pub fn verify_circuit( - builder: &mut CircuitBuilder, - recursive_all_proof_target: RecursiveAllProofTargetWithData, - verifier_data: &[VerifierCircuitData; NUM_TABLES], - inner_config: &StarkConfig, - ) where - [(); C::Hasher::HASH_SIZE]:, - >::Hasher: AlgebraicHasher, - { - let RecursiveAllProofTargetWithData { - recursive_proofs, - verifier_data: verifier_data_target, - } = recursive_all_proof_target; - let pis: [_; NUM_TABLES] = std::array::from_fn(|i| { - PublicInputs::from_vec(&recursive_proofs[i].public_inputs, inner_config) - }); +/// Represents a circuit which recursively verifies a STARK proof. +pub(crate) struct StarkWrapperCircuit +where + F: RichField + Extendable, + C: GenericConfig, +{ + pub(crate) circuit: CircuitData, + pub(crate) stark_proof_target: StarkProofTarget, + pub(crate) ctl_challenges_target: GrandProductChallengeSet, + pub(crate) init_challenger_state_target: [Target; SPONGE_WIDTH], + pub(crate) zero_target: Target, +} - let mut challenger = RecursiveChallenger::::new(builder); - for pi in &pis { - for h in &pi.trace_cap { - challenger.observe_elements(h); - } - } - let ctl_challenges = get_grand_product_challenge_set_target( - builder, - &mut challenger, - inner_config.num_challenges, +impl StarkWrapperCircuit +where + F: RichField + Extendable, + C: GenericConfig, + C::Hasher: AlgebraicHasher, +{ + pub(crate) fn prove( + &self, + proof_with_metadata: &StarkProofWithMetadata, + ctl_challenges: &GrandProductChallengeSet, + ) -> Result> { + let mut inputs = PartialWitness::new(); + + set_stark_proof_target( + &mut inputs, + &self.stark_proof_target, + &proof_with_metadata.proof, + self.zero_target, ); - // Check that the correct CTL challenges are used in every proof. - for pi in &pis { - for i in 0..inner_config.num_challenges { - builder.connect( - ctl_challenges.challenges[i].beta, - pi.ctl_challenges.challenges[i].beta, - ); - builder.connect( - ctl_challenges.challenges[i].gamma, - pi.ctl_challenges.challenges[i].gamma, - ); - } - } - let state = challenger.compact(builder); - for k in 0..SPONGE_WIDTH { - builder.connect(state[k], pis[0].challenger_state_before[k]); - } - // Check that the challenger state is consistent between proofs. - for i in 1..NUM_TABLES { - for k in 0..SPONGE_WIDTH { - builder.connect( - pis[i].challenger_state_before[k], - pis[i - 1].challenger_state_after[k], - ); - } - } - - // Verify the CTL checks. - let degrees_bits = std::array::from_fn(|i| verifier_data[i].common.degree_bits()); - verify_cross_table_lookups_circuit::( - builder, - all_cross_table_lookups(), - pis.map(|p| p.ctl_zs_last), - degrees_bits, - ctl_challenges, - inner_config, - ); - for (i, (recursive_proof, verifier_data_target)) in recursive_proofs - .into_iter() - .zip(verifier_data_target) - .enumerate() + for (challenge_target, challenge) in self + .ctl_challenges_target + .challenges + .iter() + .zip(&ctl_challenges.challenges) { - builder.verify_proof::( - &recursive_proof, - &verifier_data_target, - &verifier_data[i].common, - ); + inputs.set_target(challenge_target.beta, challenge.beta); + inputs.set_target(challenge_target.gamma, challenge.gamma); } + + inputs.set_target_arr( + self.init_challenger_state_target, + proof_with_metadata.init_challenger_state, + ); + + self.circuit.prove(inputs) } } -/// Returns the verifier data for the recursive Stark circuit. -fn verifier_data_recursive_stark_proof< +/// Represents a circuit which recursively verifies a PLONK proof. +pub(crate) struct PlonkWrapperCircuit +where + F: RichField + Extendable, + C: GenericConfig, +{ + pub(crate) circuit: CircuitData, + pub(crate) proof_with_pis_target: ProofWithPublicInputsTarget, +} + +impl PlonkWrapperCircuit +where + F: RichField + Extendable, + C: GenericConfig, + C::Hasher: AlgebraicHasher, +{ + pub(crate) fn prove( + &self, + proof: &ProofWithPublicInputs, + ) -> Result> { + let mut inputs = PartialWitness::new(); + inputs.set_proof_with_pis_target(&self.proof_with_pis_target, proof); + self.circuit.prove(inputs) + } +} + +/// Returns the recursive Stark circuit. +pub(crate) fn recursive_stark_circuit< F: RichField + Extendable, C: GenericConfig, S: Stark, const D: usize, >( table: Table, - stark: S, + stark: &S, degree_bits: usize, cross_table_lookups: &[CrossTableLookup], inner_config: &StarkConfig, circuit_config: &CircuitConfig, -) -> VerifierCircuitData + min_degree_bits: usize, +) -> StarkWrapperCircuit where [(); S::COLUMNS]:, [(); C::Hasher::HASH_SIZE]:, C::Hasher: AlgebraicHasher, { let mut builder = CircuitBuilder::::new(circuit_config.clone()); + let zero_target = builder.zero(); let num_permutation_zs = stark.num_permutation_batches(inner_config); let num_permutation_batch_size = stark.permutation_batch_size(); let num_ctl_zs = CrossTableLookup::num_ctl_zs(cross_table_lookups, table, inner_config.num_challenges); let proof_target = - add_virtual_stark_proof(&mut builder, &stark, inner_config, degree_bits, num_ctl_zs); + add_virtual_stark_proof(&mut builder, stark, inner_config, degree_bits, num_ctl_zs); builder.register_public_inputs( &proof_target .trace_cap @@ -291,8 +279,9 @@ where num_permutation_zs, ); - let challenger_state = std::array::from_fn(|_| builder.add_virtual_public_input()); - let mut challenger = RecursiveChallenger::::from_state(challenger_state); + let init_challenger_state_target = std::array::from_fn(|_| builder.add_virtual_public_input()); + let mut challenger = + RecursiveChallenger::::from_state(init_challenger_state_target); let challenges = proof_target.get_challenges::( &mut builder, &mut challenger, @@ -307,78 +296,39 @@ where verify_stark_proof_with_challenges_circuit::( &mut builder, - &stark, + stark, &proof_target, &challenges, &ctl_vars, inner_config, ); - builder.build_verifier::() + add_common_recursion_gates(&mut builder); + + // Pad to the minimum degree. + while log2_ceil(builder.num_gates()) < min_degree_bits { + builder.add_gate(NoopGate, vec![]); + } + + let circuit = builder.build::(); + StarkWrapperCircuit { + circuit, + stark_proof_target: proof_target, + ctl_challenges_target, + init_challenger_state_target, + zero_target, + } } -/// Returns the recursive Stark circuit verifier data for every Stark in `AllStark`. -pub fn all_verifier_data_recursive_stark_proof< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, ->( - all_stark: &AllStark, - degree_bits: [usize; NUM_TABLES], - inner_config: &StarkConfig, - circuit_config: &CircuitConfig, -) -> [VerifierCircuitData; NUM_TABLES] -where - [(); CpuStark::::COLUMNS]:, - [(); KeccakStark::::COLUMNS]:, - [(); KeccakSpongeStark::::COLUMNS]:, - [(); LogicStark::::COLUMNS]:, - [(); MemoryStark::::COLUMNS]:, - [(); C::Hasher::HASH_SIZE]:, - C::Hasher: AlgebraicHasher, -{ - [ - verifier_data_recursive_stark_proof( - Table::Cpu, - all_stark.cpu_stark, - degree_bits[Table::Cpu as usize], - &all_stark.cross_table_lookups, - inner_config, - circuit_config, - ), - verifier_data_recursive_stark_proof( - Table::Keccak, - all_stark.keccak_stark, - degree_bits[Table::Keccak as usize], - &all_stark.cross_table_lookups, - inner_config, - circuit_config, - ), - verifier_data_recursive_stark_proof( - Table::KeccakSponge, - all_stark.keccak_sponge_stark, - degree_bits[Table::KeccakSponge as usize], - &all_stark.cross_table_lookups, - inner_config, - circuit_config, - ), - verifier_data_recursive_stark_proof( - Table::Logic, - all_stark.logic_stark, - degree_bits[Table::Logic as usize], - &all_stark.cross_table_lookups, - inner_config, - circuit_config, - ), - verifier_data_recursive_stark_proof( - Table::Memory, - all_stark.memory_stark, - degree_bits[Table::Memory as usize], - &all_stark.cross_table_lookups, - inner_config, - circuit_config, - ), - ] +/// Add gates that are sometimes used by recursive circuits, even if it's not actually used by this +/// particular recursive circuit. This is done for uniformity. We sometimes want all recursion +/// circuits to have the same gate set, so that we can do 1-of-n conditional recursion efficiently. +pub(crate) fn add_common_recursion_gates, const D: usize>( + builder: &mut CircuitBuilder, +) { + builder.add_gate_to_gate_set(GateRef::new(ExponentiationGate::new_from_config( + &builder.config, + ))); } /// Recursively verifies an inner proof. @@ -509,87 +459,8 @@ fn eval_l_0_and_l_last_circuit, const D: usize>( ) } -pub fn add_virtual_all_proof, const D: usize>( - builder: &mut CircuitBuilder, - all_stark: &AllStark, - config: &StarkConfig, - degree_bits: &[usize], - nums_ctl_zs: &[usize], -) -> AllProofTarget { - let stark_proofs = [ - add_virtual_stark_proof( - builder, - &all_stark.cpu_stark, - config, - degree_bits[Table::Cpu as usize], - nums_ctl_zs[Table::Cpu as usize], - ), - add_virtual_stark_proof( - builder, - &all_stark.keccak_stark, - config, - degree_bits[Table::Keccak as usize], - nums_ctl_zs[Table::Keccak as usize], - ), - add_virtual_stark_proof( - builder, - &all_stark.keccak_sponge_stark, - config, - degree_bits[Table::KeccakSponge as usize], - nums_ctl_zs[Table::KeccakSponge as usize], - ), - add_virtual_stark_proof( - builder, - &all_stark.logic_stark, - config, - degree_bits[Table::Logic as usize], - nums_ctl_zs[Table::Logic as usize], - ), - add_virtual_stark_proof( - builder, - &all_stark.memory_stark, - config, - degree_bits[Table::Memory as usize], - nums_ctl_zs[Table::Memory as usize], - ), - ]; - - let public_values = add_virtual_public_values(builder); - AllProofTarget { - stark_proofs, - public_values, - } -} - -/// Returns `RecursiveAllProofTargetWithData` where the proofs targets are virtual and the -/// verifier data targets are constants. -pub fn add_virtual_recursive_all_proof, H, C, const D: usize>( - builder: &mut CircuitBuilder, - verifier_data: &[VerifierCircuitData; NUM_TABLES], -) -> RecursiveAllProofTargetWithData -where - H: Hasher>, - C: GenericConfig, -{ - let recursive_proofs = std::array::from_fn(|i| { - let verifier_data = &verifier_data[i]; - builder.add_virtual_proof_with_pis::(&verifier_data.common) - }); - let verifier_data = std::array::from_fn(|i| { - let verifier_data = &verifier_data[i]; - VerifierCircuitTarget { - constants_sigmas_cap: builder - .constant_merkle_cap(&verifier_data.verifier_only.constants_sigmas_cap), - circuit_digest: builder.add_virtual_hash(), - } - }); - RecursiveAllProofTargetWithData { - recursive_proofs, - verifier_data, - } -} - -pub fn add_virtual_public_values, const D: usize>( +#[allow(unused)] // TODO: used later? +pub(crate) fn add_virtual_public_values, const D: usize>( builder: &mut CircuitBuilder, ) -> PublicValuesTarget { let trie_roots_before = add_virtual_trie_roots(builder); @@ -602,7 +473,7 @@ pub fn add_virtual_public_values, const D: usize>( } } -pub fn add_virtual_trie_roots, const D: usize>( +pub(crate) fn add_virtual_trie_roots, const D: usize>( builder: &mut CircuitBuilder, ) -> TrieRootsTarget { let state_root = builder.add_virtual_target_arr(); @@ -615,7 +486,7 @@ pub fn add_virtual_trie_roots, const D: usize>( } } -pub fn add_virtual_block_metadata, const D: usize>( +pub(crate) fn add_virtual_block_metadata, const D: usize>( builder: &mut CircuitBuilder, ) -> BlockMetadataTarget { let block_beneficiary = builder.add_virtual_target_arr(); @@ -636,7 +507,11 @@ pub fn add_virtual_block_metadata, const D: usize>( } } -pub fn add_virtual_stark_proof, S: Stark, const D: usize>( +pub(crate) fn add_virtual_stark_proof< + F: RichField + Extendable, + S: Stark, + const D: usize, +>( builder: &mut CircuitBuilder, stark: &S, config: &StarkConfig, @@ -683,47 +558,7 @@ fn add_virtual_stark_opening_set, S: Stark, c } } -pub fn set_recursive_all_proof_target, W, const D: usize>( - witness: &mut W, - recursive_all_proof_target: &RecursiveAllProofTargetWithData, - all_proof: &RecursiveAllProof, -) where - F: RichField + Extendable, - C::Hasher: AlgebraicHasher, - W: Witness, -{ - for i in 0..NUM_TABLES { - witness.set_proof_with_pis_target( - &recursive_all_proof_target.recursive_proofs[i], - &all_proof.recursive_proofs[i], - ); - } -} -pub fn set_all_proof_target, W, const D: usize>( - witness: &mut W, - all_proof_target: &AllProofTarget, - all_proof: &AllProof, - zero: Target, -) where - F: RichField + Extendable, - C::Hasher: AlgebraicHasher, - W: Witness, -{ - for (pt, p) in all_proof_target - .stark_proofs - .iter() - .zip_eq(&all_proof.stark_proofs) - { - set_stark_proof_target(witness, pt, p, zero); - } - set_public_value_targets( - witness, - &all_proof_target.public_values, - &all_proof.public_values, - ) -} - -pub fn set_stark_proof_target, W, const D: usize>( +pub(crate) fn set_stark_proof_target, W, const D: usize>( witness: &mut W, proof_target: &StarkProofTarget, proof: &StarkProof, @@ -749,7 +584,8 @@ pub fn set_stark_proof_target, W, const D: usize>( set_fri_proof_target(witness, &proof_target.opening_proof, &proof.opening_proof); } -pub fn set_public_value_targets( +#[allow(unused)] // TODO: used later? +pub(crate) fn set_public_value_targets( witness: &mut W, public_values_target: &PublicValuesTarget, public_values: &PublicValues, @@ -774,7 +610,7 @@ pub fn set_public_value_targets( ); } -pub fn set_trie_roots_target( +pub(crate) fn set_trie_roots_target( witness: &mut W, trie_roots_target: &TrieRootsTarget, trie_roots: &TrieRoots, @@ -796,7 +632,7 @@ pub fn set_trie_roots_target( ); } -pub fn set_block_metadata_target( +pub(crate) fn set_block_metadata_target( witness: &mut W, block_metadata_target: &BlockMetadataTarget, block_metadata: &BlockMetadata, @@ -833,220 +669,3 @@ pub fn set_block_metadata_target( F::from_canonical_u64(block_metadata.block_base_fee.as_u64()), ); } - -#[cfg(test)] -pub(crate) mod tests { - use anyhow::Result; - use plonky2::field::extension::Extendable; - use plonky2::hash::hash_types::RichField; - use plonky2::hash::hashing::SPONGE_WIDTH; - use plonky2::iop::challenger::RecursiveChallenger; - use plonky2::iop::witness::{PartialWitness, WitnessWrite}; - use plonky2::plonk::circuit_builder::CircuitBuilder; - use plonky2::plonk::circuit_data::{CircuitConfig, VerifierCircuitData}; - use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; - use plonky2::plonk::proof::ProofWithPublicInputs; - - use crate::all_stark::{AllStark, Table}; - use crate::config::StarkConfig; - use crate::cpu::cpu_stark::CpuStark; - use crate::cross_table_lookup::{CrossTableLookup, CtlCheckVarsTarget}; - use crate::keccak::keccak_stark::KeccakStark; - use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; - use crate::logic::LogicStark; - use crate::memory::memory_stark::MemoryStark; - use crate::permutation::{GrandProductChallenge, GrandProductChallengeSet}; - use crate::proof::{AllChallengerState, AllProof, StarkProof}; - use crate::recursive_verifier::{ - add_virtual_stark_proof, set_stark_proof_target, - verify_stark_proof_with_challenges_circuit, RecursiveAllProof, - }; - use crate::stark::Stark; - - /// Recursively verify a Stark proof. - /// Outputs the recursive proof and the associated verifier data. - #[allow(unused)] // TODO: used later? - fn recursively_verify_stark_proof< - F: RichField + Extendable, - C: GenericConfig, - S: Stark, - const D: usize, - >( - table: Table, - stark: S, - proof: &StarkProof, - cross_table_lookups: &[CrossTableLookup], - ctl_challenges: &GrandProductChallengeSet, - challenger_state_before_vals: [F; SPONGE_WIDTH], - inner_config: &StarkConfig, - circuit_config: &CircuitConfig, - ) -> Result<(ProofWithPublicInputs, VerifierCircuitData)> - where - [(); S::COLUMNS]:, - [(); C::Hasher::HASH_SIZE]:, - C::Hasher: AlgebraicHasher, - { - let mut builder = CircuitBuilder::::new(circuit_config.clone()); - let mut pw = PartialWitness::new(); - - let num_permutation_zs = stark.num_permutation_batches(inner_config); - let num_permutation_batch_size = stark.permutation_batch_size(); - let proof_target = add_virtual_stark_proof( - &mut builder, - &stark, - inner_config, - proof.recover_degree_bits(inner_config), - proof.num_ctl_zs(), - ); - set_stark_proof_target(&mut pw, &proof_target, proof, builder.zero()); - builder.register_public_inputs( - &proof_target - .trace_cap - .0 - .iter() - .flat_map(|h| h.elements) - .collect::>(), - ); - - let ctl_challenges_target = GrandProductChallengeSet { - challenges: (0..inner_config.num_challenges) - .map(|_| GrandProductChallenge { - beta: builder.add_virtual_public_input(), - gamma: builder.add_virtual_public_input(), - }) - .collect(), - }; - for i in 0..inner_config.num_challenges { - pw.set_target( - ctl_challenges_target.challenges[i].beta, - ctl_challenges.challenges[i].beta, - ); - pw.set_target( - ctl_challenges_target.challenges[i].gamma, - ctl_challenges.challenges[i].gamma, - ); - } - - let ctl_vars = CtlCheckVarsTarget::from_proof( - table, - &proof_target, - cross_table_lookups, - &ctl_challenges_target, - num_permutation_zs, - ); - - let challenger_state_before = std::array::from_fn(|_| builder.add_virtual_public_input()); - pw.set_target_arr(challenger_state_before, challenger_state_before_vals); - let mut challenger = - RecursiveChallenger::::from_state(challenger_state_before); - let challenges = proof_target.get_challenges::( - &mut builder, - &mut challenger, - num_permutation_zs > 0, - num_permutation_batch_size, - inner_config, - ); - let challenger_state_after = challenger.compact(&mut builder); - builder.register_public_inputs(&challenger_state_after); - - builder.register_public_inputs(&proof_target.openings.ctl_zs_last); - - verify_stark_proof_with_challenges_circuit::( - &mut builder, - &stark, - &proof_target, - &challenges, - &ctl_vars, - inner_config, - ); - - let data = builder.build::(); - Ok((data.prove(pw)?, data.verifier_data())) - } - - /// Recursively verify every Stark proof in an `AllProof`. - #[allow(unused)] // TODO: used later? - pub fn recursively_verify_all_proof< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - >( - all_stark: &AllStark, - all_proof: &AllProof, - inner_config: &StarkConfig, - circuit_config: &CircuitConfig, - ) -> Result> - where - [(); CpuStark::::COLUMNS]:, - [(); KeccakStark::::COLUMNS]:, - [(); KeccakSpongeStark::::COLUMNS]:, - [(); LogicStark::::COLUMNS]:, - [(); MemoryStark::::COLUMNS]:, - [(); C::Hasher::HASH_SIZE]:, - C::Hasher: AlgebraicHasher, - { - let AllChallengerState { - states, - ctl_challenges, - } = all_proof.get_challenger_states(all_stark, inner_config); - Ok(RecursiveAllProof { - recursive_proofs: [ - recursively_verify_stark_proof( - Table::Cpu, - all_stark.cpu_stark, - &all_proof.stark_proofs[Table::Cpu as usize], - &all_stark.cross_table_lookups, - &ctl_challenges, - states[0], - inner_config, - circuit_config, - )? - .0, - recursively_verify_stark_proof( - Table::Keccak, - all_stark.keccak_stark, - &all_proof.stark_proofs[Table::Keccak as usize], - &all_stark.cross_table_lookups, - &ctl_challenges, - states[1], - inner_config, - circuit_config, - )? - .0, - recursively_verify_stark_proof( - Table::KeccakSponge, - all_stark.keccak_sponge_stark, - &all_proof.stark_proofs[Table::KeccakSponge as usize], - &all_stark.cross_table_lookups, - &ctl_challenges, - states[2], - inner_config, - circuit_config, - )? - .0, - recursively_verify_stark_proof( - Table::Logic, - all_stark.logic_stark, - &all_proof.stark_proofs[Table::Logic as usize], - &all_stark.cross_table_lookups, - &ctl_challenges, - states[3], - inner_config, - circuit_config, - )? - .0, - recursively_verify_stark_proof( - Table::Memory, - all_stark.memory_stark, - &all_proof.stark_proofs[Table::Memory as usize], - &all_stark.cross_table_lookups, - &ctl_challenges, - states[4], - inner_config, - circuit_config, - )? - .0, - ], - }) - } -} diff --git a/evm/src/verifier.rs b/evm/src/verifier.rs index a0329d04..c6d0373e 100644 --- a/evm/src/verifier.rs +++ b/evm/src/verifier.rs @@ -26,7 +26,7 @@ use crate::vanishing_poly::eval_vanishing_poly; use crate::vars::StarkEvaluationVars; pub fn verify_proof, C: GenericConfig, const D: usize>( - all_stark: AllStark, + all_stark: &AllStark, all_proof: AllProof, config: &StarkConfig, ) -> Result<()> @@ -41,7 +41,7 @@ where let AllProofChallenges { stark_challenges, ctl_challenges, - } = all_proof.get_challenges(&all_stark, config); + } = all_proof.get_challenges(all_stark, config); let nums_permutation_zs = all_stark.nums_permutation_zs(config); @@ -56,54 +56,50 @@ where let ctl_vars_per_table = CtlCheckVars::from_proofs( &all_proof.stark_proofs, - &cross_table_lookups, + cross_table_lookups, &ctl_challenges, &nums_permutation_zs, ); verify_stark_proof_with_challenges( cpu_stark, - &all_proof.stark_proofs[Table::Cpu as usize], + &all_proof.stark_proofs[Table::Cpu as usize].proof, &stark_challenges[Table::Cpu as usize], &ctl_vars_per_table[Table::Cpu as usize], config, )?; verify_stark_proof_with_challenges( keccak_stark, - &all_proof.stark_proofs[Table::Keccak as usize], + &all_proof.stark_proofs[Table::Keccak as usize].proof, &stark_challenges[Table::Keccak as usize], &ctl_vars_per_table[Table::Keccak as usize], config, )?; verify_stark_proof_with_challenges( keccak_sponge_stark, - &all_proof.stark_proofs[Table::KeccakSponge as usize], + &all_proof.stark_proofs[Table::KeccakSponge as usize].proof, &stark_challenges[Table::KeccakSponge as usize], &ctl_vars_per_table[Table::KeccakSponge as usize], config, )?; verify_stark_proof_with_challenges( memory_stark, - &all_proof.stark_proofs[Table::Memory as usize], + &all_proof.stark_proofs[Table::Memory as usize].proof, &stark_challenges[Table::Memory as usize], &ctl_vars_per_table[Table::Memory as usize], config, )?; verify_stark_proof_with_challenges( logic_stark, - &all_proof.stark_proofs[Table::Logic as usize], + &all_proof.stark_proofs[Table::Logic as usize].proof, &stark_challenges[Table::Logic as usize], &ctl_vars_per_table[Table::Logic as usize], config, )?; - let degrees_bits = - std::array::from_fn(|i| all_proof.stark_proofs[i].recover_degree_bits(config)); verify_cross_table_lookups::( cross_table_lookups, - all_proof.stark_proofs.map(|p| p.openings.ctl_zs_last), - degrees_bits, - ctl_challenges, + all_proof.stark_proofs.map(|p| p.proof.openings.ctl_zs_last), config, ) } @@ -114,7 +110,7 @@ pub(crate) fn verify_stark_proof_with_challenges< S: Stark, const D: usize, >( - stark: S, + stark: &S, proof: &StarkProof, challenges: &StarkProofChallenges, ctl_vars: &[CtlCheckVars], @@ -125,7 +121,7 @@ where [(); C::Hasher::HASH_SIZE]:, { log::debug!("Checking proof: {}", type_name::()); - validate_proof_shape(&stark, proof, config, ctl_vars.len())?; + validate_proof_shape(stark, proof, config, ctl_vars.len())?; let StarkOpeningSet { local_values, next_values, @@ -160,7 +156,7 @@ where permutation_challenge_sets: challenges.permutation_challenge_sets.clone().unwrap(), }); eval_vanishing_poly::( - &stark, + stark, config, vars, permutation_data, diff --git a/evm/tests/empty_txn_list.rs b/evm/tests/empty_txn_list.rs index aa7b60b9..2bd9a116 100644 --- a/evm/tests/empty_txn_list.rs +++ b/evm/tests/empty_txn_list.rs @@ -8,6 +8,7 @@ use plonky2::plonk::config::PoseidonGoldilocksConfig; use plonky2::util::timing::TimingTree; use plonky2_evm::all_stark::AllStark; use plonky2_evm::config::StarkConfig; +use plonky2_evm::fixed_recursive_verifier::AllRecursiveCircuits; use plonky2_evm::generation::{GenerationInputs, TrieInputs}; use plonky2_evm::proof::BlockMetadata; use plonky2_evm::prover::prove; @@ -19,6 +20,7 @@ type C = PoseidonGoldilocksConfig; /// Execute the empty list of transactions, i.e. a no-op. #[test] +#[ignore] // Too slow to run on CI. fn test_empty_txn_list() -> anyhow::Result<()> { init_logger(); @@ -49,7 +51,10 @@ fn test_empty_txn_list() -> anyhow::Result<()> { }; let mut timing = TimingTree::new("prove", log::Level::Debug); - let proof = prove::(&all_stark, &config, inputs, &mut timing)?; + // TODO: This is redundant; prove_root below calls this prove method internally. + // Just keeping it for now because the root proof returned by prove_root doesn't contain public + // values yet, and we want those for the assertions below. + let proof = prove::(&all_stark, &config, inputs.clone(), &mut timing)?; timing.filter(Duration::from_millis(100)).print(); assert_eq!( @@ -77,7 +82,14 @@ fn test_empty_txn_list() -> anyhow::Result<()> { receipts_trie_root ); - verify_proof(all_stark, proof, &config) + verify_proof(&all_stark, proof, &config)?; + + let all_circuits = AllRecursiveCircuits::::new(&all_stark, 9..19, &config); + let root_proof = all_circuits.prove_root(&all_stark, &config, inputs, &mut timing)?; + all_circuits.verify_root(root_proof.clone())?; + + let agg_proof = all_circuits.prove_aggregation(false, &root_proof, false, &root_proof)?; + all_circuits.verify_aggregation(&agg_proof) } fn init_logger() { diff --git a/evm/tests/transfer_to_new_addr.rs b/evm/tests/transfer_to_new_addr.rs index e4fe8eb4..4a5e69e3 100644 --- a/evm/tests/transfer_to_new_addr.rs +++ b/evm/tests/transfer_to_new_addr.rs @@ -104,7 +104,7 @@ fn test_simple_transfer() -> anyhow::Result<()> { expected_state_trie_after.calc_hash() ); - verify_proof(all_stark, proof, &config) + verify_proof(&all_stark, proof, &config) } fn eth_to_wei(eth: U256) -> U256 { diff --git a/plonky2/examples/bench_recursion.rs b/plonky2/examples/bench_recursion.rs index 0ad12e76..bf6b0e6b 100644 --- a/plonky2/examples/bench_recursion.rs +++ b/plonky2/examples/bench_recursion.rs @@ -14,9 +14,7 @@ use plonky2::gates::noop::NoopGate; use plonky2::hash::hash_types::RichField; use plonky2::iop::witness::{PartialWitness, WitnessWrite}; use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::circuit_data::{ - CircuitConfig, CommonCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData, -}; +use plonky2::plonk::circuit_data::{CircuitConfig, CommonCircuitData, VerifierOnlyCircuitData}; use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, PoseidonGoldilocksConfig}; use plonky2::plonk::proof::{CompressedProofWithPublicInputs, ProofWithPublicInputs}; use plonky2::plonk::prover::prove; @@ -107,10 +105,7 @@ where let mut builder = CircuitBuilder::::new(config.clone()); let pt = builder.add_virtual_proof_with_pis::(inner_cd); - let inner_data = VerifierCircuitTarget { - constants_sigmas_cap: builder.add_virtual_cap(inner_cd.config.fri_config.cap_height), - circuit_digest: builder.add_virtual_hash(), - }; + let inner_data = builder.add_virtual_verifier_data(inner_cd.config.fri_config.cap_height); builder.verify_proof::(&pt, &inner_data, inner_cd); builder.print_gate_counts(0); diff --git a/plonky2/src/fri/oracle.rs b/plonky2/src/fri/oracle.rs index cc114d98..90890134 100644 --- a/plonky2/src/fri/oracle.rs +++ b/plonky2/src/fri/oracle.rs @@ -3,12 +3,12 @@ use alloc::vec::Vec; use itertools::Itertools; use maybe_rayon::*; +use plonky2_field::types::Field; use crate::field::extension::Extendable; use crate::field::fft::FftRootTable; use crate::field::packed::PackedField; use crate::field::polynomial::{PolynomialCoeffs, PolynomialValues}; -use crate::field::types::Field; use crate::fri::proof::FriProof; use crate::fri::prover::fri_proof; use crate::fri::structure::{FriBatchInfo, FriInstanceInfo}; @@ -189,13 +189,11 @@ impl, C: GenericConfig, const D: usize> &format!("reduce batch of {} polynomials", polynomials.len()), alpha.reduce_polys_base(polys_coeff) ); - let quotient = composition_poly.divide_by_linear(*point); + let mut quotient = composition_poly.divide_by_linear(*point); + quotient.coeffs.push(F::Extension::ZERO); // pad back to power of two alpha.shift_poly(&mut final_poly); final_poly += quotient; } - // Multiply the final polynomial by `X`, so that `final_poly` has the maximum degree for - // which the LDT will pass. See github.com/mir-protocol/plonky2/pull/436 for details. - final_poly.coeffs.insert(0, F::Extension::ZERO); let lde_final_poly = final_poly.lde(fri_params.config.rate_bits); let lde_final_values = timed!( diff --git a/plonky2/src/fri/recursive_verifier.rs b/plonky2/src/fri/recursive_verifier.rs index 822dd559..e7e48f82 100644 --- a/plonky2/src/fri/recursive_verifier.rs +++ b/plonky2/src/fri/recursive_verifier.rs @@ -262,9 +262,7 @@ impl, const D: usize> CircuitBuilder { sum = self.div_add_extension(numerator, denominator, sum); } - // Multiply the final polynomial by `X`, so that `final_poly` has the maximum degree for - // which the LDT will pass. See github.com/mir-protocol/plonky2/pull/436 for details. - self.mul_extension(sum, subgroup_x) + sum } fn fri_verifier_query_round>( diff --git a/plonky2/src/fri/verifier.rs b/plonky2/src/fri/verifier.rs index 6644b971..f860ba30 100644 --- a/plonky2/src/fri/verifier.rs +++ b/plonky2/src/fri/verifier.rs @@ -157,9 +157,7 @@ pub(crate) fn fri_combine_initial< sum += numerator / denominator; } - // Multiply the final polynomial by `X`, so that `final_poly` has the maximum degree for - // which the LDT will pass. See github.com/mir-protocol/plonky2/pull/436 for details. - sum * subgroup_x + sum } fn fri_verifier_query_round< diff --git a/plonky2/src/gadgets/random_access.rs b/plonky2/src/gadgets/random_access.rs index d3a3ff1b..73e3de8c 100644 --- a/plonky2/src/gadgets/random_access.rs +++ b/plonky2/src/gadgets/random_access.rs @@ -2,15 +2,15 @@ use alloc::vec::Vec; use crate::field::extension::Extendable; use crate::gates::random_access::RandomAccessGate; -use crate::hash::hash_types::RichField; +use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget, RichField}; use crate::iop::ext_target::ExtensionTarget; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::VerifierCircuitTarget; use crate::util::log2_strict; impl, const D: usize> CircuitBuilder { - /// Checks that a `Target` matches a vector at a non-deterministic index. - /// Note: `access_index` is not range-checked. + /// Checks that a `Target` matches a vector at a particular index. pub fn random_access(&mut self, access_index: Target, v: Vec) -> Target { let vec_size = v.len(); let bits = log2_strict(vec_size); @@ -38,18 +38,64 @@ impl, const D: usize> CircuitBuilder { claimed_element } - /// Checks that an `ExtensionTarget` matches a vector at a non-deterministic index. - /// Note: `access_index` is not range-checked. + /// Like `random_access`, but with `ExtensionTarget`s rather than simple `Target`s. pub fn random_access_extension( &mut self, access_index: Target, v: Vec>, ) -> ExtensionTarget { - let v: Vec<_> = (0..D) + let selected: Vec<_> = (0..D) .map(|i| self.random_access(access_index, v.iter().map(|et| et.0[i]).collect())) .collect(); - ExtensionTarget(v.try_into().unwrap()) + ExtensionTarget(selected.try_into().unwrap()) + } + + /// Like `random_access`, but with `HashOutTarget`s rather than simple `Target`s. + pub fn random_access_hash( + &mut self, + access_index: Target, + v: Vec, + ) -> HashOutTarget { + let selected = std::array::from_fn(|i| { + self.random_access( + access_index, + v.iter().map(|hash| hash.elements[i]).collect(), + ) + }); + selected.into() + } + + /// Like `random_access`, but with `MerkleCapTarget`s rather than simple `Target`s. + pub fn random_access_merkle_cap( + &mut self, + access_index: Target, + v: Vec, + ) -> MerkleCapTarget { + let cap_size = v[0].0.len(); + assert!(v.iter().all(|cap| cap.0.len() == cap_size)); + + let selected = (0..cap_size) + .map(|i| self.random_access_hash(access_index, v.iter().map(|cap| cap.0[i]).collect())) + .collect(); + MerkleCapTarget(selected) + } + + /// Like `random_access`, but with `VerifierCircuitTarget`s rather than simple `Target`s. + pub fn random_access_verifier_data( + &mut self, + access_index: Target, + v: Vec, + ) -> VerifierCircuitTarget { + let constants_sigmas_caps = v.iter().map(|vk| vk.constants_sigmas_cap.clone()).collect(); + let circuit_digests = v.iter().map(|vk| vk.circuit_digest).collect(); + let constants_sigmas_cap = + self.random_access_merkle_cap(access_index, constants_sigmas_caps); + let circuit_digest = self.random_access_hash(access_index, circuit_digests); + VerifierCircuitTarget { + constants_sigmas_cap, + circuit_digest, + } } } diff --git a/plonky2/src/hash/merkle_proofs.rs b/plonky2/src/hash/merkle_proofs.rs index cd696e55..3726d847 100644 --- a/plonky2/src/hash/merkle_proofs.rs +++ b/plonky2/src/hash/merkle_proofs.rs @@ -11,6 +11,7 @@ use crate::hash::hashing::SPONGE_WIDTH; use crate::hash::merkle_tree::MerkleCap; use crate::iop::target::{BoolTarget, Target}; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::VerifierCircuitTarget; use crate::plonk::config::{AlgebraicHasher, Hasher}; #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] @@ -152,6 +153,11 @@ impl, const D: usize> CircuitBuilder { self.connect_hashes(*h0, *h1); } } + + pub fn connect_verifier_data(&mut self, x: &VerifierCircuitTarget, y: &VerifierCircuitTarget) { + self.connect_merkle_caps(&x.constants_sigmas_cap, &y.constants_sigmas_cap); + self.connect_hashes(x.circuit_digest, y.circuit_digest); + } } #[cfg(test)] diff --git a/plonky2/src/hash/merkle_tree.rs b/plonky2/src/hash/merkle_tree.rs index 86871701..7a9cd1f3 100644 --- a/plonky2/src/hash/merkle_tree.rs +++ b/plonky2/src/hash/merkle_tree.rs @@ -14,6 +14,7 @@ use crate::util::log2_strict; /// It can be used in place of the root to verify Merkle paths, which are `h` elements shorter. #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] #[serde(bound = "")] +// TODO: Change H to GenericHashOut, since this only cares about the hash, not the hasher. pub struct MerkleCap>(pub Vec); impl> MerkleCap { diff --git a/plonky2/src/iop/witness.rs b/plonky2/src/iop/witness.rs index e9ade3d2..14213d0d 100644 --- a/plonky2/src/iop/witness.rs +++ b/plonky2/src/iop/witness.rs @@ -14,7 +14,7 @@ use crate::iop::ext_target::ExtensionTarget; use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; use crate::plonk::circuit_data::{VerifierCircuitTarget, VerifierOnlyCircuitData}; -use crate::plonk::config::{AlgebraicHasher, GenericConfig}; +use crate::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; use crate::plonk::proof::{Proof, ProofTarget, ProofWithPublicInputs, ProofWithPublicInputsTarget}; pub trait WitnessWrite { @@ -224,6 +224,19 @@ pub trait Witness: WitnessWrite { } } + fn get_merkle_cap_target>(&self, cap_target: MerkleCapTarget) -> MerkleCap + where + F: RichField, + H: AlgebraicHasher, + { + let cap = cap_target + .0 + .iter() + .map(|hash_target| self.get_hash_target(*hash_target)) + .collect(); + MerkleCap(cap) + } + fn get_wire(&self, wire: Wire) -> F { self.get_target(Target::Wire(wire)) } diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index 9017a059..21305f4d 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -245,6 +245,13 @@ impl, const D: usize> CircuitBuilder { t } + pub fn add_virtual_verifier_data(&mut self, cap_height: usize) -> VerifierCircuitTarget { + VerifierCircuitTarget { + constants_sigmas_cap: self.add_virtual_cap(cap_height), + circuit_digest: self.add_virtual_hash(), + } + } + /// Add a virtual verifier data, register it as a public input and set it to `self.verifier_data_public_input`. /// WARNING: Do not register any public input after calling this! TODO: relax this pub fn add_verifier_data_public_inputs(&mut self) -> VerifierCircuitTarget { @@ -253,10 +260,7 @@ impl, const D: usize> CircuitBuilder { "add_verifier_data_public_inputs only needs to be called once" ); - let verifier_data = VerifierCircuitTarget { - constants_sigmas_cap: self.add_virtual_cap(self.config.fri_config.cap_height), - circuit_digest: self.add_virtual_hash(), - }; + let verifier_data = self.add_virtual_verifier_data(self.config.fri_config.cap_height); // The verifier data are public inputs. self.register_public_inputs(&verifier_data.circuit_digest.elements); for i in 0..self.config.fri_config.num_cap_elements() { @@ -784,13 +788,13 @@ impl, const D: usize> CircuitBuilder { self.add_simple_generator(const_gen); } - info!( + debug!( "Degree before blinding & padding: {}", self.gate_instances.len() ); self.blind_and_pad(); let degree = self.gate_instances.len(); - info!("Degree after blinding & padding: {}", degree); + debug!("Degree after blinding & padding: {}", degree); let degree_bits = log2_strict(degree); let fri_params = self.fri_params(degree_bits); assert!( diff --git a/plonky2/src/plonk/circuit_data.rs b/plonky2/src/plonk/circuit_data.rs index 6df986fa..c57e206d 100644 --- a/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/src/plonk/circuit_data.rs @@ -149,15 +149,15 @@ impl, C: GenericConfig, const D: usize> proof.decompress(&self.verifier_only.circuit_digest, &self.common) } - pub fn verifier_data(self) -> VerifierCircuitData { + pub fn verifier_data(&self) -> VerifierCircuitData { let CircuitData { verifier_only, common, .. } = self; VerifierCircuitData { - verifier_only, - common, + verifier_only: verifier_only.clone(), + common: common.clone(), } } @@ -258,7 +258,7 @@ pub struct ProverOnlyCircuitData< } /// Circuit data required by the verifier, but not the prover. -#[derive(Debug, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq)] pub struct VerifierOnlyCircuitData, const D: usize> { /// A commitment to each constant polynomial and each permutation polynomial. pub constants_sigmas_cap: MerkleCap, @@ -289,7 +289,7 @@ pub struct CommonCircuitData, const D: usize> { /// The number of constant wires. pub(crate) num_constants: usize, - pub(crate) num_public_inputs: usize, + pub num_public_inputs: usize, /// The `{k_i}` valued used in `S_ID_i` in Plonk's permutation argument. pub(crate) k_is: Vec, diff --git a/plonky2/src/recursion/conditional_recursive_verifier.rs b/plonky2/src/recursion/conditional_recursive_verifier.rs index ace47cab..2596e1f0 100644 --- a/plonky2/src/recursion/conditional_recursive_verifier.rs +++ b/plonky2/src/recursion/conditional_recursive_verifier.rs @@ -378,15 +378,11 @@ mod tests { pw.set_proof_with_pis_target(&pt, &proof); let dummy_pt = builder.add_virtual_proof_with_pis::(&data.common); pw.set_proof_with_pis_target::(&dummy_pt, &dummy_proof); - let inner_data = VerifierCircuitTarget { - constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height), - circuit_digest: builder.add_virtual_hash(), - }; + let inner_data = + builder.add_virtual_verifier_data(data.common.config.fri_config.cap_height); pw.set_verifier_data_target(&inner_data, &data.verifier_only); - let dummy_inner_data = VerifierCircuitTarget { - constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height), - circuit_digest: builder.add_virtual_hash(), - }; + let dummy_inner_data = + builder.add_virtual_verifier_data(data.common.config.fri_config.cap_height); pw.set_verifier_data_target(&dummy_inner_data, &dummy_data.verifier_only); let b = builder.constant_bool(F::rand().0 % 2 == 0); builder.conditionally_verify_proof::( diff --git a/plonky2/src/recursion/cyclic_recursion.rs b/plonky2/src/recursion/cyclic_recursion.rs index a12c31d4..656ff3b9 100644 --- a/plonky2/src/recursion/cyclic_recursion.rs +++ b/plonky2/src/recursion/cyclic_recursion.rs @@ -188,7 +188,7 @@ mod tests { use crate::hash::poseidon::{PoseidonHash, PoseidonPermutation}; use crate::iop::witness::{PartialWitness, WitnessWrite}; use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData, VerifierCircuitTarget}; + use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData}; use crate::plonk::config::{AlgebraicHasher, GenericConfig, PoseidonGoldilocksConfig}; use crate::recursion::cyclic_recursion::check_cyclic_proof_verifier_data; use crate::recursion::dummy_circuit::cyclic_base_proof; @@ -208,20 +208,16 @@ mod tests { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); let proof = builder.add_virtual_proof_with_pis::(&data.common); - let verifier_data = VerifierCircuitTarget { - constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height), - circuit_digest: builder.add_virtual_hash(), - }; + let verifier_data = + builder.add_virtual_verifier_data(data.common.config.fri_config.cap_height); builder.verify_proof::(&proof, &verifier_data, &data.common); let data = builder.build::(); let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); let proof = builder.add_virtual_proof_with_pis::(&data.common); - let verifier_data = VerifierCircuitTarget { - constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height), - circuit_digest: builder.add_virtual_hash(), - }; + let verifier_data = + builder.add_virtual_verifier_data(data.common.config.fri_config.cap_height); builder.verify_proof::(&proof, &verifier_data, &data.common); while builder.num_gates() < 1 << 12 { builder.add_gate(NoopGate, vec![]); diff --git a/plonky2/src/recursion/dummy_circuit.rs b/plonky2/src/recursion/dummy_circuit.rs index 34b20c71..38f51aea 100644 --- a/plonky2/src/recursion/dummy_circuit.rs +++ b/plonky2/src/recursion/dummy_circuit.rs @@ -113,11 +113,8 @@ impl, const D: usize> CircuitBuilder { let dummy_circuit = dummy_circuit::(common_data); let dummy_proof_with_pis = dummy_proof(&dummy_circuit, HashMap::new())?; let dummy_proof_with_pis_target = self.add_virtual_proof_with_pis::(common_data); - - let dummy_verifier_data_target = VerifierCircuitTarget { - constants_sigmas_cap: self.add_virtual_cap(self.config.fri_config.cap_height), - circuit_digest: self.add_virtual_hash(), - }; + let dummy_verifier_data_target = + self.add_virtual_verifier_data(self.config.fri_config.cap_height); self.add_simple_generator(DummyProofGenerator { proof_with_pis_target: dummy_proof_with_pis_target.clone(), diff --git a/plonky2/src/recursion/recursive_verifier.rs b/plonky2/src/recursion/recursive_verifier.rs index 15943a87..9aafb1f5 100644 --- a/plonky2/src/recursion/recursive_verifier.rs +++ b/plonky2/src/recursion/recursive_verifier.rs @@ -366,10 +366,7 @@ mod tests { let pt = builder.add_virtual_proof_with_pis::(&inner_cd); pw.set_proof_with_pis_target(&pt, &inner_proof); - let inner_data = VerifierCircuitTarget { - constants_sigmas_cap: builder.add_virtual_cap(inner_cd.config.fri_config.cap_height), - circuit_digest: builder.add_virtual_hash(), - }; + let inner_data = builder.add_virtual_verifier_data(inner_cd.config.fri_config.cap_height); pw.set_cap_target( &inner_data.constants_sigmas_cap, &inner_vd.constants_sigmas_cap,