mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-04 23:03:08 +00:00
Improve proof generation
This commit is contained in:
parent
77f1cd3496
commit
829ae64fc4
@ -293,16 +293,17 @@ impl<'a> Interpreter<'a> {
|
||||
pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec<bool>) {
|
||||
self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize]
|
||||
.content = jumpdest_bits.iter().map(|&x| u256_from_bool(x)).collect();
|
||||
self.generation_state.jumpdest_addresses = Some(HashMap::from([(
|
||||
context,
|
||||
BTreeSet::from_iter(
|
||||
jumpdest_bits
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter(|&(_, x)| x)
|
||||
.map(|(i, _)| i),
|
||||
),
|
||||
)]));
|
||||
self.generation_state
|
||||
.set_proofs_and_jumpdests(HashMap::from([(
|
||||
context,
|
||||
BTreeSet::from_iter(
|
||||
jumpdest_bits
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter(|&(_, x)| x)
|
||||
.map(|(i, _)| i),
|
||||
),
|
||||
)]));
|
||||
}
|
||||
|
||||
pub(crate) fn incr(&mut self, n: usize) {
|
||||
|
||||
@ -288,9 +288,9 @@ fn simulate_cpu_between_labels_and_get_user_jumps<F: Field>(
|
||||
initial_label: &str,
|
||||
final_label: &str,
|
||||
state: &mut GenerationState<F>,
|
||||
) -> Result<(), ProgramError> {
|
||||
if state.jumpdest_addresses.is_some() {
|
||||
Ok(())
|
||||
) -> Result<Option<HashMap<usize, BTreeSet<usize>>>, ProgramError> {
|
||||
if state.jumpdest_proofs.is_some() {
|
||||
Ok(None)
|
||||
} else {
|
||||
const JUMP_OPCODE: u8 = 0x56;
|
||||
const JUMPI_OPCODE: u8 = 0x57;
|
||||
@ -304,6 +304,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps<F: Field>(
|
||||
log::debug!("Simulating CPU for jumpdest analysis.");
|
||||
|
||||
loop {
|
||||
// skip jumdest table validations in simulations
|
||||
if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] {
|
||||
state.registers.program_counter =
|
||||
KERNEL.global_labels["validate_jumpdest_table_end"]
|
||||
@ -344,8 +345,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps<F: Field>(
|
||||
}
|
||||
if halt {
|
||||
log::debug!("Simulated CPU halted after {} cycles", state.traces.clock());
|
||||
state.jumpdest_addresses = Some(jumpdest_addresses);
|
||||
return Ok(());
|
||||
return Ok(Some(jumpdest_addresses));
|
||||
}
|
||||
transition(state).map_err(|_| {
|
||||
ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation)
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
use std::cmp::min;
|
||||
use std::collections::HashSet;
|
||||
use std::collections::HashMap;
|
||||
use std::mem::transmute;
|
||||
use std::str::FromStr;
|
||||
|
||||
use anyhow::{bail, Error};
|
||||
use ethereum_types::{BigEndianHash, H256, U256, U512};
|
||||
use hashbrown::HashMap;
|
||||
use itertools::{enumerate, Itertools};
|
||||
use num_bigint::BigUint;
|
||||
use plonky2::field::extension::Extendable;
|
||||
@ -256,87 +255,98 @@ impl<F: Field> GenerationState<F> {
|
||||
virt: ContextMetadata::CodeSize as usize,
|
||||
}))?;
|
||||
|
||||
if self.jumpdest_addresses.is_none() {
|
||||
self.generate_jumpdest_table()?;
|
||||
if self.jumpdest_proofs.is_none() {
|
||||
self.generate_jumpdest_proofs()?;
|
||||
}
|
||||
|
||||
let Some(jumpdest_tables) = &mut self.jumpdest_addresses else {
|
||||
let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else {
|
||||
return Err(ProgramError::ProverInputError(
|
||||
ProverInputError::InvalidJumpdestSimulation,
|
||||
));
|
||||
};
|
||||
|
||||
if let Some(ctx_jumpdest_table) = jumpdest_tables.get_mut(&context)
|
||||
&& let Some(next_jumpdest_address) = ctx_jumpdest_table.pop_last()
|
||||
if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context)
|
||||
&& let Some(next_jumpdest_address) = ctx_jumpdest_proofs.pop()
|
||||
{
|
||||
self.last_jumpdest_address = next_jumpdest_address;
|
||||
Ok((next_jumpdest_address + 1).into())
|
||||
} else {
|
||||
self.jumpdest_addresses = None;
|
||||
self.jumpdest_proofs = None;
|
||||
Ok(U256::zero())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the proof for the last jump address.
|
||||
fn run_next_jumpdest_table_proof(&mut self) -> Result<U256, ProgramError> {
|
||||
let code = (0..self.last_jumpdest_address)
|
||||
.map(|i| {
|
||||
u256_to_u8(self.memory.get(MemoryAddress {
|
||||
context: self.registers.context,
|
||||
segment: Segment::Code as usize,
|
||||
virt: i,
|
||||
}))
|
||||
})
|
||||
.collect::<Result<Vec<u8>, _>>()?;
|
||||
|
||||
// TODO: The proof searching algorithm is not very efficient. But luckily it doesn't seem
|
||||
// a problem as is done natively.
|
||||
|
||||
// Search the closest address to `last_jumpdest_address` for which none of
|
||||
// the previous 32 bytes in the code (including opcodes and pushed bytes)
|
||||
// are PUSHXX and the address is in its range.
|
||||
|
||||
const PUSH1_OPCODE: u8 = 0x60;
|
||||
const PUSH32_OPCODE: u8 = 0x7f;
|
||||
|
||||
let proof = CodeIterator::until(&code, self.last_jumpdest_address + 1).fold(
|
||||
0,
|
||||
|acc, (pos, opcode)| {
|
||||
let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) {
|
||||
code[prefix_start..pos].iter().enumerate().fold(
|
||||
true,
|
||||
|acc, (prefix_pos, &byte)| {
|
||||
acc && (byte > PUSH32_OPCODE
|
||||
|| (prefix_start + prefix_pos) as i32
|
||||
+ (byte as i32 - PUSH1_OPCODE as i32)
|
||||
+ 1
|
||||
< pos as i32)
|
||||
},
|
||||
)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
if has_prefix {
|
||||
pos - 32
|
||||
} else {
|
||||
acc
|
||||
}
|
||||
},
|
||||
);
|
||||
Ok(proof.into())
|
||||
let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else {
|
||||
return Err(ProgramError::ProverInputError(
|
||||
ProverInputError::InvalidJumpdestSimulation,
|
||||
));
|
||||
};
|
||||
if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context)
|
||||
&& let Some(next_jumpdest_proof) = ctx_jumpdest_proofs.pop()
|
||||
{
|
||||
Ok(next_jumpdest_proof.into())
|
||||
} else {
|
||||
Err(ProgramError::ProverInputError(
|
||||
ProverInputError::InvalidJumpdestSimulation,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> GenerationState<F> {
|
||||
fn generate_jumpdest_table(&mut self) -> Result<(), ProgramError> {
|
||||
const JUMPDEST_OPCODE: u8 = 0x5b;
|
||||
let mut state = self.soft_clone();
|
||||
let code_len = u256_to_usize(self.memory.get(MemoryAddress {
|
||||
context: self.registers.context,
|
||||
segment: Segment::ContextMetadata as usize,
|
||||
virt: ContextMetadata::CodeSize as usize,
|
||||
}))?;
|
||||
// Generate the jumpdest table
|
||||
fn generate_jumpdest_proofs(&mut self) -> Result<(), ProgramError> {
|
||||
let checkpoint = self.checkpoint();
|
||||
let memory = self.memory.clone();
|
||||
|
||||
let code = self.get_current_code()?;
|
||||
// We need to set the simulated jumpdest bits to one as otherwise
|
||||
// the simulation will fail.
|
||||
self.set_jumpdest_bits(&code);
|
||||
|
||||
// Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call
|
||||
let Some(jumpdest_table) = simulate_cpu_between_labels_and_get_user_jumps(
|
||||
"validate_jumpdest_table_end",
|
||||
"terminate_common",
|
||||
self,
|
||||
)?
|
||||
else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
// Return to the state before starting the simulation
|
||||
self.rollback(checkpoint);
|
||||
self.memory = memory;
|
||||
|
||||
// Find proofs for all context
|
||||
self.set_proofs_and_jumpdests(jumpdest_table);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn set_proofs_and_jumpdests(
|
||||
&mut self,
|
||||
jumpdest_table: HashMap<usize, std::collections::BTreeSet<usize>>,
|
||||
) {
|
||||
self.jumpdest_proofs = Some(HashMap::from_iter(jumpdest_table.into_iter().map(
|
||||
|(ctx, jumpdest_table)| {
|
||||
let code = self.get_code(ctx).unwrap();
|
||||
if let Some(&largest_address) = jumpdest_table.last() {
|
||||
let proofs = get_proofs_and_jumpdests(&code, largest_address, jumpdest_table);
|
||||
(ctx, proofs)
|
||||
} else {
|
||||
(ctx, vec![])
|
||||
}
|
||||
},
|
||||
)));
|
||||
}
|
||||
|
||||
fn get_current_code(&self) -> Result<Vec<u8>, ProgramError> {
|
||||
self.get_code(self.registers.context)
|
||||
}
|
||||
|
||||
fn get_code(&self, context: usize) -> Result<Vec<u8>, ProgramError> {
|
||||
let code_len = self.get_code_len()?;
|
||||
let code = (0..code_len)
|
||||
.map(|i| {
|
||||
u256_to_u8(self.memory.get(MemoryAddress {
|
||||
@ -346,16 +356,25 @@ impl<F: Field> GenerationState<F> {
|
||||
}))
|
||||
})
|
||||
.collect::<Result<Vec<u8>, _>>()?;
|
||||
Ok(code)
|
||||
}
|
||||
|
||||
// We need to set the simulated jumpdest bits to one as otherwise
|
||||
// the simulation will fail.
|
||||
let mut jumpdest_table = Vec::with_capacity(code.len());
|
||||
fn get_code_len(&self) -> Result<usize, ProgramError> {
|
||||
let code_len = u256_to_usize(self.memory.get(MemoryAddress {
|
||||
context: self.registers.context,
|
||||
segment: Segment::ContextMetadata as usize,
|
||||
virt: ContextMetadata::CodeSize as usize,
|
||||
}))?;
|
||||
Ok(code_len)
|
||||
}
|
||||
|
||||
fn set_jumpdest_bits<'a>(&mut self, code: &'a Vec<u8>) {
|
||||
const JUMPDEST_OPCODE: u8 = 0x5b;
|
||||
for (pos, opcode) in CodeIterator::new(&code) {
|
||||
jumpdest_table.push((pos, opcode == JUMPDEST_OPCODE));
|
||||
if opcode == JUMPDEST_OPCODE {
|
||||
state.memory.set(
|
||||
self.memory.set(
|
||||
MemoryAddress {
|
||||
context: state.registers.context,
|
||||
context: self.registers.context,
|
||||
segment: Segment::JumpdestBits as usize,
|
||||
virt: pos,
|
||||
},
|
||||
@ -363,18 +382,50 @@ impl<F: Field> GenerationState<F> {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call
|
||||
simulate_cpu_between_labels_and_get_user_jumps(
|
||||
"validate_jumpdest_table_end",
|
||||
"terminate_common",
|
||||
&mut state,
|
||||
)?;
|
||||
self.jumpdest_addresses = state.jumpdest_addresses;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// For each address in `jumpdest_table` it search a proof, that is the closest address
|
||||
/// for which none of the previous 32 bytes in the code (including opcodes
|
||||
/// and pushed bytes are PUSHXX and the address is in its range. It returns
|
||||
/// a vector of even size containing proofs followed by their addresses
|
||||
fn get_proofs_and_jumpdests<'a>(
|
||||
code: &'a Vec<u8>,
|
||||
largest_address: usize,
|
||||
jumpdest_table: std::collections::BTreeSet<usize>,
|
||||
) -> Vec<usize> {
|
||||
const PUSH1_OPCODE: u8 = 0x60;
|
||||
const PUSH32_OPCODE: u8 = 0x7f;
|
||||
let (proofs, _) = CodeIterator::until(&code, largest_address + 1).fold(
|
||||
(vec![], 0),
|
||||
|(mut proofs, acc), (pos, opcode)| {
|
||||
let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) {
|
||||
code[prefix_start..pos]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.fold(true, |acc, (prefix_pos, &byte)| {
|
||||
acc && (byte > PUSH32_OPCODE
|
||||
|| (prefix_start + prefix_pos) as i32
|
||||
+ (byte as i32 - PUSH1_OPCODE as i32)
|
||||
+ 1
|
||||
< pos as i32)
|
||||
})
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let acc = if has_prefix { pos - 32 } else { acc };
|
||||
if jumpdest_table.contains(&pos) {
|
||||
// Push the proof
|
||||
proofs.push(acc);
|
||||
// Push the address
|
||||
proofs.push(pos);
|
||||
}
|
||||
(proofs, acc)
|
||||
},
|
||||
);
|
||||
proofs
|
||||
}
|
||||
|
||||
struct CodeIterator<'a> {
|
||||
code: &'a [u8],
|
||||
pos: usize,
|
||||
|
||||
@ -51,8 +51,7 @@ pub(crate) struct GenerationState<F: Field> {
|
||||
/// Pointers, within the `TrieData` segment, of the three MPTs.
|
||||
pub(crate) trie_root_ptrs: TrieRootPtrs,
|
||||
|
||||
pub(crate) last_jumpdest_address: usize,
|
||||
pub(crate) jumpdest_addresses: Option<HashMap<usize, BTreeSet<usize>>>,
|
||||
pub(crate) jumpdest_proofs: Option<HashMap<usize, Vec<usize>>>,
|
||||
}
|
||||
|
||||
impl<F: Field> GenerationState<F> {
|
||||
@ -94,8 +93,7 @@ impl<F: Field> GenerationState<F> {
|
||||
txn_root_ptr: 0,
|
||||
receipt_root_ptr: 0,
|
||||
},
|
||||
last_jumpdest_address: 0,
|
||||
jumpdest_addresses: None,
|
||||
jumpdest_proofs: None,
|
||||
};
|
||||
let trie_root_ptrs = state.preinitialize_mpts(&inputs.tries);
|
||||
|
||||
@ -189,8 +187,7 @@ impl<F: Field> GenerationState<F> {
|
||||
txn_root_ptr: 0,
|
||||
receipt_root_ptr: 0,
|
||||
},
|
||||
last_jumpdest_address: 0,
|
||||
jumpdest_addresses: None,
|
||||
jumpdest_proofs: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user