Crandall squaring in AVX2 (#233)

This commit is contained in:
Jakub Nabaglo 2021-09-11 17:47:17 -07:00 committed by GitHub
parent c0e8edb899
commit bdd86a306f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 0 deletions

View File

@ -5,6 +5,7 @@ use std::iter::{Product, Sum};
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use crate::field::crandall_field::CrandallField;
use crate::field::field_types::Field;
use crate::field::packed_field::PackedField;
// PackedCrandallAVX2 wraps an array of four u64s, with the new and get methods to convert that
@ -161,6 +162,11 @@ impl PackedField for PackedCrandallAVX2 {
};
(Self::new(res0), Self::new(res1))
}
#[inline]
fn square(&self) -> Self {
Self::new(unsafe { square(self.get()) })
}
}
impl Sub<Self> for PackedCrandallAVX2 {
@ -349,6 +355,28 @@ unsafe fn mul64_64_s(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
(res_hi4, res_lo2_s)
}
/// Full 64-bit squaring. This routine is 1.2x faster than the scalar instruction.
#[inline]
unsafe fn square64_s(x: __m256i) -> (__m256i, __m256i) {
let x_hi = _mm256_srli_epi64(x, 32);
let mul_ll = _mm256_mul_epu32(x, x);
let mul_lh = _mm256_mul_epu32(x, x_hi);
let mul_hh = _mm256_mul_epu32(x_hi, x_hi);
let res_lo0_s = shift(mul_ll);
let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 33));
// cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on
// overflow and must be subtracted, not added.
let carry = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s);
let res_hi0 = mul_hh;
let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 31));
let res_hi2 = _mm256_sub_epi64(res_hi1, carry);
(res_hi2, res_lo1_s)
}
/// (u64 << 64) + u64 + u64 -> u128 addition with carry. The third argument is pre-shifted by 2^63.
/// The result is also shifted.
#[inline]
@ -391,6 +419,12 @@ unsafe fn mul(x: __m256i, y: __m256i) -> __m256i {
shift(reduce128s_s(mul64_64_s(x, y)))
}
/// Square an integer modulo FIELD_ORDER.
#[inline]
unsafe fn square(x: __m256i) -> __m256i {
shift(reduce128s_s(square64_s(x)))
}
#[inline]
unsafe fn interleave0(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
let a = _mm256_unpacklo_epi64(x, y);
@ -462,6 +496,18 @@ mod tests {
}
}
#[test]
fn test_square() {
let packed_a = PackedCrandallAVX2::from_arr(TEST_VALS_A);
let packed_res = packed_a.square();
let arr_res = packed_res.to_arr();
let expected = TEST_VALS_A.iter().map(|&a| a.square());
for (exp, res) in expected.zip(arr_res) {
assert_eq!(res, exp);
}
}
#[test]
fn test_neg() {
let packed_a = PackedCrandallAVX2::from_arr(TEST_VALS_A);

View File

@ -189,6 +189,10 @@ impl<F: Field> PackedField for Singleton<F> {
_ => panic!("r cannot be more than LOG2_WIDTH"),
}
}
fn square(&self) -> Self {
Self(self.0.square())
}
}
impl<F: Field> Sub<Self> for Singleton<F> {