From 4e532f04faf665a7dbc2c14965dd950856ec979a Mon Sep 17 00:00:00 2001 From: Jakub Nabaglo Date: Thu, 6 Jan 2022 15:50:56 -0800 Subject: [PATCH] AVX2 Poseidon S-box optimizations (#421) --- .../x86_64/poseidon_goldilocks_avx2_bmi2.rs | 169 +++++++++++++----- 1 file changed, 127 insertions(+), 42 deletions(-) diff --git a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs index 934583d6..804524ee 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs @@ -4,6 +4,7 @@ use std::mem::size_of; use plonky2_field::field_types::Field; use plonky2_field::goldilocks_field::GoldilocksField; +use plonky2_util::branch_hint; use static_assertions::const_assert; use crate::hash::poseidon::{ @@ -141,6 +142,16 @@ macro_rules! map3 { ($f:ident::<$l:literal>, $v:ident) => { ($f::<$l>($v.0), $f::<$l>($v.1), $f::<$l>($v.2)) }; + ($f:ident::<$l:literal>, $v1:ident, $v2:ident) => { + ( + $f::<$l>($v1.0, $v2.0), + $f::<$l>($v1.1, $v2.1), + $f::<$l>($v1.2, $v2.2), + ) + }; + ($f:ident, $v:ident) => { + ($f($v.0), $f($v.1), $f($v.2)) + }; ($f:ident, $v0:ident, $v1:ident) => { ($f($v0.0, $v1.0), $f($v0.1, $v1.1), $f($v0.2, $v1.2)) }; @@ -188,19 +199,32 @@ unsafe fn const_layer( unsafe fn square3( x: (__m256i, __m256i, __m256i), ) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) { - let sign_bit = _mm256_set1_epi64x(i64::MIN); - let x_hi = map3!(_mm256_srli_epi64::<32>, x); + let x_hi = { + // Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than + // bitshift. This instruction only has a floating-point flavor, so we cast to/from float. + // This is safe and free. + let x_ps = map3!(_mm256_castsi256_ps, x); + let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps); + map3!(_mm256_castps_si256, x_hi_ps) + }; + + // All pairwise multiplications. let mul_ll = map3!(_mm256_mul_epu32, x, x); let mul_lh = map3!(_mm256_mul_epu32, x, x_hi); let mul_hh = map3!(_mm256_mul_epu32, x_hi, x_hi); - let res_lo0_s = map3!(_mm256_xor_si256, mul_ll, rep sign_bit); + + // Bignum addition, but mul_lh is shifted by 33 bits (not 32). + let mul_ll_hi = map3!(_mm256_srli_epi64::<33>, mul_ll); + let t0 = map3!(_mm256_add_epi64, mul_lh, mul_ll_hi); + let t0_hi = map3!(_mm256_srli_epi64::<31>, t0); + let res_hi = map3!(_mm256_add_epi64, mul_hh, t0_hi); + + // Form low result by adding the mul_ll and the low 31 bits of mul_lh (shifted to the high + // position). let mul_lh_lo = map3!(_mm256_slli_epi64::<33>, mul_lh); - let res_lo1_s = map3!(_mm256_add_epi64, res_lo0_s, mul_lh_lo); - let carry = map3!(_mm256_cmpgt_epi64, res_lo0_s, res_lo1_s); - let mul_lh_hi = map3!(_mm256_srli_epi64::<31>, mul_lh); - let res_hi0 = map3!(_mm256_add_epi64, mul_hh, mul_lh_hi); - let res_hi1 = map3!(_mm256_sub_epi64, res_hi0, carry); - (res_lo1_s, res_hi1) + let res_lo = map3!(_mm256_add_epi64, mul_ll, mul_lh_lo); + + (res_lo, res_hi) } #[inline(always)] @@ -208,49 +232,110 @@ unsafe fn mul3( x: (__m256i, __m256i, __m256i), y: (__m256i, __m256i, __m256i), ) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) { - let sign_bit = _mm256_set1_epi64x(i64::MIN); - let y_hi = map3!(_mm256_srli_epi64::<32>, y); - let x_hi = map3!(_mm256_srli_epi64::<32>, x); + let epsilon = _mm256_set1_epi64x(0xffffffff); + let x_hi = { + // Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than + // bitshift. This instruction only has a floating-point flavor, so we cast to/from float. + // This is safe and free. + let x_ps = map3!(_mm256_castsi256_ps, x); + let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps); + map3!(_mm256_castps_si256, x_hi_ps) + }; + let y_hi = { + let y_ps = map3!(_mm256_castsi256_ps, y); + let y_hi_ps = map3!(_mm256_movehdup_ps, y_ps); + map3!(_mm256_castps_si256, y_hi_ps) + }; + + // All four pairwise multiplications let mul_ll = map3!(_mm256_mul_epu32, x, y); let mul_lh = map3!(_mm256_mul_epu32, x, y_hi); let mul_hl = map3!(_mm256_mul_epu32, x_hi, y); let mul_hh = map3!(_mm256_mul_epu32, x_hi, y_hi); - let mul_lh_lo = map3!(_mm256_slli_epi64::<32>, mul_lh); - let res_lo0_s = map3!(_mm256_xor_si256, mul_ll, rep sign_bit); - let mul_hl_lo = map3!(_mm256_slli_epi64::<32>, mul_hl); - let res_lo1_s = map3!(_mm256_add_epi64, res_lo0_s, mul_lh_lo); - let carry0 = map3!(_mm256_cmpgt_epi64, res_lo0_s, res_lo1_s); - let mul_lh_hi = map3!(_mm256_srli_epi64::<32>, mul_lh); - let res_lo2_s = map3!(_mm256_add_epi64, res_lo1_s, mul_hl_lo); - let carry1 = map3!(_mm256_cmpgt_epi64, res_lo1_s, res_lo2_s); - let mul_hl_hi = map3!(_mm256_srli_epi64::<32>, mul_hl); - let res_hi0 = map3!(_mm256_add_epi64, mul_hh, mul_lh_hi); - let res_hi1 = map3!(_mm256_add_epi64, res_hi0, mul_hl_hi); - let res_hi2 = map3!(_mm256_sub_epi64, res_hi1, carry0); - let res_hi3 = map3!(_mm256_sub_epi64, res_hi2, carry1); - (res_lo2_s, res_hi3) + + // Bignum addition + // Extract high 32 bits of mul_ll and add to mul_hl. This cannot overflow. + let mul_ll_hi = map3!(_mm256_srli_epi64::<32>, mul_ll); + let t0 = map3!(_mm256_add_epi64, mul_hl, mul_ll_hi); + // Extract low 32 bits of t0 and add to mul_lh. Again, this cannot overflow. + // Also, extract high 32 bits of t0 and add to mul_hh. + let t0_lo = map3!(_mm256_and_si256, t0, rep epsilon); + let t0_hi = map3!(_mm256_srli_epi64::<32>, t0); + let t1 = map3!(_mm256_add_epi64, mul_lh, t0_lo); + let t2 = map3!(_mm256_add_epi64, mul_hh, t0_hi); + // Lastly, extract the high 32 bits of t1 and add to t2. + let t1_hi = map3!(_mm256_srli_epi64::<32>, t1); + let res_hi = map3!(_mm256_add_epi64, t2, t1_hi); + + // Form res_lo by combining the low half of mul_ll with the low half of t1 (shifted into high + // position). + let t1_lo = { + let t1_ps = map3!(_mm256_castsi256_ps, t1); + let t1_lo_ps = map3!(_mm256_moveldup_ps, t1_ps); + map3!(_mm256_castps_si256, t1_lo_ps) + }; + let res_lo = map3!(_mm256_blend_epi32::<0xaa>, mul_ll, t1_lo); + + (res_lo, res_hi) +} + +/// Addition, where the second operand is `0 <= y < 0xffffffff00000001`. +#[inline(always)] +unsafe fn add_small( + x_s: (__m256i, __m256i, __m256i), + y: (__m256i, __m256i, __m256i), +) -> (__m256i, __m256i, __m256i) { + let res_wrapped_s = map3!(_mm256_add_epi64, x_s, y); + let mask = map3!(_mm256_cmpgt_epi32, x_s, res_wrapped_s); + let wrapback_amt = map3!(_mm256_srli_epi64::<32>, mask); // EPSILON if overflowed else 0. + let res_s = map3!(_mm256_add_epi64, res_wrapped_s, wrapback_amt); + res_s +} + +#[inline(always)] +unsafe fn maybe_adj_sub(res_wrapped_s: __m256i, mask: __m256i) -> __m256i { + // The subtraction is very unlikely to overflow so we're best off branching. + // The even u32s in `mask` are meaningless, so we want to ignore them. `_mm256_testz_pd` + // branches depending on the sign bit of double-precision (64-bit) floats. Bit cast `mask` to + // floating-point (this is free). + let mask_pd = _mm256_castsi256_pd(mask); + // `_mm256_testz_pd(mask_pd, mask_pd) == 1` iff all sign bits are 0, meaning that underflow + // did not occur for any of the vector elements. + if _mm256_testz_pd(mask_pd, mask_pd) == 1 { + res_wrapped_s + } else { + branch_hint(); + // Highly unlikely: underflow did occur. Find adjustment per element and apply it. + let adj_amount = _mm256_srli_epi64::<32>(mask); // EPSILON if underflow. + _mm256_sub_epi64(res_wrapped_s, adj_amount) + } +} + +/// Addition, where the second operand is much smaller than `0xffffffff00000001`. +#[inline(always)] +unsafe fn sub_tiny( + x_s: (__m256i, __m256i, __m256i), + y: (__m256i, __m256i, __m256i), +) -> (__m256i, __m256i, __m256i) { + let res_wrapped_s = map3!(_mm256_sub_epi64, x_s, y); + let mask = map3!(_mm256_cmpgt_epi32, res_wrapped_s, x_s); + let res_s = map3!(maybe_adj_sub, res_wrapped_s, mask); + res_s } #[inline(always)] unsafe fn reduce3( - (x_lo_s, x_hi): ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)), + (lo0, hi0): ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)), ) -> (__m256i, __m256i, __m256i) { - let epsilon = _mm256_set1_epi64x(0xffffffff); let sign_bit = _mm256_set1_epi64x(i64::MIN); - let x_hi_hi = map3!(_mm256_srli_epi64::<32>, x_hi); - let res0_s = map3!(_mm256_sub_epi64, x_lo_s, x_hi_hi); - let wraparound_mask0 = map3!(_mm256_cmpgt_epi32, res0_s, x_lo_s); - let wraparound_adj0 = map3!(_mm256_srli_epi64::<32>, wraparound_mask0); - let x_hi_lo = map3!(_mm256_and_si256, x_hi, rep epsilon); - let x_hi_lo_shifted = map3!(_mm256_slli_epi64::<32>, x_hi); - let res1_s = map3!(_mm256_sub_epi64, res0_s, wraparound_adj0); - let x_hi_lo_mul_epsilon = map3!(_mm256_sub_epi64, x_hi_lo_shifted, x_hi_lo); - let res2_s = map3!(_mm256_add_epi64, res1_s, x_hi_lo_mul_epsilon); - let wraparound_mask2 = map3!(_mm256_cmpgt_epi32, res1_s, res2_s); - let wraparound_adj2 = map3!(_mm256_srli_epi64::<32>, wraparound_mask2); - let res3_s = map3!(_mm256_add_epi64, res2_s, wraparound_adj2); - let res3 = map3!(_mm256_xor_si256, res3_s, rep sign_bit); - res3 + let epsilon = _mm256_set1_epi64x(0xffffffff); + let lo0_s = map3!(_mm256_xor_si256, lo0, rep sign_bit); + let hi_hi0 = map3!(_mm256_srli_epi64::<32>, hi0); + let lo1_s = sub_tiny(lo0_s, hi_hi0); + let t1 = map3!(_mm256_mul_epu32, hi0, rep epsilon); + let lo2_s = add_small(lo1_s, t1); + let lo2 = map3!(_mm256_xor_si256, lo2_s, rep sign_bit); + lo2 } #[inline(always)]