diff --git a/plonky2/src/fri/recursive_verifier.rs b/plonky2/src/fri/recursive_verifier.rs index e7e48f82..ac74f50f 100644 --- a/plonky2/src/fri/recursive_verifier.rs +++ b/plonky2/src/fri/recursive_verifier.rs @@ -10,10 +10,8 @@ use crate::fri::proof::{ }; use crate::fri::structure::{FriBatchInfoTarget, FriInstanceInfoTarget, FriOpeningsTarget}; use crate::fri::{FriConfig, FriParams}; +use crate::gates::coset_interpolation::CosetInterpolationGate; use crate::gates::gate::Gate; -use crate::gates::high_degree_interpolation::HighDegreeInterpolationGate; -use crate::gates::interpolation::InterpolationGate; -use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; use crate::gates::random_access::RandomAccessGate; use crate::hash::hash_types::{MerkleCapTarget, RichField}; use crate::iop::ext_target::{flatten_target, ExtensionTarget}; @@ -50,23 +48,11 @@ impl, const D: usize> CircuitBuilder { let coset_start = self.mul(start, x); // The answer is gotten by interpolating {(x*g^i, P(x*g^i))} and evaluating at beta. - // `HighDegreeInterpolationGate` has degree `arity`, so we use the low-degree gate if - // the arity is too large. - if arity > self.config.max_quotient_degree_factor { - self.interpolate_coset::>( - arity_bits, - coset_start, - &evals, - beta, - ) - } else { - self.interpolate_coset::>( - arity_bits, - coset_start, - &evals, - beta, - ) - } + let interpolation_gate = >::with_max_degree( + arity_bits, + self.config.max_quotient_degree_factor, + ); + self.interpolate_coset(interpolation_gate, coset_start, &evals, beta) } /// Make sure we have enough wires and routed wires to do the FRI checks efficiently. This check @@ -77,14 +63,13 @@ impl, const D: usize> CircuitBuilder { &self.config, max_fri_arity_bits.max(self.config.fri_config.cap_height), ); - let (interpolation_wires, interpolation_routed_wires) = - if 1 << max_fri_arity_bits > self.config.max_quotient_degree_factor { - let gate = LowDegreeInterpolationGate::::new(max_fri_arity_bits); - (gate.num_wires(), gate.num_routed_wires()) - } else { - let gate = HighDegreeInterpolationGate::::new(max_fri_arity_bits); - (gate.num_wires(), gate.num_routed_wires()) - }; + let interpolation_gate = CosetInterpolationGate::::with_max_degree( + max_fri_arity_bits, + self.config.max_quotient_degree_factor, + ); + + let interpolation_wires = interpolation_gate.num_wires(); + let interpolation_routed_wires = interpolation_gate.num_routed_wires(); let min_wires = random_access.num_wires().max(interpolation_wires); let min_routed_wires = random_access diff --git a/plonky2/src/gadgets/interpolation.rs b/plonky2/src/gadgets/interpolation.rs new file mode 100644 index 00000000..1ab35660 --- /dev/null +++ b/plonky2/src/gadgets/interpolation.rs @@ -0,0 +1,108 @@ +use plonky2_field::extension::Extendable; + +use crate::gates::coset_interpolation::CosetInterpolationGate; +use crate::hash::hash_types::RichField; +use crate::iop::ext_target::ExtensionTarget; +use crate::iop::target::Target; +use crate::plonk::circuit_builder::CircuitBuilder; + +impl, const D: usize> CircuitBuilder { + /// Interpolates a polynomial, whose points are a coset of the multiplicative subgroup with the + /// given size, and whose values are given. Returns the evaluation of the interpolant at + /// `evaluation_point`. + pub(crate) fn interpolate_coset( + &mut self, + gate: CosetInterpolationGate, + coset_shift: Target, + values: &[ExtensionTarget], + evaluation_point: ExtensionTarget, + ) -> ExtensionTarget { + let row = self.num_gates(); + self.connect(coset_shift, Target::wire(row, gate.wire_shift())); + for (i, &v) in values.iter().enumerate() { + self.connect_extension(v, ExtensionTarget::from_range(row, gate.wires_value(i))); + } + self.connect_extension( + evaluation_point, + ExtensionTarget::from_range(row, gate.wires_evaluation_point()), + ); + + let eval = ExtensionTarget::from_range(row, gate.wires_evaluation_value()); + self.add_gate(gate, vec![]); + + eval + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + + use crate::field::extension::FieldExtension; + use crate::field::interpolation::interpolant; + use crate::field::types::{Field, Sample}; + use crate::gates::coset_interpolation::CosetInterpolationGate; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use crate::plonk::verifier::verify; + + #[test] + fn test_interpolate() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let subgroup_bits = 2; + let len = 1 << subgroup_bits; + let coset_shift = F::rand(); + let g = F::primitive_root_of_unity(subgroup_bits); + let points = F::cyclic_subgroup_coset_known_order(g, coset_shift, len); + let values = FF::rand_vec(len); + + let homogeneous_points = points + .iter() + .zip(values.iter()) + .map(|(&a, &b)| (>::from_basefield(a), b)) + .collect::>(); + + let true_interpolant = interpolant(&homogeneous_points); + + let z = FF::rand(); + let true_eval = true_interpolant.eval(z); + + let coset_shift_target = builder.constant(coset_shift); + + let value_targets = values + .iter() + .map(|&v| (builder.constant_extension(v))) + .collect::>(); + + let zt = builder.constant_extension(z); + + let evals_coset_gates = (2..=4) + .map(|max_degree| { + builder.interpolate_coset( + CosetInterpolationGate::with_max_degree(subgroup_bits, max_degree), + coset_shift_target, + &value_targets, + zt, + ) + }) + .collect::>(); + let true_eval_target = builder.constant_extension(true_eval); + for &eval_coset_gate in evals_coset_gates.iter() { + builder.connect_extension(eval_coset_gate, true_eval_target); + } + + let data = builder.build::(); + let proof = data.prove(pw)?; + + verify(proof, &data.verifier_only, &data.common) + } +} diff --git a/plonky2/src/gadgets/mod.rs b/plonky2/src/gadgets/mod.rs index a3e50c4e..6309eb3d 100644 --- a/plonky2/src/gadgets/mod.rs +++ b/plonky2/src/gadgets/mod.rs @@ -1,6 +1,7 @@ pub mod arithmetic; pub mod arithmetic_extension; pub mod hash; +pub mod interpolation; pub mod polynomial; pub mod random_access; pub mod range_check; diff --git a/plonky2/src/gates/coset_interpolation.rs b/plonky2/src/gates/coset_interpolation.rs new file mode 100644 index 00000000..da94d1c0 --- /dev/null +++ b/plonky2/src/gates/coset_interpolation.rs @@ -0,0 +1,828 @@ +use alloc::boxed::Box; +use alloc::string::String; +use alloc::vec::Vec; +use alloc::{format, vec}; +use core::marker::PhantomData; +use core::ops::Range; + +use crate::field::extension::algebra::ExtensionAlgebra; +use crate::field::extension::{Extendable, FieldExtension, OEF}; +use crate::field::interpolation::barycentric_weights; +use crate::field::types::Field; +use crate::gates::gate::Gate; +use crate::gates::util::StridedConstraintConsumer; +use crate::hash::hash_types::RichField; +use crate::iop::ext_target::{ExtensionAlgebraTarget, ExtensionTarget}; +use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use crate::iop::target::Target; +use crate::iop::wire::Wire; +use crate::iop::witness::{PartitionWitness, Witness, WitnessWrite}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; + +/// One of the instantiations of `InterpolationGate`: allows constraints of variable +/// degree, up to `1<, const D: usize> { + pub subgroup_bits: usize, + pub degree: usize, + pub barycentric_weights: Vec, + _phantom: PhantomData, +} + +impl, const D: usize> CosetInterpolationGate { + pub fn new(subgroup_bits: usize) -> Self { + Self::with_max_degree(subgroup_bits, 1 << subgroup_bits) + } + + pub(crate) fn with_max_degree(subgroup_bits: usize, max_degree: usize) -> Self { + assert!(max_degree > 1, "need at least quadratic constraints"); + + let n_points = 1 << subgroup_bits; + + // Number of intermediate values required to compute interpolation with degree bound + let n_intermediates = (n_points - 2) / (max_degree - 1); + + // Find minimum degree such that (n_points - 2) / (degree - 1) < n_intermediates + 1 + // Minimizing the degree this way allows the gate to be in a larger selector group + let degree = (n_points - 2) / (n_intermediates + 1) + 2; + + let barycentric_weights = barycentric_weights( + &F::two_adic_subgroup(subgroup_bits) + .into_iter() + .map(|x| (x, F::ZERO)) + .collect::>(), + ); + + Self { + subgroup_bits, + degree, + barycentric_weights, + _phantom: PhantomData, + } + } + + fn num_points(&self) -> usize { + 1 << self.subgroup_bits + } + + /// Wire index of the coset shift. + pub(crate) fn wire_shift(&self) -> usize { + 0 + } + + fn start_values(&self) -> usize { + 1 + } + + /// Wire indices of the `i`th interpolant value. + pub(crate) fn wires_value(&self, i: usize) -> Range { + debug_assert!(i < self.num_points()); + let start = self.start_values() + i * D; + start..start + D + } + + fn start_evaluation_point(&self) -> usize { + self.start_values() + self.num_points() * D + } + + /// Wire indices of the point to evaluate the interpolant at. + pub(crate) fn wires_evaluation_point(&self) -> Range { + let start = self.start_evaluation_point(); + start..start + D + } + + fn start_evaluation_value(&self) -> usize { + self.start_evaluation_point() + D + } + + /// Wire indices of the interpolated value. + pub(crate) fn wires_evaluation_value(&self) -> Range { + let start = self.start_evaluation_value(); + start..start + D + } + + fn start_intermediates(&self) -> usize { + self.start_evaluation_value() + D + } + + pub fn num_routed_wires(&self) -> usize { + self.start_intermediates() + } + + fn num_intermediates(&self) -> usize { + (self.num_points() - 2) / (self.degree() - 1) + } + + /// The wires corresponding to the i'th intermediate evaluation. + fn wires_intermediate_eval(&self, i: usize) -> Range { + debug_assert!(i < self.num_intermediates()); + let start = self.start_intermediates() + D * i; + start..start + D + } + + /// The wires corresponding to the i'th intermediate product. + fn wires_intermediate_prod(&self, i: usize) -> Range { + debug_assert!(i < self.num_intermediates()); + let start = self.start_intermediates() + D * (self.num_intermediates() + i); + start..start + D + } + + /// End of wire indices, exclusive. + fn end(&self) -> usize { + self.start_intermediates() + D * (2 * self.num_intermediates() + 1) + } + + /// Wire indices of the shifted point to evaluate the interpolant at. + fn wires_shifted_evaluation_point(&self) -> Range { + let start = self.start_intermediates() + D * 2 * self.num_intermediates(); + start..start + D + } +} + +impl, const D: usize> Gate for CosetInterpolationGate { + fn id(&self) -> String { + format!("{self:?}") + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let shift = vars.local_wires[self.wire_shift()]; + let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); + let shifted_evaluation_point = + vars.get_local_ext_algebra(self.wires_shifted_evaluation_point()); + constraints.extend( + (evaluation_point - shifted_evaluation_point.scalar_mul(shift)).to_basefield_array(), + ); + + let domain = F::two_adic_subgroup(self.subgroup_bits); + let values = (0..self.num_points()) + .map(|i| vars.get_local_ext_algebra(self.wires_value(i))) + .collect::>(); + let weights = &self.barycentric_weights; + + let (mut computed_eval, mut computed_prod) = partial_interpolate_ext_algebra( + &domain[..self.degree()], + &values[..self.degree()], + &weights[..self.degree()], + shifted_evaluation_point, + ExtensionAlgebra::ZERO, + ExtensionAlgebra::one(), + ); + + for i in 0..self.num_intermediates() { + let intermediate_eval = vars.get_local_ext_algebra(self.wires_intermediate_eval(i)); + let intermediate_prod = vars.get_local_ext_algebra(self.wires_intermediate_prod(i)); + constraints.extend((intermediate_eval - computed_eval).to_basefield_array()); + constraints.extend((intermediate_prod - computed_prod).to_basefield_array()); + + let start_index = 1 + (self.degree() - 1) * (i + 1); + let end_index = (start_index + self.degree() - 1).min(self.num_points()); + (computed_eval, computed_prod) = partial_interpolate_ext_algebra( + &domain[start_index..end_index], + &values[start_index..end_index], + &weights[start_index..end_index], + shifted_evaluation_point, + intermediate_eval, + intermediate_prod, + ); + } + + let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); + constraints.extend((evaluation_value - computed_eval).to_basefield_array()); + + constraints + } + + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { + let shift = vars.local_wires[self.wire_shift()]; + let evaluation_point = vars.get_local_ext(self.wires_evaluation_point()); + let shifted_evaluation_point = vars.get_local_ext(self.wires_shifted_evaluation_point()); + yield_constr.many( + (evaluation_point - shifted_evaluation_point.scalar_mul(shift)).to_basefield_array(), + ); + + let domain = F::two_adic_subgroup(self.subgroup_bits); + let values = (0..self.num_points()) + .map(|i| vars.get_local_ext(self.wires_value(i))) + .collect::>(); + let weights = &self.barycentric_weights; + + let (mut computed_eval, mut computed_prod) = partial_interpolate( + &domain[..self.degree()], + &values[..self.degree()], + &weights[..self.degree()], + shifted_evaluation_point, + F::Extension::ZERO, + F::Extension::ONE, + ); + + for i in 0..self.num_intermediates() { + let intermediate_eval = vars.get_local_ext(self.wires_intermediate_eval(i)); + let intermediate_prod = vars.get_local_ext(self.wires_intermediate_prod(i)); + yield_constr.many((intermediate_eval - computed_eval).to_basefield_array()); + yield_constr.many((intermediate_prod - computed_prod).to_basefield_array()); + + let start_index = 1 + (self.degree() - 1) * (i + 1); + let end_index = (start_index + self.degree() - 1).min(self.num_points()); + (computed_eval, computed_prod) = partial_interpolate( + &domain[start_index..end_index], + &values[start_index..end_index], + &weights[start_index..end_index], + shifted_evaluation_point, + intermediate_eval, + intermediate_prod, + ); + } + + let evaluation_value = vars.get_local_ext(self.wires_evaluation_value()); + yield_constr.many((evaluation_value - computed_eval).to_basefield_array()); + } + + fn eval_unfiltered_circuit( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let shift = vars.local_wires[self.wire_shift()]; + let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); + let shifted_evaluation_point = + vars.get_local_ext_algebra(self.wires_shifted_evaluation_point()); + + let neg_one = builder.neg_one(); + let neg_shift = builder.scalar_mul_ext(neg_one, shift); + constraints.extend( + builder + .scalar_mul_add_ext_algebra(neg_shift, shifted_evaluation_point, evaluation_point) + .to_ext_target_array(), + ); + + let domain = F::two_adic_subgroup(self.subgroup_bits); + let values = (0..self.num_points()) + .map(|i| vars.get_local_ext_algebra(self.wires_value(i))) + .collect::>(); + let weights = &self.barycentric_weights; + + let initial_eval = builder.zero_ext_algebra(); + let initial_prod = builder.constant_ext_algebra(F::Extension::ONE.into()); + let (mut computed_eval, mut computed_prod) = partial_interpolate_ext_algebra_target( + builder, + &domain[..self.degree()], + &values[..self.degree()], + &weights[..self.degree()], + shifted_evaluation_point, + initial_eval, + initial_prod, + ); + + for i in 0..self.num_intermediates() { + let intermediate_eval = vars.get_local_ext_algebra(self.wires_intermediate_eval(i)); + let intermediate_prod = vars.get_local_ext_algebra(self.wires_intermediate_prod(i)); + constraints.extend( + builder + .sub_ext_algebra(intermediate_eval, computed_eval) + .to_ext_target_array(), + ); + constraints.extend( + builder + .sub_ext_algebra(intermediate_prod, computed_prod) + .to_ext_target_array(), + ); + + let start_index = 1 + (self.degree() - 1) * (i + 1); + let end_index = (start_index + self.degree() - 1).min(self.num_points()); + (computed_eval, computed_prod) = partial_interpolate_ext_algebra_target( + builder, + &domain[start_index..end_index], + &values[start_index..end_index], + &weights[start_index..end_index], + shifted_evaluation_point, + intermediate_eval, + intermediate_prod, + ); + } + + let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); + constraints.extend( + builder + .sub_ext_algebra(evaluation_value, computed_eval) + .to_ext_target_array(), + ); + + constraints + } + + fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { + let gen = InterpolationGenerator::::new(row, self.clone()); + vec![Box::new(gen.adapter())] + } + + fn num_wires(&self) -> usize { + self.end() + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + self.degree + } + + fn num_constraints(&self) -> usize { + // D constraints to check for consistency of the shifted evaluation point, plus D + // constraints for the evaluation value. + D + D + 2 * D * self.num_intermediates() + } +} + +#[derive(Debug)] +struct InterpolationGenerator, const D: usize> { + row: usize, + gate: CosetInterpolationGate, + interpolation_domain: Vec, + _phantom: PhantomData, +} + +impl, const D: usize> InterpolationGenerator { + fn new(row: usize, gate: CosetInterpolationGate) -> Self { + let interpolation_domain = F::two_adic_subgroup(gate.subgroup_bits); + InterpolationGenerator { + row, + gate, + interpolation_domain, + _phantom: PhantomData, + } + } +} + +impl, const D: usize> SimpleGenerator + for InterpolationGenerator +{ + fn dependencies(&self) -> Vec { + let local_target = |column| { + Target::Wire(Wire { + row: self.row, + column, + }) + }; + + let local_targets = |columns: Range| columns.map(local_target); + + let num_points = self.gate.num_points(); + let mut deps = Vec::with_capacity(1 + D + num_points * D); + + deps.push(local_target(self.gate.wire_shift())); + deps.extend(local_targets(self.gate.wires_evaluation_point())); + for i in 0..num_points { + deps.extend(local_targets(self.gate.wires_value(i))); + } + deps + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |column| Wire { + row: self.row, + column, + }; + + let get_local_wire = |column| witness.get_wire(local_wire(column)); + + let get_local_ext = |wire_range: Range| { + debug_assert_eq!(wire_range.len(), D); + let values = wire_range.map(get_local_wire).collect::>(); + let arr = values.try_into().unwrap(); + F::Extension::from_basefield_array(arr) + }; + + let evaluation_point = get_local_ext(self.gate.wires_evaluation_point()); + let shift = get_local_wire(self.gate.wire_shift()); + let shifted_evaluation_point = evaluation_point.scalar_mul(shift.inverse()); + let degree = self.gate.degree(); + + out_buffer.set_ext_wires( + self.gate.wires_shifted_evaluation_point().map(local_wire), + shifted_evaluation_point, + ); + + let domain = &self.interpolation_domain; + let values = (0..self.gate.num_points()) + .map(|i| get_local_ext(self.gate.wires_value(i))) + .collect::>(); + let weights = &self.gate.barycentric_weights; + + let (mut computed_eval, mut computed_prod) = partial_interpolate( + &domain[..degree], + &values[..degree], + &weights[..degree], + shifted_evaluation_point, + F::Extension::ZERO, + F::Extension::ONE, + ); + + for i in 0..self.gate.num_intermediates() { + let intermediate_eval_wires = self.gate.wires_intermediate_eval(i).map(local_wire); + let intermediate_prod_wires = self.gate.wires_intermediate_prod(i).map(local_wire); + out_buffer.set_ext_wires(intermediate_eval_wires, computed_eval); + out_buffer.set_ext_wires(intermediate_prod_wires, computed_prod); + + let start_index = 1 + (degree - 1) * (i + 1); + let end_index = (start_index + degree - 1).min(self.gate.num_points()); + (computed_eval, computed_prod) = partial_interpolate( + &domain[start_index..end_index], + &values[start_index..end_index], + &weights[start_index..end_index], + shifted_evaluation_point, + computed_eval, + computed_prod, + ); + } + + let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire); + out_buffer.set_ext_wires(evaluation_value_wires, computed_eval); + } +} + +/// Interpolate the polynomial defined by its values on an arbitrary domain at the given point `x`. +/// +/// The domain lies in a base field while the values and evaluation point may be from an extension +/// field. The Barycentric weights are precomputed and taken as arguments. +pub fn interpolate_over_base_domain, const D: usize>( + domain: &[F], + values: &[F::Extension], + barycentric_weights: &[F], + x: F::Extension, +) -> F::Extension { + let (result, _) = partial_interpolate( + domain, + values, + barycentric_weights, + x, + F::Extension::ZERO, + F::Extension::ONE, + ); + result +} + +/// Perform a partial interpolation of the polynomial defined by its values on an arbitrary domain. +/// +/// The Barycentric algorithm to interpolate a polynomial at a given point `x` is a linear pass +/// over the sequence of domain points, values, and Barycentric weights which maintains two +/// accumulated values, a partial evaluation and a partial product. This partially updates the +/// accumulated values, so that starting with an initial evaluation of 0 and a partial evaluation +/// of 1 and running over the whole domain is a full interpolation. +fn partial_interpolate, const D: usize>( + domain: &[F], + values: &[F::Extension], + barycentric_weights: &[F], + x: F::Extension, + initial_eval: F::Extension, + initial_partial_prod: F::Extension, +) -> (F::Extension, F::Extension) { + let n = domain.len(); + assert_ne!(n, 0); + assert_eq!(n, values.len()); + assert_eq!(n, barycentric_weights.len()); + + let weighted_values = values + .iter() + .zip(barycentric_weights.iter()) + .map(|(&value, &weight)| value.scalar_mul(weight)); + + weighted_values.zip(domain.iter()).fold( + (initial_eval, initial_partial_prod), + |(eval, terms_partial_prod), (val, &x_i)| { + let term = x - x_i.into(); + let next_eval = eval * term + val * terms_partial_prod; + let next_terms_partial_prod = terms_partial_prod * term; + (next_eval, next_terms_partial_prod) + }, + ) +} + +fn partial_interpolate_ext_algebra, const D: usize>( + domain: &[F::BaseField], + values: &[ExtensionAlgebra], + barycentric_weights: &[F::BaseField], + x: ExtensionAlgebra, + initial_eval: ExtensionAlgebra, + initial_partial_prod: ExtensionAlgebra, +) -> (ExtensionAlgebra, ExtensionAlgebra) { + let n = domain.len(); + assert_ne!(n, 0); + assert_eq!(n, values.len()); + assert_eq!(n, barycentric_weights.len()); + + let weighted_values = values + .iter() + .zip(barycentric_weights.iter()) + .map(|(&value, &weight)| value.scalar_mul(F::from_basefield(weight))); + + weighted_values.zip(domain.iter()).fold( + (initial_eval, initial_partial_prod), + |(eval, terms_partial_prod), (val, &x_i)| { + let term = x - F::from_basefield(x_i).into(); + let next_eval = eval * term + val * terms_partial_prod; + let next_terms_partial_prod = terms_partial_prod * term; + (next_eval, next_terms_partial_prod) + }, + ) +} + +fn partial_interpolate_ext_algebra_target, const D: usize>( + builder: &mut CircuitBuilder, + domain: &[F], + values: &[ExtensionAlgebraTarget], + barycentric_weights: &[F], + point: ExtensionAlgebraTarget, + initial_eval: ExtensionAlgebraTarget, + initial_partial_prod: ExtensionAlgebraTarget, +) -> (ExtensionAlgebraTarget, ExtensionAlgebraTarget) { + let n = values.len(); + debug_assert!(n != 0); + debug_assert!(domain.len() == n); + debug_assert!(barycentric_weights.len() == n); + + values + .iter() + .cloned() + .zip(domain.iter().cloned()) + .zip(barycentric_weights.iter().cloned()) + .fold( + (initial_eval, initial_partial_prod), + |(eval, partial_prod), ((val, x), weight)| { + let x_target = builder.constant_ext_algebra(F::Extension::from(x).into()); + let weight_target = builder.constant_extension(F::Extension::from(weight)); + let term = builder.sub_ext_algebra(point, x_target); + let weighted_val = builder.scalar_mul_ext_algebra(weight_target, val); + let new_eval = builder.mul_ext_algebra(eval, term); + let new_eval = builder.mul_add_ext_algebra(weighted_val, partial_prod, new_eval); + let new_partial_prod = builder.mul_ext_algebra(partial_prod, term); + (new_eval, new_partial_prod) + }, + ) +} + +#[cfg(test)] +mod tests { + use core::iter::repeat_with; + + use anyhow::Result; + use plonky2_field::polynomial::PolynomialValues; + use plonky2_util::log2_strict; + + use super::*; + use crate::field::goldilocks_field::GoldilocksField; + use crate::field::types::{Field, Sample}; + use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; + use crate::hash::hash_types::HashOut; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + #[test] + fn test_degree_and_wires_minimized() { + let gate = >::with_max_degree(3, 2); + assert_eq!(gate.num_intermediates(), 6); + assert_eq!(gate.degree(), 2); + + let gate = >::with_max_degree(3, 3); + assert_eq!(gate.num_intermediates(), 3); + assert_eq!(gate.degree(), 3); + + let gate = >::with_max_degree(3, 4); + assert_eq!(gate.num_intermediates(), 2); + assert_eq!(gate.degree(), 4); + + let gate = >::with_max_degree(3, 5); + assert_eq!(gate.num_intermediates(), 1); + assert_eq!(gate.degree(), 5); + + let gate = >::with_max_degree(3, 6); + assert_eq!(gate.num_intermediates(), 1); + assert_eq!(gate.degree(), 5); + + let gate = >::with_max_degree(3, 7); + assert_eq!(gate.num_intermediates(), 1); + assert_eq!(gate.degree(), 5); + + let gate = >::with_max_degree(4, 3); + assert_eq!(gate.num_intermediates(), 7); + assert_eq!(gate.degree(), 3); + + let gate = >::with_max_degree(4, 6); + assert_eq!(gate.num_intermediates(), 2); + assert_eq!(gate.degree(), 6); + + let gate = >::with_max_degree(4, 8); + assert_eq!(gate.num_intermediates(), 2); + assert_eq!(gate.degree(), 6); + + let gate = >::with_max_degree(4, 9); + assert_eq!(gate.num_intermediates(), 1); + assert_eq!(gate.degree(), 9); + } + + #[test] + fn wire_indices_degree2() { + let gate = CosetInterpolationGate:: { + subgroup_bits: 2, + degree: 2, + barycentric_weights: barycentric_weights( + &GoldilocksField::two_adic_subgroup(2) + .into_iter() + .map(|x| (x, GoldilocksField::ZERO)) + .collect::>(), + ), + _phantom: PhantomData, + }; + + // The exact indices aren't really important, but we want to make sure we don't have any + // overlaps or gaps. + assert_eq!(gate.wire_shift(), 0); + assert_eq!(gate.wires_value(0), 1..5); + assert_eq!(gate.wires_value(1), 5..9); + assert_eq!(gate.wires_value(2), 9..13); + assert_eq!(gate.wires_value(3), 13..17); + assert_eq!(gate.wires_evaluation_point(), 17..21); + assert_eq!(gate.wires_evaluation_value(), 21..25); + assert_eq!(gate.wires_intermediate_eval(0), 25..29); + assert_eq!(gate.wires_intermediate_eval(1), 29..33); + assert_eq!(gate.wires_intermediate_prod(0), 33..37); + assert_eq!(gate.wires_intermediate_prod(1), 37..41); + assert_eq!(gate.wires_shifted_evaluation_point(), 41..45); + assert_eq!(gate.num_wires(), 45); + } + + #[test] + fn wire_indices_degree_3() { + let gate = CosetInterpolationGate:: { + subgroup_bits: 2, + degree: 3, + barycentric_weights: barycentric_weights( + &GoldilocksField::two_adic_subgroup(2) + .into_iter() + .map(|x| (x, GoldilocksField::ZERO)) + .collect::>(), + ), + _phantom: PhantomData, + }; + + // The exact indices aren't really important, but we want to make sure we don't have any + // overlaps or gaps. + assert_eq!(gate.wire_shift(), 0); + assert_eq!(gate.wires_value(0), 1..5); + assert_eq!(gate.wires_value(1), 5..9); + assert_eq!(gate.wires_value(2), 9..13); + assert_eq!(gate.wires_value(3), 13..17); + assert_eq!(gate.wires_evaluation_point(), 17..21); + assert_eq!(gate.wires_evaluation_value(), 21..25); + assert_eq!(gate.wires_intermediate_eval(0), 25..29); + assert_eq!(gate.wires_intermediate_prod(0), 29..33); + assert_eq!(gate.wires_shifted_evaluation_point(), 33..37); + assert_eq!(gate.num_wires(), 37); + } + + #[test] + fn wire_indices_degree_n() { + let gate = CosetInterpolationGate:: { + subgroup_bits: 2, + degree: 4, + barycentric_weights: barycentric_weights( + &GoldilocksField::two_adic_subgroup(2) + .into_iter() + .map(|x| (x, GoldilocksField::ZERO)) + .collect::>(), + ), + _phantom: PhantomData, + }; + + // The exact indices aren't really important, but we want to make sure we don't have any + // overlaps or gaps. + assert_eq!(gate.wire_shift(), 0); + assert_eq!(gate.wires_value(0), 1..5); + assert_eq!(gate.wires_value(1), 5..9); + assert_eq!(gate.wires_value(2), 9..13); + assert_eq!(gate.wires_value(3), 13..17); + assert_eq!(gate.wires_evaluation_point(), 17..21); + assert_eq!(gate.wires_evaluation_value(), 21..25); + assert_eq!(gate.wires_shifted_evaluation_point(), 25..29); + assert_eq!(gate.num_wires(), 29); + } + + #[test] + fn low_degree() { + test_low_degree::(CosetInterpolationGate::new(2)); + } + + #[test] + fn eval_fns() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + for degree in 2..=4 { + test_eval_fns::(CosetInterpolationGate::with_max_degree(2, degree))?; + } + Ok(()) + } + + #[test] + fn test_gate_constraint() { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; + + /// Returns the local wires for an interpolation gate for given coeffs, points and eval point. + fn get_wires(shift: F, values: PolynomialValues, eval_point: FF) -> Vec { + let domain = F::two_adic_subgroup(log2_strict(values.len())); + let shifted_eval_point = + >::scalar_mul(&eval_point, shift.inverse()); + let weights = + barycentric_weights(&domain.iter().map(|&x| (x, F::ZERO)).collect::>()); + let (intermediate_eval, intermediate_prod) = partial_interpolate::<_, D>( + &domain[..3], + &values.values[..3], + &weights[..3], + shifted_eval_point, + FF::ZERO, + FF::ONE, + ); + let eval = interpolate_over_base_domain::<_, D>( + &domain, + &values.values, + &weights, + shifted_eval_point, + ); + let mut v = vec![shift]; + for val in values.values.iter() { + v.extend(val.0); + } + v.extend(eval_point.0); + v.extend(eval.0); + v.extend(intermediate_eval.0); + v.extend(intermediate_prod.0); + v.extend(shifted_eval_point.0); + v.iter().map(|&x| x.into()).collect() + } + + // Get a working row for InterpolationGate. + let shift = F::rand(); + let values = PolynomialValues::new(repeat_with(FF::rand).take(4).collect()); + let eval_point = FF::rand(); + let gate = CosetInterpolationGate::::with_max_degree(2, 3); + let vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(shift, values, eval_point), + public_inputs_hash: &HashOut::rand(), + }; + + assert!( + gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } + + #[test] + fn test_num_wires_constraints() { + let gate = >::with_max_degree(4, 8); + assert_eq!(gate.num_wires(), 47); + assert_eq!(gate.num_constraints(), 12); + + let gate = >::with_max_degree(3, 8); + assert_eq!(gate.num_wires(), 23); + assert_eq!(gate.num_constraints(), 4); + + let gate = >::with_max_degree(4, 16); + assert_eq!(gate.num_wires(), 39); + assert_eq!(gate.num_constraints(), 4); + } +} diff --git a/plonky2/src/gates/high_degree_interpolation.rs b/plonky2/src/gates/high_degree_interpolation.rs deleted file mode 100644 index f7e3be1f..00000000 --- a/plonky2/src/gates/high_degree_interpolation.rs +++ /dev/null @@ -1,360 +0,0 @@ -use alloc::boxed::Box; -use alloc::string::String; -use alloc::vec::Vec; -use alloc::{format, vec}; -use core::marker::PhantomData; -use core::ops::Range; - -use crate::field::extension::algebra::PolynomialCoeffsAlgebra; -use crate::field::extension::{Extendable, FieldExtension}; -use crate::field::interpolation::interpolant; -use crate::field::polynomial::PolynomialCoeffs; -use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; -use crate::gates::gate::Gate; -use crate::gates::interpolation::InterpolationGate; -use crate::gates::util::StridedConstraintConsumer; -use crate::hash::hash_types::RichField; -use crate::iop::ext_target::ExtensionTarget; -use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; -use crate::iop::target::Target; -use crate::iop::wire::Wire; -use crate::iop::witness::{PartitionWitness, Witness, WitnessWrite}; -use crate::plonk::circuit_builder::CircuitBuilder; -use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; - -/// One of the instantiations of `InterpolationGate`: allows constraints of variable -/// degree, up to `1<, const D: usize> { - pub subgroup_bits: usize, - _phantom: PhantomData, -} - -impl, const D: usize> InterpolationGate - for HighDegreeInterpolationGate -{ - fn new(subgroup_bits: usize) -> Self { - Self { - subgroup_bits, - _phantom: PhantomData, - } - } - - fn num_points(&self) -> usize { - 1 << self.subgroup_bits - } -} - -impl, const D: usize> HighDegreeInterpolationGate { - /// End of wire indices, exclusive. - fn end(&self) -> usize { - self.start_coeffs() + self.num_points() * D - } - - /// The domain of the points we're interpolating. - fn coset(&self, shift: F) -> impl Iterator { - let g = F::primitive_root_of_unity(self.subgroup_bits); - let size = 1 << self.subgroup_bits; - // Speed matters here, so we avoid `cyclic_subgroup_coset_known_order` which allocates. - g.powers().take(size).map(move |x| x * shift) - } - - /// The domain of the points we're interpolating. - fn coset_ext(&self, shift: F::Extension) -> impl Iterator { - let g = F::primitive_root_of_unity(self.subgroup_bits); - let size = 1 << self.subgroup_bits; - g.powers().take(size).map(move |x| shift.scalar_mul(x)) - } - - /// The domain of the points we're interpolating. - fn coset_ext_circuit( - &self, - builder: &mut CircuitBuilder, - shift: ExtensionTarget, - ) -> Vec> { - let g = F::primitive_root_of_unity(self.subgroup_bits); - let size = 1 << self.subgroup_bits; - g.powers() - .take(size) - .map(move |x| { - let subgroup_element = builder.constant(x); - builder.scalar_mul_ext(subgroup_element, shift) - }) - .collect() - } -} - -impl, const D: usize> Gate - for HighDegreeInterpolationGate -{ - fn id(&self) -> String { - format!("{self:?}") - } - - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); - - let coeffs = (0..self.num_points()) - .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) - .collect(); - let interpolant = PolynomialCoeffsAlgebra::new(coeffs); - - let coset = self.coset_ext(vars.local_wires[self.wire_shift()]); - for (i, point) in coset.into_iter().enumerate() { - let value = vars.get_local_ext_algebra(self.wires_value(i)); - let computed_value = interpolant.eval_base(point); - constraints.extend((value - computed_value).to_basefield_array()); - } - - let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); - let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); - let computed_evaluation_value = interpolant.eval(evaluation_point); - constraints.extend((evaluation_value - computed_evaluation_value).to_basefield_array()); - - constraints - } - - fn eval_unfiltered_base_one( - &self, - vars: EvaluationVarsBase, - mut yield_constr: StridedConstraintConsumer, - ) { - let coeffs = (0..self.num_points()) - .map(|i| vars.get_local_ext(self.wires_coeff(i))) - .collect(); - let interpolant = PolynomialCoeffs::new(coeffs); - - let coset = self.coset(vars.local_wires[self.wire_shift()]); - for (i, point) in coset.into_iter().enumerate() { - let value = vars.get_local_ext(self.wires_value(i)); - let computed_value = interpolant.eval_base(point); - yield_constr.many((value - computed_value).to_basefield_array()); - } - - let evaluation_point = vars.get_local_ext(self.wires_evaluation_point()); - let evaluation_value = vars.get_local_ext(self.wires_evaluation_value()); - let computed_evaluation_value = interpolant.eval(evaluation_point); - yield_constr.many((evaluation_value - computed_evaluation_value).to_basefield_array()); - } - - fn eval_unfiltered_circuit( - &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec> { - let mut constraints = Vec::with_capacity(self.num_constraints()); - - let coeffs = (0..self.num_points()) - .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) - .collect(); - let interpolant = PolynomialCoeffsExtAlgebraTarget(coeffs); - - let coset = self.coset_ext_circuit(builder, vars.local_wires[self.wire_shift()]); - for (i, point) in coset.into_iter().enumerate() { - let value = vars.get_local_ext_algebra(self.wires_value(i)); - let computed_value = interpolant.eval_scalar(builder, point); - constraints.extend( - builder - .sub_ext_algebra(value, computed_value) - .to_ext_target_array(), - ); - } - - let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); - let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); - let computed_evaluation_value = interpolant.eval(builder, evaluation_point); - constraints.extend( - builder - .sub_ext_algebra(evaluation_value, computed_evaluation_value) - .to_ext_target_array(), - ); - - constraints - } - - fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { - let gen = InterpolationGenerator:: { - row, - gate: *self, - _phantom: PhantomData, - }; - vec![Box::new(gen.adapter())] - } - - fn num_wires(&self) -> usize { - self.end() - } - - fn num_constants(&self) -> usize { - 0 - } - - fn degree(&self) -> usize { - // The highest power of x is `num_points - 1`, and then multiplication by the coefficient - // adds 1. - self.num_points() - } - - fn num_constraints(&self) -> usize { - // num_points * D constraints to check for consistency between the coefficients and the - // point-value pairs, plus D constraints for the evaluation value. - self.num_points() * D + D - } -} - -#[derive(Debug)] -struct InterpolationGenerator, const D: usize> { - row: usize, - gate: HighDegreeInterpolationGate, - _phantom: PhantomData, -} - -impl, const D: usize> SimpleGenerator - for InterpolationGenerator -{ - fn dependencies(&self) -> Vec { - let local_target = |column| { - Target::Wire(Wire { - row: self.row, - column, - }) - }; - - let local_targets = |columns: Range| columns.map(local_target); - - let num_points = self.gate.num_points(); - let mut deps = Vec::with_capacity(1 + D + num_points * D); - - deps.push(local_target(self.gate.wire_shift())); - deps.extend(local_targets(self.gate.wires_evaluation_point())); - for i in 0..num_points { - deps.extend(local_targets(self.gate.wires_value(i))); - } - deps - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let local_wire = |column| Wire { - row: self.row, - column, - }; - - let get_local_wire = |column| witness.get_wire(local_wire(column)); - - let get_local_ext = |wire_range: Range| { - debug_assert_eq!(wire_range.len(), D); - let values = wire_range.map(get_local_wire).collect::>(); - let arr = values.try_into().unwrap(); - F::Extension::from_basefield_array(arr) - }; - - // Compute the interpolant. - let points = self.gate.coset(get_local_wire(self.gate.wire_shift())); - let points = points - .into_iter() - .enumerate() - .map(|(i, point)| (point.into(), get_local_ext(self.gate.wires_value(i)))) - .collect::>(); - let interpolant = interpolant(&points); - - for (i, &coeff) in interpolant.coeffs.iter().enumerate() { - let wires = self.gate.wires_coeff(i).map(local_wire); - out_buffer.set_ext_wires(wires, coeff); - } - - let evaluation_point = get_local_ext(self.gate.wires_evaluation_point()); - let evaluation_value = interpolant.eval(evaluation_point); - let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire); - out_buffer.set_ext_wires(evaluation_value_wires, evaluation_value); - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - - use super::*; - use crate::field::goldilocks_field::GoldilocksField; - use crate::field::types::{Field, Sample}; - use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; - use crate::hash::hash_types::HashOut; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - - #[test] - fn wire_indices() { - let gate = HighDegreeInterpolationGate:: { - subgroup_bits: 1, - _phantom: PhantomData, - }; - - // The exact indices aren't really important, but we want to make sure we don't have any - // overlaps or gaps. - assert_eq!(gate.wire_shift(), 0); - assert_eq!(gate.wires_value(0), 1..5); - assert_eq!(gate.wires_value(1), 5..9); - assert_eq!(gate.wires_evaluation_point(), 9..13); - assert_eq!(gate.wires_evaluation_value(), 13..17); - assert_eq!(gate.wires_coeff(0), 17..21); - assert_eq!(gate.wires_coeff(1), 21..25); - assert_eq!(gate.num_wires(), 25); - } - - #[test] - fn low_degree() { - test_low_degree::(HighDegreeInterpolationGate::new(2)); - } - - #[test] - fn eval_fns() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - test_eval_fns::(HighDegreeInterpolationGate::new(2)) - } - - #[test] - fn test_gate_constraint() { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - type FF = >::FE; - - /// Returns the local wires for an interpolation gate for given coeffs, points and eval point. - fn get_wires( - gate: &HighDegreeInterpolationGate, - shift: F, - coeffs: PolynomialCoeffs, - eval_point: FF, - ) -> Vec { - let points = gate.coset(shift); - let mut v = vec![shift]; - for x in points { - v.extend(coeffs.eval(x.into()).0); - } - v.extend(eval_point.0); - v.extend(coeffs.eval(eval_point).0); - for i in 0..coeffs.len() { - v.extend(coeffs.coeffs[i].0); - } - v.iter().map(|&x| x.into()).collect() - } - - // Get a working row for InterpolationGate. - let shift = F::rand(); - let coeffs = PolynomialCoeffs::new(vec![FF::rand(), FF::rand()]); - let eval_point = FF::rand(); - let gate = HighDegreeInterpolationGate::::new(1); - let vars = EvaluationVars { - local_constants: &[], - local_wires: &get_wires(&gate, shift, coeffs, eval_point), - public_inputs_hash: &HashOut::rand(), - }; - - assert!( - gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), - "Gate constraints are not satisfied." - ); - } -} diff --git a/plonky2/src/gates/interpolation.rs b/plonky2/src/gates/interpolation.rs deleted file mode 100644 index 07179006..00000000 --- a/plonky2/src/gates/interpolation.rs +++ /dev/null @@ -1,178 +0,0 @@ -use alloc::vec; -use core::ops::Range; - -use crate::field::extension::Extendable; -use crate::gates::gate::Gate; -use crate::hash::hash_types::RichField; -use crate::iop::ext_target::ExtensionTarget; -use crate::iop::target::Target; -use crate::plonk::circuit_builder::CircuitBuilder; - -/// Trait for gates which interpolate a polynomial, whose points are a (base field) coset of the multiplicative subgroup -/// with the given size, and whose values are extension field elements, given by input wires. -/// Outputs the evaluation of the interpolant at a given (extension field) evaluation point. -pub(crate) trait InterpolationGate, const D: usize>: - Gate + Copy -{ - fn new(subgroup_bits: usize) -> Self; - - fn num_points(&self) -> usize; - - /// Wire index of the coset shift. - fn wire_shift(&self) -> usize { - 0 - } - - fn start_values(&self) -> usize { - 1 - } - - /// Wire indices of the `i`th interpolant value. - fn wires_value(&self, i: usize) -> Range { - debug_assert!(i < self.num_points()); - let start = self.start_values() + i * D; - start..start + D - } - - fn start_evaluation_point(&self) -> usize { - self.start_values() + self.num_points() * D - } - - /// Wire indices of the point to evaluate the interpolant at. - fn wires_evaluation_point(&self) -> Range { - let start = self.start_evaluation_point(); - start..start + D - } - - fn start_evaluation_value(&self) -> usize { - self.start_evaluation_point() + D - } - - /// Wire indices of the interpolated value. - fn wires_evaluation_value(&self) -> Range { - let start = self.start_evaluation_value(); - start..start + D - } - - fn start_coeffs(&self) -> usize { - self.start_evaluation_value() + D - } - - /// The number of routed wires required in the typical usage of this gate, where the points to - /// interpolate, the evaluation point, and the corresponding value are all routed. - fn num_routed_wires(&self) -> usize { - self.start_coeffs() - } - - /// Wire indices of the interpolant's `i`th coefficient. - fn wires_coeff(&self, i: usize) -> Range { - debug_assert!(i < self.num_points()); - let start = self.start_coeffs() + i * D; - start..start + D - } - - fn end_coeffs(&self) -> usize { - self.start_coeffs() + D * self.num_points() - } -} - -impl, const D: usize> CircuitBuilder { - /// Interpolates a polynomial, whose points are a coset of the multiplicative subgroup with the - /// given size, and whose values are given. Returns the evaluation of the interpolant at - /// `evaluation_point`. - pub(crate) fn interpolate_coset>( - &mut self, - subgroup_bits: usize, - coset_shift: Target, - values: &[ExtensionTarget], - evaluation_point: ExtensionTarget, - ) -> ExtensionTarget { - let gate = G::new(subgroup_bits); - let row = self.add_gate(gate, vec![]); - self.connect(coset_shift, Target::wire(row, gate.wire_shift())); - for (i, &v) in values.iter().enumerate() { - self.connect_extension(v, ExtensionTarget::from_range(row, gate.wires_value(i))); - } - self.connect_extension( - evaluation_point, - ExtensionTarget::from_range(row, gate.wires_evaluation_point()), - ); - - ExtensionTarget::from_range(row, gate.wires_evaluation_value()) - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - - use crate::field::extension::FieldExtension; - use crate::field::interpolation::interpolant; - use crate::field::types::{Field, Sample}; - use crate::gates::high_degree_interpolation::HighDegreeInterpolationGate; - use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; - use crate::iop::witness::PartialWitness; - use crate::plonk::circuit_builder::CircuitBuilder; - use crate::plonk::circuit_data::CircuitConfig; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use crate::plonk::verifier::verify; - - #[test] - fn test_interpolate() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - type FF = >::FE; - let config = CircuitConfig::standard_recursion_config(); - let pw = PartialWitness::new(); - let mut builder = CircuitBuilder::::new(config); - - let subgroup_bits = 2; - let len = 1 << subgroup_bits; - let coset_shift = F::rand(); - let g = F::primitive_root_of_unity(subgroup_bits); - let points = F::cyclic_subgroup_coset_known_order(g, coset_shift, len); - let values = FF::rand_vec(len); - - let homogeneous_points = points - .iter() - .zip(values.iter()) - .map(|(&a, &b)| (>::from_basefield(a), b)) - .collect::>(); - - let true_interpolant = interpolant(&homogeneous_points); - - let z = FF::rand(); - let true_eval = true_interpolant.eval(z); - - let coset_shift_target = builder.constant(coset_shift); - - let value_targets = values - .iter() - .map(|&v| (builder.constant_extension(v))) - .collect::>(); - - let zt = builder.constant_extension(z); - - let eval_hd = builder.interpolate_coset::>( - subgroup_bits, - coset_shift_target, - &value_targets, - zt, - ); - let eval_ld = builder.interpolate_coset::>( - subgroup_bits, - coset_shift_target, - &value_targets, - zt, - ); - let true_eval_target = builder.constant_extension(true_eval); - builder.connect_extension(eval_hd, true_eval_target); - builder.connect_extension(eval_ld, true_eval_target); - - let data = builder.build::(); - let proof = data.prove(pw)?; - - verify(proof, &data.verifier_only, &data.common) - } -} diff --git a/plonky2/src/gates/low_degree_interpolation.rs b/plonky2/src/gates/low_degree_interpolation.rs deleted file mode 100644 index 15d0e56a..00000000 --- a/plonky2/src/gates/low_degree_interpolation.rs +++ /dev/null @@ -1,466 +0,0 @@ -use alloc::boxed::Box; -use alloc::string::String; -use alloc::vec::Vec; -use alloc::{format, vec}; -use core::marker::PhantomData; -use core::ops::Range; - -use crate::field::extension::algebra::PolynomialCoeffsAlgebra; -use crate::field::extension::{Extendable, FieldExtension}; -use crate::field::interpolation::interpolant; -use crate::field::polynomial::PolynomialCoeffs; -use crate::field::types::Field; -use crate::gadgets::polynomial::PolynomialCoeffsExtAlgebraTarget; -use crate::gates::gate::Gate; -use crate::gates::interpolation::InterpolationGate; -use crate::gates::util::StridedConstraintConsumer; -use crate::hash::hash_types::RichField; -use crate::iop::ext_target::ExtensionTarget; -use crate::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; -use crate::iop::target::Target; -use crate::iop::wire::Wire; -use crate::iop::witness::{PartitionWitness, Witness, WitnessWrite}; -use crate::plonk::circuit_builder::CircuitBuilder; -use crate::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; - -/// One of the instantiations of `InterpolationGate`: all constraints are degree <= 2. -/// The lower degree is a tradeoff for more gates (`eval_unfiltered_recursively` for -/// this version uses more gates than `LowDegreeInterpolationGate`). -#[derive(Copy, Clone, Debug)] -pub struct LowDegreeInterpolationGate, const D: usize> { - pub subgroup_bits: usize, - _phantom: PhantomData, -} - -impl, const D: usize> InterpolationGate - for LowDegreeInterpolationGate -{ - fn new(subgroup_bits: usize) -> Self { - Self { - subgroup_bits, - _phantom: PhantomData, - } - } - - fn num_points(&self) -> usize { - 1 << self.subgroup_bits - } -} - -impl, const D: usize> LowDegreeInterpolationGate { - /// `powers_shift(i)` is the wire index of `wire_shift^i`. - pub fn powers_shift(&self, i: usize) -> usize { - debug_assert!(0 < i && i < self.num_points()); - if i == 1 { - return self.wire_shift(); - } - self.end_coeffs() + i - 2 - } - - /// `powers_evalutation_point(i)` is the wire index of `evalutation_point^i`. - pub fn powers_evaluation_point(&self, i: usize) -> Range { - debug_assert!(0 < i && i < self.num_points()); - if i == 1 { - return self.wires_evaluation_point(); - } - let start = self.end_coeffs() + self.num_points() - 2 + (i - 2) * D; - start..start + D - } - - /// End of wire indices, exclusive. - fn end(&self) -> usize { - self.powers_evaluation_point(self.num_points() - 1).end - } - - /// The domain of the points we're interpolating. - fn coset(&self, shift: F) -> impl Iterator { - let g = F::primitive_root_of_unity(self.subgroup_bits); - let size = 1 << self.subgroup_bits; - // Speed matters here, so we avoid `cyclic_subgroup_coset_known_order` which allocates. - g.powers().take(size).map(move |x| x * shift) - } -} - -impl, const D: usize> Gate for LowDegreeInterpolationGate { - fn id(&self) -> String { - format!("{self:?}") - } - - fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { - let mut constraints = Vec::with_capacity(self.num_constraints()); - - let coeffs = (0..self.num_points()) - .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) - .collect::>(); - - let mut powers_shift = (1..self.num_points()) - .map(|i| vars.local_wires[self.powers_shift(i)]) - .collect::>(); - let shift = powers_shift[0]; - for i in 1..self.num_points() - 1 { - constraints.push(powers_shift[i - 1] * shift - powers_shift[i]); - } - powers_shift.insert(0, F::Extension::ONE); - // `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient. - // Then, `altered(w^i) = original(shift*w^i)`. - let altered_coeffs = coeffs - .iter() - .zip(powers_shift) - .map(|(&c, p)| c.scalar_mul(p)) - .collect::>(); - let interpolant = PolynomialCoeffsAlgebra::new(coeffs); - let altered_interpolant = PolynomialCoeffsAlgebra::new(altered_coeffs); - - for (i, point) in F::Extension::two_adic_subgroup(self.subgroup_bits) - .into_iter() - .enumerate() - { - let value = vars.get_local_ext_algebra(self.wires_value(i)); - let computed_value = altered_interpolant.eval_base(point); - constraints.extend((value - computed_value).to_basefield_array()); - } - - let evaluation_point_powers = (1..self.num_points()) - .map(|i| vars.get_local_ext_algebra(self.powers_evaluation_point(i))) - .collect::>(); - let evaluation_point = evaluation_point_powers[0]; - for i in 1..self.num_points() - 1 { - constraints.extend( - (evaluation_point_powers[i - 1] * evaluation_point - evaluation_point_powers[i]) - .to_basefield_array(), - ); - } - let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); - let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers); - constraints.extend((evaluation_value - computed_evaluation_value).to_basefield_array()); - - constraints - } - - fn eval_unfiltered_base_one( - &self, - vars: EvaluationVarsBase, - mut yield_constr: StridedConstraintConsumer, - ) { - let coeffs = (0..self.num_points()) - .map(|i| vars.get_local_ext(self.wires_coeff(i))) - .collect::>(); - - let mut powers_shift = (1..self.num_points()) - .map(|i| vars.local_wires[self.powers_shift(i)]) - .collect::>(); - let shift = powers_shift[0]; - for i in 1..self.num_points() - 1 { - yield_constr.one(powers_shift[i - 1] * shift - powers_shift[i]); - } - powers_shift.insert(0, F::ONE); - // `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient. - // Then, `altered(w^i) = original(shift*w^i)`. - let altered_coeffs = coeffs - .iter() - .zip(powers_shift) - .map(|(&c, p)| c.scalar_mul(p)) - .collect::>(); - let interpolant = PolynomialCoeffs::new(coeffs); - let altered_interpolant = PolynomialCoeffs::new(altered_coeffs); - - for (i, point) in F::two_adic_subgroup(self.subgroup_bits) - .into_iter() - .enumerate() - { - let value = vars.get_local_ext(self.wires_value(i)); - let computed_value = altered_interpolant.eval_base(point); - yield_constr.many((value - computed_value).to_basefield_array()); - } - - let evaluation_point_powers = (1..self.num_points()) - .map(|i| vars.get_local_ext(self.powers_evaluation_point(i))) - .collect::>(); - let evaluation_point = evaluation_point_powers[0]; - for i in 1..self.num_points() - 1 { - yield_constr.many( - (evaluation_point_powers[i - 1] * evaluation_point - evaluation_point_powers[i]) - .to_basefield_array(), - ); - } - let evaluation_value = vars.get_local_ext(self.wires_evaluation_value()); - let computed_evaluation_value = interpolant.eval_with_powers(&evaluation_point_powers); - yield_constr.many((evaluation_value - computed_evaluation_value).to_basefield_array()); - } - - fn eval_unfiltered_circuit( - &self, - builder: &mut CircuitBuilder, - vars: EvaluationTargets, - ) -> Vec> { - let mut constraints = Vec::with_capacity(self.num_constraints()); - - let coeffs = (0..self.num_points()) - .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) - .collect::>(); - - let mut powers_shift = (1..self.num_points()) - .map(|i| vars.local_wires[self.powers_shift(i)]) - .collect::>(); - let shift = powers_shift[0]; - for i in 1..self.num_points() - 1 { - constraints.push(builder.mul_sub_extension( - powers_shift[i - 1], - shift, - powers_shift[i], - )); - } - powers_shift.insert(0, builder.one_extension()); - // `altered_coeffs[i] = c_i * shift^i`, where `c_i` is the original coefficient. - // Then, `altered(w^i) = original(shift*w^i)`. - let altered_coeffs = coeffs - .iter() - .zip(powers_shift) - .map(|(&c, p)| builder.scalar_mul_ext_algebra(p, c)) - .collect::>(); - let interpolant = PolynomialCoeffsExtAlgebraTarget(coeffs); - let altered_interpolant = PolynomialCoeffsExtAlgebraTarget(altered_coeffs); - - for (i, point) in F::Extension::two_adic_subgroup(self.subgroup_bits) - .into_iter() - .enumerate() - { - let value = vars.get_local_ext_algebra(self.wires_value(i)); - let point = builder.constant_extension(point); - let computed_value = altered_interpolant.eval_scalar(builder, point); - constraints.extend( - builder - .sub_ext_algebra(value, computed_value) - .to_ext_target_array(), - ); - } - - let evaluation_point_powers = (1..self.num_points()) - .map(|i| vars.get_local_ext_algebra(self.powers_evaluation_point(i))) - .collect::>(); - let evaluation_point = evaluation_point_powers[0]; - for i in 1..self.num_points() - 1 { - let neg_one_ext = builder.neg_one_extension(); - let neg_new_power = - builder.scalar_mul_ext_algebra(neg_one_ext, evaluation_point_powers[i]); - let constraint = builder.mul_add_ext_algebra( - evaluation_point, - evaluation_point_powers[i - 1], - neg_new_power, - ); - constraints.extend(constraint.to_ext_target_array()); - } - let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); - let computed_evaluation_value = - interpolant.eval_with_powers(builder, &evaluation_point_powers); - // let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); - // let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); - // let computed_evaluation_value = interpolant.eval(builder, evaluation_point); - constraints.extend( - builder - .sub_ext_algebra(evaluation_value, computed_evaluation_value) - .to_ext_target_array(), - ); - - constraints - } - - fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { - let gen = InterpolationGenerator:: { - row, - gate: *self, - _phantom: PhantomData, - }; - vec![Box::new(gen.adapter())] - } - - fn num_wires(&self) -> usize { - self.end() - } - - fn num_constants(&self) -> usize { - 0 - } - - fn degree(&self) -> usize { - 2 - } - - fn num_constraints(&self) -> usize { - // `num_points * D` constraints to check for consistency between the coefficients and the - // point-value pairs, plus `D` constraints for the evaluation value, plus `(D+1)*(num_points-2)` - // to check power constraints for evaluation point and shift. - self.num_points() * D + D + (D + 1) * (self.num_points() - 2) - } -} - -#[derive(Debug)] -struct InterpolationGenerator, const D: usize> { - row: usize, - gate: LowDegreeInterpolationGate, - _phantom: PhantomData, -} - -impl, const D: usize> SimpleGenerator - for InterpolationGenerator -{ - fn dependencies(&self) -> Vec { - let local_target = |column| { - Target::Wire(Wire { - row: self.row, - column, - }) - }; - - let local_targets = |columns: Range| columns.map(local_target); - - let num_points = self.gate.num_points(); - let mut deps = Vec::with_capacity(1 + D + num_points * D); - - deps.push(local_target(self.gate.wire_shift())); - deps.extend(local_targets(self.gate.wires_evaluation_point())); - for i in 0..num_points { - deps.extend(local_targets(self.gate.wires_value(i))); - } - deps - } - - fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { - let local_wire = |column| Wire { - row: self.row, - column, - }; - - let get_local_wire = |column| witness.get_wire(local_wire(column)); - - let get_local_ext = |wire_range: Range| { - debug_assert_eq!(wire_range.len(), D); - let values = wire_range.map(get_local_wire).collect::>(); - let arr = values.try_into().unwrap(); - F::Extension::from_basefield_array(arr) - }; - - let wire_shift = get_local_wire(self.gate.wire_shift()); - - for (i, power) in wire_shift - .powers() - .take(self.gate.num_points()) - .enumerate() - .skip(2) - { - out_buffer.set_wire(local_wire(self.gate.powers_shift(i)), power); - } - - // Compute the interpolant. - let points = self.gate.coset(wire_shift); - let points = points - .into_iter() - .enumerate() - .map(|(i, point)| (point.into(), get_local_ext(self.gate.wires_value(i)))) - .collect::>(); - let interpolant = interpolant(&points); - - for (i, &coeff) in interpolant.coeffs.iter().enumerate() { - let wires = self.gate.wires_coeff(i).map(local_wire); - out_buffer.set_ext_wires(wires, coeff); - } - - let evaluation_point = get_local_ext(self.gate.wires_evaluation_point()); - for (i, power) in evaluation_point - .powers() - .take(self.gate.num_points()) - .enumerate() - .skip(2) - { - out_buffer.set_extension_target( - ExtensionTarget::from_range(self.row, self.gate.powers_evaluation_point(i)), - power, - ); - } - let evaluation_value = interpolant.eval(evaluation_point); - let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire); - out_buffer.set_ext_wires(evaluation_value_wires, evaluation_value); - } -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - - use crate::field::extension::quadratic::QuadraticExtension; - use crate::field::goldilocks_field::GoldilocksField; - use crate::field::polynomial::PolynomialCoeffs; - use crate::field::types::{Field, Sample}; - use crate::gates::gate::Gate; - use crate::gates::gate_testing::{test_eval_fns, test_low_degree}; - use crate::gates::interpolation::InterpolationGate; - use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; - use crate::hash::hash_types::HashOut; - use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; - use crate::plonk::vars::EvaluationVars; - - #[test] - fn low_degree() { - test_low_degree::(LowDegreeInterpolationGate::new(4)); - } - - #[test] - fn eval_fns() -> Result<()> { - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - test_eval_fns::(LowDegreeInterpolationGate::new(4)) - } - - #[test] - fn test_gate_constraint() { - type F = GoldilocksField; - type FF = QuadraticExtension; - const D: usize = 2; - - /// Returns the local wires for an interpolation gate for given coeffs, points and eval point. - fn get_wires( - gate: &LowDegreeInterpolationGate, - shift: F, - coeffs: PolynomialCoeffs, - eval_point: FF, - ) -> Vec { - let points = gate.coset(shift); - let mut v = vec![shift]; - for x in points { - v.extend(coeffs.eval(x.into()).0); - } - v.extend(eval_point.0); - v.extend(coeffs.eval(eval_point).0); - for i in 0..coeffs.len() { - v.extend(coeffs.coeffs[i].0); - } - v.extend(shift.powers().skip(2).take(gate.num_points() - 2)); - v.extend( - eval_point - .powers() - .skip(2) - .take(gate.num_points() - 2) - .flat_map(|ff| ff.0), - ); - v.iter().map(|&x| x.into()).collect() - } - - // Get a working row for LowDegreeInterpolationGate. - let subgroup_bits = 4; - let shift = F::rand(); - let coeffs = PolynomialCoeffs::new(FF::rand_vec(1 << subgroup_bits)); - let eval_point = FF::rand(); - let gate = LowDegreeInterpolationGate::::new(subgroup_bits); - let vars = EvaluationVars { - local_constants: &[], - local_wires: &get_wires(&gate, shift, coeffs, eval_point), - public_inputs_hash: &HashOut::rand(), - }; - - assert!( - gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), - "Gate constraints are not satisfied." - ); - } -} diff --git a/plonky2/src/gates/mod.rs b/plonky2/src/gates/mod.rs index 1d2fc058..9df1a535 100644 --- a/plonky2/src/gates/mod.rs +++ b/plonky2/src/gates/mod.rs @@ -5,11 +5,9 @@ pub mod arithmetic_base; pub mod arithmetic_extension; pub mod base_sum; pub mod constant; +pub mod coset_interpolation; pub mod exponentiation; pub mod gate; -pub mod high_degree_interpolation; -pub mod interpolation; -pub mod low_degree_interpolation; pub mod multiplication_extension; pub mod noop; pub mod packed_util;