From 3e8ad0868845cec242a4c2085add7c1462cf5baa Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Wed, 13 Dec 2023 17:33:53 +0100 Subject: [PATCH 01/37] Rebase to main --- evm/src/cpu/kernel/asm/core/call.asm | 9 +- .../cpu/kernel/asm/core/jumpdest_analysis.asm | 182 ++++++++++++++++++ evm/src/generation/mod.rs | 67 ++++++- evm/src/generation/prover_input.rs | 153 ++++++++++++++- evm/src/generation/state.rs | 26 +++ evm/src/prover.rs | 4 +- evm/src/util.rs | 5 + evm/src/witness/errors.rs | 2 + evm/src/witness/transition.rs | 2 +- 9 files changed, 435 insertions(+), 15 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/call.asm b/evm/src/cpu/kernel/asm/core/call.asm index 2e7d1d73..5a2a14c4 100644 --- a/evm/src/cpu/kernel/asm/core/call.asm +++ b/evm/src/cpu/kernel/asm/core/call.asm @@ -367,11 +367,12 @@ call_too_deep: %checkpoint // Checkpoint %increment_call_depth // Perform jumpdest analyis - PUSH %%after - %mload_context_metadata(@CTX_METADATA_CODE_SIZE) - GET_CONTEXT + // PUSH %%after + // %mload_context_metadata(@CTX_METADATA_CODE_SIZE) + // GET_CONTEXT // stack: ctx, code_size, retdest - %jump(jumpdest_analysis) + // %jump(jumpdest_analysis) + %validate_jumpdest_table %%after: PUSH 0 // jump dest EXIT_KERNEL diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index bda6f96e..09bb35fa 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -1,3 +1,48 @@ +// Set @SEGMENT_JUMPDEST_BITS to one between positions [init_pos, final_pos], +// for the given context's code. Panics if we never hit final_pos +// Pre stack: init_pos, ctx, final_pos, retdest +// Post stack: (empty) +global verify_path: +loop_new: + // stack: i, ctx, final_pos, retdest + // Ideally we would break if i >= final_pos, but checking i > final_pos is + // cheaper. It doesn't hurt to over-read by 1, since we'll read 0 which is + // a no-op. + DUP3 DUP2 EQ // i == final_pos + %jumpi(return_new) + DUP3 DUP2 GT // i > final_pos + %jumpi(panic) + + // stack: i, ctx, final_pos, retdest + %stack (i, ctx) -> (ctx, @SEGMENT_CODE, i, i, ctx) + MLOAD_GENERAL + // stack: opcode, i, ctx, final_pos, retdest + + DUP1 + // Slightly more efficient than `%eq_const(0x5b) ISZERO` + PUSH 0x5b + SUB + // stack: opcode != JUMPDEST, opcode, i, ctx, code_len, retdest + %jumpi(continue_new) + + // stack: JUMPDEST, i, ctx, code_len, retdest + %stack (JUMPDEST, i, ctx) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, i, JUMPDEST, i, ctx) + MSTORE_GENERAL + +continue_new: + // stack: opcode, i, ctx, code_len, retdest + %add_const(code_bytes_to_skip) + %mload_kernel_code + // stack: bytes_to_skip, i, ctx, code_len, retdest + ADD + // stack: i, ctx, code_len, retdest + %jump(loop_new) + +return_new: + // stack: i, ctx, code_len, retdest + %pop3 + JUMP + // Populates @SEGMENT_JUMPDEST_BITS for the given context's code. // Pre stack: ctx, code_len, retdest // Post stack: (empty) @@ -89,3 +134,140 @@ code_bytes_to_skip: %rep 128 BYTES 1 // 0x80-0xff %endrep + + +// A proof attesting that jumpdest is a valid jump destinations is +// either 0 or an index 0 < i <= jumpdest - 32. +// A proof is valid if: +// - i == 0 and we can go from the first opcode to jumpdest and code[jumpdest] = 0x5b +// - i > 0 and: +// - for j in {i+0,..., i+31} code[j] != PUSHk for all k >= 32 - j - i, +// - we can go from opcode i+32 to jumpdest, +// - code[jumpdest] = 0x5b. +// stack: proof_prefix_addr, jumpdest, retdest +// stack: (empty) abort if jumpdest is not a valid destination +global is_jumpdest: + GET_CONTEXT + // stack: ctx, proof_prefix_addr, jumpdest, retdest + %stack + (ctx, proof_prefix_addr, jumpdest) -> + (ctx, @SEGMENT_CODE, jumpdest, jumpdest, ctx, proof_prefix_addr) + MLOAD_GENERAL + // stack: opcode, jumpdest, ctx, proof_prefix_addr, retdest + + // Slightly more efficient than `%eq_const(0x5b) ISZERO` + PUSH 0x5b + SUB + %jumpi(panic) + + //stack: jumpdest, ctx, proof_prefix_addr, retdest + SWAP2 DUP1 + // stack: proof_prefix_addr, proof_prefix_addr, ctx, jumpdest + %eq_const(0) + %jumpi(verify_path) + //stack: proof_prefix_addr, ctx, jumpdest, retdest + // If we are here we need to check that the next 32 bytes are less + // than JUMPXX for XX < 32 - i <=> opcode < 0x7f - i = 127 - i, 0 <= i < 32, + // or larger than 127 + %check_and_step(127) %check_and_step(126) %check_and_step(125) %check_and_step(124) + %check_and_step(123) %check_and_step(122) %check_and_step(121) %check_and_step(120) + %check_and_step(119) %check_and_step(118) %check_and_step(117) %check_and_step(116) + %check_and_step(115) %check_and_step(114) %check_and_step(113) %check_and_step(112) + %check_and_step(111) %check_and_step(110) %check_and_step(109) %check_and_step(108) + %check_and_step(107) %check_and_step(106) %check_and_step(105) %check_and_step(104) + %check_and_step(103) %check_and_step(102) %check_and_step(101) %check_and_step(100) + %check_and_step(99) %check_and_step(98) %check_and_step(97) %check_and_step(96) + + // check the remaining path + %jump(verify_path) + +return_is_jumpdest: + //stack: proof_prefix_addr, jumpdest, retdest + %pop2 + JUMP + + +// Chek if the opcode pointed by proof_prefix address is +// less than max and increment proof_prefix_addr +%macro check_and_step(max) + %stack + (proof_prefix_addr, ctx, jumpdest) -> + (ctx, @SEGMENT_CODE, proof_prefix_addr, proof_prefix_addr, ctx, jumpdest) + MLOAD_GENERAL + // stack: opcode, proof_prefix_addr, ctx, jumpdest + DUP1 + %gt_const(127) + %jumpi(%%ok) + %assert_lt_const($max) + // stack: proof_prefix_addr, ctx, jumpdest + PUSH 0 // We need something to pop +%%ok: + POP + %increment +%endmacro + +%macro is_jumpdest + %stack (proof, addr) -> (proof, addr, %%after) + %jump(is_jumpdest) +%%after: +%endmacro + +// Check if the jumpdest table is correct. This is done by +// non-deterministically guessing the sequence of jumpdest +// addresses used during program execution within the current context. +// For each jumpdest address we also non-deterministically guess +// a proof, which is another address in the code, such that +// is_jumpdest don't abort when the proof is on the top of the stack +// an the jumpdest address below. If that's the case we set the +// corresponding bit in @SEGMENT_JUMPDEST_BITS to 1. +// +// stack: retdest +// stack: (empty) +global validate_jumpdest_table: + // If address > 0 it is interpreted as address' = address - 1 + // and the next prover input should contain a proof for address'. + PROVER_INPUT(jumpdest_table::next_address) + DUP1 %jumpi(check_proof) + // If proof == 0 there are no more jump destionations to check + POP +global validate_jumpdest_table_end: + JUMP + // were set to 0 + //%mload_context_metadata(@CTX_METADATA_CODE_SIZE) + // get the code length in bytes + //%add_const(31) + //%div_const(32) + //GET_CONTEXT + //SWAP2 +//verify_chunk: + // stack: i (= proof), code_len, ctx = 0 + //%stack (i, code_len, ctx) -> (code_len, i, ctx, @SEGMENT_JUMPDEST_BITS, i, 32, i, code_len, ctx) + //GT + //%jumpi(valid_table) + //%mload_packing + // stack: packed_bits, code_len, i, ctx + //%assert_eq_const(0) + //%increment + //%jump(verify_chunk) + +check_proof: + %sub_const(1) + DUP1 + // We read the proof + PROVER_INPUT(jumpdest_table::next_proof) + // stack: proof, address + %is_jumpdest + GET_CONTEXT + %stack (ctx, address) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, address) + MSTORE_GENERAL + %jump(validate_jumpdest_table) +valid_table: + // stack: ctx, @SEGMENT_JUMPDEST_BITS, i, 32, i, code_len, ctx, retdest + %pop7 + JUMP + +%macro validate_jumpdest_table + PUSH %%after + %jump(validate_jumpdest_table) +%%after: +%endmacro diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 77b6fd36..94c0e432 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use anyhow::anyhow; use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; @@ -6,6 +6,7 @@ use ethereum_types::{Address, BigEndianHash, H256, U256}; use itertools::enumerate; use plonky2::field::extension::Extendable; use plonky2::field::polynomial::PolynomialValues; +use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::timed; use plonky2::util::timing::TimingTree; @@ -19,11 +20,13 @@ use crate::all_stark::{AllStark, NUM_TABLES}; use crate::config::StarkConfig; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::assembler::Kernel; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; +use crate::cpu::kernel::opcodes::get_opcode; use crate::generation::state::GenerationState; use crate::memory::segments::Segment; use crate::proof::{BlockHashes, BlockMetadata, ExtraBlockData, PublicValues, TrieRoots}; -use crate::util::h2u; +use crate::util::{h2u, u256_to_u8, u256_to_usize}; use crate::witness::memory::{MemoryAddress, MemoryChannel}; use crate::witness::transition::transition; @@ -34,7 +37,7 @@ pub(crate) mod state; mod trie_extractor; use self::mpt::{load_all_mpts, TrieRootPtrs}; -use crate::witness::util::mem_write_log; +use crate::witness::util::{mem_write_log, stack_peek}; /// Inputs needed for trace generation. #[derive(Clone, Debug, Deserialize, Serialize, Default)] @@ -244,9 +247,7 @@ pub fn generate_traces, const D: usize>( Ok((tables, public_values)) } -fn simulate_cpu, const D: usize>( - state: &mut GenerationState, -) -> anyhow::Result<()> { +fn simulate_cpu(state: &mut GenerationState) -> anyhow::Result<()> { let halt_pc = KERNEL.global_labels["halt"]; loop { @@ -281,3 +282,57 @@ fn simulate_cpu, const D: usize>( transition(state)?; } } + +fn simulate_cpu_between_labels_and_get_user_jumps( + initial_label: &str, + final_label: &str, + state: &mut GenerationState, +) -> anyhow::Result> { + let halt_pc = KERNEL.global_labels[final_label]; + let mut jumpdest_addresses = HashSet::new(); + state.registers.program_counter = KERNEL.global_labels[initial_label]; + let context = state.registers.context; + + loop { + if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] { + state.registers.program_counter = KERNEL.global_labels["validate_jumpdest_table_end"] + } + let pc = state.registers.program_counter; + let halt = state.registers.is_kernel && pc == halt_pc && state.registers.context == context; + let opcode = u256_to_u8(state.memory.get(MemoryAddress { + context: state.registers.context, + segment: Segment::Code as usize, + virt: state.registers.program_counter, + })) + .map_err(|_| anyhow::Error::msg("Invalid opcode."))?; + let cond = if let Ok(cond) = stack_peek(state, 1) { + cond != U256::zero() + } else { + false + }; + if !state.registers.is_kernel + && (opcode == get_opcode("JUMP") || (opcode == get_opcode("JUMPI") && cond)) + { + // TODO: hotfix for avoiding deeper calls to abort + let jumpdest = u256_to_usize(state.registers.stack_top) + .map_err(|_| anyhow::Error::msg("Not a valid jump destination"))?; + state.memory.set( + MemoryAddress { + context: state.registers.context, + segment: Segment::JumpdestBits as usize, + virt: jumpdest, + }, + U256::one(), + ); + if (state.registers.context == context) { + jumpdest_addresses.insert(jumpdest); + } + } + if halt { + log::debug!("Simulated CPU halted after {} cycles", state.traces.clock()); + return Ok(jumpdest_addresses.into_iter().collect()); + } + + transition(state)?; + } +} diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index b2a8f0ce..852cd77d 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -1,24 +1,34 @@ +use std::cmp::min; +use std::collections::HashSet; use std::mem::transmute; use std::str::FromStr; use anyhow::{bail, Error}; use ethereum_types::{BigEndianHash, H256, U256, U512}; +use hashbrown::HashMap; use itertools::{enumerate, Itertools}; use num_bigint::BigUint; +use plonky2::field::extension::Extendable; use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; use serde::{Deserialize, Serialize}; +use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::constants::context_metadata::ContextMetadata; +use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; +use crate::cpu::kernel::opcodes::{get_opcode, get_push_opcode}; use crate::extension_tower::{FieldExt, Fp12, BLS381, BN254}; use crate::generation::prover_input::EvmField::{ Bls381Base, Bls381Scalar, Bn254Base, Bn254Scalar, Secp256k1Base, Secp256k1Scalar, }; use crate::generation::prover_input::FieldOp::{Inverse, Sqrt}; +use crate::generation::simulate_cpu_between_labels_and_get_user_jumps; use crate::generation::state::GenerationState; use crate::memory::segments::Segment; use crate::memory::segments::Segment::BnPairing; -use crate::util::{biguint_to_mem_vec, mem_vec_to_biguint, u256_to_usize}; -use crate::witness::errors::ProgramError; +use crate::util::{biguint_to_mem_vec, mem_vec_to_biguint, u256_to_u8, u256_to_usize}; use crate::witness::errors::ProverInputError::*; +use crate::witness::errors::{ProgramError, ProverInputError}; use crate::witness::memory::MemoryAddress; use crate::witness::util::{current_context_peek, stack_peek}; @@ -47,6 +57,7 @@ impl GenerationState { "bignum_modmul" => self.run_bignum_modmul(), "withdrawal" => self.run_withdrawal(), "num_bits" => self.run_num_bits(), + "jumpdest_table" => self.run_jumpdest_table(input_fn), _ => Err(ProgramError::ProverInputError(InvalidFunction)), } } @@ -229,6 +240,144 @@ impl GenerationState { Ok(num_bits.into()) } } + + fn run_jumpdest_table(&mut self, input_fn: &ProverInputFn) -> Result { + match input_fn.0[1].as_str() { + "next_address" => self.run_next_jumpdest_table_address(), + "next_proof" => self.run_next_jumpdest_table_proof(), + _ => Err(ProgramError::ProverInputError(InvalidInput)), + } + } + /// Return the next used jump addres + fn run_next_jumpdest_table_address(&mut self) -> Result { + let code_len = u256_to_usize(self.memory.get(MemoryAddress { + context: self.registers.context, + segment: Segment::ContextMetadata as usize, + virt: ContextMetadata::CodeSize as usize, + }))?; + + if self.jumpdest_addresses.is_none() { + let mut state: GenerationState = self.soft_clone(); + + let mut jumpdest_addresses = vec![]; + // Generate the jumpdest table + let code = (0..code_len) + .map(|i| { + u256_to_u8(self.memory.get(MemoryAddress { + context: self.registers.context, + segment: Segment::Code as usize, + virt: i, + })) + }) + .collect::, _>>()?; + let mut i = 0; + while i < code_len { + if code[i] == get_opcode("JUMPDEST") { + jumpdest_addresses.push(i); + state.memory.set( + MemoryAddress { + context: state.registers.context, + segment: Segment::JumpdestBits as usize, + virt: i, + }, + U256::one(), + ); + log::debug!("jumpdest at {i}"); + } + i += if code[i] >= get_push_opcode(1) && code[i] <= get_push_opcode(32) { + (code[i] - get_push_opcode(1) + 2).into() + } else { + 1 + } + } + + // We need to skip the validate table call + self.jumpdest_addresses = simulate_cpu_between_labels_and_get_user_jumps( + "validate_jumpdest_table_end", + "terminate_common", + &mut state, + ) + .ok(); + log::debug!("code len = {code_len}"); + log::debug!("all jumpdest addresses = {:?}", jumpdest_addresses); + log::debug!("user's jumdest addresses = {:?}", self.jumpdest_addresses); + // self.jumpdest_addresses = Some(jumpdest_addresses); + } + + let Some(jumpdest_table) = &mut self.jumpdest_addresses else { + // TODO: Add another error + return Err(ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation)); + }; + + if let Some(next_jumpdest_address) = jumpdest_table.pop() { + self.last_jumpdest_address = next_jumpdest_address; + Ok((next_jumpdest_address + 1).into()) + } else { + self.jumpdest_addresses = None; + Ok(U256::zero()) + } + } + + /// Return the proof for the last jump adddress + fn run_next_jumpdest_table_proof(&mut self) -> Result { + let code_len = u256_to_usize(self.memory.get(MemoryAddress { + context: self.registers.context, + segment: Segment::ContextMetadata as usize, + virt: ContextMetadata::CodeSize as usize, + }))?; + + let mut address = MemoryAddress { + context: self.registers.context, + segment: Segment::Code as usize, + virt: 0, + }; + let mut proof = 0; + let mut prefix_size = 0; + + // TODO: The proof searching algorithm is not very eficient. But luckyly it doesn't seem + // a problem because is done natively. + + // Search the closest address to last_jumpdest_address for which none of + // the previous 32 bytes in the code (including opcodes and pushed bytes) + // are PUSHXX and the address is in its range + while address.virt < self.last_jumpdest_address { + let opcode = u256_to_u8(self.memory.get(address))?; + let is_push = + opcode >= get_push_opcode(1).into() && opcode <= get_push_opcode(32).into(); + + address.virt += if is_push { + (opcode - get_push_opcode(1) + 2).into() + } else { + 1 + }; + // Check if the new address has a prefix of size >= 32 + let mut has_prefix = true; + for i in address.virt as i32 - 32..address.virt as i32 { + let opcode = u256_to_u8(self.memory.get(MemoryAddress { + context: self.registers.context, + segment: Segment::Code as usize, + virt: i as usize, + }))?; + if i < 0 + || (opcode >= get_push_opcode(1) + && opcode <= get_push_opcode(32) + && i + (opcode - get_push_opcode(1)) as i32 + 1 >= address.virt as i32) + { + has_prefix = false; + break; + } + } + if has_prefix { + proof = address.virt - 32; + } + } + if address.virt > self.last_jumpdest_address { + return Err(ProgramError::ProverInputError( + ProverInputError::InvalidJumpDestination, + )); + } + Ok(proof.into()) + } } enum EvmField { diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index 89ff0c5a..79dd94fb 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -50,6 +50,9 @@ pub(crate) struct GenerationState { /// Pointers, within the `TrieData` segment, of the three MPTs. pub(crate) trie_root_ptrs: TrieRootPtrs, + + pub(crate) last_jumpdest_address: usize, + pub(crate) jumpdest_addresses: Option>, } impl GenerationState { @@ -91,6 +94,8 @@ impl GenerationState { txn_root_ptr: 0, receipt_root_ptr: 0, }, + last_jumpdest_address: 0, + jumpdest_addresses: None, }; let trie_root_ptrs = state.preinitialize_mpts(&inputs.tries); @@ -167,6 +172,27 @@ impl GenerationState { .map(|i| stack_peek(self, i).unwrap()) .collect() } + + /// Clone everything but the traces + pub(crate) fn soft_clone(&self) -> GenerationState { + Self { + inputs: self.inputs.clone(), + registers: self.registers.clone(), + memory: self.memory.clone(), + traces: Traces::default(), + rlp_prover_inputs: self.rlp_prover_inputs.clone(), + state_key_to_address: self.state_key_to_address.clone(), + bignum_modmul_result_limbs: self.bignum_modmul_result_limbs.clone(), + withdrawal_prover_inputs: self.withdrawal_prover_inputs.clone(), + trie_root_ptrs: TrieRootPtrs { + state_root_ptr: 0, + txn_root_ptr: 0, + receipt_root_ptr: 0, + }, + last_jumpdest_address: 0, + jumpdest_addresses: None, + } + } } /// Withdrawals prover input array is of the form `[addr0, amount0, ..., addrN, amountN, U256::MAX, U256::MAX]`. diff --git a/evm/src/prover.rs b/evm/src/prover.rs index ab33a661..51fea9a4 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -377,7 +377,7 @@ where let alphas = challenger.get_n_challenges(config.num_challenges); - #[cfg(test)] + // #[cfg(test)] { check_constraints( stark, @@ -636,7 +636,7 @@ where .collect() } -#[cfg(test)] +// #[cfg(test)] /// Check that all constraints evaluate to zero on `H`. /// Can also be used to check the degree of the constraints by evaluating on a larger subgroup. fn check_constraints<'a, F, C, S, const D: usize>( diff --git a/evm/src/util.rs b/evm/src/util.rs index 9cac52c6..7a635c0e 100644 --- a/evm/src/util.rs +++ b/evm/src/util.rs @@ -70,6 +70,11 @@ pub(crate) fn u256_to_u64(u256: U256) -> Result<(F, F), ProgramError> )) } +/// Safe alternative to `U256::as_u8()`, which errors in case of overflow instead of panicking. +pub(crate) fn u256_to_u8(u256: U256) -> Result { + u256.try_into().map_err(|_| ProgramError::IntegerTooLarge) +} + /// Safe alternative to `U256::as_usize()`, which errors in case of overflow instead of panicking. pub(crate) fn u256_to_usize(u256: U256) -> Result { u256.try_into().map_err(|_| ProgramError::IntegerTooLarge) diff --git a/evm/src/witness/errors.rs b/evm/src/witness/errors.rs index 5a0fcbfb..1b266aef 100644 --- a/evm/src/witness/errors.rs +++ b/evm/src/witness/errors.rs @@ -36,4 +36,6 @@ pub enum ProverInputError { InvalidInput, InvalidFunction, NumBitsError, + InvalidJumpDestination, + InvalidJumpdestSimulation, } diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index cf2e3bbe..0fa14321 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -395,7 +395,7 @@ fn try_perform_instruction( if state.registers.is_kernel { log_kernel_instruction(state, op); } else { - log::debug!("User instruction: {:?}", op); + log::debug!("User instruction: {:?} stack = {:?}", op, state.stack()); } fill_op_flag(op, &mut row); From f76ab777417e757ec6d0e4d93d7524efdf9f5c4d Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Wed, 13 Dec 2023 14:11:43 +0100 Subject: [PATCH 02/37] Refactor run_next_jumpdest_table_proof --- evm/src/cpu/kernel/asm/core/call.asm | 7 +- .../cpu/kernel/asm/core/jumpdest_analysis.asm | 114 +++------- evm/src/cpu/kernel/interpreter.rs | 9 + .../kernel/tests/core/jumpdest_analysis.rs | 60 +++-- evm/src/generation/mod.rs | 6 +- evm/src/generation/prover_input.rs | 205 +++++++++++------- evm/src/prover.rs | 4 +- evm/src/witness/transition.rs | 7 +- 8 files changed, 202 insertions(+), 210 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/call.asm b/evm/src/cpu/kernel/asm/core/call.asm index 5a2a14c4..46765954 100644 --- a/evm/src/cpu/kernel/asm/core/call.asm +++ b/evm/src/cpu/kernel/asm/core/call.asm @@ -367,13 +367,10 @@ call_too_deep: %checkpoint // Checkpoint %increment_call_depth // Perform jumpdest analyis - // PUSH %%after - // %mload_context_metadata(@CTX_METADATA_CODE_SIZE) - // GET_CONTEXT + GET_CONTEXT // stack: ctx, code_size, retdest - // %jump(jumpdest_analysis) %validate_jumpdest_table -%%after: + PUSH 0 // jump dest EXIT_KERNEL // (Old context) stack: new_ctx diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index 09bb35fa..76c25fa0 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -3,13 +3,13 @@ // Pre stack: init_pos, ctx, final_pos, retdest // Post stack: (empty) global verify_path: -loop_new: +loop: // stack: i, ctx, final_pos, retdest // Ideally we would break if i >= final_pos, but checking i > final_pos is // cheaper. It doesn't hurt to over-read by 1, since we'll read 0 which is // a no-op. DUP3 DUP2 EQ // i == final_pos - %jumpi(return_new) + %jumpi(return) DUP3 DUP2 GT // i > final_pos %jumpi(panic) @@ -18,51 +18,6 @@ loop_new: MLOAD_GENERAL // stack: opcode, i, ctx, final_pos, retdest - DUP1 - // Slightly more efficient than `%eq_const(0x5b) ISZERO` - PUSH 0x5b - SUB - // stack: opcode != JUMPDEST, opcode, i, ctx, code_len, retdest - %jumpi(continue_new) - - // stack: JUMPDEST, i, ctx, code_len, retdest - %stack (JUMPDEST, i, ctx) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, i, JUMPDEST, i, ctx) - MSTORE_GENERAL - -continue_new: - // stack: opcode, i, ctx, code_len, retdest - %add_const(code_bytes_to_skip) - %mload_kernel_code - // stack: bytes_to_skip, i, ctx, code_len, retdest - ADD - // stack: i, ctx, code_len, retdest - %jump(loop_new) - -return_new: - // stack: i, ctx, code_len, retdest - %pop3 - JUMP - -// Populates @SEGMENT_JUMPDEST_BITS for the given context's code. -// Pre stack: ctx, code_len, retdest -// Post stack: (empty) -global jumpdest_analysis: - // stack: ctx, code_len, retdest - PUSH 0 // i = 0 - -loop: - // stack: i, ctx, code_len, retdest - // Ideally we would break if i >= code_len, but checking i > code_len is - // cheaper. It doesn't hurt to over-read by 1, since we'll read 0 which is - // a no-op. - DUP3 DUP2 GT // i > code_len - %jumpi(return) - - // stack: i, ctx, code_len, retdest - %stack (i, ctx) -> (ctx, @SEGMENT_CODE, i, i, ctx) - MLOAD_GENERAL - // stack: opcode, i, ctx, code_len, retdest - DUP1 // Slightly more efficient than `%eq_const(0x5b) ISZERO` PUSH 0x5b @@ -144,13 +99,17 @@ code_bytes_to_skip: // - for j in {i+0,..., i+31} code[j] != PUSHk for all k >= 32 - j - i, // - we can go from opcode i+32 to jumpdest, // - code[jumpdest] = 0x5b. -// stack: proof_prefix_addr, jumpdest, retdest +// stack: proof_prefix_addr, jumpdest, ctx, retdest // stack: (empty) abort if jumpdest is not a valid destination global is_jumpdest: - GET_CONTEXT - // stack: ctx, proof_prefix_addr, jumpdest, retdest + // stack: proof_prefix_addr, jumpdest, ctx, retdest + //%stack + // (proof_prefix_addr, jumpdest, ctx) -> + // (ctx, @SEGMENT_JUMPDEST_BITS, jumpdest, proof_prefix_addr, jumpdest, ctx) + //MLOAD_GENERAL + //%jumpi(return_is_jumpdest) %stack - (ctx, proof_prefix_addr, jumpdest) -> + (proof_prefix_addr, jumpdest, ctx) -> (ctx, @SEGMENT_CODE, jumpdest, jumpdest, ctx, proof_prefix_addr) MLOAD_GENERAL // stack: opcode, jumpdest, ctx, proof_prefix_addr, retdest @@ -182,8 +141,8 @@ global is_jumpdest: %jump(verify_path) return_is_jumpdest: - //stack: proof_prefix_addr, jumpdest, retdest - %pop2 + //stack: proof_prefix_addr, jumpdest, ctx, retdest + %pop3 JUMP @@ -194,7 +153,7 @@ return_is_jumpdest: (proof_prefix_addr, ctx, jumpdest) -> (ctx, @SEGMENT_CODE, proof_prefix_addr, proof_prefix_addr, ctx, jumpdest) MLOAD_GENERAL - // stack: opcode, proof_prefix_addr, ctx, jumpdest + // stack: opcode, ctx, proof_prefix_addr, jumpdest DUP1 %gt_const(127) %jumpi(%%ok) @@ -207,7 +166,7 @@ return_is_jumpdest: %endmacro %macro is_jumpdest - %stack (proof, addr) -> (proof, addr, %%after) + %stack (proof, addr, ctx) -> (proof, addr, ctx, %%after) %jump(is_jumpdest) %%after: %endmacro @@ -216,58 +175,41 @@ return_is_jumpdest: // non-deterministically guessing the sequence of jumpdest // addresses used during program execution within the current context. // For each jumpdest address we also non-deterministically guess -// a proof, which is another address in the code, such that -// is_jumpdest don't abort when the proof is on the top of the stack +// a proof, which is another address in the code such that +// is_jumpdest don't abort, when the proof is at the top of the stack // an the jumpdest address below. If that's the case we set the // corresponding bit in @SEGMENT_JUMPDEST_BITS to 1. // -// stack: retdest +// stack: ctx, retdest // stack: (empty) global validate_jumpdest_table: - // If address > 0 it is interpreted as address' = address - 1 + // If address > 0 then address is interpreted as address' + 1 // and the next prover input should contain a proof for address'. PROVER_INPUT(jumpdest_table::next_address) DUP1 %jumpi(check_proof) // If proof == 0 there are no more jump destionations to check POP +// This is just a hook used for avoiding verification of the jumpdest +// table in another contexts. It is useful during proof generation, +// allowing the avoidance of table verification when simulating user code. global validate_jumpdest_table_end: + POP JUMP - // were set to 0 - //%mload_context_metadata(@CTX_METADATA_CODE_SIZE) - // get the code length in bytes - //%add_const(31) - //%div_const(32) - //GET_CONTEXT - //SWAP2 -//verify_chunk: - // stack: i (= proof), code_len, ctx = 0 - //%stack (i, code_len, ctx) -> (code_len, i, ctx, @SEGMENT_JUMPDEST_BITS, i, 32, i, code_len, ctx) - //GT - //%jumpi(valid_table) - //%mload_packing - // stack: packed_bits, code_len, i, ctx - //%assert_eq_const(0) - //%increment - //%jump(verify_chunk) - check_proof: %sub_const(1) - DUP1 + DUP2 DUP2 + // stack: address, ctx, address, ctx // We read the proof PROVER_INPUT(jumpdest_table::next_proof) - // stack: proof, address + // stack: proof, address, ctx, address, ctx %is_jumpdest - GET_CONTEXT - %stack (ctx, address) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, address) + %stack (address, ctx) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, address, ctx) MSTORE_GENERAL + %jump(validate_jumpdest_table) -valid_table: - // stack: ctx, @SEGMENT_JUMPDEST_BITS, i, 32, i, code_len, ctx, retdest - %pop7 - JUMP %macro validate_jumpdest_table - PUSH %%after + %stack (ctx) -> (ctx, %%after) %jump(validate_jumpdest_table) %%after: %endmacro diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index f187a480..78691632 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -10,6 +10,7 @@ use keccak_hash::keccak; use plonky2::field::goldilocks_field::GoldilocksField; use super::assembler::BYTES_PER_OFFSET; +use super::utils::u256_from_bool; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; @@ -289,6 +290,14 @@ impl<'a> Interpreter<'a> { .collect() } + pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec) { + self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize] + .content = jumpdest_bits + .into_iter() + .map(|x| u256_from_bool(x)) + .collect(); + } + fn incr(&mut self, n: usize) { self.generation_state.registers.program_counter += n; } diff --git a/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs b/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs index 022a18d7..d3edc17b 100644 --- a/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs +++ b/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs @@ -4,39 +4,37 @@ use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::interpreter::Interpreter; use crate::cpu::kernel::opcodes::{get_opcode, get_push_opcode}; -#[test] -fn test_jumpdest_analysis() -> Result<()> { - let jumpdest_analysis = KERNEL.global_labels["jumpdest_analysis"]; - const CONTEXT: usize = 3; // arbitrary +// #[test] +// fn test_jumpdest_analysis() -> Result<()> { +// let jumpdest_analysis = KERNEL.global_labels["validate_jumpdest_table"]; +// const CONTEXT: usize = 3; // arbitrary - let add = get_opcode("ADD"); - let push2 = get_push_opcode(2); - let jumpdest = get_opcode("JUMPDEST"); +// let add = get_opcode("ADD"); +// let push2 = get_push_opcode(2); +// let jumpdest = get_opcode("JUMPDEST"); - #[rustfmt::skip] - let code: Vec = vec![ - add, - jumpdest, - push2, - jumpdest, // part of PUSH2 - jumpdest, // part of PUSH2 - jumpdest, - add, - jumpdest, - ]; +// #[rustfmt::skip] +// let code: Vec = vec![ +// add, +// jumpdest, +// push2, +// jumpdest, // part of PUSH2 +// jumpdest, // part of PUSH2 +// jumpdest, +// add, +// jumpdest, +// ]; - let expected_jumpdest_bits = vec![false, true, false, false, false, true, false, true]; +// let jumpdest_bits = vec![false, true, false, false, false, true, false, true]; - // Contract creation transaction. - let initial_stack = vec![0xDEADBEEFu32.into(), code.len().into(), CONTEXT.into()]; - let mut interpreter = Interpreter::new_with_kernel(jumpdest_analysis, initial_stack); - interpreter.set_code(CONTEXT, code); - interpreter.run()?; - assert_eq!(interpreter.stack(), vec![]); - assert_eq!( - interpreter.get_jumpdest_bits(CONTEXT), - expected_jumpdest_bits - ); +// // Contract creation transaction. +// let initial_stack = vec![0xDEADBEEFu32.into(), CONTEXT.into()]; +// let mut interpreter = Interpreter::new_with_kernel(jumpdest_analysis, initial_stack); +// interpreter.set_code(CONTEXT, code); +// interpreter.set_jumpdest_bits(CONTEXT, jumpdest_bits); - Ok(()) -} +// interpreter.run()?; +// assert_eq!(interpreter.stack(), vec![]); + +// Ok(()) +// } diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 94c0e432..71176da5 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -293,6 +293,8 @@ fn simulate_cpu_between_labels_and_get_user_jumps( state.registers.program_counter = KERNEL.global_labels[initial_label]; let context = state.registers.context; + log::debug!("Simulating CPU for jumpdest analysis "); + loop { if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] { state.registers.program_counter = KERNEL.global_labels["validate_jumpdest_table_end"] @@ -330,7 +332,9 @@ fn simulate_cpu_between_labels_and_get_user_jumps( } if halt { log::debug!("Simulated CPU halted after {} cycles", state.traces.clock()); - return Ok(jumpdest_addresses.into_iter().collect()); + let mut jumpdest_addresses: Vec = jumpdest_addresses.into_iter().collect(); + jumpdest_addresses.sort(); + return Ok(jumpdest_addresses); } transition(state)?; diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 852cd77d..2435f4eb 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -257,51 +257,7 @@ impl GenerationState { }))?; if self.jumpdest_addresses.is_none() { - let mut state: GenerationState = self.soft_clone(); - - let mut jumpdest_addresses = vec![]; - // Generate the jumpdest table - let code = (0..code_len) - .map(|i| { - u256_to_u8(self.memory.get(MemoryAddress { - context: self.registers.context, - segment: Segment::Code as usize, - virt: i, - })) - }) - .collect::, _>>()?; - let mut i = 0; - while i < code_len { - if code[i] == get_opcode("JUMPDEST") { - jumpdest_addresses.push(i); - state.memory.set( - MemoryAddress { - context: state.registers.context, - segment: Segment::JumpdestBits as usize, - virt: i, - }, - U256::one(), - ); - log::debug!("jumpdest at {i}"); - } - i += if code[i] >= get_push_opcode(1) && code[i] <= get_push_opcode(32) { - (code[i] - get_push_opcode(1) + 2).into() - } else { - 1 - } - } - - // We need to skip the validate table call - self.jumpdest_addresses = simulate_cpu_between_labels_and_get_user_jumps( - "validate_jumpdest_table_end", - "terminate_common", - &mut state, - ) - .ok(); - log::debug!("code len = {code_len}"); - log::debug!("all jumpdest addresses = {:?}", jumpdest_addresses); - log::debug!("user's jumdest addresses = {:?}", self.jumpdest_addresses); - // self.jumpdest_addresses = Some(jumpdest_addresses); + self.generate_jumpdest_table()?; } let Some(jumpdest_table) = &mut self.jumpdest_addresses else { @@ -326,57 +282,138 @@ impl GenerationState { virt: ContextMetadata::CodeSize as usize, }))?; - let mut address = MemoryAddress { - context: self.registers.context, - segment: Segment::Code as usize, - virt: 0, - }; - let mut proof = 0; - let mut prefix_size = 0; + let code = (0..self.last_jumpdest_address) + .map(|i| { + u256_to_u8(self.memory.get(MemoryAddress { + context: self.registers.context, + segment: Segment::Code as usize, + virt: i, + })) + }) + .collect::, _>>()?; // TODO: The proof searching algorithm is not very eficient. But luckyly it doesn't seem - // a problem because is done natively. + // a problem as is done natively. // Search the closest address to last_jumpdest_address for which none of // the previous 32 bytes in the code (including opcodes and pushed bytes) // are PUSHXX and the address is in its range - while address.virt < self.last_jumpdest_address { - let opcode = u256_to_u8(self.memory.get(address))?; - let is_push = - opcode >= get_push_opcode(1).into() && opcode <= get_push_opcode(32).into(); - address.virt += if is_push { - (opcode - get_push_opcode(1) + 2).into() - } else { - 1 - }; - // Check if the new address has a prefix of size >= 32 - let mut has_prefix = true; - for i in address.virt as i32 - 32..address.virt as i32 { - let opcode = u256_to_u8(self.memory.get(MemoryAddress { + let proof = CodeIterator::until(&code, self.last_jumpdest_address + 1).fold( + 0, + |acc, (pos, opcode)| { + let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) { + code[prefix_start..pos].iter().enumerate().fold( + true, + |acc, (prefix_pos, &byte)| { + acc && (byte > get_push_opcode(32) + || (prefix_start + prefix_pos) as i32 + + (byte as i32 - get_push_opcode(1) as i32) + + 1 + < pos as i32) + }, + ) + } else { + false + }; + if has_prefix { + pos - 32 + } else { + acc + } + }, + ); + Ok(proof.into()) + } +} + +impl GenerationState { + fn generate_jumpdest_table(&mut self) -> Result<(), ProgramError> { + let mut state = self.soft_clone(); + let code_len = u256_to_usize(self.memory.get(MemoryAddress { + context: self.registers.context, + segment: Segment::ContextMetadata as usize, + virt: ContextMetadata::CodeSize as usize, + }))?; + // Generate the jumpdest table + let code = (0..code_len) + .map(|i| { + u256_to_u8(self.memory.get(MemoryAddress { context: self.registers.context, segment: Segment::Code as usize, - virt: i as usize, - }))?; - if i < 0 - || (opcode >= get_push_opcode(1) - && opcode <= get_push_opcode(32) - && i + (opcode - get_push_opcode(1)) as i32 + 1 >= address.virt as i32) - { - has_prefix = false; - break; - } - } - if has_prefix { - proof = address.virt - 32; + virt: i, + })) + }) + .collect::, _>>()?; + + // We need to set the the simulated jumpdest bits to one as otherwise + // the simulation will fail + let mut jumpdest_table = vec![]; + for (pos, opcode) in CodeIterator::new(&code) { + jumpdest_table.push((pos, opcode == get_opcode("JUMPDEST"))); + if opcode == get_opcode("JUMPDEST") { + state.memory.set( + MemoryAddress { + context: state.registers.context, + segment: Segment::JumpdestBits as usize, + virt: pos, + }, + U256::one(), + ); } } - if address.virt > self.last_jumpdest_address { - return Err(ProgramError::ProverInputError( - ProverInputError::InvalidJumpDestination, - )); + + // Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call + self.jumpdest_addresses = simulate_cpu_between_labels_and_get_user_jumps( + "validate_jumpdest_table_end", + "terminate_common", + &mut state, + ) + .ok(); + + Ok(()) + } +} + +struct CodeIterator<'a> { + code: &'a Vec, + pos: usize, + end: usize, +} + +impl<'a> CodeIterator<'a> { + fn new(code: &'a Vec) -> Self { + CodeIterator { + end: code.len(), + code, + pos: 0, } - Ok(proof.into()) + } + fn until(code: &'a Vec, end: usize) -> Self { + CodeIterator { + end: std::cmp::min(code.len(), end), + code, + pos: 0, + } + } +} + +impl<'a> Iterator for CodeIterator<'a> { + type Item = (usize, u8); + + fn next(&mut self) -> Option { + let CodeIterator { code, pos, end } = self; + if *pos >= *end { + return None; + } + let opcode = code[*pos]; + let old_pos = *pos; + *pos += if opcode >= get_push_opcode(1) && opcode <= get_push_opcode(32) { + (opcode - get_push_opcode(1) + 2).into() + } else { + 1 + }; + Some((old_pos, opcode)) } } diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 51fea9a4..ab33a661 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -377,7 +377,7 @@ where let alphas = challenger.get_n_challenges(config.num_challenges); - // #[cfg(test)] + #[cfg(test)] { check_constraints( stark, @@ -636,7 +636,7 @@ where .collect() } -// #[cfg(test)] +#[cfg(test)] /// Check that all constraints evaluate to zero on `H`. /// Can also be used to check the degree of the constraints by evaluating on a larger subgroup. fn check_constraints<'a, F, C, S, const D: usize>( diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index 0fa14321..04688543 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -395,7 +395,12 @@ fn try_perform_instruction( if state.registers.is_kernel { log_kernel_instruction(state, op); } else { - log::debug!("User instruction: {:?} stack = {:?}", op, state.stack()); + log::debug!( + "User instruction: {:?} ctx = {:?} stack = {:?}", + op, + state.registers.context, + state.stack() + ); } fill_op_flag(op, &mut row); From 746e13448b0c8469b81b0da7fb3a280455344c6b Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Wed, 13 Dec 2023 17:06:42 +0100 Subject: [PATCH 03/37] Fix jumpdest analisys test --- evm/src/cpu/kernel/interpreter.rs | 13 +++-- .../kernel/tests/core/jumpdest_analysis.rs | 56 +++++++++---------- evm/src/generation/prover_input.rs | 6 -- 3 files changed, 37 insertions(+), 38 deletions(-) diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index 78691632..a16d2d3a 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -292,10 +292,15 @@ impl<'a> Interpreter<'a> { pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec) { self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize] - .content = jumpdest_bits - .into_iter() - .map(|x| u256_from_bool(x)) - .collect(); + .content = jumpdest_bits.iter().map(|&x| u256_from_bool(x)).collect(); + self.generation_state.jumpdest_addresses = Some( + jumpdest_bits + .into_iter() + .enumerate() + .filter(|&(_, x)| x) + .map(|(i, _)| i) + .collect(), + ) } fn incr(&mut self, n: usize) { diff --git a/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs b/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs index d3edc17b..58e9f936 100644 --- a/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs +++ b/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs @@ -4,37 +4,37 @@ use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::interpreter::Interpreter; use crate::cpu::kernel::opcodes::{get_opcode, get_push_opcode}; -// #[test] -// fn test_jumpdest_analysis() -> Result<()> { -// let jumpdest_analysis = KERNEL.global_labels["validate_jumpdest_table"]; -// const CONTEXT: usize = 3; // arbitrary +#[test] +fn test_validate_jumpdest_table() -> Result<()> { + let validate_jumpdest_table = KERNEL.global_labels["validate_jumpdest_table"]; + const CONTEXT: usize = 3; // arbitrary -// let add = get_opcode("ADD"); -// let push2 = get_push_opcode(2); -// let jumpdest = get_opcode("JUMPDEST"); + let add = get_opcode("ADD"); + let push2 = get_push_opcode(2); + let jumpdest = get_opcode("JUMPDEST"); -// #[rustfmt::skip] -// let code: Vec = vec![ -// add, -// jumpdest, -// push2, -// jumpdest, // part of PUSH2 -// jumpdest, // part of PUSH2 -// jumpdest, -// add, -// jumpdest, -// ]; + #[rustfmt::skip] + let code: Vec = vec![ + add, + jumpdest, + push2, + jumpdest, // part of PUSH2 + jumpdest, // part of PUSH2 + jumpdest, + add, + jumpdest, + ]; -// let jumpdest_bits = vec![false, true, false, false, false, true, false, true]; + let jumpdest_bits = vec![false, true, false, false, false, true, false, true]; -// // Contract creation transaction. -// let initial_stack = vec![0xDEADBEEFu32.into(), CONTEXT.into()]; -// let mut interpreter = Interpreter::new_with_kernel(jumpdest_analysis, initial_stack); -// interpreter.set_code(CONTEXT, code); -// interpreter.set_jumpdest_bits(CONTEXT, jumpdest_bits); + // Contract creation transaction. + let initial_stack = vec![0xDEADBEEFu32.into(), CONTEXT.into()]; + let mut interpreter = Interpreter::new_with_kernel(validate_jumpdest_table, initial_stack); + interpreter.set_code(CONTEXT, code); + interpreter.set_jumpdest_bits(CONTEXT, jumpdest_bits); -// interpreter.run()?; -// assert_eq!(interpreter.stack(), vec![]); + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![]); -// Ok(()) -// } + Ok(()) +} diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 2435f4eb..53e25db4 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -276,12 +276,6 @@ impl GenerationState { /// Return the proof for the last jump adddress fn run_next_jumpdest_table_proof(&mut self) -> Result { - let code_len = u256_to_usize(self.memory.get(MemoryAddress { - context: self.registers.context, - segment: Segment::ContextMetadata as usize, - virt: ContextMetadata::CodeSize as usize, - }))?; - let code = (0..self.last_jumpdest_address) .map(|i| { u256_to_u8(self.memory.get(MemoryAddress { From 2c5347c45f07f31cdfd7a183cdf5f9873d5c32e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alonso=20Gonz=C3=A1lez?= Date: Fri, 15 Dec 2023 09:49:19 +0100 Subject: [PATCH 04/37] Apply suggestions from code review Co-authored-by: Robin Salen <30937548+Nashtare@users.noreply.github.com> --- evm/src/cpu/kernel/asm/core/call.asm | 1 - evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm | 10 +++++----- evm/src/cpu/kernel/interpreter.rs | 2 +- evm/src/generation/mod.rs | 4 ++-- evm/src/generation/prover_input.rs | 14 +++++++------- 5 files changed, 15 insertions(+), 16 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/call.asm b/evm/src/cpu/kernel/asm/core/call.asm index 46765954..5173d358 100644 --- a/evm/src/cpu/kernel/asm/core/call.asm +++ b/evm/src/cpu/kernel/asm/core/call.asm @@ -370,7 +370,6 @@ call_too_deep: GET_CONTEXT // stack: ctx, code_size, retdest %validate_jumpdest_table - PUSH 0 // jump dest EXIT_KERNEL // (Old context) stack: new_ctx diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index 76c25fa0..79475b37 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -122,9 +122,9 @@ global is_jumpdest: //stack: jumpdest, ctx, proof_prefix_addr, retdest SWAP2 DUP1 // stack: proof_prefix_addr, proof_prefix_addr, ctx, jumpdest - %eq_const(0) + IS_ZERO %jumpi(verify_path) - //stack: proof_prefix_addr, ctx, jumpdest, retdest + // stack: proof_prefix_addr, ctx, jumpdest, retdest // If we are here we need to check that the next 32 bytes are less // than JUMPXX for XX < 32 - i <=> opcode < 0x7f - i = 127 - i, 0 <= i < 32, // or larger than 127 @@ -141,7 +141,7 @@ global is_jumpdest: %jump(verify_path) return_is_jumpdest: - //stack: proof_prefix_addr, jumpdest, ctx, retdest + // stack: proof_prefix_addr, jumpdest, ctx, retdest %pop3 JUMP @@ -187,7 +187,7 @@ global validate_jumpdest_table: // and the next prover input should contain a proof for address'. PROVER_INPUT(jumpdest_table::next_address) DUP1 %jumpi(check_proof) - // If proof == 0 there are no more jump destionations to check + // If proof == 0 there are no more jump destinations to check POP // This is just a hook used for avoiding verification of the jumpdest // table in another contexts. It is useful during proof generation, @@ -196,7 +196,7 @@ global validate_jumpdest_table_end: POP JUMP check_proof: - %sub_const(1) + %decrement DUP2 DUP2 // stack: address, ctx, address, ctx // We read the proof diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index a16d2d3a..9b66b5de 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -303,7 +303,7 @@ impl<'a> Interpreter<'a> { ) } - fn incr(&mut self, n: usize) { + const fn incr(&mut self, n: usize) { self.generation_state.registers.program_counter += n; } diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 71176da5..9af830d4 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -293,7 +293,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps( state.registers.program_counter = KERNEL.global_labels[initial_label]; let context = state.registers.context; - log::debug!("Simulating CPU for jumpdest analysis "); + log::debug!("Simulating CPU for jumpdest analysis."); loop { if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] { @@ -317,7 +317,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps( { // TODO: hotfix for avoiding deeper calls to abort let jumpdest = u256_to_usize(state.registers.stack_top) - .map_err(|_| anyhow::Error::msg("Not a valid jump destination"))?; + .map_err(|_| anyhow!("Not a valid jump destination"))?; state.memory.set( MemoryAddress { context: state.registers.context, diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 53e25db4..1808f3f3 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -274,7 +274,7 @@ impl GenerationState { } } - /// Return the proof for the last jump adddress + /// Returns the proof for the last jump address. fn run_next_jumpdest_table_proof(&mut self) -> Result { let code = (0..self.last_jumpdest_address) .map(|i| { @@ -286,12 +286,12 @@ impl GenerationState { }) .collect::, _>>()?; - // TODO: The proof searching algorithm is not very eficient. But luckyly it doesn't seem + // TODO: The proof searching algorithm is not very efficient. But luckily it doesn't seem // a problem as is done natively. - // Search the closest address to last_jumpdest_address for which none of + // Search the closest address to `last_jumpdest_address` for which none of // the previous 32 bytes in the code (including opcodes and pushed bytes) - // are PUSHXX and the address is in its range + // are PUSHXX and the address is in its range. let proof = CodeIterator::until(&code, self.last_jumpdest_address + 1).fold( 0, @@ -340,9 +340,9 @@ impl GenerationState { }) .collect::, _>>()?; - // We need to set the the simulated jumpdest bits to one as otherwise - // the simulation will fail - let mut jumpdest_table = vec![]; + // We need to set the simulated jumpdest bits to one as otherwise + // the simulation will fail. + let mut jumpdest_table = Vec::with_capacity(code.len()); for (pos, opcode) in CodeIterator::new(&code) { jumpdest_table.push((pos, opcode == get_opcode("JUMPDEST"))); if opcode == get_opcode("JUMPDEST") { From 81f13f3f8a1aa60d0b2ab4caa8eacc871bee1eed Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Fri, 15 Dec 2023 17:11:00 +0100 Subject: [PATCH 05/37] Eliminate nested simulations --- .../cpu/kernel/asm/core/jumpdest_analysis.asm | 7 +- evm/src/cpu/kernel/interpreter.rs | 22 ++-- evm/src/generation/mod.rs | 109 ++++++++++-------- evm/src/generation/prover_input.rs | 45 ++++---- evm/src/generation/state.rs | 4 +- evm/src/witness/transition.rs | 7 +- 6 files changed, 103 insertions(+), 91 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index 79475b37..cfc3575b 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -114,15 +114,12 @@ global is_jumpdest: MLOAD_GENERAL // stack: opcode, jumpdest, ctx, proof_prefix_addr, retdest - // Slightly more efficient than `%eq_const(0x5b) ISZERO` - PUSH 0x5b - SUB - %jumpi(panic) + %assert_eq_const(0x5b) //stack: jumpdest, ctx, proof_prefix_addr, retdest SWAP2 DUP1 // stack: proof_prefix_addr, proof_prefix_addr, ctx, jumpdest - IS_ZERO + ISZERO %jumpi(verify_path) // stack: proof_prefix_addr, ctx, jumpdest, retdest // If we are here we need to check that the next 32 bytes are less diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index 9b66b5de..f007595a 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -1,7 +1,7 @@ //! An EVM interpreter for testing and debugging purposes. use core::cmp::Ordering; -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::ops::Range; use anyhow::{anyhow, bail, ensure}; @@ -293,17 +293,19 @@ impl<'a> Interpreter<'a> { pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec) { self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize] .content = jumpdest_bits.iter().map(|&x| u256_from_bool(x)).collect(); - self.generation_state.jumpdest_addresses = Some( - jumpdest_bits - .into_iter() - .enumerate() - .filter(|&(_, x)| x) - .map(|(i, _)| i) - .collect(), - ) + self.generation_state.jumpdest_addresses = Some(HashMap::from([( + context, + BTreeSet::from_iter( + jumpdest_bits + .into_iter() + .enumerate() + .filter(|&(_, x)| x) + .map(|(i, _)| i), + ), + )])); } - const fn incr(&mut self, n: usize) { + pub(crate) fn incr(&mut self, n: usize) { self.generation_state.registers.program_counter += n; } diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 9af830d4..5163767e 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeSet, HashMap, HashSet}; use anyhow::anyhow; use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie}; @@ -27,6 +27,7 @@ use crate::generation::state::GenerationState; use crate::memory::segments::Segment; use crate::proof::{BlockHashes, BlockMetadata, ExtraBlockData, PublicValues, TrieRoots}; use crate::util::{h2u, u256_to_u8, u256_to_usize}; +use crate::witness::errors::{ProgramError, ProverInputError}; use crate::witness::memory::{MemoryAddress, MemoryChannel}; use crate::witness::transition::transition; @@ -287,56 +288,68 @@ fn simulate_cpu_between_labels_and_get_user_jumps( initial_label: &str, final_label: &str, state: &mut GenerationState, -) -> anyhow::Result> { - let halt_pc = KERNEL.global_labels[final_label]; - let mut jumpdest_addresses = HashSet::new(); - state.registers.program_counter = KERNEL.global_labels[initial_label]; - let context = state.registers.context; +) -> Result<(), ProgramError> { + if let Some(_) = state.jumpdest_addresses { + Ok(()) + } else { + const JUMP_OPCODE: u8 = 0x56; + const JUMPI_OPCODE: u8 = 0x57; - log::debug!("Simulating CPU for jumpdest analysis."); + let halt_pc = KERNEL.global_labels[final_label]; + let mut jumpdest_addresses: HashMap<_, BTreeSet> = HashMap::new(); - loop { - if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] { - state.registers.program_counter = KERNEL.global_labels["validate_jumpdest_table_end"] - } - let pc = state.registers.program_counter; - let halt = state.registers.is_kernel && pc == halt_pc && state.registers.context == context; - let opcode = u256_to_u8(state.memory.get(MemoryAddress { - context: state.registers.context, - segment: Segment::Code as usize, - virt: state.registers.program_counter, - })) - .map_err(|_| anyhow::Error::msg("Invalid opcode."))?; - let cond = if let Ok(cond) = stack_peek(state, 1) { - cond != U256::zero() - } else { - false - }; - if !state.registers.is_kernel - && (opcode == get_opcode("JUMP") || (opcode == get_opcode("JUMPI") && cond)) - { - // TODO: hotfix for avoiding deeper calls to abort - let jumpdest = u256_to_usize(state.registers.stack_top) - .map_err(|_| anyhow!("Not a valid jump destination"))?; - state.memory.set( - MemoryAddress { - context: state.registers.context, - segment: Segment::JumpdestBits as usize, - virt: jumpdest, - }, - U256::one(), - ); - if (state.registers.context == context) { - jumpdest_addresses.insert(jumpdest); + state.registers.program_counter = KERNEL.global_labels[initial_label]; + let initial_context = state.registers.context; + + log::debug!("Simulating CPU for jumpdest analysis."); + + loop { + if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] { + state.registers.program_counter = + KERNEL.global_labels["validate_jumpdest_table_end"] } + let pc = state.registers.program_counter; + let context = state.registers.context; + let halt = state.registers.is_kernel + && pc == halt_pc + && state.registers.context == initial_context; + let opcode = u256_to_u8(state.memory.get(MemoryAddress { + context, + segment: Segment::Code as usize, + virt: state.registers.program_counter, + }))?; + let cond = if let Ok(cond) = stack_peek(state, 1) { + cond != U256::zero() + } else { + false + }; + if !state.registers.is_kernel + && (opcode == JUMP_OPCODE || (opcode == JUMPI_OPCODE && cond)) + { + // Avoid deeper calls to abort + let jumpdest = u256_to_usize(state.registers.stack_top)?; + state.memory.set( + MemoryAddress { + context, + segment: Segment::JumpdestBits as usize, + virt: jumpdest, + }, + U256::one(), + ); + if let Some(ctx_addresses) = jumpdest_addresses.get_mut(&context) { + ctx_addresses.insert(jumpdest); + } else { + jumpdest_addresses.insert(context, BTreeSet::from([jumpdest])); + } + } + if halt { + log::debug!("Simulated CPU halted after {} cycles", state.traces.clock()); + state.jumpdest_addresses = Some(jumpdest_addresses); + return Ok(()); + } + transition(state).map_err(|_| { + ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation) + })?; } - if halt { - log::debug!("Simulated CPU halted after {} cycles", state.traces.clock()); - let mut jumpdest_addresses: Vec = jumpdest_addresses.into_iter().collect(); - jumpdest_addresses.sort(); - return Ok(jumpdest_addresses); - } - - transition(state)?; } } diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 1808f3f3..35571dcb 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -16,7 +16,6 @@ use serde::{Deserialize, Serialize}; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; -use crate::cpu::kernel::opcodes::{get_opcode, get_push_opcode}; use crate::extension_tower::{FieldExt, Fp12, BLS381, BN254}; use crate::generation::prover_input::EvmField::{ Bls381Base, Bls381Scalar, Bn254Base, Bn254Scalar, Secp256k1Base, Secp256k1Scalar, @@ -250,8 +249,9 @@ impl GenerationState { } /// Return the next used jump addres fn run_next_jumpdest_table_address(&mut self) -> Result { + let context = self.registers.context; let code_len = u256_to_usize(self.memory.get(MemoryAddress { - context: self.registers.context, + context, segment: Segment::ContextMetadata as usize, virt: ContextMetadata::CodeSize as usize, }))?; @@ -260,14 +260,14 @@ impl GenerationState { self.generate_jumpdest_table()?; } - let Some(jumpdest_table) = &mut self.jumpdest_addresses else { - // TODO: Add another error + let Some(jumpdest_tables) = &mut self.jumpdest_addresses else { return Err(ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation)); }; - if let Some(next_jumpdest_address) = jumpdest_table.pop() { - self.last_jumpdest_address = next_jumpdest_address; - Ok((next_jumpdest_address + 1).into()) + if let Some(ctx_jumpdest_table) = jumpdest_tables.get_mut(&context) && let Some(next_jumpdest_address) = ctx_jumpdest_table.pop_last() + { + self.last_jumpdest_address = next_jumpdest_address; + Ok((next_jumpdest_address + 1).into()) } else { self.jumpdest_addresses = None; Ok(U256::zero()) @@ -293,6 +293,9 @@ impl GenerationState { // the previous 32 bytes in the code (including opcodes and pushed bytes) // are PUSHXX and the address is in its range. + const PUSH1_OPCODE: u8 = 0x60; + const PUSH32_OPCODE: u8 = 0x7f; + let proof = CodeIterator::until(&code, self.last_jumpdest_address + 1).fold( 0, |acc, (pos, opcode)| { @@ -300,9 +303,9 @@ impl GenerationState { code[prefix_start..pos].iter().enumerate().fold( true, |acc, (prefix_pos, &byte)| { - acc && (byte > get_push_opcode(32) + acc && (byte > PUSH32_OPCODE || (prefix_start + prefix_pos) as i32 - + (byte as i32 - get_push_opcode(1) as i32) + + (byte as i32 - PUSH1_OPCODE as i32) + 1 < pos as i32) }, @@ -323,6 +326,7 @@ impl GenerationState { impl GenerationState { fn generate_jumpdest_table(&mut self) -> Result<(), ProgramError> { + const JUMPDEST_OPCODE: u8 = 0x5b; let mut state = self.soft_clone(); let code_len = u256_to_usize(self.memory.get(MemoryAddress { context: self.registers.context, @@ -344,8 +348,8 @@ impl GenerationState { // the simulation will fail. let mut jumpdest_table = Vec::with_capacity(code.len()); for (pos, opcode) in CodeIterator::new(&code) { - jumpdest_table.push((pos, opcode == get_opcode("JUMPDEST"))); - if opcode == get_opcode("JUMPDEST") { + jumpdest_table.push((pos, opcode == JUMPDEST_OPCODE)); + if opcode == JUMPDEST_OPCODE { state.memory.set( MemoryAddress { context: state.registers.context, @@ -358,32 +362,31 @@ impl GenerationState { } // Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call - self.jumpdest_addresses = simulate_cpu_between_labels_and_get_user_jumps( + simulate_cpu_between_labels_and_get_user_jumps( "validate_jumpdest_table_end", "terminate_common", &mut state, - ) - .ok(); - + )?; + self.jumpdest_addresses = state.jumpdest_addresses; Ok(()) } } struct CodeIterator<'a> { - code: &'a Vec, + code: &'a [u8], pos: usize, end: usize, } impl<'a> CodeIterator<'a> { - fn new(code: &'a Vec) -> Self { + fn new(code: &'a [u8]) -> Self { CodeIterator { end: code.len(), code, pos: 0, } } - fn until(code: &'a Vec, end: usize) -> Self { + fn until(code: &'a [u8], end: usize) -> Self { CodeIterator { end: std::cmp::min(code.len(), end), code, @@ -396,14 +399,16 @@ impl<'a> Iterator for CodeIterator<'a> { type Item = (usize, u8); fn next(&mut self) -> Option { + const PUSH1_OPCODE: u8 = 0x60; + const PUSH32_OPCODE: u8 = 0x70; let CodeIterator { code, pos, end } = self; if *pos >= *end { return None; } let opcode = code[*pos]; let old_pos = *pos; - *pos += if opcode >= get_push_opcode(1) && opcode <= get_push_opcode(32) { - (opcode - get_push_opcode(1) + 2).into() + *pos += if opcode >= PUSH1_OPCODE && opcode <= PUSH32_OPCODE { + (opcode - PUSH1_OPCODE + 2).into() } else { 1 }; diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index 79dd94fb..de07c942 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap}; use ethereum_types::{Address, BigEndianHash, H160, H256, U256}; use keccak_hash::keccak; @@ -52,7 +52,7 @@ pub(crate) struct GenerationState { pub(crate) trie_root_ptrs: TrieRootPtrs, pub(crate) last_jumpdest_address: usize, - pub(crate) jumpdest_addresses: Option>, + pub(crate) jumpdest_addresses: Option>>, } impl GenerationState { diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index 04688543..cf2e3bbe 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -395,12 +395,7 @@ fn try_perform_instruction( if state.registers.is_kernel { log_kernel_instruction(state, op); } else { - log::debug!( - "User instruction: {:?} ctx = {:?} stack = {:?}", - op, - state.registers.context, - state.stack() - ); + log::debug!("User instruction: {:?}", op); } fill_op_flag(op, &mut row); From ad8c2df84a3584d187b17053a87b14c491849d60 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Fri, 15 Dec 2023 17:13:52 +0100 Subject: [PATCH 06/37] Remove U256::as_u8 in comment --- evm/src/util.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evm/src/util.rs b/evm/src/util.rs index 7a635c0e..43545a6b 100644 --- a/evm/src/util.rs +++ b/evm/src/util.rs @@ -70,7 +70,7 @@ pub(crate) fn u256_to_u64(u256: U256) -> Result<(F, F), ProgramError> )) } -/// Safe alternative to `U256::as_u8()`, which errors in case of overflow instead of panicking. +/// Safe conversion from U256 to u8, which errors in case of overflow instead of panicking. pub(crate) fn u256_to_u8(u256: U256) -> Result { u256.try_into().map_err(|_| ProgramError::IntegerTooLarge) } From 5a0c1ad8b78ee87e478e6918b732345a50b7213d Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Fri, 15 Dec 2023 18:14:47 +0100 Subject: [PATCH 07/37] Fix fmt --- evm/src/generation/prover_input.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 35571dcb..3a298e81 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -261,13 +261,16 @@ impl GenerationState { } let Some(jumpdest_tables) = &mut self.jumpdest_addresses else { - return Err(ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation)); + return Err(ProgramError::ProverInputError( + ProverInputError::InvalidJumpdestSimulation, + )); }; - if let Some(ctx_jumpdest_table) = jumpdest_tables.get_mut(&context) && let Some(next_jumpdest_address) = ctx_jumpdest_table.pop_last() + if let Some(ctx_jumpdest_table) = jumpdest_tables.get_mut(&context) + && let Some(next_jumpdest_address) = ctx_jumpdest_table.pop_last() { - self.last_jumpdest_address = next_jumpdest_address; - Ok((next_jumpdest_address + 1).into()) + self.last_jumpdest_address = next_jumpdest_address; + Ok((next_jumpdest_address + 1).into()) } else { self.jumpdest_addresses = None; Ok(U256::zero()) From 77f1cd34968f98eecdb6e142db5b16c6ea986f33 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Fri, 15 Dec 2023 18:52:40 +0100 Subject: [PATCH 08/37] Clippy --- evm/src/generation/mod.rs | 2 +- evm/src/generation/prover_input.rs | 2 +- evm/src/generation/state.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 5163767e..a49e291d 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -289,7 +289,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps( final_label: &str, state: &mut GenerationState, ) -> Result<(), ProgramError> { - if let Some(_) = state.jumpdest_addresses { + if state.jumpdest_addresses.is_some() { Ok(()) } else { const JUMP_OPCODE: u8 = 0x56; diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 3a298e81..926b876d 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -410,7 +410,7 @@ impl<'a> Iterator for CodeIterator<'a> { } let opcode = code[*pos]; let old_pos = *pos; - *pos += if opcode >= PUSH1_OPCODE && opcode <= PUSH32_OPCODE { + *pos += if (PUSH1_OPCODE..=PUSH32_OPCODE).contains(&opcode) { (opcode - PUSH1_OPCODE + 2).into() } else { 1 diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index de07c942..1c50cc29 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -177,7 +177,7 @@ impl GenerationState { pub(crate) fn soft_clone(&self) -> GenerationState { Self { inputs: self.inputs.clone(), - registers: self.registers.clone(), + registers: self.registers, memory: self.memory.clone(), traces: Traces::default(), rlp_prover_inputs: self.rlp_prover_inputs.clone(), From 829ae64fc42e167990a0fda7fb512c8008582987 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Tue, 19 Dec 2023 14:05:51 +0100 Subject: [PATCH 09/37] Improve proof generation --- evm/src/cpu/kernel/interpreter.rs | 21 +-- evm/src/generation/mod.rs | 10 +- evm/src/generation/prover_input.rs | 207 ++++++++++++++++++----------- evm/src/generation/state.rs | 9 +- 4 files changed, 148 insertions(+), 99 deletions(-) diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index f007595a..db76ac4b 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -293,16 +293,17 @@ impl<'a> Interpreter<'a> { pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec) { self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize] .content = jumpdest_bits.iter().map(|&x| u256_from_bool(x)).collect(); - self.generation_state.jumpdest_addresses = Some(HashMap::from([( - context, - BTreeSet::from_iter( - jumpdest_bits - .into_iter() - .enumerate() - .filter(|&(_, x)| x) - .map(|(i, _)| i), - ), - )])); + self.generation_state + .set_proofs_and_jumpdests(HashMap::from([( + context, + BTreeSet::from_iter( + jumpdest_bits + .into_iter() + .enumerate() + .filter(|&(_, x)| x) + .map(|(i, _)| i), + ), + )])); } pub(crate) fn incr(&mut self, n: usize) { diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index a49e291d..2d5bca1a 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -288,9 +288,9 @@ fn simulate_cpu_between_labels_and_get_user_jumps( initial_label: &str, final_label: &str, state: &mut GenerationState, -) -> Result<(), ProgramError> { - if state.jumpdest_addresses.is_some() { - Ok(()) +) -> Result>>, ProgramError> { + if state.jumpdest_proofs.is_some() { + Ok(None) } else { const JUMP_OPCODE: u8 = 0x56; const JUMPI_OPCODE: u8 = 0x57; @@ -304,6 +304,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps( log::debug!("Simulating CPU for jumpdest analysis."); loop { + // skip jumdest table validations in simulations if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] { state.registers.program_counter = KERNEL.global_labels["validate_jumpdest_table_end"] @@ -344,8 +345,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps( } if halt { log::debug!("Simulated CPU halted after {} cycles", state.traces.clock()); - state.jumpdest_addresses = Some(jumpdest_addresses); - return Ok(()); + return Ok(Some(jumpdest_addresses)); } transition(state).map_err(|_| { ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation) diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 926b876d..a5f73ae6 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -1,11 +1,10 @@ use std::cmp::min; -use std::collections::HashSet; +use std::collections::HashMap; use std::mem::transmute; use std::str::FromStr; use anyhow::{bail, Error}; use ethereum_types::{BigEndianHash, H256, U256, U512}; -use hashbrown::HashMap; use itertools::{enumerate, Itertools}; use num_bigint::BigUint; use plonky2::field::extension::Extendable; @@ -256,87 +255,98 @@ impl GenerationState { virt: ContextMetadata::CodeSize as usize, }))?; - if self.jumpdest_addresses.is_none() { - self.generate_jumpdest_table()?; + if self.jumpdest_proofs.is_none() { + self.generate_jumpdest_proofs()?; } - let Some(jumpdest_tables) = &mut self.jumpdest_addresses else { + let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else { return Err(ProgramError::ProverInputError( ProverInputError::InvalidJumpdestSimulation, )); }; - if let Some(ctx_jumpdest_table) = jumpdest_tables.get_mut(&context) - && let Some(next_jumpdest_address) = ctx_jumpdest_table.pop_last() + if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context) + && let Some(next_jumpdest_address) = ctx_jumpdest_proofs.pop() { - self.last_jumpdest_address = next_jumpdest_address; Ok((next_jumpdest_address + 1).into()) } else { - self.jumpdest_addresses = None; + self.jumpdest_proofs = None; Ok(U256::zero()) } } /// Returns the proof for the last jump address. fn run_next_jumpdest_table_proof(&mut self) -> Result { - let code = (0..self.last_jumpdest_address) - .map(|i| { - u256_to_u8(self.memory.get(MemoryAddress { - context: self.registers.context, - segment: Segment::Code as usize, - virt: i, - })) - }) - .collect::, _>>()?; - - // TODO: The proof searching algorithm is not very efficient. But luckily it doesn't seem - // a problem as is done natively. - - // Search the closest address to `last_jumpdest_address` for which none of - // the previous 32 bytes in the code (including opcodes and pushed bytes) - // are PUSHXX and the address is in its range. - - const PUSH1_OPCODE: u8 = 0x60; - const PUSH32_OPCODE: u8 = 0x7f; - - let proof = CodeIterator::until(&code, self.last_jumpdest_address + 1).fold( - 0, - |acc, (pos, opcode)| { - let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) { - code[prefix_start..pos].iter().enumerate().fold( - true, - |acc, (prefix_pos, &byte)| { - acc && (byte > PUSH32_OPCODE - || (prefix_start + prefix_pos) as i32 - + (byte as i32 - PUSH1_OPCODE as i32) - + 1 - < pos as i32) - }, - ) - } else { - false - }; - if has_prefix { - pos - 32 - } else { - acc - } - }, - ); - Ok(proof.into()) + let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else { + return Err(ProgramError::ProverInputError( + ProverInputError::InvalidJumpdestSimulation, + )); + }; + if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context) + && let Some(next_jumpdest_proof) = ctx_jumpdest_proofs.pop() + { + Ok(next_jumpdest_proof.into()) + } else { + Err(ProgramError::ProverInputError( + ProverInputError::InvalidJumpdestSimulation, + )) + } } } impl GenerationState { - fn generate_jumpdest_table(&mut self) -> Result<(), ProgramError> { - const JUMPDEST_OPCODE: u8 = 0x5b; - let mut state = self.soft_clone(); - let code_len = u256_to_usize(self.memory.get(MemoryAddress { - context: self.registers.context, - segment: Segment::ContextMetadata as usize, - virt: ContextMetadata::CodeSize as usize, - }))?; - // Generate the jumpdest table + fn generate_jumpdest_proofs(&mut self) -> Result<(), ProgramError> { + let checkpoint = self.checkpoint(); + let memory = self.memory.clone(); + + let code = self.get_current_code()?; + // We need to set the simulated jumpdest bits to one as otherwise + // the simulation will fail. + self.set_jumpdest_bits(&code); + + // Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call + let Some(jumpdest_table) = simulate_cpu_between_labels_and_get_user_jumps( + "validate_jumpdest_table_end", + "terminate_common", + self, + )? + else { + return Ok(()); + }; + + // Return to the state before starting the simulation + self.rollback(checkpoint); + self.memory = memory; + + // Find proofs for all context + self.set_proofs_and_jumpdests(jumpdest_table); + + Ok(()) + } + + pub(crate) fn set_proofs_and_jumpdests( + &mut self, + jumpdest_table: HashMap>, + ) { + self.jumpdest_proofs = Some(HashMap::from_iter(jumpdest_table.into_iter().map( + |(ctx, jumpdest_table)| { + let code = self.get_code(ctx).unwrap(); + if let Some(&largest_address) = jumpdest_table.last() { + let proofs = get_proofs_and_jumpdests(&code, largest_address, jumpdest_table); + (ctx, proofs) + } else { + (ctx, vec![]) + } + }, + ))); + } + + fn get_current_code(&self) -> Result, ProgramError> { + self.get_code(self.registers.context) + } + + fn get_code(&self, context: usize) -> Result, ProgramError> { + let code_len = self.get_code_len()?; let code = (0..code_len) .map(|i| { u256_to_u8(self.memory.get(MemoryAddress { @@ -346,16 +356,25 @@ impl GenerationState { })) }) .collect::, _>>()?; + Ok(code) + } - // We need to set the simulated jumpdest bits to one as otherwise - // the simulation will fail. - let mut jumpdest_table = Vec::with_capacity(code.len()); + fn get_code_len(&self) -> Result { + let code_len = u256_to_usize(self.memory.get(MemoryAddress { + context: self.registers.context, + segment: Segment::ContextMetadata as usize, + virt: ContextMetadata::CodeSize as usize, + }))?; + Ok(code_len) + } + + fn set_jumpdest_bits<'a>(&mut self, code: &'a Vec) { + const JUMPDEST_OPCODE: u8 = 0x5b; for (pos, opcode) in CodeIterator::new(&code) { - jumpdest_table.push((pos, opcode == JUMPDEST_OPCODE)); if opcode == JUMPDEST_OPCODE { - state.memory.set( + self.memory.set( MemoryAddress { - context: state.registers.context, + context: self.registers.context, segment: Segment::JumpdestBits as usize, virt: pos, }, @@ -363,18 +382,50 @@ impl GenerationState { ); } } - - // Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call - simulate_cpu_between_labels_and_get_user_jumps( - "validate_jumpdest_table_end", - "terminate_common", - &mut state, - )?; - self.jumpdest_addresses = state.jumpdest_addresses; - Ok(()) } } +/// For each address in `jumpdest_table` it search a proof, that is the closest address +/// for which none of the previous 32 bytes in the code (including opcodes +/// and pushed bytes are PUSHXX and the address is in its range. It returns +/// a vector of even size containing proofs followed by their addresses +fn get_proofs_and_jumpdests<'a>( + code: &'a Vec, + largest_address: usize, + jumpdest_table: std::collections::BTreeSet, +) -> Vec { + const PUSH1_OPCODE: u8 = 0x60; + const PUSH32_OPCODE: u8 = 0x7f; + let (proofs, _) = CodeIterator::until(&code, largest_address + 1).fold( + (vec![], 0), + |(mut proofs, acc), (pos, opcode)| { + let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) { + code[prefix_start..pos] + .iter() + .enumerate() + .fold(true, |acc, (prefix_pos, &byte)| { + acc && (byte > PUSH32_OPCODE + || (prefix_start + prefix_pos) as i32 + + (byte as i32 - PUSH1_OPCODE as i32) + + 1 + < pos as i32) + }) + } else { + false + }; + let acc = if has_prefix { pos - 32 } else { acc }; + if jumpdest_table.contains(&pos) { + // Push the proof + proofs.push(acc); + // Push the address + proofs.push(pos); + } + (proofs, acc) + }, + ); + proofs +} + struct CodeIterator<'a> { code: &'a [u8], pos: usize, diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index 1c50cc29..cc1df091 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -51,8 +51,7 @@ pub(crate) struct GenerationState { /// Pointers, within the `TrieData` segment, of the three MPTs. pub(crate) trie_root_ptrs: TrieRootPtrs, - pub(crate) last_jumpdest_address: usize, - pub(crate) jumpdest_addresses: Option>>, + pub(crate) jumpdest_proofs: Option>>, } impl GenerationState { @@ -94,8 +93,7 @@ impl GenerationState { txn_root_ptr: 0, receipt_root_ptr: 0, }, - last_jumpdest_address: 0, - jumpdest_addresses: None, + jumpdest_proofs: None, }; let trie_root_ptrs = state.preinitialize_mpts(&inputs.tries); @@ -189,8 +187,7 @@ impl GenerationState { txn_root_ptr: 0, receipt_root_ptr: 0, }, - last_jumpdest_address: 0, - jumpdest_addresses: None, + jumpdest_proofs: None, } } } From 6ababc96ec6f2d12bbab2c2b45e0a0f0116d90de Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Wed, 20 Dec 2023 14:13:36 +0100 Subject: [PATCH 10/37] Remove aborts for invalid jumps --- evm/src/cpu/kernel/asm/core/call.asm | 2 +- .../cpu/kernel/asm/core/jumpdest_analysis.asm | 74 +++++++++---------- evm/src/cpu/kernel/asm/util/basic_macros.asm | 13 ++++ .../kernel/tests/core/jumpdest_analysis.rs | 6 +- evm/src/generation/mod.rs | 16 +++- evm/src/generation/prover_input.rs | 5 +- evm/src/witness/transition.rs | 7 +- 7 files changed, 72 insertions(+), 51 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/call.asm b/evm/src/cpu/kernel/asm/core/call.asm index 5173d358..fcb4eb32 100644 --- a/evm/src/cpu/kernel/asm/core/call.asm +++ b/evm/src/cpu/kernel/asm/core/call.asm @@ -369,7 +369,7 @@ call_too_deep: // Perform jumpdest analyis GET_CONTEXT // stack: ctx, code_size, retdest - %validate_jumpdest_table + %jumpdest_analisys PUSH 0 // jump dest EXIT_KERNEL // (Old context) stack: new_ctx diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index cfc3575b..97224b3e 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -1,17 +1,14 @@ // Set @SEGMENT_JUMPDEST_BITS to one between positions [init_pos, final_pos], -// for the given context's code. Panics if we never hit final_pos +// for the given context's code. // Pre stack: init_pos, ctx, final_pos, retdest // Post stack: (empty) -global verify_path: +global verify_path_and_write_table: loop: // stack: i, ctx, final_pos, retdest - // Ideally we would break if i >= final_pos, but checking i > final_pos is - // cheaper. It doesn't hurt to over-read by 1, since we'll read 0 which is - // a no-op. DUP3 DUP2 EQ // i == final_pos - %jumpi(return) + %jumpi(proof_ok) DUP3 DUP2 GT // i > final_pos - %jumpi(panic) + %jumpi(proof_not_ok) // stack: i, ctx, final_pos, retdest %stack (i, ctx) -> (ctx, @SEGMENT_CODE, i, i, ctx) @@ -22,24 +19,29 @@ loop: // Slightly more efficient than `%eq_const(0x5b) ISZERO` PUSH 0x5b SUB - // stack: opcode != JUMPDEST, opcode, i, ctx, code_len, retdest + // stack: opcode != JUMPDEST, opcode, i, ctx, final_pos, retdest %jumpi(continue) - // stack: JUMPDEST, i, ctx, code_len, retdest + // stack: JUMPDEST, i, ctx, final_pos, retdest %stack (JUMPDEST, i, ctx) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, i, JUMPDEST, i, ctx) MSTORE_GENERAL continue: - // stack: opcode, i, ctx, code_len, retdest + // stack: opcode, i, ctx, final_pos, retdest %add_const(code_bytes_to_skip) %mload_kernel_code - // stack: bytes_to_skip, i, ctx, code_len, retdest + // stack: bytes_to_skip, i, ctx, final_pos, retdest ADD - // stack: i, ctx, code_len, retdest + // stack: i, ctx, final_pos, retdest %jump(loop) -return: - // stack: i, ctx, code_len, retdest +proof_ok: + // stack: i, ctx, final_pos, retdest + // We already know final pos is a jumpdest + %stack (i, ctx, final_pos) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, i) + MSTORE_GENERAL + JUMP +proof_not_ok: %pop3 JUMP @@ -101,26 +103,21 @@ code_bytes_to_skip: // - code[jumpdest] = 0x5b. // stack: proof_prefix_addr, jumpdest, ctx, retdest // stack: (empty) abort if jumpdest is not a valid destination -global is_jumpdest: +global write_table_if_jumpdest: // stack: proof_prefix_addr, jumpdest, ctx, retdest - //%stack - // (proof_prefix_addr, jumpdest, ctx) -> - // (ctx, @SEGMENT_JUMPDEST_BITS, jumpdest, proof_prefix_addr, jumpdest, ctx) - //MLOAD_GENERAL - //%jumpi(return_is_jumpdest) %stack (proof_prefix_addr, jumpdest, ctx) -> (ctx, @SEGMENT_CODE, jumpdest, jumpdest, ctx, proof_prefix_addr) MLOAD_GENERAL // stack: opcode, jumpdest, ctx, proof_prefix_addr, retdest - %assert_eq_const(0x5b) + %jump_eq_const(0x5b, return) //stack: jumpdest, ctx, proof_prefix_addr, retdest SWAP2 DUP1 // stack: proof_prefix_addr, proof_prefix_addr, ctx, jumpdest ISZERO - %jumpi(verify_path) + %jumpi(verify_path_and_write_table) // stack: proof_prefix_addr, ctx, jumpdest, retdest // If we are here we need to check that the next 32 bytes are less // than JUMPXX for XX < 32 - i <=> opcode < 0x7f - i = 127 - i, 0 <= i < 32, @@ -135,9 +132,8 @@ global is_jumpdest: %check_and_step(99) %check_and_step(98) %check_and_step(97) %check_and_step(96) // check the remaining path - %jump(verify_path) - -return_is_jumpdest: + %jump(verify_path_and_write_table) +return: // stack: proof_prefix_addr, jumpdest, ctx, retdest %pop3 JUMP @@ -154,7 +150,7 @@ return_is_jumpdest: DUP1 %gt_const(127) %jumpi(%%ok) - %assert_lt_const($max) + %jumpi_lt_const($max, return) // stack: proof_prefix_addr, ctx, jumpdest PUSH 0 // We need something to pop %%ok: @@ -162,13 +158,13 @@ return_is_jumpdest: %increment %endmacro -%macro is_jumpdest +%macro write_table_if_jumpdest %stack (proof, addr, ctx) -> (proof, addr, ctx, %%after) - %jump(is_jumpdest) + %jump(write_table_if_jumpdest) %%after: %endmacro -// Check if the jumpdest table is correct. This is done by +// Write the jumpdest table. This is done by // non-deterministically guessing the sequence of jumpdest // addresses used during program execution within the current context. // For each jumpdest address we also non-deterministically guess @@ -179,7 +175,7 @@ return_is_jumpdest: // // stack: ctx, retdest // stack: (empty) -global validate_jumpdest_table: +global jumpdest_analisys: // If address > 0 then address is interpreted as address' + 1 // and the next prover input should contain a proof for address'. PROVER_INPUT(jumpdest_table::next_address) @@ -189,24 +185,22 @@ global validate_jumpdest_table: // This is just a hook used for avoiding verification of the jumpdest // table in another contexts. It is useful during proof generation, // allowing the avoidance of table verification when simulating user code. -global validate_jumpdest_table_end: +global jumpdest_analisys_end: POP JUMP check_proof: %decrement - DUP2 DUP2 - // stack: address, ctx, address, ctx + DUP2 SWAP1 + // stack: address, ctx, ctx // We read the proof PROVER_INPUT(jumpdest_table::next_proof) - // stack: proof, address, ctx, address, ctx - %is_jumpdest - %stack (address, ctx) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, address, ctx) - MSTORE_GENERAL + // stack: proof, address, ctx, ctx + %write_table_if_jumpdest - %jump(validate_jumpdest_table) + %jump(jumpdest_analisys) -%macro validate_jumpdest_table +%macro jumpdest_analisys %stack (ctx) -> (ctx, %%after) - %jump(validate_jumpdest_table) + %jump(jumpdest_analisys) %%after: %endmacro diff --git a/evm/src/cpu/kernel/asm/util/basic_macros.asm b/evm/src/cpu/kernel/asm/util/basic_macros.asm index fc2472b3..d62dc27e 100644 --- a/evm/src/cpu/kernel/asm/util/basic_macros.asm +++ b/evm/src/cpu/kernel/asm/util/basic_macros.asm @@ -8,6 +8,19 @@ jumpi %endmacro +%macro jump_eq_const(c, jumpdest) + PUSH $c + SUB + %jumpi($jumpdest) +%endmacro + +%macro jumpi_lt_const(c, jumpdest) + // %assert_zero is cheaper than %assert_nonzero, so we will leverage the + // fact that (x < c) == !(x >= c). + %ge_const($c) + %jumpi($jumpdest) +%endmacro + %macro pop2 %rep 2 POP diff --git a/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs b/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs index 58e9f936..3d97251c 100644 --- a/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs +++ b/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs @@ -5,8 +5,8 @@ use crate::cpu::kernel::interpreter::Interpreter; use crate::cpu::kernel::opcodes::{get_opcode, get_push_opcode}; #[test] -fn test_validate_jumpdest_table() -> Result<()> { - let validate_jumpdest_table = KERNEL.global_labels["validate_jumpdest_table"]; +fn test_jumpdest_analisys() -> Result<()> { + let jumpdest_analisys = KERNEL.global_labels["jumpdest_analisys"]; const CONTEXT: usize = 3; // arbitrary let add = get_opcode("ADD"); @@ -29,7 +29,7 @@ fn test_validate_jumpdest_table() -> Result<()> { // Contract creation transaction. let initial_stack = vec![0xDEADBEEFu32.into(), CONTEXT.into()]; - let mut interpreter = Interpreter::new_with_kernel(validate_jumpdest_table, initial_stack); + let mut interpreter = Interpreter::new_with_kernel(jumpdest_analisys, initial_stack); interpreter.set_code(CONTEXT, code); interpreter.set_jumpdest_bits(CONTEXT, jumpdest_bits); diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 2d5bca1a..9599f335 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -299,15 +299,15 @@ fn simulate_cpu_between_labels_and_get_user_jumps( let mut jumpdest_addresses: HashMap<_, BTreeSet> = HashMap::new(); state.registers.program_counter = KERNEL.global_labels[initial_label]; + let initial_clock = state.traces.clock(); let initial_context = state.registers.context; log::debug!("Simulating CPU for jumpdest analysis."); loop { // skip jumdest table validations in simulations - if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] { - state.registers.program_counter = - KERNEL.global_labels["validate_jumpdest_table_end"] + if state.registers.program_counter == KERNEL.global_labels["jumpdest_analisys"] { + state.registers.program_counter = KERNEL.global_labels["jumpdest_analisys_end"] } let pc = state.registers.program_counter; let context = state.registers.context; @@ -337,6 +337,11 @@ fn simulate_cpu_between_labels_and_get_user_jumps( }, U256::one(), ); + let jumpdest_opcode = state.memory.get(MemoryAddress { + context, + segment: Segment::Code as usize, + virt: jumpdest, + }); if let Some(ctx_addresses) = jumpdest_addresses.get_mut(&context) { ctx_addresses.insert(jumpdest); } else { @@ -344,7 +349,10 @@ fn simulate_cpu_between_labels_and_get_user_jumps( } } if halt { - log::debug!("Simulated CPU halted after {} cycles", state.traces.clock()); + log::debug!( + "Simulated CPU halted after {} cycles", + state.traces.clock() - initial_clock + ); return Ok(Some(jumpdest_addresses)); } transition(state).map_err(|_| { diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index a5f73ae6..ca9208ea 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -306,7 +306,7 @@ impl GenerationState { // Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call let Some(jumpdest_table) = simulate_cpu_between_labels_and_get_user_jumps( - "validate_jumpdest_table_end", + "jumpdest_analisys_end", "terminate_common", self, )? @@ -385,7 +385,8 @@ impl GenerationState { } } -/// For each address in `jumpdest_table` it search a proof, that is the closest address +/// For each address in `jumpdest_table`, each bounded by larges_address, +/// this function searches for a proof. A proof is the closest address /// for which none of the previous 32 bytes in the code (including opcodes /// and pushed bytes are PUSHXX and the address is in its range. It returns /// a vector of even size containing proofs followed by their addresses diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index cf2e3bbe..b8f962e7 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -395,7 +395,12 @@ fn try_perform_instruction( if state.registers.is_kernel { log_kernel_instruction(state, op); } else { - log::debug!("User instruction: {:?}", op); + log::debug!( + "User instruction: {:?}, ctx = {:?}, stack = {:?}", + op, + state.registers.context, + state.stack() + ); } fill_op_flag(op, &mut row); From 9e39d88ab808c8f9d8a1b1619064fa20ec83fdb3 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Wed, 13 Dec 2023 17:33:53 +0100 Subject: [PATCH 11/37] Rebase to main --- evm/src/cpu/kernel/asm/core/call.asm | 9 +- .../cpu/kernel/asm/core/jumpdest_analysis.asm | 182 ++++++++++++++++++ evm/src/generation/mod.rs | 65 ++++++- evm/src/generation/prover_input.rs | 153 ++++++++++++++- evm/src/generation/state.rs | 26 +++ evm/src/util.rs | 5 + evm/src/witness/errors.rs | 2 + evm/src/witness/transition.rs | 2 +- 8 files changed, 432 insertions(+), 12 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/call.asm b/evm/src/cpu/kernel/asm/core/call.asm index 2e7d1d73..5a2a14c4 100644 --- a/evm/src/cpu/kernel/asm/core/call.asm +++ b/evm/src/cpu/kernel/asm/core/call.asm @@ -367,11 +367,12 @@ call_too_deep: %checkpoint // Checkpoint %increment_call_depth // Perform jumpdest analyis - PUSH %%after - %mload_context_metadata(@CTX_METADATA_CODE_SIZE) - GET_CONTEXT + // PUSH %%after + // %mload_context_metadata(@CTX_METADATA_CODE_SIZE) + // GET_CONTEXT // stack: ctx, code_size, retdest - %jump(jumpdest_analysis) + // %jump(jumpdest_analysis) + %validate_jumpdest_table %%after: PUSH 0 // jump dest EXIT_KERNEL diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index bda6f96e..09bb35fa 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -1,3 +1,48 @@ +// Set @SEGMENT_JUMPDEST_BITS to one between positions [init_pos, final_pos], +// for the given context's code. Panics if we never hit final_pos +// Pre stack: init_pos, ctx, final_pos, retdest +// Post stack: (empty) +global verify_path: +loop_new: + // stack: i, ctx, final_pos, retdest + // Ideally we would break if i >= final_pos, but checking i > final_pos is + // cheaper. It doesn't hurt to over-read by 1, since we'll read 0 which is + // a no-op. + DUP3 DUP2 EQ // i == final_pos + %jumpi(return_new) + DUP3 DUP2 GT // i > final_pos + %jumpi(panic) + + // stack: i, ctx, final_pos, retdest + %stack (i, ctx) -> (ctx, @SEGMENT_CODE, i, i, ctx) + MLOAD_GENERAL + // stack: opcode, i, ctx, final_pos, retdest + + DUP1 + // Slightly more efficient than `%eq_const(0x5b) ISZERO` + PUSH 0x5b + SUB + // stack: opcode != JUMPDEST, opcode, i, ctx, code_len, retdest + %jumpi(continue_new) + + // stack: JUMPDEST, i, ctx, code_len, retdest + %stack (JUMPDEST, i, ctx) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, i, JUMPDEST, i, ctx) + MSTORE_GENERAL + +continue_new: + // stack: opcode, i, ctx, code_len, retdest + %add_const(code_bytes_to_skip) + %mload_kernel_code + // stack: bytes_to_skip, i, ctx, code_len, retdest + ADD + // stack: i, ctx, code_len, retdest + %jump(loop_new) + +return_new: + // stack: i, ctx, code_len, retdest + %pop3 + JUMP + // Populates @SEGMENT_JUMPDEST_BITS for the given context's code. // Pre stack: ctx, code_len, retdest // Post stack: (empty) @@ -89,3 +134,140 @@ code_bytes_to_skip: %rep 128 BYTES 1 // 0x80-0xff %endrep + + +// A proof attesting that jumpdest is a valid jump destinations is +// either 0 or an index 0 < i <= jumpdest - 32. +// A proof is valid if: +// - i == 0 and we can go from the first opcode to jumpdest and code[jumpdest] = 0x5b +// - i > 0 and: +// - for j in {i+0,..., i+31} code[j] != PUSHk for all k >= 32 - j - i, +// - we can go from opcode i+32 to jumpdest, +// - code[jumpdest] = 0x5b. +// stack: proof_prefix_addr, jumpdest, retdest +// stack: (empty) abort if jumpdest is not a valid destination +global is_jumpdest: + GET_CONTEXT + // stack: ctx, proof_prefix_addr, jumpdest, retdest + %stack + (ctx, proof_prefix_addr, jumpdest) -> + (ctx, @SEGMENT_CODE, jumpdest, jumpdest, ctx, proof_prefix_addr) + MLOAD_GENERAL + // stack: opcode, jumpdest, ctx, proof_prefix_addr, retdest + + // Slightly more efficient than `%eq_const(0x5b) ISZERO` + PUSH 0x5b + SUB + %jumpi(panic) + + //stack: jumpdest, ctx, proof_prefix_addr, retdest + SWAP2 DUP1 + // stack: proof_prefix_addr, proof_prefix_addr, ctx, jumpdest + %eq_const(0) + %jumpi(verify_path) + //stack: proof_prefix_addr, ctx, jumpdest, retdest + // If we are here we need to check that the next 32 bytes are less + // than JUMPXX for XX < 32 - i <=> opcode < 0x7f - i = 127 - i, 0 <= i < 32, + // or larger than 127 + %check_and_step(127) %check_and_step(126) %check_and_step(125) %check_and_step(124) + %check_and_step(123) %check_and_step(122) %check_and_step(121) %check_and_step(120) + %check_and_step(119) %check_and_step(118) %check_and_step(117) %check_and_step(116) + %check_and_step(115) %check_and_step(114) %check_and_step(113) %check_and_step(112) + %check_and_step(111) %check_and_step(110) %check_and_step(109) %check_and_step(108) + %check_and_step(107) %check_and_step(106) %check_and_step(105) %check_and_step(104) + %check_and_step(103) %check_and_step(102) %check_and_step(101) %check_and_step(100) + %check_and_step(99) %check_and_step(98) %check_and_step(97) %check_and_step(96) + + // check the remaining path + %jump(verify_path) + +return_is_jumpdest: + //stack: proof_prefix_addr, jumpdest, retdest + %pop2 + JUMP + + +// Chek if the opcode pointed by proof_prefix address is +// less than max and increment proof_prefix_addr +%macro check_and_step(max) + %stack + (proof_prefix_addr, ctx, jumpdest) -> + (ctx, @SEGMENT_CODE, proof_prefix_addr, proof_prefix_addr, ctx, jumpdest) + MLOAD_GENERAL + // stack: opcode, proof_prefix_addr, ctx, jumpdest + DUP1 + %gt_const(127) + %jumpi(%%ok) + %assert_lt_const($max) + // stack: proof_prefix_addr, ctx, jumpdest + PUSH 0 // We need something to pop +%%ok: + POP + %increment +%endmacro + +%macro is_jumpdest + %stack (proof, addr) -> (proof, addr, %%after) + %jump(is_jumpdest) +%%after: +%endmacro + +// Check if the jumpdest table is correct. This is done by +// non-deterministically guessing the sequence of jumpdest +// addresses used during program execution within the current context. +// For each jumpdest address we also non-deterministically guess +// a proof, which is another address in the code, such that +// is_jumpdest don't abort when the proof is on the top of the stack +// an the jumpdest address below. If that's the case we set the +// corresponding bit in @SEGMENT_JUMPDEST_BITS to 1. +// +// stack: retdest +// stack: (empty) +global validate_jumpdest_table: + // If address > 0 it is interpreted as address' = address - 1 + // and the next prover input should contain a proof for address'. + PROVER_INPUT(jumpdest_table::next_address) + DUP1 %jumpi(check_proof) + // If proof == 0 there are no more jump destionations to check + POP +global validate_jumpdest_table_end: + JUMP + // were set to 0 + //%mload_context_metadata(@CTX_METADATA_CODE_SIZE) + // get the code length in bytes + //%add_const(31) + //%div_const(32) + //GET_CONTEXT + //SWAP2 +//verify_chunk: + // stack: i (= proof), code_len, ctx = 0 + //%stack (i, code_len, ctx) -> (code_len, i, ctx, @SEGMENT_JUMPDEST_BITS, i, 32, i, code_len, ctx) + //GT + //%jumpi(valid_table) + //%mload_packing + // stack: packed_bits, code_len, i, ctx + //%assert_eq_const(0) + //%increment + //%jump(verify_chunk) + +check_proof: + %sub_const(1) + DUP1 + // We read the proof + PROVER_INPUT(jumpdest_table::next_proof) + // stack: proof, address + %is_jumpdest + GET_CONTEXT + %stack (ctx, address) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, address) + MSTORE_GENERAL + %jump(validate_jumpdest_table) +valid_table: + // stack: ctx, @SEGMENT_JUMPDEST_BITS, i, 32, i, code_len, ctx, retdest + %pop7 + JUMP + +%macro validate_jumpdest_table + PUSH %%after + %jump(validate_jumpdest_table) +%%after: +%endmacro diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index d691d34e..2c3ee900 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -8,6 +8,7 @@ use ethereum_types::{Address, BigEndianHash, H256, U256}; use itertools::enumerate; use plonky2::field::extension::Extendable; use plonky2::field::polynomial::PolynomialValues; +use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::timed; use plonky2::util::timing::TimingTree; @@ -21,13 +22,15 @@ use crate::all_stark::{AllStark, NUM_TABLES}; use crate::config::StarkConfig; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::assembler::Kernel; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; +use crate::cpu::kernel::opcodes::get_opcode; use crate::generation::state::GenerationState; use crate::generation::trie_extractor::{get_receipt_trie, get_state_trie, get_txn_trie}; use crate::memory::segments::Segment; use crate::proof::{BlockHashes, BlockMetadata, ExtraBlockData, PublicValues, TrieRoots}; use crate::prover::check_abort_signal; -use crate::util::{h2u, u256_to_usize}; +use crate::util::{h2u, u256_to_u8, u256_to_usize}; use crate::witness::memory::{MemoryAddress, MemoryChannel}; use crate::witness::transition::transition; @@ -38,7 +41,7 @@ pub(crate) mod state; mod trie_extractor; use self::mpt::{load_all_mpts, TrieRootPtrs}; -use crate::witness::util::mem_write_log; +use crate::witness::util::{mem_write_log, stack_peek}; /// Inputs needed for trace generation. #[derive(Clone, Debug, Deserialize, Serialize, Default)] @@ -296,9 +299,7 @@ pub fn generate_traces, const D: usize>( Ok((tables, public_values)) } -fn simulate_cpu, const D: usize>( - state: &mut GenerationState, -) -> anyhow::Result<()> { +fn simulate_cpu(state: &mut GenerationState) -> anyhow::Result<()> { let halt_pc = KERNEL.global_labels["halt"]; loop { @@ -333,3 +334,57 @@ fn simulate_cpu, const D: usize>( transition(state)?; } } + +fn simulate_cpu_between_labels_and_get_user_jumps( + initial_label: &str, + final_label: &str, + state: &mut GenerationState, +) -> anyhow::Result> { + let halt_pc = KERNEL.global_labels[final_label]; + let mut jumpdest_addresses = HashSet::new(); + state.registers.program_counter = KERNEL.global_labels[initial_label]; + let context = state.registers.context; + + loop { + if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] { + state.registers.program_counter = KERNEL.global_labels["validate_jumpdest_table_end"] + } + let pc = state.registers.program_counter; + let halt = state.registers.is_kernel && pc == halt_pc && state.registers.context == context; + let opcode = u256_to_u8(state.memory.get(MemoryAddress { + context: state.registers.context, + segment: Segment::Code as usize, + virt: state.registers.program_counter, + })) + .map_err(|_| anyhow::Error::msg("Invalid opcode."))?; + let cond = if let Ok(cond) = stack_peek(state, 1) { + cond != U256::zero() + } else { + false + }; + if !state.registers.is_kernel + && (opcode == get_opcode("JUMP") || (opcode == get_opcode("JUMPI") && cond)) + { + // TODO: hotfix for avoiding deeper calls to abort + let jumpdest = u256_to_usize(state.registers.stack_top) + .map_err(|_| anyhow::Error::msg("Not a valid jump destination"))?; + state.memory.set( + MemoryAddress { + context: state.registers.context, + segment: Segment::JumpdestBits as usize, + virt: jumpdest, + }, + U256::one(), + ); + if (state.registers.context == context) { + jumpdest_addresses.insert(jumpdest); + } + } + if halt { + log::debug!("Simulated CPU halted after {} cycles", state.traces.clock()); + return Ok(jumpdest_addresses.into_iter().collect()); + } + + transition(state)?; + } +} diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index b2a8f0ce..852cd77d 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -1,24 +1,34 @@ +use std::cmp::min; +use std::collections::HashSet; use std::mem::transmute; use std::str::FromStr; use anyhow::{bail, Error}; use ethereum_types::{BigEndianHash, H256, U256, U512}; +use hashbrown::HashMap; use itertools::{enumerate, Itertools}; use num_bigint::BigUint; +use plonky2::field::extension::Extendable; use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; use serde::{Deserialize, Serialize}; +use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::constants::context_metadata::ContextMetadata; +use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; +use crate::cpu::kernel::opcodes::{get_opcode, get_push_opcode}; use crate::extension_tower::{FieldExt, Fp12, BLS381, BN254}; use crate::generation::prover_input::EvmField::{ Bls381Base, Bls381Scalar, Bn254Base, Bn254Scalar, Secp256k1Base, Secp256k1Scalar, }; use crate::generation::prover_input::FieldOp::{Inverse, Sqrt}; +use crate::generation::simulate_cpu_between_labels_and_get_user_jumps; use crate::generation::state::GenerationState; use crate::memory::segments::Segment; use crate::memory::segments::Segment::BnPairing; -use crate::util::{biguint_to_mem_vec, mem_vec_to_biguint, u256_to_usize}; -use crate::witness::errors::ProgramError; +use crate::util::{biguint_to_mem_vec, mem_vec_to_biguint, u256_to_u8, u256_to_usize}; use crate::witness::errors::ProverInputError::*; +use crate::witness::errors::{ProgramError, ProverInputError}; use crate::witness::memory::MemoryAddress; use crate::witness::util::{current_context_peek, stack_peek}; @@ -47,6 +57,7 @@ impl GenerationState { "bignum_modmul" => self.run_bignum_modmul(), "withdrawal" => self.run_withdrawal(), "num_bits" => self.run_num_bits(), + "jumpdest_table" => self.run_jumpdest_table(input_fn), _ => Err(ProgramError::ProverInputError(InvalidFunction)), } } @@ -229,6 +240,144 @@ impl GenerationState { Ok(num_bits.into()) } } + + fn run_jumpdest_table(&mut self, input_fn: &ProverInputFn) -> Result { + match input_fn.0[1].as_str() { + "next_address" => self.run_next_jumpdest_table_address(), + "next_proof" => self.run_next_jumpdest_table_proof(), + _ => Err(ProgramError::ProverInputError(InvalidInput)), + } + } + /// Return the next used jump addres + fn run_next_jumpdest_table_address(&mut self) -> Result { + let code_len = u256_to_usize(self.memory.get(MemoryAddress { + context: self.registers.context, + segment: Segment::ContextMetadata as usize, + virt: ContextMetadata::CodeSize as usize, + }))?; + + if self.jumpdest_addresses.is_none() { + let mut state: GenerationState = self.soft_clone(); + + let mut jumpdest_addresses = vec![]; + // Generate the jumpdest table + let code = (0..code_len) + .map(|i| { + u256_to_u8(self.memory.get(MemoryAddress { + context: self.registers.context, + segment: Segment::Code as usize, + virt: i, + })) + }) + .collect::, _>>()?; + let mut i = 0; + while i < code_len { + if code[i] == get_opcode("JUMPDEST") { + jumpdest_addresses.push(i); + state.memory.set( + MemoryAddress { + context: state.registers.context, + segment: Segment::JumpdestBits as usize, + virt: i, + }, + U256::one(), + ); + log::debug!("jumpdest at {i}"); + } + i += if code[i] >= get_push_opcode(1) && code[i] <= get_push_opcode(32) { + (code[i] - get_push_opcode(1) + 2).into() + } else { + 1 + } + } + + // We need to skip the validate table call + self.jumpdest_addresses = simulate_cpu_between_labels_and_get_user_jumps( + "validate_jumpdest_table_end", + "terminate_common", + &mut state, + ) + .ok(); + log::debug!("code len = {code_len}"); + log::debug!("all jumpdest addresses = {:?}", jumpdest_addresses); + log::debug!("user's jumdest addresses = {:?}", self.jumpdest_addresses); + // self.jumpdest_addresses = Some(jumpdest_addresses); + } + + let Some(jumpdest_table) = &mut self.jumpdest_addresses else { + // TODO: Add another error + return Err(ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation)); + }; + + if let Some(next_jumpdest_address) = jumpdest_table.pop() { + self.last_jumpdest_address = next_jumpdest_address; + Ok((next_jumpdest_address + 1).into()) + } else { + self.jumpdest_addresses = None; + Ok(U256::zero()) + } + } + + /// Return the proof for the last jump adddress + fn run_next_jumpdest_table_proof(&mut self) -> Result { + let code_len = u256_to_usize(self.memory.get(MemoryAddress { + context: self.registers.context, + segment: Segment::ContextMetadata as usize, + virt: ContextMetadata::CodeSize as usize, + }))?; + + let mut address = MemoryAddress { + context: self.registers.context, + segment: Segment::Code as usize, + virt: 0, + }; + let mut proof = 0; + let mut prefix_size = 0; + + // TODO: The proof searching algorithm is not very eficient. But luckyly it doesn't seem + // a problem because is done natively. + + // Search the closest address to last_jumpdest_address for which none of + // the previous 32 bytes in the code (including opcodes and pushed bytes) + // are PUSHXX and the address is in its range + while address.virt < self.last_jumpdest_address { + let opcode = u256_to_u8(self.memory.get(address))?; + let is_push = + opcode >= get_push_opcode(1).into() && opcode <= get_push_opcode(32).into(); + + address.virt += if is_push { + (opcode - get_push_opcode(1) + 2).into() + } else { + 1 + }; + // Check if the new address has a prefix of size >= 32 + let mut has_prefix = true; + for i in address.virt as i32 - 32..address.virt as i32 { + let opcode = u256_to_u8(self.memory.get(MemoryAddress { + context: self.registers.context, + segment: Segment::Code as usize, + virt: i as usize, + }))?; + if i < 0 + || (opcode >= get_push_opcode(1) + && opcode <= get_push_opcode(32) + && i + (opcode - get_push_opcode(1)) as i32 + 1 >= address.virt as i32) + { + has_prefix = false; + break; + } + } + if has_prefix { + proof = address.virt - 32; + } + } + if address.virt > self.last_jumpdest_address { + return Err(ProgramError::ProverInputError( + ProverInputError::InvalidJumpDestination, + )); + } + Ok(proof.into()) + } } enum EvmField { diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index 89ff0c5a..79dd94fb 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -50,6 +50,9 @@ pub(crate) struct GenerationState { /// Pointers, within the `TrieData` segment, of the three MPTs. pub(crate) trie_root_ptrs: TrieRootPtrs, + + pub(crate) last_jumpdest_address: usize, + pub(crate) jumpdest_addresses: Option>, } impl GenerationState { @@ -91,6 +94,8 @@ impl GenerationState { txn_root_ptr: 0, receipt_root_ptr: 0, }, + last_jumpdest_address: 0, + jumpdest_addresses: None, }; let trie_root_ptrs = state.preinitialize_mpts(&inputs.tries); @@ -167,6 +172,27 @@ impl GenerationState { .map(|i| stack_peek(self, i).unwrap()) .collect() } + + /// Clone everything but the traces + pub(crate) fn soft_clone(&self) -> GenerationState { + Self { + inputs: self.inputs.clone(), + registers: self.registers.clone(), + memory: self.memory.clone(), + traces: Traces::default(), + rlp_prover_inputs: self.rlp_prover_inputs.clone(), + state_key_to_address: self.state_key_to_address.clone(), + bignum_modmul_result_limbs: self.bignum_modmul_result_limbs.clone(), + withdrawal_prover_inputs: self.withdrawal_prover_inputs.clone(), + trie_root_ptrs: TrieRootPtrs { + state_root_ptr: 0, + txn_root_ptr: 0, + receipt_root_ptr: 0, + }, + last_jumpdest_address: 0, + jumpdest_addresses: None, + } + } } /// Withdrawals prover input array is of the form `[addr0, amount0, ..., addrN, amountN, U256::MAX, U256::MAX]`. diff --git a/evm/src/util.rs b/evm/src/util.rs index 3d9564b5..bbbd8af1 100644 --- a/evm/src/util.rs +++ b/evm/src/util.rs @@ -70,6 +70,11 @@ pub(crate) fn u256_to_u64(u256: U256) -> Result<(F, F), ProgramError> )) } +/// Safe alternative to `U256::as_u8()`, which errors in case of overflow instead of panicking. +pub(crate) fn u256_to_u8(u256: U256) -> Result { + u256.try_into().map_err(|_| ProgramError::IntegerTooLarge) +} + /// Safe alternative to `U256::as_usize()`, which errors in case of overflow instead of panicking. pub(crate) fn u256_to_usize(u256: U256) -> Result { u256.try_into().map_err(|_| ProgramError::IntegerTooLarge) diff --git a/evm/src/witness/errors.rs b/evm/src/witness/errors.rs index 5a0fcbfb..1b266aef 100644 --- a/evm/src/witness/errors.rs +++ b/evm/src/witness/errors.rs @@ -36,4 +36,6 @@ pub enum ProverInputError { InvalidInput, InvalidFunction, NumBitsError, + InvalidJumpDestination, + InvalidJumpdestSimulation, } diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index cf2e3bbe..0fa14321 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -395,7 +395,7 @@ fn try_perform_instruction( if state.registers.is_kernel { log_kernel_instruction(state, op); } else { - log::debug!("User instruction: {:?}", op); + log::debug!("User instruction: {:?} stack = {:?}", op, state.stack()); } fill_op_flag(op, &mut row); From ff3dc2e51615229c054e53f570e91576363cd73f Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Wed, 13 Dec 2023 14:11:43 +0100 Subject: [PATCH 12/37] Refactor run_next_jumpdest_table_proof --- evm/src/cpu/kernel/asm/core/call.asm | 7 +- .../cpu/kernel/asm/core/jumpdest_analysis.asm | 114 +++------- evm/src/cpu/kernel/interpreter.rs | 9 + .../kernel/tests/core/jumpdest_analysis.rs | 60 +++-- evm/src/generation/mod.rs | 6 +- evm/src/generation/prover_input.rs | 205 +++++++++++------- evm/src/witness/transition.rs | 7 +- 7 files changed, 200 insertions(+), 208 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/call.asm b/evm/src/cpu/kernel/asm/core/call.asm index 5a2a14c4..46765954 100644 --- a/evm/src/cpu/kernel/asm/core/call.asm +++ b/evm/src/cpu/kernel/asm/core/call.asm @@ -367,13 +367,10 @@ call_too_deep: %checkpoint // Checkpoint %increment_call_depth // Perform jumpdest analyis - // PUSH %%after - // %mload_context_metadata(@CTX_METADATA_CODE_SIZE) - // GET_CONTEXT + GET_CONTEXT // stack: ctx, code_size, retdest - // %jump(jumpdest_analysis) %validate_jumpdest_table -%%after: + PUSH 0 // jump dest EXIT_KERNEL // (Old context) stack: new_ctx diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index 09bb35fa..76c25fa0 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -3,13 +3,13 @@ // Pre stack: init_pos, ctx, final_pos, retdest // Post stack: (empty) global verify_path: -loop_new: +loop: // stack: i, ctx, final_pos, retdest // Ideally we would break if i >= final_pos, but checking i > final_pos is // cheaper. It doesn't hurt to over-read by 1, since we'll read 0 which is // a no-op. DUP3 DUP2 EQ // i == final_pos - %jumpi(return_new) + %jumpi(return) DUP3 DUP2 GT // i > final_pos %jumpi(panic) @@ -18,51 +18,6 @@ loop_new: MLOAD_GENERAL // stack: opcode, i, ctx, final_pos, retdest - DUP1 - // Slightly more efficient than `%eq_const(0x5b) ISZERO` - PUSH 0x5b - SUB - // stack: opcode != JUMPDEST, opcode, i, ctx, code_len, retdest - %jumpi(continue_new) - - // stack: JUMPDEST, i, ctx, code_len, retdest - %stack (JUMPDEST, i, ctx) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, i, JUMPDEST, i, ctx) - MSTORE_GENERAL - -continue_new: - // stack: opcode, i, ctx, code_len, retdest - %add_const(code_bytes_to_skip) - %mload_kernel_code - // stack: bytes_to_skip, i, ctx, code_len, retdest - ADD - // stack: i, ctx, code_len, retdest - %jump(loop_new) - -return_new: - // stack: i, ctx, code_len, retdest - %pop3 - JUMP - -// Populates @SEGMENT_JUMPDEST_BITS for the given context's code. -// Pre stack: ctx, code_len, retdest -// Post stack: (empty) -global jumpdest_analysis: - // stack: ctx, code_len, retdest - PUSH 0 // i = 0 - -loop: - // stack: i, ctx, code_len, retdest - // Ideally we would break if i >= code_len, but checking i > code_len is - // cheaper. It doesn't hurt to over-read by 1, since we'll read 0 which is - // a no-op. - DUP3 DUP2 GT // i > code_len - %jumpi(return) - - // stack: i, ctx, code_len, retdest - %stack (i, ctx) -> (ctx, @SEGMENT_CODE, i, i, ctx) - MLOAD_GENERAL - // stack: opcode, i, ctx, code_len, retdest - DUP1 // Slightly more efficient than `%eq_const(0x5b) ISZERO` PUSH 0x5b @@ -144,13 +99,17 @@ code_bytes_to_skip: // - for j in {i+0,..., i+31} code[j] != PUSHk for all k >= 32 - j - i, // - we can go from opcode i+32 to jumpdest, // - code[jumpdest] = 0x5b. -// stack: proof_prefix_addr, jumpdest, retdest +// stack: proof_prefix_addr, jumpdest, ctx, retdest // stack: (empty) abort if jumpdest is not a valid destination global is_jumpdest: - GET_CONTEXT - // stack: ctx, proof_prefix_addr, jumpdest, retdest + // stack: proof_prefix_addr, jumpdest, ctx, retdest + //%stack + // (proof_prefix_addr, jumpdest, ctx) -> + // (ctx, @SEGMENT_JUMPDEST_BITS, jumpdest, proof_prefix_addr, jumpdest, ctx) + //MLOAD_GENERAL + //%jumpi(return_is_jumpdest) %stack - (ctx, proof_prefix_addr, jumpdest) -> + (proof_prefix_addr, jumpdest, ctx) -> (ctx, @SEGMENT_CODE, jumpdest, jumpdest, ctx, proof_prefix_addr) MLOAD_GENERAL // stack: opcode, jumpdest, ctx, proof_prefix_addr, retdest @@ -182,8 +141,8 @@ global is_jumpdest: %jump(verify_path) return_is_jumpdest: - //stack: proof_prefix_addr, jumpdest, retdest - %pop2 + //stack: proof_prefix_addr, jumpdest, ctx, retdest + %pop3 JUMP @@ -194,7 +153,7 @@ return_is_jumpdest: (proof_prefix_addr, ctx, jumpdest) -> (ctx, @SEGMENT_CODE, proof_prefix_addr, proof_prefix_addr, ctx, jumpdest) MLOAD_GENERAL - // stack: opcode, proof_prefix_addr, ctx, jumpdest + // stack: opcode, ctx, proof_prefix_addr, jumpdest DUP1 %gt_const(127) %jumpi(%%ok) @@ -207,7 +166,7 @@ return_is_jumpdest: %endmacro %macro is_jumpdest - %stack (proof, addr) -> (proof, addr, %%after) + %stack (proof, addr, ctx) -> (proof, addr, ctx, %%after) %jump(is_jumpdest) %%after: %endmacro @@ -216,58 +175,41 @@ return_is_jumpdest: // non-deterministically guessing the sequence of jumpdest // addresses used during program execution within the current context. // For each jumpdest address we also non-deterministically guess -// a proof, which is another address in the code, such that -// is_jumpdest don't abort when the proof is on the top of the stack +// a proof, which is another address in the code such that +// is_jumpdest don't abort, when the proof is at the top of the stack // an the jumpdest address below. If that's the case we set the // corresponding bit in @SEGMENT_JUMPDEST_BITS to 1. // -// stack: retdest +// stack: ctx, retdest // stack: (empty) global validate_jumpdest_table: - // If address > 0 it is interpreted as address' = address - 1 + // If address > 0 then address is interpreted as address' + 1 // and the next prover input should contain a proof for address'. PROVER_INPUT(jumpdest_table::next_address) DUP1 %jumpi(check_proof) // If proof == 0 there are no more jump destionations to check POP +// This is just a hook used for avoiding verification of the jumpdest +// table in another contexts. It is useful during proof generation, +// allowing the avoidance of table verification when simulating user code. global validate_jumpdest_table_end: + POP JUMP - // were set to 0 - //%mload_context_metadata(@CTX_METADATA_CODE_SIZE) - // get the code length in bytes - //%add_const(31) - //%div_const(32) - //GET_CONTEXT - //SWAP2 -//verify_chunk: - // stack: i (= proof), code_len, ctx = 0 - //%stack (i, code_len, ctx) -> (code_len, i, ctx, @SEGMENT_JUMPDEST_BITS, i, 32, i, code_len, ctx) - //GT - //%jumpi(valid_table) - //%mload_packing - // stack: packed_bits, code_len, i, ctx - //%assert_eq_const(0) - //%increment - //%jump(verify_chunk) - check_proof: %sub_const(1) - DUP1 + DUP2 DUP2 + // stack: address, ctx, address, ctx // We read the proof PROVER_INPUT(jumpdest_table::next_proof) - // stack: proof, address + // stack: proof, address, ctx, address, ctx %is_jumpdest - GET_CONTEXT - %stack (ctx, address) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, address) + %stack (address, ctx) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, address, ctx) MSTORE_GENERAL + %jump(validate_jumpdest_table) -valid_table: - // stack: ctx, @SEGMENT_JUMPDEST_BITS, i, 32, i, code_len, ctx, retdest - %pop7 - JUMP %macro validate_jumpdest_table - PUSH %%after + %stack (ctx) -> (ctx, %%after) %jump(validate_jumpdest_table) %%after: %endmacro diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index c4376721..1efaa71d 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -10,6 +10,7 @@ use keccak_hash::keccak; use plonky2::field::goldilocks_field::GoldilocksField; use super::assembler::BYTES_PER_OFFSET; +use super::utils::u256_from_bool; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; @@ -413,6 +414,14 @@ impl<'a> Interpreter<'a> { .collect() } + pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec) { + self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize] + .content = jumpdest_bits + .into_iter() + .map(|x| u256_from_bool(x)) + .collect(); + } + fn incr(&mut self, n: usize) { self.generation_state.registers.program_counter += n; } diff --git a/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs b/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs index 022a18d7..d3edc17b 100644 --- a/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs +++ b/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs @@ -4,39 +4,37 @@ use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::interpreter::Interpreter; use crate::cpu::kernel::opcodes::{get_opcode, get_push_opcode}; -#[test] -fn test_jumpdest_analysis() -> Result<()> { - let jumpdest_analysis = KERNEL.global_labels["jumpdest_analysis"]; - const CONTEXT: usize = 3; // arbitrary +// #[test] +// fn test_jumpdest_analysis() -> Result<()> { +// let jumpdest_analysis = KERNEL.global_labels["validate_jumpdest_table"]; +// const CONTEXT: usize = 3; // arbitrary - let add = get_opcode("ADD"); - let push2 = get_push_opcode(2); - let jumpdest = get_opcode("JUMPDEST"); +// let add = get_opcode("ADD"); +// let push2 = get_push_opcode(2); +// let jumpdest = get_opcode("JUMPDEST"); - #[rustfmt::skip] - let code: Vec = vec![ - add, - jumpdest, - push2, - jumpdest, // part of PUSH2 - jumpdest, // part of PUSH2 - jumpdest, - add, - jumpdest, - ]; +// #[rustfmt::skip] +// let code: Vec = vec![ +// add, +// jumpdest, +// push2, +// jumpdest, // part of PUSH2 +// jumpdest, // part of PUSH2 +// jumpdest, +// add, +// jumpdest, +// ]; - let expected_jumpdest_bits = vec![false, true, false, false, false, true, false, true]; +// let jumpdest_bits = vec![false, true, false, false, false, true, false, true]; - // Contract creation transaction. - let initial_stack = vec![0xDEADBEEFu32.into(), code.len().into(), CONTEXT.into()]; - let mut interpreter = Interpreter::new_with_kernel(jumpdest_analysis, initial_stack); - interpreter.set_code(CONTEXT, code); - interpreter.run()?; - assert_eq!(interpreter.stack(), vec![]); - assert_eq!( - interpreter.get_jumpdest_bits(CONTEXT), - expected_jumpdest_bits - ); +// // Contract creation transaction. +// let initial_stack = vec![0xDEADBEEFu32.into(), CONTEXT.into()]; +// let mut interpreter = Interpreter::new_with_kernel(jumpdest_analysis, initial_stack); +// interpreter.set_code(CONTEXT, code); +// interpreter.set_jumpdest_bits(CONTEXT, jumpdest_bits); - Ok(()) -} +// interpreter.run()?; +// assert_eq!(interpreter.stack(), vec![]); + +// Ok(()) +// } diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 2c3ee900..995e067b 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -345,6 +345,8 @@ fn simulate_cpu_between_labels_and_get_user_jumps( state.registers.program_counter = KERNEL.global_labels[initial_label]; let context = state.registers.context; + log::debug!("Simulating CPU for jumpdest analysis "); + loop { if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] { state.registers.program_counter = KERNEL.global_labels["validate_jumpdest_table_end"] @@ -382,7 +384,9 @@ fn simulate_cpu_between_labels_and_get_user_jumps( } if halt { log::debug!("Simulated CPU halted after {} cycles", state.traces.clock()); - return Ok(jumpdest_addresses.into_iter().collect()); + let mut jumpdest_addresses: Vec = jumpdest_addresses.into_iter().collect(); + jumpdest_addresses.sort(); + return Ok(jumpdest_addresses); } transition(state)?; diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 852cd77d..2435f4eb 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -257,51 +257,7 @@ impl GenerationState { }))?; if self.jumpdest_addresses.is_none() { - let mut state: GenerationState = self.soft_clone(); - - let mut jumpdest_addresses = vec![]; - // Generate the jumpdest table - let code = (0..code_len) - .map(|i| { - u256_to_u8(self.memory.get(MemoryAddress { - context: self.registers.context, - segment: Segment::Code as usize, - virt: i, - })) - }) - .collect::, _>>()?; - let mut i = 0; - while i < code_len { - if code[i] == get_opcode("JUMPDEST") { - jumpdest_addresses.push(i); - state.memory.set( - MemoryAddress { - context: state.registers.context, - segment: Segment::JumpdestBits as usize, - virt: i, - }, - U256::one(), - ); - log::debug!("jumpdest at {i}"); - } - i += if code[i] >= get_push_opcode(1) && code[i] <= get_push_opcode(32) { - (code[i] - get_push_opcode(1) + 2).into() - } else { - 1 - } - } - - // We need to skip the validate table call - self.jumpdest_addresses = simulate_cpu_between_labels_and_get_user_jumps( - "validate_jumpdest_table_end", - "terminate_common", - &mut state, - ) - .ok(); - log::debug!("code len = {code_len}"); - log::debug!("all jumpdest addresses = {:?}", jumpdest_addresses); - log::debug!("user's jumdest addresses = {:?}", self.jumpdest_addresses); - // self.jumpdest_addresses = Some(jumpdest_addresses); + self.generate_jumpdest_table()?; } let Some(jumpdest_table) = &mut self.jumpdest_addresses else { @@ -326,57 +282,138 @@ impl GenerationState { virt: ContextMetadata::CodeSize as usize, }))?; - let mut address = MemoryAddress { - context: self.registers.context, - segment: Segment::Code as usize, - virt: 0, - }; - let mut proof = 0; - let mut prefix_size = 0; + let code = (0..self.last_jumpdest_address) + .map(|i| { + u256_to_u8(self.memory.get(MemoryAddress { + context: self.registers.context, + segment: Segment::Code as usize, + virt: i, + })) + }) + .collect::, _>>()?; // TODO: The proof searching algorithm is not very eficient. But luckyly it doesn't seem - // a problem because is done natively. + // a problem as is done natively. // Search the closest address to last_jumpdest_address for which none of // the previous 32 bytes in the code (including opcodes and pushed bytes) // are PUSHXX and the address is in its range - while address.virt < self.last_jumpdest_address { - let opcode = u256_to_u8(self.memory.get(address))?; - let is_push = - opcode >= get_push_opcode(1).into() && opcode <= get_push_opcode(32).into(); - address.virt += if is_push { - (opcode - get_push_opcode(1) + 2).into() - } else { - 1 - }; - // Check if the new address has a prefix of size >= 32 - let mut has_prefix = true; - for i in address.virt as i32 - 32..address.virt as i32 { - let opcode = u256_to_u8(self.memory.get(MemoryAddress { + let proof = CodeIterator::until(&code, self.last_jumpdest_address + 1).fold( + 0, + |acc, (pos, opcode)| { + let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) { + code[prefix_start..pos].iter().enumerate().fold( + true, + |acc, (prefix_pos, &byte)| { + acc && (byte > get_push_opcode(32) + || (prefix_start + prefix_pos) as i32 + + (byte as i32 - get_push_opcode(1) as i32) + + 1 + < pos as i32) + }, + ) + } else { + false + }; + if has_prefix { + pos - 32 + } else { + acc + } + }, + ); + Ok(proof.into()) + } +} + +impl GenerationState { + fn generate_jumpdest_table(&mut self) -> Result<(), ProgramError> { + let mut state = self.soft_clone(); + let code_len = u256_to_usize(self.memory.get(MemoryAddress { + context: self.registers.context, + segment: Segment::ContextMetadata as usize, + virt: ContextMetadata::CodeSize as usize, + }))?; + // Generate the jumpdest table + let code = (0..code_len) + .map(|i| { + u256_to_u8(self.memory.get(MemoryAddress { context: self.registers.context, segment: Segment::Code as usize, - virt: i as usize, - }))?; - if i < 0 - || (opcode >= get_push_opcode(1) - && opcode <= get_push_opcode(32) - && i + (opcode - get_push_opcode(1)) as i32 + 1 >= address.virt as i32) - { - has_prefix = false; - break; - } - } - if has_prefix { - proof = address.virt - 32; + virt: i, + })) + }) + .collect::, _>>()?; + + // We need to set the the simulated jumpdest bits to one as otherwise + // the simulation will fail + let mut jumpdest_table = vec![]; + for (pos, opcode) in CodeIterator::new(&code) { + jumpdest_table.push((pos, opcode == get_opcode("JUMPDEST"))); + if opcode == get_opcode("JUMPDEST") { + state.memory.set( + MemoryAddress { + context: state.registers.context, + segment: Segment::JumpdestBits as usize, + virt: pos, + }, + U256::one(), + ); } } - if address.virt > self.last_jumpdest_address { - return Err(ProgramError::ProverInputError( - ProverInputError::InvalidJumpDestination, - )); + + // Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call + self.jumpdest_addresses = simulate_cpu_between_labels_and_get_user_jumps( + "validate_jumpdest_table_end", + "terminate_common", + &mut state, + ) + .ok(); + + Ok(()) + } +} + +struct CodeIterator<'a> { + code: &'a Vec, + pos: usize, + end: usize, +} + +impl<'a> CodeIterator<'a> { + fn new(code: &'a Vec) -> Self { + CodeIterator { + end: code.len(), + code, + pos: 0, } - Ok(proof.into()) + } + fn until(code: &'a Vec, end: usize) -> Self { + CodeIterator { + end: std::cmp::min(code.len(), end), + code, + pos: 0, + } + } +} + +impl<'a> Iterator for CodeIterator<'a> { + type Item = (usize, u8); + + fn next(&mut self) -> Option { + let CodeIterator { code, pos, end } = self; + if *pos >= *end { + return None; + } + let opcode = code[*pos]; + let old_pos = *pos; + *pos += if opcode >= get_push_opcode(1) && opcode <= get_push_opcode(32) { + (opcode - get_push_opcode(1) + 2).into() + } else { + 1 + }; + Some((old_pos, opcode)) } } diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index 0fa14321..04688543 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -395,7 +395,12 @@ fn try_perform_instruction( if state.registers.is_kernel { log_kernel_instruction(state, op); } else { - log::debug!("User instruction: {:?} stack = {:?}", op, state.stack()); + log::debug!( + "User instruction: {:?} ctx = {:?} stack = {:?}", + op, + state.registers.context, + state.stack() + ); } fill_op_flag(op, &mut row); From ed260980b2983f6fb34ed94b93900887195543f6 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Wed, 13 Dec 2023 17:06:42 +0100 Subject: [PATCH 13/37] Fix jumpdest analisys test --- evm/src/cpu/kernel/interpreter.rs | 13 +++-- .../kernel/tests/core/jumpdest_analysis.rs | 56 +++++++++---------- evm/src/generation/prover_input.rs | 6 -- 3 files changed, 37 insertions(+), 38 deletions(-) diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index 1efaa71d..c67d793d 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -416,10 +416,15 @@ impl<'a> Interpreter<'a> { pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec) { self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize] - .content = jumpdest_bits - .into_iter() - .map(|x| u256_from_bool(x)) - .collect(); + .content = jumpdest_bits.iter().map(|&x| u256_from_bool(x)).collect(); + self.generation_state.jumpdest_addresses = Some( + jumpdest_bits + .into_iter() + .enumerate() + .filter(|&(_, x)| x) + .map(|(i, _)| i) + .collect(), + ) } fn incr(&mut self, n: usize) { diff --git a/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs b/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs index d3edc17b..58e9f936 100644 --- a/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs +++ b/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs @@ -4,37 +4,37 @@ use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::interpreter::Interpreter; use crate::cpu::kernel::opcodes::{get_opcode, get_push_opcode}; -// #[test] -// fn test_jumpdest_analysis() -> Result<()> { -// let jumpdest_analysis = KERNEL.global_labels["validate_jumpdest_table"]; -// const CONTEXT: usize = 3; // arbitrary +#[test] +fn test_validate_jumpdest_table() -> Result<()> { + let validate_jumpdest_table = KERNEL.global_labels["validate_jumpdest_table"]; + const CONTEXT: usize = 3; // arbitrary -// let add = get_opcode("ADD"); -// let push2 = get_push_opcode(2); -// let jumpdest = get_opcode("JUMPDEST"); + let add = get_opcode("ADD"); + let push2 = get_push_opcode(2); + let jumpdest = get_opcode("JUMPDEST"); -// #[rustfmt::skip] -// let code: Vec = vec![ -// add, -// jumpdest, -// push2, -// jumpdest, // part of PUSH2 -// jumpdest, // part of PUSH2 -// jumpdest, -// add, -// jumpdest, -// ]; + #[rustfmt::skip] + let code: Vec = vec![ + add, + jumpdest, + push2, + jumpdest, // part of PUSH2 + jumpdest, // part of PUSH2 + jumpdest, + add, + jumpdest, + ]; -// let jumpdest_bits = vec![false, true, false, false, false, true, false, true]; + let jumpdest_bits = vec![false, true, false, false, false, true, false, true]; -// // Contract creation transaction. -// let initial_stack = vec![0xDEADBEEFu32.into(), CONTEXT.into()]; -// let mut interpreter = Interpreter::new_with_kernel(jumpdest_analysis, initial_stack); -// interpreter.set_code(CONTEXT, code); -// interpreter.set_jumpdest_bits(CONTEXT, jumpdest_bits); + // Contract creation transaction. + let initial_stack = vec![0xDEADBEEFu32.into(), CONTEXT.into()]; + let mut interpreter = Interpreter::new_with_kernel(validate_jumpdest_table, initial_stack); + interpreter.set_code(CONTEXT, code); + interpreter.set_jumpdest_bits(CONTEXT, jumpdest_bits); -// interpreter.run()?; -// assert_eq!(interpreter.stack(), vec![]); + interpreter.run()?; + assert_eq!(interpreter.stack(), vec![]); -// Ok(()) -// } + Ok(()) +} diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 2435f4eb..53e25db4 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -276,12 +276,6 @@ impl GenerationState { /// Return the proof for the last jump adddress fn run_next_jumpdest_table_proof(&mut self) -> Result { - let code_len = u256_to_usize(self.memory.get(MemoryAddress { - context: self.registers.context, - segment: Segment::ContextMetadata as usize, - virt: ContextMetadata::CodeSize as usize, - }))?; - let code = (0..self.last_jumpdest_address) .map(|i| { u256_to_u8(self.memory.get(MemoryAddress { From 0bec6278996adeb3035372dd11013f8184bcf5d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alonso=20Gonz=C3=A1lez?= Date: Fri, 15 Dec 2023 09:49:19 +0100 Subject: [PATCH 14/37] Apply suggestions from code review Co-authored-by: Robin Salen <30937548+Nashtare@users.noreply.github.com> --- evm/src/cpu/kernel/asm/core/call.asm | 1 - evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm | 10 +++++----- evm/src/cpu/kernel/interpreter.rs | 2 +- evm/src/generation/mod.rs | 4 ++-- evm/src/generation/prover_input.rs | 14 +++++++------- 5 files changed, 15 insertions(+), 16 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/call.asm b/evm/src/cpu/kernel/asm/core/call.asm index 46765954..5173d358 100644 --- a/evm/src/cpu/kernel/asm/core/call.asm +++ b/evm/src/cpu/kernel/asm/core/call.asm @@ -370,7 +370,6 @@ call_too_deep: GET_CONTEXT // stack: ctx, code_size, retdest %validate_jumpdest_table - PUSH 0 // jump dest EXIT_KERNEL // (Old context) stack: new_ctx diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index 76c25fa0..79475b37 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -122,9 +122,9 @@ global is_jumpdest: //stack: jumpdest, ctx, proof_prefix_addr, retdest SWAP2 DUP1 // stack: proof_prefix_addr, proof_prefix_addr, ctx, jumpdest - %eq_const(0) + IS_ZERO %jumpi(verify_path) - //stack: proof_prefix_addr, ctx, jumpdest, retdest + // stack: proof_prefix_addr, ctx, jumpdest, retdest // If we are here we need to check that the next 32 bytes are less // than JUMPXX for XX < 32 - i <=> opcode < 0x7f - i = 127 - i, 0 <= i < 32, // or larger than 127 @@ -141,7 +141,7 @@ global is_jumpdest: %jump(verify_path) return_is_jumpdest: - //stack: proof_prefix_addr, jumpdest, ctx, retdest + // stack: proof_prefix_addr, jumpdest, ctx, retdest %pop3 JUMP @@ -187,7 +187,7 @@ global validate_jumpdest_table: // and the next prover input should contain a proof for address'. PROVER_INPUT(jumpdest_table::next_address) DUP1 %jumpi(check_proof) - // If proof == 0 there are no more jump destionations to check + // If proof == 0 there are no more jump destinations to check POP // This is just a hook used for avoiding verification of the jumpdest // table in another contexts. It is useful during proof generation, @@ -196,7 +196,7 @@ global validate_jumpdest_table_end: POP JUMP check_proof: - %sub_const(1) + %decrement DUP2 DUP2 // stack: address, ctx, address, ctx // We read the proof diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index c67d793d..5645045c 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -427,7 +427,7 @@ impl<'a> Interpreter<'a> { ) } - fn incr(&mut self, n: usize) { + const fn incr(&mut self, n: usize) { self.generation_state.registers.program_counter += n; } diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 995e067b..b6146260 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -345,7 +345,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps( state.registers.program_counter = KERNEL.global_labels[initial_label]; let context = state.registers.context; - log::debug!("Simulating CPU for jumpdest analysis "); + log::debug!("Simulating CPU for jumpdest analysis."); loop { if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] { @@ -369,7 +369,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps( { // TODO: hotfix for avoiding deeper calls to abort let jumpdest = u256_to_usize(state.registers.stack_top) - .map_err(|_| anyhow::Error::msg("Not a valid jump destination"))?; + .map_err(|_| anyhow!("Not a valid jump destination"))?; state.memory.set( MemoryAddress { context: state.registers.context, diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 53e25db4..1808f3f3 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -274,7 +274,7 @@ impl GenerationState { } } - /// Return the proof for the last jump adddress + /// Returns the proof for the last jump address. fn run_next_jumpdest_table_proof(&mut self) -> Result { let code = (0..self.last_jumpdest_address) .map(|i| { @@ -286,12 +286,12 @@ impl GenerationState { }) .collect::, _>>()?; - // TODO: The proof searching algorithm is not very eficient. But luckyly it doesn't seem + // TODO: The proof searching algorithm is not very efficient. But luckily it doesn't seem // a problem as is done natively. - // Search the closest address to last_jumpdest_address for which none of + // Search the closest address to `last_jumpdest_address` for which none of // the previous 32 bytes in the code (including opcodes and pushed bytes) - // are PUSHXX and the address is in its range + // are PUSHXX and the address is in its range. let proof = CodeIterator::until(&code, self.last_jumpdest_address + 1).fold( 0, @@ -340,9 +340,9 @@ impl GenerationState { }) .collect::, _>>()?; - // We need to set the the simulated jumpdest bits to one as otherwise - // the simulation will fail - let mut jumpdest_table = vec![]; + // We need to set the simulated jumpdest bits to one as otherwise + // the simulation will fail. + let mut jumpdest_table = Vec::with_capacity(code.len()); for (pos, opcode) in CodeIterator::new(&code) { jumpdest_table.push((pos, opcode == get_opcode("JUMPDEST"))); if opcode == get_opcode("JUMPDEST") { From 5acabad72d31d244cb68a4214e473b9616b03e59 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Fri, 15 Dec 2023 17:11:00 +0100 Subject: [PATCH 15/37] Eliminate nested simulations --- .../cpu/kernel/asm/core/jumpdest_analysis.asm | 7 +- evm/src/cpu/kernel/interpreter.rs | 22 ++-- evm/src/generation/mod.rs | 109 ++++++++++-------- evm/src/generation/prover_input.rs | 45 ++++---- evm/src/generation/state.rs | 4 +- evm/src/witness/transition.rs | 7 +- 6 files changed, 103 insertions(+), 91 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index 79475b37..cfc3575b 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -114,15 +114,12 @@ global is_jumpdest: MLOAD_GENERAL // stack: opcode, jumpdest, ctx, proof_prefix_addr, retdest - // Slightly more efficient than `%eq_const(0x5b) ISZERO` - PUSH 0x5b - SUB - %jumpi(panic) + %assert_eq_const(0x5b) //stack: jumpdest, ctx, proof_prefix_addr, retdest SWAP2 DUP1 // stack: proof_prefix_addr, proof_prefix_addr, ctx, jumpdest - IS_ZERO + ISZERO %jumpi(verify_path) // stack: proof_prefix_addr, ctx, jumpdest, retdest // If we are here we need to check that the next 32 bytes are less diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index 5645045c..177ac4f0 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -1,7 +1,7 @@ //! An EVM interpreter for testing and debugging purposes. use core::cmp::Ordering; -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::ops::Range; use anyhow::bail; @@ -417,17 +417,19 @@ impl<'a> Interpreter<'a> { pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec) { self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize] .content = jumpdest_bits.iter().map(|&x| u256_from_bool(x)).collect(); - self.generation_state.jumpdest_addresses = Some( - jumpdest_bits - .into_iter() - .enumerate() - .filter(|&(_, x)| x) - .map(|(i, _)| i) - .collect(), - ) + self.generation_state.jumpdest_addresses = Some(HashMap::from([( + context, + BTreeSet::from_iter( + jumpdest_bits + .into_iter() + .enumerate() + .filter(|&(_, x)| x) + .map(|(i, _)| i), + ), + )])); } - const fn incr(&mut self, n: usize) { + pub(crate) fn incr(&mut self, n: usize) { self.generation_state.registers.program_counter += n; } diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index b6146260..1919b40d 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap, HashMap, HashSet}; use std::sync::atomic::AtomicBool; use std::sync::Arc; @@ -31,6 +31,7 @@ use crate::memory::segments::Segment; use crate::proof::{BlockHashes, BlockMetadata, ExtraBlockData, PublicValues, TrieRoots}; use crate::prover::check_abort_signal; use crate::util::{h2u, u256_to_u8, u256_to_usize}; +use crate::witness::errors::{ProgramError, ProverInputError}; use crate::witness::memory::{MemoryAddress, MemoryChannel}; use crate::witness::transition::transition; @@ -339,56 +340,68 @@ fn simulate_cpu_between_labels_and_get_user_jumps( initial_label: &str, final_label: &str, state: &mut GenerationState, -) -> anyhow::Result> { - let halt_pc = KERNEL.global_labels[final_label]; - let mut jumpdest_addresses = HashSet::new(); - state.registers.program_counter = KERNEL.global_labels[initial_label]; - let context = state.registers.context; +) -> Result<(), ProgramError> { + if let Some(_) = state.jumpdest_addresses { + Ok(()) + } else { + const JUMP_OPCODE: u8 = 0x56; + const JUMPI_OPCODE: u8 = 0x57; - log::debug!("Simulating CPU for jumpdest analysis."); + let halt_pc = KERNEL.global_labels[final_label]; + let mut jumpdest_addresses: HashMap<_, BTreeSet> = HashMap::new(); - loop { - if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] { - state.registers.program_counter = KERNEL.global_labels["validate_jumpdest_table_end"] - } - let pc = state.registers.program_counter; - let halt = state.registers.is_kernel && pc == halt_pc && state.registers.context == context; - let opcode = u256_to_u8(state.memory.get(MemoryAddress { - context: state.registers.context, - segment: Segment::Code as usize, - virt: state.registers.program_counter, - })) - .map_err(|_| anyhow::Error::msg("Invalid opcode."))?; - let cond = if let Ok(cond) = stack_peek(state, 1) { - cond != U256::zero() - } else { - false - }; - if !state.registers.is_kernel - && (opcode == get_opcode("JUMP") || (opcode == get_opcode("JUMPI") && cond)) - { - // TODO: hotfix for avoiding deeper calls to abort - let jumpdest = u256_to_usize(state.registers.stack_top) - .map_err(|_| anyhow!("Not a valid jump destination"))?; - state.memory.set( - MemoryAddress { - context: state.registers.context, - segment: Segment::JumpdestBits as usize, - virt: jumpdest, - }, - U256::one(), - ); - if (state.registers.context == context) { - jumpdest_addresses.insert(jumpdest); + state.registers.program_counter = KERNEL.global_labels[initial_label]; + let initial_context = state.registers.context; + + log::debug!("Simulating CPU for jumpdest analysis."); + + loop { + if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] { + state.registers.program_counter = + KERNEL.global_labels["validate_jumpdest_table_end"] } + let pc = state.registers.program_counter; + let context = state.registers.context; + let halt = state.registers.is_kernel + && pc == halt_pc + && state.registers.context == initial_context; + let opcode = u256_to_u8(state.memory.get(MemoryAddress { + context, + segment: Segment::Code as usize, + virt: state.registers.program_counter, + }))?; + let cond = if let Ok(cond) = stack_peek(state, 1) { + cond != U256::zero() + } else { + false + }; + if !state.registers.is_kernel + && (opcode == JUMP_OPCODE || (opcode == JUMPI_OPCODE && cond)) + { + // Avoid deeper calls to abort + let jumpdest = u256_to_usize(state.registers.stack_top)?; + state.memory.set( + MemoryAddress { + context, + segment: Segment::JumpdestBits as usize, + virt: jumpdest, + }, + U256::one(), + ); + if let Some(ctx_addresses) = jumpdest_addresses.get_mut(&context) { + ctx_addresses.insert(jumpdest); + } else { + jumpdest_addresses.insert(context, BTreeSet::from([jumpdest])); + } + } + if halt { + log::debug!("Simulated CPU halted after {} cycles", state.traces.clock()); + state.jumpdest_addresses = Some(jumpdest_addresses); + return Ok(()); + } + transition(state).map_err(|_| { + ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation) + })?; } - if halt { - log::debug!("Simulated CPU halted after {} cycles", state.traces.clock()); - let mut jumpdest_addresses: Vec = jumpdest_addresses.into_iter().collect(); - jumpdest_addresses.sort(); - return Ok(jumpdest_addresses); - } - - transition(state)?; } } diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 1808f3f3..35571dcb 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -16,7 +16,6 @@ use serde::{Deserialize, Serialize}; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; -use crate::cpu::kernel::opcodes::{get_opcode, get_push_opcode}; use crate::extension_tower::{FieldExt, Fp12, BLS381, BN254}; use crate::generation::prover_input::EvmField::{ Bls381Base, Bls381Scalar, Bn254Base, Bn254Scalar, Secp256k1Base, Secp256k1Scalar, @@ -250,8 +249,9 @@ impl GenerationState { } /// Return the next used jump addres fn run_next_jumpdest_table_address(&mut self) -> Result { + let context = self.registers.context; let code_len = u256_to_usize(self.memory.get(MemoryAddress { - context: self.registers.context, + context, segment: Segment::ContextMetadata as usize, virt: ContextMetadata::CodeSize as usize, }))?; @@ -260,14 +260,14 @@ impl GenerationState { self.generate_jumpdest_table()?; } - let Some(jumpdest_table) = &mut self.jumpdest_addresses else { - // TODO: Add another error + let Some(jumpdest_tables) = &mut self.jumpdest_addresses else { return Err(ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation)); }; - if let Some(next_jumpdest_address) = jumpdest_table.pop() { - self.last_jumpdest_address = next_jumpdest_address; - Ok((next_jumpdest_address + 1).into()) + if let Some(ctx_jumpdest_table) = jumpdest_tables.get_mut(&context) && let Some(next_jumpdest_address) = ctx_jumpdest_table.pop_last() + { + self.last_jumpdest_address = next_jumpdest_address; + Ok((next_jumpdest_address + 1).into()) } else { self.jumpdest_addresses = None; Ok(U256::zero()) @@ -293,6 +293,9 @@ impl GenerationState { // the previous 32 bytes in the code (including opcodes and pushed bytes) // are PUSHXX and the address is in its range. + const PUSH1_OPCODE: u8 = 0x60; + const PUSH32_OPCODE: u8 = 0x7f; + let proof = CodeIterator::until(&code, self.last_jumpdest_address + 1).fold( 0, |acc, (pos, opcode)| { @@ -300,9 +303,9 @@ impl GenerationState { code[prefix_start..pos].iter().enumerate().fold( true, |acc, (prefix_pos, &byte)| { - acc && (byte > get_push_opcode(32) + acc && (byte > PUSH32_OPCODE || (prefix_start + prefix_pos) as i32 - + (byte as i32 - get_push_opcode(1) as i32) + + (byte as i32 - PUSH1_OPCODE as i32) + 1 < pos as i32) }, @@ -323,6 +326,7 @@ impl GenerationState { impl GenerationState { fn generate_jumpdest_table(&mut self) -> Result<(), ProgramError> { + const JUMPDEST_OPCODE: u8 = 0x5b; let mut state = self.soft_clone(); let code_len = u256_to_usize(self.memory.get(MemoryAddress { context: self.registers.context, @@ -344,8 +348,8 @@ impl GenerationState { // the simulation will fail. let mut jumpdest_table = Vec::with_capacity(code.len()); for (pos, opcode) in CodeIterator::new(&code) { - jumpdest_table.push((pos, opcode == get_opcode("JUMPDEST"))); - if opcode == get_opcode("JUMPDEST") { + jumpdest_table.push((pos, opcode == JUMPDEST_OPCODE)); + if opcode == JUMPDEST_OPCODE { state.memory.set( MemoryAddress { context: state.registers.context, @@ -358,32 +362,31 @@ impl GenerationState { } // Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call - self.jumpdest_addresses = simulate_cpu_between_labels_and_get_user_jumps( + simulate_cpu_between_labels_and_get_user_jumps( "validate_jumpdest_table_end", "terminate_common", &mut state, - ) - .ok(); - + )?; + self.jumpdest_addresses = state.jumpdest_addresses; Ok(()) } } struct CodeIterator<'a> { - code: &'a Vec, + code: &'a [u8], pos: usize, end: usize, } impl<'a> CodeIterator<'a> { - fn new(code: &'a Vec) -> Self { + fn new(code: &'a [u8]) -> Self { CodeIterator { end: code.len(), code, pos: 0, } } - fn until(code: &'a Vec, end: usize) -> Self { + fn until(code: &'a [u8], end: usize) -> Self { CodeIterator { end: std::cmp::min(code.len(), end), code, @@ -396,14 +399,16 @@ impl<'a> Iterator for CodeIterator<'a> { type Item = (usize, u8); fn next(&mut self) -> Option { + const PUSH1_OPCODE: u8 = 0x60; + const PUSH32_OPCODE: u8 = 0x70; let CodeIterator { code, pos, end } = self; if *pos >= *end { return None; } let opcode = code[*pos]; let old_pos = *pos; - *pos += if opcode >= get_push_opcode(1) && opcode <= get_push_opcode(32) { - (opcode - get_push_opcode(1) + 2).into() + *pos += if opcode >= PUSH1_OPCODE && opcode <= PUSH32_OPCODE { + (opcode - PUSH1_OPCODE + 2).into() } else { 1 }; diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index 79dd94fb..de07c942 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap}; use ethereum_types::{Address, BigEndianHash, H160, H256, U256}; use keccak_hash::keccak; @@ -52,7 +52,7 @@ pub(crate) struct GenerationState { pub(crate) trie_root_ptrs: TrieRootPtrs, pub(crate) last_jumpdest_address: usize, - pub(crate) jumpdest_addresses: Option>, + pub(crate) jumpdest_addresses: Option>>, } impl GenerationState { diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index 04688543..cf2e3bbe 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -395,12 +395,7 @@ fn try_perform_instruction( if state.registers.is_kernel { log_kernel_instruction(state, op); } else { - log::debug!( - "User instruction: {:?} ctx = {:?} stack = {:?}", - op, - state.registers.context, - state.stack() - ); + log::debug!("User instruction: {:?}", op); } fill_op_flag(op, &mut row); From 08982498d6e75174826a941a9f277920b47d1149 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Fri, 15 Dec 2023 17:13:52 +0100 Subject: [PATCH 16/37] Remove U256::as_u8 in comment --- evm/src/util.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evm/src/util.rs b/evm/src/util.rs index bbbd8af1..fa19ef09 100644 --- a/evm/src/util.rs +++ b/evm/src/util.rs @@ -70,7 +70,7 @@ pub(crate) fn u256_to_u64(u256: U256) -> Result<(F, F), ProgramError> )) } -/// Safe alternative to `U256::as_u8()`, which errors in case of overflow instead of panicking. +/// Safe conversion from U256 to u8, which errors in case of overflow instead of panicking. pub(crate) fn u256_to_u8(u256: U256) -> Result { u256.try_into().map_err(|_| ProgramError::IntegerTooLarge) } From aaa38b33ba898599007c71ed865c9b3b4edb2e41 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Fri, 15 Dec 2023 18:14:47 +0100 Subject: [PATCH 17/37] Fix fmt --- evm/src/generation/prover_input.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 35571dcb..3a298e81 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -261,13 +261,16 @@ impl GenerationState { } let Some(jumpdest_tables) = &mut self.jumpdest_addresses else { - return Err(ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation)); + return Err(ProgramError::ProverInputError( + ProverInputError::InvalidJumpdestSimulation, + )); }; - if let Some(ctx_jumpdest_table) = jumpdest_tables.get_mut(&context) && let Some(next_jumpdest_address) = ctx_jumpdest_table.pop_last() + if let Some(ctx_jumpdest_table) = jumpdest_tables.get_mut(&context) + && let Some(next_jumpdest_address) = ctx_jumpdest_table.pop_last() { - self.last_jumpdest_address = next_jumpdest_address; - Ok((next_jumpdest_address + 1).into()) + self.last_jumpdest_address = next_jumpdest_address; + Ok((next_jumpdest_address + 1).into()) } else { self.jumpdest_addresses = None; Ok(U256::zero()) From c4025063dedfbf9bdc7a6e032bdc0fefd20a4c93 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Fri, 15 Dec 2023 18:52:40 +0100 Subject: [PATCH 18/37] Clippy --- evm/src/generation/mod.rs | 2 +- evm/src/generation/prover_input.rs | 2 +- evm/src/generation/state.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 1919b40d..8f568d90 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -341,7 +341,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps( final_label: &str, state: &mut GenerationState, ) -> Result<(), ProgramError> { - if let Some(_) = state.jumpdest_addresses { + if state.jumpdest_addresses.is_some() { Ok(()) } else { const JUMP_OPCODE: u8 = 0x56; diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 3a298e81..926b876d 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -410,7 +410,7 @@ impl<'a> Iterator for CodeIterator<'a> { } let opcode = code[*pos]; let old_pos = *pos; - *pos += if opcode >= PUSH1_OPCODE && opcode <= PUSH32_OPCODE { + *pos += if (PUSH1_OPCODE..=PUSH32_OPCODE).contains(&opcode) { (opcode - PUSH1_OPCODE + 2).into() } else { 1 diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index de07c942..1c50cc29 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -177,7 +177,7 @@ impl GenerationState { pub(crate) fn soft_clone(&self) -> GenerationState { Self { inputs: self.inputs.clone(), - registers: self.registers.clone(), + registers: self.registers, memory: self.memory.clone(), traces: Traces::default(), rlp_prover_inputs: self.rlp_prover_inputs.clone(), From 4e569484c2c4e1b15a19bb24fd1ac01a16b4a9ad Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Tue, 19 Dec 2023 14:05:51 +0100 Subject: [PATCH 19/37] Improve proof generation --- evm/src/cpu/kernel/interpreter.rs | 21 +-- evm/src/generation/mod.rs | 10 +- evm/src/generation/prover_input.rs | 207 ++++++++++++++++++----------- evm/src/generation/state.rs | 9 +- 4 files changed, 148 insertions(+), 99 deletions(-) diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index 177ac4f0..30f862cd 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -417,16 +417,17 @@ impl<'a> Interpreter<'a> { pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec) { self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize] .content = jumpdest_bits.iter().map(|&x| u256_from_bool(x)).collect(); - self.generation_state.jumpdest_addresses = Some(HashMap::from([( - context, - BTreeSet::from_iter( - jumpdest_bits - .into_iter() - .enumerate() - .filter(|&(_, x)| x) - .map(|(i, _)| i), - ), - )])); + self.generation_state + .set_proofs_and_jumpdests(HashMap::from([( + context, + BTreeSet::from_iter( + jumpdest_bits + .into_iter() + .enumerate() + .filter(|&(_, x)| x) + .map(|(i, _)| i), + ), + )])); } pub(crate) fn incr(&mut self, n: usize) { diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 8f568d90..81bacb75 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -340,9 +340,9 @@ fn simulate_cpu_between_labels_and_get_user_jumps( initial_label: &str, final_label: &str, state: &mut GenerationState, -) -> Result<(), ProgramError> { - if state.jumpdest_addresses.is_some() { - Ok(()) +) -> Result>>, ProgramError> { + if state.jumpdest_proofs.is_some() { + Ok(None) } else { const JUMP_OPCODE: u8 = 0x56; const JUMPI_OPCODE: u8 = 0x57; @@ -356,6 +356,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps( log::debug!("Simulating CPU for jumpdest analysis."); loop { + // skip jumdest table validations in simulations if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] { state.registers.program_counter = KERNEL.global_labels["validate_jumpdest_table_end"] @@ -396,8 +397,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps( } if halt { log::debug!("Simulated CPU halted after {} cycles", state.traces.clock()); - state.jumpdest_addresses = Some(jumpdest_addresses); - return Ok(()); + return Ok(Some(jumpdest_addresses)); } transition(state).map_err(|_| { ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation) diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 926b876d..a5f73ae6 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -1,11 +1,10 @@ use std::cmp::min; -use std::collections::HashSet; +use std::collections::HashMap; use std::mem::transmute; use std::str::FromStr; use anyhow::{bail, Error}; use ethereum_types::{BigEndianHash, H256, U256, U512}; -use hashbrown::HashMap; use itertools::{enumerate, Itertools}; use num_bigint::BigUint; use plonky2::field::extension::Extendable; @@ -256,87 +255,98 @@ impl GenerationState { virt: ContextMetadata::CodeSize as usize, }))?; - if self.jumpdest_addresses.is_none() { - self.generate_jumpdest_table()?; + if self.jumpdest_proofs.is_none() { + self.generate_jumpdest_proofs()?; } - let Some(jumpdest_tables) = &mut self.jumpdest_addresses else { + let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else { return Err(ProgramError::ProverInputError( ProverInputError::InvalidJumpdestSimulation, )); }; - if let Some(ctx_jumpdest_table) = jumpdest_tables.get_mut(&context) - && let Some(next_jumpdest_address) = ctx_jumpdest_table.pop_last() + if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context) + && let Some(next_jumpdest_address) = ctx_jumpdest_proofs.pop() { - self.last_jumpdest_address = next_jumpdest_address; Ok((next_jumpdest_address + 1).into()) } else { - self.jumpdest_addresses = None; + self.jumpdest_proofs = None; Ok(U256::zero()) } } /// Returns the proof for the last jump address. fn run_next_jumpdest_table_proof(&mut self) -> Result { - let code = (0..self.last_jumpdest_address) - .map(|i| { - u256_to_u8(self.memory.get(MemoryAddress { - context: self.registers.context, - segment: Segment::Code as usize, - virt: i, - })) - }) - .collect::, _>>()?; - - // TODO: The proof searching algorithm is not very efficient. But luckily it doesn't seem - // a problem as is done natively. - - // Search the closest address to `last_jumpdest_address` for which none of - // the previous 32 bytes in the code (including opcodes and pushed bytes) - // are PUSHXX and the address is in its range. - - const PUSH1_OPCODE: u8 = 0x60; - const PUSH32_OPCODE: u8 = 0x7f; - - let proof = CodeIterator::until(&code, self.last_jumpdest_address + 1).fold( - 0, - |acc, (pos, opcode)| { - let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) { - code[prefix_start..pos].iter().enumerate().fold( - true, - |acc, (prefix_pos, &byte)| { - acc && (byte > PUSH32_OPCODE - || (prefix_start + prefix_pos) as i32 - + (byte as i32 - PUSH1_OPCODE as i32) - + 1 - < pos as i32) - }, - ) - } else { - false - }; - if has_prefix { - pos - 32 - } else { - acc - } - }, - ); - Ok(proof.into()) + let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else { + return Err(ProgramError::ProverInputError( + ProverInputError::InvalidJumpdestSimulation, + )); + }; + if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context) + && let Some(next_jumpdest_proof) = ctx_jumpdest_proofs.pop() + { + Ok(next_jumpdest_proof.into()) + } else { + Err(ProgramError::ProverInputError( + ProverInputError::InvalidJumpdestSimulation, + )) + } } } impl GenerationState { - fn generate_jumpdest_table(&mut self) -> Result<(), ProgramError> { - const JUMPDEST_OPCODE: u8 = 0x5b; - let mut state = self.soft_clone(); - let code_len = u256_to_usize(self.memory.get(MemoryAddress { - context: self.registers.context, - segment: Segment::ContextMetadata as usize, - virt: ContextMetadata::CodeSize as usize, - }))?; - // Generate the jumpdest table + fn generate_jumpdest_proofs(&mut self) -> Result<(), ProgramError> { + let checkpoint = self.checkpoint(); + let memory = self.memory.clone(); + + let code = self.get_current_code()?; + // We need to set the simulated jumpdest bits to one as otherwise + // the simulation will fail. + self.set_jumpdest_bits(&code); + + // Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call + let Some(jumpdest_table) = simulate_cpu_between_labels_and_get_user_jumps( + "validate_jumpdest_table_end", + "terminate_common", + self, + )? + else { + return Ok(()); + }; + + // Return to the state before starting the simulation + self.rollback(checkpoint); + self.memory = memory; + + // Find proofs for all context + self.set_proofs_and_jumpdests(jumpdest_table); + + Ok(()) + } + + pub(crate) fn set_proofs_and_jumpdests( + &mut self, + jumpdest_table: HashMap>, + ) { + self.jumpdest_proofs = Some(HashMap::from_iter(jumpdest_table.into_iter().map( + |(ctx, jumpdest_table)| { + let code = self.get_code(ctx).unwrap(); + if let Some(&largest_address) = jumpdest_table.last() { + let proofs = get_proofs_and_jumpdests(&code, largest_address, jumpdest_table); + (ctx, proofs) + } else { + (ctx, vec![]) + } + }, + ))); + } + + fn get_current_code(&self) -> Result, ProgramError> { + self.get_code(self.registers.context) + } + + fn get_code(&self, context: usize) -> Result, ProgramError> { + let code_len = self.get_code_len()?; let code = (0..code_len) .map(|i| { u256_to_u8(self.memory.get(MemoryAddress { @@ -346,16 +356,25 @@ impl GenerationState { })) }) .collect::, _>>()?; + Ok(code) + } - // We need to set the simulated jumpdest bits to one as otherwise - // the simulation will fail. - let mut jumpdest_table = Vec::with_capacity(code.len()); + fn get_code_len(&self) -> Result { + let code_len = u256_to_usize(self.memory.get(MemoryAddress { + context: self.registers.context, + segment: Segment::ContextMetadata as usize, + virt: ContextMetadata::CodeSize as usize, + }))?; + Ok(code_len) + } + + fn set_jumpdest_bits<'a>(&mut self, code: &'a Vec) { + const JUMPDEST_OPCODE: u8 = 0x5b; for (pos, opcode) in CodeIterator::new(&code) { - jumpdest_table.push((pos, opcode == JUMPDEST_OPCODE)); if opcode == JUMPDEST_OPCODE { - state.memory.set( + self.memory.set( MemoryAddress { - context: state.registers.context, + context: self.registers.context, segment: Segment::JumpdestBits as usize, virt: pos, }, @@ -363,18 +382,50 @@ impl GenerationState { ); } } - - // Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call - simulate_cpu_between_labels_and_get_user_jumps( - "validate_jumpdest_table_end", - "terminate_common", - &mut state, - )?; - self.jumpdest_addresses = state.jumpdest_addresses; - Ok(()) } } +/// For each address in `jumpdest_table` it search a proof, that is the closest address +/// for which none of the previous 32 bytes in the code (including opcodes +/// and pushed bytes are PUSHXX and the address is in its range. It returns +/// a vector of even size containing proofs followed by their addresses +fn get_proofs_and_jumpdests<'a>( + code: &'a Vec, + largest_address: usize, + jumpdest_table: std::collections::BTreeSet, +) -> Vec { + const PUSH1_OPCODE: u8 = 0x60; + const PUSH32_OPCODE: u8 = 0x7f; + let (proofs, _) = CodeIterator::until(&code, largest_address + 1).fold( + (vec![], 0), + |(mut proofs, acc), (pos, opcode)| { + let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) { + code[prefix_start..pos] + .iter() + .enumerate() + .fold(true, |acc, (prefix_pos, &byte)| { + acc && (byte > PUSH32_OPCODE + || (prefix_start + prefix_pos) as i32 + + (byte as i32 - PUSH1_OPCODE as i32) + + 1 + < pos as i32) + }) + } else { + false + }; + let acc = if has_prefix { pos - 32 } else { acc }; + if jumpdest_table.contains(&pos) { + // Push the proof + proofs.push(acc); + // Push the address + proofs.push(pos); + } + (proofs, acc) + }, + ); + proofs +} + struct CodeIterator<'a> { code: &'a [u8], pos: usize, diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index 1c50cc29..cc1df091 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -51,8 +51,7 @@ pub(crate) struct GenerationState { /// Pointers, within the `TrieData` segment, of the three MPTs. pub(crate) trie_root_ptrs: TrieRootPtrs, - pub(crate) last_jumpdest_address: usize, - pub(crate) jumpdest_addresses: Option>>, + pub(crate) jumpdest_proofs: Option>>, } impl GenerationState { @@ -94,8 +93,7 @@ impl GenerationState { txn_root_ptr: 0, receipt_root_ptr: 0, }, - last_jumpdest_address: 0, - jumpdest_addresses: None, + jumpdest_proofs: None, }; let trie_root_ptrs = state.preinitialize_mpts(&inputs.tries); @@ -189,8 +187,7 @@ impl GenerationState { txn_root_ptr: 0, receipt_root_ptr: 0, }, - last_jumpdest_address: 0, - jumpdest_addresses: None, + jumpdest_proofs: None, } } } From 11d668f5e6ceabd3ef10d93123a033a320a347b0 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Wed, 20 Dec 2023 14:13:36 +0100 Subject: [PATCH 20/37] Remove aborts for invalid jumps --- evm/src/cpu/kernel/asm/core/call.asm | 2 +- .../cpu/kernel/asm/core/jumpdest_analysis.asm | 74 +++++++++---------- evm/src/cpu/kernel/asm/util/basic_macros.asm | 13 ++++ .../kernel/tests/core/jumpdest_analysis.rs | 6 +- evm/src/generation/mod.rs | 16 +++- evm/src/generation/prover_input.rs | 5 +- evm/src/witness/transition.rs | 7 +- 7 files changed, 72 insertions(+), 51 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/call.asm b/evm/src/cpu/kernel/asm/core/call.asm index 5173d358..fcb4eb32 100644 --- a/evm/src/cpu/kernel/asm/core/call.asm +++ b/evm/src/cpu/kernel/asm/core/call.asm @@ -369,7 +369,7 @@ call_too_deep: // Perform jumpdest analyis GET_CONTEXT // stack: ctx, code_size, retdest - %validate_jumpdest_table + %jumpdest_analisys PUSH 0 // jump dest EXIT_KERNEL // (Old context) stack: new_ctx diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index cfc3575b..97224b3e 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -1,17 +1,14 @@ // Set @SEGMENT_JUMPDEST_BITS to one between positions [init_pos, final_pos], -// for the given context's code. Panics if we never hit final_pos +// for the given context's code. // Pre stack: init_pos, ctx, final_pos, retdest // Post stack: (empty) -global verify_path: +global verify_path_and_write_table: loop: // stack: i, ctx, final_pos, retdest - // Ideally we would break if i >= final_pos, but checking i > final_pos is - // cheaper. It doesn't hurt to over-read by 1, since we'll read 0 which is - // a no-op. DUP3 DUP2 EQ // i == final_pos - %jumpi(return) + %jumpi(proof_ok) DUP3 DUP2 GT // i > final_pos - %jumpi(panic) + %jumpi(proof_not_ok) // stack: i, ctx, final_pos, retdest %stack (i, ctx) -> (ctx, @SEGMENT_CODE, i, i, ctx) @@ -22,24 +19,29 @@ loop: // Slightly more efficient than `%eq_const(0x5b) ISZERO` PUSH 0x5b SUB - // stack: opcode != JUMPDEST, opcode, i, ctx, code_len, retdest + // stack: opcode != JUMPDEST, opcode, i, ctx, final_pos, retdest %jumpi(continue) - // stack: JUMPDEST, i, ctx, code_len, retdest + // stack: JUMPDEST, i, ctx, final_pos, retdest %stack (JUMPDEST, i, ctx) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, i, JUMPDEST, i, ctx) MSTORE_GENERAL continue: - // stack: opcode, i, ctx, code_len, retdest + // stack: opcode, i, ctx, final_pos, retdest %add_const(code_bytes_to_skip) %mload_kernel_code - // stack: bytes_to_skip, i, ctx, code_len, retdest + // stack: bytes_to_skip, i, ctx, final_pos, retdest ADD - // stack: i, ctx, code_len, retdest + // stack: i, ctx, final_pos, retdest %jump(loop) -return: - // stack: i, ctx, code_len, retdest +proof_ok: + // stack: i, ctx, final_pos, retdest + // We already know final pos is a jumpdest + %stack (i, ctx, final_pos) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, i) + MSTORE_GENERAL + JUMP +proof_not_ok: %pop3 JUMP @@ -101,26 +103,21 @@ code_bytes_to_skip: // - code[jumpdest] = 0x5b. // stack: proof_prefix_addr, jumpdest, ctx, retdest // stack: (empty) abort if jumpdest is not a valid destination -global is_jumpdest: +global write_table_if_jumpdest: // stack: proof_prefix_addr, jumpdest, ctx, retdest - //%stack - // (proof_prefix_addr, jumpdest, ctx) -> - // (ctx, @SEGMENT_JUMPDEST_BITS, jumpdest, proof_prefix_addr, jumpdest, ctx) - //MLOAD_GENERAL - //%jumpi(return_is_jumpdest) %stack (proof_prefix_addr, jumpdest, ctx) -> (ctx, @SEGMENT_CODE, jumpdest, jumpdest, ctx, proof_prefix_addr) MLOAD_GENERAL // stack: opcode, jumpdest, ctx, proof_prefix_addr, retdest - %assert_eq_const(0x5b) + %jump_eq_const(0x5b, return) //stack: jumpdest, ctx, proof_prefix_addr, retdest SWAP2 DUP1 // stack: proof_prefix_addr, proof_prefix_addr, ctx, jumpdest ISZERO - %jumpi(verify_path) + %jumpi(verify_path_and_write_table) // stack: proof_prefix_addr, ctx, jumpdest, retdest // If we are here we need to check that the next 32 bytes are less // than JUMPXX for XX < 32 - i <=> opcode < 0x7f - i = 127 - i, 0 <= i < 32, @@ -135,9 +132,8 @@ global is_jumpdest: %check_and_step(99) %check_and_step(98) %check_and_step(97) %check_and_step(96) // check the remaining path - %jump(verify_path) - -return_is_jumpdest: + %jump(verify_path_and_write_table) +return: // stack: proof_prefix_addr, jumpdest, ctx, retdest %pop3 JUMP @@ -154,7 +150,7 @@ return_is_jumpdest: DUP1 %gt_const(127) %jumpi(%%ok) - %assert_lt_const($max) + %jumpi_lt_const($max, return) // stack: proof_prefix_addr, ctx, jumpdest PUSH 0 // We need something to pop %%ok: @@ -162,13 +158,13 @@ return_is_jumpdest: %increment %endmacro -%macro is_jumpdest +%macro write_table_if_jumpdest %stack (proof, addr, ctx) -> (proof, addr, ctx, %%after) - %jump(is_jumpdest) + %jump(write_table_if_jumpdest) %%after: %endmacro -// Check if the jumpdest table is correct. This is done by +// Write the jumpdest table. This is done by // non-deterministically guessing the sequence of jumpdest // addresses used during program execution within the current context. // For each jumpdest address we also non-deterministically guess @@ -179,7 +175,7 @@ return_is_jumpdest: // // stack: ctx, retdest // stack: (empty) -global validate_jumpdest_table: +global jumpdest_analisys: // If address > 0 then address is interpreted as address' + 1 // and the next prover input should contain a proof for address'. PROVER_INPUT(jumpdest_table::next_address) @@ -189,24 +185,22 @@ global validate_jumpdest_table: // This is just a hook used for avoiding verification of the jumpdest // table in another contexts. It is useful during proof generation, // allowing the avoidance of table verification when simulating user code. -global validate_jumpdest_table_end: +global jumpdest_analisys_end: POP JUMP check_proof: %decrement - DUP2 DUP2 - // stack: address, ctx, address, ctx + DUP2 SWAP1 + // stack: address, ctx, ctx // We read the proof PROVER_INPUT(jumpdest_table::next_proof) - // stack: proof, address, ctx, address, ctx - %is_jumpdest - %stack (address, ctx) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, address, ctx) - MSTORE_GENERAL + // stack: proof, address, ctx, ctx + %write_table_if_jumpdest - %jump(validate_jumpdest_table) + %jump(jumpdest_analisys) -%macro validate_jumpdest_table +%macro jumpdest_analisys %stack (ctx) -> (ctx, %%after) - %jump(validate_jumpdest_table) + %jump(jumpdest_analisys) %%after: %endmacro diff --git a/evm/src/cpu/kernel/asm/util/basic_macros.asm b/evm/src/cpu/kernel/asm/util/basic_macros.asm index fc2472b3..d62dc27e 100644 --- a/evm/src/cpu/kernel/asm/util/basic_macros.asm +++ b/evm/src/cpu/kernel/asm/util/basic_macros.asm @@ -8,6 +8,19 @@ jumpi %endmacro +%macro jump_eq_const(c, jumpdest) + PUSH $c + SUB + %jumpi($jumpdest) +%endmacro + +%macro jumpi_lt_const(c, jumpdest) + // %assert_zero is cheaper than %assert_nonzero, so we will leverage the + // fact that (x < c) == !(x >= c). + %ge_const($c) + %jumpi($jumpdest) +%endmacro + %macro pop2 %rep 2 POP diff --git a/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs b/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs index 58e9f936..3d97251c 100644 --- a/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs +++ b/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs @@ -5,8 +5,8 @@ use crate::cpu::kernel::interpreter::Interpreter; use crate::cpu::kernel::opcodes::{get_opcode, get_push_opcode}; #[test] -fn test_validate_jumpdest_table() -> Result<()> { - let validate_jumpdest_table = KERNEL.global_labels["validate_jumpdest_table"]; +fn test_jumpdest_analisys() -> Result<()> { + let jumpdest_analisys = KERNEL.global_labels["jumpdest_analisys"]; const CONTEXT: usize = 3; // arbitrary let add = get_opcode("ADD"); @@ -29,7 +29,7 @@ fn test_validate_jumpdest_table() -> Result<()> { // Contract creation transaction. let initial_stack = vec![0xDEADBEEFu32.into(), CONTEXT.into()]; - let mut interpreter = Interpreter::new_with_kernel(validate_jumpdest_table, initial_stack); + let mut interpreter = Interpreter::new_with_kernel(jumpdest_analisys, initial_stack); interpreter.set_code(CONTEXT, code); interpreter.set_jumpdest_bits(CONTEXT, jumpdest_bits); diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 81bacb75..72a0dca9 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -351,15 +351,15 @@ fn simulate_cpu_between_labels_and_get_user_jumps( let mut jumpdest_addresses: HashMap<_, BTreeSet> = HashMap::new(); state.registers.program_counter = KERNEL.global_labels[initial_label]; + let initial_clock = state.traces.clock(); let initial_context = state.registers.context; log::debug!("Simulating CPU for jumpdest analysis."); loop { // skip jumdest table validations in simulations - if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] { - state.registers.program_counter = - KERNEL.global_labels["validate_jumpdest_table_end"] + if state.registers.program_counter == KERNEL.global_labels["jumpdest_analisys"] { + state.registers.program_counter = KERNEL.global_labels["jumpdest_analisys_end"] } let pc = state.registers.program_counter; let context = state.registers.context; @@ -389,6 +389,11 @@ fn simulate_cpu_between_labels_and_get_user_jumps( }, U256::one(), ); + let jumpdest_opcode = state.memory.get(MemoryAddress { + context, + segment: Segment::Code as usize, + virt: jumpdest, + }); if let Some(ctx_addresses) = jumpdest_addresses.get_mut(&context) { ctx_addresses.insert(jumpdest); } else { @@ -396,7 +401,10 @@ fn simulate_cpu_between_labels_and_get_user_jumps( } } if halt { - log::debug!("Simulated CPU halted after {} cycles", state.traces.clock()); + log::debug!( + "Simulated CPU halted after {} cycles", + state.traces.clock() - initial_clock + ); return Ok(Some(jumpdest_addresses)); } transition(state).map_err(|_| { diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index a5f73ae6..ca9208ea 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -306,7 +306,7 @@ impl GenerationState { // Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call let Some(jumpdest_table) = simulate_cpu_between_labels_and_get_user_jumps( - "validate_jumpdest_table_end", + "jumpdest_analisys_end", "terminate_common", self, )? @@ -385,7 +385,8 @@ impl GenerationState { } } -/// For each address in `jumpdest_table` it search a proof, that is the closest address +/// For each address in `jumpdest_table`, each bounded by larges_address, +/// this function searches for a proof. A proof is the closest address /// for which none of the previous 32 bytes in the code (including opcodes /// and pushed bytes are PUSHXX and the address is in its range. It returns /// a vector of even size containing proofs followed by their addresses diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index cf2e3bbe..b8f962e7 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -395,7 +395,12 @@ fn try_perform_instruction( if state.registers.is_kernel { log_kernel_instruction(state, op); } else { - log::debug!("User instruction: {:?}", op); + log::debug!( + "User instruction: {:?}, ctx = {:?}, stack = {:?}", + op, + state.registers.context, + state.stack() + ); } fill_op_flag(op, &mut row); From 3e78865d644218836768a811dce902cc43b26a6a Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Wed, 13 Dec 2023 17:33:53 +0100 Subject: [PATCH 21/37] Remove aborts for invalid jumps and Rebase --- evm/src/cpu/kernel/asm/core/call.asm | 5 +- .../cpu/kernel/asm/core/jumpdest_analysis.asm | 157 +++++++++-- evm/src/cpu/kernel/asm/util/basic_macros.asm | 13 + evm/src/cpu/kernel/interpreter.rs | 21 +- .../kernel/tests/core/jumpdest_analysis.rs | 16 +- evm/src/generation/mod.rs | 92 ++++++- evm/src/generation/prover_input.rs | 244 +++++++++++++++++- evm/src/generation/state.rs | 25 +- evm/src/util.rs | 5 + evm/src/witness/errors.rs | 2 + evm/src/witness/transition.rs | 7 +- 11 files changed, 541 insertions(+), 46 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/call.asm b/evm/src/cpu/kernel/asm/core/call.asm index 2e7d1d73..fcb4eb32 100644 --- a/evm/src/cpu/kernel/asm/core/call.asm +++ b/evm/src/cpu/kernel/asm/core/call.asm @@ -367,12 +367,9 @@ call_too_deep: %checkpoint // Checkpoint %increment_call_depth // Perform jumpdest analyis - PUSH %%after - %mload_context_metadata(@CTX_METADATA_CODE_SIZE) GET_CONTEXT // stack: ctx, code_size, retdest - %jump(jumpdest_analysis) -%%after: + %jumpdest_analisys PUSH 0 // jump dest EXIT_KERNEL // (Old context) stack: new_ctx diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index bda6f96e..97224b3e 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -1,45 +1,47 @@ -// Populates @SEGMENT_JUMPDEST_BITS for the given context's code. -// Pre stack: ctx, code_len, retdest +// Set @SEGMENT_JUMPDEST_BITS to one between positions [init_pos, final_pos], +// for the given context's code. +// Pre stack: init_pos, ctx, final_pos, retdest // Post stack: (empty) -global jumpdest_analysis: - // stack: ctx, code_len, retdest - PUSH 0 // i = 0 - +global verify_path_and_write_table: loop: - // stack: i, ctx, code_len, retdest - // Ideally we would break if i >= code_len, but checking i > code_len is - // cheaper. It doesn't hurt to over-read by 1, since we'll read 0 which is - // a no-op. - DUP3 DUP2 GT // i > code_len - %jumpi(return) + // stack: i, ctx, final_pos, retdest + DUP3 DUP2 EQ // i == final_pos + %jumpi(proof_ok) + DUP3 DUP2 GT // i > final_pos + %jumpi(proof_not_ok) - // stack: i, ctx, code_len, retdest + // stack: i, ctx, final_pos, retdest %stack (i, ctx) -> (ctx, @SEGMENT_CODE, i, i, ctx) MLOAD_GENERAL - // stack: opcode, i, ctx, code_len, retdest + // stack: opcode, i, ctx, final_pos, retdest DUP1 // Slightly more efficient than `%eq_const(0x5b) ISZERO` PUSH 0x5b SUB - // stack: opcode != JUMPDEST, opcode, i, ctx, code_len, retdest + // stack: opcode != JUMPDEST, opcode, i, ctx, final_pos, retdest %jumpi(continue) - // stack: JUMPDEST, i, ctx, code_len, retdest + // stack: JUMPDEST, i, ctx, final_pos, retdest %stack (JUMPDEST, i, ctx) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, i, JUMPDEST, i, ctx) MSTORE_GENERAL continue: - // stack: opcode, i, ctx, code_len, retdest + // stack: opcode, i, ctx, final_pos, retdest %add_const(code_bytes_to_skip) %mload_kernel_code - // stack: bytes_to_skip, i, ctx, code_len, retdest + // stack: bytes_to_skip, i, ctx, final_pos, retdest ADD - // stack: i, ctx, code_len, retdest + // stack: i, ctx, final_pos, retdest %jump(loop) -return: - // stack: i, ctx, code_len, retdest +proof_ok: + // stack: i, ctx, final_pos, retdest + // We already know final pos is a jumpdest + %stack (i, ctx, final_pos) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, i) + MSTORE_GENERAL + JUMP +proof_not_ok: %pop3 JUMP @@ -89,3 +91,116 @@ code_bytes_to_skip: %rep 128 BYTES 1 // 0x80-0xff %endrep + + +// A proof attesting that jumpdest is a valid jump destinations is +// either 0 or an index 0 < i <= jumpdest - 32. +// A proof is valid if: +// - i == 0 and we can go from the first opcode to jumpdest and code[jumpdest] = 0x5b +// - i > 0 and: +// - for j in {i+0,..., i+31} code[j] != PUSHk for all k >= 32 - j - i, +// - we can go from opcode i+32 to jumpdest, +// - code[jumpdest] = 0x5b. +// stack: proof_prefix_addr, jumpdest, ctx, retdest +// stack: (empty) abort if jumpdest is not a valid destination +global write_table_if_jumpdest: + // stack: proof_prefix_addr, jumpdest, ctx, retdest + %stack + (proof_prefix_addr, jumpdest, ctx) -> + (ctx, @SEGMENT_CODE, jumpdest, jumpdest, ctx, proof_prefix_addr) + MLOAD_GENERAL + // stack: opcode, jumpdest, ctx, proof_prefix_addr, retdest + + %jump_eq_const(0x5b, return) + + //stack: jumpdest, ctx, proof_prefix_addr, retdest + SWAP2 DUP1 + // stack: proof_prefix_addr, proof_prefix_addr, ctx, jumpdest + ISZERO + %jumpi(verify_path_and_write_table) + // stack: proof_prefix_addr, ctx, jumpdest, retdest + // If we are here we need to check that the next 32 bytes are less + // than JUMPXX for XX < 32 - i <=> opcode < 0x7f - i = 127 - i, 0 <= i < 32, + // or larger than 127 + %check_and_step(127) %check_and_step(126) %check_and_step(125) %check_and_step(124) + %check_and_step(123) %check_and_step(122) %check_and_step(121) %check_and_step(120) + %check_and_step(119) %check_and_step(118) %check_and_step(117) %check_and_step(116) + %check_and_step(115) %check_and_step(114) %check_and_step(113) %check_and_step(112) + %check_and_step(111) %check_and_step(110) %check_and_step(109) %check_and_step(108) + %check_and_step(107) %check_and_step(106) %check_and_step(105) %check_and_step(104) + %check_and_step(103) %check_and_step(102) %check_and_step(101) %check_and_step(100) + %check_and_step(99) %check_and_step(98) %check_and_step(97) %check_and_step(96) + + // check the remaining path + %jump(verify_path_and_write_table) +return: + // stack: proof_prefix_addr, jumpdest, ctx, retdest + %pop3 + JUMP + + +// Chek if the opcode pointed by proof_prefix address is +// less than max and increment proof_prefix_addr +%macro check_and_step(max) + %stack + (proof_prefix_addr, ctx, jumpdest) -> + (ctx, @SEGMENT_CODE, proof_prefix_addr, proof_prefix_addr, ctx, jumpdest) + MLOAD_GENERAL + // stack: opcode, ctx, proof_prefix_addr, jumpdest + DUP1 + %gt_const(127) + %jumpi(%%ok) + %jumpi_lt_const($max, return) + // stack: proof_prefix_addr, ctx, jumpdest + PUSH 0 // We need something to pop +%%ok: + POP + %increment +%endmacro + +%macro write_table_if_jumpdest + %stack (proof, addr, ctx) -> (proof, addr, ctx, %%after) + %jump(write_table_if_jumpdest) +%%after: +%endmacro + +// Write the jumpdest table. This is done by +// non-deterministically guessing the sequence of jumpdest +// addresses used during program execution within the current context. +// For each jumpdest address we also non-deterministically guess +// a proof, which is another address in the code such that +// is_jumpdest don't abort, when the proof is at the top of the stack +// an the jumpdest address below. If that's the case we set the +// corresponding bit in @SEGMENT_JUMPDEST_BITS to 1. +// +// stack: ctx, retdest +// stack: (empty) +global jumpdest_analisys: + // If address > 0 then address is interpreted as address' + 1 + // and the next prover input should contain a proof for address'. + PROVER_INPUT(jumpdest_table::next_address) + DUP1 %jumpi(check_proof) + // If proof == 0 there are no more jump destinations to check + POP +// This is just a hook used for avoiding verification of the jumpdest +// table in another contexts. It is useful during proof generation, +// allowing the avoidance of table verification when simulating user code. +global jumpdest_analisys_end: + POP + JUMP +check_proof: + %decrement + DUP2 SWAP1 + // stack: address, ctx, ctx + // We read the proof + PROVER_INPUT(jumpdest_table::next_proof) + // stack: proof, address, ctx, ctx + %write_table_if_jumpdest + + %jump(jumpdest_analisys) + +%macro jumpdest_analisys + %stack (ctx) -> (ctx, %%after) + %jump(jumpdest_analisys) +%%after: +%endmacro diff --git a/evm/src/cpu/kernel/asm/util/basic_macros.asm b/evm/src/cpu/kernel/asm/util/basic_macros.asm index fc2472b3..d62dc27e 100644 --- a/evm/src/cpu/kernel/asm/util/basic_macros.asm +++ b/evm/src/cpu/kernel/asm/util/basic_macros.asm @@ -8,6 +8,19 @@ jumpi %endmacro +%macro jump_eq_const(c, jumpdest) + PUSH $c + SUB + %jumpi($jumpdest) +%endmacro + +%macro jumpi_lt_const(c, jumpdest) + // %assert_zero is cheaper than %assert_nonzero, so we will leverage the + // fact that (x < c) == !(x >= c). + %ge_const($c) + %jumpi($jumpdest) +%endmacro + %macro pop2 %rep 2 POP diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index c4376721..30f862cd 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -1,7 +1,7 @@ //! An EVM interpreter for testing and debugging purposes. use core::cmp::Ordering; -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::ops::Range; use anyhow::bail; @@ -10,6 +10,7 @@ use keccak_hash::keccak; use plonky2::field::goldilocks_field::GoldilocksField; use super::assembler::BYTES_PER_OFFSET; +use super::utils::u256_from_bool; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; @@ -413,7 +414,23 @@ impl<'a> Interpreter<'a> { .collect() } - fn incr(&mut self, n: usize) { + pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec) { + self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize] + .content = jumpdest_bits.iter().map(|&x| u256_from_bool(x)).collect(); + self.generation_state + .set_proofs_and_jumpdests(HashMap::from([( + context, + BTreeSet::from_iter( + jumpdest_bits + .into_iter() + .enumerate() + .filter(|&(_, x)| x) + .map(|(i, _)| i), + ), + )])); + } + + pub(crate) fn incr(&mut self, n: usize) { self.generation_state.registers.program_counter += n; } diff --git a/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs b/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs index 022a18d7..3d97251c 100644 --- a/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs +++ b/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs @@ -5,8 +5,8 @@ use crate::cpu::kernel::interpreter::Interpreter; use crate::cpu::kernel::opcodes::{get_opcode, get_push_opcode}; #[test] -fn test_jumpdest_analysis() -> Result<()> { - let jumpdest_analysis = KERNEL.global_labels["jumpdest_analysis"]; +fn test_jumpdest_analisys() -> Result<()> { + let jumpdest_analisys = KERNEL.global_labels["jumpdest_analisys"]; const CONTEXT: usize = 3; // arbitrary let add = get_opcode("ADD"); @@ -25,18 +25,16 @@ fn test_jumpdest_analysis() -> Result<()> { jumpdest, ]; - let expected_jumpdest_bits = vec![false, true, false, false, false, true, false, true]; + let jumpdest_bits = vec![false, true, false, false, false, true, false, true]; // Contract creation transaction. - let initial_stack = vec![0xDEADBEEFu32.into(), code.len().into(), CONTEXT.into()]; - let mut interpreter = Interpreter::new_with_kernel(jumpdest_analysis, initial_stack); + let initial_stack = vec![0xDEADBEEFu32.into(), CONTEXT.into()]; + let mut interpreter = Interpreter::new_with_kernel(jumpdest_analisys, initial_stack); interpreter.set_code(CONTEXT, code); + interpreter.set_jumpdest_bits(CONTEXT, jumpdest_bits); + interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); - assert_eq!( - interpreter.get_jumpdest_bits(CONTEXT), - expected_jumpdest_bits - ); Ok(()) } diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index d691d34e..4aa35afb 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::sync::atomic::AtomicBool; use std::sync::Arc; @@ -8,6 +8,7 @@ use ethereum_types::{Address, BigEndianHash, H256, U256}; use itertools::enumerate; use plonky2::field::extension::Extendable; use plonky2::field::polynomial::PolynomialValues; +use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::timed; use plonky2::util::timing::TimingTree; @@ -21,13 +22,16 @@ use crate::all_stark::{AllStark, NUM_TABLES}; use crate::config::StarkConfig; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::assembler::Kernel; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; +use crate::cpu::kernel::opcodes::get_opcode; use crate::generation::state::GenerationState; use crate::generation::trie_extractor::{get_receipt_trie, get_state_trie, get_txn_trie}; use crate::memory::segments::Segment; use crate::proof::{BlockHashes, BlockMetadata, ExtraBlockData, PublicValues, TrieRoots}; use crate::prover::check_abort_signal; -use crate::util::{h2u, u256_to_usize}; +use crate::util::{h2u, u256_to_u8, u256_to_usize}; +use crate::witness::errors::{ProgramError, ProverInputError}; use crate::witness::memory::{MemoryAddress, MemoryChannel}; use crate::witness::transition::transition; @@ -38,7 +42,7 @@ pub(crate) mod state; mod trie_extractor; use self::mpt::{load_all_mpts, TrieRootPtrs}; -use crate::witness::util::mem_write_log; +use crate::witness::util::{mem_write_log, stack_peek}; /// Inputs needed for trace generation. #[derive(Clone, Debug, Deserialize, Serialize, Default)] @@ -296,9 +300,7 @@ pub fn generate_traces, const D: usize>( Ok((tables, public_values)) } -fn simulate_cpu, const D: usize>( - state: &mut GenerationState, -) -> anyhow::Result<()> { +fn simulate_cpu(state: &mut GenerationState) -> anyhow::Result<()> { let halt_pc = KERNEL.global_labels["halt"]; loop { @@ -333,3 +335,81 @@ fn simulate_cpu, const D: usize>( transition(state)?; } } + +fn simulate_cpu_between_labels_and_get_user_jumps( + initial_label: &str, + final_label: &str, + state: &mut GenerationState, +) -> Result>>, ProgramError> { + if state.jumpdest_proofs.is_some() { + Ok(None) + } else { + const JUMP_OPCODE: u8 = 0x56; + const JUMPI_OPCODE: u8 = 0x57; + + let halt_pc = KERNEL.global_labels[final_label]; + let mut jumpdest_addresses: HashMap<_, BTreeSet> = HashMap::new(); + + state.registers.program_counter = KERNEL.global_labels[initial_label]; + let initial_clock = state.traces.clock(); + let initial_context = state.registers.context; + + log::debug!("Simulating CPU for jumpdest analysis."); + + loop { + // skip jumdest table validations in simulations + if state.registers.program_counter == KERNEL.global_labels["jumpdest_analisys"] { + state.registers.program_counter = KERNEL.global_labels["jumpdest_analisys_end"] + } + let pc = state.registers.program_counter; + let context = state.registers.context; + let halt = state.registers.is_kernel + && pc == halt_pc + && state.registers.context == initial_context; + let opcode = u256_to_u8(state.memory.get(MemoryAddress { + context, + segment: Segment::Code as usize, + virt: state.registers.program_counter, + }))?; + let cond = if let Ok(cond) = stack_peek(state, 1) { + cond != U256::zero() + } else { + false + }; + if !state.registers.is_kernel + && (opcode == JUMP_OPCODE || (opcode == JUMPI_OPCODE && cond)) + { + // Avoid deeper calls to abort + let jumpdest = u256_to_usize(state.registers.stack_top)?; + state.memory.set( + MemoryAddress { + context, + segment: Segment::JumpdestBits as usize, + virt: jumpdest, + }, + U256::one(), + ); + let jumpdest_opcode = state.memory.get(MemoryAddress { + context, + segment: Segment::Code as usize, + virt: jumpdest, + }); + if let Some(ctx_addresses) = jumpdest_addresses.get_mut(&context) { + ctx_addresses.insert(jumpdest); + } else { + jumpdest_addresses.insert(context, BTreeSet::from([jumpdest])); + } + } + if halt { + log::debug!( + "Simulated CPU halted after {} cycles", + state.traces.clock() - initial_clock + ); + return Ok(Some(jumpdest_addresses)); + } + transition(state).map_err(|_| { + ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation) + })?; + } + } +} diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index b2a8f0ce..ca9208ea 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -1,3 +1,5 @@ +use std::cmp::min; +use std::collections::HashMap; use std::mem::transmute; use std::str::FromStr; @@ -5,20 +7,26 @@ use anyhow::{bail, Error}; use ethereum_types::{BigEndianHash, H256, U256, U512}; use itertools::{enumerate, Itertools}; use num_bigint::BigUint; +use plonky2::field::extension::Extendable; use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; use serde::{Deserialize, Serialize}; +use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::constants::context_metadata::ContextMetadata; +use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::extension_tower::{FieldExt, Fp12, BLS381, BN254}; use crate::generation::prover_input::EvmField::{ Bls381Base, Bls381Scalar, Bn254Base, Bn254Scalar, Secp256k1Base, Secp256k1Scalar, }; use crate::generation::prover_input::FieldOp::{Inverse, Sqrt}; +use crate::generation::simulate_cpu_between_labels_and_get_user_jumps; use crate::generation::state::GenerationState; use crate::memory::segments::Segment; use crate::memory::segments::Segment::BnPairing; -use crate::util::{biguint_to_mem_vec, mem_vec_to_biguint, u256_to_usize}; -use crate::witness::errors::ProgramError; +use crate::util::{biguint_to_mem_vec, mem_vec_to_biguint, u256_to_u8, u256_to_usize}; use crate::witness::errors::ProverInputError::*; +use crate::witness::errors::{ProgramError, ProverInputError}; use crate::witness::memory::MemoryAddress; use crate::witness::util::{current_context_peek, stack_peek}; @@ -47,6 +55,7 @@ impl GenerationState { "bignum_modmul" => self.run_bignum_modmul(), "withdrawal" => self.run_withdrawal(), "num_bits" => self.run_num_bits(), + "jumpdest_table" => self.run_jumpdest_table(input_fn), _ => Err(ProgramError::ProverInputError(InvalidFunction)), } } @@ -229,6 +238,237 @@ impl GenerationState { Ok(num_bits.into()) } } + + fn run_jumpdest_table(&mut self, input_fn: &ProverInputFn) -> Result { + match input_fn.0[1].as_str() { + "next_address" => self.run_next_jumpdest_table_address(), + "next_proof" => self.run_next_jumpdest_table_proof(), + _ => Err(ProgramError::ProverInputError(InvalidInput)), + } + } + /// Return the next used jump addres + fn run_next_jumpdest_table_address(&mut self) -> Result { + let context = self.registers.context; + let code_len = u256_to_usize(self.memory.get(MemoryAddress { + context, + segment: Segment::ContextMetadata as usize, + virt: ContextMetadata::CodeSize as usize, + }))?; + + if self.jumpdest_proofs.is_none() { + self.generate_jumpdest_proofs()?; + } + + let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else { + return Err(ProgramError::ProverInputError( + ProverInputError::InvalidJumpdestSimulation, + )); + }; + + if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context) + && let Some(next_jumpdest_address) = ctx_jumpdest_proofs.pop() + { + Ok((next_jumpdest_address + 1).into()) + } else { + self.jumpdest_proofs = None; + Ok(U256::zero()) + } + } + + /// Returns the proof for the last jump address. + fn run_next_jumpdest_table_proof(&mut self) -> Result { + let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else { + return Err(ProgramError::ProverInputError( + ProverInputError::InvalidJumpdestSimulation, + )); + }; + if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context) + && let Some(next_jumpdest_proof) = ctx_jumpdest_proofs.pop() + { + Ok(next_jumpdest_proof.into()) + } else { + Err(ProgramError::ProverInputError( + ProverInputError::InvalidJumpdestSimulation, + )) + } + } +} + +impl GenerationState { + fn generate_jumpdest_proofs(&mut self) -> Result<(), ProgramError> { + let checkpoint = self.checkpoint(); + let memory = self.memory.clone(); + + let code = self.get_current_code()?; + // We need to set the simulated jumpdest bits to one as otherwise + // the simulation will fail. + self.set_jumpdest_bits(&code); + + // Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call + let Some(jumpdest_table) = simulate_cpu_between_labels_and_get_user_jumps( + "jumpdest_analisys_end", + "terminate_common", + self, + )? + else { + return Ok(()); + }; + + // Return to the state before starting the simulation + self.rollback(checkpoint); + self.memory = memory; + + // Find proofs for all context + self.set_proofs_and_jumpdests(jumpdest_table); + + Ok(()) + } + + pub(crate) fn set_proofs_and_jumpdests( + &mut self, + jumpdest_table: HashMap>, + ) { + self.jumpdest_proofs = Some(HashMap::from_iter(jumpdest_table.into_iter().map( + |(ctx, jumpdest_table)| { + let code = self.get_code(ctx).unwrap(); + if let Some(&largest_address) = jumpdest_table.last() { + let proofs = get_proofs_and_jumpdests(&code, largest_address, jumpdest_table); + (ctx, proofs) + } else { + (ctx, vec![]) + } + }, + ))); + } + + fn get_current_code(&self) -> Result, ProgramError> { + self.get_code(self.registers.context) + } + + fn get_code(&self, context: usize) -> Result, ProgramError> { + let code_len = self.get_code_len()?; + let code = (0..code_len) + .map(|i| { + u256_to_u8(self.memory.get(MemoryAddress { + context: self.registers.context, + segment: Segment::Code as usize, + virt: i, + })) + }) + .collect::, _>>()?; + Ok(code) + } + + fn get_code_len(&self) -> Result { + let code_len = u256_to_usize(self.memory.get(MemoryAddress { + context: self.registers.context, + segment: Segment::ContextMetadata as usize, + virt: ContextMetadata::CodeSize as usize, + }))?; + Ok(code_len) + } + + fn set_jumpdest_bits<'a>(&mut self, code: &'a Vec) { + const JUMPDEST_OPCODE: u8 = 0x5b; + for (pos, opcode) in CodeIterator::new(&code) { + if opcode == JUMPDEST_OPCODE { + self.memory.set( + MemoryAddress { + context: self.registers.context, + segment: Segment::JumpdestBits as usize, + virt: pos, + }, + U256::one(), + ); + } + } + } +} + +/// For each address in `jumpdest_table`, each bounded by larges_address, +/// this function searches for a proof. A proof is the closest address +/// for which none of the previous 32 bytes in the code (including opcodes +/// and pushed bytes are PUSHXX and the address is in its range. It returns +/// a vector of even size containing proofs followed by their addresses +fn get_proofs_and_jumpdests<'a>( + code: &'a Vec, + largest_address: usize, + jumpdest_table: std::collections::BTreeSet, +) -> Vec { + const PUSH1_OPCODE: u8 = 0x60; + const PUSH32_OPCODE: u8 = 0x7f; + let (proofs, _) = CodeIterator::until(&code, largest_address + 1).fold( + (vec![], 0), + |(mut proofs, acc), (pos, opcode)| { + let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) { + code[prefix_start..pos] + .iter() + .enumerate() + .fold(true, |acc, (prefix_pos, &byte)| { + acc && (byte > PUSH32_OPCODE + || (prefix_start + prefix_pos) as i32 + + (byte as i32 - PUSH1_OPCODE as i32) + + 1 + < pos as i32) + }) + } else { + false + }; + let acc = if has_prefix { pos - 32 } else { acc }; + if jumpdest_table.contains(&pos) { + // Push the proof + proofs.push(acc); + // Push the address + proofs.push(pos); + } + (proofs, acc) + }, + ); + proofs +} + +struct CodeIterator<'a> { + code: &'a [u8], + pos: usize, + end: usize, +} + +impl<'a> CodeIterator<'a> { + fn new(code: &'a [u8]) -> Self { + CodeIterator { + end: code.len(), + code, + pos: 0, + } + } + fn until(code: &'a [u8], end: usize) -> Self { + CodeIterator { + end: std::cmp::min(code.len(), end), + code, + pos: 0, + } + } +} + +impl<'a> Iterator for CodeIterator<'a> { + type Item = (usize, u8); + + fn next(&mut self) -> Option { + const PUSH1_OPCODE: u8 = 0x60; + const PUSH32_OPCODE: u8 = 0x70; + let CodeIterator { code, pos, end } = self; + if *pos >= *end { + return None; + } + let opcode = code[*pos]; + let old_pos = *pos; + *pos += if (PUSH1_OPCODE..=PUSH32_OPCODE).contains(&opcode) { + (opcode - PUSH1_OPCODE + 2).into() + } else { + 1 + }; + Some((old_pos, opcode)) + } } enum EvmField { diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index 89ff0c5a..cc1df091 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap}; use ethereum_types::{Address, BigEndianHash, H160, H256, U256}; use keccak_hash::keccak; @@ -50,6 +50,8 @@ pub(crate) struct GenerationState { /// Pointers, within the `TrieData` segment, of the three MPTs. pub(crate) trie_root_ptrs: TrieRootPtrs, + + pub(crate) jumpdest_proofs: Option>>, } impl GenerationState { @@ -91,6 +93,7 @@ impl GenerationState { txn_root_ptr: 0, receipt_root_ptr: 0, }, + jumpdest_proofs: None, }; let trie_root_ptrs = state.preinitialize_mpts(&inputs.tries); @@ -167,6 +170,26 @@ impl GenerationState { .map(|i| stack_peek(self, i).unwrap()) .collect() } + + /// Clone everything but the traces + pub(crate) fn soft_clone(&self) -> GenerationState { + Self { + inputs: self.inputs.clone(), + registers: self.registers, + memory: self.memory.clone(), + traces: Traces::default(), + rlp_prover_inputs: self.rlp_prover_inputs.clone(), + state_key_to_address: self.state_key_to_address.clone(), + bignum_modmul_result_limbs: self.bignum_modmul_result_limbs.clone(), + withdrawal_prover_inputs: self.withdrawal_prover_inputs.clone(), + trie_root_ptrs: TrieRootPtrs { + state_root_ptr: 0, + txn_root_ptr: 0, + receipt_root_ptr: 0, + }, + jumpdest_proofs: None, + } + } } /// Withdrawals prover input array is of the form `[addr0, amount0, ..., addrN, amountN, U256::MAX, U256::MAX]`. diff --git a/evm/src/util.rs b/evm/src/util.rs index 3d9564b5..fa19ef09 100644 --- a/evm/src/util.rs +++ b/evm/src/util.rs @@ -70,6 +70,11 @@ pub(crate) fn u256_to_u64(u256: U256) -> Result<(F, F), ProgramError> )) } +/// Safe conversion from U256 to u8, which errors in case of overflow instead of panicking. +pub(crate) fn u256_to_u8(u256: U256) -> Result { + u256.try_into().map_err(|_| ProgramError::IntegerTooLarge) +} + /// Safe alternative to `U256::as_usize()`, which errors in case of overflow instead of panicking. pub(crate) fn u256_to_usize(u256: U256) -> Result { u256.try_into().map_err(|_| ProgramError::IntegerTooLarge) diff --git a/evm/src/witness/errors.rs b/evm/src/witness/errors.rs index 5a0fcbfb..1b266aef 100644 --- a/evm/src/witness/errors.rs +++ b/evm/src/witness/errors.rs @@ -36,4 +36,6 @@ pub enum ProverInputError { InvalidInput, InvalidFunction, NumBitsError, + InvalidJumpDestination, + InvalidJumpdestSimulation, } diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index cf2e3bbe..b8f962e7 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -395,7 +395,12 @@ fn try_perform_instruction( if state.registers.is_kernel { log_kernel_instruction(state, op); } else { - log::debug!("User instruction: {:?}", op); + log::debug!( + "User instruction: {:?}, ctx = {:?}, stack = {:?}", + op, + state.registers.context, + state.stack() + ); } fill_op_flag(op, &mut row); From 24ae0d9de09a380fc2b981090f309acda0f27eed Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Wed, 20 Dec 2023 15:27:27 +0100 Subject: [PATCH 22/37] Clippy --- evm/src/generation/prover_input.rs | 10 +++++----- evm/src/util.rs | 7 +------ 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index ca9208ea..21b7c572 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -368,9 +368,9 @@ impl GenerationState { Ok(code_len) } - fn set_jumpdest_bits<'a>(&mut self, code: &'a Vec) { + fn set_jumpdest_bits(&mut self, code: &[u8]) { const JUMPDEST_OPCODE: u8 = 0x5b; - for (pos, opcode) in CodeIterator::new(&code) { + for (pos, opcode) in CodeIterator::new(code) { if opcode == JUMPDEST_OPCODE { self.memory.set( MemoryAddress { @@ -390,14 +390,14 @@ impl GenerationState { /// for which none of the previous 32 bytes in the code (including opcodes /// and pushed bytes are PUSHXX and the address is in its range. It returns /// a vector of even size containing proofs followed by their addresses -fn get_proofs_and_jumpdests<'a>( - code: &'a Vec, +fn get_proofs_and_jumpdests( + code: &[u8], largest_address: usize, jumpdest_table: std::collections::BTreeSet, ) -> Vec { const PUSH1_OPCODE: u8 = 0x60; const PUSH32_OPCODE: u8 = 0x7f; - let (proofs, _) = CodeIterator::until(&code, largest_address + 1).fold( + let (proofs, _) = CodeIterator::until(code, largest_address + 1).fold( (vec![], 0), |(mut proofs, acc), (pos, opcode)| { let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) { diff --git a/evm/src/util.rs b/evm/src/util.rs index fa19ef09..29be65dd 100644 --- a/evm/src/util.rs +++ b/evm/src/util.rs @@ -70,7 +70,7 @@ pub(crate) fn u256_to_u64(u256: U256) -> Result<(F, F), ProgramError> )) } -/// Safe conversion from U256 to u8, which errors in case of overflow instead of panicking. +/// Safe conversion from U256 to u8, which errors in case of overflow. pub(crate) fn u256_to_u8(u256: U256) -> Result { u256.try_into().map_err(|_| ProgramError::IntegerTooLarge) } @@ -80,11 +80,6 @@ pub(crate) fn u256_to_usize(u256: U256) -> Result { u256.try_into().map_err(|_| ProgramError::IntegerTooLarge) } -/// Converts a `U256` to a `u8`, erroring in case of overlow instead of panicking. -pub(crate) fn u256_to_u8(u256: U256) -> Result { - u256.try_into().map_err(|_| ProgramError::IntegerTooLarge) -} - /// Converts a `U256` to a `bool`, erroring in case of overlow instead of panicking. pub(crate) fn u256_to_bool(u256: U256) -> Result { if u256 == U256::zero() { From 9c573a07d4da1aa35ebb2039a443aab31fff9018 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Thu, 28 Dec 2023 14:10:36 +0100 Subject: [PATCH 23/37] Restore simple_transfer and Clippy --- evm/tests/simple_transfer.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/evm/tests/simple_transfer.rs b/evm/tests/simple_transfer.rs index b5c43ff8..5fd252df 100644 --- a/evm/tests/simple_transfer.rs +++ b/evm/tests/simple_transfer.rs @@ -154,9 +154,6 @@ fn test_simple_transfer() -> anyhow::Result<()> { }, }; - let bytes = std::fs::read("jumpi_d18g0v0_Shanghai.json").unwrap(); - let inputs = serde_json::from_slice(&bytes).unwrap(); - let mut timing = TimingTree::new("prove", log::Level::Debug); let proof = prove::(&all_stark, &config, inputs, &mut timing, None)?; timing.filter(Duration::from_millis(100)).print(); From 1a95f7aa7222d30012a82c29eb241c5d2640e33d Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Thu, 28 Dec 2023 16:39:12 +0100 Subject: [PATCH 24/37] Clippy --- evm/src/generation/prover_input.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 60192c74..d7d3082c 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -397,7 +397,7 @@ fn get_proofs_and_jumpdests( ) -> Vec { const PUSH1_OPCODE: u8 = 0x60; const PUSH32_OPCODE: u8 = 0x7f; - let (proofs, _) = CodeIterator::until(&code, largest_address + 1).fold( + let (proofs, _) = CodeIterator::until(code, largest_address + 1).fold( (vec![], 0), |(mut proofs, acc), (pos, opcode)| { let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) { From ab4508fc8b2cef8d93458b208269586a1e900037 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Fri, 5 Jan 2024 16:07:33 +0100 Subject: [PATCH 25/37] Add packed verification --- .../cpu/kernel/asm/core/jumpdest_analysis.asm | 148 ++++++++++++++++-- evm/src/cpu/kernel/asm/util/basic_macros.asm | 2 +- 2 files changed, 139 insertions(+), 11 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index 8debead3..8625683a 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -111,25 +111,151 @@ global write_table_if_jumpdest: MLOAD_GENERAL // stack: opcode, jumpdest, ctx, proof_prefix_addr, retdest - %jump_eq_const(0x5b, return) + %jump_neq_const(0x5b, return) //stack: jumpdest, ctx, proof_prefix_addr, retdest SWAP2 DUP1 // stack: proof_prefix_addr, proof_prefix_addr, ctx, jumpdest ISZERO %jumpi(verify_path_and_write_table) + + // stack: proof_prefix_addr, ctx, jumpdest, retdest // If we are here we need to check that the next 32 bytes are less // than JUMPXX for XX < 32 - i <=> opcode < 0x7f - i = 127 - i, 0 <= i < 32, // or larger than 127 - %check_and_step(127) %check_and_step(126) %check_and_step(125) %check_and_step(124) - %check_and_step(123) %check_and_step(122) %check_and_step(121) %check_and_step(120) - %check_and_step(119) %check_and_step(118) %check_and_step(117) %check_and_step(116) - %check_and_step(115) %check_and_step(114) %check_and_step(113) %check_and_step(112) - %check_and_step(111) %check_and_step(110) %check_and_step(109) %check_and_step(108) - %check_and_step(107) %check_and_step(106) %check_and_step(105) %check_and_step(104) - %check_and_step(103) %check_and_step(102) %check_and_step(101) %check_and_step(100) - %check_and_step(99) %check_and_step(98) %check_and_step(97) %check_and_step(96) + + %stack + (proof_prefix_addr, ctx) -> + (ctx, @SEGMENT_CODE, proof_prefix_addr, 32, proof_prefix_addr, ctx) + %mload_packing + // packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + DUP1 %shl_const(1) + DUP2 %shl_const(2) + AND + // stack: (is_1_at_pos_2_and_3|(X)⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + // X denotes any value in {0,1} and Z^i is Z repeated i times + NOT + // stack: (is_0_at_2_or_3|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + DUP2 + OR + // stack: (is_1_at_1 or is_0_at_2_or_3|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + // stack: (~has_prefix|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + + // Compute in_range = + // - (0xFF|X⁷)³² for the first 15 bytes + // - (has_prefix => is_0_at_4 |X⁷)³² for the next 15 bytes + // - (~has_prefix|X⁷)³² for the last byte + // Compute also that ~has_prefix = ~has_prefix OR is_0_at_4 for all bytes. We don't need to update ~hash_prefix + // for the second half but it takes less cycles if we do it. + DUP2 %shl_const(3) + NOT + // stack: (is_0_at_4|X⁷)³², (~has_prefix|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + // pos 0102030405060708091011121314151617181920212223242526272829303132 + PUSH 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00 + AND + // stack: (is_0_at_4|X⁷)³¹|0⁸, (~has_prefix|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + DUP2 + DUP2 + OR + // pos 0102030405060708091011121314151617181920212223242526272829303132 + PUSH 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF0000000000000000000000000000000000 + OR + // stack: (in_range|X⁷)³², (is_0_at_4|X⁷)³², (~has_prefix|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + SWAP2 + OR + // stack: (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + + // Compute in_range' = in_range AND + // - (0xFF|X⁷)³² for bytes in positions 1-7 and 16-23 + // - (has_prefix => is_0_at_5 |X⁷)³² on the rest + // Compute also ~has_prefix = ~has_prefix OR is_0_at_5 for all bytes. + + DUP3 %shl_const(4) + NOT + // stack: (is_0_at_5|X⁷)³², (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + DUP2 + DUP2 + OR + // pos 0102030405060708091011121314151617181920212223242526272829303132 + PUSH 0xFFFFFFFFFFFFFF0000000000000000FFFFFFFFFFFFFFFF000000000000000000 + OR + // stack: (in_range'|X⁷)³², (is_0_at_5|X⁷)³², (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + SWAP2 + OR + // stack: (~has_prefix|X⁷)³², (in_range'|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + SWAP2 + AND + SWAP1 + + // Compute in_range' = in_range AND + // - (0xFF|X⁷)³² for bytes in positions 1-2, 8-11, 16-19, and 24-27 + // - (has_prefix => is_0_at_6 |X⁷)³² on the rest + // Compute also that ~has_prefix = ~has_prefix OR is_0_at_4 for all bytes. + + // stack: (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + DUP3 %shl_const(5) + NOT + // stack: (is_0_at_6|X⁷)³², (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + DUP2 + DUP2 + OR + // pos 0102030405060708091011121314151617181920212223242526272829303132 + PUSH 0xFFFFFF00000000FFFFFFFF00000000FFFFFFFF00000000FFFFFFFF0000000000 + OR + // stack: (in_range'|X⁷)³², (is_0_at_6|X⁷)³², (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + SWAP2 + OR + // stack: (~has_prefix|X⁷)³², (in_range'|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + SWAP2 + AND + SWAP1 + + // Compute in_range' = in_range AND + // - (0xFF|X⁷)³² for bytes in positions 1-7 and 16-23 + // - (has_prefix => is_0_at_7 |X⁷)³² on the rest + // Compute also that ~has_prefix = ~has_prefix OR is_0_at_7 for all bytes. + + // stack: (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + DUP3 %shl_const(6) + NOT + // stack: (is_0_at_7|X⁷)³², (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + DUP2 + DUP2 + OR + // pos 0102030405060708091011121314151617181920212223242526272829303132 + PUSH 0xFF0000FFFF0000FFFF0000FFFF0000FFFF0000FFFF0000FFFF0000FFFF000000 + OR + // stack: (in_range'|X⁷)³², (is_0_at_7|X⁷)³², (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + SWAP2 + OR + // stack: (~has_prefix|X⁷)³², (in_range'|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + SWAP2 + AND + SWAP1 + + // Compute in_range' = in_range AND + // - (0xFF|X⁷)³² for bytes in odd positions + // - (has_prefix => is_0_at_8 |X⁷)³² on the rest + // Compute also that ~has_prefix = ~has_prefix OR is_0_at_7 for all bytes. + + // stack: (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + DUP3 %shl_const(7) + NOT + // stack: (is_0_at_8|X⁷)³², (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + OR + // pos 0102030405060708091011121314151617181920212223242526272829303132 + PUSH 0x00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF + OR + AND + // stack: (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + + // Get rid of the irrelevant bits + // pos 0102030405060708091011121314151617181920212223242526272829303132 + PUSH 0x8080808080808080808080808080808080808080808080808080808080808080 + AND + %assert_eq_const(0x8080808080808080808080808080808080808080808080808080808080808080) + POP // check the remaining path %jump(verify_path_and_write_table) @@ -139,7 +265,9 @@ return: JUMP -// Chek if the opcode pointed by proof_prefix address is + + +// Check if the opcode pointed by proof_prefix address is // less than max and increment proof_prefix_addr %macro check_and_step(max) %stack diff --git a/evm/src/cpu/kernel/asm/util/basic_macros.asm b/evm/src/cpu/kernel/asm/util/basic_macros.asm index d62dc27e..55debe12 100644 --- a/evm/src/cpu/kernel/asm/util/basic_macros.asm +++ b/evm/src/cpu/kernel/asm/util/basic_macros.asm @@ -8,7 +8,7 @@ jumpi %endmacro -%macro jump_eq_const(c, jumpdest) +%macro jump_neq_const(c, jumpdest) PUSH $c SUB %jumpi($jumpdest) From 8f1efa155436c2cb2a45faf1dc53aa685a11a1dc Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Fri, 5 Jan 2024 16:15:16 +0100 Subject: [PATCH 26/37] Fix minor error --- evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm | 1 + 1 file changed, 1 insertion(+) diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index 8625683a..0e92fcac 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -256,6 +256,7 @@ global write_table_if_jumpdest: AND %assert_eq_const(0x8080808080808080808080808080808080808080808080808080808080808080) POP + %add_const(32) // check the remaining path %jump(verify_path_and_write_table) From 247d655b3903c94167d3b5960180798239c846dd Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Fri, 5 Jan 2024 17:00:32 +0100 Subject: [PATCH 27/37] Minor --- evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index 0e92fcac..59fbaa03 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -146,7 +146,7 @@ global write_table_if_jumpdest: // - (0xFF|X⁷)³² for the first 15 bytes // - (has_prefix => is_0_at_4 |X⁷)³² for the next 15 bytes // - (~has_prefix|X⁷)³² for the last byte - // Compute also that ~has_prefix = ~has_prefix OR is_0_at_4 for all bytes. We don't need to update ~hash_prefix + // Compute also ~has_prefix = ~has_prefix OR is_0_at_4 for all bytes. We don't need to update ~hash_prefix // for the second half but it takes less cycles if we do it. DUP2 %shl_const(3) NOT @@ -237,7 +237,6 @@ global write_table_if_jumpdest: // Compute in_range' = in_range AND // - (0xFF|X⁷)³² for bytes in odd positions // - (has_prefix => is_0_at_8 |X⁷)³² on the rest - // Compute also that ~has_prefix = ~has_prefix OR is_0_at_7 for all bytes. // stack: (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest DUP3 %shl_const(7) From 897ba5856aecf3948a900877b67fb8644c8f3b85 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Fri, 5 Jan 2024 17:57:46 +0100 Subject: [PATCH 28/37] Remove assertion in packed verif --- evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm | 1 + 1 file changed, 1 insertion(+) diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index 59fbaa03..bce399bb 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -254,6 +254,7 @@ global write_table_if_jumpdest: PUSH 0x8080808080808080808080808080808080808080808080808080808080808080 AND %assert_eq_const(0x8080808080808080808080808080808080808080808080808080808080808080) + %jumpi(return) POP %add_const(32) From 18a14bf2f293e8a1bd9704a8238aa0a1ee929ce1 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Fri, 5 Jan 2024 18:15:57 +0100 Subject: [PATCH 29/37] Remove assertion --- evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index bce399bb..8f66b369 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -253,8 +253,7 @@ global write_table_if_jumpdest: // pos 0102030405060708091011121314151617181920212223242526272829303132 PUSH 0x8080808080808080808080808080808080808080808080808080808080808080 AND - %assert_eq_const(0x8080808080808080808080808080808080808080808080808080808080808080) - %jumpi(return) + %jump_neq_const(0x8080808080808080808080808080808080808080808080808080808080808080, return) POP %add_const(32) From 47b428569d303eae8d51ea736ec60df3f033d5a3 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Mon, 8 Jan 2024 13:53:52 +0100 Subject: [PATCH 30/37] Remove unused macro --- .../cpu/kernel/asm/core/jumpdest_analysis.asm | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index 8f66b369..aa37d23d 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -264,34 +264,6 @@ return: %pop3 JUMP - - - -// Check if the opcode pointed by proof_prefix address is -// less than max and increment proof_prefix_addr -%macro check_and_step(max) - %stack - (proof_prefix_addr, ctx, jumpdest) -> - (ctx, @SEGMENT_CODE, proof_prefix_addr, proof_prefix_addr, ctx, jumpdest) - MLOAD_GENERAL - // stack: opcode, ctx, proof_prefix_addr, jumpdest - DUP1 - %gt_const(127) - %jumpi(%%ok) - %jumpi_lt_const($max, return) - // stack: proof_prefix_addr, ctx, jumpdest - PUSH 0 // We need something to pop -%%ok: - POP - %increment -%endmacro - -%macro write_table_if_jumpdest - %stack (proof, addr, ctx) -> (proof, addr, ctx, %%after) - %jump(write_table_if_jumpdest) -%%after: -%endmacro - // Write the jumpdest table. This is done by // non-deterministically guessing the sequence of jumpdest // addresses used during program execution within the current context. From 92aaa404da05bd51434cbea3f15415007ed8eb82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alonso=20Gonz=C3=A1lez?= Date: Wed, 10 Jan 2024 13:55:09 +0100 Subject: [PATCH 31/37] Apply suggestions from code review Co-authored-by: Robin Salen <30937548+Nashtare@users.noreply.github.com> --- evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm | 4 ++-- evm/src/generation/mod.rs | 2 +- evm/src/generation/prover_input.rs | 6 +++--- evm/src/generation/state.rs | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index ed9bb06f..aed6c186 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -41,7 +41,7 @@ continue: proof_ok: // stack: i, ctx, final_pos, retdest - // We already know final pos is a jumpdest + // We already know final_pos is a jumpdest %stack (i, ctx, final_pos) -> (ctx, @SEGMENT_JUMPDEST_BITS, i) %build_address PUSH 1 @@ -99,7 +99,7 @@ code_bytes_to_skip: %endrep -// A proof attesting that jumpdest is a valid jump destinations is +// A proof attesting that jumpdest is a valid jump destination is // either 0 or an index 0 < i <= jumpdest - 32. // A proof is valid if: // - i == 0 and we can go from the first opcode to jumpdest and code[jumpdest] = 0x5b diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 318c7d4f..238011d8 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -364,7 +364,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps( log::debug!("Simulating CPU for jumpdest analysis."); loop { - // skip jumdest table validations in simulations + // skip jumpdest table validations in simulations if state.registers.program_counter == KERNEL.global_labels["jumpdest_analysis"] { state.registers.program_counter = KERNEL.global_labels["jumpdest_analysis_end"] } diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index b5541f1b..97177a8c 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -247,7 +247,8 @@ impl GenerationState { _ => Err(ProgramError::ProverInputError(InvalidInput)), } } - /// Return the next used jump addres + + /// Returns the next used jump address. fn run_next_jumpdest_table_address(&mut self) -> Result { let context = self.registers.context; let code_len = u256_to_usize(self.get_code_len()?.into()); @@ -299,7 +300,6 @@ impl GenerationState { let code = self.get_current_code()?; // We need to set the simulated jumpdest bits to one as otherwise // the simulation will fail. - // self.set_jumpdest_bits(&code); // Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call let Some(jumpdest_table) = simulate_cpu_between_labels_and_get_user_jumps( @@ -383,7 +383,7 @@ impl GenerationState { /// this function searches for a proof. A proof is the closest address /// for which none of the previous 32 bytes in the code (including opcodes /// and pushed bytes are PUSHXX and the address is in its range. It returns -/// a vector of even size containing proofs followed by their addresses +/// a vector of even size containing proofs followed by their addresses. fn get_proofs_and_jumpdests( code: &[u8], largest_address: usize, diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index f844c5ce..f15ab317 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -169,7 +169,7 @@ impl GenerationState { .collect() } - /// Clone everything but the traces + /// Clones everything but the traces. pub(crate) fn soft_clone(&self) -> GenerationState { Self { inputs: self.inputs.clone(), From ae4a720a746277c20136f20e8abe306fdf0d759a Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Wed, 10 Jan 2024 17:26:34 +0100 Subject: [PATCH 32/37] Address comments --- evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm | 8 ++++---- evm/src/generation/mod.rs | 6 +++--- evm/src/generation/prover_input.rs | 4 ++++ evm/src/util.rs | 12 ++++++------ 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index aed6c186..c197fb6a 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -2,7 +2,7 @@ // for the given context's code. // Pre stack: init_pos, ctx, final_pos, retdest // Post stack: (empty) -global verify_path_and_write_table: +global verify_path_and_write_jumpdest_table: loop: // stack: i, ctx, final_pos, retdest DUP3 DUP2 EQ // i == final_pos @@ -10,7 +10,7 @@ loop: DUP3 DUP2 GT // i > final_pos %jumpi(proof_not_ok) - // stack: i, ctx, code_len, retdest + // stack: i, ctx, final_pos, retdest %stack (i, ctx) -> (ctx, i, i, ctx) ADD // combine context and offset to make an address (SEGMENT_CODE == 0) MLOAD_GENERAL @@ -124,7 +124,7 @@ global write_table_if_jumpdest: SWAP2 DUP1 // stack: proof_prefix_addr, proof_prefix_addr, ctx, jumpdest ISZERO - %jumpi(verify_path_and_write_table) + %jumpi(verify_path_and_write_jumpdest_table) // stack: proof_prefix_addr, ctx, jumpdest, retdest @@ -266,7 +266,7 @@ global write_table_if_jumpdest: %add_const(32) // check the remaining path - %jump(verify_path_and_write_table) + %jump(verify_path_and_write_jumpdest_table) return: // stack: proof_prefix_addr, jumpdest, ctx, retdest %pop3 diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 238011d8..e14de9b9 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -379,7 +379,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps( state.registers.program_counter, ))) else { log::debug!( - "Simulated CPU halted after {} cycles", + "Simulated CPU for jumpdest analysis halted after {} cycles", state.traces.clock() - initial_clock ); return Some(jumpdest_addresses); @@ -395,7 +395,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps( // Avoid deeper calls to abort let Ok(jumpdest) = u256_to_usize(state.registers.stack_top) else { log::debug!( - "Simulated CPU halted after {} cycles", + "Simulated CPU for jumpdest analysis halted after {} cycles", state.traces.clock() - initial_clock ); return Some(jumpdest_addresses); @@ -416,7 +416,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps( } if halt || transition(state).is_err() { log::debug!( - "Simulated CPU halted after {} cycles", + "Simulated CPU for jumpdest analysis halted after {} cycles", state.traces.clock() - initial_clock ); return Some(jumpdest_addresses); diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index 97177a8c..ccec9c45 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -240,6 +240,7 @@ impl GenerationState { } } + /// Generate the either the next used jump address or the the proof for the last jump address. fn run_jumpdest_table(&mut self, input_fn: &ProverInputFn) -> Result { match input_fn.0[1].as_str() { "next_address" => self.run_next_jumpdest_table_address(), @@ -293,6 +294,7 @@ impl GenerationState { } impl GenerationState { + /// Simulate the user's code and store all the jump addresses with their respective contexts. fn generate_jumpdest_proofs(&mut self) -> Result<(), ProgramError> { let checkpoint = self.checkpoint(); let memory = self.memory.clone(); @@ -322,6 +324,8 @@ impl GenerationState { Ok(()) } + /// Given a HashMap containing the contexts and the jumpdest addresses, compute their respective proofs, + /// by calling `get_proofs_and_jumpdests` pub(crate) fn set_proofs_and_jumpdests( &mut self, jumpdest_table: HashMap>, diff --git a/evm/src/util.rs b/evm/src/util.rs index ee2d9607..f4d80859 100644 --- a/evm/src/util.rs +++ b/evm/src/util.rs @@ -70,17 +70,17 @@ pub(crate) fn u256_to_u64(u256: U256) -> Result<(F, F), ProgramError> )) } -/// Safe conversion from U256 to u8, which errors in case of overflow instead of panicking. -pub(crate) fn u256_to_u8(u256: U256) -> Result { - u256.try_into().map_err(|_| ProgramError::IntegerTooLarge) -} - /// Safe alternative to `U256::as_usize()`, which errors in case of overflow instead of panicking. pub(crate) fn u256_to_usize(u256: U256) -> Result { u256.try_into().map_err(|_| ProgramError::IntegerTooLarge) } -/// Converts a `U256` to a `bool`, erroring in case of overlow instead of panicking. +/// Converts a `U256` to a `u8`, erroring in case of overflow instead of panicking. +pub(crate) fn u256_to_u8(u256: U256) -> Result { + u256.try_into().map_err(|_| ProgramError::IntegerTooLarge) +} + +/// Converts a `U256` to a `bool`, erroring in case of overflow instead of panicking. pub(crate) fn u256_to_bool(u256: U256) -> Result { if u256 == U256::zero() { Ok(false) From 99a1eb5c85cb8d6d094b916ae9db9c0df4c4ce72 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Wed, 10 Jan 2024 17:58:41 +0100 Subject: [PATCH 33/37] Missing review comments --- evm/src/generation/prover_input.rs | 16 +++++++++------- evm/src/generation/state.rs | 4 ++++ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index ccec9c45..c8b78e62 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -312,7 +312,6 @@ impl GenerationState { self.jumpdest_proofs = Some(HashMap::new()); return Ok(()); }; - log::debug!("jumpdest_table = {:?}", jumpdest_table); // Return to the state before starting the simulation self.rollback(checkpoint); @@ -383,7 +382,7 @@ impl GenerationState { } } -/// For all address in `jumpdest_table`, each bounded by larges_address, +/// For all address in `jumpdest_table`, each bounded by `largest_address`, /// this function searches for a proof. A proof is the closest address /// for which none of the previous 32 bytes in the code (including opcodes /// and pushed bytes are PUSHXX and the address is in its range. It returns @@ -403,11 +402,12 @@ fn get_proofs_and_jumpdests( .iter() .enumerate() .fold(true, |acc, (prefix_pos, &byte)| { - acc && (byte > PUSH32_OPCODE - || (prefix_start + prefix_pos) as i32 - + (byte as i32 - PUSH1_OPCODE as i32) - + 1 - < pos as i32) + let cond1 = byte > PUSH32_OPCODE; + let cond2 = (prefix_start + prefix_pos) as i32 + + (byte as i32 - PUSH1_OPCODE as i32) + + 1 + < pos as i32; + acc && (cond1 || cond2) }) } else { false @@ -425,6 +425,8 @@ fn get_proofs_and_jumpdests( proofs } +/// An iterator over the EVM code contained in `code`, which skips the bytes +/// that are the arguments of a PUSHXX opcode. struct CodeIterator<'a> { code: &'a [u8], pos: usize, diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index f15ab317..ddc2f359 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -51,6 +51,10 @@ pub(crate) struct GenerationState { /// Pointers, within the `TrieData` segment, of the three MPTs. pub(crate) trie_root_ptrs: TrieRootPtrs, + /// A hash map where the key is a context in the user's code and the value is the set of + /// jump destinations with its corresponding "proof". A "proof" for a jump destination is + /// either 0 or an address i > 32 in the code (not necessarily pointing to an opcode) such that + /// for every j in [i, i+32] it holds that code[j] < 0x7f - j + i. pub(crate) jumpdest_proofs: Option>>, } From 6ef0a3c738d48a6ebda9dba267b203e31742a93a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alonso=20Gonz=C3=A1lez?= Date: Thu, 11 Jan 2024 10:55:04 +0100 Subject: [PATCH 34/37] Apply suggestions from code review Co-authored-by: Linda Guiga <101227802+LindaGuiga@users.noreply.github.com> --- evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm | 6 +++--- evm/src/generation/prover_input.rs | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index c197fb6a..267bf528 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -154,7 +154,7 @@ global write_table_if_jumpdest: // - (0xFF|X⁷)³² for the first 15 bytes // - (has_prefix => is_0_at_4 |X⁷)³² for the next 15 bytes // - (~has_prefix|X⁷)³² for the last byte - // Compute also ~has_prefix = ~has_prefix OR is_0_at_4 for all bytes. We don't need to update ~hash_prefix + // Compute also ~has_prefix = ~has_prefix OR is_0_at_4 for all bytes. We don't need to update ~has_prefix // for the second half but it takes less cycles if we do it. DUP2 %shl_const(3) NOT @@ -283,7 +283,7 @@ return: // addresses used during program execution within the current context. // For each jumpdest address we also non-deterministically guess // a proof, which is another address in the code such that -// is_jumpdest don't abort, when the proof is at the top of the stack +// is_jumpdest doesn't abort, when the proof is at the top of the stack // an the jumpdest address below. If that's the case we set the // corresponding bit in @SEGMENT_JUMPDEST_BITS to 1. // @@ -297,7 +297,7 @@ global jumpdest_analysis: // If proof == 0 there are no more jump destinations to check POP // This is just a hook used for avoiding verification of the jumpdest -// table in another contexts. It is useful during proof generation, +// table in another context. It is useful during proof generation, // allowing the avoidance of table verification when simulating user code. global jumpdest_analysis_end: %pop2 diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index c8b78e62..f3078239 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -240,7 +240,7 @@ impl GenerationState { } } - /// Generate the either the next used jump address or the the proof for the last jump address. + /// Generate either the next used jump address or the proof for the last jump address. fn run_jumpdest_table(&mut self, input_fn: &ProverInputFn) -> Result { match input_fn.0[1].as_str() { "next_address" => self.run_next_jumpdest_table_address(), @@ -385,7 +385,7 @@ impl GenerationState { /// For all address in `jumpdest_table`, each bounded by `largest_address`, /// this function searches for a proof. A proof is the closest address /// for which none of the previous 32 bytes in the code (including opcodes -/// and pushed bytes are PUSHXX and the address is in its range. It returns +/// and pushed bytes) are PUSHXX and the address is in its range. It returns /// a vector of even size containing proofs followed by their addresses. fn get_proofs_and_jumpdests( code: &[u8], From bead1d60a75d8ace1bdc03d05d1c0f8e078226a5 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Thu, 11 Jan 2024 13:32:22 +0100 Subject: [PATCH 35/37] Adress review comments --- .../cpu/kernel/asm/core/jumpdest_analysis.asm | 21 +++++++++++++------ evm/src/cpu/kernel/asm/util/basic_macros.asm | 4 ++-- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index 267bf528..fef192ce 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -104,9 +104,18 @@ code_bytes_to_skip: // A proof is valid if: // - i == 0 and we can go from the first opcode to jumpdest and code[jumpdest] = 0x5b // - i > 0 and: -// - for j in {i+0,..., i+31} code[j] != PUSHk for all k >= 32 - j - i, -// - we can go from opcode i+32 to jumpdest, -// - code[jumpdest] = 0x5b. +// a) for j in {i+0,..., i+31} code[j] != PUSHk for all k >= 32 - j - i, +// b) we can go from opcode i+32 to jumpdest, +// c) code[jumpdest] = 0x5b. +// To reduce the number of instructions, when i > 32 we load all the bytes code[j], ..., +// code[j + 31] in a single 32-byte word, and check a) directly on the packed bytes. +// We perform the "packed verification" computing a boolean formula evaluated on the bits of +// code[j],..., code[j+31] of the form p_1 AND p_2 AND p_3 AND p_4 AND p_5, where: +// - p_k is either TRUE, for one subset of the j's which depends on k (for example, +// for k = 1, it is TRUE for the first 15 positions), or has_prefix_k => bit_{k + 1}_is_0 +// for the j's not in the subset. +// - has_prefix_k is a predicate that is TRUE if and only if code[j] has the same prefix of size k + 2 +// as PUSH{32-(j-i)}. // stack: proof_prefix_addr, jumpdest, ctx, retdest // stack: (empty) global write_table_if_jumpdest: @@ -197,7 +206,7 @@ global write_table_if_jumpdest: SWAP1 // Compute in_range' = in_range AND - // - (0xFF|X⁷)³² for bytes in positions 1-2, 8-11, 16-19, and 24-27 + // - (0xFF|X⁷)³² for bytes in positions 1-3, 8-11, 16-19, and 24-27 // - (has_prefix => is_0_at_6 |X⁷)³² on the rest // Compute also that ~has_prefix = ~has_prefix OR is_0_at_4 for all bytes. @@ -294,7 +303,7 @@ global jumpdest_analysis: // and the next prover input should contain a proof for address'. PROVER_INPUT(jumpdest_table::next_address) DUP1 %jumpi(check_proof) - // If proof == 0 there are no more jump destinations to check + // If address == 0 there are no more jump destinations to check POP // This is just a hook used for avoiding verification of the jumpdest // table in another context. It is useful during proof generation, @@ -303,7 +312,7 @@ global jumpdest_analysis_end: %pop2 JUMP check_proof: - // stack: proof, ctx, code_len, retdest + // stack: address, ctx, code_len, retdest DUP3 DUP2 %assert_le %decrement // stack: proof, ctx, code_len, retdest diff --git a/evm/src/cpu/kernel/asm/util/basic_macros.asm b/evm/src/cpu/kernel/asm/util/basic_macros.asm index 566614bd..855eb4bb 100644 --- a/evm/src/cpu/kernel/asm/util/basic_macros.asm +++ b/evm/src/cpu/kernel/asm/util/basic_macros.asm @@ -8,15 +8,15 @@ jumpi %endmacro +// Jump to `jumpdest` if the top of the stack is != c %macro jump_neq_const(c, jumpdest) PUSH $c SUB %jumpi($jumpdest) %endmacro +// Jump to `jumpdest` if the top of the stack is < c %macro jumpi_lt_const(c, jumpdest) - // %assert_zero is cheaper than %assert_nonzero, so we will leverage the - // fact that (x < c) == !(x >= c). %ge_const($c) %jumpi($jumpdest) %endmacro From f9c3ad6646c1047c3a44d2eb89d48f001c4e6ae6 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Thu, 11 Jan 2024 13:44:20 +0100 Subject: [PATCH 36/37] Update empty_txn_list --- evm/tests/empty_txn_list.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evm/tests/empty_txn_list.rs b/evm/tests/empty_txn_list.rs index ff4e7637..2130e51b 100644 --- a/evm/tests/empty_txn_list.rs +++ b/evm/tests/empty_txn_list.rs @@ -77,7 +77,7 @@ fn test_empty_txn_list() -> anyhow::Result<()> { // Initialize the preprocessed circuits for the zkEVM. let all_circuits = AllRecursiveCircuits::::new( &all_stark, - &[16..17, 10..11, 12..13, 14..15, 9..11, 12..13, 17..18], // Minimal ranges to prove an empty list + &[16..17, 9..11, 12..13, 14..15, 9..11, 12..13, 17..18], // Minimal ranges to prove an empty list &config, ); From ac9f704f97caebefeb934960653015acb261759c Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Thu, 11 Jan 2024 14:26:28 +0100 Subject: [PATCH 37/37] Fix comment --- evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index fef192ce..9dee5d2b 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -229,7 +229,7 @@ global write_table_if_jumpdest: SWAP1 // Compute in_range' = in_range AND - // - (0xFF|X⁷)³² for bytes in positions 1-7 and 16-23 + // - (0xFF|X⁷)³² for bytes in 1, 4-5, 8-9, 12-13, 16-17, 20-21, 24-25, 28-29 // - (has_prefix => is_0_at_7 |X⁷)³² on the rest // Compute also that ~has_prefix = ~has_prefix OR is_0_at_7 for all bytes.