diff --git a/src/hash/mod.rs b/src/hash/mod.rs index dfb4d06c..2ba1f6be 100644 --- a/src/hash/mod.rs +++ b/src/hash/mod.rs @@ -5,3 +5,6 @@ pub mod merkle_proofs; pub mod merkle_tree; pub mod poseidon; pub mod rescue; + +#[cfg(target_feature = "avx2")] +mod poseidon_avx2; diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index ad2be350..677a101f 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -493,6 +493,12 @@ impl Poseidon<8> for CrandallField { [0xbc75b7bb6f92fb6b, 0x1d46b66c2ad3ef0c, 0x44ae739518db1d10, 0x3864e0e53027baf7, 0x800fc4e2c9f585d8, 0xda6cfb436cf6973e, 0x3fc702a71c42c8df, ], ]; + + #[cfg(target_feature="avx2")] + #[inline(always)] + fn mds_layer(state_: &[CrandallField; 8]) -> [CrandallField; 8] { + crate::hash::poseidon_avx2::crandall_poseidon8_mds_avx2(*state_) + } } #[rustfmt::skip] @@ -698,6 +704,12 @@ impl Poseidon<12> for CrandallField { 0x857f31827fb3fe60, 0xfdb6ca0a6d5cc865, 0x7e60116e98d5e20c, 0x685ef5a6b9e241d3, 0xe7ad8152c5d50bed, 0xb5d5efb12203ef9a, 0x8a041eb885fb24f5, ], ]; + + #[cfg(target_feature="avx2")] + #[inline(always)] + fn mds_layer(state_: &[CrandallField; 12]) -> [CrandallField; 12] { + crate::hash::poseidon_avx2::crandall_poseidon12_mds_avx2(*state_) + } } #[cfg(test)] diff --git a/src/hash/poseidon_avx2.rs b/src/hash/poseidon_avx2.rs new file mode 100644 index 00000000..c2a617b5 --- /dev/null +++ b/src/hash/poseidon_avx2.rs @@ -0,0 +1,204 @@ +use core::arch::x86_64::*; + +use crate::field::crandall_field::CrandallField; +use crate::field::field_types::PrimeField; + +const EPSILON: u64 = 0u64.wrapping_sub(CrandallField::ORDER); +const SIGN_BIT: u64 = 1 << 63; + +const MDS_MATRIX_EXPS8: [i32; 8] = [2, 0, 1, 8, 4, 3, 0, 0]; +const MDS_MATRIX_EXPS12: [i32; 12] = [10, 13, 2, 0, 4, 1, 8, 7, 15, 5, 0, 0]; + +/// Pair of vectors (hi, lo) representing a u128. +type Vecs128 = (__m256i, __m256i); + +/// Takes cumul (u128) and x (u64). Returns cumul + (x << SHIFT) as u128. +/// Assumes that cumul is shifted by 1 << 63; the result is similarly shifted. +#[inline(always)] +unsafe fn shift_and_accumulate( + x: __m256i, + (hi_cumul, lo_cumul_s): Vecs128, +) -> Vecs128 +where + [(); (64 - SHIFT) as usize]: , +{ + let x_shifted_lo = _mm256_slli_epi64::(x); + let x_shifted_hi = _mm256_srli_epi64::<{ 64 - SHIFT }>(x); + let res_lo_s = _mm256_add_epi64(lo_cumul_s, x_shifted_lo); + let carry = _mm256_cmpgt_epi64(lo_cumul_s, res_lo_s); + let res_hi = _mm256_sub_epi64(_mm256_add_epi64(hi_cumul, x_shifted_hi), carry); + (res_hi, res_lo_s) +} + +/// Extract state[OFFSET..OFFSET + 4] as a vector. Wraps around the boundary. +#[inline(always)] +unsafe fn get_vector_with_offset( + state: [CrandallField; WIDTH], +) -> __m256i { + _mm256_setr_epi64x( + state[OFFSET % WIDTH].0 as i64, + state[(OFFSET + 1) % WIDTH].0 as i64, + state[(OFFSET + 2) % WIDTH].0 as i64, + state[(OFFSET + 3) % WIDTH].0 as i64, + ) +} + +/// Extract CrandallField element from vector. +#[inline(always)] +unsafe fn extract(v: __m256i) -> CrandallField { + CrandallField(_mm256_extract_epi64::(v) as u64) +} + +#[inline(always)] +unsafe fn iteration8( + [cumul0_s, cumul1_s]: [Vecs128; 2], + state: [CrandallField; 8], +) -> [Vecs128; 2] +// 2 vectors of 4 needed to represent entire state. +where + [(); { INDEX + 4 }]: , + [(); (64 - SHIFT) as usize]: , +{ + // Entire state, rotated by INDEX. + let state0 = get_vector_with_offset::<8, INDEX>(state); + let state1 = get_vector_with_offset::<8, { INDEX + 4 }>(state); + [ + shift_and_accumulate::(state0, cumul0_s), + shift_and_accumulate::(state1, cumul1_s), + ] +} + +#[inline(always)] +pub fn crandall_poseidon8_mds_avx2(state: [CrandallField; 8]) -> [CrandallField; 8] { + unsafe { + let mut res_s = [(_mm256_setzero_si256(), _mm256_set1_epi64x(SIGN_BIT as i64)); 2]; + + // The scalar loop goes: + // for r in 0..WIDTH { + // let mut res = 0u128; + // for i in 0..WIDTH { + // res += (state[(i + r) % WIDTH] as u128) << MDS_MATRIX_EXPS[i]; + // } + // result[r] = reduce(res); + // } + // + // Here, we swap the loops. Equivalent to: + // let mut res = [0u128; WIDTH]; + // for i in 0..WIDTH { + // let mds_matrix_exp = MDS_MATRIX_EXPS[i]; + // for r in 0..WIDTH { + // res[r] += (state[(i + r) % WIDTH] as u128) << mds_matrix_exp; + // } + // } + // for r in 0..WIDTH { + // result[r] = reduce(res[r]); + // } + // + // Notice that that in the lower version, all iterations of the inner loop shift by the same + // amount. In vector, we perform multiple iterations of the loop at once, and vector shifts + // are cheaper when all elements are shifted by the same amount. + + res_s = iteration8::<0, { MDS_MATRIX_EXPS8[0] }>(res_s, state); + res_s = iteration8::<1, { MDS_MATRIX_EXPS8[1] }>(res_s, state); + res_s = iteration8::<2, { MDS_MATRIX_EXPS8[2] }>(res_s, state); + res_s = iteration8::<3, { MDS_MATRIX_EXPS8[3] }>(res_s, state); + res_s = iteration8::<4, { MDS_MATRIX_EXPS8[4] }>(res_s, state); + res_s = iteration8::<5, { MDS_MATRIX_EXPS8[5] }>(res_s, state); + res_s = iteration8::<6, { MDS_MATRIX_EXPS8[6] }>(res_s, state); + res_s = iteration8::<7, { MDS_MATRIX_EXPS8[7] }>(res_s, state); + + let [res0_s, res1_s] = res_s; + let reduced0 = reduce96s(res0_s); + let reduced1 = reduce96s(res1_s); + [ + extract::<0>(reduced0), + extract::<1>(reduced0), + extract::<2>(reduced0), + extract::<3>(reduced0), + extract::<0>(reduced1), + extract::<1>(reduced1), + extract::<2>(reduced1), + extract::<3>(reduced1), + ] + } +} + +#[inline(always)] +unsafe fn iteration12( + [cumul0_s, cumul1_s, cumul2_s]: [Vecs128; 3], + state: [CrandallField; 12], +) -> [Vecs128; 3] +// 3 vectors of 4 needed to represent entire state. +where + [(); { INDEX + 4 }]: , + [(); { INDEX + 8 }]: , + [(); (64 - SHIFT) as usize]: , +{ + // Entire state, rotated by INDEX. + let state0 = get_vector_with_offset::<12, INDEX>(state); + let state1 = get_vector_with_offset::<12, { INDEX + 4 }>(state); + let state2 = get_vector_with_offset::<12, { INDEX + 8 }>(state); + [ + shift_and_accumulate::(state0, cumul0_s), + shift_and_accumulate::(state1, cumul1_s), + shift_and_accumulate::(state2, cumul2_s), + ] +} + +#[inline(always)] +pub fn crandall_poseidon12_mds_avx2(state: [CrandallField; 12]) -> [CrandallField; 12] { + unsafe { + let mut res_s = [(_mm256_setzero_si256(), _mm256_set1_epi64x(SIGN_BIT as i64)); 3]; + + // See width-8 version for explanation. + + res_s = iteration12::<0, { MDS_MATRIX_EXPS12[0] }>(res_s, state); + res_s = iteration12::<1, { MDS_MATRIX_EXPS12[1] }>(res_s, state); + res_s = iteration12::<2, { MDS_MATRIX_EXPS12[2] }>(res_s, state); + res_s = iteration12::<3, { MDS_MATRIX_EXPS12[3] }>(res_s, state); + res_s = iteration12::<4, { MDS_MATRIX_EXPS12[4] }>(res_s, state); + res_s = iteration12::<5, { MDS_MATRIX_EXPS12[5] }>(res_s, state); + res_s = iteration12::<6, { MDS_MATRIX_EXPS12[6] }>(res_s, state); + res_s = iteration12::<7, { MDS_MATRIX_EXPS12[7] }>(res_s, state); + res_s = iteration12::<8, { MDS_MATRIX_EXPS12[8] }>(res_s, state); + res_s = iteration12::<9, { MDS_MATRIX_EXPS12[9] }>(res_s, state); + res_s = iteration12::<10, { MDS_MATRIX_EXPS12[10] }>(res_s, state); + res_s = iteration12::<11, { MDS_MATRIX_EXPS12[11] }>(res_s, state); + + let [res0_s, res1_s, res2_s] = res_s; + let reduced0 = reduce96s(res0_s); + let reduced1 = reduce96s(res1_s); + let reduced2 = reduce96s(res2_s); + [ + extract::<0>(reduced0), + extract::<1>(reduced0), + extract::<2>(reduced0), + extract::<3>(reduced0), + extract::<0>(reduced1), + extract::<1>(reduced1), + extract::<2>(reduced1), + extract::<3>(reduced1), + extract::<0>(reduced2), + extract::<1>(reduced2), + extract::<2>(reduced2), + extract::<3>(reduced2), + ] + } +} + +#[inline(always)] +unsafe fn reduce96s(x_s: Vecs128) -> __m256i { + let (hi0, lo0_s) = x_s; + let lo1 = _mm256_mul_epu32(hi0, _mm256_set1_epi64x(EPSILON as i64)); + add_no_canonicalize_64_64s(lo1, lo0_s) +} + +#[inline(always)] +unsafe fn add_no_canonicalize_64_64s(x: __m256i, y_s: __m256i) -> __m256i { + let res_wrapped_s = _mm256_add_epi64(x, y_s); + let mask = _mm256_cmpgt_epi64(y_s, res_wrapped_s); + let res_wrapped = _mm256_xor_si256(res_wrapped_s, _mm256_set1_epi64x(SIGN_BIT as i64)); + let wrapback_amt = _mm256_and_si256(mask, _mm256_set1_epi64x(EPSILON as i64)); + let res = _mm256_add_epi64(res_wrapped, wrapback_amt); + res +}