diff --git a/Cargo.lock b/Cargo.lock index 12c6fc2..d605a83 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -818,6 +818,7 @@ dependencies = [ "thiserror", "tokio", "tokio-stream", + "tokio-util", "tracing", ] diff --git a/overwatch-rs/Cargo.toml b/overwatch-rs/Cargo.toml index a0b5257..696afc7 100644 --- a/overwatch-rs/Cargo.toml +++ b/overwatch-rs/Cargo.toml @@ -28,6 +28,7 @@ futures = "0.3" thiserror = "1.0" tokio = { version = "1.17", features = ["rt-multi-thread", "sync", "time"] } tokio-stream = {version ="0.1", features = ["sync"] } +tokio-util = "0.7" tracing = "0.1" [dev-dependencies] diff --git a/overwatch-rs/src/services/relay.rs b/overwatch-rs/src/services/relay.rs index 55beb38..7bbe97a 100644 --- a/overwatch-rs/src/services/relay.rs +++ b/overwatch-rs/src/services/relay.rs @@ -2,10 +2,14 @@ use std::any::Any; use std::fmt::Debug; use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; // crates +use futures::{Sink, Stream}; use thiserror::Error; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::oneshot; +use tokio_util::sync::PollSender; use tracing::{error, instrument}; // internal use crate::overwatch::commands::{OverwatchCommand, RelayCommand, ReplyChannel}; @@ -105,7 +109,7 @@ impl InboundRelay { } } -impl OutboundRelay { +impl OutboundRelay { /// Send a message to the relay connection pub async fn send(&self, message: M) -> Result<(), (RelayError, M)> { self.sender @@ -130,6 +134,10 @@ impl OutboundRelay { .blocking_send(message) .map_err(|e| (RelayError::Send, e.0)) } + + pub fn into_sink(self) -> impl Sink { + PollSender::new(self.sender) + } } impl Relay { @@ -174,3 +182,11 @@ impl Relay { } } } + +impl Stream for InboundRelay { + type Item = M; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.receiver.poll_recv(cx) + } +}