Fix parallelization
This commit is contained in:
parent
c83638dd3a
commit
bf83bc4403
|
@ -3,11 +3,12 @@ use ark_ec::pairing::Pairing;
|
||||||
use ark_ec::{AffineRepr, CurveGroup};
|
use ark_ec::{AffineRepr, CurveGroup};
|
||||||
use ark_ff::{BigInt, BigInteger, FftField, Field, PrimeField};
|
use ark_ff::{BigInt, BigInteger, FftField, Field, PrimeField};
|
||||||
#[cfg(feature = "parallel")]
|
#[cfg(feature = "parallel")]
|
||||||
use rayon::iter::IntoParallelIterator;
|
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
|
||||||
|
|
||||||
pub fn fft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec<G1Affine> {
|
pub fn fft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec<G1Affine> {
|
||||||
debug_assert_eq!(vals.len(), roots_of_unity.len());
|
debug_assert_eq!(vals.len(), roots_of_unity.len());
|
||||||
if vals.len() == 1 {
|
let original_len = vals.len();
|
||||||
|
if original_len == 1 {
|
||||||
return vals.to_vec();
|
return vals.to_vec();
|
||||||
}
|
}
|
||||||
let half_roots: Vec<_> = roots_of_unity.iter().step_by(2).copied().collect();
|
let half_roots: Vec<_> = roots_of_unity.iter().step_by(2).copied().collect();
|
||||||
|
@ -38,13 +39,17 @@ pub fn fft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec<G1Affine> {
|
||||||
let [l, r]: [Vec<G1Affine>; 2] = {
|
let [l, r]: [Vec<G1Affine>; 2] = {
|
||||||
#[cfg(feature = "parallel")]
|
#[cfg(feature = "parallel")]
|
||||||
{
|
{
|
||||||
[l, r].into_par_iter().map(|f| f()).collect()
|
let (l, r) = rayon::join(l, r);
|
||||||
|
[l, r]
|
||||||
}
|
}
|
||||||
#[cfg(not(feature = "parallel"))]
|
#[cfg(not(feature = "parallel"))]
|
||||||
{
|
{
|
||||||
[l(), r()]
|
[l(), r()]
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
// Double sized so we can use iterator later on
|
||||||
|
let l: Vec<_> = l.into_iter().cycle().take(original_len).collect();
|
||||||
|
let r: Vec<_> = r.into_iter().cycle().take(original_len).collect();
|
||||||
|
|
||||||
let y_times_root = {
|
let y_times_root = {
|
||||||
#[cfg(feature = "parallel")]
|
#[cfg(feature = "parallel")]
|
||||||
|
@ -56,7 +61,6 @@ pub fn fft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec<G1Affine> {
|
||||||
r.into_iter()
|
r.into_iter()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
.cycle()
|
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(i, y)| (y * roots_of_unity[i % vals.len()]).into_affine());
|
.map(|(i, y)| (y * roots_of_unity[i % vals.len()]).into_affine());
|
||||||
|
|
||||||
|
@ -70,8 +74,6 @@ pub fn fft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec<G1Affine> {
|
||||||
l.into_iter()
|
l.into_iter()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
.cycle()
|
|
||||||
.take(vals.len())
|
|
||||||
.zip(y_times_root)
|
.zip(y_times_root)
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(i, (x, y_times_root))| {
|
.map(|(i, (x, y_times_root))| {
|
||||||
|
@ -90,12 +92,16 @@ pub fn ifft_g1(vals: &[G1Affine], roots_of_unity: &[Fr]) -> Vec<G1Affine> {
|
||||||
let mut mod_min_2 = BigInt::new(<Fr as PrimeField>::MODULUS.0);
|
let mut mod_min_2 = BigInt::new(<Fr as PrimeField>::MODULUS.0);
|
||||||
mod_min_2.sub_with_borrow(&BigInt::<4>::from(2u64));
|
mod_min_2.sub_with_borrow(&BigInt::<4>::from(2u64));
|
||||||
let invlen = Fr::from(vals.len() as u64).pow(mod_min_2).into_bigint();
|
let invlen = Fr::from(vals.len() as u64).pow(mod_min_2).into_bigint();
|
||||||
|
{
|
||||||
#[cfg(feature = "parallel")]
|
#[cfg(feature = "parallel")]
|
||||||
{
|
{
|
||||||
fft_g1(vals, roots_of_unity).into_par_iter().collect()
|
fft_g1(vals, roots_of_unity).into_par_iter()
|
||||||
}
|
}
|
||||||
#[cfg(not(feature = "parallel"))]
|
#[cfg(not(feature = "parallel"))]
|
||||||
{ fft_g1(vals, roots_of_unity).into_iter() }
|
{
|
||||||
|
fft_g1(vals, roots_of_unity).into_iter()
|
||||||
|
}
|
||||||
|
}
|
||||||
.map(|g| g.mul_bigint(invlen).into_affine())
|
.map(|g| g.mul_bigint(invlen).into_affine())
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue