diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index 0448ab52..b5c65fec 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -107,8 +107,8 @@ mod tests { keccak_trace[5].values[..].copy_from_slice(&vs1); let cross_table_lookups = vec![CrossTableLookup { - looking_table: Table::Cpu, - looking_columns: vec![2, 4], + looking_tables: vec![Table::Cpu], + looking_columns: vec![vec![2, 4]], looked_table: Table::Keccak, looked_columns: vec![3, 5], default: vec![F::ONE; 2], diff --git a/evm/src/cross_table_lookup.rs b/evm/src/cross_table_lookup.rs index d29958be..f8d3e6d5 100644 --- a/evm/src/cross_table_lookup.rs +++ b/evm/src/cross_table_lookup.rs @@ -1,4 +1,5 @@ use anyhow::{ensure, Result}; +use itertools::izip; use plonky2::field::extension_field::{Extendable, FieldExtension}; use plonky2::field::field_types::Field; use plonky2::field::packed_field::PackedField; @@ -22,8 +23,8 @@ use crate::vars::{StarkEvaluationTargets, StarkEvaluationVars}; #[derive(Clone)] pub struct CrossTableLookup { - pub looking_table: Table, - pub looking_columns: Vec, + pub looking_tables: Vec, + pub looking_columns: Vec>, pub looked_table: Table, pub looked_columns: Vec, pub default: Vec, @@ -31,15 +32,18 @@ pub struct CrossTableLookup { impl CrossTableLookup { pub fn new( - looking_table: Table, - looking_columns: Vec, + looking_tables: Vec
, + looking_columns: Vec>, looked_table: Table, looked_columns: Vec, default: Vec, ) -> Self { - assert_eq!(looking_columns.len(), looked_columns.len()); + assert_eq!(looking_tables.len(), looking_columns.len()); + assert!(looking_columns + .iter() + .all(|cols| cols.len() == looked_columns.len())); Self { - looking_table, + looking_tables, looking_columns, looked_table, looked_columns, @@ -87,7 +91,7 @@ pub fn cross_table_lookup_data, const D let challenges = get_grand_product_challenge_set(challenger, config.num_challenges); let mut ctl_data_per_table = vec![CtlData::new(challenges.clone()); trace_poly_values.len()]; for CrossTableLookup { - looking_table, + looking_tables, looking_columns, looked_table, looked_columns, @@ -95,11 +99,13 @@ pub fn cross_table_lookup_data, const D } in cross_table_lookups { for &challenge in &challenges.challenges { - let z_looking = partial_products( - &trace_poly_values[*looking_table as usize], - looking_columns, - challenge, - ); + let zs_looking = looking_tables + .iter() + .zip(looking_columns) + .map(|(table, columns)| { + partial_products(&trace_poly_values[*table as usize], columns, challenge) + }) + .collect::>(); let z_looked = partial_products( &trace_poly_values[*looked_table as usize], looked_columns, @@ -107,17 +113,25 @@ pub fn cross_table_lookup_data, const D ); debug_assert_eq!( - *z_looking.values.last().unwrap(), + zs_looking + .iter() + .map(|z| *z.values.last().unwrap()) + .product::(), *z_looked.values.last().unwrap() * challenge.combine(default).exp_u64( - trace_poly_values[*looking_table as usize][0].len() as u64 + looking_tables + .iter() + .map(|table| trace_poly_values[*table as usize][0].len() as u64) + .sum::() - trace_poly_values[*looked_table as usize][0].len() as u64 ) ); - ctl_data_per_table[*looking_table as usize] - .zs_columns - .push((z_looking, looking_columns.clone())); + for (table, columns, z) in izip!(looking_tables, looking_columns, zs_looking) { + ctl_data_per_table[*table as usize] + .zs_columns + .push((z, columns.clone())); + } ctl_data_per_table[*looked_table as usize] .zs_columns .push((z_looked, looked_columns.clone())); @@ -177,7 +191,7 @@ impl<'a, F: RichField + Extendable, const D: usize> let mut ctl_vars_per_table = vec![vec![]; proofs.len()]; for CrossTableLookup { - looking_table, + looking_tables, looking_columns, looked_table, looked_columns, @@ -185,13 +199,15 @@ impl<'a, F: RichField + Extendable, const D: usize> } in cross_table_lookups { for &challenges in &ctl_challenges.challenges { - let (looking_z, looking_z_next) = ctl_zs[*looking_table as usize].next().unwrap(); - ctl_vars_per_table[*looking_table as usize].push(Self { - local_z: *looking_z, - next_z: *looking_z_next, - challenges, - columns: looking_columns, - }); + for (table, columns) in looking_tables.iter().zip(looking_columns) { + let (looking_z, looking_z_next) = ctl_zs[*table as usize].next().unwrap(); + ctl_vars_per_table[*table as usize].push(Self { + local_z: *looking_z, + next_z: *looking_z_next, + challenges, + columns, + }); + } let (looked_z, looked_z_next) = ctl_zs[*looked_table as usize].next().unwrap(); ctl_vars_per_table[*looked_table as usize].push(Self { @@ -262,7 +278,7 @@ impl<'a, const D: usize> CtlCheckVarsTarget<'a, D> { let mut ctl_vars_per_table = vec![vec![]; proofs.len()]; for CrossTableLookup { - looking_table, + looking_tables, looking_columns, looked_table, looked_columns, @@ -270,13 +286,15 @@ impl<'a, const D: usize> CtlCheckVarsTarget<'a, D> { } in cross_table_lookups { for &challenges in &ctl_challenges.challenges { - let (looking_z, looking_z_next) = ctl_zs[*looking_table as usize].next().unwrap(); - ctl_vars_per_table[*looking_table as usize].push(Self { - local_z: *looking_z, - next_z: *looking_z_next, - challenges, - columns: looking_columns, - }); + for (table, columns) in looking_tables.iter().zip(looking_columns) { + let (looking_z, looking_z_next) = ctl_zs[*table as usize].next().unwrap(); + ctl_vars_per_table[*table as usize].push(Self { + local_z: *looking_z, + next_z: *looking_z_next, + challenges, + columns, + }); + } let (looked_z, looked_z_next) = ctl_zs[*looked_table as usize].next().unwrap(); ctl_vars_per_table[*looked_table as usize].push(Self { @@ -354,22 +372,29 @@ pub(crate) fn verify_cross_table_lookups< for ( i, CrossTableLookup { - looking_table, + looking_tables, looked_table, default, .. }, ) in cross_table_lookups.into_iter().enumerate() { - let looking_degree = 1 << degrees_bits[looking_table as usize]; + let looking_degrees_sum = looking_tables + .iter() + .map(|&table| 1 << degrees_bits[table as usize]) + .sum::(); let looked_degree = 1 << degrees_bits[looked_table as usize]; - let looking_z = *ctl_zs_openings[looking_table as usize].next().unwrap(); + let looking_zs_prod = looking_tables + .into_iter() + .map(|table| *ctl_zs_openings[table as usize].next().unwrap()) + .product::(); let looked_z = *ctl_zs_openings[looked_table as usize].next().unwrap(); let challenge = challenges.challenges[i % config.num_challenges]; let combined_default = challenge.combine(default.iter()); ensure!( - looking_z == looked_z * combined_default.exp_u64(looking_degree - looked_degree), + looking_zs_prod + == looked_z * combined_default.exp_u64(looking_degrees_sum - looked_degree), "Cross-table lookup verification failed." ); } @@ -399,16 +424,23 @@ pub(crate) fn verify_cross_table_lookups_circuit< for ( i, CrossTableLookup { - looking_table, + looking_tables, looked_table, default, .. }, ) in cross_table_lookups.into_iter().enumerate() { - let looking_degree = 1 << degrees_bits[looking_table as usize]; + let looking_degrees_sum = looking_tables + .iter() + .map(|&table| 1 << degrees_bits[table as usize]) + .sum::(); let looked_degree = 1 << degrees_bits[looked_table as usize]; - let looking_z = *ctl_zs_openings[looking_table as usize].next().unwrap(); + let looking_zs_prod = builder.mul_many( + looking_tables + .into_iter() + .map(|table| *ctl_zs_openings[table as usize].next().unwrap()), + ); let looked_z = *ctl_zs_openings[looked_table as usize].next().unwrap(); let challenge = challenges.challenges[i % inner_config.num_challenges]; let default = default @@ -417,8 +449,8 @@ pub(crate) fn verify_cross_table_lookups_circuit< .collect::>(); let combined_default = challenge.combine_base_circuit(builder, &default); - let pad = builder.exp_u64(combined_default, looking_degree - looked_degree); + let pad = builder.exp_u64(combined_default, looking_degrees_sum - looked_degree); let padded_looked_z = builder.mul(looked_z, pad); - builder.connect(looking_z, padded_looked_z); + builder.connect(looking_zs_prod, padded_looked_z); } } diff --git a/plonky2/src/gadgets/arithmetic.rs b/plonky2/src/gadgets/arithmetic.rs index fe6f116d..ad587a49 100644 --- a/plonky2/src/gadgets/arithmetic.rs +++ b/plonky2/src/gadgets/arithmetic.rs @@ -25,7 +25,7 @@ impl, const D: usize> CircuitBuilder { /// Computes `x^3`. pub fn cube(&mut self, x: Target) -> Target { - self.mul_many(&[x, x, x]) + self.mul_many([x, x, x]) } /// Computes `const_0 * multiplicand_0 * multiplicand_1 + const_1 * addend`. @@ -206,12 +206,16 @@ impl, const D: usize> CircuitBuilder { } /// Multiply `n` `Target`s. - pub fn mul_many(&mut self, terms: &[Target]) -> Target { - terms - .iter() - .copied() - .reduce(|acc, t| self.mul(acc, t)) - .unwrap_or_else(|| self.one()) + pub fn mul_many(&mut self, terms: impl IntoIterator) -> Target + where + T: Borrow, + { + let mut iter = terms.into_iter(); + if let Some(first) = iter.next() { + iter.fold(*first.borrow(), |acc, t| self.mul(acc, *t.borrow())) + } else { + self.one() + } } /// Exponentiate `base` to the power of `2^power_log`.