diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index d1f993cd..b69104eb 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -119,7 +119,7 @@ mod tests { use anyhow::Result; use itertools::{izip, Itertools}; use plonky2::field::polynomial::PolynomialValues; - use plonky2::field::types::Field; + use plonky2::field::types::{Field, PrimeField64}; use plonky2::iop::witness::PartialWitness; use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_data::CircuitConfig; @@ -196,9 +196,11 @@ mod tests { num_memory_ops: usize, memory_stark: &MemoryStark, rng: &mut R, - ) -> Vec> { + ) -> (Vec>, usize) { let memory_ops = generate_random_memory_ops(num_memory_ops, rng); - memory_stark.generate_trace(memory_ops) + let trace = memory_stark.generate_trace(memory_ops); + let num_ops = trace[0].values.len(); + (trace, num_ops) } fn make_cpu_trace( @@ -282,32 +284,34 @@ mod tests { cpu_stark.generate(row.borrow_mut()); cpu_trace_rows.push(row.into()); } - - let mut current_cpu_index = 0; - let mut last_timestamp = memory_trace[memory::columns::TIMESTAMP].values[0]; for i in 0..num_memory_ops { - let mem_timestamp = memory_trace[memory::columns::TIMESTAMP].values[i]; - let clock = mem_timestamp; - let op = (0..NUM_CHANNELS) - .filter(|&o| memory_trace[memory::columns::is_channel(o)].values[i] == F::ONE) - .collect_vec()[0]; + let mem_timestamp: usize = memory_trace[memory::columns::TIMESTAMP].values[i] + .to_canonical_u64() + .try_into() + .unwrap(); + let clock = mem_timestamp / NUM_CHANNELS; + let channel = mem_timestamp % NUM_CHANNELS; - if mem_timestamp != last_timestamp { - current_cpu_index += 1; - last_timestamp = mem_timestamp; - } + let is_padding_row = (0..NUM_CHANNELS) + .map(|c| memory_trace[memory::columns::is_channel(c)].values[i]) + .all(|x| x == F::ZERO); - let row: &mut cpu::columns::CpuColumnsView = - cpu_trace_rows[current_cpu_index].borrow_mut(); + if !is_padding_row { + let row: &mut cpu::columns::CpuColumnsView = cpu_trace_rows[clock].borrow_mut(); - row.mem_channel_used[op] = F::ONE; - row.clock = clock; - row.mem_is_read[op] = memory_trace[memory::columns::IS_READ].values[i]; - row.mem_addr_context[op] = memory_trace[memory::columns::ADDR_CONTEXT].values[i]; - row.mem_addr_segment[op] = memory_trace[memory::columns::ADDR_SEGMENT].values[i]; - row.mem_addr_virtual[op] = memory_trace[memory::columns::ADDR_VIRTUAL].values[i]; - for j in 0..8 { - row.mem_value[op][j] = memory_trace[memory::columns::value_limb(j)].values[i]; + row.mem_channel_used[channel] = F::ONE; + row.clock = F::from_canonical_usize(clock); + row.mem_is_read[channel] = memory_trace[memory::columns::IS_READ].values[i]; + row.mem_addr_context[channel] = + memory_trace[memory::columns::ADDR_CONTEXT].values[i]; + row.mem_addr_segment[channel] = + memory_trace[memory::columns::ADDR_SEGMENT].values[i]; + row.mem_addr_virtual[channel] = + memory_trace[memory::columns::ADDR_VIRTUAL].values[i]; + for j in 0..8 { + row.mem_value[channel][j] = + memory_trace[memory::columns::value_limb(j)].values[i]; + } } } trace_rows_to_poly_values(cpu_trace_rows) @@ -337,7 +341,9 @@ mod tests { let keccak_trace = make_keccak_trace(num_keccak_perms, &keccak_stark, &mut rng); let logic_trace = make_logic_trace(num_logic_rows, &logic_stark, &mut rng); - let mut memory_trace = make_memory_trace(num_memory_ops, &memory_stark, &mut rng); + let mem_trace = make_memory_trace(num_memory_ops, &memory_stark, &mut rng); + let mut memory_trace = mem_trace.0; + let num_memory_ops = mem_trace.1; let cpu_trace = make_cpu_trace( num_keccak_perms, num_logic_rows, diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index ee0cf98e..ad32dd98 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -40,7 +40,6 @@ pub fn ctl_filter_logic() -> Column { pub fn ctl_data_memory(channel: usize) -> Vec> { debug_assert!(channel < NUM_CHANNELS); let mut cols: Vec> = Column::singles([ - COL_MAP.clock, COL_MAP.mem_is_read[channel], COL_MAP.mem_addr_context[channel], COL_MAP.mem_addr_segment[channel], @@ -48,6 +47,14 @@ pub fn ctl_data_memory(channel: usize) -> Vec> { ]) .collect_vec(); cols.extend(Column::singles(COL_MAP.mem_value[channel])); + + let scalar = F::from_canonical_usize(NUM_CHANNELS); + let addend = F::from_canonical_usize(channel); + cols.push(Column::linear_combination_with_constant( + vec![(COL_MAP.clock, scalar)], + addend, + )); + cols } diff --git a/evm/src/memory/columns.rs b/evm/src/memory/columns.rs index d9fa927f..214a7e4b 100644 --- a/evm/src/memory/columns.rs +++ b/evm/src/memory/columns.rs @@ -1,5 +1,7 @@ //! Memory registers. +use std::ops::Range; + use crate::memory::{NUM_CHANNELS, VALUE_LIMBS}; pub(crate) const TIMESTAMP: usize = 0; @@ -36,20 +38,22 @@ pub(crate) const CONTEXT_FIRST_CHANGE: usize = SORTED_VALUE_START + VALUE_LIMBS; pub(crate) const SEGMENT_FIRST_CHANGE: usize = CONTEXT_FIRST_CHANGE + 1; pub(crate) const VIRTUAL_FIRST_CHANGE: usize = SEGMENT_FIRST_CHANGE + 1; +// Flags to indicate if this operation came from the `i`th channel of the memory bus. +const IS_CHANNEL_START: usize = VIRTUAL_FIRST_CHANGE + 1; +pub(crate) const fn is_channel(channel: usize) -> usize { + debug_assert!(channel < NUM_CHANNELS); + IS_CHANNEL_START + channel +} + // We use a range check to ensure sorting. -pub(crate) const RANGE_CHECK: usize = VIRTUAL_FIRST_CHANGE + 1; +pub(crate) const RANGE_CHECK: usize = IS_CHANNEL_START + NUM_CHANNELS; // The counter column (used for the range check) starts from 0 and increments. pub(crate) const COUNTER: usize = RANGE_CHECK + 1; // Helper columns for the permutation argument used to enforce the range check. pub(crate) const RANGE_CHECK_PERMUTED: usize = COUNTER + 1; pub(crate) const COUNTER_PERMUTED: usize = RANGE_CHECK_PERMUTED + 1; -// Flags to indicate if this operation came from the `i`th channel of the memory bus. -const IS_CHANNEL_START: usize = COUNTER_PERMUTED + 1; -#[allow(dead_code)] -pub(crate) const fn is_channel(channel: usize) -> usize { - debug_assert!(channel < NUM_CHANNELS); - IS_CHANNEL_START + channel -} +// Columns to be padded at the top with zeroes, before the permutation argument takes place. +pub(crate) const COLUMNS_TO_PAD: Range = TIMESTAMP..RANGE_CHECK + 1; -pub(crate) const NUM_COLUMNS: usize = IS_CHANNEL_START + NUM_CHANNELS; +pub(crate) const NUM_COLUMNS: usize = COUNTER_PERMUTED + 1; diff --git a/evm/src/memory/memory_stark.rs b/evm/src/memory/memory_stark.rs index 3f7c26fc..49ee1ee2 100644 --- a/evm/src/memory/memory_stark.rs +++ b/evm/src/memory/memory_stark.rs @@ -11,17 +11,17 @@ use plonky2::timed; use plonky2::util::timing::TimingTree; use rand::Rng; -use super::columns::is_channel; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cross_table_lookup::Column; use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols}; use crate::memory::columns::{ - sorted_value_limb, value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, CONTEXT_FIRST_CHANGE, - COUNTER, COUNTER_PERMUTED, IS_READ, NUM_COLUMNS, RANGE_CHECK, RANGE_CHECK_PERMUTED, - SEGMENT_FIRST_CHANGE, SORTED_ADDR_CONTEXT, SORTED_ADDR_SEGMENT, SORTED_ADDR_VIRTUAL, - SORTED_IS_READ, SORTED_TIMESTAMP, TIMESTAMP, VIRTUAL_FIRST_CHANGE, + is_channel, sorted_value_limb, value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, + COLUMNS_TO_PAD, CONTEXT_FIRST_CHANGE, COUNTER, COUNTER_PERMUTED, IS_READ, NUM_COLUMNS, + RANGE_CHECK, RANGE_CHECK_PERMUTED, SEGMENT_FIRST_CHANGE, SORTED_ADDR_CONTEXT, + SORTED_ADDR_SEGMENT, SORTED_ADDR_VIRTUAL, SORTED_IS_READ, SORTED_TIMESTAMP, TIMESTAMP, + VIRTUAL_FIRST_CHANGE, }; -use crate::memory::NUM_CHANNELS; +use crate::memory::{NUM_CHANNELS, VALUE_LIMBS}; use crate::permutation::PermutationPair; use crate::stark::Stark; use crate::util::trace_rows_to_poly_values; @@ -30,9 +30,10 @@ use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; pub(crate) const NUM_PUBLIC_INPUTS: usize = 0; pub fn ctl_data() -> Vec> { - let mut res = Column::singles([TIMESTAMP, IS_READ, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL]) - .collect_vec(); + let mut res = + Column::singles([IS_READ, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL]).collect_vec(); res.extend(Column::singles((0..8).map(value_limb))); + res.push(Column::single(TIMESTAMP)); res } @@ -63,8 +64,7 @@ pub fn generate_random_memory_ops( let mut current_memory_values: HashMap<(F, F, F), [F; 8]> = HashMap::new(); let num_cycles = num_ops / 2; - for i in 0..num_cycles { - let timestamp = F::from_canonical_usize(i); + for clock in 0..num_cycles { let mut used_indices = HashSet::new(); let mut new_writes_this_cycle = HashMap::new(); let mut has_read = false; @@ -75,7 +75,7 @@ pub fn generate_random_memory_ops( } used_indices.insert(channel_index); - let is_read = if i == 0 { + let is_read = if clock == 0 { false } else { !has_read && rng.gen() @@ -111,6 +111,7 @@ pub fn generate_random_memory_ops( (context, segment, virt, vals) }; + let timestamp = F::from_canonical_usize(clock * NUM_CHANNELS + channel_index); memory_ops.push(MemoryOp { channel_index, timestamp, @@ -200,7 +201,7 @@ pub fn generate_range_check_value( context_first_change: &[F], segment_first_change: &[F], virtual_first_change: &[F], -) -> Vec { +) -> (Vec, usize) { let num_ops = context.len(); let mut range_check = Vec::new(); @@ -209,7 +210,6 @@ pub fn generate_range_check_value( - context_first_change[idx] - segment_first_change[idx] - virtual_first_change[idx]; - range_check.push( context_first_change[idx] * (context[idx + 1] - context[idx] - F::ONE) + segment_first_change[idx] * (segment[idx + 1] - segment[idx] - F::ONE) @@ -217,10 +217,11 @@ pub fn generate_range_check_value( + this_address_unchanged * (timestamp[idx + 1] - timestamp[idx] - F::ONE), ); } - range_check.push(F::ZERO); - range_check + let max_diff = range_check.iter().map(F::to_canonical_u64).max().unwrap() as usize; + + (range_check, max_diff) } impl, const D: usize> MemoryStark { @@ -254,6 +255,9 @@ impl, const D: usize> MemoryStark { self.generate_memory(&mut trace_cols); + // The number of rows may have changed, if the range check required padding. + let num_ops = trace_cols[0].len(); + let mut trace_rows = vec![[F::ZERO; NUM_COLUMNS]; num_ops]; for (i, col) in trace_cols.iter().enumerate() { for (j, &val) in col.iter().enumerate() { @@ -295,7 +299,7 @@ impl, const D: usize> MemoryStark { let (context_first_change, segment_first_change, virtual_first_change) = generate_first_change_flags(&sorted_context, &sorted_segment, &sorted_virtual); - let range_check_value = generate_range_check_value( + let (range_check_value, max_diff) = generate_range_check_value( &sorted_context, &sorted_segment, &sorted_virtual, @@ -304,6 +308,8 @@ impl, const D: usize> MemoryStark { &segment_first_change, &virtual_first_change, ); + let to_pad_to = (max_diff + 1).max(num_trace_rows).next_power_of_two(); + let to_pad = to_pad_to - num_trace_rows; trace_cols[SORTED_TIMESTAMP] = sorted_timestamp; trace_cols[SORTED_IS_READ] = sorted_is_read; @@ -311,7 +317,7 @@ impl, const D: usize> MemoryStark { trace_cols[SORTED_ADDR_SEGMENT] = sorted_segment; trace_cols[SORTED_ADDR_VIRTUAL] = sorted_virtual; for i in 0..num_trace_rows { - for j in 0..8 { + for j in 0..VALUE_LIMBS { trace_cols[sorted_value_limb(j)][i] = sorted_values[i][j]; } } @@ -321,9 +327,12 @@ impl, const D: usize> MemoryStark { trace_cols[VIRTUAL_FIRST_CHANGE] = virtual_first_change; trace_cols[RANGE_CHECK] = range_check_value; - trace_cols[COUNTER] = (0..num_trace_rows) - .map(|i| F::from_canonical_usize(i)) - .collect(); + + for col in COLUMNS_TO_PAD { + trace_cols[col].splice(0..0, vec![F::ZERO; to_pad]); + } + + trace_cols[COUNTER] = (0..to_pad_to).map(|i| F::from_canonical_usize(i)).collect(); let (permuted_inputs, permuted_table) = permuted_cols(&trace_cols[RANGE_CHECK], &trace_cols[COUNTER]); @@ -383,6 +392,12 @@ impl, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark Vec { + let mut unsorted_cols = vec![TIMESTAMP, IS_READ, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL]; + unsorted_cols.extend((0..VALUE_LIMBS).map(value_limb)); + let mut sorted_cols = vec![ + SORTED_TIMESTAMP, + SORTED_IS_READ, + SORTED_ADDR_CONTEXT, + SORTED_ADDR_SEGMENT, + SORTED_ADDR_VIRTUAL, + ]; + sorted_cols.extend((0..VALUE_LIMBS).map(sorted_value_limb)); + let column_pairs: Vec<_> = unsorted_cols + .into_iter() + .zip(sorted_cols.iter().cloned()) + .collect(); + vec![ + PermutationPair { column_pairs }, PermutationPair::singletons(RANGE_CHECK, RANGE_CHECK_PERMUTED), PermutationPair::singletons(COUNTER, COUNTER_PERMUTED), ]