From ea7896f06c19608367ecf1916bae22bd1c1f2329 Mon Sep 17 00:00:00 2001 From: Al Liu Date: Tue, 25 Apr 2023 19:14:30 +0800 Subject: [PATCH] Simulation streaming and gracefully shutdown (#119) * add stream supports * add test case * add polars stream, and force to use stream for the runner * using arcswap instead RefCell for producers * finish gracefully shutdown * - add IOProducer and IOSubscriber - fix deadlock in sync runner - fix testcases --- simulations/Cargo.toml | 5 + simulations/src/bin/app.rs | 162 ++++----------- simulations/src/lib.rs | 1 + simulations/src/node/dummy_streaming.rs | 48 +++++ simulations/src/node/mod.rs | 3 + simulations/src/output_processors/mod.rs | 7 +- simulations/src/runner/async_runner.rs | 92 ++++++--- simulations/src/runner/glauber_runner.rs | 119 +++++++---- simulations/src/runner/layered_runner.rs | 139 ++++++++----- simulations/src/runner/mod.rs | 153 +++++++++----- simulations/src/runner/sync_runner.rs | 113 +++++++---- simulations/src/settings.rs | 4 +- simulations/src/streaming/io.rs | 246 +++++++++++++++++++++++ simulations/src/streaming/mod.rs | 88 ++++++++ simulations/src/streaming/naive.rs | 237 ++++++++++++++++++++++ simulations/src/streaming/polars.rs | 194 ++++++++++++++++++ 16 files changed, 1277 insertions(+), 334 deletions(-) create mode 100644 simulations/src/node/dummy_streaming.rs create mode 100644 simulations/src/streaming/io.rs create mode 100644 simulations/src/streaming/mod.rs create mode 100644 simulations/src/streaming/naive.rs create mode 100644 simulations/src/streaming/polars.rs diff --git a/simulations/Cargo.toml b/simulations/Cargo.toml index 3104c597..83ccf86b 100644 --- a/simulations/Cargo.toml +++ b/simulations/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" [dependencies] anyhow = "1" +arc-swap = "1.6" clap = { version = "4", features = ["derive"] } crc32fast = "1.3" crossbeam = { version = "0.8.2", features = ["crossbeam-channel"] } @@ -15,9 +16,13 @@ nomos-core = { path = "../nomos-core" } polars = { version = "0.27", features = ["serde", "object", "json", "csv-file", "parquet", "dtype-struct"] } rand = { version = "0.8", features = ["small_rng"] } rayon = "1.7" +scopeguard = "1" serde = { version = "1.0", features = ["derive", "rc"] } serde_with = "2.3" serde_json = "1.0" [target.'cfg(target_arch = "wasm32")'.dependencies] getrandom = { version = "0.2", features = ["js"] } + +[dev-dependencies] +tempfile = "3.4" \ No newline at end of file diff --git a/simulations/src/bin/app.rs b/simulations/src/bin/app.rs index 246d03f1..89aa20d5 100644 --- a/simulations/src/bin/app.rs +++ b/simulations/src/bin/app.rs @@ -1,78 +1,21 @@ // std use std::collections::HashMap; -use std::fmt::{Display, Formatter}; use std::fs::File; -use std::io::Cursor; use std::path::{Path, PathBuf}; -use std::str::FromStr; // crates use clap::Parser; -use polars::io::SerWriter; -use polars::prelude::{DataFrame, JsonReader, SerReader}; use serde::de::DeserializeOwned; -use serde::{Deserialize, Serialize}; use simulations::network::regions::RegionsData; use simulations::network::Network; use simulations::overlay::tree::TreeOverlay; +use simulations::streaming::StreamType; // internal use simulations::{ node::carnot::CarnotNode, output_processors::OutData, runner::SimulationRunner, - settings::SimulationSettings, + settings::SimulationSettings, streaming::io::IOProducer, streaming::naive::NaiveProducer, + streaming::polars::PolarsProducer, }; -#[derive(Debug, Clone, Serialize, Deserialize)] -enum OutputType { - File(PathBuf), - StdOut, - StdErr, -} - -impl core::fmt::Display for OutputType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - OutputType::File(path) => write!(f, "{}", path.display()), - OutputType::StdOut => write!(f, "stdout"), - OutputType::StdErr => write!(f, "stderr"), - } - } -} - -/// Output format selector enum -#[derive(Clone, Debug, Default)] -enum OutputFormat { - Json, - Csv, - #[default] - Parquet, -} - -impl Display for OutputFormat { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let tag = match self { - OutputFormat::Json => "json", - OutputFormat::Csv => "csv", - OutputFormat::Parquet => "parquet", - }; - write!(f, "{tag}") - } -} - -impl FromStr for OutputFormat { - type Err = std::io::Error; - - fn from_str(s: &str) -> Result { - match s.to_ascii_lowercase().as_str() { - "json" => Ok(Self::Json), - "csv" => Ok(Self::Csv), - "parquet" => Ok(Self::Parquet), - tag => Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("Invalid {tag} tag, only [json, csv, polars] are supported",), - )), - } - } -} - /// Main simulation wrapper /// Pipes together the cli arguments with the execution #[derive(Parser)] @@ -80,88 +23,67 @@ pub struct SimulationApp { /// Json file path, on `SimulationSettings` format #[clap(long, short)] input_settings: PathBuf, - /// Output file path - #[clap(long, short)] - output_file: PathBuf, - /// Output format selector - #[clap(long, short = 'f', default_value_t)] - output_format: OutputFormat, + #[clap(long)] + stream_type: StreamType, } impl SimulationApp { pub fn run(self) -> anyhow::Result<()> { let Self { input_settings, - output_file, - output_format, + stream_type, } = self; - let simulation_settings: SimulationSettings<_, _> = load_json_from_file(&input_settings)?; + let nodes = vec![]; // TODO: Initialize nodes of different types. let regions_data = RegionsData::new(HashMap::new(), HashMap::new()); let network = Network::new(regions_data); - let mut simulation_runner: SimulationRunner<(), CarnotNode, TreeOverlay> = - SimulationRunner::new(network, nodes, simulation_settings); // build up series vector - let mut out_data: Vec = Vec::new(); - simulation_runner.simulate(Some(&mut out_data))?; - let mut dataframe: DataFrame = out_data_to_dataframe(out_data); - dump_dataframe_to(output_format, &mut dataframe, &output_file)?; + match stream_type { + simulations::streaming::StreamType::Naive => { + let simulation_settings: SimulationSettings<_, _, _> = + load_json_from_file(&input_settings)?; + let simulation_runner: SimulationRunner< + (), + CarnotNode, + TreeOverlay, + NaiveProducer, + > = SimulationRunner::new(network, nodes, simulation_settings); + simulation_runner.simulate()? + } + simulations::streaming::StreamType::Polars => { + let simulation_settings: SimulationSettings<_, _, _> = + load_json_from_file(&input_settings)?; + let simulation_runner: SimulationRunner< + (), + CarnotNode, + TreeOverlay, + PolarsProducer, + > = SimulationRunner::new(network, nodes, simulation_settings); + simulation_runner.simulate()? + } + simulations::streaming::StreamType::IO => { + let simulation_settings: SimulationSettings<_, _, _> = + load_json_from_file(&input_settings)?; + let simulation_runner: SimulationRunner< + (), + CarnotNode, + TreeOverlay, + IOProducer, + > = SimulationRunner::new(network, nodes, simulation_settings); + simulation_runner.simulate()? + } + }; Ok(()) } } -fn out_data_to_dataframe(out_data: Vec) -> DataFrame { - let mut cursor = Cursor::new(Vec::new()); - serde_json::to_writer(&mut cursor, &out_data).expect("Dump data to json "); - let dataframe = JsonReader::new(cursor) - .finish() - .expect("Load dataframe from intermediary json"); - - dataframe - .unnest(["state"]) - .expect("Node state should be unnest") -} - /// Generically load a json file fn load_json_from_file(path: &Path) -> anyhow::Result { let f = File::open(path).map_err(Box::new)?; Ok(serde_json::from_reader(f)?) } -fn dump_dataframe_to_json(data: &mut DataFrame, out_path: &Path) -> anyhow::Result<()> { - let out_path = out_path.with_extension("json"); - let f = File::create(out_path)?; - let mut writer = polars::prelude::JsonWriter::new(f); - Ok(writer.finish(data)?) -} - -fn dump_dataframe_to_csv(data: &mut DataFrame, out_path: &Path) -> anyhow::Result<()> { - let out_path = out_path.with_extension("csv"); - let f = File::create(out_path)?; - let mut writer = polars::prelude::CsvWriter::new(f); - Ok(writer.finish(data)?) -} - -fn dump_dataframe_to_parquet(data: &mut DataFrame, out_path: &Path) -> anyhow::Result<()> { - let out_path = out_path.with_extension("parquet"); - let f = File::create(out_path)?; - let writer = polars::prelude::ParquetWriter::new(f); - Ok(writer.finish(data).map(|_| ())?) -} - -fn dump_dataframe_to( - output_format: OutputFormat, - data: &mut DataFrame, - out_path: &Path, -) -> anyhow::Result<()> { - match output_format { - OutputFormat::Json => dump_dataframe_to_json(data, out_path), - OutputFormat::Csv => dump_dataframe_to_csv(data, out_path), - OutputFormat::Parquet => dump_dataframe_to_parquet(data, out_path), - } -} - fn main() -> anyhow::Result<()> { let app: SimulationApp = SimulationApp::parse(); app.run()?; diff --git a/simulations/src/lib.rs b/simulations/src/lib.rs index 78e22ee9..59052263 100644 --- a/simulations/src/lib.rs +++ b/simulations/src/lib.rs @@ -4,4 +4,5 @@ pub mod output_processors; pub mod overlay; pub mod runner; pub mod settings; +pub mod streaming; pub mod warding; diff --git a/simulations/src/node/dummy_streaming.rs b/simulations/src/node/dummy_streaming.rs new file mode 100644 index 00000000..ba8bcbe5 --- /dev/null +++ b/simulations/src/node/dummy_streaming.rs @@ -0,0 +1,48 @@ +use serde::{Deserialize, Serialize}; + +use super::{Node, NodeId}; + +#[derive(Debug, Default, Copy, Clone, Serialize, Deserialize)] +pub struct DummyStreamingState { + pub current_view: usize, +} + +/// This node implementation only used for testing different streaming implementation purposes. +pub struct DummyStreamingNode { + id: NodeId, + state: DummyStreamingState, + #[allow(dead_code)] + settings: S, +} + +impl DummyStreamingNode { + pub fn new(id: NodeId, settings: S) -> Self { + Self { + id, + state: DummyStreamingState::default(), + settings, + } + } +} + +impl Node for DummyStreamingNode { + type Settings = S; + + type State = DummyStreamingState; + + fn id(&self) -> NodeId { + self.id + } + + fn current_view(&self) -> usize { + self.state.current_view + } + + fn state(&self) -> &Self::State { + &self.state + } + + fn step(&mut self) { + self.state.current_view += 1; + } +} diff --git a/simulations/src/node/mod.rs b/simulations/src/node/mod.rs index 95c68031..5964d9b2 100644 --- a/simulations/src/node/mod.rs +++ b/simulations/src/node/mod.rs @@ -1,6 +1,9 @@ pub mod carnot; pub mod dummy; +#[cfg(test)] +pub mod dummy_streaming; + // std use std::{ collections::BTreeMap, diff --git a/simulations/src/output_processors/mod.rs b/simulations/src/output_processors/mod.rs index b823c108..6956723c 100644 --- a/simulations/src/output_processors/mod.rs +++ b/simulations/src/output_processors/mod.rs @@ -1,5 +1,7 @@ use serde::Serialize; +use crate::warding::SimulationState; + pub type SerializedNodeState = serde_json::Value; #[derive(Serialize)] @@ -12,12 +14,12 @@ impl OutData { } } -impl TryFrom<&crate::warding::SimulationState> for OutData +impl TryFrom<&SimulationState> for OutData where N: crate::node::Node, N::State: Serialize, { - type Error = serde_json::Error; + type Error = anyhow::Error; fn try_from(state: &crate::warding::SimulationState) -> Result { serde_json::to_value( @@ -30,6 +32,7 @@ where .collect::>(), ) .map(OutData::new) + .map_err(From::from) } } diff --git a/simulations/src/runner/async_runner.rs b/simulations/src/runner/async_runner.rs index fe4a3d8b..749a235b 100644 --- a/simulations/src/runner/async_runner.rs +++ b/simulations/src/runner/async_runner.rs @@ -1,25 +1,30 @@ use crate::node::{Node, NodeId}; -use crate::output_processors::OutData; use crate::overlay::Overlay; -use crate::runner::SimulationRunner; +use crate::runner::{SimulationRunner, SimulationRunnerHandle}; +use crate::streaming::{Producer, Subscriber}; use crate::warding::SimulationState; +use crossbeam::channel::bounded; +use crossbeam::select; use rand::prelude::SliceRandom; use rayon::prelude::*; use serde::Serialize; use std::collections::HashSet; use std::sync::Arc; -pub fn simulate( - runner: &mut SimulationRunner, +/// Simulate with sending the network state to any subscriber +pub fn simulate( + runner: SimulationRunner, chunk_size: usize, - mut out_data: Option<&mut Vec>, -) -> anyhow::Result<()> +) -> anyhow::Result where - M: Send + Sync + Clone, - N::Settings: Clone, - N: Send + Sync, + M: Clone + Send + Sync + 'static, + N: Send + Sync + 'static, + N::Settings: Clone + Send, N::State: Serialize, - O::Settings: Clone, + O::Settings: Clone + Send, + P::Subscriber: Send + Sync + 'static, + ::Record: + Send + Sync + 'static + for<'a> TryFrom<&'a SimulationState, Error = anyhow::Error>, { let simulation_state = SimulationState:: { nodes: Arc::clone(&runner.nodes), @@ -33,26 +38,51 @@ where .map(N::id) .collect(); - runner.dump_state_to_out_data(&simulation_state, &mut out_data)?; + let inner = runner.inner.clone(); + let nodes = runner.nodes.clone(); + let (stop_tx, stop_rx) = bounded(1); + let handle = SimulationRunnerHandle { + stop_tx, + handle: std::thread::spawn(move || { + let p = P::new(runner.stream_settings.settings)?; + scopeguard::defer!(if let Err(e) = p.stop() { + eprintln!("Error stopping producer: {e}"); + }); + let subscriber = p.subscribe()?; + std::thread::spawn(move || { + if let Err(e) = subscriber.run() { + eprintln!("Error in subscriber: {e}"); + } + }); + loop { + select! { + recv(stop_rx) -> _ => { + return Ok(()); + } + default => { + let mut inner = inner.write().expect("Write access to inner in async runner"); + node_ids.shuffle(&mut inner.rng); + for ids_chunk in node_ids.chunks(chunk_size) { + let ids: HashSet = ids_chunk.iter().copied().collect(); + nodes + .write() + .expect("Write access to nodes vector") + .par_iter_mut() + .filter(|n| ids.contains(&n.id())) + .for_each(N::step); - loop { - node_ids.shuffle(&mut runner.rng); - for ids_chunk in node_ids.chunks(chunk_size) { - let ids: HashSet = ids_chunk.iter().copied().collect(); - runner - .nodes - .write() - .expect("Write access to nodes vector") - .par_iter_mut() - .filter(|n| ids.contains(&n.id())) - .for_each(N::step); - - runner.dump_state_to_out_data(&simulation_state, &mut out_data)?; - } - // check if any condition makes the simulation stop - if runner.check_wards(&simulation_state) { - break; - } - } - Ok(()) + p.send(::Record::try_from( + &simulation_state, + )?)?; + } + // check if any condition makes the simulation stop + if inner.check_wards(&simulation_state) { + return Ok(()); + } + } + } + } + }), + }; + Ok(handle) } diff --git a/simulations/src/runner/glauber_runner.rs b/simulations/src/runner/glauber_runner.rs index ba7a7fb1..4e9a37fe 100644 --- a/simulations/src/runner/glauber_runner.rs +++ b/simulations/src/runner/glauber_runner.rs @@ -1,65 +1,100 @@ use crate::node::{Node, NodeId}; -use crate::output_processors::OutData; use crate::overlay::Overlay; -use crate::runner::SimulationRunner; +use crate::runner::{SimulationRunner, SimulationRunnerHandle}; +use crate::streaming::{Producer, Subscriber}; use crate::warding::SimulationState; +use crossbeam::channel::bounded; +use crossbeam::select; use rand::prelude::IteratorRandom; use serde::Serialize; use std::collections::BTreeSet; use std::sync::Arc; +/// Simulate with sending the network state to any subscriber. +/// /// [Glauber dynamics simulation](https://en.wikipedia.org/wiki/Glauber_dynamics) -pub fn simulate( - runner: &mut SimulationRunner, +pub fn simulate( + runner: SimulationRunner, update_rate: usize, maximum_iterations: usize, - mut out_data: Option<&mut Vec>, -) -> anyhow::Result<()> +) -> anyhow::Result where - M: Send + Sync + Clone, - N: Send + Sync, - N::Settings: Clone, + M: Send + Sync + Clone + 'static, + N: Send + Sync + 'static, + N::Settings: Clone + Send, N::State: Serialize, - O::Settings: Clone, + O::Settings: Clone + Send, + P::Subscriber: Send + Sync + 'static, + ::Record: + for<'a> TryFrom<&'a SimulationState, Error = anyhow::Error>, { let simulation_state = SimulationState { nodes: Arc::clone(&runner.nodes), }; - let nodes_remaining: BTreeSet = (0..runner - .nodes - .read() - .expect("Read access to nodes vector") - .len()) - .map(From::from) - .collect(); + + let inner = runner.inner.clone(); + let nodes = runner.nodes.clone(); + let nodes_remaining: BTreeSet = + (0..nodes.read().expect("Read access to nodes vector").len()) + .map(From::from) + .collect(); let iterations: Vec<_> = (0..maximum_iterations).collect(); - 'main: for chunk in iterations.chunks(update_rate) { - for _ in chunk { - if nodes_remaining.is_empty() { - break 'main; - } + let (stop_tx, stop_rx) = bounded(1); + let handle = SimulationRunnerHandle { + handle: std::thread::spawn(move || { + let p = P::new(runner.stream_settings.settings)?; + scopeguard::defer!(if let Err(e) = p.stop() { + eprintln!("Error stopping producer: {e}"); + }); + let subscriber = p.subscribe()?; + std::thread::spawn(move || { + if let Err(e) = subscriber.run() { + eprintln!("Error in subscriber: {e}"); + } + }); - let node_id = *nodes_remaining.iter().choose(&mut runner.rng).expect( - "Some id to be selected as it should be impossible for the set to be empty here", - ); + let mut inner = inner.write().expect("Locking inner"); - { - let mut shared_nodes = runner.nodes.write().expect("Write access to nodes vector"); - let node: &mut N = shared_nodes - .get_mut(node_id.inner()) - .expect("Node should be present"); - node.step(); - } + 'main: for chunk in iterations.chunks(update_rate) { + select! { + recv(stop_rx) -> _ => break 'main, + default => { + for _ in chunk { + if nodes_remaining.is_empty() { + break 'main; + } - // check if any condition makes the simulation stop - if runner.check_wards(&simulation_state) { - // we break the outer main loop, so we need to dump it before the breaking - runner.dump_state_to_out_data(&simulation_state, &mut out_data)?; - break 'main; + let node_id = *nodes_remaining.iter().choose(&mut inner.rng).expect( + "Some id to be selected as it should be impossible for the set to be empty here", + ); + + { + let mut shared_nodes = nodes.write().expect("Write access to nodes vector"); + let node: &mut N = shared_nodes + .get_mut(node_id.inner()) + .expect("Node should be present"); + node.step(); + } + + // check if any condition makes the simulation stop + if inner.check_wards(&simulation_state) { + // we break the outer main loop, so we need to dump it before the breaking + p.send(::Record::try_from( + &simulation_state, + )?)?; + break 'main; + } + } + // update_rate iterations reached, so dump state + p.send(::Record::try_from( + &simulation_state, + )?)?; + } + } } - } - // update_rate iterations reached, so dump state - runner.dump_state_to_out_data(&simulation_state, &mut out_data)?; - } - Ok(()) + Ok(()) + }), + stop_tx, + }; + Ok(handle) } diff --git a/simulations/src/runner/layered_runner.rs b/simulations/src/runner/layered_runner.rs index 9f28ee5e..8820e538 100644 --- a/simulations/src/runner/layered_runner.rs +++ b/simulations/src/runner/layered_runner.rs @@ -28,6 +28,8 @@ //! the data of that step simulation. // std +use crossbeam::channel::bounded; +use crossbeam::select; use std::collections::BTreeSet; use std::ops::Not; use std::sync::Arc; @@ -38,76 +40,112 @@ use rand::rngs::SmallRng; use serde::Serialize; // internal use crate::node::{Node, NodeId}; -use crate::output_processors::OutData; use crate::overlay::Overlay; use crate::runner::SimulationRunner; +use crate::streaming::{Producer, Subscriber}; use crate::warding::SimulationState; -pub fn simulate( - runner: &mut SimulationRunner, +use super::SimulationRunnerHandle; + +/// Simulate with sending the network state to any subscriber +pub fn simulate( + runner: SimulationRunner, gap: usize, distribution: Option>, - mut out_data: Option<&mut Vec>, -) -> anyhow::Result<()> +) -> anyhow::Result where - M: Send + Sync + Clone, - N: Send + Sync, - N::Settings: Clone, + M: Send + Sync + Clone + 'static, + N: Send + Sync + 'static, + N::Settings: Clone + Send, N::State: Serialize, - O::Settings: Clone, + O::Settings: Clone + Send, + P::Subscriber: Send + Sync + 'static, + ::Record: + for<'a> TryFrom<&'a SimulationState, Error = anyhow::Error>, { let distribution = distribution.unwrap_or_else(|| std::iter::repeat(1.0f32).take(gap).collect()); let layers: Vec = (0..gap).collect(); - let mut deque = build_node_ids_deque(gap, runner); + let mut deque = build_node_ids_deque(gap, &runner); let simulation_state = SimulationState { nodes: Arc::clone(&runner.nodes), }; - loop { - let (group_index, node_id) = - choose_random_layer_and_node_id(&mut runner.rng, &distribution, &layers, &mut deque); + let inner = runner.inner.clone(); + let nodes = runner.nodes.clone(); + let (stop_tx, stop_rx) = bounded(1); + let handle = SimulationRunnerHandle { + stop_tx, + handle: std::thread::spawn(move || { + let p = P::new(runner.stream_settings.settings)?; + scopeguard::defer!(if let Err(e) = p.stop() { + eprintln!("Error stopping producer: {e}"); + }); + let sub = p.subscribe()?; + std::thread::spawn(move || { + if let Err(e) = sub.run() { + eprintln!("Error running subscriber: {e}"); + } + }); + loop { + select! { + recv(stop_rx) -> _ => { + break; + } + default => { + let mut inner = inner.write().expect("Lock inner"); + let (group_index, node_id) = + choose_random_layer_and_node_id(&mut inner.rng, &distribution, &layers, &mut deque); - // remove node_id from group - deque.get_mut(group_index).unwrap().remove(&node_id); + // remove node_id from group + deque.get_mut(group_index).unwrap().remove(&node_id); - { - let mut shared_nodes = runner.nodes.write().expect("Write access to nodes vector"); - let node: &mut N = shared_nodes - .get_mut(node_id.inner()) - .expect("Node should be present"); - let prev_view = node.current_view(); - node.step(); - let after_view = node.current_view(); - if after_view > prev_view { - // pass node to next step group - deque.get_mut(group_index + 1).unwrap().insert(node_id); + { + let mut shared_nodes = nodes.write().expect("Write access to nodes vector"); + let node: &mut N = shared_nodes + .get_mut(node_id.inner()) + .expect("Node should be present"); + let prev_view = node.current_view(); + node.step(); + let after_view = node.current_view(); + if after_view > prev_view { + // pass node to next step group + deque.get_mut(group_index + 1).unwrap().insert(node_id); + } + } + + // check if any condition makes the simulation stop + if inner.check_wards(&simulation_state) { + break; + } + + // if initial is empty then we finished a full round, append a new set to the end so we can + // compute the most advanced nodes again + if deque.first().unwrap().is_empty() { + let _ = deque.push_back(BTreeSet::default()); + p.send(::Record::try_from( + &simulation_state, + )?)?; + } + + // if no more nodes to compute + if deque.iter().all(BTreeSet::is_empty) { + break; + } + } + } } - } - - // check if any condition makes the simulation stop - if runner.check_wards(&simulation_state) { - break; - } - - // if initial is empty then we finished a full round, append a new set to the end so we can - // compute the most advanced nodes again - if deque.first().unwrap().is_empty() { - let _ = deque.push_back(BTreeSet::default()); - runner.dump_state_to_out_data(&simulation_state, &mut out_data)?; - } - - // if no more nodes to compute - if deque.iter().all(BTreeSet::is_empty) { - break; - } - } - // write latest state - runner.dump_state_to_out_data(&simulation_state, &mut out_data)?; - Ok(()) + // write latest state + p.send(::Record::try_from( + &simulation_state, + )?)?; + Ok(()) + }), + }; + Ok(handle) } fn choose_random_layer_and_node_id( @@ -134,13 +172,14 @@ fn choose_random_layer_and_node_id( (i, *node_id) } -fn build_node_ids_deque( +fn build_node_ids_deque( gap: usize, - runner: &SimulationRunner, + runner: &SimulationRunner, ) -> FixedSliceDeque> where N: Node, O: Overlay, + P: Producer, { // add a +1 so we always have let mut deque = FixedSliceDeque::new(gap + 1); diff --git a/simulations/src/runner/mod.rs b/simulations/src/runner/mod.rs index 5ee8297b..d4169be4 100644 --- a/simulations/src/runner/mod.rs +++ b/simulations/src/runner/mod.rs @@ -9,6 +9,8 @@ use std::sync::{Arc, RwLock}; use std::time::Duration; // crates +use crate::streaming::{Producer, Subscriber}; +use crossbeam::channel::Sender; use rand::rngs::SmallRng; use rand::{RngCore, SeedableRng}; use rayon::prelude::*; @@ -17,37 +19,96 @@ use serde::Serialize; // internal use crate::network::Network; use crate::node::Node; -use crate::output_processors::OutData; use crate::overlay::Overlay; use crate::settings::{RunnerSettings, SimulationSettings}; -use crate::warding::{SimulationState, SimulationWard}; +use crate::streaming::StreamSettings; +use crate::warding::{SimulationState, SimulationWard, Ward}; + +pub struct SimulationRunnerHandle { + handle: std::thread::JoinHandle>, + stop_tx: Sender<()>, +} + +impl SimulationRunnerHandle { + pub fn stop_after(self, duration: Duration) -> anyhow::Result<()> { + std::thread::sleep(duration); + self.stop() + } + + pub fn stop(self) -> anyhow::Result<()> { + if !self.handle.is_finished() { + self.stop_tx.send(())?; + } + Ok(()) + } +} + +pub(crate) struct SimulationRunnerInner { + network: Network, + wards: Vec, + rng: SmallRng, +} + +impl SimulationRunnerInner +where + M: Send + Sync + Clone, +{ + fn check_wards(&mut self, state: &SimulationState) -> bool + where + N: Node + Send + Sync, + N::Settings: Clone + Send, + N::State: Serialize, + { + self.wards + .par_iter_mut() + .map(|ward| ward.analyze(state)) + .any(|x| x) + } + + fn step(&mut self, nodes: &mut Vec) + where + N: Node + Send + Sync, + N::Settings: Clone + Send, + N::State: Serialize, + { + self.network.dispatch_after(Duration::from_millis(100)); + nodes.par_iter_mut().for_each(|node| { + node.step(); + }); + self.network.collect_messages(); + } +} /// Encapsulation solution for the simulations runner /// Holds the network state, the simulating nodes and the simulation settings. -pub struct SimulationRunner +pub struct SimulationRunner where N: Node, O: Overlay, + P: Producer, { + inner: Arc>>, nodes: Arc>>, - network: Network, - settings: SimulationSettings, - rng: SmallRng, + runner_settings: RunnerSettings, + stream_settings: StreamSettings, _overlay: PhantomData, } -impl SimulationRunner +impl SimulationRunner where - M: Send + Sync + Clone, - N: Send + Sync, - N::Settings: Clone, + M: Clone + Send + Sync + 'static, + N: Send + Sync + 'static, + N::Settings: Clone + Send, N::State: Serialize, - O::Settings: Clone, + O::Settings: Clone + Send, + P::Subscriber: Send + Sync + 'static, + ::Record: + Send + Sync + 'static + for<'a> TryFrom<&'a SimulationState, Error = anyhow::Error>, { pub fn new( network: Network, nodes: Vec, - settings: SimulationSettings, + settings: SimulationSettings, ) -> Self { let seed = settings .seed @@ -57,59 +118,43 @@ where let rng = SmallRng::seed_from_u64(seed); let nodes = Arc::new(RwLock::new(nodes)); - + let SimulationSettings { + network_behaviors: _, + regions: _, + wards, + overlay_settings: _, + node_settings: _, + runner_settings, + stream_settings, + node_count: _, + committee_size: _, + seed: _, + } = settings; Self { + stream_settings, + runner_settings, + inner: Arc::new(RwLock::new(SimulationRunnerInner { + network, + rng, + wards, + })), nodes, - network, - settings, - rng, - _overlay: Default::default(), + _overlay: PhantomData, } } - pub fn simulate(&mut self, out_data: Option<&mut Vec>) -> anyhow::Result<()> { - match self.settings.runner_settings.clone() { - RunnerSettings::Sync => sync_runner::simulate(self, out_data), - RunnerSettings::Async { chunks } => async_runner::simulate(self, chunks, out_data), + pub fn simulate(self) -> anyhow::Result { + match self.runner_settings.clone() { + RunnerSettings::Sync => sync_runner::simulate::<_, _, _, P>(self), + RunnerSettings::Async { chunks } => async_runner::simulate::<_, _, _, P>(self, chunks), RunnerSettings::Glauber { maximum_iterations, update_rate, - } => glauber_runner::simulate(self, update_rate, maximum_iterations, out_data), + } => glauber_runner::simulate::<_, _, _, P>(self, update_rate, maximum_iterations), RunnerSettings::Layered { rounds_gap, distribution, - } => layered_runner::simulate(self, rounds_gap, distribution, out_data), + } => layered_runner::simulate::<_, _, _, P>(self, rounds_gap, distribution), } } - - fn dump_state_to_out_data( - &self, - simulation_state: &SimulationState, - out_data: &mut Option<&mut Vec>, - ) -> anyhow::Result<()> { - if let Some(out_data) = out_data { - out_data.push(OutData::try_from(simulation_state)?); - } - Ok(()) - } - - fn check_wards(&mut self, state: &SimulationState) -> bool { - self.settings - .wards - .par_iter_mut() - .map(|ward| ward.analyze(state)) - .any(|x| x) - } - - fn step(&mut self) { - self.network.dispatch_after(Duration::from_millis(100)); - self.nodes - .write() - .expect("Single access to nodes vector") - .par_iter_mut() - .for_each(|node| { - node.step(); - }); - self.network.collect_messages(); - } } diff --git a/simulations/src/runner/sync_runner.rs b/simulations/src/runner/sync_runner.rs index 22387606..86bcc53a 100644 --- a/simulations/src/runner/sync_runner.rs +++ b/simulations/src/runner/sync_runner.rs @@ -1,39 +1,76 @@ use serde::Serialize; -use super::SimulationRunner; +use super::{SimulationRunner, SimulationRunnerHandle}; use crate::node::Node; -use crate::output_processors::OutData; use crate::overlay::Overlay; +use crate::streaming::{Producer, Subscriber}; use crate::warding::SimulationState; +use crossbeam::channel::{bounded, select}; use std::sync::Arc; -/// Simulate with option of dumping the network state as a `::polars::Series` -pub fn simulate( - runner: &mut SimulationRunner, - mut out_data: Option<&mut Vec>, -) -> anyhow::Result<()> +/// Simulate with sending the network state to any subscriber +pub fn simulate( + runner: SimulationRunner, +) -> anyhow::Result where - M: Send + Sync + Clone, - N: Send + Sync, - N::Settings: Clone, + M: Send + Sync + Clone + 'static, + N: Send + Sync + 'static, + N::Settings: Clone + Send, N::State: Serialize, O::Settings: Clone, + P::Subscriber: Send + Sync + 'static, + ::Record: + Send + Sync + 'static + for<'a> TryFrom<&'a SimulationState, Error = anyhow::Error>, { let state = SimulationState { nodes: Arc::clone(&runner.nodes), }; - runner.dump_state_to_out_data(&state, &mut out_data)?; + let inner = runner.inner.clone(); + let nodes = runner.nodes.clone(); - for _ in 1.. { - runner.step(); - runner.dump_state_to_out_data(&state, &mut out_data)?; - // check if any condition makes the simulation stop - if runner.check_wards(&state) { - break; - } - } - Ok(()) + let (stop_tx, stop_rx) = bounded(1); + let handle = SimulationRunnerHandle { + stop_tx, + handle: std::thread::spawn(move || { + let p = P::new(runner.stream_settings.settings)?; + scopeguard::defer!(if let Err(e) = p.stop() { + eprintln!("Error stopping producer: {e}"); + }); + let subscriber = p.subscribe()?; + std::thread::spawn(move || { + if let Err(e) = subscriber.run() { + eprintln!("Error in subscriber: {e}"); + } + }); + p.send(::Record::try_from(&state)?)?; + loop { + select! { + recv(stop_rx) -> _ => { + return Ok(()); + } + default => { + let mut inner = inner.write().expect("Write access to inner simulation state"); + + // we must use a code block to make sure once the step call is finished then the write lock will be released, because in Record::try_from(&state), + // we need to call the read lock, if we do not release the write lock, + // then dead lock will occur + { + let mut nodes = nodes.write().expect("Write access to nodes vector"); + inner.step(&mut nodes); + } + + p.send(::Record::try_from(&state).unwrap()).unwrap(); + // check if any condition makes the simulation stop + if inner.check_wards(&state) { + return Ok(()); + } + } + } + } + }), + }; + Ok(handle) } #[cfg(test)] @@ -48,12 +85,14 @@ mod tests { dummy::{DummyMessage, DummyNetworkInterface, DummyNode, DummySettings}, Node, NodeId, OverlayState, SharedState, ViewOverlay, }, + output_processors::OutData, overlay::{ tree::{TreeOverlay, TreeSettings}, Overlay, }, runner::SimulationRunner, settings::SimulationSettings, + streaming::naive::{NaiveProducer, NaiveSettings}, }; use crossbeam::channel; use rand::rngs::mock::StepRng; @@ -95,11 +134,12 @@ mod tests { #[test] fn runner_one_step() { - let settings: SimulationSettings = SimulationSettings { - node_count: 10, - committee_size: 1, - ..Default::default() - }; + let settings: SimulationSettings = + SimulationSettings { + node_count: 10, + committee_size: 1, + ..Default::default() + }; let mut rng = StepRng::new(1, 0); let node_ids: Vec = (0..settings.node_count).map(Into::into).collect(); @@ -115,9 +155,11 @@ mod tests { })); let nodes = init_dummy_nodes(&node_ids, &mut network, overlay_state); - let mut runner: SimulationRunner = + let runner: SimulationRunner> = SimulationRunner::new(network, nodes, settings); - runner.step(); + let mut nodes = runner.nodes.write().unwrap(); + runner.inner.write().unwrap().step(&mut nodes); + drop(nodes); let nodes = runner.nodes.read().unwrap(); for node in nodes.iter() { @@ -127,11 +169,12 @@ mod tests { #[test] fn runner_send_receive() { - let settings: SimulationSettings = SimulationSettings { - node_count: 10, - committee_size: 1, - ..Default::default() - }; + let settings: SimulationSettings = + SimulationSettings { + node_count: 10, + committee_size: 1, + ..Default::default() + }; let mut rng = StepRng::new(1, 0); let node_ids: Vec = (0..settings.node_count).map(Into::into).collect(); @@ -159,10 +202,12 @@ mod tests { } network.collect_messages(); - let mut runner: SimulationRunner = + let runner: SimulationRunner> = SimulationRunner::new(network, nodes, settings); - runner.step(); + let mut nodes = runner.nodes.write().unwrap(); + runner.inner.write().unwrap().step(&mut nodes); + drop(nodes); let nodes = runner.nodes.read().unwrap(); let state = nodes[1].state(); diff --git a/simulations/src/settings.rs b/simulations/src/settings.rs index ea673b84..3dba264b 100644 --- a/simulations/src/settings.rs +++ b/simulations/src/settings.rs @@ -1,5 +1,6 @@ use crate::network::regions::Region; use crate::node::StepTime; +use crate::streaming::StreamSettings; use crate::warding::Ward; use serde::Deserialize; use std::collections::HashMap; @@ -22,7 +23,7 @@ pub enum RunnerSettings { } #[derive(Default, Deserialize)] -pub struct SimulationSettings { +pub struct SimulationSettings { pub network_behaviors: HashMap<(Region, Region), StepTime>, pub regions: Vec, #[serde(default)] @@ -30,6 +31,7 @@ pub struct SimulationSettings { pub overlay_settings: O, pub node_settings: N, pub runner_settings: RunnerSettings, + pub stream_settings: StreamSettings

, pub node_count: usize, pub committee_size: usize, pub seed: Option, diff --git a/simulations/src/streaming/io.rs b/simulations/src/streaming/io.rs new file mode 100644 index 00000000..5d3b27cc --- /dev/null +++ b/simulations/src/streaming/io.rs @@ -0,0 +1,246 @@ +use std::sync::{Arc, Mutex}; + +use super::{Producer, Receivers, Subscriber}; +use arc_swap::ArcSwapOption; +use crossbeam::channel::{bounded, unbounded, Sender}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug)] +pub struct IOStreamSettings { + pub writer: W, +} + +impl Default for IOStreamSettings { + fn default() -> Self { + Self { + writer: std::io::stdout(), + } + } +} + +impl<'de> Deserialize<'de> for IOStreamSettings { + fn deserialize(_deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + Ok(Self { + writer: std::io::stdout(), + }) + } +} + +#[derive(Debug)] +pub struct IOProducer { + sender: Sender, + stop_tx: Sender<()>, + recvs: ArcSwapOption>, + writer: ArcSwapOption>, +} + +impl Producer for IOProducer +where + W: std::io::Write + Send + Sync + 'static, + R: Serialize + Send + Sync + 'static, +{ + type Settings = IOStreamSettings; + + type Subscriber = IOSubscriber; + + fn new(settings: Self::Settings) -> anyhow::Result + where + Self: Sized, + { + let (sender, recv) = unbounded(); + let (stop_tx, stop_rx) = bounded(1); + Ok(Self { + sender, + recvs: ArcSwapOption::from(Some(Arc::new(Receivers { stop_rx, recv }))), + stop_tx, + writer: ArcSwapOption::from(Some(Arc::new(Mutex::new(settings.writer)))), + }) + } + + fn send(&self, state: ::Record) -> anyhow::Result<()> { + self.sender.send(state).map_err(From::from) + } + + fn subscribe(&self) -> anyhow::Result + where + Self::Subscriber: Sized, + { + let recvs = self.recvs.load(); + if recvs.is_none() { + return Err(anyhow::anyhow!("Producer has been subscribed")); + } + + let recvs = self.recvs.swap(None).unwrap(); + let writer = self.writer.swap(None).unwrap(); + let this = IOSubscriber { recvs, writer }; + Ok(this) + } + + fn stop(&self) -> anyhow::Result<()> { + Ok(self.stop_tx.send(())?) + } +} + +#[derive(Debug)] +pub struct IOSubscriber { + recvs: Arc>, + writer: Arc>, +} + +impl Subscriber for IOSubscriber +where + W: std::io::Write + Send + Sync + 'static, + R: Serialize + Send + Sync + 'static, +{ + type Record = R; + + fn next(&self) -> Option> { + Some(self.recvs.recv.recv().map_err(From::from)) + } + + fn run(self) -> anyhow::Result<()> { + loop { + crossbeam::select! { + recv(self.recvs.stop_rx) -> _ => { + break; + } + recv(self.recvs.recv) -> msg => { + self.sink(msg?)?; + } + } + } + + Ok(()) + } + + fn sink(&self, state: Self::Record) -> anyhow::Result<()> { + serde_json::to_writer( + &mut *self + .writer + .lock() + .expect("fail to lock writer in io subscriber"), + &state, + )?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, time::Duration}; + + use crate::{ + network::{ + behaviour::NetworkBehaviour, + regions::{Region, RegionsData}, + Network, + }, + node::{dummy_streaming::DummyStreamingNode, Node, NodeId}, + overlay::tree::TreeOverlay, + runner::SimulationRunner, + streaming::{StreamSettings, StreamType}, + warding::SimulationState, + }; + + use super::*; + #[derive(Debug, Clone, Serialize)] + struct IORecord { + states: HashMap, + } + + impl TryFrom<&SimulationState>> for IORecord { + type Error = anyhow::Error; + + fn try_from(value: &SimulationState>) -> Result { + let nodes = value.nodes.read().expect("failed to read nodes"); + Ok(Self { + states: nodes + .iter() + .map(|node| (node.id(), node.current_view())) + .collect(), + }) + } + } + + #[test] + fn test_streaming() { + let simulation_settings = crate::settings::SimulationSettings { + seed: Some(1), + stream_settings: StreamSettings { + ty: StreamType::IO, + settings: IOStreamSettings { + writer: std::io::stdout(), + }, + }, + ..Default::default() + }; + + let nodes = (0..6) + .map(|idx| DummyStreamingNode::new(NodeId::from(idx), ())) + .collect::>(); + let network = Network::new(RegionsData { + regions: (0..6) + .map(|idx| { + let region = match idx % 6 { + 0 => Region::Europe, + 1 => Region::NorthAmerica, + 2 => Region::SouthAmerica, + 3 => Region::Asia, + 4 => Region::Africa, + 5 => Region::Australia, + _ => unreachable!(), + }; + (region, vec![idx.into()]) + }) + .collect(), + node_region: (0..6) + .map(|idx| { + let region = match idx % 6 { + 0 => Region::Europe, + 1 => Region::NorthAmerica, + 2 => Region::SouthAmerica, + 3 => Region::Asia, + 4 => Region::Africa, + 5 => Region::Australia, + _ => unreachable!(), + }; + (idx.into(), region) + }) + .collect(), + region_network_behaviour: (0..6) + .map(|idx| { + let region = match idx % 6 { + 0 => Region::Europe, + 1 => Region::NorthAmerica, + 2 => Region::SouthAmerica, + 3 => Region::Asia, + 4 => Region::Africa, + 5 => Region::Australia, + _ => unreachable!(), + }; + ( + (region, region), + NetworkBehaviour { + delay: Duration::from_millis(100), + drop: 0.0, + }, + ) + }) + .collect(), + }); + let simulation_runner: SimulationRunner< + (), + DummyStreamingNode<()>, + TreeOverlay, + IOProducer, + > = SimulationRunner::new(network, nodes, simulation_settings); + simulation_runner + .simulate() + .unwrap() + .stop_after(Duration::from_millis(100)) + .unwrap(); + } +} diff --git a/simulations/src/streaming/mod.rs b/simulations/src/streaming/mod.rs new file mode 100644 index 00000000..ad1c6b5a --- /dev/null +++ b/simulations/src/streaming/mod.rs @@ -0,0 +1,88 @@ +use std::str::FromStr; + +use crossbeam::channel::Receiver; +use serde::Serialize; + +pub mod io; +pub mod naive; +pub mod polars; + +#[derive(Debug)] +struct Receivers { + stop_rx: Receiver<()>, + recv: Receiver, +} + +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize)] +pub enum StreamType { + #[default] + IO, + Naive, + Polars, +} + +impl FromStr for StreamType { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.trim().to_ascii_lowercase().as_str() { + "naive" => Ok(Self::Naive), + "polars" => Ok(Self::Polars), + tag => Err(format!( + "Invalid {tag} streaming type, only [naive, polars] are supported", + )), + } + } +} + +impl<'de> serde::Deserialize<'de> for StreamType { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + StreamType::from_str(&s).map_err(serde::de::Error::custom) + } +} + +#[derive(Debug, Default, Clone, Serialize, serde::Deserialize)] +pub struct StreamSettings { + #[serde(rename = "type")] + pub ty: StreamType, + pub settings: S, +} + +pub trait Producer: Send + Sync + 'static { + type Settings: Send; + type Subscriber: Subscriber; + + fn new(settings: Self::Settings) -> anyhow::Result + where + Self: Sized; + + fn send(&self, state: ::Record) -> anyhow::Result<()>; + + fn subscribe(&self) -> anyhow::Result + where + Self::Subscriber: Sized; + + fn stop(&self) -> anyhow::Result<()>; +} + +pub trait Subscriber { + type Record: Serialize + Send + Sync + 'static; + + fn next(&self) -> Option>; + + fn run(self) -> anyhow::Result<()> + where + Self: Sized, + { + while let Some(state) = self.next() { + self.sink(state?)?; + } + Ok(()) + } + + fn sink(&self, state: Self::Record) -> anyhow::Result<()>; +} diff --git a/simulations/src/streaming/naive.rs b/simulations/src/streaming/naive.rs new file mode 100644 index 00000000..8b1c8ea3 --- /dev/null +++ b/simulations/src/streaming/naive.rs @@ -0,0 +1,237 @@ +use std::{ + fs::{File, OpenOptions}, + io::Write, + path::PathBuf, + sync::{Arc, Mutex}, +}; + +use super::{Producer, Receivers, Subscriber}; +use arc_swap::ArcSwapOption; +use crossbeam::channel::{bounded, unbounded, Sender}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NaiveSettings { + pub path: PathBuf, +} + +impl Default for NaiveSettings { + fn default() -> Self { + let mut tmp = std::env::temp_dir(); + tmp.push("simulation"); + tmp.set_extension("data"); + Self { path: tmp } + } +} + +#[derive(Debug)] +pub struct NaiveProducer { + sender: Sender, + stop_tx: Sender<()>, + recvs: ArcSwapOption>, + settings: NaiveSettings, +} + +impl Producer for NaiveProducer +where + R: Serialize + Send + Sync + 'static, +{ + type Settings = NaiveSettings; + + type Subscriber = NaiveSubscriber; + + fn new(settings: Self::Settings) -> anyhow::Result + where + Self: Sized, + { + let (sender, recv) = unbounded(); + let (stop_tx, stop_rx) = bounded(1); + Ok(Self { + sender, + recvs: ArcSwapOption::from(Some(Arc::new(Receivers { stop_rx, recv }))), + stop_tx, + settings, + }) + } + + fn send(&self, state: ::Record) -> anyhow::Result<()> { + self.sender.send(state).map_err(From::from) + } + + fn subscribe(&self) -> anyhow::Result + where + Self::Subscriber: Sized, + { + let recvs = self.recvs.load(); + if recvs.is_none() { + return Err(anyhow::anyhow!("Producer has been subscribed")); + } + + let mut opts = OpenOptions::new(); + let recvs = self.recvs.swap(None).unwrap(); + let this = NaiveSubscriber { + file: Arc::new(Mutex::new( + opts.truncate(true) + .create(true) + .read(true) + .write(true) + .open(&self.settings.path)?, + )), + recvs, + }; + eprintln!("Subscribed to {}", self.settings.path.display()); + Ok(this) + } + + fn stop(&self) -> anyhow::Result<()> { + Ok(self.stop_tx.send(())?) + } +} + +#[derive(Debug)] +pub struct NaiveSubscriber { + file: Arc>, + recvs: Arc>, +} + +impl Subscriber for NaiveSubscriber +where + R: Serialize + Send + Sync + 'static, +{ + type Record = R; + + fn next(&self) -> Option> { + Some(self.recvs.recv.recv().map_err(From::from)) + } + + fn run(self) -> anyhow::Result<()> { + loop { + crossbeam::select! { + recv(self.recvs.stop_rx) -> _ => { + break; + } + recv(self.recvs.recv) -> msg => { + self.sink(msg?)?; + } + } + } + + Ok(()) + } + + fn sink(&self, state: Self::Record) -> anyhow::Result<()> { + let mut file = self.file.lock().expect("failed to lock file"); + serde_json::to_writer(&mut *file, &state)?; + file.write_all(b",\n")?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, time::Duration}; + + use crate::{ + network::{ + behaviour::NetworkBehaviour, + regions::{Region, RegionsData}, + Network, + }, + node::{dummy_streaming::DummyStreamingNode, Node, NodeId}, + overlay::tree::TreeOverlay, + runner::SimulationRunner, + warding::SimulationState, + }; + + use super::*; + #[derive(Debug, Clone, Serialize)] + struct NaiveRecord { + states: HashMap, + } + + impl TryFrom<&SimulationState>> for NaiveRecord { + type Error = anyhow::Error; + + fn try_from(value: &SimulationState>) -> Result { + Ok(Self { + states: value + .nodes + .read() + .expect("failed to read nodes") + .iter() + .map(|node| (node.id(), node.current_view())) + .collect(), + }) + } + } + + #[test] + fn test_streaming() { + let simulation_settings = crate::settings::SimulationSettings { + seed: Some(1), + ..Default::default() + }; + + let nodes = (0..6) + .map(|idx| DummyStreamingNode::new(NodeId::from(idx), ())) + .collect::>(); + let network = Network::new(RegionsData { + regions: (0..6) + .map(|idx| { + let region = match idx % 6 { + 0 => Region::Europe, + 1 => Region::NorthAmerica, + 2 => Region::SouthAmerica, + 3 => Region::Asia, + 4 => Region::Africa, + 5 => Region::Australia, + _ => unreachable!(), + }; + (region, vec![idx.into()]) + }) + .collect(), + node_region: (0..6) + .map(|idx| { + let region = match idx % 6 { + 0 => Region::Europe, + 1 => Region::NorthAmerica, + 2 => Region::SouthAmerica, + 3 => Region::Asia, + 4 => Region::Africa, + 5 => Region::Australia, + _ => unreachable!(), + }; + (idx.into(), region) + }) + .collect(), + region_network_behaviour: (0..6) + .map(|idx| { + let region = match idx % 6 { + 0 => Region::Europe, + 1 => Region::NorthAmerica, + 2 => Region::SouthAmerica, + 3 => Region::Asia, + 4 => Region::Africa, + 5 => Region::Australia, + _ => unreachable!(), + }; + ( + (region, region), + NetworkBehaviour { + delay: Duration::from_millis(100), + drop: 0.0, + }, + ) + }) + .collect(), + }); + let simulation_runner: SimulationRunner< + (), + DummyStreamingNode<()>, + TreeOverlay, + NaiveProducer, + > = SimulationRunner::new(network, nodes, simulation_settings); + + simulation_runner.simulate().unwrap(); + } +} diff --git a/simulations/src/streaming/polars.rs b/simulations/src/streaming/polars.rs new file mode 100644 index 00000000..889a26b0 --- /dev/null +++ b/simulations/src/streaming/polars.rs @@ -0,0 +1,194 @@ +use arc_swap::ArcSwapOption; +use crossbeam::channel::{bounded, unbounded, Sender}; +use polars::prelude::*; +use serde::{Deserialize, Serialize}; +use std::{ + fs::File, + io::Cursor, + path::{Path, PathBuf}, + str::FromStr, + sync::Mutex, +}; + +use super::{Producer, Receivers, Subscriber}; + +#[derive(Debug, Clone, Copy, Serialize)] +pub enum PolarsFormat { + Json, + Csv, + Parquet, +} + +impl FromStr for PolarsFormat { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.trim().to_ascii_lowercase().as_str() { + "json" => Ok(Self::Json), + "csv" => Ok(Self::Csv), + "parquet" => Ok(Self::Parquet), + tag => Err(format!( + "Invalid {tag} format, only [json, csv, parquet] are supported", + )), + } + } +} + +impl<'de> Deserialize<'de> for PolarsFormat { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + PolarsFormat::from_str(&s).map_err(serde::de::Error::custom) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PolarsSettings { + pub format: PolarsFormat, + pub path: PathBuf, +} + +#[derive(Debug)] +pub struct PolarsProducer { + sender: Sender, + stop_tx: Sender<()>, + recvs: ArcSwapOption>, + settings: PolarsSettings, +} + +impl Producer for PolarsProducer +where + R: Serialize + Send + Sync + 'static, +{ + type Settings = PolarsSettings; + + type Subscriber = PolarsSubscriber; + + fn new(settings: Self::Settings) -> anyhow::Result + where + Self: Sized, + { + let (sender, recv) = unbounded(); + let (stop_tx, stop_rx) = bounded(1); + Ok(Self { + sender, + recvs: ArcSwapOption::from(Some(Arc::new(Receivers { stop_rx, recv }))), + stop_tx, + settings, + }) + } + + fn send(&self, state: ::Record) -> anyhow::Result<()> { + self.sender.send(state).map_err(From::from) + } + + fn subscribe(&self) -> anyhow::Result + where + Self::Subscriber: Sized, + { + let recvs = self.recvs.load(); + if recvs.is_none() { + return Err(anyhow::anyhow!("Producer has been subscribed")); + } + + let recvs = self.recvs.swap(None).unwrap(); + let this = PolarsSubscriber { + data: Arc::new(Mutex::new(Vec::new())), + recvs, + path: self.settings.path.clone(), + format: self.settings.format, + }; + Ok(this) + } + + fn stop(&self) -> anyhow::Result<()> { + Ok(self.stop_tx.send(())?) + } +} + +#[derive(Debug)] +pub struct PolarsSubscriber { + data: Arc>>, + path: PathBuf, + format: PolarsFormat, + recvs: Arc>, +} + +impl PolarsSubscriber +where + R: Serialize, +{ + fn persist(&self) -> anyhow::Result<()> { + let data = self + .data + .lock() + .expect("failed to lock data in PolarsSubscriber pesist"); + let mut cursor = Cursor::new(Vec::new()); + serde_json::to_writer(&mut cursor, &*data).expect("Dump data to json "); + let mut data = JsonReader::new(cursor) + .finish() + .expect("Load dataframe from intermediary json"); + + data.unnest(["state"])?; + match self.format { + PolarsFormat::Json => dump_dataframe_to_json(&mut data, self.path.as_path()), + PolarsFormat::Csv => dump_dataframe_to_csv(&mut data, self.path.as_path()), + PolarsFormat::Parquet => dump_dataframe_to_parquet(&mut data, self.path.as_path()), + } + } +} + +impl super::Subscriber for PolarsSubscriber +where + R: Serialize + Send + Sync + 'static, +{ + type Record = R; + + fn next(&self) -> Option> { + Some(self.recvs.recv.recv().map_err(From::from)) + } + + fn run(self) -> anyhow::Result<()> { + loop { + crossbeam::select! { + recv(self.recvs.stop_rx) -> _ => { + return self.persist(); + } + recv(self.recvs.recv) -> msg => { + self.sink(msg?)?; + } + } + } + } + + fn sink(&self, state: Self::Record) -> anyhow::Result<()> { + self.data + .lock() + .expect("failed to lock data in PolarsSubscriber") + .push(state); + Ok(()) + } +} + +fn dump_dataframe_to_json(data: &mut DataFrame, out_path: &Path) -> anyhow::Result<()> { + let out_path = out_path.with_extension("json"); + let f = File::create(out_path)?; + let mut writer = polars::prelude::JsonWriter::new(f); + Ok(writer.finish(data)?) +} + +fn dump_dataframe_to_csv(data: &mut DataFrame, out_path: &Path) -> anyhow::Result<()> { + let out_path = out_path.with_extension("csv"); + let f = File::create(out_path)?; + let mut writer = polars::prelude::CsvWriter::new(f); + Ok(writer.finish(data)?) +} + +fn dump_dataframe_to_parquet(data: &mut DataFrame, out_path: &Path) -> anyhow::Result<()> { + let out_path = out_path.with_extension("parquet"); + let f = File::create(out_path)?; + let writer = polars::prelude::ParquetWriter::new(f); + Ok(writer.finish(data).map(|_| ())?) +}