PR feedback + use only one selector when possible

This commit is contained in:
wborgeaud 2022-03-28 10:15:06 +02:00
parent 283c9350a7
commit e50e668f7e
4 changed files with 66 additions and 29 deletions

View File

@ -1,6 +1,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::{Debug, Error, Formatter}; use std::fmt::{Debug, Error, Formatter};
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
use std::ops::Range;
use std::sync::Arc; use std::sync::Arc;
use plonky2_field::batch_util::batch_multiply_inplace; use plonky2_field::batch_util::batch_multiply_inplace;
@ -85,13 +86,14 @@ pub trait Gate<F: RichField + Extendable<D>, const D: usize>: 'static + Send + S
mut vars: EvaluationVars<F, D>, mut vars: EvaluationVars<F, D>,
gate_index: usize, gate_index: usize,
selector_index: usize, selector_index: usize,
group_range: (usize, usize), group_range: Range<usize>,
num_selectors: usize, num_selectors: usize,
) -> Vec<F::Extension> { ) -> Vec<F::Extension> {
let filter = compute_filter( let filter = compute_filter(
gate_index, gate_index,
group_range, group_range,
vars.local_constants[selector_index], vars.local_constants[selector_index],
num_selectors > 1,
); );
vars.remove_prefix(num_selectors); vars.remove_prefix(num_selectors);
self.eval_unfiltered(vars) self.eval_unfiltered(vars)
@ -107,7 +109,7 @@ pub trait Gate<F: RichField + Extendable<D>, const D: usize>: 'static + Send + S
mut vars_batch: EvaluationVarsBaseBatch<F>, mut vars_batch: EvaluationVarsBaseBatch<F>,
gate_index: usize, gate_index: usize,
selector_index: usize, selector_index: usize,
group_range: (usize, usize), group_range: Range<usize>,
num_selectors: usize, num_selectors: usize,
) -> Vec<F> { ) -> Vec<F> {
let filters: Vec<_> = vars_batch let filters: Vec<_> = vars_batch
@ -115,8 +117,9 @@ pub trait Gate<F: RichField + Extendable<D>, const D: usize>: 'static + Send + S
.map(|vars| { .map(|vars| {
compute_filter( compute_filter(
gate_index, gate_index,
group_range, group_range.clone(),
vars.local_constants[selector_index], vars.local_constants[selector_index],
num_selectors > 1,
) )
}) })
.collect(); .collect();
@ -135,7 +138,7 @@ pub trait Gate<F: RichField + Extendable<D>, const D: usize>: 'static + Send + S
mut vars: EvaluationTargets<D>, mut vars: EvaluationTargets<D>,
gate_index: usize, gate_index: usize,
selector_index: usize, selector_index: usize,
group_range: (usize, usize), group_range: Range<usize>,
num_selectors: usize, num_selectors: usize,
combined_gate_constraints: &mut [ExtensionTarget<D>], combined_gate_constraints: &mut [ExtensionTarget<D>],
) { ) {
@ -144,6 +147,7 @@ pub trait Gate<F: RichField + Extendable<D>, const D: usize>: 'static + Send + S
gate_index, gate_index,
group_range, group_range,
vars.local_constants[selector_index], vars.local_constants[selector_index],
num_selectors > 1,
); );
vars.remove_prefix(num_selectors); vars.remove_prefix(num_selectors);
let my_constraints = self.eval_unfiltered_recursively(builder, vars); let my_constraints = self.eval_unfiltered_recursively(builder, vars);
@ -231,11 +235,16 @@ pub struct PrefixedGate<F: RichField + Extendable<D>, const D: usize> {
} }
/// A gate's filter designed so that it is non-zero if `s = gate_index`. /// A gate's filter designed so that it is non-zero if `s = gate_index`.
fn compute_filter<K: Field>(gate_index: usize, group_range: (usize, usize), s: K) -> K { fn compute_filter<K: Field>(
debug_assert!((group_range.0 <= gate_index) && (gate_index < group_range.1)); gate_index: usize,
(group_range.0..group_range.1) group_range: Range<usize>,
s: K,
many_selector: bool,
) -> K {
debug_assert!(group_range.contains(&gate_index));
group_range
.filter(|&i| i != gate_index) .filter(|&i| i != gate_index)
.chain(Some(UNUSED_SELECTOR)) .chain(many_selector.then(|| UNUSED_SELECTOR))
.map(|i| K::from_canonical_usize(i) - s) .map(|i| K::from_canonical_usize(i) - s)
.product() .product()
} }
@ -243,13 +252,14 @@ fn compute_filter<K: Field>(gate_index: usize, group_range: (usize, usize), s: K
fn compute_filter_recursively<F: RichField + Extendable<D>, const D: usize>( fn compute_filter_recursively<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>, builder: &mut CircuitBuilder<F, D>,
gate_index: usize, gate_index: usize,
group_range: (usize, usize), group_range: Range<usize>,
s: ExtensionTarget<D>, s: ExtensionTarget<D>,
many_selectors: bool,
) -> ExtensionTarget<D> { ) -> ExtensionTarget<D> {
debug_assert!((group_range.0 <= gate_index) && (gate_index < group_range.1)); debug_assert!(group_range.contains(&gate_index));
let v = (group_range.0..group_range.1) let v = group_range
.filter(|&i| i != gate_index) .filter(|&i| i != gate_index)
.chain(Some(UNUSED_SELECTOR)) .chain(many_selectors.then(|| UNUSED_SELECTOR))
.map(|i| { .map(|i| {
let c = builder.constant_extension(F::Extension::from_canonical_usize(i)); let c = builder.constant_extension(F::Extension::from_canonical_usize(i));
builder.sub_extension(c, s) builder.sub_extension(c, s)

View File

@ -1,3 +1,5 @@
use std::ops::Range;
use plonky2_field::extension_field::Extendable; use plonky2_field::extension_field::Extendable;
use plonky2_field::polynomial::PolynomialValues; use plonky2_field::polynomial::PolynomialValues;
@ -10,7 +12,7 @@ pub(crate) const UNUSED_SELECTOR: usize = u32::MAX as usize;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct SelectorsInfo { pub(crate) struct SelectorsInfo {
pub(crate) selector_indices: Vec<usize>, pub(crate) selector_indices: Vec<usize>,
pub(crate) groups: Vec<(usize, usize)>, pub(crate) groups: Vec<Range<usize>>,
pub(crate) num_selectors: usize, pub(crate) num_selectors: usize,
} }
@ -32,27 +34,51 @@ pub(crate) fn selector_polynomials<F: RichField + Extendable<D>, const D: usize>
max_degree: usize, max_degree: usize,
) -> (Vec<PolynomialValues<F>>, SelectorsInfo) { ) -> (Vec<PolynomialValues<F>>, SelectorsInfo) {
let n = instances.len(); let n = instances.len();
let num_gates = gates.len();
let max_gate_degree = gates.last().expect("No gates?").0.degree();
let index = |id| gates.iter().position(|g| g.0.id() == id).unwrap();
// Special case if we can use only one selector polynomial.
if max_gate_degree + num_gates - 1 <= max_degree {
return (
vec![PolynomialValues::new(
instances
.iter()
.map(|g| F::from_canonical_usize(index(g.gate_ref.0.id())))
.collect(),
)],
SelectorsInfo {
selector_indices: vec![0; num_gates],
groups: vec![0..num_gates],
num_selectors: 1,
},
);
}
if max_gate_degree >= max_degree {
panic!(
"{} has too high degree. Consider increasing `quotient_degree_factor`.",
gates.last().unwrap().0.id()
);
}
// Greedily construct the groups. // Greedily construct the groups.
let mut groups = Vec::new(); let mut groups = Vec::new();
let mut pos = 0; let mut start = 0;
while pos < gates.len() { while start < num_gates {
let mut i = 0; let mut size = 0;
while (pos + i < gates.len()) && (i + gates[pos + i].0.degree() < max_degree) { while (start + size < gates.len()) && (size + gates[start + size].0.degree() < max_degree) {
i += 1; size += 1;
} }
groups.push((pos, pos + i)); groups.push(start..start + size);
pos += i; start += size;
} }
let index = |id| gates.iter().position(|g| g.0.id() == id).unwrap(); let group = |i| groups.iter().position(|range| range.contains(&i)).unwrap();
let group = |i| groups.iter().position(|&(a, b)| a <= i && i < b).unwrap();
// `selector_indices[i] = j` iff the `i`-th gate uses the `j`-th selector polynomial. // `selector_indices[i] = j` iff the `i`-th gate uses the `j`-th selector polynomial.
let selector_indices = gates let selector_indices = (0..num_gates).map(group).collect();
.iter()
.map(|g| group(index(g.0.id())))
.collect::<Vec<_>>();
// Placeholder value to indicate that a gate doesn't use a selector polynomial. // Placeholder value to indicate that a gate doesn't use a selector polynomial.
let unused = F::from_canonical_usize(UNUSED_SELECTOR); let unused = F::from_canonical_usize(UNUSED_SELECTOR);

View File

@ -666,6 +666,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let quotient_degree_factor = self.config.max_quotient_degree_factor; let quotient_degree_factor = self.config.max_quotient_degree_factor;
let mut gates = self.gates.iter().cloned().collect::<Vec<_>>(); let mut gates = self.gates.iter().cloned().collect::<Vec<_>>();
// Gates need to be sorted by their degrees to compute the selector polynomials.
gates.sort_unstable_by_key(|g| (g.0.degree(), g.0.id())); gates.sort_unstable_by_key(|g| (g.0.degree(), g.0.id()));
let (mut constant_vecs, selectors_info) = let (mut constant_vecs, selectors_info) =
selector_polynomials(&gates, &self.gate_instances, quotient_degree_factor + 1); selector_polynomials(&gates, &self.gate_instances, quotient_degree_factor + 1);

View File

@ -220,7 +220,7 @@ pub fn evaluate_gate_constraints<
vars, vars,
i, i,
selector_index, selector_index,
common_data.selectors_info.groups[selector_index], common_data.selectors_info.groups[selector_index].clone(),
common_data.selectors_info.num_selectors, common_data.selectors_info.num_selectors,
); );
for (i, c) in gate_constraints.into_iter().enumerate() { for (i, c) in gate_constraints.into_iter().enumerate() {
@ -254,7 +254,7 @@ pub fn evaluate_gate_constraints_base_batch<
vars_batch, vars_batch,
i, i,
selector_index, selector_index,
common_data.selectors_info.groups[selector_index], common_data.selectors_info.groups[selector_index].clone(),
common_data.selectors_info.num_selectors, common_data.selectors_info.num_selectors,
); );
debug_assert!( debug_assert!(
@ -290,7 +290,7 @@ pub fn evaluate_gate_constraints_recursively<
vars, vars,
i, i,
selector_index, selector_index,
common_data.selectors_info.groups[selector_index], common_data.selectors_info.groups[selector_index].clone(),
common_data.selectors_info.num_selectors, common_data.selectors_info.num_selectors,
&mut all_gate_constraints, &mut all_gate_constraints,
) )