use rayon instead of tokio

This commit is contained in:
Youngjoon Lee 2024-08-17 23:41:37 +09:00
parent 5248c02f2e
commit 1e643ab38f
No known key found for this signature in database
GPG Key ID: 167546E2D1712F8C
2 changed files with 19 additions and 51 deletions

View File

@ -8,10 +8,10 @@ chrono = "0.4.38"
clap = { version = "4.5.16", features = ["derive"] }
csv = "1.3.0"
rand = "0.8.5"
rayon = "1.10.0"
rustc-hash = "2.0.0"
strum = "0.26.3"
strum_macros = "0.26.4"
tokio = { version = "1.39.2", features = ["rt", "rt-multi-thread", "sync"] }
tracing = "0.1.40"
tracing-subscriber = "0.3.18"

View File

@ -1,7 +1,7 @@
use chrono::Utc;
use clap::Parser;
use rayon::prelude::*;
use std::{
collections::HashMap,
error::Error,
path::Path,
time::{Duration, SystemTime},
@ -45,12 +45,6 @@ fn main() {
num_threads,
} = args;
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(num_threads)
.enable_all()
.build()
.unwrap();
// Create a directory and initialize a CSV file only with a header
assert!(
Path::new(&outdir).is_dir(),
@ -69,60 +63,34 @@ fn main() {
let session_start_time = SystemTime::now();
runtime.block_on(async {
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel::<(u16, u16)>();
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.build()
.unwrap();
let mut waiting_iterations: HashMap<u16, (u16, String)> = HashMap::new();
let mut num_completed_paramsets = 0;
let num_paramsets = paramsets.len();
for paramset in paramsets {
pool.install(|| {
paramsets.par_iter().for_each(|paramset| {
let paramset_dir = format!("{outdir}/{subdir}/__WIP__paramset_{}", paramset.id);
std::fs::create_dir_all(paramset_dir.as_str()).unwrap();
save_paramset_info(&paramset, format!("{paramset_dir}/paramset.csv").as_str()).unwrap();
save_paramset_info(paramset, format!("{paramset_dir}/paramset.csv").as_str()).unwrap();
for i in 0..paramset.num_iterations {
let out_csv_path = format!("{paramset_dir}/__WIP__iteration_{i}.csv");
let topology_path = format!("{paramset_dir}/topology_{i}.csv");
let sender = sender.clone();
tokio::task::spawn(async move {
run_iteration(paramset, i as u64, &out_csv_path, &topology_path);
let new_out_csv_path = out_csv_path.replace("__WIP__iteration_", "iteration_");
std::fs::rename(&out_csv_path, &new_out_csv_path)
.expect("Failed to rename: {out_csv_path} -> {new_out_csv_path}");
tracing::info!("ParamSet:{}, Iteration:{} completed.", paramset.id, i);
run_iteration(*paramset, i as u64, &out_csv_path, &topology_path);
sender.send((paramset.id, i)).unwrap();
});
let new_out_csv_path = out_csv_path.replace("__WIP__iteration_", "iteration_");
std::fs::rename(&out_csv_path, &new_out_csv_path)
.expect("Failed to rename: {out_csv_path} -> {new_out_csv_path}");
tracing::info!("ParamSet:{}, Iteration:{} completed.", paramset.id, i);
}
waiting_iterations.insert(paramset.id, (paramset.num_iterations, paramset_dir));
}
let new_paramset_dir = paramset_dir.replace("__WIP__paramset_", "paramset_");
std::fs::rename(&paramset_dir, &new_paramset_dir)
.expect("Failed to rename: {paramset_dir} -> {new_paramset_dir}: {e}");
while let Some((paramset_id, _)) = receiver.recv().await {
let (remaining_iterations, _) = waiting_iterations.get_mut(&paramset_id).unwrap();
*remaining_iterations -= 1;
if *remaining_iterations == 0 {
let paramset_dir = waiting_iterations.remove(&paramset_id).unwrap().1;
let new_paramset_dir = paramset_dir.replace("__WIP__paramset_", "paramset_");
std::fs::rename(&paramset_dir, &new_paramset_dir)
.expect("Failed to rename: {paramset_dir} -> {new_paramset_dir}: {e}");
num_completed_paramsets += 1;
tracing::info!(
"ParamSet:{} completed. {}/{} paramsets have been done so far.",
paramset_id,
num_completed_paramsets,
num_paramsets
);
}
// Exit loop if no more iterations are waiting
if waiting_iterations.is_empty() {
break;
}
}
tracing::info!("ParamSet:{} completed", paramset.id);
});
});
let session_duration = SystemTime::now()