From 86dc4c933ae28e4ba7a3da3c595051bc8b31de7a Mon Sep 17 00:00:00 2001 From: Jakub Nabaglo Date: Fri, 21 Jan 2022 10:26:43 -0800 Subject: [PATCH] Make all FFTs in-place (#439) * Make all FFTs in-place * Delete leftover marker --- field/src/fft.rs | 59 +++++++++++++++++++------------------ field/src/interpolation.rs | 2 +- field/src/polynomial/mod.rs | 18 +++++------ plonky2/benches/ffts.rs | 2 +- plonky2/src/fri/oracle.rs | 2 +- util/src/lib.rs | 6 ++-- 6 files changed, 45 insertions(+), 44 deletions(-) diff --git a/field/src/fft.rs b/field/src/fft.rs index 8428d3fb..c548d51e 100644 --- a/field/src/fft.rs +++ b/field/src/fft.rs @@ -1,7 +1,7 @@ use std::cmp::{max, min}; use std::option::Option; -use plonky2_util::{log2_strict, reverse_index_bits}; +use plonky2_util::{log2_strict, reverse_index_bits_in_place}; use unroll::unroll_for_loops; use crate::field_types::Field; @@ -34,10 +34,10 @@ pub fn fft_root_table(n: usize) -> FftRootTable { #[inline] fn fft_dispatch( - input: &[F], + input: &mut [F], zero_factor: Option, root_table: Option<&FftRootTable>, -) -> Vec { +) { let computed_root_table = if root_table.is_some() { None } else { @@ -45,33 +45,32 @@ fn fft_dispatch( }; let used_root_table = root_table.or(computed_root_table.as_ref()).unwrap(); - fft_classic(input, zero_factor.unwrap_or(0), used_root_table) + fft_classic(input, zero_factor.unwrap_or(0), used_root_table); } #[inline] -pub fn fft(poly: &PolynomialCoeffs) -> PolynomialValues { +pub fn fft(poly: PolynomialCoeffs) -> PolynomialValues { fft_with_options(poly, None, None) } #[inline] pub fn fft_with_options( - poly: &PolynomialCoeffs, + poly: PolynomialCoeffs, zero_factor: Option, root_table: Option<&FftRootTable>, ) -> PolynomialValues { - let PolynomialCoeffs { coeffs } = poly; - PolynomialValues { - values: fft_dispatch(coeffs, zero_factor, root_table), - } + let PolynomialCoeffs { coeffs: mut buffer } = poly; + fft_dispatch(&mut buffer, zero_factor, root_table); + PolynomialValues { values: buffer } } #[inline] -pub fn ifft(poly: &PolynomialValues) -> PolynomialCoeffs { +pub fn ifft(poly: PolynomialValues) -> PolynomialCoeffs { ifft_with_options(poly, None, None) } pub fn ifft_with_options( - poly: &PolynomialValues, + poly: PolynomialValues, zero_factor: Option, root_table: Option<&FftRootTable>, ) -> PolynomialCoeffs { @@ -79,20 +78,20 @@ pub fn ifft_with_options( let lg_n = log2_strict(n); let n_inv = F::inverse_2exp(lg_n); - let PolynomialValues { values } = poly; - let mut coeffs = fft_dispatch(values, zero_factor, root_table); + let PolynomialValues { values: mut buffer } = poly; + fft_dispatch(&mut buffer, zero_factor, root_table); // We reverse all values except the first, and divide each by n. - coeffs[0] *= n_inv; - coeffs[n / 2] *= n_inv; + buffer[0] *= n_inv; + buffer[n / 2] *= n_inv; for i in 1..(n / 2) { let j = n - i; - let coeffs_i = coeffs[j] * n_inv; - let coeffs_j = coeffs[i] * n_inv; - coeffs[i] = coeffs_i; - coeffs[j] = coeffs_j; + let coeffs_i = buffer[j] * n_inv; + let coeffs_j = buffer[i] * n_inv; + buffer[i] = coeffs_i; + buffer[j] = coeffs_j; } - PolynomialCoeffs { coeffs } + PolynomialCoeffs { coeffs: buffer } } /// Generic FFT implementation that works with both scalar and packed inputs. @@ -167,8 +166,8 @@ fn fft_classic_simd( /// The parameter r signifies that the first 1/2^r of the entries of /// input may be non-zero, but the last 1 - 1/2^r entries are /// definitely zero. -pub(crate) fn fft_classic(input: &[F], r: usize, root_table: &FftRootTable) -> Vec { - let mut values = reverse_index_bits(input); +pub(crate) fn fft_classic(values: &mut [F], r: usize, root_table: &FftRootTable) { + reverse_index_bits_in_place(values); let n = values.len(); let lg_n = log2_strict(n); @@ -200,11 +199,10 @@ pub(crate) fn fft_classic(input: &[F], r: usize, root_table: &FftRootT if lg_n <= lg_packed_width { // Need the slice to be at least the width of two packed vectors for the vectorized version // to work. Do this tiny problem in scalar. - fft_classic_simd::(&mut values[..], r, lg_n, root_table); + fft_classic_simd::(values, r, lg_n, root_table); } else { - fft_classic_simd::<::Packing>(&mut values[..], r, lg_n, root_table); + fft_classic_simd::<::Packing>(values, r, lg_n, root_table); } - values } #[cfg(test)] @@ -231,10 +229,10 @@ mod tests { assert_eq!(coeffs.len(), degree_padded); let coefficients = PolynomialCoeffs { coeffs }; - let points = fft(&coefficients); + let points = fft(coefficients.clone()); assert_eq!(points, evaluate_naive(&coefficients)); - let interpolated_coefficients = ifft(&points); + let interpolated_coefficients = ifft(points); for i in 0..degree { assert_eq!(interpolated_coefficients.coeffs[i], coefficients.coeffs[i]); } @@ -245,7 +243,10 @@ mod tests { for r in 0..4 { // expand coefficients by factor 2^r by filling with zeros let zero_tail = coefficients.lde(r); - assert_eq!(fft(&zero_tail), fft_with_options(&zero_tail, Some(r), None)); + assert_eq!( + fft(zero_tail.clone()), + fft_with_options(zero_tail, Some(r), None) + ); } } diff --git a/field/src/interpolation.rs b/field/src/interpolation.rs index ac6f6437..1a2e37df 100644 --- a/field/src/interpolation.rs +++ b/field/src/interpolation.rs @@ -19,7 +19,7 @@ pub fn interpolant(points: &[(F, F)]) -> PolynomialCoeffs { .map(|x| interpolate(points, x, &barycentric_weights)) .collect(); - let mut coeffs = ifft(&PolynomialValues { + let mut coeffs = ifft(PolynomialValues { values: subgroup_evals, }); coeffs.trim(); diff --git a/field/src/polynomial/mod.rs b/field/src/polynomial/mod.rs index 624e8212..4264c914 100644 --- a/field/src/polynomial/mod.rs +++ b/field/src/polynomial/mod.rs @@ -31,12 +31,12 @@ impl PolynomialValues { self.values.len() } - pub fn ifft(&self) -> PolynomialCoeffs { + pub fn ifft(self) -> PolynomialCoeffs { ifft(self) } /// Returns the polynomial whose evaluation on the coset `shift*H` is `self`. - pub fn coset_ifft(&self, shift: F) -> PolynomialCoeffs { + pub fn coset_ifft(self, shift: F) -> PolynomialCoeffs { let mut shifted_coeffs = self.ifft(); shifted_coeffs .coeffs @@ -52,9 +52,9 @@ impl PolynomialValues { polys.into_iter().map(|p| p.lde(rate_bits)).collect() } - pub fn lde(&self, rate_bits: usize) -> Self { + pub fn lde(self, rate_bits: usize) -> Self { let coeffs = ifft(self).lde(rate_bits); - fft_with_options(&coeffs, Some(rate_bits), None) + fft_with_options(coeffs, Some(rate_bits), None) } pub fn degree(&self) -> usize { @@ -64,7 +64,7 @@ impl PolynomialValues { } pub fn degree_plus_one(&self) -> usize { - self.ifft().degree_plus_one() + self.clone().ifft().degree_plus_one() } } @@ -213,12 +213,12 @@ impl PolynomialCoeffs { Self::new(self.trimmed().coeffs.into_iter().rev().collect()) } - pub fn fft(&self) -> PolynomialValues { + pub fn fft(self) -> PolynomialValues { fft(self) } pub fn fft_with_options( - &self, + self, zero_factor: Option, root_table: Option<&FftRootTable>, ) -> PolynomialValues { @@ -386,7 +386,7 @@ impl Mul for &PolynomialCoeffs { .zip(b_evals.values) .map(|(pa, pb)| pa * pb) .collect(); - ifft(&mul_evals.into()) + ifft(mul_evals.into()) } } @@ -454,7 +454,7 @@ mod tests { let n = 1 << k; let evals = PolynomialValues::new(F::rand_vec(n)); let shift = F::rand(); - let coeffs = evals.coset_ifft(shift); + let coeffs = evals.clone().coset_ifft(shift); let generator = F::primitive_root_of_unity(k); let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n) diff --git a/plonky2/benches/ffts.rs b/plonky2/benches/ffts.rs index cfa02a25..63ac9c85 100644 --- a/plonky2/benches/ffts.rs +++ b/plonky2/benches/ffts.rs @@ -11,7 +11,7 @@ pub(crate) fn bench_ffts(c: &mut Criterion) { let size = 1 << size_log; group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, _| { let coeffs = PolynomialCoeffs::new(F::rand_vec(size)); - b.iter(|| coeffs.fft_with_options(None, None)); + b.iter(|| coeffs.clone().fft_with_options(None, None)); }); } } diff --git a/plonky2/src/fri/oracle.rs b/plonky2/src/fri/oracle.rs index c705e125..02db3140 100644 --- a/plonky2/src/fri/oracle.rs +++ b/plonky2/src/fri/oracle.rs @@ -47,7 +47,7 @@ impl, C: GenericConfig, const D: usize> let coeffs = timed!( timing, "IFFT", - values.par_iter().map(|v| v.ifft()).collect::>() + values.into_par_iter().map(|v| v.ifft()).collect::>() ); Self::from_coeffs( diff --git a/util/src/lib.rs b/util/src/lib.rs index 6dc32cb5..5c683a50 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -80,7 +80,7 @@ fn reverse_index_bits_large(arr: &[T], n_power: usize) -> Vec { result } -pub fn reverse_index_bits_in_place(arr: &mut Vec) { +pub fn reverse_index_bits_in_place(arr: &mut [T]) { let n = arr.len(); let n_power = log2_strict(n); @@ -101,7 +101,7 @@ pub fn reverse_index_bits_in_place(arr: &mut Vec) { where reverse_bits(src, n_power) computes the n_power-bit reverse. */ -fn reverse_index_bits_in_place_small(arr: &mut Vec, n_power: usize) { +fn reverse_index_bits_in_place_small(arr: &mut [T], n_power: usize) { let n = arr.len(); // BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them n_power-bit reverses. let dst_shr_amt = 6 - n_power; @@ -113,7 +113,7 @@ fn reverse_index_bits_in_place_small(arr: &mut Vec, n_power: usize) { } } -fn reverse_index_bits_in_place_large(arr: &mut Vec, n_power: usize) { +fn reverse_index_bits_in_place_large(arr: &mut [T], n_power: usize) { let n = arr.len(); // LLVM does not know that it does not need to reverse src at each iteration (which is expensive // on x86). We take advantage of the fact that the low bits of dst change rarely and the high