mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-07 08:13:11 +00:00
Vectorize Poseidon constant layer with NEON (#245)
* Start work on Crandall arithmetic in NEON * Poseidon constant layer in NEON * its alive Co-authored-by: Jakub Nabaglo <jakub@mirprotocol.org>
This commit is contained in:
parent
b411a275f9
commit
a7cd1ef40b
@ -12,6 +12,9 @@ pub(crate) mod packed_field;
|
||||
#[cfg(target_feature = "avx2")]
|
||||
pub(crate) mod packed_avx2;
|
||||
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub(crate) mod packed_crandall_neon;
|
||||
|
||||
#[cfg(test)]
|
||||
mod field_testing;
|
||||
#[cfg(test)]
|
||||
|
||||
251
src/field/packed_crandall_neon.rs
Normal file
251
src/field/packed_crandall_neon.rs
Normal file
@ -0,0 +1,251 @@
|
||||
use core::arch::aarch64::*;
|
||||
use std::convert::TryInto;
|
||||
use std::fmt;
|
||||
use std::fmt::{Debug, Formatter};
|
||||
use std::iter::{Product, Sum};
|
||||
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
|
||||
|
||||
use crate::field::crandall_field::CrandallField;
|
||||
use crate::field::field_types::PrimeField;
|
||||
use crate::field::packed_field::PackedField;
|
||||
|
||||
/// PackedCrandallNeon wraps to ensure that Rust does not assume 16-byte alignment. Similar to
|
||||
/// AVX2's PackedPrimeField. I don't think it matters as much on ARM but incorrectly-aligned
|
||||
/// pointers are undefined behavior in Rust, so let's avoid them.
|
||||
#[derive(Copy, Clone)]
|
||||
#[repr(transparent)]
|
||||
pub struct PackedCrandallNeon(pub [CrandallField; 2]);
|
||||
|
||||
impl PackedCrandallNeon {
|
||||
#[inline]
|
||||
fn new(x: uint64x2_t) -> Self {
|
||||
let x0 = unsafe { vgetq_lane_u64::<0>(x) };
|
||||
let x1 = unsafe { vgetq_lane_u64::<1>(x) };
|
||||
Self([CrandallField(x0), CrandallField(x1)])
|
||||
}
|
||||
#[inline]
|
||||
fn get(&self) -> uint64x2_t {
|
||||
let x0 = self.0[0].0;
|
||||
let x1 = self.0[1].0;
|
||||
unsafe { vcombine_u64(vmov_n_u64(x0), vmov_n_u64(x1)) }
|
||||
}
|
||||
|
||||
/// Addition that assumes x + y < 2^64 + F::ORDER. May return incorrect results if this
|
||||
/// condition is not met, hence it is marked unsafe.
|
||||
#[inline]
|
||||
pub unsafe fn add_canonical_u64(&self, rhs: uint64x2_t) -> Self {
|
||||
Self::new(add_canonical_u64(self.get(), rhs))
|
||||
}
|
||||
}
|
||||
|
||||
impl Add<Self> for PackedCrandallNeon {
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
fn add(self, rhs: Self) -> Self {
|
||||
Self::new(unsafe { add(self.get(), rhs.get()) })
|
||||
}
|
||||
}
|
||||
impl Add<CrandallField> for PackedCrandallNeon {
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
fn add(self, rhs: CrandallField) -> Self {
|
||||
self + Self::broadcast(rhs)
|
||||
}
|
||||
}
|
||||
impl AddAssign<Self> for PackedCrandallNeon {
|
||||
#[inline]
|
||||
fn add_assign(&mut self, rhs: Self) {
|
||||
*self = *self + rhs;
|
||||
}
|
||||
}
|
||||
impl AddAssign<CrandallField> for PackedCrandallNeon {
|
||||
#[inline]
|
||||
fn add_assign(&mut self, rhs: CrandallField) {
|
||||
*self = *self + rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for PackedCrandallNeon {
|
||||
#[inline]
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "({:?})", self.get())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PackedCrandallNeon {
|
||||
#[inline]
|
||||
fn default() -> Self {
|
||||
Self::zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl Mul<Self> for PackedCrandallNeon {
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
fn mul(self, rhs: Self) -> Self {
|
||||
// TODO: Implement.
|
||||
// Do this in scalar for now.
|
||||
Self([self.0[0] * rhs.0[0], self.0[1] * rhs.0[1]])
|
||||
}
|
||||
}
|
||||
impl Mul<CrandallField> for PackedCrandallNeon {
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
fn mul(self, rhs: CrandallField) -> Self {
|
||||
self * Self::broadcast(rhs)
|
||||
}
|
||||
}
|
||||
impl MulAssign<Self> for PackedCrandallNeon {
|
||||
#[inline]
|
||||
fn mul_assign(&mut self, rhs: Self) {
|
||||
*self = *self * rhs;
|
||||
}
|
||||
}
|
||||
impl MulAssign<CrandallField> for PackedCrandallNeon {
|
||||
#[inline]
|
||||
fn mul_assign(&mut self, rhs: CrandallField) {
|
||||
*self = *self * rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Neg for PackedCrandallNeon {
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
fn neg(self) -> Self {
|
||||
Self::new(unsafe { neg(self.get()) })
|
||||
}
|
||||
}
|
||||
|
||||
impl Product for PackedCrandallNeon {
|
||||
#[inline]
|
||||
fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
|
||||
iter.reduce(|x, y| x * y).unwrap_or(Self::one())
|
||||
}
|
||||
}
|
||||
|
||||
impl PackedField for PackedCrandallNeon {
|
||||
const LOG2_WIDTH: usize = 1;
|
||||
|
||||
type FieldType = CrandallField;
|
||||
|
||||
#[inline]
|
||||
fn broadcast(x: CrandallField) -> Self {
|
||||
Self([x; 2])
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn from_arr(arr: [Self::FieldType; Self::WIDTH]) -> Self {
|
||||
Self(arr)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn to_arr(&self) -> [Self::FieldType; Self::WIDTH] {
|
||||
self.0
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn from_slice(slice: &[CrandallField]) -> Self {
|
||||
Self(slice.try_into().unwrap())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn to_vec(&self) -> Vec<CrandallField> {
|
||||
self.0.into()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn interleave(&self, other: Self, r: usize) -> (Self, Self) {
|
||||
let (v0, v1) = (self.get(), other.get());
|
||||
let (res0, res1) = match r {
|
||||
0 => unsafe { interleave0(v0, v1) },
|
||||
1 => (v0, v1),
|
||||
_ => panic!("r cannot be more than LOG2_WIDTH"),
|
||||
};
|
||||
(Self::new(res0), Self::new(res1))
|
||||
}
|
||||
}
|
||||
|
||||
impl Sub<Self> for PackedCrandallNeon {
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
fn sub(self, rhs: Self) -> Self {
|
||||
Self::new(unsafe { sub(self.get(), rhs.get()) })
|
||||
}
|
||||
}
|
||||
impl Sub<CrandallField> for PackedCrandallNeon {
|
||||
type Output = Self;
|
||||
#[inline]
|
||||
fn sub(self, rhs: CrandallField) -> Self {
|
||||
self - Self::broadcast(rhs)
|
||||
}
|
||||
}
|
||||
impl SubAssign<Self> for PackedCrandallNeon {
|
||||
#[inline]
|
||||
fn sub_assign(&mut self, rhs: Self) {
|
||||
*self = *self - rhs;
|
||||
}
|
||||
}
|
||||
impl SubAssign<CrandallField> for PackedCrandallNeon {
|
||||
#[inline]
|
||||
fn sub_assign(&mut self, rhs: CrandallField) {
|
||||
*self = *self - rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl Sum for PackedCrandallNeon {
|
||||
#[inline]
|
||||
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
|
||||
iter.reduce(|x, y| x + y).unwrap_or(Self::zero())
|
||||
}
|
||||
}
|
||||
|
||||
const FIELD_ORDER: u64 = CrandallField::ORDER;
|
||||
|
||||
#[inline]
|
||||
unsafe fn field_order() -> uint64x2_t {
|
||||
vmovq_n_u64(FIELD_ORDER)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn canonicalize(x: uint64x2_t) -> uint64x2_t {
|
||||
let mask = vcgeq_u64(x, field_order()); // Mask is -1 if x >= FIELD_ORDER.
|
||||
let x_maybe_unwrapped = vsubq_u64(x, field_order());
|
||||
vbslq_u64(mask, x_maybe_unwrapped, x) // Bitwise select
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn add_no_canonicalize_64_64(x: uint64x2_t, y: uint64x2_t) -> uint64x2_t {
|
||||
let res_wrapped = vaddq_u64(x, y);
|
||||
let mask = vcgtq_u64(y, res_wrapped); // Mask is -1 if overflow.
|
||||
let res_maybe_unwrapped = vsubq_u64(res_wrapped, field_order());
|
||||
vbslq_u64(mask, res_maybe_unwrapped, res_wrapped) // Bitwise select
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn add_canonical_u64(x: uint64x2_t, y: uint64x2_t) -> uint64x2_t {
|
||||
add_no_canonicalize_64_64(x, y)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn add(x: uint64x2_t, y: uint64x2_t) -> uint64x2_t {
|
||||
add_no_canonicalize_64_64(x, canonicalize(y))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn sub(x: uint64x2_t, y: uint64x2_t) -> uint64x2_t {
|
||||
let y = canonicalize(y);
|
||||
let mask = vcgtq_u64(y, x); // Mask is -1 if overflow.
|
||||
let res_wrapped = vsubq_u64(x, y);
|
||||
let res_maybe_unwrapped = vaddq_u64(res_wrapped, field_order());
|
||||
vbslq_u64(mask, res_maybe_unwrapped, res_wrapped) // Bitwise select
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn neg(y: uint64x2_t) -> uint64x2_t {
|
||||
vsubq_u64(field_order(), canonicalize(y))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn interleave0(x: uint64x2_t, y: uint64x2_t) -> (uint64x2_t, uint64x2_t) {
|
||||
(vtrn1q_u64(x, y), vtrn2q_u64(x, y))
|
||||
}
|
||||
@ -1,7 +1,7 @@
|
||||
//! Implementation of the Poseidon hash function, as described in
|
||||
//! https://eprint.iacr.org/2019/458.pdf
|
||||
|
||||
#[cfg(target_feature = "avx2")]
|
||||
#[cfg(any(target_feature = "avx2", target_feature = "neon"))]
|
||||
use std::convert::TryInto;
|
||||
|
||||
use unroll::unroll_for_loops;
|
||||
@ -508,6 +508,14 @@ impl Poseidon<8> for CrandallField {
|
||||
ALL_ROUND_CONSTANTS[8 * round_ctr..8 * round_ctr + 8].try_into().unwrap()); }
|
||||
}
|
||||
|
||||
#[cfg(target_feature="neon")]
|
||||
#[inline(always)]
|
||||
fn constant_layer(state: &mut [Self; 8], round_ctr: usize) {
|
||||
// This assumes that every element of ALL_ROUND_CONSTANTS is in 0..CrandallField::ORDER.
|
||||
unsafe { crate::hash::poseidon_neon::crandall_poseidon_const_neon::<4>(state,
|
||||
ALL_ROUND_CONSTANTS[8 * round_ctr..8 * round_ctr + 8].try_into().unwrap()); }
|
||||
}
|
||||
|
||||
#[cfg(target_feature="avx2")]
|
||||
#[inline(always)]
|
||||
fn mds_layer(state_: &[CrandallField; 8]) -> [CrandallField; 8] {
|
||||
@ -739,6 +747,14 @@ impl Poseidon<12> for CrandallField {
|
||||
ALL_ROUND_CONSTANTS[12 * round_ctr..12 * round_ctr + 12].try_into().unwrap()); }
|
||||
}
|
||||
|
||||
#[cfg(target_feature="neon")]
|
||||
#[inline(always)]
|
||||
fn constant_layer(state: &mut [Self; 12], round_ctr: usize) {
|
||||
// This assumes that every element of ALL_ROUND_CONSTANTS is in 0..CrandallField::ORDER.
|
||||
unsafe { crate::hash::poseidon_neon::crandall_poseidon_const_neon::<6>(state,
|
||||
ALL_ROUND_CONSTANTS[12 * round_ctr..12 * round_ctr + 12].try_into().unwrap()); }
|
||||
}
|
||||
|
||||
#[cfg(target_feature="avx2")]
|
||||
#[inline(always)]
|
||||
fn mds_layer(state_: &[CrandallField; 12]) -> [CrandallField; 12] {
|
||||
|
||||
@ -2,6 +2,8 @@ use core::arch::aarch64::*;
|
||||
|
||||
use crate::field::crandall_field::CrandallField;
|
||||
use crate::field::field_types::PrimeField;
|
||||
use crate::field::packed_crandall_neon::PackedCrandallNeon;
|
||||
use crate::field::packed_field::PackedField;
|
||||
|
||||
const EPSILON: u64 = 0u64.wrapping_sub(CrandallField::ORDER);
|
||||
|
||||
@ -228,3 +230,18 @@ unsafe fn mul_add_32_32_64(x: uint32x2_t, y: uint32x2_t, z: uint64x2_t) -> uint6
|
||||
let res_unwrapped = vaddq_u64(res_wrapped, vmovq_n_u64(EPSILON));
|
||||
vbslq_u64(mask, res_unwrapped, res_wrapped)
|
||||
}
|
||||
|
||||
/// Poseidon constant layer for Crandall. Assumes that every element in round_constants is in
|
||||
/// 0..CrandallField::ORDER; when this is not true it may return garbage. It's marked unsafe for
|
||||
/// this reason.
|
||||
#[inline(always)]
|
||||
pub unsafe fn crandall_poseidon_const_neon<const PACKED_WIDTH: usize>(
|
||||
state: &mut [CrandallField; 2 * PACKED_WIDTH],
|
||||
round_constants: [u64; 2 * PACKED_WIDTH],
|
||||
) {
|
||||
let packed_state = PackedCrandallNeon::pack_slice_mut(state);
|
||||
for i in 0..PACKED_WIDTH {
|
||||
let packed_round_const = vld1q_u64(round_constants[2 * i..2 * i + 2].as_ptr());
|
||||
packed_state[i] = packed_state[i].add_canonical_u64(packed_round_const);
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user