From 9139d1350a0fb464af3ad729660ddfbafc4409ba Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Thu, 11 Nov 2021 07:16:16 -0800 Subject: [PATCH] Minor refactor of partial product code (#351) --- src/plonk/prover.rs | 89 ++++++++++++------------------------ src/plonk/vanishing_poly.rs | 38 ++++++--------- src/util/partial_products.rs | 79 +++++++++++++++++--------------- 3 files changed, 83 insertions(+), 123 deletions(-) diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index be356d9f..1c247998 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -17,9 +17,10 @@ use crate::plonk::vanishing_poly::eval_vanishing_poly_base_batch; use crate::plonk::vars::EvaluationVarsBase; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::timed; -use crate::util::partial_products::partial_products; +use crate::util::partial_products::{partial_products_and_z_gx, quotient_chunk_products}; use crate::util::timing::TimingTree; use crate::util::{log2_ceil, transpose}; +use std::mem::swap; pub(crate) fn prove, const D: usize>( prover_data: &ProverOnlyCircuitData, @@ -91,28 +92,21 @@ pub(crate) fn prove, const D: usize>( common_data.quotient_degree_factor < common_data.config.num_routed_wires, "When the number of routed wires is smaller that the degree, we should change the logic to avoid computing partial products." ); - let mut partial_products = timed!( + let mut partial_products_and_zs = timed!( timing, "compute partial products", all_wires_permutation_partial_products(&witness, &betas, &gammas, prover_data, common_data) ); - let plonk_z_vecs = timed!( - timing, - "compute Z's", - compute_zs(&mut partial_products, common_data) - ); + // Z is expected at the front of our batch; see `zs_range` and `partial_products_range`. + let plonk_z_vecs = partial_products_and_zs.iter_mut() + .map(|partial_products_and_z| partial_products_and_z.pop().unwrap()) + .collect(); + let zs_partial_products = [plonk_z_vecs, partial_products_and_zs.concat()].concat(); - // The first polynomial in `partial_products` represent the final product used in the - // computation of `Z`. It isn't needed anymore so we discard it. - partial_products.iter_mut().for_each(|part| { - part.remove(0); - }); - - let zs_partial_products = [plonk_z_vecs, partial_products.concat()].concat(); - let zs_partial_products_commitment = timed!( + let partial_products_and_zs_commitment = timed!( timing, - "commit to Z's", + "commit to partial products and Z's", PolynomialBatchCommitment::from_values( zs_partial_products, config.rate_bits, @@ -123,7 +117,7 @@ pub(crate) fn prove, const D: usize>( ) ); - challenger.observe_cap(&zs_partial_products_commitment.merkle_tree.cap); + challenger.observe_cap(&partial_products_and_zs_commitment.merkle_tree.cap); let alphas = challenger.get_n_challenges(num_challenges); @@ -135,7 +129,7 @@ pub(crate) fn prove, const D: usize>( prover_data, &public_inputs_hash, &wires_commitment, - &zs_partial_products_commitment, + &partial_products_and_zs_commitment, &betas, &gammas, &alphas, @@ -184,7 +178,7 @@ pub(crate) fn prove, const D: usize>( &[ &prover_data.constants_sigmas_commitment, &wires_commitment, - &zs_partial_products_commitment, + &partial_products_and_zs_commitment, "ient_polys_commitment, ], zeta, @@ -196,7 +190,7 @@ pub(crate) fn prove, const D: usize>( let proof = Proof { wires_cap: wires_commitment.merkle_tree.cap, - plonk_zs_partial_products_cap: zs_partial_products_commitment.merkle_tree.cap, + plonk_zs_partial_products_cap: partial_products_and_zs_commitment.merkle_tree.cap, quotient_polys_cap: quotient_polys_commitment.merkle_tree.cap, openings, opening_proof, @@ -217,7 +211,7 @@ fn all_wires_permutation_partial_products, const D: ) -> Vec>> { (0..common_data.config.num_challenges) .map(|i| { - wires_permutation_partial_products( + wires_permutation_partial_products_and_zs( witness, betas[i], gammas[i], @@ -231,7 +225,7 @@ fn all_wires_permutation_partial_products, const D: /// Compute the partial products used in the `Z` polynomial. /// Returns the polynomials interpolating `partial_products(f / g)` /// where `f, g` are the products in the definition of `Z`: `Z(g^i) = f / g`. -fn wires_permutation_partial_products, const D: usize>( +fn wires_permutation_partial_products_and_zs, const D: usize>( witness: &MatrixWitness, beta: F, gamma: F, @@ -241,7 +235,8 @@ fn wires_permutation_partial_products, const D: usi let degree = common_data.quotient_degree_factor; let subgroup = &prover_data.subgroup; let k_is = &common_data.k_is; - let values = subgroup + let (num_prods, final_num_prod) = common_data.num_partial_products; + let all_quotient_chunk_products = subgroup .par_iter() .enumerate() .map(|(i, &x)| { @@ -265,51 +260,25 @@ fn wires_permutation_partial_products, const D: usi .map(|(num, den_inv)| num * den_inv) .collect::>(); - let quotient_partials = partial_products("ient_values, degree); - - // This is the final product for the quotient. - let quotient = *quotient_partials.last().unwrap() - * quotient_values[common_data.num_partial_products.1..] - .iter() - .copied() - .product(); - - // We add the quotient at the beginning of the vector to reuse them later in the computation of `Z`. - [vec![quotient], quotient_partials].concat() + quotient_chunk_products("ient_values, degree) }) .collect::>(); - transpose(&values) + let mut z_x = F::ONE; + let mut all_partial_products_and_zs = Vec::new(); + for quotient_chunk_products in all_quotient_chunk_products { + let mut partial_products_and_z_gx = partial_products_and_z_gx(z_x, "ient_chunk_products); + // The last term is Z(gx), but we replace it with Z(x), otherwise Z would end up shifted. + swap(&mut z_x, &mut partial_products_and_z_gx[num_prods]); + all_partial_products_and_zs.push(partial_products_and_z_gx); + } + + transpose(&all_partial_products_and_zs) .into_par_iter() .map(PolynomialValues::new) .collect() } -fn compute_zs, const D: usize>( - partial_products: &mut [Vec>], - common_data: &CommonCircuitData, -) -> Vec> { - (0..common_data.config.num_challenges) - .map(|i| compute_z(&mut partial_products[i], common_data)) - .collect() -} - -/// Compute the `Z` polynomial by reusing the computations done in `wires_permutation_partial_products`. -fn compute_z, const D: usize>( - partial_products: &mut [PolynomialValues], - common_data: &CommonCircuitData, -) -> PolynomialValues { - let mut plonk_z_points = vec![F::ONE]; - for i in 1..common_data.degree() { - let last = *plonk_z_points.last().unwrap(); - for q in partial_products.iter_mut() { - q.values[i - 1] *= last; - } - plonk_z_points.push(partial_products[0].values[i - 1]); - } - plonk_z_points.into() -} - const BATCH_SIZE: usize = 32; fn compute_quotient_polys<'a, F: RichField + Extendable, const D: usize>( diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 899d69a6..f91e027b 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -75,13 +75,10 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( ); vanishing_partial_products_terms.extend(partial_product_checks); - let v_shift_term = *current_partial_products.last().unwrap() - * numerator_values[final_num_prod..].iter().copied().product() - - z_gz - * denominator_values[final_num_prod..] - .iter() - .copied() - .product(); + let final_nume_product = numerator_values[final_num_prod..].iter().copied().product(); + let final_deno_product = denominator_values[final_num_prod..].iter().copied().product(); + let last_partial = *current_partial_products.last().unwrap(); + let v_shift_term = last_partial * final_nume_product - z_gz * final_deno_product; vanishing_v_shift_terms.push(v_shift_term); } @@ -185,13 +182,10 @@ pub(crate) fn eval_vanishing_poly_base_batch, const ); vanishing_partial_products_terms.extend(partial_product_checks); - let v_shift_term = *current_partial_products.last().unwrap() - * numerator_values[final_num_prod..].iter().copied().product() - - z_gz - * denominator_values[final_num_prod..] - .iter() - .copied() - .product(); + let final_nume_product = numerator_values[final_num_prod..].iter().copied().product(); + let final_deno_product = denominator_values[final_num_prod..].iter().copied().product(); + let last_partial = *current_partial_products.last().unwrap(); + let v_shift_term = last_partial * final_nume_product - z_gz * final_deno_product; vanishing_v_shift_terms.push(v_shift_term); numerator_values.clear(); @@ -381,17 +375,11 @@ pub(crate) fn eval_vanishing_poly_recursively, cons ); vanishing_partial_products_terms.extend(partial_product_checks); - let nume_acc = builder.mul_many_extension(&{ - let mut v = numerator_values[final_num_prod..].to_vec(); - v.push(*current_partial_products.last().unwrap()); - v - }); - let z_gz_denominators = builder.mul_many_extension(&{ - let mut v = denominator_values[final_num_prod..].to_vec(); - v.push(z_gz); - v - }); - let v_shift_term = builder.sub_extension(nume_acc, z_gz_denominators); + let final_nume_product = builder.mul_many_extension(&numerator_values[final_num_prod..]); + let final_deno_product = builder.mul_many_extension(&denominator_values[final_num_prod..]); + let z_gz_denominators = builder.mul_extension(z_gz, final_deno_product); + let last_partial = *current_partial_products.last().unwrap(); + let v_shift_term = builder.mul_sub_extension(last_partial, final_nume_product, z_gz_denominators); vanishing_v_shift_terms.push(v_shift_term); } diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index 92419a56..1e7a4f4b 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -2,19 +2,30 @@ use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; use crate::plonk::circuit_builder::CircuitBuilder; +use itertools::Itertools; + +pub(crate) fn quotient_chunk_products( + quotient_values: &[F], + max_degree: usize, +) -> Vec { + debug_assert!(max_degree > 1); + assert!(quotient_values.len() > 0); + let chunk_size = max_degree; + quotient_values.chunks(chunk_size) + .map(|chunk| chunk.iter().copied().product()) + .collect() +} /// Compute partial products of the original vector `v` such that all products consist of `max_degree` /// or less elements. This is done until we've computed the product `P` of all elements in the vector. -pub fn partial_products(v: &[F], max_degree: usize) -> Vec { - debug_assert!(max_degree > 1); +pub(crate) fn partial_products_and_z_gx(z_x: F, quotient_chunk_products: &[F]) -> Vec { + assert!(quotient_chunk_products.len() > 0); let mut res = Vec::new(); - let mut acc = F::ONE; - let chunk_size = max_degree; - for chunk in v.chunks_exact(chunk_size) { - acc *= chunk.iter().copied().product(); + let mut acc = z_x; + for "ient_chunk_product in quotient_chunk_products { + acc *= quotient_chunk_product; res.push(acc); } - res } @@ -30,24 +41,26 @@ pub fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { /// Checks that the partial products of `numerators/denominators` are coherent with those in `partials` by only computing /// products of size `max_degree` or less. -pub fn check_partial_products( +pub(crate) fn check_partial_products( numerators: &[F], denominators: &[F], partials: &[F], - mut acc: F, + z_x: F, max_degree: usize, ) -> Vec { debug_assert!(max_degree > 1); + let mut acc = z_x; let mut partials = partials.iter(); let mut res = Vec::new(); let chunk_size = max_degree; for (nume_chunk, deno_chunk) in numerators .chunks_exact(chunk_size) - .zip(denominators.chunks_exact(chunk_size)) + .zip_eq(denominators.chunks_exact(chunk_size)) { - acc *= nume_chunk.iter().copied().product(); - let mut new_acc = *partials.next().unwrap(); - res.push(acc - new_acc * deno_chunk.iter().copied().product()); + let num_chunk_product = nume_chunk.iter().copied().product(); + let den_chunk_product = deno_chunk.iter().copied().product(); + let new_acc = *partials.next().unwrap(); + res.push(acc * num_chunk_product - new_acc * den_chunk_product); acc = new_acc; } debug_assert!(partials.next().is_none()); @@ -55,7 +68,7 @@ pub fn check_partial_products( res } -pub fn check_partial_products_recursively, const D: usize>( +pub(crate) fn check_partial_products_recursively, const D: usize>( builder: &mut CircuitBuilder, numerators: &[ExtensionTarget], denominators: &[ExtensionTarget], @@ -93,18 +106,11 @@ mod tests { fn test_partial_products() { type F = GoldilocksField; let denominators = vec![F::ONE; 6]; - let v = [1, 2, 3, 4, 5, 6] - .into_iter() - .map(|&i| F::from_canonical_u64(i)) - .collect::>(); - let p = partial_products(&v, 2); - assert_eq!( - p, - [2, 24, 720] - .into_iter() - .map(|&i| F::from_canonical_u64(i)) - .collect::>() - ); + let v = field_vec(&[1, 2, 3, 4, 5, 6]); + let quotient_chunks_prods = quotient_chunk_products(&v, 2); + assert_eq!(quotient_chunks_prods, field_vec(&[2, 12, 30])); + let p = partial_products_and_z_gx(F::ONE, "ient_chunks_prods); + assert_eq!(p, field_vec(&[2, 24, 720])); let nums = num_partial_products(v.len(), 2); assert_eq!(p.len(), nums.0); @@ -116,18 +122,11 @@ mod tests { v.into_iter().product::(), ); - let v = [1, 2, 3, 4, 5, 6] - .into_iter() - .map(|&i| F::from_canonical_u64(i)) - .collect::>(); - let p = partial_products(&v, 3); - assert_eq!( - p, - [6, 720] - .into_iter() - .map(|&i| F::from_canonical_u64(i)) - .collect::>() - ); + let v = field_vec(&[1, 2, 3, 4, 5, 6]); + let quotient_chunks_prods = quotient_chunk_products(&v, 3); + assert_eq!(quotient_chunks_prods, field_vec(&[6, 120])); + let p = partial_products_and_z_gx(F::ONE, "ient_chunks_prods); + assert_eq!(p, field_vec(&[6, 720])); let nums = num_partial_products(v.len(), 3); assert_eq!(p.len(), nums.0); assert!(check_partial_products(&v, &denominators, &p, F::ONE, 3) @@ -138,4 +137,8 @@ mod tests { v.into_iter().product::(), ); } + + fn field_vec(xs: &[usize]) -> Vec { + xs.iter().map(|&x| F::from_canonical_usize(x)).collect() + } }