From ae3003a9d7eec22384328079fd8b413ce7acb153 Mon Sep 17 00:00:00 2001 From: Robin Salen <30937548+Nashtare@users.noreply.github.com> Date: Fri, 22 Dec 2023 17:23:22 +0100 Subject: [PATCH] Add alternative method to prove txs without pre-loaded table circuits (#1438) --- evm/src/cross_table_lookup.rs | 2 +- evm/src/fixed_recursive_verifier.rs | 209 ++++++++++++---------------- evm/src/proof.rs | 10 +- evm/src/prover.rs | 2 +- evm/src/recursive_verifier.rs | 2 +- evm/tests/empty_txn_list.rs | 23 +-- 6 files changed, 103 insertions(+), 145 deletions(-) diff --git a/evm/src/cross_table_lookup.rs b/evm/src/cross_table_lookup.rs index 65b27b13..21f94126 100644 --- a/evm/src/cross_table_lookup.rs +++ b/evm/src/cross_table_lookup.rs @@ -566,7 +566,7 @@ impl GrandProductChallenge { /// Like `PermutationChallenge`, but with `num_challenges` copies to boost soundness. #[derive(Clone, Eq, PartialEq, Debug)] -pub(crate) struct GrandProductChallengeSet { +pub struct GrandProductChallengeSet { pub(crate) challenges: Vec>, } diff --git a/evm/src/fixed_recursive_verifier.rs b/evm/src/fixed_recursive_verifier.rs index 063f479a..3f405e52 100644 --- a/evm/src/fixed_recursive_verifier.rs +++ b/evm/src/fixed_recursive_verifier.rs @@ -1,9 +1,11 @@ use core::mem::{self, MaybeUninit}; use std::collections::BTreeMap; use std::ops::Range; +use std::path::Path; use std::sync::atomic::AtomicBool; use std::sync::Arc; +use anyhow::anyhow; use eth_trie_utils::partial_trie::{HashedPartialTrie, Node, PartialTrie}; use hashbrown::HashMap; use itertools::{zip_eq, Itertools}; @@ -23,6 +25,7 @@ use plonky2::plonk::config::{AlgebraicHasher, GenericConfig}; use plonky2::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}; use plonky2::recursion::cyclic_recursion::check_cyclic_proof_verifier_data; use plonky2::recursion::dummy_circuit::cyclic_base_proof; +use plonky2::util::serialization::gate_serialization::default; use plonky2::util::serialization::{ Buffer, GateSerializer, IoResult, Read, WitnessGeneratorSerializer, Write, }; @@ -38,8 +41,8 @@ use crate::cross_table_lookup::{ use crate::generation::GenerationInputs; use crate::get_challenges::observe_public_values_target; use crate::proof::{ - BlockHashesTarget, BlockMetadataTarget, ExtraBlockData, ExtraBlockDataTarget, PublicValues, - PublicValuesTarget, StarkProofWithMetadata, TrieRoots, TrieRootsTarget, + AllProof, BlockHashesTarget, BlockMetadataTarget, ExtraBlockData, ExtraBlockDataTarget, + PublicValues, PublicValuesTarget, StarkProofWithMetadata, TrieRoots, TrieRootsTarget, }; use crate::prover::{check_abort_signal, prove}; use crate::recursive_verifier::{ @@ -70,7 +73,7 @@ where /// The block circuit, which verifies an aggregation root proof and a previous block proof. pub block: BlockCircuitData, /// Holds chains of circuits for each table and for each initial `degree_bits`. - by_table: [RecursiveCircuitsForTable; NUM_TABLES], + pub by_table: [RecursiveCircuitsForTable; NUM_TABLES], } /// Data for the EVM root circuit, which is used to combine each STARK's shrunk wrapper proof @@ -297,6 +300,7 @@ where { pub fn to_bytes( &self, + skip_tables: bool, gate_serializer: &dyn GateSerializer, generator_serializer: &dyn WitnessGeneratorSerializer, ) -> IoResult> { @@ -308,14 +312,17 @@ where .to_buffer(&mut buffer, gate_serializer, generator_serializer)?; self.block .to_buffer(&mut buffer, gate_serializer, generator_serializer)?; - for table in &self.by_table { - table.to_buffer(&mut buffer, gate_serializer, generator_serializer)?; + if !skip_tables { + for table in &self.by_table { + table.to_buffer(&mut buffer, gate_serializer, generator_serializer)?; + } } Ok(buffer) } pub fn from_bytes( bytes: &[u8], + skip_tables: bool, gate_serializer: &dyn GateSerializer, generator_serializer: &dyn WitnessGeneratorSerializer, ) -> IoResult { @@ -330,21 +337,30 @@ where let block = BlockCircuitData::from_buffer(&mut buffer, gate_serializer, generator_serializer)?; - // Tricky use of MaybeUninit to remove the need for implementing Debug - // for all underlying types, necessary to convert a by_table Vec to an array. - let by_table = { - let mut by_table: [MaybeUninit>; NUM_TABLES] = - unsafe { MaybeUninit::uninit().assume_init() }; - for table in &mut by_table[..] { - let value = RecursiveCircuitsForTable::from_buffer( - &mut buffer, - gate_serializer, - generator_serializer, - )?; - *table = MaybeUninit::new(value); - } - unsafe { - mem::transmute::<_, [RecursiveCircuitsForTable; NUM_TABLES]>(by_table) + let by_table = match skip_tables { + true => (0..NUM_TABLES) + .map(|_| RecursiveCircuitsForTable { + by_stark_size: BTreeMap::default(), + }) + .collect_vec() + .try_into() + .unwrap(), + false => { + // Tricky use of MaybeUninit to remove the need for implementing Debug + // for all underlying types, necessary to convert a by_table Vec to an array. + let mut by_table: [MaybeUninit>; NUM_TABLES] = + unsafe { MaybeUninit::uninit().assume_init() }; + for table in &mut by_table[..] { + let value = RecursiveCircuitsForTable::from_buffer( + &mut buffer, + gate_serializer, + generator_serializer, + )?; + *table = MaybeUninit::new(value); + } + unsafe { + mem::transmute::<_, [RecursiveCircuitsForTable; NUM_TABLES]>(by_table) + } } }; @@ -432,72 +448,6 @@ where } } - /// Expand the preprocessed STARK table circuits with the provided ranges. - /// - /// If a range for a given table is contained within the current one, this will be a no-op. - /// Otherwise, it will add the circuits for the missing table sizes, and regenerate the upper circuits. - pub fn expand( - &mut self, - all_stark: &AllStark, - degree_bits_ranges: &[Range; NUM_TABLES], - stark_config: &StarkConfig, - ) { - self.by_table[Table::Arithmetic as usize].expand( - Table::Arithmetic, - &all_stark.arithmetic_stark, - degree_bits_ranges[Table::Arithmetic as usize].clone(), - &all_stark.cross_table_lookups, - stark_config, - ); - self.by_table[Table::BytePacking as usize].expand( - Table::BytePacking, - &all_stark.byte_packing_stark, - degree_bits_ranges[Table::BytePacking as usize].clone(), - &all_stark.cross_table_lookups, - stark_config, - ); - self.by_table[Table::Cpu as usize].expand( - Table::Cpu, - &all_stark.cpu_stark, - degree_bits_ranges[Table::Cpu as usize].clone(), - &all_stark.cross_table_lookups, - stark_config, - ); - self.by_table[Table::Keccak as usize].expand( - Table::Keccak, - &all_stark.keccak_stark, - degree_bits_ranges[Table::Keccak as usize].clone(), - &all_stark.cross_table_lookups, - stark_config, - ); - self.by_table[Table::KeccakSponge as usize].expand( - Table::KeccakSponge, - &all_stark.keccak_sponge_stark, - degree_bits_ranges[Table::KeccakSponge as usize].clone(), - &all_stark.cross_table_lookups, - stark_config, - ); - self.by_table[Table::Logic as usize].expand( - Table::Logic, - &all_stark.logic_stark, - degree_bits_ranges[Table::Logic as usize].clone(), - &all_stark.cross_table_lookups, - stark_config, - ); - self.by_table[Table::Memory as usize].expand( - Table::Memory, - &all_stark.memory_stark, - degree_bits_ranges[Table::Memory as usize].clone(), - &all_stark.cross_table_lookups, - stark_config, - ); - - // Regenerate the upper circuits. - self.root = Self::create_root_circuit(&self.by_table, stark_config); - self.aggregation = Self::create_aggregation_circuit(&self.root); - self.block = Self::create_block_circuit(&self.aggregation); - } - /// Outputs the `VerifierCircuitData` needed to verify any block proof /// generated by an honest prover. pub fn final_verifier_data(&self) -> VerifierCircuitData { @@ -988,7 +938,7 @@ where .by_stark_size .get(&original_degree_bits) .ok_or_else(|| { - anyhow::Error::msg(format!( + anyhow!(format!( "Missing preprocessed circuits for {:?} table with size {}.", Table::all()[table], original_degree_bits, @@ -1028,6 +978,55 @@ where Ok((root_proof, all_proof.public_values)) } + /// From an initial set of STARK proofs passed with their associated recursive table circuits, + /// generate a recursive transaction proof. + /// It is aimed at being used when preprocessed table circuits have not been loaded to memory. + pub fn prove_root_after_initial_stark( + &self, + all_stark: &AllStark, + config: &StarkConfig, + all_proof: AllProof, + table_circuits: &[(RecursiveCircuitsForTableSize, u8); NUM_TABLES], + timing: &mut TimingTree, + abort_signal: Option>, + ) -> anyhow::Result<(ProofWithPublicInputs, PublicValues)> { + let mut root_inputs = PartialWitness::new(); + + for table in 0..NUM_TABLES { + let (table_circuit, index_verifier_data) = &table_circuits[table]; + + let stark_proof = &all_proof.stark_proofs[table]; + let original_degree_bits = stark_proof.proof.recover_degree_bits(config); + + let shrunk_proof = table_circuit.shrink(stark_proof, &all_proof.ctl_challenges)?; + root_inputs.set_target( + self.root.index_verifier_data[table], + F::from_canonical_u8(*index_verifier_data), + ); + root_inputs.set_proof_with_pis_target(&self.root.proof_with_pis[table], &shrunk_proof); + + check_abort_signal(abort_signal.clone())?; + } + + root_inputs.set_verifier_data_target( + &self.root.cyclic_vk, + &self.aggregation.circuit.verifier_only, + ); + + set_public_value_targets( + &mut root_inputs, + &self.root.public_values, + &all_proof.public_values, + ) + .map_err(|_| { + anyhow::Error::msg("Invalid conversion when setting public values targets.") + })?; + + let root_proof = self.root.circuit.prove(root_inputs)?; + + Ok((root_proof, all_proof.public_values)) + } + pub fn verify_root(&self, agg_proof: ProofWithPublicInputs) -> anyhow::Result<()> { self.root.circuit.verify(agg_proof) } @@ -1255,7 +1254,7 @@ where { /// A map from `log_2(height)` to a chain of shrinking recursion circuits starting at that /// height. - by_stark_size: BTreeMap>, + pub by_stark_size: BTreeMap>, } impl RecursiveCircuitsForTable @@ -1321,32 +1320,6 @@ where Self { by_stark_size } } - fn expand>( - &mut self, - table: Table, - stark: &S, - degree_bits_range: Range, - all_ctls: &[CrossTableLookup], - stark_config: &StarkConfig, - ) { - let new_ranges = degree_bits_range - .filter(|degree_bits| !self.by_stark_size.contains_key(degree_bits)) - .collect_vec(); - - for degree_bits in new_ranges { - self.by_stark_size.insert( - degree_bits, - RecursiveCircuitsForTableSize::new::( - table, - stark, - degree_bits, - all_ctls, - stark_config, - ), - ); - } - } - /// For each initial `degree_bits`, get the final circuit at the end of that shrinking chain. /// Each of these final circuits should have degree `THRESHOLD_DEGREE_BITS`. fn final_circuits(&self) -> Vec<&CircuitData> { @@ -1366,7 +1339,7 @@ where /// A chain of shrinking wrapper circuits, ending with a final circuit with `degree_bits` /// `THRESHOLD_DEGREE_BITS`. #[derive(Eq, PartialEq, Debug)] -struct RecursiveCircuitsForTableSize +pub struct RecursiveCircuitsForTableSize where F: RichField + Extendable, C: GenericConfig, @@ -1382,7 +1355,7 @@ where C: GenericConfig, C::Hasher: AlgebraicHasher, { - fn to_buffer( + pub fn to_buffer( &self, buffer: &mut Vec, gate_serializer: &dyn GateSerializer, @@ -1409,7 +1382,7 @@ where Ok(()) } - fn from_buffer( + pub fn from_buffer( buffer: &mut Buffer, gate_serializer: &dyn GateSerializer, generator_serializer: &dyn WitnessGeneratorSerializer, @@ -1500,7 +1473,7 @@ where } } - fn shrink( + pub fn shrink( &self, stark_proof_with_metadata: &StarkProofWithMetadata, ctl_challenges: &GrandProductChallengeSet, diff --git a/evm/src/proof.rs b/evm/src/proof.rs index a5bf5756..3768f98f 100644 --- a/evm/src/proof.rs +++ b/evm/src/proof.rs @@ -270,7 +270,7 @@ impl ExtraBlockData { /// Memory values which are public. /// Note: All the larger integers are encoded with 32-bit limbs in little-endian order. #[derive(Eq, PartialEq, Debug)] -pub(crate) struct PublicValuesTarget { +pub struct PublicValuesTarget { /// Trie hashes before the execution of the local state transition. pub trie_roots_before: TrieRootsTarget, /// Trie hashes after the execution of the local state transition. @@ -485,7 +485,7 @@ impl PublicValuesTarget { /// Circuit version of `TrieRoots`. /// `Target`s for trie hashes. Since a `Target` holds a 32-bit limb, each hash requires 8 `Target`s. #[derive(Eq, PartialEq, Debug, Copy, Clone)] -pub(crate) struct TrieRootsTarget { +pub struct TrieRootsTarget { /// Targets for the state trie hash. pub(crate) state_root: [Target; 8], /// Targets for the transactions trie hash. @@ -556,7 +556,7 @@ impl TrieRootsTarget { /// Metadata contained in a block header. Those are identical between /// all state transition proofs within the same block. #[derive(Eq, PartialEq, Debug, Copy, Clone)] -pub(crate) struct BlockMetadataTarget { +pub struct BlockMetadataTarget { /// `Target`s for the address of this block's producer. pub(crate) block_beneficiary: [Target; 5], /// `Target` for the timestamp of this block. @@ -681,7 +681,7 @@ impl BlockMetadataTarget { /// When the block number is less than 256, dummy values, i.e. `H256::default()`, /// should be used for the additional block hashes. #[derive(Eq, PartialEq, Debug, Copy, Clone)] -pub(crate) struct BlockHashesTarget { +pub struct BlockHashesTarget { /// `Target`s for the previous 256 hashes to the current block. The leftmost hash, i.e. `prev_hashes[0..8]`, /// is the oldest, and the rightmost, i.e. `prev_hashes[255 * 7..255 * 8]` is the hash of the parent block. pub(crate) prev_hashes: [Target; 2048], @@ -739,7 +739,7 @@ impl BlockHashesTarget { /// Additional block data that are specific to the local transaction being proven, /// unlike `BlockMetadata`. #[derive(Eq, PartialEq, Debug, Copy, Clone)] -pub(crate) struct ExtraBlockDataTarget { +pub struct ExtraBlockDataTarget { /// `Target`s for the state trie digest of the checkpoint block. pub checkpoint_state_trie_root: [Target; 8], /// `Target` for the transaction count prior execution of the local state transition, starting diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 4d2daaf5..c90490b8 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -668,7 +668,7 @@ where /// Utility method that checks whether a kill signal has been emitted by one of the workers, /// which will result in an early abort for all the other processes involved in the same set /// of transactions. -pub(crate) fn check_abort_signal(abort_signal: Option>) -> Result<()> { +pub fn check_abort_signal(abort_signal: Option>) -> Result<()> { if let Some(signal) = abort_signal { if signal.load(Ordering::Relaxed) { return Err(anyhow!("Stopping job from abort signal.")); diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index 3103dd49..633a8d33 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -803,7 +803,7 @@ pub(crate) fn set_stark_proof_target, W, const D: set_fri_proof_target(witness, &proof_target.opening_proof, &proof.opening_proof); } -pub(crate) fn set_public_value_targets( +pub fn set_public_value_targets( witness: &mut W, public_values_target: &PublicValuesTarget, public_values: &PublicValues, diff --git a/evm/tests/empty_txn_list.rs b/evm/tests/empty_txn_list.rs index 16486677..5904b8a9 100644 --- a/evm/tests/empty_txn_list.rs +++ b/evm/tests/empty_txn_list.rs @@ -75,11 +75,9 @@ fn test_empty_txn_list() -> anyhow::Result<()> { }; // Initialize the preprocessed circuits for the zkEVM. - // The provided ranges are the minimal ones to prove an empty list, except the one of the CPU - // that is wrong for testing purposes, see below. - let mut all_circuits = AllRecursiveCircuits::::new( + let all_circuits = AllRecursiveCircuits::::new( &all_stark, - &[16..17, 10..11, 11..12, 14..15, 9..11, 12..13, 17..18], // Minimal ranges to prove an empty list + &[16..17, 10..11, 12..13, 14..15, 9..11, 12..13, 17..18], // Minimal ranges to prove an empty list &config, ); @@ -91,7 +89,7 @@ fn test_empty_txn_list() -> anyhow::Result<()> { let timing = TimingTree::new("serialize AllRecursiveCircuits", log::Level::Info); let all_circuits_bytes = all_circuits - .to_bytes(&gate_serializer, &generator_serializer) + .to_bytes(false, &gate_serializer, &generator_serializer) .map_err(|_| anyhow::Error::msg("AllRecursiveCircuits serialization failed."))?; timing.filter(Duration::from_millis(100)).print(); info!( @@ -102,6 +100,7 @@ fn test_empty_txn_list() -> anyhow::Result<()> { let timing = TimingTree::new("deserialize AllRecursiveCircuits", log::Level::Info); let all_circuits_from_bytes = AllRecursiveCircuits::::from_bytes( &all_circuits_bytes, + false, &gate_serializer, &generator_serializer, ) @@ -111,20 +110,6 @@ fn test_empty_txn_list() -> anyhow::Result<()> { assert_eq!(all_circuits, all_circuits_from_bytes); } - let mut timing = TimingTree::new("prove", log::Level::Info); - // We're missing some preprocessed circuits. - assert!(all_circuits - .prove_root(&all_stark, &config, inputs.clone(), &mut timing, None) - .is_err()); - - // Expand the preprocessed circuits. - // We pass an empty range if we don't want to add different table sizes. - all_circuits.expand( - &all_stark, - &[0..0, 0..0, 12..13, 0..0, 0..0, 0..0, 0..0], - &StarkConfig::standard_fast_config(), - ); - let mut timing = TimingTree::new("prove", log::Level::Info); let (root_proof, public_values) = all_circuits.prove_root(&all_stark, &config, inputs, &mut timing, None)?;