Rework Relay to expose OutboundRelay on connect (#2)

* Rework Relay to expose OutboundRelay on connect

This simplifies the implementation and enforces a correct usage
of the channel by exposing send methods only after a successful
connection.

* update tests
This commit is contained in:
Giacomo Pasini 2022-10-25 11:34:53 +02:00 committed by GitHub
parent cc20ecc918
commit 7124639b60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 74 deletions

View File

@ -1,7 +1,7 @@
use crate::network::*;
use async_trait::async_trait;
use overwatch::services::handle::ServiceStateHandle;
use overwatch::services::relay::{NoMessage, Relay};
use overwatch::services::relay::{NoMessage, OutboundRelay};
use overwatch::services::state::{NoOperator, NoState};
use overwatch::services::{ServiceCore, ServiceData, ServiceId};
use serde::{Deserialize, Serialize};
@ -38,10 +38,13 @@ impl ServiceCore for ChatService {
mut service_state, ..
} = self;
// TODO: waku should not end up in the public interface of the network service, at least not as a type
let mut network_relay: Relay<NetworkService<waku::Waku>> =
service_state.overwatch_handle.relay();
let mut network_relay = service_state
.overwatch_handle
.relay::<NetworkService<waku::Waku>>()
.connect()
.await
.unwrap();
let user = service_state.settings_reader.get_updated_settings();
network_relay.connect().await.unwrap();
let (sender, mut receiver) = channel(1);
// TODO: typestate so I can't call send if it's not connected
network_relay

View File

@ -1,6 +1,7 @@
// std
use std::any::Any;
use std::fmt::Debug;
use std::marker::PhantomData;
// crates
use thiserror::Error;
use tokio::sync::mpsc::{channel, Receiver, Sender};
@ -47,21 +48,8 @@ pub type RelayResult = Result<AnyMessage, RelayError>;
/// Notice that it is bound to 'static.
pub trait RelayMessage: 'static {}
enum RelayState<M> {
Disconnected,
Connected(OutboundRelay<M>),
}
impl<M> Clone for RelayState<M> {
fn clone(&self) -> Self {
match self {
RelayState::Disconnected => RelayState::Disconnected,
RelayState::Connected(outbound) => RelayState::Connected(outbound.clone()),
}
}
}
/// Channel receiver of a relay connection
#[derive(Debug)]
pub struct InboundRelay<M> {
receiver: Receiver<M>,
_stats: (), // placeholder
@ -73,20 +61,30 @@ pub struct OutboundRelay<M> {
_stats: (), // placeholder
}
#[derive(Debug)]
pub struct Relay<S: ServiceCore> {
state: RelayState<S::Message>,
_marker: PhantomData<S>,
overwatch_handle: OverwatchHandle,
}
impl<S: ServiceCore> Clone for Relay<S> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
_marker: PhantomData,
overwatch_handle: self.overwatch_handle.clone(),
}
}
}
impl<M> Clone for OutboundRelay<M> {
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
_stats: (),
}
}
}
// TODO: make buffer_size const?
/// Relay channel builder
pub fn relay<M>(buffer_size: usize) -> (InboundRelay<M>, OutboundRelay<M>) {
@ -109,71 +107,44 @@ impl<M> InboundRelay<M> {
impl<M> OutboundRelay<M> {
/// Send a message to the relay connection
pub async fn send(&mut self, message: M) -> Result<(), (RelayError, M)> {
pub async fn send(&self, message: M) -> Result<(), (RelayError, M)> {
self.sender
.send(message)
.await
.map_err(|e| (RelayError::Send, e.0))
}
}
impl<M> Clone for OutboundRelay<M> {
fn clone(&self) -> Self {
Self {
sender: self.sender.clone(),
_stats: (),
}
/// Send a message to the relay connection in a blocking fashion.
///
/// The intended usage of this function is for sending data from
/// synchronous code to asynchronous code.
///
/// # Panics
///
/// This function panics if called within an asynchronous execution
/// context.
///
/// # Exa
pub fn blocking_send(&self, message: M) -> Result<(), (RelayError, M)> {
self.sender
.blocking_send(message)
.map_err(|e| (RelayError::Send, e.0))
}
}
impl<S: ServiceCore> Relay<S> {
pub fn new(overwatch_handle: OverwatchHandle) -> Self {
Self {
state: RelayState::Disconnected,
overwatch_handle,
_marker: PhantomData,
}
}
#[instrument(skip(self), err(Debug))]
pub async fn connect(&mut self) -> Result<(), RelayError> {
if let RelayState::Disconnected = self.state {
let (reply, receiver) = oneshot::channel();
self.request_relay(reply).await;
self.handle_relay_response(receiver).await
} else {
Err(RelayError::AlreadyConnected)
}
}
#[instrument(skip(self), err(Debug))]
pub fn disconnect(&mut self) -> Result<(), RelayError> {
self.state = RelayState::Disconnected;
Ok(())
}
#[instrument(skip_all, err(Debug))]
pub async fn send(&mut self, message: S::Message) -> Result<(), RelayError> {
// TODO: we could make a retry system and/or add timeouts
if let RelayState::Connected(outbound_relay) = &mut self.state {
outbound_relay
.send(message)
.await
.map_err(|(e, _message)| e)
} else {
Err(RelayError::Disconnected)
}
}
#[instrument(skip_all, err(Debug))]
pub fn blocking_send(&mut self, message: S::Message) -> Result<(), RelayError> {
if let RelayState::Connected(outbound_relay) = &mut self.state {
outbound_relay
.sender
.blocking_send(message)
.map_err(|_| RelayError::Send)
} else {
Err(RelayError::Disconnected)
}
pub async fn connect(&mut self) -> Result<OutboundRelay<S::Message>, RelayError> {
let (reply, receiver) = oneshot::channel();
self.request_relay(reply).await;
self.handle_relay_response(receiver).await
}
async fn request_relay(&mut self, reply: oneshot::Sender<RelayResult>) {
@ -188,14 +159,11 @@ impl<S: ServiceCore> Relay<S> {
async fn handle_relay_response(
&mut self,
receiver: oneshot::Receiver<RelayResult>,
) -> Result<(), RelayError> {
) -> Result<OutboundRelay<S::Message>, RelayError> {
let response = receiver.await;
match response {
Ok(Ok(message)) => match message.downcast::<OutboundRelay<S::Message>>() {
Ok(channel) => {
self.state = RelayState::Connected(*channel);
Ok(())
}
Ok(channel) => Ok(*channel),
Err(m) => Err(RelayError::InvalidMessage {
type_id: format!("{:?}", m.type_id()),
service_id: S::SERVICE_ID,

View File

@ -90,7 +90,7 @@ fn derive_print_service() {
let mut print_service_relay = handle.relay::<PrintService>();
overwatch.spawn(async move {
print_service_relay
let print_service_relay = print_service_relay
.connect()
.await
.expect("A connection to the print service is established");