281 lines
8.4 KiB
Rust
281 lines
8.4 KiB
Rust
|
use std::collections::HashMap;
|
||
|
|
||
|
use sphinx_packet::{constants::PAYLOAD_SIZE, payload::PAYLOAD_OVERHEAD_SIZE};
|
||
|
use uuid::Uuid;
|
||
|
|
||
|
use crate::error::MixnetError;
|
||
|
|
||
|
pub(crate) struct FragmentSet(Vec<Fragment>);
|
||
|
|
||
|
impl FragmentSet {
|
||
|
const MAX_PLAIN_PAYLOAD_SIZE: usize = PAYLOAD_SIZE - PAYLOAD_OVERHEAD_SIZE;
|
||
|
const CHUNK_SIZE: usize = Self::MAX_PLAIN_PAYLOAD_SIZE - FragmentHeader::SIZE;
|
||
|
|
||
|
pub(crate) fn new(msg: &[u8]) -> Result<Self, MixnetError> {
|
||
|
// For now, we don't support more than `u8::MAX + 1` fragments.
|
||
|
// If needed, we can devise the FragmentSet chaining to support larger messages, like Nym.
|
||
|
let last_fragment_id = FragmentId::try_from(Self::num_chunks(msg) - 1)
|
||
|
.map_err(|_| MixnetError::MessageTooLong(msg.len()))?;
|
||
|
let set_id = FragmentSetId::new();
|
||
|
|
||
|
Ok(FragmentSet(
|
||
|
msg.chunks(Self::CHUNK_SIZE)
|
||
|
.enumerate()
|
||
|
.map(|(i, chunk)| Fragment {
|
||
|
header: FragmentHeader {
|
||
|
set_id,
|
||
|
last_fragment_id,
|
||
|
fragment_id: FragmentId::try_from(i)
|
||
|
.expect("i is always in the right range"),
|
||
|
},
|
||
|
body: Vec::from(chunk),
|
||
|
})
|
||
|
.collect(),
|
||
|
))
|
||
|
}
|
||
|
|
||
|
fn num_chunks(msg: &[u8]) -> usize {
|
||
|
msg.len().div_ceil(Self::CHUNK_SIZE)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl AsRef<Vec<Fragment>> for FragmentSet {
|
||
|
fn as_ref(&self) -> &Vec<Fragment> {
|
||
|
&self.0
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[derive(PartialEq, Eq, Debug, Clone)]
|
||
|
pub(crate) struct Fragment {
|
||
|
header: FragmentHeader,
|
||
|
body: Vec<u8>,
|
||
|
}
|
||
|
|
||
|
impl Fragment {
|
||
|
pub(crate) fn bytes(&self) -> Vec<u8> {
|
||
|
let mut out = Vec::with_capacity(FragmentHeader::SIZE + self.body.len());
|
||
|
out.extend(self.header.bytes());
|
||
|
out.extend(&self.body);
|
||
|
out
|
||
|
}
|
||
|
|
||
|
pub(crate) fn from_bytes(value: &[u8]) -> Result<Self, MixnetError> {
|
||
|
Ok(Self {
|
||
|
header: FragmentHeader::from_bytes(&value[0..FragmentHeader::SIZE])?,
|
||
|
body: value[FragmentHeader::SIZE..].to_vec(),
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)]
|
||
|
struct FragmentSetId(Uuid);
|
||
|
|
||
|
impl FragmentSetId {
|
||
|
const SIZE: usize = 16;
|
||
|
|
||
|
fn new() -> Self {
|
||
|
Self(Uuid::new_v4())
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)]
|
||
|
struct FragmentId(u8);
|
||
|
|
||
|
impl FragmentId {
|
||
|
const SIZE: usize = std::mem::size_of::<u8>();
|
||
|
}
|
||
|
|
||
|
impl TryFrom<usize> for FragmentId {
|
||
|
type Error = MixnetError;
|
||
|
|
||
|
fn try_from(id: usize) -> Result<Self, Self::Error> {
|
||
|
if id > u8::MAX as usize {
|
||
|
return Err(MixnetError::InvalidFragmentId);
|
||
|
}
|
||
|
Ok(Self(id as u8))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl From<FragmentId> for usize {
|
||
|
fn from(id: FragmentId) -> Self {
|
||
|
id.0 as usize
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[derive(PartialEq, Eq, Debug, Clone)]
|
||
|
struct FragmentHeader {
|
||
|
set_id: FragmentSetId,
|
||
|
last_fragment_id: FragmentId,
|
||
|
fragment_id: FragmentId,
|
||
|
}
|
||
|
|
||
|
impl FragmentHeader {
|
||
|
const SIZE: usize = FragmentSetId::SIZE + 2 * FragmentId::SIZE;
|
||
|
|
||
|
fn bytes(&self) -> [u8; Self::SIZE] {
|
||
|
let mut out = [0u8; Self::SIZE];
|
||
|
out[0..FragmentSetId::SIZE].copy_from_slice(self.set_id.0.as_bytes());
|
||
|
out[FragmentSetId::SIZE] = self.last_fragment_id.0;
|
||
|
out[FragmentSetId::SIZE + FragmentId::SIZE] = self.fragment_id.0;
|
||
|
out
|
||
|
}
|
||
|
|
||
|
fn from_bytes(value: &[u8]) -> Result<Self, MixnetError> {
|
||
|
if value.len() != Self::SIZE {
|
||
|
return Err(MixnetError::InvalidFragmentHeader);
|
||
|
}
|
||
|
|
||
|
Ok(Self {
|
||
|
set_id: FragmentSetId(Uuid::from_slice(&value[0..FragmentSetId::SIZE])?),
|
||
|
last_fragment_id: FragmentId(value[FragmentSetId::SIZE]),
|
||
|
fragment_id: FragmentId(value[FragmentSetId::SIZE + FragmentId::SIZE]),
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
pub struct MessageReconstructor {
|
||
|
fragment_sets: HashMap<FragmentSetId, FragmentSetReconstructor>,
|
||
|
}
|
||
|
|
||
|
impl MessageReconstructor {
|
||
|
pub fn new() -> Self {
|
||
|
Self {
|
||
|
fragment_sets: HashMap::new(),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/// Adds a fragment to the reconstructor and tries to reconstruct a message from the fragment set.
|
||
|
/// This returns `None` if the message has not been reconstructed yet.
|
||
|
pub fn add_and_reconstruct(&mut self, fragment: Fragment) -> Option<Vec<u8>> {
|
||
|
let set_id = fragment.header.set_id;
|
||
|
let reconstructed_msg = self
|
||
|
.fragment_sets
|
||
|
.entry(set_id)
|
||
|
.or_insert(FragmentSetReconstructor::new(
|
||
|
fragment.header.last_fragment_id,
|
||
|
))
|
||
|
.add(fragment)
|
||
|
.try_reconstruct_message()?;
|
||
|
// A message has been reconstructed completely from the fragment set.
|
||
|
// Delete the fragment set from the reconstructor.
|
||
|
self.fragment_sets.remove(&set_id);
|
||
|
Some(reconstructed_msg)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
struct FragmentSetReconstructor {
|
||
|
last_fragment_id: FragmentId,
|
||
|
fragments: HashMap<FragmentId, Fragment>,
|
||
|
// For mem optimization, accumulates the expected message size
|
||
|
// whenever a new fragment is added to the `fragments`.
|
||
|
message_size: usize,
|
||
|
}
|
||
|
|
||
|
impl FragmentSetReconstructor {
|
||
|
fn new(last_fragment_id: FragmentId) -> Self {
|
||
|
Self {
|
||
|
last_fragment_id,
|
||
|
fragments: HashMap::new(),
|
||
|
message_size: 0,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
fn add(&mut self, fragment: Fragment) -> &mut Self {
|
||
|
self.message_size += fragment.body.len();
|
||
|
if let Some(old_fragment) = self.fragments.insert(fragment.header.fragment_id, fragment) {
|
||
|
// In the case when a new fragment replaces the old one, adjust the `meesage_size`.
|
||
|
// e.g. The same fragment has been received multiple times.
|
||
|
self.message_size -= old_fragment.body.len();
|
||
|
}
|
||
|
self
|
||
|
}
|
||
|
|
||
|
/// Merges all fragments gathered if possible
|
||
|
fn try_reconstruct_message(&self) -> Option<Vec<u8>> {
|
||
|
(self.fragments.len() - 1 == self.last_fragment_id.into()).then(|| {
|
||
|
let mut msg = Vec::with_capacity(self.message_size);
|
||
|
for id in 0..=self.last_fragment_id.0 {
|
||
|
msg.extend(&self.fragments.get(&FragmentId(id)).unwrap().body);
|
||
|
}
|
||
|
msg
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[cfg(test)]
|
||
|
mod tests {
|
||
|
use std::collections::HashSet;
|
||
|
|
||
|
use rand::RngCore;
|
||
|
|
||
|
use super::*;
|
||
|
|
||
|
#[test]
|
||
|
fn fragment_header() {
|
||
|
let header = FragmentHeader {
|
||
|
set_id: FragmentSetId::new(),
|
||
|
last_fragment_id: FragmentId(19),
|
||
|
fragment_id: FragmentId(0),
|
||
|
};
|
||
|
let bz = header.bytes();
|
||
|
assert_eq!(FragmentHeader::SIZE, bz.len());
|
||
|
assert_eq!(header, FragmentHeader::from_bytes(bz.as_slice()).unwrap());
|
||
|
}
|
||
|
|
||
|
#[test]
|
||
|
fn fragment() {
|
||
|
let fragment = Fragment {
|
||
|
header: FragmentHeader {
|
||
|
set_id: FragmentSetId::new(),
|
||
|
last_fragment_id: FragmentId(19),
|
||
|
fragment_id: FragmentId(0),
|
||
|
},
|
||
|
body: vec![1, 2, 3, 4],
|
||
|
};
|
||
|
let bz = fragment.bytes();
|
||
|
assert_eq!(FragmentHeader::SIZE + fragment.body.len(), bz.len());
|
||
|
assert_eq!(fragment, Fragment::from_bytes(bz.as_slice()).unwrap());
|
||
|
}
|
||
|
|
||
|
#[test]
|
||
|
fn fragment_set() {
|
||
|
let mut msg = vec![0u8; FragmentSet::CHUNK_SIZE * 3 + FragmentSet::CHUNK_SIZE / 2];
|
||
|
rand::thread_rng().fill_bytes(&mut msg);
|
||
|
|
||
|
assert_eq!(4, FragmentSet::num_chunks(&msg));
|
||
|
|
||
|
let set = FragmentSet::new(&msg).unwrap();
|
||
|
assert_eq!(4, set.as_ref().iter().len());
|
||
|
assert_eq!(
|
||
|
1,
|
||
|
HashSet::<FragmentSetId>::from_iter(
|
||
|
set.as_ref().iter().map(|fragment| fragment.header.set_id)
|
||
|
)
|
||
|
.len()
|
||
|
);
|
||
|
set.as_ref()
|
||
|
.iter()
|
||
|
.enumerate()
|
||
|
.for_each(|(i, fragment)| assert_eq!(i, fragment.header.fragment_id.0 as usize));
|
||
|
}
|
||
|
|
||
|
#[test]
|
||
|
fn message_reconstructor() {
|
||
|
let mut msg = vec![0u8; FragmentSet::CHUNK_SIZE * 2];
|
||
|
rand::thread_rng().fill_bytes(&mut msg);
|
||
|
|
||
|
let set = FragmentSet::new(&msg).unwrap();
|
||
|
|
||
|
let mut reconstructor = MessageReconstructor::new();
|
||
|
let mut fragments = set.as_ref().iter();
|
||
|
assert_eq!(
|
||
|
None,
|
||
|
reconstructor.add_and_reconstruct(fragments.next().unwrap().clone())
|
||
|
);
|
||
|
assert_eq!(
|
||
|
Some(msg),
|
||
|
reconstructor.add_and_reconstruct(fragments.next().unwrap().clone())
|
||
|
);
|
||
|
}
|
||
|
}
|