diff --git a/.DS_Store b/.DS_Store
deleted file mode 100644
index f1112a46..00000000
Binary files a/.DS_Store and /dev/null differ
diff --git a/mixnet/node/Cargo.toml b/mixnet/node/Cargo.toml
index 37601022..f1fed044 100644
--- a/mixnet/node/Cargo.toml
+++ b/mixnet/node/Cargo.toml
@@ -6,7 +6,7 @@ edition = "2021"
[dependencies]
serde = { version = "1.0", features = ["derive"] }
tracing = "0.1.37"
-tokio = { version = "1.32", features = ["net", "time"] }
+tokio = { version = "1.32", features = ["net", "time", "signal"] }
thiserror = "1"
sphinx-packet = "0.1.0"
nym-sphinx = { package = "nym-sphinx", git = "https://github.com/nymtech/nym", tag = "v1.1.22" }
diff --git a/mixnet/node/src/config.rs b/mixnet/node/src/config.rs
index 6d325c9b..50b1061f 100644
--- a/mixnet/node/src/config.rs
+++ b/mixnet/node/src/config.rs
@@ -7,7 +7,7 @@ use nym_sphinx::{PrivateKey, PublicKey};
use serde::{Deserialize, Serialize};
use sphinx_packet::crypto::{PRIVATE_KEY_SIZE, PUBLIC_KEY_SIZE};
-#[derive(Serialize, Deserialize, Clone, Debug)]
+#[derive(Serialize, Deserialize, Copy, Clone, Debug)]
pub struct MixnetNodeConfig {
/// A listen address for receiving Sphinx packets
pub listen_address: SocketAddr,
diff --git a/mixnet/node/src/lib.rs b/mixnet/node/src/lib.rs
index 61337689..f79525ca 100644
--- a/mixnet/node/src/lib.rs
+++ b/mixnet/node/src/lib.rs
@@ -1,16 +1,15 @@
mod client_notifier;
pub mod config;
-use std::{net::SocketAddr, time::Duration};
+use std::{collections::HashMap, net::SocketAddr, time::Duration};
use client_notifier::ClientNotifier;
pub use config::MixnetNodeConfig;
use mixnet_protocol::{Body, ProtocolError};
use mixnet_topology::MixnetNodeId;
-use mixnet_util::ConnectionPool;
use nym_sphinx::{
addressing::nodes::{NymNodeRoutingAddress, NymNodeRoutingAddressError},
- Delay, DestinationAddressBytes, NodeAddressBytes, Payload, PrivateKey,
+ Delay, DestinationAddressBytes, NodeAddressBytes, PrivateKey,
};
pub use sphinx_packet::crypto::PRIVATE_KEY_SIZE;
use sphinx_packet::{crypto::PUBLIC_KEY_SIZE, ProcessedPacket, SphinxPacket};
@@ -28,7 +27,9 @@ pub enum MixnetNodeError {
#[error("invalid routing address: {0}")]
InvalidRoutingAddress(#[from] NymNodeRoutingAddressError),
#[error("send error: {0}")]
- SendError(#[from] tokio::sync::mpsc::error::TrySendError
),
+ PacketSendError(#[from] tokio::sync::mpsc::error::SendError),
+ #[error("send error: fail to send {0} to client")]
+ ClientSendError(#[from] tokio::sync::mpsc::error::TrySendError),
#[error("client: {0}")]
Client(ProtocolError),
}
@@ -36,13 +37,11 @@ pub enum MixnetNodeError {
// A mix node that routes packets in the Mixnet.
pub struct MixnetNode {
config: MixnetNodeConfig,
- pool: ConnectionPool,
}
impl MixnetNode {
pub fn new(config: MixnetNodeConfig) -> Self {
- let pool = ConnectionPool::new(config.connection_pool_size);
- Self { config, pool }
+ Self { config }
}
pub fn id(&self) -> MixnetNodeId {
@@ -68,7 +67,6 @@ impl MixnetNode {
});
//TODO: Accepting ad-hoc TCP conns for now. Improve conn handling.
- //TODO: Add graceful shutdown
let listener = TcpListener::bind(self.config.listen_address)
.await
.map_err(ProtocolError::IO)?;
@@ -77,205 +75,239 @@ impl MixnetNode {
self.config.listen_address
);
- loop {
- match listener.accept().await {
- Ok((socket, remote_addr)) => {
- tracing::debug!("Accepted incoming connection from {remote_addr:?}");
+ let (tx, rx) = mpsc::unbounded_channel();
- let client_tx = client_tx.clone();
- let private_key = self.config.private_key;
- let pool = self.pool.clone();
- tokio::spawn(async move {
- if let Err(e) = Self::handle_connection(
- socket,
- self.config.max_retries,
- self.config.retry_delay,
- pool,
- private_key,
- client_tx,
- )
- .await
- {
- tracing::error!("failed to handle conn: {e}");
+ let packet_forwarder = PacketForwarder::new(tx.clone(), rx, self.config);
+
+ tokio::spawn(async move {
+ packet_forwarder.run().await;
+ });
+
+ let runner = MixnetNodeRunner {
+ config: self.config,
+ client_tx,
+ packet_tx: tx,
+ };
+
+ loop {
+ tokio::select! {
+ res = listener.accept() => {
+ match res {
+ Ok((socket, remote_addr)) => {
+ tracing::debug!("Accepted incoming connection from {remote_addr:?}");
+
+ let runner = runner.clone();
+ tokio::spawn(async move {
+ if let Err(e) = runner.handle_connection(socket).await {
+ tracing::error!("failed to handle conn: {e}");
+ }
+ });
}
- });
+ Err(e) => tracing::warn!("Failed to accept incoming connection: {e}"),
+ }
+ }
+ _ = tokio::signal::ctrl_c() => {
+ tracing::info!("Shutting down...");
+ return Ok(());
}
- Err(e) => tracing::warn!("Failed to accept incoming connection: {e}"),
}
}
}
+}
- async fn handle_connection(
- mut socket: TcpStream,
- max_retries: usize,
- retry_delay: Duration,
- pool: ConnectionPool,
- private_key: [u8; PRIVATE_KEY_SIZE],
- client_tx: mpsc::Sender,
- ) -> Result<()> {
+#[derive(Clone)]
+struct MixnetNodeRunner {
+ config: MixnetNodeConfig,
+ client_tx: mpsc::Sender,
+ packet_tx: mpsc::UnboundedSender,
+}
+
+impl MixnetNodeRunner {
+ async fn handle_connection(&self, mut socket: TcpStream) -> Result<()> {
loop {
let body = Body::read(&mut socket).await?;
-
- let pool = pool.clone();
- let private_key = PrivateKey::from(private_key);
- let client_tx = client_tx.clone();
-
+ let this = self.clone();
tokio::spawn(async move {
- if let Err(e) = Self::handle_body(
- max_retries,
- retry_delay,
- body,
- &pool,
- &private_key,
- &client_tx,
- )
- .await
- {
+ if let Err(e) = this.handle_body(body).await {
tracing::error!("failed to handle body: {e}");
}
});
}
}
- // TODO: refactor this fn to make it receive less arguments
- #[allow(clippy::too_many_arguments)]
- async fn handle_body(
- max_retries: usize,
- retry_delay: Duration,
- body: Body,
- pool: &ConnectionPool,
- private_key: &PrivateKey,
- client_tx: &mpsc::Sender,
- ) -> Result<()> {
- match body {
- Body::SphinxPacket(packet) => {
- Self::handle_sphinx_packet(pool, max_retries, retry_delay, private_key, packet)
- .await
- }
+ async fn handle_body(&self, pkt: Body) -> Result<()> {
+ match pkt {
+ Body::SphinxPacket(packet) => self.handle_sphinx_packet(packet).await,
Body::FinalPayload(payload) => {
- Self::forward_body_to_client_notifier(
- private_key,
- client_tx,
- Body::FinalPayload(payload),
- )
- .await
+ self.forward_body_to_client_notifier(Body::FinalPayload(payload))
+ .await
}
_ => unreachable!(),
}
}
- async fn handle_sphinx_packet(
- pool: &ConnectionPool,
- max_retries: usize,
- retry_delay: Duration,
- private_key: &PrivateKey,
- packet: Box,
- ) -> Result<()> {
+ async fn handle_sphinx_packet(&self, packet: Box) -> Result<()> {
match packet
- .process(private_key)
+ .process(&PrivateKey::from(self.config.private_key))
.map_err(ProtocolError::InvalidSphinxPacket)?
{
ProcessedPacket::ForwardHop(packet, next_node_addr, delay) => {
- Self::forward_packet_to_next_hop(
- pool,
- max_retries,
- retry_delay,
- packet,
- next_node_addr,
- delay,
- )
- .await
+ self.forward_packet_to_next_hop(Body::SphinxPacket(packet), next_node_addr, delay)
+ .await
}
ProcessedPacket::FinalHop(destination_addr, _, payload) => {
- Self::forward_payload_to_destination(
- pool,
- max_retries,
- retry_delay,
- payload,
- destination_addr,
- )
- .await
+ self.forward_payload_to_destination(Body::FinalPayload(payload), destination_addr)
+ .await
}
}
}
- async fn forward_body_to_client_notifier(
- _private_key: &PrivateKey,
- client_tx: &mpsc::Sender,
- body: Body,
- ) -> Result<()> {
+ async fn forward_body_to_client_notifier(&self, body: Body) -> Result<()> {
// TODO: Decrypt the final payload using the private key, if it's encrypted
// Do not wait when the channel is full or no receiver exists
- client_tx.try_send(body)?;
+ self.client_tx.try_send(body)?;
Ok(())
}
async fn forward_packet_to_next_hop(
- pool: &ConnectionPool,
- max_retries: usize,
- retry_delay: Duration,
- packet: Box,
+ &self,
+ packet: Body,
next_node_addr: NodeAddressBytes,
delay: Delay,
) -> Result<()> {
tracing::debug!("Delaying the packet for {delay:?}");
tokio::time::sleep(delay.to_duration()).await;
- Self::forward(
- pool,
- max_retries,
- retry_delay,
- Body::new_sphinx(packet),
- NymNodeRoutingAddress::try_from(next_node_addr)?,
- )
- .await
+ self.forward(packet, NymNodeRoutingAddress::try_from(next_node_addr)?)
+ .await
}
async fn forward_payload_to_destination(
- pool: &ConnectionPool,
- max_retries: usize,
- retry_delay: Duration,
- payload: Payload,
+ &self,
+ payload: Body,
destination_addr: DestinationAddressBytes,
) -> Result<()> {
tracing::debug!("Forwarding final payload to destination mixnode");
- Self::forward(
- pool,
- max_retries,
- retry_delay,
- Body::new_final_payload(payload),
+ self.forward(
+ payload,
NymNodeRoutingAddress::try_from_bytes(&destination_addr.as_bytes())?,
)
.await
}
- async fn forward(
- pool: &ConnectionPool,
- max_retries: usize,
- retry_delay: Duration,
- body: Body,
- to: NymNodeRoutingAddress,
- ) -> Result<()> {
+ async fn forward(&self, pkt: Body, to: NymNodeRoutingAddress) -> Result<()> {
let addr = SocketAddr::from(to);
- let arc_socket = pool.get_or_init(&addr).await?;
- if let Err(e) = {
- let mut socket = arc_socket.lock().await;
- body.write(&mut *socket).await
- } {
- tracing::error!("Failed to forward packet to {addr} with error: {e}. Retrying...");
- return mixnet_protocol::retry_backoff(
- addr,
- max_retries,
- retry_delay,
- body,
- arc_socket,
- )
- .await
- .map_err(Into::into);
- }
+ self.packet_tx.send(Packet::new(addr, pkt))?;
Ok(())
}
}
+
+struct PacketForwarder {
+ config: MixnetNodeConfig,
+ packet_rx: mpsc::UnboundedReceiver,
+ packet_tx: mpsc::UnboundedSender,
+ connections: HashMap,
+}
+
+impl PacketForwarder {
+ pub fn new(
+ packet_tx: mpsc::UnboundedSender,
+ packet_rx: mpsc::UnboundedReceiver,
+ config: MixnetNodeConfig,
+ ) -> Self {
+ Self {
+ packet_tx,
+ packet_rx,
+ connections: HashMap::with_capacity(config.connection_pool_size),
+ config,
+ }
+ }
+
+ pub async fn run(mut self) {
+ loop {
+ tokio::select! {
+ pkt = self.packet_rx.recv() => {
+ if let Some(pkt) = pkt {
+ self.send(pkt).await;
+ } else {
+ unreachable!("Packet channel should not be closed, because PacketForwarder is also holding the send half");
+ }
+ },
+ _ = tokio::signal::ctrl_c() => {
+ tracing::info!("Shutting down packet forwarder task...");
+ return;
+ }
+ }
+ }
+ }
+
+ async fn try_send(&mut self, target: SocketAddr, body: &Body) -> Result<()> {
+ if let std::collections::hash_map::Entry::Vacant(e) = self.connections.entry(target) {
+ match TcpStream::connect(target).await {
+ Ok(tcp) => {
+ e.insert(tcp);
+ }
+ Err(e) => {
+ tracing::error!("failed to connect to {}: {e}", target);
+ return Err(MixnetNodeError::Protocol(e.into()));
+ }
+ }
+ }
+ Ok(body
+ .write(self.connections.get_mut(&target).unwrap())
+ .await?)
+ }
+
+ async fn send(&mut self, pkt: Packet) {
+ if let Err(err) = self.try_send(pkt.target, &pkt.body).await {
+ match err {
+ MixnetNodeError::Protocol(ProtocolError::IO(e))
+ if e.kind() == std::io::ErrorKind::Unsupported =>
+ {
+ tracing::error!("fail to send message to {}: {e}", pkt.target);
+ }
+ _ => self.handle_retry(pkt),
+ }
+ }
+ }
+
+ fn handle_retry(&self, mut pkt: Packet) {
+ if pkt.retry_count < self.config.max_retries {
+ let delay = Duration::from_millis(
+ (self.config.retry_delay.as_millis() as u64).pow(pkt.retry_count as u32),
+ );
+ let tx = self.packet_tx.clone();
+ tokio::spawn(async move {
+ tokio::time::sleep(delay).await;
+ pkt.retry_count += 1;
+ if let Err(e) = tx.send(pkt) {
+ tracing::error!("fail to enqueue retry message: {e}");
+ }
+ });
+ } else {
+ tracing::error!(
+ "fail to send message to {}: reach maximum retries",
+ pkt.target
+ );
+ }
+ }
+}
+
+pub struct Packet {
+ target: SocketAddr,
+ body: Body,
+ retry_count: usize,
+}
+
+impl Packet {
+ fn new(target: SocketAddr, body: Body) -> Self {
+ Self {
+ target,
+ body,
+ retry_count: 0,
+ }
+ }
+}
diff --git a/mixnet/protocol/Cargo.toml b/mixnet/protocol/Cargo.toml
index c0db927b..d0673195 100644
--- a/mixnet/protocol/Cargo.toml
+++ b/mixnet/protocol/Cargo.toml
@@ -6,7 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
-tokio = { version = "1.32", features = ["sync", "net"] }
+tokio = { version = "1.32", features = ["sync", "net", "time"] }
sphinx-packet = "0.1.0"
futures = "0.3"
tokio-util = { version = "0.7", features = ["io", "io-util"] }
diff --git a/mixnet/protocol/src/lib.rs b/mixnet/protocol/src/lib.rs
index 459eac7b..af6fdf7c 100644
--- a/mixnet/protocol/src/lib.rs
+++ b/mixnet/protocol/src/lib.rs
@@ -1,6 +1,7 @@
use sphinx_packet::{payload::Payload, SphinxPacket};
use std::{io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration};
+
use tokio::{
io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::TcpStream,
@@ -122,7 +123,6 @@ pub async fn retry_backoff(
for idx in 0..max_retries {
// backoff
let wait = Duration::from_millis((retry_delay.as_millis() as u64).pow(idx as u32));
-
tokio::time::sleep(wait).await;
let mut socket = socket.lock().await;
match body.write(&mut *socket).await {
diff --git a/mixnet/util/Cargo.toml b/mixnet/util/Cargo.toml
index 7ab67774..66b323f0 100644
--- a/mixnet/util/Cargo.toml
+++ b/mixnet/util/Cargo.toml
@@ -5,5 +5,4 @@ edition = "2021"
[dependencies]
tokio = { version = "1.32", default-features = false, features = ["sync", "net"] }
-parking_lot = { version = "0.12", features = ["send_guard"] }
-mixnet-protocol = { path = "../protocol" }
\ No newline at end of file
+mixnet-protocol = { path = "../protocol" }
diff --git a/mixnet/util/src/lib.rs b/mixnet/util/src/lib.rs
index e00f3714..a8431eb2 100644
--- a/mixnet/util/src/lib.rs
+++ b/mixnet/util/src/lib.rs
@@ -1,7 +1,6 @@
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
-use tokio::net::TcpStream;
-use tokio::sync::Mutex;
+use tokio::{net::TcpStream, sync::Mutex};
#[derive(Clone)]
pub struct ConnectionPool {
diff --git a/tests/src/nodes/mixnode.rs b/tests/src/nodes/mixnode.rs
index b55de6e0..25a23060 100644
--- a/tests/src/nodes/mixnode.rs
+++ b/tests/src/nodes/mixnode.rs
@@ -72,7 +72,7 @@ impl MixNode {
let mut nodes = Vec::::new();
for config in &configs {
- nodes.push(Self::spawn(config.clone()).await);
+ nodes.push(Self::spawn(*config).await);
}
// We need to return configs as well, to configure mixclients accordingly