diff --git a/src/hash/hashing.rs b/src/hash/hashing.rs index c60098be..d031ebbb 100644 --- a/src/hash/hashing.rs +++ b/src/hash/hashing.rs @@ -141,6 +141,6 @@ pub fn hash_n_to_1(inputs: Vec, pad: bool) -> F { pub(crate) fn permute(inputs: [F; SPONGE_WIDTH]) -> [F; SPONGE_WIDTH] { match HASH_FAMILY { HashFamily::GMiMC => F::gmimc_permute(inputs), - HashFamily::Poseidon => F::poseidon_naive(inputs), + HashFamily::Poseidon => F::poseidon(inputs), } } diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index cb117907..8da878aa 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -549,7 +549,7 @@ where } #[inline] - fn partial_rounds_fast(state: &mut [Self; WIDTH], round_ctr: &mut usize) { + fn partial_rounds(state: &mut [Self; WIDTH], round_ctr: &mut usize) { Self::partial_first_constant_layer(state); *state = Self::mds_partial_layer_init(state); @@ -564,7 +564,21 @@ where } #[inline] - fn partial_rounds(state: &mut [Self; WIDTH], round_ctr: &mut usize) { + fn poseidon(input: [Self; WIDTH]) -> [Self; WIDTH] { + let mut state = input; + let mut round_ctr = 0; + + Self::full_rounds(&mut state, &mut round_ctr); + Self::partial_rounds(&mut state, &mut round_ctr); + Self::full_rounds(&mut state, &mut round_ctr); + debug_assert_eq!(round_ctr, N_ROUNDS); + + state + } + + // For testing only, to ensure that various tricks are correct. + #[inline] + fn partial_rounds_naive(state: &mut [Self; WIDTH], round_ctr: &mut usize) { for _ in 0..N_PARTIAL_ROUNDS { Self::constant_layer(state, *round_ctr); state[0] = Self::sbox_monomial(state[0]); @@ -573,26 +587,13 @@ where } } - #[inline] - fn poseidon(input: [Self; WIDTH]) -> [Self; WIDTH] { - let mut state = input; - let mut round_ctr = 0; - - Self::full_rounds(&mut state, &mut round_ctr); - Self::partial_rounds_fast(&mut state, &mut round_ctr); - Self::full_rounds(&mut state, &mut round_ctr); - debug_assert_eq!(round_ctr, N_ROUNDS); - - state - } - #[inline] fn poseidon_naive(input: [Self; WIDTH]) -> [Self; WIDTH] { let mut state = input; let mut round_ctr = 0; Self::full_rounds(&mut state, &mut round_ctr); - Self::partial_rounds(&mut state, &mut round_ctr); + Self::partial_rounds_naive(&mut state, &mut round_ctr); Self::full_rounds(&mut state, &mut round_ctr); debug_assert_eq!(round_ctr, N_ROUNDS); diff --git a/src/hash/poseidon_goldilocks.rs b/src/hash/poseidon_goldilocks.rs index 04e349fe..8fd01655 100644 --- a/src/hash/poseidon_goldilocks.rs +++ b/src/hash/poseidon_goldilocks.rs @@ -352,7 +352,7 @@ impl Poseidon<12> for GoldilocksField { #[cfg(all(target_arch="x86_64", target_feature="avx2", target_feature="bmi2"))] #[inline] - fn poseidon_naive(input: [Self; 12]) -> [Self; 12] { + fn poseidon(input: [Self; 12]) -> [Self; 12] { unsafe { crate::hash::arch::x86_64::poseidon_goldilocks_avx2_bmi2::poseidon(&input) }