diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index b4687a89..ab105449 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -16,10 +16,21 @@ pub struct AllStark, const D: usize> { impl, const D: usize> AllStark { pub(crate) fn nums_permutation_zs(&self, config: &StarkConfig) -> Vec { - vec![ + let ans = vec![ self.cpu_stark.num_permutation_batches(config), self.keccak_stark.num_permutation_batches(config), - ] + ]; + debug_assert_eq!(ans.len(), Table::num_tables()); + ans + } + + pub(crate) fn permutation_batch_sizes(&self) -> Vec { + let ans = vec![ + self.cpu_stark.permutation_batch_size(), + self.keccak_stark.permutation_batch_size(), + ]; + debug_assert_eq!(ans.len(), Table::num_tables()); + ans } } @@ -29,6 +40,12 @@ pub enum Table { Keccak = 1, } +impl Table { + pub(crate) fn num_tables() -> usize { + Table::Keccak as usize + 1 + } +} + #[cfg(test)] mod tests { use anyhow::Result; diff --git a/evm/src/cross_table_lookup.rs b/evm/src/cross_table_lookup.rs index 2f6ac84c..d111fc9f 100644 --- a/evm/src/cross_table_lookup.rs +++ b/evm/src/cross_table_lookup.rs @@ -155,7 +155,7 @@ impl<'a, F: RichField + Extendable, const D: usize> CtlCheckVars<'a, F, F::Extension, F::Extension, D> { pub(crate) fn from_proofs>( - proofs: &[&StarkProofWithPublicInputs], + proofs: &[StarkProofWithPublicInputs], cross_table_lookups: &'a [CrossTableLookup], ctl_challenges: &'a GrandProductChallengeSet, num_permutation_zs: &[usize], @@ -236,7 +236,7 @@ pub(crate) fn verify_cross_table_lookups< const D: usize, >( cross_table_lookups: Vec>, - proofs: &[&StarkProofWithPublicInputs], + proofs: &[StarkProofWithPublicInputs], challenges: GrandProductChallengeSet, config: &StarkConfig, ) -> Result<()> { diff --git a/evm/src/get_challenges.rs b/evm/src/get_challenges.rs index 64eaeb53..6293a6f3 100644 --- a/evm/src/get_challenges.rs +++ b/evm/src/get_challenges.rs @@ -1,3 +1,4 @@ +use itertools::izip; use plonky2::field::extension_field::Extendable; use plonky2::fri::proof::FriProof; use plonky2::hash::hash_types::RichField; @@ -23,7 +24,7 @@ impl, C: GenericConfig, const D: usize> A ) -> AllProofChallenges { let mut challenger = Challenger::::new(); - for proof in self.proofs() { + for proof in &self.stark_proofs { challenger.observe_cap(&proof.proof.trace_cap); } @@ -31,16 +32,15 @@ impl, C: GenericConfig, const D: usize> A get_grand_product_challenge_set(&mut challenger, config.num_challenges); AllProofChallenges { - cpu_challenges: self.cpu_proof.get_challenges( - &mut challenger, - &all_stark.cpu_stark, - config, - ), - keccak_challenges: self.keccak_proof.get_challenges( - &mut challenger, - &all_stark.keccak_stark, - config, - ), + stark_challenges: izip!( + &self.stark_proofs, + all_stark.nums_permutation_zs(config), + all_stark.permutation_batch_sizes() + ) + .map(|(proof, num_perm, batch_size)| { + proof.get_challenges(&mut challenger, num_perm > 0, batch_size, config) + }) + .collect(), ctl_challenges, } } @@ -52,10 +52,11 @@ where C: GenericConfig, { /// Computes all Fiat-Shamir challenges used in the STARK proof. - pub(crate) fn get_challenges>( + pub(crate) fn get_challenges( &self, challenger: &mut Challenger, - stark: &S, + stark_use_permutation: bool, + stark_permutation_batch_size: usize, config: &StarkConfig, ) -> StarkProofChallenges { let degree_bits = self.proof.recover_degree_bits(config); @@ -76,11 +77,11 @@ where let num_challenges = config.num_challenges; - let permutation_challenge_sets = stark.uses_permutation_args().then(|| { + let permutation_challenge_sets = stark_use_permutation.then(|| { get_n_grand_product_challenge_sets( challenger, num_challenges, - stark.permutation_batch_size(), + stark_permutation_batch_size, ) }); diff --git a/evm/src/proof.rs b/evm/src/proof.rs index c04ff9dc..d03bc397 100644 --- a/evm/src/proof.rs +++ b/evm/src/proof.rs @@ -19,19 +19,11 @@ use crate::permutation::GrandProductChallengeSet; #[derive(Debug, Clone)] pub struct AllProof, C: GenericConfig, const D: usize> { - pub cpu_proof: StarkProofWithPublicInputs, - pub keccak_proof: StarkProofWithPublicInputs, -} - -impl, C: GenericConfig, const D: usize> AllProof { - pub fn proofs(&self) -> [&StarkProofWithPublicInputs; 2] { - [&self.cpu_proof, &self.keccak_proof] - } + pub stark_proofs: Vec>, } pub(crate) struct AllProofChallenges, const D: usize> { - pub cpu_challenges: StarkProofChallenges, - pub keccak_challenges: StarkProofChallenges, + pub stark_challenges: Vec>, pub ctl_challenges: GrandProductChallengeSet, } diff --git a/evm/src/prover.rs b/evm/src/prover.rs index bc6cd99b..df3b6db8 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -48,7 +48,7 @@ where [(); KeccakStark::::COLUMNS]:, [(); KeccakStark::::PUBLIC_INPUTS]:, { - let num_starks = Table::Keccak as usize + 1; + let num_starks = Table::num_tables(); debug_assert_eq!(num_starks, trace_poly_values.len()); debug_assert_eq!(num_starks, public_inputs.len()); @@ -118,10 +118,10 @@ where timing, )?; - Ok(AllProof { - cpu_proof, - keccak_proof, - }) + let stark_proofs = vec![cpu_proof, keccak_proof]; + debug_assert_eq!(stark_proofs.len(), num_starks); + + Ok(AllProof { stark_proofs }) } /// Compute proof for a single STARK table. diff --git a/evm/src/verifier.rs b/evm/src/verifier.rs index 52882488..76e739a2 100644 --- a/evm/src/verifier.rs +++ b/evm/src/verifier.rs @@ -32,8 +32,7 @@ where [(); C::Hasher::HASH_SIZE]:, { let AllProofChallenges { - cpu_challenges, - keccak_challenges, + stark_challenges, ctl_challenges, } = all_proof.get_challenges(&all_stark, config); @@ -46,7 +45,7 @@ where } = all_stark; let ctl_vars_per_table = CtlCheckVars::from_proofs( - &all_proof.proofs(), + &all_proof.stark_proofs, &cross_table_lookups, &ctl_challenges, &nums_permutation_zs, @@ -54,22 +53,22 @@ where verify_stark_proof_with_challenges( cpu_stark, - &all_proof.cpu_proof, - cpu_challenges, + &all_proof.stark_proofs[Table::Cpu as usize], + &stark_challenges[Table::Cpu as usize], &ctl_vars_per_table[Table::Cpu as usize], config, )?; verify_stark_proof_with_challenges( keccak_stark, - &all_proof.keccak_proof, - keccak_challenges, + &all_proof.stark_proofs[Table::Keccak as usize], + &stark_challenges[Table::Keccak as usize], &ctl_vars_per_table[Table::Keccak as usize], config, )?; verify_cross_table_lookups( cross_table_lookups, - &all_proof.proofs(), + &all_proof.stark_proofs, ctl_challenges, config, ) @@ -83,7 +82,7 @@ pub(crate) fn verify_stark_proof_with_challenges< >( stark: S, proof_with_pis: &StarkProofWithPublicInputs, - challenges: StarkProofChallenges, + challenges: &StarkProofChallenges, ctl_vars: &[CtlCheckVars], config: &StarkConfig, ) -> Result<()> @@ -134,7 +133,7 @@ where let permutation_data = stark.uses_permutation_args().then(|| PermutationCheckVars { local_zs: permutation_ctl_zs[..num_permutation_zs].to_vec(), next_zs: permutation_ctl_zs_right[..num_permutation_zs].to_vec(), - permutation_challenge_sets: challenges.permutation_challenge_sets.unwrap(), + permutation_challenge_sets: challenges.permutation_challenge_sets.clone().unwrap(), }); eval_vanishing_poly::( &stark,