Mixnet packet handle (#435)

* finish mixnet packet sending fan-in model
This commit is contained in:
Al Liu 2023-09-26 18:23:34 +08:00 committed by GitHub
parent 03973cd422
commit 8ea341dbaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 187 additions and 157 deletions

BIN
.DS_Store vendored

Binary file not shown.

View File

@ -6,7 +6,7 @@ edition = "2021"
[dependencies] [dependencies]
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
tracing = "0.1.37" tracing = "0.1.37"
tokio = { version = "1.32", features = ["net", "time"] } tokio = { version = "1.32", features = ["net", "time", "signal"] }
thiserror = "1" thiserror = "1"
sphinx-packet = "0.1.0" sphinx-packet = "0.1.0"
nym-sphinx = { package = "nym-sphinx", git = "https://github.com/nymtech/nym", tag = "v1.1.22" } nym-sphinx = { package = "nym-sphinx", git = "https://github.com/nymtech/nym", tag = "v1.1.22" }

View File

@ -7,7 +7,7 @@ use nym_sphinx::{PrivateKey, PublicKey};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sphinx_packet::crypto::{PRIVATE_KEY_SIZE, PUBLIC_KEY_SIZE}; use sphinx_packet::crypto::{PRIVATE_KEY_SIZE, PUBLIC_KEY_SIZE};
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Copy, Clone, Debug)]
pub struct MixnetNodeConfig { pub struct MixnetNodeConfig {
/// A listen address for receiving Sphinx packets /// A listen address for receiving Sphinx packets
pub listen_address: SocketAddr, pub listen_address: SocketAddr,

View File

@ -1,16 +1,15 @@
mod client_notifier; mod client_notifier;
pub mod config; pub mod config;
use std::{net::SocketAddr, time::Duration}; use std::{collections::HashMap, net::SocketAddr, time::Duration};
use client_notifier::ClientNotifier; use client_notifier::ClientNotifier;
pub use config::MixnetNodeConfig; pub use config::MixnetNodeConfig;
use mixnet_protocol::{Body, ProtocolError}; use mixnet_protocol::{Body, ProtocolError};
use mixnet_topology::MixnetNodeId; use mixnet_topology::MixnetNodeId;
use mixnet_util::ConnectionPool;
use nym_sphinx::{ use nym_sphinx::{
addressing::nodes::{NymNodeRoutingAddress, NymNodeRoutingAddressError}, addressing::nodes::{NymNodeRoutingAddress, NymNodeRoutingAddressError},
Delay, DestinationAddressBytes, NodeAddressBytes, Payload, PrivateKey, Delay, DestinationAddressBytes, NodeAddressBytes, PrivateKey,
}; };
pub use sphinx_packet::crypto::PRIVATE_KEY_SIZE; pub use sphinx_packet::crypto::PRIVATE_KEY_SIZE;
use sphinx_packet::{crypto::PUBLIC_KEY_SIZE, ProcessedPacket, SphinxPacket}; use sphinx_packet::{crypto::PUBLIC_KEY_SIZE, ProcessedPacket, SphinxPacket};
@ -28,7 +27,9 @@ pub enum MixnetNodeError {
#[error("invalid routing address: {0}")] #[error("invalid routing address: {0}")]
InvalidRoutingAddress(#[from] NymNodeRoutingAddressError), InvalidRoutingAddress(#[from] NymNodeRoutingAddressError),
#[error("send error: {0}")] #[error("send error: {0}")]
SendError(#[from] tokio::sync::mpsc::error::TrySendError<Body>), PacketSendError(#[from] tokio::sync::mpsc::error::SendError<Packet>),
#[error("send error: fail to send {0} to client")]
ClientSendError(#[from] tokio::sync::mpsc::error::TrySendError<Body>),
#[error("client: {0}")] #[error("client: {0}")]
Client(ProtocolError), Client(ProtocolError),
} }
@ -36,13 +37,11 @@ pub enum MixnetNodeError {
// A mix node that routes packets in the Mixnet. // A mix node that routes packets in the Mixnet.
pub struct MixnetNode { pub struct MixnetNode {
config: MixnetNodeConfig, config: MixnetNodeConfig,
pool: ConnectionPool,
} }
impl MixnetNode { impl MixnetNode {
pub fn new(config: MixnetNodeConfig) -> Self { pub fn new(config: MixnetNodeConfig) -> Self {
let pool = ConnectionPool::new(config.connection_pool_size); Self { config }
Self { config, pool }
} }
pub fn id(&self) -> MixnetNodeId { pub fn id(&self) -> MixnetNodeId {
@ -68,7 +67,6 @@ impl MixnetNode {
}); });
//TODO: Accepting ad-hoc TCP conns for now. Improve conn handling. //TODO: Accepting ad-hoc TCP conns for now. Improve conn handling.
//TODO: Add graceful shutdown
let listener = TcpListener::bind(self.config.listen_address) let listener = TcpListener::bind(self.config.listen_address)
.await .await
.map_err(ProtocolError::IO)?; .map_err(ProtocolError::IO)?;
@ -77,205 +75,239 @@ impl MixnetNode {
self.config.listen_address self.config.listen_address
); );
loop { let (tx, rx) = mpsc::unbounded_channel();
match listener.accept().await {
Ok((socket, remote_addr)) => {
tracing::debug!("Accepted incoming connection from {remote_addr:?}");
let client_tx = client_tx.clone(); let packet_forwarder = PacketForwarder::new(tx.clone(), rx, self.config);
let private_key = self.config.private_key;
let pool = self.pool.clone(); tokio::spawn(async move {
tokio::spawn(async move { packet_forwarder.run().await;
if let Err(e) = Self::handle_connection( });
socket,
self.config.max_retries, let runner = MixnetNodeRunner {
self.config.retry_delay, config: self.config,
pool, client_tx,
private_key, packet_tx: tx,
client_tx, };
)
.await loop {
{ tokio::select! {
tracing::error!("failed to handle conn: {e}"); res = listener.accept() => {
match res {
Ok((socket, remote_addr)) => {
tracing::debug!("Accepted incoming connection from {remote_addr:?}");
let runner = runner.clone();
tokio::spawn(async move {
if let Err(e) = runner.handle_connection(socket).await {
tracing::error!("failed to handle conn: {e}");
}
});
} }
}); Err(e) => tracing::warn!("Failed to accept incoming connection: {e}"),
}
}
_ = tokio::signal::ctrl_c() => {
tracing::info!("Shutting down...");
return Ok(());
} }
Err(e) => tracing::warn!("Failed to accept incoming connection: {e}"),
} }
} }
} }
}
async fn handle_connection( #[derive(Clone)]
mut socket: TcpStream, struct MixnetNodeRunner {
max_retries: usize, config: MixnetNodeConfig,
retry_delay: Duration, client_tx: mpsc::Sender<Body>,
pool: ConnectionPool, packet_tx: mpsc::UnboundedSender<Packet>,
private_key: [u8; PRIVATE_KEY_SIZE], }
client_tx: mpsc::Sender<Body>,
) -> Result<()> { impl MixnetNodeRunner {
async fn handle_connection(&self, mut socket: TcpStream) -> Result<()> {
loop { loop {
let body = Body::read(&mut socket).await?; let body = Body::read(&mut socket).await?;
let this = self.clone();
let pool = pool.clone();
let private_key = PrivateKey::from(private_key);
let client_tx = client_tx.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = Self::handle_body( if let Err(e) = this.handle_body(body).await {
max_retries,
retry_delay,
body,
&pool,
&private_key,
&client_tx,
)
.await
{
tracing::error!("failed to handle body: {e}"); tracing::error!("failed to handle body: {e}");
} }
}); });
} }
} }
// TODO: refactor this fn to make it receive less arguments async fn handle_body(&self, pkt: Body) -> Result<()> {
#[allow(clippy::too_many_arguments)] match pkt {
async fn handle_body( Body::SphinxPacket(packet) => self.handle_sphinx_packet(packet).await,
max_retries: usize,
retry_delay: Duration,
body: Body,
pool: &ConnectionPool,
private_key: &PrivateKey,
client_tx: &mpsc::Sender<Body>,
) -> Result<()> {
match body {
Body::SphinxPacket(packet) => {
Self::handle_sphinx_packet(pool, max_retries, retry_delay, private_key, packet)
.await
}
Body::FinalPayload(payload) => { Body::FinalPayload(payload) => {
Self::forward_body_to_client_notifier( self.forward_body_to_client_notifier(Body::FinalPayload(payload))
private_key, .await
client_tx,
Body::FinalPayload(payload),
)
.await
} }
_ => unreachable!(), _ => unreachable!(),
} }
} }
async fn handle_sphinx_packet( async fn handle_sphinx_packet(&self, packet: Box<SphinxPacket>) -> Result<()> {
pool: &ConnectionPool,
max_retries: usize,
retry_delay: Duration,
private_key: &PrivateKey,
packet: Box<SphinxPacket>,
) -> Result<()> {
match packet match packet
.process(private_key) .process(&PrivateKey::from(self.config.private_key))
.map_err(ProtocolError::InvalidSphinxPacket)? .map_err(ProtocolError::InvalidSphinxPacket)?
{ {
ProcessedPacket::ForwardHop(packet, next_node_addr, delay) => { ProcessedPacket::ForwardHop(packet, next_node_addr, delay) => {
Self::forward_packet_to_next_hop( self.forward_packet_to_next_hop(Body::SphinxPacket(packet), next_node_addr, delay)
pool, .await
max_retries,
retry_delay,
packet,
next_node_addr,
delay,
)
.await
} }
ProcessedPacket::FinalHop(destination_addr, _, payload) => { ProcessedPacket::FinalHop(destination_addr, _, payload) => {
Self::forward_payload_to_destination( self.forward_payload_to_destination(Body::FinalPayload(payload), destination_addr)
pool, .await
max_retries,
retry_delay,
payload,
destination_addr,
)
.await
} }
} }
} }
async fn forward_body_to_client_notifier( async fn forward_body_to_client_notifier(&self, body: Body) -> Result<()> {
_private_key: &PrivateKey,
client_tx: &mpsc::Sender<Body>,
body: Body,
) -> Result<()> {
// TODO: Decrypt the final payload using the private key, if it's encrypted // TODO: Decrypt the final payload using the private key, if it's encrypted
// Do not wait when the channel is full or no receiver exists // Do not wait when the channel is full or no receiver exists
client_tx.try_send(body)?; self.client_tx.try_send(body)?;
Ok(()) Ok(())
} }
async fn forward_packet_to_next_hop( async fn forward_packet_to_next_hop(
pool: &ConnectionPool, &self,
max_retries: usize, packet: Body,
retry_delay: Duration,
packet: Box<SphinxPacket>,
next_node_addr: NodeAddressBytes, next_node_addr: NodeAddressBytes,
delay: Delay, delay: Delay,
) -> Result<()> { ) -> Result<()> {
tracing::debug!("Delaying the packet for {delay:?}"); tracing::debug!("Delaying the packet for {delay:?}");
tokio::time::sleep(delay.to_duration()).await; tokio::time::sleep(delay.to_duration()).await;
Self::forward( self.forward(packet, NymNodeRoutingAddress::try_from(next_node_addr)?)
pool, .await
max_retries,
retry_delay,
Body::new_sphinx(packet),
NymNodeRoutingAddress::try_from(next_node_addr)?,
)
.await
} }
async fn forward_payload_to_destination( async fn forward_payload_to_destination(
pool: &ConnectionPool, &self,
max_retries: usize, payload: Body,
retry_delay: Duration,
payload: Payload,
destination_addr: DestinationAddressBytes, destination_addr: DestinationAddressBytes,
) -> Result<()> { ) -> Result<()> {
tracing::debug!("Forwarding final payload to destination mixnode"); tracing::debug!("Forwarding final payload to destination mixnode");
Self::forward( self.forward(
pool, payload,
max_retries,
retry_delay,
Body::new_final_payload(payload),
NymNodeRoutingAddress::try_from_bytes(&destination_addr.as_bytes())?, NymNodeRoutingAddress::try_from_bytes(&destination_addr.as_bytes())?,
) )
.await .await
} }
async fn forward( async fn forward(&self, pkt: Body, to: NymNodeRoutingAddress) -> Result<()> {
pool: &ConnectionPool,
max_retries: usize,
retry_delay: Duration,
body: Body,
to: NymNodeRoutingAddress,
) -> Result<()> {
let addr = SocketAddr::from(to); let addr = SocketAddr::from(to);
let arc_socket = pool.get_or_init(&addr).await?;
if let Err(e) = { self.packet_tx.send(Packet::new(addr, pkt))?;
let mut socket = arc_socket.lock().await;
body.write(&mut *socket).await
} {
tracing::error!("Failed to forward packet to {addr} with error: {e}. Retrying...");
return mixnet_protocol::retry_backoff(
addr,
max_retries,
retry_delay,
body,
arc_socket,
)
.await
.map_err(Into::into);
}
Ok(()) Ok(())
} }
} }
struct PacketForwarder {
config: MixnetNodeConfig,
packet_rx: mpsc::UnboundedReceiver<Packet>,
packet_tx: mpsc::UnboundedSender<Packet>,
connections: HashMap<SocketAddr, TcpStream>,
}
impl PacketForwarder {
pub fn new(
packet_tx: mpsc::UnboundedSender<Packet>,
packet_rx: mpsc::UnboundedReceiver<Packet>,
config: MixnetNodeConfig,
) -> Self {
Self {
packet_tx,
packet_rx,
connections: HashMap::with_capacity(config.connection_pool_size),
config,
}
}
pub async fn run(mut self) {
loop {
tokio::select! {
pkt = self.packet_rx.recv() => {
if let Some(pkt) = pkt {
self.send(pkt).await;
} else {
unreachable!("Packet channel should not be closed, because PacketForwarder is also holding the send half");
}
},
_ = tokio::signal::ctrl_c() => {
tracing::info!("Shutting down packet forwarder task...");
return;
}
}
}
}
async fn try_send(&mut self, target: SocketAddr, body: &Body) -> Result<()> {
if let std::collections::hash_map::Entry::Vacant(e) = self.connections.entry(target) {
match TcpStream::connect(target).await {
Ok(tcp) => {
e.insert(tcp);
}
Err(e) => {
tracing::error!("failed to connect to {}: {e}", target);
return Err(MixnetNodeError::Protocol(e.into()));
}
}
}
Ok(body
.write(self.connections.get_mut(&target).unwrap())
.await?)
}
async fn send(&mut self, pkt: Packet) {
if let Err(err) = self.try_send(pkt.target, &pkt.body).await {
match err {
MixnetNodeError::Protocol(ProtocolError::IO(e))
if e.kind() == std::io::ErrorKind::Unsupported =>
{
tracing::error!("fail to send message to {}: {e}", pkt.target);
}
_ => self.handle_retry(pkt),
}
}
}
fn handle_retry(&self, mut pkt: Packet) {
if pkt.retry_count < self.config.max_retries {
let delay = Duration::from_millis(
(self.config.retry_delay.as_millis() as u64).pow(pkt.retry_count as u32),
);
let tx = self.packet_tx.clone();
tokio::spawn(async move {
tokio::time::sleep(delay).await;
pkt.retry_count += 1;
if let Err(e) = tx.send(pkt) {
tracing::error!("fail to enqueue retry message: {e}");
}
});
} else {
tracing::error!(
"fail to send message to {}: reach maximum retries",
pkt.target
);
}
}
}
pub struct Packet {
target: SocketAddr,
body: Body,
retry_count: usize,
}
impl Packet {
fn new(target: SocketAddr, body: Body) -> Self {
Self {
target,
body,
retry_count: 0,
}
}
}

View File

@ -6,7 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
tokio = { version = "1.32", features = ["sync", "net"] } tokio = { version = "1.32", features = ["sync", "net", "time"] }
sphinx-packet = "0.1.0" sphinx-packet = "0.1.0"
futures = "0.3" futures = "0.3"
tokio-util = { version = "0.7", features = ["io", "io-util"] } tokio-util = { version = "0.7", features = ["io", "io-util"] }

View File

@ -1,6 +1,7 @@
use sphinx_packet::{payload::Payload, SphinxPacket}; use sphinx_packet::{payload::Payload, SphinxPacket};
use std::{io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration}; use std::{io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration};
use tokio::{ use tokio::{
io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::TcpStream, net::TcpStream,
@ -122,7 +123,6 @@ pub async fn retry_backoff(
for idx in 0..max_retries { for idx in 0..max_retries {
// backoff // backoff
let wait = Duration::from_millis((retry_delay.as_millis() as u64).pow(idx as u32)); let wait = Duration::from_millis((retry_delay.as_millis() as u64).pow(idx as u32));
tokio::time::sleep(wait).await; tokio::time::sleep(wait).await;
let mut socket = socket.lock().await; let mut socket = socket.lock().await;
match body.write(&mut *socket).await { match body.write(&mut *socket).await {

View File

@ -5,5 +5,4 @@ edition = "2021"
[dependencies] [dependencies]
tokio = { version = "1.32", default-features = false, features = ["sync", "net"] } tokio = { version = "1.32", default-features = false, features = ["sync", "net"] }
parking_lot = { version = "0.12", features = ["send_guard"] } mixnet-protocol = { path = "../protocol" }
mixnet-protocol = { path = "../protocol" }

View File

@ -1,7 +1,6 @@
use std::{collections::HashMap, net::SocketAddr, sync::Arc}; use std::{collections::HashMap, net::SocketAddr, sync::Arc};
use tokio::net::TcpStream; use tokio::{net::TcpStream, sync::Mutex};
use tokio::sync::Mutex;
#[derive(Clone)] #[derive(Clone)]
pub struct ConnectionPool { pub struct ConnectionPool {

View File

@ -72,7 +72,7 @@ impl MixNode {
let mut nodes = Vec::<MixNode>::new(); let mut nodes = Vec::<MixNode>::new();
for config in &configs { for config in &configs {
nodes.push(Self::spawn(config.clone()).await); nodes.push(Self::spawn(*config).await);
} }
// We need to return configs as well, to configure mixclients accordingly // We need to return configs as well, to configure mixclients accordingly