From 8e2b5211b4a4e5ae03946f333e70de523e4aedf1 Mon Sep 17 00:00:00 2001 From: kaichao Date: Wed, 28 Jan 2026 14:54:16 +0800 Subject: [PATCH] Managed data storage for Ratchet state (#21) * feat: managed persist storage with sqlite * chore: sync skipped keys * chore: refactor * chore: refactor * chore: clean code * chore: export skipped keys from state. * chore: renaming data to record * chore: remove types from stroage mod file --- .gitignore | 3 + Cargo.lock | 187 ++++++-- double-ratchets/Cargo.toml | 4 + double-ratchets/README.md | 5 +- double-ratchets/examples/out_of_order_demo.rs | 166 +++++++ double-ratchets/examples/storage_demo.rs | 241 ++++++++++ double-ratchets/src/keypair.rs | 12 + double-ratchets/src/lib.rs | 4 + double-ratchets/src/state.rs | 75 ++- double-ratchets/src/storage/mod.rs | 5 + double-ratchets/src/storage/session.rs | 310 +++++++++++++ double-ratchets/src/storage/sqlite.rs | 437 ++++++++++++++++++ double-ratchets/src/storage/types.rs | 81 ++++ 13 files changed, 1500 insertions(+), 30 deletions(-) create mode 100644 double-ratchets/examples/out_of_order_demo.rs create mode 100644 double-ratchets/examples/storage_demo.rs create mode 100644 double-ratchets/src/storage/mod.rs create mode 100644 double-ratchets/src/storage/session.rs create mode 100644 double-ratchets/src/storage/sqlite.rs create mode 100644 double-ratchets/src/storage/types.rs diff --git a/.gitignore b/.gitignore index e00b4cb..1fca2de 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,6 @@ target # Compiled binary **/ffi_nim_example + +# Temporary data folder +tmp diff --git a/Cargo.lock b/Cargo.lock index ed54022..1970c57 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -24,6 +24,12 @@ version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" +[[package]] +name = "bitflags" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" + [[package]] name = "blake2" version = "0.10.6" @@ -48,6 +54,16 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" +[[package]] +name = "cc" +version = "1.2.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6354c81bbfd62d9cfa9cb3c773c2b7b2a3a482d569de977fd0e961f6e7c00583" +dependencies = [ + "find-msvc-tools", + "shlex", +] + [[package]] name = "cfg-if" version = "1.0.4" @@ -161,7 +177,7 @@ checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -182,7 +198,7 @@ checksum = "6edb4b64a43d977b8e99788fe3a04d483834fba1215a7e02caa415b626497f7f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -205,6 +221,7 @@ dependencies = [ "hkdf", "rand", "rand_core", + "rusqlite", "safer-ffi", "thiserror", "x25519-dalek", @@ -282,12 +299,36 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320bea982e85d42441eb25c49b41218e7eaa2657e8f90bc4eca7437376751e23" +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fiat-crypto" version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" +[[package]] +name = "find-msvc-tools" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8591b0bcc8a98a64310a2fae1bb3e9b8564dd10e381e6e28010fde8e8e8568db" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + [[package]] name = "generic-array" version = "0.14.7" @@ -310,21 +351,39 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" dependencies = [ "cfg-if", "libc", "wasi", ] +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + [[package]] name = "hashbrown" version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +[[package]] +name = "hashlink" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" +dependencies = [ + "hashbrown 0.15.5", +] + [[package]] name = "hex" version = "0.4.3" @@ -356,7 +415,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.16.1", ] [[package]] @@ -388,9 +447,21 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.178" +version = "0.2.180" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" +checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" + +[[package]] +name = "libsqlite3-sys" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "947e6816f7825b2b45027c2c32e7085da9934defa535de4a6a46b10a4d5257fa" +dependencies = [ + "cc", + "openssl-sys", + "pkg-config", + "vcpkg", +] [[package]] name = "logos-chat" @@ -434,6 +505,28 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" +[[package]] +name = "openssl-src" +version = "300.5.4+3.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a507b3792995dae9b0df8a1c1e3771e8418b7c2d9f0baeba32e6fe8b06c7cb72" +dependencies = [ + "cc", +] + +[[package]] +name = "openssl-sys" +version = "0.9.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" +dependencies = [ + "cc", + "libc", + "openssl-src", + "pkg-config", + "vcpkg", +] + [[package]] name = "paste" version = "1.0.15" @@ -450,6 +543,12 @@ dependencies = [ "spki", ] +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + [[package]] name = "poly1305" version = "0.8.0" @@ -491,9 +590,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.103" +version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" dependencies = [ "unicode-ident", ] @@ -518,14 +617,14 @@ dependencies = [ "itertools", "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] name = "quote" -version = "1.0.42" +version = "1.0.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f" +checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4" dependencies = [ "proc-macro2", ] @@ -560,6 +659,20 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rusqlite" +version = "0.35.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a22715a5d6deef63c637207afbe68d0c72c3f8d0022d7cf9714c442d6157606b" +dependencies = [ + "bitflags", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", +] + [[package]] name = "rustc_version" version = "0.4.1" @@ -646,7 +759,7 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -666,6 +779,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f179d4e11094a893b82fff208f74d448a7512f99f5a0acbd5c679b705f83ed9" +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "signature" version = "2.2.0" @@ -675,6 +794,12 @@ dependencies = [ "rand_core", ] +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + [[package]] name = "spki" version = "0.7.3" @@ -739,9 +864,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.111" +version = "2.0.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87" +checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a" dependencies = [ "proc-macro2", "quote", @@ -750,22 +875,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.17" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "2.0.17" +version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -835,6 +960,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0976c77def3f1f75c4ef892a292c31c0bbe9e3d0702c63044d7c76db298171a3" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.5" @@ -906,22 +1037,22 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.31" +version = "0.8.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd74ec98b9250adb3ca554bdde269adf631549f51d8a8f8f0a10b50f1cb298c3" +checksum = "71ddd76bcebeed25db614f82bf31a9f4222d3fbba300e6fb6c00afa26cbd4d9d" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.31" +version = "0.8.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8a8d209fdf45cf5138cbb5a506f6b52522a25afccc534d1475dad8e31105c6a" +checksum = "d8187381b52e32220d50b255276aa16a084ec0a9017a0ca2152a1f55c539758d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] [[package]] @@ -935,11 +1066,11 @@ dependencies = [ [[package]] name = "zeroize_derive" -version = "1.4.2" +version = "1.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" +checksum = "85a5b4158499876c763cb03bc4e49185d3cccbabb15b33c627f7884f43db852e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.111", + "syn 2.0.114", ] diff --git a/double-ratchets/Cargo.toml b/double-ratchets/Cargo.toml index 3685df2..048011d 100644 --- a/double-ratchets/Cargo.toml +++ b/double-ratchets/Cargo.toml @@ -20,6 +20,10 @@ thiserror = "2" blake2 = "0.10.6" safer-ffi = "0.1.13" zeroize = "1.8.2" +rusqlite = { version = "0.35", optional = true, features = ["bundled"] } [features] +default = [] +storage = ["rusqlite"] +sqlcipher = ["storage", "rusqlite/bundled-sqlcipher-vendored-openssl"] headers = ["safer-ffi/headers"] diff --git a/double-ratchets/README.md b/double-ratchets/README.md index bad8a7b..974deda 100644 --- a/double-ratchets/README.md +++ b/double-ratchets/README.md @@ -17,8 +17,11 @@ let plaintext = bob.decrypt_message(&ciphertext, header); Run examples, -``` +```bash cargo run --example double_ratchet_basic + +cargo run --example storage_demo --features storage +cargo run --example storage_demo --features sqlcipher ``` Run Nim FFI example, diff --git a/double-ratchets/examples/out_of_order_demo.rs b/double-ratchets/examples/out_of_order_demo.rs new file mode 100644 index 0000000..a2dbb4d --- /dev/null +++ b/double-ratchets/examples/out_of_order_demo.rs @@ -0,0 +1,166 @@ +//! Demonstrates out-of-order message handling with skipped keys persistence. +//! +//! Run with: cargo run --example out_of_order_demo --features storage + +#[cfg(feature = "storage")] +use double_ratchets::{ + InstallationKeyPair, RatchetState, SqliteStorage, StorageConfig, hkdf::DefaultDomain, + state::Header, +}; + +fn main() { + println!("=== Out-of-Order Message Handling Demo (skipped - enable 'storage' feature) ===\n"); + + #[cfg(feature = "storage")] + run_demo(); +} + +#[cfg(feature = "storage")] +fn run_demo() { + let mut storage = + SqliteStorage::new(StorageConfig::InMemory).expect("Failed to create storage"); + + // Setup + let shared_secret = [0x42u8; 32]; + let bob_keypair = InstallationKeyPair::generate(); + + let alice_state: RatchetState = + RatchetState::init_sender(shared_secret, bob_keypair.public().clone()); + let bob_state: RatchetState = + RatchetState::init_receiver(shared_secret, bob_keypair); + + storage.save("alice", &alice_state).unwrap(); + storage.save("bob", &bob_state).unwrap(); + + // === Alice sends 5 messages === + println!("Alice sends 5 messages..."); + let mut messages: Vec<(Vec, Header)> = Vec::new(); + + for i in 1..=5 { + let mut alice: RatchetState = storage.load("alice").unwrap(); + let msg = format!("Message #{}", i); + let (ct, header) = alice.encrypt_message(msg.as_bytes()); + storage.save("alice", &alice).unwrap(); + messages.push((ct, header)); + println!(" Sent: \"{}\"", msg); + } + + // === Bob receives messages out of order: 1, 3, 5 === + println!("\nBob receives messages 1, 3, 5 (out of order)..."); + + for &idx in &[0, 2, 4] { + let mut bob: RatchetState = storage.load("bob").unwrap(); + let (ct, header) = &messages[idx]; + let pt = bob + .decrypt_message(ct, header.clone()) + .expect("Decrypt failed"); + storage.save("bob", &bob).unwrap(); + println!(" Received: \"{}\"", String::from_utf8_lossy(&pt)); + } + + let bob: RatchetState = storage.load("bob").unwrap(); + println!("\nBob's skipped_keys count: {}", bob.skipped_keys.len()); + println!(" (Messages 2 and 4 keys are stored for later)"); + + // === Simulate Bob's app restart === + println!("\n--- Simulating Bob's app restart ---"); + drop(storage); + + // In-memory storage doesn't persist across restarts. + // Use file storage to properly demonstrate persistence: + println!(" (Using file storage to demonstrate real persistence)"); + if let Err(e) = std::fs::create_dir_all("./tmp") { + eprintln!("Failed to create tmp directory: {}", e); + return; // Or handle as needed + } + let db_path = "./tmp/out_of_order_demo.db"; + let _ = std::fs::remove_file(db_path); + + // Redo with file storage + let mut storage = SqliteStorage::new(StorageConfig::File(db_path.to_string())) + .expect("Failed to create storage"); + + // Re-setup + let bob_keypair = InstallationKeyPair::generate(); + let alice_state: RatchetState = + RatchetState::init_sender(shared_secret, bob_keypair.public().clone()); + let bob_state: RatchetState = + RatchetState::init_receiver(shared_secret, bob_keypair); + + storage.save("alice", &alice_state).unwrap(); + storage.save("bob", &bob_state).unwrap(); + + // Alice sends 5 messages + let mut messages: Vec<(Vec, Header)> = Vec::new(); + for i in 1..=5 { + let mut alice: RatchetState = storage.load("alice").unwrap(); + let msg = format!("Message #{}", i); + let (ct, header) = alice.encrypt_message(msg.as_bytes()); + storage.save("alice", &alice).unwrap(); + messages.push((ct, header)); + } + println!(" Alice sent 5 messages"); + + // Bob receives 1, 3, 5 (skips 2, 4) + for &idx in &[0, 2, 4] { + let mut bob: RatchetState = storage.load("bob").unwrap(); + let (ct, header) = &messages[idx]; + bob.decrypt_message(ct, header.clone()).unwrap(); + storage.save("bob", &bob).unwrap(); + } + + let bob: RatchetState = storage.load("bob").unwrap(); + println!( + " Bob received 1,3,5. Skipped keys stored: {}", + bob.skipped_keys.len() + ); + + // Close and reopen storage (simulating app restart) + drop(storage); + let mut storage = + SqliteStorage::new(StorageConfig::File(db_path.to_string())).expect("Failed to reopen"); + + let bob: RatchetState = storage.load("bob").unwrap(); + println!( + "\n After restart, Bob's skipped_keys: {}", + bob.skipped_keys.len() + ); + + // === Now Bob receives the delayed messages === + println!("\nBob receives delayed message 2..."); + { + let mut bob: RatchetState = storage.load("bob").unwrap(); + let (ct, header) = &messages[1]; + let pt = bob.decrypt_message(ct, header.clone()).unwrap(); + storage.save("bob", &bob).unwrap(); + println!(" Received: \"{}\"", String::from_utf8_lossy(&pt)); + println!(" Remaining skipped_keys: {}", bob.skipped_keys.len()); + } + + println!("\nBob receives delayed message 4..."); + let (ct4, header4) = messages[3].clone(); + { + let mut bob: RatchetState = storage.load("bob").unwrap(); + let pt = bob.decrypt_message(&ct4, header4.clone()).unwrap(); + storage.save("bob", &bob).unwrap(); + println!(" Received: \"{}\"", String::from_utf8_lossy(&pt)); + println!(" Remaining skipped_keys: {}", bob.skipped_keys.len()); + } + + // === Demonstrate replay protection === + println!("\n--- Replay Protection Demo ---"); + println!("Trying to decrypt message 4 again (should fail)..."); + + { + let mut bob: RatchetState = storage.load("bob").unwrap(); + match bob.decrypt_message(&ct4, header4) { + Ok(_) => println!(" ERROR: Replay attack succeeded!"), + Err(e) => println!(" Correctly rejected: {:?}", e), + } + } + + // Cleanup + let _ = std::fs::remove_file(db_path); + + println!("\n=== Demo Complete ==="); +} diff --git a/double-ratchets/examples/storage_demo.rs b/double-ratchets/examples/storage_demo.rs new file mode 100644 index 0000000..ce05bd4 --- /dev/null +++ b/double-ratchets/examples/storage_demo.rs @@ -0,0 +1,241 @@ +//! Demonstrates SQLite storage for Double Ratchet state persistence. +//! +//! Run with: cargo run --example storage_demo --features storage +//! For SQLCipher: cargo run --example storage_demo --features sqlcipher + +#[cfg(feature = "storage")] +use double_ratchets::{ + InstallationKeyPair, RatchetSession, SqliteStorage, StorageConfig, hkdf::PrivateV1Domain, +}; + +fn main() { + println!("=== Double Ratchet Storage Demo ===\n"); + + // Demo 1: In-memory storage (for testing) + println!("--- Demo 1: In-Memory Storage (skipped - enable 'storage' feature) ---"); + #[cfg(feature = "storage")] + demo_in_memory(); + + // Demo 2: File-based storage (for local development) + println!("\n--- Demo 2: File-Based Storage (skipped - enable 'storage' feature) ---"); + #[cfg(feature = "storage")] + demo_file_storage(); + + // Demo 3: SQLCipher encrypted storage (for production) + #[cfg(feature = "sqlcipher")] + { + println!("\n--- Demo 3: SQLCipher Encrypted Storage ---"); + demo_sqlcipher(); + } + + #[cfg(not(feature = "sqlcipher"))] + { + println!("\n--- Demo 3: SQLCipher (skipped - enable 'sqlcipher' feature) ---"); + } +} + +#[cfg(feature = "storage")] +fn demo_in_memory() { + let mut alice_storage = + SqliteStorage::new(StorageConfig::InMemory).expect("Failed to create storage"); + let mut bob_storage = + SqliteStorage::new(StorageConfig::InMemory).expect("Failed to create storage"); + run_conversation(&mut alice_storage, &mut bob_storage); +} + +#[cfg(feature = "storage")] +fn demo_file_storage() { + ensure_tmp_directory(); + + let db_path_alice = "./tmp/double_ratchet_demo_alice.db"; + let db_path_bob = "./tmp/double_ratchet_demo_bob.db"; + let _ = std::fs::remove_file(db_path_alice); + let _ = std::fs::remove_file(db_path_bob); + + // Initial conversation + { + let mut alice_storage = SqliteStorage::new(StorageConfig::File(db_path_alice.to_string())) + .expect("Failed to create storage"); + + let mut bob_storage = SqliteStorage::new(StorageConfig::File(db_path_bob.to_string())) + .expect("Failed to create storage"); + + println!(" Database created at: {}, {}", db_path_alice, db_path_bob); + run_conversation(&mut alice_storage, &mut bob_storage); + } + + // Simulate restart - reopen and continue + println!("\n Simulating application restart..."); + { + let mut alice_storage = SqliteStorage::new(StorageConfig::File(db_path_alice.to_string())) + .expect("Failed to reopen storage"); + let mut bob_storage = SqliteStorage::new(StorageConfig::File(db_path_bob.to_string())) + .expect("Failed to reopen storage"); + continue_after_restart(&mut alice_storage, &mut bob_storage); + } + + let _ = std::fs::remove_file(db_path_alice); + let _ = std::fs::remove_file(db_path_bob); +} + +#[cfg(feature = "sqlcipher")] +fn demo_sqlcipher() { + ensure_tmp_directory(); + let alice_db_path = "./tmp/double_ratchet_encrypted_alice.db"; + let bob_db_path = "./tmp/double_ratchet_encrypted_bob.db"; + let encryption_key = "super-secret-key-123!"; + let _ = std::fs::remove_file(alice_db_path); + let _ = std::fs::remove_file(bob_db_path); + + // Initial conversation with encryption + { + let mut alice_storage = SqliteStorage::new(StorageConfig::Encrypted { + path: alice_db_path.to_string(), + key: encryption_key.to_string(), + }) + .expect("Failed to create encrypted storage"); + let mut bob_storage = SqliteStorage::new(StorageConfig::Encrypted { + path: bob_db_path.to_string(), + key: encryption_key.to_string(), + }) + .expect("Failed to create encrypted storage"); + println!( + " Encrypted database created at: {}, {}", + alice_db_path, bob_db_path + ); + run_conversation(&mut alice_storage, &mut bob_storage); + } + + // Restart with correct key + println!("\n Simulating restart with encryption key..."); + { + let mut alice_storage = SqliteStorage::new(StorageConfig::Encrypted { + path: alice_db_path.to_string(), + key: encryption_key.to_string(), + }) + .expect("Failed to create encrypted storage"); + let mut bob_storage = SqliteStorage::new(StorageConfig::Encrypted { + path: bob_db_path.to_string(), + key: encryption_key.to_string(), + }) + .expect("Failed to create encrypted storage"); + continue_after_restart(&mut alice_storage, &mut bob_storage); + } + + let _ = std::fs::remove_file(alice_db_path); + let _ = std::fs::remove_file(bob_db_path); +} + +#[allow(dead_code)] +fn ensure_tmp_directory() { + if let Err(e) = std::fs::create_dir_all("./tmp") { + eprintln!("Failed to create tmp directory: {}", e); + return; // Or handle as needed + } +} + +/// Simulates a conversation between Alice and Bob. +/// Each party saves/loads state from storage for each operation. +#[cfg(feature = "storage")] +fn run_conversation(alice_storage: &mut SqliteStorage, bob_storage: &mut SqliteStorage) { + // === Setup: Simulate X3DH key exchange === + let shared_secret = [0x42u8; 32]; // In reality, this comes from X3DH + let bob_keypair = InstallationKeyPair::generate(); + + let conv_id = "conv1"; + + let mut alice_session: RatchetSession = RatchetSession::create_sender_session( + alice_storage, + conv_id, + shared_secret, + bob_keypair.public().clone(), + ) + .unwrap(); + + let mut bob_session: RatchetSession = + RatchetSession::create_receiver_session(bob_storage, conv_id, shared_secret, bob_keypair) + .unwrap(); + + println!(" Sessions created for Alice and Bob"); + + // === Message 1: Alice -> Bob === + let (ct1, h1) = { + let result = alice_session + .encrypt_message(b"Hello Bob! This is message 1.") + .unwrap(); + println!(" Alice sent: \"Hello Bob! This is message 1.\""); + result + }; + + { + let pt = bob_session.decrypt_message(&ct1, h1).unwrap(); + println!(" Bob received: \"{}\"", String::from_utf8_lossy(&pt)); + } + + // === Message 2: Bob -> Alice (triggers DH ratchet) === + let (ct2, h2) = { + let result = bob_session + .encrypt_message(b"Hi Alice! Got your message.") + .unwrap(); + println!(" Bob sent: \"Hi Alice! Got your message.\""); + result + }; + + { + let pt = alice_session.decrypt_message(&ct2, h2).unwrap(); + println!(" Alice received: \"{}\"", String::from_utf8_lossy(&pt)); + } + + // === Message 3: Alice -> Bob === + let (ct3, h3) = { + let result = alice_session + .encrypt_message(b"Great! Let's keep chatting.") + .unwrap(); + println!(" Alice sent: \"Great! Let's keep chatting.\""); + result + }; + + { + let pt = bob_session.decrypt_message(&ct3, h3).unwrap(); + println!(" Bob received: \"{}\"", String::from_utf8_lossy(&pt)); + } + + // Print final state + println!( + " State after conversation: Alice msg_send={}, Bob msg_recv={}", + alice_session.msg_send(), + bob_session.msg_recv() + ); +} + +#[cfg(feature = "storage")] +fn continue_after_restart(alice_storage: &mut SqliteStorage, bob_storage: &mut SqliteStorage) { + // Load persisted states + let conv_id = "conv1"; + + let mut alice_session: RatchetSession = + RatchetSession::open(alice_storage, conv_id).unwrap(); + let mut bob_session: RatchetSession = + RatchetSession::open(bob_storage, conv_id).unwrap(); + println!(" Sessions restored for Alice and Bob",); + + // Continue conversation + let (ct, header) = { + let result = alice_session + .encrypt_message(b"Message after restart!") + .unwrap(); + println!(" Alice sent: \"Message after restart!\""); + result + }; + + { + let pt = bob_session.decrypt_message(&ct, header).unwrap(); + println!(" Bob received: \"{}\"", String::from_utf8_lossy(&pt)); + } + + println!( + " Final state: Alice msg_send={}, Bob msg_recv={}", + alice_session.msg_send(), + bob_session.msg_recv() + ); +} diff --git a/double-ratchets/src/keypair.rs b/double-ratchets/src/keypair.rs index 7d9dc07..c32adb9 100644 --- a/double-ratchets/src/keypair.rs +++ b/double-ratchets/src/keypair.rs @@ -24,4 +24,16 @@ impl InstallationKeyPair { pub fn public(&self) -> &PublicKey { &self.public } + + /// Export the secret key as raw bytes for storage. + pub fn secret_bytes(&self) -> [u8; 32] { + self.secret.to_bytes() + } + + /// Reconstruct from secret key bytes. + pub fn from_secret_bytes(bytes: [u8; 32]) -> Self { + let secret = StaticSecret::from(bytes); + let public = PublicKey::from(&secret); + Self { secret, public } + } } diff --git a/double-ratchets/src/lib.rs b/double-ratchets/src/lib.rs index 3bd8b46..1b9a566 100644 --- a/double-ratchets/src/lib.rs +++ b/double-ratchets/src/lib.rs @@ -4,7 +4,11 @@ pub mod ffi; pub mod hkdf; pub mod keypair; pub mod state; +#[cfg(feature = "storage")] +pub mod storage; pub mod types; pub use keypair::InstallationKeyPair; pub use state::{Header, RatchetState}; +#[cfg(feature = "storage")] +pub use storage::{RatchetSession, SessionError, SqliteStorage, StorageConfig, StorageError}; diff --git a/double-ratchets/src/state.rs b/double-ratchets/src/state.rs index 9aec0cd..dc92ee1 100644 --- a/double-ratchets/src/state.rs +++ b/double-ratchets/src/state.rs @@ -31,7 +31,15 @@ pub struct RatchetState { pub skipped_keys: HashMap<(PublicKey, u32), MessageKey>, - _domain: PhantomData, + pub(crate) _domain: PhantomData, +} + +/// Represents a skipped message key for storage or inspection. +#[derive(Debug, Clone)] +pub struct SkippedKey { + pub public_key: [u8; 32], + pub msg_num: u32, + pub message_key: MessageKey, } /// Public header attached to every encrypted message (unencrypted but authenticated). @@ -290,6 +298,22 @@ impl RatchetState { Ok(()) } + + /// Exports the skipped keys for storage or inspection. + /// + /// # Returns + /// + /// A vector of `SkippedKey` representing the currently stored skipped message keys. + pub fn skipped_keys(&self) -> Vec { + self.skipped_keys + .iter() + .map(|((pk, msg_num), mk)| SkippedKey { + public_key: pk.to_bytes(), + msg_num: *msg_num, + message_key: *mk, + }) + .collect() + } } #[cfg(test)] @@ -488,4 +512,53 @@ mod tests { assert!(result.is_err()); assert_eq!(result.unwrap_err(), RatchetError::MessageReplay); } + + #[test] + fn test_skipped_keys_export() { + let (mut alice, mut bob, _) = setup_alice_bob(); + + // Initially no skipped keys + assert!(bob.skipped_keys().is_empty()); + + // Alice sends 4 messages + let mut encrypted = vec![]; + for i in 0..4 { + let msg = format!("Message {}", i).into_bytes(); + let (ct, h) = alice.encrypt_message(&msg); + encrypted.push((ct, h, msg)); + } + + // Bob receives message 0 first + bob.decrypt_message(&encrypted[0].0, encrypted[0].1.clone()) + .unwrap(); + assert!(bob.skipped_keys().is_empty()); + + // Bob receives message 3, skipping 1 and 2 + bob.decrypt_message(&encrypted[3].0, encrypted[3].1.clone()) + .unwrap(); + + // Now we should have 2 skipped keys (for messages 1 and 2) + let skipped = bob.skipped_keys(); + assert_eq!(skipped.len(), 2); + + // Verify the skipped keys have the expected message numbers + let msg_nums: Vec = skipped.iter().map(|sk| sk.msg_num).collect(); + assert!(msg_nums.contains(&1)); + assert!(msg_nums.contains(&2)); + + // Verify each skipped key has valid data + for sk in &skipped { + assert_eq!(sk.public_key.len(), 32); + assert_eq!(sk.message_key.len(), 32); + } + + // Now decrypt message 1 using the skipped key + bob.decrypt_message(&encrypted[1].0, encrypted[1].1.clone()) + .unwrap(); + + // Should only have 1 skipped key left (for message 2) + let skipped_after = bob.skipped_keys(); + assert_eq!(skipped_after.len(), 1); + assert_eq!(skipped_after[0].msg_num, 2); + } } diff --git a/double-ratchets/src/storage/mod.rs b/double-ratchets/src/storage/mod.rs new file mode 100644 index 0000000..e26ec70 --- /dev/null +++ b/double-ratchets/src/storage/mod.rs @@ -0,0 +1,5 @@ +mod session; +mod sqlite; + +pub use session::{RatchetSession, SessionError}; +pub use sqlite::{SqliteStorage, StorageConfig}; diff --git a/double-ratchets/src/storage/session.rs b/double-ratchets/src/storage/session.rs new file mode 100644 index 0000000..399af8d --- /dev/null +++ b/double-ratchets/src/storage/session.rs @@ -0,0 +1,310 @@ +use x25519_dalek::PublicKey; + +use crate::{ + InstallationKeyPair, + errors::RatchetError, + hkdf::HkdfInfo, + state::{Header, RatchetState}, + types::SharedSecret, +}; + +use super::{SqliteStorage, StorageError}; + +/// A session wrapper that automatically persists ratchet state after operations. +/// Provides rollback semantics - state is only saved if the operation succeeds. +pub struct RatchetSession<'a, D: HkdfInfo + Clone> { + storage: &'a mut SqliteStorage, + conversation_id: String, + state: RatchetState, +} + +#[derive(Debug)] +pub enum SessionError { + Storage(StorageError), + Ratchet(RatchetError), +} + +impl From for SessionError { + fn from(e: StorageError) -> Self { + SessionError::Storage(e) + } +} + +impl From for SessionError { + fn from(e: RatchetError) -> Self { + SessionError::Ratchet(e) + } +} + +impl std::fmt::Display for SessionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SessionError::Storage(e) => write!(f, "storage error: {}", e), + SessionError::Ratchet(e) => write!(f, "ratchet error: {}", e), + } + } +} + +impl std::error::Error for SessionError {} + +impl<'a, D: HkdfInfo + Clone> RatchetSession<'a, D> { + /// Opens an existing session from storage. + pub fn open( + storage: &'a mut SqliteStorage, + conversation_id: impl Into, + ) -> Result { + let conversation_id = conversation_id.into(); + let state = storage.load(&conversation_id)?; + Ok(Self { + storage, + conversation_id, + state, + }) + } + + /// Creates a new session and persists the initial state. + pub fn create( + storage: &'a mut SqliteStorage, + conversation_id: impl Into, + state: RatchetState, + ) -> Result { + let conversation_id = conversation_id.into(); + storage.save(&conversation_id, &state)?; + Ok(Self { + storage, + conversation_id, + state, + }) + } + + /// Initializes a new session as a sender and persists the initial state. + pub fn create_sender_session( + storage: &'a mut SqliteStorage, + conversation_id: impl Into, + shared_secret: SharedSecret, + remote_pub: PublicKey, + ) -> Result { + let state = RatchetState::::init_sender(shared_secret, remote_pub); + Self::create(storage, conversation_id, state) + } + + /// Initializes a new session as a receiver and persists the initial state. + pub fn create_receiver_session( + storage: &'a mut SqliteStorage, + conversation_id: impl Into, + shared_secret: SharedSecret, + dh_self: InstallationKeyPair, + ) -> Result { + let conversation_id = conversation_id.into(); + if storage.exists(&conversation_id)? { + return Self::open(storage, conversation_id); + } + + let state = RatchetState::::init_receiver(shared_secret, dh_self); + Self::create(storage, conversation_id, state) + } + + /// Encrypts a message and persists the updated state. + /// If persistence fails, the in-memory state is NOT modified. + pub fn encrypt_message(&mut self, plaintext: &[u8]) -> Result<(Vec, Header), SessionError> { + // Clone state for rollback + let state_backup = self.state.clone(); + + // Perform encryption (modifies state) + let result = self.state.encrypt_message(plaintext); + + // Try to persist + if let Err(e) = self.storage.save(&self.conversation_id, &self.state) { + // Rollback + self.state = state_backup; + return Err(SessionError::Storage(e)); + } + + Ok(result) + } + + /// Decrypts a message and persists the updated state. + /// If decryption or persistence fails, the in-memory state is NOT modified. + pub fn decrypt_message( + &mut self, + ciphertext_with_nonce: &[u8], + header: Header, + ) -> Result, SessionError> { + // Clone state for rollback + let state_backup = self.state.clone(); + + // Perform decryption (modifies state) + let plaintext = match self.state.decrypt_message(ciphertext_with_nonce, header) { + Ok(pt) => pt, + Err(e) => { + // Rollback on decrypt failure + self.state = state_backup; + return Err(SessionError::Ratchet(e)); + } + }; + + // Try to persist + if let Err(e) = self.storage.save(&self.conversation_id, &self.state) { + // Rollback + self.state = state_backup; + return Err(SessionError::Storage(e)); + } + + Ok(plaintext) + } + + /// Returns a reference to the current state (read-only). + pub fn state(&self) -> &RatchetState { + &self.state + } + + /// Returns the conversation ID. + pub fn conversation_id(&self) -> &str { + &self.conversation_id + } + + /// Manually saves the current state. + pub fn save(&mut self) -> Result<(), StorageError> { + self.storage.save(&self.conversation_id, &self.state) + } + + pub fn msg_send(&self) -> u32 { + self.state.msg_send + } + + pub fn msg_recv(&self) -> u32 { + self.state.msg_recv + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{hkdf::DefaultDomain, keypair::InstallationKeyPair, storage::StorageConfig}; + + fn create_test_storage() -> SqliteStorage { + SqliteStorage::new(StorageConfig::InMemory).unwrap() + } + + #[test] + fn test_session_create_and_open() { + let mut storage = create_test_storage(); + + let shared_secret = [0x42; 32]; + let bob_keypair = InstallationKeyPair::generate(); + let alice: RatchetState = + RatchetState::init_sender(shared_secret, bob_keypair.public().clone()); + + // Create session + { + let session = RatchetSession::create(&mut storage, "conv1", alice).unwrap(); + assert_eq!(session.conversation_id(), "conv1"); + } + + // Open existing session + { + let session: RatchetSession = + RatchetSession::open(&mut storage, "conv1").unwrap(); + assert_eq!(session.state().msg_send, 0); + } + } + + #[test] + fn test_session_encrypt_persists() { + let mut storage = create_test_storage(); + + let shared_secret = [0x42; 32]; + let bob_keypair = InstallationKeyPair::generate(); + let alice: RatchetState = + RatchetState::init_sender(shared_secret, bob_keypair.public().clone()); + + // Create and encrypt + { + let mut session = RatchetSession::create(&mut storage, "conv1", alice).unwrap(); + session.encrypt_message(b"Hello").unwrap(); + assert_eq!(session.state().msg_send, 1); + } + + // Reopen - state should be persisted + { + let session: RatchetSession = + RatchetSession::open(&mut storage, "conv1").unwrap(); + assert_eq!(session.state().msg_send, 1); + } + } + + #[test] + fn test_session_full_conversation() { + let mut storage = create_test_storage(); + + let shared_secret = [0x42; 32]; + let bob_keypair = InstallationKeyPair::generate(); + let alice: RatchetState = + RatchetState::init_sender(shared_secret, bob_keypair.public().clone()); + let bob: RatchetState = + RatchetState::init_receiver(shared_secret, bob_keypair); + + // Alice sends + let (ct, header) = { + let mut session = RatchetSession::create(&mut storage, "alice", alice).unwrap(); + session.encrypt_message(b"Hello Bob").unwrap() + }; + + // Bob receives + let plaintext = { + let mut session = RatchetSession::create(&mut storage, "bob", bob).unwrap(); + session.decrypt_message(&ct, header).unwrap() + }; + assert_eq!(plaintext, b"Hello Bob"); + + // Bob replies + let (ct2, header2) = { + let mut session: RatchetSession = + RatchetSession::open(&mut storage, "bob").unwrap(); + session.encrypt_message(b"Hi Alice").unwrap() + }; + + // Alice receives + let plaintext2 = { + let mut session: RatchetSession = + RatchetSession::open(&mut storage, "alice").unwrap(); + session.decrypt_message(&ct2, header2).unwrap() + }; + assert_eq!(plaintext2, b"Hi Alice"); + } + + #[test] + fn test_session_open_or_create() { + let mut storage = create_test_storage(); + + let shared_secret = [0x42; 32]; + let bob_keypair = InstallationKeyPair::generate(); + let bob_pub = bob_keypair.public().clone(); + + // First call creates + { + let session: RatchetSession = RatchetSession::create_sender_session( + &mut storage, + "conv1", + shared_secret, + bob_pub.clone(), + ) + .unwrap(); + assert_eq!(session.state().msg_send, 0); + } + + // Second call opens existing + { + let mut session: RatchetSession = + RatchetSession::open(&mut storage, "conv1").unwrap(); + session.encrypt_message(b"test").unwrap(); + } + + // Verify persistence + { + let session: RatchetSession = + RatchetSession::open(&mut storage, "conv1").unwrap(); + assert_eq!(session.state().msg_send, 1); + } + } +} diff --git a/double-ratchets/src/storage/sqlite.rs b/double-ratchets/src/storage/sqlite.rs new file mode 100644 index 0000000..2c061f8 --- /dev/null +++ b/double-ratchets/src/storage/sqlite.rs @@ -0,0 +1,437 @@ +use rusqlite::{Connection, params}; + +use super::{RatchetStateRecord, SkippedKey, StorageError}; +use crate::{hkdf::HkdfInfo, state::RatchetState}; + +/// Configuration for SQLite storage. +#[derive(Debug, Clone)] +pub enum StorageConfig { + /// In-memory database (for testing). + InMemory, + /// File-based SQLite database (unencrypted, for local dev). + File(String), + /// SQLCipher encrypted database (for production). + /// Requires the `sqlcipher` feature. + #[cfg(feature = "sqlcipher")] + Encrypted { path: String, key: String }, +} + +/// SQLite-based storage for ratchet state. +pub struct SqliteStorage { + conn: Connection, +} + +impl SqliteStorage { + /// Creates a new SQLite storage with the given configuration. + pub fn new(config: StorageConfig) -> Result { + let conn = match config { + StorageConfig::InMemory => Connection::open_in_memory()?, + StorageConfig::File(path) => Connection::open(path)?, + #[cfg(feature = "sqlcipher")] + StorageConfig::Encrypted { path, key } => { + let conn = Connection::open(path)?; + conn.pragma_update(None, "key", &key)?; + conn + } + }; + + let storage = Self { conn }; + storage.init_schema()?; + Ok(storage) + } + + fn init_schema(&self) -> Result<(), StorageError> { + self.conn.execute_batch( + " + CREATE TABLE IF NOT EXISTS ratchet_state ( + conversation_id TEXT PRIMARY KEY, + root_key BLOB NOT NULL, + sending_chain BLOB, + receiving_chain BLOB, + dh_self_secret BLOB NOT NULL, + dh_remote BLOB, + msg_send INTEGER NOT NULL, + msg_recv INTEGER NOT NULL, + prev_chain_len INTEGER NOT NULL + ); + + CREATE TABLE IF NOT EXISTS skipped_keys ( + conversation_id TEXT NOT NULL, + public_key BLOB NOT NULL, + msg_num INTEGER NOT NULL, + message_key BLOB NOT NULL, + created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')), + PRIMARY KEY (conversation_id, public_key, msg_num), + FOREIGN KEY (conversation_id) REFERENCES ratchet_state(conversation_id) ON DELETE CASCADE + ); + + CREATE INDEX IF NOT EXISTS idx_skipped_keys_conversation + ON skipped_keys(conversation_id); + ", + )?; + Ok(()) + } + + /// Saves the ratchet state for a conversation within a transaction. + /// Rolls back automatically if any error occurs. + pub fn save( + &mut self, + conversation_id: &str, + state: &RatchetState, + ) -> Result<(), StorageError> { + let tx = self.conn.transaction()?; + + let data = RatchetStateRecord::from(state); + let skipped_keys: Vec = state.skipped_keys(); + + // Upsert main state + tx.execute( + " + INSERT INTO ratchet_state ( + conversation_id, root_key, sending_chain, receiving_chain, + dh_self_secret, dh_remote, msg_send, msg_recv, prev_chain_len + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9) + ON CONFLICT(conversation_id) DO UPDATE SET + root_key = excluded.root_key, + sending_chain = excluded.sending_chain, + receiving_chain = excluded.receiving_chain, + dh_self_secret = excluded.dh_self_secret, + dh_remote = excluded.dh_remote, + msg_send = excluded.msg_send, + msg_recv = excluded.msg_recv, + prev_chain_len = excluded.prev_chain_len + ", + params![ + conversation_id, + data.root_key.as_slice(), + data.sending_chain.as_ref().map(|c| c.as_slice()), + data.receiving_chain.as_ref().map(|c| c.as_slice()), + data.dh_self_secret.as_slice(), + data.dh_remote.as_ref().map(|c| c.as_slice()), + data.msg_send, + data.msg_recv, + data.prev_chain_len, + ], + )?; + + // Sync skipped keys efficiently - only insert new, delete removed + sync_skipped_keys(&tx, conversation_id, skipped_keys)?; + + tx.commit()?; + Ok(()) + } + + /// Loads the ratchet state for a conversation. + pub fn load( + &self, + conversation_id: &str, + ) -> Result, StorageError> { + let data = self.load_state_data(conversation_id)?; + let skipped_keys = self.load_skipped_keys(conversation_id)?; + Ok(data.into_ratchet_state(skipped_keys)) + } + + fn load_state_data(&self, conversation_id: &str) -> Result { + let mut stmt = self.conn.prepare( + " + SELECT root_key, sending_chain, receiving_chain, dh_self_secret, + dh_remote, msg_send, msg_recv, prev_chain_len + FROM ratchet_state + WHERE conversation_id = ?1 + ", + )?; + + stmt.query_row(params![conversation_id], |row| { + Ok(RatchetStateRecord { + root_key: blob_to_array(row.get::<_, Vec>(0)?), + sending_chain: row.get::<_, Option>>(1)?.map(blob_to_array), + receiving_chain: row.get::<_, Option>>(2)?.map(blob_to_array), + dh_self_secret: blob_to_array(row.get::<_, Vec>(3)?), + dh_remote: row.get::<_, Option>>(4)?.map(blob_to_array), + msg_send: row.get(5)?, + msg_recv: row.get(6)?, + prev_chain_len: row.get(7)?, + }) + }) + .map_err(|e| match e { + rusqlite::Error::QueryReturnedNoRows => { + StorageError::ConversationNotFound(conversation_id.to_string()) + } + e => StorageError::Database(e), + }) + } + + fn load_skipped_keys(&self, conversation_id: &str) -> Result, StorageError> { + let mut stmt = self.conn.prepare( + " + SELECT public_key, msg_num, message_key + FROM skipped_keys + WHERE conversation_id = ?1 + ", + )?; + + let rows = stmt.query_map(params![conversation_id], |row| { + Ok(SkippedKey { + public_key: blob_to_array(row.get::<_, Vec>(0)?), + msg_num: row.get(1)?, + message_key: blob_to_array(row.get::<_, Vec>(2)?), + }) + })?; + + rows.collect::, _>>() + .map_err(StorageError::Database) + } + + /// Checks if a conversation exists. + pub fn exists(&self, conversation_id: &str) -> Result { + let count: i64 = self.conn.query_row( + "SELECT COUNT(*) FROM ratchet_state WHERE conversation_id = ?1", + params![conversation_id], + |row| row.get(0), + )?; + Ok(count > 0) + } + + /// Deletes a conversation and its skipped keys. + pub fn delete(&mut self, conversation_id: &str) -> Result<(), StorageError> { + let tx = self.conn.transaction()?; + tx.execute( + "DELETE FROM skipped_keys WHERE conversation_id = ?1", + params![conversation_id], + )?; + tx.execute( + "DELETE FROM ratchet_state WHERE conversation_id = ?1", + params![conversation_id], + )?; + tx.commit()?; + Ok(()) + } + + /// Cleans up old skipped keys older than the given age in seconds. + pub fn cleanup_old_skipped_keys(&mut self, max_age_secs: i64) -> Result { + let deleted = self.conn.execute( + "DELETE FROM skipped_keys WHERE created_at < strftime('%s', 'now') - ?1", + params![max_age_secs], + )?; + Ok(deleted) + } +} + +/// Syncs skipped keys efficiently by computing diff and only inserting/deleting changes. +fn sync_skipped_keys( + tx: &rusqlite::Transaction, + conversation_id: &str, + current_keys: Vec, +) -> Result<(), StorageError> { + use std::collections::HashSet; + + // Get existing keys from DB (just the identifiers) + let mut stmt = + tx.prepare("SELECT public_key, msg_num FROM skipped_keys WHERE conversation_id = ?1")?; + let existing: HashSet<([u8; 32], u32)> = stmt + .query_map(params![conversation_id], |row| { + Ok(( + blob_to_array(row.get::<_, Vec>(0)?), + row.get::<_, u32>(1)?, + )) + })? + .filter_map(|r| r.ok()) + .collect(); + + // Build set of current keys + let current_set: HashSet<([u8; 32], u32)> = current_keys + .iter() + .map(|sk| (sk.public_key, sk.msg_num)) + .collect(); + + // Delete keys that were removed (used for decryption) + for (pk, msg_num) in existing.difference(¤t_set) { + tx.execute( + "DELETE FROM skipped_keys WHERE conversation_id = ?1 AND public_key = ?2 AND msg_num = ?3", + params![conversation_id, pk.as_slice(), msg_num], + )?; + } + + // Insert new keys + for sk in ¤t_keys { + let key = (sk.public_key, sk.msg_num); + if !existing.contains(&key) { + tx.execute( + "INSERT INTO skipped_keys (conversation_id, public_key, msg_num, message_key) + VALUES (?1, ?2, ?3, ?4)", + params![ + conversation_id, + sk.public_key.as_slice(), + sk.msg_num, + sk.message_key.as_slice(), + ], + )?; + } + } + + Ok(()) +} + +fn blob_to_array(blob: Vec) -> [u8; N] { + blob.try_into() + .unwrap_or_else(|v: Vec| panic!("Expected {} bytes, got {}", N, v.len())) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{hkdf::DefaultDomain, keypair::InstallationKeyPair}; + + fn create_test_storage() -> SqliteStorage { + SqliteStorage::new(StorageConfig::InMemory).unwrap() + } + + fn create_test_state() -> (RatchetState, RatchetState) { + let shared_secret = [0x42; 32]; + let bob_keypair = InstallationKeyPair::generate(); + let alice = RatchetState::init_sender(shared_secret, bob_keypair.public().clone()); + let bob = RatchetState::init_receiver(shared_secret, bob_keypair); + (alice, bob) + } + + #[test] + fn test_save_and_load_sender() { + let mut storage = create_test_storage(); + let (alice, _) = create_test_state(); + + storage.save("conv1", &alice).unwrap(); + let loaded: RatchetState = storage.load("conv1").unwrap(); + + assert_eq!(alice.root_key, loaded.root_key); + assert_eq!(alice.sending_chain, loaded.sending_chain); + assert_eq!(alice.receiving_chain, loaded.receiving_chain); + assert_eq!(alice.msg_send, loaded.msg_send); + assert_eq!(alice.msg_recv, loaded.msg_recv); + assert_eq!(alice.prev_chain_len, loaded.prev_chain_len); + assert_eq!( + alice.dh_self.public().to_bytes(), + loaded.dh_self.public().to_bytes() + ); + } + + #[test] + fn test_save_and_load_receiver() { + let mut storage = create_test_storage(); + let (_, bob) = create_test_state(); + + storage.save("conv1", &bob).unwrap(); + let loaded: RatchetState = storage.load("conv1").unwrap(); + + assert_eq!(bob.root_key, loaded.root_key); + assert!(loaded.dh_remote.is_none()); + } + + #[test] + fn test_load_not_found() { + let storage = create_test_storage(); + let result: Result, _> = storage.load("nonexistent"); + assert!(matches!(result, Err(StorageError::ConversationNotFound(_)))); + } + + #[test] + fn test_save_with_skipped_keys() { + let mut storage = create_test_storage(); + let (mut alice, mut bob) = create_test_state(); + + // Alice sends 3 messages + let mut sent = vec![]; + for i in 0..3 { + let plaintext = format!("Message {}", i + 1).into_bytes(); + let (ct, header) = alice.encrypt_message(&plaintext); + sent.push((ct, header, plaintext)); + } + + // Bob receives 0 and 2, skipping 1 + bob.decrypt_message(&sent[0].0, sent[0].1.clone()).unwrap(); + bob.decrypt_message(&sent[2].0, sent[2].1.clone()).unwrap(); + + assert_eq!(bob.skipped_keys.len(), 1); + + // Save and reload + storage.save("conv1", &bob).unwrap(); + let mut loaded: RatchetState = storage.load("conv1").unwrap(); + + assert_eq!(loaded.skipped_keys.len(), 1); + + // Should be able to decrypt skipped message + let pt = loaded + .decrypt_message(&sent[1].0, sent[1].1.clone()) + .unwrap(); + assert_eq!(pt, sent[1].2); + } + + #[test] + fn test_update_existing() { + let mut storage = create_test_storage(); + let (mut alice, mut bob) = create_test_state(); + + storage.save("conv1", &alice).unwrap(); + + // Exchange a message + let (ct, header) = alice.encrypt_message(b"Hello"); + bob.decrypt_message(&ct, header).unwrap(); + + // Update Alice's state + storage.save("conv1", &alice).unwrap(); + + let loaded: RatchetState = storage.load("conv1").unwrap(); + assert_eq!(loaded.msg_send, 1); + } + + #[test] + fn test_exists() { + let mut storage = create_test_storage(); + let (alice, _) = create_test_state(); + + assert!(!storage.exists("conv1").unwrap()); + storage.save("conv1", &alice).unwrap(); + assert!(storage.exists("conv1").unwrap()); + } + + #[test] + fn test_delete() { + let mut storage = create_test_storage(); + let (alice, _) = create_test_state(); + + storage.save("conv1", &alice).unwrap(); + assert!(storage.exists("conv1").unwrap()); + + storage.delete("conv1").unwrap(); + assert!(!storage.exists("conv1").unwrap()); + } + + #[test] + fn test_continue_conversation_after_reload() { + let mut storage = create_test_storage(); + let (mut alice, mut bob) = create_test_state(); + + // Exchange messages + let (ct1, h1) = alice.encrypt_message(b"Hello Bob"); + bob.decrypt_message(&ct1, h1).unwrap(); + + let (ct2, h2) = bob.encrypt_message(b"Hello Alice"); + alice.decrypt_message(&ct2, h2).unwrap(); + + // Save both + storage.save("alice", &alice).unwrap(); + storage.save("bob", &bob).unwrap(); + + // Reload + let mut alice_new: RatchetState = storage.load("alice").unwrap(); + let mut bob_new: RatchetState = storage.load("bob").unwrap(); + + // Continue conversation + let (ct3, h3) = alice_new.encrypt_message(b"After reload"); + let pt3 = bob_new.decrypt_message(&ct3, h3).unwrap(); + assert_eq!(pt3, b"After reload"); + + let (ct4, h4) = bob_new.encrypt_message(b"Reply after reload"); + let pt4 = alice_new.decrypt_message(&ct4, h4).unwrap(); + assert_eq!(pt4, b"Reply after reload"); + } +} diff --git a/double-ratchets/src/storage/types.rs b/double-ratchets/src/storage/types.rs new file mode 100644 index 0000000..6a1cd80 --- /dev/null +++ b/double-ratchets/src/storage/types.rs @@ -0,0 +1,81 @@ +use crate::{ + hkdf::HkdfInfo, + state::{RatchetState, SkippedKey}, + types::MessageKey, +}; +use thiserror::Error; +use x25519_dalek::PublicKey; + +#[derive(Debug, Error)] +pub enum StorageError { + #[error("database error: {0}")] + Database(#[from] rusqlite::Error), + + #[error("conversation not found: {0}")] + ConversationNotFound(String), + + #[error("serialization error")] + Serialization, + + #[error("deserialization error")] + Deserialization, +} + +/// Stored representation of a skipped message key. + +/// Raw state data for storage (without generic parameter). +#[derive(Debug, Clone)] +pub struct RatchetStateRecord { + pub root_key: [u8; 32], + pub sending_chain: Option<[u8; 32]>, + pub receiving_chain: Option<[u8; 32]>, + pub dh_self_secret: [u8; 32], + pub dh_remote: Option<[u8; 32]>, + pub msg_send: u32, + pub msg_recv: u32, + pub prev_chain_len: u32, +} + +impl From<&RatchetState> for RatchetStateRecord { + fn from(state: &RatchetState) -> Self { + Self { + root_key: state.root_key, + sending_chain: state.sending_chain, + receiving_chain: state.receiving_chain, + dh_self_secret: state.dh_self.secret_bytes(), + dh_remote: state.dh_remote.map(|pk| pk.to_bytes()), + msg_send: state.msg_send, + msg_recv: state.msg_recv, + prev_chain_len: state.prev_chain_len, + } + } +} + +impl RatchetStateRecord { + pub fn into_ratchet_state(self, skipped_keys: Vec) -> RatchetState { + use crate::keypair::InstallationKeyPair; + use std::collections::HashMap; + use std::marker::PhantomData; + + let dh_self = InstallationKeyPair::from_secret_bytes(self.dh_self_secret); + let dh_remote = self.dh_remote.map(PublicKey::from); + + let skipped: HashMap<(PublicKey, u32), MessageKey> = skipped_keys + .into_iter() + .map(|sk| ((PublicKey::from(sk.public_key), sk.msg_num), sk.message_key)) + .collect(); + + RatchetState { + root_key: self.root_key, + sending_chain: self.sending_chain, + receiving_chain: self.receiving_chain, + dh_self, + dh_remote, + msg_send: self.msg_send, + msg_recv: self.msg_recv, + prev_chain_len: self.prev_chain_len, + skipped_keys: skipped, + _domain: PhantomData, + } + } +}