From 53913829281e595554204f546a7d95148bd61b32 Mon Sep 17 00:00:00 2001 From: Al Liu Date: Mon, 21 Aug 2023 12:02:10 +0800 Subject: [PATCH] Using CSV format by default for simulations (#304) --- consensus-engine/src/types/block_id.rs | 6 + consensus-engine/src/types/node_id.rs | 6 + simulations/Cargo.toml | 1 + simulations/config/carnot.json | 2 +- simulations/config/carnot_dev.json | 62 ++++ simulations/src/bin/app/main.rs | 22 +- simulations/src/bin/app/overlay_node.rs | 10 + simulations/src/node/carnot/mod.rs | 199 ++---------- simulations/src/node/carnot/serde_util.rs | 299 ++++++++---------- simulations/src/node/carnot/serde_util/csv.rs | 165 ++++++++++ .../src/node/carnot/serde_util/json.rs | 137 ++++++++ simulations/src/node/carnot/state.rs | 162 ++++++++++ simulations/src/node/dummy.rs | 4 +- simulations/src/output_processors/mod.rs | 15 +- simulations/src/runner/async_runner.rs | 2 +- simulations/src/runner/mod.rs | 4 +- simulations/src/settings.rs | 4 +- simulations/src/streaming/mod.rs | 51 +++ simulations/src/streaming/naive.rs | 99 +++++- simulations/src/streaming/polars.rs | 59 +--- 20 files changed, 900 insertions(+), 409 deletions(-) create mode 100644 simulations/config/carnot_dev.json create mode 100644 simulations/src/node/carnot/serde_util/csv.rs create mode 100644 simulations/src/node/carnot/serde_util/json.rs create mode 100644 simulations/src/node/carnot/state.rs diff --git a/consensus-engine/src/types/block_id.rs b/consensus-engine/src/types/block_id.rs index b054f935..1d5a697f 100644 --- a/consensus-engine/src/types/block_id.rs +++ b/consensus-engine/src/types/block_id.rs @@ -39,6 +39,12 @@ impl From for [u8; 32] { } } +impl<'a> From<&'a BlockId> for &'a [u8; 32] { + fn from(id: &'a BlockId) -> Self { + &id.0 + } +} + impl core::fmt::Display for BlockId { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "0x")?; diff --git a/consensus-engine/src/types/node_id.rs b/consensus-engine/src/types/node_id.rs index e58c115b..e2c8c48d 100644 --- a/consensus-engine/src/types/node_id.rs +++ b/consensus-engine/src/types/node_id.rs @@ -35,6 +35,12 @@ impl From for [u8; 32] { } } +impl<'a> From<&'a NodeId> for &'a [u8; 32] { + fn from(id: &'a NodeId) -> Self { + &id.0 + } +} + impl core::fmt::Display for NodeId { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "0x")?; diff --git a/simulations/Cargo.toml b/simulations/Cargo.toml index 79799864..88ecdf10 100644 --- a/simulations/Cargo.toml +++ b/simulations/Cargo.toml @@ -13,6 +13,7 @@ path = "src/bin/app/main.rs" anyhow = "1" arc-swap = "1.6" bls-signatures = "0.14" +csv = "1" clap = { version = "4", features = ["derive"] } ctrlc = "3.4" chrono = { version = "0.4", features = ["serde"] } diff --git a/simulations/config/carnot.json b/simulations/config/carnot.json index 074f8009..83bd6fde 100644 --- a/simulations/config/carnot.json +++ b/simulations/config/carnot.json @@ -26,7 +26,7 @@ "step_time": "10ms", "runner_settings": "Sync", "stream_settings": { - "path": "test.json" + "path": "test.csv" }, "node_count": 3000, "views_count": 3, diff --git a/simulations/config/carnot_dev.json b/simulations/config/carnot_dev.json new file mode 100644 index 00000000..b98ced2a --- /dev/null +++ b/simulations/config/carnot_dev.json @@ -0,0 +1,62 @@ +{ + "network_settings": { + "network_behaviors": { + "north america:north america": "10ms", + "north america:europe": "150ms", + "north america:asia": "250ms", + "europe:europe": "10ms", + "europe:asia": "200ms", + "europe:north america": "150ms", + "asia:north america": "250ms", + "asia:europe": "200ms", + "asia:asia": "10ms" + }, + "regions": { + "north america": 0.4, + "europe": 0.3, + "asia": 0.3 + } + }, + "overlay_settings": { + "number_of_committees": 7 + }, + "node_settings": { + "network_capacity_kbps": 10000024, + "timeout": "10000ms" + }, + "step_time": "100ms", + "runner_settings": "Sync", + "stream_settings": { + "path": "tree_500_7_view_1_default.csv", + "format": "csv" + }, + "node_count": 500, + "views_count": 10, + "leaders_count": 1, + "seed": 0, + "wards": [ + { + "max_view": 1 + }, + { + "stalled_view": { + "consecutive_viewed_checkpoint": null, + "criterion": 0, + "threshold": 100 + } + } + ], + "record_settings": { + "current_view": true, + "highest_voted_view": true, + "local_high_qc": true, + "safe_blocks": false, + "last_view_timeout_qc": true, + "latest_committed_block": true, + "latest_committed_view": true, + "root_committee": false, + "parent_committee": false, + "child_committees": false, + "committed_blocks": false + } +} \ No newline at end of file diff --git a/simulations/src/bin/app/main.rs b/simulations/src/bin/app/main.rs index c7298556..302b83ec 100644 --- a/simulations/src/bin/app/main.rs +++ b/simulations/src/bin/app/main.rs @@ -19,7 +19,7 @@ use serde::Serialize; use simulations::network::behaviour::create_behaviours; use simulations::network::regions::{create_regions, RegionsData}; use simulations::network::{InMemoryNetworkInterface, Network}; -use simulations::node::carnot::{CarnotSettings, CarnotState}; +use simulations::node::carnot::{CarnotRecord, CarnotSettings, CarnotState}; use simulations::node::{NodeId, NodeIdExt}; use simulations::output_processors::Record; use simulations::runner::{BoxedNode, SimulationRunnerHandle}; @@ -27,9 +27,7 @@ use simulations::streaming::{ io::IOSubscriber, naive::NaiveSubscriber, polars::PolarsSubscriber, StreamType, }; // internal -use simulations::{ - output_processors::OutData, runner::SimulationRunner, settings::SimulationSettings, -}; +use simulations::{runner::SimulationRunner, settings::SimulationSettings}; mod log; mod overlay_node; @@ -146,24 +144,28 @@ fn run( where M: Clone + Send + Sync + 'static, S: 'static, - T: Serialize + 'static, + T: Serialize + Clone + 'static, { let stream_settings = settings.stream_settings.clone(); - let runner = - SimulationRunner::<_, OutData, S, T>::new(network, nodes, Default::default(), settings)?; + let runner = SimulationRunner::<_, CarnotRecord, S, T>::new( + network, + nodes, + Default::default(), + settings, + )?; let handle = match stream_type { Some(StreamType::Naive) => { let settings = stream_settings.unwrap_naive(); - runner.simulate_and_subscribe::>(settings)? + runner.simulate_and_subscribe::>(settings)? } Some(StreamType::IO) => { let settings = stream_settings.unwrap_io(); - runner.simulate_and_subscribe::>(settings)? + runner.simulate_and_subscribe::>(settings)? } Some(StreamType::Polars) => { let settings = stream_settings.unwrap_polars(); - runner.simulate_and_subscribe::>(settings)? + runner.simulate_and_subscribe::>(settings)? } None => runner.simulate()?, }; diff --git a/simulations/src/bin/app/overlay_node.rs b/simulations/src/bin/app/overlay_node.rs index 00d878e0..37901847 100644 --- a/simulations/src/bin/app/overlay_node.rs +++ b/simulations/src/bin/app/overlay_node.rs @@ -20,6 +20,13 @@ pub fn to_overlay_node( mut rng: R, settings: &SimulationSettings, ) -> BoxedNode { + let fmt = match &settings.stream_settings { + simulations::streaming::StreamSettings::Naive(n) => n.format, + simulations::streaming::StreamSettings::IO(_) => { + simulations::streaming::SubscriberFormat::Csv + } + simulations::streaming::StreamSettings::Polars(p) => p.format, + }; match &settings.overlay_settings { simulations::settings::OverlaySettings::Flat => { let overlay_settings = consensus_engine::overlay::FlatOverlaySettings { @@ -33,6 +40,7 @@ pub fn to_overlay_node( CarnotSettings::new( settings.node_settings.timeout, settings.record_settings.clone(), + fmt, ), overlay_settings, genesis, @@ -55,6 +63,7 @@ pub fn to_overlay_node( CarnotSettings::new( settings.node_settings.timeout, settings.record_settings.clone(), + fmt, ), overlay_settings, genesis, @@ -77,6 +86,7 @@ pub fn to_overlay_node( CarnotSettings::new( settings.node_settings.timeout, settings.record_settings.clone(), + fmt, ), overlay_settings, genesis, diff --git a/simulations/src/node/carnot/mod.rs b/simulations/src/node/carnot/mod.rs index 551ff3a4..e2df06fd 100644 --- a/simulations/src/node/carnot/mod.rs +++ b/simulations/src/node/carnot/mod.rs @@ -3,10 +3,14 @@ mod event_builder; mod message_cache; pub mod messages; +mod state; +pub use state::*; mod serde_util; mod tally; mod timeout; +use std::any::Any; +use std::collections::BTreeMap; // std use std::hash::Hash; use std::time::Instant; @@ -21,6 +25,10 @@ use super::{Node, NodeId}; use crate::network::{InMemoryNetworkInterface, NetworkInterface, NetworkMessage}; use crate::node::carnot::event_builder::{CarnotTx, Event}; use crate::node::carnot::message_cache::MessageCache; +use crate::output_processors::{Record, RecordType, Runtime}; +use crate::settings::SimulationSettings; +use crate::streaming::SubscriberFormat; +use crate::warding::SimulationState; use consensus_engine::overlay::RandomBeaconState; use consensus_engine::{ Block, BlockId, Carnot, Committee, Overlay, Payload, Qc, StandardQc, TimeoutQc, View, Vote, @@ -32,183 +40,27 @@ use nomos_consensus::{ network::messages::{NewViewMsg, TimeoutMsg, VoteMsg}, }; -const NODE_ID: &str = "node_id"; -const CURRENT_VIEW: &str = "current_view"; -const HIGHEST_VOTED_VIEW: &str = "highest_voted_view"; -const LOCAL_HIGH_QC: &str = "local_high_qc"; -const SAFE_BLOCKS: &str = "safe_blocks"; -const LAST_VIEW_TIMEOUT_QC: &str = "last_view_timeout_qc"; -const LATEST_COMMITTED_BLOCK: &str = "latest_committed_block"; -const LATEST_COMMITTED_VIEW: &str = "latest_committed_view"; -const ROOT_COMMITTEE: &str = "root_committee"; -const PARENT_COMMITTEE: &str = "parent_committee"; -const CHILD_COMMITTEES: &str = "child_committees"; -const COMMITTED_BLOCKS: &str = "committed_blocks"; -const STEP_DURATION: &str = "step_duration"; - -pub const CARNOT_RECORD_KEYS: &[&str] = &[ - NODE_ID, - CURRENT_VIEW, - HIGHEST_VOTED_VIEW, - LOCAL_HIGH_QC, - SAFE_BLOCKS, - LAST_VIEW_TIMEOUT_QC, - LATEST_COMMITTED_BLOCK, - LATEST_COMMITTED_VIEW, - ROOT_COMMITTEE, - PARENT_COMMITTEE, - CHILD_COMMITTEES, - COMMITTED_BLOCKS, - STEP_DURATION, -]; - -static RECORD_SETTINGS: std::sync::OnceLock> = std::sync::OnceLock::new(); - -#[derive(Debug)] -pub struct CarnotState { - node_id: NodeId, - current_view: View, - highest_voted_view: View, - local_high_qc: StandardQc, - safe_blocks: HashMap, - last_view_timeout_qc: Option, - latest_committed_block: Block, - latest_committed_view: View, - root_committee: Committee, - parent_committee: Option, - child_committees: Vec, - committed_blocks: Vec, - step_duration: Duration, -} - -impl serde::Serialize for CarnotState { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - if let Some(rs) = RECORD_SETTINGS.get() { - let keys = rs - .iter() - .filter_map(|(k, v)| { - if CARNOT_RECORD_KEYS.contains(&k.trim()) && *v { - Some(k) - } else { - None - } - }) - .collect::>(); - - let mut state = serde_util::CarnotState::default(); - for k in keys { - match k.trim() { - NODE_ID => { - state.node_id = Some(self.node_id.into()); - } - CURRENT_VIEW => { - state.current_view = Some(self.current_view); - } - HIGHEST_VOTED_VIEW => { - state.highest_voted_view = Some(self.highest_voted_view); - } - LOCAL_HIGH_QC => { - state.local_high_qc = Some((&self.local_high_qc).into()); - } - SAFE_BLOCKS => { - state.safe_blocks = Some((&self.safe_blocks).into()); - } - LAST_VIEW_TIMEOUT_QC => { - state.last_view_timeout_qc = - Some(self.last_view_timeout_qc.as_ref().map(From::from)); - } - LATEST_COMMITTED_BLOCK => { - state.latest_committed_block = Some((&self.latest_committed_block).into()); - } - LATEST_COMMITTED_VIEW => { - state.latest_committed_view = Some(self.latest_committed_view); - } - ROOT_COMMITTEE => { - state.root_committee = Some((&self.root_committee).into()); - } - PARENT_COMMITTEE => { - state.parent_committee = - Some(self.parent_committee.as_ref().map(From::from)); - } - CHILD_COMMITTEES => { - state.child_committees = Some(self.child_committees.as_slice().into()); - } - COMMITTED_BLOCKS => { - state.committed_blocks = Some(self.committed_blocks.as_slice().into()); - } - STEP_DURATION => { - state.step_duration = Some(self.step_duration); - } - _ => {} - } - } - state.serialize(serializer) - } else { - serializer.serialize_none() - } - } -} - -impl CarnotState { - const fn keys() -> &'static [&'static str] { - CARNOT_RECORD_KEYS - } -} - -/// Have to implement this manually because of the `serde_json` will panic if the key of map -/// is not a string. -fn serialize_blocks(blocks: &HashMap, serializer: S) -> Result -where - S: serde::Serializer, -{ - use serde::ser::SerializeMap; - let mut ser = serializer.serialize_map(Some(blocks.len()))?; - for (k, v) in blocks { - ser.serialize_entry(&format!("{k:?}"), v)?; - } - ser.end() -} - -impl From<&Carnot> for CarnotState { - fn from(value: &Carnot) -> Self { - let node_id = value.id(); - let current_view = value.current_view(); - Self { - node_id, - current_view, - local_high_qc: value.high_qc(), - parent_committee: value.parent_committee(), - root_committee: value.root_committee(), - child_committees: value.child_committees(), - latest_committed_block: value.latest_committed_block(), - latest_committed_view: value.latest_committed_view(), - safe_blocks: value - .blocks_in_view(current_view) - .into_iter() - .map(|b| (b.id, b)) - .collect(), - last_view_timeout_qc: value.last_view_timeout_qc(), - committed_blocks: value.latest_committed_blocks(), - highest_voted_view: Default::default(), - step_duration: Default::default(), - } - } -} +static RECORD_SETTINGS: std::sync::OnceLock> = std::sync::OnceLock::new(); #[derive(Clone, Default, Deserialize)] pub struct CarnotSettings { timeout: Duration, - record_settings: HashMap, + record_settings: BTreeMap, + + #[serde(default)] + format: SubscriberFormat, } impl CarnotSettings { - pub fn new(timeout: Duration, record_settings: HashMap) -> Self { + pub fn new( + timeout: Duration, + record_settings: BTreeMap, + format: SubscriberFormat, + ) -> Self { Self { timeout, record_settings, + format, } } } @@ -217,6 +69,8 @@ impl CarnotSettings { pub struct CarnotNode { id: consensus_engine::NodeId, state: CarnotState, + /// A step counter + current_step: usize, settings: CarnotSettings, network_interface: InMemoryNetworkInterface, message_cache: MessageCache, @@ -259,8 +113,10 @@ impl< engine, random_beacon_pk, step_duration: Duration::ZERO, + current_step: 0, }; this.state = CarnotState::from(&this.engine); + this.state.format = this.settings.format; this } @@ -570,8 +426,13 @@ impl< } // update state - self.state = CarnotState::from(&self.engine); - self.state.step_duration = step_duration.elapsed(); + self.state = CarnotState::new( + self.current_step, + step_duration.elapsed(), + self.settings.format, + &self.engine, + ); + self.current_step += 1; } } diff --git a/simulations/src/node/carnot/serde_util.rs b/simulations/src/node/carnot/serde_util.rs index 8b5cf654..d032acba 100644 --- a/simulations/src/node/carnot/serde_util.rs +++ b/simulations/src/node/carnot/serde_util.rs @@ -6,155 +6,146 @@ use serde::{ }; use self::{ - serde_block::BlockHelper, serde_id::{BlockIdHelper, NodeIdHelper}, standard_qc::StandardQcHelper, timeout_qc::TimeoutQcHelper, }; use consensus_engine::{AggregateQc, Block, BlockId, Committee, Qc, StandardQc, TimeoutQc, View}; -#[serde_with::skip_serializing_none] -#[serde_with::serde_as] -#[derive(Serialize, Default)] -pub(crate) struct CarnotState<'a> { - pub(crate) node_id: Option, - pub(crate) current_view: Option, - pub(crate) highest_voted_view: Option, - pub(crate) local_high_qc: Option, - pub(crate) safe_blocks: Option>, - pub(crate) last_view_timeout_qc: Option>>, - pub(crate) latest_committed_block: Option, - pub(crate) latest_committed_view: Option, - pub(crate) root_committee: Option>, - pub(crate) parent_committee: Option>>, - pub(crate) child_committees: Option>, - pub(crate) committed_blocks: Option>, - #[serde_as(as = "Option")] - pub(crate) step_duration: Option, -} +const NODE_ID: &str = "node_id"; +const CURRENT_VIEW: &str = "current_view"; +const HIGHEST_VOTED_VIEW: &str = "highest_voted_view"; +const LOCAL_HIGH_QC: &str = "local_high_qc"; +const SAFE_BLOCKS: &str = "safe_blocks"; +const LAST_VIEW_TIMEOUT_QC: &str = "last_view_timeout_qc"; +const LATEST_COMMITTED_BLOCK: &str = "latest_committed_block"; +const LATEST_COMMITTED_VIEW: &str = "latest_committed_view"; +const ROOT_COMMITTEE: &str = "root_committee"; +const PARENT_COMMITTEE: &str = "parent_committee"; +const CHILD_COMMITTEES: &str = "child_committees"; +const COMMITTED_BLOCKS: &str = "committed_blocks"; +const STEP_DURATION: &str = "step_duration"; -impl<'a> From<&'a super::CarnotState> for CarnotState<'a> { - fn from(value: &'a super::CarnotState) -> Self { - Self { - node_id: Some(value.node_id.into()), - current_view: Some(value.current_view), - highest_voted_view: Some(value.highest_voted_view), - local_high_qc: Some(StandardQcHelper::from(&value.local_high_qc)), - safe_blocks: Some(SafeBlocksHelper::from(&value.safe_blocks)), - last_view_timeout_qc: Some(value.last_view_timeout_qc.as_ref().map(From::from)), - latest_committed_block: Some(BlockHelper::from(&value.latest_committed_block)), - latest_committed_view: Some(value.latest_committed_view), - root_committee: Some(CommitteeHelper::from(&value.root_committee)), - parent_committee: Some(value.parent_committee.as_ref().map(From::from)), - child_committees: Some(CommitteesHelper::from(value.child_committees.as_slice())), - committed_blocks: Some(CommittedBlockHelper::from( - value.committed_blocks.as_slice(), - )), - step_duration: Some(value.step_duration), +pub const CARNOT_RECORD_KEYS: &[&str] = &[ + CHILD_COMMITTEES, + COMMITTED_BLOCKS, + CURRENT_VIEW, + HIGHEST_VOTED_VIEW, + LAST_VIEW_TIMEOUT_QC, + LATEST_COMMITTED_BLOCK, + LATEST_COMMITTED_VIEW, + LOCAL_HIGH_QC, + NODE_ID, + PARENT_COMMITTEE, + ROOT_COMMITTEE, + SAFE_BLOCKS, + STEP_DURATION, +]; + +macro_rules! serializer { + ($name: ident) => { + #[serde_with::skip_serializing_none] + #[serde_with::serde_as] + #[derive(Serialize, Default)] + pub(crate) struct $name<'a> { + step_id: usize, + child_committees: Option>, + committed_blocks: Option>, + current_view: Option, + highest_voted_view: Option, + last_view_timeout_qc: Option>>, + latest_committed_block: Option>, + latest_committed_view: Option, + local_high_qc: Option>, + node_id: Option>, + parent_committee: Option>>, + root_committee: Option>, + safe_blocks: Option>, + #[serde_as(as = "Option")] + step_duration: Option, } - } -} -pub(crate) struct SafeBlocksHelper<'a>(&'a HashMap); - -impl<'a> From<&'a HashMap> for SafeBlocksHelper<'a> { - fn from(val: &'a HashMap) -> Self { - Self(val) - } -} - -impl<'a> Serialize for SafeBlocksHelper<'a> { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let iter = self.0.values(); - let mut s = serializer.serialize_seq(Some(iter.size_hint().0))?; - for b in iter { - s.serialize_element(&BlockHelper::from(b))?; + impl<'a> $name<'a> { + pub(crate) fn serialize_state( + &mut self, + keys: Vec<&String>, + state: &'a super::super::CarnotState, + serializer: S, + ) -> Result { + self.step_id = state.step_id; + for k in keys { + match k.trim() { + NODE_ID => { + self.node_id = Some((&state.node_id).into()); + } + CURRENT_VIEW => { + self.current_view = Some(state.current_view); + } + HIGHEST_VOTED_VIEW => { + self.highest_voted_view = Some(state.highest_voted_view); + } + LOCAL_HIGH_QC => { + self.local_high_qc = Some((&state.local_high_qc).into()); + } + SAFE_BLOCKS => { + self.safe_blocks = Some((&state.safe_blocks).into()); + } + LAST_VIEW_TIMEOUT_QC => { + self.last_view_timeout_qc = + Some(state.last_view_timeout_qc.as_ref().map(From::from)); + } + LATEST_COMMITTED_BLOCK => { + self.latest_committed_block = + Some((&state.latest_committed_block).into()); + } + LATEST_COMMITTED_VIEW => { + self.latest_committed_view = Some(state.latest_committed_view); + } + ROOT_COMMITTEE => { + self.root_committee = Some((&state.root_committee).into()); + } + PARENT_COMMITTEE => { + self.parent_committee = + Some(state.parent_committee.as_ref().map(From::from)); + } + CHILD_COMMITTEES => { + self.child_committees = Some(state.child_committees.as_slice().into()); + } + COMMITTED_BLOCKS => { + self.committed_blocks = Some(state.committed_blocks.as_slice().into()); + } + STEP_DURATION => { + self.step_duration = Some(state.step_duration); + } + _ => {} + } + } + self.serialize(serializer) + } } - s.end() - } + }; } -pub(crate) struct CommitteeHelper<'a>(&'a Committee); +mod csv; +mod json; -impl<'a> From<&'a Committee> for CommitteeHelper<'a> { - fn from(val: &'a Committee) -> Self { - Self(val) - } -} - -impl<'a> Serialize for CommitteeHelper<'a> { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let iter = self.0.iter(); - let mut s = serializer.serialize_seq(Some(iter.size_hint().0))?; - for id in iter { - s.serialize_element(&NodeIdHelper::from(*id))?; - } - s.end() - } -} - -pub(crate) struct CommitteesHelper<'a>(&'a [Committee]); - -impl<'a> From<&'a [Committee]> for CommitteesHelper<'a> { - fn from(val: &'a [Committee]) -> Self { - Self(val) - } -} - -impl<'a> Serialize for CommitteesHelper<'a> { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let mut s = serializer.serialize_seq(Some(self.0.len()))?; - for c in self.0 { - s.serialize_element(&CommitteeHelper::from(c))?; - } - s.end() - } -} - -pub(crate) struct CommittedBlockHelper<'a>(&'a [BlockId]); - -impl<'a> From<&'a [BlockId]> for CommittedBlockHelper<'a> { - fn from(val: &'a [BlockId]) -> Self { - Self(val) - } -} - -impl<'a> Serialize for CommittedBlockHelper<'a> { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let mut s = serializer.serialize_seq(Some(self.0.len()))?; - for c in self.0 { - s.serialize_element(&BlockIdHelper::from(*c))?; - } - s.end() - } -} +pub(super) use self::csv::CarnotStateCsvSerializer; +pub(super) use json::CarnotStateJsonSerializer; pub(crate) mod standard_qc { use super::*; #[derive(Serialize)] - pub(crate) struct StandardQcHelper { + pub(crate) struct StandardQcHelper<'a> { view: View, - id: serde_id::BlockIdHelper, + id: serde_id::BlockIdHelper<'a>, } - impl From<&StandardQc> for StandardQcHelper { - fn from(val: &StandardQc) -> Self { + impl<'a> From<&'a StandardQc> for StandardQcHelper<'a> { + fn from(val: &'a StandardQc) -> Self { Self { view: val.view, - id: val.id.into(), + id: (&val.id).into(), } } } @@ -201,12 +192,17 @@ pub(crate) mod qc { Aggregate(aggregate_qc::AggregateQcHelper<'a>), } + impl<'a> From<&'a Qc> for QcHelper<'a> { + fn from(value: &'a Qc) -> Self { + match value { + Qc::Standard(s) => Self::Standard(s), + Qc::Aggregated(a) => Self::Aggregate(a.into()), + } + } + } + pub fn serialize(t: &Qc, serializer: S) -> Result { - let qc = match t { - Qc::Standard(s) => QcHelper::Standard(s), - Qc::Aggregated(a) => QcHelper::Aggregate(aggregate_qc::AggregateQcHelper::from(a)), - }; - qc.serialize(serializer) + QcHelper::from(t).serialize(serializer) } } @@ -241,48 +237,25 @@ pub(crate) mod timeout_qc { } } -pub(crate) mod serde_block { - use super::*; - - #[derive(Serialize)] - pub(crate) struct BlockHelper { - view: View, - id: BlockIdHelper, - } - - impl From<&Block> for BlockHelper { - fn from(val: &Block) -> Self { - Self { - view: val.view, - id: val.id.into(), - } - } - } - - pub fn serialize(t: &Block, serializer: S) -> Result { - BlockHelper::from(t).serialize(serializer) - } -} - pub(crate) mod serde_id { use consensus_engine::{BlockId, NodeId}; use super::*; - #[derive(Serialize, Deserialize)] - pub(crate) struct BlockIdHelper(#[serde(with = "serde_array32")] [u8; 32]); + #[derive(Serialize)] + pub(crate) struct BlockIdHelper<'a>(#[serde(with = "serde_array32")] &'a [u8; 32]); - impl From for BlockIdHelper { - fn from(val: BlockId) -> Self { + impl<'a> From<&'a BlockId> for BlockIdHelper<'a> { + fn from(val: &'a BlockId) -> Self { Self(val.into()) } } - #[derive(Serialize, Deserialize)] - pub(crate) struct NodeIdHelper(#[serde(with = "serde_array32")] [u8; 32]); + #[derive(Serialize)] + pub(crate) struct NodeIdHelper<'a>(#[serde(with = "serde_array32")] &'a [u8; 32]); - impl From for NodeIdHelper { - fn from(val: NodeId) -> Self { + impl<'a> From<&'a NodeId> for NodeIdHelper<'a> { + fn from(val: &'a NodeId) -> Self { Self(val.into()) } } @@ -291,7 +264,7 @@ pub(crate) mod serde_id { t: &NodeId, serializer: S, ) -> Result { - NodeIdHelper::from(*t).serialize(serializer) + NodeIdHelper::from(t).serialize(serializer) } pub(crate) mod serde_array32 { diff --git a/simulations/src/node/carnot/serde_util/csv.rs b/simulations/src/node/carnot/serde_util/csv.rs new file mode 100644 index 00000000..7964eb3b --- /dev/null +++ b/simulations/src/node/carnot/serde_util/csv.rs @@ -0,0 +1,165 @@ +use super::*; +use serde_block::BlockHelper; + +serializer!(CarnotStateCsvSerializer); + +pub(crate) mod serde_block { + use consensus_engine::LeaderProof; + + use super::{qc::QcHelper, *}; + + #[derive(Serialize)] + #[serde(untagged)] + enum LeaderProofHelper<'a> { + LeaderId { leader_id: NodeIdHelper<'a> }, + } + + impl<'a> From<&'a LeaderProof> for LeaderProofHelper<'a> { + fn from(value: &'a LeaderProof) -> Self { + match value { + LeaderProof::LeaderId { leader_id } => Self::LeaderId { + leader_id: leader_id.into(), + }, + } + } + } + + pub(super) struct BlockHelper<'a>(BlockHelperInner<'a>); + + #[derive(Serialize)] + struct BlockHelperInner<'a> { + view: View, + id: BlockIdHelper<'a>, + parent_qc: QcHelper<'a>, + leader_proof: LeaderProofHelper<'a>, + } + + impl<'a> From<&'a Block> for BlockHelper<'a> { + fn from(val: &'a Block) -> Self { + Self(BlockHelperInner { + view: val.view, + id: (&val.id).into(), + parent_qc: (&val.parent_qc).into(), + leader_proof: (&val.leader_proof).into(), + }) + } + } + + impl<'a> serde::Serialize for BlockHelper<'a> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serde_json::to_string(&self.0) + .map_err(::custom) + .and_then(|s| serializer.serialize_str(s.as_str())) + } + } +} + +pub(super) struct LocalHighQcHelper<'a>(StandardQcHelper<'a>); + +impl<'a> Serialize for LocalHighQcHelper<'a> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serde_json::to_string(&self.0) + .map_err(::custom) + .and_then(|s| serializer.serialize_str(s.as_str())) + } +} + +impl<'a> From<&'a StandardQc> for LocalHighQcHelper<'a> { + fn from(value: &'a StandardQc) -> Self { + Self(From::from(value)) + } +} + +struct SafeBlocksHelper<'a>(&'a HashMap); + +impl<'a> From<&'a HashMap> for SafeBlocksHelper<'a> { + fn from(val: &'a HashMap) -> Self { + Self(val) + } +} + +impl<'a> Serialize for SafeBlocksHelper<'a> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + self.0 + .values() + .map(|b| serde_json::to_string(&BlockHelper::from(b))) + .collect::, _>>() + .map_err(::custom) + .and_then(|val| serializer.serialize_str(&format!("[{}]", val.join(",")))) + } +} + +struct CommitteeHelper<'a>(&'a Committee); + +impl<'a> From<&'a Committee> for CommitteeHelper<'a> { + fn from(val: &'a Committee) -> Self { + Self(val) + } +} + +impl<'a> Serialize for CommitteeHelper<'a> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + self.0 + .iter() + .map(|b| serde_json::to_string(&NodeIdHelper::from(b))) + .collect::, _>>() + .map_err(::custom) + .and_then(|val| serializer.serialize_str(&format!("[{}]", val.join(",")))) + } +} + +struct CommitteesHelper<'a>(&'a [Committee]); + +impl<'a> From<&'a [Committee]> for CommitteesHelper<'a> { + fn from(val: &'a [Committee]) -> Self { + Self(val) + } +} + +impl<'a> Serialize for CommitteesHelper<'a> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + self.0 + .iter() + .map(|b| serde_json::to_string(&CommitteeHelper::from(b))) + .collect::, _>>() + .map_err(::custom) + .and_then(|val| serializer.serialize_str(&format!("[{}]", val.join(",")))) + } +} + +struct CommittedBlockHelper<'a>(&'a [BlockId]); + +impl<'a> From<&'a [BlockId]> for CommittedBlockHelper<'a> { + fn from(val: &'a [BlockId]) -> Self { + Self(val) + } +} + +impl<'a> Serialize for CommittedBlockHelper<'a> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + self.0 + .iter() + .map(|b| serde_json::to_string(&BlockIdHelper::from(b))) + .collect::, _>>() + .map_err(::custom) + .and_then(|val| serializer.serialize_str(&format!("[{}]", val.join(",")))) + } +} diff --git a/simulations/src/node/carnot/serde_util/json.rs b/simulations/src/node/carnot/serde_util/json.rs new file mode 100644 index 00000000..69d361ba --- /dev/null +++ b/simulations/src/node/carnot/serde_util/json.rs @@ -0,0 +1,137 @@ +use super::*; +use serde_block::BlockHelper; + +serializer!(CarnotStateJsonSerializer); + +pub(super) type LocalHighQcHelper<'a> = super::standard_qc::StandardQcHelper<'a>; + +pub(crate) mod serde_block { + use consensus_engine::LeaderProof; + + use super::{qc::QcHelper, *}; + + #[derive(Serialize)] + #[serde(untagged)] + enum LeaderProofHelper<'a> { + LeaderId { leader_id: NodeIdHelper<'a> }, + } + + impl<'a> From<&'a LeaderProof> for LeaderProofHelper<'a> { + fn from(value: &'a LeaderProof) -> Self { + match value { + LeaderProof::LeaderId { leader_id } => Self::LeaderId { + leader_id: leader_id.into(), + }, + } + } + } + + #[derive(Serialize)] + pub(crate) struct BlockHelper<'a> { + view: View, + id: BlockIdHelper<'a>, + parent_qc: QcHelper<'a>, + leader_proof: LeaderProofHelper<'a>, + } + + impl<'a> From<&'a Block> for BlockHelper<'a> { + fn from(val: &'a Block) -> Self { + Self { + view: val.view, + id: (&val.id).into(), + parent_qc: (&val.parent_qc).into(), + leader_proof: (&val.leader_proof).into(), + } + } + } + + pub fn serialize(t: &Block, serializer: S) -> Result { + BlockHelper::from(t).serialize(serializer) + } +} + +struct SafeBlocksHelper<'a>(&'a HashMap); + +impl<'a> From<&'a HashMap> for SafeBlocksHelper<'a> { + fn from(val: &'a HashMap) -> Self { + Self(val) + } +} + +impl<'a> Serialize for SafeBlocksHelper<'a> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let iter = self.0.values(); + let mut s = serializer.serialize_seq(Some(iter.size_hint().0))?; + for b in iter { + s.serialize_element(&BlockHelper::from(b))?; + } + s.end() + } +} + +struct CommitteeHelper<'a>(&'a Committee); + +impl<'a> From<&'a Committee> for CommitteeHelper<'a> { + fn from(val: &'a Committee) -> Self { + Self(val) + } +} + +impl<'a> Serialize for CommitteeHelper<'a> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let iter = self.0.iter(); + let mut s = serializer.serialize_seq(Some(iter.size_hint().0))?; + for id in iter { + s.serialize_element(&NodeIdHelper::from(id))?; + } + s.end() + } +} + +struct CommitteesHelper<'a>(&'a [Committee]); + +impl<'a> From<&'a [Committee]> for CommitteesHelper<'a> { + fn from(val: &'a [Committee]) -> Self { + Self(val) + } +} + +impl<'a> Serialize for CommitteesHelper<'a> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut s = serializer.serialize_seq(Some(self.0.len()))?; + for c in self.0 { + s.serialize_element(&CommitteeHelper::from(c))?; + } + s.end() + } +} + +struct CommittedBlockHelper<'a>(&'a [BlockId]); + +impl<'a> From<&'a [BlockId]> for CommittedBlockHelper<'a> { + fn from(val: &'a [BlockId]) -> Self { + Self(val) + } +} + +impl<'a> Serialize for CommittedBlockHelper<'a> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut s = serializer.serialize_seq(Some(self.0.len()))?; + for c in self.0 { + s.serialize_element(&BlockIdHelper::from(c))?; + } + s.end() + } +} diff --git a/simulations/src/node/carnot/state.rs b/simulations/src/node/carnot/state.rs new file mode 100644 index 00000000..23ba9384 --- /dev/null +++ b/simulations/src/node/carnot/state.rs @@ -0,0 +1,162 @@ +use serde::Serialize; + +use super::*; + +#[derive(Debug, Clone)] +pub struct CarnotState { + pub(crate) node_id: NodeId, + pub(crate) current_view: View, + pub(crate) highest_voted_view: View, + pub(crate) local_high_qc: StandardQc, + pub(crate) safe_blocks: HashMap, + pub(crate) last_view_timeout_qc: Option, + pub(crate) latest_committed_block: Block, + pub(crate) latest_committed_view: View, + pub(crate) root_committee: Committee, + pub(crate) parent_committee: Option, + pub(crate) child_committees: Vec, + pub(crate) committed_blocks: Vec, + pub(super) step_duration: Duration, + + /// Step id for this state + pub(super) step_id: usize, + /// does not serialize this field, this field is used to check + /// how to serialize other fields because csv format does not support + /// nested map or struct, we have to do some customize. + pub(super) format: SubscriberFormat, +} + +impl CarnotState { + pub(super) fn new( + step_id: usize, + step_duration: Duration, + fmt: SubscriberFormat, + engine: &Carnot, + ) -> Self { + let mut this = Self::from(engine); + this.step_id = step_id; + this.step_duration = step_duration; + this.format = fmt; + this + } +} + +#[derive(Serialize)] +#[serde(untagged)] +pub enum CarnotRecord { + Runtime(Runtime), + Settings(Box), + Data(Vec>), +} + +impl From for CarnotRecord { + fn from(value: Runtime) -> Self { + Self::Runtime(value) + } +} + +impl From for CarnotRecord { + fn from(value: SimulationSettings) -> Self { + Self::Settings(Box::new(value)) + } +} + +impl Record for CarnotRecord { + type Data = CarnotState; + + fn record_type(&self) -> RecordType { + match self { + CarnotRecord::Runtime(_) => RecordType::Meta, + CarnotRecord::Settings(_) => RecordType::Settings, + CarnotRecord::Data(_) => RecordType::Data, + } + } + + fn data(&self) -> Vec<&CarnotState> { + match self { + CarnotRecord::Data(d) => d.iter().map(AsRef::as_ref).collect(), + _ => vec![], + } + } +} + +impl TryFrom<&SimulationState> for CarnotRecord { + type Error = anyhow::Error; + + fn try_from(state: &SimulationState) -> Result { + let Ok(states) = state + .nodes + .read() + .iter() + .map(|n| Box::::downcast(Box::new(n.state().clone()))) + .collect::, _>>() + else { + return Err(anyhow::anyhow!("use carnot record on other node")); + }; + Ok(Self::Data(states)) + } +} + +impl serde::Serialize for CarnotState { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + if let Some(rs) = RECORD_SETTINGS.get() { + let keys = rs + .iter() + .filter_map(|(k, v)| { + if serde_util::CARNOT_RECORD_KEYS.contains(&k.trim()) && *v { + Some(k) + } else { + None + } + }) + .collect::>(); + + match self.format { + SubscriberFormat::Json => serde_util::CarnotStateJsonSerializer::default() + .serialize_state(keys, self, serializer), + SubscriberFormat::Csv => serde_util::CarnotStateCsvSerializer::default() + .serialize_state(keys, self, serializer), + SubscriberFormat::Parquet => unreachable!(), + } + } else { + serializer.serialize_none() + } + } +} + +impl CarnotState { + const fn keys() -> &'static [&'static str] { + serde_util::CARNOT_RECORD_KEYS + } +} + +impl From<&Carnot> for CarnotState { + fn from(value: &Carnot) -> Self { + let node_id = value.id(); + let current_view = value.current_view(); + Self { + node_id, + current_view, + local_high_qc: value.high_qc(), + parent_committee: value.parent_committee(), + root_committee: value.root_committee(), + child_committees: value.child_committees(), + latest_committed_block: value.latest_committed_block(), + latest_committed_view: value.latest_committed_view(), + safe_blocks: value + .blocks_in_view(current_view) + .into_iter() + .map(|b| (b.id, b)) + .collect(), + last_view_timeout_qc: value.last_view_timeout_qc(), + committed_blocks: value.latest_committed_blocks(), + highest_voted_view: Default::default(), + step_duration: Default::default(), + format: SubscriberFormat::Csv, + step_id: 0, + } + } +} diff --git a/simulations/src/node/dummy.rs b/simulations/src/node/dummy.rs index 1f88a1de..24a01482 100644 --- a/simulations/src/node/dummy.rs +++ b/simulations/src/node/dummy.rs @@ -13,14 +13,14 @@ use crate::{ use super::{CommitteeId, OverlayGetter, OverlayState, SharedState, ViewOverlay}; -#[derive(Debug, Default, Serialize)] +#[derive(Debug, Default, Clone, Serialize)] pub struct DummyState { pub current_view: View, pub message_count: usize, pub view_state: BTreeMap, } -#[derive(Debug, Default, Serialize)] +#[derive(Debug, Default, Copy, Clone, Serialize)] pub struct DummyViewState { proposal_received: bool, vote_received_count: usize, diff --git a/simulations/src/output_processors/mod.rs b/simulations/src/output_processors/mod.rs index 8fb4d8c2..aa3cf1d6 100644 --- a/simulations/src/output_processors/mod.rs +++ b/simulations/src/output_processors/mod.rs @@ -14,6 +14,8 @@ pub enum RecordType { } pub trait Record: From + From + Send + Sync + 'static { + type Data: serde::Serialize; + fn record_type(&self) -> RecordType; fn is_settings(&self) -> bool { @@ -27,6 +29,8 @@ pub trait Record: From + From + Send + Sync + 'stat fn is_data(&self) -> bool { self.record_type() == RecordType::Data } + + fn data(&self) -> Vec<&Self::Data>; } pub type SerializedNodeState = serde_json::Value; @@ -79,6 +83,8 @@ impl From for OutData { } impl Record for OutData { + type Data = SerializedNodeState; + fn record_type(&self) -> RecordType { match self { Self::Runtime(_) => RecordType::Meta, @@ -86,6 +92,13 @@ impl Record for OutData { Self::Data(_) => RecordType::Data, } } + + fn data(&self) -> Vec<&SerializedNodeState> { + match self { + Self::Data(d) => vec![d], + _ => unreachable!(), + } + } } impl OutData { @@ -95,7 +108,7 @@ impl OutData { } } -impl TryFrom<&SimulationState> for OutData { +impl TryFrom<&SimulationState> for OutData { type Error = anyhow::Error; fn try_from(state: &SimulationState) -> Result { diff --git a/simulations/src/runner/async_runner.rs b/simulations/src/runner/async_runner.rs index 50bcc022..4cc5ddcb 100644 --- a/simulations/src/runner/async_runner.rs +++ b/simulations/src/runner/async_runner.rs @@ -53,7 +53,7 @@ where .write() .par_iter_mut() .filter(|n| ids.contains(&n.id())) - .for_each(|node|node.step(step_time)); + .for_each(|node| node.step(step_time)); p.send(R::try_from( &simulation_state, diff --git a/simulations/src/runner/mod.rs b/simulations/src/runner/mod.rs index c37f1312..1a7f564e 100644 --- a/simulations/src/runner/mod.rs +++ b/simulations/src/runner/mod.rs @@ -113,7 +113,7 @@ where + Sync + 'static, S: 'static, - T: Serialize + 'static, + T: Serialize + Clone + 'static, { pub fn new( network: Network, @@ -191,7 +191,7 @@ where + Sync + 'static, S: 'static, - T: Serialize + 'static, + T: Serialize + Clone + 'static, { pub fn simulate_and_subscribe( self, diff --git a/simulations/src/settings.rs b/simulations/src/settings.rs index 58fedfc9..8f98aa3b 100644 --- a/simulations/src/settings.rs +++ b/simulations/src/settings.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::BTreeMap; use crate::network::NetworkSettings; use crate::streaming::StreamSettings; @@ -53,7 +53,7 @@ pub struct SimulationSettings { #[serde(default)] pub wards: Vec, #[serde(default)] - pub record_settings: HashMap, + pub record_settings: BTreeMap, pub network_settings: NetworkSettings, pub overlay_settings: OverlaySettings, pub node_settings: NodeSettings, diff --git a/simulations/src/streaming/mod.rs b/simulations/src/streaming/mod.rs index 1acd7c1a..2ca6d79f 100644 --- a/simulations/src/streaming/mod.rs +++ b/simulations/src/streaming/mod.rs @@ -15,6 +15,57 @@ pub mod polars; pub mod runtime_subscriber; pub mod settings_subscriber; +#[derive(Debug, Default, Clone, Copy, Serialize, PartialEq, Eq)] +pub enum SubscriberFormat { + Json, + #[default] + Csv, + Parquet, +} + +impl SubscriberFormat { + pub const fn csv() -> Self { + Self::Csv + } + + pub const fn json() -> Self { + Self::Json + } + + pub const fn parquet() -> Self { + Self::Parquet + } + + pub fn is_csv(&self) -> bool { + matches!(self, Self::Csv) + } +} + +impl FromStr for SubscriberFormat { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.trim().to_ascii_lowercase().as_str() { + "json" => Ok(Self::Json), + "csv" => Ok(Self::Csv), + "parquet" => Ok(Self::Parquet), + tag => Err(format!( + "Invalid {tag} format, only [json, csv, parquet] are supported", + )), + } + } +} + +impl<'de> Deserialize<'de> for SubscriberFormat { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + SubscriberFormat::from_str(&s).map_err(serde::de::Error::custom) + } +} + pub enum SubscriberType { Meta, Settings, diff --git a/simulations/src/streaming/naive.rs b/simulations/src/streaming/naive.rs index 5ff33efc..6b118ce6 100644 --- a/simulations/src/streaming/naive.rs +++ b/simulations/src/streaming/naive.rs @@ -1,18 +1,23 @@ -use super::{Receivers, StreamSettings, Subscriber}; -use crate::output_processors::{RecordType, Runtime}; +use super::{Receivers, StreamSettings, Subscriber, SubscriberFormat}; +use crate::output_processors::{Record, RecordType, Runtime}; use crossbeam::channel::{Receiver, Sender}; use parking_lot::Mutex; use serde::{Deserialize, Serialize}; use std::{ fs::{File, OpenOptions}, - io::Write, + io::{Seek, Write}, path::PathBuf, - sync::Arc, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, }; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct NaiveSettings { pub path: PathBuf, + #[serde(default = "SubscriberFormat::csv")] + pub format: SubscriberFormat, } impl TryFrom for NaiveSettings { @@ -31,14 +36,19 @@ impl Default for NaiveSettings { let mut tmp = std::env::temp_dir(); tmp.push("simulation"); tmp.set_extension("data"); - Self { path: tmp } + Self { + path: tmp, + format: SubscriberFormat::Csv, + } } } #[derive(Debug)] pub struct NaiveSubscriber { - file: Arc>, - recvs: Arc>, + file: Mutex, + recvs: Receivers, + initialized: AtomicBool, + format: SubscriberFormat, } impl Subscriber for NaiveSubscriber @@ -63,14 +73,16 @@ where recv: record_recv, }; let this = NaiveSubscriber { - file: Arc::new(Mutex::new( + file: Mutex::new( opts.truncate(true) .create(true) .read(true) .write(true) .open(&settings.path)?, - )), - recvs: Arc::new(recvs), + ), + recvs, + initialized: AtomicBool::new(false), + format: settings.format, }; tracing::info!( target = "simulation", @@ -107,8 +119,18 @@ where fn sink(&self, state: Arc) -> anyhow::Result<()> { let mut file = self.file.lock(); - serde_json::to_writer(&mut *file, &state)?; - file.write_all(b",\n")?; + match self.format { + SubscriberFormat::Json => { + write_json_record(&mut *file, &self.initialized, &*state)?; + } + SubscriberFormat::Csv => { + write_csv_record(&mut *file, &self.initialized, &*state)?; + } + SubscriberFormat::Parquet => { + panic!("native subscriber does not support parquet format") + } + } + Ok(()) } @@ -117,6 +139,59 @@ where } } +impl Drop for NaiveSubscriber { + fn drop(&mut self) { + if SubscriberFormat::Json == self.format { + let mut file = self.file.lock(); + // To construct a valid json format, we need to overwrite the last comma + if let Err(e) = file + .seek(std::io::SeekFrom::End(-1)) + .and_then(|_| file.write_all(b"]}")) + { + tracing::error!(target="simulations", err=%e, "fail to close json format"); + } + } + } +} + +fn write_json_record( + mut w: W, + initialized: &AtomicBool, + record: &R, +) -> std::io::Result<()> { + if !initialized.load(Ordering::Acquire) { + w.write_all(b"{\"records\": [")?; + initialized.store(true, Ordering::Release); + } + for data in record.data() { + serde_json::to_writer(&mut w, data)?; + w.write_all(b",")?; + } + Ok(()) +} + +fn write_csv_record( + w: &mut W, + initialized: &AtomicBool, + record: &R, +) -> csv::Result<()> { + // If have not write csv header, then write it + let mut w = if !initialized.load(Ordering::Acquire) { + initialized.store(true, Ordering::Release); + csv::WriterBuilder::new().has_headers(true).from_writer(w) + } else { + csv::WriterBuilder::new().has_headers(false).from_writer(w) + }; + for data in record.data() { + w.serialize(data).map_err(|e| { + tracing::error!(target = "simulations", err = %e, "fail to write CSV record"); + e + })?; + w.flush()?; + } + Ok(()) +} + #[cfg(test)] mod tests { use std::{collections::HashMap, time::Duration}; diff --git a/simulations/src/streaming/polars.rs b/simulations/src/streaming/polars.rs index 249e1778..20e2c26a 100644 --- a/simulations/src/streaming/polars.rs +++ b/simulations/src/streaming/polars.rs @@ -1,4 +1,4 @@ -use super::{Receivers, StreamSettings}; +use super::{Receivers, StreamSettings, SubscriberFormat}; use crate::output_processors::{RecordType, Runtime}; use crossbeam::channel::{Receiver, Sender}; use parking_lot::Mutex; @@ -8,44 +8,11 @@ use std::{ fs::File, io::Cursor, path::{Path, PathBuf}, - str::FromStr, }; -#[derive(Debug, Clone, Copy, Serialize)] -pub enum PolarsFormat { - Json, - Csv, - Parquet, -} - -impl FromStr for PolarsFormat { - type Err = String; - - fn from_str(s: &str) -> Result { - match s.trim().to_ascii_lowercase().as_str() { - "json" => Ok(Self::Json), - "csv" => Ok(Self::Csv), - "parquet" => Ok(Self::Parquet), - tag => Err(format!( - "Invalid {tag} format, only [json, csv, parquet] are supported", - )), - } - } -} - -impl<'de> Deserialize<'de> for PolarsFormat { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let s = String::deserialize(deserializer)?; - PolarsFormat::from_str(&s).map_err(serde::de::Error::custom) - } -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PolarsSettings { - pub format: PolarsFormat, + pub format: SubscriberFormat, #[serde(skip_serializing_if = "Option::is_none")] pub path: Option, } @@ -63,10 +30,10 @@ impl TryFrom for PolarsSettings { #[derive(Debug)] pub struct PolarsSubscriber { - data: Arc>>>, + data: Mutex>>, path: PathBuf, - format: PolarsFormat, - recvs: Arc>, + format: SubscriberFormat, + recvs: Receivers, } impl PolarsSubscriber @@ -83,9 +50,9 @@ where data.unnest(["state"])?; match self.format { - PolarsFormat::Json => dump_dataframe_to_json(&mut data, self.path.as_path()), - PolarsFormat::Csv => dump_dataframe_to_csv(&mut data, self.path.as_path()), - PolarsFormat::Parquet => dump_dataframe_to_parquet(&mut data, self.path.as_path()), + SubscriberFormat::Json => dump_dataframe_to_json(&mut data, self.path.as_path()), + SubscriberFormat::Csv => dump_dataframe_to_csv(&mut data, self.path.as_path()), + SubscriberFormat::Parquet => dump_dataframe_to_parquet(&mut data, self.path.as_path()), } } } @@ -110,14 +77,14 @@ where recv: record_recv, }; let this = PolarsSubscriber { - data: Arc::new(Mutex::new(Vec::new())), - recvs: Arc::new(recvs), + data: Mutex::new(Vec::new()), + recvs, path: settings.path.clone().unwrap_or_else(|| { let mut p = std::env::temp_dir().join("polars"); match settings.format { - PolarsFormat::Json => p.set_extension("json"), - PolarsFormat::Csv => p.set_extension("csv"), - PolarsFormat::Parquet => p.set_extension("parquet"), + SubscriberFormat::Json => p.set_extension("json"), + SubscriberFormat::Csv => p.set_extension("csv"), + SubscriberFormat::Parquet => p.set_extension("parquet"), }; p }),