Added partial products

This commit is contained in:
wborgeaud 2021-06-30 15:05:40 +02:00
parent 69fff573fe
commit a0298a61f4

View File

@ -1,5 +1,6 @@
use std::time::Instant;
use itertools::Itertools;
use log::info;
use rayon::prelude::*;
@ -12,7 +13,7 @@ use crate::polynomial::commitment::ListPolynomialCommitment;
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::proof::Proof;
use crate::timed;
use crate::util::{log2_ceil, transpose};
use crate::util::{ceil_div_usize, log2_ceil, transpose};
use crate::vars::EvaluationVarsBase;
use crate::witness::{PartialWitness, Witness};
@ -78,14 +79,20 @@ pub(crate) fn prove<F: Extendable<D>, const D: usize>(
let betas = challenger.get_n_challenges(num_challenges);
let gammas = challenger.get_n_challenges(num_challenges);
let partial_products = timed!(
all_wires_permutation_partial_products(&witness, &betas, &gammas, prover_data, common_data),
"to compute partial products"
);
let plonk_z_vecs = timed!(
compute_zs(&witness, &betas, &gammas, prover_data, common_data),
compute_zs(&partial_products, prover_data, common_data),
"to compute Z's"
);
let zs_partial_products = [partial_products.concat(), plonk_z_vecs].concat();
let plonk_zs_commitment = timed!(
ListPolynomialCommitment::new(
plonk_z_vecs,
zs_partial_products,
fri_config.rate_bits,
PlonkPolynomials::ZS.blinding
),
@ -168,41 +175,105 @@ pub(crate) fn prove<F: Extendable<D>, const D: usize>(
}
}
fn compute_zs<F: Extendable<D>, const D: usize>(
fn all_wires_permutation_partial_products<F: Extendable<D>, const D: usize>(
witness: &Witness<F>,
betas: &[F],
gammas: &[F],
prover_data: &ProverOnlyCircuitData<F, D>,
common_data: &CommonCircuitData<F, D>,
) -> Vec<PolynomialValues<F>> {
) -> Vec<Vec<PolynomialValues<F>>> {
(0..common_data.config.num_challenges)
.map(|i| compute_z(witness, betas[i], gammas[i], prover_data, common_data))
.map(|i| {
wires_permutation_partial_products(
witness,
betas[i],
gammas[i],
prover_data,
common_data,
)
})
.collect()
}
fn compute_z<F: Extendable<D>, const D: usize>(
fn wires_permutation_partial_products<F: Extendable<D>, const D: usize>(
witness: &Witness<F>,
beta: F,
gamma: F,
prover_data: &ProverOnlyCircuitData<F, D>,
common_data: &CommonCircuitData<F, D>,
) -> Vec<PolynomialValues<F>> {
let vanish_degree = common_data
.max_filtered_constraint_degree
.next_power_of_two();
let num_polys = ceil_div_usize(common_data.config.num_routed_wires, vanish_degree);
assert!(
num_polys <= vanish_degree,
"Not supported yet. would need to add partial products of partial products for this."
);
let subgroup = &prover_data.subgroup;
let mut values = vec![vec![F::ONE; 2 * num_polys]];
let k_is = &common_data.k_is;
for i in 1..common_data.degree() {
let x = subgroup[i - 1];
let s_sigmas = &prover_data.sigmas[i - 1];
let mut partials_numerator = Vec::with_capacity(2 * num_polys);
let mut partials_denominator = Vec::with_capacity(num_polys);
for chunk in (0..common_data.config.num_routed_wires)
.collect::<Vec<_>>()
.chunks(vanish_degree)
{
let (numerator, denominator) = chunk.iter().fold((F::ONE, F::ONE), |acc, &j| {
let wire_value = witness.get_wire(i - 1, j);
let k_i = k_is[j];
let s_id = k_i * x;
let s_sigma = s_sigmas[j];
(
acc.0 * wire_value + beta * s_id + gamma,
acc.1 * wire_value + beta * s_sigma + gamma,
)
});
partials_numerator.push(numerator);
partials_denominator.push(denominator);
}
partials_numerator.append(&mut partials_denominator);
values.push(partials_numerator);
}
transpose(&values)
.into_par_iter()
.map(PolynomialValues::new)
.collect()
}
fn compute_zs<F: Extendable<D>, const D: usize>(
partial_products: &[Vec<PolynomialValues<F>>],
prover_data: &ProverOnlyCircuitData<F, D>,
common_data: &CommonCircuitData<F, D>,
) -> Vec<PolynomialValues<F>> {
(0..common_data.config.num_challenges)
.map(|i| compute_z(&partial_products[i], prover_data, common_data))
.collect()
}
fn compute_z<F: Extendable<D>, const D: usize>(
partial_products: &[PolynomialValues<F>],
prover_data: &ProverOnlyCircuitData<F, D>,
common_data: &CommonCircuitData<F, D>,
) -> PolynomialValues<F> {
let num_partials = partial_products.len() / 2;
let subgroup = &prover_data.subgroup;
let mut plonk_z_points = vec![F::ONE];
let k_is = &common_data.k_is;
for i in 1..common_data.degree() {
let x = subgroup[i - 1];
let mut numerator = F::ONE;
let mut denominator = F::ONE;
let s_sigmas = &prover_data.sigmas[i - 1];
for j in 0..common_data.config.num_routed_wires {
let wire_value = witness.get_wire(i - 1, j);
let k_i = k_is[j];
let s_id = k_i * x;
let s_sigma = s_sigmas[j];
numerator *= wire_value + beta * s_id + gamma;
denominator *= wire_value + beta * s_sigma + gamma;
}
let numerator = partial_products[..num_partials]
.iter()
.map(|vs| vs.values[i])
.product();
let denominator = partial_products[num_partials..]
.iter()
.map(|vs| vs.values[i])
.product();
let last = *plonk_z_points.last().unwrap();
plonk_z_points.push(last * numerator / denominator);
}