Add mixnet retry mechanism (#386)

Imiplement mixnet retry
This commit is contained in:
Al Liu 2023-09-18 18:26:54 +08:00 committed by GitHub
parent 5e194922c6
commit 2429893ac2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 237 additions and 36 deletions

View File

@ -6,7 +6,7 @@ edition = "2021"
[dependencies] [dependencies]
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
tracing = "0.1.37" tracing = "0.1.37"
tokio = { version = "1.29.1", features = ["net"] } tokio = { version = "1.32", features = ["net"] }
sphinx-packet = "0.1.0" sphinx-packet = "0.1.0"
nym-sphinx = { package = "nym-sphinx", git = "https://github.com/nymtech/nym", tag = "v1.1.22" } nym-sphinx = { package = "nym-sphinx", git = "https://github.com/nymtech/nym", tag = "v1.1.22" }
# Using an older version, since `nym-sphinx` depends on `rand` v0.7.3. # Using an older version, since `nym-sphinx` depends on `rand` v0.7.3.

View File

@ -11,6 +11,21 @@ pub struct MixnetClientConfig {
pub mode: MixnetClientMode, pub mode: MixnetClientMode,
pub topology: MixnetTopology, pub topology: MixnetTopology,
pub connection_pool_size: usize, pub connection_pool_size: usize,
pub max_retries: usize,
pub retry_delay: std::time::Duration,
}
impl MixnetClientConfig {
/// Creates a new `MixnetClientConfig` with default values.
pub fn new(mode: MixnetClientMode, topology: MixnetTopology) -> Self {
Self {
mode,
topology,
connection_pool_size: 256,
max_retries: 3,
retry_delay: std::time::Duration::from_secs(5),
}
}
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]

View File

@ -26,7 +26,13 @@ impl<R: Rng> MixnetClient<R> {
let cache = ConnectionPool::new(config.connection_pool_size); let cache = ConnectionPool::new(config.connection_pool_size);
Self { Self {
mode: config.mode, mode: config.mode,
sender: Sender::new(config.topology, cache, rng), sender: Sender::new(
config.topology,
cache,
rng,
config.max_retries,
config.retry_delay,
),
} }
} }

View File

@ -39,7 +39,8 @@ impl Receiver {
fn fragment_stream( fn fragment_stream(
socket: TcpStream, socket: TcpStream,
) -> impl Stream<Item = Result<Fragment, MixnetClientError>> + Send + 'static { ) -> impl Stream<Item = Result<Fragment, MixnetClientError>> + Send + 'static {
stream::unfold(socket, |mut socket| async move { stream::unfold(socket, move |mut socket| {
async move {
let Ok(body) = Body::read(&mut socket).await else { let Ok(body) = Body::read(&mut socket).await else {
// TODO: Maybe this is a hard error and the stream is corrupted? In that case stop the stream // TODO: Maybe this is a hard error and the stream is corrupted? In that case stop the stream
return Some((Err(MixnetClientError::MixnetNodeStreamClosed), socket)); return Some((Err(MixnetClientError::MixnetNodeStreamClosed), socket));
@ -49,7 +50,11 @@ impl Receiver {
Body::SphinxPacket(_) => { Body::SphinxPacket(_) => {
Some((Err(MixnetClientError::UnexpectedStreamBody), socket)) Some((Err(MixnetClientError::UnexpectedStreamBody), socket))
} }
Body::FinalPayload(payload) => Some((Self::fragment_from_payload(payload), socket)), Body::FinalPayload(payload) => {
Some((Self::fragment_from_payload(payload), socket))
}
_ => unreachable!(),
}
} }
}) })
} }

View File

@ -16,15 +16,25 @@ pub struct Sender<R: Rng> {
//TODO: handle topology update //TODO: handle topology update
topology: MixnetTopology, topology: MixnetTopology,
pool: ConnectionPool, pool: ConnectionPool,
max_retries: usize,
retry_delay: Duration,
rng: R, rng: R,
} }
impl<R: Rng> Sender<R> { impl<R: Rng> Sender<R> {
pub fn new(topology: MixnetTopology, pool: ConnectionPool, rng: R) -> Self { pub fn new(
topology: MixnetTopology,
pool: ConnectionPool,
rng: R,
max_retries: usize,
retry_delay: Duration,
) -> Self {
Self { Self {
topology, topology,
rng, rng,
pool, pool,
max_retries,
retry_delay,
} }
} }
@ -46,9 +56,17 @@ impl<R: Rng> Sender<R> {
.into_iter() .into_iter()
.for_each(|(packet, first_node)| { .for_each(|(packet, first_node)| {
let pool = self.pool.clone(); let pool = self.pool.clone();
let max_retries = self.max_retries;
let retry_delay = self.retry_delay;
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = if let Err(e) = Self::send_packet(
Self::send_packet(&pool, Box::new(packet), first_node.address).await &pool,
max_retries,
retry_delay,
Box::new(packet),
first_node.address,
)
.await
{ {
tracing::error!("failed to send packet to the first node: {e}"); tracing::error!("failed to send packet to the first node: {e}");
} }
@ -101,6 +119,8 @@ impl<R: Rng> Sender<R> {
async fn send_packet( async fn send_packet(
pool: &ConnectionPool, pool: &ConnectionPool,
max_retries: usize,
retry_delay: Duration,
packet: Box<SphinxPacket>, packet: Box<SphinxPacket>,
addr: NodeAddressBytes, addr: NodeAddressBytes,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> { ) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
@ -109,12 +129,24 @@ impl<R: Rng> Sender<R> {
let mu: std::sync::Arc<tokio::sync::Mutex<tokio::net::TcpStream>> = let mu: std::sync::Arc<tokio::sync::Mutex<tokio::net::TcpStream>> =
pool.get_or_init(&addr).await?; pool.get_or_init(&addr).await?;
let arc_socket = mu.clone();
let body = Body::SphinxPacket(packet);
if let Err(e) = {
let mut socket = mu.lock().await; let mut socket = mu.lock().await;
let body = Body::new_sphinx(packet); body.write(&mut *socket).await
body.write(&mut *socket).await?; } {
tracing::error!("Failed to send packet to {addr} with error: {e}. Retrying...");
tracing::debug!("Sent a Sphinx packet successuflly to the node: {addr:?}"); return mixnet_protocol::retry_backoff(
addr,
max_retries,
retry_delay,
body,
arc_socket,
)
.await;
}
Ok(()) Ok(())
} }
} }

View File

@ -1,4 +1,7 @@
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use std::{
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
time::Duration,
};
use nym_sphinx::{PrivateKey, PublicKey}; use nym_sphinx::{PrivateKey, PublicKey};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -13,7 +16,14 @@ pub struct MixnetNodeConfig {
/// A key for decrypting Sphinx packets /// A key for decrypting Sphinx packets
pub private_key: [u8; PRIVATE_KEY_SIZE], pub private_key: [u8; PRIVATE_KEY_SIZE],
/// The size of the connection pool. /// The size of the connection pool.
#[serde(default = "MixnetNodeConfig::default_connection_pool_size")]
pub connection_pool_size: usize, pub connection_pool_size: usize,
/// The maximum number of retries.
#[serde(default = "MixnetNodeConfig::default_max_retries")]
pub max_retries: usize,
/// The retry delay between retries.
#[serde(default = "MixnetNodeConfig::default_retry_delay")]
pub retry_delay: Duration,
} }
impl Default for MixnetNodeConfig { impl Default for MixnetNodeConfig {
@ -26,11 +36,25 @@ impl Default for MixnetNodeConfig {
)), )),
private_key: PrivateKey::new().to_bytes(), private_key: PrivateKey::new().to_bytes(),
connection_pool_size: 255, connection_pool_size: 255,
max_retries: 3,
retry_delay: Duration::from_secs(5),
} }
} }
} }
impl MixnetNodeConfig { impl MixnetNodeConfig {
const fn default_connection_pool_size() -> usize {
255
}
const fn default_max_retries() -> usize {
3
}
const fn default_retry_delay() -> Duration {
Duration::from_secs(5)
}
pub fn public_key(&self) -> [u8; PUBLIC_KEY_SIZE] { pub fn public_key(&self) -> [u8; PUBLIC_KEY_SIZE] {
*PublicKey::from(&PrivateKey::from(self.private_key)).as_bytes() *PublicKey::from(&PrivateKey::from(self.private_key)).as_bytes()
} }

View File

@ -1,7 +1,7 @@
mod client_notifier; mod client_notifier;
pub mod config; pub mod config;
use std::{error::Error, net::SocketAddr}; use std::{error::Error, net::SocketAddr, time::Duration};
use client_notifier::ClientNotifier; use client_notifier::ClientNotifier;
pub use config::MixnetNodeConfig; pub use config::MixnetNodeConfig;
@ -70,8 +70,15 @@ impl MixnetNode {
let private_key = self.config.private_key; let private_key = self.config.private_key;
let pool = self.pool.clone(); let pool = self.pool.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = if let Err(e) = Self::handle_connection(
Self::handle_connection(socket, pool, private_key, client_tx).await socket,
self.config.max_retries,
self.config.retry_delay,
pool,
private_key,
client_tx,
)
.await
{ {
tracing::error!("failed to handle conn: {e}"); tracing::error!("failed to handle conn: {e}");
} }
@ -84,6 +91,8 @@ impl MixnetNode {
async fn handle_connection( async fn handle_connection(
mut socket: TcpStream, mut socket: TcpStream,
max_retries: usize,
retry_delay: Duration,
pool: ConnectionPool, pool: ConnectionPool,
private_key: [u8; PRIVATE_KEY_SIZE], private_key: [u8; PRIVATE_KEY_SIZE],
client_tx: mpsc::Sender<Body>, client_tx: mpsc::Sender<Body>,
@ -96,14 +105,27 @@ impl MixnetNode {
let client_tx = client_tx.clone(); let client_tx = client_tx.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = Self::handle_body(body, &pool, &private_key, &client_tx).await { if let Err(e) = Self::handle_body(
max_retries,
retry_delay,
body,
&pool,
&private_key,
&client_tx,
)
.await
{
tracing::error!("failed to handle body: {e}"); tracing::error!("failed to handle body: {e}");
} }
}); });
} }
} }
// TODO: refactor this fn to make it receive less arguments
#[allow(clippy::too_many_arguments)]
async fn handle_body( async fn handle_body(
max_retries: usize,
retry_delay: Duration,
body: Body, body: Body,
pool: &ConnectionPool, pool: &ConnectionPool,
private_key: &PrivateKey, private_key: &PrivateKey,
@ -111,25 +133,49 @@ impl MixnetNode {
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> { ) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
match body { match body {
Body::SphinxPacket(packet) => { Body::SphinxPacket(packet) => {
Self::handle_sphinx_packet(pool, private_key, packet).await Self::handle_sphinx_packet(pool, max_retries, retry_delay, private_key, packet)
.await
} }
_body @ Body::FinalPayload(_) => { Body::FinalPayload(payload) => {
Self::forward_body_to_client_notifier(private_key, client_tx, _body).await Self::forward_body_to_client_notifier(
private_key,
client_tx,
Body::FinalPayload(payload),
)
.await
} }
_ => unreachable!(),
} }
} }
async fn handle_sphinx_packet( async fn handle_sphinx_packet(
pool: &ConnectionPool, pool: &ConnectionPool,
max_retries: usize,
retry_delay: Duration,
private_key: &PrivateKey, private_key: &PrivateKey,
packet: Box<SphinxPacket>, packet: Box<SphinxPacket>,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> { ) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
match packet.process(private_key)? { match packet.process(private_key)? {
ProcessedPacket::ForwardHop(packet, next_node_addr, delay) => { ProcessedPacket::ForwardHop(packet, next_node_addr, delay) => {
Self::forward_packet_to_next_hop(pool, packet, next_node_addr, delay).await Self::forward_packet_to_next_hop(
pool,
max_retries,
retry_delay,
packet,
next_node_addr,
delay,
)
.await
} }
ProcessedPacket::FinalHop(destination_addr, _, payload) => { ProcessedPacket::FinalHop(destination_addr, _, payload) => {
Self::forward_payload_to_destination(pool, payload, destination_addr).await Self::forward_payload_to_destination(
pool,
max_retries,
retry_delay,
payload,
destination_addr,
)
.await
} }
} }
} }
@ -148,6 +194,8 @@ impl MixnetNode {
async fn forward_packet_to_next_hop( async fn forward_packet_to_next_hop(
pool: &ConnectionPool, pool: &ConnectionPool,
max_retries: usize,
retry_delay: Duration,
packet: Box<SphinxPacket>, packet: Box<SphinxPacket>,
next_node_addr: NodeAddressBytes, next_node_addr: NodeAddressBytes,
delay: Delay, delay: Delay,
@ -157,6 +205,8 @@ impl MixnetNode {
Self::forward( Self::forward(
pool, pool,
max_retries,
retry_delay,
Body::new_sphinx(packet), Body::new_sphinx(packet),
NymNodeRoutingAddress::try_from(next_node_addr)?, NymNodeRoutingAddress::try_from(next_node_addr)?,
) )
@ -165,6 +215,8 @@ impl MixnetNode {
async fn forward_payload_to_destination( async fn forward_payload_to_destination(
pool: &ConnectionPool, pool: &ConnectionPool,
max_retries: usize,
retry_delay: Duration,
payload: Payload, payload: Payload,
destination_addr: DestinationAddressBytes, destination_addr: DestinationAddressBytes,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> { ) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
@ -172,6 +224,8 @@ impl MixnetNode {
Self::forward( Self::forward(
pool, pool,
max_retries,
retry_delay,
Body::new_final_payload(payload), Body::new_final_payload(payload),
NymNodeRoutingAddress::try_from_bytes(&destination_addr.as_bytes())?, NymNodeRoutingAddress::try_from_bytes(&destination_addr.as_bytes())?,
) )
@ -180,12 +234,28 @@ impl MixnetNode {
async fn forward( async fn forward(
pool: &ConnectionPool, pool: &ConnectionPool,
max_retries: usize,
retry_delay: Duration,
body: Body, body: Body,
to: NymNodeRoutingAddress, to: NymNodeRoutingAddress,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> { ) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
let addr = SocketAddr::try_from(to)?; let addr = SocketAddr::try_from(to)?;
body.write(&mut *pool.get_or_init(&addr).await?.lock().await) let arc_socket = pool.get_or_init(&addr).await?;
.await?;
if let Err(e) = {
let mut socket = arc_socket.lock().await;
body.write(&mut *socket).await
} {
tracing::error!("Failed to forward packet to {addr} with error: {e}. Retrying...");
return mixnet_protocol::retry_backoff(
addr,
max_retries,
retry_delay,
body,
arc_socket,
)
.await;
}
Ok(()) Ok(())
} }
} }

View File

@ -6,7 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
tokio = "1.29.1" tokio = "1.32"
sphinx-packet = "0.1.0" sphinx-packet = "0.1.0"
futures = "0.3" futures = "0.3"
tokio-util = {version = "0.7", features = ["io", "io-util"] } tokio-util = {version = "0.7", features = ["io", "io-util"] }

View File

@ -1,8 +1,13 @@
use sphinx_packet::{payload::Payload, SphinxPacket}; use sphinx_packet::{payload::Payload, SphinxPacket};
use std::error::Error; use std::{error::Error, io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::TcpStream,
sync::Mutex,
};
#[non_exhaustive]
pub enum Body { pub enum Body {
SphinxPacket(Box<SphinxPacket>), SphinxPacket(Box<SphinxPacket>),
FinalPayload(Payload), FinalPayload(Payload),
@ -76,7 +81,7 @@ impl Body {
} }
pub async fn write<W>( pub async fn write<W>(
self, &self,
writer: &mut W, writer: &mut W,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> ) -> Result<(), Box<dyn Error + Send + Sync + 'static>>
where where
@ -85,12 +90,12 @@ impl Body {
let variant = self.variant_as_u8(); let variant = self.variant_as_u8();
writer.write_u8(variant).await?; writer.write_u8(variant).await?;
match self { match self {
Body::SphinxPacket(packet) => { Self::SphinxPacket(packet) => {
let data = packet.to_bytes(); let data = packet.to_bytes();
writer.write_u64(data.len() as u64).await?; writer.write_u64(data.len() as u64).await?;
writer.write_all(&data).await?; writer.write_all(&data).await?;
} }
Body::FinalPayload(payload) => { Self::FinalPayload(payload) => {
let data = payload.as_bytes(); let data = payload.as_bytes();
writer.write_u64(data.len() as u64).await?; writer.write_u64(data.len() as u64).await?;
writer.write_all(data).await?; writer.write_all(data).await?;
@ -99,3 +104,36 @@ impl Body {
Ok(()) Ok(())
} }
} }
pub async fn retry_backoff(
peer_addr: SocketAddr,
max_retries: usize,
retry_delay: Duration,
body: Body,
socket: Arc<Mutex<TcpStream>>,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
for idx in 0..max_retries {
// backoff
let wait = Duration::from_millis((retry_delay.as_millis() as u64).pow(idx as u32));
tokio::time::sleep(wait).await;
let mut socket = socket.lock().await;
match body.write(&mut *socket).await {
Ok(_) => return Ok(()),
Err(e) => {
if let Some(err) = e.downcast_ref::<std::io::Error>() {
match err.kind() {
ErrorKind::Unsupported => return Err(e),
_ => {
// update the connection
if let Ok(tcp) = TcpStream::connect(peer_addr).await {
*socket = tcp;
}
}
}
}
}
}
}
Err(format!("Failure after {max_retries} retries").into())
}

View File

@ -34,6 +34,8 @@ async fn setup(msg_size: usize) -> (Vec<MixNode>, MixnetClient<OsRng>, MessageSt
mode: MixnetClientMode::Sender, mode: MixnetClientMode::Sender,
topology: topology.clone(), topology: topology.clone(),
connection_pool_size: 255, connection_pool_size: 255,
max_retries: 3,
retry_delay: Duration::from_secs(5),
}, },
OsRng, OsRng,
); );
@ -47,6 +49,8 @@ async fn setup(msg_size: usize) -> (Vec<MixNode>, MixnetClient<OsRng>, MessageSt
), ),
topology, topology,
connection_pool_size: 255, connection_pool_size: 255,
max_retries: 3,
retry_delay: Duration::from_secs(5),
}, },
OsRng, OsRng,
); );

View File

@ -65,6 +65,7 @@ impl MixNode {
)), )),
private_key, private_key,
connection_pool_size: 255, connection_pool_size: 255,
..Default::default()
}; };
configs.push(config); configs.push(config);
} }

View File

@ -255,6 +255,8 @@ fn create_node_config(
mode: mixnet_client_mode, mode: mixnet_client_mode,
topology: mixnet_topology, topology: mixnet_topology,
connection_pool_size: 255, connection_pool_size: 255,
max_retries: 3,
retry_delay: Duration::from_secs(5),
}, },
mixnet_delay: Duration::ZERO..Duration::from_millis(10), mixnet_delay: Duration::ZERO..Duration::from_millis(10),
}, },

View File

@ -24,6 +24,8 @@ async fn mixnet() {
mode: MixnetClientMode::Sender, mode: MixnetClientMode::Sender,
topology: topology.clone(), topology: topology.clone(),
connection_pool_size: 255, connection_pool_size: 255,
max_retries: 3,
retry_delay: Duration::from_secs(5),
}, },
OsRng, OsRng,
); );
@ -126,6 +128,8 @@ async fn run_nodes_and_destination_client() -> (
mode: MixnetClientMode::SenderReceiver(config3.client_listen_address), mode: MixnetClientMode::SenderReceiver(config3.client_listen_address),
topology: topology.clone(), topology: topology.clone(),
connection_pool_size: 255, connection_pool_size: 255,
max_retries: 3,
retry_delay: Duration::from_secs(5),
}, },
OsRng, OsRng,
); );