diff --git a/src/field/packed_avx2/common.rs b/src/field/packed_avx2/common.rs index 97674a17..c100e6dc 100644 --- a/src/field/packed_avx2/common.rs +++ b/src/field/packed_avx2/common.rs @@ -3,7 +3,21 @@ use core::arch::x86_64::*; use crate::field::field_types::PrimeField; pub trait ReducibleAVX2: PrimeField { - unsafe fn reduce128s_s(x_s: (__m256i, __m256i)) -> __m256i; + unsafe fn reduce128(x: (__m256i, __m256i)) -> __m256i; +} + +const SIGN_BIT: u64 = 1 << 63; + +#[inline] +unsafe fn sign_bit() -> __m256i { + _mm256_set1_epi64x(SIGN_BIT as i64) +} + +/// Add 2^63 with overflow. Needed to emulate unsigned comparisons (see point 3. in +/// packed_prime_field.rs). +#[inline] +pub unsafe fn shift(x: __m256i) -> __m256i { + _mm256_xor_si256(x, sign_bit()) } #[inline] diff --git a/src/field/packed_avx2/goldilocks.rs b/src/field/packed_avx2/goldilocks.rs index 2cea1767..186c8e0c 100644 --- a/src/field/packed_avx2/goldilocks.rs +++ b/src/field/packed_avx2/goldilocks.rs @@ -2,19 +2,21 @@ use core::arch::x86_64::*; use crate::field::goldilocks_field::GoldilocksField; use crate::field::packed_avx2::common::{ - add_no_canonicalize_64_64s_s, epsilon, sub_no_canonicalize_64s_64_s, ReducibleAVX2, + add_no_canonicalize_64_64s_s, epsilon, shift, sub_no_canonicalize_64s_64_s, ReducibleAVX2, }; /// Reduce a u128 modulo FIELD_ORDER. The input is (u64, u64), pre-shifted by 2^63. The result is /// similarly shifted. impl ReducibleAVX2 for GoldilocksField { #[inline] - unsafe fn reduce128s_s(x_s: (__m256i, __m256i)) -> __m256i { - let (hi0, lo0_s) = x_s; + unsafe fn reduce128(x: (__m256i, __m256i)) -> __m256i { + let (hi0, lo0) = x; + let lo0_s = shift(lo0); let hi_hi0 = _mm256_srli_epi64(hi0, 32); let lo1_s = sub_no_canonicalize_64s_64_s::(lo0_s, hi_hi0); let t1 = _mm256_mul_epu32(hi0, epsilon::()); let lo2_s = add_no_canonicalize_64_64s_s::(t1, lo1_s); - lo2_s + let lo2 = shift(lo2_s); + lo2 } } diff --git a/src/field/packed_avx2/packed_prime_field.rs b/src/field/packed_avx2/packed_prime_field.rs index b892da4a..5800d0bd 100644 --- a/src/field/packed_avx2/packed_prime_field.rs +++ b/src/field/packed_avx2/packed_prime_field.rs @@ -6,7 +6,7 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use crate::field::field_types::PrimeField; use crate::field::packed_avx2::common::{ - add_no_canonicalize_64_64s_s, epsilon, field_order, ReducibleAVX2, + add_no_canonicalize_64_64s_s, epsilon, field_order, shift, ReducibleAVX2, }; use crate::field::packed_field::PackedField; @@ -211,13 +211,6 @@ impl Sum for PackedPrimeField { } } -const SIGN_BIT: u64 = 1 << 63; - -#[inline] -unsafe fn sign_bit() -> __m256i { - _mm256_set1_epi64x(SIGN_BIT as i64) -} - // Resources: // 1. Intel Intrinsics Guide for explanation of each intrinsic: // https://software.intel.com/sites/landingpage/IntrinsicsGuide/ @@ -267,12 +260,6 @@ unsafe fn sign_bit() -> __m256i { // Notice that the above 3-value addition still only requires two calls to shift, just like our // 2-value addition. -/// Add 2^63 with overflow. Needed to emulate unsigned comparisons (see point 3. above). -#[inline] -unsafe fn shift(x: __m256i) -> __m256i { - _mm256_xor_si256(x, sign_bit()) -} - /// Convert to canonical representation. /// The argument is assumed to be shifted by 1 << 63 (i.e. x_s = x + 1<<63, where x is the field /// value). The returned value is similarly shifted by 1 << 63 (i.e. we return y_s = y + (1<<63), @@ -311,67 +298,83 @@ unsafe fn neg(y: __m256i) -> __m256i { _mm256_sub_epi64(shift(field_order::()), canonicalize_s::(y_s)) } -/// Full 64-bit by 64-bit multiplication. This emulated multiplication is 1.5x slower than the +/// Full 64-bit by 64-bit multiplication. This emulated multiplication is 1.33x slower than the /// scalar instruction, but may be worth it if we want our data to live in vector registers. #[inline] -unsafe fn mul64_64_s(x: __m256i, y: __m256i) -> (__m256i, __m256i) { - let x_hi = _mm256_srli_epi64(x, 32); - let y_hi = _mm256_srli_epi64(y, 32); +unsafe fn mul64_64(x: __m256i, y: __m256i) -> (__m256i, __m256i) { + // We want to move the high 32 bits to the low position. The multiplication instruction ignores + // the high 32 bits, so it's ok to just duplicate it into the low position. This duplication can + // be done on port 5; bitshifts run on ports 0 and 1, competing with multiplication. + // This instruction is only provided for 32-bit floats, not integers. Idk why Intel makes the + // distinction; the casts are free and it guarantees that the exact bit pattern is preserved. + // Using a swizzle instruction of the wrong domain (float vs int) does not increase latency + // since Haswell. + let x_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x))); + let y_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(y))); + + // All four pairwise multiplications let mul_ll = _mm256_mul_epu32(x, y); let mul_lh = _mm256_mul_epu32(x, y_hi); let mul_hl = _mm256_mul_epu32(x_hi, y); let mul_hh = _mm256_mul_epu32(x_hi, y_hi); - let res_lo0_s = shift(mul_ll); - let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 32)); - let res_lo2_s = _mm256_add_epi32(res_lo1_s, _mm256_slli_epi64(mul_hl, 32)); + // Bignum addition + // Extract high 32 bits of mul_ll and add to mul_hl. This cannot overflow. + let mul_ll_hi = _mm256_srli_epi64::<32>(mul_ll); + let t0 = _mm256_add_epi64(mul_hl, mul_ll_hi); + // Extract low 32 bits of t0 and add to mul_lh. Again, this cannot overflow. + // Also, extract high 32 bits of t0 and add to mul_hh. + let t0_lo = _mm256_and_si256(t0, _mm256_set1_epi64x(u32::MAX.into())); + let t0_hi = _mm256_srli_epi64::<32>(t0); + let t1 = _mm256_add_epi64(mul_lh, t0_lo); + let t2 = _mm256_add_epi64(mul_hh, t0_hi); + // Lastly, extract the high 32 bits of t1 and add to t2. + let t1_hi = _mm256_srli_epi64::<32>(t1); + let res_hi = _mm256_add_epi64(t2, t1_hi); - // cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on - // overflow and must be subtracted, not added. - let carry0 = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s); - let carry1 = _mm256_cmpgt_epi64(res_lo1_s, res_lo2_s); + // Form res_lo by combining the low half of mul_ll with the low half of t1 (shifted into high + // position). + let t1_lo = _mm256_castps_si256(_mm256_moveldup_ps(_mm256_castsi256_ps(t1))); + let res_lo = _mm256_blend_epi32::<0xaa>(mul_ll, t1_lo); - let res_hi0 = mul_hh; - let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 32)); - let res_hi2 = _mm256_add_epi64(res_hi1, _mm256_srli_epi64(mul_hl, 32)); - let res_hi3 = _mm256_sub_epi64(res_hi2, carry0); - let res_hi4 = _mm256_sub_epi64(res_hi3, carry1); - - (res_hi4, res_lo2_s) + (res_hi, res_lo) } /// Full 64-bit squaring. This routine is 1.2x faster than the scalar instruction. #[inline] -unsafe fn square64_s(x: __m256i) -> (__m256i, __m256i) { - let x_hi = _mm256_srli_epi64(x, 32); +unsafe fn square64(x: __m256i) -> (__m256i, __m256i) { + // Get high 32 bits of x. See comment in mul64_64_s. + let x_hi = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x))); + + // All pairwise multiplications. let mul_ll = _mm256_mul_epu32(x, x); let mul_lh = _mm256_mul_epu32(x, x_hi); let mul_hh = _mm256_mul_epu32(x_hi, x_hi); - let res_lo0_s = shift(mul_ll); - let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 33)); + // Bignum addition, but mul_lh is shifted by 33 bits (not 32). + let mul_ll_hi = _mm256_srli_epi64::<33>(mul_ll); + let t0 = _mm256_add_epi64(mul_lh, mul_ll_hi); + let t0_hi = _mm256_srli_epi64::<31>(t0); + let res_hi = _mm256_add_epi64(mul_hh, t0_hi); - // cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on - // overflow and must be subtracted, not added. - let carry = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s); + // Form low result by adding the mul_ll and the low 31 bits of mul_lh (shifted to the high + // position). + let mul_lh_lo = _mm256_slli_epi64::<33>(mul_lh); + let res_lo = _mm256_add_epi64(mul_ll, mul_lh_lo); - let res_hi0 = mul_hh; - let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 31)); - let res_hi2 = _mm256_sub_epi64(res_hi1, carry); - - (res_hi2, res_lo1_s) + (res_hi, res_lo) } /// Multiply two integers modulo FIELD_ORDER. #[inline] unsafe fn mul(x: __m256i, y: __m256i) -> __m256i { - shift(F::reduce128s_s(mul64_64_s(x, y))) + F::reduce128(mul64_64(x, y)) } /// Square an integer modulo FIELD_ORDER. #[inline] unsafe fn square(x: __m256i) -> __m256i { - shift(F::reduce128s_s(square64_s(x))) + F::reduce128(square64(x)) } #[inline]