diff --git a/src/field/packed_crandall_avx2.rs b/src/field/packed_crandall_avx2.rs index 2d427af6..59315e32 100644 --- a/src/field/packed_crandall_avx2.rs +++ b/src/field/packed_crandall_avx2.rs @@ -5,6 +5,7 @@ use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use crate::field::crandall_field::CrandallField; +use crate::field::field_types::Field; use crate::field::packed_field::PackedField; // PackedCrandallAVX2 wraps an array of four u64s, with the new and get methods to convert that @@ -161,6 +162,11 @@ impl PackedField for PackedCrandallAVX2 { }; (Self::new(res0), Self::new(res1)) } + + #[inline] + fn square(&self) -> Self { + Self::new(unsafe { square(self.get()) }) + } } impl Sub for PackedCrandallAVX2 { @@ -349,6 +355,28 @@ unsafe fn mul64_64_s(x: __m256i, y: __m256i) -> (__m256i, __m256i) { (res_hi4, res_lo2_s) } +/// Full 64-bit squaring. This routine is 1.2x faster than the scalar instruction. +#[inline] +unsafe fn square64_s(x: __m256i) -> (__m256i, __m256i) { + let x_hi = _mm256_srli_epi64(x, 32); + let mul_ll = _mm256_mul_epu32(x, x); + let mul_lh = _mm256_mul_epu32(x, x_hi); + let mul_hh = _mm256_mul_epu32(x_hi, x_hi); + + let res_lo0_s = shift(mul_ll); + let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 33)); + + // cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on + // overflow and must be subtracted, not added. + let carry = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s); + + let res_hi0 = mul_hh; + let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 31)); + let res_hi2 = _mm256_sub_epi64(res_hi1, carry); + + (res_hi2, res_lo1_s) +} + /// (u64 << 64) + u64 + u64 -> u128 addition with carry. The third argument is pre-shifted by 2^63. /// The result is also shifted. #[inline] @@ -391,6 +419,12 @@ unsafe fn mul(x: __m256i, y: __m256i) -> __m256i { shift(reduce128s_s(mul64_64_s(x, y))) } +/// Square an integer modulo FIELD_ORDER. +#[inline] +unsafe fn square(x: __m256i) -> __m256i { + shift(reduce128s_s(square64_s(x))) +} + #[inline] unsafe fn interleave0(x: __m256i, y: __m256i) -> (__m256i, __m256i) { let a = _mm256_unpacklo_epi64(x, y); @@ -462,6 +496,18 @@ mod tests { } } + #[test] + fn test_square() { + let packed_a = PackedCrandallAVX2::from_arr(TEST_VALS_A); + let packed_res = packed_a.square(); + let arr_res = packed_res.to_arr(); + + let expected = TEST_VALS_A.iter().map(|&a| a.square()); + for (exp, res) in expected.zip(arr_res) { + assert_eq!(res, exp); + } + } + #[test] fn test_neg() { let packed_a = PackedCrandallAVX2::from_arr(TEST_VALS_A); diff --git a/src/field/packed_field.rs b/src/field/packed_field.rs index b41e23a7..15ee5a9e 100644 --- a/src/field/packed_field.rs +++ b/src/field/packed_field.rs @@ -189,6 +189,10 @@ impl PackedField for Singleton { _ => panic!("r cannot be more than LOG2_WIDTH"), } } + + fn square(&self) -> Self { + Self(self.0.square()) + } } impl Sub for Singleton {