From ab5abc391d308cada7f634b50aa201451bde034d Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Sat, 16 Jul 2022 09:59:23 -0700 Subject: [PATCH] Organize segments in an enum It's a bit more type-safe (can't mix up segment with context or virtual addr), and this way uniqueness of ordinals is enforced, partially addressing a concern raised in #591. To avoid making `Segment` public (which I don't think would be appropriate), I had to make some other visibility changes, and had to move `generate_random_memory_ops` into the test module. --- evm/src/all_stark.rs | 3 +- evm/src/cpu/bootstrap_kernel.rs | 5 +- evm/src/cpu/kernel/aggregator.rs | 5 +- evm/src/generation/memory.rs | 4 +- evm/src/generation/state.rs | 9 +- evm/src/memory/memory_stark.rs | 175 ++++++++++++++++--------------- evm/src/memory/segments.rs | 77 ++++++++++---- 7 files changed, 165 insertions(+), 113 deletions(-) diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index f02e0202..ba157fc0 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -146,7 +146,8 @@ mod tests { use crate::cross_table_lookup::testutils::check_ctls; use crate::keccak::keccak_stark::{KeccakStark, NUM_INPUTS, NUM_ROUNDS}; use crate::logic::{self, LogicStark, Operation}; - use crate::memory::memory_stark::{generate_random_memory_ops, MemoryStark}; + use crate::memory::memory_stark::tests::generate_random_memory_ops; + use crate::memory::memory_stark::MemoryStark; use crate::memory::NUM_CHANNELS; use crate::proof::AllProof; use crate::prover::prove; diff --git a/evm/src/cpu/bootstrap_kernel.rs b/evm/src/cpu/bootstrap_kernel.rs index 7b11bf60..bb0e2be9 100644 --- a/evm/src/cpu/bootstrap_kernel.rs +++ b/evm/src/cpu/bootstrap_kernel.rs @@ -18,7 +18,8 @@ use crate::cpu::kernel::keccak_util::keccakf_u32s; use crate::cpu::public_inputs::NUM_PUBLIC_INPUTS; use crate::generation::state::GenerationState; use crate::memory; -use crate::memory::{segments, NUM_CHANNELS}; +use crate::memory::segments::Segment; +use crate::memory::NUM_CHANNELS; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; /// The Keccak rate (1088 bits), measured in bytes. @@ -53,7 +54,7 @@ pub(crate) fn generate_bootstrap_kernel(state: &mut GenerationState value[0] = F::from_canonical_u8(byte); let channel = addr % NUM_CHANNELS; - state.set_mem_current(channel, segments::CODE, addr, value); + state.set_mem_current(channel, Segment::Code, addr, value); packed_bytes = (packed_bytes << 8) | byte as u32; } diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index 6ca88ba1..28a9c597 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -8,12 +8,15 @@ use once_cell::sync::Lazy; use super::assembler::{assemble, Kernel}; use crate::cpu::kernel::parser::parse; +use crate::memory::segments::Segment; pub static KERNEL: Lazy = Lazy::new(combined_kernel); pub fn evm_constants() -> HashMap { let mut c = HashMap::new(); - c.insert("SEGMENT_ID_TXN_DATA".into(), 0.into()); // TODO: Replace with actual segment ID. + for segment in Segment::all() { + c.insert(segment.var_name().into(), (segment as u32).into()); + } c } diff --git a/evm/src/generation/memory.rs b/evm/src/generation/memory.rs index dfff4388..2ef46d15 100644 --- a/evm/src/generation/memory.rs +++ b/evm/src/generation/memory.rs @@ -1,7 +1,7 @@ use plonky2::field::types::Field; use crate::memory::memory_stark::MemoryOp; -use crate::memory::segments::NUM_SEGMENTS; +use crate::memory::segments::Segment; use crate::memory::VALUE_LIMBS; #[allow(unused)] // TODO: Should be used soon. @@ -26,7 +26,7 @@ impl Default for MemoryState { #[derive(Default, Debug)] pub(crate) struct MemoryContextState { /// The content of each memory segment. - pub segments: [MemorySegmentState; NUM_SEGMENTS], + pub segments: [MemorySegmentState; Segment::COUNT], } #[derive(Default, Debug)] diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index 7a95f7e4..c5a6bbc4 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -6,6 +6,7 @@ use plonky2::field::types::Field; use crate::cpu::columns::{CpuColumnsView, NUM_CPU_COLUMNS}; use crate::generation::memory::MemoryState; use crate::memory::memory_stark::MemoryOp; +use crate::memory::segments::Segment; use crate::{keccak, logic}; #[derive(Debug)] @@ -52,12 +53,12 @@ impl GenerationState { pub(crate) fn get_mem_current( &mut self, channel_index: usize, - segment: usize, + segment: Segment, virt: usize, ) -> [F; crate::memory::VALUE_LIMBS] { let timestamp = self.cpu_rows.len(); let context = self.current_context; - let value = self.memory.contexts[context].segments[segment].get(virt); + let value = self.memory.contexts[context].segments[segment as usize].get(virt); self.memory.log.push(MemoryOp { channel_index: Some(channel_index), timestamp, @@ -74,7 +75,7 @@ impl GenerationState { pub(crate) fn set_mem_current( &mut self, channel_index: usize, - segment: usize, + segment: Segment, virt: usize, value: [F; crate::memory::VALUE_LIMBS], ) { @@ -89,7 +90,7 @@ impl GenerationState { virt, value, }); - self.memory.contexts[context].segments[segment].set(virt, value) + self.memory.contexts[context].segments[segment as usize].set(virt, value) } pub(crate) fn commit_cpu_row(&mut self) { diff --git a/evm/src/memory/memory_stark.rs b/evm/src/memory/memory_stark.rs index f150323b..843dfc2f 100644 --- a/evm/src/memory/memory_stark.rs +++ b/evm/src/memory/memory_stark.rs @@ -1,4 +1,3 @@ -use std::collections::{HashMap, HashSet}; use std::marker::PhantomData; use itertools::Itertools; @@ -10,7 +9,6 @@ use plonky2::hash::hash_types::RichField; use plonky2::timed; use plonky2::util::timing::TimingTree; use plonky2::util::transpose; -use rand::Rng; use rayon::prelude::*; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; @@ -21,6 +19,7 @@ use crate::memory::columns::{ COUNTER, COUNTER_PERMUTED, IS_READ, NUM_COLUMNS, RANGE_CHECK, RANGE_CHECK_PERMUTED, SEGMENT_FIRST_CHANGE, TIMESTAMP, VIRTUAL_FIRST_CHANGE, }; +use crate::memory::segments::Segment; use crate::memory::{NUM_CHANNELS, VALUE_LIMBS}; use crate::permutation::PermutationPair; use crate::stark::Stark; @@ -46,13 +45,13 @@ pub struct MemoryStark { } #[derive(Clone, Debug)] -pub struct MemoryOp { +pub(crate) struct MemoryOp { /// The channel this operation came from, or `None` if it's a dummy operation for padding. pub channel_index: Option, pub timestamp: usize, pub is_read: bool, pub context: usize, - pub segment: usize, + pub segment: Segment, pub virt: usize, pub value: [F; 8], } @@ -70,7 +69,7 @@ impl MemoryOp { row[TIMESTAMP] = F::from_canonical_usize(self.timestamp); row[IS_READ] = F::from_bool(self.is_read); row[ADDR_CONTEXT] = F::from_canonical_usize(self.context); - row[ADDR_SEGMENT] = F::from_canonical_usize(self.segment); + row[ADDR_SEGMENT] = F::from_canonical_usize(self.segment as usize); row[ADDR_VIRTUAL] = F::from_canonical_usize(self.virt); for j in 0..VALUE_LIMBS { row[value_limb(j)] = self.value[j]; @@ -79,79 +78,6 @@ impl MemoryOp { } } -pub fn generate_random_memory_ops( - num_ops: usize, - rng: &mut R, -) -> Vec> { - let mut memory_ops = Vec::new(); - - let mut current_memory_values: HashMap<(usize, usize, usize), [F; 8]> = HashMap::new(); - let num_cycles = num_ops / 2; - for clock in 0..num_cycles { - let mut used_indices = HashSet::new(); - let mut new_writes_this_cycle = HashMap::new(); - let mut has_read = false; - for _ in 0..2 { - let mut channel_index = rng.gen_range(0..NUM_CHANNELS); - while used_indices.contains(&channel_index) { - channel_index = rng.gen_range(0..NUM_CHANNELS); - } - used_indices.insert(channel_index); - - let is_read = if clock == 0 { - false - } else { - !has_read && rng.gen() - }; - has_read = has_read || is_read; - - let (context, segment, virt, vals) = if is_read { - let written: Vec<_> = current_memory_values.keys().collect(); - let &(context, segment, virt) = written[rng.gen_range(0..written.len())]; - let &vals = current_memory_values - .get(&(context, segment, virt)) - .unwrap(); - - (context, segment, virt, vals) - } else { - // TODO: with taller memory table or more padding (to enable range-checking bigger diffs), - // test larger address values. - let mut context = rng.gen_range(0..40); - let mut segment = rng.gen_range(0..8); - let mut virt = rng.gen_range(0..20); - while new_writes_this_cycle.contains_key(&(context, segment, virt)) { - context = rng.gen_range(0..40); - segment = rng.gen_range(0..8); - virt = rng.gen_range(0..20); - } - - let val: [u32; 8] = rng.gen(); - let vals: [F; 8] = val.map(F::from_canonical_u32); - - new_writes_this_cycle.insert((context, segment, virt), vals); - - (context, segment, virt, vals) - }; - - let timestamp = clock * NUM_CHANNELS + channel_index; - memory_ops.push(MemoryOp { - channel_index: Some(channel_index), - timestamp, - is_read, - context, - segment, - virt, - value: vals, - }); - } - for (k, v) in new_writes_this_cycle { - current_memory_values.insert(k, v); - } - } - - memory_ops -} - fn get_max_range_check(memory_ops: &[MemoryOp]) -> usize { memory_ops .iter() @@ -160,7 +86,7 @@ fn get_max_range_check(memory_ops: &[MemoryOp]) -> usize { if curr.context != next.context { next.context - curr.context - 1 } else if curr.segment != next.segment { - next.segment - curr.segment - 1 + next.segment as usize - curr.segment as usize - 1 } else if curr.virt != next.virt { next.virt - curr.virt - 1 } else { @@ -264,7 +190,7 @@ impl, const D: usize> MemoryStark { } } - pub fn generate_trace(&self, memory_ops: Vec>) -> Vec> { + pub(crate) fn generate_trace(&self, memory_ops: Vec>) -> Vec> { let mut timing = TimingTree::new("generate trace", log::Level::Debug); // Generate most of the trace in row-major form. @@ -533,13 +459,94 @@ impl, const D: usize> Stark for MemoryStark( + num_ops: usize, + rng: &mut R, + ) -> Vec> { + let mut memory_ops = Vec::new(); + + let mut current_memory_values: HashMap<(usize, Segment, usize), [F; 8]> = HashMap::new(); + let num_cycles = num_ops / 2; + for clock in 0..num_cycles { + let mut used_indices = HashSet::new(); + let mut new_writes_this_cycle = HashMap::new(); + let mut has_read = false; + for _ in 0..2 { + let mut channel_index = rng.gen_range(0..NUM_CHANNELS); + while used_indices.contains(&channel_index) { + channel_index = rng.gen_range(0..NUM_CHANNELS); + } + used_indices.insert(channel_index); + + let is_read = if clock == 0 { + false + } else { + !has_read && rng.gen() + }; + has_read = has_read || is_read; + + let (context, segment, virt, vals) = if is_read { + let written: Vec<_> = current_memory_values.keys().collect(); + let &(context, segment, virt) = written[rng.gen_range(0..written.len())]; + let &vals = current_memory_values + .get(&(context, segment, virt)) + .unwrap(); + + (context, segment, virt, vals) + } else { + // TODO: with taller memory table or more padding (to enable range-checking bigger diffs), + // test larger address values. + let mut context = rng.gen_range(0..40); + let segments = [Segment::Code, Segment::Stack, Segment::MainMemory]; + let mut segment = *segments.choose(rng).unwrap(); + let mut virt = rng.gen_range(0..20); + while new_writes_this_cycle.contains_key(&(context, segment, virt)) { + context = rng.gen_range(0..40); + segment = *segments.choose(rng).unwrap(); + virt = rng.gen_range(0..20); + } + + let val: [u32; 8] = rng.gen(); + let vals: [F; 8] = val.map(F::from_canonical_u32); + + new_writes_this_cycle.insert((context, segment, virt), vals); + + (context, segment, virt, vals) + }; + + let timestamp = clock * NUM_CHANNELS + channel_index; + memory_ops.push(MemoryOp { + channel_index: Some(channel_index), + timestamp, + is_read, + context, + segment, + virt, + value: vals, + }); + } + for (k, v) in new_writes_this_cycle { + current_memory_values.insert(k, v); + } + } + + memory_ops + } + #[test] fn test_stark_degree() -> Result<()> { const D: usize = 2; diff --git a/evm/src/memory/segments.rs b/evm/src/memory/segments.rs index d20cb037..f1b92dfc 100644 --- a/evm/src/memory/segments.rs +++ b/evm/src/memory/segments.rs @@ -1,22 +1,61 @@ -/// Contains EVM bytecode. -pub const CODE: usize = 0; +#[allow(dead_code)] // TODO: Not all segments are used yet. +#[derive(Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd, Debug)] +pub(crate) enum Segment { + /// Contains EVM bytecode. + Code = 0, + /// The program stack. + Stack = 1, + /// Main memory, owned by the contract code. + MainMemory = 2, + /// Data passed to the current context by its caller. + Calldata = 3, + /// Data returned to the current context by its latest callee. + Returndata = 4, + /// A segment which contains a few fixed-size metadata fields, such as the caller's context, or the + /// size of `CALLDATA` and `RETURNDATA`. + Metadata = 5, + /// General purpose kernel memory, used by various kernel functions. + /// In general, calling a helper function can result in this memory being clobbered. + KernelGeneral = 6, + /// Contains transaction data (after it's parsed and converted to a standard format). + TxnData = 7, + /// Raw RLP data. + RlpRaw = 8, + /// RLP data that has been parsed and converted to a more "friendly" format. + RlpParsed = 9, +} -pub const STACK: usize = 1; +impl Segment { + pub(crate) const COUNT: usize = 10; -/// Main memory, owned by the contract code. -pub const MAIN_MEM: usize = 2; + pub(crate) fn all() -> [Self; Self::COUNT] { + [ + Self::Code, + Self::Stack, + Self::MainMemory, + Self::Calldata, + Self::Returndata, + Self::Metadata, + Self::KernelGeneral, + Self::TxnData, + Self::RlpRaw, + Self::RlpParsed, + ] + } -/// Memory owned by the kernel. -pub const KERNEL_MEM: usize = 3; - -/// Data passed to the current context by its caller. -pub const CALLDATA: usize = 4; - -/// Data returned to the current context by its latest callee. -pub const RETURNDATA: usize = 5; - -/// A segment which contains a few fixed-size metadata fields, such as the caller's context, or the -/// size of `CALLDATA` and `RETURNDATA`. -pub const METADATA: usize = 6; - -pub const NUM_SEGMENTS: usize = 7; + /// The variable name that gets passed into kernel assembly code. + pub(crate) fn var_name(&self) -> &'static str { + match self { + Segment::Code => "SEGMENT_CODE", + Segment::Stack => "SEGMENT_STACK", + Segment::MainMemory => "SEGMENT_MAIN_MEMORY", + Segment::Calldata => "SEGMENT_CALLDATA", + Segment::Returndata => "SEGMENT_RETURNDATA", + Segment::Metadata => "SEGMENT_METADATA", + Segment::KernelGeneral => "SEGMENT_KERNEL_GENERAL", + Segment::TxnData => "SEGMENT_TXN_DATA", + Segment::RlpRaw => "SEGMENT_RLP_RAW", + Segment::RlpParsed => "SEGMENT_RLP_PARSED", + } + } +}