use thread and crossbeam instead of rayon

This commit is contained in:
Youngjoon Lee 2024-08-25 03:38:38 +09:00
parent 4b6a0d6164
commit 8e11dec005
No known key found for this signature in database
GPG Key ID: 167546E2D1712F8C
3 changed files with 108 additions and 42 deletions

View File

@ -6,10 +6,10 @@ edition = "2021"
[dependencies]
chrono = "0.4.38"
clap = { version = "4.5.16", features = ["derive"] }
crossbeam = "0.8.4"
csv = "1.3.0"
protocol = { version = "0.1.0", path = "../protocol" }
rand = "0.8.5"
rayon = "1.10.0"
rustc-hash = "2.0.0"
tracing = "0.1.40"
tracing-subscriber = "0.3.18"

View File

@ -18,14 +18,14 @@ use crate::{
pub struct Iteration {
pub paramset: ParamSet,
pub iteration_idx: usize,
pub rootdir: String,
pub paramset_dir: String,
}
impl Iteration {
pub fn start(&mut self) {
let dir = format!(
"{}/iteration_{}__WIP_DUR__",
self.rootdir, self.iteration_idx
self.paramset_dir, self.iteration_idx
);
std::fs::create_dir_all(dir.as_str()).unwrap();

View File

@ -6,6 +6,7 @@ mod paramset;
mod topology;
use std::{
collections::{hash_map::Entry, HashMap},
error::Error,
path::Path,
time::{Duration, SystemTime},
@ -16,7 +17,6 @@ use clap::Parser;
use iteration::Iteration;
use paramset::{ExperimentId, ParamSet, SessionId, PARAMSET_CSV_COLUMNS};
use protocol::queue::QueueType;
use rayon::prelude::*;
#[derive(Debug, Parser)]
#[command(name = "Ordering Measurement")]
@ -67,49 +67,21 @@ fn main() {
queue_type,
Utc::now().to_rfc3339()
);
std::fs::create_dir_all(&format!("{outdir}/{subdir}")).unwrap();
let rootdir = format!("{outdir}/{subdir}");
std::fs::create_dir_all(&rootdir).unwrap();
let paramsets = ParamSet::new_all_paramsets(exp_id, session_id, queue_type);
let session_start_time = SystemTime::now();
let mut iterations: Vec<Iteration> = Vec::new();
for paramset in paramsets {
if paramset.id < from_paramset.unwrap_or(0) {
tracing::info!("ParamSet:{} skipped", paramset.id);
continue;
} else if paramset.id > to_paramset.unwrap_or(u16::MAX) {
tracing::info!("ParamSets:{}~ skipped", paramset.id);
break;
}
let paramset_dir = format!("{outdir}/{subdir}/paramset_{}", paramset.id);
std::fs::create_dir_all(paramset_dir.as_str()).unwrap();
save_paramset_info(&paramset, format!("{paramset_dir}/paramset.csv").as_str()).unwrap();
for i in 0..paramset.num_iterations {
iterations.push(Iteration {
paramset: paramset.clone(),
iteration_idx: i,
rootdir: paramset_dir.clone(),
});
}
}
if reverse_order {
iterations.reverse();
}
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.build()
.unwrap();
pool.install(|| {
iterations.par_iter_mut().for_each(|iteration| {
iteration.start();
});
});
let iterations = prepare_all_iterations(
&paramsets,
from_paramset,
to_paramset,
reverse_order,
&rootdir,
);
run_all_iterations(iterations, num_threads);
let session_duration = SystemTime::now()
.duration_since(session_start_time)
@ -131,6 +103,100 @@ fn main() {
tracing::info!("Session completed.");
}
fn prepare_all_iterations(
paramsets: &[ParamSet],
from_paramset: Option<u16>,
to_paramset: Option<u16>,
reverse_order: bool,
rootdir: &str,
) -> Vec<Iteration> {
let mut iterations: Vec<Iteration> = Vec::new();
for paramset in paramsets.iter() {
if paramset.id < from_paramset.unwrap_or(0) {
tracing::info!("ParamSet:{} skipped", paramset.id);
continue;
} else if paramset.id > to_paramset.unwrap_or(u16::MAX) {
tracing::info!("ParamSets:{}~ skipped", paramset.id);
break;
}
let paramset_dir = format!("{rootdir}/__WIP__paramset_{}", paramset.id);
std::fs::create_dir_all(paramset_dir.as_str()).unwrap();
save_paramset_info(paramset, format!("{paramset_dir}/paramset.csv").as_str()).unwrap();
for i in 0..paramset.num_iterations {
iterations.push(Iteration {
paramset: paramset.clone(),
iteration_idx: i,
paramset_dir: paramset_dir.clone(),
});
}
}
if reverse_order {
iterations.reverse();
}
iterations
}
fn run_all_iterations(iterations: Vec<Iteration>, num_threads: usize) {
let (task_tx, task_rx) = crossbeam::channel::unbounded::<Iteration>();
let (noti_tx, noti_rx) = crossbeam::channel::unbounded::<Iteration>();
let mut threads = Vec::with_capacity(num_threads);
for _ in 0..num_threads {
let task_rx = task_rx.clone();
let noti_tx = noti_tx.clone();
let thread = std::thread::spawn(move || {
while let Ok(mut iteration) = task_rx.recv() {
iteration.start();
noti_tx.send(iteration).unwrap();
}
});
threads.push(thread);
}
let num_all_iterations = iterations.len();
for iteration in iterations {
task_tx.send(iteration).unwrap();
}
// Close the task sender channel, so that the threads can know that there's no task remains.
drop(task_tx);
let mut paramset_progresses: HashMap<u16, usize> = HashMap::new();
for _ in 0..num_all_iterations {
let iteration = noti_rx.recv().unwrap();
match paramset_progresses.entry(iteration.paramset.id) {
Entry::Occupied(mut e) => {
*e.get_mut() += 1;
}
Entry::Vacant(e) => {
e.insert(1);
}
}
if *paramset_progresses.get(&iteration.paramset.id).unwrap()
== iteration.paramset.num_iterations
{
let new_paramset_dir = iteration
.paramset_dir
.replace("__WIP__paramset", "paramset");
std::fs::rename(iteration.paramset_dir, new_paramset_dir).unwrap();
tracing::info!(
"ParamSet:{} is done ({} iterations)",
iteration.paramset.id,
iteration.paramset.num_iterations
);
}
}
for thread in threads {
thread.join().unwrap();
}
}
fn save_paramset_info(paramset: &ParamSet, path: &str) -> Result<(), Box<dyn Error>> {
// Assert that the file does not already exist
assert!(