Improve proof generation

This commit is contained in:
4l0n50 2023-12-19 14:05:51 +01:00
parent c4025063de
commit 4e569484c2
4 changed files with 148 additions and 99 deletions

View File

@ -417,16 +417,17 @@ impl<'a> Interpreter<'a> {
pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec<bool>) { pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec<bool>) {
self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize] self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize]
.content = jumpdest_bits.iter().map(|&x| u256_from_bool(x)).collect(); .content = jumpdest_bits.iter().map(|&x| u256_from_bool(x)).collect();
self.generation_state.jumpdest_addresses = Some(HashMap::from([( self.generation_state
context, .set_proofs_and_jumpdests(HashMap::from([(
BTreeSet::from_iter( context,
jumpdest_bits BTreeSet::from_iter(
.into_iter() jumpdest_bits
.enumerate() .into_iter()
.filter(|&(_, x)| x) .enumerate()
.map(|(i, _)| i), .filter(|&(_, x)| x)
), .map(|(i, _)| i),
)])); ),
)]));
} }
pub(crate) fn incr(&mut self, n: usize) { pub(crate) fn incr(&mut self, n: usize) {

View File

@ -340,9 +340,9 @@ fn simulate_cpu_between_labels_and_get_user_jumps<F: Field>(
initial_label: &str, initial_label: &str,
final_label: &str, final_label: &str,
state: &mut GenerationState<F>, state: &mut GenerationState<F>,
) -> Result<(), ProgramError> { ) -> Result<Option<HashMap<usize, BTreeSet<usize>>>, ProgramError> {
if state.jumpdest_addresses.is_some() { if state.jumpdest_proofs.is_some() {
Ok(()) Ok(None)
} else { } else {
const JUMP_OPCODE: u8 = 0x56; const JUMP_OPCODE: u8 = 0x56;
const JUMPI_OPCODE: u8 = 0x57; const JUMPI_OPCODE: u8 = 0x57;
@ -356,6 +356,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps<F: Field>(
log::debug!("Simulating CPU for jumpdest analysis."); log::debug!("Simulating CPU for jumpdest analysis.");
loop { loop {
// skip jumdest table validations in simulations
if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] { if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] {
state.registers.program_counter = state.registers.program_counter =
KERNEL.global_labels["validate_jumpdest_table_end"] KERNEL.global_labels["validate_jumpdest_table_end"]
@ -396,8 +397,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps<F: Field>(
} }
if halt { if halt {
log::debug!("Simulated CPU halted after {} cycles", state.traces.clock()); log::debug!("Simulated CPU halted after {} cycles", state.traces.clock());
state.jumpdest_addresses = Some(jumpdest_addresses); return Ok(Some(jumpdest_addresses));
return Ok(());
} }
transition(state).map_err(|_| { transition(state).map_err(|_| {
ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation) ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation)

View File

@ -1,11 +1,10 @@
use std::cmp::min; use std::cmp::min;
use std::collections::HashSet; use std::collections::HashMap;
use std::mem::transmute; use std::mem::transmute;
use std::str::FromStr; use std::str::FromStr;
use anyhow::{bail, Error}; use anyhow::{bail, Error};
use ethereum_types::{BigEndianHash, H256, U256, U512}; use ethereum_types::{BigEndianHash, H256, U256, U512};
use hashbrown::HashMap;
use itertools::{enumerate, Itertools}; use itertools::{enumerate, Itertools};
use num_bigint::BigUint; use num_bigint::BigUint;
use plonky2::field::extension::Extendable; use plonky2::field::extension::Extendable;
@ -256,87 +255,98 @@ impl<F: Field> GenerationState<F> {
virt: ContextMetadata::CodeSize as usize, virt: ContextMetadata::CodeSize as usize,
}))?; }))?;
if self.jumpdest_addresses.is_none() { if self.jumpdest_proofs.is_none() {
self.generate_jumpdest_table()?; self.generate_jumpdest_proofs()?;
} }
let Some(jumpdest_tables) = &mut self.jumpdest_addresses else { let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else {
return Err(ProgramError::ProverInputError( return Err(ProgramError::ProverInputError(
ProverInputError::InvalidJumpdestSimulation, ProverInputError::InvalidJumpdestSimulation,
)); ));
}; };
if let Some(ctx_jumpdest_table) = jumpdest_tables.get_mut(&context) if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context)
&& let Some(next_jumpdest_address) = ctx_jumpdest_table.pop_last() && let Some(next_jumpdest_address) = ctx_jumpdest_proofs.pop()
{ {
self.last_jumpdest_address = next_jumpdest_address;
Ok((next_jumpdest_address + 1).into()) Ok((next_jumpdest_address + 1).into())
} else { } else {
self.jumpdest_addresses = None; self.jumpdest_proofs = None;
Ok(U256::zero()) Ok(U256::zero())
} }
} }
/// Returns the proof for the last jump address. /// Returns the proof for the last jump address.
fn run_next_jumpdest_table_proof(&mut self) -> Result<U256, ProgramError> { fn run_next_jumpdest_table_proof(&mut self) -> Result<U256, ProgramError> {
let code = (0..self.last_jumpdest_address) let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else {
.map(|i| { return Err(ProgramError::ProverInputError(
u256_to_u8(self.memory.get(MemoryAddress { ProverInputError::InvalidJumpdestSimulation,
context: self.registers.context, ));
segment: Segment::Code as usize, };
virt: i, if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context)
})) && let Some(next_jumpdest_proof) = ctx_jumpdest_proofs.pop()
}) {
.collect::<Result<Vec<u8>, _>>()?; Ok(next_jumpdest_proof.into())
} else {
// TODO: The proof searching algorithm is not very efficient. But luckily it doesn't seem Err(ProgramError::ProverInputError(
// a problem as is done natively. ProverInputError::InvalidJumpdestSimulation,
))
// Search the closest address to `last_jumpdest_address` for which none of }
// the previous 32 bytes in the code (including opcodes and pushed bytes)
// are PUSHXX and the address is in its range.
const PUSH1_OPCODE: u8 = 0x60;
const PUSH32_OPCODE: u8 = 0x7f;
let proof = CodeIterator::until(&code, self.last_jumpdest_address + 1).fold(
0,
|acc, (pos, opcode)| {
let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) {
code[prefix_start..pos].iter().enumerate().fold(
true,
|acc, (prefix_pos, &byte)| {
acc && (byte > PUSH32_OPCODE
|| (prefix_start + prefix_pos) as i32
+ (byte as i32 - PUSH1_OPCODE as i32)
+ 1
< pos as i32)
},
)
} else {
false
};
if has_prefix {
pos - 32
} else {
acc
}
},
);
Ok(proof.into())
} }
} }
impl<F: Field> GenerationState<F> { impl<F: Field> GenerationState<F> {
fn generate_jumpdest_table(&mut self) -> Result<(), ProgramError> { fn generate_jumpdest_proofs(&mut self) -> Result<(), ProgramError> {
const JUMPDEST_OPCODE: u8 = 0x5b; let checkpoint = self.checkpoint();
let mut state = self.soft_clone(); let memory = self.memory.clone();
let code_len = u256_to_usize(self.memory.get(MemoryAddress {
context: self.registers.context, let code = self.get_current_code()?;
segment: Segment::ContextMetadata as usize, // We need to set the simulated jumpdest bits to one as otherwise
virt: ContextMetadata::CodeSize as usize, // the simulation will fail.
}))?; self.set_jumpdest_bits(&code);
// Generate the jumpdest table
// Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call
let Some(jumpdest_table) = simulate_cpu_between_labels_and_get_user_jumps(
"validate_jumpdest_table_end",
"terminate_common",
self,
)?
else {
return Ok(());
};
// Return to the state before starting the simulation
self.rollback(checkpoint);
self.memory = memory;
// Find proofs for all context
self.set_proofs_and_jumpdests(jumpdest_table);
Ok(())
}
pub(crate) fn set_proofs_and_jumpdests(
&mut self,
jumpdest_table: HashMap<usize, std::collections::BTreeSet<usize>>,
) {
self.jumpdest_proofs = Some(HashMap::from_iter(jumpdest_table.into_iter().map(
|(ctx, jumpdest_table)| {
let code = self.get_code(ctx).unwrap();
if let Some(&largest_address) = jumpdest_table.last() {
let proofs = get_proofs_and_jumpdests(&code, largest_address, jumpdest_table);
(ctx, proofs)
} else {
(ctx, vec![])
}
},
)));
}
fn get_current_code(&self) -> Result<Vec<u8>, ProgramError> {
self.get_code(self.registers.context)
}
fn get_code(&self, context: usize) -> Result<Vec<u8>, ProgramError> {
let code_len = self.get_code_len()?;
let code = (0..code_len) let code = (0..code_len)
.map(|i| { .map(|i| {
u256_to_u8(self.memory.get(MemoryAddress { u256_to_u8(self.memory.get(MemoryAddress {
@ -346,16 +356,25 @@ impl<F: Field> GenerationState<F> {
})) }))
}) })
.collect::<Result<Vec<u8>, _>>()?; .collect::<Result<Vec<u8>, _>>()?;
Ok(code)
}
// We need to set the simulated jumpdest bits to one as otherwise fn get_code_len(&self) -> Result<usize, ProgramError> {
// the simulation will fail. let code_len = u256_to_usize(self.memory.get(MemoryAddress {
let mut jumpdest_table = Vec::with_capacity(code.len()); context: self.registers.context,
segment: Segment::ContextMetadata as usize,
virt: ContextMetadata::CodeSize as usize,
}))?;
Ok(code_len)
}
fn set_jumpdest_bits<'a>(&mut self, code: &'a Vec<u8>) {
const JUMPDEST_OPCODE: u8 = 0x5b;
for (pos, opcode) in CodeIterator::new(&code) { for (pos, opcode) in CodeIterator::new(&code) {
jumpdest_table.push((pos, opcode == JUMPDEST_OPCODE));
if opcode == JUMPDEST_OPCODE { if opcode == JUMPDEST_OPCODE {
state.memory.set( self.memory.set(
MemoryAddress { MemoryAddress {
context: state.registers.context, context: self.registers.context,
segment: Segment::JumpdestBits as usize, segment: Segment::JumpdestBits as usize,
virt: pos, virt: pos,
}, },
@ -363,18 +382,50 @@ impl<F: Field> GenerationState<F> {
); );
} }
} }
// Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call
simulate_cpu_between_labels_and_get_user_jumps(
"validate_jumpdest_table_end",
"terminate_common",
&mut state,
)?;
self.jumpdest_addresses = state.jumpdest_addresses;
Ok(())
} }
} }
/// For each address in `jumpdest_table` it search a proof, that is the closest address
/// for which none of the previous 32 bytes in the code (including opcodes
/// and pushed bytes are PUSHXX and the address is in its range. It returns
/// a vector of even size containing proofs followed by their addresses
fn get_proofs_and_jumpdests<'a>(
code: &'a Vec<u8>,
largest_address: usize,
jumpdest_table: std::collections::BTreeSet<usize>,
) -> Vec<usize> {
const PUSH1_OPCODE: u8 = 0x60;
const PUSH32_OPCODE: u8 = 0x7f;
let (proofs, _) = CodeIterator::until(&code, largest_address + 1).fold(
(vec![], 0),
|(mut proofs, acc), (pos, opcode)| {
let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) {
code[prefix_start..pos]
.iter()
.enumerate()
.fold(true, |acc, (prefix_pos, &byte)| {
acc && (byte > PUSH32_OPCODE
|| (prefix_start + prefix_pos) as i32
+ (byte as i32 - PUSH1_OPCODE as i32)
+ 1
< pos as i32)
})
} else {
false
};
let acc = if has_prefix { pos - 32 } else { acc };
if jumpdest_table.contains(&pos) {
// Push the proof
proofs.push(acc);
// Push the address
proofs.push(pos);
}
(proofs, acc)
},
);
proofs
}
struct CodeIterator<'a> { struct CodeIterator<'a> {
code: &'a [u8], code: &'a [u8],
pos: usize, pos: usize,

View File

@ -51,8 +51,7 @@ pub(crate) struct GenerationState<F: Field> {
/// Pointers, within the `TrieData` segment, of the three MPTs. /// Pointers, within the `TrieData` segment, of the three MPTs.
pub(crate) trie_root_ptrs: TrieRootPtrs, pub(crate) trie_root_ptrs: TrieRootPtrs,
pub(crate) last_jumpdest_address: usize, pub(crate) jumpdest_proofs: Option<HashMap<usize, Vec<usize>>>,
pub(crate) jumpdest_addresses: Option<HashMap<usize, BTreeSet<usize>>>,
} }
impl<F: Field> GenerationState<F> { impl<F: Field> GenerationState<F> {
@ -94,8 +93,7 @@ impl<F: Field> GenerationState<F> {
txn_root_ptr: 0, txn_root_ptr: 0,
receipt_root_ptr: 0, receipt_root_ptr: 0,
}, },
last_jumpdest_address: 0, jumpdest_proofs: None,
jumpdest_addresses: None,
}; };
let trie_root_ptrs = state.preinitialize_mpts(&inputs.tries); let trie_root_ptrs = state.preinitialize_mpts(&inputs.tries);
@ -189,8 +187,7 @@ impl<F: Field> GenerationState<F> {
txn_root_ptr: 0, txn_root_ptr: 0,
receipt_root_ptr: 0, receipt_root_ptr: 0,
}, },
last_jumpdest_address: 0, jumpdest_proofs: None,
jumpdest_addresses: None,
} }
} }
} }