diff --git a/evm/src/fixed_recursive_verifier.rs b/evm/src/fixed_recursive_verifier.rs index 87f076c1..063f479a 100644 --- a/evm/src/fixed_recursive_verifier.rs +++ b/evm/src/fixed_recursive_verifier.rs @@ -1,6 +1,8 @@ use core::mem::{self, MaybeUninit}; use std::collections::BTreeMap; use std::ops::Range; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; use eth_trie_utils::partial_trie::{HashedPartialTrie, Node, PartialTrie}; use hashbrown::HashMap; @@ -39,7 +41,7 @@ use crate::proof::{ BlockHashesTarget, BlockMetadataTarget, ExtraBlockData, ExtraBlockDataTarget, PublicValues, PublicValuesTarget, StarkProofWithMetadata, TrieRoots, TrieRootsTarget, }; -use crate::prover::prove; +use crate::prover::{check_abort_signal, prove}; use crate::recursive_verifier::{ add_common_recursion_gates, add_virtual_public_values, get_memory_extra_looking_sum_circuit, recursive_stark_circuit, set_public_value_targets, PlonkWrapperCircuit, PublicInputs, @@ -967,8 +969,15 @@ where config: &StarkConfig, generation_inputs: GenerationInputs, timing: &mut TimingTree, + abort_signal: Option>, ) -> anyhow::Result<(ProofWithPublicInputs, PublicValues)> { - let all_proof = prove::(all_stark, config, generation_inputs, timing)?; + let all_proof = prove::( + all_stark, + config, + generation_inputs, + timing, + abort_signal.clone(), + )?; let mut root_inputs = PartialWitness::new(); for table in 0..NUM_TABLES { @@ -996,6 +1005,8 @@ where F::from_canonical_usize(index_verifier_data), ); root_inputs.set_proof_with_pis_target(&self.root.proof_with_pis[table], &shrunk_proof); + + check_abort_signal(abort_signal.clone())?; } root_inputs.set_verifier_data_target( diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 77b6fd36..f692ae39 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -1,4 +1,6 @@ use std::collections::HashMap; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; use anyhow::anyhow; use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; @@ -23,6 +25,7 @@ use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::generation::state::GenerationState; use crate::memory::segments::Segment; use crate::proof::{BlockHashes, BlockMetadata, ExtraBlockData, PublicValues, TrieRoots}; +use crate::prover::check_abort_signal; use crate::util::h2u; use crate::witness::memory::{MemoryAddress, MemoryChannel}; use crate::witness::transition::transition; diff --git a/evm/src/keccak/keccak_stark.rs b/evm/src/keccak/keccak_stark.rs index 18ba3a5f..3e1bf192 100644 --- a/evm/src/keccak/keccak_stark.rs +++ b/evm/src/keccak/keccak_stark.rs @@ -765,6 +765,7 @@ mod tests { }, &mut Challenger::new(), &mut timing, + None, )?; timing.print(); diff --git a/evm/src/prover.rs b/evm/src/prover.rs index ab33a661..32989c8f 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -1,4 +1,7 @@ -use anyhow::{ensure, Result}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +use anyhow::{anyhow, ensure, Result}; use itertools::Itertools; use once_cell::sync::Lazy; use plonky2::field::extension::Extendable; @@ -32,6 +35,7 @@ use crate::lookup::{lookup_helper_columns, Lookup, LookupCheckVars}; use crate::proof::{AllProof, PublicValues, StarkOpeningSet, StarkProof, StarkProofWithMetadata}; use crate::stark::Stark; use crate::vanishing_poly::eval_vanishing_poly; +use crate::witness::errors::ProgramError; #[cfg(test)] use crate::{ cross_table_lookup::testutils::check_ctls, verifier::testutils::get_memory_extra_looking_values, @@ -43,6 +47,7 @@ pub fn prove( config: &StarkConfig, inputs: GenerationInputs, timing: &mut TimingTree, + abort_signal: Option>, ) -> Result> where F: RichField + Extendable, @@ -54,7 +59,16 @@ where "generate all traces", generate_traces(all_stark, inputs, config, timing)? ); - let proof = prove_with_traces(all_stark, config, traces, public_values, timing)?; + check_abort_signal(abort_signal.clone())?; + + let proof = prove_with_traces( + all_stark, + config, + traces, + public_values, + timing, + abort_signal, + )?; Ok(proof) } @@ -65,6 +79,7 @@ pub(crate) fn prove_with_traces( trace_poly_values: [Vec>; NUM_TABLES], public_values: PublicValues, timing: &mut TimingTree, + abort_signal: Option>, ) -> Result> where F: RichField + Extendable, @@ -136,7 +151,8 @@ where ctl_data_per_table, &mut challenger, &ctl_challenges, - timing + timing, + abort_signal, )? ); @@ -172,6 +188,7 @@ fn prove_with_commitments( challenger: &mut Challenger, ctl_challenges: &GrandProductChallengeSet, timing: &mut TimingTree, + abort_signal: Option>, ) -> Result<[StarkProofWithMetadata; NUM_TABLES]> where F: RichField + Extendable, @@ -189,6 +206,7 @@ where ctl_challenges, challenger, timing, + abort_signal.clone(), )? ); let byte_packing_proof = timed!( @@ -203,6 +221,7 @@ where ctl_challenges, challenger, timing, + abort_signal.clone(), )? ); let cpu_proof = timed!( @@ -217,6 +236,7 @@ where ctl_challenges, challenger, timing, + abort_signal.clone(), )? ); let keccak_proof = timed!( @@ -231,6 +251,7 @@ where ctl_challenges, challenger, timing, + abort_signal.clone(), )? ); let keccak_sponge_proof = timed!( @@ -245,6 +266,7 @@ where ctl_challenges, challenger, timing, + abort_signal.clone(), )? ); let logic_proof = timed!( @@ -259,6 +281,7 @@ where ctl_challenges, challenger, timing, + abort_signal.clone(), )? ); let memory_proof = timed!( @@ -273,6 +296,7 @@ where ctl_challenges, challenger, timing, + abort_signal, )? ); @@ -300,12 +324,15 @@ pub(crate) fn prove_single_table( ctl_challenges: &GrandProductChallengeSet, challenger: &mut Challenger, timing: &mut TimingTree, + abort_signal: Option>, ) -> Result> where F: RichField + Extendable, C: GenericConfig, S: Stark, { + check_abort_signal(abort_signal.clone())?; + let degree = trace_poly_values[0].len(); let degree_bits = log2_strict(degree); let fri_params = config.fri_params(degree_bits); @@ -392,6 +419,8 @@ where ); } + check_abort_signal(abort_signal.clone())?; + let quotient_polys = timed!( timing, "compute quotient polys", @@ -469,6 +498,8 @@ where "ient_commitment, ]; + check_abort_signal(abort_signal.clone())?; + let opening_proof = timed!( timing, "compute openings proof", @@ -636,6 +667,19 @@ where .collect() } +/// Utility method that checks whether a kill signal has been emitted by one of the workers, +/// which will result in an early abort for all the other processes involved in the same set +/// of transactions. +pub(crate) fn check_abort_signal(abort_signal: Option>) -> Result<()> { + if let Some(signal) = abort_signal { + if signal.load(Ordering::Relaxed) { + return Err(anyhow!("Stopping job from abort signal.")); + } + } + + Ok(()) +} + #[cfg(test)] /// Check that all constraints evaluate to zero on `H`. /// Can also be used to check the degree of the constraints by evaluating on a larger subgroup. diff --git a/evm/tests/add11_yml.rs b/evm/tests/add11_yml.rs index 040750fa..6a15dfc0 100644 --- a/evm/tests/add11_yml.rs +++ b/evm/tests/add11_yml.rs @@ -168,7 +168,7 @@ fn add11_yml() -> anyhow::Result<()> { }; let mut timing = TimingTree::new("prove", log::Level::Debug); - let proof = prove::(&all_stark, &config, inputs, &mut timing)?; + let proof = prove::(&all_stark, &config, inputs, &mut timing, None)?; timing.filter(Duration::from_millis(100)).print(); verify_proof(&all_stark, proof, &config) diff --git a/evm/tests/basic_smart_contract.rs b/evm/tests/basic_smart_contract.rs index 28db58ed..7d07ca19 100644 --- a/evm/tests/basic_smart_contract.rs +++ b/evm/tests/basic_smart_contract.rs @@ -200,7 +200,7 @@ fn test_basic_smart_contract() -> anyhow::Result<()> { }; let mut timing = TimingTree::new("prove", log::Level::Debug); - let proof = prove::(&all_stark, &config, inputs, &mut timing)?; + let proof = prove::(&all_stark, &config, inputs, &mut timing, None)?; timing.filter(Duration::from_millis(100)).print(); verify_proof(&all_stark, proof, &config) diff --git a/evm/tests/empty_txn_list.rs b/evm/tests/empty_txn_list.rs index 684a8f36..16486677 100644 --- a/evm/tests/empty_txn_list.rs +++ b/evm/tests/empty_txn_list.rs @@ -114,7 +114,7 @@ fn test_empty_txn_list() -> anyhow::Result<()> { let mut timing = TimingTree::new("prove", log::Level::Info); // We're missing some preprocessed circuits. assert!(all_circuits - .prove_root(&all_stark, &config, inputs.clone(), &mut timing) + .prove_root(&all_stark, &config, inputs.clone(), &mut timing, None) .is_err()); // Expand the preprocessed circuits. @@ -127,7 +127,7 @@ fn test_empty_txn_list() -> anyhow::Result<()> { let mut timing = TimingTree::new("prove", log::Level::Info); let (root_proof, public_values) = - all_circuits.prove_root(&all_stark, &config, inputs, &mut timing)?; + all_circuits.prove_root(&all_stark, &config, inputs, &mut timing, None)?; timing.filter(Duration::from_millis(100)).print(); all_circuits.verify_root(root_proof.clone())?; diff --git a/evm/tests/erc20.rs b/evm/tests/erc20.rs index 9c4bfa83..48d0d753 100644 --- a/evm/tests/erc20.rs +++ b/evm/tests/erc20.rs @@ -176,7 +176,7 @@ fn test_erc20() -> anyhow::Result<()> { }; let mut timing = TimingTree::new("prove", log::Level::Debug); - let proof = prove::(&all_stark, &config, inputs, &mut timing)?; + let proof = prove::(&all_stark, &config, inputs, &mut timing, None)?; timing.filter(Duration::from_millis(100)).print(); verify_proof(&all_stark, proof, &config) diff --git a/evm/tests/log_opcode.rs b/evm/tests/log_opcode.rs index b00fe36f..157b9fe6 100644 --- a/evm/tests/log_opcode.rs +++ b/evm/tests/log_opcode.rs @@ -234,7 +234,7 @@ fn test_log_opcodes() -> anyhow::Result<()> { }; let mut timing = TimingTree::new("prove", log::Level::Debug); - let proof = prove::(&all_stark, &config, inputs, &mut timing)?; + let proof = prove::(&all_stark, &config, inputs, &mut timing, None)?; timing.filter(Duration::from_millis(100)).print(); // Assert that the proof leads to the correct state and receipt roots. @@ -450,7 +450,7 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { let mut timing = TimingTree::new("prove root first", log::Level::Info); let (root_proof_first, public_values_first) = - all_circuits.prove_root(&all_stark, &config, inputs_first, &mut timing)?; + all_circuits.prove_root(&all_stark, &config, inputs_first, &mut timing, None)?; timing.filter(Duration::from_millis(100)).print(); all_circuits.verify_root(root_proof_first.clone())?; @@ -570,7 +570,7 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { let mut timing = TimingTree::new("prove root second", log::Level::Info); let (root_proof_second, public_values_second) = - all_circuits.prove_root(&all_stark, &config, inputs, &mut timing)?; + all_circuits.prove_root(&all_stark, &config, inputs, &mut timing, None.clone())?; timing.filter(Duration::from_millis(100)).print(); all_circuits.verify_root(root_proof_second.clone())?; @@ -635,7 +635,7 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { }; let (root_proof, public_values) = - all_circuits.prove_root(&all_stark, &config, inputs, &mut timing)?; + all_circuits.prove_root(&all_stark, &config, inputs, &mut timing, None)?; all_circuits.verify_root(root_proof.clone())?; // We can just duplicate the initial proof as the state didn't change. diff --git a/evm/tests/self_balance_gas_cost.rs b/evm/tests/self_balance_gas_cost.rs index 7346dc24..538f2aa7 100644 --- a/evm/tests/self_balance_gas_cost.rs +++ b/evm/tests/self_balance_gas_cost.rs @@ -187,7 +187,7 @@ fn self_balance_gas_cost() -> anyhow::Result<()> { }; let mut timing = TimingTree::new("prove", log::Level::Debug); - let proof = prove::(&all_stark, &config, inputs, &mut timing)?; + let proof = prove::(&all_stark, &config, inputs, &mut timing, None)?; timing.filter(Duration::from_millis(100)).print(); verify_proof(&all_stark, proof, &config) diff --git a/evm/tests/selfdestruct.rs b/evm/tests/selfdestruct.rs index fb33b18f..829e0b21 100644 --- a/evm/tests/selfdestruct.rs +++ b/evm/tests/selfdestruct.rs @@ -139,7 +139,7 @@ fn test_selfdestruct() -> anyhow::Result<()> { }; let mut timing = TimingTree::new("prove", log::Level::Debug); - let proof = prove::(&all_stark, &config, inputs, &mut timing)?; + let proof = prove::(&all_stark, &config, inputs, &mut timing, None)?; timing.filter(Duration::from_millis(100)).print(); verify_proof(&all_stark, proof, &config) diff --git a/evm/tests/simple_transfer.rs b/evm/tests/simple_transfer.rs index ede18bf8..5fd252df 100644 --- a/evm/tests/simple_transfer.rs +++ b/evm/tests/simple_transfer.rs @@ -155,7 +155,7 @@ fn test_simple_transfer() -> anyhow::Result<()> { }; let mut timing = TimingTree::new("prove", log::Level::Debug); - let proof = prove::(&all_stark, &config, inputs, &mut timing)?; + let proof = prove::(&all_stark, &config, inputs, &mut timing, None)?; timing.filter(Duration::from_millis(100)).print(); verify_proof(&all_stark, proof, &config) diff --git a/evm/tests/withdrawals.rs b/evm/tests/withdrawals.rs index 29c95817..ef2d19b0 100644 --- a/evm/tests/withdrawals.rs +++ b/evm/tests/withdrawals.rs @@ -85,7 +85,7 @@ fn test_withdrawals() -> anyhow::Result<()> { }; let mut timing = TimingTree::new("prove", log::Level::Debug); - let proof = prove::(&all_stark, &config, inputs, &mut timing)?; + let proof = prove::(&all_stark, &config, inputs, &mut timing, None)?; timing.filter(Duration::from_millis(100)).print(); verify_proof(&all_stark, proof, &config)