diff --git a/src/prover.rs b/src/prover.rs index 67b5a185..76b69e10 100644 --- a/src/prover.rs +++ b/src/prover.rs @@ -1,5 +1,6 @@ use std::time::Instant; +use itertools::Itertools; use log::info; use rayon::prelude::*; @@ -12,7 +13,7 @@ use crate::polynomial::commitment::ListPolynomialCommitment; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::proof::Proof; use crate::timed; -use crate::util::{log2_ceil, transpose}; +use crate::util::{ceil_div_usize, log2_ceil, transpose}; use crate::vars::EvaluationVarsBase; use crate::witness::{PartialWitness, Witness}; @@ -78,14 +79,20 @@ pub(crate) fn prove, const D: usize>( let betas = challenger.get_n_challenges(num_challenges); let gammas = challenger.get_n_challenges(num_challenges); + let partial_products = timed!( + all_wires_permutation_partial_products(&witness, &betas, &gammas, prover_data, common_data), + "to compute partial products" + ); + let plonk_z_vecs = timed!( - compute_zs(&witness, &betas, &gammas, prover_data, common_data), + compute_zs(&partial_products, prover_data, common_data), "to compute Z's" ); + let zs_partial_products = [partial_products.concat(), plonk_z_vecs].concat(); let plonk_zs_commitment = timed!( ListPolynomialCommitment::new( - plonk_z_vecs, + zs_partial_products, fri_config.rate_bits, PlonkPolynomials::ZS.blinding ), @@ -168,41 +175,105 @@ pub(crate) fn prove, const D: usize>( } } -fn compute_zs, const D: usize>( +fn all_wires_permutation_partial_products, const D: usize>( witness: &Witness, betas: &[F], gammas: &[F], prover_data: &ProverOnlyCircuitData, common_data: &CommonCircuitData, -) -> Vec> { +) -> Vec>> { (0..common_data.config.num_challenges) - .map(|i| compute_z(witness, betas[i], gammas[i], prover_data, common_data)) + .map(|i| { + wires_permutation_partial_products( + witness, + betas[i], + gammas[i], + prover_data, + common_data, + ) + }) .collect() } -fn compute_z, const D: usize>( +fn wires_permutation_partial_products, const D: usize>( witness: &Witness, beta: F, gamma: F, prover_data: &ProverOnlyCircuitData, common_data: &CommonCircuitData, +) -> Vec> { + let vanish_degree = common_data + .max_filtered_constraint_degree + .next_power_of_two(); + let num_polys = ceil_div_usize(common_data.config.num_routed_wires, vanish_degree); + assert!( + num_polys <= vanish_degree, + "Not supported yet. would need to add partial products of partial products for this." + ); + let subgroup = &prover_data.subgroup; + let mut values = vec![vec![F::ONE; 2 * num_polys]]; + let k_is = &common_data.k_is; + for i in 1..common_data.degree() { + let x = subgroup[i - 1]; + let s_sigmas = &prover_data.sigmas[i - 1]; + let mut partials_numerator = Vec::with_capacity(2 * num_polys); + let mut partials_denominator = Vec::with_capacity(num_polys); + for chunk in (0..common_data.config.num_routed_wires) + .collect::>() + .chunks(vanish_degree) + { + let (numerator, denominator) = chunk.iter().fold((F::ONE, F::ONE), |acc, &j| { + let wire_value = witness.get_wire(i - 1, j); + let k_i = k_is[j]; + let s_id = k_i * x; + let s_sigma = s_sigmas[j]; + ( + acc.0 * wire_value + beta * s_id + gamma, + acc.1 * wire_value + beta * s_sigma + gamma, + ) + }); + partials_numerator.push(numerator); + partials_denominator.push(denominator); + } + partials_numerator.append(&mut partials_denominator); + values.push(partials_numerator); + } + + transpose(&values) + .into_par_iter() + .map(PolynomialValues::new) + .collect() +} + +fn compute_zs, const D: usize>( + partial_products: &[Vec>], + prover_data: &ProverOnlyCircuitData, + common_data: &CommonCircuitData, +) -> Vec> { + (0..common_data.config.num_challenges) + .map(|i| compute_z(&partial_products[i], prover_data, common_data)) + .collect() +} + +fn compute_z, const D: usize>( + partial_products: &[PolynomialValues], + prover_data: &ProverOnlyCircuitData, + common_data: &CommonCircuitData, ) -> PolynomialValues { + let num_partials = partial_products.len() / 2; let subgroup = &prover_data.subgroup; let mut plonk_z_points = vec![F::ONE]; let k_is = &common_data.k_is; for i in 1..common_data.degree() { let x = subgroup[i - 1]; - let mut numerator = F::ONE; - let mut denominator = F::ONE; - let s_sigmas = &prover_data.sigmas[i - 1]; - for j in 0..common_data.config.num_routed_wires { - let wire_value = witness.get_wire(i - 1, j); - let k_i = k_is[j]; - let s_id = k_i * x; - let s_sigma = s_sigmas[j]; - numerator *= wire_value + beta * s_id + gamma; - denominator *= wire_value + beta * s_sigma + gamma; - } + let numerator = partial_products[..num_partials] + .iter() + .map(|vs| vs.values[i]) + .product(); + let denominator = partial_products[num_partials..] + .iter() + .map(|vs| vs.values[i]) + .product(); let last = *plonk_z_points.last().unwrap(); plonk_z_points.push(last * numerator / denominator); }