This commit is contained in:
Daniel Lubarov 2022-12-01 12:46:04 -08:00
parent 3069363d35
commit 9bf47ef8ac
10 changed files with 93 additions and 139 deletions

View File

@ -11,43 +11,27 @@ use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::constants::context_metadata::ContextMetadata; use crate::cpu::kernel::constants::context_metadata::ContextMetadata;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata; use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
use crate::cpu::kernel::constants::txn_fields::NormalizedTxnField; use crate::cpu::kernel::constants::txn_fields::NormalizedTxnField;
use crate::generation::memory::{MemoryContextState, MemorySegmentState};
use crate::generation::prover_input::ProverInputFn; use crate::generation::prover_input::ProverInputFn;
use crate::generation::state::GenerationState; use crate::generation::state::GenerationState;
use crate::generation::GenerationInputs; use crate::generation::GenerationInputs;
use crate::memory::segments::Segment; use crate::memory::segments::Segment;
use crate::witness::memory::{MemoryContextState, MemorySegmentState, MemoryState};
type F = GoldilocksField; type F = GoldilocksField;
/// Halt interpreter execution whenever a jump to this offset is done. /// Halt interpreter execution whenever a jump to this offset is done.
const DEFAULT_HALT_OFFSET: usize = 0xdeadbeef; const DEFAULT_HALT_OFFSET: usize = 0xdeadbeef;
#[derive(Clone, Debug)] impl MemoryState {
pub(crate) struct InterpreterMemory {
pub(crate) context_memory: Vec<MemoryContextState>,
}
impl Default for InterpreterMemory {
fn default() -> Self {
Self {
context_memory: vec![MemoryContextState::default()],
}
}
}
impl InterpreterMemory {
fn with_code_and_stack(code: &[u8], stack: Vec<U256>) -> Self { fn with_code_and_stack(code: &[u8], stack: Vec<U256>) -> Self {
let mut mem = Self::default(); let mut mem = Self::new(code);
for (i, b) in code.iter().copied().enumerate() { mem.contexts[0].segments[Segment::Stack as usize].content = stack;
mem.context_memory[0].segments[Segment::Code as usize].set(i, b.into());
}
mem.context_memory[0].segments[Segment::Stack as usize].content = stack;
mem mem
} }
fn mload_general(&self, context: usize, segment: Segment, offset: usize) -> U256 { fn mload_general(&self, context: usize, segment: Segment, offset: usize) -> U256 {
let value = self.context_memory[context].segments[segment as usize].get(offset); let value = self.contexts[context].segments[segment as usize].get(offset);
assert!( assert!(
value.bits() <= segment.bit_range(), value.bits() <= segment.bit_range(),
"Value read from memory exceeds expected range of {:?} segment", "Value read from memory exceeds expected range of {:?} segment",
@ -62,7 +46,7 @@ impl InterpreterMemory {
"Value written to memory exceeds expected range of {:?} segment", "Value written to memory exceeds expected range of {:?} segment",
segment segment
); );
self.context_memory[context].segments[segment as usize].set(offset, value) self.contexts[context].segments[segment as usize].set(offset, value)
} }
} }
@ -71,7 +55,7 @@ pub struct Interpreter<'a> {
jumpdests: Vec<usize>, jumpdests: Vec<usize>,
pub(crate) offset: usize, pub(crate) offset: usize,
pub(crate) context: usize, pub(crate) context: usize,
pub(crate) memory: InterpreterMemory, pub(crate) memory: MemoryState,
pub(crate) generation_state: GenerationState<F>, pub(crate) generation_state: GenerationState<F>,
prover_inputs_map: &'a HashMap<usize, ProverInputFn>, prover_inputs_map: &'a HashMap<usize, ProverInputFn>,
pub(crate) halt_offsets: Vec<usize>, pub(crate) halt_offsets: Vec<usize>,
@ -123,8 +107,8 @@ impl<'a> Interpreter<'a> {
kernel_mode: true, kernel_mode: true,
jumpdests: find_jumpdests(code), jumpdests: find_jumpdests(code),
offset: initial_offset, offset: initial_offset,
memory: InterpreterMemory::with_code_and_stack(code, initial_stack), memory: MemoryState::with_code_and_stack(code, initial_stack),
generation_state: GenerationState::new(GenerationInputs::default()), generation_state: GenerationState::new(GenerationInputs::default(), code),
prover_inputs_map: prover_inputs, prover_inputs_map: prover_inputs,
context: 0, context: 0,
halt_offsets: vec![DEFAULT_HALT_OFFSET], halt_offsets: vec![DEFAULT_HALT_OFFSET],
@ -149,7 +133,7 @@ impl<'a> Interpreter<'a> {
} }
fn code(&self) -> &MemorySegmentState { fn code(&self) -> &MemorySegmentState {
&self.memory.context_memory[self.context].segments[Segment::Code as usize] &self.memory.contexts[self.context].segments[Segment::Code as usize]
} }
fn code_slice(&self, n: usize) -> Vec<u8> { fn code_slice(&self, n: usize) -> Vec<u8> {
@ -160,37 +144,36 @@ impl<'a> Interpreter<'a> {
} }
pub(crate) fn get_txn_field(&self, field: NormalizedTxnField) -> U256 { pub(crate) fn get_txn_field(&self, field: NormalizedTxnField) -> U256 {
self.memory.context_memory[0].segments[Segment::TxnFields as usize].get(field as usize) self.memory.contexts[0].segments[Segment::TxnFields as usize].get(field as usize)
} }
pub(crate) fn set_txn_field(&mut self, field: NormalizedTxnField, value: U256) { pub(crate) fn set_txn_field(&mut self, field: NormalizedTxnField, value: U256) {
self.memory.context_memory[0].segments[Segment::TxnFields as usize] self.memory.contexts[0].segments[Segment::TxnFields as usize].set(field as usize, value);
.set(field as usize, value);
} }
pub(crate) fn get_txn_data(&self) -> &[U256] { pub(crate) fn get_txn_data(&self) -> &[U256] {
&self.memory.context_memory[0].segments[Segment::TxnData as usize].content &self.memory.contexts[0].segments[Segment::TxnData as usize].content
} }
pub(crate) fn get_global_metadata_field(&self, field: GlobalMetadata) -> U256 { pub(crate) fn get_global_metadata_field(&self, field: GlobalMetadata) -> U256 {
self.memory.context_memory[0].segments[Segment::GlobalMetadata as usize].get(field as usize) self.memory.contexts[0].segments[Segment::GlobalMetadata as usize].get(field as usize)
} }
pub(crate) fn set_global_metadata_field(&mut self, field: GlobalMetadata, value: U256) { pub(crate) fn set_global_metadata_field(&mut self, field: GlobalMetadata, value: U256) {
self.memory.context_memory[0].segments[Segment::GlobalMetadata as usize] self.memory.contexts[0].segments[Segment::GlobalMetadata as usize]
.set(field as usize, value) .set(field as usize, value)
} }
pub(crate) fn get_trie_data(&self) -> &[U256] { pub(crate) fn get_trie_data(&self) -> &[U256] {
&self.memory.context_memory[0].segments[Segment::TrieData as usize].content &self.memory.contexts[0].segments[Segment::TrieData as usize].content
} }
pub(crate) fn get_trie_data_mut(&mut self) -> &mut Vec<U256> { pub(crate) fn get_trie_data_mut(&mut self) -> &mut Vec<U256> {
&mut self.memory.context_memory[0].segments[Segment::TrieData as usize].content &mut self.memory.contexts[0].segments[Segment::TrieData as usize].content
} }
pub(crate) fn get_rlp_memory(&self) -> Vec<u8> { pub(crate) fn get_rlp_memory(&self) -> Vec<u8> {
self.memory.context_memory[0].segments[Segment::RlpRaw as usize] self.memory.contexts[0].segments[Segment::RlpRaw as usize]
.content .content
.iter() .iter()
.map(|x| x.as_u32() as u8) .map(|x| x.as_u32() as u8)
@ -198,23 +181,21 @@ impl<'a> Interpreter<'a> {
} }
pub(crate) fn set_rlp_memory(&mut self, rlp: Vec<u8>) { pub(crate) fn set_rlp_memory(&mut self, rlp: Vec<u8>) {
self.memory.context_memory[0].segments[Segment::RlpRaw as usize].content = self.memory.contexts[0].segments[Segment::RlpRaw as usize].content =
rlp.into_iter().map(U256::from).collect(); rlp.into_iter().map(U256::from).collect();
} }
pub(crate) fn set_code(&mut self, context: usize, code: Vec<u8>) { pub(crate) fn set_code(&mut self, context: usize, code: Vec<u8>) {
assert_ne!(context, 0, "Can't modify kernel code."); assert_ne!(context, 0, "Can't modify kernel code.");
while self.memory.context_memory.len() <= context { while self.memory.contexts.len() <= context {
self.memory self.memory.contexts.push(MemoryContextState::default());
.context_memory
.push(MemoryContextState::default());
} }
self.memory.context_memory[context].segments[Segment::Code as usize].content = self.memory.contexts[context].segments[Segment::Code as usize].content =
code.into_iter().map(U256::from).collect(); code.into_iter().map(U256::from).collect();
} }
pub(crate) fn get_jumpdest_bits(&self, context: usize) -> Vec<bool> { pub(crate) fn get_jumpdest_bits(&self, context: usize) -> Vec<bool> {
self.memory.context_memory[context].segments[Segment::JumpdestBits as usize] self.memory.contexts[context].segments[Segment::JumpdestBits as usize]
.content .content
.iter() .iter()
.map(|x| x.bit(0)) .map(|x| x.bit(0))
@ -226,11 +207,11 @@ impl<'a> Interpreter<'a> {
} }
pub(crate) fn stack(&self) -> &[U256] { pub(crate) fn stack(&self) -> &[U256] {
&self.memory.context_memory[self.context].segments[Segment::Stack as usize].content &self.memory.contexts[self.context].segments[Segment::Stack as usize].content
} }
fn stack_mut(&mut self) -> &mut Vec<U256> { fn stack_mut(&mut self) -> &mut Vec<U256> {
&mut self.memory.context_memory[self.context].segments[Segment::Stack as usize].content &mut self.memory.contexts[self.context].segments[Segment::Stack as usize].content
} }
pub(crate) fn push(&mut self, x: U256) { pub(crate) fn push(&mut self, x: U256) {
@ -548,7 +529,7 @@ impl<'a> Interpreter<'a> {
fn run_callvalue(&mut self) { fn run_callvalue(&mut self) {
self.push( self.push(
self.memory.context_memory[self.context].segments[Segment::ContextMetadata as usize] self.memory.contexts[self.context].segments[Segment::ContextMetadata as usize]
.get(ContextMetadata::CallValue as usize), .get(ContextMetadata::CallValue as usize),
) )
} }
@ -569,7 +550,7 @@ impl<'a> Interpreter<'a> {
fn run_calldatasize(&mut self) { fn run_calldatasize(&mut self) {
self.push( self.push(
self.memory.context_memory[self.context].segments[Segment::ContextMetadata as usize] self.memory.contexts[self.context].segments[Segment::ContextMetadata as usize]
.get(ContextMetadata::CalldataSize as usize), .get(ContextMetadata::CalldataSize as usize),
) )
} }
@ -596,8 +577,7 @@ impl<'a> Interpreter<'a> {
.prover_inputs_map .prover_inputs_map
.get(&(self.offset - 1)) .get(&(self.offset - 1))
.ok_or_else(|| anyhow!("Offset not in prover inputs."))?; .ok_or_else(|| anyhow!("Offset not in prover inputs."))?;
let stack = self.stack().to_vec(); let output = self.generation_state.prover_input(prover_input_fn);
let output = self.generation_state.prover_input(&stack, prover_input_fn);
self.push(output); self.push(output);
Ok(()) Ok(())
} }
@ -661,7 +641,7 @@ impl<'a> Interpreter<'a> {
fn run_msize(&mut self) { fn run_msize(&mut self) {
self.push( self.push(
self.memory.context_memory[self.context].segments[Segment::ContextMetadata as usize] self.memory.contexts[self.context].segments[Segment::ContextMetadata as usize]
.get(ContextMetadata::MSize as usize), .get(ContextMetadata::MSize as usize),
) )
} }
@ -952,11 +932,11 @@ mod tests {
let run = run(&code, 0, vec![], &pis)?; let run = run(&code, 0, vec![], &pis)?;
assert_eq!(run.stack(), &[0xff.into(), 0xff00.into()]); assert_eq!(run.stack(), &[0xff.into(), 0xff00.into()]);
assert_eq!( assert_eq!(
run.memory.context_memory[0].segments[Segment::MainMemory as usize].get(0x27), run.memory.contexts[0].segments[Segment::MainMemory as usize].get(0x27),
0x42.into() 0x42.into()
); );
assert_eq!( assert_eq!(
run.memory.context_memory[0].segments[Segment::MainMemory as usize].get(0x1f), run.memory.contexts[0].segments[Segment::MainMemory as usize].get(0x1f),
0xff.into() 0xff.into()
); );
Ok(()) Ok(())

View File

@ -144,10 +144,9 @@ fn test_extcodecopy() -> Result<()> {
// Put random data in main memory and the `KernelAccountCode` segment for realism. // Put random data in main memory and the `KernelAccountCode` segment for realism.
let mut rng = thread_rng(); let mut rng = thread_rng();
for i in 0..2000 { for i in 0..2000 {
interpreter.memory.context_memory[interpreter.context].segments interpreter.memory.contexts[interpreter.context].segments[Segment::MainMemory as usize]
[Segment::MainMemory as usize]
.set(i, U256::from(rng.gen::<u8>())); .set(i, U256::from(rng.gen::<u8>()));
interpreter.memory.context_memory[interpreter.context].segments interpreter.memory.contexts[interpreter.context].segments
[Segment::KernelAccountCode as usize] [Segment::KernelAccountCode as usize]
.set(i, U256::from(rng.gen::<u8>())); .set(i, U256::from(rng.gen::<u8>()));
} }
@ -173,7 +172,7 @@ fn test_extcodecopy() -> Result<()> {
assert!(interpreter.stack().is_empty()); assert!(interpreter.stack().is_empty());
// Check that the code was correctly copied to memory. // Check that the code was correctly copied to memory.
for i in 0..size { for i in 0..size {
let memory = interpreter.memory.context_memory[interpreter.context].segments let memory = interpreter.memory.contexts[interpreter.context].segments
[Segment::MainMemory as usize] [Segment::MainMemory as usize]
.get(dest_offset + i); .get(dest_offset + i);
assert_eq!( assert_eq!(

View File

@ -1,51 +0,0 @@
// TODO: Remove?
// use ethereum_types::U256;
//
// use crate::memory::memory_stark::MemoryOp;
// use crate::memory::segments::Segment;
//
// #[allow(unused)] // TODO: Should be used soon.
// #[derive(Debug)]
// pub(crate) struct MemoryState {
// /// A log of each memory operation, in the order that it occurred.
// pub log: Vec<MemoryOp>,
//
// pub contexts: Vec<MemoryContextState>,
// }
//
// impl Default for MemoryState {
// fn default() -> Self {
// Self {
// log: vec![],
// // We start with an initial context for the kernel.
// contexts: vec![MemoryContextState::default()],
// }
// }
// }
//
// #[derive(Clone, Default, Debug)]
// pub(crate) struct MemoryContextState {
// /// The content of each memory segment.
// pub segments: [MemorySegmentState; Segment::COUNT],
// }
//
// #[derive(Clone, Default, Debug)]
// pub(crate) struct MemorySegmentState {
// pub content: Vec<U256>,
// }
//
// impl MemorySegmentState {
// pub(crate) fn get(&self, virtual_addr: usize) -> U256 {
// self.content
// .get(virtual_addr)
// .copied()
// .unwrap_or(U256::zero())
// }
//
// pub(crate) fn set(&mut self, virtual_addr: usize, value: U256) {
// if virtual_addr >= self.content.len() {
// self.content.resize(virtual_addr + 1, U256::zero());
// }
// self.content[virtual_addr] = value;
// }
// }

View File

@ -19,7 +19,6 @@ use crate::proof::{BlockMetadata, PublicValues, TrieRoots};
use crate::witness::memory::MemoryAddress; use crate::witness::memory::MemoryAddress;
use crate::witness::transition::transition; use crate::witness::transition::transition;
pub(crate) mod memory;
pub(crate) mod mpt; pub(crate) mod mpt;
pub(crate) mod prover_input; pub(crate) mod prover_input;
pub(crate) mod rlp; pub(crate) mod rlp;

View File

@ -1,7 +1,6 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use itertools::Itertools; use itertools::Itertools;
use log::info;
use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::extension::{Extendable, FieldExtension};
use plonky2::field::packed::PackedField; use plonky2::field::packed::PackedField;
use plonky2::field::polynomial::PolynomialValues; use plonky2::field::polynomial::PolynomialValues;
@ -55,7 +54,6 @@ impl<F: RichField + Extendable<D>, const D: usize> KeccakStark<F, D> {
inputs: Vec<[u64; NUM_INPUTS]>, inputs: Vec<[u64; NUM_INPUTS]>,
) -> Vec<[F; NUM_COLUMNS]> { ) -> Vec<[F; NUM_COLUMNS]> {
let num_rows = (inputs.len() * NUM_ROUNDS).next_power_of_two(); let num_rows = (inputs.len() * NUM_ROUNDS).next_power_of_two();
info!("{} rows", num_rows);
let mut rows = Vec::with_capacity(num_rows); let mut rows = Vec::with_capacity(num_rows);
for input in inputs.iter() { for input in inputs.iter() {
rows.extend(self.generate_trace_rows_for_perm(*input)); rows.extend(self.generate_trace_rows_for_perm(*input));

View File

@ -408,6 +408,7 @@ mod tests {
use crate::keccak_sponge::keccak_sponge_stark::{KeccakSpongeOp, KeccakSpongeStark}; use crate::keccak_sponge::keccak_sponge_stark::{KeccakSpongeOp, KeccakSpongeStark};
use crate::memory::segments::Segment; use crate::memory::segments::Segment;
use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree};
use crate::witness::memory::MemoryAddress;
#[test] #[test]
fn test_stark_degree() -> Result<()> { fn test_stark_degree() -> Result<()> {
@ -441,9 +442,11 @@ mod tests {
let expected_output = keccak(&input); let expected_output = keccak(&input);
let op = KeccakSpongeOp { let op = KeccakSpongeOp {
context: 0, base_address: MemoryAddress {
segment: Segment::Code, context: 0,
virt: 0, segment: Segment::Code as usize,
virt: 0,
},
timestamp: 0, timestamp: 0,
len: input.len(), len: input.len(),
input, input,

View File

@ -464,6 +464,8 @@ pub(crate) mod tests {
use crate::memory::segments::Segment; use crate::memory::segments::Segment;
use crate::memory::NUM_CHANNELS; use crate::memory::NUM_CHANNELS;
use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree};
use crate::witness::memory::MemoryAddress;
use crate::witness::memory::MemoryOpKind::{Read, Write};
pub(crate) fn generate_random_memory_ops<R: Rng>(num_ops: usize, rng: &mut R) -> Vec<MemoryOp> { pub(crate) fn generate_random_memory_ops<R: Rng>(num_ops: usize, rng: &mut R) -> Vec<MemoryOp> {
let mut memory_ops = Vec::new(); let mut memory_ops = Vec::new();
@ -525,10 +527,12 @@ pub(crate) mod tests {
memory_ops.push(MemoryOp { memory_ops.push(MemoryOp {
filter: true, filter: true,
timestamp, timestamp,
is_read, address: MemoryAddress {
context, context,
segment, segment: segment as usize,
virt, virt,
},
op: if is_read { Read } else { Write },
value: vals, value: vals,
}); });
} }

View File

@ -45,6 +45,7 @@ pub fn trace_rows_to_poly_values<F: Field, const COLUMNS: usize>(
.collect() .collect()
} }
#[allow(unused)] // TODO: Remove?
/// Returns the 32-bit little-endian limbs of a `U256`. /// Returns the 32-bit little-endian limbs of a `U256`.
pub(crate) fn u256_limbs<F: Field>(u256: U256) -> [F; 8] { pub(crate) fn u256_limbs<F: Field>(u256: U256) -> [F; 8] {
u256.0 u256.0

View File

@ -1,5 +1,3 @@
use std::collections::HashMap;
use ethereum_types::U256; use ethereum_types::U256;
use crate::cpu::membus::{NUM_CHANNELS, NUM_GP_CHANNELS}; use crate::cpu::membus::{NUM_CHANNELS, NUM_GP_CHANNELS};
@ -87,24 +85,17 @@ impl MemoryOp {
} }
} }
#[derive(Clone, Default, Debug)] #[derive(Clone, Debug)]
pub struct MemoryState { pub struct MemoryState {
contents: HashMap<MemoryAddress, U256>, pub(crate) contexts: Vec<MemoryContextState>,
} }
impl MemoryState { impl MemoryState {
pub fn new(kernel_code: &[u8]) -> Self { pub fn new(kernel_code: &[u8]) -> Self {
let mut contents = HashMap::new(); let code_u256s = kernel_code.iter().map(|&x| x.into()).collect();
let mut result = Self::default();
for (i, &byte) in kernel_code.iter().enumerate() { result.contexts[0].segments[Segment::Code as usize].content = code_u256s;
if byte != 0 { result
let address = MemoryAddress::new(0, Segment::Code, i);
let val = byte.into();
contents.insert(address, val);
}
}
Self { contents }
} }
pub fn apply_ops(&mut self, ops: &[MemoryOp]) { pub fn apply_ops(&mut self, ops: &[MemoryOp]) {
@ -119,17 +110,46 @@ impl MemoryState {
} }
pub fn get(&self, address: MemoryAddress) -> U256 { pub fn get(&self, address: MemoryAddress) -> U256 {
self.contents self.contexts[address.context].segments[address.segment].get(address.virt)
.get(&address)
.copied()
.unwrap_or_else(U256::zero)
} }
pub fn set(&mut self, address: MemoryAddress, val: U256) { pub fn set(&mut self, address: MemoryAddress, val: U256) {
if val.is_zero() { self.contexts[address.context].segments[address.segment].set(address.virt, val);
self.contents.remove(&address); }
} else { }
self.contents.insert(address, val);
impl Default for MemoryState {
fn default() -> Self {
Self {
// We start with an initial context for the kernel.
contexts: vec![MemoryContextState::default()],
} }
} }
} }
#[derive(Clone, Default, Debug)]
pub(crate) struct MemoryContextState {
/// The content of each memory segment.
pub(crate) segments: [MemorySegmentState; Segment::COUNT],
}
#[derive(Clone, Default, Debug)]
pub(crate) struct MemorySegmentState {
pub(crate) content: Vec<U256>,
}
impl MemorySegmentState {
pub(crate) fn get(&self, virtual_addr: usize) -> U256 {
self.content
.get(virtual_addr)
.copied()
.unwrap_or(U256::zero())
}
pub(crate) fn set(&mut self, virtual_addr: usize, value: U256) {
if virtual_addr >= self.content.len() {
self.content.resize(virtual_addr + 1, U256::zero());
}
self.content[virtual_addr] = value;
}
}

View File

@ -230,6 +230,7 @@ fn try_perform_instruction<F: Field>(state: &mut GenerationState<F>) -> Result<(
let op = decode(state.registers, opcode)?; let op = decode(state.registers, opcode)?;
let pc = state.registers.program_counter; let pc = state.registers.program_counter;
log::trace!("Cycle {}", state.traces.clock());
log::trace!("Executing {:?} at {}", op, KERNEL.offset_name(pc)); log::trace!("Executing {:?} at {}", op, KERNEL.offset_name(pc));
log::trace!( log::trace!(
"Stack: {:?}", "Stack: {:?}",