Add aborting signal (#1429)

* Add aborting signal

* Clippy

* Update to Option following comment
This commit is contained in:
Robin Salen 2023-12-15 19:35:27 +01:00 committed by GitHub
parent fdd7ee46fe
commit a64311cfd4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 77 additions and 18 deletions

View File

@ -1,6 +1,8 @@
use core::mem::{self, MaybeUninit};
use std::collections::BTreeMap;
use std::ops::Range;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use eth_trie_utils::partial_trie::{HashedPartialTrie, Node, PartialTrie};
use hashbrown::HashMap;
@ -39,7 +41,7 @@ use crate::proof::{
BlockHashesTarget, BlockMetadataTarget, ExtraBlockData, ExtraBlockDataTarget, PublicValues,
PublicValuesTarget, StarkProofWithMetadata, TrieRoots, TrieRootsTarget,
};
use crate::prover::prove;
use crate::prover::{check_abort_signal, prove};
use crate::recursive_verifier::{
add_common_recursion_gates, add_virtual_public_values, get_memory_extra_looking_sum_circuit,
recursive_stark_circuit, set_public_value_targets, PlonkWrapperCircuit, PublicInputs,
@ -967,8 +969,15 @@ where
config: &StarkConfig,
generation_inputs: GenerationInputs,
timing: &mut TimingTree,
abort_signal: Option<Arc<AtomicBool>>,
) -> anyhow::Result<(ProofWithPublicInputs<F, C, D>, PublicValues)> {
let all_proof = prove::<F, C, D>(all_stark, config, generation_inputs, timing)?;
let all_proof = prove::<F, C, D>(
all_stark,
config,
generation_inputs,
timing,
abort_signal.clone(),
)?;
let mut root_inputs = PartialWitness::new();
for table in 0..NUM_TABLES {
@ -996,6 +1005,8 @@ where
F::from_canonical_usize(index_verifier_data),
);
root_inputs.set_proof_with_pis_target(&self.root.proof_with_pis[table], &shrunk_proof);
check_abort_signal(abort_signal.clone())?;
}
root_inputs.set_verifier_data_target(

View File

@ -1,4 +1,6 @@
use std::collections::HashMap;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use anyhow::anyhow;
use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie};
@ -23,6 +25,7 @@ use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
use crate::generation::state::GenerationState;
use crate::memory::segments::Segment;
use crate::proof::{BlockHashes, BlockMetadata, ExtraBlockData, PublicValues, TrieRoots};
use crate::prover::check_abort_signal;
use crate::util::h2u;
use crate::witness::memory::{MemoryAddress, MemoryChannel};
use crate::witness::transition::transition;

View File

@ -765,6 +765,7 @@ mod tests {
},
&mut Challenger::new(),
&mut timing,
None,
)?;
timing.print();

View File

@ -1,4 +1,7 @@
use anyhow::{ensure, Result};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use anyhow::{anyhow, ensure, Result};
use itertools::Itertools;
use once_cell::sync::Lazy;
use plonky2::field::extension::Extendable;
@ -32,6 +35,7 @@ use crate::lookup::{lookup_helper_columns, Lookup, LookupCheckVars};
use crate::proof::{AllProof, PublicValues, StarkOpeningSet, StarkProof, StarkProofWithMetadata};
use crate::stark::Stark;
use crate::vanishing_poly::eval_vanishing_poly;
use crate::witness::errors::ProgramError;
#[cfg(test)]
use crate::{
cross_table_lookup::testutils::check_ctls, verifier::testutils::get_memory_extra_looking_values,
@ -43,6 +47,7 @@ pub fn prove<F, C, const D: usize>(
config: &StarkConfig,
inputs: GenerationInputs,
timing: &mut TimingTree,
abort_signal: Option<Arc<AtomicBool>>,
) -> Result<AllProof<F, C, D>>
where
F: RichField + Extendable<D>,
@ -54,7 +59,16 @@ where
"generate all traces",
generate_traces(all_stark, inputs, config, timing)?
);
let proof = prove_with_traces(all_stark, config, traces, public_values, timing)?;
check_abort_signal(abort_signal.clone())?;
let proof = prove_with_traces(
all_stark,
config,
traces,
public_values,
timing,
abort_signal,
)?;
Ok(proof)
}
@ -65,6 +79,7 @@ pub(crate) fn prove_with_traces<F, C, const D: usize>(
trace_poly_values: [Vec<PolynomialValues<F>>; NUM_TABLES],
public_values: PublicValues,
timing: &mut TimingTree,
abort_signal: Option<Arc<AtomicBool>>,
) -> Result<AllProof<F, C, D>>
where
F: RichField + Extendable<D>,
@ -136,7 +151,8 @@ where
ctl_data_per_table,
&mut challenger,
&ctl_challenges,
timing
timing,
abort_signal,
)?
);
@ -172,6 +188,7 @@ fn prove_with_commitments<F, C, const D: usize>(
challenger: &mut Challenger<F, C::Hasher>,
ctl_challenges: &GrandProductChallengeSet<F>,
timing: &mut TimingTree,
abort_signal: Option<Arc<AtomicBool>>,
) -> Result<[StarkProofWithMetadata<F, C, D>; NUM_TABLES]>
where
F: RichField + Extendable<D>,
@ -189,6 +206,7 @@ where
ctl_challenges,
challenger,
timing,
abort_signal.clone(),
)?
);
let byte_packing_proof = timed!(
@ -203,6 +221,7 @@ where
ctl_challenges,
challenger,
timing,
abort_signal.clone(),
)?
);
let cpu_proof = timed!(
@ -217,6 +236,7 @@ where
ctl_challenges,
challenger,
timing,
abort_signal.clone(),
)?
);
let keccak_proof = timed!(
@ -231,6 +251,7 @@ where
ctl_challenges,
challenger,
timing,
abort_signal.clone(),
)?
);
let keccak_sponge_proof = timed!(
@ -245,6 +266,7 @@ where
ctl_challenges,
challenger,
timing,
abort_signal.clone(),
)?
);
let logic_proof = timed!(
@ -259,6 +281,7 @@ where
ctl_challenges,
challenger,
timing,
abort_signal.clone(),
)?
);
let memory_proof = timed!(
@ -273,6 +296,7 @@ where
ctl_challenges,
challenger,
timing,
abort_signal,
)?
);
@ -300,12 +324,15 @@ pub(crate) fn prove_single_table<F, C, S, const D: usize>(
ctl_challenges: &GrandProductChallengeSet<F>,
challenger: &mut Challenger<F, C::Hasher>,
timing: &mut TimingTree,
abort_signal: Option<Arc<AtomicBool>>,
) -> Result<StarkProofWithMetadata<F, C, D>>
where
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
S: Stark<F, D>,
{
check_abort_signal(abort_signal.clone())?;
let degree = trace_poly_values[0].len();
let degree_bits = log2_strict(degree);
let fri_params = config.fri_params(degree_bits);
@ -392,6 +419,8 @@ where
);
}
check_abort_signal(abort_signal.clone())?;
let quotient_polys = timed!(
timing,
"compute quotient polys",
@ -469,6 +498,8 @@ where
&quotient_commitment,
];
check_abort_signal(abort_signal.clone())?;
let opening_proof = timed!(
timing,
"compute openings proof",
@ -636,6 +667,19 @@ where
.collect()
}
/// Utility method that checks whether a kill signal has been emitted by one of the workers,
/// which will result in an early abort for all the other processes involved in the same set
/// of transactions.
pub(crate) fn check_abort_signal(abort_signal: Option<Arc<AtomicBool>>) -> Result<()> {
if let Some(signal) = abort_signal {
if signal.load(Ordering::Relaxed) {
return Err(anyhow!("Stopping job from abort signal."));
}
}
Ok(())
}
#[cfg(test)]
/// Check that all constraints evaluate to zero on `H`.
/// Can also be used to check the degree of the constraints by evaluating on a larger subgroup.

View File

@ -168,7 +168,7 @@ fn add11_yml() -> anyhow::Result<()> {
};
let mut timing = TimingTree::new("prove", log::Level::Debug);
let proof = prove::<F, C, D>(&all_stark, &config, inputs, &mut timing)?;
let proof = prove::<F, C, D>(&all_stark, &config, inputs, &mut timing, None)?;
timing.filter(Duration::from_millis(100)).print();
verify_proof(&all_stark, proof, &config)

View File

@ -200,7 +200,7 @@ fn test_basic_smart_contract() -> anyhow::Result<()> {
};
let mut timing = TimingTree::new("prove", log::Level::Debug);
let proof = prove::<F, C, D>(&all_stark, &config, inputs, &mut timing)?;
let proof = prove::<F, C, D>(&all_stark, &config, inputs, &mut timing, None)?;
timing.filter(Duration::from_millis(100)).print();
verify_proof(&all_stark, proof, &config)

View File

@ -114,7 +114,7 @@ fn test_empty_txn_list() -> anyhow::Result<()> {
let mut timing = TimingTree::new("prove", log::Level::Info);
// We're missing some preprocessed circuits.
assert!(all_circuits
.prove_root(&all_stark, &config, inputs.clone(), &mut timing)
.prove_root(&all_stark, &config, inputs.clone(), &mut timing, None)
.is_err());
// Expand the preprocessed circuits.
@ -127,7 +127,7 @@ fn test_empty_txn_list() -> anyhow::Result<()> {
let mut timing = TimingTree::new("prove", log::Level::Info);
let (root_proof, public_values) =
all_circuits.prove_root(&all_stark, &config, inputs, &mut timing)?;
all_circuits.prove_root(&all_stark, &config, inputs, &mut timing, None)?;
timing.filter(Duration::from_millis(100)).print();
all_circuits.verify_root(root_proof.clone())?;

View File

@ -176,7 +176,7 @@ fn test_erc20() -> anyhow::Result<()> {
};
let mut timing = TimingTree::new("prove", log::Level::Debug);
let proof = prove::<F, C, D>(&all_stark, &config, inputs, &mut timing)?;
let proof = prove::<F, C, D>(&all_stark, &config, inputs, &mut timing, None)?;
timing.filter(Duration::from_millis(100)).print();
verify_proof(&all_stark, proof, &config)

View File

@ -234,7 +234,7 @@ fn test_log_opcodes() -> anyhow::Result<()> {
};
let mut timing = TimingTree::new("prove", log::Level::Debug);
let proof = prove::<F, C, D>(&all_stark, &config, inputs, &mut timing)?;
let proof = prove::<F, C, D>(&all_stark, &config, inputs, &mut timing, None)?;
timing.filter(Duration::from_millis(100)).print();
// Assert that the proof leads to the correct state and receipt roots.
@ -450,7 +450,7 @@ fn test_log_with_aggreg() -> anyhow::Result<()> {
let mut timing = TimingTree::new("prove root first", log::Level::Info);
let (root_proof_first, public_values_first) =
all_circuits.prove_root(&all_stark, &config, inputs_first, &mut timing)?;
all_circuits.prove_root(&all_stark, &config, inputs_first, &mut timing, None)?;
timing.filter(Duration::from_millis(100)).print();
all_circuits.verify_root(root_proof_first.clone())?;
@ -570,7 +570,7 @@ fn test_log_with_aggreg() -> anyhow::Result<()> {
let mut timing = TimingTree::new("prove root second", log::Level::Info);
let (root_proof_second, public_values_second) =
all_circuits.prove_root(&all_stark, &config, inputs, &mut timing)?;
all_circuits.prove_root(&all_stark, &config, inputs, &mut timing, None.clone())?;
timing.filter(Duration::from_millis(100)).print();
all_circuits.verify_root(root_proof_second.clone())?;
@ -635,7 +635,7 @@ fn test_log_with_aggreg() -> anyhow::Result<()> {
};
let (root_proof, public_values) =
all_circuits.prove_root(&all_stark, &config, inputs, &mut timing)?;
all_circuits.prove_root(&all_stark, &config, inputs, &mut timing, None)?;
all_circuits.verify_root(root_proof.clone())?;
// We can just duplicate the initial proof as the state didn't change.

View File

@ -187,7 +187,7 @@ fn self_balance_gas_cost() -> anyhow::Result<()> {
};
let mut timing = TimingTree::new("prove", log::Level::Debug);
let proof = prove::<F, C, D>(&all_stark, &config, inputs, &mut timing)?;
let proof = prove::<F, C, D>(&all_stark, &config, inputs, &mut timing, None)?;
timing.filter(Duration::from_millis(100)).print();
verify_proof(&all_stark, proof, &config)

View File

@ -139,7 +139,7 @@ fn test_selfdestruct() -> anyhow::Result<()> {
};
let mut timing = TimingTree::new("prove", log::Level::Debug);
let proof = prove::<F, C, D>(&all_stark, &config, inputs, &mut timing)?;
let proof = prove::<F, C, D>(&all_stark, &config, inputs, &mut timing, None)?;
timing.filter(Duration::from_millis(100)).print();
verify_proof(&all_stark, proof, &config)

View File

@ -155,7 +155,7 @@ fn test_simple_transfer() -> anyhow::Result<()> {
};
let mut timing = TimingTree::new("prove", log::Level::Debug);
let proof = prove::<F, C, D>(&all_stark, &config, inputs, &mut timing)?;
let proof = prove::<F, C, D>(&all_stark, &config, inputs, &mut timing, None)?;
timing.filter(Duration::from_millis(100)).print();
verify_proof(&all_stark, proof, &config)

View File

@ -85,7 +85,7 @@ fn test_withdrawals() -> anyhow::Result<()> {
};
let mut timing = TimingTree::new("prove", log::Level::Debug);
let proof = prove::<F, C, D>(&all_stark, &config, inputs, &mut timing)?;
let proof = prove::<F, C, D>(&all_stark, &config, inputs, &mut timing, None)?;
timing.filter(Duration::from_millis(100)).print();
verify_proof(&all_stark, proof, &config)