diff --git a/evm/src/cpu/kernel/asm/core/call.asm b/evm/src/cpu/kernel/asm/core/call.asm index aa8fbf0c..b5b89354 100644 --- a/evm/src/cpu/kernel/asm/core/call.asm +++ b/evm/src/cpu/kernel/asm/core/call.asm @@ -387,12 +387,10 @@ call_too_deep: %checkpoint // Checkpoint %increment_call_depth // Perform jumpdest analyis - PUSH %%after %mload_context_metadata(@CTX_METADATA_CODE_SIZE) GET_CONTEXT // stack: ctx, code_size, retdest - %jump(jumpdest_analysis) -%%after: + %jumpdest_analysis PUSH 0 // jump dest EXIT_KERNEL // (Old context) stack: new_ctx diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index a43f301a..9dee5d2b 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -1,29 +1,26 @@ -// Populates @SEGMENT_JUMPDEST_BITS for the given context's code. -// Pre stack: ctx, code_len, retdest +// Set @SEGMENT_JUMPDEST_BITS to one between positions [init_pos, final_pos], +// for the given context's code. +// Pre stack: init_pos, ctx, final_pos, retdest // Post stack: (empty) -global jumpdest_analysis: - // stack: ctx, code_len, retdest - PUSH 0 // i = 0 - +global verify_path_and_write_jumpdest_table: loop: - // stack: i, ctx, code_len, retdest - // Ideally we would break if i >= code_len, but checking i > code_len is - // cheaper. It doesn't hurt to over-read by 1, since we'll read 0 which is - // a no-op. - DUP3 DUP2 GT // i > code_len - %jumpi(return) + // stack: i, ctx, final_pos, retdest + DUP3 DUP2 EQ // i == final_pos + %jumpi(proof_ok) + DUP3 DUP2 GT // i > final_pos + %jumpi(proof_not_ok) - // stack: i, ctx, code_len, retdest + // stack: i, ctx, final_pos, retdest %stack (i, ctx) -> (ctx, i, i, ctx) ADD // combine context and offset to make an address (SEGMENT_CODE == 0) MLOAD_GENERAL - // stack: opcode, i, ctx, code_len, retdest + // stack: opcode, i, ctx, final_pos, retdest DUP1 // Slightly more efficient than `%eq_const(0x5b) ISZERO` PUSH 0x5b SUB - // stack: opcode != JUMPDEST, opcode, i, ctx, code_len, retdest + // stack: opcode != JUMPDEST, opcode, i, ctx, final_pos, retdest %jumpi(continue) // stack: JUMPDEST, i, ctx, code_len, retdest @@ -34,16 +31,23 @@ loop: MSTORE_GENERAL continue: - // stack: opcode, i, ctx, code_len, retdest + // stack: opcode, i, ctx, final_pos, retdest %add_const(code_bytes_to_skip) %mload_kernel_code - // stack: bytes_to_skip, i, ctx, code_len, retdest + // stack: bytes_to_skip, i, ctx, final_pos, retdest ADD - // stack: i, ctx, code_len, retdest + // stack: i, ctx, final_pos, retdest %jump(loop) -return: - // stack: i, ctx, code_len, retdest +proof_ok: + // stack: i, ctx, final_pos, retdest + // We already know final_pos is a jumpdest + %stack (i, ctx, final_pos) -> (ctx, @SEGMENT_JUMPDEST_BITS, i) + %build_address + PUSH 1 + MSTORE_GENERAL + JUMP +proof_not_ok: %pop3 JUMP @@ -93,3 +97,237 @@ code_bytes_to_skip: %rep 128 BYTES 1 // 0x80-0xff %endrep + + +// A proof attesting that jumpdest is a valid jump destination is +// either 0 or an index 0 < i <= jumpdest - 32. +// A proof is valid if: +// - i == 0 and we can go from the first opcode to jumpdest and code[jumpdest] = 0x5b +// - i > 0 and: +// a) for j in {i+0,..., i+31} code[j] != PUSHk for all k >= 32 - j - i, +// b) we can go from opcode i+32 to jumpdest, +// c) code[jumpdest] = 0x5b. +// To reduce the number of instructions, when i > 32 we load all the bytes code[j], ..., +// code[j + 31] in a single 32-byte word, and check a) directly on the packed bytes. +// We perform the "packed verification" computing a boolean formula evaluated on the bits of +// code[j],..., code[j+31] of the form p_1 AND p_2 AND p_3 AND p_4 AND p_5, where: +// - p_k is either TRUE, for one subset of the j's which depends on k (for example, +// for k = 1, it is TRUE for the first 15 positions), or has_prefix_k => bit_{k + 1}_is_0 +// for the j's not in the subset. +// - has_prefix_k is a predicate that is TRUE if and only if code[j] has the same prefix of size k + 2 +// as PUSH{32-(j-i)}. +// stack: proof_prefix_addr, jumpdest, ctx, retdest +// stack: (empty) +global write_table_if_jumpdest: + // stack: proof_prefix_addr, jumpdest, ctx, retdest + %stack + (proof_prefix_addr, jumpdest, ctx) -> + (ctx, jumpdest, jumpdest, ctx, proof_prefix_addr) + ADD // combine context and offset to make an address (SEGMENT_CODE == 0) + MLOAD_GENERAL + // stack: opcode, jumpdest, ctx, proof_prefix_addr, retdest + + %jump_neq_const(0x5b, return) + + //stack: jumpdest, ctx, proof_prefix_addr, retdest + SWAP2 DUP1 + // stack: proof_prefix_addr, proof_prefix_addr, ctx, jumpdest + ISZERO + %jumpi(verify_path_and_write_jumpdest_table) + + + // stack: proof_prefix_addr, ctx, jumpdest, retdest + // If we are here we need to check that the next 32 bytes are less + // than JUMPXX for XX < 32 - i <=> opcode < 0x7f - i = 127 - i, 0 <= i < 32, + // or larger than 127 + + %stack + (proof_prefix_addr, ctx) -> + (ctx, proof_prefix_addr, 32, proof_prefix_addr, ctx) + ADD // combine context and offset to make an address (SEGMENT_CODE == 0) + %mload_packing + // packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + DUP1 %shl_const(1) + DUP2 %shl_const(2) + AND + // stack: (is_1_at_pos_2_and_3|(X)⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + // X denotes any value in {0,1} and Z^i is Z repeated i times + NOT + // stack: (is_0_at_2_or_3|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + DUP2 + OR + // stack: (is_1_at_1 or is_0_at_2_or_3|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + // stack: (~has_prefix|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + + // Compute in_range = + // - (0xFF|X⁷)³² for the first 15 bytes + // - (has_prefix => is_0_at_4 |X⁷)³² for the next 15 bytes + // - (~has_prefix|X⁷)³² for the last byte + // Compute also ~has_prefix = ~has_prefix OR is_0_at_4 for all bytes. We don't need to update ~has_prefix + // for the second half but it takes less cycles if we do it. + DUP2 %shl_const(3) + NOT + // stack: (is_0_at_4|X⁷)³², (~has_prefix|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + // pos 0102030405060708091011121314151617181920212223242526272829303132 + PUSH 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00 + AND + // stack: (is_0_at_4|X⁷)³¹|0⁸, (~has_prefix|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + DUP2 + DUP2 + OR + // pos 0102030405060708091011121314151617181920212223242526272829303132 + PUSH 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF0000000000000000000000000000000000 + OR + // stack: (in_range|X⁷)³², (is_0_at_4|X⁷)³², (~has_prefix|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + SWAP2 + OR + // stack: (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + + // Compute in_range' = in_range AND + // - (0xFF|X⁷)³² for bytes in positions 1-7 and 16-23 + // - (has_prefix => is_0_at_5 |X⁷)³² on the rest + // Compute also ~has_prefix = ~has_prefix OR is_0_at_5 for all bytes. + + DUP3 %shl_const(4) + NOT + // stack: (is_0_at_5|X⁷)³², (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + DUP2 + DUP2 + OR + // pos 0102030405060708091011121314151617181920212223242526272829303132 + PUSH 0xFFFFFFFFFFFFFF0000000000000000FFFFFFFFFFFFFFFF000000000000000000 + OR + // stack: (in_range'|X⁷)³², (is_0_at_5|X⁷)³², (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + SWAP2 + OR + // stack: (~has_prefix|X⁷)³², (in_range'|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + SWAP2 + AND + SWAP1 + + // Compute in_range' = in_range AND + // - (0xFF|X⁷)³² for bytes in positions 1-3, 8-11, 16-19, and 24-27 + // - (has_prefix => is_0_at_6 |X⁷)³² on the rest + // Compute also that ~has_prefix = ~has_prefix OR is_0_at_4 for all bytes. + + // stack: (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + DUP3 %shl_const(5) + NOT + // stack: (is_0_at_6|X⁷)³², (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + DUP2 + DUP2 + OR + // pos 0102030405060708091011121314151617181920212223242526272829303132 + PUSH 0xFFFFFF00000000FFFFFFFF00000000FFFFFFFF00000000FFFFFFFF0000000000 + OR + // stack: (in_range'|X⁷)³², (is_0_at_6|X⁷)³², (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + SWAP2 + OR + // stack: (~has_prefix|X⁷)³², (in_range'|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + SWAP2 + AND + SWAP1 + + // Compute in_range' = in_range AND + // - (0xFF|X⁷)³² for bytes in 1, 4-5, 8-9, 12-13, 16-17, 20-21, 24-25, 28-29 + // - (has_prefix => is_0_at_7 |X⁷)³² on the rest + // Compute also that ~has_prefix = ~has_prefix OR is_0_at_7 for all bytes. + + // stack: (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + DUP3 %shl_const(6) + NOT + // stack: (is_0_at_7|X⁷)³², (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + DUP2 + DUP2 + OR + // pos 0102030405060708091011121314151617181920212223242526272829303132 + PUSH 0xFF0000FFFF0000FFFF0000FFFF0000FFFF0000FFFF0000FFFF0000FFFF000000 + OR + // stack: (in_range'|X⁷)³², (is_0_at_7|X⁷)³², (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + SWAP2 + OR + // stack: (~has_prefix|X⁷)³², (in_range'|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + SWAP2 + AND + SWAP1 + + // Compute in_range' = in_range AND + // - (0xFF|X⁷)³² for bytes in odd positions + // - (has_prefix => is_0_at_8 |X⁷)³² on the rest + + // stack: (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + DUP3 %shl_const(7) + NOT + // stack: (is_0_at_8|X⁷)³², (~has_prefix|X⁷)³², (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + OR + // pos 0102030405060708091011121314151617181920212223242526272829303132 + PUSH 0x00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF + OR + AND + // stack: (in_range|X⁷)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest + + // Get rid of the irrelevant bits + // pos 0102030405060708091011121314151617181920212223242526272829303132 + PUSH 0x8080808080808080808080808080808080808080808080808080808080808080 + AND + %jump_neq_const(0x8080808080808080808080808080808080808080808080808080808080808080, return) + POP + %add_const(32) + + // check the remaining path + %jump(verify_path_and_write_jumpdest_table) +return: + // stack: proof_prefix_addr, jumpdest, ctx, retdest + %pop3 + JUMP + +%macro write_table_if_jumpdest + %stack (proof_prefix_addr, jumpdest, ctx) -> (proof_prefix_addr, jumpdest, ctx, %%after) + %jump(write_table_if_jumpdest) +%%after: +%endmacro + +// Write the jumpdest table. This is done by +// non-deterministically guessing the sequence of jumpdest +// addresses used during program execution within the current context. +// For each jumpdest address we also non-deterministically guess +// a proof, which is another address in the code such that +// is_jumpdest doesn't abort, when the proof is at the top of the stack +// an the jumpdest address below. If that's the case we set the +// corresponding bit in @SEGMENT_JUMPDEST_BITS to 1. +// +// stack: ctx, code_len, retdest +// stack: (empty) +global jumpdest_analysis: + // If address > 0 then address is interpreted as address' + 1 + // and the next prover input should contain a proof for address'. + PROVER_INPUT(jumpdest_table::next_address) + DUP1 %jumpi(check_proof) + // If address == 0 there are no more jump destinations to check + POP +// This is just a hook used for avoiding verification of the jumpdest +// table in another context. It is useful during proof generation, +// allowing the avoidance of table verification when simulating user code. +global jumpdest_analysis_end: + %pop2 + JUMP +check_proof: + // stack: address, ctx, code_len, retdest + DUP3 DUP2 %assert_le + %decrement + // stack: proof, ctx, code_len, retdest + DUP2 SWAP1 + // stack: address, ctx, ctx, code_len, retdest + // We read the proof + PROVER_INPUT(jumpdest_table::next_proof) + // stack: proof, address, ctx, ctx, code_len, retdest + %write_table_if_jumpdest + // stack: ctx, code_len, retdest + + %jump(jumpdest_analysis) + +%macro jumpdest_analysis + %stack (ctx, code_len) -> (ctx, code_len, %%after) + %jump(jumpdest_analysis) +%%after: +%endmacro diff --git a/evm/src/cpu/kernel/asm/util/basic_macros.asm b/evm/src/cpu/kernel/asm/util/basic_macros.asm index 44d734a3..76def0c7 100644 --- a/evm/src/cpu/kernel/asm/util/basic_macros.asm +++ b/evm/src/cpu/kernel/asm/util/basic_macros.asm @@ -8,6 +8,19 @@ jumpi %endmacro +// Jump to `jumpdest` if the top of the stack is != c +%macro jump_neq_const(c, jumpdest) + PUSH $c + SUB + %jumpi($jumpdest) +%endmacro + +// Jump to `jumpdest` if the top of the stack is < c +%macro jumpi_lt_const(c, jumpdest) + %ge_const($c) + %jumpi($jumpdest) +%endmacro + %macro pop2 %rep 2 POP diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index 227df021..0a60ccff 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -1,7 +1,7 @@ //! An EVM interpreter for testing and debugging purposes. use core::cmp::Ordering; -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::ops::Range; use anyhow::bail; @@ -11,6 +11,7 @@ use keccak_hash::keccak; use plonky2::field::goldilocks_field::GoldilocksField; use super::assembler::BYTES_PER_OFFSET; +use super::utils::u256_from_bool; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::constants::context_metadata::ContextMetadata; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; @@ -585,7 +586,23 @@ impl<'a> Interpreter<'a> { .collect() } - fn incr(&mut self, n: usize) { + pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec) { + self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits.unscale()] + .content = jumpdest_bits.iter().map(|&x| u256_from_bool(x)).collect(); + self.generation_state + .set_proofs_and_jumpdests(HashMap::from([( + context, + BTreeSet::from_iter( + jumpdest_bits + .into_iter() + .enumerate() + .filter(|&(_, x)| x) + .map(|(i, _)| i), + ), + )])); + } + + pub(crate) fn incr(&mut self, n: usize) { self.generation_state.registers.program_counter += n; } diff --git a/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs b/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs index 1d686d62..17601491 100644 --- a/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs +++ b/evm/src/cpu/kernel/tests/core/jumpdest_analysis.rs @@ -27,7 +27,7 @@ fn test_jumpdest_analysis() -> Result<()> { jumpdest, ]; - let expected_jumpdest_bits = vec![false, true, false, false, false, true, false, true]; + let jumpdest_bits = vec![false, true, false, false, false, true, false, true]; // Contract creation transaction. let initial_stack = vec![ @@ -37,12 +37,10 @@ fn test_jumpdest_analysis() -> Result<()> { ]; let mut interpreter = Interpreter::new_with_kernel(jumpdest_analysis, initial_stack); interpreter.set_code(CONTEXT, code); + interpreter.set_jumpdest_bits(CONTEXT, jumpdest_bits); + interpreter.run()?; assert_eq!(interpreter.stack(), vec![]); - assert_eq!( - interpreter.get_jumpdest_bits(CONTEXT), - expected_jumpdest_bits - ); Ok(()) } diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index 515238ec..e14de9b9 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap}; use std::sync::atomic::AtomicBool; use std::sync::Arc; @@ -8,6 +8,7 @@ use ethereum_types::{Address, BigEndianHash, H256, U256}; use itertools::enumerate; use plonky2::field::extension::Extendable; use plonky2::field::polynomial::PolynomialValues; +use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::timed; use plonky2::util::timing::TimingTree; @@ -21,13 +22,16 @@ use crate::all_stark::{AllStark, NUM_TABLES}; use crate::config::StarkConfig; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::assembler::Kernel; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; +use crate::cpu::kernel::opcodes::get_opcode; use crate::generation::state::GenerationState; use crate::generation::trie_extractor::{get_receipt_trie, get_state_trie, get_txn_trie}; use crate::memory::segments::Segment; use crate::proof::{BlockHashes, BlockMetadata, ExtraBlockData, PublicValues, TrieRoots}; use crate::prover::check_abort_signal; -use crate::util::{h2u, u256_to_usize}; +use crate::util::{h2u, u256_to_u8, u256_to_usize}; +use crate::witness::errors::{ProgramError, ProverInputError}; use crate::witness::memory::{MemoryAddress, MemoryChannel}; use crate::witness::transition::transition; @@ -38,7 +42,7 @@ pub(crate) mod state; mod trie_extractor; use self::mpt::{load_all_mpts, TrieRootPtrs}; -use crate::witness::util::mem_write_log; +use crate::witness::util::{mem_write_log, stack_peek}; /// Inputs needed for trace generation. #[derive(Clone, Debug, Deserialize, Serialize, Default)] @@ -303,9 +307,7 @@ pub fn generate_traces, const D: usize>( Ok((tables, public_values)) } -fn simulate_cpu, const D: usize>( - state: &mut GenerationState, -) -> anyhow::Result<()> { +fn simulate_cpu(state: &mut GenerationState) -> anyhow::Result<()> { let halt_pc = KERNEL.global_labels["halt"]; loop { @@ -340,3 +342,85 @@ fn simulate_cpu, const D: usize>( transition(state)?; } } + +fn simulate_cpu_between_labels_and_get_user_jumps( + initial_label: &str, + final_label: &str, + state: &mut GenerationState, +) -> Option>> { + if state.jumpdest_proofs.is_some() { + None + } else { + const JUMP_OPCODE: u8 = 0x56; + const JUMPI_OPCODE: u8 = 0x57; + + let halt_pc = KERNEL.global_labels[final_label]; + let mut jumpdest_addresses: HashMap<_, BTreeSet> = HashMap::new(); + + state.registers.program_counter = KERNEL.global_labels[initial_label]; + let initial_clock = state.traces.clock(); + let initial_context = state.registers.context; + + log::debug!("Simulating CPU for jumpdest analysis."); + + loop { + // skip jumpdest table validations in simulations + if state.registers.program_counter == KERNEL.global_labels["jumpdest_analysis"] { + state.registers.program_counter = KERNEL.global_labels["jumpdest_analysis_end"] + } + let pc = state.registers.program_counter; + let context = state.registers.context; + let mut halt = state.registers.is_kernel + && pc == halt_pc + && state.registers.context == initial_context; + let Ok(opcode) = u256_to_u8(state.memory.get(MemoryAddress::new( + context, + Segment::Code, + state.registers.program_counter, + ))) else { + log::debug!( + "Simulated CPU for jumpdest analysis halted after {} cycles", + state.traces.clock() - initial_clock + ); + return Some(jumpdest_addresses); + }; + let cond = if let Ok(cond) = stack_peek(state, 1) { + cond != U256::zero() + } else { + false + }; + if !state.registers.is_kernel + && (opcode == JUMP_OPCODE || (opcode == JUMPI_OPCODE && cond)) + { + // Avoid deeper calls to abort + let Ok(jumpdest) = u256_to_usize(state.registers.stack_top) else { + log::debug!( + "Simulated CPU for jumpdest analysis halted after {} cycles", + state.traces.clock() - initial_clock + ); + return Some(jumpdest_addresses); + }; + state.memory.set( + MemoryAddress::new(context, Segment::JumpdestBits, jumpdest), + U256::one(), + ); + let jumpdest_opcode = + state + .memory + .get(MemoryAddress::new(context, Segment::Code, jumpdest)); + if let Some(ctx_addresses) = jumpdest_addresses.get_mut(&context) { + ctx_addresses.insert(jumpdest); + } else { + jumpdest_addresses.insert(context, BTreeSet::from([jumpdest])); + } + } + if halt || transition(state).is_err() { + log::debug!( + "Simulated CPU for jumpdest analysis halted after {} cycles", + state.traces.clock() - initial_clock + ); + return Some(jumpdest_addresses); + } + } + } +} diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index b60233d9..f3078239 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -1,3 +1,5 @@ +use std::cmp::min; +use std::collections::HashMap; use std::mem::transmute; use std::str::FromStr; @@ -5,20 +7,26 @@ use anyhow::{bail, Error}; use ethereum_types::{BigEndianHash, H256, U256, U512}; use itertools::{enumerate, Itertools}; use num_bigint::BigUint; +use plonky2::field::extension::Extendable; use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; use serde::{Deserialize, Serialize}; +use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::constants::context_metadata::ContextMetadata; +use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::extension_tower::{FieldExt, Fp12, BLS381, BN254}; use crate::generation::prover_input::EvmField::{ Bls381Base, Bls381Scalar, Bn254Base, Bn254Scalar, Secp256k1Base, Secp256k1Scalar, }; use crate::generation::prover_input::FieldOp::{Inverse, Sqrt}; +use crate::generation::simulate_cpu_between_labels_and_get_user_jumps; use crate::generation::state::GenerationState; use crate::memory::segments::Segment; use crate::memory::segments::Segment::BnPairing; -use crate::util::{biguint_to_mem_vec, mem_vec_to_biguint, u256_to_usize}; -use crate::witness::errors::ProgramError; +use crate::util::{biguint_to_mem_vec, mem_vec_to_biguint, u256_to_u8, u256_to_usize}; use crate::witness::errors::ProverInputError::*; +use crate::witness::errors::{ProgramError, ProverInputError}; use crate::witness::memory::MemoryAddress; use crate::witness::operation::CONTEXT_SCALING_FACTOR; use crate::witness::util::{current_context_peek, stack_peek}; @@ -48,6 +56,7 @@ impl GenerationState { "bignum_modmul" => self.run_bignum_modmul(), "withdrawal" => self.run_withdrawal(), "num_bits" => self.run_num_bits(), + "jumpdest_table" => self.run_jumpdest_table(input_fn), _ => Err(ProgramError::ProverInputError(InvalidFunction)), } } @@ -230,6 +239,236 @@ impl GenerationState { Ok(num_bits.into()) } } + + /// Generate either the next used jump address or the proof for the last jump address. + fn run_jumpdest_table(&mut self, input_fn: &ProverInputFn) -> Result { + match input_fn.0[1].as_str() { + "next_address" => self.run_next_jumpdest_table_address(), + "next_proof" => self.run_next_jumpdest_table_proof(), + _ => Err(ProgramError::ProverInputError(InvalidInput)), + } + } + + /// Returns the next used jump address. + fn run_next_jumpdest_table_address(&mut self) -> Result { + let context = self.registers.context; + let code_len = u256_to_usize(self.get_code_len()?.into()); + + if self.jumpdest_proofs.is_none() { + self.generate_jumpdest_proofs()?; + } + + let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else { + return Err(ProgramError::ProverInputError( + ProverInputError::InvalidJumpdestSimulation, + )); + }; + + if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context) + && let Some(next_jumpdest_address) = ctx_jumpdest_proofs.pop() + { + Ok((next_jumpdest_address + 1).into()) + } else { + self.jumpdest_proofs = None; + Ok(U256::zero()) + } + } + + /// Returns the proof for the last jump address. + fn run_next_jumpdest_table_proof(&mut self) -> Result { + let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else { + return Err(ProgramError::ProverInputError( + ProverInputError::InvalidJumpdestSimulation, + )); + }; + if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context) + && let Some(next_jumpdest_proof) = ctx_jumpdest_proofs.pop() + { + Ok(next_jumpdest_proof.into()) + } else { + Err(ProgramError::ProverInputError( + ProverInputError::InvalidJumpdestSimulation, + )) + } + } +} + +impl GenerationState { + /// Simulate the user's code and store all the jump addresses with their respective contexts. + fn generate_jumpdest_proofs(&mut self) -> Result<(), ProgramError> { + let checkpoint = self.checkpoint(); + let memory = self.memory.clone(); + + let code = self.get_current_code()?; + // We need to set the simulated jumpdest bits to one as otherwise + // the simulation will fail. + + // Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call + let Some(jumpdest_table) = simulate_cpu_between_labels_and_get_user_jumps( + "jumpdest_analysis_end", + "terminate_common", + self, + ) else { + self.jumpdest_proofs = Some(HashMap::new()); + return Ok(()); + }; + + // Return to the state before starting the simulation + self.rollback(checkpoint); + self.memory = memory; + + // Find proofs for all contexts + self.set_proofs_and_jumpdests(jumpdest_table); + + Ok(()) + } + + /// Given a HashMap containing the contexts and the jumpdest addresses, compute their respective proofs, + /// by calling `get_proofs_and_jumpdests` + pub(crate) fn set_proofs_and_jumpdests( + &mut self, + jumpdest_table: HashMap>, + ) { + self.jumpdest_proofs = Some(HashMap::from_iter(jumpdest_table.into_iter().map( + |(ctx, jumpdest_table)| { + let code = self.get_code(ctx).unwrap(); + if let Some(&largest_address) = jumpdest_table.last() { + let proofs = get_proofs_and_jumpdests(&code, largest_address, jumpdest_table); + (ctx, proofs) + } else { + (ctx, vec![]) + } + }, + ))); + } + + fn get_current_code(&self) -> Result, ProgramError> { + self.get_code(self.registers.context) + } + + fn get_code(&self, context: usize) -> Result, ProgramError> { + let code_len = self.get_code_len()?; + let code = (0..code_len) + .map(|i| { + u256_to_u8(self.memory.get(MemoryAddress::new( + self.registers.context, + Segment::Code, + i, + ))) + }) + .collect::, _>>()?; + Ok(code) + } + + fn get_code_len(&self) -> Result { + let code_len = u256_to_usize(self.memory.get(MemoryAddress::new( + self.registers.context, + Segment::ContextMetadata, + ContextMetadata::CodeSize.unscale(), + )))?; + Ok(code_len) + } + + fn set_jumpdest_bits(&mut self, code: &[u8]) { + const JUMPDEST_OPCODE: u8 = 0x5b; + for (pos, opcode) in CodeIterator::new(code) { + if opcode == JUMPDEST_OPCODE { + self.memory.set( + MemoryAddress::new(self.registers.context, Segment::JumpdestBits, pos), + U256::one(), + ); + } + } + } +} + +/// For all address in `jumpdest_table`, each bounded by `largest_address`, +/// this function searches for a proof. A proof is the closest address +/// for which none of the previous 32 bytes in the code (including opcodes +/// and pushed bytes) are PUSHXX and the address is in its range. It returns +/// a vector of even size containing proofs followed by their addresses. +fn get_proofs_and_jumpdests( + code: &[u8], + largest_address: usize, + jumpdest_table: std::collections::BTreeSet, +) -> Vec { + const PUSH1_OPCODE: u8 = 0x60; + const PUSH32_OPCODE: u8 = 0x7f; + let (proofs, _) = CodeIterator::until(code, largest_address + 1).fold( + (vec![], 0), + |(mut proofs, acc), (pos, opcode)| { + let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) { + code[prefix_start..pos] + .iter() + .enumerate() + .fold(true, |acc, (prefix_pos, &byte)| { + let cond1 = byte > PUSH32_OPCODE; + let cond2 = (prefix_start + prefix_pos) as i32 + + (byte as i32 - PUSH1_OPCODE as i32) + + 1 + < pos as i32; + acc && (cond1 || cond2) + }) + } else { + false + }; + let acc = if has_prefix { pos - 32 } else { acc }; + if jumpdest_table.contains(&pos) { + // Push the proof + proofs.push(acc); + // Push the address + proofs.push(pos); + } + (proofs, acc) + }, + ); + proofs +} + +/// An iterator over the EVM code contained in `code`, which skips the bytes +/// that are the arguments of a PUSHXX opcode. +struct CodeIterator<'a> { + code: &'a [u8], + pos: usize, + end: usize, +} + +impl<'a> CodeIterator<'a> { + fn new(code: &'a [u8]) -> Self { + CodeIterator { + end: code.len(), + code, + pos: 0, + } + } + fn until(code: &'a [u8], end: usize) -> Self { + CodeIterator { + end: std::cmp::min(code.len(), end), + code, + pos: 0, + } + } +} + +impl<'a> Iterator for CodeIterator<'a> { + type Item = (usize, u8); + + fn next(&mut self) -> Option { + const PUSH1_OPCODE: u8 = 0x60; + const PUSH32_OPCODE: u8 = 0x7f; + let CodeIterator { code, pos, end } = self; + if *pos >= *end { + return None; + } + let opcode = code[*pos]; + let old_pos = *pos; + *pos += if (PUSH1_OPCODE..=PUSH32_OPCODE).contains(&opcode) { + (opcode - PUSH1_OPCODE + 2).into() + } else { + 1 + }; + Some((old_pos, opcode)) + } } enum EvmField { diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index c2822d1c..54512b2f 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap}; use ethereum_types::{Address, BigEndianHash, H160, H256, U256}; use keccak_hash::keccak; @@ -50,6 +50,12 @@ pub(crate) struct GenerationState { /// Pointers, within the `TrieData` segment, of the three MPTs. pub(crate) trie_root_ptrs: TrieRootPtrs, + + /// A hash map where the key is a context in the user's code and the value is the set of + /// jump destinations with its corresponding "proof". A "proof" for a jump destination is + /// either 0 or an address i > 32 in the code (not necessarily pointing to an opcode) such that + /// for every j in [i, i+32] it holds that code[j] < 0x7f - j + i. + pub(crate) jumpdest_proofs: Option>>, } impl GenerationState { @@ -91,6 +97,7 @@ impl GenerationState { txn_root_ptr: 0, receipt_root_ptr: 0, }, + jumpdest_proofs: None, }; let trie_root_ptrs = state.preinitialize_mpts(&inputs.tries); @@ -165,6 +172,26 @@ impl GenerationState { .map(|i| stack_peek(self, i).unwrap()) .collect() } + + /// Clones everything but the traces. + pub(crate) fn soft_clone(&self) -> GenerationState { + Self { + inputs: self.inputs.clone(), + registers: self.registers, + memory: self.memory.clone(), + traces: Traces::default(), + rlp_prover_inputs: self.rlp_prover_inputs.clone(), + state_key_to_address: self.state_key_to_address.clone(), + bignum_modmul_result_limbs: self.bignum_modmul_result_limbs.clone(), + withdrawal_prover_inputs: self.withdrawal_prover_inputs.clone(), + trie_root_ptrs: TrieRootPtrs { + state_root_ptr: 0, + txn_root_ptr: 0, + receipt_root_ptr: 0, + }, + jumpdest_proofs: None, + } + } } /// Withdrawals prover input array is of the form `[addr0, amount0, ..., addrN, amountN, U256::MAX, U256::MAX]`. diff --git a/evm/src/generation/trie_extractor.rs b/evm/src/generation/trie_extractor.rs index d55a1fbf..4d3a745a 100644 --- a/evm/src/generation/trie_extractor.rs +++ b/evm/src/generation/trie_extractor.rs @@ -255,7 +255,7 @@ pub(crate) fn get_trie_helper( ) -> Result, ProgramError> { let load = |offset| memory.get(MemoryAddress::new(0, Segment::TrieData, offset)); let load_slice_from = |init_offset| { - &memory.contexts[0].segments[Segment::TrieData as usize].content[init_offset..] + &memory.contexts[0].segments[Segment::TrieData.unscale()].content[init_offset..] }; let trie_type = PartialTrieType::all()[u256_to_usize(load(ptr))?]; diff --git a/evm/src/witness/errors.rs b/evm/src/witness/errors.rs index 5a0fcbfb..1b266aef 100644 --- a/evm/src/witness/errors.rs +++ b/evm/src/witness/errors.rs @@ -36,4 +36,6 @@ pub enum ProverInputError { InvalidInput, InvalidFunction, NumBitsError, + InvalidJumpDestination, + InvalidJumpdestSimulation, } diff --git a/evm/tests/empty_txn_list.rs b/evm/tests/empty_txn_list.rs index ff4e7637..2130e51b 100644 --- a/evm/tests/empty_txn_list.rs +++ b/evm/tests/empty_txn_list.rs @@ -77,7 +77,7 @@ fn test_empty_txn_list() -> anyhow::Result<()> { // Initialize the preprocessed circuits for the zkEVM. let all_circuits = AllRecursiveCircuits::::new( &all_stark, - &[16..17, 10..11, 12..13, 14..15, 9..11, 12..13, 17..18], // Minimal ranges to prove an empty list + &[16..17, 9..11, 12..13, 14..15, 9..11, 12..13, 17..18], // Minimal ranges to prove an empty list &config, );