diff --git a/plonky2/src/gates/gate.rs b/plonky2/src/gates/gate.rs index f5de3472..e95ec0f8 100644 --- a/plonky2/src/gates/gate.rs +++ b/plonky2/src/gates/gate.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use std::fmt::{Debug, Error, Formatter}; use std::hash::{Hash, Hasher}; +use std::ops::Range; use std::sync::Arc; use plonky2_field::batch_util::batch_multiply_inplace; @@ -85,13 +86,14 @@ pub trait Gate, const D: usize>: 'static + Send + S mut vars: EvaluationVars, gate_index: usize, selector_index: usize, - group_range: (usize, usize), + group_range: Range, num_selectors: usize, ) -> Vec { let filter = compute_filter( gate_index, group_range, vars.local_constants[selector_index], + num_selectors > 1, ); vars.remove_prefix(num_selectors); self.eval_unfiltered(vars) @@ -107,7 +109,7 @@ pub trait Gate, const D: usize>: 'static + Send + S mut vars_batch: EvaluationVarsBaseBatch, gate_index: usize, selector_index: usize, - group_range: (usize, usize), + group_range: Range, num_selectors: usize, ) -> Vec { let filters: Vec<_> = vars_batch @@ -115,8 +117,9 @@ pub trait Gate, const D: usize>: 'static + Send + S .map(|vars| { compute_filter( gate_index, - group_range, + group_range.clone(), vars.local_constants[selector_index], + num_selectors > 1, ) }) .collect(); @@ -135,7 +138,7 @@ pub trait Gate, const D: usize>: 'static + Send + S mut vars: EvaluationTargets, gate_index: usize, selector_index: usize, - group_range: (usize, usize), + group_range: Range, num_selectors: usize, combined_gate_constraints: &mut [ExtensionTarget], ) { @@ -144,6 +147,7 @@ pub trait Gate, const D: usize>: 'static + Send + S gate_index, group_range, vars.local_constants[selector_index], + num_selectors > 1, ); vars.remove_prefix(num_selectors); let my_constraints = self.eval_unfiltered_recursively(builder, vars); @@ -231,11 +235,16 @@ pub struct PrefixedGate, const D: usize> { } /// A gate's filter designed so that it is non-zero if `s = gate_index`. -fn compute_filter(gate_index: usize, group_range: (usize, usize), s: K) -> K { - debug_assert!((group_range.0 <= gate_index) && (gate_index < group_range.1)); - (group_range.0..group_range.1) +fn compute_filter( + gate_index: usize, + group_range: Range, + s: K, + many_selector: bool, +) -> K { + debug_assert!(group_range.contains(&gate_index)); + group_range .filter(|&i| i != gate_index) - .chain(Some(UNUSED_SELECTOR)) + .chain(many_selector.then(|| UNUSED_SELECTOR)) .map(|i| K::from_canonical_usize(i) - s) .product() } @@ -243,13 +252,14 @@ fn compute_filter(gate_index: usize, group_range: (usize, usize), s: K fn compute_filter_recursively, const D: usize>( builder: &mut CircuitBuilder, gate_index: usize, - group_range: (usize, usize), + group_range: Range, s: ExtensionTarget, + many_selectors: bool, ) -> ExtensionTarget { - debug_assert!((group_range.0 <= gate_index) && (gate_index < group_range.1)); - let v = (group_range.0..group_range.1) + debug_assert!(group_range.contains(&gate_index)); + let v = group_range .filter(|&i| i != gate_index) - .chain(Some(UNUSED_SELECTOR)) + .chain(many_selectors.then(|| UNUSED_SELECTOR)) .map(|i| { let c = builder.constant_extension(F::Extension::from_canonical_usize(i)); builder.sub_extension(c, s) diff --git a/plonky2/src/gates/selectors.rs b/plonky2/src/gates/selectors.rs index f51bc8cb..559dcf48 100644 --- a/plonky2/src/gates/selectors.rs +++ b/plonky2/src/gates/selectors.rs @@ -1,3 +1,5 @@ +use std::ops::Range; + use plonky2_field::extension_field::Extendable; use plonky2_field::polynomial::PolynomialValues; @@ -10,7 +12,7 @@ pub(crate) const UNUSED_SELECTOR: usize = u32::MAX as usize; #[derive(Debug, Clone)] pub(crate) struct SelectorsInfo { pub(crate) selector_indices: Vec, - pub(crate) groups: Vec<(usize, usize)>, + pub(crate) groups: Vec>, pub(crate) num_selectors: usize, } @@ -32,27 +34,51 @@ pub(crate) fn selector_polynomials, const D: usize> max_degree: usize, ) -> (Vec>, SelectorsInfo) { let n = instances.len(); + let num_gates = gates.len(); + let max_gate_degree = gates.last().expect("No gates?").0.degree(); + + let index = |id| gates.iter().position(|g| g.0.id() == id).unwrap(); + + // Special case if we can use only one selector polynomial. + if max_gate_degree + num_gates - 1 <= max_degree { + return ( + vec![PolynomialValues::new( + instances + .iter() + .map(|g| F::from_canonical_usize(index(g.gate_ref.0.id()))) + .collect(), + )], + SelectorsInfo { + selector_indices: vec![0; num_gates], + groups: vec![0..num_gates], + num_selectors: 1, + }, + ); + } + + if max_gate_degree >= max_degree { + panic!( + "{} has too high degree. Consider increasing `quotient_degree_factor`.", + gates.last().unwrap().0.id() + ); + } // Greedily construct the groups. let mut groups = Vec::new(); - let mut pos = 0; - while pos < gates.len() { - let mut i = 0; - while (pos + i < gates.len()) && (i + gates[pos + i].0.degree() < max_degree) { - i += 1; + let mut start = 0; + while start < num_gates { + let mut size = 0; + while (start + size < gates.len()) && (size + gates[start + size].0.degree() < max_degree) { + size += 1; } - groups.push((pos, pos + i)); - pos += i; + groups.push(start..start + size); + start += size; } - let index = |id| gates.iter().position(|g| g.0.id() == id).unwrap(); - let group = |i| groups.iter().position(|&(a, b)| a <= i && i < b).unwrap(); + let group = |i| groups.iter().position(|range| range.contains(&i)).unwrap(); // `selector_indices[i] = j` iff the `i`-th gate uses the `j`-th selector polynomial. - let selector_indices = gates - .iter() - .map(|g| group(index(g.0.id()))) - .collect::>(); + let selector_indices = (0..num_gates).map(group).collect(); // Placeholder value to indicate that a gate doesn't use a selector polynomial. let unused = F::from_canonical_usize(UNUSED_SELECTOR); diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index 12180600..022e2b12 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -666,6 +666,7 @@ impl, const D: usize> CircuitBuilder { let quotient_degree_factor = self.config.max_quotient_degree_factor; let mut gates = self.gates.iter().cloned().collect::>(); + // Gates need to be sorted by their degrees to compute the selector polynomials. gates.sort_unstable_by_key(|g| (g.0.degree(), g.0.id())); let (mut constant_vecs, selectors_info) = selector_polynomials(&gates, &self.gate_instances, quotient_degree_factor + 1); diff --git a/plonky2/src/plonk/vanishing_poly.rs b/plonky2/src/plonk/vanishing_poly.rs index 4ee797a5..54a5fe98 100644 --- a/plonky2/src/plonk/vanishing_poly.rs +++ b/plonky2/src/plonk/vanishing_poly.rs @@ -220,7 +220,7 @@ pub fn evaluate_gate_constraints< vars, i, selector_index, - common_data.selectors_info.groups[selector_index], + common_data.selectors_info.groups[selector_index].clone(), common_data.selectors_info.num_selectors, ); for (i, c) in gate_constraints.into_iter().enumerate() { @@ -254,7 +254,7 @@ pub fn evaluate_gate_constraints_base_batch< vars_batch, i, selector_index, - common_data.selectors_info.groups[selector_index], + common_data.selectors_info.groups[selector_index].clone(), common_data.selectors_info.num_selectors, ); debug_assert!( @@ -290,7 +290,7 @@ pub fn evaluate_gate_constraints_recursively< vars, i, selector_index, - common_data.selectors_info.groups[selector_index], + common_data.selectors_info.groups[selector_index].clone(), common_data.selectors_info.num_selectors, &mut all_gate_constraints, )