diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index 3ebdd802..6d96d355 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -137,7 +137,7 @@ where const FAST_PARTIAL_ROUND_W_HATS: [[u64; WIDTH - 1]; N_PARTIAL_ROUNDS]; const FAST_PARTIAL_ROUND_INITIAL_MATRIX: [[u64; WIDTH - 1]; WIDTH - 1]; - #[inline] + #[inline(always)] #[unroll_for_loops] fn mds_row_shf(r: usize, v: &[u64; WIDTH]) -> u128 { debug_assert!(r < WIDTH); @@ -148,9 +148,15 @@ where // NB: Unrolling this, calculating each term independently, and // summing at the end, didn't improve performance for me. let mut res = 0u128; - for i in 0..WIDTH { - res += (v[(i + r) % WIDTH] as u128) << Self::MDS_MATRIX_EXPS[i]; + + // This is a hacky way of fully unrolling the loop. + assert!(WIDTH <= 12); + for i in 0..12 { + if i < WIDTH { + res += (v[(i + r) % WIDTH] as u128) << Self::MDS_MATRIX_EXPS[i]; + } } + res } @@ -164,9 +170,14 @@ where state[r] = state_[r].to_noncanonical_u64(); } - for r in 0..WIDTH { - result[r] = Self::from_noncanonical_u128(Self::mds_row_shf(r, &state)); + // This is a hacky way of fully unrolling the loop. + assert!(WIDTH <= 12); + for r in 0..12 { + if r < WIDTH { + result[r] = Self::from_noncanonical_u128(Self::mds_row_shf(r, &state)); + } } + result }