diff --git a/nomos-mix/core/src/message_blend/crypto.rs b/nomos-mix/core/src/message_blend/crypto.rs index 67d80afa..62a8a865 100644 --- a/nomos-mix/core/src/message_blend/crypto.rs +++ b/nomos-mix/core/src/message_blend/crypto.rs @@ -38,7 +38,7 @@ where } } - pub fn wrap_message(&mut self, message: &[u8]) -> Result, nomos_mix_message::Error> { + pub fn wrap_message(&mut self, message: &[u8]) -> Result, M::Error> { // TODO: Use the actual Sphinx encoding instead of mock. let public_keys = self .membership @@ -50,10 +50,7 @@ where M::build_message(message, &public_keys) } - pub fn unwrap_message( - &self, - message: &[u8], - ) -> Result<(Vec, bool), nomos_mix_message::Error> { + pub fn unwrap_message(&self, message: &[u8]) -> Result<(Vec, bool), M::Error> { M::unwrap_message(message, &self.settings.private_key) } } diff --git a/nomos-mix/core/src/message_blend/mod.rs b/nomos-mix/core/src/message_blend/mod.rs index 13d6a2ce..2cc988b6 100644 --- a/nomos-mix/core/src/message_blend/mod.rs +++ b/nomos-mix/core/src/message_blend/mod.rs @@ -4,6 +4,7 @@ pub mod temporal; pub use crypto::CryptographicProcessorSettings; use futures::{Stream, StreamExt}; use rand::RngCore; +use std::fmt::Debug; use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; @@ -52,6 +53,7 @@ where M: MixMessage, M::PrivateKey: Serialize + DeserializeOwned, M::PublicKey: Clone + PartialEq, + M::Error: Debug, Scheduler: Stream + Unpin + Send + Sync + 'static, { pub fn new( @@ -91,9 +93,6 @@ where tracing::error!("Failed to send message to the outbound channel: {e:?}"); } } - Err(nomos_mix_message::Error::MsgUnwrapNotAllowed) => { - tracing::debug!("Message cannot be unwrapped by this node"); - } Err(e) => { tracing::error!("Failed to unwrap message: {:?}", e); } @@ -108,6 +107,7 @@ where M: MixMessage + Unpin, M::PrivateKey: Serialize + DeserializeOwned + Unpin, M::PublicKey: Clone + PartialEq + Unpin, + M::Error: Debug, Scheduler: Stream + Unpin + Send + Sync + 'static, { type Item = MixOutgoingMessage; @@ -126,6 +126,7 @@ where M: MixMessage, M::PrivateKey: Serialize + DeserializeOwned, M::PublicKey: Clone + PartialEq, + M::Error: Debug, Scheduler: Stream + Unpin + Send + Sync + 'static, { fn blend( @@ -155,6 +156,7 @@ where M: MixMessage, M::PrivateKey: Clone + Serialize + DeserializeOwned + PartialEq, M::PublicKey: Clone + Serialize + DeserializeOwned + PartialEq, + M::Error: Debug, S: Stream + Unpin + Send + Sync + 'static, { } diff --git a/nomos-mix/message/Cargo.toml b/nomos-mix/message/Cargo.toml index 69488a8e..3a4c33c1 100644 --- a/nomos-mix/message/Cargo.toml +++ b/nomos-mix/message/Cargo.toml @@ -4,7 +4,9 @@ version = "0.1.0" edition = "2021" [dependencies] -serde = { version = "1.0", features = ["derive"] } +itertools = "0.13" +rand_chacha = "0.3" +sha2 = "0.10" sphinx-packet = "0.2" thiserror = "1.0.65" x25519-dalek = { version = "2.0.1", features = [ diff --git a/nomos-mix/message/src/lib.rs b/nomos-mix/message/src/lib.rs index 5debe180..cc8e4c27 100644 --- a/nomos-mix/message/src/lib.rs +++ b/nomos-mix/message/src/lib.rs @@ -1,15 +1,16 @@ -mod error; pub mod mock; -pub mod packet; - -pub use error::Error; +pub mod sphinx; pub trait MixMessage { type PublicKey; type PrivateKey; + type Error; const DROP_MESSAGE: &'static [u8]; - fn build_message(payload: &[u8], public_keys: &[Self::PublicKey]) -> Result, Error>; + fn build_message( + payload: &[u8], + public_keys: &[Self::PublicKey], + ) -> Result, Self::Error>; /// Unwrap the message one layer. /// /// This function returns the unwrapped message and a boolean indicating whether the message was fully unwrapped. @@ -20,7 +21,7 @@ pub trait MixMessage { fn unwrap_message( message: &[u8], private_key: &Self::PrivateKey, - ) -> Result<(Vec, bool), Error>; + ) -> Result<(Vec, bool), Self::Error>; fn is_drop_message(message: &[u8]) -> bool { message == Self::DROP_MESSAGE } diff --git a/nomos-mix/message/src/error.rs b/nomos-mix/message/src/mock/error.rs similarity index 81% rename from nomos-mix/message/src/error.rs rename to nomos-mix/message/src/mock/error.rs index 38b4e93a..bc54d0fc 100644 --- a/nomos-mix/message/src/error.rs +++ b/nomos-mix/message/src/mock/error.rs @@ -6,8 +6,6 @@ pub enum Error { PayloadTooLarge, #[error("Invalid number of layers")] InvalidNumberOfLayers, - #[error("Sphinx packet error: {0}")] - SphinxPacketError(#[from] sphinx_packet::Error), #[error("Unwrapping a message is not allowed to this node")] /// e.g. the message cannot be unwrapped using the private key provided MsgUnwrapNotAllowed, diff --git a/nomos-mix/message/src/mock/mod.rs b/nomos-mix/message/src/mock/mod.rs index ea8f6d5d..74e5738f 100644 --- a/nomos-mix/message/src/mock/mod.rs +++ b/nomos-mix/message/src/mock/mod.rs @@ -1,4 +1,8 @@ -use crate::{Error, MixMessage}; +pub mod error; + +use error::Error; + +use crate::MixMessage; // TODO: Remove all the mock below once the actual implementation is integrated to the system. // /// A mock implementation of the Sphinx encoding. @@ -17,13 +21,17 @@ pub struct MockMixMessage; impl MixMessage for MockMixMessage { type PublicKey = [u8; NODE_ID_SIZE]; type PrivateKey = [u8; NODE_ID_SIZE]; + type Error = Error; const DROP_MESSAGE: &'static [u8] = &[0; MESSAGE_SIZE]; /// The length of the encoded message is fixed to [`MESSAGE_SIZE`] bytes. /// The [`MAX_LAYERS`] number of [`NodeId`]s are concatenated in front of the payload. /// The payload is zero-padded to the end. /// - fn build_message(payload: &[u8], public_keys: &[Self::PublicKey]) -> Result, Error> { + fn build_message( + payload: &[u8], + public_keys: &[Self::PublicKey], + ) -> Result, Self::Error> { // In this mock, we don't encrypt anything. So, we use public key as just a node ID. let node_ids = public_keys; if node_ids.is_empty() || node_ids.len() > MAX_LAYERS { @@ -54,7 +62,7 @@ impl MixMessage for MockMixMessage { fn unwrap_message( message: &[u8], private_key: &Self::PrivateKey, - ) -> Result<(Vec, bool), Error> { + ) -> Result<(Vec, bool), Self::Error> { if message.len() != MESSAGE_SIZE { return Err(Error::InvalidMixMessage); } diff --git a/nomos-mix/message/src/sphinx/error.rs b/nomos-mix/message/src/sphinx/error.rs new file mode 100644 index 00000000..55af8cd2 --- /dev/null +++ b/nomos-mix/message/src/sphinx/error.rs @@ -0,0 +1,15 @@ +use sphinx_packet::header::routing::RoutingFlag; + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("Sphinx packet error: {0}")] + SphinxPacketError(#[from] sphinx_packet::Error), + #[error("Invalid packet bytes")] + InvalidPacketBytes, + #[error("Invalid routing flag: {0}")] + InvalidRoutingFlag(RoutingFlag), + #[error("Invalid routing length: {0} bytes")] + InvalidEncryptedRoutingInfoLength(usize), + #[error("ConsistentLengthLayeredEncryptionError: {0}")] + ConsistentLengthLayeredEncryptionError(#[from] super::layered_cipher::Error), +} diff --git a/nomos-mix/message/src/sphinx/layered_cipher.rs b/nomos-mix/message/src/sphinx/layered_cipher.rs new file mode 100644 index 00000000..0bfcf249 --- /dev/null +++ b/nomos-mix/message/src/sphinx/layered_cipher.rs @@ -0,0 +1,345 @@ +use std::marker::PhantomData; + +use rand_chacha::{ + rand_core::{RngCore, SeedableRng}, + ChaCha12Rng, +}; +use sphinx_packet::{ + constants::HEADER_INTEGRITY_MAC_SIZE, + crypto::STREAM_CIPHER_INIT_VECTOR, + header::{ + keys::{HeaderIntegrityMacKey, StreamCipherKey}, + mac::HeaderIntegrityMac, + }, +}; + +use super::parse_bytes; + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("Invalid cipher text length")] + InvalidCipherTextLength, + #[error("Invalid encryption param")] + InvalidEncryptionParam, + #[error("Integrity MAC verification failed")] + IntegrityMacVerificationFailed, +} + +type Result = std::result::Result; + +/// A cipher to encrypt/decrypt a list of data of the same size using a list of keys. +/// +/// The cipher performs the layered encryption. +/// The following example shows the simplified output. +/// - Input: [[data0, k0], [data1, k1]] +/// - Output: encrypt(k0, [data0, encrypt(k1, [data1])]) +/// +/// The max number of layers is limited to the `max_layers` parameter. +/// Even if the number of data and keys provided for encryption is smaller than `max_layers`, +/// The cipher always produces the max-size output regardless of the number of data and keys provided, +/// in order to ensure that all outputs generated by the cipher are the same size. +/// +/// The cipher also provides the length-preserved decryption. +/// Even if one layer of encryptions is decrypted, the length of decrypted data is +/// the same as the length of the original data. +/// For example: +/// len(encrypt(k0, [data0, encrypt(k1, [data1])])) == len(encrypt(k1, [data1])) +pub struct ConsistentLengthLayeredCipher { + /// All encrypted data produced by the cipher has the same size according to the `max_layers`. + pub max_layers: usize, + _data: PhantomData, +} + +pub trait ConsistentLengthLayeredCipherData { + // Returns the serialized bytes for an instance of the implementing type + fn to_bytes(&self) -> Vec; + // The size of the serialized data. + const SIZE: usize; +} + +/// A parameter for one layer of encryption +pub struct EncryptionParam { + /// A data to be included in the layer. + pub data: D, + /// A [`Key`] to encrypt the layer that will include the [`Self::data`]. + pub key: Key, +} + +/// A set of keys to encrypt/decrypt a single layer. +pub struct Key { + /// A 128-bit key for encryption/decryption + pub stream_cipher_key: StreamCipherKey, + /// A 128-bit key for computing/verifying integrity MAC + pub integrity_mac_key: HeaderIntegrityMacKey, +} + +impl ConsistentLengthLayeredCipher { + pub fn new(max_layers: usize) -> Self { + Self { + max_layers, + _data: Default::default(), + } + } + + /// The total size of fully encrypted output that includes all layers. + /// This size is determined by [`D::size`] and [`max_layers`]. + pub const fn total_size(max_layers: usize) -> usize { + Self::SINGLE_LAYER_SIZE * max_layers + } + + /// The size of a single layer that contains a data and a MAC. + /// The MAC is used to verify integrity of the encrypted next layer. + const SINGLE_LAYER_SIZE: usize = D::SIZE + HEADER_INTEGRITY_MAC_SIZE; + + /// Perform the layered encryption. + pub fn encrypt(&self, params: &[EncryptionParam]) -> Result<(Vec, HeaderIntegrityMac)> { + if params.is_empty() || params.len() > self.max_layers { + return Err(Error::InvalidEncryptionParam); + } + + params + .iter() + .take(params.len() - 1) // Exclude the last param that will be treated separately below. + .rev() // Data and keys must be used in reverse order to encrypt the inner-most layer first + .try_fold(self.build_last_layer(params)?, |(encrypted, mac), param| { + self.build_intermediate_layer(param, mac, encrypted) + }) + } + + /// Build an intermediate layer of encryption that wraps subsequent layers already encrypted. + /// The output has the same size as [`Self::total_size`], + /// regardless of how many subsequent layers that this layer wraps. + fn build_intermediate_layer( + &self, + param: &EncryptionParam, + next_mac: HeaderIntegrityMac, + next_encrypted_data: Vec, + ) -> Result<(Vec, HeaderIntegrityMac)> { + // Concatenate the data with the encrypted subsequent layers and its MAC. + let data = param.data.to_bytes(); + let total_data = itertools::chain!( + &data, + next_mac.as_bytes(), + // Truncate last bytes for the length-preserved decryption later. + // They will be restored by a filler during the decryption process. + &next_encrypted_data[..next_encrypted_data.len() - Self::SINGLE_LAYER_SIZE], + ) + .copied() + .collect::>(); + + // Encrypt the concatenated bytes, and compute MAC. + let mut encrypted = total_data; + self.apply_streamcipher( + &mut encrypted, + ¶m.key.stream_cipher_key, + StreamCipherOption::FromFront, + ); + let mac = Self::compute_mac(¶m.key.integrity_mac_key, &encrypted); + + assert_eq!(encrypted.len(), Self::total_size(self.max_layers)); + Ok((encrypted, mac)) + } + + /// Build the last layer of encryption. + /// The output has the same size as [`Self::total_size`] by using fillers, + /// even though it doesn't wrap any subsequent layer. + /// This is for the length-preserved decryption. + fn build_last_layer( + &self, + params: &[EncryptionParam], + ) -> Result<(Vec, HeaderIntegrityMac)> { + let last_param = params.last().ok_or(Error::InvalidEncryptionParam)?; + + // Build fillers that will be appended to the last data. + // The number of fillers must be the same as the number of intermediate layers + // (excluding the last layer) that will be decrypted later. + let fillers = self.build_fillers(¶ms[..params.len() - 1]); + // Header integrity MAC doesn't need to be included in the last layer + // because there is no next encrypted layer. + // Instead, random bytes are used to fill the space between data and fillers. + // The size of random bytes depends on the [`self.max_layers`]. + let random_bytes = + random_bytes(Self::total_size(self.max_layers) - D::SIZE - fillers.len()); + + // First, concat the data and the random bytes, and encrypt it. + let last_data = last_param.data.to_bytes(); + let total_data_without_fillers = itertools::chain!(&last_data, &random_bytes) + .copied() + .collect::>(); + let mut encrypted = total_data_without_fillers; + self.apply_streamcipher( + &mut encrypted, + &last_param.key.stream_cipher_key, + StreamCipherOption::FromFront, + ); + + // Append fillers to the encrypted bytes, and compute MAC. + encrypted.extend(fillers); + let mac = Self::compute_mac(&last_param.key.integrity_mac_key, &encrypted); + + assert_eq!(encrypted.len(), Self::total_size(self.max_layers)); + Ok((encrypted, mac)) + } + + /// Build as many fillers as the number of keys provided. + /// Fillers are encrypted in accumulated manner by keys. + fn build_fillers(&self, params: &[EncryptionParam]) -> Vec { + let mut fillers = vec![0u8; Self::SINGLE_LAYER_SIZE * params.len()]; + params + .iter() + .map(|param| ¶m.key.stream_cipher_key) + .enumerate() + .for_each(|(i, key)| { + self.apply_streamcipher( + &mut fillers[0..(i + 1) * Self::SINGLE_LAYER_SIZE], + key, + StreamCipherOption::FromBack, + ) + }); + fillers + } + + /// Unpack one layer of encryption by performing the length-preserved decryption. + pub fn unpack( + &self, + mac: &HeaderIntegrityMac, + encrypted_total_data: &[u8], + key: &Key, + ) -> Result<(Vec, HeaderIntegrityMac, Vec)> { + if encrypted_total_data.len() != Self::total_size(self.max_layers) { + return Err(Error::InvalidCipherTextLength); + } + // If a wrong key is used, the decryption should fail. + if !mac.verify(key.integrity_mac_key, encrypted_total_data) { + return Err(Error::IntegrityMacVerificationFailed); + } + + // Extend the encrypted data by the length of a single layer + // in order to restore the truncated part (a encrypted filler) + // by [`Self::build_intermediate_layer`] during the encryption process. + let total_data_with_zero_filler = encrypted_total_data + .iter() + .copied() + .chain(std::iter::repeat(0u8).take(Self::SINGLE_LAYER_SIZE)) + .collect::>(); + + // Decrypt the extended data. + let mut decrypted = total_data_with_zero_filler; + self.apply_streamcipher( + &mut decrypted, + &key.stream_cipher_key, + StreamCipherOption::FromFront, + ); + + // Parse the decrypted data into 3 parts: data, MAC, and the next encrypted data. + let parsed = parse_bytes( + &decrypted, + &[ + D::SIZE, + HEADER_INTEGRITY_MAC_SIZE, + Self::total_size(self.max_layers), + ], + ) + .unwrap(); + let data = parsed[0].to_vec(); + let next_mac = HeaderIntegrityMac::from_bytes(parsed[1].try_into().unwrap()); + let next_encrypted_data = parsed[2].to_vec(); + Ok((data, next_mac, next_encrypted_data)) + } + + fn apply_streamcipher(&self, data: &mut [u8], key: &StreamCipherKey, opt: StreamCipherOption) { + let pseudorandom_bytes = sphinx_packet::crypto::generate_pseudorandom_bytes( + key, + &STREAM_CIPHER_INIT_VECTOR, + Self::total_size(self.max_layers) + Self::SINGLE_LAYER_SIZE, + ); + let pseudorandom_bytes = match opt { + StreamCipherOption::FromFront => &pseudorandom_bytes[..data.len()], + StreamCipherOption::FromBack => { + &pseudorandom_bytes[pseudorandom_bytes.len() - data.len()..] + } + }; + Self::xor_in_place(data, pseudorandom_bytes) + } + + // In-place XOR operation: b is applied to a. + fn xor_in_place(a: &mut [u8], b: &[u8]) { + assert_eq!(a.len(), b.len()); + a.iter_mut().zip(b.iter()).for_each(|(x1, &x2)| *x1 ^= x2); + } + + fn compute_mac(key: &HeaderIntegrityMacKey, data: &[u8]) -> HeaderIntegrityMac { + let mac = sphinx_packet::crypto::compute_keyed_hmac::(key, data).into_bytes(); + assert!(mac.len() >= HEADER_INTEGRITY_MAC_SIZE); + HeaderIntegrityMac::from_bytes( + mac.into_iter() + .take(HEADER_INTEGRITY_MAC_SIZE) + .collect::>() + .try_into() + .unwrap(), + ) + } +} + +fn random_bytes(size: usize) -> Vec { + let mut bytes = vec![0u8; size]; + let mut rng = ChaCha12Rng::from_entropy(); + rng.fill_bytes(&mut bytes); + bytes +} + +enum StreamCipherOption { + FromFront, + FromBack, +} + +#[cfg(test)] +mod tests { + use sphinx_packet::{constants::INTEGRITY_MAC_KEY_SIZE, crypto::STREAM_CIPHER_KEY_SIZE}; + + use super::*; + + #[test] + fn build_and_unpack() { + let cipher = ConsistentLengthLayeredCipher::<[u8; 10]>::new(5); + + let params = (0u8..3) + .map(|i| EncryptionParam::<[u8; 10]> { + data: [i; 10], + key: Key { + stream_cipher_key: [i * 10; STREAM_CIPHER_KEY_SIZE], + integrity_mac_key: [i * 20; INTEGRITY_MAC_KEY_SIZE], + }, + }) + .collect::>(); + + let (encrypted, mac) = cipher.encrypt(¶ms).unwrap(); + + let next_encrypted = encrypted.clone(); + let (data, next_mac, next_encrypted) = cipher + .unpack(&mac, &next_encrypted, ¶ms[0].key) + .unwrap(); + assert_eq!(data, params[0].data); + assert_eq!(next_encrypted.len(), encrypted.len()); + + let (data, next_mac, next_encrypted) = cipher + .unpack(&next_mac, &next_encrypted, ¶ms[1].key) + .unwrap(); + assert_eq!(data, params[1].data); + assert_eq!(next_encrypted.len(), encrypted.len()); + + let (data, _, next_encrypted) = cipher + .unpack(&next_mac, &next_encrypted, ¶ms[2].key) + .unwrap(); + assert_eq!(data, params[2].data); + assert_eq!(next_encrypted.len(), encrypted.len()); + } + + impl ConsistentLengthLayeredCipherData for [u8; 10] { + fn to_bytes(&self) -> Vec { + self.to_vec() + } + + const SIZE: usize = 10; + } +} diff --git a/nomos-mix/message/src/sphinx/mod.rs b/nomos-mix/message/src/sphinx/mod.rs new file mode 100644 index 00000000..9c53217e --- /dev/null +++ b/nomos-mix/message/src/sphinx/mod.rs @@ -0,0 +1,67 @@ +use error::Error; +use packet::{Packet, UnpackedPacket}; + +use crate::MixMessage; + +pub mod error; +mod layered_cipher; +pub mod packet; +mod routing; + +pub struct SphinxMessage; + +const ASYM_KEY_SIZE: usize = 32; +const PADDED_PAYLOAD_SIZE: usize = 2048; +const MAX_LAYERS: usize = 5; + +impl MixMessage for SphinxMessage { + type PublicKey = [u8; ASYM_KEY_SIZE]; + type PrivateKey = [u8; ASYM_KEY_SIZE]; + type Error = Error; + + const DROP_MESSAGE: &'static [u8] = &[0; Packet::size(MAX_LAYERS, PADDED_PAYLOAD_SIZE)]; + + fn build_message( + payload: &[u8], + public_keys: &[Self::PublicKey], + ) -> Result, Self::Error> { + let packet = Packet::build( + &public_keys + .iter() + .map(|k| x25519_dalek::PublicKey::from(*k)) + .collect::>(), + MAX_LAYERS, + payload, + PADDED_PAYLOAD_SIZE, + )?; + Ok(packet.to_bytes()) + } + + fn unwrap_message( + message: &[u8], + private_key: &Self::PrivateKey, + ) -> Result<(Vec, bool), Self::Error> { + let packet = Packet::from_bytes(message, MAX_LAYERS)?; + let unpacked_packet = + packet.unpack(&x25519_dalek::StaticSecret::from(*private_key), MAX_LAYERS)?; + match unpacked_packet { + UnpackedPacket::ToForward(packet) => Ok((packet.to_bytes(), false)), + UnpackedPacket::FullyUnpacked(payload) => Ok((payload, true)), + } + } +} + +fn parse_bytes<'a>(data: &'a [u8], sizes: &[usize]) -> Result, String> { + let mut i = 0; + sizes + .iter() + .map(|&size| { + if i + size > data.len() { + return Err("The sum of sizes exceeds the length of the input slice".to_string()); + } + let slice = &data[i..i + size]; + i += size; + Ok(slice) + }) + .collect() +} diff --git a/nomos-mix/message/src/packet.rs b/nomos-mix/message/src/sphinx/packet.rs similarity index 52% rename from nomos-mix/message/src/packet.rs rename to nomos-mix/message/src/sphinx/packet.rs index 8e9267c3..583ac3ce 100644 --- a/nomos-mix/message/src/packet.rs +++ b/nomos-mix/message/src/sphinx/packet.rs @@ -1,11 +1,20 @@ -use crate::Error; -use serde::{Deserialize, Serialize}; -use sphinx_packet::constants::NODE_ADDRESS_LENGTH; +use sphinx_packet::{ + constants::NODE_ADDRESS_LENGTH, + header::{ + keys::RoutingKeys, + routing::{FINAL_HOP, FORWARD_HOP}, + }, + payload::Payload, +}; + +use crate::sphinx::ASYM_KEY_SIZE; + +use super::{error::Error, parse_bytes, routing::EncryptedRoutingInformation}; /// A packet that contains a header and a payload. /// The header and payload are encrypted for the selected recipients. /// This packet can be serialized and sent over the network. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug)] pub struct Packet { header: Header, // This crate doesn't limit the payload size. @@ -14,36 +23,29 @@ pub struct Packet { } /// The packet header -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug)] struct Header { /// The ephemeral public key for a recipient to derive the shared secret /// which can be used to decrypt the header and payload. ephemeral_public_key: x25519_dalek::PublicKey, - // TODO: Length-preserved layered encryption on RoutingInfo - routing_info: RoutingInfo, -} - -#[derive(Debug, Serialize, Deserialize)] -struct RoutingInfo { - // TODO: Change this to `is_final_layer: bool` - // by implementing the length-preserved layered encryption. - // It's not good to expose the info that how many layers remain to the intermediate recipients. - remaining_layers: u8, - // TODO:: Add the following fields - // header_integrity_hamc - // additional data (e.g. incentivization) + encrypted_routing_info: EncryptedRoutingInformation, } impl Packet { pub fn build( recipient_pubkeys: &[x25519_dalek::PublicKey], + max_layers: usize, payload: &[u8], - payload_size: usize, + max_payload_size: usize, ) -> Result { // Derive `[sphinx_packet::header::keys::KeyMaterial]` for all recipients. let ephemeral_privkey = x25519_dalek::StaticSecret::random(); let key_material = Self::derive_key_material(recipient_pubkeys, &ephemeral_privkey); + // Build the encrypted routing information. + let encrypted_routing_info = + EncryptedRoutingInformation::new(&key_material.routing_keys, max_layers)?; + // Encrypt the payload for all recipients. let payload_keys = key_material .routing_keys @@ -53,22 +55,19 @@ impl Packet { let payload = sphinx_packet::payload::Payload::encapsulate_message( payload, &payload_keys, - payload_size, + max_payload_size, )?; Ok(Packet { header: Header { ephemeral_public_key: x25519_dalek::PublicKey::from(&ephemeral_privkey), - routing_info: RoutingInfo { - remaining_layers: u8::try_from(recipient_pubkeys.len()) - .map_err(|_| Error::InvalidNumberOfLayers)?, - }, + encrypted_routing_info, }, payload: payload.into_bytes(), }) } - fn derive_key_material( + pub(crate) fn derive_key_material( recipient_pubkeys: &[x25519_dalek::PublicKey], ephemeral_privkey: &x25519_dalek::StaticSecret, ) -> sphinx_packet::header::keys::KeyMaterial { @@ -90,6 +89,7 @@ impl Packet { pub fn unpack( &self, private_key: &x25519_dalek::StaticSecret, + max_layers: usize, ) -> Result { // Derive the routing keys for the recipient let routing_keys = sphinx_packet::header::SphinxHeader::compute_routing_keys( @@ -101,25 +101,40 @@ impl Packet { let payload = sphinx_packet::payload::Payload::from_bytes(&self.payload)?; let payload = payload.unwrap(&routing_keys.payload_key)?; - // If this is the last layer of encryption, return the decrypted payload. - if self.header.routing_info.remaining_layers == 1 { - return Ok(UnpackedPacket::FullyUnpacked(payload.recover_plaintext()?)); + // Unpack the routing information + let (routing_info, next_encrypted_routing_info) = self + .header + .encrypted_routing_info + .unpack(&routing_keys, max_layers)?; + match routing_info.flag { + FORWARD_HOP => Ok(UnpackedPacket::ToForward(self.build_next_packet( + &routing_keys, + next_encrypted_routing_info, + payload, + ))), + FINAL_HOP => Ok(UnpackedPacket::FullyUnpacked(payload.recover_plaintext()?)), + _ => Err(Error::InvalidRoutingFlag(routing_info.flag)), } + } + fn build_next_packet( + &self, + routing_keys: &RoutingKeys, + next_encrypted_routing_info: EncryptedRoutingInformation, + payload: Payload, + ) -> Packet { // Derive the new ephemeral public key for the next recipient let next_ephemeral_pubkey = Self::derive_next_ephemeral_public_key( &self.header.ephemeral_public_key, &routing_keys.blinding_factor, ); - Ok(UnpackedPacket::ToForward(Packet { + Packet { header: Header { ephemeral_public_key: next_ephemeral_pubkey, - routing_info: RoutingInfo { - remaining_layers: self.header.routing_info.remaining_layers - 1, - }, + encrypted_routing_info: next_encrypted_routing_info, }, payload: payload.into_bytes(), - })) + } } /// Derive the next ephemeral public key for the next recipient. @@ -135,6 +150,49 @@ impl Packet { let new_shared_secret = blinding_factor.diffie_hellman(cur_ephemeral_pubkey); x25519_dalek::PublicKey::from(new_shared_secret.to_bytes()) } + + pub fn to_bytes(&self) -> Vec { + let ephemeral_public_key = self.header.ephemeral_public_key.to_bytes(); + let encrypted_routing_info = self.header.encrypted_routing_info.to_bytes(); + itertools::chain!( + &ephemeral_public_key, + &encrypted_routing_info, + &self.payload, + ) + .copied() + .collect() + } + + pub fn from_bytes(data: &[u8], max_layers: usize) -> Result { + let ephemeral_public_key_size = ASYM_KEY_SIZE; + let encrypted_routing_info_size = EncryptedRoutingInformation::size(max_layers); + let parsed = parse_bytes( + data, + &[ + ephemeral_public_key_size, + encrypted_routing_info_size, + data.len() - ephemeral_public_key_size - encrypted_routing_info_size, + ], + ) + .map_err(|_| Error::InvalidPacketBytes)?; + + Ok(Packet { + header: Header { + ephemeral_public_key: { + let bytes: [u8; 32] = parsed[0].try_into().unwrap(); + x25519_dalek::PublicKey::from(bytes) + }, + encrypted_routing_info: EncryptedRoutingInformation::from_bytes( + parsed[1], max_layers, + )?, + }, + payload: parsed[2].to_vec(), + }) + } + + pub const fn size(max_layers: usize, max_payload_size: usize) -> usize { + ASYM_KEY_SIZE + EncryptedRoutingInformation::size(max_layers) + max_payload_size + } } pub enum UnpackedPacket { @@ -144,14 +202,12 @@ pub enum UnpackedPacket { #[cfg(test)] mod tests { - use nomos_core::wire; - use super::*; #[test] fn unpack() { // Prepare keys of two recipients - let recipient_privkeys = (0..2) + let recipient_privkeys = (0..3) .map(|_| x25519_dalek::StaticSecret::random()) .collect::>(); let recipient_pubkeys = recipient_privkeys @@ -160,18 +216,26 @@ mod tests { .collect::>(); // Build a packet + let max_layers = 5; let payload = [10u8; 512]; - let packet = Packet::build(&recipient_pubkeys, &payload, 1024).unwrap(); + let packet = Packet::build(&recipient_pubkeys, max_layers, &payload, 1024).unwrap(); // The 1st recipient unpacks the packet - let packet = match packet.unpack(&recipient_privkeys[0]).unwrap() { + let packet = match packet.unpack(&recipient_privkeys[0], max_layers).unwrap() { UnpackedPacket::ToForward(packet) => packet, UnpackedPacket::FullyUnpacked(_) => { panic!("The unpacked packet should be the ToFoward type"); } }; // The 2nd recipient unpacks the packet - match packet.unpack(&recipient_privkeys[1]).unwrap() { + let packet = match packet.unpack(&recipient_privkeys[1], max_layers).unwrap() { + UnpackedPacket::ToForward(packet) => packet, + UnpackedPacket::FullyUnpacked(_) => { + panic!("The unpacked packet should be the ToFoward type"); + } + }; + // The last recipient unpacks the packet + match packet.unpack(&recipient_privkeys[2], max_layers).unwrap() { UnpackedPacket::ToForward(_) => { panic!("The unpacked packet should be the FullyUnpacked type"); } @@ -185,34 +249,25 @@ mod tests { #[test] fn unpack_with_wrong_keys() { // Build a packet with two public keys + let max_layers = 5; let payload = [10u8; 512]; let packet = Packet::build( &(0..2) .map(|_| x25519_dalek::PublicKey::from(&x25519_dalek::StaticSecret::random())) .collect::>(), + max_layers, &payload, 1024, ) .unwrap(); - // The 1st recipient unpacks the packet with an wrong key - let packet = match packet - .unpack(&x25519_dalek::StaticSecret::random()) - .unwrap() - { - UnpackedPacket::ToForward(packet) => packet, - UnpackedPacket::FullyUnpacked(_) => { - panic!("The unpacked packet should be the ToFoward type"); - } - }; - // The 2nd recipient unpacks the packet with an wrong key assert!(packet - .unpack(&x25519_dalek::StaticSecret::random()) + .unpack(&x25519_dalek::StaticSecret::random(), max_layers) .is_err()); } #[test] - fn consistent_size_serialization() { + fn consistent_size_after_unpack() { // Prepare keys of two recipients let recipient_privkeys = (0..2) .map(|_| x25519_dalek::StaticSecret::random()) @@ -223,28 +278,76 @@ mod tests { .collect::>(); // Build a packet + let max_layers = 5; let payload = [10u8; 512]; - let packet = Packet::build(&recipient_pubkeys, &payload, 1024).unwrap(); + let max_payload_size = 1024; + let packet = + Packet::build(&recipient_pubkeys, max_layers, &payload, max_payload_size).unwrap(); // Calculate the expected packet size - let pubkey_size = 32; - let routing_info_size = 1; - let payload_length_enconding_size = 8; - let payload_size = 1024; - let packet_size = - pubkey_size + routing_info_size + payload_length_enconding_size + payload_size; + let packet_size = Packet::size(max_layers, max_payload_size); // The serialized packet size must be the same as the expected size. - assert_eq!(wire::serialize(&packet).unwrap().len(), packet_size); + assert_eq!(packet.to_bytes().len(), packet_size); // The unpacked packet size must be the same as the original packet size. - match packet.unpack(&recipient_privkeys[0]).unwrap() { + match packet.unpack(&recipient_privkeys[0], max_layers).unwrap() { UnpackedPacket::ToForward(packet) => { - assert_eq!(wire::serialize(&packet).unwrap().len(), packet_size); + assert_eq!(packet.to_bytes().len(), packet_size); } UnpackedPacket::FullyUnpacked(_) => { panic!("The unpacked packet should be the ToFoward type"); } } } + + #[test] + fn consistent_size_with_any_num_layers() { + let max_layers = 5; + let payload = [10u8; 512]; + + // Build a packet with 2 recipients + let recipient_pubkeys = (0..2) + .map(|_| x25519_dalek::PublicKey::from(&x25519_dalek::StaticSecret::random())) + .collect::>(); + let packet1 = Packet::build(&recipient_pubkeys, max_layers, &payload, 1024).unwrap(); + + // Build a packet with 3 recipients + let recipient_pubkeys = (0..3) + .map(|_| x25519_dalek::PublicKey::from(&x25519_dalek::StaticSecret::random())) + .collect::>(); + let packet2 = Packet::build(&recipient_pubkeys, max_layers, &payload, 1024).unwrap(); + + assert_eq!(packet1.to_bytes().len(), packet2.to_bytes().len()); + } + + #[test] + fn to_from_bytes() { + let max_layers = 5; + let payload = [10u8; 512]; + + // Build a packet with 2 recipients + let recipient_pubkeys = (0..2) + .map(|_| x25519_dalek::PublicKey::from(&x25519_dalek::StaticSecret::random())) + .collect::>(); + let packet = Packet::build(&recipient_pubkeys, max_layers, &payload, 1024).unwrap(); + + let bytes = packet.to_bytes(); + let loaded_packet = Packet::from_bytes(&bytes, max_layers).unwrap(); + + // Manually compare packets because PartialEq is not implemented + // for [`sphinx_packet::header::mac::HeaderIntegrityMac`] used in our header. + assert_eq!( + packet.header.ephemeral_public_key, + loaded_packet.header.ephemeral_public_key + ); + assert_eq!( + packet.header.encrypted_routing_info.encrypted_routing_info, + loaded_packet + .header + .encrypted_routing_info + .encrypted_routing_info + ); + assert_eq!(packet.payload, loaded_packet.payload); + } } diff --git a/nomos-mix/message/src/sphinx/routing.rs b/nomos-mix/message/src/sphinx/routing.rs new file mode 100644 index 00000000..a0f22b75 --- /dev/null +++ b/nomos-mix/message/src/sphinx/routing.rs @@ -0,0 +1,140 @@ +use sphinx_packet::{ + constants::HEADER_INTEGRITY_MAC_SIZE, + header::{ + keys::RoutingKeys, + mac::HeaderIntegrityMac, + routing::{RoutingFlag, FINAL_HOP, FORWARD_HOP}, + }, +}; + +use super::{ + error::Error, + layered_cipher::{ + ConsistentLengthLayeredCipher, ConsistentLengthLayeredCipherData, EncryptionParam, Key, + }, + parse_bytes, +}; + +/// A routing information that will be contained in a packet header +/// in the encrypted format. +pub struct RoutingInformation { + pub flag: RoutingFlag, + // Add additional fields here +} + +impl RoutingInformation { + pub fn new(flag: RoutingFlag) -> Self { + Self { flag } + } + + pub fn from_bytes(data: &[u8]) -> Result { + if data.len() != Self::SIZE { + return Err(Error::InvalidEncryptedRoutingInfoLength(data.len())); + } + Ok(Self { flag: data[0] }) + } +} + +impl ConsistentLengthLayeredCipherData for RoutingInformation { + fn to_bytes(&self) -> Vec { + vec![self.flag] + } + + const SIZE: usize = std::mem::size_of::(); +} + +/// Encrypted routing information that will be contained in a packet header. +#[derive(Debug)] +pub struct EncryptedRoutingInformation { + /// A MAC to verify the integrity of [`Self::encrypted_routing_info`]. + pub mac: HeaderIntegrityMac, + /// The actual encrypted routing information produced by [`ConsistentLengthLayeredCipher`]. + /// Its size should be the same as [`ConsistentLengthLayeredCipher::total_size`]. + pub encrypted_routing_info: Vec, +} + +type LayeredCipher = ConsistentLengthLayeredCipher; + +impl EncryptedRoutingInformation { + /// Build all [`RoutingInformation`]s for the provides keys, + /// and encrypt them using [`ConsistentLengthLayeredCipher`]. + pub fn new(routing_keys: &[RoutingKeys], max_layers: usize) -> Result { + let cipher = LayeredCipher::new(max_layers); + let params = routing_keys + .iter() + .enumerate() + .map(|(i, k)| { + let flag = if i == routing_keys.len() - 1 { + FINAL_HOP + } else { + FORWARD_HOP + }; + EncryptionParam:: { + data: RoutingInformation::new(flag), + key: Self::layered_cipher_key(k), + } + }) + .collect::>(); + let (encrypted, mac) = cipher.encrypt(¶ms)?; + + Ok(Self { + mac, + encrypted_routing_info: encrypted, + }) + } + + /// Unpack one layer of encryptions using the key provided. + /// Returns the decrypted routing information + /// and the next [`EncryptedRoutingInformation`] to be unpacked further. + pub fn unpack( + &self, + routing_key: &RoutingKeys, + max_layers: usize, + ) -> Result<(RoutingInformation, Self), Error> { + let cipher = LayeredCipher::new(max_layers); + let (routing_info, next_mac, next_encrypted_routing_info) = cipher.unpack( + &self.mac, + &self.encrypted_routing_info, + &Self::layered_cipher_key(routing_key), + )?; + Ok(( + RoutingInformation::from_bytes(&routing_info)?, + Self { + mac: next_mac, + encrypted_routing_info: next_encrypted_routing_info, + }, + )) + } + + fn layered_cipher_key(routing_key: &RoutingKeys) -> Key { + Key { + stream_cipher_key: routing_key.stream_cipher_key, + integrity_mac_key: routing_key.header_integrity_hmac_key, + } + } + + pub fn to_bytes(&self) -> Vec { + itertools::chain!(self.mac.as_bytes(), &self.encrypted_routing_info) + .copied() + .collect() + } + + pub fn from_bytes(data: &[u8], max_layers: usize) -> Result { + let parsed = parse_bytes( + data, + &[ + HEADER_INTEGRITY_MAC_SIZE, + LayeredCipher::total_size(max_layers), + ], + ) + .map_err(|_| Error::InvalidEncryptedRoutingInfoLength(data.len()))?; + Ok(Self { + mac: HeaderIntegrityMac::from_bytes(parsed[0].try_into().unwrap()), + encrypted_routing_info: parsed[1].to_vec(), + }) + } + + pub const fn size(max_layers: usize) -> usize { + HEADER_INTEGRITY_MAC_SIZE + LayeredCipher::total_size(max_layers) + } +}