diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index d5cb13a0..5bd40eab 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -519,6 +519,12 @@ impl Poseidon<8> for CrandallField { fn mds_layer(state_: &[CrandallField; 8]) -> [CrandallField; 8] { crate::hash::poseidon_neon::crandall_poseidon8_mds_neon(*state_) } + + #[cfg(target_feature="avx2")] + #[inline(always)] + fn sbox_layer(state: &mut [Self; 8]) { + crate::hash::poseidon_avx2::crandall_poseidon_sbox_avx2::<2>(state); + } } #[rustfmt::skip] @@ -744,6 +750,12 @@ impl Poseidon<12> for CrandallField { fn mds_layer(state_: &[CrandallField; 12]) -> [CrandallField; 12] { crate::hash::poseidon_neon::crandall_poseidon12_mds_neon(*state_) } + + #[cfg(target_feature="avx2")] + #[inline(always)] + fn sbox_layer(state: &mut [Self; 12]) { + crate::hash::poseidon_avx2::crandall_poseidon_sbox_avx2::<3>(state); + } } #[cfg(test)] diff --git a/src/hash/poseidon_avx2.rs b/src/hash/poseidon_avx2.rs index 643317f8..cc4a3d1d 100644 --- a/src/hash/poseidon_avx2.rs +++ b/src/hash/poseidon_avx2.rs @@ -220,3 +220,28 @@ pub unsafe fn crandall_poseidon_const_avx2( packed_state[i] = packed_state[i].add_canonical_u64(packed_round_constants[i]); } } + +#[inline(always)] +pub fn crandall_poseidon_sbox_avx2( + state: &mut [CrandallField; 4 * PACKED_WIDTH], +) { + // This function is manually interleaved to maximize instruction-level parallelism. + + let packed_state = PackedCrandallAVX2::pack_slice_mut(state); + + let mut x2 = [PackedCrandallAVX2::zero(); PACKED_WIDTH]; + for i in 0..PACKED_WIDTH { + x2[i] = packed_state[i].square(); + } + + let mut x3 = [PackedCrandallAVX2::zero(); PACKED_WIDTH]; + let mut x4 = [PackedCrandallAVX2::zero(); PACKED_WIDTH]; + for i in 0..PACKED_WIDTH { + x3[i] = packed_state[i] * x2[i]; + x4[i] = x2[i].square(); + } + + for i in 0..PACKED_WIDTH { + packed_state[i] = x3[i] * x4[i]; + } +}