mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-05-04 01:03:09 +00:00
Improve proof generation
This commit is contained in:
parent
c4025063de
commit
4e569484c2
@ -417,16 +417,17 @@ impl<'a> Interpreter<'a> {
|
|||||||
pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec<bool>) {
|
pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec<bool>) {
|
||||||
self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize]
|
self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize]
|
||||||
.content = jumpdest_bits.iter().map(|&x| u256_from_bool(x)).collect();
|
.content = jumpdest_bits.iter().map(|&x| u256_from_bool(x)).collect();
|
||||||
self.generation_state.jumpdest_addresses = Some(HashMap::from([(
|
self.generation_state
|
||||||
context,
|
.set_proofs_and_jumpdests(HashMap::from([(
|
||||||
BTreeSet::from_iter(
|
context,
|
||||||
jumpdest_bits
|
BTreeSet::from_iter(
|
||||||
.into_iter()
|
jumpdest_bits
|
||||||
.enumerate()
|
.into_iter()
|
||||||
.filter(|&(_, x)| x)
|
.enumerate()
|
||||||
.map(|(i, _)| i),
|
.filter(|&(_, x)| x)
|
||||||
),
|
.map(|(i, _)| i),
|
||||||
)]));
|
),
|
||||||
|
)]));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn incr(&mut self, n: usize) {
|
pub(crate) fn incr(&mut self, n: usize) {
|
||||||
|
|||||||
@ -340,9 +340,9 @@ fn simulate_cpu_between_labels_and_get_user_jumps<F: Field>(
|
|||||||
initial_label: &str,
|
initial_label: &str,
|
||||||
final_label: &str,
|
final_label: &str,
|
||||||
state: &mut GenerationState<F>,
|
state: &mut GenerationState<F>,
|
||||||
) -> Result<(), ProgramError> {
|
) -> Result<Option<HashMap<usize, BTreeSet<usize>>>, ProgramError> {
|
||||||
if state.jumpdest_addresses.is_some() {
|
if state.jumpdest_proofs.is_some() {
|
||||||
Ok(())
|
Ok(None)
|
||||||
} else {
|
} else {
|
||||||
const JUMP_OPCODE: u8 = 0x56;
|
const JUMP_OPCODE: u8 = 0x56;
|
||||||
const JUMPI_OPCODE: u8 = 0x57;
|
const JUMPI_OPCODE: u8 = 0x57;
|
||||||
@ -356,6 +356,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps<F: Field>(
|
|||||||
log::debug!("Simulating CPU for jumpdest analysis.");
|
log::debug!("Simulating CPU for jumpdest analysis.");
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
|
// skip jumdest table validations in simulations
|
||||||
if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] {
|
if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] {
|
||||||
state.registers.program_counter =
|
state.registers.program_counter =
|
||||||
KERNEL.global_labels["validate_jumpdest_table_end"]
|
KERNEL.global_labels["validate_jumpdest_table_end"]
|
||||||
@ -396,8 +397,7 @@ fn simulate_cpu_between_labels_and_get_user_jumps<F: Field>(
|
|||||||
}
|
}
|
||||||
if halt {
|
if halt {
|
||||||
log::debug!("Simulated CPU halted after {} cycles", state.traces.clock());
|
log::debug!("Simulated CPU halted after {} cycles", state.traces.clock());
|
||||||
state.jumpdest_addresses = Some(jumpdest_addresses);
|
return Ok(Some(jumpdest_addresses));
|
||||||
return Ok(());
|
|
||||||
}
|
}
|
||||||
transition(state).map_err(|_| {
|
transition(state).map_err(|_| {
|
||||||
ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation)
|
ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation)
|
||||||
|
|||||||
@ -1,11 +1,10 @@
|
|||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashMap;
|
||||||
use std::mem::transmute;
|
use std::mem::transmute;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
|
||||||
use anyhow::{bail, Error};
|
use anyhow::{bail, Error};
|
||||||
use ethereum_types::{BigEndianHash, H256, U256, U512};
|
use ethereum_types::{BigEndianHash, H256, U256, U512};
|
||||||
use hashbrown::HashMap;
|
|
||||||
use itertools::{enumerate, Itertools};
|
use itertools::{enumerate, Itertools};
|
||||||
use num_bigint::BigUint;
|
use num_bigint::BigUint;
|
||||||
use plonky2::field::extension::Extendable;
|
use plonky2::field::extension::Extendable;
|
||||||
@ -256,87 +255,98 @@ impl<F: Field> GenerationState<F> {
|
|||||||
virt: ContextMetadata::CodeSize as usize,
|
virt: ContextMetadata::CodeSize as usize,
|
||||||
}))?;
|
}))?;
|
||||||
|
|
||||||
if self.jumpdest_addresses.is_none() {
|
if self.jumpdest_proofs.is_none() {
|
||||||
self.generate_jumpdest_table()?;
|
self.generate_jumpdest_proofs()?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let Some(jumpdest_tables) = &mut self.jumpdest_addresses else {
|
let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else {
|
||||||
return Err(ProgramError::ProverInputError(
|
return Err(ProgramError::ProverInputError(
|
||||||
ProverInputError::InvalidJumpdestSimulation,
|
ProverInputError::InvalidJumpdestSimulation,
|
||||||
));
|
));
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(ctx_jumpdest_table) = jumpdest_tables.get_mut(&context)
|
if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context)
|
||||||
&& let Some(next_jumpdest_address) = ctx_jumpdest_table.pop_last()
|
&& let Some(next_jumpdest_address) = ctx_jumpdest_proofs.pop()
|
||||||
{
|
{
|
||||||
self.last_jumpdest_address = next_jumpdest_address;
|
|
||||||
Ok((next_jumpdest_address + 1).into())
|
Ok((next_jumpdest_address + 1).into())
|
||||||
} else {
|
} else {
|
||||||
self.jumpdest_addresses = None;
|
self.jumpdest_proofs = None;
|
||||||
Ok(U256::zero())
|
Ok(U256::zero())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the proof for the last jump address.
|
/// Returns the proof for the last jump address.
|
||||||
fn run_next_jumpdest_table_proof(&mut self) -> Result<U256, ProgramError> {
|
fn run_next_jumpdest_table_proof(&mut self) -> Result<U256, ProgramError> {
|
||||||
let code = (0..self.last_jumpdest_address)
|
let Some(jumpdest_proofs) = &mut self.jumpdest_proofs else {
|
||||||
.map(|i| {
|
return Err(ProgramError::ProverInputError(
|
||||||
u256_to_u8(self.memory.get(MemoryAddress {
|
ProverInputError::InvalidJumpdestSimulation,
|
||||||
context: self.registers.context,
|
));
|
||||||
segment: Segment::Code as usize,
|
};
|
||||||
virt: i,
|
if let Some(ctx_jumpdest_proofs) = jumpdest_proofs.get_mut(&self.registers.context)
|
||||||
}))
|
&& let Some(next_jumpdest_proof) = ctx_jumpdest_proofs.pop()
|
||||||
})
|
{
|
||||||
.collect::<Result<Vec<u8>, _>>()?;
|
Ok(next_jumpdest_proof.into())
|
||||||
|
} else {
|
||||||
// TODO: The proof searching algorithm is not very efficient. But luckily it doesn't seem
|
Err(ProgramError::ProverInputError(
|
||||||
// a problem as is done natively.
|
ProverInputError::InvalidJumpdestSimulation,
|
||||||
|
))
|
||||||
// Search the closest address to `last_jumpdest_address` for which none of
|
}
|
||||||
// the previous 32 bytes in the code (including opcodes and pushed bytes)
|
|
||||||
// are PUSHXX and the address is in its range.
|
|
||||||
|
|
||||||
const PUSH1_OPCODE: u8 = 0x60;
|
|
||||||
const PUSH32_OPCODE: u8 = 0x7f;
|
|
||||||
|
|
||||||
let proof = CodeIterator::until(&code, self.last_jumpdest_address + 1).fold(
|
|
||||||
0,
|
|
||||||
|acc, (pos, opcode)| {
|
|
||||||
let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) {
|
|
||||||
code[prefix_start..pos].iter().enumerate().fold(
|
|
||||||
true,
|
|
||||||
|acc, (prefix_pos, &byte)| {
|
|
||||||
acc && (byte > PUSH32_OPCODE
|
|
||||||
|| (prefix_start + prefix_pos) as i32
|
|
||||||
+ (byte as i32 - PUSH1_OPCODE as i32)
|
|
||||||
+ 1
|
|
||||||
< pos as i32)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
};
|
|
||||||
if has_prefix {
|
|
||||||
pos - 32
|
|
||||||
} else {
|
|
||||||
acc
|
|
||||||
}
|
|
||||||
},
|
|
||||||
);
|
|
||||||
Ok(proof.into())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<F: Field> GenerationState<F> {
|
impl<F: Field> GenerationState<F> {
|
||||||
fn generate_jumpdest_table(&mut self) -> Result<(), ProgramError> {
|
fn generate_jumpdest_proofs(&mut self) -> Result<(), ProgramError> {
|
||||||
const JUMPDEST_OPCODE: u8 = 0x5b;
|
let checkpoint = self.checkpoint();
|
||||||
let mut state = self.soft_clone();
|
let memory = self.memory.clone();
|
||||||
let code_len = u256_to_usize(self.memory.get(MemoryAddress {
|
|
||||||
context: self.registers.context,
|
let code = self.get_current_code()?;
|
||||||
segment: Segment::ContextMetadata as usize,
|
// We need to set the simulated jumpdest bits to one as otherwise
|
||||||
virt: ContextMetadata::CodeSize as usize,
|
// the simulation will fail.
|
||||||
}))?;
|
self.set_jumpdest_bits(&code);
|
||||||
// Generate the jumpdest table
|
|
||||||
|
// Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call
|
||||||
|
let Some(jumpdest_table) = simulate_cpu_between_labels_and_get_user_jumps(
|
||||||
|
"validate_jumpdest_table_end",
|
||||||
|
"terminate_common",
|
||||||
|
self,
|
||||||
|
)?
|
||||||
|
else {
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
|
||||||
|
// Return to the state before starting the simulation
|
||||||
|
self.rollback(checkpoint);
|
||||||
|
self.memory = memory;
|
||||||
|
|
||||||
|
// Find proofs for all context
|
||||||
|
self.set_proofs_and_jumpdests(jumpdest_table);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn set_proofs_and_jumpdests(
|
||||||
|
&mut self,
|
||||||
|
jumpdest_table: HashMap<usize, std::collections::BTreeSet<usize>>,
|
||||||
|
) {
|
||||||
|
self.jumpdest_proofs = Some(HashMap::from_iter(jumpdest_table.into_iter().map(
|
||||||
|
|(ctx, jumpdest_table)| {
|
||||||
|
let code = self.get_code(ctx).unwrap();
|
||||||
|
if let Some(&largest_address) = jumpdest_table.last() {
|
||||||
|
let proofs = get_proofs_and_jumpdests(&code, largest_address, jumpdest_table);
|
||||||
|
(ctx, proofs)
|
||||||
|
} else {
|
||||||
|
(ctx, vec![])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_current_code(&self) -> Result<Vec<u8>, ProgramError> {
|
||||||
|
self.get_code(self.registers.context)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_code(&self, context: usize) -> Result<Vec<u8>, ProgramError> {
|
||||||
|
let code_len = self.get_code_len()?;
|
||||||
let code = (0..code_len)
|
let code = (0..code_len)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
u256_to_u8(self.memory.get(MemoryAddress {
|
u256_to_u8(self.memory.get(MemoryAddress {
|
||||||
@ -346,16 +356,25 @@ impl<F: Field> GenerationState<F> {
|
|||||||
}))
|
}))
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<u8>, _>>()?;
|
.collect::<Result<Vec<u8>, _>>()?;
|
||||||
|
Ok(code)
|
||||||
|
}
|
||||||
|
|
||||||
// We need to set the simulated jumpdest bits to one as otherwise
|
fn get_code_len(&self) -> Result<usize, ProgramError> {
|
||||||
// the simulation will fail.
|
let code_len = u256_to_usize(self.memory.get(MemoryAddress {
|
||||||
let mut jumpdest_table = Vec::with_capacity(code.len());
|
context: self.registers.context,
|
||||||
|
segment: Segment::ContextMetadata as usize,
|
||||||
|
virt: ContextMetadata::CodeSize as usize,
|
||||||
|
}))?;
|
||||||
|
Ok(code_len)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_jumpdest_bits<'a>(&mut self, code: &'a Vec<u8>) {
|
||||||
|
const JUMPDEST_OPCODE: u8 = 0x5b;
|
||||||
for (pos, opcode) in CodeIterator::new(&code) {
|
for (pos, opcode) in CodeIterator::new(&code) {
|
||||||
jumpdest_table.push((pos, opcode == JUMPDEST_OPCODE));
|
|
||||||
if opcode == JUMPDEST_OPCODE {
|
if opcode == JUMPDEST_OPCODE {
|
||||||
state.memory.set(
|
self.memory.set(
|
||||||
MemoryAddress {
|
MemoryAddress {
|
||||||
context: state.registers.context,
|
context: self.registers.context,
|
||||||
segment: Segment::JumpdestBits as usize,
|
segment: Segment::JumpdestBits as usize,
|
||||||
virt: pos,
|
virt: pos,
|
||||||
},
|
},
|
||||||
@ -363,18 +382,50 @@ impl<F: Field> GenerationState<F> {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call
|
|
||||||
simulate_cpu_between_labels_and_get_user_jumps(
|
|
||||||
"validate_jumpdest_table_end",
|
|
||||||
"terminate_common",
|
|
||||||
&mut state,
|
|
||||||
)?;
|
|
||||||
self.jumpdest_addresses = state.jumpdest_addresses;
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// For each address in `jumpdest_table` it search a proof, that is the closest address
|
||||||
|
/// for which none of the previous 32 bytes in the code (including opcodes
|
||||||
|
/// and pushed bytes are PUSHXX and the address is in its range. It returns
|
||||||
|
/// a vector of even size containing proofs followed by their addresses
|
||||||
|
fn get_proofs_and_jumpdests<'a>(
|
||||||
|
code: &'a Vec<u8>,
|
||||||
|
largest_address: usize,
|
||||||
|
jumpdest_table: std::collections::BTreeSet<usize>,
|
||||||
|
) -> Vec<usize> {
|
||||||
|
const PUSH1_OPCODE: u8 = 0x60;
|
||||||
|
const PUSH32_OPCODE: u8 = 0x7f;
|
||||||
|
let (proofs, _) = CodeIterator::until(&code, largest_address + 1).fold(
|
||||||
|
(vec![], 0),
|
||||||
|
|(mut proofs, acc), (pos, opcode)| {
|
||||||
|
let has_prefix = if let Some(prefix_start) = pos.checked_sub(32) {
|
||||||
|
code[prefix_start..pos]
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.fold(true, |acc, (prefix_pos, &byte)| {
|
||||||
|
acc && (byte > PUSH32_OPCODE
|
||||||
|
|| (prefix_start + prefix_pos) as i32
|
||||||
|
+ (byte as i32 - PUSH1_OPCODE as i32)
|
||||||
|
+ 1
|
||||||
|
< pos as i32)
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
};
|
||||||
|
let acc = if has_prefix { pos - 32 } else { acc };
|
||||||
|
if jumpdest_table.contains(&pos) {
|
||||||
|
// Push the proof
|
||||||
|
proofs.push(acc);
|
||||||
|
// Push the address
|
||||||
|
proofs.push(pos);
|
||||||
|
}
|
||||||
|
(proofs, acc)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
proofs
|
||||||
|
}
|
||||||
|
|
||||||
struct CodeIterator<'a> {
|
struct CodeIterator<'a> {
|
||||||
code: &'a [u8],
|
code: &'a [u8],
|
||||||
pos: usize,
|
pos: usize,
|
||||||
|
|||||||
@ -51,8 +51,7 @@ pub(crate) struct GenerationState<F: Field> {
|
|||||||
/// Pointers, within the `TrieData` segment, of the three MPTs.
|
/// Pointers, within the `TrieData` segment, of the three MPTs.
|
||||||
pub(crate) trie_root_ptrs: TrieRootPtrs,
|
pub(crate) trie_root_ptrs: TrieRootPtrs,
|
||||||
|
|
||||||
pub(crate) last_jumpdest_address: usize,
|
pub(crate) jumpdest_proofs: Option<HashMap<usize, Vec<usize>>>,
|
||||||
pub(crate) jumpdest_addresses: Option<HashMap<usize, BTreeSet<usize>>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<F: Field> GenerationState<F> {
|
impl<F: Field> GenerationState<F> {
|
||||||
@ -94,8 +93,7 @@ impl<F: Field> GenerationState<F> {
|
|||||||
txn_root_ptr: 0,
|
txn_root_ptr: 0,
|
||||||
receipt_root_ptr: 0,
|
receipt_root_ptr: 0,
|
||||||
},
|
},
|
||||||
last_jumpdest_address: 0,
|
jumpdest_proofs: None,
|
||||||
jumpdest_addresses: None,
|
|
||||||
};
|
};
|
||||||
let trie_root_ptrs = state.preinitialize_mpts(&inputs.tries);
|
let trie_root_ptrs = state.preinitialize_mpts(&inputs.tries);
|
||||||
|
|
||||||
@ -189,8 +187,7 @@ impl<F: Field> GenerationState<F> {
|
|||||||
txn_root_ptr: 0,
|
txn_root_ptr: 0,
|
||||||
receipt_root_ptr: 0,
|
receipt_root_ptr: 0,
|
||||||
},
|
},
|
||||||
last_jumpdest_address: 0,
|
jumpdest_proofs: None,
|
||||||
jumpdest_addresses: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user