mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-02-25 08:13:07 +00:00
Crandall squaring in AVX2 (#233)
This commit is contained in:
parent
c0e8edb899
commit
bdd86a306f
@ -5,6 +5,7 @@ use std::iter::{Product, Sum};
|
||||
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
|
||||
|
||||
use crate::field::crandall_field::CrandallField;
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::packed_field::PackedField;
|
||||
|
||||
// PackedCrandallAVX2 wraps an array of four u64s, with the new and get methods to convert that
|
||||
@ -161,6 +162,11 @@ impl PackedField for PackedCrandallAVX2 {
|
||||
};
|
||||
(Self::new(res0), Self::new(res1))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn square(&self) -> Self {
|
||||
Self::new(unsafe { square(self.get()) })
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<Self> for PackedCrandallAVX2 {
|
||||
@ -349,6 +355,28 @@ unsafe fn mul64_64_s(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
|
||||
(res_hi4, res_lo2_s)
|
||||
}
|
||||
|
||||
/// Full 64-bit squaring. This routine is 1.2x faster than the scalar instruction.
|
||||
#[inline]
|
||||
unsafe fn square64_s(x: __m256i) -> (__m256i, __m256i) {
|
||||
let x_hi = _mm256_srli_epi64(x, 32);
|
||||
let mul_ll = _mm256_mul_epu32(x, x);
|
||||
let mul_lh = _mm256_mul_epu32(x, x_hi);
|
||||
let mul_hh = _mm256_mul_epu32(x_hi, x_hi);
|
||||
|
||||
let res_lo0_s = shift(mul_ll);
|
||||
let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 33));
|
||||
|
||||
// cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on
|
||||
// overflow and must be subtracted, not added.
|
||||
let carry = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s);
|
||||
|
||||
let res_hi0 = mul_hh;
|
||||
let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 31));
|
||||
let res_hi2 = _mm256_sub_epi64(res_hi1, carry);
|
||||
|
||||
(res_hi2, res_lo1_s)
|
||||
}
|
||||
|
||||
/// (u64 << 64) + u64 + u64 -> u128 addition with carry. The third argument is pre-shifted by 2^63.
|
||||
/// The result is also shifted.
|
||||
#[inline]
|
||||
@ -391,6 +419,12 @@ unsafe fn mul(x: __m256i, y: __m256i) -> __m256i {
|
||||
shift(reduce128s_s(mul64_64_s(x, y)))
|
||||
}
|
||||
|
||||
/// Square an integer modulo FIELD_ORDER.
|
||||
#[inline]
|
||||
unsafe fn square(x: __m256i) -> __m256i {
|
||||
shift(reduce128s_s(square64_s(x)))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn interleave0(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
|
||||
let a = _mm256_unpacklo_epi64(x, y);
|
||||
@ -462,6 +496,18 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_square() {
|
||||
let packed_a = PackedCrandallAVX2::from_arr(TEST_VALS_A);
|
||||
let packed_res = packed_a.square();
|
||||
let arr_res = packed_res.to_arr();
|
||||
|
||||
let expected = TEST_VALS_A.iter().map(|&a| a.square());
|
||||
for (exp, res) in expected.zip(arr_res) {
|
||||
assert_eq!(res, exp);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_neg() {
|
||||
let packed_a = PackedCrandallAVX2::from_arr(TEST_VALS_A);
|
||||
|
||||
@ -189,6 +189,10 @@ impl<F: Field> PackedField for Singleton<F> {
|
||||
_ => panic!("r cannot be more than LOG2_WIDTH"),
|
||||
}
|
||||
}
|
||||
|
||||
fn square(&self) -> Self {
|
||||
Self(self.0.square())
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field> Sub<Self> for Singleton<F> {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user