diff --git a/src/gates/gate.rs b/src/gates/gate.rs index 7369796f..b882c336 100644 --- a/src/gates/gate.rs +++ b/src/gates/gate.rs @@ -68,9 +68,21 @@ pub trait Gate, const D: usize>: 'static + Send + S fn eval_filtered_base(&self, mut vars: EvaluationVarsBase, prefix: &[bool]) -> Vec { let filter = compute_filter(prefix, vars.local_constants); vars.remove_prefix(prefix); - self.eval_unfiltered_base(vars) - .into_iter() - .map(|c| c * filter) + let mut res = self.eval_unfiltered_base(vars); + res.iter_mut().for_each(|c| { + *c *= filter; + }); + res + } + + fn eval_filtered_base_batch( + &self, + vars_batch: &[EvaluationVarsBase], + prefix: &[bool], + ) -> Vec> { + vars_batch + .iter() + .map(|&vars| self.eval_filtered_base(vars, prefix)) .collect() } diff --git a/src/plonk/plonk_common.rs b/src/plonk/plonk_common.rs index dd1bf5d4..6b84886d 100644 --- a/src/plonk/plonk_common.rs +++ b/src/plonk/plonk_common.rs @@ -147,12 +147,25 @@ pub(crate) fn eval_l_1_recursively, const D: usize> builder.div_extension(eval_zero_poly, denominator) } -/// For each alpha in alphas, compute a reduction of the given terms using powers of alpha. -pub(crate) fn reduce_with_powers_multi(terms: &[F], alphas: &[F]) -> Vec { - alphas - .iter() - .map(|&alpha| reduce_with_powers(terms, alpha)) - .collect() +/// For each alpha in alphas, compute a reduction of the given terms using powers of alpha. T can +/// be any type convertible to a double-ended iterator. +pub(crate) fn reduce_with_powers_multi< + 'a, + F: Field, + I: DoubleEndedIterator, + T: IntoIterator, +>( + terms: T, + alphas: &[F], +) -> Vec { + let mut cumul = vec![F::ZERO; alphas.len()]; + for &term in terms.into_iter().rev() { + cumul + .iter_mut() + .zip(alphas) + .for_each(|(c, &alpha)| *c = term.multiply_accumulate(*c, alpha)); + } + cumul } pub(crate) fn reduce_with_powers(terms: &[F], alpha: F) -> F { diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index ef26683c..def9da2c 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -14,7 +14,7 @@ use crate::plonk::circuit_data::{CommonCircuitData, ProverOnlyCircuitData}; use crate::plonk::plonk_common::PlonkPolynomials; use crate::plonk::plonk_common::ZeroPolyOnCoset; use crate::plonk::proof::{Proof, ProofWithPublicInputs}; -use crate::plonk::vanishing_poly::eval_vanishing_poly_base; +use crate::plonk::vanishing_poly::eval_vanishing_poly_base_batch; use crate::plonk::vars::EvaluationVarsBase; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::timed; @@ -309,6 +309,8 @@ fn compute_z, const D: usize>( plonk_z_points.into() } +const BATCH_SIZE: usize = 32; + fn compute_quotient_polys<'a, F: RichField + Extendable, const D: usize>( common_data: &CommonCircuitData, prover_data: &'a ProverOnlyCircuitData, @@ -344,50 +346,77 @@ fn compute_quotient_polys<'a, F: RichField + Extendable, const D: usize>( let z_h_on_coset = ZeroPolyOnCoset::new(common_data.degree_bits, max_degree_bits); - let quotient_values: Vec> = points - .into_par_iter() + let points_batches = points.par_chunks(BATCH_SIZE); + let quotient_values: Vec> = points_batches .enumerate() - .map(|(i, x)| { - let shifted_x = F::coset_shift() * x; - let i_next = (i + next_step) % lde_size; - let local_constants_sigmas = get_at_index(&prover_data.constants_sigmas_commitment, i); - let local_constants = &local_constants_sigmas[common_data.constants_range()]; - let s_sigmas = &local_constants_sigmas[common_data.sigmas_range()]; - let local_wires = get_at_index(wires_commitment, i); - let local_zs_partial_products = get_at_index(zs_partial_products_commitment, i); - let local_zs = &local_zs_partial_products[common_data.zs_range()]; - let next_zs = - &get_at_index(zs_partial_products_commitment, i_next)[common_data.zs_range()]; - let partial_products = &local_zs_partial_products[common_data.partial_products_range()]; + .map(|(batch_i, xs_batch)| { + assert_eq!(xs_batch.len(), BATCH_SIZE); + let indices_batch: Vec = + (BATCH_SIZE * batch_i..BATCH_SIZE * (batch_i + 1)).collect(); - debug_assert_eq!(local_wires.len(), common_data.config.num_wires); - debug_assert_eq!(local_zs.len(), num_challenges); + let mut shifted_xs_batch = Vec::with_capacity(xs_batch.len()); + let mut vars_batch = Vec::with_capacity(xs_batch.len()); + let mut local_zs_batch = Vec::with_capacity(xs_batch.len()); + let mut next_zs_batch = Vec::with_capacity(xs_batch.len()); + let mut partial_products_batch = Vec::with_capacity(xs_batch.len()); + let mut s_sigmas_batch = Vec::with_capacity(xs_batch.len()); - let vars = EvaluationVarsBase { - local_constants, - local_wires, - public_inputs_hash, - }; - let mut quotient_values = eval_vanishing_poly_base( + for (&i, &x) in indices_batch.iter().zip(xs_batch) { + let shifted_x = F::coset_shift() * x; + let i_next = (i + next_step) % lde_size; + let local_constants_sigmas = + get_at_index(&prover_data.constants_sigmas_commitment, i); + let local_constants = &local_constants_sigmas[common_data.constants_range()]; + let s_sigmas = &local_constants_sigmas[common_data.sigmas_range()]; + let local_wires = get_at_index(wires_commitment, i); + let local_zs_partial_products = get_at_index(zs_partial_products_commitment, i); + let local_zs = &local_zs_partial_products[common_data.zs_range()]; + let next_zs = + &get_at_index(zs_partial_products_commitment, i_next)[common_data.zs_range()]; + let partial_products = + &local_zs_partial_products[common_data.partial_products_range()]; + + debug_assert_eq!(local_wires.len(), common_data.config.num_wires); + debug_assert_eq!(local_zs.len(), num_challenges); + + let vars = EvaluationVarsBase { + local_constants, + local_wires, + public_inputs_hash, + }; + + shifted_xs_batch.push(shifted_x); + vars_batch.push(vars); + local_zs_batch.push(local_zs); + next_zs_batch.push(next_zs); + partial_products_batch.push(partial_products); + s_sigmas_batch.push(s_sigmas); + } + let mut quotient_values_batch = eval_vanishing_poly_base_batch( common_data, - i, - shifted_x, - vars, - local_zs, - next_zs, - partial_products, - s_sigmas, + &indices_batch, + &shifted_xs_batch, + &vars_batch, + &local_zs_batch, + &next_zs_batch, + &partial_products_batch, + &s_sigmas_batch, betas, gammas, alphas, &z_h_on_coset, ); - let denominator_inv = z_h_on_coset.eval_inverse(i); - quotient_values - .iter_mut() - .for_each(|v| *v *= denominator_inv); - quotient_values + + for (&i, quotient_values) in indices_batch.iter().zip(quotient_values_batch.iter_mut()) + { + let denominator_inv = z_h_on_coset.eval_inverse(i); + quotient_values + .iter_mut() + .for_each(|v| *v *= denominator_inv); + } + quotient_values_batch }) + .flatten() .collect(); transpose("ient_values) diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index ae79d2dc..6d0dc982 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -101,30 +101,45 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( plonk_common::reduce_with_powers_multi(&vanishing_terms, alphas) } -/// Like `eval_vanishing_poly`, but specialized for base field points. -pub(crate) fn eval_vanishing_poly_base, const D: usize>( +/// Like `eval_vanishing_poly`, but specialized for base field points. Batched. +pub(crate) fn eval_vanishing_poly_base_batch, const D: usize>( common_data: &CommonCircuitData, - index: usize, - x: F, - vars: EvaluationVarsBase, - local_zs: &[F], - next_zs: &[F], - partial_products: &[F], - s_sigmas: &[F], + indices_batch: &[usize], + xs_batch: &[F], + vars_batch: &[EvaluationVarsBase], + local_zs_batch: &[&[F]], + next_zs_batch: &[&[F]], + partial_products_batch: &[&[F]], + s_sigmas_batch: &[&[F]], betas: &[F], gammas: &[F], alphas: &[F], z_h_on_coset: &ZeroPolyOnCoset, -) -> Vec { +) -> Vec> { + let n = indices_batch.len(); + assert_eq!(xs_batch.len(), n); + assert_eq!(vars_batch.len(), n); + assert_eq!(local_zs_batch.len(), n); + assert_eq!(next_zs_batch.len(), n); + assert_eq!(partial_products_batch.len(), n); + assert_eq!(s_sigmas_batch.len(), n); + let max_degree = common_data.quotient_degree_factor; let (num_prods, final_num_prod) = common_data.num_partial_products; - let constraint_terms = - evaluate_gate_constraints_base(&common_data.gates, common_data.num_gate_constraints, vars); + let num_gate_constraints = common_data.num_gate_constraints; + + let constraint_terms_batch = + evaluate_gate_constraints_base_batch(&common_data.gates, num_gate_constraints, vars_batch); + debug_assert!(constraint_terms_batch.len() == n * num_gate_constraints); let num_challenges = common_data.config.num_challenges; let num_routed_wires = common_data.config.num_routed_wires; + let mut numerator_values = Vec::with_capacity(num_routed_wires); + let mut denominator_values = Vec::with_capacity(num_routed_wires); + let mut quotient_values = Vec::with_capacity(num_routed_wires); + // The L_1(x) (Z(x) - 1) vanishing terms. let mut vanishing_z_1_terms = Vec::with_capacity(num_challenges); // The terms checking the partial products. @@ -132,67 +147,81 @@ pub(crate) fn eval_vanishing_poly_base, const D: us // The Z(x) f'(x) - g'(x) Z(g x) terms. let mut vanishing_v_shift_terms = Vec::with_capacity(num_challenges); - let l1_x = z_h_on_coset.eval_l1(index, x); + let mut res_batch: Vec> = Vec::with_capacity(n); + for k in 0..n { + let index = indices_batch[k]; + let x = xs_batch[k]; + let vars = vars_batch[k]; + let local_zs = local_zs_batch[k]; + let next_zs = next_zs_batch[k]; + let partial_products = partial_products_batch[k]; + let s_sigmas = s_sigmas_batch[k]; - let mut numerator_values = Vec::with_capacity(num_routed_wires); - let mut denominator_values = Vec::with_capacity(num_routed_wires); - let mut quotient_values = Vec::with_capacity(num_routed_wires); - for i in 0..num_challenges { - let z_x = local_zs[i]; - let z_gz = next_zs[i]; - vanishing_z_1_terms.push(l1_x * z_x.sub_one()); + let constraint_terms = + &constraint_terms_batch[k * num_gate_constraints..(k + 1) * num_gate_constraints]; - numerator_values.extend((0..num_routed_wires).map(|j| { - let wire_value = vars.local_wires[j]; - let k_i = common_data.k_is[j]; - let s_id = k_i * x; - wire_value + betas[i] * s_id + gammas[i] - })); - denominator_values.extend((0..num_routed_wires).map(|j| { - let wire_value = vars.local_wires[j]; - let s_sigma = s_sigmas[j]; - wire_value + betas[i] * s_sigma + gammas[i] - })); - let denominator_inverses = F::batch_multiplicative_inverse(&denominator_values); - quotient_values - .extend((0..num_routed_wires).map(|j| numerator_values[j] * denominator_inverses[j])); + let l1_x = z_h_on_coset.eval_l1(index, x); + for i in 0..num_challenges { + let z_x = local_zs[i]; + let z_gz = next_zs[i]; + vanishing_z_1_terms.push(l1_x * z_x.sub_one()); - // The partial products considered for this iteration of `i`. - let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; - // Check the numerator partial products. - let mut partial_product_check = - check_partial_products("ient_values, current_partial_products, max_degree); - // The first checks are of the form `q - n/d` which is a rational function not a polynomial. - // We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials. - denominator_values - .chunks(max_degree) - .zip(partial_product_check.iter_mut()) - .for_each(|(d, q)| { - *q *= d.iter().copied().product(); - }); - vanishing_partial_products_terms.extend(partial_product_check); + numerator_values.extend((0..num_routed_wires).map(|j| { + let wire_value = vars.local_wires[j]; + let k_i = common_data.k_is[j]; + let s_id = k_i * x; + wire_value + betas[i] * s_id + gammas[i] + })); + denominator_values.extend((0..num_routed_wires).map(|j| { + let wire_value = vars.local_wires[j]; + let s_sigma = s_sigmas[j]; + wire_value + betas[i] * s_sigma + gammas[i] + })); + let denominator_inverses = F::batch_multiplicative_inverse(&denominator_values); + quotient_values.extend( + (0..num_routed_wires).map(|j| numerator_values[j] * denominator_inverses[j]), + ); - // The quotient final product is the product of the last `final_num_prod` elements. - let quotient: F = current_partial_products[num_prods - final_num_prod..] + // The partial products considered for this iteration of `i`. + let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; + // Check the numerator partial products. + let mut partial_product_check = + check_partial_products("ient_values, current_partial_products, max_degree); + // The first checks are of the form `q - n/d` which is a rational function not a polynomial. + // We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials. + denominator_values + .chunks(max_degree) + .zip(partial_product_check.iter_mut()) + .for_each(|(d, q)| { + *q *= d.iter().copied().product(); + }); + vanishing_partial_products_terms.extend(partial_product_check); + + // The quotient final product is the product of the last `final_num_prod` elements. + let quotient: F = current_partial_products[num_prods - final_num_prod..] + .iter() + .copied() + .product(); + vanishing_v_shift_terms.push(quotient * z_x - z_gz); + + numerator_values.clear(); + denominator_values.clear(); + quotient_values.clear(); + } + + let vanishing_terms = vanishing_z_1_terms .iter() - .copied() - .product(); - vanishing_v_shift_terms.push(quotient * z_x - z_gz); + .chain(vanishing_partial_products_terms.iter()) + .chain(vanishing_v_shift_terms.iter()) + .chain(constraint_terms); + let res = plonk_common::reduce_with_powers_multi(vanishing_terms, alphas); + res_batch.push(res); - numerator_values.clear(); - denominator_values.clear(); - quotient_values.clear(); + vanishing_z_1_terms.clear(); + vanishing_partial_products_terms.clear(); + vanishing_v_shift_terms.clear(); } - - let vanishing_terms = [ - vanishing_z_1_terms, - vanishing_partial_products_terms, - vanishing_v_shift_terms, - constraint_terms, - ] - .concat(); - - plonk_common::reduce_with_powers_multi(&vanishing_terms, alphas) + res_batch } /// Evaluates all gate constraints. @@ -219,23 +248,38 @@ pub fn evaluate_gate_constraints, const D: usize>( constraints } -pub fn evaluate_gate_constraints_base, const D: usize>( +/// Evaluate all gate constraints in the base field. +/// +/// Returns a vector of num_gate_constraints * vars_batch.len() field elements. The constraints +/// corresponding to vars_batch[i] are found in +/// result[num_gate_constraints * i..num_gate_constraints * (i + 1)]. +pub fn evaluate_gate_constraints_base_batch, const D: usize>( gates: &[PrefixedGate], num_gate_constraints: usize, - vars: EvaluationVarsBase, + vars_batch: &[EvaluationVarsBase], ) -> Vec { - let mut constraints = vec![F::ZERO; num_gate_constraints]; + let mut constraints_batch = vec![F::ZERO; num_gate_constraints * vars_batch.len()]; for gate in gates { - let gate_constraints = gate.gate.0.eval_filtered_base(vars, &gate.prefix); - for (i, c) in gate_constraints.into_iter().enumerate() { + let gate_constraints_batch = gate + .gate + .0 + .eval_filtered_base_batch(vars_batch, &gate.prefix); + for (constraints, gate_constraints) in constraints_batch + .chunks_exact_mut(num_gate_constraints) + .zip(gate_constraints_batch.iter()) + { debug_assert!( - i < num_gate_constraints, + gate_constraints.len() <= constraints.len(), "num_constraints() gave too low of a number" ); - constraints[i] += c; + for (constraint, &gate_constraint) in + constraints.iter_mut().zip(gate_constraints.iter()) + { + *constraint += gate_constraint; + } } } - constraints + constraints_batch } pub fn evaluate_gate_constraints_recursively, const D: usize>(