1
0
mirror of synced 2025-02-23 13:08:19 +00:00

Add concrete Error implementation for mixnet (#405)

Add concrete Error implementation
This commit is contained in:
Al Liu 2023-09-25 14:21:07 +08:00 committed by GitHub
parent 46c7e25f6f
commit cb343156b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 160 additions and 130 deletions

View File

@ -0,0 +1,22 @@
use mixnet_protocol::ProtocolError;
use nym_sphinx::addressing::nodes::NymNodeRoutingAddressError;
#[derive(thiserror::Error, Debug)]
pub enum MixnetClientError {
#[error("mixnet node connect error")]
MixnetNodeConnectError,
#[error("mixnode stream has been closed")]
MixnetNodeStreamClosed,
#[error("unexpected stream body received")]
UnexpectedStreamBody,
#[error("invalid payload")]
InvalidPayload,
#[error("invalid routing address: {0}")]
InvalidRoutingAddress(#[from] NymNodeRoutingAddressError),
#[error("{0}")]
Protocol(#[from] ProtocolError),
#[error("{0}")]
Message(#[from] nym_sphinx::message::NymMessageError),
}
pub type Result<T> = core::result::Result<T, MixnetClientError>;

View File

@ -1,8 +1,9 @@
pub mod config;
pub mod error;
pub use error::*;
mod receiver;
mod sender;
use std::error::Error;
use std::time::Duration;
pub use config::MixnetClientConfig;
@ -11,7 +12,6 @@ use futures::stream::BoxStream;
use mixnet_util::ConnectionPool;
use rand::Rng;
use sender::Sender;
use thiserror::Error;
// A client for sending packets to Mixnet and receiving packets from Mixnet.
pub struct MixnetClient<R: Rng> {
@ -19,7 +19,7 @@ pub struct MixnetClient<R: Rng> {
sender: Sender<R>,
}
pub type MessageStream = BoxStream<'static, Result<Vec<u8>, MixnetClientError>>;
pub type MessageStream = BoxStream<'static, Result<Vec<u8>>>;
impl<R: Rng> MixnetClient<R> {
pub fn new(config: MixnetClientConfig, rng: R) -> Self {
@ -36,27 +36,11 @@ impl<R: Rng> MixnetClient<R> {
}
}
pub async fn run(&self) -> Result<MessageStream, MixnetClientError> {
pub async fn run(&self) -> Result<MessageStream> {
self.mode.run().await
}
pub fn send(
&mut self,
msg: Vec<u8>,
total_delay: Duration,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
pub fn send(&mut self, msg: Vec<u8>, total_delay: Duration) -> Result<()> {
self.sender.send(msg, total_delay)
}
}
#[derive(Error, Debug)]
pub enum MixnetClientError {
#[error("mixnet node connect error")]
MixnetNodeConnectError,
#[error("mixnode stream has been closed")]
MixnetNodeStreamClosed,
#[error("unexpected stream body received")]
UnexpectedStreamBody,
#[error("invalid payload")]
InvalidPayload,
}

View File

@ -1,4 +1,4 @@
use std::{error::Error, net::SocketAddr};
use std::net::SocketAddr;
use futures::{stream, Stream, StreamExt};
use mixnet_protocol::Body;
@ -9,6 +9,7 @@ use nym_sphinx::{
};
use tokio::net::TcpStream;
use super::error::*;
use crate::MixnetClientError;
// Receiver accepts TCP connections to receive incoming payloads from the Mixnet.
@ -21,12 +22,7 @@ impl Receiver {
Self { node_address }
}
pub async fn run(
&self,
) -> Result<
impl Stream<Item = Result<Vec<u8>, MixnetClientError>> + Send + 'static,
MixnetClientError,
> {
pub async fn run(&self) -> Result<impl Stream<Item = Result<Vec<u8>>> + Send + 'static> {
let Ok(socket) = TcpStream::connect(self.node_address).await else {
return Err(MixnetClientError::MixnetNodeConnectError);
};
@ -36,9 +32,7 @@ impl Receiver {
))))
}
fn fragment_stream(
socket: TcpStream,
) -> impl Stream<Item = Result<Fragment, MixnetClientError>> + Send + 'static {
fn fragment_stream(socket: TcpStream) -> impl Stream<Item = Result<Fragment>> + Send + 'static {
stream::unfold(socket, move |mut socket| {
async move {
let Ok(body) = Body::read(&mut socket).await else {
@ -60,11 +54,8 @@ impl Receiver {
}
fn message_stream(
fragment_stream: impl Stream<Item = Result<Fragment, MixnetClientError>>
+ Send
+ Unpin
+ 'static,
) -> impl Stream<Item = Result<Vec<u8>, MixnetClientError>> + Send + 'static {
fragment_stream: impl Stream<Item = Result<Fragment>> + Send + Unpin + 'static,
) -> impl Stream<Item = Result<Vec<u8>>> + Send + 'static {
// MessageReconstructor buffers all received fragments
// and eventually returns reconstructed messages.
let message_reconstructor: MessageReconstructor = Default::default();
@ -80,7 +71,7 @@ impl Receiver {
)
}
fn fragment_from_payload(payload: Payload) -> Result<Fragment, MixnetClientError> {
fn fragment_from_payload(payload: Payload) -> Result<Fragment> {
let Ok(payload_plaintext) = payload.recover_plaintext() else {
return Err(MixnetClientError::InvalidPayload);
};
@ -91,12 +82,9 @@ impl Receiver {
}
async fn reconstruct_message(
fragment_stream: &mut (impl Stream<Item = Result<Fragment, MixnetClientError>>
+ Send
+ Unpin
+ 'static),
fragment_stream: &mut (impl Stream<Item = Result<Fragment>> + Send + Unpin + 'static),
message_reconstructor: &mut MessageReconstructor,
) -> Result<Vec<u8>, MixnetClientError> {
) -> Result<Vec<u8>> {
// Read fragments until at least one message is fully reconstructed.
while let Some(next) = fragment_stream.next().await {
match next {
@ -131,7 +119,7 @@ impl Receiver {
}
}
fn remove_padding(msg: Vec<u8>) -> Result<Vec<u8>, Box<dyn Error>> {
fn remove_padding(msg: Vec<u8>) -> Result<Vec<u8>> {
let padded_message = PaddedMessage::new_reconstructed(msg);
// we need this because PaddedMessage.remove_padding requires it for other NymMessage types.
let dummy_num_mix_hops = 0;

View File

@ -1,6 +1,6 @@
use std::{error::Error, net::SocketAddr, time::Duration};
use std::{net::SocketAddr, time::Duration};
use mixnet_protocol::Body;
use mixnet_protocol::{Body, ProtocolError};
use mixnet_topology::MixnetTopology;
use mixnet_util::ConnectionPool;
use nym_sphinx::{
@ -11,6 +11,8 @@ use nym_sphinx::{
use rand::{distributions::Uniform, prelude::Distribution, Rng};
use sphinx_packet::{route, SphinxPacket, SphinxPacketBuilder};
use super::error::*;
// Sender splits messages into Sphinx packets and sends them to the Mixnet.
pub struct Sender<R: Rng> {
//TODO: handle topology update
@ -38,11 +40,7 @@ impl<R: Rng> Sender<R> {
}
}
pub fn send(
&mut self,
msg: Vec<u8>,
total_delay: Duration,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
pub fn send(&mut self, msg: Vec<u8>, total_delay: Duration) -> Result<()> {
let destination = self.topology.random_destination(&mut self.rng)?;
let destination = Destination::new(
DestinationAddressBytes::from_bytes(destination.address.as_bytes()),
@ -52,7 +50,7 @@ impl<R: Rng> Sender<R> {
self.pad_and_split_message(msg)
.into_iter()
.map(|fragment| self.build_sphinx_packet(fragment, &destination, total_delay))
.collect::<Result<Vec<_>, _>>()?
.collect::<Result<Vec<_>>>()?
.into_iter()
.for_each(|(packet, first_node)| {
let pool = self.pool.clone();
@ -95,8 +93,7 @@ impl<R: Rng> Sender<R> {
fragment: Fragment,
destination: &Destination,
total_delay: Duration,
) -> Result<(sphinx_packet::SphinxPacket, route::Node), Box<dyn Error + Send + Sync + 'static>>
{
) -> Result<(sphinx_packet::SphinxPacket, route::Node)> {
let route = self.topology.random_route(&mut self.rng)?;
let delays: Vec<Delay> =
@ -110,7 +107,8 @@ impl<R: Rng> Sender<R> {
let packet = SphinxPacketBuilder::new()
.with_payload_size(payload.len() + PAYLOAD_OVERHEAD_SIZE)
.build_packet(payload, &route, destination, &delays)?;
.build_packet(payload, &route, destination, &delays)
.map_err(ProtocolError::InvalidSphinxPacket)?;
let first_mixnode = route.first().cloned().expect("route is not empty");
@ -123,8 +121,8 @@ impl<R: Rng> Sender<R> {
retry_delay: Duration,
packet: Box<SphinxPacket>,
addr: NodeAddressBytes,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
let addr = SocketAddr::try_from(NymNodeRoutingAddress::try_from(addr)?)?;
) -> Result<()> {
let addr = SocketAddr::from(NymNodeRoutingAddress::try_from(addr)?);
tracing::debug!("Sending a Sphinx packet to the node: {addr:?}");
let mu: std::sync::Arc<tokio::sync::Mutex<tokio::net::TcpStream>> =
@ -145,7 +143,8 @@ impl<R: Rng> Sender<R> {
body,
arc_socket,
)
.await;
.await
.map_err(Into::into);
}
Ok(())
}

View File

@ -7,6 +7,7 @@ edition = "2021"
serde = { version = "1.0", features = ["derive"] }
tracing = "0.1.37"
tokio = { version = "1.32", features = ["net", "time"] }
thiserror = "1"
sphinx-packet = "0.1.0"
nym-sphinx = { package = "nym-sphinx", git = "https://github.com/nymtech/nym", tag = "v1.1.22" }
mixnet-protocol = { path = "../protocol" }

View File

@ -1,4 +1,4 @@
use std::{error::Error, net::SocketAddr};
use std::net::SocketAddr;
use mixnet_protocol::Body;
use tokio::{
@ -12,8 +12,10 @@ impl ClientNotifier {
pub async fn run(
listen_address: SocketAddr,
mut rx: mpsc::Receiver<Body>,
) -> Result<(), Box<dyn Error>> {
let listener = TcpListener::bind(listen_address).await?;
) -> super::Result<()> {
let listener = TcpListener::bind(listen_address)
.await
.map_err(super::ProtocolError::IO)?;
tracing::info!("Listening mixnet client connections: {listen_address}");
// Currently, handling only a single incoming connection
@ -21,7 +23,7 @@ impl ClientNotifier {
loop {
match listener.accept().await {
Ok((socket, remote_addr)) => {
tracing::debug!("Accepted incoming client connection from {remote_addr:?}");
tracing::debug!("Accepted incoming client connection from {remote_addr}");
if let Err(e) = Self::handle_connection(socket, &mut rx).await {
tracing::error!("failed to handle conn: {e}");
@ -35,10 +37,10 @@ impl ClientNotifier {
async fn handle_connection(
mut socket: TcpStream,
rx: &mut mpsc::Receiver<Body>,
) -> Result<(), Box<dyn Error>> {
) -> super::Result<()> {
while let Some(body) = rx.recv().await {
if let Err(e) = body.write(&mut socket).await {
return Err(format!("error from client conn: {e}").into());
return Err(super::MixnetNodeError::Client(e));
}
}
tracing::debug!("body receiver closed");

View File

@ -1,16 +1,16 @@
mod client_notifier;
pub mod config;
use std::{error::Error, net::SocketAddr, time::Duration};
use std::{net::SocketAddr, time::Duration};
use client_notifier::ClientNotifier;
pub use config::MixnetNodeConfig;
use mixnet_protocol::Body;
use mixnet_protocol::{Body, ProtocolError};
use mixnet_topology::MixnetNodeId;
use mixnet_util::ConnectionPool;
use nym_sphinx::{
addressing::nodes::NymNodeRoutingAddress, Delay, DestinationAddressBytes, NodeAddressBytes,
Payload, PrivateKey,
addressing::nodes::{NymNodeRoutingAddress, NymNodeRoutingAddressError},
Delay, DestinationAddressBytes, NodeAddressBytes, Payload, PrivateKey,
};
pub use sphinx_packet::crypto::PRIVATE_KEY_SIZE;
use sphinx_packet::{crypto::PUBLIC_KEY_SIZE, ProcessedPacket, SphinxPacket};
@ -19,6 +19,20 @@ use tokio::{
sync::mpsc,
};
pub type Result<T> = core::result::Result<T, MixnetNodeError>;
#[derive(Debug, thiserror::Error)]
pub enum MixnetNodeError {
#[error("{0}")]
Protocol(#[from] ProtocolError),
#[error("invalid routing address: {0}")]
InvalidRoutingAddress(#[from] NymNodeRoutingAddressError),
#[error("send error: {0}")]
SendError(#[from] tokio::sync::mpsc::error::TrySendError<Body>),
#[error("client: {0}")]
Client(ProtocolError),
}
// A mix node that routes packets in the Mixnet.
pub struct MixnetNode {
config: MixnetNodeConfig,
@ -41,7 +55,7 @@ impl MixnetNode {
const CLIENT_NOTI_CHANNEL_SIZE: usize = 100;
pub async fn run(self) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
pub async fn run(self) -> Result<()> {
tracing::info!("Public key: {:?}", self.public_key());
// Spawn a ClientNotifier
@ -55,7 +69,9 @@ impl MixnetNode {
//TODO: Accepting ad-hoc TCP conns for now. Improve conn handling.
//TODO: Add graceful shutdown
let listener = TcpListener::bind(self.config.listen_address).await?;
let listener = TcpListener::bind(self.config.listen_address)
.await
.map_err(ProtocolError::IO)?;
tracing::info!(
"Listening mixnet node connections: {}",
self.config.listen_address
@ -96,7 +112,7 @@ impl MixnetNode {
pool: ConnectionPool,
private_key: [u8; PRIVATE_KEY_SIZE],
client_tx: mpsc::Sender<Body>,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
) -> Result<()> {
loop {
let body = Body::read(&mut socket).await?;
@ -130,7 +146,7 @@ impl MixnetNode {
pool: &ConnectionPool,
private_key: &PrivateKey,
client_tx: &mpsc::Sender<Body>,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
) -> Result<()> {
match body {
Body::SphinxPacket(packet) => {
Self::handle_sphinx_packet(pool, max_retries, retry_delay, private_key, packet)
@ -154,8 +170,11 @@ impl MixnetNode {
retry_delay: Duration,
private_key: &PrivateKey,
packet: Box<SphinxPacket>,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
match packet.process(private_key)? {
) -> Result<()> {
match packet
.process(private_key)
.map_err(ProtocolError::InvalidSphinxPacket)?
{
ProcessedPacket::ForwardHop(packet, next_node_addr, delay) => {
Self::forward_packet_to_next_hop(
pool,
@ -184,7 +203,7 @@ impl MixnetNode {
_private_key: &PrivateKey,
client_tx: &mpsc::Sender<Body>,
body: Body,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
) -> Result<()> {
// TODO: Decrypt the final payload using the private key, if it's encrypted
// Do not wait when the channel is full or no receiver exists
@ -199,7 +218,7 @@ impl MixnetNode {
packet: Box<SphinxPacket>,
next_node_addr: NodeAddressBytes,
delay: Delay,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
) -> Result<()> {
tracing::debug!("Delaying the packet for {delay:?}");
tokio::time::sleep(delay.to_duration()).await;
@ -219,7 +238,7 @@ impl MixnetNode {
retry_delay: Duration,
payload: Payload,
destination_addr: DestinationAddressBytes,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
) -> Result<()> {
tracing::debug!("Forwarding final payload to destination mixnode");
Self::forward(
@ -238,8 +257,8 @@ impl MixnetNode {
retry_delay: Duration,
body: Body,
to: NymNodeRoutingAddress,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
let addr = SocketAddr::try_from(to)?;
) -> Result<()> {
let addr = SocketAddr::from(to);
let arc_socket = pool.get_or_init(&addr).await?;
if let Err(e) = {
@ -254,7 +273,8 @@ impl MixnetNode {
body,
arc_socket,
)
.await;
.await
.map_err(Into::into);
}
Ok(())
}

View File

@ -6,7 +6,8 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
tokio = "1.32"
tokio = { version = "1.32", features = ["sync", "net"] }
sphinx-packet = "0.1.0"
futures = "0.3"
tokio-util = {version = "0.7", features = ["io", "io-util"] }
tokio-util = { version = "0.7", features = ["io", "io-util"] }
thiserror = "1"

View File

@ -1,12 +1,28 @@
use sphinx_packet::{payload::Payload, SphinxPacket};
use std::{error::Error, io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration};
use std::{io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::TcpStream,
sync::Mutex,
};
pub type Result<T> = core::result::Result<T, ProtocolError>;
#[derive(Debug, thiserror::Error)]
pub enum ProtocolError {
#[error("Unknown body type {0}")]
UnknownBodyType(u8),
#[error("{0}")]
InvalidSphinxPacket(sphinx_packet::Error),
#[error("{0}")]
InvalidPayload(sphinx_packet::Error),
#[error("{0}")]
IO(#[from] io::Error),
#[error("fail to send packet, reach maximum retries {0}")]
ReachMaxRetries(usize),
}
#[non_exhaustive]
pub enum Body {
SphinxPacket(Box<SphinxPacket>),
@ -29,7 +45,7 @@ impl Body {
}
}
pub async fn read<R>(reader: &mut R) -> Result<Body, Box<dyn Error + Send + Sync + 'static>>
pub async fn read<R>(reader: &mut R) -> Result<Body>
where
R: AsyncRead + Unpin,
{
@ -37,20 +53,17 @@ impl Body {
match id {
0 => Self::read_sphinx_packet(reader).await,
1 => Self::read_final_payload(reader).await,
_ => Err("Invalid body type".into()),
id => Err(ProtocolError::UnknownBodyType(id)),
}
}
fn sphinx_packet_from_bytes(
data: &[u8],
) -> Result<Self, Box<dyn Error + Send + Sync + 'static>> {
let packet = SphinxPacket::from_bytes(data)?;
Ok(Self::new_sphinx(Box::new(packet)))
fn sphinx_packet_from_bytes(data: &[u8]) -> Result<Self> {
SphinxPacket::from_bytes(data)
.map(|packet| Self::new_sphinx(Box::new(packet)))
.map_err(ProtocolError::InvalidPayload)
}
async fn read_sphinx_packet<R>(
reader: &mut R,
) -> Result<Body, Box<dyn Error + Send + Sync + 'static>>
async fn read_sphinx_packet<R>(reader: &mut R) -> Result<Body>
where
R: AsyncRead + Unpin,
{
@ -60,16 +73,13 @@ impl Body {
Self::sphinx_packet_from_bytes(&buf)
}
pub fn final_payload_from_bytes(
data: &[u8],
) -> Result<Self, Box<dyn Error + Send + Sync + 'static>> {
let payload = Payload::from_bytes(data)?;
Ok(Self::new_final_payload(payload))
pub fn final_payload_from_bytes(data: &[u8]) -> Result<Self> {
Payload::from_bytes(data)
.map(Self::new_final_payload)
.map_err(ProtocolError::InvalidPayload)
}
async fn read_final_payload<R>(
reader: &mut R,
) -> Result<Body, Box<dyn Error + Send + Sync + 'static>>
async fn read_final_payload<R>(reader: &mut R) -> Result<Body>
where
R: AsyncRead + Unpin,
{
@ -80,10 +90,7 @@ impl Body {
Self::final_payload_from_bytes(&buf)
}
pub async fn write<W>(
&self,
writer: &mut W,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>>
pub async fn write<W>(&self, writer: &mut W) -> Result<()>
where
W: AsyncWrite + Unpin + ?Sized,
{
@ -111,7 +118,7 @@ pub async fn retry_backoff(
retry_delay: Duration,
body: Body,
socket: Arc<Mutex<TcpStream>>,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
) -> Result<()> {
for idx in 0..max_retries {
// backoff
let wait = Duration::from_millis((retry_delay.as_millis() as u64).pow(idx as u32));
@ -121,19 +128,22 @@ pub async fn retry_backoff(
match body.write(&mut *socket).await {
Ok(_) => return Ok(()),
Err(e) => {
if let Some(err) = e.downcast_ref::<std::io::Error>() {
match err.kind() {
ErrorKind::Unsupported => return Err(e),
_ => {
// update the connection
if let Ok(tcp) = TcpStream::connect(peer_addr).await {
*socket = tcp;
match &e {
ProtocolError::IO(err) => {
match err.kind() {
ErrorKind::Unsupported => return Err(e),
_ => {
// update the connection
if let Ok(tcp) = TcpStream::connect(peer_addr).await {
*socket = tcp;
}
}
}
}
_ => return Err(e),
}
}
}
}
Err(format!("Failure after {max_retries} retries").into())
Err(ProtocolError::ReachMaxRetries(max_retries))
}

View File

@ -10,3 +10,4 @@ rand = "0.7.3"
serde = { version = "1.0", features = ["derive"] }
sphinx-packet = "0.1.0"
nym-sphinx = { package = "nym-sphinx", git = "https://github.com/nymtech/nym", tag = "v1.1.22" }
thiserror = "1"

View File

@ -1,4 +1,4 @@
use std::{error::Error, net::SocketAddr};
use std::net::SocketAddr;
use nym_sphinx::addressing::nodes::{NymNodeRoutingAddress, NymNodeRoutingAddressError};
use rand::{seq::IteratorRandom, Rng};
@ -7,6 +7,8 @@ use sphinx_packet::{crypto::PUBLIC_KEY_SIZE, route};
pub type MixnetNodeId = [u8; PUBLIC_KEY_SIZE];
pub type Result<T> = core::result::Result<T, NymNodeRoutingAddressError>;
#[derive(Serialize, Deserialize, Clone, Debug, Default)]
pub struct MixnetTopology {
pub layers: Vec<Layer>,
@ -54,10 +56,7 @@ mod hex_serde {
}
impl MixnetTopology {
pub fn random_route<R: Rng>(
&self,
rng: &mut R,
) -> Result<Vec<route::Node>, Box<dyn Error + Send + Sync + 'static>> {
pub fn random_route<R: Rng>(&self, rng: &mut R) -> Result<Vec<route::Node>> {
let num_hops = self.layers.len();
let route: Vec<route::Node> = self
@ -78,19 +77,14 @@ impl MixnetTopology {
}
// Choose a destination mixnet node randomly from the last layer.
pub fn random_destination<R: Rng>(
&self,
rng: &mut R,
) -> Result<route::Node, Box<dyn Error + Send + Sync + 'static>> {
Ok(self
.layers
pub fn random_destination<R: Rng>(&self, rng: &mut R) -> Result<route::Node> {
self.layers
.last()
.expect("topology is not empty")
.random_node(rng)
.expect("layer is not empty")
.clone()
.try_into()
.unwrap())
}
}
@ -103,7 +97,7 @@ impl Layer {
impl TryInto<route::Node> for Node {
type Error = NymNodeRoutingAddressError;
fn try_into(self) -> Result<route::Node, Self::Error> {
fn try_into(self) -> Result<route::Node> {
Ok(route::Node {
address: NymNodeRoutingAddress::from(self.address).try_into()?,
pub_key: self.public_key.into(),

View File

@ -5,4 +5,5 @@ edition = "2021"
[dependencies]
tokio = { version = "1.32", default-features = false, features = ["sync", "net"] }
parking_lot = { version = "0.12", features = ["send_guard"] }
parking_lot = { version = "0.12", features = ["send_guard"] }
mixnet-protocol = { path = "../protocol" }

View File

@ -15,12 +15,19 @@ impl ConnectionPool {
}
}
pub async fn get_or_init(&self, addr: &SocketAddr) -> std::io::Result<Arc<Mutex<TcpStream>>> {
pub async fn get_or_init(
&self,
addr: &SocketAddr,
) -> mixnet_protocol::Result<Arc<Mutex<TcpStream>>> {
let mut pool = self.pool.lock().await;
match pool.get(addr).cloned() {
Some(tcp) => Ok(tcp),
None => {
let tcp = Arc::new(Mutex::new(TcpStream::connect(addr).await?));
let tcp = Arc::new(Mutex::new(
TcpStream::connect(addr)
.await
.map_err(mixnet_protocol::ProtocolError::IO)?,
));
pool.insert(*addr, tcp.clone());
Ok(tcp)
}