Make all FFTs in-place (#439)

* Make all FFTs in-place

* Delete leftover marker
This commit is contained in:
Jakub Nabaglo 2022-01-21 10:26:43 -08:00 committed by GitHub
parent 2e3a682bde
commit 86dc4c933a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 45 additions and 44 deletions

View File

@ -1,7 +1,7 @@
use std::cmp::{max, min};
use std::option::Option;
use plonky2_util::{log2_strict, reverse_index_bits};
use plonky2_util::{log2_strict, reverse_index_bits_in_place};
use unroll::unroll_for_loops;
use crate::field_types::Field;
@ -34,10 +34,10 @@ pub fn fft_root_table<F: Field>(n: usize) -> FftRootTable<F> {
#[inline]
fn fft_dispatch<F: Field>(
input: &[F],
input: &mut [F],
zero_factor: Option<usize>,
root_table: Option<&FftRootTable<F>>,
) -> Vec<F> {
) {
let computed_root_table = if root_table.is_some() {
None
} else {
@ -45,33 +45,32 @@ fn fft_dispatch<F: Field>(
};
let used_root_table = root_table.or(computed_root_table.as_ref()).unwrap();
fft_classic(input, zero_factor.unwrap_or(0), used_root_table)
fft_classic(input, zero_factor.unwrap_or(0), used_root_table);
}
#[inline]
pub fn fft<F: Field>(poly: &PolynomialCoeffs<F>) -> PolynomialValues<F> {
pub fn fft<F: Field>(poly: PolynomialCoeffs<F>) -> PolynomialValues<F> {
fft_with_options(poly, None, None)
}
#[inline]
pub fn fft_with_options<F: Field>(
poly: &PolynomialCoeffs<F>,
poly: PolynomialCoeffs<F>,
zero_factor: Option<usize>,
root_table: Option<&FftRootTable<F>>,
) -> PolynomialValues<F> {
let PolynomialCoeffs { coeffs } = poly;
PolynomialValues {
values: fft_dispatch(coeffs, zero_factor, root_table),
}
let PolynomialCoeffs { coeffs: mut buffer } = poly;
fft_dispatch(&mut buffer, zero_factor, root_table);
PolynomialValues { values: buffer }
}
#[inline]
pub fn ifft<F: Field>(poly: &PolynomialValues<F>) -> PolynomialCoeffs<F> {
pub fn ifft<F: Field>(poly: PolynomialValues<F>) -> PolynomialCoeffs<F> {
ifft_with_options(poly, None, None)
}
pub fn ifft_with_options<F: Field>(
poly: &PolynomialValues<F>,
poly: PolynomialValues<F>,
zero_factor: Option<usize>,
root_table: Option<&FftRootTable<F>>,
) -> PolynomialCoeffs<F> {
@ -79,20 +78,20 @@ pub fn ifft_with_options<F: Field>(
let lg_n = log2_strict(n);
let n_inv = F::inverse_2exp(lg_n);
let PolynomialValues { values } = poly;
let mut coeffs = fft_dispatch(values, zero_factor, root_table);
let PolynomialValues { values: mut buffer } = poly;
fft_dispatch(&mut buffer, zero_factor, root_table);
// We reverse all values except the first, and divide each by n.
coeffs[0] *= n_inv;
coeffs[n / 2] *= n_inv;
buffer[0] *= n_inv;
buffer[n / 2] *= n_inv;
for i in 1..(n / 2) {
let j = n - i;
let coeffs_i = coeffs[j] * n_inv;
let coeffs_j = coeffs[i] * n_inv;
coeffs[i] = coeffs_i;
coeffs[j] = coeffs_j;
let coeffs_i = buffer[j] * n_inv;
let coeffs_j = buffer[i] * n_inv;
buffer[i] = coeffs_i;
buffer[j] = coeffs_j;
}
PolynomialCoeffs { coeffs }
PolynomialCoeffs { coeffs: buffer }
}
/// Generic FFT implementation that works with both scalar and packed inputs.
@ -167,8 +166,8 @@ fn fft_classic_simd<P: PackedField>(
/// The parameter r signifies that the first 1/2^r of the entries of
/// input may be non-zero, but the last 1 - 1/2^r entries are
/// definitely zero.
pub(crate) fn fft_classic<F: Field>(input: &[F], r: usize, root_table: &FftRootTable<F>) -> Vec<F> {
let mut values = reverse_index_bits(input);
pub(crate) fn fft_classic<F: Field>(values: &mut [F], r: usize, root_table: &FftRootTable<F>) {
reverse_index_bits_in_place(values);
let n = values.len();
let lg_n = log2_strict(n);
@ -200,11 +199,10 @@ pub(crate) fn fft_classic<F: Field>(input: &[F], r: usize, root_table: &FftRootT
if lg_n <= lg_packed_width {
// Need the slice to be at least the width of two packed vectors for the vectorized version
// to work. Do this tiny problem in scalar.
fft_classic_simd::<F>(&mut values[..], r, lg_n, root_table);
fft_classic_simd::<F>(values, r, lg_n, root_table);
} else {
fft_classic_simd::<<F as Packable>::Packing>(&mut values[..], r, lg_n, root_table);
fft_classic_simd::<<F as Packable>::Packing>(values, r, lg_n, root_table);
}
values
}
#[cfg(test)]
@ -231,10 +229,10 @@ mod tests {
assert_eq!(coeffs.len(), degree_padded);
let coefficients = PolynomialCoeffs { coeffs };
let points = fft(&coefficients);
let points = fft(coefficients.clone());
assert_eq!(points, evaluate_naive(&coefficients));
let interpolated_coefficients = ifft(&points);
let interpolated_coefficients = ifft(points);
for i in 0..degree {
assert_eq!(interpolated_coefficients.coeffs[i], coefficients.coeffs[i]);
}
@ -245,7 +243,10 @@ mod tests {
for r in 0..4 {
// expand coefficients by factor 2^r by filling with zeros
let zero_tail = coefficients.lde(r);
assert_eq!(fft(&zero_tail), fft_with_options(&zero_tail, Some(r), None));
assert_eq!(
fft(zero_tail.clone()),
fft_with_options(zero_tail, Some(r), None)
);
}
}

View File

@ -19,7 +19,7 @@ pub fn interpolant<F: Field>(points: &[(F, F)]) -> PolynomialCoeffs<F> {
.map(|x| interpolate(points, x, &barycentric_weights))
.collect();
let mut coeffs = ifft(&PolynomialValues {
let mut coeffs = ifft(PolynomialValues {
values: subgroup_evals,
});
coeffs.trim();

View File

@ -31,12 +31,12 @@ impl<F: Field> PolynomialValues<F> {
self.values.len()
}
pub fn ifft(&self) -> PolynomialCoeffs<F> {
pub fn ifft(self) -> PolynomialCoeffs<F> {
ifft(self)
}
/// Returns the polynomial whose evaluation on the coset `shift*H` is `self`.
pub fn coset_ifft(&self, shift: F) -> PolynomialCoeffs<F> {
pub fn coset_ifft(self, shift: F) -> PolynomialCoeffs<F> {
let mut shifted_coeffs = self.ifft();
shifted_coeffs
.coeffs
@ -52,9 +52,9 @@ impl<F: Field> PolynomialValues<F> {
polys.into_iter().map(|p| p.lde(rate_bits)).collect()
}
pub fn lde(&self, rate_bits: usize) -> Self {
pub fn lde(self, rate_bits: usize) -> Self {
let coeffs = ifft(self).lde(rate_bits);
fft_with_options(&coeffs, Some(rate_bits), None)
fft_with_options(coeffs, Some(rate_bits), None)
}
pub fn degree(&self) -> usize {
@ -64,7 +64,7 @@ impl<F: Field> PolynomialValues<F> {
}
pub fn degree_plus_one(&self) -> usize {
self.ifft().degree_plus_one()
self.clone().ifft().degree_plus_one()
}
}
@ -213,12 +213,12 @@ impl<F: Field> PolynomialCoeffs<F> {
Self::new(self.trimmed().coeffs.into_iter().rev().collect())
}
pub fn fft(&self) -> PolynomialValues<F> {
pub fn fft(self) -> PolynomialValues<F> {
fft(self)
}
pub fn fft_with_options(
&self,
self,
zero_factor: Option<usize>,
root_table: Option<&FftRootTable<F>>,
) -> PolynomialValues<F> {
@ -386,7 +386,7 @@ impl<F: Field> Mul for &PolynomialCoeffs<F> {
.zip(b_evals.values)
.map(|(pa, pb)| pa * pb)
.collect();
ifft(&mul_evals.into())
ifft(mul_evals.into())
}
}
@ -454,7 +454,7 @@ mod tests {
let n = 1 << k;
let evals = PolynomialValues::new(F::rand_vec(n));
let shift = F::rand();
let coeffs = evals.coset_ifft(shift);
let coeffs = evals.clone().coset_ifft(shift);
let generator = F::primitive_root_of_unity(k);
let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n)

View File

@ -11,7 +11,7 @@ pub(crate) fn bench_ffts<F: Field>(c: &mut Criterion) {
let size = 1 << size_log;
group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, _| {
let coeffs = PolynomialCoeffs::new(F::rand_vec(size));
b.iter(|| coeffs.fft_with_options(None, None));
b.iter(|| coeffs.clone().fft_with_options(None, None));
});
}
}

View File

@ -47,7 +47,7 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
let coeffs = timed!(
timing,
"IFFT",
values.par_iter().map(|v| v.ifft()).collect::<Vec<_>>()
values.into_par_iter().map(|v| v.ifft()).collect::<Vec<_>>()
);
Self::from_coeffs(

View File

@ -80,7 +80,7 @@ fn reverse_index_bits_large<T: Copy>(arr: &[T], n_power: usize) -> Vec<T> {
result
}
pub fn reverse_index_bits_in_place<T>(arr: &mut Vec<T>) {
pub fn reverse_index_bits_in_place<T>(arr: &mut [T]) {
let n = arr.len();
let n_power = log2_strict(n);
@ -101,7 +101,7 @@ pub fn reverse_index_bits_in_place<T>(arr: &mut Vec<T>) {
where reverse_bits(src, n_power) computes the n_power-bit reverse.
*/
fn reverse_index_bits_in_place_small<T>(arr: &mut Vec<T>, n_power: usize) {
fn reverse_index_bits_in_place_small<T>(arr: &mut [T], n_power: usize) {
let n = arr.len();
// BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them n_power-bit reverses.
let dst_shr_amt = 6 - n_power;
@ -113,7 +113,7 @@ fn reverse_index_bits_in_place_small<T>(arr: &mut Vec<T>, n_power: usize) {
}
}
fn reverse_index_bits_in_place_large<T>(arr: &mut Vec<T>, n_power: usize) {
fn reverse_index_bits_in_place_large<T>(arr: &mut [T], n_power: usize) {
let n = arr.len();
// LLVM does not know that it does not need to reverse src at each iteration (which is expensive
// on x86). We take advantage of the fact that the low bits of dst change rarely and the high