From 782d7d0e18e4d429ffb1da8bc4b962e8490bbaa4 Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Tue, 23 Aug 2022 12:22:54 -0700 Subject: [PATCH] Revert "Support accessing local row in CTLs" --- evm/src/all_stark.rs | 14 +-- evm/src/cpu/cpu_stark.rs | 10 +- evm/src/cross_table_lookup.rs | 179 +++++++++++++--------------------- 3 files changed, 73 insertions(+), 130 deletions(-) diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index 2c6cb9cd..e2a11ba2 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -145,7 +145,6 @@ mod tests { use crate::cpu::cpu_stark::CpuStark; use crate::cpu::kernel::aggregator::KERNEL; use crate::cross_table_lookup::testutils::check_ctls; - use crate::cross_table_lookup::Column; use crate::keccak::keccak_stark::{KeccakStark, NUM_INPUTS, NUM_ROUNDS}; use crate::logic::{self, LogicStark, Operation}; use crate::memory::memory_stark::tests::generate_random_memory_ops; @@ -217,10 +216,8 @@ mod tests { .map(|i| { (0..2 * NUM_INPUTS) .map(|j| { - // There's an extra -1 because the argument to eval_table is the local row, - // but the inputs/outputs live in the next row. - let local_row = (i + 1) * NUM_ROUNDS - 1 - 1; - keccak::columns::reg_input_limb(j).eval_table(keccak_trace, local_row) + keccak::columns::reg_input_limb(j) + .eval_table(keccak_trace, (i + 1) * NUM_ROUNDS - 1) }) .collect::>() .try_into() @@ -231,11 +228,8 @@ mod tests { .map(|i| { (0..2 * NUM_INPUTS) .map(|j| { - let out_limb = Column::single(keccak::columns::reg_output_limb(j)); - // There's an extra -1 because the argument to eval_table is the local row, - // but the inputs/outputs live in the next row. - let local_row = (i + 1) * NUM_ROUNDS - 1 - 1; - out_limb.eval_table(keccak_trace, local_row) + keccak_trace[keccak::columns::reg_output_limb(j)].values + [(i + 1) * NUM_ROUNDS - 1] }) .collect::>() .try_into() diff --git a/evm/src/cpu/cpu_stark.rs b/evm/src/cpu/cpu_stark.rs index 9036d163..918f7d9b 100644 --- a/evm/src/cpu/cpu_stark.rs +++ b/evm/src/cpu/cpu_stark.rs @@ -10,7 +10,7 @@ use plonky2::hash::hash_types::RichField; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::{CpuColumnsView, COL_MAP, NUM_CPU_COLUMNS}; use crate::cpu::{bootstrap_kernel, control_flow, decode, jumps, simple_logic, syscalls}; -use crate::cross_table_lookup::{Column, WeightedColumn}; +use crate::cross_table_lookup::Column; use crate::memory::NUM_CHANNELS; use crate::stark::Stark; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; @@ -50,14 +50,10 @@ pub fn ctl_data_memory(channel: usize) -> Vec> { .collect_vec(); cols.extend(Column::singles(COL_MAP.mem_value[channel])); - let weight = F::from_canonical_usize(NUM_CHANNELS); + let scalar = F::from_canonical_usize(NUM_CHANNELS); let addend = F::from_canonical_usize(channel); cols.push(Column::linear_combination_with_constant( - vec![WeightedColumn { - column: COL_MAP.clock, - next: true, - weight, - }], + vec![(COL_MAP.clock, scalar)], addend, )); diff --git a/evm/src/cross_table_lookup.rs b/evm/src/cross_table_lookup.rs index ef781848..4097df7b 100644 --- a/evm/src/cross_table_lookup.rs +++ b/evm/src/cross_table_lookup.rs @@ -1,3 +1,5 @@ +use std::iter::repeat; + use anyhow::{ensure, Result}; use itertools::Itertools; use plonky2::field::extension::{Extendable, FieldExtension}; @@ -22,30 +24,16 @@ use crate::stark::Stark; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; /// Represent a linear combination of columns. -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct Column { - linear_combination: Vec>, + linear_combination: Vec<(usize, F)>, constant: F, } -#[derive(Clone, Debug)] -pub(crate) struct WeightedColumn { - /// The index of the column. - pub(crate) column: usize, - /// True if this column refers to a column in the next row, rather than the local row. - /// Most CTLs consist of only "next" columns. - pub(crate) next: bool, - pub(crate) weight: F, -} - impl Column { - pub fn single(column: usize) -> Self { + pub fn single(c: usize) -> Self { Self { - linear_combination: vec![WeightedColumn { - column, - next: true, - weight: F::ONE, - }], + linear_combination: vec![(c, F::ONE)], constant: F::ZERO, } } @@ -54,17 +42,14 @@ impl Column { cs.into_iter().map(Self::single) } - pub(crate) fn linear_combination_with_constant>>( + pub fn linear_combination_with_constant>( iter: I, constant: F, ) -> Self { let v = iter.into_iter().collect::>(); assert!(!v.is_empty()); debug_assert_eq!( - v.iter() - .map(|weighted_col| weighted_col.column) - .unique() - .count(), + v.iter().map(|(c, _)| c).unique().count(), v.len(), "Duplicate columns." ); @@ -74,64 +59,35 @@ impl Column { } } - pub(crate) fn linear_combination>>(iter: I) -> Self { + pub fn linear_combination>(iter: I) -> Self { Self::linear_combination_with_constant(iter, F::ZERO) } pub fn le_bits>(cs: I) -> Self { - Self::linear_combination(cs.into_iter().zip(F::TWO.powers()).map(|(column, weight)| { - WeightedColumn { - column, - next: true, - weight, - } - })) + Self::linear_combination(cs.into_iter().zip(F::TWO.powers())) } pub fn sum>(cs: I) -> Self { - Self::linear_combination(cs.into_iter().map(|column| WeightedColumn { - column, - next: true, - weight: F::ONE, - })) + Self::linear_combination(cs.into_iter().zip(repeat(F::ONE))) } - pub fn eval(&self, local_values: &[P], next_values: &[P]) -> P + pub fn eval(&self, v: &[P]) -> P where FE: FieldExtension, P: PackedField, { self.linear_combination .iter() - .map(|weighted_col| { - let values = if weighted_col.next { - next_values - } else { - local_values - }; - values[weighted_col.column] * FE::from_basefield(weighted_col.weight) - }) + .map(|&(c, f)| v[c] * FE::from_basefield(f)) .sum::

() + FE::from_basefield(self.constant) } /// Evaluate on an row of a table given in column-major form. - pub fn eval_table(&self, table: &[PolynomialValues], local_row: usize) -> F { - let degree = table[0].len(); - debug_assert!(degree.is_power_of_two()); - let next_row = (local_row + 1) & (degree - 1); // Equivalent to % degree. - + pub fn eval_table(&self, table: &[PolynomialValues], row: usize) -> F { self.linear_combination .iter() - .map(|weighted_col| { - let row = if weighted_col.next { - next_row - } else { - local_row - }; - let poly = &table[weighted_col.column]; - poly.values[row] * weighted_col.weight - }) + .map(|&(c, f)| table[c].values[row] * f) .sum::() + self.constant } @@ -139,8 +95,7 @@ impl Column { pub fn eval_circuit( &self, builder: &mut CircuitBuilder, - local_values: &[ExtensionTarget], - next_values: &[ExtensionTarget], + v: &[ExtensionTarget], ) -> ExtensionTarget where F: RichField + Extendable, @@ -148,15 +103,10 @@ impl Column { let pairs = self .linear_combination .iter() - .map(|weighted_col| { - let values = if weighted_col.next { - next_values - } else { - local_values - }; + .map(|&(c, f)| { ( - values[weighted_col.column], - builder.constant_extension(F::Extension::from_basefield(weighted_col.weight)), + v[c], + builder.constant_extension(F::Extension::from_basefield(f)), ) }) .collect::>(); @@ -165,7 +115,7 @@ impl Column { } } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct TableWithColumns { table: Table, columns: Vec>, @@ -182,7 +132,7 @@ impl TableWithColumns { } } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct CrossTableLookup { looking_tables: Vec>, looked_table: TableWithColumns, @@ -329,19 +279,16 @@ fn partial_products( let mut partial_prod = F::ONE; let degree = trace[0].len(); let mut res = Vec::with_capacity(degree); - for next_row in 0..degree { - debug_assert!(degree.is_power_of_two()); - let local_row = (next_row + degree - 1) & (degree - 1); // Equivalent to % degree. - + for i in 0..degree { let filter = if let Some(column) = filter_column { - column.eval_table(trace, local_row) + column.eval_table(trace, i) } else { F::ONE }; if filter.is_one() { let evals = columns .iter() - .map(|c| c.eval_table(trace, local_row)) + .map(|c| c.eval_table(trace, i)) .collect::>(); partial_prod *= challenge.combine(evals.iter()); } else { @@ -439,23 +386,27 @@ pub(crate) fn eval_cross_table_lookup_checks>(); - let combined = challenges.combine(evals.iter()); - let filter = if let Some(column) = filter_column { - column.eval(vars.local_values, vars.next_values) - } else { - P::ONES + let combine = |v: &[P]| -> P { + let evals = columns.iter().map(|c| c.eval(v)).collect::>(); + challenges.combine(evals.iter()) }; - let multiplier = filter * combined + P::ONES - filter; + let filter = |v: &[P]| -> P { + if let Some(column) = filter_column { + column.eval(v) + } else { + P::ONES + } + }; + let local_filter = filter(vars.local_values); + let next_filter = filter(vars.next_values); + let select = |filter, x| filter * x + P::ONES - filter; // Check value of `Z(1)` - consumer.constraint_last_row(*next_z - multiplier); + consumer.constraint_first_row(*local_z - select(local_filter, combine(vars.local_values))); // Check `Z(gw) = combination * Z(w)` - consumer.constraint_transition(*next_z - *local_z * multiplier); + consumer.constraint_transition( + *next_z - *local_z * select(next_filter, combine(vars.next_values)), + ); } } @@ -540,12 +491,16 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit< } = lookup_vars; let one = builder.one_extension(); - let filter = if let Some(column) = filter_column { - column.eval_circuit(builder, vars.local_values, vars.next_values) + let local_filter = if let Some(column) = filter_column { + column.eval_circuit(builder, vars.local_values) + } else { + one + }; + let next_filter = if let Some(column) = filter_column { + column.eval_circuit(builder, vars.next_values) } else { one }; - // TODO: Can use builder.select_ext_generalized. fn select, const D: usize>( builder: &mut CircuitBuilder, filter: ExtensionTarget, @@ -557,17 +512,23 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit< } // Check value of `Z(1)` - let evals = columns + let local_columns_eval = columns .iter() - .map(|c| c.eval_circuit(builder, vars.local_values, vars.next_values)) + .map(|c| c.eval_circuit(builder, vars.local_values)) .collect::>(); - let combined = challenges.combine_circuit(builder, &evals); - let multiplier = select(builder, filter, combined); - let first_row = builder.sub_extension(*next_z, multiplier); - consumer.constraint_last_row(builder, first_row); + let combined_local = challenges.combine_circuit(builder, &local_columns_eval); + let selected_local = select(builder, local_filter, combined_local); + let first_row = builder.sub_extension(*local_z, selected_local); + consumer.constraint_first_row(builder, first_row); // Check `Z(gw) = combination * Z(w)` - let product = builder.mul_extension(*local_z, multiplier); - let transition = builder.sub_extension(*next_z, product); + let next_columns_eval = columns + .iter() + .map(|c| c.eval_circuit(builder, vars.next_values)) + .collect::>(); + let combined_next = challenges.combine_circuit(builder, &next_columns_eval); + let selected_next = select(builder, next_filter, combined_next); + let mut transition = builder.mul_extension(*local_z, selected_next); + transition = builder.sub_extension(*next_z, transition); consumer.constraint_transition(builder, transition); } } @@ -785,14 +746,9 @@ pub(crate) mod testutils { multiset: &mut MultiSet, ) { let trace = &trace_poly_values[table.table as usize]; - let degree = trace[0].len(); - - for next_row in 0..trace[0].len() { - debug_assert!(degree.is_power_of_two()); - let local_row = (next_row + degree - 1) & (degree - 1); // Equivalent to % degree. - + for i in 0..trace[0].len() { let filter = if let Some(column) = &table.filter_column { - column.eval_table(trace, local_row) + column.eval_table(trace, i) } else { F::ONE }; @@ -800,12 +756,9 @@ pub(crate) mod testutils { let row = table .columns .iter() - .map(|c| c.eval_table(trace, local_row)) + .map(|c| c.eval_table(trace, i)) .collect::>(); - multiset - .entry(row) - .or_default() - .push((table.table, local_row)); + multiset.entry(row).or_default().push((table.table, i)); } else { assert_eq!(filter, F::ZERO, "Non-binary filter?") }