From 7124639b601819282c58f565608dab5e664a87c0 Mon Sep 17 00:00:00 2001 From: Giacomo Pasini Date: Tue, 25 Oct 2022 11:34:53 +0200 Subject: [PATCH] Rework Relay to expose OutboundRelay on connect (#2) * Rework Relay to expose OutboundRelay on connect This simplifies the implementation and enforces a correct usage of the channel by exposing send methods only after a successful connection. * update tests --- examples/waku-chat/src/chat.rs | 11 ++-- overwatch/src/services/relay.rs | 106 +++++++++++-------------------- overwatch/tests/print_service.rs | 2 +- 3 files changed, 45 insertions(+), 74 deletions(-) diff --git a/examples/waku-chat/src/chat.rs b/examples/waku-chat/src/chat.rs index b322294..f389743 100644 --- a/examples/waku-chat/src/chat.rs +++ b/examples/waku-chat/src/chat.rs @@ -1,7 +1,7 @@ use crate::network::*; use async_trait::async_trait; use overwatch::services::handle::ServiceStateHandle; -use overwatch::services::relay::{NoMessage, Relay}; +use overwatch::services::relay::{NoMessage, OutboundRelay}; use overwatch::services::state::{NoOperator, NoState}; use overwatch::services::{ServiceCore, ServiceData, ServiceId}; use serde::{Deserialize, Serialize}; @@ -38,10 +38,13 @@ impl ServiceCore for ChatService { mut service_state, .. } = self; // TODO: waku should not end up in the public interface of the network service, at least not as a type - let mut network_relay: Relay> = - service_state.overwatch_handle.relay(); + let mut network_relay = service_state + .overwatch_handle + .relay::>() + .connect() + .await + .unwrap(); let user = service_state.settings_reader.get_updated_settings(); - network_relay.connect().await.unwrap(); let (sender, mut receiver) = channel(1); // TODO: typestate so I can't call send if it's not connected network_relay diff --git a/overwatch/src/services/relay.rs b/overwatch/src/services/relay.rs index 89815f8..87fefd0 100644 --- a/overwatch/src/services/relay.rs +++ b/overwatch/src/services/relay.rs @@ -1,6 +1,7 @@ // std use std::any::Any; use std::fmt::Debug; +use std::marker::PhantomData; // crates use thiserror::Error; use tokio::sync::mpsc::{channel, Receiver, Sender}; @@ -47,21 +48,8 @@ pub type RelayResult = Result; /// Notice that it is bound to 'static. pub trait RelayMessage: 'static {} -enum RelayState { - Disconnected, - Connected(OutboundRelay), -} - -impl Clone for RelayState { - fn clone(&self) -> Self { - match self { - RelayState::Disconnected => RelayState::Disconnected, - RelayState::Connected(outbound) => RelayState::Connected(outbound.clone()), - } - } -} - /// Channel receiver of a relay connection +#[derive(Debug)] pub struct InboundRelay { receiver: Receiver, _stats: (), // placeholder @@ -73,20 +61,30 @@ pub struct OutboundRelay { _stats: (), // placeholder } +#[derive(Debug)] pub struct Relay { - state: RelayState, + _marker: PhantomData, overwatch_handle: OverwatchHandle, } impl Clone for Relay { fn clone(&self) -> Self { Self { - state: self.state.clone(), + _marker: PhantomData, overwatch_handle: self.overwatch_handle.clone(), } } } +impl Clone for OutboundRelay { + fn clone(&self) -> Self { + Self { + sender: self.sender.clone(), + _stats: (), + } + } +} + // TODO: make buffer_size const? /// Relay channel builder pub fn relay(buffer_size: usize) -> (InboundRelay, OutboundRelay) { @@ -109,71 +107,44 @@ impl InboundRelay { impl OutboundRelay { /// Send a message to the relay connection - pub async fn send(&mut self, message: M) -> Result<(), (RelayError, M)> { + pub async fn send(&self, message: M) -> Result<(), (RelayError, M)> { self.sender .send(message) .await .map_err(|e| (RelayError::Send, e.0)) } -} -impl Clone for OutboundRelay { - fn clone(&self) -> Self { - Self { - sender: self.sender.clone(), - _stats: (), - } + /// Send a message to the relay connection in a blocking fashion. + /// + /// The intended usage of this function is for sending data from + /// synchronous code to asynchronous code. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution + /// context. + /// + /// # Exa + pub fn blocking_send(&self, message: M) -> Result<(), (RelayError, M)> { + self.sender + .blocking_send(message) + .map_err(|e| (RelayError::Send, e.0)) } } impl Relay { pub fn new(overwatch_handle: OverwatchHandle) -> Self { Self { - state: RelayState::Disconnected, overwatch_handle, + _marker: PhantomData, } } #[instrument(skip(self), err(Debug))] - pub async fn connect(&mut self) -> Result<(), RelayError> { - if let RelayState::Disconnected = self.state { - let (reply, receiver) = oneshot::channel(); - self.request_relay(reply).await; - self.handle_relay_response(receiver).await - } else { - Err(RelayError::AlreadyConnected) - } - } - - #[instrument(skip(self), err(Debug))] - pub fn disconnect(&mut self) -> Result<(), RelayError> { - self.state = RelayState::Disconnected; - Ok(()) - } - - #[instrument(skip_all, err(Debug))] - pub async fn send(&mut self, message: S::Message) -> Result<(), RelayError> { - // TODO: we could make a retry system and/or add timeouts - if let RelayState::Connected(outbound_relay) = &mut self.state { - outbound_relay - .send(message) - .await - .map_err(|(e, _message)| e) - } else { - Err(RelayError::Disconnected) - } - } - - #[instrument(skip_all, err(Debug))] - pub fn blocking_send(&mut self, message: S::Message) -> Result<(), RelayError> { - if let RelayState::Connected(outbound_relay) = &mut self.state { - outbound_relay - .sender - .blocking_send(message) - .map_err(|_| RelayError::Send) - } else { - Err(RelayError::Disconnected) - } + pub async fn connect(&mut self) -> Result, RelayError> { + let (reply, receiver) = oneshot::channel(); + self.request_relay(reply).await; + self.handle_relay_response(receiver).await } async fn request_relay(&mut self, reply: oneshot::Sender) { @@ -188,14 +159,11 @@ impl Relay { async fn handle_relay_response( &mut self, receiver: oneshot::Receiver, - ) -> Result<(), RelayError> { + ) -> Result, RelayError> { let response = receiver.await; match response { Ok(Ok(message)) => match message.downcast::>() { - Ok(channel) => { - self.state = RelayState::Connected(*channel); - Ok(()) - } + Ok(channel) => Ok(*channel), Err(m) => Err(RelayError::InvalidMessage { type_id: format!("{:?}", m.type_id()), service_id: S::SERVICE_ID, diff --git a/overwatch/tests/print_service.rs b/overwatch/tests/print_service.rs index c6ddab3..9153c0e 100644 --- a/overwatch/tests/print_service.rs +++ b/overwatch/tests/print_service.rs @@ -90,7 +90,7 @@ fn derive_print_service() { let mut print_service_relay = handle.relay::(); overwatch.spawn(async move { - print_service_relay + let print_service_relay = print_service_relay .connect() .await .expect("A connection to the print service is established");