From 1858a869a71039bb5b5301c96997ec7b6f22b92e Mon Sep 17 00:00:00 2001 From: wborgeaud Date: Thu, 1 Jul 2021 15:57:55 +0200 Subject: [PATCH] Optimize products of 1 element --- src/util/partial_products.rs | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/util/partial_products.rs b/src/util/partial_products.rs index 2d9d33e2..f22ba01e 100644 --- a/src/util/partial_products.rs +++ b/src/util/partial_products.rs @@ -9,11 +9,17 @@ pub fn partial_products(v: &[T], max_degree: usize) -> Vec while remainder.len() >= max_degree { let new_partials = remainder .chunks(max_degree) - // TODO: If `chunk.len()=1`, there's some redundant data. + .filter(|chunk| chunk.len() != 1) .map(|chunk| chunk.iter().copied().product()) .collect::>(); res.extend_from_slice(&new_partials); + let addendum = if remainder.len() % max_degree == 1 { + vec![*remainder.last().unwrap()] + } else { + vec![] + }; remainder = new_partials; + remainder.extend(addendum); } res @@ -24,7 +30,8 @@ pub fn num_partial_products(n: usize, max_degree: usize) -> (usize, usize) { let mut remainder = n; while remainder >= max_degree { let new_partials_len = ceil_div_usize(remainder, max_degree); - res += new_partials_len; + let addendum = if remainder % max_degree == 1 { 1 } else { 0 }; + res += new_partials_len - addendum; remainder = new_partials_len; } @@ -42,10 +49,17 @@ pub fn check_partial_products>( while remainder.len() >= max_degree { let products = remainder .chunks(max_degree) + .filter(|chunk| chunk.len() != 1) .map(|chunk| chunk.iter().copied().product()) .collect::>(); res.extend(products.iter().zip(&partials).map(|(&a, &b)| a - b)); + let addendum = if remainder.len() % max_degree == 1 { + vec![*remainder.last().unwrap()] + } else { + vec![] + }; remainder = partials.drain(..products.len()).collect(); + remainder.extend(addendum) } res @@ -61,7 +75,7 @@ mod tests { fn test_partial_products() { let v = vec![1, 2, 3, 4, 5, 6]; let p = partial_products(&v, 2); - assert_eq!(p, vec![2, 12, 30, 24, 30, 720]); + assert_eq!(p, vec![2, 12, 30, 24, 720]); assert_eq!(p.len(), num_partial_products(v.len(), 2).0); assert!(check_partial_products(&v, &p, 2) .iter()