diff --git a/Cargo.lock b/Cargo.lock index 6168294..e5f7411 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -12,6 +12,24 @@ dependencies = [ "generic-array", ] +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "bitflags" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" + [[package]] name = "blake2" version = "0.10.6" @@ -30,6 +48,16 @@ dependencies = [ "generic-array", ] +[[package]] +name = "cc" +version = "1.2.53" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "755d2fce177175ffca841e9a06afdb2c4ab0f593d53b4dee48147dfaade85932" +dependencies = [ + "find-msvc-tools", + "shlex", +] + [[package]] name = "cfg-if" version = "1.0.4" @@ -141,12 +169,60 @@ dependencies = [ "x25519-dalek", ] +[[package]] +name = "double-ratchets-storage" +version = "0.1.0" +dependencies = [ + "chacha20poly1305", + "double-ratchets", + "rand", + "rusqlite", + "serde", + "tempfile", + "thiserror", + "x25519-dalek", +] + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "fiat-crypto" version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" +[[package]] +name = "find-msvc-tools" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8591b0bcc8a98a64310a2fae1bb3e9b8564dd10e381e6e28010fde8e8e8568db" + [[package]] name = "generic-array" version = "0.14.7" @@ -168,6 +244,36 @@ dependencies = [ "wasi", ] +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", +] + +[[package]] +name = "hashlink" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +dependencies = [ + "hashbrown", +] + [[package]] name = "hkdf" version = "0.12.4" @@ -201,6 +307,23 @@ version = "0.2.178" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" +[[package]] +name = "libsqlite3-sys" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + [[package]] name = "logos-chat" version = "0.1.0" @@ -208,12 +331,24 @@ dependencies = [ "thiserror", ] +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + [[package]] name = "opaque-debug" version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + [[package]] name = "poly1305" version = "0.8.0" @@ -252,6 +387,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + [[package]] name = "rand" version = "0.8.5" @@ -279,7 +420,21 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.16", +] + +[[package]] +name = "rusqlite" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e" +dependencies = [ + "bitflags", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", ] [[package]] @@ -291,6 +446,19 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + [[package]] name = "semver" version = "1.0.27" @@ -327,6 +495,18 @@ dependencies = [ "syn", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + [[package]] name = "subtle" version = "2.6.1" @@ -344,6 +524,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tempfile" +version = "3.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" +dependencies = [ + "fastrand", + "getrandom 0.3.4", + "once_cell", + "rustix", + "windows-sys", +] + [[package]] name = "thiserror" version = "2.0.17" @@ -386,6 +579,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.5" @@ -398,6 +597,36 @@ version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" + [[package]] name = "x25519-dalek" version = "2.0.1" diff --git a/Cargo.toml b/Cargo.toml index bc37360..24dc941 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,4 +5,5 @@ resolver = "3" members = [ "conversations", "double-ratchets", + "double-ratchets-storage", ] diff --git a/double-ratchets-storage/Cargo.toml b/double-ratchets-storage/Cargo.toml new file mode 100644 index 0000000..4a71fc2 --- /dev/null +++ b/double-ratchets-storage/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "double-ratchets-storage" +version = "0.1.0" +edition = "2024" + +[features] +default = ["sqlite"] +sqlite = ["rusqlite/bundled"] +sqlcipher = ["rusqlite/bundled-sqlcipher"] + +[dependencies] +double-ratchets = { path = "../double-ratchets" } +x25519-dalek = { version = "2.0.1", features = ["static_secrets"] } +rusqlite = { version = "0.32", optional = true } +serde = { version = "1.0", features = ["derive"] } +thiserror = "2" +chacha20poly1305 = "0.10" +rand = "0.8" + +[dev-dependencies] +tempfile = "3" diff --git a/double-ratchets-storage/src/error.rs b/double-ratchets-storage/src/error.rs new file mode 100644 index 0000000..205c71e --- /dev/null +++ b/double-ratchets-storage/src/error.rs @@ -0,0 +1,46 @@ +//! Error types for the storage module. + +use thiserror::Error; + +/// Errors that can occur during storage operations. +#[derive(Error, Debug)] +pub enum StorageError { + /// Database operation failed. + #[cfg(any(feature = "sqlite", feature = "sqlcipher"))] + #[error("database error: {0}")] + Database(#[from] rusqlite::Error), + + /// Field-level encryption failed. + #[error("encryption failed: {0}")] + Encryption(String), + + /// Field-level decryption failed. + #[error("decryption failed: {0}")] + Decryption(String), + + /// Stored state is corrupted or invalid. + #[error("corrupted state: {0}")] + CorruptedState(String), + + /// Session was not found in storage. + #[error("session not found: {}", hex::encode(.session_id))] + SessionNotFound { + /// The session ID that was not found. + session_id: [u8; 32], + }, + + /// I/O operation failed. + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + + /// Key reconstruction failed. + #[error("key reconstruction failed: {0}")] + KeyReconstruction(String), +} + +/// Helper module for hex encoding (used in error messages). +mod hex { + pub fn encode(bytes: &[u8]) -> String { + bytes.iter().map(|b| format!("{:02x}", b)).collect() + } +} diff --git a/double-ratchets-storage/src/lib.rs b/double-ratchets-storage/src/lib.rs new file mode 100644 index 0000000..9536cf6 --- /dev/null +++ b/double-ratchets-storage/src/lib.rs @@ -0,0 +1,178 @@ +//! Persistent storage for Double Ratchet state. +//! +//! This crate provides storage backends for persisting [`RatchetState`](double_ratchets::RatchetState) +//! across application restarts. It includes: +//! +//! - [`MemoryStorage`] - In-memory storage for testing +//! - [`SqliteStorage`] - SQLite storage with field-level encryption (requires `sqlite` feature) +//! +//! # Features +//! +//! - `sqlite` (default) - Enables SQLite storage with `rusqlite/bundled` +//! - `sqlcipher` - Enables SQLCipher full-database encryption (mutually exclusive with `sqlite`) +//! +//! # Security +//! +//! Private keys (`dh_self_secret`) are always encrypted with ChaCha20Poly1305 before storage, +//! even when using plain SQLite. For additional security, enable the `sqlcipher` feature +//! for full database encryption. +//! +//! # Example +//! +//! ```no_run +//! use double_ratchets::hkdf::DefaultDomain; +//! use double_ratchets::state::RatchetState; +//! use double_ratchets::InstallationKeyPair; +//! use double_ratchets_storage::{ +//! RatchetStorage, SqliteStorage, StorableRatchetState, +//! }; +//! +//! // Create a ratchet state +//! let bob_keypair = InstallationKeyPair::generate(); +//! let shared_secret = [0x42u8; 32]; +//! let state: RatchetState = +//! RatchetState::init_sender(shared_secret, *bob_keypair.public()); +//! +//! // Open storage +//! let encryption_key = [0u8; 32]; // Use proper key derivation! +//! let storage = SqliteStorage::open("ratchets.db", encryption_key).unwrap(); +//! +//! // Save state +//! let session_id = [1u8; 32]; +//! let storable = StorableRatchetState::from_ratchet_state(&state, "default"); +//! storage.save(&session_id, &storable).unwrap(); +//! +//! // Load state +//! let loaded = storage.load(&session_id).unwrap().unwrap(); +//! let restored: RatchetState = loaded.to_ratchet_state().unwrap(); +//! ``` + +pub mod error; +pub mod memory; +#[cfg(any(feature = "sqlite", feature = "sqlcipher"))] +pub mod sqlite; +pub mod traits; +pub mod types; + +// Re-exports for convenience +pub use error::StorageError; +pub use memory::MemoryStorage; +#[cfg(any(feature = "sqlite", feature = "sqlcipher"))] +pub use sqlite::{EncryptionKey, SqliteStorage}; +pub use traits::{RatchetStorage, SessionId}; +pub use types::{SkippedKey, StorableRatchetState}; + +#[cfg(test)] +mod integration_tests { + use super::*; + use double_ratchets::hkdf::DefaultDomain; + use double_ratchets::state::RatchetState; + use double_ratchets::InstallationKeyPair; + + /// Integration test: full encryption/decryption cycle with storage + #[test] + fn test_full_conversation_with_storage_roundtrip() { + // Setup Alice and Bob + let bob_keypair = InstallationKeyPair::generate(); + let shared_secret = [0x42u8; 32]; + + let mut alice: RatchetState = + RatchetState::init_sender(shared_secret, *bob_keypair.public()); + let mut bob: RatchetState = + RatchetState::init_receiver(shared_secret, bob_keypair); + + let storage = MemoryStorage::new(); + let alice_session = [0xAA; 32]; + let bob_session = [0xBB; 32]; + + // Alice sends a message + let (ct1, header1) = alice.encrypt_message(b"Hello Bob!"); + + // Save Alice's state + let alice_storable = StorableRatchetState::from_ratchet_state(&alice, "default"); + storage.save(&alice_session, &alice_storable).unwrap(); + + // Bob receives the message + let pt1 = bob.decrypt_message(&ct1, header1).unwrap(); + assert_eq!(pt1, b"Hello Bob!"); + + // Save Bob's state + let bob_storable = StorableRatchetState::from_ratchet_state(&bob, "default"); + storage.save(&bob_session, &bob_storable).unwrap(); + + // Simulate restart: load states from storage + let alice_loaded = storage.load(&alice_session).unwrap().unwrap(); + let bob_loaded = storage.load(&bob_session).unwrap().unwrap(); + + let mut alice_restored: RatchetState = + alice_loaded.to_ratchet_state().unwrap(); + let mut bob_restored: RatchetState = + bob_loaded.to_ratchet_state().unwrap(); + + // Bob replies + let (ct2, header2) = bob_restored.encrypt_message(b"Hi Alice!"); + let pt2 = alice_restored.decrypt_message(&ct2, header2).unwrap(); + assert_eq!(pt2, b"Hi Alice!"); + + // Alice sends another message + let (ct3, header3) = alice_restored.encrypt_message(b"How are you?"); + let pt3 = bob_restored.decrypt_message(&ct3, header3).unwrap(); + assert_eq!(pt3, b"How are you?"); + } + + /// Integration test: verify SQLite storage with encryption works + #[cfg(any(feature = "sqlite", feature = "sqlcipher"))] + #[test] + fn test_sqlite_integration() { + let dir = tempfile::tempdir().unwrap(); + let db_path = dir.path().join("integration_test.db"); + let key = [0x42u8; 32]; + + // Setup + let bob_keypair = InstallationKeyPair::generate(); + let shared_secret = [0x42u8; 32]; + let mut alice: RatchetState = + RatchetState::init_sender(shared_secret, *bob_keypair.public()); + let mut bob: RatchetState = + RatchetState::init_receiver(shared_secret, bob_keypair); + + let alice_session = [0xAA; 32]; + let bob_session = [0xBB; 32]; + + // Exchange messages + let (ct1, header1) = alice.encrypt_message(b"Message 1"); + bob.decrypt_message(&ct1, header1).unwrap(); + + let (ct2, header2) = bob.encrypt_message(b"Response 1"); + alice.decrypt_message(&ct2, header2).unwrap(); + + // Save both states + { + let storage = SqliteStorage::open(&db_path, key).unwrap(); + + let alice_storable = StorableRatchetState::from_ratchet_state(&alice, "default"); + let bob_storable = StorableRatchetState::from_ratchet_state(&bob, "default"); + + storage.save(&alice_session, &alice_storable).unwrap(); + storage.save(&bob_session, &bob_storable).unwrap(); + } + + // Reopen database (simulating restart) + { + let storage = SqliteStorage::open(&db_path, key).unwrap(); + + let alice_loaded = storage.load(&alice_session).unwrap().unwrap(); + let bob_loaded = storage.load(&bob_session).unwrap().unwrap(); + + let mut alice_restored: RatchetState = + alice_loaded.to_ratchet_state().unwrap(); + let mut bob_restored: RatchetState = + bob_loaded.to_ratchet_state().unwrap(); + + // Continue conversation + let (ct3, header3) = alice_restored.encrypt_message(b"Message 2"); + let pt3 = bob_restored.decrypt_message(&ct3, header3).unwrap(); + assert_eq!(pt3, b"Message 2"); + } + } +} diff --git a/double-ratchets-storage/src/memory.rs b/double-ratchets-storage/src/memory.rs new file mode 100644 index 0000000..cc2791e --- /dev/null +++ b/double-ratchets-storage/src/memory.rs @@ -0,0 +1,216 @@ +//! In-memory storage implementation for testing. + +use std::collections::HashMap; +use std::sync::RwLock; + +use crate::error::StorageError; +use crate::traits::{RatchetStorage, SessionId}; +use crate::types::StorableRatchetState; + +/// In-memory storage backend for testing purposes. +/// +/// This implementation stores ratchet states in a `HashMap` wrapped in a `RwLock` +/// for thread-safe access. Data is not persisted across process restarts. +/// +/// # Example +/// +/// ``` +/// use double_ratchets_storage::{MemoryStorage, RatchetStorage}; +/// +/// let storage = MemoryStorage::new(); +/// assert!(storage.list_sessions().unwrap().is_empty()); +/// ``` +pub struct MemoryStorage { + states: RwLock>, +} + +impl MemoryStorage { + /// Create a new empty in-memory storage. + pub fn new() -> Self { + Self { + states: RwLock::new(HashMap::new()), + } + } + + /// Get the number of stored sessions. + pub fn len(&self) -> usize { + self.states.read().unwrap().len() + } + + /// Check if the storage is empty. + pub fn is_empty(&self) -> bool { + self.states.read().unwrap().is_empty() + } + + /// Clear all stored sessions. + pub fn clear(&self) { + self.states.write().unwrap().clear(); + } +} + +impl Default for MemoryStorage { + fn default() -> Self { + Self::new() + } +} + +impl RatchetStorage for MemoryStorage { + fn save(&self, session_id: &SessionId, state: &StorableRatchetState) -> Result<(), StorageError> { + let mut states = self.states.write().unwrap(); + states.insert(*session_id, state.clone()); + Ok(()) + } + + fn load(&self, session_id: &SessionId) -> Result, StorageError> { + let states = self.states.read().unwrap(); + Ok(states.get(session_id).cloned()) + } + + fn delete(&self, session_id: &SessionId) -> Result { + let mut states = self.states.write().unwrap(); + Ok(states.remove(session_id).is_some()) + } + + fn exists(&self, session_id: &SessionId) -> Result { + let states = self.states.read().unwrap(); + Ok(states.contains_key(session_id)) + } + + fn list_sessions(&self) -> Result, StorageError> { + let states = self.states.read().unwrap(); + Ok(states.keys().copied().collect()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use double_ratchets::hkdf::DefaultDomain; + use double_ratchets::state::RatchetState; + use double_ratchets::InstallationKeyPair; + + fn create_test_state() -> StorableRatchetState { + let bob_keypair = InstallationKeyPair::generate(); + let shared_secret = [0x42u8; 32]; + let state: RatchetState = + RatchetState::init_sender(shared_secret, *bob_keypair.public()); + StorableRatchetState::from_ratchet_state(&state, "default") + } + + #[test] + fn test_save_and_load() { + let storage = MemoryStorage::new(); + let session_id = [1u8; 32]; + let state = create_test_state(); + + storage.save(&session_id, &state).unwrap(); + let loaded = storage.load(&session_id).unwrap(); + + assert!(loaded.is_some()); + let loaded = loaded.unwrap(); + assert_eq!(loaded.root_key, state.root_key); + } + + #[test] + fn test_load_nonexistent() { + let storage = MemoryStorage::new(); + let session_id = [1u8; 32]; + + let loaded = storage.load(&session_id).unwrap(); + assert!(loaded.is_none()); + } + + #[test] + fn test_delete() { + let storage = MemoryStorage::new(); + let session_id = [1u8; 32]; + let state = create_test_state(); + + storage.save(&session_id, &state).unwrap(); + assert!(storage.exists(&session_id).unwrap()); + + let deleted = storage.delete(&session_id).unwrap(); + assert!(deleted); + assert!(!storage.exists(&session_id).unwrap()); + + // Deleting again should return false + let deleted = storage.delete(&session_id).unwrap(); + assert!(!deleted); + } + + #[test] + fn test_exists() { + let storage = MemoryStorage::new(); + let session_id = [1u8; 32]; + + assert!(!storage.exists(&session_id).unwrap()); + + let state = create_test_state(); + storage.save(&session_id, &state).unwrap(); + + assert!(storage.exists(&session_id).unwrap()); + } + + #[test] + fn test_list_sessions() { + let storage = MemoryStorage::new(); + + assert!(storage.list_sessions().unwrap().is_empty()); + + let state = create_test_state(); + let session_ids: Vec = (0..3).map(|i| [i; 32]).collect(); + + for id in &session_ids { + storage.save(id, &state).unwrap(); + } + + let mut listed = storage.list_sessions().unwrap(); + listed.sort(); + let mut expected = session_ids.clone(); + expected.sort(); + + assert_eq!(listed, expected); + } + + #[test] + fn test_overwrite() { + let storage = MemoryStorage::new(); + let session_id = [1u8; 32]; + + // Create first state + let bob_keypair1 = InstallationKeyPair::generate(); + let state1: RatchetState = + RatchetState::init_sender([0x42u8; 32], *bob_keypair1.public()); + let storable1 = StorableRatchetState::from_ratchet_state(&state1, "default"); + + // Create second state with different root + let bob_keypair2 = InstallationKeyPair::generate(); + let state2: RatchetState = + RatchetState::init_sender([0x43u8; 32], *bob_keypair2.public()); + let storable2 = StorableRatchetState::from_ratchet_state(&state2, "default"); + + // Save first, then overwrite with second + storage.save(&session_id, &storable1).unwrap(); + storage.save(&session_id, &storable2).unwrap(); + + // Should have the second state + let loaded = storage.load(&session_id).unwrap().unwrap(); + assert_eq!(loaded.root_key, storable2.root_key); + assert_ne!(loaded.root_key, storable1.root_key); + } + + #[test] + fn test_clear() { + let storage = MemoryStorage::new(); + let state = create_test_state(); + + for i in 0..5 { + storage.save(&[i; 32], &state).unwrap(); + } + + assert_eq!(storage.len(), 5); + + storage.clear(); + assert!(storage.is_empty()); + } +} diff --git a/double-ratchets-storage/src/sqlite.rs b/double-ratchets-storage/src/sqlite.rs new file mode 100644 index 0000000..bac9dfc --- /dev/null +++ b/double-ratchets-storage/src/sqlite.rs @@ -0,0 +1,746 @@ +//! SQLite storage implementation with field-level encryption. + +use std::path::Path; +use std::sync::Mutex; +use std::time::{SystemTime, UNIX_EPOCH}; + +use chacha20poly1305::{ + aead::{Aead, KeyInit}, + ChaCha20Poly1305, Nonce, +}; +use rand::RngCore; +use rusqlite::{params, Connection, OptionalExtension}; + +use crate::error::StorageError; +use crate::traits::{RatchetStorage, SessionId}; +use crate::types::{SkippedKey, StorableRatchetState}; + +/// Field encryption key type (32 bytes for ChaCha20Poly1305). +pub type EncryptionKey = [u8; 32]; + +/// SQLite storage backend with field-level encryption for secrets. +/// +/// This implementation stores ratchet states in SQLite with: +/// - Field-level encryption for private keys using ChaCha20Poly1305 +/// - WAL mode for better concurrent performance +/// - Foreign keys and cascading deletes for data integrity +/// - Atomic transactions to prevent partial writes +/// +/// # Security +/// +/// The `dh_self_secret` field is encrypted with the provided encryption key. +/// For additional security, consider using SQLCipher for full database encryption +/// via the `open_encrypted` method (requires `sqlcipher` feature). +/// +/// # Example +/// +/// ```no_run +/// use double_ratchets_storage::SqliteStorage; +/// +/// let key = [0u8; 32]; // Use a proper key derivation function +/// let storage = SqliteStorage::open("ratchets.db", key).unwrap(); +/// ``` +pub struct SqliteStorage { + conn: Mutex, + encryption_key: EncryptionKey, +} + +impl SqliteStorage { + /// Open or create a SQLite database with field-level encryption. + /// + /// # Arguments + /// + /// * `path` - Path to the database file. + /// * `encryption_key` - 32-byte key for field-level encryption. + /// + /// # Returns + /// + /// * `Ok(SqliteStorage)` on success. + /// * `Err(StorageError)` on failure. + pub fn open>(path: P, encryption_key: EncryptionKey) -> Result { + let conn = Connection::open(path)?; + Self::initialize(conn, encryption_key) + } + + /// Create an in-memory SQLite database (for testing). + /// + /// # Arguments + /// + /// * `encryption_key` - 32-byte key for field-level encryption. + /// + /// # Returns + /// + /// * `Ok(SqliteStorage)` on success. + /// * `Err(StorageError)` on failure. + pub fn open_in_memory(encryption_key: EncryptionKey) -> Result { + let conn = Connection::open_in_memory()?; + Self::initialize(conn, encryption_key) + } + + /// Open or create a SQLCipher-encrypted database. + /// + /// This method requires the `sqlcipher` feature to be enabled. + /// + /// # Arguments + /// + /// * `path` - Path to the database file. + /// * `db_password` - Password for SQLCipher database encryption. + /// * `field_key` - 32-byte key for additional field-level encryption. + /// + /// # Returns + /// + /// * `Ok(SqliteStorage)` on success. + /// * `Err(StorageError)` on failure. + #[cfg(feature = "sqlcipher")] + pub fn open_encrypted>( + path: P, + db_password: &str, + field_key: EncryptionKey, + ) -> Result { + let conn = Connection::open(path)?; + + // Set SQLCipher key + conn.pragma_update(None, "key", db_password)?; + + Self::initialize(conn, field_key) + } + + fn initialize(conn: Connection, encryption_key: EncryptionKey) -> Result { + // Enable WAL mode for better performance + conn.pragma_update(None, "journal_mode", "WAL")?; + + // Enable foreign keys + conn.pragma_update(None, "foreign_keys", "ON")?; + + // Create tables + conn.execute_batch( + r#" + CREATE TABLE IF NOT EXISTS ratchet_states ( + session_id BLOB PRIMARY KEY, + root_key BLOB NOT NULL, + sending_chain BLOB, + receiving_chain BLOB, + dh_self_secret_encrypted BLOB NOT NULL, + dh_self_secret_nonce BLOB NOT NULL, + dh_self_public BLOB NOT NULL, + dh_remote BLOB, + msg_send INTEGER NOT NULL, + msg_recv INTEGER NOT NULL, + prev_chain_len INTEGER NOT NULL, + domain_id TEXT NOT NULL, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL + ); + + CREATE TABLE IF NOT EXISTS skipped_keys ( + id INTEGER PRIMARY KEY, + session_id BLOB NOT NULL REFERENCES ratchet_states(session_id) ON DELETE CASCADE, + public_key BLOB NOT NULL, + msg_num INTEGER NOT NULL, + message_key BLOB NOT NULL, + UNIQUE(session_id, public_key, msg_num) + ); + + CREATE INDEX IF NOT EXISTS idx_skipped_keys_session ON skipped_keys(session_id); + "#, + )?; + + Ok(Self { + conn: Mutex::new(conn), + encryption_key, + }) + } + + /// Encrypt a 32-byte secret using ChaCha20Poly1305. + fn encrypt_secret(&self, secret: &[u8; 32]) -> Result<(Vec, [u8; 12]), StorageError> { + let cipher = ChaCha20Poly1305::new((&self.encryption_key).into()); + + let mut nonce_bytes = [0u8; 12]; + rand::thread_rng().fill_bytes(&mut nonce_bytes); + let nonce = Nonce::from_slice(&nonce_bytes); + + let ciphertext = cipher + .encrypt(nonce, secret.as_ref()) + .map_err(|e| StorageError::Encryption(e.to_string()))?; + + Ok((ciphertext, nonce_bytes)) + } + + /// Decrypt a secret using ChaCha20Poly1305. + fn decrypt_secret(&self, ciphertext: &[u8], nonce: &[u8; 12]) -> Result<[u8; 32], StorageError> { + let cipher = ChaCha20Poly1305::new((&self.encryption_key).into()); + let nonce = Nonce::from_slice(nonce); + + let plaintext = cipher + .decrypt(nonce, ciphertext) + .map_err(|e| StorageError::Decryption(e.to_string()))?; + + plaintext + .try_into() + .map_err(|_| StorageError::CorruptedState("decrypted secret has wrong length".to_string())) + } + + fn current_timestamp() -> i64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i64 + } +} + +impl RatchetStorage for SqliteStorage { + fn save(&self, session_id: &SessionId, state: &StorableRatchetState) -> Result<(), StorageError> { + let (encrypted_secret, nonce) = self.encrypt_secret(&state.dh_self_secret)?; + + let conn = self.conn.lock().unwrap(); + let tx = conn.unchecked_transaction()?; + + let now = Self::current_timestamp(); + + // Check if session exists to determine created_at + let exists: bool = tx.query_row( + "SELECT 1 FROM ratchet_states WHERE session_id = ?", + [session_id.as_slice()], + |_| Ok(true), + ).optional()?.unwrap_or(false); + + if exists { + // Update existing session + tx.execute( + r#" + UPDATE ratchet_states SET + root_key = ?, + sending_chain = ?, + receiving_chain = ?, + dh_self_secret_encrypted = ?, + dh_self_secret_nonce = ?, + dh_self_public = ?, + dh_remote = ?, + msg_send = ?, + msg_recv = ?, + prev_chain_len = ?, + domain_id = ?, + updated_at = ? + WHERE session_id = ? + "#, + params![ + state.root_key.as_slice(), + state.sending_chain.as_ref().map(|c| c.as_slice()), + state.receiving_chain.as_ref().map(|c| c.as_slice()), + encrypted_secret.as_slice(), + nonce.as_slice(), + state.dh_self_public.as_slice(), + state.dh_remote.as_ref().map(|pk| pk.as_slice()), + state.msg_send, + state.msg_recv, + state.prev_chain_len, + &state.domain_id, + now, + session_id.as_slice(), + ], + )?; + + // Delete existing skipped keys + tx.execute( + "DELETE FROM skipped_keys WHERE session_id = ?", + [session_id.as_slice()], + )?; + } else { + // Insert new session + tx.execute( + r#" + INSERT INTO ratchet_states ( + session_id, root_key, sending_chain, receiving_chain, + dh_self_secret_encrypted, dh_self_secret_nonce, dh_self_public, + dh_remote, msg_send, msg_recv, prev_chain_len, domain_id, + created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + "#, + params![ + session_id.as_slice(), + state.root_key.as_slice(), + state.sending_chain.as_ref().map(|c| c.as_slice()), + state.receiving_chain.as_ref().map(|c| c.as_slice()), + encrypted_secret.as_slice(), + nonce.as_slice(), + state.dh_self_public.as_slice(), + state.dh_remote.as_ref().map(|pk| pk.as_slice()), + state.msg_send, + state.msg_recv, + state.prev_chain_len, + &state.domain_id, + now, + now, + ], + )?; + } + + // Insert skipped keys + for sk in &state.skipped_keys { + tx.execute( + r#" + INSERT INTO skipped_keys (session_id, public_key, msg_num, message_key) + VALUES (?, ?, ?, ?) + "#, + params![ + session_id.as_slice(), + sk.public_key.as_slice(), + sk.msg_num, + sk.message_key.as_slice(), + ], + )?; + } + + tx.commit()?; + Ok(()) + } + + fn load(&self, session_id: &SessionId) -> Result, StorageError> { + let conn = self.conn.lock().unwrap(); + + let row = conn + .query_row( + r#" + SELECT root_key, sending_chain, receiving_chain, + dh_self_secret_encrypted, dh_self_secret_nonce, dh_self_public, + dh_remote, msg_send, msg_recv, prev_chain_len, domain_id + FROM ratchet_states WHERE session_id = ? + "#, + [session_id.as_slice()], + |row| { + Ok(( + row.get::<_, Vec>(0)?, + row.get::<_, Option>>(1)?, + row.get::<_, Option>>(2)?, + row.get::<_, Vec>(3)?, + row.get::<_, Vec>(4)?, + row.get::<_, Vec>(5)?, + row.get::<_, Option>>(6)?, + row.get::<_, u32>(7)?, + row.get::<_, u32>(8)?, + row.get::<_, u32>(9)?, + row.get::<_, String>(10)?, + )) + }, + ) + .optional()?; + + let Some(( + root_key_bytes, + sending_chain_bytes, + receiving_chain_bytes, + encrypted_secret, + nonce_bytes, + dh_self_public_bytes, + dh_remote_bytes, + msg_send, + msg_recv, + prev_chain_len, + domain_id, + )) = row + else { + return Ok(None); + }; + + // Decrypt the secret + let nonce: [u8; 12] = nonce_bytes + .try_into() + .map_err(|_| StorageError::CorruptedState("invalid nonce length".to_string()))?; + let dh_self_secret = self.decrypt_secret(&encrypted_secret, &nonce)?; + + // Convert byte vectors to arrays + let root_key: [u8; 32] = root_key_bytes + .try_into() + .map_err(|_| StorageError::CorruptedState("invalid root_key length".to_string()))?; + let dh_self_public: [u8; 32] = dh_self_public_bytes + .try_into() + .map_err(|_| StorageError::CorruptedState("invalid dh_self_public length".to_string()))?; + + let sending_chain = sending_chain_bytes + .map(|b| { + b.try_into() + .map_err(|_| StorageError::CorruptedState("invalid sending_chain length".to_string())) + }) + .transpose()?; + let receiving_chain = receiving_chain_bytes + .map(|b| { + b.try_into() + .map_err(|_| StorageError::CorruptedState("invalid receiving_chain length".to_string())) + }) + .transpose()?; + let dh_remote = dh_remote_bytes + .map(|b| { + b.try_into() + .map_err(|_| StorageError::CorruptedState("invalid dh_remote length".to_string())) + }) + .transpose()?; + + // Load skipped keys + let mut stmt = conn.prepare( + "SELECT public_key, msg_num, message_key FROM skipped_keys WHERE session_id = ?", + )?; + let skipped_keys: Vec = stmt + .query_map([session_id.as_slice()], |row| { + let pk_bytes: Vec = row.get(0)?; + let msg_num: u32 = row.get(1)?; + let mk_bytes: Vec = row.get(2)?; + Ok((pk_bytes, msg_num, mk_bytes)) + })? + .collect::, _>>()? + .into_iter() + .map(|(pk_bytes, msg_num, mk_bytes)| { + let public_key: [u8; 32] = pk_bytes + .try_into() + .map_err(|_| StorageError::CorruptedState("invalid skipped public_key length".to_string()))?; + let message_key: [u8; 32] = mk_bytes + .try_into() + .map_err(|_| StorageError::CorruptedState("invalid skipped message_key length".to_string()))?; + Ok(SkippedKey { + public_key, + msg_num, + message_key, + }) + }) + .collect::, StorageError>>()?; + + Ok(Some(StorableRatchetState { + root_key, + sending_chain, + receiving_chain, + dh_self_secret, + dh_self_public, + dh_remote, + msg_send, + msg_recv, + prev_chain_len, + skipped_keys, + domain_id, + })) + } + + fn delete(&self, session_id: &SessionId) -> Result { + let conn = self.conn.lock().unwrap(); + let changes = conn.execute( + "DELETE FROM ratchet_states WHERE session_id = ?", + [session_id.as_slice()], + )?; + Ok(changes > 0) + } + + fn exists(&self, session_id: &SessionId) -> Result { + let conn = self.conn.lock().unwrap(); + let exists: bool = conn + .query_row( + "SELECT 1 FROM ratchet_states WHERE session_id = ?", + [session_id.as_slice()], + |_| Ok(true), + ) + .optional()? + .unwrap_or(false); + Ok(exists) + } + + fn list_sessions(&self) -> Result, StorageError> { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare("SELECT session_id FROM ratchet_states")?; + let sessions: Vec = stmt + .query_map([], |row| { + let bytes: Vec = row.get(0)?; + Ok(bytes) + })? + .collect::, _>>()? + .into_iter() + .filter_map(|bytes| bytes.try_into().ok()) + .collect(); + Ok(sessions) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use double_ratchets::hkdf::DefaultDomain; + use double_ratchets::state::RatchetState; + use double_ratchets::InstallationKeyPair; + + fn create_test_storage() -> SqliteStorage { + let key = [0x42u8; 32]; + SqliteStorage::open_in_memory(key).unwrap() + } + + fn create_test_state() -> StorableRatchetState { + let bob_keypair = InstallationKeyPair::generate(); + let shared_secret = [0x42u8; 32]; + let state: RatchetState = + RatchetState::init_sender(shared_secret, *bob_keypair.public()); + StorableRatchetState::from_ratchet_state(&state, "default") + } + + #[test] + fn test_save_and_load() { + let storage = create_test_storage(); + let session_id = [1u8; 32]; + let state = create_test_state(); + + storage.save(&session_id, &state).unwrap(); + let loaded = storage.load(&session_id).unwrap(); + + assert!(loaded.is_some()); + let loaded = loaded.unwrap(); + assert_eq!(loaded.root_key, state.root_key); + assert_eq!(loaded.dh_self_public, state.dh_self_public); + // Secret should be decrypted correctly + assert_eq!(loaded.dh_self_secret, state.dh_self_secret); + } + + #[test] + fn test_load_nonexistent() { + let storage = create_test_storage(); + let session_id = [1u8; 32]; + + let loaded = storage.load(&session_id).unwrap(); + assert!(loaded.is_none()); + } + + #[test] + fn test_delete() { + let storage = create_test_storage(); + let session_id = [1u8; 32]; + let state = create_test_state(); + + storage.save(&session_id, &state).unwrap(); + assert!(storage.exists(&session_id).unwrap()); + + let deleted = storage.delete(&session_id).unwrap(); + assert!(deleted); + assert!(!storage.exists(&session_id).unwrap()); + + // Deleting again should return false + let deleted = storage.delete(&session_id).unwrap(); + assert!(!deleted); + } + + #[test] + fn test_exists() { + let storage = create_test_storage(); + let session_id = [1u8; 32]; + + assert!(!storage.exists(&session_id).unwrap()); + + let state = create_test_state(); + storage.save(&session_id, &state).unwrap(); + + assert!(storage.exists(&session_id).unwrap()); + } + + #[test] + fn test_list_sessions() { + let storage = create_test_storage(); + + assert!(storage.list_sessions().unwrap().is_empty()); + + let state = create_test_state(); + let session_ids: Vec = (0..3).map(|i| [i; 32]).collect(); + + for id in &session_ids { + storage.save(id, &state).unwrap(); + } + + let mut listed = storage.list_sessions().unwrap(); + listed.sort(); + let mut expected = session_ids.clone(); + expected.sort(); + + assert_eq!(listed, expected); + } + + #[test] + fn test_overwrite() { + let storage = create_test_storage(); + let session_id = [1u8; 32]; + + // Create first state + let bob_keypair1 = InstallationKeyPair::generate(); + let state1: RatchetState = + RatchetState::init_sender([0x42u8; 32], *bob_keypair1.public()); + let storable1 = StorableRatchetState::from_ratchet_state(&state1, "default"); + + // Create second state with different root + let bob_keypair2 = InstallationKeyPair::generate(); + let state2: RatchetState = + RatchetState::init_sender([0x43u8; 32], *bob_keypair2.public()); + let storable2 = StorableRatchetState::from_ratchet_state(&state2, "default"); + + // Save first, then overwrite with second + storage.save(&session_id, &storable1).unwrap(); + storage.save(&session_id, &storable2).unwrap(); + + // Should have the second state + let loaded = storage.load(&session_id).unwrap().unwrap(); + assert_eq!(loaded.root_key, storable2.root_key); + assert_ne!(loaded.root_key, storable1.root_key); + } + + #[test] + fn test_skipped_keys_storage() { + let storage = create_test_storage(); + let session_id = [1u8; 32]; + + // Create states and generate skipped keys + let bob_keypair = InstallationKeyPair::generate(); + let shared_secret = [0x42u8; 32]; + let mut alice: RatchetState = + RatchetState::init_sender(shared_secret, *bob_keypair.public()); + let mut bob: RatchetState = + RatchetState::init_receiver(shared_secret, bob_keypair); + + // Alice sends multiple messages + let mut messages = vec![]; + for i in 0..3 { + let (ct, header) = alice.encrypt_message(&format!("Message {}", i).into_bytes()); + messages.push((ct, header)); + } + + // Bob receives out of order to create skipped keys + bob.decrypt_message(&messages[0].0, messages[0].1.clone()) + .unwrap(); + bob.decrypt_message(&messages[2].0, messages[2].1.clone()) + .unwrap(); + + assert!(!bob.skipped_keys.is_empty()); + + // Save and reload + let storable = StorableRatchetState::from_ratchet_state(&bob, "default"); + storage.save(&session_id, &storable).unwrap(); + + let loaded = storage.load(&session_id).unwrap().unwrap(); + assert_eq!(loaded.skipped_keys.len(), storable.skipped_keys.len()); + + // Restore and verify we can decrypt the skipped message + let mut restored: RatchetState = loaded.to_ratchet_state().unwrap(); + let pt = restored + .decrypt_message(&messages[1].0, messages[1].1.clone()) + .unwrap(); + assert_eq!(pt, b"Message 1"); + } + + #[test] + fn test_encryption_uses_different_nonces() { + let storage = create_test_storage(); + let state = create_test_state(); + + // Save the same state twice with different session IDs + storage.save(&[1u8; 32], &state).unwrap(); + storage.save(&[2u8; 32], &state).unwrap(); + + // Both should load correctly (encryption with different nonces) + let loaded1 = storage.load(&[1u8; 32]).unwrap().unwrap(); + let loaded2 = storage.load(&[2u8; 32]).unwrap().unwrap(); + + assert_eq!(loaded1.dh_self_secret, loaded2.dh_self_secret); + } + + #[test] + fn test_cascade_delete_skipped_keys() { + let storage = create_test_storage(); + let session_id = [1u8; 32]; + + // Create a state with skipped keys + let bob_keypair = InstallationKeyPair::generate(); + let shared_secret = [0x42u8; 32]; + let mut alice: RatchetState = + RatchetState::init_sender(shared_secret, *bob_keypair.public()); + let mut bob: RatchetState = + RatchetState::init_receiver(shared_secret, bob_keypair); + + let mut messages = vec![]; + for i in 0..3 { + let (ct, header) = alice.encrypt_message(&format!("Message {}", i).into_bytes()); + messages.push((ct, header)); + } + + bob.decrypt_message(&messages[0].0, messages[0].1.clone()) + .unwrap(); + bob.decrypt_message(&messages[2].0, messages[2].1.clone()) + .unwrap(); + + let storable = StorableRatchetState::from_ratchet_state(&bob, "default"); + storage.save(&session_id, &storable).unwrap(); + + // Verify skipped keys exist + { + let conn = storage.conn.lock().unwrap(); + let count: i64 = conn + .query_row( + "SELECT COUNT(*) FROM skipped_keys WHERE session_id = ?", + [session_id.as_slice()], + |row| row.get(0), + ) + .unwrap(); + assert!(count > 0); + } + + // Delete session + storage.delete(&session_id).unwrap(); + + // Verify skipped keys were also deleted (cascade) + { + let conn = storage.conn.lock().unwrap(); + let count: i64 = conn + .query_row( + "SELECT COUNT(*) FROM skipped_keys WHERE session_id = ?", + [session_id.as_slice()], + |row| row.get(0), + ) + .unwrap(); + assert_eq!(count, 0); + } + } + + #[test] + fn test_file_storage() { + let dir = tempfile::tempdir().unwrap(); + let db_path = dir.path().join("test.db"); + let key = [0x42u8; 32]; + + let state = create_test_state(); + let session_id = [1u8; 32]; + + // Save in one instance + { + let storage = SqliteStorage::open(&db_path, key).unwrap(); + storage.save(&session_id, &state).unwrap(); + } + + // Load in another instance + { + let storage = SqliteStorage::open(&db_path, key).unwrap(); + let loaded = storage.load(&session_id).unwrap().unwrap(); + assert_eq!(loaded.root_key, state.root_key); + } + } + + #[test] + fn test_wrong_key_fails_decryption() { + let dir = tempfile::tempdir().unwrap(); + let db_path = dir.path().join("test.db"); + let key1 = [0x42u8; 32]; + let key2 = [0x43u8; 32]; + + let state = create_test_state(); + let session_id = [1u8; 32]; + + // Save with key1 + { + let storage = SqliteStorage::open(&db_path, key1).unwrap(); + storage.save(&session_id, &state).unwrap(); + } + + // Try to load with key2 - should fail decryption + { + let storage = SqliteStorage::open(&db_path, key2).unwrap(); + let result = storage.load(&session_id); + assert!(result.is_err()); + } + } +} diff --git a/double-ratchets-storage/src/traits.rs b/double-ratchets-storage/src/traits.rs new file mode 100644 index 0000000..5607a65 --- /dev/null +++ b/double-ratchets-storage/src/traits.rs @@ -0,0 +1,74 @@ +//! Storage trait definitions. + +use crate::error::StorageError; +use crate::types::StorableRatchetState; + +/// A 32-byte session identifier. +pub type SessionId = [u8; 32]; + +/// Abstract storage interface for ratchet states. +/// +/// Implementations must be thread-safe (`Send + Sync`). +pub trait RatchetStorage: Send + Sync { + /// Save a ratchet state for the given session. + /// + /// If a state already exists for this session, it will be overwritten. + /// + /// # Arguments + /// + /// * `session_id` - Unique identifier for the session. + /// * `state` - The ratchet state to store. + /// + /// # Returns + /// + /// * `Ok(())` on success. + /// * `Err(StorageError)` on failure. + fn save(&self, session_id: &SessionId, state: &StorableRatchetState) -> Result<(), StorageError>; + + /// Load a ratchet state for the given session. + /// + /// # Arguments + /// + /// * `session_id` - Unique identifier for the session. + /// + /// # Returns + /// + /// * `Ok(Some(state))` if found. + /// * `Ok(None)` if not found. + /// * `Err(StorageError)` on failure. + fn load(&self, session_id: &SessionId) -> Result, StorageError>; + + /// Delete a ratchet state for the given session. + /// + /// # Arguments + /// + /// * `session_id` - Unique identifier for the session. + /// + /// # Returns + /// + /// * `Ok(true)` if the session existed and was deleted. + /// * `Ok(false)` if the session did not exist. + /// * `Err(StorageError)` on failure. + fn delete(&self, session_id: &SessionId) -> Result; + + /// Check if a session exists in storage. + /// + /// # Arguments + /// + /// * `session_id` - Unique identifier for the session. + /// + /// # Returns + /// + /// * `Ok(true)` if the session exists. + /// * `Ok(false)` if the session does not exist. + /// * `Err(StorageError)` on failure. + fn exists(&self, session_id: &SessionId) -> Result; + + /// List all session IDs in storage. + /// + /// # Returns + /// + /// * `Ok(Vec)` containing all session IDs. + /// * `Err(StorageError)` on failure. + fn list_sessions(&self) -> Result, StorageError>; +} diff --git a/double-ratchets-storage/src/types.rs b/double-ratchets-storage/src/types.rs new file mode 100644 index 0000000..59806a5 --- /dev/null +++ b/double-ratchets-storage/src/types.rs @@ -0,0 +1,225 @@ +//! Serializable types for ratchet state storage. + +use std::collections::HashMap; + +use double_ratchets::state::RatchetState; +use double_ratchets::hkdf::HkdfInfo; +use double_ratchets::InstallationKeyPair; +use serde::{Deserialize, Serialize}; +use x25519_dalek::PublicKey; + +use crate::error::StorageError; + +/// A skipped message key entry for storage. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SkippedKey { + /// The public key associated with this skipped message. + pub public_key: [u8; 32], + /// The message number. + pub msg_num: u32, + /// The 32-byte message key. + pub message_key: [u8; 32], +} + +/// Serializable version of `RatchetState`. +/// +/// This struct stores all keys as raw byte arrays for easy serialization +/// and database storage. Use `from_ratchet_state()` and `to_ratchet_state()` +/// for conversion. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StorableRatchetState { + /// The current root key (32 bytes). + pub root_key: [u8; 32], + + /// The current sending chain key, if any (32 bytes). + pub sending_chain: Option<[u8; 32]>, + + /// The current receiving chain key, if any (32 bytes). + pub receiving_chain: Option<[u8; 32]>, + + /// Our DH secret key (32 bytes). + /// + /// **Security**: This should be encrypted before storage. + pub dh_self_secret: [u8; 32], + + /// Our DH public key (32 bytes). + pub dh_self_public: [u8; 32], + + /// Remote party's DH public key, if known (32 bytes). + pub dh_remote: Option<[u8; 32]>, + + /// Number of messages sent in the current sending chain. + pub msg_send: u32, + + /// Number of messages received in the current receiving chain. + pub msg_recv: u32, + + /// Length of the previous sending chain. + pub prev_chain_len: u32, + + /// Skipped message keys for out-of-order message handling. + pub skipped_keys: Vec, + + /// Domain identifier for HKDF info. + pub domain_id: String, +} + +impl StorableRatchetState { + /// Convert a `RatchetState` into a `StorableRatchetState`. + /// + /// # Type Parameters + /// + /// * `D` - The HKDF domain type implementing `HkdfInfo`. + /// + /// # Arguments + /// + /// * `state` - The ratchet state to convert. + /// * `domain_id` - A string identifier for the domain (used to reconstruct the correct domain type). + pub fn from_ratchet_state(state: &RatchetState, domain_id: &str) -> Self { + let skipped_keys: Vec = state + .skipped_keys + .iter() + .map(|((pub_key, msg_num), msg_key)| SkippedKey { + public_key: *pub_key.as_bytes(), + msg_num: *msg_num, + message_key: *msg_key, + }) + .collect(); + + StorableRatchetState { + root_key: state.root_key, + sending_chain: state.sending_chain, + receiving_chain: state.receiving_chain, + dh_self_secret: state.dh_self.secret_bytes(), + dh_self_public: *state.dh_self.public().as_bytes(), + dh_remote: state.dh_remote.map(|pk| *pk.as_bytes()), + msg_send: state.msg_send, + msg_recv: state.msg_recv, + prev_chain_len: state.prev_chain_len, + skipped_keys, + domain_id: domain_id.to_string(), + } + } + + /// Convert this `StorableRatchetState` back into a `RatchetState`. + /// + /// # Type Parameters + /// + /// * `D` - The HKDF domain type implementing `HkdfInfo`. + /// + /// # Returns + /// + /// * `Ok(RatchetState)` on success. + /// * `Err(StorageError)` if key reconstruction fails. + pub fn to_ratchet_state(&self) -> Result, StorageError> { + // Reconstruct the keypair + let dh_self = InstallationKeyPair::from_bytes(self.dh_self_secret, self.dh_self_public) + .map_err(|e| StorageError::KeyReconstruction(e.to_string()))?; + + // Reconstruct skipped keys HashMap + let skipped_keys: HashMap<(PublicKey, u32), [u8; 32]> = self + .skipped_keys + .iter() + .map(|sk| { + let pub_key = PublicKey::from(sk.public_key); + ((pub_key, sk.msg_num), sk.message_key) + }) + .collect(); + + Ok(RatchetState::from_parts( + self.root_key, + self.sending_chain, + self.receiving_chain, + dh_self, + self.dh_remote.map(PublicKey::from), + self.msg_send, + self.msg_recv, + self.prev_chain_len, + skipped_keys, + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use double_ratchets::hkdf::DefaultDomain; + + #[test] + fn test_roundtrip_sender_state() { + // Create a sender state + let bob_keypair = InstallationKeyPair::generate(); + let shared_secret = [0x42u8; 32]; + let state: RatchetState = + RatchetState::init_sender(shared_secret, *bob_keypair.public()); + + // Convert to storable and back + let storable = StorableRatchetState::from_ratchet_state(&state, "default"); + let restored: RatchetState = storable.to_ratchet_state().unwrap(); + + // Verify fields match + assert_eq!(state.root_key, restored.root_key); + assert_eq!(state.sending_chain, restored.sending_chain); + assert_eq!(state.receiving_chain, restored.receiving_chain); + assert_eq!(state.dh_self.public().as_bytes(), restored.dh_self.public().as_bytes()); + assert_eq!(state.dh_remote.map(|pk| *pk.as_bytes()), restored.dh_remote.map(|pk| *pk.as_bytes())); + assert_eq!(state.msg_send, restored.msg_send); + assert_eq!(state.msg_recv, restored.msg_recv); + assert_eq!(state.prev_chain_len, restored.prev_chain_len); + } + + #[test] + fn test_roundtrip_receiver_state() { + // Create a receiver state + let keypair = InstallationKeyPair::generate(); + let shared_secret = [0x42u8; 32]; + let state: RatchetState = + RatchetState::init_receiver(shared_secret, keypair); + + // Convert to storable and back + let storable = StorableRatchetState::from_ratchet_state(&state, "default"); + let restored: RatchetState = storable.to_ratchet_state().unwrap(); + + // Verify fields match + assert_eq!(state.root_key, restored.root_key); + assert_eq!(state.dh_self.public().as_bytes(), restored.dh_self.public().as_bytes()); + assert!(restored.dh_remote.is_none()); + } + + #[test] + fn test_roundtrip_with_skipped_keys() { + // Create states and exchange messages to generate skipped keys + let bob_keypair = InstallationKeyPair::generate(); + let shared_secret = [0x42u8; 32]; + let mut alice: RatchetState = + RatchetState::init_sender(shared_secret, *bob_keypair.public()); + let mut bob: RatchetState = + RatchetState::init_receiver(shared_secret, bob_keypair); + + // Alice sends multiple messages + let mut messages = vec![]; + for i in 0..3 { + let (ct, header) = alice.encrypt_message(&format!("Message {}", i).into_bytes()); + messages.push((ct, header)); + } + + // Bob receives them out of order to create skipped keys + bob.decrypt_message(&messages[0].0, messages[0].1.clone()).unwrap(); + bob.decrypt_message(&messages[2].0, messages[2].1.clone()).unwrap(); + // Message 1 key is now in skipped_keys + + assert!(!bob.skipped_keys.is_empty()); + + // Convert to storable and back + let storable = StorableRatchetState::from_ratchet_state(&bob, "default"); + let restored: RatchetState = storable.to_ratchet_state().unwrap(); + + // Verify skipped keys are preserved + assert_eq!(bob.skipped_keys.len(), restored.skipped_keys.len()); + + // The restored state should be able to decrypt the skipped message + let mut restored = restored; + let pt = restored.decrypt_message(&messages[1].0, messages[1].1.clone()).unwrap(); + assert_eq!(pt, b"Message 1"); + } +} diff --git a/double-ratchets/src/keypair.rs b/double-ratchets/src/keypair.rs index 26463a4..704f613 100644 --- a/double-ratchets/src/keypair.rs +++ b/double-ratchets/src/keypair.rs @@ -23,4 +23,37 @@ impl InstallationKeyPair { pub fn public(&self) -> &PublicKey { &self.public } + + /// Export the secret key as raw bytes for serialization/storage. + /// + /// # Security Warning + /// + /// The returned bytes contain the private key material. Handle with care + /// and ensure proper encryption when storing. + pub fn secret_bytes(&self) -> [u8; 32] { + self.secret.to_bytes() + } + + /// Reconstruct a keypair from raw secret and public key bytes. + /// + /// # Arguments + /// + /// * `secret` - The 32-byte secret key. + /// * `public` - The 32-byte public key. + /// + /// # Returns + /// + /// * `Ok(InstallationKeyPair)` if the keys are valid and consistent. + /// * `Err(&'static str)` if the public key doesn't match the secret key. + pub fn from_bytes(secret: [u8; 32], public: [u8; 32]) -> Result { + let secret = StaticSecret::from(secret); + let expected_public = PublicKey::from(&secret); + let public = PublicKey::from(public); + + if expected_public.as_bytes() != public.as_bytes() { + return Err("public key does not match secret key"); + } + + Ok(Self { secret, public }) + } } diff --git a/double-ratchets/src/state.rs b/double-ratchets/src/state.rs index b34ef46..720eb69 100644 --- a/double-ratchets/src/state.rs +++ b/double-ratchets/src/state.rs @@ -59,6 +59,46 @@ impl Header { } impl RatchetState { + /// Reconstruct a `RatchetState` from its component parts. + /// + /// This is primarily used for deserialization from storage. + /// + /// # Arguments + /// + /// * `root_key` - The current root key. + /// * `sending_chain` - The current sending chain key, if any. + /// * `receiving_chain` - The current receiving chain key, if any. + /// * `dh_self` - Our DH key pair. + /// * `dh_remote` - Remote party's DH public key, if known. + /// * `msg_send` - Number of messages sent in current sending chain. + /// * `msg_recv` - Number of messages received in current receiving chain. + /// * `prev_chain_len` - Length of the previous sending chain. + /// * `skipped_keys` - Map of skipped message keys. + pub fn from_parts( + root_key: RootKey, + sending_chain: Option, + receiving_chain: Option, + dh_self: InstallationKeyPair, + dh_remote: Option, + msg_send: u32, + msg_recv: u32, + prev_chain_len: u32, + skipped_keys: HashMap<(PublicKey, u32), MessageKey>, + ) -> Self { + Self { + root_key, + sending_chain, + receiving_chain, + dh_self, + dh_remote, + msg_send, + msg_recv, + prev_chain_len, + skipped_keys, + _domain: PhantomData, + } + } + /// Initializes the party that sends the first message. /// /// Performs the initial Diffie-Hellman computation with the remote public key