mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-04-18 09:33:34 +00:00
AVX2 Poseidon S-box optimizations (#421)
This commit is contained in:
parent
bf30fed701
commit
4e532f04fa
@ -4,6 +4,7 @@ use std::mem::size_of;
|
|||||||
|
|
||||||
use plonky2_field::field_types::Field;
|
use plonky2_field::field_types::Field;
|
||||||
use plonky2_field::goldilocks_field::GoldilocksField;
|
use plonky2_field::goldilocks_field::GoldilocksField;
|
||||||
|
use plonky2_util::branch_hint;
|
||||||
use static_assertions::const_assert;
|
use static_assertions::const_assert;
|
||||||
|
|
||||||
use crate::hash::poseidon::{
|
use crate::hash::poseidon::{
|
||||||
@ -141,6 +142,16 @@ macro_rules! map3 {
|
|||||||
($f:ident::<$l:literal>, $v:ident) => {
|
($f:ident::<$l:literal>, $v:ident) => {
|
||||||
($f::<$l>($v.0), $f::<$l>($v.1), $f::<$l>($v.2))
|
($f::<$l>($v.0), $f::<$l>($v.1), $f::<$l>($v.2))
|
||||||
};
|
};
|
||||||
|
($f:ident::<$l:literal>, $v1:ident, $v2:ident) => {
|
||||||
|
(
|
||||||
|
$f::<$l>($v1.0, $v2.0),
|
||||||
|
$f::<$l>($v1.1, $v2.1),
|
||||||
|
$f::<$l>($v1.2, $v2.2),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
($f:ident, $v:ident) => {
|
||||||
|
($f($v.0), $f($v.1), $f($v.2))
|
||||||
|
};
|
||||||
($f:ident, $v0:ident, $v1:ident) => {
|
($f:ident, $v0:ident, $v1:ident) => {
|
||||||
($f($v0.0, $v1.0), $f($v0.1, $v1.1), $f($v0.2, $v1.2))
|
($f($v0.0, $v1.0), $f($v0.1, $v1.1), $f($v0.2, $v1.2))
|
||||||
};
|
};
|
||||||
@ -188,19 +199,32 @@ unsafe fn const_layer(
|
|||||||
unsafe fn square3(
|
unsafe fn square3(
|
||||||
x: (__m256i, __m256i, __m256i),
|
x: (__m256i, __m256i, __m256i),
|
||||||
) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) {
|
) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) {
|
||||||
let sign_bit = _mm256_set1_epi64x(i64::MIN);
|
let x_hi = {
|
||||||
let x_hi = map3!(_mm256_srli_epi64::<32>, x);
|
// Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than
|
||||||
|
// bitshift. This instruction only has a floating-point flavor, so we cast to/from float.
|
||||||
|
// This is safe and free.
|
||||||
|
let x_ps = map3!(_mm256_castsi256_ps, x);
|
||||||
|
let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps);
|
||||||
|
map3!(_mm256_castps_si256, x_hi_ps)
|
||||||
|
};
|
||||||
|
|
||||||
|
// All pairwise multiplications.
|
||||||
let mul_ll = map3!(_mm256_mul_epu32, x, x);
|
let mul_ll = map3!(_mm256_mul_epu32, x, x);
|
||||||
let mul_lh = map3!(_mm256_mul_epu32, x, x_hi);
|
let mul_lh = map3!(_mm256_mul_epu32, x, x_hi);
|
||||||
let mul_hh = map3!(_mm256_mul_epu32, x_hi, x_hi);
|
let mul_hh = map3!(_mm256_mul_epu32, x_hi, x_hi);
|
||||||
let res_lo0_s = map3!(_mm256_xor_si256, mul_ll, rep sign_bit);
|
|
||||||
|
// Bignum addition, but mul_lh is shifted by 33 bits (not 32).
|
||||||
|
let mul_ll_hi = map3!(_mm256_srli_epi64::<33>, mul_ll);
|
||||||
|
let t0 = map3!(_mm256_add_epi64, mul_lh, mul_ll_hi);
|
||||||
|
let t0_hi = map3!(_mm256_srli_epi64::<31>, t0);
|
||||||
|
let res_hi = map3!(_mm256_add_epi64, mul_hh, t0_hi);
|
||||||
|
|
||||||
|
// Form low result by adding the mul_ll and the low 31 bits of mul_lh (shifted to the high
|
||||||
|
// position).
|
||||||
let mul_lh_lo = map3!(_mm256_slli_epi64::<33>, mul_lh);
|
let mul_lh_lo = map3!(_mm256_slli_epi64::<33>, mul_lh);
|
||||||
let res_lo1_s = map3!(_mm256_add_epi64, res_lo0_s, mul_lh_lo);
|
let res_lo = map3!(_mm256_add_epi64, mul_ll, mul_lh_lo);
|
||||||
let carry = map3!(_mm256_cmpgt_epi64, res_lo0_s, res_lo1_s);
|
|
||||||
let mul_lh_hi = map3!(_mm256_srli_epi64::<31>, mul_lh);
|
(res_lo, res_hi)
|
||||||
let res_hi0 = map3!(_mm256_add_epi64, mul_hh, mul_lh_hi);
|
|
||||||
let res_hi1 = map3!(_mm256_sub_epi64, res_hi0, carry);
|
|
||||||
(res_lo1_s, res_hi1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
@ -208,49 +232,110 @@ unsafe fn mul3(
|
|||||||
x: (__m256i, __m256i, __m256i),
|
x: (__m256i, __m256i, __m256i),
|
||||||
y: (__m256i, __m256i, __m256i),
|
y: (__m256i, __m256i, __m256i),
|
||||||
) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) {
|
) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) {
|
||||||
let sign_bit = _mm256_set1_epi64x(i64::MIN);
|
let epsilon = _mm256_set1_epi64x(0xffffffff);
|
||||||
let y_hi = map3!(_mm256_srli_epi64::<32>, y);
|
let x_hi = {
|
||||||
let x_hi = map3!(_mm256_srli_epi64::<32>, x);
|
// Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than
|
||||||
|
// bitshift. This instruction only has a floating-point flavor, so we cast to/from float.
|
||||||
|
// This is safe and free.
|
||||||
|
let x_ps = map3!(_mm256_castsi256_ps, x);
|
||||||
|
let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps);
|
||||||
|
map3!(_mm256_castps_si256, x_hi_ps)
|
||||||
|
};
|
||||||
|
let y_hi = {
|
||||||
|
let y_ps = map3!(_mm256_castsi256_ps, y);
|
||||||
|
let y_hi_ps = map3!(_mm256_movehdup_ps, y_ps);
|
||||||
|
map3!(_mm256_castps_si256, y_hi_ps)
|
||||||
|
};
|
||||||
|
|
||||||
|
// All four pairwise multiplications
|
||||||
let mul_ll = map3!(_mm256_mul_epu32, x, y);
|
let mul_ll = map3!(_mm256_mul_epu32, x, y);
|
||||||
let mul_lh = map3!(_mm256_mul_epu32, x, y_hi);
|
let mul_lh = map3!(_mm256_mul_epu32, x, y_hi);
|
||||||
let mul_hl = map3!(_mm256_mul_epu32, x_hi, y);
|
let mul_hl = map3!(_mm256_mul_epu32, x_hi, y);
|
||||||
let mul_hh = map3!(_mm256_mul_epu32, x_hi, y_hi);
|
let mul_hh = map3!(_mm256_mul_epu32, x_hi, y_hi);
|
||||||
let mul_lh_lo = map3!(_mm256_slli_epi64::<32>, mul_lh);
|
|
||||||
let res_lo0_s = map3!(_mm256_xor_si256, mul_ll, rep sign_bit);
|
// Bignum addition
|
||||||
let mul_hl_lo = map3!(_mm256_slli_epi64::<32>, mul_hl);
|
// Extract high 32 bits of mul_ll and add to mul_hl. This cannot overflow.
|
||||||
let res_lo1_s = map3!(_mm256_add_epi64, res_lo0_s, mul_lh_lo);
|
let mul_ll_hi = map3!(_mm256_srli_epi64::<32>, mul_ll);
|
||||||
let carry0 = map3!(_mm256_cmpgt_epi64, res_lo0_s, res_lo1_s);
|
let t0 = map3!(_mm256_add_epi64, mul_hl, mul_ll_hi);
|
||||||
let mul_lh_hi = map3!(_mm256_srli_epi64::<32>, mul_lh);
|
// Extract low 32 bits of t0 and add to mul_lh. Again, this cannot overflow.
|
||||||
let res_lo2_s = map3!(_mm256_add_epi64, res_lo1_s, mul_hl_lo);
|
// Also, extract high 32 bits of t0 and add to mul_hh.
|
||||||
let carry1 = map3!(_mm256_cmpgt_epi64, res_lo1_s, res_lo2_s);
|
let t0_lo = map3!(_mm256_and_si256, t0, rep epsilon);
|
||||||
let mul_hl_hi = map3!(_mm256_srli_epi64::<32>, mul_hl);
|
let t0_hi = map3!(_mm256_srli_epi64::<32>, t0);
|
||||||
let res_hi0 = map3!(_mm256_add_epi64, mul_hh, mul_lh_hi);
|
let t1 = map3!(_mm256_add_epi64, mul_lh, t0_lo);
|
||||||
let res_hi1 = map3!(_mm256_add_epi64, res_hi0, mul_hl_hi);
|
let t2 = map3!(_mm256_add_epi64, mul_hh, t0_hi);
|
||||||
let res_hi2 = map3!(_mm256_sub_epi64, res_hi1, carry0);
|
// Lastly, extract the high 32 bits of t1 and add to t2.
|
||||||
let res_hi3 = map3!(_mm256_sub_epi64, res_hi2, carry1);
|
let t1_hi = map3!(_mm256_srli_epi64::<32>, t1);
|
||||||
(res_lo2_s, res_hi3)
|
let res_hi = map3!(_mm256_add_epi64, t2, t1_hi);
|
||||||
|
|
||||||
|
// Form res_lo by combining the low half of mul_ll with the low half of t1 (shifted into high
|
||||||
|
// position).
|
||||||
|
let t1_lo = {
|
||||||
|
let t1_ps = map3!(_mm256_castsi256_ps, t1);
|
||||||
|
let t1_lo_ps = map3!(_mm256_moveldup_ps, t1_ps);
|
||||||
|
map3!(_mm256_castps_si256, t1_lo_ps)
|
||||||
|
};
|
||||||
|
let res_lo = map3!(_mm256_blend_epi32::<0xaa>, mul_ll, t1_lo);
|
||||||
|
|
||||||
|
(res_lo, res_hi)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Addition, where the second operand is `0 <= y < 0xffffffff00000001`.
|
||||||
|
#[inline(always)]
|
||||||
|
unsafe fn add_small(
|
||||||
|
x_s: (__m256i, __m256i, __m256i),
|
||||||
|
y: (__m256i, __m256i, __m256i),
|
||||||
|
) -> (__m256i, __m256i, __m256i) {
|
||||||
|
let res_wrapped_s = map3!(_mm256_add_epi64, x_s, y);
|
||||||
|
let mask = map3!(_mm256_cmpgt_epi32, x_s, res_wrapped_s);
|
||||||
|
let wrapback_amt = map3!(_mm256_srli_epi64::<32>, mask); // EPSILON if overflowed else 0.
|
||||||
|
let res_s = map3!(_mm256_add_epi64, res_wrapped_s, wrapback_amt);
|
||||||
|
res_s
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
unsafe fn maybe_adj_sub(res_wrapped_s: __m256i, mask: __m256i) -> __m256i {
|
||||||
|
// The subtraction is very unlikely to overflow so we're best off branching.
|
||||||
|
// The even u32s in `mask` are meaningless, so we want to ignore them. `_mm256_testz_pd`
|
||||||
|
// branches depending on the sign bit of double-precision (64-bit) floats. Bit cast `mask` to
|
||||||
|
// floating-point (this is free).
|
||||||
|
let mask_pd = _mm256_castsi256_pd(mask);
|
||||||
|
// `_mm256_testz_pd(mask_pd, mask_pd) == 1` iff all sign bits are 0, meaning that underflow
|
||||||
|
// did not occur for any of the vector elements.
|
||||||
|
if _mm256_testz_pd(mask_pd, mask_pd) == 1 {
|
||||||
|
res_wrapped_s
|
||||||
|
} else {
|
||||||
|
branch_hint();
|
||||||
|
// Highly unlikely: underflow did occur. Find adjustment per element and apply it.
|
||||||
|
let adj_amount = _mm256_srli_epi64::<32>(mask); // EPSILON if underflow.
|
||||||
|
_mm256_sub_epi64(res_wrapped_s, adj_amount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Addition, where the second operand is much smaller than `0xffffffff00000001`.
|
||||||
|
#[inline(always)]
|
||||||
|
unsafe fn sub_tiny(
|
||||||
|
x_s: (__m256i, __m256i, __m256i),
|
||||||
|
y: (__m256i, __m256i, __m256i),
|
||||||
|
) -> (__m256i, __m256i, __m256i) {
|
||||||
|
let res_wrapped_s = map3!(_mm256_sub_epi64, x_s, y);
|
||||||
|
let mask = map3!(_mm256_cmpgt_epi32, res_wrapped_s, x_s);
|
||||||
|
let res_s = map3!(maybe_adj_sub, res_wrapped_s, mask);
|
||||||
|
res_s
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
unsafe fn reduce3(
|
unsafe fn reduce3(
|
||||||
(x_lo_s, x_hi): ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)),
|
(lo0, hi0): ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)),
|
||||||
) -> (__m256i, __m256i, __m256i) {
|
) -> (__m256i, __m256i, __m256i) {
|
||||||
let epsilon = _mm256_set1_epi64x(0xffffffff);
|
|
||||||
let sign_bit = _mm256_set1_epi64x(i64::MIN);
|
let sign_bit = _mm256_set1_epi64x(i64::MIN);
|
||||||
let x_hi_hi = map3!(_mm256_srli_epi64::<32>, x_hi);
|
let epsilon = _mm256_set1_epi64x(0xffffffff);
|
||||||
let res0_s = map3!(_mm256_sub_epi64, x_lo_s, x_hi_hi);
|
let lo0_s = map3!(_mm256_xor_si256, lo0, rep sign_bit);
|
||||||
let wraparound_mask0 = map3!(_mm256_cmpgt_epi32, res0_s, x_lo_s);
|
let hi_hi0 = map3!(_mm256_srli_epi64::<32>, hi0);
|
||||||
let wraparound_adj0 = map3!(_mm256_srli_epi64::<32>, wraparound_mask0);
|
let lo1_s = sub_tiny(lo0_s, hi_hi0);
|
||||||
let x_hi_lo = map3!(_mm256_and_si256, x_hi, rep epsilon);
|
let t1 = map3!(_mm256_mul_epu32, hi0, rep epsilon);
|
||||||
let x_hi_lo_shifted = map3!(_mm256_slli_epi64::<32>, x_hi);
|
let lo2_s = add_small(lo1_s, t1);
|
||||||
let res1_s = map3!(_mm256_sub_epi64, res0_s, wraparound_adj0);
|
let lo2 = map3!(_mm256_xor_si256, lo2_s, rep sign_bit);
|
||||||
let x_hi_lo_mul_epsilon = map3!(_mm256_sub_epi64, x_hi_lo_shifted, x_hi_lo);
|
lo2
|
||||||
let res2_s = map3!(_mm256_add_epi64, res1_s, x_hi_lo_mul_epsilon);
|
|
||||||
let wraparound_mask2 = map3!(_mm256_cmpgt_epi32, res1_s, res2_s);
|
|
||||||
let wraparound_adj2 = map3!(_mm256_srli_epi64::<32>, wraparound_mask2);
|
|
||||||
let res3_s = map3!(_mm256_add_epi64, res2_s, wraparound_adj2);
|
|
||||||
let res3 = map3!(_mm256_xor_si256, res3_s, rep sign_bit);
|
|
||||||
res3
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user