Non-vector Poseidon speedups (#230)

This commit is contained in:
Jakub Nabaglo 2021-09-11 11:25:20 -07:00 committed by GitHub
parent ba8b40f0e6
commit c0e8edb899
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -162,7 +162,7 @@ where
res
}
#[inline]
#[inline(always)]
#[unroll_for_loops]
fn mds_layer(state_: &[Self; WIDTH]) -> [Self; WIDTH] {
let mut result = [Self::ZERO; WIDTH];
@ -183,15 +183,18 @@ where
result
}
#[inline]
#[inline(always)]
#[unroll_for_loops]
fn partial_first_constant_layer(state: &mut [Self; WIDTH]) {
for i in 0..WIDTH {
state[i] += Self::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]);
assert!(WIDTH <= 12);
for i in 0..12 {
if i < WIDTH {
state[i] += Self::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]);
}
}
}
#[inline]
#[inline(always)]
#[unroll_for_loops]
fn mds_partial_layer_init(state: &[Self; WIDTH]) -> [Self; WIDTH] {
let mut result = [Self::ZERO; WIDTH];
@ -201,14 +204,21 @@ where
// c = 0
result[0] = state[0];
for c in 1..WIDTH {
for r in 1..WIDTH {
// NB: FAST_PARTIAL_ROUND_INITIAL_MATRIX is stored in
// column-major order so that this dot product is cache
// friendly.
let t =
Self::from_canonical_u64(Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1]);
result[c] += state[r] * t;
assert!(WIDTH <= 12);
for c in 1..12 {
if c < WIDTH {
assert!(WIDTH <= 12);
for r in 1..12 {
if r < WIDTH {
// NB: FAST_PARTIAL_ROUND_INITIAL_MATRIX is stored in
// column-major order so that this dot product is cache
// friendly.
let t = Self::from_canonical_u64(
Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[c - 1][r - 1],
);
result[c] += state[r] * t;
}
}
}
}
result
@ -222,37 +232,46 @@ where
///
/// M_00 is a scalar, v is 1x(t-1), w_hat is (t-1)x1 and Id is the
/// (t-1)x(t-1) identity matrix.
#[inline]
#[inline(always)]
#[unroll_for_loops]
fn mds_partial_layer_fast(state: &[Self; WIDTH], r: usize) -> [Self; WIDTH] {
// Set d = [M_00 | w^] dot [state]
let s0 = state[0].to_noncanonical_u64() as u128;
let mut d = Self::from_noncanonical_u128(s0 << Self::MDS_MATRIX_EXPS[0]);
for i in 1..WIDTH {
let t = Self::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]);
d += state[i] * t;
assert!(WIDTH <= 12);
for i in 1..12 {
if i < WIDTH {
let t = Self::from_canonical_u64(Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1]);
d += state[i] * t;
}
}
// result = [d] concat [state[0] * v + state[shift up by 1]]
let mut result = [Self::ZERO; WIDTH];
result[0] = d;
for i in 1..WIDTH {
let t = Self::from_canonical_u64(Self::FAST_PARTIAL_ROUND_VS[r][i - 1]);
result[i] = state[0] * t + state[i];
assert!(WIDTH <= 12);
for i in 1..12 {
if i < WIDTH {
let t = Self::from_canonical_u64(Self::FAST_PARTIAL_ROUND_VS[r][i - 1]);
result[i] = state[0] * t + state[i];
}
}
result
}
#[inline]
#[inline(always)]
#[unroll_for_loops]
fn constant_layer(state: &mut [Self; WIDTH], round_ctr: usize) {
for i in 0..WIDTH {
state[i] += Self::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]);
assert!(WIDTH <= 12);
for i in 0..12 {
if i < WIDTH {
state[i] += Self::from_canonical_u64(ALL_ROUND_CONSTANTS[i + WIDTH * round_ctr]);
}
}
}
#[inline]
#[inline(always)]
fn sbox_monomial(x: Self) -> Self {
// x |--> x^7
let x2 = x * x;
@ -261,16 +280,18 @@ where
x3 * x4
}
#[inline]
#[inline(always)]
#[unroll_for_loops]
fn sbox_layer(state: &mut [Self; WIDTH]) {
for i in 0..WIDTH {
state[i] = Self::sbox_monomial(state[i]);
assert!(WIDTH <= 12);
for i in 0..12 {
if i < WIDTH {
state[i] = Self::sbox_monomial(state[i]);
}
}
}
#[inline]
#[unroll_for_loops]
fn full_rounds(state: &mut [Self; WIDTH], round_ctr: &mut usize) {
for _ in 0..HALF_N_FULL_ROUNDS {
Self::constant_layer(state, *round_ctr);
@ -281,7 +302,6 @@ where
}
#[inline]
#[unroll_for_loops]
fn partial_rounds_fast(state: &mut [Self; WIDTH], round_ctr: &mut usize) {
Self::partial_first_constant_layer(state);
*state = Self::mds_partial_layer_init(state);
@ -299,7 +319,6 @@ where
}
#[inline]
#[unroll_for_loops]
fn partial_rounds(state: &mut [Self; WIDTH], round_ctr: &mut usize) {
for _ in 0..N_PARTIAL_ROUNDS {
Self::constant_layer(state, *round_ctr);