diff --git a/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs b/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs index df0e9397..1f0978f0 100644 --- a/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs +++ b/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs @@ -1,12 +1,32 @@ use core::arch::x86_64::*; +use std::convert::TryInto; use std::mem::size_of; use crate::field::field_types::Field; use crate::field::goldilocks_field::GoldilocksField; -use crate::hash::poseidon::{ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS}; +use crate::hash::poseidon::{ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS, N_ROUNDS}; + +// WARNING: This code contains tricks that work for the current MDS matrix and round constants, but +// are not guaranteed to work if those are changed. const WIDTH: usize = 12; +// These tranformed round constants are used where the constant layer is fused with the preceeding +// MDS layer. The FUSED_ROUND_CONSTANTS for round i are the ALL_ROUND_CONSTANTS for round i + 1. +// The FUSED_ROUND_CONSTANTS for the very last round are 0, as it is not followed by a constant +// layer. On top of that, all FUSED_ROUND_CONSTANTS are shifted by 2 ** 63 to save a few XORs per +// round. +const fn make_fused_round_constants() -> [u64; WIDTH * N_ROUNDS] { + let mut res = [0x8000000000000000u64; WIDTH * N_ROUNDS]; + let mut i: usize = WIDTH; + while i < WIDTH * N_ROUNDS { + res[i - WIDTH] ^= ALL_ROUND_CONSTANTS[i]; + i += 1; + } + res +} +const FUSED_ROUND_CONSTANTS: [u64; WIDTH * N_ROUNDS] = make_fused_round_constants(); + // This is the top row of the MDS matrix. Concretely, it's the MDS exps vector at the following // indices: [0, 11, ..., 1]. static TOP_ROW_EXPS: [usize; 12] = [0, 10, 16, 3, 12, 8, 1, 5, 3, 0, 1, 0]; @@ -49,50 +69,6 @@ static TOP_ROW_EXPS: [usize; 12] = [0, 10, 16, 3, 12, 8, 1, 5, 3, 0, 1, 0]; // Notice that the above 3-value addition still only requires two calls to shift, just like our // 2-value addition. -#[inline(always)] -unsafe fn const_layer( - (state0_s, state1_s, state2_s): (__m256i, __m256i, __m256i), - (base, index): (*const GoldilocksField, usize), -) -> (__m256i, __m256i, __m256i) { - // TODO: We can make this entire layer effectively free by folding it into MDS multiplication. - let (state0, state1, state2): (__m256i, __m256i, __m256i); - let sign_bit = _mm256_set1_epi64x(i64::MIN); - asm!( - // Below is optimized for latency. In particular, we avoid pcmpgtq because it has latency - // of 3 cycles and can only run on port 5. pcmpgtd is much faster. - "vpaddq {t0}, {state0}, [{base:r} + {index:r}]", - "vpaddq {t1}, {state1}, [{base:r} + {index:r} + 32]", - "vpaddq {t2}, {state2}, [{base:r} + {index:r} + 64]", - // It's okay to do vpcmpgtd (instead of vpcmpgtq) because all the round - // constants are >= 1 << 32 and < field order. - "vpcmpgtd {u0}, {state0}, {t0}", - "vpcmpgtd {u1}, {state1}, {t1}", - "vpcmpgtd {u2}, {state2}, {t2}", - // Unshift by 1 << 63. - "vpxor {t0}, {sign_bit}, {t0}", - "vpxor {t1}, {sign_bit}, {t1}", - "vpxor {t2}, {sign_bit}, {t2}", - // Add epsilon if t >> 32 > state >> 32. - "vpsrlq {u0}, {u0}, 32", - "vpsrlq {u1}, {u1}, 32", - "vpsrlq {u2}, {u2}, 32", - "vpaddq {state0}, {u0}, {t0}", - "vpaddq {state1}, {u1}, {t1}", - "vpaddq {state2}, {u2}, {t2}", - - state0 = inout(ymm_reg) state0_s => state0, - state1 = inout(ymm_reg) state1_s => state1, - state2 = inout(ymm_reg) state2_s => state2, - t0 = out(ymm_reg) _, t1 = out(ymm_reg) _, t2 = out(ymm_reg) _, - u0 = out(ymm_reg) _, u1 = out(ymm_reg) _, u2 = out(ymm_reg) _, - sign_bit = in(ymm_reg) sign_bit, - base = in(reg) base, - index = in(reg) index, - options(pure, readonly, preserves_flags, nostack), - ); - (state0, state1, state2) -} - macro_rules! map3 { ($f:ident::<$l:literal>, $v:ident) => { ($f::<$l>($v.0), $f::<$l>($v.1), $f::<$l>($v.2)) @@ -105,6 +81,41 @@ macro_rules! map3 { }; } +#[inline(always)] +unsafe fn const_layer( + state: (__m256i, __m256i, __m256i), + round_const_arr: &[u64; 12], +) -> (__m256i, __m256i, __m256i) { + let sign_bit = _mm256_set1_epi64x(i64::MIN); + let round_const = ( + _mm256_loadu_si256((&round_const_arr[0..4]).as_ptr().cast::<__m256i>()), + _mm256_loadu_si256((&round_const_arr[4..8]).as_ptr().cast::<__m256i>()), + _mm256_loadu_si256((&round_const_arr[8..12]).as_ptr().cast::<__m256i>()), + ); + let state_s = map3!(_mm256_xor_si256, state, rep sign_bit); // Shift by 2**63. + let res_maybe_wrapped_s = map3!(_mm256_add_epi64, state_s, round_const); + // 32-bit compare is much faster than 64-bit compare on Intel. We can use 32-bit compare here + // as long as we can guarantee that state > res_maybe_wrapped iff state >> 32 > + // res_maybe_wrapped >> 32. Clearly, if state >> 32 > res_maybe_wrapped >> 32, then state > + // res_maybe_wrapped, and similarly for <. + // It remains to show that we can't have state >> 32 == res_maybe_wrapped >> 32 with state > + // res_maybe_wrapped. If state >> 32 == res_maybe_wrapped >> 32, then round_const >> 32 = + // 0xffffffff and the addition of the low doubleword generated a carry bit. This can never + // occur if all round constants are < 0xffffffff00000001 = ORDER: if the high bits are + // 0xffffffff, then the low bits are 0, so the carry bit cannot occur. So this trick is valid + // as long as all the round constants are in canonical form. + // The mask contains 0xffffffff in the high doubleword if wraparound occured and 0 otherwise. + // We will ignore the low doubleword. + let wraparound_mask = map3!(_mm256_cmpgt_epi32, state_s, res_maybe_wrapped_s); + // wraparound_adjustment contains 0xffffffff = EPSILON if wraparound occured and 0 otherwise. + let wraparound_adjustment = map3!(_mm256_srli_epi64::<32>, wraparound_mask); + // XOR commutes with the addition below. Placing it here helps mask latency. + let res_maybe_wrapped = map3!(_mm256_xor_si256, res_maybe_wrapped_s, rep sign_bit); + // Add EPSILON = subtract ORDER. + let res = map3!(_mm256_add_epi64, res_maybe_wrapped, wraparound_adjustment); + res +} + #[inline(always)] unsafe fn square3( x: (__m256i, __m256i, __m256i), @@ -188,17 +199,18 @@ unsafe fn sbox_layer_full(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m25 } #[inline(always)] -unsafe fn mds_layer_reduce_s( +unsafe fn mds_layer_reduce( lo_s: (__m256i, __m256i, __m256i), hi: (__m256i, __m256i, __m256i), ) -> (__m256i, __m256i, __m256i) { // This is done in assembly because, frankly, it's cleaner than intrinsics. We also don't have // to worry about whether the compiler is doing weird things. This entire routine needs proper // pipelining so there's no point rewriting this, only to have to rewrite it again. - let res0_s: __m256i; - let res1_s: __m256i; - let res2_s: __m256i; + let res0: __m256i; + let res1: __m256i; + let res2: __m256i; let epsilon = _mm256_set1_epi64x(0xffffffff); + let sign_bit = _mm256_set1_epi64x(i64::MIN); asm!( // The high results are in ymm3, ymm4, ymm5. // The low results (shifted by 2**63) are in ymm0, ymm1, ymm2 @@ -249,25 +261,29 @@ unsafe fn mds_layer_reduce_s( "vpsrlq ymm6, ymm0, 32", "vpsrlq ymm7, ymm1, 32", "vpsrlq ymm8, ymm2, 32", + "vpxor ymm3, ymm15, ymm3", + "vpxor ymm4, ymm15, ymm4", + "vpxor ymm5, ymm15, ymm5", "vpaddq ymm0, ymm6, ymm3", "vpaddq ymm1, ymm7, ymm4", "vpaddq ymm2, ymm8, ymm5", - inout("ymm0") lo_s.0 => res0_s, - inout("ymm1") lo_s.1 => res1_s, - inout("ymm2") lo_s.2 => res2_s, + inout("ymm0") lo_s.0 => res0, + inout("ymm1") lo_s.1 => res1, + inout("ymm2") lo_s.2 => res2, inout("ymm3") hi.0 => _, inout("ymm4") hi.1 => _, inout("ymm5") hi.2 => _, out("ymm6") _, out("ymm7") _, out("ymm8") _, out("ymm9") _, out("ymm10") _, out("ymm11") _, - in("ymm14") epsilon, + in("ymm14") epsilon, in("ymm15") sign_bit, options(pure, nomem, preserves_flags, nostack), ); - (res0_s, res1_s, res2_s) + (res0, res1, res2) } #[inline(always)] -unsafe fn mds_layer_multiply_s( +unsafe fn mds_multiply_and_add_round_const_s( state: (__m256i, __m256i, __m256i), + (base, index): (*const u64, usize), ) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) { // TODO: Would it be faster to save the input to memory and do unaligned // loads instead of swizzling? It would reduce pressure on port 5 but it @@ -308,9 +324,11 @@ unsafe fn mds_layer_multiply_s( // ymm5[0:2] += ymm7[2:4] // ymm5[2:4] += ymm8[0:2] // Thus, the final result resides in ymm3, ymm4, ymm5. + + // WARNING: This code assumes that sum(1 << exp for exp in MDS_EXPS) * 0xffffffff fits in a + // u64. If this guarantee ceases to hold, then it will no longer be correct. let (unreduced_lo0_s, unreduced_lo1_s, unreduced_lo2_s): (__m256i, __m256i, __m256i); let (unreduced_hi0, unreduced_hi1, unreduced_hi2): (__m256i, __m256i, __m256i); - let sign_bit = _mm256_set1_epi64x(i64::MIN); let epsilon = _mm256_set1_epi64x(0xffffffff); asm!( // Extract low 32 bits of the word @@ -500,11 +518,17 @@ unsafe fn mds_layer_multiply_s( "vpsrlq ymm11, ymm2, 32", // Need to move the low result from ymm3-ymm5 to ymm0-13 so it is not - // overwritten. Save three instructions by combining the move with xor ymm15, - // which would otherwise be done in 3:. - "vpxor ymm0, ymm15, ymm3", - "vpxor ymm1, ymm15, ymm4", - "vpxor ymm2, ymm15, ymm5", + // overwritten. Save three instructions by combining the move with the constant layer, + // which would otherwise be done in 3:. The round constants include the shift by 2**63, so + // the resulting ymm0,1,2 are also shifted by 2**63. + // It is safe to add the round constants here without checking for overflow. The values in + // ymm3,4,5 are guaranteed to be <= 0x11536fffeeac9. All round constants are < 2**64 + // - 0x11536fffeeac9. + // WARNING: If this guarantee ceases to hold due to a change in the MDS matrix or round + // constants, then this code will no longer be correct. + "vpaddq ymm0, ymm3, [{base} + {index}]", + "vpaddq ymm1, ymm4, [{base} + {index} + 32]", + "vpaddq ymm2, ymm5, [{base} + {index} + 64]", // MDS matrix multiplication, again. This time on high 32 bits. // Jump to the _local label_ (see above) `2:`. `b` for _backward_ specifies the direction. @@ -514,7 +538,10 @@ unsafe fn mds_layer_multiply_s( "3:", // Just done the MDS matrix multiplication on high 32 bits. // The high results are in ymm3, ymm4, ymm5. - // The low results (shifted by 2**63) are in ymm0, ymm1, ymm2 + // The low results (shifted by 2**63 and including the following constant layer) are in + // ymm0, ymm1, ymm2. + base = in(reg) base, + index = in(reg) index, inout("ymm0") state.0 => unreduced_lo0_s, inout("ymm1") state.1 => unreduced_lo1_s, inout("ymm2") state.2 => unreduced_lo2_s, @@ -523,7 +550,7 @@ unsafe fn mds_layer_multiply_s( out("ymm5") unreduced_hi2, out("ymm6") _,out("ymm7") _, out("ymm8") _, out("ymm9") _, out("ymm10") _, out("ymm11") _, out("ymm12") _, out("ymm13") _, - in("ymm14") epsilon, in("ymm15") sign_bit, + in("ymm14") epsilon, out("rax") _, options(pure, nomem, nostack), ); @@ -534,9 +561,12 @@ unsafe fn mds_layer_multiply_s( } #[inline(always)] -unsafe fn mds_layer_full_s(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) { - let (unreduced_lo_s, unreduced_hi) = mds_layer_multiply_s(state); - mds_layer_reduce_s(unreduced_lo_s, unreduced_hi) +unsafe fn mds_const_layers_full( + state: (__m256i, __m256i, __m256i), + round_constants: (*const u64, usize), +) -> (__m256i, __m256i, __m256i) { + let (unreduced_lo_s, unreduced_hi) = mds_multiply_and_add_round_const_s(state, round_constants); + mds_layer_reduce(unreduced_lo_s, unreduced_hi) } /// Compute x ** 7 @@ -620,8 +650,9 @@ unsafe fn sbox_partial(mut x: u64) -> u64 { } #[inline(always)] -unsafe fn sbox_mds_layers_partial_s( +unsafe fn partial_round( (state0, state1, state2): (__m256i, __m256i, __m256i), + round_constants: (*const u64, usize), ) -> (__m256i, __m256i, __m256i) { // Extract the low quadword let state0ab: __m128i = _mm256_castsi256_si128(state0); @@ -638,7 +669,7 @@ unsafe fn sbox_mds_layers_partial_s( let ( (mut unreduced_lo0_s, mut unreduced_lo1_s, mut unreduced_lo2_s), (mut unreduced_hi0, mut unreduced_hi1, mut unreduced_hi2), - ) = mds_layer_multiply_s((state0bcd, state1, state2)); + ) = mds_multiply_and_add_round_const_s((state0bcd, state1, state2), round_constants); asm!( // Just done the MDS matrix multiplication on high 32 bits. // The high results are in ymm3, ymm4, ymm5. @@ -688,88 +719,72 @@ unsafe fn sbox_mds_layers_partial_s( in("ymm14") epsilon, options(pure, nomem, preserves_flags, nostack), ); - mds_layer_reduce_s( + mds_layer_reduce( (unreduced_lo0_s, unreduced_lo1_s, unreduced_lo2_s), (unreduced_hi0, unreduced_hi1, unreduced_hi2), ) } #[inline(always)] -unsafe fn full_round_s( - state_s: (__m256i, __m256i, __m256i), - round_constants: (*const GoldilocksField, usize), +unsafe fn full_round( + state: (__m256i, __m256i, __m256i), + round_constants: (*const u64, usize), ) -> (__m256i, __m256i, __m256i) { - let state = const_layer(state_s, round_constants); let state = sbox_layer_full(state); - let state_s = mds_layer_full_s(state); - state_s -} - -#[inline(always)] -unsafe fn partial_round_s( - state_s: (__m256i, __m256i, __m256i), - round_constants: (*const GoldilocksField, usize), -) -> (__m256i, __m256i, __m256i) { - let state = const_layer(state_s, round_constants); - let state_s = sbox_mds_layers_partial_s(state); - state_s + let state = mds_const_layers_full(state, round_constants); + state } #[inline] // Called twice; permit inlining but don't _require_ it -unsafe fn half_full_rounds_s( - mut state_s: (__m256i, __m256i, __m256i), +unsafe fn half_full_rounds( + mut state: (__m256i, __m256i, __m256i), start_round: usize, ) -> (__m256i, __m256i, __m256i) { - let base = (&ALL_ROUND_CONSTANTS + let base = (&FUSED_ROUND_CONSTANTS [WIDTH * start_round..WIDTH * start_round + WIDTH * HALF_N_FULL_ROUNDS]) - .as_ptr() - .cast::(); + .as_ptr(); for i in 0..HALF_N_FULL_ROUNDS { - state_s = full_round_s(state_s, (base, i * WIDTH * size_of::())); + state = full_round(state, (base, i * WIDTH * size_of::())); } - state_s + state } #[inline(always)] -unsafe fn all_partial_rounds_s( - mut state_s: (__m256i, __m256i, __m256i), +unsafe fn all_partial_rounds( + mut state: (__m256i, __m256i, __m256i), start_round: usize, ) -> (__m256i, __m256i, __m256i) { - let base = (&ALL_ROUND_CONSTANTS + let base = (&FUSED_ROUND_CONSTANTS [WIDTH * start_round..WIDTH * start_round + WIDTH * N_PARTIAL_ROUNDS]) - .as_ptr() - .cast::(); + .as_ptr(); for i in 0..N_PARTIAL_ROUNDS { - state_s = partial_round_s(state_s, (base, i * WIDTH * size_of::())); + state = partial_round(state, (base, i * WIDTH * size_of::())); } - state_s + state } #[inline] pub unsafe fn poseidon(state: &[GoldilocksField; 12]) -> [GoldilocksField; 12] { - let sign_bit = _mm256_set1_epi64x(i64::MIN); + let state = ( + _mm256_loadu_si256((&state[0..4]).as_ptr().cast::<__m256i>()), + _mm256_loadu_si256((&state[4..8]).as_ptr().cast::<__m256i>()), + _mm256_loadu_si256((&state[8..12]).as_ptr().cast::<__m256i>()), + ); - let mut s0 = _mm256_loadu_si256((&state[0..4]).as_ptr().cast::<__m256i>()); - let mut s1 = _mm256_loadu_si256((&state[4..8]).as_ptr().cast::<__m256i>()); - let mut s2 = _mm256_loadu_si256((&state[8..12]).as_ptr().cast::<__m256i>()); - s0 = _mm256_xor_si256(s0, sign_bit); - s1 = _mm256_xor_si256(s1, sign_bit); - s2 = _mm256_xor_si256(s2, sign_bit); + // The first constant layer must be done explicitly. The remaining constant layers are fused + // with the preceeding MDS layer. + let state = const_layer(state, &ALL_ROUND_CONSTANTS[0..WIDTH].try_into().unwrap()); - (s0, s1, s2) = half_full_rounds_s((s0, s1, s2), 0); - (s0, s1, s2) = all_partial_rounds_s((s0, s1, s2), HALF_N_FULL_ROUNDS); - (s0, s1, s2) = half_full_rounds_s((s0, s1, s2), HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS); - - s0 = _mm256_xor_si256(s0, sign_bit); - s1 = _mm256_xor_si256(s1, sign_bit); - s2 = _mm256_xor_si256(s2, sign_bit); + let state = half_full_rounds(state, 0); + let state = all_partial_rounds(state, HALF_N_FULL_ROUNDS); + let state = half_full_rounds(state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS); let mut res = [GoldilocksField::ZERO; 12]; - _mm256_storeu_si256((&mut res[0..4]).as_mut_ptr().cast::<__m256i>(), s0); - _mm256_storeu_si256((&mut res[4..8]).as_mut_ptr().cast::<__m256i>(), s1); - _mm256_storeu_si256((&mut res[8..12]).as_mut_ptr().cast::<__m256i>(), s2); + _mm256_storeu_si256((&mut res[0..4]).as_mut_ptr().cast::<__m256i>(), state.0); + _mm256_storeu_si256((&mut res[4..8]).as_mut_ptr().cast::<__m256i>(), state.1); + _mm256_storeu_si256((&mut res[8..12]).as_mut_ptr().cast::<__m256i>(), state.2); res } diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index ae0b132c..5d34c64d 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -17,7 +17,7 @@ use crate::plonk::circuit_builder::CircuitBuilder; pub(crate) const HALF_N_FULL_ROUNDS: usize = 4; pub(crate) const N_FULL_ROUNDS_TOTAL: usize = 2 * HALF_N_FULL_ROUNDS; pub(crate) const N_PARTIAL_ROUNDS: usize = 22; -const N_ROUNDS: usize = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS; +pub(crate) const N_ROUNDS: usize = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS; const MAX_WIDTH: usize = 12; // we only have width 8 and 12, and 12 is bigger. :) #[inline(always)] @@ -45,6 +45,10 @@ pub const ALL_ROUND_CONSTANTS: [u64; MAX_WIDTH * N_ROUNDS] = [ // not met, some platform-specific implementation of constant_layer may return incorrect // results. // + // WARNING: The AVX2 Goldilocks specialization relies on all round constants being in + // 0..0xfffeeac900011537. If these constants are randomly regenerated, there is a ~.6% chance + // that this condition will no longer hold. + // // WARNING: If these are changed in any way, then all the // implementations of Poseidon must be regenerated. See comments // in `poseidon_goldilocks.rs` and `poseidon_crandall.rs` for