diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index fd789bb1..b5e825fd 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -62,7 +62,7 @@ mod tests { use crate::config::StarkConfig; use crate::cpu::columns::{KECCAK_INPUT_LIMBS, KECCAK_OUTPUT_LIMBS}; use crate::cpu::cpu_stark::CpuStark; - use crate::cross_table_lookup::{CrossTableLookup, TableWithColumns}; + use crate::cross_table_lookup::{Column, CrossTableLookup, TableWithColumns}; use crate::keccak::keccak_stark::{KeccakStark, NUM_INPUTS, NUM_ROUNDS}; use crate::proof::AllProof; use crate::prover::prove; @@ -148,13 +148,13 @@ mod tests { let cross_table_lookups = vec![CrossTableLookup::new( vec![TableWithColumns::new( Table::Cpu, - cpu_keccak_input_output, - vec![cpu::columns::IS_KECCAK], + Column::singles(cpu_keccak_input_output), + Column::single(cpu::columns::IS_KECCAK), )], TableWithColumns::new( Table::Keccak, - keccak_keccak_input_output, - vec![keccak::registers::reg_step(NUM_ROUNDS - 1)], + Column::singles(keccak_keccak_input_output), + Column::single(keccak::registers::reg_step(NUM_ROUNDS - 1)), ), None, )]; diff --git a/evm/src/cross_table_lookup.rs b/evm/src/cross_table_lookup.rs index 24d0878e..97610c5f 100644 --- a/evm/src/cross_table_lookup.rs +++ b/evm/src/cross_table_lookup.rs @@ -1,3 +1,5 @@ +use std::iter::repeat; + use anyhow::{ensure, Result}; use itertools::Itertools; use plonky2::field::extension_field::{Extendable, FieldExtension}; @@ -21,42 +23,129 @@ use crate::proof::{StarkProofWithPublicInputs, StarkProofWithPublicInputsTarget} use crate::stark::Stark; use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; +/// Represent a column or a linear combination of columns. #[derive(Clone)] -pub struct TableWithColumns { - table: Table, - columns: Vec, - /// Vector of columns `[c_1,...,c_k]` used as a filter using the sum `c_1 + ... + c_k`. - /// An empty vector corresponds to no filter. - filter_columns: Vec, +pub enum Column { + Single(usize), + LinearCombination(Vec<(usize, F)>), + Empty, } -impl TableWithColumns { - pub fn new(table: Table, columns: Vec, filter_columns: Vec) -> Self { +impl Column { + pub fn single(c: usize) -> Self { + Self::Single(c) + } + + pub fn singles(cs: Vec) -> Vec { + cs.into_iter().map(Self::single).collect() + } + + pub fn linear_combination>(iter: I) -> Self { + let v = iter.into_iter().collect::>(); + assert!(!v.is_empty()); debug_assert_eq!( - filter_columns.iter().unique().count(), - filter_columns.len(), + v.iter().map(|(c, _)| c).unique().count(), + v.len(), "Duplicate filter columns." ); + Self::LinearCombination(v) + } + + pub fn le_bits(cs: &[usize]) -> Self { + Self::linear_combination(cs.iter().copied().zip(F::TWO.powers())) + } + + pub fn sum(cs: &[usize]) -> Self { + Self::linear_combination(cs.iter().copied().zip(repeat(F::ONE))) + } + + pub fn is_empty(&self) -> bool { + matches!(self, Self::Empty) + } + + pub fn eval(&self, v: &[P]) -> P + where + FE: FieldExtension, + P: PackedField, + { + match self { + Column::Single(c) => v[*c], + Column::LinearCombination(cs) => { + cs.iter().map(|&(c, f)| v[c] * FE::from_basefield(f)).sum() + } + Column::Empty => panic!("Cannot eval with empty column."), + } + } + + /// Evaluate on an row of a table given in column-major form. + pub fn eval_table(&self, table: &[PolynomialValues], row: usize) -> F { + match self { + Column::Single(c) => table[*c].values[row], + Column::LinearCombination(cs) => { + cs.iter().map(|&(c, f)| table[c].values[row] * f).sum() + } + Column::Empty => panic!("Cannot eval with empty column."), + } + } + + pub fn eval_circuit( + &self, + builder: &mut CircuitBuilder, + v: &[ExtensionTarget], + ) -> ExtensionTarget + where + F: RichField + Extendable, + { + match self { + Column::Single(c) => v[*c], + Column::LinearCombination(cs) => { + let pairs = cs + .iter() + .map(|&(c, f)| { + ( + v[c], + builder.constant_extension(F::Extension::from_basefield(f)), + ) + }) + .collect::>(); + let zero = builder.zero_extension(); + builder.inner_product_extension(F::ONE, zero, pairs) + } + Column::Empty => panic!("Cannot eval with empty column."), + } + } +} + +#[derive(Clone)] +pub struct TableWithColumns { + table: Table, + columns: Vec>, + filter_column: Column, +} + +impl TableWithColumns { + pub fn new(table: Table, columns: Vec>, filter_column: Column) -> Self { + assert!(columns.iter().all(|c| !c.is_empty())); Self { table, columns, - filter_columns, + filter_column, } } } #[derive(Clone)] pub struct CrossTableLookup { - looking_tables: Vec, - looked_table: TableWithColumns, + looking_tables: Vec>, + looked_table: TableWithColumns, /// Default value if filters are not used. default: Option>, } impl CrossTableLookup { pub fn new( - looking_tables: Vec, - looked_table: TableWithColumns, + looking_tables: Vec>, + looked_table: TableWithColumns, default: Option>, ) -> Self { assert!(looking_tables @@ -65,8 +154,8 @@ impl CrossTableLookup { assert!( looking_tables .iter() - .all(|twc| twc.filter_columns.is_empty() == default.is_some()) - && default.is_some() == looked_table.filter_columns.is_empty(), + .all(|twc| twc.filter_column.is_empty() == default.is_some()) + && default.is_some() == looked_table.filter_column.is_empty(), "Default values should be provided iff there are no filter columns." ); if let Some(default) = &default { @@ -87,7 +176,7 @@ pub struct CtlData { pub(crate) challenges: GrandProductChallengeSet, /// Vector of `(Z, columns, filter_columns)` where `Z` is a Z-polynomial for a lookup /// on columns `columns` with filter columns `filter_columns`. - pub zs_columns: Vec<(PolynomialValues, Vec, Vec)>, + pub zs_columns: Vec<(PolynomialValues, Vec>, Column)>, } impl CtlData { @@ -130,14 +219,14 @@ pub fn cross_table_lookup_data, const D partial_products( &trace_poly_values[table.table as usize], &table.columns, - &table.filter_columns, + &table.filter_column, challenge, ) }); let z_looked = partial_products( &trace_poly_values[looked_table.table as usize], &looked_table.columns, - &looked_table.filter_columns, + &looked_table.filter_column, challenge, ); @@ -168,7 +257,7 @@ pub fn cross_table_lookup_data, const D ctl_data_per_table[table.table as usize].zs_columns.push(( z, table.columns.clone(), - table.filter_columns.clone(), + table.filter_column.clone(), )); } ctl_data_per_table[looked_table.table as usize] @@ -176,7 +265,7 @@ pub fn cross_table_lookup_data, const D .push(( z_looked, looked_table.columns.clone(), - looked_table.filter_columns.clone(), + looked_table.filter_column.clone(), )); } } @@ -185,21 +274,25 @@ pub fn cross_table_lookup_data, const D fn partial_products( trace: &[PolynomialValues], - columns: &[usize], - filter_columns: &[usize], + columns: &[Column], + filter_column: &Column, challenge: GrandProductChallenge, ) -> PolynomialValues { let mut partial_prod = F::ONE; let degree = trace[0].len(); let mut res = Vec::with_capacity(degree); for i in 0..degree { - let filter = if filter_columns.is_empty() { + let filter = if filter_column.is_empty() { F::ONE } else { - filter_columns.iter().map(|&j| trace[j].values[i]).sum() + filter_column.eval_table(trace, i) }; if filter.is_one() { - partial_prod *= challenge.combine(columns.iter().map(|&j| &trace[j].values[i])); + let evals = columns + .iter() + .map(|c| c.eval_table(trace, i)) + .collect::>(); + partial_prod *= challenge.combine(evals.iter()); } else { assert_eq!(filter, F::ZERO, "Non-binary filter?") }; @@ -218,8 +311,8 @@ where pub(crate) local_z: P, pub(crate) next_z: P, pub(crate) challenges: GrandProductChallenge, - pub(crate) columns: &'a [usize], - pub(crate) filter_columns: &'a [usize], + pub(crate) columns: &'a [Column], + pub(crate) filter_column: &'a Column, } impl<'a, F: RichField + Extendable, const D: usize> @@ -258,7 +351,7 @@ impl<'a, F: RichField + Extendable, const D: usize> next_z: *looking_z_next, challenges, columns: &table.columns, - filter_columns: &table.filter_columns, + filter_column: &table.filter_column, }); } @@ -268,7 +361,7 @@ impl<'a, F: RichField + Extendable, const D: usize> next_z: *looked_z_next, challenges, columns: &looked_table.columns, - filter_columns: &looked_table.filter_columns, + filter_column: &looked_table.filter_column, }); } } @@ -293,14 +386,17 @@ pub(crate) fn eval_cross_table_lookup_checks P { challenges.combine(columns.iter().map(|&i| &v[i])) }; + let combine = |v: &[P]| -> P { + let evals = columns.iter().map(|c| c.eval(v)).collect::>(); + challenges.combine(evals.iter()) + }; let filter = |v: &[P]| -> P { - if filter_columns.is_empty() { + if filter_column.is_empty() { P::ONES } else { - filter_columns.iter().map(|&i| v[i]).sum() + filter_column.eval(v) } }; let local_filter = filter(vars.local_values); @@ -317,16 +413,16 @@ pub(crate) fn eval_cross_table_lookup_checks { +pub struct CtlCheckVarsTarget<'a, F: Field, const D: usize> { pub(crate) local_z: ExtensionTarget, pub(crate) next_z: ExtensionTarget, pub(crate) challenges: GrandProductChallenge, - pub(crate) columns: &'a [usize], - pub(crate) filter_columns: &'a [usize], + pub(crate) columns: &'a [Column], + pub(crate) filter_column: &'a Column, } -impl<'a, const D: usize> CtlCheckVarsTarget<'a, D> { - pub(crate) fn from_proofs( +impl<'a, F: Field, const D: usize> CtlCheckVarsTarget<'a, F, D> { + pub(crate) fn from_proofs( proofs: &[StarkProofWithPublicInputsTarget], cross_table_lookups: &'a [CrossTableLookup], ctl_challenges: &'a GrandProductChallengeSet, @@ -359,7 +455,7 @@ impl<'a, const D: usize> CtlCheckVarsTarget<'a, D> { next_z: *looking_z_next, challenges, columns: &table.columns, - filter_columns: &table.filter_columns, + filter_column: &table.filter_column, }); } @@ -369,7 +465,7 @@ impl<'a, const D: usize> CtlCheckVarsTarget<'a, D> { next_z: *looked_z_next, challenges, columns: &looked_table.columns, - filter_columns: &looked_table.filter_columns, + filter_column: &looked_table.filter_column, }); } } @@ -384,7 +480,7 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit< >( builder: &mut CircuitBuilder, vars: StarkEvaluationTargets, - ctl_vars: &[CtlCheckVarsTarget], + ctl_vars: &[CtlCheckVarsTarget], consumer: &mut RecursiveConstraintConsumer, ) { for lookup_vars in ctl_vars { @@ -393,19 +489,19 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit< next_z, challenges, columns, - filter_columns, + filter_column, } = lookup_vars; let one = builder.one_extension(); - let local_filter = if filter_columns.is_empty() { + let local_filter = if filter_column.is_empty() { one } else { - builder.add_many_extension(filter_columns.iter().map(|&i| vars.local_values[i])) + filter_column.eval_circuit(builder, vars.local_values) }; - let next_filter = if filter_columns.is_empty() { + let next_filter = if filter_column.is_empty() { one } else { - builder.add_many_extension(filter_columns.iter().map(|&i| vars.next_values[i])) + filter_column.eval_circuit(builder, vars.next_values) }; fn select, const D: usize>( builder: &mut CircuitBuilder, @@ -418,24 +514,20 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit< } // Check value of `Z(1)` - let combined_local = challenges.combine_circuit( - builder, - &columns - .iter() - .map(|&i| vars.local_values[i]) - .collect::>(), - ); + let local_columns_eval = columns + .iter() + .map(|c| c.eval_circuit(builder, vars.local_values)) + .collect::>(); + let combined_local = challenges.combine_circuit(builder, &local_columns_eval); let selected_local = select(builder, local_filter, combined_local); let first_row = builder.sub_extension(*local_z, selected_local); consumer.constraint_first_row(builder, first_row); // Check `Z(gw) = combination * Z(w)` - let combined_next = challenges.combine_circuit( - builder, - &columns - .iter() - .map(|&i| vars.next_values[i]) - .collect::>(), - ); + let next_columns_eval = columns + .iter() + .map(|c| c.eval_circuit(builder, vars.next_values)) + .collect::>(); + let combined_next = challenges.combine_circuit(builder, &next_columns_eval); let selected_next = select(builder, next_filter, combined_next); let mut transition = builder.mul_extension(*local_z, selected_next); transition = builder.sub_extension(*next_z, transition); diff --git a/evm/src/prover.rs b/evm/src/prover.rs index cf0b9fae..facf22f1 100644 --- a/evm/src/prover.rs +++ b/evm/src/prover.rs @@ -392,14 +392,14 @@ where .iter() .enumerate() .map( - |(i, (_, columns, filter_columns))| CtlCheckVars:: { + |(i, (_, columns, filter_column))| CtlCheckVars:: { local_z: permutation_ctl_zs_commitment.get_lde_values_packed(i_start, step) [num_permutation_zs + i], next_z: permutation_ctl_zs_commitment .get_lde_values_packed(i_next_start, step)[num_permutation_zs + i], challenges: ctl_data.challenges.challenges[i % config.num_challenges], columns, - filter_columns, + filter_column, }, ) .collect::>(); @@ -510,14 +510,14 @@ fn check_constraints<'a, F, C, S, const D: usize>( .iter() .enumerate() .map( - |(iii, (_, columns, filter_columns))| CtlCheckVars:: { + |(iii, (_, columns, filter_column))| CtlCheckVars:: { local_z: get_comm_values(permutation_ctl_zs_commitment, i) [num_permutation_zs + iii], next_z: get_comm_values(permutation_ctl_zs_commitment, i_next) [num_permutation_zs + iii], challenges: ctl_data.challenges.challenges[iii % config.num_challenges], columns, - filter_columns, + filter_column, }, ) .collect::>(); diff --git a/evm/src/recursive_verifier.rs b/evm/src/recursive_verifier.rs index be5016d5..ae12aa3a 100644 --- a/evm/src/recursive_verifier.rs +++ b/evm/src/recursive_verifier.rs @@ -100,7 +100,7 @@ fn verify_stark_proof_with_challenges_circuit< stark: S, proof_with_pis: &StarkProofWithPublicInputsTarget, challenges: &StarkProofChallengesTarget, - ctl_vars: &[CtlCheckVarsTarget], + ctl_vars: &[CtlCheckVarsTarget], inner_config: &StarkConfig, ) where C::Hasher: AlgebraicHasher, diff --git a/evm/src/vanishing_poly.rs b/evm/src/vanishing_poly.rs index 31e1ffa9..16417918 100644 --- a/evm/src/vanishing_poly.rs +++ b/evm/src/vanishing_poly.rs @@ -50,7 +50,7 @@ pub(crate) fn eval_vanishing_poly_circuit( config: &StarkConfig, vars: StarkEvaluationTargets, permutation_data: Option>, - ctl_vars: &[CtlCheckVarsTarget], + ctl_vars: &[CtlCheckVarsTarget], consumer: &mut RecursiveConstraintConsumer, ) where F: RichField + Extendable, diff --git a/plonky2/src/gadgets/split_base.rs b/plonky2/src/gadgets/split_base.rs index 5ceede66..8cb113c4 100644 --- a/plonky2/src/gadgets/split_base.rs +++ b/plonky2/src/gadgets/split_base.rs @@ -30,7 +30,7 @@ impl, const D: usize> CircuitBuilder { /// Takes an iterator of bits `(b_i)` and returns `sum b_i * 2^i`, i.e., /// the number with little-endian bit representation given by `bits`. - pub(crate) fn le_sum(&mut self, bits: impl Iterator>) -> Target { + pub fn le_sum(&mut self, bits: impl Iterator>) -> Target { let bits = bits.map(|b| *b.borrow()).collect_vec(); let num_bits = bits.len(); if num_bits == 0 {