1
0
mirror of synced 2025-02-03 03:14:43 +00:00

Mix: Packet header encryption (#897)

* Mix: Packet header encryption

* modularization

* apply feedbacks

* use itertools::chain! instead of concat_bytes

* use explicit RoutingFlag type in error

* define size of layered encryption data as const

* remove MixMessage::Settings

* Use proper size for drop message

* remove TODOs
This commit is contained in:
Youngjoon Lee 2024-11-28 10:36:14 +09:00 committed by GitHub
parent f71db1a7fe
commit 0ff16ef5f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 760 additions and 82 deletions

View File

@ -38,7 +38,7 @@ where
} }
} }
pub fn wrap_message(&mut self, message: &[u8]) -> Result<Vec<u8>, nomos_mix_message::Error> { pub fn wrap_message(&mut self, message: &[u8]) -> Result<Vec<u8>, M::Error> {
// TODO: Use the actual Sphinx encoding instead of mock. // TODO: Use the actual Sphinx encoding instead of mock.
let public_keys = self let public_keys = self
.membership .membership
@ -50,10 +50,7 @@ where
M::build_message(message, &public_keys) M::build_message(message, &public_keys)
} }
pub fn unwrap_message( pub fn unwrap_message(&self, message: &[u8]) -> Result<(Vec<u8>, bool), M::Error> {
&self,
message: &[u8],
) -> Result<(Vec<u8>, bool), nomos_mix_message::Error> {
M::unwrap_message(message, &self.settings.private_key) M::unwrap_message(message, &self.settings.private_key)
} }
} }

View File

@ -4,6 +4,7 @@ pub mod temporal;
pub use crypto::CryptographicProcessorSettings; pub use crypto::CryptographicProcessorSettings;
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
use rand::RngCore; use rand::RngCore;
use std::fmt::Debug;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
@ -52,6 +53,7 @@ where
M: MixMessage, M: MixMessage,
M::PrivateKey: Serialize + DeserializeOwned, M::PrivateKey: Serialize + DeserializeOwned,
M::PublicKey: Clone + PartialEq, M::PublicKey: Clone + PartialEq,
M::Error: Debug,
Scheduler: Stream<Item = ()> + Unpin + Send + Sync + 'static, Scheduler: Stream<Item = ()> + Unpin + Send + Sync + 'static,
{ {
pub fn new( pub fn new(
@ -91,9 +93,6 @@ where
tracing::error!("Failed to send message to the outbound channel: {e:?}"); tracing::error!("Failed to send message to the outbound channel: {e:?}");
} }
} }
Err(nomos_mix_message::Error::MsgUnwrapNotAllowed) => {
tracing::debug!("Message cannot be unwrapped by this node");
}
Err(e) => { Err(e) => {
tracing::error!("Failed to unwrap message: {:?}", e); tracing::error!("Failed to unwrap message: {:?}", e);
} }
@ -108,6 +107,7 @@ where
M: MixMessage + Unpin, M: MixMessage + Unpin,
M::PrivateKey: Serialize + DeserializeOwned + Unpin, M::PrivateKey: Serialize + DeserializeOwned + Unpin,
M::PublicKey: Clone + PartialEq + Unpin, M::PublicKey: Clone + PartialEq + Unpin,
M::Error: Debug,
Scheduler: Stream<Item = ()> + Unpin + Send + Sync + 'static, Scheduler: Stream<Item = ()> + Unpin + Send + Sync + 'static,
{ {
type Item = MixOutgoingMessage; type Item = MixOutgoingMessage;
@ -126,6 +126,7 @@ where
M: MixMessage, M: MixMessage,
M::PrivateKey: Serialize + DeserializeOwned, M::PrivateKey: Serialize + DeserializeOwned,
M::PublicKey: Clone + PartialEq, M::PublicKey: Clone + PartialEq,
M::Error: Debug,
Scheduler: Stream<Item = ()> + Unpin + Send + Sync + 'static, Scheduler: Stream<Item = ()> + Unpin + Send + Sync + 'static,
{ {
fn blend( fn blend(
@ -155,6 +156,7 @@ where
M: MixMessage, M: MixMessage,
M::PrivateKey: Clone + Serialize + DeserializeOwned + PartialEq, M::PrivateKey: Clone + Serialize + DeserializeOwned + PartialEq,
M::PublicKey: Clone + Serialize + DeserializeOwned + PartialEq, M::PublicKey: Clone + Serialize + DeserializeOwned + PartialEq,
M::Error: Debug,
S: Stream<Item = ()> + Unpin + Send + Sync + 'static, S: Stream<Item = ()> + Unpin + Send + Sync + 'static,
{ {
} }

View File

@ -4,7 +4,9 @@ version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
serde = { version = "1.0", features = ["derive"] } itertools = "0.13"
rand_chacha = "0.3"
sha2 = "0.10"
sphinx-packet = "0.2" sphinx-packet = "0.2"
thiserror = "1.0.65" thiserror = "1.0.65"
x25519-dalek = { version = "2.0.1", features = [ x25519-dalek = { version = "2.0.1", features = [

View File

@ -1,15 +1,16 @@
mod error;
pub mod mock; pub mod mock;
pub mod packet; pub mod sphinx;
pub use error::Error;
pub trait MixMessage { pub trait MixMessage {
type PublicKey; type PublicKey;
type PrivateKey; type PrivateKey;
type Error;
const DROP_MESSAGE: &'static [u8]; const DROP_MESSAGE: &'static [u8];
fn build_message(payload: &[u8], public_keys: &[Self::PublicKey]) -> Result<Vec<u8>, Error>; fn build_message(
payload: &[u8],
public_keys: &[Self::PublicKey],
) -> Result<Vec<u8>, Self::Error>;
/// Unwrap the message one layer. /// Unwrap the message one layer.
/// ///
/// This function returns the unwrapped message and a boolean indicating whether the message was fully unwrapped. /// This function returns the unwrapped message and a boolean indicating whether the message was fully unwrapped.
@ -20,7 +21,7 @@ pub trait MixMessage {
fn unwrap_message( fn unwrap_message(
message: &[u8], message: &[u8],
private_key: &Self::PrivateKey, private_key: &Self::PrivateKey,
) -> Result<(Vec<u8>, bool), Error>; ) -> Result<(Vec<u8>, bool), Self::Error>;
fn is_drop_message(message: &[u8]) -> bool { fn is_drop_message(message: &[u8]) -> bool {
message == Self::DROP_MESSAGE message == Self::DROP_MESSAGE
} }

View File

@ -6,8 +6,6 @@ pub enum Error {
PayloadTooLarge, PayloadTooLarge,
#[error("Invalid number of layers")] #[error("Invalid number of layers")]
InvalidNumberOfLayers, InvalidNumberOfLayers,
#[error("Sphinx packet error: {0}")]
SphinxPacketError(#[from] sphinx_packet::Error),
#[error("Unwrapping a message is not allowed to this node")] #[error("Unwrapping a message is not allowed to this node")]
/// e.g. the message cannot be unwrapped using the private key provided /// e.g. the message cannot be unwrapped using the private key provided
MsgUnwrapNotAllowed, MsgUnwrapNotAllowed,

View File

@ -1,4 +1,8 @@
use crate::{Error, MixMessage}; pub mod error;
use error::Error;
use crate::MixMessage;
// TODO: Remove all the mock below once the actual implementation is integrated to the system. // TODO: Remove all the mock below once the actual implementation is integrated to the system.
// //
/// A mock implementation of the Sphinx encoding. /// A mock implementation of the Sphinx encoding.
@ -17,13 +21,17 @@ pub struct MockMixMessage;
impl MixMessage for MockMixMessage { impl MixMessage for MockMixMessage {
type PublicKey = [u8; NODE_ID_SIZE]; type PublicKey = [u8; NODE_ID_SIZE];
type PrivateKey = [u8; NODE_ID_SIZE]; type PrivateKey = [u8; NODE_ID_SIZE];
type Error = Error;
const DROP_MESSAGE: &'static [u8] = &[0; MESSAGE_SIZE]; const DROP_MESSAGE: &'static [u8] = &[0; MESSAGE_SIZE];
/// The length of the encoded message is fixed to [`MESSAGE_SIZE`] bytes. /// The length of the encoded message is fixed to [`MESSAGE_SIZE`] bytes.
/// The [`MAX_LAYERS`] number of [`NodeId`]s are concatenated in front of the payload. /// The [`MAX_LAYERS`] number of [`NodeId`]s are concatenated in front of the payload.
/// The payload is zero-padded to the end. /// The payload is zero-padded to the end.
/// ///
fn build_message(payload: &[u8], public_keys: &[Self::PublicKey]) -> Result<Vec<u8>, Error> { fn build_message(
payload: &[u8],
public_keys: &[Self::PublicKey],
) -> Result<Vec<u8>, Self::Error> {
// In this mock, we don't encrypt anything. So, we use public key as just a node ID. // In this mock, we don't encrypt anything. So, we use public key as just a node ID.
let node_ids = public_keys; let node_ids = public_keys;
if node_ids.is_empty() || node_ids.len() > MAX_LAYERS { if node_ids.is_empty() || node_ids.len() > MAX_LAYERS {
@ -54,7 +62,7 @@ impl MixMessage for MockMixMessage {
fn unwrap_message( fn unwrap_message(
message: &[u8], message: &[u8],
private_key: &Self::PrivateKey, private_key: &Self::PrivateKey,
) -> Result<(Vec<u8>, bool), Error> { ) -> Result<(Vec<u8>, bool), Self::Error> {
if message.len() != MESSAGE_SIZE { if message.len() != MESSAGE_SIZE {
return Err(Error::InvalidMixMessage); return Err(Error::InvalidMixMessage);
} }

View File

@ -0,0 +1,15 @@
use sphinx_packet::header::routing::RoutingFlag;
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("Sphinx packet error: {0}")]
SphinxPacketError(#[from] sphinx_packet::Error),
#[error("Invalid packet bytes")]
InvalidPacketBytes,
#[error("Invalid routing flag: {0}")]
InvalidRoutingFlag(RoutingFlag),
#[error("Invalid routing length: {0} bytes")]
InvalidEncryptedRoutingInfoLength(usize),
#[error("ConsistentLengthLayeredEncryptionError: {0}")]
ConsistentLengthLayeredEncryptionError(#[from] super::layered_cipher::Error),
}

View File

@ -0,0 +1,345 @@
use std::marker::PhantomData;
use rand_chacha::{
rand_core::{RngCore, SeedableRng},
ChaCha12Rng,
};
use sphinx_packet::{
constants::HEADER_INTEGRITY_MAC_SIZE,
crypto::STREAM_CIPHER_INIT_VECTOR,
header::{
keys::{HeaderIntegrityMacKey, StreamCipherKey},
mac::HeaderIntegrityMac,
},
};
use super::parse_bytes;
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("Invalid cipher text length")]
InvalidCipherTextLength,
#[error("Invalid encryption param")]
InvalidEncryptionParam,
#[error("Integrity MAC verification failed")]
IntegrityMacVerificationFailed,
}
type Result<T> = std::result::Result<T, Error>;
/// A cipher to encrypt/decrypt a list of data of the same size using a list of keys.
///
/// The cipher performs the layered encryption.
/// The following example shows the simplified output.
/// - Input: [[data0, k0], [data1, k1]]
/// - Output: encrypt(k0, [data0, encrypt(k1, [data1])])
///
/// The max number of layers is limited to the `max_layers` parameter.
/// Even if the number of data and keys provided for encryption is smaller than `max_layers`,
/// The cipher always produces the max-size output regardless of the number of data and keys provided,
/// in order to ensure that all outputs generated by the cipher are the same size.
///
/// The cipher also provides the length-preserved decryption.
/// Even if one layer of encryptions is decrypted, the length of decrypted data is
/// the same as the length of the original data.
/// For example:
/// len(encrypt(k0, [data0, encrypt(k1, [data1])])) == len(encrypt(k1, [data1]))
pub struct ConsistentLengthLayeredCipher<D> {
/// All encrypted data produced by the cipher has the same size according to the `max_layers`.
pub max_layers: usize,
_data: PhantomData<D>,
}
pub trait ConsistentLengthLayeredCipherData {
// Returns the serialized bytes for an instance of the implementing type
fn to_bytes(&self) -> Vec<u8>;
// The size of the serialized data.
const SIZE: usize;
}
/// A parameter for one layer of encryption
pub struct EncryptionParam<D: ConsistentLengthLayeredCipherData> {
/// A data to be included in the layer.
pub data: D,
/// A [`Key`] to encrypt the layer that will include the [`Self::data`].
pub key: Key,
}
/// A set of keys to encrypt/decrypt a single layer.
pub struct Key {
/// A 128-bit key for encryption/decryption
pub stream_cipher_key: StreamCipherKey,
/// A 128-bit key for computing/verifying integrity MAC
pub integrity_mac_key: HeaderIntegrityMacKey,
}
impl<D: ConsistentLengthLayeredCipherData> ConsistentLengthLayeredCipher<D> {
pub fn new(max_layers: usize) -> Self {
Self {
max_layers,
_data: Default::default(),
}
}
/// The total size of fully encrypted output that includes all layers.
/// This size is determined by [`D::size`] and [`max_layers`].
pub const fn total_size(max_layers: usize) -> usize {
Self::SINGLE_LAYER_SIZE * max_layers
}
/// The size of a single layer that contains a data and a MAC.
/// The MAC is used to verify integrity of the encrypted next layer.
const SINGLE_LAYER_SIZE: usize = D::SIZE + HEADER_INTEGRITY_MAC_SIZE;
/// Perform the layered encryption.
pub fn encrypt(&self, params: &[EncryptionParam<D>]) -> Result<(Vec<u8>, HeaderIntegrityMac)> {
if params.is_empty() || params.len() > self.max_layers {
return Err(Error::InvalidEncryptionParam);
}
params
.iter()
.take(params.len() - 1) // Exclude the last param that will be treated separately below.
.rev() // Data and keys must be used in reverse order to encrypt the inner-most layer first
.try_fold(self.build_last_layer(params)?, |(encrypted, mac), param| {
self.build_intermediate_layer(param, mac, encrypted)
})
}
/// Build an intermediate layer of encryption that wraps subsequent layers already encrypted.
/// The output has the same size as [`Self::total_size`],
/// regardless of how many subsequent layers that this layer wraps.
fn build_intermediate_layer(
&self,
param: &EncryptionParam<D>,
next_mac: HeaderIntegrityMac,
next_encrypted_data: Vec<u8>,
) -> Result<(Vec<u8>, HeaderIntegrityMac)> {
// Concatenate the data with the encrypted subsequent layers and its MAC.
let data = param.data.to_bytes();
let total_data = itertools::chain!(
&data,
next_mac.as_bytes(),
// Truncate last bytes for the length-preserved decryption later.
// They will be restored by a filler during the decryption process.
&next_encrypted_data[..next_encrypted_data.len() - Self::SINGLE_LAYER_SIZE],
)
.copied()
.collect::<Vec<_>>();
// Encrypt the concatenated bytes, and compute MAC.
let mut encrypted = total_data;
self.apply_streamcipher(
&mut encrypted,
&param.key.stream_cipher_key,
StreamCipherOption::FromFront,
);
let mac = Self::compute_mac(&param.key.integrity_mac_key, &encrypted);
assert_eq!(encrypted.len(), Self::total_size(self.max_layers));
Ok((encrypted, mac))
}
/// Build the last layer of encryption.
/// The output has the same size as [`Self::total_size`] by using fillers,
/// even though it doesn't wrap any subsequent layer.
/// This is for the length-preserved decryption.
fn build_last_layer(
&self,
params: &[EncryptionParam<D>],
) -> Result<(Vec<u8>, HeaderIntegrityMac)> {
let last_param = params.last().ok_or(Error::InvalidEncryptionParam)?;
// Build fillers that will be appended to the last data.
// The number of fillers must be the same as the number of intermediate layers
// (excluding the last layer) that will be decrypted later.
let fillers = self.build_fillers(&params[..params.len() - 1]);
// Header integrity MAC doesn't need to be included in the last layer
// because there is no next encrypted layer.
// Instead, random bytes are used to fill the space between data and fillers.
// The size of random bytes depends on the [`self.max_layers`].
let random_bytes =
random_bytes(Self::total_size(self.max_layers) - D::SIZE - fillers.len());
// First, concat the data and the random bytes, and encrypt it.
let last_data = last_param.data.to_bytes();
let total_data_without_fillers = itertools::chain!(&last_data, &random_bytes)
.copied()
.collect::<Vec<_>>();
let mut encrypted = total_data_without_fillers;
self.apply_streamcipher(
&mut encrypted,
&last_param.key.stream_cipher_key,
StreamCipherOption::FromFront,
);
// Append fillers to the encrypted bytes, and compute MAC.
encrypted.extend(fillers);
let mac = Self::compute_mac(&last_param.key.integrity_mac_key, &encrypted);
assert_eq!(encrypted.len(), Self::total_size(self.max_layers));
Ok((encrypted, mac))
}
/// Build as many fillers as the number of keys provided.
/// Fillers are encrypted in accumulated manner by keys.
fn build_fillers(&self, params: &[EncryptionParam<D>]) -> Vec<u8> {
let mut fillers = vec![0u8; Self::SINGLE_LAYER_SIZE * params.len()];
params
.iter()
.map(|param| &param.key.stream_cipher_key)
.enumerate()
.for_each(|(i, key)| {
self.apply_streamcipher(
&mut fillers[0..(i + 1) * Self::SINGLE_LAYER_SIZE],
key,
StreamCipherOption::FromBack,
)
});
fillers
}
/// Unpack one layer of encryption by performing the length-preserved decryption.
pub fn unpack(
&self,
mac: &HeaderIntegrityMac,
encrypted_total_data: &[u8],
key: &Key,
) -> Result<(Vec<u8>, HeaderIntegrityMac, Vec<u8>)> {
if encrypted_total_data.len() != Self::total_size(self.max_layers) {
return Err(Error::InvalidCipherTextLength);
}
// If a wrong key is used, the decryption should fail.
if !mac.verify(key.integrity_mac_key, encrypted_total_data) {
return Err(Error::IntegrityMacVerificationFailed);
}
// Extend the encrypted data by the length of a single layer
// in order to restore the truncated part (a encrypted filler)
// by [`Self::build_intermediate_layer`] during the encryption process.
let total_data_with_zero_filler = encrypted_total_data
.iter()
.copied()
.chain(std::iter::repeat(0u8).take(Self::SINGLE_LAYER_SIZE))
.collect::<Vec<_>>();
// Decrypt the extended data.
let mut decrypted = total_data_with_zero_filler;
self.apply_streamcipher(
&mut decrypted,
&key.stream_cipher_key,
StreamCipherOption::FromFront,
);
// Parse the decrypted data into 3 parts: data, MAC, and the next encrypted data.
let parsed = parse_bytes(
&decrypted,
&[
D::SIZE,
HEADER_INTEGRITY_MAC_SIZE,
Self::total_size(self.max_layers),
],
)
.unwrap();
let data = parsed[0].to_vec();
let next_mac = HeaderIntegrityMac::from_bytes(parsed[1].try_into().unwrap());
let next_encrypted_data = parsed[2].to_vec();
Ok((data, next_mac, next_encrypted_data))
}
fn apply_streamcipher(&self, data: &mut [u8], key: &StreamCipherKey, opt: StreamCipherOption) {
let pseudorandom_bytes = sphinx_packet::crypto::generate_pseudorandom_bytes(
key,
&STREAM_CIPHER_INIT_VECTOR,
Self::total_size(self.max_layers) + Self::SINGLE_LAYER_SIZE,
);
let pseudorandom_bytes = match opt {
StreamCipherOption::FromFront => &pseudorandom_bytes[..data.len()],
StreamCipherOption::FromBack => {
&pseudorandom_bytes[pseudorandom_bytes.len() - data.len()..]
}
};
Self::xor_in_place(data, pseudorandom_bytes)
}
// In-place XOR operation: b is applied to a.
fn xor_in_place(a: &mut [u8], b: &[u8]) {
assert_eq!(a.len(), b.len());
a.iter_mut().zip(b.iter()).for_each(|(x1, &x2)| *x1 ^= x2);
}
fn compute_mac(key: &HeaderIntegrityMacKey, data: &[u8]) -> HeaderIntegrityMac {
let mac = sphinx_packet::crypto::compute_keyed_hmac::<sha2::Sha256>(key, data).into_bytes();
assert!(mac.len() >= HEADER_INTEGRITY_MAC_SIZE);
HeaderIntegrityMac::from_bytes(
mac.into_iter()
.take(HEADER_INTEGRITY_MAC_SIZE)
.collect::<Vec<_>>()
.try_into()
.unwrap(),
)
}
}
fn random_bytes(size: usize) -> Vec<u8> {
let mut bytes = vec![0u8; size];
let mut rng = ChaCha12Rng::from_entropy();
rng.fill_bytes(&mut bytes);
bytes
}
enum StreamCipherOption {
FromFront,
FromBack,
}
#[cfg(test)]
mod tests {
use sphinx_packet::{constants::INTEGRITY_MAC_KEY_SIZE, crypto::STREAM_CIPHER_KEY_SIZE};
use super::*;
#[test]
fn build_and_unpack() {
let cipher = ConsistentLengthLayeredCipher::<[u8; 10]>::new(5);
let params = (0u8..3)
.map(|i| EncryptionParam::<[u8; 10]> {
data: [i; 10],
key: Key {
stream_cipher_key: [i * 10; STREAM_CIPHER_KEY_SIZE],
integrity_mac_key: [i * 20; INTEGRITY_MAC_KEY_SIZE],
},
})
.collect::<Vec<_>>();
let (encrypted, mac) = cipher.encrypt(&params).unwrap();
let next_encrypted = encrypted.clone();
let (data, next_mac, next_encrypted) = cipher
.unpack(&mac, &next_encrypted, &params[0].key)
.unwrap();
assert_eq!(data, params[0].data);
assert_eq!(next_encrypted.len(), encrypted.len());
let (data, next_mac, next_encrypted) = cipher
.unpack(&next_mac, &next_encrypted, &params[1].key)
.unwrap();
assert_eq!(data, params[1].data);
assert_eq!(next_encrypted.len(), encrypted.len());
let (data, _, next_encrypted) = cipher
.unpack(&next_mac, &next_encrypted, &params[2].key)
.unwrap();
assert_eq!(data, params[2].data);
assert_eq!(next_encrypted.len(), encrypted.len());
}
impl ConsistentLengthLayeredCipherData for [u8; 10] {
fn to_bytes(&self) -> Vec<u8> {
self.to_vec()
}
const SIZE: usize = 10;
}
}

View File

@ -0,0 +1,67 @@
use error::Error;
use packet::{Packet, UnpackedPacket};
use crate::MixMessage;
pub mod error;
mod layered_cipher;
pub mod packet;
mod routing;
pub struct SphinxMessage;
const ASYM_KEY_SIZE: usize = 32;
const PADDED_PAYLOAD_SIZE: usize = 2048;
const MAX_LAYERS: usize = 5;
impl MixMessage for SphinxMessage {
type PublicKey = [u8; ASYM_KEY_SIZE];
type PrivateKey = [u8; ASYM_KEY_SIZE];
type Error = Error;
const DROP_MESSAGE: &'static [u8] = &[0; Packet::size(MAX_LAYERS, PADDED_PAYLOAD_SIZE)];
fn build_message(
payload: &[u8],
public_keys: &[Self::PublicKey],
) -> Result<Vec<u8>, Self::Error> {
let packet = Packet::build(
&public_keys
.iter()
.map(|k| x25519_dalek::PublicKey::from(*k))
.collect::<Vec<_>>(),
MAX_LAYERS,
payload,
PADDED_PAYLOAD_SIZE,
)?;
Ok(packet.to_bytes())
}
fn unwrap_message(
message: &[u8],
private_key: &Self::PrivateKey,
) -> Result<(Vec<u8>, bool), Self::Error> {
let packet = Packet::from_bytes(message, MAX_LAYERS)?;
let unpacked_packet =
packet.unpack(&x25519_dalek::StaticSecret::from(*private_key), MAX_LAYERS)?;
match unpacked_packet {
UnpackedPacket::ToForward(packet) => Ok((packet.to_bytes(), false)),
UnpackedPacket::FullyUnpacked(payload) => Ok((payload, true)),
}
}
}
fn parse_bytes<'a>(data: &'a [u8], sizes: &[usize]) -> Result<Vec<&'a [u8]>, String> {
let mut i = 0;
sizes
.iter()
.map(|&size| {
if i + size > data.len() {
return Err("The sum of sizes exceeds the length of the input slice".to_string());
}
let slice = &data[i..i + size];
i += size;
Ok(slice)
})
.collect()
}

View File

@ -1,11 +1,20 @@
use crate::Error; use sphinx_packet::{
use serde::{Deserialize, Serialize}; constants::NODE_ADDRESS_LENGTH,
use sphinx_packet::constants::NODE_ADDRESS_LENGTH; header::{
keys::RoutingKeys,
routing::{FINAL_HOP, FORWARD_HOP},
},
payload::Payload,
};
use crate::sphinx::ASYM_KEY_SIZE;
use super::{error::Error, parse_bytes, routing::EncryptedRoutingInformation};
/// A packet that contains a header and a payload. /// A packet that contains a header and a payload.
/// The header and payload are encrypted for the selected recipients. /// The header and payload are encrypted for the selected recipients.
/// This packet can be serialized and sent over the network. /// This packet can be serialized and sent over the network.
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug)]
pub struct Packet { pub struct Packet {
header: Header, header: Header,
// This crate doesn't limit the payload size. // This crate doesn't limit the payload size.
@ -14,36 +23,29 @@ pub struct Packet {
} }
/// The packet header /// The packet header
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug)]
struct Header { struct Header {
/// The ephemeral public key for a recipient to derive the shared secret /// The ephemeral public key for a recipient to derive the shared secret
/// which can be used to decrypt the header and payload. /// which can be used to decrypt the header and payload.
ephemeral_public_key: x25519_dalek::PublicKey, ephemeral_public_key: x25519_dalek::PublicKey,
// TODO: Length-preserved layered encryption on RoutingInfo encrypted_routing_info: EncryptedRoutingInformation,
routing_info: RoutingInfo,
}
#[derive(Debug, Serialize, Deserialize)]
struct RoutingInfo {
// TODO: Change this to `is_final_layer: bool`
// by implementing the length-preserved layered encryption.
// It's not good to expose the info that how many layers remain to the intermediate recipients.
remaining_layers: u8,
// TODO:: Add the following fields
// header_integrity_hamc
// additional data (e.g. incentivization)
} }
impl Packet { impl Packet {
pub fn build( pub fn build(
recipient_pubkeys: &[x25519_dalek::PublicKey], recipient_pubkeys: &[x25519_dalek::PublicKey],
max_layers: usize,
payload: &[u8], payload: &[u8],
payload_size: usize, max_payload_size: usize,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
// Derive `[sphinx_packet::header::keys::KeyMaterial]` for all recipients. // Derive `[sphinx_packet::header::keys::KeyMaterial]` for all recipients.
let ephemeral_privkey = x25519_dalek::StaticSecret::random(); let ephemeral_privkey = x25519_dalek::StaticSecret::random();
let key_material = Self::derive_key_material(recipient_pubkeys, &ephemeral_privkey); let key_material = Self::derive_key_material(recipient_pubkeys, &ephemeral_privkey);
// Build the encrypted routing information.
let encrypted_routing_info =
EncryptedRoutingInformation::new(&key_material.routing_keys, max_layers)?;
// Encrypt the payload for all recipients. // Encrypt the payload for all recipients.
let payload_keys = key_material let payload_keys = key_material
.routing_keys .routing_keys
@ -53,22 +55,19 @@ impl Packet {
let payload = sphinx_packet::payload::Payload::encapsulate_message( let payload = sphinx_packet::payload::Payload::encapsulate_message(
payload, payload,
&payload_keys, &payload_keys,
payload_size, max_payload_size,
)?; )?;
Ok(Packet { Ok(Packet {
header: Header { header: Header {
ephemeral_public_key: x25519_dalek::PublicKey::from(&ephemeral_privkey), ephemeral_public_key: x25519_dalek::PublicKey::from(&ephemeral_privkey),
routing_info: RoutingInfo { encrypted_routing_info,
remaining_layers: u8::try_from(recipient_pubkeys.len())
.map_err(|_| Error::InvalidNumberOfLayers)?,
},
}, },
payload: payload.into_bytes(), payload: payload.into_bytes(),
}) })
} }
fn derive_key_material( pub(crate) fn derive_key_material(
recipient_pubkeys: &[x25519_dalek::PublicKey], recipient_pubkeys: &[x25519_dalek::PublicKey],
ephemeral_privkey: &x25519_dalek::StaticSecret, ephemeral_privkey: &x25519_dalek::StaticSecret,
) -> sphinx_packet::header::keys::KeyMaterial { ) -> sphinx_packet::header::keys::KeyMaterial {
@ -90,6 +89,7 @@ impl Packet {
pub fn unpack( pub fn unpack(
&self, &self,
private_key: &x25519_dalek::StaticSecret, private_key: &x25519_dalek::StaticSecret,
max_layers: usize,
) -> Result<UnpackedPacket, Error> { ) -> Result<UnpackedPacket, Error> {
// Derive the routing keys for the recipient // Derive the routing keys for the recipient
let routing_keys = sphinx_packet::header::SphinxHeader::compute_routing_keys( let routing_keys = sphinx_packet::header::SphinxHeader::compute_routing_keys(
@ -101,25 +101,40 @@ impl Packet {
let payload = sphinx_packet::payload::Payload::from_bytes(&self.payload)?; let payload = sphinx_packet::payload::Payload::from_bytes(&self.payload)?;
let payload = payload.unwrap(&routing_keys.payload_key)?; let payload = payload.unwrap(&routing_keys.payload_key)?;
// If this is the last layer of encryption, return the decrypted payload. // Unpack the routing information
if self.header.routing_info.remaining_layers == 1 { let (routing_info, next_encrypted_routing_info) = self
return Ok(UnpackedPacket::FullyUnpacked(payload.recover_plaintext()?)); .header
.encrypted_routing_info
.unpack(&routing_keys, max_layers)?;
match routing_info.flag {
FORWARD_HOP => Ok(UnpackedPacket::ToForward(self.build_next_packet(
&routing_keys,
next_encrypted_routing_info,
payload,
))),
FINAL_HOP => Ok(UnpackedPacket::FullyUnpacked(payload.recover_plaintext()?)),
_ => Err(Error::InvalidRoutingFlag(routing_info.flag)),
} }
}
fn build_next_packet(
&self,
routing_keys: &RoutingKeys,
next_encrypted_routing_info: EncryptedRoutingInformation,
payload: Payload,
) -> Packet {
// Derive the new ephemeral public key for the next recipient // Derive the new ephemeral public key for the next recipient
let next_ephemeral_pubkey = Self::derive_next_ephemeral_public_key( let next_ephemeral_pubkey = Self::derive_next_ephemeral_public_key(
&self.header.ephemeral_public_key, &self.header.ephemeral_public_key,
&routing_keys.blinding_factor, &routing_keys.blinding_factor,
); );
Ok(UnpackedPacket::ToForward(Packet { Packet {
header: Header { header: Header {
ephemeral_public_key: next_ephemeral_pubkey, ephemeral_public_key: next_ephemeral_pubkey,
routing_info: RoutingInfo { encrypted_routing_info: next_encrypted_routing_info,
remaining_layers: self.header.routing_info.remaining_layers - 1,
},
}, },
payload: payload.into_bytes(), payload: payload.into_bytes(),
})) }
} }
/// Derive the next ephemeral public key for the next recipient. /// Derive the next ephemeral public key for the next recipient.
@ -135,6 +150,49 @@ impl Packet {
let new_shared_secret = blinding_factor.diffie_hellman(cur_ephemeral_pubkey); let new_shared_secret = blinding_factor.diffie_hellman(cur_ephemeral_pubkey);
x25519_dalek::PublicKey::from(new_shared_secret.to_bytes()) x25519_dalek::PublicKey::from(new_shared_secret.to_bytes())
} }
pub fn to_bytes(&self) -> Vec<u8> {
let ephemeral_public_key = self.header.ephemeral_public_key.to_bytes();
let encrypted_routing_info = self.header.encrypted_routing_info.to_bytes();
itertools::chain!(
&ephemeral_public_key,
&encrypted_routing_info,
&self.payload,
)
.copied()
.collect()
}
pub fn from_bytes(data: &[u8], max_layers: usize) -> Result<Self, Error> {
let ephemeral_public_key_size = ASYM_KEY_SIZE;
let encrypted_routing_info_size = EncryptedRoutingInformation::size(max_layers);
let parsed = parse_bytes(
data,
&[
ephemeral_public_key_size,
encrypted_routing_info_size,
data.len() - ephemeral_public_key_size - encrypted_routing_info_size,
],
)
.map_err(|_| Error::InvalidPacketBytes)?;
Ok(Packet {
header: Header {
ephemeral_public_key: {
let bytes: [u8; 32] = parsed[0].try_into().unwrap();
x25519_dalek::PublicKey::from(bytes)
},
encrypted_routing_info: EncryptedRoutingInformation::from_bytes(
parsed[1], max_layers,
)?,
},
payload: parsed[2].to_vec(),
})
}
pub const fn size(max_layers: usize, max_payload_size: usize) -> usize {
ASYM_KEY_SIZE + EncryptedRoutingInformation::size(max_layers) + max_payload_size
}
} }
pub enum UnpackedPacket { pub enum UnpackedPacket {
@ -144,14 +202,12 @@ pub enum UnpackedPacket {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use nomos_core::wire;
use super::*; use super::*;
#[test] #[test]
fn unpack() { fn unpack() {
// Prepare keys of two recipients // Prepare keys of two recipients
let recipient_privkeys = (0..2) let recipient_privkeys = (0..3)
.map(|_| x25519_dalek::StaticSecret::random()) .map(|_| x25519_dalek::StaticSecret::random())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let recipient_pubkeys = recipient_privkeys let recipient_pubkeys = recipient_privkeys
@ -160,18 +216,26 @@ mod tests {
.collect::<Vec<_>>(); .collect::<Vec<_>>();
// Build a packet // Build a packet
let max_layers = 5;
let payload = [10u8; 512]; let payload = [10u8; 512];
let packet = Packet::build(&recipient_pubkeys, &payload, 1024).unwrap(); let packet = Packet::build(&recipient_pubkeys, max_layers, &payload, 1024).unwrap();
// The 1st recipient unpacks the packet // The 1st recipient unpacks the packet
let packet = match packet.unpack(&recipient_privkeys[0]).unwrap() { let packet = match packet.unpack(&recipient_privkeys[0], max_layers).unwrap() {
UnpackedPacket::ToForward(packet) => packet, UnpackedPacket::ToForward(packet) => packet,
UnpackedPacket::FullyUnpacked(_) => { UnpackedPacket::FullyUnpacked(_) => {
panic!("The unpacked packet should be the ToFoward type"); panic!("The unpacked packet should be the ToFoward type");
} }
}; };
// The 2nd recipient unpacks the packet // The 2nd recipient unpacks the packet
match packet.unpack(&recipient_privkeys[1]).unwrap() { let packet = match packet.unpack(&recipient_privkeys[1], max_layers).unwrap() {
UnpackedPacket::ToForward(packet) => packet,
UnpackedPacket::FullyUnpacked(_) => {
panic!("The unpacked packet should be the ToFoward type");
}
};
// The last recipient unpacks the packet
match packet.unpack(&recipient_privkeys[2], max_layers).unwrap() {
UnpackedPacket::ToForward(_) => { UnpackedPacket::ToForward(_) => {
panic!("The unpacked packet should be the FullyUnpacked type"); panic!("The unpacked packet should be the FullyUnpacked type");
} }
@ -185,34 +249,25 @@ mod tests {
#[test] #[test]
fn unpack_with_wrong_keys() { fn unpack_with_wrong_keys() {
// Build a packet with two public keys // Build a packet with two public keys
let max_layers = 5;
let payload = [10u8; 512]; let payload = [10u8; 512];
let packet = Packet::build( let packet = Packet::build(
&(0..2) &(0..2)
.map(|_| x25519_dalek::PublicKey::from(&x25519_dalek::StaticSecret::random())) .map(|_| x25519_dalek::PublicKey::from(&x25519_dalek::StaticSecret::random()))
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
max_layers,
&payload, &payload,
1024, 1024,
) )
.unwrap(); .unwrap();
// The 1st recipient unpacks the packet with an wrong key
let packet = match packet
.unpack(&x25519_dalek::StaticSecret::random())
.unwrap()
{
UnpackedPacket::ToForward(packet) => packet,
UnpackedPacket::FullyUnpacked(_) => {
panic!("The unpacked packet should be the ToFoward type");
}
};
// The 2nd recipient unpacks the packet with an wrong key
assert!(packet assert!(packet
.unpack(&x25519_dalek::StaticSecret::random()) .unpack(&x25519_dalek::StaticSecret::random(), max_layers)
.is_err()); .is_err());
} }
#[test] #[test]
fn consistent_size_serialization() { fn consistent_size_after_unpack() {
// Prepare keys of two recipients // Prepare keys of two recipients
let recipient_privkeys = (0..2) let recipient_privkeys = (0..2)
.map(|_| x25519_dalek::StaticSecret::random()) .map(|_| x25519_dalek::StaticSecret::random())
@ -223,28 +278,76 @@ mod tests {
.collect::<Vec<_>>(); .collect::<Vec<_>>();
// Build a packet // Build a packet
let max_layers = 5;
let payload = [10u8; 512]; let payload = [10u8; 512];
let packet = Packet::build(&recipient_pubkeys, &payload, 1024).unwrap(); let max_payload_size = 1024;
let packet =
Packet::build(&recipient_pubkeys, max_layers, &payload, max_payload_size).unwrap();
// Calculate the expected packet size // Calculate the expected packet size
let pubkey_size = 32; let packet_size = Packet::size(max_layers, max_payload_size);
let routing_info_size = 1;
let payload_length_enconding_size = 8;
let payload_size = 1024;
let packet_size =
pubkey_size + routing_info_size + payload_length_enconding_size + payload_size;
// The serialized packet size must be the same as the expected size. // The serialized packet size must be the same as the expected size.
assert_eq!(wire::serialize(&packet).unwrap().len(), packet_size); assert_eq!(packet.to_bytes().len(), packet_size);
// The unpacked packet size must be the same as the original packet size. // The unpacked packet size must be the same as the original packet size.
match packet.unpack(&recipient_privkeys[0]).unwrap() { match packet.unpack(&recipient_privkeys[0], max_layers).unwrap() {
UnpackedPacket::ToForward(packet) => { UnpackedPacket::ToForward(packet) => {
assert_eq!(wire::serialize(&packet).unwrap().len(), packet_size); assert_eq!(packet.to_bytes().len(), packet_size);
} }
UnpackedPacket::FullyUnpacked(_) => { UnpackedPacket::FullyUnpacked(_) => {
panic!("The unpacked packet should be the ToFoward type"); panic!("The unpacked packet should be the ToFoward type");
} }
} }
} }
#[test]
fn consistent_size_with_any_num_layers() {
let max_layers = 5;
let payload = [10u8; 512];
// Build a packet with 2 recipients
let recipient_pubkeys = (0..2)
.map(|_| x25519_dalek::PublicKey::from(&x25519_dalek::StaticSecret::random()))
.collect::<Vec<_>>();
let packet1 = Packet::build(&recipient_pubkeys, max_layers, &payload, 1024).unwrap();
// Build a packet with 3 recipients
let recipient_pubkeys = (0..3)
.map(|_| x25519_dalek::PublicKey::from(&x25519_dalek::StaticSecret::random()))
.collect::<Vec<_>>();
let packet2 = Packet::build(&recipient_pubkeys, max_layers, &payload, 1024).unwrap();
assert_eq!(packet1.to_bytes().len(), packet2.to_bytes().len());
}
#[test]
fn to_from_bytes() {
let max_layers = 5;
let payload = [10u8; 512];
// Build a packet with 2 recipients
let recipient_pubkeys = (0..2)
.map(|_| x25519_dalek::PublicKey::from(&x25519_dalek::StaticSecret::random()))
.collect::<Vec<_>>();
let packet = Packet::build(&recipient_pubkeys, max_layers, &payload, 1024).unwrap();
let bytes = packet.to_bytes();
let loaded_packet = Packet::from_bytes(&bytes, max_layers).unwrap();
// Manually compare packets because PartialEq is not implemented
// for [`sphinx_packet::header::mac::HeaderIntegrityMac`] used in our header.
assert_eq!(
packet.header.ephemeral_public_key,
loaded_packet.header.ephemeral_public_key
);
assert_eq!(
packet.header.encrypted_routing_info.encrypted_routing_info,
loaded_packet
.header
.encrypted_routing_info
.encrypted_routing_info
);
assert_eq!(packet.payload, loaded_packet.payload);
}
} }

View File

@ -0,0 +1,140 @@
use sphinx_packet::{
constants::HEADER_INTEGRITY_MAC_SIZE,
header::{
keys::RoutingKeys,
mac::HeaderIntegrityMac,
routing::{RoutingFlag, FINAL_HOP, FORWARD_HOP},
},
};
use super::{
error::Error,
layered_cipher::{
ConsistentLengthLayeredCipher, ConsistentLengthLayeredCipherData, EncryptionParam, Key,
},
parse_bytes,
};
/// A routing information that will be contained in a packet header
/// in the encrypted format.
pub struct RoutingInformation {
pub flag: RoutingFlag,
// Add additional fields here
}
impl RoutingInformation {
pub fn new(flag: RoutingFlag) -> Self {
Self { flag }
}
pub fn from_bytes(data: &[u8]) -> Result<Self, Error> {
if data.len() != Self::SIZE {
return Err(Error::InvalidEncryptedRoutingInfoLength(data.len()));
}
Ok(Self { flag: data[0] })
}
}
impl ConsistentLengthLayeredCipherData for RoutingInformation {
fn to_bytes(&self) -> Vec<u8> {
vec![self.flag]
}
const SIZE: usize = std::mem::size_of::<RoutingFlag>();
}
/// Encrypted routing information that will be contained in a packet header.
#[derive(Debug)]
pub struct EncryptedRoutingInformation {
/// A MAC to verify the integrity of [`Self::encrypted_routing_info`].
pub mac: HeaderIntegrityMac,
/// The actual encrypted routing information produced by [`ConsistentLengthLayeredCipher`].
/// Its size should be the same as [`ConsistentLengthLayeredCipher::total_size`].
pub encrypted_routing_info: Vec<u8>,
}
type LayeredCipher = ConsistentLengthLayeredCipher<RoutingInformation>;
impl EncryptedRoutingInformation {
/// Build all [`RoutingInformation`]s for the provides keys,
/// and encrypt them using [`ConsistentLengthLayeredCipher`].
pub fn new(routing_keys: &[RoutingKeys], max_layers: usize) -> Result<Self, Error> {
let cipher = LayeredCipher::new(max_layers);
let params = routing_keys
.iter()
.enumerate()
.map(|(i, k)| {
let flag = if i == routing_keys.len() - 1 {
FINAL_HOP
} else {
FORWARD_HOP
};
EncryptionParam::<RoutingInformation> {
data: RoutingInformation::new(flag),
key: Self::layered_cipher_key(k),
}
})
.collect::<Vec<_>>();
let (encrypted, mac) = cipher.encrypt(&params)?;
Ok(Self {
mac,
encrypted_routing_info: encrypted,
})
}
/// Unpack one layer of encryptions using the key provided.
/// Returns the decrypted routing information
/// and the next [`EncryptedRoutingInformation`] to be unpacked further.
pub fn unpack(
&self,
routing_key: &RoutingKeys,
max_layers: usize,
) -> Result<(RoutingInformation, Self), Error> {
let cipher = LayeredCipher::new(max_layers);
let (routing_info, next_mac, next_encrypted_routing_info) = cipher.unpack(
&self.mac,
&self.encrypted_routing_info,
&Self::layered_cipher_key(routing_key),
)?;
Ok((
RoutingInformation::from_bytes(&routing_info)?,
Self {
mac: next_mac,
encrypted_routing_info: next_encrypted_routing_info,
},
))
}
fn layered_cipher_key(routing_key: &RoutingKeys) -> Key {
Key {
stream_cipher_key: routing_key.stream_cipher_key,
integrity_mac_key: routing_key.header_integrity_hmac_key,
}
}
pub fn to_bytes(&self) -> Vec<u8> {
itertools::chain!(self.mac.as_bytes(), &self.encrypted_routing_info)
.copied()
.collect()
}
pub fn from_bytes(data: &[u8], max_layers: usize) -> Result<Self, Error> {
let parsed = parse_bytes(
data,
&[
HEADER_INTEGRITY_MAC_SIZE,
LayeredCipher::total_size(max_layers),
],
)
.map_err(|_| Error::InvalidEncryptedRoutingInfoLength(data.len()))?;
Ok(Self {
mac: HeaderIntegrityMac::from_bytes(parsed[0].try_into().unwrap()),
encrypted_routing_info: parsed[1].to_vec(),
})
}
pub const fn size(max_layers: usize) -> usize {
HEADER_INTEGRITY_MAC_SIZE + LayeredCipher::total_size(max_layers)
}
}