diff --git a/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs b/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs index 8503d5b2..1df21550 100644 --- a/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs +++ b/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs @@ -833,13 +833,25 @@ unsafe fn all_partial_rounds( state } -#[inline] -pub unsafe fn poseidon(state: &[GoldilocksField; 12]) -> [GoldilocksField; 12] { - let state = ( +#[inline(always)] +unsafe fn load_state(state: &[GoldilocksField; 12]) -> (__m256i, __m256i, __m256i) { + ( _mm256_loadu_si256((&state[0..4]).as_ptr().cast::<__m256i>()), _mm256_loadu_si256((&state[4..8]).as_ptr().cast::<__m256i>()), _mm256_loadu_si256((&state[8..12]).as_ptr().cast::<__m256i>()), - ); + ) +} + +#[inline(always)] +unsafe fn store_state(buf: &mut [GoldilocksField; 12], state: (__m256i, __m256i, __m256i)) { + _mm256_storeu_si256((&mut buf[0..4]).as_mut_ptr().cast::<__m256i>(), state.0); + _mm256_storeu_si256((&mut buf[4..8]).as_mut_ptr().cast::<__m256i>(), state.1); + _mm256_storeu_si256((&mut buf[8..12]).as_mut_ptr().cast::<__m256i>(), state.2); +} + +#[inline] +pub unsafe fn poseidon(state: &[GoldilocksField; 12]) -> [GoldilocksField; 12] { + let state = load_state(state); // The first constant layer must be done explicitly. The remaining constant layers are fused // with the preceeding MDS layer. @@ -850,9 +862,35 @@ pub unsafe fn poseidon(state: &[GoldilocksField; 12]) -> [GoldilocksField; 12] { let state = half_full_rounds(state, HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS); let mut res = [GoldilocksField::ZERO; 12]; - _mm256_storeu_si256((&mut res[0..4]).as_mut_ptr().cast::<__m256i>(), state.0); - _mm256_storeu_si256((&mut res[4..8]).as_mut_ptr().cast::<__m256i>(), state.1); - _mm256_storeu_si256((&mut res[8..12]).as_mut_ptr().cast::<__m256i>(), state.2); - + store_state(&mut res, state); + res +} + +#[inline(always)] +pub unsafe fn constant_layer(state_arr: &mut [GoldilocksField; WIDTH], round_ctr: usize) { + let state = load_state(state_arr); + let round_consts = &ALL_ROUND_CONSTANTS[WIDTH * round_ctr..][..WIDTH] + .try_into() + .unwrap(); + let state = const_layer(state, round_consts); + store_state(state_arr, state); +} + +#[inline(always)] +pub unsafe fn sbox_layer(state_arr: &mut [GoldilocksField; WIDTH]) { + let state = load_state(state_arr); + let state = sbox_layer_full(state); + store_state(state_arr, state); +} + +#[inline(always)] +pub unsafe fn mds_layer(state: &[GoldilocksField; WIDTH]) -> [GoldilocksField; WIDTH] { + let state = load_state(state); + // We want to do an MDS layer without the constant layer. + // The FUSED_ROUND_CONSTANTS for the last round are all 0 (shifted by 2**63 as required). + let round_consts = FUSED_ROUND_CONSTANTS[WIDTH * (N_ROUNDS - 1)..].as_ptr(); + let state = mds_const_layers_full(state, (round_consts, 0)); + let mut res = [GoldilocksField::ZERO; 12]; + store_state(&mut res, state); res } diff --git a/src/hash/poseidon_goldilocks.rs b/src/hash/poseidon_goldilocks.rs index 6f01a15b..32d5e237 100644 --- a/src/hash/poseidon_goldilocks.rs +++ b/src/hash/poseidon_goldilocks.rs @@ -358,6 +358,30 @@ impl Poseidon<12> for GoldilocksField { } } + #[cfg(all(target_arch="x86_64", target_feature="avx2", target_feature="bmi2"))] + #[inline(always)] + fn constant_layer(state: &mut [Self; 12], round_ctr: usize) { + unsafe { + crate::hash::arch::x86_64::poseidon_goldilocks_avx2_bmi2::constant_layer(state, round_ctr); + } + } + + #[cfg(all(target_arch="x86_64", target_feature="avx2", target_feature="bmi2"))] + #[inline(always)] + fn sbox_layer(state: &mut [Self; 12]) { + unsafe { + crate::hash::arch::x86_64::poseidon_goldilocks_avx2_bmi2::sbox_layer(state); + } + } + + #[cfg(all(target_arch="x86_64", target_feature="avx2", target_feature="bmi2"))] + #[inline(always)] + fn mds_layer(state: &[Self; 12]) -> [Self; 12] { + unsafe { + crate::hash::arch::x86_64::poseidon_goldilocks_avx2_bmi2::mds_layer(state) + } + } + #[cfg(all(target_arch="aarch64", target_feature="neon"))] #[inline] fn poseidon(input: [Self; 12]) -> [Self; 12] {