diff --git a/src/plonk/prover.rs b/src/plonk/prover.rs index 031354fa..427880c3 100644 --- a/src/plonk/prover.rs +++ b/src/plonk/prover.rs @@ -268,8 +268,7 @@ fn wires_permutation_partial_products, const D: usi let quotient_partials = partial_products("ient_values, degree); // This is the final product for the quotient. - let quotient = quotient_partials - [common_data.num_partial_products.0 - common_data.num_partial_products.1..] + let quotient = quotient_partials[common_data.num_partial_products.1..] .iter() .copied() .product(); diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index 633047d0..83b0e396 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -1,5 +1,5 @@ use std::iter::Product; -use std::ops::Sub; +use std::ops::{MulAssign, Sub}; use crate::field::extension_field::target::ExtensionTarget; use crate::field::extension_field::Extendable; @@ -9,17 +9,18 @@ use crate::util::ceil_div_usize; /// Compute partial products of the original vector `v` such that all products consist of `max_degree` /// or less elements. This is done until we've computed the product `P` of all elements in the vector. -pub fn partial_products(v: &[T], max_degree: usize) -> Vec { +pub fn partial_products(v: &[T], max_degree: usize) -> Vec { + debug_assert!(max_degree > 1); let mut res = Vec::new(); - let mut remainder = v.to_vec(); - while remainder.len() > max_degree { - let new_partials = remainder - .chunks(max_degree) - // TODO: can filter out chunks of length 1. - .map(|chunk| chunk.iter().copied().product()) - .collect::>(); - res.extend_from_slice(&new_partials); - remainder = new_partials; + let mut acc = v[0]; + let chunk_size = max_degree - 1; + let num_chunks = ceil_div_usize(v.len() - 1, chunk_size) - 1; + for i in 0..num_chunks { + acc *= v[1 + i * chunk_size..1 + (i + 1) * chunk_size] + .iter() + .copied() + .product(); + res.push(acc); } res @@ -29,34 +30,33 @@ pub fn partial_products(v: &[T], max_degree: usize) -> Vec /// vector of length `n`, and `b` is the number of elements needed to compute the final product. pub fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { debug_assert!(max_degree > 1); - let mut res = 0; - let mut remainder = n; - while remainder > max_degree { - let new_partials_len = ceil_div_usize(remainder, max_degree); - res += new_partials_len; - remainder = new_partials_len; - } + let chunk_size = max_degree - 1; + let num_chunks = ceil_div_usize(n - 1, chunk_size) - 1; - (res, remainder) + (num_chunks, 1 + num_chunks * chunk_size) } /// Checks that the partial products of `v` are coherent with those in `partials` by only computing /// products of size `max_degree` or less. -pub fn check_partial_products>( +pub fn check_partial_products>( v: &[T], mut partials: &[T], max_degree: usize, ) -> Vec { + debug_assert!(max_degree > 1); + let mut partials = partials.iter(); let mut res = Vec::new(); - let mut remainder = v; - while remainder.len() > max_degree { - let products = remainder - .chunks(max_degree) - .map(|chunk| chunk.iter().copied().product::()); - let products_len = products.len(); - res.extend(products.zip(partials).map(|(a, &b)| a - b)); - (remainder, partials) = partials.split_at(products_len); + let mut acc = v[0]; + let chunk_size = max_degree - 1; + let num_chunks = ceil_div_usize(v.len() - 1, chunk_size) - 1; + for i in 0..num_chunks { + acc *= v[1 + i * chunk_size..1 + (i + 1) * chunk_size] + .iter() + .copied() + .product(); + res.push(acc - *partials.next().unwrap()); } + debug_assert!(partials.next().is_none()); res } @@ -67,22 +67,20 @@ pub fn check_partial_products_recursively, const D: partials: &[ExtensionTarget], max_degree: usize, ) -> Vec> { + debug_assert!(max_degree > 1); + let mut partials = partials.iter(); let mut res = Vec::new(); - let mut remainder = v.to_vec(); - let mut partials = partials.to_vec(); - while remainder.len() > max_degree { - let products = remainder - .chunks(max_degree) - .map(|chunk| builder.mul_many_extension(chunk)) - .collect::>(); - res.extend( - products - .iter() - .zip(&partials) - .map(|(&a, &b)| builder.sub_extension(a, b)), - ); - remainder = partials.drain(..products.len()).collect(); + let mut acc = v[0]; + let chunk_size = max_degree - 1; + let num_chunks = ceil_div_usize(v.len() - 1, chunk_size) - 1; + for i in 0..num_chunks { + let mut chunk = v[1 + i * chunk_size..1 + (i + 1) * chunk_size].to_vec(); + chunk.push(acc); + acc = builder.mul_many_extension(&chunk); + + res.push(builder.sub_extension(acc, *partials.next().unwrap())); } + debug_assert!(partials.next().is_none()); res } @@ -97,15 +95,15 @@ mod tests { fn test_partial_products() { let v = vec![1, 2, 3, 4, 5, 6]; let p = partial_products(&v, 2); - assert_eq!(p, vec![2, 12, 30, 24, 30]); + assert_eq!(p, vec![2, 6, 24, 120]); let nums = num_partial_products(v.len(), 2); assert_eq!(p.len(), nums.0); assert!(check_partial_products(&v, &p, 2) .iter() .all(|x| x.is_zero())); assert_eq!( + *p.last().unwrap() * v[nums.1..].iter().copied().product::(), v.into_iter().product::(), - p[p.len() - nums.1..].iter().copied().product(), ); let v = vec![1, 2, 3, 4, 5, 6]; @@ -117,8 +115,8 @@ mod tests { .iter() .all(|x| x.is_zero())); assert_eq!( + *p.last().unwrap() * v[nums.1..].iter().copied().product::(), v.into_iter().product::(), - p[p.len() - nums.1..].iter().copied().product(), ); } }