mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-04 23:03:08 +00:00
ARM-optimized Goldilocks Poseidon (#294)
* Optimized Poseidon for ARM * Hamish comments * Unused things
This commit is contained in:
parent
8a5419d461
commit
06e48d0b61
2
src/hash/arch/aarch64/mod.rs
Normal file
2
src/hash/arch/aarch64/mod.rs
Normal file
@ -0,0 +1,2 @@
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub(crate) mod poseidon_goldilocks_neon;
|
||||
984
src/hash/arch/aarch64/poseidon_goldilocks_neon.rs
Normal file
984
src/hash/arch/aarch64/poseidon_goldilocks_neon.rs
Normal file
@ -0,0 +1,984 @@
|
||||
use std::arch::aarch64::*;
|
||||
use std::convert::TryInto;
|
||||
|
||||
use unroll::unroll_for_loops;
|
||||
|
||||
use crate::field::goldilocks_field::GoldilocksField;
|
||||
use crate::hash::poseidon::{ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS};
|
||||
|
||||
const WIDTH: usize = 12;
|
||||
|
||||
// The order below is arbitrary. Repeated coefficients have been removed so these constants fit in
|
||||
// two registers.
|
||||
// TODO: ensure this is aligned to 16 bytes (for vector loads), ideally on the same cacheline
|
||||
const MDS_CONSTS: [u32; 8] = [
|
||||
0xffffffff,
|
||||
1 << 1,
|
||||
1 << 3,
|
||||
1 << 5,
|
||||
1 << 8,
|
||||
1 << 10,
|
||||
1 << 12,
|
||||
1 << 16,
|
||||
];
|
||||
|
||||
// The round constants to be applied by the second set of full rounds. These are just the usual round constants,
|
||||
// shifted by one round, with zeros shifted in.
|
||||
const fn make_final_round_constants() -> [u64; WIDTH * HALF_N_FULL_ROUNDS] {
|
||||
let mut res = [0; WIDTH * HALF_N_FULL_ROUNDS];
|
||||
let mut i: usize = 0;
|
||||
while i < WIDTH * (HALF_N_FULL_ROUNDS - 1) {
|
||||
res[i] = ALL_ROUND_CONSTANTS[i + WIDTH * (HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + 1)];
|
||||
i += 1;
|
||||
}
|
||||
res
|
||||
}
|
||||
const FINAL_ROUND_CONSTANTS: [u64; WIDTH * HALF_N_FULL_ROUNDS] = make_final_round_constants();
|
||||
|
||||
// ====================================== SCALAR ARITHMETIC =======================================
|
||||
|
||||
/// Addition modulo ORDER accounting for wraparound. Correct only when a + b < 2**64 + ORDER.
|
||||
#[inline(always)]
|
||||
unsafe fn add_with_wraparound(a: u64, b: u64) -> u64 {
|
||||
let res: u64;
|
||||
let adj: u64;
|
||||
asm!(
|
||||
"adds {res}, {a}, {b}",
|
||||
// Set adj to 0xffffffff if addition overflowed and 0 otherwise.
|
||||
// 'cs' for 'carry set'.
|
||||
"csetm {adj:w}, cs",
|
||||
a = in(reg) a,
|
||||
b = in(reg) b,
|
||||
res = lateout(reg) res,
|
||||
adj = lateout(reg) adj,
|
||||
options(pure, nomem, nostack),
|
||||
);
|
||||
res.wrapping_add(adj) // adj is EPSILON if wraparound occured and 0 otherwise
|
||||
}
|
||||
|
||||
/// Addition of a and (b >> 32) modulo ORDER accounting for wraparound.
|
||||
#[inline(always)]
|
||||
unsafe fn sub_with_wraparound_lsr32(a: u64, b: u64) -> u64 {
|
||||
let res: u64;
|
||||
let adj: u64;
|
||||
asm!(
|
||||
"subs {res}, {a}, {b}, lsr #32",
|
||||
// Set adj to 0xffffffff if subtraction underflowed and 0 otherwise.
|
||||
// 'cc' for 'carry clear'.
|
||||
// NB: The CF in ARM subtraction is the opposite of x86: CF set == underflow did not occur.
|
||||
"csetm {adj:w}, cc",
|
||||
a = in(reg) a,
|
||||
b = in(reg) b,
|
||||
res = lateout(reg) res,
|
||||
adj = lateout(reg) adj,
|
||||
options(pure, nomem, nostack),
|
||||
);
|
||||
res.wrapping_sub(adj) // adj is EPSILON if underflow occured and 0 otherwise.
|
||||
}
|
||||
|
||||
/// Multiplication of the low word (i.e., x as u32) by EPSILON.
|
||||
#[inline(always)]
|
||||
unsafe fn mul_epsilon(x: u64) -> u64 {
|
||||
let res;
|
||||
let epsilon: u64 = 0xffffffff;
|
||||
asm!(
|
||||
// Use UMULL to save one instruction. The compiler emits two: extract the low word and then multiply.
|
||||
"umull {res}, {x:w}, {epsilon:w}",
|
||||
x = in(reg) x,
|
||||
epsilon = in(reg) epsilon,
|
||||
res = lateout(reg) res,
|
||||
options(pure, nomem, nostack, preserves_flags),
|
||||
);
|
||||
res
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn multiply(x: u64, y: u64) -> u64 {
|
||||
let xy = (x as u128) * (y as u128);
|
||||
let xy_lo = xy as u64;
|
||||
let xy_hi = (xy >> 64) as u64;
|
||||
|
||||
let res0 = sub_with_wraparound_lsr32(xy_lo, xy_hi);
|
||||
|
||||
let xy_hi_lo_mul_epsilon = mul_epsilon(xy_hi);
|
||||
|
||||
// add_with_wraparound is safe, as xy_hi_lo_mul_epsilon <= 0xfffffffe00000001 <= ORDER.
|
||||
let res1 = add_with_wraparound(res0, xy_hi_lo_mul_epsilon);
|
||||
|
||||
res1
|
||||
}
|
||||
|
||||
// ==================================== STANDALONE CONST LAYER =====================================
|
||||
|
||||
/// Standalone const layer. Run only once, at the start of round 1. Remaining const layers are fused with the preceeding
|
||||
/// MDS matrix multiplication.
|
||||
#[inline(always)]
|
||||
#[unroll_for_loops]
|
||||
unsafe fn const_layer_full(
|
||||
mut state: [u64; WIDTH],
|
||||
round_constants: &[u64; WIDTH],
|
||||
) -> [u64; WIDTH] {
|
||||
assert!(WIDTH == 12);
|
||||
for i in 0..12 {
|
||||
let rc = round_constants[i];
|
||||
// add_with_wraparound is safe, because rc is in canonical form.
|
||||
state[i] = add_with_wraparound(state[i], rc);
|
||||
}
|
||||
state
|
||||
}
|
||||
|
||||
// ========================================== FULL ROUNDS ==========================================
|
||||
|
||||
/// Full S-box.
|
||||
#[inline(always)]
|
||||
#[unroll_for_loops]
|
||||
unsafe fn sbox_layer_full(state: [u64; WIDTH]) -> [u64; WIDTH] {
|
||||
// This is done in scalar. S-boxes in vector are only slightly slower throughput-wise but have an insane latency
|
||||
// (~100 cycles) on the M1.
|
||||
|
||||
let mut state2 = [0u64; WIDTH];
|
||||
assert!(WIDTH == 12);
|
||||
for i in 0..12 {
|
||||
state2[i] = multiply(state[i], state[i]);
|
||||
}
|
||||
|
||||
let mut state3 = [0u64; WIDTH];
|
||||
let mut state4 = [0u64; WIDTH];
|
||||
assert!(WIDTH == 12);
|
||||
for i in 0..12 {
|
||||
state3[i] = multiply(state[i], state2[i]);
|
||||
state4[i] = multiply(state2[i], state2[i]);
|
||||
}
|
||||
|
||||
let mut state7 = [0u64; WIDTH];
|
||||
assert!(WIDTH == 12);
|
||||
for i in 0..12 {
|
||||
state7[i] = multiply(state3[i], state4[i]);
|
||||
}
|
||||
|
||||
state7
|
||||
}
|
||||
|
||||
// Aliases for readability. E.g. MDS[5] can be found in mdsv5[MDSI5].
|
||||
const MDSI2: i32 = 1; // MDS[2] == 1
|
||||
const MDSI4: i32 = 2; // MDS[4] == 3
|
||||
const MDSI5: i32 = 3; // MDS[5] == 5
|
||||
const MDSI6: i32 = 1; // MDS[6] == 1
|
||||
const MDSI7: i32 = 0; // MDS[7] == 8
|
||||
const MDSI8: i32 = 2; // MDS[8] == 12
|
||||
const MDSI9: i32 = 2; // MDS[9] == 3
|
||||
const MDSI10: i32 = 3; // MDS[10] == 16
|
||||
const MDSI11: i32 = 1; // MDS[11] == 10
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn mds_reduce(
|
||||
[[cumul0_a, cumul0_b], [cumul1_a, cumul1_b]]: [[uint64x2_t; 2]; 2],
|
||||
) -> uint64x2_t {
|
||||
// mds_consts0 == [0xffffffff, 1 << 1, 1 << 3, 1 << 5]
|
||||
let mds_consts0: uint32x4_t = vld1q_u32((&MDS_CONSTS[0..4]).as_ptr().cast::<u32>());
|
||||
|
||||
// Merge accumulators
|
||||
let cumul0 = vaddq_u64(cumul0_a, cumul0_b);
|
||||
let cumul1 = vaddq_u64(cumul1_a, cumul1_b);
|
||||
|
||||
// Swizzle
|
||||
let res_lo = vzip1q_u64(cumul0, cumul1);
|
||||
let res_hi = vzip2q_u64(cumul0, cumul1);
|
||||
|
||||
// Reduce from u96
|
||||
let res_hi = vsraq_n_u64::<32>(res_hi, res_lo);
|
||||
let res_lo = vsliq_n_u64::<32>(res_lo, res_hi);
|
||||
|
||||
// Extract high 32-bits.
|
||||
let res_hi_hi = vget_low_u32(vuzp2q_u32(
|
||||
vreinterpretq_u32_u64(res_hi),
|
||||
vreinterpretq_u32_u64(res_hi),
|
||||
));
|
||||
|
||||
// Multiply by EPSILON and accumulate.
|
||||
let res_unadj = vmlal_laneq_u32::<0>(res_lo, res_hi_hi, mds_consts0);
|
||||
let res_adj = vcgtq_u64(res_lo, res_unadj);
|
||||
let res = vsraq_n_u64::<32>(res_unadj, res_adj);
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn mds_const_layers_full(
|
||||
state: [u64; WIDTH],
|
||||
round_constants: &[u64; WIDTH],
|
||||
) -> [u64; WIDTH] {
|
||||
// mds_consts0 == [0xffffffff, 1 << 1, 1 << 3, 1 << 5]
|
||||
// mds_consts1 == [1 << 8, 1 << 10, 1 << 12, 1 << 16]
|
||||
let mds_consts0: uint32x4_t = vld1q_u32((&MDS_CONSTS[0..4]).as_ptr().cast::<u32>());
|
||||
let mds_consts1: uint32x4_t = vld1q_u32((&MDS_CONSTS[4..8]).as_ptr().cast::<u32>());
|
||||
|
||||
// Aliases for readability. E.g. MDS[5] can be found in mdsv5[mdsi5]. MDS[0], MDS[1], and
|
||||
// MDS[3] are 0, so they are not needed.
|
||||
let mdsv2 = mds_consts0; // MDS[2] == 1
|
||||
let mdsv4 = mds_consts0; // MDS[4] == 3
|
||||
let mdsv5 = mds_consts0; // MDS[5] == 5
|
||||
let mdsv6 = mds_consts0; // MDS[6] == 1
|
||||
let mdsv7 = mds_consts1; // MDS[7] == 8
|
||||
let mdsv8 = mds_consts1; // MDS[8] == 12
|
||||
let mdsv9 = mds_consts0; // MDS[9] == 3
|
||||
let mdsv10 = mds_consts1; // MDS[10] == 16
|
||||
let mdsv11 = mds_consts1; // MDS[11] == 10
|
||||
|
||||
// For i even, we combine state[i] and state[i + 1] into one vector to save on registers.
|
||||
// Thus, state1 actually contains state0 and state1 but is only used in the intrinsics that
|
||||
// access the high high doubleword.
|
||||
let state1: uint32x4_t =
|
||||
vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[0]), vcreate_u64(state[1])));
|
||||
let state3: uint32x4_t =
|
||||
vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[2]), vcreate_u64(state[3])));
|
||||
let state5: uint32x4_t =
|
||||
vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[4]), vcreate_u64(state[5])));
|
||||
let state7: uint32x4_t =
|
||||
vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[6]), vcreate_u64(state[7])));
|
||||
let state9: uint32x4_t =
|
||||
vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[8]), vcreate_u64(state[9])));
|
||||
let state11: uint32x4_t =
|
||||
vreinterpretq_u32_u64(vcombine_u64(vcreate_u64(state[10]), vcreate_u64(state[11])));
|
||||
// state0 is an alias to the low doubleword of state1. The compiler should use one register for both.
|
||||
let state0: uint32x2_t = vget_low_u32(state1);
|
||||
let state2: uint32x2_t = vget_low_u32(state3);
|
||||
let state4: uint32x2_t = vget_low_u32(state5);
|
||||
let state6: uint32x2_t = vget_low_u32(state7);
|
||||
let state8: uint32x2_t = vget_low_u32(state9);
|
||||
let state10: uint32x2_t = vget_low_u32(state11);
|
||||
|
||||
// Two accumulators per output to hide latency. Each accumulator is a vector of two u64s,
|
||||
// containing the result for the low 32 bits and the high 32 bits. Thus, the final result at
|
||||
// index i is (cumuli_a[0] + cumuli_b[0]) + (cumuli_a[1] + cumuli_b[1]) * 2**32.
|
||||
|
||||
// Start by loading the round constants.
|
||||
let mut cumul0_a = vcombine_u64(vld1_u64(&round_constants[0]), vcreate_u64(0));
|
||||
let mut cumul1_a = vcombine_u64(vld1_u64(&round_constants[1]), vcreate_u64(0));
|
||||
let mut cumul2_a = vcombine_u64(vld1_u64(&round_constants[2]), vcreate_u64(0));
|
||||
let mut cumul3_a = vcombine_u64(vld1_u64(&round_constants[3]), vcreate_u64(0));
|
||||
let mut cumul4_a = vcombine_u64(vld1_u64(&round_constants[4]), vcreate_u64(0));
|
||||
let mut cumul5_a = vcombine_u64(vld1_u64(&round_constants[5]), vcreate_u64(0));
|
||||
let mut cumul6_a = vcombine_u64(vld1_u64(&round_constants[6]), vcreate_u64(0));
|
||||
let mut cumul7_a = vcombine_u64(vld1_u64(&round_constants[7]), vcreate_u64(0));
|
||||
let mut cumul8_a = vcombine_u64(vld1_u64(&round_constants[8]), vcreate_u64(0));
|
||||
let mut cumul9_a = vcombine_u64(vld1_u64(&round_constants[9]), vcreate_u64(0));
|
||||
let mut cumul10_a = vcombine_u64(vld1_u64(&round_constants[10]), vcreate_u64(0));
|
||||
let mut cumul11_a = vcombine_u64(vld1_u64(&round_constants[11]), vcreate_u64(0));
|
||||
|
||||
// Now the matrix multiplication.
|
||||
// MDS exps: [0, 0, 1, 0, 3, 5, 1, 8, 12, 3, 16, 10]
|
||||
// out[i] += in[j] << mds[j - i]
|
||||
|
||||
let mut cumul0_b = vshll_n_u32::<0>(state0); // MDS[0]
|
||||
let mut cumul1_b = vshll_n_u32::<10>(state0); // MDS[11]
|
||||
let mut cumul2_b = vshll_n_u32::<16>(state0); // MDS[10]
|
||||
let mut cumul3_b = vshll_n_u32::<3>(state0); // MDS[9]
|
||||
let mut cumul4_b = vshll_n_u32::<12>(state0); // MDS[8]
|
||||
let mut cumul5_b = vshll_n_u32::<8>(state0); // MDS[7]
|
||||
let mut cumul6_b = vshll_n_u32::<1>(state0); // MDS[6]
|
||||
let mut cumul7_b = vshll_n_u32::<5>(state0); // MDS[5]
|
||||
let mut cumul8_b = vshll_n_u32::<3>(state0); // MDS[4]
|
||||
let mut cumul9_b = vshll_n_u32::<0>(state0); // MDS[3]
|
||||
let mut cumul10_b = vshll_n_u32::<1>(state0); // MDS[2]
|
||||
let mut cumul11_b = vshll_n_u32::<0>(state0); // MDS[1]
|
||||
|
||||
cumul0_a = vaddw_high_u32(cumul0_a, state1); // MDS[1]
|
||||
cumul1_a = vaddw_high_u32(cumul1_a, state1); // MDS[0]
|
||||
cumul2_a = vmlal_high_laneq_u32::<MDSI11>(cumul2_a, state1, mdsv11); // MDS[11]
|
||||
cumul3_a = vmlal_high_laneq_u32::<MDSI10>(cumul3_a, state1, mdsv10); // MDS[10]
|
||||
cumul4_a = vmlal_high_laneq_u32::<MDSI9>(cumul4_a, state1, mdsv9); // MDS[9]
|
||||
cumul5_a = vmlal_high_laneq_u32::<MDSI8>(cumul5_a, state1, mdsv8); // MDS[8]
|
||||
cumul6_a = vmlal_high_laneq_u32::<MDSI7>(cumul6_a, state1, mdsv7); // MDS[7]
|
||||
cumul7_a = vmlal_high_laneq_u32::<MDSI6>(cumul7_a, state1, mdsv6); // MDS[6]
|
||||
cumul8_a = vmlal_high_laneq_u32::<MDSI5>(cumul8_a, state1, mdsv5); // MDS[5]
|
||||
cumul9_a = vmlal_high_laneq_u32::<MDSI4>(cumul9_a, state1, mdsv4); // MDS[4]
|
||||
cumul10_a = vaddw_high_u32(cumul10_a, state1); // MDS[3]
|
||||
cumul11_a = vmlal_high_laneq_u32::<MDSI2>(cumul11_a, state1, mdsv2); // MDS[2]
|
||||
|
||||
cumul0_b = vmlal_laneq_u32::<MDSI2>(cumul0_b, state2, mdsv2); // MDS[2]
|
||||
cumul1_b = vaddw_u32(cumul1_b, state2); // MDS[1]
|
||||
cumul2_b = vaddw_u32(cumul2_b, state2); // MDS[0]
|
||||
cumul3_b = vmlal_laneq_u32::<MDSI11>(cumul3_b, state2, mdsv11); // MDS[11]
|
||||
cumul4_b = vmlal_laneq_u32::<MDSI10>(cumul4_b, state2, mdsv10); // MDS[10]
|
||||
cumul5_b = vmlal_laneq_u32::<MDSI9>(cumul5_b, state2, mdsv9); // MDS[9]
|
||||
cumul6_b = vmlal_laneq_u32::<MDSI8>(cumul6_b, state2, mdsv8); // MDS[8]
|
||||
cumul7_b = vmlal_laneq_u32::<MDSI7>(cumul7_b, state2, mdsv7); // MDS[7]
|
||||
cumul8_b = vmlal_laneq_u32::<MDSI6>(cumul8_b, state2, mdsv6); // MDS[6]
|
||||
cumul9_b = vmlal_laneq_u32::<MDSI5>(cumul9_b, state2, mdsv5); // MDS[5]
|
||||
cumul10_b = vmlal_laneq_u32::<MDSI4>(cumul10_b, state2, mdsv4); // MDS[4]
|
||||
cumul11_b = vaddw_u32(cumul11_b, state2); // MDS[3]
|
||||
|
||||
cumul0_a = vaddw_high_u32(cumul0_a, state3); // MDS[3]
|
||||
cumul1_a = vmlal_high_laneq_u32::<MDSI2>(cumul1_a, state3, mdsv2); // MDS[2]
|
||||
cumul2_a = vaddw_high_u32(cumul2_a, state3); // MDS[1]
|
||||
cumul3_a = vaddw_high_u32(cumul3_a, state3); // MDS[0]
|
||||
cumul4_a = vmlal_high_laneq_u32::<MDSI11>(cumul4_a, state3, mdsv11); // MDS[11]
|
||||
cumul5_a = vmlal_high_laneq_u32::<MDSI10>(cumul5_a, state3, mdsv10); // MDS[10]
|
||||
cumul6_a = vmlal_high_laneq_u32::<MDSI9>(cumul6_a, state3, mdsv9); // MDS[9]
|
||||
cumul7_a = vmlal_high_laneq_u32::<MDSI8>(cumul7_a, state3, mdsv8); // MDS[8]
|
||||
cumul8_a = vmlal_high_laneq_u32::<MDSI7>(cumul8_a, state3, mdsv7); // MDS[7]
|
||||
cumul9_a = vmlal_high_laneq_u32::<MDSI6>(cumul9_a, state3, mdsv6); // MDS[6]
|
||||
cumul10_a = vmlal_high_laneq_u32::<MDSI5>(cumul10_a, state3, mdsv5); // MDS[5]
|
||||
cumul11_a = vmlal_high_laneq_u32::<MDSI4>(cumul11_a, state3, mdsv4); // MDS[4]
|
||||
|
||||
cumul0_b = vmlal_laneq_u32::<MDSI4>(cumul0_b, state4, mdsv4); // MDS[4]
|
||||
cumul1_b = vaddw_u32(cumul1_b, state4); // MDS[3]
|
||||
cumul2_b = vmlal_laneq_u32::<MDSI2>(cumul2_b, state4, mdsv2); // MDS[2]
|
||||
cumul3_b = vaddw_u32(cumul3_b, state4); // MDS[1]
|
||||
cumul4_b = vaddw_u32(cumul4_b, state4); // MDS[0]
|
||||
cumul5_b = vmlal_laneq_u32::<MDSI11>(cumul5_b, state4, mdsv11); // MDS[11]
|
||||
cumul6_b = vmlal_laneq_u32::<MDSI10>(cumul6_b, state4, mdsv10); // MDS[10]
|
||||
cumul7_b = vmlal_laneq_u32::<MDSI9>(cumul7_b, state4, mdsv9); // MDS[9]
|
||||
cumul8_b = vmlal_laneq_u32::<MDSI8>(cumul8_b, state4, mdsv8); // MDS[8]
|
||||
cumul9_b = vmlal_laneq_u32::<MDSI7>(cumul9_b, state4, mdsv7); // MDS[7]
|
||||
cumul10_b = vmlal_laneq_u32::<MDSI6>(cumul10_b, state4, mdsv6); // MDS[6]
|
||||
cumul11_b = vmlal_laneq_u32::<MDSI5>(cumul11_b, state4, mdsv5); // MDS[5]
|
||||
|
||||
cumul0_a = vmlal_high_laneq_u32::<MDSI5>(cumul0_a, state5, mdsv5); // MDS[5]
|
||||
cumul1_a = vmlal_high_laneq_u32::<MDSI4>(cumul1_a, state5, mdsv4); // MDS[4]
|
||||
cumul2_a = vaddw_high_u32(cumul2_a, state5); // MDS[3]
|
||||
cumul3_a = vmlal_high_laneq_u32::<MDSI2>(cumul3_a, state5, mdsv2); // MDS[2]
|
||||
cumul4_a = vaddw_high_u32(cumul4_a, state5); // MDS[1]
|
||||
cumul5_a = vaddw_high_u32(cumul5_a, state5); // MDS[0]
|
||||
cumul6_a = vmlal_high_laneq_u32::<MDSI11>(cumul6_a, state5, mdsv11); // MDS[11]
|
||||
cumul7_a = vmlal_high_laneq_u32::<MDSI10>(cumul7_a, state5, mdsv10); // MDS[10]
|
||||
cumul8_a = vmlal_high_laneq_u32::<MDSI9>(cumul8_a, state5, mdsv9); // MDS[9]
|
||||
cumul9_a = vmlal_high_laneq_u32::<MDSI8>(cumul9_a, state5, mdsv8); // MDS[8]
|
||||
cumul10_a = vmlal_high_laneq_u32::<MDSI7>(cumul10_a, state5, mdsv7); // MDS[7]
|
||||
cumul11_a = vmlal_high_laneq_u32::<MDSI6>(cumul11_a, state5, mdsv6); // MDS[6]
|
||||
|
||||
cumul0_b = vmlal_laneq_u32::<MDSI6>(cumul0_b, state6, mdsv6); // MDS[6]
|
||||
cumul1_b = vmlal_laneq_u32::<MDSI5>(cumul1_b, state6, mdsv5); // MDS[5]
|
||||
cumul2_b = vmlal_laneq_u32::<MDSI4>(cumul2_b, state6, mdsv4); // MDS[4]
|
||||
cumul3_b = vaddw_u32(cumul3_b, state6); // MDS[3]
|
||||
cumul4_b = vmlal_laneq_u32::<MDSI2>(cumul4_b, state6, mdsv2); // MDS[2]
|
||||
cumul5_b = vaddw_u32(cumul5_b, state6); // MDS[1]
|
||||
cumul6_b = vaddw_u32(cumul6_b, state6); // MDS[0]
|
||||
cumul7_b = vmlal_laneq_u32::<MDSI11>(cumul7_b, state6, mdsv11); // MDS[11]
|
||||
cumul8_b = vmlal_laneq_u32::<MDSI10>(cumul8_b, state6, mdsv10); // MDS[10]
|
||||
cumul9_b = vmlal_laneq_u32::<MDSI9>(cumul9_b, state6, mdsv9); // MDS[9]
|
||||
cumul10_b = vmlal_laneq_u32::<MDSI8>(cumul10_b, state6, mdsv8); // MDS[8]
|
||||
cumul11_b = vmlal_laneq_u32::<MDSI7>(cumul11_b, state6, mdsv7); // MDS[7]
|
||||
|
||||
cumul0_a = vmlal_high_laneq_u32::<MDSI7>(cumul0_a, state7, mdsv7); // MDS[7]
|
||||
cumul1_a = vmlal_high_laneq_u32::<MDSI6>(cumul1_a, state7, mdsv6); // MDS[6]
|
||||
cumul2_a = vmlal_high_laneq_u32::<MDSI5>(cumul2_a, state7, mdsv5); // MDS[5]
|
||||
cumul3_a = vmlal_high_laneq_u32::<MDSI4>(cumul3_a, state7, mdsv4); // MDS[4]
|
||||
cumul4_a = vaddw_high_u32(cumul4_a, state7); // MDS[3]
|
||||
cumul5_a = vmlal_high_laneq_u32::<MDSI2>(cumul5_a, state7, mdsv2); // MDS[2]
|
||||
cumul6_a = vaddw_high_u32(cumul6_a, state7); // MDS[1]
|
||||
cumul7_a = vaddw_high_u32(cumul7_a, state7); // MDS[0]
|
||||
cumul8_a = vmlal_high_laneq_u32::<MDSI11>(cumul8_a, state7, mdsv11); // MDS[11]
|
||||
cumul9_a = vmlal_high_laneq_u32::<MDSI10>(cumul9_a, state7, mdsv10); // MDS[10]
|
||||
cumul10_a = vmlal_high_laneq_u32::<MDSI9>(cumul10_a, state7, mdsv9); // MDS[9]
|
||||
cumul11_a = vmlal_high_laneq_u32::<MDSI8>(cumul11_a, state7, mdsv8); // MDS[8]
|
||||
|
||||
cumul0_b = vmlal_laneq_u32::<MDSI8>(cumul0_b, state8, mdsv8); // MDS[8]
|
||||
cumul1_b = vmlal_laneq_u32::<MDSI7>(cumul1_b, state8, mdsv7); // MDS[7]
|
||||
cumul2_b = vmlal_laneq_u32::<MDSI6>(cumul2_b, state8, mdsv6); // MDS[6]
|
||||
cumul3_b = vmlal_laneq_u32::<MDSI5>(cumul3_b, state8, mdsv5); // MDS[5]
|
||||
cumul4_b = vmlal_laneq_u32::<MDSI4>(cumul4_b, state8, mdsv4); // MDS[4]
|
||||
cumul5_b = vaddw_u32(cumul5_b, state8); // MDS[3]
|
||||
cumul6_b = vmlal_laneq_u32::<MDSI2>(cumul6_b, state8, mdsv2); // MDS[2]
|
||||
cumul7_b = vaddw_u32(cumul7_b, state8); // MDS[1]
|
||||
cumul8_b = vaddw_u32(cumul8_b, state8); // MDS[0]
|
||||
cumul9_b = vmlal_laneq_u32::<MDSI11>(cumul9_b, state8, mdsv11); // MDS[11]
|
||||
cumul10_b = vmlal_laneq_u32::<MDSI10>(cumul10_b, state8, mdsv10); // MDS[10]
|
||||
cumul11_b = vmlal_laneq_u32::<MDSI9>(cumul11_b, state8, mdsv9); // MDS[9]
|
||||
|
||||
cumul0_a = vmlal_high_laneq_u32::<MDSI9>(cumul0_a, state9, mdsv9); // MDS[9]
|
||||
cumul1_a = vmlal_high_laneq_u32::<MDSI8>(cumul1_a, state9, mdsv8); // MDS[8]
|
||||
cumul2_a = vmlal_high_laneq_u32::<MDSI7>(cumul2_a, state9, mdsv7); // MDS[7]
|
||||
cumul3_a = vmlal_high_laneq_u32::<MDSI6>(cumul3_a, state9, mdsv6); // MDS[6]
|
||||
cumul4_a = vmlal_high_laneq_u32::<MDSI5>(cumul4_a, state9, mdsv5); // MDS[5]
|
||||
cumul5_a = vmlal_high_laneq_u32::<MDSI4>(cumul5_a, state9, mdsv4); // MDS[4]
|
||||
cumul6_a = vaddw_high_u32(cumul6_a, state9); // MDS[3]
|
||||
cumul7_a = vmlal_high_laneq_u32::<MDSI2>(cumul7_a, state9, mdsv2); // MDS[2]
|
||||
cumul8_a = vaddw_high_u32(cumul8_a, state9); // MDS[1]
|
||||
cumul9_a = vaddw_high_u32(cumul9_a, state9); // MDS[0]
|
||||
cumul10_a = vmlal_high_laneq_u32::<MDSI11>(cumul10_a, state9, mdsv11); // MDS[11]
|
||||
cumul11_a = vmlal_high_laneq_u32::<MDSI10>(cumul11_a, state9, mdsv10); // MDS[10]
|
||||
|
||||
cumul0_b = vmlal_laneq_u32::<MDSI10>(cumul0_b, state10, mdsv10); // MDS[10]
|
||||
cumul1_b = vmlal_laneq_u32::<MDSI9>(cumul1_b, state10, mdsv9); // MDS[9]
|
||||
cumul2_b = vmlal_laneq_u32::<MDSI8>(cumul2_b, state10, mdsv8); // MDS[8]
|
||||
cumul3_b = vmlal_laneq_u32::<MDSI7>(cumul3_b, state10, mdsv7); // MDS[7]
|
||||
cumul4_b = vmlal_laneq_u32::<MDSI6>(cumul4_b, state10, mdsv6); // MDS[6]
|
||||
cumul5_b = vmlal_laneq_u32::<MDSI5>(cumul5_b, state10, mdsv5); // MDS[5]
|
||||
cumul6_b = vmlal_laneq_u32::<MDSI4>(cumul6_b, state10, mdsv4); // MDS[4]
|
||||
cumul7_b = vaddw_u32(cumul7_b, state10); // MDS[3]
|
||||
cumul8_b = vmlal_laneq_u32::<MDSI2>(cumul8_b, state10, mdsv2); // MDS[2]
|
||||
cumul9_b = vaddw_u32(cumul9_b, state10); // MDS[1]
|
||||
cumul10_b = vaddw_u32(cumul10_b, state10); // MDS[0]
|
||||
cumul11_b = vmlal_laneq_u32::<MDSI11>(cumul11_b, state10, mdsv11); // MDS[11]
|
||||
|
||||
cumul0_a = vmlal_high_laneq_u32::<MDSI11>(cumul0_a, state11, mdsv11); // MDS[11]
|
||||
cumul1_a = vmlal_high_laneq_u32::<MDSI10>(cumul1_a, state11, mdsv10); // MDS[10]
|
||||
cumul2_a = vmlal_high_laneq_u32::<MDSI9>(cumul2_a, state11, mdsv9); // MDS[9]
|
||||
cumul3_a = vmlal_high_laneq_u32::<MDSI8>(cumul3_a, state11, mdsv8); // MDS[8]
|
||||
cumul4_a = vmlal_high_laneq_u32::<MDSI7>(cumul4_a, state11, mdsv7); // MDS[7]
|
||||
cumul5_a = vmlal_high_laneq_u32::<MDSI6>(cumul5_a, state11, mdsv6); // MDS[6]
|
||||
cumul6_a = vmlal_high_laneq_u32::<MDSI5>(cumul6_a, state11, mdsv5); // MDS[5]
|
||||
cumul7_a = vmlal_high_laneq_u32::<MDSI4>(cumul7_a, state11, mdsv4); // MDS[4]
|
||||
cumul8_a = vaddw_high_u32(cumul8_a, state11); // MDS[3]
|
||||
cumul9_a = vmlal_high_laneq_u32::<MDSI2>(cumul9_a, state11, mdsv2); // MDS[2]
|
||||
cumul10_a = vaddw_high_u32(cumul10_a, state11); // MDS[1]
|
||||
cumul11_a = vaddw_high_u32(cumul11_a, state11); // MDS[0]
|
||||
|
||||
let reduced = [
|
||||
mds_reduce([[cumul0_a, cumul0_b], [cumul1_a, cumul1_b]]),
|
||||
mds_reduce([[cumul2_a, cumul2_b], [cumul3_a, cumul3_b]]),
|
||||
mds_reduce([[cumul4_a, cumul4_b], [cumul5_a, cumul5_b]]),
|
||||
mds_reduce([[cumul6_a, cumul6_b], [cumul7_a, cumul7_b]]),
|
||||
mds_reduce([[cumul8_a, cumul8_b], [cumul9_a, cumul9_b]]),
|
||||
mds_reduce([[cumul10_a, cumul10_b], [cumul11_a, cumul11_b]]),
|
||||
];
|
||||
[
|
||||
vgetq_lane_u64::<0>(reduced[0]),
|
||||
vgetq_lane_u64::<1>(reduced[0]),
|
||||
vgetq_lane_u64::<0>(reduced[1]),
|
||||
vgetq_lane_u64::<1>(reduced[1]),
|
||||
vgetq_lane_u64::<0>(reduced[2]),
|
||||
vgetq_lane_u64::<1>(reduced[2]),
|
||||
vgetq_lane_u64::<0>(reduced[3]),
|
||||
vgetq_lane_u64::<1>(reduced[3]),
|
||||
vgetq_lane_u64::<0>(reduced[4]),
|
||||
vgetq_lane_u64::<1>(reduced[4]),
|
||||
vgetq_lane_u64::<0>(reduced[5]),
|
||||
vgetq_lane_u64::<1>(reduced[5]),
|
||||
]
|
||||
}
|
||||
|
||||
// ======================================== PARTIAL ROUNDS =========================================
|
||||
|
||||
#[rustfmt::skip]
|
||||
macro_rules! mds_reduce_asm {
|
||||
($c0:literal, $c1:literal, $out:literal, $consts:literal) => {
|
||||
concat!(
|
||||
// Swizzle
|
||||
"zip1.2d ", $out, ",", $c0, ",", $c1, "\n", // lo
|
||||
"zip2.2d ", $c0, ",", $c0, ",", $c1, "\n", // hi
|
||||
|
||||
// Reduction from u96
|
||||
"usra.2d ", $c0, ",", $out, ", #32\n", "sli.2d ", $out, ",", $c0, ", #32\n",
|
||||
// Extract high 32-bits.
|
||||
"uzp2.4s ", $c0, ",", $c0, ",", $c0, "\n",
|
||||
// Multiply by EPSILON and accumulate.
|
||||
"mov.16b ", $c1, ",", $out, "\n",
|
||||
"umlal.2d ", $out, ",", $c0, ", ", $consts, "[0]\n",
|
||||
"cmhi.2d ", $c1, ",", $c1, ",", $out, "\n",
|
||||
"usra.2d ", $out, ",", $c1, ", #32",
|
||||
)
|
||||
};
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn partial_round(
|
||||
(state_scalar, state_vector): ([u64; WIDTH], [uint64x2_t; 5]),
|
||||
round_constants: &[u64; WIDTH],
|
||||
) -> ([u64; WIDTH], [uint64x2_t; 5]) {
|
||||
// see readme-asm.md
|
||||
|
||||
// mds_consts0 == [0xffffffff, 1 << 1, 1 << 3, 1 << 5]
|
||||
// mds_consts1 == [1 << 8, 1 << 10, 1 << 12, 1 << 16]
|
||||
let mds_consts0: uint32x4_t = vld1q_u32((&MDS_CONSTS[0..4]).as_ptr().cast::<u32>());
|
||||
let mds_consts1: uint32x4_t = vld1q_u32((&MDS_CONSTS[4..8]).as_ptr().cast::<u32>());
|
||||
|
||||
let res0: u64;
|
||||
let res1: u64;
|
||||
let res23: uint64x2_t;
|
||||
let res45: uint64x2_t;
|
||||
let res67: uint64x2_t;
|
||||
let res89: uint64x2_t;
|
||||
let res1011: uint64x2_t;
|
||||
|
||||
let res2_scalar: u64;
|
||||
let res3_scalar: u64;
|
||||
let res4_scalar: u64;
|
||||
let res5_scalar: u64;
|
||||
let res6_scalar: u64;
|
||||
let res7_scalar: u64;
|
||||
let res8_scalar: u64;
|
||||
let res9_scalar: u64;
|
||||
let res10_scalar: u64;
|
||||
let res11_scalar: u64;
|
||||
|
||||
asm!(
|
||||
"ldp d0, d1, [{rc_ptr}, #16]",
|
||||
"fmov d21, {s1}",
|
||||
"ldp {lo0}, {lo1}, [{rc_ptr}]",
|
||||
"umulh {t0}, {s0}, {s0}",
|
||||
"mul {t1}, {s0}, {s0}",
|
||||
"subs {t1}, {t1}, {t0}, lsr #32",
|
||||
"csetm {t2:w}, cc",
|
||||
"lsl {t3}, {t0}, #32",
|
||||
"sub {t1}, {t1}, {t2}",
|
||||
"mov {t0:w}, {t0:w}",
|
||||
"sub {t0}, {t3}, {t0}",
|
||||
"adds {t0}, {t1}, {t0}",
|
||||
"csetm {t1:w}, cs",
|
||||
"add {t0}, {t0}, {t1}",
|
||||
"umulh {t1}, {s0}, {t0}",
|
||||
"umulh {t2}, {t0}, {t0}",
|
||||
"mul {s0}, {s0}, {t0}",
|
||||
"mul {t0}, {t0}, {t0}",
|
||||
"subs {s0}, {s0}, {t1}, lsr #32",
|
||||
"csetm {t3:w}, cc",
|
||||
"subs {t0}, {t0}, {t2}, lsr #32",
|
||||
"csetm {t4:w}, cc",
|
||||
"lsl {t5}, {t1}, #32",
|
||||
"lsl {t6}, {t2}, #32",
|
||||
"sub {s0}, {s0}, {t3}",
|
||||
"sub {t0}, {t0}, {t4}",
|
||||
"mov {t1:w}, {t1:w}",
|
||||
"mov {t2:w}, {t2:w}",
|
||||
"sub {t1}, {t5}, {t1}",
|
||||
"ushll.2d v10, v21, #10",
|
||||
"sub {t2}, {t6}, {t2}",
|
||||
"ushll.2d v11, v21, #16",
|
||||
"adds {t1}, {s0}, {t1}",
|
||||
"uaddw.2d v0, v0, v22",
|
||||
"csetm {s0:w}, cs",
|
||||
"umlal.2d v1, v22, v31[1]",
|
||||
"adds {t2}, {t0}, {t2}",
|
||||
"uaddw2.2d v10, v10, v22",
|
||||
"csetm {t0:w}, cs",
|
||||
"uaddw2.2d v11, v11, v22",
|
||||
"add {t1}, {t1}, {s0}",
|
||||
"ldp d2, d3, [{rc_ptr}, #32]",
|
||||
"add {t2}, {t2}, {t0}",
|
||||
"ushll.2d v12, v21, #3",
|
||||
"umulh {s0}, {t1}, {t2}",
|
||||
"ushll.2d v13, v21, #12",
|
||||
"mul {t0}, {t1}, {t2}",
|
||||
"umlal.2d v0, v23, v30[1]",
|
||||
"add {lo1}, {lo1}, {s1:w}, uxtw",
|
||||
"uaddw2.2d v10, v10, v23",
|
||||
"add {lo0}, {lo0}, {s1:w}, uxtw",
|
||||
"uaddw.2d v11, v11, v23",
|
||||
"lsr {hi0}, {s1}, #32",
|
||||
"umlal2.2d v1, v23, v30[1]",
|
||||
"lsr {t3}, {s2}, #32",
|
||||
"umlal.2d v2, v22, v31[3]",
|
||||
"lsr {t4}, {s3}, #32",
|
||||
"umlal2.2d v12, v22, v31[1]",
|
||||
"add {hi1}, {hi0}, {t3}",
|
||||
"umlal.2d v3, v22, v30[2]",
|
||||
"add {hi0}, {hi0}, {t3}, lsl #1",
|
||||
"umlal2.2d v13, v22, v31[3]",
|
||||
"add {lo1}, {lo1}, {s2:w}, uxtw",
|
||||
"ldp d4, d5, [{rc_ptr}, #48]",
|
||||
"add {lo0}, {lo0}, {s2:w}, uxtw #1",
|
||||
"ushll.2d v14, v21, #8",
|
||||
"lsr {t3}, {s4}, #32",
|
||||
"ushll.2d v15, v21, #1",
|
||||
"lsr {t5}, {s5}, #32",
|
||||
"umlal.2d v0, v24, v30[2]",
|
||||
"subs {t0}, {t0}, {s0}, lsr #32",
|
||||
"umlal2.2d v10, v24, v30[3]",
|
||||
"add {hi1}, {hi1}, {t4}, lsl #1",
|
||||
"umlal2.2d v11, v24, v30[2]",
|
||||
"add {t6}, {t3}, {t5}, lsl #3",
|
||||
"uaddw.2d v1, v1, v24",
|
||||
"add {t5}, {t3}, {t5}, lsl #2",
|
||||
"uaddw.2d v2, v2, v23",
|
||||
"lsr {t3}, {s6}, #32",
|
||||
"umlal.2d v3, v23, v31[1]",
|
||||
"lsr {s1}, {s7}, #32",
|
||||
"uaddw2.2d v12, v12, v23",
|
||||
"mov {s2:w}, {s4:w}",
|
||||
"uaddw2.2d v13, v13, v23",
|
||||
"add {hi0}, {hi0}, {t4}",
|
||||
"umlal.2d v4, v22, v31[2]",
|
||||
"add {lo1}, {lo1}, {s3:w}, uxtw #1",
|
||||
"umlal2.2d v14, v22, v30[2]",
|
||||
"add {lo0}, {lo0}, {s3:w}, uxtw",
|
||||
"umlal.2d v5, v22, v31[0]",
|
||||
"add {t4}, {s2}, {s5:w}, uxtw #3",
|
||||
"umlal2.2d v15, v22, v31[2]",
|
||||
"add {s2}, {s2}, {s5:w}, uxtw #2",
|
||||
"ldp d6, d7, [{rc_ptr}, #64]",
|
||||
"add {s3}, {s1}, {t3}, lsl #4",
|
||||
"ushll.2d v16, v21, #5",
|
||||
"csetm {t1:w}, cc",
|
||||
"ushll.2d v17, v21, #3",
|
||||
"add {hi1}, {hi1}, {t6}",
|
||||
"umlal.2d v0, v25, v30[1]",
|
||||
"add {hi0}, {hi0}, {t5}, lsl #3",
|
||||
"umlal2.2d v10, v25, v31[0]",
|
||||
"mov {t5:w}, {s6:w}",
|
||||
"umlal.2d v1, v25, v30[3]",
|
||||
"mov {t6:w}, {s7:w}",
|
||||
"umlal2.2d v11, v25, v30[1]",
|
||||
"add {s4}, {t6}, {t5}, lsl #4",
|
||||
"umlal.2d v2, v24, v30[1]",
|
||||
"add {t3}, {t3}, {s1}, lsl #7",
|
||||
"uaddw2.2d v12, v12, v24",
|
||||
"lsr {s1}, {s8}, #32",
|
||||
"uaddw.2d v13, v13, v24",
|
||||
"lsr {s5}, {s9}, #32",
|
||||
"umlal2.2d v3, v24, v30[1]",
|
||||
"lsl {t2}, {s0}, #32",
|
||||
"umlal.2d v4, v23, v31[3]",
|
||||
"sub {t0}, {t0}, {t1}",
|
||||
"umlal2.2d v14, v23, v31[1]",
|
||||
"add {lo1}, {lo1}, {t4}",
|
||||
"umlal.2d v5, v23, v30[2]",
|
||||
"add {lo0}, {lo0}, {s2}, lsl #3",
|
||||
"umlal2.2d v15, v23, v31[3]",
|
||||
"add {t4}, {t5}, {t6}, lsl #7",
|
||||
"umlal.2d v6, v22, v30[1]",
|
||||
"add {hi1}, {hi1}, {s3}, lsl #1",
|
||||
"umlal2.2d v16, v22, v31[0]",
|
||||
"add {t5}, {s1}, {s5}, lsl #4",
|
||||
"umlal.2d v7, v22, v30[3]",
|
||||
"mov {s0:w}, {s0:w}",
|
||||
"umlal2.2d v17, v22, v30[1]",
|
||||
"sub {s0}, {t2}, {s0}",
|
||||
"ldp d8, d9, [{rc_ptr}, #80]",
|
||||
"add {lo1}, {lo1}, {s4}, lsl #1",
|
||||
"ushll.2d v18, v21, #0",
|
||||
"add {hi0}, {hi0}, {t3}, lsl #1",
|
||||
"ushll.2d v19, v21, #1",
|
||||
"mov {t3:w}, {s9:w}",
|
||||
"umlal.2d v0, v26, v31[2]",
|
||||
"mov {t6:w}, {s8:w}",
|
||||
"umlal2.2d v10, v26, v30[2]",
|
||||
"add {s2}, {t6}, {t3}, lsl #4",
|
||||
"umlal.2d v1, v26, v31[0]",
|
||||
"add {s1}, {s5}, {s1}, lsl #9",
|
||||
"umlal2.2d v11, v26, v31[2]",
|
||||
"lsr {s3}, {s10}, #32",
|
||||
"umlal.2d v2, v25, v30[2]",
|
||||
"lsr {s4}, {s11}, #32",
|
||||
"umlal2.2d v12, v25, v30[3]",
|
||||
"adds {s0}, {t0}, {s0}",
|
||||
"umlal2.2d v13, v25, v30[2]",
|
||||
"add {lo0}, {lo0}, {t4}, lsl #1",
|
||||
"uaddw.2d v3, v3, v25",
|
||||
"add {t3}, {t3}, {t6}, lsl #9",
|
||||
"uaddw.2d v4, v4, v24",
|
||||
"add {hi1}, {hi1}, {t5}, lsl #8",
|
||||
"umlal.2d v5, v24, v31[1]",
|
||||
"add {t4}, {s3}, {s4}, lsl #13",
|
||||
"uaddw2.2d v14, v14, v24",
|
||||
"csetm {t0:w}, cs",
|
||||
"uaddw2.2d v15, v15, v24",
|
||||
"add {lo1}, {lo1}, {s2}, lsl #8",
|
||||
"umlal.2d v6, v23, v31[2]",
|
||||
"add {hi0}, {hi0}, {s1}, lsl #3",
|
||||
"umlal2.2d v16, v23, v30[2]",
|
||||
"mov {t5:w}, {s10:w}",
|
||||
"umlal.2d v7, v23, v31[0]",
|
||||
"mov {t6:w}, {s11:w}",
|
||||
"umlal2.2d v17, v23, v31[2]",
|
||||
"add {s1}, {t5}, {t6}, lsl #13",
|
||||
"umlal.2d v8, v22, v30[2]",
|
||||
"add {s2}, {s4}, {s3}, lsl #6",
|
||||
"umlal2.2d v18, v22, v30[3]",
|
||||
"add {s0}, {s0}, {t0}",
|
||||
"uaddw.2d v9, v9, v22",
|
||||
"add {lo0}, {lo0}, {t3}, lsl #3",
|
||||
"umlal2.2d v19, v22, v30[2]",
|
||||
"add {t3}, {t6}, {t5}, lsl #6",
|
||||
"add.2d v0, v0, v10",
|
||||
"add {hi1}, {hi1}, {t4}, lsl #3",
|
||||
"add.2d v1, v1, v11",
|
||||
"fmov d20, {s0}",
|
||||
"umlal.2d v0, v20, v31[3]",
|
||||
"add {lo1}, {lo1}, {s1}, lsl #3",
|
||||
"umlal.2d v1, v20, v30[2]",
|
||||
"add {hi0}, {hi0}, {s2}, lsl #10",
|
||||
"zip1.2d v22, v0, v1",
|
||||
"lsr {t4}, {s0}, #32",
|
||||
"zip2.2d v0, v0, v1",
|
||||
"add {lo0}, {lo0}, {t3}, lsl #10",
|
||||
"usra.2d v0, v22, #32",
|
||||
"add {hi1}, {hi1}, {t4}, lsl #10",
|
||||
"sli.2d v22, v0, #32",
|
||||
"mov {t3:w}, {s0:w}",
|
||||
"uzp2.4s v0, v0, v0",
|
||||
"add {lo1}, {lo1}, {t3}, lsl #10",
|
||||
"mov.16b v1, v22",
|
||||
"add {hi0}, {hi0}, {t4}",
|
||||
"umlal.2d v22, v0, v30[0]",
|
||||
"add {lo0}, {lo0}, {t3}",
|
||||
"cmhi.2d v1, v1, v22",
|
||||
"lsl {t0}, {hi0}, #32",
|
||||
"usra.2d v22, v1, #32",
|
||||
"lsl {t1}, {hi1}, #32",
|
||||
"fmov {s2}, d22",
|
||||
"adds {lo0}, {lo0}, {t0}",
|
||||
"fmov.d {s3}, v22[1]",
|
||||
"csetm {t0:w}, cs",
|
||||
"umlal.2d v2, v26, v30[1]",
|
||||
"adds {lo1}, {lo1}, {t1}",
|
||||
"umlal2.2d v12, v26, v31[0]",
|
||||
"csetm {t1:w}, cs",
|
||||
"umlal.2d v3, v26, v30[3]",
|
||||
"and {t2}, {hi0}, #0xffffffff00000000",
|
||||
"umlal2.2d v13, v26, v30[1]",
|
||||
"and {t3}, {hi1}, #0xffffffff00000000",
|
||||
"umlal.2d v4, v25, v30[1]",
|
||||
"lsr {hi0}, {hi0}, #32",
|
||||
"uaddw2.2d v14, v14, v25",
|
||||
"lsr {hi1}, {hi1}, #32",
|
||||
"uaddw.2d v15, v15, v25",
|
||||
"sub {hi0}, {t2}, {hi0}",
|
||||
"umlal2.2d v5, v25, v30[1]",
|
||||
"sub {hi1}, {t3}, {hi1}",
|
||||
"umlal.2d v6, v24, v31[3]",
|
||||
"add {lo0}, {lo0}, {t0}",
|
||||
"umlal2.2d v16, v24, v31[1]",
|
||||
"add {lo1}, {lo1}, {t1}",
|
||||
"umlal.2d v7, v24, v30[2]",
|
||||
"adds {lo0}, {lo0}, {hi0}",
|
||||
"umlal2.2d v17, v24, v31[3]",
|
||||
"csetm {t0:w}, cs",
|
||||
"umlal.2d v8, v23, v30[1]",
|
||||
"adds {lo1}, {lo1}, {hi1}",
|
||||
"umlal2.2d v18, v23, v31[0]",
|
||||
"csetm {t1:w}, cs",
|
||||
"umlal.2d v9, v23, v30[3]",
|
||||
"add {s0}, {lo0}, {t0}",
|
||||
"umlal2.2d v19, v23, v30[1]",
|
||||
"add {s1}, {lo1}, {t1}",
|
||||
"add.2d v2, v2, v12",
|
||||
"add.2d v3, v3, v13",
|
||||
"umlal.2d v2, v20, v31[2]",
|
||||
"umlal.2d v3, v20, v31[0]",
|
||||
mds_reduce_asm!("v2", "v3", "v23", "v30"),
|
||||
"fmov {s4}, d23",
|
||||
"fmov.d {s5}, v23[1]",
|
||||
"umlal.2d v4, v26, v30[2]",
|
||||
"umlal2.2d v14, v26, v30[3]",
|
||||
"umlal2.2d v15, v26, v30[2]",
|
||||
"uaddw.2d v5, v5, v26",
|
||||
"uaddw.2d v6, v6, v25",
|
||||
"uaddw2.2d v16, v16, v25",
|
||||
"uaddw2.2d v17, v17, v25",
|
||||
"umlal.2d v7, v25, v31[1]",
|
||||
"umlal.2d v8, v24, v31[2]",
|
||||
"umlal2.2d v18, v24, v30[2]",
|
||||
"umlal.2d v9, v24, v31[0]",
|
||||
"umlal2.2d v19, v24, v31[2]",
|
||||
"add.2d v4, v4, v14",
|
||||
"add.2d v5, v5, v15",
|
||||
"umlal.2d v4, v20, v30[1]",
|
||||
"umlal.2d v5, v20, v30[3]",
|
||||
mds_reduce_asm!("v4", "v5", "v24", "v30"),
|
||||
"fmov {s6}, d24",
|
||||
"fmov.d {s7}, v24[1]",
|
||||
"umlal.2d v6, v26, v30[1]",
|
||||
"uaddw2.2d v16, v16, v26",
|
||||
"umlal2.2d v17, v26, v30[1]",
|
||||
"uaddw.2d v7, v7, v26",
|
||||
"umlal.2d v8, v25, v31[3]",
|
||||
"umlal2.2d v18, v25, v31[1]",
|
||||
"umlal.2d v9, v25, v30[2]",
|
||||
"umlal2.2d v19, v25, v31[3]",
|
||||
"add.2d v6, v6, v16",
|
||||
"add.2d v7, v7, v17",
|
||||
"umlal.2d v6, v20, v30[2]",
|
||||
"uaddw.2d v7, v7, v20",
|
||||
mds_reduce_asm!("v6", "v7", "v25", "v30"),
|
||||
"fmov {s8}, d25",
|
||||
"fmov.d {s9}, v25[1]",
|
||||
"uaddw.2d v8, v8, v26",
|
||||
"uaddw2.2d v18, v18, v26",
|
||||
"umlal.2d v9, v26, v31[1]",
|
||||
"uaddw2.2d v19, v19, v26",
|
||||
"add.2d v8, v8, v18",
|
||||
"add.2d v9, v9, v19",
|
||||
"umlal.2d v8, v20, v30[1]",
|
||||
"uaddw.2d v9, v9, v20",
|
||||
mds_reduce_asm!("v8", "v9", "v26", "v30"),
|
||||
"fmov {s10}, d26",
|
||||
"fmov.d {s11}, v26[1]",
|
||||
|
||||
// Scalar inputs/outputs
|
||||
// s0 is transformed by the S-box
|
||||
s0 = inout(reg) state_scalar[0] => res0,
|
||||
// s1-s6 double as scratch in the MDS matrix multiplication
|
||||
s1 = inout(reg) state_scalar[1] => res1,
|
||||
// s2-s11 are copied from the vector inputs/outputs
|
||||
s2 = inout(reg) state_scalar[2] => res2_scalar,
|
||||
s3 = inout(reg) state_scalar[3] => res3_scalar,
|
||||
s4 = inout(reg) state_scalar[4] => res4_scalar,
|
||||
s5 = inout(reg) state_scalar[5] => res5_scalar,
|
||||
s6 = inout(reg) state_scalar[6] => res6_scalar,
|
||||
s7 = inout(reg) state_scalar[7] => res7_scalar,
|
||||
s8 = inout(reg) state_scalar[8] => res8_scalar,
|
||||
s9 = inout(reg) state_scalar[9] => res9_scalar,
|
||||
s10 = inout(reg) state_scalar[10] => res10_scalar,
|
||||
s11 = inout(reg) state_scalar[11] => res11_scalar,
|
||||
|
||||
// Pointer to the round constants
|
||||
rc_ptr = in(reg) round_constants.as_ptr(),
|
||||
|
||||
// Scalar MDS multiplication accumulators
|
||||
lo1 = out(reg) _,
|
||||
hi1 = out(reg) _,
|
||||
lo0 = out(reg) _,
|
||||
hi0 = out(reg) _,
|
||||
|
||||
// Scalar scratch registers
|
||||
// All are used in the scalar S-box
|
||||
t0 = out(reg) _,
|
||||
t1 = out(reg) _,
|
||||
t2 = out(reg) _,
|
||||
// t3-t6 are used in the scalar MDS matrix multiplication
|
||||
t3 = out(reg) _,
|
||||
t4 = out(reg) _,
|
||||
t5 = out(reg) _,
|
||||
t6 = out(reg) _,
|
||||
|
||||
// Vector MDS multiplication accumulators
|
||||
// v{n} and v1{n} are accumulators for res[n + 2] (we need two to mask latency)
|
||||
// The low and high 64-bits are accumulators for the low and high results, respectively
|
||||
out("v0") _,
|
||||
out("v1") _,
|
||||
out("v2") _,
|
||||
out("v3") _,
|
||||
out("v4") _,
|
||||
out("v5") _,
|
||||
out("v6") _,
|
||||
out("v7") _,
|
||||
out("v8") _,
|
||||
out("v9") _,
|
||||
out("v10") _,
|
||||
out("v11") _,
|
||||
out("v12") _,
|
||||
out("v13") _,
|
||||
out("v14") _,
|
||||
out("v15") _,
|
||||
out("v16") _,
|
||||
out("v17") _,
|
||||
out("v18") _,
|
||||
out("v19") _,
|
||||
|
||||
// Inputs into vector MDS matrix multiplication
|
||||
// v20 and v21 are sbox(state0) and state1, respectively. They are copied from the scalar
|
||||
// registers.
|
||||
out("v20") _,
|
||||
out("v21") _,
|
||||
// v22, ..., v26 hold state[2,3], ..., state[10,11]
|
||||
inout("v22") state_vector[0] => res23,
|
||||
inout("v23") state_vector[1] => res45,
|
||||
inout("v24") state_vector[2] => res67,
|
||||
inout("v25") state_vector[3] => res89,
|
||||
inout("v26") state_vector[4] => res1011,
|
||||
|
||||
// Useful constants
|
||||
in("v30") mds_consts0,
|
||||
in("v31") mds_consts1,
|
||||
|
||||
options(nostack, pure, readonly),
|
||||
);
|
||||
(
|
||||
[
|
||||
res0,
|
||||
res1,
|
||||
res2_scalar,
|
||||
res3_scalar,
|
||||
res4_scalar,
|
||||
res5_scalar,
|
||||
res6_scalar,
|
||||
res7_scalar,
|
||||
res8_scalar,
|
||||
res9_scalar,
|
||||
res10_scalar,
|
||||
res11_scalar,
|
||||
],
|
||||
[res23, res45, res67, res89, res1011],
|
||||
)
|
||||
}
|
||||
|
||||
// ========================================== GLUE CODE ===========================================
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn full_round(state: [u64; 12], round_constants: &[u64; WIDTH]) -> [u64; 12] {
|
||||
let state = sbox_layer_full(state);
|
||||
let state = mds_const_layers_full(state, round_constants);
|
||||
state
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn full_rounds(
|
||||
mut state: [u64; 12],
|
||||
round_constants: &[u64; WIDTH * HALF_N_FULL_ROUNDS],
|
||||
) -> [u64; 12] {
|
||||
for round_constants_chunk in round_constants.chunks_exact(WIDTH) {
|
||||
state = full_round(state, round_constants_chunk.try_into().unwrap());
|
||||
}
|
||||
state
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
unsafe fn partial_rounds(
|
||||
state: [u64; 12],
|
||||
round_constants: &[u64; WIDTH * N_PARTIAL_ROUNDS],
|
||||
) -> [u64; 12] {
|
||||
let mut state = (
|
||||
state,
|
||||
[
|
||||
vcombine_u64(vcreate_u64(state[2]), vcreate_u64(state[3])),
|
||||
vcombine_u64(vcreate_u64(state[4]), vcreate_u64(state[5])),
|
||||
vcombine_u64(vcreate_u64(state[6]), vcreate_u64(state[7])),
|
||||
vcombine_u64(vcreate_u64(state[8]), vcreate_u64(state[9])),
|
||||
vcombine_u64(vcreate_u64(state[10]), vcreate_u64(state[11])),
|
||||
],
|
||||
);
|
||||
for round_constants_chunk in round_constants.chunks_exact(WIDTH) {
|
||||
state = partial_round(state, round_constants_chunk.try_into().unwrap());
|
||||
}
|
||||
state.0
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub unsafe fn poseidon(state: [GoldilocksField; 12]) -> [GoldilocksField; 12] {
|
||||
let state = [
|
||||
state[0].0,
|
||||
state[1].0,
|
||||
state[2].0,
|
||||
state[3].0,
|
||||
state[4].0,
|
||||
state[5].0,
|
||||
state[6].0,
|
||||
state[7].0,
|
||||
state[8].0,
|
||||
state[9].0,
|
||||
state[10].0,
|
||||
state[11].0,
|
||||
];
|
||||
let state = const_layer_full(state, ALL_ROUND_CONSTANTS[0..WIDTH].try_into().unwrap());
|
||||
let state = full_rounds(
|
||||
state,
|
||||
ALL_ROUND_CONSTANTS[WIDTH..WIDTH * (HALF_N_FULL_ROUNDS + 1)]
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
);
|
||||
let state = partial_rounds(
|
||||
state,
|
||||
ALL_ROUND_CONSTANTS
|
||||
[WIDTH * (HALF_N_FULL_ROUNDS + 1)..WIDTH * (HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + 1)]
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
);
|
||||
let state = full_rounds(state, &FINAL_ROUND_CONSTANTS);
|
||||
[
|
||||
GoldilocksField(state[0]),
|
||||
GoldilocksField(state[1]),
|
||||
GoldilocksField(state[2]),
|
||||
GoldilocksField(state[3]),
|
||||
GoldilocksField(state[4]),
|
||||
GoldilocksField(state[5]),
|
||||
GoldilocksField(state[6]),
|
||||
GoldilocksField(state[7]),
|
||||
GoldilocksField(state[8]),
|
||||
GoldilocksField(state[9]),
|
||||
GoldilocksField(state[10]),
|
||||
GoldilocksField(state[11]),
|
||||
]
|
||||
}
|
||||
495
src/hash/arch/aarch64/readme-asm.md
Normal file
495
src/hash/arch/aarch64/readme-asm.md
Normal file
@ -0,0 +1,495 @@
|
||||
Partial rounds ASM
|
||||
==================
|
||||
|
||||
The partial rounds are written in hand-rolled ASM. This was necessary to ensure proper pipelining. Indeed, the ASM shaves 40% off the execution time of the original vector intrinsics-based partial round.
|
||||
|
||||
The partial layer performs two operations:
|
||||
1. Apply the S-box to state[0]
|
||||
2. Apply an affine transform (MDS matrix + constant layer) to the entire state vector.
|
||||
|
||||
The S-box must be performed in scalar to minimize latency. The MDS matrix is done mostly in vector to maximize throughput. To take advantage of the otherwise idle scalar execution units, MDS matrix multiplication for result[0..2] is done in scalar. Clearly, this necessitates some data movement, as the input state must be available to both scalar and vector execution units.
|
||||
|
||||
This task has plentiful opportunities for pipelining and parallelism. Most immediately, the S-box—with its long latency chain—can be performed simultaneously with most of the MDS matrix multiplication, with the permuted input only available right before the reduction. In addition, the MDS matrix multiplication can be scheduled in a way that interleaves different kinds of operations, masking the latency of the reduction step.
|
||||
|
||||
There are three chains of ASM:
|
||||
1. the S-box,
|
||||
2. the scalar part of MDS multiplication (for result[0..2]),
|
||||
3. the vector part of MDS multiplication (for result[2..12]).
|
||||
Those chains are explained individually below. They interact sporadically to exchange results. In the compiled file, they have been interleaved.
|
||||
|
||||
|
||||
S-box
|
||||
-----
|
||||
|
||||
The ASM for the S-box is as follows:
|
||||
```assembly
|
||||
umulh {t0}, {s0}, {s0}
|
||||
mul {t1}, {s0}, {s0}
|
||||
subs {t1}, {t1}, {t0}, lsr #32
|
||||
csetm {t2:w}, cc
|
||||
lsl {t3}, {t0}, #32
|
||||
sub {t1}, {t1}, {t2}
|
||||
mov {t0:w}, {t0:w}
|
||||
sub {t0}, {t3}, {t0}
|
||||
adds {t0}, {t1}, {t0}
|
||||
csetm {t1:w}, cs
|
||||
add {t0}, {t0}, {t1}
|
||||
|
||||
// t0 now contains state ** 2
|
||||
umulh {t1}, {s0}, {t0}
|
||||
umulh {t2}, {t0}, {t0}
|
||||
mul {s0}, {s0}, {t0}
|
||||
mul {t0}, {t0}, {t0}
|
||||
subs {s0}, {s0}, {t1}, lsr #32
|
||||
csetm {t3:w}, cc
|
||||
subs {t0}, {t0}, {t2}, lsr #32
|
||||
csetm {t4:w}, cc
|
||||
lsl {t5}, {t1}, #32
|
||||
lsl {t6}, {t2}, #32
|
||||
sub {s0}, {s0}, {t3}
|
||||
sub {t0}, {t0}, {t4}
|
||||
mov {t1:w}, {t1:w}
|
||||
mov {t2:w}, {t2:w}
|
||||
sub {t1}, {t5}, {t1}
|
||||
sub {t2}, {t6}, {t2}
|
||||
adds {t1}, {s0}, {t1}
|
||||
csetm {s0:w}, cs
|
||||
adds {t2}, {t0}, {t2}
|
||||
csetm {t0:w}, cs
|
||||
add {t1}, {t1}, {s0}
|
||||
add {t2}, {t2}, {t0}
|
||||
|
||||
// t1 now contains state ** 3
|
||||
// t2 now contains state ** 4
|
||||
umulh {s0}, {t1}, {t2}
|
||||
mul {t0}, {t1}, {t2}
|
||||
subs {t0}, {t0}, {s0}, lsr #32
|
||||
csetm {t1:w}, cc
|
||||
lsl {t2}, {s0}, #32
|
||||
sub {t0}, {t0}, {t1}
|
||||
mov {s0:w}, {s0:w}
|
||||
sub {s0}, {t2}, {s0}
|
||||
adds {s0}, {t0}, {s0}
|
||||
csetm {t0:w}, cs
|
||||
add {s0}, {s0}, {t0}
|
||||
|
||||
// s0 now contains state **7
|
||||
fmov d20, {s0}
|
||||
```
|
||||
|
||||
It is merely four repetitions of a block of 11 instructions (the middle two repetitions are interleaved). The input and output are in `s0`. `t0` through `t6` are scratch registers. The `fmov` copies the result to the bottom 64 bits of the vector register v20.
|
||||
|
||||
Trick: `csetm` sets its destination to all 1s if the condition is met. In our case the destination is 32-bits and the condition is overflow/underflow of the previous instruction, so we get EPSILON on over/underflow and 0 otherwise.
|
||||
|
||||
Note: the last multiplication does not use `t3` through `t6`, making them available to scalar MDS multiplication.
|
||||
|
||||
|
||||
Scalar MDS multiplication
|
||||
-------------------------
|
||||
|
||||
The ASM for the scalar MDS multiplication is
|
||||
```assembly
|
||||
ldp {lo0}, {lo1}, [{rc_ptr}]
|
||||
add {lo1}, {lo1}, {s1:w}, uxtw
|
||||
add {lo0}, {lo0}, {s1:w}, uxtw
|
||||
lsr {hi0}, {s1}, #32
|
||||
lsr {t3}, {s2}, #32
|
||||
lsr {t4}, {s3}, #32
|
||||
add {hi1}, {hi0}, {t3}
|
||||
add {hi0}, {hi0}, {t3}, lsl #1
|
||||
add {lo1}, {lo1}, {s2:w}, uxtw
|
||||
add {lo0}, {lo0}, {s2:w}, uxtw #1
|
||||
lsr {t3}, {s4}, #32
|
||||
lsr {t5}, {s5}, #32
|
||||
add {hi1}, {hi1}, {t4}, lsl #1
|
||||
add {t6}, {t3}, {t5}, lsl #3
|
||||
add {t5}, {t3}, {t5}, lsl #2
|
||||
lsr {t3}, {s6}, #32
|
||||
lsr {s1}, {s7}, #32
|
||||
mov {s2:w}, {s4:w}
|
||||
add {hi0}, {hi0}, {t4}
|
||||
add {lo1}, {lo1}, {s3:w}, uxtw #1
|
||||
add {lo0}, {lo0}, {s3:w}, uxtw
|
||||
add {t4}, {s2}, {s5:w}, uxtw #3
|
||||
add {s2}, {s2}, {s5:w}, uxtw #2
|
||||
add {s3}, {s1}, {t3}, lsl #4
|
||||
add {hi1}, {hi1}, {t6}
|
||||
add {hi0}, {hi0}, {t5}, lsl #3
|
||||
mov {t5:w}, {s6:w}
|
||||
mov {t6:w}, {s7:w}
|
||||
add {s4}, {t6}, {t5}, lsl #4
|
||||
add {t3}, {t3}, {s1}, lsl #7
|
||||
lsr {s1}, {s8}, #32
|
||||
lsr {s5}, {s9}, #32
|
||||
add {lo1}, {lo1}, {t4}
|
||||
add {lo0}, {lo0}, {s2}, lsl #3
|
||||
add {t4}, {t5}, {t6}, lsl #7
|
||||
add {hi1}, {hi1}, {s3}, lsl #1
|
||||
add {t5}, {s1}, {s5}, lsl #4
|
||||
add {lo1}, {lo1}, {s4}, lsl #1
|
||||
add {hi0}, {hi0}, {t3}, lsl #1
|
||||
mov {t3:w}, {s9:w}
|
||||
mov {t6:w}, {s8:w}
|
||||
add {s2}, {t6}, {t3}, lsl #4
|
||||
add {s1}, {s5}, {s1}, lsl #9
|
||||
lsr {s3}, {s10}, #32
|
||||
lsr {s4}, {s11}, #32
|
||||
add {lo0}, {lo0}, {t4}, lsl #1
|
||||
add {t3}, {t3}, {t6}, lsl #9
|
||||
add {hi1}, {hi1}, {t5}, lsl #8
|
||||
add {t4}, {s3}, {s4}, lsl #13
|
||||
add {lo1}, {lo1}, {s2}, lsl #8
|
||||
add {hi0}, {hi0}, {s1}, lsl #3
|
||||
mov {t5:w}, {s10:w}
|
||||
mov {t6:w}, {s11:w}
|
||||
add {s1}, {t5}, {t6}, lsl #13
|
||||
add {s2}, {s4}, {s3}, lsl #6
|
||||
add {lo0}, {lo0}, {t3}, lsl #3
|
||||
add {t3}, {t6}, {t5}, lsl #6
|
||||
add {hi1}, {hi1}, {t4}, lsl #3
|
||||
add {lo1}, {lo1}, {s1}, lsl #3
|
||||
add {hi0}, {hi0}, {s2}, lsl #10
|
||||
lsr {t4}, {s0}, #32
|
||||
add {lo0}, {lo0}, {t3}, lsl #10
|
||||
add {hi1}, {hi1}, {t4}, lsl #10
|
||||
mov {t3:w}, {s0:w}
|
||||
add {lo1}, {lo1}, {t3}, lsl #10
|
||||
add {hi0}, {hi0}, {t4}
|
||||
add {lo0}, {lo0}, {t3}
|
||||
|
||||
// Reduction
|
||||
lsl {t0}, {hi0}, #32
|
||||
lsl {t1}, {hi1}, #32
|
||||
adds {lo0}, {lo0}, {t0}
|
||||
csetm {t0:w}, cs
|
||||
adds {lo1}, {lo1}, {t1}
|
||||
csetm {t1:w}, cs
|
||||
and {t2}, {hi0}, #0xffffffff00000000
|
||||
and {t3}, {hi1}, #0xffffffff00000000
|
||||
lsr {hi0}, {hi0}, #32
|
||||
lsr {hi1}, {hi1}, #32
|
||||
sub {hi0}, {t2}, {hi0}
|
||||
sub {hi1}, {t3}, {hi1}
|
||||
add {lo0}, {lo0}, {t0}
|
||||
add {lo1}, {lo1}, {t1}
|
||||
adds {lo0}, {lo0}, {hi0}
|
||||
csetm {t0:w}, cs
|
||||
adds {lo1}, {lo1}, {hi1}
|
||||
csetm {t1:w}, cs
|
||||
add {s0}, {lo0}, {t0}
|
||||
add {s1}, {lo1}, {t1}
|
||||
```
|
||||
|
||||
The MDS multiplication is done separately on the low 32 bits and the high 32 bits of the input, and combined by linearity. Each input is split into the low part and the high part. There are separate accumulators for the low and high parts of the result `lo0`/`lo1`, for result[0] and result[1] respectively, and `hi0`/`hi1`.
|
||||
|
||||
The pointer to the round constants is given in `rc_ptr.` Registers `s0`-`s11` constain the state vector at the start, and are later used as scratch. `t3`-`t6` are temporaries.
|
||||
|
||||
`s1` is assumed to be available first, as it is computed in scalar. `s2`-`s11` are used next. `s0` is assumed to be available last, as it must be transformed by the S-box.
|
||||
|
||||
The reduction is
|
||||
```assembly
|
||||
lsl {t0}, {hi0}, #32
|
||||
adds {lo0}, {lo0}, {t0}
|
||||
csetm {t0:w}, cs
|
||||
and {t2}, {hi0}, #0xffffffff00000000
|
||||
lsr {hi0}, {hi0}, #32
|
||||
sub {hi0}, {t2}, {hi0}
|
||||
add {lo0}, {lo0}, {t0}
|
||||
adds {lo0}, {lo0}, {hi0}
|
||||
csetm {t0:w}, cs
|
||||
add {s0}, {lo0}, {t0}
|
||||
```
|
||||
repeated and interleaved. `cset` sets its destination to EPSILON if the previous instruction overflowed.
|
||||
|
||||
|
||||
Vector MDS multiplication
|
||||
-------------------------
|
||||
|
||||
The ASM for the vector MDS multiplication is
|
||||
```assembly
|
||||
fmov d21, {s1}
|
||||
|
||||
// res2,3 <- consts,state1
|
||||
ldp d0, d1, [{rc_ptr}, #16]
|
||||
ushll.2d v10, v21, #10 // MDS[11] == 10
|
||||
ushll.2d v11, v21, #16 // MDS[10] == 16
|
||||
|
||||
// res2,3 <- state2,3
|
||||
uaddw.2d v0, v0, v22 // MDS[0] == 0
|
||||
umlal.2d v1, v22, v31[1] // MDS[11] == 10
|
||||
uaddw2.2d v10, v10, v22 // MDS[1] == 0
|
||||
uaddw2.2d v11, v11, v22 // MDS[0] == 0
|
||||
|
||||
// res4,5 <- consts,state1
|
||||
ldp d2, d3, [{rc_ptr}, #32]
|
||||
ushll.2d v12, v21, #3 // MDS[9] == 3
|
||||
ushll.2d v13, v21, #12 // MDS[8] == 12
|
||||
|
||||
// res2,3 <- state4,5
|
||||
umlal.2d v0, v23, v30[1] // MDS[2] == 1
|
||||
uaddw2.2d v10, v10, v23 // MDS[3] == 0
|
||||
uaddw.2d v11, v11, v23 // MDS[1] == 0
|
||||
umlal2.2d v1, v23, v30[1] // MDS[2] == 1
|
||||
|
||||
// res4,5 <- state2,3
|
||||
umlal.2d v2, v22, v31[3] // MDS[10] == 16
|
||||
umlal2.2d v12, v22, v31[1] // MDS[11] == 10
|
||||
umlal.2d v3, v22, v30[2] // MDS[9] == 3
|
||||
umlal2.2d v13, v22, v31[3] // MDS[10] == 16
|
||||
|
||||
// res6,7 <- consts,state1
|
||||
ldp d4, d5, [{rc_ptr}, #48]
|
||||
ushll.2d v14, v21, #8 // MDS[7] == 8
|
||||
ushll.2d v15, v21, #1 // MDS[6] == 1
|
||||
|
||||
// res2,3 <- state6,7
|
||||
umlal.2d v0, v24, v30[2] // MDS[4] == 3
|
||||
umlal2.2d v10, v24, v30[3] // MDS[5] == 5
|
||||
umlal2.2d v11, v24, v30[2] // MDS[4] == 3
|
||||
uaddw.2d v1, v1, v24 // MDS[3] == 0
|
||||
|
||||
// res4,5 <- state4,5
|
||||
uaddw.2d v2, v2, v23 // MDS[0] == 0
|
||||
umlal.2d v3, v23, v31[1] // MDS[11] == 10
|
||||
uaddw2.2d v12, v12, v23 // MDS[1] == 0
|
||||
uaddw2.2d v13, v13, v23 // MDS[0] == 0
|
||||
|
||||
// res6,7 <- state2,3
|
||||
umlal.2d v4, v22, v31[2] // MDS[8] == 12
|
||||
umlal2.2d v14, v22, v30[2] // MDS[9] == 3
|
||||
umlal.2d v5, v22, v31[0] // MDS[7] == 8
|
||||
umlal2.2d v15, v22, v31[2] // MDS[8] == 12
|
||||
|
||||
// res8,9 <- consts,state1
|
||||
ldp d6, d7, [{rc_ptr}, #64]
|
||||
ushll.2d v16, v21, #5 // MDS[5] == 5
|
||||
ushll.2d v17, v21, #3 // MDS[4] == 3
|
||||
|
||||
// res2,3 <- state8,9
|
||||
umlal.2d v0, v25, v30[1] // MDS[6] == 1
|
||||
umlal2.2d v10, v25, v31[0] // MDS[7] == 8
|
||||
umlal.2d v1, v25, v30[3] // MDS[5] == 5
|
||||
umlal2.2d v11, v25, v30[1] // MDS[6] == 1
|
||||
|
||||
// res4,5 <- state6,7
|
||||
umlal.2d v2, v24, v30[1] // MDS[2] == 1
|
||||
uaddw2.2d v12, v12, v24 // MDS[3] == 0
|
||||
uaddw.2d v13, v13, v24 // MDS[1] == 0
|
||||
umlal2.2d v3, v24, v30[1] // MDS[2] == 1
|
||||
|
||||
// res6,7 <- state4,5
|
||||
umlal.2d v4, v23, v31[3] // MDS[10] == 16
|
||||
umlal2.2d v14, v23, v31[1] // MDS[11] == 10
|
||||
umlal.2d v5, v23, v30[2] // MDS[9] == 3
|
||||
umlal2.2d v15, v23, v31[3] // MDS[10] == 16
|
||||
|
||||
// res8,9 <- state2,3
|
||||
umlal.2d v6, v22, v30[1] // MDS[6] == 1
|
||||
umlal2.2d v16, v22, v31[0] // MDS[7] == 8
|
||||
umlal.2d v7, v22, v30[3] // MDS[5] == 5
|
||||
umlal2.2d v17, v22, v30[1] // MDS[6] == 1
|
||||
|
||||
// res10,11 <- consts,state1
|
||||
ldp d8, d9, [{rc_ptr}, #80]
|
||||
ushll.2d v18, v21, #0 // MDS[3] == 0
|
||||
ushll.2d v19, v21, #1 // MDS[2] == 1
|
||||
|
||||
// res2,3 <- state10,11
|
||||
umlal.2d v0, v26, v31[2] // MDS[8] == 12
|
||||
umlal2.2d v10, v26, v30[2] // MDS[9] == 3
|
||||
umlal.2d v1, v26, v31[0] // MDS[7] == 8
|
||||
umlal2.2d v11, v26, v31[2] // MDS[8] == 12
|
||||
|
||||
// res4,5 <- state8,9
|
||||
umlal.2d v2, v25, v30[2] // MDS[4] == 3
|
||||
umlal2.2d v12, v25, v30[3] // MDS[5] == 5
|
||||
umlal2.2d v13, v25, v30[2] // MDS[4] == 3
|
||||
uaddw.2d v3, v3, v25 // MDS[3] == 0
|
||||
|
||||
// res6,7 <- state6,7
|
||||
uaddw.2d v4, v4, v24 // MDS[0] == 0
|
||||
umlal.2d v5, v24, v31[1] // MDS[11] == 10
|
||||
uaddw2.2d v14, v14, v24 // MDS[1] == 0
|
||||
uaddw2.2d v15, v15, v24 // MDS[0] == 0
|
||||
|
||||
// res8,9 <- state4,5
|
||||
umlal.2d v6, v23, v31[2] // MDS[8] == 12
|
||||
umlal2.2d v16, v23, v30[2] // MDS[9] == 3
|
||||
umlal.2d v7, v23, v31[0] // MDS[7] == 8
|
||||
umlal2.2d v17, v23, v31[2] // MDS[8] == 12
|
||||
|
||||
// res10,11 <- state2,3
|
||||
umlal.2d v8, v22, v30[2] // MDS[4] == 3
|
||||
umlal2.2d v18, v22, v30[3] // MDS[5] == 5
|
||||
uaddw.2d v9, v9, v22 // MDS[3] == 0
|
||||
umlal2.2d v19, v22, v30[2] // MDS[4] == 3
|
||||
|
||||
// merge accumulators, res2,3 <- state0, and reduce
|
||||
add.2d v0, v0, v10
|
||||
add.2d v1, v1, v11
|
||||
|
||||
umlal.2d v0, v20, v31[3] // MDS[10] == 16
|
||||
umlal.2d v1, v20, v30[2] // MDS[9] == 3
|
||||
mds_reduce_asm(v0, v1, v22)
|
||||
fmov {s2}, d22
|
||||
fmov.d {s3}, v22[1]
|
||||
|
||||
// res4,5 <- state10,11
|
||||
umlal.2d v2, v26, v30[1] // MDS[6] == 1
|
||||
umlal2.2d v12, v26, v31[0] // MDS[7] == 8
|
||||
umlal.2d v3, v26, v30[3] // MDS[5] == 5
|
||||
umlal2.2d v13, v26, v30[1] // MDS[6] == 1
|
||||
|
||||
// res6,7 <- state8,9
|
||||
umlal.2d v4, v25, v30[1] // MDS[2] == 1
|
||||
uaddw2.2d v14, v14, v25 // MDS[3] == 0
|
||||
uaddw.2d v15, v15, v25 // MDS[1] == 0
|
||||
umlal2.2d v5, v25, v30[1] // MDS[2] == 1
|
||||
|
||||
// res8,9 <- state6,7
|
||||
umlal.2d v6, v24, v31[3] // MDS[10] == 16
|
||||
umlal2.2d v16, v24, v31[1] // MDS[11] == 10
|
||||
umlal.2d v7, v24, v30[2] // MDS[9] == 3
|
||||
umlal2.2d v17, v24, v31[3] // MDS[10] == 16
|
||||
|
||||
// res10,11 <- state4,5
|
||||
umlal.2d v8, v23, v30[1] // MDS[6] == 1
|
||||
umlal2.2d v18, v23, v31[0] // MDS[7] == 8
|
||||
umlal.2d v9, v23, v30[3] // MDS[5] == 5
|
||||
umlal2.2d v19, v23, v30[1] // MDS[6] == 1
|
||||
|
||||
// merge accumulators, res4,5 <- state0, and reduce
|
||||
add.2d v2, v2, v12
|
||||
add.2d v3, v3, v13
|
||||
|
||||
umlal.2d v2, v20, v31[2] // MDS[8] == 12
|
||||
umlal.2d v3, v20, v31[0] // MDS[7] == 8
|
||||
mds_reduce_asm(v2, v3, v23)
|
||||
fmov {s4}, d23
|
||||
fmov.d {s5}, v23[1]
|
||||
|
||||
// res6,7 <- state10,11
|
||||
umlal.2d v4, v26, v30[2] // MDS[4] == 3
|
||||
umlal2.2d v14, v26, v30[3] // MDS[5] == 5
|
||||
umlal2.2d v15, v26, v30[2] // MDS[4] == 3
|
||||
uaddw.2d v5, v5, v26 // MDS[3] == 0
|
||||
|
||||
// res8,9 <- state8,9
|
||||
uaddw.2d v6, v6, v25 // MDS[0] == 0
|
||||
uaddw2.2d v16, v16, v25 // MDS[1] == 0
|
||||
uaddw2.2d v17, v17, v25 // MDS[0] == 0
|
||||
umlal.2d v7, v25, v31[1] // MDS[11] == 10
|
||||
|
||||
// res10,11 <- state6,7
|
||||
umlal.2d v8, v24, v31[2] // MDS[8] == 12
|
||||
umlal2.2d v18, v24, v30[2] // MDS[9] == 3
|
||||
umlal.2d v9, v24, v31[0] // MDS[7] == 8
|
||||
umlal2.2d v19, v24, v31[2] // MDS[8] == 12
|
||||
|
||||
// merge accumulators, res6,7 <- state0, and reduce
|
||||
add.2d v4, v4, v14
|
||||
add.2d v5, v5, v15
|
||||
|
||||
umlal.2d v4, v20, v30[1] // MDS[6] == 1
|
||||
umlal.2d v5, v20, v30[3] // MDS[5] == 5
|
||||
mds_reduce_asm(v4, v5, v24)
|
||||
fmov {s6}, d24
|
||||
fmov.d {s7}, v24[1]
|
||||
|
||||
// res8,9 <- state10,11
|
||||
umlal.2d v6, v26, v30[1] // MDS[2] == 1
|
||||
uaddw2.2d v16, v16, v26 // MDS[3] == 0
|
||||
umlal2.2d v17, v26, v30[1] // MDS[2] == 1
|
||||
uaddw.2d v7, v7, v26 // MDS[1] == 0
|
||||
|
||||
// res10,11 <- state8,9
|
||||
umlal.2d v8, v25, v31[3] // MDS[10] == 16
|
||||
umlal2.2d v18, v25, v31[1] // MDS[11] == 10
|
||||
umlal.2d v9, v25, v30[2] // MDS[9] == 3
|
||||
umlal2.2d v19, v25, v31[3] // MDS[10] == 16
|
||||
|
||||
// merge accumulators, res8,9 <- state0, and reduce
|
||||
add.2d v6, v6, v16
|
||||
add.2d v7, v7, v17
|
||||
|
||||
umlal.2d v6, v20, v30[2] // MDS[4] == 3
|
||||
uaddw.2d v7, v7, v20 // MDS[3] == 0
|
||||
mds_reduce_asm(v6, v7, v25)
|
||||
fmov {s8}, d25
|
||||
fmov.d {s9}, v25[1]
|
||||
|
||||
// res10,11 <- state10,11
|
||||
uaddw.2d v8, v8, v26 // MDS[0] == 0
|
||||
uaddw2.2d v18, v18, v26 // MDS[1] == 0
|
||||
umlal.2d v9, v26, v31[1] // MDS[11] == 10
|
||||
uaddw2.2d v19, v19, v26 // MDS[0] == 0
|
||||
|
||||
// merge accumulators, res10,11 <- state0, and reduce
|
||||
add.2d v8, v8, v18
|
||||
add.2d v9, v9, v19
|
||||
|
||||
umlal.2d v8, v20, v30[1] // MDS[2] == 1
|
||||
uaddw.2d v9, v9, v20 // MDS[1] == 0
|
||||
mds_reduce_asm(v8, v9, v26)
|
||||
fmov {s10}, d26
|
||||
fmov.d {s11}, v26[1]
|
||||
```
|
||||
where the macro `mds_reduce_asm` is defined as
|
||||
```assembly
|
||||
($c0, $c1, $out) => {
|
||||
// Swizzle
|
||||
zip1.2d $out, $c0, $c1 // lo
|
||||
zip2.2d $c0, $c0, $c1 // hi
|
||||
|
||||
// Reduction from u96
|
||||
usra.2d $c0, $out, #32
|
||||
sli.2d $out, $c0, #32
|
||||
// Extract high 32-bits.
|
||||
uzp2.4s $c0, $c0, $c0
|
||||
// Multiply by EPSILON and accumulate.
|
||||
mov.16b $c1, $out
|
||||
umlal.2d $out, $c0, v30[0]
|
||||
cmhi.2d $c1, $c1, $out
|
||||
usra.2d $out, $c1, #32
|
||||
}
|
||||
```
|
||||
|
||||
The order in which inputs are assumed to be available is:
|
||||
- state[1]
|
||||
- state[2] and state[3]
|
||||
- state[4] and state[5]
|
||||
- state[6] and state[7]
|
||||
- state[8] and state[9]
|
||||
- state[10] and state[11]
|
||||
- state[0]
|
||||
|
||||
The order in which the results are produced is:
|
||||
- state[2] and state[3]
|
||||
- state[4] and state[5]
|
||||
- state[6] and state[7]
|
||||
- state[8] and state[9]
|
||||
- state[10] and state[11]
|
||||
|
||||
The order of the instructions in the assembly should be thought of as a setting the relative priority of each instruction; because of CPU reordering, it does not correspond exactly to execution order in time. Ideally, we'd like the MDS matrix multiplication to happen in the following order:
|
||||
s[1] s[2..4] s[4..6] s[6..8] s[8..10] s[10..12] s[0]
|
||||
res[2..4] 1 2 4 7 11 16 21
|
||||
res[4..6] 3 5 8 12 17 22 26
|
||||
output res[6..8] 6 9 13 18 23 27 30
|
||||
res[8..10] 10 14 19 24 28 31 33
|
||||
res[10..12] 15 20 25 29 32 34 35
|
||||
|
||||
This is the order in which the operations are ordered in the ASM. It permits the start of one iteration to be interleaved with the end of the previous iteration (CPU reordering means we don't have to do it manually). Reductions, which have high latency, are executed as soon as the unreduced product is available; the pipelining permits them to be executed simultaneously with multiplication/accumulation, masking the latency.
|
||||
|
||||
The registers `v0`-`v19` are used for scratch. `v0` and `v10` are accumulators for res[2], `v1` and `v11` are accumulators for res[3], and so on. The accumulators hold the low result in the low 64 bits and the high result in the high 64 bits (this is convenient as both low and high are always multiplied by the same constant). They must be added before reduction.
|
||||
|
||||
The inputs for state[0] and state[1] are in the low 64 bits of `v20` and `v21`, respectively. The inputs and outputs for state[2..4], ..., state[10..12] are in `v22`, ..., `v26`, respectively.
|
||||
|
||||
`v30` and `v31` contains the constants [EPSILON, 1, 1 << 3, 1 << 5], [1 << 8, 1 << 10, 1 << 12, 1 << 16]. EPSILON is used in the reduction. The remaining constants are MDS matrix elements (except 0) and are used to form the dot products.
|
||||
|
||||
The instruction `umlal.2d v4, v20, v30[1]` can be read as:
|
||||
1. take the low 64 bits (`umlal2` for high 64 bits) of `v20` (state[0]),
|
||||
2. multiply the low and high 32 bits thereof by `v30[1]` (1),
|
||||
3. add the low and high product to the low and high 64-bits of `v4` respectively,
|
||||
4. save to `v4`.
|
||||
|
||||
We do not use `umlal` when the MDS coefficient is 1; instead, we use `uaddw` ("widening add") to reduce latency.
|
||||
@ -1,2 +1,5 @@
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
pub(crate) mod x86_64;
|
||||
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
pub(crate) mod aarch64;
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
mod arch;
|
||||
pub mod gmimc;
|
||||
pub mod hash_types;
|
||||
pub mod hashing;
|
||||
@ -8,5 +9,3 @@ pub mod poseidon;
|
||||
pub mod poseidon_crandall;
|
||||
pub mod poseidon_goldilocks;
|
||||
pub mod rescue;
|
||||
|
||||
mod arch;
|
||||
|
||||
@ -357,6 +357,14 @@ impl Poseidon<12> for GoldilocksField {
|
||||
crate::hash::arch::x86_64::poseidon_goldilocks_avx2_bmi2::poseidon(&input)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(target_arch="aarch64", target_feature="neon"))]
|
||||
#[inline]
|
||||
fn poseidon(input: [Self; 12]) -> [Self; 12] {
|
||||
unsafe {
|
||||
crate::hash::arch::aarch64::poseidon_goldilocks_neon::poseidon(input)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user