diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 2be91b40..ef322c9f 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -28,7 +28,7 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( alphas: &[F], ) -> Vec { let max_degree = common_data.quotient_degree_factor; - let (num_prods, final_num_prod) = common_data.num_partial_products; + let (num_prods, _final_num_prod) = common_data.num_partial_products; let constraint_terms = evaluate_gate_constraints(&common_data.gates, common_data.num_gate_constraints, vars); @@ -37,8 +37,6 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( let mut vanishing_z_1_terms = Vec::new(); // The terms checking the partial products. let mut vanishing_partial_products_terms = Vec::new(); - // The Z(x) f'(x) - g'(x) Z(g x) terms. - let mut vanishing_v_shift_terms = Vec::new(); let l1_x = plonk_common::eval_l_1(common_data.degree(), x); @@ -71,24 +69,15 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( &denominator_values, current_partial_products, z_x, + z_gz, max_degree, ); vanishing_partial_products_terms.extend(partial_product_checks); - - let final_nume_product = numerator_values[final_num_prod..].iter().copied().product(); - let final_deno_product = denominator_values[final_num_prod..] - .iter() - .copied() - .product(); - let last_partial = *current_partial_products.last().unwrap(); - let v_shift_term = last_partial * final_nume_product - z_gz * final_deno_product; - vanishing_v_shift_terms.push(v_shift_term); } let vanishing_terms = [ vanishing_z_1_terms, vanishing_partial_products_terms, - vanishing_v_shift_terms, constraint_terms, ] .concat(); @@ -121,7 +110,7 @@ pub(crate) fn eval_vanishing_poly_base_batch, const assert_eq!(s_sigmas_batch.len(), n); let max_degree = common_data.quotient_degree_factor; - let (num_prods, final_num_prod) = common_data.num_partial_products; + let (num_prods, _final_num_prod) = common_data.num_partial_products; let num_gate_constraints = common_data.num_gate_constraints; @@ -139,8 +128,6 @@ pub(crate) fn eval_vanishing_poly_base_batch, const let mut vanishing_z_1_terms = Vec::with_capacity(num_challenges); // The terms checking the partial products. let mut vanishing_partial_products_terms = Vec::new(); - // The Z(x) f'(x) - g'(x) Z(g x) terms. - let mut vanishing_v_shift_terms = Vec::with_capacity(num_challenges); let mut res_batch: Vec> = Vec::with_capacity(n); for k in 0..n { @@ -181,19 +168,11 @@ pub(crate) fn eval_vanishing_poly_base_batch, const &denominator_values, current_partial_products, z_x, + z_gz, max_degree, ); vanishing_partial_products_terms.extend(partial_product_checks); - let final_nume_product = numerator_values[final_num_prod..].iter().copied().product(); - let final_deno_product = denominator_values[final_num_prod..] - .iter() - .copied() - .product(); - let last_partial = *current_partial_products.last().unwrap(); - let v_shift_term = last_partial * final_nume_product - z_gz * final_deno_product; - vanishing_v_shift_terms.push(v_shift_term); - numerator_values.clear(); denominator_values.clear(); } @@ -201,14 +180,12 @@ pub(crate) fn eval_vanishing_poly_base_batch, const let vanishing_terms = vanishing_z_1_terms .iter() .chain(vanishing_partial_products_terms.iter()) - .chain(vanishing_v_shift_terms.iter()) .chain(constraint_terms); let res = plonk_common::reduce_with_powers_multi(vanishing_terms, alphas); res_batch.push(res); vanishing_z_1_terms.clear(); vanishing_partial_products_terms.clear(); - vanishing_v_shift_terms.clear(); } res_batch } @@ -314,7 +291,7 @@ pub(crate) fn eval_vanishing_poly_recursively, cons alphas: &[Target], ) -> Vec> { let max_degree = common_data.quotient_degree_factor; - let (num_prods, final_num_prod) = common_data.num_partial_products; + let (num_prods, _final_num_prod) = common_data.num_partial_products; let constraint_terms = with_context!( builder, @@ -331,8 +308,6 @@ pub(crate) fn eval_vanishing_poly_recursively, cons let mut vanishing_z_1_terms = Vec::new(); // The terms checking the partial products. let mut vanishing_partial_products_terms = Vec::new(); - // The Z(x) f'(x) - g'(x) Z(g x) terms. - let mut vanishing_v_shift_terms = Vec::new(); let l1_x = eval_l_1_recursively(builder, common_data.degree(), x, x_pow_deg); @@ -377,23 +352,15 @@ pub(crate) fn eval_vanishing_poly_recursively, cons &denominator_values, current_partial_products, z_x, + z_gz, max_degree, ); vanishing_partial_products_terms.extend(partial_product_checks); - - let final_nume_product = builder.mul_many_extension(&numerator_values[final_num_prod..]); - let final_deno_product = builder.mul_many_extension(&denominator_values[final_num_prod..]); - let z_gz_denominators = builder.mul_extension(z_gz, final_deno_product); - let last_partial = *current_partial_products.last().unwrap(); - let v_shift_term = - builder.mul_sub_extension(last_partial, final_nume_product, z_gz_denominators); - vanishing_v_shift_terms.push(v_shift_term); } let vanishing_terms = [ vanishing_z_1_terms, vanishing_partial_products_terms, - vanishing_v_shift_terms, constraint_terms, ] .concat(); diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index c4133b4d..0f3c9bfa 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -1,9 +1,12 @@ +use std::iter; + use itertools::Itertools; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; use crate::field::field_types::{Field, RichField}; use crate::plonk::circuit_builder::CircuitBuilder; +use crate::util::ceil_div_usize; pub(crate) fn quotient_chunk_products( quotient_values: &[F], @@ -33,70 +36,74 @@ pub(crate) fn partial_products_and_z_gx(z_x: F, quotient_chunk_product /// Returns a tuple `(a,b)`, where `a` is the length of the output of `partial_products()` on a /// vector of length `n`, and `b` is the number of original elements consumed in `partial_products()`. -pub fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { +pub(crate) fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { debug_assert!(max_degree > 1); let chunk_size = max_degree; - let num_chunks = n / chunk_size; - + // We'll split the product into `ceil_div_usize(n, chunk_size)` chunks, but the last chunk will + // be associated with Z(gx) itself. Thus we subtract one to get the chunks associated with + // partial products. + let num_chunks = ceil_div_usize(n, chunk_size) - 1; (num_chunks, num_chunks * chunk_size) } -/// Checks that the partial products of `numerators/denominators` are coherent with those in `partials` by only computing -/// products of size `max_degree` or less. +/// Checks the relationship between each pair of partial product accumulators. In particular, this +/// sequence of accumulators starts with `Z(x)`, then contains each partial product polynomials +/// `p_i(x)`, and finally `Z(g x)`. See the partial products section of the Plonky2 paper. pub(crate) fn check_partial_products( numerators: &[F], denominators: &[F], partials: &[F], z_x: F, + z_gx: F, max_degree: usize, ) -> Vec { debug_assert!(max_degree > 1); - let mut acc = z_x; - let mut partials = partials.iter(); - let mut res = Vec::new(); + let product_accs = iter::once(&z_x) + .chain(partials.iter()) + .chain(iter::once(&z_gx)); let chunk_size = max_degree; - for (nume_chunk, deno_chunk) in numerators - .chunks_exact(chunk_size) - .zip_eq(denominators.chunks_exact(chunk_size)) - { - let num_chunk_product = nume_chunk.iter().copied().product(); - let den_chunk_product = deno_chunk.iter().copied().product(); - let new_acc = *partials.next().unwrap(); - res.push(acc * num_chunk_product - new_acc * den_chunk_product); - acc = new_acc; - } - debug_assert!(partials.next().is_none()); - - res + numerators + .chunks(chunk_size) + .zip_eq(denominators.chunks(chunk_size)) + .zip_eq(product_accs.tuple_windows()) + .map(|((nume_chunk, deno_chunk), (&prev_acc, &next_acc))| { + let num_chunk_product = nume_chunk.iter().copied().product(); + let den_chunk_product = deno_chunk.iter().copied().product(); + // Assert that next_acc * deno_product = prev_acc * nume_product. + prev_acc * num_chunk_product - next_acc * den_chunk_product + }) + .collect() } +/// Checks the relationship between each pair of partial product accumulators. In particular, this +/// sequence of accumulators starts with `Z(x)`, then contains each partial product polynomials +/// `p_i(x)`, and finally `Z(g x)`. See the partial products section of the Plonky2 paper. pub(crate) fn check_partial_products_recursively, const D: usize>( builder: &mut CircuitBuilder, numerators: &[ExtensionTarget], denominators: &[ExtensionTarget], partials: &[ExtensionTarget], - mut acc: ExtensionTarget, + z_x: ExtensionTarget, + z_gx: ExtensionTarget, max_degree: usize, ) -> Vec> { debug_assert!(max_degree > 1); - let mut partials = partials.iter(); - let mut res = Vec::new(); + let product_accs = iter::once(&z_x) + .chain(partials.iter()) + .chain(iter::once(&z_gx)); let chunk_size = max_degree; - for (nume_chunk, deno_chunk) in numerators - .chunks_exact(chunk_size) - .zip(denominators.chunks_exact(chunk_size)) - { - let nume_product = builder.mul_many_extension(nume_chunk); - let deno_product = builder.mul_many_extension(deno_chunk); - let new_acc = *partials.next().unwrap(); - let new_acc_deno = builder.mul_extension(new_acc, deno_product); - // Assert that new_acc*deno_product = acc * nume_product. - res.push(builder.mul_sub_extension(acc, nume_product, new_acc_deno)); - acc = new_acc; - } - debug_assert!(partials.next().is_none()); - - res + numerators + .chunks(chunk_size) + .zip_eq(denominators.chunks(chunk_size)) + .zip_eq(product_accs.tuple_windows()) + .map(|((nume_chunk, deno_chunk), (&prev_acc, &next_acc))| { + let nume_product = builder.mul_many_extension(nume_chunk); + let deno_product = builder.mul_many_extension(deno_chunk); + let next_acc_deno = builder.mul_extension(next_acc, deno_product); + // Assert that next_acc * deno_product = prev_acc * nume_product. + builder.mul_sub_extension(prev_acc, nume_product, next_acc_deno) + }) + .collect() } #[cfg(test)] @@ -108,36 +115,31 @@ mod tests { fn test_partial_products() { type F = GoldilocksField; let denominators = vec![F::ONE; 6]; + let z_x = F::ONE; let v = field_vec(&[1, 2, 3, 4, 5, 6]); + let z_gx = F::from_canonical_u64(720); let quotient_chunks_prods = quotient_chunk_products(&v, 2); assert_eq!(quotient_chunks_prods, field_vec(&[2, 12, 30])); - let p = partial_products_and_z_gx(F::ONE, "ient_chunks_prods); - assert_eq!(p, field_vec(&[2, 24, 720])); + let pps_and_z_gx = partial_products_and_z_gx(z_x, "ient_chunks_prods); + let pps = &pps_and_z_gx[..pps_and_z_gx.len() - 1]; + assert_eq!(pps_and_z_gx, field_vec(&[2, 24, 720])); let nums = num_partial_products(v.len(), 2); - assert_eq!(p.len(), nums.0); - assert!(check_partial_products(&v, &denominators, &p, F::ONE, 2) + assert_eq!(pps.len(), nums.0); + assert!(check_partial_products(&v, &denominators, pps, z_x, z_gx, 2) .iter() .all(|x| x.is_zero())); - assert_eq!( - *p.last().unwrap() * v[nums.1..].iter().copied().product::(), - v.into_iter().product::(), - ); - let v = field_vec(&[1, 2, 3, 4, 5, 6]); let quotient_chunks_prods = quotient_chunk_products(&v, 3); assert_eq!(quotient_chunks_prods, field_vec(&[6, 120])); - let p = partial_products_and_z_gx(F::ONE, "ient_chunks_prods); - assert_eq!(p, field_vec(&[6, 720])); + let pps_and_z_gx = partial_products_and_z_gx(z_x, "ient_chunks_prods); + let pps = &pps_and_z_gx[..pps_and_z_gx.len() - 1]; + assert_eq!(pps_and_z_gx, field_vec(&[6, 720])); let nums = num_partial_products(v.len(), 3); - assert_eq!(p.len(), nums.0); - assert!(check_partial_products(&v, &denominators, &p, F::ONE, 3) + assert_eq!(pps.len(), nums.0); + assert!(check_partial_products(&v, &denominators, pps, z_x, z_gx, 3) .iter() .all(|x| x.is_zero())); - assert_eq!( - *p.last().unwrap() * v[nums.1..].iter().copied().product::(), - v.into_iter().product::(), - ); } fn field_vec(xs: &[usize]) -> Vec {