diff --git a/plonky2/src/gates/gate.rs b/plonky2/src/gates/gate.rs index bc56b2e2..f5de3472 100644 --- a/plonky2/src/gates/gate.rs +++ b/plonky2/src/gates/gate.rs @@ -7,6 +7,7 @@ use plonky2_field::batch_util::batch_multiply_inplace; use plonky2_field::extension_field::{Extendable, FieldExtension}; use plonky2_field::field_types::Field; +use crate::gates::selectors::UNUSED_SELECTOR; use crate::gates::util::StridedConstraintConsumer; use crate::hash::hash_types::RichField; use crate::iop::ext_target::ExtensionTarget; @@ -84,12 +85,12 @@ pub trait Gate, const D: usize>: 'static + Send + S mut vars: EvaluationVars, gate_index: usize, selector_index: usize, - combination_range: (usize, usize), + group_range: (usize, usize), num_selectors: usize, ) -> Vec { let filter = compute_filter( gate_index, - combination_range, + group_range, vars.local_constants[selector_index], ); vars.remove_prefix(num_selectors); @@ -106,7 +107,7 @@ pub trait Gate, const D: usize>: 'static + Send + S mut vars_batch: EvaluationVarsBaseBatch, gate_index: usize, selector_index: usize, - combination_range: (usize, usize), + group_range: (usize, usize), num_selectors: usize, ) -> Vec { let filters: Vec<_> = vars_batch @@ -114,7 +115,7 @@ pub trait Gate, const D: usize>: 'static + Send + S .map(|vars| { compute_filter( gate_index, - combination_range, + group_range, vars.local_constants[selector_index], ) }) @@ -134,14 +135,14 @@ pub trait Gate, const D: usize>: 'static + Send + S mut vars: EvaluationTargets, gate_index: usize, selector_index: usize, - combination_range: (usize, usize), - combined_gate_constraints: &mut [ExtensionTarget], + group_range: (usize, usize), num_selectors: usize, + combined_gate_constraints: &mut [ExtensionTarget], ) { let filter = compute_filter_recursively( builder, gate_index, - combination_range, + group_range, vars.local_constants[selector_index], ); vars.remove_prefix(num_selectors); @@ -229,34 +230,30 @@ pub struct PrefixedGate, const D: usize> { pub prefix: Vec, } -/// A gate's filter is computed as `prod b_i*c_i + (1-b_i)*(1-c_i)`, with `(b_i)` the prefix and -/// `(c_i)` the local constants, which is one if the prefix of `constants` matches `prefix`. -fn compute_filter( - gate_index: usize, - combination_range: (usize, usize), - constant: K, -) -> K { - (combination_range.0..combination_range.1) +/// A gate's filter designed so that it is non-zero if `s = gate_index`. +fn compute_filter(gate_index: usize, group_range: (usize, usize), s: K) -> K { + debug_assert!((group_range.0 <= gate_index) && (gate_index < group_range.1)); + (group_range.0..group_range.1) .filter(|&i| i != gate_index) - .chain(Some(u32::MAX as usize)) - .map(|i| K::from_canonical_usize(i) - constant) + .chain(Some(UNUSED_SELECTOR)) + .map(|i| K::from_canonical_usize(i) - s) .product() } fn compute_filter_recursively, const D: usize>( builder: &mut CircuitBuilder, gate_index: usize, - combination_range: (usize, usize), - constant: ExtensionTarget, + group_range: (usize, usize), + s: ExtensionTarget, ) -> ExtensionTarget { - let v = (combination_range.0..combination_range.1) + debug_assert!((group_range.0 <= gate_index) && (gate_index < group_range.1)); + let v = (group_range.0..group_range.1) .filter(|&i| i != gate_index) - .chain(Some(u32::MAX as usize)) - .map(|i| builder.constant_extension(F::Extension::from_canonical_usize(i))) - .collect::>(); - let v = v - .into_iter() - .map(|x| builder.sub_extension(x, constant)) + .chain(Some(UNUSED_SELECTOR)) + .map(|i| { + let c = builder.constant_extension(F::Extension::from_canonical_usize(i)); + builder.sub_extension(c, s) + }) .collect::>(); builder.mul_many_extension(&v) } diff --git a/plonky2/src/gates/selectors.rs b/plonky2/src/gates/selectors.rs index a014c2d4..ed251237 100644 --- a/plonky2/src/gates/selectors.rs +++ b/plonky2/src/gates/selectors.rs @@ -4,74 +4,73 @@ use plonky2_field::polynomial::PolynomialValues; use crate::gates::gate::{GateInstance, GateRef}; use crate::hash::hash_types::RichField; -pub(crate) fn compute_selectors, const D: usize>( +/// Placeholder value to indicate that a gate doesn't use a selector polynomial. +pub(crate) const UNUSED_SELECTOR: usize = u32::MAX as usize; + +#[derive(Debug, Clone)] +pub(crate) struct SelectorsInfo { + pub(crate) selector_indices: Vec, + pub(crate) groups: Vec<(usize, usize)>, + pub(crate) num_selectors: usize, +} + +/// Returns the selector polynomials and related information. +/// +/// Selector polynomials are computed as follows: +/// Partition the gates into (the smallest amount of) groups `{ G_i }`, such that for each group `G` +/// `|G| + max_{g in G} g.degree() <= max_degree`. These groups are constructed greedily from +/// the list of gates sorted by degree. +pub(crate) fn selector_polynomials, const D: usize>( gates: Vec>, instances: &[GateInstance], max_degree: usize, -) -> ( - Vec>, - Vec, - Vec<(usize, usize)>, - usize, -) { +) -> (Vec>, SelectorsInfo) { let n = instances.len(); - let mut combinations = Vec::new(); + // Greedily construct the groups. + let mut groups = Vec::new(); let mut pos = 0; - while pos < gates.len() { let mut i = 0; while (pos + i < gates.len()) && (i + gates[pos + i].0.degree() < max_degree) { i += 1; } - combinations.push((pos, pos + i)); + groups.push((pos, pos + i)); pos += i; } - let bad = F::from_canonical_usize(u32::MAX as usize); - - let num_constants_polynomials = gates.iter().map(|g| g.0.num_constants()).max().unwrap(); - let mut polynomials = - vec![PolynomialValues::zero(n); combinations.len() + num_constants_polynomials]; let index = |id| gates.iter().position(|g| g.0.id() == id).unwrap(); - let combination = |i| { - combinations - .iter() - .position(|&(a, b)| a <= i && i < b) - .unwrap() - }; + let group = |i| groups.iter().position(|&(a, b)| a <= i && i < b).unwrap(); + // `selector_indices[i] = j` iff the `i`-th gate uses the `j`-th selector polynomial. let selector_indices = gates .iter() - .map(|g| combination(index(g.0.id()))) + .map(|g| group(index(g.0.id()))) .collect::>(); - let combination_ranges = selector_indices - .iter() - .map(|&i| (combinations[i].0, combinations[i].1)) - .collect(); + // Placeholder value to indicate that a gate doesn't use a selector polynomial. + let unused = F::from_canonical_usize(UNUSED_SELECTOR); + + let mut polynomials = vec![PolynomialValues::zero(n); groups.len()]; for (j, g) in instances.iter().enumerate() { - let GateInstance { - gate_ref, - constants, - } = g; + let GateInstance { gate_ref, .. } = g; let i = index(gate_ref.0.id()); - let comb = combination(i); - polynomials[comb].values[j] = F::from_canonical_usize(i); - - for combis in (0..combinations.len()).filter(|&combis| combis != comb) { - polynomials[combis].values[j] = bad; - } - - for k in 0..constants.len() { - polynomials[combinations.len() + k].values[j] = constants[k]; + let gr = group(i); + for g in 0..groups.len() { + polynomials[g].values[j] = if g == gr { + F::from_canonical_usize(i) + } else { + unused + }; } } ( polynomials, - selector_indices, - combination_ranges, - combinations.len(), + SelectorsInfo { + selector_indices, + num_selectors: groups.len(), + groups, + }, ) } diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index cd039da6..7fbc67f4 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -22,7 +22,7 @@ use crate::gates::constant::ConstantGate; use crate::gates::gate::{CurrentSlot, Gate, GateInstance, GateRef}; use crate::gates::noop::NoopGate; use crate::gates::public_input::PublicInputGate; -use crate::gates::selectors::compute_selectors; +use crate::gates::selectors::selector_polynomials; use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget, RichField}; use crate::hash::merkle_proofs::MerkleProofTarget; use crate::iop::ext_target::ExtensionTarget; @@ -44,7 +44,7 @@ use crate::util::context_tree::ContextTree; use crate::util::marking::{Markable, MarkedTargets}; use crate::util::partial_products::num_partial_products; use crate::util::timing::TimingTree; -use crate::util::transpose_poly_values; +use crate::util::{transpose, transpose_poly_values}; pub struct CircuitBuilder, const D: usize> { pub config: CircuitConfig, @@ -551,6 +551,29 @@ impl, const D: usize> CircuitBuilder { } } + fn constant_polys(&self) -> Vec> { + let max_constants = self + .gates + .iter() + .map(|g| g.0.num_constants()) + .max() + .unwrap(); + transpose( + &self + .gate_instances + .iter() + .map(|g| { + let mut consts = g.constants.clone(); + consts.resize(max_constants, F::ZERO); + consts + }) + .collect::>(), + ) + .into_iter() + .map(PolynomialValues::new) + .collect() + } + fn sigma_vecs(&self, k_is: &[F], subgroup: &[F]) -> (Vec>, Forest) { let degree = self.gate_instances.len(); let degree_log = log2_strict(degree); @@ -641,17 +664,16 @@ impl, const D: usize> CircuitBuilder { "FRI total reduction arity is too large.", ); - let mut gates = self.gates.iter().cloned().collect::>(); - gates.sort_unstable_by_key(|g| g.0.degree()); - let (constant_vecs, selector_indices, combination_ranges, num_selectors) = - compute_selectors( - gates.clone(), - &self.gate_instances, - self.config.max_quotient_degree_factor + 1, - ); - let num_constants = constant_vecs.len(); let quotient_degree_factor = self.config.max_quotient_degree_factor; - debug!("Quotient degree factor set to: {}.", quotient_degree_factor); + let mut gates = self.gates.iter().cloned().collect::>(); + gates.sort_unstable_by_key(|g| (g.0.degree(), g.0.id())); + let (mut constant_vecs, selectors_info) = selector_polynomials( + gates.clone(), + &self.gate_instances, + quotient_degree_factor + 1, + ); + constant_vecs.extend(self.constant_polys()); + let num_constants = constant_vecs.len(); let subgroup = F::two_adic_subgroup(degree_bits); @@ -754,9 +776,7 @@ impl, const D: usize> CircuitBuilder { fri_params, degree_bits, gates, - selector_indices, - combination_ranges, - num_selectors, + selectors_info, quotient_degree_factor, num_gate_constraints, num_constants, diff --git a/plonky2/src/plonk/circuit_data.rs b/plonky2/src/plonk/circuit_data.rs index f67024e3..8c33ae98 100644 --- a/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/src/plonk/circuit_data.rs @@ -13,6 +13,7 @@ use crate::fri::structure::{ }; use crate::fri::{FriConfig, FriParams}; use crate::gates::gate::GateRef; +use crate::gates::selectors::SelectorsInfo; use crate::hash::hash_types::{MerkleCapTarget, RichField}; use crate::hash::merkle_tree::MerkleCap; use crate::iop::ext_target::ExtensionTarget; @@ -248,9 +249,8 @@ pub struct CommonCircuitData< /// The types of gates used in this circuit, along with their prefixes. pub(crate) gates: Vec>, - pub(crate) selector_indices: Vec, - pub(crate) combination_ranges: Vec<(usize, usize)>, - pub(crate) num_selectors: usize, + /// Information on the circuit's selector polynomials. + pub(crate) selectors_info: SelectorsInfo, /// The degree of the PLONK quotient polynomial. pub(crate) quotient_degree_factor: usize, diff --git a/plonky2/src/plonk/vanishing_poly.rs b/plonky2/src/plonk/vanishing_poly.rs index 44e48843..4ee797a5 100644 --- a/plonky2/src/plonk/vanishing_poly.rs +++ b/plonky2/src/plonk/vanishing_poly.rs @@ -215,12 +215,13 @@ pub fn evaluate_gate_constraints< ) -> Vec { let mut constraints = vec![F::Extension::ZERO; common_data.num_gate_constraints]; for (i, gate) in common_data.gates.iter().enumerate() { + let selector_index = common_data.selectors_info.selector_indices[i]; let gate_constraints = gate.0.eval_filtered( vars, i, - common_data.selector_indices[i], - common_data.combination_ranges[i], - common_data.num_selectors, + selector_index, + common_data.selectors_info.groups[selector_index], + common_data.selectors_info.num_selectors, ); for (i, c) in gate_constraints.into_iter().enumerate() { debug_assert!( @@ -248,12 +249,13 @@ pub fn evaluate_gate_constraints_base_batch< ) -> Vec { let mut constraints_batch = vec![F::ZERO; common_data.num_gate_constraints * vars_batch.len()]; for (i, gate) in common_data.gates.iter().enumerate() { + let selector_index = common_data.selectors_info.selector_indices[i]; let gate_constraints_batch = gate.0.eval_filtered_base_batch( vars_batch, i, - common_data.selector_indices[i], - common_data.combination_ranges[i], - common_data.num_selectors, + selector_index, + common_data.selectors_info.groups[selector_index], + common_data.selectors_info.num_selectors, ); debug_assert!( gate_constraints_batch.len() <= constraints_batch.len(), @@ -279,6 +281,7 @@ pub fn evaluate_gate_constraints_recursively< ) -> Vec> { let mut all_gate_constraints = vec![builder.zero_extension(); common_data.num_gate_constraints]; for (i, gate) in common_data.gates.iter().enumerate() { + let selector_index = common_data.selectors_info.selector_indices[i]; with_context!( builder, &format!("evaluate {} constraints", gate.0.id()), @@ -286,10 +289,10 @@ pub fn evaluate_gate_constraints_recursively< builder, vars, i, - common_data.selector_indices[i], - common_data.combination_ranges[i], + selector_index, + common_data.selectors_info.groups[selector_index], + common_data.selectors_info.num_selectors, &mut all_gate_constraints, - common_data.num_selectors ) ); }