From 8ea341dbafbf13915c10de37db66f66bd0826df9 Mon Sep 17 00:00:00 2001 From: Al Liu Date: Tue, 26 Sep 2023 18:23:34 +0800 Subject: [PATCH] Mixnet packet handle (#435) * finish mixnet packet sending fan-in model --- .DS_Store | Bin 6148 -> 0 bytes mixnet/node/Cargo.toml | 2 +- mixnet/node/src/config.rs | 2 +- mixnet/node/src/lib.rs | 328 ++++++++++++++++++++----------------- mixnet/protocol/Cargo.toml | 2 +- mixnet/protocol/src/lib.rs | 2 +- mixnet/util/Cargo.toml | 3 +- mixnet/util/src/lib.rs | 3 +- tests/src/nodes/mixnode.rs | 2 +- 9 files changed, 187 insertions(+), 157 deletions(-) delete mode 100644 .DS_Store diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index f1112a46af5356070b3e3426e2915aae5b41750c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKyG{c^3>-s>NJuCp<^BK#f3S+e7w`j!gaRpW5+_n$#dq;(j2{KjK^4(NW63+a zUXM>V#rX`t)*rh&U6nG}!$Qa}nw0V!~@0@YHNS0`6e0VyB_&P@UTJ~X;xFPswN)4?T1 z0OEq-Fz#cPAT|#Wd*PJG2+fj8Osdt0VM%AaRbDTg5|a+A;lt|5RuhWF(^$=WSb@ttuDt)>)Bl+Nk4ZX70V!}+3fN-vuvzn!s<+Nw&U), + PacketSendError(#[from] tokio::sync::mpsc::error::SendError), + #[error("send error: fail to send {0} to client")] + ClientSendError(#[from] tokio::sync::mpsc::error::TrySendError), #[error("client: {0}")] Client(ProtocolError), } @@ -36,13 +37,11 @@ pub enum MixnetNodeError { // A mix node that routes packets in the Mixnet. pub struct MixnetNode { config: MixnetNodeConfig, - pool: ConnectionPool, } impl MixnetNode { pub fn new(config: MixnetNodeConfig) -> Self { - let pool = ConnectionPool::new(config.connection_pool_size); - Self { config, pool } + Self { config } } pub fn id(&self) -> MixnetNodeId { @@ -68,7 +67,6 @@ impl MixnetNode { }); //TODO: Accepting ad-hoc TCP conns for now. Improve conn handling. - //TODO: Add graceful shutdown let listener = TcpListener::bind(self.config.listen_address) .await .map_err(ProtocolError::IO)?; @@ -77,205 +75,239 @@ impl MixnetNode { self.config.listen_address ); - loop { - match listener.accept().await { - Ok((socket, remote_addr)) => { - tracing::debug!("Accepted incoming connection from {remote_addr:?}"); + let (tx, rx) = mpsc::unbounded_channel(); - let client_tx = client_tx.clone(); - let private_key = self.config.private_key; - let pool = self.pool.clone(); - tokio::spawn(async move { - if let Err(e) = Self::handle_connection( - socket, - self.config.max_retries, - self.config.retry_delay, - pool, - private_key, - client_tx, - ) - .await - { - tracing::error!("failed to handle conn: {e}"); + let packet_forwarder = PacketForwarder::new(tx.clone(), rx, self.config); + + tokio::spawn(async move { + packet_forwarder.run().await; + }); + + let runner = MixnetNodeRunner { + config: self.config, + client_tx, + packet_tx: tx, + }; + + loop { + tokio::select! { + res = listener.accept() => { + match res { + Ok((socket, remote_addr)) => { + tracing::debug!("Accepted incoming connection from {remote_addr:?}"); + + let runner = runner.clone(); + tokio::spawn(async move { + if let Err(e) = runner.handle_connection(socket).await { + tracing::error!("failed to handle conn: {e}"); + } + }); } - }); + Err(e) => tracing::warn!("Failed to accept incoming connection: {e}"), + } + } + _ = tokio::signal::ctrl_c() => { + tracing::info!("Shutting down..."); + return Ok(()); } - Err(e) => tracing::warn!("Failed to accept incoming connection: {e}"), } } } +} - async fn handle_connection( - mut socket: TcpStream, - max_retries: usize, - retry_delay: Duration, - pool: ConnectionPool, - private_key: [u8; PRIVATE_KEY_SIZE], - client_tx: mpsc::Sender, - ) -> Result<()> { +#[derive(Clone)] +struct MixnetNodeRunner { + config: MixnetNodeConfig, + client_tx: mpsc::Sender, + packet_tx: mpsc::UnboundedSender, +} + +impl MixnetNodeRunner { + async fn handle_connection(&self, mut socket: TcpStream) -> Result<()> { loop { let body = Body::read(&mut socket).await?; - - let pool = pool.clone(); - let private_key = PrivateKey::from(private_key); - let client_tx = client_tx.clone(); - + let this = self.clone(); tokio::spawn(async move { - if let Err(e) = Self::handle_body( - max_retries, - retry_delay, - body, - &pool, - &private_key, - &client_tx, - ) - .await - { + if let Err(e) = this.handle_body(body).await { tracing::error!("failed to handle body: {e}"); } }); } } - // TODO: refactor this fn to make it receive less arguments - #[allow(clippy::too_many_arguments)] - async fn handle_body( - max_retries: usize, - retry_delay: Duration, - body: Body, - pool: &ConnectionPool, - private_key: &PrivateKey, - client_tx: &mpsc::Sender, - ) -> Result<()> { - match body { - Body::SphinxPacket(packet) => { - Self::handle_sphinx_packet(pool, max_retries, retry_delay, private_key, packet) - .await - } + async fn handle_body(&self, pkt: Body) -> Result<()> { + match pkt { + Body::SphinxPacket(packet) => self.handle_sphinx_packet(packet).await, Body::FinalPayload(payload) => { - Self::forward_body_to_client_notifier( - private_key, - client_tx, - Body::FinalPayload(payload), - ) - .await + self.forward_body_to_client_notifier(Body::FinalPayload(payload)) + .await } _ => unreachable!(), } } - async fn handle_sphinx_packet( - pool: &ConnectionPool, - max_retries: usize, - retry_delay: Duration, - private_key: &PrivateKey, - packet: Box, - ) -> Result<()> { + async fn handle_sphinx_packet(&self, packet: Box) -> Result<()> { match packet - .process(private_key) + .process(&PrivateKey::from(self.config.private_key)) .map_err(ProtocolError::InvalidSphinxPacket)? { ProcessedPacket::ForwardHop(packet, next_node_addr, delay) => { - Self::forward_packet_to_next_hop( - pool, - max_retries, - retry_delay, - packet, - next_node_addr, - delay, - ) - .await + self.forward_packet_to_next_hop(Body::SphinxPacket(packet), next_node_addr, delay) + .await } ProcessedPacket::FinalHop(destination_addr, _, payload) => { - Self::forward_payload_to_destination( - pool, - max_retries, - retry_delay, - payload, - destination_addr, - ) - .await + self.forward_payload_to_destination(Body::FinalPayload(payload), destination_addr) + .await } } } - async fn forward_body_to_client_notifier( - _private_key: &PrivateKey, - client_tx: &mpsc::Sender, - body: Body, - ) -> Result<()> { + async fn forward_body_to_client_notifier(&self, body: Body) -> Result<()> { // TODO: Decrypt the final payload using the private key, if it's encrypted // Do not wait when the channel is full or no receiver exists - client_tx.try_send(body)?; + self.client_tx.try_send(body)?; Ok(()) } async fn forward_packet_to_next_hop( - pool: &ConnectionPool, - max_retries: usize, - retry_delay: Duration, - packet: Box, + &self, + packet: Body, next_node_addr: NodeAddressBytes, delay: Delay, ) -> Result<()> { tracing::debug!("Delaying the packet for {delay:?}"); tokio::time::sleep(delay.to_duration()).await; - Self::forward( - pool, - max_retries, - retry_delay, - Body::new_sphinx(packet), - NymNodeRoutingAddress::try_from(next_node_addr)?, - ) - .await + self.forward(packet, NymNodeRoutingAddress::try_from(next_node_addr)?) + .await } async fn forward_payload_to_destination( - pool: &ConnectionPool, - max_retries: usize, - retry_delay: Duration, - payload: Payload, + &self, + payload: Body, destination_addr: DestinationAddressBytes, ) -> Result<()> { tracing::debug!("Forwarding final payload to destination mixnode"); - Self::forward( - pool, - max_retries, - retry_delay, - Body::new_final_payload(payload), + self.forward( + payload, NymNodeRoutingAddress::try_from_bytes(&destination_addr.as_bytes())?, ) .await } - async fn forward( - pool: &ConnectionPool, - max_retries: usize, - retry_delay: Duration, - body: Body, - to: NymNodeRoutingAddress, - ) -> Result<()> { + async fn forward(&self, pkt: Body, to: NymNodeRoutingAddress) -> Result<()> { let addr = SocketAddr::from(to); - let arc_socket = pool.get_or_init(&addr).await?; - if let Err(e) = { - let mut socket = arc_socket.lock().await; - body.write(&mut *socket).await - } { - tracing::error!("Failed to forward packet to {addr} with error: {e}. Retrying..."); - return mixnet_protocol::retry_backoff( - addr, - max_retries, - retry_delay, - body, - arc_socket, - ) - .await - .map_err(Into::into); - } + self.packet_tx.send(Packet::new(addr, pkt))?; Ok(()) } } + +struct PacketForwarder { + config: MixnetNodeConfig, + packet_rx: mpsc::UnboundedReceiver, + packet_tx: mpsc::UnboundedSender, + connections: HashMap, +} + +impl PacketForwarder { + pub fn new( + packet_tx: mpsc::UnboundedSender, + packet_rx: mpsc::UnboundedReceiver, + config: MixnetNodeConfig, + ) -> Self { + Self { + packet_tx, + packet_rx, + connections: HashMap::with_capacity(config.connection_pool_size), + config, + } + } + + pub async fn run(mut self) { + loop { + tokio::select! { + pkt = self.packet_rx.recv() => { + if let Some(pkt) = pkt { + self.send(pkt).await; + } else { + unreachable!("Packet channel should not be closed, because PacketForwarder is also holding the send half"); + } + }, + _ = tokio::signal::ctrl_c() => { + tracing::info!("Shutting down packet forwarder task..."); + return; + } + } + } + } + + async fn try_send(&mut self, target: SocketAddr, body: &Body) -> Result<()> { + if let std::collections::hash_map::Entry::Vacant(e) = self.connections.entry(target) { + match TcpStream::connect(target).await { + Ok(tcp) => { + e.insert(tcp); + } + Err(e) => { + tracing::error!("failed to connect to {}: {e}", target); + return Err(MixnetNodeError::Protocol(e.into())); + } + } + } + Ok(body + .write(self.connections.get_mut(&target).unwrap()) + .await?) + } + + async fn send(&mut self, pkt: Packet) { + if let Err(err) = self.try_send(pkt.target, &pkt.body).await { + match err { + MixnetNodeError::Protocol(ProtocolError::IO(e)) + if e.kind() == std::io::ErrorKind::Unsupported => + { + tracing::error!("fail to send message to {}: {e}", pkt.target); + } + _ => self.handle_retry(pkt), + } + } + } + + fn handle_retry(&self, mut pkt: Packet) { + if pkt.retry_count < self.config.max_retries { + let delay = Duration::from_millis( + (self.config.retry_delay.as_millis() as u64).pow(pkt.retry_count as u32), + ); + let tx = self.packet_tx.clone(); + tokio::spawn(async move { + tokio::time::sleep(delay).await; + pkt.retry_count += 1; + if let Err(e) = tx.send(pkt) { + tracing::error!("fail to enqueue retry message: {e}"); + } + }); + } else { + tracing::error!( + "fail to send message to {}: reach maximum retries", + pkt.target + ); + } + } +} + +pub struct Packet { + target: SocketAddr, + body: Body, + retry_count: usize, +} + +impl Packet { + fn new(target: SocketAddr, body: Body) -> Self { + Self { + target, + body, + retry_count: 0, + } + } +} diff --git a/mixnet/protocol/Cargo.toml b/mixnet/protocol/Cargo.toml index c0db927b..d0673195 100644 --- a/mixnet/protocol/Cargo.toml +++ b/mixnet/protocol/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -tokio = { version = "1.32", features = ["sync", "net"] } +tokio = { version = "1.32", features = ["sync", "net", "time"] } sphinx-packet = "0.1.0" futures = "0.3" tokio-util = { version = "0.7", features = ["io", "io-util"] } diff --git a/mixnet/protocol/src/lib.rs b/mixnet/protocol/src/lib.rs index 459eac7b..af6fdf7c 100644 --- a/mixnet/protocol/src/lib.rs +++ b/mixnet/protocol/src/lib.rs @@ -1,6 +1,7 @@ use sphinx_packet::{payload::Payload, SphinxPacket}; use std::{io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration}; + use tokio::{ io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, net::TcpStream, @@ -122,7 +123,6 @@ pub async fn retry_backoff( for idx in 0..max_retries { // backoff let wait = Duration::from_millis((retry_delay.as_millis() as u64).pow(idx as u32)); - tokio::time::sleep(wait).await; let mut socket = socket.lock().await; match body.write(&mut *socket).await { diff --git a/mixnet/util/Cargo.toml b/mixnet/util/Cargo.toml index 7ab67774..66b323f0 100644 --- a/mixnet/util/Cargo.toml +++ b/mixnet/util/Cargo.toml @@ -5,5 +5,4 @@ edition = "2021" [dependencies] tokio = { version = "1.32", default-features = false, features = ["sync", "net"] } -parking_lot = { version = "0.12", features = ["send_guard"] } -mixnet-protocol = { path = "../protocol" } \ No newline at end of file +mixnet-protocol = { path = "../protocol" } diff --git a/mixnet/util/src/lib.rs b/mixnet/util/src/lib.rs index e00f3714..a8431eb2 100644 --- a/mixnet/util/src/lib.rs +++ b/mixnet/util/src/lib.rs @@ -1,7 +1,6 @@ use std::{collections::HashMap, net::SocketAddr, sync::Arc}; -use tokio::net::TcpStream; -use tokio::sync::Mutex; +use tokio::{net::TcpStream, sync::Mutex}; #[derive(Clone)] pub struct ConnectionPool { diff --git a/tests/src/nodes/mixnode.rs b/tests/src/nodes/mixnode.rs index b55de6e0..25a23060 100644 --- a/tests/src/nodes/mixnode.rs +++ b/tests/src/nodes/mixnode.rs @@ -72,7 +72,7 @@ impl MixNode { let mut nodes = Vec::::new(); for config in &configs { - nodes.push(Self::spawn(config.clone()).await); + nodes.push(Self::spawn(*config).await); } // We need to return configs as well, to configure mixclients accordingly