From 5acabad72d31d244cb68a4214e473b9616b03e59 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Fri, 15 Dec 2023 17:11:00 +0100 Subject: [PATCH] Eliminate nested simulations --- .../cpu/kernel/asm/core/jumpdest_analysis.asm | 7 +- evm/src/cpu/kernel/interpreter.rs | 22 ++-- evm/src/generation/mod.rs | 109 ++++++++++-------- evm/src/generation/prover_input.rs | 45 ++++---- evm/src/generation/state.rs | 4 +- evm/src/witness/transition.rs | 7 +- 6 files changed, 103 insertions(+), 91 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index 79475b37..cfc3575b 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -114,15 +114,12 @@ global is_jumpdest: MLOAD_GENERAL // stack: opcode, jumpdest, ctx, proof_prefix_addr, retdest - // Slightly more efficient than `%eq_const(0x5b) ISZERO` - PUSH 0x5b - SUB - %jumpi(panic) + %assert_eq_const(0x5b) //stack: jumpdest, ctx, proof_prefix_addr, retdest SWAP2 DUP1 // stack: proof_prefix_addr, proof_prefix_addr, ctx, jumpdest - IS_ZERO + ISZERO %jumpi(verify_path) // stack: proof_prefix_addr, ctx, jumpdest, retdest // If we are here we need to check that the next 32 bytes are less diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index 5645045c..177ac4f0 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -1,7 +1,7 @@ //! An EVM interpreter for testing and debugging purposes. use core::cmp::Ordering; -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::ops::Range; use anyhow::bail; @@ -417,17 +417,19 @@ impl<'a> Interpreter<'a> { pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec) { self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize] .content = jumpdest_bits.iter().map(|&x| u256_from_bool(x)).collect(); - self.generation_state.jumpdest_addresses = Some( - jumpdest_bits - .into_iter() - .enumerate() - .filter(|&(_, x)| x) - .map(|(i, _)| i) - .collect(), - ) + self.generation_state.jumpdest_addresses = Some(HashMap::from([( + context, + BTreeSet::from_iter( + jumpdest_bits + .into_iter() + .enumerate() + .filter(|&(_, x)| x) + .map(|(i, _)| i), + ), + )])); } - const fn incr(&mut self, n: usize) { + pub(crate) fn incr(&mut self, n: usize) { self.generation_state.registers.program_counter += n; } diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index b6146260..1919b40d 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap, HashMap, HashSet}; use std::sync::atomic::AtomicBool; use std::sync::Arc; @@ -31,6 +31,7 @@ use crate::memory::segments::Segment; use crate::proof::{BlockHashes, BlockMetadata, ExtraBlockData, PublicValues, TrieRoots}; use crate::prover::check_abort_signal; use crate::util::{h2u, u256_to_u8, u256_to_usize}; +use crate::witness::errors::{ProgramError, ProverInputError}; use crate::witness::memory::{MemoryAddress, MemoryChannel}; use crate::witness::transition::transition; @@ -339,56 +340,68 @@ fn simulate_cpu_between_labels_and_get_user_jumps( initial_label: &str, final_label: &str, state: &mut GenerationState, -) -> anyhow::Result> { - let halt_pc = KERNEL.global_labels[final_label]; - let mut jumpdest_addresses = HashSet::new(); - state.registers.program_counter = KERNEL.global_labels[initial_label]; - let context = state.registers.context; +) -> Result<(), ProgramError> { + if let Some(_) = state.jumpdest_addresses { + Ok(()) + } else { + const JUMP_OPCODE: u8 = 0x56; + const JUMPI_OPCODE: u8 = 0x57; - log::debug!("Simulating CPU for jumpdest analysis."); + let halt_pc = KERNEL.global_labels[final_label]; + let mut jumpdest_addresses: HashMap<_, BTreeSet> = HashMap::new(); - loop { - if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] { - state.registers.program_counter = KERNEL.global_labels["validate_jumpdest_table_end"] - } - let pc = state.registers.program_counter; - let halt = state.registers.is_kernel && pc == halt_pc && state.registers.context == context; - let opcode = u256_to_u8(state.memory.get(MemoryAddress { - context: state.registers.context, - segment: Segment::Code as usize, - virt: state.registers.program_counter, - })) - .map_err(|_| anyhow::Error::msg("Invalid opcode."))?; - let cond = if let Ok(cond) = stack_peek(state, 1) { - cond != U256::zero() - } else { - false - }; - if !state.registers.is_kernel - && (opcode == get_opcode("JUMP") || (opcode == get_opcode("JUMPI") && cond)) - { - // TODO: hotfix for avoiding deeper calls to abort - let jumpdest = u256_to_usize(state.registers.stack_top) - .map_err(|_| anyhow!("Not a valid jump destination"))?; - state.memory.set( - MemoryAddress { - context: state.registers.context, - segment: Segment::JumpdestBits as usize, - virt: jumpdest, - }, - U256::one(), - ); - if (state.registers.context == context) { - jumpdest_addresses.insert(jumpdest); + state.registers.program_counter = KERNEL.global_labels[initial_label]; + let initial_context = state.registers.context; + + log::debug!("Simulating CPU for jumpdest analysis."); + + loop { + if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] { + state.registers.program_counter = + KERNEL.global_labels["validate_jumpdest_table_end"] } + let pc = state.registers.program_counter; + let context = state.registers.context; + let halt = state.registers.is_kernel + && pc == halt_pc + && state.registers.context == initial_context; + let opcode = u256_to_u8(state.memory.get(MemoryAddress { + context, + segment: Segment::Code as usize, + virt: state.registers.program_counter, + }))?; + let cond = if let Ok(cond) = stack_peek(state, 1) { + cond != U256::zero() + } else { + false + }; + if !state.registers.is_kernel + && (opcode == JUMP_OPCODE || (opcode == JUMPI_OPCODE && cond)) + { + // Avoid deeper calls to abort + let jumpdest = u256_to_usize(state.registers.stack_top)?; + state.memory.set( + MemoryAddress { + context, + segment: Segment::JumpdestBits as usize, + virt: jumpdest, + }, + U256::one(), + ); + if let Some(ctx_addresses) = jumpdest_addresses.get_mut(&context) { + ctx_addresses.insert(jumpdest); + } else { + jumpdest_addresses.insert(context, BTreeSet::from([jumpdest])); + } + } + if halt { + log::debug!("Simulated CPU halted after {} cycles", state.traces.clock()); + state.jumpdest_addresses = Some(jumpdest_addresses); + return Ok(()); + } + transition(state).map_err(|_| { + ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation) + })?; } - if halt { - log::debug!("Simulated CPU halted after {} cycles", state.traces.clock()); - let mut jumpdest_addresses: Vec = jumpdest_addresses.into_iter().collect(); - jumpdest_addresses.sort(); - return Ok(jumpdest_addresses); - } - - transition(state)?; } } diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 1808f3f3..35571dcb 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -16,7 +16,6 @@ use serde::{Deserialize, Serialize}; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; -use crate::cpu::kernel::opcodes::{get_opcode, get_push_opcode}; use crate::extension_tower::{FieldExt, Fp12, BLS381, BN254}; use crate::generation::prover_input::EvmField::{ Bls381Base, Bls381Scalar, Bn254Base, Bn254Scalar, Secp256k1Base, Secp256k1Scalar, @@ -250,8 +249,9 @@ impl GenerationState { } /// Return the next used jump addres fn run_next_jumpdest_table_address(&mut self) -> Result { + let context = self.registers.context; let code_len = u256_to_usize(self.memory.get(MemoryAddress { - context: self.registers.context, + context, segment: Segment::ContextMetadata as usize, virt: ContextMetadata::CodeSize as usize, }))?; @@ -260,14 +260,14 @@ impl GenerationState { self.generate_jumpdest_table()?; } - let Some(jumpdest_table) = &mut self.jumpdest_addresses else { - // TODO: Add another error + let Some(jumpdest_tables) = &mut self.jumpdest_addresses else { return Err(ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation)); }; - if let Some(next_jumpdest_address) = jumpdest_table.pop() { - self.last_jumpdest_address = next_jumpdest_address; - Ok((next_jumpdest_address + 1).into()) + if let Some(ctx_jumpdest_table) = jumpdest_tables.get_mut(&context) && let Some(next_jumpdest_address) = ctx_jumpdest_table.pop_last() + { + self.last_jumpdest_address = next_jumpdest_address; + Ok((next_jumpdest_address + 1).into()) } else { self.jumpdest_addresses = None; Ok(U256::zero()) @@ -293,6 +293,9 @@ impl GenerationState { // the previous 32 bytes in the code (including opcodes and pushed bytes) // are PUSHXX and the address is in its range. + const PUSH1_OPCODE: u8 = 0x60; + const PUSH32_OPCODE: u8 = 0x7f; + let proof = CodeIterator::until(&code, self.last_jumpdest_address + 1).fold( 0, |acc, (pos, opcode)| { @@ -300,9 +303,9 @@ impl GenerationState { code[prefix_start..pos].iter().enumerate().fold( true, |acc, (prefix_pos, &byte)| { - acc && (byte > get_push_opcode(32) + acc && (byte > PUSH32_OPCODE || (prefix_start + prefix_pos) as i32 - + (byte as i32 - get_push_opcode(1) as i32) + + (byte as i32 - PUSH1_OPCODE as i32) + 1 < pos as i32) }, @@ -323,6 +326,7 @@ impl GenerationState { impl GenerationState { fn generate_jumpdest_table(&mut self) -> Result<(), ProgramError> { + const JUMPDEST_OPCODE: u8 = 0x5b; let mut state = self.soft_clone(); let code_len = u256_to_usize(self.memory.get(MemoryAddress { context: self.registers.context, @@ -344,8 +348,8 @@ impl GenerationState { // the simulation will fail. let mut jumpdest_table = Vec::with_capacity(code.len()); for (pos, opcode) in CodeIterator::new(&code) { - jumpdest_table.push((pos, opcode == get_opcode("JUMPDEST"))); - if opcode == get_opcode("JUMPDEST") { + jumpdest_table.push((pos, opcode == JUMPDEST_OPCODE)); + if opcode == JUMPDEST_OPCODE { state.memory.set( MemoryAddress { context: state.registers.context, @@ -358,32 +362,31 @@ impl GenerationState { } // Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call - self.jumpdest_addresses = simulate_cpu_between_labels_and_get_user_jumps( + simulate_cpu_between_labels_and_get_user_jumps( "validate_jumpdest_table_end", "terminate_common", &mut state, - ) - .ok(); - + )?; + self.jumpdest_addresses = state.jumpdest_addresses; Ok(()) } } struct CodeIterator<'a> { - code: &'a Vec, + code: &'a [u8], pos: usize, end: usize, } impl<'a> CodeIterator<'a> { - fn new(code: &'a Vec) -> Self { + fn new(code: &'a [u8]) -> Self { CodeIterator { end: code.len(), code, pos: 0, } } - fn until(code: &'a Vec, end: usize) -> Self { + fn until(code: &'a [u8], end: usize) -> Self { CodeIterator { end: std::cmp::min(code.len(), end), code, @@ -396,14 +399,16 @@ impl<'a> Iterator for CodeIterator<'a> { type Item = (usize, u8); fn next(&mut self) -> Option { + const PUSH1_OPCODE: u8 = 0x60; + const PUSH32_OPCODE: u8 = 0x70; let CodeIterator { code, pos, end } = self; if *pos >= *end { return None; } let opcode = code[*pos]; let old_pos = *pos; - *pos += if opcode >= get_push_opcode(1) && opcode <= get_push_opcode(32) { - (opcode - get_push_opcode(1) + 2).into() + *pos += if opcode >= PUSH1_OPCODE && opcode <= PUSH32_OPCODE { + (opcode - PUSH1_OPCODE + 2).into() } else { 1 }; diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index 79dd94fb..de07c942 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap}; use ethereum_types::{Address, BigEndianHash, H160, H256, U256}; use keccak_hash::keccak; @@ -52,7 +52,7 @@ pub(crate) struct GenerationState { pub(crate) trie_root_ptrs: TrieRootPtrs, pub(crate) last_jumpdest_address: usize, - pub(crate) jumpdest_addresses: Option>, + pub(crate) jumpdest_addresses: Option>>, } impl GenerationState { diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index 04688543..cf2e3bbe 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -395,12 +395,7 @@ fn try_perform_instruction( if state.registers.is_kernel { log_kernel_instruction(state, op); } else { - log::debug!( - "User instruction: {:?} ctx = {:?} stack = {:?}", - op, - state.registers.context, - state.stack() - ); + log::debug!("User instruction: {:?}", op); } fill_op_flag(op, &mut row);