Parallelize i/fft

This commit is contained in:
danielsanchezq 2024-06-14 12:01:06 +02:00
parent f7126d6a5f
commit 5eeb96271f
2 changed files with 79 additions and 40 deletions

View File

@ -20,10 +20,11 @@ num-bigint = "0.4.4"
thiserror = "1.0.58" thiserror = "1.0.58"
num-traits = "0.2.18" num-traits = "0.2.18"
rand = "0.8.5" rand = "0.8.5"
rayon = { version = "1.10", optional = true }
[dev-dependencies] [dev-dependencies]
divan = "0.1" divan = "0.1"
rayon = "1.10"
[[bench]] [[bench]]
name = "kzg" name = "kzg"
@ -33,6 +34,7 @@ harness = false
default = ["single"] default = ["single"]
single = [] single = []
parallel = [ parallel = [
"rayon",
"ark-ff/parallel", "ark-ff/parallel",
"ark-ff/asm", "ark-ff/asm",
"ark-ff/rayon", "ark-ff/rayon",

View File

@ -2,7 +2,8 @@ use ark_bls12_381::{Bls12_381, Fr, G1Affine};
use ark_ec::pairing::Pairing; use ark_ec::pairing::Pairing;
use ark_ec::{AffineRepr, CurveGroup}; use ark_ec::{AffineRepr, CurveGroup};
use ark_ff::{BigInt, BigInteger, FftField, Field, PrimeField}; use ark_ff::{BigInt, BigInteger, FftField, Field, PrimeField};
use blst::BLS12_381_G1; #[cfg(parallel)]
use rayon::iter::IntoParallelIterator;
pub fn fft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec<G1Affine> { pub fn fft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec<G1Affine> {
debug_assert_eq!(vals.len(), roots_of_unity.len()); debug_assert_eq!(vals.len(), roots_of_unity.len());
@ -11,45 +12,77 @@ pub fn fft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec<G1Affine> {
} }
let half_roots: Vec<_> = roots_of_unity.iter().step_by(2).copied().collect(); let half_roots: Vec<_> = roots_of_unity.iter().step_by(2).copied().collect();
let l = fft_g1( let l = || {
vals.iter() fft_g1(
.step_by(2) vals.iter()
.copied() .step_by(2)
.collect::<Vec<_>>() .copied()
.as_slice(), .collect::<Vec<_>>()
half_roots.as_slice(), .as_slice(),
); half_roots.as_slice(),
)
};
let r = fft_g1( let r = || {
vals.iter() fft_g1(
.skip(1) vals.iter()
.step_by(2) .skip(1)
.copied() .step_by(2)
.collect::<Vec<_>>() .copied()
.as_slice(), .collect::<Vec<_>>()
half_roots.as_slice(), .as_slice(),
); half_roots.as_slice(),
)
};
let y_times_root = r let [l, r]: [Vec<G1Affine>; 2] = {
.into_iter() #[cfg(parallel)]
.cycle() {
.enumerate() [l, r].into_par_iter().map(|f| f()).collect()
.map(|(i, y)| (y * roots_of_unity[i % vals.len()]).into_affine()); }
#[cfg(not(parallel))]
{
[l(), r()]
}
};
l.into_iter() let y_times_root = {
.cycle() #[cfg(parallel)]
.take(vals.len()) {
.zip(y_times_root) r.into_par_iter()
.enumerate() }
.map(|(i, (x, y_times_root))| { #[cfg(not(parallel))]
if i < vals.len() / 2 { {
x + y_times_root r.into_iter()
} else { }
x - y_times_root }
} .cycle()
.into_affine() .enumerate()
}) .map(|(i, y)| (y * roots_of_unity[i % vals.len()]).into_affine());
.collect()
{
#[cfg(parallel)]
{
l.into_par_iter()
}
#[cfg(not(parallel))]
{
l.into_iter()
}
}
.cycle()
.take(vals.len())
.zip(y_times_root)
.enumerate()
.map(|(i, (x, y_times_root))| {
if i < vals.len() / 2 {
x + y_times_root
} else {
x - y_times_root
}
.into_affine()
})
.collect()
} }
pub fn ifft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec<G1Affine> { pub fn ifft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec<G1Affine> {
@ -57,8 +90,12 @@ pub fn ifft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec<G1Affine> {
let mut mod_min_2 = BigInt::new(<Fr as PrimeField>::MODULUS.0); let mut mod_min_2 = BigInt::new(<Fr as PrimeField>::MODULUS.0);
mod_min_2.sub_with_borrow(&BigInt::<4>::from(2u64)); mod_min_2.sub_with_borrow(&BigInt::<4>::from(2u64));
let invlen = Fr::from(vals.len() as u64).pow(mod_min_2).into_bigint(); let invlen = Fr::from(vals.len() as u64).pow(mod_min_2).into_bigint();
fft_g1(vals, roots_of_unity) #[cfg(parallel)]
.into_iter() {
fft_g1(vals, roots_of_unity).into_par_iter()
}
#[cfg(not(parallel))]
{ fft_g1(vals, roots_of_unity).into_iter() }
.map(|g| g.mul_bigint(invlen).into_affine()) .map(|g| g.mul_bigint(invlen).into_affine())
.collect() .collect()
} }