From 0bf6e6d272fae58ed6f2f1a6f69a645cfe9183b6 Mon Sep 17 00:00:00 2001 From: gusto Date: Mon, 8 May 2023 13:06:39 +0300 Subject: [PATCH] Simulation initialization (#122) * Use enums for different settings types * Enum for overlay settings * Configurable simulation overlay * Use duration type for network behaviour delays * Configurable simulation nodes * Runner for different node types * Seedable rng * Convert settings to required objects * Implement IOStreamSettings deserialization * Use common run method for different node types * Configuration for simapp * Testcase for region distribution * Use unix time if seed is not provided --- .gitignore | 4 + sim_config.json.example | 45 +++++++ simulations/src/bin/app.rs | 157 +++++++++++++++++------ simulations/src/network/behaviour.rs | 16 ++- simulations/src/network/mod.rs | 42 ++++++ simulations/src/network/regions.rs | 118 +++++++++++++++++ simulations/src/node/carnot/mod.rs | 12 +- simulations/src/node/dummy.rs | 74 ++++------- simulations/src/node/mod.rs | 16 ++- simulations/src/overlay/flat.rs | 15 ++- simulations/src/overlay/mod.rs | 80 +++++++++++- simulations/src/overlay/tree.rs | 14 +- simulations/src/runner/async_runner.rs | 8 +- simulations/src/runner/glauber_runner.rs | 8 +- simulations/src/runner/layered_runner.rs | 16 +-- simulations/src/runner/mod.rs | 50 +++----- simulations/src/runner/sync_runner.rs | 49 ++++--- simulations/src/settings.rs | 26 ++-- simulations/src/streaming/io.rs | 86 ++++++++----- simulations/src/streaming/mod.rs | 20 ++- simulations/src/streaming/naive.rs | 29 +++-- simulations/src/streaming/polars.rs | 16 ++- 22 files changed, 649 insertions(+), 252 deletions(-) create mode 100644 sim_config.json.example diff --git a/.gitignore b/.gitignore index 7fe3190a..e80810d3 100644 --- a/.gitignore +++ b/.gitignore @@ -8,5 +8,9 @@ Cargo.lock # These are backup files generated by rustfmt **/*.rs.bk + +# Files generated by build processes or applications config.yml store.* +sim_config.json +*.txt diff --git a/sim_config.json.example b/sim_config.json.example new file mode 100644 index 00000000..3b39d754 --- /dev/null +++ b/sim_config.json.example @@ -0,0 +1,45 @@ +{ + "wards": [ + { + "max_view": { + "max_view": 10 + } + }, + { + "min_max_view": { + "max_gap": 5 + } + }, + { + "stalled_view": { + "consecutive_viewed_checkpoint": 123456789, + "criterion": 1000, + "threshold": 5 + } + } + ], + "network_settings": { + "network_behaviors": {}, + "regions": { + "Europe": 0.3 + } + }, + "overlay_settings": { + "Tree": { + "tree_type": "FullBinaryTree", + "committee_size": 1, + "depth": 3 + } + }, + "node_settings": "Dummy", + "runner_settings": "Sync", + "stream_settings": { + "Naive": { + "path": "sim_naive_stream.txt" + } + }, + "node_count": 10, + "views_count": 20, + "leaders_count": 2, + "seed": 12345 +} diff --git a/simulations/src/bin/app.rs b/simulations/src/bin/app.rs index 89aa20d5..243ea0e6 100644 --- a/simulations/src/bin/app.rs +++ b/simulations/src/bin/app.rs @@ -1,13 +1,24 @@ // std -use std::collections::HashMap; +use anyhow::Ok; +use serde::Serialize; +use std::collections::BTreeMap; use std::fs::File; use std::path::{Path, PathBuf}; +use std::sync::{Arc, RwLock}; +use std::time::{SystemTime, UNIX_EPOCH}; // crates use clap::Parser; +use crossbeam::channel; +use rand::rngs::SmallRng; +use rand::seq::SliceRandom; +use rand::{Rng, SeedableRng}; use serde::de::DeserializeOwned; -use simulations::network::regions::RegionsData; -use simulations::network::Network; -use simulations::overlay::tree::TreeOverlay; +use simulations::network::behaviour::create_behaviours; +use simulations::network::regions::{create_regions, RegionsData}; +use simulations::network::{InMemoryNetworkInterface, Network}; +use simulations::node::dummy::DummyNode; +use simulations::node::{Node, NodeId, OverlayState, ViewOverlay}; +use simulations::overlay::{create_overlay, Overlay, SimulationOverlay}; use simulations::streaming::StreamType; // internal use simulations::{ @@ -33,57 +44,123 @@ impl SimulationApp { input_settings, stream_type, } = self; + let simulation_settings: SimulationSettings = load_json_from_file(&input_settings)?; - let nodes = vec![]; // TODO: Initialize nodes of different types. - let regions_data = RegionsData::new(HashMap::new(), HashMap::new()); - let network = Network::new(regions_data); + let seed = simulation_settings.seed.unwrap_or_else(|| { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_secs() + }); + let mut rng = SmallRng::seed_from_u64(seed); + let mut node_ids: Vec = (0..simulation_settings.node_count) + .map(Into::into) + .collect(); + node_ids.shuffle(&mut rng); - // build up series vector - match stream_type { - simulations::streaming::StreamType::Naive => { - let simulation_settings: SimulationSettings<_, _, _> = - load_json_from_file(&input_settings)?; - let simulation_runner: SimulationRunner< - (), - CarnotNode, - TreeOverlay, - NaiveProducer, - > = SimulationRunner::new(network, nodes, simulation_settings); - simulation_runner.simulate()? + let regions = create_regions(&node_ids, &mut rng, &simulation_settings.network_settings); + let behaviours = create_behaviours(&simulation_settings.network_settings); + let regions_data = RegionsData::new(regions, behaviours); + let overlay = create_overlay(&simulation_settings.overlay_settings); + let overlays = generate_overlays( + &node_ids, + &overlay, + simulation_settings.views_count, + simulation_settings.leaders_count, + &mut rng, + ); + + let overlay_state = Arc::new(RwLock::new(OverlayState { + all_nodes: node_ids.clone(), + overlay, + overlays, + })); + + let mut network = Network::new(regions_data); + + match &simulation_settings.node_settings { + simulations::settings::NodeSettings::Carnot => { + let nodes = node_ids + .iter() + .map(|node_id| CarnotNode::new(*node_id)) + .collect(); + run(network, nodes, simulation_settings, stream_type)?; } - simulations::streaming::StreamType::Polars => { - let simulation_settings: SimulationSettings<_, _, _> = - load_json_from_file(&input_settings)?; - let simulation_runner: SimulationRunner< - (), - CarnotNode, - TreeOverlay, - PolarsProducer, - > = SimulationRunner::new(network, nodes, simulation_settings); - simulation_runner.simulate()? - } - simulations::streaming::StreamType::IO => { - let simulation_settings: SimulationSettings<_, _, _> = - load_json_from_file(&input_settings)?; - let simulation_runner: SimulationRunner< - (), - CarnotNode, - TreeOverlay, - IOProducer, - > = SimulationRunner::new(network, nodes, simulation_settings); - simulation_runner.simulate()? + simulations::settings::NodeSettings::Dummy => { + let nodes = node_ids + .iter() + .map(|node_id| { + let (node_message_sender, node_message_receiver) = channel::unbounded(); + let network_message_receiver = + network.connect(*node_id, node_message_receiver); + let network_interface = InMemoryNetworkInterface::new( + *node_id, + node_message_sender, + network_message_receiver, + ); + DummyNode::new(*node_id, 0, overlay_state.clone(), network_interface) + }) + .collect(); + run(network, nodes, simulation_settings, stream_type)?; } }; Ok(()) } } +fn run( + network: Network, + nodes: Vec, + settings: SimulationSettings, + stream_type: StreamType, +) -> anyhow::Result<()> +where + M: Clone + Send + Sync + 'static, + N: Send + Sync + 'static, + N::Settings: Clone + Send, + N::State: Serialize, +{ + let runner = SimulationRunner::new(network, nodes, settings); + match stream_type { + simulations::streaming::StreamType::Naive => runner.simulate::>()?, + simulations::streaming::StreamType::Polars => { + runner.simulate::>()? + } + simulations::streaming::StreamType::IO => { + runner.simulate::>()? + } + }; + Ok(()) +} + /// Generically load a json file fn load_json_from_file(path: &Path) -> anyhow::Result { let f = File::open(path).map_err(Box::new)?; Ok(serde_json::from_reader(f)?) } +// Helper method to pregenerate views. +// TODO: Remove once shared overlay can generate new views on demand. +fn generate_overlays( + node_ids: &[NodeId], + overlay: &SimulationOverlay, + overlay_count: usize, + leader_count: usize, + rng: &mut R, +) -> BTreeMap { + (0..overlay_count) + .map(|view_id| { + ( + view_id, + ViewOverlay { + leaders: overlay.leaders(node_ids, leader_count, rng).collect(), + layout: overlay.layout(node_ids, rng), + }, + ) + }) + .collect() +} + fn main() -> anyhow::Result<()> { let app: SimulationApp = SimulationApp::parse(); app.run()?; diff --git a/simulations/src/network/behaviour.rs b/simulations/src/network/behaviour.rs index 577b50f6..a0606564 100644 --- a/simulations/src/network/behaviour.rs +++ b/simulations/src/network/behaviour.rs @@ -1,8 +1,10 @@ // std -use std::time::Duration; +use std::{collections::HashMap, time::Duration}; // crates use rand::Rng; use serde::{Deserialize, Serialize}; + +use super::{regions::Region, NetworkSettings}; // internal #[derive(Default, Debug, Clone, Serialize, Deserialize)] @@ -24,3 +26,15 @@ impl NetworkBehaviour { rng.gen_bool(self.drop) } } + +// Takes a reference to the simulation_settings and returns a HashMap representing the +// network behaviors for pairs of NodeIds. +pub fn create_behaviours( + network_settings: &NetworkSettings, +) -> HashMap<(Region, Region), NetworkBehaviour> { + network_settings + .network_behaviors + .iter() + .map(|((a, b), d)| ((*a, *b), NetworkBehaviour::new(*d, 0.0))) + .collect() +} diff --git a/simulations/src/network/mod.rs b/simulations/src/network/mod.rs index f0236858..943f9c26 100644 --- a/simulations/src/network/mod.rs +++ b/simulations/src/network/mod.rs @@ -8,6 +8,7 @@ use std::{ use crossbeam::channel::{self, Receiver, Sender}; use rand::{rngs::ThreadRng, Rng}; use rayon::prelude::*; +use serde::Deserialize; // internal use crate::node::NodeId; @@ -16,6 +17,14 @@ pub mod regions; type NetworkTime = Instant; +#[derive(Clone, Debug, Deserialize, Default)] +pub struct NetworkSettings { + pub network_behaviors: HashMap<(regions::Region, regions::Region), Duration>, + /// Represents node distribution in the simulated regions. + /// The sum of distributions should be 1. + pub regions: HashMap, +} + pub struct Network { pub regions: regions::RegionsData, network_time: NetworkTime, @@ -144,6 +153,39 @@ pub trait NetworkInterface { fn receive_messages(&self) -> Vec>; } +pub struct InMemoryNetworkInterface { + id: NodeId, + sender: Sender>, + receiver: Receiver>, +} + +impl InMemoryNetworkInterface { + pub fn new( + id: NodeId, + sender: Sender>, + receiver: Receiver>, + ) -> Self { + Self { + id, + sender, + receiver, + } + } +} + +impl NetworkInterface for InMemoryNetworkInterface { + type Payload = M; + + fn send_message(&self, address: NodeId, message: Self::Payload) { + let message = NetworkMessage::new(self.id, address, message); + self.sender.send(message).unwrap(); + } + + fn receive_messages(&self) -> Vec> { + self.receiver.try_iter().collect() + } +} + #[cfg(test)] mod tests { use super::{ diff --git a/simulations/src/network/regions.rs b/simulations/src/network/regions.rs index 54e093f6..887f93a9 100644 --- a/simulations/src/network/regions.rs +++ b/simulations/src/network/regions.rs @@ -1,10 +1,13 @@ // std +use rand::{seq::SliceRandom, Rng}; use std::collections::HashMap; // crates use serde::{Deserialize, Serialize}; // internal use crate::{network::behaviour::NetworkBehaviour, node::NodeId}; +use super::NetworkSettings; + #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] pub enum Region { NorthAmerica, @@ -56,3 +59,118 @@ impl RegionsData { &self.regions[®ion] } } + +// Takes a reference to the node_ids and simulation_settings and returns a HashMap +// representing the regions and their associated node IDs. +pub fn create_regions( + node_ids: &[NodeId], + rng: &mut R, + network_settings: &NetworkSettings, +) -> HashMap> { + let mut region_nodes = node_ids.to_vec(); + region_nodes.shuffle(rng); + + let regions = network_settings + .regions + .clone() + .into_iter() + .collect::>(); + + let last_region_index = regions.len() - 1; + + regions + .iter() + .enumerate() + .map(|(i, (region, distribution))| { + if i < last_region_index { + let node_count = (node_ids.len() as f32 * distribution).round() as usize; + let nodes = region_nodes.drain(..node_count).collect::>(); + (*region, nodes) + } else { + // Assign the remaining nodes to the last region. + (*region, region_nodes.clone()) + } + }) + .collect() +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use rand::rngs::mock::StepRng; + + use crate::{ + network::{ + regions::{create_regions, Region}, + NetworkSettings, + }, + node::NodeId, + }; + + #[test] + fn create_regions_precision() { + struct TestCase { + node_count: usize, + distributions: Vec, + } + + let test_cases = vec![ + TestCase { + node_count: 10, + distributions: vec![0.5, 0.3, 0.2], + }, + TestCase { + node_count: 7, + distributions: vec![0.6, 0.4], + }, + TestCase { + node_count: 20, + distributions: vec![0.4, 0.3, 0.2, 0.1], + }, + TestCase { + node_count: 23, + distributions: vec![0.4, 0.3, 0.3], + }, + TestCase { + node_count: 111, + distributions: vec![0.3, 0.3, 0.3, 0.1], + }, + TestCase { + node_count: 73, + distributions: vec![0.3, 0.2, 0.2, 0.2, 0.1], + }, + ]; + let mut rng = StepRng::new(1, 0); + + for tcase in test_cases.iter() { + let nodes = (0..tcase.node_count) + .map(Into::into) + .collect::>(); + + let available_regions = vec![ + Region::NorthAmerica, + Region::Europe, + Region::Asia, + Region::Africa, + Region::SouthAmerica, + Region::Australia, + ]; + + let mut region_distribution = HashMap::new(); + for (region, &dist) in available_regions.iter().zip(&tcase.distributions) { + region_distribution.insert(*region, dist); + } + + let settings = NetworkSettings { + network_behaviors: HashMap::new(), + regions: region_distribution, + }; + + let regions = create_regions(&nodes, &mut rng, &settings); + + let total_nodes_in_regions = regions.values().map(|v| v.len()).sum::(); + assert_eq!(total_nodes_in_regions, nodes.len()); + } + } +} diff --git a/simulations/src/node/carnot/mod.rs b/simulations/src/node/carnot/mod.rs index c44e4f95..c0078228 100644 --- a/simulations/src/node/carnot/mod.rs +++ b/simulations/src/node/carnot/mod.rs @@ -7,7 +7,7 @@ use super::{Node, NodeId}; #[derive(Default, Serialize)] pub struct CarnotState {} -#[derive(Clone, Deserialize)] +#[derive(Clone, Default, Deserialize)] pub struct CarnotSettings {} #[allow(dead_code)] // TODO: remove when handling settings @@ -17,6 +17,16 @@ pub struct CarnotNode { settings: CarnotSettings, } +impl CarnotNode { + pub fn new(id: NodeId) -> Self { + Self { + id, + state: Default::default(), + settings: Default::default(), + } + } +} + impl Node for CarnotNode { type Settings = CarnotSettings; type State = CarnotState; diff --git a/simulations/src/node/dummy.rs b/simulations/src/node/dummy.rs index ae7ec480..b6dcf2c4 100644 --- a/simulations/src/node/dummy.rs +++ b/simulations/src/node/dummy.rs @@ -1,11 +1,10 @@ // std use std::collections::{BTreeMap, BTreeSet}; // crates -use crossbeam::channel::{Receiver, Sender}; use serde::{Deserialize, Serialize}; // internal use crate::{ - network::{NetworkInterface, NetworkMessage}, + network::{InMemoryNetworkInterface, NetworkInterface, NetworkMessage}, node::{Node, NodeId}, }; @@ -131,7 +130,7 @@ pub struct DummyNode { state: DummyState, _settings: DummySettings, overlay_state: SharedState, - network_interface: DummyNetworkInterface, + network_interface: InMemoryNetworkInterface, local_view: LocalView, // Node in current view might be a leader in the next view. @@ -154,7 +153,7 @@ impl DummyNode { node_id: NodeId, view_id: usize, overlay_state: SharedState, - network_interface: DummyNetworkInterface, + network_interface: InMemoryNetworkInterface, ) -> Self { Self { node_id, @@ -373,39 +372,6 @@ impl Node for DummyNode { } } -pub struct DummyNetworkInterface { - id: NodeId, - sender: Sender>, - receiver: Receiver>, -} - -impl DummyNetworkInterface { - pub fn new( - id: NodeId, - sender: Sender>, - receiver: Receiver>, - ) -> Self { - Self { - id, - sender, - receiver, - } - } -} - -impl NetworkInterface for DummyNetworkInterface { - type Payload = DummyMessage; - - fn send_message(&self, address: NodeId, message: Self::Payload) { - let message = NetworkMessage::new(self.id, address, message); - self.sender.send(message).unwrap(); - } - - fn receive_messages(&self) -> Vec> { - self.receiver.try_iter().collect() - } -} - fn get_parent_nodes(node_id: NodeId, view: &ViewOverlay) -> Option> { let committee_id = view.layout.committee(node_id)?; view.layout.parent_nodes(committee_id).map(|c| c.nodes) @@ -468,11 +434,11 @@ mod tests { network::{ behaviour::NetworkBehaviour, regions::{Region, RegionsData}, - Network, + InMemoryNetworkInterface, Network, }, node::{ dummy::{get_child_nodes, get_parent_nodes, get_roles, DummyRole}, - Node, NodeId, OverlayState, SharedState, ViewOverlay, + Node, NodeId, OverlayState, SharedState, SimulationOverlay, ViewOverlay, }, overlay::{ tree::{TreeOverlay, TreeSettings}, @@ -480,7 +446,7 @@ mod tests { }, }; - use super::{DummyMessage, DummyNetworkInterface, DummyNode, Intent, Vote}; + use super::{DummyMessage, DummyNode, Intent, Vote}; fn init_network(node_ids: &[NodeId]) -> Network { let regions = HashMap::from([(Region::Europe, node_ids.to_vec())]); @@ -502,7 +468,7 @@ mod tests { .map(|node_id| { let (node_message_sender, node_message_receiver) = channel::unbounded(); let network_message_receiver = network.connect(*node_id, node_message_receiver); - let network_interface = DummyNetworkInterface::new( + let network_interface = InMemoryNetworkInterface::new( *node_id, node_message_sender, network_message_receiver, @@ -515,9 +481,9 @@ mod tests { .collect() } - fn generate_overlays( + fn generate_overlays( node_ids: &[NodeId], - overlay: O, + overlay: &SimulationOverlay, overlay_count: usize, leader_count: usize, rng: &mut R, @@ -576,6 +542,7 @@ mod tests { }; let overlay_state = Arc::new(RwLock::new(OverlayState { all_nodes: node_ids.clone(), + overlay: SimulationOverlay::Tree(overlay), overlays: BTreeMap::from([ (0, view.clone()), (1, view.clone()), @@ -710,19 +677,20 @@ mod tests { let mut rng = SmallRng::seed_from_u64(timestamp); let committee_size = 1; - let overlay = TreeOverlay::new(TreeSettings { + let overlay = SimulationOverlay::Tree(TreeOverlay::new(TreeSettings { tree_type: Default::default(), depth: 3, committee_size, - }); + })); // There are more nodes in the network than in a tree overlay. let node_ids: Vec = (0..100).map(Into::into).collect(); let mut network = init_network(&node_ids); - let overlays = generate_overlays(&node_ids, overlay, 4, 3, &mut rng); + let overlays = generate_overlays(&node_ids, &overlay, 4, 3, &mut rng); let overlay_state = Arc::new(RwLock::new(OverlayState { all_nodes: node_ids.clone(), + overlay, overlays: overlays.clone(), })); @@ -759,19 +727,20 @@ mod tests { let mut rng = SmallRng::seed_from_u64(timestamp); let committee_size = 100; - let overlay = TreeOverlay::new(TreeSettings { + let overlay = SimulationOverlay::Tree(TreeOverlay::new(TreeSettings { tree_type: Default::default(), depth: 3, committee_size, - }); + })); // There are more nodes in the network than in a tree overlay. let node_ids: Vec = (0..10000).map(Into::into).collect(); let mut network = init_network(&node_ids); - let overlays = generate_overlays(&node_ids, overlay, 4, 100, &mut rng); + let overlays = generate_overlays(&node_ids, &overlay, 4, 100, &mut rng); let overlay_state = Arc::new(RwLock::new(OverlayState { all_nodes: node_ids.clone(), + overlay, overlays: overlays.clone(), })); @@ -808,19 +777,20 @@ mod tests { let mut rng = SmallRng::seed_from_u64(timestamp); let committee_size = 500; - let overlay = TreeOverlay::new(TreeSettings { + let overlay = SimulationOverlay::Tree(TreeOverlay::new(TreeSettings { tree_type: Default::default(), depth: 5, committee_size, - }); + })); // There are more nodes in the network than in a tree overlay. let node_ids: Vec = (0..100000).map(Into::into).collect(); let mut network = init_network(&node_ids); - let overlays = generate_overlays(&node_ids, overlay, 4, 1000, &mut rng); + let overlays = generate_overlays(&node_ids, &overlay, 4, 1000, &mut rng); let overlay_state = Arc::new(RwLock::new(OverlayState { all_nodes: node_ids.clone(), + overlay, overlays: overlays.clone(), })); diff --git a/simulations/src/node/mod.rs b/simulations/src/node/mod.rs index 5964d9b2..83fcc3d6 100644 --- a/simulations/src/node/mod.rs +++ b/simulations/src/node/mod.rs @@ -14,7 +14,7 @@ use std::{ // crates use serde::{Deserialize, Serialize}; // internal -use crate::overlay::Layout; +use crate::overlay::{Layout, OverlaySettings, SimulationOverlay}; #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] #[serde(transparent)] @@ -126,11 +126,25 @@ pub struct ViewOverlay { pub layout: Layout, } +impl From for ViewOverlay { + fn from(value: OverlaySettings) -> Self { + match value { + OverlaySettings::Flat => { + todo!() + } + OverlaySettings::Tree(_) => { + todo!() + } + } + } +} + pub type SharedState = Arc>; /// A state that represents how nodes are interconnected in the network. pub struct OverlayState { pub all_nodes: Vec, + pub overlay: SimulationOverlay, pub overlays: BTreeMap, } diff --git a/simulations/src/overlay/flat.rs b/simulations/src/overlay/flat.rs index 7cd2e27a..52387b67 100644 --- a/simulations/src/overlay/flat.rs +++ b/simulations/src/overlay/flat.rs @@ -8,14 +8,19 @@ use crate::node::NodeId; use crate::overlay::{Committee, Layout}; pub struct FlatOverlay; - -impl Overlay for FlatOverlay { - type Settings = (); - - fn new(_settings: Self::Settings) -> Self { +impl FlatOverlay { + pub fn new() -> Self { Self } +} +impl Default for FlatOverlay { + fn default() -> Self { + Self::new() + } +} + +impl Overlay for FlatOverlay { fn nodes(&self) -> Vec { (0..10).map(NodeId::from).collect() } diff --git a/simulations/src/overlay/mod.rs b/simulations/src/overlay/mod.rs index e8b0c60b..11abd72c 100644 --- a/simulations/src/overlay/mod.rs +++ b/simulations/src/overlay/mod.rs @@ -5,9 +5,12 @@ pub mod tree; use std::collections::{BTreeSet, HashMap}; // crates use rand::Rng; +use serde::Deserialize; // internal use crate::node::{CommitteeId, NodeId}; +use self::tree::TreeSettings; + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct Committee { pub nodes: BTreeSet, @@ -89,10 +92,70 @@ impl Layout { } } -pub trait Overlay { - type Settings; +pub enum SimulationOverlay { + Flat(flat::FlatOverlay), + Tree(tree::TreeOverlay), +} - fn new(settings: Self::Settings) -> Self; +#[derive(Clone, Debug, Deserialize)] +pub enum OverlaySettings { + Flat, + Tree(TreeSettings), +} + +impl Default for OverlaySettings { + fn default() -> Self { + Self::Tree(Default::default()) + } +} + +impl From for OverlaySettings { + fn from(settings: TreeSettings) -> OverlaySettings { + OverlaySettings::Tree(settings) + } +} + +impl TryInto for OverlaySettings { + type Error = String; + + fn try_into(self) -> Result { + if let Self::Tree(settings) = self { + Ok(settings) + } else { + Err("unable to convert to tree settings".into()) + } + } +} + +impl Overlay for SimulationOverlay { + fn nodes(&self) -> Vec { + match self { + SimulationOverlay::Flat(overlay) => overlay.nodes(), + SimulationOverlay::Tree(overlay) => overlay.nodes(), + } + } + + fn leaders( + &self, + nodes: &[NodeId], + size: usize, + rng: &mut R, + ) -> Box> { + match self { + SimulationOverlay::Flat(overlay) => overlay.leaders(nodes, size, rng), + SimulationOverlay::Tree(overlay) => overlay.leaders(nodes, size, rng), + } + } + + fn layout(&self, nodes: &[NodeId], rng: &mut R) -> Layout { + match self { + SimulationOverlay::Flat(overlay) => overlay.layout(nodes, rng), + SimulationOverlay::Tree(overlay) => overlay.layout(nodes, rng), + } + } +} + +pub trait Overlay { fn nodes(&self) -> Vec; fn leaders( &self, @@ -102,3 +165,14 @@ pub trait Overlay { ) -> Box>; fn layout(&self, nodes: &[NodeId], rng: &mut R) -> Layout; } + +// Takes a reference to the simulation_settings and returns a SimulationOverlay instance based +// on the overlay settings specified in simulation_settings. +pub fn create_overlay(overlay_settings: &OverlaySettings) -> SimulationOverlay { + match &overlay_settings { + OverlaySettings::Flat => SimulationOverlay::Flat(flat::FlatOverlay::new()), + OverlaySettings::Tree(settings) => { + SimulationOverlay::Tree(tree::TreeOverlay::new(settings.clone())) + } + } +} diff --git a/simulations/src/overlay/tree.rs b/simulations/src/overlay/tree.rs index 7dc03400..31af6eb3 100644 --- a/simulations/src/overlay/tree.rs +++ b/simulations/src/overlay/tree.rs @@ -7,13 +7,13 @@ use serde::Deserialize; use super::{Committee, Layout, Overlay}; use crate::node::{CommitteeId, NodeId}; -#[derive(Clone, Default, Deserialize)] +#[derive(Clone, Debug, Default, Deserialize)] pub enum TreeType { #[default] FullBinaryTree, } -#[derive(Clone, Deserialize)] +#[derive(Clone, Debug, Deserialize)] pub struct TreeSettings { pub tree_type: TreeType, pub committee_size: usize, @@ -40,6 +40,10 @@ struct TreeProperties { } impl TreeOverlay { + pub fn new(settings: TreeSettings) -> Self { + Self { settings } + } + fn build_full_binary_tree( node_ids: &[NodeId], rng: &mut R, @@ -99,12 +103,6 @@ impl TreeOverlay { } impl Overlay for TreeOverlay { - type Settings = TreeSettings; - - fn new(settings: Self::Settings) -> Self { - Self { settings } - } - fn nodes(&self) -> Vec { let properties = get_tree_properties(&self.settings); (0..properties.node_count).map(From::from).collect() diff --git a/simulations/src/runner/async_runner.rs b/simulations/src/runner/async_runner.rs index 749a235b..b2f83294 100644 --- a/simulations/src/runner/async_runner.rs +++ b/simulations/src/runner/async_runner.rs @@ -1,5 +1,4 @@ use crate::node::{Node, NodeId}; -use crate::overlay::Overlay; use crate::runner::{SimulationRunner, SimulationRunnerHandle}; use crate::streaming::{Producer, Subscriber}; use crate::warding::SimulationState; @@ -12,8 +11,8 @@ use std::collections::HashSet; use std::sync::Arc; /// Simulate with sending the network state to any subscriber -pub fn simulate( - runner: SimulationRunner, +pub fn simulate( + runner: SimulationRunner, chunk_size: usize, ) -> anyhow::Result where @@ -21,7 +20,6 @@ where N: Send + Sync + 'static, N::Settings: Clone + Send, N::State: Serialize, - O::Settings: Clone + Send, P::Subscriber: Send + Sync + 'static, ::Record: Send + Sync + 'static + for<'a> TryFrom<&'a SimulationState, Error = anyhow::Error>, @@ -44,7 +42,7 @@ where let handle = SimulationRunnerHandle { stop_tx, handle: std::thread::spawn(move || { - let p = P::new(runner.stream_settings.settings)?; + let p = P::new(runner.stream_settings)?; scopeguard::defer!(if let Err(e) = p.stop() { eprintln!("Error stopping producer: {e}"); }); diff --git a/simulations/src/runner/glauber_runner.rs b/simulations/src/runner/glauber_runner.rs index 4e9a37fe..0c7fa3aa 100644 --- a/simulations/src/runner/glauber_runner.rs +++ b/simulations/src/runner/glauber_runner.rs @@ -1,5 +1,4 @@ use crate::node::{Node, NodeId}; -use crate::overlay::Overlay; use crate::runner::{SimulationRunner, SimulationRunnerHandle}; use crate::streaming::{Producer, Subscriber}; use crate::warding::SimulationState; @@ -13,8 +12,8 @@ use std::sync::Arc; /// Simulate with sending the network state to any subscriber. /// /// [Glauber dynamics simulation](https://en.wikipedia.org/wiki/Glauber_dynamics) -pub fn simulate( - runner: SimulationRunner, +pub fn simulate( + runner: SimulationRunner, update_rate: usize, maximum_iterations: usize, ) -> anyhow::Result @@ -23,7 +22,6 @@ where N: Send + Sync + 'static, N::Settings: Clone + Send, N::State: Serialize, - O::Settings: Clone + Send, P::Subscriber: Send + Sync + 'static, ::Record: for<'a> TryFrom<&'a SimulationState, Error = anyhow::Error>, @@ -42,7 +40,7 @@ where let (stop_tx, stop_rx) = bounded(1); let handle = SimulationRunnerHandle { handle: std::thread::spawn(move || { - let p = P::new(runner.stream_settings.settings)?; + let p = P::new(runner.stream_settings)?; scopeguard::defer!(if let Err(e) = p.stop() { eprintln!("Error stopping producer: {e}"); }); diff --git a/simulations/src/runner/layered_runner.rs b/simulations/src/runner/layered_runner.rs index 8820e538..e9248b6d 100644 --- a/simulations/src/runner/layered_runner.rs +++ b/simulations/src/runner/layered_runner.rs @@ -40,7 +40,6 @@ use rand::rngs::SmallRng; use serde::Serialize; // internal use crate::node::{Node, NodeId}; -use crate::overlay::Overlay; use crate::runner::SimulationRunner; use crate::streaming::{Producer, Subscriber}; use crate::warding::SimulationState; @@ -48,8 +47,8 @@ use crate::warding::SimulationState; use super::SimulationRunnerHandle; /// Simulate with sending the network state to any subscriber -pub fn simulate( - runner: SimulationRunner, +pub fn simulate( + runner: SimulationRunner, gap: usize, distribution: Option>, ) -> anyhow::Result @@ -58,7 +57,6 @@ where N: Send + Sync + 'static, N::Settings: Clone + Send, N::State: Serialize, - O::Settings: Clone + Send, P::Subscriber: Send + Sync + 'static, ::Record: for<'a> TryFrom<&'a SimulationState, Error = anyhow::Error>, @@ -68,7 +66,7 @@ where let layers: Vec = (0..gap).collect(); - let mut deque = build_node_ids_deque(gap, &runner); + let mut deque = build_node_ids_deque::(gap, &runner); let simulation_state = SimulationState { nodes: Arc::clone(&runner.nodes), @@ -80,7 +78,7 @@ where let handle = SimulationRunnerHandle { stop_tx, handle: std::thread::spawn(move || { - let p = P::new(runner.stream_settings.settings)?; + let p = P::new(runner.stream_settings)?; scopeguard::defer!(if let Err(e) = p.stop() { eprintln!("Error stopping producer: {e}"); }); @@ -172,14 +170,12 @@ fn choose_random_layer_and_node_id( (i, *node_id) } -fn build_node_ids_deque( +fn build_node_ids_deque( gap: usize, - runner: &SimulationRunner, + runner: &SimulationRunner, ) -> FixedSliceDeque> where N: Node, - O: Overlay, - P: Producer, { // add a +1 so we always have let mut deque = FixedSliceDeque::new(gap + 1); diff --git a/simulations/src/runner/mod.rs b/simulations/src/runner/mod.rs index d4169be4..dbffa71e 100644 --- a/simulations/src/runner/mod.rs +++ b/simulations/src/runner/mod.rs @@ -4,12 +4,11 @@ mod layered_runner; mod sync_runner; // std -use std::marker::PhantomData; use std::sync::{Arc, RwLock}; use std::time::Duration; // crates -use crate::streaming::{Producer, Subscriber}; +use crate::streaming::{Producer, StreamSettings, Subscriber}; use crossbeam::channel::Sender; use rand::rngs::SmallRng; use rand::{RngCore, SeedableRng}; @@ -19,9 +18,7 @@ use serde::Serialize; // internal use crate::network::Network; use crate::node::Node; -use crate::overlay::Overlay; use crate::settings::{RunnerSettings, SimulationSettings}; -use crate::streaming::StreamSettings; use crate::warding::{SimulationState, SimulationWard, Ward}; pub struct SimulationRunnerHandle { @@ -81,35 +78,24 @@ where /// Encapsulation solution for the simulations runner /// Holds the network state, the simulating nodes and the simulation settings. -pub struct SimulationRunner +pub struct SimulationRunner where N: Node, - O: Overlay, - P: Producer, { inner: Arc>>, nodes: Arc>>, runner_settings: RunnerSettings, - stream_settings: StreamSettings, - _overlay: PhantomData, + stream_settings: StreamSettings, } -impl SimulationRunner +impl SimulationRunner where M: Clone + Send + Sync + 'static, N: Send + Sync + 'static, N::Settings: Clone + Send, N::State: Serialize, - O::Settings: Clone + Send, - P::Subscriber: Send + Sync + 'static, - ::Record: - Send + Sync + 'static + for<'a> TryFrom<&'a SimulationState, Error = anyhow::Error>, { - pub fn new( - network: Network, - nodes: Vec, - settings: SimulationSettings, - ) -> Self { + pub fn new(network: Network, nodes: Vec, settings: SimulationSettings) -> Self { let seed = settings .seed .unwrap_or_else(|| rand::thread_rng().next_u64()); @@ -119,42 +105,46 @@ where let rng = SmallRng::seed_from_u64(seed); let nodes = Arc::new(RwLock::new(nodes)); let SimulationSettings { - network_behaviors: _, - regions: _, wards, overlay_settings: _, node_settings: _, runner_settings, stream_settings, node_count: _, - committee_size: _, seed: _, + views_count: _, + leaders_count: _, + network_settings: _, } = settings; Self { - stream_settings, - runner_settings, inner: Arc::new(RwLock::new(SimulationRunnerInner { network, rng, wards, })), nodes, - _overlay: PhantomData, + runner_settings, + stream_settings, } } - pub fn simulate(self) -> anyhow::Result { + pub fn simulate(self) -> anyhow::Result + where + P::Subscriber: Send + Sync + 'static, + ::Record: + Send + Sync + 'static + for<'a> TryFrom<&'a SimulationState, Error = anyhow::Error>, + { match self.runner_settings.clone() { - RunnerSettings::Sync => sync_runner::simulate::<_, _, _, P>(self), - RunnerSettings::Async { chunks } => async_runner::simulate::<_, _, _, P>(self, chunks), + RunnerSettings::Sync => sync_runner::simulate::<_, _, P>(self), + RunnerSettings::Async { chunks } => async_runner::simulate::<_, _, P>(self, chunks), RunnerSettings::Glauber { maximum_iterations, update_rate, - } => glauber_runner::simulate::<_, _, _, P>(self, update_rate, maximum_iterations), + } => glauber_runner::simulate::<_, _, P>(self, update_rate, maximum_iterations), RunnerSettings::Layered { rounds_gap, distribution, - } => layered_runner::simulate::<_, _, _, P>(self, rounds_gap, distribution), + } => layered_runner::simulate::<_, _, P>(self, rounds_gap, distribution), } } } diff --git a/simulations/src/runner/sync_runner.rs b/simulations/src/runner/sync_runner.rs index 86bcc53a..f5810509 100644 --- a/simulations/src/runner/sync_runner.rs +++ b/simulations/src/runner/sync_runner.rs @@ -2,22 +2,20 @@ use serde::Serialize; use super::{SimulationRunner, SimulationRunnerHandle}; use crate::node::Node; -use crate::overlay::Overlay; use crate::streaming::{Producer, Subscriber}; use crate::warding::SimulationState; use crossbeam::channel::{bounded, select}; use std::sync::Arc; /// Simulate with sending the network state to any subscriber -pub fn simulate( - runner: SimulationRunner, +pub fn simulate( + runner: SimulationRunner, ) -> anyhow::Result where M: Send + Sync + Clone + 'static, N: Send + Sync + 'static, N::Settings: Clone + Send, N::State: Serialize, - O::Settings: Clone, P::Subscriber: Send + Sync + 'static, ::Record: Send + Sync + 'static + for<'a> TryFrom<&'a SimulationState, Error = anyhow::Error>, @@ -33,7 +31,7 @@ where let handle = SimulationRunnerHandle { stop_tx, handle: std::thread::spawn(move || { - let p = P::new(runner.stream_settings.settings)?; + let p = P::new(runner.stream_settings)?; scopeguard::defer!(if let Err(e) = p.stop() { eprintln!("Error stopping producer: {e}"); }); @@ -79,20 +77,18 @@ mod tests { network::{ behaviour::NetworkBehaviour, regions::{Region, RegionsData}, - Network, + InMemoryNetworkInterface, Network, }, node::{ - dummy::{DummyMessage, DummyNetworkInterface, DummyNode, DummySettings}, + dummy::{DummyMessage, DummyNode}, Node, NodeId, OverlayState, SharedState, ViewOverlay, }, - output_processors::OutData, overlay::{ tree::{TreeOverlay, TreeSettings}, - Overlay, + Overlay, SimulationOverlay, }, runner::SimulationRunner, settings::SimulationSettings, - streaming::naive::{NaiveProducer, NaiveSettings}, }; use crossbeam::channel; use rand::rngs::mock::StepRng; @@ -122,7 +118,7 @@ mod tests { .map(|node_id| { let (node_message_sender, node_message_receiver) = channel::unbounded(); let network_message_receiver = network.connect(*node_id, node_message_receiver); - let network_interface = DummyNetworkInterface::new( + let network_interface = InMemoryNetworkInterface::new( *node_id, node_message_sender, network_message_receiver, @@ -134,16 +130,15 @@ mod tests { #[test] fn runner_one_step() { - let settings: SimulationSettings = - SimulationSettings { - node_count: 10, - committee_size: 1, - ..Default::default() - }; + let settings: SimulationSettings = SimulationSettings { + node_count: 10, + overlay_settings: TreeSettings::default().into(), + ..Default::default() + }; let mut rng = StepRng::new(1, 0); let node_ids: Vec = (0..settings.node_count).map(Into::into).collect(); - let overlay = TreeOverlay::new(settings.overlay_settings.clone()); + let overlay = TreeOverlay::new(settings.overlay_settings.clone().try_into().unwrap()); let mut network = init_network(&node_ids); let view = ViewOverlay { leaders: overlay.leaders(&node_ids, 1, &mut rng).collect(), @@ -151,11 +146,12 @@ mod tests { }; let overlay_state = Arc::new(RwLock::new(OverlayState { all_nodes: node_ids.clone(), + overlay: SimulationOverlay::Tree(overlay), overlays: BTreeMap::from([(0, view.clone()), (1, view)]), })); let nodes = init_dummy_nodes(&node_ids, &mut network, overlay_state); - let runner: SimulationRunner> = + let runner: SimulationRunner = SimulationRunner::new(network, nodes, settings); let mut nodes = runner.nodes.write().unwrap(); runner.inner.write().unwrap().step(&mut nodes); @@ -169,16 +165,14 @@ mod tests { #[test] fn runner_send_receive() { - let settings: SimulationSettings = - SimulationSettings { - node_count: 10, - committee_size: 1, - ..Default::default() - }; + let settings: SimulationSettings = SimulationSettings { + node_count: 10, + ..Default::default() + }; let mut rng = StepRng::new(1, 0); let node_ids: Vec = (0..settings.node_count).map(Into::into).collect(); - let overlay = TreeOverlay::new(settings.overlay_settings.clone()); + let overlay = TreeOverlay::new(settings.overlay_settings.clone().try_into().unwrap()); let mut network = init_network(&node_ids); let view = ViewOverlay { leaders: overlay.leaders(&node_ids, 1, &mut rng).collect(), @@ -186,6 +180,7 @@ mod tests { }; let overlay_state = Arc::new(RwLock::new(OverlayState { all_nodes: node_ids.clone(), + overlay: SimulationOverlay::Tree(overlay), overlays: BTreeMap::from([ (0, view.clone()), (1, view.clone()), @@ -202,7 +197,7 @@ mod tests { } network.collect_messages(); - let runner: SimulationRunner> = + let runner: SimulationRunner = SimulationRunner::new(network, nodes, settings); let mut nodes = runner.nodes.write().unwrap(); diff --git a/simulations/src/settings.rs b/simulations/src/settings.rs index 3dba264b..d94d39cc 100644 --- a/simulations/src/settings.rs +++ b/simulations/src/settings.rs @@ -1,9 +1,8 @@ -use crate::network::regions::Region; -use crate::node::StepTime; +use crate::network::NetworkSettings; +use crate::overlay::OverlaySettings; use crate::streaming::StreamSettings; use crate::warding::Ward; use serde::Deserialize; -use std::collections::HashMap; #[derive(Clone, Debug, Deserialize, Default)] pub enum RunnerSettings { @@ -22,17 +21,24 @@ pub enum RunnerSettings { }, } +#[derive(Clone, Debug, Deserialize, Default)] +pub enum NodeSettings { + Carnot, + #[default] + Dummy, +} + #[derive(Default, Deserialize)] -pub struct SimulationSettings { - pub network_behaviors: HashMap<(Region, Region), StepTime>, - pub regions: Vec, +pub struct SimulationSettings { #[serde(default)] pub wards: Vec, - pub overlay_settings: O, - pub node_settings: N, + pub network_settings: NetworkSettings, + pub overlay_settings: OverlaySettings, + pub node_settings: NodeSettings, pub runner_settings: RunnerSettings, - pub stream_settings: StreamSettings

, + pub stream_settings: StreamSettings, pub node_count: usize, - pub committee_size: usize, + pub views_count: usize, + pub leaders_count: usize, pub seed: Option, } diff --git a/simulations/src/streaming/io.rs b/simulations/src/streaming/io.rs index 5d3b27cc..ee1344c6 100644 --- a/simulations/src/streaming/io.rs +++ b/simulations/src/streaming/io.rs @@ -1,31 +1,52 @@ -use std::sync::{Arc, Mutex}; +use std::{ + any::Any, + io::stdout, + sync::{Arc, Mutex}, +}; -use super::{Producer, Receivers, Subscriber}; +use super::{Producer, Receivers, StreamSettings, Subscriber}; use arc_swap::ArcSwapOption; use crossbeam::channel::{bounded, unbounded, Sender}; use serde::{Deserialize, Serialize}; -#[derive(Debug)] -pub struct IOStreamSettings { - pub writer: W, +#[derive(Debug, Default, Deserialize)] +pub struct IOStreamSettings { + pub writer_type: WriteType, } -impl Default for IOStreamSettings { - fn default() -> Self { - Self { - writer: std::io::stdout(), +#[derive(Debug, Default, Deserialize)] +pub enum WriteType { + #[default] + Stdout, +} + +pub trait ToWriter { + fn to_writer(&self) -> Result; +} + +impl ToWriter for WriteType { + fn to_writer(&self) -> Result { + match self { + WriteType::Stdout => { + let stdout = Box::new(stdout()); + let boxed_any = Box::new(stdout) as Box; + boxed_any + .downcast::() + .map(|boxed| *boxed) + .map_err(|_| "Writer type mismatch".to_string()) + } } } } -impl<'de> Deserialize<'de> for IOStreamSettings { - fn deserialize(_deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - Ok(Self { - writer: std::io::stdout(), - }) +impl TryFrom for IOStreamSettings { + type Error = String; + + fn try_from(settings: StreamSettings) -> Result { + match settings { + StreamSettings::IO(settings) => Ok(settings), + _ => Err("io settings can't be created".into()), + } } } @@ -42,21 +63,28 @@ where W: std::io::Write + Send + Sync + 'static, R: Serialize + Send + Sync + 'static, { - type Settings = IOStreamSettings; + type Settings = IOStreamSettings; type Subscriber = IOSubscriber; - fn new(settings: Self::Settings) -> anyhow::Result + fn new(settings: StreamSettings) -> anyhow::Result where Self: Sized, { + let settings: IOStreamSettings = settings + .try_into() + .expect("io settings from stream settings"); + let writer = settings + .writer_type + .to_writer() + .expect("writer from writer type"); let (sender, recv) = unbounded(); let (stop_tx, stop_rx) = bounded(1); Ok(Self { sender, recvs: ArcSwapOption::from(Some(Arc::new(Receivers { stop_rx, recv }))), stop_tx, - writer: ArcSwapOption::from(Some(Arc::new(Mutex::new(settings.writer)))), + writer: ArcSwapOption::from(Some(Arc::new(Mutex::new(writer)))), }) } @@ -139,9 +167,7 @@ mod tests { Network, }, node::{dummy_streaming::DummyStreamingNode, Node, NodeId}, - overlay::tree::TreeOverlay, runner::SimulationRunner, - streaming::{StreamSettings, StreamType}, warding::SimulationState, }; @@ -169,12 +195,6 @@ mod tests { fn test_streaming() { let simulation_settings = crate::settings::SimulationSettings { seed: Some(1), - stream_settings: StreamSettings { - ty: StreamType::IO, - settings: IOStreamSettings { - writer: std::io::stdout(), - }, - }, ..Default::default() }; @@ -231,14 +251,10 @@ mod tests { }) .collect(), }); - let simulation_runner: SimulationRunner< - (), - DummyStreamingNode<()>, - TreeOverlay, - IOProducer, - > = SimulationRunner::new(network, nodes, simulation_settings); + let simulation_runner: SimulationRunner<(), DummyStreamingNode<()>> = + SimulationRunner::new(network, nodes, simulation_settings); simulation_runner - .simulate() + .simulate::>() .unwrap() .stop_after(Duration::from_millis(100)) .unwrap(); diff --git a/simulations/src/streaming/mod.rs b/simulations/src/streaming/mod.rs index ad1c6b5a..ab55f686 100644 --- a/simulations/src/streaming/mod.rs +++ b/simulations/src/streaming/mod.rs @@ -1,7 +1,7 @@ use std::str::FromStr; use crossbeam::channel::Receiver; -use serde::Serialize; +use serde::{Deserialize, Serialize}; pub mod io; pub mod naive; @@ -45,18 +45,24 @@ impl<'de> serde::Deserialize<'de> for StreamType { } } -#[derive(Debug, Default, Clone, Serialize, serde::Deserialize)] -pub struct StreamSettings { - #[serde(rename = "type")] - pub ty: StreamType, - pub settings: S, +#[derive(Debug, Deserialize)] +pub enum StreamSettings { + Naive(naive::NaiveSettings), + IO(io::IOStreamSettings), + Polars(polars::PolarsSettings), +} + +impl Default for StreamSettings { + fn default() -> Self { + Self::IO(Default::default()) + } } pub trait Producer: Send + Sync + 'static { type Settings: Send; type Subscriber: Subscriber; - fn new(settings: Self::Settings) -> anyhow::Result + fn new(settings: StreamSettings) -> anyhow::Result where Self: Sized; diff --git a/simulations/src/streaming/naive.rs b/simulations/src/streaming/naive.rs index 8b1c8ea3..65619fa9 100644 --- a/simulations/src/streaming/naive.rs +++ b/simulations/src/streaming/naive.rs @@ -5,7 +5,7 @@ use std::{ sync::{Arc, Mutex}, }; -use super::{Producer, Receivers, Subscriber}; +use super::{Producer, Receivers, StreamSettings, Subscriber}; use arc_swap::ArcSwapOption; use crossbeam::channel::{bounded, unbounded, Sender}; use serde::{Deserialize, Serialize}; @@ -15,6 +15,17 @@ pub struct NaiveSettings { pub path: PathBuf, } +impl TryFrom for NaiveSettings { + type Error = String; + + fn try_from(settings: StreamSettings) -> Result { + match settings { + StreamSettings::Naive(settings) => Ok(settings), + _ => Err("naive settings can't be created".into()), + } + } +} + impl Default for NaiveSettings { fn default() -> Self { let mut tmp = std::env::temp_dir(); @@ -40,10 +51,11 @@ where type Subscriber = NaiveSubscriber; - fn new(settings: Self::Settings) -> anyhow::Result + fn new(settings: StreamSettings) -> anyhow::Result where Self: Sized, { + let settings = settings.try_into().expect("naive settings"); let (sender, recv) = unbounded(); let (stop_tx, stop_rx) = bounded(1); Ok(Self { @@ -138,7 +150,6 @@ mod tests { Network, }, node::{dummy_streaming::DummyStreamingNode, Node, NodeId}, - overlay::tree::TreeOverlay, runner::SimulationRunner, warding::SimulationState, }; @@ -225,13 +236,11 @@ mod tests { }) .collect(), }); - let simulation_runner: SimulationRunner< - (), - DummyStreamingNode<()>, - TreeOverlay, - NaiveProducer, - > = SimulationRunner::new(network, nodes, simulation_settings); + let simulation_runner: SimulationRunner<(), DummyStreamingNode<()>> = + SimulationRunner::new(network, nodes, simulation_settings); - simulation_runner.simulate().unwrap(); + simulation_runner + .simulate::>() + .unwrap(); } } diff --git a/simulations/src/streaming/polars.rs b/simulations/src/streaming/polars.rs index 889a26b0..78527718 100644 --- a/simulations/src/streaming/polars.rs +++ b/simulations/src/streaming/polars.rs @@ -10,7 +10,7 @@ use std::{ sync::Mutex, }; -use super::{Producer, Receivers, Subscriber}; +use super::{Producer, Receivers, StreamSettings, Subscriber}; #[derive(Debug, Clone, Copy, Serialize)] pub enum PolarsFormat { @@ -50,6 +50,17 @@ pub struct PolarsSettings { pub path: PathBuf, } +impl TryFrom for PolarsSettings { + type Error = String; + + fn try_from(settings: StreamSettings) -> Result { + match settings { + StreamSettings::Polars(settings) => Ok(settings), + _ => Err("polars settings can't be created".into()), + } + } +} + #[derive(Debug)] pub struct PolarsProducer { sender: Sender, @@ -66,10 +77,11 @@ where type Subscriber = PolarsSubscriber; - fn new(settings: Self::Settings) -> anyhow::Result + fn new(settings: StreamSettings) -> anyhow::Result where Self: Sized, { + let settings = settings.try_into().expect("polars settings"); let (sender, recv) = unbounded(); let (stop_tx, stop_rx) = bounded(1); Ok(Self {