diff --git a/Cargo.toml b/Cargo.toml index 5afe29a..f15e6fb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] - +resolver = "2" members = [ "overwatch-rs", "overwatch-derive", diff --git a/overwatch-derive/src/lib.rs b/overwatch-derive/src/lib.rs index 3858db0..8ec3430 100644 --- a/overwatch-derive/src/lib.rs +++ b/overwatch-derive/src/lib.rs @@ -180,16 +180,14 @@ fn generate_start_all_impl(fields: &Punctuated) -> proc_macro2::To let call_start = fields.iter().map(|field| { let field_identifier = field.ident.as_ref().expect("A struct attribute identifier"); quote! { - self.#field_identifier.service_runner().run()?; + self.#field_identifier.service_runner().run()? } }); quote! { #[::tracing::instrument(skip(self), err)] - fn start_all(&mut self) -> Result<(), ::overwatch_rs::overwatch::Error> { - #( #call_start )* - - ::std::result::Result::Ok(()) + fn start_all(&mut self) -> Result<::overwatch_rs::overwatch::ServicesLifeCycleHandle, ::overwatch_rs::overwatch::Error> { + ::std::result::Result::Ok([#( #call_start ),*].try_into()?) } } } diff --git a/overwatch-rs/Cargo.toml b/overwatch-rs/Cargo.toml index 49eb858..8a58460 100644 --- a/overwatch-rs/Cargo.toml +++ b/overwatch-rs/Cargo.toml @@ -29,7 +29,7 @@ color-eyre = "0.6" async-trait = "0.1" futures = "0.3" thiserror = "1.0" -tokio = { version = "1.17", features = ["rt-multi-thread", "sync", "time"] } +tokio = { version = "1.32", features = ["rt-multi-thread", "sync", "time"] } tokio-stream = {version ="0.1", features = ["sync"] } tokio-util = "0.7" tracing = "0.1" diff --git a/overwatch-rs/src/overwatch/commands.rs b/overwatch-rs/src/overwatch/commands.rs index a5989f8..d2b11f8 100644 --- a/overwatch-rs/src/overwatch/commands.rs +++ b/overwatch-rs/src/overwatch/commands.rs @@ -2,6 +2,7 @@ // crates use crate::overwatch::AnySettings; +use crate::services::life_cycle::LifecycleMessage; use tokio::sync::oneshot; // internal @@ -33,18 +34,9 @@ pub struct RelayCommand { /// Command for managing [`ServiceCore`](crate::services::ServiceCore) lifecycle #[allow(unused)] #[derive(Debug)] -pub struct ServiceLifeCycle { - service_id: ServiceId, - reply_channel: ReplyChannel, -} - -/// [`ServiceCore`](crate::services::ServiceCore) lifecycle related commands -#[derive(Debug)] -pub enum ServiceLifeCycleCommand { - Shutdown(ServiceLifeCycle<()>), - Kill(ServiceLifeCycle<()>), - Start(ServiceLifeCycle<()>), - Stop(ServiceLifeCycle<()>), +pub struct ServiceLifeCycleCommand { + pub service_id: ServiceId, + pub msg: LifecycleMessage, } /// [`Overwatch`](crate::overwatch::Overwatch) lifecycle related commands diff --git a/overwatch-rs/src/overwatch/life_cycle.rs b/overwatch-rs/src/overwatch/life_cycle.rs new file mode 100644 index 0000000..00d0240 --- /dev/null +++ b/overwatch-rs/src/overwatch/life_cycle.rs @@ -0,0 +1,87 @@ +// std +use std::borrow::Cow; +use std::collections::HashMap; +use std::default::Default; +use std::error::Error; +// crates +use tokio::sync::broadcast::Sender; +// internal +use crate::services::life_cycle::{FinishedSignal, LifecycleHandle, LifecycleMessage}; +use crate::services::ServiceId; +use crate::DynError; + +/// Grouper handle for the `LifecycleHandle` of each spawned service. +#[derive(Clone)] +pub struct ServicesLifeCycleHandle { + handlers: HashMap, +} + +impl ServicesLifeCycleHandle { + pub fn empty() -> Self { + Self { + handlers: Default::default(), + } + } + + /// Send a `Shutdown` message to the specified service + /// + /// # Arguments + /// + /// `service` - The `ServiceId` of the target service + /// `sender` - A sender side of a broadcast channel. A return signal when finished handling the + /// message will be sent. + pub fn shutdown( + &self, + service: ServiceId, + sender: Sender, + ) -> Result<(), DynError> { + self.handlers + .get(service) + .unwrap() + .send(LifecycleMessage::Shutdown(sender))?; + Ok(()) + } + + /// Send a `Kill` message to the specified service (`ServiceId`) + /// + /// # Arguments + /// + /// `service` - The `ServiceId` of the target service + pub fn kill(&self, service: ServiceId) -> Result<(), DynError> { + self.handlers + .get(service) + .unwrap() + .send(LifecycleMessage::Kill) + } + + /// Send a `Kill` message to all services registered in this handle + pub fn kill_all(&self) -> Result<(), DynError> { + for service_id in self.services_ids() { + self.kill(service_id)?; + } + Ok(()) + } + + /// Get all services ids registered in this handle + pub fn services_ids(&self) -> impl Iterator + '_ { + self.handlers.keys().copied() + } +} + +impl TryFrom<[(ServiceId, LifecycleHandle); N]> for ServicesLifeCycleHandle { + // TODO: On errors refactor extract into a concrete error type with `thiserror` + type Error = Box; + + fn try_from(value: [(ServiceId, LifecycleHandle); N]) -> Result { + let mut handlers = HashMap::new(); + for (service_id, handle) in value { + if handlers.contains_key(service_id) { + return Err(Box::::from(Cow::Owned(format!( + "Duplicated serviceId: {service_id}" + )))); + } + handlers.insert(service_id, handle); + } + Ok(Self { handlers }) + } +} diff --git a/overwatch-rs/src/overwatch/mod.rs b/overwatch-rs/src/overwatch/mod.rs index b1e9555..02c1134 100644 --- a/overwatch-rs/src/overwatch/mod.rs +++ b/overwatch-rs/src/overwatch/mod.rs @@ -1,5 +1,6 @@ pub mod commands; pub mod handle; +pub mod life_cycle; // std use std::any::Any; @@ -14,14 +15,16 @@ use tokio::runtime::{Handle, Runtime}; use tokio::sync::mpsc::Receiver; use tokio::sync::oneshot; use tokio::task::JoinHandle; -use tracing::{info, instrument}; +use tracing::{error, info, instrument}; // internal - use crate::overwatch::commands::{ - OverwatchCommand, OverwatchLifeCycleCommand, RelayCommand, SettingsCommand, + OverwatchCommand, OverwatchLifeCycleCommand, RelayCommand, ServiceLifeCycleCommand, + SettingsCommand, }; use crate::overwatch::handle::OverwatchHandle; +pub use crate::overwatch::life_cycle::ServicesLifeCycleHandle; +use crate::services::life_cycle::LifecycleMessage; use crate::services::relay::RelayResult; use crate::services::{ServiceError, ServiceId}; use crate::utils::runtime::default_multithread_runtime; @@ -79,7 +82,7 @@ pub trait Services: Sized { // TODO: this probably will be removed once the services lifecycle is implemented /// Start all services attached to the trait implementer - fn start_all(&mut self) -> Result<(), Error>; + fn start_all(&mut self) -> Result; /// Stop a service attached to the trait implementer fn stop(&mut self, service_id: ServiceId) -> Result<(), Error>; @@ -124,12 +127,20 @@ where let (commands_sender, commands_receiver) = tokio::sync::mpsc::channel(16); let handle = OverwatchHandle::new(runtime.handle().clone(), commands_sender); let services = S::new(settings, handle.clone())?; - let runner = OverwatchRunner { + let mut runner = OverwatchRunner { services, handle: handle.clone(), finish_signal_sender, }; - runtime.spawn(async move { runner.run_(commands_receiver).await }); + + let lifecycle_handlers = runner.services.start_all()?; + + runtime.spawn(async move { + runner + .run_(commands_receiver, lifecycle_handlers.clone()) + .await + }); + Ok(Overwatch { runtime, handle, @@ -138,28 +149,48 @@ where } #[instrument(name = "overwatch-run", skip_all)] - async fn run_(self, mut receiver: Receiver) { + async fn run_( + self, + mut receiver: Receiver, + lifecycle_handlers: ServicesLifeCycleHandle, + ) { let Self { mut services, handle: _, finish_signal_sender, } = self; - // TODO: this probably need to be manually done, or at least handled by a flag - services.start_all().expect("Services to start running"); while let Some(command) = receiver.recv().await { info!(command = ?command, "Overwatch command received"); match command { OverwatchCommand::Relay(relay_command) => { Self::handle_relay(&mut services, relay_command).await; } - OverwatchCommand::ServiceLifeCycle(_) => { - unimplemented!("Services life cycle is still not supported!"); - } + OverwatchCommand::ServiceLifeCycle(msg) => match msg { + ServiceLifeCycleCommand { + service_id, + msg: LifecycleMessage::Shutdown(channel), + } => { + if let Err(e) = lifecycle_handlers.shutdown(service_id, channel) { + error!(e); + } + } + ServiceLifeCycleCommand { + service_id, + msg: LifecycleMessage::Kill, + } => { + if let Err(e) = lifecycle_handlers.kill(service_id) { + error!(e); + } + } + }, OverwatchCommand::OverwatchLifeCycle(command) => { if matches!( command, OverwatchLifeCycleCommand::Kill | OverwatchLifeCycleCommand::Shutdown ) { + if let Err(e) = lifecycle_handlers.kill_all() { + error!(e); + } break; } } @@ -216,7 +247,7 @@ impl Overwatch { &self.handle } - /// Get the underllaying tokio runtime handle + /// Get the underlaying tokio runtime handle pub fn runtime(&self) -> &Handle { self.runtime.handle() } @@ -247,7 +278,7 @@ impl Overwatch { #[cfg(test)] mod test { use crate::overwatch::handle::OverwatchHandle; - use crate::overwatch::{Error, OverwatchRunner, Services}; + use crate::overwatch::{Error, OverwatchRunner, Services, ServicesLifeCycleHandle}; use crate::services::relay::{RelayError, RelayResult}; use crate::services::ServiceId; use std::time::Duration; @@ -269,8 +300,8 @@ mod test { Err(Error::Unavailable { service_id }) } - fn start_all(&mut self) -> Result<(), Error> { - Ok(()) + fn start_all(&mut self) -> Result { + Ok(ServicesLifeCycleHandle::empty()) } fn stop(&mut self, service_id: ServiceId) -> Result<(), Error> { diff --git a/overwatch-rs/src/services/handle.rs b/overwatch-rs/src/services/handle.rs index a5085dc..851bbe8 100644 --- a/overwatch-rs/src/services/handle.rs +++ b/overwatch-rs/src/services/handle.rs @@ -1,14 +1,14 @@ // crates -use futures::future::{abortable, AbortHandle}; use tokio::runtime::Handle; // internal use crate::overwatch::handle::OverwatchHandle; +use crate::services::life_cycle::LifecycleHandle; use crate::services::relay::{relay, InboundRelay, OutboundRelay}; use crate::services::settings::{SettingsNotifier, SettingsUpdater}; use crate::services::state::{StateHandle, StateOperator, StateUpdater}; use crate::services::{ServiceCore, ServiceData, ServiceId, ServiceState}; -// TODO: Abstract handle over state, to diferentiate when the service is running and when it is not +// TODO: Abstract handle over state, to differentiate when the service is running and when it is not // that way we can expose a better API depending on what is happenning. Would get rid of the probably // unnecessary Option and cloning. /// Service handle @@ -33,7 +33,7 @@ pub struct ServiceStateHandle { pub overwatch_handle: OverwatchHandle, pub settings_reader: SettingsNotifier, pub state_updater: StateUpdater, - pub _lifecycle_handler: (), + pub lifecycle_handle: LifecycleHandle, } /// Main service executor @@ -41,6 +41,7 @@ pub struct ServiceStateHandle { pub struct ServiceRunner { service_state: ServiceStateHandle, state_handle: StateHandle, + lifecycle_handle: LifecycleHandle, } impl ServiceHandle { @@ -94,17 +95,20 @@ impl ServiceHandle { let (state_handle, state_updater) = StateHandle::::new(self.initial_state.clone(), operator); + let lifecycle_handle = LifecycleHandle::new(); + let service_state = ServiceStateHandle { inbound_relay, overwatch_handle: self.overwatch_handle.clone(), state_updater, settings_reader, - _lifecycle_handler: (), + lifecycle_handle: lifecycle_handle.clone(), }; ServiceRunner { service_state, state_handle, + lifecycle_handle, } } } @@ -124,22 +128,19 @@ where /// Spawn the service main loop and handle it lifecycle /// Return a handle to abort execution manually - pub fn run(self) -> Result { + pub fn run(self) -> Result<(ServiceId, LifecycleHandle), crate::DynError> { let ServiceRunner { service_state, state_handle, - .. + lifecycle_handle, } = self; let runtime = service_state.overwatch_handle.runtime().clone(); let service = S::init(service_state)?; - let (runner, abortable_handle) = abortable(service.run()); - runtime.spawn(runner); + runtime.spawn(service.run()); runtime.spawn(state_handle.run()); - // TODO: Handle service lifecycle - // TODO: this handle should not scape this scope, it should actually be handled in the lifecycle part mentioned above - Ok(abortable_handle) + Ok((S::SERVICE_ID, lifecycle_handle)) } } diff --git a/overwatch-rs/src/services/life_cycle.rs b/overwatch-rs/src/services/life_cycle.rs index 8b13789..12a306d 100644 --- a/overwatch-rs/src/services/life_cycle.rs +++ b/overwatch-rs/src/services/life_cycle.rs @@ -1 +1,73 @@ +use crate::DynError; +use futures::Stream; +use std::default::Default; +use std::error::Error; +use tokio::sync::broadcast::{channel, Receiver, Sender}; +use tokio_stream::StreamExt; +/// Type alias for an empty signal +pub type FinishedSignal = (); + +/// Supported lifecycle messages +#[derive(Clone, Debug)] +pub enum LifecycleMessage { + /// Shutdown + /// Hold a sender from a broadcast channel. It is intended to signal when finished handling the + /// shutdown process. + Shutdown(Sender), + /// Kill + /// Well, nothing much to explain here, everything should be about to be nuked. + Kill, +} + +/// Handle for lifecycle communications with a `Service` +pub struct LifecycleHandle { + message_channel: Receiver, + notifier: Sender, +} + +impl Clone for LifecycleHandle { + fn clone(&self) -> Self { + Self { + // `resubscribe` gives us access just to newly produced event not already enqueued ones + // that is fine, as at any point missing signals means you were not interested in the moment + // it was produced and most probably whatever holding the handle was not even alive. + message_channel: self.message_channel.resubscribe(), + notifier: self.notifier.clone(), + } + } +} + +impl LifecycleHandle { + pub fn new() -> Self { + // Use a single lifecycle message at a time. Idea is that all computations on lifecycle should + // stack so waiting es effective even if later on is somehow reversed (for example for start/stop events). + let (notifier, message_channel) = channel(1); + Self { + notifier, + message_channel, + } + } + + /// Incoming lifecycle message stream + /// Notice that messages are not buffered. So, different calls to this method could yield different + /// incoming messages depending the timing of call. + pub fn message_stream(&self) -> impl Stream { + tokio_stream::wrappers::BroadcastStream::new(self.message_channel.resubscribe()) + .filter_map(Result::ok) + } + + /// Send a `LifecycleMessage` to the service + pub fn send(&self, msg: LifecycleMessage) -> Result<(), DynError> { + self.notifier + .send(msg) + .map(|_| ()) + .map_err(|e| Box::new(e) as Box) + } +} + +impl Default for LifecycleHandle { + fn default() -> Self { + Self::new() + } +} diff --git a/overwatch-rs/src/services/relay.rs b/overwatch-rs/src/services/relay.rs index 19e7ab9..bbe2a83 100644 --- a/overwatch-rs/src/services/relay.rs +++ b/overwatch-rs/src/services/relay.rs @@ -189,7 +189,7 @@ impl Relay { Ok(Ok(message)) => match message.downcast::>() { Ok(channel) => Ok(*channel), Err(m) => Err(RelayError::InvalidMessage { - type_id: format!("{:?}", m.type_id()), + type_id: format!("{:?}", (*m).type_id()), service_id: S::SERVICE_ID, }), }, diff --git a/overwatch-rs/tests/cancelable_service.rs b/overwatch-rs/tests/cancelable_service.rs new file mode 100644 index 0000000..e0ab96f --- /dev/null +++ b/overwatch-rs/tests/cancelable_service.rs @@ -0,0 +1,86 @@ +use overwatch_derive::Services; +use overwatch_rs::overwatch::commands::{OverwatchCommand, ServiceLifeCycleCommand}; +use overwatch_rs::overwatch::OverwatchRunner; +use overwatch_rs::services::handle::{ServiceHandle, ServiceStateHandle}; +use overwatch_rs::services::life_cycle::LifecycleMessage; +use overwatch_rs::services::relay::NoMessage; +use overwatch_rs::services::state::{NoOperator, NoState}; +use overwatch_rs::services::{ServiceCore, ServiceData, ServiceId}; +use overwatch_rs::DynError; +use std::time::Duration; +use tokio::time::sleep; +use tokio_stream::StreamExt; + +pub struct CancellableService { + service_state: ServiceStateHandle, +} + +impl ServiceData for CancellableService { + const SERVICE_ID: ServiceId = "cancel-me-please"; + type Settings = (); + type State = NoState; + type StateOperator = NoOperator; + type Message = NoMessage; +} + +#[async_trait::async_trait] +impl ServiceCore for CancellableService { + fn init(service_state: ServiceStateHandle) -> Result { + Ok(Self { service_state }) + } + + async fn run(self) -> Result<(), DynError> { + let mut lifecycle_stream = self.service_state.lifecycle_handle.message_stream(); + let mut interval = tokio::time::interval(Duration::from_millis(200)); + loop { + tokio::select! { + msg = lifecycle_stream.next() => { + match msg { + Some(LifecycleMessage::Shutdown(reply)) => { + reply.send(()).unwrap(); + break; + } + Some(LifecycleMessage::Kill) => { + break; + } + _ => { + unimplemented!(); + } + } + } + _ = interval.tick() => { + println!("Waiting to be killed 💀"); + } + } + } + Ok(()) + } +} + +#[derive(Services)] +struct CancelableServices { + cancelable: ServiceHandle, +} + +#[test] +fn run_overwatch_then_shutdown_service_and_kill() { + let settings = CancelableServicesServiceSettings { cancelable: () }; + let overwatch = OverwatchRunner::::run(settings, None).unwrap(); + let handle = overwatch.handle().clone(); + let (sender, mut receiver) = tokio::sync::broadcast::channel(1); + overwatch.spawn(async move { + sleep(Duration::from_millis(500)).await; + handle + .send(OverwatchCommand::ServiceLifeCycle( + ServiceLifeCycleCommand { + service_id: ::SERVICE_ID, + msg: LifecycleMessage::Shutdown(sender), + }, + )) + .await; + // wait service finished + receiver.recv().await.unwrap(); + handle.kill().await; + }); + overwatch.wait_finished(); +}