diff --git a/double-ratchets-storage/src/lib.rs b/double-ratchets-storage/src/lib.rs index 9536cf6..99ee462 100644 --- a/double-ratchets-storage/src/lib.rs +++ b/double-ratchets-storage/src/lib.rs @@ -1,10 +1,13 @@ //! Persistent storage for Double Ratchet state. //! //! This crate provides storage backends for persisting [`RatchetState`](double_ratchets::RatchetState) -//! across application restarts. It includes: +//! with automatic field-level persistence on each ratchet operation. //! -//! - [`MemoryStorage`] - In-memory storage for testing -//! - [`SqliteStorage`] - SQLite storage with field-level encryption (requires `sqlite` feature) +//! # Main API +//! +//! - [`PersistentRatchet`] - Wrapper that auto-persists state changes during encrypt/decrypt +//! - [`SqliteRatchetStore`] - SQLite backend with field-level encryption +//! - [`RatchetStore`] - Trait for implementing custom storage backends //! //! # Features //! @@ -13,166 +16,233 @@ //! //! # Security //! -//! Private keys (`dh_self_secret`) are always encrypted with ChaCha20Poly1305 before storage, -//! even when using plain SQLite. For additional security, enable the `sqlcipher` feature -//! for full database encryption. +//! Private keys (`dh_self_secret`) and message keys are always encrypted with ChaCha20Poly1305 +//! before storage, even when using plain SQLite. For additional security, enable the `sqlcipher` +//! feature for full database encryption. //! //! # Example //! //! ```no_run +//! use std::sync::Arc; //! use double_ratchets::hkdf::DefaultDomain; -//! use double_ratchets::state::RatchetState; //! use double_ratchets::InstallationKeyPair; -//! use double_ratchets_storage::{ -//! RatchetStorage, SqliteStorage, StorableRatchetState, -//! }; -//! -//! // Create a ratchet state -//! let bob_keypair = InstallationKeyPair::generate(); -//! let shared_secret = [0x42u8; 32]; -//! let state: RatchetState = -//! RatchetState::init_sender(shared_secret, *bob_keypair.public()); +//! use double_ratchets_storage::{PersistentRatchet, RatchetStore, SqliteRatchetStore}; //! //! // Open storage //! let encryption_key = [0u8; 32]; // Use proper key derivation! -//! let storage = SqliteStorage::open("ratchets.db", encryption_key).unwrap(); +//! let store: Arc = +//! Arc::new(SqliteRatchetStore::open("ratchets.db", encryption_key).unwrap()); //! -//! // Save state +//! // Initialize sender +//! let bob_keypair = InstallationKeyPair::generate(); +//! let shared_secret = [0x42u8; 32]; //! let session_id = [1u8; 32]; -//! let storable = StorableRatchetState::from_ratchet_state(&state, "default"); -//! storage.save(&session_id, &storable).unwrap(); //! -//! // Load state -//! let loaded = storage.load(&session_id).unwrap().unwrap(); -//! let restored: RatchetState = loaded.to_ratchet_state().unwrap(); +//! let mut alice: PersistentRatchet = PersistentRatchet::init_sender( +//! Arc::clone(&store), +//! session_id, +//! shared_secret, +//! *bob_keypair.public(), +//! ).unwrap(); +//! +//! // Encrypt - state is automatically persisted +//! let (ciphertext, header) = alice.encrypt_message(b"Hello!").unwrap(); +//! +//! // Later: load from storage +//! let alice_restored: PersistentRatchet = +//! PersistentRatchet::load(store, session_id).unwrap().unwrap(); //! ``` pub mod error; -pub mod memory; +#[cfg(any(feature = "sqlite", feature = "sqlcipher"))] +pub mod persistent; #[cfg(any(feature = "sqlite", feature = "sqlcipher"))] pub mod sqlite; pub mod traits; -pub mod types; // Re-exports for convenience pub use error::StorageError; -pub use memory::MemoryStorage; #[cfg(any(feature = "sqlite", feature = "sqlcipher"))] -pub use sqlite::{EncryptionKey, SqliteStorage}; -pub use traits::{RatchetStorage, SessionId}; -pub use types::{SkippedKey, StorableRatchetState}; +pub use persistent::{PersistentRatchet, PersistentRatchetError}; +#[cfg(any(feature = "sqlite", feature = "sqlcipher"))] +pub use sqlite::{EncryptionKey, SqliteRatchetStore}; +pub use traits::{RatchetStore, SessionId, SkippedKeyEntry, StoredState}; -#[cfg(test)] +#[cfg(all(test, any(feature = "sqlite", feature = "sqlcipher")))] mod integration_tests { use super::*; use double_ratchets::hkdf::DefaultDomain; - use double_ratchets::state::RatchetState; use double_ratchets::InstallationKeyPair; + use std::sync::Arc; - /// Integration test: full encryption/decryption cycle with storage + /// Integration test: full conversation with auto-persist #[test] - fn test_full_conversation_with_storage_roundtrip() { - // Setup Alice and Bob - let bob_keypair = InstallationKeyPair::generate(); - let shared_secret = [0x42u8; 32]; - - let mut alice: RatchetState = - RatchetState::init_sender(shared_secret, *bob_keypair.public()); - let mut bob: RatchetState = - RatchetState::init_receiver(shared_secret, bob_keypair); - - let storage = MemoryStorage::new(); + fn test_full_conversation_with_auto_persist() { + let store: Arc = + Arc::new(SqliteRatchetStore::open_in_memory([0x42u8; 32]).unwrap()); let alice_session = [0xAA; 32]; let bob_session = [0xBB; 32]; - // Alice sends a message - let (ct1, header1) = alice.encrypt_message(b"Hello Bob!"); + let bob_keypair = InstallationKeyPair::generate(); + let shared_secret = [0x42u8; 32]; - // Save Alice's state - let alice_storable = StorableRatchetState::from_ratchet_state(&alice, "default"); - storage.save(&alice_session, &alice_storable).unwrap(); + // Initialize both parties - state is auto-persisted + let mut alice: PersistentRatchet = PersistentRatchet::init_sender( + Arc::clone(&store), + alice_session, + shared_secret, + *bob_keypair.public(), + ) + .unwrap(); - // Bob receives the message + let mut bob: PersistentRatchet = PersistentRatchet::init_receiver( + Arc::clone(&store), + bob_session, + shared_secret, + bob_keypair, + ) + .unwrap(); + + // Alice sends - state auto-persisted + let (ct1, header1) = alice.encrypt_message(b"Hello Bob!").unwrap(); let pt1 = bob.decrypt_message(&ct1, header1).unwrap(); assert_eq!(pt1, b"Hello Bob!"); - // Save Bob's state - let bob_storable = StorableRatchetState::from_ratchet_state(&bob, "default"); - storage.save(&bob_session, &bob_storable).unwrap(); + // Verify state was persisted + assert!(store.session_exists(&alice_session).unwrap()); + assert!(store.session_exists(&bob_session).unwrap()); - // Simulate restart: load states from storage - let alice_loaded = storage.load(&alice_session).unwrap().unwrap(); - let bob_loaded = storage.load(&bob_session).unwrap().unwrap(); - - let mut alice_restored: RatchetState = - alice_loaded.to_ratchet_state().unwrap(); - let mut bob_restored: RatchetState = - bob_loaded.to_ratchet_state().unwrap(); - - // Bob replies - let (ct2, header2) = bob_restored.encrypt_message(b"Hi Alice!"); - let pt2 = alice_restored.decrypt_message(&ct2, header2).unwrap(); + // Bob replies - state auto-persisted + let (ct2, header2) = bob.encrypt_message(b"Hi Alice!").unwrap(); + let pt2 = alice.decrypt_message(&ct2, header2).unwrap(); assert_eq!(pt2, b"Hi Alice!"); - // Alice sends another message - let (ct3, header3) = alice_restored.encrypt_message(b"How are you?"); - let pt3 = bob_restored.decrypt_message(&ct3, header3).unwrap(); + // Alice sends another - state auto-persisted + let (ct3, header3) = alice.encrypt_message(b"How are you?").unwrap(); + let pt3 = bob.decrypt_message(&ct3, header3).unwrap(); assert_eq!(pt3, b"How are you?"); } - /// Integration test: verify SQLite storage with encryption works - #[cfg(any(feature = "sqlite", feature = "sqlcipher"))] + /// Integration test: verify SQLite storage with file persistence #[test] - fn test_sqlite_integration() { + fn test_sqlite_file_persistence() { let dir = tempfile::tempdir().unwrap(); let db_path = dir.path().join("integration_test.db"); let key = [0x42u8; 32]; - // Setup let bob_keypair = InstallationKeyPair::generate(); + let bob_pub = *bob_keypair.public(); let shared_secret = [0x42u8; 32]; - let mut alice: RatchetState = - RatchetState::init_sender(shared_secret, *bob_keypair.public()); - let mut bob: RatchetState = - RatchetState::init_receiver(shared_secret, bob_keypair); - let alice_session = [0xAA; 32]; let bob_session = [0xBB; 32]; - // Exchange messages - let (ct1, header1) = alice.encrypt_message(b"Message 1"); - bob.decrypt_message(&ct1, header1).unwrap(); - - let (ct2, header2) = bob.encrypt_message(b"Response 1"); - alice.decrypt_message(&ct2, header2).unwrap(); - - // Save both states + // First session: exchange messages { - let storage = SqliteStorage::open(&db_path, key).unwrap(); + let store: Arc = + Arc::new(SqliteRatchetStore::open(&db_path, key).unwrap()); - let alice_storable = StorableRatchetState::from_ratchet_state(&alice, "default"); - let bob_storable = StorableRatchetState::from_ratchet_state(&bob, "default"); + let mut alice: PersistentRatchet = PersistentRatchet::init_sender( + Arc::clone(&store), + alice_session, + shared_secret, + bob_pub, + ) + .unwrap(); - storage.save(&alice_session, &alice_storable).unwrap(); - storage.save(&bob_session, &bob_storable).unwrap(); + let mut bob: PersistentRatchet = PersistentRatchet::init_receiver( + Arc::clone(&store), + bob_session, + shared_secret, + bob_keypair, + ) + .unwrap(); + + let (ct1, h1) = alice.encrypt_message(b"Message 1").unwrap(); + bob.decrypt_message(&ct1, h1).unwrap(); + + let (ct2, h2) = bob.encrypt_message(b"Response 1").unwrap(); + alice.decrypt_message(&ct2, h2).unwrap(); } // Reopen database (simulating restart) { - let storage = SqliteStorage::open(&db_path, key).unwrap(); + let store: Arc = + Arc::new(SqliteRatchetStore::open(&db_path, key).unwrap()); - let alice_loaded = storage.load(&alice_session).unwrap().unwrap(); - let bob_loaded = storage.load(&bob_session).unwrap().unwrap(); + let mut alice: PersistentRatchet = + PersistentRatchet::load(Arc::clone(&store), alice_session) + .unwrap() + .unwrap(); - let mut alice_restored: RatchetState = - alice_loaded.to_ratchet_state().unwrap(); - let mut bob_restored: RatchetState = - bob_loaded.to_ratchet_state().unwrap(); + let mut bob: PersistentRatchet = + PersistentRatchet::load(Arc::clone(&store), bob_session) + .unwrap() + .unwrap(); // Continue conversation - let (ct3, header3) = alice_restored.encrypt_message(b"Message 2"); - let pt3 = bob_restored.decrypt_message(&ct3, header3).unwrap(); + let (ct3, h3) = alice.encrypt_message(b"Message 2").unwrap(); + let pt3 = bob.decrypt_message(&ct3, h3).unwrap(); assert_eq!(pt3, b"Message 2"); } } + + /// Integration test: out-of-order messages with skipped keys + #[test] + fn test_out_of_order_messages_persisted() { + let store: Arc = + Arc::new(SqliteRatchetStore::open_in_memory([0x42u8; 32]).unwrap()); + let alice_session = [0xAA; 32]; + let bob_session = [0xBB; 32]; + + let bob_keypair = InstallationKeyPair::generate(); + let shared_secret = [0x42u8; 32]; + + let mut alice: PersistentRatchet = PersistentRatchet::init_sender( + Arc::clone(&store), + alice_session, + shared_secret, + *bob_keypair.public(), + ) + .unwrap(); + + let mut bob: PersistentRatchet = PersistentRatchet::init_receiver( + Arc::clone(&store), + bob_session, + shared_secret, + bob_keypair, + ) + .unwrap(); + + // Alice sends 4 messages + let mut messages = vec![]; + for i in 0..4 { + let (ct, h) = alice + .encrypt_message(format!("Message {}", i).as_bytes()) + .unwrap(); + messages.push((ct, h)); + } + + // Bob receives 0, 2, 3 (skipping 1) + bob.decrypt_message(&messages[0].0, messages[0].1.clone()) + .unwrap(); + bob.decrypt_message(&messages[2].0, messages[2].1.clone()) + .unwrap(); + bob.decrypt_message(&messages[3].0, messages[3].1.clone()) + .unwrap(); + + // Verify skipped key is persisted + let bob_state = store.load_state(&bob_session).unwrap().unwrap(); + assert_eq!(bob_state.skipped_keys.len(), 1); + assert_eq!(bob_state.skipped_keys[0].msg_num, 1); + + // Now receive the skipped message + let pt1 = bob + .decrypt_message(&messages[1].0, messages[1].1.clone()) + .unwrap(); + assert_eq!(pt1, b"Message 1"); + + // Skipped key should be removed from storage + let bob_state = store.load_state(&bob_session).unwrap().unwrap(); + assert!(bob_state.skipped_keys.is_empty()); + } } diff --git a/double-ratchets-storage/src/memory.rs b/double-ratchets-storage/src/memory.rs deleted file mode 100644 index cc2791e..0000000 --- a/double-ratchets-storage/src/memory.rs +++ /dev/null @@ -1,216 +0,0 @@ -//! In-memory storage implementation for testing. - -use std::collections::HashMap; -use std::sync::RwLock; - -use crate::error::StorageError; -use crate::traits::{RatchetStorage, SessionId}; -use crate::types::StorableRatchetState; - -/// In-memory storage backend for testing purposes. -/// -/// This implementation stores ratchet states in a `HashMap` wrapped in a `RwLock` -/// for thread-safe access. Data is not persisted across process restarts. -/// -/// # Example -/// -/// ``` -/// use double_ratchets_storage::{MemoryStorage, RatchetStorage}; -/// -/// let storage = MemoryStorage::new(); -/// assert!(storage.list_sessions().unwrap().is_empty()); -/// ``` -pub struct MemoryStorage { - states: RwLock>, -} - -impl MemoryStorage { - /// Create a new empty in-memory storage. - pub fn new() -> Self { - Self { - states: RwLock::new(HashMap::new()), - } - } - - /// Get the number of stored sessions. - pub fn len(&self) -> usize { - self.states.read().unwrap().len() - } - - /// Check if the storage is empty. - pub fn is_empty(&self) -> bool { - self.states.read().unwrap().is_empty() - } - - /// Clear all stored sessions. - pub fn clear(&self) { - self.states.write().unwrap().clear(); - } -} - -impl Default for MemoryStorage { - fn default() -> Self { - Self::new() - } -} - -impl RatchetStorage for MemoryStorage { - fn save(&self, session_id: &SessionId, state: &StorableRatchetState) -> Result<(), StorageError> { - let mut states = self.states.write().unwrap(); - states.insert(*session_id, state.clone()); - Ok(()) - } - - fn load(&self, session_id: &SessionId) -> Result, StorageError> { - let states = self.states.read().unwrap(); - Ok(states.get(session_id).cloned()) - } - - fn delete(&self, session_id: &SessionId) -> Result { - let mut states = self.states.write().unwrap(); - Ok(states.remove(session_id).is_some()) - } - - fn exists(&self, session_id: &SessionId) -> Result { - let states = self.states.read().unwrap(); - Ok(states.contains_key(session_id)) - } - - fn list_sessions(&self) -> Result, StorageError> { - let states = self.states.read().unwrap(); - Ok(states.keys().copied().collect()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use double_ratchets::hkdf::DefaultDomain; - use double_ratchets::state::RatchetState; - use double_ratchets::InstallationKeyPair; - - fn create_test_state() -> StorableRatchetState { - let bob_keypair = InstallationKeyPair::generate(); - let shared_secret = [0x42u8; 32]; - let state: RatchetState = - RatchetState::init_sender(shared_secret, *bob_keypair.public()); - StorableRatchetState::from_ratchet_state(&state, "default") - } - - #[test] - fn test_save_and_load() { - let storage = MemoryStorage::new(); - let session_id = [1u8; 32]; - let state = create_test_state(); - - storage.save(&session_id, &state).unwrap(); - let loaded = storage.load(&session_id).unwrap(); - - assert!(loaded.is_some()); - let loaded = loaded.unwrap(); - assert_eq!(loaded.root_key, state.root_key); - } - - #[test] - fn test_load_nonexistent() { - let storage = MemoryStorage::new(); - let session_id = [1u8; 32]; - - let loaded = storage.load(&session_id).unwrap(); - assert!(loaded.is_none()); - } - - #[test] - fn test_delete() { - let storage = MemoryStorage::new(); - let session_id = [1u8; 32]; - let state = create_test_state(); - - storage.save(&session_id, &state).unwrap(); - assert!(storage.exists(&session_id).unwrap()); - - let deleted = storage.delete(&session_id).unwrap(); - assert!(deleted); - assert!(!storage.exists(&session_id).unwrap()); - - // Deleting again should return false - let deleted = storage.delete(&session_id).unwrap(); - assert!(!deleted); - } - - #[test] - fn test_exists() { - let storage = MemoryStorage::new(); - let session_id = [1u8; 32]; - - assert!(!storage.exists(&session_id).unwrap()); - - let state = create_test_state(); - storage.save(&session_id, &state).unwrap(); - - assert!(storage.exists(&session_id).unwrap()); - } - - #[test] - fn test_list_sessions() { - let storage = MemoryStorage::new(); - - assert!(storage.list_sessions().unwrap().is_empty()); - - let state = create_test_state(); - let session_ids: Vec = (0..3).map(|i| [i; 32]).collect(); - - for id in &session_ids { - storage.save(id, &state).unwrap(); - } - - let mut listed = storage.list_sessions().unwrap(); - listed.sort(); - let mut expected = session_ids.clone(); - expected.sort(); - - assert_eq!(listed, expected); - } - - #[test] - fn test_overwrite() { - let storage = MemoryStorage::new(); - let session_id = [1u8; 32]; - - // Create first state - let bob_keypair1 = InstallationKeyPair::generate(); - let state1: RatchetState = - RatchetState::init_sender([0x42u8; 32], *bob_keypair1.public()); - let storable1 = StorableRatchetState::from_ratchet_state(&state1, "default"); - - // Create second state with different root - let bob_keypair2 = InstallationKeyPair::generate(); - let state2: RatchetState = - RatchetState::init_sender([0x43u8; 32], *bob_keypair2.public()); - let storable2 = StorableRatchetState::from_ratchet_state(&state2, "default"); - - // Save first, then overwrite with second - storage.save(&session_id, &storable1).unwrap(); - storage.save(&session_id, &storable2).unwrap(); - - // Should have the second state - let loaded = storage.load(&session_id).unwrap().unwrap(); - assert_eq!(loaded.root_key, storable2.root_key); - assert_ne!(loaded.root_key, storable1.root_key); - } - - #[test] - fn test_clear() { - let storage = MemoryStorage::new(); - let state = create_test_state(); - - for i in 0..5 { - storage.save(&[i; 32], &state).unwrap(); - } - - assert_eq!(storage.len(), 5); - - storage.clear(); - assert!(storage.is_empty()); - } -} diff --git a/double-ratchets-storage/src/persistent.rs b/double-ratchets-storage/src/persistent.rs new file mode 100644 index 0000000..ca60d1b --- /dev/null +++ b/double-ratchets-storage/src/persistent.rs @@ -0,0 +1,554 @@ +//! Persistent ratchet wrapper that auto-saves state changes. + +use std::sync::Arc; + +use double_ratchets::errors::RatchetError; +use double_ratchets::hkdf::HkdfInfo; +use double_ratchets::state::{Header, RatchetState}; +use double_ratchets::InstallationKeyPair; +use x25519_dalek::PublicKey; + +use crate::error::StorageError; +use crate::traits::{RatchetStore, SessionId}; + +/// A wrapper around `RatchetState` that automatically persists state changes. +/// +/// This wrapper intercepts `encrypt_message` and `decrypt_message` calls, +/// delegates to the underlying `RatchetState`, and then persists the changed +/// fields to storage. +/// +/// # Example +/// +/// ```no_run +/// use double_ratchets::hkdf::DefaultDomain; +/// use double_ratchets::InstallationKeyPair; +/// use double_ratchets_storage::{PersistentRatchet, SqliteRatchetStore}; +/// use std::sync::Arc; +/// +/// let store = Arc::new(SqliteRatchetStore::open_in_memory([0u8; 32]).unwrap()); +/// let session_id = [1u8; 32]; +/// let bob_pub = InstallationKeyPair::generate().public().clone(); +/// let shared_secret = [0x42u8; 32]; +/// +/// let mut ratchet: PersistentRatchet = +/// PersistentRatchet::init_sender(store, session_id, shared_secret, bob_pub).unwrap(); +/// +/// let (ciphertext, header) = ratchet.encrypt_message(b"Hello!").unwrap(); +/// ``` +pub struct PersistentRatchet { + state: RatchetState, + store: Arc, + session_id: SessionId, +} + +impl PersistentRatchet { + /// Initialize as the sender (first to send a message). + /// + /// Creates a new ratchet state and persists it to storage. + pub fn init_sender( + store: Arc, + session_id: SessionId, + shared_secret: [u8; 32], + remote_pub: PublicKey, + ) -> Result { + let state = RatchetState::::init_sender(shared_secret, remote_pub); + + // Persist initial state + store.init_session( + &session_id, + &state.root_key, + state.sending_chain.as_ref(), + state.receiving_chain.as_ref(), + &state.dh_self.secret_bytes(), + state.dh_self.public().as_bytes(), + state.dh_remote.as_ref().map(|pk| pk.as_bytes()), + state.msg_send, + state.msg_recv, + state.prev_chain_len, + )?; + + Ok(Self { + state, + store, + session_id, + }) + } + + /// Initialize as the receiver (first to receive a message). + /// + /// Creates a new ratchet state and persists it to storage. + pub fn init_receiver( + store: Arc, + session_id: SessionId, + shared_secret: [u8; 32], + dh_self: InstallationKeyPair, + ) -> Result { + let state = RatchetState::::init_receiver(shared_secret, dh_self); + + // Persist initial state + store.init_session( + &session_id, + &state.root_key, + state.sending_chain.as_ref(), + state.receiving_chain.as_ref(), + &state.dh_self.secret_bytes(), + state.dh_self.public().as_bytes(), + state.dh_remote.as_ref().map(|pk| pk.as_bytes()), + state.msg_send, + state.msg_recv, + state.prev_chain_len, + )?; + + Ok(Self { + state, + store, + session_id, + }) + } + + /// Load an existing session from storage. + pub fn load( + store: Arc, + session_id: SessionId, + ) -> Result, StorageError> { + let Some(stored) = store.load_state(&session_id)? else { + return Ok(None); + }; + + // Reconstruct the keypair + let dh_self = InstallationKeyPair::from_bytes(stored.dh_self_secret, stored.dh_self_public) + .map_err(|e| StorageError::KeyReconstruction(e.to_string()))?; + + // Reconstruct skipped keys + let mut skipped_keys = std::collections::HashMap::new(); + for entry in stored.skipped_keys { + let pub_key = PublicKey::from(entry.dh_public); + skipped_keys.insert((pub_key, entry.msg_num), entry.message_key); + } + + let state = RatchetState::::from_parts( + stored.root_key, + stored.sending_chain, + stored.receiving_chain, + dh_self, + stored.dh_remote.map(PublicKey::from), + stored.msg_send, + stored.msg_recv, + stored.prev_chain_len, + skipped_keys, + ); + + Ok(Some(Self { + state, + store, + session_id, + })) + } + + /// Encrypt a message and persist state changes. + /// + /// This may trigger a DH ratchet if the sending direction changed. + /// All state changes are persisted to storage after encryption. + pub fn encrypt_message(&mut self, plaintext: &[u8]) -> Result<(Vec, Header), StorageError> { + // Check if we'll do a DH ratchet (no sending chain) + let will_ratchet = self.state.sending_chain.is_none(); + + // Perform encryption + let (ciphertext, header) = self.state.encrypt_message(plaintext); + + // Persist changes + if will_ratchet { + // DH ratchet happened: root_key, sending_chain, dh_self all changed + self.store.store_root_and_chains( + &self.session_id, + &self.state.root_key, + self.state.sending_chain.as_ref(), + self.state.receiving_chain.as_ref(), + )?; + self.store.store_dh_self( + &self.session_id, + &self.state.dh_self.secret_bytes(), + self.state.dh_self.public().as_bytes(), + )?; + } else { + // Only sending chain changed + self.store.store_root_and_chains( + &self.session_id, + &self.state.root_key, + self.state.sending_chain.as_ref(), + self.state.receiving_chain.as_ref(), + )?; + } + + // Counters always change + self.store.store_counters( + &self.session_id, + self.state.msg_send, + self.state.msg_recv, + self.state.prev_chain_len, + )?; + + Ok((ciphertext, header)) + } + + /// Decrypt a message and persist state changes. + /// + /// Handles DH ratcheting, skipped messages, and replay protection. + /// All state changes are persisted to storage after decryption. + pub fn decrypt_message( + &mut self, + ciphertext_with_nonce: &[u8], + header: Header, + ) -> Result, PersistentRatchetError> { + // Track skipped keys before decryption + let skipped_before: std::collections::HashSet<_> = self + .state + .skipped_keys + .keys() + .map(|(pk, n)| (*pk.as_bytes(), *n)) + .collect(); + + // Check if we'll do a DH ratchet + let will_ratchet = self.state.dh_remote.as_ref() != Some(&header.dh_pub); + + // Perform decryption + let plaintext = self + .state + .decrypt_message(ciphertext_with_nonce, header) + .map_err(PersistentRatchetError::Ratchet)?; + + // Track skipped keys after decryption + let skipped_after: std::collections::HashSet<_> = self + .state + .skipped_keys + .keys() + .map(|(pk, n)| (*pk.as_bytes(), *n)) + .collect(); + + // Persist changes + if will_ratchet { + // DH ratchet happened + self.store + .store_root_and_chains( + &self.session_id, + &self.state.root_key, + self.state.sending_chain.as_ref(), + self.state.receiving_chain.as_ref(), + ) + .map_err(PersistentRatchetError::Storage)?; + self.store + .store_dh_remote( + &self.session_id, + self.state.dh_remote.as_ref().map(|pk| pk.as_bytes()), + ) + .map_err(PersistentRatchetError::Storage)?; + } else { + // Only receiving chain changed + self.store + .store_root_and_chains( + &self.session_id, + &self.state.root_key, + self.state.sending_chain.as_ref(), + self.state.receiving_chain.as_ref(), + ) + .map_err(PersistentRatchetError::Storage)?; + } + + // Counters + self.store + .store_counters( + &self.session_id, + self.state.msg_send, + self.state.msg_recv, + self.state.prev_chain_len, + ) + .map_err(PersistentRatchetError::Storage)?; + + // Handle skipped keys changes + // New skipped keys (added during skip_message_keys) + for (pk_bytes, msg_num) in skipped_after.difference(&skipped_before) { + let pk = PublicKey::from(*pk_bytes); + if let Some(key) = self.state.skipped_keys.get(&(pk, *msg_num)) { + self.store + .add_skipped_key(&self.session_id, pk_bytes, *msg_num, key) + .map_err(PersistentRatchetError::Storage)?; + } + } + + // Removed skipped keys (used for decryption) + for (pk_bytes, msg_num) in skipped_before.difference(&skipped_after) { + self.store + .remove_skipped_key(&self.session_id, pk_bytes, *msg_num) + .map_err(PersistentRatchetError::Storage)?; + } + + Ok(plaintext) + } + + /// Get a reference to the underlying state (read-only). + pub fn state(&self) -> &RatchetState { + &self.state + } + + /// Get the session ID. + pub fn session_id(&self) -> &SessionId { + &self.session_id + } + + /// Delete this session from storage. + pub fn delete(self) -> Result { + self.store.delete_session(&self.session_id) + } +} + +/// Error type for persistent ratchet operations. +#[derive(Debug)] +pub enum PersistentRatchetError { + /// Storage operation failed. + Storage(StorageError), + /// Ratchet operation failed. + Ratchet(RatchetError), +} + +impl std::fmt::Display for PersistentRatchetError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Storage(e) => write!(f, "storage error: {}", e), + Self::Ratchet(e) => write!(f, "ratchet error: {:?}", e), + } + } +} + +impl std::error::Error for PersistentRatchetError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Storage(e) => Some(e), + Self::Ratchet(_) => None, + } + } +} + +impl From for PersistentRatchetError { + fn from(e: StorageError) -> Self { + Self::Storage(e) + } +} + +impl From for PersistentRatchetError { + fn from(e: RatchetError) -> Self { + Self::Ratchet(e) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sqlite::SqliteRatchetStore; + use double_ratchets::hkdf::DefaultDomain; + + fn test_store() -> Arc { + Arc::new(SqliteRatchetStore::open_in_memory([0x42u8; 32]).unwrap()) + } + + #[test] + fn test_basic_roundtrip() { + let store = test_store(); + let alice_session = [0xAA; 32]; + let bob_session = [0xBB; 32]; + + let bob_keypair = InstallationKeyPair::generate(); + let shared_secret = [0x42u8; 32]; + + // Initialize both parties + let mut alice: PersistentRatchet = PersistentRatchet::init_sender( + Arc::clone(&store), + alice_session, + shared_secret, + *bob_keypair.public(), + ) + .unwrap(); + + let mut bob: PersistentRatchet = + PersistentRatchet::init_receiver(Arc::clone(&store), bob_session, shared_secret, bob_keypair) + .unwrap(); + + // Alice sends a message + let (ct, header) = alice.encrypt_message(b"Hello Bob!").unwrap(); + let pt = bob.decrypt_message(&ct, header).unwrap(); + assert_eq!(pt, b"Hello Bob!"); + + // Verify state was persisted + let alice_loaded = store.load_state(&alice_session).unwrap().unwrap(); + assert_eq!(alice_loaded.msg_send, 1); + + let bob_loaded = store.load_state(&bob_session).unwrap().unwrap(); + assert_eq!(bob_loaded.msg_recv, 1); + } + + #[test] + fn test_load_and_continue() { + let store = test_store(); + let alice_session = [0xAA; 32]; + let bob_session = [0xBB; 32]; + + let bob_keypair = InstallationKeyPair::generate(); + let shared_secret = [0x42u8; 32]; + + // Initialize and exchange one message + { + let mut alice: PersistentRatchet = PersistentRatchet::init_sender( + Arc::clone(&store), + alice_session, + shared_secret, + *bob_keypair.public(), + ) + .unwrap(); + + let mut bob: PersistentRatchet = PersistentRatchet::init_receiver( + Arc::clone(&store), + bob_session, + shared_secret, + bob_keypair, + ) + .unwrap(); + + let (ct, header) = alice.encrypt_message(b"Message 1").unwrap(); + bob.decrypt_message(&ct, header).unwrap(); + } + + // Load from storage and continue + { + let mut alice: PersistentRatchet = + PersistentRatchet::load(Arc::clone(&store), alice_session) + .unwrap() + .unwrap(); + + let mut bob: PersistentRatchet = + PersistentRatchet::load(Arc::clone(&store), bob_session) + .unwrap() + .unwrap(); + + // Bob replies + let (ct, header) = bob.encrypt_message(b"Reply from Bob").unwrap(); + let pt = alice.decrypt_message(&ct, header).unwrap(); + assert_eq!(pt, b"Reply from Bob"); + + // Alice sends another + let (ct2, header2) = alice.encrypt_message(b"Message 2").unwrap(); + let pt2 = bob.decrypt_message(&ct2, header2).unwrap(); + assert_eq!(pt2, b"Message 2"); + } + } + + #[test] + fn test_skipped_keys_persisted() { + let store = test_store(); + let alice_session = [0xAA; 32]; + let bob_session = [0xBB; 32]; + + let bob_keypair = InstallationKeyPair::generate(); + let shared_secret = [0x42u8; 32]; + + let mut alice: PersistentRatchet = PersistentRatchet::init_sender( + Arc::clone(&store), + alice_session, + shared_secret, + *bob_keypair.public(), + ) + .unwrap(); + + let mut bob: PersistentRatchet = + PersistentRatchet::init_receiver(Arc::clone(&store), bob_session, shared_secret, bob_keypair) + .unwrap(); + + // Alice sends 3 messages + let mut messages = vec![]; + for i in 0..3 { + let (ct, header) = alice + .encrypt_message(format!("Message {}", i).as_bytes()) + .unwrap(); + messages.push((ct, header)); + } + + // Bob receives them out of order: 0, 2 (skipping 1) + bob.decrypt_message(&messages[0].0, messages[0].1.clone()) + .unwrap(); + bob.decrypt_message(&messages[2].0, messages[2].1.clone()) + .unwrap(); + + // Check skipped key was persisted + let bob_loaded = store.load_state(&bob_session).unwrap().unwrap(); + assert_eq!(bob_loaded.skipped_keys.len(), 1); + assert_eq!(bob_loaded.skipped_keys[0].msg_num, 1); + + // Now receive the skipped message + bob.decrypt_message(&messages[1].0, messages[1].1.clone()) + .unwrap(); + + // Skipped key should be removed + let bob_loaded = store.load_state(&bob_session).unwrap().unwrap(); + assert!(bob_loaded.skipped_keys.is_empty()); + } + + #[test] + fn test_dh_ratchet_persisted() { + let store = test_store(); + let alice_session = [0xAA; 32]; + let bob_session = [0xBB; 32]; + + let bob_keypair = InstallationKeyPair::generate(); + let shared_secret = [0x42u8; 32]; + + let mut alice: PersistentRatchet = PersistentRatchet::init_sender( + Arc::clone(&store), + alice_session, + shared_secret, + *bob_keypair.public(), + ) + .unwrap(); + + let mut bob: PersistentRatchet = + PersistentRatchet::init_receiver(Arc::clone(&store), bob_session, shared_secret, bob_keypair) + .unwrap(); + + // Alice sends + let (ct1, h1) = alice.encrypt_message(b"Hello").unwrap(); + bob.decrypt_message(&ct1, h1).unwrap(); + + // Get Bob's initial DH public + let bob_initial_pub = bob.state().dh_self.public().as_bytes().clone(); + + // Bob replies (triggers DH ratchet) + let (ct2, h2) = bob.encrypt_message(b"Hi").unwrap(); + alice.decrypt_message(&ct2, h2).unwrap(); + + // Bob's DH key should have changed + let bob_new_pub = bob.state().dh_self.public().as_bytes().clone(); + assert_ne!(bob_initial_pub, bob_new_pub); + + // Verify persisted + let bob_loaded = store.load_state(&bob_session).unwrap().unwrap(); + assert_eq!(bob_loaded.dh_self_public, bob_new_pub); + } + + #[test] + fn test_delete_session() { + let store = test_store(); + let session_id = [0x11; 32]; + let bob_keypair = InstallationKeyPair::generate(); + + let ratchet: PersistentRatchet = PersistentRatchet::init_sender( + Arc::clone(&store), + session_id, + [0x42u8; 32], + *bob_keypair.public(), + ) + .unwrap(); + + assert!(store.session_exists(&session_id).unwrap()); + + ratchet.delete().unwrap(); + + assert!(!store.session_exists(&session_id).unwrap()); + } +} diff --git a/double-ratchets-storage/src/sqlite.rs b/double-ratchets-storage/src/sqlite.rs index bac9dfc..ddbf3bc 100644 --- a/double-ratchets-storage/src/sqlite.rs +++ b/double-ratchets-storage/src/sqlite.rs @@ -2,7 +2,6 @@ use std::path::Path; use std::sync::Mutex; -use std::time::{SystemTime, UNIX_EPOCH}; use chacha20poly1305::{ aead::{Aead, KeyInit}, @@ -12,85 +11,37 @@ use rand::RngCore; use rusqlite::{params, Connection, OptionalExtension}; use crate::error::StorageError; -use crate::traits::{RatchetStorage, SessionId}; -use crate::types::{SkippedKey, StorableRatchetState}; +use crate::traits::{RatchetStore, SessionId, SkippedKeyEntry, StoredState}; /// Field encryption key type (32 bytes for ChaCha20Poly1305). pub type EncryptionKey = [u8; 32]; -/// SQLite storage backend with field-level encryption for secrets. +/// SQLite storage with field-level encryption for secrets. /// -/// This implementation stores ratchet states in SQLite with: -/// - Field-level encryption for private keys using ChaCha20Poly1305 -/// - WAL mode for better concurrent performance -/// - Foreign keys and cascading deletes for data integrity -/// - Atomic transactions to prevent partial writes +/// Schema: +/// - `sessions`: Core ratchet state (one row per session) +/// - `skipped_keys`: Skipped message keys (many per session) /// -/// # Security -/// -/// The `dh_self_secret` field is encrypted with the provided encryption key. -/// For additional security, consider using SQLCipher for full database encryption -/// via the `open_encrypted` method (requires `sqlcipher` feature). -/// -/// # Example -/// -/// ```no_run -/// use double_ratchets_storage::SqliteStorage; -/// -/// let key = [0u8; 32]; // Use a proper key derivation function -/// let storage = SqliteStorage::open("ratchets.db", key).unwrap(); -/// ``` -pub struct SqliteStorage { +/// Encrypted fields: dh_self_secret, skipped message_key +pub struct SqliteRatchetStore { conn: Mutex, encryption_key: EncryptionKey, } -impl SqliteStorage { - /// Open or create a SQLite database with field-level encryption. - /// - /// # Arguments - /// - /// * `path` - Path to the database file. - /// * `encryption_key` - 32-byte key for field-level encryption. - /// - /// # Returns - /// - /// * `Ok(SqliteStorage)` on success. - /// * `Err(StorageError)` on failure. +impl SqliteRatchetStore { + /// Open or create a SQLite database. pub fn open>(path: P, encryption_key: EncryptionKey) -> Result { let conn = Connection::open(path)?; Self::initialize(conn, encryption_key) } - /// Create an in-memory SQLite database (for testing). - /// - /// # Arguments - /// - /// * `encryption_key` - 32-byte key for field-level encryption. - /// - /// # Returns - /// - /// * `Ok(SqliteStorage)` on success. - /// * `Err(StorageError)` on failure. + /// Create an in-memory database (for testing). pub fn open_in_memory(encryption_key: EncryptionKey) -> Result { let conn = Connection::open_in_memory()?; Self::initialize(conn, encryption_key) } - /// Open or create a SQLCipher-encrypted database. - /// - /// This method requires the `sqlcipher` feature to be enabled. - /// - /// # Arguments - /// - /// * `path` - Path to the database file. - /// * `db_password` - Password for SQLCipher database encryption. - /// * `field_key` - 32-byte key for additional field-level encryption. - /// - /// # Returns - /// - /// * `Ok(SqliteStorage)` on success. - /// * `Err(StorageError)` on failure. + /// Open with SQLCipher full-database encryption. #[cfg(feature = "sqlcipher")] pub fn open_encrypted>( path: P, @@ -98,50 +49,39 @@ impl SqliteStorage { field_key: EncryptionKey, ) -> Result { let conn = Connection::open(path)?; - - // Set SQLCipher key conn.pragma_update(None, "key", db_password)?; - Self::initialize(conn, field_key) } fn initialize(conn: Connection, encryption_key: EncryptionKey) -> Result { - // Enable WAL mode for better performance conn.pragma_update(None, "journal_mode", "WAL")?; - - // Enable foreign keys conn.pragma_update(None, "foreign_keys", "ON")?; - // Create tables conn.execute_batch( r#" - CREATE TABLE IF NOT EXISTS ratchet_states ( - session_id BLOB PRIMARY KEY, + CREATE TABLE IF NOT EXISTS sessions ( + session_id BLOB PRIMARY KEY NOT NULL, root_key BLOB NOT NULL, sending_chain BLOB, receiving_chain BLOB, - dh_self_secret_encrypted BLOB NOT NULL, - dh_self_secret_nonce BLOB NOT NULL, - dh_self_public BLOB NOT NULL, + dh_secret_enc BLOB NOT NULL, + dh_secret_nonce BLOB NOT NULL, + dh_public BLOB NOT NULL, dh_remote BLOB, msg_send INTEGER NOT NULL, msg_recv INTEGER NOT NULL, - prev_chain_len INTEGER NOT NULL, - domain_id TEXT NOT NULL, - created_at INTEGER NOT NULL, - updated_at INTEGER NOT NULL + prev_chain_len INTEGER NOT NULL ); CREATE TABLE IF NOT EXISTS skipped_keys ( - id INTEGER PRIMARY KEY, - session_id BLOB NOT NULL REFERENCES ratchet_states(session_id) ON DELETE CASCADE, - public_key BLOB NOT NULL, + session_id BLOB NOT NULL, + dh_public BLOB NOT NULL, msg_num INTEGER NOT NULL, - message_key BLOB NOT NULL, - UNIQUE(session_id, public_key, msg_num) + message_key_enc BLOB NOT NULL, + message_key_nonce BLOB NOT NULL, + PRIMARY KEY (session_id, dh_public, msg_num), + FOREIGN KEY (session_id) REFERENCES sessions(session_id) ON DELETE CASCADE ); - - CREATE INDEX IF NOT EXISTS idx_skipped_keys_session ON skipped_keys(session_id); "#, )?; @@ -151,23 +91,20 @@ impl SqliteStorage { }) } - /// Encrypt a 32-byte secret using ChaCha20Poly1305. - fn encrypt_secret(&self, secret: &[u8; 32]) -> Result<(Vec, [u8; 12]), StorageError> { + fn encrypt(&self, plaintext: &[u8; 32]) -> Result<(Vec, [u8; 12]), StorageError> { let cipher = ChaCha20Poly1305::new((&self.encryption_key).into()); - let mut nonce_bytes = [0u8; 12]; rand::thread_rng().fill_bytes(&mut nonce_bytes); let nonce = Nonce::from_slice(&nonce_bytes); let ciphertext = cipher - .encrypt(nonce, secret.as_ref()) + .encrypt(nonce, plaintext.as_ref()) .map_err(|e| StorageError::Encryption(e.to_string()))?; Ok((ciphertext, nonce_bytes)) } - /// Decrypt a secret using ChaCha20Poly1305. - fn decrypt_secret(&self, ciphertext: &[u8], nonce: &[u8; 12]) -> Result<[u8; 32], StorageError> { + fn decrypt(&self, ciphertext: &[u8], nonce: &[u8; 12]) -> Result<[u8; 32], StorageError> { let cipher = ChaCha20Poly1305::new((&self.encryption_key).into()); let nonce = Nonce::from_slice(nonce); @@ -177,135 +114,110 @@ impl SqliteStorage { plaintext .try_into() - .map_err(|_| StorageError::CorruptedState("decrypted secret has wrong length".to_string())) - } - - fn current_timestamp() -> i64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() as i64 + .map_err(|_| StorageError::CorruptedState("wrong decrypted length".into())) } } -impl RatchetStorage for SqliteStorage { - fn save(&self, session_id: &SessionId, state: &StorableRatchetState) -> Result<(), StorageError> { - let (encrypted_secret, nonce) = self.encrypt_secret(&state.dh_self_secret)?; - +impl RatchetStore for SqliteRatchetStore { + fn store_root_and_chains( + &self, + session_id: &SessionId, + root_key: &[u8; 32], + sending_chain: Option<&[u8; 32]>, + receiving_chain: Option<&[u8; 32]>, + ) -> Result<(), StorageError> { let conn = self.conn.lock().unwrap(); - let tx = conn.unchecked_transaction()?; - - let now = Self::current_timestamp(); - - // Check if session exists to determine created_at - let exists: bool = tx.query_row( - "SELECT 1 FROM ratchet_states WHERE session_id = ?", - [session_id.as_slice()], - |_| Ok(true), - ).optional()?.unwrap_or(false); - - if exists { - // Update existing session - tx.execute( - r#" - UPDATE ratchet_states SET - root_key = ?, - sending_chain = ?, - receiving_chain = ?, - dh_self_secret_encrypted = ?, - dh_self_secret_nonce = ?, - dh_self_public = ?, - dh_remote = ?, - msg_send = ?, - msg_recv = ?, - prev_chain_len = ?, - domain_id = ?, - updated_at = ? - WHERE session_id = ? - "#, - params![ - state.root_key.as_slice(), - state.sending_chain.as_ref().map(|c| c.as_slice()), - state.receiving_chain.as_ref().map(|c| c.as_slice()), - encrypted_secret.as_slice(), - nonce.as_slice(), - state.dh_self_public.as_slice(), - state.dh_remote.as_ref().map(|pk| pk.as_slice()), - state.msg_send, - state.msg_recv, - state.prev_chain_len, - &state.domain_id, - now, - session_id.as_slice(), - ], - )?; - - // Delete existing skipped keys - tx.execute( - "DELETE FROM skipped_keys WHERE session_id = ?", - [session_id.as_slice()], - )?; - } else { - // Insert new session - tx.execute( - r#" - INSERT INTO ratchet_states ( - session_id, root_key, sending_chain, receiving_chain, - dh_self_secret_encrypted, dh_self_secret_nonce, dh_self_public, - dh_remote, msg_send, msg_recv, prev_chain_len, domain_id, - created_at, updated_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - "#, - params![ - session_id.as_slice(), - state.root_key.as_slice(), - state.sending_chain.as_ref().map(|c| c.as_slice()), - state.receiving_chain.as_ref().map(|c| c.as_slice()), - encrypted_secret.as_slice(), - nonce.as_slice(), - state.dh_self_public.as_slice(), - state.dh_remote.as_ref().map(|pk| pk.as_slice()), - state.msg_send, - state.msg_recv, - state.prev_chain_len, - &state.domain_id, - now, - now, - ], - )?; - } - - // Insert skipped keys - for sk in &state.skipped_keys { - tx.execute( - r#" - INSERT INTO skipped_keys (session_id, public_key, msg_num, message_key) - VALUES (?, ?, ?, ?) - "#, - params![ - session_id.as_slice(), - sk.public_key.as_slice(), - sk.msg_num, - sk.message_key.as_slice(), - ], - )?; - } - - tx.commit()?; + conn.execute( + "UPDATE sessions SET root_key = ?, sending_chain = ?, receiving_chain = ? WHERE session_id = ?", + params![ + root_key.as_slice(), + sending_chain.map(|c| c.as_slice()), + receiving_chain.map(|c| c.as_slice()), + session_id.as_slice(), + ], + )?; Ok(()) } - fn load(&self, session_id: &SessionId) -> Result, StorageError> { + fn store_dh_self( + &self, + session_id: &SessionId, + secret: &[u8; 32], + public: &[u8; 32], + ) -> Result<(), StorageError> { + let (enc, nonce) = self.encrypt(secret)?; + let conn = self.conn.lock().unwrap(); + conn.execute( + "UPDATE sessions SET dh_secret_enc = ?, dh_secret_nonce = ?, dh_public = ? WHERE session_id = ?", + params![enc.as_slice(), nonce.as_slice(), public.as_slice(), session_id.as_slice()], + )?; + Ok(()) + } + + fn store_dh_remote( + &self, + session_id: &SessionId, + remote: Option<&[u8; 32]>, + ) -> Result<(), StorageError> { + let conn = self.conn.lock().unwrap(); + conn.execute( + "UPDATE sessions SET dh_remote = ? WHERE session_id = ?", + params![remote.map(|r| r.as_slice()), session_id.as_slice()], + )?; + Ok(()) + } + + fn store_counters( + &self, + session_id: &SessionId, + msg_send: u32, + msg_recv: u32, + prev_chain_len: u32, + ) -> Result<(), StorageError> { + let conn = self.conn.lock().unwrap(); + conn.execute( + "UPDATE sessions SET msg_send = ?, msg_recv = ?, prev_chain_len = ? WHERE session_id = ?", + params![msg_send, msg_recv, prev_chain_len, session_id.as_slice()], + )?; + Ok(()) + } + + fn add_skipped_key( + &self, + session_id: &SessionId, + dh_public: &[u8; 32], + msg_num: u32, + message_key: &[u8; 32], + ) -> Result<(), StorageError> { + let (enc, nonce) = self.encrypt(message_key)?; + let conn = self.conn.lock().unwrap(); + conn.execute( + "INSERT OR REPLACE INTO skipped_keys (session_id, dh_public, msg_num, message_key_enc, message_key_nonce) VALUES (?, ?, ?, ?, ?)", + params![session_id.as_slice(), dh_public.as_slice(), msg_num, enc.as_slice(), nonce.as_slice()], + )?; + Ok(()) + } + + fn remove_skipped_key( + &self, + session_id: &SessionId, + dh_public: &[u8; 32], + msg_num: u32, + ) -> Result<(), StorageError> { + let conn = self.conn.lock().unwrap(); + conn.execute( + "DELETE FROM skipped_keys WHERE session_id = ? AND dh_public = ? AND msg_num = ?", + params![session_id.as_slice(), dh_public.as_slice(), msg_num], + )?; + Ok(()) + } + + fn load_state(&self, session_id: &SessionId) -> Result, StorageError> { let conn = self.conn.lock().unwrap(); let row = conn .query_row( - r#" - SELECT root_key, sending_chain, receiving_chain, - dh_self_secret_encrypted, dh_self_secret_nonce, dh_self_public, - dh_remote, msg_send, msg_recv, prev_chain_len, domain_id - FROM ratchet_states WHERE session_id = ? - "#, + "SELECT root_key, sending_chain, receiving_chain, dh_secret_enc, dh_secret_nonce, dh_public, dh_remote, msg_send, msg_recv, prev_chain_len FROM sessions WHERE session_id = ?", [session_id.as_slice()], |row| { Ok(( @@ -319,91 +231,64 @@ impl RatchetStorage for SqliteStorage { row.get::<_, u32>(7)?, row.get::<_, u32>(8)?, row.get::<_, u32>(9)?, - row.get::<_, String>(10)?, )) }, ) .optional()?; let Some(( - root_key_bytes, - sending_chain_bytes, - receiving_chain_bytes, - encrypted_secret, - nonce_bytes, - dh_self_public_bytes, - dh_remote_bytes, - msg_send, - msg_recv, - prev_chain_len, - domain_id, - )) = row - else { + root_key_bytes, sending_bytes, receiving_bytes, + secret_enc, secret_nonce, dh_pub_bytes, dh_remote_bytes, + msg_send, msg_recv, prev_chain_len, + )) = row else { return Ok(None); }; - // Decrypt the secret - let nonce: [u8; 12] = nonce_bytes - .try_into() - .map_err(|_| StorageError::CorruptedState("invalid nonce length".to_string()))?; - let dh_self_secret = self.decrypt_secret(&encrypted_secret, &nonce)?; + let nonce: [u8; 12] = secret_nonce.try_into() + .map_err(|_| StorageError::CorruptedState("invalid nonce".into()))?; + let dh_self_secret = self.decrypt(&secret_enc, &nonce)?; - // Convert byte vectors to arrays - let root_key: [u8; 32] = root_key_bytes - .try_into() - .map_err(|_| StorageError::CorruptedState("invalid root_key length".to_string()))?; - let dh_self_public: [u8; 32] = dh_self_public_bytes - .try_into() - .map_err(|_| StorageError::CorruptedState("invalid dh_self_public length".to_string()))?; + let root_key: [u8; 32] = root_key_bytes.try_into() + .map_err(|_| StorageError::CorruptedState("invalid root_key".into()))?; + let dh_self_public: [u8; 32] = dh_pub_bytes.try_into() + .map_err(|_| StorageError::CorruptedState("invalid dh_public".into()))?; - let sending_chain = sending_chain_bytes - .map(|b| { - b.try_into() - .map_err(|_| StorageError::CorruptedState("invalid sending_chain length".to_string())) - }) + let sending_chain = sending_bytes + .map(|b| b.try_into().map_err(|_| StorageError::CorruptedState("invalid sending_chain".into()))) .transpose()?; - let receiving_chain = receiving_chain_bytes - .map(|b| { - b.try_into() - .map_err(|_| StorageError::CorruptedState("invalid receiving_chain length".to_string())) - }) + let receiving_chain = receiving_bytes + .map(|b| b.try_into().map_err(|_| StorageError::CorruptedState("invalid receiving_chain".into()))) .transpose()?; let dh_remote = dh_remote_bytes - .map(|b| { - b.try_into() - .map_err(|_| StorageError::CorruptedState("invalid dh_remote length".to_string())) - }) + .map(|b| b.try_into().map_err(|_| StorageError::CorruptedState("invalid dh_remote".into()))) .transpose()?; // Load skipped keys let mut stmt = conn.prepare( - "SELECT public_key, msg_num, message_key FROM skipped_keys WHERE session_id = ?", + "SELECT dh_public, msg_num, message_key_enc, message_key_nonce FROM skipped_keys WHERE session_id = ?", )?; - let skipped_keys: Vec = stmt - .query_map([session_id.as_slice()], |row| { - let pk_bytes: Vec = row.get(0)?; - let msg_num: u32 = row.get(1)?; - let mk_bytes: Vec = row.get(2)?; - Ok((pk_bytes, msg_num, mk_bytes)) - })? - .collect::, _>>()? - .into_iter() - .map(|(pk_bytes, msg_num, mk_bytes)| { - let public_key: [u8; 32] = pk_bytes - .try_into() - .map_err(|_| StorageError::CorruptedState("invalid skipped public_key length".to_string()))?; - let message_key: [u8; 32] = mk_bytes - .try_into() - .map_err(|_| StorageError::CorruptedState("invalid skipped message_key length".to_string()))?; - Ok(SkippedKey { - public_key, - msg_num, - message_key, - }) - }) - .collect::, StorageError>>()?; + let mut skipped_keys = Vec::new(); + let rows = stmt.query_map([session_id.as_slice()], |row| { + Ok(( + row.get::<_, Vec>(0)?, + row.get::<_, u32>(1)?, + row.get::<_, Vec>(2)?, + row.get::<_, Vec>(3)?, + )) + })?; - Ok(Some(StorableRatchetState { + for row in rows { + let (pk_bytes, msg_num, key_enc, key_nonce) = row?; + let dh_public: [u8; 32] = pk_bytes.try_into() + .map_err(|_| StorageError::CorruptedState("invalid skipped dh_public".into()))?; + let nonce: [u8; 12] = key_nonce.try_into() + .map_err(|_| StorageError::CorruptedState("invalid skipped nonce".into()))?; + let message_key = self.decrypt(&key_enc, &nonce)?; + + skipped_keys.push(SkippedKeyEntry { dh_public, msg_num, message_key }); + } + + Ok(Some(StoredState { root_key, sending_chain, receiving_chain, @@ -414,24 +299,72 @@ impl RatchetStorage for SqliteStorage { msg_recv, prev_chain_len, skipped_keys, - domain_id, })) } - fn delete(&self, session_id: &SessionId) -> Result { + fn init_session( + &self, + session_id: &SessionId, + root_key: &[u8; 32], + sending_chain: Option<&[u8; 32]>, + receiving_chain: Option<&[u8; 32]>, + dh_self_secret: &[u8; 32], + dh_self_public: &[u8; 32], + dh_remote: Option<&[u8; 32]>, + msg_send: u32, + msg_recv: u32, + prev_chain_len: u32, + ) -> Result<(), StorageError> { + let (secret_enc, secret_nonce) = self.encrypt(dh_self_secret)?; + let conn = self.conn.lock().unwrap(); + + conn.execute( + r#" + INSERT INTO sessions (session_id, root_key, sending_chain, receiving_chain, dh_secret_enc, dh_secret_nonce, dh_public, dh_remote, msg_send, msg_recv, prev_chain_len) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(session_id) DO UPDATE SET + root_key = excluded.root_key, + sending_chain = excluded.sending_chain, + receiving_chain = excluded.receiving_chain, + dh_secret_enc = excluded.dh_secret_enc, + dh_secret_nonce = excluded.dh_secret_nonce, + dh_public = excluded.dh_public, + dh_remote = excluded.dh_remote, + msg_send = excluded.msg_send, + msg_recv = excluded.msg_recv, + prev_chain_len = excluded.prev_chain_len + "#, + params![ + session_id.as_slice(), + root_key.as_slice(), + sending_chain.map(|c| c.as_slice()), + receiving_chain.map(|c| c.as_slice()), + secret_enc.as_slice(), + secret_nonce.as_slice(), + dh_self_public.as_slice(), + dh_remote.map(|r| r.as_slice()), + msg_send, + msg_recv, + prev_chain_len, + ], + )?; + Ok(()) + } + + fn delete_session(&self, session_id: &SessionId) -> Result { let conn = self.conn.lock().unwrap(); let changes = conn.execute( - "DELETE FROM ratchet_states WHERE session_id = ?", + "DELETE FROM sessions WHERE session_id = ?", [session_id.as_slice()], )?; Ok(changes > 0) } - fn exists(&self, session_id: &SessionId) -> Result { + fn session_exists(&self, session_id: &SessionId) -> Result { let conn = self.conn.lock().unwrap(); let exists: bool = conn .query_row( - "SELECT 1 FROM ratchet_states WHERE session_id = ?", + "SELECT 1 FROM sessions WHERE session_id = ?", [session_id.as_slice()], |_| Ok(true), ) @@ -442,15 +375,11 @@ impl RatchetStorage for SqliteStorage { fn list_sessions(&self) -> Result, StorageError> { let conn = self.conn.lock().unwrap(); - let mut stmt = conn.prepare("SELECT session_id FROM ratchet_states")?; - let sessions: Vec = stmt - .query_map([], |row| { - let bytes: Vec = row.get(0)?; - Ok(bytes) - })? - .collect::, _>>()? - .into_iter() - .filter_map(|bytes| bytes.try_into().ok()) + let mut stmt = conn.prepare("SELECT session_id FROM sessions")?; + let sessions = stmt + .query_map([], |row| row.get::<_, Vec>(0))? + .filter_map(|r| r.ok()) + .filter_map(|v| v.try_into().ok()) .collect(); Ok(sessions) } @@ -459,288 +388,99 @@ impl RatchetStorage for SqliteStorage { #[cfg(test)] mod tests { use super::*; - use double_ratchets::hkdf::DefaultDomain; - use double_ratchets::state::RatchetState; - use double_ratchets::InstallationKeyPair; - fn create_test_storage() -> SqliteStorage { - let key = [0x42u8; 32]; - SqliteStorage::open_in_memory(key).unwrap() - } - - fn create_test_state() -> StorableRatchetState { - let bob_keypair = InstallationKeyPair::generate(); - let shared_secret = [0x42u8; 32]; - let state: RatchetState = - RatchetState::init_sender(shared_secret, *bob_keypair.public()); - StorableRatchetState::from_ratchet_state(&state, "default") + fn test_store() -> SqliteRatchetStore { + SqliteRatchetStore::open_in_memory([0x42u8; 32]).unwrap() } #[test] - fn test_save_and_load() { - let storage = create_test_storage(); - let session_id = [1u8; 32]; - let state = create_test_state(); - - storage.save(&session_id, &state).unwrap(); - let loaded = storage.load(&session_id).unwrap(); - - assert!(loaded.is_some()); - let loaded = loaded.unwrap(); - assert_eq!(loaded.root_key, state.root_key); - assert_eq!(loaded.dh_self_public, state.dh_self_public); - // Secret should be decrypted correctly - assert_eq!(loaded.dh_self_secret, state.dh_self_secret); - } - - #[test] - fn test_load_nonexistent() { - let storage = create_test_storage(); + fn test_init_and_load() { + let store = test_store(); let session_id = [1u8; 32]; - let loaded = storage.load(&session_id).unwrap(); - assert!(loaded.is_none()); + store.init_session( + &session_id, + &[0xAA; 32], // root + Some(&[0xBB; 32]), // sending + None, // receiving + &[0xCC; 32], // secret + &[0xDD; 32], // public + Some(&[0xEE; 32]), // remote + 5, 3, 2, + ).unwrap(); + + let state = store.load_state(&session_id).unwrap().unwrap(); + assert_eq!(state.root_key, [0xAA; 32]); + assert_eq!(state.sending_chain, Some([0xBB; 32])); + assert_eq!(state.receiving_chain, None); + assert_eq!(state.dh_self_secret, [0xCC; 32]); + assert_eq!(state.dh_self_public, [0xDD; 32]); + assert_eq!(state.dh_remote, Some([0xEE; 32])); + assert_eq!(state.msg_send, 5); + assert_eq!(state.msg_recv, 3); + assert_eq!(state.prev_chain_len, 2); } #[test] - fn test_delete() { - let storage = create_test_storage(); - let session_id = [1u8; 32]; - let state = create_test_state(); - - storage.save(&session_id, &state).unwrap(); - assert!(storage.exists(&session_id).unwrap()); - - let deleted = storage.delete(&session_id).unwrap(); - assert!(deleted); - assert!(!storage.exists(&session_id).unwrap()); - - // Deleting again should return false - let deleted = storage.delete(&session_id).unwrap(); - assert!(!deleted); - } - - #[test] - fn test_exists() { - let storage = create_test_storage(); + fn test_update_fields() { + let store = test_store(); let session_id = [1u8; 32]; - assert!(!storage.exists(&session_id).unwrap()); + store.init_session( + &session_id, &[0xAA; 32], None, None, + &[0xCC; 32], &[0xDD; 32], None, 0, 0, 0, + ).unwrap(); - let state = create_test_state(); - storage.save(&session_id, &state).unwrap(); + // Update root and chains + store.store_root_and_chains(&session_id, &[0x11; 32], Some(&[0x22; 32]), Some(&[0x33; 32])).unwrap(); - assert!(storage.exists(&session_id).unwrap()); + let state = store.load_state(&session_id).unwrap().unwrap(); + assert_eq!(state.root_key, [0x11; 32]); + assert_eq!(state.sending_chain, Some([0x22; 32])); + assert_eq!(state.receiving_chain, Some([0x33; 32])); } #[test] - fn test_list_sessions() { - let storage = create_test_storage(); - - assert!(storage.list_sessions().unwrap().is_empty()); - - let state = create_test_state(); - let session_ids: Vec = (0..3).map(|i| [i; 32]).collect(); - - for id in &session_ids { - storage.save(id, &state).unwrap(); - } - - let mut listed = storage.list_sessions().unwrap(); - listed.sort(); - let mut expected = session_ids.clone(); - expected.sort(); - - assert_eq!(listed, expected); - } - - #[test] - fn test_overwrite() { - let storage = create_test_storage(); + fn test_skipped_keys() { + let store = test_store(); let session_id = [1u8; 32]; - // Create first state - let bob_keypair1 = InstallationKeyPair::generate(); - let state1: RatchetState = - RatchetState::init_sender([0x42u8; 32], *bob_keypair1.public()); - let storable1 = StorableRatchetState::from_ratchet_state(&state1, "default"); + store.init_session( + &session_id, &[0xAA; 32], None, None, + &[0xCC; 32], &[0xDD; 32], None, 0, 0, 0, + ).unwrap(); - // Create second state with different root - let bob_keypair2 = InstallationKeyPair::generate(); - let state2: RatchetState = - RatchetState::init_sender([0x43u8; 32], *bob_keypair2.public()); - let storable2 = StorableRatchetState::from_ratchet_state(&state2, "default"); + // Add skipped keys + store.add_skipped_key(&session_id, &[0x11; 32], 5, &[0xAB; 32]).unwrap(); + store.add_skipped_key(&session_id, &[0x11; 32], 6, &[0xCD; 32]).unwrap(); - // Save first, then overwrite with second - storage.save(&session_id, &storable1).unwrap(); - storage.save(&session_id, &storable2).unwrap(); + let state = store.load_state(&session_id).unwrap().unwrap(); + assert_eq!(state.skipped_keys.len(), 2); - // Should have the second state - let loaded = storage.load(&session_id).unwrap().unwrap(); - assert_eq!(loaded.root_key, storable2.root_key); - assert_ne!(loaded.root_key, storable1.root_key); + // Remove one + store.remove_skipped_key(&session_id, &[0x11; 32], 5).unwrap(); + + let state = store.load_state(&session_id).unwrap().unwrap(); + assert_eq!(state.skipped_keys.len(), 1); + assert_eq!(state.skipped_keys[0].msg_num, 6); } #[test] - fn test_skipped_keys_storage() { - let storage = create_test_storage(); + fn test_delete_cascade() { + let store = test_store(); let session_id = [1u8; 32]; - // Create states and generate skipped keys - let bob_keypair = InstallationKeyPair::generate(); - let shared_secret = [0x42u8; 32]; - let mut alice: RatchetState = - RatchetState::init_sender(shared_secret, *bob_keypair.public()); - let mut bob: RatchetState = - RatchetState::init_receiver(shared_secret, bob_keypair); + store.init_session( + &session_id, &[0xAA; 32], None, None, + &[0xCC; 32], &[0xDD; 32], None, 0, 0, 0, + ).unwrap(); + store.add_skipped_key(&session_id, &[0x11; 32], 5, &[0xAB; 32]).unwrap(); - // Alice sends multiple messages - let mut messages = vec![]; - for i in 0..3 { - let (ct, header) = alice.encrypt_message(&format!("Message {}", i).into_bytes()); - messages.push((ct, header)); - } + assert!(store.session_exists(&session_id).unwrap()); - // Bob receives out of order to create skipped keys - bob.decrypt_message(&messages[0].0, messages[0].1.clone()) - .unwrap(); - bob.decrypt_message(&messages[2].0, messages[2].1.clone()) - .unwrap(); + store.delete_session(&session_id).unwrap(); - assert!(!bob.skipped_keys.is_empty()); - - // Save and reload - let storable = StorableRatchetState::from_ratchet_state(&bob, "default"); - storage.save(&session_id, &storable).unwrap(); - - let loaded = storage.load(&session_id).unwrap().unwrap(); - assert_eq!(loaded.skipped_keys.len(), storable.skipped_keys.len()); - - // Restore and verify we can decrypt the skipped message - let mut restored: RatchetState = loaded.to_ratchet_state().unwrap(); - let pt = restored - .decrypt_message(&messages[1].0, messages[1].1.clone()) - .unwrap(); - assert_eq!(pt, b"Message 1"); - } - - #[test] - fn test_encryption_uses_different_nonces() { - let storage = create_test_storage(); - let state = create_test_state(); - - // Save the same state twice with different session IDs - storage.save(&[1u8; 32], &state).unwrap(); - storage.save(&[2u8; 32], &state).unwrap(); - - // Both should load correctly (encryption with different nonces) - let loaded1 = storage.load(&[1u8; 32]).unwrap().unwrap(); - let loaded2 = storage.load(&[2u8; 32]).unwrap().unwrap(); - - assert_eq!(loaded1.dh_self_secret, loaded2.dh_self_secret); - } - - #[test] - fn test_cascade_delete_skipped_keys() { - let storage = create_test_storage(); - let session_id = [1u8; 32]; - - // Create a state with skipped keys - let bob_keypair = InstallationKeyPair::generate(); - let shared_secret = [0x42u8; 32]; - let mut alice: RatchetState = - RatchetState::init_sender(shared_secret, *bob_keypair.public()); - let mut bob: RatchetState = - RatchetState::init_receiver(shared_secret, bob_keypair); - - let mut messages = vec![]; - for i in 0..3 { - let (ct, header) = alice.encrypt_message(&format!("Message {}", i).into_bytes()); - messages.push((ct, header)); - } - - bob.decrypt_message(&messages[0].0, messages[0].1.clone()) - .unwrap(); - bob.decrypt_message(&messages[2].0, messages[2].1.clone()) - .unwrap(); - - let storable = StorableRatchetState::from_ratchet_state(&bob, "default"); - storage.save(&session_id, &storable).unwrap(); - - // Verify skipped keys exist - { - let conn = storage.conn.lock().unwrap(); - let count: i64 = conn - .query_row( - "SELECT COUNT(*) FROM skipped_keys WHERE session_id = ?", - [session_id.as_slice()], - |row| row.get(0), - ) - .unwrap(); - assert!(count > 0); - } - - // Delete session - storage.delete(&session_id).unwrap(); - - // Verify skipped keys were also deleted (cascade) - { - let conn = storage.conn.lock().unwrap(); - let count: i64 = conn - .query_row( - "SELECT COUNT(*) FROM skipped_keys WHERE session_id = ?", - [session_id.as_slice()], - |row| row.get(0), - ) - .unwrap(); - assert_eq!(count, 0); - } - } - - #[test] - fn test_file_storage() { - let dir = tempfile::tempdir().unwrap(); - let db_path = dir.path().join("test.db"); - let key = [0x42u8; 32]; - - let state = create_test_state(); - let session_id = [1u8; 32]; - - // Save in one instance - { - let storage = SqliteStorage::open(&db_path, key).unwrap(); - storage.save(&session_id, &state).unwrap(); - } - - // Load in another instance - { - let storage = SqliteStorage::open(&db_path, key).unwrap(); - let loaded = storage.load(&session_id).unwrap().unwrap(); - assert_eq!(loaded.root_key, state.root_key); - } - } - - #[test] - fn test_wrong_key_fails_decryption() { - let dir = tempfile::tempdir().unwrap(); - let db_path = dir.path().join("test.db"); - let key1 = [0x42u8; 32]; - let key2 = [0x43u8; 32]; - - let state = create_test_state(); - let session_id = [1u8; 32]; - - // Save with key1 - { - let storage = SqliteStorage::open(&db_path, key1).unwrap(); - storage.save(&session_id, &state).unwrap(); - } - - // Try to load with key2 - should fail decryption - { - let storage = SqliteStorage::open(&db_path, key2).unwrap(); - let result = storage.load(&session_id); - assert!(result.is_err()); - } + assert!(!store.session_exists(&session_id).unwrap()); + // Skipped keys should be gone too (cascade) } } diff --git a/double-ratchets-storage/src/traits.rs b/double-ratchets-storage/src/traits.rs index 5607a65..22f540c 100644 --- a/double-ratchets-storage/src/traits.rs +++ b/double-ratchets-storage/src/traits.rs @@ -1,74 +1,112 @@ -//! Storage trait definitions. +//! Storage trait for field-level ratchet state persistence. use crate::error::StorageError; -use crate::types::StorableRatchetState; /// A 32-byte session identifier. pub type SessionId = [u8; 32]; -/// Abstract storage interface for ratchet states. +/// Field-level storage interface for ratchet state. /// -/// Implementations must be thread-safe (`Send + Sync`). -pub trait RatchetStorage: Send + Sync { - /// Save a ratchet state for the given session. - /// - /// If a state already exists for this session, it will be overwritten. - /// - /// # Arguments - /// - /// * `session_id` - Unique identifier for the session. - /// * `state` - The ratchet state to store. - /// - /// # Returns - /// - /// * `Ok(())` on success. - /// * `Err(StorageError)` on failure. - fn save(&self, session_id: &SessionId, state: &StorableRatchetState) -> Result<(), StorageError>; +/// This trait provides granular storage operations that are called automatically +/// during ratchet operations. Each method persists only the fields that changed. +pub trait RatchetStore: Send + Sync { + /// Store the root key and chain keys after a DH ratchet step. + fn store_root_and_chains( + &self, + session_id: &SessionId, + root_key: &[u8; 32], + sending_chain: Option<&[u8; 32]>, + receiving_chain: Option<&[u8; 32]>, + ) -> Result<(), StorageError>; - /// Load a ratchet state for the given session. - /// - /// # Arguments - /// - /// * `session_id` - Unique identifier for the session. - /// - /// # Returns - /// - /// * `Ok(Some(state))` if found. - /// * `Ok(None)` if not found. - /// * `Err(StorageError)` on failure. - fn load(&self, session_id: &SessionId) -> Result, StorageError>; + /// Store our DH keypair (secret encrypted, public plaintext). + fn store_dh_self( + &self, + session_id: &SessionId, + secret: &[u8; 32], + public: &[u8; 32], + ) -> Result<(), StorageError>; - /// Delete a ratchet state for the given session. - /// - /// # Arguments - /// - /// * `session_id` - Unique identifier for the session. - /// - /// # Returns - /// - /// * `Ok(true)` if the session existed and was deleted. - /// * `Ok(false)` if the session did not exist. - /// * `Err(StorageError)` on failure. - fn delete(&self, session_id: &SessionId) -> Result; + /// Store the remote party's DH public key. + fn store_dh_remote( + &self, + session_id: &SessionId, + remote: Option<&[u8; 32]>, + ) -> Result<(), StorageError>; - /// Check if a session exists in storage. - /// - /// # Arguments - /// - /// * `session_id` - Unique identifier for the session. - /// - /// # Returns - /// - /// * `Ok(true)` if the session exists. - /// * `Ok(false)` if the session does not exist. - /// * `Err(StorageError)` on failure. - fn exists(&self, session_id: &SessionId) -> Result; + /// Store message counters. + fn store_counters( + &self, + session_id: &SessionId, + msg_send: u32, + msg_recv: u32, + prev_chain_len: u32, + ) -> Result<(), StorageError>; - /// List all session IDs in storage. - /// - /// # Returns - /// - /// * `Ok(Vec)` containing all session IDs. - /// * `Err(StorageError)` on failure. + /// Add a skipped message key. + fn add_skipped_key( + &self, + session_id: &SessionId, + dh_public: &[u8; 32], + msg_num: u32, + message_key: &[u8; 32], + ) -> Result<(), StorageError>; + + /// Remove a skipped message key (after use). + fn remove_skipped_key( + &self, + session_id: &SessionId, + dh_public: &[u8; 32], + msg_num: u32, + ) -> Result<(), StorageError>; + + /// Load all state for a session. Returns None if session doesn't exist. + fn load_state(&self, session_id: &SessionId) -> Result, StorageError>; + + /// Initialize a new session with all fields. + fn init_session( + &self, + session_id: &SessionId, + root_key: &[u8; 32], + sending_chain: Option<&[u8; 32]>, + receiving_chain: Option<&[u8; 32]>, + dh_self_secret: &[u8; 32], + dh_self_public: &[u8; 32], + dh_remote: Option<&[u8; 32]>, + msg_send: u32, + msg_recv: u32, + prev_chain_len: u32, + ) -> Result<(), StorageError>; + + /// Delete a session and all its data. + fn delete_session(&self, session_id: &SessionId) -> Result; + + /// Check if a session exists. + fn session_exists(&self, session_id: &SessionId) -> Result; + + /// List all session IDs. fn list_sessions(&self) -> Result, StorageError>; } + +/// Complete state loaded from storage. +#[derive(Debug, Clone)] +pub struct StoredState { + pub root_key: [u8; 32], + pub sending_chain: Option<[u8; 32]>, + pub receiving_chain: Option<[u8; 32]>, + pub dh_self_secret: [u8; 32], + pub dh_self_public: [u8; 32], + pub dh_remote: Option<[u8; 32]>, + pub msg_send: u32, + pub msg_recv: u32, + pub prev_chain_len: u32, + pub skipped_keys: Vec, +} + +/// A skipped key entry from storage. +#[derive(Debug, Clone)] +pub struct SkippedKeyEntry { + pub dh_public: [u8; 32], + pub msg_num: u32, + pub message_key: [u8; 32], +} diff --git a/double-ratchets-storage/src/types.rs b/double-ratchets-storage/src/types.rs deleted file mode 100644 index 59806a5..0000000 --- a/double-ratchets-storage/src/types.rs +++ /dev/null @@ -1,225 +0,0 @@ -//! Serializable types for ratchet state storage. - -use std::collections::HashMap; - -use double_ratchets::state::RatchetState; -use double_ratchets::hkdf::HkdfInfo; -use double_ratchets::InstallationKeyPair; -use serde::{Deserialize, Serialize}; -use x25519_dalek::PublicKey; - -use crate::error::StorageError; - -/// A skipped message key entry for storage. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct SkippedKey { - /// The public key associated with this skipped message. - pub public_key: [u8; 32], - /// The message number. - pub msg_num: u32, - /// The 32-byte message key. - pub message_key: [u8; 32], -} - -/// Serializable version of `RatchetState`. -/// -/// This struct stores all keys as raw byte arrays for easy serialization -/// and database storage. Use `from_ratchet_state()` and `to_ratchet_state()` -/// for conversion. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct StorableRatchetState { - /// The current root key (32 bytes). - pub root_key: [u8; 32], - - /// The current sending chain key, if any (32 bytes). - pub sending_chain: Option<[u8; 32]>, - - /// The current receiving chain key, if any (32 bytes). - pub receiving_chain: Option<[u8; 32]>, - - /// Our DH secret key (32 bytes). - /// - /// **Security**: This should be encrypted before storage. - pub dh_self_secret: [u8; 32], - - /// Our DH public key (32 bytes). - pub dh_self_public: [u8; 32], - - /// Remote party's DH public key, if known (32 bytes). - pub dh_remote: Option<[u8; 32]>, - - /// Number of messages sent in the current sending chain. - pub msg_send: u32, - - /// Number of messages received in the current receiving chain. - pub msg_recv: u32, - - /// Length of the previous sending chain. - pub prev_chain_len: u32, - - /// Skipped message keys for out-of-order message handling. - pub skipped_keys: Vec, - - /// Domain identifier for HKDF info. - pub domain_id: String, -} - -impl StorableRatchetState { - /// Convert a `RatchetState` into a `StorableRatchetState`. - /// - /// # Type Parameters - /// - /// * `D` - The HKDF domain type implementing `HkdfInfo`. - /// - /// # Arguments - /// - /// * `state` - The ratchet state to convert. - /// * `domain_id` - A string identifier for the domain (used to reconstruct the correct domain type). - pub fn from_ratchet_state(state: &RatchetState, domain_id: &str) -> Self { - let skipped_keys: Vec = state - .skipped_keys - .iter() - .map(|((pub_key, msg_num), msg_key)| SkippedKey { - public_key: *pub_key.as_bytes(), - msg_num: *msg_num, - message_key: *msg_key, - }) - .collect(); - - StorableRatchetState { - root_key: state.root_key, - sending_chain: state.sending_chain, - receiving_chain: state.receiving_chain, - dh_self_secret: state.dh_self.secret_bytes(), - dh_self_public: *state.dh_self.public().as_bytes(), - dh_remote: state.dh_remote.map(|pk| *pk.as_bytes()), - msg_send: state.msg_send, - msg_recv: state.msg_recv, - prev_chain_len: state.prev_chain_len, - skipped_keys, - domain_id: domain_id.to_string(), - } - } - - /// Convert this `StorableRatchetState` back into a `RatchetState`. - /// - /// # Type Parameters - /// - /// * `D` - The HKDF domain type implementing `HkdfInfo`. - /// - /// # Returns - /// - /// * `Ok(RatchetState)` on success. - /// * `Err(StorageError)` if key reconstruction fails. - pub fn to_ratchet_state(&self) -> Result, StorageError> { - // Reconstruct the keypair - let dh_self = InstallationKeyPair::from_bytes(self.dh_self_secret, self.dh_self_public) - .map_err(|e| StorageError::KeyReconstruction(e.to_string()))?; - - // Reconstruct skipped keys HashMap - let skipped_keys: HashMap<(PublicKey, u32), [u8; 32]> = self - .skipped_keys - .iter() - .map(|sk| { - let pub_key = PublicKey::from(sk.public_key); - ((pub_key, sk.msg_num), sk.message_key) - }) - .collect(); - - Ok(RatchetState::from_parts( - self.root_key, - self.sending_chain, - self.receiving_chain, - dh_self, - self.dh_remote.map(PublicKey::from), - self.msg_send, - self.msg_recv, - self.prev_chain_len, - skipped_keys, - )) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use double_ratchets::hkdf::DefaultDomain; - - #[test] - fn test_roundtrip_sender_state() { - // Create a sender state - let bob_keypair = InstallationKeyPair::generate(); - let shared_secret = [0x42u8; 32]; - let state: RatchetState = - RatchetState::init_sender(shared_secret, *bob_keypair.public()); - - // Convert to storable and back - let storable = StorableRatchetState::from_ratchet_state(&state, "default"); - let restored: RatchetState = storable.to_ratchet_state().unwrap(); - - // Verify fields match - assert_eq!(state.root_key, restored.root_key); - assert_eq!(state.sending_chain, restored.sending_chain); - assert_eq!(state.receiving_chain, restored.receiving_chain); - assert_eq!(state.dh_self.public().as_bytes(), restored.dh_self.public().as_bytes()); - assert_eq!(state.dh_remote.map(|pk| *pk.as_bytes()), restored.dh_remote.map(|pk| *pk.as_bytes())); - assert_eq!(state.msg_send, restored.msg_send); - assert_eq!(state.msg_recv, restored.msg_recv); - assert_eq!(state.prev_chain_len, restored.prev_chain_len); - } - - #[test] - fn test_roundtrip_receiver_state() { - // Create a receiver state - let keypair = InstallationKeyPair::generate(); - let shared_secret = [0x42u8; 32]; - let state: RatchetState = - RatchetState::init_receiver(shared_secret, keypair); - - // Convert to storable and back - let storable = StorableRatchetState::from_ratchet_state(&state, "default"); - let restored: RatchetState = storable.to_ratchet_state().unwrap(); - - // Verify fields match - assert_eq!(state.root_key, restored.root_key); - assert_eq!(state.dh_self.public().as_bytes(), restored.dh_self.public().as_bytes()); - assert!(restored.dh_remote.is_none()); - } - - #[test] - fn test_roundtrip_with_skipped_keys() { - // Create states and exchange messages to generate skipped keys - let bob_keypair = InstallationKeyPair::generate(); - let shared_secret = [0x42u8; 32]; - let mut alice: RatchetState = - RatchetState::init_sender(shared_secret, *bob_keypair.public()); - let mut bob: RatchetState = - RatchetState::init_receiver(shared_secret, bob_keypair); - - // Alice sends multiple messages - let mut messages = vec![]; - for i in 0..3 { - let (ct, header) = alice.encrypt_message(&format!("Message {}", i).into_bytes()); - messages.push((ct, header)); - } - - // Bob receives them out of order to create skipped keys - bob.decrypt_message(&messages[0].0, messages[0].1.clone()).unwrap(); - bob.decrypt_message(&messages[2].0, messages[2].1.clone()).unwrap(); - // Message 1 key is now in skipped_keys - - assert!(!bob.skipped_keys.is_empty()); - - // Convert to storable and back - let storable = StorableRatchetState::from_ratchet_state(&bob, "default"); - let restored: RatchetState = storable.to_ratchet_state().unwrap(); - - // Verify skipped keys are preserved - assert_eq!(bob.skipped_keys.len(), restored.skipped_keys.len()); - - // The restored state should be able to decrypt the skipped message - let mut restored = restored; - let pt = restored.decrypt_message(&messages[1].0, messages[1].1.clone()).unwrap(); - assert_eq!(pt, b"Message 1"); - } -}