diff --git a/nomos-da/kzgrs/src/fft.rs b/nomos-da/kzgrs/src/fft.rs index 66edf55b..89bfcce4 100644 --- a/nomos-da/kzgrs/src/fft.rs +++ b/nomos-da/kzgrs/src/fft.rs @@ -3,11 +3,12 @@ use ark_ec::pairing::Pairing; use ark_ec::{AffineRepr, CurveGroup}; use ark_ff::{BigInt, BigInteger, FftField, Field, PrimeField}; #[cfg(feature = "parallel")] -use rayon::iter::IntoParallelIterator; +use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; pub fn fft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec { debug_assert_eq!(vals.len(), roots_of_unity.len()); - if vals.len() == 1 { + let original_len = vals.len(); + if original_len == 1 { return vals.to_vec(); } let half_roots: Vec<_> = roots_of_unity.iter().step_by(2).copied().collect(); @@ -38,13 +39,17 @@ pub fn fft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec { let [l, r]: [Vec; 2] = { #[cfg(feature = "parallel")] { - [l, r].into_par_iter().map(|f| f()).collect() + let (l, r) = rayon::join(l, r); + [l, r] } #[cfg(not(feature = "parallel"))] { [l(), r()] } }; + // Double sized so we can use iterator later on + let l: Vec<_> = l.into_iter().cycle().take(original_len).collect(); + let r: Vec<_> = r.into_iter().cycle().take(original_len).collect(); let y_times_root = { #[cfg(feature = "parallel")] @@ -56,7 +61,6 @@ pub fn fft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec { r.into_iter() } } - .cycle() .enumerate() .map(|(i, y)| (y * roots_of_unity[i % vals.len()]).into_affine()); @@ -70,8 +74,6 @@ pub fn fft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec { l.into_iter() } } - .cycle() - .take(vals.len()) .zip(y_times_root) .enumerate() .map(|(i, (x, y_times_root))| { @@ -90,14 +92,18 @@ pub fn ifft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec { let mut mod_min_2 = BigInt::new(::MODULUS.0); mod_min_2.sub_with_borrow(&BigInt::<4>::from(2u64)); let invlen = Fr::from(vals.len() as u64).pow(mod_min_2).into_bigint(); - #[cfg(feature = "parallel")] { - fft_g1(vals, roots_of_unity).into_par_iter().collect() + #[cfg(feature = "parallel")] + { + fft_g1(vals, roots_of_unity).into_par_iter() + } + #[cfg(not(feature = "parallel"))] + { + fft_g1(vals, roots_of_unity).into_iter() + } } - #[cfg(not(feature = "parallel"))] - { fft_g1(vals, roots_of_unity).into_iter() } - .map(|g| g.mul_bigint(invlen).into_affine()) - .collect() + .map(|g| g.mul_bigint(invlen).into_affine()) + .collect() } #[cfg(test)]