Fix bug in jumpdest proof generation and check that jumpdest addr < code_len

This commit is contained in:
4l0n50 2023-12-28 14:04:23 +01:00
commit a85f9872f1
7 changed files with 60 additions and 42 deletions

View File

@ -367,9 +367,10 @@ call_too_deep:
%checkpoint // Checkpoint
%increment_call_depth
// Perform jumpdest analyis
%mload_context_metadata(@CTX_METADATA_CODE_SIZE)
GET_CONTEXT
// stack: ctx, code_size, retdest
%jumpdest_analisys
%jumpdest_analysis
PUSH 0 // jump dest
EXIT_KERNEL
// (Old context) stack: new_ctx

View File

@ -102,7 +102,7 @@ code_bytes_to_skip:
// - we can go from opcode i+32 to jumpdest,
// - code[jumpdest] = 0x5b.
// stack: proof_prefix_addr, jumpdest, ctx, retdest
// stack: (empty) abort if jumpdest is not a valid destination
// stack: (empty)
global write_table_if_jumpdest:
// stack: proof_prefix_addr, jumpdest, ctx, retdest
%stack
@ -129,7 +129,7 @@ global write_table_if_jumpdest:
%check_and_step(111) %check_and_step(110) %check_and_step(109) %check_and_step(108)
%check_and_step(107) %check_and_step(106) %check_and_step(105) %check_and_step(104)
%check_and_step(103) %check_and_step(102) %check_and_step(101) %check_and_step(100)
%check_and_step(99) %check_and_step(98) %check_and_step(97) %check_and_step(96)
%check_and_step(99) %check_and_step(98) %check_and_step(97) %check_and_step(96)
// check the remaining path
%jump(verify_path_and_write_table)
@ -173,9 +173,9 @@ return:
// an the jumpdest address below. If that's the case we set the
// corresponding bit in @SEGMENT_JUMPDEST_BITS to 1.
//
// stack: ctx, retdest
// stack: ctx, code_len, retdest
// stack: (empty)
global jumpdest_analisys:
global jumpdest_analysis:
// If address > 0 then address is interpreted as address' + 1
// and the next prover input should contain a proof for address'.
PROVER_INPUT(jumpdest_table::next_address)
@ -185,22 +185,26 @@ global jumpdest_analisys:
// This is just a hook used for avoiding verification of the jumpdest
// table in another contexts. It is useful during proof generation,
// allowing the avoidance of table verification when simulating user code.
global jumpdest_analisys_end:
POP
global jumpdest_analysis_end:
%pop2
JUMP
check_proof:
// stack: proof, ctx, code_len, retdest
DUP3 DUP2 %assert_le
%decrement
// stack: proof, ctx, code_len, retdest
DUP2 SWAP1
// stack: address, ctx, ctx
// stack: address, ctx, ctx, code_len, retdest
// We read the proof
PROVER_INPUT(jumpdest_table::next_proof)
// stack: proof, address, ctx, ctx
// stack: proof, address, ctx, ctx, code_len, retdest
%write_table_if_jumpdest
// stack: ctx, code_len, retdest
%jump(jumpdest_analisys)
%jump(jumpdest_analysis)
%macro jumpdest_analisys
%stack (ctx) -> (ctx, %%after)
%jump(jumpdest_analisys)
%macro jumpdest_analysis
%stack (ctx, code_len) -> (ctx, code_len, %%after)
%jump(jumpdest_analysis)
%%after:
%endmacro

View File

@ -5,8 +5,8 @@ use crate::cpu::kernel::interpreter::Interpreter;
use crate::cpu::kernel::opcodes::{get_opcode, get_push_opcode};
#[test]
fn test_jumpdest_analisys() -> Result<()> {
let jumpdest_analisys = KERNEL.global_labels["jumpdest_analisys"];
fn test_jumpdest_analysis() -> Result<()> {
let jumpdest_analysis = KERNEL.global_labels["jumpdest_analysis"];
const CONTEXT: usize = 3; // arbitrary
let add = get_opcode("ADD");
@ -28,8 +28,8 @@ fn test_jumpdest_analisys() -> Result<()> {
let jumpdest_bits = vec![false, true, false, false, false, true, false, true];
// Contract creation transaction.
let initial_stack = vec![0xDEADBEEFu32.into(), CONTEXT.into()];
let mut interpreter = Interpreter::new_with_kernel(jumpdest_analisys, initial_stack);
let initial_stack = vec![0xDEADBEEFu32.into(), code.len().into(), CONTEXT.into()];
let mut interpreter = Interpreter::new_with_kernel(jumpdest_analysis, initial_stack);
interpreter.set_code(CONTEXT, code);
interpreter.set_jumpdest_bits(CONTEXT, jumpdest_bits);

View File

@ -340,9 +340,9 @@ fn simulate_cpu_between_labels_and_get_user_jumps<F: Field>(
initial_label: &str,
final_label: &str,
state: &mut GenerationState<F>,
) -> Result<Option<HashMap<usize, BTreeSet<usize>>>, ProgramError> {
) -> Option<HashMap<usize, BTreeSet<usize>>> {
if state.jumpdest_proofs.is_some() {
Ok(None)
None
} else {
const JUMP_OPCODE: u8 = 0x56;
const JUMPI_OPCODE: u8 = 0x57;
@ -358,19 +358,25 @@ fn simulate_cpu_between_labels_and_get_user_jumps<F: Field>(
loop {
// skip jumdest table validations in simulations
if state.registers.program_counter == KERNEL.global_labels["jumpdest_analisys"] {
state.registers.program_counter = KERNEL.global_labels["jumpdest_analisys_end"]
if state.registers.program_counter == KERNEL.global_labels["jumpdest_analysis"] {
state.registers.program_counter = KERNEL.global_labels["jumpdest_analysis_end"]
}
let pc = state.registers.program_counter;
let context = state.registers.context;
let halt = state.registers.is_kernel
let mut halt = state.registers.is_kernel
&& pc == halt_pc
&& state.registers.context == initial_context;
let opcode = u256_to_u8(state.memory.get(MemoryAddress {
let Ok(opcode) = u256_to_u8(state.memory.get(MemoryAddress {
context,
segment: Segment::Code as usize,
virt: state.registers.program_counter,
}))?;
})) else {
log::debug!(
"Simulated CPU halted after {} cycles",
state.traces.clock() - initial_clock
);
return Some(jumpdest_addresses);
};
let cond = if let Ok(cond) = stack_peek(state, 1) {
cond != U256::zero()
} else {
@ -380,7 +386,13 @@ fn simulate_cpu_between_labels_and_get_user_jumps<F: Field>(
&& (opcode == JUMP_OPCODE || (opcode == JUMPI_OPCODE && cond))
{
// Avoid deeper calls to abort
let jumpdest = u256_to_usize(state.registers.stack_top)?;
let Ok(jumpdest) = u256_to_usize(state.registers.stack_top) else {
log::debug!(
"Simulated CPU halted after {} cycles",
state.traces.clock() - initial_clock
);
return Some(jumpdest_addresses);
};
state.memory.set(
MemoryAddress {
context,
@ -400,16 +412,13 @@ fn simulate_cpu_between_labels_and_get_user_jumps<F: Field>(
jumpdest_addresses.insert(context, BTreeSet::from([jumpdest]));
}
}
if halt {
if halt || transition(state).is_err() {
log::debug!(
"Simulated CPU halted after {} cycles",
state.traces.clock() - initial_clock
);
return Ok(Some(jumpdest_addresses));
return Some(jumpdest_addresses);
}
transition(state).map_err(|_| {
ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation)
})?;
}
}
}

View File

@ -302,15 +302,15 @@ impl<F: Field> GenerationState<F> {
let code = self.get_current_code()?;
// We need to set the simulated jumpdest bits to one as otherwise
// the simulation will fail.
self.set_jumpdest_bits(&code);
// self.set_jumpdest_bits(&code);
// Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call
let Some(jumpdest_table) = simulate_cpu_between_labels_and_get_user_jumps(
"jumpdest_analisys_end",
"jumpdest_analysis_end",
"terminate_common",
self,
)?
else {
) else {
self.jumpdest_proofs = Some(HashMap::new());
return Ok(());
};
@ -318,7 +318,7 @@ impl<F: Field> GenerationState<F> {
self.rollback(checkpoint);
self.memory = memory;
// Find proofs for all context
// Find proofs for all contexts
self.set_proofs_and_jumpdests(jumpdest_table);
Ok(())
@ -368,9 +368,9 @@ impl<F: Field> GenerationState<F> {
Ok(code_len)
}
fn set_jumpdest_bits<'a>(&mut self, code: &'a Vec<u8>) {
fn set_jumpdest_bits(&mut self, code: &[u8]) {
const JUMPDEST_OPCODE: u8 = 0x5b;
for (pos, opcode) in CodeIterator::new(&code) {
for (pos, opcode) in CodeIterator::new(code) {
if opcode == JUMPDEST_OPCODE {
self.memory.set(
MemoryAddress {
@ -385,13 +385,13 @@ impl<F: Field> GenerationState<F> {
}
}
/// For each address in `jumpdest_table`, each bounded by larges_address,
/// For all address in `jumpdest_table`, each bounded by larges_address,
/// this function searches for a proof. A proof is the closest address
/// for which none of the previous 32 bytes in the code (including opcodes
/// and pushed bytes are PUSHXX and the address is in its range. It returns
/// a vector of even size containing proofs followed by their addresses
fn get_proofs_and_jumpdests<'a>(
code: &'a Vec<u8>,
fn get_proofs_and_jumpdests(
code: &[u8],
largest_address: usize,
jumpdest_table: std::collections::BTreeSet<usize>,
) -> Vec<usize> {
@ -455,7 +455,7 @@ impl<'a> Iterator for CodeIterator<'a> {
fn next(&mut self) -> Option<Self::Item> {
const PUSH1_OPCODE: u8 = 0x60;
const PUSH32_OPCODE: u8 = 0x70;
const PUSH32_OPCODE: u8 = 0x7f;
let CodeIterator { code, pos, end } = self;
if *pos >= *end {
return None;

View File

@ -396,8 +396,9 @@ fn try_perform_instruction<F: Field>(
log_kernel_instruction(state, op);
} else {
log::debug!(
"User instruction: {:?}, ctx = {:?}, stack = {:?}",
"User instruction: {:?}, pc = {:?}, ctx = {:?}, stack = {:?}",
op,
state.registers.program_counter,
state.registers.context,
state.stack()
);

View File

@ -154,6 +154,9 @@ fn test_simple_transfer() -> anyhow::Result<()> {
},
};
let bytes = std::fs::read("jumpi_d18g0v0_Shanghai.json").unwrap();
let inputs = serde_json::from_slice(&bytes).unwrap();
let mut timing = TimingTree::new("prove", log::Level::Debug);
let proof = prove::<F, C, D>(&all_stark, &config, inputs, &mut timing, None)?;
timing.filter(Duration::from_millis(100)).print();