From b96c22a4f44824ca5436eb3948cb689e6a8f2d75 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Fri, 2 Dec 2022 19:18:37 -0800 Subject: [PATCH] Interpreter fixes --- evm/src/cpu/kernel/asm/exp.asm | 2 +- evm/src/cpu/kernel/interpreter.rs | 154 ++++++++++++++--------- evm/src/cpu/kernel/tests/account_code.rs | 17 +-- evm/src/cpu/kernel/tests/balance.rs | 8 +- evm/src/cpu/kernel/tests/mpt/hash.rs | 2 +- evm/src/cpu/kernel/tests/mpt/insert.rs | 4 +- evm/src/cpu/kernel/tests/mpt/read.rs | 2 +- 7 files changed, 113 insertions(+), 76 deletions(-) diff --git a/evm/src/cpu/kernel/asm/exp.asm b/evm/src/cpu/kernel/asm/exp.asm index f025e312..0aa40048 100644 --- a/evm/src/cpu/kernel/asm/exp.asm +++ b/evm/src/cpu/kernel/asm/exp.asm @@ -73,4 +73,4 @@ recursion_return: jump global sys_exp: - PANIC + PANIC // TODO: Implement. diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index d2ce43ed..1ab5f734 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -16,6 +16,7 @@ use crate::generation::state::GenerationState; use crate::generation::GenerationInputs; use crate::memory::segments::Segment; use crate::witness::memory::{MemoryContextState, MemorySegmentState, MemoryState}; +use crate::witness::util::stack_peek; type F = GoldilocksField; @@ -23,13 +24,6 @@ type F = GoldilocksField; const DEFAULT_HALT_OFFSET: usize = 0xdeadbeef; impl MemoryState { - fn with_code_and_stack(code: &[u8], stack: Vec) -> Self { - let mut mem = Self::new(code); - mem.contexts[0].segments[Segment::Stack as usize].content = stack; - - mem - } - fn mload_general(&self, context: usize, segment: Segment, offset: usize) -> U256 { let value = self.contexts[context].segments[segment as usize].get(offset); assert!( @@ -53,9 +47,7 @@ impl MemoryState { pub struct Interpreter<'a> { kernel_mode: bool, jumpdests: Vec, - pub(crate) offset: usize, pub(crate) context: usize, - pub(crate) memory: MemoryState, pub(crate) generation_state: GenerationState, prover_inputs_map: &'a HashMap, pub(crate) halt_offsets: Vec, @@ -103,11 +95,9 @@ impl<'a> Interpreter<'a> { initial_stack: Vec, prover_inputs: &'a HashMap, ) -> Self { - Self { + let mut result = Self { kernel_mode: true, jumpdests: find_jumpdests(code), - offset: initial_offset, - memory: MemoryState::with_code_and_stack(code, initial_stack), generation_state: GenerationState::new(GenerationInputs::default(), code), prover_inputs_map: prover_inputs, context: 0, @@ -115,7 +105,11 @@ impl<'a> Interpreter<'a> { debug_offsets: vec![], running: false, opcode_count: [0; 0x100], - } + }; + result.generation_state.registers.program_counter = initial_offset; + result.generation_state.registers.stack_len = initial_stack.len(); + *result.stack_mut() = initial_stack; + result } pub(crate) fn run(&mut self) -> anyhow::Result<()> { @@ -133,47 +127,51 @@ impl<'a> Interpreter<'a> { } fn code(&self) -> &MemorySegmentState { - &self.memory.contexts[self.context].segments[Segment::Code as usize] + &self.generation_state.memory.contexts[self.context].segments[Segment::Code as usize] } fn code_slice(&self, n: usize) -> Vec { - self.code().content[self.offset..self.offset + n] + let pc = self.generation_state.registers.program_counter; + self.code().content[pc..pc + n] .iter() .map(|u256| u256.byte(0)) .collect::>() } pub(crate) fn get_txn_field(&self, field: NormalizedTxnField) -> U256 { - self.memory.contexts[0].segments[Segment::TxnFields as usize].get(field as usize) + self.generation_state.memory.contexts[0].segments[Segment::TxnFields as usize] + .get(field as usize) } pub(crate) fn set_txn_field(&mut self, field: NormalizedTxnField, value: U256) { - self.memory.contexts[0].segments[Segment::TxnFields as usize].set(field as usize, value); + self.generation_state.memory.contexts[0].segments[Segment::TxnFields as usize] + .set(field as usize, value); } pub(crate) fn get_txn_data(&self) -> &[U256] { - &self.memory.contexts[0].segments[Segment::TxnData as usize].content + &self.generation_state.memory.contexts[0].segments[Segment::TxnData as usize].content } pub(crate) fn get_global_metadata_field(&self, field: GlobalMetadata) -> U256 { - self.memory.contexts[0].segments[Segment::GlobalMetadata as usize].get(field as usize) + self.generation_state.memory.contexts[0].segments[Segment::GlobalMetadata as usize] + .get(field as usize) } pub(crate) fn set_global_metadata_field(&mut self, field: GlobalMetadata, value: U256) { - self.memory.contexts[0].segments[Segment::GlobalMetadata as usize] + self.generation_state.memory.contexts[0].segments[Segment::GlobalMetadata as usize] .set(field as usize, value) } pub(crate) fn get_trie_data(&self) -> &[U256] { - &self.memory.contexts[0].segments[Segment::TrieData as usize].content + &self.generation_state.memory.contexts[0].segments[Segment::TrieData as usize].content } pub(crate) fn get_trie_data_mut(&mut self) -> &mut Vec { - &mut self.memory.contexts[0].segments[Segment::TrieData as usize].content + &mut self.generation_state.memory.contexts[0].segments[Segment::TrieData as usize].content } pub(crate) fn get_rlp_memory(&self) -> Vec { - self.memory.contexts[0].segments[Segment::RlpRaw as usize] + self.generation_state.memory.contexts[0].segments[Segment::RlpRaw as usize] .content .iter() .map(|x| x.as_u32() as u8) @@ -181,21 +179,24 @@ impl<'a> Interpreter<'a> { } pub(crate) fn set_rlp_memory(&mut self, rlp: Vec) { - self.memory.contexts[0].segments[Segment::RlpRaw as usize].content = + self.generation_state.memory.contexts[0].segments[Segment::RlpRaw as usize].content = rlp.into_iter().map(U256::from).collect(); } pub(crate) fn set_code(&mut self, context: usize, code: Vec) { assert_ne!(context, 0, "Can't modify kernel code."); - while self.memory.contexts.len() <= context { - self.memory.contexts.push(MemoryContextState::default()); + while self.generation_state.memory.contexts.len() <= context { + self.generation_state + .memory + .contexts + .push(MemoryContextState::default()); } - self.memory.contexts[context].segments[Segment::Code as usize].content = + self.generation_state.memory.contexts[context].segments[Segment::Code as usize].content = code.into_iter().map(U256::from).collect(); } pub(crate) fn get_jumpdest_bits(&self, context: usize) -> Vec { - self.memory.contexts[context].segments[Segment::JumpdestBits as usize] + self.generation_state.memory.contexts[context].segments[Segment::JumpdestBits as usize] .content .iter() .map(|x| x.bit(0)) @@ -203,19 +204,22 @@ impl<'a> Interpreter<'a> { } fn incr(&mut self, n: usize) { - self.offset += n; + self.generation_state.registers.program_counter += n; } pub(crate) fn stack(&self) -> &[U256] { - &self.memory.contexts[self.context].segments[Segment::Stack as usize].content + &self.generation_state.memory.contexts[self.context].segments[Segment::Stack as usize] + .content } fn stack_mut(&mut self) -> &mut Vec { - &mut self.memory.contexts[self.context].segments[Segment::Stack as usize].content + &mut self.generation_state.memory.contexts[self.context].segments[Segment::Stack as usize] + .content } pub(crate) fn push(&mut self, x: U256) { self.stack_mut().push(x); + self.generation_state.registers.stack_len += 1; } fn push_bool(&mut self, x: bool) { @@ -223,11 +227,15 @@ impl<'a> Interpreter<'a> { } pub(crate) fn pop(&mut self) -> U256 { - self.stack_mut().pop().expect("Pop on empty stack.") + let result = stack_peek(&self.generation_state, 0); + self.generation_state.registers.stack_len -= 1; + let new_len = self.stack_len(); + self.stack_mut().truncate(new_len); + result.expect("Empty stack") } fn run_opcode(&mut self) -> anyhow::Result<()> { - let opcode = self.code().get(self.offset).byte(0); + let opcode = self.code().get(self.generation_state.registers.program_counter).byte(0); self.opcode_count[opcode as usize] += 1; self.incr(1); match opcode { @@ -327,7 +335,7 @@ impl<'a> Interpreter<'a> { _ => bail!("Unrecognized opcode {}.", opcode), }; - if self.debug_offsets.contains(&self.offset) { + if self.debug_offsets.contains(&self.generation_state.registers.program_counter) { println!("At {}, stack={:?}", self.offset_name(), self.stack()); } else if let Some(label) = self.offset_label() { println!("At {label}"); @@ -337,11 +345,11 @@ impl<'a> Interpreter<'a> { } fn offset_name(&self) -> String { - KERNEL.offset_name(self.offset) + KERNEL.offset_name(self.generation_state.registers.program_counter) } fn offset_label(&self) -> Option { - KERNEL.offset_label(self.offset) + KERNEL.offset_label(self.generation_state.registers.program_counter) } fn run_stop(&mut self) { @@ -503,7 +511,8 @@ impl<'a> Interpreter<'a> { let size = self.pop().as_usize(); let bytes = (offset..offset + size) .map(|i| { - self.memory + self.generation_state + .memory .mload_general(self.context, Segment::MainMemory, i) .byte(0) }) @@ -520,7 +529,12 @@ impl<'a> Interpreter<'a> { let offset = self.pop().as_usize(); let size = self.pop().as_usize(); let bytes = (offset..offset + size) - .map(|i| self.memory.mload_general(context, segment, i).byte(0)) + .map(|i| { + self.generation_state + .memory + .mload_general(context, segment, i) + .byte(0) + }) .collect::>(); println!("Hashing {:?}", &bytes); let hash = keccak(bytes); @@ -529,7 +543,8 @@ impl<'a> Interpreter<'a> { fn run_callvalue(&mut self) { self.push( - self.memory.contexts[self.context].segments[Segment::ContextMetadata as usize] + self.generation_state.memory.contexts[self.context].segments + [Segment::ContextMetadata as usize] .get(ContextMetadata::CallValue as usize), ) } @@ -539,7 +554,8 @@ impl<'a> Interpreter<'a> { let value = U256::from_big_endian( &(0..32) .map(|i| { - self.memory + self.generation_state + .memory .mload_general(self.context, Segment::Calldata, offset + i) .byte(0) }) @@ -550,7 +566,8 @@ impl<'a> Interpreter<'a> { fn run_calldatasize(&mut self) { self.push( - self.memory.contexts[self.context].segments[Segment::ContextMetadata as usize] + self.generation_state.memory.contexts[self.context].segments + [Segment::ContextMetadata as usize] .get(ContextMetadata::CalldataSize as usize), ) } @@ -560,10 +577,12 @@ impl<'a> Interpreter<'a> { let offset = self.pop().as_usize(); let size = self.pop().as_usize(); for i in 0..size { - let calldata_byte = - self.memory - .mload_general(self.context, Segment::Calldata, offset + i); - self.memory.mstore_general( + let calldata_byte = self.generation_state.memory.mload_general( + self.context, + Segment::Calldata, + offset + i, + ); + self.generation_state.memory.mstore_general( self.context, Segment::MainMemory, dest_offset + i, @@ -575,7 +594,7 @@ impl<'a> Interpreter<'a> { fn run_prover_input(&mut self) -> anyhow::Result<()> { let prover_input_fn = self .prover_inputs_map - .get(&(self.offset - 1)) + .get(&(self.generation_state.registers.program_counter - 1)) .ok_or_else(|| anyhow!("Offset not in prover inputs."))?; let output = self.generation_state.prover_input(prover_input_fn); self.push(output); @@ -591,7 +610,8 @@ impl<'a> Interpreter<'a> { let value = U256::from_big_endian( &(0..32) .map(|i| { - self.memory + self.generation_state + .memory .mload_general(self.context, Segment::MainMemory, offset + i) .byte(0) }) @@ -606,15 +626,19 @@ impl<'a> Interpreter<'a> { let mut bytes = [0; 32]; value.to_big_endian(&mut bytes); for (i, byte) in (0..32).zip(bytes) { - self.memory - .mstore_general(self.context, Segment::MainMemory, offset + i, byte.into()); + self.generation_state.memory.mstore_general( + self.context, + Segment::MainMemory, + offset + i, + byte.into(), + ); } } fn run_mstore8(&mut self) { let offset = self.pop().as_usize(); let value = self.pop(); - self.memory.mstore_general( + self.generation_state.memory.mstore_general( self.context, Segment::MainMemory, offset, @@ -636,12 +660,13 @@ impl<'a> Interpreter<'a> { } fn run_pc(&mut self) { - self.push((self.offset - 1).into()); + self.push((self.generation_state.registers.program_counter - 1).into()); } fn run_msize(&mut self) { self.push( - self.memory.contexts[self.context].segments[Segment::ContextMetadata as usize] + self.generation_state.memory.contexts[self.context].segments + [Segment::ContextMetadata as usize] .get(ContextMetadata::MSize as usize), ) } @@ -656,7 +681,7 @@ impl<'a> Interpreter<'a> { panic!("Destination is not a JUMPDEST."); } - self.offset = offset; + self.generation_state.registers.program_counter = offset; if self.halt_offsets.contains(&offset) { self.running = false; @@ -670,11 +695,11 @@ impl<'a> Interpreter<'a> { } fn run_dup(&mut self, n: u8) { - self.push(self.stack()[self.stack().len() - n as usize]); + self.push(self.stack()[self.stack_len() - n as usize]); } fn run_swap(&mut self, n: u8) -> anyhow::Result<()> { - let len = self.stack().len(); + let len = self.stack_len(); ensure!(len > n as usize); self.stack_mut().swap(len - 1, len - n as usize - 1); Ok(()) @@ -693,7 +718,10 @@ impl<'a> Interpreter<'a> { let context = self.pop().as_usize(); let segment = Segment::all()[self.pop().as_usize()]; let offset = self.pop().as_usize(); - let value = self.memory.mload_general(context, segment, offset); + let value = self + .generation_state + .memory + .mload_general(context, segment, offset); assert!(value.bits() <= segment.bit_range()); self.push(value); } @@ -710,7 +738,13 @@ impl<'a> Interpreter<'a> { segment, segment.bit_range() ); - self.memory.mstore_general(context, segment, offset, value); + self.generation_state + .memory + .mstore_general(context, segment, offset, value); + } + + fn stack_len(&self) -> usize { + self.generation_state.registers.stack_len } } @@ -932,11 +966,13 @@ mod tests { let run = run(&code, 0, vec![], &pis)?; assert_eq!(run.stack(), &[0xff.into(), 0xff00.into()]); assert_eq!( - run.memory.contexts[0].segments[Segment::MainMemory as usize].get(0x27), + run.generation_state.memory.contexts[0].segments[Segment::MainMemory as usize] + .get(0x27), 0x42.into() ); assert_eq!( - run.memory.contexts[0].segments[Segment::MainMemory as usize].get(0x1f), + run.generation_state.memory.contexts[0].segments[Segment::MainMemory as usize] + .get(0x1f), 0xff.into() ); Ok(()) diff --git a/evm/src/cpu/kernel/tests/account_code.rs b/evm/src/cpu/kernel/tests/account_code.rs index 445ae3da..c6d7f156 100644 --- a/evm/src/cpu/kernel/tests/account_code.rs +++ b/evm/src/cpu/kernel/tests/account_code.rs @@ -42,7 +42,7 @@ fn prepare_interpreter( let mut state_trie: PartialTrie = Default::default(); let trie_inputs = Default::default(); - interpreter.offset = load_all_mpts; + interpreter.generation_state.registers.program_counter = load_all_mpts; interpreter.push(0xDEADBEEFu32.into()); interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); @@ -53,7 +53,7 @@ fn prepare_interpreter( keccak(address.to_fixed_bytes()).as_bytes(), )); // Next, execute mpt_insert_state_trie. - interpreter.offset = mpt_insert_state_trie; + interpreter.generation_state.registers.program_counter = mpt_insert_state_trie; let trie_data = interpreter.get_trie_data_mut(); if trie_data.is_empty() { // In the assembly we skip over 0, knowing trie_data[0] = 0 by default. @@ -83,7 +83,7 @@ fn prepare_interpreter( ); // Now, execute mpt_hash_state_trie. - interpreter.offset = mpt_hash_state_trie; + interpreter.generation_state.registers.program_counter = mpt_hash_state_trie; interpreter.push(0xDEADBEEFu32.into()); interpreter.run()?; @@ -115,7 +115,7 @@ fn test_extcodesize() -> Result<()> { let extcodesize = KERNEL.global_labels["extcodesize"]; // Test `extcodesize` - interpreter.offset = extcodesize; + interpreter.generation_state.registers.program_counter = extcodesize; interpreter.pop(); assert!(interpreter.stack().is_empty()); interpreter.push(0xDEADBEEFu32.into()); @@ -144,9 +144,10 @@ fn test_extcodecopy() -> Result<()> { // Put random data in main memory and the `KernelAccountCode` segment for realism. let mut rng = thread_rng(); for i in 0..2000 { - interpreter.memory.contexts[interpreter.context].segments[Segment::MainMemory as usize] + interpreter.generation_state.memory.contexts[interpreter.context].segments + [Segment::MainMemory as usize] .set(i, U256::from(rng.gen::())); - interpreter.memory.contexts[interpreter.context].segments + interpreter.generation_state.memory.contexts[interpreter.context].segments [Segment::KernelAccountCode as usize] .set(i, U256::from(rng.gen::())); } @@ -157,7 +158,7 @@ fn test_extcodecopy() -> Result<()> { let size = rng.gen_range(0..1500); // Test `extcodecopy` - interpreter.offset = extcodecopy; + interpreter.generation_state.registers.program_counter = extcodecopy; interpreter.pop(); assert!(interpreter.stack().is_empty()); interpreter.push(0xDEADBEEFu32.into()); @@ -172,7 +173,7 @@ fn test_extcodecopy() -> Result<()> { assert!(interpreter.stack().is_empty()); // Check that the code was correctly copied to memory. for i in 0..size { - let memory = interpreter.memory.contexts[interpreter.context].segments + let memory = interpreter.generation_state.memory.contexts[interpreter.context].segments [Segment::MainMemory as usize] .get(dest_offset + i); assert_eq!( diff --git a/evm/src/cpu/kernel/tests/balance.rs b/evm/src/cpu/kernel/tests/balance.rs index 1e784e85..b0e087a9 100644 --- a/evm/src/cpu/kernel/tests/balance.rs +++ b/evm/src/cpu/kernel/tests/balance.rs @@ -33,7 +33,7 @@ fn prepare_interpreter( let mut state_trie: PartialTrie = Default::default(); let trie_inputs = Default::default(); - interpreter.offset = load_all_mpts; + interpreter.generation_state.registers.program_counter = load_all_mpts; interpreter.push(0xDEADBEEFu32.into()); interpreter.generation_state.mpt_prover_inputs = all_mpt_prover_inputs_reversed(&trie_inputs); @@ -44,7 +44,7 @@ fn prepare_interpreter( keccak(address.to_fixed_bytes()).as_bytes(), )); // Next, execute mpt_insert_state_trie. - interpreter.offset = mpt_insert_state_trie; + interpreter.generation_state.registers.program_counter = mpt_insert_state_trie; let trie_data = interpreter.get_trie_data_mut(); if trie_data.is_empty() { // In the assembly we skip over 0, knowing trie_data[0] = 0 by default. @@ -74,7 +74,7 @@ fn prepare_interpreter( ); // Now, execute mpt_hash_state_trie. - interpreter.offset = mpt_hash_state_trie; + interpreter.generation_state.registers.program_counter = mpt_hash_state_trie; interpreter.push(0xDEADBEEFu32.into()); interpreter.run()?; @@ -105,7 +105,7 @@ fn test_balance() -> Result<()> { prepare_interpreter(&mut interpreter, address, &account)?; // Test `balance` - interpreter.offset = KERNEL.global_labels["balance"]; + interpreter.generation_state.registers.program_counter = KERNEL.global_labels["balance"]; interpreter.pop(); assert!(interpreter.stack().is_empty()); interpreter.push(0xDEADBEEFu32.into()); diff --git a/evm/src/cpu/kernel/tests/mpt/hash.rs b/evm/src/cpu/kernel/tests/mpt/hash.rs index 6321fb4b..6c6c6f63 100644 --- a/evm/src/cpu/kernel/tests/mpt/hash.rs +++ b/evm/src/cpu/kernel/tests/mpt/hash.rs @@ -113,7 +113,7 @@ fn test_state_trie(trie_inputs: TrieInputs) -> Result<()> { assert_eq!(interpreter.stack(), vec![]); // Now, execute mpt_hash_state_trie. - interpreter.offset = mpt_hash_state_trie; + interpreter.generation_state.registers.program_counter = mpt_hash_state_trie; interpreter.push(0xDEADBEEFu32.into()); interpreter.run()?; diff --git a/evm/src/cpu/kernel/tests/mpt/insert.rs b/evm/src/cpu/kernel/tests/mpt/insert.rs index 6e1ad573..cf546969 100644 --- a/evm/src/cpu/kernel/tests/mpt/insert.rs +++ b/evm/src/cpu/kernel/tests/mpt/insert.rs @@ -164,7 +164,7 @@ fn test_state_trie(mut state_trie: PartialTrie, k: Nibbles, mut account: Account assert_eq!(interpreter.stack(), vec![]); // Next, execute mpt_insert_state_trie. - interpreter.offset = mpt_insert_state_trie; + interpreter.generation_state.registers.program_counter = mpt_insert_state_trie; let trie_data = interpreter.get_trie_data_mut(); if trie_data.is_empty() { // In the assembly we skip over 0, knowing trie_data[0] = 0 by default. @@ -194,7 +194,7 @@ fn test_state_trie(mut state_trie: PartialTrie, k: Nibbles, mut account: Account ); // Now, execute mpt_hash_state_trie. - interpreter.offset = mpt_hash_state_trie; + interpreter.generation_state.registers.program_counter = mpt_hash_state_trie; interpreter.push(0xDEADBEEFu32.into()); interpreter.run()?; diff --git a/evm/src/cpu/kernel/tests/mpt/read.rs b/evm/src/cpu/kernel/tests/mpt/read.rs index d8808e24..62313f62 100644 --- a/evm/src/cpu/kernel/tests/mpt/read.rs +++ b/evm/src/cpu/kernel/tests/mpt/read.rs @@ -27,7 +27,7 @@ fn mpt_read() -> Result<()> { assert_eq!(interpreter.stack(), vec![]); // Now, execute mpt_read on the state trie. - interpreter.offset = mpt_read; + interpreter.generation_state.registers.program_counter = mpt_read; interpreter.push(0xdeadbeefu32.into()); interpreter.push(0xABCDEFu64.into()); interpreter.push(6.into());