Simulation streaming and gracefully shutdown (#119)
* add stream supports * add test case * add polars stream, and force to use stream for the runner * using arcswap instead RefCell for producers * finish gracefully shutdown * - add IOProducer and IOSubscriber - fix deadlock in sync runner - fix testcases
This commit is contained in:
parent
b1381d727f
commit
ea7896f06c
@ -7,6 +7,7 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1"
|
||||
arc-swap = "1.6"
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
crc32fast = "1.3"
|
||||
crossbeam = { version = "0.8.2", features = ["crossbeam-channel"] }
|
||||
@ -15,9 +16,13 @@ nomos-core = { path = "../nomos-core" }
|
||||
polars = { version = "0.27", features = ["serde", "object", "json", "csv-file", "parquet", "dtype-struct"] }
|
||||
rand = { version = "0.8", features = ["small_rng"] }
|
||||
rayon = "1.7"
|
||||
scopeguard = "1"
|
||||
serde = { version = "1.0", features = ["derive", "rc"] }
|
||||
serde_with = "2.3"
|
||||
serde_json = "1.0"
|
||||
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.4"
|
@ -1,78 +1,21 @@
|
||||
// std
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::fs::File;
|
||||
use std::io::Cursor;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::str::FromStr;
|
||||
// crates
|
||||
use clap::Parser;
|
||||
use polars::io::SerWriter;
|
||||
use polars::prelude::{DataFrame, JsonReader, SerReader};
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use simulations::network::regions::RegionsData;
|
||||
use simulations::network::Network;
|
||||
use simulations::overlay::tree::TreeOverlay;
|
||||
use simulations::streaming::StreamType;
|
||||
// internal
|
||||
use simulations::{
|
||||
node::carnot::CarnotNode, output_processors::OutData, runner::SimulationRunner,
|
||||
settings::SimulationSettings,
|
||||
settings::SimulationSettings, streaming::io::IOProducer, streaming::naive::NaiveProducer,
|
||||
streaming::polars::PolarsProducer,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
enum OutputType {
|
||||
File(PathBuf),
|
||||
StdOut,
|
||||
StdErr,
|
||||
}
|
||||
|
||||
impl core::fmt::Display for OutputType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
OutputType::File(path) => write!(f, "{}", path.display()),
|
||||
OutputType::StdOut => write!(f, "stdout"),
|
||||
OutputType::StdErr => write!(f, "stderr"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Output format selector enum
|
||||
#[derive(Clone, Debug, Default)]
|
||||
enum OutputFormat {
|
||||
Json,
|
||||
Csv,
|
||||
#[default]
|
||||
Parquet,
|
||||
}
|
||||
|
||||
impl Display for OutputFormat {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
let tag = match self {
|
||||
OutputFormat::Json => "json",
|
||||
OutputFormat::Csv => "csv",
|
||||
OutputFormat::Parquet => "parquet",
|
||||
};
|
||||
write!(f, "{tag}")
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for OutputFormat {
|
||||
type Err = std::io::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_ascii_lowercase().as_str() {
|
||||
"json" => Ok(Self::Json),
|
||||
"csv" => Ok(Self::Csv),
|
||||
"parquet" => Ok(Self::Parquet),
|
||||
tag => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Invalid {tag} tag, only [json, csv, polars] are supported",),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Main simulation wrapper
|
||||
/// Pipes together the cli arguments with the execution
|
||||
#[derive(Parser)]
|
||||
@ -80,88 +23,67 @@ pub struct SimulationApp {
|
||||
/// Json file path, on `SimulationSettings` format
|
||||
#[clap(long, short)]
|
||||
input_settings: PathBuf,
|
||||
/// Output file path
|
||||
#[clap(long, short)]
|
||||
output_file: PathBuf,
|
||||
/// Output format selector
|
||||
#[clap(long, short = 'f', default_value_t)]
|
||||
output_format: OutputFormat,
|
||||
#[clap(long)]
|
||||
stream_type: StreamType,
|
||||
}
|
||||
|
||||
impl SimulationApp {
|
||||
pub fn run(self) -> anyhow::Result<()> {
|
||||
let Self {
|
||||
input_settings,
|
||||
output_file,
|
||||
output_format,
|
||||
stream_type,
|
||||
} = self;
|
||||
let simulation_settings: SimulationSettings<_, _> = load_json_from_file(&input_settings)?;
|
||||
|
||||
let nodes = vec![]; // TODO: Initialize nodes of different types.
|
||||
let regions_data = RegionsData::new(HashMap::new(), HashMap::new());
|
||||
let network = Network::new(regions_data);
|
||||
|
||||
let mut simulation_runner: SimulationRunner<(), CarnotNode, TreeOverlay> =
|
||||
SimulationRunner::new(network, nodes, simulation_settings);
|
||||
// build up series vector
|
||||
let mut out_data: Vec<OutData> = Vec::new();
|
||||
simulation_runner.simulate(Some(&mut out_data))?;
|
||||
let mut dataframe: DataFrame = out_data_to_dataframe(out_data);
|
||||
dump_dataframe_to(output_format, &mut dataframe, &output_file)?;
|
||||
match stream_type {
|
||||
simulations::streaming::StreamType::Naive => {
|
||||
let simulation_settings: SimulationSettings<_, _, _> =
|
||||
load_json_from_file(&input_settings)?;
|
||||
let simulation_runner: SimulationRunner<
|
||||
(),
|
||||
CarnotNode,
|
||||
TreeOverlay,
|
||||
NaiveProducer<OutData>,
|
||||
> = SimulationRunner::new(network, nodes, simulation_settings);
|
||||
simulation_runner.simulate()?
|
||||
}
|
||||
simulations::streaming::StreamType::Polars => {
|
||||
let simulation_settings: SimulationSettings<_, _, _> =
|
||||
load_json_from_file(&input_settings)?;
|
||||
let simulation_runner: SimulationRunner<
|
||||
(),
|
||||
CarnotNode,
|
||||
TreeOverlay,
|
||||
PolarsProducer<OutData>,
|
||||
> = SimulationRunner::new(network, nodes, simulation_settings);
|
||||
simulation_runner.simulate()?
|
||||
}
|
||||
simulations::streaming::StreamType::IO => {
|
||||
let simulation_settings: SimulationSettings<_, _, _> =
|
||||
load_json_from_file(&input_settings)?;
|
||||
let simulation_runner: SimulationRunner<
|
||||
(),
|
||||
CarnotNode,
|
||||
TreeOverlay,
|
||||
IOProducer<std::io::Stdout, OutData>,
|
||||
> = SimulationRunner::new(network, nodes, simulation_settings);
|
||||
simulation_runner.simulate()?
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn out_data_to_dataframe(out_data: Vec<OutData>) -> DataFrame {
|
||||
let mut cursor = Cursor::new(Vec::new());
|
||||
serde_json::to_writer(&mut cursor, &out_data).expect("Dump data to json ");
|
||||
let dataframe = JsonReader::new(cursor)
|
||||
.finish()
|
||||
.expect("Load dataframe from intermediary json");
|
||||
|
||||
dataframe
|
||||
.unnest(["state"])
|
||||
.expect("Node state should be unnest")
|
||||
}
|
||||
|
||||
/// Generically load a json file
|
||||
fn load_json_from_file<T: DeserializeOwned>(path: &Path) -> anyhow::Result<T> {
|
||||
let f = File::open(path).map_err(Box::new)?;
|
||||
Ok(serde_json::from_reader(f)?)
|
||||
}
|
||||
|
||||
fn dump_dataframe_to_json(data: &mut DataFrame, out_path: &Path) -> anyhow::Result<()> {
|
||||
let out_path = out_path.with_extension("json");
|
||||
let f = File::create(out_path)?;
|
||||
let mut writer = polars::prelude::JsonWriter::new(f);
|
||||
Ok(writer.finish(data)?)
|
||||
}
|
||||
|
||||
fn dump_dataframe_to_csv(data: &mut DataFrame, out_path: &Path) -> anyhow::Result<()> {
|
||||
let out_path = out_path.with_extension("csv");
|
||||
let f = File::create(out_path)?;
|
||||
let mut writer = polars::prelude::CsvWriter::new(f);
|
||||
Ok(writer.finish(data)?)
|
||||
}
|
||||
|
||||
fn dump_dataframe_to_parquet(data: &mut DataFrame, out_path: &Path) -> anyhow::Result<()> {
|
||||
let out_path = out_path.with_extension("parquet");
|
||||
let f = File::create(out_path)?;
|
||||
let writer = polars::prelude::ParquetWriter::new(f);
|
||||
Ok(writer.finish(data).map(|_| ())?)
|
||||
}
|
||||
|
||||
fn dump_dataframe_to(
|
||||
output_format: OutputFormat,
|
||||
data: &mut DataFrame,
|
||||
out_path: &Path,
|
||||
) -> anyhow::Result<()> {
|
||||
match output_format {
|
||||
OutputFormat::Json => dump_dataframe_to_json(data, out_path),
|
||||
OutputFormat::Csv => dump_dataframe_to_csv(data, out_path),
|
||||
OutputFormat::Parquet => dump_dataframe_to_parquet(data, out_path),
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let app: SimulationApp = SimulationApp::parse();
|
||||
app.run()?;
|
||||
|
@ -4,4 +4,5 @@ pub mod output_processors;
|
||||
pub mod overlay;
|
||||
pub mod runner;
|
||||
pub mod settings;
|
||||
pub mod streaming;
|
||||
pub mod warding;
|
||||
|
48
simulations/src/node/dummy_streaming.rs
Normal file
48
simulations/src/node/dummy_streaming.rs
Normal file
@ -0,0 +1,48 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{Node, NodeId};
|
||||
|
||||
#[derive(Debug, Default, Copy, Clone, Serialize, Deserialize)]
|
||||
pub struct DummyStreamingState {
|
||||
pub current_view: usize,
|
||||
}
|
||||
|
||||
/// This node implementation only used for testing different streaming implementation purposes.
|
||||
pub struct DummyStreamingNode<S> {
|
||||
id: NodeId,
|
||||
state: DummyStreamingState,
|
||||
#[allow(dead_code)]
|
||||
settings: S,
|
||||
}
|
||||
|
||||
impl<S> DummyStreamingNode<S> {
|
||||
pub fn new(id: NodeId, settings: S) -> Self {
|
||||
Self {
|
||||
id,
|
||||
state: DummyStreamingState::default(),
|
||||
settings,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Node for DummyStreamingNode<S> {
|
||||
type Settings = S;
|
||||
|
||||
type State = DummyStreamingState;
|
||||
|
||||
fn id(&self) -> NodeId {
|
||||
self.id
|
||||
}
|
||||
|
||||
fn current_view(&self) -> usize {
|
||||
self.state.current_view
|
||||
}
|
||||
|
||||
fn state(&self) -> &Self::State {
|
||||
&self.state
|
||||
}
|
||||
|
||||
fn step(&mut self) {
|
||||
self.state.current_view += 1;
|
||||
}
|
||||
}
|
@ -1,6 +1,9 @@
|
||||
pub mod carnot;
|
||||
pub mod dummy;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod dummy_streaming;
|
||||
|
||||
// std
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
|
@ -1,5 +1,7 @@
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::warding::SimulationState;
|
||||
|
||||
pub type SerializedNodeState = serde_json::Value;
|
||||
|
||||
#[derive(Serialize)]
|
||||
@ -12,12 +14,12 @@ impl OutData {
|
||||
}
|
||||
}
|
||||
|
||||
impl<N> TryFrom<&crate::warding::SimulationState<N>> for OutData
|
||||
impl<N> TryFrom<&SimulationState<N>> for OutData
|
||||
where
|
||||
N: crate::node::Node,
|
||||
N::State: Serialize,
|
||||
{
|
||||
type Error = serde_json::Error;
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(state: &crate::warding::SimulationState<N>) -> Result<Self, Self::Error> {
|
||||
serde_json::to_value(
|
||||
@ -30,6 +32,7 @@ where
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
.map(OutData::new)
|
||||
.map_err(From::from)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,25 +1,30 @@
|
||||
use crate::node::{Node, NodeId};
|
||||
use crate::output_processors::OutData;
|
||||
use crate::overlay::Overlay;
|
||||
use crate::runner::SimulationRunner;
|
||||
use crate::runner::{SimulationRunner, SimulationRunnerHandle};
|
||||
use crate::streaming::{Producer, Subscriber};
|
||||
use crate::warding::SimulationState;
|
||||
use crossbeam::channel::bounded;
|
||||
use crossbeam::select;
|
||||
use rand::prelude::SliceRandom;
|
||||
use rayon::prelude::*;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub fn simulate<M, N: Node, O: Overlay>(
|
||||
runner: &mut SimulationRunner<M, N, O>,
|
||||
/// Simulate with sending the network state to any subscriber
|
||||
pub fn simulate<M, N: Node, O: Overlay, P: Producer>(
|
||||
runner: SimulationRunner<M, N, O, P>,
|
||||
chunk_size: usize,
|
||||
mut out_data: Option<&mut Vec<OutData>>,
|
||||
) -> anyhow::Result<()>
|
||||
) -> anyhow::Result<SimulationRunnerHandle>
|
||||
where
|
||||
M: Send + Sync + Clone,
|
||||
N::Settings: Clone,
|
||||
N: Send + Sync,
|
||||
M: Clone + Send + Sync + 'static,
|
||||
N: Send + Sync + 'static,
|
||||
N::Settings: Clone + Send,
|
||||
N::State: Serialize,
|
||||
O::Settings: Clone,
|
||||
O::Settings: Clone + Send,
|
||||
P::Subscriber: Send + Sync + 'static,
|
||||
<P::Subscriber as Subscriber>::Record:
|
||||
Send + Sync + 'static + for<'a> TryFrom<&'a SimulationState<N>, Error = anyhow::Error>,
|
||||
{
|
||||
let simulation_state = SimulationState::<N> {
|
||||
nodes: Arc::clone(&runner.nodes),
|
||||
@ -33,26 +38,51 @@ where
|
||||
.map(N::id)
|
||||
.collect();
|
||||
|
||||
runner.dump_state_to_out_data(&simulation_state, &mut out_data)?;
|
||||
let inner = runner.inner.clone();
|
||||
let nodes = runner.nodes.clone();
|
||||
let (stop_tx, stop_rx) = bounded(1);
|
||||
let handle = SimulationRunnerHandle {
|
||||
stop_tx,
|
||||
handle: std::thread::spawn(move || {
|
||||
let p = P::new(runner.stream_settings.settings)?;
|
||||
scopeguard::defer!(if let Err(e) = p.stop() {
|
||||
eprintln!("Error stopping producer: {e}");
|
||||
});
|
||||
let subscriber = p.subscribe()?;
|
||||
std::thread::spawn(move || {
|
||||
if let Err(e) = subscriber.run() {
|
||||
eprintln!("Error in subscriber: {e}");
|
||||
}
|
||||
});
|
||||
loop {
|
||||
select! {
|
||||
recv(stop_rx) -> _ => {
|
||||
return Ok(());
|
||||
}
|
||||
default => {
|
||||
let mut inner = inner.write().expect("Write access to inner in async runner");
|
||||
node_ids.shuffle(&mut inner.rng);
|
||||
for ids_chunk in node_ids.chunks(chunk_size) {
|
||||
let ids: HashSet<NodeId> = ids_chunk.iter().copied().collect();
|
||||
nodes
|
||||
.write()
|
||||
.expect("Write access to nodes vector")
|
||||
.par_iter_mut()
|
||||
.filter(|n| ids.contains(&n.id()))
|
||||
.for_each(N::step);
|
||||
|
||||
loop {
|
||||
node_ids.shuffle(&mut runner.rng);
|
||||
for ids_chunk in node_ids.chunks(chunk_size) {
|
||||
let ids: HashSet<NodeId> = ids_chunk.iter().copied().collect();
|
||||
runner
|
||||
.nodes
|
||||
.write()
|
||||
.expect("Write access to nodes vector")
|
||||
.par_iter_mut()
|
||||
.filter(|n| ids.contains(&n.id()))
|
||||
.for_each(N::step);
|
||||
|
||||
runner.dump_state_to_out_data(&simulation_state, &mut out_data)?;
|
||||
}
|
||||
// check if any condition makes the simulation stop
|
||||
if runner.check_wards(&simulation_state) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
p.send(<P::Subscriber as Subscriber>::Record::try_from(
|
||||
&simulation_state,
|
||||
)?)?;
|
||||
}
|
||||
// check if any condition makes the simulation stop
|
||||
if inner.check_wards(&simulation_state) {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
};
|
||||
Ok(handle)
|
||||
}
|
||||
|
@ -1,65 +1,100 @@
|
||||
use crate::node::{Node, NodeId};
|
||||
use crate::output_processors::OutData;
|
||||
use crate::overlay::Overlay;
|
||||
use crate::runner::SimulationRunner;
|
||||
use crate::runner::{SimulationRunner, SimulationRunnerHandle};
|
||||
use crate::streaming::{Producer, Subscriber};
|
||||
use crate::warding::SimulationState;
|
||||
use crossbeam::channel::bounded;
|
||||
use crossbeam::select;
|
||||
use rand::prelude::IteratorRandom;
|
||||
use serde::Serialize;
|
||||
use std::collections::BTreeSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Simulate with sending the network state to any subscriber.
|
||||
///
|
||||
/// [Glauber dynamics simulation](https://en.wikipedia.org/wiki/Glauber_dynamics)
|
||||
pub fn simulate<M, N: Node, O: Overlay>(
|
||||
runner: &mut SimulationRunner<M, N, O>,
|
||||
pub fn simulate<M, N: Node, O: Overlay, P: Producer>(
|
||||
runner: SimulationRunner<M, N, O, P>,
|
||||
update_rate: usize,
|
||||
maximum_iterations: usize,
|
||||
mut out_data: Option<&mut Vec<OutData>>,
|
||||
) -> anyhow::Result<()>
|
||||
) -> anyhow::Result<SimulationRunnerHandle>
|
||||
where
|
||||
M: Send + Sync + Clone,
|
||||
N: Send + Sync,
|
||||
N::Settings: Clone,
|
||||
M: Send + Sync + Clone + 'static,
|
||||
N: Send + Sync + 'static,
|
||||
N::Settings: Clone + Send,
|
||||
N::State: Serialize,
|
||||
O::Settings: Clone,
|
||||
O::Settings: Clone + Send,
|
||||
P::Subscriber: Send + Sync + 'static,
|
||||
<P::Subscriber as Subscriber>::Record:
|
||||
for<'a> TryFrom<&'a SimulationState<N>, Error = anyhow::Error>,
|
||||
{
|
||||
let simulation_state = SimulationState {
|
||||
nodes: Arc::clone(&runner.nodes),
|
||||
};
|
||||
let nodes_remaining: BTreeSet<NodeId> = (0..runner
|
||||
.nodes
|
||||
.read()
|
||||
.expect("Read access to nodes vector")
|
||||
.len())
|
||||
.map(From::from)
|
||||
.collect();
|
||||
|
||||
let inner = runner.inner.clone();
|
||||
let nodes = runner.nodes.clone();
|
||||
let nodes_remaining: BTreeSet<NodeId> =
|
||||
(0..nodes.read().expect("Read access to nodes vector").len())
|
||||
.map(From::from)
|
||||
.collect();
|
||||
let iterations: Vec<_> = (0..maximum_iterations).collect();
|
||||
'main: for chunk in iterations.chunks(update_rate) {
|
||||
for _ in chunk {
|
||||
if nodes_remaining.is_empty() {
|
||||
break 'main;
|
||||
}
|
||||
let (stop_tx, stop_rx) = bounded(1);
|
||||
let handle = SimulationRunnerHandle {
|
||||
handle: std::thread::spawn(move || {
|
||||
let p = P::new(runner.stream_settings.settings)?;
|
||||
scopeguard::defer!(if let Err(e) = p.stop() {
|
||||
eprintln!("Error stopping producer: {e}");
|
||||
});
|
||||
let subscriber = p.subscribe()?;
|
||||
std::thread::spawn(move || {
|
||||
if let Err(e) = subscriber.run() {
|
||||
eprintln!("Error in subscriber: {e}");
|
||||
}
|
||||
});
|
||||
|
||||
let node_id = *nodes_remaining.iter().choose(&mut runner.rng).expect(
|
||||
"Some id to be selected as it should be impossible for the set to be empty here",
|
||||
);
|
||||
let mut inner = inner.write().expect("Locking inner");
|
||||
|
||||
{
|
||||
let mut shared_nodes = runner.nodes.write().expect("Write access to nodes vector");
|
||||
let node: &mut N = shared_nodes
|
||||
.get_mut(node_id.inner())
|
||||
.expect("Node should be present");
|
||||
node.step();
|
||||
}
|
||||
'main: for chunk in iterations.chunks(update_rate) {
|
||||
select! {
|
||||
recv(stop_rx) -> _ => break 'main,
|
||||
default => {
|
||||
for _ in chunk {
|
||||
if nodes_remaining.is_empty() {
|
||||
break 'main;
|
||||
}
|
||||
|
||||
// check if any condition makes the simulation stop
|
||||
if runner.check_wards(&simulation_state) {
|
||||
// we break the outer main loop, so we need to dump it before the breaking
|
||||
runner.dump_state_to_out_data(&simulation_state, &mut out_data)?;
|
||||
break 'main;
|
||||
let node_id = *nodes_remaining.iter().choose(&mut inner.rng).expect(
|
||||
"Some id to be selected as it should be impossible for the set to be empty here",
|
||||
);
|
||||
|
||||
{
|
||||
let mut shared_nodes = nodes.write().expect("Write access to nodes vector");
|
||||
let node: &mut N = shared_nodes
|
||||
.get_mut(node_id.inner())
|
||||
.expect("Node should be present");
|
||||
node.step();
|
||||
}
|
||||
|
||||
// check if any condition makes the simulation stop
|
||||
if inner.check_wards(&simulation_state) {
|
||||
// we break the outer main loop, so we need to dump it before the breaking
|
||||
p.send(<P::Subscriber as Subscriber>::Record::try_from(
|
||||
&simulation_state,
|
||||
)?)?;
|
||||
break 'main;
|
||||
}
|
||||
}
|
||||
// update_rate iterations reached, so dump state
|
||||
p.send(<P::Subscriber as Subscriber>::Record::try_from(
|
||||
&simulation_state,
|
||||
)?)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// update_rate iterations reached, so dump state
|
||||
runner.dump_state_to_out_data(&simulation_state, &mut out_data)?;
|
||||
}
|
||||
Ok(())
|
||||
Ok(())
|
||||
}),
|
||||
stop_tx,
|
||||
};
|
||||
Ok(handle)
|
||||
}
|
||||
|
@ -28,6 +28,8 @@
|
||||
//! the data of that step simulation.
|
||||
|
||||
// std
|
||||
use crossbeam::channel::bounded;
|
||||
use crossbeam::select;
|
||||
use std::collections::BTreeSet;
|
||||
use std::ops::Not;
|
||||
use std::sync::Arc;
|
||||
@ -38,76 +40,112 @@ use rand::rngs::SmallRng;
|
||||
use serde::Serialize;
|
||||
// internal
|
||||
use crate::node::{Node, NodeId};
|
||||
use crate::output_processors::OutData;
|
||||
use crate::overlay::Overlay;
|
||||
use crate::runner::SimulationRunner;
|
||||
use crate::streaming::{Producer, Subscriber};
|
||||
use crate::warding::SimulationState;
|
||||
|
||||
pub fn simulate<M, N: Node, O: Overlay>(
|
||||
runner: &mut SimulationRunner<M, N, O>,
|
||||
use super::SimulationRunnerHandle;
|
||||
|
||||
/// Simulate with sending the network state to any subscriber
|
||||
pub fn simulate<M, N: Node, O: Overlay, P: Producer>(
|
||||
runner: SimulationRunner<M, N, O, P>,
|
||||
gap: usize,
|
||||
distribution: Option<Vec<f32>>,
|
||||
mut out_data: Option<&mut Vec<OutData>>,
|
||||
) -> anyhow::Result<()>
|
||||
) -> anyhow::Result<SimulationRunnerHandle>
|
||||
where
|
||||
M: Send + Sync + Clone,
|
||||
N: Send + Sync,
|
||||
N::Settings: Clone,
|
||||
M: Send + Sync + Clone + 'static,
|
||||
N: Send + Sync + 'static,
|
||||
N::Settings: Clone + Send,
|
||||
N::State: Serialize,
|
||||
O::Settings: Clone,
|
||||
O::Settings: Clone + Send,
|
||||
P::Subscriber: Send + Sync + 'static,
|
||||
<P::Subscriber as Subscriber>::Record:
|
||||
for<'a> TryFrom<&'a SimulationState<N>, Error = anyhow::Error>,
|
||||
{
|
||||
let distribution =
|
||||
distribution.unwrap_or_else(|| std::iter::repeat(1.0f32).take(gap).collect());
|
||||
|
||||
let layers: Vec<usize> = (0..gap).collect();
|
||||
|
||||
let mut deque = build_node_ids_deque(gap, runner);
|
||||
let mut deque = build_node_ids_deque(gap, &runner);
|
||||
|
||||
let simulation_state = SimulationState {
|
||||
nodes: Arc::clone(&runner.nodes),
|
||||
};
|
||||
|
||||
loop {
|
||||
let (group_index, node_id) =
|
||||
choose_random_layer_and_node_id(&mut runner.rng, &distribution, &layers, &mut deque);
|
||||
let inner = runner.inner.clone();
|
||||
let nodes = runner.nodes.clone();
|
||||
let (stop_tx, stop_rx) = bounded(1);
|
||||
let handle = SimulationRunnerHandle {
|
||||
stop_tx,
|
||||
handle: std::thread::spawn(move || {
|
||||
let p = P::new(runner.stream_settings.settings)?;
|
||||
scopeguard::defer!(if let Err(e) = p.stop() {
|
||||
eprintln!("Error stopping producer: {e}");
|
||||
});
|
||||
let sub = p.subscribe()?;
|
||||
std::thread::spawn(move || {
|
||||
if let Err(e) = sub.run() {
|
||||
eprintln!("Error running subscriber: {e}");
|
||||
}
|
||||
});
|
||||
loop {
|
||||
select! {
|
||||
recv(stop_rx) -> _ => {
|
||||
break;
|
||||
}
|
||||
default => {
|
||||
let mut inner = inner.write().expect("Lock inner");
|
||||
let (group_index, node_id) =
|
||||
choose_random_layer_and_node_id(&mut inner.rng, &distribution, &layers, &mut deque);
|
||||
|
||||
// remove node_id from group
|
||||
deque.get_mut(group_index).unwrap().remove(&node_id);
|
||||
// remove node_id from group
|
||||
deque.get_mut(group_index).unwrap().remove(&node_id);
|
||||
|
||||
{
|
||||
let mut shared_nodes = runner.nodes.write().expect("Write access to nodes vector");
|
||||
let node: &mut N = shared_nodes
|
||||
.get_mut(node_id.inner())
|
||||
.expect("Node should be present");
|
||||
let prev_view = node.current_view();
|
||||
node.step();
|
||||
let after_view = node.current_view();
|
||||
if after_view > prev_view {
|
||||
// pass node to next step group
|
||||
deque.get_mut(group_index + 1).unwrap().insert(node_id);
|
||||
{
|
||||
let mut shared_nodes = nodes.write().expect("Write access to nodes vector");
|
||||
let node: &mut N = shared_nodes
|
||||
.get_mut(node_id.inner())
|
||||
.expect("Node should be present");
|
||||
let prev_view = node.current_view();
|
||||
node.step();
|
||||
let after_view = node.current_view();
|
||||
if after_view > prev_view {
|
||||
// pass node to next step group
|
||||
deque.get_mut(group_index + 1).unwrap().insert(node_id);
|
||||
}
|
||||
}
|
||||
|
||||
// check if any condition makes the simulation stop
|
||||
if inner.check_wards(&simulation_state) {
|
||||
break;
|
||||
}
|
||||
|
||||
// if initial is empty then we finished a full round, append a new set to the end so we can
|
||||
// compute the most advanced nodes again
|
||||
if deque.first().unwrap().is_empty() {
|
||||
let _ = deque.push_back(BTreeSet::default());
|
||||
p.send(<P::Subscriber as Subscriber>::Record::try_from(
|
||||
&simulation_state,
|
||||
)?)?;
|
||||
}
|
||||
|
||||
// if no more nodes to compute
|
||||
if deque.iter().all(BTreeSet::is_empty) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check if any condition makes the simulation stop
|
||||
if runner.check_wards(&simulation_state) {
|
||||
break;
|
||||
}
|
||||
|
||||
// if initial is empty then we finished a full round, append a new set to the end so we can
|
||||
// compute the most advanced nodes again
|
||||
if deque.first().unwrap().is_empty() {
|
||||
let _ = deque.push_back(BTreeSet::default());
|
||||
runner.dump_state_to_out_data(&simulation_state, &mut out_data)?;
|
||||
}
|
||||
|
||||
// if no more nodes to compute
|
||||
if deque.iter().all(BTreeSet::is_empty) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// write latest state
|
||||
runner.dump_state_to_out_data(&simulation_state, &mut out_data)?;
|
||||
Ok(())
|
||||
// write latest state
|
||||
p.send(<P::Subscriber as Subscriber>::Record::try_from(
|
||||
&simulation_state,
|
||||
)?)?;
|
||||
Ok(())
|
||||
}),
|
||||
};
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
fn choose_random_layer_and_node_id(
|
||||
@ -134,13 +172,14 @@ fn choose_random_layer_and_node_id(
|
||||
(i, *node_id)
|
||||
}
|
||||
|
||||
fn build_node_ids_deque<M, N, O>(
|
||||
fn build_node_ids_deque<M, N, O, P>(
|
||||
gap: usize,
|
||||
runner: &SimulationRunner<M, N, O>,
|
||||
runner: &SimulationRunner<M, N, O, P>,
|
||||
) -> FixedSliceDeque<BTreeSet<NodeId>>
|
||||
where
|
||||
N: Node,
|
||||
O: Overlay,
|
||||
P: Producer,
|
||||
{
|
||||
// add a +1 so we always have
|
||||
let mut deque = FixedSliceDeque::new(gap + 1);
|
||||
|
@ -9,6 +9,8 @@ use std::sync::{Arc, RwLock};
|
||||
use std::time::Duration;
|
||||
|
||||
// crates
|
||||
use crate::streaming::{Producer, Subscriber};
|
||||
use crossbeam::channel::Sender;
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{RngCore, SeedableRng};
|
||||
use rayon::prelude::*;
|
||||
@ -17,37 +19,96 @@ use serde::Serialize;
|
||||
// internal
|
||||
use crate::network::Network;
|
||||
use crate::node::Node;
|
||||
use crate::output_processors::OutData;
|
||||
use crate::overlay::Overlay;
|
||||
use crate::settings::{RunnerSettings, SimulationSettings};
|
||||
use crate::warding::{SimulationState, SimulationWard};
|
||||
use crate::streaming::StreamSettings;
|
||||
use crate::warding::{SimulationState, SimulationWard, Ward};
|
||||
|
||||
pub struct SimulationRunnerHandle {
|
||||
handle: std::thread::JoinHandle<anyhow::Result<()>>,
|
||||
stop_tx: Sender<()>,
|
||||
}
|
||||
|
||||
impl SimulationRunnerHandle {
|
||||
pub fn stop_after(self, duration: Duration) -> anyhow::Result<()> {
|
||||
std::thread::sleep(duration);
|
||||
self.stop()
|
||||
}
|
||||
|
||||
pub fn stop(self) -> anyhow::Result<()> {
|
||||
if !self.handle.is_finished() {
|
||||
self.stop_tx.send(())?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct SimulationRunnerInner<M> {
|
||||
network: Network<M>,
|
||||
wards: Vec<Ward>,
|
||||
rng: SmallRng,
|
||||
}
|
||||
|
||||
impl<M> SimulationRunnerInner<M>
|
||||
where
|
||||
M: Send + Sync + Clone,
|
||||
{
|
||||
fn check_wards<N>(&mut self, state: &SimulationState<N>) -> bool
|
||||
where
|
||||
N: Node + Send + Sync,
|
||||
N::Settings: Clone + Send,
|
||||
N::State: Serialize,
|
||||
{
|
||||
self.wards
|
||||
.par_iter_mut()
|
||||
.map(|ward| ward.analyze(state))
|
||||
.any(|x| x)
|
||||
}
|
||||
|
||||
fn step<N>(&mut self, nodes: &mut Vec<N>)
|
||||
where
|
||||
N: Node + Send + Sync,
|
||||
N::Settings: Clone + Send,
|
||||
N::State: Serialize,
|
||||
{
|
||||
self.network.dispatch_after(Duration::from_millis(100));
|
||||
nodes.par_iter_mut().for_each(|node| {
|
||||
node.step();
|
||||
});
|
||||
self.network.collect_messages();
|
||||
}
|
||||
}
|
||||
|
||||
/// Encapsulation solution for the simulations runner
|
||||
/// Holds the network state, the simulating nodes and the simulation settings.
|
||||
pub struct SimulationRunner<M, N, O>
|
||||
pub struct SimulationRunner<M, N, O, P>
|
||||
where
|
||||
N: Node,
|
||||
O: Overlay,
|
||||
P: Producer,
|
||||
{
|
||||
inner: Arc<RwLock<SimulationRunnerInner<M>>>,
|
||||
nodes: Arc<RwLock<Vec<N>>>,
|
||||
network: Network<M>,
|
||||
settings: SimulationSettings<N::Settings, O::Settings>,
|
||||
rng: SmallRng,
|
||||
runner_settings: RunnerSettings,
|
||||
stream_settings: StreamSettings<P::Settings>,
|
||||
_overlay: PhantomData<O>,
|
||||
}
|
||||
|
||||
impl<M, N: Node, O: Overlay> SimulationRunner<M, N, O>
|
||||
impl<M, N: Node, O: Overlay, P: Producer> SimulationRunner<M, N, O, P>
|
||||
where
|
||||
M: Send + Sync + Clone,
|
||||
N: Send + Sync,
|
||||
N::Settings: Clone,
|
||||
M: Clone + Send + Sync + 'static,
|
||||
N: Send + Sync + 'static,
|
||||
N::Settings: Clone + Send,
|
||||
N::State: Serialize,
|
||||
O::Settings: Clone,
|
||||
O::Settings: Clone + Send,
|
||||
P::Subscriber: Send + Sync + 'static,
|
||||
<P::Subscriber as Subscriber>::Record:
|
||||
Send + Sync + 'static + for<'a> TryFrom<&'a SimulationState<N>, Error = anyhow::Error>,
|
||||
{
|
||||
pub fn new(
|
||||
network: Network<M>,
|
||||
nodes: Vec<N>,
|
||||
settings: SimulationSettings<N::Settings, O::Settings>,
|
||||
settings: SimulationSettings<N::Settings, O::Settings, P::Settings>,
|
||||
) -> Self {
|
||||
let seed = settings
|
||||
.seed
|
||||
@ -57,59 +118,43 @@ where
|
||||
|
||||
let rng = SmallRng::seed_from_u64(seed);
|
||||
let nodes = Arc::new(RwLock::new(nodes));
|
||||
|
||||
let SimulationSettings {
|
||||
network_behaviors: _,
|
||||
regions: _,
|
||||
wards,
|
||||
overlay_settings: _,
|
||||
node_settings: _,
|
||||
runner_settings,
|
||||
stream_settings,
|
||||
node_count: _,
|
||||
committee_size: _,
|
||||
seed: _,
|
||||
} = settings;
|
||||
Self {
|
||||
stream_settings,
|
||||
runner_settings,
|
||||
inner: Arc::new(RwLock::new(SimulationRunnerInner {
|
||||
network,
|
||||
rng,
|
||||
wards,
|
||||
})),
|
||||
nodes,
|
||||
network,
|
||||
settings,
|
||||
rng,
|
||||
_overlay: Default::default(),
|
||||
_overlay: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn simulate(&mut self, out_data: Option<&mut Vec<OutData>>) -> anyhow::Result<()> {
|
||||
match self.settings.runner_settings.clone() {
|
||||
RunnerSettings::Sync => sync_runner::simulate(self, out_data),
|
||||
RunnerSettings::Async { chunks } => async_runner::simulate(self, chunks, out_data),
|
||||
pub fn simulate(self) -> anyhow::Result<SimulationRunnerHandle> {
|
||||
match self.runner_settings.clone() {
|
||||
RunnerSettings::Sync => sync_runner::simulate::<_, _, _, P>(self),
|
||||
RunnerSettings::Async { chunks } => async_runner::simulate::<_, _, _, P>(self, chunks),
|
||||
RunnerSettings::Glauber {
|
||||
maximum_iterations,
|
||||
update_rate,
|
||||
} => glauber_runner::simulate(self, update_rate, maximum_iterations, out_data),
|
||||
} => glauber_runner::simulate::<_, _, _, P>(self, update_rate, maximum_iterations),
|
||||
RunnerSettings::Layered {
|
||||
rounds_gap,
|
||||
distribution,
|
||||
} => layered_runner::simulate(self, rounds_gap, distribution, out_data),
|
||||
} => layered_runner::simulate::<_, _, _, P>(self, rounds_gap, distribution),
|
||||
}
|
||||
}
|
||||
|
||||
fn dump_state_to_out_data(
|
||||
&self,
|
||||
simulation_state: &SimulationState<N>,
|
||||
out_data: &mut Option<&mut Vec<OutData>>,
|
||||
) -> anyhow::Result<()> {
|
||||
if let Some(out_data) = out_data {
|
||||
out_data.push(OutData::try_from(simulation_state)?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn check_wards(&mut self, state: &SimulationState<N>) -> bool {
|
||||
self.settings
|
||||
.wards
|
||||
.par_iter_mut()
|
||||
.map(|ward| ward.analyze(state))
|
||||
.any(|x| x)
|
||||
}
|
||||
|
||||
fn step(&mut self) {
|
||||
self.network.dispatch_after(Duration::from_millis(100));
|
||||
self.nodes
|
||||
.write()
|
||||
.expect("Single access to nodes vector")
|
||||
.par_iter_mut()
|
||||
.for_each(|node| {
|
||||
node.step();
|
||||
});
|
||||
self.network.collect_messages();
|
||||
}
|
||||
}
|
||||
|
@ -1,39 +1,76 @@
|
||||
use serde::Serialize;
|
||||
|
||||
use super::SimulationRunner;
|
||||
use super::{SimulationRunner, SimulationRunnerHandle};
|
||||
use crate::node::Node;
|
||||
use crate::output_processors::OutData;
|
||||
use crate::overlay::Overlay;
|
||||
use crate::streaming::{Producer, Subscriber};
|
||||
use crate::warding::SimulationState;
|
||||
use crossbeam::channel::{bounded, select};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Simulate with option of dumping the network state as a `::polars::Series`
|
||||
pub fn simulate<M, N: Node, O: Overlay>(
|
||||
runner: &mut SimulationRunner<M, N, O>,
|
||||
mut out_data: Option<&mut Vec<OutData>>,
|
||||
) -> anyhow::Result<()>
|
||||
/// Simulate with sending the network state to any subscriber
|
||||
pub fn simulate<M, N: Node, O: Overlay, P: Producer>(
|
||||
runner: SimulationRunner<M, N, O, P>,
|
||||
) -> anyhow::Result<SimulationRunnerHandle>
|
||||
where
|
||||
M: Send + Sync + Clone,
|
||||
N: Send + Sync,
|
||||
N::Settings: Clone,
|
||||
M: Send + Sync + Clone + 'static,
|
||||
N: Send + Sync + 'static,
|
||||
N::Settings: Clone + Send,
|
||||
N::State: Serialize,
|
||||
O::Settings: Clone,
|
||||
P::Subscriber: Send + Sync + 'static,
|
||||
<P::Subscriber as Subscriber>::Record:
|
||||
Send + Sync + 'static + for<'a> TryFrom<&'a SimulationState<N>, Error = anyhow::Error>,
|
||||
{
|
||||
let state = SimulationState {
|
||||
nodes: Arc::clone(&runner.nodes),
|
||||
};
|
||||
|
||||
runner.dump_state_to_out_data(&state, &mut out_data)?;
|
||||
let inner = runner.inner.clone();
|
||||
let nodes = runner.nodes.clone();
|
||||
|
||||
for _ in 1.. {
|
||||
runner.step();
|
||||
runner.dump_state_to_out_data(&state, &mut out_data)?;
|
||||
// check if any condition makes the simulation stop
|
||||
if runner.check_wards(&state) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
let (stop_tx, stop_rx) = bounded(1);
|
||||
let handle = SimulationRunnerHandle {
|
||||
stop_tx,
|
||||
handle: std::thread::spawn(move || {
|
||||
let p = P::new(runner.stream_settings.settings)?;
|
||||
scopeguard::defer!(if let Err(e) = p.stop() {
|
||||
eprintln!("Error stopping producer: {e}");
|
||||
});
|
||||
let subscriber = p.subscribe()?;
|
||||
std::thread::spawn(move || {
|
||||
if let Err(e) = subscriber.run() {
|
||||
eprintln!("Error in subscriber: {e}");
|
||||
}
|
||||
});
|
||||
p.send(<P::Subscriber as Subscriber>::Record::try_from(&state)?)?;
|
||||
loop {
|
||||
select! {
|
||||
recv(stop_rx) -> _ => {
|
||||
return Ok(());
|
||||
}
|
||||
default => {
|
||||
let mut inner = inner.write().expect("Write access to inner simulation state");
|
||||
|
||||
// we must use a code block to make sure once the step call is finished then the write lock will be released, because in Record::try_from(&state),
|
||||
// we need to call the read lock, if we do not release the write lock,
|
||||
// then dead lock will occur
|
||||
{
|
||||
let mut nodes = nodes.write().expect("Write access to nodes vector");
|
||||
inner.step(&mut nodes);
|
||||
}
|
||||
|
||||
p.send(<P::Subscriber as Subscriber>::Record::try_from(&state).unwrap()).unwrap();
|
||||
// check if any condition makes the simulation stop
|
||||
if inner.check_wards(&state) {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
};
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -48,12 +85,14 @@ mod tests {
|
||||
dummy::{DummyMessage, DummyNetworkInterface, DummyNode, DummySettings},
|
||||
Node, NodeId, OverlayState, SharedState, ViewOverlay,
|
||||
},
|
||||
output_processors::OutData,
|
||||
overlay::{
|
||||
tree::{TreeOverlay, TreeSettings},
|
||||
Overlay,
|
||||
},
|
||||
runner::SimulationRunner,
|
||||
settings::SimulationSettings,
|
||||
streaming::naive::{NaiveProducer, NaiveSettings},
|
||||
};
|
||||
use crossbeam::channel;
|
||||
use rand::rngs::mock::StepRng;
|
||||
@ -95,11 +134,12 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn runner_one_step() {
|
||||
let settings: SimulationSettings<DummySettings, TreeSettings> = SimulationSettings {
|
||||
node_count: 10,
|
||||
committee_size: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let settings: SimulationSettings<DummySettings, TreeSettings, NaiveSettings> =
|
||||
SimulationSettings {
|
||||
node_count: 10,
|
||||
committee_size: 1,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut rng = StepRng::new(1, 0);
|
||||
let node_ids: Vec<NodeId> = (0..settings.node_count).map(Into::into).collect();
|
||||
@ -115,9 +155,11 @@ mod tests {
|
||||
}));
|
||||
let nodes = init_dummy_nodes(&node_ids, &mut network, overlay_state);
|
||||
|
||||
let mut runner: SimulationRunner<DummyMessage, DummyNode, TreeOverlay> =
|
||||
let runner: SimulationRunner<DummyMessage, DummyNode, TreeOverlay, NaiveProducer<OutData>> =
|
||||
SimulationRunner::new(network, nodes, settings);
|
||||
runner.step();
|
||||
let mut nodes = runner.nodes.write().unwrap();
|
||||
runner.inner.write().unwrap().step(&mut nodes);
|
||||
drop(nodes);
|
||||
|
||||
let nodes = runner.nodes.read().unwrap();
|
||||
for node in nodes.iter() {
|
||||
@ -127,11 +169,12 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn runner_send_receive() {
|
||||
let settings: SimulationSettings<DummySettings, TreeSettings> = SimulationSettings {
|
||||
node_count: 10,
|
||||
committee_size: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let settings: SimulationSettings<DummySettings, TreeSettings, NaiveSettings> =
|
||||
SimulationSettings {
|
||||
node_count: 10,
|
||||
committee_size: 1,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut rng = StepRng::new(1, 0);
|
||||
let node_ids: Vec<NodeId> = (0..settings.node_count).map(Into::into).collect();
|
||||
@ -159,10 +202,12 @@ mod tests {
|
||||
}
|
||||
network.collect_messages();
|
||||
|
||||
let mut runner: SimulationRunner<DummyMessage, DummyNode, TreeOverlay> =
|
||||
let runner: SimulationRunner<DummyMessage, DummyNode, TreeOverlay, NaiveProducer<OutData>> =
|
||||
SimulationRunner::new(network, nodes, settings);
|
||||
|
||||
runner.step();
|
||||
let mut nodes = runner.nodes.write().unwrap();
|
||||
runner.inner.write().unwrap().step(&mut nodes);
|
||||
drop(nodes);
|
||||
|
||||
let nodes = runner.nodes.read().unwrap();
|
||||
let state = nodes[1].state();
|
||||
|
@ -1,5 +1,6 @@
|
||||
use crate::network::regions::Region;
|
||||
use crate::node::StepTime;
|
||||
use crate::streaming::StreamSettings;
|
||||
use crate::warding::Ward;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
@ -22,7 +23,7 @@ pub enum RunnerSettings {
|
||||
}
|
||||
|
||||
#[derive(Default, Deserialize)]
|
||||
pub struct SimulationSettings<N, O> {
|
||||
pub struct SimulationSettings<N, O, P> {
|
||||
pub network_behaviors: HashMap<(Region, Region), StepTime>,
|
||||
pub regions: Vec<Region>,
|
||||
#[serde(default)]
|
||||
@ -30,6 +31,7 @@ pub struct SimulationSettings<N, O> {
|
||||
pub overlay_settings: O,
|
||||
pub node_settings: N,
|
||||
pub runner_settings: RunnerSettings,
|
||||
pub stream_settings: StreamSettings<P>,
|
||||
pub node_count: usize,
|
||||
pub committee_size: usize,
|
||||
pub seed: Option<u64>,
|
||||
|
246
simulations/src/streaming/io.rs
Normal file
246
simulations/src/streaming/io.rs
Normal file
@ -0,0 +1,246 @@
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::{Producer, Receivers, Subscriber};
|
||||
use arc_swap::ArcSwapOption;
|
||||
use crossbeam::channel::{bounded, unbounded, Sender};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct IOStreamSettings<W = std::io::Stdout> {
|
||||
pub writer: W,
|
||||
}
|
||||
|
||||
impl Default for IOStreamSettings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
writer: std::io::stdout(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for IOStreamSettings {
|
||||
fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
Ok(Self {
|
||||
writer: std::io::stdout(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct IOProducer<W, R> {
|
||||
sender: Sender<R>,
|
||||
stop_tx: Sender<()>,
|
||||
recvs: ArcSwapOption<Receivers<R>>,
|
||||
writer: ArcSwapOption<Mutex<W>>,
|
||||
}
|
||||
|
||||
impl<W, R> Producer for IOProducer<W, R>
|
||||
where
|
||||
W: std::io::Write + Send + Sync + 'static,
|
||||
R: Serialize + Send + Sync + 'static,
|
||||
{
|
||||
type Settings = IOStreamSettings<W>;
|
||||
|
||||
type Subscriber = IOSubscriber<W, R>;
|
||||
|
||||
fn new(settings: Self::Settings) -> anyhow::Result<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let (sender, recv) = unbounded();
|
||||
let (stop_tx, stop_rx) = bounded(1);
|
||||
Ok(Self {
|
||||
sender,
|
||||
recvs: ArcSwapOption::from(Some(Arc::new(Receivers { stop_rx, recv }))),
|
||||
stop_tx,
|
||||
writer: ArcSwapOption::from(Some(Arc::new(Mutex::new(settings.writer)))),
|
||||
})
|
||||
}
|
||||
|
||||
fn send(&self, state: <Self::Subscriber as Subscriber>::Record) -> anyhow::Result<()> {
|
||||
self.sender.send(state).map_err(From::from)
|
||||
}
|
||||
|
||||
fn subscribe(&self) -> anyhow::Result<Self::Subscriber>
|
||||
where
|
||||
Self::Subscriber: Sized,
|
||||
{
|
||||
let recvs = self.recvs.load();
|
||||
if recvs.is_none() {
|
||||
return Err(anyhow::anyhow!("Producer has been subscribed"));
|
||||
}
|
||||
|
||||
let recvs = self.recvs.swap(None).unwrap();
|
||||
let writer = self.writer.swap(None).unwrap();
|
||||
let this = IOSubscriber { recvs, writer };
|
||||
Ok(this)
|
||||
}
|
||||
|
||||
fn stop(&self) -> anyhow::Result<()> {
|
||||
Ok(self.stop_tx.send(())?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct IOSubscriber<W, R> {
|
||||
recvs: Arc<Receivers<R>>,
|
||||
writer: Arc<Mutex<W>>,
|
||||
}
|
||||
|
||||
impl<W, R> Subscriber for IOSubscriber<W, R>
|
||||
where
|
||||
W: std::io::Write + Send + Sync + 'static,
|
||||
R: Serialize + Send + Sync + 'static,
|
||||
{
|
||||
type Record = R;
|
||||
|
||||
fn next(&self) -> Option<anyhow::Result<Self::Record>> {
|
||||
Some(self.recvs.recv.recv().map_err(From::from))
|
||||
}
|
||||
|
||||
fn run(self) -> anyhow::Result<()> {
|
||||
loop {
|
||||
crossbeam::select! {
|
||||
recv(self.recvs.stop_rx) -> _ => {
|
||||
break;
|
||||
}
|
||||
recv(self.recvs.recv) -> msg => {
|
||||
self.sink(msg?)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn sink(&self, state: Self::Record) -> anyhow::Result<()> {
|
||||
serde_json::to_writer(
|
||||
&mut *self
|
||||
.writer
|
||||
.lock()
|
||||
.expect("fail to lock writer in io subscriber"),
|
||||
&state,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{collections::HashMap, time::Duration};
|
||||
|
||||
use crate::{
|
||||
network::{
|
||||
behaviour::NetworkBehaviour,
|
||||
regions::{Region, RegionsData},
|
||||
Network,
|
||||
},
|
||||
node::{dummy_streaming::DummyStreamingNode, Node, NodeId},
|
||||
overlay::tree::TreeOverlay,
|
||||
runner::SimulationRunner,
|
||||
streaming::{StreamSettings, StreamType},
|
||||
warding::SimulationState,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct IORecord {
|
||||
states: HashMap<NodeId, usize>,
|
||||
}
|
||||
|
||||
impl TryFrom<&SimulationState<DummyStreamingNode<()>>> for IORecord {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: &SimulationState<DummyStreamingNode<()>>) -> Result<Self, Self::Error> {
|
||||
let nodes = value.nodes.read().expect("failed to read nodes");
|
||||
Ok(Self {
|
||||
states: nodes
|
||||
.iter()
|
||||
.map(|node| (node.id(), node.current_view()))
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming() {
|
||||
let simulation_settings = crate::settings::SimulationSettings {
|
||||
seed: Some(1),
|
||||
stream_settings: StreamSettings {
|
||||
ty: StreamType::IO,
|
||||
settings: IOStreamSettings {
|
||||
writer: std::io::stdout(),
|
||||
},
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let nodes = (0..6)
|
||||
.map(|idx| DummyStreamingNode::new(NodeId::from(idx), ()))
|
||||
.collect::<Vec<_>>();
|
||||
let network = Network::new(RegionsData {
|
||||
regions: (0..6)
|
||||
.map(|idx| {
|
||||
let region = match idx % 6 {
|
||||
0 => Region::Europe,
|
||||
1 => Region::NorthAmerica,
|
||||
2 => Region::SouthAmerica,
|
||||
3 => Region::Asia,
|
||||
4 => Region::Africa,
|
||||
5 => Region::Australia,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
(region, vec![idx.into()])
|
||||
})
|
||||
.collect(),
|
||||
node_region: (0..6)
|
||||
.map(|idx| {
|
||||
let region = match idx % 6 {
|
||||
0 => Region::Europe,
|
||||
1 => Region::NorthAmerica,
|
||||
2 => Region::SouthAmerica,
|
||||
3 => Region::Asia,
|
||||
4 => Region::Africa,
|
||||
5 => Region::Australia,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
(idx.into(), region)
|
||||
})
|
||||
.collect(),
|
||||
region_network_behaviour: (0..6)
|
||||
.map(|idx| {
|
||||
let region = match idx % 6 {
|
||||
0 => Region::Europe,
|
||||
1 => Region::NorthAmerica,
|
||||
2 => Region::SouthAmerica,
|
||||
3 => Region::Asia,
|
||||
4 => Region::Africa,
|
||||
5 => Region::Australia,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
(
|
||||
(region, region),
|
||||
NetworkBehaviour {
|
||||
delay: Duration::from_millis(100),
|
||||
drop: 0.0,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
});
|
||||
let simulation_runner: SimulationRunner<
|
||||
(),
|
||||
DummyStreamingNode<()>,
|
||||
TreeOverlay,
|
||||
IOProducer<std::io::Stdout, IORecord>,
|
||||
> = SimulationRunner::new(network, nodes, simulation_settings);
|
||||
simulation_runner
|
||||
.simulate()
|
||||
.unwrap()
|
||||
.stop_after(Duration::from_millis(100))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
88
simulations/src/streaming/mod.rs
Normal file
88
simulations/src/streaming/mod.rs
Normal file
@ -0,0 +1,88 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
use crossbeam::channel::Receiver;
|
||||
use serde::Serialize;
|
||||
|
||||
pub mod io;
|
||||
pub mod naive;
|
||||
pub mod polars;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Receivers<R> {
|
||||
stop_rx: Receiver<()>,
|
||||
recv: Receiver<R>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize)]
|
||||
pub enum StreamType {
|
||||
#[default]
|
||||
IO,
|
||||
Naive,
|
||||
Polars,
|
||||
}
|
||||
|
||||
impl FromStr for StreamType {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.trim().to_ascii_lowercase().as_str() {
|
||||
"naive" => Ok(Self::Naive),
|
||||
"polars" => Ok(Self::Polars),
|
||||
tag => Err(format!(
|
||||
"Invalid {tag} streaming type, only [naive, polars] are supported",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::Deserialize<'de> for StreamType {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
StreamType::from_str(&s).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Serialize, serde::Deserialize)]
|
||||
pub struct StreamSettings<S> {
|
||||
#[serde(rename = "type")]
|
||||
pub ty: StreamType,
|
||||
pub settings: S,
|
||||
}
|
||||
|
||||
pub trait Producer: Send + Sync + 'static {
|
||||
type Settings: Send;
|
||||
type Subscriber: Subscriber;
|
||||
|
||||
fn new(settings: Self::Settings) -> anyhow::Result<Self>
|
||||
where
|
||||
Self: Sized;
|
||||
|
||||
fn send(&self, state: <Self::Subscriber as Subscriber>::Record) -> anyhow::Result<()>;
|
||||
|
||||
fn subscribe(&self) -> anyhow::Result<Self::Subscriber>
|
||||
where
|
||||
Self::Subscriber: Sized;
|
||||
|
||||
fn stop(&self) -> anyhow::Result<()>;
|
||||
}
|
||||
|
||||
pub trait Subscriber {
|
||||
type Record: Serialize + Send + Sync + 'static;
|
||||
|
||||
fn next(&self) -> Option<anyhow::Result<Self::Record>>;
|
||||
|
||||
fn run(self) -> anyhow::Result<()>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
while let Some(state) = self.next() {
|
||||
self.sink(state?)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn sink(&self, state: Self::Record) -> anyhow::Result<()>;
|
||||
}
|
237
simulations/src/streaming/naive.rs
Normal file
237
simulations/src/streaming/naive.rs
Normal file
@ -0,0 +1,237 @@
|
||||
use std::{
|
||||
fs::{File, OpenOptions},
|
||||
io::Write,
|
||||
path::PathBuf,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use super::{Producer, Receivers, Subscriber};
|
||||
use arc_swap::ArcSwapOption;
|
||||
use crossbeam::channel::{bounded, unbounded, Sender};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct NaiveSettings {
|
||||
pub path: PathBuf,
|
||||
}
|
||||
|
||||
impl Default for NaiveSettings {
|
||||
fn default() -> Self {
|
||||
let mut tmp = std::env::temp_dir();
|
||||
tmp.push("simulation");
|
||||
tmp.set_extension("data");
|
||||
Self { path: tmp }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct NaiveProducer<R> {
|
||||
sender: Sender<R>,
|
||||
stop_tx: Sender<()>,
|
||||
recvs: ArcSwapOption<Receivers<R>>,
|
||||
settings: NaiveSettings,
|
||||
}
|
||||
|
||||
impl<R> Producer for NaiveProducer<R>
|
||||
where
|
||||
R: Serialize + Send + Sync + 'static,
|
||||
{
|
||||
type Settings = NaiveSettings;
|
||||
|
||||
type Subscriber = NaiveSubscriber<R>;
|
||||
|
||||
fn new(settings: Self::Settings) -> anyhow::Result<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let (sender, recv) = unbounded();
|
||||
let (stop_tx, stop_rx) = bounded(1);
|
||||
Ok(Self {
|
||||
sender,
|
||||
recvs: ArcSwapOption::from(Some(Arc::new(Receivers { stop_rx, recv }))),
|
||||
stop_tx,
|
||||
settings,
|
||||
})
|
||||
}
|
||||
|
||||
fn send(&self, state: <Self::Subscriber as Subscriber>::Record) -> anyhow::Result<()> {
|
||||
self.sender.send(state).map_err(From::from)
|
||||
}
|
||||
|
||||
fn subscribe(&self) -> anyhow::Result<Self::Subscriber>
|
||||
where
|
||||
Self::Subscriber: Sized,
|
||||
{
|
||||
let recvs = self.recvs.load();
|
||||
if recvs.is_none() {
|
||||
return Err(anyhow::anyhow!("Producer has been subscribed"));
|
||||
}
|
||||
|
||||
let mut opts = OpenOptions::new();
|
||||
let recvs = self.recvs.swap(None).unwrap();
|
||||
let this = NaiveSubscriber {
|
||||
file: Arc::new(Mutex::new(
|
||||
opts.truncate(true)
|
||||
.create(true)
|
||||
.read(true)
|
||||
.write(true)
|
||||
.open(&self.settings.path)?,
|
||||
)),
|
||||
recvs,
|
||||
};
|
||||
eprintln!("Subscribed to {}", self.settings.path.display());
|
||||
Ok(this)
|
||||
}
|
||||
|
||||
fn stop(&self) -> anyhow::Result<()> {
|
||||
Ok(self.stop_tx.send(())?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct NaiveSubscriber<R> {
|
||||
file: Arc<Mutex<File>>,
|
||||
recvs: Arc<Receivers<R>>,
|
||||
}
|
||||
|
||||
impl<R> Subscriber for NaiveSubscriber<R>
|
||||
where
|
||||
R: Serialize + Send + Sync + 'static,
|
||||
{
|
||||
type Record = R;
|
||||
|
||||
fn next(&self) -> Option<anyhow::Result<Self::Record>> {
|
||||
Some(self.recvs.recv.recv().map_err(From::from))
|
||||
}
|
||||
|
||||
fn run(self) -> anyhow::Result<()> {
|
||||
loop {
|
||||
crossbeam::select! {
|
||||
recv(self.recvs.stop_rx) -> _ => {
|
||||
break;
|
||||
}
|
||||
recv(self.recvs.recv) -> msg => {
|
||||
self.sink(msg?)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn sink(&self, state: Self::Record) -> anyhow::Result<()> {
|
||||
let mut file = self.file.lock().expect("failed to lock file");
|
||||
serde_json::to_writer(&mut *file, &state)?;
|
||||
file.write_all(b",\n")?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{collections::HashMap, time::Duration};
|
||||
|
||||
use crate::{
|
||||
network::{
|
||||
behaviour::NetworkBehaviour,
|
||||
regions::{Region, RegionsData},
|
||||
Network,
|
||||
},
|
||||
node::{dummy_streaming::DummyStreamingNode, Node, NodeId},
|
||||
overlay::tree::TreeOverlay,
|
||||
runner::SimulationRunner,
|
||||
warding::SimulationState,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct NaiveRecord {
|
||||
states: HashMap<NodeId, usize>,
|
||||
}
|
||||
|
||||
impl TryFrom<&SimulationState<DummyStreamingNode<()>>> for NaiveRecord {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: &SimulationState<DummyStreamingNode<()>>) -> Result<Self, Self::Error> {
|
||||
Ok(Self {
|
||||
states: value
|
||||
.nodes
|
||||
.read()
|
||||
.expect("failed to read nodes")
|
||||
.iter()
|
||||
.map(|node| (node.id(), node.current_view()))
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming() {
|
||||
let simulation_settings = crate::settings::SimulationSettings {
|
||||
seed: Some(1),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let nodes = (0..6)
|
||||
.map(|idx| DummyStreamingNode::new(NodeId::from(idx), ()))
|
||||
.collect::<Vec<_>>();
|
||||
let network = Network::new(RegionsData {
|
||||
regions: (0..6)
|
||||
.map(|idx| {
|
||||
let region = match idx % 6 {
|
||||
0 => Region::Europe,
|
||||
1 => Region::NorthAmerica,
|
||||
2 => Region::SouthAmerica,
|
||||
3 => Region::Asia,
|
||||
4 => Region::Africa,
|
||||
5 => Region::Australia,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
(region, vec![idx.into()])
|
||||
})
|
||||
.collect(),
|
||||
node_region: (0..6)
|
||||
.map(|idx| {
|
||||
let region = match idx % 6 {
|
||||
0 => Region::Europe,
|
||||
1 => Region::NorthAmerica,
|
||||
2 => Region::SouthAmerica,
|
||||
3 => Region::Asia,
|
||||
4 => Region::Africa,
|
||||
5 => Region::Australia,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
(idx.into(), region)
|
||||
})
|
||||
.collect(),
|
||||
region_network_behaviour: (0..6)
|
||||
.map(|idx| {
|
||||
let region = match idx % 6 {
|
||||
0 => Region::Europe,
|
||||
1 => Region::NorthAmerica,
|
||||
2 => Region::SouthAmerica,
|
||||
3 => Region::Asia,
|
||||
4 => Region::Africa,
|
||||
5 => Region::Australia,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
(
|
||||
(region, region),
|
||||
NetworkBehaviour {
|
||||
delay: Duration::from_millis(100),
|
||||
drop: 0.0,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
});
|
||||
let simulation_runner: SimulationRunner<
|
||||
(),
|
||||
DummyStreamingNode<()>,
|
||||
TreeOverlay,
|
||||
NaiveProducer<NaiveRecord>,
|
||||
> = SimulationRunner::new(network, nodes, simulation_settings);
|
||||
|
||||
simulation_runner.simulate().unwrap();
|
||||
}
|
||||
}
|
194
simulations/src/streaming/polars.rs
Normal file
194
simulations/src/streaming/polars.rs
Normal file
@ -0,0 +1,194 @@
|
||||
use arc_swap::ArcSwapOption;
|
||||
use crossbeam::channel::{bounded, unbounded, Sender};
|
||||
use polars::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
fs::File,
|
||||
io::Cursor,
|
||||
path::{Path, PathBuf},
|
||||
str::FromStr,
|
||||
sync::Mutex,
|
||||
};
|
||||
|
||||
use super::{Producer, Receivers, Subscriber};
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize)]
|
||||
pub enum PolarsFormat {
|
||||
Json,
|
||||
Csv,
|
||||
Parquet,
|
||||
}
|
||||
|
||||
impl FromStr for PolarsFormat {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.trim().to_ascii_lowercase().as_str() {
|
||||
"json" => Ok(Self::Json),
|
||||
"csv" => Ok(Self::Csv),
|
||||
"parquet" => Ok(Self::Parquet),
|
||||
tag => Err(format!(
|
||||
"Invalid {tag} format, only [json, csv, parquet] are supported",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for PolarsFormat {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
PolarsFormat::from_str(&s).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PolarsSettings {
|
||||
pub format: PolarsFormat,
|
||||
pub path: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PolarsProducer<R> {
|
||||
sender: Sender<R>,
|
||||
stop_tx: Sender<()>,
|
||||
recvs: ArcSwapOption<Receivers<R>>,
|
||||
settings: PolarsSettings,
|
||||
}
|
||||
|
||||
impl<R> Producer for PolarsProducer<R>
|
||||
where
|
||||
R: Serialize + Send + Sync + 'static,
|
||||
{
|
||||
type Settings = PolarsSettings;
|
||||
|
||||
type Subscriber = PolarsSubscriber<R>;
|
||||
|
||||
fn new(settings: Self::Settings) -> anyhow::Result<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let (sender, recv) = unbounded();
|
||||
let (stop_tx, stop_rx) = bounded(1);
|
||||
Ok(Self {
|
||||
sender,
|
||||
recvs: ArcSwapOption::from(Some(Arc::new(Receivers { stop_rx, recv }))),
|
||||
stop_tx,
|
||||
settings,
|
||||
})
|
||||
}
|
||||
|
||||
fn send(&self, state: <Self::Subscriber as Subscriber>::Record) -> anyhow::Result<()> {
|
||||
self.sender.send(state).map_err(From::from)
|
||||
}
|
||||
|
||||
fn subscribe(&self) -> anyhow::Result<Self::Subscriber>
|
||||
where
|
||||
Self::Subscriber: Sized,
|
||||
{
|
||||
let recvs = self.recvs.load();
|
||||
if recvs.is_none() {
|
||||
return Err(anyhow::anyhow!("Producer has been subscribed"));
|
||||
}
|
||||
|
||||
let recvs = self.recvs.swap(None).unwrap();
|
||||
let this = PolarsSubscriber {
|
||||
data: Arc::new(Mutex::new(Vec::new())),
|
||||
recvs,
|
||||
path: self.settings.path.clone(),
|
||||
format: self.settings.format,
|
||||
};
|
||||
Ok(this)
|
||||
}
|
||||
|
||||
fn stop(&self) -> anyhow::Result<()> {
|
||||
Ok(self.stop_tx.send(())?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PolarsSubscriber<R> {
|
||||
data: Arc<Mutex<Vec<R>>>,
|
||||
path: PathBuf,
|
||||
format: PolarsFormat,
|
||||
recvs: Arc<Receivers<R>>,
|
||||
}
|
||||
|
||||
impl<R> PolarsSubscriber<R>
|
||||
where
|
||||
R: Serialize,
|
||||
{
|
||||
fn persist(&self) -> anyhow::Result<()> {
|
||||
let data = self
|
||||
.data
|
||||
.lock()
|
||||
.expect("failed to lock data in PolarsSubscriber pesist");
|
||||
let mut cursor = Cursor::new(Vec::new());
|
||||
serde_json::to_writer(&mut cursor, &*data).expect("Dump data to json ");
|
||||
let mut data = JsonReader::new(cursor)
|
||||
.finish()
|
||||
.expect("Load dataframe from intermediary json");
|
||||
|
||||
data.unnest(["state"])?;
|
||||
match self.format {
|
||||
PolarsFormat::Json => dump_dataframe_to_json(&mut data, self.path.as_path()),
|
||||
PolarsFormat::Csv => dump_dataframe_to_csv(&mut data, self.path.as_path()),
|
||||
PolarsFormat::Parquet => dump_dataframe_to_parquet(&mut data, self.path.as_path()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> super::Subscriber for PolarsSubscriber<R>
|
||||
where
|
||||
R: Serialize + Send + Sync + 'static,
|
||||
{
|
||||
type Record = R;
|
||||
|
||||
fn next(&self) -> Option<anyhow::Result<Self::Record>> {
|
||||
Some(self.recvs.recv.recv().map_err(From::from))
|
||||
}
|
||||
|
||||
fn run(self) -> anyhow::Result<()> {
|
||||
loop {
|
||||
crossbeam::select! {
|
||||
recv(self.recvs.stop_rx) -> _ => {
|
||||
return self.persist();
|
||||
}
|
||||
recv(self.recvs.recv) -> msg => {
|
||||
self.sink(msg?)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn sink(&self, state: Self::Record) -> anyhow::Result<()> {
|
||||
self.data
|
||||
.lock()
|
||||
.expect("failed to lock data in PolarsSubscriber")
|
||||
.push(state);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn dump_dataframe_to_json(data: &mut DataFrame, out_path: &Path) -> anyhow::Result<()> {
|
||||
let out_path = out_path.with_extension("json");
|
||||
let f = File::create(out_path)?;
|
||||
let mut writer = polars::prelude::JsonWriter::new(f);
|
||||
Ok(writer.finish(data)?)
|
||||
}
|
||||
|
||||
fn dump_dataframe_to_csv(data: &mut DataFrame, out_path: &Path) -> anyhow::Result<()> {
|
||||
let out_path = out_path.with_extension("csv");
|
||||
let f = File::create(out_path)?;
|
||||
let mut writer = polars::prelude::CsvWriter::new(f);
|
||||
Ok(writer.finish(data)?)
|
||||
}
|
||||
|
||||
fn dump_dataframe_to_parquet(data: &mut DataFrame, out_path: &Path) -> anyhow::Result<()> {
|
||||
let out_path = out_path.with_extension("parquet");
|
||||
let f = File::create(out_path)?;
|
||||
let writer = polars::prelude::ParquetWriter::new(f);
|
||||
Ok(writer.finish(data).map(|_| ())?)
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user