diff --git a/src/field/extension_field/algebra.rs b/src/field/extension_field/algebra.rs index 004e8dc1..97c3c12d 100644 --- a/src/field/extension_field/algebra.rs +++ b/src/field/extension_field/algebra.rs @@ -151,3 +151,101 @@ impl, const D: usize> PolynomialCoeffsAlgebra { .fold(ExtensionAlgebra::ZERO, |acc, &c| acc * x + c) } } + +#[cfg(test)] +mod tests { + use crate::field::crandall_field::CrandallField; + use crate::field::extension_field::algebra::ExtensionAlgebra; + use crate::field::extension_field::{Extendable, FieldExtension}; + use crate::field::field::Field; + use itertools::Itertools; + + /// Tests that the multiplication on the extension algebra lifts that of the field extension. + fn test_extension_algebra, const D: usize>() { + #[derive(Copy, Clone, Debug)] + enum ZeroOne { + Zero, + One, + } + + let to_field = |zo: &ZeroOne| match zo { + ZeroOne::Zero => F::ZERO, + ZeroOne::One => F::ONE, + }; + let to_fields = |x: &[ZeroOne], y: &[ZeroOne]| -> (F::Extension, F::Extension) { + let mut arr0 = [F::ZERO; D]; + let mut arr1 = [F::ZERO; D]; + arr0.copy_from_slice(&x.iter().map(to_field).collect::>()); + arr1.copy_from_slice(&y.iter().map(to_field).collect::>()); + ( + >::Extension::from_basefield_array(arr0), + >::Extension::from_basefield_array(arr1), + ) + }; + + // Standard MLE formula. + let selector = |xs: Vec, ts: &[F::Extension]| -> F::Extension { + (0..2 * D) + .map(|i| match xs[i] { + ZeroOne::Zero => F::Extension::ONE - ts[i], + ZeroOne::One => ts[i], + }) + .product() + }; + + let mul_mle = |ts: Vec| -> [F::Extension; D] { + let mut ans = [F::Extension::ZERO; D]; + for xs in (0..2 * D) + .map(|_| vec![ZeroOne::Zero, ZeroOne::One]) + .multi_cartesian_product() + { + let (a, b) = to_fields(&xs[..D], &xs[D..]); + let c = a * b; + let res = selector(xs, &ts); + for i in 0..D { + ans[i] += res * c.to_basefield_array()[i].into(); + } + } + ans + }; + + let ts = F::Extension::rand_vec(2 * D); + let mut arr0 = [F::Extension::ZERO; D]; + let mut arr1 = [F::Extension::ZERO; D]; + arr0.copy_from_slice(&ts[..D]); + arr1.copy_from_slice(&ts[D..]); + let x = ExtensionAlgebra::from_basefield_array(arr0); + let y = ExtensionAlgebra::from_basefield_array(arr1); + let z = x * y; + + dbg!(z.0, mul_mle(ts.clone())); + assert_eq!(z.0, mul_mle(ts)); + } + + mod base { + use super::*; + + #[test] + fn test_algebra() { + test_extension_algebra::(); + } + } + + mod quadratic { + use super::*; + + #[test] + fn test_algebra() { + test_extension_algebra::(); + } + } + + mod quartic { + use super::*; + + #[test] + fn test_algebra() { + test_extension_algebra::(); + } + } +}