mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-08 08:43:06 +00:00
Restore vectorization to full Poseidon rounds on Aarch64 (#498)
* Restore vectorization to full Poseidon layers on Aarch64 * Typos
This commit is contained in:
parent
6072fab077
commit
c7af639579
@ -1,2 +1,2 @@
|
||||
// #[cfg(target_feature = "neon")]
|
||||
// pub(crate) mod poseidon_goldilocks_neon;
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub(crate) mod poseidon_goldilocks_neon;
|
||||
|
||||
@ -2,37 +2,24 @@
|
||||
|
||||
use std::arch::aarch64::*;
|
||||
use std::arch::asm;
|
||||
use std::mem::transmute;
|
||||
|
||||
use plonky2_field::field_types::Field64;
|
||||
use plonky2_field::goldilocks_field::GoldilocksField;
|
||||
use plonky2_util::branch_hint;
|
||||
use static_assertions::const_assert;
|
||||
use unroll::unroll_for_loops;
|
||||
|
||||
use crate::hash::poseidon::{
|
||||
Poseidon, ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS, N_ROUNDS,
|
||||
};
|
||||
use crate::hash::poseidon::Poseidon;
|
||||
|
||||
// ========================================== CONSTANTS ===========================================
|
||||
|
||||
const WIDTH: usize = 12;
|
||||
|
||||
// The order below is arbitrary. Repeated coefficients have been removed so these constants fit in
|
||||
// two registers.
|
||||
// TODO: ensure this is aligned to 16 bytes (for vector loads), ideally on the same cacheline
|
||||
const MDS_CONSTS: [u32; 8] = [
|
||||
0xffffffff,
|
||||
1 << 1,
|
||||
1 << 3,
|
||||
1 << 5,
|
||||
1 << 8,
|
||||
1 << 10,
|
||||
1 << 12,
|
||||
1 << 16,
|
||||
];
|
||||
const EPSILON: u64 = 0xffffffff;
|
||||
|
||||
// The round constants to be applied by the second set of full rounds. These are just the usual round constants,
|
||||
// shifted by one round, with zeros shifted in.
|
||||
// The round constants to be applied by the second set of full rounds. These are just the usual
|
||||
// round constants, shifted by one round, with zeros shifted in.
|
||||
/*
|
||||
const fn make_final_round_constants() -> [u64; WIDTH * HALF_N_FULL_ROUNDS] {
|
||||
let mut res = [0; WIDTH * HALF_N_FULL_ROUNDS];
|
||||
let mut i: usize = 0;
|
||||
@ -43,6 +30,7 @@ const fn make_final_round_constants() -> [u64; WIDTH * HALF_N_FULL_ROUNDS] {
|
||||
res
|
||||
}
|
||||
const FINAL_ROUND_CONSTANTS: [u64; WIDTH * HALF_N_FULL_ROUNDS] = make_final_round_constants();
|
||||
*/
|
||||
|
||||
// ===================================== COMPILE-TIME CHECKS ======================================
|
||||
|
||||
@ -52,9 +40,12 @@ const FINAL_ROUND_CONSTANTS: [u64; WIDTH * HALF_N_FULL_ROUNDS] = make_final_roun
|
||||
const fn check_mds_matrix() -> bool {
|
||||
// Can't == two arrays in a const_assert! (:
|
||||
let mut i = 0;
|
||||
let wanted_matrix_exps = [0, 0, 1, 0, 3, 5, 1, 8, 12, 3, 16, 10];
|
||||
let wanted_matrix_circ = [17, 15, 41, 16, 2, 28, 13, 13, 39, 18, 34, 20];
|
||||
let wanted_matrix_diag = [8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
|
||||
while i < WIDTH {
|
||||
if <GoldilocksField as Poseidon>::MDS_MATRIX_EXPS[i] != wanted_matrix_exps[i] {
|
||||
if <GoldilocksField as Poseidon>::MDS_MATRIX_CIRC[i] != wanted_matrix_circ[i]
|
||||
|| <GoldilocksField as Poseidon>::MDS_MATRIX_DIAG[i] != wanted_matrix_diag[i]
|
||||
{
|
||||
return false;
|
||||
}
|
||||
i += 1;
|
||||
@ -63,37 +54,10 @@ const fn check_mds_matrix() -> bool {
|
||||
}
|
||||
const_assert!(check_mds_matrix());
|
||||
|
||||
/// The maximum amount by which the MDS matrix will multiply the input.
|
||||
/// i.e. max(MDS(state)) <= mds_matrix_inf_norm() * max(state).
|
||||
const fn mds_matrix_inf_norm() -> u64 {
|
||||
let mut cumul = 0;
|
||||
let mut i = 0;
|
||||
while i < WIDTH {
|
||||
cumul += 1 << <GoldilocksField as Poseidon>::MDS_MATRIX_EXPS[i];
|
||||
i += 1;
|
||||
}
|
||||
cumul
|
||||
}
|
||||
|
||||
/// Ensure that adding round constants to the low result of the MDS multiplication can never
|
||||
/// overflow.
|
||||
#[allow(dead_code)]
|
||||
const fn check_round_const_bounds_mds() -> bool {
|
||||
let max_mds_res = mds_matrix_inf_norm() * (u32::MAX as u64);
|
||||
let mut i = WIDTH; // First const layer is handled specially.
|
||||
while i < WIDTH * N_ROUNDS {
|
||||
if ALL_ROUND_CONSTANTS[i].overflowing_add(max_mds_res).1 {
|
||||
return false;
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
true
|
||||
}
|
||||
const_assert!(check_round_const_bounds_mds());
|
||||
|
||||
/// Ensure that the first WIDTH round constants are in canonical* form. This is required because
|
||||
/// the first constant layer does not handle double overflow.
|
||||
/// *: round_const == GoldilocksField::ORDER is safe.
|
||||
/*
|
||||
#[allow(dead_code)]
|
||||
const fn check_round_const_bounds_init() -> bool {
|
||||
let mut i = 0;
|
||||
@ -106,11 +70,9 @@ const fn check_round_const_bounds_init() -> bool {
|
||||
true
|
||||
}
|
||||
const_assert!(check_round_const_bounds_init());
|
||||
|
||||
*/
|
||||
// ====================================== SCALAR ARITHMETIC =======================================
|
||||
|
||||
const EPSILON: u64 = 0xffffffff;
|
||||
|
||||
/// Addition modulo ORDER accounting for wraparound. Correct only when a + b < 2**64 + ORDER.
|
||||
#[inline(always)]
|
||||
unsafe fn add_with_wraparound(a: u64, b: u64) -> u64 {
|
||||
@ -133,7 +95,16 @@ unsafe fn add_with_wraparound(a: u64, b: u64) -> u64 {
|
||||
/// Subtraction of a and (b >> 32) modulo ORDER accounting for wraparound.
|
||||
#[inline(always)]
|
||||
unsafe fn sub_with_wraparound_lsr32(a: u64, b: u64) -> u64 {
|
||||
let b_hi = b >> 32;
|
||||
let mut b_hi = b >> 32;
|
||||
// Make sure that LLVM emits two separate instructions for the shift and the subtraction. This
|
||||
// reduces pressure on the execution units with access to the flags, as they are no longer
|
||||
// responsible for the shift. The hack is to insert a fake computation between the two
|
||||
// instructions with an `asm` block to make LLVM think that they can't be merged.
|
||||
asm!(
|
||||
"/* {0} */", // Make Rust think we're using the register.
|
||||
inlateout(reg) b_hi,
|
||||
options(nomem, nostack, preserves_flags, pure),
|
||||
);
|
||||
// This could be done with a.overflowing_add(b_hi), but `checked_sub` signals to the compiler
|
||||
// that overflow is unlikely (note: this is a standard library implementation detail, not part
|
||||
// of the spec).
|
||||
@ -153,7 +124,8 @@ unsafe fn sub_with_wraparound_lsr32(a: u64, b: u64) -> u64 {
|
||||
unsafe fn mul_epsilon(x: u64) -> u64 {
|
||||
let res;
|
||||
asm!(
|
||||
// Use UMULL to save one instruction. The compiler emits two: extract the low word and then multiply.
|
||||
// Use UMULL to save one instruction. The compiler emits two: extract the low word and then
|
||||
// multiply.
|
||||
"umull {res}, {x:w}, {epsilon:w}",
|
||||
x = in(reg) x,
|
||||
epsilon = in(reg) EPSILON,
|
||||
@ -179,8 +151,9 @@ unsafe fn multiply(x: u64, y: u64) -> u64 {
|
||||
|
||||
// ==================================== STANDALONE CONST LAYER =====================================
|
||||
|
||||
/// Standalone const layer. Run only once, at the start of round 1. Remaining const layers are fused with the preceeding
|
||||
/// MDS matrix multiplication.
|
||||
/// Standalone const layer. Run only once, at the start of round 1. Remaining const layers are fused
|
||||
/// with the preceeding MDS matrix multiplication.
|
||||
/*
|
||||
#[inline(always)]
|
||||
#[unroll_for_loops]
|
||||
unsafe fn const_layer_full(
|
||||
@ -195,15 +168,15 @@ unsafe fn const_layer_full(
|
||||
}
|
||||
state
|
||||
}
|
||||
|
||||
*/
|
||||
// ========================================== FULL ROUNDS ==========================================
|
||||
|
||||
/// Full S-box.
|
||||
#[inline(always)]
|
||||
#[unroll_for_loops]
|
||||
unsafe fn sbox_layer_full(state: [u64; WIDTH]) -> [u64; WIDTH] {
|
||||
// This is done in scalar. S-boxes in vector are only slightly slower throughput-wise but have an insane latency
|
||||
// (~100 cycles) on the M1.
|
||||
// This is done in scalar. S-boxes in vector are only slightly slower throughput-wise but have
|
||||
// an insane latency (~100 cycles) on the M1.
|
||||
|
||||
let mut state2 = [0u64; WIDTH];
|
||||
assert!(WIDTH == 12);
|
||||
@ -228,297 +201,227 @@ unsafe fn sbox_layer_full(state: [u64; WIDTH]) -> [u64; WIDTH] {
|
||||
state7
|
||||
}
|
||||
|
||||
// Aliases for readability. E.g. MDS[5] can be found in mdsv5[MDSI5].
|
||||
const MDSI2: i32 = 1; // MDS[2] == 1
|
||||
const MDSI4: i32 = 2; // MDS[4] == 3
|
||||
const MDSI5: i32 = 3; // MDS[5] == 5
|
||||
const MDSI6: i32 = 1; // MDS[6] == 1
|
||||
const MDSI7: i32 = 0; // MDS[7] == 8
|
||||
const MDSI8: i32 = 2; // MDS[8] == 12
|
||||
const MDSI9: i32 = 2; // MDS[9] == 3
|
||||
const MDSI10: i32 = 3; // MDS[10] == 16
|
||||
const MDSI11: i32 = 1; // MDS[11] == 10
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn mds_reduce(
|
||||
[[cumul0_a, cumul0_b], [cumul1_a, cumul1_b]]: [[uint64x2_t; 2]; 2],
|
||||
// `cumul_a` and `cumul_b` represent two separate field elements. We take advantage of
|
||||
// vectorization by reducing them simultaneously.
|
||||
[cumul_a, cumul_b]: [uint32x4_t; 2],
|
||||
) -> uint64x2_t {
|
||||
// mds_consts0 == [0xffffffff, 1 << 1, 1 << 3, 1 << 5]
|
||||
let mds_consts0: uint32x4_t = vld1q_u32((&MDS_CONSTS[0..4]).as_ptr().cast::<u32>());
|
||||
|
||||
// Merge accumulators
|
||||
let cumul0 = vaddq_u64(cumul0_a, cumul0_b);
|
||||
let cumul1 = vaddq_u64(cumul1_a, cumul1_b);
|
||||
|
||||
// Swizzle
|
||||
let res_lo = vzip1q_u64(cumul0, cumul1);
|
||||
let res_hi = vzip2q_u64(cumul0, cumul1);
|
||||
|
||||
// Reduce from u96
|
||||
let res_hi = vsraq_n_u64::<32>(res_hi, res_lo);
|
||||
let res_lo = vsliq_n_u64::<32>(res_lo, res_hi);
|
||||
|
||||
// Extract high 32-bits.
|
||||
let res_hi_hi = vget_low_u32(vuzp2q_u32(
|
||||
vreinterpretq_u32_u64(res_hi),
|
||||
vreinterpretq_u32_u64(res_hi),
|
||||
));
|
||||
|
||||
// Multiply by EPSILON and accumulate.
|
||||
let res_unadj = vmlal_laneq_u32::<0>(res_lo, res_hi_hi, mds_consts0);
|
||||
let res_adj = vcgtq_u64(res_lo, res_unadj);
|
||||
vsraq_n_u64::<32>(res_unadj, res_adj)
|
||||
// Form:
|
||||
// `lo = [cumul_a[0] + cumul_a[2] * 2**32, cumul_b[0] + cumul_b[2] * 2**32]`
|
||||
// `hi = [cumul_a[1] + cumul_a[3] * 2**32, cumul_b[1] + cumul_b[3] * 2**32]`
|
||||
// Observe that the result `== lo + hi * 2**16 (mod Goldilocks)`.
|
||||
let mut lo = vreinterpretq_u64_u32(vuzp1q_u32(cumul_a, cumul_b));
|
||||
let mut hi = vreinterpretq_u64_u32(vuzp2q_u32(cumul_a, cumul_b));
|
||||
// Add the high 48 bits of `lo` to `hi`. This cannot overflow.
|
||||
hi = vsraq_n_u64::<16>(hi, lo);
|
||||
// Now, result `== lo.bits[0..16] + hi * 2**16 (mod Goldilocks)`.
|
||||
// Set the high 48 bits of `lo` to the low 48 bits of `hi`.
|
||||
lo = vsliq_n_u64::<16>(lo, hi);
|
||||
// At this point, result `== lo + hi.bits[48..64] * 2**64 (mod Goldilocks)`.
|
||||
// It remains to fold `hi.bits[48..64]` into `lo`.
|
||||
let top = {
|
||||
// Extract the top 16 bits of `hi` as a `u32`.
|
||||
// Interpret `hi` as a vector of bytes, so we can use a table lookup instruction.
|
||||
let hi_u8 = vreinterpretq_u8_u64(hi);
|
||||
// Indices defining the permutation. `0xff` is out of bounds, producing `0`.
|
||||
let top_idx =
|
||||
transmute::<[u8; 8], uint8x8_t>([0x06, 0x07, 0xff, 0xff, 0x0e, 0x0f, 0xff, 0xff]);
|
||||
let top_u8 = vqtbl1_u8(hi_u8, top_idx);
|
||||
vreinterpret_u32_u8(top_u8)
|
||||
};
|
||||
// result `== lo + top * 2**64 (mod Goldilocks)`.
|
||||
let adj_lo = vmlal_n_u32(lo, top, EPSILON as u32);
|
||||
let wraparound_mask = vcgtq_u64(lo, adj_lo);
|
||||
vsraq_n_u64::<32>(adj_lo, wraparound_mask) // Add epsilon on overflow.
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn mds_const_layers_full(
|
||||
state: [u64; WIDTH],
|
||||
round_constants: &[u64; WIDTH],
|
||||
) -> [u64; WIDTH] {
|
||||
// mds_consts0 == [0xffffffff, 1 << 1, 1 << 3, 1 << 5]
|
||||
// mds_consts1 == [1 << 8, 1 << 10, 1 << 12, 1 << 16]
|
||||
let mds_consts0: uint32x4_t = vld1q_u32((&MDS_CONSTS[0..4]).as_ptr().cast::<u32>());
|
||||
let mds_consts1: uint32x4_t = vld1q_u32((&MDS_CONSTS[4..8]).as_ptr().cast::<u32>());
|
||||
unsafe fn mds_layer_full(state: [u64; WIDTH]) -> [u64; WIDTH] {
|
||||
// This function performs an MDS multiplication in complex FFT space.
|
||||
// However, instead of performing a width-12 FFT, we perform three width-4 FFTs, which is
|
||||
// cheaper. The 12x12 matrix-vector multiplication (a convolution) becomes two 3x3 real
|
||||
// matrix-vector multiplications and one 3x3 complex matrix-vector multiplication.
|
||||
|
||||
// Aliases for readability. E.g. MDS[5] can be found in mdsv5[mdsi5]. MDS[0], MDS[1], and
|
||||
// MDS[3] are 0, so they are not needed.
|
||||
let mdsv2 = mds_consts0; // MDS[2] == 1
|
||||
let mdsv4 = mds_consts0; // MDS[4] == 3
|
||||
let mdsv5 = mds_consts0; // MDS[5] == 5
|
||||
let mdsv6 = mds_consts0; // MDS[6] == 1
|
||||
let mdsv7 = mds_consts1; // MDS[7] == 8
|
||||
let mdsv8 = mds_consts1; // MDS[8] == 12
|
||||
let mdsv9 = mds_consts0; // MDS[9] == 3
|
||||
let mdsv10 = mds_consts1; // MDS[10] == 16
|
||||
let mdsv11 = mds_consts1; // MDS[11] == 10
|
||||
// We split each 64-bit into four chunks of 16 bits. To prevent overflow, each chunk is 32 bits
|
||||
// long. Each NEON vector below represents one field element and consists of four 32-bit chunks:
|
||||
// `elem == vector[0] + vector[1] * 2**16 + vector[2] * 2**32 + vector[3] * 2**48`.
|
||||
|
||||
// For i even, we combine state[i] and state[i + 1] into one vector to save on registers.
|
||||
// Thus, state1 actually contains state0 and state1 but is only used in the intrinsics that
|
||||
// access the high high doubleword.
|
||||
let state1: uint32x4_t =
|
||||
vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[0]), vcreate_u64(state[1])));
|
||||
let state3: uint32x4_t =
|
||||
vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[2]), vcreate_u64(state[3])));
|
||||
let state5: uint32x4_t =
|
||||
vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[4]), vcreate_u64(state[5])));
|
||||
let state7: uint32x4_t =
|
||||
vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[6]), vcreate_u64(state[7])));
|
||||
let state9: uint32x4_t =
|
||||
vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[8]), vcreate_u64(state[9])));
|
||||
let state11: uint32x4_t =
|
||||
vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[10]), vcreate_u64(state[11])));
|
||||
// state0 is an alias to the low doubleword of state1. The compiler should use one register for both.
|
||||
let state0: uint32x2_t = vget_low_u32(state1);
|
||||
let state2: uint32x2_t = vget_low_u32(state3);
|
||||
let state4: uint32x2_t = vget_low_u32(state5);
|
||||
let state6: uint32x2_t = vget_low_u32(state7);
|
||||
let state8: uint32x2_t = vget_low_u32(state9);
|
||||
let state10: uint32x2_t = vget_low_u32(state11);
|
||||
// Constants that we multiply by.
|
||||
let mut consts: uint32x4_t = transmute::<[u32; 4], _>([2, 4, 8, 16]);
|
||||
|
||||
// Two accumulators per output to hide latency. Each accumulator is a vector of two u64s,
|
||||
// containing the result for the low 32 bits and the high 32 bits. Thus, the final result at
|
||||
// index i is (cumuli_a[0] + cumuli_b[0]) + (cumuli_a[1] + cumuli_b[1]) * 2**32.
|
||||
// Prevent LLVM from turning fused multiply (by power of 2)-add (1 instruction) into shift and
|
||||
// add (two instructions). This fake `asm` block means that LLVM no longer knows the contents of
|
||||
// `consts`.
|
||||
asm!("/* {0:v} */", // Make Rust think the register is being used.
|
||||
inout(vreg) consts,
|
||||
options(pure, nomem, nostack, preserves_flags),
|
||||
);
|
||||
|
||||
// Start by loading the round constants.
|
||||
let mut cumul0_a = vcombine_u64(vld1_u64(&round_constants[0]), vcreate_u64(0));
|
||||
let mut cumul1_a = vcombine_u64(vld1_u64(&round_constants[1]), vcreate_u64(0));
|
||||
let mut cumul2_a = vcombine_u64(vld1_u64(&round_constants[2]), vcreate_u64(0));
|
||||
let mut cumul3_a = vcombine_u64(vld1_u64(&round_constants[3]), vcreate_u64(0));
|
||||
let mut cumul4_a = vcombine_u64(vld1_u64(&round_constants[4]), vcreate_u64(0));
|
||||
let mut cumul5_a = vcombine_u64(vld1_u64(&round_constants[5]), vcreate_u64(0));
|
||||
let mut cumul6_a = vcombine_u64(vld1_u64(&round_constants[6]), vcreate_u64(0));
|
||||
let mut cumul7_a = vcombine_u64(vld1_u64(&round_constants[7]), vcreate_u64(0));
|
||||
let mut cumul8_a = vcombine_u64(vld1_u64(&round_constants[8]), vcreate_u64(0));
|
||||
let mut cumul9_a = vcombine_u64(vld1_u64(&round_constants[9]), vcreate_u64(0));
|
||||
let mut cumul10_a = vcombine_u64(vld1_u64(&round_constants[10]), vcreate_u64(0));
|
||||
let mut cumul11_a = vcombine_u64(vld1_u64(&round_constants[11]), vcreate_u64(0));
|
||||
// Four length-3 complex FFTs.
|
||||
let mut state_fft = [vdupq_n_u32(0); 12];
|
||||
for i in 0..3 {
|
||||
// Interpret each field element as a 4-vector of `u16`s.
|
||||
let x0 = vcreate_u16(state[i]);
|
||||
let x1 = vcreate_u16(state[i + 3]);
|
||||
let x2 = vcreate_u16(state[i + 6]);
|
||||
let x3 = vcreate_u16(state[i + 9]);
|
||||
|
||||
// Now the matrix multiplication.
|
||||
// MDS exps: [0, 0, 1, 0, 3, 5, 1, 8, 12, 3, 16, 10]
|
||||
// out[i] += in[j] << mds[j - i]
|
||||
// `vaddl_u16` and `vsubl_u16` yield 4-vectors of `u32`s.
|
||||
let y0 = vaddl_u16(x0, x2);
|
||||
let y1 = vaddl_u16(x1, x3);
|
||||
let y2 = vsubl_u16(x0, x2);
|
||||
let y3 = vsubl_u16(x1, x3);
|
||||
|
||||
let mut cumul0_b = vshll_n_u32::<0>(state0); // MDS[0]
|
||||
let mut cumul1_b = vshll_n_u32::<10>(state0); // MDS[11]
|
||||
let mut cumul2_b = vshll_n_u32::<16>(state0); // MDS[10]
|
||||
let mut cumul3_b = vshll_n_u32::<3>(state0); // MDS[9]
|
||||
let mut cumul4_b = vshll_n_u32::<12>(state0); // MDS[8]
|
||||
let mut cumul5_b = vshll_n_u32::<8>(state0); // MDS[7]
|
||||
let mut cumul6_b = vshll_n_u32::<1>(state0); // MDS[6]
|
||||
let mut cumul7_b = vshll_n_u32::<5>(state0); // MDS[5]
|
||||
let mut cumul8_b = vshll_n_u32::<3>(state0); // MDS[4]
|
||||
let mut cumul9_b = vshll_n_u32::<0>(state0); // MDS[3]
|
||||
let mut cumul10_b = vshll_n_u32::<1>(state0); // MDS[2]
|
||||
let mut cumul11_b = vshll_n_u32::<0>(state0); // MDS[1]
|
||||
let z0 = vaddq_u32(y0, y1);
|
||||
let z1 = vsubq_u32(y0, y1);
|
||||
let z2 = y2;
|
||||
let z3 = y3;
|
||||
|
||||
cumul0_a = vaddw_high_u32(cumul0_a, state1); // MDS[1]
|
||||
cumul1_a = vaddw_high_u32(cumul1_a, state1); // MDS[0]
|
||||
cumul2_a = vmlal_high_laneq_u32::<MDSI11>(cumul2_a, state1, mdsv11); // MDS[11]
|
||||
cumul3_a = vmlal_high_laneq_u32::<MDSI10>(cumul3_a, state1, mdsv10); // MDS[10]
|
||||
cumul4_a = vmlal_high_laneq_u32::<MDSI9>(cumul4_a, state1, mdsv9); // MDS[9]
|
||||
cumul5_a = vmlal_high_laneq_u32::<MDSI8>(cumul5_a, state1, mdsv8); // MDS[8]
|
||||
cumul6_a = vmlal_high_laneq_u32::<MDSI7>(cumul6_a, state1, mdsv7); // MDS[7]
|
||||
cumul7_a = vmlal_high_laneq_u32::<MDSI6>(cumul7_a, state1, mdsv6); // MDS[6]
|
||||
cumul8_a = vmlal_high_laneq_u32::<MDSI5>(cumul8_a, state1, mdsv5); // MDS[5]
|
||||
cumul9_a = vmlal_high_laneq_u32::<MDSI4>(cumul9_a, state1, mdsv4); // MDS[4]
|
||||
cumul10_a = vaddw_high_u32(cumul10_a, state1); // MDS[3]
|
||||
cumul11_a = vmlal_high_laneq_u32::<MDSI2>(cumul11_a, state1, mdsv2); // MDS[2]
|
||||
// The FFT is `[z0, z2 + z3 i, z1, z2 - z3 i]`.
|
||||
|
||||
cumul0_b = vmlal_laneq_u32::<MDSI2>(cumul0_b, state2, mdsv2); // MDS[2]
|
||||
cumul1_b = vaddw_u32(cumul1_b, state2); // MDS[1]
|
||||
cumul2_b = vaddw_u32(cumul2_b, state2); // MDS[0]
|
||||
cumul3_b = vmlal_laneq_u32::<MDSI11>(cumul3_b, state2, mdsv11); // MDS[11]
|
||||
cumul4_b = vmlal_laneq_u32::<MDSI10>(cumul4_b, state2, mdsv10); // MDS[10]
|
||||
cumul5_b = vmlal_laneq_u32::<MDSI9>(cumul5_b, state2, mdsv9); // MDS[9]
|
||||
cumul6_b = vmlal_laneq_u32::<MDSI8>(cumul6_b, state2, mdsv8); // MDS[8]
|
||||
cumul7_b = vmlal_laneq_u32::<MDSI7>(cumul7_b, state2, mdsv7); // MDS[7]
|
||||
cumul8_b = vmlal_laneq_u32::<MDSI6>(cumul8_b, state2, mdsv6); // MDS[6]
|
||||
cumul9_b = vmlal_laneq_u32::<MDSI5>(cumul9_b, state2, mdsv5); // MDS[5]
|
||||
cumul10_b = vmlal_laneq_u32::<MDSI4>(cumul10_b, state2, mdsv4); // MDS[4]
|
||||
cumul11_b = vaddw_u32(cumul11_b, state2); // MDS[3]
|
||||
state_fft[i] = z0;
|
||||
state_fft[i + 3] = z1;
|
||||
state_fft[i + 6] = z2;
|
||||
state_fft[i + 9] = z3;
|
||||
}
|
||||
|
||||
cumul0_a = vaddw_high_u32(cumul0_a, state3); // MDS[3]
|
||||
cumul1_a = vmlal_high_laneq_u32::<MDSI2>(cumul1_a, state3, mdsv2); // MDS[2]
|
||||
cumul2_a = vaddw_high_u32(cumul2_a, state3); // MDS[1]
|
||||
cumul3_a = vaddw_high_u32(cumul3_a, state3); // MDS[0]
|
||||
cumul4_a = vmlal_high_laneq_u32::<MDSI11>(cumul4_a, state3, mdsv11); // MDS[11]
|
||||
cumul5_a = vmlal_high_laneq_u32::<MDSI10>(cumul5_a, state3, mdsv10); // MDS[10]
|
||||
cumul6_a = vmlal_high_laneq_u32::<MDSI9>(cumul6_a, state3, mdsv9); // MDS[9]
|
||||
cumul7_a = vmlal_high_laneq_u32::<MDSI8>(cumul7_a, state3, mdsv8); // MDS[8]
|
||||
cumul8_a = vmlal_high_laneq_u32::<MDSI7>(cumul8_a, state3, mdsv7); // MDS[7]
|
||||
cumul9_a = vmlal_high_laneq_u32::<MDSI6>(cumul9_a, state3, mdsv6); // MDS[6]
|
||||
cumul10_a = vmlal_high_laneq_u32::<MDSI5>(cumul10_a, state3, mdsv5); // MDS[5]
|
||||
cumul11_a = vmlal_high_laneq_u32::<MDSI4>(cumul11_a, state3, mdsv4); // MDS[4]
|
||||
// 3x3 real matrix-vector mul for component 0 of the FFTs.
|
||||
// Multiply the vector `[x0, x1, x2]` by the matrix
|
||||
// `[[ 64, 64, 128],`
|
||||
// ` [128, 64, 64],`
|
||||
// ` [ 64, 128, 64]]`
|
||||
// The results are divided by 4 (this ends up cancelling out some later computations).
|
||||
{
|
||||
let x0 = state_fft[0];
|
||||
let x1 = state_fft[1];
|
||||
let x2 = state_fft[2];
|
||||
|
||||
cumul0_b = vmlal_laneq_u32::<MDSI4>(cumul0_b, state4, mdsv4); // MDS[4]
|
||||
cumul1_b = vaddw_u32(cumul1_b, state4); // MDS[3]
|
||||
cumul2_b = vmlal_laneq_u32::<MDSI2>(cumul2_b, state4, mdsv2); // MDS[2]
|
||||
cumul3_b = vaddw_u32(cumul3_b, state4); // MDS[1]
|
||||
cumul4_b = vaddw_u32(cumul4_b, state4); // MDS[0]
|
||||
cumul5_b = vmlal_laneq_u32::<MDSI11>(cumul5_b, state4, mdsv11); // MDS[11]
|
||||
cumul6_b = vmlal_laneq_u32::<MDSI10>(cumul6_b, state4, mdsv10); // MDS[10]
|
||||
cumul7_b = vmlal_laneq_u32::<MDSI9>(cumul7_b, state4, mdsv9); // MDS[9]
|
||||
cumul8_b = vmlal_laneq_u32::<MDSI8>(cumul8_b, state4, mdsv8); // MDS[8]
|
||||
cumul9_b = vmlal_laneq_u32::<MDSI7>(cumul9_b, state4, mdsv7); // MDS[7]
|
||||
cumul10_b = vmlal_laneq_u32::<MDSI6>(cumul10_b, state4, mdsv6); // MDS[6]
|
||||
cumul11_b = vmlal_laneq_u32::<MDSI5>(cumul11_b, state4, mdsv5); // MDS[5]
|
||||
let t = vshlq_n_u32::<4>(x0);
|
||||
let u = vaddq_u32(x1, x2);
|
||||
|
||||
cumul0_a = vmlal_high_laneq_u32::<MDSI5>(cumul0_a, state5, mdsv5); // MDS[5]
|
||||
cumul1_a = vmlal_high_laneq_u32::<MDSI4>(cumul1_a, state5, mdsv4); // MDS[4]
|
||||
cumul2_a = vaddw_high_u32(cumul2_a, state5); // MDS[3]
|
||||
cumul3_a = vmlal_high_laneq_u32::<MDSI2>(cumul3_a, state5, mdsv2); // MDS[2]
|
||||
cumul4_a = vaddw_high_u32(cumul4_a, state5); // MDS[1]
|
||||
cumul5_a = vaddw_high_u32(cumul5_a, state5); // MDS[0]
|
||||
cumul6_a = vmlal_high_laneq_u32::<MDSI11>(cumul6_a, state5, mdsv11); // MDS[11]
|
||||
cumul7_a = vmlal_high_laneq_u32::<MDSI10>(cumul7_a, state5, mdsv10); // MDS[10]
|
||||
cumul8_a = vmlal_high_laneq_u32::<MDSI9>(cumul8_a, state5, mdsv9); // MDS[9]
|
||||
cumul9_a = vmlal_high_laneq_u32::<MDSI8>(cumul9_a, state5, mdsv8); // MDS[8]
|
||||
cumul10_a = vmlal_high_laneq_u32::<MDSI7>(cumul10_a, state5, mdsv7); // MDS[7]
|
||||
cumul11_a = vmlal_high_laneq_u32::<MDSI6>(cumul11_a, state5, mdsv6); // MDS[6]
|
||||
let y0 = vshlq_n_u32::<4>(u);
|
||||
let y1 = vmlaq_laneq_u32::<3>(t, x2, consts);
|
||||
let y2 = vmlaq_laneq_u32::<3>(t, x1, consts);
|
||||
|
||||
cumul0_b = vmlal_laneq_u32::<MDSI6>(cumul0_b, state6, mdsv6); // MDS[6]
|
||||
cumul1_b = vmlal_laneq_u32::<MDSI5>(cumul1_b, state6, mdsv5); // MDS[5]
|
||||
cumul2_b = vmlal_laneq_u32::<MDSI4>(cumul2_b, state6, mdsv4); // MDS[4]
|
||||
cumul3_b = vaddw_u32(cumul3_b, state6); // MDS[3]
|
||||
cumul4_b = vmlal_laneq_u32::<MDSI2>(cumul4_b, state6, mdsv2); // MDS[2]
|
||||
cumul5_b = vaddw_u32(cumul5_b, state6); // MDS[1]
|
||||
cumul6_b = vaddw_u32(cumul6_b, state6); // MDS[0]
|
||||
cumul7_b = vmlal_laneq_u32::<MDSI11>(cumul7_b, state6, mdsv11); // MDS[11]
|
||||
cumul8_b = vmlal_laneq_u32::<MDSI10>(cumul8_b, state6, mdsv10); // MDS[10]
|
||||
cumul9_b = vmlal_laneq_u32::<MDSI9>(cumul9_b, state6, mdsv9); // MDS[9]
|
||||
cumul10_b = vmlal_laneq_u32::<MDSI8>(cumul10_b, state6, mdsv8); // MDS[8]
|
||||
cumul11_b = vmlal_laneq_u32::<MDSI7>(cumul11_b, state6, mdsv7); // MDS[7]
|
||||
state_fft[0] = vaddq_u32(y0, y1);
|
||||
state_fft[1] = vaddq_u32(y1, y2);
|
||||
state_fft[2] = vaddq_u32(y0, y2);
|
||||
}
|
||||
|
||||
cumul0_a = vmlal_high_laneq_u32::<MDSI7>(cumul0_a, state7, mdsv7); // MDS[7]
|
||||
cumul1_a = vmlal_high_laneq_u32::<MDSI6>(cumul1_a, state7, mdsv6); // MDS[6]
|
||||
cumul2_a = vmlal_high_laneq_u32::<MDSI5>(cumul2_a, state7, mdsv5); // MDS[5]
|
||||
cumul3_a = vmlal_high_laneq_u32::<MDSI4>(cumul3_a, state7, mdsv4); // MDS[4]
|
||||
cumul4_a = vaddw_high_u32(cumul4_a, state7); // MDS[3]
|
||||
cumul5_a = vmlal_high_laneq_u32::<MDSI2>(cumul5_a, state7, mdsv2); // MDS[2]
|
||||
cumul6_a = vaddw_high_u32(cumul6_a, state7); // MDS[1]
|
||||
cumul7_a = vaddw_high_u32(cumul7_a, state7); // MDS[0]
|
||||
cumul8_a = vmlal_high_laneq_u32::<MDSI11>(cumul8_a, state7, mdsv11); // MDS[11]
|
||||
cumul9_a = vmlal_high_laneq_u32::<MDSI10>(cumul9_a, state7, mdsv10); // MDS[10]
|
||||
cumul10_a = vmlal_high_laneq_u32::<MDSI9>(cumul10_a, state7, mdsv9); // MDS[9]
|
||||
cumul11_a = vmlal_high_laneq_u32::<MDSI8>(cumul11_a, state7, mdsv8); // MDS[8]
|
||||
// 3x3 real matrix-vector mul for component 2 of the FFTs.
|
||||
// Multiply the vector `[x0, x1, x2]` by the matrix
|
||||
// `[[ -4, -8, 32],`
|
||||
// ` [-32, -4, -8],`
|
||||
// ` [ 8, -32, -4]]`
|
||||
// The results are divided by 4 (this ends up cancelling out some later computations).
|
||||
{
|
||||
let x0 = state_fft[3];
|
||||
let x1 = state_fft[4];
|
||||
let x2 = state_fft[5];
|
||||
state_fft[3] = vmlsq_laneq_u32::<2>(vmlaq_laneq_u32::<0>(x0, x1, consts), x2, consts);
|
||||
state_fft[4] = vmlaq_laneq_u32::<0>(vmlaq_laneq_u32::<2>(x1, x0, consts), x2, consts);
|
||||
state_fft[5] = vmlsq_laneq_u32::<0>(x2, vmlsq_laneq_u32::<1>(x0, x1, consts), consts);
|
||||
}
|
||||
|
||||
cumul0_b = vmlal_laneq_u32::<MDSI8>(cumul0_b, state8, mdsv8); // MDS[8]
|
||||
cumul1_b = vmlal_laneq_u32::<MDSI7>(cumul1_b, state8, mdsv7); // MDS[7]
|
||||
cumul2_b = vmlal_laneq_u32::<MDSI6>(cumul2_b, state8, mdsv6); // MDS[6]
|
||||
cumul3_b = vmlal_laneq_u32::<MDSI5>(cumul3_b, state8, mdsv5); // MDS[5]
|
||||
cumul4_b = vmlal_laneq_u32::<MDSI4>(cumul4_b, state8, mdsv4); // MDS[4]
|
||||
cumul5_b = vaddw_u32(cumul5_b, state8); // MDS[3]
|
||||
cumul6_b = vmlal_laneq_u32::<MDSI2>(cumul6_b, state8, mdsv2); // MDS[2]
|
||||
cumul7_b = vaddw_u32(cumul7_b, state8); // MDS[1]
|
||||
cumul8_b = vaddw_u32(cumul8_b, state8); // MDS[0]
|
||||
cumul9_b = vmlal_laneq_u32::<MDSI11>(cumul9_b, state8, mdsv11); // MDS[11]
|
||||
cumul10_b = vmlal_laneq_u32::<MDSI10>(cumul10_b, state8, mdsv10); // MDS[10]
|
||||
cumul11_b = vmlal_laneq_u32::<MDSI9>(cumul11_b, state8, mdsv9); // MDS[9]
|
||||
// 3x3 complex matrix-vector mul for components 1 and 3 of the FFTs.
|
||||
// Multiply the vector `[x0r + x0i i, x1r + x1i i, x2r + x2i i]` by the matrix
|
||||
// `[[ 4 + 2i, 2 + 32i, 2 - 8i],`
|
||||
// ` [-8 - 2i, 4 + 2i, 2 + 32i],`
|
||||
// ` [32 - 2i, -8 - 2i, 4 + 2i]]`
|
||||
// The results are divided by 2 (this ends up cancelling out some later computations).
|
||||
{
|
||||
let x0r = state_fft[6];
|
||||
let x1r = state_fft[7];
|
||||
let x2r = state_fft[8];
|
||||
|
||||
cumul0_a = vmlal_high_laneq_u32::<MDSI9>(cumul0_a, state9, mdsv9); // MDS[9]
|
||||
cumul1_a = vmlal_high_laneq_u32::<MDSI8>(cumul1_a, state9, mdsv8); // MDS[8]
|
||||
cumul2_a = vmlal_high_laneq_u32::<MDSI7>(cumul2_a, state9, mdsv7); // MDS[7]
|
||||
cumul3_a = vmlal_high_laneq_u32::<MDSI6>(cumul3_a, state9, mdsv6); // MDS[6]
|
||||
cumul4_a = vmlal_high_laneq_u32::<MDSI5>(cumul4_a, state9, mdsv5); // MDS[5]
|
||||
cumul5_a = vmlal_high_laneq_u32::<MDSI4>(cumul5_a, state9, mdsv4); // MDS[4]
|
||||
cumul6_a = vaddw_high_u32(cumul6_a, state9); // MDS[3]
|
||||
cumul7_a = vmlal_high_laneq_u32::<MDSI2>(cumul7_a, state9, mdsv2); // MDS[2]
|
||||
cumul8_a = vaddw_high_u32(cumul8_a, state9); // MDS[1]
|
||||
cumul9_a = vaddw_high_u32(cumul9_a, state9); // MDS[0]
|
||||
cumul10_a = vmlal_high_laneq_u32::<MDSI11>(cumul10_a, state9, mdsv11); // MDS[11]
|
||||
cumul11_a = vmlal_high_laneq_u32::<MDSI10>(cumul11_a, state9, mdsv10); // MDS[10]
|
||||
let x0i = state_fft[9];
|
||||
let x1i = state_fft[10];
|
||||
let x2i = state_fft[11];
|
||||
|
||||
cumul0_b = vmlal_laneq_u32::<MDSI10>(cumul0_b, state10, mdsv10); // MDS[10]
|
||||
cumul1_b = vmlal_laneq_u32::<MDSI9>(cumul1_b, state10, mdsv9); // MDS[9]
|
||||
cumul2_b = vmlal_laneq_u32::<MDSI8>(cumul2_b, state10, mdsv8); // MDS[8]
|
||||
cumul3_b = vmlal_laneq_u32::<MDSI7>(cumul3_b, state10, mdsv7); // MDS[7]
|
||||
cumul4_b = vmlal_laneq_u32::<MDSI6>(cumul4_b, state10, mdsv6); // MDS[6]
|
||||
cumul5_b = vmlal_laneq_u32::<MDSI5>(cumul5_b, state10, mdsv5); // MDS[5]
|
||||
cumul6_b = vmlal_laneq_u32::<MDSI4>(cumul6_b, state10, mdsv4); // MDS[4]
|
||||
cumul7_b = vaddw_u32(cumul7_b, state10); // MDS[3]
|
||||
cumul8_b = vmlal_laneq_u32::<MDSI2>(cumul8_b, state10, mdsv2); // MDS[2]
|
||||
cumul9_b = vaddw_u32(cumul9_b, state10); // MDS[1]
|
||||
cumul10_b = vaddw_u32(cumul10_b, state10); // MDS[0]
|
||||
cumul11_b = vmlal_laneq_u32::<MDSI11>(cumul11_b, state10, mdsv11); // MDS[11]
|
||||
// real part of result <- real part of input
|
||||
let r0rr = vaddq_u32(vmlaq_laneq_u32::<0>(x1r, x0r, consts), x2r);
|
||||
let r1rr = vmlaq_laneq_u32::<0>(x2r, vmlsq_laneq_u32::<0>(x1r, x0r, consts), consts);
|
||||
let r2rr = vmlsq_laneq_u32::<0>(x2r, vmlsq_laneq_u32::<1>(x1r, x0r, consts), consts);
|
||||
|
||||
cumul0_a = vmlal_high_laneq_u32::<MDSI11>(cumul0_a, state11, mdsv11); // MDS[11]
|
||||
cumul1_a = vmlal_high_laneq_u32::<MDSI10>(cumul1_a, state11, mdsv10); // MDS[10]
|
||||
cumul2_a = vmlal_high_laneq_u32::<MDSI9>(cumul2_a, state11, mdsv9); // MDS[9]
|
||||
cumul3_a = vmlal_high_laneq_u32::<MDSI8>(cumul3_a, state11, mdsv8); // MDS[8]
|
||||
cumul4_a = vmlal_high_laneq_u32::<MDSI7>(cumul4_a, state11, mdsv7); // MDS[7]
|
||||
cumul5_a = vmlal_high_laneq_u32::<MDSI6>(cumul5_a, state11, mdsv6); // MDS[6]
|
||||
cumul6_a = vmlal_high_laneq_u32::<MDSI5>(cumul6_a, state11, mdsv5); // MDS[5]
|
||||
cumul7_a = vmlal_high_laneq_u32::<MDSI4>(cumul7_a, state11, mdsv4); // MDS[4]
|
||||
cumul8_a = vaddw_high_u32(cumul8_a, state11); // MDS[3]
|
||||
cumul9_a = vmlal_high_laneq_u32::<MDSI2>(cumul9_a, state11, mdsv2); // MDS[2]
|
||||
cumul10_a = vaddw_high_u32(cumul10_a, state11); // MDS[1]
|
||||
cumul11_a = vaddw_high_u32(cumul11_a, state11); // MDS[0]
|
||||
// real part of result <- imaginary part of input
|
||||
let r0ri = vmlsq_laneq_u32::<1>(vmlaq_laneq_u32::<3>(x0i, x1i, consts), x2i, consts);
|
||||
let r1ri = vmlsq_laneq_u32::<3>(vsubq_u32(x0i, x1i), x2i, consts);
|
||||
let r2ri = vsubq_u32(vaddq_u32(x0i, x1i), x2i);
|
||||
|
||||
let reduced = [
|
||||
mds_reduce([[cumul0_a, cumul0_b], [cumul1_a, cumul1_b]]),
|
||||
mds_reduce([[cumul2_a, cumul2_b], [cumul3_a, cumul3_b]]),
|
||||
mds_reduce([[cumul4_a, cumul4_b], [cumul5_a, cumul5_b]]),
|
||||
mds_reduce([[cumul6_a, cumul6_b], [cumul7_a, cumul7_b]]),
|
||||
mds_reduce([[cumul8_a, cumul8_b], [cumul9_a, cumul9_b]]),
|
||||
mds_reduce([[cumul10_a, cumul10_b], [cumul11_a, cumul11_b]]),
|
||||
];
|
||||
[
|
||||
vgetq_lane_u64::<0>(reduced[0]),
|
||||
vgetq_lane_u64::<1>(reduced[0]),
|
||||
vgetq_lane_u64::<0>(reduced[1]),
|
||||
vgetq_lane_u64::<1>(reduced[1]),
|
||||
vgetq_lane_u64::<0>(reduced[2]),
|
||||
vgetq_lane_u64::<1>(reduced[2]),
|
||||
vgetq_lane_u64::<0>(reduced[3]),
|
||||
vgetq_lane_u64::<1>(reduced[3]),
|
||||
vgetq_lane_u64::<0>(reduced[4]),
|
||||
vgetq_lane_u64::<1>(reduced[4]),
|
||||
vgetq_lane_u64::<0>(reduced[5]),
|
||||
vgetq_lane_u64::<1>(reduced[5]),
|
||||
]
|
||||
// real part of result (total)
|
||||
let r0r = vsubq_u32(r0rr, r0ri);
|
||||
let r1r = vaddq_u32(r1rr, r1ri);
|
||||
let r2r = vmlaq_laneq_u32::<0>(r2ri, r2rr, consts);
|
||||
|
||||
// imaginary part of result <- real part of input
|
||||
let r0ir = vmlsq_laneq_u32::<1>(vmlaq_laneq_u32::<3>(x0r, x1r, consts), x2r, consts);
|
||||
let r1ir = vmlaq_laneq_u32::<3>(vsubq_u32(x1r, x0r), x2r, consts);
|
||||
let r2ir = vsubq_u32(x2r, vaddq_u32(x0r, x1r));
|
||||
|
||||
// imaginary part of result <- imaginary part of input
|
||||
let r0ii = vaddq_u32(vmlaq_laneq_u32::<0>(x1i, x0i, consts), x2i);
|
||||
let r1ii = vmlaq_laneq_u32::<0>(x2i, vmlsq_laneq_u32::<0>(x1i, x0i, consts), consts);
|
||||
let r2ii = vmlsq_laneq_u32::<0>(x2i, vmlsq_laneq_u32::<1>(x1i, x0i, consts), consts);
|
||||
|
||||
// imaginary part of result (total)
|
||||
let r0i = vaddq_u32(r0ir, r0ii);
|
||||
let r1i = vaddq_u32(r1ir, r1ii);
|
||||
let r2i = vmlaq_laneq_u32::<0>(r2ir, r2ii, consts);
|
||||
|
||||
state_fft[6] = r0r;
|
||||
state_fft[7] = r1r;
|
||||
state_fft[8] = r2r;
|
||||
|
||||
state_fft[9] = r0i;
|
||||
state_fft[10] = r1i;
|
||||
state_fft[11] = r2i;
|
||||
}
|
||||
|
||||
// Three length-4 inverse FFTs.
|
||||
// Normally, such IFFT would divide by 4, but we've already taken care of that.
|
||||
for i in 0..3 {
|
||||
let z0 = state_fft[i];
|
||||
let z1 = state_fft[i + 3];
|
||||
let z2 = state_fft[i + 6];
|
||||
let z3 = state_fft[i + 9];
|
||||
|
||||
let y0 = vsubq_u32(z0, z1);
|
||||
let y1 = vaddq_u32(z0, z1);
|
||||
let y2 = z2;
|
||||
let y3 = z3;
|
||||
|
||||
let x0 = vaddq_u32(y0, y2);
|
||||
let x1 = vaddq_u32(y1, y3);
|
||||
let x2 = vsubq_u32(y0, y2);
|
||||
let x3 = vsubq_u32(y1, y3);
|
||||
|
||||
state_fft[i] = x0;
|
||||
state_fft[i + 3] = x1;
|
||||
state_fft[i + 6] = x2;
|
||||
state_fft[i + 9] = x3;
|
||||
}
|
||||
|
||||
// Perform `res[0] += state[0] * 8` for the diagonal component of the MDS matrix.
|
||||
state_fft[0] = vmlal_laneq_u16::<4>(
|
||||
state_fft[0],
|
||||
vcreate_u16(state[0]), // Each 16-bit chunk gets zero-extended.
|
||||
vreinterpretq_u16_u32(consts), // Hack: these constants fit in `u16s`, so we can bit-cast.
|
||||
);
|
||||
|
||||
let mut res_arr = [0; 12];
|
||||
for i in 0..6 {
|
||||
let res = mds_reduce([state_fft[2 * i], state_fft[2 * i + 1]]);
|
||||
res_arr[2 * i] = vgetq_lane_u64::<0>(res);
|
||||
res_arr[2 * i + 1] = vgetq_lane_u64::<1>(res);
|
||||
}
|
||||
|
||||
res_arr
|
||||
}
|
||||
|
||||
// ======================================== PARTIAL ROUNDS =========================================
|
||||
|
||||
/*
|
||||
#[rustfmt::skip]
|
||||
macro_rules! mds_reduce_asm {
|
||||
($c0:literal, $c1:literal, $out:literal, $consts:literal) => {
|
||||
@ -961,13 +864,15 @@ unsafe fn partial_round(
|
||||
[res23, res45, res67, res89, res1011],
|
||||
)
|
||||
}
|
||||
*/
|
||||
|
||||
// ========================================== GLUE CODE ===========================================
|
||||
|
||||
/*
|
||||
#[inline(always)]
|
||||
unsafe fn full_round(state: [u64; 12], round_constants: &[u64; WIDTH]) -> [u64; 12] {
|
||||
let state = sbox_layer_full(state);
|
||||
mds_const_layers_full(state, round_constants)
|
||||
mds_layer_full(state, round_constants)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
@ -1001,43 +906,19 @@ unsafe fn partial_rounds(
|
||||
}
|
||||
state.0
|
||||
}
|
||||
*/
|
||||
|
||||
#[inline(always)]
|
||||
fn unwrap_state(state: [GoldilocksField; 12]) -> [u64; 12] {
|
||||
[
|
||||
state[0].0,
|
||||
state[1].0,
|
||||
state[2].0,
|
||||
state[3].0,
|
||||
state[4].0,
|
||||
state[5].0,
|
||||
state[6].0,
|
||||
state[7].0,
|
||||
state[8].0,
|
||||
state[9].0,
|
||||
state[10].0,
|
||||
state[11].0,
|
||||
]
|
||||
state.map(|s| s.0)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn wrap_state(state: [u64; 12]) -> [GoldilocksField; 12] {
|
||||
[
|
||||
GoldilocksField(state[0]),
|
||||
GoldilocksField(state[1]),
|
||||
GoldilocksField(state[2]),
|
||||
GoldilocksField(state[3]),
|
||||
GoldilocksField(state[4]),
|
||||
GoldilocksField(state[5]),
|
||||
GoldilocksField(state[6]),
|
||||
GoldilocksField(state[7]),
|
||||
GoldilocksField(state[8]),
|
||||
GoldilocksField(state[9]),
|
||||
GoldilocksField(state[10]),
|
||||
GoldilocksField(state[11]),
|
||||
]
|
||||
state.map(GoldilocksField)
|
||||
}
|
||||
|
||||
/*
|
||||
#[inline(always)]
|
||||
pub unsafe fn poseidon(state: [GoldilocksField; 12]) -> [GoldilocksField; 12] {
|
||||
let state = unwrap_state(state);
|
||||
@ -1058,6 +939,7 @@ pub unsafe fn poseidon(state: [GoldilocksField; 12]) -> [GoldilocksField; 12] {
|
||||
let state = full_rounds(state, &FINAL_ROUND_CONSTANTS);
|
||||
wrap_state(state)
|
||||
}
|
||||
*/
|
||||
|
||||
#[inline(always)]
|
||||
pub unsafe fn sbox_layer(state: &mut [GoldilocksField; WIDTH]) {
|
||||
@ -1067,8 +949,6 @@ pub unsafe fn sbox_layer(state: &mut [GoldilocksField; WIDTH]) {
|
||||
#[inline(always)]
|
||||
pub unsafe fn mds_layer(state: &[GoldilocksField; WIDTH]) -> [GoldilocksField; WIDTH] {
|
||||
let state = unwrap_state(*state);
|
||||
// We want to do an MDS layer without the constant layer.
|
||||
let round_consts = [0u64; WIDTH];
|
||||
let state = mds_const_layers_full(state, &round_consts);
|
||||
let state = mds_layer_full(state);
|
||||
wrap_state(state)
|
||||
}
|
||||
|
||||
@ -252,21 +252,21 @@ impl Poseidon for GoldilocksField {
|
||||
// }
|
||||
// }
|
||||
|
||||
// #[cfg(all(target_arch="aarch64", target_feature="neon"))]
|
||||
// #[inline(always)]
|
||||
// fn sbox_layer(state: &mut [Self; 12]) {
|
||||
// unsafe {
|
||||
// crate::hash::arch::aarch64::poseidon_goldilocks_neon::sbox_layer(state);
|
||||
// }
|
||||
// }
|
||||
#[cfg(all(target_arch="aarch64", target_feature="neon"))]
|
||||
#[inline(always)]
|
||||
fn sbox_layer(state: &mut [Self; 12]) {
|
||||
unsafe {
|
||||
crate::hash::arch::aarch64::poseidon_goldilocks_neon::sbox_layer(state);
|
||||
}
|
||||
}
|
||||
|
||||
// #[cfg(all(target_arch="aarch64", target_feature="neon"))]
|
||||
// #[inline(always)]
|
||||
// fn mds_layer(state: &[Self; 12]) -> [Self; 12] {
|
||||
// unsafe {
|
||||
// crate::hash::arch::aarch64::poseidon_goldilocks_neon::mds_layer(state)
|
||||
// }
|
||||
// }
|
||||
#[cfg(all(target_arch="aarch64", target_feature="neon"))]
|
||||
#[inline(always)]
|
||||
fn mds_layer(state: &[Self; 12]) -> [Self; 12] {
|
||||
unsafe {
|
||||
crate::hash::arch::aarch64::poseidon_goldilocks_neon::mds_layer(state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user