From 7221c96440d4084cafc566483cd6a42da9b02f22 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Tue, 12 Jul 2022 15:29:27 -0700 Subject: [PATCH 1/6] Use FFTs to get subgroup evaluations in `check_constraints` Instead of quadratic evaluation. Should speed up `test_all_stark`. --- evm/src/prover.rs | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 2d398d1a..346224a5 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -492,14 +492,19 @@ fn check_constraints<'a, F, C, S, const D: usize>( let subgroup = F::two_adic_subgroup(degree_bits + rate_bits); - // Retrieve the polynomials values at index `i`. - let get_comm_values = |comm: &PolynomialBatch, i| -> Vec { - comm.polynomials - .iter() - .map(|poly| poly.eval(subgroup[i])) // O(n^2) FTW - .collect() + // Get the evaluations of a batch of polynomials over our subgroup. + let get_subgroup_evals = |comm: &PolynomialBatch| -> Vec> { + let values = comm + .polynomials + .par_iter() + .map(|coeffs| coeffs.clone().fft().values) + .collect::>(); + transpose(&values) }; + let trace_subgroup_evals = get_subgroup_evals(trace_commitment); + let permutation_ctl_zs_subgroup_evals = get_subgroup_evals(permutation_ctl_zs_commitment); + // Last element of the subgroup. let last = F::primitive_root_of_unity(degree_bits).inverse(); @@ -519,19 +524,14 @@ fn check_constraints<'a, F, C, S, const D: usize>( lagrange_basis_last, ); let vars = StarkEvaluationVars { - local_values: &get_comm_values(trace_commitment, i).try_into().unwrap(), - next_values: &get_comm_values(trace_commitment, i_next) - .try_into() - .unwrap(), + local_values: trace_subgroup_evals[i].as_slice().try_into().unwrap(), + next_values: trace_subgroup_evals[i_next].as_slice().try_into().unwrap(), public_inputs: &public_inputs, }; let permutation_check_vars = permutation_challenges.map(|permutation_challenge_sets| PermutationCheckVars { - local_zs: get_comm_values(permutation_ctl_zs_commitment, i) - [..num_permutation_zs] - .to_vec(), - next_zs: get_comm_values(permutation_ctl_zs_commitment, i_next) - [..num_permutation_zs] + local_zs: permutation_ctl_zs_subgroup_evals[i][..num_permutation_zs].to_vec(), + next_zs: permutation_ctl_zs_subgroup_evals[i_next][..num_permutation_zs] .to_vec(), permutation_challenge_sets: permutation_challenge_sets.to_vec(), }); @@ -542,10 +542,8 @@ fn check_constraints<'a, F, C, S, const D: usize>( .enumerate() .map( |(iii, (_, columns, filter_column))| CtlCheckVars:: { - local_z: get_comm_values(permutation_ctl_zs_commitment, i) - [num_permutation_zs + iii], - next_z: get_comm_values(permutation_ctl_zs_commitment, i_next) - [num_permutation_zs + iii], + local_z: permutation_ctl_zs_subgroup_evals[i][num_permutation_zs + iii], + next_z: permutation_ctl_zs_subgroup_evals[i_next][num_permutation_zs + iii], challenges: ctl_data.challenges.challenges[iii % config.num_challenges], columns, filter_column, From d1afe8129cb328b870ccb1ea8342f709aee23880 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Tue, 12 Jul 2022 17:25:46 -0700 Subject: [PATCH 2/6] More realistic padding rows in memory This adds padding rows which satisfy the ordering checks. To ensure that they also satisfy the value consistency checks, I just copied the address and value from the last operation. I think this method of padding feels more natural, though it is a bit more code since we need to calculate the max range check in a different way. But on the plus side, the constraints are a bit smaller and simpler. Also added a few constraints that I think we need for soundness: - Each `is_channel` flag is bool. - Sum of `is_channel` flags is bool. - Dummy operations must be reads (otherwise the prover could put writes in the memory table which aren't in the CPU table). --- evm/src/generation/state.rs | 4 +- evm/src/memory/columns.rs | 5 -- evm/src/memory/memory_stark.rs | 155 +++++++++++++++++++++------------ 3 files changed, 103 insertions(+), 61 deletions(-) diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index 4dbd90fe..1b26a3c4 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -62,7 +62,7 @@ impl GenerationState { let context = self.current_context; let value = self.memory.contexts[context].segments[segment].get(virt); self.memory.log.push(MemoryOp { - channel_index, + channel_index: Some(channel_index), timestamp, is_read: true, context, @@ -84,7 +84,7 @@ impl GenerationState { let timestamp = self.cpu_rows.len(); let context = self.current_context; self.memory.log.push(MemoryOp { - channel_index, + channel_index: Some(channel_index), timestamp, is_read: false, context, diff --git a/evm/src/memory/columns.rs b/evm/src/memory/columns.rs index 65e75891..5f6c3911 100644 --- a/evm/src/memory/columns.rs +++ b/evm/src/memory/columns.rs @@ -1,7 +1,5 @@ //! Memory registers. -use std::ops::Range; - use crate::memory::{NUM_CHANNELS, VALUE_LIMBS}; // Columns for memory operations, ordered by (addr, timestamp). @@ -41,7 +39,4 @@ pub(crate) const COUNTER: usize = RANGE_CHECK + 1; pub(crate) const RANGE_CHECK_PERMUTED: usize = COUNTER + 1; pub(crate) const COUNTER_PERMUTED: usize = RANGE_CHECK_PERMUTED + 1; -// Columns to be padded at the top with zeroes, before the permutation argument takes place. -pub(crate) const COLUMNS_TO_PAD: Range = TIMESTAMP..RANGE_CHECK + 1; - pub(crate) const NUM_COLUMNS: usize = COUNTER_PERMUTED + 1; diff --git a/evm/src/memory/memory_stark.rs b/evm/src/memory/memory_stark.rs index 1653ac16..14a75810 100644 --- a/evm/src/memory/memory_stark.rs +++ b/evm/src/memory/memory_stark.rs @@ -15,11 +15,11 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cross_table_lookup::Column; use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols}; use crate::memory::columns::{ - is_channel, value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, COLUMNS_TO_PAD, - CONTEXT_FIRST_CHANGE, COUNTER, COUNTER_PERMUTED, IS_READ, NUM_COLUMNS, RANGE_CHECK, - RANGE_CHECK_PERMUTED, SEGMENT_FIRST_CHANGE, TIMESTAMP, VIRTUAL_FIRST_CHANGE, + is_channel, value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, CONTEXT_FIRST_CHANGE, + COUNTER, COUNTER_PERMUTED, IS_READ, NUM_COLUMNS, RANGE_CHECK, RANGE_CHECK_PERMUTED, + SEGMENT_FIRST_CHANGE, TIMESTAMP, VIRTUAL_FIRST_CHANGE, }; -use crate::memory::NUM_CHANNELS; +use crate::memory::{NUM_CHANNELS, VALUE_LIMBS}; use crate::permutation::PermutationPair; use crate::stark::Stark; use crate::util::trace_rows_to_poly_values; @@ -44,9 +44,10 @@ pub struct MemoryStark { pub(crate) f: PhantomData, } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct MemoryOp { - pub channel_index: usize, + /// The channel this operation came from, or `None` if it's a dummy operation for padding. + pub channel_index: Option, pub timestamp: usize, pub is_read: bool, pub context: usize, @@ -111,7 +112,7 @@ pub fn generate_random_memory_ops( let timestamp = clock * NUM_CHANNELS + channel_index; memory_ops.push(MemoryOp { - channel_index, + channel_index: Some(channel_index), timestamp, is_read, context, @@ -128,6 +129,25 @@ pub fn generate_random_memory_ops( memory_ops } +fn get_max_range_check(memory_ops: &[MemoryOp]) -> usize { + memory_ops + .iter() + .tuple_windows() + .map(|(curr, next)| { + if curr.context != next.context { + next.context - curr.context - 1 + } else if curr.segment != next.segment { + next.segment - curr.segment - 1 + } else if curr.virt != next.virt { + next.virt - curr.virt - 1 + } else { + next.timestamp - curr.timestamp - 1 + } + }) + .max() + .unwrap_or(0) +} + pub fn generate_first_change_flags( context: &[F], segment: &[F], @@ -169,7 +189,7 @@ pub fn generate_range_check_value( context_first_change: &[F], segment_first_change: &[F], virtual_first_change: &[F], -) -> (Vec, usize) { +) -> Vec { let num_ops = context.len(); let mut range_check = Vec::new(); @@ -187,9 +207,7 @@ pub fn generate_range_check_value( } range_check.push(F::ZERO); - let max_diff = range_check.iter().map(F::to_canonical_u64).max().unwrap() as usize; - - (range_check, max_diff) + range_check } impl, const D: usize> MemoryStark { @@ -198,7 +216,7 @@ impl, const D: usize> MemoryStark { mut memory_ops: Vec>, ) -> Vec<[F; NUM_COLUMNS]> { memory_ops.sort_by_key(|op| (op.context, op.segment, op.virt, op.timestamp)); - + Self::pad_memory_ops(&mut memory_ops); let num_ops = memory_ops.len(); let mut trace_cols = [(); NUM_COLUMNS].map(|_| vec![F::ZERO; num_ops]); @@ -212,22 +230,21 @@ impl, const D: usize> MemoryStark { virt, value, } = memory_ops[i]; - trace_cols[is_channel(channel_index)][i] = F::ONE; + if let Some(channel) = channel_index { + trace_cols[is_channel(channel)][i] = F::ONE; + } trace_cols[TIMESTAMP][i] = F::from_canonical_usize(timestamp); trace_cols[IS_READ][i] = F::from_bool(is_read); trace_cols[ADDR_CONTEXT][i] = F::from_canonical_usize(context); trace_cols[ADDR_SEGMENT][i] = F::from_canonical_usize(segment); trace_cols[ADDR_VIRTUAL][i] = F::from_canonical_usize(virt); - for j in 0..8 { + for j in 0..VALUE_LIMBS { trace_cols[value_limb(j)][i] = value[j]; } } self.generate_memory(&mut trace_cols); - // The number of rows may have changed, if the range check required padding. - let num_ops = trace_cols[0].len(); - let mut trace_rows = vec![[F::ZERO; NUM_COLUMNS]; num_ops]; for (i, col) in trace_cols.iter().enumerate() { for (j, &val) in col.iter().enumerate() { @@ -237,6 +254,29 @@ impl, const D: usize> MemoryStark { trace_rows } + fn pad_memory_ops(memory_ops: &mut Vec>) { + let num_ops = memory_ops.len(); + let max_range_check = get_max_range_check(&memory_ops); + let num_ops_padded = num_ops.max(max_range_check + 1).next_power_of_two(); + let to_pad = num_ops_padded - num_ops; + + let last_op = memory_ops.last().expect("No memory ops?").clone(); + + // We essentially repeat the last operation until our operation list has the desired size, + // with a few changes: + // - We change its channel to `None` to indicate that this is a dummy operation. + // - We increment its timestamp in order to pass the ordering check. + // - We make sure it's a read, sine dummy operations must be reads. + for i in 0..to_pad { + memory_ops.push(MemoryOp { + channel_index: None, + timestamp: last_op.timestamp + i + 1, + is_read: true, + ..last_op + }); + } + } + fn generate_memory(&self, trace_cols: &mut [Vec]) { let num_trace_rows = trace_cols[0].len(); @@ -248,7 +288,7 @@ impl, const D: usize> MemoryStark { let (context_first_change, segment_first_change, virtual_first_change) = generate_first_change_flags(context, segment, virtuals); - let (range_check_value, max_diff) = generate_range_check_value( + trace_cols[RANGE_CHECK] = generate_range_check_value( context, segment, virtuals, @@ -257,20 +297,14 @@ impl, const D: usize> MemoryStark { &segment_first_change, &virtual_first_change, ); - let to_pad_to = (max_diff + 1).max(num_trace_rows).next_power_of_two(); - let to_pad = to_pad_to - num_trace_rows; trace_cols[CONTEXT_FIRST_CHANGE] = context_first_change; trace_cols[SEGMENT_FIRST_CHANGE] = segment_first_change; trace_cols[VIRTUAL_FIRST_CHANGE] = virtual_first_change; - trace_cols[RANGE_CHECK] = range_check_value; - - for col in COLUMNS_TO_PAD { - trace_cols[col].splice(0..0, vec![F::ZERO; to_pad]); - } - - trace_cols[COUNTER] = (0..to_pad_to).map(|i| F::from_canonical_usize(i)).collect(); + trace_cols[COUNTER] = (0..num_trace_rows) + .map(|i| F::from_canonical_usize(i)) + .collect(); let (permuted_inputs, permuted_table) = permuted_cols(&trace_cols[RANGE_CHECK], &trace_cols[COUNTER]); @@ -326,11 +360,23 @@ impl, const D: usize> Stark for MemoryStark = (0..8).map(|i| vars.next_values[value_limb(i)]).collect(); - // Indicator that this is a real row, not a row of padding. - // TODO: enforce that all padding is at the beginning. - let valid_row: P = (0..NUM_CHANNELS) + // Each `is_channel` value must be 0 or 1. + for c in 0..NUM_CHANNELS { + let is_channel = vars.local_values[is_channel(c)]; + yield_constr.constraint(is_channel * (is_channel - P::ONES)); + } + + // The sum of `is_channel` flags, `has_channel`, must also be 0 or 1. + let has_channel: P = (0..NUM_CHANNELS) .map(|c| vars.local_values[is_channel(c)]) .sum(); + yield_constr.constraint(has_channel * (has_channel - P::ONES)); + + // If this is a dummy row (with no channel), it must be a read. This means the prover can + // insert reads which never appear in the CPU trace (which are harmless), but not writes. + let is_dummy = P::ONES - has_channel; + let is_write = P::ONES - vars.local_values[IS_READ]; + yield_constr.constraint(is_dummy * is_write); let context_first_change = vars.local_values[CONTEXT_FIRST_CHANGE]; let segment_first_change = vars.local_values[SEGMENT_FIRST_CHANGE]; @@ -358,21 +404,15 @@ impl, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark Date: Tue, 12 Jul 2022 17:52:49 -0700 Subject: [PATCH 3/6] fix --- evm/src/memory/memory_stark.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evm/src/memory/memory_stark.rs b/evm/src/memory/memory_stark.rs index 14a75810..37255650 100644 --- a/evm/src/memory/memory_stark.rs +++ b/evm/src/memory/memory_stark.rs @@ -256,7 +256,7 @@ impl, const D: usize> MemoryStark { fn pad_memory_ops(memory_ops: &mut Vec>) { let num_ops = memory_ops.len(); - let max_range_check = get_max_range_check(&memory_ops); + let max_range_check = get_max_range_check(memory_ops); let num_ops_padded = num_ops.max(max_range_check + 1).next_power_of_two(); let to_pad = num_ops_padded - num_ops; From a68d8ff586209df2d2492c7de19d3d6e787abc4d Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 13 Jul 2022 18:54:43 +0200 Subject: [PATCH 4/6] Avoid duplicate macros --- evm/src/cpu/kernel/assembler.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/evm/src/cpu/kernel/assembler.rs b/evm/src/cpu/kernel/assembler.rs index 179e9367..bef01d85 100644 --- a/evm/src/cpu/kernel/assembler.rs +++ b/evm/src/cpu/kernel/assembler.rs @@ -67,7 +67,8 @@ fn find_macros(files: &[File]) -> HashMap { params: params.clone(), items: items.clone(), }; - macros.insert(name.clone(), _macro); + let old = macros.insert(name.clone(), _macro); + assert!(old.is_none(), "Duplicate macro: {name}"); } } } From a8852946b3e966227109d088c18636c2a6724e2b Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Wed, 13 Jul 2022 09:53:44 -0700 Subject: [PATCH 5/6] Have `make_kernel` log the size of each (assembled) file For now it doesn't log filenames, but we can compare against the list of filenames in `combined_kernel`. Current output: ``` [DEBUG plonky2_evm::cpu::kernel::assembler] Assembled file size: 0 bytes [DEBUG plonky2_evm::cpu::kernel::assembler] Assembled file size: 49 bytes [DEBUG plonky2_evm::cpu::kernel::assembler] Assembled file size: 387 bytes [DEBUG plonky2_evm::cpu::kernel::assembler] Assembled file size: 27365 bytes [DEBUG plonky2_evm::cpu::kernel::assembler] Assembled file size: 0 bytes [DEBUG plonky2_evm::cpu::kernel::assembler] Assembled file size: 11 bytes [DEBUG plonky2_evm::cpu::kernel::assembler] Assembled file size: 7 bytes [DEBUG plonky2_evm::cpu::kernel::aggregator::tests] Total kernel size: 27819 bytes ``` This shows that most of our kernel code is from `curve_add.asm`, which makes sense since it invovles a couple uses of the large `inverse` macro. Thankfully that will be replaced at some point. --- evm/src/cpu/kernel/aggregator.rs | 7 ++++++- evm/src/cpu/kernel/assembler.rs | 4 ++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index 2b96aaf3..04d57e75 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -36,6 +36,7 @@ mod tests { use anyhow::Result; use ethereum_types::U256; + use log::debug; use rand::{thread_rng, Rng}; use crate::cpu::kernel::aggregator::combined_kernel; @@ -43,9 +44,13 @@ mod tests { #[test] fn make_kernel() { + let _ = env_logger::Builder::from_default_env() + .format_timestamp(None) + .try_init(); + // Make sure we can parse and assemble the entire kernel. let kernel = combined_kernel(); - println!("Kernel size: {} bytes", kernel.code.len()); + debug!("Total kernel size: {} bytes", kernel.code.len()); } fn u256ify<'a>(hexes: impl IntoIterator) -> Result> { diff --git a/evm/src/cpu/kernel/assembler.rs b/evm/src/cpu/kernel/assembler.rs index 179e9367..11bff6a1 100644 --- a/evm/src/cpu/kernel/assembler.rs +++ b/evm/src/cpu/kernel/assembler.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use ethereum_types::U256; use itertools::izip; +use log::debug; use super::ast::PushTarget; use crate::cpu::kernel::ast::Literal; @@ -49,7 +50,10 @@ pub(crate) fn assemble(files: Vec, constants: HashMap) -> Ke } let mut code = vec![]; for (file, locals) in izip!(expanded_files, local_labels) { + let prev_len = code.len(); assemble_file(file, &mut code, locals, &global_labels); + let file_len = code.len() - prev_len; + debug!("Assembled file size: {} bytes", file_len); } assert_eq!(code.len(), offset, "Code length doesn't match offset."); Kernel { From bfd924870f9f5956616fa8487d40cb3ad830fb74 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Wed, 13 Jul 2022 12:57:27 -0700 Subject: [PATCH 6/6] Generate most of the memory table while it's in row-wise form This should improve cache locality - since we generally access several values at a time in a given row, we want themt to be close together in memory. There are a few steps that make more sense column-wise, though, such as generating the `COUNTER` column. I put those after the transpose. --- evm/src/memory/memory_stark.rs | 232 ++++++++++++++------------------- 1 file changed, 95 insertions(+), 137 deletions(-) diff --git a/evm/src/memory/memory_stark.rs b/evm/src/memory/memory_stark.rs index 37255650..f150323b 100644 --- a/evm/src/memory/memory_stark.rs +++ b/evm/src/memory/memory_stark.rs @@ -9,7 +9,9 @@ use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::timed; use plonky2::util::timing::TimingTree; +use plonky2::util::transpose; use rand::Rng; +use rayon::prelude::*; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cross_table_lookup::Column; @@ -22,7 +24,6 @@ use crate::memory::columns::{ use crate::memory::{NUM_CHANNELS, VALUE_LIMBS}; use crate::permutation::PermutationPair; use crate::stark::Stark; -use crate::util::trace_rows_to_poly_values; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; pub(crate) const NUM_PUBLIC_INPUTS: usize = 0; @@ -56,6 +57,28 @@ pub struct MemoryOp { pub value: [F; 8], } +impl MemoryOp { + /// Generate a row for a given memory operation. Note that this does not generate columns which + /// depend on the next operation, such as `CONTEXT_FIRST_CHANGE`; those are generated later. + /// It also does not generate columns such as `COUNTER`, which are generated later, after the + /// trace has been transposed into column-major form. + fn to_row(&self) -> [F; NUM_COLUMNS] { + let mut row = [F::ZERO; NUM_COLUMNS]; + if let Some(channel) = self.channel_index { + row[is_channel(channel)] = F::ONE; + } + row[TIMESTAMP] = F::from_canonical_usize(self.timestamp); + row[IS_READ] = F::from_bool(self.is_read); + row[ADDR_CONTEXT] = F::from_canonical_usize(self.context); + row[ADDR_SEGMENT] = F::from_canonical_usize(self.segment); + row[ADDR_VIRTUAL] = F::from_canonical_usize(self.virt); + for j in 0..VALUE_LIMBS { + row[value_limb(j)] = self.value[j]; + } + row + } +} + pub fn generate_random_memory_ops( num_ops: usize, rng: &mut R, @@ -148,112 +171,76 @@ fn get_max_range_check(memory_ops: &[MemoryOp]) -> usize { .unwrap_or(0) } -pub fn generate_first_change_flags( - context: &[F], - segment: &[F], - virtuals: &[F], -) -> (Vec, Vec, Vec) { - let num_ops = context.len(); - let mut context_first_change = Vec::with_capacity(num_ops); - let mut segment_first_change = Vec::with_capacity(num_ops); - let mut virtual_first_change = Vec::with_capacity(num_ops); +/// Generates the `_FIRST_CHANGE` columns and the `RANGE_CHECK` column in the trace. +pub fn generate_first_change_flags_and_rc(trace_rows: &mut [[F; NUM_COLUMNS]]) { + let num_ops = trace_rows.len(); for idx in 0..num_ops - 1 { - let this_context_first_change = context[idx] != context[idx + 1]; - let this_segment_first_change = - segment[idx] != segment[idx + 1] && !this_context_first_change; - let this_virtual_first_change = virtuals[idx] != virtuals[idx + 1] - && !this_segment_first_change - && !this_context_first_change; + let row = trace_rows[idx].as_slice(); + let next_row = trace_rows[idx + 1].as_slice(); - context_first_change.push(F::from_bool(this_context_first_change)); - segment_first_change.push(F::from_bool(this_segment_first_change)); - virtual_first_change.push(F::from_bool(this_virtual_first_change)); + let context = row[ADDR_CONTEXT]; + let segment = row[ADDR_SEGMENT]; + let virt = row[ADDR_VIRTUAL]; + let timestamp = row[TIMESTAMP]; + let next_context = next_row[ADDR_CONTEXT]; + let next_segment = next_row[ADDR_SEGMENT]; + let next_virt = next_row[ADDR_VIRTUAL]; + let next_timestamp = next_row[TIMESTAMP]; + + let context_changed = context != next_context; + let segment_changed = segment != next_segment; + let virtual_changed = virt != next_virt; + + let context_first_change = context_changed; + let segment_first_change = segment_changed && !context_first_change; + let virtual_first_change = + virtual_changed && !segment_first_change && !context_first_change; + + let row = trace_rows[idx].as_mut_slice(); + row[CONTEXT_FIRST_CHANGE] = F::from_bool(context_first_change); + row[SEGMENT_FIRST_CHANGE] = F::from_bool(segment_first_change); + row[VIRTUAL_FIRST_CHANGE] = F::from_bool(virtual_first_change); + + row[RANGE_CHECK] = if context_first_change { + next_context - context - F::ONE + } else if segment_first_change { + next_segment - segment - F::ONE + } else if virtual_first_change { + next_virt - virt - F::ONE + } else { + next_timestamp - timestamp - F::ONE + }; } - - context_first_change.push(F::ZERO); - segment_first_change.push(F::ZERO); - virtual_first_change.push(F::ZERO); - - ( - context_first_change, - segment_first_change, - virtual_first_change, - ) -} - -pub fn generate_range_check_value( - context: &[F], - segment: &[F], - virtuals: &[F], - timestamp: &[F], - context_first_change: &[F], - segment_first_change: &[F], - virtual_first_change: &[F], -) -> Vec { - let num_ops = context.len(); - let mut range_check = Vec::new(); - - for idx in 0..num_ops - 1 { - let this_address_unchanged = F::ONE - - context_first_change[idx] - - segment_first_change[idx] - - virtual_first_change[idx]; - range_check.push( - context_first_change[idx] * (context[idx + 1] - context[idx] - F::ONE) - + segment_first_change[idx] * (segment[idx + 1] - segment[idx] - F::ONE) - + virtual_first_change[idx] * (virtuals[idx + 1] - virtuals[idx] - F::ONE) - + this_address_unchanged * (timestamp[idx + 1] - timestamp[idx] - F::ONE), - ); - } - range_check.push(F::ZERO); - - range_check } impl, const D: usize> MemoryStark { - pub(crate) fn generate_trace_rows( - &self, - mut memory_ops: Vec>, - ) -> Vec<[F; NUM_COLUMNS]> { + /// Generate most of the trace rows. Excludes a few columns like `COUNTER`, which are generated + /// later, after transposing to column-major form. + fn generate_trace_row_major(&self, mut memory_ops: Vec>) -> Vec<[F; NUM_COLUMNS]> { memory_ops.sort_by_key(|op| (op.context, op.segment, op.virt, op.timestamp)); + Self::pad_memory_ops(&mut memory_ops); - let num_ops = memory_ops.len(); - let mut trace_cols = [(); NUM_COLUMNS].map(|_| vec![F::ZERO; num_ops]); - for i in 0..num_ops { - let MemoryOp { - channel_index, - timestamp, - is_read, - context, - segment, - virt, - value, - } = memory_ops[i]; - if let Some(channel) = channel_index { - trace_cols[is_channel(channel)][i] = F::ONE; - } - trace_cols[TIMESTAMP][i] = F::from_canonical_usize(timestamp); - trace_cols[IS_READ][i] = F::from_bool(is_read); - trace_cols[ADDR_CONTEXT][i] = F::from_canonical_usize(context); - trace_cols[ADDR_SEGMENT][i] = F::from_canonical_usize(segment); - trace_cols[ADDR_VIRTUAL][i] = F::from_canonical_usize(virt); - for j in 0..VALUE_LIMBS { - trace_cols[value_limb(j)][i] = value[j]; - } - } - - self.generate_memory(&mut trace_cols); - - let mut trace_rows = vec![[F::ZERO; NUM_COLUMNS]; num_ops]; - for (i, col) in trace_cols.iter().enumerate() { - for (j, &val) in col.iter().enumerate() { - trace_rows[j][i] = val; - } - } + let mut trace_rows = memory_ops + .into_par_iter() + .map(|op| op.to_row()) + .collect::>(); + generate_first_change_flags_and_rc(trace_rows.as_mut_slice()); trace_rows } + /// Generates the `COUNTER`, `RANGE_CHECK_PERMUTED` and `COUNTER_PERMUTED` columns, given a + /// trace in column-major form. + fn generate_trace_col_major(trace_col_vecs: &mut [Vec]) { + let height = trace_col_vecs[0].len(); + trace_col_vecs[COUNTER] = (0..height).map(|i| F::from_canonical_usize(i)).collect(); + + let (permuted_inputs, permuted_table) = + permuted_cols(&trace_col_vecs[RANGE_CHECK], &trace_col_vecs[COUNTER]); + trace_col_vecs[RANGE_CHECK_PERMUTED] = permuted_inputs; + trace_col_vecs[COUNTER_PERMUTED] = permuted_table; + } + fn pad_memory_ops(memory_ops: &mut Vec>) { let num_ops = memory_ops.len(); let max_range_check = get_max_range_check(memory_ops); @@ -277,56 +264,27 @@ impl, const D: usize> MemoryStark { } } - fn generate_memory(&self, trace_cols: &mut [Vec]) { - let num_trace_rows = trace_cols[0].len(); - - let timestamp = &trace_cols[TIMESTAMP]; - let context = &trace_cols[ADDR_CONTEXT]; - let segment = &trace_cols[ADDR_SEGMENT]; - let virtuals = &trace_cols[ADDR_VIRTUAL]; - - let (context_first_change, segment_first_change, virtual_first_change) = - generate_first_change_flags(context, segment, virtuals); - - trace_cols[RANGE_CHECK] = generate_range_check_value( - context, - segment, - virtuals, - timestamp, - &context_first_change, - &segment_first_change, - &virtual_first_change, - ); - - trace_cols[CONTEXT_FIRST_CHANGE] = context_first_change; - trace_cols[SEGMENT_FIRST_CHANGE] = segment_first_change; - trace_cols[VIRTUAL_FIRST_CHANGE] = virtual_first_change; - - trace_cols[COUNTER] = (0..num_trace_rows) - .map(|i| F::from_canonical_usize(i)) - .collect(); - - let (permuted_inputs, permuted_table) = - permuted_cols(&trace_cols[RANGE_CHECK], &trace_cols[COUNTER]); - trace_cols[RANGE_CHECK_PERMUTED] = permuted_inputs; - trace_cols[COUNTER_PERMUTED] = permuted_table; - } - pub fn generate_trace(&self, memory_ops: Vec>) -> Vec> { let mut timing = TimingTree::new("generate trace", log::Level::Debug); - // Generate the witness. + // Generate most of the trace in row-major form. let trace_rows = timed!( &mut timing, "generate trace rows", - self.generate_trace_rows(memory_ops) + self.generate_trace_row_major(memory_ops) ); + let trace_row_vecs: Vec<_> = trace_rows.into_iter().map(|row| row.to_vec()).collect(); - let trace_polys = timed!( - &mut timing, - "convert to PolynomialValues", - trace_rows_to_poly_values(trace_rows) - ); + // Transpose to column-major form. + let mut trace_col_vecs = transpose(&trace_row_vecs); + + // A few final generation steps, which work better in column-major form. + Self::generate_trace_col_major(&mut trace_col_vecs); + + let trace_polys = trace_col_vecs + .into_iter() + .map(|column| PolynomialValues::new(column)) + .collect(); timing.print(); trace_polys