diff --git a/Cargo.toml b/Cargo.toml index 8d14c3d0..a78d0a96 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["field", "insertion", "plonky2", "starky", "system_zero", "util", "waksman", "ecdsa", "u32", "evm"] +members = ["field", "insertion", "plonky2", "starky", "system_zero", "util", "waksman", "ecdsa", "u32", "evm", "maybe_rayon"] [profile.release] opt-level = 3 diff --git a/maybe_rayon/Cargo.toml b/maybe_rayon/Cargo.toml new file mode 100644 index 00000000..f8cc95fb --- /dev/null +++ b/maybe_rayon/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "maybe_rayon" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +parallel = ["rayon"] + +[dependencies] +rayon = { version = "1.5.3", optional = true } diff --git a/maybe_rayon/src/lib.rs b/maybe_rayon/src/lib.rs new file mode 100644 index 00000000..6f8dc7a1 --- /dev/null +++ b/maybe_rayon/src/lib.rs @@ -0,0 +1,250 @@ +#[cfg(not(feature = "parallel"))] +use std::{ + iter::{IntoIterator, Iterator}, + slice::{Chunks, ChunksExact, ChunksMut, ChunksExactMut}, +}; + +#[cfg(feature = "parallel")] +use rayon::{ + prelude::*, + slice::{Chunks as ParChunks, ChunksMut as ParChunksMut, ChunksExact as ParChunksExact, ChunksExactMut as ParChunksExactMut, ParallelSlice, ParallelSliceMut} +}; + +#[cfg(feature = "parallel")] +pub use rayon::prelude::{ + ParallelIterator, + IndexedParallelIterator, + ParallelExtend, + ParallelDrainFull, + ParallelDrainRange +}; + +pub trait MaybeParIter<'data> { + #[cfg(feature = "parallel")] + type Item: Send + 'data; + + #[cfg(feature = "parallel")] + type Iter: ParallelIterator; + + #[cfg(not(feature = "parallel"))] + type Item; + + #[cfg(not(feature = "parallel"))] + type Iter: Iterator; + + fn par_iter(&'data self) -> Self::Iter; +} + +#[cfg(feature = "parallel")] +impl<'data, T> MaybeParIter<'data> for T where T: ?Sized + IntoParallelRefIterator<'data> { + type Item = T::Item; + type Iter = T::Iter; + + fn par_iter(&'data self) -> Self::Iter { + self.par_iter() + } +} + +#[cfg(not(feature = "parallel"))] +impl<'data, T: 'data> MaybeParIter<'data> for Vec { + type Item = &'data T; + type Iter = std::slice::Iter<'data, T>; + + fn par_iter(&'data self) -> Self::Iter { + self.iter() + } +} + +#[cfg(not(feature = "parallel"))] +impl<'data, T: 'data> MaybeParIter<'data> for [T] { + type Item = &'data T; + type Iter = std::slice::Iter<'data, T>; + + fn par_iter(&'data self) -> Self::Iter { + self.iter() + } +} + +pub trait MaybeParIterMut<'data> { + #[cfg(feature = "parallel")] + type Item: Send + 'data; + + #[cfg(feature = "parallel")] + type Iter: ParallelIterator; + + #[cfg(not(feature = "parallel"))] + type Item; + + #[cfg(not(feature = "parallel"))] + type Iter: Iterator; + + fn par_iter_mut(&'data mut self) -> Self::Iter; +} + +#[cfg(feature = "parallel")] +impl<'data, T> MaybeParIterMut<'data> for T where T: ?Sized + IntoParallelRefMutIterator<'data> { + type Item = T::Item; + type Iter = T::Iter; + + fn par_iter_mut(&'data mut self) -> Self::Iter { + self.par_iter_mut() + } +} + +#[cfg(not(feature = "parallel"))] +impl<'data, T: 'data> MaybeParIterMut<'data> for Vec { + type Item = &'data mut T; + type Iter = std::slice::IterMut<'data, T>; + + fn par_iter_mut(&'data mut self) -> Self::Iter { + self.iter_mut() + } +} + +#[cfg(not(feature = "parallel"))] +impl<'data, T: 'data> MaybeParIterMut<'data> for [T] { + type Item = &'data mut T; + type Iter = std::slice::IterMut<'data, T>; + + fn par_iter_mut(&'data mut self) -> Self::Iter { + self.iter_mut() + } +} + +pub trait MaybeIntoParIter { + #[cfg(feature = "parallel")] + type Item: Send; + + #[cfg(feature = "parallel")] + type Iter: ParallelIterator; + + #[cfg(not(feature = "parallel"))] + type Item; + + #[cfg(not(feature = "parallel"))] + type Iter: Iterator; + + fn maybe_into_par_iter(self) -> Self::Iter; +} + +#[cfg(feature = "parallel")] +impl MaybeIntoParIter for T where T: IntoParallelIterator { + type Item = T::Item; + type Iter = T::Iter; + + fn maybe_into_par_iter(self) -> Self::Iter { + self.into_par_iter() + } +} + +#[cfg(not(feature = "parallel"))] +impl MaybeIntoParIter for T where T: IntoIterator { + type Item = T::Item; + type Iter = T::IntoIter; + + fn maybe_into_par_iter(self) -> Self::Iter { + self.into_iter() + } +} + +#[cfg(feature = "parallel")] +pub trait MaybeParChunks { + fn par_chunks(&self, chunk_size: usize) -> ParChunks<'_, T>; + fn par_chunks_exact(&self, chunk_size: usize) -> ParChunksExact<'_, T>; +} + +#[cfg(not(feature = "parallel"))] +pub trait MaybeParChunks { + fn par_chunks(&self, chunk_size: usize) -> Chunks<'_, T>; + fn par_chunks_exact(&self, chunk_size: usize) -> ChunksExact<'_, T>; +} + +#[cfg(feature = "parallel")] +impl + ?Sized, U: Sync> MaybeParChunks for T { + fn par_chunks(&self, chunk_size: usize) -> ParChunks<'_, U> { + self.par_chunks(chunk_size) + } + fn par_chunks_exact(&self, chunk_size: usize) -> ParChunksExact<'_, U> { + self.par_chunks_exact(chunk_size) + } +} + +#[cfg(not(feature = "parallel"))] +impl MaybeParChunks for [T] { + fn par_chunks(&self, chunk_size: usize) -> Chunks<'_, T> { + self.chunks(chunk_size) + } + + fn par_chunks_exact(&self, chunk_size: usize) -> ChunksExact<'_, T> { + self.chunks_exact(chunk_size) + } +} + +#[cfg(feature = "parallel")] +pub trait MaybeParChunksMut { + fn par_chunks_mut(&mut self, chunk_size: usize) -> ParChunksMut<'_, T>; + fn par_chunks_exact_mut(&mut self, chunk_size: usize) -> ParChunksExactMut<'_, T>; +} + +#[cfg(not(feature = "parallel"))] +pub trait MaybeParChunksMut { + fn par_chunks_mut(&mut self, chunk_size: usize) -> ChunksMut<'_, T>; + fn par_chunks_exact_mut(&mut self, chunk_size: usize) -> ChunksExactMut<'_, T>; +} + + +#[cfg(feature = "parallel")] +impl, U: Send> MaybeParChunksMut for T { + fn par_chunks_mut(&mut self, chunk_size: usize) -> ParChunksMut<'_, U> { + self.par_chunks_mut(chunk_size) + } + fn par_chunks_exact_mut(&mut self, chunk_size: usize) -> ParChunksExactMut<'_, U> { + self.par_chunks_exact_mut(chunk_size) + } +} + +#[cfg(not(feature = "parallel"))] +impl MaybeParChunksMut for [T] { + fn par_chunks_mut(&mut self, chunk_size: usize) -> ChunksMut<'_, T> { + self.chunks_mut(chunk_size) + } + fn par_chunks_exact_mut(&mut self, chunk_size: usize) -> ChunksExactMut<'_, T> { + self.chunks_exact_mut(chunk_size) + } +} + +pub trait ParallelIteratorMock { + type Item; + fn find_any

(self, predicate: P) -> Option + where + P: Fn(&Self::Item) -> bool + Sync + Send; +} + +impl ParallelIteratorMock for T { + type Item = T::Item; + + fn find_any

(mut self, predicate: P) -> Option + where + P: Fn(&Self::Item) -> bool + Sync + Send + { + self.find(predicate) + } +} + +#[cfg(feature = "parallel")] +pub fn join(oper_a: A, oper_b: B) -> (RA, RB) + where A: FnOnce() -> RA + Send, + B: FnOnce() -> RB + Send, + RA: Send, + RB: Send +{ + rayon::join(oper_a, oper_b) +} + +#[cfg(not(feature = "parallel"))] +pub fn join(oper_a: A, oper_b: B) -> (RA, RB) + where A: FnOnce() -> RA, + B: FnOnce() -> RB, +{ + (oper_a(), oper_b()) +} diff --git a/plonky2/Cargo.toml b/plonky2/Cargo.toml index 9c019640..9ee89344 100644 --- a/plonky2/Cargo.toml +++ b/plonky2/Cargo.toml @@ -10,6 +10,10 @@ categories = ["cryptography"] edition = "2021" default-run = "generate_constants" +[features] +default = ["parallel"] +parallel = ["maybe_rayon/parallel"] + [dependencies] plonky2_field = { path = "../field" } plonky2_util = { path = "../util" } @@ -19,7 +23,7 @@ itertools = "0.10.0" num = { version = "0.4", features = [ "rand" ] } rand = "0.8.4" rand_chacha = "0.3.1" -rayon = "1.5.1" +maybe_rayon = { path = "../maybe_rayon" } unroll = "0.1.5" anyhow = "1.0.40" serde = { version = "1.0", features = ["derive"] } @@ -32,6 +36,7 @@ criterion = "0.3.5" tynm = "0.1.6" structopt = "0.3.26" num_cpus = "1.13.1" +rayon = "1.5.1" [target.'cfg(not(target_env = "msvc"))'.dev-dependencies] jemallocator = "0.3.2" diff --git a/plonky2/examples/bench_recursion.rs b/plonky2/examples/bench_recursion.rs index 1f2d127f..8073c9dc 100644 --- a/plonky2/examples/bench_recursion.rs +++ b/plonky2/examples/bench_recursion.rs @@ -2,7 +2,6 @@ // custom CLI argument parsing (even with harness disabled). We could also have // put it in `src/bin/`, but then we wouldn't have access to // `[dev-dependencies]`. - #![allow(incomplete_features)] #![feature(generic_const_exprs)] diff --git a/plonky2/src/fri/oracle.rs b/plonky2/src/fri/oracle.rs index 312b458b..da4e9e80 100644 --- a/plonky2/src/fri/oracle.rs +++ b/plonky2/src/fri/oracle.rs @@ -5,7 +5,7 @@ use plonky2_field::packed::PackedField; use plonky2_field::polynomial::{PolynomialCoeffs, PolynomialValues}; use plonky2_field::types::Field; use plonky2_util::{log2_strict, reverse_index_bits_in_place}; -use rayon::prelude::*; +use maybe_rayon::*; use crate::fri::proof::FriProof; use crate::fri::prover::fri_proof; @@ -52,7 +52,7 @@ impl, C: GenericConfig, const D: usize> let coeffs = timed!( timing, "IFFT", - values.into_par_iter().map(|v| v.ifft()).collect::>() + values.maybe_into_par_iter().map(|v| v.ifft()).collect::>() ); Self::from_coeffs( @@ -122,7 +122,7 @@ impl, C: GenericConfig, const D: usize> }) .chain( (0..salt_size) - .into_par_iter() + .maybe_into_par_iter() .map(|_| F::rand_vec(degree << rate_bits)), ) .collect() diff --git a/plonky2/src/fri/prover.rs b/plonky2/src/fri/prover.rs index 6136a9a1..0f3215a8 100644 --- a/plonky2/src/fri/prover.rs +++ b/plonky2/src/fri/prover.rs @@ -2,7 +2,7 @@ use itertools::Itertools; use plonky2_field::extension::{flatten, unflatten, Extendable}; use plonky2_field::polynomial::{PolynomialCoeffs, PolynomialValues}; use plonky2_util::reverse_index_bits_in_place; -use rayon::prelude::*; +use maybe_rayon::*; use crate::fri::proof::{FriInitialTreeProof, FriProof, FriQueryRound, FriQueryStep}; use crate::fri::{FriConfig, FriParams}; @@ -119,7 +119,7 @@ fn fri_proof_of_work, C: GenericConfig, c config: &FriConfig, ) -> F { (0..=F::NEG_ONE.to_canonical_u64()) - .into_par_iter() + .maybe_into_par_iter() .find_any(|&i| { C::InnerHasher::hash_no_pad( ¤t_hash diff --git a/plonky2/src/hash/merkle_tree.rs b/plonky2/src/hash/merkle_tree.rs index 69cf2ef9..f7b6d4a2 100644 --- a/plonky2/src/hash/merkle_tree.rs +++ b/plonky2/src/hash/merkle_tree.rs @@ -2,9 +2,9 @@ use std::mem::MaybeUninit; use std::slice; use plonky2_util::log2_strict; -use rayon::prelude::*; use serde::{Deserialize, Serialize}; +use maybe_rayon::*; use crate::hash::hash_types::RichField; use crate::hash::merkle_proofs::MerkleProof; use crate::plonk::config::GenericHashOut; @@ -77,10 +77,12 @@ where let (right_digest_mem, right_digests_buf) = right_digests_buf.split_first_mut().unwrap(); // Split `leaves` between both children. let (left_leaves, right_leaves) = leaves.split_at(leaves.len() / 2); - let (left_digest, right_digest) = rayon::join( + + let (left_digest, right_digest) = maybe_rayon::join( || fill_subtree::(left_digests_buf, left_leaves), || fill_subtree::(right_digests_buf, right_leaves), ); + left_digest_mem.write(left_digest); right_digest_mem.write(right_digest); H::two_to_one(left_digest, right_digest) diff --git a/plonky2/src/plonk/permutation_argument.rs b/plonky2/src/plonk/permutation_argument.rs index 076c2a7a..f9b23796 100644 --- a/plonky2/src/plonk/permutation_argument.rs +++ b/plonky2/src/plonk/permutation_argument.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use plonky2_field::polynomial::PolynomialValues; use plonky2_field::types::Field; -use rayon::prelude::*; +use maybe_rayon::*; use crate::iop::target::Target; use crate::iop::wire::Wire; diff --git a/plonky2/src/plonk/proof.rs b/plonky2/src/plonk/proof.rs index 18af1f73..1cb83b14 100644 --- a/plonky2/src/plonk/proof.rs +++ b/plonky2/src/plonk/proof.rs @@ -1,7 +1,7 @@ use anyhow::ensure; use plonky2_field::extension::Extendable; -use rayon::prelude::*; use serde::{Deserialize, Serialize}; +use maybe_rayon::*; use crate::fri::oracle::PolynomialBatch; use crate::fri::proof::{ diff --git a/plonky2/src/plonk/prover.rs b/plonky2/src/plonk/prover.rs index 26626208..526721f0 100644 --- a/plonky2/src/plonk/prover.rs +++ b/plonky2/src/plonk/prover.rs @@ -6,7 +6,7 @@ use plonky2_field::extension::Extendable; use plonky2_field::polynomial::{PolynomialCoeffs, PolynomialValues}; use plonky2_field::zero_poly_coset::ZeroPolyOnCoset; use plonky2_util::{ceil_div_usize, log2_ceil}; -use rayon::prelude::*; +use maybe_rayon::*; use crate::field::types::Field; use crate::fri::oracle::PolynomialBatch; @@ -142,7 +142,7 @@ where timing, "split up quotient polys", quotient_polys - .into_par_iter() + .maybe_into_par_iter() .flat_map(|mut quotient_poly| { quotient_poly.trim_to_len(quotient_degree).expect( "Quotient has failed, the vanishing polynomial is not divisible by Z_H", @@ -305,7 +305,7 @@ fn wires_permutation_partial_products_and_zs< } transpose(&all_partial_products_and_zs) - .into_par_iter() + .maybe_into_par_iter() .map(PolynomialValues::new) .collect() } @@ -452,7 +452,7 @@ fn compute_quotient_polys< .collect(); transpose("ient_values) - .into_par_iter() + .maybe_into_par_iter() .map(PolynomialValues::new) .map(|values| values.coset_ifft(F::coset_shift())) .collect()