parallelize coeff bin

This commit is contained in:
Youngjoon Lee 2024-09-09 19:00:59 +09:00
parent 1e26f26d49
commit 01d3d6f120
No known key found for this signature in database
GPG Key ID: 167546E2D1712F8C
2 changed files with 85 additions and 56 deletions

View File

@ -1,5 +1,6 @@
use std::{env, fs::File, path::Path};
use std::{fs::File, path::PathBuf};
use clap::Parser;
use glob::glob;
use polars::prelude::*;
use walkdir::WalkDir;
@ -221,33 +222,31 @@ fn skip_noise(seq: &[Entry], mut index: usize) -> usize {
index
}
#[derive(Debug, Parser)]
#[command(name = "Calculating ordering coefficients")]
struct Args {
#[arg(short, long)]
path: String,
#[arg(short, long)]
num_threads: usize,
}
fn main() {
tracing_subscriber::fmt::init();
let args: Vec<String> = env::args().collect();
if args.len() < 2 {
eprintln!("Usage: {} <path>", args[0]);
std::process::exit(1);
}
let path = &args[1];
calculate_coeffs(path);
let args = Args::parse();
calculate_coeffs(&args);
}
fn calculate_coeffs(path: &str) {
for entry in WalkDir::new(path)
fn calculate_coeffs(args: &Args) {
let mut tasks: Vec<Task> = Vec::new();
for entry in WalkDir::new(args.path.as_str())
.into_iter()
.filter_map(|e| e.ok())
.filter(|e| e.file_type().is_dir())
{
let dir_name = entry.path().file_name().unwrap().to_string_lossy();
if dir_name.starts_with("iteration_") {
let mut senders: Vec<u64> = Vec::new();
let mut receivers: Vec<u64> = Vec::new();
let mut strongs: Vec<u64> = Vec::new();
let mut casuals: Vec<u64> = Vec::new();
let mut weaks: Vec<u64> = Vec::new();
for sent_seq_file in glob(&format!("{}/sent_seq_*.csv", entry.path().display()))
.unwrap()
.filter_map(Result::ok)
@ -262,50 +261,43 @@ fn calculate_coeffs(path: &str) {
let receiver =
extract_id(&recv_seq_file.file_name().unwrap().to_string_lossy()).unwrap();
tracing::info!("Processing:");
tracing::info!(" {}", sent_seq_file.display());
tracing::info!(" {}", recv_seq_file.display());
let sent_seq = load_sequence(sent_seq_file.to_str().unwrap());
let recv_seq = load_sequence(recv_seq_file.to_str().unwrap());
let (strong, casual) = strong_and_casual_coeff(&sent_seq, &recv_seq);
let weak = weak_coeff(&sent_seq, &recv_seq);
senders.push(sender as u64);
receivers.push(receiver as u64);
strongs.push(strong);
casuals.push(casual);
weaks.push(weak);
tracing::info!(
"Processed: sender:{}, receiver:{}, strong:{}, casual:{}, weak:{}",
let task = Task {
sent_seq_file: sent_seq_file.clone(),
recv_seq_file: recv_seq_file.clone(),
sender,
receiver,
strong,
casual,
weak
);
outpath: entry
.path()
.join(format!("coeffs_{}_{}.csv", sender, receiver)),
};
tasks.push(task);
}
}
// Create a Polars DataFrame
let mut df = DataFrame::new(vec![
Series::new("sender", &senders),
Series::new("receiver", &receivers),
Series::new("strong", &strongs),
Series::new("casual", &casuals),
Series::new("weak", &weaks),
])
.unwrap()
.sort(["sender", "receiver"], SortMultipleOptions::default())
.unwrap();
// Write the sorted DataFrame to a CSV file
let outpath = Path::new(entry.path()).join("coeffs.csv");
let mut file = File::create(&outpath).unwrap();
CsvWriter::new(&mut file).finish(&mut df).unwrap();
tracing::info!("Saved {}", outpath.display());
}
}
let (task_tx, task_rx) = crossbeam::channel::unbounded::<Task>();
let mut threads = Vec::with_capacity(args.num_threads);
for _ in 0..args.num_threads {
let task_rx = task_rx.clone();
let thread = std::thread::spawn(move || {
while let Ok(task) = task_rx.recv() {
task.run();
}
});
threads.push(thread);
}
for task in tasks {
task_tx.send(task).unwrap();
}
// Close the task sender channel, so that the threads can know that there's no task remains.
drop(task_tx);
for thread in threads {
thread.join().unwrap();
}
}
fn extract_id(filename: &str) -> Option<u8> {
@ -319,6 +311,43 @@ fn extract_id(filename: &str) -> Option<u8> {
None
}
struct Task {
sent_seq_file: PathBuf,
recv_seq_file: PathBuf,
sender: u8,
receiver: u8,
outpath: PathBuf,
}
impl Task {
fn run(&self) {
tracing::info!(
"Processing:\n {}\n {}",
self.sent_seq_file.display(),
self.recv_seq_file.display()
);
let sent_seq = load_sequence(self.sent_seq_file.to_str().unwrap());
let recv_seq = load_sequence(self.recv_seq_file.to_str().unwrap());
let (strong, casual) = strong_and_casual_coeff(&sent_seq, &recv_seq);
let weak = weak_coeff(&sent_seq, &recv_seq);
let mut df = DataFrame::new(vec![
Series::new("sender", &[self.sender as u64]),
Series::new("receiver", &[self.receiver as u64]),
Series::new("strong", &[strong]),
Series::new("casual", &[casual]),
Series::new("weak", &[weak]),
])
.unwrap()
.sort(["sender", "receiver"], SortMultipleOptions::default())
.unwrap();
let mut file = File::create(&self.outpath).unwrap();
CsvWriter::new(&mut file).finish(&mut df).unwrap();
tracing::info!("Saved {}", self.outpath.display());
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -16,7 +16,7 @@ fn aggregate(path: &str) {
let mut casuals = Series::new_empty("", &DataType::Int64);
let mut weaks = Series::new_empty("", &DataType::Int64);
let pattern = format!("{}/**/coeffs.csv", entry.path().display());
let pattern = format!("{}/**/coeffs_*.csv", entry.path().display());
for file in glob(&pattern).unwrap().filter_map(Result::ok) {
let df = CsvReadOptions::default()
.with_has_header(true)