mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-06 07:43:10 +00:00
AVX2: Fold the constant layer into MDS matrix multiplication (#302)
* Fuse constant layer with MDS matrix multiplication * Warnings and lints * Minor documentation
This commit is contained in:
parent
7d45c80c03
commit
001c979599
@ -1,12 +1,32 @@
|
||||
use core::arch::x86_64::*;
|
||||
use std::convert::TryInto;
|
||||
use std::mem::size_of;
|
||||
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::hash::poseidon::{ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS};
|
||||
use crate::hash::poseidon::{ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS, N_ROUNDS};
|
||||
|
||||
// WARNING: This code contains tricks that work for the current MDS matrix and round constants, but
|
||||
// are not guaranteed to work if those are changed.
|
||||
|
||||
const WIDTH: usize = 12;
|
||||
|
||||
// These tranformed round constants are used where the constant layer is fused with the preceeding
|
||||
// MDS layer. The FUSED_ROUND_CONSTANTS for round i are the ALL_ROUND_CONSTANTS for round i + 1.
|
||||
// The FUSED_ROUND_CONSTANTS for the very last round are 0, as it is not followed by a constant
|
||||
// layer. On top of that, all FUSED_ROUND_CONSTANTS are shifted by 2 ** 63 to save a few XORs per
|
||||
// round.
|
||||
const fn make_fused_round_constants() -> [u64; WIDTH * N_ROUNDS] {
|
||||
let mut res = [0x8000000000000000u64; WIDTH * N_ROUNDS];
|
||||
let mut i: usize = WIDTH;
|
||||
while i < WIDTH * N_ROUNDS {
|
||||
res[i - WIDTH] ^= ALL_ROUND_CONSTANTS[i];
|
||||
i += 1;
|
||||
}
|
||||
res
|
||||
}
|
||||
const FUSED_ROUND_CONSTANTS: [u64; WIDTH * N_ROUNDS] = make_fused_round_constants();
|
||||
|
||||
// This is the top row of the MDS matrix. Concretely, it's the MDS exps vector at the following
|
||||
// indices: [0, 11, ..., 1].
|
||||
static TOP_ROW_EXPS: [usize; 12] = [0, 10, 16, 3, 12, 8, 1, 5, 3, 0, 1, 0];
|
||||
@ -49,50 +69,6 @@ static TOP_ROW_EXPS: [usize; 12] = [0, 10, 16, 3, 12, 8, 1, 5, 3, 0, 1, 0];
|
||||
// Notice that the above 3-value addition still only requires two calls to shift, just like our
|
||||
// 2-value addition.
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn const_layer(
|
||||
(state0_s, state1_s, state2_s): (__m256i, __m256i, __m256i),
|
||||
(base, index): (*const GoldilocksField, usize),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
// TODO: We can make this entire layer effectively free by folding it into MDS multiplication.
|
||||
let (state0, state1, state2): (__m256i, __m256i, __m256i);
|
||||
let sign_bit = _mm256_set1_epi64x(i64::MIN);
|
||||
asm!(
|
||||
// Below is optimized for latency. In particular, we avoid pcmpgtq because it has latency
|
||||
// of 3 cycles and can only run on port 5. pcmpgtd is much faster.
|
||||
"vpaddq {t0}, {state0}, [{base:r} + {index:r}]",
|
||||
"vpaddq {t1}, {state1}, [{base:r} + {index:r} + 32]",
|
||||
"vpaddq {t2}, {state2}, [{base:r} + {index:r} + 64]",
|
||||
// It's okay to do vpcmpgtd (instead of vpcmpgtq) because all the round
|
||||
// constants are >= 1 << 32 and < field order.
|
||||
"vpcmpgtd {u0}, {state0}, {t0}",
|
||||
"vpcmpgtd {u1}, {state1}, {t1}",
|
||||
"vpcmpgtd {u2}, {state2}, {t2}",
|
||||
// Unshift by 1 << 63.
|
||||
"vpxor {t0}, {sign_bit}, {t0}",
|
||||
"vpxor {t1}, {sign_bit}, {t1}",
|
||||
"vpxor {t2}, {sign_bit}, {t2}",
|
||||
// Add epsilon if t >> 32 > state >> 32.
|
||||
"vpsrlq {u0}, {u0}, 32",
|
||||
"vpsrlq {u1}, {u1}, 32",
|
||||
"vpsrlq {u2}, {u2}, 32",
|
||||
"vpaddq {state0}, {u0}, {t0}",
|
||||
"vpaddq {state1}, {u1}, {t1}",
|
||||
"vpaddq {state2}, {u2}, {t2}",
|
||||
|
||||
state0 = inout(ymm_reg) state0_s => state0,
|
||||
state1 = inout(ymm_reg) state1_s => state1,
|
||||
state2 = inout(ymm_reg) state2_s => state2,
|
||||
t0 = out(ymm_reg) _, t1 = out(ymm_reg) _, t2 = out(ymm_reg) _,
|
||||
u0 = out(ymm_reg) _, u1 = out(ymm_reg) _, u2 = out(ymm_reg) _,
|
||||
sign_bit = in(ymm_reg) sign_bit,
|
||||
base = in(reg) base,
|
||||
index = in(reg) index,
|
||||
options(pure, readonly, preserves_flags, nostack),
|
||||
);
|
||||
(state0, state1, state2)
|
||||
}
|
||||
|
||||
macro_rules! map3 {
|
||||
($f:ident::<$l:literal>, $v:ident) => {
|
||||
($f::<$l>($v.0), $f::<$l>($v.1), $f::<$l>($v.2))
|
||||
@ -105,6 +81,41 @@ macro_rules! map3 {
|
||||
};
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn const_layer(
|
||||
state: (__m256i, __m256i, __m256i),
|
||||
round_const_arr: &[u64; 12],
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let sign_bit = _mm256_set1_epi64x(i64::MIN);
|
||||
let round_const = (
|
||||
_mm256_loadu_si256((&round_const_arr[0..4]).as_ptr().cast::<__m256i>()),
|
||||
_mm256_loadu_si256((&round_const_arr[4..8]).as_ptr().cast::<__m256i>()),
|
||||
_mm256_loadu_si256((&round_const_arr[8..12]).as_ptr().cast::<__m256i>()),
|
||||
);
|
||||
let state_s = map3!(_mm256_xor_si256, state, rep sign_bit); // Shift by 2**63.
|
||||
let res_maybe_wrapped_s = map3!(_mm256_add_epi64, state_s, round_const);
|
||||
// 32-bit compare is much faster than 64-bit compare on Intel. We can use 32-bit compare here
|
||||
// as long as we can guarantee that state > res_maybe_wrapped iff state >> 32 >
|
||||
// res_maybe_wrapped >> 32. Clearly, if state >> 32 > res_maybe_wrapped >> 32, then state >
|
||||
// res_maybe_wrapped, and similarly for <.
|
||||
// It remains to show that we can't have state >> 32 == res_maybe_wrapped >> 32 with state >
|
||||
// res_maybe_wrapped. If state >> 32 == res_maybe_wrapped >> 32, then round_const >> 32 =
|
||||
// 0xffffffff and the addition of the low doubleword generated a carry bit. This can never
|
||||
// occur if all round constants are < 0xffffffff00000001 = ORDER: if the high bits are
|
||||
// 0xffffffff, then the low bits are 0, so the carry bit cannot occur. So this trick is valid
|
||||
// as long as all the round constants are in canonical form.
|
||||
// The mask contains 0xffffffff in the high doubleword if wraparound occured and 0 otherwise.
|
||||
// We will ignore the low doubleword.
|
||||
let wraparound_mask = map3!(_mm256_cmpgt_epi32, state_s, res_maybe_wrapped_s);
|
||||
// wraparound_adjustment contains 0xffffffff = EPSILON if wraparound occured and 0 otherwise.
|
||||
let wraparound_adjustment = map3!(_mm256_srli_epi64::<32>, wraparound_mask);
|
||||
// XOR commutes with the addition below. Placing it here helps mask latency.
|
||||
let res_maybe_wrapped = map3!(_mm256_xor_si256, res_maybe_wrapped_s, rep sign_bit);
|
||||
// Add EPSILON = subtract ORDER.
|
||||
let res = map3!(_mm256_add_epi64, res_maybe_wrapped, wraparound_adjustment);
|
||||
res
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn square3(
|
||||
x: (__m256i, __m256i, __m256i),
|
||||
@ -188,17 +199,18 @@ unsafe fn sbox_layer_full(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m25
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn mds_layer_reduce_s(
|
||||
unsafe fn mds_layer_reduce(
|
||||
lo_s: (__m256i, __m256i, __m256i),
|
||||
hi: (__m256i, __m256i, __m256i),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
// This is done in assembly because, frankly, it's cleaner than intrinsics. We also don't have
|
||||
// to worry about whether the compiler is doing weird things. This entire routine needs proper
|
||||
// pipelining so there's no point rewriting this, only to have to rewrite it again.
|
||||
let res0_s: __m256i;
|
||||
let res1_s: __m256i;
|
||||
let res2_s: __m256i;
|
||||
let res0: __m256i;
|
||||
let res1: __m256i;
|
||||
let res2: __m256i;
|
||||
let epsilon = _mm256_set1_epi64x(0xffffffff);
|
||||
let sign_bit = _mm256_set1_epi64x(i64::MIN);
|
||||
asm!(
|
||||
// The high results are in ymm3, ymm4, ymm5.
|
||||
// The low results (shifted by 2**63) are in ymm0, ymm1, ymm2
|
||||
@ -249,25 +261,29 @@ unsafe fn mds_layer_reduce_s(
|
||||
"vpsrlq ymm6, ymm0, 32",
|
||||
"vpsrlq ymm7, ymm1, 32",
|
||||
"vpsrlq ymm8, ymm2, 32",
|
||||
"vpxor ymm3, ymm15, ymm3",
|
||||
"vpxor ymm4, ymm15, ymm4",
|
||||
"vpxor ymm5, ymm15, ymm5",
|
||||
"vpaddq ymm0, ymm6, ymm3",
|
||||
"vpaddq ymm1, ymm7, ymm4",
|
||||
"vpaddq ymm2, ymm8, ymm5",
|
||||
inout("ymm0") lo_s.0 => res0_s,
|
||||
inout("ymm1") lo_s.1 => res1_s,
|
||||
inout("ymm2") lo_s.2 => res2_s,
|
||||
inout("ymm0") lo_s.0 => res0,
|
||||
inout("ymm1") lo_s.1 => res1,
|
||||
inout("ymm2") lo_s.2 => res2,
|
||||
inout("ymm3") hi.0 => _,
|
||||
inout("ymm4") hi.1 => _,
|
||||
inout("ymm5") hi.2 => _,
|
||||
out("ymm6") _, out("ymm7") _, out("ymm8") _, out("ymm9") _, out("ymm10") _, out("ymm11") _,
|
||||
in("ymm14") epsilon,
|
||||
in("ymm14") epsilon, in("ymm15") sign_bit,
|
||||
options(pure, nomem, preserves_flags, nostack),
|
||||
);
|
||||
(res0_s, res1_s, res2_s)
|
||||
(res0, res1, res2)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn mds_layer_multiply_s(
|
||||
unsafe fn mds_multiply_and_add_round_const_s(
|
||||
state: (__m256i, __m256i, __m256i),
|
||||
(base, index): (*const u64, usize),
|
||||
) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) {
|
||||
// TODO: Would it be faster to save the input to memory and do unaligned
|
||||
// loads instead of swizzling? It would reduce pressure on port 5 but it
|
||||
@ -308,9 +324,11 @@ unsafe fn mds_layer_multiply_s(
|
||||
// ymm5[0:2] += ymm7[2:4]
|
||||
// ymm5[2:4] += ymm8[0:2]
|
||||
// Thus, the final result resides in ymm3, ymm4, ymm5.
|
||||
|
||||
// WARNING: This code assumes that sum(1 << exp for exp in MDS_EXPS) * 0xffffffff fits in a
|
||||
// u64. If this guarantee ceases to hold, then it will no longer be correct.
|
||||
let (unreduced_lo0_s, unreduced_lo1_s, unreduced_lo2_s): (__m256i, __m256i, __m256i);
|
||||
let (unreduced_hi0, unreduced_hi1, unreduced_hi2): (__m256i, __m256i, __m256i);
|
||||
let sign_bit = _mm256_set1_epi64x(i64::MIN);
|
||||
let epsilon = _mm256_set1_epi64x(0xffffffff);
|
||||
asm!(
|
||||
// Extract low 32 bits of the word
|
||||
@ -500,11 +518,17 @@ unsafe fn mds_layer_multiply_s(
|
||||
"vpsrlq ymm11, ymm2, 32",
|
||||
|
||||
// Need to move the low result from ymm3-ymm5 to ymm0-13 so it is not
|
||||
// overwritten. Save three instructions by combining the move with xor ymm15,
|
||||
// which would otherwise be done in 3:.
|
||||
"vpxor ymm0, ymm15, ymm3",
|
||||
"vpxor ymm1, ymm15, ymm4",
|
||||
"vpxor ymm2, ymm15, ymm5",
|
||||
// overwritten. Save three instructions by combining the move with the constant layer,
|
||||
// which would otherwise be done in 3:. The round constants include the shift by 2**63, so
|
||||
// the resulting ymm0,1,2 are also shifted by 2**63.
|
||||
// It is safe to add the round constants here without checking for overflow. The values in
|
||||
// ymm3,4,5 are guaranteed to be <= 0x11536fffeeac9. All round constants are < 2**64
|
||||
// - 0x11536fffeeac9.
|
||||
// WARNING: If this guarantee ceases to hold due to a change in the MDS matrix or round
|
||||
// constants, then this code will no longer be correct.
|
||||
"vpaddq ymm0, ymm3, [{base} + {index}]",
|
||||
"vpaddq ymm1, ymm4, [{base} + {index} + 32]",
|
||||
"vpaddq ymm2, ymm5, [{base} + {index} + 64]",
|
||||
|
||||
// MDS matrix multiplication, again. This time on high 32 bits.
|
||||
// Jump to the _local label_ (see above) `2:`. `b` for _backward_ specifies the direction.
|
||||
@ -514,7 +538,10 @@ unsafe fn mds_layer_multiply_s(
|
||||
"3:",
|
||||
// Just done the MDS matrix multiplication on high 32 bits.
|
||||
// The high results are in ymm3, ymm4, ymm5.
|
||||
// The low results (shifted by 2**63) are in ymm0, ymm1, ymm2
|
||||
// The low results (shifted by 2**63 and including the following constant layer) are in
|
||||
// ymm0, ymm1, ymm2.
|
||||
base = in(reg) base,
|
||||
index = in(reg) index,
|
||||
inout("ymm0") state.0 => unreduced_lo0_s,
|
||||
inout("ymm1") state.1 => unreduced_lo1_s,
|
||||
inout("ymm2") state.2 => unreduced_lo2_s,
|
||||
@ -523,7 +550,7 @@ unsafe fn mds_layer_multiply_s(
|
||||
out("ymm5") unreduced_hi2,
|
||||
out("ymm6") _,out("ymm7") _, out("ymm8") _, out("ymm9") _,
|
||||
out("ymm10") _, out("ymm11") _, out("ymm12") _, out("ymm13") _,
|
||||
in("ymm14") epsilon, in("ymm15") sign_bit,
|
||||
in("ymm14") epsilon,
|
||||
out("rax") _,
|
||||
options(pure, nomem, nostack),
|
||||
);
|
||||
@ -534,9 +561,12 @@ unsafe fn mds_layer_multiply_s(
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn mds_layer_full_s(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
|
||||
let (unreduced_lo_s, unreduced_hi) = mds_layer_multiply_s(state);
|
||||
mds_layer_reduce_s(unreduced_lo_s, unreduced_hi)
|
||||
unsafe fn mds_const_layers_full(
|
||||
state: (__m256i, __m256i, __m256i),
|
||||
round_constants: (*const u64, usize),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let (unreduced_lo_s, unreduced_hi) = mds_multiply_and_add_round_const_s(state, round_constants);
|
||||
mds_layer_reduce(unreduced_lo_s, unreduced_hi)
|
||||
}
|
||||
|
||||
/// Compute x ** 7
|
||||
@ -620,8 +650,9 @@ unsafe fn sbox_partial(mut x: u64) -> u64 {
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn sbox_mds_layers_partial_s(
|
||||
unsafe fn partial_round(
|
||||
(state0, state1, state2): (__m256i, __m256i, __m256i),
|
||||
round_constants: (*const u64, usize),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
// Extract the low quadword
|
||||
let state0ab: __m128i = _mm256_castsi256_si128(state0);
|
||||
@ -638,7 +669,7 @@ unsafe fn sbox_mds_layers_partial_s(
|
||||
let (
|
||||
(mut unreduced_lo0_s, mut unreduced_lo1_s, mut unreduced_lo2_s),
|
||||
(mut unreduced_hi0, mut unreduced_hi1, mut unreduced_hi2),
|
||||
) = mds_layer_multiply_s((state0bcd, state1, state2));
|
||||
) = mds_multiply_and_add_round_const_s((state0bcd, state1, state2), round_constants);
|
||||
asm!(
|
||||
// Just done the MDS matrix multiplication on high 32 bits.
|
||||
// The high results are in ymm3, ymm4, ymm5.
|
||||
@ -688,88 +719,72 @@ unsafe fn sbox_mds_layers_partial_s(
|
||||
in("ymm14") epsilon,
|
||||
options(pure, nomem, preserves_flags, nostack),
|
||||
);
|
||||
mds_layer_reduce_s(
|
||||
mds_layer_reduce(
|
||||
(unreduced_lo0_s, unreduced_lo1_s, unreduced_lo2_s),
|
||||
(unreduced_hi0, unreduced_hi1, unreduced_hi2),
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn full_round_s(
|
||||
state_s: (__m256i, __m256i, __m256i),
|
||||
round_constants: (*const GoldilocksField, usize),
|
||||
unsafe fn full_round(
|
||||
state: (__m256i, __m256i, __m256i),
|
||||
round_constants: (*const u64, usize),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let state = const_layer(state_s, round_constants);
|
||||
let state = sbox_layer_full(state);
|
||||
let state_s = mds_layer_full_s(state);
|
||||
state_s
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn partial_round_s(
|
||||
state_s: (__m256i, __m256i, __m256i),
|
||||
round_constants: (*const GoldilocksField, usize),
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let state = const_layer(state_s, round_constants);
|
||||
let state_s = sbox_mds_layers_partial_s(state);
|
||||
state_s
|
||||
let state = mds_const_layers_full(state, round_constants);
|
||||
state
|
||||
}
|
||||
|
||||
#[inline] // Called twice; permit inlining but don't _require_ it
|
||||
unsafe fn half_full_rounds_s(
|
||||
mut state_s: (__m256i, __m256i, __m256i),
|
||||
unsafe fn half_full_rounds(
|
||||
mut state: (__m256i, __m256i, __m256i),
|
||||
start_round: usize,
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let base = (&ALL_ROUND_CONSTANTS
|
||||
let base = (&FUSED_ROUND_CONSTANTS
|
||||
[WIDTH * start_round..WIDTH * start_round + WIDTH * HALF_N_FULL_ROUNDS])
|
||||
.as_ptr()
|
||||
.cast::<GoldilocksField>();
|
||||
.as_ptr();
|
||||
|
||||
for i in 0..HALF_N_FULL_ROUNDS {
|
||||
state_s = full_round_s(state_s, (base, i * WIDTH * size_of::<u64>()));
|
||||
state = full_round(state, (base, i * WIDTH * size_of::<u64>()));
|
||||
}
|
||||
state_s
|
||||
state
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn all_partial_rounds_s(
|
||||
mut state_s: (__m256i, __m256i, __m256i),
|
||||
unsafe fn all_partial_rounds(
|
||||
mut state: (__m256i, __m256i, __m256i),
|
||||
start_round: usize,
|
||||
) -> (__m256i, __m256i, __m256i) {
|
||||
let base = (&ALL_ROUND_CONSTANTS
|
||||
let base = (&FUSED_ROUND_CONSTANTS
|
||||
[WIDTH * start_round..WIDTH * start_round + WIDTH * N_PARTIAL_ROUNDS])
|
||||
.as_ptr()
|
||||
.cast::<GoldilocksField>();
|
||||
.as_ptr();
|
||||
|
||||
for i in 0..N_PARTIAL_ROUNDS {
|
||||
state_s = partial_round_s(state_s, (base, i * WIDTH * size_of::<u64>()));
|
||||
state = partial_round(state, (base, i * WIDTH * size_of::<u64>()));
|
||||
}
|
||||
state_s
|
||||
state
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub unsafe fn poseidon(state: &[GoldilocksField; 12]) -> [GoldilocksField; 12] {
|
||||
let sign_bit = _mm256_set1_epi64x(i64::MIN);
|
||||
let state = (
|
||||
_mm256_loadu_si256((&state[0..4]).as_ptr().cast::<__m256i>()),
|
||||
_mm256_loadu_si256((&state[4..8]).as_ptr().cast::<__m256i>()),
|
||||
_mm256_loadu_si256((&state[8..12]).as_ptr().cast::<__m256i>()),
|
||||
);
|
||||
|
||||
let mut s0 = _mm256_loadu_si256((&state[0..4]).as_ptr().cast::<__m256i>());
|
||||
let mut s1 = _mm256_loadu_si256((&state[4..8]).as_ptr().cast::<__m256i>());
|
||||
let mut s2 = _mm256_loadu_si256((&state[8..12]).as_ptr().cast::<__m256i>());
|
||||
s0 = _mm256_xor_si256(s0, sign_bit);
|
||||
s1 = _mm256_xor_si256(s1, sign_bit);
|
||||
s2 = _mm256_xor_si256(s2, sign_bit);
|
||||
// The first constant layer must be done explicitly. The remaining constant layers are fused
|
||||
// with the preceeding MDS layer.
|
||||
let state = const_layer(state, &ALL_ROUND_CONSTANTS[0..WIDTH].try_into().unwrap());
|
||||
|
||||
(s0, s1, s2) = half_full_rounds_s((s0, s1, s2), 0);
|
||||
(s0, s1, s2) = all_partial_rounds_s((s0, s1, s2), HALF_N_FULL_ROUNDS);
|
||||
(s0, s1, s2) = half_full_rounds_s((s0, s1, s2), HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS);
|
||||
|
||||
s0 = _mm256_xor_si256(s0, sign_bit);
|
||||
s1 = _mm256_xor_si256(s1, sign_bit);
|
||||
s2 = _mm256_xor_si256(s2, sign_bit);
|
||||
let state = half_full_rounds(state, 0);
|
||||
let state = all_partial_rounds(state, HALF_N_FULL_ROUNDS);
|
||||
let state = half_full_rounds(state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS);
|
||||
|
||||
let mut res = [GoldilocksField::ZERO; 12];
|
||||
_mm256_storeu_si256((&mut res[0..4]).as_mut_ptr().cast::<__m256i>(), s0);
|
||||
_mm256_storeu_si256((&mut res[4..8]).as_mut_ptr().cast::<__m256i>(), s1);
|
||||
_mm256_storeu_si256((&mut res[8..12]).as_mut_ptr().cast::<__m256i>(), s2);
|
||||
_mm256_storeu_si256((&mut res[0..4]).as_mut_ptr().cast::<__m256i>(), state.0);
|
||||
_mm256_storeu_si256((&mut res[4..8]).as_mut_ptr().cast::<__m256i>(), state.1);
|
||||
_mm256_storeu_si256((&mut res[8..12]).as_mut_ptr().cast::<__m256i>(), state.2);
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
@ -17,7 +17,7 @@ use crate::plonk::circuit_builder::CircuitBuilder;
|
||||
pub(crate) const HALF_N_FULL_ROUNDS: usize = 4;
|
||||
pub(crate) const N_FULL_ROUNDS_TOTAL: usize = 2 * HALF_N_FULL_ROUNDS;
|
||||
pub(crate) const N_PARTIAL_ROUNDS: usize = 22;
|
||||
const N_ROUNDS: usize = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS;
|
||||
pub(crate) const N_ROUNDS: usize = N_FULL_ROUNDS_TOTAL + N_PARTIAL_ROUNDS;
|
||||
const MAX_WIDTH: usize = 12; // we only have width 8 and 12, and 12 is bigger. :)
|
||||
|
||||
#[inline(always)]
|
||||
@ -45,6 +45,10 @@ pub const ALL_ROUND_CONSTANTS: [u64; MAX_WIDTH * N_ROUNDS] = [
|
||||
// not met, some platform-specific implementation of constant_layer may return incorrect
|
||||
// results.
|
||||
//
|
||||
// WARNING: The AVX2 Goldilocks specialization relies on all round constants being in
|
||||
// 0..0xfffeeac900011537. If these constants are randomly regenerated, there is a ~.6% chance
|
||||
// that this condition will no longer hold.
|
||||
//
|
||||
// WARNING: If these are changed in any way, then all the
|
||||
// implementations of Poseidon must be regenerated. See comments
|
||||
// in `poseidon_goldilocks.rs` and `poseidon_crandall.rs` for
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user