Refactor run_next_jumpdest_table_proof

This commit is contained in:
4l0n50 2023-12-13 14:11:43 +01:00
parent 9e39d88ab8
commit ff3dc2e516
7 changed files with 200 additions and 208 deletions

View File

@ -367,13 +367,10 @@ call_too_deep:
%checkpoint // Checkpoint
%increment_call_depth
// Perform jumpdest analyis
// PUSH %%after
// %mload_context_metadata(@CTX_METADATA_CODE_SIZE)
// GET_CONTEXT
GET_CONTEXT
// stack: ctx, code_size, retdest
// %jump(jumpdest_analysis)
%validate_jumpdest_table
%%after:
PUSH 0 // jump dest
EXIT_KERNEL
// (Old context) stack: new_ctx

View File

@ -3,13 +3,13 @@
// Pre stack: init_pos, ctx, final_pos, retdest
// Post stack: (empty)
global verify_path:
loop_new:
loop:
// stack: i, ctx, final_pos, retdest
// Ideally we would break if i >= final_pos, but checking i > final_pos is
// cheaper. It doesn't hurt to over-read by 1, since we'll read 0 which is
// a no-op.
DUP3 DUP2 EQ // i == final_pos
%jumpi(return_new)
%jumpi(return)
DUP3 DUP2 GT // i > final_pos
%jumpi(panic)
@ -18,51 +18,6 @@ loop_new:
MLOAD_GENERAL
// stack: opcode, i, ctx, final_pos, retdest
DUP1
// Slightly more efficient than `%eq_const(0x5b) ISZERO`
PUSH 0x5b
SUB
// stack: opcode != JUMPDEST, opcode, i, ctx, code_len, retdest
%jumpi(continue_new)
// stack: JUMPDEST, i, ctx, code_len, retdest
%stack (JUMPDEST, i, ctx) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, i, JUMPDEST, i, ctx)
MSTORE_GENERAL
continue_new:
// stack: opcode, i, ctx, code_len, retdest
%add_const(code_bytes_to_skip)
%mload_kernel_code
// stack: bytes_to_skip, i, ctx, code_len, retdest
ADD
// stack: i, ctx, code_len, retdest
%jump(loop_new)
return_new:
// stack: i, ctx, code_len, retdest
%pop3
JUMP
// Populates @SEGMENT_JUMPDEST_BITS for the given context's code.
// Pre stack: ctx, code_len, retdest
// Post stack: (empty)
global jumpdest_analysis:
// stack: ctx, code_len, retdest
PUSH 0 // i = 0
loop:
// stack: i, ctx, code_len, retdest
// Ideally we would break if i >= code_len, but checking i > code_len is
// cheaper. It doesn't hurt to over-read by 1, since we'll read 0 which is
// a no-op.
DUP3 DUP2 GT // i > code_len
%jumpi(return)
// stack: i, ctx, code_len, retdest
%stack (i, ctx) -> (ctx, @SEGMENT_CODE, i, i, ctx)
MLOAD_GENERAL
// stack: opcode, i, ctx, code_len, retdest
DUP1
// Slightly more efficient than `%eq_const(0x5b) ISZERO`
PUSH 0x5b
@ -144,13 +99,17 @@ code_bytes_to_skip:
// - for j in {i+0,..., i+31} code[j] != PUSHk for all k >= 32 - j - i,
// - we can go from opcode i+32 to jumpdest,
// - code[jumpdest] = 0x5b.
// stack: proof_prefix_addr, jumpdest, retdest
// stack: proof_prefix_addr, jumpdest, ctx, retdest
// stack: (empty) abort if jumpdest is not a valid destination
global is_jumpdest:
GET_CONTEXT
// stack: ctx, proof_prefix_addr, jumpdest, retdest
// stack: proof_prefix_addr, jumpdest, ctx, retdest
//%stack
// (proof_prefix_addr, jumpdest, ctx) ->
// (ctx, @SEGMENT_JUMPDEST_BITS, jumpdest, proof_prefix_addr, jumpdest, ctx)
//MLOAD_GENERAL
//%jumpi(return_is_jumpdest)
%stack
(ctx, proof_prefix_addr, jumpdest) ->
(proof_prefix_addr, jumpdest, ctx) ->
(ctx, @SEGMENT_CODE, jumpdest, jumpdest, ctx, proof_prefix_addr)
MLOAD_GENERAL
// stack: opcode, jumpdest, ctx, proof_prefix_addr, retdest
@ -182,8 +141,8 @@ global is_jumpdest:
%jump(verify_path)
return_is_jumpdest:
//stack: proof_prefix_addr, jumpdest, retdest
%pop2
//stack: proof_prefix_addr, jumpdest, ctx, retdest
%pop3
JUMP
@ -194,7 +153,7 @@ return_is_jumpdest:
(proof_prefix_addr, ctx, jumpdest) ->
(ctx, @SEGMENT_CODE, proof_prefix_addr, proof_prefix_addr, ctx, jumpdest)
MLOAD_GENERAL
// stack: opcode, proof_prefix_addr, ctx, jumpdest
// stack: opcode, ctx, proof_prefix_addr, jumpdest
DUP1
%gt_const(127)
%jumpi(%%ok)
@ -207,7 +166,7 @@ return_is_jumpdest:
%endmacro
%macro is_jumpdest
%stack (proof, addr) -> (proof, addr, %%after)
%stack (proof, addr, ctx) -> (proof, addr, ctx, %%after)
%jump(is_jumpdest)
%%after:
%endmacro
@ -216,58 +175,41 @@ return_is_jumpdest:
// non-deterministically guessing the sequence of jumpdest
// addresses used during program execution within the current context.
// For each jumpdest address we also non-deterministically guess
// a proof, which is another address in the code, such that
// is_jumpdest don't abort when the proof is on the top of the stack
// a proof, which is another address in the code such that
// is_jumpdest don't abort, when the proof is at the top of the stack
// an the jumpdest address below. If that's the case we set the
// corresponding bit in @SEGMENT_JUMPDEST_BITS to 1.
//
// stack: retdest
// stack: ctx, retdest
// stack: (empty)
global validate_jumpdest_table:
// If address > 0 it is interpreted as address' = address - 1
// If address > 0 then address is interpreted as address' + 1
// and the next prover input should contain a proof for address'.
PROVER_INPUT(jumpdest_table::next_address)
DUP1 %jumpi(check_proof)
// If proof == 0 there are no more jump destionations to check
POP
// This is just a hook used for avoiding verification of the jumpdest
// table in another contexts. It is useful during proof generation,
// allowing the avoidance of table verification when simulating user code.
global validate_jumpdest_table_end:
POP
JUMP
// were set to 0
//%mload_context_metadata(@CTX_METADATA_CODE_SIZE)
// get the code length in bytes
//%add_const(31)
//%div_const(32)
//GET_CONTEXT
//SWAP2
//verify_chunk:
// stack: i (= proof), code_len, ctx = 0
//%stack (i, code_len, ctx) -> (code_len, i, ctx, @SEGMENT_JUMPDEST_BITS, i, 32, i, code_len, ctx)
//GT
//%jumpi(valid_table)
//%mload_packing
// stack: packed_bits, code_len, i, ctx
//%assert_eq_const(0)
//%increment
//%jump(verify_chunk)
check_proof:
%sub_const(1)
DUP1
DUP2 DUP2
// stack: address, ctx, address, ctx
// We read the proof
PROVER_INPUT(jumpdest_table::next_proof)
// stack: proof, address
// stack: proof, address, ctx, address, ctx
%is_jumpdest
GET_CONTEXT
%stack (ctx, address) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, address)
%stack (address, ctx) -> (1, ctx, @SEGMENT_JUMPDEST_BITS, address, ctx)
MSTORE_GENERAL
%jump(validate_jumpdest_table)
valid_table:
// stack: ctx, @SEGMENT_JUMPDEST_BITS, i, 32, i, code_len, ctx, retdest
%pop7
JUMP
%macro validate_jumpdest_table
PUSH %%after
%stack (ctx) -> (ctx, %%after)
%jump(validate_jumpdest_table)
%%after:
%endmacro

View File

@ -10,6 +10,7 @@ use keccak_hash::keccak;
use plonky2::field::goldilocks_field::GoldilocksField;
use super::assembler::BYTES_PER_OFFSET;
use super::utils::u256_from_bool;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::constants::context_metadata::ContextMetadata;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
@ -413,6 +414,14 @@ impl<'a> Interpreter<'a> {
.collect()
}
pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec<bool>) {
self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize]
.content = jumpdest_bits
.into_iter()
.map(|x| u256_from_bool(x))
.collect();
}
fn incr(&mut self, n: usize) {
self.generation_state.registers.program_counter += n;
}

View File

@ -4,39 +4,37 @@ use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::interpreter::Interpreter;
use crate::cpu::kernel::opcodes::{get_opcode, get_push_opcode};
#[test]
fn test_jumpdest_analysis() -> Result<()> {
let jumpdest_analysis = KERNEL.global_labels["jumpdest_analysis"];
const CONTEXT: usize = 3; // arbitrary
// #[test]
// fn test_jumpdest_analysis() -> Result<()> {
// let jumpdest_analysis = KERNEL.global_labels["validate_jumpdest_table"];
// const CONTEXT: usize = 3; // arbitrary
let add = get_opcode("ADD");
let push2 = get_push_opcode(2);
let jumpdest = get_opcode("JUMPDEST");
// let add = get_opcode("ADD");
// let push2 = get_push_opcode(2);
// let jumpdest = get_opcode("JUMPDEST");
#[rustfmt::skip]
let code: Vec<u8> = vec![
add,
jumpdest,
push2,
jumpdest, // part of PUSH2
jumpdest, // part of PUSH2
jumpdest,
add,
jumpdest,
];
// #[rustfmt::skip]
// let code: Vec<u8> = vec![
// add,
// jumpdest,
// push2,
// jumpdest, // part of PUSH2
// jumpdest, // part of PUSH2
// jumpdest,
// add,
// jumpdest,
// ];
let expected_jumpdest_bits = vec![false, true, false, false, false, true, false, true];
// let jumpdest_bits = vec![false, true, false, false, false, true, false, true];
// Contract creation transaction.
let initial_stack = vec![0xDEADBEEFu32.into(), code.len().into(), CONTEXT.into()];
let mut interpreter = Interpreter::new_with_kernel(jumpdest_analysis, initial_stack);
interpreter.set_code(CONTEXT, code);
interpreter.run()?;
assert_eq!(interpreter.stack(), vec![]);
assert_eq!(
interpreter.get_jumpdest_bits(CONTEXT),
expected_jumpdest_bits
);
// // Contract creation transaction.
// let initial_stack = vec![0xDEADBEEFu32.into(), CONTEXT.into()];
// let mut interpreter = Interpreter::new_with_kernel(jumpdest_analysis, initial_stack);
// interpreter.set_code(CONTEXT, code);
// interpreter.set_jumpdest_bits(CONTEXT, jumpdest_bits);
Ok(())
}
// interpreter.run()?;
// assert_eq!(interpreter.stack(), vec![]);
// Ok(())
// }

View File

@ -345,6 +345,8 @@ fn simulate_cpu_between_labels_and_get_user_jumps<F: Field>(
state.registers.program_counter = KERNEL.global_labels[initial_label];
let context = state.registers.context;
log::debug!("Simulating CPU for jumpdest analysis ");
loop {
if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] {
state.registers.program_counter = KERNEL.global_labels["validate_jumpdest_table_end"]
@ -382,7 +384,9 @@ fn simulate_cpu_between_labels_and_get_user_jumps<F: Field>(
}
if halt {
log::debug!("Simulated CPU halted after {} cycles", state.traces.clock());
return Ok(jumpdest_addresses.into_iter().collect());
let mut jumpdest_addresses: Vec<usize> = jumpdest_addresses.into_iter().collect();
jumpdest_addresses.sort();
return Ok(jumpdest_addresses);
}
transition(state)?;

View File

@ -257,51 +257,7 @@ impl<F: Field> GenerationState<F> {
}))?;
if self.jumpdest_addresses.is_none() {
let mut state: GenerationState<F> = self.soft_clone();
let mut jumpdest_addresses = vec![];
// Generate the jumpdest table
let code = (0..code_len)
.map(|i| {
u256_to_u8(self.memory.get(MemoryAddress {
context: self.registers.context,
segment: Segment::Code as usize,
virt: i,
}))
})
.collect::<Result<Vec<u8>, _>>()?;
let mut i = 0;
while i < code_len {
if code[i] == get_opcode("JUMPDEST") {
jumpdest_addresses.push(i);
state.memory.set(
MemoryAddress {
context: state.registers.context,
segment: Segment::JumpdestBits as usize,
virt: i,
},
U256::one(),
);
log::debug!("jumpdest at {i}");
}
i += if code[i] >= get_push_opcode(1) && code[i] <= get_push_opcode(32) {
(code[i] - get_push_opcode(1) + 2).into()
} else {
1
}
}
// We need to skip the validate table call
self.jumpdest_addresses = simulate_cpu_between_labels_and_get_user_jumps(
"validate_jumpdest_table_end",
"terminate_common",
&mut state,
)
.ok();
log::debug!("code len = {code_len}");
log::debug!("all jumpdest addresses = {:?}", jumpdest_addresses);
log::debug!("user's jumdest addresses = {:?}", self.jumpdest_addresses);
// self.jumpdest_addresses = Some(jumpdest_addresses);
self.generate_jumpdest_table()?;
}
let Some(jumpdest_table) = &mut self.jumpdest_addresses else {
@ -326,57 +282,138 @@ impl<F: Field> GenerationState<F> {
virt: ContextMetadata::CodeSize as usize,
}))?;
let mut address = MemoryAddress {
context: self.registers.context,
segment: Segment::Code as usize,
virt: 0,
};
let mut proof = 0;
let mut prefix_size = 0;
let code = (0..self.last_jumpdest_address)
.map(|i| {
u256_to_u8(self.memory.get(MemoryAddress {
context: self.registers.context,
segment: Segment::Code as usize,
virt: i,
}))
})
.collect::<Result<Vec<u8>, _>>()?;
// TODO: The proof searching algorithm is not very eficient. But luckyly it doesn't seem
// a problem because is done natively.
// a problem as is done natively.
// Search the closest address to last_jumpdest_address for which none of
// the previous 32 bytes in the code (including opcodes and pushed bytes)
// are PUSHXX and the address is in its range
while address.virt < self.last_jumpdest_address {
let opcode = u256_to_u8(self.memory.get(address))?;
let is_push =
opcode >= get_push_opcode(1).into() && opcode <= get_push_opcode(32).into();
address.virt += if is_push {
(opcode - get_push_opcode(1) + 2).into()
} else {
1
};
// Check if the new address has a prefix of size >= 32
let mut has_prefix = true;
for i in address.virt as i32 - 32..address.virt as i32 {
let opcode = u256_to_u8(self.memory.get(MemoryAddress {
let proof = CodeIterator::until(&code, self.last_jumpdest_address + 1).fold(
0,
|acc, (pos, opcode)| {
let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) {
code[prefix_start..pos].iter().enumerate().fold(
true,
|acc, (prefix_pos, &byte)| {
acc && (byte > get_push_opcode(32)
|| (prefix_start + prefix_pos) as i32
+ (byte as i32 - get_push_opcode(1) as i32)
+ 1
< pos as i32)
},
)
} else {
false
};
if has_prefix {
pos - 32
} else {
acc
}
},
);
Ok(proof.into())
}
}
impl<F: Field> GenerationState<F> {
fn generate_jumpdest_table(&mut self) -> Result<(), ProgramError> {
let mut state = self.soft_clone();
let code_len = u256_to_usize(self.memory.get(MemoryAddress {
context: self.registers.context,
segment: Segment::ContextMetadata as usize,
virt: ContextMetadata::CodeSize as usize,
}))?;
// Generate the jumpdest table
let code = (0..code_len)
.map(|i| {
u256_to_u8(self.memory.get(MemoryAddress {
context: self.registers.context,
segment: Segment::Code as usize,
virt: i as usize,
}))?;
if i < 0
|| (opcode >= get_push_opcode(1)
&& opcode <= get_push_opcode(32)
&& i + (opcode - get_push_opcode(1)) as i32 + 1 >= address.virt as i32)
{
has_prefix = false;
break;
}
}
if has_prefix {
proof = address.virt - 32;
virt: i,
}))
})
.collect::<Result<Vec<u8>, _>>()?;
// We need to set the the simulated jumpdest bits to one as otherwise
// the simulation will fail
let mut jumpdest_table = vec![];
for (pos, opcode) in CodeIterator::new(&code) {
jumpdest_table.push((pos, opcode == get_opcode("JUMPDEST")));
if opcode == get_opcode("JUMPDEST") {
state.memory.set(
MemoryAddress {
context: state.registers.context,
segment: Segment::JumpdestBits as usize,
virt: pos,
},
U256::one(),
);
}
}
if address.virt > self.last_jumpdest_address {
return Err(ProgramError::ProverInputError(
ProverInputError::InvalidJumpDestination,
));
// Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call
self.jumpdest_addresses = simulate_cpu_between_labels_and_get_user_jumps(
"validate_jumpdest_table_end",
"terminate_common",
&mut state,
)
.ok();
Ok(())
}
}
struct CodeIterator<'a> {
code: &'a Vec<u8>,
pos: usize,
end: usize,
}
impl<'a> CodeIterator<'a> {
fn new(code: &'a Vec<u8>) -> Self {
CodeIterator {
end: code.len(),
code,
pos: 0,
}
Ok(proof.into())
}
fn until(code: &'a Vec<u8>, end: usize) -> Self {
CodeIterator {
end: std::cmp::min(code.len(), end),
code,
pos: 0,
}
}
}
impl<'a> Iterator for CodeIterator<'a> {
type Item = (usize, u8);
fn next(&mut self) -> Option<Self::Item> {
let CodeIterator { code, pos, end } = self;
if *pos >= *end {
return None;
}
let opcode = code[*pos];
let old_pos = *pos;
*pos += if opcode >= get_push_opcode(1) && opcode <= get_push_opcode(32) {
(opcode - get_push_opcode(1) + 2).into()
} else {
1
};
Some((old_pos, opcode))
}
}

View File

@ -395,7 +395,12 @@ fn try_perform_instruction<F: Field>(
if state.registers.is_kernel {
log_kernel_instruction(state, op);
} else {
log::debug!("User instruction: {:?} stack = {:?}", op, state.stack());
log::debug!(
"User instruction: {:?} ctx = {:?} stack = {:?}",
op,
state.registers.context,
state.stack()
);
}
fill_op_flag(op, &mut row);