diff --git a/src/field/crandall_field.rs b/src/field/crandall_field.rs index fdb39241..c3627713 100644 --- a/src/field/crandall_field.rs +++ b/src/field/crandall_field.rs @@ -192,6 +192,11 @@ impl PrimeField for CrandallField { fn to_noncanonical_u64(&self) -> u64 { self.0 } + + #[inline] + fn from_noncanonical_u64(n: u64) -> Self { + Self(n) + } } impl Neg for CrandallField { diff --git a/src/field/field_types.rs b/src/field/field_types.rs index 0ea24509..43f6abda 100644 --- a/src/field/field_types.rs +++ b/src/field/field_types.rs @@ -308,6 +308,8 @@ pub trait PrimeField: Field { fn to_canonical_u64(&self) -> u64; fn to_noncanonical_u64(&self) -> u64; + + fn from_noncanonical_u64(n: u64) -> Self; } /// An iterator over the powers of a certain base element `b`: `b^0, b^1, b^2, ...`. diff --git a/src/field/goldilocks_field.rs b/src/field/goldilocks_field.rs index a4e052da..3b98d78c 100644 --- a/src/field/goldilocks_field.rs +++ b/src/field/goldilocks_field.rs @@ -115,6 +115,11 @@ impl PrimeField for GoldilocksField { fn to_noncanonical_u64(&self) -> u64 { self.0 } + + #[inline] + fn from_noncanonical_u64(n: u64) -> Self { + Self(n) + } } impl Neg for GoldilocksField { diff --git a/src/field/mod.rs b/src/field/mod.rs index 75775ed1..2b81774f 100644 --- a/src/field/mod.rs +++ b/src/field/mod.rs @@ -10,7 +10,7 @@ pub(crate) mod packable; pub(crate) mod packed_field; #[cfg(target_feature = "avx2")] -pub(crate) mod packed_crandall_avx2; +pub(crate) mod packed_avx2; #[cfg(test)] mod field_testing; diff --git a/src/field/packable.rs b/src/field/packable.rs index e73e73bd..6e3fccb0 100644 --- a/src/field/packable.rs +++ b/src/field/packable.rs @@ -14,5 +14,10 @@ impl Packable for F { #[cfg(target_feature = "avx2")] impl Packable for crate::field::crandall_field::CrandallField { - type PackedType = crate::field::packed_crandall_avx2::PackedCrandallAVX2; + type PackedType = crate::field::packed_avx2::PackedCrandallAVX2; +} + +#[cfg(target_feature = "avx2")] +impl Packable for crate::field::goldilocks_field::GoldilocksField { + type PackedType = crate::field::packed_avx2::PackedGoldilocksAVX2; } diff --git a/src/field/packed_avx2/common.rs b/src/field/packed_avx2/common.rs new file mode 100644 index 00000000..97674a17 --- /dev/null +++ b/src/field/packed_avx2/common.rs @@ -0,0 +1,39 @@ +use core::arch::x86_64::*; + +use crate::field::field_types::PrimeField; + +pub trait ReducibleAVX2: PrimeField { + unsafe fn reduce128s_s(x_s: (__m256i, __m256i)) -> __m256i; +} + +#[inline] +pub unsafe fn field_order() -> __m256i { + _mm256_set1_epi64x(F::ORDER as i64) +} + +#[inline] +pub unsafe fn epsilon() -> __m256i { + _mm256_set1_epi64x(0u64.wrapping_sub(F::ORDER) as i64) +} + +/// Addition u64 + u64 -> u64. Assumes that x + y < 2^64 + FIELD_ORDER. The second argument is +/// pre-shifted by 1 << 63. The result is similarly shifted. +#[inline] +pub unsafe fn add_no_canonicalize_64_64s_s(x: __m256i, y_s: __m256i) -> __m256i { + let res_wrapped_s = _mm256_add_epi64(x, y_s); + let mask = _mm256_cmpgt_epi64(y_s, res_wrapped_s); // -1 if overflowed else 0. + let wrapback_amt = _mm256_and_si256(mask, epsilon::()); // -FIELD_ORDER if overflowed else 0. + let res_s = _mm256_add_epi64(res_wrapped_s, wrapback_amt); + res_s +} + +/// Subtraction u64 - u64 -> u64. Assumes that double overflow cannot occur. The first argument is +/// pre-shifted by 1 << 63 and the result is similarly shifted. +#[inline] +pub unsafe fn sub_no_canonicalize_64s_64_s(x_s: __m256i, y: __m256i) -> __m256i { + let res_wrapped_s = _mm256_sub_epi64(x_s, y); + let mask = _mm256_cmpgt_epi64(res_wrapped_s, x_s); // -1 if overflowed else 0. + let wrapback_amt = _mm256_and_si256(mask, epsilon::()); // -FIELD_ORDER if overflowed else 0. + let res_s = _mm256_sub_epi64(res_wrapped_s, wrapback_amt); + res_s +} diff --git a/src/field/packed_avx2/crandall.rs b/src/field/packed_avx2/crandall.rs new file mode 100644 index 00000000..0f267f3f --- /dev/null +++ b/src/field/packed_avx2/crandall.rs @@ -0,0 +1,42 @@ +use core::arch::x86_64::*; + +use crate::field::crandall_field::CrandallField; +use crate::field::packed_avx2::common::{add_no_canonicalize_64_64s_s, epsilon, ReducibleAVX2}; + +/// (u64 << 64) + u64 + u64 -> u128 addition with carry. The third argument is pre-shifted by 2^63. +/// The result is also shifted. +#[inline] +unsafe fn add_with_carry_hi_lo_los_s( + hi: __m256i, + lo0: __m256i, + lo1_s: __m256i, +) -> (__m256i, __m256i) { + let res_lo_s = _mm256_add_epi64(lo0, lo1_s); + // carry is -1 if overflow (res_lo < lo1) because cmpgt returns -1 on true and 0 on false. + let carry = _mm256_cmpgt_epi64(lo1_s, res_lo_s); + let res_hi = _mm256_sub_epi64(hi, carry); + (res_hi, res_lo_s) +} + +/// u64 * u32 + u64 fused multiply-add. The result is given as a tuple (u64, u64). The third +/// argument is assumed to be pre-shifted by 2^63. The result is similarly shifted. +#[inline] +unsafe fn fmadd_64_32_64s_s(x: __m256i, y: __m256i, z_s: __m256i) -> (__m256i, __m256i) { + let x_hi = _mm256_srli_epi64(x, 32); + let mul_lo = _mm256_mul_epu32(x, y); + let mul_hi = _mm256_mul_epu32(x_hi, y); + let (tmp_hi, tmp_lo_s) = add_with_carry_hi_lo_los_s(_mm256_srli_epi64(mul_hi, 32), mul_lo, z_s); + add_with_carry_hi_lo_los_s(tmp_hi, _mm256_slli_epi64(mul_hi, 32), tmp_lo_s) +} + +/// Reduce a u128 modulo FIELD_ORDER. The input is (u64, u64), pre-shifted by 2^63. The result is +/// similarly shifted. +impl ReducibleAVX2 for CrandallField { + #[inline] + unsafe fn reduce128s_s(x_s: (__m256i, __m256i)) -> __m256i { + let (hi0, lo0_s) = x_s; + let (hi1, lo1_s) = fmadd_64_32_64s_s(hi0, epsilon::(), lo0_s); + let lo2 = _mm256_mul_epu32(hi1, epsilon::()); + add_no_canonicalize_64_64s_s::(lo2, lo1_s) + } +} diff --git a/src/field/packed_avx2/goldilocks.rs b/src/field/packed_avx2/goldilocks.rs new file mode 100644 index 00000000..2cea1767 --- /dev/null +++ b/src/field/packed_avx2/goldilocks.rs @@ -0,0 +1,20 @@ +use core::arch::x86_64::*; + +use crate::field::goldilocks_field::GoldilocksField; +use crate::field::packed_avx2::common::{ + add_no_canonicalize_64_64s_s, epsilon, sub_no_canonicalize_64s_64_s, ReducibleAVX2, +}; + +/// Reduce a u128 modulo FIELD_ORDER. The input is (u64, u64), pre-shifted by 2^63. The result is +/// similarly shifted. +impl ReducibleAVX2 for GoldilocksField { + #[inline] + unsafe fn reduce128s_s(x_s: (__m256i, __m256i)) -> __m256i { + let (hi0, lo0_s) = x_s; + let hi_hi0 = _mm256_srli_epi64(hi0, 32); + let lo1_s = sub_no_canonicalize_64s_64_s::(lo0_s, hi_hi0); + let t1 = _mm256_mul_epu32(hi0, epsilon::()); + let lo2_s = add_no_canonicalize_64_64s_s::(t1, lo1_s); + lo2_s + } +} diff --git a/src/field/packed_avx2/mod.rs b/src/field/packed_avx2/mod.rs new file mode 100644 index 00000000..1aa7c870 --- /dev/null +++ b/src/field/packed_avx2/mod.rs @@ -0,0 +1,261 @@ +mod common; +mod crandall; +mod goldilocks; +mod packed_prime_field; + +use packed_prime_field::PackedPrimeField; + +use crate::field::crandall_field::CrandallField; +use crate::field::goldilocks_field::GoldilocksField; + +pub type PackedCrandallAVX2 = PackedPrimeField; +pub type PackedGoldilocksAVX2 = PackedPrimeField; + +#[cfg(test)] +mod tests { + use crate::field::crandall_field::CrandallField; + use crate::field::goldilocks_field::GoldilocksField; + use crate::field::packed_avx2::common::ReducibleAVX2; + use crate::field::packed_avx2::packed_prime_field::PackedPrimeField; + use crate::field::packed_field::PackedField; + + fn test_vals_a() -> [F; 4] { + [ + F::from_noncanonical_u64(14479013849828404771), + F::from_noncanonical_u64(9087029921428221768), + F::from_noncanonical_u64(2441288194761790662), + F::from_noncanonical_u64(5646033492608483824), + ] + } + fn test_vals_b() -> [F; 4] { + [ + F::from_noncanonical_u64(17891926589593242302), + F::from_noncanonical_u64(11009798273260028228), + F::from_noncanonical_u64(2028722748960791447), + F::from_noncanonical_u64(7929433601095175579), + ] + } + + fn test_add() + where + [(); PackedPrimeField::::WIDTH]: , + { + let a_arr = test_vals_a::(); + let b_arr = test_vals_b::(); + + let packed_a = PackedPrimeField::::from_arr(a_arr); + let packed_b = PackedPrimeField::::from_arr(b_arr); + let packed_res = packed_a + packed_b; + let arr_res = packed_res.to_arr(); + + let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a + b); + for (exp, res) in expected.zip(arr_res) { + assert_eq!(res, exp); + } + } + + fn test_mul() + where + [(); PackedPrimeField::::WIDTH]: , + { + let a_arr = test_vals_a::(); + let b_arr = test_vals_b::(); + + let packed_a = PackedPrimeField::::from_arr(a_arr); + let packed_b = PackedPrimeField::::from_arr(b_arr); + let packed_res = packed_a * packed_b; + let arr_res = packed_res.to_arr(); + + let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a * b); + for (exp, res) in expected.zip(arr_res) { + assert_eq!(res, exp); + } + } + + fn test_square() + where + [(); PackedPrimeField::::WIDTH]: , + { + let a_arr = test_vals_a::(); + + let packed_a = PackedPrimeField::::from_arr(a_arr); + let packed_res = packed_a.square(); + let arr_res = packed_res.to_arr(); + + let expected = a_arr.iter().map(|&a| a.square()); + for (exp, res) in expected.zip(arr_res) { + assert_eq!(res, exp); + } + } + + fn test_neg() + where + [(); PackedPrimeField::::WIDTH]: , + { + let a_arr = test_vals_a::(); + + let packed_a = PackedPrimeField::::from_arr(a_arr); + let packed_res = -packed_a; + let arr_res = packed_res.to_arr(); + + let expected = a_arr.iter().map(|&a| -a); + for (exp, res) in expected.zip(arr_res) { + assert_eq!(res, exp); + } + } + + fn test_sub() + where + [(); PackedPrimeField::::WIDTH]: , + { + let a_arr = test_vals_a::(); + let b_arr = test_vals_b::(); + + let packed_a = PackedPrimeField::::from_arr(a_arr); + let packed_b = PackedPrimeField::::from_arr(b_arr); + let packed_res = packed_a - packed_b; + let arr_res = packed_res.to_arr(); + + let expected = a_arr.iter().zip(b_arr).map(|(&a, b)| a - b); + for (exp, res) in expected.zip(arr_res) { + assert_eq!(res, exp); + } + } + + fn test_interleave_is_involution() + where + [(); PackedPrimeField::::WIDTH]: , + { + let a_arr = test_vals_a::(); + let b_arr = test_vals_b::(); + + let packed_a = PackedPrimeField::::from_arr(a_arr); + let packed_b = PackedPrimeField::::from_arr(b_arr); + { + // Interleave, then deinterleave. + let (x, y) = packed_a.interleave(packed_b, 0); + let (res_a, res_b) = x.interleave(y, 0); + assert_eq!(res_a.to_arr(), a_arr); + assert_eq!(res_b.to_arr(), b_arr); + } + { + let (x, y) = packed_a.interleave(packed_b, 1); + let (res_a, res_b) = x.interleave(y, 1); + assert_eq!(res_a.to_arr(), a_arr); + assert_eq!(res_b.to_arr(), b_arr); + } + } + + fn test_interleave() + where + [(); PackedPrimeField::::WIDTH]: , + { + let in_a: [F; 4] = [ + F::from_noncanonical_u64(00), + F::from_noncanonical_u64(01), + F::from_noncanonical_u64(02), + F::from_noncanonical_u64(03), + ]; + let in_b: [F; 4] = [ + F::from_noncanonical_u64(10), + F::from_noncanonical_u64(11), + F::from_noncanonical_u64(12), + F::from_noncanonical_u64(13), + ]; + let int0_a: [F; 4] = [ + F::from_noncanonical_u64(00), + F::from_noncanonical_u64(10), + F::from_noncanonical_u64(02), + F::from_noncanonical_u64(12), + ]; + let int0_b: [F; 4] = [ + F::from_noncanonical_u64(01), + F::from_noncanonical_u64(11), + F::from_noncanonical_u64(03), + F::from_noncanonical_u64(13), + ]; + let int1_a: [F; 4] = [ + F::from_noncanonical_u64(00), + F::from_noncanonical_u64(01), + F::from_noncanonical_u64(10), + F::from_noncanonical_u64(11), + ]; + let int1_b: [F; 4] = [ + F::from_noncanonical_u64(02), + F::from_noncanonical_u64(03), + F::from_noncanonical_u64(12), + F::from_noncanonical_u64(13), + ]; + + let packed_a = PackedPrimeField::::from_arr(in_a); + let packed_b = PackedPrimeField::::from_arr(in_b); + { + let (x0, y0) = packed_a.interleave(packed_b, 0); + assert_eq!(x0.to_arr(), int0_a); + assert_eq!(y0.to_arr(), int0_b); + } + { + let (x1, y1) = packed_a.interleave(packed_b, 1); + assert_eq!(x1.to_arr(), int1_a); + assert_eq!(y1.to_arr(), int1_b); + } + } + + #[test] + fn test_add_crandall() { + test_add::(); + } + #[test] + fn test_mul_crandall() { + test_mul::(); + } + #[test] + fn test_square_crandall() { + test_square::(); + } + #[test] + fn test_neg_crandall() { + test_neg::(); + } + #[test] + fn test_sub_crandall() { + test_sub::(); + } + #[test] + fn test_interleave_is_involution_crandall() { + test_interleave_is_involution::(); + } + #[test] + fn test_interleave_crandall() { + test_interleave::(); + } + + #[test] + fn test_add_goldilocks() { + test_add::(); + } + #[test] + fn test_mul_goldilocks() { + test_mul::(); + } + #[test] + fn test_square_goldilocks() { + test_square::(); + } + #[test] + fn test_neg_goldilocks() { + test_neg::(); + } + #[test] + fn test_sub_goldilocks() { + test_sub::(); + } + #[test] + fn test_interleave_is_involution_goldilocks() { + test_interleave_is_involution::(); + } + #[test] + fn test_interleave_goldilocks() { + test_interleave::(); + } +} diff --git a/src/field/packed_avx2/packed_prime_field.rs b/src/field/packed_avx2/packed_prime_field.rs new file mode 100644 index 00000000..b892da4a --- /dev/null +++ b/src/field/packed_avx2/packed_prime_field.rs @@ -0,0 +1,402 @@ +use core::arch::x86_64::*; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::iter::{Product, Sum}; +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +use crate::field::field_types::PrimeField; +use crate::field::packed_avx2::common::{ + add_no_canonicalize_64_64s_s, epsilon, field_order, ReducibleAVX2, +}; +use crate::field::packed_field::PackedField; + +// PackedPrimeField wraps an array of four u64s, with the new and get methods to convert that +// array to and from __m256i, which is the type we actually operate on. This indirection is a +// terrible trick to change PackedPrimeField's alignment. +// We'd like to be able to cast slices of PrimeField to slices of PackedPrimeField. Rust +// aligns __m256i to 32 bytes but PrimeField has a lower alignment. That alignment extends to +// PackedPrimeField and it appears that it cannot be lowered with #[repr(C, blah)]. It is +// important for Rust not to assume 32-byte alignment, so we cannot wrap __m256i directly. +// There are two versions of vectorized load/store instructions on x86: aligned (vmovaps and +// friends) and unaligned (vmovups etc.). The difference between them is that aligned loads and +// stores are permitted to segfault on unaligned accesses. Historically, the aligned instructions +// were faster, and although this is no longer the case, compilers prefer the aligned versions if +// they know that the address is aligned. Using aligned instructions on unaligned addresses leads to +// bugs that can be frustrating to diagnose. Hence, we can't have Rust assuming alignment, and +// therefore PackedPrimeField wraps [F; 4] and not __m256i. +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct PackedPrimeField(pub [F; 4]); + +impl PackedPrimeField { + #[inline] + fn new(x: __m256i) -> Self { + let mut obj = Self([F::ZERO; 4]); + let ptr = (&mut obj.0).as_mut_ptr().cast::<__m256i>(); + unsafe { + _mm256_storeu_si256(ptr, x); + } + obj + } + #[inline] + fn get(&self) -> __m256i { + let ptr = (&self.0).as_ptr().cast::<__m256i>(); + unsafe { _mm256_loadu_si256(ptr) } + } +} + +impl Add for PackedPrimeField { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self::new(unsafe { add::(self.get(), rhs.get()) }) + } +} +impl Add for PackedPrimeField { + type Output = Self; + #[inline] + fn add(self, rhs: F) -> Self { + self + Self::broadcast(rhs) + } +} +impl AddAssign for PackedPrimeField { + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} +impl AddAssign for PackedPrimeField { + #[inline] + fn add_assign(&mut self, rhs: F) { + *self = *self + rhs; + } +} + +impl Debug for PackedPrimeField { + #[inline] + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "({:?})", self.get()) + } +} + +impl Default for PackedPrimeField { + #[inline] + fn default() -> Self { + Self::zero() + } +} + +impl Mul for PackedPrimeField { + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + Self::new(unsafe { mul::(self.get(), rhs.get()) }) + } +} +impl Mul for PackedPrimeField { + type Output = Self; + #[inline] + fn mul(self, rhs: F) -> Self { + self * Self::broadcast(rhs) + } +} +impl MulAssign for PackedPrimeField { + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} +impl MulAssign for PackedPrimeField { + #[inline] + fn mul_assign(&mut self, rhs: F) { + *self = *self * rhs; + } +} + +impl Neg for PackedPrimeField { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self::new(unsafe { neg::(self.get()) }) + } +} + +impl Product for PackedPrimeField { + #[inline] + fn product>(iter: I) -> Self { + iter.reduce(|x, y| x * y).unwrap_or(Self::one()) + } +} + +impl PackedField for PackedPrimeField { + const LOG2_WIDTH: usize = 2; + + type FieldType = F; + + #[inline] + fn broadcast(x: F) -> Self { + Self([x; 4]) + } + + #[inline] + fn from_arr(arr: [F; Self::WIDTH]) -> Self { + Self(arr) + } + + #[inline] + fn to_arr(&self) -> [F; Self::WIDTH] { + self.0 + } + + #[inline] + fn from_slice(slice: &[F]) -> Self { + assert!(slice.len() == 4); + Self([slice[0], slice[1], slice[2], slice[3]]) + } + + #[inline] + fn to_vec(&self) -> Vec { + self.0.into() + } + + #[inline] + fn interleave(&self, other: Self, r: usize) -> (Self, Self) { + let (v0, v1) = (self.get(), other.get()); + let (res0, res1) = match r { + 0 => unsafe { interleave0(v0, v1) }, + 1 => unsafe { interleave1(v0, v1) }, + 2 => (v0, v1), + _ => panic!("r cannot be more than LOG2_WIDTH"), + }; + (Self::new(res0), Self::new(res1)) + } + + #[inline] + fn square(&self) -> Self { + Self::new(unsafe { square::(self.get()) }) + } +} + +impl Sub for PackedPrimeField { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self::new(unsafe { sub::(self.get(), rhs.get()) }) + } +} +impl Sub for PackedPrimeField { + type Output = Self; + #[inline] + fn sub(self, rhs: F) -> Self { + self - Self::broadcast(rhs) + } +} +impl SubAssign for PackedPrimeField { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} +impl SubAssign for PackedPrimeField { + #[inline] + fn sub_assign(&mut self, rhs: F) { + *self = *self - rhs; + } +} + +impl Sum for PackedPrimeField { + #[inline] + fn sum>(iter: I) -> Self { + iter.reduce(|x, y| x + y).unwrap_or(Self::zero()) + } +} + +const SIGN_BIT: u64 = 1 << 63; + +#[inline] +unsafe fn sign_bit() -> __m256i { + _mm256_set1_epi64x(SIGN_BIT as i64) +} + +// Resources: +// 1. Intel Intrinsics Guide for explanation of each intrinsic: +// https://software.intel.com/sites/landingpage/IntrinsicsGuide/ +// 2. uops.info lists micro-ops for each instruction: https://uops.info/table.html +// 3. Intel optimization manual for introduction to x86 vector extensions and best practices: +// https://software.intel.com/content/www/us/en/develop/download/intel-64-and-ia-32-architectures-optimization-reference-manual.html + +// Preliminary knowledge: +// 1. Vector code usually avoids branching. Instead of branches, we can do input selection with +// _mm256_blendv_epi8 or similar instruction. If all we're doing is conditionally zeroing a +// vector element then _mm256_and_si256 or _mm256_andnot_si256 may be used and are cheaper. +// +// 2. AVX does not support addition with carry but 128-bit (2-word) addition can be easily +// emulated. The method recognizes that for a + b overflowed iff (a + b) < a: +// i. res_lo = a_lo + b_lo +// ii. carry_mask = res_lo < a_lo +// iii. res_hi = a_hi + b_hi - carry_mask +// Notice that carry_mask is subtracted, not added. This is because AVX comparison instructions +// return -1 (all bits 1) for true and 0 for false. +// +// 3. AVX does not have unsigned 64-bit comparisons. Those can be emulated with signed comparisons +// by recognizing that a __m256i { + _mm256_xor_si256(x, sign_bit()) +} + +/// Convert to canonical representation. +/// The argument is assumed to be shifted by 1 << 63 (i.e. x_s = x + 1<<63, where x is the field +/// value). The returned value is similarly shifted by 1 << 63 (i.e. we return y_s = y + (1<<63), +/// where 0 <= y < FIELD_ORDER). +#[inline] +unsafe fn canonicalize_s(x_s: __m256i) -> __m256i { + // If x >= FIELD_ORDER then corresponding mask bits are all 0; otherwise all 1. + let mask = _mm256_cmpgt_epi64(shift(field_order::()), x_s); + // wrapback_amt is -FIELD_ORDER if mask is 0; otherwise 0. + let wrapback_amt = _mm256_andnot_si256(mask, epsilon::()); + _mm256_add_epi64(x_s, wrapback_amt) +} + +#[inline] +unsafe fn add(x: __m256i, y: __m256i) -> __m256i { + let y_s = shift(y); + let res_s = add_no_canonicalize_64_64s_s::(x, canonicalize_s::(y_s)); + shift(res_s) +} + +#[inline] +unsafe fn sub(x: __m256i, y: __m256i) -> __m256i { + let mut y_s = shift(y); + y_s = canonicalize_s::(y_s); + let x_s = shift(x); + let mask = _mm256_cmpgt_epi64(y_s, x_s); // -1 if sub will underflow (y > x) else 0. + let wrapback_amt = _mm256_and_si256(mask, epsilon::()); // -FIELD_ORDER if underflow else 0. + let res_wrapped = _mm256_sub_epi64(x_s, y_s); + let res = _mm256_sub_epi64(res_wrapped, wrapback_amt); + res +} + +#[inline] +unsafe fn neg(y: __m256i) -> __m256i { + let y_s = shift(y); + _mm256_sub_epi64(shift(field_order::()), canonicalize_s::(y_s)) +} + +/// Full 64-bit by 64-bit multiplication. This emulated multiplication is 1.5x slower than the +/// scalar instruction, but may be worth it if we want our data to live in vector registers. +#[inline] +unsafe fn mul64_64_s(x: __m256i, y: __m256i) -> (__m256i, __m256i) { + let x_hi = _mm256_srli_epi64(x, 32); + let y_hi = _mm256_srli_epi64(y, 32); + let mul_ll = _mm256_mul_epu32(x, y); + let mul_lh = _mm256_mul_epu32(x, y_hi); + let mul_hl = _mm256_mul_epu32(x_hi, y); + let mul_hh = _mm256_mul_epu32(x_hi, y_hi); + + let res_lo0_s = shift(mul_ll); + let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 32)); + let res_lo2_s = _mm256_add_epi32(res_lo1_s, _mm256_slli_epi64(mul_hl, 32)); + + // cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on + // overflow and must be subtracted, not added. + let carry0 = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s); + let carry1 = _mm256_cmpgt_epi64(res_lo1_s, res_lo2_s); + + let res_hi0 = mul_hh; + let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 32)); + let res_hi2 = _mm256_add_epi64(res_hi1, _mm256_srli_epi64(mul_hl, 32)); + let res_hi3 = _mm256_sub_epi64(res_hi2, carry0); + let res_hi4 = _mm256_sub_epi64(res_hi3, carry1); + + (res_hi4, res_lo2_s) +} + +/// Full 64-bit squaring. This routine is 1.2x faster than the scalar instruction. +#[inline] +unsafe fn square64_s(x: __m256i) -> (__m256i, __m256i) { + let x_hi = _mm256_srli_epi64(x, 32); + let mul_ll = _mm256_mul_epu32(x, x); + let mul_lh = _mm256_mul_epu32(x, x_hi); + let mul_hh = _mm256_mul_epu32(x_hi, x_hi); + + let res_lo0_s = shift(mul_ll); + let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 33)); + + // cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on + // overflow and must be subtracted, not added. + let carry = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s); + + let res_hi0 = mul_hh; + let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 31)); + let res_hi2 = _mm256_sub_epi64(res_hi1, carry); + + (res_hi2, res_lo1_s) +} + +/// Multiply two integers modulo FIELD_ORDER. +#[inline] +unsafe fn mul(x: __m256i, y: __m256i) -> __m256i { + shift(F::reduce128s_s(mul64_64_s(x, y))) +} + +/// Square an integer modulo FIELD_ORDER. +#[inline] +unsafe fn square(x: __m256i) -> __m256i { + shift(F::reduce128s_s(square64_s(x))) +} + +#[inline] +unsafe fn interleave0(x: __m256i, y: __m256i) -> (__m256i, __m256i) { + let a = _mm256_unpacklo_epi64(x, y); + let b = _mm256_unpackhi_epi64(x, y); + (a, b) +} + +#[inline] +unsafe fn interleave1(x: __m256i, y: __m256i) -> (__m256i, __m256i) { + let y_lo = _mm256_castsi256_si128(y); // This has 0 cost. + + // 1 places y_lo in the high half of x; 0 would place it in the lower half. + let a = _mm256_inserti128_si256::<1>(x, y_lo); + // NB: _mm256_permute2x128_si256 could be used here as well but _mm256_inserti128_si256 has + // lower latency on Zen 3 processors. + + // Each nibble of the constant has the following semantics: + // 0 => src1[low 128 bits] + // 1 => src1[high 128 bits] + // 2 => src2[low 128 bits] + // 3 => src2[high 128 bits] + // The low (resp. high) nibble chooses the low (resp. high) 128 bits of the result. + let b = _mm256_permute2x128_si256::<0x31>(x, y); + + (a, b) +} diff --git a/src/field/packed_crandall_avx2.rs b/src/field/packed_crandall_avx2.rs deleted file mode 100644 index a5336126..00000000 --- a/src/field/packed_crandall_avx2.rs +++ /dev/null @@ -1,622 +0,0 @@ -use core::arch::x86_64::*; -use std::fmt; -use std::fmt::{Debug, Formatter}; -use std::iter::{Product, Sum}; -use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; - -use crate::field::crandall_field::CrandallField; -use crate::field::packed_field::PackedField; - -// PackedCrandallAVX2 wraps an array of four u64s, with the new and get methods to convert that -// array to and from __m256i, which is the type we actually operate on. This indirection is a -// terrible trick to change PackedCrandallAVX2's alignment. -// We'd like to be able to cast slices of CrandallField to slices of PackedCrandallAVX2. Rust -// aligns __m256i to 32 bytes but CrandallField has a lower alignment. That alignment extends to -// PackedCrandallAVX2 and it appears that it cannot be lowered with #[repr(C, blah)]. It is -// important for Rust not to assume 32-byte alignment, so we cannot wrap __m256i directly. -// There are two versions of vectorized load/store instructions on x86: aligned (vmovaps and -// friends) and unaligned (vmovups etc.). The difference between them is that aligned loads and -// stores are permitted to segfault on unaligned accesses. Historically, the aligned instructions -// were faster, and although this is no longer the case, compilers prefer the aligned versions if -// they know that the address is aligned. Using aligned instructions on unaligned addresses leads to -// bugs that can be frustrating to diagnose. Hence, we can't have Rust assuming alignment, and -// therefore PackedCrandallAVX2 wraps [u64; 4] and not __m256i. -#[derive(Copy, Clone)] -#[repr(transparent)] -pub struct PackedCrandallAVX2(pub [u64; 4]); - -impl PackedCrandallAVX2 { - #[inline] - fn new(x: __m256i) -> Self { - let mut obj = Self([0, 0, 0, 0]); - let ptr = (&mut obj.0).as_mut_ptr().cast::<__m256i>(); - unsafe { - _mm256_storeu_si256(ptr, x); - } - obj - } - #[inline] - fn get(&self) -> __m256i { - let ptr = (&self.0).as_ptr().cast::<__m256i>(); - unsafe { _mm256_loadu_si256(ptr) } - } -} - -impl Add for PackedCrandallAVX2 { - type Output = Self; - #[inline] - fn add(self, rhs: Self) -> Self { - Self::new(unsafe { add(self.get(), rhs.get()) }) - } -} -impl Add for PackedCrandallAVX2 { - type Output = Self; - #[inline] - fn add(self, rhs: CrandallField) -> Self { - self + Self::broadcast(rhs) - } -} -impl AddAssign for PackedCrandallAVX2 { - #[inline] - fn add_assign(&mut self, rhs: Self) { - *self = *self + rhs; - } -} -impl AddAssign for PackedCrandallAVX2 { - #[inline] - fn add_assign(&mut self, rhs: CrandallField) { - *self = *self + rhs; - } -} - -impl Debug for PackedCrandallAVX2 { - #[inline] - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "({:?})", self.get()) - } -} - -impl Default for PackedCrandallAVX2 { - #[inline] - fn default() -> Self { - Self::zero() - } -} - -impl Mul for PackedCrandallAVX2 { - type Output = Self; - #[inline] - fn mul(self, rhs: Self) -> Self { - Self::new(unsafe { mul(self.get(), rhs.get()) }) - } -} -impl Mul for PackedCrandallAVX2 { - type Output = Self; - #[inline] - fn mul(self, rhs: CrandallField) -> Self { - self * Self::broadcast(rhs) - } -} -impl MulAssign for PackedCrandallAVX2 { - #[inline] - fn mul_assign(&mut self, rhs: Self) { - *self = *self * rhs; - } -} -impl MulAssign for PackedCrandallAVX2 { - #[inline] - fn mul_assign(&mut self, rhs: CrandallField) { - *self = *self * rhs; - } -} - -impl Neg for PackedCrandallAVX2 { - type Output = Self; - #[inline] - fn neg(self) -> Self { - Self::new(unsafe { neg(self.get()) }) - } -} - -impl Product for PackedCrandallAVX2 { - #[inline] - fn product>(iter: I) -> Self { - iter.reduce(|x, y| x * y).unwrap_or(Self::one()) - } -} - -impl PackedField for PackedCrandallAVX2 { - const LOG2_WIDTH: usize = 2; - - type FieldType = CrandallField; - - #[inline] - fn broadcast(x: CrandallField) -> Self { - Self::new(unsafe { _mm256_set1_epi64x(x.0 as i64) }) - } - - #[inline] - fn from_arr(arr: [Self::FieldType; Self::WIDTH]) -> Self { - Self([arr[0].0, arr[1].0, arr[2].0, arr[3].0]) - } - - #[inline] - fn to_arr(&self) -> [Self::FieldType; Self::WIDTH] { - [ - CrandallField(self.0[0]), - CrandallField(self.0[1]), - CrandallField(self.0[2]), - CrandallField(self.0[3]), - ] - } - - #[inline] - fn from_slice(slice: &[Self::FieldType]) -> Self { - assert!(slice.len() == 4); - Self::from_arr([slice[0], slice[1], slice[2], slice[3]]) - } - - #[inline] - fn to_vec(&self) -> Vec { - vec![ - CrandallField(self.0[0]), - CrandallField(self.0[1]), - CrandallField(self.0[2]), - CrandallField(self.0[3]), - ] - } - - #[inline] - fn interleave(&self, other: Self, r: usize) -> (Self, Self) { - let (v0, v1) = (self.get(), other.get()); - let (res0, res1) = match r { - 0 => unsafe { interleave0(v0, v1) }, - 1 => unsafe { interleave1(v0, v1) }, - 2 => (v0, v1), - _ => panic!("r cannot be more than LOG2_WIDTH"), - }; - (Self::new(res0), Self::new(res1)) - } - - #[inline] - fn square(&self) -> Self { - Self::new(unsafe { square(self.get()) }) - } -} - -impl Sub for PackedCrandallAVX2 { - type Output = Self; - #[inline] - fn sub(self, rhs: Self) -> Self { - Self::new(unsafe { sub(self.get(), rhs.get()) }) - } -} -impl Sub for PackedCrandallAVX2 { - type Output = Self; - #[inline] - fn sub(self, rhs: CrandallField) -> Self { - self - Self::broadcast(rhs) - } -} -impl SubAssign for PackedCrandallAVX2 { - #[inline] - fn sub_assign(&mut self, rhs: Self) { - *self = *self - rhs; - } -} -impl SubAssign for PackedCrandallAVX2 { - #[inline] - fn sub_assign(&mut self, rhs: CrandallField) { - *self = *self - rhs; - } -} - -impl Sum for PackedCrandallAVX2 { - #[inline] - fn sum>(iter: I) -> Self { - iter.reduce(|x, y| x + y).unwrap_or(Self::zero()) - } -} - -const EPSILON: u64 = (1 << 31) + (1 << 28) - 1; -const FIELD_ORDER: u64 = 0u64.wrapping_sub(EPSILON); -const SIGN_BIT: u64 = 1 << 63; - -#[inline] -unsafe fn field_order() -> __m256i { - _mm256_set1_epi64x(FIELD_ORDER as i64) -} - -#[inline] -unsafe fn epsilon() -> __m256i { - _mm256_set1_epi64x(EPSILON as i64) -} - -#[inline] -unsafe fn sign_bit() -> __m256i { - _mm256_set1_epi64x(SIGN_BIT as i64) -} - -// Resources: -// 1. Intel Intrinsics Guide for explanation of each intrinsic: -// https://software.intel.com/sites/landingpage/IntrinsicsGuide/ -// 2. uops.info lists micro-ops for each instruction: https://uops.info/table.html -// 3. Intel optimization manual for introduction to x86 vector extensions and best practices: -// https://software.intel.com/content/www/us/en/develop/download/intel-64-and-ia-32-architectures-optimization-reference-manual.html - -// Preliminary knowledge: -// 1. Vector code usually avoids branching. Instead of branches, we can do input selection with -// _mm256_blendv_epi8 or similar instruction. If all we're doing is conditionally zeroing a -// vector element then _mm256_and_si256 or _mm256_andnot_si256 may be used and are cheaper. -// -// 2. AVX does not support addition with carry but 128-bit (2-word) addition can be easily -// emulated. The method recognizes that for a + b overflowed iff (a + b) < a: -// i. res_lo = a_lo + b_lo -// ii. carry_mask = res_lo < a_lo -// iii. res_hi = a_hi + b_hi - carry_mask -// Notice that carry_mask is subtracted, not added. This is because AVX comparison instructions -// return -1 (all bits 1) for true and 0 for false. -// -// 3. AVX does not have unsigned 64-bit comparisons. Those can be emulated with signed comparisons -// by recognizing that a __m256i { - _mm256_xor_si256(x, sign_bit()) -} - -/// Convert to canonical representation. -/// The argument is assumed to be shifted by 1 << 63 (i.e. x_s = x + 1<<63, where x is the -/// Crandall field value). The returned value is similarly shifted by 1 << 63 (i.e. we return y`_s -/// = y + 1<<63, where 0 <= y < FIELD_ORDER). -#[inline] -unsafe fn canonicalize_s(x_s: __m256i) -> __m256i { - // If x >= FIELD_ORDER then corresponding mask bits are all 0; otherwise all 1. - let mask = _mm256_cmpgt_epi64(shift(field_order()), x_s); - // wrapback_amt is -FIELD_ORDER if mask is 0; otherwise 0. - let wrapback_amt = _mm256_andnot_si256(mask, epsilon()); - _mm256_add_epi64(x_s, wrapback_amt) -} - -/// Addition u64 + u64 -> u64. Assumes that x + y < 2^64 + FIELD_ORDER. The second argument is -/// pre-shifted by 1 << 63. The result is similarly shifted. -#[inline] -unsafe fn add_no_canonicalize_64_64s_s(x: __m256i, y_s: __m256i) -> __m256i { - let res_wrapped_s = _mm256_add_epi64(x, y_s); - let mask = _mm256_cmpgt_epi64(y_s, res_wrapped_s); // -1 if overflowed else 0. - let wrapback_amt = _mm256_and_si256(mask, epsilon()); // -FIELD_ORDER if overflowed else 0. - let res_s = _mm256_add_epi64(res_wrapped_s, wrapback_amt); - res_s -} - -#[inline] -unsafe fn add(x: __m256i, y: __m256i) -> __m256i { - let y_s = shift(y); - let res_s = add_no_canonicalize_64_64s_s(x, canonicalize_s(y_s)); - shift(res_s) -} - -#[inline] -unsafe fn sub(x: __m256i, y: __m256i) -> __m256i { - let mut y_s = shift(y); - y_s = canonicalize_s(y_s); - let x_s = shift(x); - let mask = _mm256_cmpgt_epi64(y_s, x_s); // -1 if sub will underflow (y > x) else 0. - let wrapback_amt = _mm256_and_si256(mask, epsilon()); // -FIELD_ORDER if underflow else 0. - let res_wrapped = _mm256_sub_epi64(x_s, y_s); - let res = _mm256_sub_epi64(res_wrapped, wrapback_amt); - res -} - -#[inline] -unsafe fn neg(y: __m256i) -> __m256i { - let y_s = shift(y); - _mm256_sub_epi64(shift(field_order()), canonicalize_s(y_s)) -} - -/// Full 64-bit by 64-bit multiplication. This emulated multiplication is 1.5x slower than the -/// scalar instruction, but may be worth it if we want our data to live in vector registers. -#[inline] -unsafe fn mul64_64_s(x: __m256i, y: __m256i) -> (__m256i, __m256i) { - let x_hi = _mm256_srli_epi64(x, 32); - let y_hi = _mm256_srli_epi64(y, 32); - let mul_ll = _mm256_mul_epu32(x, y); - let mul_lh = _mm256_mul_epu32(x, y_hi); - let mul_hl = _mm256_mul_epu32(x_hi, y); - let mul_hh = _mm256_mul_epu32(x_hi, y_hi); - - let res_lo0_s = shift(mul_ll); - let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 32)); - let res_lo2_s = _mm256_add_epi32(res_lo1_s, _mm256_slli_epi64(mul_hl, 32)); - - // cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on - // overflow and must be subtracted, not added. - let carry0 = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s); - let carry1 = _mm256_cmpgt_epi64(res_lo1_s, res_lo2_s); - - let res_hi0 = mul_hh; - let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 32)); - let res_hi2 = _mm256_add_epi64(res_hi1, _mm256_srli_epi64(mul_hl, 32)); - let res_hi3 = _mm256_sub_epi64(res_hi2, carry0); - let res_hi4 = _mm256_sub_epi64(res_hi3, carry1); - - (res_hi4, res_lo2_s) -} - -/// Full 64-bit squaring. This routine is 1.2x faster than the scalar instruction. -#[inline] -unsafe fn square64_s(x: __m256i) -> (__m256i, __m256i) { - let x_hi = _mm256_srli_epi64(x, 32); - let mul_ll = _mm256_mul_epu32(x, x); - let mul_lh = _mm256_mul_epu32(x, x_hi); - let mul_hh = _mm256_mul_epu32(x_hi, x_hi); - - let res_lo0_s = shift(mul_ll); - let res_lo1_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 33)); - - // cmpgt returns -1 on true and 0 on false. Hence, the carry values below are set to -1 on - // overflow and must be subtracted, not added. - let carry = _mm256_cmpgt_epi64(res_lo0_s, res_lo1_s); - - let res_hi0 = mul_hh; - let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 31)); - let res_hi2 = _mm256_sub_epi64(res_hi1, carry); - - (res_hi2, res_lo1_s) -} - -/// (u64 << 64) + u64 + u64 -> u128 addition with carry. The third argument is pre-shifted by 2^63. -/// The result is also shifted. -#[inline] -unsafe fn add_with_carry_hi_lo_los_s( - hi: __m256i, - lo0: __m256i, - lo1_s: __m256i, -) -> (__m256i, __m256i) { - let res_lo_s = _mm256_add_epi64(lo0, lo1_s); - // carry is -1 if overflow (res_lo < lo1) because cmpgt returns -1 on true and 0 on false. - let carry = _mm256_cmpgt_epi64(lo1_s, res_lo_s); - let res_hi = _mm256_sub_epi64(hi, carry); - (res_hi, res_lo_s) -} - -/// u64 * u32 + u64 fused multiply-add. The result is given as a tuple (u64, u64). The third -/// argument is assumed to be pre-shifted by 2^63. The result is similarly shifted. -#[inline] -unsafe fn fmadd_64_32_64s_s(x: __m256i, y: __m256i, z_s: __m256i) -> (__m256i, __m256i) { - let x_hi = _mm256_srli_epi64(x, 32); - let mul_lo = _mm256_mul_epu32(x, y); - let mul_hi = _mm256_mul_epu32(x_hi, y); - let (tmp_hi, tmp_lo_s) = add_with_carry_hi_lo_los_s(_mm256_srli_epi64(mul_hi, 32), mul_lo, z_s); - add_with_carry_hi_lo_los_s(tmp_hi, _mm256_slli_epi64(mul_hi, 32), tmp_lo_s) -} - -/// Reduce a u128 modulo FIELD_ORDER. The input is (u64, u64), pre-shifted by 2^63. The result is -/// similarly shifted. -#[inline] -unsafe fn reduce128s_s(x_s: (__m256i, __m256i)) -> __m256i { - let (hi0, lo0_s) = x_s; - let (hi1, lo1_s) = fmadd_64_32_64s_s(hi0, epsilon(), lo0_s); - let lo2 = _mm256_mul_epu32(hi1, epsilon()); - add_no_canonicalize_64_64s_s(lo2, lo1_s) -} - -/// Multiply two integers modulo FIELD_ORDER. -#[inline] -unsafe fn mul(x: __m256i, y: __m256i) -> __m256i { - shift(reduce128s_s(mul64_64_s(x, y))) -} - -/// Square an integer modulo FIELD_ORDER. -#[inline] -unsafe fn square(x: __m256i) -> __m256i { - shift(reduce128s_s(square64_s(x))) -} - -#[inline] -unsafe fn interleave0(x: __m256i, y: __m256i) -> (__m256i, __m256i) { - let a = _mm256_unpacklo_epi64(x, y); - let b = _mm256_unpackhi_epi64(x, y); - (a, b) -} - -#[inline] -unsafe fn interleave1(x: __m256i, y: __m256i) -> (__m256i, __m256i) { - let y_lo = _mm256_castsi256_si128(y); // This has 0 cost. - - // 1 places y_lo in the high half of x; 0 would place it in the lower half. - let a = _mm256_inserti128_si256::<1>(x, y_lo); - // NB: _mm256_permute2x128_si256 could be used here as well but _mm256_inserti128_si256 has - // lower latency on Zen 3 processors. - - // Each nibble of the constant has the following semantics: - // 0 => src1[low 128 bits] - // 1 => src1[high 128 bits] - // 2 => src2[low 128 bits] - // 3 => src2[high 128 bits] - // The low (resp. high) nibble chooses the low (resp. high) 128 bits of the result. - let b = _mm256_permute2x128_si256::<0x31>(x, y); - - (a, b) -} - -#[cfg(test)] -mod tests { - use crate::field::field_types::Field; - use crate::field::packed_crandall_avx2::*; - - const TEST_VALS_A: [CrandallField; 4] = [ - CrandallField(14479013849828404771), - CrandallField(9087029921428221768), - CrandallField(2441288194761790662), - CrandallField(5646033492608483824), - ]; - const TEST_VALS_B: [CrandallField; 4] = [ - CrandallField(17891926589593242302), - CrandallField(11009798273260028228), - CrandallField(2028722748960791447), - CrandallField(7929433601095175579), - ]; - - #[test] - fn test_add() { - let packed_a = PackedCrandallAVX2::from_arr(TEST_VALS_A); - let packed_b = PackedCrandallAVX2::from_arr(TEST_VALS_B); - let packed_res = packed_a + packed_b; - let arr_res = packed_res.to_arr(); - - let expected = TEST_VALS_A.iter().zip(TEST_VALS_B).map(|(&a, b)| a + b); - for (exp, res) in expected.zip(arr_res) { - assert_eq!(res, exp); - } - } - - #[test] - fn test_mul() { - let packed_a = PackedCrandallAVX2::from_arr(TEST_VALS_A); - let packed_b = PackedCrandallAVX2::from_arr(TEST_VALS_B); - let packed_res = packed_a * packed_b; - let arr_res = packed_res.to_arr(); - - let expected = TEST_VALS_A.iter().zip(TEST_VALS_B).map(|(&a, b)| a * b); - for (exp, res) in expected.zip(arr_res) { - assert_eq!(res, exp); - } - } - - #[test] - fn test_square() { - let packed_a = PackedCrandallAVX2::from_arr(TEST_VALS_A); - let packed_res = packed_a.square(); - let arr_res = packed_res.to_arr(); - - let expected = TEST_VALS_A.iter().map(|&a| a.square()); - for (exp, res) in expected.zip(arr_res) { - assert_eq!(res, exp); - } - } - - #[test] - fn test_neg() { - let packed_a = PackedCrandallAVX2::from_arr(TEST_VALS_A); - let packed_res = -packed_a; - let arr_res = packed_res.to_arr(); - - let expected = TEST_VALS_A.iter().map(|&a| -a); - for (exp, res) in expected.zip(arr_res) { - assert_eq!(res, exp); - } - } - - #[test] - fn test_sub() { - let packed_a = PackedCrandallAVX2::from_arr(TEST_VALS_A); - let packed_b = PackedCrandallAVX2::from_arr(TEST_VALS_B); - let packed_res = packed_a - packed_b; - let arr_res = packed_res.to_arr(); - - let expected = TEST_VALS_A.iter().zip(TEST_VALS_B).map(|(&a, b)| a - b); - for (exp, res) in expected.zip(arr_res) { - assert_eq!(res, exp); - } - } - - #[test] - fn test_interleave_is_involution() { - let packed_a = PackedCrandallAVX2::from_arr(TEST_VALS_A); - let packed_b = PackedCrandallAVX2::from_arr(TEST_VALS_B); - { - // Interleave, then deinterleave. - let (x, y) = packed_a.interleave(packed_b, 0); - let (res_a, res_b) = x.interleave(y, 0); - assert_eq!(res_a.to_arr(), TEST_VALS_A); - assert_eq!(res_b.to_arr(), TEST_VALS_B); - } - { - let (x, y) = packed_a.interleave(packed_b, 1); - let (res_a, res_b) = x.interleave(y, 1); - assert_eq!(res_a.to_arr(), TEST_VALS_A); - assert_eq!(res_b.to_arr(), TEST_VALS_B); - } - } - - #[test] - fn test_interleave() { - let in_a: [CrandallField; 4] = [ - CrandallField(00), - CrandallField(01), - CrandallField(02), - CrandallField(03), - ]; - let in_b: [CrandallField; 4] = [ - CrandallField(10), - CrandallField(11), - CrandallField(12), - CrandallField(13), - ]; - let int0_a: [CrandallField; 4] = [ - CrandallField(00), - CrandallField(10), - CrandallField(02), - CrandallField(12), - ]; - let int0_b: [CrandallField; 4] = [ - CrandallField(01), - CrandallField(11), - CrandallField(03), - CrandallField(13), - ]; - let int1_a: [CrandallField; 4] = [ - CrandallField(00), - CrandallField(01), - CrandallField(10), - CrandallField(11), - ]; - let int1_b: [CrandallField; 4] = [ - CrandallField(02), - CrandallField(03), - CrandallField(12), - CrandallField(13), - ]; - - let packed_a = PackedCrandallAVX2::from_arr(in_a); - let packed_b = PackedCrandallAVX2::from_arr(in_b); - { - let (x0, y0) = packed_a.interleave(packed_b, 0); - assert_eq!(x0.to_arr(), int0_a); - assert_eq!(y0.to_arr(), int0_b); - } - { - let (x1, y1) = packed_a.interleave(packed_b, 1); - assert_eq!(x1.to_arr(), int1_a); - assert_eq!(y1.to_arr(), int1_b); - } - } -}