From 2429893ac23d8c6649f9200507d1b9322587545e Mon Sep 17 00:00:00 2001 From: Al Liu Date: Mon, 18 Sep 2023 18:26:54 +0800 Subject: [PATCH] Add mixnet retry mechanism (#386) Imiplement mixnet retry --- mixnet/client/Cargo.toml | 2 +- mixnet/client/src/config.rs | 15 ++++++ mixnet/client/src/lib.rs | 8 ++- mixnet/client/src/receiver.rs | 23 +++++---- mixnet/client/src/sender.rs | 46 +++++++++++++++--- mixnet/node/src/config.rs | 26 +++++++++- mixnet/node/src/lib.rs | 92 ++++++++++++++++++++++++++++++----- mixnet/protocol/Cargo.toml | 2 +- mixnet/protocol/src/lib.rs | 48 ++++++++++++++++-- tests/src/benches/mixnet.rs | 4 ++ tests/src/nodes/mixnode.rs | 1 + tests/src/nodes/nomos.rs | 2 + tests/src/tests/mixnet.rs | 4 ++ 13 files changed, 237 insertions(+), 36 deletions(-) diff --git a/mixnet/client/Cargo.toml b/mixnet/client/Cargo.toml index ebf7eafa..d30cc963 100644 --- a/mixnet/client/Cargo.toml +++ b/mixnet/client/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] serde = { version = "1.0", features = ["derive"] } tracing = "0.1.37" -tokio = { version = "1.29.1", features = ["net"] } +tokio = { version = "1.32", features = ["net"] } sphinx-packet = "0.1.0" nym-sphinx = { package = "nym-sphinx", git = "https://github.com/nymtech/nym", tag = "v1.1.22" } # Using an older version, since `nym-sphinx` depends on `rand` v0.7.3. diff --git a/mixnet/client/src/config.rs b/mixnet/client/src/config.rs index 3fb1a489..16ca684d 100644 --- a/mixnet/client/src/config.rs +++ b/mixnet/client/src/config.rs @@ -11,6 +11,21 @@ pub struct MixnetClientConfig { pub mode: MixnetClientMode, pub topology: MixnetTopology, pub connection_pool_size: usize, + pub max_retries: usize, + pub retry_delay: std::time::Duration, +} + +impl MixnetClientConfig { + /// Creates a new `MixnetClientConfig` with default values. + pub fn new(mode: MixnetClientMode, topology: MixnetTopology) -> Self { + Self { + mode, + topology, + connection_pool_size: 256, + max_retries: 3, + retry_delay: std::time::Duration::from_secs(5), + } + } } #[derive(Serialize, Deserialize, Clone, Debug)] diff --git a/mixnet/client/src/lib.rs b/mixnet/client/src/lib.rs index b08e9c3e..a53c94a4 100644 --- a/mixnet/client/src/lib.rs +++ b/mixnet/client/src/lib.rs @@ -26,7 +26,13 @@ impl MixnetClient { let cache = ConnectionPool::new(config.connection_pool_size); Self { mode: config.mode, - sender: Sender::new(config.topology, cache, rng), + sender: Sender::new( + config.topology, + cache, + rng, + config.max_retries, + config.retry_delay, + ), } } diff --git a/mixnet/client/src/receiver.rs b/mixnet/client/src/receiver.rs index 4c0bfcc1..e0b5a218 100644 --- a/mixnet/client/src/receiver.rs +++ b/mixnet/client/src/receiver.rs @@ -39,17 +39,22 @@ impl Receiver { fn fragment_stream( socket: TcpStream, ) -> impl Stream> + Send + 'static { - stream::unfold(socket, |mut socket| async move { - let Ok(body) = Body::read(&mut socket).await else { - // TODO: Maybe this is a hard error and the stream is corrupted? In that case stop the stream - return Some((Err(MixnetClientError::MixnetNodeStreamClosed), socket)); - }; + stream::unfold(socket, move |mut socket| { + async move { + let Ok(body) = Body::read(&mut socket).await else { + // TODO: Maybe this is a hard error and the stream is corrupted? In that case stop the stream + return Some((Err(MixnetClientError::MixnetNodeStreamClosed), socket)); + }; - match body { - Body::SphinxPacket(_) => { - Some((Err(MixnetClientError::UnexpectedStreamBody), socket)) + match body { + Body::SphinxPacket(_) => { + Some((Err(MixnetClientError::UnexpectedStreamBody), socket)) + } + Body::FinalPayload(payload) => { + Some((Self::fragment_from_payload(payload), socket)) + } + _ => unreachable!(), } - Body::FinalPayload(payload) => Some((Self::fragment_from_payload(payload), socket)), } }) } diff --git a/mixnet/client/src/sender.rs b/mixnet/client/src/sender.rs index b63fb3d2..32d0f644 100644 --- a/mixnet/client/src/sender.rs +++ b/mixnet/client/src/sender.rs @@ -16,15 +16,25 @@ pub struct Sender { //TODO: handle topology update topology: MixnetTopology, pool: ConnectionPool, + max_retries: usize, + retry_delay: Duration, rng: R, } impl Sender { - pub fn new(topology: MixnetTopology, pool: ConnectionPool, rng: R) -> Self { + pub fn new( + topology: MixnetTopology, + pool: ConnectionPool, + rng: R, + max_retries: usize, + retry_delay: Duration, + ) -> Self { Self { topology, rng, pool, + max_retries, + retry_delay, } } @@ -46,9 +56,17 @@ impl Sender { .into_iter() .for_each(|(packet, first_node)| { let pool = self.pool.clone(); + let max_retries = self.max_retries; + let retry_delay = self.retry_delay; tokio::spawn(async move { - if let Err(e) = - Self::send_packet(&pool, Box::new(packet), first_node.address).await + if let Err(e) = Self::send_packet( + &pool, + max_retries, + retry_delay, + Box::new(packet), + first_node.address, + ) + .await { tracing::error!("failed to send packet to the first node: {e}"); } @@ -101,6 +119,8 @@ impl Sender { async fn send_packet( pool: &ConnectionPool, + max_retries: usize, + retry_delay: Duration, packet: Box, addr: NodeAddressBytes, ) -> Result<(), Box> { @@ -109,12 +129,24 @@ impl Sender { let mu: std::sync::Arc> = pool.get_or_init(&addr).await?; - let mut socket = mu.lock().await; - let body = Body::new_sphinx(packet); - body.write(&mut *socket).await?; + let arc_socket = mu.clone(); - tracing::debug!("Sent a Sphinx packet successuflly to the node: {addr:?}"); + let body = Body::SphinxPacket(packet); + if let Err(e) = { + let mut socket = mu.lock().await; + body.write(&mut *socket).await + } { + tracing::error!("Failed to send packet to {addr} with error: {e}. Retrying..."); + return mixnet_protocol::retry_backoff( + addr, + max_retries, + retry_delay, + body, + arc_socket, + ) + .await; + } Ok(()) } } diff --git a/mixnet/node/src/config.rs b/mixnet/node/src/config.rs index 93e989a5..6d325c9b 100644 --- a/mixnet/node/src/config.rs +++ b/mixnet/node/src/config.rs @@ -1,4 +1,7 @@ -use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; +use std::{ + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + time::Duration, +}; use nym_sphinx::{PrivateKey, PublicKey}; use serde::{Deserialize, Serialize}; @@ -13,7 +16,14 @@ pub struct MixnetNodeConfig { /// A key for decrypting Sphinx packets pub private_key: [u8; PRIVATE_KEY_SIZE], /// The size of the connection pool. + #[serde(default = "MixnetNodeConfig::default_connection_pool_size")] pub connection_pool_size: usize, + /// The maximum number of retries. + #[serde(default = "MixnetNodeConfig::default_max_retries")] + pub max_retries: usize, + /// The retry delay between retries. + #[serde(default = "MixnetNodeConfig::default_retry_delay")] + pub retry_delay: Duration, } impl Default for MixnetNodeConfig { @@ -26,11 +36,25 @@ impl Default for MixnetNodeConfig { )), private_key: PrivateKey::new().to_bytes(), connection_pool_size: 255, + max_retries: 3, + retry_delay: Duration::from_secs(5), } } } impl MixnetNodeConfig { + const fn default_connection_pool_size() -> usize { + 255 + } + + const fn default_max_retries() -> usize { + 3 + } + + const fn default_retry_delay() -> Duration { + Duration::from_secs(5) + } + pub fn public_key(&self) -> [u8; PUBLIC_KEY_SIZE] { *PublicKey::from(&PrivateKey::from(self.private_key)).as_bytes() } diff --git a/mixnet/node/src/lib.rs b/mixnet/node/src/lib.rs index e38b4b56..6997bdfe 100644 --- a/mixnet/node/src/lib.rs +++ b/mixnet/node/src/lib.rs @@ -1,7 +1,7 @@ mod client_notifier; pub mod config; -use std::{error::Error, net::SocketAddr}; +use std::{error::Error, net::SocketAddr, time::Duration}; use client_notifier::ClientNotifier; pub use config::MixnetNodeConfig; @@ -70,8 +70,15 @@ impl MixnetNode { let private_key = self.config.private_key; let pool = self.pool.clone(); tokio::spawn(async move { - if let Err(e) = - Self::handle_connection(socket, pool, private_key, client_tx).await + if let Err(e) = Self::handle_connection( + socket, + self.config.max_retries, + self.config.retry_delay, + pool, + private_key, + client_tx, + ) + .await { tracing::error!("failed to handle conn: {e}"); } @@ -84,6 +91,8 @@ impl MixnetNode { async fn handle_connection( mut socket: TcpStream, + max_retries: usize, + retry_delay: Duration, pool: ConnectionPool, private_key: [u8; PRIVATE_KEY_SIZE], client_tx: mpsc::Sender, @@ -96,14 +105,27 @@ impl MixnetNode { let client_tx = client_tx.clone(); tokio::spawn(async move { - if let Err(e) = Self::handle_body(body, &pool, &private_key, &client_tx).await { + if let Err(e) = Self::handle_body( + max_retries, + retry_delay, + body, + &pool, + &private_key, + &client_tx, + ) + .await + { tracing::error!("failed to handle body: {e}"); } }); } } + // TODO: refactor this fn to make it receive less arguments + #[allow(clippy::too_many_arguments)] async fn handle_body( + max_retries: usize, + retry_delay: Duration, body: Body, pool: &ConnectionPool, private_key: &PrivateKey, @@ -111,25 +133,49 @@ impl MixnetNode { ) -> Result<(), Box> { match body { Body::SphinxPacket(packet) => { - Self::handle_sphinx_packet(pool, private_key, packet).await + Self::handle_sphinx_packet(pool, max_retries, retry_delay, private_key, packet) + .await } - _body @ Body::FinalPayload(_) => { - Self::forward_body_to_client_notifier(private_key, client_tx, _body).await + Body::FinalPayload(payload) => { + Self::forward_body_to_client_notifier( + private_key, + client_tx, + Body::FinalPayload(payload), + ) + .await } + _ => unreachable!(), } } async fn handle_sphinx_packet( pool: &ConnectionPool, + max_retries: usize, + retry_delay: Duration, private_key: &PrivateKey, packet: Box, ) -> Result<(), Box> { match packet.process(private_key)? { ProcessedPacket::ForwardHop(packet, next_node_addr, delay) => { - Self::forward_packet_to_next_hop(pool, packet, next_node_addr, delay).await + Self::forward_packet_to_next_hop( + pool, + max_retries, + retry_delay, + packet, + next_node_addr, + delay, + ) + .await } ProcessedPacket::FinalHop(destination_addr, _, payload) => { - Self::forward_payload_to_destination(pool, payload, destination_addr).await + Self::forward_payload_to_destination( + pool, + max_retries, + retry_delay, + payload, + destination_addr, + ) + .await } } } @@ -148,6 +194,8 @@ impl MixnetNode { async fn forward_packet_to_next_hop( pool: &ConnectionPool, + max_retries: usize, + retry_delay: Duration, packet: Box, next_node_addr: NodeAddressBytes, delay: Delay, @@ -157,6 +205,8 @@ impl MixnetNode { Self::forward( pool, + max_retries, + retry_delay, Body::new_sphinx(packet), NymNodeRoutingAddress::try_from(next_node_addr)?, ) @@ -165,6 +215,8 @@ impl MixnetNode { async fn forward_payload_to_destination( pool: &ConnectionPool, + max_retries: usize, + retry_delay: Duration, payload: Payload, destination_addr: DestinationAddressBytes, ) -> Result<(), Box> { @@ -172,6 +224,8 @@ impl MixnetNode { Self::forward( pool, + max_retries, + retry_delay, Body::new_final_payload(payload), NymNodeRoutingAddress::try_from_bytes(&destination_addr.as_bytes())?, ) @@ -180,12 +234,28 @@ impl MixnetNode { async fn forward( pool: &ConnectionPool, + max_retries: usize, + retry_delay: Duration, body: Body, to: NymNodeRoutingAddress, ) -> Result<(), Box> { let addr = SocketAddr::try_from(to)?; - body.write(&mut *pool.get_or_init(&addr).await?.lock().await) - .await?; + let arc_socket = pool.get_or_init(&addr).await?; + + if let Err(e) = { + let mut socket = arc_socket.lock().await; + body.write(&mut *socket).await + } { + tracing::error!("Failed to forward packet to {addr} with error: {e}. Retrying..."); + return mixnet_protocol::retry_backoff( + addr, + max_retries, + retry_delay, + body, + arc_socket, + ) + .await; + } Ok(()) } } diff --git a/mixnet/protocol/Cargo.toml b/mixnet/protocol/Cargo.toml index a587c82b..13387034 100644 --- a/mixnet/protocol/Cargo.toml +++ b/mixnet/protocol/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -tokio = "1.29.1" +tokio = "1.32" sphinx-packet = "0.1.0" futures = "0.3" tokio-util = {version = "0.7", features = ["io", "io-util"] } \ No newline at end of file diff --git a/mixnet/protocol/src/lib.rs b/mixnet/protocol/src/lib.rs index 5bb1422d..4b6701f9 100644 --- a/mixnet/protocol/src/lib.rs +++ b/mixnet/protocol/src/lib.rs @@ -1,8 +1,13 @@ use sphinx_packet::{payload::Payload, SphinxPacket}; -use std::error::Error; +use std::{error::Error, io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration}; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::{ + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + net::TcpStream, + sync::Mutex, +}; +#[non_exhaustive] pub enum Body { SphinxPacket(Box), FinalPayload(Payload), @@ -76,7 +81,7 @@ impl Body { } pub async fn write( - self, + &self, writer: &mut W, ) -> Result<(), Box> where @@ -85,12 +90,12 @@ impl Body { let variant = self.variant_as_u8(); writer.write_u8(variant).await?; match self { - Body::SphinxPacket(packet) => { + Self::SphinxPacket(packet) => { let data = packet.to_bytes(); writer.write_u64(data.len() as u64).await?; writer.write_all(&data).await?; } - Body::FinalPayload(payload) => { + Self::FinalPayload(payload) => { let data = payload.as_bytes(); writer.write_u64(data.len() as u64).await?; writer.write_all(data).await?; @@ -99,3 +104,36 @@ impl Body { Ok(()) } } + +pub async fn retry_backoff( + peer_addr: SocketAddr, + max_retries: usize, + retry_delay: Duration, + body: Body, + socket: Arc>, +) -> Result<(), Box> { + for idx in 0..max_retries { + // backoff + let wait = Duration::from_millis((retry_delay.as_millis() as u64).pow(idx as u32)); + + tokio::time::sleep(wait).await; + let mut socket = socket.lock().await; + match body.write(&mut *socket).await { + Ok(_) => return Ok(()), + Err(e) => { + if let Some(err) = e.downcast_ref::() { + match err.kind() { + ErrorKind::Unsupported => return Err(e), + _ => { + // update the connection + if let Ok(tcp) = TcpStream::connect(peer_addr).await { + *socket = tcp; + } + } + } + } + } + } + } + Err(format!("Failure after {max_retries} retries").into()) +} diff --git a/tests/src/benches/mixnet.rs b/tests/src/benches/mixnet.rs index 85470feb..5d56ca43 100644 --- a/tests/src/benches/mixnet.rs +++ b/tests/src/benches/mixnet.rs @@ -34,6 +34,8 @@ async fn setup(msg_size: usize) -> (Vec, MixnetClient, MessageSt mode: MixnetClientMode::Sender, topology: topology.clone(), connection_pool_size: 255, + max_retries: 3, + retry_delay: Duration::from_secs(5), }, OsRng, ); @@ -47,6 +49,8 @@ async fn setup(msg_size: usize) -> (Vec, MixnetClient, MessageSt ), topology, connection_pool_size: 255, + max_retries: 3, + retry_delay: Duration::from_secs(5), }, OsRng, ); diff --git a/tests/src/nodes/mixnode.rs b/tests/src/nodes/mixnode.rs index 854f9b2c..b55de6e0 100644 --- a/tests/src/nodes/mixnode.rs +++ b/tests/src/nodes/mixnode.rs @@ -65,6 +65,7 @@ impl MixNode { )), private_key, connection_pool_size: 255, + ..Default::default() }; configs.push(config); } diff --git a/tests/src/nodes/nomos.rs b/tests/src/nodes/nomos.rs index 4c7f2ed9..c41ef75c 100644 --- a/tests/src/nodes/nomos.rs +++ b/tests/src/nodes/nomos.rs @@ -255,6 +255,8 @@ fn create_node_config( mode: mixnet_client_mode, topology: mixnet_topology, connection_pool_size: 255, + max_retries: 3, + retry_delay: Duration::from_secs(5), }, mixnet_delay: Duration::ZERO..Duration::from_millis(10), }, diff --git a/tests/src/tests/mixnet.rs b/tests/src/tests/mixnet.rs index ea7f6702..a0a5dea1 100644 --- a/tests/src/tests/mixnet.rs +++ b/tests/src/tests/mixnet.rs @@ -24,6 +24,8 @@ async fn mixnet() { mode: MixnetClientMode::Sender, topology: topology.clone(), connection_pool_size: 255, + max_retries: 3, + retry_delay: Duration::from_secs(5), }, OsRng, ); @@ -126,6 +128,8 @@ async fn run_nodes_and_destination_client() -> ( mode: MixnetClientMode::SenderReceiver(config3.client_listen_address), topology: topology.clone(), connection_pool_size: 255, + max_retries: 3, + retry_delay: Duration::from_secs(5), }, OsRng, );