From 4f6e9fb2e031d4147e24e85da5cf9b35032e9ce8 Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Mon, 24 May 2021 17:31:55 +0200 Subject: [PATCH] Recursive evaluation for interpolation gate. --- src/field/extension_field/target.rs | 44 ++++++++++++++++++++++++++++- src/gadgets/mod.rs | 1 + src/gadgets/polynomial.rs | 35 +++++++++++++++++++++++ src/gates/interpolation.rs | 30 +++++++++++++++++++- src/vars.rs | 9 ++++++ 5 files changed, 117 insertions(+), 2 deletions(-) create mode 100644 src/gadgets/polynomial.rs diff --git a/src/field/extension_field/target.rs b/src/field/extension_field/target.rs index 6ca84a24..28429c1a 100644 --- a/src/field/extension_field/target.rs +++ b/src/field/extension_field/target.rs @@ -4,9 +4,22 @@ use crate::field::field::Field; use crate::target::Target; #[derive(Copy, Clone, Debug)] -pub struct ExtensionTarget([Target; D]); +pub struct ExtensionTarget(pub [Target; D]); + +impl ExtensionTarget { + pub fn to_target_array(&self) -> [Target; D] { + self.0 + } +} impl CircuitBuilder { + pub fn zero_ext(&mut self) -> ExtensionTarget + where + F: Extendable, + { + ExtensionTarget([self.zero(); D]) + } + pub fn add_extension( &mut self, mut a: ExtensionTarget, @@ -21,6 +34,20 @@ impl CircuitBuilder { a } + pub fn sub_extension( + &mut self, + mut a: ExtensionTarget, + b: ExtensionTarget, + ) -> ExtensionTarget + where + F: Extendable, + { + for i in 0..D { + a.0[i] = self.sub(a.0[i], b.0[i]); + } + a + } + pub fn mul_extension( &mut self, a: ExtensionTarget, @@ -43,4 +70,19 @@ impl CircuitBuilder { } ExtensionTarget(res) } + + /// Returns a*b where `b` is in the extension field and `a` is in the base field. + pub fn scalar_mul( + &mut self, + a: Target, + mut b: ExtensionTarget, + ) -> ExtensionTarget + where + F: Extendable, + { + for i in 0..D { + b.0[i] = self.mul(a, b.0[i]); + } + b + } } diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs index ed84207e..9a6a728e 100644 --- a/src/gadgets/mod.rs +++ b/src/gadgets/mod.rs @@ -1,3 +1,4 @@ pub mod arithmetic; pub mod hash; +pub mod polynomial; pub(crate) mod split_join; diff --git a/src/gadgets/polynomial.rs b/src/gadgets/polynomial.rs new file mode 100644 index 00000000..05427038 --- /dev/null +++ b/src/gadgets/polynomial.rs @@ -0,0 +1,35 @@ +use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::target::ExtensionTarget; +use crate::field::extension_field::Extendable; +use crate::field::field::Field; +use crate::target::Target; + +pub struct PolynomialCoeffsTarget(pub Vec>); + +impl PolynomialCoeffsTarget { + pub fn eval_scalar>( + &self, + builder: &mut CircuitBuilder, + point: Target, + ) -> ExtensionTarget { + let mut acc = builder.zero_ext(); + for &c in self.0.iter().rev() { + let tmp = builder.scalar_mul(point, acc); + acc = builder.add_extension(tmp, c); + } + acc + } + + pub fn eval>( + &self, + builder: &mut CircuitBuilder, + point: ExtensionTarget, + ) -> ExtensionTarget { + let mut acc = builder.zero_ext(); + for &c in self.0.iter().rev() { + let tmp = builder.mul_extension(point, acc); + acc = builder.add_extension(tmp, c); + } + acc + } +} diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index adc7055d..a1b9ddd5 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -6,6 +6,7 @@ use crate::circuit_builder::CircuitBuilder; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; use crate::field::lagrange::interpolant; +use crate::gadgets::polynomial::PolynomialCoeffsTarget; use crate::gates::gate::{Gate, GateRef}; use crate::generator::{SimpleGenerator, WitnessGenerator}; use crate::polynomial::polynomial::PolynomialCoeffs; @@ -125,7 +126,34 @@ impl, const D: usize> Gate for InterpolationGate, vars: EvaluationTargets, ) -> Vec { - todo!() + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let coeffs = (0..self.num_points) + .map(|i| vars.get_local_ext(self.wires_coeff(i))) + .collect(); + let interpolant = PolynomialCoeffsTarget(coeffs); + + for i in 0..self.num_points { + let point = vars.local_wires[self.wire_point(i)]; + let value = vars.get_local_ext(self.wires_value(i)); + let computed_value = interpolant.eval_scalar(builder, point); + constraints.extend( + &builder + .sub_extension(value, computed_value) + .to_target_array(), + ); + } + + let evaluation_point = vars.get_local_ext(self.wires_evaluation_point()); + let evaluation_value = vars.get_local_ext(self.wires_evaluation_value()); + let computed_evaluation_value = interpolant.eval(builder, evaluation_point); + constraints.extend( + &builder + .sub_extension(evaluation_value, computed_evaluation_value) + .to_target_array(), + ); + + constraints } fn generators( diff --git a/src/vars.rs b/src/vars.rs index 7c0afab9..aa8a3561 100644 --- a/src/vars.rs +++ b/src/vars.rs @@ -1,6 +1,7 @@ use std::convert::TryInto; use std::ops::Range; +use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; use crate::target::Target; @@ -27,3 +28,11 @@ pub struct EvaluationTargets<'a> { pub(crate) local_constants: &'a [Target], pub(crate) local_wires: &'a [Target], } + +impl<'a> EvaluationTargets<'a> { + pub fn get_local_ext(&self, wire_range: Range) -> ExtensionTarget { + debug_assert_eq!(wire_range.len(), D); + let arr = self.local_wires[wire_range].try_into().unwrap(); + ExtensionTarget(arr) + } +}