diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index f374781a..068b0bcb 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -128,19 +128,16 @@ fn ctl_byte_packing() -> CrossTableLookup { let cpu_packing_looking = TableWithColumns::new( Table::Cpu, cpu_stark::ctl_data_byte_packing(), - vec![], Some(cpu_stark::ctl_filter_byte_packing()), ); let cpu_unpacking_looking = TableWithColumns::new( Table::Cpu, cpu_stark::ctl_data_byte_unpacking(), - vec![], Some(cpu_stark::ctl_filter_byte_unpacking()), ); let byte_packing_looked = TableWithColumns::new( Table::BytePacking, byte_packing_stark::ctl_looked_data(), - vec![], Some(byte_packing_stark::ctl_looked_filter()), ); CrossTableLookup::new( @@ -153,13 +150,11 @@ fn ctl_keccak() -> CrossTableLookup { let keccak_sponge_looking = TableWithColumns::new( Table::KeccakSponge, keccak_sponge_stark::ctl_looking_keccak(), - vec![], Some(keccak_sponge_stark::ctl_looking_keccak_filter()), ); let keccak_looked = TableWithColumns::new( Table::Keccak, keccak_stark::ctl_data(), - vec![], Some(keccak_stark::ctl_filter()), ); CrossTableLookup::new(vec![keccak_sponge_looking], keccak_looked) @@ -169,13 +164,11 @@ fn ctl_keccak_sponge() -> CrossTableLookup { let cpu_looking = TableWithColumns::new( Table::Cpu, cpu_stark::ctl_data_keccak_sponge(), - vec![], Some(cpu_stark::ctl_filter_keccak_sponge()), ); let keccak_sponge_looked = TableWithColumns::new( Table::KeccakSponge, keccak_sponge_stark::ctl_looked_data(), - vec![], Some(keccak_sponge_stark::ctl_looked_filter()), ); CrossTableLookup::new(vec![cpu_looking], keccak_sponge_looked) @@ -185,7 +178,6 @@ fn ctl_logic() -> CrossTableLookup { let cpu_looking = TableWithColumns::new( Table::Cpu, cpu_stark::ctl_data_logic(), - vec![], Some(cpu_stark::ctl_filter_logic()), ); let mut all_lookers = vec![cpu_looking]; @@ -193,17 +185,12 @@ fn ctl_logic() -> CrossTableLookup { let keccak_sponge_looking = TableWithColumns::new( Table::KeccakSponge, keccak_sponge_stark::ctl_looking_logic(i), - vec![], Some(keccak_sponge_stark::ctl_looking_logic_filter()), ); all_lookers.push(keccak_sponge_looking); } - let logic_looked = TableWithColumns::new( - Table::Logic, - logic::ctl_data(), - vec![], - Some(logic::ctl_filter()), - ); + let logic_looked = + TableWithColumns::new(Table::Logic, logic::ctl_data(), Some(logic::ctl_filter())); CrossTableLookup::new(all_lookers, logic_looked) } @@ -211,14 +198,12 @@ fn ctl_memory() -> CrossTableLookup { let cpu_memory_code_read = TableWithColumns::new( Table::Cpu, cpu_stark::ctl_data_code_memory(), - vec![], Some(cpu_stark::ctl_filter_code_memory()), ); let cpu_memory_gp_ops = (0..NUM_GP_CHANNELS).map(|channel| { TableWithColumns::new( Table::Cpu, cpu_stark::ctl_data_gp_memory(channel), - vec![], Some(cpu_stark::ctl_filter_gp_memory(channel)), ) }); @@ -226,7 +211,6 @@ fn ctl_memory() -> CrossTableLookup { TableWithColumns::new( Table::KeccakSponge, keccak_sponge_stark::ctl_looking_memory(i), - vec![], Some(keccak_sponge_stark::ctl_looking_memory_filter(i)), ) }); @@ -234,7 +218,6 @@ fn ctl_memory() -> CrossTableLookup { TableWithColumns::new( Table::BytePacking, byte_packing_stark::ctl_looking_memory(i), - vec![], Some(byte_packing_stark::ctl_looking_memory_filter(i)), ) }); @@ -246,7 +229,6 @@ fn ctl_memory() -> CrossTableLookup { let memory_looked = TableWithColumns::new( Table::Memory, memory_stark::ctl_data(), - vec![], Some(memory_stark::ctl_filter()), ); CrossTableLookup::new(all_lookers, memory_looked) diff --git a/evm/src/arithmetic/arithmetic_stark.rs b/evm/src/arithmetic/arithmetic_stark.rs index 034620a4..5441cf27 100644 --- a/evm/src/arithmetic/arithmetic_stark.rs +++ b/evm/src/arithmetic/arithmetic_stark.rs @@ -96,7 +96,6 @@ pub fn ctl_arithmetic_rows() -> TableWithColumns { TableWithColumns::new( Table::Arithmetic, cpu_arith_data_link(&COMBINED_OPS, ®ISTER_MAP), - vec![], filter_column, ) } diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index 6623b67e..1579b3a1 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -103,7 +103,6 @@ pub fn ctl_arithmetic_base_rows() -> TableWithColumns { TableWithColumns::new( Table::Cpu, columns, - vec![], Some(Column::sum([ COL_MAP.op.binary_op, COL_MAP.op.fp254_op, @@ -121,12 +120,7 @@ pub fn ctl_arithmetic_shift_rows() -> TableWithColumns { // (also `ops` is used as the operation filter). The list of // operations includes binary operations which will simply ignore // the third input. - TableWithColumns::new( - Table::Cpu, - columns, - vec![], - Some(Column::single(COL_MAP.op.shift)), - ) + TableWithColumns::new(Table::Cpu, columns, Some(Column::single(COL_MAP.op.shift))) } pub fn ctl_data_byte_packing() -> Vec> { diff --git a/evm/src/cross_table_lookup.rs b/evm/src/cross_table_lookup.rs index ecc6a34c..315bf42f 100644 --- a/evm/src/cross_table_lookup.rs +++ b/evm/src/cross_table_lookup.rs @@ -25,6 +25,7 @@ use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; #[derive(Clone, Debug)] pub struct Column { linear_combination: Vec<(usize, F)>, + next_row_linear_combination: Vec<(usize, F)>, constant: F, } @@ -32,6 +33,7 @@ impl Column { pub fn single(c: usize) -> Self { Self { linear_combination: vec![(c, F::ONE)], + next_row_linear_combination: vec![], constant: F::ZERO, } } @@ -42,9 +44,24 @@ impl Column { cs.into_iter().map(|c| Self::single(*c.borrow())) } + pub fn single_next_row(c: usize) -> Self { + Self { + linear_combination: vec![], + next_row_linear_combination: vec![(c, F::ONE)], + constant: F::ZERO, + } + } + + pub fn singles_next_row>>( + cs: I, + ) -> impl Iterator { + cs.into_iter().map(|c| Self::single_next_row(*c.borrow())) + } + pub fn constant(constant: F) -> Self { Self { linear_combination: vec![], + next_row_linear_combination: vec![], constant, } } @@ -70,6 +87,34 @@ impl Column { ); Self { linear_combination: v, + next_row_linear_combination: vec![], + constant, + } + } + + pub fn linear_combination_and_next_row_with_constant>( + iter: I, + next_row_iter: I, + constant: F, + ) -> Self { + let v = iter.into_iter().collect::>(); + let next_row_v = next_row_iter.into_iter().collect::>(); + + assert!(!v.is_empty() || !next_row_v.is_empty()); + debug_assert_eq!( + v.iter().map(|(c, _)| c).unique().count(), + v.len(), + "Duplicate columns." + ); + debug_assert_eq!( + next_row_v.iter().map(|(c, _)| c).unique().count(), + next_row_v.len(), + "Duplicate columns." + ); + + Self { + linear_combination: v, + next_row_linear_combination: next_row_v, constant, } } @@ -106,13 +151,43 @@ impl Column { + FE::from_basefield(self.constant) } + pub fn eval_with_next(&self, v: &[P], next_v: &[P]) -> P + where + FE: FieldExtension, + P: PackedField, + { + self.linear_combination + .iter() + .map(|&(c, f)| v[c] * FE::from_basefield(f)) + .sum::

() + + self + .next_row_linear_combination + .iter() + .map(|&(c, f)| next_v[c] * FE::from_basefield(f)) + .sum::

() + + FE::from_basefield(self.constant) + } + /// Evaluate on an row of a table given in column-major form. pub fn eval_table(&self, table: &[PolynomialValues], row: usize) -> F { - self.linear_combination + let mut res = self + .linear_combination .iter() .map(|&(c, f)| table[c].values[row] * f) .sum::() - + self.constant + + self.constant; + + // If we access the next row at the last row, for sanity, we consider the next row's values to be 0. + // If CTLs are correctly written, the filter should be 0 in that case anyway. + if !self.next_row_linear_combination.is_empty() && row < table.len() - 1 { + res += self + .next_row_linear_combination + .iter() + .map(|&(c, f)| table[c].values[row + 1] * f) + .sum::(); + } + + res } pub fn eval_circuit( @@ -136,27 +211,50 @@ impl Column { let constant = builder.constant_extension(F::Extension::from_basefield(self.constant)); builder.inner_product_extension(F::ONE, constant, pairs) } + + pub fn eval_with_next_circuit( + &self, + builder: &mut CircuitBuilder, + v: &[ExtensionTarget], + next_v: &[ExtensionTarget], + ) -> ExtensionTarget + where + F: RichField + Extendable, + { + let mut pairs = self + .linear_combination + .iter() + .map(|&(c, f)| { + ( + v[c], + builder.constant_extension(F::Extension::from_basefield(f)), + ) + }) + .collect::>(); + let next_row_pairs = self.next_row_linear_combination.iter().map(|&(c, f)| { + ( + next_v[c], + builder.constant_extension(F::Extension::from_basefield(f)), + ) + }); + pairs.extend(next_row_pairs); + let constant = builder.constant_extension(F::Extension::from_basefield(self.constant)); + builder.inner_product_extension(F::ONE, constant, pairs) + } } #[derive(Clone, Debug)] pub struct TableWithColumns { table: Table, - local_columns: Vec>, - next_columns: Vec>, + columns: Vec>, pub(crate) filter_column: Option>, } impl TableWithColumns { - pub fn new( - table: Table, - local_columns: Vec>, - next_columns: Vec>, - filter_column: Option>, - ) -> Self { + pub fn new(table: Table, columns: Vec>, filter_column: Option>) -> Self { Self { table, - local_columns, - next_columns, + columns, filter_column, } } @@ -175,8 +273,7 @@ impl CrossTableLookup { ) -> Self { assert!(looking_tables .iter() - .all(|twc| (twc.local_columns.len() + twc.next_columns.len()) - == (looked_table.local_columns.len() + looked_table.next_columns.len()))); + .all(|twc| twc.columns.len() == looked_table.columns.len())); Self { looking_tables, looked_table, @@ -204,8 +301,7 @@ pub struct CtlData { pub(crate) struct CtlZData { pub(crate) z: PolynomialValues, pub(crate) challenge: GrandProductChallenge, - pub(crate) local_columns: Vec>, - pub(crate) next_columns: Vec>, + pub(crate) columns: Vec>, pub(crate) filter_column: Option>, } @@ -242,16 +338,14 @@ pub(crate) fn cross_table_lookup_data( let zs_looking = looking_tables.iter().map(|table| { partial_products( &trace_poly_values[table.table as usize], - &table.local_columns, - &table.next_columns, + &table.columns, &table.filter_column, challenge, ) }); let z_looked = partial_products( &trace_poly_values[looked_table.table as usize], - &looked_table.local_columns, - &looked_table.next_columns, + &looked_table.columns, &looked_table.filter_column, challenge, ); @@ -261,8 +355,7 @@ pub(crate) fn cross_table_lookup_data( .push(CtlZData { z, challenge, - local_columns: table.local_columns.clone(), - next_columns: table.next_columns.clone(), + columns: table.columns.clone(), filter_column: table.filter_column.clone(), }); } @@ -271,8 +364,7 @@ pub(crate) fn cross_table_lookup_data( .push(CtlZData { z: z_looked, challenge, - local_columns: looked_table.local_columns.clone(), - next_columns: looked_table.next_columns.clone(), + columns: looked_table.columns.clone(), filter_column: looked_table.filter_column.clone(), }); } @@ -282,8 +374,7 @@ pub(crate) fn cross_table_lookup_data( fn partial_products( trace: &[PolynomialValues], - local_columns: &[Column], - next_columns: &[Column], + columns: &[Column], filter_column: &Option>, challenge: GrandProductChallenge, ) -> PolynomialValues { @@ -297,16 +388,9 @@ fn partial_products( F::ONE }; if filter.is_one() { - let evals = local_columns + let evals = columns .iter() .map(|c| c.eval_table(trace, i)) - .chain( - next_columns - .iter() - // The modulo is there to avoid out of bounds. For any CTL using next row - // values, we expect the filter to be 0 at the last row. - .map(|c| c.eval_table(trace, (i + 1) % degree)), - ) .collect::>(); partial_prod *= challenge.combine(evals.iter()); } else { @@ -328,8 +412,7 @@ where pub(crate) local_z: P, pub(crate) next_z: P, pub(crate) challenges: GrandProductChallenge, - pub(crate) local_columns: &'a [Column], - pub(crate) next_columns: &'a [Column], + pub(crate) columns: &'a [Column], pub(crate) filter_column: &'a Option>, } @@ -366,8 +449,7 @@ impl<'a, F: RichField + Extendable, const D: usize> local_z: *looking_z, next_z: *looking_z_next, challenges, - local_columns: &table.local_columns, - next_columns: &table.next_columns, + columns: &table.columns, filter_column: &table.filter_column, }); } @@ -377,8 +459,7 @@ impl<'a, F: RichField + Extendable, const D: usize> local_z: *looked_z, next_z: *looked_z_next, challenges, - local_columns: &looked_table.local_columns, - next_columns: &looked_table.next_columns, + columns: &looked_table.columns, filter_column: &looked_table.filter_column, }); } @@ -406,16 +487,14 @@ pub(crate) fn eval_cross_table_lookup_checks>(); - evals.extend(next_columns.iter().map(|c| c.eval(vars.next_values))); let combined = challenges.combine(evals.iter()); let local_filter = if let Some(column) = filter_column { column.eval(vars.local_values) @@ -436,8 +515,7 @@ pub struct CtlCheckVarsTarget<'a, F: Field, const D: usize> { pub(crate) local_z: ExtensionTarget, pub(crate) next_z: ExtensionTarget, pub(crate) challenges: GrandProductChallenge, - pub(crate) local_columns: &'a [Column], - pub(crate) next_columns: &'a [Column], + pub(crate) columns: &'a [Column], pub(crate) filter_column: &'a Option>, } @@ -473,8 +551,7 @@ impl<'a, F: Field, const D: usize> CtlCheckVarsTarget<'a, F, D> { local_z: *looking_z, next_z: *looking_z_next, challenges, - local_columns: &looking_table.local_columns, - next_columns: &looking_table.next_columns, + columns: &looking_table.columns, filter_column: &looking_table.filter_column, }); } @@ -486,8 +563,7 @@ impl<'a, F: Field, const D: usize> CtlCheckVarsTarget<'a, F, D> { local_z: *looked_z, next_z: *looked_z_next, challenges, - local_columns: &looked_table.local_columns, - next_columns: &looked_table.next_columns, + columns: &looked_table.columns, filter_column: &looked_table.filter_column, }); } @@ -513,8 +589,7 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit< local_z, next_z, challenges, - local_columns, - next_columns, + columns, filter_column, } = lookup_vars; @@ -534,15 +609,10 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit< builder.mul_add_extension(filter, x, tmp) // filter * x + 1 - filter } - let mut evals = local_columns + let evals = columns .iter() - .map(|c| c.eval_circuit(builder, vars.local_values)) + .map(|c| c.eval_with_next_circuit(builder, vars.local_values, vars.next_values)) .collect::>(); - evals.extend( - next_columns - .iter() - .map(|c| c.eval_circuit(builder, vars.next_values)), - ); let combined = challenges.combine_circuit(builder, &evals); let select = select(builder, local_filter, combined); @@ -692,15 +762,9 @@ pub(crate) mod testutils { }; if filter.is_one() { let row = table - .local_columns + .columns .iter() .map(|c| c.eval_table(trace, i)) - .chain( - table - .next_columns - .iter() - .map(|c| c.eval_table(trace, (i + 1) % trace[0].len())), - ) .collect::>(); multiset.entry(row).or_default().push((table.table, i)); } else { diff --git a/evm/src/keccak/keccak_stark.rs b/evm/src/keccak/keccak_stark.rs index 28e4e939..74f92622 100644 --- a/evm/src/keccak/keccak_stark.rs +++ b/evm/src/keccak/keccak_stark.rs @@ -763,8 +763,7 @@ mod tests { beta: F::ZERO, gamma: F::ZERO, }, - local_columns: vec![], - next_columns: vec![], + columns: vec![], filter_column: None, }; let ctl_data = CtlData { diff --git a/evm/src/prover.rs b/evm/src/prover.rs index 330dfcd1..0426da8e 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -589,8 +589,7 @@ where next_z: permutation_ctl_zs_commitment.get_lde_values_packed(i_next_start, step) [num_permutation_zs + i], challenges: zs_columns.challenge, - local_columns: &zs_columns.local_columns, - next_columns: &zs_columns.next_columns, + columns: &zs_columns.columns, filter_column: &zs_columns.filter_column, }) .collect::>(); @@ -708,8 +707,7 @@ fn check_constraints<'a, F, C, S, const D: usize>( local_z: permutation_ctl_zs_subgroup_evals[i][num_permutation_zs + iii], next_z: permutation_ctl_zs_subgroup_evals[i_next][num_permutation_zs + iii], challenges: zs_columns.challenge, - local_columns: &zs_columns.local_columns, - next_columns: &zs_columns.next_columns, + columns: &zs_columns.columns, filter_column: &zs_columns.filter_column, }) .collect::>();