Merge pull request #629 from proxima-one/maybe-rayon

add rayon shim
This commit is contained in:
Daniel Lubarov 2022-07-28 11:38:56 -07:00 committed by GitHub
commit bb45c8c850
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 304 additions and 19 deletions

View File

@ -1,5 +1,5 @@
[workspace]
members = ["field", "insertion", "plonky2", "starky", "system_zero", "util", "waksman", "ecdsa", "u32", "evm"]
members = ["field", "insertion", "plonky2", "starky", "system_zero", "util", "waksman", "ecdsa", "u32", "evm", "maybe_rayon"]
[profile.release]
opt-level = 3

View File

@ -17,7 +17,7 @@ log = "0.4.14"
once_cell = "1.13.0"
pest = "2.1.3"
pest_derive = "2.1.0"
rayon = "1.5.1"
maybe_rayon = { path = "../maybe_rayon" }
rand = "0.8.5"
rand_chacha = "0.3.1"
rlp = "0.5.1"
@ -28,7 +28,9 @@ keccak-hash = "0.9.0"
hex = "0.4.3"
[features]
default = ["parallel"]
asmtools = ["hex"]
parallel = ["maybe_rayon/parallel"]
[[bin]]
name = "assemble"

View File

@ -2,6 +2,7 @@ use std::marker::PhantomData;
use ethereum_types::U256;
use itertools::Itertools;
use maybe_rayon::*;
use plonky2::field::extension::{Extendable, FieldExtension};
use plonky2::field::packed::PackedField;
use plonky2::field::polynomial::PolynomialValues;
@ -10,7 +11,6 @@ use plonky2::hash::hash_types::RichField;
use plonky2::timed;
use plonky2::util::timing::TimingTree;
use plonky2::util::transpose;
use rayon::prelude::*;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cross_table_lookup::Column;

View File

@ -1,6 +1,7 @@
//! Permutation arguments.
use itertools::Itertools;
use maybe_rayon::*;
use plonky2::field::batch_util::batch_multiply_inplace;
use plonky2::field::extension::{Extendable, FieldExtension};
use plonky2::field::packed::PackedField;
@ -16,7 +17,6 @@ use plonky2::plonk::plonk_common::{
reduce_with_powers, reduce_with_powers_circuit, reduce_with_powers_ext_circuit,
};
use plonky2::util::reducing::{ReducingFactor, ReducingFactorTarget};
use rayon::prelude::*;
use crate::config::StarkConfig;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};

View File

@ -1,4 +1,5 @@
use itertools::Itertools;
use maybe_rayon::*;
use plonky2::field::extension::{Extendable, FieldExtension};
use plonky2::fri::oracle::PolynomialBatch;
use plonky2::fri::proof::{
@ -12,7 +13,6 @@ use plonky2::hash::merkle_tree::MerkleCap;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::iop::target::Target;
use plonky2::plonk::config::GenericConfig;
use rayon::prelude::*;
use crate::config::StarkConfig;
use crate::permutation::GrandProductChallengeSet;

View File

@ -1,6 +1,7 @@
use std::any::type_name;
use anyhow::{ensure, Result};
use maybe_rayon::*;
use plonky2::field::extension::Extendable;
use plonky2::field::packable::Packable;
use plonky2::field::packed::PackedField;
@ -15,7 +16,6 @@ use plonky2::timed;
use plonky2::util::timing::TimingTree;
use plonky2::util::transpose;
use plonky2_util::{log2_ceil, log2_strict};
use rayon::prelude::*;
use crate::all_stark::{AllStark, Table};
use crate::config::StarkConfig;

11
maybe_rayon/Cargo.toml Normal file
View File

@ -0,0 +1,11 @@
[package]
name = "maybe_rayon"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
parallel = ["rayon"]
[dependencies]
rayon = { version = "1.5.3", optional = true }

262
maybe_rayon/src/lib.rs Normal file
View File

@ -0,0 +1,262 @@
#[cfg(not(feature = "parallel"))]
use std::{
iter::{IntoIterator, Iterator},
slice::{Chunks, ChunksExact, ChunksExactMut, ChunksMut},
};
#[cfg(feature = "parallel")]
pub use rayon::prelude::{
IndexedParallelIterator, ParallelDrainFull, ParallelDrainRange, ParallelExtend,
ParallelIterator,
};
#[cfg(feature = "parallel")]
use rayon::{
prelude::*,
slice::{
Chunks as ParChunks, ChunksExact as ParChunksExact, ChunksExactMut as ParChunksExactMut,
ChunksMut as ParChunksMut, ParallelSlice, ParallelSliceMut,
},
};
pub trait MaybeParIter<'data> {
#[cfg(feature = "parallel")]
type Item: Send + 'data;
#[cfg(feature = "parallel")]
type Iter: ParallelIterator<Item = Self::Item>;
#[cfg(not(feature = "parallel"))]
type Item;
#[cfg(not(feature = "parallel"))]
type Iter: Iterator<Item = Self::Item>;
fn par_iter(&'data self) -> Self::Iter;
}
#[cfg(feature = "parallel")]
impl<'data, T> MaybeParIter<'data> for T
where
T: ?Sized + IntoParallelRefIterator<'data>,
{
type Item = T::Item;
type Iter = T::Iter;
fn par_iter(&'data self) -> Self::Iter {
self.par_iter()
}
}
#[cfg(not(feature = "parallel"))]
impl<'data, T: 'data> MaybeParIter<'data> for Vec<T> {
type Item = &'data T;
type Iter = std::slice::Iter<'data, T>;
fn par_iter(&'data self) -> Self::Iter {
self.iter()
}
}
#[cfg(not(feature = "parallel"))]
impl<'data, T: 'data> MaybeParIter<'data> for [T] {
type Item = &'data T;
type Iter = std::slice::Iter<'data, T>;
fn par_iter(&'data self) -> Self::Iter {
self.iter()
}
}
pub trait MaybeParIterMut<'data> {
#[cfg(feature = "parallel")]
type Item: Send + 'data;
#[cfg(feature = "parallel")]
type Iter: ParallelIterator<Item = Self::Item>;
#[cfg(not(feature = "parallel"))]
type Item;
#[cfg(not(feature = "parallel"))]
type Iter: Iterator<Item = Self::Item>;
fn par_iter_mut(&'data mut self) -> Self::Iter;
}
#[cfg(feature = "parallel")]
impl<'data, T> MaybeParIterMut<'data> for T
where
T: ?Sized + IntoParallelRefMutIterator<'data>,
{
type Item = T::Item;
type Iter = T::Iter;
fn par_iter_mut(&'data mut self) -> Self::Iter {
self.par_iter_mut()
}
}
#[cfg(not(feature = "parallel"))]
impl<'data, T: 'data> MaybeParIterMut<'data> for Vec<T> {
type Item = &'data mut T;
type Iter = std::slice::IterMut<'data, T>;
fn par_iter_mut(&'data mut self) -> Self::Iter {
self.iter_mut()
}
}
#[cfg(not(feature = "parallel"))]
impl<'data, T: 'data> MaybeParIterMut<'data> for [T] {
type Item = &'data mut T;
type Iter = std::slice::IterMut<'data, T>;
fn par_iter_mut(&'data mut self) -> Self::Iter {
self.iter_mut()
}
}
pub trait MaybeIntoParIter {
#[cfg(feature = "parallel")]
type Item: Send;
#[cfg(feature = "parallel")]
type Iter: ParallelIterator<Item = Self::Item>;
#[cfg(not(feature = "parallel"))]
type Item;
#[cfg(not(feature = "parallel"))]
type Iter: Iterator<Item = Self::Item>;
fn into_par_iter(self) -> Self::Iter;
}
#[cfg(feature = "parallel")]
impl<T> MaybeIntoParIter for T
where
T: IntoParallelIterator,
{
type Item = T::Item;
type Iter = T::Iter;
fn into_par_iter(self) -> Self::Iter {
self.into_par_iter()
}
}
#[cfg(not(feature = "parallel"))]
impl<T> MaybeIntoParIter for T
where
T: IntoIterator,
{
type Item = T::Item;
type Iter = T::IntoIter;
fn into_par_iter(self) -> Self::Iter {
self.into_iter()
}
}
#[cfg(feature = "parallel")]
pub trait MaybeParChunks<T: Sync> {
fn par_chunks(&self, chunk_size: usize) -> ParChunks<'_, T>;
fn par_chunks_exact(&self, chunk_size: usize) -> ParChunksExact<'_, T>;
}
#[cfg(not(feature = "parallel"))]
pub trait MaybeParChunks<T> {
fn par_chunks(&self, chunk_size: usize) -> Chunks<'_, T>;
fn par_chunks_exact(&self, chunk_size: usize) -> ChunksExact<'_, T>;
}
#[cfg(feature = "parallel")]
impl<T: ParallelSlice<U> + ?Sized, U: Sync> MaybeParChunks<U> for T {
fn par_chunks(&self, chunk_size: usize) -> ParChunks<'_, U> {
self.par_chunks(chunk_size)
}
fn par_chunks_exact(&self, chunk_size: usize) -> ParChunksExact<'_, U> {
self.par_chunks_exact(chunk_size)
}
}
#[cfg(not(feature = "parallel"))]
impl<T> MaybeParChunks<T> for [T] {
fn par_chunks(&self, chunk_size: usize) -> Chunks<'_, T> {
self.chunks(chunk_size)
}
fn par_chunks_exact(&self, chunk_size: usize) -> ChunksExact<'_, T> {
self.chunks_exact(chunk_size)
}
}
#[cfg(feature = "parallel")]
pub trait MaybeParChunksMut<T: Send> {
fn par_chunks_mut(&mut self, chunk_size: usize) -> ParChunksMut<'_, T>;
fn par_chunks_exact_mut(&mut self, chunk_size: usize) -> ParChunksExactMut<'_, T>;
}
#[cfg(not(feature = "parallel"))]
pub trait MaybeParChunksMut<T: Send> {
fn par_chunks_mut(&mut self, chunk_size: usize) -> ChunksMut<'_, T>;
fn par_chunks_exact_mut(&mut self, chunk_size: usize) -> ChunksExactMut<'_, T>;
}
#[cfg(feature = "parallel")]
impl<T: ?Sized + ParallelSliceMut<U>, U: Send> MaybeParChunksMut<U> for T {
fn par_chunks_mut(&mut self, chunk_size: usize) -> ParChunksMut<'_, U> {
self.par_chunks_mut(chunk_size)
}
fn par_chunks_exact_mut(&mut self, chunk_size: usize) -> ParChunksExactMut<'_, U> {
self.par_chunks_exact_mut(chunk_size)
}
}
#[cfg(not(feature = "parallel"))]
impl<T: Send> MaybeParChunksMut<T> for [T] {
fn par_chunks_mut(&mut self, chunk_size: usize) -> ChunksMut<'_, T> {
self.chunks_mut(chunk_size)
}
fn par_chunks_exact_mut(&mut self, chunk_size: usize) -> ChunksExactMut<'_, T> {
self.chunks_exact_mut(chunk_size)
}
}
pub trait ParallelIteratorMock {
type Item;
fn find_any<P>(self, predicate: P) -> Option<Self::Item>
where
P: Fn(&Self::Item) -> bool + Sync + Send;
}
impl<T: Iterator> ParallelIteratorMock for T {
type Item = T::Item;
fn find_any<P>(mut self, predicate: P) -> Option<Self::Item>
where
P: Fn(&Self::Item) -> bool + Sync + Send,
{
self.find(predicate)
}
}
#[cfg(feature = "parallel")]
pub fn join<A, B, RA, RB>(oper_a: A, oper_b: B) -> (RA, RB)
where
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
RA: Send,
RB: Send,
{
rayon::join(oper_a, oper_b)
}
#[cfg(not(feature = "parallel"))]
pub fn join<A, B, RA, RB>(oper_a: A, oper_b: B) -> (RA, RB)
where
A: FnOnce() -> RA,
B: FnOnce() -> RB,
{
(oper_a(), oper_b())
}

View File

@ -10,6 +10,10 @@ categories = ["cryptography"]
edition = "2021"
default-run = "generate_constants"
[features]
default = ["parallel"]
parallel = ["maybe_rayon/parallel"]
[dependencies]
plonky2_field = { path = "../field" }
plonky2_util = { path = "../util" }
@ -19,7 +23,7 @@ itertools = "0.10.0"
num = { version = "0.4", features = [ "rand" ] }
rand = "0.8.4"
rand_chacha = "0.3.1"
rayon = "1.5.1"
maybe_rayon = { path = "../maybe_rayon" }
unroll = "0.1.5"
anyhow = "1.0.40"
serde = { version = "1.0", features = ["derive"] }
@ -32,6 +36,7 @@ criterion = "0.3.5"
tynm = "0.1.6"
structopt = "0.3.26"
num_cpus = "1.13.1"
rayon = "1.5.1"
[target.'cfg(not(target_env = "msvc"))'.dev-dependencies]
jemallocator = "0.3.2"

View File

@ -2,7 +2,6 @@
// custom CLI argument parsing (even with harness disabled). We could also have
// put it in `src/bin/`, but then we wouldn't have access to
// `[dev-dependencies]`.
#![allow(incomplete_features)]
#![feature(generic_const_exprs)]

View File

@ -1,11 +1,11 @@
use itertools::Itertools;
use maybe_rayon::*;
use plonky2_field::extension::Extendable;
use plonky2_field::fft::FftRootTable;
use plonky2_field::packed::PackedField;
use plonky2_field::polynomial::{PolynomialCoeffs, PolynomialValues};
use plonky2_field::types::Field;
use plonky2_util::{log2_strict, reverse_index_bits_in_place};
use rayon::prelude::*;
use crate::fri::proof::FriProof;
use crate::fri::prover::fri_proof;

View File

@ -1,8 +1,8 @@
use itertools::Itertools;
use maybe_rayon::*;
use plonky2_field::extension::{flatten, unflatten, Extendable};
use plonky2_field::polynomial::{PolynomialCoeffs, PolynomialValues};
use plonky2_util::reverse_index_bits_in_place;
use rayon::prelude::*;
use crate::fri::proof::{FriInitialTreeProof, FriProof, FriQueryRound, FriQueryStep};
use crate::fri::{FriConfig, FriParams};

View File

@ -1,8 +1,8 @@
use std::mem::MaybeUninit;
use std::slice;
use maybe_rayon::*;
use plonky2_util::log2_strict;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use crate::hash::hash_types::RichField;
@ -77,10 +77,12 @@ where
let (right_digest_mem, right_digests_buf) = right_digests_buf.split_first_mut().unwrap();
// Split `leaves` between both children.
let (left_leaves, right_leaves) = leaves.split_at(leaves.len() / 2);
let (left_digest, right_digest) = rayon::join(
let (left_digest, right_digest) = maybe_rayon::join(
|| fill_subtree::<F, H>(left_digests_buf, left_leaves),
|| fill_subtree::<F, H>(right_digests_buf, right_leaves),
);
left_digest_mem.write(left_digest);
right_digest_mem.write(right_digest);
H::two_to_one(left_digest, right_digest)

View File

@ -1,8 +1,8 @@
use std::collections::HashMap;
use maybe_rayon::*;
use plonky2_field::polynomial::PolynomialValues;
use plonky2_field::types::Field;
use rayon::prelude::*;
use crate::iop::target::Target;
use crate::iop::wire::Wire;

View File

@ -1,6 +1,6 @@
use anyhow::ensure;
use maybe_rayon::*;
use plonky2_field::extension::Extendable;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use crate::fri::oracle::PolynomialBatch;

View File

@ -2,11 +2,11 @@ use std::mem::swap;
use anyhow::ensure;
use anyhow::Result;
use maybe_rayon::*;
use plonky2_field::extension::Extendable;
use plonky2_field::polynomial::{PolynomialCoeffs, PolynomialValues};
use plonky2_field::zero_poly_coset::ZeroPolyOnCoset;
use plonky2_util::{ceil_div_usize, log2_ceil};
use rayon::prelude::*;
use crate::field::types::Field;
use crate::fri::oracle::PolynomialBatch;

View File

@ -4,6 +4,10 @@ description = "Implementation of STARKs"
version = "0.1.0"
edition = "2021"
[features]
default = ["parallel"]
parallel = ["maybe_rayon/parallel"]
[dependencies]
plonky2 = { path = "../plonky2" }
plonky2_util = { path = "../util" }
@ -11,4 +15,4 @@ anyhow = "1.0.40"
env_logger = "0.9.0"
itertools = "0.10.0"
log = "0.4.14"
rayon = "1.5.1"
maybe_rayon = { path = "../maybe_rayon"}

View File

@ -1,6 +1,7 @@
//! Permutation arguments.
use itertools::Itertools;
use maybe_rayon::*;
use plonky2::field::batch_util::batch_multiply_inplace;
use plonky2::field::extension::{Extendable, FieldExtension};
use plonky2::field::packed::PackedField;
@ -13,7 +14,6 @@ use plonky2::iop::target::Target;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::config::{AlgebraicHasher, GenericConfig, Hasher};
use plonky2::util::reducing::{ReducingFactor, ReducingFactorTarget};
use rayon::prelude::*;
use crate::config::StarkConfig;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};

View File

@ -1,4 +1,5 @@
use itertools::Itertools;
use maybe_rayon::*;
use plonky2::field::extension::{Extendable, FieldExtension};
use plonky2::fri::oracle::PolynomialBatch;
use plonky2::fri::proof::{
@ -12,7 +13,6 @@ use plonky2::hash::merkle_tree::MerkleCap;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::iop::target::Target;
use plonky2::plonk::config::GenericConfig;
use rayon::prelude::*;
use crate::config::StarkConfig;
use crate::permutation::PermutationChallengeSet;

View File

@ -2,6 +2,7 @@ use std::iter::once;
use anyhow::{ensure, Result};
use itertools::Itertools;
use maybe_rayon::*;
use plonky2::field::extension::Extendable;
use plonky2::field::packable::Packable;
use plonky2::field::packed::PackedField;
@ -16,7 +17,6 @@ use plonky2::timed;
use plonky2::util::timing::TimingTree;
use plonky2::util::transpose;
use plonky2_util::{log2_ceil, log2_strict};
use rayon::prelude::*;
use crate::config::StarkConfig;
use crate::constraint_consumer::ConstraintConsumer;