Parallelize i/fft
This commit is contained in:
parent
f7126d6a5f
commit
5eeb96271f
|
@ -20,10 +20,11 @@ num-bigint = "0.4.4"
|
||||||
thiserror = "1.0.58"
|
thiserror = "1.0.58"
|
||||||
num-traits = "0.2.18"
|
num-traits = "0.2.18"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
|
rayon = { version = "1.10", optional = true }
|
||||||
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
divan = "0.1"
|
divan = "0.1"
|
||||||
rayon = "1.10"
|
|
||||||
|
|
||||||
[[bench]]
|
[[bench]]
|
||||||
name = "kzg"
|
name = "kzg"
|
||||||
|
@ -33,6 +34,7 @@ harness = false
|
||||||
default = ["single"]
|
default = ["single"]
|
||||||
single = []
|
single = []
|
||||||
parallel = [
|
parallel = [
|
||||||
|
"rayon",
|
||||||
"ark-ff/parallel",
|
"ark-ff/parallel",
|
||||||
"ark-ff/asm",
|
"ark-ff/asm",
|
||||||
"ark-ff/rayon",
|
"ark-ff/rayon",
|
||||||
|
|
|
@ -2,7 +2,8 @@ use ark_bls12_381::{Bls12_381, Fr, G1Affine};
|
||||||
use ark_ec::pairing::Pairing;
|
use ark_ec::pairing::Pairing;
|
||||||
use ark_ec::{AffineRepr, CurveGroup};
|
use ark_ec::{AffineRepr, CurveGroup};
|
||||||
use ark_ff::{BigInt, BigInteger, FftField, Field, PrimeField};
|
use ark_ff::{BigInt, BigInteger, FftField, Field, PrimeField};
|
||||||
use blst::BLS12_381_G1;
|
#[cfg(parallel)]
|
||||||
|
use rayon::iter::IntoParallelIterator;
|
||||||
|
|
||||||
pub fn fft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec<G1Affine> {
|
pub fn fft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec<G1Affine> {
|
||||||
debug_assert_eq!(vals.len(), roots_of_unity.len());
|
debug_assert_eq!(vals.len(), roots_of_unity.len());
|
||||||
|
@ -11,16 +12,19 @@ pub fn fft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec<G1Affine> {
|
||||||
}
|
}
|
||||||
let half_roots: Vec<_> = roots_of_unity.iter().step_by(2).copied().collect();
|
let half_roots: Vec<_> = roots_of_unity.iter().step_by(2).copied().collect();
|
||||||
|
|
||||||
let l = fft_g1(
|
let l = || {
|
||||||
|
fft_g1(
|
||||||
vals.iter()
|
vals.iter()
|
||||||
.step_by(2)
|
.step_by(2)
|
||||||
.copied()
|
.copied()
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.as_slice(),
|
.as_slice(),
|
||||||
half_roots.as_slice(),
|
half_roots.as_slice(),
|
||||||
);
|
)
|
||||||
|
};
|
||||||
|
|
||||||
let r = fft_g1(
|
let r = || {
|
||||||
|
fft_g1(
|
||||||
vals.iter()
|
vals.iter()
|
||||||
.skip(1)
|
.skip(1)
|
||||||
.step_by(2)
|
.step_by(2)
|
||||||
|
@ -28,15 +32,44 @@ pub fn fft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec<G1Affine> {
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.as_slice(),
|
.as_slice(),
|
||||||
half_roots.as_slice(),
|
half_roots.as_slice(),
|
||||||
);
|
)
|
||||||
|
};
|
||||||
|
|
||||||
let y_times_root = r
|
let [l, r]: [Vec<G1Affine>; 2] = {
|
||||||
.into_iter()
|
#[cfg(parallel)]
|
||||||
|
{
|
||||||
|
[l, r].into_par_iter().map(|f| f()).collect()
|
||||||
|
}
|
||||||
|
#[cfg(not(parallel))]
|
||||||
|
{
|
||||||
|
[l(), r()]
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let y_times_root = {
|
||||||
|
#[cfg(parallel)]
|
||||||
|
{
|
||||||
|
r.into_par_iter()
|
||||||
|
}
|
||||||
|
#[cfg(not(parallel))]
|
||||||
|
{
|
||||||
|
r.into_iter()
|
||||||
|
}
|
||||||
|
}
|
||||||
.cycle()
|
.cycle()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(i, y)| (y * roots_of_unity[i % vals.len()]).into_affine());
|
.map(|(i, y)| (y * roots_of_unity[i % vals.len()]).into_affine());
|
||||||
|
|
||||||
|
{
|
||||||
|
#[cfg(parallel)]
|
||||||
|
{
|
||||||
|
l.into_par_iter()
|
||||||
|
}
|
||||||
|
#[cfg(not(parallel))]
|
||||||
|
{
|
||||||
l.into_iter()
|
l.into_iter()
|
||||||
|
}
|
||||||
|
}
|
||||||
.cycle()
|
.cycle()
|
||||||
.take(vals.len())
|
.take(vals.len())
|
||||||
.zip(y_times_root)
|
.zip(y_times_root)
|
||||||
|
@ -57,8 +90,12 @@ pub fn ifft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec<G1Affine> {
|
||||||
let mut mod_min_2 = BigInt::new(<Fr as PrimeField>::MODULUS.0);
|
let mut mod_min_2 = BigInt::new(<Fr as PrimeField>::MODULUS.0);
|
||||||
mod_min_2.sub_with_borrow(&BigInt::<4>::from(2u64));
|
mod_min_2.sub_with_borrow(&BigInt::<4>::from(2u64));
|
||||||
let invlen = Fr::from(vals.len() as u64).pow(mod_min_2).into_bigint();
|
let invlen = Fr::from(vals.len() as u64).pow(mod_min_2).into_bigint();
|
||||||
fft_g1(vals, roots_of_unity)
|
#[cfg(parallel)]
|
||||||
.into_iter()
|
{
|
||||||
|
fft_g1(vals, roots_of_unity).into_par_iter()
|
||||||
|
}
|
||||||
|
#[cfg(not(parallel))]
|
||||||
|
{ fft_g1(vals, roots_of_unity).into_iter() }
|
||||||
.map(|g| g.mul_bigint(invlen).into_affine())
|
.map(|g| g.mul_bigint(invlen).into_affine())
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue