diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 6c57217c..be356d9f 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -100,7 +100,7 @@ pub(crate) fn prove, const D: usize>( let plonk_z_vecs = timed!( timing, "compute Z's", - compute_zs(&partial_products, common_data) + compute_zs(&mut partial_products, common_data) ); // The first polynomial in `partial_products` represent the final product used in the @@ -286,24 +286,26 @@ fn wires_permutation_partial_products, const D: usi } fn compute_zs, const D: usize>( - partial_products: &[Vec>], + partial_products: &mut [Vec>], common_data: &CommonCircuitData, ) -> Vec> { (0..common_data.config.num_challenges) - .map(|i| compute_z(&partial_products[i], common_data)) + .map(|i| compute_z(&mut partial_products[i], common_data)) .collect() } /// Compute the `Z` polynomial by reusing the computations done in `wires_permutation_partial_products`. fn compute_z, const D: usize>( - partial_products: &[PolynomialValues], + partial_products: &mut [PolynomialValues], common_data: &CommonCircuitData, ) -> PolynomialValues { let mut plonk_z_points = vec![F::ONE]; for i in 1..common_data.degree() { - let quotient = partial_products[0].values[i - 1]; let last = *plonk_z_points.last().unwrap(); - plonk_z_points.push(last * quotient); + for q in partial_products.iter_mut() { + q.values[i - 1] *= last; + } + plonk_z_points.push(partial_products[0].values[i - 1]); } plonk_z_points.into() } diff --git a/src/plonk/vanishing_poly.rs b/src/plonk/vanishing_poly.rs index 43e52994..899d69a6 100644 --- a/src/plonk/vanishing_poly.rs +++ b/src/plonk/vanishing_poly.rs @@ -70,13 +70,13 @@ pub(crate) fn eval_vanishing_poly, const D: usize>( &numerator_values, &denominator_values, current_partial_products, + z_x, max_degree, ); vanishing_partial_products_terms.extend(partial_product_checks); let v_shift_term = *current_partial_products.last().unwrap() * numerator_values[final_num_prod..].iter().copied().product() - * z_x - z_gz * denominator_values[final_num_prod..] .iter() @@ -180,13 +180,13 @@ pub(crate) fn eval_vanishing_poly_base_batch, const &numerator_values, &denominator_values, current_partial_products, + z_x, max_degree, ); vanishing_partial_products_terms.extend(partial_product_checks); let v_shift_term = *current_partial_products.last().unwrap() * numerator_values[final_num_prod..].iter().copied().product() - * z_x - z_gz * denominator_values[final_num_prod..] .iter() @@ -376,6 +376,7 @@ pub(crate) fn eval_vanishing_poly_recursively, cons &numerator_values, &denominator_values, current_partial_products, + z_x, max_degree, ); vanishing_partial_products_terms.extend(partial_product_checks); @@ -390,7 +391,7 @@ pub(crate) fn eval_vanishing_poly_recursively, cons v.push(z_gz); v }); - let v_shift_term = builder.mul_sub_extension(nume_acc, z_x, z_gz_denominators); + let v_shift_term = builder.sub_extension(nume_acc, z_gz_denominators); vanishing_v_shift_terms.push(v_shift_term); } diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index c3c4659a..92419a56 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -34,12 +34,12 @@ pub fn check_partial_products( numerators: &[F], denominators: &[F], partials: &[F], + mut acc: F, max_degree: usize, ) -> Vec { debug_assert!(max_degree > 1); let mut partials = partials.iter(); let mut res = Vec::new(); - let mut acc = F::ONE; let chunk_size = max_degree; for (nume_chunk, deno_chunk) in numerators .chunks_exact(chunk_size) @@ -60,12 +60,12 @@ pub fn check_partial_products_recursively, const D: numerators: &[ExtensionTarget], denominators: &[ExtensionTarget], partials: &[ExtensionTarget], + mut acc: ExtensionTarget, max_degree: usize, ) -> Vec> { debug_assert!(max_degree > 1); let mut partials = partials.iter(); let mut res = Vec::new(); - let mut acc = builder.one_extension(); let chunk_size = max_degree; for (nume_chunk, deno_chunk) in numerators .chunks_exact(chunk_size) @@ -108,7 +108,7 @@ mod tests { let nums = num_partial_products(v.len(), 2); assert_eq!(p.len(), nums.0); - assert!(check_partial_products(&v, &denominators, &p, 2) + assert!(check_partial_products(&v, &denominators, &p, F::ONE, 2) .iter() .all(|x| x.is_zero())); assert_eq!( @@ -130,7 +130,7 @@ mod tests { ); let nums = num_partial_products(v.len(), 3); assert_eq!(p.len(), nums.0); - assert!(check_partial_products(&v, &denominators, &p, 3) + assert!(check_partial_products(&v, &denominators, &p, F::ONE, 3) .iter() .all(|x| x.is_zero())); assert_eq!(