diff --git a/starky/src/lookup.rs b/starky/src/lookup.rs index 2a90e6b9..f260400b 100644 --- a/starky/src/lookup.rs +++ b/starky/src/lookup.rs @@ -9,7 +9,7 @@ use core::iter::repeat; use itertools::Itertools; use num_bigint::BigUint; -use plonky2::field::batch_util::batch_add_inplace; +use plonky2::field::batch_util::{batch_add_inplace, batch_multiply_inplace}; use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; @@ -761,79 +761,35 @@ pub(crate) fn get_helper_cols( .len() .div_ceil(constraint_degree.checked_sub(1).unwrap_or(1)); - let mut helper_columns = Vec::with_capacity(num_helper_columns); - - for mut cols_filts in &columns_filters - .iter() - .chunks(constraint_degree.checked_sub(1).unwrap_or(1)) - { - let (first_col, first_filter) = cols_filts.next().unwrap(); - - let mut filter_col = Vec::with_capacity(degree); - let first_combined = (0..degree) - .map(|d| { - let f = { - let f = first_filter.eval_table(trace, d); - filter_col.push(f); - f - }; - if f.is_one() { - let evals = first_col - .iter() - .map(|c| c.eval_table(trace, d)) + let chunks = columns_filters.chunks(constraint_degree.checked_sub(1).unwrap_or(1)); + let helper_columns: Vec<_> = chunks + .filter_map(|cols_filts| { + cols_filts + .iter() + .map(|(col, filter)| { + let combined = (0..degree) + .map(|d| { + let evals = col + .iter() + .map(|c| c.eval_table(trace, d)) + .collect::>(); + challenge.combine(&evals) + }) .collect::>(); - challenge.combine(evals.iter()) - } else { - assert_eq!(f, F::ZERO, "Non-binary filter?"); - // Dummy value. Cannot be zero since it will be batch-inverted. - F::ONE - } - }) - .collect::>(); - let mut acc = F::batch_multiplicative_inverse(&first_combined); - for d in 0..degree { - if filter_col[d].is_zero() { - acc[d] = F::ZERO; - } - } - - for (col, filt) in cols_filts { - let mut filter_col = Vec::with_capacity(degree); - let mut combined = (0..degree) - .map(|d| { - let f = { - let f = filt.eval_table(trace, d); - filter_col.push(f); - f - }; - if f.is_one() { - let evals = col - .iter() - .map(|c| c.eval_table(trace, d)) - .collect::>(); - challenge.combine(evals.iter()) - } else { - assert_eq!(f, F::ZERO, "Non-binary filter?"); - // Dummy value. Cannot be zero since it will be batch-inverted. - F::ONE - } + let mut combined = F::batch_multiplicative_inverse(&combined); + let filter_col: Vec<_> = + (0..degree).map(|d| filter.eval_table(trace, d)).collect(); + batch_multiply_inplace(&mut combined, &filter_col); + combined }) - .collect::>(); - - combined = F::batch_multiplicative_inverse(&combined); - - for d in 0..degree { - if filter_col[d].is_zero() { - combined[d] = F::ZERO; - } - } - - batch_add_inplace(&mut acc, &combined); - } - - helper_columns.push(acc.into()); - } + .reduce(|mut acc, combined| { + batch_add_inplace(&mut acc, &combined); + acc + }) + .map(PolynomialValues::from) + }) + .collect(); assert_eq!(helper_columns.len(), num_helper_columns); helper_columns