diff --git a/evm/src/cpu/kernel/asm/core/call.asm b/evm/src/cpu/kernel/asm/core/call.asm index fcb4eb32..91204cc9 100644 --- a/evm/src/cpu/kernel/asm/core/call.asm +++ b/evm/src/cpu/kernel/asm/core/call.asm @@ -367,9 +367,10 @@ call_too_deep: %checkpoint // Checkpoint %increment_call_depth // Perform jumpdest analyis + %mload_context_metadata(@CTX_METADATA_CODE_SIZE) GET_CONTEXT // stack: ctx, code_size, retdest - %jumpdest_analisys + %jumpdest_analysis PUSH 0 // jump dest EXIT_KERNEL // (Old context) stack: new_ctx diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index 97224b3e..8debead3 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -102,7 +102,7 @@ code_bytes_to_skip: // - we can go from opcode i+32 to jumpdest, // - code[jumpdest] = 0x5b. // stack: proof_prefix_addr, jumpdest, ctx, retdest -// stack: (empty) abort if jumpdest is not a valid destination +// stack: (empty) global write_table_if_jumpdest: // stack: proof_prefix_addr, jumpdest, ctx, retdest %stack @@ -129,7 +129,7 @@ global write_table_if_jumpdest: %check_and_step(111) %check_and_step(110) %check_and_step(109) %check_and_step(108) %check_and_step(107) %check_and_step(106) %check_and_step(105) %check_and_step(104) %check_and_step(103) %check_and_step(102) %check_and_step(101) %check_and_step(100) - %check_and_step(99) %check_and_step(98) %check_and_step(97) %check_and_step(96) + %check_and_step(99) %check_and_step(98) %check_and_step(97) %check_and_step(96) // check the remaining path %jump(verify_path_and_write_table) @@ -173,9 +173,9 @@ return: // an the jumpdest address below. If that's the case we set the // corresponding bit in @SEGMENT_JUMPDEST_BITS to 1. // -// stack: ctx, retdest +// stack: ctx, code_len, retdest // stack: (empty) -global jumpdest_analisys: +global jumpdest_analysis: // If address > 0 then address is interpreted as address' + 1 // and the next prover input should contain a proof for address'. PROVER_INPUT(jumpdest_table::next_address) @@ -185,22 +185,26 @@ global jumpdest_analisys: // This is just a hook used for avoiding verification of the jumpdest // table in another contexts. It is useful during proof generation, // allowing the avoidance of table verification when simulating user code. -global jumpdest_analisys_end: - POP +global jumpdest_analysis_end: + %pop2 JUMP check_proof: + // stack: proof, ctx, code_len, retdest + DUP3 DUP2 %assert_le %decrement + // stack: proof, ctx, code_len, retdest DUP2 SWAP1 - // stack: address, ctx, ctx + // stack: address, ctx, ctx, code_len, retdest // We read the proof PROVER_INPUT(jumpdest_table::next_proof) - // stack: proof, address, ctx, ctx + // stack: proof, address, ctx, ctx, code_len, retdest %write_table_if_jumpdest + // stack: ctx, code_len, retdest - %jump(jumpdest_analisys) + %jump(jumpdest_analysis) -%macro jumpdest_analisys - %stack (ctx) -> (ctx, %%after) - %jump(jumpdest_analisys) +%macro jumpdest_analysis + %stack (ctx, code_len) -> (ctx, code_len, %%after) + %jump(jumpdest_analysis) %%after: %endmacro diff --git a/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs b/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs index 3d97251c..6ce52783 100644 --- a/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs +++ b/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs @@ -5,8 +5,8 @@ use crate::cpu::kernel::interpreter::Interpreter; use crate::cpu::kernel::opcodes::{get_opcode, get_push_opcode}; #[test] -fn test_jumpdest_analisys() -> Result<()> { - let jumpdest_analisys = KERNEL.global_labels["jumpdest_analisys"]; +fn test_jumpdest_analysis() -> Result<()> { + let jumpdest_analysis = KERNEL.global_labels["jumpdest_analysis"]; const CONTEXT: usize = 3; // arbitrary let add = get_opcode("ADD"); @@ -28,8 +28,8 @@ fn test_jumpdest_analisys() -> Result<()> { let jumpdest_bits = vec![false, true, false, false, false, true, false, true]; // Contract creation transaction. - let initial_stack = vec![0xDEADBEEFu32.into(), CONTEXT.into()]; - let mut interpreter = Interpreter::new_with_kernel(jumpdest_analisys, initial_stack); + let initial_stack = vec![0xDEADBEEFu32.into(), code.len().into(), CONTEXT.into()]; + let mut interpreter = Interpreter::new_with_kernel(jumpdest_analysis, initial_stack); interpreter.set_code(CONTEXT, code); interpreter.set_jumpdest_bits(CONTEXT, jumpdest_bits); diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 9c1b6c2b..ddd2bae9 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -340,9 +340,9 @@ fn simulate_cpu_between_labels_and_get_user_jumps( initial_label: &str, final_label: &str, state: &mut GenerationState, -) -> Result>>, ProgramError> { +) -> Option>> { if state.jumpdest_proofs.is_some() { - Ok(None) + None } else { const JUMP_OPCODE: u8 = 0x56; const JUMPI_OPCODE: u8 = 0x57; @@ -358,19 +358,25 @@ fn simulate_cpu_between_labels_and_get_user_jumps( loop { // skip jumdest table validations in simulations - if state.registers.program_counter == KERNEL.global_labels["jumpdest_analisys"] { - state.registers.program_counter = KERNEL.global_labels["jumpdest_analisys_end"] + if state.registers.program_counter == KERNEL.global_labels["jumpdest_analysis"] { + state.registers.program_counter = KERNEL.global_labels["jumpdest_analysis_end"] } let pc = state.registers.program_counter; let context = state.registers.context; - let halt = state.registers.is_kernel + let mut halt = state.registers.is_kernel && pc == halt_pc && state.registers.context == initial_context; - let opcode = u256_to_u8(state.memory.get(MemoryAddress { + let Ok(opcode) = u256_to_u8(state.memory.get(MemoryAddress { context, segment: Segment::Code as usize, virt: state.registers.program_counter, - }))?; + })) else { + log::debug!( + "Simulated CPU halted after {} cycles", + state.traces.clock() - initial_clock + ); + return Some(jumpdest_addresses); + }; let cond = if let Ok(cond) = stack_peek(state, 1) { cond != U256::zero() } else { @@ -380,7 +386,13 @@ fn simulate_cpu_between_labels_and_get_user_jumps( && (opcode == JUMP_OPCODE || (opcode == JUMPI_OPCODE && cond)) { // Avoid deeper calls to abort - let jumpdest = u256_to_usize(state.registers.stack_top)?; + let Ok(jumpdest) = u256_to_usize(state.registers.stack_top) else { + log::debug!( + "Simulated CPU halted after {} cycles", + state.traces.clock() - initial_clock + ); + return Some(jumpdest_addresses); + }; state.memory.set( MemoryAddress { context, @@ -400,16 +412,13 @@ fn simulate_cpu_between_labels_and_get_user_jumps( jumpdest_addresses.insert(context, BTreeSet::from([jumpdest])); } } - if halt { + if halt || transition(state).is_err() { log::debug!( "Simulated CPU halted after {} cycles", state.traces.clock() - initial_clock ); - return Ok(Some(jumpdest_addresses)); + return Some(jumpdest_addresses); } - transition(state).map_err(|_| { - ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation) - })?; } } } diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index ca9208ea..60192c74 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -302,15 +302,15 @@ impl GenerationState { let code = self.get_current_code()?; // We need to set the simulated jumpdest bits to one as otherwise // the simulation will fail. - self.set_jumpdest_bits(&code); + // self.set_jumpdest_bits(&code); // Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call let Some(jumpdest_table) = simulate_cpu_between_labels_and_get_user_jumps( - "jumpdest_analisys_end", + "jumpdest_analysis_end", "terminate_common", self, - )? - else { + ) else { + self.jumpdest_proofs = Some(HashMap::new()); return Ok(()); }; @@ -318,7 +318,7 @@ impl GenerationState { self.rollback(checkpoint); self.memory = memory; - // Find proofs for all context + // Find proofs for all contexts self.set_proofs_and_jumpdests(jumpdest_table); Ok(()) @@ -368,9 +368,9 @@ impl GenerationState { Ok(code_len) } - fn set_jumpdest_bits<'a>(&mut self, code: &'a Vec) { + fn set_jumpdest_bits(&mut self, code: &[u8]) { const JUMPDEST_OPCODE: u8 = 0x5b; - for (pos, opcode) in CodeIterator::new(&code) { + for (pos, opcode) in CodeIterator::new(code) { if opcode == JUMPDEST_OPCODE { self.memory.set( MemoryAddress { @@ -385,13 +385,13 @@ impl GenerationState { } } -/// For each address in `jumpdest_table`, each bounded by larges_address, +/// For all address in `jumpdest_table`, each bounded by larges_address, /// this function searches for a proof. A proof is the closest address /// for which none of the previous 32 bytes in the code (including opcodes /// and pushed bytes are PUSHXX and the address is in its range. It returns /// a vector of even size containing proofs followed by their addresses -fn get_proofs_and_jumpdests<'a>( - code: &'a Vec, +fn get_proofs_and_jumpdests( + code: &[u8], largest_address: usize, jumpdest_table: std::collections::BTreeSet, ) -> Vec { @@ -455,7 +455,7 @@ impl<'a> Iterator for CodeIterator<'a> { fn next(&mut self) -> Option { const PUSH1_OPCODE: u8 = 0x60; - const PUSH32_OPCODE: u8 = 0x70; + const PUSH32_OPCODE: u8 = 0x7f; let CodeIterator { code, pos, end } = self; if *pos >= *end { return None; diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index b8f962e7..96313b0d 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -396,8 +396,9 @@ fn try_perform_instruction( log_kernel_instruction(state, op); } else { log::debug!( - "User instruction: {:?}, ctx = {:?}, stack = {:?}", + "User instruction: {:?}, pc = {:?}, ctx = {:?}, stack = {:?}", op, + state.registers.program_counter, state.registers.context, state.stack() ); diff --git a/evm/tests/simple_transfer.rs b/evm/tests/simple_transfer.rs index 5fd252df..b5c43ff8 100644 --- a/evm/tests/simple_transfer.rs +++ b/evm/tests/simple_transfer.rs @@ -154,6 +154,9 @@ fn test_simple_transfer() -> anyhow::Result<()> { }, }; + let bytes = std::fs::read("jumpi_d18g0v0_Shanghai.json").unwrap(); + let inputs = serde_json::from_slice(&bytes).unwrap(); + let mut timing = TimingTree::new("prove", log::Level::Debug); let proof = prove::(&all_stark, &config, inputs, &mut timing, None)?; timing.filter(Duration::from_millis(100)).print();