From a37dec9881e2d6ff4d02cf998f1da19c40b0150e Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Mon, 22 Aug 2022 11:07:39 -0700 Subject: [PATCH] Support accessing previous row in CTLs --- evm/src/all_stark.rs | 14 ++- evm/src/cpu/cpu_stark.rs | 10 +- evm/src/cross_table_lookup.rs | 186 ++++++++++++++++++++++------------ 3 files changed, 137 insertions(+), 73 deletions(-) diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index e2a11ba2..2c6cb9cd 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -145,6 +145,7 @@ mod tests { use crate::cpu::cpu_stark::CpuStark; use crate::cpu::kernel::aggregator::KERNEL; use crate::cross_table_lookup::testutils::check_ctls; + use crate::cross_table_lookup::Column; use crate::keccak::keccak_stark::{KeccakStark, NUM_INPUTS, NUM_ROUNDS}; use crate::logic::{self, LogicStark, Operation}; use crate::memory::memory_stark::tests::generate_random_memory_ops; @@ -216,8 +217,10 @@ mod tests { .map(|i| { (0..2 * NUM_INPUTS) .map(|j| { - keccak::columns::reg_input_limb(j) - .eval_table(keccak_trace, (i + 1) * NUM_ROUNDS - 1) + // There's an extra -1 because the argument to eval_table is the local row, + // but the inputs/outputs live in the next row. + let local_row = (i + 1) * NUM_ROUNDS - 1 - 1; + keccak::columns::reg_input_limb(j).eval_table(keccak_trace, local_row) }) .collect::>() .try_into() @@ -228,8 +231,11 @@ mod tests { .map(|i| { (0..2 * NUM_INPUTS) .map(|j| { - keccak_trace[keccak::columns::reg_output_limb(j)].values - [(i + 1) * NUM_ROUNDS - 1] + let out_limb = Column::single(keccak::columns::reg_output_limb(j)); + // There's an extra -1 because the argument to eval_table is the local row, + // but the inputs/outputs live in the next row. + let local_row = (i + 1) * NUM_ROUNDS - 1 - 1; + out_limb.eval_table(keccak_trace, local_row) }) .collect::>() .try_into() diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index 918f7d9b..9036d163 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -10,7 +10,7 @@ use plonky2::hash::hash_types::RichField; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::{CpuColumnsView, COL_MAP, NUM_CPU_COLUMNS}; use crate::cpu::{bootstrap_kernel, control_flow, decode, jumps, simple_logic, syscalls}; -use crate::cross_table_lookup::Column; +use crate::cross_table_lookup::{Column, WeightedColumn}; use crate::memory::NUM_CHANNELS; use crate::stark::Stark; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; @@ -50,10 +50,14 @@ pub fn ctl_data_memory(channel: usize) -> Vec> { .collect_vec(); cols.extend(Column::singles(COL_MAP.mem_value[channel])); - let scalar = F::from_canonical_usize(NUM_CHANNELS); + let weight = F::from_canonical_usize(NUM_CHANNELS); let addend = F::from_canonical_usize(channel); cols.push(Column::linear_combination_with_constant( - vec![(COL_MAP.clock, scalar)], + vec![WeightedColumn { + column: COL_MAP.clock, + next: true, + weight, + }], addend, )); diff --git a/evm/src/cross_table_lookup.rs b/evm/src/cross_table_lookup.rs index 4097df7b..7f55c40a 100644 --- a/evm/src/cross_table_lookup.rs +++ b/evm/src/cross_table_lookup.rs @@ -1,5 +1,3 @@ -use std::iter::repeat; - use anyhow::{ensure, Result}; use itertools::Itertools; use plonky2::field::extension::{Extendable, FieldExtension}; @@ -24,16 +22,30 @@ use crate::stark::Stark; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; /// Represent a linear combination of columns. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Column { - linear_combination: Vec<(usize, F)>, + linear_combination: Vec>, constant: F, } +#[derive(Clone, Debug)] +pub(crate) struct WeightedColumn { + /// The index of the column. + pub(crate) column: usize, + /// True if this column refers to a column in the next row, rather than the local row. + /// Most CTLs consist of only "next" columns. + pub(crate) next: bool, + pub(crate) weight: F, +} + impl Column { - pub fn single(c: usize) -> Self { + pub fn single(column: usize) -> Self { Self { - linear_combination: vec![(c, F::ONE)], + linear_combination: vec![WeightedColumn { + column, + next: true, + weight: F::ONE, + }], constant: F::ZERO, } } @@ -42,14 +54,17 @@ impl Column { cs.into_iter().map(Self::single) } - pub fn linear_combination_with_constant>( + pub(crate) fn linear_combination_with_constant>>( iter: I, constant: F, ) -> Self { let v = iter.into_iter().collect::>(); assert!(!v.is_empty()); debug_assert_eq!( - v.iter().map(|(c, _)| c).unique().count(), + v.iter() + .map(|weighted_col| weighted_col.column) + .unique() + .count(), v.len(), "Duplicate columns." ); @@ -59,35 +74,65 @@ impl Column { } } - pub fn linear_combination>(iter: I) -> Self { + pub(crate) fn linear_combination>>(iter: I) -> Self { Self::linear_combination_with_constant(iter, F::ZERO) } pub fn le_bits>(cs: I) -> Self { - Self::linear_combination(cs.into_iter().zip(F::TWO.powers())) + Self::linear_combination(cs.into_iter().zip(F::TWO.powers()).map(|(column, weight)| { + WeightedColumn { + column, + next: true, + weight, + } + })) } pub fn sum>(cs: I) -> Self { - Self::linear_combination(cs.into_iter().zip(repeat(F::ONE))) + Self::linear_combination(cs.into_iter().map(|column| WeightedColumn { + column, + next: true, + weight: F::ONE, + })) } - pub fn eval(&self, v: &[P]) -> P + pub fn eval(&self, local_values: &[P], next_values: &[P]) -> P where FE: FieldExtension, P: PackedField, { self.linear_combination .iter() - .map(|&(c, f)| v[c] * FE::from_basefield(f)) + .map(|weighted_col| { + let values = if weighted_col.next { + next_values + } else { + local_values + }; + values[weighted_col.column] * FE::from_basefield(weighted_col.weight) + }) .sum::

() + FE::from_basefield(self.constant) } /// Evaluate on an row of a table given in column-major form. - pub fn eval_table(&self, table: &[PolynomialValues], row: usize) -> F { + pub fn eval_table(&self, table: &[PolynomialValues], local_row: usize) -> F { + let mut next_row = local_row + 1; + if next_row == table[0].len() { + next_row = 0; + } + self.linear_combination .iter() - .map(|&(c, f)| table[c].values[row] * f) + .map(|weighted_col| { + let row = if weighted_col.next { + next_row + } else { + local_row + }; + let poly = &table[weighted_col.column]; + poly.values[row] * weighted_col.weight + }) .sum::() + self.constant } @@ -95,7 +140,8 @@ impl Column { pub fn eval_circuit( &self, builder: &mut CircuitBuilder, - v: &[ExtensionTarget], + local_values: &[ExtensionTarget], + next_values: &[ExtensionTarget], ) -> ExtensionTarget where F: RichField + Extendable, @@ -103,10 +149,15 @@ impl Column { let pairs = self .linear_combination .iter() - .map(|&(c, f)| { + .map(|weighted_col| { + let values = if weighted_col.next { + next_values + } else { + local_values + }; ( - v[c], - builder.constant_extension(F::Extension::from_basefield(f)), + values[weighted_col.column], + builder.constant_extension(F::Extension::from_basefield(weighted_col.weight)), ) }) .collect::>(); @@ -115,7 +166,7 @@ impl Column { } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct TableWithColumns { table: Table, columns: Vec>, @@ -132,7 +183,7 @@ impl TableWithColumns { } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct CrossTableLookup { looking_tables: Vec>, looked_table: TableWithColumns, @@ -279,16 +330,22 @@ fn partial_products( let mut partial_prod = F::ONE; let degree = trace[0].len(); let mut res = Vec::with_capacity(degree); - for i in 0..degree { + for next_row in 0..degree { + let local_row = if next_row == 0 { + degree - 1 + } else { + next_row - 1 + }; + let filter = if let Some(column) = filter_column { - column.eval_table(trace, i) + column.eval_table(trace, local_row) } else { F::ONE }; if filter.is_one() { let evals = columns .iter() - .map(|c| c.eval_table(trace, i)) + .map(|c| c.eval_table(trace, local_row)) .collect::>(); partial_prod *= challenge.combine(evals.iter()); } else { @@ -386,27 +443,23 @@ pub(crate) fn eval_cross_table_lookup_checks P { - let evals = columns.iter().map(|c| c.eval(v)).collect::>(); - challenges.combine(evals.iter()) + // TODO: Avoid collecting here. + let evals = columns + .iter() + .map(|c| c.eval(vars.local_values, vars.next_values)) + .collect::>(); + let combined = challenges.combine(evals.iter()); + let filter = if let Some(column) = filter_column { + column.eval(vars.local_values, vars.next_values) + } else { + P::ONES }; - let filter = |v: &[P]| -> P { - if let Some(column) = filter_column { - column.eval(v) - } else { - P::ONES - } - }; - let local_filter = filter(vars.local_values); - let next_filter = filter(vars.next_values); - let select = |filter, x| filter * x + P::ONES - filter; + let multiplier = filter * combined + P::ONES - filter; // Check value of `Z(1)` - consumer.constraint_first_row(*local_z - select(local_filter, combine(vars.local_values))); + consumer.constraint_last_row(*next_z - multiplier); // Check `Z(gw) = combination * Z(w)` - consumer.constraint_transition( - *next_z - *local_z * select(next_filter, combine(vars.next_values)), - ); + consumer.constraint_transition(*next_z - *local_z * multiplier); } } @@ -491,16 +544,12 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit< } = lookup_vars; let one = builder.one_extension(); - let local_filter = if let Some(column) = filter_column { - column.eval_circuit(builder, vars.local_values) - } else { - one - }; - let next_filter = if let Some(column) = filter_column { - column.eval_circuit(builder, vars.next_values) + let filter = if let Some(column) = filter_column { + column.eval_circuit(builder, vars.local_values, vars.next_values) } else { one }; + // TODO: Can use builder.select_ext_generalized. fn select, const D: usize>( builder: &mut CircuitBuilder, filter: ExtensionTarget, @@ -512,23 +561,17 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit< } // Check value of `Z(1)` - let local_columns_eval = columns + let evals = columns .iter() - .map(|c| c.eval_circuit(builder, vars.local_values)) + .map(|c| c.eval_circuit(builder, vars.local_values, vars.next_values)) .collect::>(); - let combined_local = challenges.combine_circuit(builder, &local_columns_eval); - let selected_local = select(builder, local_filter, combined_local); - let first_row = builder.sub_extension(*local_z, selected_local); - consumer.constraint_first_row(builder, first_row); + let combined = challenges.combine_circuit(builder, &evals); + let multiplier = select(builder, filter, combined); + let first_row = builder.sub_extension(*next_z, multiplier); + consumer.constraint_last_row(builder, first_row); // Check `Z(gw) = combination * Z(w)` - let next_columns_eval = columns - .iter() - .map(|c| c.eval_circuit(builder, vars.next_values)) - .collect::>(); - let combined_next = challenges.combine_circuit(builder, &next_columns_eval); - let selected_next = select(builder, next_filter, combined_next); - let mut transition = builder.mul_extension(*local_z, selected_next); - transition = builder.sub_extension(*next_z, transition); + let product = builder.mul_extension(*local_z, multiplier); + let transition = builder.sub_extension(*next_z, product); consumer.constraint_transition(builder, transition); } } @@ -746,9 +789,17 @@ pub(crate) mod testutils { multiset: &mut MultiSet, ) { let trace = &trace_poly_values[table.table as usize]; - for i in 0..trace[0].len() { + let degree = trace[0].len(); + + for next_row in 0..trace[0].len() { + let local_row = if next_row == 0 { + degree - 1 + } else { + next_row - 1 + }; + let filter = if let Some(column) = &table.filter_column { - column.eval_table(trace, i) + column.eval_table(trace, local_row) } else { F::ONE }; @@ -756,9 +807,12 @@ pub(crate) mod testutils { let row = table .columns .iter() - .map(|c| c.eval_table(trace, i)) + .map(|c| c.eval_table(trace, local_row)) .collect::>(); - multiset.entry(row).or_default().push((table.table, i)); + multiset + .entry(row) + .or_default() + .push((table.table, local_row)); } else { assert_eq!(filter, F::ZERO, "Non-binary filter?") }