diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index 177ac4f0..30f862cd 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -417,16 +417,17 @@ impl<'a> Interpreter<'a> { pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec) { self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize] .content = jumpdest_bits.iter().map(|&x| u256_from_bool(x)).collect(); - self.generation_state.jumpdest_addresses = Some(HashMap::from([( - context, - BTreeSet::from_iter( - jumpdest_bits - .into_iter() - .enumerate() - .filter(|&(_, x)| x) - .map(|(i, _)| i), - ), - )])); + self.generation_state + .set_proofs_and_jumpdests(HashMap::from([( + context, + BTreeSet::from_iter( + jumpdest_bits + .into_iter() + .enumerate() + .filter(|&(_, x)| x) + .map(|(i, _)| i), + ), + )])); } pub(crate) fn incr(&mut self, n: usize) { diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 8f568d90..81bacb75 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -340,9 +340,9 @@ fn simulate_cpu_between_labels_and_get_user_jumps( initial_label: &str, final_label: &str, state: &mut GenerationState, -) -> Result<(), ProgramError> { - if state.jumpdest_addresses.is_some() { - Ok(()) +) -> Result>>, ProgramError> { + if state.jumpdest_proofs.is_some() { + Ok(None) } else { const JUMP_OPCODE: u8 = 0x56; const JUMPI_OPCODE: u8 = 0x57; @@ -356,6 +356,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps( log::debug!("Simulating CPU for jumpdest analysis."); loop { + // skip jumdest table validations in simulations if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] { state.registers.program_counter = KERNEL.global_labels["validate_jumpdest_table_end"] @@ -396,8 +397,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps( } if halt { log::debug!("Simulated CPU halted after {} cycles", state.traces.clock()); - state.jumpdest_addresses = Some(jumpdest_addresses); - return Ok(()); + return Ok(Some(jumpdest_addresses)); } transition(state).map_err(|_| { ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation) diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 926b876d..a5f73ae6 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -1,11 +1,10 @@ use std::cmp::min; -use std::collections::HashSet; +use std::collections::HashMap; use std::mem::transmute; use std::str::FromStr; use anyhow::{bail, Error}; use ethereum_types::{BigEndianHash, H256, U256, U512}; -use hashbrown::HashMap; use itertools::{enumerate, Itertools}; use num_bigint::BigUint; use plonky2::field::extension::Extendable; @@ -256,87 +255,98 @@ impl GenerationState { virt: ContextMetadata::CodeSize as usize, }))?; - if self.jumpdest_addresses.is_none() { - self.generate_jumpdest_table()?; + if self.jumpdest_proofs.is_none() { + self.generate_jumpdest_proofs()?; } - let Some(jumpdest_tables) = &mut self.jumpdest_addresses else { + let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else { return Err(ProgramError::ProverInputError( ProverInputError::InvalidJumpdestSimulation, )); }; - if let Some(ctx_jumpdest_table) = jumpdest_tables.get_mut(&context) - && let Some(next_jumpdest_address) = ctx_jumpdest_table.pop_last() + if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context) + && let Some(next_jumpdest_address) = ctx_jumpdest_proofs.pop() { - self.last_jumpdest_address = next_jumpdest_address; Ok((next_jumpdest_address + 1).into()) } else { - self.jumpdest_addresses = None; + self.jumpdest_proofs = None; Ok(U256::zero()) } } /// Returns the proof for the last jump address. fn run_next_jumpdest_table_proof(&mut self) -> Result { - let code = (0..self.last_jumpdest_address) - .map(|i| { - u256_to_u8(self.memory.get(MemoryAddress { - context: self.registers.context, - segment: Segment::Code as usize, - virt: i, - })) - }) - .collect::, _>>()?; - - // TODO: The proof searching algorithm is not very efficient. But luckily it doesn't seem - // a problem as is done natively. - - // Search the closest address to `last_jumpdest_address` for which none of - // the previous 32 bytes in the code (including opcodes and pushed bytes) - // are PUSHXX and the address is in its range. - - const PUSH1_OPCODE: u8 = 0x60; - const PUSH32_OPCODE: u8 = 0x7f; - - let proof = CodeIterator::until(&code, self.last_jumpdest_address + 1).fold( - 0, - |acc, (pos, opcode)| { - let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) { - code[prefix_start..pos].iter().enumerate().fold( - true, - |acc, (prefix_pos, &byte)| { - acc && (byte > PUSH32_OPCODE - || (prefix_start + prefix_pos) as i32 - + (byte as i32 - PUSH1_OPCODE as i32) - + 1 - < pos as i32) - }, - ) - } else { - false - }; - if has_prefix { - pos - 32 - } else { - acc - } - }, - ); - Ok(proof.into()) + let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else { + return Err(ProgramError::ProverInputError( + ProverInputError::InvalidJumpdestSimulation, + )); + }; + if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context) + && let Some(next_jumpdest_proof) = ctx_jumpdest_proofs.pop() + { + Ok(next_jumpdest_proof.into()) + } else { + Err(ProgramError::ProverInputError( + ProverInputError::InvalidJumpdestSimulation, + )) + } } } impl GenerationState { - fn generate_jumpdest_table(&mut self) -> Result<(), ProgramError> { - const JUMPDEST_OPCODE: u8 = 0x5b; - let mut state = self.soft_clone(); - let code_len = u256_to_usize(self.memory.get(MemoryAddress { - context: self.registers.context, - segment: Segment::ContextMetadata as usize, - virt: ContextMetadata::CodeSize as usize, - }))?; - // Generate the jumpdest table + fn generate_jumpdest_proofs(&mut self) -> Result<(), ProgramError> { + let checkpoint = self.checkpoint(); + let memory = self.memory.clone(); + + let code = self.get_current_code()?; + // We need to set the simulated jumpdest bits to one as otherwise + // the simulation will fail. + self.set_jumpdest_bits(&code); + + // Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call + let Some(jumpdest_table) = simulate_cpu_between_labels_and_get_user_jumps( + "validate_jumpdest_table_end", + "terminate_common", + self, + )? + else { + return Ok(()); + }; + + // Return to the state before starting the simulation + self.rollback(checkpoint); + self.memory = memory; + + // Find proofs for all context + self.set_proofs_and_jumpdests(jumpdest_table); + + Ok(()) + } + + pub(crate) fn set_proofs_and_jumpdests( + &mut self, + jumpdest_table: HashMap>, + ) { + self.jumpdest_proofs = Some(HashMap::from_iter(jumpdest_table.into_iter().map( + |(ctx, jumpdest_table)| { + let code = self.get_code(ctx).unwrap(); + if let Some(&largest_address) = jumpdest_table.last() { + let proofs = get_proofs_and_jumpdests(&code, largest_address, jumpdest_table); + (ctx, proofs) + } else { + (ctx, vec![]) + } + }, + ))); + } + + fn get_current_code(&self) -> Result, ProgramError> { + self.get_code(self.registers.context) + } + + fn get_code(&self, context: usize) -> Result, ProgramError> { + let code_len = self.get_code_len()?; let code = (0..code_len) .map(|i| { u256_to_u8(self.memory.get(MemoryAddress { @@ -346,16 +356,25 @@ impl GenerationState { })) }) .collect::, _>>()?; + Ok(code) + } - // We need to set the simulated jumpdest bits to one as otherwise - // the simulation will fail. - let mut jumpdest_table = Vec::with_capacity(code.len()); + fn get_code_len(&self) -> Result { + let code_len = u256_to_usize(self.memory.get(MemoryAddress { + context: self.registers.context, + segment: Segment::ContextMetadata as usize, + virt: ContextMetadata::CodeSize as usize, + }))?; + Ok(code_len) + } + + fn set_jumpdest_bits<'a>(&mut self, code: &'a Vec) { + const JUMPDEST_OPCODE: u8 = 0x5b; for (pos, opcode) in CodeIterator::new(&code) { - jumpdest_table.push((pos, opcode == JUMPDEST_OPCODE)); if opcode == JUMPDEST_OPCODE { - state.memory.set( + self.memory.set( MemoryAddress { - context: state.registers.context, + context: self.registers.context, segment: Segment::JumpdestBits as usize, virt: pos, }, @@ -363,18 +382,50 @@ impl GenerationState { ); } } - - // Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call - simulate_cpu_between_labels_and_get_user_jumps( - "validate_jumpdest_table_end", - "terminate_common", - &mut state, - )?; - self.jumpdest_addresses = state.jumpdest_addresses; - Ok(()) } } +/// For each address in `jumpdest_table` it search a proof, that is the closest address +/// for which none of the previous 32 bytes in the code (including opcodes +/// and pushed bytes are PUSHXX and the address is in its range. It returns +/// a vector of even size containing proofs followed by their addresses +fn get_proofs_and_jumpdests<'a>( + code: &'a Vec, + largest_address: usize, + jumpdest_table: std::collections::BTreeSet, +) -> Vec { + const PUSH1_OPCODE: u8 = 0x60; + const PUSH32_OPCODE: u8 = 0x7f; + let (proofs, _) = CodeIterator::until(&code, largest_address + 1).fold( + (vec![], 0), + |(mut proofs, acc), (pos, opcode)| { + let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) { + code[prefix_start..pos] + .iter() + .enumerate() + .fold(true, |acc, (prefix_pos, &byte)| { + acc && (byte > PUSH32_OPCODE + || (prefix_start + prefix_pos) as i32 + + (byte as i32 - PUSH1_OPCODE as i32) + + 1 + < pos as i32) + }) + } else { + false + }; + let acc = if has_prefix { pos - 32 } else { acc }; + if jumpdest_table.contains(&pos) { + // Push the proof + proofs.push(acc); + // Push the address + proofs.push(pos); + } + (proofs, acc) + }, + ); + proofs +} + struct CodeIterator<'a> { code: &'a [u8], pos: usize, diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index 1c50cc29..cc1df091 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -51,8 +51,7 @@ pub(crate) struct GenerationState { /// Pointers, within the `TrieData` segment, of the three MPTs. pub(crate) trie_root_ptrs: TrieRootPtrs, - pub(crate) last_jumpdest_address: usize, - pub(crate) jumpdest_addresses: Option>>, + pub(crate) jumpdest_proofs: Option>>, } impl GenerationState { @@ -94,8 +93,7 @@ impl GenerationState { txn_root_ptr: 0, receipt_root_ptr: 0, }, - last_jumpdest_address: 0, - jumpdest_addresses: None, + jumpdest_proofs: None, }; let trie_root_ptrs = state.preinitialize_mpts(&inputs.tries); @@ -189,8 +187,7 @@ impl GenerationState { txn_root_ptr: 0, receipt_root_ptr: 0, }, - last_jumpdest_address: 0, - jumpdest_addresses: None, + jumpdest_proofs: None, } } }