diff --git a/Cargo.lock b/Cargo.lock index 1970c57..29a4609 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -221,8 +221,8 @@ dependencies = [ "hkdf", "rand", "rand_core", - "rusqlite", "safer-ffi", + "storage", "thiserror", "x25519-dalek", "zeroize", @@ -845,6 +845,14 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "storage" +version = "0.1.0" +dependencies = [ + "rusqlite", + "thiserror", +] + [[package]] name = "subtle" version = "2.6.1" diff --git a/Cargo.toml b/Cargo.toml index 4a118c2..0381c7d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,9 @@ members = [ "conversations", "crypto", "double-ratchets", + "storage", ] [workspace.dependencies] blake2 = "0.10" +storage = { path = "storage" } diff --git a/double-ratchets/Cargo.toml b/double-ratchets/Cargo.toml index 048011d..6d28b1b 100644 --- a/double-ratchets/Cargo.toml +++ b/double-ratchets/Cargo.toml @@ -20,10 +20,10 @@ thiserror = "2" blake2 = "0.10.6" safer-ffi = "0.1.13" zeroize = "1.8.2" -rusqlite = { version = "0.35", optional = true, features = ["bundled"] } +storage = { workspace = true, optional = true, features = ["sqlite"] } [features] default = [] -storage = ["rusqlite"] -sqlcipher = ["storage", "rusqlite/bundled-sqlcipher-vendored-openssl"] +persist = ["storage"] +sqlcipher = ["persist", "storage/sqlcipher"] headers = ["safer-ffi/headers"] diff --git a/double-ratchets/examples/out_of_order_demo.rs b/double-ratchets/examples/out_of_order_demo.rs index a2dbb4d..7217f21 100644 --- a/double-ratchets/examples/out_of_order_demo.rs +++ b/double-ratchets/examples/out_of_order_demo.rs @@ -4,7 +4,7 @@ #[cfg(feature = "storage")] use double_ratchets::{ - InstallationKeyPair, RatchetState, SqliteStorage, StorageConfig, hkdf::DefaultDomain, + InstallationKeyPair, RatchetState, RatchetStorage, StorageConfig, hkdf::DefaultDomain, state::Header, }; @@ -18,7 +18,7 @@ fn main() { #[cfg(feature = "storage")] fn run_demo() { let mut storage = - SqliteStorage::new(StorageConfig::InMemory).expect("Failed to create storage"); + RatchetStorage::with_config(StorageConfig::InMemory).expect("Failed to create storage"); // Setup let shared_secret = [0x42u8; 32]; @@ -77,7 +77,7 @@ fn run_demo() { let _ = std::fs::remove_file(db_path); // Redo with file storage - let mut storage = SqliteStorage::new(StorageConfig::File(db_path.to_string())) + let mut storage = RatchetStorage::with_config(StorageConfig::File(db_path.to_string())) .expect("Failed to create storage"); // Re-setup @@ -118,7 +118,7 @@ fn run_demo() { // Close and reopen storage (simulating app restart) drop(storage); let mut storage = - SqliteStorage::new(StorageConfig::File(db_path.to_string())).expect("Failed to reopen"); + RatchetStorage::with_config(StorageConfig::File(db_path.to_string())).expect("Failed to reopen"); let bob: RatchetState = storage.load("bob").unwrap(); println!( diff --git a/double-ratchets/examples/storage_demo.rs b/double-ratchets/examples/storage_demo.rs index ce05bd4..ef6e7c1 100644 --- a/double-ratchets/examples/storage_demo.rs +++ b/double-ratchets/examples/storage_demo.rs @@ -5,7 +5,7 @@ #[cfg(feature = "storage")] use double_ratchets::{ - InstallationKeyPair, RatchetSession, SqliteStorage, StorageConfig, hkdf::PrivateV1Domain, + InstallationKeyPair, RatchetSession, RatchetStorage, StorageConfig, hkdf::PrivateV1Domain, }; fn main() { @@ -37,9 +37,9 @@ fn main() { #[cfg(feature = "storage")] fn demo_in_memory() { let mut alice_storage = - SqliteStorage::new(StorageConfig::InMemory).expect("Failed to create storage"); + RatchetStorage::with_config(StorageConfig::InMemory).expect("Failed to create storage"); let mut bob_storage = - SqliteStorage::new(StorageConfig::InMemory).expect("Failed to create storage"); + RatchetStorage::with_config(StorageConfig::InMemory).expect("Failed to create storage"); run_conversation(&mut alice_storage, &mut bob_storage); } @@ -54,10 +54,10 @@ fn demo_file_storage() { // Initial conversation { - let mut alice_storage = SqliteStorage::new(StorageConfig::File(db_path_alice.to_string())) + let mut alice_storage = RatchetStorage::with_config(StorageConfig::File(db_path_alice.to_string())) .expect("Failed to create storage"); - let mut bob_storage = SqliteStorage::new(StorageConfig::File(db_path_bob.to_string())) + let mut bob_storage = RatchetStorage::with_config(StorageConfig::File(db_path_bob.to_string())) .expect("Failed to create storage"); println!(" Database created at: {}, {}", db_path_alice, db_path_bob); @@ -67,9 +67,9 @@ fn demo_file_storage() { // Simulate restart - reopen and continue println!("\n Simulating application restart..."); { - let mut alice_storage = SqliteStorage::new(StorageConfig::File(db_path_alice.to_string())) + let mut alice_storage = RatchetStorage::with_config(StorageConfig::File(db_path_alice.to_string())) .expect("Failed to reopen storage"); - let mut bob_storage = SqliteStorage::new(StorageConfig::File(db_path_bob.to_string())) + let mut bob_storage = RatchetStorage::with_config(StorageConfig::File(db_path_bob.to_string())) .expect("Failed to reopen storage"); continue_after_restart(&mut alice_storage, &mut bob_storage); } @@ -89,12 +89,12 @@ fn demo_sqlcipher() { // Initial conversation with encryption { - let mut alice_storage = SqliteStorage::new(StorageConfig::Encrypted { + let mut alice_storage = RatchetStorage::with_config(StorageConfig::Encrypted { path: alice_db_path.to_string(), key: encryption_key.to_string(), }) .expect("Failed to create encrypted storage"); - let mut bob_storage = SqliteStorage::new(StorageConfig::Encrypted { + let mut bob_storage = RatchetStorage::with_config(StorageConfig::Encrypted { path: bob_db_path.to_string(), key: encryption_key.to_string(), }) @@ -109,12 +109,12 @@ fn demo_sqlcipher() { // Restart with correct key println!("\n Simulating restart with encryption key..."); { - let mut alice_storage = SqliteStorage::new(StorageConfig::Encrypted { + let mut alice_storage = RatchetStorage::with_config(StorageConfig::Encrypted { path: alice_db_path.to_string(), key: encryption_key.to_string(), }) .expect("Failed to create encrypted storage"); - let mut bob_storage = SqliteStorage::new(StorageConfig::Encrypted { + let mut bob_storage = RatchetStorage::with_config(StorageConfig::Encrypted { path: bob_db_path.to_string(), key: encryption_key.to_string(), }) @@ -137,7 +137,7 @@ fn ensure_tmp_directory() { /// Simulates a conversation between Alice and Bob. /// Each party saves/loads state from storage for each operation. #[cfg(feature = "storage")] -fn run_conversation(alice_storage: &mut SqliteStorage, bob_storage: &mut SqliteStorage) { +fn run_conversation(alice_storage: &mut RatchetStorage, bob_storage: &mut RatchetStorage) { // === Setup: Simulate X3DH key exchange === let shared_secret = [0x42u8; 32]; // In reality, this comes from X3DH let bob_keypair = InstallationKeyPair::generate(); @@ -209,7 +209,7 @@ fn run_conversation(alice_storage: &mut SqliteStorage, bob_storage: &mut SqliteS } #[cfg(feature = "storage")] -fn continue_after_restart(alice_storage: &mut SqliteStorage, bob_storage: &mut SqliteStorage) { +fn continue_after_restart(alice_storage: &mut RatchetStorage, bob_storage: &mut RatchetStorage) { // Load persisted states let conv_id = "conv1"; diff --git a/double-ratchets/src/lib.rs b/double-ratchets/src/lib.rs index 1b9a566..c055348 100644 --- a/double-ratchets/src/lib.rs +++ b/double-ratchets/src/lib.rs @@ -4,11 +4,13 @@ pub mod ffi; pub mod hkdf; pub mod keypair; pub mod state; -#[cfg(feature = "storage")] +#[cfg(feature = "persist")] pub mod storage; pub mod types; pub use keypair::InstallationKeyPair; -pub use state::{Header, RatchetState}; -#[cfg(feature = "storage")] -pub use storage::{RatchetSession, SessionError, SqliteStorage, StorageConfig, StorageError}; +pub use state::{Header, RatchetState, SkippedKey}; +#[cfg(feature = "persist")] +pub use storage::StorageConfig; +#[cfg(feature = "persist")] +pub use storage::{RatchetSession, RatchetStorage, SessionError}; diff --git a/double-ratchets/src/storage/mod.rs b/double-ratchets/src/storage/mod.rs index e26ec70..1625f95 100644 --- a/double-ratchets/src/storage/mod.rs +++ b/double-ratchets/src/storage/mod.rs @@ -1,5 +1,13 @@ -mod session; -mod sqlite; +//! Storage module for persisting ratchet state. +//! +//! This module provides storage implementations for the double ratchet state, +//! built on top of the shared `storage` crate. +mod ratchet_storage; +mod session; +mod types; + +pub use ratchet_storage::RatchetStorage; pub use session::{RatchetSession, SessionError}; -pub use sqlite::{SqliteStorage, StorageConfig}; +pub use storage::{SqliteDb, StorageConfig, StorageError}; +pub use types::RatchetStateRecord; diff --git a/double-ratchets/src/storage/mod.rs.bak b/double-ratchets/src/storage/mod.rs.bak new file mode 100644 index 0000000..e26ec70 --- /dev/null +++ b/double-ratchets/src/storage/mod.rs.bak @@ -0,0 +1,5 @@ +mod session; +mod sqlite; + +pub use session::{RatchetSession, SessionError}; +pub use sqlite::{SqliteStorage, StorageConfig}; diff --git a/double-ratchets/src/storage/ratchet_storage.rs b/double-ratchets/src/storage/ratchet_storage.rs new file mode 100644 index 0000000..ad3b549 --- /dev/null +++ b/double-ratchets/src/storage/ratchet_storage.rs @@ -0,0 +1,320 @@ +//! Ratchet-specific storage implementation. + +use std::collections::HashSet; + +use storage::{SqliteDb, StorageBackend, StorageError, params}; + +use super::types::RatchetStateRecord; +use crate::{ + hkdf::HkdfInfo, + state::{RatchetState, SkippedKey}, +}; + +/// Schema for ratchet state tables. +const RATCHET_SCHEMA: &str = " + CREATE TABLE IF NOT EXISTS ratchet_state ( + conversation_id TEXT PRIMARY KEY, + root_key BLOB NOT NULL, + sending_chain BLOB, + receiving_chain BLOB, + dh_self_secret BLOB NOT NULL, + dh_remote BLOB, + msg_send INTEGER NOT NULL, + msg_recv INTEGER NOT NULL, + prev_chain_len INTEGER NOT NULL + ); + + CREATE TABLE IF NOT EXISTS skipped_keys ( + conversation_id TEXT NOT NULL, + public_key BLOB NOT NULL, + msg_num INTEGER NOT NULL, + message_key BLOB NOT NULL, + created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')), + PRIMARY KEY (conversation_id, public_key, msg_num), + FOREIGN KEY (conversation_id) REFERENCES ratchet_state(conversation_id) ON DELETE CASCADE + ); + + CREATE INDEX IF NOT EXISTS idx_skipped_keys_conversation + ON skipped_keys(conversation_id); +"; + +/// Ratchet-specific storage operations. +/// +/// This struct wraps a `SqliteDb` and provides domain-specific +/// storage operations for ratchet state. +pub struct RatchetStorage { + db: SqliteDb, +} + +impl RatchetStorage { + /// Creates a new ratchet storage with the given database. + pub fn new(db: SqliteDb) -> Result { + // Initialize schema + db.execute_batch(RATCHET_SCHEMA)?; + Ok(Self { db }) + } + + /// Creates a new ratchet storage with the given configuration. + pub fn with_config(config: storage::StorageConfig) -> Result { + let db = SqliteDb::new(config)?; + Self::new(db) + } + + /// Creates an in-memory storage (useful for testing). + pub fn in_memory() -> Result { + let db = SqliteDb::in_memory()?; + Self::new(db) + } + + /// Saves the ratchet state for a conversation. + pub fn save( + &mut self, + conversation_id: &str, + state: &RatchetState, + ) -> Result<(), StorageError> { + let tx = self.db.transaction()?; + + let data = RatchetStateRecord::from(state); + let skipped_keys: Vec = state.skipped_keys(); + + // Upsert main state + tx.execute( + " + INSERT INTO ratchet_state ( + conversation_id, root_key, sending_chain, receiving_chain, + dh_self_secret, dh_remote, msg_send, msg_recv, prev_chain_len + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9) + ON CONFLICT(conversation_id) DO UPDATE SET + root_key = excluded.root_key, + sending_chain = excluded.sending_chain, + receiving_chain = excluded.receiving_chain, + dh_self_secret = excluded.dh_self_secret, + dh_remote = excluded.dh_remote, + msg_send = excluded.msg_send, + msg_recv = excluded.msg_recv, + prev_chain_len = excluded.prev_chain_len + ", + params![ + conversation_id, + data.root_key.as_slice(), + data.sending_chain.as_ref().map(|c| c.as_slice()), + data.receiving_chain.as_ref().map(|c| c.as_slice()), + data.dh_self_secret.as_slice(), + data.dh_remote.as_ref().map(|c| c.as_slice()), + data.msg_send, + data.msg_recv, + data.prev_chain_len, + ], + )?; + + // Sync skipped keys + sync_skipped_keys(&tx, conversation_id, skipped_keys)?; + + tx.commit()?; + Ok(()) + } + + /// Loads the ratchet state for a conversation. + pub fn load( + &self, + conversation_id: &str, + ) -> Result, StorageError> { + let data = self.load_state_data(conversation_id)?; + let skipped_keys = self.load_skipped_keys(conversation_id)?; + Ok(data.into_ratchet_state(skipped_keys)) + } + + fn load_state_data(&self, conversation_id: &str) -> Result { + let conn = self.db.connection(); + let mut stmt = conn.prepare( + " + SELECT root_key, sending_chain, receiving_chain, dh_self_secret, + dh_remote, msg_send, msg_recv, prev_chain_len + FROM ratchet_state + WHERE conversation_id = ?1 + ", + )?; + + stmt.query_row(params![conversation_id], |row| { + Ok(RatchetStateRecord { + root_key: blob_to_array(row.get::<_, Vec>(0)?), + sending_chain: row.get::<_, Option>>(1)?.map(blob_to_array), + receiving_chain: row.get::<_, Option>>(2)?.map(blob_to_array), + dh_self_secret: blob_to_array(row.get::<_, Vec>(3)?), + dh_remote: row.get::<_, Option>>(4)?.map(blob_to_array), + msg_send: row.get(5)?, + msg_recv: row.get(6)?, + prev_chain_len: row.get(7)?, + }) + }) + .map_err(|e| match e { + storage::RusqliteError::QueryReturnedNoRows => { + StorageError::NotFound(conversation_id.to_string()) + } + e => StorageError::Database(e.to_string()), + }) + } + + fn load_skipped_keys(&self, conversation_id: &str) -> Result, StorageError> { + let conn = self.db.connection(); + let mut stmt = conn.prepare( + " + SELECT public_key, msg_num, message_key + FROM skipped_keys + WHERE conversation_id = ?1 + ", + )?; + + let rows = stmt.query_map(params![conversation_id], |row| { + Ok(SkippedKey { + public_key: blob_to_array(row.get::<_, Vec>(0)?), + msg_num: row.get(1)?, + message_key: blob_to_array(row.get::<_, Vec>(2)?), + }) + })?; + + rows.collect::, _>>() + .map_err(|e| StorageError::Database(e.to_string())) + } + + /// Checks if a conversation exists. + pub fn exists(&self, conversation_id: &str) -> Result { + let conn = self.db.connection(); + let count: i64 = conn.query_row( + "SELECT COUNT(*) FROM ratchet_state WHERE conversation_id = ?1", + params![conversation_id], + |row| row.get(0), + )?; + Ok(count > 0) + } + + /// Deletes a conversation and its skipped keys. + pub fn delete(&mut self, conversation_id: &str) -> Result<(), StorageError> { + let tx = self.db.transaction()?; + tx.execute( + "DELETE FROM skipped_keys WHERE conversation_id = ?1", + params![conversation_id], + )?; + tx.execute( + "DELETE FROM ratchet_state WHERE conversation_id = ?1", + params![conversation_id], + )?; + tx.commit()?; + Ok(()) + } + + /// Cleans up old skipped keys older than the given age in seconds. + pub fn cleanup_old_skipped_keys(&mut self, max_age_secs: i64) -> Result { + let conn = self.db.connection(); + let deleted = conn.execute( + "DELETE FROM skipped_keys WHERE created_at < strftime('%s', 'now') - ?1", + params![max_age_secs], + )?; + Ok(deleted) + } +} + +/// Syncs skipped keys efficiently by computing diff and only inserting/deleting changes. +fn sync_skipped_keys( + tx: &storage::Transaction, + conversation_id: &str, + current_keys: Vec, +) -> Result<(), StorageError> { + // Get existing keys from DB (just the identifiers) + let mut stmt = + tx.prepare("SELECT public_key, msg_num FROM skipped_keys WHERE conversation_id = ?1")?; + let existing: HashSet<([u8; 32], u32)> = stmt + .query_map(params![conversation_id], |row| { + Ok(( + blob_to_array(row.get::<_, Vec>(0)?), + row.get::<_, u32>(1)?, + )) + })? + .filter_map(|r| r.ok()) + .collect(); + + // Build set of current keys + let current_set: HashSet<([u8; 32], u32)> = current_keys + .iter() + .map(|sk| (sk.public_key, sk.msg_num)) + .collect(); + + // Delete keys that were removed (used for decryption) + for (pk, msg_num) in existing.difference(¤t_set) { + tx.execute( + "DELETE FROM skipped_keys WHERE conversation_id = ?1 AND public_key = ?2 AND msg_num = ?3", + params![conversation_id, pk.as_slice(), msg_num], + )?; + } + + // Insert new keys + for sk in ¤t_keys { + let key = (sk.public_key, sk.msg_num); + if !existing.contains(&key) { + tx.execute( + "INSERT INTO skipped_keys (conversation_id, public_key, msg_num, message_key) + VALUES (?1, ?2, ?3, ?4)", + params![ + conversation_id, + sk.public_key.as_slice(), + sk.msg_num, + sk.message_key.as_slice(), + ], + )?; + } + } + + Ok(()) +} + +fn blob_to_array(blob: Vec) -> [u8; N] { + blob.try_into() + .unwrap_or_else(|v: Vec| panic!("Expected {} bytes, got {}", N, v.len())) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{keypair::InstallationKeyPair, state::RatchetState, types::SharedSecret}; + + fn create_test_state() -> (RatchetState, SharedSecret) { + let shared_secret = [0x42u8; 32]; + let bob_keypair = InstallationKeyPair::generate(); + let state = RatchetState::init_sender(shared_secret, bob_keypair.public().clone()); + (state, shared_secret) + } + + #[test] + fn test_save_and_load() { + let mut storage = RatchetStorage::in_memory().unwrap(); + let (state, _) = create_test_state(); + + storage.save("conv1", &state).unwrap(); + let loaded: RatchetState = storage.load("conv1").unwrap(); + + assert_eq!(state.root_key, loaded.root_key); + assert_eq!(state.msg_send, loaded.msg_send); + } + + #[test] + fn test_exists() { + let mut storage = RatchetStorage::in_memory().unwrap(); + let (state, _) = create_test_state(); + + assert!(!storage.exists("conv1").unwrap()); + storage.save("conv1", &state).unwrap(); + assert!(storage.exists("conv1").unwrap()); + } + + #[test] + fn test_delete() { + let mut storage = RatchetStorage::in_memory().unwrap(); + let (state, _) = create_test_state(); + + storage.save("conv1", &state).unwrap(); + assert!(storage.exists("conv1").unwrap()); + + storage.delete("conv1").unwrap(); + assert!(!storage.exists("conv1").unwrap()); + } +} diff --git a/double-ratchets/src/storage/session.rs b/double-ratchets/src/storage/session.rs index 399af8d..cdab488 100644 --- a/double-ratchets/src/storage/session.rs +++ b/double-ratchets/src/storage/session.rs @@ -1,4 +1,7 @@ +//! Session wrapper for automatic state persistence. + use x25519_dalek::PublicKey; +use storage::StorageError; use crate::{ InstallationKeyPair, @@ -8,12 +11,12 @@ use crate::{ types::SharedSecret, }; -use super::{SqliteStorage, StorageError}; +use super::RatchetStorage; /// A session wrapper that automatically persists ratchet state after operations. /// Provides rollback semantics - state is only saved if the operation succeeds. pub struct RatchetSession<'a, D: HkdfInfo + Clone> { - storage: &'a mut SqliteStorage, + storage: &'a mut RatchetStorage, conversation_id: String, state: RatchetState, } @@ -50,7 +53,7 @@ impl std::error::Error for SessionError {} impl<'a, D: HkdfInfo + Clone> RatchetSession<'a, D> { /// Opens an existing session from storage. pub fn open( - storage: &'a mut SqliteStorage, + storage: &'a mut RatchetStorage, conversation_id: impl Into, ) -> Result { let conversation_id = conversation_id.into(); @@ -64,7 +67,7 @@ impl<'a, D: HkdfInfo + Clone> RatchetSession<'a, D> { /// Creates a new session and persists the initial state. pub fn create( - storage: &'a mut SqliteStorage, + storage: &'a mut RatchetStorage, conversation_id: impl Into, state: RatchetState, ) -> Result { @@ -79,7 +82,7 @@ impl<'a, D: HkdfInfo + Clone> RatchetSession<'a, D> { /// Initializes a new session as a sender and persists the initial state. pub fn create_sender_session( - storage: &'a mut SqliteStorage, + storage: &'a mut RatchetStorage, conversation_id: impl Into, shared_secret: SharedSecret, remote_pub: PublicKey, @@ -90,7 +93,7 @@ impl<'a, D: HkdfInfo + Clone> RatchetSession<'a, D> { /// Initializes a new session as a receiver and persists the initial state. pub fn create_receiver_session( - storage: &'a mut SqliteStorage, + storage: &'a mut RatchetStorage, conversation_id: impl Into, shared_secret: SharedSecret, dh_self: InstallationKeyPair, @@ -180,10 +183,10 @@ impl<'a, D: HkdfInfo + Clone> RatchetSession<'a, D> { #[cfg(test)] mod tests { use super::*; - use crate::{hkdf::DefaultDomain, keypair::InstallationKeyPair, storage::StorageConfig}; + use crate::hkdf::DefaultDomain; - fn create_test_storage() -> SqliteStorage { - SqliteStorage::new(StorageConfig::InMemory).unwrap() + fn create_test_storage() -> RatchetStorage { + RatchetStorage::in_memory().unwrap() } #[test] diff --git a/double-ratchets/src/storage/session.rs.bak b/double-ratchets/src/storage/session.rs.bak new file mode 100644 index 0000000..399af8d --- /dev/null +++ b/double-ratchets/src/storage/session.rs.bak @@ -0,0 +1,310 @@ +use x25519_dalek::PublicKey; + +use crate::{ + InstallationKeyPair, + errors::RatchetError, + hkdf::HkdfInfo, + state::{Header, RatchetState}, + types::SharedSecret, +}; + +use super::{SqliteStorage, StorageError}; + +/// A session wrapper that automatically persists ratchet state after operations. +/// Provides rollback semantics - state is only saved if the operation succeeds. +pub struct RatchetSession<'a, D: HkdfInfo + Clone> { + storage: &'a mut SqliteStorage, + conversation_id: String, + state: RatchetState, +} + +#[derive(Debug)] +pub enum SessionError { + Storage(StorageError), + Ratchet(RatchetError), +} + +impl From for SessionError { + fn from(e: StorageError) -> Self { + SessionError::Storage(e) + } +} + +impl From for SessionError { + fn from(e: RatchetError) -> Self { + SessionError::Ratchet(e) + } +} + +impl std::fmt::Display for SessionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SessionError::Storage(e) => write!(f, "storage error: {}", e), + SessionError::Ratchet(e) => write!(f, "ratchet error: {}", e), + } + } +} + +impl std::error::Error for SessionError {} + +impl<'a, D: HkdfInfo + Clone> RatchetSession<'a, D> { + /// Opens an existing session from storage. + pub fn open( + storage: &'a mut SqliteStorage, + conversation_id: impl Into, + ) -> Result { + let conversation_id = conversation_id.into(); + let state = storage.load(&conversation_id)?; + Ok(Self { + storage, + conversation_id, + state, + }) + } + + /// Creates a new session and persists the initial state. + pub fn create( + storage: &'a mut SqliteStorage, + conversation_id: impl Into, + state: RatchetState, + ) -> Result { + let conversation_id = conversation_id.into(); + storage.save(&conversation_id, &state)?; + Ok(Self { + storage, + conversation_id, + state, + }) + } + + /// Initializes a new session as a sender and persists the initial state. + pub fn create_sender_session( + storage: &'a mut SqliteStorage, + conversation_id: impl Into, + shared_secret: SharedSecret, + remote_pub: PublicKey, + ) -> Result { + let state = RatchetState::::init_sender(shared_secret, remote_pub); + Self::create(storage, conversation_id, state) + } + + /// Initializes a new session as a receiver and persists the initial state. + pub fn create_receiver_session( + storage: &'a mut SqliteStorage, + conversation_id: impl Into, + shared_secret: SharedSecret, + dh_self: InstallationKeyPair, + ) -> Result { + let conversation_id = conversation_id.into(); + if storage.exists(&conversation_id)? { + return Self::open(storage, conversation_id); + } + + let state = RatchetState::::init_receiver(shared_secret, dh_self); + Self::create(storage, conversation_id, state) + } + + /// Encrypts a message and persists the updated state. + /// If persistence fails, the in-memory state is NOT modified. + pub fn encrypt_message(&mut self, plaintext: &[u8]) -> Result<(Vec, Header), SessionError> { + // Clone state for rollback + let state_backup = self.state.clone(); + + // Perform encryption (modifies state) + let result = self.state.encrypt_message(plaintext); + + // Try to persist + if let Err(e) = self.storage.save(&self.conversation_id, &self.state) { + // Rollback + self.state = state_backup; + return Err(SessionError::Storage(e)); + } + + Ok(result) + } + + /// Decrypts a message and persists the updated state. + /// If decryption or persistence fails, the in-memory state is NOT modified. + pub fn decrypt_message( + &mut self, + ciphertext_with_nonce: &[u8], + header: Header, + ) -> Result, SessionError> { + // Clone state for rollback + let state_backup = self.state.clone(); + + // Perform decryption (modifies state) + let plaintext = match self.state.decrypt_message(ciphertext_with_nonce, header) { + Ok(pt) => pt, + Err(e) => { + // Rollback on decrypt failure + self.state = state_backup; + return Err(SessionError::Ratchet(e)); + } + }; + + // Try to persist + if let Err(e) = self.storage.save(&self.conversation_id, &self.state) { + // Rollback + self.state = state_backup; + return Err(SessionError::Storage(e)); + } + + Ok(plaintext) + } + + /// Returns a reference to the current state (read-only). + pub fn state(&self) -> &RatchetState { + &self.state + } + + /// Returns the conversation ID. + pub fn conversation_id(&self) -> &str { + &self.conversation_id + } + + /// Manually saves the current state. + pub fn save(&mut self) -> Result<(), StorageError> { + self.storage.save(&self.conversation_id, &self.state) + } + + pub fn msg_send(&self) -> u32 { + self.state.msg_send + } + + pub fn msg_recv(&self) -> u32 { + self.state.msg_recv + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{hkdf::DefaultDomain, keypair::InstallationKeyPair, storage::StorageConfig}; + + fn create_test_storage() -> SqliteStorage { + SqliteStorage::new(StorageConfig::InMemory).unwrap() + } + + #[test] + fn test_session_create_and_open() { + let mut storage = create_test_storage(); + + let shared_secret = [0x42; 32]; + let bob_keypair = InstallationKeyPair::generate(); + let alice: RatchetState = + RatchetState::init_sender(shared_secret, bob_keypair.public().clone()); + + // Create session + { + let session = RatchetSession::create(&mut storage, "conv1", alice).unwrap(); + assert_eq!(session.conversation_id(), "conv1"); + } + + // Open existing session + { + let session: RatchetSession = + RatchetSession::open(&mut storage, "conv1").unwrap(); + assert_eq!(session.state().msg_send, 0); + } + } + + #[test] + fn test_session_encrypt_persists() { + let mut storage = create_test_storage(); + + let shared_secret = [0x42; 32]; + let bob_keypair = InstallationKeyPair::generate(); + let alice: RatchetState = + RatchetState::init_sender(shared_secret, bob_keypair.public().clone()); + + // Create and encrypt + { + let mut session = RatchetSession::create(&mut storage, "conv1", alice).unwrap(); + session.encrypt_message(b"Hello").unwrap(); + assert_eq!(session.state().msg_send, 1); + } + + // Reopen - state should be persisted + { + let session: RatchetSession = + RatchetSession::open(&mut storage, "conv1").unwrap(); + assert_eq!(session.state().msg_send, 1); + } + } + + #[test] + fn test_session_full_conversation() { + let mut storage = create_test_storage(); + + let shared_secret = [0x42; 32]; + let bob_keypair = InstallationKeyPair::generate(); + let alice: RatchetState = + RatchetState::init_sender(shared_secret, bob_keypair.public().clone()); + let bob: RatchetState = + RatchetState::init_receiver(shared_secret, bob_keypair); + + // Alice sends + let (ct, header) = { + let mut session = RatchetSession::create(&mut storage, "alice", alice).unwrap(); + session.encrypt_message(b"Hello Bob").unwrap() + }; + + // Bob receives + let plaintext = { + let mut session = RatchetSession::create(&mut storage, "bob", bob).unwrap(); + session.decrypt_message(&ct, header).unwrap() + }; + assert_eq!(plaintext, b"Hello Bob"); + + // Bob replies + let (ct2, header2) = { + let mut session: RatchetSession = + RatchetSession::open(&mut storage, "bob").unwrap(); + session.encrypt_message(b"Hi Alice").unwrap() + }; + + // Alice receives + let plaintext2 = { + let mut session: RatchetSession = + RatchetSession::open(&mut storage, "alice").unwrap(); + session.decrypt_message(&ct2, header2).unwrap() + }; + assert_eq!(plaintext2, b"Hi Alice"); + } + + #[test] + fn test_session_open_or_create() { + let mut storage = create_test_storage(); + + let shared_secret = [0x42; 32]; + let bob_keypair = InstallationKeyPair::generate(); + let bob_pub = bob_keypair.public().clone(); + + // First call creates + { + let session: RatchetSession = RatchetSession::create_sender_session( + &mut storage, + "conv1", + shared_secret, + bob_pub.clone(), + ) + .unwrap(); + assert_eq!(session.state().msg_send, 0); + } + + // Second call opens existing + { + let mut session: RatchetSession = + RatchetSession::open(&mut storage, "conv1").unwrap(); + session.encrypt_message(b"test").unwrap(); + } + + // Verify persistence + { + let session: RatchetSession = + RatchetSession::open(&mut storage, "conv1").unwrap(); + assert_eq!(session.state().msg_send, 1); + } + } +} diff --git a/double-ratchets/src/storage/sqlite.rs b/double-ratchets/src/storage/sqlite.rs deleted file mode 100644 index 2c061f8..0000000 --- a/double-ratchets/src/storage/sqlite.rs +++ /dev/null @@ -1,437 +0,0 @@ -use rusqlite::{Connection, params}; - -use super::{RatchetStateRecord, SkippedKey, StorageError}; -use crate::{hkdf::HkdfInfo, state::RatchetState}; - -/// Configuration for SQLite storage. -#[derive(Debug, Clone)] -pub enum StorageConfig { - /// In-memory database (for testing). - InMemory, - /// File-based SQLite database (unencrypted, for local dev). - File(String), - /// SQLCipher encrypted database (for production). - /// Requires the `sqlcipher` feature. - #[cfg(feature = "sqlcipher")] - Encrypted { path: String, key: String }, -} - -/// SQLite-based storage for ratchet state. -pub struct SqliteStorage { - conn: Connection, -} - -impl SqliteStorage { - /// Creates a new SQLite storage with the given configuration. - pub fn new(config: StorageConfig) -> Result { - let conn = match config { - StorageConfig::InMemory => Connection::open_in_memory()?, - StorageConfig::File(path) => Connection::open(path)?, - #[cfg(feature = "sqlcipher")] - StorageConfig::Encrypted { path, key } => { - let conn = Connection::open(path)?; - conn.pragma_update(None, "key", &key)?; - conn - } - }; - - let storage = Self { conn }; - storage.init_schema()?; - Ok(storage) - } - - fn init_schema(&self) -> Result<(), StorageError> { - self.conn.execute_batch( - " - CREATE TABLE IF NOT EXISTS ratchet_state ( - conversation_id TEXT PRIMARY KEY, - root_key BLOB NOT NULL, - sending_chain BLOB, - receiving_chain BLOB, - dh_self_secret BLOB NOT NULL, - dh_remote BLOB, - msg_send INTEGER NOT NULL, - msg_recv INTEGER NOT NULL, - prev_chain_len INTEGER NOT NULL - ); - - CREATE TABLE IF NOT EXISTS skipped_keys ( - conversation_id TEXT NOT NULL, - public_key BLOB NOT NULL, - msg_num INTEGER NOT NULL, - message_key BLOB NOT NULL, - created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')), - PRIMARY KEY (conversation_id, public_key, msg_num), - FOREIGN KEY (conversation_id) REFERENCES ratchet_state(conversation_id) ON DELETE CASCADE - ); - - CREATE INDEX IF NOT EXISTS idx_skipped_keys_conversation - ON skipped_keys(conversation_id); - ", - )?; - Ok(()) - } - - /// Saves the ratchet state for a conversation within a transaction. - /// Rolls back automatically if any error occurs. - pub fn save( - &mut self, - conversation_id: &str, - state: &RatchetState, - ) -> Result<(), StorageError> { - let tx = self.conn.transaction()?; - - let data = RatchetStateRecord::from(state); - let skipped_keys: Vec = state.skipped_keys(); - - // Upsert main state - tx.execute( - " - INSERT INTO ratchet_state ( - conversation_id, root_key, sending_chain, receiving_chain, - dh_self_secret, dh_remote, msg_send, msg_recv, prev_chain_len - ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9) - ON CONFLICT(conversation_id) DO UPDATE SET - root_key = excluded.root_key, - sending_chain = excluded.sending_chain, - receiving_chain = excluded.receiving_chain, - dh_self_secret = excluded.dh_self_secret, - dh_remote = excluded.dh_remote, - msg_send = excluded.msg_send, - msg_recv = excluded.msg_recv, - prev_chain_len = excluded.prev_chain_len - ", - params![ - conversation_id, - data.root_key.as_slice(), - data.sending_chain.as_ref().map(|c| c.as_slice()), - data.receiving_chain.as_ref().map(|c| c.as_slice()), - data.dh_self_secret.as_slice(), - data.dh_remote.as_ref().map(|c| c.as_slice()), - data.msg_send, - data.msg_recv, - data.prev_chain_len, - ], - )?; - - // Sync skipped keys efficiently - only insert new, delete removed - sync_skipped_keys(&tx, conversation_id, skipped_keys)?; - - tx.commit()?; - Ok(()) - } - - /// Loads the ratchet state for a conversation. - pub fn load( - &self, - conversation_id: &str, - ) -> Result, StorageError> { - let data = self.load_state_data(conversation_id)?; - let skipped_keys = self.load_skipped_keys(conversation_id)?; - Ok(data.into_ratchet_state(skipped_keys)) - } - - fn load_state_data(&self, conversation_id: &str) -> Result { - let mut stmt = self.conn.prepare( - " - SELECT root_key, sending_chain, receiving_chain, dh_self_secret, - dh_remote, msg_send, msg_recv, prev_chain_len - FROM ratchet_state - WHERE conversation_id = ?1 - ", - )?; - - stmt.query_row(params![conversation_id], |row| { - Ok(RatchetStateRecord { - root_key: blob_to_array(row.get::<_, Vec>(0)?), - sending_chain: row.get::<_, Option>>(1)?.map(blob_to_array), - receiving_chain: row.get::<_, Option>>(2)?.map(blob_to_array), - dh_self_secret: blob_to_array(row.get::<_, Vec>(3)?), - dh_remote: row.get::<_, Option>>(4)?.map(blob_to_array), - msg_send: row.get(5)?, - msg_recv: row.get(6)?, - prev_chain_len: row.get(7)?, - }) - }) - .map_err(|e| match e { - rusqlite::Error::QueryReturnedNoRows => { - StorageError::ConversationNotFound(conversation_id.to_string()) - } - e => StorageError::Database(e), - }) - } - - fn load_skipped_keys(&self, conversation_id: &str) -> Result, StorageError> { - let mut stmt = self.conn.prepare( - " - SELECT public_key, msg_num, message_key - FROM skipped_keys - WHERE conversation_id = ?1 - ", - )?; - - let rows = stmt.query_map(params![conversation_id], |row| { - Ok(SkippedKey { - public_key: blob_to_array(row.get::<_, Vec>(0)?), - msg_num: row.get(1)?, - message_key: blob_to_array(row.get::<_, Vec>(2)?), - }) - })?; - - rows.collect::, _>>() - .map_err(StorageError::Database) - } - - /// Checks if a conversation exists. - pub fn exists(&self, conversation_id: &str) -> Result { - let count: i64 = self.conn.query_row( - "SELECT COUNT(*) FROM ratchet_state WHERE conversation_id = ?1", - params![conversation_id], - |row| row.get(0), - )?; - Ok(count > 0) - } - - /// Deletes a conversation and its skipped keys. - pub fn delete(&mut self, conversation_id: &str) -> Result<(), StorageError> { - let tx = self.conn.transaction()?; - tx.execute( - "DELETE FROM skipped_keys WHERE conversation_id = ?1", - params![conversation_id], - )?; - tx.execute( - "DELETE FROM ratchet_state WHERE conversation_id = ?1", - params![conversation_id], - )?; - tx.commit()?; - Ok(()) - } - - /// Cleans up old skipped keys older than the given age in seconds. - pub fn cleanup_old_skipped_keys(&mut self, max_age_secs: i64) -> Result { - let deleted = self.conn.execute( - "DELETE FROM skipped_keys WHERE created_at < strftime('%s', 'now') - ?1", - params![max_age_secs], - )?; - Ok(deleted) - } -} - -/// Syncs skipped keys efficiently by computing diff and only inserting/deleting changes. -fn sync_skipped_keys( - tx: &rusqlite::Transaction, - conversation_id: &str, - current_keys: Vec, -) -> Result<(), StorageError> { - use std::collections::HashSet; - - // Get existing keys from DB (just the identifiers) - let mut stmt = - tx.prepare("SELECT public_key, msg_num FROM skipped_keys WHERE conversation_id = ?1")?; - let existing: HashSet<([u8; 32], u32)> = stmt - .query_map(params![conversation_id], |row| { - Ok(( - blob_to_array(row.get::<_, Vec>(0)?), - row.get::<_, u32>(1)?, - )) - })? - .filter_map(|r| r.ok()) - .collect(); - - // Build set of current keys - let current_set: HashSet<([u8; 32], u32)> = current_keys - .iter() - .map(|sk| (sk.public_key, sk.msg_num)) - .collect(); - - // Delete keys that were removed (used for decryption) - for (pk, msg_num) in existing.difference(¤t_set) { - tx.execute( - "DELETE FROM skipped_keys WHERE conversation_id = ?1 AND public_key = ?2 AND msg_num = ?3", - params![conversation_id, pk.as_slice(), msg_num], - )?; - } - - // Insert new keys - for sk in ¤t_keys { - let key = (sk.public_key, sk.msg_num); - if !existing.contains(&key) { - tx.execute( - "INSERT INTO skipped_keys (conversation_id, public_key, msg_num, message_key) - VALUES (?1, ?2, ?3, ?4)", - params![ - conversation_id, - sk.public_key.as_slice(), - sk.msg_num, - sk.message_key.as_slice(), - ], - )?; - } - } - - Ok(()) -} - -fn blob_to_array(blob: Vec) -> [u8; N] { - blob.try_into() - .unwrap_or_else(|v: Vec| panic!("Expected {} bytes, got {}", N, v.len())) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{hkdf::DefaultDomain, keypair::InstallationKeyPair}; - - fn create_test_storage() -> SqliteStorage { - SqliteStorage::new(StorageConfig::InMemory).unwrap() - } - - fn create_test_state() -> (RatchetState, RatchetState) { - let shared_secret = [0x42; 32]; - let bob_keypair = InstallationKeyPair::generate(); - let alice = RatchetState::init_sender(shared_secret, bob_keypair.public().clone()); - let bob = RatchetState::init_receiver(shared_secret, bob_keypair); - (alice, bob) - } - - #[test] - fn test_save_and_load_sender() { - let mut storage = create_test_storage(); - let (alice, _) = create_test_state(); - - storage.save("conv1", &alice).unwrap(); - let loaded: RatchetState = storage.load("conv1").unwrap(); - - assert_eq!(alice.root_key, loaded.root_key); - assert_eq!(alice.sending_chain, loaded.sending_chain); - assert_eq!(alice.receiving_chain, loaded.receiving_chain); - assert_eq!(alice.msg_send, loaded.msg_send); - assert_eq!(alice.msg_recv, loaded.msg_recv); - assert_eq!(alice.prev_chain_len, loaded.prev_chain_len); - assert_eq!( - alice.dh_self.public().to_bytes(), - loaded.dh_self.public().to_bytes() - ); - } - - #[test] - fn test_save_and_load_receiver() { - let mut storage = create_test_storage(); - let (_, bob) = create_test_state(); - - storage.save("conv1", &bob).unwrap(); - let loaded: RatchetState = storage.load("conv1").unwrap(); - - assert_eq!(bob.root_key, loaded.root_key); - assert!(loaded.dh_remote.is_none()); - } - - #[test] - fn test_load_not_found() { - let storage = create_test_storage(); - let result: Result, _> = storage.load("nonexistent"); - assert!(matches!(result, Err(StorageError::ConversationNotFound(_)))); - } - - #[test] - fn test_save_with_skipped_keys() { - let mut storage = create_test_storage(); - let (mut alice, mut bob) = create_test_state(); - - // Alice sends 3 messages - let mut sent = vec![]; - for i in 0..3 { - let plaintext = format!("Message {}", i + 1).into_bytes(); - let (ct, header) = alice.encrypt_message(&plaintext); - sent.push((ct, header, plaintext)); - } - - // Bob receives 0 and 2, skipping 1 - bob.decrypt_message(&sent[0].0, sent[0].1.clone()).unwrap(); - bob.decrypt_message(&sent[2].0, sent[2].1.clone()).unwrap(); - - assert_eq!(bob.skipped_keys.len(), 1); - - // Save and reload - storage.save("conv1", &bob).unwrap(); - let mut loaded: RatchetState = storage.load("conv1").unwrap(); - - assert_eq!(loaded.skipped_keys.len(), 1); - - // Should be able to decrypt skipped message - let pt = loaded - .decrypt_message(&sent[1].0, sent[1].1.clone()) - .unwrap(); - assert_eq!(pt, sent[1].2); - } - - #[test] - fn test_update_existing() { - let mut storage = create_test_storage(); - let (mut alice, mut bob) = create_test_state(); - - storage.save("conv1", &alice).unwrap(); - - // Exchange a message - let (ct, header) = alice.encrypt_message(b"Hello"); - bob.decrypt_message(&ct, header).unwrap(); - - // Update Alice's state - storage.save("conv1", &alice).unwrap(); - - let loaded: RatchetState = storage.load("conv1").unwrap(); - assert_eq!(loaded.msg_send, 1); - } - - #[test] - fn test_exists() { - let mut storage = create_test_storage(); - let (alice, _) = create_test_state(); - - assert!(!storage.exists("conv1").unwrap()); - storage.save("conv1", &alice).unwrap(); - assert!(storage.exists("conv1").unwrap()); - } - - #[test] - fn test_delete() { - let mut storage = create_test_storage(); - let (alice, _) = create_test_state(); - - storage.save("conv1", &alice).unwrap(); - assert!(storage.exists("conv1").unwrap()); - - storage.delete("conv1").unwrap(); - assert!(!storage.exists("conv1").unwrap()); - } - - #[test] - fn test_continue_conversation_after_reload() { - let mut storage = create_test_storage(); - let (mut alice, mut bob) = create_test_state(); - - // Exchange messages - let (ct1, h1) = alice.encrypt_message(b"Hello Bob"); - bob.decrypt_message(&ct1, h1).unwrap(); - - let (ct2, h2) = bob.encrypt_message(b"Hello Alice"); - alice.decrypt_message(&ct2, h2).unwrap(); - - // Save both - storage.save("alice", &alice).unwrap(); - storage.save("bob", &bob).unwrap(); - - // Reload - let mut alice_new: RatchetState = storage.load("alice").unwrap(); - let mut bob_new: RatchetState = storage.load("bob").unwrap(); - - // Continue conversation - let (ct3, h3) = alice_new.encrypt_message(b"After reload"); - let pt3 = bob_new.decrypt_message(&ct3, h3).unwrap(); - assert_eq!(pt3, b"After reload"); - - let (ct4, h4) = bob_new.encrypt_message(b"Reply after reload"); - let pt4 = alice_new.decrypt_message(&ct4, h4).unwrap(); - assert_eq!(pt4, b"Reply after reload"); - } -} diff --git a/double-ratchets/src/storage/types.rs b/double-ratchets/src/storage/types.rs index 6a1cd80..b667b5d 100644 --- a/double-ratchets/src/storage/types.rs +++ b/double-ratchets/src/storage/types.rs @@ -1,28 +1,12 @@ +//! Storage types for ratchet state. + use crate::{ hkdf::HkdfInfo, state::{RatchetState, SkippedKey}, types::MessageKey, }; -use thiserror::Error; use x25519_dalek::PublicKey; -#[derive(Debug, Error)] -pub enum StorageError { - #[error("database error: {0}")] - Database(#[from] rusqlite::Error), - - #[error("conversation not found: {0}")] - ConversationNotFound(String), - - #[error("serialization error")] - Serialization, - - #[error("deserialization error")] - Deserialization, -} - -/// Stored representation of a skipped message key. - /// Raw state data for storage (without generic parameter). #[derive(Debug, Clone)] pub struct RatchetStateRecord { diff --git a/double-ratchets/src/storage/types.rs.bak b/double-ratchets/src/storage/types.rs.bak new file mode 100644 index 0000000..6a1cd80 --- /dev/null +++ b/double-ratchets/src/storage/types.rs.bak @@ -0,0 +1,81 @@ +use crate::{ + hkdf::HkdfInfo, + state::{RatchetState, SkippedKey}, + types::MessageKey, +}; +use thiserror::Error; +use x25519_dalek::PublicKey; + +#[derive(Debug, Error)] +pub enum StorageError { + #[error("database error: {0}")] + Database(#[from] rusqlite::Error), + + #[error("conversation not found: {0}")] + ConversationNotFound(String), + + #[error("serialization error")] + Serialization, + + #[error("deserialization error")] + Deserialization, +} + +/// Stored representation of a skipped message key. + +/// Raw state data for storage (without generic parameter). +#[derive(Debug, Clone)] +pub struct RatchetStateRecord { + pub root_key: [u8; 32], + pub sending_chain: Option<[u8; 32]>, + pub receiving_chain: Option<[u8; 32]>, + pub dh_self_secret: [u8; 32], + pub dh_remote: Option<[u8; 32]>, + pub msg_send: u32, + pub msg_recv: u32, + pub prev_chain_len: u32, +} + +impl From<&RatchetState> for RatchetStateRecord { + fn from(state: &RatchetState) -> Self { + Self { + root_key: state.root_key, + sending_chain: state.sending_chain, + receiving_chain: state.receiving_chain, + dh_self_secret: state.dh_self.secret_bytes(), + dh_remote: state.dh_remote.map(|pk| pk.to_bytes()), + msg_send: state.msg_send, + msg_recv: state.msg_recv, + prev_chain_len: state.prev_chain_len, + } + } +} + +impl RatchetStateRecord { + pub fn into_ratchet_state(self, skipped_keys: Vec) -> RatchetState { + use crate::keypair::InstallationKeyPair; + use std::collections::HashMap; + use std::marker::PhantomData; + + let dh_self = InstallationKeyPair::from_secret_bytes(self.dh_self_secret); + let dh_remote = self.dh_remote.map(PublicKey::from); + + let skipped: HashMap<(PublicKey, u32), MessageKey> = skipped_keys + .into_iter() + .map(|sk| ((PublicKey::from(sk.public_key), sk.msg_num), sk.message_key)) + .collect(); + + RatchetState { + root_key: self.root_key, + sending_chain: self.sending_chain, + receiving_chain: self.receiving_chain, + dh_self, + dh_remote, + msg_send: self.msg_send, + msg_recv: self.msg_recv, + prev_chain_len: self.prev_chain_len, + skipped_keys: skipped, + _domain: PhantomData, + } + } +} diff --git a/storage/Cargo.toml b/storage/Cargo.toml new file mode 100644 index 0000000..aecff59 --- /dev/null +++ b/storage/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "storage" +version = "0.1.0" +edition = "2024" +description = "Shared storage layer for libchat" + +[dependencies] +thiserror = "2" +rusqlite = { version = "0.35", optional = true, features = ["bundled"] } + +[features] +default = [] +sqlite = ["rusqlite"] +sqlcipher = ["sqlite", "rusqlite/bundled-sqlcipher-vendored-openssl"] diff --git a/storage/src/error.rs b/storage/src/error.rs new file mode 100644 index 0000000..6e68cf4 --- /dev/null +++ b/storage/src/error.rs @@ -0,0 +1,36 @@ +use thiserror::Error; + +/// Common storage errors. +#[derive(Debug, Error)] +pub enum StorageError { + /// Database error (wraps rusqlite::Error when sqlite feature is enabled). + #[error("database error: {0}")] + Database(String), + + /// Record not found. + #[error("not found: {0}")] + NotFound(String), + + /// Serialization error. + #[error("serialization error: {0}")] + Serialization(String), + + /// Deserialization error. + #[error("deserialization error: {0}")] + Deserialization(String), + + /// Schema migration error. + #[error("migration error: {0}")] + Migration(String), + + /// Transaction error. + #[error("transaction error: {0}")] + Transaction(String), +} + +#[cfg(feature = "sqlite")] +impl From for StorageError { + fn from(e: rusqlite::Error) -> Self { + StorageError::Database(e.to_string()) + } +} diff --git a/storage/src/lib.rs b/storage/src/lib.rs new file mode 100644 index 0000000..6c5e612 --- /dev/null +++ b/storage/src/lib.rs @@ -0,0 +1,43 @@ +//! Shared storage layer for libchat. +//! +//! This crate provides a common storage abstraction that can be used by +//! multiple crates in the libchat workspace (double-ratchets, conversations, etc.). +//! +//! # Features +//! +//! - `sqlite`: Enable SQLite-based storage +//! - `sqlcipher`: Enable encrypted SQLite storage via SQLCipher + +mod error; + +#[cfg(feature = "sqlite")] +mod sqlite; + +pub use error::StorageError; + +#[cfg(feature = "sqlite")] +pub use sqlite::{SqliteDb, StorageConfig}; + +// Re-export rusqlite types that domain crates will need +#[cfg(feature = "sqlite")] +pub use rusqlite::{params, Transaction, Error as RusqliteError}; + +/// Trait for types that can be stored and retrieved. +/// +/// Implement this trait for domain-specific storage operations. +pub trait Storable: Sized { + /// The key type used to identify records. + type Key; + + /// The error type returned by storage operations. + type Error: From; +} + +/// Trait for storage backends. +pub trait StorageBackend { + /// Initialize the storage (e.g., create tables). + fn init(&self) -> Result<(), StorageError>; + + /// Execute a batch of SQL statements (for schema migrations). + fn execute_batch(&self, sql: &str) -> Result<(), StorageError>; +} diff --git a/storage/src/sqlite.rs b/storage/src/sqlite.rs new file mode 100644 index 0000000..7ee22fb --- /dev/null +++ b/storage/src/sqlite.rs @@ -0,0 +1,103 @@ +//! SQLite storage backend. + +use rusqlite::Connection; +use std::path::Path; + +use crate::{StorageBackend, StorageError}; + +/// Configuration for SQLite storage. +#[derive(Debug, Clone)] +pub enum StorageConfig { + /// In-memory database (for testing). + InMemory, + /// File-based SQLite database. + File(String), + /// SQLCipher encrypted database (requires `sqlcipher` feature). + #[cfg(feature = "sqlcipher")] + Encrypted { + path: String, + key: String, + }, +} + +/// SQLite database wrapper. +/// +/// This provides the core database connection and can be shared +/// across different domain-specific storage implementations. +pub struct SqliteDb { + conn: Connection, +} + +impl SqliteDb { + /// Creates a new SQLite database with the given configuration. + pub fn new(config: StorageConfig) -> Result { + let conn = match config { + StorageConfig::InMemory => Connection::open_in_memory()?, + StorageConfig::File(ref path) => Connection::open(path)?, + #[cfg(feature = "sqlcipher")] + StorageConfig::Encrypted { ref path, ref key } => { + let conn = Connection::open(path)?; + conn.pragma_update(None, "key", key)?; + conn + } + }; + + // Enable foreign keys + conn.execute_batch("PRAGMA foreign_keys = ON;")?; + + Ok(Self { conn }) + } + + /// Opens an existing database file. + pub fn open>(path: P) -> Result { + let conn = Connection::open(path)?; + conn.execute_batch("PRAGMA foreign_keys = ON;")?; + Ok(Self { conn }) + } + + /// Creates an in-memory database (useful for testing). + pub fn in_memory() -> Result { + Self::new(StorageConfig::InMemory) + } + + /// Returns a reference to the underlying connection. + /// + /// Use this for domain-specific storage operations. + pub fn connection(&self) -> &Connection { + &self.conn + } + + /// Returns a mutable reference to the underlying connection. + /// + /// Use this for operations requiring a transaction. + pub fn connection_mut(&mut self) -> &mut Connection { + &mut self.conn + } + + /// Begins a transaction. + pub fn transaction(&mut self) -> Result, StorageError> { + Ok(self.conn.transaction()?) + } + + /// Checks if a table exists. + pub fn table_exists(&self, table_name: &str) -> Result { + let count: i32 = self.conn.query_row( + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?1", + [table_name], + |row| row.get(0), + )?; + Ok(count > 0) + } +} + +impl StorageBackend for SqliteDb { + fn init(&self) -> Result<(), StorageError> { + // Base initialization is done in new() + Ok(()) + } + + fn execute_batch(&self, sql: &str) -> Result<(), StorageError> { + self.conn.execute_batch(sql)?; + Ok(()) + } +}