Batched eval_vanishing_poly_base (#317)

* Batched eval_vanishing_poly_base

* Reduce the number of allocations

* Lints

* Delete unused things

* Minor: fix a debug_assert

* Daniel PR comments

* Lints

* Daniel PR comments
This commit is contained in:
Jakub Nabaglo 2021-10-25 13:23:05 -07:00 committed by GitHub
parent f616d6436d
commit bf421314f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 216 additions and 118 deletions

View File

@ -68,9 +68,21 @@ pub trait Gate<F: RichField + Extendable<D>, const D: usize>: 'static + Send + S
fn eval_filtered_base(&self, mut vars: EvaluationVarsBase<F>, prefix: &[bool]) -> Vec<F> {
let filter = compute_filter(prefix, vars.local_constants);
vars.remove_prefix(prefix);
self.eval_unfiltered_base(vars)
.into_iter()
.map(|c| c * filter)
let mut res = self.eval_unfiltered_base(vars);
res.iter_mut().for_each(|c| {
*c *= filter;
});
res
}
fn eval_filtered_base_batch(
&self,
vars_batch: &[EvaluationVarsBase<F>],
prefix: &[bool],
) -> Vec<Vec<F>> {
vars_batch
.iter()
.map(|&vars| self.eval_filtered_base(vars, prefix))
.collect()
}

View File

@ -147,12 +147,25 @@ pub(crate) fn eval_l_1_recursively<F: RichField + Extendable<D>, const D: usize>
builder.div_extension(eval_zero_poly, denominator)
}
/// For each alpha in alphas, compute a reduction of the given terms using powers of alpha.
pub(crate) fn reduce_with_powers_multi<F: Field>(terms: &[F], alphas: &[F]) -> Vec<F> {
alphas
.iter()
.map(|&alpha| reduce_with_powers(terms, alpha))
.collect()
/// For each alpha in alphas, compute a reduction of the given terms using powers of alpha. T can
/// be any type convertible to a double-ended iterator.
pub(crate) fn reduce_with_powers_multi<
'a,
F: Field,
I: DoubleEndedIterator<Item = &'a F>,
T: IntoIterator<IntoIter = I>,
>(
terms: T,
alphas: &[F],
) -> Vec<F> {
let mut cumul = vec![F::ZERO; alphas.len()];
for &term in terms.into_iter().rev() {
cumul
.iter_mut()
.zip(alphas)
.for_each(|(c, &alpha)| *c = term.multiply_accumulate(*c, alpha));
}
cumul
}
pub(crate) fn reduce_with_powers<F: Field>(terms: &[F], alpha: F) -> F {

View File

@ -14,7 +14,7 @@ use crate::plonk::circuit_data::{CommonCircuitData, ProverOnlyCircuitData};
use crate::plonk::plonk_common::PlonkPolynomials;
use crate::plonk::plonk_common::ZeroPolyOnCoset;
use crate::plonk::proof::{Proof, ProofWithPublicInputs};
use crate::plonk::vanishing_poly::eval_vanishing_poly_base;
use crate::plonk::vanishing_poly::eval_vanishing_poly_base_batch;
use crate::plonk::vars::EvaluationVarsBase;
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues};
use crate::timed;
@ -309,6 +309,8 @@ fn compute_z<F: RichField + Extendable<D>, const D: usize>(
plonk_z_points.into()
}
const BATCH_SIZE: usize = 32;
fn compute_quotient_polys<'a, F: RichField + Extendable<D>, const D: usize>(
common_data: &CommonCircuitData<F, D>,
prover_data: &'a ProverOnlyCircuitData<F, D>,
@ -344,50 +346,77 @@ fn compute_quotient_polys<'a, F: RichField + Extendable<D>, const D: usize>(
let z_h_on_coset = ZeroPolyOnCoset::new(common_data.degree_bits, max_degree_bits);
let quotient_values: Vec<Vec<F>> = points
.into_par_iter()
let points_batches = points.par_chunks(BATCH_SIZE);
let quotient_values: Vec<Vec<F>> = points_batches
.enumerate()
.map(|(i, x)| {
let shifted_x = F::coset_shift() * x;
let i_next = (i + next_step) % lde_size;
let local_constants_sigmas = get_at_index(&prover_data.constants_sigmas_commitment, i);
let local_constants = &local_constants_sigmas[common_data.constants_range()];
let s_sigmas = &local_constants_sigmas[common_data.sigmas_range()];
let local_wires = get_at_index(wires_commitment, i);
let local_zs_partial_products = get_at_index(zs_partial_products_commitment, i);
let local_zs = &local_zs_partial_products[common_data.zs_range()];
let next_zs =
&get_at_index(zs_partial_products_commitment, i_next)[common_data.zs_range()];
let partial_products = &local_zs_partial_products[common_data.partial_products_range()];
.map(|(batch_i, xs_batch)| {
assert_eq!(xs_batch.len(), BATCH_SIZE);
let indices_batch: Vec<usize> =
(BATCH_SIZE * batch_i..BATCH_SIZE * (batch_i + 1)).collect();
debug_assert_eq!(local_wires.len(), common_data.config.num_wires);
debug_assert_eq!(local_zs.len(), num_challenges);
let mut shifted_xs_batch = Vec::with_capacity(xs_batch.len());
let mut vars_batch = Vec::with_capacity(xs_batch.len());
let mut local_zs_batch = Vec::with_capacity(xs_batch.len());
let mut next_zs_batch = Vec::with_capacity(xs_batch.len());
let mut partial_products_batch = Vec::with_capacity(xs_batch.len());
let mut s_sigmas_batch = Vec::with_capacity(xs_batch.len());
let vars = EvaluationVarsBase {
local_constants,
local_wires,
public_inputs_hash,
};
let mut quotient_values = eval_vanishing_poly_base(
for (&i, &x) in indices_batch.iter().zip(xs_batch) {
let shifted_x = F::coset_shift() * x;
let i_next = (i + next_step) % lde_size;
let local_constants_sigmas =
get_at_index(&prover_data.constants_sigmas_commitment, i);
let local_constants = &local_constants_sigmas[common_data.constants_range()];
let s_sigmas = &local_constants_sigmas[common_data.sigmas_range()];
let local_wires = get_at_index(wires_commitment, i);
let local_zs_partial_products = get_at_index(zs_partial_products_commitment, i);
let local_zs = &local_zs_partial_products[common_data.zs_range()];
let next_zs =
&get_at_index(zs_partial_products_commitment, i_next)[common_data.zs_range()];
let partial_products =
&local_zs_partial_products[common_data.partial_products_range()];
debug_assert_eq!(local_wires.len(), common_data.config.num_wires);
debug_assert_eq!(local_zs.len(), num_challenges);
let vars = EvaluationVarsBase {
local_constants,
local_wires,
public_inputs_hash,
};
shifted_xs_batch.push(shifted_x);
vars_batch.push(vars);
local_zs_batch.push(local_zs);
next_zs_batch.push(next_zs);
partial_products_batch.push(partial_products);
s_sigmas_batch.push(s_sigmas);
}
let mut quotient_values_batch = eval_vanishing_poly_base_batch(
common_data,
i,
shifted_x,
vars,
local_zs,
next_zs,
partial_products,
s_sigmas,
&indices_batch,
&shifted_xs_batch,
&vars_batch,
&local_zs_batch,
&next_zs_batch,
&partial_products_batch,
&s_sigmas_batch,
betas,
gammas,
alphas,
&z_h_on_coset,
);
let denominator_inv = z_h_on_coset.eval_inverse(i);
quotient_values
.iter_mut()
.for_each(|v| *v *= denominator_inv);
quotient_values
for (&i, quotient_values) in indices_batch.iter().zip(quotient_values_batch.iter_mut())
{
let denominator_inv = z_h_on_coset.eval_inverse(i);
quotient_values
.iter_mut()
.for_each(|v| *v *= denominator_inv);
}
quotient_values_batch
})
.flatten()
.collect();
transpose(&quotient_values)

View File

@ -101,30 +101,45 @@ pub(crate) fn eval_vanishing_poly<F: RichField + Extendable<D>, const D: usize>(
plonk_common::reduce_with_powers_multi(&vanishing_terms, alphas)
}
/// Like `eval_vanishing_poly`, but specialized for base field points.
pub(crate) fn eval_vanishing_poly_base<F: RichField + Extendable<D>, const D: usize>(
/// Like `eval_vanishing_poly`, but specialized for base field points. Batched.
pub(crate) fn eval_vanishing_poly_base_batch<F: RichField + Extendable<D>, const D: usize>(
common_data: &CommonCircuitData<F, D>,
index: usize,
x: F,
vars: EvaluationVarsBase<F>,
local_zs: &[F],
next_zs: &[F],
partial_products: &[F],
s_sigmas: &[F],
indices_batch: &[usize],
xs_batch: &[F],
vars_batch: &[EvaluationVarsBase<F>],
local_zs_batch: &[&[F]],
next_zs_batch: &[&[F]],
partial_products_batch: &[&[F]],
s_sigmas_batch: &[&[F]],
betas: &[F],
gammas: &[F],
alphas: &[F],
z_h_on_coset: &ZeroPolyOnCoset<F>,
) -> Vec<F> {
) -> Vec<Vec<F>> {
let n = indices_batch.len();
assert_eq!(xs_batch.len(), n);
assert_eq!(vars_batch.len(), n);
assert_eq!(local_zs_batch.len(), n);
assert_eq!(next_zs_batch.len(), n);
assert_eq!(partial_products_batch.len(), n);
assert_eq!(s_sigmas_batch.len(), n);
let max_degree = common_data.quotient_degree_factor;
let (num_prods, final_num_prod) = common_data.num_partial_products;
let constraint_terms =
evaluate_gate_constraints_base(&common_data.gates, common_data.num_gate_constraints, vars);
let num_gate_constraints = common_data.num_gate_constraints;
let constraint_terms_batch =
evaluate_gate_constraints_base_batch(&common_data.gates, num_gate_constraints, vars_batch);
debug_assert!(constraint_terms_batch.len() == n * num_gate_constraints);
let num_challenges = common_data.config.num_challenges;
let num_routed_wires = common_data.config.num_routed_wires;
let mut numerator_values = Vec::with_capacity(num_routed_wires);
let mut denominator_values = Vec::with_capacity(num_routed_wires);
let mut quotient_values = Vec::with_capacity(num_routed_wires);
// The L_1(x) (Z(x) - 1) vanishing terms.
let mut vanishing_z_1_terms = Vec::with_capacity(num_challenges);
// The terms checking the partial products.
@ -132,67 +147,81 @@ pub(crate) fn eval_vanishing_poly_base<F: RichField + Extendable<D>, const D: us
// The Z(x) f'(x) - g'(x) Z(g x) terms.
let mut vanishing_v_shift_terms = Vec::with_capacity(num_challenges);
let l1_x = z_h_on_coset.eval_l1(index, x);
let mut res_batch: Vec<Vec<F>> = Vec::with_capacity(n);
for k in 0..n {
let index = indices_batch[k];
let x = xs_batch[k];
let vars = vars_batch[k];
let local_zs = local_zs_batch[k];
let next_zs = next_zs_batch[k];
let partial_products = partial_products_batch[k];
let s_sigmas = s_sigmas_batch[k];
let mut numerator_values = Vec::with_capacity(num_routed_wires);
let mut denominator_values = Vec::with_capacity(num_routed_wires);
let mut quotient_values = Vec::with_capacity(num_routed_wires);
for i in 0..num_challenges {
let z_x = local_zs[i];
let z_gz = next_zs[i];
vanishing_z_1_terms.push(l1_x * z_x.sub_one());
let constraint_terms =
&constraint_terms_batch[k * num_gate_constraints..(k + 1) * num_gate_constraints];
numerator_values.extend((0..num_routed_wires).map(|j| {
let wire_value = vars.local_wires[j];
let k_i = common_data.k_is[j];
let s_id = k_i * x;
wire_value + betas[i] * s_id + gammas[i]
}));
denominator_values.extend((0..num_routed_wires).map(|j| {
let wire_value = vars.local_wires[j];
let s_sigma = s_sigmas[j];
wire_value + betas[i] * s_sigma + gammas[i]
}));
let denominator_inverses = F::batch_multiplicative_inverse(&denominator_values);
quotient_values
.extend((0..num_routed_wires).map(|j| numerator_values[j] * denominator_inverses[j]));
let l1_x = z_h_on_coset.eval_l1(index, x);
for i in 0..num_challenges {
let z_x = local_zs[i];
let z_gz = next_zs[i];
vanishing_z_1_terms.push(l1_x * z_x.sub_one());
// The partial products considered for this iteration of `i`.
let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods];
// Check the numerator partial products.
let mut partial_product_check =
check_partial_products(&quotient_values, current_partial_products, max_degree);
// The first checks are of the form `q - n/d` which is a rational function not a polynomial.
// We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials.
denominator_values
.chunks(max_degree)
.zip(partial_product_check.iter_mut())
.for_each(|(d, q)| {
*q *= d.iter().copied().product();
});
vanishing_partial_products_terms.extend(partial_product_check);
numerator_values.extend((0..num_routed_wires).map(|j| {
let wire_value = vars.local_wires[j];
let k_i = common_data.k_is[j];
let s_id = k_i * x;
wire_value + betas[i] * s_id + gammas[i]
}));
denominator_values.extend((0..num_routed_wires).map(|j| {
let wire_value = vars.local_wires[j];
let s_sigma = s_sigmas[j];
wire_value + betas[i] * s_sigma + gammas[i]
}));
let denominator_inverses = F::batch_multiplicative_inverse(&denominator_values);
quotient_values.extend(
(0..num_routed_wires).map(|j| numerator_values[j] * denominator_inverses[j]),
);
// The quotient final product is the product of the last `final_num_prod` elements.
let quotient: F = current_partial_products[num_prods - final_num_prod..]
// The partial products considered for this iteration of `i`.
let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods];
// Check the numerator partial products.
let mut partial_product_check =
check_partial_products(&quotient_values, current_partial_products, max_degree);
// The first checks are of the form `q - n/d` which is a rational function not a polynomial.
// We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials.
denominator_values
.chunks(max_degree)
.zip(partial_product_check.iter_mut())
.for_each(|(d, q)| {
*q *= d.iter().copied().product();
});
vanishing_partial_products_terms.extend(partial_product_check);
// The quotient final product is the product of the last `final_num_prod` elements.
let quotient: F = current_partial_products[num_prods - final_num_prod..]
.iter()
.copied()
.product();
vanishing_v_shift_terms.push(quotient * z_x - z_gz);
numerator_values.clear();
denominator_values.clear();
quotient_values.clear();
}
let vanishing_terms = vanishing_z_1_terms
.iter()
.copied()
.product();
vanishing_v_shift_terms.push(quotient * z_x - z_gz);
.chain(vanishing_partial_products_terms.iter())
.chain(vanishing_v_shift_terms.iter())
.chain(constraint_terms);
let res = plonk_common::reduce_with_powers_multi(vanishing_terms, alphas);
res_batch.push(res);
numerator_values.clear();
denominator_values.clear();
quotient_values.clear();
vanishing_z_1_terms.clear();
vanishing_partial_products_terms.clear();
vanishing_v_shift_terms.clear();
}
let vanishing_terms = [
vanishing_z_1_terms,
vanishing_partial_products_terms,
vanishing_v_shift_terms,
constraint_terms,
]
.concat();
plonk_common::reduce_with_powers_multi(&vanishing_terms, alphas)
res_batch
}
/// Evaluates all gate constraints.
@ -219,23 +248,38 @@ pub fn evaluate_gate_constraints<F: RichField + Extendable<D>, const D: usize>(
constraints
}
pub fn evaluate_gate_constraints_base<F: RichField + Extendable<D>, const D: usize>(
/// Evaluate all gate constraints in the base field.
///
/// Returns a vector of num_gate_constraints * vars_batch.len() field elements. The constraints
/// corresponding to vars_batch[i] are found in
/// result[num_gate_constraints * i..num_gate_constraints * (i + 1)].
pub fn evaluate_gate_constraints_base_batch<F: RichField + Extendable<D>, const D: usize>(
gates: &[PrefixedGate<F, D>],
num_gate_constraints: usize,
vars: EvaluationVarsBase<F>,
vars_batch: &[EvaluationVarsBase<F>],
) -> Vec<F> {
let mut constraints = vec![F::ZERO; num_gate_constraints];
let mut constraints_batch = vec![F::ZERO; num_gate_constraints * vars_batch.len()];
for gate in gates {
let gate_constraints = gate.gate.0.eval_filtered_base(vars, &gate.prefix);
for (i, c) in gate_constraints.into_iter().enumerate() {
let gate_constraints_batch = gate
.gate
.0
.eval_filtered_base_batch(vars_batch, &gate.prefix);
for (constraints, gate_constraints) in constraints_batch
.chunks_exact_mut(num_gate_constraints)
.zip(gate_constraints_batch.iter())
{
debug_assert!(
i < num_gate_constraints,
gate_constraints.len() <= constraints.len(),
"num_constraints() gave too low of a number"
);
constraints[i] += c;
for (constraint, &gate_constraint) in
constraints.iter_mut().zip(gate_constraints.iter())
{
*constraint += gate_constraint;
}
}
}
constraints
constraints_batch
}
pub fn evaluate_gate_constraints_recursively<F: RichField + Extendable<D>, const D: usize>(