diff --git a/plonky2/src/fri/recursive_verifier.rs b/plonky2/src/fri/recursive_verifier.rs index e7e48f82..d1b19a7d 100644 --- a/plonky2/src/fri/recursive_verifier.rs +++ b/plonky2/src/fri/recursive_verifier.rs @@ -10,10 +10,9 @@ use crate::fri::proof::{ }; use crate::fri::structure::{FriBatchInfoTarget, FriInstanceInfoTarget, FriOpeningsTarget}; use crate::fri::{FriConfig, FriParams}; +use crate::gates::coset_interpolation::CosetInterpolationGate; use crate::gates::gate::Gate; -use crate::gates::high_degree_interpolation::HighDegreeInterpolationGate; use crate::gates::interpolation::InterpolationGate; -use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; use crate::gates::random_access::RandomAccessGate; use crate::hash::hash_types::{MerkleCapTarget, RichField}; use crate::iop::ext_target::{flatten_target, ExtensionTarget}; @@ -50,23 +49,11 @@ impl, const D: usize> CircuitBuilder { let coset_start = self.mul(start, x); // The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta. - // `HighDegreeInterpolationGate` has degree `arity`, so we use the low-degree gate if - // the arity is too large. - if arity > self.config.max_quotient_degree_factor { - self.interpolate_coset::>( - arity_bits, - coset_start, - &evals, - beta, - ) - } else { - self.interpolate_coset::>( - arity_bits, - coset_start, - &evals, - beta, - ) - } + let interpolation_gate = >::with_max_degree( + arity_bits, + self.config.max_quotient_degree_factor, + ); + self.interpolate_coset(interpolation_gate, coset_start, &evals, beta) } /// Make sure we have enough wires and routed wires to do the FRI checks efficiently. This check @@ -77,14 +64,13 @@ impl, const D: usize> CircuitBuilder { &self.config, max_fri_arity_bits.max(self.config.fri_config.cap_height), ); - let (interpolation_wires, interpolation_routed_wires) = - if 1 << max_fri_arity_bits > self.config.max_quotient_degree_factor { - let gate = LowDegreeInterpolationGate::::new(max_fri_arity_bits); - (gate.num_wires(), gate.num_routed_wires()) - } else { - let gate = HighDegreeInterpolationGate::::new(max_fri_arity_bits); - (gate.num_wires(), gate.num_routed_wires()) - }; + let interpolation_gate = CosetInterpolationGate::::with_max_degree( + max_fri_arity_bits, + self.config.max_quotient_degree_factor, + ); + + let interpolation_wires = interpolation_gate.num_wires(); + let interpolation_routed_wires = interpolation_gate.num_routed_wires(); let min_wires = random_access.num_wires().max(interpolation_wires); let min_routed_wires = random_access diff --git a/plonky2/src/gates/coset_interpolation.rs b/plonky2/src/gates/coset_interpolation.rs new file mode 100644 index 00000000..3663a92c --- /dev/null +++ b/plonky2/src/gates/coset_interpolation.rs @@ -0,0 +1,783 @@ +use alloc::boxed::Box; +use alloc::string::String; +use alloc::vec::Vec; +use alloc::{format, vec}; +use core::marker::PhantomData; +use core::ops::Range; + +use crate::field::extension::algebra::ExtensionAlgebra; +use crate::field::extension::{Extendable, FieldExtension, OEF}; +use crate::field::interpolation::barycentric_weights; +use crate::field::types::Field; +use crate::gates::gate::Gate; +use crate::gates::interpolation::InterpolationGate; +use crate::gates::util::StridedConstraintConsumer; +use crate::hash::hash_types::RichField; +use crate::iop::ext_target::{ExtensionAlgebraTarget, ExtensionTarget}; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::wire::Wire; +use crate::iop::witness::{PartitionWitness, Witness, WitnessWrite}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// One of the instantiations of `InterpolationGate`: allows constraints of variable +/// degree, up to `1<, const D: usize> { + pub subgroup_bits: usize, + pub degree: usize, + _phantom: PhantomData, +} + +impl, const D: usize> InterpolationGate + for CosetInterpolationGate +{ + fn new(subgroup_bits: usize) -> Self { + Self::with_max_degree(subgroup_bits, 1 << subgroup_bits) + } + + fn num_points(&self) -> usize { + 1 << self.subgroup_bits + } + + /// Wire indices of the interpolant's `i`th coefficient. + fn wires_coeff(&self, _i: usize) -> Range { + panic!("No coefficient wires"); + } + + fn end_coeffs(&self) -> usize { + self.start_coeffs() + } +} + +impl, const D: usize> CosetInterpolationGate { + pub(crate) fn with_max_degree(subgroup_bits: usize, max_degree: usize) -> Self { + assert!(max_degree > 1, "need at least quadratic constraints"); + + let n_points = 1 << subgroup_bits; + + // Number of intermediate values required to compute interpolation with degree bound + let n_intermediates = (n_points - 2) / (max_degree - 1); + + // Find minimum degree such that (n_points - 2) / (degree - 1) < n_intermediates + 1 + // Minimizing the degree this way allows the gate to be in a larger selector group + let degree = (n_points - 2) / (n_intermediates + 1) + 2; + + Self { + subgroup_bits, + degree, + _phantom: PhantomData, + } + } + + fn num_intermediates(&self) -> usize { + (self.num_points() - 2) / (self.degree() - 1) + } + + /// The wires corresponding to the i'th intermediate evaluation. + fn wires_intermediate_eval(&self, i: usize) -> Range { + debug_assert!(i < self.num_intermediates()); + let start = self.end_coeffs() + D * i; + start..start + D + } + + /// The wires corresponding to the i'th intermediate product. + fn wires_intermediate_prod(&self, i: usize) -> Range { + debug_assert!(i < self.num_intermediates()); + let start = self.end_coeffs() + D * (self.num_intermediates() + i); + start..start + D + } + + fn barycentric_weights(&self) -> Vec { + barycentric_weights( + &F::two_adic_subgroup(self.subgroup_bits) + .into_iter() + .map(|x| (x, F::ZERO)) + .collect::>(), + ) + } + + /// End of wire indices, exclusive. + fn end(&self) -> usize { + self.end_coeffs() + D * (2 * self.num_intermediates() + 1) + } + + /// Wire indices of the shifted point to evaluate the interpolant at. + fn wires_shifted_evaluation_point(&self) -> Range { + let start = self.end_coeffs() + D * 2 * self.num_intermediates(); + start..start + D + } +} + +impl, const D: usize> Gate for CosetInterpolationGate { + fn id(&self) -> String { + format!("{self:?}") + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let shift = vars.local_wires[self.wire_shift()]; + let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); + let shifted_evaluation_point = + vars.get_local_ext_algebra(self.wires_shifted_evaluation_point()); + constraints.extend( + (evaluation_point - shifted_evaluation_point.scalar_mul(shift)).to_basefield_array(), + ); + + let domain = F::two_adic_subgroup(self.subgroup_bits); + let values = (0..self.num_points()) + .map(|i| vars.get_local_ext_algebra(self.wires_value(i))) + .collect::>(); + let weights = self.barycentric_weights(); + + let (mut computed_eval, mut computed_prod) = partial_interpolate_ext_algebra( + &domain[..self.degree()], + &values[..self.degree()], + &weights[..self.degree()], + shifted_evaluation_point, + ExtensionAlgebra::ZERO, + ExtensionAlgebra::one(), + ); + + for i in 0..self.num_intermediates() { + let intermediate_eval = vars.get_local_ext_algebra(self.wires_intermediate_eval(i)); + let intermediate_prod = vars.get_local_ext_algebra(self.wires_intermediate_prod(i)); + constraints.extend((intermediate_eval - computed_eval).to_basefield_array()); + constraints.extend((intermediate_prod - computed_prod).to_basefield_array()); + + let start_index = 1 + (self.degree() - 1) * (i + 1); + let end_index = (start_index + self.degree() - 1).min(self.num_points()); + (computed_eval, computed_prod) = partial_interpolate_ext_algebra( + &domain[start_index..end_index], + &values[start_index..end_index], + &weights[start_index..end_index], + shifted_evaluation_point, + intermediate_eval, + intermediate_prod, + ); + } + + let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); + constraints.extend((evaluation_value - computed_eval).to_basefield_array()); + + constraints + } + + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { + let shift = vars.local_wires[self.wire_shift()]; + let evaluation_point = vars.get_local_ext(self.wires_evaluation_point()); + let shifted_evaluation_point = vars.get_local_ext(self.wires_shifted_evaluation_point()); + yield_constr.many( + (evaluation_point - shifted_evaluation_point.scalar_mul(shift)).to_basefield_array(), + ); + + let domain = F::two_adic_subgroup(self.subgroup_bits); + let values = (0..self.num_points()) + .map(|i| vars.get_local_ext(self.wires_value(i))) + .collect::>(); + let weights = self.barycentric_weights(); + + let (mut computed_eval, mut computed_prod) = partial_interpolate( + &domain[..self.degree()], + &values[..self.degree()], + &weights[..self.degree()], + shifted_evaluation_point, + F::Extension::ZERO, + F::Extension::ONE, + ); + + for i in 0..self.num_intermediates() { + let intermediate_eval = vars.get_local_ext(self.wires_intermediate_eval(i)); + let intermediate_prod = vars.get_local_ext(self.wires_intermediate_prod(i)); + yield_constr.many((intermediate_eval - computed_eval).to_basefield_array()); + yield_constr.many((intermediate_prod - computed_prod).to_basefield_array()); + + let start_index = 1 + (self.degree() - 1) * (i + 1); + let end_index = (start_index + self.degree() - 1).min(self.num_points()); + (computed_eval, computed_prod) = partial_interpolate( + &domain[start_index..end_index], + &values[start_index..end_index], + &weights[start_index..end_index], + shifted_evaluation_point, + intermediate_eval, + intermediate_prod, + ); + } + + let evaluation_value = vars.get_local_ext(self.wires_evaluation_value()); + yield_constr.many((evaluation_value - computed_eval).to_basefield_array()); + } + + fn eval_unfiltered_circuit( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let shift = vars.local_wires[self.wire_shift()]; + let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); + let shifted_evaluation_point = + vars.get_local_ext_algebra(self.wires_shifted_evaluation_point()); + + let neg_one = builder.neg_one(); + let neg_shift = builder.scalar_mul_ext(neg_one, shift); + constraints.extend( + builder + .scalar_mul_add_ext_algebra(neg_shift, shifted_evaluation_point, evaluation_point) + .to_ext_target_array(), + ); + + let domain = F::two_adic_subgroup(self.subgroup_bits); + let values = (0..self.num_points()) + .map(|i| vars.get_local_ext_algebra(self.wires_value(i))) + .collect::>(); + let weights = self.barycentric_weights(); + + let initial_eval = builder.zero_ext_algebra(); + let initial_prod = builder.constant_ext_algebra(F::Extension::ONE.into()); + let (mut computed_eval, mut computed_prod) = partial_interpolate_ext_algebra_target( + builder, + &domain[..self.degree()], + &values[..self.degree()], + &weights[..self.degree()], + shifted_evaluation_point, + initial_eval, + initial_prod, + ); + + for i in 0..self.num_intermediates() { + let intermediate_eval = vars.get_local_ext_algebra(self.wires_intermediate_eval(i)); + let intermediate_prod = vars.get_local_ext_algebra(self.wires_intermediate_prod(i)); + constraints.extend( + builder + .sub_ext_algebra(intermediate_eval, computed_eval) + .to_ext_target_array(), + ); + constraints.extend( + builder + .sub_ext_algebra(intermediate_prod, computed_prod) + .to_ext_target_array(), + ); + + let start_index = 1 + (self.degree() - 1) * (i + 1); + let end_index = (start_index + self.degree() - 1).min(self.num_points()); + (computed_eval, computed_prod) = partial_interpolate_ext_algebra_target( + builder, + &domain[start_index..end_index], + &values[start_index..end_index], + &weights[start_index..end_index], + shifted_evaluation_point, + intermediate_eval, + intermediate_prod, + ); + } + + let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); + constraints.extend( + builder + .sub_ext_algebra(evaluation_value, computed_eval) + .to_ext_target_array(), + ); + + constraints + } + + fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { + let gen = InterpolationGenerator::::new(row, *self); + vec![Box::new(gen.adapter())] + } + + fn num_wires(&self) -> usize { + self.end() + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + self.degree + } + + fn num_constraints(&self) -> usize { + // D constraints to check for consistency of the shifted evaluation point, plus D + // constraints for the evaluation value. + D + D + 2 * D * self.num_intermediates() + } +} + +#[derive(Debug)] +struct InterpolationGenerator, const D: usize> { + row: usize, + gate: CosetInterpolationGate, + interpolation_domain: Vec, + interpolation_weights: Vec, + _phantom: PhantomData, +} + +impl, const D: usize> InterpolationGenerator { + fn new(row: usize, gate: CosetInterpolationGate) -> Self { + let interpolation_domain = F::two_adic_subgroup(gate.subgroup_bits); + let interpolation_weights = gate.barycentric_weights(); + InterpolationGenerator { + row, + gate, + interpolation_domain, + interpolation_weights, + _phantom: PhantomData, + } + } +} + +impl, const D: usize> SimpleGenerator + for InterpolationGenerator +{ + fn dependencies(&self) -> Vec { + let local_target = |column| { + Target::Wire(Wire { + row: self.row, + column, + }) + }; + + let local_targets = |columns: Range| columns.map(local_target); + + let num_points = self.gate.num_points(); + let mut deps = Vec::with_capacity(1 + D + num_points * D); + + deps.push(local_target(self.gate.wire_shift())); + deps.extend(local_targets(self.gate.wires_evaluation_point())); + for i in 0..num_points { + deps.extend(local_targets(self.gate.wires_value(i))); + } + deps + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |column| Wire { + row: self.row, + column, + }; + + let get_local_wire = |column| witness.get_wire(local_wire(column)); + + let get_local_ext = |wire_range: Range| { + debug_assert_eq!(wire_range.len(), D); + let values = wire_range.map(get_local_wire).collect::>(); + let arr = values.try_into().unwrap(); + F::Extension::from_basefield_array(arr) + }; + + let evaluation_point = get_local_ext(self.gate.wires_evaluation_point()); + let shift = get_local_wire(self.gate.wire_shift()); + let shifted_evaluation_point = evaluation_point.scalar_mul(shift.inverse()); + let degree = self.gate.degree(); + + out_buffer.set_ext_wires( + self.gate.wires_shifted_evaluation_point().map(local_wire), + shifted_evaluation_point, + ); + + let domain = &self.interpolation_domain; + let values = (0..self.gate.num_points()) + .map(|i| get_local_ext(self.gate.wires_value(i))) + .collect::>(); + let weights = &self.interpolation_weights; + + let (mut computed_eval, mut computed_prod) = partial_interpolate( + &domain[..degree], + &values[..degree], + &weights[..degree], + shifted_evaluation_point, + F::Extension::ZERO, + F::Extension::ONE, + ); + + for i in 0..self.gate.num_intermediates() { + let intermediate_eval_wires = self.gate.wires_intermediate_eval(i).map(local_wire); + let intermediate_prod_wires = self.gate.wires_intermediate_prod(i).map(local_wire); + out_buffer.set_ext_wires(intermediate_eval_wires, computed_eval); + out_buffer.set_ext_wires(intermediate_prod_wires, computed_prod); + + let start_index = 1 + (degree - 1) * (i + 1); + let end_index = (start_index + degree - 1).min(self.gate.num_points()); + (computed_eval, computed_prod) = partial_interpolate( + &domain[start_index..end_index], + &values[start_index..end_index], + &weights[start_index..end_index], + shifted_evaluation_point, + computed_eval, + computed_prod, + ); + } + + let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire); + out_buffer.set_ext_wires(evaluation_value_wires, computed_eval); + } +} + +/// Interpolate the polynomial defined by its values on an arbitrary domain at the given point `x`. +/// +/// The domain lies in a base field while the values and evaluation point may be from an extension +/// field. The Barycentric weights are precomputed and taken as arguments. +pub fn interpolate_over_base_domain, const D: usize>( + domain: &[F], + values: &[F::Extension], + barycentric_weights: &[F], + x: F::Extension, +) -> F::Extension { + let (result, _) = partial_interpolate( + domain, + values, + barycentric_weights, + x, + F::Extension::ZERO, + F::Extension::ONE, + ); + result +} + +/// Perform a partial interpolation of the polynomial defined by its values on an arbitrary domain. +/// +/// The Barycentric algorithm to interpolate a polynomial at a given point `x` is a linear pass +/// over the sequence of domain points, values, and Barycentric weights which maintains two +/// accumulated values, a partial evaluation and a partial product. This partially updates the +/// accumulated values, so that starting with an initial evaluation of 0 and a partial evaluation +/// of 1 and running over the whole domain is a full interpolation. +fn partial_interpolate, const D: usize>( + domain: &[F], + values: &[F::Extension], + barycentric_weights: &[F], + x: F::Extension, + initial_eval: F::Extension, + initial_partial_prod: F::Extension, +) -> (F::Extension, F::Extension) { + let n = domain.len(); + assert_ne!(n, 0); + assert_eq!(n, values.len()); + assert_eq!(n, barycentric_weights.len()); + + let weighted_values = values + .iter() + .zip(barycentric_weights.iter()) + .map(|(&value, &weight)| value.scalar_mul(weight)); + + weighted_values.zip(domain.iter()).fold( + (initial_eval, initial_partial_prod), + |(eval, terms_partial_prod), (val, &x_i)| { + let term = x - x_i.into(); + let next_eval = eval * term + val * terms_partial_prod; + let next_terms_partial_prod = terms_partial_prod * term; + (next_eval, next_terms_partial_prod) + }, + ) +} + +fn partial_interpolate_ext_algebra, const D: usize>( + domain: &[F::BaseField], + values: &[ExtensionAlgebra], + barycentric_weights: &[F::BaseField], + x: ExtensionAlgebra, + initial_eval: ExtensionAlgebra, + initial_partial_prod: ExtensionAlgebra, +) -> (ExtensionAlgebra, ExtensionAlgebra) { + let n = domain.len(); + assert_ne!(n, 0); + assert_eq!(n, values.len()); + assert_eq!(n, barycentric_weights.len()); + + let weighted_values = values + .iter() + .zip(barycentric_weights.iter()) + .map(|(&value, &weight)| value.scalar_mul(F::from_basefield(weight))); + + weighted_values.zip(domain.iter()).fold( + (initial_eval, initial_partial_prod), + |(eval, terms_partial_prod), (val, &x_i)| { + let term = x - F::from_basefield(x_i).into(); + let next_eval = eval * term + val * terms_partial_prod; + let next_terms_partial_prod = terms_partial_prod * term; + (next_eval, next_terms_partial_prod) + }, + ) +} + +fn partial_interpolate_ext_algebra_target, const D: usize>( + builder: &mut CircuitBuilder, + domain: &[F], + values: &[ExtensionAlgebraTarget], + barycentric_weights: &[F], + point: ExtensionAlgebraTarget, + initial_eval: ExtensionAlgebraTarget, + initial_partial_prod: ExtensionAlgebraTarget, +) -> (ExtensionAlgebraTarget, ExtensionAlgebraTarget) { + let n = values.len(); + debug_assert!(n != 0); + debug_assert!(domain.len() == n); + debug_assert!(barycentric_weights.len() == n); + + values + .iter() + .cloned() + .zip(domain.iter().cloned()) + .zip(barycentric_weights.iter().cloned()) + .fold( + (initial_eval, initial_partial_prod), + |(eval, partial_prod), ((val, x), weight)| { + let x_target = builder.constant_ext_algebra(F::Extension::from(x).into()); + let weight_target = builder.constant_extension(F::Extension::from(weight)); + let term = builder.sub_ext_algebra(point, x_target); + let weighted_val = builder.scalar_mul_ext_algebra(weight_target, val); + let new_eval = builder.mul_ext_algebra(eval, term); + let new_eval = builder.mul_add_ext_algebra(weighted_val, partial_prod, new_eval); + let new_partial_prod = builder.mul_ext_algebra(partial_prod, term); + (new_eval, new_partial_prod) + }, + ) +} + +#[cfg(test)] +mod tests { + use core::iter::repeat_with; + + use anyhow::Result; + use plonky2_field::polynomial::PolynomialValues; + use plonky2_util::log2_strict; + + use super::*; + use crate::field::goldilocks_field::GoldilocksField; + use crate::field::types::{Field, Sample}; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::hash::hash_types::HashOut; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + #[test] + fn test_degree_and_wires_minimized() { + let gate = >::with_max_degree(3, 2); + assert_eq!(gate.num_intermediates(), 6); + assert_eq!(gate.degree(), 2); + + let gate = >::with_max_degree(3, 3); + assert_eq!(gate.num_intermediates(), 3); + assert_eq!(gate.degree(), 3); + + let gate = >::with_max_degree(3, 4); + assert_eq!(gate.num_intermediates(), 2); + assert_eq!(gate.degree(), 4); + + let gate = >::with_max_degree(3, 5); + assert_eq!(gate.num_intermediates(), 1); + assert_eq!(gate.degree(), 5); + + let gate = >::with_max_degree(3, 6); + assert_eq!(gate.num_intermediates(), 1); + assert_eq!(gate.degree(), 5); + + let gate = >::with_max_degree(3, 7); + assert_eq!(gate.num_intermediates(), 1); + assert_eq!(gate.degree(), 5); + + let gate = >::with_max_degree(4, 3); + assert_eq!(gate.num_intermediates(), 7); + assert_eq!(gate.degree(), 3); + + let gate = >::with_max_degree(4, 6); + assert_eq!(gate.num_intermediates(), 2); + assert_eq!(gate.degree(), 6); + + let gate = >::with_max_degree(4, 8); + assert_eq!(gate.num_intermediates(), 2); + assert_eq!(gate.degree(), 6); + + let gate = >::with_max_degree(4, 9); + assert_eq!(gate.num_intermediates(), 1); + assert_eq!(gate.degree(), 9); + } + + #[test] + fn wire_indices_degree2() { + let gate = CosetInterpolationGate:: { + subgroup_bits: 2, + degree: 2, + _phantom: PhantomData, + }; + + // The exact indices aren't really important, but we want to make sure we don't have any + // overlaps or gaps. + assert_eq!(gate.wire_shift(), 0); + assert_eq!(gate.wires_value(0), 1..5); + assert_eq!(gate.wires_value(1), 5..9); + assert_eq!(gate.wires_value(2), 9..13); + assert_eq!(gate.wires_value(3), 13..17); + assert_eq!(gate.wires_evaluation_point(), 17..21); + assert_eq!(gate.wires_evaluation_value(), 21..25); + assert_eq!(gate.wires_intermediate_eval(0), 25..29); + assert_eq!(gate.wires_intermediate_eval(1), 29..33); + assert_eq!(gate.wires_intermediate_prod(0), 33..37); + assert_eq!(gate.wires_intermediate_prod(1), 37..41); + assert_eq!(gate.wires_shifted_evaluation_point(), 41..45); + assert_eq!(gate.num_wires(), 45); + } + + #[test] + fn wire_indices_degree_3() { + let gate = CosetInterpolationGate:: { + subgroup_bits: 2, + degree: 3, + _phantom: PhantomData, + }; + + // The exact indices aren't really important, but we want to make sure we don't have any + // overlaps or gaps. + assert_eq!(gate.wire_shift(), 0); + assert_eq!(gate.wires_value(0), 1..5); + assert_eq!(gate.wires_value(1), 5..9); + assert_eq!(gate.wires_value(2), 9..13); + assert_eq!(gate.wires_value(3), 13..17); + assert_eq!(gate.wires_evaluation_point(), 17..21); + assert_eq!(gate.wires_evaluation_value(), 21..25); + assert_eq!(gate.wires_intermediate_eval(0), 25..29); + assert_eq!(gate.wires_intermediate_prod(0), 29..33); + assert_eq!(gate.wires_shifted_evaluation_point(), 33..37); + assert_eq!(gate.num_wires(), 37); + } + + #[test] + fn wire_indices_degree_n() { + let gate = CosetInterpolationGate:: { + subgroup_bits: 2, + degree: 4, + _phantom: PhantomData, + }; + + // The exact indices aren't really important, but we want to make sure we don't have any + // overlaps or gaps. + assert_eq!(gate.wire_shift(), 0); + assert_eq!(gate.wires_value(0), 1..5); + assert_eq!(gate.wires_value(1), 5..9); + assert_eq!(gate.wires_value(2), 9..13); + assert_eq!(gate.wires_value(3), 13..17); + assert_eq!(gate.wires_evaluation_point(), 17..21); + assert_eq!(gate.wires_evaluation_value(), 21..25); + assert_eq!(gate.wires_shifted_evaluation_point(), 25..29); + assert_eq!(gate.num_wires(), 29); + } + + #[test] + fn low_degree() { + test_low_degree::(CosetInterpolationGate::new(2)); + } + + #[test] + fn eval_fns() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + for degree in 2..=4 { + test_eval_fns::(CosetInterpolationGate::with_max_degree(2, degree))?; + } + Ok(()) + } + + #[test] + fn test_gate_constraint() { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; + + /// Returns the local wires for an interpolation gate for given coeffs, points and eval point. + fn get_wires(shift: F, values: PolynomialValues, eval_point: FF) -> Vec { + let domain = F::two_adic_subgroup(log2_strict(values.len())); + let shifted_eval_point = + >::scalar_mul(&eval_point, shift.inverse()); + let weights = + barycentric_weights(&domain.iter().map(|&x| (x, F::ZERO)).collect::>()); + let (intermediate_eval, intermediate_prod) = partial_interpolate::<_, D>( + &domain[..3], + &values.values[..3], + &weights[..3], + shifted_eval_point, + FF::ZERO, + FF::ONE, + ); + let eval = interpolate_over_base_domain::<_, D>( + &domain, + &values.values, + &weights, + shifted_eval_point, + ); + let mut v = vec![shift]; + for val in values.values.iter() { + v.extend(val.0); + } + v.extend(eval_point.0); + v.extend(eval.0); + v.extend(intermediate_eval.0); + v.extend(intermediate_prod.0); + v.extend(shifted_eval_point.0); + v.iter().map(|&x| x.into()).collect() + } + + // Get a working row for InterpolationGate. + let shift = F::rand(); + let values = PolynomialValues::new(repeat_with(FF::rand).take(4).collect()); + let eval_point = FF::rand(); + let gate = CosetInterpolationGate::::with_max_degree(2, 3); + let vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(shift, values, eval_point), + public_inputs_hash: &HashOut::rand(), + }; + + assert!( + gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } + + #[test] + fn test_num_wires_constraints() { + let gate = >::with_max_degree(4, 8); + assert_eq!(gate.num_wires(), 47); + assert_eq!(gate.num_constraints(), 12); + + let gate = >::with_max_degree(3, 8); + assert_eq!(gate.num_wires(), 23); + assert_eq!(gate.num_constraints(), 4); + + let gate = >::with_max_degree(4, 16); + assert_eq!(gate.num_wires(), 39); + assert_eq!(gate.num_constraints(), 4); + } +} diff --git a/plonky2/src/gates/high_degree_interpolation.rs b/plonky2/src/gates/high_degree_interpolation.rs index f7e3be1f..814ce4e1 100644 --- a/plonky2/src/gates/high_degree_interpolation.rs +++ b/plonky2/src/gates/high_degree_interpolation.rs @@ -357,4 +357,15 @@ mod tests { "Gate constraints are not satisfied." ); } + + #[test] + fn test_num_wires_constraints() { + let gate = >::new(3); + assert_eq!(gate.num_wires(), 37); + assert_eq!(gate.num_constraints(), 18); + + let gate = >::new(4); + assert_eq!(gate.num_wires(), 69); + assert_eq!(gate.num_constraints(), 34); + } } diff --git a/plonky2/src/gates/interpolation.rs b/plonky2/src/gates/interpolation.rs index 07179006..5732e488 100644 --- a/plonky2/src/gates/interpolation.rs +++ b/plonky2/src/gates/interpolation.rs @@ -82,12 +82,11 @@ impl, const D: usize> CircuitBuilder { /// `evaluation_point`. pub(crate) fn interpolate_coset>( &mut self, - subgroup_bits: usize, + gate: G, coset_shift: Target, values: &[ExtensionTarget], evaluation_point: ExtensionTarget, ) -> ExtensionTarget { - let gate = G::new(subgroup_bits); let row = self.add_gate(gate, vec![]); self.connect(coset_shift, Target::wire(row, gate.wire_shift())); for (i, &v) in values.iter().enumerate() { @@ -109,7 +108,9 @@ mod tests { use crate::field::extension::FieldExtension; use crate::field::interpolation::interpolant; use crate::field::types::{Field, Sample}; + use crate::gates::coset_interpolation::CosetInterpolationGate; use crate::gates::high_degree_interpolation::HighDegreeInterpolationGate; + use crate::gates::interpolation::InterpolationGate; use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; use crate::iop::witness::PartialWitness; use crate::plonk::circuit_builder::CircuitBuilder; @@ -154,21 +155,34 @@ mod tests { let zt = builder.constant_extension(z); - let eval_hd = builder.interpolate_coset::>( - subgroup_bits, + let eval_hd = builder.interpolate_coset( + HighDegreeInterpolationGate::new(subgroup_bits), coset_shift_target, &value_targets, zt, ); - let eval_ld = builder.interpolate_coset::>( - subgroup_bits, + let eval_ld = builder.interpolate_coset( + LowDegreeInterpolationGate::new(subgroup_bits), coset_shift_target, &value_targets, zt, ); + let evals_coset_gates = (2..=4) + .map(|max_degree| { + builder.interpolate_coset( + CosetInterpolationGate::with_max_degree(subgroup_bits, max_degree), + coset_shift_target, + &value_targets, + zt, + ) + }) + .collect::>(); let true_eval_target = builder.constant_extension(true_eval); builder.connect_extension(eval_hd, true_eval_target); builder.connect_extension(eval_ld, true_eval_target); + for &eval_coset_gate in evals_coset_gates.iter() { + builder.connect_extension(eval_coset_gate, true_eval_target); + } let data = builder.build::(); let proof = data.prove(pw)?; diff --git a/plonky2/src/gates/low_degree_interpolation.rs b/plonky2/src/gates/low_degree_interpolation.rs index 15d0e56a..1d5950a7 100644 --- a/plonky2/src/gates/low_degree_interpolation.rs +++ b/plonky2/src/gates/low_degree_interpolation.rs @@ -463,4 +463,11 @@ mod tests { "Gate constraints are not satisfied." ); } + + #[test] + fn test_num_wires_constraints() { + let gate = >::new(4); + assert_eq!(gate.num_wires(), 111); + assert_eq!(gate.num_constraints(), 76); + } } diff --git a/plonky2/src/gates/mod.rs b/plonky2/src/gates/mod.rs index 1d2fc058..e53cab86 100644 --- a/plonky2/src/gates/mod.rs +++ b/plonky2/src/gates/mod.rs @@ -5,6 +5,7 @@ pub mod arithmetic_base; pub mod arithmetic_extension; pub mod base_sum; pub mod constant; +pub mod coset_interpolation; pub mod exponentiation; pub mod gate; pub mod high_degree_interpolation;