From cb343156b7ac1bb88d57cfdde7df9c60acb2e9ae Mon Sep 17 00:00:00 2001 From: Al Liu Date: Mon, 25 Sep 2023 14:21:07 +0800 Subject: [PATCH] Add concrete Error implementation for mixnet (#405) Add concrete Error implementation --- mixnet/client/src/error.rs | 22 +++++++++ mixnet/client/src/lib.rs | 26 ++-------- mixnet/client/src/receiver.rs | 32 ++++--------- mixnet/client/src/sender.rs | 27 +++++------ mixnet/node/Cargo.toml | 1 + mixnet/node/src/client_notifier.rs | 14 +++--- mixnet/node/src/lib.rs | 52 +++++++++++++------- mixnet/protocol/Cargo.toml | 5 +- mixnet/protocol/src/lib.rs | 76 +++++++++++++++++------------- mixnet/topology/Cargo.toml | 1 + mixnet/topology/src/lib.rs | 20 +++----- mixnet/util/Cargo.toml | 3 +- mixnet/util/src/lib.rs | 11 ++++- 13 files changed, 160 insertions(+), 130 deletions(-) create mode 100644 mixnet/client/src/error.rs diff --git a/mixnet/client/src/error.rs b/mixnet/client/src/error.rs new file mode 100644 index 00000000..f1613e5f --- /dev/null +++ b/mixnet/client/src/error.rs @@ -0,0 +1,22 @@ +use mixnet_protocol::ProtocolError; +use nym_sphinx::addressing::nodes::NymNodeRoutingAddressError; + +#[derive(thiserror::Error, Debug)] +pub enum MixnetClientError { + #[error("mixnet node connect error")] + MixnetNodeConnectError, + #[error("mixnode stream has been closed")] + MixnetNodeStreamClosed, + #[error("unexpected stream body received")] + UnexpectedStreamBody, + #[error("invalid payload")] + InvalidPayload, + #[error("invalid routing address: {0}")] + InvalidRoutingAddress(#[from] NymNodeRoutingAddressError), + #[error("{0}")] + Protocol(#[from] ProtocolError), + #[error("{0}")] + Message(#[from] nym_sphinx::message::NymMessageError), +} + +pub type Result = core::result::Result; diff --git a/mixnet/client/src/lib.rs b/mixnet/client/src/lib.rs index a53c94a4..ec871eb1 100644 --- a/mixnet/client/src/lib.rs +++ b/mixnet/client/src/lib.rs @@ -1,8 +1,9 @@ pub mod config; +pub mod error; +pub use error::*; mod receiver; mod sender; -use std::error::Error; use std::time::Duration; pub use config::MixnetClientConfig; @@ -11,7 +12,6 @@ use futures::stream::BoxStream; use mixnet_util::ConnectionPool; use rand::Rng; use sender::Sender; -use thiserror::Error; // A client for sending packets to Mixnet and receiving packets from Mixnet. pub struct MixnetClient { @@ -19,7 +19,7 @@ pub struct MixnetClient { sender: Sender, } -pub type MessageStream = BoxStream<'static, Result, MixnetClientError>>; +pub type MessageStream = BoxStream<'static, Result>>; impl MixnetClient { pub fn new(config: MixnetClientConfig, rng: R) -> Self { @@ -36,27 +36,11 @@ impl MixnetClient { } } - pub async fn run(&self) -> Result { + pub async fn run(&self) -> Result { self.mode.run().await } - pub fn send( - &mut self, - msg: Vec, - total_delay: Duration, - ) -> Result<(), Box> { + pub fn send(&mut self, msg: Vec, total_delay: Duration) -> Result<()> { self.sender.send(msg, total_delay) } } - -#[derive(Error, Debug)] -pub enum MixnetClientError { - #[error("mixnet node connect error")] - MixnetNodeConnectError, - #[error("mixnode stream has been closed")] - MixnetNodeStreamClosed, - #[error("unexpected stream body received")] - UnexpectedStreamBody, - #[error("invalid payload")] - InvalidPayload, -} diff --git a/mixnet/client/src/receiver.rs b/mixnet/client/src/receiver.rs index e0b5a218..65ebd716 100644 --- a/mixnet/client/src/receiver.rs +++ b/mixnet/client/src/receiver.rs @@ -1,4 +1,4 @@ -use std::{error::Error, net::SocketAddr}; +use std::net::SocketAddr; use futures::{stream, Stream, StreamExt}; use mixnet_protocol::Body; @@ -9,6 +9,7 @@ use nym_sphinx::{ }; use tokio::net::TcpStream; +use super::error::*; use crate::MixnetClientError; // Receiver accepts TCP connections to receive incoming payloads from the Mixnet. @@ -21,12 +22,7 @@ impl Receiver { Self { node_address } } - pub async fn run( - &self, - ) -> Result< - impl Stream, MixnetClientError>> + Send + 'static, - MixnetClientError, - > { + pub async fn run(&self) -> Result>> + Send + 'static> { let Ok(socket) = TcpStream::connect(self.node_address).await else { return Err(MixnetClientError::MixnetNodeConnectError); }; @@ -36,9 +32,7 @@ impl Receiver { )))) } - fn fragment_stream( - socket: TcpStream, - ) -> impl Stream> + Send + 'static { + fn fragment_stream(socket: TcpStream) -> impl Stream> + Send + 'static { stream::unfold(socket, move |mut socket| { async move { let Ok(body) = Body::read(&mut socket).await else { @@ -60,11 +54,8 @@ impl Receiver { } fn message_stream( - fragment_stream: impl Stream> - + Send - + Unpin - + 'static, - ) -> impl Stream, MixnetClientError>> + Send + 'static { + fragment_stream: impl Stream> + Send + Unpin + 'static, + ) -> impl Stream>> + Send + 'static { // MessageReconstructor buffers all received fragments // and eventually returns reconstructed messages. let message_reconstructor: MessageReconstructor = Default::default(); @@ -80,7 +71,7 @@ impl Receiver { ) } - fn fragment_from_payload(payload: Payload) -> Result { + fn fragment_from_payload(payload: Payload) -> Result { let Ok(payload_plaintext) = payload.recover_plaintext() else { return Err(MixnetClientError::InvalidPayload); }; @@ -91,12 +82,9 @@ impl Receiver { } async fn reconstruct_message( - fragment_stream: &mut (impl Stream> - + Send - + Unpin - + 'static), + fragment_stream: &mut (impl Stream> + Send + Unpin + 'static), message_reconstructor: &mut MessageReconstructor, - ) -> Result, MixnetClientError> { + ) -> Result> { // Read fragments until at least one message is fully reconstructed. while let Some(next) = fragment_stream.next().await { match next { @@ -131,7 +119,7 @@ impl Receiver { } } - fn remove_padding(msg: Vec) -> Result, Box> { + fn remove_padding(msg: Vec) -> Result> { let padded_message = PaddedMessage::new_reconstructed(msg); // we need this because PaddedMessage.remove_padding requires it for other NymMessage types. let dummy_num_mix_hops = 0; diff --git a/mixnet/client/src/sender.rs b/mixnet/client/src/sender.rs index 32d0f644..a3450fd3 100644 --- a/mixnet/client/src/sender.rs +++ b/mixnet/client/src/sender.rs @@ -1,6 +1,6 @@ -use std::{error::Error, net::SocketAddr, time::Duration}; +use std::{net::SocketAddr, time::Duration}; -use mixnet_protocol::Body; +use mixnet_protocol::{Body, ProtocolError}; use mixnet_topology::MixnetTopology; use mixnet_util::ConnectionPool; use nym_sphinx::{ @@ -11,6 +11,8 @@ use nym_sphinx::{ use rand::{distributions::Uniform, prelude::Distribution, Rng}; use sphinx_packet::{route, SphinxPacket, SphinxPacketBuilder}; +use super::error::*; + // Sender splits messages into Sphinx packets and sends them to the Mixnet. pub struct Sender { //TODO: handle topology update @@ -38,11 +40,7 @@ impl Sender { } } - pub fn send( - &mut self, - msg: Vec, - total_delay: Duration, - ) -> Result<(), Box> { + pub fn send(&mut self, msg: Vec, total_delay: Duration) -> Result<()> { let destination = self.topology.random_destination(&mut self.rng)?; let destination = Destination::new( DestinationAddressBytes::from_bytes(destination.address.as_bytes()), @@ -52,7 +50,7 @@ impl Sender { self.pad_and_split_message(msg) .into_iter() .map(|fragment| self.build_sphinx_packet(fragment, &destination, total_delay)) - .collect::, _>>()? + .collect::>>()? .into_iter() .for_each(|(packet, first_node)| { let pool = self.pool.clone(); @@ -95,8 +93,7 @@ impl Sender { fragment: Fragment, destination: &Destination, total_delay: Duration, - ) -> Result<(sphinx_packet::SphinxPacket, route::Node), Box> - { + ) -> Result<(sphinx_packet::SphinxPacket, route::Node)> { let route = self.topology.random_route(&mut self.rng)?; let delays: Vec = @@ -110,7 +107,8 @@ impl Sender { let packet = SphinxPacketBuilder::new() .with_payload_size(payload.len() + PAYLOAD_OVERHEAD_SIZE) - .build_packet(payload, &route, destination, &delays)?; + .build_packet(payload, &route, destination, &delays) + .map_err(ProtocolError::InvalidSphinxPacket)?; let first_mixnode = route.first().cloned().expect("route is not empty"); @@ -123,8 +121,8 @@ impl Sender { retry_delay: Duration, packet: Box, addr: NodeAddressBytes, - ) -> Result<(), Box> { - let addr = SocketAddr::try_from(NymNodeRoutingAddress::try_from(addr)?)?; + ) -> Result<()> { + let addr = SocketAddr::from(NymNodeRoutingAddress::try_from(addr)?); tracing::debug!("Sending a Sphinx packet to the node: {addr:?}"); let mu: std::sync::Arc> = @@ -145,7 +143,8 @@ impl Sender { body, arc_socket, ) - .await; + .await + .map_err(Into::into); } Ok(()) } diff --git a/mixnet/node/Cargo.toml b/mixnet/node/Cargo.toml index 703f3412..37601022 100644 --- a/mixnet/node/Cargo.toml +++ b/mixnet/node/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" serde = { version = "1.0", features = ["derive"] } tracing = "0.1.37" tokio = { version = "1.32", features = ["net", "time"] } +thiserror = "1" sphinx-packet = "0.1.0" nym-sphinx = { package = "nym-sphinx", git = "https://github.com/nymtech/nym", tag = "v1.1.22" } mixnet-protocol = { path = "../protocol" } diff --git a/mixnet/node/src/client_notifier.rs b/mixnet/node/src/client_notifier.rs index 688db25c..2a277689 100644 --- a/mixnet/node/src/client_notifier.rs +++ b/mixnet/node/src/client_notifier.rs @@ -1,4 +1,4 @@ -use std::{error::Error, net::SocketAddr}; +use std::net::SocketAddr; use mixnet_protocol::Body; use tokio::{ @@ -12,8 +12,10 @@ impl ClientNotifier { pub async fn run( listen_address: SocketAddr, mut rx: mpsc::Receiver, - ) -> Result<(), Box> { - let listener = TcpListener::bind(listen_address).await?; + ) -> super::Result<()> { + let listener = TcpListener::bind(listen_address) + .await + .map_err(super::ProtocolError::IO)?; tracing::info!("Listening mixnet client connections: {listen_address}"); // Currently, handling only a single incoming connection @@ -21,7 +23,7 @@ impl ClientNotifier { loop { match listener.accept().await { Ok((socket, remote_addr)) => { - tracing::debug!("Accepted incoming client connection from {remote_addr:?}"); + tracing::debug!("Accepted incoming client connection from {remote_addr}"); if let Err(e) = Self::handle_connection(socket, &mut rx).await { tracing::error!("failed to handle conn: {e}"); @@ -35,10 +37,10 @@ impl ClientNotifier { async fn handle_connection( mut socket: TcpStream, rx: &mut mpsc::Receiver, - ) -> Result<(), Box> { + ) -> super::Result<()> { while let Some(body) = rx.recv().await { if let Err(e) = body.write(&mut socket).await { - return Err(format!("error from client conn: {e}").into()); + return Err(super::MixnetNodeError::Client(e)); } } tracing::debug!("body receiver closed"); diff --git a/mixnet/node/src/lib.rs b/mixnet/node/src/lib.rs index 6997bdfe..61337689 100644 --- a/mixnet/node/src/lib.rs +++ b/mixnet/node/src/lib.rs @@ -1,16 +1,16 @@ mod client_notifier; pub mod config; -use std::{error::Error, net::SocketAddr, time::Duration}; +use std::{net::SocketAddr, time::Duration}; use client_notifier::ClientNotifier; pub use config::MixnetNodeConfig; -use mixnet_protocol::Body; +use mixnet_protocol::{Body, ProtocolError}; use mixnet_topology::MixnetNodeId; use mixnet_util::ConnectionPool; use nym_sphinx::{ - addressing::nodes::NymNodeRoutingAddress, Delay, DestinationAddressBytes, NodeAddressBytes, - Payload, PrivateKey, + addressing::nodes::{NymNodeRoutingAddress, NymNodeRoutingAddressError}, + Delay, DestinationAddressBytes, NodeAddressBytes, Payload, PrivateKey, }; pub use sphinx_packet::crypto::PRIVATE_KEY_SIZE; use sphinx_packet::{crypto::PUBLIC_KEY_SIZE, ProcessedPacket, SphinxPacket}; @@ -19,6 +19,20 @@ use tokio::{ sync::mpsc, }; +pub type Result = core::result::Result; + +#[derive(Debug, thiserror::Error)] +pub enum MixnetNodeError { + #[error("{0}")] + Protocol(#[from] ProtocolError), + #[error("invalid routing address: {0}")] + InvalidRoutingAddress(#[from] NymNodeRoutingAddressError), + #[error("send error: {0}")] + SendError(#[from] tokio::sync::mpsc::error::TrySendError), + #[error("client: {0}")] + Client(ProtocolError), +} + // A mix node that routes packets in the Mixnet. pub struct MixnetNode { config: MixnetNodeConfig, @@ -41,7 +55,7 @@ impl MixnetNode { const CLIENT_NOTI_CHANNEL_SIZE: usize = 100; - pub async fn run(self) -> Result<(), Box> { + pub async fn run(self) -> Result<()> { tracing::info!("Public key: {:?}", self.public_key()); // Spawn a ClientNotifier @@ -55,7 +69,9 @@ impl MixnetNode { //TODO: Accepting ad-hoc TCP conns for now. Improve conn handling. //TODO: Add graceful shutdown - let listener = TcpListener::bind(self.config.listen_address).await?; + let listener = TcpListener::bind(self.config.listen_address) + .await + .map_err(ProtocolError::IO)?; tracing::info!( "Listening mixnet node connections: {}", self.config.listen_address @@ -96,7 +112,7 @@ impl MixnetNode { pool: ConnectionPool, private_key: [u8; PRIVATE_KEY_SIZE], client_tx: mpsc::Sender, - ) -> Result<(), Box> { + ) -> Result<()> { loop { let body = Body::read(&mut socket).await?; @@ -130,7 +146,7 @@ impl MixnetNode { pool: &ConnectionPool, private_key: &PrivateKey, client_tx: &mpsc::Sender, - ) -> Result<(), Box> { + ) -> Result<()> { match body { Body::SphinxPacket(packet) => { Self::handle_sphinx_packet(pool, max_retries, retry_delay, private_key, packet) @@ -154,8 +170,11 @@ impl MixnetNode { retry_delay: Duration, private_key: &PrivateKey, packet: Box, - ) -> Result<(), Box> { - match packet.process(private_key)? { + ) -> Result<()> { + match packet + .process(private_key) + .map_err(ProtocolError::InvalidSphinxPacket)? + { ProcessedPacket::ForwardHop(packet, next_node_addr, delay) => { Self::forward_packet_to_next_hop( pool, @@ -184,7 +203,7 @@ impl MixnetNode { _private_key: &PrivateKey, client_tx: &mpsc::Sender, body: Body, - ) -> Result<(), Box> { + ) -> Result<()> { // TODO: Decrypt the final payload using the private key, if it's encrypted // Do not wait when the channel is full or no receiver exists @@ -199,7 +218,7 @@ impl MixnetNode { packet: Box, next_node_addr: NodeAddressBytes, delay: Delay, - ) -> Result<(), Box> { + ) -> Result<()> { tracing::debug!("Delaying the packet for {delay:?}"); tokio::time::sleep(delay.to_duration()).await; @@ -219,7 +238,7 @@ impl MixnetNode { retry_delay: Duration, payload: Payload, destination_addr: DestinationAddressBytes, - ) -> Result<(), Box> { + ) -> Result<()> { tracing::debug!("Forwarding final payload to destination mixnode"); Self::forward( @@ -238,8 +257,8 @@ impl MixnetNode { retry_delay: Duration, body: Body, to: NymNodeRoutingAddress, - ) -> Result<(), Box> { - let addr = SocketAddr::try_from(to)?; + ) -> Result<()> { + let addr = SocketAddr::from(to); let arc_socket = pool.get_or_init(&addr).await?; if let Err(e) = { @@ -254,7 +273,8 @@ impl MixnetNode { body, arc_socket, ) - .await; + .await + .map_err(Into::into); } Ok(()) } diff --git a/mixnet/protocol/Cargo.toml b/mixnet/protocol/Cargo.toml index 13387034..e21c7ff4 100644 --- a/mixnet/protocol/Cargo.toml +++ b/mixnet/protocol/Cargo.toml @@ -6,7 +6,8 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -tokio = "1.32" +tokio = { version = "1.32", features = ["sync", "net"] } sphinx-packet = "0.1.0" futures = "0.3" -tokio-util = {version = "0.7", features = ["io", "io-util"] } \ No newline at end of file +tokio-util = { version = "0.7", features = ["io", "io-util"] } +thiserror = "1" diff --git a/mixnet/protocol/src/lib.rs b/mixnet/protocol/src/lib.rs index 4b6701f9..459eac7b 100644 --- a/mixnet/protocol/src/lib.rs +++ b/mixnet/protocol/src/lib.rs @@ -1,12 +1,28 @@ use sphinx_packet::{payload::Payload, SphinxPacket}; -use std::{error::Error, io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration}; +use std::{io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration}; use tokio::{ - io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, net::TcpStream, sync::Mutex, }; +pub type Result = core::result::Result; + +#[derive(Debug, thiserror::Error)] +pub enum ProtocolError { + #[error("Unknown body type {0}")] + UnknownBodyType(u8), + #[error("{0}")] + InvalidSphinxPacket(sphinx_packet::Error), + #[error("{0}")] + InvalidPayload(sphinx_packet::Error), + #[error("{0}")] + IO(#[from] io::Error), + #[error("fail to send packet, reach maximum retries {0}")] + ReachMaxRetries(usize), +} + #[non_exhaustive] pub enum Body { SphinxPacket(Box), @@ -29,7 +45,7 @@ impl Body { } } - pub async fn read(reader: &mut R) -> Result> + pub async fn read(reader: &mut R) -> Result where R: AsyncRead + Unpin, { @@ -37,20 +53,17 @@ impl Body { match id { 0 => Self::read_sphinx_packet(reader).await, 1 => Self::read_final_payload(reader).await, - _ => Err("Invalid body type".into()), + id => Err(ProtocolError::UnknownBodyType(id)), } } - fn sphinx_packet_from_bytes( - data: &[u8], - ) -> Result> { - let packet = SphinxPacket::from_bytes(data)?; - Ok(Self::new_sphinx(Box::new(packet))) + fn sphinx_packet_from_bytes(data: &[u8]) -> Result { + SphinxPacket::from_bytes(data) + .map(|packet| Self::new_sphinx(Box::new(packet))) + .map_err(ProtocolError::InvalidPayload) } - async fn read_sphinx_packet( - reader: &mut R, - ) -> Result> + async fn read_sphinx_packet(reader: &mut R) -> Result where R: AsyncRead + Unpin, { @@ -60,16 +73,13 @@ impl Body { Self::sphinx_packet_from_bytes(&buf) } - pub fn final_payload_from_bytes( - data: &[u8], - ) -> Result> { - let payload = Payload::from_bytes(data)?; - Ok(Self::new_final_payload(payload)) + pub fn final_payload_from_bytes(data: &[u8]) -> Result { + Payload::from_bytes(data) + .map(Self::new_final_payload) + .map_err(ProtocolError::InvalidPayload) } - async fn read_final_payload( - reader: &mut R, - ) -> Result> + async fn read_final_payload(reader: &mut R) -> Result where R: AsyncRead + Unpin, { @@ -80,10 +90,7 @@ impl Body { Self::final_payload_from_bytes(&buf) } - pub async fn write( - &self, - writer: &mut W, - ) -> Result<(), Box> + pub async fn write(&self, writer: &mut W) -> Result<()> where W: AsyncWrite + Unpin + ?Sized, { @@ -111,7 +118,7 @@ pub async fn retry_backoff( retry_delay: Duration, body: Body, socket: Arc>, -) -> Result<(), Box> { +) -> Result<()> { for idx in 0..max_retries { // backoff let wait = Duration::from_millis((retry_delay.as_millis() as u64).pow(idx as u32)); @@ -121,19 +128,22 @@ pub async fn retry_backoff( match body.write(&mut *socket).await { Ok(_) => return Ok(()), Err(e) => { - if let Some(err) = e.downcast_ref::() { - match err.kind() { - ErrorKind::Unsupported => return Err(e), - _ => { - // update the connection - if let Ok(tcp) = TcpStream::connect(peer_addr).await { - *socket = tcp; + match &e { + ProtocolError::IO(err) => { + match err.kind() { + ErrorKind::Unsupported => return Err(e), + _ => { + // update the connection + if let Ok(tcp) = TcpStream::connect(peer_addr).await { + *socket = tcp; + } } } } + _ => return Err(e), } } } } - Err(format!("Failure after {max_retries} retries").into()) + Err(ProtocolError::ReachMaxRetries(max_retries)) } diff --git a/mixnet/topology/Cargo.toml b/mixnet/topology/Cargo.toml index 57ac2d46..7171b17f 100644 --- a/mixnet/topology/Cargo.toml +++ b/mixnet/topology/Cargo.toml @@ -10,3 +10,4 @@ rand = "0.7.3" serde = { version = "1.0", features = ["derive"] } sphinx-packet = "0.1.0" nym-sphinx = { package = "nym-sphinx", git = "https://github.com/nymtech/nym", tag = "v1.1.22" } +thiserror = "1" diff --git a/mixnet/topology/src/lib.rs b/mixnet/topology/src/lib.rs index d276f08d..8fd9d6c8 100644 --- a/mixnet/topology/src/lib.rs +++ b/mixnet/topology/src/lib.rs @@ -1,4 +1,4 @@ -use std::{error::Error, net::SocketAddr}; +use std::net::SocketAddr; use nym_sphinx::addressing::nodes::{NymNodeRoutingAddress, NymNodeRoutingAddressError}; use rand::{seq::IteratorRandom, Rng}; @@ -7,6 +7,8 @@ use sphinx_packet::{crypto::PUBLIC_KEY_SIZE, route}; pub type MixnetNodeId = [u8; PUBLIC_KEY_SIZE]; +pub type Result = core::result::Result; + #[derive(Serialize, Deserialize, Clone, Debug, Default)] pub struct MixnetTopology { pub layers: Vec, @@ -54,10 +56,7 @@ mod hex_serde { } impl MixnetTopology { - pub fn random_route( - &self, - rng: &mut R, - ) -> Result, Box> { + pub fn random_route(&self, rng: &mut R) -> Result> { let num_hops = self.layers.len(); let route: Vec = self @@ -78,19 +77,14 @@ impl MixnetTopology { } // Choose a destination mixnet node randomly from the last layer. - pub fn random_destination( - &self, - rng: &mut R, - ) -> Result> { - Ok(self - .layers + pub fn random_destination(&self, rng: &mut R) -> Result { + self.layers .last() .expect("topology is not empty") .random_node(rng) .expect("layer is not empty") .clone() .try_into() - .unwrap()) } } @@ -103,7 +97,7 @@ impl Layer { impl TryInto for Node { type Error = NymNodeRoutingAddressError; - fn try_into(self) -> Result { + fn try_into(self) -> Result { Ok(route::Node { address: NymNodeRoutingAddress::from(self.address).try_into()?, pub_key: self.public_key.into(), diff --git a/mixnet/util/Cargo.toml b/mixnet/util/Cargo.toml index 0565ae3b..7ab67774 100644 --- a/mixnet/util/Cargo.toml +++ b/mixnet/util/Cargo.toml @@ -5,4 +5,5 @@ edition = "2021" [dependencies] tokio = { version = "1.32", default-features = false, features = ["sync", "net"] } -parking_lot = { version = "0.12", features = ["send_guard"] } \ No newline at end of file +parking_lot = { version = "0.12", features = ["send_guard"] } +mixnet-protocol = { path = "../protocol" } \ No newline at end of file diff --git a/mixnet/util/src/lib.rs b/mixnet/util/src/lib.rs index d5e47d76..e00f3714 100644 --- a/mixnet/util/src/lib.rs +++ b/mixnet/util/src/lib.rs @@ -15,12 +15,19 @@ impl ConnectionPool { } } - pub async fn get_or_init(&self, addr: &SocketAddr) -> std::io::Result>> { + pub async fn get_or_init( + &self, + addr: &SocketAddr, + ) -> mixnet_protocol::Result>> { let mut pool = self.pool.lock().await; match pool.get(addr).cloned() { Some(tcp) => Ok(tcp), None => { - let tcp = Arc::new(Mutex::new(TcpStream::connect(addr).await?)); + let tcp = Arc::new(Mutex::new( + TcpStream::connect(addr) + .await + .map_err(mixnet_protocol::ProtocolError::IO)?, + )); pool.insert(*addr, tcp.clone()); Ok(tcp) }