diff --git a/nomos-da/kzgrs/Cargo.toml b/nomos-da/kzgrs/Cargo.toml index e15fcfd5..06a36c1d 100644 --- a/nomos-da/kzgrs/Cargo.toml +++ b/nomos-da/kzgrs/Cargo.toml @@ -20,10 +20,11 @@ num-bigint = "0.4.4" thiserror = "1.0.58" num-traits = "0.2.18" rand = "0.8.5" +rayon = { version = "1.10", optional = true } + [dev-dependencies] divan = "0.1" -rayon = "1.10" [[bench]] name = "kzg" @@ -33,6 +34,7 @@ harness = false default = ["single"] single = [] parallel = [ + "rayon", "ark-ff/parallel", "ark-ff/asm", "ark-ff/rayon", diff --git a/nomos-da/kzgrs/src/fft.rs b/nomos-da/kzgrs/src/fft.rs index 9aec9c92..84a410ac 100644 --- a/nomos-da/kzgrs/src/fft.rs +++ b/nomos-da/kzgrs/src/fft.rs @@ -2,7 +2,8 @@ use ark_bls12_381::{Bls12_381, Fr, G1Affine}; use ark_ec::pairing::Pairing; use ark_ec::{AffineRepr, CurveGroup}; use ark_ff::{BigInt, BigInteger, FftField, Field, PrimeField}; -use blst::BLS12_381_G1; +#[cfg(parallel)] +use rayon::iter::IntoParallelIterator; pub fn fft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec { debug_assert_eq!(vals.len(), roots_of_unity.len()); @@ -11,45 +12,77 @@ pub fn fft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec { } let half_roots: Vec<_> = roots_of_unity.iter().step_by(2).copied().collect(); - let l = fft_g1( - vals.iter() - .step_by(2) - .copied() - .collect::>() - .as_slice(), - half_roots.as_slice(), - ); + let l = || { + fft_g1( + vals.iter() + .step_by(2) + .copied() + .collect::>() + .as_slice(), + half_roots.as_slice(), + ) + }; - let r = fft_g1( - vals.iter() - .skip(1) - .step_by(2) - .copied() - .collect::>() - .as_slice(), - half_roots.as_slice(), - ); + let r = || { + fft_g1( + vals.iter() + .skip(1) + .step_by(2) + .copied() + .collect::>() + .as_slice(), + half_roots.as_slice(), + ) + }; - let y_times_root = r - .into_iter() - .cycle() - .enumerate() - .map(|(i, y)| (y * roots_of_unity[i % vals.len()]).into_affine()); + let [l, r]: [Vec; 2] = { + #[cfg(parallel)] + { + [l, r].into_par_iter().map(|f| f()).collect() + } + #[cfg(not(parallel))] + { + [l(), r()] + } + }; - l.into_iter() - .cycle() - .take(vals.len()) - .zip(y_times_root) - .enumerate() - .map(|(i, (x, y_times_root))| { - if i < vals.len() / 2 { - x + y_times_root - } else { - x - y_times_root - } - .into_affine() - }) - .collect() + let y_times_root = { + #[cfg(parallel)] + { + r.into_par_iter() + } + #[cfg(not(parallel))] + { + r.into_iter() + } + } + .cycle() + .enumerate() + .map(|(i, y)| (y * roots_of_unity[i % vals.len()]).into_affine()); + + { + #[cfg(parallel)] + { + l.into_par_iter() + } + #[cfg(not(parallel))] + { + l.into_iter() + } + } + .cycle() + .take(vals.len()) + .zip(y_times_root) + .enumerate() + .map(|(i, (x, y_times_root))| { + if i < vals.len() / 2 { + x + y_times_root + } else { + x - y_times_root + } + .into_affine() + }) + .collect() } pub fn ifft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec { @@ -57,8 +90,12 @@ pub fn ifft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec { let mut mod_min_2 = BigInt::new(::MODULUS.0); mod_min_2.sub_with_borrow(&BigInt::<4>::from(2u64)); let invlen = Fr::from(vals.len() as u64).pow(mod_min_2).into_bigint(); - fft_g1(vals, roots_of_unity) - .into_iter() + #[cfg(parallel)] + { + fft_g1(vals, roots_of_unity).into_par_iter() + } + #[cfg(not(parallel))] + { fft_g1(vals, roots_of_unity).into_iter() } .map(|g| g.mul_bigint(invlen).into_affine()) .collect() }