Improve proof generation

This commit is contained in:
4l0n50 2023-12-19 14:05:51 +01:00
parent 77f1cd3496
commit 829ae64fc4
4 changed files with 148 additions and 99 deletions

View File

@ -293,16 +293,17 @@ impl<'a> Interpreter<'a> {
pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec<bool>) {
self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize]
.content = jumpdest_bits.iter().map(|&x| u256_from_bool(x)).collect();
self.generation_state.jumpdest_addresses = Some(HashMap::from([(
context,
BTreeSet::from_iter(
jumpdest_bits
.into_iter()
.enumerate()
.filter(|&(_, x)| x)
.map(|(i, _)| i),
),
)]));
self.generation_state
.set_proofs_and_jumpdests(HashMap::from([(
context,
BTreeSet::from_iter(
jumpdest_bits
.into_iter()
.enumerate()
.filter(|&(_, x)| x)
.map(|(i, _)| i),
),
)]));
}
pub(crate) fn incr(&mut self, n: usize) {

View File

@ -288,9 +288,9 @@ fn simulate_cpu_between_labels_and_get_user_jumps<F: Field>(
initial_label: &str,
final_label: &str,
state: &mut GenerationState<F>,
) -> Result<(), ProgramError> {
if state.jumpdest_addresses.is_some() {
Ok(())
) -> Result<Option<HashMap<usize, BTreeSet<usize>>>, ProgramError> {
if state.jumpdest_proofs.is_some() {
Ok(None)
} else {
const JUMP_OPCODE: u8 = 0x56;
const JUMPI_OPCODE: u8 = 0x57;
@ -304,6 +304,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps<F: Field>(
log::debug!("Simulating CPU for jumpdest analysis.");
loop {
// skip jumdest table validations in simulations
if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] {
state.registers.program_counter =
KERNEL.global_labels["validate_jumpdest_table_end"]
@ -344,8 +345,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps<F: Field>(
}
if halt {
log::debug!("Simulated CPU halted after {} cycles", state.traces.clock());
state.jumpdest_addresses = Some(jumpdest_addresses);
return Ok(());
return Ok(Some(jumpdest_addresses));
}
transition(state).map_err(|_| {
ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation)

View File

@ -1,11 +1,10 @@
use std::cmp::min;
use std::collections::HashSet;
use std::collections::HashMap;
use std::mem::transmute;
use std::str::FromStr;
use anyhow::{bail, Error};
use ethereum_types::{BigEndianHash, H256, U256, U512};
use hashbrown::HashMap;
use itertools::{enumerate, Itertools};
use num_bigint::BigUint;
use plonky2::field::extension::Extendable;
@ -256,87 +255,98 @@ impl<F: Field> GenerationState<F> {
virt: ContextMetadata::CodeSize as usize,
}))?;
if self.jumpdest_addresses.is_none() {
self.generate_jumpdest_table()?;
if self.jumpdest_proofs.is_none() {
self.generate_jumpdest_proofs()?;
}
let Some(jumpdest_tables) = &mut self.jumpdest_addresses else {
let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else {
return Err(ProgramError::ProverInputError(
ProverInputError::InvalidJumpdestSimulation,
));
};
if let Some(ctx_jumpdest_table) = jumpdest_tables.get_mut(&context)
&& let Some(next_jumpdest_address) = ctx_jumpdest_table.pop_last()
if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context)
&& let Some(next_jumpdest_address) = ctx_jumpdest_proofs.pop()
{
self.last_jumpdest_address = next_jumpdest_address;
Ok((next_jumpdest_address + 1).into())
} else {
self.jumpdest_addresses = None;
self.jumpdest_proofs = None;
Ok(U256::zero())
}
}
/// Returns the proof for the last jump address.
fn run_next_jumpdest_table_proof(&mut self) -> Result<U256, ProgramError> {
let code = (0..self.last_jumpdest_address)
.map(|i| {
u256_to_u8(self.memory.get(MemoryAddress {
context: self.registers.context,
segment: Segment::Code as usize,
virt: i,
}))
})
.collect::<Result<Vec<u8>, _>>()?;
// TODO: The proof searching algorithm is not very efficient. But luckily it doesn't seem
// a problem as is done natively.
// Search the closest address to `last_jumpdest_address` for which none of
// the previous 32 bytes in the code (including opcodes and pushed bytes)
// are PUSHXX and the address is in its range.
const PUSH1_OPCODE: u8 = 0x60;
const PUSH32_OPCODE: u8 = 0x7f;
let proof = CodeIterator::until(&code, self.last_jumpdest_address + 1).fold(
0,
|acc, (pos, opcode)| {
let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) {
code[prefix_start..pos].iter().enumerate().fold(
true,
|acc, (prefix_pos, &byte)| {
acc && (byte > PUSH32_OPCODE
|| (prefix_start + prefix_pos) as i32
+ (byte as i32 - PUSH1_OPCODE as i32)
+ 1
< pos as i32)
},
)
} else {
false
};
if has_prefix {
pos - 32
} else {
acc
}
},
);
Ok(proof.into())
let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else {
return Err(ProgramError::ProverInputError(
ProverInputError::InvalidJumpdestSimulation,
));
};
if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context)
&& let Some(next_jumpdest_proof) = ctx_jumpdest_proofs.pop()
{
Ok(next_jumpdest_proof.into())
} else {
Err(ProgramError::ProverInputError(
ProverInputError::InvalidJumpdestSimulation,
))
}
}
}
impl<F: Field> GenerationState<F> {
fn generate_jumpdest_table(&mut self) -> Result<(), ProgramError> {
const JUMPDEST_OPCODE: u8 = 0x5b;
let mut state = self.soft_clone();
let code_len = u256_to_usize(self.memory.get(MemoryAddress {
context: self.registers.context,
segment: Segment::ContextMetadata as usize,
virt: ContextMetadata::CodeSize as usize,
}))?;
// Generate the jumpdest table
fn generate_jumpdest_proofs(&mut self) -> Result<(), ProgramError> {
let checkpoint = self.checkpoint();
let memory = self.memory.clone();
let code = self.get_current_code()?;
// We need to set the simulated jumpdest bits to one as otherwise
// the simulation will fail.
self.set_jumpdest_bits(&code);
// Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call
let Some(jumpdest_table) = simulate_cpu_between_labels_and_get_user_jumps(
"validate_jumpdest_table_end",
"terminate_common",
self,
)?
else {
return Ok(());
};
// Return to the state before starting the simulation
self.rollback(checkpoint);
self.memory = memory;
// Find proofs for all context
self.set_proofs_and_jumpdests(jumpdest_table);
Ok(())
}
pub(crate) fn set_proofs_and_jumpdests(
&mut self,
jumpdest_table: HashMap<usize, std::collections::BTreeSet<usize>>,
) {
self.jumpdest_proofs = Some(HashMap::from_iter(jumpdest_table.into_iter().map(
|(ctx, jumpdest_table)| {
let code = self.get_code(ctx).unwrap();
if let Some(&largest_address) = jumpdest_table.last() {
let proofs = get_proofs_and_jumpdests(&code, largest_address, jumpdest_table);
(ctx, proofs)
} else {
(ctx, vec![])
}
},
)));
}
fn get_current_code(&self) -> Result<Vec<u8>, ProgramError> {
self.get_code(self.registers.context)
}
fn get_code(&self, context: usize) -> Result<Vec<u8>, ProgramError> {
let code_len = self.get_code_len()?;
let code = (0..code_len)
.map(|i| {
u256_to_u8(self.memory.get(MemoryAddress {
@ -346,16 +356,25 @@ impl<F: Field> GenerationState<F> {
}))
})
.collect::<Result<Vec<u8>, _>>()?;
Ok(code)
}
// We need to set the simulated jumpdest bits to one as otherwise
// the simulation will fail.
let mut jumpdest_table = Vec::with_capacity(code.len());
fn get_code_len(&self) -> Result<usize, ProgramError> {
let code_len = u256_to_usize(self.memory.get(MemoryAddress {
context: self.registers.context,
segment: Segment::ContextMetadata as usize,
virt: ContextMetadata::CodeSize as usize,
}))?;
Ok(code_len)
}
fn set_jumpdest_bits<'a>(&mut self, code: &'a Vec<u8>) {
const JUMPDEST_OPCODE: u8 = 0x5b;
for (pos, opcode) in CodeIterator::new(&code) {
jumpdest_table.push((pos, opcode == JUMPDEST_OPCODE));
if opcode == JUMPDEST_OPCODE {
state.memory.set(
self.memory.set(
MemoryAddress {
context: state.registers.context,
context: self.registers.context,
segment: Segment::JumpdestBits as usize,
virt: pos,
},
@ -363,18 +382,50 @@ impl<F: Field> GenerationState<F> {
);
}
}
// Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call
simulate_cpu_between_labels_and_get_user_jumps(
"validate_jumpdest_table_end",
"terminate_common",
&mut state,
)?;
self.jumpdest_addresses = state.jumpdest_addresses;
Ok(())
}
}
/// For each address in `jumpdest_table` it search a proof, that is the closest address
/// for which none of the previous 32 bytes in the code (including opcodes
/// and pushed bytes are PUSHXX and the address is in its range. It returns
/// a vector of even size containing proofs followed by their addresses
fn get_proofs_and_jumpdests<'a>(
code: &'a Vec<u8>,
largest_address: usize,
jumpdest_table: std::collections::BTreeSet<usize>,
) -> Vec<usize> {
const PUSH1_OPCODE: u8 = 0x60;
const PUSH32_OPCODE: u8 = 0x7f;
let (proofs, _) = CodeIterator::until(&code, largest_address + 1).fold(
(vec![], 0),
|(mut proofs, acc), (pos, opcode)| {
let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) {
code[prefix_start..pos]
.iter()
.enumerate()
.fold(true, |acc, (prefix_pos, &byte)| {
acc && (byte > PUSH32_OPCODE
|| (prefix_start + prefix_pos) as i32
+ (byte as i32 - PUSH1_OPCODE as i32)
+ 1
< pos as i32)
})
} else {
false
};
let acc = if has_prefix { pos - 32 } else { acc };
if jumpdest_table.contains(&pos) {
// Push the proof
proofs.push(acc);
// Push the address
proofs.push(pos);
}
(proofs, acc)
},
);
proofs
}
struct CodeIterator<'a> {
code: &'a [u8],
pos: usize,

View File

@ -51,8 +51,7 @@ pub(crate) struct GenerationState<F: Field> {
/// Pointers, within the `TrieData` segment, of the three MPTs.
pub(crate) trie_root_ptrs: TrieRootPtrs,
pub(crate) last_jumpdest_address: usize,
pub(crate) jumpdest_addresses: Option<HashMap<usize, BTreeSet<usize>>>,
pub(crate) jumpdest_proofs: Option<HashMap<usize, Vec<usize>>>,
}
impl<F: Field> GenerationState<F> {
@ -94,8 +93,7 @@ impl<F: Field> GenerationState<F> {
txn_root_ptr: 0,
receipt_root_ptr: 0,
},
last_jumpdest_address: 0,
jumpdest_addresses: None,
jumpdest_proofs: None,
};
let trie_root_ptrs = state.preinitialize_mpts(&inputs.tries);
@ -189,8 +187,7 @@ impl<F: Field> GenerationState<F> {
txn_root_ptr: 0,
receipt_root_ptr: 0,
},
last_jumpdest_address: 0,
jumpdest_addresses: None,
jumpdest_proofs: None,
}
}
}