From 7ee7d8bf8a2adfba83874d88b28fed406c37efd0 Mon Sep 17 00:00:00 2001 From: Jakub Nabaglo Date: Fri, 3 Sep 2021 19:55:16 -0700 Subject: [PATCH] Crandall arithmetic in AVX2 --- src/field/crandall_field.rs | 1 + src/field/mod.rs | 3 + src/field/packable.rs | 5 + src/field/packed_crandall_avx2.rs | 597 ++++++++++++++++++++++++++++++ 4 files changed, 606 insertions(+) create mode 100644 src/field/packed_crandall_avx2.rs diff --git a/src/field/crandall_field.rs b/src/field/crandall_field.rs index 05b17318..fa63fcb6 100644 --- a/src/field/crandall_field.rs +++ b/src/field/crandall_field.rs @@ -110,6 +110,7 @@ const CAUCHY_MDS_8: [[CrandallField; 8]; 8] = [ /// = 2**28 * (2**36 - 9) + 1 /// ``` #[derive(Copy, Clone, Serialize, Deserialize)] +#[repr(transparent)] // Must be compatible with PackedCrandallAVX2 pub struct CrandallField(pub u64); impl Default for CrandallField { diff --git a/src/field/mod.rs b/src/field/mod.rs index 6566d430..ac624d39 100644 --- a/src/field/mod.rs +++ b/src/field/mod.rs @@ -7,6 +7,9 @@ pub(crate) mod interpolation; pub(crate) mod packable; pub(crate) mod packed_field; +#[cfg(target_feature = "avx2")] +pub(crate) mod packed_crandall_avx2; + #[cfg(test)] mod field_testing; #[cfg(test)] diff --git a/src/field/packable.rs b/src/field/packable.rs index 27b67428..f650e803 100644 --- a/src/field/packable.rs +++ b/src/field/packable.rs @@ -11,3 +11,8 @@ pub trait Packable: Field { impl Packable for F { default type PackedType = Singleton; } + +#[cfg(target_feature = "avx2")] +impl Packable for CrandallField { + type PackedType = crate::field::packed_crandall_avx2::PackedCrandallAVX2; +} diff --git a/src/field/packed_crandall_avx2.rs b/src/field/packed_crandall_avx2.rs new file mode 100644 index 00000000..3d90132b --- /dev/null +++ b/src/field/packed_crandall_avx2.rs @@ -0,0 +1,597 @@ +use core::arch::x86_64::*; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::iter::{Product, Sum}; +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +use crate::field::crandall_field::CrandallField; +use crate::field::field_types::Field; +use crate::field::packed_field::PackedField; + +// PackedCrandallAVX2 wraps an array of four u64s, with the new and get methods to convert that +// array to and from __m256i, which is the type we actually operate on. This indirection is a +// terrible trick to change PackedCrandallAVX2's alignment. +// We'd like to be able to cast slices of CrandallField to slices of PackedCrandallAVX2. Rust +// aligns __m256i to 32 bytes but CrandallField has a lower alignment. That alignment extends to +// PackedCrandallAVX2 and it appears that it cannot be lowered with #[repr(C, blah)]. It is +// important for Rust not to assume 32-byte alignment, so we cannot wrap __m256i directly. +// There are two versions of vectorized load/store instructions on x86: aligned (vmovaps and +// friends) and unaligned (vmovups etc.). The difference between them is that aligned loads and +// stores are permitted to segfault on unaligned accesses. Historically, the aligned instructions +// were faster, and although this is no longer the case, compilers prefer the aligned versions if +// they know that the address is aligned. Using aligned instructions on unaligned addresses leads to +// bugs that can be frustrating to diagnose. Hence, we can't have Rust assuming alignment, and +// therefore PackedCrandallAVX2 wraps [u64; 4] and not __m256i. +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct PackedCrandallAVX2(pub [u64; 4]); + +impl PackedCrandallAVX2 { + #[inline] + fn new(x: __m256i) -> Self { + let mut obj = Self([0, 0, 0, 0]); + let ptr = (&mut obj.0).as_mut_ptr().cast::<__m256i>(); + unsafe { + _mm256_storeu_si256(ptr, x); + } + obj + } + #[inline] + fn get(&self) -> __m256i { + let ptr = (&self.0).as_ptr().cast::<__m256i>(); + unsafe { _mm256_loadu_si256(ptr) } + } +} + +impl Add for PackedCrandallAVX2 { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self::new(unsafe { add(self.get(), rhs.get()) }) + } +} +impl Add for PackedCrandallAVX2 { + type Output = Self; + #[inline] + fn add(self, rhs: CrandallField) -> Self { + self + Self::broadcast(rhs) + } +} +impl AddAssign for PackedCrandallAVX2 { + #[inline] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} +impl AddAssign for PackedCrandallAVX2 { + #[inline] + fn add_assign(&mut self, rhs: CrandallField) { + *self = *self + rhs; + } +} + +impl Debug for PackedCrandallAVX2 { + #[inline] + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "({:?})", self.get()) + } +} + +impl Default for PackedCrandallAVX2 { + #[inline] + fn default() -> Self { + Self::zero() + } +} + +impl Mul for PackedCrandallAVX2 { + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + Self::new(unsafe { mul(self.get(), rhs.get()) }) + } +} +impl Mul for PackedCrandallAVX2 { + type Output = Self; + #[inline] + fn mul(self, rhs: CrandallField) -> Self { + self * Self::broadcast(rhs) + } +} +impl MulAssign for PackedCrandallAVX2 { + #[inline] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} +impl MulAssign for PackedCrandallAVX2 { + #[inline] + fn mul_assign(&mut self, rhs: CrandallField) { + *self = *self * rhs; + } +} + +impl Neg for PackedCrandallAVX2 { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self::new(unsafe { neg(self.get()) }) + } +} + +impl Product for PackedCrandallAVX2 { + #[inline] + fn product>(iter: I) -> Self { + iter.reduce(|x, y| x * y).unwrap_or(Self::one()) + } +} + +impl PackedField for PackedCrandallAVX2 { + const LOG2_WIDTH: usize = 2; + + type FieldType = CrandallField; + + #[inline] + fn broadcast(x: CrandallField) -> Self { + Self::new(unsafe { _mm256_set1_epi64x(x.0 as i64) }) + } + + #[inline] + fn new_from_slice(arr: &[Self::FieldType]) -> Self { + if let [a, b, c, d] = arr { + let v = unsafe { _mm256_setr_epi64x(a.0 as i64, b.0 as i64, c.0 as i64, d.0 as i64) }; + Self::new(v) + } else { + panic!(); + } + } + #[inline] + fn to_vec(&self) -> Vec { + let a = unsafe { _mm256_extract_epi64(self.get(), 0) } as u64; + let b = unsafe { _mm256_extract_epi64(self.get(), 1) } as u64; + let c = unsafe { _mm256_extract_epi64(self.get(), 2) } as u64; + let d = unsafe { _mm256_extract_epi64(self.get(), 3) } as u64; + vec![ + CrandallField(a), + CrandallField(b), + CrandallField(c), + CrandallField(d), + ] + } + + #[inline] + fn interleave(&self, other: Self, r: usize) -> (Self, Self) { + let (v0, v1) = (self.get(), other.get()); + let (res0, res1) = match r { + 0 => unsafe { interleave0(v0, v1) }, + 1 => unsafe { interleave1(v0, v1) }, + 2 => (v0, v1), + _ => panic!("r cannot be more than LOG2_WIDTH"), + }; + (Self::new(res0), Self::new(res1)) + } +} + +impl Sub for PackedCrandallAVX2 { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self::new(unsafe { sub(self.get(), rhs.get()) }) + } +} +impl Sub for PackedCrandallAVX2 { + type Output = Self; + #[inline] + fn sub(self, rhs: CrandallField) -> Self { + self - Self::broadcast(rhs) + } +} +impl SubAssign for PackedCrandallAVX2 { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} +impl SubAssign for PackedCrandallAVX2 { + #[inline] + fn sub_assign(&mut self, rhs: CrandallField) { + *self = *self - rhs; + } +} + +impl Sum for PackedCrandallAVX2 { + #[inline] + fn sum>(iter: I) -> Self { + iter.reduce(|x, y| x + y).unwrap_or(Self::zero()) + } +} + +const EPSILON: u64 = (1 << 31) + (1 << 28) - 1; +const FIELD_ORDER: u64 = 0u64.overflowing_sub(EPSILON).0; +const SIGN_BIT: u64 = 1 << 63; + +#[inline] +unsafe fn field_order() -> __m256i { + _mm256_set1_epi64x(FIELD_ORDER as i64) +} + +#[inline] +unsafe fn epsilon() -> __m256i { + _mm256_set1_epi64x(EPSILON as i64) +} + +#[inline] +unsafe fn sign_bit() -> __m256i { + _mm256_set1_epi64x(SIGN_BIT as i64) +} + +// Resources: +// 1. Intel Intrinsics Guide for explanation of each intrinsic: +// https://software.intel.com/sites/landingpage/IntrinsicsGuide/ +// 2. uops.info lists micro-ops for each instruction: https://uops.info/table.html +// 3. Intel optimization manual for introduction to x86 vector extensions and best practices: +// https://software.intel.com/content/www/us/en/develop/download/intel-64-and-ia-32-architectures-optimization-reference-manual.html + +// Preliminary knowledge: +// 1. Vector code usually avoids branching. Instead of branches, we can do input selection with +// _mm256_blendv_epi8 or similar instruction. If all we're doing is conditionally zeroing a +// vector element then _mm256_and_si256 or _mm256_andnot_si256 may be used and are cheaper. +// +// 2. AVX does not support addition with carry but 128-bit (2-word) addition can be easily +// emulated. The method recognizes that for a + b overflowed iff (a + b) < a: +// i. res_lo = a_lo + b_lo +// ii. carry_mask = res_lo < a_lo +// iii. res_hi = a_hi + b_hi - carry_mask +// Notice that carry_mask is subtracted, not added. This is because AVX comparison instructions +// return -1 (all bits 1) for true and 0 for false. +// +// 3. AVX does not have unsigned 64-bit comparisons. Those can be emulated with signed comparisons +// by recognizing that a __m256i { + _mm256_xor_si256(x, sign_bit()) +} + +/// Convert to canonical representation. +/// The argument is assumed to be shifted by 1 << 63 (i.e. x_s = x + 1<<63, where x is the +/// Crandall field value). The returned value is similarly shifted by 1 << 63 (i.e. we return y`_s +/// = y + 1<<63, where 0 <= y < FIELD_ORDER). +#[inline] +unsafe fn canonicalize_s(x_s: __m256i) -> __m256i { + // If x >= FIELD_ORDER then corresponding mask bits are all 0; otherwise all 1. + let mask = _mm256_cmpgt_epi64(shift(field_order()), x_s); + // wrapback_amt is -FIELD_ORDER if mask is 0; otherwise 0. + let wrapback_amt = _mm256_andnot_si256(mask, epsilon()); + _mm256_add_epi64(x_s, wrapback_amt) +} + +// Theoretical throughput (Skylake) +// Scalar version (compiled): 1.75 cycles/(op * word) +// Scalar version (optimized asm): 1 cycle/(op * word) +// Below (256-bit vectors): .75 cycles/(op * word) +#[inline] +unsafe fn add(x: __m256i, y: __m256i) -> __m256i { + let mut y_s = shift(y); + y_s = canonicalize_s(y_s); + let res_wrapped_s = _mm256_add_epi64(x, y_s); + let mask = _mm256_cmpgt_epi64(y_s, res_wrapped_s); // 1 if overflowed else 0. + let res_wrapped = shift(res_wrapped_s); + let wrapback_amt = _mm256_and_si256(mask, epsilon()); // -FIELD_ORDER if overflowed else 0. + let res = _mm256_add_epi64(res_wrapped, wrapback_amt); + res +} + +// Theoretical throughput (Skylake) +// Scalar version (compiled): 1.75 cycles/(op * word) +// Scalar version (optimized asm): 1 cycle/(op * word) +// Below (256-bit vectors): .75 cycles/(op * word) +#[inline] +unsafe fn sub(x: __m256i, y: __m256i) -> __m256i { + let mut y_s = shift(y); + y_s = canonicalize_s(y_s); + let x_s = shift(x); + let mask = _mm256_cmpgt_epi64(y_s, x_s); // 1 if sub will underflow (y > y) else 0. + let wrapback_amt = _mm256_and_si256(mask, epsilon()); // -FIELD_ORDER if underflow else 0. + let res_wrapped = _mm256_sub_epi64(x_s, y_s); + let res = _mm256_sub_epi64(res_wrapped, wrapback_amt); + res +} + +// Theoretical throughput (Skylake) +// Scalar version (compiled): 1 cycle/(op * word) +// Scalar version (optimized asm): .5 cycles/(op * word) +// Below (256-bit vectors): .42 cycles/(op * word) +#[inline] +unsafe fn neg(y: __m256i) -> __m256i { + let y_s = shift(y); + let field_order_s = shift(field_order()); + let mask = _mm256_cmpgt_epi64(y_s, field_order_s); // 1 if sub will underflow (y > y) else 0. + let wrapback_amt = _mm256_and_si256(mask, epsilon()); // -FIELD_ORDER if underflow else 0. + let res_wrapped = _mm256_sub_epi64(field_order_s, y_s); + let res = _mm256_sub_epi64(res_wrapped, wrapback_amt); + res +} + +/// Full 64-bit by 64-bit multiplication. This emulated multiplication is 1.5x slower than the +/// scalar instruction, but may be worth it if we want our data to live in vector registers. +#[inline] +unsafe fn mul64_64_s(x: __m256i, y: __m256i) -> (__m256i, __m256i) { + let x_hi = _mm256_srli_epi64(x, 32); + let y_hi = _mm256_srli_epi64(y, 32); + let mul_ll = _mm256_mul_epu32(x, y); + let mul_lh = _mm256_mul_epu32(x, y_hi); + let mul_hl = _mm256_mul_epu32(x_hi, y); + let mul_hh = _mm256_mul_epu32(x_hi, y_hi); + + let res_lo0_s = shift(mul_ll); + let res_hi0 = mul_hh; + + let res_hi1 = _mm256_add_epi64(res_hi0, _mm256_srli_epi64(mul_lh, 32)); + let res_hi2 = _mm256_add_epi64(res_hi1, _mm256_srli_epi64(mul_hl, 32)); + + let res_lo3_s = _mm256_add_epi32(res_lo0_s, _mm256_slli_epi64(mul_lh, 32)); + let res_hi3 = _mm256_sub_epi64(res_hi2, _mm256_cmpgt_epi64(res_lo0_s, res_lo3_s)); // Carry. + + let res_lo4_s = _mm256_add_epi32(res_lo3_s, _mm256_slli_epi64(mul_hl, 32)); + let res_hi4 = _mm256_sub_epi64(res_hi3, _mm256_cmpgt_epi64(res_lo3_s, res_lo4_s)); // Carry. + + (res_hi4, res_lo4_s) +} + +/// u128 + u64 addition with carry. The second argument is pre-shifted by 2^63. The result is also +/// shifted. +#[inline] +unsafe fn add_with_carry128_64s_s(x: (__m256i, __m256i), y_s: __m256i) -> (__m256i, __m256i) { + let (x_hi, x_lo) = x; + let res_lo_s = _mm256_add_epi64(x_lo, y_s); + let carry = _mm256_cmpgt_epi64(y_s, res_lo_s); + let res_hi = _mm256_sub_epi64(x_hi, carry); + (res_hi, res_lo_s) +} + +/// u128 + u64 addition with carry. The first argument is pre-shifted by 2^63. The result is also +/// shifted. +#[inline] +unsafe fn add_with_carry128s_64_s(x_s: (__m256i, __m256i), y: __m256i) -> (__m256i, __m256i) { + let (x_hi, x_lo_s) = x_s; + let res_lo_s = _mm256_add_epi64(x_lo_s, y); + let carry = _mm256_cmpgt_epi64(x_lo_s, res_lo_s); + let res_hi = _mm256_sub_epi64(x_hi, carry); + (res_hi, res_lo_s) +} + +/// u64 * u32 + u64 fused multiply-add. The result is given as a tuple (u64, u64). The third +/// argument is assumed to be pre-shifted by 2^63. The result is similarly shifted. +#[inline] +unsafe fn fmadd_64_32_64s_s(x: __m256i, y: __m256i, z_s: __m256i) -> (__m256i, __m256i) { + let x_hi = _mm256_srli_epi64(x, 32); + let mul_lo = _mm256_mul_epu32(x, y); + let mul_hi = _mm256_mul_epu32(x_hi, y); + let tmp_s = add_with_carry128_64s_s((_mm256_srli_epi64(mul_hi, 32), mul_lo), z_s); + add_with_carry128s_64_s(tmp_s, _mm256_slli_epi64(mul_hi, 32)) +} + +/// Reduce a u128 modulo FIELD_ORDER. The input is (u64, u64), pre-shifted by 2^63. The result is +/// similarly shifted. +#[inline] +unsafe fn reduce128s_s(x_s: (__m256i, __m256i)) -> __m256i { + let (hi0, lo0_s) = x_s; + let (hi1, lo1_s) = fmadd_64_32_64s_s(hi0, epsilon(), lo0_s); + let lo2 = _mm256_mul_epu32(hi1, epsilon()); + let res_wrapped_s = _mm256_add_epi64(lo1_s, lo2); + let carry_mask = _mm256_cmpgt_epi64(lo1_s, res_wrapped_s); // all 1 if overflow + let res_s = _mm256_add_epi64(res_wrapped_s, _mm256_and_si256(carry_mask, epsilon())); + res_s +} + +/// Multiply two integers modulo FIELD_ORDER. +#[inline] +unsafe fn mul(x: __m256i, y: __m256i) -> __m256i { + shift(reduce128s_s(mul64_64_s(x, y))) +} + +#[inline] +unsafe fn interleave0(x: __m256i, y: __m256i) -> (__m256i, __m256i) { + let a = _mm256_unpacklo_epi64(x, y); + let b = _mm256_unpackhi_epi64(x, y); + (a, b) +} + +#[inline] +unsafe fn interleave1(x: __m256i, y: __m256i) -> (__m256i, __m256i) { + let y_lo = _mm256_castsi256_si128(y); // This has 0 cost. + + // 1 places y_lo in the high half of x; 0 would place it in the lower half. + let a = _mm256_inserti128_si256(x, y_lo, 1); + // NB: _mm256_permute2x128_si256 could be used here as well but _mm256_inserti128_si256 has + // lower latency on Zen 3 processors. + + // Each nibble of the constant has the following semantics: + // 0 => src1[low 128 bits] + // 1 => src1[high 128 bits] + // 2 => src2[low 128 bits] + // 3 => src2[high 128 bits] + // The low (resp. high) nibble chooses the low (resp. high) 128 bits of the result. + let b = _mm256_permute2x128_si256(x, y, 0x31); + + (a, b) +} + +#[cfg(test)] +mod tests { + use crate::field::crandall_field::CrandallField; + use crate::field::packed_crandall_avx2::*; + + const TEST_VALS_A: &[CrandallField] = &[ + CrandallField(14479013849828404771), + CrandallField(9087029921428221768), + CrandallField(2441288194761790662), + CrandallField(5646033492608483824), + ]; + const TEST_VALS_B: &[CrandallField] = &[ + CrandallField(17891926589593242302), + CrandallField(11009798273260028228), + CrandallField(2028722748960791447), + CrandallField(7929433601095175579), + ]; + + #[test] + fn test_add() { + let packed_a = PackedCrandallAVX2::new_from_slice(TEST_VALS_A); + let packed_b = PackedCrandallAVX2::new_from_slice(TEST_VALS_B); + let packed_res = packed_a + packed_b; + let arr_res = packed_res.to_vec(); + + let expected = TEST_VALS_A + .iter() + .zip(TEST_VALS_B.iter()) + .map(|(&a, &b)| a + b); + for (exp, res) in expected.zip(arr_res) { + assert_eq!(res, exp); + } + } + + #[test] + fn test_mul() { + let packed_a = PackedCrandallAVX2::new_from_slice(TEST_VALS_A); + let packed_b = PackedCrandallAVX2::new_from_slice(TEST_VALS_B); + let packed_res = packed_a * packed_b; + let arr_res = packed_res.to_vec(); + + let expected = TEST_VALS_A + .iter() + .zip(TEST_VALS_B.iter()) + .map(|(&a, &b)| a * b); + for (exp, res) in expected.zip(arr_res) { + assert_eq!(res, exp); + } + } + + #[test] + fn test_neg() { + let packed_a = PackedCrandallAVX2::new_from_slice(TEST_VALS_A); + let packed_res = -packed_a; + let arr_res = packed_res.to_vec(); + + let expected = TEST_VALS_A.iter().map(|&a| -a); + for (exp, res) in expected.zip(arr_res) { + assert_eq!(res, exp); + } + } + + #[test] + fn test_sub() { + let packed_a = PackedCrandallAVX2::new_from_slice(TEST_VALS_A); + let packed_b = PackedCrandallAVX2::new_from_slice(TEST_VALS_B); + let packed_res = packed_a - packed_b; + let arr_res = packed_res.to_vec(); + + let expected = TEST_VALS_A + .iter() + .zip(TEST_VALS_B.iter()) + .map(|(&a, &b)| a - b); + for (exp, res) in expected.zip(arr_res) { + assert_eq!(res, exp); + } + } + + #[test] + fn test_interleave_is_bijection() { + let packed_a = PackedCrandallAVX2::new_from_slice(TEST_VALS_A); + let packed_b = PackedCrandallAVX2::new_from_slice(TEST_VALS_B); + { + // Interleave, then deinterleave. + let (x, y) = packed_a.interleave(packed_b, 0); + let (res_a, res_b) = x.interleave(y, 0); + assert_eq!(res_a.to_vec(), TEST_VALS_A); + assert_eq!(res_b.to_vec(), TEST_VALS_B); + } + { + let (x, y) = packed_a.interleave(packed_b, 1); + let (res_a, res_b) = x.interleave(y, 1); + assert_eq!(res_a.to_vec(), TEST_VALS_A); + assert_eq!(res_b.to_vec(), TEST_VALS_B); + } + } + + #[test] + fn test_interleave() { + let arr_a: [CrandallField; 4] = [ + CrandallField(00), + CrandallField(01), + CrandallField(02), + CrandallField(03), + ]; + let arr_b: [CrandallField; 4] = [ + CrandallField(10), + CrandallField(11), + CrandallField(12), + CrandallField(13), + ]; + let arr_x0: [CrandallField; 4] = [ + CrandallField(00), + CrandallField(10), + CrandallField(02), + CrandallField(12), + ]; + let arr_y0: [CrandallField; 4] = [ + CrandallField(01), + CrandallField(11), + CrandallField(03), + CrandallField(13), + ]; + let arr_x1: [CrandallField; 4] = [ + CrandallField(00), + CrandallField(01), + CrandallField(10), + CrandallField(11), + ]; + let arr_y1: [CrandallField; 4] = [ + CrandallField(02), + CrandallField(03), + CrandallField(12), + CrandallField(13), + ]; + + let packed_a = PackedCrandallAVX2::new_from_slice(&arr_a); + let packed_b = PackedCrandallAVX2::new_from_slice(&arr_b); + { + let (x0, y0) = packed_a.interleave(packed_b, 0); + assert_eq!(x0.to_vec()[..], arr_x0); + assert_eq!(y0.to_vec()[..], arr_y0); + } + { + let (x1, y1) = packed_a.interleave(packed_b, 1); + assert_eq!(x1.to_vec()[..], arr_x1); + assert_eq!(y1.to_vec()[..], arr_y1); + } + } +}