Unroll a couple loops in Poseidon code (#215)

* Unroll a couple loops in Poseidon code (super hacky)

* Comments
This commit is contained in:
Daniel Lubarov 2021-09-03 21:42:40 -07:00 committed by GitHub
parent 032e2feeb4
commit ba4b03e487
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -137,7 +137,7 @@ where
const FAST_PARTIAL_ROUND_W_HATS: [[u64; WIDTH - 1]; N_PARTIAL_ROUNDS];
const FAST_PARTIAL_ROUND_INITIAL_MATRIX: [[u64; WIDTH - 1]; WIDTH - 1];
#[inline]
#[inline(always)]
#[unroll_for_loops]
fn mds_row_shf(r: usize, v: &[u64; WIDTH]) -> u128 {
debug_assert!(r < WIDTH);
@ -148,9 +148,15 @@ where
// NB: Unrolling this, calculating each term independently, and
// summing at the end, didn't improve performance for me.
let mut res = 0u128;
for i in 0..WIDTH {
res += (v[(i + r) % WIDTH] as u128) << Self::MDS_MATRIX_EXPS[i];
// This is a hacky way of fully unrolling the loop.
assert!(WIDTH <= 12);
for i in 0..12 {
if i < WIDTH {
res += (v[(i + r) % WIDTH] as u128) << Self::MDS_MATRIX_EXPS[i];
}
}
res
}
@ -164,9 +170,14 @@ where
state[r] = state_[r].to_noncanonical_u64();
}
for r in 0..WIDTH {
result[r] = Self::from_noncanonical_u128(Self::mds_row_shf(r, &state));
// This is a hacky way of fully unrolling the loop.
assert!(WIDTH <= 12);
for r in 0..12 {
if r < WIDTH {
result[r] = Self::from_noncanonical_u128(Self::mds_row_shf(r, &state));
}
}
result
}