diff --git a/src/field/fft.rs b/src/field/fft.rs index d8376ce4..30197ff9 100644 --- a/src/field/fft.rs +++ b/src/field/fft.rs @@ -11,7 +11,7 @@ use crate::util::{log2_strict, reverse_index_bits}; pub(crate) type FftRootTable = Vec>; -fn fft_classic_root_table(n: usize) -> FftRootTable { +pub fn fft_root_table(n: usize) -> FftRootTable { let lg_n = log2_strict(n); // bases[i] = g^2^i, for i = 0, ..., lg_n - 1 let mut bases = Vec::with_capacity(lg_n); @@ -36,14 +36,16 @@ fn fft_classic_root_table(n: usize) -> FftRootTable { fn fft_dispatch( input: &[F], zero_factor: Option, - root_table: Option>, + root_table: Option<&FftRootTable>, ) -> Vec { - let n = input.len(); - fft_classic( - input, - zero_factor.unwrap_or(0), - root_table.unwrap_or_else(|| fft_classic_root_table(n)), - ) + let computed_root_table = if let Some(_) = root_table { + None + } else { + Some(fft_root_table(input.len())) + }; + let used_root_table = root_table.or(computed_root_table.as_ref()).unwrap(); + + fft_classic(input, zero_factor.unwrap_or(0), used_root_table) } #[inline] @@ -55,7 +57,7 @@ pub fn fft(poly: &PolynomialCoeffs) -> PolynomialValues { pub fn fft_with_options( poly: &PolynomialCoeffs, zero_factor: Option, - root_table: Option>, + root_table: Option<&FftRootTable>, ) -> PolynomialValues { let PolynomialCoeffs { coeffs } = poly; PolynomialValues { @@ -71,7 +73,7 @@ pub fn ifft(poly: &PolynomialValues) -> PolynomialCoeffs { pub fn ifft_with_options( poly: &PolynomialValues, zero_factor: Option, - root_table: Option>, + root_table: Option<&FftRootTable>, ) -> PolynomialCoeffs { let n = poly.len(); let lg_n = log2_strict(n); @@ -166,7 +168,7 @@ fn fft_classic_simd( /// The parameter r signifies that the first 1/2^r of the entries of /// input may be non-zero, but the last 1 - 1/2^r entries are /// definitely zero. -pub(crate) fn fft_classic(input: &[F], r: usize, root_table: FftRootTable) -> Vec { +pub(crate) fn fft_classic(input: &[F], r: usize, root_table: &FftRootTable) -> Vec { let mut values = reverse_index_bits(input); let n = values.len(); diff --git a/src/fri/commitment.rs b/src/fri/commitment.rs index 9160f1a4..f9c96a55 100644 --- a/src/fri/commitment.rs +++ b/src/fri/commitment.rs @@ -1,6 +1,7 @@ use rayon::prelude::*; use crate::field::extension_field::Extendable; +use crate::field::fft::FftRootTable; use crate::field::field_types::{Field, RichField}; use crate::fri::proof::FriProof; use crate::fri::prover::fri_proof; @@ -35,6 +36,7 @@ impl PolynomialBatchCommitment { blinding: bool, cap_height: usize, timing: &mut TimingTree, + fft_root_table: Option<&FftRootTable>, ) -> Self { let coeffs = timed!( timing, @@ -42,7 +44,14 @@ impl PolynomialBatchCommitment { values.par_iter().map(|v| v.ifft()).collect::>() ); - Self::from_coeffs(coeffs, rate_bits, blinding, cap_height, timing) + Self::from_coeffs( + coeffs, + rate_bits, + blinding, + cap_height, + timing, + fft_root_table, + ) } /// Creates a list polynomial commitment for the polynomials `polynomials`. @@ -52,12 +61,13 @@ impl PolynomialBatchCommitment { blinding: bool, cap_height: usize, timing: &mut TimingTree, + fft_root_table: Option<&FftRootTable>, ) -> Self { let degree = polynomials[0].len(); let lde_values = timed!( timing, "FFT + blinding", - Self::lde_values(&polynomials, rate_bits, blinding) + Self::lde_values(&polynomials, rate_bits, blinding, fft_root_table) ); let mut leaves = timed!(timing, "transpose LDEs", transpose(&lde_values)); @@ -81,6 +91,7 @@ impl PolynomialBatchCommitment { polynomials: &[PolynomialCoeffs], rate_bits: usize, blinding: bool, + fft_root_table: Option<&FftRootTable>, ) -> Vec> { let degree = polynomials[0].len(); @@ -92,7 +103,7 @@ impl PolynomialBatchCommitment { .map(|p| { assert_eq!(p.len(), degree, "Polynomial degrees inconsistent"); p.lde(rate_bits) - .coset_fft_with_options(F::coset_shift(), Some(rate_bits), None) + .coset_fft_with_options(F::coset_shift(), Some(rate_bits), fft_root_table) .values }) .chain( @@ -309,6 +320,7 @@ mod tests { common_data.config.zero_knowledge && PlonkPolynomials::polynomials(i).blinding, common_data.config.cap_height, &mut TimingTree::default(), + None, ) }) .collect::>(); diff --git a/src/plonk/circuit_builder.rs b/src/plonk/circuit_builder.rs index f8b5f603..415c011e 100644 --- a/src/plonk/circuit_builder.rs +++ b/src/plonk/circuit_builder.rs @@ -1,3 +1,4 @@ +use std::cmp::max; use std::collections::{BTreeMap, HashMap, HashSet}; use std::convert::TryInto; use std::time::Instant; @@ -7,6 +8,7 @@ use log::{info, Level}; use crate::field::cosets::get_unique_coset_shifts; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::{Extendable, FieldExtension}; +use crate::field::fft::fft_root_table; use crate::field::field_types::RichField; use crate::fri::commitment::PolynomialBatchCommitment; use crate::gates::arithmetic::{ArithmeticExtensionGate, NUM_ARITHMETIC_OPS}; @@ -605,6 +607,11 @@ impl, const D: usize> CircuitBuilder { let k_is = get_unique_coset_shifts(degree, self.config.num_routed_wires); let (sigma_vecs, partition_witness) = self.sigma_vecs(&k_is, &subgroup); + // Precompute FFT roots. + let max_fft_points = + 1 << degree_bits + max(self.config.rate_bits, log2_ceil(quotient_degree_factor)); + let fft_root_table = fft_root_table(max_fft_points); + let constants_sigmas_vecs = [constant_vecs, sigma_vecs.clone()].concat(); let constants_sigmas_commitment = PolynomialBatchCommitment::from_values( constants_sigmas_vecs, @@ -612,6 +619,7 @@ impl, const D: usize> CircuitBuilder { self.config.zero_knowledge & PlonkPolynomials::CONSTANTS_SIGMAS.blinding, self.config.cap_height, &mut timing, + Some(&fft_root_table), ); let constants_sigmas_cap = constants_sigmas_commitment.merkle_tree.cap.clone(); @@ -645,6 +653,7 @@ impl, const D: usize> CircuitBuilder { public_inputs: self.public_inputs, marked_targets: self.marked_targets, partition_witness, + fft_root_table: Some(fft_root_table), }; // The HashSet of gates will have a non-deterministic order. When converting to a Vec, we diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index 540e3e84..82f1c1cd 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -4,6 +4,7 @@ use std::ops::{Range, RangeFrom}; use anyhow::Result; use crate::field::extension_field::Extendable; +use crate::field::fft::FftRootTable; use crate::field::field_types::{Field, RichField}; use crate::fri::commitment::PolynomialBatchCommitment; use crate::fri::FriConfig; @@ -157,6 +158,8 @@ pub(crate) struct ProverOnlyCircuitData, const D: u pub marked_targets: Vec>, /// Partial witness holding the copy constraints information. pub partition_witness: PartitionWitness, + /// Pre-computed roots for faster FFT. + pub fft_root_table: Option>, } /// Circuit data required by the verifier, but not the prover. diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 7f609826..9f4303ae 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -74,6 +74,7 @@ pub(crate) fn prove, const D: usize>( config.zero_knowledge & PlonkPolynomials::WIRES.blinding, config.cap_height, &mut timing, + prover_data.fft_root_table.as_ref(), ) ); @@ -119,6 +120,7 @@ pub(crate) fn prove, const D: usize>( config.zero_knowledge & PlonkPolynomials::ZS_PARTIAL_PRODUCTS.blinding, config.cap_height, &mut timing, + prover_data.fft_root_table.as_ref(), ) ); @@ -167,7 +169,8 @@ pub(crate) fn prove, const D: usize>( config.rate_bits, config.zero_knowledge & PlonkPolynomials::QUOTIENT.blinding, config.cap_height, - &mut timing + &mut timing, + prover_data.fft_root_table.as_ref(), ) ); diff --git a/src/polynomial/polynomial.rs b/src/polynomial/polynomial.rs index 674504b5..eb1529aa 100644 --- a/src/polynomial/polynomial.rs +++ b/src/polynomial/polynomial.rs @@ -211,7 +211,7 @@ impl PolynomialCoeffs { pub fn fft_with_options( &self, zero_factor: Option, - root_table: Option>, + root_table: Option<&FftRootTable>, ) -> PolynomialValues { fft_with_options(self, zero_factor, root_table) } @@ -226,7 +226,7 @@ impl PolynomialCoeffs { &self, shift: F, zero_factor: Option, - root_table: Option>, + root_table: Option<&FftRootTable>, ) -> PolynomialValues { let modified_poly: Self = shift .powers()