diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index aba35828..6517e89d 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -38,12 +38,23 @@ pub(crate) fn combined_kernel() -> Kernel { #[cfg(test)] mod tests { + use std::str::FromStr; + + use anyhow::Result; + use ethereum_types::U256; + use log::debug; + use rand::{thread_rng, Rng}; + use crate::cpu::kernel::aggregator::combined_kernel; #[test] fn make_kernel() { + let _ = env_logger::Builder::from_default_env() + .format_timestamp(None) + .try_init(); + // Make sure we can parse and assemble the entire kernel. let kernel = combined_kernel(); - println!("Kernel size: {} bytes", kernel.code.len()); + debug!("Total kernel size: {} bytes", kernel.code.len()); } } diff --git a/evm/src/cpu/kernel/assembler.rs b/evm/src/cpu/kernel/assembler.rs index bef01d85..bdc8ded4 100644 --- a/evm/src/cpu/kernel/assembler.rs +++ b/evm/src/cpu/kernel/assembler.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use ethereum_types::U256; use itertools::izip; +use log::debug; use super::ast::PushTarget; use crate::cpu::kernel::ast::Literal; @@ -49,7 +50,10 @@ pub(crate) fn assemble(files: Vec, constants: HashMap) -> Ke } let mut code = vec![]; for (file, locals) in izip!(expanded_files, local_labels) { + let prev_len = code.len(); assemble_file(file, &mut code, locals, &global_labels); + let file_len = code.len() - prev_len; + debug!("Assembled file size: {} bytes", file_len); } assert_eq!(code.len(), offset, "Code length doesn't match offset."); Kernel { diff --git a/evm/src/generation/state.rs b/evm/src/generation/state.rs index 4dbd90fe..1b26a3c4 100644 --- a/evm/src/generation/state.rs +++ b/evm/src/generation/state.rs @@ -62,7 +62,7 @@ impl GenerationState { let context = self.current_context; let value = self.memory.contexts[context].segments[segment].get(virt); self.memory.log.push(MemoryOp { - channel_index, + channel_index: Some(channel_index), timestamp, is_read: true, context, @@ -84,7 +84,7 @@ impl GenerationState { let timestamp = self.cpu_rows.len(); let context = self.current_context; self.memory.log.push(MemoryOp { - channel_index, + channel_index: Some(channel_index), timestamp, is_read: false, context, diff --git a/evm/src/memory/columns.rs b/evm/src/memory/columns.rs index 65e75891..5f6c3911 100644 --- a/evm/src/memory/columns.rs +++ b/evm/src/memory/columns.rs @@ -1,7 +1,5 @@ //! Memory registers. -use std::ops::Range; - use crate::memory::{NUM_CHANNELS, VALUE_LIMBS}; // Columns for memory operations, ordered by (addr, timestamp). @@ -41,7 +39,4 @@ pub(crate) const COUNTER: usize = RANGE_CHECK + 1; pub(crate) const RANGE_CHECK_PERMUTED: usize = COUNTER + 1; pub(crate) const COUNTER_PERMUTED: usize = RANGE_CHECK_PERMUTED + 1; -// Columns to be padded at the top with zeroes, before the permutation argument takes place. -pub(crate) const COLUMNS_TO_PAD: Range = TIMESTAMP..RANGE_CHECK + 1; - pub(crate) const NUM_COLUMNS: usize = COUNTER_PERMUTED + 1; diff --git a/evm/src/memory/memory_stark.rs b/evm/src/memory/memory_stark.rs index 1653ac16..f150323b 100644 --- a/evm/src/memory/memory_stark.rs +++ b/evm/src/memory/memory_stark.rs @@ -9,20 +9,21 @@ use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::timed; use plonky2::util::timing::TimingTree; +use plonky2::util::transpose; use rand::Rng; +use rayon::prelude::*; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cross_table_lookup::Column; use crate::lookup::{eval_lookups, eval_lookups_circuit, permuted_cols}; use crate::memory::columns::{ - is_channel, value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, COLUMNS_TO_PAD, - CONTEXT_FIRST_CHANGE, COUNTER, COUNTER_PERMUTED, IS_READ, NUM_COLUMNS, RANGE_CHECK, - RANGE_CHECK_PERMUTED, SEGMENT_FIRST_CHANGE, TIMESTAMP, VIRTUAL_FIRST_CHANGE, + is_channel, value_limb, ADDR_CONTEXT, ADDR_SEGMENT, ADDR_VIRTUAL, CONTEXT_FIRST_CHANGE, + COUNTER, COUNTER_PERMUTED, IS_READ, NUM_COLUMNS, RANGE_CHECK, RANGE_CHECK_PERMUTED, + SEGMENT_FIRST_CHANGE, TIMESTAMP, VIRTUAL_FIRST_CHANGE, }; -use crate::memory::NUM_CHANNELS; +use crate::memory::{NUM_CHANNELS, VALUE_LIMBS}; use crate::permutation::PermutationPair; use crate::stark::Stark; -use crate::util::trace_rows_to_poly_values; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; pub(crate) const NUM_PUBLIC_INPUTS: usize = 0; @@ -44,9 +45,10 @@ pub struct MemoryStark { pub(crate) f: PhantomData, } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct MemoryOp { - pub channel_index: usize, + /// The channel this operation came from, or `None` if it's a dummy operation for padding. + pub channel_index: Option, pub timestamp: usize, pub is_read: bool, pub context: usize, @@ -55,6 +57,28 @@ pub struct MemoryOp { pub value: [F; 8], } +impl MemoryOp { + /// Generate a row for a given memory operation. Note that this does not generate columns which + /// depend on the next operation, such as `CONTEXT_FIRST_CHANGE`; those are generated later. + /// It also does not generate columns such as `COUNTER`, which are generated later, after the + /// trace has been transposed into column-major form. + fn to_row(&self) -> [F; NUM_COLUMNS] { + let mut row = [F::ZERO; NUM_COLUMNS]; + if let Some(channel) = self.channel_index { + row[is_channel(channel)] = F::ONE; + } + row[TIMESTAMP] = F::from_canonical_usize(self.timestamp); + row[IS_READ] = F::from_bool(self.is_read); + row[ADDR_CONTEXT] = F::from_canonical_usize(self.context); + row[ADDR_SEGMENT] = F::from_canonical_usize(self.segment); + row[ADDR_VIRTUAL] = F::from_canonical_usize(self.virt); + for j in 0..VALUE_LIMBS { + row[value_limb(j)] = self.value[j]; + } + row + } +} + pub fn generate_random_memory_ops( num_ops: usize, rng: &mut R, @@ -111,7 +135,7 @@ pub fn generate_random_memory_ops( let timestamp = clock * NUM_CHANNELS + channel_index; memory_ops.push(MemoryOp { - channel_index, + channel_index: Some(channel_index), timestamp, is_read, context, @@ -128,171 +152,139 @@ pub fn generate_random_memory_ops( memory_ops } -pub fn generate_first_change_flags( - context: &[F], - segment: &[F], - virtuals: &[F], -) -> (Vec, Vec, Vec) { - let num_ops = context.len(); - let mut context_first_change = Vec::with_capacity(num_ops); - let mut segment_first_change = Vec::with_capacity(num_ops); - let mut virtual_first_change = Vec::with_capacity(num_ops); - for idx in 0..num_ops - 1 { - let this_context_first_change = context[idx] != context[idx + 1]; - let this_segment_first_change = - segment[idx] != segment[idx + 1] && !this_context_first_change; - let this_virtual_first_change = virtuals[idx] != virtuals[idx + 1] - && !this_segment_first_change - && !this_context_first_change; - - context_first_change.push(F::from_bool(this_context_first_change)); - segment_first_change.push(F::from_bool(this_segment_first_change)); - virtual_first_change.push(F::from_bool(this_virtual_first_change)); - } - - context_first_change.push(F::ZERO); - segment_first_change.push(F::ZERO); - virtual_first_change.push(F::ZERO); - - ( - context_first_change, - segment_first_change, - virtual_first_change, - ) +fn get_max_range_check(memory_ops: &[MemoryOp]) -> usize { + memory_ops + .iter() + .tuple_windows() + .map(|(curr, next)| { + if curr.context != next.context { + next.context - curr.context - 1 + } else if curr.segment != next.segment { + next.segment - curr.segment - 1 + } else if curr.virt != next.virt { + next.virt - curr.virt - 1 + } else { + next.timestamp - curr.timestamp - 1 + } + }) + .max() + .unwrap_or(0) } -pub fn generate_range_check_value( - context: &[F], - segment: &[F], - virtuals: &[F], - timestamp: &[F], - context_first_change: &[F], - segment_first_change: &[F], - virtual_first_change: &[F], -) -> (Vec, usize) { - let num_ops = context.len(); - let mut range_check = Vec::new(); - +/// Generates the `_FIRST_CHANGE` columns and the `RANGE_CHECK` column in the trace. +pub fn generate_first_change_flags_and_rc(trace_rows: &mut [[F; NUM_COLUMNS]]) { + let num_ops = trace_rows.len(); for idx in 0..num_ops - 1 { - let this_address_unchanged = F::ONE - - context_first_change[idx] - - segment_first_change[idx] - - virtual_first_change[idx]; - range_check.push( - context_first_change[idx] * (context[idx + 1] - context[idx] - F::ONE) - + segment_first_change[idx] * (segment[idx + 1] - segment[idx] - F::ONE) - + virtual_first_change[idx] * (virtuals[idx + 1] - virtuals[idx] - F::ONE) - + this_address_unchanged * (timestamp[idx + 1] - timestamp[idx] - F::ONE), - ); + let row = trace_rows[idx].as_slice(); + let next_row = trace_rows[idx + 1].as_slice(); + + let context = row[ADDR_CONTEXT]; + let segment = row[ADDR_SEGMENT]; + let virt = row[ADDR_VIRTUAL]; + let timestamp = row[TIMESTAMP]; + let next_context = next_row[ADDR_CONTEXT]; + let next_segment = next_row[ADDR_SEGMENT]; + let next_virt = next_row[ADDR_VIRTUAL]; + let next_timestamp = next_row[TIMESTAMP]; + + let context_changed = context != next_context; + let segment_changed = segment != next_segment; + let virtual_changed = virt != next_virt; + + let context_first_change = context_changed; + let segment_first_change = segment_changed && !context_first_change; + let virtual_first_change = + virtual_changed && !segment_first_change && !context_first_change; + + let row = trace_rows[idx].as_mut_slice(); + row[CONTEXT_FIRST_CHANGE] = F::from_bool(context_first_change); + row[SEGMENT_FIRST_CHANGE] = F::from_bool(segment_first_change); + row[VIRTUAL_FIRST_CHANGE] = F::from_bool(virtual_first_change); + + row[RANGE_CHECK] = if context_first_change { + next_context - context - F::ONE + } else if segment_first_change { + next_segment - segment - F::ONE + } else if virtual_first_change { + next_virt - virt - F::ONE + } else { + next_timestamp - timestamp - F::ONE + }; } - range_check.push(F::ZERO); - - let max_diff = range_check.iter().map(F::to_canonical_u64).max().unwrap() as usize; - - (range_check, max_diff) } impl, const D: usize> MemoryStark { - pub(crate) fn generate_trace_rows( - &self, - mut memory_ops: Vec>, - ) -> Vec<[F; NUM_COLUMNS]> { + /// Generate most of the trace rows. Excludes a few columns like `COUNTER`, which are generated + /// later, after transposing to column-major form. + fn generate_trace_row_major(&self, mut memory_ops: Vec>) -> Vec<[F; NUM_COLUMNS]> { memory_ops.sort_by_key(|op| (op.context, op.segment, op.virt, op.timestamp)); - let num_ops = memory_ops.len(); + Self::pad_memory_ops(&mut memory_ops); - let mut trace_cols = [(); NUM_COLUMNS].map(|_| vec![F::ZERO; num_ops]); - for i in 0..num_ops { - let MemoryOp { - channel_index, - timestamp, - is_read, - context, - segment, - virt, - value, - } = memory_ops[i]; - trace_cols[is_channel(channel_index)][i] = F::ONE; - trace_cols[TIMESTAMP][i] = F::from_canonical_usize(timestamp); - trace_cols[IS_READ][i] = F::from_bool(is_read); - trace_cols[ADDR_CONTEXT][i] = F::from_canonical_usize(context); - trace_cols[ADDR_SEGMENT][i] = F::from_canonical_usize(segment); - trace_cols[ADDR_VIRTUAL][i] = F::from_canonical_usize(virt); - for j in 0..8 { - trace_cols[value_limb(j)][i] = value[j]; - } - } - - self.generate_memory(&mut trace_cols); - - // The number of rows may have changed, if the range check required padding. - let num_ops = trace_cols[0].len(); - - let mut trace_rows = vec![[F::ZERO; NUM_COLUMNS]; num_ops]; - for (i, col) in trace_cols.iter().enumerate() { - for (j, &val) in col.iter().enumerate() { - trace_rows[j][i] = val; - } - } + let mut trace_rows = memory_ops + .into_par_iter() + .map(|op| op.to_row()) + .collect::>(); + generate_first_change_flags_and_rc(trace_rows.as_mut_slice()); trace_rows } - fn generate_memory(&self, trace_cols: &mut [Vec]) { - let num_trace_rows = trace_cols[0].len(); - - let timestamp = &trace_cols[TIMESTAMP]; - let context = &trace_cols[ADDR_CONTEXT]; - let segment = &trace_cols[ADDR_SEGMENT]; - let virtuals = &trace_cols[ADDR_VIRTUAL]; - - let (context_first_change, segment_first_change, virtual_first_change) = - generate_first_change_flags(context, segment, virtuals); - - let (range_check_value, max_diff) = generate_range_check_value( - context, - segment, - virtuals, - timestamp, - &context_first_change, - &segment_first_change, - &virtual_first_change, - ); - let to_pad_to = (max_diff + 1).max(num_trace_rows).next_power_of_two(); - let to_pad = to_pad_to - num_trace_rows; - - trace_cols[CONTEXT_FIRST_CHANGE] = context_first_change; - trace_cols[SEGMENT_FIRST_CHANGE] = segment_first_change; - trace_cols[VIRTUAL_FIRST_CHANGE] = virtual_first_change; - - trace_cols[RANGE_CHECK] = range_check_value; - - for col in COLUMNS_TO_PAD { - trace_cols[col].splice(0..0, vec![F::ZERO; to_pad]); - } - - trace_cols[COUNTER] = (0..to_pad_to).map(|i| F::from_canonical_usize(i)).collect(); + /// Generates the `COUNTER`, `RANGE_CHECK_PERMUTED` and `COUNTER_PERMUTED` columns, given a + /// trace in column-major form. + fn generate_trace_col_major(trace_col_vecs: &mut [Vec]) { + let height = trace_col_vecs[0].len(); + trace_col_vecs[COUNTER] = (0..height).map(|i| F::from_canonical_usize(i)).collect(); let (permuted_inputs, permuted_table) = - permuted_cols(&trace_cols[RANGE_CHECK], &trace_cols[COUNTER]); - trace_cols[RANGE_CHECK_PERMUTED] = permuted_inputs; - trace_cols[COUNTER_PERMUTED] = permuted_table; + permuted_cols(&trace_col_vecs[RANGE_CHECK], &trace_col_vecs[COUNTER]); + trace_col_vecs[RANGE_CHECK_PERMUTED] = permuted_inputs; + trace_col_vecs[COUNTER_PERMUTED] = permuted_table; + } + + fn pad_memory_ops(memory_ops: &mut Vec>) { + let num_ops = memory_ops.len(); + let max_range_check = get_max_range_check(memory_ops); + let num_ops_padded = num_ops.max(max_range_check + 1).next_power_of_two(); + let to_pad = num_ops_padded - num_ops; + + let last_op = memory_ops.last().expect("No memory ops?").clone(); + + // We essentially repeat the last operation until our operation list has the desired size, + // with a few changes: + // - We change its channel to `None` to indicate that this is a dummy operation. + // - We increment its timestamp in order to pass the ordering check. + // - We make sure it's a read, sine dummy operations must be reads. + for i in 0..to_pad { + memory_ops.push(MemoryOp { + channel_index: None, + timestamp: last_op.timestamp + i + 1, + is_read: true, + ..last_op + }); + } } pub fn generate_trace(&self, memory_ops: Vec>) -> Vec> { let mut timing = TimingTree::new("generate trace", log::Level::Debug); - // Generate the witness. + // Generate most of the trace in row-major form. let trace_rows = timed!( &mut timing, "generate trace rows", - self.generate_trace_rows(memory_ops) + self.generate_trace_row_major(memory_ops) ); + let trace_row_vecs: Vec<_> = trace_rows.into_iter().map(|row| row.to_vec()).collect(); - let trace_polys = timed!( - &mut timing, - "convert to PolynomialValues", - trace_rows_to_poly_values(trace_rows) - ); + // Transpose to column-major form. + let mut trace_col_vecs = transpose(&trace_row_vecs); + + // A few final generation steps, which work better in column-major form. + Self::generate_trace_col_major(&mut trace_col_vecs); + + let trace_polys = trace_col_vecs + .into_iter() + .map(|column| PolynomialValues::new(column)) + .collect(); timing.print(); trace_polys @@ -326,11 +318,23 @@ impl, const D: usize> Stark for MemoryStark = (0..8).map(|i| vars.next_values[value_limb(i)]).collect(); - // Indicator that this is a real row, not a row of padding. - // TODO: enforce that all padding is at the beginning. - let valid_row: P = (0..NUM_CHANNELS) + // Each `is_channel` value must be 0 or 1. + for c in 0..NUM_CHANNELS { + let is_channel = vars.local_values[is_channel(c)]; + yield_constr.constraint(is_channel * (is_channel - P::ONES)); + } + + // The sum of `is_channel` flags, `has_channel`, must also be 0 or 1. + let has_channel: P = (0..NUM_CHANNELS) .map(|c| vars.local_values[is_channel(c)]) .sum(); + yield_constr.constraint(has_channel * (has_channel - P::ONES)); + + // If this is a dummy row (with no channel), it must be a read. This means the prover can + // insert reads which never appear in the CPU trace (which are harmless), but not writes. + let is_dummy = P::ONES - has_channel; + let is_write = P::ONES - vars.local_values[IS_READ]; + yield_constr.constraint(is_dummy * is_write); let context_first_change = vars.local_values[CONTEXT_FIRST_CHANGE]; let segment_first_change = vars.local_values[SEGMENT_FIRST_CHANGE]; @@ -358,21 +362,15 @@ impl, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark, const D: usize> Stark for MemoryStark( let subgroup = F::two_adic_subgroup(degree_bits + rate_bits); - // Retrieve the polynomials values at index `i`. - let get_comm_values = |comm: &PolynomialBatch, i| -> Vec { - comm.polynomials - .iter() - .map(|poly| poly.eval(subgroup[i])) // O(n^2) FTW - .collect() + // Get the evaluations of a batch of polynomials over our subgroup. + let get_subgroup_evals = |comm: &PolynomialBatch| -> Vec> { + let values = comm + .polynomials + .par_iter() + .map(|coeffs| coeffs.clone().fft().values) + .collect::>(); + transpose(&values) }; + let trace_subgroup_evals = get_subgroup_evals(trace_commitment); + let permutation_ctl_zs_subgroup_evals = get_subgroup_evals(permutation_ctl_zs_commitment); + // Last element of the subgroup. let last = F::primitive_root_of_unity(degree_bits).inverse(); @@ -519,19 +524,14 @@ fn check_constraints<'a, F, C, S, const D: usize>( lagrange_basis_last, ); let vars = StarkEvaluationVars { - local_values: &get_comm_values(trace_commitment, i).try_into().unwrap(), - next_values: &get_comm_values(trace_commitment, i_next) - .try_into() - .unwrap(), + local_values: trace_subgroup_evals[i].as_slice().try_into().unwrap(), + next_values: trace_subgroup_evals[i_next].as_slice().try_into().unwrap(), public_inputs: &public_inputs, }; let permutation_check_vars = permutation_challenges.map(|permutation_challenge_sets| PermutationCheckVars { - local_zs: get_comm_values(permutation_ctl_zs_commitment, i) - [..num_permutation_zs] - .to_vec(), - next_zs: get_comm_values(permutation_ctl_zs_commitment, i_next) - [..num_permutation_zs] + local_zs: permutation_ctl_zs_subgroup_evals[i][..num_permutation_zs].to_vec(), + next_zs: permutation_ctl_zs_subgroup_evals[i_next][..num_permutation_zs] .to_vec(), permutation_challenge_sets: permutation_challenge_sets.to_vec(), }); @@ -542,10 +542,8 @@ fn check_constraints<'a, F, C, S, const D: usize>( .enumerate() .map( |(iii, (_, columns, filter_column))| CtlCheckVars:: { - local_z: get_comm_values(permutation_ctl_zs_commitment, i) - [num_permutation_zs + iii], - next_z: get_comm_values(permutation_ctl_zs_commitment, i_next) - [num_permutation_zs + iii], + local_z: permutation_ctl_zs_subgroup_evals[i][num_permutation_zs + iii], + next_z: permutation_ctl_zs_subgroup_evals[i_next][num_permutation_zs + iii], challenges: ctl_data.challenges.challenges[iii % config.num_challenges], columns, filter_column,