From 574a3d4847fc3ca8506849c611bde55e639f9adf Mon Sep 17 00:00:00 2001 From: Hamish Ivey-Law <426294+unzvfu@users.noreply.github.com> Date: Thu, 1 Jul 2021 14:55:41 +1000 Subject: [PATCH] FFT improvements (#81) * Use built-in `reverse_bits`; remove duplicate `reverse_index_bits`. * Reduce precomputation time/space complexity from quadratic to linear. * Several working cache-friendly FFTs. * Fix to allow FFT of constant polynomial. * Simplify FFT strategy choice. * Add PrimeField and CHARACTERISTIC properties to Fields. * Add faster method for inverse of 2^m. * Pre-compute some of the roots; tidy up loop iteration. * Precomputation for both FFT variants. * Refactor precomputation; add optional parameters; rename some things. * Unrolled version with zero tail. * Iterative version of Unrolled precomputation. * Test zero tail algo. * Restore default degree. * Address comments from @dlubarov and @wborgeaud. --- src/field/crandall_field.rs | 3 + src/field/extension_field/quadratic.rs | 3 + src/field/extension_field/quartic.rs | 3 + src/field/fft.rs | 358 ++++++++++++++++++------- src/field/field.rs | 28 ++ src/field/field_testing.rs | 14 + src/polynomial/polynomial.rs | 6 +- src/util/mod.rs | 15 +- 8 files changed, 325 insertions(+), 105 deletions(-) diff --git a/src/field/crandall_field.rs b/src/field/crandall_field.rs index dbd29cb2..7a1d18d6 100644 --- a/src/field/crandall_field.rs +++ b/src/field/crandall_field.rs @@ -136,6 +136,8 @@ impl Debug for CrandallField { } impl Field for CrandallField { + type PrimeField = Self; + const ZERO: Self = Self(0); const ONE: Self = Self(1); const TWO: Self = Self(2); @@ -143,6 +145,7 @@ impl Field for CrandallField { const ORDER: u64 = 18446744071293632513; const TWO_ADICITY: usize = 28; + const CHARACTERISTIC: u64 = Self::ORDER; const MULTIPLICATIVE_GROUP_GENERATOR: Self = Self(5); const POWER_OF_TWO_GENERATOR: Self = Self(10281950781551402419); diff --git a/src/field/extension_field/quadratic.rs b/src/field/extension_field/quadratic.rs index af21ad60..ede2ef26 100644 --- a/src/field/extension_field/quadratic.rs +++ b/src/field/extension_field/quadratic.rs @@ -43,11 +43,14 @@ impl From<>::BaseField> for QuadraticCrandallField { } impl Field for QuadraticCrandallField { + type PrimeField = CrandallField; + const ZERO: Self = Self([CrandallField::ZERO; 2]); const ONE: Self = Self([CrandallField::ONE, CrandallField::ZERO]); const TWO: Self = Self([CrandallField::TWO, CrandallField::ZERO]); const NEG_ONE: Self = Self([CrandallField::NEG_ONE, CrandallField::ZERO]); + const CHARACTERISTIC: u64 = CrandallField::ORDER; // Does not fit in 64-bits. const ORDER: u64 = 0; const TWO_ADICITY: usize = 29; diff --git a/src/field/extension_field/quartic.rs b/src/field/extension_field/quartic.rs index b93cbb56..f609eeb7 100644 --- a/src/field/extension_field/quartic.rs +++ b/src/field/extension_field/quartic.rs @@ -50,6 +50,8 @@ impl From<>::BaseField> for QuarticCrandallField { } impl Field for QuarticCrandallField { + type PrimeField = CrandallField; + const ZERO: Self = Self([CrandallField::ZERO; 4]); const ONE: Self = Self([ CrandallField::ONE, @@ -70,6 +72,7 @@ impl Field for QuarticCrandallField { CrandallField::ZERO, ]); + const CHARACTERISTIC: u64 = CrandallField::ORDER; // Does not fit in 64-bits. const ORDER: u64 = 0; const TWO_ADICITY: usize = 30; diff --git a/src/field/fft.rs b/src/field/fft.rs index 56764b47..af5c05a7 100644 --- a/src/field/fft.rs +++ b/src/field/fft.rs @@ -1,142 +1,304 @@ +use std::option::Option; + use crate::field::field::Field; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; -use crate::util::{log2_ceil, log2_strict}; +use crate::util::{log2_strict, reverse_index_bits}; -/// Permutes `arr` such that each index is mapped to its reverse in binary. -fn reverse_index_bits(arr: Vec) -> Vec { - let n = arr.len(); - let n_power = log2_strict(n); +// TODO: Should really do some "dynamic" dispatch to handle the +// different FFT algos rather than C-style enum dispatch. +enum FftStrategy { Classic, Unrolled } - let mut result = Vec::with_capacity(n); - for i in 0..n { - result.push(arr[reverse_bits(i, n_power)]); +const FFT_STRATEGY: FftStrategy = FftStrategy::Classic; + +type FftRootTable = Vec>; + +fn fft_classic_root_table(n: usize) -> FftRootTable { + let lg_n = log2_strict(n); + // bases[i] = g^2^i, for i = 0, ..., lg_n - 1 + let mut bases = Vec::with_capacity(lg_n); + let mut base = F::primitive_root_of_unity(lg_n); + bases.push(base); + for _ in 1..lg_n { + base = base.square(); // base = g^2^_ + bases.push(base); } - result -} -fn reverse_bits(n: usize, num_bits: usize) -> usize { - let mut result = 0; - for i in 0..num_bits { - let i_rev = num_bits - i - 1; - result |= (n >> i & 1) << i_rev; + let mut root_table = Vec::with_capacity(lg_n); + for lg_m in 1..=lg_n { + let half_m = 1 << (lg_m - 1); + let base = bases[lg_n - lg_m]; + let root_row = base.powers().take(half_m.max(2)).collect(); + root_table.push(root_row); } - result + root_table } -pub(crate) struct FftPrecomputation { - /// For each layer index i, stores the cyclic subgroup corresponding to the evaluation domain of - /// layer i. The indices within these subgroup vectors are bit-reversed. - subgroups_rev: Vec>, + +fn fft_unrolled_root_table(n: usize) -> FftRootTable { + // Precompute a table of the roots of unity used in the main + // loops. + + // Suppose n is the size of the outer vector and g is a primitive nth + // root of unity. Then the [lg(m) - 1][j] element of the table is + // g^{ n/2m * j } for j = 0..m-1 + + let lg_n = log2_strict(n); + // bases[i] = g^2^i, for i = 0, ..., lg_n - 2 + let mut bases = Vec::with_capacity(lg_n); + let mut base = F::primitive_root_of_unity(lg_n); + bases.push(base); + // NB: If n = 1, then lg_n is zero, so we can't do 1..(lg_n-1) here + for _ in 2..lg_n { + base = base.square(); // base = g^2^(_-1) + bases.push(base); + } + + let mut root_table = Vec::with_capacity(lg_n); + for lg_m in 1..lg_n { + let m = 1 << lg_m; + let base = bases[lg_n - lg_m - 1]; + let root_row = base.powers().take(m.max(2)).collect(); + root_table.push(root_row); + } + root_table } -impl FftPrecomputation { - pub fn size(&self) -> usize { - self.subgroups_rev.last().unwrap().len() +#[inline] +fn fft_dispatch( + input: Vec, + zero_factor: Option, + root_table: Option> +) -> Vec { + let n = input.len(); + match FFT_STRATEGY { + FftStrategy::Classic + => fft_classic(input, + zero_factor.unwrap_or(0), + root_table.unwrap_or_else(|| fft_classic_root_table(n))), + FftStrategy::Unrolled + => fft_unrolled(input, + zero_factor.unwrap_or(0), + root_table.unwrap_or_else(|| fft_unrolled_root_table(n))) } } +#[inline] pub fn fft(poly: PolynomialCoeffs) -> PolynomialValues { - let precomputation = fft_precompute(poly.len()); - fft_with_precomputation_power_of_2(poly, &precomputation) + fft_with_options(poly, None, None) } -pub(crate) fn fft_precompute(degree: usize) -> FftPrecomputation { - let degree_log = log2_ceil(degree); - - let mut subgroups_rev = Vec::new(); - let mut subgroup = F::two_adic_subgroup(degree_log); - for _i in 0..=degree_log { - let subsubgroup = subgroup.iter().step_by(2).copied().collect(); - let subgroup_rev = reverse_index_bits(subgroup); - subgroups_rev.push(subgroup_rev); - subgroup = subsubgroup; - } - subgroups_rev.reverse(); - - FftPrecomputation { subgroups_rev } +#[inline] +pub fn fft_with_options( + poly: PolynomialCoeffs, + zero_factor: Option, + root_table: Option> +) -> PolynomialValues { + let PolynomialCoeffs { coeffs } = poly; + PolynomialValues { values: fft_dispatch(coeffs, zero_factor, root_table) } } -pub(crate) fn ifft_with_precomputation_power_of_2( +#[inline] +pub fn ifft(poly: PolynomialValues) -> PolynomialCoeffs { + ifft_with_options(poly, None, None) +} + +pub fn ifft_with_options( poly: PolynomialValues, - precomputation: &FftPrecomputation, + zero_factor: Option, + root_table: Option> ) -> PolynomialCoeffs { let n = poly.len(); - let n_inv = F::from_canonical_usize(n).try_inverse().unwrap(); + let lg_n = log2_strict(n); + let n_inv = F::inverse_2exp(lg_n); let PolynomialValues { values } = poly; - let PolynomialValues { values: mut result } = - fft_with_precomputation_power_of_2(PolynomialCoeffs { coeffs: values }, precomputation); + let mut coeffs = fft_dispatch(values, zero_factor, root_table); // We reverse all values except the first, and divide each by n. - result[0] *= n_inv; - result[n / 2] *= n_inv; + coeffs[0] *= n_inv; + coeffs[n / 2] *= n_inv; for i in 1..(n / 2) { let j = n - i; - let result_i = result[j] * n_inv; - let result_j = result[i] * n_inv; - result[i] = result_i; - result[j] = result_j; + let coeffs_i = coeffs[j] * n_inv; + let coeffs_j = coeffs[i] * n_inv; + coeffs[i] = coeffs_i; + coeffs[j] = coeffs_j; } - PolynomialCoeffs { coeffs: result } + PolynomialCoeffs { coeffs } } -pub(crate) fn fft_with_precomputation_power_of_2( - poly: PolynomialCoeffs, - precomputation: &FftPrecomputation, -) -> PolynomialValues { - debug_assert_eq!( - poly.len(), - precomputation.subgroups_rev.last().unwrap().len(), - "Number of coefficients does not match size of subgroup in precomputation" - ); +/// FFT implementation based on Section 32.3 of "Introduction to +/// Algorithms" by Cormen et al. +/// +/// The parameter r signifies that the first 1/2^r of the entries of +/// input may be non-zero, but the last 1 - 1/2^r entries are +/// definitely zero. +pub(crate) fn fft_classic( + input: Vec, + r: usize, + root_table: FftRootTable +) -> Vec { + let mut values = reverse_index_bits(input); - let half_degree = poly.len() >> 1; - let degree_log = poly.log_len(); + let n = values.len(); + let lg_n = log2_strict(n); - // In the base layer, we're just evaluating "degree 0 polynomials", i.e. the coefficients - // themselves. - let PolynomialCoeffs { coeffs } = poly; - let mut evaluations = reverse_index_bits(coeffs); + if root_table.len() != lg_n { + panic!("Expected root table of length {}, but it was {}.", lg_n, root_table.len()); + } - for i in 1..=degree_log { - // In layer i, we're evaluating a series of polynomials, each at 2^i points. In practice - // we evaluate a pair of points together, so we have 2^(i - 1) pairs. - let points_per_poly = 1 << i; - let pairs_per_poly = 1 << (i - 1); - - let mut new_evaluations = Vec::new(); - for pair_index in 0..half_degree { - let poly_index = pair_index / pairs_per_poly; - let pair_index_within_poly = pair_index % pairs_per_poly; - - let child_index_0 = poly_index * points_per_poly + pair_index_within_poly; - let child_index_1 = child_index_0 + pairs_per_poly; - - let even = evaluations[child_index_0]; - let odd = evaluations[child_index_1]; - - let point_0 = precomputation.subgroups_rev[i][pair_index_within_poly * 2]; - let product = point_0 * odd; - new_evaluations.push(even + product); - new_evaluations.push(even - product); + // After reverse_index_bits, the only non-zero elements of values + // are at indices i*2^r for i = 0..n/2^r. The loop below copies + // the value at i*2^r to the positions [i*2^r + 1, i*2^r + 2, ..., + // (i+1)*2^r - 1]; i.e. it replaces the 2^r - 1 zeros following + // element i*2^r with the value at i*2^r. This corresponds to the + // first r rounds of the FFT when there are 2^r zeros at the end + // of the original input. + if r > 0 { // if r == 0 then this loop is a noop. + let mask = !((1 << r) - 1); + for i in 0..n { + values[i] = values[i & mask]; } - evaluations = new_evaluations; } - // Reorder so that evaluations' indices correspond to (g_0, g_1, g_2, ...) - let values = reverse_index_bits(evaluations); - PolynomialValues { values } + let mut m = 1 << (r + 1); + for lg_m in (r+1)..=lg_n { + let half_m = m / 2; + for k in (0..n).step_by(m) { + for j in 0..half_m { + let omega = root_table[lg_m - 1][j]; + let t = omega * values[k + half_m + j]; + let u = values[k + j]; + values[k + j] = u + t; + values[k + half_m + j] = u - t; + } + } + m *= 2; + } + values } -pub(crate) fn ifft(poly: PolynomialValues) -> PolynomialCoeffs { - let precomputation = fft_precompute(poly.len()); - ifft_with_precomputation_power_of_2(poly, &precomputation) +/// FFT implementation inspired by Barretenberg's (but with extra unrolling): +/// https://github.com/AztecProtocol/barretenberg/blob/master/barretenberg/src/aztec/polynomials/polynomial_arithmetic.cpp#L58 +/// https://github.com/AztecProtocol/barretenberg/blob/master/barretenberg/src/aztec/polynomials/evaluation_domain.cpp#L30 +/// +/// The parameter r signifies that the first 1/2^r of the entries of +/// input may be non-zero, but the last 1 - 1/2^r entries are +/// definitely zero. +fn fft_unrolled( + input: Vec, + r_orig: usize, + root_table: FftRootTable +) -> Vec { + let n = input.len(); + let lg_n = log2_strict(input.len()); + + let mut values = reverse_index_bits(input); + + // FFT of a constant polynomial (including zero) is itself. + if n < 2 { + return values + } + + // The 'm' corresponds to the specialisation from the 'm' in the + // main loop (m >= 4) below. + + // (See comment in fft_classic near same code.) + let mut r = r_orig; + let mut m = 1 << r; + if r > 0 { // if r == 0 then this loop is a noop. + let mask = !((1 << r) - 1); + for i in 0..n { + values[i] = values[i & mask]; + } + } + + // m = 1 + if m == 1 { + for k in (0..n).step_by(2) { + let t = values[k + 1]; + values[k + 1] = values[k] - t; + values[k] += t; + } + r += 1; + m *= 2; + } + + if n == 2 { + return values + } + + if root_table.len() != (lg_n - 1) { + panic!("Expected root table of length {}, but it was {}.", lg_n, root_table.len()); + } + + // m = 2 + if m <= 2 { + for k in (0..n).step_by(4) { + // NB: Grouping statements as is done in the main loop below + // does not seem to help here (worse by a few millis). + let omega_0 = root_table[0][0]; + let tmp_0 = omega_0 * values[k + 2 + 0]; + values[k + 2 + 0] = values[k + 0] - tmp_0; + values[k + 0] += tmp_0; + + let omega_1 = root_table[0][1]; + let tmp_1 = omega_1 * values[k + 2 + 1]; + values[k + 2 + 1] = values[k + 1] - tmp_1; + values[k + 1] += tmp_1; + } + r += 1; + m *= 2; + } + + // m >= 4 + for lg_m in r..lg_n { + for k in (0..n).step_by(2*m) { + // Unrolled the commented loop by groups of 4 and + // rearranged the lines. Improves runtime by about + // 10%. + /* + for j in (0..m) { + let omega = root_table[lg_m - 1][j]; + let tmp = omega * values[k + m + j]; + values[k + m + j] = values[k + j] - tmp; + values[k + j] += tmp; + } + */ + for j in (0..m).step_by(4) { + let off1 = k + j; + let off2 = k + m + j; + + let omega_0 = root_table[lg_m - 1][j + 0]; + let omega_1 = root_table[lg_m - 1][j + 1]; + let omega_2 = root_table[lg_m - 1][j + 2]; + let omega_3 = root_table[lg_m - 1][j + 3]; + + let tmp_0 = omega_0 * values[off2 + 0]; + let tmp_1 = omega_1 * values[off2 + 1]; + let tmp_2 = omega_2 * values[off2 + 2]; + let tmp_3 = omega_3 * values[off2 + 3]; + + values[off2 + 0] = values[off1 + 0] - tmp_0; + values[off2 + 1] = values[off1 + 1] - tmp_1; + values[off2 + 2] = values[off1 + 2] - tmp_2; + values[off2 + 3] = values[off1 + 3] - tmp_3; + values[off1 + 0] += tmp_0; + values[off1 + 1] += tmp_1; + values[off1 + 2] += tmp_2; + values[off1 + 3] += tmp_3; + } + } + m *= 2; + } + values } + #[cfg(test)] mod tests { use crate::field::crandall_field::CrandallField; - use crate::field::fft::{fft, ifft}; + use crate::field::fft::{fft, ifft, fft_with_options}; use crate::field::field::Field; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::util::{log2_ceil, log2_strict}; @@ -162,6 +324,12 @@ mod tests { for i in degree..degree_padded { assert_eq!(interpolated_coefficients.coeffs[i], F::ZERO); } + + for r in 0..4 { + // expand ceofficients by factor 2^r by filling with zeros + let zero_tail = coefficients.clone().lde(r); + assert_eq!(fft(zero_tail.clone()), fft_with_options(zero_tail, Some(r), None)); + } } fn evaluate_naive(coefficients: &PolynomialCoeffs) -> PolynomialValues { diff --git a/src/field/field.rs b/src/field/field.rs index 516012d2..f4ef5990 100644 --- a/src/field/field.rs +++ b/src/field/field.rs @@ -32,11 +32,14 @@ pub trait Field: + Send + Sync { + type PrimeField: Field; + const ZERO: Self; const ONE: Self; const TWO: Self; const NEG_ONE: Self; + const CHARACTERISTIC: u64; const ORDER: u64; const TWO_ADICITY: usize; @@ -101,6 +104,31 @@ pub trait Field: x_inv } + /// Compute the inverse of 2^exp in this field. + #[inline] + fn inverse_2exp(exp: usize) -> Self { + let p = Self::CHARACTERISTIC; + + if exp <= Self::PrimeField::TWO_ADICITY { + // The inverse of 2^exp is p-(p-1)/2^exp when char(F) = p and exp is + // at most the TWO_ADICITY of the prime field. + // + // NB: PrimeFields fit in 64 bits => TWO_ADICITY < 64 => + // exp < 64 => this shift amount is legal. + Self::from_canonical_u64(p - ((p - 1) >> exp)) + } else { + // In the general case we compute 1/2 = (p+1)/2 and then exponentiate + // by exp to get 1/2^exp. Costs about log_2(exp) operations. + let half = Self::from_canonical_u64((p + 1) >> 1); + half.exp(exp as u64) + + // TODO: Faster to combine several high powers of 1/2 using multiple + // applications of the trick above. E.g. if the 2-adicity is v, then + // compute 1/2^(v^2 + v + 13) with 1/2^((v + 1) * v + 13), etc. + // (using the v-adic expansion of m). Costs about log_v(exp) operations. + } + } + fn primitive_root_of_unity(n_log: usize) -> Self { assert!(n_log <= Self::TWO_ADICITY); let mut base = Self::POWER_OF_TWO_GENERATOR; diff --git a/src/field/field_testing.rs b/src/field/field_testing.rs index 53e9c63c..7190684f 100644 --- a/src/field/field_testing.rs +++ b/src/field/field_testing.rs @@ -315,6 +315,20 @@ macro_rules! test_arithmetic { assert_eq!(x, F::ONE); assert_eq!(F::ZERO - x, F::NEG_ONE); } + + #[test] + fn inverse_2exp() { + // Just check consistency with try_inverse() + type F = $field; + + let v = ::PrimeField::TWO_ADICITY; + + for e in [0, 1, 2, 3, 4, v - 2, v - 1, v, v + 1, v + 2, 123*v] { + let x = F::TWO.exp(e as u64).inverse(); + let y = F::inverse_2exp(e); + assert_eq!(x, y); + } + } } }; } diff --git a/src/polynomial/polynomial.rs b/src/polynomial/polynomial.rs index 888d7af0..5f295030 100644 --- a/src/polynomial/polynomial.rs +++ b/src/polynomial/polynomial.rs @@ -1,3 +1,5 @@ +use std::time::Instant; + use std::cmp::max; use std::iter::Sum; use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; @@ -5,7 +7,7 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; use anyhow::{ensure, Result}; use crate::field::extension_field::Extendable; -use crate::field::fft::{fft, ifft}; +use crate::field::fft::{fft, ifft, fft_with_options}; use crate::field::field::Field; use crate::util::log2_strict; @@ -55,7 +57,7 @@ impl PolynomialValues { pub fn lde(self, rate_bits: usize) -> Self { let coeffs = ifft(self).lde(rate_bits); - fft(coeffs) + fft_with_options(coeffs, Some(rate_bits), None) } pub fn degree(&self) -> usize { diff --git a/src/util/mod.rs b/src/util/mod.rs index f901b0af..8fd60d53 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -49,13 +49,13 @@ pub(crate) fn transpose(matrix: &[Vec]) -> Vec> { } /// Permutes `arr` such that each index is mapped to its reverse in binary. -pub(crate) fn reverse_index_bits(arr: Vec) -> Vec { +pub(crate) fn reverse_index_bits(arr: Vec) -> Vec { let n = arr.len(); let n_power = log2_strict(n); let mut result = Vec::with_capacity(n); for i in 0..n { - result.push(arr[reverse_bits(i, n_power)].clone()); + result.push(arr[reverse_bits(i, n_power)]); } result } @@ -73,12 +73,11 @@ pub(crate) fn reverse_index_bits_in_place(arr: &mut Vec) { } pub(crate) fn reverse_bits(n: usize, num_bits: usize) -> usize { - let mut result = 0; - for i in 0..num_bits { - let i_rev = num_bits - i - 1; - result |= (n >> i & 1) << i_rev; - } - result + // NB: The only reason we need overflowing_shr() here as opposed + // to plain '>>' is to accommodate the case n == num_bits == 0, + // which would become `0 >> 64`. Rust thinks that any shift of 64 + // bits causes overflow, even when the argument is zero. + n.reverse_bits().overflowing_shr(usize::BITS - num_bits as u32).0 } #[cfg(test)]