AVX2: Fold the constant layer into MDS matrix multiplication (#302)

* Fuse constant layer with MDS matrix multiplication

* Warnings and lints

* Minor documentation
This commit is contained in:
Jakub Nabaglo 2021-10-21 16:51:06 -07:00 committed by GitHub
parent 7d45c80c03
commit 001c979599
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 136 additions and 117 deletions

View File

@ -1,12 +1,32 @@
use core::arch::x86_64::*;
use std::convert::TryInto;
use std::mem::size_of;
use crate::field::field_types::Field;
use crate::field::goldilocks_field::GoldilocksField;
use crate::hash::poseidon::{ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS};
use crate::hash::poseidon::{ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS, N_ROUNDS};
// WARNING: This code contains tricks that work for the current MDS matrix and round constants, but
// are not guaranteed to work if those are changed.
const WIDTH: usize = 12;
// These tranformed round constants are used where the constant layer is fused with the preceeding
// MDS layer. The FUSED_ROUND_CONSTANTS for round i are the ALL_ROUND_CONSTANTS for round i + 1.
// The FUSED_ROUND_CONSTANTS for the very last round are 0, as it is not followed by a constant
// layer. On top of that, all FUSED_ROUND_CONSTANTS are shifted by 2 ** 63 to save a few XORs per
// round.
const fn make_fused_round_constants() -> [u64; WIDTH * N_ROUNDS] {
let mut res = [0x8000000000000000u64; WIDTH * N_ROUNDS];
let mut i: usize = WIDTH;
while i < WIDTH * N_ROUNDS {
res[i - WIDTH] ^= ALL_ROUND_CONSTANTS[i];
i += 1;
}
res
}
const FUSED_ROUND_CONSTANTS: [u64; WIDTH * N_ROUNDS] = make_fused_round_constants();
// This is the top row of the MDS matrix. Concretely, it's the MDS exps vector at the following
// indices: [0, 11, ..., 1].
static TOP_ROW_EXPS: [usize; 12] = [0, 10, 16, 3, 12, 8, 1, 5, 3, 0, 1, 0];
@ -49,50 +69,6 @@ static TOP_ROW_EXPS: [usize; 12] = [0, 10, 16, 3, 12, 8, 1, 5, 3, 0, 1, 0];
// Notice that the above 3-value addition still only requires two calls to shift, just like our
// 2-value addition.
#[inline(always)]
unsafe fn const_layer(
(state0_s, state1_s, state2_s): (__m256i, __m256i, __m256i),
(base, index): (*const GoldilocksField, usize),
) -> (__m256i, __m256i, __m256i) {
// TODO: We can make this entire layer effectively free by folding it into MDS multiplication.
let (state0, state1, state2): (__m256i, __m256i, __m256i);
let sign_bit = _mm256_set1_epi64x(i64::MIN);
asm!(
// Below is optimized for latency. In particular, we avoid pcmpgtq because it has latency
// of 3 cycles and can only run on port 5. pcmpgtd is much faster.
"vpaddq {t0}, {state0}, [{base:r} + {index:r}]",
"vpaddq {t1}, {state1}, [{base:r} + {index:r} + 32]",
"vpaddq {t2}, {state2}, [{base:r} + {index:r} + 64]",
// It's okay to do vpcmpgtd (instead of vpcmpgtq) because all the round
// constants are >= 1 << 32 and < field order.
"vpcmpgtd {u0}, {state0}, {t0}",
"vpcmpgtd {u1}, {state1}, {t1}",
"vpcmpgtd {u2}, {state2}, {t2}",
// Unshift by 1 << 63.
"vpxor {t0}, {sign_bit}, {t0}",
"vpxor {t1}, {sign_bit}, {t1}",
"vpxor {t2}, {sign_bit}, {t2}",
// Add epsilon if t >> 32 > state >> 32.
"vpsrlq {u0}, {u0}, 32",
"vpsrlq {u1}, {u1}, 32",
"vpsrlq {u2}, {u2}, 32",
"vpaddq {state0}, {u0}, {t0}",
"vpaddq {state1}, {u1}, {t1}",
"vpaddq {state2}, {u2}, {t2}",
state0 = inout(ymm_reg) state0_s => state0,
state1 = inout(ymm_reg) state1_s => state1,
state2 = inout(ymm_reg) state2_s => state2,
t0 = out(ymm_reg) _, t1 = out(ymm_reg) _, t2 = out(ymm_reg) _,
u0 = out(ymm_reg) _, u1 = out(ymm_reg) _, u2 = out(ymm_reg) _,
sign_bit = in(ymm_reg) sign_bit,
base = in(reg) base,
index = in(reg) index,
options(pure, readonly, preserves_flags, nostack),
);
(state0, state1, state2)
}
macro_rules! map3 {
($f:ident::<$l:literal>, $v:ident) => {
($f::<$l>($v.0), $f::<$l>($v.1), $f::<$l>($v.2))
@ -105,6 +81,41 @@ macro_rules! map3 {
};
}
#[inline(always)]
unsafe fn const_layer(
state: (__m256i, __m256i, __m256i),
round_const_arr: &[u64; 12],
) -> (__m256i, __m256i, __m256i) {
let sign_bit = _mm256_set1_epi64x(i64::MIN);
let round_const = (
_mm256_loadu_si256((&round_const_arr[0..4]).as_ptr().cast::<__m256i>()),
_mm256_loadu_si256((&round_const_arr[4..8]).as_ptr().cast::<__m256i>()),
_mm256_loadu_si256((&round_const_arr[8..12]).as_ptr().cast::<__m256i>()),
);
let state_s = map3!(_mm256_xor_si256, state, rep sign_bit); // Shift by 2**63.
let res_maybe_wrapped_s = map3!(_mm256_add_epi64, state_s, round_const);
// 32-bit compare is much faster than 64-bit compare on Intel. We can use 32-bit compare here
// as long as we can guarantee that state > res_maybe_wrapped iff state >> 32 >
// res_maybe_wrapped >> 32. Clearly, if state >> 32 > res_maybe_wrapped >> 32, then state >
// res_maybe_wrapped, and similarly for <.
// It remains to show that we can't have state >> 32 == res_maybe_wrapped >> 32 with state >
// res_maybe_wrapped. If state >> 32 == res_maybe_wrapped >> 32, then round_const >> 32 =
// 0xffffffff and the addition of the low doubleword generated a carry bit. This can never
// occur if all round constants are < 0xffffffff00000001 = ORDER: if the high bits are
// 0xffffffff, then the low bits are 0, so the carry bit cannot occur. So this trick is valid
// as long as all the round constants are in canonical form.
// The mask contains 0xffffffff in the high doubleword if wraparound occured and 0 otherwise.
// We will ignore the low doubleword.
let wraparound_mask = map3!(_mm256_cmpgt_epi32, state_s, res_maybe_wrapped_s);
// wraparound_adjustment contains 0xffffffff = EPSILON if wraparound occured and 0 otherwise.
let wraparound_adjustment = map3!(_mm256_srli_epi64::<32>, wraparound_mask);
// XOR commutes with the addition below. Placing it here helps mask latency.
let res_maybe_wrapped = map3!(_mm256_xor_si256, res_maybe_wrapped_s, rep sign_bit);
// Add EPSILON = subtract ORDER.
let res = map3!(_mm256_add_epi64, res_maybe_wrapped, wraparound_adjustment);
res
}
#[inline(always)]
unsafe fn square3(
x: (__m256i, __m256i, __m256i),
@ -188,17 +199,18 @@ unsafe fn sbox_layer_full(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m25
}
#[inline(always)]
unsafe fn mds_layer_reduce_s(
unsafe fn mds_layer_reduce(
lo_s: (__m256i, __m256i, __m256i),
hi: (__m256i, __m256i, __m256i),
) -> (__m256i, __m256i, __m256i) {
// This is done in assembly because, frankly, it's cleaner than intrinsics. We also don't have
// to worry about whether the compiler is doing weird things. This entire routine needs proper
// pipelining so there's no point rewriting this, only to have to rewrite it again.
let res0_s: __m256i;
let res1_s: __m256i;
let res2_s: __m256i;
let res0: __m256i;
let res1: __m256i;
let res2: __m256i;
let epsilon = _mm256_set1_epi64x(0xffffffff);
let sign_bit = _mm256_set1_epi64x(i64::MIN);
asm!(
// The high results are in ymm3, ymm4, ymm5.
// The low results (shifted by 2**63) are in ymm0, ymm1, ymm2
@ -249,25 +261,29 @@ unsafe fn mds_layer_reduce_s(
"vpsrlq ymm6, ymm0, 32",
"vpsrlq ymm7, ymm1, 32",
"vpsrlq ymm8, ymm2, 32",
"vpxor ymm3, ymm15, ymm3",
"vpxor ymm4, ymm15, ymm4",
"vpxor ymm5, ymm15, ymm5",
"vpaddq ymm0, ymm6, ymm3",
"vpaddq ymm1, ymm7, ymm4",
"vpaddq ymm2, ymm8, ymm5",
inout("ymm0") lo_s.0 => res0_s,
inout("ymm1") lo_s.1 => res1_s,
inout("ymm2") lo_s.2 => res2_s,
inout("ymm0") lo_s.0 => res0,
inout("ymm1") lo_s.1 => res1,
inout("ymm2") lo_s.2 => res2,
inout("ymm3") hi.0 => _,
inout("ymm4") hi.1 => _,
inout("ymm5") hi.2 => _,
out("ymm6") _, out("ymm7") _, out("ymm8") _, out("ymm9") _, out("ymm10") _, out("ymm11") _,
in("ymm14") epsilon,
in("ymm14") epsilon, in("ymm15") sign_bit,
options(pure, nomem, preserves_flags, nostack),
);
(res0_s, res1_s, res2_s)
(res0, res1, res2)
}
#[inline(always)]
unsafe fn mds_layer_multiply_s(
unsafe fn mds_multiply_and_add_round_const_s(
state: (__m256i, __m256i, __m256i),
(base, index): (*const u64, usize),
) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) {
// TODO: Would it be faster to save the input to memory and do unaligned
// loads instead of swizzling? It would reduce pressure on port 5 but it
@ -308,9 +324,11 @@ unsafe fn mds_layer_multiply_s(
// ymm5[0:2] += ymm7[2:4]
// ymm5[2:4] += ymm8[0:2]
// Thus, the final result resides in ymm3, ymm4, ymm5.
// WARNING: This code assumes that sum(1 << exp for exp in MDS_EXPS) * 0xffffffff fits in a
// u64. If this guarantee ceases to hold, then it will no longer be correct.
let (unreduced_lo0_s, unreduced_lo1_s, unreduced_lo2_s): (__m256i, __m256i, __m256i);
let (unreduced_hi0, unreduced_hi1, unreduced_hi2): (__m256i, __m256i, __m256i);
let sign_bit = _mm256_set1_epi64x(i64::MIN);
let epsilon = _mm256_set1_epi64x(0xffffffff);
asm!(
// Extract low 32 bits of the word
@ -500,11 +518,17 @@ unsafe fn mds_layer_multiply_s(
"vpsrlq ymm11, ymm2, 32",
// Need to move the low result from ymm3-ymm5 to ymm0-13 so it is not
// overwritten. Save three instructions by combining the move with xor ymm15,
// which would otherwise be done in 3:.
"vpxor ymm0, ymm15, ymm3",
"vpxor ymm1, ymm15, ymm4",
"vpxor ymm2, ymm15, ymm5",
// overwritten. Save three instructions by combining the move with the constant layer,
// which would otherwise be done in 3:. The round constants include the shift by 2**63, so
// the resulting ymm0,1,2 are also shifted by 2**63.
// It is safe to add the round constants here without checking for overflow. The values in
// ymm3,4,5 are guaranteed to be <= 0x11536fffeeac9. All round constants are < 2**64
// - 0x11536fffeeac9.
// WARNING: If this guarantee ceases to hold due to a change in the MDS matrix or round
// constants, then this code will no longer be correct.
"vpaddq ymm0, ymm3, [{base} + {index}]",
"vpaddq ymm1, ymm4, [{base} + {index} + 32]",
"vpaddq ymm2, ymm5, [{base} + {index} + 64]",
// MDS matrix multiplication, again. This time on high 32 bits.
// Jump to the _local label_ (see above) `2:`. `b` for _backward_ specifies the direction.
@ -514,7 +538,10 @@ unsafe fn mds_layer_multiply_s(
"3:",
// Just done the MDS matrix multiplication on high 32 bits.
// The high results are in ymm3, ymm4, ymm5.
// The low results (shifted by 2**63) are in ymm0, ymm1, ymm2
// The low results (shifted by 2**63 and including the following constant layer) are in
// ymm0, ymm1, ymm2.
base = in(reg) base,
index = in(reg) index,
inout("ymm0") state.0 => unreduced_lo0_s,
inout("ymm1") state.1 => unreduced_lo1_s,
inout("ymm2") state.2 => unreduced_lo2_s,
@ -523,7 +550,7 @@ unsafe fn mds_layer_multiply_s(
out("ymm5") unreduced_hi2,
out("ymm6") _,out("ymm7") _, out("ymm8") _, out("ymm9") _,
out("ymm10") _, out("ymm11") _, out("ymm12") _, out("ymm13") _,
in("ymm14") epsilon, in("ymm15") sign_bit,
in("ymm14") epsilon,
out("rax") _,
options(pure, nomem, nostack),
);
@ -534,9 +561,12 @@ unsafe fn mds_layer_multiply_s(
}
#[inline(always)]
unsafe fn mds_layer_full_s(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
let (unreduced_lo_s, unreduced_hi) = mds_layer_multiply_s(state);
mds_layer_reduce_s(unreduced_lo_s, unreduced_hi)
unsafe fn mds_const_layers_full(
state: (__m256i, __m256i, __m256i),
round_constants: (*const u64, usize),
) -> (__m256i, __m256i, __m256i) {
let (unreduced_lo_s, unreduced_hi) = mds_multiply_and_add_round_const_s(state, round_constants);
mds_layer_reduce(unreduced_lo_s, unreduced_hi)
}
/// Compute x ** 7
@ -620,8 +650,9 @@ unsafe fn sbox_partial(mut x: u64) -> u64 {
}
#[inline(always)]
unsafe fn sbox_mds_layers_partial_s(
unsafe fn partial_round(
(state0, state1, state2): (__m256i, __m256i, __m256i),
round_constants: (*const u64, usize),
) -> (__m256i, __m256i, __m256i) {
// Extract the low quadword
let state0ab: __m128i = _mm256_castsi256_si128(state0);
@ -638,7 +669,7 @@ unsafe fn sbox_mds_layers_partial_s(
let (
(mut unreduced_lo0_s, mut unreduced_lo1_s, mut unreduced_lo2_s),
(mut unreduced_hi0, mut unreduced_hi1, mut unreduced_hi2),
) = mds_layer_multiply_s((state0bcd, state1, state2));
) = mds_multiply_and_add_round_const_s((state0bcd, state1, state2), round_constants);
asm!(
// Just done the MDS matrix multiplication on high 32 bits.
// The high results are in ymm3, ymm4, ymm5.
@ -688,88 +719,72 @@ unsafe fn sbox_mds_layers_partial_s(
in("ymm14") epsilon,
options(pure, nomem, preserves_flags, nostack),
);
mds_layer_reduce_s(
mds_layer_reduce(
(unreduced_lo0_s, unreduced_lo1_s, unreduced_lo2_s),
(unreduced_hi0, unreduced_hi1, unreduced_hi2),
)
}
#[inline(always)]
unsafe fn full_round_s(
state_s: (__m256i, __m256i, __m256i),
round_constants: (*const GoldilocksField, usize),
unsafe fn full_round(
state: (__m256i, __m256i, __m256i),
round_constants: (*const u64, usize),
) -> (__m256i, __m256i, __m256i) {
let state = const_layer(state_s, round_constants);
let state = sbox_layer_full(state);
let state_s = mds_layer_full_s(state);
state_s
}
#[inline(always)]
unsafe fn partial_round_s(
state_s: (__m256i, __m256i, __m256i),
round_constants: (*const GoldilocksField, usize),
) -> (__m256i, __m256i, __m256i) {
let state = const_layer(state_s, round_constants);
let state_s = sbox_mds_layers_partial_s(state);
state_s
let state = mds_const_layers_full(state, round_constants);
state
}
#[inline] // Called twice; permit inlining but don't _require_ it
unsafe fn half_full_rounds_s(
mut state_s: (__m256i, __m256i, __m256i),
unsafe fn half_full_rounds(
mut state: (__m256i, __m256i, __m256i),
start_round: usize,
) -> (__m256i, __m256i, __m256i) {
let base = (&ALL_ROUND_CONSTANTS
let base = (&FUSED_ROUND_CONSTANTS
[WIDTH * start_round..WIDTH * start_round + WIDTH * HALF_N_FULL_ROUNDS])
.as_ptr()
.cast::<GoldilocksField>();
.as_ptr();
for i in 0..HALF_N_FULL_ROUNDS {
state_s = full_round_s(state_s, (base, i * WIDTH * size_of::<u64>()));
state = full_round(state, (base, i * WIDTH * size_of::<u64>()));
}
state_s
state
}
#[inline(always)]
unsafe fn all_partial_rounds_s(
mut state_s: (__m256i, __m256i, __m256i),
unsafe fn all_partial_rounds(
mut state: (__m256i, __m256i, __m256i),
start_round: usize,
) -> (__m256i, __m256i, __m256i) {
let base = (&ALL_ROUND_CONSTANTS
let base = (&FUSED_ROUND_CONSTANTS
[WIDTH * start_round..WIDTH * start_round + WIDTH * N_PARTIAL_ROUNDS])
.as_ptr()
.cast::<GoldilocksField>();
.as_ptr();
for i in 0..N_PARTIAL_ROUNDS {
state_s = partial_round_s(state_s, (base, i * WIDTH * size_of::<u64>()));
state = partial_round(state, (base, i * WIDTH * size_of::<u64>()));
}
state_s
state
}
#[inline]
pub unsafe fn poseidon(state: &[GoldilocksField; 12]) -> [GoldilocksField; 12] {
let sign_bit = _mm256_set1_epi64x(i64::MIN);
let state = (
_mm256_loadu_si256((&state[0..4]).as_ptr().cast::<__m256i>()),
_mm256_loadu_si256((&state[4..8]).as_ptr().cast::<__m256i>()),
_mm256_loadu_si256((&state[8..12]).as_ptr().cast::<__m256i>()),
);
let mut s0 = _mm256_loadu_si256((&state[0..4]).as_ptr().cast::<__m256i>());
let mut s1 = _mm256_loadu_si256((&state[4..8]).as_ptr().cast::<__m256i>());
let mut s2 = _mm256_loadu_si256((&state[8..12]).as_ptr().cast::<__m256i>());
s0 = _mm256_xor_si256(s0, sign_bit);
s1 = _mm256_xor_si256(s1, sign_bit);
s2 = _mm256_xor_si256(s2, sign_bit);
// The first constant layer must be done explicitly. The remaining constant layers are fused
// with the preceeding MDS layer.
let state = const_layer(state, &ALL_ROUND_CONSTANTS[0..WIDTH].try_into().unwrap());
(s0, s1, s2) = half_full_rounds_s((s0, s1, s2), 0);
(s0, s1, s2) = all_partial_rounds_s((s0, s1, s2), HALF_N_FULL_ROUNDS);
(s0, s1, s2) = half_full_rounds_s((s0, s1, s2), HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS);
s0 = _mm256_xor_si256(s0, sign_bit);
s1 = _mm256_xor_si256(s1, sign_bit);
s2 = _mm256_xor_si256(s2, sign_bit);
let state = half_full_rounds(state, 0);
let state = all_partial_rounds(state, HALF_N_FULL_ROUNDS);
let state = half_full_rounds(state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS);
let mut res = [GoldilocksField::ZERO; 12];
_mm256_storeu_si256((&mut res[0..4]).as_mut_ptr().cast::<__m256i>(), s0);
_mm256_storeu_si256((&mut res[4..8]).as_mut_ptr().cast::<__m256i>(), s1);
_mm256_storeu_si256((&mut res[8..12]).as_mut_ptr().cast::<__m256i>(), s2);
_mm256_storeu_si256((&mut res[0..4]).as_mut_ptr().cast::<__m256i>(), state.0);
_mm256_storeu_si256((&mut res[4..8]).as_mut_ptr().cast::<__m256i>(), state.1);
_mm256_storeu_si256((&mut res[8..12]).as_mut_ptr().cast::<__m256i>(), state.2);
res
}

View File

@ -17,7 +17,7 @@ use crate::plonk::circuit_builder::CircuitBuilder;
pub(crate) const HALF_N_FULL_ROUNDS: usize = 4;
pub(crate) const N_FULL_ROUNDS_TOTAL: usize = 2 * HALF_N_FULL_ROUNDS;
pub(crate) const N_PARTIAL_ROUNDS: usize = 22;
const N_ROUNDS: usize = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS;
pub(crate) const N_ROUNDS: usize = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS;
const MAX_WIDTH: usize = 12; // we only have width 8 and 12, and 12 is bigger. :)
#[inline(always)]
@ -45,6 +45,10 @@ pub const ALL_ROUND_CONSTANTS: [u64; MAX_WIDTH * N_ROUNDS] = [
// not met, some platform-specific implementation of constant_layer may return incorrect
// results.
//
// WARNING: The AVX2 Goldilocks specialization relies on all round constants being in
// 0..0xfffeeac900011537. If these constants are randomly regenerated, there is a ~.6% chance
// that this condition will no longer hold.
//
// WARNING: If these are changed in any way, then all the
// implementations of Poseidon must be regenerated. See comments
// in `poseidon_goldilocks.rs` and `poseidon_crandall.rs` for