mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-08 08:43:06 +00:00
Eliminate nested simulations
This commit is contained in:
parent
2c5347c45f
commit
81f13f3f8a
@ -114,15 +114,12 @@ global is_jumpdest:
|
||||
MLOAD_GENERAL
|
||||
// stack: opcode, jumpdest, ctx, proof_prefix_addr, retdest
|
||||
|
||||
// Slightly more efficient than `%eq_const(0x5b) ISZERO`
|
||||
PUSH 0x5b
|
||||
SUB
|
||||
%jumpi(panic)
|
||||
%assert_eq_const(0x5b)
|
||||
|
||||
//stack: jumpdest, ctx, proof_prefix_addr, retdest
|
||||
SWAP2 DUP1
|
||||
// stack: proof_prefix_addr, proof_prefix_addr, ctx, jumpdest
|
||||
IS_ZERO
|
||||
ISZERO
|
||||
%jumpi(verify_path)
|
||||
// stack: proof_prefix_addr, ctx, jumpdest, retdest
|
||||
// If we are here we need to check that the next 32 bytes are less
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
//! An EVM interpreter for testing and debugging purposes.
|
||||
|
||||
use core::cmp::Ordering;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{BTreeSet, HashMap, HashSet};
|
||||
use std::ops::Range;
|
||||
|
||||
use anyhow::{anyhow, bail, ensure};
|
||||
@ -293,17 +293,19 @@ impl<'a> Interpreter<'a> {
|
||||
pub(crate) fn set_jumpdest_bits(&mut self, context: usize, jumpdest_bits: Vec<bool>) {
|
||||
self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize]
|
||||
.content = jumpdest_bits.iter().map(|&x| u256_from_bool(x)).collect();
|
||||
self.generation_state.jumpdest_addresses = Some(
|
||||
jumpdest_bits
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter(|&(_, x)| x)
|
||||
.map(|(i, _)| i)
|
||||
.collect(),
|
||||
)
|
||||
self.generation_state.jumpdest_addresses = Some(HashMap::from([(
|
||||
context,
|
||||
BTreeSet::from_iter(
|
||||
jumpdest_bits
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter(|&(_, x)| x)
|
||||
.map(|(i, _)| i),
|
||||
),
|
||||
)]));
|
||||
}
|
||||
|
||||
const fn incr(&mut self, n: usize) {
|
||||
pub(crate) fn incr(&mut self, n: usize) {
|
||||
self.generation_state.registers.program_counter += n;
|
||||
}
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::collections::{BTreeSet, HashMap, HashSet};
|
||||
|
||||
use anyhow::anyhow;
|
||||
use eth_trie_utils::partial_trie::{HashedPartialTrie, PartialTrie};
|
||||
@ -27,6 +27,7 @@ use crate::generation::state::GenerationState;
|
||||
use crate::memory::segments::Segment;
|
||||
use crate::proof::{BlockHashes, BlockMetadata, ExtraBlockData, PublicValues, TrieRoots};
|
||||
use crate::util::{h2u, u256_to_u8, u256_to_usize};
|
||||
use crate::witness::errors::{ProgramError, ProverInputError};
|
||||
use crate::witness::memory::{MemoryAddress, MemoryChannel};
|
||||
use crate::witness::transition::transition;
|
||||
|
||||
@ -287,56 +288,68 @@ fn simulate_cpu_between_labels_and_get_user_jumps<F: Field>(
|
||||
initial_label: &str,
|
||||
final_label: &str,
|
||||
state: &mut GenerationState<F>,
|
||||
) -> anyhow::Result<Vec<usize>> {
|
||||
let halt_pc = KERNEL.global_labels[final_label];
|
||||
let mut jumpdest_addresses = HashSet::new();
|
||||
state.registers.program_counter = KERNEL.global_labels[initial_label];
|
||||
let context = state.registers.context;
|
||||
) -> Result<(), ProgramError> {
|
||||
if let Some(_) = state.jumpdest_addresses {
|
||||
Ok(())
|
||||
} else {
|
||||
const JUMP_OPCODE: u8 = 0x56;
|
||||
const JUMPI_OPCODE: u8 = 0x57;
|
||||
|
||||
log::debug!("Simulating CPU for jumpdest analysis.");
|
||||
let halt_pc = KERNEL.global_labels[final_label];
|
||||
let mut jumpdest_addresses: HashMap<_, BTreeSet<usize>> = HashMap::new();
|
||||
|
||||
loop {
|
||||
if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] {
|
||||
state.registers.program_counter = KERNEL.global_labels["validate_jumpdest_table_end"]
|
||||
}
|
||||
let pc = state.registers.program_counter;
|
||||
let halt = state.registers.is_kernel && pc == halt_pc && state.registers.context == context;
|
||||
let opcode = u256_to_u8(state.memory.get(MemoryAddress {
|
||||
context: state.registers.context,
|
||||
segment: Segment::Code as usize,
|
||||
virt: state.registers.program_counter,
|
||||
}))
|
||||
.map_err(|_| anyhow::Error::msg("Invalid opcode."))?;
|
||||
let cond = if let Ok(cond) = stack_peek(state, 1) {
|
||||
cond != U256::zero()
|
||||
} else {
|
||||
false
|
||||
};
|
||||
if !state.registers.is_kernel
|
||||
&& (opcode == get_opcode("JUMP") || (opcode == get_opcode("JUMPI") && cond))
|
||||
{
|
||||
// TODO: hotfix for avoiding deeper calls to abort
|
||||
let jumpdest = u256_to_usize(state.registers.stack_top)
|
||||
.map_err(|_| anyhow!("Not a valid jump destination"))?;
|
||||
state.memory.set(
|
||||
MemoryAddress {
|
||||
context: state.registers.context,
|
||||
segment: Segment::JumpdestBits as usize,
|
||||
virt: jumpdest,
|
||||
},
|
||||
U256::one(),
|
||||
);
|
||||
if (state.registers.context == context) {
|
||||
jumpdest_addresses.insert(jumpdest);
|
||||
state.registers.program_counter = KERNEL.global_labels[initial_label];
|
||||
let initial_context = state.registers.context;
|
||||
|
||||
log::debug!("Simulating CPU for jumpdest analysis.");
|
||||
|
||||
loop {
|
||||
if state.registers.program_counter == KERNEL.global_labels["validate_jumpdest_table"] {
|
||||
state.registers.program_counter =
|
||||
KERNEL.global_labels["validate_jumpdest_table_end"]
|
||||
}
|
||||
let pc = state.registers.program_counter;
|
||||
let context = state.registers.context;
|
||||
let halt = state.registers.is_kernel
|
||||
&& pc == halt_pc
|
||||
&& state.registers.context == initial_context;
|
||||
let opcode = u256_to_u8(state.memory.get(MemoryAddress {
|
||||
context,
|
||||
segment: Segment::Code as usize,
|
||||
virt: state.registers.program_counter,
|
||||
}))?;
|
||||
let cond = if let Ok(cond) = stack_peek(state, 1) {
|
||||
cond != U256::zero()
|
||||
} else {
|
||||
false
|
||||
};
|
||||
if !state.registers.is_kernel
|
||||
&& (opcode == JUMP_OPCODE || (opcode == JUMPI_OPCODE && cond))
|
||||
{
|
||||
// Avoid deeper calls to abort
|
||||
let jumpdest = u256_to_usize(state.registers.stack_top)?;
|
||||
state.memory.set(
|
||||
MemoryAddress {
|
||||
context,
|
||||
segment: Segment::JumpdestBits as usize,
|
||||
virt: jumpdest,
|
||||
},
|
||||
U256::one(),
|
||||
);
|
||||
if let Some(ctx_addresses) = jumpdest_addresses.get_mut(&context) {
|
||||
ctx_addresses.insert(jumpdest);
|
||||
} else {
|
||||
jumpdest_addresses.insert(context, BTreeSet::from([jumpdest]));
|
||||
}
|
||||
}
|
||||
if halt {
|
||||
log::debug!("Simulated CPU halted after {} cycles", state.traces.clock());
|
||||
state.jumpdest_addresses = Some(jumpdest_addresses);
|
||||
return Ok(());
|
||||
}
|
||||
transition(state).map_err(|_| {
|
||||
ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation)
|
||||
})?;
|
||||
}
|
||||
if halt {
|
||||
log::debug!("Simulated CPU halted after {} cycles", state.traces.clock());
|
||||
let mut jumpdest_addresses: Vec<usize> = jumpdest_addresses.into_iter().collect();
|
||||
jumpdest_addresses.sort();
|
||||
return Ok(jumpdest_addresses);
|
||||
}
|
||||
|
||||
transition(state)?;
|
||||
}
|
||||
}
|
||||
|
||||
@ -16,7 +16,6 @@ use serde::{Deserialize, Serialize};
|
||||
use crate::cpu::kernel::aggregator::KERNEL;
|
||||
use crate::cpu::kernel::constants::context_metadata::ContextMetadata;
|
||||
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
|
||||
use crate::cpu::kernel::opcodes::{get_opcode, get_push_opcode};
|
||||
use crate::extension_tower::{FieldExt, Fp12, BLS381, BN254};
|
||||
use crate::generation::prover_input::EvmField::{
|
||||
Bls381Base, Bls381Scalar, Bn254Base, Bn254Scalar, Secp256k1Base, Secp256k1Scalar,
|
||||
@ -250,8 +249,9 @@ impl<F: Field> GenerationState<F> {
|
||||
}
|
||||
/// Return the next used jump addres
|
||||
fn run_next_jumpdest_table_address(&mut self) -> Result<U256, ProgramError> {
|
||||
let context = self.registers.context;
|
||||
let code_len = u256_to_usize(self.memory.get(MemoryAddress {
|
||||
context: self.registers.context,
|
||||
context,
|
||||
segment: Segment::ContextMetadata as usize,
|
||||
virt: ContextMetadata::CodeSize as usize,
|
||||
}))?;
|
||||
@ -260,14 +260,14 @@ impl<F: Field> GenerationState<F> {
|
||||
self.generate_jumpdest_table()?;
|
||||
}
|
||||
|
||||
let Some(jumpdest_table) = &mut self.jumpdest_addresses else {
|
||||
// TODO: Add another error
|
||||
let Some(jumpdest_tables) = &mut self.jumpdest_addresses else {
|
||||
return Err(ProgramError::ProverInputError(ProverInputError::InvalidJumpdestSimulation));
|
||||
};
|
||||
|
||||
if let Some(next_jumpdest_address) = jumpdest_table.pop() {
|
||||
self.last_jumpdest_address = next_jumpdest_address;
|
||||
Ok((next_jumpdest_address + 1).into())
|
||||
if let Some(ctx_jumpdest_table) = jumpdest_tables.get_mut(&context) && let Some(next_jumpdest_address) = ctx_jumpdest_table.pop_last()
|
||||
{
|
||||
self.last_jumpdest_address = next_jumpdest_address;
|
||||
Ok((next_jumpdest_address + 1).into())
|
||||
} else {
|
||||
self.jumpdest_addresses = None;
|
||||
Ok(U256::zero())
|
||||
@ -293,6 +293,9 @@ impl<F: Field> GenerationState<F> {
|
||||
// the previous 32 bytes in the code (including opcodes and pushed bytes)
|
||||
// are PUSHXX and the address is in its range.
|
||||
|
||||
const PUSH1_OPCODE: u8 = 0x60;
|
||||
const PUSH32_OPCODE: u8 = 0x7f;
|
||||
|
||||
let proof = CodeIterator::until(&code, self.last_jumpdest_address + 1).fold(
|
||||
0,
|
||||
|acc, (pos, opcode)| {
|
||||
@ -300,9 +303,9 @@ impl<F: Field> GenerationState<F> {
|
||||
code[prefix_start..pos].iter().enumerate().fold(
|
||||
true,
|
||||
|acc, (prefix_pos, &byte)| {
|
||||
acc && (byte > get_push_opcode(32)
|
||||
acc && (byte > PUSH32_OPCODE
|
||||
|| (prefix_start + prefix_pos) as i32
|
||||
+ (byte as i32 - get_push_opcode(1) as i32)
|
||||
+ (byte as i32 - PUSH1_OPCODE as i32)
|
||||
+ 1
|
||||
< pos as i32)
|
||||
},
|
||||
@ -323,6 +326,7 @@ impl<F: Field> GenerationState<F> {
|
||||
|
||||
impl<F: Field> GenerationState<F> {
|
||||
fn generate_jumpdest_table(&mut self) -> Result<(), ProgramError> {
|
||||
const JUMPDEST_OPCODE: u8 = 0x5b;
|
||||
let mut state = self.soft_clone();
|
||||
let code_len = u256_to_usize(self.memory.get(MemoryAddress {
|
||||
context: self.registers.context,
|
||||
@ -344,8 +348,8 @@ impl<F: Field> GenerationState<F> {
|
||||
// the simulation will fail.
|
||||
let mut jumpdest_table = Vec::with_capacity(code.len());
|
||||
for (pos, opcode) in CodeIterator::new(&code) {
|
||||
jumpdest_table.push((pos, opcode == get_opcode("JUMPDEST")));
|
||||
if opcode == get_opcode("JUMPDEST") {
|
||||
jumpdest_table.push((pos, opcode == JUMPDEST_OPCODE));
|
||||
if opcode == JUMPDEST_OPCODE {
|
||||
state.memory.set(
|
||||
MemoryAddress {
|
||||
context: state.registers.context,
|
||||
@ -358,32 +362,31 @@ impl<F: Field> GenerationState<F> {
|
||||
}
|
||||
|
||||
// Simulate the user's code and (unnecessarily) part of the kernel code, skipping the validate table call
|
||||
self.jumpdest_addresses = simulate_cpu_between_labels_and_get_user_jumps(
|
||||
simulate_cpu_between_labels_and_get_user_jumps(
|
||||
"validate_jumpdest_table_end",
|
||||
"terminate_common",
|
||||
&mut state,
|
||||
)
|
||||
.ok();
|
||||
|
||||
)?;
|
||||
self.jumpdest_addresses = state.jumpdest_addresses;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct CodeIterator<'a> {
|
||||
code: &'a Vec<u8>,
|
||||
code: &'a [u8],
|
||||
pos: usize,
|
||||
end: usize,
|
||||
}
|
||||
|
||||
impl<'a> CodeIterator<'a> {
|
||||
fn new(code: &'a Vec<u8>) -> Self {
|
||||
fn new(code: &'a [u8]) -> Self {
|
||||
CodeIterator {
|
||||
end: code.len(),
|
||||
code,
|
||||
pos: 0,
|
||||
}
|
||||
}
|
||||
fn until(code: &'a Vec<u8>, end: usize) -> Self {
|
||||
fn until(code: &'a [u8], end: usize) -> Self {
|
||||
CodeIterator {
|
||||
end: std::cmp::min(code.len(), end),
|
||||
code,
|
||||
@ -396,14 +399,16 @@ impl<'a> Iterator for CodeIterator<'a> {
|
||||
type Item = (usize, u8);
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
const PUSH1_OPCODE: u8 = 0x60;
|
||||
const PUSH32_OPCODE: u8 = 0x70;
|
||||
let CodeIterator { code, pos, end } = self;
|
||||
if *pos >= *end {
|
||||
return None;
|
||||
}
|
||||
let opcode = code[*pos];
|
||||
let old_pos = *pos;
|
||||
*pos += if opcode >= get_push_opcode(1) && opcode <= get_push_opcode(32) {
|
||||
(opcode - get_push_opcode(1) + 2).into()
|
||||
*pos += if opcode >= PUSH1_OPCODE && opcode <= PUSH32_OPCODE {
|
||||
(opcode - PUSH1_OPCODE + 2).into()
|
||||
} else {
|
||||
1
|
||||
};
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
|
||||
use ethereum_types::{Address, BigEndianHash, H160, H256, U256};
|
||||
use keccak_hash::keccak;
|
||||
@ -52,7 +52,7 @@ pub(crate) struct GenerationState<F: Field> {
|
||||
pub(crate) trie_root_ptrs: TrieRootPtrs,
|
||||
|
||||
pub(crate) last_jumpdest_address: usize,
|
||||
pub(crate) jumpdest_addresses: Option<Vec<usize>>,
|
||||
pub(crate) jumpdest_addresses: Option<HashMap<usize, BTreeSet<usize>>>,
|
||||
}
|
||||
|
||||
impl<F: Field> GenerationState<F> {
|
||||
|
||||
@ -395,12 +395,7 @@ fn try_perform_instruction<F: Field>(
|
||||
if state.registers.is_kernel {
|
||||
log_kernel_instruction(state, op);
|
||||
} else {
|
||||
log::debug!(
|
||||
"User instruction: {:?} ctx = {:?} stack = {:?}",
|
||||
op,
|
||||
state.registers.context,
|
||||
state.stack()
|
||||
);
|
||||
log::debug!("User instruction: {:?}", op);
|
||||
}
|
||||
|
||||
fill_op_flag(op, &mut row);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user