From 01d3d6f120e5f4bb76000cacb63af55a7789426e Mon Sep 17 00:00:00 2001 From: Youngjoon Lee <5462944+youngjoon-lee@users.noreply.github.com> Date: Mon, 9 Sep 2024 19:00:59 +0900 Subject: [PATCH] parallelize coeff bin --- mixnet-rs/ordering/src/bin/coeff.rs | 139 ++++++++++++++--------- mixnet-rs/ordering/src/bin/coeff_aggr.rs | 2 +- 2 files changed, 85 insertions(+), 56 deletions(-) diff --git a/mixnet-rs/ordering/src/bin/coeff.rs b/mixnet-rs/ordering/src/bin/coeff.rs index 4cb0770..7a17796 100644 --- a/mixnet-rs/ordering/src/bin/coeff.rs +++ b/mixnet-rs/ordering/src/bin/coeff.rs @@ -1,5 +1,6 @@ -use std::{env, fs::File, path::Path}; +use std::{fs::File, path::PathBuf}; +use clap::Parser; use glob::glob; use polars::prelude::*; use walkdir::WalkDir; @@ -221,33 +222,31 @@ fn skip_noise(seq: &[Entry], mut index: usize) -> usize { index } +#[derive(Debug, Parser)] +#[command(name = "Calculating ordering coefficients")] +struct Args { + #[arg(short, long)] + path: String, + #[arg(short, long)] + num_threads: usize, +} + fn main() { tracing_subscriber::fmt::init(); - let args: Vec = env::args().collect(); - if args.len() < 2 { - eprintln!("Usage: {} ", args[0]); - std::process::exit(1); - } - let path = &args[1]; - - calculate_coeffs(path); + let args = Args::parse(); + calculate_coeffs(&args); } -fn calculate_coeffs(path: &str) { - for entry in WalkDir::new(path) +fn calculate_coeffs(args: &Args) { + let mut tasks: Vec = Vec::new(); + for entry in WalkDir::new(args.path.as_str()) .into_iter() .filter_map(|e| e.ok()) .filter(|e| e.file_type().is_dir()) { let dir_name = entry.path().file_name().unwrap().to_string_lossy(); if dir_name.starts_with("iteration_") { - let mut senders: Vec = Vec::new(); - let mut receivers: Vec = Vec::new(); - let mut strongs: Vec = Vec::new(); - let mut casuals: Vec = Vec::new(); - let mut weaks: Vec = Vec::new(); - for sent_seq_file in glob(&format!("{}/sent_seq_*.csv", entry.path().display())) .unwrap() .filter_map(Result::ok) @@ -262,50 +261,43 @@ fn calculate_coeffs(path: &str) { let receiver = extract_id(&recv_seq_file.file_name().unwrap().to_string_lossy()).unwrap(); - tracing::info!("Processing:"); - tracing::info!(" {}", sent_seq_file.display()); - tracing::info!(" {}", recv_seq_file.display()); - - let sent_seq = load_sequence(sent_seq_file.to_str().unwrap()); - let recv_seq = load_sequence(recv_seq_file.to_str().unwrap()); - let (strong, casual) = strong_and_casual_coeff(&sent_seq, &recv_seq); - let weak = weak_coeff(&sent_seq, &recv_seq); - - senders.push(sender as u64); - receivers.push(receiver as u64); - strongs.push(strong); - casuals.push(casual); - weaks.push(weak); - - tracing::info!( - "Processed: sender:{}, receiver:{}, strong:{}, casual:{}, weak:{}", + let task = Task { + sent_seq_file: sent_seq_file.clone(), + recv_seq_file: recv_seq_file.clone(), sender, receiver, - strong, - casual, - weak - ); + outpath: entry + .path() + .join(format!("coeffs_{}_{}.csv", sender, receiver)), + }; + tasks.push(task); } } - - // Create a Polars DataFrame - let mut df = DataFrame::new(vec![ - Series::new("sender", &senders), - Series::new("receiver", &receivers), - Series::new("strong", &strongs), - Series::new("casual", &casuals), - Series::new("weak", &weaks), - ]) - .unwrap() - .sort(["sender", "receiver"], SortMultipleOptions::default()) - .unwrap(); - // Write the sorted DataFrame to a CSV file - let outpath = Path::new(entry.path()).join("coeffs.csv"); - let mut file = File::create(&outpath).unwrap(); - CsvWriter::new(&mut file).finish(&mut df).unwrap(); - tracing::info!("Saved {}", outpath.display()); } } + + let (task_tx, task_rx) = crossbeam::channel::unbounded::(); + let mut threads = Vec::with_capacity(args.num_threads); + for _ in 0..args.num_threads { + let task_rx = task_rx.clone(); + + let thread = std::thread::spawn(move || { + while let Ok(task) = task_rx.recv() { + task.run(); + } + }); + threads.push(thread); + } + + for task in tasks { + task_tx.send(task).unwrap(); + } + // Close the task sender channel, so that the threads can know that there's no task remains. + drop(task_tx); + + for thread in threads { + thread.join().unwrap(); + } } fn extract_id(filename: &str) -> Option { @@ -319,6 +311,43 @@ fn extract_id(filename: &str) -> Option { None } +struct Task { + sent_seq_file: PathBuf, + recv_seq_file: PathBuf, + sender: u8, + receiver: u8, + outpath: PathBuf, +} + +impl Task { + fn run(&self) { + tracing::info!( + "Processing:\n {}\n {}", + self.sent_seq_file.display(), + self.recv_seq_file.display() + ); + + let sent_seq = load_sequence(self.sent_seq_file.to_str().unwrap()); + let recv_seq = load_sequence(self.recv_seq_file.to_str().unwrap()); + let (strong, casual) = strong_and_casual_coeff(&sent_seq, &recv_seq); + let weak = weak_coeff(&sent_seq, &recv_seq); + + let mut df = DataFrame::new(vec![ + Series::new("sender", &[self.sender as u64]), + Series::new("receiver", &[self.receiver as u64]), + Series::new("strong", &[strong]), + Series::new("casual", &[casual]), + Series::new("weak", &[weak]), + ]) + .unwrap() + .sort(["sender", "receiver"], SortMultipleOptions::default()) + .unwrap(); + let mut file = File::create(&self.outpath).unwrap(); + CsvWriter::new(&mut file).finish(&mut df).unwrap(); + tracing::info!("Saved {}", self.outpath.display()); + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/mixnet-rs/ordering/src/bin/coeff_aggr.rs b/mixnet-rs/ordering/src/bin/coeff_aggr.rs index 2524ba6..c2fb11e 100644 --- a/mixnet-rs/ordering/src/bin/coeff_aggr.rs +++ b/mixnet-rs/ordering/src/bin/coeff_aggr.rs @@ -16,7 +16,7 @@ fn aggregate(path: &str) { let mut casuals = Series::new_empty("", &DataType::Int64); let mut weaks = Series::new_empty("", &DataType::Int64); - let pattern = format!("{}/**/coeffs.csv", entry.path().display()); + let pattern = format!("{}/**/coeffs_*.csv", entry.path().display()); for file in glob(&pattern).unwrap().filter_map(Result::ok) { let df = CsvReadOptions::default() .with_has_header(true)