diff --git a/src/plonk/circuit_data.rs b/src/plonk/circuit_data.rs index 9ba05a87..869543af 100644 --- a/src/plonk/circuit_data.rs +++ b/src/plonk/circuit_data.rs @@ -190,8 +190,8 @@ pub struct CommonCircuitData, const D: usize> { /// The `{k_i}` valued used in `S_ID_i` in Plonk's permutation argument. pub(crate) k_is: Vec, - /// The number of partial products needed to compute the `Z` polynomials and the number - /// of partial products needed to compute the final product. + /// The number of partial products needed to compute the `Z` polynomials and + /// the number of original elements consumed in `partial_products()`. pub(crate) num_partial_products: (usize, usize), /// A digest of the "circuit" (i.e. the instance, minus public inputs), which can be used to diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 031354fa..1dd17cb8 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -1,3 +1,5 @@ +use std::mem::swap; + use anyhow::Result; use rayon::prelude::*; @@ -17,7 +19,7 @@ use crate::plonk::vanishing_poly::eval_vanishing_poly_base_batch; use crate::plonk::vars::EvaluationVarsBase; use crate::polynomial::polynomial::{PolynomialCoeffs, PolynomialValues}; use crate::timed; -use crate::util::partial_products::partial_products; +use crate::util::partial_products::{partial_products_and_z_gx, quotient_chunk_products}; use crate::util::timing::TimingTree; use crate::util::{log2_ceil, transpose}; @@ -91,28 +93,22 @@ pub(crate) fn prove, const D: usize>( common_data.quotient_degree_factor < common_data.config.num_routed_wires, "When the number of routed wires is smaller that the degree, we should change the logic to avoid computing partial products." ); - let mut partial_products = timed!( + let mut partial_products_and_zs = timed!( timing, "compute partial products", all_wires_permutation_partial_products(&witness, &betas, &gammas, prover_data, common_data) ); - let plonk_z_vecs = timed!( - timing, - "compute Z's", - compute_zs(&partial_products, common_data) - ); + // Z is expected at the front of our batch; see `zs_range` and `partial_products_range`. + let plonk_z_vecs = partial_products_and_zs + .iter_mut() + .map(|partial_products_and_z| partial_products_and_z.pop().unwrap()) + .collect(); + let zs_partial_products = [plonk_z_vecs, partial_products_and_zs.concat()].concat(); - // The first polynomial in `partial_products` represent the final product used in the - // computation of `Z`. It isn't needed anymore so we discard it. - partial_products.iter_mut().for_each(|part| { - part.remove(0); - }); - - let zs_partial_products = [plonk_z_vecs, partial_products.concat()].concat(); - let zs_partial_products_commitment = timed!( + let partial_products_and_zs_commitment = timed!( timing, - "commit to Z's", + "commit to partial products and Z's", PolynomialBatchCommitment::from_values( zs_partial_products, config.rate_bits, @@ -123,7 +119,7 @@ pub(crate) fn prove, const D: usize>( ) ); - challenger.observe_cap(&zs_partial_products_commitment.merkle_tree.cap); + challenger.observe_cap(&partial_products_and_zs_commitment.merkle_tree.cap); let alphas = challenger.get_n_challenges(num_challenges); @@ -135,7 +131,7 @@ pub(crate) fn prove, const D: usize>( prover_data, &public_inputs_hash, &wires_commitment, - &zs_partial_products_commitment, + &partial_products_and_zs_commitment, &betas, &gammas, &alphas, @@ -184,7 +180,7 @@ pub(crate) fn prove, const D: usize>( &[ &prover_data.constants_sigmas_commitment, &wires_commitment, - &zs_partial_products_commitment, + &partial_products_and_zs_commitment, "ient_polys_commitment, ], zeta, @@ -196,7 +192,7 @@ pub(crate) fn prove, const D: usize>( let proof = Proof { wires_cap: wires_commitment.merkle_tree.cap, - plonk_zs_partial_products_cap: zs_partial_products_commitment.merkle_tree.cap, + plonk_zs_partial_products_cap: partial_products_and_zs_commitment.merkle_tree.cap, quotient_polys_cap: quotient_polys_commitment.merkle_tree.cap, openings, opening_proof, @@ -217,7 +213,7 @@ fn all_wires_permutation_partial_products, const D: ) -> Vec>> { (0..common_data.config.num_challenges) .map(|i| { - wires_permutation_partial_products( + wires_permutation_partial_products_and_zs( witness, betas[i], gammas[i], @@ -231,7 +227,7 @@ fn all_wires_permutation_partial_products, const D: /// Compute the partial products used in the `Z` polynomial. /// Returns the polynomials interpolating `partial_products(f / g)` /// where `f, g` are the products in the definition of `Z`: `Z(g^i) = f / g`. -fn wires_permutation_partial_products, const D: usize>( +fn wires_permutation_partial_products_and_zs, const D: usize>( witness: &MatrixWitness, beta: F, gamma: F, @@ -241,7 +237,8 @@ fn wires_permutation_partial_products, const D: usi let degree = common_data.quotient_degree_factor; let subgroup = &prover_data.subgroup; let k_is = &common_data.k_is; - let values = subgroup + let (num_prods, final_num_prod) = common_data.num_partial_products; + let all_quotient_chunk_products = subgroup .par_iter() .enumerate() .map(|(i, &x)| { @@ -265,49 +262,26 @@ fn wires_permutation_partial_products, const D: usi .map(|(num, den_inv)| num * den_inv) .collect::>(); - let quotient_partials = partial_products("ient_values, degree); - - // This is the final product for the quotient. - let quotient = quotient_partials - [common_data.num_partial_products.0 - common_data.num_partial_products.1..] - .iter() - .copied() - .product(); - - // We add the quotient at the beginning of the vector to reuse them later in the computation of `Z`. - [vec![quotient], quotient_partials].concat() + quotient_chunk_products("ient_values, degree) }) .collect::>(); - transpose(&values) + let mut z_x = F::ONE; + let mut all_partial_products_and_zs = Vec::new(); + for quotient_chunk_products in all_quotient_chunk_products { + let mut partial_products_and_z_gx = + partial_products_and_z_gx(z_x, "ient_chunk_products); + // The last term is Z(gx), but we replace it with Z(x), otherwise Z would end up shifted. + swap(&mut z_x, &mut partial_products_and_z_gx[num_prods]); + all_partial_products_and_zs.push(partial_products_and_z_gx); + } + + transpose(&all_partial_products_and_zs) .into_par_iter() .map(PolynomialValues::new) .collect() } -fn compute_zs, const D: usize>( - partial_products: &[Vec>], - common_data: &CommonCircuitData, -) -> Vec> { - (0..common_data.config.num_challenges) - .map(|i| compute_z(&partial_products[i], common_data)) - .collect() -} - -/// Compute the `Z` polynomial by reusing the computations done in `wires_permutation_partial_products`. -fn compute_z, const D: usize>( - partial_products: &[PolynomialValues], - common_data: &CommonCircuitData, -) -> PolynomialValues { - let mut plonk_z_points = vec![F::ONE]; - for i in 1..common_data.degree() { - let quotient = partial_products[0].values[i - 1]; - let last = *plonk_z_points.last().unwrap(); - plonk_z_points.push(last * quotient); - } - plonk_z_points.into() -} - const BATCH_SIZE: usize = 32; fn compute_quotient_polys<'a, F: RichField + Extendable, const D: usize>( diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 28c6a287..2be91b40 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -62,31 +62,27 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( wire_value + s_sigma.scalar_mul(betas[i]) + gammas[i].into() }) .collect::>(); - let quotient_values = (0..common_data.config.num_routed_wires) - .map(|j| numerator_values[j] / denominator_values[j]) - .collect::>(); // The partial products considered for this iteration of `i`. let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; // Check the quotient partial products. - let mut partial_product_check = - check_partial_products("ient_values, current_partial_products, max_degree); - // The first checks are of the form `q - n/d` which is a rational function not a polynomial. - // We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials. - denominator_values - .chunks(max_degree) - .zip(partial_product_check.iter_mut()) - .for_each(|(d, q)| { - *q *= d.iter().copied().product(); - }); - vanishing_partial_products_terms.extend(partial_product_check); + let partial_product_checks = check_partial_products( + &numerator_values, + &denominator_values, + current_partial_products, + z_x, + max_degree, + ); + vanishing_partial_products_terms.extend(partial_product_checks); - // The quotient final product is the product of the last `final_num_prod` elements. - let quotient: F::Extension = current_partial_products[num_prods - final_num_prod..] + let final_nume_product = numerator_values[final_num_prod..].iter().copied().product(); + let final_deno_product = denominator_values[final_num_prod..] .iter() .copied() .product(); - vanishing_v_shift_terms.push(quotient * z_x - z_gz); + let last_partial = *current_partial_products.last().unwrap(); + let v_shift_term = last_partial * final_nume_product - z_gz * final_deno_product; + vanishing_v_shift_terms.push(v_shift_term); } let vanishing_terms = [ @@ -138,7 +134,6 @@ pub(crate) fn eval_vanishing_poly_base_batch, const let mut numerator_values = Vec::with_capacity(num_routed_wires); let mut denominator_values = Vec::with_capacity(num_routed_wires); - let mut quotient_values = Vec::with_capacity(num_routed_wires); // The L_1(x) (Z(x) - 1) vanishing terms. let mut vanishing_z_1_terms = Vec::with_capacity(num_challenges); @@ -177,36 +172,30 @@ pub(crate) fn eval_vanishing_poly_base_batch, const let s_sigma = s_sigmas[j]; wire_value + betas[i] * s_sigma + gammas[i] })); - let denominator_inverses = F::batch_multiplicative_inverse(&denominator_values); - quotient_values.extend( - (0..num_routed_wires).map(|j| numerator_values[j] * denominator_inverses[j]), - ); // The partial products considered for this iteration of `i`. let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; // Check the numerator partial products. - let mut partial_product_check = - check_partial_products("ient_values, current_partial_products, max_degree); - // The first checks are of the form `q - n/d` which is a rational function not a polynomial. - // We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials. - denominator_values - .chunks(max_degree) - .zip(partial_product_check.iter_mut()) - .for_each(|(d, q)| { - *q *= d.iter().copied().product(); - }); - vanishing_partial_products_terms.extend(partial_product_check); + let partial_product_checks = check_partial_products( + &numerator_values, + &denominator_values, + current_partial_products, + z_x, + max_degree, + ); + vanishing_partial_products_terms.extend(partial_product_checks); - // The quotient final product is the product of the last `final_num_prod` elements. - let quotient: F = current_partial_products[num_prods - final_num_prod..] + let final_nume_product = numerator_values[final_num_prod..].iter().copied().product(); + let final_deno_product = denominator_values[final_num_prod..] .iter() .copied() .product(); - vanishing_v_shift_terms.push(quotient * z_x - z_gz); + let last_partial = *current_partial_products.last().unwrap(); + let v_shift_term = last_partial * final_nume_product - z_gz * final_deno_product; + vanishing_v_shift_terms.push(v_shift_term); numerator_values.clear(); denominator_values.clear(); - quotient_values.clear(); } let vanishing_terms = vanishing_z_1_terms @@ -363,7 +352,6 @@ pub(crate) fn eval_vanishing_poly_recursively, cons let mut numerator_values = Vec::new(); let mut denominator_values = Vec::new(); - let mut quotient_values = Vec::new(); for j in 0..common_data.config.num_routed_wires { let wire_value = vars.local_wires[j]; @@ -376,38 +364,30 @@ pub(crate) fn eval_vanishing_poly_recursively, cons let numerator = builder.mul_add_extension(beta_ext, s_ids[j], wire_value_plus_gamma); let denominator = builder.mul_add_extension(beta_ext, s_sigmas[j], wire_value_plus_gamma); - let quotient = builder.div_extension(numerator, denominator); - numerator_values.push(numerator); denominator_values.push(denominator); - quotient_values.push(quotient); } // The partial products considered for this iteration of `i`. let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; // Check the quotient partial products. - let mut partial_product_check = check_partial_products_recursively( + let partial_product_checks = check_partial_products_recursively( builder, - "ient_values, + &numerator_values, + &denominator_values, current_partial_products, + z_x, max_degree, ); - // The first checks are of the form `q - n/d` which is a rational function not a polynomial. - // We multiply them by `d` to get checks of the form `q*d - n` which low-degree polynomials. - denominator_values - .chunks(max_degree) - .zip(partial_product_check.iter_mut()) - .for_each(|(d, q)| { - let mut v = d.to_vec(); - v.push(*q); - *q = builder.mul_many_extension(&v); - }); - vanishing_partial_products_terms.extend(partial_product_check); + vanishing_partial_products_terms.extend(partial_product_checks); - // The quotient final product is the product of the last `final_num_prod` elements. - let quotient = - builder.mul_many_extension(¤t_partial_products[num_prods - final_num_prod..]); - vanishing_v_shift_terms.push(builder.mul_sub_extension(quotient, z_x, z_gz)); + let final_nume_product = builder.mul_many_extension(&numerator_values[final_num_prod..]); + let final_deno_product = builder.mul_many_extension(&denominator_values[final_num_prod..]); + let z_gz_denominators = builder.mul_extension(z_gz, final_deno_product); + let last_partial = *current_partial_products.last().unwrap(); + let v_shift_term = + builder.mul_sub_extension(last_partial, final_nume_product, z_gz_denominators); + vanishing_v_shift_terms.push(v_shift_term); } let vanishing_terms = [ diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index 633047d0..c4133b4d 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -1,124 +1,146 @@ -use std::iter::Product; -use std::ops::Sub; +use itertools::Itertools; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; -use crate::field::field_types::RichField; +use crate::field::field_types::{Field, RichField}; use crate::plonk::circuit_builder::CircuitBuilder; -use crate::util::ceil_div_usize; + +pub(crate) fn quotient_chunk_products( + quotient_values: &[F], + max_degree: usize, +) -> Vec { + debug_assert!(max_degree > 1); + assert!(quotient_values.len() > 0); + let chunk_size = max_degree; + quotient_values + .chunks(chunk_size) + .map(|chunk| chunk.iter().copied().product()) + .collect() +} /// Compute partial products of the original vector `v` such that all products consist of `max_degree` /// or less elements. This is done until we've computed the product `P` of all elements in the vector. -pub fn partial_products(v: &[T], max_degree: usize) -> Vec { +pub(crate) fn partial_products_and_z_gx(z_x: F, quotient_chunk_products: &[F]) -> Vec { + assert!(quotient_chunk_products.len() > 0); let mut res = Vec::new(); - let mut remainder = v.to_vec(); - while remainder.len() > max_degree { - let new_partials = remainder - .chunks(max_degree) - // TODO: can filter out chunks of length 1. - .map(|chunk| chunk.iter().copied().product()) - .collect::>(); - res.extend_from_slice(&new_partials); - remainder = new_partials; + let mut acc = z_x; + for "ient_chunk_product in quotient_chunk_products { + acc *= quotient_chunk_product; + res.push(acc); } - res } /// Returns a tuple `(a,b)`, where `a` is the length of the output of `partial_products()` on a -/// vector of length `n`, and `b` is the number of elements needed to compute the final product. +/// vector of length `n`, and `b` is the number of original elements consumed in `partial_products()`. pub fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { debug_assert!(max_degree > 1); - let mut res = 0; - let mut remainder = n; - while remainder > max_degree { - let new_partials_len = ceil_div_usize(remainder, max_degree); - res += new_partials_len; - remainder = new_partials_len; - } + let chunk_size = max_degree; + let num_chunks = n / chunk_size; - (res, remainder) + (num_chunks, num_chunks * chunk_size) } -/// Checks that the partial products of `v` are coherent with those in `partials` by only computing +/// Checks that the partial products of `numerators/denominators` are coherent with those in `partials` by only computing /// products of size `max_degree` or less. -pub fn check_partial_products>( - v: &[T], - mut partials: &[T], +pub(crate) fn check_partial_products( + numerators: &[F], + denominators: &[F], + partials: &[F], + z_x: F, max_degree: usize, -) -> Vec { +) -> Vec { + debug_assert!(max_degree > 1); + let mut acc = z_x; + let mut partials = partials.iter(); let mut res = Vec::new(); - let mut remainder = v; - while remainder.len() > max_degree { - let products = remainder - .chunks(max_degree) - .map(|chunk| chunk.iter().copied().product::()); - let products_len = products.len(); - res.extend(products.zip(partials).map(|(a, &b)| a - b)); - (remainder, partials) = partials.split_at(products_len); + let chunk_size = max_degree; + for (nume_chunk, deno_chunk) in numerators + .chunks_exact(chunk_size) + .zip_eq(denominators.chunks_exact(chunk_size)) + { + let num_chunk_product = nume_chunk.iter().copied().product(); + let den_chunk_product = deno_chunk.iter().copied().product(); + let new_acc = *partials.next().unwrap(); + res.push(acc * num_chunk_product - new_acc * den_chunk_product); + acc = new_acc; } + debug_assert!(partials.next().is_none()); res } -pub fn check_partial_products_recursively, const D: usize>( +pub(crate) fn check_partial_products_recursively, const D: usize>( builder: &mut CircuitBuilder, - v: &[ExtensionTarget], + numerators: &[ExtensionTarget], + denominators: &[ExtensionTarget], partials: &[ExtensionTarget], + mut acc: ExtensionTarget, max_degree: usize, ) -> Vec> { + debug_assert!(max_degree > 1); + let mut partials = partials.iter(); let mut res = Vec::new(); - let mut remainder = v.to_vec(); - let mut partials = partials.to_vec(); - while remainder.len() > max_degree { - let products = remainder - .chunks(max_degree) - .map(|chunk| builder.mul_many_extension(chunk)) - .collect::>(); - res.extend( - products - .iter() - .zip(&partials) - .map(|(&a, &b)| builder.sub_extension(a, b)), - ); - remainder = partials.drain(..products.len()).collect(); + let chunk_size = max_degree; + for (nume_chunk, deno_chunk) in numerators + .chunks_exact(chunk_size) + .zip(denominators.chunks_exact(chunk_size)) + { + let nume_product = builder.mul_many_extension(nume_chunk); + let deno_product = builder.mul_many_extension(deno_chunk); + let new_acc = *partials.next().unwrap(); + let new_acc_deno = builder.mul_extension(new_acc, deno_product); + // Assert that new_acc*deno_product = acc * nume_product. + res.push(builder.mul_sub_extension(acc, nume_product, new_acc_deno)); + acc = new_acc; } + debug_assert!(partials.next().is_none()); res } #[cfg(test)] mod tests { - use num::Zero; - use super::*; + use crate::field::goldilocks_field::GoldilocksField; #[test] fn test_partial_products() { - let v = vec![1, 2, 3, 4, 5, 6]; - let p = partial_products(&v, 2); - assert_eq!(p, vec![2, 12, 30, 24, 30]); + type F = GoldilocksField; + let denominators = vec![F::ONE; 6]; + let v = field_vec(&[1, 2, 3, 4, 5, 6]); + let quotient_chunks_prods = quotient_chunk_products(&v, 2); + assert_eq!(quotient_chunks_prods, field_vec(&[2, 12, 30])); + let p = partial_products_and_z_gx(F::ONE, "ient_chunks_prods); + assert_eq!(p, field_vec(&[2, 24, 720])); + let nums = num_partial_products(v.len(), 2); assert_eq!(p.len(), nums.0); - assert!(check_partial_products(&v, &p, 2) + assert!(check_partial_products(&v, &denominators, &p, F::ONE, 2) .iter() .all(|x| x.is_zero())); assert_eq!( - v.into_iter().product::(), - p[p.len() - nums.1..].iter().copied().product(), + *p.last().unwrap() * v[nums.1..].iter().copied().product::(), + v.into_iter().product::(), ); - let v = vec![1, 2, 3, 4, 5, 6]; - let p = partial_products(&v, 3); - assert_eq!(p, vec![6, 120]); + let v = field_vec(&[1, 2, 3, 4, 5, 6]); + let quotient_chunks_prods = quotient_chunk_products(&v, 3); + assert_eq!(quotient_chunks_prods, field_vec(&[6, 120])); + let p = partial_products_and_z_gx(F::ONE, "ient_chunks_prods); + assert_eq!(p, field_vec(&[6, 720])); let nums = num_partial_products(v.len(), 3); assert_eq!(p.len(), nums.0); - assert!(check_partial_products(&v, &p, 3) + assert!(check_partial_products(&v, &denominators, &p, F::ONE, 3) .iter() .all(|x| x.is_zero())); assert_eq!( - v.into_iter().product::(), - p[p.len() - nums.1..].iter().copied().product(), + *p.last().unwrap() * v[nums.1..].iter().copied().product::(), + v.into_iter().product::(), ); } + + fn field_vec(xs: &[usize]) -> Vec { + xs.iter().map(|&x| F::from_canonical_usize(x)).collect() + } }