diff --git a/field/src/arch/x86_64/avx2_goldilocks_field.rs b/field/src/arch/x86_64/avx2_goldilocks_field.rs index 61eb26ac..b9336cee 100644 --- a/field/src/arch/x86_64/avx2_goldilocks_field.rs +++ b/field/src/arch/x86_64/avx2_goldilocks_field.rs @@ -7,6 +7,7 @@ use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssi use crate::field_types::{Field, PrimeField}; use crate::goldilocks_field::GoldilocksField; +use crate::ops::Square; use crate::packed_field::PackedField; // Ideally `Avx2GoldilocksField` would wrap `__m256i`. Unfortunately, `__m256i` has an alignment of @@ -194,7 +195,9 @@ unsafe impl PackedField for Avx2GoldilocksField { }; (Self::new(res0), Self::new(res1)) } +} +impl Square for Avx2GoldilocksField { #[inline] fn square(&self) -> Self { Self::new(unsafe { square(self.get()) }) @@ -509,6 +512,7 @@ mod tests { use crate::arch::x86_64::avx2_goldilocks_field::Avx2GoldilocksField; use crate::field_types::PrimeField; use crate::goldilocks_field::GoldilocksField; + use crate::ops::Square; use crate::packed_field::PackedField; fn test_vals_a() -> [GoldilocksField; 4] { diff --git a/field/src/extension_field/quadratic.rs b/field/src/extension_field/quadratic.rs index 3b2651cf..e072d323 100644 --- a/field/src/extension_field/quadratic.rs +++ b/field/src/extension_field/quadratic.rs @@ -9,6 +9,7 @@ use serde::{Deserialize, Serialize}; use crate::extension_field::{Extendable, FieldExtension, Frobenius, OEF}; use crate::field_types::Field; +use crate::ops::Square; #[derive(Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] #[serde(bound = "")] @@ -73,19 +74,6 @@ impl> Field for QuadraticExtension { F::characteristic() } - #[inline(always)] - fn square(&self) -> Self { - // Specialising mul reduces the computation of c1 from 2 muls - // and one add to one mul and a shift - - let Self([a0, a1]) = *self; - - let c0 = a0.square() + >::W * a1.square(); - let c1 = a0 * a1.double(); - - Self([c0, c1]) - } - // Algorithm 11.3.4 in Handbook of Elliptic and Hyperelliptic Curve Cryptography. fn try_inverse(&self) -> Option { if self.is_zero() { @@ -204,6 +192,21 @@ impl> MulAssign for QuadraticExtension { } } +impl> Square for QuadraticExtension { + #[inline(always)] + fn square(&self) -> Self { + // Specialising mul reduces the computation of c1 from 2 muls + // and one add to one mul and a shift + + let Self([a0, a1]) = *self; + + let c0 = a0.square() + >::W * a1.square(); + let c1 = a0 * a1.double(); + + Self([c0, c1]) + } +} + impl> Product for QuadraticExtension { fn product>(iter: I) -> Self { iter.fold(Self::ONE, |acc, x| acc * x) diff --git a/field/src/extension_field/quartic.rs b/field/src/extension_field/quartic.rs index b060b778..4e9cebf9 100644 --- a/field/src/extension_field/quartic.rs +++ b/field/src/extension_field/quartic.rs @@ -10,6 +10,7 @@ use serde::{Deserialize, Serialize}; use crate::extension_field::{Extendable, FieldExtension, Frobenius, OEF}; use crate::field_types::Field; +use crate::ops::Square; #[derive(Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] #[serde(bound = "")] @@ -75,19 +76,6 @@ impl> Field for QuarticExtension { F::characteristic() } - #[inline(always)] - fn square(&self) -> Self { - let Self([a0, a1, a2, a3]) = *self; - let w = >::W; - - let c0 = a0.square() + w * (a1 * a3.double() + a2.square()); - let c1 = (a0 * a1 + w * a2 * a3).double(); - let c2 = a0 * a2.double() + a1.square() + w * a3.square(); - let c3 = (a0 * a3 + a1 * a2).double(); - - Self([c0, c1, c2, c3]) - } - // Algorithm 11.3.4 in Handbook of Elliptic and Hyperelliptic Curve Cryptography. fn try_inverse(&self) -> Option { if self.is_zero() { @@ -241,6 +229,21 @@ impl> MulAssign for QuarticExtension { } } +impl> Square for QuarticExtension { + #[inline(always)] + fn square(&self) -> Self { + let Self([a0, a1, a2, a3]) = *self; + let w = >::W; + + let c0 = a0.square() + w * (a1 * a3.double() + a2.square()); + let c1 = (a0 * a1 + w * a2 * a3).double(); + let c2 = a0 * a2.double() + a1.square() + w * a3.square(); + let c3 = (a0 * a3 + a1 * a2).double(); + + Self([c0, c1, c2, c3]) + } +} + impl> Product for QuarticExtension { fn product>(iter: I) -> Self { iter.fold(Self::ONE, |acc, x| acc * x) diff --git a/field/src/field_testing.rs b/field/src/field_testing.rs index 672fb554..171d0f12 100644 --- a/field/src/field_testing.rs +++ b/field/src/field_testing.rs @@ -1,6 +1,7 @@ use crate::extension_field::Extendable; use crate::extension_field::Frobenius; use crate::field_types::Field; +use crate::ops::Square; #[macro_export] macro_rules! test_field_arithmetic { diff --git a/field/src/field_types.rs b/field/src/field_types.rs index 4ac6ab75..0d7b314f 100644 --- a/field/src/field_types.rs +++ b/field/src/field_types.rs @@ -11,6 +11,7 @@ use serde::de::DeserializeOwned; use serde::Serialize; use crate::extension_field::Frobenius; +use crate::ops::Square; /// A finite field. pub trait Field: @@ -26,6 +27,7 @@ pub trait Field: + SubAssign + Mul + MulAssign + + Square + Product + Div + DivAssign @@ -80,11 +82,6 @@ pub trait Field: *self + *self } - #[inline] - fn square(&self) -> Self { - *self * *self - } - #[inline] fn cube(&self) -> Self { self.square() * *self diff --git a/field/src/lib.rs b/field/src/lib.rs index b17bdc5e..47dd9ccb 100644 --- a/field/src/lib.rs +++ b/field/src/lib.rs @@ -17,6 +17,7 @@ pub mod field_types; pub mod goldilocks_field; pub mod interpolation; mod inversion; +pub mod ops; pub mod packable; pub mod packed_field; pub mod polynomial; diff --git a/field/src/ops.rs b/field/src/ops.rs new file mode 100644 index 00000000..bf8ff8a9 --- /dev/null +++ b/field/src/ops.rs @@ -0,0 +1,11 @@ +use std::ops::Mul; + +pub trait Square { + fn square(&self) -> Self; +} + +impl + Copy> Square for F { + default fn square(&self) -> Self { + *self * *self + } +} diff --git a/field/src/packed_field.rs b/field/src/packed_field.rs index 813cbeba..4b3336d9 100644 --- a/field/src/packed_field.rs +++ b/field/src/packed_field.rs @@ -4,6 +4,7 @@ use std::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign}; use std::slice; use crate::field_types::Field; +use crate::ops::Square; /// # Safety /// - WIDTH is assumed to be a power of 2. @@ -24,6 +25,7 @@ pub unsafe trait PackedField: + Mul + MulAssign + MulAssign + + Square + Neg + Product + Send @@ -44,10 +46,6 @@ where const ZEROS: Self; const ONES: Self; - fn square(&self) -> Self { - *self * *self - } - fn from_arr(arr: [Self::Scalar; Self::WIDTH]) -> Self; fn as_arr(&self) -> [Self::Scalar; Self::WIDTH]; @@ -106,10 +104,6 @@ unsafe impl PackedField for F { const ZEROS: Self = F::ZERO; const ONES: Self = F::ONE; - fn square(&self) -> Self { - ::square(self) - } - fn from_arr(arr: [Self::Scalar; Self::WIDTH]) -> Self { arr[0] } diff --git a/field/src/prime_field_testing.rs b/field/src/prime_field_testing.rs index 42a7bee7..4aec6712 100644 --- a/field/src/prime_field_testing.rs +++ b/field/src/prime_field_testing.rs @@ -71,6 +71,7 @@ macro_rules! test_prime_field_arithmetic { use std::ops::{Add, Mul, Neg, Sub}; use crate::field_types::{Field, PrimeField}; + use crate::ops::Square; #[test] fn arithmetic_addition() { diff --git a/plonky2/src/curve/curve_adds.rs b/plonky2/src/curve/curve_adds.rs index 3638d450..98dbc697 100644 --- a/plonky2/src/curve/curve_adds.rs +++ b/plonky2/src/curve/curve_adds.rs @@ -1,6 +1,7 @@ use std::ops::Add; use plonky2_field::field_types::Field; +use plonky2_field::ops::Square; use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; diff --git a/plonky2/src/curve/curve_summation.rs b/plonky2/src/curve/curve_summation.rs index 67f8023c..7ea01524 100644 --- a/plonky2/src/curve/curve_summation.rs +++ b/plonky2/src/curve/curve_summation.rs @@ -1,6 +1,7 @@ use std::iter::Sum; use plonky2_field::field_types::Field; +use plonky2_field::ops::Square; use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; diff --git a/plonky2/src/curve/curve_types.rs b/plonky2/src/curve/curve_types.rs index c02db73f..b7ee34e6 100644 --- a/plonky2/src/curve/curve_types.rs +++ b/plonky2/src/curve/curve_types.rs @@ -2,6 +2,7 @@ use std::fmt::Debug; use std::ops::Neg; use plonky2_field::field_types::Field; +use plonky2_field::ops::Square; // To avoid implementation conflicts from associated types, // see https://github.com/rust-lang/rust/issues/20400 diff --git a/plonky2/src/gates/exponentiation.rs b/plonky2/src/gates/exponentiation.rs index d9628c4f..51558a21 100644 --- a/plonky2/src/gates/exponentiation.rs +++ b/plonky2/src/gates/exponentiation.rs @@ -2,6 +2,7 @@ use std::marker::PhantomData; use plonky2_field::extension_field::Extendable; use plonky2_field::field_types::Field; +use plonky2_field::ops::Square; use plonky2_field::packed_field::PackedField; use crate::gates::gate::Gate; @@ -90,7 +91,7 @@ impl, const D: usize> Gate for Exponentiation let prev_intermediate_value = if i == 0 { F::Extension::ONE } else { - ::square(&intermediate_values[i - 1]) + intermediate_values[i - 1].square() }; // power_bits is in LE order, but we accumulate in BE order.