Add concrete Error implementation for mixnet (#405)
Add concrete Error implementation
This commit is contained in:
parent
46c7e25f6f
commit
cb343156b7
22
mixnet/client/src/error.rs
Normal file
22
mixnet/client/src/error.rs
Normal file
@ -0,0 +1,22 @@
|
||||
use mixnet_protocol::ProtocolError;
|
||||
use nym_sphinx::addressing::nodes::NymNodeRoutingAddressError;
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum MixnetClientError {
|
||||
#[error("mixnet node connect error")]
|
||||
MixnetNodeConnectError,
|
||||
#[error("mixnode stream has been closed")]
|
||||
MixnetNodeStreamClosed,
|
||||
#[error("unexpected stream body received")]
|
||||
UnexpectedStreamBody,
|
||||
#[error("invalid payload")]
|
||||
InvalidPayload,
|
||||
#[error("invalid routing address: {0}")]
|
||||
InvalidRoutingAddress(#[from] NymNodeRoutingAddressError),
|
||||
#[error("{0}")]
|
||||
Protocol(#[from] ProtocolError),
|
||||
#[error("{0}")]
|
||||
Message(#[from] nym_sphinx::message::NymMessageError),
|
||||
}
|
||||
|
||||
pub type Result<T> = core::result::Result<T, MixnetClientError>;
|
@ -1,8 +1,9 @@
|
||||
pub mod config;
|
||||
pub mod error;
|
||||
pub use error::*;
|
||||
mod receiver;
|
||||
mod sender;
|
||||
|
||||
use std::error::Error;
|
||||
use std::time::Duration;
|
||||
|
||||
pub use config::MixnetClientConfig;
|
||||
@ -11,7 +12,6 @@ use futures::stream::BoxStream;
|
||||
use mixnet_util::ConnectionPool;
|
||||
use rand::Rng;
|
||||
use sender::Sender;
|
||||
use thiserror::Error;
|
||||
|
||||
// A client for sending packets to Mixnet and receiving packets from Mixnet.
|
||||
pub struct MixnetClient<R: Rng> {
|
||||
@ -19,7 +19,7 @@ pub struct MixnetClient<R: Rng> {
|
||||
sender: Sender<R>,
|
||||
}
|
||||
|
||||
pub type MessageStream = BoxStream<'static, Result<Vec<u8>, MixnetClientError>>;
|
||||
pub type MessageStream = BoxStream<'static, Result<Vec<u8>>>;
|
||||
|
||||
impl<R: Rng> MixnetClient<R> {
|
||||
pub fn new(config: MixnetClientConfig, rng: R) -> Self {
|
||||
@ -36,27 +36,11 @@ impl<R: Rng> MixnetClient<R> {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(&self) -> Result<MessageStream, MixnetClientError> {
|
||||
pub async fn run(&self) -> Result<MessageStream> {
|
||||
self.mode.run().await
|
||||
}
|
||||
|
||||
pub fn send(
|
||||
&mut self,
|
||||
msg: Vec<u8>,
|
||||
total_delay: Duration,
|
||||
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
|
||||
pub fn send(&mut self, msg: Vec<u8>, total_delay: Duration) -> Result<()> {
|
||||
self.sender.send(msg, total_delay)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum MixnetClientError {
|
||||
#[error("mixnet node connect error")]
|
||||
MixnetNodeConnectError,
|
||||
#[error("mixnode stream has been closed")]
|
||||
MixnetNodeStreamClosed,
|
||||
#[error("unexpected stream body received")]
|
||||
UnexpectedStreamBody,
|
||||
#[error("invalid payload")]
|
||||
InvalidPayload,
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use std::{error::Error, net::SocketAddr};
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use futures::{stream, Stream, StreamExt};
|
||||
use mixnet_protocol::Body;
|
||||
@ -9,6 +9,7 @@ use nym_sphinx::{
|
||||
};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use super::error::*;
|
||||
use crate::MixnetClientError;
|
||||
|
||||
// Receiver accepts TCP connections to receive incoming payloads from the Mixnet.
|
||||
@ -21,12 +22,7 @@ impl Receiver {
|
||||
Self { node_address }
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
&self,
|
||||
) -> Result<
|
||||
impl Stream<Item = Result<Vec<u8>, MixnetClientError>> + Send + 'static,
|
||||
MixnetClientError,
|
||||
> {
|
||||
pub async fn run(&self) -> Result<impl Stream<Item = Result<Vec<u8>>> + Send + 'static> {
|
||||
let Ok(socket) = TcpStream::connect(self.node_address).await else {
|
||||
return Err(MixnetClientError::MixnetNodeConnectError);
|
||||
};
|
||||
@ -36,9 +32,7 @@ impl Receiver {
|
||||
))))
|
||||
}
|
||||
|
||||
fn fragment_stream(
|
||||
socket: TcpStream,
|
||||
) -> impl Stream<Item = Result<Fragment, MixnetClientError>> + Send + 'static {
|
||||
fn fragment_stream(socket: TcpStream) -> impl Stream<Item = Result<Fragment>> + Send + 'static {
|
||||
stream::unfold(socket, move |mut socket| {
|
||||
async move {
|
||||
let Ok(body) = Body::read(&mut socket).await else {
|
||||
@ -60,11 +54,8 @@ impl Receiver {
|
||||
}
|
||||
|
||||
fn message_stream(
|
||||
fragment_stream: impl Stream<Item = Result<Fragment, MixnetClientError>>
|
||||
+ Send
|
||||
+ Unpin
|
||||
+ 'static,
|
||||
) -> impl Stream<Item = Result<Vec<u8>, MixnetClientError>> + Send + 'static {
|
||||
fragment_stream: impl Stream<Item = Result<Fragment>> + Send + Unpin + 'static,
|
||||
) -> impl Stream<Item = Result<Vec<u8>>> + Send + 'static {
|
||||
// MessageReconstructor buffers all received fragments
|
||||
// and eventually returns reconstructed messages.
|
||||
let message_reconstructor: MessageReconstructor = Default::default();
|
||||
@ -80,7 +71,7 @@ impl Receiver {
|
||||
)
|
||||
}
|
||||
|
||||
fn fragment_from_payload(payload: Payload) -> Result<Fragment, MixnetClientError> {
|
||||
fn fragment_from_payload(payload: Payload) -> Result<Fragment> {
|
||||
let Ok(payload_plaintext) = payload.recover_plaintext() else {
|
||||
return Err(MixnetClientError::InvalidPayload);
|
||||
};
|
||||
@ -91,12 +82,9 @@ impl Receiver {
|
||||
}
|
||||
|
||||
async fn reconstruct_message(
|
||||
fragment_stream: &mut (impl Stream<Item = Result<Fragment, MixnetClientError>>
|
||||
+ Send
|
||||
+ Unpin
|
||||
+ 'static),
|
||||
fragment_stream: &mut (impl Stream<Item = Result<Fragment>> + Send + Unpin + 'static),
|
||||
message_reconstructor: &mut MessageReconstructor,
|
||||
) -> Result<Vec<u8>, MixnetClientError> {
|
||||
) -> Result<Vec<u8>> {
|
||||
// Read fragments until at least one message is fully reconstructed.
|
||||
while let Some(next) = fragment_stream.next().await {
|
||||
match next {
|
||||
@ -131,7 +119,7 @@ impl Receiver {
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_padding(msg: Vec<u8>) -> Result<Vec<u8>, Box<dyn Error>> {
|
||||
fn remove_padding(msg: Vec<u8>) -> Result<Vec<u8>> {
|
||||
let padded_message = PaddedMessage::new_reconstructed(msg);
|
||||
// we need this because PaddedMessage.remove_padding requires it for other NymMessage types.
|
||||
let dummy_num_mix_hops = 0;
|
||||
|
@ -1,6 +1,6 @@
|
||||
use std::{error::Error, net::SocketAddr, time::Duration};
|
||||
use std::{net::SocketAddr, time::Duration};
|
||||
|
||||
use mixnet_protocol::Body;
|
||||
use mixnet_protocol::{Body, ProtocolError};
|
||||
use mixnet_topology::MixnetTopology;
|
||||
use mixnet_util::ConnectionPool;
|
||||
use nym_sphinx::{
|
||||
@ -11,6 +11,8 @@ use nym_sphinx::{
|
||||
use rand::{distributions::Uniform, prelude::Distribution, Rng};
|
||||
use sphinx_packet::{route, SphinxPacket, SphinxPacketBuilder};
|
||||
|
||||
use super::error::*;
|
||||
|
||||
// Sender splits messages into Sphinx packets and sends them to the Mixnet.
|
||||
pub struct Sender<R: Rng> {
|
||||
//TODO: handle topology update
|
||||
@ -38,11 +40,7 @@ impl<R: Rng> Sender<R> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send(
|
||||
&mut self,
|
||||
msg: Vec<u8>,
|
||||
total_delay: Duration,
|
||||
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
|
||||
pub fn send(&mut self, msg: Vec<u8>, total_delay: Duration) -> Result<()> {
|
||||
let destination = self.topology.random_destination(&mut self.rng)?;
|
||||
let destination = Destination::new(
|
||||
DestinationAddressBytes::from_bytes(destination.address.as_bytes()),
|
||||
@ -52,7 +50,7 @@ impl<R: Rng> Sender<R> {
|
||||
self.pad_and_split_message(msg)
|
||||
.into_iter()
|
||||
.map(|fragment| self.build_sphinx_packet(fragment, &destination, total_delay))
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.collect::<Result<Vec<_>>>()?
|
||||
.into_iter()
|
||||
.for_each(|(packet, first_node)| {
|
||||
let pool = self.pool.clone();
|
||||
@ -95,8 +93,7 @@ impl<R: Rng> Sender<R> {
|
||||
fragment: Fragment,
|
||||
destination: &Destination,
|
||||
total_delay: Duration,
|
||||
) -> Result<(sphinx_packet::SphinxPacket, route::Node), Box<dyn Error + Send + Sync + 'static>>
|
||||
{
|
||||
) -> Result<(sphinx_packet::SphinxPacket, route::Node)> {
|
||||
let route = self.topology.random_route(&mut self.rng)?;
|
||||
|
||||
let delays: Vec<Delay> =
|
||||
@ -110,7 +107,8 @@ impl<R: Rng> Sender<R> {
|
||||
|
||||
let packet = SphinxPacketBuilder::new()
|
||||
.with_payload_size(payload.len() + PAYLOAD_OVERHEAD_SIZE)
|
||||
.build_packet(payload, &route, destination, &delays)?;
|
||||
.build_packet(payload, &route, destination, &delays)
|
||||
.map_err(ProtocolError::InvalidSphinxPacket)?;
|
||||
|
||||
let first_mixnode = route.first().cloned().expect("route is not empty");
|
||||
|
||||
@ -123,8 +121,8 @@ impl<R: Rng> Sender<R> {
|
||||
retry_delay: Duration,
|
||||
packet: Box<SphinxPacket>,
|
||||
addr: NodeAddressBytes,
|
||||
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
|
||||
let addr = SocketAddr::try_from(NymNodeRoutingAddress::try_from(addr)?)?;
|
||||
) -> Result<()> {
|
||||
let addr = SocketAddr::from(NymNodeRoutingAddress::try_from(addr)?);
|
||||
tracing::debug!("Sending a Sphinx packet to the node: {addr:?}");
|
||||
|
||||
let mu: std::sync::Arc<tokio::sync::Mutex<tokio::net::TcpStream>> =
|
||||
@ -145,7 +143,8 @@ impl<R: Rng> Sender<R> {
|
||||
body,
|
||||
arc_socket,
|
||||
)
|
||||
.await;
|
||||
.await
|
||||
.map_err(Into::into);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ edition = "2021"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tracing = "0.1.37"
|
||||
tokio = { version = "1.32", features = ["net", "time"] }
|
||||
thiserror = "1"
|
||||
sphinx-packet = "0.1.0"
|
||||
nym-sphinx = { package = "nym-sphinx", git = "https://github.com/nymtech/nym", tag = "v1.1.22" }
|
||||
mixnet-protocol = { path = "../protocol" }
|
||||
|
@ -1,4 +1,4 @@
|
||||
use std::{error::Error, net::SocketAddr};
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use mixnet_protocol::Body;
|
||||
use tokio::{
|
||||
@ -12,8 +12,10 @@ impl ClientNotifier {
|
||||
pub async fn run(
|
||||
listen_address: SocketAddr,
|
||||
mut rx: mpsc::Receiver<Body>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
let listener = TcpListener::bind(listen_address).await?;
|
||||
) -> super::Result<()> {
|
||||
let listener = TcpListener::bind(listen_address)
|
||||
.await
|
||||
.map_err(super::ProtocolError::IO)?;
|
||||
tracing::info!("Listening mixnet client connections: {listen_address}");
|
||||
|
||||
// Currently, handling only a single incoming connection
|
||||
@ -21,7 +23,7 @@ impl ClientNotifier {
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((socket, remote_addr)) => {
|
||||
tracing::debug!("Accepted incoming client connection from {remote_addr:?}");
|
||||
tracing::debug!("Accepted incoming client connection from {remote_addr}");
|
||||
|
||||
if let Err(e) = Self::handle_connection(socket, &mut rx).await {
|
||||
tracing::error!("failed to handle conn: {e}");
|
||||
@ -35,10 +37,10 @@ impl ClientNotifier {
|
||||
async fn handle_connection(
|
||||
mut socket: TcpStream,
|
||||
rx: &mut mpsc::Receiver<Body>,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
) -> super::Result<()> {
|
||||
while let Some(body) = rx.recv().await {
|
||||
if let Err(e) = body.write(&mut socket).await {
|
||||
return Err(format!("error from client conn: {e}").into());
|
||||
return Err(super::MixnetNodeError::Client(e));
|
||||
}
|
||||
}
|
||||
tracing::debug!("body receiver closed");
|
||||
|
@ -1,16 +1,16 @@
|
||||
mod client_notifier;
|
||||
pub mod config;
|
||||
|
||||
use std::{error::Error, net::SocketAddr, time::Duration};
|
||||
use std::{net::SocketAddr, time::Duration};
|
||||
|
||||
use client_notifier::ClientNotifier;
|
||||
pub use config::MixnetNodeConfig;
|
||||
use mixnet_protocol::Body;
|
||||
use mixnet_protocol::{Body, ProtocolError};
|
||||
use mixnet_topology::MixnetNodeId;
|
||||
use mixnet_util::ConnectionPool;
|
||||
use nym_sphinx::{
|
||||
addressing::nodes::NymNodeRoutingAddress, Delay, DestinationAddressBytes, NodeAddressBytes,
|
||||
Payload, PrivateKey,
|
||||
addressing::nodes::{NymNodeRoutingAddress, NymNodeRoutingAddressError},
|
||||
Delay, DestinationAddressBytes, NodeAddressBytes, Payload, PrivateKey,
|
||||
};
|
||||
pub use sphinx_packet::crypto::PRIVATE_KEY_SIZE;
|
||||
use sphinx_packet::{crypto::PUBLIC_KEY_SIZE, ProcessedPacket, SphinxPacket};
|
||||
@ -19,6 +19,20 @@ use tokio::{
|
||||
sync::mpsc,
|
||||
};
|
||||
|
||||
pub type Result<T> = core::result::Result<T, MixnetNodeError>;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum MixnetNodeError {
|
||||
#[error("{0}")]
|
||||
Protocol(#[from] ProtocolError),
|
||||
#[error("invalid routing address: {0}")]
|
||||
InvalidRoutingAddress(#[from] NymNodeRoutingAddressError),
|
||||
#[error("send error: {0}")]
|
||||
SendError(#[from] tokio::sync::mpsc::error::TrySendError<Body>),
|
||||
#[error("client: {0}")]
|
||||
Client(ProtocolError),
|
||||
}
|
||||
|
||||
// A mix node that routes packets in the Mixnet.
|
||||
pub struct MixnetNode {
|
||||
config: MixnetNodeConfig,
|
||||
@ -41,7 +55,7 @@ impl MixnetNode {
|
||||
|
||||
const CLIENT_NOTI_CHANNEL_SIZE: usize = 100;
|
||||
|
||||
pub async fn run(self) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
|
||||
pub async fn run(self) -> Result<()> {
|
||||
tracing::info!("Public key: {:?}", self.public_key());
|
||||
|
||||
// Spawn a ClientNotifier
|
||||
@ -55,7 +69,9 @@ impl MixnetNode {
|
||||
|
||||
//TODO: Accepting ad-hoc TCP conns for now. Improve conn handling.
|
||||
//TODO: Add graceful shutdown
|
||||
let listener = TcpListener::bind(self.config.listen_address).await?;
|
||||
let listener = TcpListener::bind(self.config.listen_address)
|
||||
.await
|
||||
.map_err(ProtocolError::IO)?;
|
||||
tracing::info!(
|
||||
"Listening mixnet node connections: {}",
|
||||
self.config.listen_address
|
||||
@ -96,7 +112,7 @@ impl MixnetNode {
|
||||
pool: ConnectionPool,
|
||||
private_key: [u8; PRIVATE_KEY_SIZE],
|
||||
client_tx: mpsc::Sender<Body>,
|
||||
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
|
||||
) -> Result<()> {
|
||||
loop {
|
||||
let body = Body::read(&mut socket).await?;
|
||||
|
||||
@ -130,7 +146,7 @@ impl MixnetNode {
|
||||
pool: &ConnectionPool,
|
||||
private_key: &PrivateKey,
|
||||
client_tx: &mpsc::Sender<Body>,
|
||||
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
|
||||
) -> Result<()> {
|
||||
match body {
|
||||
Body::SphinxPacket(packet) => {
|
||||
Self::handle_sphinx_packet(pool, max_retries, retry_delay, private_key, packet)
|
||||
@ -154,8 +170,11 @@ impl MixnetNode {
|
||||
retry_delay: Duration,
|
||||
private_key: &PrivateKey,
|
||||
packet: Box<SphinxPacket>,
|
||||
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
|
||||
match packet.process(private_key)? {
|
||||
) -> Result<()> {
|
||||
match packet
|
||||
.process(private_key)
|
||||
.map_err(ProtocolError::InvalidSphinxPacket)?
|
||||
{
|
||||
ProcessedPacket::ForwardHop(packet, next_node_addr, delay) => {
|
||||
Self::forward_packet_to_next_hop(
|
||||
pool,
|
||||
@ -184,7 +203,7 @@ impl MixnetNode {
|
||||
_private_key: &PrivateKey,
|
||||
client_tx: &mpsc::Sender<Body>,
|
||||
body: Body,
|
||||
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
|
||||
) -> Result<()> {
|
||||
// TODO: Decrypt the final payload using the private key, if it's encrypted
|
||||
|
||||
// Do not wait when the channel is full or no receiver exists
|
||||
@ -199,7 +218,7 @@ impl MixnetNode {
|
||||
packet: Box<SphinxPacket>,
|
||||
next_node_addr: NodeAddressBytes,
|
||||
delay: Delay,
|
||||
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
|
||||
) -> Result<()> {
|
||||
tracing::debug!("Delaying the packet for {delay:?}");
|
||||
tokio::time::sleep(delay.to_duration()).await;
|
||||
|
||||
@ -219,7 +238,7 @@ impl MixnetNode {
|
||||
retry_delay: Duration,
|
||||
payload: Payload,
|
||||
destination_addr: DestinationAddressBytes,
|
||||
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
|
||||
) -> Result<()> {
|
||||
tracing::debug!("Forwarding final payload to destination mixnode");
|
||||
|
||||
Self::forward(
|
||||
@ -238,8 +257,8 @@ impl MixnetNode {
|
||||
retry_delay: Duration,
|
||||
body: Body,
|
||||
to: NymNodeRoutingAddress,
|
||||
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
|
||||
let addr = SocketAddr::try_from(to)?;
|
||||
) -> Result<()> {
|
||||
let addr = SocketAddr::from(to);
|
||||
let arc_socket = pool.get_or_init(&addr).await?;
|
||||
|
||||
if let Err(e) = {
|
||||
@ -254,7 +273,8 @@ impl MixnetNode {
|
||||
body,
|
||||
arc_socket,
|
||||
)
|
||||
.await;
|
||||
.await
|
||||
.map_err(Into::into);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -6,7 +6,8 @@ edition = "2021"
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
tokio = "1.32"
|
||||
tokio = { version = "1.32", features = ["sync", "net"] }
|
||||
sphinx-packet = "0.1.0"
|
||||
futures = "0.3"
|
||||
tokio-util = {version = "0.7", features = ["io", "io-util"] }
|
||||
tokio-util = { version = "0.7", features = ["io", "io-util"] }
|
||||
thiserror = "1"
|
||||
|
@ -1,12 +1,28 @@
|
||||
use sphinx_packet::{payload::Payload, SphinxPacket};
|
||||
use std::{error::Error, io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration};
|
||||
|
||||
use std::{io::ErrorKind, net::SocketAddr, sync::Arc, time::Duration};
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
|
||||
io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
|
||||
net::TcpStream,
|
||||
sync::Mutex,
|
||||
};
|
||||
|
||||
pub type Result<T> = core::result::Result<T, ProtocolError>;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ProtocolError {
|
||||
#[error("Unknown body type {0}")]
|
||||
UnknownBodyType(u8),
|
||||
#[error("{0}")]
|
||||
InvalidSphinxPacket(sphinx_packet::Error),
|
||||
#[error("{0}")]
|
||||
InvalidPayload(sphinx_packet::Error),
|
||||
#[error("{0}")]
|
||||
IO(#[from] io::Error),
|
||||
#[error("fail to send packet, reach maximum retries {0}")]
|
||||
ReachMaxRetries(usize),
|
||||
}
|
||||
|
||||
#[non_exhaustive]
|
||||
pub enum Body {
|
||||
SphinxPacket(Box<SphinxPacket>),
|
||||
@ -29,7 +45,7 @@ impl Body {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read<R>(reader: &mut R) -> Result<Body, Box<dyn Error + Send + Sync + 'static>>
|
||||
pub async fn read<R>(reader: &mut R) -> Result<Body>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
@ -37,20 +53,17 @@ impl Body {
|
||||
match id {
|
||||
0 => Self::read_sphinx_packet(reader).await,
|
||||
1 => Self::read_final_payload(reader).await,
|
||||
_ => Err("Invalid body type".into()),
|
||||
id => Err(ProtocolError::UnknownBodyType(id)),
|
||||
}
|
||||
}
|
||||
|
||||
fn sphinx_packet_from_bytes(
|
||||
data: &[u8],
|
||||
) -> Result<Self, Box<dyn Error + Send + Sync + 'static>> {
|
||||
let packet = SphinxPacket::from_bytes(data)?;
|
||||
Ok(Self::new_sphinx(Box::new(packet)))
|
||||
fn sphinx_packet_from_bytes(data: &[u8]) -> Result<Self> {
|
||||
SphinxPacket::from_bytes(data)
|
||||
.map(|packet| Self::new_sphinx(Box::new(packet)))
|
||||
.map_err(ProtocolError::InvalidPayload)
|
||||
}
|
||||
|
||||
async fn read_sphinx_packet<R>(
|
||||
reader: &mut R,
|
||||
) -> Result<Body, Box<dyn Error + Send + Sync + 'static>>
|
||||
async fn read_sphinx_packet<R>(reader: &mut R) -> Result<Body>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
@ -60,16 +73,13 @@ impl Body {
|
||||
Self::sphinx_packet_from_bytes(&buf)
|
||||
}
|
||||
|
||||
pub fn final_payload_from_bytes(
|
||||
data: &[u8],
|
||||
) -> Result<Self, Box<dyn Error + Send + Sync + 'static>> {
|
||||
let payload = Payload::from_bytes(data)?;
|
||||
Ok(Self::new_final_payload(payload))
|
||||
pub fn final_payload_from_bytes(data: &[u8]) -> Result<Self> {
|
||||
Payload::from_bytes(data)
|
||||
.map(Self::new_final_payload)
|
||||
.map_err(ProtocolError::InvalidPayload)
|
||||
}
|
||||
|
||||
async fn read_final_payload<R>(
|
||||
reader: &mut R,
|
||||
) -> Result<Body, Box<dyn Error + Send + Sync + 'static>>
|
||||
async fn read_final_payload<R>(reader: &mut R) -> Result<Body>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
@ -80,10 +90,7 @@ impl Body {
|
||||
Self::final_payload_from_bytes(&buf)
|
||||
}
|
||||
|
||||
pub async fn write<W>(
|
||||
&self,
|
||||
writer: &mut W,
|
||||
) -> Result<(), Box<dyn Error + Send + Sync + 'static>>
|
||||
pub async fn write<W>(&self, writer: &mut W) -> Result<()>
|
||||
where
|
||||
W: AsyncWrite + Unpin + ?Sized,
|
||||
{
|
||||
@ -111,7 +118,7 @@ pub async fn retry_backoff(
|
||||
retry_delay: Duration,
|
||||
body: Body,
|
||||
socket: Arc<Mutex<TcpStream>>,
|
||||
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
|
||||
) -> Result<()> {
|
||||
for idx in 0..max_retries {
|
||||
// backoff
|
||||
let wait = Duration::from_millis((retry_delay.as_millis() as u64).pow(idx as u32));
|
||||
@ -121,19 +128,22 @@ pub async fn retry_backoff(
|
||||
match body.write(&mut *socket).await {
|
||||
Ok(_) => return Ok(()),
|
||||
Err(e) => {
|
||||
if let Some(err) = e.downcast_ref::<std::io::Error>() {
|
||||
match err.kind() {
|
||||
ErrorKind::Unsupported => return Err(e),
|
||||
_ => {
|
||||
// update the connection
|
||||
if let Ok(tcp) = TcpStream::connect(peer_addr).await {
|
||||
*socket = tcp;
|
||||
match &e {
|
||||
ProtocolError::IO(err) => {
|
||||
match err.kind() {
|
||||
ErrorKind::Unsupported => return Err(e),
|
||||
_ => {
|
||||
// update the connection
|
||||
if let Ok(tcp) = TcpStream::connect(peer_addr).await {
|
||||
*socket = tcp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => return Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(format!("Failure after {max_retries} retries").into())
|
||||
Err(ProtocolError::ReachMaxRetries(max_retries))
|
||||
}
|
||||
|
@ -10,3 +10,4 @@ rand = "0.7.3"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
sphinx-packet = "0.1.0"
|
||||
nym-sphinx = { package = "nym-sphinx", git = "https://github.com/nymtech/nym", tag = "v1.1.22" }
|
||||
thiserror = "1"
|
||||
|
@ -1,4 +1,4 @@
|
||||
use std::{error::Error, net::SocketAddr};
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use nym_sphinx::addressing::nodes::{NymNodeRoutingAddress, NymNodeRoutingAddressError};
|
||||
use rand::{seq::IteratorRandom, Rng};
|
||||
@ -7,6 +7,8 @@ use sphinx_packet::{crypto::PUBLIC_KEY_SIZE, route};
|
||||
|
||||
pub type MixnetNodeId = [u8; PUBLIC_KEY_SIZE];
|
||||
|
||||
pub type Result<T> = core::result::Result<T, NymNodeRoutingAddressError>;
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug, Default)]
|
||||
pub struct MixnetTopology {
|
||||
pub layers: Vec<Layer>,
|
||||
@ -54,10 +56,7 @@ mod hex_serde {
|
||||
}
|
||||
|
||||
impl MixnetTopology {
|
||||
pub fn random_route<R: Rng>(
|
||||
&self,
|
||||
rng: &mut R,
|
||||
) -> Result<Vec<route::Node>, Box<dyn Error + Send + Sync + 'static>> {
|
||||
pub fn random_route<R: Rng>(&self, rng: &mut R) -> Result<Vec<route::Node>> {
|
||||
let num_hops = self.layers.len();
|
||||
|
||||
let route: Vec<route::Node> = self
|
||||
@ -78,19 +77,14 @@ impl MixnetTopology {
|
||||
}
|
||||
|
||||
// Choose a destination mixnet node randomly from the last layer.
|
||||
pub fn random_destination<R: Rng>(
|
||||
&self,
|
||||
rng: &mut R,
|
||||
) -> Result<route::Node, Box<dyn Error + Send + Sync + 'static>> {
|
||||
Ok(self
|
||||
.layers
|
||||
pub fn random_destination<R: Rng>(&self, rng: &mut R) -> Result<route::Node> {
|
||||
self.layers
|
||||
.last()
|
||||
.expect("topology is not empty")
|
||||
.random_node(rng)
|
||||
.expect("layer is not empty")
|
||||
.clone()
|
||||
.try_into()
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
@ -103,7 +97,7 @@ impl Layer {
|
||||
impl TryInto<route::Node> for Node {
|
||||
type Error = NymNodeRoutingAddressError;
|
||||
|
||||
fn try_into(self) -> Result<route::Node, Self::Error> {
|
||||
fn try_into(self) -> Result<route::Node> {
|
||||
Ok(route::Node {
|
||||
address: NymNodeRoutingAddress::from(self.address).try_into()?,
|
||||
pub_key: self.public_key.into(),
|
||||
|
@ -5,4 +5,5 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.32", default-features = false, features = ["sync", "net"] }
|
||||
parking_lot = { version = "0.12", features = ["send_guard"] }
|
||||
parking_lot = { version = "0.12", features = ["send_guard"] }
|
||||
mixnet-protocol = { path = "../protocol" }
|
@ -15,12 +15,19 @@ impl ConnectionPool {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_or_init(&self, addr: &SocketAddr) -> std::io::Result<Arc<Mutex<TcpStream>>> {
|
||||
pub async fn get_or_init(
|
||||
&self,
|
||||
addr: &SocketAddr,
|
||||
) -> mixnet_protocol::Result<Arc<Mutex<TcpStream>>> {
|
||||
let mut pool = self.pool.lock().await;
|
||||
match pool.get(addr).cloned() {
|
||||
Some(tcp) => Ok(tcp),
|
||||
None => {
|
||||
let tcp = Arc::new(Mutex::new(TcpStream::connect(addr).await?));
|
||||
let tcp = Arc::new(Mutex::new(
|
||||
TcpStream::connect(addr)
|
||||
.await
|
||||
.map_err(mixnet_protocol::ProtocolError::IO)?,
|
||||
));
|
||||
pool.insert(*addr, tcp.clone());
|
||||
Ok(tcp)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user