Interpreter fixes

This commit is contained in:
Daniel Lubarov 2022-12-02 19:18:37 -08:00
parent 82d0f08193
commit b96c22a4f4
7 changed files with 113 additions and 76 deletions

View File

@ -73,4 +73,4 @@ recursion_return:
jump
global sys_exp:
PANIC
PANIC // TODO: Implement.

View File

@ -16,6 +16,7 @@ use crate::generation::state::GenerationState;
use crate::generation::GenerationInputs;
use crate::memory::segments::Segment;
use crate::witness::memory::{MemoryContextState, MemorySegmentState, MemoryState};
use crate::witness::util::stack_peek;
type F = GoldilocksField;
@ -23,13 +24,6 @@ type F = GoldilocksField;
const DEFAULT_HALT_OFFSET: usize = 0xdeadbeef;
impl MemoryState {
fn with_code_and_stack(code: &[u8], stack: Vec<U256>) -> Self {
let mut mem = Self::new(code);
mem.contexts[0].segments[Segment::Stack as usize].content = stack;
mem
}
fn mload_general(&self, context: usize, segment: Segment, offset: usize) -> U256 {
let value = self.contexts[context].segments[segment as usize].get(offset);
assert!(
@ -53,9 +47,7 @@ impl MemoryState {
pub struct Interpreter<'a> {
kernel_mode: bool,
jumpdests: Vec<usize>,
pub(crate) offset: usize,
pub(crate) context: usize,
pub(crate) memory: MemoryState,
pub(crate) generation_state: GenerationState<F>,
prover_inputs_map: &'a HashMap<usize, ProverInputFn>,
pub(crate) halt_offsets: Vec<usize>,
@ -103,11 +95,9 @@ impl<'a> Interpreter<'a> {
initial_stack: Vec<U256>,
prover_inputs: &'a HashMap<usize, ProverInputFn>,
) -> Self {
Self {
let mut result = Self {
kernel_mode: true,
jumpdests: find_jumpdests(code),
offset: initial_offset,
memory: MemoryState::with_code_and_stack(code, initial_stack),
generation_state: GenerationState::new(GenerationInputs::default(), code),
prover_inputs_map: prover_inputs,
context: 0,
@ -115,7 +105,11 @@ impl<'a> Interpreter<'a> {
debug_offsets: vec![],
running: false,
opcode_count: [0; 0x100],
}
};
result.generation_state.registers.program_counter = initial_offset;
result.generation_state.registers.stack_len = initial_stack.len();
*result.stack_mut() = initial_stack;
result
}
pub(crate) fn run(&mut self) -> anyhow::Result<()> {
@ -133,47 +127,51 @@ impl<'a> Interpreter<'a> {
}
fn code(&self) -> &MemorySegmentState {
&self.memory.contexts[self.context].segments[Segment::Code as usize]
&self.generation_state.memory.contexts[self.context].segments[Segment::Code as usize]
}
fn code_slice(&self, n: usize) -> Vec<u8> {
self.code().content[self.offset..self.offset + n]
let pc = self.generation_state.registers.program_counter;
self.code().content[pc..pc + n]
.iter()
.map(|u256| u256.byte(0))
.collect::<Vec<_>>()
}
pub(crate) fn get_txn_field(&self, field: NormalizedTxnField) -> U256 {
self.memory.contexts[0].segments[Segment::TxnFields as usize].get(field as usize)
self.generation_state.memory.contexts[0].segments[Segment::TxnFields as usize]
.get(field as usize)
}
pub(crate) fn set_txn_field(&mut self, field: NormalizedTxnField, value: U256) {
self.memory.contexts[0].segments[Segment::TxnFields as usize].set(field as usize, value);
self.generation_state.memory.contexts[0].segments[Segment::TxnFields as usize]
.set(field as usize, value);
}
pub(crate) fn get_txn_data(&self) -> &[U256] {
&self.memory.contexts[0].segments[Segment::TxnData as usize].content
&self.generation_state.memory.contexts[0].segments[Segment::TxnData as usize].content
}
pub(crate) fn get_global_metadata_field(&self, field: GlobalMetadata) -> U256 {
self.memory.contexts[0].segments[Segment::GlobalMetadata as usize].get(field as usize)
self.generation_state.memory.contexts[0].segments[Segment::GlobalMetadata as usize]
.get(field as usize)
}
pub(crate) fn set_global_metadata_field(&mut self, field: GlobalMetadata, value: U256) {
self.memory.contexts[0].segments[Segment::GlobalMetadata as usize]
self.generation_state.memory.contexts[0].segments[Segment::GlobalMetadata as usize]
.set(field as usize, value)
}
pub(crate) fn get_trie_data(&self) -> &[U256] {
&self.memory.contexts[0].segments[Segment::TrieData as usize].content
&self.generation_state.memory.contexts[0].segments[Segment::TrieData as usize].content
}
pub(crate) fn get_trie_data_mut(&mut self) -> &mut Vec<U256> {
&mut self.memory.contexts[0].segments[Segment::TrieData as usize].content
&mut self.generation_state.memory.contexts[0].segments[Segment::TrieData as usize].content
}
pub(crate) fn get_rlp_memory(&self) -> Vec<u8> {
self.memory.contexts[0].segments[Segment::RlpRaw as usize]
self.generation_state.memory.contexts[0].segments[Segment::RlpRaw as usize]
.content
.iter()
.map(|x| x.as_u32() as u8)
@ -181,21 +179,24 @@ impl<'a> Interpreter<'a> {
}
pub(crate) fn set_rlp_memory(&mut self, rlp: Vec<u8>) {
self.memory.contexts[0].segments[Segment::RlpRaw as usize].content =
self.generation_state.memory.contexts[0].segments[Segment::RlpRaw as usize].content =
rlp.into_iter().map(U256::from).collect();
}
pub(crate) fn set_code(&mut self, context: usize, code: Vec<u8>) {
assert_ne!(context, 0, "Can't modify kernel code.");
while self.memory.contexts.len() <= context {
self.memory.contexts.push(MemoryContextState::default());
while self.generation_state.memory.contexts.len() <= context {
self.generation_state
.memory
.contexts
.push(MemoryContextState::default());
}
self.memory.contexts[context].segments[Segment::Code as usize].content =
self.generation_state.memory.contexts[context].segments[Segment::Code as usize].content =
code.into_iter().map(U256::from).collect();
}
pub(crate) fn get_jumpdest_bits(&self, context: usize) -> Vec<bool> {
self.memory.contexts[context].segments[Segment::JumpdestBits as usize]
self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize]
.content
.iter()
.map(|x| x.bit(0))
@ -203,19 +204,22 @@ impl<'a> Interpreter<'a> {
}
fn incr(&mut self, n: usize) {
self.offset += n;
self.generation_state.registers.program_counter += n;
}
pub(crate) fn stack(&self) -> &[U256] {
&self.memory.contexts[self.context].segments[Segment::Stack as usize].content
&self.generation_state.memory.contexts[self.context].segments[Segment::Stack as usize]
.content
}
fn stack_mut(&mut self) -> &mut Vec<U256> {
&mut self.memory.contexts[self.context].segments[Segment::Stack as usize].content
&mut self.generation_state.memory.contexts[self.context].segments[Segment::Stack as usize]
.content
}
pub(crate) fn push(&mut self, x: U256) {
self.stack_mut().push(x);
self.generation_state.registers.stack_len += 1;
}
fn push_bool(&mut self, x: bool) {
@ -223,11 +227,15 @@ impl<'a> Interpreter<'a> {
}
pub(crate) fn pop(&mut self) -> U256 {
self.stack_mut().pop().expect("Pop on empty stack.")
let result = stack_peek(&self.generation_state, 0);
self.generation_state.registers.stack_len -= 1;
let new_len = self.stack_len();
self.stack_mut().truncate(new_len);
result.expect("Empty stack")
}
fn run_opcode(&mut self) -> anyhow::Result<()> {
let opcode = self.code().get(self.offset).byte(0);
let opcode = self.code().get(self.generation_state.registers.program_counter).byte(0);
self.opcode_count[opcode as usize] += 1;
self.incr(1);
match opcode {
@ -327,7 +335,7 @@ impl<'a> Interpreter<'a> {
_ => bail!("Unrecognized opcode {}.", opcode),
};
if self.debug_offsets.contains(&self.offset) {
if self.debug_offsets.contains(&self.generation_state.registers.program_counter) {
println!("At {}, stack={:?}", self.offset_name(), self.stack());
} else if let Some(label) = self.offset_label() {
println!("At {label}");
@ -337,11 +345,11 @@ impl<'a> Interpreter<'a> {
}
fn offset_name(&self) -> String {
KERNEL.offset_name(self.offset)
KERNEL.offset_name(self.generation_state.registers.program_counter)
}
fn offset_label(&self) -> Option<String> {
KERNEL.offset_label(self.offset)
KERNEL.offset_label(self.generation_state.registers.program_counter)
}
fn run_stop(&mut self) {
@ -503,7 +511,8 @@ impl<'a> Interpreter<'a> {
let size = self.pop().as_usize();
let bytes = (offset..offset + size)
.map(|i| {
self.memory
self.generation_state
.memory
.mload_general(self.context, Segment::MainMemory, i)
.byte(0)
})
@ -520,7 +529,12 @@ impl<'a> Interpreter<'a> {
let offset = self.pop().as_usize();
let size = self.pop().as_usize();
let bytes = (offset..offset + size)
.map(|i| self.memory.mload_general(context, segment, i).byte(0))
.map(|i| {
self.generation_state
.memory
.mload_general(context, segment, i)
.byte(0)
})
.collect::<Vec<_>>();
println!("Hashing {:?}", &bytes);
let hash = keccak(bytes);
@ -529,7 +543,8 @@ impl<'a> Interpreter<'a> {
fn run_callvalue(&mut self) {
self.push(
self.memory.contexts[self.context].segments[Segment::ContextMetadata as usize]
self.generation_state.memory.contexts[self.context].segments
[Segment::ContextMetadata as usize]
.get(ContextMetadata::CallValue as usize),
)
}
@ -539,7 +554,8 @@ impl<'a> Interpreter<'a> {
let value = U256::from_big_endian(
&(0..32)
.map(|i| {
self.memory
self.generation_state
.memory
.mload_general(self.context, Segment::Calldata, offset + i)
.byte(0)
})
@ -550,7 +566,8 @@ impl<'a> Interpreter<'a> {
fn run_calldatasize(&mut self) {
self.push(
self.memory.contexts[self.context].segments[Segment::ContextMetadata as usize]
self.generation_state.memory.contexts[self.context].segments
[Segment::ContextMetadata as usize]
.get(ContextMetadata::CalldataSize as usize),
)
}
@ -560,10 +577,12 @@ impl<'a> Interpreter<'a> {
let offset = self.pop().as_usize();
let size = self.pop().as_usize();
for i in 0..size {
let calldata_byte =
self.memory
.mload_general(self.context, Segment::Calldata, offset + i);
self.memory.mstore_general(
let calldata_byte = self.generation_state.memory.mload_general(
self.context,
Segment::Calldata,
offset + i,
);
self.generation_state.memory.mstore_general(
self.context,
Segment::MainMemory,
dest_offset + i,
@ -575,7 +594,7 @@ impl<'a> Interpreter<'a> {
fn run_prover_input(&mut self) -> anyhow::Result<()> {
let prover_input_fn = self
.prover_inputs_map
.get(&(self.offset - 1))
.get(&(self.generation_state.registers.program_counter - 1))
.ok_or_else(|| anyhow!("Offset not in prover inputs."))?;
let output = self.generation_state.prover_input(prover_input_fn);
self.push(output);
@ -591,7 +610,8 @@ impl<'a> Interpreter<'a> {
let value = U256::from_big_endian(
&(0..32)
.map(|i| {
self.memory
self.generation_state
.memory
.mload_general(self.context, Segment::MainMemory, offset + i)
.byte(0)
})
@ -606,15 +626,19 @@ impl<'a> Interpreter<'a> {
let mut bytes = [0; 32];
value.to_big_endian(&mut bytes);
for (i, byte) in (0..32).zip(bytes) {
self.memory
.mstore_general(self.context, Segment::MainMemory, offset + i, byte.into());
self.generation_state.memory.mstore_general(
self.context,
Segment::MainMemory,
offset + i,
byte.into(),
);
}
}
fn run_mstore8(&mut self) {
let offset = self.pop().as_usize();
let value = self.pop();
self.memory.mstore_general(
self.generation_state.memory.mstore_general(
self.context,
Segment::MainMemory,
offset,
@ -636,12 +660,13 @@ impl<'a> Interpreter<'a> {
}
fn run_pc(&mut self) {
self.push((self.offset - 1).into());
self.push((self.generation_state.registers.program_counter - 1).into());
}
fn run_msize(&mut self) {
self.push(
self.memory.contexts[self.context].segments[Segment::ContextMetadata as usize]
self.generation_state.memory.contexts[self.context].segments
[Segment::ContextMetadata as usize]
.get(ContextMetadata::MSize as usize),
)
}
@ -656,7 +681,7 @@ impl<'a> Interpreter<'a> {
panic!("Destination is not a JUMPDEST.");
}
self.offset = offset;
self.generation_state.registers.program_counter = offset;
if self.halt_offsets.contains(&offset) {
self.running = false;
@ -670,11 +695,11 @@ impl<'a> Interpreter<'a> {
}
fn run_dup(&mut self, n: u8) {
self.push(self.stack()[self.stack().len() - n as usize]);
self.push(self.stack()[self.stack_len() - n as usize]);
}
fn run_swap(&mut self, n: u8) -> anyhow::Result<()> {
let len = self.stack().len();
let len = self.stack_len();
ensure!(len > n as usize);
self.stack_mut().swap(len - 1, len - n as usize - 1);
Ok(())
@ -693,7 +718,10 @@ impl<'a> Interpreter<'a> {
let context = self.pop().as_usize();
let segment = Segment::all()[self.pop().as_usize()];
let offset = self.pop().as_usize();
let value = self.memory.mload_general(context, segment, offset);
let value = self
.generation_state
.memory
.mload_general(context, segment, offset);
assert!(value.bits() <= segment.bit_range());
self.push(value);
}
@ -710,7 +738,13 @@ impl<'a> Interpreter<'a> {
segment,
segment.bit_range()
);
self.memory.mstore_general(context, segment, offset, value);
self.generation_state
.memory
.mstore_general(context, segment, offset, value);
}
fn stack_len(&self) -> usize {
self.generation_state.registers.stack_len
}
}
@ -932,11 +966,13 @@ mod tests {
let run = run(&code, 0, vec![], &pis)?;
assert_eq!(run.stack(), &[0xff.into(), 0xff00.into()]);
assert_eq!(
run.memory.contexts[0].segments[Segment::MainMemory as usize].get(0x27),
run.generation_state.memory.contexts[0].segments[Segment::MainMemory as usize]
.get(0x27),
0x42.into()
);
assert_eq!(
run.memory.contexts[0].segments[Segment::MainMemory as usize].get(0x1f),
run.generation_state.memory.contexts[0].segments[Segment::MainMemory as usize]
.get(0x1f),
0xff.into()
);
Ok(())

View File

@ -42,7 +42,7 @@ fn prepare_interpreter(
let mut state_trie: PartialTrie = Default::default();
let trie_inputs = Default::default();
interpreter.offset = load_all_mpts;
interpreter.generation_state.registers.program_counter = load_all_mpts;
interpreter.push(0xDEADBEEFu32.into());
interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs);
@ -53,7 +53,7 @@ fn prepare_interpreter(
keccak(address.to_fixed_bytes()).as_bytes(),
));
// Next, execute mpt_insert_state_trie.
interpreter.offset = mpt_insert_state_trie;
interpreter.generation_state.registers.program_counter = mpt_insert_state_trie;
let trie_data = interpreter.get_trie_data_mut();
if trie_data.is_empty() {
// In the assembly we skip over 0, knowing trie_data[0] = 0 by default.
@ -83,7 +83,7 @@ fn prepare_interpreter(
);
// Now, execute mpt_hash_state_trie.
interpreter.offset = mpt_hash_state_trie;
interpreter.generation_state.registers.program_counter = mpt_hash_state_trie;
interpreter.push(0xDEADBEEFu32.into());
interpreter.run()?;
@ -115,7 +115,7 @@ fn test_extcodesize() -> Result<()> {
let extcodesize = KERNEL.global_labels["extcodesize"];
// Test `extcodesize`
interpreter.offset = extcodesize;
interpreter.generation_state.registers.program_counter = extcodesize;
interpreter.pop();
assert!(interpreter.stack().is_empty());
interpreter.push(0xDEADBEEFu32.into());
@ -144,9 +144,10 @@ fn test_extcodecopy() -> Result<()> {
// Put random data in main memory and the `KernelAccountCode` segment for realism.
let mut rng = thread_rng();
for i in 0..2000 {
interpreter.memory.contexts[interpreter.context].segments[Segment::MainMemory as usize]
interpreter.generation_state.memory.contexts[interpreter.context].segments
[Segment::MainMemory as usize]
.set(i, U256::from(rng.gen::<u8>()));
interpreter.memory.contexts[interpreter.context].segments
interpreter.generation_state.memory.contexts[interpreter.context].segments
[Segment::KernelAccountCode as usize]
.set(i, U256::from(rng.gen::<u8>()));
}
@ -157,7 +158,7 @@ fn test_extcodecopy() -> Result<()> {
let size = rng.gen_range(0..1500);
// Test `extcodecopy`
interpreter.offset = extcodecopy;
interpreter.generation_state.registers.program_counter = extcodecopy;
interpreter.pop();
assert!(interpreter.stack().is_empty());
interpreter.push(0xDEADBEEFu32.into());
@ -172,7 +173,7 @@ fn test_extcodecopy() -> Result<()> {
assert!(interpreter.stack().is_empty());
// Check that the code was correctly copied to memory.
for i in 0..size {
let memory = interpreter.memory.contexts[interpreter.context].segments
let memory = interpreter.generation_state.memory.contexts[interpreter.context].segments
[Segment::MainMemory as usize]
.get(dest_offset + i);
assert_eq!(

View File

@ -33,7 +33,7 @@ fn prepare_interpreter(
let mut state_trie: PartialTrie = Default::default();
let trie_inputs = Default::default();
interpreter.offset = load_all_mpts;
interpreter.generation_state.registers.program_counter = load_all_mpts;
interpreter.push(0xDEADBEEFu32.into());
interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs);
@ -44,7 +44,7 @@ fn prepare_interpreter(
keccak(address.to_fixed_bytes()).as_bytes(),
));
// Next, execute mpt_insert_state_trie.
interpreter.offset = mpt_insert_state_trie;
interpreter.generation_state.registers.program_counter = mpt_insert_state_trie;
let trie_data = interpreter.get_trie_data_mut();
if trie_data.is_empty() {
// In the assembly we skip over 0, knowing trie_data[0] = 0 by default.
@ -74,7 +74,7 @@ fn prepare_interpreter(
);
// Now, execute mpt_hash_state_trie.
interpreter.offset = mpt_hash_state_trie;
interpreter.generation_state.registers.program_counter = mpt_hash_state_trie;
interpreter.push(0xDEADBEEFu32.into());
interpreter.run()?;
@ -105,7 +105,7 @@ fn test_balance() -> Result<()> {
prepare_interpreter(&mut interpreter, address, &account)?;
// Test `balance`
interpreter.offset = KERNEL.global_labels["balance"];
interpreter.generation_state.registers.program_counter = KERNEL.global_labels["balance"];
interpreter.pop();
assert!(interpreter.stack().is_empty());
interpreter.push(0xDEADBEEFu32.into());

View File

@ -113,7 +113,7 @@ fn test_state_trie(trie_inputs: TrieInputs) -> Result<()> {
assert_eq!(interpreter.stack(), vec![]);
// Now, execute mpt_hash_state_trie.
interpreter.offset = mpt_hash_state_trie;
interpreter.generation_state.registers.program_counter = mpt_hash_state_trie;
interpreter.push(0xDEADBEEFu32.into());
interpreter.run()?;

View File

@ -164,7 +164,7 @@ fn test_state_trie(mut state_trie: PartialTrie, k: Nibbles, mut account: Account
assert_eq!(interpreter.stack(), vec![]);
// Next, execute mpt_insert_state_trie.
interpreter.offset = mpt_insert_state_trie;
interpreter.generation_state.registers.program_counter = mpt_insert_state_trie;
let trie_data = interpreter.get_trie_data_mut();
if trie_data.is_empty() {
// In the assembly we skip over 0, knowing trie_data[0] = 0 by default.
@ -194,7 +194,7 @@ fn test_state_trie(mut state_trie: PartialTrie, k: Nibbles, mut account: Account
);
// Now, execute mpt_hash_state_trie.
interpreter.offset = mpt_hash_state_trie;
interpreter.generation_state.registers.program_counter = mpt_hash_state_trie;
interpreter.push(0xDEADBEEFu32.into());
interpreter.run()?;

View File

@ -27,7 +27,7 @@ fn mpt_read() -> Result<()> {
assert_eq!(interpreter.stack(), vec![]);
// Now, execute mpt_read on the state trie.
interpreter.offset = mpt_read;
interpreter.generation_state.registers.program_counter = mpt_read;
interpreter.push(0xdeadbeefu32.into());
interpreter.push(0xABCDEFu64.into());
interpreter.push(6.into());