Cache FFT roots (#261)

This commit is contained in:
Jakub Nabaglo 2021-09-22 10:56:09 -07:00 committed by GitHub
parent 46cc27571d
commit 7360391515
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 46 additions and 17 deletions

View File

@ -11,7 +11,7 @@ use crate::util::{log2_strict, reverse_index_bits};
pub(crate) type FftRootTable<F> = Vec<Vec<F>>;
fn fft_classic_root_table<F: Field>(n: usize) -> FftRootTable<F> {
pub fn fft_root_table<F: Field>(n: usize) -> FftRootTable<F> {
let lg_n = log2_strict(n);
// bases[i] = g^2^i, for i = 0, ..., lg_n - 1
let mut bases = Vec::with_capacity(lg_n);
@ -36,14 +36,16 @@ fn fft_classic_root_table<F: Field>(n: usize) -> FftRootTable<F> {
fn fft_dispatch<F: Field>(
input: &[F],
zero_factor: Option<usize>,
root_table: Option<FftRootTable<F>>,
root_table: Option<&FftRootTable<F>>,
) -> Vec<F> {
let n = input.len();
fft_classic(
input,
zero_factor.unwrap_or(0),
root_table.unwrap_or_else(|| fft_classic_root_table(n)),
)
let computed_root_table = if let Some(_) = root_table {
None
} else {
Some(fft_root_table(input.len()))
};
let used_root_table = root_table.or(computed_root_table.as_ref()).unwrap();
fft_classic(input, zero_factor.unwrap_or(0), used_root_table)
}
#[inline]
@ -55,7 +57,7 @@ pub fn fft<F: Field>(poly: &PolynomialCoeffs<F>) -> PolynomialValues<F> {
pub fn fft_with_options<F: Field>(
poly: &PolynomialCoeffs<F>,
zero_factor: Option<usize>,
root_table: Option<FftRootTable<F>>,
root_table: Option<&FftRootTable<F>>,
) -> PolynomialValues<F> {
let PolynomialCoeffs { coeffs } = poly;
PolynomialValues {
@ -71,7 +73,7 @@ pub fn ifft<F: Field>(poly: &PolynomialValues<F>) -> PolynomialCoeffs<F> {
pub fn ifft_with_options<F: Field>(
poly: &PolynomialValues<F>,
zero_factor: Option<usize>,
root_table: Option<FftRootTable<F>>,
root_table: Option<&FftRootTable<F>>,
) -> PolynomialCoeffs<F> {
let n = poly.len();
let lg_n = log2_strict(n);
@ -166,7 +168,7 @@ fn fft_classic_simd<P: PackedField>(
/// The parameter r signifies that the first 1/2^r of the entries of
/// input may be non-zero, but the last 1 - 1/2^r entries are
/// definitely zero.
pub(crate) fn fft_classic<F: Field>(input: &[F], r: usize, root_table: FftRootTable<F>) -> Vec<F> {
pub(crate) fn fft_classic<F: Field>(input: &[F], r: usize, root_table: &FftRootTable<F>) -> Vec<F> {
let mut values = reverse_index_bits(input);
let n = values.len();

View File

@ -1,6 +1,7 @@
use rayon::prelude::*;
use crate::field::extension_field::Extendable;
use crate::field::fft::FftRootTable;
use crate::field::field_types::{Field, RichField};
use crate::fri::proof::FriProof;
use crate::fri::prover::fri_proof;
@ -35,6 +36,7 @@ impl<F: RichField> PolynomialBatchCommitment<F> {
blinding: bool,
cap_height: usize,
timing: &mut TimingTree,
fft_root_table: Option<&FftRootTable<F>>,
) -> Self {
let coeffs = timed!(
timing,
@ -42,7 +44,14 @@ impl<F: RichField> PolynomialBatchCommitment<F> {
values.par_iter().map(|v| v.ifft()).collect::<Vec<_>>()
);
Self::from_coeffs(coeffs, rate_bits, blinding, cap_height, timing)
Self::from_coeffs(
coeffs,
rate_bits,
blinding,
cap_height,
timing,
fft_root_table,
)
}
/// Creates a list polynomial commitment for the polynomials `polynomials`.
@ -52,12 +61,13 @@ impl<F: RichField> PolynomialBatchCommitment<F> {
blinding: bool,
cap_height: usize,
timing: &mut TimingTree,
fft_root_table: Option<&FftRootTable<F>>,
) -> Self {
let degree = polynomials[0].len();
let lde_values = timed!(
timing,
"FFT + blinding",
Self::lde_values(&polynomials, rate_bits, blinding)
Self::lde_values(&polynomials, rate_bits, blinding, fft_root_table)
);
let mut leaves = timed!(timing, "transpose LDEs", transpose(&lde_values));
@ -81,6 +91,7 @@ impl<F: RichField> PolynomialBatchCommitment<F> {
polynomials: &[PolynomialCoeffs<F>],
rate_bits: usize,
blinding: bool,
fft_root_table: Option<&FftRootTable<F>>,
) -> Vec<Vec<F>> {
let degree = polynomials[0].len();
@ -92,7 +103,7 @@ impl<F: RichField> PolynomialBatchCommitment<F> {
.map(|p| {
assert_eq!(p.len(), degree, "Polynomial degrees inconsistent");
p.lde(rate_bits)
.coset_fft_with_options(F::coset_shift(), Some(rate_bits), None)
.coset_fft_with_options(F::coset_shift(), Some(rate_bits), fft_root_table)
.values
})
.chain(
@ -309,6 +320,7 @@ mod tests {
common_data.config.zero_knowledge && PlonkPolynomials::polynomials(i).blinding,
common_data.config.cap_height,
&mut TimingTree::default(),
None,
)
})
.collect::<Vec<_>>();

View File

@ -1,3 +1,4 @@
use std::cmp::max;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::convert::TryInto;
use std::time::Instant;
@ -7,6 +8,7 @@ use log::{info, Level};
use crate::field::cosets::get_unique_coset_shifts;
use crate::field::extension_field::target::ExtensionTarget;
use crate::field::extension_field::{Extendable, FieldExtension};
use crate::field::fft::fft_root_table;
use crate::field::field_types::RichField;
use crate::fri::commitment::PolynomialBatchCommitment;
use crate::gates::arithmetic::{ArithmeticExtensionGate, NUM_ARITHMETIC_OPS};
@ -605,6 +607,11 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let k_is = get_unique_coset_shifts(degree, self.config.num_routed_wires);
let (sigma_vecs, partition_witness) = self.sigma_vecs(&k_is, &subgroup);
// Precompute FFT roots.
let max_fft_points =
1 << degree_bits + max(self.config.rate_bits, log2_ceil(quotient_degree_factor));
let fft_root_table = fft_root_table(max_fft_points);
let constants_sigmas_vecs = [constant_vecs, sigma_vecs.clone()].concat();
let constants_sigmas_commitment = PolynomialBatchCommitment::from_values(
constants_sigmas_vecs,
@ -612,6 +619,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
self.config.zero_knowledge & PlonkPolynomials::CONSTANTS_SIGMAS.blinding,
self.config.cap_height,
&mut timing,
Some(&fft_root_table),
);
let constants_sigmas_cap = constants_sigmas_commitment.merkle_tree.cap.clone();
@ -645,6 +653,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
public_inputs: self.public_inputs,
marked_targets: self.marked_targets,
partition_witness,
fft_root_table: Some(fft_root_table),
};
// The HashSet of gates will have a non-deterministic order. When converting to a Vec, we

View File

@ -4,6 +4,7 @@ use std::ops::{Range, RangeFrom};
use anyhow::Result;
use crate::field::extension_field::Extendable;
use crate::field::fft::FftRootTable;
use crate::field::field_types::{Field, RichField};
use crate::fri::commitment::PolynomialBatchCommitment;
use crate::fri::FriConfig;
@ -157,6 +158,8 @@ pub(crate) struct ProverOnlyCircuitData<F: RichField + Extendable<D>, const D: u
pub marked_targets: Vec<MarkedTargets<D>>,
/// Partial witness holding the copy constraints information.
pub partition_witness: PartitionWitness<F>,
/// Pre-computed roots for faster FFT.
pub fft_root_table: Option<FftRootTable<F>>,
}
/// Circuit data required by the verifier, but not the prover.

View File

@ -74,6 +74,7 @@ pub(crate) fn prove<F: RichField + Extendable<D>, const D: usize>(
config.zero_knowledge & PlonkPolynomials::WIRES.blinding,
config.cap_height,
&mut timing,
prover_data.fft_root_table.as_ref(),
)
);
@ -119,6 +120,7 @@ pub(crate) fn prove<F: RichField + Extendable<D>, const D: usize>(
config.zero_knowledge & PlonkPolynomials::ZS_PARTIAL_PRODUCTS.blinding,
config.cap_height,
&mut timing,
prover_data.fft_root_table.as_ref(),
)
);
@ -167,7 +169,8 @@ pub(crate) fn prove<F: RichField + Extendable<D>, const D: usize>(
config.rate_bits,
config.zero_knowledge & PlonkPolynomials::QUOTIENT.blinding,
config.cap_height,
&mut timing
&mut timing,
prover_data.fft_root_table.as_ref(),
)
);

View File

@ -211,7 +211,7 @@ impl<F: Field> PolynomialCoeffs<F> {
pub fn fft_with_options(
&self,
zero_factor: Option<usize>,
root_table: Option<FftRootTable<F>>,
root_table: Option<&FftRootTable<F>>,
) -> PolynomialValues<F> {
fft_with_options(self, zero_factor, root_table)
}
@ -226,7 +226,7 @@ impl<F: Field> PolynomialCoeffs<F> {
&self,
shift: F,
zero_factor: Option<usize>,
root_table: Option<FftRootTable<F>>,
root_table: Option<&FftRootTable<F>>,
) -> PolynomialValues<F> {
let modified_poly: Self = shift
.powers()