diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index e3a8cc5d..ad2be350 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -162,7 +162,7 @@ where res } - #[inline] + #[inline(always)] #[unroll_for_loops] fn mds_layer(state_: &[Self; WIDTH]) -> [Self; WIDTH] { let mut result = [Self::ZERO; WIDTH]; @@ -183,15 +183,18 @@ where result } - #[inline] + #[inline(always)] #[unroll_for_loops] fn partial_first_constant_layer(state: &mut [Self; WIDTH]) { - for i in 0..WIDTH { - state[i] += Self::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]); + assert!(WIDTH <= 12); + for i in 0..12 { + if i < WIDTH { + state[i] += Self::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]); + } } } - #[inline] + #[inline(always)] #[unroll_for_loops] fn mds_partial_layer_init(state: &[Self; WIDTH]) -> [Self; WIDTH] { let mut result = [Self::ZERO; WIDTH]; @@ -201,14 +204,21 @@ where // c = 0 result[0] = state[0]; - for c in 1..WIDTH { - for r in 1..WIDTH { - // NB: FAST_PARTIAL_ROUND_INITIAL_MATRIX is stored in - // column-major order so that this dot product is cache - // friendly. - let t = - Self::from_canonical_u64(Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1]); - result[c] += state[r] * t; + assert!(WIDTH <= 12); + for c in 1..12 { + if c < WIDTH { + assert!(WIDTH <= 12); + for r in 1..12 { + if r < WIDTH { + // NB: FAST_PARTIAL_ROUND_INITIAL_MATRIX is stored in + // column-major order so that this dot product is cache + // friendly. + let t = Self::from_canonical_u64( + Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1], + ); + result[c] += state[r] * t; + } + } } } result @@ -222,37 +232,46 @@ where /// /// M_00 is a scalar, v is 1x(t-1), w_hat is (t-1)x1 and Id is the /// (t-1)x(t-1) identity matrix. - #[inline] + #[inline(always)] #[unroll_for_loops] fn mds_partial_layer_fast(state: &[Self; WIDTH], r: usize) -> [Self; WIDTH] { // Set d = [M_00 | w^] dot [state] let s0 = state[0].to_noncanonical_u64() as u128; let mut d = Self::from_noncanonical_u128(s0 << Self::MDS_MATRIX_EXPS[0]); - for i in 1..WIDTH { - let t = Self::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]); - d += state[i] * t; + assert!(WIDTH <= 12); + for i in 1..12 { + if i < WIDTH { + let t = Self::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]); + d += state[i] * t; + } } // result = [d] concat [state[0] * v + state[shift up by 1]] let mut result = [Self::ZERO; WIDTH]; result[0] = d; - for i in 1..WIDTH { - let t = Self::from_canonical_u64(Self::FAST_PARTIAL_ROUND_VS[r][i - 1]); - result[i] = state[0] * t + state[i]; + assert!(WIDTH <= 12); + for i in 1..12 { + if i < WIDTH { + let t = Self::from_canonical_u64(Self::FAST_PARTIAL_ROUND_VS[r][i - 1]); + result[i] = state[0] * t + state[i]; + } } result } - #[inline] + #[inline(always)] #[unroll_for_loops] fn constant_layer(state: &mut [Self; WIDTH], round_ctr: usize) { - for i in 0..WIDTH { - state[i] += Self::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]); + assert!(WIDTH <= 12); + for i in 0..12 { + if i < WIDTH { + state[i] += Self::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]); + } } } - #[inline] + #[inline(always)] fn sbox_monomial(x: Self) -> Self { // x |--> x^7 let x2 = x * x; @@ -261,16 +280,18 @@ where x3 * x4 } - #[inline] + #[inline(always)] #[unroll_for_loops] fn sbox_layer(state: &mut [Self; WIDTH]) { - for i in 0..WIDTH { - state[i] = Self::sbox_monomial(state[i]); + assert!(WIDTH <= 12); + for i in 0..12 { + if i < WIDTH { + state[i] = Self::sbox_monomial(state[i]); + } } } #[inline] - #[unroll_for_loops] fn full_rounds(state: &mut [Self; WIDTH], round_ctr: &mut usize) { for _ in 0..HALF_N_FULL_ROUNDS { Self::constant_layer(state, *round_ctr); @@ -281,7 +302,6 @@ where } #[inline] - #[unroll_for_loops] fn partial_rounds_fast(state: &mut [Self; WIDTH], round_ctr: &mut usize) { Self::partial_first_constant_layer(state); *state = Self::mds_partial_layer_init(state); @@ -299,7 +319,6 @@ where } #[inline] - #[unroll_for_loops] fn partial_rounds(state: &mut [Self; WIDTH], round_ctr: &mut usize) { for _ in 0..N_PARTIAL_ROUNDS { Self::constant_layer(state, *round_ctr);