diff --git a/src/field/extension_field/algebra.rs b/src/field/extension_field/algebra.rs new file mode 100644 index 00000000..db80ce75 --- /dev/null +++ b/src/field/extension_field/algebra.rs @@ -0,0 +1,153 @@ +use crate::field::extension_field::{FieldExtension, OEF}; +use std::fmt::{Debug, Display, Formatter}; +use std::iter::{Product, Sum}; +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +/// Let `F_D` be the extension `F[X]/(X^D-W)`. Then `ExtensionAlgebra` is the quotient `F_D[X]/(X^D-W)`. +/// It's a `D`-dimensional algebra over `F_D` useful to lift the multiplication over `F_D` to a multiplication over `(F_D)^D`. +#[derive(Copy, Clone)] +pub struct ExtensionAlgebra, const D: usize>([F; D]); + +impl, const D: usize> ExtensionAlgebra { + pub const ZERO: Self = Self([F::ZERO; D]); + + pub fn one() -> Self { + F::ONE.into() + } + + pub fn from_basefield_array(arr: [F; D]) -> Self { + Self(arr) + } + + pub fn to_basefield_array(self) -> [F; D] { + self.0 + } +} + +impl, const D: usize> From for ExtensionAlgebra { + fn from(x: F) -> Self { + let mut arr = [F::ZERO; D]; + arr[0] = x; + Self(arr) + } +} + +impl, const D: usize> Display for ExtensionAlgebra { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "({}) + ", self.0[0])?; + for i in 1..D - 1 { + write!(f, "({})*b^{} + ", self.0[i], i)?; + } + write!(f, "{}*b^{}", self.0[D - 1], D - 1) + } +} + +impl, const D: usize> Debug for ExtensionAlgebra { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } +} + +impl, const D: usize> Neg for ExtensionAlgebra { + type Output = Self; + + #[inline] + fn neg(self) -> Self { + let mut arr = self.0; + arr.iter_mut().for_each(|x| *x = -*x); + Self(arr) + } +} + +impl, const D: usize> Add for ExtensionAlgebra { + type Output = Self; + + #[inline] + fn add(self, rhs: Self) -> Self { + let mut arr = self.0; + arr.iter_mut().zip(&rhs.0).for_each(|(x, &y)| *x += y); + Self(arr) + } +} + +impl, const D: usize> AddAssign for ExtensionAlgebra { + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl, const D: usize> Sum for ExtensionAlgebra { + fn sum>(iter: I) -> Self { + iter.fold(Self::ZERO, |acc, x| acc + x) + } +} + +impl, const D: usize> Sub for ExtensionAlgebra { + type Output = Self; + + #[inline] + fn sub(self, rhs: Self) -> Self { + let mut arr = self.0; + arr.iter_mut().zip(&rhs.0).for_each(|(x, &y)| *x -= y); + Self(arr) + } +} + +impl, const D: usize> SubAssign for ExtensionAlgebra { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl, const D: usize> Mul for ExtensionAlgebra { + type Output = Self; + + #[inline] + fn mul(self, rhs: Self) -> Self { + let mut res = [F::ZERO; D]; + let w = F::from_basefield(F::W); + for i in 0..D { + for j in 0..D { + res[(i + j) % D] += if i + j < D { + self.0[i] * rhs.0[j] + } else { + w * self.0[i] * rhs.0[i] + } + } + } + Self(res) + } +} + +impl, const D: usize> MulAssign for ExtensionAlgebra { + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl, const D: usize> Product for ExtensionAlgebra { + fn product>(iter: I) -> Self { + iter.fold(Self::one(), |acc, x| acc * x) + } +} + +/// A polynomial in coefficient form. +#[derive(Clone, Debug)] +pub struct PolynomialCoeffsAlgebra, const D: usize> { + pub(crate) coeffs: Vec>, +} + +impl, const D: usize> PolynomialCoeffsAlgebra { + pub fn new(coeffs: Vec>) -> Self { + PolynomialCoeffsAlgebra { coeffs } + } + + pub fn eval(&self, x: ExtensionAlgebra) -> ExtensionAlgebra { + self.coeffs + .iter() + .rev() + .fold(ExtensionAlgebra::ZERO, |acc, &c| acc * x + c) + } +} diff --git a/src/field/extension_field/mod.rs b/src/field/extension_field/mod.rs index 7e74be78..c5510ea1 100644 --- a/src/field/extension_field/mod.rs +++ b/src/field/extension_field/mod.rs @@ -1,5 +1,6 @@ use crate::field::field::Field; +pub mod algebra; pub mod quadratic; pub mod quartic; mod quartic_quartic; diff --git a/src/gates/interpolation.rs b/src/gates/interpolation.rs index 53e0b7e7..3872b9ee 100644 --- a/src/gates/interpolation.rs +++ b/src/gates/interpolation.rs @@ -3,6 +3,7 @@ use std::marker::PhantomData; use std::ops::Range; use crate::circuit_builder::CircuitBuilder; +use crate::field::extension_field::algebra::PolynomialCoeffsAlgebra; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::lagrange::interpolant; @@ -111,19 +112,19 @@ where let mut constraints = Vec::with_capacity(self.num_constraints()); let coeffs = (0..self.num_points) - .map(|i| vars.get_local_ext_ext(self.wires_coeff(i))) + .map(|i| vars.get_local_ext_algebra(self.wires_coeff(i))) .collect(); - let interpolant = PolynomialCoeffs::new(coeffs); + let interpolant = PolynomialCoeffsAlgebra::new(coeffs); for i in 0..self.num_points { let point = vars.local_wires[self.wire_point(i)]; - let value = vars.get_local_ext_ext(self.wires_value(i)); + let value = vars.get_local_ext_algebra(self.wires_value(i)); let computed_value = interpolant.eval(point.into()); constraints.extend(&(value - computed_value).to_basefield_array()); } - let evaluation_point = vars.get_local_ext_ext(self.wires_evaluation_point()); - let evaluation_value = vars.get_local_ext_ext(self.wires_evaluation_value()); + let evaluation_point = vars.get_local_ext_algebra(self.wires_evaluation_point()); + let evaluation_value = vars.get_local_ext_algebra(self.wires_evaluation_value()); let computed_evaluation_value = interpolant.eval(evaluation_point); constraints.extend(&(evaluation_value - computed_evaluation_value).to_basefield_array()); diff --git a/src/vars.rs b/src/vars.rs index 88d0759a..1ac7935c 100644 --- a/src/vars.rs +++ b/src/vars.rs @@ -1,6 +1,7 @@ use std::convert::TryInto; use std::ops::Range; +use crate::field::extension_field::algebra::ExtensionAlgebra; use crate::field::extension_field::target::{ExtensionExtensionTarget, ExtensionTarget}; use crate::field::extension_field::{Extendable, FieldExtension}; use crate::field::field::Field; @@ -18,16 +19,13 @@ pub struct EvaluationVarsBase<'a, F: Field> { } impl<'a, F: Extendable, const D: usize> EvaluationVars<'a, F, D> { - pub fn get_local_ext_ext( + pub fn get_local_ext_algebra( &self, wire_range: Range, - ) -> <>::Extension as Extendable>::Extension - where - F::Extension: Extendable, - { + ) -> ExtensionAlgebra { debug_assert_eq!(wire_range.len(), D); let arr = self.local_wires[wire_range].try_into().unwrap(); - <>::Extension as Extendable>::Extension::from_basefield_array(arr) + ExtensionAlgebra::from_basefield_array(arr) } }