Mixnet packet handle (#435)
* finish mixnet packet sending fan-in model
This commit is contained in:
parent
03973cd422
commit
8ea341dbaf
|
@ -6,7 +6,7 @@ edition = "2021"
|
|||
[dependencies]
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tracing = "0.1.37"
|
||||
tokio = { version = "1.32", features = ["net", "time"] }
|
||||
tokio = { version = "1.32", features = ["net", "time", "signal"] }
|
||||
thiserror = "1"
|
||||
sphinx-packet = "0.1.0"
|
||||
nym-sphinx = { package = "nym-sphinx", git = "https://github.com/nymtech/nym", tag = "v1.1.22" }
|
||||
|
|
|
@ -7,7 +7,7 @@ use nym_sphinx::{PrivateKey, PublicKey};
|
|||
use serde::{Deserialize, Serialize};
|
||||
use sphinx_packet::crypto::{PRIVATE_KEY_SIZE, PUBLIC_KEY_SIZE};
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[derive(Serialize, Deserialize, Copy, Clone, Debug)]
|
||||
pub struct MixnetNodeConfig {
|
||||
/// A listen address for receiving Sphinx packets
|
||||
pub listen_address: SocketAddr,
|
||||
|
|
|
@ -1,16 +1,15 @@
|
|||
mod client_notifier;
|
||||
pub mod config;
|
||||
|
||||
use std::{net::SocketAddr, time::Duration};
|
||||
use std::{collections::HashMap, net::SocketAddr, time::Duration};
|
||||
|
||||
use client_notifier::ClientNotifier;
|
||||
pub use config::MixnetNodeConfig;
|
||||
use mixnet_protocol::{Body, ProtocolError};
|
||||
use mixnet_topology::MixnetNodeId;
|
||||
use mixnet_util::ConnectionPool;
|
||||
use nym_sphinx::{
|
||||
addressing::nodes::{NymNodeRoutingAddress, NymNodeRoutingAddressError},
|
||||
Delay, DestinationAddressBytes, NodeAddressBytes, Payload, PrivateKey,
|
||||
Delay, DestinationAddressBytes, NodeAddressBytes, PrivateKey,
|
||||
};
|
||||
pub use sphinx_packet::crypto::PRIVATE_KEY_SIZE;
|
||||
use sphinx_packet::{crypto::PUBLIC_KEY_SIZE, ProcessedPacket, SphinxPacket};
|
||||
|
@ -28,7 +27,9 @@ pub enum MixnetNodeError {
|
|||
#[error("invalid routing address: {0}")]
|
||||
InvalidRoutingAddress(#[from] NymNodeRoutingAddressError),
|
||||
#[error("send error: {0}")]
|
||||
SendError(#[from] tokio::sync::mpsc::error::TrySendError<Body>),
|
||||
PacketSendError(#[from] tokio::sync::mpsc::error::SendError<Packet>),
|
||||
#[error("send error: fail to send {0} to client")]
|
||||
ClientSendError(#[from] tokio::sync::mpsc::error::TrySendError<Body>),
|
||||
#[error("client: {0}")]
|
||||
Client(ProtocolError),
|
||||
}
|
||||
|
@ -36,13 +37,11 @@ pub enum MixnetNodeError {
|
|||
// A mix node that routes packets in the Mixnet.
|
||||
pub struct MixnetNode {
|
||||
config: MixnetNodeConfig,
|
||||
pool: ConnectionPool,
|
||||
}
|
||||
|
||||
impl MixnetNode {
|
||||
pub fn new(config: MixnetNodeConfig) -> Self {
|
||||
let pool = ConnectionPool::new(config.connection_pool_size);
|
||||
Self { config, pool }
|
||||
Self { config }
|
||||
}
|
||||
|
||||
pub fn id(&self) -> MixnetNodeId {
|
||||
|
@ -68,7 +67,6 @@ impl MixnetNode {
|
|||
});
|
||||
|
||||
//TODO: Accepting ad-hoc TCP conns for now. Improve conn handling.
|
||||
//TODO: Add graceful shutdown
|
||||
let listener = TcpListener::bind(self.config.listen_address)
|
||||
.await
|
||||
.map_err(ProtocolError::IO)?;
|
||||
|
@ -77,205 +75,239 @@ impl MixnetNode {
|
|||
self.config.listen_address
|
||||
);
|
||||
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((socket, remote_addr)) => {
|
||||
tracing::debug!("Accepted incoming connection from {remote_addr:?}");
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
|
||||
let client_tx = client_tx.clone();
|
||||
let private_key = self.config.private_key;
|
||||
let pool = self.pool.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = Self::handle_connection(
|
||||
socket,
|
||||
self.config.max_retries,
|
||||
self.config.retry_delay,
|
||||
pool,
|
||||
private_key,
|
||||
client_tx,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::error!("failed to handle conn: {e}");
|
||||
let packet_forwarder = PacketForwarder::new(tx.clone(), rx, self.config);
|
||||
|
||||
tokio::spawn(async move {
|
||||
packet_forwarder.run().await;
|
||||
});
|
||||
|
||||
let runner = MixnetNodeRunner {
|
||||
config: self.config,
|
||||
client_tx,
|
||||
packet_tx: tx,
|
||||
};
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
res = listener.accept() => {
|
||||
match res {
|
||||
Ok((socket, remote_addr)) => {
|
||||
tracing::debug!("Accepted incoming connection from {remote_addr:?}");
|
||||
|
||||
let runner = runner.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = runner.handle_connection(socket).await {
|
||||
tracing::error!("failed to handle conn: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
Err(e) => tracing::warn!("Failed to accept incoming connection: {e}"),
|
||||
}
|
||||
}
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
tracing::info!("Shutting down...");
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => tracing::warn!("Failed to accept incoming connection: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_connection(
|
||||
mut socket: TcpStream,
|
||||
max_retries: usize,
|
||||
retry_delay: Duration,
|
||||
pool: ConnectionPool,
|
||||
private_key: [u8; PRIVATE_KEY_SIZE],
|
||||
client_tx: mpsc::Sender<Body>,
|
||||
) -> Result<()> {
|
||||
#[derive(Clone)]
|
||||
struct MixnetNodeRunner {
|
||||
config: MixnetNodeConfig,
|
||||
client_tx: mpsc::Sender<Body>,
|
||||
packet_tx: mpsc::UnboundedSender<Packet>,
|
||||
}
|
||||
|
||||
impl MixnetNodeRunner {
|
||||
async fn handle_connection(&self, mut socket: TcpStream) -> Result<()> {
|
||||
loop {
|
||||
let body = Body::read(&mut socket).await?;
|
||||
|
||||
let pool = pool.clone();
|
||||
let private_key = PrivateKey::from(private_key);
|
||||
let client_tx = client_tx.clone();
|
||||
|
||||
let this = self.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = Self::handle_body(
|
||||
max_retries,
|
||||
retry_delay,
|
||||
body,
|
||||
&pool,
|
||||
&private_key,
|
||||
&client_tx,
|
||||
)
|
||||
.await
|
||||
{
|
||||
if let Err(e) = this.handle_body(body).await {
|
||||
tracing::error!("failed to handle body: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: refactor this fn to make it receive less arguments
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn handle_body(
|
||||
max_retries: usize,
|
||||
retry_delay: Duration,
|
||||
body: Body,
|
||||
pool: &ConnectionPool,
|
||||
private_key: &PrivateKey,
|
||||
client_tx: &mpsc::Sender<Body>,
|
||||
) -> Result<()> {
|
||||
match body {
|
||||
Body::SphinxPacket(packet) => {
|
||||
Self::handle_sphinx_packet(pool, max_retries, retry_delay, private_key, packet)
|
||||
.await
|
||||
}
|
||||
async fn handle_body(&self, pkt: Body) -> Result<()> {
|
||||
match pkt {
|
||||
Body::SphinxPacket(packet) => self.handle_sphinx_packet(packet).await,
|
||||
Body::FinalPayload(payload) => {
|
||||
Self::forward_body_to_client_notifier(
|
||||
private_key,
|
||||
client_tx,
|
||||
Body::FinalPayload(payload),
|
||||
)
|
||||
.await
|
||||
self.forward_body_to_client_notifier(Body::FinalPayload(payload))
|
||||
.await
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_sphinx_packet(
|
||||
pool: &ConnectionPool,
|
||||
max_retries: usize,
|
||||
retry_delay: Duration,
|
||||
private_key: &PrivateKey,
|
||||
packet: Box<SphinxPacket>,
|
||||
) -> Result<()> {
|
||||
async fn handle_sphinx_packet(&self, packet: Box<SphinxPacket>) -> Result<()> {
|
||||
match packet
|
||||
.process(private_key)
|
||||
.process(&PrivateKey::from(self.config.private_key))
|
||||
.map_err(ProtocolError::InvalidSphinxPacket)?
|
||||
{
|
||||
ProcessedPacket::ForwardHop(packet, next_node_addr, delay) => {
|
||||
Self::forward_packet_to_next_hop(
|
||||
pool,
|
||||
max_retries,
|
||||
retry_delay,
|
||||
packet,
|
||||
next_node_addr,
|
||||
delay,
|
||||
)
|
||||
.await
|
||||
self.forward_packet_to_next_hop(Body::SphinxPacket(packet), next_node_addr, delay)
|
||||
.await
|
||||
}
|
||||
ProcessedPacket::FinalHop(destination_addr, _, payload) => {
|
||||
Self::forward_payload_to_destination(
|
||||
pool,
|
||||
max_retries,
|
||||
retry_delay,
|
||||
payload,
|
||||
destination_addr,
|
||||
)
|
||||
.await
|
||||
self.forward_payload_to_destination(Body::FinalPayload(payload), destination_addr)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn forward_body_to_client_notifier(
|
||||
_private_key: &PrivateKey,
|
||||
client_tx: &mpsc::Sender<Body>,
|
||||
body: Body,
|
||||
) -> Result<()> {
|
||||
async fn forward_body_to_client_notifier(&self, body: Body) -> Result<()> {
|
||||
// TODO: Decrypt the final payload using the private key, if it's encrypted
|
||||
|
||||
// Do not wait when the channel is full or no receiver exists
|
||||
client_tx.try_send(body)?;
|
||||
self.client_tx.try_send(body)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn forward_packet_to_next_hop(
|
||||
pool: &ConnectionPool,
|
||||
max_retries: usize,
|
||||
retry_delay: Duration,
|
||||
packet: Box<SphinxPacket>,
|
||||
&self,
|
||||
packet: Body,
|
||||
next_node_addr: NodeAddressBytes,
|
||||
delay: Delay,
|
||||
) -> Result<()> {
|
||||
tracing::debug!("Delaying the packet for {delay:?}");
|
||||
tokio::time::sleep(delay.to_duration()).await;
|
||||
|
||||
Self::forward(
|
||||
pool,
|
||||
max_retries,
|
||||
retry_delay,
|
||||
Body::new_sphinx(packet),
|
||||
NymNodeRoutingAddress::try_from(next_node_addr)?,
|
||||
)
|
||||
.await
|
||||
self.forward(packet, NymNodeRoutingAddress::try_from(next_node_addr)?)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn forward_payload_to_destination(
|
||||
pool: &ConnectionPool,
|
||||
max_retries: usize,
|
||||
retry_delay: Duration,
|
||||
payload: Payload,
|
||||
&self,
|
||||
payload: Body,
|
||||
destination_addr: DestinationAddressBytes,
|
||||
) -> Result<()> {
|
||||
tracing::debug!("Forwarding final payload to destination mixnode");
|
||||
|
||||
Self::forward(
|
||||
pool,
|
||||
max_retries,
|
||||
retry_delay,
|
||||
Body::new_final_payload(payload),
|
||||
self.forward(
|
||||
payload,
|
||||
NymNodeRoutingAddress::try_from_bytes(&destination_addr.as_bytes())?,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn forward(
|
||||
pool: &ConnectionPool,
|
||||
max_retries: usize,
|
||||
retry_delay: Duration,
|
||||
body: Body,
|
||||
to: NymNodeRoutingAddress,
|
||||
) -> Result<()> {
|
||||
async fn forward(&self, pkt: Body, to: NymNodeRoutingAddress) -> Result<()> {
|
||||
let addr = SocketAddr::from(to);
|
||||
let arc_socket = pool.get_or_init(&addr).await?;
|
||||
|
||||
if let Err(e) = {
|
||||
let mut socket = arc_socket.lock().await;
|
||||
body.write(&mut *socket).await
|
||||
} {
|
||||
tracing::error!("Failed to forward packet to {addr} with error: {e}. Retrying...");
|
||||
return mixnet_protocol::retry_backoff(
|
||||
addr,
|
||||
max_retries,
|
||||
retry_delay,
|
||||
body,
|
||||
arc_socket,
|
||||
)
|
||||
.await
|
||||
.map_err(Into::into);
|
||||
}
|
||||
self.packet_tx.send(Packet::new(addr, pkt))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct PacketForwarder {
|
||||
config: MixnetNodeConfig,
|
||||
packet_rx: mpsc::UnboundedReceiver<Packet>,
|
||||
packet_tx: mpsc::UnboundedSender<Packet>,
|
||||
connections: HashMap<SocketAddr, TcpStream>,
|
||||
}
|
||||
|
||||
impl PacketForwarder {
|
||||
pub fn new(
|
||||
packet_tx: mpsc::UnboundedSender<Packet>,
|
||||
packet_rx: mpsc::UnboundedReceiver<Packet>,
|
||||
config: MixnetNodeConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
packet_tx,
|
||||
packet_rx,
|
||||
connections: HashMap::with_capacity(config.connection_pool_size),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(mut self) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
pkt = self.packet_rx.recv() => {
|
||||
if let Some(pkt) = pkt {
|
||||
self.send(pkt).await;
|
||||
} else {
|
||||
unreachable!("Packet channel should not be closed, because PacketForwarder is also holding the send half");
|
||||
}
|
||||
},
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
tracing::info!("Shutting down packet forwarder task...");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn try_send(&mut self, target: SocketAddr, body: &Body) -> Result<()> {
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = self.connections.entry(target) {
|
||||
match TcpStream::connect(target).await {
|
||||
Ok(tcp) => {
|
||||
e.insert(tcp);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to connect to {}: {e}", target);
|
||||
return Err(MixnetNodeError::Protocol(e.into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(body
|
||||
.write(self.connections.get_mut(&target).unwrap())
|
||||
.await?)
|
||||
}
|
||||
|
||||
async fn send(&mut self, pkt: Packet) {
|
||||
if let Err(err) = self.try_send(pkt.target, &pkt.body).await {
|
||||
match err {
|
||||
MixnetNodeError::Protocol(ProtocolError::IO(e))
|
||||
if e.kind() == std::io::ErrorKind::Unsupported =>
|
||||
{
|
||||
tracing::error!("fail to send message to {}: {e}", pkt.target);
|
||||
}
|
||||
_ => self.handle_retry(pkt),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_retry(&self, mut pkt: Packet) {
|
||||
if pkt.retry_count < self.config.max_retries {
|
||||
let delay = Duration::from_millis(
|
||||
(self.config.retry_delay.as_millis() as u64).pow(pkt.retry_count as u32),
|
||||
);
|
||||
let tx = self.packet_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(delay).await;
|
||||
pkt.retry_count += 1;
|
||||
if let Err(e) = tx.send(pkt) {
|
||||
tracing::error!("fail to enqueue retry message: {e}");
|
||||
}
|
||||
});
|
||||
} else {
|
||||
tracing::error!(
|
||||
"fail to send message to {}: reach maximum retries",
|
||||
pkt.target
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Packet {
|
||||
target: SocketAddr,
|
||||
body: Body,
|
||||
retry_count: usize,
|
||||
}
|
||||
|
||||
impl Packet {
|
||||
fn new(target: SocketAddr, body: Body) -> Self {
|
||||
Self {
|
||||
target,
|
||||
body,
|
||||
retry_count: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ edition = "2021"
|
|||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.32", features = ["sync", "net"] }
|
||||
tokio = { version = "1.32", features = ["sync", "net", "time"] }
|
||||
sphinx-packet = "0.1.0"
|
||||
futures = "0.3"
|
||||
tokio-util = { version = "0.7", features = ["io", "io-util"] }
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use sphinx_packet::{payload::Payload, SphinxPacket};
|
||||
|
||||
use std::{io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration};
|
||||
|
||||
use tokio::{
|
||||
io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
|
||||
net::TcpStream,
|
||||
|
@ -122,7 +123,6 @@ pub async fn retry_backoff(
|
|||
for idx in 0..max_retries {
|
||||
// backoff
|
||||
let wait = Duration::from_millis((retry_delay.as_millis() as u64).pow(idx as u32));
|
||||
|
||||
tokio::time::sleep(wait).await;
|
||||
let mut socket = socket.lock().await;
|
||||
match body.write(&mut *socket).await {
|
||||
|
|
|
@ -5,5 +5,4 @@ edition = "2021"
|
|||
|
||||
[dependencies]
|
||||
tokio = { version = "1.32", default-features = false, features = ["sync", "net"] }
|
||||
parking_lot = { version = "0.12", features = ["send_guard"] }
|
||||
mixnet-protocol = { path = "../protocol" }
|
|
@ -1,7 +1,6 @@
|
|||
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
|
||||
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::{net::TcpStream, sync::Mutex};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ConnectionPool {
|
||||
|
|
|
@ -72,7 +72,7 @@ impl MixNode {
|
|||
|
||||
let mut nodes = Vec::<MixNode>::new();
|
||||
for config in &configs {
|
||||
nodes.push(Self::spawn(config.clone()).await);
|
||||
nodes.push(Self::spawn(*config).await);
|
||||
}
|
||||
|
||||
// We need to return configs as well, to configure mixclients accordingly
|
||||
|
|
Loading…
Reference in New Issue