mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-08 08:43:06 +00:00
Vectorized FFT (#223)
* Vectorized FFT * Cleanup * Use updated FieldPacking * Use to_vec/from_slice (+ typo) * Cleanup + Daniel's comments
This commit is contained in:
parent
bdd86a306f
commit
a8d08aa153
@ -1,6 +1,11 @@
|
||||
use std::cmp::{max, min};
|
||||
use std::option::Option;
|
||||
|
||||
use unroll::unroll_for_loops;
|
||||
|
||||
use crate::field::field_types::Field;
|
||||
use crate::field::packable::Packable;
|
||||
use crate::field::packed_field::{PackedField, Singleton};
|
||||
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues};
|
||||
use crate::util::{log2_strict, reverse_index_bits};
|
||||
|
||||
@ -137,6 +142,73 @@ pub fn ifft_with_options<F: Field>(
|
||||
PolynomialCoeffs { coeffs }
|
||||
}
|
||||
|
||||
/// Generic FFT implementation that works with both scalar and packed inputs.
|
||||
#[unroll_for_loops]
|
||||
fn fft_classic_simd<P: PackedField>(
|
||||
values: &mut [P::FieldType],
|
||||
r: usize,
|
||||
lg_n: usize,
|
||||
root_table: &FftRootTable<P::FieldType>,
|
||||
) {
|
||||
let lg_packed_width = P::LOG2_WIDTH; // 0 when P is a scalar.
|
||||
let packed_values = P::pack_slice_mut(values);
|
||||
let packed_n = packed_values.len();
|
||||
debug_assert!(packed_n == 1 << (lg_n - lg_packed_width));
|
||||
|
||||
// Want the below for loop to unroll, hence the need for a literal.
|
||||
// This loop will not run when P is a scalar.
|
||||
assert!(lg_packed_width <= 4);
|
||||
for lg_half_m in 0..4 {
|
||||
if (r..min(lg_n, lg_packed_width)).contains(&lg_half_m) {
|
||||
// Intuitively, we split values into m slices: subarr[0], ..., subarr[m - 1]. Each of
|
||||
// those slices is split into two halves: subarr[j].left, subarr[j].right. We do
|
||||
// (subarr[j].left[k], subarr[j].right[k])
|
||||
// := f(subarr[j].left[k], subarr[j].right[k], omega[k]),
|
||||
// where f(u, v, omega) = (u + omega * v, u - omega * v).
|
||||
let half_m = 1 << lg_half_m;
|
||||
|
||||
// Set omega to root_table[lg_half_m][0..half_m] but repeated.
|
||||
let mut omega_vec = P::zero().to_vec();
|
||||
for j in 0..omega_vec.len() {
|
||||
omega_vec[j] = root_table[lg_half_m][j % half_m];
|
||||
}
|
||||
let omega = P::from_slice(&omega_vec[..]);
|
||||
|
||||
for k in (0..packed_n).step_by(2) {
|
||||
// We have two vectors and want to do math on pairs of adjacent elements (or for
|
||||
// lg_half_m > 0, pairs of adjacent blocks of elements). .interleave does the
|
||||
// appropriate shuffling and is its own inverse.
|
||||
let (u, v) = packed_values[k].interleave(packed_values[k + 1], lg_half_m);
|
||||
let t = omega * v;
|
||||
(packed_values[k], packed_values[k + 1]) = (u + t).interleave(u - t, lg_half_m);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We've already done the first lg_packed_width (if they were required) iterations.
|
||||
let s = max(r, lg_packed_width);
|
||||
|
||||
for lg_half_m in s..lg_n {
|
||||
let lg_m = lg_half_m + 1;
|
||||
let m = 1 << lg_m; // Subarray size (in field elements).
|
||||
let packed_m = m >> lg_packed_width; // Subarray size (in vectors).
|
||||
let half_packed_m = packed_m / 2;
|
||||
debug_assert!(half_packed_m != 0);
|
||||
|
||||
// omega values for this iteration, as slice of vectors
|
||||
let omega_table = P::pack_slice(&root_table[lg_half_m][..]);
|
||||
for k in (0..packed_n).step_by(packed_m) {
|
||||
for j in 0..half_packed_m {
|
||||
let omega = omega_table[j];
|
||||
let t = omega * packed_values[k + half_packed_m + j];
|
||||
let u = packed_values[k + j];
|
||||
packed_values[k + j] = u + t;
|
||||
packed_values[k + half_packed_m + j] = u - t;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// FFT implementation based on Section 32.3 of "Introduction to
|
||||
/// Algorithms" by Cormen et al.
|
||||
///
|
||||
@ -172,19 +244,13 @@ pub(crate) fn fft_classic<F: Field>(input: &[F], r: usize, root_table: FftRootTa
|
||||
}
|
||||
}
|
||||
|
||||
let mut m = 1 << (r + 1);
|
||||
for lg_m in (r + 1)..=lg_n {
|
||||
let half_m = m / 2;
|
||||
for k in (0..n).step_by(m) {
|
||||
for j in 0..half_m {
|
||||
let omega = root_table[lg_m - 1][j];
|
||||
let t = omega * values[k + half_m + j];
|
||||
let u = values[k + j];
|
||||
values[k + j] = u + t;
|
||||
values[k + half_m + j] = u - t;
|
||||
}
|
||||
}
|
||||
m *= 2;
|
||||
let lg_packed_width = <F as Packable>::PackedType::LOG2_WIDTH;
|
||||
if lg_n <= lg_packed_width {
|
||||
// Need the slice to be at least the width of two packed vectors for the vectorized version
|
||||
// to work. Do this tiny problem in scalar.
|
||||
fft_classic_simd::<Singleton<F>>(&mut values[..], r, lg_n, &root_table);
|
||||
} else {
|
||||
fft_classic_simd::<<F as Packable>::PackedType>(&mut values[..], r, lg_n, &root_table);
|
||||
}
|
||||
values
|
||||
}
|
||||
|
||||
@ -151,6 +151,22 @@ impl PackedField for PackedCrandallAVX2 {
|
||||
]
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn from_slice(slice: &[Self::FieldType]) -> Self {
|
||||
assert!(slice.len() == 4);
|
||||
Self::from_arr([slice[0], slice[1], slice[2], slice[3]])
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn to_vec(&self) -> Vec<Self::FieldType> {
|
||||
vec![
|
||||
CrandallField(self.0[0]),
|
||||
CrandallField(self.0[1]),
|
||||
CrandallField(self.0[2]),
|
||||
CrandallField(self.0[3]),
|
||||
]
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn interleave(&self, other: Self, r: usize) -> (Self, Self) {
|
||||
let (v0, v1) = (self.get(), other.get());
|
||||
|
||||
@ -50,6 +50,9 @@ pub trait PackedField:
|
||||
fn from_arr(arr: [Self::FieldType; Self::WIDTH]) -> Self;
|
||||
fn to_arr(&self) -> [Self::FieldType; Self::WIDTH];
|
||||
|
||||
fn from_slice(slice: &[Self::FieldType]) -> Self;
|
||||
fn to_vec(&self) -> Vec<Self::FieldType>;
|
||||
|
||||
/// Take interpret two vectors as chunks of (1 << r) elements. Unpack and interleave those
|
||||
/// chunks. This is best seen with an example. If we have:
|
||||
/// A = [x0, y0, x1, y1],
|
||||
@ -183,6 +186,15 @@ impl<F: Field> PackedField for Singleton<F> {
|
||||
[self.0]
|
||||
}
|
||||
|
||||
fn from_slice(slice: &[Self::FieldType]) -> Self {
|
||||
assert!(slice.len() == 1);
|
||||
Self(slice[0])
|
||||
}
|
||||
|
||||
fn to_vec(&self) -> Vec<Self::FieldType> {
|
||||
vec![self.0]
|
||||
}
|
||||
|
||||
fn interleave(&self, other: Self, r: usize) -> (Self, Self) {
|
||||
match r {
|
||||
0 => (*self, other), // This is a no-op whenever r == LOG2_WIDTH.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user