diff --git a/evm/src/memory/memory_stark.rs b/evm/src/memory/memory_stark.rs index 37255650..f150323b 100644 --- a/evm/src/memory/memory_stark.rs +++ b/evm/src/memory/memory_stark.rs @@ -9,7 +9,9 @@ use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::timed; use plonky2::util::timing::TimingTree; +use plonky2::util::transpose; use rand::Rng; +use rayon::prelude::*; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cross_table_lookup::Column; @@ -22,7 +24,6 @@ use crate::memory::columns::{ use crate::memory::{NUM_CHANNELS, VALUE_LIMBS}; use crate::permutation::PermutationPair; use crate::stark::Stark; -use crate::util::trace_rows_to_poly_values; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; pub(crate) const NUM_PUBLIC_INPUTS: usize = 0; @@ -56,6 +57,28 @@ pub struct MemoryOp { pub value: [F; 8], } +impl MemoryOp { + /// Generate a row for a given memory operation. Note that this does not generate columns which + /// depend on the next operation, such as `CONTEXT_FIRST_CHANGE`; those are generated later. + /// It also does not generate columns such as `COUNTER`, which are generated later, after the + /// trace has been transposed into column-major form. + fn to_row(&self) -> [F; NUM_COLUMNS] { + let mut row = [F::ZERO; NUM_COLUMNS]; + if let Some(channel) = self.channel_index { + row[is_channel(channel)] = F::ONE; + } + row[TIMESTAMP] = F::from_canonical_usize(self.timestamp); + row[IS_READ] = F::from_bool(self.is_read); + row[ADDR_CONTEXT] = F::from_canonical_usize(self.context); + row[ADDR_SEGMENT] = F::from_canonical_usize(self.segment); + row[ADDR_VIRTUAL] = F::from_canonical_usize(self.virt); + for j in 0..VALUE_LIMBS { + row[value_limb(j)] = self.value[j]; + } + row + } +} + pub fn generate_random_memory_ops( num_ops: usize, rng: &mut R, @@ -148,112 +171,76 @@ fn get_max_range_check(memory_ops: &[MemoryOp]) -> usize { .unwrap_or(0) } -pub fn generate_first_change_flags( - context: &[F], - segment: &[F], - virtuals: &[F], -) -> (Vec, Vec, Vec) { - let num_ops = context.len(); - let mut context_first_change = Vec::with_capacity(num_ops); - let mut segment_first_change = Vec::with_capacity(num_ops); - let mut virtual_first_change = Vec::with_capacity(num_ops); +/// Generates the `_FIRST_CHANGE` columns and the `RANGE_CHECK` column in the trace. +pub fn generate_first_change_flags_and_rc(trace_rows: &mut [[F; NUM_COLUMNS]]) { + let num_ops = trace_rows.len(); for idx in 0..num_ops - 1 { - let this_context_first_change = context[idx] != context[idx + 1]; - let this_segment_first_change = - segment[idx] != segment[idx + 1] && !this_context_first_change; - let this_virtual_first_change = virtuals[idx] != virtuals[idx + 1] - && !this_segment_first_change - && !this_context_first_change; + let row = trace_rows[idx].as_slice(); + let next_row = trace_rows[idx + 1].as_slice(); - context_first_change.push(F::from_bool(this_context_first_change)); - segment_first_change.push(F::from_bool(this_segment_first_change)); - virtual_first_change.push(F::from_bool(this_virtual_first_change)); + let context = row[ADDR_CONTEXT]; + let segment = row[ADDR_SEGMENT]; + let virt = row[ADDR_VIRTUAL]; + let timestamp = row[TIMESTAMP]; + let next_context = next_row[ADDR_CONTEXT]; + let next_segment = next_row[ADDR_SEGMENT]; + let next_virt = next_row[ADDR_VIRTUAL]; + let next_timestamp = next_row[TIMESTAMP]; + + let context_changed = context != next_context; + let segment_changed = segment != next_segment; + let virtual_changed = virt != next_virt; + + let context_first_change = context_changed; + let segment_first_change = segment_changed && !context_first_change; + let virtual_first_change = + virtual_changed && !segment_first_change && !context_first_change; + + let row = trace_rows[idx].as_mut_slice(); + row[CONTEXT_FIRST_CHANGE] = F::from_bool(context_first_change); + row[SEGMENT_FIRST_CHANGE] = F::from_bool(segment_first_change); + row[VIRTUAL_FIRST_CHANGE] = F::from_bool(virtual_first_change); + + row[RANGE_CHECK] = if context_first_change { + next_context - context - F::ONE + } else if segment_first_change { + next_segment - segment - F::ONE + } else if virtual_first_change { + next_virt - virt - F::ONE + } else { + next_timestamp - timestamp - F::ONE + }; } - - context_first_change.push(F::ZERO); - segment_first_change.push(F::ZERO); - virtual_first_change.push(F::ZERO); - - ( - context_first_change, - segment_first_change, - virtual_first_change, - ) -} - -pub fn generate_range_check_value( - context: &[F], - segment: &[F], - virtuals: &[F], - timestamp: &[F], - context_first_change: &[F], - segment_first_change: &[F], - virtual_first_change: &[F], -) -> Vec { - let num_ops = context.len(); - let mut range_check = Vec::new(); - - for idx in 0..num_ops - 1 { - let this_address_unchanged = F::ONE - - context_first_change[idx] - - segment_first_change[idx] - - virtual_first_change[idx]; - range_check.push( - context_first_change[idx] * (context[idx + 1] - context[idx] - F::ONE) - + segment_first_change[idx] * (segment[idx + 1] - segment[idx] - F::ONE) - + virtual_first_change[idx] * (virtuals[idx + 1] - virtuals[idx] - F::ONE) - + this_address_unchanged * (timestamp[idx + 1] - timestamp[idx] - F::ONE), - ); - } - range_check.push(F::ZERO); - - range_check } impl, const D: usize> MemoryStark { - pub(crate) fn generate_trace_rows( - &self, - mut memory_ops: Vec>, - ) -> Vec<[F; NUM_COLUMNS]> { + /// Generate most of the trace rows. Excludes a few columns like `COUNTER`, which are generated + /// later, after transposing to column-major form. + fn generate_trace_row_major(&self, mut memory_ops: Vec>) -> Vec<[F; NUM_COLUMNS]> { memory_ops.sort_by_key(|op| (op.context, op.segment, op.virt, op.timestamp)); + Self::pad_memory_ops(&mut memory_ops); - let num_ops = memory_ops.len(); - let mut trace_cols = [(); NUM_COLUMNS].map(|_| vec![F::ZERO; num_ops]); - for i in 0..num_ops { - let MemoryOp { - channel_index, - timestamp, - is_read, - context, - segment, - virt, - value, - } = memory_ops[i]; - if let Some(channel) = channel_index { - trace_cols[is_channel(channel)][i] = F::ONE; - } - trace_cols[TIMESTAMP][i] = F::from_canonical_usize(timestamp); - trace_cols[IS_READ][i] = F::from_bool(is_read); - trace_cols[ADDR_CONTEXT][i] = F::from_canonical_usize(context); - trace_cols[ADDR_SEGMENT][i] = F::from_canonical_usize(segment); - trace_cols[ADDR_VIRTUAL][i] = F::from_canonical_usize(virt); - for j in 0..VALUE_LIMBS { - trace_cols[value_limb(j)][i] = value[j]; - } - } - - self.generate_memory(&mut trace_cols); - - let mut trace_rows = vec![[F::ZERO; NUM_COLUMNS]; num_ops]; - for (i, col) in trace_cols.iter().enumerate() { - for (j, &val) in col.iter().enumerate() { - trace_rows[j][i] = val; - } - } + let mut trace_rows = memory_ops + .into_par_iter() + .map(|op| op.to_row()) + .collect::>(); + generate_first_change_flags_and_rc(trace_rows.as_mut_slice()); trace_rows } + /// Generates the `COUNTER`, `RANGE_CHECK_PERMUTED` and `COUNTER_PERMUTED` columns, given a + /// trace in column-major form. + fn generate_trace_col_major(trace_col_vecs: &mut [Vec]) { + let height = trace_col_vecs[0].len(); + trace_col_vecs[COUNTER] = (0..height).map(|i| F::from_canonical_usize(i)).collect(); + + let (permuted_inputs, permuted_table) = + permuted_cols(&trace_col_vecs[RANGE_CHECK], &trace_col_vecs[COUNTER]); + trace_col_vecs[RANGE_CHECK_PERMUTED] = permuted_inputs; + trace_col_vecs[COUNTER_PERMUTED] = permuted_table; + } + fn pad_memory_ops(memory_ops: &mut Vec>) { let num_ops = memory_ops.len(); let max_range_check = get_max_range_check(memory_ops); @@ -277,56 +264,27 @@ impl, const D: usize> MemoryStark { } } - fn generate_memory(&self, trace_cols: &mut [Vec]) { - let num_trace_rows = trace_cols[0].len(); - - let timestamp = &trace_cols[TIMESTAMP]; - let context = &trace_cols[ADDR_CONTEXT]; - let segment = &trace_cols[ADDR_SEGMENT]; - let virtuals = &trace_cols[ADDR_VIRTUAL]; - - let (context_first_change, segment_first_change, virtual_first_change) = - generate_first_change_flags(context, segment, virtuals); - - trace_cols[RANGE_CHECK] = generate_range_check_value( - context, - segment, - virtuals, - timestamp, - &context_first_change, - &segment_first_change, - &virtual_first_change, - ); - - trace_cols[CONTEXT_FIRST_CHANGE] = context_first_change; - trace_cols[SEGMENT_FIRST_CHANGE] = segment_first_change; - trace_cols[VIRTUAL_FIRST_CHANGE] = virtual_first_change; - - trace_cols[COUNTER] = (0..num_trace_rows) - .map(|i| F::from_canonical_usize(i)) - .collect(); - - let (permuted_inputs, permuted_table) = - permuted_cols(&trace_cols[RANGE_CHECK], &trace_cols[COUNTER]); - trace_cols[RANGE_CHECK_PERMUTED] = permuted_inputs; - trace_cols[COUNTER_PERMUTED] = permuted_table; - } - pub fn generate_trace(&self, memory_ops: Vec>) -> Vec> { let mut timing = TimingTree::new("generate trace", log::Level::Debug); - // Generate the witness. + // Generate most of the trace in row-major form. let trace_rows = timed!( &mut timing, "generate trace rows", - self.generate_trace_rows(memory_ops) + self.generate_trace_row_major(memory_ops) ); + let trace_row_vecs: Vec<_> = trace_rows.into_iter().map(|row| row.to_vec()).collect(); - let trace_polys = timed!( - &mut timing, - "convert to PolynomialValues", - trace_rows_to_poly_values(trace_rows) - ); + // Transpose to column-major form. + let mut trace_col_vecs = transpose(&trace_row_vecs); + + // A few final generation steps, which work better in column-major form. + Self::generate_trace_col_major(&mut trace_col_vecs); + + let trace_polys = trace_col_vecs + .into_iter() + .map(|column| PolynomialValues::new(column)) + .collect(); timing.print(); trace_polys