diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index 4dbd90fe..1b26a3c4 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -62,7 +62,7 @@ impl GenerationState { let context = self.current_context; let value = self.memory.contexts[context].segments[segment].get(virt); self.memory.log.push(MemoryOp { - channel_index, + channel_index: Some(channel_index), timestamp, is_read: true, context, @@ -84,7 +84,7 @@ impl GenerationState { let timestamp = self.cpu_rows.len(); let context = self.current_context; self.memory.log.push(MemoryOp { - channel_index, + channel_index: Some(channel_index), timestamp, is_read: false, context, diff --git a/evm/src/memory/columns.rs b/evm/src/memory/columns.rs index 65e75891..5f6c3911 100644 --- a/evm/src/memory/columns.rs +++ b/evm/src/memory/columns.rs @@ -1,7 +1,5 @@ //! Memory registers. -use std::ops::Range; - use crate::memory::{NUM_CHANNELS, VALUE_LIMBS}; // Columns for memory operations, ordered by (addr, timestamp). @@ -41,7 +39,4 @@ pub(crate) const COUNTER: usize = RANGE_CHECK + 1; pub(crate) const RANGE_CHECK_PERMUTED: usize = COUNTER + 1; pub(crate) const COUNTER_PERMUTED: usize = RANGE_CHECK_PERMUTED + 1; -// Columns to be padded at the top with zeroes, before the permutation argument takes place. -pub(crate) const COLUMNS_TO_PAD: Range = TIMESTAMP..RANGE_CHECK + 1; - pub(crate) const NUM_COLUMNS: usize = COUNTER_PERMUTED + 1; diff --git a/evm/src/memory/memory_stark.rs b/evm/src/memory/memory_stark.rs index 1653ac16..14a75810 100644 --- a/evm/src/memory/memory_stark.rs +++ b/evm/src/memory/memory_stark.rs @@ -15,11 +15,11 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cross_table_lookup::Column; use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols}; use crate::memory::columns::{ - is_channel, value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, COLUMNS_TO_PAD, - CONTEXT_FIRST_CHANGE, COUNTER, COUNTER_PERMUTED, IS_READ, NUM_COLUMNS, RANGE_CHECK, - RANGE_CHECK_PERMUTED, SEGMENT_FIRST_CHANGE, TIMESTAMP, VIRTUAL_FIRST_CHANGE, + is_channel, value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, CONTEXT_FIRST_CHANGE, + COUNTER, COUNTER_PERMUTED, IS_READ, NUM_COLUMNS, RANGE_CHECK, RANGE_CHECK_PERMUTED, + SEGMENT_FIRST_CHANGE, TIMESTAMP, VIRTUAL_FIRST_CHANGE, }; -use crate::memory::NUM_CHANNELS; +use crate::memory::{NUM_CHANNELS, VALUE_LIMBS}; use crate::permutation::PermutationPair; use crate::stark::Stark; use crate::util::trace_rows_to_poly_values; @@ -44,9 +44,10 @@ pub struct MemoryStark { pub(crate) f: PhantomData, } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct MemoryOp { - pub channel_index: usize, + /// The channel this operation came from, or `None` if it's a dummy operation for padding. + pub channel_index: Option, pub timestamp: usize, pub is_read: bool, pub context: usize, @@ -111,7 +112,7 @@ pub fn generate_random_memory_ops( let timestamp = clock * NUM_CHANNELS + channel_index; memory_ops.push(MemoryOp { - channel_index, + channel_index: Some(channel_index), timestamp, is_read, context, @@ -128,6 +129,25 @@ pub fn generate_random_memory_ops( memory_ops } +fn get_max_range_check(memory_ops: &[MemoryOp]) -> usize { + memory_ops + .iter() + .tuple_windows() + .map(|(curr, next)| { + if curr.context != next.context { + next.context - curr.context - 1 + } else if curr.segment != next.segment { + next.segment - curr.segment - 1 + } else if curr.virt != next.virt { + next.virt - curr.virt - 1 + } else { + next.timestamp - curr.timestamp - 1 + } + }) + .max() + .unwrap_or(0) +} + pub fn generate_first_change_flags( context: &[F], segment: &[F], @@ -169,7 +189,7 @@ pub fn generate_range_check_value( context_first_change: &[F], segment_first_change: &[F], virtual_first_change: &[F], -) -> (Vec, usize) { +) -> Vec { let num_ops = context.len(); let mut range_check = Vec::new(); @@ -187,9 +207,7 @@ pub fn generate_range_check_value( } range_check.push(F::ZERO); - let max_diff = range_check.iter().map(F::to_canonical_u64).max().unwrap() as usize; - - (range_check, max_diff) + range_check } impl, const D: usize> MemoryStark { @@ -198,7 +216,7 @@ impl, const D: usize> MemoryStark { mut memory_ops: Vec>, ) -> Vec<[F; NUM_COLUMNS]> { memory_ops.sort_by_key(|op| (op.context, op.segment, op.virt, op.timestamp)); - + Self::pad_memory_ops(&mut memory_ops); let num_ops = memory_ops.len(); let mut trace_cols = [(); NUM_COLUMNS].map(|_| vec![F::ZERO; num_ops]); @@ -212,22 +230,21 @@ impl, const D: usize> MemoryStark { virt, value, } = memory_ops[i]; - trace_cols[is_channel(channel_index)][i] = F::ONE; + if let Some(channel) = channel_index { + trace_cols[is_channel(channel)][i] = F::ONE; + } trace_cols[TIMESTAMP][i] = F::from_canonical_usize(timestamp); trace_cols[IS_READ][i] = F::from_bool(is_read); trace_cols[ADDR_CONTEXT][i] = F::from_canonical_usize(context); trace_cols[ADDR_SEGMENT][i] = F::from_canonical_usize(segment); trace_cols[ADDR_VIRTUAL][i] = F::from_canonical_usize(virt); - for j in 0..8 { + for j in 0..VALUE_LIMBS { trace_cols[value_limb(j)][i] = value[j]; } } self.generate_memory(&mut trace_cols); - // The number of rows may have changed, if the range check required padding. - let num_ops = trace_cols[0].len(); - let mut trace_rows = vec![[F::ZERO; NUM_COLUMNS]; num_ops]; for (i, col) in trace_cols.iter().enumerate() { for (j, &val) in col.iter().enumerate() { @@ -237,6 +254,29 @@ impl, const D: usize> MemoryStark { trace_rows } + fn pad_memory_ops(memory_ops: &mut Vec>) { + let num_ops = memory_ops.len(); + let max_range_check = get_max_range_check(&memory_ops); + let num_ops_padded = num_ops.max(max_range_check + 1).next_power_of_two(); + let to_pad = num_ops_padded - num_ops; + + let last_op = memory_ops.last().expect("No memory ops?").clone(); + + // We essentially repeat the last operation until our operation list has the desired size, + // with a few changes: + // - We change its channel to `None` to indicate that this is a dummy operation. + // - We increment its timestamp in order to pass the ordering check. + // - We make sure it's a read, sine dummy operations must be reads. + for i in 0..to_pad { + memory_ops.push(MemoryOp { + channel_index: None, + timestamp: last_op.timestamp + i + 1, + is_read: true, + ..last_op + }); + } + } + fn generate_memory(&self, trace_cols: &mut [Vec]) { let num_trace_rows = trace_cols[0].len(); @@ -248,7 +288,7 @@ impl, const D: usize> MemoryStark { let (context_first_change, segment_first_change, virtual_first_change) = generate_first_change_flags(context, segment, virtuals); - let (range_check_value, max_diff) = generate_range_check_value( + trace_cols[RANGE_CHECK] = generate_range_check_value( context, segment, virtuals, @@ -257,20 +297,14 @@ impl, const D: usize> MemoryStark { &segment_first_change, &virtual_first_change, ); - let to_pad_to = (max_diff + 1).max(num_trace_rows).next_power_of_two(); - let to_pad = to_pad_to - num_trace_rows; trace_cols[CONTEXT_FIRST_CHANGE] = context_first_change; trace_cols[SEGMENT_FIRST_CHANGE] = segment_first_change; trace_cols[VIRTUAL_FIRST_CHANGE] = virtual_first_change; - trace_cols[RANGE_CHECK] = range_check_value; - - for col in COLUMNS_TO_PAD { - trace_cols[col].splice(0..0, vec![F::ZERO; to_pad]); - } - - trace_cols[COUNTER] = (0..to_pad_to).map(|i| F::from_canonical_usize(i)).collect(); + trace_cols[COUNTER] = (0..num_trace_rows) + .map(|i| F::from_canonical_usize(i)) + .collect(); let (permuted_inputs, permuted_table) = permuted_cols(&trace_cols[RANGE_CHECK], &trace_cols[COUNTER]); @@ -326,11 +360,23 @@ impl, const D: usize> Stark for MemoryStark = (0..8).map(|i| vars.next_values[value_limb(i)]).collect(); - // Indicator that this is a real row, not a row of padding. - // TODO: enforce that all padding is at the beginning. - let valid_row: P = (0..NUM_CHANNELS) + // Each `is_channel` value must be 0 or 1. + for c in 0..NUM_CHANNELS { + let is_channel = vars.local_values[is_channel(c)]; + yield_constr.constraint(is_channel * (is_channel - P::ONES)); + } + + // The sum of `is_channel` flags, `has_channel`, must also be 0 or 1. + let has_channel: P = (0..NUM_CHANNELS) .map(|c| vars.local_values[is_channel(c)]) .sum(); + yield_constr.constraint(has_channel * (has_channel - P::ONES)); + + // If this is a dummy row (with no channel), it must be a read. This means the prover can + // insert reads which never appear in the CPU trace (which are harmless), but not writes. + let is_dummy = P::ONES - has_channel; + let is_write = P::ONES - vars.local_values[IS_READ]; + yield_constr.constraint(is_dummy * is_write); let context_first_change = vars.local_values[CONTEXT_FIRST_CHANGE]; let segment_first_change = vars.local_values[SEGMENT_FIRST_CHANGE]; @@ -358,21 +404,15 @@ impl, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark