Merge pull request #1423 from topos-protocol/jumpdest_nd

Refactor jumpdest analysis
This commit is contained in:
Alonso González 2024-01-11 17:14:48 +01:00 committed by GitHub
commit a78a29a698
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 657 additions and 41 deletions

View File

@ -387,12 +387,10 @@ call_too_deep:
%checkpoint // Checkpoint
%increment_call_depth
// Perform jumpdest analyis
PUSH %%after
%mload_context_metadata(@CTX_METADATA_CODE_SIZE)
GET_CONTEXT
// stack: ctx, code_size, retdest
%jump(jumpdest_analysis)
%%after:
%jumpdest_analysis
PUSH 0 // jump dest
EXIT_KERNEL
// (Old context) stack: new_ctx

View File

@ -1,29 +1,26 @@
// Populates @SEGMENT_JUMPDEST_BITS for the given context's code.
// Pre stack: ctx, code_len, retdest
// Set @SEGMENT_JUMPDEST_BITS to one between positions [init_pos, final_pos],
// for the given context's code.
// Pre stack: init_pos, ctx, final_pos, retdest
// Post stack: (empty)
global jumpdest_analysis:
// stack: ctx, code_len, retdest
PUSH 0 // i = 0
global verify_path_and_write_jumpdest_table:
loop:
// stack: i, ctx, code_len, retdest
// Ideally we would break if i >= code_len, but checking i > code_len is
// cheaper. It doesn't hurt to over-read by 1, since we'll read 0 which is
// a no-op.
DUP3 DUP2 GT // i > code_len
%jumpi(return)
// stack: i, ctx, final_pos, retdest
DUP3 DUP2 EQ // i == final_pos
%jumpi(proof_ok)
DUP3 DUP2 GT // i > final_pos
%jumpi(proof_not_ok)
// stack: i, ctx, code_len, retdest
// stack: i, ctx, final_pos, retdest
%stack (i, ctx) -> (ctx, i, i, ctx)
ADD // combine context and offset to make an address (SEGMENT_CODE == 0)
MLOAD_GENERAL
// stack: opcode, i, ctx, code_len, retdest
// stack: opcode, i, ctx, final_pos, retdest
DUP1
// Slightly more efficient than `%eq_const(0x5b) ISZERO`
PUSH 0x5b
SUB
// stack: opcode != JUMPDEST, opcode, i, ctx, code_len, retdest
// stack: opcode != JUMPDEST, opcode, i, ctx, final_pos, retdest
%jumpi(continue)
// stack: JUMPDEST, i, ctx, code_len, retdest
@ -34,16 +31,23 @@ loop:
MSTORE_GENERAL
continue:
// stack: opcode, i, ctx, code_len, retdest
// stack: opcode, i, ctx, final_pos, retdest
%add_const(code_bytes_to_skip)
%mload_kernel_code
// stack: bytes_to_skip, i, ctx, code_len, retdest
// stack: bytes_to_skip, i, ctx, final_pos, retdest
ADD
// stack: i, ctx, code_len, retdest
// stack: i, ctx, final_pos, retdest
%jump(loop)
return:
// stack: i, ctx, code_len, retdest
proof_ok:
// stack: i, ctx, final_pos, retdest
// We already know final_pos is a jumpdest
%stack (i, ctx, final_pos) -> (ctx, @SEGMENT_JUMPDEST_BITS, i)
%build_address
PUSH 1
MSTORE_GENERAL
JUMP
proof_not_ok:
%pop3
JUMP
@ -93,3 +97,237 @@ code_bytes_to_skip:
%rep 128
BYTES 1 // 0x80-0xff
%endrep
// A proof attesting that jumpdest is a valid jump destination is
// either 0 or an index 0 < i <= jumpdest - 32.
// A proof is valid if:
// - i == 0 and we can go from the first opcode to jumpdest and code[jumpdest] = 0x5b
// - i > 0 and:
// a) for j in {i+0,..., i+31} code[j] != PUSHk for all k >= 32 - j - i,
// b) we can go from opcode i+32 to jumpdest,
// c) code[jumpdest] = 0x5b.
// To reduce the number of instructions, when i > 32 we load all the bytes code[j], ...,
// code[j + 31] in a single 32-byte word, and check a) directly on the packed bytes.
// We perform the "packed verification" computing a boolean formula evaluated on the bits of
// code[j],..., code[j+31] of the form p_1 AND p_2 AND p_3 AND p_4 AND p_5, where:
// - p_k is either TRUE, for one subset of the j's which depends on k (for example,
// for k = 1, it is TRUE for the first 15 positions), or has_prefix_k => bit_{k + 1}_is_0
// for the j's not in the subset.
// - has_prefix_k is a predicate that is TRUE if and only if code[j] has the same prefix of size k + 2
// as PUSH{32-(j-i)}.
// stack: proof_prefix_addr, jumpdest, ctx, retdest
// stack: (empty)
global write_table_if_jumpdest:
// stack: proof_prefix_addr, jumpdest, ctx, retdest
%stack
(proof_prefix_addr, jumpdest, ctx) ->
(ctx, jumpdest, jumpdest, ctx, proof_prefix_addr)
ADD // combine context and offset to make an address (SEGMENT_CODE == 0)
MLOAD_GENERAL
// stack: opcode, jumpdest, ctx, proof_prefix_addr, retdest
%jump_neq_const(0x5b, return)
//stack: jumpdest, ctx, proof_prefix_addr, retdest
SWAP2 DUP1
// stack: proof_prefix_addr, proof_prefix_addr, ctx, jumpdest
ISZERO
%jumpi(verify_path_and_write_jumpdest_table)
// stack: proof_prefix_addr, ctx, jumpdest, retdest
// If we are here we need to check that the next 32 bytes are less
// than JUMPXX for XX < 32 - i <=> opcode < 0x7f - i = 127 - i, 0 <= i < 32,
// or larger than 127
%stack
(proof_prefix_addr, ctx) ->
(ctx, proof_prefix_addr, 32, proof_prefix_addr, ctx)
ADD // combine context and offset to make an address (SEGMENT_CODE == 0)
%mload_packing
// packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
DUP1 %shl_const(1)
DUP2 %shl_const(2)
AND
// stack: (is_1_at_pos_2_and_3|(X))³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
// X denotes any value in {0,1} and Z^i is Z repeated i times
NOT
// stack: (is_0_at_2_or_3|X)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
DUP2
OR
// stack: (is_1_at_1 or is_0_at_2_or_3|X)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
// stack: (~has_prefix|X)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
// Compute in_range =
// - (0xFF|X)³² for the first 15 bytes
// - (has_prefix => is_0_at_4 |X)³² for the next 15 bytes
// - (~has_prefix|X)³² for the last byte
// Compute also ~has_prefix = ~has_prefix OR is_0_at_4 for all bytes. We don't need to update ~has_prefix
// for the second half but it takes less cycles if we do it.
DUP2 %shl_const(3)
NOT
// stack: (is_0_at_4|X)³², (~has_prefix|X)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
// pos 0102030405060708091011121314151617181920212223242526272829303132
PUSH 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00
AND
// stack: (is_0_at_4|X)³¹|0, (~has_prefix|X)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
DUP2
DUP2
OR
// pos 0102030405060708091011121314151617181920212223242526272829303132
PUSH 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF0000000000000000000000000000000000
OR
// stack: (in_range|X)³², (is_0_at_4|X)³², (~has_prefix|X)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
SWAP2
OR
// stack: (~has_prefix|X)³², (in_range|X)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
// Compute in_range' = in_range AND
// - (0xFF|X)³² for bytes in positions 1-7 and 16-23
// - (has_prefix => is_0_at_5 |X)³² on the rest
// Compute also ~has_prefix = ~has_prefix OR is_0_at_5 for all bytes.
DUP3 %shl_const(4)
NOT
// stack: (is_0_at_5|X)³², (~has_prefix|X)³², (in_range|X)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
DUP2
DUP2
OR
// pos 0102030405060708091011121314151617181920212223242526272829303132
PUSH 0xFFFFFFFFFFFFFF0000000000000000FFFFFFFFFFFFFFFF000000000000000000
OR
// stack: (in_range'|X)³², (is_0_at_5|X)³², (~has_prefix|X)³², (in_range|X)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
SWAP2
OR
// stack: (~has_prefix|X)³², (in_range'|X)³², (in_range|X)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
SWAP2
AND
SWAP1
// Compute in_range' = in_range AND
// - (0xFF|X)³² for bytes in positions 1-3, 8-11, 16-19, and 24-27
// - (has_prefix => is_0_at_6 |X)³² on the rest
// Compute also that ~has_prefix = ~has_prefix OR is_0_at_4 for all bytes.
// stack: (~has_prefix|X)³², (in_range|X)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
DUP3 %shl_const(5)
NOT
// stack: (is_0_at_6|X)³², (~has_prefix|X)³², (in_range|X)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
DUP2
DUP2
OR
// pos 0102030405060708091011121314151617181920212223242526272829303132
PUSH 0xFFFFFF00000000FFFFFFFF00000000FFFFFFFF00000000FFFFFFFF0000000000
OR
// stack: (in_range'|X)³², (is_0_at_6|X)³², (~has_prefix|X)³², (in_range|X)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
SWAP2
OR
// stack: (~has_prefix|X)³², (in_range'|X)³², (in_range|X)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
SWAP2
AND
SWAP1
// Compute in_range' = in_range AND
// - (0xFF|X)³² for bytes in 1, 4-5, 8-9, 12-13, 16-17, 20-21, 24-25, 28-29
// - (has_prefix => is_0_at_7 |X)³² on the rest
// Compute also that ~has_prefix = ~has_prefix OR is_0_at_7 for all bytes.
// stack: (~has_prefix|X)³², (in_range|X)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
DUP3 %shl_const(6)
NOT
// stack: (is_0_at_7|X)³², (~has_prefix|X)³², (in_range|X)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
DUP2
DUP2
OR
// pos 0102030405060708091011121314151617181920212223242526272829303132
PUSH 0xFF0000FFFF0000FFFF0000FFFF0000FFFF0000FFFF0000FFFF0000FFFF000000
OR
// stack: (in_range'|X)³², (is_0_at_7|X)³², (~has_prefix|X)³², (in_range|X)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
SWAP2
OR
// stack: (~has_prefix|X)³², (in_range'|X)³², (in_range|X)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
SWAP2
AND
SWAP1
// Compute in_range' = in_range AND
// - (0xFF|X)³² for bytes in odd positions
// - (has_prefix => is_0_at_8 |X)³² on the rest
// stack: (~has_prefix|X)³², (in_range|X)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
DUP3 %shl_const(7)
NOT
// stack: (is_0_at_8|X)³², (~has_prefix|X)³², (in_range|X)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
OR
// pos 0102030405060708091011121314151617181920212223242526272829303132
PUSH 0x00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF00FF
OR
AND
// stack: (in_range|X)³², packed_opcodes, proof_prefix_addr, ctx, jumpdest, retdest
// Get rid of the irrelevant bits
// pos 0102030405060708091011121314151617181920212223242526272829303132
PUSH 0x8080808080808080808080808080808080808080808080808080808080808080
AND
%jump_neq_const(0x8080808080808080808080808080808080808080808080808080808080808080, return)
POP
%add_const(32)
// check the remaining path
%jump(verify_path_and_write_jumpdest_table)
return:
// stack: proof_prefix_addr, jumpdest, ctx, retdest
%pop3
JUMP
%macro write_table_if_jumpdest
%stack (proof_prefix_addr, jumpdest, ctx) -> (proof_prefix_addr, jumpdest, ctx, %%after)
%jump(write_table_if_jumpdest)
%%after:
%endmacro
// Write the jumpdest table. This is done by
// non-deterministically guessing the sequence of jumpdest
// addresses used during program execution within the current context.
// For each jumpdest address we also non-deterministically guess
// a proof, which is another address in the code such that
// is_jumpdest doesn't abort, when the proof is at the top of the stack
// an the jumpdest address below. If that's the case we set the
// corresponding bit in @SEGMENT_JUMPDEST_BITS to 1.
//
// stack: ctx, code_len, retdest
// stack: (empty)
global jumpdest_analysis:
// If address > 0 then address is interpreted as address' + 1
// and the next prover input should contain a proof for address'.
PROVER_INPUT(jumpdest_table::next_address)
DUP1 %jumpi(check_proof)
// If address == 0 there are no more jump destinations to check
POP
// This is just a hook used for avoiding verification of the jumpdest
// table in another context. It is useful during proof generation,
// allowing the avoidance of table verification when simulating user code.
global jumpdest_analysis_end:
%pop2
JUMP
check_proof:
// stack: address, ctx, code_len, retdest
DUP3 DUP2 %assert_le
%decrement
// stack: proof, ctx, code_len, retdest
DUP2 SWAP1
// stack: address, ctx, ctx, code_len, retdest
// We read the proof
PROVER_INPUT(jumpdest_table::next_proof)
// stack: proof, address, ctx, ctx, code_len, retdest
%write_table_if_jumpdest
// stack: ctx, code_len, retdest
%jump(jumpdest_analysis)
%macro jumpdest_analysis
%stack (ctx, code_len) -> (ctx, code_len, %%after)
%jump(jumpdest_analysis)
%%after:
%endmacro

View File

@ -8,6 +8,19 @@
jumpi
%endmacro
// Jump to `jumpdest` if the top of the stack is != c
%macro jump_neq_const(c, jumpdest)
PUSH $c
SUB
%jumpi($jumpdest)
%endmacro
// Jump to `jumpdest` if the top of the stack is < c
%macro jumpi_lt_const(c, jumpdest)
%ge_const($c)
%jumpi($jumpdest)
%endmacro
%macro pop2
%rep 2
POP

View File

@ -1,7 +1,7 @@
//! An EVM interpreter for testing and debugging purposes.
use core::cmp::Ordering;
use std::collections::HashMap;
use std::collections::{BTreeSet, HashMap, HashSet};
use std::ops::Range;
use anyhow::bail;
@ -11,6 +11,7 @@ use keccak_hash::keccak;
use plonky2::field::goldilocks_field::GoldilocksField;
use super::assembler::BYTES_PER_OFFSET;
use super::utils::u256_from_bool;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::constants::context_metadata::ContextMetadata;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
@ -585,7 +586,23 @@ impl<'a> Interpreter<'a> {
.collect()
}
fn incr(&mut self, n: usize) {
pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec<bool>) {
self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits.unscale()]
.content = jumpdest_bits.iter().map(|&x| u256_from_bool(x)).collect();
self.generation_state
.set_proofs_and_jumpdests(HashMap::from([(
context,
BTreeSet::from_iter(
jumpdest_bits
.into_iter()
.enumerate()
.filter(|&(_, x)| x)
.map(|(i, _)| i),
),
)]));
}
pub(crate) fn incr(&mut self, n: usize) {
self.generation_state.registers.program_counter += n;
}

View File

@ -27,7 +27,7 @@ fn test_jumpdest_analysis() -> Result<()> {
jumpdest,
];
let expected_jumpdest_bits = vec![false, true, false, false, false, true, false, true];
let jumpdest_bits = vec![false, true, false, false, false, true, false, true];
// Contract creation transaction.
let initial_stack = vec![
@ -37,12 +37,10 @@ fn test_jumpdest_analysis() -> Result<()> {
];
let mut interpreter = Interpreter::new_with_kernel(jumpdest_analysis, initial_stack);
interpreter.set_code(CONTEXT, code);
interpreter.set_jumpdest_bits(CONTEXT, jumpdest_bits);
interpreter.run()?;
assert_eq!(interpreter.stack(), vec![]);
assert_eq!(
interpreter.get_jumpdest_bits(CONTEXT),
expected_jumpdest_bits
);
Ok(())
}

View File

@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::collections::{BTreeSet, HashMap};
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
@ -8,6 +8,7 @@ use ethereum_types::{Address, BigEndianHash, H256, U256};
use itertools::enumerate;
use plonky2::field::extension::Extendable;
use plonky2::field::polynomial::PolynomialValues;
use plonky2::field::types::Field;
use plonky2::hash::hash_types::RichField;
use plonky2::timed;
use plonky2::util::timing::TimingTree;
@ -21,13 +22,16 @@ use crate::all_stark::{AllStark, NUM_TABLES};
use crate::config::StarkConfig;
use crate::cpu::columns::CpuColumnsView;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::assembler::Kernel;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
use crate::cpu::kernel::opcodes::get_opcode;
use crate::generation::state::GenerationState;
use crate::generation::trie_extractor::{get_receipt_trie, get_state_trie, get_txn_trie};
use crate::memory::segments::Segment;
use crate::proof::{BlockHashes, BlockMetadata, ExtraBlockData, PublicValues, TrieRoots};
use crate::prover::check_abort_signal;
use crate::util::{h2u, u256_to_usize};
use crate::util::{h2u, u256_to_u8, u256_to_usize};
use crate::witness::errors::{ProgramError, ProverInputError};
use crate::witness::memory::{MemoryAddress, MemoryChannel};
use crate::witness::transition::transition;
@ -38,7 +42,7 @@ pub(crate) mod state;
mod trie_extractor;
use self::mpt::{load_all_mpts, TrieRootPtrs};
use crate::witness::util::mem_write_log;
use crate::witness::util::{mem_write_log, stack_peek};
/// Inputs needed for trace generation.
#[derive(Clone, Debug, Deserialize, Serialize, Default)]
@ -303,9 +307,7 @@ pub fn generate_traces<F: RichField + Extendable<D>, const D: usize>(
Ok((tables, public_values))
}
fn simulate_cpu<F: RichField + Extendable<D>, const D: usize>(
state: &mut GenerationState<F>,
) -> anyhow::Result<()> {
fn simulate_cpu<F: Field>(state: &mut GenerationState<F>) -> anyhow::Result<()> {
let halt_pc = KERNEL.global_labels["halt"];
loop {
@ -340,3 +342,85 @@ fn simulate_cpu<F: RichField + Extendable<D>, const D: usize>(
transition(state)?;
}
}
fn simulate_cpu_between_labels_and_get_user_jumps<F: Field>(
initial_label: &str,
final_label: &str,
state: &mut GenerationState<F>,
) -> Option<HashMap<usize, BTreeSet<usize>>> {
if state.jumpdest_proofs.is_some() {
None
} else {
const JUMP_OPCODE: u8 = 0x56;
const JUMPI_OPCODE: u8 = 0x57;
let halt_pc = KERNEL.global_labels[final_label];
let mut jumpdest_addresses: HashMap<_, BTreeSet<usize>> = HashMap::new();
state.registers.program_counter = KERNEL.global_labels[initial_label];
let initial_clock = state.traces.clock();
let initial_context = state.registers.context;
log::debug!("Simulating CPU for jumpdest analysis.");
loop {
// skip jumpdest table validations in simulations
if state.registers.program_counter == KERNEL.global_labels["jumpdest_analysis"] {
state.registers.program_counter = KERNEL.global_labels["jumpdest_analysis_end"]
}
let pc = state.registers.program_counter;
let context = state.registers.context;
let mut halt = state.registers.is_kernel
&& pc == halt_pc
&& state.registers.context == initial_context;
let Ok(opcode) = u256_to_u8(state.memory.get(MemoryAddress::new(
context,
Segment::Code,
state.registers.program_counter,
))) else {
log::debug!(
"Simulated CPU for jumpdest analysis halted after {} cycles",
state.traces.clock() - initial_clock
);
return Some(jumpdest_addresses);
};
let cond = if let Ok(cond) = stack_peek(state, 1) {
cond != U256::zero()
} else {
false
};
if !state.registers.is_kernel
&& (opcode == JUMP_OPCODE || (opcode == JUMPI_OPCODE && cond))
{
// Avoid deeper calls to abort
let Ok(jumpdest) = u256_to_usize(state.registers.stack_top) else {
log::debug!(
"Simulated CPU for jumpdest analysis halted after {} cycles",
state.traces.clock() - initial_clock
);
return Some(jumpdest_addresses);
};
state.memory.set(
MemoryAddress::new(context, Segment::JumpdestBits, jumpdest),
U256::one(),
);
let jumpdest_opcode =
state
.memory
.get(MemoryAddress::new(context, Segment::Code, jumpdest));
if let Some(ctx_addresses) = jumpdest_addresses.get_mut(&context) {
ctx_addresses.insert(jumpdest);
} else {
jumpdest_addresses.insert(context, BTreeSet::from([jumpdest]));
}
}
if halt || transition(state).is_err() {
log::debug!(
"Simulated CPU for jumpdest analysis halted after {} cycles",
state.traces.clock() - initial_clock
);
return Some(jumpdest_addresses);
}
}
}
}

View File

@ -1,3 +1,5 @@
use std::cmp::min;
use std::collections::HashMap;
use std::mem::transmute;
use std::str::FromStr;
@ -5,20 +7,26 @@ use anyhow::{bail, Error};
use ethereum_types::{BigEndianHash, H256, U256, U512};
use itertools::{enumerate, Itertools};
use num_bigint::BigUint;
use plonky2::field::extension::Extendable;
use plonky2::field::types::Field;
use plonky2::hash::hash_types::RichField;
use serde::{Deserialize, Serialize};
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::constants::context_metadata::ContextMetadata;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
use crate::extension_tower::{FieldExt, Fp12, BLS381, BN254};
use crate::generation::prover_input::EvmField::{
Bls381Base, Bls381Scalar, Bn254Base, Bn254Scalar, Secp256k1Base, Secp256k1Scalar,
};
use crate::generation::prover_input::FieldOp::{Inverse, Sqrt};
use crate::generation::simulate_cpu_between_labels_and_get_user_jumps;
use crate::generation::state::GenerationState;
use crate::memory::segments::Segment;
use crate::memory::segments::Segment::BnPairing;
use crate::util::{biguint_to_mem_vec, mem_vec_to_biguint, u256_to_usize};
use crate::witness::errors::ProgramError;
use crate::util::{biguint_to_mem_vec, mem_vec_to_biguint, u256_to_u8, u256_to_usize};
use crate::witness::errors::ProverInputError::*;
use crate::witness::errors::{ProgramError, ProverInputError};
use crate::witness::memory::MemoryAddress;
use crate::witness::operation::CONTEXT_SCALING_FACTOR;
use crate::witness::util::{current_context_peek, stack_peek};
@ -48,6 +56,7 @@ impl<F: Field> GenerationState<F> {
"bignum_modmul" => self.run_bignum_modmul(),
"withdrawal" => self.run_withdrawal(),
"num_bits" => self.run_num_bits(),
"jumpdest_table" => self.run_jumpdest_table(input_fn),
_ => Err(ProgramError::ProverInputError(InvalidFunction)),
}
}
@ -230,6 +239,236 @@ impl<F: Field> GenerationState<F> {
Ok(num_bits.into())
}
}
/// Generate either the next used jump address or the proof for the last jump address.
fn run_jumpdest_table(&mut self, input_fn: &ProverInputFn) -> Result<U256, ProgramError> {
match input_fn.0[1].as_str() {
"next_address" => self.run_next_jumpdest_table_address(),
"next_proof" => self.run_next_jumpdest_table_proof(),
_ => Err(ProgramError::ProverInputError(InvalidInput)),
}
}
/// Returns the next used jump address.
fn run_next_jumpdest_table_address(&mut self) -> Result<U256, ProgramError> {
let context = self.registers.context;
let code_len = u256_to_usize(self.get_code_len()?.into());
if self.jumpdest_proofs.is_none() {
self.generate_jumpdest_proofs()?;
}
let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else {
return Err(ProgramError::ProverInputError(
ProverInputError::InvalidJumpdestSimulation,
));
};
if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context)
&& let Some(next_jumpdest_address) = ctx_jumpdest_proofs.pop()
{
Ok((next_jumpdest_address + 1).into())
} else {
self.jumpdest_proofs = None;
Ok(U256::zero())
}
}
/// Returns the proof for the last jump address.
fn run_next_jumpdest_table_proof(&mut self) -> Result<U256, ProgramError> {
let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else {
return Err(ProgramError::ProverInputError(
ProverInputError::InvalidJumpdestSimulation,
));
};
if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context)
&& let Some(next_jumpdest_proof) = ctx_jumpdest_proofs.pop()
{
Ok(next_jumpdest_proof.into())
} else {
Err(ProgramError::ProverInputError(
ProverInputError::InvalidJumpdestSimulation,
))
}
}
}
impl<F: Field> GenerationState<F> {
/// Simulate the user's code and store all the jump addresses with their respective contexts.
fn generate_jumpdest_proofs(&mut self) -> Result<(), ProgramError> {
let checkpoint = self.checkpoint();
let memory = self.memory.clone();
let code = self.get_current_code()?;
// We need to set the simulated jumpdest bits to one as otherwise
// the simulation will fail.
// Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call
let Some(jumpdest_table) = simulate_cpu_between_labels_and_get_user_jumps(
"jumpdest_analysis_end",
"terminate_common",
self,
) else {
self.jumpdest_proofs = Some(HashMap::new());
return Ok(());
};
// Return to the state before starting the simulation
self.rollback(checkpoint);
self.memory = memory;
// Find proofs for all contexts
self.set_proofs_and_jumpdests(jumpdest_table);
Ok(())
}
/// Given a HashMap containing the contexts and the jumpdest addresses, compute their respective proofs,
/// by calling `get_proofs_and_jumpdests`
pub(crate) fn set_proofs_and_jumpdests(
&mut self,
jumpdest_table: HashMap<usize, std::collections::BTreeSet<usize>>,
) {
self.jumpdest_proofs = Some(HashMap::from_iter(jumpdest_table.into_iter().map(
|(ctx, jumpdest_table)| {
let code = self.get_code(ctx).unwrap();
if let Some(&largest_address) = jumpdest_table.last() {
let proofs = get_proofs_and_jumpdests(&code, largest_address, jumpdest_table);
(ctx, proofs)
} else {
(ctx, vec![])
}
},
)));
}
fn get_current_code(&self) -> Result<Vec<u8>, ProgramError> {
self.get_code(self.registers.context)
}
fn get_code(&self, context: usize) -> Result<Vec<u8>, ProgramError> {
let code_len = self.get_code_len()?;
let code = (0..code_len)
.map(|i| {
u256_to_u8(self.memory.get(MemoryAddress::new(
self.registers.context,
Segment::Code,
i,
)))
})
.collect::<Result<Vec<u8>, _>>()?;
Ok(code)
}
fn get_code_len(&self) -> Result<usize, ProgramError> {
let code_len = u256_to_usize(self.memory.get(MemoryAddress::new(
self.registers.context,
Segment::ContextMetadata,
ContextMetadata::CodeSize.unscale(),
)))?;
Ok(code_len)
}
fn set_jumpdest_bits(&mut self, code: &[u8]) {
const JUMPDEST_OPCODE: u8 = 0x5b;
for (pos, opcode) in CodeIterator::new(code) {
if opcode == JUMPDEST_OPCODE {
self.memory.set(
MemoryAddress::new(self.registers.context, Segment::JumpdestBits, pos),
U256::one(),
);
}
}
}
}
/// For all address in `jumpdest_table`, each bounded by `largest_address`,
/// this function searches for a proof. A proof is the closest address
/// for which none of the previous 32 bytes in the code (including opcodes
/// and pushed bytes) are PUSHXX and the address is in its range. It returns
/// a vector of even size containing proofs followed by their addresses.
fn get_proofs_and_jumpdests(
code: &[u8],
largest_address: usize,
jumpdest_table: std::collections::BTreeSet<usize>,
) -> Vec<usize> {
const PUSH1_OPCODE: u8 = 0x60;
const PUSH32_OPCODE: u8 = 0x7f;
let (proofs, _) = CodeIterator::until(code, largest_address + 1).fold(
(vec![], 0),
|(mut proofs, acc), (pos, opcode)| {
let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) {
code[prefix_start..pos]
.iter()
.enumerate()
.fold(true, |acc, (prefix_pos, &byte)| {
let cond1 = byte > PUSH32_OPCODE;
let cond2 = (prefix_start + prefix_pos) as i32
+ (byte as i32 - PUSH1_OPCODE as i32)
+ 1
< pos as i32;
acc && (cond1 || cond2)
})
} else {
false
};
let acc = if has_prefix { pos - 32 } else { acc };
if jumpdest_table.contains(&pos) {
// Push the proof
proofs.push(acc);
// Push the address
proofs.push(pos);
}
(proofs, acc)
},
);
proofs
}
/// An iterator over the EVM code contained in `code`, which skips the bytes
/// that are the arguments of a PUSHXX opcode.
struct CodeIterator<'a> {
code: &'a [u8],
pos: usize,
end: usize,
}
impl<'a> CodeIterator<'a> {
fn new(code: &'a [u8]) -> Self {
CodeIterator {
end: code.len(),
code,
pos: 0,
}
}
fn until(code: &'a [u8], end: usize) -> Self {
CodeIterator {
end: std::cmp::min(code.len(), end),
code,
pos: 0,
}
}
}
impl<'a> Iterator for CodeIterator<'a> {
type Item = (usize, u8);
fn next(&mut self) -> Option<Self::Item> {
const PUSH1_OPCODE: u8 = 0x60;
const PUSH32_OPCODE: u8 = 0x7f;
let CodeIterator { code, pos, end } = self;
if *pos >= *end {
return None;
}
let opcode = code[*pos];
let old_pos = *pos;
*pos += if (PUSH1_OPCODE..=PUSH32_OPCODE).contains(&opcode) {
(opcode - PUSH1_OPCODE + 2).into()
} else {
1
};
Some((old_pos, opcode))
}
}
enum EvmField {

View File

@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::collections::{BTreeSet, HashMap};
use ethereum_types::{Address, BigEndianHash, H160, H256, U256};
use keccak_hash::keccak;
@ -50,6 +50,12 @@ pub(crate) struct GenerationState<F: Field> {
/// Pointers, within the `TrieData` segment, of the three MPTs.
pub(crate) trie_root_ptrs: TrieRootPtrs,
/// A hash map where the key is a context in the user's code and the value is the set of
/// jump destinations with its corresponding "proof". A "proof" for a jump destination is
/// either 0 or an address i > 32 in the code (not necessarily pointing to an opcode) such that
/// for every j in [i, i+32] it holds that code[j] < 0x7f - j + i.
pub(crate) jumpdest_proofs: Option<HashMap<usize, Vec<usize>>>,
}
impl<F: Field> GenerationState<F> {
@ -91,6 +97,7 @@ impl<F: Field> GenerationState<F> {
txn_root_ptr: 0,
receipt_root_ptr: 0,
},
jumpdest_proofs: None,
};
let trie_root_ptrs = state.preinitialize_mpts(&inputs.tries);
@ -165,6 +172,26 @@ impl<F: Field> GenerationState<F> {
.map(|i| stack_peek(self, i).unwrap())
.collect()
}
/// Clones everything but the traces.
pub(crate) fn soft_clone(&self) -> GenerationState<F> {
Self {
inputs: self.inputs.clone(),
registers: self.registers,
memory: self.memory.clone(),
traces: Traces::default(),
rlp_prover_inputs: self.rlp_prover_inputs.clone(),
state_key_to_address: self.state_key_to_address.clone(),
bignum_modmul_result_limbs: self.bignum_modmul_result_limbs.clone(),
withdrawal_prover_inputs: self.withdrawal_prover_inputs.clone(),
trie_root_ptrs: TrieRootPtrs {
state_root_ptr: 0,
txn_root_ptr: 0,
receipt_root_ptr: 0,
},
jumpdest_proofs: None,
}
}
}
/// Withdrawals prover input array is of the form `[addr0, amount0, ..., addrN, amountN, U256::MAX, U256::MAX]`.

View File

@ -255,7 +255,7 @@ pub(crate) fn get_trie_helper<N: PartialTrie>(
) -> Result<Node<N>, ProgramError> {
let load = |offset| memory.get(MemoryAddress::new(0, Segment::TrieData, offset));
let load_slice_from = |init_offset| {
&memory.contexts[0].segments[Segment::TrieData as usize].content[init_offset..]
&memory.contexts[0].segments[Segment::TrieData.unscale()].content[init_offset..]
};
let trie_type = PartialTrieType::all()[u256_to_usize(load(ptr))?];

View File

@ -36,4 +36,6 @@ pub enum ProverInputError {
InvalidInput,
InvalidFunction,
NumBitsError,
InvalidJumpDestination,
InvalidJumpdestSimulation,
}

View File

@ -77,7 +77,7 @@ fn test_empty_txn_list() -> anyhow::Result<()> {
// Initialize the preprocessed circuits for the zkEVM.
let all_circuits = AllRecursiveCircuits::<F, C, D>::new(
&all_stark,
&[16..17, 10..11, 12..13, 14..15, 9..11, 12..13, 17..18], // Minimal ranges to prove an empty list
&[16..17, 9..11, 12..13, 14..15, 9..11, 12..13, 17..18], // Minimal ranges to prove an empty list
&config,
);