Rework Relay to expose OutboundRelay on connect (#2)

* Rework Relay to expose OutboundRelay on connect

This simplifies the implementation and enforces a correct usage
of the channel by exposing send methods only after a successful
connection.

* update tests
This commit is contained in:
Giacomo Pasini 2022-10-25 11:34:53 +02:00 committed by GitHub
parent cc20ecc918
commit 7124639b60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 74 deletions

View File

@ -1,7 +1,7 @@
use crate::network::*; use crate::network::*;
use async_trait::async_trait; use async_trait::async_trait;
use overwatch::services::handle::ServiceStateHandle; use overwatch::services::handle::ServiceStateHandle;
use overwatch::services::relay::{NoMessage, Relay}; use overwatch::services::relay::{NoMessage, OutboundRelay};
use overwatch::services::state::{NoOperator, NoState}; use overwatch::services::state::{NoOperator, NoState};
use overwatch::services::{ServiceCore, ServiceData, ServiceId}; use overwatch::services::{ServiceCore, ServiceData, ServiceId};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -38,10 +38,13 @@ impl ServiceCore for ChatService {
mut service_state, .. mut service_state, ..
} = self; } = self;
// TODO: waku should not end up in the public interface of the network service, at least not as a type // TODO: waku should not end up in the public interface of the network service, at least not as a type
let mut network_relay: Relay<NetworkService<waku::Waku>> = let mut network_relay = service_state
service_state.overwatch_handle.relay(); .overwatch_handle
.relay::<NetworkService<waku::Waku>>()
.connect()
.await
.unwrap();
let user = service_state.settings_reader.get_updated_settings(); let user = service_state.settings_reader.get_updated_settings();
network_relay.connect().await.unwrap();
let (sender, mut receiver) = channel(1); let (sender, mut receiver) = channel(1);
// TODO: typestate so I can't call send if it's not connected // TODO: typestate so I can't call send if it's not connected
network_relay network_relay

View File

@ -1,6 +1,7 @@
// std // std
use std::any::Any; use std::any::Any;
use std::fmt::Debug; use std::fmt::Debug;
use std::marker::PhantomData;
// crates // crates
use thiserror::Error; use thiserror::Error;
use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::mpsc::{channel, Receiver, Sender};
@ -47,21 +48,8 @@ pub type RelayResult = Result<AnyMessage, RelayError>;
/// Notice that it is bound to 'static. /// Notice that it is bound to 'static.
pub trait RelayMessage: 'static {} pub trait RelayMessage: 'static {}
enum RelayState<M> {
Disconnected,
Connected(OutboundRelay<M>),
}
impl<M> Clone for RelayState<M> {
fn clone(&self) -> Self {
match self {
RelayState::Disconnected => RelayState::Disconnected,
RelayState::Connected(outbound) => RelayState::Connected(outbound.clone()),
}
}
}
/// Channel receiver of a relay connection /// Channel receiver of a relay connection
#[derive(Debug)]
pub struct InboundRelay<M> { pub struct InboundRelay<M> {
receiver: Receiver<M>, receiver: Receiver<M>,
_stats: (), // placeholder _stats: (), // placeholder
@ -73,20 +61,30 @@ pub struct OutboundRelay<M> {
_stats: (), // placeholder _stats: (), // placeholder
} }
#[derive(Debug)]
pub struct Relay<S: ServiceCore> { pub struct Relay<S: ServiceCore> {
state: RelayState<S::Message>, _marker: PhantomData<S>,
overwatch_handle: OverwatchHandle, overwatch_handle: OverwatchHandle,
} }
impl<S: ServiceCore> Clone for Relay<S> { impl<S: ServiceCore> Clone for Relay<S> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
state: self.state.clone(), _marker: PhantomData,
overwatch_handle: self.overwatch_handle.clone(), overwatch_handle: self.overwatch_handle.clone(),
} }
} }
} }
impl<M> Clone for OutboundRelay<M> {
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
_stats: (),
}
}
}
// TODO: make buffer_size const? // TODO: make buffer_size const?
/// Relay channel builder /// Relay channel builder
pub fn relay<M>(buffer_size: usize) -> (InboundRelay<M>, OutboundRelay<M>) { pub fn relay<M>(buffer_size: usize) -> (InboundRelay<M>, OutboundRelay<M>) {
@ -109,71 +107,44 @@ impl<M> InboundRelay<M> {
impl<M> OutboundRelay<M> { impl<M> OutboundRelay<M> {
/// Send a message to the relay connection /// Send a message to the relay connection
pub async fn send(&mut self, message: M) -> Result<(), (RelayError, M)> { pub async fn send(&self, message: M) -> Result<(), (RelayError, M)> {
self.sender self.sender
.send(message) .send(message)
.await .await
.map_err(|e| (RelayError::Send, e.0)) .map_err(|e| (RelayError::Send, e.0))
} }
}
impl<M> Clone for OutboundRelay<M> { /// Send a message to the relay connection in a blocking fashion.
fn clone(&self) -> Self { ///
Self { /// The intended usage of this function is for sending data from
sender: self.sender.clone(), /// synchronous code to asynchronous code.
_stats: (), ///
} /// # Panics
///
/// This function panics if called within an asynchronous execution
/// context.
///
/// # Exa
pub fn blocking_send(&self, message: M) -> Result<(), (RelayError, M)> {
self.sender
.blocking_send(message)
.map_err(|e| (RelayError::Send, e.0))
} }
} }
impl<S: ServiceCore> Relay<S> { impl<S: ServiceCore> Relay<S> {
pub fn new(overwatch_handle: OverwatchHandle) -> Self { pub fn new(overwatch_handle: OverwatchHandle) -> Self {
Self { Self {
state: RelayState::Disconnected,
overwatch_handle, overwatch_handle,
_marker: PhantomData,
} }
} }
#[instrument(skip(self), err(Debug))] #[instrument(skip(self), err(Debug))]
pub async fn connect(&mut self) -> Result<(), RelayError> { pub async fn connect(&mut self) -> Result<OutboundRelay<S::Message>, RelayError> {
if let RelayState::Disconnected = self.state { let (reply, receiver) = oneshot::channel();
let (reply, receiver) = oneshot::channel(); self.request_relay(reply).await;
self.request_relay(reply).await; self.handle_relay_response(receiver).await
self.handle_relay_response(receiver).await
} else {
Err(RelayError::AlreadyConnected)
}
}
#[instrument(skip(self), err(Debug))]
pub fn disconnect(&mut self) -> Result<(), RelayError> {
self.state = RelayState::Disconnected;
Ok(())
}
#[instrument(skip_all, err(Debug))]
pub async fn send(&mut self, message: S::Message) -> Result<(), RelayError> {
// TODO: we could make a retry system and/or add timeouts
if let RelayState::Connected(outbound_relay) = &mut self.state {
outbound_relay
.send(message)
.await
.map_err(|(e, _message)| e)
} else {
Err(RelayError::Disconnected)
}
}
#[instrument(skip_all, err(Debug))]
pub fn blocking_send(&mut self, message: S::Message) -> Result<(), RelayError> {
if let RelayState::Connected(outbound_relay) = &mut self.state {
outbound_relay
.sender
.blocking_send(message)
.map_err(|_| RelayError::Send)
} else {
Err(RelayError::Disconnected)
}
} }
async fn request_relay(&mut self, reply: oneshot::Sender<RelayResult>) { async fn request_relay(&mut self, reply: oneshot::Sender<RelayResult>) {
@ -188,14 +159,11 @@ impl<S: ServiceCore> Relay<S> {
async fn handle_relay_response( async fn handle_relay_response(
&mut self, &mut self,
receiver: oneshot::Receiver<RelayResult>, receiver: oneshot::Receiver<RelayResult>,
) -> Result<(), RelayError> { ) -> Result<OutboundRelay<S::Message>, RelayError> {
let response = receiver.await; let response = receiver.await;
match response { match response {
Ok(Ok(message)) => match message.downcast::<OutboundRelay<S::Message>>() { Ok(Ok(message)) => match message.downcast::<OutboundRelay<S::Message>>() {
Ok(channel) => { Ok(channel) => Ok(*channel),
self.state = RelayState::Connected(*channel);
Ok(())
}
Err(m) => Err(RelayError::InvalidMessage { Err(m) => Err(RelayError::InvalidMessage {
type_id: format!("{:?}", m.type_id()), type_id: format!("{:?}", m.type_id()),
service_id: S::SERVICE_ID, service_id: S::SERVICE_ID,

View File

@ -90,7 +90,7 @@ fn derive_print_service() {
let mut print_service_relay = handle.relay::<PrintService>(); let mut print_service_relay = handle.relay::<PrintService>();
overwatch.spawn(async move { overwatch.spawn(async move {
print_service_relay let print_service_relay = print_service_relay
.connect() .connect()
.await .await
.expect("A connection to the print service is established"); .expect("A connection to the print service is established");