Mixnet packet handle (#435)

* finish mixnet packet sending fan-in model
This commit is contained in:
Al Liu 2023-09-26 18:23:34 +08:00 committed by GitHub
parent 03973cd422
commit 8ea341dbaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 187 additions and 157 deletions

BIN
.DS_Store vendored

Binary file not shown.

View File

@ -6,7 +6,7 @@ edition = "2021"
[dependencies]
serde = { version = "1.0", features = ["derive"] }
tracing = "0.1.37"
tokio = { version = "1.32", features = ["net", "time"] }
tokio = { version = "1.32", features = ["net", "time", "signal"] }
thiserror = "1"
sphinx-packet = "0.1.0"
nym-sphinx = { package = "nym-sphinx", git = "https://github.com/nymtech/nym", tag = "v1.1.22" }

View File

@ -7,7 +7,7 @@ use nym_sphinx::{PrivateKey, PublicKey};
use serde::{Deserialize, Serialize};
use sphinx_packet::crypto::{PRIVATE_KEY_SIZE, PUBLIC_KEY_SIZE};
#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Serialize, Deserialize, Copy, Clone, Debug)]
pub struct MixnetNodeConfig {
/// A listen address for receiving Sphinx packets
pub listen_address: SocketAddr,

View File

@ -1,16 +1,15 @@
mod client_notifier;
pub mod config;
use std::{net::SocketAddr, time::Duration};
use std::{collections::HashMap, net::SocketAddr, time::Duration};
use client_notifier::ClientNotifier;
pub use config::MixnetNodeConfig;
use mixnet_protocol::{Body, ProtocolError};
use mixnet_topology::MixnetNodeId;
use mixnet_util::ConnectionPool;
use nym_sphinx::{
addressing::nodes::{NymNodeRoutingAddress, NymNodeRoutingAddressError},
Delay, DestinationAddressBytes, NodeAddressBytes, Payload, PrivateKey,
Delay, DestinationAddressBytes, NodeAddressBytes, PrivateKey,
};
pub use sphinx_packet::crypto::PRIVATE_KEY_SIZE;
use sphinx_packet::{crypto::PUBLIC_KEY_SIZE, ProcessedPacket, SphinxPacket};
@ -28,7 +27,9 @@ pub enum MixnetNodeError {
#[error("invalid routing address: {0}")]
InvalidRoutingAddress(#[from] NymNodeRoutingAddressError),
#[error("send error: {0}")]
SendError(#[from] tokio::sync::mpsc::error::TrySendError<Body>),
PacketSendError(#[from] tokio::sync::mpsc::error::SendError<Packet>),
#[error("send error: fail to send {0} to client")]
ClientSendError(#[from] tokio::sync::mpsc::error::TrySendError<Body>),
#[error("client: {0}")]
Client(ProtocolError),
}
@ -36,13 +37,11 @@ pub enum MixnetNodeError {
// A mix node that routes packets in the Mixnet.
pub struct MixnetNode {
config: MixnetNodeConfig,
pool: ConnectionPool,
}
impl MixnetNode {
pub fn new(config: MixnetNodeConfig) -> Self {
let pool = ConnectionPool::new(config.connection_pool_size);
Self { config, pool }
Self { config }
}
pub fn id(&self) -> MixnetNodeId {
@ -68,7 +67,6 @@ impl MixnetNode {
});
//TODO: Accepting ad-hoc TCP conns for now. Improve conn handling.
//TODO: Add graceful shutdown
let listener = TcpListener::bind(self.config.listen_address)
.await
.map_err(ProtocolError::IO)?;
@ -77,205 +75,239 @@ impl MixnetNode {
self.config.listen_address
);
loop {
match listener.accept().await {
Ok((socket, remote_addr)) => {
tracing::debug!("Accepted incoming connection from {remote_addr:?}");
let (tx, rx) = mpsc::unbounded_channel();
let client_tx = client_tx.clone();
let private_key = self.config.private_key;
let pool = self.pool.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_connection(
socket,
self.config.max_retries,
self.config.retry_delay,
pool,
private_key,
client_tx,
)
.await
{
tracing::error!("failed to handle conn: {e}");
let packet_forwarder = PacketForwarder::new(tx.clone(), rx, self.config);
tokio::spawn(async move {
packet_forwarder.run().await;
});
let runner = MixnetNodeRunner {
config: self.config,
client_tx,
packet_tx: tx,
};
loop {
tokio::select! {
res = listener.accept() => {
match res {
Ok((socket, remote_addr)) => {
tracing::debug!("Accepted incoming connection from {remote_addr:?}");
let runner = runner.clone();
tokio::spawn(async move {
if let Err(e) = runner.handle_connection(socket).await {
tracing::error!("failed to handle conn: {e}");
}
});
}
});
Err(e) => tracing::warn!("Failed to accept incoming connection: {e}"),
}
}
_ = tokio::signal::ctrl_c() => {
tracing::info!("Shutting down...");
return Ok(());
}
Err(e) => tracing::warn!("Failed to accept incoming connection: {e}"),
}
}
}
}
async fn handle_connection(
mut socket: TcpStream,
max_retries: usize,
retry_delay: Duration,
pool: ConnectionPool,
private_key: [u8; PRIVATE_KEY_SIZE],
client_tx: mpsc::Sender<Body>,
) -> Result<()> {
#[derive(Clone)]
struct MixnetNodeRunner {
config: MixnetNodeConfig,
client_tx: mpsc::Sender<Body>,
packet_tx: mpsc::UnboundedSender<Packet>,
}
impl MixnetNodeRunner {
async fn handle_connection(&self, mut socket: TcpStream) -> Result<()> {
loop {
let body = Body::read(&mut socket).await?;
let pool = pool.clone();
let private_key = PrivateKey::from(private_key);
let client_tx = client_tx.clone();
let this = self.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_body(
max_retries,
retry_delay,
body,
&pool,
&private_key,
&client_tx,
)
.await
{
if let Err(e) = this.handle_body(body).await {
tracing::error!("failed to handle body: {e}");
}
});
}
}
// TODO: refactor this fn to make it receive less arguments
#[allow(clippy::too_many_arguments)]
async fn handle_body(
max_retries: usize,
retry_delay: Duration,
body: Body,
pool: &ConnectionPool,
private_key: &PrivateKey,
client_tx: &mpsc::Sender<Body>,
) -> Result<()> {
match body {
Body::SphinxPacket(packet) => {
Self::handle_sphinx_packet(pool, max_retries, retry_delay, private_key, packet)
.await
}
async fn handle_body(&self, pkt: Body) -> Result<()> {
match pkt {
Body::SphinxPacket(packet) => self.handle_sphinx_packet(packet).await,
Body::FinalPayload(payload) => {
Self::forward_body_to_client_notifier(
private_key,
client_tx,
Body::FinalPayload(payload),
)
.await
self.forward_body_to_client_notifier(Body::FinalPayload(payload))
.await
}
_ => unreachable!(),
}
}
async fn handle_sphinx_packet(
pool: &ConnectionPool,
max_retries: usize,
retry_delay: Duration,
private_key: &PrivateKey,
packet: Box<SphinxPacket>,
) -> Result<()> {
async fn handle_sphinx_packet(&self, packet: Box<SphinxPacket>) -> Result<()> {
match packet
.process(private_key)
.process(&PrivateKey::from(self.config.private_key))
.map_err(ProtocolError::InvalidSphinxPacket)?
{
ProcessedPacket::ForwardHop(packet, next_node_addr, delay) => {
Self::forward_packet_to_next_hop(
pool,
max_retries,
retry_delay,
packet,
next_node_addr,
delay,
)
.await
self.forward_packet_to_next_hop(Body::SphinxPacket(packet), next_node_addr, delay)
.await
}
ProcessedPacket::FinalHop(destination_addr, _, payload) => {
Self::forward_payload_to_destination(
pool,
max_retries,
retry_delay,
payload,
destination_addr,
)
.await
self.forward_payload_to_destination(Body::FinalPayload(payload), destination_addr)
.await
}
}
}
async fn forward_body_to_client_notifier(
_private_key: &PrivateKey,
client_tx: &mpsc::Sender<Body>,
body: Body,
) -> Result<()> {
async fn forward_body_to_client_notifier(&self, body: Body) -> Result<()> {
// TODO: Decrypt the final payload using the private key, if it's encrypted
// Do not wait when the channel is full or no receiver exists
client_tx.try_send(body)?;
self.client_tx.try_send(body)?;
Ok(())
}
async fn forward_packet_to_next_hop(
pool: &ConnectionPool,
max_retries: usize,
retry_delay: Duration,
packet: Box<SphinxPacket>,
&self,
packet: Body,
next_node_addr: NodeAddressBytes,
delay: Delay,
) -> Result<()> {
tracing::debug!("Delaying the packet for {delay:?}");
tokio::time::sleep(delay.to_duration()).await;
Self::forward(
pool,
max_retries,
retry_delay,
Body::new_sphinx(packet),
NymNodeRoutingAddress::try_from(next_node_addr)?,
)
.await
self.forward(packet, NymNodeRoutingAddress::try_from(next_node_addr)?)
.await
}
async fn forward_payload_to_destination(
pool: &ConnectionPool,
max_retries: usize,
retry_delay: Duration,
payload: Payload,
&self,
payload: Body,
destination_addr: DestinationAddressBytes,
) -> Result<()> {
tracing::debug!("Forwarding final payload to destination mixnode");
Self::forward(
pool,
max_retries,
retry_delay,
Body::new_final_payload(payload),
self.forward(
payload,
NymNodeRoutingAddress::try_from_bytes(&destination_addr.as_bytes())?,
)
.await
}
async fn forward(
pool: &ConnectionPool,
max_retries: usize,
retry_delay: Duration,
body: Body,
to: NymNodeRoutingAddress,
) -> Result<()> {
async fn forward(&self, pkt: Body, to: NymNodeRoutingAddress) -> Result<()> {
let addr = SocketAddr::from(to);
let arc_socket = pool.get_or_init(&addr).await?;
if let Err(e) = {
let mut socket = arc_socket.lock().await;
body.write(&mut *socket).await
} {
tracing::error!("Failed to forward packet to {addr} with error: {e}. Retrying...");
return mixnet_protocol::retry_backoff(
addr,
max_retries,
retry_delay,
body,
arc_socket,
)
.await
.map_err(Into::into);
}
self.packet_tx.send(Packet::new(addr, pkt))?;
Ok(())
}
}
struct PacketForwarder {
config: MixnetNodeConfig,
packet_rx: mpsc::UnboundedReceiver<Packet>,
packet_tx: mpsc::UnboundedSender<Packet>,
connections: HashMap<SocketAddr, TcpStream>,
}
impl PacketForwarder {
pub fn new(
packet_tx: mpsc::UnboundedSender<Packet>,
packet_rx: mpsc::UnboundedReceiver<Packet>,
config: MixnetNodeConfig,
) -> Self {
Self {
packet_tx,
packet_rx,
connections: HashMap::with_capacity(config.connection_pool_size),
config,
}
}
pub async fn run(mut self) {
loop {
tokio::select! {
pkt = self.packet_rx.recv() => {
if let Some(pkt) = pkt {
self.send(pkt).await;
} else {
unreachable!("Packet channel should not be closed, because PacketForwarder is also holding the send half");
}
},
_ = tokio::signal::ctrl_c() => {
tracing::info!("Shutting down packet forwarder task...");
return;
}
}
}
}
async fn try_send(&mut self, target: SocketAddr, body: &Body) -> Result<()> {
if let std::collections::hash_map::Entry::Vacant(e) = self.connections.entry(target) {
match TcpStream::connect(target).await {
Ok(tcp) => {
e.insert(tcp);
}
Err(e) => {
tracing::error!("failed to connect to {}: {e}", target);
return Err(MixnetNodeError::Protocol(e.into()));
}
}
}
Ok(body
.write(self.connections.get_mut(&target).unwrap())
.await?)
}
async fn send(&mut self, pkt: Packet) {
if let Err(err) = self.try_send(pkt.target, &pkt.body).await {
match err {
MixnetNodeError::Protocol(ProtocolError::IO(e))
if e.kind() == std::io::ErrorKind::Unsupported =>
{
tracing::error!("fail to send message to {}: {e}", pkt.target);
}
_ => self.handle_retry(pkt),
}
}
}
fn handle_retry(&self, mut pkt: Packet) {
if pkt.retry_count < self.config.max_retries {
let delay = Duration::from_millis(
(self.config.retry_delay.as_millis() as u64).pow(pkt.retry_count as u32),
);
let tx = self.packet_tx.clone();
tokio::spawn(async move {
tokio::time::sleep(delay).await;
pkt.retry_count += 1;
if let Err(e) = tx.send(pkt) {
tracing::error!("fail to enqueue retry message: {e}");
}
});
} else {
tracing::error!(
"fail to send message to {}: reach maximum retries",
pkt.target
);
}
}
}
pub struct Packet {
target: SocketAddr,
body: Body,
retry_count: usize,
}
impl Packet {
fn new(target: SocketAddr, body: Body) -> Self {
Self {
target,
body,
retry_count: 0,
}
}
}

View File

@ -6,7 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
tokio = { version = "1.32", features = ["sync", "net"] }
tokio = { version = "1.32", features = ["sync", "net", "time"] }
sphinx-packet = "0.1.0"
futures = "0.3"
tokio-util = { version = "0.7", features = ["io", "io-util"] }

View File

@ -1,6 +1,7 @@
use sphinx_packet::{payload::Payload, SphinxPacket};
use std::{io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration};
use tokio::{
io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::TcpStream,
@ -122,7 +123,6 @@ pub async fn retry_backoff(
for idx in 0..max_retries {
// backoff
let wait = Duration::from_millis((retry_delay.as_millis() as u64).pow(idx as u32));
tokio::time::sleep(wait).await;
let mut socket = socket.lock().await;
match body.write(&mut *socket).await {

View File

@ -5,5 +5,4 @@ edition = "2021"
[dependencies]
tokio = { version = "1.32", default-features = false, features = ["sync", "net"] }
parking_lot = { version = "0.12", features = ["send_guard"] }
mixnet-protocol = { path = "../protocol" }
mixnet-protocol = { path = "../protocol" }

View File

@ -1,7 +1,6 @@
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::{net::TcpStream, sync::Mutex};
#[derive(Clone)]
pub struct ConnectionPool {

View File

@ -72,7 +72,7 @@ impl MixNode {
let mut nodes = Vec::<MixNode>::new();
for config in &configs {
nodes.push(Self::spawn(config.clone()).await);
nodes.push(Self::spawn(*config).await);
}
// We need to return configs as well, to configure mixclients accordingly