Change compute_permutation_z_polys to batch permutation checks (#492)

* Change `compute_permutation_z_polys` to batch permutation checks

* feedback
This commit is contained in:
Daniel Lubarov 2022-02-16 22:37:20 -08:00 committed by GitHub
parent 72d13d0ded
commit 431faccbdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -67,48 +67,46 @@ where
// Before batching, each permutation pair leads to `num_challenges` permutation arguments, so we
// start with the cartesian product of `permutation_pairs` and `0..num_challenges`. Then we
// chunk these arguments based on our batch size.
let permutation_instances = permutation_pairs
let permutation_batches = permutation_pairs
.iter()
.cartesian_product(0..config.num_challenges)
.chunks(stark.permutation_batch_size())
.into_iter()
.flat_map(|batch| {
batch.enumerate().map(|(i, (pair, chal))| {
let challenge = permutation_challenge_sets[i].challenges[chal];
PermutationInstance { pair, challenge }
})
.map(|batch| {
batch
.enumerate()
.map(|(i, (pair, chal))| {
let challenge = permutation_challenge_sets[i].challenges[chal];
PermutationInstance { pair, challenge }
})
.collect_vec()
})
.collect_vec();
permutation_instances
permutation_batches
.into_par_iter()
.map(|instance| compute_permutation_z_poly(instance, trace_poly_values))
.map(|instances| compute_permutation_z_poly(&instances, trace_poly_values))
.collect()
}
/// Compute a single Z polynomial.
// TODO: Change this to handle a batch of `PermutationInstance`s.
fn compute_permutation_z_poly<F: Field>(
instance: PermutationInstance<F>,
instances: &[PermutationInstance<F>],
trace_poly_values: &[PolynomialValues<F>],
) -> PolynomialValues<F> {
let PermutationInstance { pair, challenge } = instance;
let PermutationPair { column_pairs } = pair;
let PermutationChallenge { beta, gamma } = challenge;
let degree = trace_poly_values[0].len();
let mut reduced_lhs = PolynomialValues::constant(gamma, degree);
let mut reduced_rhs = PolynomialValues::constant(gamma, degree);
let (reduced_lhs_polys, reduced_rhs_polys): (Vec<_>, Vec<_>) = instances
.iter()
.map(|instance| permutation_reduced_polys(instance, trace_poly_values, degree))
.unzip();
for ((lhs, rhs), weight) in column_pairs.iter().zip(beta.powers()) {
reduced_lhs.add_assign_scaled(&trace_poly_values[*lhs], weight);
reduced_rhs.add_assign_scaled(&trace_poly_values[*rhs], weight);
}
let numerator = poly_product_elementwise(reduced_lhs_polys.into_iter());
let denominator = poly_product_elementwise(reduced_rhs_polys.into_iter());
// Compute the quotients.
let reduced_rhs_inverses = F::batch_multiplicative_inverse(&reduced_rhs.values);
let mut quotients = reduced_lhs.values;
batch_multiply_inplace(&mut quotients, &reduced_rhs_inverses);
let denominator_inverses = F::batch_multiplicative_inverse(&denominator.values);
let mut quotients = numerator.values;
batch_multiply_inplace(&mut quotients, &denominator_inverses);
// Compute Z, which contains partial products of the quotients.
let mut partial_products = Vec::with_capacity(degree);
@ -120,6 +118,39 @@ fn compute_permutation_z_poly<F: Field>(
PolynomialValues::new(partial_products)
}
/// Computes the reduced polynomial, `\sum beta^i f_i(x) + gamma`, for both the "left" and "right"
/// sides of a given `PermutationPair`.
fn permutation_reduced_polys<F: Field>(
instance: &PermutationInstance<F>,
trace_poly_values: &[PolynomialValues<F>],
degree: usize,
) -> (PolynomialValues<F>, PolynomialValues<F>) {
let PermutationInstance {
pair: PermutationPair { column_pairs },
challenge: PermutationChallenge { beta, gamma },
} = instance;
let mut reduced_lhs = PolynomialValues::constant(*gamma, degree);
let mut reduced_rhs = PolynomialValues::constant(*gamma, degree);
for ((lhs, rhs), weight) in column_pairs.iter().zip(beta.powers()) {
reduced_lhs.add_assign_scaled(&trace_poly_values[*lhs], weight);
reduced_rhs.add_assign_scaled(&trace_poly_values[*rhs], weight);
}
(reduced_lhs, reduced_rhs)
}
/// Computes the elementwise product of a set of polynomials. Assumes that the set is non-empty and
/// that each polynomial has the same length.
fn poly_product_elementwise<F: Field>(
mut polys: impl Iterator<Item = PolynomialValues<F>>,
) -> PolynomialValues<F> {
let mut product = polys.next().expect("Expected at least one polynomial");
for poly in polys {
batch_multiply_inplace(&mut product.values, &poly.values)
}
product
}
fn get_permutation_challenge<F: RichField, H: Hasher<F>>(
challenger: &mut Challenger<F, H>,
) -> PermutationChallenge<F> {