From a7cd1ef40bb7b49399370e3f84c968170993b600 Mon Sep 17 00:00:00 2001 From: Jakub Nabaglo Date: Tue, 14 Sep 2021 21:37:07 -0700 Subject: [PATCH] Vectorize Poseidon constant layer with NEON (#245) * Start work on Crandall arithmetic in NEON * Poseidon constant layer in NEON * its alive Co-authored-by: Jakub Nabaglo --- src/field/mod.rs | 3 + src/field/packed_crandall_neon.rs | 251 ++++++++++++++++++++++++++++++ src/hash/poseidon.rs | 18 ++- src/hash/poseidon_neon.rs | 17 ++ 4 files changed, 288 insertions(+), 1 deletion(-) create mode 100644 src/field/packed_crandall_neon.rs diff --git a/src/field/mod.rs b/src/field/mod.rs index 2b81774f..60948d48 100644 --- a/src/field/mod.rs +++ b/src/field/mod.rs @@ -12,6 +12,9 @@ pub(crate) mod packed_field; #[cfg(target_feature = "avx2")] pub(crate) mod packed_avx2; +#[cfg(target_feature = "neon")] +pub(crate) mod packed_crandall_neon; + #[cfg(test)] mod field_testing; #[cfg(test)] diff --git a/src/field/packed_crandall_neon.rs b/src/field/packed_crandall_neon.rs new file mode 100644 index 00000000..4df9acaa --- /dev/null +++ b/src/field/packed_crandall_neon.rs @@ -0,0 +1,251 @@ +use core::arch::aarch64::*; +use std::convert::TryInto; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::iter::{Product, Sum}; +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +use crate::field::crandall_field::CrandallField; +use crate::field::field_types::PrimeField; +use crate::field::packed_field::PackedField; + +/// PackedCrandallNeon wraps to ensure that Rust does not assume 16-byte alignment. Similar to +/// AVX2's PackedPrimeField. I don't think it matters as much on ARM but incorrectly-aligned +/// pointers are undefined behavior in Rust, so let's avoid them. +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct PackedCrandallNeon(pub [CrandallField; 2]); + +impl PackedCrandallNeon { + #[inline] + fn new(x: uint64x2_t) -> Self { + let x0 = unsafe { vgetq_lane_u64::<0>(x) }; + let x1 = unsafe { vgetq_lane_u64::<1>(x) }; + Self([CrandallField(x0), CrandallField(x1)]) + } + #[inline] + fn get(&self) -> uint64x2_t { + let x0 = self.0[0].0; + let x1 = self.0[1].0; + unsafe { vcombine_u64(vmov_n_u64(x0), vmov_n_u64(x1)) } + } + + /// Addition that assumes x + y < 2^64 + F::ORDER. May return incorrect results if this + /// condition is not met, hence it is marked unsafe. + #[inline] + pub unsafe fn add_canonical_u64(&self, rhs: uint64x2_t) -> Self { + Self::new(add_canonical_u64(self.get(), rhs)) + } +} + +impl Add for PackedCrandallNeon { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self::new(unsafe { add(self.get(), rhs.get()) }) + } +} +impl Add for PackedCrandallNeon { + type Output = Self; + #[inline] + fn add(self, rhs: CrandallField) -> Self { + self + Self::broadcast(rhs) + } +} +impl AddAssign for PackedCrandallNeon { + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} +impl AddAssign for PackedCrandallNeon { + #[inline] + fn add_assign(&mut self, rhs: CrandallField) { + *self = *self + rhs; + } +} + +impl Debug for PackedCrandallNeon { + #[inline] + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "({:?})", self.get()) + } +} + +impl Default for PackedCrandallNeon { + #[inline] + fn default() -> Self { + Self::zero() + } +} + +impl Mul for PackedCrandallNeon { + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + // TODO: Implement. + // Do this in scalar for now. + Self([self.0[0] * rhs.0[0], self.0[1] * rhs.0[1]]) + } +} +impl Mul for PackedCrandallNeon { + type Output = Self; + #[inline] + fn mul(self, rhs: CrandallField) -> Self { + self * Self::broadcast(rhs) + } +} +impl MulAssign for PackedCrandallNeon { + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} +impl MulAssign for PackedCrandallNeon { + #[inline] + fn mul_assign(&mut self, rhs: CrandallField) { + *self = *self * rhs; + } +} + +impl Neg for PackedCrandallNeon { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self::new(unsafe { neg(self.get()) }) + } +} + +impl Product for PackedCrandallNeon { + #[inline] + fn product>(iter: I) -> Self { + iter.reduce(|x, y| x * y).unwrap_or(Self::one()) + } +} + +impl PackedField for PackedCrandallNeon { + const LOG2_WIDTH: usize = 1; + + type FieldType = CrandallField; + + #[inline] + fn broadcast(x: CrandallField) -> Self { + Self([x; 2]) + } + + #[inline] + fn from_arr(arr: [Self::FieldType; Self::WIDTH]) -> Self { + Self(arr) + } + + #[inline] + fn to_arr(&self) -> [Self::FieldType; Self::WIDTH] { + self.0 + } + + #[inline] + fn from_slice(slice: &[CrandallField]) -> Self { + Self(slice.try_into().unwrap()) + } + + #[inline] + fn to_vec(&self) -> Vec { + self.0.into() + } + + #[inline] + fn interleave(&self, other: Self, r: usize) -> (Self, Self) { + let (v0, v1) = (self.get(), other.get()); + let (res0, res1) = match r { + 0 => unsafe { interleave0(v0, v1) }, + 1 => (v0, v1), + _ => panic!("r cannot be more than LOG2_WIDTH"), + }; + (Self::new(res0), Self::new(res1)) + } +} + +impl Sub for PackedCrandallNeon { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self::new(unsafe { sub(self.get(), rhs.get()) }) + } +} +impl Sub for PackedCrandallNeon { + type Output = Self; + #[inline] + fn sub(self, rhs: CrandallField) -> Self { + self - Self::broadcast(rhs) + } +} +impl SubAssign for PackedCrandallNeon { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} +impl SubAssign for PackedCrandallNeon { + #[inline] + fn sub_assign(&mut self, rhs: CrandallField) { + *self = *self - rhs; + } +} + +impl Sum for PackedCrandallNeon { + #[inline] + fn sum>(iter: I) -> Self { + iter.reduce(|x, y| x + y).unwrap_or(Self::zero()) + } +} + +const FIELD_ORDER: u64 = CrandallField::ORDER; + +#[inline] +unsafe fn field_order() -> uint64x2_t { + vmovq_n_u64(FIELD_ORDER) +} + +#[inline] +unsafe fn canonicalize(x: uint64x2_t) -> uint64x2_t { + let mask = vcgeq_u64(x, field_order()); // Mask is -1 if x >= FIELD_ORDER. + let x_maybe_unwrapped = vsubq_u64(x, field_order()); + vbslq_u64(mask, x_maybe_unwrapped, x) // Bitwise select +} + +#[inline] +unsafe fn add_no_canonicalize_64_64(x: uint64x2_t, y: uint64x2_t) -> uint64x2_t { + let res_wrapped = vaddq_u64(x, y); + let mask = vcgtq_u64(y, res_wrapped); // Mask is -1 if overflow. + let res_maybe_unwrapped = vsubq_u64(res_wrapped, field_order()); + vbslq_u64(mask, res_maybe_unwrapped, res_wrapped) // Bitwise select +} + +#[inline] +unsafe fn add_canonical_u64(x: uint64x2_t, y: uint64x2_t) -> uint64x2_t { + add_no_canonicalize_64_64(x, y) +} + +#[inline] +unsafe fn add(x: uint64x2_t, y: uint64x2_t) -> uint64x2_t { + add_no_canonicalize_64_64(x, canonicalize(y)) +} + +#[inline] +unsafe fn sub(x: uint64x2_t, y: uint64x2_t) -> uint64x2_t { + let y = canonicalize(y); + let mask = vcgtq_u64(y, x); // Mask is -1 if overflow. + let res_wrapped = vsubq_u64(x, y); + let res_maybe_unwrapped = vaddq_u64(res_wrapped, field_order()); + vbslq_u64(mask, res_maybe_unwrapped, res_wrapped) // Bitwise select +} + +#[inline] +unsafe fn neg(y: uint64x2_t) -> uint64x2_t { + vsubq_u64(field_order(), canonicalize(y)) +} + +#[inline] +unsafe fn interleave0(x: uint64x2_t, y: uint64x2_t) -> (uint64x2_t, uint64x2_t) { + (vtrn1q_u64(x, y), vtrn2q_u64(x, y)) +} diff --git a/src/hash/poseidon.rs b/src/hash/poseidon.rs index 5bd40eab..b3d809f2 100644 --- a/src/hash/poseidon.rs +++ b/src/hash/poseidon.rs @@ -1,7 +1,7 @@ //! Implementation of the Poseidon hash function, as described in //! https://eprint.iacr.org/2019/458.pdf -#[cfg(target_feature = "avx2")] +#[cfg(any(target_feature = "avx2", target_feature = "neon"))] use std::convert::TryInto; use unroll::unroll_for_loops; @@ -508,6 +508,14 @@ impl Poseidon<8> for CrandallField { ALL_ROUND_CONSTANTS[8 * round_ctr..8 * round_ctr + 8].try_into().unwrap()); } } + #[cfg(target_feature="neon")] + #[inline(always)] + fn constant_layer(state: &mut [Self; 8], round_ctr: usize) { + // This assumes that every element of ALL_ROUND_CONSTANTS is in 0..CrandallField::ORDER. + unsafe { crate::hash::poseidon_neon::crandall_poseidon_const_neon::<4>(state, + ALL_ROUND_CONSTANTS[8 * round_ctr..8 * round_ctr + 8].try_into().unwrap()); } + } + #[cfg(target_feature="avx2")] #[inline(always)] fn mds_layer(state_: &[CrandallField; 8]) -> [CrandallField; 8] { @@ -739,6 +747,14 @@ impl Poseidon<12> for CrandallField { ALL_ROUND_CONSTANTS[12 * round_ctr..12 * round_ctr + 12].try_into().unwrap()); } } + #[cfg(target_feature="neon")] + #[inline(always)] + fn constant_layer(state: &mut [Self; 12], round_ctr: usize) { + // This assumes that every element of ALL_ROUND_CONSTANTS is in 0..CrandallField::ORDER. + unsafe { crate::hash::poseidon_neon::crandall_poseidon_const_neon::<6>(state, + ALL_ROUND_CONSTANTS[12 * round_ctr..12 * round_ctr + 12].try_into().unwrap()); } + } + #[cfg(target_feature="avx2")] #[inline(always)] fn mds_layer(state_: &[CrandallField; 12]) -> [CrandallField; 12] { diff --git a/src/hash/poseidon_neon.rs b/src/hash/poseidon_neon.rs index fc70364a..de34a239 100644 --- a/src/hash/poseidon_neon.rs +++ b/src/hash/poseidon_neon.rs @@ -2,6 +2,8 @@ use core::arch::aarch64::*; use crate::field::crandall_field::CrandallField; use crate::field::field_types::PrimeField; +use crate::field::packed_crandall_neon::PackedCrandallNeon; +use crate::field::packed_field::PackedField; const EPSILON: u64 = 0u64.wrapping_sub(CrandallField::ORDER); @@ -228,3 +230,18 @@ unsafe fn mul_add_32_32_64(x: uint32x2_t, y: uint32x2_t, z: uint64x2_t) -> uint6 let res_unwrapped = vaddq_u64(res_wrapped, vmovq_n_u64(EPSILON)); vbslq_u64(mask, res_unwrapped, res_wrapped) } + +/// Poseidon constant layer for Crandall. Assumes that every element in round_constants is in +/// 0..CrandallField::ORDER; when this is not true it may return garbage. It's marked unsafe for +/// this reason. +#[inline(always)] +pub unsafe fn crandall_poseidon_const_neon( + state: &mut [CrandallField; 2 * PACKED_WIDTH], + round_constants: [u64; 2 * PACKED_WIDTH], +) { + let packed_state = PackedCrandallNeon::pack_slice_mut(state); + for i in 0..PACKED_WIDTH { + let packed_round_const = vld1q_u64(round_constants[2 * i..2 * i + 2].as_ptr()); + packed_state[i] = packed_state[i].add_canonical_u64(packed_round_const); + } +}