Vectorized FFT (#223)

* Vectorized FFT

* Cleanup

* Use updated FieldPacking

* Use to_vec/from_slice (+ typo)

* Cleanup + Daniel's comments
This commit is contained in:
Jakub Nabaglo 2021-09-12 16:54:25 -07:00 committed by GitHub
parent bdd86a306f
commit a8d08aa153
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 107 additions and 13 deletions

View File

@ -1,6 +1,11 @@
use std::cmp::{max, min};
use std::option::Option;
use unroll::unroll_for_loops;
use crate::field::field_types::Field;
use crate::field::packable::Packable;
use crate::field::packed_field::{PackedField, Singleton};
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::util::{log2_strict, reverse_index_bits};
@ -137,6 +142,73 @@ pub fn ifft_with_options<F: Field>(
PolynomialCoeffs { coeffs }
}
/// Generic FFT implementation that works with both scalar and packed inputs.
#[unroll_for_loops]
fn fft_classic_simd<P: PackedField>(
values: &mut [P::FieldType],
r: usize,
lg_n: usize,
root_table: &FftRootTable<P::FieldType>,
) {
let lg_packed_width = P::LOG2_WIDTH; // 0 when P is a scalar.
let packed_values = P::pack_slice_mut(values);
let packed_n = packed_values.len();
debug_assert!(packed_n == 1 << (lg_n - lg_packed_width));
// Want the below for loop to unroll, hence the need for a literal.
// This loop will not run when P is a scalar.
assert!(lg_packed_width <= 4);
for lg_half_m in 0..4 {
if (r..min(lg_n, lg_packed_width)).contains(&lg_half_m) {
// Intuitively, we split values into m slices: subarr[0], ..., subarr[m - 1]. Each of
// those slices is split into two halves: subarr[j].left, subarr[j].right. We do
// (subarr[j].left[k], subarr[j].right[k])
// := f(subarr[j].left[k], subarr[j].right[k], omega[k]),
// where f(u, v, omega) = (u + omega * v, u - omega * v).
let half_m = 1 << lg_half_m;
// Set omega to root_table[lg_half_m][0..half_m] but repeated.
let mut omega_vec = P::zero().to_vec();
for j in 0..omega_vec.len() {
omega_vec[j] = root_table[lg_half_m][j % half_m];
}
let omega = P::from_slice(&omega_vec[..]);
for k in (0..packed_n).step_by(2) {
// We have two vectors and want to do math on pairs of adjacent elements (or for
// lg_half_m > 0, pairs of adjacent blocks of elements). .interleave does the
// appropriate shuffling and is its own inverse.
let (u, v) = packed_values[k].interleave(packed_values[k + 1], lg_half_m);
let t = omega * v;
(packed_values[k], packed_values[k + 1]) = (u + t).interleave(u - t, lg_half_m);
}
}
}
// We've already done the first lg_packed_width (if they were required) iterations.
let s = max(r, lg_packed_width);
for lg_half_m in s..lg_n {
let lg_m = lg_half_m + 1;
let m = 1 << lg_m; // Subarray size (in field elements).
let packed_m = m >> lg_packed_width; // Subarray size (in vectors).
let half_packed_m = packed_m / 2;
debug_assert!(half_packed_m != 0);
// omega values for this iteration, as slice of vectors
let omega_table = P::pack_slice(&root_table[lg_half_m][..]);
for k in (0..packed_n).step_by(packed_m) {
for j in 0..half_packed_m {
let omega = omega_table[j];
let t = omega * packed_values[k + half_packed_m + j];
let u = packed_values[k + j];
packed_values[k + j] = u + t;
packed_values[k + half_packed_m + j] = u - t;
}
}
}
}
/// FFT implementation based on Section 32.3 of "Introduction to
/// Algorithms" by Cormen et al.
///
@ -172,19 +244,13 @@ pub(crate) fn fft_classic<F: Field>(input: &[F], r: usize, root_table: FftRootTa
}
}
let mut m = 1 << (r + 1);
for lg_m in (r + 1)..=lg_n {
let half_m = m / 2;
for k in (0..n).step_by(m) {
for j in 0..half_m {
let omega = root_table[lg_m - 1][j];
let t = omega * values[k + half_m + j];
let u = values[k + j];
values[k + j] = u + t;
values[k + half_m + j] = u - t;
}
}
m *= 2;
let lg_packed_width = <F as Packable>::PackedType::LOG2_WIDTH;
if lg_n <= lg_packed_width {
// Need the slice to be at least the width of two packed vectors for the vectorized version
// to work. Do this tiny problem in scalar.
fft_classic_simd::<Singleton<F>>(&mut values[..], r, lg_n, &root_table);
} else {
fft_classic_simd::<<F as Packable>::PackedType>(&mut values[..], r, lg_n, &root_table);
}
values
}

View File

@ -151,6 +151,22 @@ impl PackedField for PackedCrandallAVX2 {
]
}
#[inline]
fn from_slice(slice: &[Self::FieldType]) -> Self {
assert!(slice.len() == 4);
Self::from_arr([slice[0], slice[1], slice[2], slice[3]])
}
#[inline]
fn to_vec(&self) -> Vec<Self::FieldType> {
vec![
CrandallField(self.0[0]),
CrandallField(self.0[1]),
CrandallField(self.0[2]),
CrandallField(self.0[3]),
]
}
#[inline]
fn interleave(&self, other: Self, r: usize) -> (Self, Self) {
let (v0, v1) = (self.get(), other.get());

View File

@ -50,6 +50,9 @@ pub trait PackedField:
fn from_arr(arr: [Self::FieldType; Self::WIDTH]) -> Self;
fn to_arr(&self) -> [Self::FieldType; Self::WIDTH];
fn from_slice(slice: &[Self::FieldType]) -> Self;
fn to_vec(&self) -> Vec<Self::FieldType>;
/// Take interpret two vectors as chunks of (1 << r) elements. Unpack and interleave those
/// chunks. This is best seen with an example. If we have:
/// A = [x0, y0, x1, y1],
@ -183,6 +186,15 @@ impl<F: Field> PackedField for Singleton<F> {
[self.0]
}
fn from_slice(slice: &[Self::FieldType]) -> Self {
assert!(slice.len() == 1);
Self(slice[0])
}
fn to_vec(&self) -> Vec<Self::FieldType> {
vec![self.0]
}
fn interleave(&self, other: Self, r: usize) -> (Self, Self) {
match r {
0 => (*self, other), // This is a no-op whenever r == LOG2_WIDTH.