mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-03 06:13:07 +00:00
Make all FFTs in-place (#439)
* Make all FFTs in-place * Delete leftover marker
This commit is contained in:
parent
2e3a682bde
commit
86dc4c933a
@ -1,7 +1,7 @@
|
||||
use std::cmp::{max, min};
|
||||
use std::option::Option;
|
||||
|
||||
use plonky2_util::{log2_strict, reverse_index_bits};
|
||||
use plonky2_util::{log2_strict, reverse_index_bits_in_place};
|
||||
use unroll::unroll_for_loops;
|
||||
|
||||
use crate::field_types::Field;
|
||||
@ -34,10 +34,10 @@ pub fn fft_root_table<F: Field>(n: usize) -> FftRootTable<F> {
|
||||
|
||||
#[inline]
|
||||
fn fft_dispatch<F: Field>(
|
||||
input: &[F],
|
||||
input: &mut [F],
|
||||
zero_factor: Option<usize>,
|
||||
root_table: Option<&FftRootTable<F>>,
|
||||
) -> Vec<F> {
|
||||
) {
|
||||
let computed_root_table = if root_table.is_some() {
|
||||
None
|
||||
} else {
|
||||
@ -45,33 +45,32 @@ fn fft_dispatch<F: Field>(
|
||||
};
|
||||
let used_root_table = root_table.or(computed_root_table.as_ref()).unwrap();
|
||||
|
||||
fft_classic(input, zero_factor.unwrap_or(0), used_root_table)
|
||||
fft_classic(input, zero_factor.unwrap_or(0), used_root_table);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn fft<F: Field>(poly: &PolynomialCoeffs<F>) -> PolynomialValues<F> {
|
||||
pub fn fft<F: Field>(poly: PolynomialCoeffs<F>) -> PolynomialValues<F> {
|
||||
fft_with_options(poly, None, None)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn fft_with_options<F: Field>(
|
||||
poly: &PolynomialCoeffs<F>,
|
||||
poly: PolynomialCoeffs<F>,
|
||||
zero_factor: Option<usize>,
|
||||
root_table: Option<&FftRootTable<F>>,
|
||||
) -> PolynomialValues<F> {
|
||||
let PolynomialCoeffs { coeffs } = poly;
|
||||
PolynomialValues {
|
||||
values: fft_dispatch(coeffs, zero_factor, root_table),
|
||||
}
|
||||
let PolynomialCoeffs { coeffs: mut buffer } = poly;
|
||||
fft_dispatch(&mut buffer, zero_factor, root_table);
|
||||
PolynomialValues { values: buffer }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn ifft<F: Field>(poly: &PolynomialValues<F>) -> PolynomialCoeffs<F> {
|
||||
pub fn ifft<F: Field>(poly: PolynomialValues<F>) -> PolynomialCoeffs<F> {
|
||||
ifft_with_options(poly, None, None)
|
||||
}
|
||||
|
||||
pub fn ifft_with_options<F: Field>(
|
||||
poly: &PolynomialValues<F>,
|
||||
poly: PolynomialValues<F>,
|
||||
zero_factor: Option<usize>,
|
||||
root_table: Option<&FftRootTable<F>>,
|
||||
) -> PolynomialCoeffs<F> {
|
||||
@ -79,20 +78,20 @@ pub fn ifft_with_options<F: Field>(
|
||||
let lg_n = log2_strict(n);
|
||||
let n_inv = F::inverse_2exp(lg_n);
|
||||
|
||||
let PolynomialValues { values } = poly;
|
||||
let mut coeffs = fft_dispatch(values, zero_factor, root_table);
|
||||
let PolynomialValues { values: mut buffer } = poly;
|
||||
fft_dispatch(&mut buffer, zero_factor, root_table);
|
||||
|
||||
// We reverse all values except the first, and divide each by n.
|
||||
coeffs[0] *= n_inv;
|
||||
coeffs[n / 2] *= n_inv;
|
||||
buffer[0] *= n_inv;
|
||||
buffer[n / 2] *= n_inv;
|
||||
for i in 1..(n / 2) {
|
||||
let j = n - i;
|
||||
let coeffs_i = coeffs[j] * n_inv;
|
||||
let coeffs_j = coeffs[i] * n_inv;
|
||||
coeffs[i] = coeffs_i;
|
||||
coeffs[j] = coeffs_j;
|
||||
let coeffs_i = buffer[j] * n_inv;
|
||||
let coeffs_j = buffer[i] * n_inv;
|
||||
buffer[i] = coeffs_i;
|
||||
buffer[j] = coeffs_j;
|
||||
}
|
||||
PolynomialCoeffs { coeffs }
|
||||
PolynomialCoeffs { coeffs: buffer }
|
||||
}
|
||||
|
||||
/// Generic FFT implementation that works with both scalar and packed inputs.
|
||||
@ -167,8 +166,8 @@ fn fft_classic_simd<P: PackedField>(
|
||||
/// The parameter r signifies that the first 1/2^r of the entries of
|
||||
/// input may be non-zero, but the last 1 - 1/2^r entries are
|
||||
/// definitely zero.
|
||||
pub(crate) fn fft_classic<F: Field>(input: &[F], r: usize, root_table: &FftRootTable<F>) -> Vec<F> {
|
||||
let mut values = reverse_index_bits(input);
|
||||
pub(crate) fn fft_classic<F: Field>(values: &mut [F], r: usize, root_table: &FftRootTable<F>) {
|
||||
reverse_index_bits_in_place(values);
|
||||
|
||||
let n = values.len();
|
||||
let lg_n = log2_strict(n);
|
||||
@ -200,11 +199,10 @@ pub(crate) fn fft_classic<F: Field>(input: &[F], r: usize, root_table: &FftRootT
|
||||
if lg_n <= lg_packed_width {
|
||||
// Need the slice to be at least the width of two packed vectors for the vectorized version
|
||||
// to work. Do this tiny problem in scalar.
|
||||
fft_classic_simd::<F>(&mut values[..], r, lg_n, root_table);
|
||||
fft_classic_simd::<F>(values, r, lg_n, root_table);
|
||||
} else {
|
||||
fft_classic_simd::<<F as Packable>::Packing>(&mut values[..], r, lg_n, root_table);
|
||||
fft_classic_simd::<<F as Packable>::Packing>(values, r, lg_n, root_table);
|
||||
}
|
||||
values
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -231,10 +229,10 @@ mod tests {
|
||||
assert_eq!(coeffs.len(), degree_padded);
|
||||
let coefficients = PolynomialCoeffs { coeffs };
|
||||
|
||||
let points = fft(&coefficients);
|
||||
let points = fft(coefficients.clone());
|
||||
assert_eq!(points, evaluate_naive(&coefficients));
|
||||
|
||||
let interpolated_coefficients = ifft(&points);
|
||||
let interpolated_coefficients = ifft(points);
|
||||
for i in 0..degree {
|
||||
assert_eq!(interpolated_coefficients.coeffs[i], coefficients.coeffs[i]);
|
||||
}
|
||||
@ -245,7 +243,10 @@ mod tests {
|
||||
for r in 0..4 {
|
||||
// expand coefficients by factor 2^r by filling with zeros
|
||||
let zero_tail = coefficients.lde(r);
|
||||
assert_eq!(fft(&zero_tail), fft_with_options(&zero_tail, Some(r), None));
|
||||
assert_eq!(
|
||||
fft(zero_tail.clone()),
|
||||
fft_with_options(zero_tail, Some(r), None)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ pub fn interpolant<F: Field>(points: &[(F, F)]) -> PolynomialCoeffs<F> {
|
||||
.map(|x| interpolate(points, x, &barycentric_weights))
|
||||
.collect();
|
||||
|
||||
let mut coeffs = ifft(&PolynomialValues {
|
||||
let mut coeffs = ifft(PolynomialValues {
|
||||
values: subgroup_evals,
|
||||
});
|
||||
coeffs.trim();
|
||||
|
||||
@ -31,12 +31,12 @@ impl<F: Field> PolynomialValues<F> {
|
||||
self.values.len()
|
||||
}
|
||||
|
||||
pub fn ifft(&self) -> PolynomialCoeffs<F> {
|
||||
pub fn ifft(self) -> PolynomialCoeffs<F> {
|
||||
ifft(self)
|
||||
}
|
||||
|
||||
/// Returns the polynomial whose evaluation on the coset `shift*H` is `self`.
|
||||
pub fn coset_ifft(&self, shift: F) -> PolynomialCoeffs<F> {
|
||||
pub fn coset_ifft(self, shift: F) -> PolynomialCoeffs<F> {
|
||||
let mut shifted_coeffs = self.ifft();
|
||||
shifted_coeffs
|
||||
.coeffs
|
||||
@ -52,9 +52,9 @@ impl<F: Field> PolynomialValues<F> {
|
||||
polys.into_iter().map(|p| p.lde(rate_bits)).collect()
|
||||
}
|
||||
|
||||
pub fn lde(&self, rate_bits: usize) -> Self {
|
||||
pub fn lde(self, rate_bits: usize) -> Self {
|
||||
let coeffs = ifft(self).lde(rate_bits);
|
||||
fft_with_options(&coeffs, Some(rate_bits), None)
|
||||
fft_with_options(coeffs, Some(rate_bits), None)
|
||||
}
|
||||
|
||||
pub fn degree(&self) -> usize {
|
||||
@ -64,7 +64,7 @@ impl<F: Field> PolynomialValues<F> {
|
||||
}
|
||||
|
||||
pub fn degree_plus_one(&self) -> usize {
|
||||
self.ifft().degree_plus_one()
|
||||
self.clone().ifft().degree_plus_one()
|
||||
}
|
||||
}
|
||||
|
||||
@ -213,12 +213,12 @@ impl<F: Field> PolynomialCoeffs<F> {
|
||||
Self::new(self.trimmed().coeffs.into_iter().rev().collect())
|
||||
}
|
||||
|
||||
pub fn fft(&self) -> PolynomialValues<F> {
|
||||
pub fn fft(self) -> PolynomialValues<F> {
|
||||
fft(self)
|
||||
}
|
||||
|
||||
pub fn fft_with_options(
|
||||
&self,
|
||||
self,
|
||||
zero_factor: Option<usize>,
|
||||
root_table: Option<&FftRootTable<F>>,
|
||||
) -> PolynomialValues<F> {
|
||||
@ -386,7 +386,7 @@ impl<F: Field> Mul for &PolynomialCoeffs<F> {
|
||||
.zip(b_evals.values)
|
||||
.map(|(pa, pb)| pa * pb)
|
||||
.collect();
|
||||
ifft(&mul_evals.into())
|
||||
ifft(mul_evals.into())
|
||||
}
|
||||
}
|
||||
|
||||
@ -454,7 +454,7 @@ mod tests {
|
||||
let n = 1 << k;
|
||||
let evals = PolynomialValues::new(F::rand_vec(n));
|
||||
let shift = F::rand();
|
||||
let coeffs = evals.coset_ifft(shift);
|
||||
let coeffs = evals.clone().coset_ifft(shift);
|
||||
|
||||
let generator = F::primitive_root_of_unity(k);
|
||||
let naive_coset_evals = F::cyclic_subgroup_coset_known_order(generator, shift, n)
|
||||
|
||||
@ -11,7 +11,7 @@ pub(crate) fn bench_ffts<F: Field>(c: &mut Criterion) {
|
||||
let size = 1 << size_log;
|
||||
group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, _| {
|
||||
let coeffs = PolynomialCoeffs::new(F::rand_vec(size));
|
||||
b.iter(|| coeffs.fft_with_options(None, None));
|
||||
b.iter(|| coeffs.clone().fft_with_options(None, None));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -47,7 +47,7 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
|
||||
let coeffs = timed!(
|
||||
timing,
|
||||
"IFFT",
|
||||
values.par_iter().map(|v| v.ifft()).collect::<Vec<_>>()
|
||||
values.into_par_iter().map(|v| v.ifft()).collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
Self::from_coeffs(
|
||||
|
||||
@ -80,7 +80,7 @@ fn reverse_index_bits_large<T: Copy>(arr: &[T], n_power: usize) -> Vec<T> {
|
||||
result
|
||||
}
|
||||
|
||||
pub fn reverse_index_bits_in_place<T>(arr: &mut Vec<T>) {
|
||||
pub fn reverse_index_bits_in_place<T>(arr: &mut [T]) {
|
||||
let n = arr.len();
|
||||
let n_power = log2_strict(n);
|
||||
|
||||
@ -101,7 +101,7 @@ pub fn reverse_index_bits_in_place<T>(arr: &mut Vec<T>) {
|
||||
where reverse_bits(src, n_power) computes the n_power-bit reverse.
|
||||
*/
|
||||
|
||||
fn reverse_index_bits_in_place_small<T>(arr: &mut Vec<T>, n_power: usize) {
|
||||
fn reverse_index_bits_in_place_small<T>(arr: &mut [T], n_power: usize) {
|
||||
let n = arr.len();
|
||||
// BIT_REVERSE_6BIT holds 6-bit reverses. This shift makes them n_power-bit reverses.
|
||||
let dst_shr_amt = 6 - n_power;
|
||||
@ -113,7 +113,7 @@ fn reverse_index_bits_in_place_small<T>(arr: &mut Vec<T>, n_power: usize) {
|
||||
}
|
||||
}
|
||||
|
||||
fn reverse_index_bits_in_place_large<T>(arr: &mut Vec<T>, n_power: usize) {
|
||||
fn reverse_index_bits_in_place_large<T>(arr: &mut [T], n_power: usize) {
|
||||
let n = arr.len();
|
||||
// LLVM does not know that it does not need to reverse src at each iteration (which is expensive
|
||||
// on x86). We take advantage of the fact that the low bits of dst change rarely and the high
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user