From 32f09ac2dfc026f4c3e9771cdaa293f59e21de0c Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Wed, 10 Nov 2021 18:13:27 +0100 Subject: [PATCH] Remove quotients and work directly with numerators and denominators in partial products check --- src/plonk/vanishing_poly.rs | 99 +++++++++++++----------------------- src/util/partial_products.rs | 41 ++++++++++----- 2 files changed, 64 insertions(+), 76 deletions(-) diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 4976eaba..43e52994 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -62,31 +62,26 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( wire_value + s_sigma.scalar_mul(betas[i]) + gammas[i].into() }) .collect::>(); - let quotient_values = (0..common_data.config.num_routed_wires) - .map(|j| numerator_values[j] / denominator_values[j]) - .collect::>(); // The partial products considered for this iteration of `i`. let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; // Check the quotient partial products. - let mut partial_product_checks = - check_partial_products("ient_values, current_partial_products, max_degree); - // The partial products are products of quotients, so we multiply them by the product of the - // corresponding denominators to make sure they are polynomials. - for (j, partial_product_check) in partial_product_checks.iter_mut().enumerate() { - let range = j * max_degree..(j + 1) * max_degree; - *partial_product_check *= denominator_values[range].iter().copied().product(); - } + let partial_product_checks = check_partial_products( + &numerator_values, + &denominator_values, + current_partial_products, + max_degree, + ); vanishing_partial_products_terms.extend(partial_product_checks); - let quotient: F::Extension = *current_partial_products.last().unwrap() - * quotient_values[final_num_prod..].iter().copied().product(); - let mut v_shift_term = quotient * z_x - z_gz; - // Need to multiply by the denominators to make sure we get a polynomial. - v_shift_term *= denominator_values[final_num_prod..] - .iter() - .copied() - .product(); + let v_shift_term = *current_partial_products.last().unwrap() + * numerator_values[final_num_prod..].iter().copied().product() + * z_x + - z_gz + * denominator_values[final_num_prod..] + .iter() + .copied() + .product(); vanishing_v_shift_terms.push(v_shift_term); } @@ -139,7 +134,6 @@ pub(crate) fn eval_vanishing_poly_base_batch, const let mut numerator_values = Vec::with_capacity(num_routed_wires); let mut denominator_values = Vec::with_capacity(num_routed_wires); - let mut quotient_values = Vec::with_capacity(num_routed_wires); // The L_1(x) (Z(x) - 1) vanishing terms. let mut vanishing_z_1_terms = Vec::with_capacity(num_challenges); @@ -178,37 +172,30 @@ pub(crate) fn eval_vanishing_poly_base_batch, const let s_sigma = s_sigmas[j]; wire_value + betas[i] * s_sigma + gammas[i] })); - let denominator_inverses = F::batch_multiplicative_inverse(&denominator_values); - quotient_values.extend( - (0..num_routed_wires).map(|j| numerator_values[j] * denominator_inverses[j]), - ); // The partial products considered for this iteration of `i`. let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; // Check the numerator partial products. - let mut partial_product_checks = - check_partial_products("ient_values, current_partial_products, max_degree); - // The partial products are products of quotients, so we multiply them by the product of the - // corresponding denominators to make sure they are polynomials. - for (j, partial_product_check) in partial_product_checks.iter_mut().enumerate() { - let range = j * max_degree..(j + 1) * max_degree; - *partial_product_check *= denominator_values[range].iter().copied().product(); - } + let partial_product_checks = check_partial_products( + &numerator_values, + &denominator_values, + current_partial_products, + max_degree, + ); vanishing_partial_products_terms.extend(partial_product_checks); - let quotient: F = *current_partial_products.last().unwrap() - * quotient_values[final_num_prod..].iter().copied().product(); - let mut v_shift_term = quotient * z_x - z_gz; - // Need to multiply by the denominators to make sure we get a polynomial. - v_shift_term *= denominator_values[final_num_prod..] - .iter() - .copied() - .product(); + let v_shift_term = *current_partial_products.last().unwrap() + * numerator_values[final_num_prod..].iter().copied().product() + * z_x + - z_gz + * denominator_values[final_num_prod..] + .iter() + .copied() + .product(); vanishing_v_shift_terms.push(v_shift_term); numerator_values.clear(); denominator_values.clear(); - quotient_values.clear(); } let vanishing_terms = vanishing_z_1_terms @@ -365,7 +352,6 @@ pub(crate) fn eval_vanishing_poly_recursively, cons let mut numerator_values = Vec::new(); let mut denominator_values = Vec::new(); - let mut quotient_values = Vec::new(); for j in 0..common_data.config.num_routed_wires { let wire_value = vars.local_wires[j]; @@ -378,46 +364,33 @@ pub(crate) fn eval_vanishing_poly_recursively, cons let numerator = builder.mul_add_extension(beta_ext, s_ids[j], wire_value_plus_gamma); let denominator = builder.mul_add_extension(beta_ext, s_sigmas[j], wire_value_plus_gamma); - let quotient = builder.div_extension(numerator, denominator); - numerator_values.push(numerator); denominator_values.push(denominator); - quotient_values.push(quotient); } // The partial products considered for this iteration of `i`. let current_partial_products = &partial_products[i * num_prods..(i + 1) * num_prods]; // Check the quotient partial products. - let mut partial_product_checks = check_partial_products_recursively( + let partial_product_checks = check_partial_products_recursively( builder, - "ient_values, + &numerator_values, + &denominator_values, current_partial_products, max_degree, ); - // The partial products are products of quotients, so we multiply them by the product of the - // corresponding denominators to make sure they are polynomials. - for (j, partial_product_check) in partial_product_checks.iter_mut().enumerate() { - let range = j * max_degree..(j + 1) * max_degree; - *partial_product_check = builder.mul_many_extension(&{ - let mut v = denominator_values[range].to_vec(); - v.push(*partial_product_check); - v - }); - } vanishing_partial_products_terms.extend(partial_product_checks); - let quotient = builder.mul_many_extension(&{ - let mut v = quotient_values[final_num_prod..].to_vec(); + let nume_acc = builder.mul_many_extension(&{ + let mut v = numerator_values[final_num_prod..].to_vec(); v.push(*current_partial_products.last().unwrap()); v }); - let mut v_shift_term = builder.mul_sub_extension(quotient, z_x, z_gz); - // Need to multiply by the denominators to make sure we get a polynomial. - v_shift_term = builder.mul_many_extension(&{ + let z_gz_denominators = builder.mul_many_extension(&{ let mut v = denominator_values[final_num_prod..].to_vec(); - v.push(v_shift_term); + v.push(z_gz); v }); + let v_shift_term = builder.mul_sub_extension(nume_acc, z_x, z_gz_denominators); vanishing_v_shift_terms.push(v_shift_term); } diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index 37b51825..c3c4659a 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -28,18 +28,26 @@ pub fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { (num_chunks, num_chunks * chunk_size) } -/// Checks that the partial products of `v` are coherent with those in `partials` by only computing +/// Checks that the partial products of `numerators/denominators` are coherent with those in `partials` by only computing /// products of size `max_degree` or less. -pub fn check_partial_products(v: &[F], partials: &[F], max_degree: usize) -> Vec { +pub fn check_partial_products( + numerators: &[F], + denominators: &[F], + partials: &[F], + max_degree: usize, +) -> Vec { debug_assert!(max_degree > 1); let mut partials = partials.iter(); let mut res = Vec::new(); let mut acc = F::ONE; let chunk_size = max_degree; - for chunk in v.chunks_exact(chunk_size) { - acc *= chunk.iter().copied().product(); - let new_acc = *partials.next().unwrap(); - res.push(acc - new_acc); + for (nume_chunk, deno_chunk) in numerators + .chunks_exact(chunk_size) + .zip(denominators.chunks_exact(chunk_size)) + { + acc *= nume_chunk.iter().copied().product(); + let mut new_acc = *partials.next().unwrap(); + res.push(acc - new_acc * deno_chunk.iter().copied().product()); acc = new_acc; } debug_assert!(partials.next().is_none()); @@ -49,7 +57,8 @@ pub fn check_partial_products(v: &[F], partials: &[F], max_degree: usi pub fn check_partial_products_recursively, const D: usize>( builder: &mut CircuitBuilder, - v: &[ExtensionTarget], + numerators: &[ExtensionTarget], + denominators: &[ExtensionTarget], partials: &[ExtensionTarget], max_degree: usize, ) -> Vec> { @@ -58,11 +67,16 @@ pub fn check_partial_products_recursively, const D: let mut res = Vec::new(); let mut acc = builder.one_extension(); let chunk_size = max_degree; - for chunk in v.chunks_exact(chunk_size) { - let chunk_product = builder.mul_many_extension(chunk); + for (nume_chunk, deno_chunk) in numerators + .chunks_exact(chunk_size) + .zip(denominators.chunks_exact(chunk_size)) + { + let nume_product = builder.mul_many_extension(nume_chunk); + let deno_product = builder.mul_many_extension(deno_chunk); let new_acc = *partials.next().unwrap(); - // Assert that new_acc = acc * chunk_product. - res.push(builder.mul_sub_extension(acc, chunk_product, new_acc)); + let new_acc_deno = builder.mul_extension(new_acc, deno_product); + // Assert that new_acc*deno_product = acc * nume_product. + res.push(builder.mul_sub_extension(acc, nume_product, new_acc_deno)); acc = new_acc; } debug_assert!(partials.next().is_none()); @@ -78,6 +92,7 @@ mod tests { #[test] fn test_partial_products() { type F = GoldilocksField; + let denominators = vec![F::ONE; 6]; let v = [1, 2, 3, 4, 5, 6] .into_iter() .map(|&i| F::from_canonical_u64(i)) @@ -93,7 +108,7 @@ mod tests { let nums = num_partial_products(v.len(), 2); assert_eq!(p.len(), nums.0); - assert!(check_partial_products(&v, &p, 2) + assert!(check_partial_products(&v, &denominators, &p, 2) .iter() .all(|x| x.is_zero())); assert_eq!( @@ -115,7 +130,7 @@ mod tests { ); let nums = num_partial_products(v.len(), 3); assert_eq!(p.len(), nums.0); - assert!(check_partial_products(&v, &p, 3) + assert!(check_partial_products(&v, &denominators, &p, 3) .iter() .all(|x| x.is_zero())); assert_eq!(