diff --git a/double-ratchets/examples/out_of_order_demo.rs b/double-ratchets/examples/out_of_order_demo.rs index d40c6a5..63637a3 100644 --- a/double-ratchets/examples/out_of_order_demo.rs +++ b/double-ratchets/examples/out_of_order_demo.rs @@ -2,12 +2,13 @@ //! //! Run with: cargo run --example out_of_order_demo -p double-ratchets -use double_ratchets::{InstallationKeyPair, RatchetSession, RatchetStorage, hkdf::DefaultDomain}; +use double_ratchets::{ + InstallationKeyPair, RatchetSession, SqliteRatchetStore, hkdf::DefaultDomain, +}; fn main() { println!("=== Out-of-Order Message Handling Demo ===\n"); - // Setup ensure_tmp_directory(); let alice_db_path = "./tmp/out_of_order_demo_alice.db"; let bob_db_path = "./tmp/out_of_order_demo_bob.db"; @@ -20,28 +21,27 @@ fn main() { let bob_public = bob_keypair.public().clone(); let conv_id = "out_of_order_conv"; - // Collect messages for out-of-order delivery let mut messages: Vec<(Vec, double_ratchets::Header)> = Vec::new(); // Phase 1: Alice sends 5 messages, Bob receives 1, 3, 5 (skipping 2, 4) { - let mut alice_storage = RatchetStorage::new(alice_db_path, encryption_key) + let alice_storage = SqliteRatchetStore::new(alice_db_path, encryption_key) .expect("Failed to create Alice storage"); - let mut bob_storage = - RatchetStorage::new(bob_db_path, encryption_key).expect("Failed to create Bob storage"); + let bob_storage = SqliteRatchetStore::new(bob_db_path, encryption_key) + .expect("Failed to create Bob storage"); - let mut alice_session: RatchetSession = + let mut alice_session: RatchetSession = RatchetSession::create_sender_session( - &mut alice_storage, + alice_storage, conv_id, shared_secret, bob_public, ) .unwrap(); - let mut bob_session: RatchetSession = + let mut bob_session: RatchetSession = RatchetSession::create_receiver_session( - &mut bob_storage, + bob_storage, conv_id, shared_secret, bob_keypair, @@ -73,11 +73,11 @@ fn main() { // Phase 2: Simulate app restart by reopening storage println!("\n Simulating app restart..."); { - let mut bob_storage = - RatchetStorage::new(bob_db_path, encryption_key).expect("Failed to reopen Bob storage"); + let bob_storage = SqliteRatchetStore::new(bob_db_path, encryption_key) + .expect("Failed to reopen Bob storage"); - let bob_session: RatchetSession = - RatchetSession::open(&mut bob_storage, conv_id).unwrap(); + let bob_session: RatchetSession = + RatchetSession::open(bob_storage, conv_id).unwrap(); println!( " After restart, Bob's skipped_keys: {}", bob_session.state().skipped_keys.len() @@ -86,13 +86,13 @@ fn main() { // Phase 3: Bob receives the delayed messages println!("\nBob receives delayed message 2..."); - let (ct4, header4) = messages[3].clone(); // Save for replay test + let (ct4, header4) = messages[3].clone(); { - let mut bob_storage = - RatchetStorage::new(bob_db_path, encryption_key).expect("Failed to open Bob storage"); + let bob_storage = SqliteRatchetStore::new(bob_db_path, encryption_key) + .expect("Failed to open Bob storage"); - let mut bob_session: RatchetSession = - RatchetSession::open(&mut bob_storage, conv_id).unwrap(); + let mut bob_session: RatchetSession = + RatchetSession::open(bob_storage, conv_id).unwrap(); let (ct, header) = &messages[1]; let pt = bob_session.decrypt_message(ct, header.clone()).unwrap(); @@ -105,11 +105,11 @@ fn main() { println!("\nBob receives delayed message 4..."); { - let mut bob_storage = - RatchetStorage::new(bob_db_path, encryption_key).expect("Failed to open Bob storage"); + let bob_storage = SqliteRatchetStore::new(bob_db_path, encryption_key) + .expect("Failed to open Bob storage"); - let mut bob_session: RatchetSession = - RatchetSession::open(&mut bob_storage, conv_id).unwrap(); + let mut bob_session: RatchetSession = + RatchetSession::open(bob_storage, conv_id).unwrap(); let pt = bob_session.decrypt_message(&ct4, header4.clone()).unwrap(); println!(" Received: \"{}\"", String::from_utf8_lossy(&pt)); @@ -123,11 +123,11 @@ fn main() { println!("\n--- Replay Protection Demo ---"); println!("Trying to decrypt message 4 again (should fail)..."); { - let mut bob_storage = - RatchetStorage::new(bob_db_path, encryption_key).expect("Failed to open Bob storage"); + let bob_storage = SqliteRatchetStore::new(bob_db_path, encryption_key) + .expect("Failed to open Bob storage"); - let mut bob_session: RatchetSession = - RatchetSession::open(&mut bob_storage, conv_id).unwrap(); + let mut bob_session: RatchetSession = + RatchetSession::open(bob_storage, conv_id).unwrap(); match bob_session.decrypt_message(&ct4, header4) { Ok(_) => println!(" ERROR: Replay attack succeeded!"), @@ -135,7 +135,6 @@ fn main() { } } - // Cleanup let _ = std::fs::remove_file(alice_db_path); let _ = std::fs::remove_file(bob_db_path); diff --git a/double-ratchets/examples/storage_demo.rs b/double-ratchets/examples/storage_demo.rs index 9b08a1e..34184a5 100644 --- a/double-ratchets/examples/storage_demo.rs +++ b/double-ratchets/examples/storage_demo.rs @@ -2,7 +2,9 @@ //! //! Run with: cargo run --example storage_demo -p double-ratchets -use double_ratchets::{InstallationKeyPair, RatchetSession, RatchetStorage, hkdf::PrivateV1Domain}; +use double_ratchets::{ + InstallationKeyPair, RatchetSession, SqliteRatchetStore, hkdf::PrivateV1Domain, +}; fn main() { println!("=== Double Ratchet Storage Demo ===\n"); @@ -16,25 +18,25 @@ fn main() { // Initial conversation with encryption { - let mut alice_storage = RatchetStorage::new(alice_db_path, encryption_key) + let alice_storage = SqliteRatchetStore::new(alice_db_path, encryption_key) .expect("Failed to create alice encrypted storage"); - let mut bob_storage = RatchetStorage::new(bob_db_path, encryption_key) + let bob_storage = SqliteRatchetStore::new(bob_db_path, encryption_key) .expect("Failed to create bob encrypted storage"); println!( " Encrypted database created at: {}, {}", alice_db_path, bob_db_path ); - run_conversation(&mut alice_storage, &mut bob_storage); + run_conversation(alice_storage, bob_storage); } // Restart with correct key println!("\n Simulating restart with encryption key..."); { - let mut alice_storage = RatchetStorage::new(alice_db_path, encryption_key) + let alice_storage = SqliteRatchetStore::new(alice_db_path, encryption_key) .expect("Failed to create alice encrypted storage"); - let mut bob_storage = RatchetStorage::new(bob_db_path, encryption_key) + let bob_storage = SqliteRatchetStore::new(bob_db_path, encryption_key) .expect("Failed to create bob encrypted storage"); - continue_after_restart(&mut alice_storage, &mut bob_storage); + continue_after_restart(alice_storage, bob_storage); } let _ = std::fs::remove_file(alice_db_path); @@ -48,72 +50,53 @@ fn ensure_tmp_directory() { } } -/// Simulates a conversation between Alice and Bob. -/// Each party saves/loads state from storage for each operation. -fn run_conversation(alice_storage: &mut RatchetStorage, bob_storage: &mut RatchetStorage) { - // === Setup: Simulate X3DH key exchange === - let shared_secret = [0x42u8; 32]; // In reality, this comes from X3DH +fn run_conversation(alice_storage: SqliteRatchetStore, bob_storage: SqliteRatchetStore) { + let shared_secret = [0x42u8; 32]; let bob_keypair = InstallationKeyPair::generate(); - let conv_id = "conv1"; - let mut alice_session: RatchetSession = RatchetSession::create_sender_session( - alice_storage, - conv_id, - shared_secret, - bob_keypair.public().clone(), - ) - .unwrap(); + let mut alice_session: RatchetSession = + RatchetSession::create_sender_session( + alice_storage, + conv_id, + shared_secret, + bob_keypair.public().clone(), + ) + .unwrap(); - let mut bob_session: RatchetSession = + let mut bob_session: RatchetSession = RatchetSession::create_receiver_session(bob_storage, conv_id, shared_secret, bob_keypair) .unwrap(); println!(" Sessions created for Alice and Bob"); - // === Message 1: Alice -> Bob === - let (ct1, h1) = { - let result = alice_session - .encrypt_message(b"Hello Bob! This is message 1.") - .unwrap(); - println!(" Alice sent: \"Hello Bob! This is message 1.\""); - result - }; + // Message 1: Alice -> Bob + let (ct1, h1) = alice_session + .encrypt_message(b"Hello Bob! This is message 1.") + .unwrap(); + println!(" Alice sent: \"Hello Bob! This is message 1.\""); - { - let pt = bob_session.decrypt_message(&ct1, h1).unwrap(); - println!(" Bob received: \"{}\"", String::from_utf8_lossy(&pt)); - } + let pt = bob_session.decrypt_message(&ct1, h1).unwrap(); + println!(" Bob received: \"{}\"", String::from_utf8_lossy(&pt)); - // === Message 2: Bob -> Alice (triggers DH ratchet) === - let (ct2, h2) = { - let result = bob_session - .encrypt_message(b"Hi Alice! Got your message.") - .unwrap(); - println!(" Bob sent: \"Hi Alice! Got your message.\""); - result - }; + // Message 2: Bob -> Alice + let (ct2, h2) = bob_session + .encrypt_message(b"Hi Alice! Got your message.") + .unwrap(); + println!(" Bob sent: \"Hi Alice! Got your message.\""); - { - let pt = alice_session.decrypt_message(&ct2, h2).unwrap(); - println!(" Alice received: \"{}\"", String::from_utf8_lossy(&pt)); - } + let pt = alice_session.decrypt_message(&ct2, h2).unwrap(); + println!(" Alice received: \"{}\"", String::from_utf8_lossy(&pt)); - // === Message 3: Alice -> Bob === - let (ct3, h3) = { - let result = alice_session - .encrypt_message(b"Great! Let's keep chatting.") - .unwrap(); - println!(" Alice sent: \"Great! Let's keep chatting.\""); - result - }; + // Message 3: Alice -> Bob + let (ct3, h3) = alice_session + .encrypt_message(b"Great! Let's keep chatting.") + .unwrap(); + println!(" Alice sent: \"Great! Let's keep chatting.\""); - { - let pt = bob_session.decrypt_message(&ct3, h3).unwrap(); - println!(" Bob received: \"{}\"", String::from_utf8_lossy(&pt)); - } + let pt = bob_session.decrypt_message(&ct3, h3).unwrap(); + println!(" Bob received: \"{}\"", String::from_utf8_lossy(&pt)); - // Print final state println!( " State after conversation: Alice msg_send={}, Bob msg_recv={}", alice_session.msg_send(), @@ -121,29 +104,22 @@ fn run_conversation(alice_storage: &mut RatchetStorage, bob_storage: &mut Ratche ); } -fn continue_after_restart(alice_storage: &mut RatchetStorage, bob_storage: &mut RatchetStorage) { - // Load persisted states +fn continue_after_restart(alice_storage: SqliteRatchetStore, bob_storage: SqliteRatchetStore) { let conv_id = "conv1"; - let mut alice_session: RatchetSession = + let mut alice_session: RatchetSession = RatchetSession::open(alice_storage, conv_id).unwrap(); - let mut bob_session: RatchetSession = + let mut bob_session: RatchetSession = RatchetSession::open(bob_storage, conv_id).unwrap(); - println!(" Sessions restored for Alice and Bob",); + println!(" Sessions restored for Alice and Bob"); - // Continue conversation - let (ct, header) = { - let result = alice_session - .encrypt_message(b"Message after restart!") - .unwrap(); - println!(" Alice sent: \"Message after restart!\""); - result - }; + let (ct, header) = alice_session + .encrypt_message(b"Message after restart!") + .unwrap(); + println!(" Alice sent: \"Message after restart!\""); - { - let pt = bob_session.decrypt_message(&ct, header).unwrap(); - println!(" Bob received: \"{}\"", String::from_utf8_lossy(&pt)); - } + let pt = bob_session.decrypt_message(&ct, header).unwrap(); + println!(" Bob received: \"{}\"", String::from_utf8_lossy(&pt)); println!( " Final state: Alice msg_send={}, Bob msg_recv={}", diff --git a/double-ratchets/src/keypair.rs b/double-ratchets/src/keypair.rs index 7943646..00f8724 100644 --- a/double-ratchets/src/keypair.rs +++ b/double-ratchets/src/keypair.rs @@ -1,3 +1,5 @@ +use std::fmt::Debug; + use rand_core::OsRng; use x25519_dalek::{PublicKey, StaticSecret}; use zeroize::{Zeroize, ZeroizeOnDrop}; @@ -10,6 +12,15 @@ pub struct InstallationKeyPair { public: PublicKey, } +impl Debug for InstallationKeyPair { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("InstallationKeyPair") + .field("public", &self.public.as_bytes()) + .field("secret", &"[REDACTED]") + .finish() + } +} + impl InstallationKeyPair { pub fn generate() -> Self { let secret = StaticSecret::random_from_rng(OsRng); @@ -36,4 +47,12 @@ impl InstallationKeyPair { let public = PublicKey::from(&secret); Self { secret, public } } + + /// Import the key pair from both secret and public bytes. + pub fn from_bytes(secret: [u8; 32], public: [u8; 32]) -> Self { + Self { + secret: StaticSecret::from(secret), + public: PublicKey::from(public), + } + } } diff --git a/double-ratchets/src/lib.rs b/double-ratchets/src/lib.rs index c5abe43..b8f73cd 100644 --- a/double-ratchets/src/lib.rs +++ b/double-ratchets/src/lib.rs @@ -10,5 +10,7 @@ pub mod types; pub use keypair::InstallationKeyPair; pub use state::{Header, RatchetState, SkippedKey}; -pub use storage::StorageConfig; -pub use storage::{RatchetSession, RatchetStorage, SessionError}; +pub use storage::{ + EphemeralStore, RatchetSession, RatchetStateData, RatchetStore, SessionError, SkippedKeyId, + SkippedMessageKey, SqliteRatchetStore, StoreError, +}; diff --git a/double-ratchets/src/storage/db.rs b/double-ratchets/src/storage/db.rs deleted file mode 100644 index 2c216d7..0000000 --- a/double-ratchets/src/storage/db.rs +++ /dev/null @@ -1,320 +0,0 @@ -//! Ratchet-specific storage implementation. - -use std::collections::HashSet; - -use storage::{SqliteDb, StorageError, params}; - -use super::types::RatchetStateRecord; -use crate::{ - hkdf::HkdfInfo, - state::{RatchetState, SkippedKey}, -}; - -/// Schema for ratchet state tables. -const RATCHET_SCHEMA: &str = " - CREATE TABLE IF NOT EXISTS ratchet_state ( - conversation_id TEXT PRIMARY KEY, - root_key BLOB NOT NULL, - sending_chain BLOB, - receiving_chain BLOB, - dh_self_secret BLOB NOT NULL, - dh_remote BLOB, - msg_send INTEGER NOT NULL, - msg_recv INTEGER NOT NULL, - prev_chain_len INTEGER NOT NULL - ); - - CREATE TABLE IF NOT EXISTS skipped_keys ( - conversation_id TEXT NOT NULL, - public_key BLOB NOT NULL, - msg_num INTEGER NOT NULL, - message_key BLOB NOT NULL, - created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')), - PRIMARY KEY (conversation_id, public_key, msg_num), - FOREIGN KEY (conversation_id) REFERENCES ratchet_state(conversation_id) ON DELETE CASCADE - ); - - CREATE INDEX IF NOT EXISTS idx_skipped_keys_conversation - ON skipped_keys(conversation_id); -"; - -/// Ratchet-specific storage operations. -/// -/// This struct wraps a `SqliteDb` and provides domain-specific -/// storage operations for ratchet state. -pub struct RatchetStorage { - db: SqliteDb, -} - -impl RatchetStorage { - /// Opens an existing encrypted database file. - pub fn new(path: &str, key: &str) -> Result { - let db = SqliteDb::sqlcipher(path.to_string(), key.to_string())?; - Self::run_migration(db) - } - - /// Creates an in-memory storage (useful for testing). - pub fn in_memory() -> Result { - let db = SqliteDb::in_memory()?; - Self::run_migration(db) - } - - /// Creates a new ratchet storage with the given database. - fn run_migration(db: SqliteDb) -> Result { - // Initialize schema - db.connection().execute_batch(RATCHET_SCHEMA)?; - Ok(Self { db }) - } - - /// Saves the ratchet state for a conversation. - pub fn save( - &mut self, - conversation_id: &str, - state: &RatchetState, - ) -> Result<(), StorageError> { - let tx = self.db.transaction()?; - - let data = RatchetStateRecord::from(state); - let skipped_keys: Vec = state.skipped_keys(); - - // Upsert main state - tx.execute( - " - INSERT INTO ratchet_state ( - conversation_id, root_key, sending_chain, receiving_chain, - dh_self_secret, dh_remote, msg_send, msg_recv, prev_chain_len - ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9) - ON CONFLICT(conversation_id) DO UPDATE SET - root_key = excluded.root_key, - sending_chain = excluded.sending_chain, - receiving_chain = excluded.receiving_chain, - dh_self_secret = excluded.dh_self_secret, - dh_remote = excluded.dh_remote, - msg_send = excluded.msg_send, - msg_recv = excluded.msg_recv, - prev_chain_len = excluded.prev_chain_len - ", - params![ - conversation_id, - data.root_key.as_slice(), - data.sending_chain.as_ref().map(|c| c.as_slice()), - data.receiving_chain.as_ref().map(|c| c.as_slice()), - data.dh_self_secret.as_slice(), - data.dh_remote.as_ref().map(|c| c.as_slice()), - data.msg_send, - data.msg_recv, - data.prev_chain_len, - ], - )?; - - // Sync skipped keys - sync_skipped_keys(&tx, conversation_id, skipped_keys)?; - - tx.commit()?; - Ok(()) - } - - /// Loads the ratchet state for a conversation. - pub fn load( - &self, - conversation_id: &str, - ) -> Result, StorageError> { - let data = self.load_state_data(conversation_id)?; - let skipped_keys = self.load_skipped_keys(conversation_id)?; - Ok(data.into_ratchet_state(skipped_keys)) - } - - fn load_state_data(&self, conversation_id: &str) -> Result { - let conn = self.db.connection(); - let mut stmt = conn.prepare( - " - SELECT root_key, sending_chain, receiving_chain, dh_self_secret, - dh_remote, msg_send, msg_recv, prev_chain_len - FROM ratchet_state - WHERE conversation_id = ?1 - ", - )?; - - stmt.query_row(params![conversation_id], |row| { - Ok(RatchetStateRecord { - root_key: blob_to_array(row.get::<_, Vec>(0)?), - sending_chain: row.get::<_, Option>>(1)?.map(blob_to_array), - receiving_chain: row.get::<_, Option>>(2)?.map(blob_to_array), - dh_self_secret: blob_to_array(row.get::<_, Vec>(3)?), - dh_remote: row.get::<_, Option>>(4)?.map(blob_to_array), - msg_send: row.get(5)?, - msg_recv: row.get(6)?, - prev_chain_len: row.get(7)?, - }) - }) - .map_err(|e| match e { - storage::RusqliteError::QueryReturnedNoRows => { - StorageError::NotFound(conversation_id.to_string()) - } - e => StorageError::Database(e.to_string()), - }) - } - - fn load_skipped_keys(&self, conversation_id: &str) -> Result, StorageError> { - let conn = self.db.connection(); - let mut stmt = conn.prepare( - " - SELECT public_key, msg_num, message_key - FROM skipped_keys - WHERE conversation_id = ?1 - ", - )?; - - let rows = stmt.query_map(params![conversation_id], |row| { - Ok(SkippedKey { - public_key: blob_to_array(row.get::<_, Vec>(0)?), - msg_num: row.get(1)?, - message_key: blob_to_array(row.get::<_, Vec>(2)?), - }) - })?; - - rows.collect::, _>>() - .map_err(|e| StorageError::Database(e.to_string())) - } - - /// Checks if a conversation exists. - pub fn exists(&self, conversation_id: &str) -> Result { - let conn = self.db.connection(); - let count: i64 = conn.query_row( - "SELECT COUNT(*) FROM ratchet_state WHERE conversation_id = ?1", - params![conversation_id], - |row| row.get(0), - )?; - Ok(count > 0) - } - - /// Deletes a conversation and its skipped keys. - pub fn delete(&mut self, conversation_id: &str) -> Result<(), StorageError> { - let tx = self.db.transaction()?; - tx.execute( - "DELETE FROM skipped_keys WHERE conversation_id = ?1", - params![conversation_id], - )?; - tx.execute( - "DELETE FROM ratchet_state WHERE conversation_id = ?1", - params![conversation_id], - )?; - tx.commit()?; - Ok(()) - } - - /// Cleans up old skipped keys older than the given age in seconds. - pub fn cleanup_old_skipped_keys(&mut self, max_age_secs: i64) -> Result { - let conn = self.db.connection(); - let deleted = conn.execute( - "DELETE FROM skipped_keys WHERE created_at < strftime('%s', 'now') - ?1", - params![max_age_secs], - )?; - Ok(deleted) - } -} - -/// Syncs skipped keys efficiently by computing diff and only inserting/deleting changes. -fn sync_skipped_keys( - tx: &storage::Transaction, - conversation_id: &str, - current_keys: Vec, -) -> Result<(), StorageError> { - // Get existing keys from DB (just the identifiers) - let mut stmt = - tx.prepare("SELECT public_key, msg_num FROM skipped_keys WHERE conversation_id = ?1")?; - let existing: HashSet<([u8; 32], u32)> = stmt - .query_map(params![conversation_id], |row| { - Ok(( - blob_to_array(row.get::<_, Vec>(0)?), - row.get::<_, u32>(1)?, - )) - })? - .filter_map(|r| r.ok()) - .collect(); - - // Build set of current keys - let current_set: HashSet<([u8; 32], u32)> = current_keys - .iter() - .map(|sk| (sk.public_key, sk.msg_num)) - .collect(); - - // Delete keys that were removed (used for decryption) - for (pk, msg_num) in existing.difference(¤t_set) { - tx.execute( - "DELETE FROM skipped_keys WHERE conversation_id = ?1 AND public_key = ?2 AND msg_num = ?3", - params![conversation_id, pk.as_slice(), msg_num], - )?; - } - - // Insert new keys - for sk in ¤t_keys { - let key = (sk.public_key, sk.msg_num); - if !existing.contains(&key) { - tx.execute( - "INSERT INTO skipped_keys (conversation_id, public_key, msg_num, message_key) - VALUES (?1, ?2, ?3, ?4)", - params![ - conversation_id, - sk.public_key.as_slice(), - sk.msg_num, - sk.message_key.as_slice(), - ], - )?; - } - } - - Ok(()) -} - -fn blob_to_array(blob: Vec) -> [u8; N] { - blob.try_into() - .unwrap_or_else(|v: Vec| panic!("Expected {} bytes, got {}", N, v.len())) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{keypair::InstallationKeyPair, state::RatchetState, types::SharedSecret}; - - fn create_test_state() -> (RatchetState, SharedSecret) { - let shared_secret = [0x42u8; 32]; - let bob_keypair = InstallationKeyPair::generate(); - let state = RatchetState::init_sender(shared_secret, bob_keypair.public().clone()); - (state, shared_secret) - } - - #[test] - fn test_save_and_load() { - let mut storage = RatchetStorage::in_memory().unwrap(); - let (state, _) = create_test_state(); - - storage.save("conv1", &state).unwrap(); - let loaded: RatchetState = storage.load("conv1").unwrap(); - - assert_eq!(state.root_key, loaded.root_key); - assert_eq!(state.msg_send, loaded.msg_send); - } - - #[test] - fn test_exists() { - let mut storage = RatchetStorage::in_memory().unwrap(); - let (state, _) = create_test_state(); - - assert!(!storage.exists("conv1").unwrap()); - storage.save("conv1", &state).unwrap(); - assert!(storage.exists("conv1").unwrap()); - } - - #[test] - fn test_delete() { - let mut storage = RatchetStorage::in_memory().unwrap(); - let (state, _) = create_test_state(); - - storage.save("conv1", &state).unwrap(); - assert!(storage.exists("conv1").unwrap()); - - storage.delete("conv1").unwrap(); - assert!(!storage.exists("conv1").unwrap()); - } -} diff --git a/double-ratchets/src/storage/ephemeral.rs b/double-ratchets/src/storage/ephemeral.rs new file mode 100644 index 0000000..5c49ff4 --- /dev/null +++ b/double-ratchets/src/storage/ephemeral.rs @@ -0,0 +1,171 @@ +//! In-memory ephemeral storage for testing. +//! +//! This store keeps all data in memory and is useful for testing +//! or scenarios where persistence is not needed. + +use std::collections::HashMap; + +use super::store::{RatchetStateData, RatchetStore, SkippedKeyId, SkippedMessageKey, StoreError}; + +/// In-memory storage implementation. +/// +/// All data is lost when the store is dropped. +#[derive(Debug, Default)] +pub struct EphemeralStore { + states: HashMap, + skipped_keys: HashMap>, +} + +impl EphemeralStore { + /// Creates a new empty ephemeral store. + pub fn new() -> Self { + Self::default() + } +} + +impl RatchetStore for EphemeralStore { + fn save_state( + &mut self, + conversation_id: &str, + state: &RatchetStateData, + ) -> Result<(), StoreError> { + self.states.insert(conversation_id.to_string(), state.clone()); + Ok(()) + } + + fn load_state(&self, conversation_id: &str) -> Result { + self.states + .get(conversation_id) + .cloned() + .ok_or_else(|| StoreError::NotFound(conversation_id.to_string())) + } + + fn exists(&self, conversation_id: &str) -> Result { + Ok(self.states.contains_key(conversation_id)) + } + + fn delete(&mut self, conversation_id: &str) -> Result<(), StoreError> { + self.states.remove(conversation_id); + self.skipped_keys.remove(conversation_id); + Ok(()) + } + + fn get_skipped_key( + &self, + conversation_id: &str, + id: &SkippedKeyId, + ) -> Result, StoreError> { + Ok(self + .skipped_keys + .get(conversation_id) + .and_then(|keys| keys.get(id)) + .map(|sk| sk.message_key)) + } + + fn add_skipped_key( + &mut self, + conversation_id: &str, + key: SkippedMessageKey, + ) -> Result<(), StoreError> { + self.skipped_keys + .entry(conversation_id.to_string()) + .or_default() + .insert(key.id.clone(), key); + Ok(()) + } + + fn remove_skipped_key( + &mut self, + conversation_id: &str, + id: &SkippedKeyId, + ) -> Result<(), StoreError> { + if let Some(keys) = self.skipped_keys.get_mut(conversation_id) { + keys.remove(id); + } + Ok(()) + } + + fn get_all_skipped_keys( + &self, + conversation_id: &str, + ) -> Result, StoreError> { + Ok(self + .skipped_keys + .get(conversation_id) + .map(|keys| keys.values().cloned().collect()) + .unwrap_or_default()) + } + + fn clear_skipped_keys(&mut self, conversation_id: &str) -> Result<(), StoreError> { + self.skipped_keys.remove(conversation_id); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::keypair::InstallationKeyPair; + + fn create_test_state() -> RatchetStateData { + RatchetStateData { + root_key: [0x42; 32], + sending_chain: Some([0x01; 32]), + receiving_chain: None, + dh_self: InstallationKeyPair::generate(), + dh_remote: Some([0x02; 32]), + msg_send: 0, + msg_recv: 0, + prev_chain_len: 0, + } + } + + #[test] + fn test_save_and_load() { + let mut store = EphemeralStore::new(); + let state = create_test_state(); + + store.save_state("conv1", &state).unwrap(); + let loaded = store.load_state("conv1").unwrap(); + + assert_eq!(state.root_key, loaded.root_key); + assert_eq!(state.msg_send, loaded.msg_send); + } + + #[test] + fn test_exists() { + let mut store = EphemeralStore::new(); + let state = create_test_state(); + + assert!(!store.exists("conv1").unwrap()); + store.save_state("conv1", &state).unwrap(); + assert!(store.exists("conv1").unwrap()); + } + + #[test] + fn test_skipped_keys() { + let mut store = EphemeralStore::new(); + let state = create_test_state(); + store.save_state("conv1", &state).unwrap(); + + let id = SkippedKeyId { + public_key: [0x01; 32], + msg_num: 5, + }; + let key = SkippedMessageKey { + id: id.clone(), + message_key: [0xAB; 32], + }; + + // Add key + store.add_skipped_key("conv1", key.clone()).unwrap(); + assert_eq!( + store.get_skipped_key("conv1", &id).unwrap(), + Some([0xAB; 32]) + ); + + // Remove key + store.remove_skipped_key("conv1", &id).unwrap(); + assert_eq!(store.get_skipped_key("conv1", &id).unwrap(), None); + } +} diff --git a/double-ratchets/src/storage/errors.rs b/double-ratchets/src/storage/errors.rs index 39f2ebc..cc4be09 100644 --- a/double-ratchets/src/storage/errors.rs +++ b/double-ratchets/src/storage/errors.rs @@ -13,4 +13,13 @@ pub enum SessionError { #[error("conversation already exists: {0}")] ConvAlreadyExists(String), + + #[error("conversation not found: {0}")] + ConvNotFound(String), + + #[error("storage backend error: {0}")] + StorageError(String), + + #[error("deserialization failed: {0}")] + DeserializationFailed(String), } diff --git a/double-ratchets/src/storage/mod.rs b/double-ratchets/src/storage/mod.rs index 354cae6..ca183b1 100644 --- a/double-ratchets/src/storage/mod.rs +++ b/double-ratchets/src/storage/mod.rs @@ -1,15 +1,23 @@ //! Storage module for persisting ratchet state. //! -//! This module provides storage implementations for the double ratchet state, -//! built on top of the shared `storage` crate. +//! This module provides a trait-based abstraction for storage, allowing +//! the double ratchet to be agnostic to how data is persisted. +//! +//! # Architecture +//! +//! - [`RatchetStore`] - Trait defining storage needs for double ratchet state +//! - [`RatchetSession`] - High-level wrapper with automatic persistence +//! - [`EphemeralStore`] - In-memory implementation for testing +//! - [`SqliteRatchetStore`] - SQLite/SQLCipher implementation for production -mod db; +mod ephemeral; mod errors; mod session; -mod types; +mod sqlite; +mod store; -pub use db::RatchetStorage; +pub use ephemeral::EphemeralStore; pub use errors::SessionError; pub use session::RatchetSession; -pub use storage::{SqliteDb, StorageConfig, StorageError}; -pub use types::RatchetStateRecord; +pub use sqlite::SqliteRatchetStore; +pub use store::{RatchetStateData, RatchetStore, SkippedKeyId, SkippedMessageKey, StoreError}; diff --git a/double-ratchets/src/storage/session.rs b/double-ratchets/src/storage/session.rs index 40be01f..65c55ef 100644 --- a/double-ratchets/src/storage/session.rs +++ b/double-ratchets/src/storage/session.rs @@ -1,147 +1,166 @@ //! Session wrapper for automatic state persistence. +use std::{collections::HashMap, marker::PhantomData}; + use x25519_dalek::PublicKey; use crate::{ - InstallationKeyPair, SessionError, + InstallationKeyPair, hkdf::HkdfInfo, state::{Header, RatchetState}, types::SharedSecret, }; -use super::RatchetStorage; +use super::{ + SessionError, + store::{RatchetStateData, RatchetStore, SkippedKeyId, SkippedMessageKey, StoreError}, +}; -/// A session wrapper that automatically persists ratchet state after operations. -/// Provides rollback semantics - state is only saved if the operation succeeds. -pub struct RatchetSession<'a, D: HkdfInfo + Clone> { - storage: &'a mut RatchetStorage, +/// Session wrapper with automatic persistence. +pub struct RatchetSession { + store: S, conversation_id: String, state: RatchetState, } -impl<'a, D: HkdfInfo + Clone> RatchetSession<'a, D> { +impl RatchetSession { /// Opens an existing session from storage. - pub fn open( - storage: &'a mut RatchetStorage, - conversation_id: impl Into, - ) -> Result { + pub fn open(store: S, conversation_id: impl Into) -> Result { let conversation_id = conversation_id.into(); - let state = storage.load(&conversation_id)?; + let data = store + .load_state(&conversation_id) + .map_err(|e| map_store_error(e, &conversation_id))?; + let skipped_keys = store + .get_all_skipped_keys(&conversation_id) + .map_err(|e| map_store_error(e, &conversation_id))?; + let state = state_from_data(data, skipped_keys); + Ok(Self { - storage, + store, conversation_id, state, }) } - /// Creates a new session and persists the initial state. + /// Creates a new session with the given state. pub fn create( - storage: &'a mut RatchetStorage, + mut store: S, conversation_id: impl Into, state: RatchetState, ) -> Result { let conversation_id = conversation_id.into(); - storage.save(&conversation_id, &state)?; + let data = state_to_data(&state); + store + .save_state(&conversation_id, &data) + .map_err(|e| map_store_error(e, &conversation_id))?; + + for key in get_skipped_keys(&state) { + store + .add_skipped_key(&conversation_id, key) + .map_err(|e| map_store_error(e, &conversation_id))?; + } + Ok(Self { - storage, + store, conversation_id, state, }) } - /// Initializes a new session as a sender and persists the initial state. + /// Creates sender session. pub fn create_sender_session( - storage: &'a mut RatchetStorage, + store: S, conversation_id: &str, shared_secret: SharedSecret, remote_pub: PublicKey, ) -> Result { - if storage.exists(conversation_id)? { + let temp_store = store; + if temp_store + .exists(conversation_id) + .map_err(|e| map_store_error(e, conversation_id))? + { return Err(SessionError::ConvAlreadyExists(conversation_id.to_string())); } let state = RatchetState::::init_sender(shared_secret, remote_pub); - Ok(Self::create(storage, conversation_id, state)?) + Self::create(temp_store, conversation_id, state) } - /// Initializes a new session as a receiver and persists the initial state. + /// Creates receiver session. pub fn create_receiver_session( - storage: &'a mut RatchetStorage, + store: S, conversation_id: &str, shared_secret: SharedSecret, dh_self: InstallationKeyPair, ) -> Result { - if storage.exists(conversation_id)? { + let temp_store = store; + if temp_store + .exists(conversation_id) + .map_err(|e| map_store_error(e, conversation_id))? + { return Err(SessionError::ConvAlreadyExists(conversation_id.to_string())); } - let state = RatchetState::::init_receiver(shared_secret, dh_self); - Ok(Self::create(storage, conversation_id, state)?) + Self::create(temp_store, conversation_id, state) } - /// Encrypts a message and persists the updated state. - /// If persistence fails, the in-memory state is NOT modified. + /// Encrypts a message. pub fn encrypt_message(&mut self, plaintext: &[u8]) -> Result<(Vec, Header), SessionError> { - // Clone state for rollback let state_backup = self.state.clone(); - - // Perform encryption (modifies state) let result = self.state.encrypt_message(plaintext); - - // Try to persist - if let Err(e) = self.storage.save(&self.conversation_id, &self.state) { - // Rollback + if let Err(e) = self.persist_state() { self.state = state_backup; - return Err(e.into()); + return Err(e); } - Ok(result) } - /// Decrypts a message and persists the updated state. - /// If decryption or persistence fails, the in-memory state is NOT modified. + /// Decrypts a message. pub fn decrypt_message( &mut self, ciphertext_with_nonce: &[u8], header: Header, ) -> Result, SessionError> { - // Clone state for rollback let state_backup = self.state.clone(); - - // Perform decryption (modifies state) let plaintext = match self.state.decrypt_message(ciphertext_with_nonce, header) { Ok(pt) => pt, Err(e) => { - // Rollback on decrypt failure self.state = state_backup; return Err(e.into()); } }; - - // Try to persist - if let Err(e) = self.storage.save(&self.conversation_id, &self.state) { - // Rollback + if let Err(e) = self.persist_state() { self.state = state_backup; - return Err(e.into()); + return Err(e); } - Ok(plaintext) } - /// Returns a reference to the current state (read-only). + fn persist_state(&mut self) -> Result<(), SessionError> { + let data = state_to_data(&self.state); + self.store + .save_state(&self.conversation_id, &data) + .map_err(|e| map_store_error(e, &self.conversation_id))?; + self.store + .clear_skipped_keys(&self.conversation_id) + .map_err(|e| map_store_error(e, &self.conversation_id))?; + for key in get_skipped_keys(&self.state) { + self.store + .add_skipped_key(&self.conversation_id, key) + .map_err(|e| map_store_error(e, &self.conversation_id))?; + } + Ok(()) + } + pub fn state(&self) -> &RatchetState { &self.state } - /// Returns the conversation ID. pub fn conversation_id(&self) -> &str { &self.conversation_id } - /// Manually saves the current state. pub fn save(&mut self) -> Result<(), SessionError> { - self.storage - .save(&self.conversation_id, &self.state) - .map_err(|error| error.into()) + self.persist_state() } pub fn msg_send(&self) -> u32 { @@ -151,202 +170,113 @@ impl<'a, D: HkdfInfo + Clone> RatchetSession<'a, D> { pub fn msg_recv(&self) -> u32 { self.state.msg_recv } + + pub fn into_store(self) -> S { + self.store + } +} + +fn state_to_data(state: &RatchetState) -> RatchetStateData { + RatchetStateData { + root_key: state.root_key, + sending_chain: state.sending_chain, + receiving_chain: state.receiving_chain, + dh_self: state.dh_self.clone(), + dh_remote: state.dh_remote.map(|pk| pk.to_bytes()), + msg_send: state.msg_send, + msg_recv: state.msg_recv, + prev_chain_len: state.prev_chain_len, + } +} + +fn state_from_data( + data: RatchetStateData, + skipped_keys: Vec, +) -> RatchetState { + let skipped_map = skipped_keys + .into_iter() + .map(|sk| { + let pk = PublicKey::from(sk.id.public_key); + ((pk, sk.id.msg_num), sk.message_key) + }) + .collect::>(); + + RatchetState { + root_key: data.root_key, + sending_chain: data.sending_chain, + receiving_chain: data.receiving_chain, + dh_self: data.dh_self, + dh_remote: data.dh_remote.map(PublicKey::from), + msg_send: data.msg_send, + msg_recv: data.msg_recv, + prev_chain_len: data.prev_chain_len, + skipped_keys: skipped_map, + _domain: PhantomData, + } +} + +fn get_skipped_keys(state: &RatchetState) -> Vec { + state + .skipped_keys + .iter() + .map(|((pk, msg_num), mk)| SkippedMessageKey { + id: SkippedKeyId { + public_key: pk.to_bytes(), + msg_num: *msg_num, + }, + message_key: *mk, + }) + .collect() +} + +fn map_store_error(e: StoreError, conversation_id: &str) -> SessionError { + match e { + StoreError::NotFound(_) => SessionError::ConvNotFound(conversation_id.to_string()), + StoreError::AlreadyExists(_) => { + SessionError::ConvAlreadyExists(conversation_id.to_string()) + } + StoreError::Storage(s) => SessionError::StorageError(s), + StoreError::Serialization(s) => SessionError::DeserializationFailed(s), + } } #[cfg(test)] mod tests { use super::*; - use crate::hkdf::DefaultDomain; - - fn create_test_storage() -> RatchetStorage { - RatchetStorage::in_memory().unwrap() - } + use crate::{hkdf::DefaultDomain, storage::EphemeralStore}; #[test] fn test_session_create_and_open() { - let mut storage = create_test_storage(); - + let store = EphemeralStore::new(); let shared_secret = [0x42; 32]; let bob_keypair = InstallationKeyPair::generate(); let alice: RatchetState = RatchetState::init_sender(shared_secret, bob_keypair.public().clone()); - // Create session - { - let session = RatchetSession::create(&mut storage, "conv1", alice).unwrap(); - assert_eq!(session.conversation_id(), "conv1"); - } + let session = RatchetSession::create(store, "conv1", alice).unwrap(); + assert_eq!(session.conversation_id(), "conv1"); - // Open existing session - { - let session: RatchetSession = - RatchetSession::open(&mut storage, "conv1").unwrap(); - assert_eq!(session.state().msg_send, 0); - } + let store = session.into_store(); + let session: RatchetSession<_, DefaultDomain> = + RatchetSession::open(store, "conv1").unwrap(); + assert_eq!(session.state().msg_send, 0); } #[test] fn test_session_encrypt_persists() { - let mut storage = create_test_storage(); - + let store = EphemeralStore::new(); let shared_secret = [0x42; 32]; let bob_keypair = InstallationKeyPair::generate(); let alice: RatchetState = RatchetState::init_sender(shared_secret, bob_keypair.public().clone()); - // Create and encrypt - { - let mut session = RatchetSession::create(&mut storage, "conv1", alice).unwrap(); - session.encrypt_message(b"Hello").unwrap(); - assert_eq!(session.state().msg_send, 1); - } + let mut session = RatchetSession::create(store, "conv1", alice).unwrap(); + session.encrypt_message(b"Hello").unwrap(); + assert_eq!(session.state().msg_send, 1); - // Reopen - state should be persisted - { - let session: RatchetSession = - RatchetSession::open(&mut storage, "conv1").unwrap(); - assert_eq!(session.state().msg_send, 1); - } - } - - #[test] - fn test_session_full_conversation() { - let mut storage = create_test_storage(); - - let shared_secret = [0x42; 32]; - let bob_keypair = InstallationKeyPair::generate(); - let alice: RatchetState = - RatchetState::init_sender(shared_secret, bob_keypair.public().clone()); - let bob: RatchetState = - RatchetState::init_receiver(shared_secret, bob_keypair); - - // Alice sends - let (ct, header) = { - let mut session = RatchetSession::create(&mut storage, "alice", alice).unwrap(); - session.encrypt_message(b"Hello Bob").unwrap() - }; - - // Bob receives - let plaintext = { - let mut session = RatchetSession::create(&mut storage, "bob", bob).unwrap(); - session.decrypt_message(&ct, header).unwrap() - }; - assert_eq!(plaintext, b"Hello Bob"); - - // Bob replies - let (ct2, header2) = { - let mut session: RatchetSession = - RatchetSession::open(&mut storage, "bob").unwrap(); - session.encrypt_message(b"Hi Alice").unwrap() - }; - - // Alice receives - let plaintext2 = { - let mut session: RatchetSession = - RatchetSession::open(&mut storage, "alice").unwrap(); - session.decrypt_message(&ct2, header2).unwrap() - }; - assert_eq!(plaintext2, b"Hi Alice"); - } - - #[test] - fn test_session_open_or_create() { - let mut storage = create_test_storage(); - - let shared_secret = [0x42; 32]; - let bob_keypair = InstallationKeyPair::generate(); - let bob_pub = bob_keypair.public().clone(); - - // First call creates - { - let session: RatchetSession = RatchetSession::create_sender_session( - &mut storage, - "conv1", - shared_secret, - bob_pub.clone(), - ) - .unwrap(); - assert_eq!(session.state().msg_send, 0); - } - - // Second call opens existing - { - let mut session: RatchetSession = - RatchetSession::open(&mut storage, "conv1").unwrap(); - session.encrypt_message(b"test").unwrap(); - } - - // Verify persistence - { - let session: RatchetSession = - RatchetSession::open(&mut storage, "conv1").unwrap(); - assert_eq!(session.state().msg_send, 1); - } - } - - #[test] - fn test_create_sender_session_fails_when_conversation_exists() { - let mut storage = create_test_storage(); - - let shared_secret = [0x42; 32]; - let bob_keypair = InstallationKeyPair::generate(); - let bob_pub = bob_keypair.public().clone(); - - // First creation succeeds - { - let _session: RatchetSession = RatchetSession::create_sender_session( - &mut storage, - "conv1", - shared_secret, - bob_pub.clone(), - ) - .unwrap(); - } - - // Second creation should fail with ConversationAlreadyExists - { - let result: Result, _> = - RatchetSession::create_sender_session( - &mut storage, - "conv1", - shared_secret, - bob_pub.clone(), - ); - - assert!(matches!(result, Err(SessionError::ConvAlreadyExists(_)))); - } - } - - #[test] - fn test_create_receiver_session_fails_when_conversation_exists() { - let mut storage = create_test_storage(); - - let shared_secret = [0x42; 32]; - let bob_keypair = InstallationKeyPair::generate(); - - // First creation succeeds - { - let _session: RatchetSession = RatchetSession::create_receiver_session( - &mut storage, - "conv1", - shared_secret, - bob_keypair, - ) - .unwrap(); - } - - // Second creation should fail with ConversationAlreadyExists - { - let another_keypair = InstallationKeyPair::generate(); - let result: Result, _> = - RatchetSession::create_receiver_session( - &mut storage, - "conv1", - shared_secret, - another_keypair, - ); - - assert!(matches!(result, Err(SessionError::ConvAlreadyExists(_)))); - } + let store = session.into_store(); + let session: RatchetSession<_, DefaultDomain> = + RatchetSession::open(store, "conv1").unwrap(); + assert_eq!(session.state().msg_send, 1); } } diff --git a/double-ratchets/src/storage/sqlite.rs b/double-ratchets/src/storage/sqlite.rs new file mode 100644 index 0000000..e4d9751 --- /dev/null +++ b/double-ratchets/src/storage/sqlite.rs @@ -0,0 +1,355 @@ +//! SQLite/SQLCipher implementation of RatchetStore. + +use storage::{params, RusqliteError, SqliteDb}; + +use super::store::{RatchetStateData, RatchetStore, SkippedKeyId, SkippedMessageKey, StoreError}; +use crate::keypair::InstallationKeyPair; +use crate::types::MessageKey; + +/// Schema for ratchet state tables. +const RATCHET_SCHEMA: &str = " + CREATE TABLE IF NOT EXISTS ratchet_state ( + conversation_id TEXT PRIMARY KEY, + root_key BLOB NOT NULL, + sending_chain BLOB, + receiving_chain BLOB, + dh_self_secret BLOB NOT NULL, + dh_self_public BLOB NOT NULL, + dh_remote BLOB, + msg_send INTEGER NOT NULL, + msg_recv INTEGER NOT NULL, + prev_chain_len INTEGER NOT NULL + ); + + CREATE TABLE IF NOT EXISTS skipped_keys ( + conversation_id TEXT NOT NULL, + public_key BLOB NOT NULL, + msg_num INTEGER NOT NULL, + message_key BLOB NOT NULL, + created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')), + PRIMARY KEY (conversation_id, public_key, msg_num), + FOREIGN KEY (conversation_id) REFERENCES ratchet_state(conversation_id) ON DELETE CASCADE + ); + + CREATE INDEX IF NOT EXISTS idx_skipped_keys_conversation + ON skipped_keys(conversation_id); +"; + +/// SQLite/SQLCipher backed ratchet store. +pub struct SqliteRatchetStore { + db: SqliteDb, +} + +impl SqliteRatchetStore { + /// Creates a new encrypted SQLite store. + pub fn new(path: &str, key: &str) -> Result { + let db = SqliteDb::sqlcipher(path.to_string(), key.to_string()) + .map_err(|e| StoreError::Storage(e.to_string()))?; + Self::init(db) + } + + /// Creates an in-memory store (useful for testing). + pub fn in_memory() -> Result { + let db = SqliteDb::in_memory().map_err(|e| StoreError::Storage(e.to_string()))?; + Self::init(db) + } + + fn init(db: SqliteDb) -> Result { + db.connection() + .execute_batch(RATCHET_SCHEMA) + .map_err(|e| StoreError::Storage(e.to_string()))?; + Ok(Self { db }) + } +} + +impl RatchetStore for SqliteRatchetStore { + fn save_state( + &mut self, + conversation_id: &str, + state: &RatchetStateData, + ) -> Result<(), StoreError> { + let conn = self.db.connection(); + conn.execute( + " + INSERT INTO ratchet_state ( + conversation_id, root_key, sending_chain, receiving_chain, + dh_self_secret, dh_self_public, dh_remote, msg_send, msg_recv, prev_chain_len + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10) + ON CONFLICT(conversation_id) DO UPDATE SET + root_key = excluded.root_key, + sending_chain = excluded.sending_chain, + receiving_chain = excluded.receiving_chain, + dh_self_secret = excluded.dh_self_secret, + dh_self_public = excluded.dh_self_public, + dh_remote = excluded.dh_remote, + msg_send = excluded.msg_send, + msg_recv = excluded.msg_recv, + prev_chain_len = excluded.prev_chain_len + ", + params![ + conversation_id, + state.root_key.as_slice(), + state.sending_chain.as_ref().map(|c| c.as_slice()), + state.receiving_chain.as_ref().map(|c| c.as_slice()), + state.dh_self.secret_bytes().as_slice(), + state.dh_self.public().as_bytes().as_slice(), + state.dh_remote.as_ref().map(|c| c.as_slice()), + state.msg_send, + state.msg_recv, + state.prev_chain_len, + ], + ) + .map_err(|e| StoreError::Storage(e.to_string()))?; + Ok(()) + } + + fn load_state(&self, conversation_id: &str) -> Result { + let conn = self.db.connection(); + let mut stmt = conn + .prepare( + " + SELECT root_key, sending_chain, receiving_chain, dh_self_secret, dh_self_public, + dh_remote, msg_send, msg_recv, prev_chain_len + FROM ratchet_state + WHERE conversation_id = ?1 + ", + ) + .map_err(|e| StoreError::Storage(e.to_string()))?; + + stmt.query_row(params![conversation_id], |row| { + let secret_bytes: Vec = row.get(3)?; + let public_bytes: Vec = row.get(4)?; + + Ok(RatchetStateData { + root_key: blob_to_array(row.get::<_, Vec>(0)?), + sending_chain: row.get::<_, Option>>(1)?.map(blob_to_array), + receiving_chain: row.get::<_, Option>>(2)?.map(blob_to_array), + dh_self: InstallationKeyPair::from_bytes( + blob_to_array(secret_bytes), + blob_to_array(public_bytes), + ), + dh_remote: row.get::<_, Option>>(5)?.map(blob_to_array), + msg_send: row.get(6)?, + msg_recv: row.get(7)?, + prev_chain_len: row.get(8)?, + }) + }) + .map_err(|e| match e { + RusqliteError::QueryReturnedNoRows => { + StoreError::NotFound(conversation_id.to_string()) + } + e => StoreError::Storage(e.to_string()), + }) + } + + fn exists(&self, conversation_id: &str) -> Result { + let conn = self.db.connection(); + let count: i64 = conn + .query_row( + "SELECT COUNT(*) FROM ratchet_state WHERE conversation_id = ?1", + params![conversation_id], + |row| row.get(0), + ) + .map_err(|e| StoreError::Storage(e.to_string()))?; + Ok(count > 0) + } + + fn delete(&mut self, conversation_id: &str) -> Result<(), StoreError> { + let conn = self.db.connection(); + // Skipped keys are deleted via CASCADE + conn.execute( + "DELETE FROM ratchet_state WHERE conversation_id = ?1", + params![conversation_id], + ) + .map_err(|e| StoreError::Storage(e.to_string()))?; + Ok(()) + } + + fn get_skipped_key( + &self, + conversation_id: &str, + id: &SkippedKeyId, + ) -> Result, StoreError> { + let conn = self.db.connection(); + let result: Result, _> = conn.query_row( + "SELECT message_key FROM skipped_keys + WHERE conversation_id = ?1 AND public_key = ?2 AND msg_num = ?3", + params![conversation_id, id.public_key.as_slice(), id.msg_num], + |row| row.get(0), + ); + + match result { + Ok(bytes) => Ok(Some(blob_to_array(bytes))), + Err(RusqliteError::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(StoreError::Storage(e.to_string())), + } + } + + fn add_skipped_key( + &mut self, + conversation_id: &str, + key: SkippedMessageKey, + ) -> Result<(), StoreError> { + let conn = self.db.connection(); + conn.execute( + "INSERT OR REPLACE INTO skipped_keys (conversation_id, public_key, msg_num, message_key) + VALUES (?1, ?2, ?3, ?4)", + params![ + conversation_id, + key.id.public_key.as_slice(), + key.id.msg_num, + key.message_key.as_slice(), + ], + ) + .map_err(|e| StoreError::Storage(e.to_string()))?; + Ok(()) + } + + fn remove_skipped_key( + &mut self, + conversation_id: &str, + id: &SkippedKeyId, + ) -> Result<(), StoreError> { + let conn = self.db.connection(); + conn.execute( + "DELETE FROM skipped_keys WHERE conversation_id = ?1 AND public_key = ?2 AND msg_num = ?3", + params![conversation_id, id.public_key.as_slice(), id.msg_num], + ) + .map_err(|e| StoreError::Storage(e.to_string()))?; + Ok(()) + } + + fn get_all_skipped_keys( + &self, + conversation_id: &str, + ) -> Result, StoreError> { + let conn = self.db.connection(); + let mut stmt = conn + .prepare( + "SELECT public_key, msg_num, message_key FROM skipped_keys WHERE conversation_id = ?1", + ) + .map_err(|e| StoreError::Storage(e.to_string()))?; + + let rows = stmt + .query_map(params![conversation_id], |row| { + Ok(SkippedMessageKey { + id: SkippedKeyId { + public_key: blob_to_array(row.get::<_, Vec>(0)?), + msg_num: row.get(1)?, + }, + message_key: blob_to_array(row.get::<_, Vec>(2)?), + }) + }) + .map_err(|e| StoreError::Storage(e.to_string()))?; + + rows.collect::, _>>() + .map_err(|e| StoreError::Storage(e.to_string())) + } + + fn clear_skipped_keys(&mut self, conversation_id: &str) -> Result<(), StoreError> { + let conn = self.db.connection(); + conn.execute( + "DELETE FROM skipped_keys WHERE conversation_id = ?1", + params![conversation_id], + ) + .map_err(|e| StoreError::Storage(e.to_string()))?; + Ok(()) + } +} + +fn blob_to_array(blob: Vec) -> [u8; N] { + blob.try_into() + .unwrap_or_else(|v: Vec| panic!("Expected {} bytes, got {}", N, v.len())) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn create_test_state() -> RatchetStateData { + RatchetStateData { + root_key: [0x42; 32], + sending_chain: Some([0x01; 32]), + receiving_chain: None, + dh_self: InstallationKeyPair::generate(), + dh_remote: Some([0x02; 32]), + msg_send: 0, + msg_recv: 0, + prev_chain_len: 0, + } + } + + #[test] + fn test_save_and_load() { + let mut store = SqliteRatchetStore::in_memory().unwrap(); + let state = create_test_state(); + + store.save_state("conv1", &state).unwrap(); + let loaded = store.load_state("conv1").unwrap(); + + assert_eq!(state.root_key, loaded.root_key); + assert_eq!(state.msg_send, loaded.msg_send); + } + + #[test] + fn test_exists() { + let mut store = SqliteRatchetStore::in_memory().unwrap(); + let state = create_test_state(); + + assert!(!store.exists("conv1").unwrap()); + store.save_state("conv1", &state).unwrap(); + assert!(store.exists("conv1").unwrap()); + } + + #[test] + fn test_skipped_keys() { + let mut store = SqliteRatchetStore::in_memory().unwrap(); + let state = create_test_state(); + store.save_state("conv1", &state).unwrap(); + + let id = SkippedKeyId { + public_key: [0x01; 32], + msg_num: 5, + }; + let key = SkippedMessageKey { + id: id.clone(), + message_key: [0xAB; 32], + }; + + // Add key + store.add_skipped_key("conv1", key.clone()).unwrap(); + assert_eq!( + store.get_skipped_key("conv1", &id).unwrap(), + Some([0xAB; 32]) + ); + + // Get all + let all = store.get_all_skipped_keys("conv1").unwrap(); + assert_eq!(all.len(), 1); + + // Remove key + store.remove_skipped_key("conv1", &id).unwrap(); + assert_eq!(store.get_skipped_key("conv1", &id).unwrap(), None); + } + + #[test] + fn test_delete_cascades() { + let mut store = SqliteRatchetStore::in_memory().unwrap(); + let state = create_test_state(); + store.save_state("conv1", &state).unwrap(); + + let id = SkippedKeyId { + public_key: [0x01; 32], + msg_num: 5, + }; + let key = SkippedMessageKey { + id: id.clone(), + message_key: [0xAB; 32], + }; + store.add_skipped_key("conv1", key).unwrap(); + + // Delete conversation - skipped keys should be deleted too + store.delete("conv1").unwrap(); + assert!(!store.exists("conv1").unwrap()); + } +} diff --git a/double-ratchets/src/storage/store.rs b/double-ratchets/src/storage/store.rs new file mode 100644 index 0000000..7e5653e --- /dev/null +++ b/double-ratchets/src/storage/store.rs @@ -0,0 +1,115 @@ +//! Storage trait for double ratchet persistence. +//! +//! This module defines the `RatchetStore` trait that abstracts storage needs +//! for the double ratchet algorithm. Implementations can be backed by SQLite, +//! PostgreSQL, in-memory storage, or any other backend. + +use crate::{ + keypair::InstallationKeyPair, + types::{ChainKey, MessageKey, RootKey}, +}; + +/// Identifier for a skipped message key. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct SkippedKeyId { + pub public_key: [u8; 32], + pub msg_num: u32, +} + +/// A skipped message key with its identifier. +#[derive(Debug, Clone)] +pub struct SkippedMessageKey { + pub id: SkippedKeyId, + pub message_key: MessageKey, +} + +/// The core ratchet state that needs to be persisted. +#[derive(Debug, Clone)] +pub struct RatchetStateData { + pub root_key: RootKey, + pub sending_chain: Option, + pub receiving_chain: Option, + pub dh_self: InstallationKeyPair, + pub dh_remote: Option<[u8; 32]>, + pub msg_send: u32, + pub msg_recv: u32, + pub prev_chain_len: u32, +} + +/// Error type for store operations. +#[derive(Debug, thiserror::Error)] +pub enum StoreError { + #[error("not found: {0}")] + NotFound(String), + + #[error("already exists: {0}")] + AlreadyExists(String), + + #[error("storage error: {0}")] + Storage(String), + + #[error("serialization error: {0}")] + Serialization(String), +} + +/// Trait defining storage requirements for the double ratchet algorithm. +/// +/// This trait abstracts the storage layer, allowing the double ratchet +/// implementation to be agnostic to the underlying storage mechanism. +/// +/// # Example Implementations +/// +/// - `SqliteRatchetStore` - SQLite/SQLCipher backed storage +/// - `EphemeralStore` - In-memory storage for testing +/// - `PostgresRatchetStore` - PostgreSQL backed storage (external) +pub trait RatchetStore { + // === Ratchet State Operations === + + /// Saves the ratchet state for a conversation. + fn save_state( + &mut self, + conversation_id: &str, + state: &RatchetStateData, + ) -> Result<(), StoreError>; + + /// Loads the ratchet state for a conversation. + fn load_state(&self, conversation_id: &str) -> Result; + + /// Checks if a conversation exists. + fn exists(&self, conversation_id: &str) -> Result; + + /// Deletes a conversation and all its associated data. + fn delete(&mut self, conversation_id: &str) -> Result<(), StoreError>; + + // === Skipped Message Key Operations === + + /// Gets a skipped message key if it exists. + fn get_skipped_key( + &self, + conversation_id: &str, + id: &SkippedKeyId, + ) -> Result, StoreError>; + + /// Adds a skipped message key. + fn add_skipped_key( + &mut self, + conversation_id: &str, + key: SkippedMessageKey, + ) -> Result<(), StoreError>; + + /// Removes a skipped message key (after successful decryption). + fn remove_skipped_key( + &mut self, + conversation_id: &str, + id: &SkippedKeyId, + ) -> Result<(), StoreError>; + + /// Gets all skipped keys for a conversation. + fn get_all_skipped_keys( + &self, + conversation_id: &str, + ) -> Result, StoreError>; + + /// Clears all skipped keys for a conversation. + fn clear_skipped_keys(&mut self, conversation_id: &str) -> Result<(), StoreError>; +} diff --git a/double-ratchets/src/storage/types.rs b/double-ratchets/src/storage/types.rs deleted file mode 100644 index 485e67a..0000000 --- a/double-ratchets/src/storage/types.rs +++ /dev/null @@ -1,65 +0,0 @@ -//! Storage types for ratchet state. - -use crate::{ - hkdf::HkdfInfo, - state::{RatchetState, SkippedKey}, - types::MessageKey, -}; -use x25519_dalek::PublicKey; - -/// Raw state data for storage (without generic parameter). -#[derive(Debug, Clone)] -pub struct RatchetStateRecord { - pub root_key: [u8; 32], - pub sending_chain: Option<[u8; 32]>, - pub receiving_chain: Option<[u8; 32]>, - pub dh_self_secret: [u8; 32], - pub dh_remote: Option<[u8; 32]>, - pub msg_send: u32, - pub msg_recv: u32, - pub prev_chain_len: u32, -} - -impl From<&RatchetState> for RatchetStateRecord { - fn from(state: &RatchetState) -> Self { - Self { - root_key: state.root_key, - sending_chain: state.sending_chain, - receiving_chain: state.receiving_chain, - dh_self_secret: *state.dh_self.secret_bytes(), - dh_remote: state.dh_remote.map(|pk| pk.to_bytes()), - msg_send: state.msg_send, - msg_recv: state.msg_recv, - prev_chain_len: state.prev_chain_len, - } - } -} - -impl RatchetStateRecord { - pub fn into_ratchet_state(self, skipped_keys: Vec) -> RatchetState { - use crate::keypair::InstallationKeyPair; - use std::collections::HashMap; - use std::marker::PhantomData; - - let dh_self = InstallationKeyPair::from_secret_bytes(self.dh_self_secret); - let dh_remote = self.dh_remote.map(PublicKey::from); - - let skipped: HashMap<(PublicKey, u32), MessageKey> = skipped_keys - .into_iter() - .map(|sk| ((PublicKey::from(sk.public_key), sk.msg_num), sk.message_key)) - .collect(); - - RatchetState { - root_key: self.root_key, - sending_chain: self.sending_chain, - receiving_chain: self.receiving_chain, - dh_self, - dh_remote, - msg_send: self.msg_send, - msg_recv: self.msg_recv, - prev_chain_len: self.prev_chain_len, - skipped_keys: skipped, - _domain: PhantomData, - } - } -}