Eliminate nested simulations

This commit is contained in:
4l0n50 2023-12-15 17:11:00 +01:00
parent 2c5347c45f
commit 81f13f3f8a
6 changed files with 103 additions and 91 deletions

View File

@ -114,15 +114,12 @@ global is_jumpdest:
MLOAD_GENERAL
// stack: opcode, jumpdest, ctx, proof_prefix_addr, retdest
// Slightly more efficient than `%eq_const(0x5b) ISZERO`
PUSH 0x5b
SUB
%jumpi(panic)
%assert_eq_const(0x5b)
//stack: jumpdest, ctx, proof_prefix_addr, retdest
SWAP2 DUP1
// stack: proof_prefix_addr, proof_prefix_addr, ctx, jumpdest
IS_ZERO
ISZERO
%jumpi(verify_path)
// stack: proof_prefix_addr, ctx, jumpdest, retdest
// If we are here we need to check that the next 32 bytes are less

View File

@ -1,7 +1,7 @@
//! An EVM interpreter for testing and debugging purposes.
use core::cmp::Ordering;
use std::collections::HashMap;
use std::collections::{BTreeSet, HashMap, HashSet};
use std::ops::Range;
use anyhow::{anyhow, bail, ensure};
@ -293,17 +293,19 @@ impl<'a> Interpreter<'a> {
pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec<bool>) {
self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize]
.content = jumpdest_bits.iter().map(|&x| u256_from_bool(x)).collect();
self.generation_state.jumpdest_addresses = Some(
jumpdest_bits
.into_iter()
.enumerate()
.filter(|&(_, x)| x)
.map(|(i, _)| i)
.collect(),
)
self.generation_state.jumpdest_addresses = Some(HashMap::from([(
context,
BTreeSet::from_iter(
jumpdest_bits
.into_iter()
.enumerate()
.filter(|&(_, x)| x)
.map(|(i, _)| i),
),
)]));
}
const fn incr(&mut self, n: usize) {
pub(crate) fn incr(&mut self, n: usize) {
self.generation_state.registers.program_counter += n;
}

View File

@ -1,4 +1,4 @@
use std::collections::{HashMap, HashSet};
use std::collections::{BTreeSet, HashMap, HashSet};
use anyhow::anyhow;
use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie};
@ -27,6 +27,7 @@ use crate::generation::state::GenerationState;
use crate::memory::segments::Segment;
use crate::proof::{BlockHashes, BlockMetadata, ExtraBlockData, PublicValues, TrieRoots};
use crate::util::{h2u, u256_to_u8, u256_to_usize};
use crate::witness::errors::{ProgramError, ProverInputError};
use crate::witness::memory::{MemoryAddress, MemoryChannel};
use crate::witness::transition::transition;
@ -287,56 +288,68 @@ fn simulate_cpu_between_labels_and_get_user_jumps<F: Field>(
initial_label: &str,
final_label: &str,
state: &mut GenerationState<F>,
) -> anyhow::Result<Vec<usize>> {
let halt_pc = KERNEL.global_labels[final_label];
let mut jumpdest_addresses = HashSet::new();
state.registers.program_counter = KERNEL.global_labels[initial_label];
let context = state.registers.context;
) -> Result<(), ProgramError> {
if let Some(_) = state.jumpdest_addresses {
Ok(())
} else {
const JUMP_OPCODE: u8 = 0x56;
const JUMPI_OPCODE: u8 = 0x57;
log::debug!("Simulating CPU for jumpdest analysis.");
let halt_pc = KERNEL.global_labels[final_label];
let mut jumpdest_addresses: HashMap<_, BTreeSet<usize>> = HashMap::new();
loop {
if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] {
state.registers.program_counter = KERNEL.global_labels["validate_jumpdest_table_end"]
}
let pc = state.registers.program_counter;
let halt = state.registers.is_kernel && pc == halt_pc && state.registers.context == context;
let opcode = u256_to_u8(state.memory.get(MemoryAddress {
context: state.registers.context,
segment: Segment::Code as usize,
virt: state.registers.program_counter,
}))
.map_err(|_| anyhow::Error::msg("Invalid opcode."))?;
let cond = if let Ok(cond) = stack_peek(state, 1) {
cond != U256::zero()
} else {
false
};
if !state.registers.is_kernel
&& (opcode == get_opcode("JUMP") || (opcode == get_opcode("JUMPI") && cond))
{
// TODO: hotfix for avoiding deeper calls to abort
let jumpdest = u256_to_usize(state.registers.stack_top)
.map_err(|_| anyhow!("Not a valid jump destination"))?;
state.memory.set(
MemoryAddress {
context: state.registers.context,
segment: Segment::JumpdestBits as usize,
virt: jumpdest,
},
U256::one(),
);
if (state.registers.context == context) {
jumpdest_addresses.insert(jumpdest);
state.registers.program_counter = KERNEL.global_labels[initial_label];
let initial_context = state.registers.context;
log::debug!("Simulating CPU for jumpdest analysis.");
loop {
if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] {
state.registers.program_counter =
KERNEL.global_labels["validate_jumpdest_table_end"]
}
let pc = state.registers.program_counter;
let context = state.registers.context;
let halt = state.registers.is_kernel
&& pc == halt_pc
&& state.registers.context == initial_context;
let opcode = u256_to_u8(state.memory.get(MemoryAddress {
context,
segment: Segment::Code as usize,
virt: state.registers.program_counter,
}))?;
let cond = if let Ok(cond) = stack_peek(state, 1) {
cond != U256::zero()
} else {
false
};
if !state.registers.is_kernel
&& (opcode == JUMP_OPCODE || (opcode == JUMPI_OPCODE && cond))
{
// Avoid deeper calls to abort
let jumpdest = u256_to_usize(state.registers.stack_top)?;
state.memory.set(
MemoryAddress {
context,
segment: Segment::JumpdestBits as usize,
virt: jumpdest,
},
U256::one(),
);
if let Some(ctx_addresses) = jumpdest_addresses.get_mut(&context) {
ctx_addresses.insert(jumpdest);
} else {
jumpdest_addresses.insert(context, BTreeSet::from([jumpdest]));
}
}
if halt {
log::debug!("Simulated CPU halted after {} cycles", state.traces.clock());
state.jumpdest_addresses = Some(jumpdest_addresses);
return Ok(());
}
transition(state).map_err(|_| {
ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation)
})?;
}
if halt {
log::debug!("Simulated CPU halted after {} cycles", state.traces.clock());
let mut jumpdest_addresses: Vec<usize> = jumpdest_addresses.into_iter().collect();
jumpdest_addresses.sort();
return Ok(jumpdest_addresses);
}
transition(state)?;
}
}

View File

@ -16,7 +16,6 @@ use serde::{Deserialize, Serialize};
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::constants::context_metadata::ContextMetadata;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
use crate::cpu::kernel::opcodes::{get_opcode, get_push_opcode};
use crate::extension_tower::{FieldExt, Fp12, BLS381, BN254};
use crate::generation::prover_input::EvmField::{
Bls381Base, Bls381Scalar, Bn254Base, Bn254Scalar, Secp256k1Base, Secp256k1Scalar,
@ -250,8 +249,9 @@ impl<F: Field> GenerationState<F> {
}
/// Return the next used jump addres
fn run_next_jumpdest_table_address(&mut self) -> Result<U256, ProgramError> {
let context = self.registers.context;
let code_len = u256_to_usize(self.memory.get(MemoryAddress {
context: self.registers.context,
context,
segment: Segment::ContextMetadata as usize,
virt: ContextMetadata::CodeSize as usize,
}))?;
@ -260,14 +260,14 @@ impl<F: Field> GenerationState<F> {
self.generate_jumpdest_table()?;
}
let Some(jumpdest_table) = &mut self.jumpdest_addresses else {
// TODO: Add another error
let Some(jumpdest_tables) = &mut self.jumpdest_addresses else {
return Err(ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation));
};
if let Some(next_jumpdest_address) = jumpdest_table.pop() {
self.last_jumpdest_address = next_jumpdest_address;
Ok((next_jumpdest_address + 1).into())
if let Some(ctx_jumpdest_table) = jumpdest_tables.get_mut(&context) && let Some(next_jumpdest_address) = ctx_jumpdest_table.pop_last()
{
self.last_jumpdest_address = next_jumpdest_address;
Ok((next_jumpdest_address + 1).into())
} else {
self.jumpdest_addresses = None;
Ok(U256::zero())
@ -293,6 +293,9 @@ impl<F: Field> GenerationState<F> {
// the previous 32 bytes in the code (including opcodes and pushed bytes)
// are PUSHXX and the address is in its range.
const PUSH1_OPCODE: u8 = 0x60;
const PUSH32_OPCODE: u8 = 0x7f;
let proof = CodeIterator::until(&code, self.last_jumpdest_address + 1).fold(
0,
|acc, (pos, opcode)| {
@ -300,9 +303,9 @@ impl<F: Field> GenerationState<F> {
code[prefix_start..pos].iter().enumerate().fold(
true,
|acc, (prefix_pos, &byte)| {
acc && (byte > get_push_opcode(32)
acc && (byte > PUSH32_OPCODE
|| (prefix_start + prefix_pos) as i32
+ (byte as i32 - get_push_opcode(1) as i32)
+ (byte as i32 - PUSH1_OPCODE as i32)
+ 1
< pos as i32)
},
@ -323,6 +326,7 @@ impl<F: Field> GenerationState<F> {
impl<F: Field> GenerationState<F> {
fn generate_jumpdest_table(&mut self) -> Result<(), ProgramError> {
const JUMPDEST_OPCODE: u8 = 0x5b;
let mut state = self.soft_clone();
let code_len = u256_to_usize(self.memory.get(MemoryAddress {
context: self.registers.context,
@ -344,8 +348,8 @@ impl<F: Field> GenerationState<F> {
// the simulation will fail.
let mut jumpdest_table = Vec::with_capacity(code.len());
for (pos, opcode) in CodeIterator::new(&code) {
jumpdest_table.push((pos, opcode == get_opcode("JUMPDEST")));
if opcode == get_opcode("JUMPDEST") {
jumpdest_table.push((pos, opcode == JUMPDEST_OPCODE));
if opcode == JUMPDEST_OPCODE {
state.memory.set(
MemoryAddress {
context: state.registers.context,
@ -358,32 +362,31 @@ impl<F: Field> GenerationState<F> {
}
// Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call
self.jumpdest_addresses = simulate_cpu_between_labels_and_get_user_jumps(
simulate_cpu_between_labels_and_get_user_jumps(
"validate_jumpdest_table_end",
"terminate_common",
&mut state,
)
.ok();
)?;
self.jumpdest_addresses = state.jumpdest_addresses;
Ok(())
}
}
struct CodeIterator<'a> {
code: &'a Vec<u8>,
code: &'a [u8],
pos: usize,
end: usize,
}
impl<'a> CodeIterator<'a> {
fn new(code: &'a Vec<u8>) -> Self {
fn new(code: &'a [u8]) -> Self {
CodeIterator {
end: code.len(),
code,
pos: 0,
}
}
fn until(code: &'a Vec<u8>, end: usize) -> Self {
fn until(code: &'a [u8], end: usize) -> Self {
CodeIterator {
end: std::cmp::min(code.len(), end),
code,
@ -396,14 +399,16 @@ impl<'a> Iterator for CodeIterator<'a> {
type Item = (usize, u8);
fn next(&mut self) -> Option<Self::Item> {
const PUSH1_OPCODE: u8 = 0x60;
const PUSH32_OPCODE: u8 = 0x70;
let CodeIterator { code, pos, end } = self;
if *pos >= *end {
return None;
}
let opcode = code[*pos];
let old_pos = *pos;
*pos += if opcode >= get_push_opcode(1) && opcode <= get_push_opcode(32) {
(opcode - get_push_opcode(1) + 2).into()
*pos += if opcode >= PUSH1_OPCODE && opcode <= PUSH32_OPCODE {
(opcode - PUSH1_OPCODE + 2).into()
} else {
1
};

View File

@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::collections::{BTreeSet, HashMap};
use ethereum_types::{Address, BigEndianHash, H160, H256, U256};
use keccak_hash::keccak;
@ -52,7 +52,7 @@ pub(crate) struct GenerationState<F: Field> {
pub(crate) trie_root_ptrs: TrieRootPtrs,
pub(crate) last_jumpdest_address: usize,
pub(crate) jumpdest_addresses: Option<Vec<usize>>,
pub(crate) jumpdest_addresses: Option<HashMap<usize, BTreeSet<usize>>>,
}
impl<F: Field> GenerationState<F> {

View File

@ -395,12 +395,7 @@ fn try_perform_instruction<F: Field>(
if state.registers.is_kernel {
log_kernel_instruction(state, op);
} else {
log::debug!(
"User instruction: {:?} ctx = {:?} stack = {:?}",
op,
state.registers.context,
state.stack()
);
log::debug!("User instruction: {:?}", op);
}
fill_op_flag(op, &mut row);