Recursive evaluation for interpolation gate.

This commit is contained in:
wborgeaud 2021-05-24 17:31:55 +02:00
parent b64a5fab46
commit 4f6e9fb2e0
5 changed files with 117 additions and 2 deletions

View File

@ -4,9 +4,22 @@ use crate::field::field::Field;
use crate::target::Target;
#[derive(Copy, Clone, Debug)]
pub struct ExtensionTarget<const D: usize>([Target; D]);
pub struct ExtensionTarget<const D: usize>(pub [Target; D]);
impl<const D: usize> ExtensionTarget<D> {
pub fn to_target_array(&self) -> [Target; D] {
self.0
}
}
impl<F: Field> CircuitBuilder<F> {
pub fn zero_ext<const D: usize>(&mut self) -> ExtensionTarget<D>
where
F: Extendable<D>,
{
ExtensionTarget([self.zero(); D])
}
pub fn add_extension<const D: usize>(
&mut self,
mut a: ExtensionTarget<D>,
@ -21,6 +34,20 @@ impl<F: Field> CircuitBuilder<F> {
a
}
pub fn sub_extension<const D: usize>(
&mut self,
mut a: ExtensionTarget<D>,
b: ExtensionTarget<D>,
) -> ExtensionTarget<D>
where
F: Extendable<D>,
{
for i in 0..D {
a.0[i] = self.sub(a.0[i], b.0[i]);
}
a
}
pub fn mul_extension<const D: usize>(
&mut self,
a: ExtensionTarget<D>,
@ -43,4 +70,19 @@ impl<F: Field> CircuitBuilder<F> {
}
ExtensionTarget(res)
}
/// Returns a*b where `b` is in the extension field and `a` is in the base field.
pub fn scalar_mul<const D: usize>(
&mut self,
a: Target,
mut b: ExtensionTarget<D>,
) -> ExtensionTarget<D>
where
F: Extendable<D>,
{
for i in 0..D {
b.0[i] = self.mul(a, b.0[i]);
}
b
}
}

View File

@ -1,3 +1,4 @@
pub mod arithmetic;
pub mod hash;
pub mod polynomial;
pub(crate) mod split_join;

35
src/gadgets/polynomial.rs Normal file
View File

@ -0,0 +1,35 @@
use crate::circuit_builder::CircuitBuilder;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::Extendable;
use crate::field::field::Field;
use crate::target::Target;
pub struct PolynomialCoeffsTarget<const D: usize>(pub Vec<ExtensionTarget<D>>);
impl<const D: usize> PolynomialCoeffsTarget<D> {
pub fn eval_scalar<F: Field + Extendable<D>>(
&self,
builder: &mut CircuitBuilder<F>,
point: Target,
) -> ExtensionTarget<D> {
let mut acc = builder.zero_ext();
for &c in self.0.iter().rev() {
let tmp = builder.scalar_mul(point, acc);
acc = builder.add_extension(tmp, c);
}
acc
}
pub fn eval<F: Field + Extendable<D>>(
&self,
builder: &mut CircuitBuilder<F>,
point: ExtensionTarget<D>,
) -> ExtensionTarget<D> {
let mut acc = builder.zero_ext();
for &c in self.0.iter().rev() {
let tmp = builder.mul_extension(point, acc);
acc = builder.add_extension(tmp, c);
}
acc
}
}

View File

@ -6,6 +6,7 @@ use crate::circuit_builder::CircuitBuilder;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field::Field;
use crate::field::lagrange::interpolant;
use crate::gadgets::polynomial::PolynomialCoeffsTarget;
use crate::gates::gate::{Gate, GateRef};
use crate::generator::{SimpleGenerator, WitnessGenerator};
use crate::polynomial::polynomial::PolynomialCoeffs;
@ -125,7 +126,34 @@ impl<F: Field + Extendable<D>, const D: usize> Gate<F> for InterpolationGate<F,
builder: &mut CircuitBuilder<F>,
vars: EvaluationTargets,
) -> Vec<Target> {
todo!()
let mut constraints = Vec::with_capacity(self.num_constraints());
let coeffs = (0..self.num_points)
.map(|i| vars.get_local_ext(self.wires_coeff(i)))
.collect();
let interpolant = PolynomialCoeffsTarget(coeffs);
for i in 0..self.num_points {
let point = vars.local_wires[self.wire_point(i)];
let value = vars.get_local_ext(self.wires_value(i));
let computed_value = interpolant.eval_scalar(builder, point);
constraints.extend(
&builder
.sub_extension(value, computed_value)
.to_target_array(),
);
}
let evaluation_point = vars.get_local_ext(self.wires_evaluation_point());
let evaluation_value = vars.get_local_ext(self.wires_evaluation_value());
let computed_evaluation_value = interpolant.eval(builder, evaluation_point);
constraints.extend(
&builder
.sub_extension(evaluation_value, computed_evaluation_value)
.to_target_array(),
);
constraints
}
fn generators(

View File

@ -1,6 +1,7 @@
use std::convert::TryInto;
use std::ops::Range;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::field::Field;
use crate::target::Target;
@ -27,3 +28,11 @@ pub struct EvaluationTargets<'a> {
pub(crate) local_constants: &'a [Target],
pub(crate) local_wires: &'a [Target],
}
impl<'a> EvaluationTargets<'a> {
pub fn get_local_ext<const D: usize>(&self, wire_range: Range<usize>) -> ExtensionTarget<D> {
debug_assert_eq!(wire_range.len(), D);
let arr = self.local_wires[wire_range].try_into().unwrap();
ExtensionTarget(arr)
}
}