From 431faccbdbee989300992a1d1d04a42bc2602b7e Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Wed, 16 Feb 2022 22:37:20 -0800 Subject: [PATCH] Change `compute_permutation_z_polys` to batch permutation checks (#492) * Change `compute_permutation_z_polys` to batch permutation checks * feedback --- starky/src/permutation.rs | 77 +++++++++++++++++++++++++++------------ 1 file changed, 54 insertions(+), 23 deletions(-) diff --git a/starky/src/permutation.rs b/starky/src/permutation.rs index 1f7655b4..01cfa8bf 100644 --- a/starky/src/permutation.rs +++ b/starky/src/permutation.rs @@ -67,48 +67,46 @@ where // Before batching, each permutation pair leads to `num_challenges` permutation arguments, so we // start with the cartesian product of `permutation_pairs` and `0..num_challenges`. Then we // chunk these arguments based on our batch size. - let permutation_instances = permutation_pairs + let permutation_batches = permutation_pairs .iter() .cartesian_product(0..config.num_challenges) .chunks(stark.permutation_batch_size()) .into_iter() - .flat_map(|batch| { - batch.enumerate().map(|(i, (pair, chal))| { - let challenge = permutation_challenge_sets[i].challenges[chal]; - PermutationInstance { pair, challenge } - }) + .map(|batch| { + batch + .enumerate() + .map(|(i, (pair, chal))| { + let challenge = permutation_challenge_sets[i].challenges[chal]; + PermutationInstance { pair, challenge } + }) + .collect_vec() }) .collect_vec(); - permutation_instances + permutation_batches .into_par_iter() - .map(|instance| compute_permutation_z_poly(instance, trace_poly_values)) + .map(|instances| compute_permutation_z_poly(&instances, trace_poly_values)) .collect() } /// Compute a single Z polynomial. -// TODO: Change this to handle a batch of `PermutationInstance`s. fn compute_permutation_z_poly( - instance: PermutationInstance, + instances: &[PermutationInstance], trace_poly_values: &[PolynomialValues], ) -> PolynomialValues { - let PermutationInstance { pair, challenge } = instance; - let PermutationPair { column_pairs } = pair; - let PermutationChallenge { beta, gamma } = challenge; - let degree = trace_poly_values[0].len(); - let mut reduced_lhs = PolynomialValues::constant(gamma, degree); - let mut reduced_rhs = PolynomialValues::constant(gamma, degree); + let (reduced_lhs_polys, reduced_rhs_polys): (Vec<_>, Vec<_>) = instances + .iter() + .map(|instance| permutation_reduced_polys(instance, trace_poly_values, degree)) + .unzip(); - for ((lhs, rhs), weight) in column_pairs.iter().zip(beta.powers()) { - reduced_lhs.add_assign_scaled(&trace_poly_values[*lhs], weight); - reduced_rhs.add_assign_scaled(&trace_poly_values[*rhs], weight); - } + let numerator = poly_product_elementwise(reduced_lhs_polys.into_iter()); + let denominator = poly_product_elementwise(reduced_rhs_polys.into_iter()); // Compute the quotients. - let reduced_rhs_inverses = F::batch_multiplicative_inverse(&reduced_rhs.values); - let mut quotients = reduced_lhs.values; - batch_multiply_inplace(&mut quotients, &reduced_rhs_inverses); + let denominator_inverses = F::batch_multiplicative_inverse(&denominator.values); + let mut quotients = numerator.values; + batch_multiply_inplace(&mut quotients, &denominator_inverses); // Compute Z, which contains partial products of the quotients. let mut partial_products = Vec::with_capacity(degree); @@ -120,6 +118,39 @@ fn compute_permutation_z_poly( PolynomialValues::new(partial_products) } +/// Computes the reduced polynomial, `\sum beta^i f_i(x) + gamma`, for both the "left" and "right" +/// sides of a given `PermutationPair`. +fn permutation_reduced_polys( + instance: &PermutationInstance, + trace_poly_values: &[PolynomialValues], + degree: usize, +) -> (PolynomialValues, PolynomialValues) { + let PermutationInstance { + pair: PermutationPair { column_pairs }, + challenge: PermutationChallenge { beta, gamma }, + } = instance; + + let mut reduced_lhs = PolynomialValues::constant(*gamma, degree); + let mut reduced_rhs = PolynomialValues::constant(*gamma, degree); + for ((lhs, rhs), weight) in column_pairs.iter().zip(beta.powers()) { + reduced_lhs.add_assign_scaled(&trace_poly_values[*lhs], weight); + reduced_rhs.add_assign_scaled(&trace_poly_values[*rhs], weight); + } + (reduced_lhs, reduced_rhs) +} + +/// Computes the elementwise product of a set of polynomials. Assumes that the set is non-empty and +/// that each polynomial has the same length. +fn poly_product_elementwise( + mut polys: impl Iterator>, +) -> PolynomialValues { + let mut product = polys.next().expect("Expected at least one polynomial"); + for poly in polys { + batch_multiply_inplace(&mut product.values, &poly.values) + } + product +} + fn get_permutation_challenge>( challenger: &mut Challenger, ) -> PermutationChallenge {