diff --git a/plonky2/src/gates/gate.rs b/plonky2/src/gates/gate.rs index 03bd0a7b..e95ec0f8 100644 --- a/plonky2/src/gates/gate.rs +++ b/plonky2/src/gates/gate.rs @@ -1,13 +1,14 @@ use std::collections::HashMap; use std::fmt::{Debug, Error, Formatter}; use std::hash::{Hash, Hasher}; +use std::ops::Range; use std::sync::Arc; use plonky2_field::batch_util::batch_multiply_inplace; use plonky2_field::extension_field::{Extendable, FieldExtension}; use plonky2_field::field_types::Field; -use crate::gates::gate_tree::Tree; +use crate::gates::selectors::UNUSED_SELECTOR; use crate::gates::util::StridedConstraintConsumer; use crate::hash::hash_types::RichField; use crate::iop::ext_target::ExtensionTarget; @@ -80,9 +81,21 @@ pub trait Gate, const D: usize>: 'static + Send + S vars: EvaluationTargets, ) -> Vec>; - fn eval_filtered(&self, mut vars: EvaluationVars, prefix: &[bool]) -> Vec { - let filter = compute_filter(prefix, vars.local_constants); - vars.remove_prefix(prefix); + fn eval_filtered( + &self, + mut vars: EvaluationVars, + gate_index: usize, + selector_index: usize, + group_range: Range, + num_selectors: usize, + ) -> Vec { + let filter = compute_filter( + gate_index, + group_range, + vars.local_constants[selector_index], + num_selectors > 1, + ); + vars.remove_prefix(num_selectors); self.eval_unfiltered(vars) .into_iter() .map(|c| filter * c) @@ -94,13 +107,23 @@ pub trait Gate, const D: usize>: 'static + Send + S fn eval_filtered_base_batch( &self, mut vars_batch: EvaluationVarsBaseBatch, - prefix: &[bool], + gate_index: usize, + selector_index: usize, + group_range: Range, + num_selectors: usize, ) -> Vec { let filters: Vec<_> = vars_batch .iter() - .map(|vars| compute_filter(prefix, vars.local_constants)) + .map(|vars| { + compute_filter( + gate_index, + group_range.clone(), + vars.local_constants[selector_index], + num_selectors > 1, + ) + }) .collect(); - vars_batch.remove_prefix(prefix); + vars_batch.remove_prefix(num_selectors); let mut res_batch = self.eval_unfiltered_base_batch(vars_batch); for res_chunk in res_batch.chunks_exact_mut(filters.len()) { batch_multiply_inplace(res_chunk, &filters); @@ -113,11 +136,20 @@ pub trait Gate, const D: usize>: 'static + Send + S &self, builder: &mut CircuitBuilder, mut vars: EvaluationTargets, - prefix: &[bool], + gate_index: usize, + selector_index: usize, + group_range: Range, + num_selectors: usize, combined_gate_constraints: &mut [ExtensionTarget], ) { - let filter = compute_filter_recursively(builder, prefix, vars.local_constants); - vars.remove_prefix(prefix); + let filter = compute_filter_recursively( + builder, + gate_index, + group_range, + vars.local_constants[selector_index], + num_selectors > 1, + ); + vars.remove_prefix(num_selectors); let my_constraints = self.eval_unfiltered_recursively(builder, vars); for (acc, c) in combined_gate_constraints.iter_mut().zip(my_constraints) { *acc = builder.mul_add_extension(filter, c, *acc); @@ -202,42 +234,36 @@ pub struct PrefixedGate, const D: usize> { pub prefix: Vec, } -impl, const D: usize> PrefixedGate { - pub fn from_tree(tree: Tree>) -> Vec { - tree.traversal() - .into_iter() - .map(|(gate, prefix)| PrefixedGate { gate, prefix }) - .collect() - } -} - -/// A gate's filter is computed as `prod b_i*c_i + (1-b_i)*(1-c_i)`, with `(b_i)` the prefix and -/// `(c_i)` the local constants, which is one if the prefix of `constants` matches `prefix`. -fn compute_filter<'a, K: Field, T: IntoIterator>(prefix: &[bool], constants: T) -> K { - prefix - .iter() - .zip(constants) - .map(|(&b, &c)| if b { c } else { K::ONE - c }) +/// A gate's filter designed so that it is non-zero if `s = gate_index`. +fn compute_filter( + gate_index: usize, + group_range: Range, + s: K, + many_selector: bool, +) -> K { + debug_assert!(group_range.contains(&gate_index)); + group_range + .filter(|&i| i != gate_index) + .chain(many_selector.then(|| UNUSED_SELECTOR)) + .map(|i| K::from_canonical_usize(i) - s) .product() } fn compute_filter_recursively, const D: usize>( builder: &mut CircuitBuilder, - prefix: &[bool], - constants: &[ExtensionTarget], + gate_index: usize, + group_range: Range, + s: ExtensionTarget, + many_selectors: bool, ) -> ExtensionTarget { - let one = builder.one_extension(); - let v = prefix - .iter() - .enumerate() - .map(|(i, &b)| { - if b { - constants[i] - } else { - builder.sub_extension(one, constants[i]) - } + debug_assert!(group_range.contains(&gate_index)); + let v = group_range + .filter(|&i| i != gate_index) + .chain(many_selectors.then(|| UNUSED_SELECTOR)) + .map(|i| { + let c = builder.constant_extension(F::Extension::from_canonical_usize(i)); + builder.sub_extension(c, s) }) .collect::>(); - builder.mul_many_extension(&v) } diff --git a/plonky2/src/gates/gate_tree.rs b/plonky2/src/gates/gate_tree.rs deleted file mode 100644 index 2f670337..00000000 --- a/plonky2/src/gates/gate_tree.rs +++ /dev/null @@ -1,294 +0,0 @@ -use log::debug; -use plonky2_field::extension_field::Extendable; - -use crate::gates::gate::GateRef; -use crate::hash::hash_types::RichField; - -/// A binary tree where leaves hold some type `T` and other nodes are empty. -#[derive(Debug, Clone)] -pub enum Tree { - Leaf(T), - Bifurcation(Option>>, Option>>), -} - -impl Default for Tree { - fn default() -> Self { - Self::Bifurcation(None, None) - } -} - -impl Tree { - /// Traverse a tree using a depth-first traversal and collect data and position for each leaf. - /// A leaf's position is represented by its left/right path, where `false` means left and `true` means right. - pub fn traversal(&self) -> Vec<(T, Vec)> { - let mut res = Vec::new(); - let prefix = []; - self.traverse(&prefix, &mut res); - res - } - - /// Utility function to traverse the tree. - fn traverse(&self, prefix: &[bool], current: &mut Vec<(T, Vec)>) { - match &self { - // If node is a leaf, collect the data and position. - Tree::Leaf(t) => { - current.push((t.clone(), prefix.to_vec())); - } - // Otherwise, traverse the left subtree and then the right subtree. - Tree::Bifurcation(left, right) => { - if let Some(l) = left { - let mut left_prefix = prefix.to_vec(); - left_prefix.push(false); - l.traverse(&left_prefix, current); - } - if let Some(r) = right { - let mut right_prefix = prefix.to_vec(); - right_prefix.push(true); - r.traverse(&right_prefix, current); - } - } - } - } -} - -impl, const D: usize> Tree> { - /// The binary gate tree influences the degree `D` of the constraint polynomial and the number `C` - /// of constant wires in the circuit. We want to construct a tree minimizing both values. To do so - /// we iterate over possible values of `(D, C)` and try to construct a tree with these values. - /// For this construction, we use the greedy algorithm in `Self::find_tree`. - /// This latter function greedily adds gates at the depth where - /// `filtered_deg(gate)=D, constant_wires(gate)=C` to ensure no space is wasted. - /// We return the first tree found in this manner, along with it's maximum filtered degree - /// and the number of constant wires needed when using this tree. - pub fn from_gates(mut gates: Vec>) -> (Self, usize, usize) { - let timer = std::time::Instant::now(); - gates.sort_unstable_by_key(|g| (-(g.0.degree() as isize), -(g.0.num_constants() as isize))); - - for max_degree_bits in 1..10 { - // The quotient polynomials are padded to the next power of 2 in `compute_quotient_polys`. - // So we can restrict our search space by setting `max_degree` to 1 + a power of 2. - let max_degree = (1 << max_degree_bits) + 1; - for max_constants in 1..100 { - if let Some(mut best_tree) = Self::find_tree(&gates, max_degree, max_constants) { - let mut best_num_constants = best_tree.num_constants(); - let mut best_degree = max_degree; - // Iterate backwards from `max_degree` to try to find a tree with a lower degree - // but the same number of constants. - 'optdegree: for degree in (0..max_degree).rev() { - if let Some(tree) = Self::find_tree(&gates, degree, max_constants) { - let num_constants = tree.num_constants(); - if num_constants > best_num_constants { - break 'optdegree; - } else { - best_degree = degree; - best_num_constants = num_constants; - best_tree = tree; - } - } - } - debug!( - "Found tree with max degree {} and {} constants wires in {:.4}s.", - best_degree, - best_num_constants, - timer.elapsed().as_secs_f32() - ); - return (best_tree, best_degree, best_num_constants); - } - } - } - - panic!("Can't find a tree.") - } - - /// Greedily add gates wherever possible. Returns `None` if this fails. - fn find_tree(gates: &[GateRef], max_degree: usize, max_constants: usize) -> Option { - let mut tree = Tree::default(); - - for g in gates { - tree.try_add_gate(g, max_degree, max_constants)?; - } - tree.shorten(); - Some(tree) - } - - /// Try to add a gate in the tree. Returns `None` if this fails. - fn try_add_gate( - &mut self, - g: &GateRef, - max_degree: usize, - max_constants: usize, - ) -> Option<()> { - // We want `gate.degree + depth <= max_degree` and `gate.num_constants + depth <= max_wires`. - let depth = max_degree - .checked_sub(g.0.degree())? - .min(max_constants.checked_sub(g.0.num_constants())?); - self.try_add_gate_at_depth(g, depth) - } - - /// Try to add a gate in the tree at a specified depth. Returns `None` if this fails. - fn try_add_gate_at_depth(&mut self, g: &GateRef, depth: usize) -> Option<()> { - // If depth is 0, we have to insert the gate here. - if depth == 0 { - return if let Tree::Bifurcation(None, None) = self { - // Insert the gate as a new leaf. - *self = Tree::Leaf(g.clone()); - Some(()) - } else { - // A leaf is already here. - None - }; - } - - // A leaf is already here so we cannot go deeper. - if let Tree::Leaf(_) = self { - return None; - } - - if let Tree::Bifurcation(left, right) = self { - if let Some(left) = left { - // Try to add the gate to the left if there's already a left subtree. - if left.try_add_gate_at_depth(g, depth - 1).is_some() { - return Some(()); - } - } else { - // Add a new left subtree and try to add the gate to it. - let mut new_left = Tree::default(); - if new_left.try_add_gate_at_depth(g, depth - 1).is_some() { - *left = Some(Box::new(new_left)); - return Some(()); - } - } - if let Some(right) = right { - // Try to add the gate to the right if there's already a right subtree. - if right.try_add_gate_at_depth(g, depth - 1).is_some() { - return Some(()); - } - } else { - // Add a new right subtree and try to add the gate to it. - let mut new_right = Tree::default(); - if new_right.try_add_gate_at_depth(g, depth - 1).is_some() { - *right = Some(Box::new(new_right)); - return Some(()); - } - } - } - - None - } - - /// `Self::find_tree` returns a tree where each gate has `F(gate)=M` (see `Self::from_gates` comment). - /// This can produce subtrees with more nodes than necessary. This function removes useless nodes, - /// i.e., nodes that have a left but no right subtree. - fn shorten(&mut self) { - if let Tree::Bifurcation(left, right) = self { - if let (Some(left), None) = (left, right) { - // If the node has a left but no right subtree, set the node to its (shortened) left subtree. - let mut new = *left.clone(); - new.shorten(); - *self = new; - } - } - if let Tree::Bifurcation(left, right) = self { - if let Some(left) = left { - // Shorten the left subtree if there is one. - left.shorten(); - } - if let Some(right) = right { - // Shorten the right subtree if there is one. - right.shorten(); - } - } - } - - /// Returns the tree's maximum filtered constraint degree. - pub fn max_filtered_degree(&self) -> usize { - self.traversal() - .into_iter() - .map(|(g, p)| g.0.degree() + p.len()) - .max() - .expect("Empty tree.") - } - - /// Returns the number of constant wires needed to fit all prefixes and gate constants. - fn num_constants(&self) -> usize { - self.traversal() - .into_iter() - .map(|(g, p)| g.0.num_constants() + p.len()) - .max() - .expect("Empty tree.") - } -} - -#[cfg(test)] -mod tests { - use log::info; - - use super::*; - use crate::gadgets::interpolation::InterpolationGate; - use crate::gates::arithmetic_extension::ArithmeticExtensionGate; - use crate::gates::base_sum::BaseSumGate; - use crate::gates::constant::ConstantGate; - use crate::gates::interpolation::HighDegreeInterpolationGate; - use crate::gates::noop::NoopGate; - use crate::gates::poseidon::PoseidonGate; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - #[test] - fn test_prefix_generation() { - env_logger::init(); - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let gates = vec![ - GateRef::new(NoopGate), - GateRef::new(ConstantGate { num_consts: 4 }), - GateRef::new(ArithmeticExtensionGate { num_ops: 4 }), - GateRef::new(BaseSumGate::<4>::new(4)), - GateRef::new(PoseidonGate::::new()), - GateRef::new(HighDegreeInterpolationGate::new(2)), - ]; - - let (tree, _, _) = Tree::from_gates(gates.clone()); - let mut gates_with_prefix = tree.traversal(); - for (g, p) in &gates_with_prefix { - info!( - "\nGate: {}, prefix: {:?}.\n\ - Filtered constraint degree: {}, Num constant wires: {}", - &g.0.id()[..20.min(g.0.id().len())], - p, - g.0.degree() + p.len(), - g.0.num_constants() + p.len() - ); - } - - assert_eq!( - gates_with_prefix.len(), - gates.len(), - "The tree has too much or too little gates." - ); - assert!( - gates - .iter() - .all(|g| gates_with_prefix.iter().map(|(gg, _)| gg).any(|gg| gg == g)), - "Some gates are not in the tree." - ); - assert!( - gates_with_prefix - .iter() - .all(|(g, p)| g.0.degree() + g.0.num_constants() + p.len() <= 9), - "Total degree is larger than 8." - ); - - gates_with_prefix.sort_unstable_by_key(|(_g, p)| p.len()); - for i in 0..gates_with_prefix.len() { - for j in i + 1..gates_with_prefix.len() { - assert_ne!( - &gates_with_prefix[i].1, - &gates_with_prefix[j].1[0..gates_with_prefix[i].1.len()], - "Some gates share an overlapping prefix" - ); - } - } - } -} diff --git a/plonky2/src/gates/mod.rs b/plonky2/src/gates/mod.rs index 18e3e99b..7a08e709 100644 --- a/plonky2/src/gates/mod.rs +++ b/plonky2/src/gates/mod.rs @@ -11,7 +11,6 @@ pub mod comparison; pub mod constant; pub mod exponentiation; pub mod gate; -pub mod gate_tree; pub mod interpolation; pub mod low_degree_interpolation; pub mod multiplication_extension; @@ -24,6 +23,7 @@ pub mod random_access; pub mod range_check_u32; pub mod reducing; pub mod reducing_extension; +pub(crate) mod selectors; pub mod subtraction_u32; pub mod switch; pub mod util; diff --git a/plonky2/src/gates/selectors.rs b/plonky2/src/gates/selectors.rs new file mode 100644 index 00000000..dffeb987 --- /dev/null +++ b/plonky2/src/gates/selectors.rs @@ -0,0 +1,111 @@ +use std::ops::Range; + +use plonky2_field::extension_field::Extendable; +use plonky2_field::polynomial::PolynomialValues; + +use crate::gates::gate::{GateInstance, GateRef}; +use crate::hash::hash_types::RichField; + +/// Placeholder value to indicate that a gate doesn't use a selector polynomial. +pub(crate) const UNUSED_SELECTOR: usize = u32::MAX as usize; + +#[derive(Debug, Clone)] +pub(crate) struct SelectorsInfo { + pub(crate) selector_indices: Vec, + pub(crate) groups: Vec>, +} + +impl SelectorsInfo { + pub fn num_selectors(&self) -> usize { + self.groups.len() + } +} + +/// Returns the selector polynomials and related information. +/// +/// Selector polynomials are computed as follows: +/// Partition the gates into (the smallest amount of) groups `{ G_i }`, such that for each group `G` +/// `|G| + max_{g in G} g.degree() <= max_degree`. These groups are constructed greedily from +/// the list of gates sorted by degree. +/// We build a selector polynomial `S_i` for each group `G_i`, with +/// S_i[j] = +/// if j-th row gate=g_k in G_i +/// k +/// else +/// UNUSED_SELECTOR +pub(crate) fn selector_polynomials, const D: usize>( + gates: &[GateRef], + instances: &[GateInstance], + max_degree: usize, +) -> (Vec>, SelectorsInfo) { + let n = instances.len(); + let num_gates = gates.len(); + let max_gate_degree = gates.last().expect("No gates?").0.degree(); + + let index = |id| gates.iter().position(|g| g.0.id() == id).unwrap(); + + // Special case if we can use only one selector polynomial. + if max_gate_degree + num_gates - 1 <= max_degree { + return ( + vec![PolynomialValues::new( + instances + .iter() + .map(|g| F::from_canonical_usize(index(g.gate_ref.0.id()))) + .collect(), + )], + SelectorsInfo { + selector_indices: vec![0; num_gates], + groups: vec![0..num_gates], + }, + ); + } + + if max_gate_degree >= max_degree { + panic!( + "{} has too high degree. Consider increasing `quotient_degree_factor`.", + gates.last().unwrap().0.id() + ); + } + + // Greedily construct the groups. + let mut groups = Vec::new(); + let mut start = 0; + while start < num_gates { + let mut size = 0; + while (start + size < gates.len()) && (size + gates[start + size].0.degree() < max_degree) { + size += 1; + } + groups.push(start..start + size); + start += size; + } + + let group = |i| groups.iter().position(|range| range.contains(&i)).unwrap(); + + // `selector_indices[i] = j` iff the `i`-th gate uses the `j`-th selector polynomial. + let selector_indices = (0..num_gates).map(group).collect(); + + // Placeholder value to indicate that a gate doesn't use a selector polynomial. + let unused = F::from_canonical_usize(UNUSED_SELECTOR); + + let mut polynomials = vec![PolynomialValues::zero(n); groups.len()]; + for (j, g) in instances.iter().enumerate() { + let GateInstance { gate_ref, .. } = g; + let i = index(gate_ref.0.id()); + let gr = group(i); + for g in 0..groups.len() { + polynomials[g].values[j] = if g == gr { + F::from_canonical_usize(i) + } else { + unused + }; + } + } + + ( + polynomials, + SelectorsInfo { + selector_indices, + groups, + }, + ) +} diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index 8e2f2e10..924ca553 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -19,10 +19,10 @@ use crate::gadgets::polynomial::PolynomialCoeffsExtTarget; use crate::gates::arithmetic_base::ArithmeticGate; use crate::gates::arithmetic_extension::ArithmeticExtensionGate; use crate::gates::constant::ConstantGate; -use crate::gates::gate::{CurrentSlot, Gate, GateInstance, GateRef, PrefixedGate}; -use crate::gates::gate_tree::Tree; +use crate::gates::gate::{CurrentSlot, Gate, GateInstance, GateRef}; use crate::gates::noop::NoopGate; use crate::gates::public_input::PublicInputGate; +use crate::gates::selectors::selector_polynomials; use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget, RichField}; use crate::hash::merkle_proofs::MerkleProofTarget; use crate::iop::ext_target::ExtensionTarget; @@ -551,32 +551,27 @@ impl, const D: usize> CircuitBuilder { } } - fn constant_polys( - &self, - gates: &[PrefixedGate], - num_constants: usize, - ) -> Vec> { - let constants_per_gate = self - .gate_instances + fn constant_polys(&self) -> Vec> { + let max_constants = self + .gates .iter() - .map(|gate| { - let prefix = &gates - .iter() - .find(|g| g.gate.0.id() == gate.gate_ref.0.id()) - .unwrap() - .prefix; - let mut prefixed_constants = Vec::with_capacity(num_constants); - prefixed_constants.extend(prefix.iter().map(|&b| if b { F::ONE } else { F::ZERO })); - prefixed_constants.extend_from_slice(&gate.constants); - prefixed_constants.resize(num_constants, F::ZERO); - prefixed_constants - }) - .collect::>(); - - transpose(&constants_per_gate) - .into_iter() - .map(PolynomialValues::new) - .collect() + .map(|g| g.0.num_constants()) + .max() + .unwrap(); + transpose( + &self + .gate_instances + .iter() + .map(|g| { + let mut consts = g.constants.clone(); + consts.resize(max_constants, F::ZERO); + consts + }) + .collect::>(), + ) + .into_iter() + .map(PolynomialValues::new) + .collect() } fn sigma_vecs(&self, k_is: &[F], subgroup: &[F]) -> (Vec>, Forest) { @@ -669,27 +664,17 @@ impl, const D: usize> CircuitBuilder { "FRI total reduction arity is too large.", ); - let gates = self.gates.iter().cloned().collect(); - let (gate_tree, max_filtered_constraint_degree, num_constants) = Tree::from_gates(gates); - let prefixed_gates = PrefixedGate::from_tree(gate_tree); - - // `quotient_degree_factor` has to be between `max_filtered_constraint_degree-1` and `1<>(); + // Gates need to be sorted by their degrees (and ID to make the ordering deterministic) to compute the selector polynomials. + gates.sort_unstable_by_key(|g| (g.0.degree(), g.0.id())); + let (mut constant_vecs, selectors_info) = + selector_polynomials(&gates, &self.gate_instances, quotient_degree_factor + 1); + constant_vecs.extend(self.constant_polys()); + let num_constants = constant_vecs.len(); let subgroup = F::two_adic_subgroup(degree_bits); - let constant_vecs = timed!( - timing, - "generate constant polynomials", - self.constant_polys(&prefixed_gates, num_constants) - ); - let k_is = get_unique_coset_shifts(degree, self.config.num_routed_wires); let (sigma_vecs, forest) = timed!( timing, @@ -768,11 +753,6 @@ impl, const D: usize> CircuitBuilder { fft_root_table: Some(fft_root_table), }; - // The HashSet of gates will have a non-deterministic order. When converting to a Vec, we - // sort by ID to make the ordering deterministic. - let mut gates = self.gates.iter().cloned().collect::>(); - gates.sort_unstable_by_key(|gate| gate.0.id()); - let num_gate_constraints = gates .iter() .map(|gate| gate.0.num_constraints()) @@ -793,7 +773,8 @@ impl, const D: usize> CircuitBuilder { config: self.config, fri_params, degree_bits, - gates: prefixed_gates, + gates, + selectors_info, quotient_degree_factor, num_gate_constraints, num_constants, diff --git a/plonky2/src/plonk/circuit_data.rs b/plonky2/src/plonk/circuit_data.rs index 34b38fcf..e836014b 100644 --- a/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/src/plonk/circuit_data.rs @@ -12,7 +12,8 @@ use crate::fri::structure::{ FriBatchInfo, FriBatchInfoTarget, FriInstanceInfo, FriInstanceInfoTarget, FriPolynomialInfo, }; use crate::fri::{FriConfig, FriParams}; -use crate::gates::gate::PrefixedGate; +use crate::gates::gate::GateRef; +use crate::gates::selectors::SelectorsInfo; use crate::hash::hash_types::{MerkleCapTarget, RichField}; use crate::hash::merkle_tree::MerkleCap; use crate::iop::ext_target::ExtensionTarget; @@ -253,7 +254,10 @@ pub struct CommonCircuitData< pub(crate) degree_bits: usize, /// The types of gates used in this circuit, along with their prefixes. - pub(crate) gates: Vec>, + pub(crate) gates: Vec>, + + /// Information on the circuit's selector polynomials. + pub(crate) selectors_info: SelectorsInfo, /// The degree of the PLONK quotient polynomial. pub(crate) quotient_degree_factor: usize, @@ -297,7 +301,7 @@ impl, C: GenericConfig, const D: usize> pub fn constraint_degree(&self) -> usize { self.gates .iter() - .map(|g| g.gate.0.degree()) + .map(|g| g.0.degree()) .max() .expect("No gates?") } diff --git a/plonky2/src/plonk/vanishing_poly.rs b/plonky2/src/plonk/vanishing_poly.rs index 70de5833..ab23aa45 100644 --- a/plonky2/src/plonk/vanishing_poly.rs +++ b/plonky2/src/plonk/vanishing_poly.rs @@ -3,7 +3,6 @@ use plonky2_field::extension_field::{Extendable, FieldExtension}; use plonky2_field::field_types::Field; use plonky2_field::zero_poly_coset::ZeroPolyOnCoset; -use crate::gates::gate::PrefixedGate; use crate::hash::hash_types::RichField; use crate::iop::ext_target::ExtensionTarget; use crate::iop::target::Target; @@ -40,8 +39,7 @@ pub(crate) fn eval_vanishing_poly< let max_degree = common_data.quotient_degree_factor; let num_prods = common_data.num_partial_products; - let constraint_terms = - evaluate_gate_constraints(&common_data.gates, common_data.num_gate_constraints, vars); + let constraint_terms = evaluate_gate_constraints(common_data, vars); // The L_1(x) (Z(x) - 1) vanishing terms. let mut vanishing_z_1_terms = Vec::new(); @@ -128,8 +126,7 @@ pub(crate) fn eval_vanishing_poly_base_batch< let num_gate_constraints = common_data.num_gate_constraints; - let constraint_terms_batch = - evaluate_gate_constraints_base_batch(&common_data.gates, num_gate_constraints, vars_batch); + let constraint_terms_batch = evaluate_gate_constraints_base_batch(common_data, vars_batch); debug_assert!(constraint_terms_batch.len() == n * num_gate_constraints); let num_challenges = common_data.config.num_challenges; @@ -208,17 +205,27 @@ pub(crate) fn eval_vanishing_poly_base_batch< /// `num_gate_constraints` is the largest number of constraints imposed by any gate. It is not /// strictly necessary, but it helps performance by ensuring that we allocate a vector with exactly /// the capacity that we need. -pub fn evaluate_gate_constraints, const D: usize>( - gates: &[PrefixedGate], - num_gate_constraints: usize, +pub fn evaluate_gate_constraints< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + common_data: &CommonCircuitData, vars: EvaluationVars, ) -> Vec { - let mut constraints = vec![F::Extension::ZERO; num_gate_constraints]; - for gate in gates { - let gate_constraints = gate.gate.0.eval_filtered(vars, &gate.prefix); + let mut constraints = vec![F::Extension::ZERO; common_data.num_gate_constraints]; + for (i, gate) in common_data.gates.iter().enumerate() { + let selector_index = common_data.selectors_info.selector_indices[i]; + let gate_constraints = gate.0.eval_filtered( + vars, + i, + selector_index, + common_data.selectors_info.groups[selector_index].clone(), + common_data.selectors_info.num_selectors(), + ); for (i, c) in gate_constraints.into_iter().enumerate() { debug_assert!( - i < num_gate_constraints, + i < common_data.num_gate_constraints, "num_constraints() gave too low of a number" ); constraints[i] += c; @@ -232,17 +239,24 @@ pub fn evaluate_gate_constraints, const D: usize>( /// Returns a vector of `num_gate_constraints * vars_batch.len()` field elements. The constraints /// corresponding to `vars_batch[i]` are found in `result[i], result[vars_batch.len() + i], /// result[2 * vars_batch.len() + i], ...`. -pub fn evaluate_gate_constraints_base_batch, const D: usize>( - gates: &[PrefixedGate], - num_gate_constraints: usize, +pub fn evaluate_gate_constraints_base_batch< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + common_data: &CommonCircuitData, vars_batch: EvaluationVarsBaseBatch, ) -> Vec { - let mut constraints_batch = vec![F::ZERO; num_gate_constraints * vars_batch.len()]; - for gate in gates { - let gate_constraints_batch = gate - .gate - .0 - .eval_filtered_base_batch(vars_batch, &gate.prefix); + let mut constraints_batch = vec![F::ZERO; common_data.num_gate_constraints * vars_batch.len()]; + for (i, gate) in common_data.gates.iter().enumerate() { + let selector_index = common_data.selectors_info.selector_indices[i]; + let gate_constraints_batch = gate.0.eval_filtered_base_batch( + vars_batch, + i, + selector_index, + common_data.selectors_info.groups[selector_index].clone(), + common_data.selectors_info.num_selectors(), + ); debug_assert!( gate_constraints_batch.len() <= constraints_batch.len(), "num_constraints() gave too low of a number" @@ -256,22 +270,29 @@ pub fn evaluate_gate_constraints_base_batch, const constraints_batch } -pub fn evaluate_gate_constraints_recursively, const D: usize>( +pub fn evaluate_gate_constraints_recursively< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( builder: &mut CircuitBuilder, - gates: &[PrefixedGate], - num_gate_constraints: usize, + common_data: &CommonCircuitData, vars: EvaluationTargets, ) -> Vec> { - let mut all_gate_constraints = vec![builder.zero_extension(); num_gate_constraints]; - for gate in gates { + let mut all_gate_constraints = vec![builder.zero_extension(); common_data.num_gate_constraints]; + for (i, gate) in common_data.gates.iter().enumerate() { + let selector_index = common_data.selectors_info.selector_indices[i]; with_context!( builder, - &format!("evaluate {} constraints", gate.gate.0.id()), - gate.gate.0.eval_filtered_recursively( + &format!("evaluate {} constraints", gate.0.id()), + gate.0.eval_filtered_recursively( builder, vars, - &gate.prefix, - &mut all_gate_constraints + i, + selector_index, + common_data.selectors_info.groups[selector_index].clone(), + common_data.selectors_info.num_selectors(), + &mut all_gate_constraints, ) ); } @@ -308,12 +329,7 @@ pub(crate) fn eval_vanishing_poly_recursively< let constraint_terms = with_context!( builder, "evaluate gate constraints", - evaluate_gate_constraints_recursively( - builder, - &common_data.gates, - common_data.num_gate_constraints, - vars, - ) + evaluate_gate_constraints_recursively(builder, common_data, vars,) ); // The L_1(x) (Z(x) - 1) vanishing terms. diff --git a/plonky2/src/plonk/vars.rs b/plonky2/src/plonk/vars.rs index e2e52cfb..8c7b71e3 100644 --- a/plonky2/src/plonk/vars.rs +++ b/plonky2/src/plonk/vars.rs @@ -55,8 +55,8 @@ impl<'a, F: RichField + Extendable, const D: usize> EvaluationVars<'a, F, D> ExtensionAlgebra::from_basefield_array(arr) } - pub fn remove_prefix(&mut self, prefix: &[bool]) { - self.local_constants = &self.local_constants[prefix.len()..]; + pub fn remove_prefix(&mut self, num_selectors: usize) { + self.local_constants = &self.local_constants[num_selectors..]; } } @@ -77,8 +77,8 @@ impl<'a, F: Field> EvaluationVarsBaseBatch<'a, F> { } } - pub fn remove_prefix(&mut self, prefix: &[bool]) { - self.local_constants = &self.local_constants[prefix.len() * self.len()..]; + pub fn remove_prefix(&mut self, num_selectors: usize) { + self.local_constants = &self.local_constants[num_selectors * self.len()..]; } pub fn len(&self) -> usize { @@ -209,8 +209,8 @@ impl<'a, P: PackedField> ExactSizeIterator for EvaluationVarsBaseBatchIterPacked } impl<'a, const D: usize> EvaluationTargets<'a, D> { - pub fn remove_prefix(&mut self, prefix: &[bool]) { - self.local_constants = &self.local_constants[prefix.len()..]; + pub fn remove_prefix(&mut self, num_selectors: usize) { + self.local_constants = &self.local_constants[num_selectors..]; } }