diff --git a/src/field/fft.rs b/src/field/fft.rs index 4a60f888..46aa3981 100644 --- a/src/field/fft.rs +++ b/src/field/fft.rs @@ -1,6 +1,11 @@ +use std::cmp::{max, min}; use std::option::Option; +use unroll::unroll_for_loops; + use crate::field::field_types::Field; +use crate::field::packable::Packable; +use crate::field::packed_field::{PackedField, Singleton}; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::util::{log2_strict, reverse_index_bits}; @@ -137,6 +142,73 @@ pub fn ifft_with_options( PolynomialCoeffs { coeffs } } +/// Generic FFT implementation that works with both scalar and packed inputs. +#[unroll_for_loops] +fn fft_classic_simd( + values: &mut [P::FieldType], + r: usize, + lg_n: usize, + root_table: &FftRootTable, +) { + let lg_packed_width = P::LOG2_WIDTH; // 0 when P is a scalar. + let packed_values = P::pack_slice_mut(values); + let packed_n = packed_values.len(); + debug_assert!(packed_n == 1 << (lg_n - lg_packed_width)); + + // Want the below for loop to unroll, hence the need for a literal. + // This loop will not run when P is a scalar. + assert!(lg_packed_width <= 4); + for lg_half_m in 0..4 { + if (r..min(lg_n, lg_packed_width)).contains(&lg_half_m) { + // Intuitively, we split values into m slices: subarr[0], ..., subarr[m - 1]. Each of + // those slices is split into two halves: subarr[j].left, subarr[j].right. We do + // (subarr[j].left[k], subarr[j].right[k]) + // := f(subarr[j].left[k], subarr[j].right[k], omega[k]), + // where f(u, v, omega) = (u + omega * v, u - omega * v). + let half_m = 1 << lg_half_m; + + // Set omega to root_table[lg_half_m][0..half_m] but repeated. + let mut omega_vec = P::zero().to_vec(); + for j in 0..omega_vec.len() { + omega_vec[j] = root_table[lg_half_m][j % half_m]; + } + let omega = P::from_slice(&omega_vec[..]); + + for k in (0..packed_n).step_by(2) { + // We have two vectors and want to do math on pairs of adjacent elements (or for + // lg_half_m > 0, pairs of adjacent blocks of elements). .interleave does the + // appropriate shuffling and is its own inverse. + let (u, v) = packed_values[k].interleave(packed_values[k + 1], lg_half_m); + let t = omega * v; + (packed_values[k], packed_values[k + 1]) = (u + t).interleave(u - t, lg_half_m); + } + } + } + + // We've already done the first lg_packed_width (if they were required) iterations. + let s = max(r, lg_packed_width); + + for lg_half_m in s..lg_n { + let lg_m = lg_half_m + 1; + let m = 1 << lg_m; // Subarray size (in field elements). + let packed_m = m >> lg_packed_width; // Subarray size (in vectors). + let half_packed_m = packed_m / 2; + debug_assert!(half_packed_m != 0); + + // omega values for this iteration, as slice of vectors + let omega_table = P::pack_slice(&root_table[lg_half_m][..]); + for k in (0..packed_n).step_by(packed_m) { + for j in 0..half_packed_m { + let omega = omega_table[j]; + let t = omega * packed_values[k + half_packed_m + j]; + let u = packed_values[k + j]; + packed_values[k + j] = u + t; + packed_values[k + half_packed_m + j] = u - t; + } + } + } +} + /// FFT implementation based on Section 32.3 of "Introduction to /// Algorithms" by Cormen et al. /// @@ -172,19 +244,13 @@ pub(crate) fn fft_classic(input: &[F], r: usize, root_table: FftRootTa } } - let mut m = 1 << (r + 1); - for lg_m in (r + 1)..=lg_n { - let half_m = m / 2; - for k in (0..n).step_by(m) { - for j in 0..half_m { - let omega = root_table[lg_m - 1][j]; - let t = omega * values[k + half_m + j]; - let u = values[k + j]; - values[k + j] = u + t; - values[k + half_m + j] = u - t; - } - } - m *= 2; + let lg_packed_width = ::PackedType::LOG2_WIDTH; + if lg_n <= lg_packed_width { + // Need the slice to be at least the width of two packed vectors for the vectorized version + // to work. Do this tiny problem in scalar. + fft_classic_simd::>(&mut values[..], r, lg_n, &root_table); + } else { + fft_classic_simd::<::PackedType>(&mut values[..], r, lg_n, &root_table); } values } diff --git a/src/field/packed_crandall_avx2.rs b/src/field/packed_crandall_avx2.rs index 59315e32..4eb900e7 100644 --- a/src/field/packed_crandall_avx2.rs +++ b/src/field/packed_crandall_avx2.rs @@ -151,6 +151,22 @@ impl PackedField for PackedCrandallAVX2 { ] } + #[inline] + fn from_slice(slice: &[Self::FieldType]) -> Self { + assert!(slice.len() == 4); + Self::from_arr([slice[0], slice[1], slice[2], slice[3]]) + } + + #[inline] + fn to_vec(&self) -> Vec { + vec![ + CrandallField(self.0[0]), + CrandallField(self.0[1]), + CrandallField(self.0[2]), + CrandallField(self.0[3]), + ] + } + #[inline] fn interleave(&self, other: Self, r: usize) -> (Self, Self) { let (v0, v1) = (self.get(), other.get()); diff --git a/src/field/packed_field.rs b/src/field/packed_field.rs index 15ee5a9e..a4b1945a 100644 --- a/src/field/packed_field.rs +++ b/src/field/packed_field.rs @@ -50,6 +50,9 @@ pub trait PackedField: fn from_arr(arr: [Self::FieldType; Self::WIDTH]) -> Self; fn to_arr(&self) -> [Self::FieldType; Self::WIDTH]; + fn from_slice(slice: &[Self::FieldType]) -> Self; + fn to_vec(&self) -> Vec; + /// Take interpret two vectors as chunks of (1 << r) elements. Unpack and interleave those /// chunks. This is best seen with an example. If we have: /// A = [x0, y0, x1, y1], @@ -183,6 +186,15 @@ impl PackedField for Singleton { [self.0] } + fn from_slice(slice: &[Self::FieldType]) -> Self { + assert!(slice.len() == 1); + Self(slice[0]) + } + + fn to_vec(&self) -> Vec { + vec![self.0] + } + fn interleave(&self, other: Self, r: usize) -> (Self, Self) { match r { 0 => (*self, other), // This is a no-op whenever r == LOG2_WIDTH.