mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-06 15:53:10 +00:00
Minor optimizations to AVX2 multiplication (#378)
* Minor optimizations to AVX2 multiplication * Typos (thx Hamish!)
This commit is contained in:
parent
5eaa1ad529
commit
aff71943c3
@ -3,7 +3,21 @@ use core::arch::x86_64::*;
|
||||
use crate::field::field_types::PrimeField;
|
||||
|
||||
pub trait ReducibleAVX2: PrimeField {
|
||||
unsafe fn reduce128s_s(x_s: (__m256i, __m256i)) -> __m256i;
|
||||
unsafe fn reduce128(x: (__m256i, __m256i)) -> __m256i;
|
||||
}
|
||||
|
||||
const SIGN_BIT: u64 = 1 << 63;
|
||||
|
||||
#[inline]
|
||||
unsafe fn sign_bit() -> __m256i {
|
||||
_mm256_set1_epi64x(SIGN_BIT as i64)
|
||||
}
|
||||
|
||||
/// Add 2^63 with overflow. Needed to emulate unsigned comparisons (see point 3. in
|
||||
/// packed_prime_field.rs).
|
||||
#[inline]
|
||||
pub unsafe fn shift(x: __m256i) -> __m256i {
|
||||
_mm256_xor_si256(x, sign_bit())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
||||
@ -2,19 +2,21 @@ use core::arch::x86_64::*;
|
||||
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::field::packed_avx2::common::{
|
||||
add_no_canonicalize_64_64s_s, epsilon, sub_no_canonicalize_64s_64_s, ReducibleAVX2,
|
||||
add_no_canonicalize_64_64s_s, epsilon, shift, sub_no_canonicalize_64s_64_s, ReducibleAVX2,
|
||||
};
|
||||
|
||||
/// Reduce a u128 modulo FIELD_ORDER. The input is (u64, u64), pre-shifted by 2^63. The result is
|
||||
/// similarly shifted.
|
||||
impl ReducibleAVX2 for GoldilocksField {
|
||||
#[inline]
|
||||
unsafe fn reduce128s_s(x_s: (__m256i, __m256i)) -> __m256i {
|
||||
let (hi0, lo0_s) = x_s;
|
||||
unsafe fn reduce128(x: (__m256i, __m256i)) -> __m256i {
|
||||
let (hi0, lo0) = x;
|
||||
let lo0_s = shift(lo0);
|
||||
let hi_hi0 = _mm256_srli_epi64(hi0, 32);
|
||||
let lo1_s = sub_no_canonicalize_64s_64_s::<GoldilocksField>(lo0_s, hi_hi0);
|
||||
let t1 = _mm256_mul_epu32(hi0, epsilon::<GoldilocksField>());
|
||||
let lo2_s = add_no_canonicalize_64_64s_s::<GoldilocksField>(t1, lo1_s);
|
||||
lo2_s
|
||||
let lo2 = shift(lo2_s);
|
||||
lo2
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,7 +6,7 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
|
||||
|
||||
use crate::field::field_types::PrimeField;
|
||||
use crate::field::packed_avx2::common::{
|
||||
add_no_canonicalize_64_64s_s, epsilon, field_order, ReducibleAVX2,
|
||||
add_no_canonicalize_64_64s_s, epsilon, field_order, shift, ReducibleAVX2,
|
||||
};
|
||||
use crate::field::packed_field::PackedField;
|
||||
|
||||
@ -211,13 +211,6 @@ impl<F: ReducibleAVX2> Sum for PackedPrimeField<F> {
|
||||
}
|
||||
}
|
||||
|
||||
const SIGN_BIT: u64 = 1 << 63;
|
||||
|
||||
#[inline]
|
||||
unsafe fn sign_bit() -> __m256i {
|
||||
_mm256_set1_epi64x(SIGN_BIT as i64)
|
||||
}
|
||||
|
||||
// Resources:
|
||||
// 1. Intel Intrinsics Guide for explanation of each intrinsic:
|
||||
// https://software.intel.com/sites/landingpage/IntrinsicsGuide/
|
||||
@ -267,12 +260,6 @@ unsafe fn sign_bit() -> __m256i {
|
||||
// Notice that the above 3-value addition still only requires two calls to shift, just like our
|
||||
// 2-value addition.
|
||||
|
||||
/// Add 2^63 with overflow. Needed to emulate unsigned comparisons (see point 3. above).
|
||||
#[inline]
|
||||
unsafe fn shift(x: __m256i) -> __m256i {
|
||||
_mm256_xor_si256(x, sign_bit())
|
||||
}
|
||||
|
||||
/// Convert to canonical representation.
|
||||
/// The argument is assumed to be shifted by 1 << 63 (i.e. x_s = x + 1<<63, where x is the field
|
||||
/// value). The returned value is similarly shifted by 1 << 63 (i.e. we return y_s = y + (1<<63),
|
||||
@ -311,67 +298,83 @@ unsafe fn neg<F: PrimeField>(y: __m256i) -> __m256i {
|
||||
_mm256_sub_epi64(shift(field_order::<F>()), canonicalize_s::<F>(y_s))
|
||||
}
|
||||
|
||||
/// Full 64-bit by 64-bit multiplication. This emulated multiplication is 1.5x slower than the
|
||||
/// Full 64-bit by 64-bit multiplication. This emulated multiplication is 1.33x slower than the
|
||||
/// scalar instruction, but may be worth it if we want our data to live in vector registers.
|
||||
#[inline]
|
||||
unsafe fn mul64_64_s(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
|
||||
let x_hi = _mm256_srli_epi64(x, 32);
|
||||
let y_hi = _mm256_srli_epi64(y, 32);
|
||||
unsafe fn mul64_64(x: __m256i, y: __m256i) -> (__m256i, __m256i) {
|
||||
// We want to move the high 32 bits to the low position. The multiplication instruction ignores
|
||||
// the high 32 bits, so it's ok to just duplicate it into the low position. This duplication can
|
||||
// be done on port 5; bitshifts run on ports 0 and 1, competing with multiplication.
|
||||
// This instruction is only provided for 32-bit floats, not integers. Idk why Intel makes the
|
||||
// distinction; the casts are free and it guarantees that the exact bit pattern is preserved.
|
||||
// Using a swizzle instruction of the wrong domain (float vs int) does not increase latency
|
||||
// since Haswell.
|
||||
let x_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x)));
|
||||
let y_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(y)));
|
||||
|
||||
// All four pairwise multiplications
|
||||
let mul_ll = _mm256_mul_epu32(x, y);
|
||||
let mul_lh = _mm256_mul_epu32(x, y_hi);
|
||||
let mul_hl = _mm256_mul_epu32(x_hi, y);
|
||||
let mul_hh = _mm256_mul_epu32(x_hi, y_hi);
|
||||
|
||||
let res_lo0_s = shift(mul_ll);
|
||||
let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 32));
|
||||
let res_lo2_s = _mm256_add_epi32(res_lo1_s, _mm256_slli_epi64(mul_hl, 32));
|
||||
// Bignum addition
|
||||
// Extract high 32 bits of mul_ll and add to mul_hl. This cannot overflow.
|
||||
let mul_ll_hi = _mm256_srli_epi64::<32>(mul_ll);
|
||||
let t0 = _mm256_add_epi64(mul_hl, mul_ll_hi);
|
||||
// Extract low 32 bits of t0 and add to mul_lh. Again, this cannot overflow.
|
||||
// Also, extract high 32 bits of t0 and add to mul_hh.
|
||||
let t0_lo = _mm256_and_si256(t0, _mm256_set1_epi64x(u32::MAX.into()));
|
||||
let t0_hi = _mm256_srli_epi64::<32>(t0);
|
||||
let t1 = _mm256_add_epi64(mul_lh, t0_lo);
|
||||
let t2 = _mm256_add_epi64(mul_hh, t0_hi);
|
||||
// Lastly, extract the high 32 bits of t1 and add to t2.
|
||||
let t1_hi = _mm256_srli_epi64::<32>(t1);
|
||||
let res_hi = _mm256_add_epi64(t2, t1_hi);
|
||||
|
||||
// cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on
|
||||
// overflow and must be subtracted, not added.
|
||||
let carry0 = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s);
|
||||
let carry1 = _mm256_cmpgt_epi64(res_lo1_s, res_lo2_s);
|
||||
// Form res_lo by combining the low half of mul_ll with the low half of t1 (shifted into high
|
||||
// position).
|
||||
let t1_lo = _mm256_castps_si256(_mm256_moveldup_ps(_mm256_castsi256_ps(t1)));
|
||||
let res_lo = _mm256_blend_epi32::<0xaa>(mul_ll, t1_lo);
|
||||
|
||||
let res_hi0 = mul_hh;
|
||||
let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 32));
|
||||
let res_hi2 = _mm256_add_epi64(res_hi1, _mm256_srli_epi64(mul_hl, 32));
|
||||
let res_hi3 = _mm256_sub_epi64(res_hi2, carry0);
|
||||
let res_hi4 = _mm256_sub_epi64(res_hi3, carry1);
|
||||
|
||||
(res_hi4, res_lo2_s)
|
||||
(res_hi, res_lo)
|
||||
}
|
||||
|
||||
/// Full 64-bit squaring. This routine is 1.2x faster than the scalar instruction.
|
||||
#[inline]
|
||||
unsafe fn square64_s(x: __m256i) -> (__m256i, __m256i) {
|
||||
let x_hi = _mm256_srli_epi64(x, 32);
|
||||
unsafe fn square64(x: __m256i) -> (__m256i, __m256i) {
|
||||
// Get high 32 bits of x. See comment in mul64_64_s.
|
||||
let x_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x)));
|
||||
|
||||
// All pairwise multiplications.
|
||||
let mul_ll = _mm256_mul_epu32(x, x);
|
||||
let mul_lh = _mm256_mul_epu32(x, x_hi);
|
||||
let mul_hh = _mm256_mul_epu32(x_hi, x_hi);
|
||||
|
||||
let res_lo0_s = shift(mul_ll);
|
||||
let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 33));
|
||||
// Bignum addition, but mul_lh is shifted by 33 bits (not 32).
|
||||
let mul_ll_hi = _mm256_srli_epi64::<33>(mul_ll);
|
||||
let t0 = _mm256_add_epi64(mul_lh, mul_ll_hi);
|
||||
let t0_hi = _mm256_srli_epi64::<31>(t0);
|
||||
let res_hi = _mm256_add_epi64(mul_hh, t0_hi);
|
||||
|
||||
// cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on
|
||||
// overflow and must be subtracted, not added.
|
||||
let carry = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s);
|
||||
// Form low result by adding the mul_ll and the low 31 bits of mul_lh (shifted to the high
|
||||
// position).
|
||||
let mul_lh_lo = _mm256_slli_epi64::<33>(mul_lh);
|
||||
let res_lo = _mm256_add_epi64(mul_ll, mul_lh_lo);
|
||||
|
||||
let res_hi0 = mul_hh;
|
||||
let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 31));
|
||||
let res_hi2 = _mm256_sub_epi64(res_hi1, carry);
|
||||
|
||||
(res_hi2, res_lo1_s)
|
||||
(res_hi, res_lo)
|
||||
}
|
||||
|
||||
/// Multiply two integers modulo FIELD_ORDER.
|
||||
#[inline]
|
||||
unsafe fn mul<F: ReducibleAVX2>(x: __m256i, y: __m256i) -> __m256i {
|
||||
shift(F::reduce128s_s(mul64_64_s(x, y)))
|
||||
F::reduce128(mul64_64(x, y))
|
||||
}
|
||||
|
||||
/// Square an integer modulo FIELD_ORDER.
|
||||
#[inline]
|
||||
unsafe fn square<F: ReducibleAVX2>(x: __m256i) -> __m256i {
|
||||
shift(F::reduce128s_s(square64_s(x)))
|
||||
F::reduce128(square64(x))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user