Merge pull request #524 from mir-protocol/better_selectors

Change selector scheme
This commit is contained in:
wborgeaud 2022-03-31 09:13:19 +02:00 committed by GitHub
commit b4d11c28fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 274 additions and 430 deletions

View File

@ -1,13 +1,14 @@
use std::collections::HashMap;
use std::fmt::{Debug, Error, Formatter};
use std::hash::{Hash, Hasher};
use std::ops::Range;
use std::sync::Arc;
use plonky2_field::batch_util::batch_multiply_inplace;
use plonky2_field::extension_field::{Extendable, FieldExtension};
use plonky2_field::field_types::Field;
use crate::gates::gate_tree::Tree;
use crate::gates::selectors::UNUSED_SELECTOR;
use crate::gates::util::StridedConstraintConsumer;
use crate::hash::hash_types::RichField;
use crate::iop::ext_target::ExtensionTarget;
@ -80,9 +81,21 @@ pub trait Gate<F: RichField + Extendable<D>, const D: usize>: 'static + Send + S
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>>;
fn eval_filtered(&self, mut vars: EvaluationVars<F, D>, prefix: &[bool]) -> Vec<F::Extension> {
let filter = compute_filter(prefix, vars.local_constants);
vars.remove_prefix(prefix);
fn eval_filtered(
&self,
mut vars: EvaluationVars<F, D>,
gate_index: usize,
selector_index: usize,
group_range: Range<usize>,
num_selectors: usize,
) -> Vec<F::Extension> {
let filter = compute_filter(
gate_index,
group_range,
vars.local_constants[selector_index],
num_selectors > 1,
);
vars.remove_prefix(num_selectors);
self.eval_unfiltered(vars)
.into_iter()
.map(|c| filter * c)
@ -94,13 +107,23 @@ pub trait Gate<F: RichField + Extendable<D>, const D: usize>: 'static + Send + S
fn eval_filtered_base_batch(
&self,
mut vars_batch: EvaluationVarsBaseBatch<F>,
prefix: &[bool],
gate_index: usize,
selector_index: usize,
group_range: Range<usize>,
num_selectors: usize,
) -> Vec<F> {
let filters: Vec<_> = vars_batch
.iter()
.map(|vars| compute_filter(prefix, vars.local_constants))
.map(|vars| {
compute_filter(
gate_index,
group_range.clone(),
vars.local_constants[selector_index],
num_selectors > 1,
)
})
.collect();
vars_batch.remove_prefix(prefix);
vars_batch.remove_prefix(num_selectors);
let mut res_batch = self.eval_unfiltered_base_batch(vars_batch);
for res_chunk in res_batch.chunks_exact_mut(filters.len()) {
batch_multiply_inplace(res_chunk, &filters);
@ -113,11 +136,20 @@ pub trait Gate<F: RichField + Extendable<D>, const D: usize>: 'static + Send + S
&self,
builder: &mut CircuitBuilder<F, D>,
mut vars: EvaluationTargets<D>,
prefix: &[bool],
gate_index: usize,
selector_index: usize,
group_range: Range<usize>,
num_selectors: usize,
combined_gate_constraints: &mut [ExtensionTarget<D>],
) {
let filter = compute_filter_recursively(builder, prefix, vars.local_constants);
vars.remove_prefix(prefix);
let filter = compute_filter_recursively(
builder,
gate_index,
group_range,
vars.local_constants[selector_index],
num_selectors > 1,
);
vars.remove_prefix(num_selectors);
let my_constraints = self.eval_unfiltered_recursively(builder, vars);
for (acc, c) in combined_gate_constraints.iter_mut().zip(my_constraints) {
*acc = builder.mul_add_extension(filter, c, *acc);
@ -202,42 +234,36 @@ pub struct PrefixedGate<F: RichField + Extendable<D>, const D: usize> {
pub prefix: Vec<bool>,
}
impl<F: RichField + Extendable<D>, const D: usize> PrefixedGate<F, D> {
pub fn from_tree(tree: Tree<GateRef<F, D>>) -> Vec<Self> {
tree.traversal()
.into_iter()
.map(|(gate, prefix)| PrefixedGate { gate, prefix })
.collect()
}
}
/// A gate's filter is computed as `prod b_i*c_i + (1-b_i)*(1-c_i)`, with `(b_i)` the prefix and
/// `(c_i)` the local constants, which is one if the prefix of `constants` matches `prefix`.
fn compute_filter<'a, K: Field, T: IntoIterator<Item = &'a K>>(prefix: &[bool], constants: T) -> K {
prefix
.iter()
.zip(constants)
.map(|(&b, &c)| if b { c } else { K::ONE - c })
/// A gate's filter designed so that it is non-zero if `s = gate_index`.
fn compute_filter<K: Field>(
gate_index: usize,
group_range: Range<usize>,
s: K,
many_selector: bool,
) -> K {
debug_assert!(group_range.contains(&gate_index));
group_range
.filter(|&i| i != gate_index)
.chain(many_selector.then(|| UNUSED_SELECTOR))
.map(|i| K::from_canonical_usize(i) - s)
.product()
}
fn compute_filter_recursively<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
prefix: &[bool],
constants: &[ExtensionTarget<D>],
gate_index: usize,
group_range: Range<usize>,
s: ExtensionTarget<D>,
many_selectors: bool,
) -> ExtensionTarget<D> {
let one = builder.one_extension();
let v = prefix
.iter()
.enumerate()
.map(|(i, &b)| {
if b {
constants[i]
} else {
builder.sub_extension(one, constants[i])
}
debug_assert!(group_range.contains(&gate_index));
let v = group_range
.filter(|&i| i != gate_index)
.chain(many_selectors.then(|| UNUSED_SELECTOR))
.map(|i| {
let c = builder.constant_extension(F::Extension::from_canonical_usize(i));
builder.sub_extension(c, s)
})
.collect::<Vec<_>>();
builder.mul_many_extension(&v)
}

View File

@ -1,294 +0,0 @@
use log::debug;
use plonky2_field::extension_field::Extendable;
use crate::gates::gate::GateRef;
use crate::hash::hash_types::RichField;
/// A binary tree where leaves hold some type `T` and other nodes are empty.
#[derive(Debug, Clone)]
pub enum Tree<T> {
Leaf(T),
Bifurcation(Option<Box<Tree<T>>>, Option<Box<Tree<T>>>),
}
impl<T> Default for Tree<T> {
fn default() -> Self {
Self::Bifurcation(None, None)
}
}
impl<T: Clone> Tree<T> {
/// Traverse a tree using a depth-first traversal and collect data and position for each leaf.
/// A leaf's position is represented by its left/right path, where `false` means left and `true` means right.
pub fn traversal(&self) -> Vec<(T, Vec<bool>)> {
let mut res = Vec::new();
let prefix = [];
self.traverse(&prefix, &mut res);
res
}
/// Utility function to traverse the tree.
fn traverse(&self, prefix: &[bool], current: &mut Vec<(T, Vec<bool>)>) {
match &self {
// If node is a leaf, collect the data and position.
Tree::Leaf(t) => {
current.push((t.clone(), prefix.to_vec()));
}
// Otherwise, traverse the left subtree and then the right subtree.
Tree::Bifurcation(left, right) => {
if let Some(l) = left {
let mut left_prefix = prefix.to_vec();
left_prefix.push(false);
l.traverse(&left_prefix, current);
}
if let Some(r) = right {
let mut right_prefix = prefix.to_vec();
right_prefix.push(true);
r.traverse(&right_prefix, current);
}
}
}
}
}
impl<F: RichField + Extendable<D>, const D: usize> Tree<GateRef<F, D>> {
/// The binary gate tree influences the degree `D` of the constraint polynomial and the number `C`
/// of constant wires in the circuit. We want to construct a tree minimizing both values. To do so
/// we iterate over possible values of `(D, C)` and try to construct a tree with these values.
/// For this construction, we use the greedy algorithm in `Self::find_tree`.
/// This latter function greedily adds gates at the depth where
/// `filtered_deg(gate)=D, constant_wires(gate)=C` to ensure no space is wasted.
/// We return the first tree found in this manner, along with it's maximum filtered degree
/// and the number of constant wires needed when using this tree.
pub fn from_gates(mut gates: Vec<GateRef<F, D>>) -> (Self, usize, usize) {
let timer = std::time::Instant::now();
gates.sort_unstable_by_key(|g| (-(g.0.degree() as isize), -(g.0.num_constants() as isize)));
for max_degree_bits in 1..10 {
// The quotient polynomials are padded to the next power of 2 in `compute_quotient_polys`.
// So we can restrict our search space by setting `max_degree` to 1 + a power of 2.
let max_degree = (1 << max_degree_bits) + 1;
for max_constants in 1..100 {
if let Some(mut best_tree) = Self::find_tree(&gates, max_degree, max_constants) {
let mut best_num_constants = best_tree.num_constants();
let mut best_degree = max_degree;
// Iterate backwards from `max_degree` to try to find a tree with a lower degree
// but the same number of constants.
'optdegree: for degree in (0..max_degree).rev() {
if let Some(tree) = Self::find_tree(&gates, degree, max_constants) {
let num_constants = tree.num_constants();
if num_constants > best_num_constants {
break 'optdegree;
} else {
best_degree = degree;
best_num_constants = num_constants;
best_tree = tree;
}
}
}
debug!(
"Found tree with max degree {} and {} constants wires in {:.4}s.",
best_degree,
best_num_constants,
timer.elapsed().as_secs_f32()
);
return (best_tree, best_degree, best_num_constants);
}
}
}
panic!("Can't find a tree.")
}
/// Greedily add gates wherever possible. Returns `None` if this fails.
fn find_tree(gates: &[GateRef<F, D>], max_degree: usize, max_constants: usize) -> Option<Self> {
let mut tree = Tree::default();
for g in gates {
tree.try_add_gate(g, max_degree, max_constants)?;
}
tree.shorten();
Some(tree)
}
/// Try to add a gate in the tree. Returns `None` if this fails.
fn try_add_gate(
&mut self,
g: &GateRef<F, D>,
max_degree: usize,
max_constants: usize,
) -> Option<()> {
// We want `gate.degree + depth <= max_degree` and `gate.num_constants + depth <= max_wires`.
let depth = max_degree
.checked_sub(g.0.degree())?
.min(max_constants.checked_sub(g.0.num_constants())?);
self.try_add_gate_at_depth(g, depth)
}
/// Try to add a gate in the tree at a specified depth. Returns `None` if this fails.
fn try_add_gate_at_depth(&mut self, g: &GateRef<F, D>, depth: usize) -> Option<()> {
// If depth is 0, we have to insert the gate here.
if depth == 0 {
return if let Tree::Bifurcation(None, None) = self {
// Insert the gate as a new leaf.
*self = Tree::Leaf(g.clone());
Some(())
} else {
// A leaf is already here.
None
};
}
// A leaf is already here so we cannot go deeper.
if let Tree::Leaf(_) = self {
return None;
}
if let Tree::Bifurcation(left, right) = self {
if let Some(left) = left {
// Try to add the gate to the left if there's already a left subtree.
if left.try_add_gate_at_depth(g, depth - 1).is_some() {
return Some(());
}
} else {
// Add a new left subtree and try to add the gate to it.
let mut new_left = Tree::default();
if new_left.try_add_gate_at_depth(g, depth - 1).is_some() {
*left = Some(Box::new(new_left));
return Some(());
}
}
if let Some(right) = right {
// Try to add the gate to the right if there's already a right subtree.
if right.try_add_gate_at_depth(g, depth - 1).is_some() {
return Some(());
}
} else {
// Add a new right subtree and try to add the gate to it.
let mut new_right = Tree::default();
if new_right.try_add_gate_at_depth(g, depth - 1).is_some() {
*right = Some(Box::new(new_right));
return Some(());
}
}
}
None
}
/// `Self::find_tree` returns a tree where each gate has `F(gate)=M` (see `Self::from_gates` comment).
/// This can produce subtrees with more nodes than necessary. This function removes useless nodes,
/// i.e., nodes that have a left but no right subtree.
fn shorten(&mut self) {
if let Tree::Bifurcation(left, right) = self {
if let (Some(left), None) = (left, right) {
// If the node has a left but no right subtree, set the node to its (shortened) left subtree.
let mut new = *left.clone();
new.shorten();
*self = new;
}
}
if let Tree::Bifurcation(left, right) = self {
if let Some(left) = left {
// Shorten the left subtree if there is one.
left.shorten();
}
if let Some(right) = right {
// Shorten the right subtree if there is one.
right.shorten();
}
}
}
/// Returns the tree's maximum filtered constraint degree.
pub fn max_filtered_degree(&self) -> usize {
self.traversal()
.into_iter()
.map(|(g, p)| g.0.degree() + p.len())
.max()
.expect("Empty tree.")
}
/// Returns the number of constant wires needed to fit all prefixes and gate constants.
fn num_constants(&self) -> usize {
self.traversal()
.into_iter()
.map(|(g, p)| g.0.num_constants() + p.len())
.max()
.expect("Empty tree.")
}
}
#[cfg(test)]
mod tests {
use log::info;
use super::*;
use crate::gadgets::interpolation::InterpolationGate;
use crate::gates::arithmetic_extension::ArithmeticExtensionGate;
use crate::gates::base_sum::BaseSumGate;
use crate::gates::constant::ConstantGate;
use crate::gates::interpolation::HighDegreeInterpolationGate;
use crate::gates::noop::NoopGate;
use crate::gates::poseidon::PoseidonGate;
use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
#[test]
fn test_prefix_generation() {
env_logger::init();
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;
let gates = vec![
GateRef::new(NoopGate),
GateRef::new(ConstantGate { num_consts: 4 }),
GateRef::new(ArithmeticExtensionGate { num_ops: 4 }),
GateRef::new(BaseSumGate::<4>::new(4)),
GateRef::new(PoseidonGate::<F, D>::new()),
GateRef::new(HighDegreeInterpolationGate::new(2)),
];
let (tree, _, _) = Tree::from_gates(gates.clone());
let mut gates_with_prefix = tree.traversal();
for (g, p) in &gates_with_prefix {
info!(
"\nGate: {}, prefix: {:?}.\n\
Filtered constraint degree: {}, Num constant wires: {}",
&g.0.id()[..20.min(g.0.id().len())],
p,
g.0.degree() + p.len(),
g.0.num_constants() + p.len()
);
}
assert_eq!(
gates_with_prefix.len(),
gates.len(),
"The tree has too much or too little gates."
);
assert!(
gates
.iter()
.all(|g| gates_with_prefix.iter().map(|(gg, _)| gg).any(|gg| gg == g)),
"Some gates are not in the tree."
);
assert!(
gates_with_prefix
.iter()
.all(|(g, p)| g.0.degree() + g.0.num_constants() + p.len() <= 9),
"Total degree is larger than 8."
);
gates_with_prefix.sort_unstable_by_key(|(_g, p)| p.len());
for i in 0..gates_with_prefix.len() {
for j in i + 1..gates_with_prefix.len() {
assert_ne!(
&gates_with_prefix[i].1,
&gates_with_prefix[j].1[0..gates_with_prefix[i].1.len()],
"Some gates share an overlapping prefix"
);
}
}
}
}

View File

@ -11,7 +11,6 @@ pub mod comparison;
pub mod constant;
pub mod exponentiation;
pub mod gate;
pub mod gate_tree;
pub mod interpolation;
pub mod low_degree_interpolation;
pub mod multiplication_extension;
@ -24,6 +23,7 @@ pub mod random_access;
pub mod range_check_u32;
pub mod reducing;
pub mod reducing_extension;
pub(crate) mod selectors;
pub mod subtraction_u32;
pub mod switch;
pub mod util;

View File

@ -0,0 +1,111 @@
use std::ops::Range;
use plonky2_field::extension_field::Extendable;
use plonky2_field::polynomial::PolynomialValues;
use crate::gates::gate::{GateInstance, GateRef};
use crate::hash::hash_types::RichField;
/// Placeholder value to indicate that a gate doesn't use a selector polynomial.
pub(crate) const UNUSED_SELECTOR: usize = u32::MAX as usize;
#[derive(Debug, Clone)]
pub(crate) struct SelectorsInfo {
pub(crate) selector_indices: Vec<usize>,
pub(crate) groups: Vec<Range<usize>>,
}
impl SelectorsInfo {
pub fn num_selectors(&self) -> usize {
self.groups.len()
}
}
/// Returns the selector polynomials and related information.
///
/// Selector polynomials are computed as follows:
/// Partition the gates into (the smallest amount of) groups `{ G_i }`, such that for each group `G`
/// `|G| + max_{g in G} g.degree() <= max_degree`. These groups are constructed greedily from
/// the list of gates sorted by degree.
/// We build a selector polynomial `S_i` for each group `G_i`, with
/// S_i[j] =
/// if j-th row gate=g_k in G_i
/// k
/// else
/// UNUSED_SELECTOR
pub(crate) fn selector_polynomials<F: RichField + Extendable<D>, const D: usize>(
gates: &[GateRef<F, D>],
instances: &[GateInstance<F, D>],
max_degree: usize,
) -> (Vec<PolynomialValues<F>>, SelectorsInfo) {
let n = instances.len();
let num_gates = gates.len();
let max_gate_degree = gates.last().expect("No gates?").0.degree();
let index = |id| gates.iter().position(|g| g.0.id() == id).unwrap();
// Special case if we can use only one selector polynomial.
if max_gate_degree + num_gates - 1 <= max_degree {
return (
vec![PolynomialValues::new(
instances
.iter()
.map(|g| F::from_canonical_usize(index(g.gate_ref.0.id())))
.collect(),
)],
SelectorsInfo {
selector_indices: vec![0; num_gates],
groups: vec![0..num_gates],
},
);
}
if max_gate_degree >= max_degree {
panic!(
"{} has too high degree. Consider increasing `quotient_degree_factor`.",
gates.last().unwrap().0.id()
);
}
// Greedily construct the groups.
let mut groups = Vec::new();
let mut start = 0;
while start < num_gates {
let mut size = 0;
while (start + size < gates.len()) && (size + gates[start + size].0.degree() < max_degree) {
size += 1;
}
groups.push(start..start + size);
start += size;
}
let group = |i| groups.iter().position(|range| range.contains(&i)).unwrap();
// `selector_indices[i] = j` iff the `i`-th gate uses the `j`-th selector polynomial.
let selector_indices = (0..num_gates).map(group).collect();
// Placeholder value to indicate that a gate doesn't use a selector polynomial.
let unused = F::from_canonical_usize(UNUSED_SELECTOR);
let mut polynomials = vec![PolynomialValues::zero(n); groups.len()];
for (j, g) in instances.iter().enumerate() {
let GateInstance { gate_ref, .. } = g;
let i = index(gate_ref.0.id());
let gr = group(i);
for g in 0..groups.len() {
polynomials[g].values[j] = if g == gr {
F::from_canonical_usize(i)
} else {
unused
};
}
}
(
polynomials,
SelectorsInfo {
selector_indices,
groups,
},
)
}

View File

@ -19,10 +19,10 @@ use crate::gadgets::polynomial::PolynomialCoeffsExtTarget;
use crate::gates::arithmetic_base::ArithmeticGate;
use crate::gates::arithmetic_extension::ArithmeticExtensionGate;
use crate::gates::constant::ConstantGate;
use crate::gates::gate::{CurrentSlot, Gate, GateInstance, GateRef, PrefixedGate};
use crate::gates::gate_tree::Tree;
use crate::gates::gate::{CurrentSlot, Gate, GateInstance, GateRef};
use crate::gates::noop::NoopGate;
use crate::gates::public_input::PublicInputGate;
use crate::gates::selectors::selector_polynomials;
use crate::hash::hash_types::{HashOutTarget, MerkleCapTarget, RichField};
use crate::hash::merkle_proofs::MerkleProofTarget;
use crate::iop::ext_target::ExtensionTarget;
@ -551,32 +551,27 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
}
}
fn constant_polys(
&self,
gates: &[PrefixedGate<F, D>],
num_constants: usize,
) -> Vec<PolynomialValues<F>> {
let constants_per_gate = self
.gate_instances
fn constant_polys(&self) -> Vec<PolynomialValues<F>> {
let max_constants = self
.gates
.iter()
.map(|gate| {
let prefix = &gates
.iter()
.find(|g| g.gate.0.id() == gate.gate_ref.0.id())
.unwrap()
.prefix;
let mut prefixed_constants = Vec::with_capacity(num_constants);
prefixed_constants.extend(prefix.iter().map(|&b| if b { F::ONE } else { F::ZERO }));
prefixed_constants.extend_from_slice(&gate.constants);
prefixed_constants.resize(num_constants, F::ZERO);
prefixed_constants
})
.collect::<Vec<_>>();
transpose(&constants_per_gate)
.into_iter()
.map(PolynomialValues::new)
.collect()
.map(|g| g.0.num_constants())
.max()
.unwrap();
transpose(
&self
.gate_instances
.iter()
.map(|g| {
let mut consts = g.constants.clone();
consts.resize(max_constants, F::ZERO);
consts
})
.collect::<Vec<_>>(),
)
.into_iter()
.map(PolynomialValues::new)
.collect()
}
fn sigma_vecs(&self, k_is: &[F], subgroup: &[F]) -> (Vec<PolynomialValues<F>>, Forest) {
@ -669,27 +664,17 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
"FRI total reduction arity is too large.",
);
let gates = self.gates.iter().cloned().collect();
let (gate_tree, max_filtered_constraint_degree, num_constants) = Tree::from_gates(gates);
let prefixed_gates = PrefixedGate::from_tree(gate_tree);
// `quotient_degree_factor` has to be between `max_filtered_constraint_degree-1` and `1<<rate_bits`.
// We find the value that minimizes `num_partial_product + quotient_degree_factor`.
let min_quotient_degree_factor = (max_filtered_constraint_degree - 1).max(2);
let max_quotient_degree_factor = self.config.max_quotient_degree_factor.min(1 << rate_bits);
let quotient_degree_factor = (min_quotient_degree_factor..=max_quotient_degree_factor)
.min_by_key(|&q| num_partial_products(self.config.num_routed_wires, q) + q)
.unwrap();
debug!("Quotient degree factor set to: {}.", quotient_degree_factor);
let quotient_degree_factor = self.config.max_quotient_degree_factor;
let mut gates = self.gates.iter().cloned().collect::<Vec<_>>();
// Gates need to be sorted by their degrees (and ID to make the ordering deterministic) to compute the selector polynomials.
gates.sort_unstable_by_key(|g| (g.0.degree(), g.0.id()));
let (mut constant_vecs, selectors_info) =
selector_polynomials(&gates, &self.gate_instances, quotient_degree_factor + 1);
constant_vecs.extend(self.constant_polys());
let num_constants = constant_vecs.len();
let subgroup = F::two_adic_subgroup(degree_bits);
let constant_vecs = timed!(
timing,
"generate constant polynomials",
self.constant_polys(&prefixed_gates, num_constants)
);
let k_is = get_unique_coset_shifts(degree, self.config.num_routed_wires);
let (sigma_vecs, forest) = timed!(
timing,
@ -768,11 +753,6 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
fft_root_table: Some(fft_root_table),
};
// The HashSet of gates will have a non-deterministic order. When converting to a Vec, we
// sort by ID to make the ordering deterministic.
let mut gates = self.gates.iter().cloned().collect::<Vec<_>>();
gates.sort_unstable_by_key(|gate| gate.0.id());
let num_gate_constraints = gates
.iter()
.map(|gate| gate.0.num_constraints())
@ -793,7 +773,8 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
config: self.config,
fri_params,
degree_bits,
gates: prefixed_gates,
gates,
selectors_info,
quotient_degree_factor,
num_gate_constraints,
num_constants,

View File

@ -12,7 +12,8 @@ use crate::fri::structure::{
FriBatchInfo, FriBatchInfoTarget, FriInstanceInfo, FriInstanceInfoTarget, FriPolynomialInfo,
};
use crate::fri::{FriConfig, FriParams};
use crate::gates::gate::PrefixedGate;
use crate::gates::gate::GateRef;
use crate::gates::selectors::SelectorsInfo;
use crate::hash::hash_types::{MerkleCapTarget, RichField};
use crate::hash::merkle_tree::MerkleCap;
use crate::iop::ext_target::ExtensionTarget;
@ -253,7 +254,10 @@ pub struct CommonCircuitData<
pub(crate) degree_bits: usize,
/// The types of gates used in this circuit, along with their prefixes.
pub(crate) gates: Vec<PrefixedGate<F, D>>,
pub(crate) gates: Vec<GateRef<F, D>>,
/// Information on the circuit's selector polynomials.
pub(crate) selectors_info: SelectorsInfo,
/// The degree of the PLONK quotient polynomial.
pub(crate) quotient_degree_factor: usize,
@ -297,7 +301,7 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
pub fn constraint_degree(&self) -> usize {
self.gates
.iter()
.map(|g| g.gate.0.degree())
.map(|g| g.0.degree())
.max()
.expect("No gates?")
}

View File

@ -3,7 +3,6 @@ use plonky2_field::extension_field::{Extendable, FieldExtension};
use plonky2_field::field_types::Field;
use plonky2_field::zero_poly_coset::ZeroPolyOnCoset;
use crate::gates::gate::PrefixedGate;
use crate::hash::hash_types::RichField;
use crate::iop::ext_target::ExtensionTarget;
use crate::iop::target::Target;
@ -40,8 +39,7 @@ pub(crate) fn eval_vanishing_poly<
let max_degree = common_data.quotient_degree_factor;
let num_prods = common_data.num_partial_products;
let constraint_terms =
evaluate_gate_constraints(&common_data.gates, common_data.num_gate_constraints, vars);
let constraint_terms = evaluate_gate_constraints(common_data, vars);
// The L_1(x) (Z(x) - 1) vanishing terms.
let mut vanishing_z_1_terms = Vec::new();
@ -128,8 +126,7 @@ pub(crate) fn eval_vanishing_poly_base_batch<
let num_gate_constraints = common_data.num_gate_constraints;
let constraint_terms_batch =
evaluate_gate_constraints_base_batch(&common_data.gates, num_gate_constraints, vars_batch);
let constraint_terms_batch = evaluate_gate_constraints_base_batch(common_data, vars_batch);
debug_assert!(constraint_terms_batch.len() == n * num_gate_constraints);
let num_challenges = common_data.config.num_challenges;
@ -208,17 +205,27 @@ pub(crate) fn eval_vanishing_poly_base_batch<
/// `num_gate_constraints` is the largest number of constraints imposed by any gate. It is not
/// strictly necessary, but it helps performance by ensuring that we allocate a vector with exactly
/// the capacity that we need.
pub fn evaluate_gate_constraints<F: RichField + Extendable<D>, const D: usize>(
gates: &[PrefixedGate<F, D>],
num_gate_constraints: usize,
pub fn evaluate_gate_constraints<
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
const D: usize,
>(
common_data: &CommonCircuitData<F, C, D>,
vars: EvaluationVars<F, D>,
) -> Vec<F::Extension> {
let mut constraints = vec![F::Extension::ZERO; num_gate_constraints];
for gate in gates {
let gate_constraints = gate.gate.0.eval_filtered(vars, &gate.prefix);
let mut constraints = vec![F::Extension::ZERO; common_data.num_gate_constraints];
for (i, gate) in common_data.gates.iter().enumerate() {
let selector_index = common_data.selectors_info.selector_indices[i];
let gate_constraints = gate.0.eval_filtered(
vars,
i,
selector_index,
common_data.selectors_info.groups[selector_index].clone(),
common_data.selectors_info.num_selectors(),
);
for (i, c) in gate_constraints.into_iter().enumerate() {
debug_assert!(
i < num_gate_constraints,
i < common_data.num_gate_constraints,
"num_constraints() gave too low of a number"
);
constraints[i] += c;
@ -232,17 +239,24 @@ pub fn evaluate_gate_constraints<F: RichField + Extendable<D>, const D: usize>(
/// Returns a vector of `num_gate_constraints * vars_batch.len()` field elements. The constraints
/// corresponding to `vars_batch[i]` are found in `result[i], result[vars_batch.len() + i],
/// result[2 * vars_batch.len() + i], ...`.
pub fn evaluate_gate_constraints_base_batch<F: RichField + Extendable<D>, const D: usize>(
gates: &[PrefixedGate<F, D>],
num_gate_constraints: usize,
pub fn evaluate_gate_constraints_base_batch<
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
const D: usize,
>(
common_data: &CommonCircuitData<F, C, D>,
vars_batch: EvaluationVarsBaseBatch<F>,
) -> Vec<F> {
let mut constraints_batch = vec![F::ZERO; num_gate_constraints * vars_batch.len()];
for gate in gates {
let gate_constraints_batch = gate
.gate
.0
.eval_filtered_base_batch(vars_batch, &gate.prefix);
let mut constraints_batch = vec![F::ZERO; common_data.num_gate_constraints * vars_batch.len()];
for (i, gate) in common_data.gates.iter().enumerate() {
let selector_index = common_data.selectors_info.selector_indices[i];
let gate_constraints_batch = gate.0.eval_filtered_base_batch(
vars_batch,
i,
selector_index,
common_data.selectors_info.groups[selector_index].clone(),
common_data.selectors_info.num_selectors(),
);
debug_assert!(
gate_constraints_batch.len() <= constraints_batch.len(),
"num_constraints() gave too low of a number"
@ -256,22 +270,29 @@ pub fn evaluate_gate_constraints_base_batch<F: RichField + Extendable<D>, const
constraints_batch
}
pub fn evaluate_gate_constraints_recursively<F: RichField + Extendable<D>, const D: usize>(
pub fn evaluate_gate_constraints_recursively<
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
const D: usize,
>(
builder: &mut CircuitBuilder<F, D>,
gates: &[PrefixedGate<F, D>],
num_gate_constraints: usize,
common_data: &CommonCircuitData<F, C, D>,
vars: EvaluationTargets<D>,
) -> Vec<ExtensionTarget<D>> {
let mut all_gate_constraints = vec![builder.zero_extension(); num_gate_constraints];
for gate in gates {
let mut all_gate_constraints = vec![builder.zero_extension(); common_data.num_gate_constraints];
for (i, gate) in common_data.gates.iter().enumerate() {
let selector_index = common_data.selectors_info.selector_indices[i];
with_context!(
builder,
&format!("evaluate {} constraints", gate.gate.0.id()),
gate.gate.0.eval_filtered_recursively(
&format!("evaluate {} constraints", gate.0.id()),
gate.0.eval_filtered_recursively(
builder,
vars,
&gate.prefix,
&mut all_gate_constraints
i,
selector_index,
common_data.selectors_info.groups[selector_index].clone(),
common_data.selectors_info.num_selectors(),
&mut all_gate_constraints,
)
);
}
@ -308,12 +329,7 @@ pub(crate) fn eval_vanishing_poly_recursively<
let constraint_terms = with_context!(
builder,
"evaluate gate constraints",
evaluate_gate_constraints_recursively(
builder,
&common_data.gates,
common_data.num_gate_constraints,
vars,
)
evaluate_gate_constraints_recursively(builder, common_data, vars,)
);
// The L_1(x) (Z(x) - 1) vanishing terms.

View File

@ -55,8 +55,8 @@ impl<'a, F: RichField + Extendable<D>, const D: usize> EvaluationVars<'a, F, D>
ExtensionAlgebra::from_basefield_array(arr)
}
pub fn remove_prefix(&mut self, prefix: &[bool]) {
self.local_constants = &self.local_constants[prefix.len()..];
pub fn remove_prefix(&mut self, num_selectors: usize) {
self.local_constants = &self.local_constants[num_selectors..];
}
}
@ -77,8 +77,8 @@ impl<'a, F: Field> EvaluationVarsBaseBatch<'a, F> {
}
}
pub fn remove_prefix(&mut self, prefix: &[bool]) {
self.local_constants = &self.local_constants[prefix.len() * self.len()..];
pub fn remove_prefix(&mut self, num_selectors: usize) {
self.local_constants = &self.local_constants[num_selectors * self.len()..];
}
pub fn len(&self) -> usize {
@ -209,8 +209,8 @@ impl<'a, P: PackedField> ExactSizeIterator for EvaluationVarsBaseBatchIterPacked
}
impl<'a, const D: usize> EvaluationTargets<'a, D> {
pub fn remove_prefix(&mut self, prefix: &[bool]) {
self.local_constants = &self.local_constants[prefix.len()..];
pub fn remove_prefix(&mut self, num_selectors: usize) {
self.local_constants = &self.local_constants[num_selectors..];
}
}