From 9e39d88ab808c8f9d8a1b1619064fa20ec83fdb3 Mon Sep 17 00:00:00 2001 From: 4l0n50 Date: Wed, 13 Dec 2023 17:33:53 +0100 Subject: [PATCH] Rebase to main --- evm/src/cpu/kernel/asm/core/call.asm | 9 +- .../cpu/kernel/asm/core/jumpdest_analysis.asm | 182 ++++++++++++++++++ evm/src/generation/mod.rs | 65 ++++++- evm/src/generation/prover_input.rs | 153 ++++++++++++++- evm/src/generation/state.rs | 26 +++ evm/src/util.rs | 5 + evm/src/witness/errors.rs | 2 + evm/src/witness/transition.rs | 2 +- 8 files changed, 432 insertions(+), 12 deletions(-) diff --git a/evm/src/cpu/kernel/asm/core/call.asm b/evm/src/cpu/kernel/asm/core/call.asm index 2e7d1d73..5a2a14c4 100644 --- a/evm/src/cpu/kernel/asm/core/call.asm +++ b/evm/src/cpu/kernel/asm/core/call.asm @@ -367,11 +367,12 @@ call_too_deep: %checkpoint // Checkpoint %increment_call_depth // Perform jumpdest analyis - PUSH %%after - %mload_context_metadata(@CTX_METADATA_CODE_SIZE) - GET_CONTEXT + // PUSH %%after + // %mload_context_metadata(@CTX_METADATA_CODE_SIZE) + // GET_CONTEXT // stack: ctx, code_size, retdest - %jump(jumpdest_analysis) + // %jump(jumpdest_analysis) + %validate_jumpdest_table %%after: PUSH 0 // jump dest EXIT_KERNEL diff --git a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm index bda6f96e..09bb35fa 100644 --- a/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -1,3 +1,48 @@ +// Set @SEGMENT_JUMPDEST_BITS to one between positions [init_pos, final_pos], +// for the given context's code. Panics if we never hit final_pos +// Pre stack: init_pos, ctx, final_pos, retdest +// Post stack: (empty) +global verify_path: +loop_new: + // stack: i, ctx, final_pos, retdest + // Ideally we would break if i >= final_pos, but checking i > final_pos is + // cheaper. It doesn't hurt to over-read by 1, since we'll read 0 which is + // a no-op. + DUP3 DUP2 EQ // i == final_pos + %jumpi(return_new) + DUP3 DUP2 GT // i > final_pos + %jumpi(panic) + + // stack: i, ctx, final_pos, retdest + %stack (i, ctx) -> (ctx, @SEGMENT_CODE, i, i, ctx) + MLOAD_GENERAL + // stack: opcode, i, ctx, final_pos, retdest + + DUP1 + // Slightly more efficient than `%eq_const(0x5b) ISZERO` + PUSH 0x5b + SUB + // stack: opcode != JUMPDEST, opcode, i, ctx, code_len, retdest + %jumpi(continue_new) + + // stack: JUMPDEST, i, ctx, code_len, retdest + %stack (JUMPDEST, i, ctx) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, i, JUMPDEST, i, ctx) + MSTORE_GENERAL + +continue_new: + // stack: opcode, i, ctx, code_len, retdest + %add_const(code_bytes_to_skip) + %mload_kernel_code + // stack: bytes_to_skip, i, ctx, code_len, retdest + ADD + // stack: i, ctx, code_len, retdest + %jump(loop_new) + +return_new: + // stack: i, ctx, code_len, retdest + %pop3 + JUMP + // Populates @SEGMENT_JUMPDEST_BITS for the given context's code. // Pre stack: ctx, code_len, retdest // Post stack: (empty) @@ -89,3 +134,140 @@ code_bytes_to_skip: %rep 128 BYTES 1 // 0x80-0xff %endrep + + +// A proof attesting that jumpdest is a valid jump destinations is +// either 0 or an index 0 < i <= jumpdest - 32. +// A proof is valid if: +// - i == 0 and we can go from the first opcode to jumpdest and code[jumpdest] = 0x5b +// - i > 0 and: +// - for j in {i+0,..., i+31} code[j] != PUSHk for all k >= 32 - j - i, +// - we can go from opcode i+32 to jumpdest, +// - code[jumpdest] = 0x5b. +// stack: proof_prefix_addr, jumpdest, retdest +// stack: (empty) abort if jumpdest is not a valid destination +global is_jumpdest: + GET_CONTEXT + // stack: ctx, proof_prefix_addr, jumpdest, retdest + %stack + (ctx, proof_prefix_addr, jumpdest) -> + (ctx, @SEGMENT_CODE, jumpdest, jumpdest, ctx, proof_prefix_addr) + MLOAD_GENERAL + // stack: opcode, jumpdest, ctx, proof_prefix_addr, retdest + + // Slightly more efficient than `%eq_const(0x5b) ISZERO` + PUSH 0x5b + SUB + %jumpi(panic) + + //stack: jumpdest, ctx, proof_prefix_addr, retdest + SWAP2 DUP1 + // stack: proof_prefix_addr, proof_prefix_addr, ctx, jumpdest + %eq_const(0) + %jumpi(verify_path) + //stack: proof_prefix_addr, ctx, jumpdest, retdest + // If we are here we need to check that the next 32 bytes are less + // than JUMPXX for XX < 32 - i <=> opcode < 0x7f - i = 127 - i, 0 <= i < 32, + // or larger than 127 + %check_and_step(127) %check_and_step(126) %check_and_step(125) %check_and_step(124) + %check_and_step(123) %check_and_step(122) %check_and_step(121) %check_and_step(120) + %check_and_step(119) %check_and_step(118) %check_and_step(117) %check_and_step(116) + %check_and_step(115) %check_and_step(114) %check_and_step(113) %check_and_step(112) + %check_and_step(111) %check_and_step(110) %check_and_step(109) %check_and_step(108) + %check_and_step(107) %check_and_step(106) %check_and_step(105) %check_and_step(104) + %check_and_step(103) %check_and_step(102) %check_and_step(101) %check_and_step(100) + %check_and_step(99) %check_and_step(98) %check_and_step(97) %check_and_step(96) + + // check the remaining path + %jump(verify_path) + +return_is_jumpdest: + //stack: proof_prefix_addr, jumpdest, retdest + %pop2 + JUMP + + +// Chek if the opcode pointed by proof_prefix address is +// less than max and increment proof_prefix_addr +%macro check_and_step(max) + %stack + (proof_prefix_addr, ctx, jumpdest) -> + (ctx, @SEGMENT_CODE, proof_prefix_addr, proof_prefix_addr, ctx, jumpdest) + MLOAD_GENERAL + // stack: opcode, proof_prefix_addr, ctx, jumpdest + DUP1 + %gt_const(127) + %jumpi(%%ok) + %assert_lt_const($max) + // stack: proof_prefix_addr, ctx, jumpdest + PUSH 0 // We need something to pop +%%ok: + POP + %increment +%endmacro + +%macro is_jumpdest + %stack (proof, addr) -> (proof, addr, %%after) + %jump(is_jumpdest) +%%after: +%endmacro + +// Check if the jumpdest table is correct. This is done by +// non-deterministically guessing the sequence of jumpdest +// addresses used during program execution within the current context. +// For each jumpdest address we also non-deterministically guess +// a proof, which is another address in the code, such that +// is_jumpdest don't abort when the proof is on the top of the stack +// an the jumpdest address below. If that's the case we set the +// corresponding bit in @SEGMENT_JUMPDEST_BITS to 1. +// +// stack: retdest +// stack: (empty) +global validate_jumpdest_table: + // If address > 0 it is interpreted as address' = address - 1 + // and the next prover input should contain a proof for address'. + PROVER_INPUT(jumpdest_table::next_address) + DUP1 %jumpi(check_proof) + // If proof == 0 there are no more jump destionations to check + POP +global validate_jumpdest_table_end: + JUMP + // were set to 0 + //%mload_context_metadata(@CTX_METADATA_CODE_SIZE) + // get the code length in bytes + //%add_const(31) + //%div_const(32) + //GET_CONTEXT + //SWAP2 +//verify_chunk: + // stack: i (= proof), code_len, ctx = 0 + //%stack (i, code_len, ctx) -> (code_len, i, ctx, @SEGMENT_JUMPDEST_BITS, i, 32, i, code_len, ctx) + //GT + //%jumpi(valid_table) + //%mload_packing + // stack: packed_bits, code_len, i, ctx + //%assert_eq_const(0) + //%increment + //%jump(verify_chunk) + +check_proof: + %sub_const(1) + DUP1 + // We read the proof + PROVER_INPUT(jumpdest_table::next_proof) + // stack: proof, address + %is_jumpdest + GET_CONTEXT + %stack (ctx, address) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, address) + MSTORE_GENERAL + %jump(validate_jumpdest_table) +valid_table: + // stack: ctx, @SEGMENT_JUMPDEST_BITS, i, 32, i, code_len, ctx, retdest + %pop7 + JUMP + +%macro validate_jumpdest_table + PUSH %%after + %jump(validate_jumpdest_table) +%%after: +%endmacro diff --git a/evm/src/generation/mod.rs b/evm/src/generation/mod.rs index d691d34e..2c3ee900 100644 --- a/evm/src/generation/mod.rs +++ b/evm/src/generation/mod.rs @@ -8,6 +8,7 @@ use ethereum_types::{Address, BigEndianHash, H256, U256}; use itertools::enumerate; use plonky2::field::extension::Extendable; use plonky2::field::polynomial::PolynomialValues; +use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::timed; use plonky2::util::timing::TimingTree; @@ -21,13 +22,15 @@ use crate::all_stark::{AllStark, NUM_TABLES}; use crate::config::StarkConfig; use crate::cpu::columns::CpuColumnsView; use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::assembler::Kernel; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; +use crate::cpu::kernel::opcodes::get_opcode; use crate::generation::state::GenerationState; use crate::generation::trie_extractor::{get_receipt_trie, get_state_trie, get_txn_trie}; use crate::memory::segments::Segment; use crate::proof::{BlockHashes, BlockMetadata, ExtraBlockData, PublicValues, TrieRoots}; use crate::prover::check_abort_signal; -use crate::util::{h2u, u256_to_usize}; +use crate::util::{h2u, u256_to_u8, u256_to_usize}; use crate::witness::memory::{MemoryAddress, MemoryChannel}; use crate::witness::transition::transition; @@ -38,7 +41,7 @@ pub(crate) mod state; mod trie_extractor; use self::mpt::{load_all_mpts, TrieRootPtrs}; -use crate::witness::util::mem_write_log; +use crate::witness::util::{mem_write_log, stack_peek}; /// Inputs needed for trace generation. #[derive(Clone, Debug, Deserialize, Serialize, Default)] @@ -296,9 +299,7 @@ pub fn generate_traces, const D: usize>( Ok((tables, public_values)) } -fn simulate_cpu, const D: usize>( - state: &mut GenerationState, -) -> anyhow::Result<()> { +fn simulate_cpu(state: &mut GenerationState) -> anyhow::Result<()> { let halt_pc = KERNEL.global_labels["halt"]; loop { @@ -333,3 +334,57 @@ fn simulate_cpu, const D: usize>( transition(state)?; } } + +fn simulate_cpu_between_labels_and_get_user_jumps( + initial_label: &str, + final_label: &str, + state: &mut GenerationState, +) -> anyhow::Result> { + let halt_pc = KERNEL.global_labels[final_label]; + let mut jumpdest_addresses = HashSet::new(); + state.registers.program_counter = KERNEL.global_labels[initial_label]; + let context = state.registers.context; + + loop { + if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] { + state.registers.program_counter = KERNEL.global_labels["validate_jumpdest_table_end"] + } + let pc = state.registers.program_counter; + let halt = state.registers.is_kernel && pc == halt_pc && state.registers.context == context; + let opcode = u256_to_u8(state.memory.get(MemoryAddress { + context: state.registers.context, + segment: Segment::Code as usize, + virt: state.registers.program_counter, + })) + .map_err(|_| anyhow::Error::msg("Invalid opcode."))?; + let cond = if let Ok(cond) = stack_peek(state, 1) { + cond != U256::zero() + } else { + false + }; + if !state.registers.is_kernel + && (opcode == get_opcode("JUMP") || (opcode == get_opcode("JUMPI") && cond)) + { + // TODO: hotfix for avoiding deeper calls to abort + let jumpdest = u256_to_usize(state.registers.stack_top) + .map_err(|_| anyhow::Error::msg("Not a valid jump destination"))?; + state.memory.set( + MemoryAddress { + context: state.registers.context, + segment: Segment::JumpdestBits as usize, + virt: jumpdest, + }, + U256::one(), + ); + if (state.registers.context == context) { + jumpdest_addresses.insert(jumpdest); + } + } + if halt { + log::debug!("Simulated CPU halted after {} cycles", state.traces.clock()); + return Ok(jumpdest_addresses.into_iter().collect()); + } + + transition(state)?; + } +} diff --git a/evm/src/generation/prover_input.rs b/evm/src/generation/prover_input.rs index b2a8f0ce..852cd77d 100644 --- a/evm/src/generation/prover_input.rs +++ b/evm/src/generation/prover_input.rs @@ -1,24 +1,34 @@ +use std::cmp::min; +use std::collections::HashSet; use std::mem::transmute; use std::str::FromStr; use anyhow::{bail, Error}; use ethereum_types::{BigEndianHash, H256, U256, U512}; +use hashbrown::HashMap; use itertools::{enumerate, Itertools}; use num_bigint::BigUint; +use plonky2::field::extension::Extendable; use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; use serde::{Deserialize, Serialize}; +use crate::cpu::kernel::aggregator::KERNEL; +use crate::cpu::kernel::constants::context_metadata::ContextMetadata; +use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; +use crate::cpu::kernel::opcodes::{get_opcode, get_push_opcode}; use crate::extension_tower::{FieldExt, Fp12, BLS381, BN254}; use crate::generation::prover_input::EvmField::{ Bls381Base, Bls381Scalar, Bn254Base, Bn254Scalar, Secp256k1Base, Secp256k1Scalar, }; use crate::generation::prover_input::FieldOp::{Inverse, Sqrt}; +use crate::generation::simulate_cpu_between_labels_and_get_user_jumps; use crate::generation::state::GenerationState; use crate::memory::segments::Segment; use crate::memory::segments::Segment::BnPairing; -use crate::util::{biguint_to_mem_vec, mem_vec_to_biguint, u256_to_usize}; -use crate::witness::errors::ProgramError; +use crate::util::{biguint_to_mem_vec, mem_vec_to_biguint, u256_to_u8, u256_to_usize}; use crate::witness::errors::ProverInputError::*; +use crate::witness::errors::{ProgramError, ProverInputError}; use crate::witness::memory::MemoryAddress; use crate::witness::util::{current_context_peek, stack_peek}; @@ -47,6 +57,7 @@ impl GenerationState { "bignum_modmul" => self.run_bignum_modmul(), "withdrawal" => self.run_withdrawal(), "num_bits" => self.run_num_bits(), + "jumpdest_table" => self.run_jumpdest_table(input_fn), _ => Err(ProgramError::ProverInputError(InvalidFunction)), } } @@ -229,6 +240,144 @@ impl GenerationState { Ok(num_bits.into()) } } + + fn run_jumpdest_table(&mut self, input_fn: &ProverInputFn) -> Result { + match input_fn.0[1].as_str() { + "next_address" => self.run_next_jumpdest_table_address(), + "next_proof" => self.run_next_jumpdest_table_proof(), + _ => Err(ProgramError::ProverInputError(InvalidInput)), + } + } + /// Return the next used jump addres + fn run_next_jumpdest_table_address(&mut self) -> Result { + let code_len = u256_to_usize(self.memory.get(MemoryAddress { + context: self.registers.context, + segment: Segment::ContextMetadata as usize, + virt: ContextMetadata::CodeSize as usize, + }))?; + + if self.jumpdest_addresses.is_none() { + let mut state: GenerationState = self.soft_clone(); + + let mut jumpdest_addresses = vec![]; + // Generate the jumpdest table + let code = (0..code_len) + .map(|i| { + u256_to_u8(self.memory.get(MemoryAddress { + context: self.registers.context, + segment: Segment::Code as usize, + virt: i, + })) + }) + .collect::, _>>()?; + let mut i = 0; + while i < code_len { + if code[i] == get_opcode("JUMPDEST") { + jumpdest_addresses.push(i); + state.memory.set( + MemoryAddress { + context: state.registers.context, + segment: Segment::JumpdestBits as usize, + virt: i, + }, + U256::one(), + ); + log::debug!("jumpdest at {i}"); + } + i += if code[i] >= get_push_opcode(1) && code[i] <= get_push_opcode(32) { + (code[i] - get_push_opcode(1) + 2).into() + } else { + 1 + } + } + + // We need to skip the validate table call + self.jumpdest_addresses = simulate_cpu_between_labels_and_get_user_jumps( + "validate_jumpdest_table_end", + "terminate_common", + &mut state, + ) + .ok(); + log::debug!("code len = {code_len}"); + log::debug!("all jumpdest addresses = {:?}", jumpdest_addresses); + log::debug!("user's jumdest addresses = {:?}", self.jumpdest_addresses); + // self.jumpdest_addresses = Some(jumpdest_addresses); + } + + let Some(jumpdest_table) = &mut self.jumpdest_addresses else { + // TODO: Add another error + return Err(ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation)); + }; + + if let Some(next_jumpdest_address) = jumpdest_table.pop() { + self.last_jumpdest_address = next_jumpdest_address; + Ok((next_jumpdest_address + 1).into()) + } else { + self.jumpdest_addresses = None; + Ok(U256::zero()) + } + } + + /// Return the proof for the last jump adddress + fn run_next_jumpdest_table_proof(&mut self) -> Result { + let code_len = u256_to_usize(self.memory.get(MemoryAddress { + context: self.registers.context, + segment: Segment::ContextMetadata as usize, + virt: ContextMetadata::CodeSize as usize, + }))?; + + let mut address = MemoryAddress { + context: self.registers.context, + segment: Segment::Code as usize, + virt: 0, + }; + let mut proof = 0; + let mut prefix_size = 0; + + // TODO: The proof searching algorithm is not very eficient. But luckyly it doesn't seem + // a problem because is done natively. + + // Search the closest address to last_jumpdest_address for which none of + // the previous 32 bytes in the code (including opcodes and pushed bytes) + // are PUSHXX and the address is in its range + while address.virt < self.last_jumpdest_address { + let opcode = u256_to_u8(self.memory.get(address))?; + let is_push = + opcode >= get_push_opcode(1).into() && opcode <= get_push_opcode(32).into(); + + address.virt += if is_push { + (opcode - get_push_opcode(1) + 2).into() + } else { + 1 + }; + // Check if the new address has a prefix of size >= 32 + let mut has_prefix = true; + for i in address.virt as i32 - 32..address.virt as i32 { + let opcode = u256_to_u8(self.memory.get(MemoryAddress { + context: self.registers.context, + segment: Segment::Code as usize, + virt: i as usize, + }))?; + if i < 0 + || (opcode >= get_push_opcode(1) + && opcode <= get_push_opcode(32) + && i + (opcode - get_push_opcode(1)) as i32 + 1 >= address.virt as i32) + { + has_prefix = false; + break; + } + } + if has_prefix { + proof = address.virt - 32; + } + } + if address.virt > self.last_jumpdest_address { + return Err(ProgramError::ProverInputError( + ProverInputError::InvalidJumpDestination, + )); + } + Ok(proof.into()) + } } enum EvmField { diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index 89ff0c5a..79dd94fb 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -50,6 +50,9 @@ pub(crate) struct GenerationState { /// Pointers, within the `TrieData` segment, of the three MPTs. pub(crate) trie_root_ptrs: TrieRootPtrs, + + pub(crate) last_jumpdest_address: usize, + pub(crate) jumpdest_addresses: Option>, } impl GenerationState { @@ -91,6 +94,8 @@ impl GenerationState { txn_root_ptr: 0, receipt_root_ptr: 0, }, + last_jumpdest_address: 0, + jumpdest_addresses: None, }; let trie_root_ptrs = state.preinitialize_mpts(&inputs.tries); @@ -167,6 +172,27 @@ impl GenerationState { .map(|i| stack_peek(self, i).unwrap()) .collect() } + + /// Clone everything but the traces + pub(crate) fn soft_clone(&self) -> GenerationState { + Self { + inputs: self.inputs.clone(), + registers: self.registers.clone(), + memory: self.memory.clone(), + traces: Traces::default(), + rlp_prover_inputs: self.rlp_prover_inputs.clone(), + state_key_to_address: self.state_key_to_address.clone(), + bignum_modmul_result_limbs: self.bignum_modmul_result_limbs.clone(), + withdrawal_prover_inputs: self.withdrawal_prover_inputs.clone(), + trie_root_ptrs: TrieRootPtrs { + state_root_ptr: 0, + txn_root_ptr: 0, + receipt_root_ptr: 0, + }, + last_jumpdest_address: 0, + jumpdest_addresses: None, + } + } } /// Withdrawals prover input array is of the form `[addr0, amount0, ..., addrN, amountN, U256::MAX, U256::MAX]`. diff --git a/evm/src/util.rs b/evm/src/util.rs index 3d9564b5..bbbd8af1 100644 --- a/evm/src/util.rs +++ b/evm/src/util.rs @@ -70,6 +70,11 @@ pub(crate) fn u256_to_u64(u256: U256) -> Result<(F, F), ProgramError> )) } +/// Safe alternative to `U256::as_u8()`, which errors in case of overflow instead of panicking. +pub(crate) fn u256_to_u8(u256: U256) -> Result { + u256.try_into().map_err(|_| ProgramError::IntegerTooLarge) +} + /// Safe alternative to `U256::as_usize()`, which errors in case of overflow instead of panicking. pub(crate) fn u256_to_usize(u256: U256) -> Result { u256.try_into().map_err(|_| ProgramError::IntegerTooLarge) diff --git a/evm/src/witness/errors.rs b/evm/src/witness/errors.rs index 5a0fcbfb..1b266aef 100644 --- a/evm/src/witness/errors.rs +++ b/evm/src/witness/errors.rs @@ -36,4 +36,6 @@ pub enum ProverInputError { InvalidInput, InvalidFunction, NumBitsError, + InvalidJumpDestination, + InvalidJumpdestSimulation, } diff --git a/evm/src/witness/transition.rs b/evm/src/witness/transition.rs index cf2e3bbe..0fa14321 100644 --- a/evm/src/witness/transition.rs +++ b/evm/src/witness/transition.rs @@ -395,7 +395,7 @@ fn try_perform_instruction( if state.registers.is_kernel { log_kernel_instruction(state, op); } else { - log::debug!("User instruction: {:?}", op); + log::debug!("User instruction: {:?} stack = {:?}", op, state.stack()); } fill_op_flag(op, &mut row);