diff --git a/plonky2/src/hash/arch/aarch64/mod.rs b/plonky2/src/hash/arch/aarch64/mod.rs index ba86797d..b8ae14af 100644 --- a/plonky2/src/hash/arch/aarch64/mod.rs +++ b/plonky2/src/hash/arch/aarch64/mod.rs @@ -1,2 +1,2 @@ -// #[cfg(target_feature = "neon")] -// pub(crate) mod poseidon_goldilocks_neon; +#[cfg(target_feature = "neon")] +pub(crate) mod poseidon_goldilocks_neon; diff --git a/plonky2/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs b/plonky2/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs index f2276506..352456e7 100644 --- a/plonky2/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs +++ b/plonky2/src/hash/arch/aarch64/poseidon_goldilocks_neon.rs @@ -2,37 +2,24 @@ use std::arch::aarch64::*; use std::arch::asm; +use std::mem::transmute; -use plonky2_field::field_types::Field64; use plonky2_field::goldilocks_field::GoldilocksField; use plonky2_util::branch_hint; use static_assertions::const_assert; use unroll::unroll_for_loops; -use crate::hash::poseidon::{ - Poseidon, ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS, N_ROUNDS, -}; +use crate::hash::poseidon::Poseidon; // ========================================== CONSTANTS =========================================== const WIDTH: usize = 12; -// The order below is arbitrary. Repeated coefficients have been removed so these constants fit in -// two registers. -// TODO: ensure this is aligned to 16 bytes (for vector loads), ideally on the same cacheline -const MDS_CONSTS: [u32; 8] = [ - 0xffffffff, - 1 << 1, - 1 << 3, - 1 << 5, - 1 << 8, - 1 << 10, - 1 << 12, - 1 << 16, -]; +const EPSILON: u64 = 0xffffffff; -// The round constants to be applied by the second set of full rounds. These are just the usual round constants, -// shifted by one round, with zeros shifted in. +// The round constants to be applied by the second set of full rounds. These are just the usual +// round constants, shifted by one round, with zeros shifted in. +/* const fn make_final_round_constants() -> [u64; WIDTH * HALF_N_FULL_ROUNDS] { let mut res = [0; WIDTH * HALF_N_FULL_ROUNDS]; let mut i: usize = 0; @@ -43,6 +30,7 @@ const fn make_final_round_constants() -> [u64; WIDTH * HALF_N_FULL_ROUNDS] { res } const FINAL_ROUND_CONSTANTS: [u64; WIDTH * HALF_N_FULL_ROUNDS] = make_final_round_constants(); +*/ // ===================================== COMPILE-TIME CHECKS ====================================== @@ -52,9 +40,12 @@ const FINAL_ROUND_CONSTANTS: [u64; WIDTH * HALF_N_FULL_ROUNDS] = make_final_roun const fn check_mds_matrix() -> bool { // Can't == two arrays in a const_assert! (: let mut i = 0; - let wanted_matrix_exps = [0, 0, 1, 0, 3, 5, 1, 8, 12, 3, 16, 10]; + let wanted_matrix_circ = [17, 15, 41, 16, 2, 28, 13, 13, 39, 18, 34, 20]; + let wanted_matrix_diag = [8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; while i < WIDTH { - if ::MDS_MATRIX_EXPS[i] != wanted_matrix_exps[i] { + if ::MDS_MATRIX_CIRC[i] != wanted_matrix_circ[i] + || ::MDS_MATRIX_DIAG[i] != wanted_matrix_diag[i] + { return false; } i += 1; @@ -63,37 +54,10 @@ const fn check_mds_matrix() -> bool { } const_assert!(check_mds_matrix()); -/// The maximum amount by which the MDS matrix will multiply the input. -/// i.e. max(MDS(state)) <= mds_matrix_inf_norm() * max(state). -const fn mds_matrix_inf_norm() -> u64 { - let mut cumul = 0; - let mut i = 0; - while i < WIDTH { - cumul += 1 << ::MDS_MATRIX_EXPS[i]; - i += 1; - } - cumul -} - -/// Ensure that adding round constants to the low result of the MDS multiplication can never -/// overflow. -#[allow(dead_code)] -const fn check_round_const_bounds_mds() -> bool { - let max_mds_res = mds_matrix_inf_norm() * (u32::MAX as u64); - let mut i = WIDTH; // First const layer is handled specially. - while i < WIDTH * N_ROUNDS { - if ALL_ROUND_CONSTANTS[i].overflowing_add(max_mds_res).1 { - return false; - } - i += 1; - } - true -} -const_assert!(check_round_const_bounds_mds()); - /// Ensure that the first WIDTH round constants are in canonical* form. This is required because /// the first constant layer does not handle double overflow. /// *: round_const == GoldilocksField::ORDER is safe. +/* #[allow(dead_code)] const fn check_round_const_bounds_init() -> bool { let mut i = 0; @@ -106,11 +70,9 @@ const fn check_round_const_bounds_init() -> bool { true } const_assert!(check_round_const_bounds_init()); - +*/ // ====================================== SCALAR ARITHMETIC ======================================= -const EPSILON: u64 = 0xffffffff; - /// Addition modulo ORDER accounting for wraparound. Correct only when a + b < 2**64 + ORDER. #[inline(always)] unsafe fn add_with_wraparound(a: u64, b: u64) -> u64 { @@ -133,7 +95,16 @@ unsafe fn add_with_wraparound(a: u64, b: u64) -> u64 { /// Subtraction of a and (b >> 32) modulo ORDER accounting for wraparound. #[inline(always)] unsafe fn sub_with_wraparound_lsr32(a: u64, b: u64) -> u64 { - let b_hi = b >> 32; + let mut b_hi = b >> 32; + // Make sure that LLVM emits two separate instructions for the shift and the subtraction. This + // reduces pressure on the execution units with access to the flags, as they are no longer + // responsible for the shift. The hack is to insert a fake computation between the two + // instructions with an `asm` block to make LLVM think that they can't be merged. + asm!( + "/* {0} */", // Make Rust think we're using the register. + inlateout(reg) b_hi, + options(nomem, nostack, preserves_flags, pure), + ); // This could be done with a.overflowing_add(b_hi), but `checked_sub` signals to the compiler // that overflow is unlikely (note: this is a standard library implementation detail, not part // of the spec). @@ -153,7 +124,8 @@ unsafe fn sub_with_wraparound_lsr32(a: u64, b: u64) -> u64 { unsafe fn mul_epsilon(x: u64) -> u64 { let res; asm!( - // Use UMULL to save one instruction. The compiler emits two: extract the low word and then multiply. + // Use UMULL to save one instruction. The compiler emits two: extract the low word and then + // multiply. "umull {res}, {x:w}, {epsilon:w}", x = in(reg) x, epsilon = in(reg) EPSILON, @@ -179,8 +151,9 @@ unsafe fn multiply(x: u64, y: u64) -> u64 { // ==================================== STANDALONE CONST LAYER ===================================== -/// Standalone const layer. Run only once, at the start of round 1. Remaining const layers are fused with the preceeding -/// MDS matrix multiplication. +/// Standalone const layer. Run only once, at the start of round 1. Remaining const layers are fused +/// with the preceeding MDS matrix multiplication. +/* #[inline(always)] #[unroll_for_loops] unsafe fn const_layer_full( @@ -195,15 +168,15 @@ unsafe fn const_layer_full( } state } - +*/ // ========================================== FULL ROUNDS ========================================== /// Full S-box. #[inline(always)] #[unroll_for_loops] unsafe fn sbox_layer_full(state: [u64; WIDTH]) -> [u64; WIDTH] { - // This is done in scalar. S-boxes in vector are only slightly slower throughput-wise but have an insane latency - // (~100 cycles) on the M1. + // This is done in scalar. S-boxes in vector are only slightly slower throughput-wise but have + // an insane latency (~100 cycles) on the M1. let mut state2 = [0u64; WIDTH]; assert!(WIDTH == 12); @@ -228,297 +201,227 @@ unsafe fn sbox_layer_full(state: [u64; WIDTH]) -> [u64; WIDTH] { state7 } -// Aliases for readability. E.g. MDS[5] can be found in mdsv5[MDSI5]. -const MDSI2: i32 = 1; // MDS[2] == 1 -const MDSI4: i32 = 2; // MDS[4] == 3 -const MDSI5: i32 = 3; // MDS[5] == 5 -const MDSI6: i32 = 1; // MDS[6] == 1 -const MDSI7: i32 = 0; // MDS[7] == 8 -const MDSI8: i32 = 2; // MDS[8] == 12 -const MDSI9: i32 = 2; // MDS[9] == 3 -const MDSI10: i32 = 3; // MDS[10] == 16 -const MDSI11: i32 = 1; // MDS[11] == 10 - #[inline(always)] unsafe fn mds_reduce( - [[cumul0_a, cumul0_b], [cumul1_a, cumul1_b]]: [[uint64x2_t; 2]; 2], + // `cumul_a` and `cumul_b` represent two separate field elements. We take advantage of + // vectorization by reducing them simultaneously. + [cumul_a, cumul_b]: [uint32x4_t; 2], ) -> uint64x2_t { - // mds_consts0 == [0xffffffff, 1 << 1, 1 << 3, 1 << 5] - let mds_consts0: uint32x4_t = vld1q_u32((&MDS_CONSTS[0..4]).as_ptr().cast::()); - - // Merge accumulators - let cumul0 = vaddq_u64(cumul0_a, cumul0_b); - let cumul1 = vaddq_u64(cumul1_a, cumul1_b); - - // Swizzle - let res_lo = vzip1q_u64(cumul0, cumul1); - let res_hi = vzip2q_u64(cumul0, cumul1); - - // Reduce from u96 - let res_hi = vsraq_n_u64::<32>(res_hi, res_lo); - let res_lo = vsliq_n_u64::<32>(res_lo, res_hi); - - // Extract high 32-bits. - let res_hi_hi = vget_low_u32(vuzp2q_u32( - vreinterpretq_u32_u64(res_hi), - vreinterpretq_u32_u64(res_hi), - )); - - // Multiply by EPSILON and accumulate. - let res_unadj = vmlal_laneq_u32::<0>(res_lo, res_hi_hi, mds_consts0); - let res_adj = vcgtq_u64(res_lo, res_unadj); - vsraq_n_u64::<32>(res_unadj, res_adj) + // Form: + // `lo = [cumul_a[0] + cumul_a[2] * 2**32, cumul_b[0] + cumul_b[2] * 2**32]` + // `hi = [cumul_a[1] + cumul_a[3] * 2**32, cumul_b[1] + cumul_b[3] * 2**32]` + // Observe that the result `== lo + hi * 2**16 (mod Goldilocks)`. + let mut lo = vreinterpretq_u64_u32(vuzp1q_u32(cumul_a, cumul_b)); + let mut hi = vreinterpretq_u64_u32(vuzp2q_u32(cumul_a, cumul_b)); + // Add the high 48 bits of `lo` to `hi`. This cannot overflow. + hi = vsraq_n_u64::<16>(hi, lo); + // Now, result `== lo.bits[0..16] + hi * 2**16 (mod Goldilocks)`. + // Set the high 48 bits of `lo` to the low 48 bits of `hi`. + lo = vsliq_n_u64::<16>(lo, hi); + // At this point, result `== lo + hi.bits[48..64] * 2**64 (mod Goldilocks)`. + // It remains to fold `hi.bits[48..64]` into `lo`. + let top = { + // Extract the top 16 bits of `hi` as a `u32`. + // Interpret `hi` as a vector of bytes, so we can use a table lookup instruction. + let hi_u8 = vreinterpretq_u8_u64(hi); + // Indices defining the permutation. `0xff` is out of bounds, producing `0`. + let top_idx = + transmute::<[u8; 8], uint8x8_t>([0x06, 0x07, 0xff, 0xff, 0x0e, 0x0f, 0xff, 0xff]); + let top_u8 = vqtbl1_u8(hi_u8, top_idx); + vreinterpret_u32_u8(top_u8) + }; + // result `== lo + top * 2**64 (mod Goldilocks)`. + let adj_lo = vmlal_n_u32(lo, top, EPSILON as u32); + let wraparound_mask = vcgtq_u64(lo, adj_lo); + vsraq_n_u64::<32>(adj_lo, wraparound_mask) // Add epsilon on overflow. } #[inline(always)] -unsafe fn mds_const_layers_full( - state: [u64; WIDTH], - round_constants: &[u64; WIDTH], -) -> [u64; WIDTH] { - // mds_consts0 == [0xffffffff, 1 << 1, 1 << 3, 1 << 5] - // mds_consts1 == [1 << 8, 1 << 10, 1 << 12, 1 << 16] - let mds_consts0: uint32x4_t = vld1q_u32((&MDS_CONSTS[0..4]).as_ptr().cast::()); - let mds_consts1: uint32x4_t = vld1q_u32((&MDS_CONSTS[4..8]).as_ptr().cast::()); +unsafe fn mds_layer_full(state: [u64; WIDTH]) -> [u64; WIDTH] { + // This function performs an MDS multiplication in complex FFT space. + // However, instead of performing a width-12 FFT, we perform three width-4 FFTs, which is + // cheaper. The 12x12 matrix-vector multiplication (a convolution) becomes two 3x3 real + // matrix-vector multiplications and one 3x3 complex matrix-vector multiplication. - // Aliases for readability. E.g. MDS[5] can be found in mdsv5[mdsi5]. MDS[0], MDS[1], and - // MDS[3] are 0, so they are not needed. - let mdsv2 = mds_consts0; // MDS[2] == 1 - let mdsv4 = mds_consts0; // MDS[4] == 3 - let mdsv5 = mds_consts0; // MDS[5] == 5 - let mdsv6 = mds_consts0; // MDS[6] == 1 - let mdsv7 = mds_consts1; // MDS[7] == 8 - let mdsv8 = mds_consts1; // MDS[8] == 12 - let mdsv9 = mds_consts0; // MDS[9] == 3 - let mdsv10 = mds_consts1; // MDS[10] == 16 - let mdsv11 = mds_consts1; // MDS[11] == 10 + // We split each 64-bit into four chunks of 16 bits. To prevent overflow, each chunk is 32 bits + // long. Each NEON vector below represents one field element and consists of four 32-bit chunks: + // `elem == vector[0] + vector[1] * 2**16 + vector[2] * 2**32 + vector[3] * 2**48`. - // For i even, we combine state[i] and state[i + 1] into one vector to save on registers. - // Thus, state1 actually contains state0 and state1 but is only used in the intrinsics that - // access the high high doubleword. - let state1: uint32x4_t = - vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[0]), vcreate_u64(state[1]))); - let state3: uint32x4_t = - vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[2]), vcreate_u64(state[3]))); - let state5: uint32x4_t = - vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[4]), vcreate_u64(state[5]))); - let state7: uint32x4_t = - vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[6]), vcreate_u64(state[7]))); - let state9: uint32x4_t = - vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[8]), vcreate_u64(state[9]))); - let state11: uint32x4_t = - vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[10]), vcreate_u64(state[11]))); - // state0 is an alias to the low doubleword of state1. The compiler should use one register for both. - let state0: uint32x2_t = vget_low_u32(state1); - let state2: uint32x2_t = vget_low_u32(state3); - let state4: uint32x2_t = vget_low_u32(state5); - let state6: uint32x2_t = vget_low_u32(state7); - let state8: uint32x2_t = vget_low_u32(state9); - let state10: uint32x2_t = vget_low_u32(state11); + // Constants that we multiply by. + let mut consts: uint32x4_t = transmute::<[u32; 4], _>([2, 4, 8, 16]); - // Two accumulators per output to hide latency. Each accumulator is a vector of two u64s, - // containing the result for the low 32 bits and the high 32 bits. Thus, the final result at - // index i is (cumuli_a[0] + cumuli_b[0]) + (cumuli_a[1] + cumuli_b[1]) * 2**32. + // Prevent LLVM from turning fused multiply (by power of 2)-add (1 instruction) into shift and + // add (two instructions). This fake `asm` block means that LLVM no longer knows the contents of + // `consts`. + asm!("/* {0:v} */", // Make Rust think the register is being used. + inout(vreg) consts, + options(pure, nomem, nostack, preserves_flags), + ); - // Start by loading the round constants. - let mut cumul0_a = vcombine_u64(vld1_u64(&round_constants[0]), vcreate_u64(0)); - let mut cumul1_a = vcombine_u64(vld1_u64(&round_constants[1]), vcreate_u64(0)); - let mut cumul2_a = vcombine_u64(vld1_u64(&round_constants[2]), vcreate_u64(0)); - let mut cumul3_a = vcombine_u64(vld1_u64(&round_constants[3]), vcreate_u64(0)); - let mut cumul4_a = vcombine_u64(vld1_u64(&round_constants[4]), vcreate_u64(0)); - let mut cumul5_a = vcombine_u64(vld1_u64(&round_constants[5]), vcreate_u64(0)); - let mut cumul6_a = vcombine_u64(vld1_u64(&round_constants[6]), vcreate_u64(0)); - let mut cumul7_a = vcombine_u64(vld1_u64(&round_constants[7]), vcreate_u64(0)); - let mut cumul8_a = vcombine_u64(vld1_u64(&round_constants[8]), vcreate_u64(0)); - let mut cumul9_a = vcombine_u64(vld1_u64(&round_constants[9]), vcreate_u64(0)); - let mut cumul10_a = vcombine_u64(vld1_u64(&round_constants[10]), vcreate_u64(0)); - let mut cumul11_a = vcombine_u64(vld1_u64(&round_constants[11]), vcreate_u64(0)); + // Four length-3 complex FFTs. + let mut state_fft = [vdupq_n_u32(0); 12]; + for i in 0..3 { + // Interpret each field element as a 4-vector of `u16`s. + let x0 = vcreate_u16(state[i]); + let x1 = vcreate_u16(state[i + 3]); + let x2 = vcreate_u16(state[i + 6]); + let x3 = vcreate_u16(state[i + 9]); - // Now the matrix multiplication. - // MDS exps: [0, 0, 1, 0, 3, 5, 1, 8, 12, 3, 16, 10] - // out[i] += in[j] << mds[j - i] + // `vaddl_u16` and `vsubl_u16` yield 4-vectors of `u32`s. + let y0 = vaddl_u16(x0, x2); + let y1 = vaddl_u16(x1, x3); + let y2 = vsubl_u16(x0, x2); + let y3 = vsubl_u16(x1, x3); - let mut cumul0_b = vshll_n_u32::<0>(state0); // MDS[0] - let mut cumul1_b = vshll_n_u32::<10>(state0); // MDS[11] - let mut cumul2_b = vshll_n_u32::<16>(state0); // MDS[10] - let mut cumul3_b = vshll_n_u32::<3>(state0); // MDS[9] - let mut cumul4_b = vshll_n_u32::<12>(state0); // MDS[8] - let mut cumul5_b = vshll_n_u32::<8>(state0); // MDS[7] - let mut cumul6_b = vshll_n_u32::<1>(state0); // MDS[6] - let mut cumul7_b = vshll_n_u32::<5>(state0); // MDS[5] - let mut cumul8_b = vshll_n_u32::<3>(state0); // MDS[4] - let mut cumul9_b = vshll_n_u32::<0>(state0); // MDS[3] - let mut cumul10_b = vshll_n_u32::<1>(state0); // MDS[2] - let mut cumul11_b = vshll_n_u32::<0>(state0); // MDS[1] + let z0 = vaddq_u32(y0, y1); + let z1 = vsubq_u32(y0, y1); + let z2 = y2; + let z3 = y3; - cumul0_a = vaddw_high_u32(cumul0_a, state1); // MDS[1] - cumul1_a = vaddw_high_u32(cumul1_a, state1); // MDS[0] - cumul2_a = vmlal_high_laneq_u32::(cumul2_a, state1, mdsv11); // MDS[11] - cumul3_a = vmlal_high_laneq_u32::(cumul3_a, state1, mdsv10); // MDS[10] - cumul4_a = vmlal_high_laneq_u32::(cumul4_a, state1, mdsv9); // MDS[9] - cumul5_a = vmlal_high_laneq_u32::(cumul5_a, state1, mdsv8); // MDS[8] - cumul6_a = vmlal_high_laneq_u32::(cumul6_a, state1, mdsv7); // MDS[7] - cumul7_a = vmlal_high_laneq_u32::(cumul7_a, state1, mdsv6); // MDS[6] - cumul8_a = vmlal_high_laneq_u32::(cumul8_a, state1, mdsv5); // MDS[5] - cumul9_a = vmlal_high_laneq_u32::(cumul9_a, state1, mdsv4); // MDS[4] - cumul10_a = vaddw_high_u32(cumul10_a, state1); // MDS[3] - cumul11_a = vmlal_high_laneq_u32::(cumul11_a, state1, mdsv2); // MDS[2] + // The FFT is `[z0, z2 + z3 i, z1, z2 - z3 i]`. - cumul0_b = vmlal_laneq_u32::(cumul0_b, state2, mdsv2); // MDS[2] - cumul1_b = vaddw_u32(cumul1_b, state2); // MDS[1] - cumul2_b = vaddw_u32(cumul2_b, state2); // MDS[0] - cumul3_b = vmlal_laneq_u32::(cumul3_b, state2, mdsv11); // MDS[11] - cumul4_b = vmlal_laneq_u32::(cumul4_b, state2, mdsv10); // MDS[10] - cumul5_b = vmlal_laneq_u32::(cumul5_b, state2, mdsv9); // MDS[9] - cumul6_b = vmlal_laneq_u32::(cumul6_b, state2, mdsv8); // MDS[8] - cumul7_b = vmlal_laneq_u32::(cumul7_b, state2, mdsv7); // MDS[7] - cumul8_b = vmlal_laneq_u32::(cumul8_b, state2, mdsv6); // MDS[6] - cumul9_b = vmlal_laneq_u32::(cumul9_b, state2, mdsv5); // MDS[5] - cumul10_b = vmlal_laneq_u32::(cumul10_b, state2, mdsv4); // MDS[4] - cumul11_b = vaddw_u32(cumul11_b, state2); // MDS[3] + state_fft[i] = z0; + state_fft[i + 3] = z1; + state_fft[i + 6] = z2; + state_fft[i + 9] = z3; + } - cumul0_a = vaddw_high_u32(cumul0_a, state3); // MDS[3] - cumul1_a = vmlal_high_laneq_u32::(cumul1_a, state3, mdsv2); // MDS[2] - cumul2_a = vaddw_high_u32(cumul2_a, state3); // MDS[1] - cumul3_a = vaddw_high_u32(cumul3_a, state3); // MDS[0] - cumul4_a = vmlal_high_laneq_u32::(cumul4_a, state3, mdsv11); // MDS[11] - cumul5_a = vmlal_high_laneq_u32::(cumul5_a, state3, mdsv10); // MDS[10] - cumul6_a = vmlal_high_laneq_u32::(cumul6_a, state3, mdsv9); // MDS[9] - cumul7_a = vmlal_high_laneq_u32::(cumul7_a, state3, mdsv8); // MDS[8] - cumul8_a = vmlal_high_laneq_u32::(cumul8_a, state3, mdsv7); // MDS[7] - cumul9_a = vmlal_high_laneq_u32::(cumul9_a, state3, mdsv6); // MDS[6] - cumul10_a = vmlal_high_laneq_u32::(cumul10_a, state3, mdsv5); // MDS[5] - cumul11_a = vmlal_high_laneq_u32::(cumul11_a, state3, mdsv4); // MDS[4] + // 3x3 real matrix-vector mul for component 0 of the FFTs. + // Multiply the vector `[x0, x1, x2]` by the matrix + // `[[ 64, 64, 128],` + // ` [128, 64, 64],` + // ` [ 64, 128, 64]]` + // The results are divided by 4 (this ends up cancelling out some later computations). + { + let x0 = state_fft[0]; + let x1 = state_fft[1]; + let x2 = state_fft[2]; - cumul0_b = vmlal_laneq_u32::(cumul0_b, state4, mdsv4); // MDS[4] - cumul1_b = vaddw_u32(cumul1_b, state4); // MDS[3] - cumul2_b = vmlal_laneq_u32::(cumul2_b, state4, mdsv2); // MDS[2] - cumul3_b = vaddw_u32(cumul3_b, state4); // MDS[1] - cumul4_b = vaddw_u32(cumul4_b, state4); // MDS[0] - cumul5_b = vmlal_laneq_u32::(cumul5_b, state4, mdsv11); // MDS[11] - cumul6_b = vmlal_laneq_u32::(cumul6_b, state4, mdsv10); // MDS[10] - cumul7_b = vmlal_laneq_u32::(cumul7_b, state4, mdsv9); // MDS[9] - cumul8_b = vmlal_laneq_u32::(cumul8_b, state4, mdsv8); // MDS[8] - cumul9_b = vmlal_laneq_u32::(cumul9_b, state4, mdsv7); // MDS[7] - cumul10_b = vmlal_laneq_u32::(cumul10_b, state4, mdsv6); // MDS[6] - cumul11_b = vmlal_laneq_u32::(cumul11_b, state4, mdsv5); // MDS[5] + let t = vshlq_n_u32::<4>(x0); + let u = vaddq_u32(x1, x2); - cumul0_a = vmlal_high_laneq_u32::(cumul0_a, state5, mdsv5); // MDS[5] - cumul1_a = vmlal_high_laneq_u32::(cumul1_a, state5, mdsv4); // MDS[4] - cumul2_a = vaddw_high_u32(cumul2_a, state5); // MDS[3] - cumul3_a = vmlal_high_laneq_u32::(cumul3_a, state5, mdsv2); // MDS[2] - cumul4_a = vaddw_high_u32(cumul4_a, state5); // MDS[1] - cumul5_a = vaddw_high_u32(cumul5_a, state5); // MDS[0] - cumul6_a = vmlal_high_laneq_u32::(cumul6_a, state5, mdsv11); // MDS[11] - cumul7_a = vmlal_high_laneq_u32::(cumul7_a, state5, mdsv10); // MDS[10] - cumul8_a = vmlal_high_laneq_u32::(cumul8_a, state5, mdsv9); // MDS[9] - cumul9_a = vmlal_high_laneq_u32::(cumul9_a, state5, mdsv8); // MDS[8] - cumul10_a = vmlal_high_laneq_u32::(cumul10_a, state5, mdsv7); // MDS[7] - cumul11_a = vmlal_high_laneq_u32::(cumul11_a, state5, mdsv6); // MDS[6] + let y0 = vshlq_n_u32::<4>(u); + let y1 = vmlaq_laneq_u32::<3>(t, x2, consts); + let y2 = vmlaq_laneq_u32::<3>(t, x1, consts); - cumul0_b = vmlal_laneq_u32::(cumul0_b, state6, mdsv6); // MDS[6] - cumul1_b = vmlal_laneq_u32::(cumul1_b, state6, mdsv5); // MDS[5] - cumul2_b = vmlal_laneq_u32::(cumul2_b, state6, mdsv4); // MDS[4] - cumul3_b = vaddw_u32(cumul3_b, state6); // MDS[3] - cumul4_b = vmlal_laneq_u32::(cumul4_b, state6, mdsv2); // MDS[2] - cumul5_b = vaddw_u32(cumul5_b, state6); // MDS[1] - cumul6_b = vaddw_u32(cumul6_b, state6); // MDS[0] - cumul7_b = vmlal_laneq_u32::(cumul7_b, state6, mdsv11); // MDS[11] - cumul8_b = vmlal_laneq_u32::(cumul8_b, state6, mdsv10); // MDS[10] - cumul9_b = vmlal_laneq_u32::(cumul9_b, state6, mdsv9); // MDS[9] - cumul10_b = vmlal_laneq_u32::(cumul10_b, state6, mdsv8); // MDS[8] - cumul11_b = vmlal_laneq_u32::(cumul11_b, state6, mdsv7); // MDS[7] + state_fft[0] = vaddq_u32(y0, y1); + state_fft[1] = vaddq_u32(y1, y2); + state_fft[2] = vaddq_u32(y0, y2); + } - cumul0_a = vmlal_high_laneq_u32::(cumul0_a, state7, mdsv7); // MDS[7] - cumul1_a = vmlal_high_laneq_u32::(cumul1_a, state7, mdsv6); // MDS[6] - cumul2_a = vmlal_high_laneq_u32::(cumul2_a, state7, mdsv5); // MDS[5] - cumul3_a = vmlal_high_laneq_u32::(cumul3_a, state7, mdsv4); // MDS[4] - cumul4_a = vaddw_high_u32(cumul4_a, state7); // MDS[3] - cumul5_a = vmlal_high_laneq_u32::(cumul5_a, state7, mdsv2); // MDS[2] - cumul6_a = vaddw_high_u32(cumul6_a, state7); // MDS[1] - cumul7_a = vaddw_high_u32(cumul7_a, state7); // MDS[0] - cumul8_a = vmlal_high_laneq_u32::(cumul8_a, state7, mdsv11); // MDS[11] - cumul9_a = vmlal_high_laneq_u32::(cumul9_a, state7, mdsv10); // MDS[10] - cumul10_a = vmlal_high_laneq_u32::(cumul10_a, state7, mdsv9); // MDS[9] - cumul11_a = vmlal_high_laneq_u32::(cumul11_a, state7, mdsv8); // MDS[8] + // 3x3 real matrix-vector mul for component 2 of the FFTs. + // Multiply the vector `[x0, x1, x2]` by the matrix + // `[[ -4, -8, 32],` + // ` [-32, -4, -8],` + // ` [ 8, -32, -4]]` + // The results are divided by 4 (this ends up cancelling out some later computations). + { + let x0 = state_fft[3]; + let x1 = state_fft[4]; + let x2 = state_fft[5]; + state_fft[3] = vmlsq_laneq_u32::<2>(vmlaq_laneq_u32::<0>(x0, x1, consts), x2, consts); + state_fft[4] = vmlaq_laneq_u32::<0>(vmlaq_laneq_u32::<2>(x1, x0, consts), x2, consts); + state_fft[5] = vmlsq_laneq_u32::<0>(x2, vmlsq_laneq_u32::<1>(x0, x1, consts), consts); + } - cumul0_b = vmlal_laneq_u32::(cumul0_b, state8, mdsv8); // MDS[8] - cumul1_b = vmlal_laneq_u32::(cumul1_b, state8, mdsv7); // MDS[7] - cumul2_b = vmlal_laneq_u32::(cumul2_b, state8, mdsv6); // MDS[6] - cumul3_b = vmlal_laneq_u32::(cumul3_b, state8, mdsv5); // MDS[5] - cumul4_b = vmlal_laneq_u32::(cumul4_b, state8, mdsv4); // MDS[4] - cumul5_b = vaddw_u32(cumul5_b, state8); // MDS[3] - cumul6_b = vmlal_laneq_u32::(cumul6_b, state8, mdsv2); // MDS[2] - cumul7_b = vaddw_u32(cumul7_b, state8); // MDS[1] - cumul8_b = vaddw_u32(cumul8_b, state8); // MDS[0] - cumul9_b = vmlal_laneq_u32::(cumul9_b, state8, mdsv11); // MDS[11] - cumul10_b = vmlal_laneq_u32::(cumul10_b, state8, mdsv10); // MDS[10] - cumul11_b = vmlal_laneq_u32::(cumul11_b, state8, mdsv9); // MDS[9] + // 3x3 complex matrix-vector mul for components 1 and 3 of the FFTs. + // Multiply the vector `[x0r + x0i i, x1r + x1i i, x2r + x2i i]` by the matrix + // `[[ 4 + 2i, 2 + 32i, 2 - 8i],` + // ` [-8 - 2i, 4 + 2i, 2 + 32i],` + // ` [32 - 2i, -8 - 2i, 4 + 2i]]` + // The results are divided by 2 (this ends up cancelling out some later computations). + { + let x0r = state_fft[6]; + let x1r = state_fft[7]; + let x2r = state_fft[8]; - cumul0_a = vmlal_high_laneq_u32::(cumul0_a, state9, mdsv9); // MDS[9] - cumul1_a = vmlal_high_laneq_u32::(cumul1_a, state9, mdsv8); // MDS[8] - cumul2_a = vmlal_high_laneq_u32::(cumul2_a, state9, mdsv7); // MDS[7] - cumul3_a = vmlal_high_laneq_u32::(cumul3_a, state9, mdsv6); // MDS[6] - cumul4_a = vmlal_high_laneq_u32::(cumul4_a, state9, mdsv5); // MDS[5] - cumul5_a = vmlal_high_laneq_u32::(cumul5_a, state9, mdsv4); // MDS[4] - cumul6_a = vaddw_high_u32(cumul6_a, state9); // MDS[3] - cumul7_a = vmlal_high_laneq_u32::(cumul7_a, state9, mdsv2); // MDS[2] - cumul8_a = vaddw_high_u32(cumul8_a, state9); // MDS[1] - cumul9_a = vaddw_high_u32(cumul9_a, state9); // MDS[0] - cumul10_a = vmlal_high_laneq_u32::(cumul10_a, state9, mdsv11); // MDS[11] - cumul11_a = vmlal_high_laneq_u32::(cumul11_a, state9, mdsv10); // MDS[10] + let x0i = state_fft[9]; + let x1i = state_fft[10]; + let x2i = state_fft[11]; - cumul0_b = vmlal_laneq_u32::(cumul0_b, state10, mdsv10); // MDS[10] - cumul1_b = vmlal_laneq_u32::(cumul1_b, state10, mdsv9); // MDS[9] - cumul2_b = vmlal_laneq_u32::(cumul2_b, state10, mdsv8); // MDS[8] - cumul3_b = vmlal_laneq_u32::(cumul3_b, state10, mdsv7); // MDS[7] - cumul4_b = vmlal_laneq_u32::(cumul4_b, state10, mdsv6); // MDS[6] - cumul5_b = vmlal_laneq_u32::(cumul5_b, state10, mdsv5); // MDS[5] - cumul6_b = vmlal_laneq_u32::(cumul6_b, state10, mdsv4); // MDS[4] - cumul7_b = vaddw_u32(cumul7_b, state10); // MDS[3] - cumul8_b = vmlal_laneq_u32::(cumul8_b, state10, mdsv2); // MDS[2] - cumul9_b = vaddw_u32(cumul9_b, state10); // MDS[1] - cumul10_b = vaddw_u32(cumul10_b, state10); // MDS[0] - cumul11_b = vmlal_laneq_u32::(cumul11_b, state10, mdsv11); // MDS[11] + // real part of result <- real part of input + let r0rr = vaddq_u32(vmlaq_laneq_u32::<0>(x1r, x0r, consts), x2r); + let r1rr = vmlaq_laneq_u32::<0>(x2r, vmlsq_laneq_u32::<0>(x1r, x0r, consts), consts); + let r2rr = vmlsq_laneq_u32::<0>(x2r, vmlsq_laneq_u32::<1>(x1r, x0r, consts), consts); - cumul0_a = vmlal_high_laneq_u32::(cumul0_a, state11, mdsv11); // MDS[11] - cumul1_a = vmlal_high_laneq_u32::(cumul1_a, state11, mdsv10); // MDS[10] - cumul2_a = vmlal_high_laneq_u32::(cumul2_a, state11, mdsv9); // MDS[9] - cumul3_a = vmlal_high_laneq_u32::(cumul3_a, state11, mdsv8); // MDS[8] - cumul4_a = vmlal_high_laneq_u32::(cumul4_a, state11, mdsv7); // MDS[7] - cumul5_a = vmlal_high_laneq_u32::(cumul5_a, state11, mdsv6); // MDS[6] - cumul6_a = vmlal_high_laneq_u32::(cumul6_a, state11, mdsv5); // MDS[5] - cumul7_a = vmlal_high_laneq_u32::(cumul7_a, state11, mdsv4); // MDS[4] - cumul8_a = vaddw_high_u32(cumul8_a, state11); // MDS[3] - cumul9_a = vmlal_high_laneq_u32::(cumul9_a, state11, mdsv2); // MDS[2] - cumul10_a = vaddw_high_u32(cumul10_a, state11); // MDS[1] - cumul11_a = vaddw_high_u32(cumul11_a, state11); // MDS[0] + // real part of result <- imaginary part of input + let r0ri = vmlsq_laneq_u32::<1>(vmlaq_laneq_u32::<3>(x0i, x1i, consts), x2i, consts); + let r1ri = vmlsq_laneq_u32::<3>(vsubq_u32(x0i, x1i), x2i, consts); + let r2ri = vsubq_u32(vaddq_u32(x0i, x1i), x2i); - let reduced = [ - mds_reduce([[cumul0_a, cumul0_b], [cumul1_a, cumul1_b]]), - mds_reduce([[cumul2_a, cumul2_b], [cumul3_a, cumul3_b]]), - mds_reduce([[cumul4_a, cumul4_b], [cumul5_a, cumul5_b]]), - mds_reduce([[cumul6_a, cumul6_b], [cumul7_a, cumul7_b]]), - mds_reduce([[cumul8_a, cumul8_b], [cumul9_a, cumul9_b]]), - mds_reduce([[cumul10_a, cumul10_b], [cumul11_a, cumul11_b]]), - ]; - [ - vgetq_lane_u64::<0>(reduced[0]), - vgetq_lane_u64::<1>(reduced[0]), - vgetq_lane_u64::<0>(reduced[1]), - vgetq_lane_u64::<1>(reduced[1]), - vgetq_lane_u64::<0>(reduced[2]), - vgetq_lane_u64::<1>(reduced[2]), - vgetq_lane_u64::<0>(reduced[3]), - vgetq_lane_u64::<1>(reduced[3]), - vgetq_lane_u64::<0>(reduced[4]), - vgetq_lane_u64::<1>(reduced[4]), - vgetq_lane_u64::<0>(reduced[5]), - vgetq_lane_u64::<1>(reduced[5]), - ] + // real part of result (total) + let r0r = vsubq_u32(r0rr, r0ri); + let r1r = vaddq_u32(r1rr, r1ri); + let r2r = vmlaq_laneq_u32::<0>(r2ri, r2rr, consts); + + // imaginary part of result <- real part of input + let r0ir = vmlsq_laneq_u32::<1>(vmlaq_laneq_u32::<3>(x0r, x1r, consts), x2r, consts); + let r1ir = vmlaq_laneq_u32::<3>(vsubq_u32(x1r, x0r), x2r, consts); + let r2ir = vsubq_u32(x2r, vaddq_u32(x0r, x1r)); + + // imaginary part of result <- imaginary part of input + let r0ii = vaddq_u32(vmlaq_laneq_u32::<0>(x1i, x0i, consts), x2i); + let r1ii = vmlaq_laneq_u32::<0>(x2i, vmlsq_laneq_u32::<0>(x1i, x0i, consts), consts); + let r2ii = vmlsq_laneq_u32::<0>(x2i, vmlsq_laneq_u32::<1>(x1i, x0i, consts), consts); + + // imaginary part of result (total) + let r0i = vaddq_u32(r0ir, r0ii); + let r1i = vaddq_u32(r1ir, r1ii); + let r2i = vmlaq_laneq_u32::<0>(r2ir, r2ii, consts); + + state_fft[6] = r0r; + state_fft[7] = r1r; + state_fft[8] = r2r; + + state_fft[9] = r0i; + state_fft[10] = r1i; + state_fft[11] = r2i; + } + + // Three length-4 inverse FFTs. + // Normally, such IFFT would divide by 4, but we've already taken care of that. + for i in 0..3 { + let z0 = state_fft[i]; + let z1 = state_fft[i + 3]; + let z2 = state_fft[i + 6]; + let z3 = state_fft[i + 9]; + + let y0 = vsubq_u32(z0, z1); + let y1 = vaddq_u32(z0, z1); + let y2 = z2; + let y3 = z3; + + let x0 = vaddq_u32(y0, y2); + let x1 = vaddq_u32(y1, y3); + let x2 = vsubq_u32(y0, y2); + let x3 = vsubq_u32(y1, y3); + + state_fft[i] = x0; + state_fft[i + 3] = x1; + state_fft[i + 6] = x2; + state_fft[i + 9] = x3; + } + + // Perform `res[0] += state[0] * 8` for the diagonal component of the MDS matrix. + state_fft[0] = vmlal_laneq_u16::<4>( + state_fft[0], + vcreate_u16(state[0]), // Each 16-bit chunk gets zero-extended. + vreinterpretq_u16_u32(consts), // Hack: these constants fit in `u16s`, so we can bit-cast. + ); + + let mut res_arr = [0; 12]; + for i in 0..6 { + let res = mds_reduce([state_fft[2 * i], state_fft[2 * i + 1]]); + res_arr[2 * i] = vgetq_lane_u64::<0>(res); + res_arr[2 * i + 1] = vgetq_lane_u64::<1>(res); + } + + res_arr } // ======================================== PARTIAL ROUNDS ========================================= +/* #[rustfmt::skip] macro_rules! mds_reduce_asm { ($c0:literal, $c1:literal, $out:literal, $consts:literal) => { @@ -961,13 +864,15 @@ unsafe fn partial_round( [res23, res45, res67, res89, res1011], ) } +*/ // ========================================== GLUE CODE =========================================== +/* #[inline(always)] unsafe fn full_round(state: [u64; 12], round_constants: &[u64; WIDTH]) -> [u64; 12] { let state = sbox_layer_full(state); - mds_const_layers_full(state, round_constants) + mds_layer_full(state, round_constants) } #[inline] @@ -1001,43 +906,19 @@ unsafe fn partial_rounds( } state.0 } +*/ #[inline(always)] fn unwrap_state(state: [GoldilocksField; 12]) -> [u64; 12] { - [ - state[0].0, - state[1].0, - state[2].0, - state[3].0, - state[4].0, - state[5].0, - state[6].0, - state[7].0, - state[8].0, - state[9].0, - state[10].0, - state[11].0, - ] + state.map(|s| s.0) } #[inline(always)] fn wrap_state(state: [u64; 12]) -> [GoldilocksField; 12] { - [ - GoldilocksField(state[0]), - GoldilocksField(state[1]), - GoldilocksField(state[2]), - GoldilocksField(state[3]), - GoldilocksField(state[4]), - GoldilocksField(state[5]), - GoldilocksField(state[6]), - GoldilocksField(state[7]), - GoldilocksField(state[8]), - GoldilocksField(state[9]), - GoldilocksField(state[10]), - GoldilocksField(state[11]), - ] + state.map(GoldilocksField) } +/* #[inline(always)] pub unsafe fn poseidon(state: [GoldilocksField; 12]) -> [GoldilocksField; 12] { let state = unwrap_state(state); @@ -1058,6 +939,7 @@ pub unsafe fn poseidon(state: [GoldilocksField; 12]) -> [GoldilocksField; 12] { let state = full_rounds(state, &FINAL_ROUND_CONSTANTS); wrap_state(state) } +*/ #[inline(always)] pub unsafe fn sbox_layer(state: &mut [GoldilocksField; WIDTH]) { @@ -1067,8 +949,6 @@ pub unsafe fn sbox_layer(state: &mut [GoldilocksField; WIDTH]) { #[inline(always)] pub unsafe fn mds_layer(state: &[GoldilocksField; WIDTH]) -> [GoldilocksField; WIDTH] { let state = unwrap_state(*state); - // We want to do an MDS layer without the constant layer. - let round_consts = [0u64; WIDTH]; - let state = mds_const_layers_full(state, &round_consts); + let state = mds_layer_full(state); wrap_state(state) } diff --git a/plonky2/src/hash/poseidon_goldilocks.rs b/plonky2/src/hash/poseidon_goldilocks.rs index 971fda0f..177b30ff 100644 --- a/plonky2/src/hash/poseidon_goldilocks.rs +++ b/plonky2/src/hash/poseidon_goldilocks.rs @@ -252,21 +252,21 @@ impl Poseidon for GoldilocksField { // } // } - // #[cfg(all(target_arch="aarch64", target_feature="neon"))] - // #[inline(always)] - // fn sbox_layer(state: &mut [Self; 12]) { - // unsafe { - // crate::hash::arch::aarch64::poseidon_goldilocks_neon::sbox_layer(state); - // } - // } + #[cfg(all(target_arch="aarch64", target_feature="neon"))] + #[inline(always)] + fn sbox_layer(state: &mut [Self; 12]) { + unsafe { + crate::hash::arch::aarch64::poseidon_goldilocks_neon::sbox_layer(state); + } + } - // #[cfg(all(target_arch="aarch64", target_feature="neon"))] - // #[inline(always)] - // fn mds_layer(state: &[Self; 12]) -> [Self; 12] { - // unsafe { - // crate::hash::arch::aarch64::poseidon_goldilocks_neon::mds_layer(state) - // } - // } + #[cfg(all(target_arch="aarch64", target_feature="neon"))] + #[inline(always)] + fn mds_layer(state: &[Self; 12]) -> [Self; 12] { + unsafe { + crate::hash::arch::aarch64::poseidon_goldilocks_neon::mds_layer(state) + } + } } #[cfg(test)]