mirror of
https://github.com/logos-storage/plonky2.git
synced 2026-01-07 08:13:11 +00:00
Batched eval_vanishing_poly_base (#317)
* Batched eval_vanishing_poly_base * Reduce the number of allocations * Lints * Delete unused things * Minor: fix a debug_assert * Daniel PR comments * Lints * Daniel PR comments
This commit is contained in:
parent
f616d6436d
commit
bf421314f9
@ -68,9 +68,21 @@ pub trait Gate<F: RichField + Extendable<D>, const D: usize>: 'static + Send + S
|
||||
fn eval_filtered_base(&self, mut vars: EvaluationVarsBase<F>, prefix: &[bool]) -> Vec<F> {
|
||||
let filter = compute_filter(prefix, vars.local_constants);
|
||||
vars.remove_prefix(prefix);
|
||||
self.eval_unfiltered_base(vars)
|
||||
.into_iter()
|
||||
.map(|c| c * filter)
|
||||
let mut res = self.eval_unfiltered_base(vars);
|
||||
res.iter_mut().for_each(|c| {
|
||||
*c *= filter;
|
||||
});
|
||||
res
|
||||
}
|
||||
|
||||
fn eval_filtered_base_batch(
|
||||
&self,
|
||||
vars_batch: &[EvaluationVarsBase<F>],
|
||||
prefix: &[bool],
|
||||
) -> Vec<Vec<F>> {
|
||||
vars_batch
|
||||
.iter()
|
||||
.map(|&vars| self.eval_filtered_base(vars, prefix))
|
||||
.collect()
|
||||
}
|
||||
|
||||
|
||||
@ -147,12 +147,25 @@ pub(crate) fn eval_l_1_recursively<F: RichField + Extendable<D>, const D: usize>
|
||||
builder.div_extension(eval_zero_poly, denominator)
|
||||
}
|
||||
|
||||
/// For each alpha in alphas, compute a reduction of the given terms using powers of alpha.
|
||||
pub(crate) fn reduce_with_powers_multi<F: Field>(terms: &[F], alphas: &[F]) -> Vec<F> {
|
||||
alphas
|
||||
.iter()
|
||||
.map(|&alpha| reduce_with_powers(terms, alpha))
|
||||
.collect()
|
||||
/// For each alpha in alphas, compute a reduction of the given terms using powers of alpha. T can
|
||||
/// be any type convertible to a double-ended iterator.
|
||||
pub(crate) fn reduce_with_powers_multi<
|
||||
'a,
|
||||
F: Field,
|
||||
I: DoubleEndedIterator<Item = &'a F>,
|
||||
T: IntoIterator<IntoIter = I>,
|
||||
>(
|
||||
terms: T,
|
||||
alphas: &[F],
|
||||
) -> Vec<F> {
|
||||
let mut cumul = vec![F::ZERO; alphas.len()];
|
||||
for &term in terms.into_iter().rev() {
|
||||
cumul
|
||||
.iter_mut()
|
||||
.zip(alphas)
|
||||
.for_each(|(c, &alpha)| *c = term.multiply_accumulate(*c, alpha));
|
||||
}
|
||||
cumul
|
||||
}
|
||||
|
||||
pub(crate) fn reduce_with_powers<F: Field>(terms: &[F], alpha: F) -> F {
|
||||
|
||||
@ -14,7 +14,7 @@ use crate::plonk::circuit_data::{CommonCircuitData, ProverOnlyCircuitData};
|
||||
use crate::plonk::plonk_common::PlonkPolynomials;
|
||||
use crate::plonk::plonk_common::ZeroPolyOnCoset;
|
||||
use crate::plonk::proof::{Proof, ProofWithPublicInputs};
|
||||
use crate::plonk::vanishing_poly::eval_vanishing_poly_base;
|
||||
use crate::plonk::vanishing_poly::eval_vanishing_poly_base_batch;
|
||||
use crate::plonk::vars::EvaluationVarsBase;
|
||||
use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues};
|
||||
use crate::timed;
|
||||
@ -309,6 +309,8 @@ fn compute_z<F: RichField + Extendable<D>, const D: usize>(
|
||||
plonk_z_points.into()
|
||||
}
|
||||
|
||||
const BATCH_SIZE: usize = 32;
|
||||
|
||||
fn compute_quotient_polys<'a, F: RichField + Extendable<D>, const D: usize>(
|
||||
common_data: &CommonCircuitData<F, D>,
|
||||
prover_data: &'a ProverOnlyCircuitData<F, D>,
|
||||
@ -344,50 +346,77 @@ fn compute_quotient_polys<'a, F: RichField + Extendable<D>, const D: usize>(
|
||||
|
||||
let z_h_on_coset = ZeroPolyOnCoset::new(common_data.degree_bits, max_degree_bits);
|
||||
|
||||
let quotient_values: Vec<Vec<F>> = points
|
||||
.into_par_iter()
|
||||
let points_batches = points.par_chunks(BATCH_SIZE);
|
||||
let quotient_values: Vec<Vec<F>> = points_batches
|
||||
.enumerate()
|
||||
.map(|(i, x)| {
|
||||
let shifted_x = F::coset_shift() * x;
|
||||
let i_next = (i + next_step) % lde_size;
|
||||
let local_constants_sigmas = get_at_index(&prover_data.constants_sigmas_commitment, i);
|
||||
let local_constants = &local_constants_sigmas[common_data.constants_range()];
|
||||
let s_sigmas = &local_constants_sigmas[common_data.sigmas_range()];
|
||||
let local_wires = get_at_index(wires_commitment, i);
|
||||
let local_zs_partial_products = get_at_index(zs_partial_products_commitment, i);
|
||||
let local_zs = &local_zs_partial_products[common_data.zs_range()];
|
||||
let next_zs =
|
||||
&get_at_index(zs_partial_products_commitment, i_next)[common_data.zs_range()];
|
||||
let partial_products = &local_zs_partial_products[common_data.partial_products_range()];
|
||||
.map(|(batch_i, xs_batch)| {
|
||||
assert_eq!(xs_batch.len(), BATCH_SIZE);
|
||||
let indices_batch: Vec<usize> =
|
||||
(BATCH_SIZE * batch_i..BATCH_SIZE * (batch_i + 1)).collect();
|
||||
|
||||
debug_assert_eq!(local_wires.len(), common_data.config.num_wires);
|
||||
debug_assert_eq!(local_zs.len(), num_challenges);
|
||||
let mut shifted_xs_batch = Vec::with_capacity(xs_batch.len());
|
||||
let mut vars_batch = Vec::with_capacity(xs_batch.len());
|
||||
let mut local_zs_batch = Vec::with_capacity(xs_batch.len());
|
||||
let mut next_zs_batch = Vec::with_capacity(xs_batch.len());
|
||||
let mut partial_products_batch = Vec::with_capacity(xs_batch.len());
|
||||
let mut s_sigmas_batch = Vec::with_capacity(xs_batch.len());
|
||||
|
||||
let vars = EvaluationVarsBase {
|
||||
local_constants,
|
||||
local_wires,
|
||||
public_inputs_hash,
|
||||
};
|
||||
let mut quotient_values = eval_vanishing_poly_base(
|
||||
for (&i, &x) in indices_batch.iter().zip(xs_batch) {
|
||||
let shifted_x = F::coset_shift() * x;
|
||||
let i_next = (i + next_step) % lde_size;
|
||||
let local_constants_sigmas =
|
||||
get_at_index(&prover_data.constants_sigmas_commitment, i);
|
||||
let local_constants = &local_constants_sigmas[common_data.constants_range()];
|
||||
let s_sigmas = &local_constants_sigmas[common_data.sigmas_range()];
|
||||
let local_wires = get_at_index(wires_commitment, i);
|
||||
let local_zs_partial_products = get_at_index(zs_partial_products_commitment, i);
|
||||
let local_zs = &local_zs_partial_products[common_data.zs_range()];
|
||||
let next_zs =
|
||||
&get_at_index(zs_partial_products_commitment, i_next)[common_data.zs_range()];
|
||||
let partial_products =
|
||||
&local_zs_partial_products[common_data.partial_products_range()];
|
||||
|
||||
debug_assert_eq!(local_wires.len(), common_data.config.num_wires);
|
||||
debug_assert_eq!(local_zs.len(), num_challenges);
|
||||
|
||||
let vars = EvaluationVarsBase {
|
||||
local_constants,
|
||||
local_wires,
|
||||
public_inputs_hash,
|
||||
};
|
||||
|
||||
shifted_xs_batch.push(shifted_x);
|
||||
vars_batch.push(vars);
|
||||
local_zs_batch.push(local_zs);
|
||||
next_zs_batch.push(next_zs);
|
||||
partial_products_batch.push(partial_products);
|
||||
s_sigmas_batch.push(s_sigmas);
|
||||
}
|
||||
let mut quotient_values_batch = eval_vanishing_poly_base_batch(
|
||||
common_data,
|
||||
i,
|
||||
shifted_x,
|
||||
vars,
|
||||
local_zs,
|
||||
next_zs,
|
||||
partial_products,
|
||||
s_sigmas,
|
||||
&indices_batch,
|
||||
&shifted_xs_batch,
|
||||
&vars_batch,
|
||||
&local_zs_batch,
|
||||
&next_zs_batch,
|
||||
&partial_products_batch,
|
||||
&s_sigmas_batch,
|
||||
betas,
|
||||
gammas,
|
||||
alphas,
|
||||
&z_h_on_coset,
|
||||
);
|
||||
let denominator_inv = z_h_on_coset.eval_inverse(i);
|
||||
quotient_values
|
||||
.iter_mut()
|
||||
.for_each(|v| *v *= denominator_inv);
|
||||
quotient_values
|
||||
|
||||
for (&i, quotient_values) in indices_batch.iter().zip(quotient_values_batch.iter_mut())
|
||||
{
|
||||
let denominator_inv = z_h_on_coset.eval_inverse(i);
|
||||
quotient_values
|
||||
.iter_mut()
|
||||
.for_each(|v| *v *= denominator_inv);
|
||||
}
|
||||
quotient_values_batch
|
||||
})
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
transpose("ient_values)
|
||||
|
||||
@ -101,30 +101,45 @@ pub(crate) fn eval_vanishing_poly<F: RichField + Extendable<D>, const D: usize>(
|
||||
plonk_common::reduce_with_powers_multi(&vanishing_terms, alphas)
|
||||
}
|
||||
|
||||
/// Like `eval_vanishing_poly`, but specialized for base field points.
|
||||
pub(crate) fn eval_vanishing_poly_base<F: RichField + Extendable<D>, const D: usize>(
|
||||
/// Like `eval_vanishing_poly`, but specialized for base field points. Batched.
|
||||
pub(crate) fn eval_vanishing_poly_base_batch<F: RichField + Extendable<D>, const D: usize>(
|
||||
common_data: &CommonCircuitData<F, D>,
|
||||
index: usize,
|
||||
x: F,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
local_zs: &[F],
|
||||
next_zs: &[F],
|
||||
partial_products: &[F],
|
||||
s_sigmas: &[F],
|
||||
indices_batch: &[usize],
|
||||
xs_batch: &[F],
|
||||
vars_batch: &[EvaluationVarsBase<F>],
|
||||
local_zs_batch: &[&[F]],
|
||||
next_zs_batch: &[&[F]],
|
||||
partial_products_batch: &[&[F]],
|
||||
s_sigmas_batch: &[&[F]],
|
||||
betas: &[F],
|
||||
gammas: &[F],
|
||||
alphas: &[F],
|
||||
z_h_on_coset: &ZeroPolyOnCoset<F>,
|
||||
) -> Vec<F> {
|
||||
) -> Vec<Vec<F>> {
|
||||
let n = indices_batch.len();
|
||||
assert_eq!(xs_batch.len(), n);
|
||||
assert_eq!(vars_batch.len(), n);
|
||||
assert_eq!(local_zs_batch.len(), n);
|
||||
assert_eq!(next_zs_batch.len(), n);
|
||||
assert_eq!(partial_products_batch.len(), n);
|
||||
assert_eq!(s_sigmas_batch.len(), n);
|
||||
|
||||
let max_degree = common_data.quotient_degree_factor;
|
||||
let (num_prods, final_num_prod) = common_data.num_partial_products;
|
||||
|
||||
let constraint_terms =
|
||||
evaluate_gate_constraints_base(&common_data.gates, common_data.num_gate_constraints, vars);
|
||||
let num_gate_constraints = common_data.num_gate_constraints;
|
||||
|
||||
let constraint_terms_batch =
|
||||
evaluate_gate_constraints_base_batch(&common_data.gates, num_gate_constraints, vars_batch);
|
||||
debug_assert!(constraint_terms_batch.len() == n * num_gate_constraints);
|
||||
|
||||
let num_challenges = common_data.config.num_challenges;
|
||||
let num_routed_wires = common_data.config.num_routed_wires;
|
||||
|
||||
let mut numerator_values = Vec::with_capacity(num_routed_wires);
|
||||
let mut denominator_values = Vec::with_capacity(num_routed_wires);
|
||||
let mut quotient_values = Vec::with_capacity(num_routed_wires);
|
||||
|
||||
// The L_1(x) (Z(x) - 1) vanishing terms.
|
||||
let mut vanishing_z_1_terms = Vec::with_capacity(num_challenges);
|
||||
// The terms checking the partial products.
|
||||
@ -132,67 +147,81 @@ pub(crate) fn eval_vanishing_poly_base<F: RichField + Extendable<D>, const D: us
|
||||
// The Z(x) f'(x) - g'(x) Z(g x) terms.
|
||||
let mut vanishing_v_shift_terms = Vec::with_capacity(num_challenges);
|
||||
|
||||
let l1_x = z_h_on_coset.eval_l1(index, x);
|
||||
let mut res_batch: Vec<Vec<F>> = Vec::with_capacity(n);
|
||||
for k in 0..n {
|
||||
let index = indices_batch[k];
|
||||
let x = xs_batch[k];
|
||||
let vars = vars_batch[k];
|
||||
let local_zs = local_zs_batch[k];
|
||||
let next_zs = next_zs_batch[k];
|
||||
let partial_products = partial_products_batch[k];
|
||||
let s_sigmas = s_sigmas_batch[k];
|
||||
|
||||
let mut numerator_values = Vec::with_capacity(num_routed_wires);
|
||||
let mut denominator_values = Vec::with_capacity(num_routed_wires);
|
||||
let mut quotient_values = Vec::with_capacity(num_routed_wires);
|
||||
for i in 0..num_challenges {
|
||||
let z_x = local_zs[i];
|
||||
let z_gz = next_zs[i];
|
||||
vanishing_z_1_terms.push(l1_x * z_x.sub_one());
|
||||
let constraint_terms =
|
||||
&constraint_terms_batch[k * num_gate_constraints..(k + 1) * num_gate_constraints];
|
||||
|
||||
numerator_values.extend((0..num_routed_wires).map(|j| {
|
||||
let wire_value = vars.local_wires[j];
|
||||
let k_i = common_data.k_is[j];
|
||||
let s_id = k_i * x;
|
||||
wire_value + betas[i] * s_id + gammas[i]
|
||||
}));
|
||||
denominator_values.extend((0..num_routed_wires).map(|j| {
|
||||
let wire_value = vars.local_wires[j];
|
||||
let s_sigma = s_sigmas[j];
|
||||
wire_value + betas[i] * s_sigma + gammas[i]
|
||||
}));
|
||||
let denominator_inverses = F::batch_multiplicative_inverse(&denominator_values);
|
||||
quotient_values
|
||||
.extend((0..num_routed_wires).map(|j| numerator_values[j] * denominator_inverses[j]));
|
||||
let l1_x = z_h_on_coset.eval_l1(index, x);
|
||||
for i in 0..num_challenges {
|
||||
let z_x = local_zs[i];
|
||||
let z_gz = next_zs[i];
|
||||
vanishing_z_1_terms.push(l1_x * z_x.sub_one());
|
||||
|
||||
// The partial products considered for this iteration of `i`.
|
||||
let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods];
|
||||
// Check the numerator partial products.
|
||||
let mut partial_product_check =
|
||||
check_partial_products("ient_values, current_partial_products, max_degree);
|
||||
// The first checks are of the form `q - n/d` which is a rational function not a polynomial.
|
||||
// We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials.
|
||||
denominator_values
|
||||
.chunks(max_degree)
|
||||
.zip(partial_product_check.iter_mut())
|
||||
.for_each(|(d, q)| {
|
||||
*q *= d.iter().copied().product();
|
||||
});
|
||||
vanishing_partial_products_terms.extend(partial_product_check);
|
||||
numerator_values.extend((0..num_routed_wires).map(|j| {
|
||||
let wire_value = vars.local_wires[j];
|
||||
let k_i = common_data.k_is[j];
|
||||
let s_id = k_i * x;
|
||||
wire_value + betas[i] * s_id + gammas[i]
|
||||
}));
|
||||
denominator_values.extend((0..num_routed_wires).map(|j| {
|
||||
let wire_value = vars.local_wires[j];
|
||||
let s_sigma = s_sigmas[j];
|
||||
wire_value + betas[i] * s_sigma + gammas[i]
|
||||
}));
|
||||
let denominator_inverses = F::batch_multiplicative_inverse(&denominator_values);
|
||||
quotient_values.extend(
|
||||
(0..num_routed_wires).map(|j| numerator_values[j] * denominator_inverses[j]),
|
||||
);
|
||||
|
||||
// The quotient final product is the product of the last `final_num_prod` elements.
|
||||
let quotient: F = current_partial_products[num_prods - final_num_prod..]
|
||||
// The partial products considered for this iteration of `i`.
|
||||
let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods];
|
||||
// Check the numerator partial products.
|
||||
let mut partial_product_check =
|
||||
check_partial_products("ient_values, current_partial_products, max_degree);
|
||||
// The first checks are of the form `q - n/d` which is a rational function not a polynomial.
|
||||
// We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials.
|
||||
denominator_values
|
||||
.chunks(max_degree)
|
||||
.zip(partial_product_check.iter_mut())
|
||||
.for_each(|(d, q)| {
|
||||
*q *= d.iter().copied().product();
|
||||
});
|
||||
vanishing_partial_products_terms.extend(partial_product_check);
|
||||
|
||||
// The quotient final product is the product of the last `final_num_prod` elements.
|
||||
let quotient: F = current_partial_products[num_prods - final_num_prod..]
|
||||
.iter()
|
||||
.copied()
|
||||
.product();
|
||||
vanishing_v_shift_terms.push(quotient * z_x - z_gz);
|
||||
|
||||
numerator_values.clear();
|
||||
denominator_values.clear();
|
||||
quotient_values.clear();
|
||||
}
|
||||
|
||||
let vanishing_terms = vanishing_z_1_terms
|
||||
.iter()
|
||||
.copied()
|
||||
.product();
|
||||
vanishing_v_shift_terms.push(quotient * z_x - z_gz);
|
||||
.chain(vanishing_partial_products_terms.iter())
|
||||
.chain(vanishing_v_shift_terms.iter())
|
||||
.chain(constraint_terms);
|
||||
let res = plonk_common::reduce_with_powers_multi(vanishing_terms, alphas);
|
||||
res_batch.push(res);
|
||||
|
||||
numerator_values.clear();
|
||||
denominator_values.clear();
|
||||
quotient_values.clear();
|
||||
vanishing_z_1_terms.clear();
|
||||
vanishing_partial_products_terms.clear();
|
||||
vanishing_v_shift_terms.clear();
|
||||
}
|
||||
|
||||
let vanishing_terms = [
|
||||
vanishing_z_1_terms,
|
||||
vanishing_partial_products_terms,
|
||||
vanishing_v_shift_terms,
|
||||
constraint_terms,
|
||||
]
|
||||
.concat();
|
||||
|
||||
plonk_common::reduce_with_powers_multi(&vanishing_terms, alphas)
|
||||
res_batch
|
||||
}
|
||||
|
||||
/// Evaluates all gate constraints.
|
||||
@ -219,23 +248,38 @@ pub fn evaluate_gate_constraints<F: RichField + Extendable<D>, const D: usize>(
|
||||
constraints
|
||||
}
|
||||
|
||||
pub fn evaluate_gate_constraints_base<F: RichField + Extendable<D>, const D: usize>(
|
||||
/// Evaluate all gate constraints in the base field.
|
||||
///
|
||||
/// Returns a vector of num_gate_constraints * vars_batch.len() field elements. The constraints
|
||||
/// corresponding to vars_batch[i] are found in
|
||||
/// result[num_gate_constraints * i..num_gate_constraints * (i + 1)].
|
||||
pub fn evaluate_gate_constraints_base_batch<F: RichField + Extendable<D>, const D: usize>(
|
||||
gates: &[PrefixedGate<F, D>],
|
||||
num_gate_constraints: usize,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
vars_batch: &[EvaluationVarsBase<F>],
|
||||
) -> Vec<F> {
|
||||
let mut constraints = vec![F::ZERO; num_gate_constraints];
|
||||
let mut constraints_batch = vec![F::ZERO; num_gate_constraints * vars_batch.len()];
|
||||
for gate in gates {
|
||||
let gate_constraints = gate.gate.0.eval_filtered_base(vars, &gate.prefix);
|
||||
for (i, c) in gate_constraints.into_iter().enumerate() {
|
||||
let gate_constraints_batch = gate
|
||||
.gate
|
||||
.0
|
||||
.eval_filtered_base_batch(vars_batch, &gate.prefix);
|
||||
for (constraints, gate_constraints) in constraints_batch
|
||||
.chunks_exact_mut(num_gate_constraints)
|
||||
.zip(gate_constraints_batch.iter())
|
||||
{
|
||||
debug_assert!(
|
||||
i < num_gate_constraints,
|
||||
gate_constraints.len() <= constraints.len(),
|
||||
"num_constraints() gave too low of a number"
|
||||
);
|
||||
constraints[i] += c;
|
||||
for (constraint, &gate_constraint) in
|
||||
constraints.iter_mut().zip(gate_constraints.iter())
|
||||
{
|
||||
*constraint += gate_constraint;
|
||||
}
|
||||
}
|
||||
}
|
||||
constraints
|
||||
constraints_batch
|
||||
}
|
||||
|
||||
pub fn evaluate_gate_constraints_recursively<F: RichField + Extendable<D>, const D: usize>(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user