From 7959bd22cedea23277c21b6564cadb54d6521ad1 Mon Sep 17 00:00:00 2001 From: Sai <135601871+sai-deng@users.noreply.github.com> Date: Wed, 9 Oct 2024 15:29:34 +0000 Subject: [PATCH] Refactor CTL Handling (#1629) * refactor * fmt * fmt * sync target version * fix * fix clippy * fix clippy --- plonky2/src/iop/witness.rs | 4 +- plonky2/src/plonk/vars.rs | 10 +- plonky2/src/util/serialization/mod.rs | 4 +- plonky2/src/util/strided_view.rs | 18 +-- starky/src/cross_table_lookup.rs | 212 +++++++++----------------- starky/src/proof.rs | 25 --- 6 files changed, 87 insertions(+), 186 deletions(-) diff --git a/plonky2/src/iop/witness.rs b/plonky2/src/iop/witness.rs index abf1a779..b3b13fc0 100644 --- a/plonky2/src/iop/witness.rs +++ b/plonky2/src/iop/witness.rs @@ -388,13 +388,13 @@ impl<'a, F: Field> PartitionWitness<'a, F> { } } -impl<'a, F: Field> WitnessWrite for PartitionWitness<'a, F> { +impl WitnessWrite for PartitionWitness<'_, F> { fn set_target(&mut self, target: Target, value: F) -> Result<()> { self.set_target_returning_rep(target, value).map(|_| ()) } } -impl<'a, F: Field> Witness for PartitionWitness<'a, F> { +impl Witness for PartitionWitness<'_, F> { fn try_get_target(&self, target: Target) -> Option { let rep_index = self.representative_map[self.target_index(target)]; self.values[rep_index] diff --git a/plonky2/src/plonk/vars.rs b/plonky2/src/plonk/vars.rs index 6cffc45c..d70cb2c3 100644 --- a/plonky2/src/plonk/vars.rs +++ b/plonky2/src/plonk/vars.rs @@ -46,7 +46,7 @@ pub struct EvaluationVarsBasePacked<'a, P: PackedField> { pub public_inputs_hash: &'a HashOut, } -impl<'a, F: RichField + Extendable, const D: usize> EvaluationVars<'a, F, D> { +impl, const D: usize> EvaluationVars<'_, F, D> { pub fn get_local_ext_algebra( &self, wire_range: Range, @@ -120,7 +120,7 @@ impl<'a, F: Field> EvaluationVarsBaseBatch<'a, F> { } } -impl<'a, F: Field> EvaluationVarsBase<'a, F> { +impl EvaluationVarsBase<'_, F> { pub fn get_local_ext(&self, wire_range: Range) -> F::Extension where F: RichField + Extendable, @@ -209,13 +209,13 @@ impl<'a, P: PackedField> Iterator for EvaluationVarsBaseBatchIterPacked<'a, P> { } } -impl<'a, P: PackedField> ExactSizeIterator for EvaluationVarsBaseBatchIterPacked<'a, P> { +impl ExactSizeIterator for EvaluationVarsBaseBatchIterPacked<'_, P> { fn len(&self) -> usize { (self.vars_batch.len() - self.i) / P::WIDTH } } -impl<'a, const D: usize> EvaluationTargets<'a, D> { +impl EvaluationTargets<'_, D> { pub fn remove_prefix(&mut self, num_selectors: usize) { self.local_constants = &self.local_constants[num_selectors..]; } @@ -228,7 +228,7 @@ pub struct EvaluationTargets<'a, const D: usize> { pub public_inputs_hash: &'a HashOutTarget, } -impl<'a, const D: usize> EvaluationTargets<'a, D> { +impl EvaluationTargets<'_, D> { pub fn get_local_ext_algebra(&self, wire_range: Range) -> ExtensionAlgebraTarget { debug_assert_eq!(wire_range.len(), D); let arr = self.local_wires[wire_range].try_into().unwrap(); diff --git a/plonky2/src/util/serialization/mod.rs b/plonky2/src/util/serialization/mod.rs index 393db6c6..90e150ea 100644 --- a/plonky2/src/util/serialization/mod.rs +++ b/plonky2/src/util/serialization/mod.rs @@ -2196,13 +2196,13 @@ impl<'a> Buffer<'a> { } } -impl<'a> Remaining for Buffer<'a> { +impl Remaining for Buffer<'_> { fn remaining(&self) -> usize { self.bytes.len() - self.pos() } } -impl<'a> Read for Buffer<'a> { +impl Read for Buffer<'_> { #[inline] fn read_exact(&mut self, bytes: &mut [u8]) -> IoResult<()> { let n = bytes.len(); diff --git a/plonky2/src/util/strided_view.rs b/plonky2/src/util/strided_view.rs index bab978a7..55836657 100644 --- a/plonky2/src/util/strided_view.rs +++ b/plonky2/src/util/strided_view.rs @@ -130,7 +130,7 @@ impl<'a, P: PackedField> PackedStridedView<'a, P> { } } -impl<'a, P: PackedField> Index for PackedStridedView<'a, P> { +impl Index for PackedStridedView<'_, P> { type Output = P; #[inline] fn index(&self, index: usize) -> &Self::Output { @@ -182,7 +182,7 @@ pub struct PackedStridedViewIter<'a, P: PackedField> { _phantom: PhantomData<&'a [P::Scalar]>, } -impl<'a, P: PackedField> PackedStridedViewIter<'a, P> { +impl PackedStridedViewIter<'_, P> { pub(self) const fn new(start: *const P::Scalar, end: *const P::Scalar, stride: usize) -> Self { Self { start, @@ -215,7 +215,7 @@ impl<'a, P: PackedField> Iterator for PackedStridedViewIter<'a, P> { } } -impl<'a, P: PackedField> DoubleEndedIterator for PackedStridedViewIter<'a, P> { +impl DoubleEndedIterator for PackedStridedViewIter<'_, P> { fn next_back(&mut self) -> Option { debug_assert_eq!( (self.end as usize).wrapping_sub(self.start as usize) @@ -241,7 +241,7 @@ pub trait Viewable { fn view(&self, index: F) -> Self::View; } -impl<'a, P: PackedField> Viewable> for PackedStridedView<'a, P> { +impl Viewable> for PackedStridedView<'_, P> { type View = Self; fn view(&self, range: Range) -> Self::View { assert!(range.start <= self.len(), "Invalid access"); @@ -257,7 +257,7 @@ impl<'a, P: PackedField> Viewable> for PackedStridedView<'a, P> { } } -impl<'a, P: PackedField> Viewable> for PackedStridedView<'a, P> { +impl Viewable> for PackedStridedView<'_, P> { type View = Self; fn view(&self, range: RangeFrom) -> Self::View { assert!(range.start <= self.len(), "Invalid access"); @@ -272,14 +272,14 @@ impl<'a, P: PackedField> Viewable> for PackedStridedView<'a, P> } } -impl<'a, P: PackedField> Viewable for PackedStridedView<'a, P> { +impl Viewable for PackedStridedView<'_, P> { type View = Self; fn view(&self, _range: RangeFull) -> Self::View { *self } } -impl<'a, P: PackedField> Viewable> for PackedStridedView<'a, P> { +impl Viewable> for PackedStridedView<'_, P> { type View = Self; fn view(&self, range: RangeInclusive) -> Self::View { assert!(*range.start() <= self.len(), "Invalid access"); @@ -295,7 +295,7 @@ impl<'a, P: PackedField> Viewable> for PackedStridedView<' } } -impl<'a, P: PackedField> Viewable> for PackedStridedView<'a, P> { +impl Viewable> for PackedStridedView<'_, P> { type View = Self; fn view(&self, range: RangeTo) -> Self::View { assert!(range.end <= self.len(), "Invalid access"); @@ -308,7 +308,7 @@ impl<'a, P: PackedField> Viewable> for PackedStridedView<'a, P> { } } -impl<'a, P: PackedField> Viewable> for PackedStridedView<'a, P> { +impl Viewable> for PackedStridedView<'_, P> { type View = Self; fn view(&self, range: RangeToInclusive) -> Self::View { assert!(range.end < self.len(), "Invalid access"); diff --git a/starky/src/cross_table_lookup.rs b/starky/src/cross_table_lookup.rs index dd814431..3fef3f15 100644 --- a/starky/src/cross_table_lookup.rs +++ b/starky/src/cross_table_lookup.rs @@ -30,7 +30,6 @@ #[cfg(not(feature = "std"))] use alloc::{vec, vec::Vec}; -use core::cmp::min; use core::fmt::Debug; use core::iter::once; @@ -55,7 +54,7 @@ use crate::lookup::{ eval_helper_columns, eval_helper_columns_circuit, get_grand_product_challenge_set, get_helper_cols, Column, ColumnFilter, Filter, GrandProductChallenge, GrandProductChallengeSet, }; -use crate::proof::{MultiProof, StarkProofTarget, StarkProofWithMetadata}; +use crate::proof::{StarkProof, StarkProofTarget}; use crate::stark::Stark; /// An alias for `usize`, to represent the index of a STARK table in a multi-STARK setting. @@ -186,7 +185,7 @@ impl<'a, F: Field> CtlZData<'a, F> { } } -impl<'a, F: Field> CtlData<'a, F> { +impl CtlData<'_, F> { /// Returns all the cross-table lookup helper polynomials. pub(crate) fn ctl_helper_polys(&self) -> Vec> { let num_polys = self @@ -250,58 +249,6 @@ where (ctl_challenges, ctl_data) } -/// Outputs all the CTL data necessary to prove a multi-STARK system. -pub fn get_ctl_vars_from_proofs<'a, F, C, const D: usize, const N: usize>( - multi_proof: &MultiProof, - all_cross_table_lookups: &'a [CrossTableLookup], - ctl_challenges: &'a GrandProductChallengeSet, - num_lookup_columns: &'a [usize; N], - max_constraint_degree: usize, -) -> [Vec>::Extension, >::Extension, D>>; - N] -where - F: RichField + Extendable, - C: GenericConfig, -{ - let num_ctl_helper_cols = - num_ctl_helper_columns_by_table(all_cross_table_lookups, max_constraint_degree); - - CtlCheckVars::from_proofs( - &multi_proof.stark_proofs, - all_cross_table_lookups, - ctl_challenges, - num_lookup_columns, - &num_ctl_helper_cols, - ) -} -/// Returns the number of helper columns for each `Table`. -pub(crate) fn num_ctl_helper_columns_by_table( - ctls: &[CrossTableLookup], - constraint_degree: usize, -) -> Vec<[usize; N]> { - let mut res = vec![[0; N]; ctls.len()]; - for (i, ctl) in ctls.iter().enumerate() { - let CrossTableLookup { - looking_tables, - looked_table: _, - } = ctl; - let mut num_by_table = [0; N]; - - let grouped_lookups = looking_tables.iter().group_by(|&a| a.table); - - for (table, group) in grouped_lookups.into_iter() { - let sum = group.count(); - if sum > 1 { - // We only need helper columns if there are at least 2 columns. - num_by_table[table] = sum.div_ceil(constraint_degree - 1); - } - } - - res[i] = num_by_table; - } - res -} - /// Gets the auxiliary polynomials associated to these CTL data. pub(crate) fn get_ctl_auxiliary_polys( ctl_data: Option<&CtlData>, @@ -492,104 +439,82 @@ where impl<'a, F: RichField + Extendable, const D: usize> CtlCheckVars<'a, F, F::Extension, F::Extension, D> { - /// Extracts the `CtlCheckVars` for each STARK. - pub fn from_proofs, const N: usize>( - proofs: &[StarkProofWithMetadata; N], + /// Extracts the `CtlCheckVars` from a single proof. + pub fn from_proof>( + table_idx: TableIdx, + proof: &StarkProof, cross_table_lookups: &'a [CrossTableLookup], ctl_challenges: &'a GrandProductChallengeSet, - num_lookup_columns: &[usize; N], - num_helper_ctl_columns: &Vec<[usize; N]>, - ) -> [Vec; N] { - let mut ctl_vars_per_table = [0; N].map(|_| vec![]); - // If there are no auxiliary polys in the proofs `openings`, - // return early. The verifier will reject the proofs when - // calling `validate_proof_shape`. - if proofs - .iter() - .any(|p| p.proof.openings.auxiliary_polys.is_none()) - { - return ctl_vars_per_table; - } + num_lookup_columns: usize, + total_num_helper_columns: usize, + num_helper_ctl_columns: &[usize], + ) -> Vec { + // Get all cross-table lookup polynomial openings for the provided STARK proof. + let ctl_zs = { + let auxiliary_polys = proof + .openings + .auxiliary_polys + .as_ref() + .expect("We cannot have CTLs without auxiliary polynomials."); + let auxiliary_polys_next = proof + .openings + .auxiliary_polys_next + .as_ref() + .expect("We cannot have CTLs without auxiliary polynomials."); - let mut total_num_helper_cols_by_table = [0; N]; - for p_ctls in num_helper_ctl_columns { - for j in 0..N { - total_num_helper_cols_by_table[j] += p_ctls[j] * ctl_challenges.challenges.len(); - } - } + auxiliary_polys + .iter() + .skip(num_lookup_columns) + .zip(auxiliary_polys_next.iter().skip(num_lookup_columns)) + .collect::>() + }; - // Get all cross-table lookup polynomial openings for each STARK proof. - let ctl_zs = proofs - .iter() - .zip(num_lookup_columns) - .map(|(p, &num_lookup)| { - let openings = &p.proof.openings; + let mut z_index = 0; + let mut start_index = 0; + let mut ctl_vars = vec![]; - let ctl_zs = &openings - .auxiliary_polys - .as_ref() - .expect("We cannot have CTls without auxiliary polynomials.")[num_lookup..]; - let ctl_zs_next = &openings - .auxiliary_polys_next - .as_ref() - .expect("We cannot have CTls without auxiliary polynomials.")[num_lookup..]; - ctl_zs.iter().zip(ctl_zs_next).collect::>() - }) - .collect::>(); - - // Put each cross-table lookup polynomial into the correct table data: if a CTL polynomial is extracted from looking/looked table t, then we add it to the `CtlCheckVars` of table t. - let mut start_indices = [0; N]; - let mut z_indices = [0; N]; for ( + i, CrossTableLookup { looking_tables, looked_table, }, - num_ctls, - ) in cross_table_lookups.iter().zip(num_helper_ctl_columns) + ) in cross_table_lookups.iter().enumerate() { for &challenges in &ctl_challenges.challenges { - // Group looking tables by `Table`, since we bundle the looking tables taken from the same `Table` together thanks to helper columns. - // We want to only iterate on each `Table` once. - let mut filtered_looking_tables = Vec::with_capacity(min(looking_tables.len(), N)); - for table in looking_tables { - if !filtered_looking_tables.contains(&(table.table)) { - filtered_looking_tables.push(table.table); + // Group the looking tables by `Table` to process them together. + let count = looking_tables + .iter() + .filter(|looking_table| looking_table.table == table_idx) + .count(); + + let cols_filts = looking_tables.iter().filter_map(|looking_table| { + if looking_table.table == table_idx { + Some((&looking_table.columns, &looking_table.filter)) + } else { + None } - } + }); - for &table in filtered_looking_tables.iter() { - // We have first all the helper polynomials, then all the z polynomials. - let (looking_z, looking_z_next) = - ctl_zs[table][total_num_helper_cols_by_table[table] + z_indices[table]]; - - let count = looking_tables - .iter() - .filter(|looking_table| looking_table.table == table) - .count(); - let cols_filts = looking_tables.iter().filter_map(|looking_table| { - if looking_table.table == table { - Some((&looking_table.columns, &looking_table.filter)) - } else { - None - } - }); + if count > 0 { let mut columns = Vec::with_capacity(count); let mut filter = Vec::with_capacity(count); for (col, filt) in cols_filts { columns.push(&col[..]); filter.push(filt.clone()); } - let helper_columns = ctl_zs[table] - [start_indices[table]..start_indices[table] + num_ctls[table]] + + let (looking_z, looking_z_next) = ctl_zs[total_num_helper_columns + z_index]; + let helper_columns = ctl_zs + [start_index..start_index + num_helper_ctl_columns[i]] .iter() .map(|&(h, _)| *h) .collect::>(); - start_indices[table] += num_ctls[table]; + start_index += num_helper_ctl_columns[i]; + z_index += 1; - z_indices[table] += 1; - ctl_vars_per_table[table].push(Self { + ctl_vars.push(Self { helper_columns, local_z: *looking_z, next_z: *looking_z_next, @@ -599,25 +524,26 @@ impl<'a, F: RichField + Extendable, const D: usize> }); } - let (looked_z, looked_z_next) = ctl_zs[looked_table.table] - [total_num_helper_cols_by_table[looked_table.table] - + z_indices[looked_table.table]]; + if looked_table.table == table_idx { + let (looked_z, looked_z_next) = ctl_zs[total_num_helper_columns + z_index]; + z_index += 1; - z_indices[looked_table.table] += 1; + let columns = vec![&looked_table.columns[..]]; + let filter = vec![looked_table.filter.clone()]; - let columns = vec![&looked_table.columns[..]]; - let filter = vec![looked_table.filter.clone()]; - ctl_vars_per_table[looked_table.table].push(Self { - helper_columns: vec![], - local_z: *looked_z, - next_z: *looked_z_next, - challenges, - columns, - filter, - }); + ctl_vars.push(Self { + helper_columns: vec![], + local_z: *looked_z, + next_z: *looked_z_next, + challenges, + columns, + filter, + }); + } } } - ctl_vars_per_table + + ctl_vars } } diff --git a/starky/src/proof.rs b/starky/src/proof.rs index 31151f1c..cad1f1e7 100644 --- a/starky/src/proof.rs +++ b/starky/src/proof.rs @@ -159,31 +159,6 @@ where pub proof: StarkProof, } -/// A combination of STARK proofs for independent statements operating on possibly shared variables, -/// along with Cross-Table Lookup (CTL) challenges to assert consistency of common variables across tables. -#[derive(Debug, Clone)] -pub struct MultiProof< - F: RichField + Extendable, - C: GenericConfig, - const D: usize, - const N: usize, -> { - /// Proofs for all the different STARK modules. - pub stark_proofs: [StarkProofWithMetadata; N], - /// Cross-table lookup challenges. - pub ctl_challenges: GrandProductChallengeSet, -} - -impl, C: GenericConfig, const D: usize, const N: usize> - MultiProof -{ - /// Returns the degree (i.e. the trace length) of each STARK proof, - /// from their common [`StarkConfig`]. - pub fn recover_degree_bits(&self, config: &StarkConfig) -> [usize; N] { - core::array::from_fn(|i| self.stark_proofs[i].proof.recover_degree_bits(config)) - } -} - /// Randomness used for a STARK proof. #[derive(Debug)] pub struct StarkProofChallenges, const D: usize> {