use std::convert::TryInto; use std::marker::PhantomData; use std::ops::Range; use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; use crate::field::lagrange::interpolant; use crate::gates::gate::{Gate, GateRef}; use crate::generator::{SimpleGenerator, WitnessGenerator}; use crate::polynomial::polynomial::PolynomialCoeffs; use crate::target::Target; use crate::vars::{EvaluationTargets, EvaluationVars}; use crate::wire::Wire; use crate::witness::PartialWitness; /// Evaluates the interpolant of some given elements from a field extension. /// /// In particular, this gate takes as inputs `num_points` points, `num_points` values, and the point /// to evaluate the interpolant at. It computes the interpolant and outputs its evaluation at the /// given point. #[derive(Clone, Debug)] pub(crate) struct InterpolationGate, const D: usize> { num_points: usize, _phantom: PhantomData, } impl, const D: usize> InterpolationGate { pub fn new(num_points: usize) -> GateRef { let gate = Self { num_points, _phantom: PhantomData, }; GateRef::new(gate) } fn start_points(&self) -> usize { 0 } /// Wire indices of the `i`th interpolant point. pub fn wire_point(&self, i: usize) -> usize { debug_assert!(i < self.num_points); self.start_points() + i } fn start_values(&self) -> usize { self.start_points() + self.num_points } /// Wire indices of the `i`th interpolant value. pub fn wires_value(&self, i: usize) -> Range { debug_assert!(i < self.num_points); let start = self.start_values() + i * D; start..start + D } fn start_evaluation_point(&self) -> usize { self.start_values() + self.num_points * D } /// Wire indices of the point to evaluate the interpolant at. pub fn wires_evaluation_point(&self) -> Range { let start = self.start_evaluation_point(); start..start + D } fn start_evaluation_value(&self) -> usize { self.start_evaluation_point() + D } /// Wire indices of the interpolated value. pub fn wires_evaluation_value(&self) -> Range { let start = self.start_evaluation_value(); start..start + D } fn start_coeffs(&self) -> usize { self.start_evaluation_value() + D } /// Wire indices of the interpolant's `i`th coefficient. pub fn wires_coeff(&self, i: usize) -> Range { debug_assert!(i < self.num_points); let start = self.start_coeffs() + i * D; start..start + D } /// End of wire indices, exclusive. fn end(&self) -> usize { self.start_coeffs() + self.num_points * D } } impl, const D: usize> Gate for InterpolationGate { fn id(&self) -> String { format!("{:?}", self, D) } fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { let mut constraints = Vec::with_capacity(self.num_constraints()); let coeffs = (0..self.num_points) .map(|i| vars.get_local_ext(self.wires_coeff(i))) .collect(); let interpolant = PolynomialCoeffs::new(coeffs); for i in 0..self.num_points { let point = F::Extension::from_basefield(vars.local_wires[self.wire_point(i)]); let value = vars.get_local_ext(self.wires_value(i)); let computed_value = interpolant.eval(point); constraints.extend(&(value - computed_value).to_basefield_array()); } let evaluation_point = vars.get_local_ext(self.wires_evaluation_point()); let evaluation_value = vars.get_local_ext(self.wires_evaluation_value()); let computed_evaluation_value = interpolant.eval(evaluation_point); constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array()); constraints } fn eval_unfiltered_recursively( &self, builder: &mut CircuitBuilder, vars: EvaluationTargets, ) -> Vec { todo!() } fn generators( &self, gate_index: usize, _local_constants: &[F], ) -> Vec>> { let gen = InterpolationGenerator:: { gate_index, gate: self.clone(), _phantom: PhantomData, }; vec![Box::new(gen)] } fn num_wires(&self) -> usize { self.end() } fn num_constants(&self) -> usize { 0 } fn degree(&self) -> usize { // The highest power of x is `num_points - 1`, and then multiplication by the coefficient // adds 1. self.num_points } fn num_constraints(&self) -> usize { // num_points * D constraints to check for consistency between the coefficients and the // point-value pairs, plus D constraints for the evaluation value. self.num_points * D + D } } struct InterpolationGenerator, const D: usize> { gate_index: usize, gate: InterpolationGate, _phantom: PhantomData, } impl, const D: usize> SimpleGenerator for InterpolationGenerator { fn dependencies(&self) -> Vec { let local_target = |input| { Target::Wire(Wire { gate: self.gate_index, input, }) }; let local_targets = |inputs: Range| inputs.map(local_target); let mut deps = Vec::new(); deps.extend(local_targets(self.gate.wires_evaluation_point())); deps.extend(local_targets(self.gate.wires_evaluation_value())); for i in 0..self.gate.num_points { deps.push(local_target(self.gate.wire_point(i))); deps.extend(local_targets(self.gate.wires_value(i))); deps.extend(local_targets(self.gate.wires_coeff(i))); } deps } fn run_once(&self, witness: &PartialWitness) -> PartialWitness { let n = self.gate.num_points; let local_wire = |input| Wire { gate: self.gate_index, input, }; let get_local_wire = |input| witness.get_wire(local_wire(input)); let get_local_ext = |wire_range: Range| { debug_assert_eq!(wire_range.len(), D); let values = wire_range.map(get_local_wire).collect::>(); let arr = values.try_into().unwrap(); F::Extension::from_basefield_array(arr) }; // Compute the interpolant. let points = (0..n) .map(|i| { ( F::Extension::from_basefield(get_local_wire(self.gate.wire_point(i))), get_local_ext(self.gate.wires_value(i)), ) }) .collect::>(); let interpolant = interpolant(&points); let mut result = PartialWitness::::new(); for (i, &coeff) in interpolant.coeffs.iter().enumerate() { let wires = self.gate.wires_coeff(i).map(local_wire); result.set_ext_wires(wires, coeff); } let evaluation_point = get_local_ext(self.gate.wires_evaluation_point()); let evaluation_value = interpolant.eval(evaluation_point); let evaluation_value_wires = self.gate.wires_evaluation_value().map(local_wire); result.set_ext_wires(evaluation_value_wires, evaluation_value); result } } #[cfg(test)] mod tests { use std::marker::PhantomData; use crate::field::crandall_field::CrandallField; use crate::gates::gate::Gate; use crate::gates::gate_testing::test_low_degree; use crate::gates::interpolation::InterpolationGate; #[test] fn wire_indices() { let gate = InterpolationGate:: { num_points: 2, _phantom: PhantomData, }; // The exact indices aren't really important, but we want to make sure we don't have any // overlaps or gaps. assert_eq!(gate.wire_point(0), 0); assert_eq!(gate.wire_point(1), 1); assert_eq!(gate.wires_value(0), 2..6); assert_eq!(gate.wires_value(1), 6..10); assert_eq!(gate.wires_evaluation_point(), 10..14); assert_eq!(gate.wires_evaluation_value(), 14..18); assert_eq!(gate.wires_coeff(0), 18..22); assert_eq!(gate.wires_coeff(1), 22..26); assert_eq!(gate.num_wires(), 26); } #[test] fn low_degree() { type F = CrandallField; test_low_degree(InterpolationGate::::new(4)); test_low_degree(InterpolationGate::::new(4)); } }