From 595e751ac18efafa7f2d3352f6a41297090a216c Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Tue, 27 Dec 2022 18:15:18 -0800 Subject: [PATCH] Shrink STARK proofs to a constant degree The goal here is to end up with a single "root" circuit representing any EVM proof. I.e. it must verify each STARK, but be general enough to work with any combination of STARK sizes (within some range of sizes that we chose to support). This root circuit can then be plugged into our aggregation circuit. In particular, for each STARK, and for each initial `degree_bits` (within a range that we choose to support), this adds a "shrinking chain" of circuits. Such a chain shrinks a STARK proof from that initial `degree_bits` down to a constant, `THRESHOLD_DEGREE_BITS`. The root circuit then combines these shrunk-to-constant proofs for each table. It's similar to `RecursiveAllProof::verify_circuit`; I adapted the code from there and I think we can remove it after. The main difference is that now instead of having one verification key per STARK, we have several possible VKs, one per initial `degree_bits`. We bake the list of possible VKs into the root circuit, and have the prover indicate the index of the VK they're actually using. This also partially removes the default feature of CTLs. So far we've used filters instead of defaults. Until now it was easy to keep supporting defaults just in case, but here maintaining support would require some more work. E.g. we couldn't use `exp_u64` any more, since the size delta is now dynamic, it can't be hardcoded. If there are no concerns, I'll fully remove the feature after. --- evm/src/cross_table_lookup.rs | 60 +-- evm/src/fixed_recursive_verifier.rs | 410 ++++++++++++++++++ evm/src/get_challenges.rs | 8 +- evm/src/lib.rs | 1 + evm/src/permutation.rs | 13 +- evm/src/proof.rs | 18 +- evm/src/prover.rs | 23 +- evm/src/recursive_verifier.rs | 223 ++++++++-- evm/src/verifier.rs | 26 +- evm/tests/empty_txn_list.rs | 12 +- evm/tests/transfer_to_new_addr.rs | 2 +- plonky2/examples/bench_recursion.rs | 9 +- plonky2/src/gadgets/random_access.rs | 60 ++- plonky2/src/hash/merkle_proofs.rs | 6 + plonky2/src/hash/merkle_tree.rs | 1 + plonky2/src/iop/witness.rs | 15 +- plonky2/src/plonk/circuit_builder.rs | 16 +- plonky2/src/plonk/circuit_data.rs | 10 +- .../conditional_recursive_verifier.rs | 12 +- plonky2/src/recursion/cyclic_recursion.rs | 14 +- plonky2/src/recursion/dummy_circuit.rs | 7 +- plonky2/src/recursion/recursive_verifier.rs | 5 +- 22 files changed, 764 insertions(+), 187 deletions(-) create mode 100644 evm/src/fixed_recursive_verifier.rs diff --git a/evm/src/cross_table_lookup.rs b/evm/src/cross_table_lookup.rs index 4930321a..2d55c6db 100644 --- a/evm/src/cross_table_lookup.rs +++ b/evm/src/cross_table_lookup.rs @@ -8,7 +8,6 @@ use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; -use plonky2::iop::challenger::Challenger; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::iop::target::Target; use plonky2::plonk::circuit_builder::CircuitBuilder; @@ -17,10 +16,8 @@ use plonky2::plonk::config::GenericConfig; use crate::all_stark::{Table, NUM_TABLES}; use crate::config::StarkConfig; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::permutation::{ - get_grand_product_challenge_set, GrandProductChallenge, GrandProductChallengeSet, -}; -use crate::proof::{StarkProof, StarkProofTarget}; +use crate::permutation::{GrandProductChallenge, GrandProductChallengeSet}; +use crate::proof::{StarkProofTarget, StarkProofWithMetadata}; use crate::stark::Stark; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; @@ -163,6 +160,7 @@ pub struct CrossTableLookup { pub(crate) looking_tables: Vec>, pub(crate) looked_table: TableWithColumns, /// Default value if filters are not used. + // TODO: Remove? Ended up not using it. default: Option>, } @@ -234,13 +232,11 @@ impl CtlData { } } -pub fn cross_table_lookup_data, const D: usize>( - config: &StarkConfig, +pub(crate) fn cross_table_lookup_data, const D: usize>( trace_poly_values: &[Vec>; NUM_TABLES], cross_table_lookups: &[CrossTableLookup], - challenger: &mut Challenger, + ctl_challenges: &GrandProductChallengeSet, ) -> [CtlData; NUM_TABLES] { - let challenges = get_grand_product_challenge_set(challenger, config.num_challenges); let mut ctl_data_per_table = [0; NUM_TABLES].map(|_| CtlData::default()); for CrossTableLookup { looking_tables, @@ -249,7 +245,7 @@ pub fn cross_table_lookup_data, const D } in cross_table_lookups { log::debug!("Processing CTL for {:?}", looked_table.table); - for &challenge in &challenges.challenges { + for &challenge in &ctl_challenges.challenges { let zs_looking = looking_tables.iter().map(|table| { partial_products( &trace_poly_values[table.table as usize], @@ -358,7 +354,7 @@ impl<'a, F: RichField + Extendable, const D: usize> CtlCheckVars<'a, F, F::Extension, F::Extension, D> { pub(crate) fn from_proofs>( - proofs: &[StarkProof; NUM_TABLES], + proofs: &[StarkProofWithMetadata; NUM_TABLES], cross_table_lookups: &'a [CrossTableLookup], ctl_challenges: &'a GrandProductChallengeSet, num_permutation_zs: &[usize; NUM_TABLES], @@ -367,7 +363,7 @@ impl<'a, F: RichField + Extendable, const D: usize> .iter() .zip(num_permutation_zs) .map(|(p, &num_perms)| { - let openings = &p.openings; + let openings = &p.proof.openings; let ctl_zs = openings.permutation_ctl_zs.iter().skip(num_perms); let ctl_zs_next = openings.permutation_ctl_zs_next.iter().skip(num_perms); ctl_zs.zip(ctl_zs_next) @@ -582,7 +578,7 @@ pub(crate) fn verify_cross_table_lookups< C: GenericConfig, const D: usize, >( - cross_table_lookups: Vec>, + cross_table_lookups: &[CrossTableLookup], ctl_zs_lasts: [Vec; NUM_TABLES], degrees_bits: [usize; NUM_TABLES], challenges: GrandProductChallengeSet, @@ -597,7 +593,7 @@ pub(crate) fn verify_cross_table_lookups< default, .. }, - ) in cross_table_lookups.into_iter().enumerate() + ) in cross_table_lookups.iter().enumerate() { for _ in 0..config.num_challenges { let looking_degrees_sum = looking_tables @@ -635,47 +631,23 @@ pub(crate) fn verify_cross_table_lookups_circuit< builder: &mut CircuitBuilder, cross_table_lookups: Vec>, ctl_zs_lasts: [Vec; NUM_TABLES], - degrees_bits: [usize; NUM_TABLES], - challenges: GrandProductChallengeSet, inner_config: &StarkConfig, ) { let mut ctl_zs_openings = ctl_zs_lasts.iter().map(|v| v.iter()).collect::>(); - for ( - i, - CrossTableLookup { - looking_tables, - looked_table, - default, - .. - }, - ) in cross_table_lookups.into_iter().enumerate() + for CrossTableLookup { + looking_tables, + looked_table, + .. + } in cross_table_lookups.into_iter() { for _ in 0..inner_config.num_challenges { - let looking_degrees_sum = looking_tables - .iter() - .map(|table| 1 << degrees_bits[table.table as usize]) - .sum::(); - let looked_degree = 1 << degrees_bits[looked_table.table as usize]; let looking_zs_prod = builder.mul_many( looking_tables .iter() .map(|table| *ctl_zs_openings[table.table as usize].next().unwrap()), ); let looked_z = *ctl_zs_openings[looked_table.table as usize].next().unwrap(); - let challenge = challenges.challenges[i % inner_config.num_challenges]; - if let Some(default) = default.as_ref() { - let default = default - .iter() - .map(|&x| builder.constant(x)) - .collect::>(); - let combined_default = challenge.combine_base_circuit(builder, &default); - - let pad = builder.exp_u64(combined_default, looking_degrees_sum - looked_degree); - let padded_looked_z = builder.mul(looked_z, pad); - builder.connect(looking_zs_prod, padded_looked_z); - } else { - builder.connect(looking_zs_prod, looked_z); - } + builder.connect(looking_zs_prod, looked_z); } } debug_assert!(ctl_zs_openings.iter_mut().all(|iter| iter.next().is_none())); diff --git a/evm/src/fixed_recursive_verifier.rs b/evm/src/fixed_recursive_verifier.rs new file mode 100644 index 00000000..8fa19b3b --- /dev/null +++ b/evm/src/fixed_recursive_verifier.rs @@ -0,0 +1,410 @@ +use std::collections::BTreeMap; +use std::ops::Range; + +use itertools::Itertools; +use plonky2::field::extension::Extendable; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::hashing::SPONGE_WIDTH; +use plonky2::iop::challenger::RecursiveChallenger; +use plonky2::iop::target::Target; +use plonky2::iop::witness::{PartialWitness, WitnessWrite}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::{CircuitConfig, CircuitData}; +use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; +use plonky2::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}; +use plonky2::util::timing::TimingTree; + +use crate::all_stark::{all_cross_table_lookups, AllStark, Table, NUM_TABLES}; +use crate::config::StarkConfig; +use crate::cpu::cpu_stark::CpuStark; +use crate::cross_table_lookup::{verify_cross_table_lookups_circuit, CrossTableLookup}; +use crate::generation::GenerationInputs; +use crate::keccak::keccak_stark::KeccakStark; +use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; +use crate::logic::LogicStark; +use crate::memory::memory_stark::MemoryStark; +use crate::permutation::{get_grand_product_challenge_set_target, GrandProductChallengeSet}; +use crate::proof::StarkProofWithMetadata; +use crate::prover::prove; +use crate::recursive_verifier::{ + add_common_recursion_gates, recursive_stark_circuit, PlonkWrapperCircuit, PublicInputs, + StarkWrapperCircuit, +}; +use crate::stark::Stark; + +/// The recursion threshold. We end a chain of recursive proofs once we reach this size. +const THRESHOLD_DEGREE_BITS: usize = 13; + +/// Contains all recursive circuits used in the system. For each STARK and each initial +/// `degree_bits`, this contains a chain of recursive circuits for shrinking that STARK from +/// `degree_bits` to a constant `THRESHOLD_DEGREE_BITS`. It also contains a special root circuit +/// for combining each STARK's shrunk wrapper proof into a single proof. +pub struct AllRecursiveCircuits +where + F: RichField + Extendable, + C: GenericConfig, +{ + /// The root circuit, which aggregates the (shrunk) per-table recursive proofs. + pub root: RootCircuitData, + /// Holds chains of circuits for each table and for each initial `degree_bits`. + by_table: [RecursiveCircuitsForTable; NUM_TABLES], +} + +/// Data for the special root circuit, which is used to combine each STARK's shrunk wrapper proof +/// into a single proof. +pub struct RootCircuitData +where + F: RichField + Extendable, + C: GenericConfig, +{ + pub circuit: CircuitData, + proof_with_pis: [ProofWithPublicInputsTarget; NUM_TABLES], + /// For each table, various inner circuits may be used depending on the initial table size. + /// This target holds the index of the circuit (within `final_circuits()`) that was used. + index_verifier_data: [Target; NUM_TABLES], +} + +impl AllRecursiveCircuits +where + F: RichField + Extendable, + C: GenericConfig, + C::Hasher: AlgebraicHasher, + [(); C::Hasher::HASH_SIZE]:, + [(); CpuStark::::COLUMNS]:, + [(); KeccakStark::::COLUMNS]:, + [(); KeccakSpongeStark::::COLUMNS]:, + [(); LogicStark::::COLUMNS]:, + [(); MemoryStark::::COLUMNS]:, +{ + /// Preprocess all recursive circuits used by the system. + pub fn new( + all_stark: &AllStark, + degree_bits_range: Range, + stark_config: &StarkConfig, + ) -> Self { + let cpu = RecursiveCircuitsForTable::new( + Table::Cpu, + &all_stark.cpu_stark, + degree_bits_range.clone(), + &all_stark.cross_table_lookups, + stark_config, + ); + let keccak = RecursiveCircuitsForTable::new( + Table::Keccak, + &all_stark.keccak_stark, + degree_bits_range.clone(), + &all_stark.cross_table_lookups, + stark_config, + ); + let keccak_sponge = RecursiveCircuitsForTable::new( + Table::KeccakSponge, + &all_stark.keccak_sponge_stark, + degree_bits_range.clone(), + &all_stark.cross_table_lookups, + stark_config, + ); + let logic = RecursiveCircuitsForTable::new( + Table::Logic, + &all_stark.logic_stark, + degree_bits_range.clone(), + &all_stark.cross_table_lookups, + stark_config, + ); + let memory = RecursiveCircuitsForTable::new( + Table::Memory, + &all_stark.memory_stark, + degree_bits_range, + &all_stark.cross_table_lookups, + stark_config, + ); + + let by_table = [cpu, keccak, keccak_sponge, logic, memory]; + let root = Self::create_root_circuit(&by_table, stark_config); + Self { root, by_table } + } + + fn create_root_circuit( + by_table: &[RecursiveCircuitsForTable; NUM_TABLES], + stark_config: &StarkConfig, + ) -> RootCircuitData { + let inner_common_data: [_; NUM_TABLES] = + std::array::from_fn(|i| &by_table[i].final_circuits()[0].common); + + let mut builder = CircuitBuilder::new(CircuitConfig::standard_recursion_config()); + let recursive_proofs = + std::array::from_fn(|i| builder.add_virtual_proof_with_pis::(inner_common_data[i])); + let pis: [_; NUM_TABLES] = std::array::from_fn(|i| { + PublicInputs::from_vec(&recursive_proofs[i].public_inputs, stark_config) + }); + let index_verifier_data = std::array::from_fn(|_i| builder.add_virtual_target()); + + let mut challenger = RecursiveChallenger::::new(&mut builder); + for pi in &pis { + for h in &pi.trace_cap { + challenger.observe_elements(h); + } + } + let ctl_challenges = get_grand_product_challenge_set_target( + &mut builder, + &mut challenger, + stark_config.num_challenges, + ); + // Check that the correct CTL challenges are used in every proof. + for pi in &pis { + for i in 0..stark_config.num_challenges { + builder.connect( + ctl_challenges.challenges[i].beta, + pi.ctl_challenges.challenges[i].beta, + ); + builder.connect( + ctl_challenges.challenges[i].gamma, + pi.ctl_challenges.challenges[i].gamma, + ); + } + } + + let state = challenger.compact(&mut builder); + for k in 0..SPONGE_WIDTH { + builder.connect(state[k], pis[0].challenger_state_before[k]); + } + // Check that the challenger state is consistent between proofs. + for i in 1..NUM_TABLES { + for k in 0..SPONGE_WIDTH { + builder.connect( + pis[i].challenger_state_before[k], + pis[i - 1].challenger_state_after[k], + ); + } + } + + // Verify the CTL checks. + verify_cross_table_lookups_circuit::( + &mut builder, + all_cross_table_lookups(), + pis.map(|p| p.ctl_zs_last), + stark_config, + ); + + for (i, table_circuits) in by_table.iter().enumerate() { + let final_circuits = table_circuits.final_circuits(); + for final_circuit in &final_circuits { + assert_eq!( + &final_circuit.common, inner_common_data[i], + "common_data mismatch" + ); + } + let mut possible_vks = final_circuits + .into_iter() + .map(|c| builder.constant_verifier_data(&c.verifier_only)) + .collect_vec(); + // random_access_verifier_data expects a vector whose length is a power of two. + // To satisfy this, we will just add some duplicates of the first VK. + while !possible_vks.len().is_power_of_two() { + possible_vks.push(possible_vks[0].clone()); + } + let inner_verifier_data = + builder.random_access_verifier_data(index_verifier_data[i], possible_vks); + + builder.verify_proof::( + &recursive_proofs[i], + &inner_verifier_data, + inner_common_data[i], + ); + } + + RootCircuitData { + circuit: builder.build(), + proof_with_pis: recursive_proofs, + index_verifier_data, + } + } + + /// Create a proof for each STARK, then combine them, eventually culminating in a root proof. + pub fn prove_root( + &self, + all_stark: &AllStark, + config: &StarkConfig, + generation_inputs: GenerationInputs, + timing: &mut TimingTree, + ) -> anyhow::Result> { + let all_proof = prove::(all_stark, config, generation_inputs, timing)?; + let mut root_inputs = PartialWitness::new(); + for table in 0..NUM_TABLES { + let stark_proof = &all_proof.stark_proofs[table]; + let original_degree_bits = stark_proof.proof.recover_degree_bits(config); + let table_circuits = &self.by_table[table]; + let shrunk_proof = table_circuits.by_stark_size[&original_degree_bits] + .shrink(stark_proof, &all_proof.ctl_challenges)?; + let index_verifier_data = table_circuits + .by_stark_size + .keys() + .position(|&size| size == original_degree_bits) + .unwrap(); + root_inputs.set_target( + self.root.index_verifier_data[table], + F::from_canonical_usize(index_verifier_data), + ); + root_inputs.set_proof_with_pis_target(&self.root.proof_with_pis[table], &shrunk_proof); + } + self.root.circuit.prove(root_inputs) + } +} + +struct RecursiveCircuitsForTable +where + F: RichField + Extendable, + C: GenericConfig, +{ + /// A map from `log_2(height)` to a chain of shrinking recursion circuits starting at that + /// height. + by_stark_size: BTreeMap>, +} + +impl RecursiveCircuitsForTable +where + F: RichField + Extendable, + C: GenericConfig, + C::Hasher: AlgebraicHasher, + [(); C::Hasher::HASH_SIZE]:, +{ + fn new>( + table: Table, + stark: &S, + degree_bits_range: Range, + all_ctls: &[CrossTableLookup], + stark_config: &StarkConfig, + ) -> Self + where + [(); S::COLUMNS]:, + { + let by_stark_size = degree_bits_range + .map(|degree_bits| { + ( + degree_bits, + RecursiveCircuitsForTableSize::new::( + table, + stark, + degree_bits, + all_ctls, + stark_config, + ), + ) + }) + .collect(); + Self { by_stark_size } + } + + /// For each initial `degree_bits`, get the final circuit at the end of that shrinking chain. + /// Each of these final circuits should have degree `THRESHOLD_DEGREE_BITS`. + fn final_circuits(&self) -> Vec<&CircuitData> { + self.by_stark_size + .values() + .map(|chain| { + chain + .shrinking_wrappers + .last() + .map(|wrapper| &wrapper.circuit) + .unwrap_or(&chain.initial_wrapper.circuit) + }) + .collect() + } +} + +/// A chain of shrinking wrapper circuits, ending with a final circuit with `degree_bits` +/// `THRESHOLD_DEGREE_BITS`. +struct RecursiveCircuitsForTableSize +where + F: RichField + Extendable, + C: GenericConfig, +{ + initial_wrapper: StarkWrapperCircuit, + shrinking_wrappers: Vec>, +} + +impl RecursiveCircuitsForTableSize +where + F: RichField + Extendable, + C: GenericConfig, + C::Hasher: AlgebraicHasher, + [(); C::Hasher::HASH_SIZE]:, +{ + fn new>( + table: Table, + stark: &S, + degree_bits: usize, + all_ctls: &[CrossTableLookup], + stark_config: &StarkConfig, + ) -> Self + where + [(); S::COLUMNS]:, + { + let initial_wrapper = recursive_stark_circuit( + table, + stark, + degree_bits, + all_ctls, + stark_config, + &shrinking_config(), + THRESHOLD_DEGREE_BITS, + ); + let mut shrinking_wrappers = vec![]; + + // Shrinking recursion loop. + loop { + let last = shrinking_wrappers + .last() + .map(|wrapper: &PlonkWrapperCircuit| &wrapper.circuit) + .unwrap_or(&initial_wrapper.circuit); + let last_degree_bits = last.common.degree_bits(); + assert!(last_degree_bits >= THRESHOLD_DEGREE_BITS); + if last_degree_bits == THRESHOLD_DEGREE_BITS { + break; + } + + let mut builder = CircuitBuilder::new(shrinking_config()); + let proof_with_pis_target = builder.add_virtual_proof_with_pis::(&last.common); + let last_vk = builder.constant_verifier_data(&last.verifier_only); + builder.verify_proof::(&proof_with_pis_target, &last_vk, &last.common); + builder.register_public_inputs(&proof_with_pis_target.public_inputs); // carry PIs forward + add_common_recursion_gates(&mut builder); + let circuit = builder.build(); + + assert!( + circuit.common.degree_bits() < last_degree_bits, + "Couldn't shrink to expected recursion threshold of 2^{}; stalled at 2^{}", + THRESHOLD_DEGREE_BITS, + circuit.common.degree_bits() + ); + shrinking_wrappers.push(PlonkWrapperCircuit { + circuit, + proof_with_pis_target, + }); + } + + Self { + initial_wrapper, + shrinking_wrappers, + } + } + + fn shrink( + &self, + stark_proof_with_metadata: &StarkProofWithMetadata, + ctl_challenges: &GrandProductChallengeSet, + ) -> anyhow::Result> { + let mut proof = self + .initial_wrapper + .prove(stark_proof_with_metadata, ctl_challenges)?; + for wrapper_circuit in &self.shrinking_wrappers { + proof = wrapper_circuit.prove(&proof)?; + } + Ok(proof) + } +} + +fn shrinking_config() -> CircuitConfig { + CircuitConfig { + num_routed_wires: 40, + ..CircuitConfig::standard_recursion_config() + } +} diff --git a/evm/src/get_challenges.rs b/evm/src/get_challenges.rs index ede7c466..f368b7c2 100644 --- a/evm/src/get_challenges.rs +++ b/evm/src/get_challenges.rs @@ -23,7 +23,7 @@ impl, C: GenericConfig, const D: usize> A let mut challenger = Challenger::::new(); for proof in &self.stark_proofs { - challenger.observe_cap(&proof.trace_cap); + challenger.observe_cap(&proof.proof.trace_cap); } // TODO: Observe public values. @@ -37,7 +37,7 @@ impl, C: GenericConfig, const D: usize> A AllProofChallenges { stark_challenges: std::array::from_fn(|i| { challenger.compact(); - self.stark_proofs[i].get_challenges( + self.stark_proofs[i].proof.get_challenges( &mut challenger, num_permutation_zs[i] > 0, num_permutation_batch_sizes[i], @@ -57,7 +57,7 @@ impl, C: GenericConfig, const D: usize> A let mut challenger = Challenger::::new(); for proof in &self.stark_proofs { - challenger.observe_cap(&proof.trace_cap); + challenger.observe_cap(&proof.proof.trace_cap); } // TODO: Observe public values. @@ -70,7 +70,7 @@ impl, C: GenericConfig, const D: usize> A let mut challenger_states = vec![challenger.compact()]; for i in 0..NUM_TABLES { - self.stark_proofs[i].get_challenges( + self.stark_proofs[i].proof.get_challenges( &mut challenger, num_permutation_zs[i] > 0, num_permutation_batch_sizes[i], diff --git a/evm/src/lib.rs b/evm/src/lib.rs index 4c368491..6ca956c4 100644 --- a/evm/src/lib.rs +++ b/evm/src/lib.rs @@ -12,6 +12,7 @@ pub mod config; pub mod constraint_consumer; pub mod cpu; pub mod cross_table_lookup; +pub mod fixed_recursive_verifier; pub mod generation; mod get_challenges; pub mod keccak; diff --git a/evm/src/permutation.rs b/evm/src/permutation.rs index b081c309..4f42a4aa 100644 --- a/evm/src/permutation.rs +++ b/evm/src/permutation.rs @@ -15,9 +15,7 @@ use plonky2::iop::ext_target::ExtensionTarget; use plonky2::iop::target::Target; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; -use plonky2::plonk::plonk_common::{ - reduce_with_powers, reduce_with_powers_circuit, reduce_with_powers_ext_circuit, -}; +use plonky2::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_circuit}; use plonky2::util::reducing::{ReducingFactor, ReducingFactorTarget}; use crate::config::StarkConfig; @@ -82,15 +80,6 @@ impl GrandProductChallenge { let gamma = builder.convert_to_ext(self.gamma); builder.add_extension(reduced, gamma) } - - pub(crate) fn combine_base_circuit, const D: usize>( - &self, - builder: &mut CircuitBuilder, - terms: &[Target], - ) -> Target { - let reduced = reduce_with_powers_circuit(builder, terms, self.beta); - builder.add(reduced, self.gamma) - } } /// Like `PermutationChallenge`, but with `num_challenges` copies to boost soundness. diff --git a/evm/src/proof.rs b/evm/src/proof.rs index 4cd03a65..46b88ca2 100644 --- a/evm/src/proof.rs +++ b/evm/src/proof.rs @@ -19,15 +19,17 @@ use crate::all_stark::NUM_TABLES; use crate::config::StarkConfig; use crate::permutation::GrandProductChallengeSet; +/// A STARK proof for each table, plus some metadata used to create recursive wrapper proofs. #[derive(Debug, Clone)] pub struct AllProof, C: GenericConfig, const D: usize> { - pub stark_proofs: [StarkProof; NUM_TABLES], + pub stark_proofs: [StarkProofWithMetadata; NUM_TABLES], + pub(crate) ctl_challenges: GrandProductChallengeSet, pub public_values: PublicValues, } impl, C: GenericConfig, const D: usize> AllProof { pub fn degree_bits(&self, config: &StarkConfig) -> [usize; NUM_TABLES] { - std::array::from_fn(|i| self.stark_proofs[i].recover_degree_bits(config)) + std::array::from_fn(|i| self.stark_proofs[i].proof.recover_degree_bits(config)) } } @@ -113,6 +115,18 @@ pub struct StarkProof, C: GenericConfig, pub opening_proof: FriProof, } +/// A `StarkProof` along with some metadata about the initial Fiat-Shamir state, which is used when +/// creating a recursive wrapper proof around a STARK proof. +#[derive(Debug, Clone)] +pub struct StarkProofWithMetadata +where + F: RichField + Extendable, + C: GenericConfig, +{ + pub(crate) init_challenger_state: [F; SPONGE_WIDTH], + pub(crate) proof: StarkProof, +} + impl, C: GenericConfig, const D: usize> StarkProof { /// Recover the length of the trace from a STARK proof and a STARK config. pub fn recover_degree_bits(&self, config: &StarkConfig) -> usize { diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 55b57437..a777e1ad 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -29,10 +29,10 @@ use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; use crate::logic::LogicStark; use crate::memory::memory_stark::MemoryStark; use crate::permutation::{ - compute_permutation_z_polys, get_n_grand_product_challenge_sets, GrandProductChallengeSet, - PermutationCheckVars, + compute_permutation_z_polys, get_grand_product_challenge_set, + get_n_grand_product_challenge_sets, GrandProductChallengeSet, PermutationCheckVars, }; -use crate::proof::{AllProof, PublicValues, StarkOpeningSet, StarkProof}; +use crate::proof::{AllProof, PublicValues, StarkOpeningSet, StarkProof, StarkProofWithMetadata}; use crate::stark::Stark; use crate::vanishing_poly::eval_vanishing_poly; use crate::vars::StarkEvaluationVars; @@ -117,14 +117,14 @@ where challenger.observe_cap(cap); } + let ctl_challenges = get_grand_product_challenge_set(&mut challenger, config.num_challenges); let ctl_data_per_table = timed!( timing, "compute CTL data", cross_table_lookup_data::( - config, &trace_poly_values, &all_stark.cross_table_lookups, - &mut challenger, + &ctl_challenges, ) ); @@ -144,6 +144,7 @@ where Ok(AllProof { stark_proofs, + ctl_challenges, public_values, }) } @@ -156,7 +157,7 @@ fn prove_with_commitments( ctl_data_per_table: [CtlData; NUM_TABLES], challenger: &mut Challenger, timing: &mut TimingTree, -) -> Result<[StarkProof; NUM_TABLES]> +) -> Result<[StarkProofWithMetadata; NUM_TABLES]> where F: RichField + Extendable, C: GenericConfig, @@ -250,7 +251,7 @@ pub(crate) fn prove_single_table( ctl_data: &CtlData, challenger: &mut Challenger, timing: &mut TimingTree, -) -> Result> +) -> Result> where F: RichField + Extendable, C: GenericConfig, @@ -268,7 +269,7 @@ where "FRI total reduction arity is too large.", ); - challenger.compact(); + let init_challenger_state = challenger.compact(); // Permutation arguments. let permutation_challenges = stark.uses_permutation_args().then(|| { @@ -411,12 +412,16 @@ where ) ); - Ok(StarkProof { + let proof = StarkProof { trace_cap: trace_commitment.merkle_tree.cap.clone(), permutation_ctl_zs_cap, quotient_polys_cap, openings, opening_proof, + }; + Ok(StarkProofWithMetadata { + init_challenger_state, + proof, }) } diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index 0f713e32..59a37602 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -5,18 +5,24 @@ use itertools::Itertools; use plonky2::field::extension::Extendable; use plonky2::field::types::Field; use plonky2::fri::witness_util::set_fri_proof_target; +use plonky2::gates::exponentiation::ExponentiationGate; +use plonky2::gates::gate::GateRef; +use plonky2::gates::noop::NoopGate; use plonky2::hash::hash_types::{HashOut, RichField}; use plonky2::hash::hashing::SPONGE_WIDTH; use plonky2::iop::challenger::{Challenger, RecursiveChallenger}; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::iop::target::Target; -use plonky2::iop::witness::Witness; +use plonky2::iop::witness::{PartialWitness, Witness, WitnessWrite}; use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::circuit_data::{CircuitConfig, VerifierCircuitData, VerifierCircuitTarget}; +use plonky2::plonk::circuit_data::{ + CircuitConfig, CircuitData, VerifierCircuitData, VerifierCircuitTarget, +}; use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; use plonky2::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}; use plonky2::util::reducing::ReducingFactorTarget; use plonky2::with_context; +use plonky2_util::log2_ceil; use crate::all_stark::{all_cross_table_lookups, AllStark, Table, NUM_TABLES}; use crate::config::StarkConfig; @@ -36,8 +42,8 @@ use crate::permutation::{ }; use crate::proof::{ AllProof, AllProofTarget, BlockMetadata, BlockMetadataTarget, PublicValues, PublicValuesTarget, - StarkOpeningSetTarget, StarkProof, StarkProofChallengesTarget, StarkProofTarget, TrieRoots, - TrieRootsTarget, + StarkOpeningSetTarget, StarkProof, StarkProofChallengesTarget, StarkProofTarget, + StarkProofWithMetadata, TrieRoots, TrieRootsTarget, }; use crate::stark::Stark; use crate::util::{h160_limbs, h256_limbs}; @@ -58,12 +64,12 @@ pub struct RecursiveAllProofTargetWithData { pub verifier_data: [VerifierCircuitTarget; NUM_TABLES], } -struct PublicInputs { - trace_cap: Vec>, - ctl_zs_last: Vec, - ctl_challenges: GrandProductChallengeSet, - challenger_state_before: [T; SPONGE_WIDTH], - challenger_state_after: [T; SPONGE_WIDTH], +pub(crate) struct PublicInputs { + pub(crate) trace_cap: Vec>, + pub(crate) ctl_zs_last: Vec, + pub(crate) ctl_challenges: GrandProductChallengeSet, + pub(crate) challenger_state_before: [T; SPONGE_WIDTH], + pub(crate) challenger_state_after: [T; SPONGE_WIDTH], } /// Similar to the unstable `Iterator::next_chunk`. Could be replaced with that when it's stable. @@ -76,9 +82,10 @@ fn next_chunk(iter: &mut impl Iterator) -> [ } impl PublicInputs { - fn from_vec(v: &[T], config: &StarkConfig) -> Self { + pub(crate) fn from_vec(v: &[T], config: &StarkConfig) -> Self { + log::info!("from_vec {}", v.len()); let mut iter = v.iter().copied(); - let trace_cap = (0..1 << config.fri_config.cap_height) + let trace_cap = (0..config.fri_config.num_cap_elements()) .map(|_| next_chunk::<_, 4>(&mut iter).to_vec()) .collect(); let ctl_challenges = GrandProductChallengeSet { @@ -91,7 +98,8 @@ impl PublicInputs { }; let challenger_state_before = next_chunk(&mut iter); let challenger_state_after = next_chunk(&mut iter); - let ctl_zs_last = iter.collect(); + let ctl_zs_last: Vec<_> = iter.collect(); + log::info!("from_vec num Zs: {}", ctl_zs_last.len()); // TODO Self { trace_cap, @@ -143,7 +151,7 @@ impl, C: GenericConfig, const D: usize> // Verify the CTL checks. let degrees_bits = std::array::from_fn(|i| verifier_data[i].common.degree_bits()); verify_cross_table_lookups::( - cross_table_lookups, + &cross_table_lookups, pis.map(|p| p.ctl_zs_last), degrees_bits, ctl_challenges, @@ -158,6 +166,7 @@ impl, C: GenericConfig, const D: usize> } /// Recursively verify every recursive proof. + // TODO: Remove? Replaced by fixed_recursive_verifier. pub fn verify_circuit( builder: &mut CircuitBuilder, recursive_all_proof_target: RecursiveAllProofTargetWithData, @@ -215,13 +224,10 @@ impl, C: GenericConfig, const D: usize> } // Verify the CTL checks. - let degrees_bits = std::array::from_fn(|i| verifier_data[i].common.degree_bits()); verify_cross_table_lookups_circuit::( builder, all_cross_table_lookups(), pis.map(|p| p.ctl_zs_last), - degrees_bits, - ctl_challenges, inner_config, ); for (i, (recursive_proof, verifier_data_target)) in recursive_proofs @@ -238,33 +244,113 @@ impl, C: GenericConfig, const D: usize> } } -/// Returns the verifier data for the recursive Stark circuit. -fn verifier_data_recursive_stark_proof< +/// Represents a circuit which recursively verifies a STARK proof. +pub(crate) struct StarkWrapperCircuit +where + F: RichField + Extendable, + C: GenericConfig, +{ + pub(crate) circuit: CircuitData, + pub(crate) stark_proof_target: StarkProofTarget, + pub(crate) ctl_challenges_target: GrandProductChallengeSet, + pub(crate) init_challenger_state_target: [Target; SPONGE_WIDTH], + pub(crate) zero_target: Target, +} + +impl StarkWrapperCircuit +where + F: RichField + Extendable, + C: GenericConfig, + C::Hasher: AlgebraicHasher, +{ + pub(crate) fn prove( + &self, + proof_with_metadata: &StarkProofWithMetadata, + ctl_challenges: &GrandProductChallengeSet, + ) -> Result> { + let mut inputs = PartialWitness::new(); + + set_stark_proof_target( + &mut inputs, + &self.stark_proof_target, + &proof_with_metadata.proof, + self.zero_target, + ); + + for (challenge_target, challenge) in self + .ctl_challenges_target + .challenges + .iter() + .zip(&ctl_challenges.challenges) + { + inputs.set_target(challenge_target.beta, challenge.beta); + inputs.set_target(challenge_target.gamma, challenge.gamma); + } + + inputs.set_target_arr( + self.init_challenger_state_target, + proof_with_metadata.init_challenger_state, + ); + + self.circuit.prove(inputs) + } +} + +/// Represents a circuit which recursively verifies a PLONK proof. +pub(crate) struct PlonkWrapperCircuit +where + F: RichField + Extendable, + C: GenericConfig, +{ + pub(crate) circuit: CircuitData, + pub(crate) proof_with_pis_target: ProofWithPublicInputsTarget, +} + +impl PlonkWrapperCircuit +where + F: RichField + Extendable, + C: GenericConfig, + C::Hasher: AlgebraicHasher, +{ + pub(crate) fn prove( + &self, + proof: &ProofWithPublicInputs, + ) -> Result> { + let mut inputs = PartialWitness::new(); + inputs.set_proof_with_pis_target(&self.proof_with_pis_target, proof); + self.circuit.prove(inputs) + } +} + +/// Returns the recursive Stark circuit. +pub(crate) fn recursive_stark_circuit< F: RichField + Extendable, C: GenericConfig, S: Stark, const D: usize, >( table: Table, - stark: S, + stark: &S, degree_bits: usize, cross_table_lookups: &[CrossTableLookup], inner_config: &StarkConfig, circuit_config: &CircuitConfig, -) -> VerifierCircuitData + min_degree_bits: usize, +) -> StarkWrapperCircuit where [(); S::COLUMNS]:, [(); C::Hasher::HASH_SIZE]:, C::Hasher: AlgebraicHasher, { let mut builder = CircuitBuilder::::new(circuit_config.clone()); + let zero_target = builder.zero(); let num_permutation_zs = stark.num_permutation_batches(inner_config); let num_permutation_batch_size = stark.permutation_batch_size(); let num_ctl_zs = CrossTableLookup::num_ctl_zs(cross_table_lookups, table, inner_config.num_challenges); let proof_target = - add_virtual_stark_proof(&mut builder, &stark, inner_config, degree_bits, num_ctl_zs); + add_virtual_stark_proof(&mut builder, stark, inner_config, degree_bits, num_ctl_zs); builder.register_public_inputs( &proof_target .trace_cap @@ -291,8 +377,9 @@ where num_permutation_zs, ); - let challenger_state = std::array::from_fn(|_| builder.add_virtual_public_input()); - let mut challenger = RecursiveChallenger::::from_state(challenger_state); + let init_challenger_state_target = std::array::from_fn(|_| builder.add_virtual_public_input()); + let mut challenger = + RecursiveChallenger::::from_state(init_challenger_state_target); let challenges = proof_target.get_challenges::( &mut builder, &mut challenger, @@ -307,14 +394,39 @@ where verify_stark_proof_with_challenges_circuit::( &mut builder, - &stark, + stark, &proof_target, &challenges, &ctl_vars, inner_config, ); - builder.build_verifier::() + add_common_recursion_gates(&mut builder); + + // Pad to the minimum degree. + while log2_ceil(builder.num_gates()) < min_degree_bits { + builder.add_gate(NoopGate, vec![]); + } + + let circuit = builder.build::(); + StarkWrapperCircuit { + circuit, + stark_proof_target: proof_target, + ctl_challenges_target, + init_challenger_state_target, + zero_target, + } +} + +/// Add gates that are sometimes used by recursive circuits, even if it's not actually used by this +/// particular recursive circuit. This is done for uniformity. We sometimes want all recursion +/// circuits to have the same gate set, so that we can do 1-of-n conditional recursion efficiently. +pub(crate) fn add_common_recursion_gates, const D: usize>( + builder: &mut CircuitBuilder, +) { + builder.add_gate_to_gate_set(GateRef::new(ExponentiationGate::new_from_config( + &builder.config, + ))); } /// Returns the recursive Stark circuit verifier data for every Stark in `AllStark`. @@ -338,46 +450,61 @@ where C::Hasher: AlgebraicHasher, { [ - verifier_data_recursive_stark_proof( + recursive_stark_circuit( Table::Cpu, - all_stark.cpu_stark, + &all_stark.cpu_stark, degree_bits[Table::Cpu as usize], &all_stark.cross_table_lookups, inner_config, circuit_config, - ), - verifier_data_recursive_stark_proof( + 0, + ) + .circuit + .verifier_data(), + recursive_stark_circuit( Table::Keccak, - all_stark.keccak_stark, + &all_stark.keccak_stark, degree_bits[Table::Keccak as usize], &all_stark.cross_table_lookups, inner_config, circuit_config, - ), - verifier_data_recursive_stark_proof( + 0, + ) + .circuit + .verifier_data(), + recursive_stark_circuit( Table::KeccakSponge, - all_stark.keccak_sponge_stark, + &all_stark.keccak_sponge_stark, degree_bits[Table::KeccakSponge as usize], &all_stark.cross_table_lookups, inner_config, circuit_config, - ), - verifier_data_recursive_stark_proof( + 0, + ) + .circuit + .verifier_data(), + recursive_stark_circuit( Table::Logic, - all_stark.logic_stark, + &all_stark.logic_stark, degree_bits[Table::Logic as usize], &all_stark.cross_table_lookups, inner_config, circuit_config, - ), - verifier_data_recursive_stark_proof( + 0, + ) + .circuit + .verifier_data(), + recursive_stark_circuit( Table::Memory, - all_stark.memory_stark, + &all_stark.memory_stark, degree_bits[Table::Memory as usize], &all_stark.cross_table_lookups, inner_config, circuit_config, - ), + 0, + ) + .circuit + .verifier_data(), ] } @@ -580,7 +707,7 @@ where VerifierCircuitTarget { constants_sigmas_cap: builder .constant_merkle_cap(&verifier_data.verifier_only.constants_sigmas_cap), - circuit_digest: builder.add_virtual_hash(), + circuit_digest: builder.constant_hash(verifier_data.verifier_only.circuit_digest), } }); RecursiveAllProofTargetWithData { @@ -714,7 +841,7 @@ pub fn set_all_proof_target, W, const D: usize>( .iter() .zip_eq(&all_proof.stark_proofs) { - set_stark_proof_target(witness, pt, p, zero); + set_stark_proof_target(witness, pt, &p.proof, zero); } set_public_value_targets( witness, @@ -994,7 +1121,7 @@ pub(crate) mod tests { recursively_verify_stark_proof( Table::Cpu, all_stark.cpu_stark, - &all_proof.stark_proofs[Table::Cpu as usize], + &all_proof.stark_proofs[Table::Cpu as usize].proof, &all_stark.cross_table_lookups, &ctl_challenges, states[0], @@ -1005,7 +1132,7 @@ pub(crate) mod tests { recursively_verify_stark_proof( Table::Keccak, all_stark.keccak_stark, - &all_proof.stark_proofs[Table::Keccak as usize], + &all_proof.stark_proofs[Table::Keccak as usize].proof, &all_stark.cross_table_lookups, &ctl_challenges, states[1], @@ -1016,7 +1143,7 @@ pub(crate) mod tests { recursively_verify_stark_proof( Table::KeccakSponge, all_stark.keccak_sponge_stark, - &all_proof.stark_proofs[Table::KeccakSponge as usize], + &all_proof.stark_proofs[Table::KeccakSponge as usize].proof, &all_stark.cross_table_lookups, &ctl_challenges, states[2], @@ -1027,7 +1154,7 @@ pub(crate) mod tests { recursively_verify_stark_proof( Table::Logic, all_stark.logic_stark, - &all_proof.stark_proofs[Table::Logic as usize], + &all_proof.stark_proofs[Table::Logic as usize].proof, &all_stark.cross_table_lookups, &ctl_challenges, states[3], @@ -1038,7 +1165,7 @@ pub(crate) mod tests { recursively_verify_stark_proof( Table::Memory, all_stark.memory_stark, - &all_proof.stark_proofs[Table::Memory as usize], + &all_proof.stark_proofs[Table::Memory as usize].proof, &all_stark.cross_table_lookups, &ctl_challenges, states[4], diff --git a/evm/src/verifier.rs b/evm/src/verifier.rs index a0329d04..cb75ad33 100644 --- a/evm/src/verifier.rs +++ b/evm/src/verifier.rs @@ -26,7 +26,7 @@ use crate::vanishing_poly::eval_vanishing_poly; use crate::vars::StarkEvaluationVars; pub fn verify_proof, C: GenericConfig, const D: usize>( - all_stark: AllStark, + all_stark: &AllStark, all_proof: AllProof, config: &StarkConfig, ) -> Result<()> @@ -41,7 +41,7 @@ where let AllProofChallenges { stark_challenges, ctl_challenges, - } = all_proof.get_challenges(&all_stark, config); + } = all_proof.get_challenges(all_stark, config); let nums_permutation_zs = all_stark.nums_permutation_zs(config); @@ -56,52 +56,52 @@ where let ctl_vars_per_table = CtlCheckVars::from_proofs( &all_proof.stark_proofs, - &cross_table_lookups, + cross_table_lookups, &ctl_challenges, &nums_permutation_zs, ); verify_stark_proof_with_challenges( cpu_stark, - &all_proof.stark_proofs[Table::Cpu as usize], + &all_proof.stark_proofs[Table::Cpu as usize].proof, &stark_challenges[Table::Cpu as usize], &ctl_vars_per_table[Table::Cpu as usize], config, )?; verify_stark_proof_with_challenges( keccak_stark, - &all_proof.stark_proofs[Table::Keccak as usize], + &all_proof.stark_proofs[Table::Keccak as usize].proof, &stark_challenges[Table::Keccak as usize], &ctl_vars_per_table[Table::Keccak as usize], config, )?; verify_stark_proof_with_challenges( keccak_sponge_stark, - &all_proof.stark_proofs[Table::KeccakSponge as usize], + &all_proof.stark_proofs[Table::KeccakSponge as usize].proof, &stark_challenges[Table::KeccakSponge as usize], &ctl_vars_per_table[Table::KeccakSponge as usize], config, )?; verify_stark_proof_with_challenges( memory_stark, - &all_proof.stark_proofs[Table::Memory as usize], + &all_proof.stark_proofs[Table::Memory as usize].proof, &stark_challenges[Table::Memory as usize], &ctl_vars_per_table[Table::Memory as usize], config, )?; verify_stark_proof_with_challenges( logic_stark, - &all_proof.stark_proofs[Table::Logic as usize], + &all_proof.stark_proofs[Table::Logic as usize].proof, &stark_challenges[Table::Logic as usize], &ctl_vars_per_table[Table::Logic as usize], config, )?; let degrees_bits = - std::array::from_fn(|i| all_proof.stark_proofs[i].recover_degree_bits(config)); + std::array::from_fn(|i| all_proof.stark_proofs[i].proof.recover_degree_bits(config)); verify_cross_table_lookups::( cross_table_lookups, - all_proof.stark_proofs.map(|p| p.openings.ctl_zs_last), + all_proof.stark_proofs.map(|p| p.proof.openings.ctl_zs_last), degrees_bits, ctl_challenges, config, @@ -114,7 +114,7 @@ pub(crate) fn verify_stark_proof_with_challenges< S: Stark, const D: usize, >( - stark: S, + stark: &S, proof: &StarkProof, challenges: &StarkProofChallenges, ctl_vars: &[CtlCheckVars], @@ -125,7 +125,7 @@ where [(); C::Hasher::HASH_SIZE]:, { log::debug!("Checking proof: {}", type_name::()); - validate_proof_shape(&stark, proof, config, ctl_vars.len())?; + validate_proof_shape(stark, proof, config, ctl_vars.len())?; let StarkOpeningSet { local_values, next_values, @@ -160,7 +160,7 @@ where permutation_challenge_sets: challenges.permutation_challenge_sets.clone().unwrap(), }); eval_vanishing_poly::( - &stark, + stark, config, vars, permutation_data, diff --git a/evm/tests/empty_txn_list.rs b/evm/tests/empty_txn_list.rs index aa7b60b9..1e564758 100644 --- a/evm/tests/empty_txn_list.rs +++ b/evm/tests/empty_txn_list.rs @@ -8,6 +8,7 @@ use plonky2::plonk::config::PoseidonGoldilocksConfig; use plonky2::util::timing::TimingTree; use plonky2_evm::all_stark::AllStark; use plonky2_evm::config::StarkConfig; +use plonky2_evm::fixed_recursive_verifier::AllRecursiveCircuits; use plonky2_evm::generation::{GenerationInputs, TrieInputs}; use plonky2_evm::proof::BlockMetadata; use plonky2_evm::prover::prove; @@ -49,7 +50,10 @@ fn test_empty_txn_list() -> anyhow::Result<()> { }; let mut timing = TimingTree::new("prove", log::Level::Debug); - let proof = prove::(&all_stark, &config, inputs, &mut timing)?; + // TODO: This is redundant; prove_root below calls this prove method internally. + // Just keeping it for now because the root proof returned by prove_root doesn't contain public + // values yet, and we want those for the assertions below. + let proof = prove::(&all_stark, &config, inputs.clone(), &mut timing)?; timing.filter(Duration::from_millis(100)).print(); assert_eq!( @@ -77,7 +81,11 @@ fn test_empty_txn_list() -> anyhow::Result<()> { receipts_trie_root ); - verify_proof(all_stark, proof, &config) + verify_proof(&all_stark, proof, &config)?; + + let all_circuits = AllRecursiveCircuits::::new(&all_stark, 9..19, &config); + let root_proof = all_circuits.prove_root(&all_stark, &config, inputs, &mut timing)?; + all_circuits.root.circuit.verify(root_proof) } fn init_logger() { diff --git a/evm/tests/transfer_to_new_addr.rs b/evm/tests/transfer_to_new_addr.rs index e4fe8eb4..4a5e69e3 100644 --- a/evm/tests/transfer_to_new_addr.rs +++ b/evm/tests/transfer_to_new_addr.rs @@ -104,7 +104,7 @@ fn test_simple_transfer() -> anyhow::Result<()> { expected_state_trie_after.calc_hash() ); - verify_proof(all_stark, proof, &config) + verify_proof(&all_stark, proof, &config) } fn eth_to_wei(eth: U256) -> U256 { diff --git a/plonky2/examples/bench_recursion.rs b/plonky2/examples/bench_recursion.rs index 0ad12e76..bf6b0e6b 100644 --- a/plonky2/examples/bench_recursion.rs +++ b/plonky2/examples/bench_recursion.rs @@ -14,9 +14,7 @@ use plonky2::gates::noop::NoopGate; use plonky2::hash::hash_types::RichField; use plonky2::iop::witness::{PartialWitness, WitnessWrite}; use plonky2::plonk::circuit_builder::CircuitBuilder; -use plonky2::plonk::circuit_data::{ - CircuitConfig, CommonCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData, -}; +use plonky2::plonk::circuit_data::{CircuitConfig, CommonCircuitData, VerifierOnlyCircuitData}; use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, PoseidonGoldilocksConfig}; use plonky2::plonk::proof::{CompressedProofWithPublicInputs, ProofWithPublicInputs}; use plonky2::plonk::prover::prove; @@ -107,10 +105,7 @@ where let mut builder = CircuitBuilder::::new(config.clone()); let pt = builder.add_virtual_proof_with_pis::(inner_cd); - let inner_data = VerifierCircuitTarget { - constants_sigmas_cap: builder.add_virtual_cap(inner_cd.config.fri_config.cap_height), - circuit_digest: builder.add_virtual_hash(), - }; + let inner_data = builder.add_virtual_verifier_data(inner_cd.config.fri_config.cap_height); builder.verify_proof::(&pt, &inner_data, inner_cd); builder.print_gate_counts(0); diff --git a/plonky2/src/gadgets/random_access.rs b/plonky2/src/gadgets/random_access.rs index d3a3ff1b..73e3de8c 100644 --- a/plonky2/src/gadgets/random_access.rs +++ b/plonky2/src/gadgets/random_access.rs @@ -2,15 +2,15 @@ use alloc::vec::Vec; use crate::field::extension::Extendable; use crate::gates::random_access::RandomAccessGate; -use crate::hash::hash_types::RichField; +use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget, RichField}; use crate::iop::ext_target::ExtensionTarget; use crate::iop::target::Target; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::VerifierCircuitTarget; use crate::util::log2_strict; impl, const D: usize> CircuitBuilder { - /// Checks that a `Target` matches a vector at a non-deterministic index. - /// Note: `access_index` is not range-checked. + /// Checks that a `Target` matches a vector at a particular index. pub fn random_access(&mut self, access_index: Target, v: Vec) -> Target { let vec_size = v.len(); let bits = log2_strict(vec_size); @@ -38,18 +38,64 @@ impl, const D: usize> CircuitBuilder { claimed_element } - /// Checks that an `ExtensionTarget` matches a vector at a non-deterministic index. - /// Note: `access_index` is not range-checked. + /// Like `random_access`, but with `ExtensionTarget`s rather than simple `Target`s. pub fn random_access_extension( &mut self, access_index: Target, v: Vec>, ) -> ExtensionTarget { - let v: Vec<_> = (0..D) + let selected: Vec<_> = (0..D) .map(|i| self.random_access(access_index, v.iter().map(|et| et.0[i]).collect())) .collect(); - ExtensionTarget(v.try_into().unwrap()) + ExtensionTarget(selected.try_into().unwrap()) + } + + /// Like `random_access`, but with `HashOutTarget`s rather than simple `Target`s. + pub fn random_access_hash( + &mut self, + access_index: Target, + v: Vec, + ) -> HashOutTarget { + let selected = std::array::from_fn(|i| { + self.random_access( + access_index, + v.iter().map(|hash| hash.elements[i]).collect(), + ) + }); + selected.into() + } + + /// Like `random_access`, but with `MerkleCapTarget`s rather than simple `Target`s. + pub fn random_access_merkle_cap( + &mut self, + access_index: Target, + v: Vec, + ) -> MerkleCapTarget { + let cap_size = v[0].0.len(); + assert!(v.iter().all(|cap| cap.0.len() == cap_size)); + + let selected = (0..cap_size) + .map(|i| self.random_access_hash(access_index, v.iter().map(|cap| cap.0[i]).collect())) + .collect(); + MerkleCapTarget(selected) + } + + /// Like `random_access`, but with `VerifierCircuitTarget`s rather than simple `Target`s. + pub fn random_access_verifier_data( + &mut self, + access_index: Target, + v: Vec, + ) -> VerifierCircuitTarget { + let constants_sigmas_caps = v.iter().map(|vk| vk.constants_sigmas_cap.clone()).collect(); + let circuit_digests = v.iter().map(|vk| vk.circuit_digest).collect(); + let constants_sigmas_cap = + self.random_access_merkle_cap(access_index, constants_sigmas_caps); + let circuit_digest = self.random_access_hash(access_index, circuit_digests); + VerifierCircuitTarget { + constants_sigmas_cap, + circuit_digest, + } } } diff --git a/plonky2/src/hash/merkle_proofs.rs b/plonky2/src/hash/merkle_proofs.rs index cd696e55..3726d847 100644 --- a/plonky2/src/hash/merkle_proofs.rs +++ b/plonky2/src/hash/merkle_proofs.rs @@ -11,6 +11,7 @@ use crate::hash::hashing::SPONGE_WIDTH; use crate::hash::merkle_tree::MerkleCap; use crate::iop::target::{BoolTarget, Target}; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::VerifierCircuitTarget; use crate::plonk::config::{AlgebraicHasher, Hasher}; #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] @@ -152,6 +153,11 @@ impl, const D: usize> CircuitBuilder { self.connect_hashes(*h0, *h1); } } + + pub fn connect_verifier_data(&mut self, x: &VerifierCircuitTarget, y: &VerifierCircuitTarget) { + self.connect_merkle_caps(&x.constants_sigmas_cap, &y.constants_sigmas_cap); + self.connect_hashes(x.circuit_digest, y.circuit_digest); + } } #[cfg(test)] diff --git a/plonky2/src/hash/merkle_tree.rs b/plonky2/src/hash/merkle_tree.rs index 86871701..7a9cd1f3 100644 --- a/plonky2/src/hash/merkle_tree.rs +++ b/plonky2/src/hash/merkle_tree.rs @@ -14,6 +14,7 @@ use crate::util::log2_strict; /// It can be used in place of the root to verify Merkle paths, which are `h` elements shorter. #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] #[serde(bound = "")] +// TODO: Change H to GenericHashOut, since this only cares about the hash, not the hasher. pub struct MerkleCap>(pub Vec); impl> MerkleCap { diff --git a/plonky2/src/iop/witness.rs b/plonky2/src/iop/witness.rs index e9ade3d2..14213d0d 100644 --- a/plonky2/src/iop/witness.rs +++ b/plonky2/src/iop/witness.rs @@ -14,7 +14,7 @@ use crate::iop::ext_target::ExtensionTarget; use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; use crate::plonk::circuit_data::{VerifierCircuitTarget, VerifierOnlyCircuitData}; -use crate::plonk::config::{AlgebraicHasher, GenericConfig}; +use crate::plonk::config::{AlgebraicHasher, GenericConfig, Hasher}; use crate::plonk::proof::{Proof, ProofTarget, ProofWithPublicInputs, ProofWithPublicInputsTarget}; pub trait WitnessWrite { @@ -224,6 +224,19 @@ pub trait Witness: WitnessWrite { } } + fn get_merkle_cap_target>(&self, cap_target: MerkleCapTarget) -> MerkleCap + where + F: RichField, + H: AlgebraicHasher, + { + let cap = cap_target + .0 + .iter() + .map(|hash_target| self.get_hash_target(*hash_target)) + .collect(); + MerkleCap(cap) + } + fn get_wire(&self, wire: Wire) -> F { self.get_target(Target::Wire(wire)) } diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index 9017a059..21305f4d 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -245,6 +245,13 @@ impl, const D: usize> CircuitBuilder { t } + pub fn add_virtual_verifier_data(&mut self, cap_height: usize) -> VerifierCircuitTarget { + VerifierCircuitTarget { + constants_sigmas_cap: self.add_virtual_cap(cap_height), + circuit_digest: self.add_virtual_hash(), + } + } + /// Add a virtual verifier data, register it as a public input and set it to `self.verifier_data_public_input`. /// WARNING: Do not register any public input after calling this! TODO: relax this pub fn add_verifier_data_public_inputs(&mut self) -> VerifierCircuitTarget { @@ -253,10 +260,7 @@ impl, const D: usize> CircuitBuilder { "add_verifier_data_public_inputs only needs to be called once" ); - let verifier_data = VerifierCircuitTarget { - constants_sigmas_cap: self.add_virtual_cap(self.config.fri_config.cap_height), - circuit_digest: self.add_virtual_hash(), - }; + let verifier_data = self.add_virtual_verifier_data(self.config.fri_config.cap_height); // The verifier data are public inputs. self.register_public_inputs(&verifier_data.circuit_digest.elements); for i in 0..self.config.fri_config.num_cap_elements() { @@ -784,13 +788,13 @@ impl, const D: usize> CircuitBuilder { self.add_simple_generator(const_gen); } - info!( + debug!( "Degree before blinding & padding: {}", self.gate_instances.len() ); self.blind_and_pad(); let degree = self.gate_instances.len(); - info!("Degree after blinding & padding: {}", degree); + debug!("Degree after blinding & padding: {}", degree); let degree_bits = log2_strict(degree); let fri_params = self.fri_params(degree_bits); assert!( diff --git a/plonky2/src/plonk/circuit_data.rs b/plonky2/src/plonk/circuit_data.rs index 6df986fa..c57e206d 100644 --- a/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/src/plonk/circuit_data.rs @@ -149,15 +149,15 @@ impl, C: GenericConfig, const D: usize> proof.decompress(&self.verifier_only.circuit_digest, &self.common) } - pub fn verifier_data(self) -> VerifierCircuitData { + pub fn verifier_data(&self) -> VerifierCircuitData { let CircuitData { verifier_only, common, .. } = self; VerifierCircuitData { - verifier_only, - common, + verifier_only: verifier_only.clone(), + common: common.clone(), } } @@ -258,7 +258,7 @@ pub struct ProverOnlyCircuitData< } /// Circuit data required by the verifier, but not the prover. -#[derive(Debug, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq)] pub struct VerifierOnlyCircuitData, const D: usize> { /// A commitment to each constant polynomial and each permutation polynomial. pub constants_sigmas_cap: MerkleCap, @@ -289,7 +289,7 @@ pub struct CommonCircuitData, const D: usize> { /// The number of constant wires. pub(crate) num_constants: usize, - pub(crate) num_public_inputs: usize, + pub num_public_inputs: usize, /// The `{k_i}` valued used in `S_ID_i` in Plonk's permutation argument. pub(crate) k_is: Vec, diff --git a/plonky2/src/recursion/conditional_recursive_verifier.rs b/plonky2/src/recursion/conditional_recursive_verifier.rs index ace47cab..2596e1f0 100644 --- a/plonky2/src/recursion/conditional_recursive_verifier.rs +++ b/plonky2/src/recursion/conditional_recursive_verifier.rs @@ -378,15 +378,11 @@ mod tests { pw.set_proof_with_pis_target(&pt, &proof); let dummy_pt = builder.add_virtual_proof_with_pis::(&data.common); pw.set_proof_with_pis_target::(&dummy_pt, &dummy_proof); - let inner_data = VerifierCircuitTarget { - constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height), - circuit_digest: builder.add_virtual_hash(), - }; + let inner_data = + builder.add_virtual_verifier_data(data.common.config.fri_config.cap_height); pw.set_verifier_data_target(&inner_data, &data.verifier_only); - let dummy_inner_data = VerifierCircuitTarget { - constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height), - circuit_digest: builder.add_virtual_hash(), - }; + let dummy_inner_data = + builder.add_virtual_verifier_data(data.common.config.fri_config.cap_height); pw.set_verifier_data_target(&dummy_inner_data, &dummy_data.verifier_only); let b = builder.constant_bool(F::rand().0 % 2 == 0); builder.conditionally_verify_proof::( diff --git a/plonky2/src/recursion/cyclic_recursion.rs b/plonky2/src/recursion/cyclic_recursion.rs index a12c31d4..656ff3b9 100644 --- a/plonky2/src/recursion/cyclic_recursion.rs +++ b/plonky2/src/recursion/cyclic_recursion.rs @@ -188,7 +188,7 @@ mod tests { use crate::hash::poseidon::{PoseidonHash, PoseidonPermutation}; use crate::iop::witness::{PartialWitness, WitnessWrite}; use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData, VerifierCircuitTarget}; + use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData}; use crate::plonk::config::{AlgebraicHasher, GenericConfig, PoseidonGoldilocksConfig}; use crate::recursion::cyclic_recursion::check_cyclic_proof_verifier_data; use crate::recursion::dummy_circuit::cyclic_base_proof; @@ -208,20 +208,16 @@ mod tests { let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); let proof = builder.add_virtual_proof_with_pis::(&data.common); - let verifier_data = VerifierCircuitTarget { - constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height), - circuit_digest: builder.add_virtual_hash(), - }; + let verifier_data = + builder.add_virtual_verifier_data(data.common.config.fri_config.cap_height); builder.verify_proof::(&proof, &verifier_data, &data.common); let data = builder.build::(); let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); let proof = builder.add_virtual_proof_with_pis::(&data.common); - let verifier_data = VerifierCircuitTarget { - constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height), - circuit_digest: builder.add_virtual_hash(), - }; + let verifier_data = + builder.add_virtual_verifier_data(data.common.config.fri_config.cap_height); builder.verify_proof::(&proof, &verifier_data, &data.common); while builder.num_gates() < 1 << 12 { builder.add_gate(NoopGate, vec![]); diff --git a/plonky2/src/recursion/dummy_circuit.rs b/plonky2/src/recursion/dummy_circuit.rs index 34b20c71..38f51aea 100644 --- a/plonky2/src/recursion/dummy_circuit.rs +++ b/plonky2/src/recursion/dummy_circuit.rs @@ -113,11 +113,8 @@ impl, const D: usize> CircuitBuilder { let dummy_circuit = dummy_circuit::(common_data); let dummy_proof_with_pis = dummy_proof(&dummy_circuit, HashMap::new())?; let dummy_proof_with_pis_target = self.add_virtual_proof_with_pis::(common_data); - - let dummy_verifier_data_target = VerifierCircuitTarget { - constants_sigmas_cap: self.add_virtual_cap(self.config.fri_config.cap_height), - circuit_digest: self.add_virtual_hash(), - }; + let dummy_verifier_data_target = + self.add_virtual_verifier_data(self.config.fri_config.cap_height); self.add_simple_generator(DummyProofGenerator { proof_with_pis_target: dummy_proof_with_pis_target.clone(), diff --git a/plonky2/src/recursion/recursive_verifier.rs b/plonky2/src/recursion/recursive_verifier.rs index 15943a87..9aafb1f5 100644 --- a/plonky2/src/recursion/recursive_verifier.rs +++ b/plonky2/src/recursion/recursive_verifier.rs @@ -366,10 +366,7 @@ mod tests { let pt = builder.add_virtual_proof_with_pis::(&inner_cd); pw.set_proof_with_pis_target(&pt, &inner_proof); - let inner_data = VerifierCircuitTarget { - constants_sigmas_cap: builder.add_virtual_cap(inner_cd.config.fri_config.cap_height), - circuit_digest: builder.add_virtual_hash(), - }; + let inner_data = builder.add_virtual_verifier_data(inner_cd.config.fri_config.cap_height); pw.set_cap_target( &inner_data.constants_sigmas_cap, &inner_vd.constants_sigmas_cap,