From 10940321ff362754782bc4ee8406909538ac1f7f Mon Sep 17 00:00:00 2001 From: kaichao Date: Thu, 29 Jan 2026 09:19:52 +0800 Subject: [PATCH] Encode ratchet sate for serialization (#20) * feat: costom encode for double ratchet * chore: correct capacity * chore: refactor reference * chore: reader for parse bytes * chore: extract reader * chore: example with persist state. * chore: update example * chore: implement serde compatibility. * chore: as_bytes * chore: zerorize the secrec material * chore: use as_types to return reference for static key. * chore: extract example from basic demo --- Cargo.lock | 1 + double-ratchets/Cargo.toml | 1 + .../examples/serialization_demo.rs | 75 +++++ double-ratchets/src/errors.rs | 3 + double-ratchets/src/keypair.rs | 8 +- double-ratchets/src/lib.rs | 1 + double-ratchets/src/reader.rs | 135 ++++++++ double-ratchets/src/state.rs | 303 ++++++++++++++++++ 8 files changed, 523 insertions(+), 4 deletions(-) create mode 100644 double-ratchets/examples/serialization_demo.rs create mode 100644 double-ratchets/src/reader.rs diff --git a/Cargo.lock b/Cargo.lock index d4f5b92..e9850d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -223,6 +223,7 @@ dependencies = [ "rand_core", "rusqlite", "safer-ffi", + "serde", "thiserror", "x25519-dalek", "zeroize", diff --git a/double-ratchets/Cargo.toml b/double-ratchets/Cargo.toml index 048011d..de78550 100644 --- a/double-ratchets/Cargo.toml +++ b/double-ratchets/Cargo.toml @@ -20,6 +20,7 @@ thiserror = "2" blake2 = "0.10.6" safer-ffi = "0.1.13" zeroize = "1.8.2" +serde = "1.0" rusqlite = { version = "0.35", optional = true, features = ["bundled"] } [features] diff --git a/double-ratchets/examples/serialization_demo.rs b/double-ratchets/examples/serialization_demo.rs new file mode 100644 index 0000000..76a5878 --- /dev/null +++ b/double-ratchets/examples/serialization_demo.rs @@ -0,0 +1,75 @@ +use double_ratchets::{InstallationKeyPair, RatchetState, hkdf::PrivateV1Domain}; + +fn main() { + // === Initial shared secret (X3DH / prekey result in real systems) === + let shared_secret = [42u8; 32]; + + let bob_dh = InstallationKeyPair::generate(); + + let mut alice: RatchetState = + RatchetState::init_sender(shared_secret, bob_dh.public().clone()); + let mut bob: RatchetState = RatchetState::init_receiver(shared_secret, bob_dh); + + let (ciphertext, header) = alice.encrypt_message(b"Hello Bob!"); + + // === Bob receives === + let plaintext = bob.decrypt_message(&ciphertext, header); + println!( + "Bob received: {}", + String::from_utf8_lossy(&plaintext.unwrap()) + ); + + // === Bob replies (triggers DH ratchet) === + let (ciphertext, header) = bob.encrypt_message(b"Hi Alice!"); + + let plaintext = alice.decrypt_message(&ciphertext, header); + println!( + "Alice received: {}", + String::from_utf8_lossy(&plaintext.unwrap()) + ); + + // === Serialize the state of alice and bob === + println!("Before restart, persist the state"); + let alice_state = alice.as_bytes(); + let bob_state = bob.as_bytes(); + + // === Deserialize alice and bob state from bytes === + println!("Restart alice and bob"); + let mut alice_new: RatchetState = + RatchetState::from_bytes(&alice_state).unwrap(); + let mut bob_new: RatchetState = RatchetState::from_bytes(&bob_state).unwrap(); + + // === Alice sends a message === + let (ciphertext, header) = alice_new.encrypt_message(b"Hello Bob!"); + + // === Bob receives === + let plaintext = bob_new.decrypt_message(&ciphertext, header); + println!( + "New Bob received: {}", + String::from_utf8_lossy(&plaintext.unwrap()) + ); + + // === Bob replies (triggers DH ratchet) === + let (ciphertext, header) = bob_new.encrypt_message(b"Hi Alice!"); + + let plaintext = alice_new.decrypt_message(&ciphertext, header); + println!( + "New Alice received: {}", + String::from_utf8_lossy(&plaintext.unwrap()) + ); + + let (skipped_ciphertext, skipped_header) = bob_new.encrypt_message(b"Hi Alice skipped!"); + let (resumed_ciphertext, resumed_header) = bob_new.encrypt_message(b"Hi Alice resumed!"); + + let plaintext = alice_new.decrypt_message(&resumed_ciphertext, resumed_header); + println!( + "New Alice received: {}", + String::from_utf8_lossy(&plaintext.unwrap()) + ); + + let plaintext = alice_new.decrypt_message(&skipped_ciphertext, skipped_header); + println!( + "New Alice received: {}", + String::from_utf8_lossy(&plaintext.unwrap()) + ); +} diff --git a/double-ratchets/src/errors.rs b/double-ratchets/src/errors.rs index c0e15c7..1787a57 100644 --- a/double-ratchets/src/errors.rs +++ b/double-ratchets/src/errors.rs @@ -23,4 +23,7 @@ pub enum RatchetError { #[error("missing receiving chain")] MissingReceivingChain, + + #[error("deserialization failed")] + DeserializationFailed, } diff --git a/double-ratchets/src/keypair.rs b/double-ratchets/src/keypair.rs index c32adb9..7943646 100644 --- a/double-ratchets/src/keypair.rs +++ b/double-ratchets/src/keypair.rs @@ -25,12 +25,12 @@ impl InstallationKeyPair { &self.public } - /// Export the secret key as raw bytes for storage. - pub fn secret_bytes(&self) -> [u8; 32] { - self.secret.to_bytes() + /// Export the secret key as raw bytes for serialization/storage. + pub fn secret_bytes(&self) -> &[u8; 32] { + self.secret.as_bytes() } - /// Reconstruct from secret key bytes. + /// Import the secret key from raw bytes. pub fn from_secret_bytes(bytes: [u8; 32]) -> Self { let secret = StaticSecret::from(bytes); let public = PublicKey::from(&secret); diff --git a/double-ratchets/src/lib.rs b/double-ratchets/src/lib.rs index 1b9a566..f2cd789 100644 --- a/double-ratchets/src/lib.rs +++ b/double-ratchets/src/lib.rs @@ -3,6 +3,7 @@ pub mod errors; pub mod ffi; pub mod hkdf; pub mod keypair; +pub mod reader; pub mod state; #[cfg(feature = "storage")] pub mod storage; diff --git a/double-ratchets/src/reader.rs b/double-ratchets/src/reader.rs new file mode 100644 index 0000000..bb4b89f --- /dev/null +++ b/double-ratchets/src/reader.rs @@ -0,0 +1,135 @@ +use crate::errors::RatchetError; + +pub struct Reader<'a> { + data: &'a [u8], + pos: usize, +} + +impl<'a> Reader<'a> { + pub fn new(data: &'a [u8]) -> Self { + Self { data, pos: 0 } + } + + pub fn read_bytes(&mut self, n: usize) -> Result<&[u8], RatchetError> { + if self.pos + n > self.data.len() { + return Err(RatchetError::DeserializationFailed); + } + let slice = &self.data[self.pos..self.pos + n]; + self.pos += n; + Ok(slice) + } + + pub fn read_array(&mut self) -> Result<[u8; N], RatchetError> { + self.read_bytes(N)? + .try_into() + .map_err(|_| RatchetError::DeserializationFailed) + } + + pub fn read_u8(&mut self) -> Result { + Ok(self.read_bytes(1)?[0]) + } + + pub fn read_u32(&mut self) -> Result { + Ok(u32::from_be_bytes(self.read_array()?)) + } + + pub fn read_option(&mut self) -> Result, RatchetError> { + match self.read_u8()? { + 0x00 => Ok(None), + 0x01 => Ok(Some(self.read_array()?)), + _ => Err(RatchetError::DeserializationFailed), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_read_bytes() { + let data = [1, 2, 3, 4, 5]; + let mut reader = Reader::new(&data); + + assert_eq!(reader.read_bytes(2).unwrap(), &[1, 2]); + assert_eq!(reader.read_bytes(3).unwrap(), &[3, 4, 5]); + } + + #[test] + fn test_read_bytes_overflow() { + let data = [1, 2, 3]; + let mut reader = Reader::new(&data); + + assert!(matches!( + reader.read_bytes(4), + Err(RatchetError::DeserializationFailed) + )); + } + + #[test] + fn test_read_array() { + let data = [1, 2, 3, 4]; + let mut reader = Reader::new(&data); + + let arr: [u8; 4] = reader.read_array().unwrap(); + assert_eq!(arr, [1, 2, 3, 4]); + } + + #[test] + fn test_read_u8() { + let data = [0x42, 0xFF]; + let mut reader = Reader::new(&data); + + assert_eq!(reader.read_u8().unwrap(), 0x42); + assert_eq!(reader.read_u8().unwrap(), 0xFF); + } + + #[test] + fn test_read_u32() { + let data = [0x00, 0x01, 0x02, 0x03]; + let mut reader = Reader::new(&data); + + assert_eq!(reader.read_u32().unwrap(), 0x00010203); + } + + #[test] + fn test_read_option_none() { + let data = [0x00]; + let mut reader = Reader::new(&data); + + assert_eq!(reader.read_option().unwrap(), None); + } + + #[test] + fn test_read_option_some() { + let mut data = vec![0x01]; + data.extend_from_slice(&[0x42; 32]); + let mut reader = Reader::new(&data); + + assert_eq!(reader.read_option().unwrap(), Some([0x42; 32])); + } + + #[test] + fn test_read_option_invalid_flag() { + let data = [0x02]; + let mut reader = Reader::new(&data); + + assert!(matches!( + reader.read_option(), + Err(RatchetError::DeserializationFailed) + )); + } + + #[test] + fn test_sequential_reads() { + let mut data = vec![0x01]; // version + data.extend_from_slice(&[0xAA; 32]); // 32-byte array + data.extend_from_slice(&[0x00, 0x00, 0x00, 0x10]); // u32 = 16 + + let mut reader = Reader::new(&data); + + assert_eq!(reader.read_u8().unwrap(), 0x01); + assert_eq!(reader.read_array::<32>().unwrap(), [0xAA; 32]); + assert_eq!(reader.read_u32().unwrap(), 16); + } +} diff --git a/double-ratchets/src/state.rs b/double-ratchets/src/state.rs index dc92ee1..48ad359 100644 --- a/double-ratchets/src/state.rs +++ b/double-ratchets/src/state.rs @@ -1,15 +1,21 @@ use std::{collections::HashMap, marker::PhantomData}; +use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Error as DeError}; use x25519_dalek::PublicKey; +use zeroize::{Zeroize, Zeroizing}; use crate::{ aead::{decrypt, encrypt}, errors::RatchetError, hkdf::{DefaultDomain, HkdfInfo, kdf_chain, kdf_root}, keypair::InstallationKeyPair, + reader::Reader, types::{ChainKey, MessageKey, Nonce, RootKey, SharedSecret}, }; +/// Current binary format version. +const SERIALIZATION_VERSION: u8 = 1; + /// Represents the local state of the Double Ratchet algorithm for one conversation. /// /// This struct maintains all keys and counters required to perform the Double Ratchet @@ -42,6 +48,153 @@ pub struct SkippedKey { pub message_key: MessageKey, } +impl RatchetState { + /// Serializes the ratchet state to a binary format. + /// + /// # Binary Format (Version 1) + /// + /// ```text + /// | Field | Size (bytes) | Description | + /// |--------------------|--------------|--------------------------------------| + /// | version | 1 | Format version (0x01) | + /// | root_key | 32 | Root key | + /// | sending_chain_flag | 1 | 0x00 = None, 0x01 = Some | + /// | sending_chain | 0 or 32 | Chain key if flag is 0x01 | + /// | receiving_chain_flag| 1 | 0x00 = None, 0x01 = Some | + /// | receiving_chain | 0 or 32 | Chain key if flag is 0x01 | + /// | dh_self_secret | 32 | DH secret key | + /// | dh_remote_flag | 1 | 0x00 = None, 0x01 = Some | + /// | dh_remote | 0 or 32 | DH public key if flag is 0x01 | + /// | msg_send | 4 | Send counter (big-endian) | + /// | msg_recv | 4 | Receive counter (big-endian) | + /// | prev_chain_len | 4 | Previous chain length (big-endian) | + /// | skipped_count | 4 | Number of skipped keys (big-endian) | + /// | skipped_keys | 68 * count | Each: pubkey(32) + msg_num(4) + key(32) | + /// ``` + pub fn as_bytes(&self) -> Zeroizing> { + fn option_size(opt: Option<[u8; 32]>) -> usize { + 1 + opt.map_or(0, |_| 32) + } + + fn write_option(buf: &mut Vec, opt: Option<[u8; 32]>) { + match opt { + Some(data) => { + buf.push(0x01); + buf.extend_from_slice(&data); + } + None => buf.push(0x00), + } + } + + let skipped_count = self.skipped_keys.len(); + let dh_remote = self.dh_remote.map(|pk| pk.to_bytes()); + + let capacity = 1 + 32 // version + root_key + + option_size(self.sending_chain) + + option_size(self.receiving_chain) + + 32 // dh_self + + option_size(dh_remote) + + 12 // counters + + 4 + (skipped_count * 68); // skipped keys + + let mut buf = Zeroizing::new(Vec::with_capacity(capacity)); + + buf.push(SERIALIZATION_VERSION); + buf.extend_from_slice(&self.root_key); + write_option(&mut buf, self.sending_chain); + write_option(&mut buf, self.receiving_chain); + + let dh_secret = self.dh_self.secret_bytes(); + buf.extend_from_slice(dh_secret); + + write_option(&mut buf, dh_remote); + + buf.extend_from_slice(&self.msg_send.to_be_bytes()); + buf.extend_from_slice(&self.msg_recv.to_be_bytes()); + buf.extend_from_slice(&self.prev_chain_len.to_be_bytes()); + + buf.extend_from_slice(&(skipped_count as u32).to_be_bytes()); + for ((pk, msg_num), mk) in &self.skipped_keys { + buf.extend_from_slice(pk.as_bytes()); + buf.extend_from_slice(&msg_num.to_be_bytes()); + buf.extend_from_slice(mk); + } + + buf + } + + /// Deserializes a ratchet state from binary data. + /// + /// # Errors + /// + /// Returns `RatchetError::DeserializationFailed` if the data is invalid or truncated. + pub fn from_bytes(data: &[u8]) -> Result { + let mut reader = Reader::new(data); + + let version = reader.read_u8()?; + if version != SERIALIZATION_VERSION { + return Err(RatchetError::DeserializationFailed); + } + + let root_key: RootKey = reader.read_array()?; + let sending_chain = reader.read_option()?; + let receiving_chain = reader.read_option()?; + + let mut dh_self_bytes: [u8; 32] = reader.read_array()?; + let dh_self = InstallationKeyPair::from_secret_bytes(dh_self_bytes); + dh_self_bytes.zeroize(); + + let dh_remote = reader.read_option()?.map(PublicKey::from); + + let msg_send = reader.read_u32()?; + let msg_recv = reader.read_u32()?; + let prev_chain_len = reader.read_u32()?; + + let skipped_count = reader.read_u32()? as usize; + let mut skipped_keys = HashMap::with_capacity(skipped_count); + for _ in 0..skipped_count { + let pk = PublicKey::from(reader.read_array::<32>()?); + let msg_num = reader.read_u32()?; + let mk: MessageKey = reader.read_array()?; + skipped_keys.insert((pk, msg_num), mk); + } + + Ok(Self { + root_key, + sending_chain, + receiving_chain, + dh_self, + dh_remote, + msg_send, + msg_recv, + prev_chain_len, + skipped_keys, + _domain: PhantomData, + }) + } +} + +/// Custom serde Serialize implementation that uses our binary format. +impl Serialize for RatchetState { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_bytes(&self.as_bytes()) + } +} + +/// Custom serde Deserialize implementation that uses our binary format. +impl<'de, D: HkdfInfo> Deserialize<'de> for RatchetState { + fn deserialize(deserializer: De) -> Result + where + De: Deserializer<'de>, + { + let bytes = >::deserialize(deserializer)?; + Self::from_bytes(&bytes).map_err(DeError::custom) + } +} + /// Public header attached to every encrypted message (unencrypted but authenticated). #[derive(Clone, Debug)] pub struct Header { @@ -513,6 +666,156 @@ mod tests { assert_eq!(result.unwrap_err(), RatchetError::MessageReplay); } + #[test] + fn test_serialize_deserialize_sender_state() { + let (alice, _, _) = setup_alice_bob(); + + // Serialize to binary + let bytes = alice.as_bytes(); + + // Deserialize back + let restored: RatchetState = RatchetState::from_bytes(&bytes).unwrap(); + + // Verify key fields match + assert_eq!(alice.root_key, restored.root_key); + assert_eq!(alice.sending_chain, restored.sending_chain); + assert_eq!(alice.receiving_chain, restored.receiving_chain); + assert_eq!(alice.msg_send, restored.msg_send); + assert_eq!(alice.msg_recv, restored.msg_recv); + assert_eq!(alice.prev_chain_len, restored.prev_chain_len); + assert_eq!( + alice.dh_remote.map(|pk| pk.to_bytes()), + restored.dh_remote.map(|pk| pk.to_bytes()) + ); + assert_eq!( + alice.dh_self.public().to_bytes(), + restored.dh_self.public().to_bytes() + ); + } + + #[test] + fn test_serialize_deserialize_receiver_state() { + let (_, bob, _) = setup_alice_bob(); + + // Serialize to binary + let bytes = bob.as_bytes(); + + // Deserialize back + let restored: RatchetState = RatchetState::from_bytes(&bytes).unwrap(); + + // Verify key fields match + assert_eq!(bob.root_key, restored.root_key); + assert_eq!(bob.sending_chain, restored.sending_chain); + assert_eq!(bob.receiving_chain, restored.receiving_chain); + assert_eq!(bob.msg_send, restored.msg_send); + assert_eq!(bob.msg_recv, restored.msg_recv); + assert_eq!(bob.prev_chain_len, restored.prev_chain_len); + assert!(bob.dh_remote.is_none()); + assert!(restored.dh_remote.is_none()); + } + + #[test] + fn test_serialize_deserialize_with_skipped_keys() { + let (mut alice, mut bob, _) = setup_alice_bob(); + + // Alice sends 3 messages + let mut sent = vec![]; + for i in 0..3 { + let plaintext = format!("Message {}", i + 1).into_bytes(); + let (ct, header) = alice.encrypt_message(&plaintext); + sent.push((ct, header, plaintext)); + } + + // Bob receives only msg0 and msg2, skipping msg1 + bob.decrypt_message(&sent[0].0, sent[0].1.clone()).unwrap(); + bob.decrypt_message(&sent[2].0, sent[2].1.clone()).unwrap(); + + // Bob should have one skipped key + assert_eq!(bob.skipped_keys.len(), 1); + + // Serialize Bob's state + let bytes = bob.as_bytes(); + + // Deserialize + let mut restored: RatchetState = RatchetState::from_bytes(&bytes).unwrap(); + + // Restored state should have the skipped key + assert_eq!(restored.skipped_keys.len(), 1); + + // The restored state should be able to decrypt the skipped message + let pt1 = restored + .decrypt_message(&sent[1].0, sent[1].1.clone()) + .unwrap(); + assert_eq!(pt1, sent[1].2); + } + + #[test] + fn test_serialize_deserialize_continue_conversation() { + let (mut alice, mut bob, _) = setup_alice_bob(); + + // Exchange some messages + let (ct1, h1) = alice.encrypt_message(b"Hello Bob"); + bob.decrypt_message(&ct1, h1).unwrap(); + + let (ct2, h2) = bob.encrypt_message(b"Hello Alice"); + alice.decrypt_message(&ct2, h2).unwrap(); + + // Serialize both states + let alice_bytes = alice.as_bytes(); + let bob_bytes = bob.as_bytes(); + + // Deserialize + let mut alice_restored: RatchetState = RatchetState::from_bytes(&alice_bytes).unwrap(); + let mut bob_restored: RatchetState = RatchetState::from_bytes(&bob_bytes).unwrap(); + + // Continue the conversation with restored states + let (ct3, h3) = alice_restored.encrypt_message(b"Message after restore"); + let pt3 = bob_restored.decrypt_message(&ct3, h3).unwrap(); + assert_eq!(pt3, b"Message after restore"); + + let (ct4, h4) = bob_restored.encrypt_message(b"Reply after restore"); + let pt4 = alice_restored.decrypt_message(&ct4, h4).unwrap(); + assert_eq!(pt4, b"Reply after restore"); + } + + #[test] + fn test_serialization_version_check() { + let (alice, _, _) = setup_alice_bob(); + let mut bytes = alice.as_bytes(); + + // Tamper with version byte + bytes[0] = 0xFF; + + let result = RatchetState::::from_bytes(&bytes); + assert!(matches!(result, Err(RatchetError::DeserializationFailed))); + } + + #[test] + fn test_serialization_truncated_data() { + let (alice, _, _) = setup_alice_bob(); + let bytes = alice.as_bytes(); + + // Truncate the data + let truncated = &bytes[..10]; + + let result = RatchetState::::from_bytes(truncated); + assert!(matches!(result, Err(RatchetError::DeserializationFailed))); + } + + #[test] + fn test_serialization_size_efficiency() { + let (alice, _, _) = setup_alice_bob(); + let bytes = alice.as_bytes(); + + // Minimum size: version(1) + root_key(32) + sending_flag(1) + sending(32) + + // receiving_flag(1) + dh_self(32) + dh_remote_flag(1) + dh_remote(32) + + // counters(12) + skipped_count(4) = 148 bytes for sender with no skipped keys + assert!(bytes.len() < 200, "Serialized size should be compact"); + + // Verify version byte + assert_eq!(bytes[0], 1, "Version should be 1"); + } + #[test] fn test_skipped_keys_export() { let (mut alice, mut bob, _) = setup_alice_bob();