Restore vectorization to full Poseidon rounds on Aarch64 (#498)

* Restore vectorization to full Poseidon layers on Aarch64

* Typos
This commit is contained in:
Jakub Nabaglo 2022-02-21 17:45:01 -08:00 committed by GitHub
parent 6072fab077
commit c7af639579
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 250 additions and 370 deletions

View File

@ -1,2 +1,2 @@
// #[cfg(target_feature = "neon")]
// pub(crate) mod poseidon_goldilocks_neon;
#[cfg(target_feature = "neon")]
pub(crate) mod poseidon_goldilocks_neon;

View File

@ -2,37 +2,24 @@
use std::arch::aarch64::*;
use std::arch::asm;
use std::mem::transmute;
use plonky2_field::field_types::Field64;
use plonky2_field::goldilocks_field::GoldilocksField;
use plonky2_util::branch_hint;
use static_assertions::const_assert;
use unroll::unroll_for_loops;
use crate::hash::poseidon::{
Poseidon, ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS, N_ROUNDS,
};
use crate::hash::poseidon::Poseidon;
// ========================================== CONSTANTS ===========================================
const WIDTH: usize = 12;
// The order below is arbitrary. Repeated coefficients have been removed so these constants fit in
// two registers.
// TODO: ensure this is aligned to 16 bytes (for vector loads), ideally on the same cacheline
const MDS_CONSTS: [u32; 8] = [
0xffffffff,
1 << 1,
1 << 3,
1 << 5,
1 << 8,
1 << 10,
1 << 12,
1 << 16,
];
const EPSILON: u64 = 0xffffffff;
// The round constants to be applied by the second set of full rounds. These are just the usual round constants,
// shifted by one round, with zeros shifted in.
// The round constants to be applied by the second set of full rounds. These are just the usual
// round constants, shifted by one round, with zeros shifted in.
/*
const fn make_final_round_constants() -> [u64; WIDTH * HALF_N_FULL_ROUNDS] {
let mut res = [0; WIDTH * HALF_N_FULL_ROUNDS];
let mut i: usize = 0;
@ -43,6 +30,7 @@ const fn make_final_round_constants() -> [u64; WIDTH * HALF_N_FULL_ROUNDS] {
res
}
const FINAL_ROUND_CONSTANTS: [u64; WIDTH * HALF_N_FULL_ROUNDS] = make_final_round_constants();
*/
// ===================================== COMPILE-TIME CHECKS ======================================
@ -52,9 +40,12 @@ const FINAL_ROUND_CONSTANTS: [u64; WIDTH * HALF_N_FULL_ROUNDS] = make_final_roun
const fn check_mds_matrix() -> bool {
// Can't == two arrays in a const_assert! (:
let mut i = 0;
let wanted_matrix_exps = [0, 0, 1, 0, 3, 5, 1, 8, 12, 3, 16, 10];
let wanted_matrix_circ = [17, 15, 41, 16, 2, 28, 13, 13, 39, 18, 34, 20];
let wanted_matrix_diag = [8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
while i < WIDTH {
if <GoldilocksField as Poseidon>::MDS_MATRIX_EXPS[i] != wanted_matrix_exps[i] {
if <GoldilocksField as Poseidon>::MDS_MATRIX_CIRC[i] != wanted_matrix_circ[i]
|| <GoldilocksField as Poseidon>::MDS_MATRIX_DIAG[i] != wanted_matrix_diag[i]
{
return false;
}
i += 1;
@ -63,37 +54,10 @@ const fn check_mds_matrix() -> bool {
}
const_assert!(check_mds_matrix());
/// The maximum amount by which the MDS matrix will multiply the input.
/// i.e. max(MDS(state)) <= mds_matrix_inf_norm() * max(state).
const fn mds_matrix_inf_norm() -> u64 {
let mut cumul = 0;
let mut i = 0;
while i < WIDTH {
cumul += 1 << <GoldilocksField as Poseidon>::MDS_MATRIX_EXPS[i];
i += 1;
}
cumul
}
/// Ensure that adding round constants to the low result of the MDS multiplication can never
/// overflow.
#[allow(dead_code)]
const fn check_round_const_bounds_mds() -> bool {
let max_mds_res = mds_matrix_inf_norm() * (u32::MAX as u64);
let mut i = WIDTH; // First const layer is handled specially.
while i < WIDTH * N_ROUNDS {
if ALL_ROUND_CONSTANTS[i].overflowing_add(max_mds_res).1 {
return false;
}
i += 1;
}
true
}
const_assert!(check_round_const_bounds_mds());
/// Ensure that the first WIDTH round constants are in canonical* form. This is required because
/// the first constant layer does not handle double overflow.
/// *: round_const == GoldilocksField::ORDER is safe.
/*
#[allow(dead_code)]
const fn check_round_const_bounds_init() -> bool {
let mut i = 0;
@ -106,11 +70,9 @@ const fn check_round_const_bounds_init() -> bool {
true
}
const_assert!(check_round_const_bounds_init());
*/
// ====================================== SCALAR ARITHMETIC =======================================
const EPSILON: u64 = 0xffffffff;
/// Addition modulo ORDER accounting for wraparound. Correct only when a + b < 2**64 + ORDER.
#[inline(always)]
unsafe fn add_with_wraparound(a: u64, b: u64) -> u64 {
@ -133,7 +95,16 @@ unsafe fn add_with_wraparound(a: u64, b: u64) -> u64 {
/// Subtraction of a and (b >> 32) modulo ORDER accounting for wraparound.
#[inline(always)]
unsafe fn sub_with_wraparound_lsr32(a: u64, b: u64) -> u64 {
let b_hi = b >> 32;
let mut b_hi = b >> 32;
// Make sure that LLVM emits two separate instructions for the shift and the subtraction. This
// reduces pressure on the execution units with access to the flags, as they are no longer
// responsible for the shift. The hack is to insert a fake computation between the two
// instructions with an `asm` block to make LLVM think that they can't be merged.
asm!(
"/* {0} */", // Make Rust think we're using the register.
inlateout(reg) b_hi,
options(nomem, nostack, preserves_flags, pure),
);
// This could be done with a.overflowing_add(b_hi), but `checked_sub` signals to the compiler
// that overflow is unlikely (note: this is a standard library implementation detail, not part
// of the spec).
@ -153,7 +124,8 @@ unsafe fn sub_with_wraparound_lsr32(a: u64, b: u64) -> u64 {
unsafe fn mul_epsilon(x: u64) -> u64 {
let res;
asm!(
// Use UMULL to save one instruction. The compiler emits two: extract the low word and then multiply.
// Use UMULL to save one instruction. The compiler emits two: extract the low word and then
// multiply.
"umull {res}, {x:w}, {epsilon:w}",
x = in(reg) x,
epsilon = in(reg) EPSILON,
@ -179,8 +151,9 @@ unsafe fn multiply(x: u64, y: u64) -> u64 {
// ==================================== STANDALONE CONST LAYER =====================================
/// Standalone const layer. Run only once, at the start of round 1. Remaining const layers are fused with the preceeding
/// MDS matrix multiplication.
/// Standalone const layer. Run only once, at the start of round 1. Remaining const layers are fused
/// with the preceeding MDS matrix multiplication.
/*
#[inline(always)]
#[unroll_for_loops]
unsafe fn const_layer_full(
@ -195,15 +168,15 @@ unsafe fn const_layer_full(
}
state
}
*/
// ========================================== FULL ROUNDS ==========================================
/// Full S-box.
#[inline(always)]
#[unroll_for_loops]
unsafe fn sbox_layer_full(state: [u64; WIDTH]) -> [u64; WIDTH] {
// This is done in scalar. S-boxes in vector are only slightly slower throughput-wise but have an insane latency
// (~100 cycles) on the M1.
// This is done in scalar. S-boxes in vector are only slightly slower throughput-wise but have
// an insane latency (~100 cycles) on the M1.
let mut state2 = [0u64; WIDTH];
assert!(WIDTH == 12);
@ -228,297 +201,227 @@ unsafe fn sbox_layer_full(state: [u64; WIDTH]) -> [u64; WIDTH] {
state7
}
// Aliases for readability. E.g. MDS[5] can be found in mdsv5[MDSI5].
const MDSI2: i32 = 1; // MDS[2] == 1
const MDSI4: i32 = 2; // MDS[4] == 3
const MDSI5: i32 = 3; // MDS[5] == 5
const MDSI6: i32 = 1; // MDS[6] == 1
const MDSI7: i32 = 0; // MDS[7] == 8
const MDSI8: i32 = 2; // MDS[8] == 12
const MDSI9: i32 = 2; // MDS[9] == 3
const MDSI10: i32 = 3; // MDS[10] == 16
const MDSI11: i32 = 1; // MDS[11] == 10
#[inline(always)]
unsafe fn mds_reduce(
[[cumul0_a, cumul0_b], [cumul1_a, cumul1_b]]: [[uint64x2_t; 2]; 2],
// `cumul_a` and `cumul_b` represent two separate field elements. We take advantage of
// vectorization by reducing them simultaneously.
[cumul_a, cumul_b]: [uint32x4_t; 2],
) -> uint64x2_t {
// mds_consts0 == [0xffffffff, 1 << 1, 1 << 3, 1 << 5]
let mds_consts0: uint32x4_t = vld1q_u32((&MDS_CONSTS[0..4]).as_ptr().cast::<u32>());
// Merge accumulators
let cumul0 = vaddq_u64(cumul0_a, cumul0_b);
let cumul1 = vaddq_u64(cumul1_a, cumul1_b);
// Swizzle
let res_lo = vzip1q_u64(cumul0, cumul1);
let res_hi = vzip2q_u64(cumul0, cumul1);
// Reduce from u96
let res_hi = vsraq_n_u64::<32>(res_hi, res_lo);
let res_lo = vsliq_n_u64::<32>(res_lo, res_hi);
// Extract high 32-bits.
let res_hi_hi = vget_low_u32(vuzp2q_u32(
vreinterpretq_u32_u64(res_hi),
vreinterpretq_u32_u64(res_hi),
));
// Multiply by EPSILON and accumulate.
let res_unadj = vmlal_laneq_u32::<0>(res_lo, res_hi_hi, mds_consts0);
let res_adj = vcgtq_u64(res_lo, res_unadj);
vsraq_n_u64::<32>(res_unadj, res_adj)
// Form:
// `lo = [cumul_a[0] + cumul_a[2] * 2**32, cumul_b[0] + cumul_b[2] * 2**32]`
// `hi = [cumul_a[1] + cumul_a[3] * 2**32, cumul_b[1] + cumul_b[3] * 2**32]`
// Observe that the result `== lo + hi * 2**16 (mod Goldilocks)`.
let mut lo = vreinterpretq_u64_u32(vuzp1q_u32(cumul_a, cumul_b));
let mut hi = vreinterpretq_u64_u32(vuzp2q_u32(cumul_a, cumul_b));
// Add the high 48 bits of `lo` to `hi`. This cannot overflow.
hi = vsraq_n_u64::<16>(hi, lo);
// Now, result `== lo.bits[0..16] + hi * 2**16 (mod Goldilocks)`.
// Set the high 48 bits of `lo` to the low 48 bits of `hi`.
lo = vsliq_n_u64::<16>(lo, hi);
// At this point, result `== lo + hi.bits[48..64] * 2**64 (mod Goldilocks)`.
// It remains to fold `hi.bits[48..64]` into `lo`.
let top = {
// Extract the top 16 bits of `hi` as a `u32`.
// Interpret `hi` as a vector of bytes, so we can use a table lookup instruction.
let hi_u8 = vreinterpretq_u8_u64(hi);
// Indices defining the permutation. `0xff` is out of bounds, producing `0`.
let top_idx =
transmute::<[u8; 8], uint8x8_t>([0x06, 0x07, 0xff, 0xff, 0x0e, 0x0f, 0xff, 0xff]);
let top_u8 = vqtbl1_u8(hi_u8, top_idx);
vreinterpret_u32_u8(top_u8)
};
// result `== lo + top * 2**64 (mod Goldilocks)`.
let adj_lo = vmlal_n_u32(lo, top, EPSILON as u32);
let wraparound_mask = vcgtq_u64(lo, adj_lo);
vsraq_n_u64::<32>(adj_lo, wraparound_mask) // Add epsilon on overflow.
}
#[inline(always)]
unsafe fn mds_const_layers_full(
state: [u64; WIDTH],
round_constants: &[u64; WIDTH],
) -> [u64; WIDTH] {
// mds_consts0 == [0xffffffff, 1 << 1, 1 << 3, 1 << 5]
// mds_consts1 == [1 << 8, 1 << 10, 1 << 12, 1 << 16]
let mds_consts0: uint32x4_t = vld1q_u32((&MDS_CONSTS[0..4]).as_ptr().cast::<u32>());
let mds_consts1: uint32x4_t = vld1q_u32((&MDS_CONSTS[4..8]).as_ptr().cast::<u32>());
unsafe fn mds_layer_full(state: [u64; WIDTH]) -> [u64; WIDTH] {
// This function performs an MDS multiplication in complex FFT space.
// However, instead of performing a width-12 FFT, we perform three width-4 FFTs, which is
// cheaper. The 12x12 matrix-vector multiplication (a convolution) becomes two 3x3 real
// matrix-vector multiplications and one 3x3 complex matrix-vector multiplication.
// Aliases for readability. E.g. MDS[5] can be found in mdsv5[mdsi5]. MDS[0], MDS[1], and
// MDS[3] are 0, so they are not needed.
let mdsv2 = mds_consts0; // MDS[2] == 1
let mdsv4 = mds_consts0; // MDS[4] == 3
let mdsv5 = mds_consts0; // MDS[5] == 5
let mdsv6 = mds_consts0; // MDS[6] == 1
let mdsv7 = mds_consts1; // MDS[7] == 8
let mdsv8 = mds_consts1; // MDS[8] == 12
let mdsv9 = mds_consts0; // MDS[9] == 3
let mdsv10 = mds_consts1; // MDS[10] == 16
let mdsv11 = mds_consts1; // MDS[11] == 10
// We split each 64-bit into four chunks of 16 bits. To prevent overflow, each chunk is 32 bits
// long. Each NEON vector below represents one field element and consists of four 32-bit chunks:
// `elem == vector[0] + vector[1] * 2**16 + vector[2] * 2**32 + vector[3] * 2**48`.
// For i even, we combine state[i] and state[i + 1] into one vector to save on registers.
// Thus, state1 actually contains state0 and state1 but is only used in the intrinsics that
// access the high high doubleword.
let state1: uint32x4_t =
vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[0]), vcreate_u64(state[1])));
let state3: uint32x4_t =
vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[2]), vcreate_u64(state[3])));
let state5: uint32x4_t =
vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[4]), vcreate_u64(state[5])));
let state7: uint32x4_t =
vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[6]), vcreate_u64(state[7])));
let state9: uint32x4_t =
vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[8]), vcreate_u64(state[9])));
let state11: uint32x4_t =
vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[10]), vcreate_u64(state[11])));
// state0 is an alias to the low doubleword of state1. The compiler should use one register for both.
let state0: uint32x2_t = vget_low_u32(state1);
let state2: uint32x2_t = vget_low_u32(state3);
let state4: uint32x2_t = vget_low_u32(state5);
let state6: uint32x2_t = vget_low_u32(state7);
let state8: uint32x2_t = vget_low_u32(state9);
let state10: uint32x2_t = vget_low_u32(state11);
// Constants that we multiply by.
let mut consts: uint32x4_t = transmute::<[u32; 4], _>([2, 4, 8, 16]);
// Two accumulators per output to hide latency. Each accumulator is a vector of two u64s,
// containing the result for the low 32 bits and the high 32 bits. Thus, the final result at
// index i is (cumuli_a[0] + cumuli_b[0]) + (cumuli_a[1] + cumuli_b[1]) * 2**32.
// Prevent LLVM from turning fused multiply (by power of 2)-add (1 instruction) into shift and
// add (two instructions). This fake `asm` block means that LLVM no longer knows the contents of
// `consts`.
asm!("/* {0:v} */", // Make Rust think the register is being used.
inout(vreg) consts,
options(pure, nomem, nostack, preserves_flags),
);
// Start by loading the round constants.
let mut cumul0_a = vcombine_u64(vld1_u64(&round_constants[0]), vcreate_u64(0));
let mut cumul1_a = vcombine_u64(vld1_u64(&round_constants[1]), vcreate_u64(0));
let mut cumul2_a = vcombine_u64(vld1_u64(&round_constants[2]), vcreate_u64(0));
let mut cumul3_a = vcombine_u64(vld1_u64(&round_constants[3]), vcreate_u64(0));
let mut cumul4_a = vcombine_u64(vld1_u64(&round_constants[4]), vcreate_u64(0));
let mut cumul5_a = vcombine_u64(vld1_u64(&round_constants[5]), vcreate_u64(0));
let mut cumul6_a = vcombine_u64(vld1_u64(&round_constants[6]), vcreate_u64(0));
let mut cumul7_a = vcombine_u64(vld1_u64(&round_constants[7]), vcreate_u64(0));
let mut cumul8_a = vcombine_u64(vld1_u64(&round_constants[8]), vcreate_u64(0));
let mut cumul9_a = vcombine_u64(vld1_u64(&round_constants[9]), vcreate_u64(0));
let mut cumul10_a = vcombine_u64(vld1_u64(&round_constants[10]), vcreate_u64(0));
let mut cumul11_a = vcombine_u64(vld1_u64(&round_constants[11]), vcreate_u64(0));
// Four length-3 complex FFTs.
let mut state_fft = [vdupq_n_u32(0); 12];
for i in 0..3 {
// Interpret each field element as a 4-vector of `u16`s.
let x0 = vcreate_u16(state[i]);
let x1 = vcreate_u16(state[i + 3]);
let x2 = vcreate_u16(state[i + 6]);
let x3 = vcreate_u16(state[i + 9]);
// Now the matrix multiplication.
// MDS exps: [0, 0, 1, 0, 3, 5, 1, 8, 12, 3, 16, 10]
// out[i] += in[j] << mds[j - i]
// `vaddl_u16` and `vsubl_u16` yield 4-vectors of `u32`s.
let y0 = vaddl_u16(x0, x2);
let y1 = vaddl_u16(x1, x3);
let y2 = vsubl_u16(x0, x2);
let y3 = vsubl_u16(x1, x3);
let mut cumul0_b = vshll_n_u32::<0>(state0); // MDS[0]
let mut cumul1_b = vshll_n_u32::<10>(state0); // MDS[11]
let mut cumul2_b = vshll_n_u32::<16>(state0); // MDS[10]
let mut cumul3_b = vshll_n_u32::<3>(state0); // MDS[9]
let mut cumul4_b = vshll_n_u32::<12>(state0); // MDS[8]
let mut cumul5_b = vshll_n_u32::<8>(state0); // MDS[7]
let mut cumul6_b = vshll_n_u32::<1>(state0); // MDS[6]
let mut cumul7_b = vshll_n_u32::<5>(state0); // MDS[5]
let mut cumul8_b = vshll_n_u32::<3>(state0); // MDS[4]
let mut cumul9_b = vshll_n_u32::<0>(state0); // MDS[3]
let mut cumul10_b = vshll_n_u32::<1>(state0); // MDS[2]
let mut cumul11_b = vshll_n_u32::<0>(state0); // MDS[1]
let z0 = vaddq_u32(y0, y1);
let z1 = vsubq_u32(y0, y1);
let z2 = y2;
let z3 = y3;
cumul0_a = vaddw_high_u32(cumul0_a, state1); // MDS[1]
cumul1_a = vaddw_high_u32(cumul1_a, state1); // MDS[0]
cumul2_a = vmlal_high_laneq_u32::<MDSI11>(cumul2_a, state1, mdsv11); // MDS[11]
cumul3_a = vmlal_high_laneq_u32::<MDSI10>(cumul3_a, state1, mdsv10); // MDS[10]
cumul4_a = vmlal_high_laneq_u32::<MDSI9>(cumul4_a, state1, mdsv9); // MDS[9]
cumul5_a = vmlal_high_laneq_u32::<MDSI8>(cumul5_a, state1, mdsv8); // MDS[8]
cumul6_a = vmlal_high_laneq_u32::<MDSI7>(cumul6_a, state1, mdsv7); // MDS[7]
cumul7_a = vmlal_high_laneq_u32::<MDSI6>(cumul7_a, state1, mdsv6); // MDS[6]
cumul8_a = vmlal_high_laneq_u32::<MDSI5>(cumul8_a, state1, mdsv5); // MDS[5]
cumul9_a = vmlal_high_laneq_u32::<MDSI4>(cumul9_a, state1, mdsv4); // MDS[4]
cumul10_a = vaddw_high_u32(cumul10_a, state1); // MDS[3]
cumul11_a = vmlal_high_laneq_u32::<MDSI2>(cumul11_a, state1, mdsv2); // MDS[2]
// The FFT is `[z0, z2 + z3 i, z1, z2 - z3 i]`.
cumul0_b = vmlal_laneq_u32::<MDSI2>(cumul0_b, state2, mdsv2); // MDS[2]
cumul1_b = vaddw_u32(cumul1_b, state2); // MDS[1]
cumul2_b = vaddw_u32(cumul2_b, state2); // MDS[0]
cumul3_b = vmlal_laneq_u32::<MDSI11>(cumul3_b, state2, mdsv11); // MDS[11]
cumul4_b = vmlal_laneq_u32::<MDSI10>(cumul4_b, state2, mdsv10); // MDS[10]
cumul5_b = vmlal_laneq_u32::<MDSI9>(cumul5_b, state2, mdsv9); // MDS[9]
cumul6_b = vmlal_laneq_u32::<MDSI8>(cumul6_b, state2, mdsv8); // MDS[8]
cumul7_b = vmlal_laneq_u32::<MDSI7>(cumul7_b, state2, mdsv7); // MDS[7]
cumul8_b = vmlal_laneq_u32::<MDSI6>(cumul8_b, state2, mdsv6); // MDS[6]
cumul9_b = vmlal_laneq_u32::<MDSI5>(cumul9_b, state2, mdsv5); // MDS[5]
cumul10_b = vmlal_laneq_u32::<MDSI4>(cumul10_b, state2, mdsv4); // MDS[4]
cumul11_b = vaddw_u32(cumul11_b, state2); // MDS[3]
state_fft[i] = z0;
state_fft[i + 3] = z1;
state_fft[i + 6] = z2;
state_fft[i + 9] = z3;
}
cumul0_a = vaddw_high_u32(cumul0_a, state3); // MDS[3]
cumul1_a = vmlal_high_laneq_u32::<MDSI2>(cumul1_a, state3, mdsv2); // MDS[2]
cumul2_a = vaddw_high_u32(cumul2_a, state3); // MDS[1]
cumul3_a = vaddw_high_u32(cumul3_a, state3); // MDS[0]
cumul4_a = vmlal_high_laneq_u32::<MDSI11>(cumul4_a, state3, mdsv11); // MDS[11]
cumul5_a = vmlal_high_laneq_u32::<MDSI10>(cumul5_a, state3, mdsv10); // MDS[10]
cumul6_a = vmlal_high_laneq_u32::<MDSI9>(cumul6_a, state3, mdsv9); // MDS[9]
cumul7_a = vmlal_high_laneq_u32::<MDSI8>(cumul7_a, state3, mdsv8); // MDS[8]
cumul8_a = vmlal_high_laneq_u32::<MDSI7>(cumul8_a, state3, mdsv7); // MDS[7]
cumul9_a = vmlal_high_laneq_u32::<MDSI6>(cumul9_a, state3, mdsv6); // MDS[6]
cumul10_a = vmlal_high_laneq_u32::<MDSI5>(cumul10_a, state3, mdsv5); // MDS[5]
cumul11_a = vmlal_high_laneq_u32::<MDSI4>(cumul11_a, state3, mdsv4); // MDS[4]
// 3x3 real matrix-vector mul for component 0 of the FFTs.
// Multiply the vector `[x0, x1, x2]` by the matrix
// `[[ 64, 64, 128],`
// ` [128, 64, 64],`
// ` [ 64, 128, 64]]`
// The results are divided by 4 (this ends up cancelling out some later computations).
{
let x0 = state_fft[0];
let x1 = state_fft[1];
let x2 = state_fft[2];
cumul0_b = vmlal_laneq_u32::<MDSI4>(cumul0_b, state4, mdsv4); // MDS[4]
cumul1_b = vaddw_u32(cumul1_b, state4); // MDS[3]
cumul2_b = vmlal_laneq_u32::<MDSI2>(cumul2_b, state4, mdsv2); // MDS[2]
cumul3_b = vaddw_u32(cumul3_b, state4); // MDS[1]
cumul4_b = vaddw_u32(cumul4_b, state4); // MDS[0]
cumul5_b = vmlal_laneq_u32::<MDSI11>(cumul5_b, state4, mdsv11); // MDS[11]
cumul6_b = vmlal_laneq_u32::<MDSI10>(cumul6_b, state4, mdsv10); // MDS[10]
cumul7_b = vmlal_laneq_u32::<MDSI9>(cumul7_b, state4, mdsv9); // MDS[9]
cumul8_b = vmlal_laneq_u32::<MDSI8>(cumul8_b, state4, mdsv8); // MDS[8]
cumul9_b = vmlal_laneq_u32::<MDSI7>(cumul9_b, state4, mdsv7); // MDS[7]
cumul10_b = vmlal_laneq_u32::<MDSI6>(cumul10_b, state4, mdsv6); // MDS[6]
cumul11_b = vmlal_laneq_u32::<MDSI5>(cumul11_b, state4, mdsv5); // MDS[5]
let t = vshlq_n_u32::<4>(x0);
let u = vaddq_u32(x1, x2);
cumul0_a = vmlal_high_laneq_u32::<MDSI5>(cumul0_a, state5, mdsv5); // MDS[5]
cumul1_a = vmlal_high_laneq_u32::<MDSI4>(cumul1_a, state5, mdsv4); // MDS[4]
cumul2_a = vaddw_high_u32(cumul2_a, state5); // MDS[3]
cumul3_a = vmlal_high_laneq_u32::<MDSI2>(cumul3_a, state5, mdsv2); // MDS[2]
cumul4_a = vaddw_high_u32(cumul4_a, state5); // MDS[1]
cumul5_a = vaddw_high_u32(cumul5_a, state5); // MDS[0]
cumul6_a = vmlal_high_laneq_u32::<MDSI11>(cumul6_a, state5, mdsv11); // MDS[11]
cumul7_a = vmlal_high_laneq_u32::<MDSI10>(cumul7_a, state5, mdsv10); // MDS[10]
cumul8_a = vmlal_high_laneq_u32::<MDSI9>(cumul8_a, state5, mdsv9); // MDS[9]
cumul9_a = vmlal_high_laneq_u32::<MDSI8>(cumul9_a, state5, mdsv8); // MDS[8]
cumul10_a = vmlal_high_laneq_u32::<MDSI7>(cumul10_a, state5, mdsv7); // MDS[7]
cumul11_a = vmlal_high_laneq_u32::<MDSI6>(cumul11_a, state5, mdsv6); // MDS[6]
let y0 = vshlq_n_u32::<4>(u);
let y1 = vmlaq_laneq_u32::<3>(t, x2, consts);
let y2 = vmlaq_laneq_u32::<3>(t, x1, consts);
cumul0_b = vmlal_laneq_u32::<MDSI6>(cumul0_b, state6, mdsv6); // MDS[6]
cumul1_b = vmlal_laneq_u32::<MDSI5>(cumul1_b, state6, mdsv5); // MDS[5]
cumul2_b = vmlal_laneq_u32::<MDSI4>(cumul2_b, state6, mdsv4); // MDS[4]
cumul3_b = vaddw_u32(cumul3_b, state6); // MDS[3]
cumul4_b = vmlal_laneq_u32::<MDSI2>(cumul4_b, state6, mdsv2); // MDS[2]
cumul5_b = vaddw_u32(cumul5_b, state6); // MDS[1]
cumul6_b = vaddw_u32(cumul6_b, state6); // MDS[0]
cumul7_b = vmlal_laneq_u32::<MDSI11>(cumul7_b, state6, mdsv11); // MDS[11]
cumul8_b = vmlal_laneq_u32::<MDSI10>(cumul8_b, state6, mdsv10); // MDS[10]
cumul9_b = vmlal_laneq_u32::<MDSI9>(cumul9_b, state6, mdsv9); // MDS[9]
cumul10_b = vmlal_laneq_u32::<MDSI8>(cumul10_b, state6, mdsv8); // MDS[8]
cumul11_b = vmlal_laneq_u32::<MDSI7>(cumul11_b, state6, mdsv7); // MDS[7]
state_fft[0] = vaddq_u32(y0, y1);
state_fft[1] = vaddq_u32(y1, y2);
state_fft[2] = vaddq_u32(y0, y2);
}
cumul0_a = vmlal_high_laneq_u32::<MDSI7>(cumul0_a, state7, mdsv7); // MDS[7]
cumul1_a = vmlal_high_laneq_u32::<MDSI6>(cumul1_a, state7, mdsv6); // MDS[6]
cumul2_a = vmlal_high_laneq_u32::<MDSI5>(cumul2_a, state7, mdsv5); // MDS[5]
cumul3_a = vmlal_high_laneq_u32::<MDSI4>(cumul3_a, state7, mdsv4); // MDS[4]
cumul4_a = vaddw_high_u32(cumul4_a, state7); // MDS[3]
cumul5_a = vmlal_high_laneq_u32::<MDSI2>(cumul5_a, state7, mdsv2); // MDS[2]
cumul6_a = vaddw_high_u32(cumul6_a, state7); // MDS[1]
cumul7_a = vaddw_high_u32(cumul7_a, state7); // MDS[0]
cumul8_a = vmlal_high_laneq_u32::<MDSI11>(cumul8_a, state7, mdsv11); // MDS[11]
cumul9_a = vmlal_high_laneq_u32::<MDSI10>(cumul9_a, state7, mdsv10); // MDS[10]
cumul10_a = vmlal_high_laneq_u32::<MDSI9>(cumul10_a, state7, mdsv9); // MDS[9]
cumul11_a = vmlal_high_laneq_u32::<MDSI8>(cumul11_a, state7, mdsv8); // MDS[8]
// 3x3 real matrix-vector mul for component 2 of the FFTs.
// Multiply the vector `[x0, x1, x2]` by the matrix
// `[[ -4, -8, 32],`
// ` [-32, -4, -8],`
// ` [ 8, -32, -4]]`
// The results are divided by 4 (this ends up cancelling out some later computations).
{
let x0 = state_fft[3];
let x1 = state_fft[4];
let x2 = state_fft[5];
state_fft[3] = vmlsq_laneq_u32::<2>(vmlaq_laneq_u32::<0>(x0, x1, consts), x2, consts);
state_fft[4] = vmlaq_laneq_u32::<0>(vmlaq_laneq_u32::<2>(x1, x0, consts), x2, consts);
state_fft[5] = vmlsq_laneq_u32::<0>(x2, vmlsq_laneq_u32::<1>(x0, x1, consts), consts);
}
cumul0_b = vmlal_laneq_u32::<MDSI8>(cumul0_b, state8, mdsv8); // MDS[8]
cumul1_b = vmlal_laneq_u32::<MDSI7>(cumul1_b, state8, mdsv7); // MDS[7]
cumul2_b = vmlal_laneq_u32::<MDSI6>(cumul2_b, state8, mdsv6); // MDS[6]
cumul3_b = vmlal_laneq_u32::<MDSI5>(cumul3_b, state8, mdsv5); // MDS[5]
cumul4_b = vmlal_laneq_u32::<MDSI4>(cumul4_b, state8, mdsv4); // MDS[4]
cumul5_b = vaddw_u32(cumul5_b, state8); // MDS[3]
cumul6_b = vmlal_laneq_u32::<MDSI2>(cumul6_b, state8, mdsv2); // MDS[2]
cumul7_b = vaddw_u32(cumul7_b, state8); // MDS[1]
cumul8_b = vaddw_u32(cumul8_b, state8); // MDS[0]
cumul9_b = vmlal_laneq_u32::<MDSI11>(cumul9_b, state8, mdsv11); // MDS[11]
cumul10_b = vmlal_laneq_u32::<MDSI10>(cumul10_b, state8, mdsv10); // MDS[10]
cumul11_b = vmlal_laneq_u32::<MDSI9>(cumul11_b, state8, mdsv9); // MDS[9]
// 3x3 complex matrix-vector mul for components 1 and 3 of the FFTs.
// Multiply the vector `[x0r + x0i i, x1r + x1i i, x2r + x2i i]` by the matrix
// `[[ 4 + 2i, 2 + 32i, 2 - 8i],`
// ` [-8 - 2i, 4 + 2i, 2 + 32i],`
// ` [32 - 2i, -8 - 2i, 4 + 2i]]`
// The results are divided by 2 (this ends up cancelling out some later computations).
{
let x0r = state_fft[6];
let x1r = state_fft[7];
let x2r = state_fft[8];
cumul0_a = vmlal_high_laneq_u32::<MDSI9>(cumul0_a, state9, mdsv9); // MDS[9]
cumul1_a = vmlal_high_laneq_u32::<MDSI8>(cumul1_a, state9, mdsv8); // MDS[8]
cumul2_a = vmlal_high_laneq_u32::<MDSI7>(cumul2_a, state9, mdsv7); // MDS[7]
cumul3_a = vmlal_high_laneq_u32::<MDSI6>(cumul3_a, state9, mdsv6); // MDS[6]
cumul4_a = vmlal_high_laneq_u32::<MDSI5>(cumul4_a, state9, mdsv5); // MDS[5]
cumul5_a = vmlal_high_laneq_u32::<MDSI4>(cumul5_a, state9, mdsv4); // MDS[4]
cumul6_a = vaddw_high_u32(cumul6_a, state9); // MDS[3]
cumul7_a = vmlal_high_laneq_u32::<MDSI2>(cumul7_a, state9, mdsv2); // MDS[2]
cumul8_a = vaddw_high_u32(cumul8_a, state9); // MDS[1]
cumul9_a = vaddw_high_u32(cumul9_a, state9); // MDS[0]
cumul10_a = vmlal_high_laneq_u32::<MDSI11>(cumul10_a, state9, mdsv11); // MDS[11]
cumul11_a = vmlal_high_laneq_u32::<MDSI10>(cumul11_a, state9, mdsv10); // MDS[10]
let x0i = state_fft[9];
let x1i = state_fft[10];
let x2i = state_fft[11];
cumul0_b = vmlal_laneq_u32::<MDSI10>(cumul0_b, state10, mdsv10); // MDS[10]
cumul1_b = vmlal_laneq_u32::<MDSI9>(cumul1_b, state10, mdsv9); // MDS[9]
cumul2_b = vmlal_laneq_u32::<MDSI8>(cumul2_b, state10, mdsv8); // MDS[8]
cumul3_b = vmlal_laneq_u32::<MDSI7>(cumul3_b, state10, mdsv7); // MDS[7]
cumul4_b = vmlal_laneq_u32::<MDSI6>(cumul4_b, state10, mdsv6); // MDS[6]
cumul5_b = vmlal_laneq_u32::<MDSI5>(cumul5_b, state10, mdsv5); // MDS[5]
cumul6_b = vmlal_laneq_u32::<MDSI4>(cumul6_b, state10, mdsv4); // MDS[4]
cumul7_b = vaddw_u32(cumul7_b, state10); // MDS[3]
cumul8_b = vmlal_laneq_u32::<MDSI2>(cumul8_b, state10, mdsv2); // MDS[2]
cumul9_b = vaddw_u32(cumul9_b, state10); // MDS[1]
cumul10_b = vaddw_u32(cumul10_b, state10); // MDS[0]
cumul11_b = vmlal_laneq_u32::<MDSI11>(cumul11_b, state10, mdsv11); // MDS[11]
// real part of result <- real part of input
let r0rr = vaddq_u32(vmlaq_laneq_u32::<0>(x1r, x0r, consts), x2r);
let r1rr = vmlaq_laneq_u32::<0>(x2r, vmlsq_laneq_u32::<0>(x1r, x0r, consts), consts);
let r2rr = vmlsq_laneq_u32::<0>(x2r, vmlsq_laneq_u32::<1>(x1r, x0r, consts), consts);
cumul0_a = vmlal_high_laneq_u32::<MDSI11>(cumul0_a, state11, mdsv11); // MDS[11]
cumul1_a = vmlal_high_laneq_u32::<MDSI10>(cumul1_a, state11, mdsv10); // MDS[10]
cumul2_a = vmlal_high_laneq_u32::<MDSI9>(cumul2_a, state11, mdsv9); // MDS[9]
cumul3_a = vmlal_high_laneq_u32::<MDSI8>(cumul3_a, state11, mdsv8); // MDS[8]
cumul4_a = vmlal_high_laneq_u32::<MDSI7>(cumul4_a, state11, mdsv7); // MDS[7]
cumul5_a = vmlal_high_laneq_u32::<MDSI6>(cumul5_a, state11, mdsv6); // MDS[6]
cumul6_a = vmlal_high_laneq_u32::<MDSI5>(cumul6_a, state11, mdsv5); // MDS[5]
cumul7_a = vmlal_high_laneq_u32::<MDSI4>(cumul7_a, state11, mdsv4); // MDS[4]
cumul8_a = vaddw_high_u32(cumul8_a, state11); // MDS[3]
cumul9_a = vmlal_high_laneq_u32::<MDSI2>(cumul9_a, state11, mdsv2); // MDS[2]
cumul10_a = vaddw_high_u32(cumul10_a, state11); // MDS[1]
cumul11_a = vaddw_high_u32(cumul11_a, state11); // MDS[0]
// real part of result <- imaginary part of input
let r0ri = vmlsq_laneq_u32::<1>(vmlaq_laneq_u32::<3>(x0i, x1i, consts), x2i, consts);
let r1ri = vmlsq_laneq_u32::<3>(vsubq_u32(x0i, x1i), x2i, consts);
let r2ri = vsubq_u32(vaddq_u32(x0i, x1i), x2i);
let reduced = [
mds_reduce([[cumul0_a, cumul0_b], [cumul1_a, cumul1_b]]),
mds_reduce([[cumul2_a, cumul2_b], [cumul3_a, cumul3_b]]),
mds_reduce([[cumul4_a, cumul4_b], [cumul5_a, cumul5_b]]),
mds_reduce([[cumul6_a, cumul6_b], [cumul7_a, cumul7_b]]),
mds_reduce([[cumul8_a, cumul8_b], [cumul9_a, cumul9_b]]),
mds_reduce([[cumul10_a, cumul10_b], [cumul11_a, cumul11_b]]),
];
[
vgetq_lane_u64::<0>(reduced[0]),
vgetq_lane_u64::<1>(reduced[0]),
vgetq_lane_u64::<0>(reduced[1]),
vgetq_lane_u64::<1>(reduced[1]),
vgetq_lane_u64::<0>(reduced[2]),
vgetq_lane_u64::<1>(reduced[2]),
vgetq_lane_u64::<0>(reduced[3]),
vgetq_lane_u64::<1>(reduced[3]),
vgetq_lane_u64::<0>(reduced[4]),
vgetq_lane_u64::<1>(reduced[4]),
vgetq_lane_u64::<0>(reduced[5]),
vgetq_lane_u64::<1>(reduced[5]),
]
// real part of result (total)
let r0r = vsubq_u32(r0rr, r0ri);
let r1r = vaddq_u32(r1rr, r1ri);
let r2r = vmlaq_laneq_u32::<0>(r2ri, r2rr, consts);
// imaginary part of result <- real part of input
let r0ir = vmlsq_laneq_u32::<1>(vmlaq_laneq_u32::<3>(x0r, x1r, consts), x2r, consts);
let r1ir = vmlaq_laneq_u32::<3>(vsubq_u32(x1r, x0r), x2r, consts);
let r2ir = vsubq_u32(x2r, vaddq_u32(x0r, x1r));
// imaginary part of result <- imaginary part of input
let r0ii = vaddq_u32(vmlaq_laneq_u32::<0>(x1i, x0i, consts), x2i);
let r1ii = vmlaq_laneq_u32::<0>(x2i, vmlsq_laneq_u32::<0>(x1i, x0i, consts), consts);
let r2ii = vmlsq_laneq_u32::<0>(x2i, vmlsq_laneq_u32::<1>(x1i, x0i, consts), consts);
// imaginary part of result (total)
let r0i = vaddq_u32(r0ir, r0ii);
let r1i = vaddq_u32(r1ir, r1ii);
let r2i = vmlaq_laneq_u32::<0>(r2ir, r2ii, consts);
state_fft[6] = r0r;
state_fft[7] = r1r;
state_fft[8] = r2r;
state_fft[9] = r0i;
state_fft[10] = r1i;
state_fft[11] = r2i;
}
// Three length-4 inverse FFTs.
// Normally, such IFFT would divide by 4, but we've already taken care of that.
for i in 0..3 {
let z0 = state_fft[i];
let z1 = state_fft[i + 3];
let z2 = state_fft[i + 6];
let z3 = state_fft[i + 9];
let y0 = vsubq_u32(z0, z1);
let y1 = vaddq_u32(z0, z1);
let y2 = z2;
let y3 = z3;
let x0 = vaddq_u32(y0, y2);
let x1 = vaddq_u32(y1, y3);
let x2 = vsubq_u32(y0, y2);
let x3 = vsubq_u32(y1, y3);
state_fft[i] = x0;
state_fft[i + 3] = x1;
state_fft[i + 6] = x2;
state_fft[i + 9] = x3;
}
// Perform `res[0] += state[0] * 8` for the diagonal component of the MDS matrix.
state_fft[0] = vmlal_laneq_u16::<4>(
state_fft[0],
vcreate_u16(state[0]), // Each 16-bit chunk gets zero-extended.
vreinterpretq_u16_u32(consts), // Hack: these constants fit in `u16s`, so we can bit-cast.
);
let mut res_arr = [0; 12];
for i in 0..6 {
let res = mds_reduce([state_fft[2 * i], state_fft[2 * i + 1]]);
res_arr[2 * i] = vgetq_lane_u64::<0>(res);
res_arr[2 * i + 1] = vgetq_lane_u64::<1>(res);
}
res_arr
}
// ======================================== PARTIAL ROUNDS =========================================
/*
#[rustfmt::skip]
macro_rules! mds_reduce_asm {
($c0:literal, $c1:literal, $out:literal, $consts:literal) => {
@ -961,13 +864,15 @@ unsafe fn partial_round(
[res23, res45, res67, res89, res1011],
)
}
*/
// ========================================== GLUE CODE ===========================================
/*
#[inline(always)]
unsafe fn full_round(state: [u64; 12], round_constants: &[u64; WIDTH]) -> [u64; 12] {
let state = sbox_layer_full(state);
mds_const_layers_full(state, round_constants)
mds_layer_full(state, round_constants)
}
#[inline]
@ -1001,43 +906,19 @@ unsafe fn partial_rounds(
}
state.0
}
*/
#[inline(always)]
fn unwrap_state(state: [GoldilocksField; 12]) -> [u64; 12] {
[
state[0].0,
state[1].0,
state[2].0,
state[3].0,
state[4].0,
state[5].0,
state[6].0,
state[7].0,
state[8].0,
state[9].0,
state[10].0,
state[11].0,
]
state.map(|s| s.0)
}
#[inline(always)]
fn wrap_state(state: [u64; 12]) -> [GoldilocksField; 12] {
[
GoldilocksField(state[0]),
GoldilocksField(state[1]),
GoldilocksField(state[2]),
GoldilocksField(state[3]),
GoldilocksField(state[4]),
GoldilocksField(state[5]),
GoldilocksField(state[6]),
GoldilocksField(state[7]),
GoldilocksField(state[8]),
GoldilocksField(state[9]),
GoldilocksField(state[10]),
GoldilocksField(state[11]),
]
state.map(GoldilocksField)
}
/*
#[inline(always)]
pub unsafe fn poseidon(state: [GoldilocksField; 12]) -> [GoldilocksField; 12] {
let state = unwrap_state(state);
@ -1058,6 +939,7 @@ pub unsafe fn poseidon(state: [GoldilocksField; 12]) -> [GoldilocksField; 12] {
let state = full_rounds(state, &FINAL_ROUND_CONSTANTS);
wrap_state(state)
}
*/
#[inline(always)]
pub unsafe fn sbox_layer(state: &mut [GoldilocksField; WIDTH]) {
@ -1067,8 +949,6 @@ pub unsafe fn sbox_layer(state: &mut [GoldilocksField; WIDTH]) {
#[inline(always)]
pub unsafe fn mds_layer(state: &[GoldilocksField; WIDTH]) -> [GoldilocksField; WIDTH] {
let state = unwrap_state(*state);
// We want to do an MDS layer without the constant layer.
let round_consts = [0u64; WIDTH];
let state = mds_const_layers_full(state, &round_consts);
let state = mds_layer_full(state);
wrap_state(state)
}

View File

@ -252,21 +252,21 @@ impl Poseidon for GoldilocksField {
// }
// }
// #[cfg(all(target_arch="aarch64", target_feature="neon"))]
// #[inline(always)]
// fn sbox_layer(state: &mut [Self; 12]) {
// unsafe {
// crate::hash::arch::aarch64::poseidon_goldilocks_neon::sbox_layer(state);
// }
// }
#[cfg(all(target_arch="aarch64", target_feature="neon"))]
#[inline(always)]
fn sbox_layer(state: &mut [Self; 12]) {
unsafe {
crate::hash::arch::aarch64::poseidon_goldilocks_neon::sbox_layer(state);
}
}
// #[cfg(all(target_arch="aarch64", target_feature="neon"))]
// #[inline(always)]
// fn mds_layer(state: &[Self; 12]) -> [Self; 12] {
// unsafe {
// crate::hash::arch::aarch64::poseidon_goldilocks_neon::mds_layer(state)
// }
// }
#[cfg(all(target_arch="aarch64", target_feature="neon"))]
#[inline(always)]
fn mds_layer(state: &[Self; 12]) -> [Self; 12] {
unsafe {
crate::hash::arch::aarch64::poseidon_goldilocks_neon::mds_layer(state)
}
}
}
#[cfg(test)]