Managed data storage for Ratchet state (#21)

* feat: managed persist storage with sqlite

* chore: sync skipped keys

* chore: refactor

* chore: refactor

* chore: clean code

* chore: export skipped keys from state.

* chore: renaming data to record

* chore: remove types from stroage mod file
This commit is contained in:
kaichao 2026-01-28 14:54:16 +08:00 committed by GitHub
parent 4b1069a4a8
commit 8e2b5211b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1500 additions and 30 deletions

3
.gitignore vendored
View File

@ -24,3 +24,6 @@ target
# Compiled binary
**/ffi_nim_example
# Temporary data folder
tmp

187
Cargo.lock generated
View File

@ -24,6 +24,12 @@ version = "1.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06"
[[package]]
name = "bitflags"
version = "2.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3"
[[package]]
name = "blake2"
version = "0.10.6"
@ -48,6 +54,16 @@ version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3"
[[package]]
name = "cc"
version = "1.2.54"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6354c81bbfd62d9cfa9cb3c773c2b7b2a3a482d569de977fd0e961f6e7c00583"
dependencies = [
"find-msvc-tools",
"shlex",
]
[[package]]
name = "cfg-if"
version = "1.0.4"
@ -161,7 +177,7 @@ checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
"syn 2.0.114",
]
[[package]]
@ -182,7 +198,7 @@ checksum = "6edb4b64a43d977b8e99788fe3a04d483834fba1215a7e02caa415b626497f7f"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
"syn 2.0.114",
]
[[package]]
@ -205,6 +221,7 @@ dependencies = [
"hkdf",
"rand",
"rand_core",
"rusqlite",
"safer-ffi",
"thiserror",
"x25519-dalek",
@ -282,12 +299,36 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "320bea982e85d42441eb25c49b41218e7eaa2657e8f90bc4eca7437376751e23"
[[package]]
name = "fallible-iterator"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649"
[[package]]
name = "fallible-streaming-iterator"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
[[package]]
name = "fiat-crypto"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d"
[[package]]
name = "find-msvc-tools"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8591b0bcc8a98a64310a2fae1bb3e9b8564dd10e381e6e28010fde8e8e8568db"
[[package]]
name = "foldhash"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "generic-array"
version = "0.14.7"
@ -310,21 +351,39 @@ dependencies = [
[[package]]
name = "getrandom"
version = "0.2.16"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592"
checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0"
dependencies = [
"cfg-if",
"libc",
"wasi",
]
[[package]]
name = "hashbrown"
version = "0.15.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
dependencies = [
"foldhash",
]
[[package]]
name = "hashbrown"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100"
[[package]]
name = "hashlink"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1"
dependencies = [
"hashbrown 0.15.5",
]
[[package]]
name = "hex"
version = "0.4.3"
@ -356,7 +415,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017"
dependencies = [
"equivalent",
"hashbrown",
"hashbrown 0.16.1",
]
[[package]]
@ -388,9 +447,21 @@ dependencies = [
[[package]]
name = "libc"
version = "0.2.178"
version = "0.2.180"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091"
checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc"
[[package]]
name = "libsqlite3-sys"
version = "0.33.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "947e6816f7825b2b45027c2c32e7085da9934defa535de4a6a46b10a4d5257fa"
dependencies = [
"cc",
"openssl-sys",
"pkg-config",
"vcpkg",
]
[[package]]
name = "logos-chat"
@ -434,6 +505,28 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
[[package]]
name = "openssl-src"
version = "300.5.4+3.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a507b3792995dae9b0df8a1c1e3771e8418b7c2d9f0baeba32e6fe8b06c7cb72"
dependencies = [
"cc",
]
[[package]]
name = "openssl-sys"
version = "0.9.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321"
dependencies = [
"cc",
"libc",
"openssl-src",
"pkg-config",
"vcpkg",
]
[[package]]
name = "paste"
version = "1.0.15"
@ -450,6 +543,12 @@ dependencies = [
"spki",
]
[[package]]
name = "pkg-config"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
[[package]]
name = "poly1305"
version = "0.8.0"
@ -491,9 +590,9 @@ dependencies = [
[[package]]
name = "proc-macro2"
version = "1.0.103"
version = "1.0.106"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8"
checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
dependencies = [
"unicode-ident",
]
@ -518,14 +617,14 @@ dependencies = [
"itertools",
"proc-macro2",
"quote",
"syn 2.0.111",
"syn 2.0.114",
]
[[package]]
name = "quote"
version = "1.0.42"
version = "1.0.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f"
checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4"
dependencies = [
"proc-macro2",
]
@ -560,6 +659,20 @@ dependencies = [
"getrandom",
]
[[package]]
name = "rusqlite"
version = "0.35.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a22715a5d6deef63c637207afbe68d0c72c3f8d0022d7cf9714c442d6157606b"
dependencies = [
"bitflags",
"fallible-iterator",
"fallible-streaming-iterator",
"hashlink",
"libsqlite3-sys",
"smallvec",
]
[[package]]
name = "rustc_version"
version = "0.4.1"
@ -646,7 +759,7 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
"syn 2.0.114",
]
[[package]]
@ -666,6 +779,12 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f179d4e11094a893b82fff208f74d448a7512f99f5a0acbd5c679b705f83ed9"
[[package]]
name = "shlex"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "signature"
version = "2.2.0"
@ -675,6 +794,12 @@ dependencies = [
"rand_core",
]
[[package]]
name = "smallvec"
version = "1.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03"
[[package]]
name = "spki"
version = "0.7.3"
@ -739,9 +864,9 @@ dependencies = [
[[package]]
name = "syn"
version = "2.0.111"
version = "2.0.114"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87"
checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a"
dependencies = [
"proc-macro2",
"quote",
@ -750,22 +875,22 @@ dependencies = [
[[package]]
name = "thiserror"
version = "2.0.17"
version = "2.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8"
checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "2.0.17"
version = "2.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913"
checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
"syn 2.0.114",
]
[[package]]
@ -835,6 +960,12 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0976c77def3f1f75c4ef892a292c31c0bbe9e3d0702c63044d7c76db298171a3"
[[package]]
name = "vcpkg"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
name = "version_check"
version = "0.9.5"
@ -906,22 +1037,22 @@ dependencies = [
[[package]]
name = "zerocopy"
version = "0.8.31"
version = "0.8.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd74ec98b9250adb3ca554bdde269adf631549f51d8a8f8f0a10b50f1cb298c3"
checksum = "71ddd76bcebeed25db614f82bf31a9f4222d3fbba300e6fb6c00afa26cbd4d9d"
dependencies = [
"zerocopy-derive",
]
[[package]]
name = "zerocopy-derive"
version = "0.8.31"
version = "0.8.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d8a8d209fdf45cf5138cbb5a506f6b52522a25afccc534d1475dad8e31105c6a"
checksum = "d8187381b52e32220d50b255276aa16a084ec0a9017a0ca2152a1f55c539758d"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
"syn 2.0.114",
]
[[package]]
@ -935,11 +1066,11 @@ dependencies = [
[[package]]
name = "zeroize_derive"
version = "1.4.2"
version = "1.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69"
checksum = "85a5b4158499876c763cb03bc4e49185d3cccbabb15b33c627f7884f43db852e"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
"syn 2.0.114",
]

View File

@ -20,6 +20,10 @@ thiserror = "2"
blake2 = "0.10.6"
safer-ffi = "0.1.13"
zeroize = "1.8.2"
rusqlite = { version = "0.35", optional = true, features = ["bundled"] }
[features]
default = []
storage = ["rusqlite"]
sqlcipher = ["storage", "rusqlite/bundled-sqlcipher-vendored-openssl"]
headers = ["safer-ffi/headers"]

View File

@ -17,8 +17,11 @@ let plaintext = bob.decrypt_message(&ciphertext, header);
Run examples,
```
```bash
cargo run --example double_ratchet_basic
cargo run --example storage_demo --features storage
cargo run --example storage_demo --features sqlcipher
```
Run Nim FFI example,

View File

@ -0,0 +1,166 @@
//! Demonstrates out-of-order message handling with skipped keys persistence.
//!
//! Run with: cargo run --example out_of_order_demo --features storage
#[cfg(feature = "storage")]
use double_ratchets::{
InstallationKeyPair, RatchetState, SqliteStorage, StorageConfig, hkdf::DefaultDomain,
state::Header,
};
fn main() {
println!("=== Out-of-Order Message Handling Demo (skipped - enable 'storage' feature) ===\n");
#[cfg(feature = "storage")]
run_demo();
}
#[cfg(feature = "storage")]
fn run_demo() {
let mut storage =
SqliteStorage::new(StorageConfig::InMemory).expect("Failed to create storage");
// Setup
let shared_secret = [0x42u8; 32];
let bob_keypair = InstallationKeyPair::generate();
let alice_state: RatchetState<DefaultDomain> =
RatchetState::init_sender(shared_secret, bob_keypair.public().clone());
let bob_state: RatchetState<DefaultDomain> =
RatchetState::init_receiver(shared_secret, bob_keypair);
storage.save("alice", &alice_state).unwrap();
storage.save("bob", &bob_state).unwrap();
// === Alice sends 5 messages ===
println!("Alice sends 5 messages...");
let mut messages: Vec<(Vec<u8>, Header)> = Vec::new();
for i in 1..=5 {
let mut alice: RatchetState<DefaultDomain> = storage.load("alice").unwrap();
let msg = format!("Message #{}", i);
let (ct, header) = alice.encrypt_message(msg.as_bytes());
storage.save("alice", &alice).unwrap();
messages.push((ct, header));
println!(" Sent: \"{}\"", msg);
}
// === Bob receives messages out of order: 1, 3, 5 ===
println!("\nBob receives messages 1, 3, 5 (out of order)...");
for &idx in &[0, 2, 4] {
let mut bob: RatchetState<DefaultDomain> = storage.load("bob").unwrap();
let (ct, header) = &messages[idx];
let pt = bob
.decrypt_message(ct, header.clone())
.expect("Decrypt failed");
storage.save("bob", &bob).unwrap();
println!(" Received: \"{}\"", String::from_utf8_lossy(&pt));
}
let bob: RatchetState<DefaultDomain> = storage.load("bob").unwrap();
println!("\nBob's skipped_keys count: {}", bob.skipped_keys.len());
println!(" (Messages 2 and 4 keys are stored for later)");
// === Simulate Bob's app restart ===
println!("\n--- Simulating Bob's app restart ---");
drop(storage);
// In-memory storage doesn't persist across restarts.
// Use file storage to properly demonstrate persistence:
println!(" (Using file storage to demonstrate real persistence)");
if let Err(e) = std::fs::create_dir_all("./tmp") {
eprintln!("Failed to create tmp directory: {}", e);
return; // Or handle as needed
}
let db_path = "./tmp/out_of_order_demo.db";
let _ = std::fs::remove_file(db_path);
// Redo with file storage
let mut storage = SqliteStorage::new(StorageConfig::File(db_path.to_string()))
.expect("Failed to create storage");
// Re-setup
let bob_keypair = InstallationKeyPair::generate();
let alice_state: RatchetState<DefaultDomain> =
RatchetState::init_sender(shared_secret, bob_keypair.public().clone());
let bob_state: RatchetState<DefaultDomain> =
RatchetState::init_receiver(shared_secret, bob_keypair);
storage.save("alice", &alice_state).unwrap();
storage.save("bob", &bob_state).unwrap();
// Alice sends 5 messages
let mut messages: Vec<(Vec<u8>, Header)> = Vec::new();
for i in 1..=5 {
let mut alice: RatchetState<DefaultDomain> = storage.load("alice").unwrap();
let msg = format!("Message #{}", i);
let (ct, header) = alice.encrypt_message(msg.as_bytes());
storage.save("alice", &alice).unwrap();
messages.push((ct, header));
}
println!(" Alice sent 5 messages");
// Bob receives 1, 3, 5 (skips 2, 4)
for &idx in &[0, 2, 4] {
let mut bob: RatchetState<DefaultDomain> = storage.load("bob").unwrap();
let (ct, header) = &messages[idx];
bob.decrypt_message(ct, header.clone()).unwrap();
storage.save("bob", &bob).unwrap();
}
let bob: RatchetState<DefaultDomain> = storage.load("bob").unwrap();
println!(
" Bob received 1,3,5. Skipped keys stored: {}",
bob.skipped_keys.len()
);
// Close and reopen storage (simulating app restart)
drop(storage);
let mut storage =
SqliteStorage::new(StorageConfig::File(db_path.to_string())).expect("Failed to reopen");
let bob: RatchetState<DefaultDomain> = storage.load("bob").unwrap();
println!(
"\n After restart, Bob's skipped_keys: {}",
bob.skipped_keys.len()
);
// === Now Bob receives the delayed messages ===
println!("\nBob receives delayed message 2...");
{
let mut bob: RatchetState<DefaultDomain> = storage.load("bob").unwrap();
let (ct, header) = &messages[1];
let pt = bob.decrypt_message(ct, header.clone()).unwrap();
storage.save("bob", &bob).unwrap();
println!(" Received: \"{}\"", String::from_utf8_lossy(&pt));
println!(" Remaining skipped_keys: {}", bob.skipped_keys.len());
}
println!("\nBob receives delayed message 4...");
let (ct4, header4) = messages[3].clone();
{
let mut bob: RatchetState<DefaultDomain> = storage.load("bob").unwrap();
let pt = bob.decrypt_message(&ct4, header4.clone()).unwrap();
storage.save("bob", &bob).unwrap();
println!(" Received: \"{}\"", String::from_utf8_lossy(&pt));
println!(" Remaining skipped_keys: {}", bob.skipped_keys.len());
}
// === Demonstrate replay protection ===
println!("\n--- Replay Protection Demo ---");
println!("Trying to decrypt message 4 again (should fail)...");
{
let mut bob: RatchetState<DefaultDomain> = storage.load("bob").unwrap();
match bob.decrypt_message(&ct4, header4) {
Ok(_) => println!(" ERROR: Replay attack succeeded!"),
Err(e) => println!(" Correctly rejected: {:?}", e),
}
}
// Cleanup
let _ = std::fs::remove_file(db_path);
println!("\n=== Demo Complete ===");
}

View File

@ -0,0 +1,241 @@
//! Demonstrates SQLite storage for Double Ratchet state persistence.
//!
//! Run with: cargo run --example storage_demo --features storage
//! For SQLCipher: cargo run --example storage_demo --features sqlcipher
#[cfg(feature = "storage")]
use double_ratchets::{
InstallationKeyPair, RatchetSession, SqliteStorage, StorageConfig, hkdf::PrivateV1Domain,
};
fn main() {
println!("=== Double Ratchet Storage Demo ===\n");
// Demo 1: In-memory storage (for testing)
println!("--- Demo 1: In-Memory Storage (skipped - enable 'storage' feature) ---");
#[cfg(feature = "storage")]
demo_in_memory();
// Demo 2: File-based storage (for local development)
println!("\n--- Demo 2: File-Based Storage (skipped - enable 'storage' feature) ---");
#[cfg(feature = "storage")]
demo_file_storage();
// Demo 3: SQLCipher encrypted storage (for production)
#[cfg(feature = "sqlcipher")]
{
println!("\n--- Demo 3: SQLCipher Encrypted Storage ---");
demo_sqlcipher();
}
#[cfg(not(feature = "sqlcipher"))]
{
println!("\n--- Demo 3: SQLCipher (skipped - enable 'sqlcipher' feature) ---");
}
}
#[cfg(feature = "storage")]
fn demo_in_memory() {
let mut alice_storage =
SqliteStorage::new(StorageConfig::InMemory).expect("Failed to create storage");
let mut bob_storage =
SqliteStorage::new(StorageConfig::InMemory).expect("Failed to create storage");
run_conversation(&mut alice_storage, &mut bob_storage);
}
#[cfg(feature = "storage")]
fn demo_file_storage() {
ensure_tmp_directory();
let db_path_alice = "./tmp/double_ratchet_demo_alice.db";
let db_path_bob = "./tmp/double_ratchet_demo_bob.db";
let _ = std::fs::remove_file(db_path_alice);
let _ = std::fs::remove_file(db_path_bob);
// Initial conversation
{
let mut alice_storage = SqliteStorage::new(StorageConfig::File(db_path_alice.to_string()))
.expect("Failed to create storage");
let mut bob_storage = SqliteStorage::new(StorageConfig::File(db_path_bob.to_string()))
.expect("Failed to create storage");
println!(" Database created at: {}, {}", db_path_alice, db_path_bob);
run_conversation(&mut alice_storage, &mut bob_storage);
}
// Simulate restart - reopen and continue
println!("\n Simulating application restart...");
{
let mut alice_storage = SqliteStorage::new(StorageConfig::File(db_path_alice.to_string()))
.expect("Failed to reopen storage");
let mut bob_storage = SqliteStorage::new(StorageConfig::File(db_path_bob.to_string()))
.expect("Failed to reopen storage");
continue_after_restart(&mut alice_storage, &mut bob_storage);
}
let _ = std::fs::remove_file(db_path_alice);
let _ = std::fs::remove_file(db_path_bob);
}
#[cfg(feature = "sqlcipher")]
fn demo_sqlcipher() {
ensure_tmp_directory();
let alice_db_path = "./tmp/double_ratchet_encrypted_alice.db";
let bob_db_path = "./tmp/double_ratchet_encrypted_bob.db";
let encryption_key = "super-secret-key-123!";
let _ = std::fs::remove_file(alice_db_path);
let _ = std::fs::remove_file(bob_db_path);
// Initial conversation with encryption
{
let mut alice_storage = SqliteStorage::new(StorageConfig::Encrypted {
path: alice_db_path.to_string(),
key: encryption_key.to_string(),
})
.expect("Failed to create encrypted storage");
let mut bob_storage = SqliteStorage::new(StorageConfig::Encrypted {
path: bob_db_path.to_string(),
key: encryption_key.to_string(),
})
.expect("Failed to create encrypted storage");
println!(
" Encrypted database created at: {}, {}",
alice_db_path, bob_db_path
);
run_conversation(&mut alice_storage, &mut bob_storage);
}
// Restart with correct key
println!("\n Simulating restart with encryption key...");
{
let mut alice_storage = SqliteStorage::new(StorageConfig::Encrypted {
path: alice_db_path.to_string(),
key: encryption_key.to_string(),
})
.expect("Failed to create encrypted storage");
let mut bob_storage = SqliteStorage::new(StorageConfig::Encrypted {
path: bob_db_path.to_string(),
key: encryption_key.to_string(),
})
.expect("Failed to create encrypted storage");
continue_after_restart(&mut alice_storage, &mut bob_storage);
}
let _ = std::fs::remove_file(alice_db_path);
let _ = std::fs::remove_file(bob_db_path);
}
#[allow(dead_code)]
fn ensure_tmp_directory() {
if let Err(e) = std::fs::create_dir_all("./tmp") {
eprintln!("Failed to create tmp directory: {}", e);
return; // Or handle as needed
}
}
/// Simulates a conversation between Alice and Bob.
/// Each party saves/loads state from storage for each operation.
#[cfg(feature = "storage")]
fn run_conversation(alice_storage: &mut SqliteStorage, bob_storage: &mut SqliteStorage) {
// === Setup: Simulate X3DH key exchange ===
let shared_secret = [0x42u8; 32]; // In reality, this comes from X3DH
let bob_keypair = InstallationKeyPair::generate();
let conv_id = "conv1";
let mut alice_session: RatchetSession<PrivateV1Domain> = RatchetSession::create_sender_session(
alice_storage,
conv_id,
shared_secret,
bob_keypair.public().clone(),
)
.unwrap();
let mut bob_session: RatchetSession<PrivateV1Domain> =
RatchetSession::create_receiver_session(bob_storage, conv_id, shared_secret, bob_keypair)
.unwrap();
println!(" Sessions created for Alice and Bob");
// === Message 1: Alice -> Bob ===
let (ct1, h1) = {
let result = alice_session
.encrypt_message(b"Hello Bob! This is message 1.")
.unwrap();
println!(" Alice sent: \"Hello Bob! This is message 1.\"");
result
};
{
let pt = bob_session.decrypt_message(&ct1, h1).unwrap();
println!(" Bob received: \"{}\"", String::from_utf8_lossy(&pt));
}
// === Message 2: Bob -> Alice (triggers DH ratchet) ===
let (ct2, h2) = {
let result = bob_session
.encrypt_message(b"Hi Alice! Got your message.")
.unwrap();
println!(" Bob sent: \"Hi Alice! Got your message.\"");
result
};
{
let pt = alice_session.decrypt_message(&ct2, h2).unwrap();
println!(" Alice received: \"{}\"", String::from_utf8_lossy(&pt));
}
// === Message 3: Alice -> Bob ===
let (ct3, h3) = {
let result = alice_session
.encrypt_message(b"Great! Let's keep chatting.")
.unwrap();
println!(" Alice sent: \"Great! Let's keep chatting.\"");
result
};
{
let pt = bob_session.decrypt_message(&ct3, h3).unwrap();
println!(" Bob received: \"{}\"", String::from_utf8_lossy(&pt));
}
// Print final state
println!(
" State after conversation: Alice msg_send={}, Bob msg_recv={}",
alice_session.msg_send(),
bob_session.msg_recv()
);
}
#[cfg(feature = "storage")]
fn continue_after_restart(alice_storage: &mut SqliteStorage, bob_storage: &mut SqliteStorage) {
// Load persisted states
let conv_id = "conv1";
let mut alice_session: RatchetSession<PrivateV1Domain> =
RatchetSession::open(alice_storage, conv_id).unwrap();
let mut bob_session: RatchetSession<PrivateV1Domain> =
RatchetSession::open(bob_storage, conv_id).unwrap();
println!(" Sessions restored for Alice and Bob",);
// Continue conversation
let (ct, header) = {
let result = alice_session
.encrypt_message(b"Message after restart!")
.unwrap();
println!(" Alice sent: \"Message after restart!\"");
result
};
{
let pt = bob_session.decrypt_message(&ct, header).unwrap();
println!(" Bob received: \"{}\"", String::from_utf8_lossy(&pt));
}
println!(
" Final state: Alice msg_send={}, Bob msg_recv={}",
alice_session.msg_send(),
bob_session.msg_recv()
);
}

View File

@ -24,4 +24,16 @@ impl InstallationKeyPair {
pub fn public(&self) -> &PublicKey {
&self.public
}
/// Export the secret key as raw bytes for storage.
pub fn secret_bytes(&self) -> [u8; 32] {
self.secret.to_bytes()
}
/// Reconstruct from secret key bytes.
pub fn from_secret_bytes(bytes: [u8; 32]) -> Self {
let secret = StaticSecret::from(bytes);
let public = PublicKey::from(&secret);
Self { secret, public }
}
}

View File

@ -4,7 +4,11 @@ pub mod ffi;
pub mod hkdf;
pub mod keypair;
pub mod state;
#[cfg(feature = "storage")]
pub mod storage;
pub mod types;
pub use keypair::InstallationKeyPair;
pub use state::{Header, RatchetState};
#[cfg(feature = "storage")]
pub use storage::{RatchetSession, SessionError, SqliteStorage, StorageConfig, StorageError};

View File

@ -31,7 +31,15 @@ pub struct RatchetState<D: HkdfInfo = DefaultDomain> {
pub skipped_keys: HashMap<(PublicKey, u32), MessageKey>,
_domain: PhantomData<D>,
pub(crate) _domain: PhantomData<D>,
}
/// Represents a skipped message key for storage or inspection.
#[derive(Debug, Clone)]
pub struct SkippedKey {
pub public_key: [u8; 32],
pub msg_num: u32,
pub message_key: MessageKey,
}
/// Public header attached to every encrypted message (unencrypted but authenticated).
@ -290,6 +298,22 @@ impl<D: HkdfInfo> RatchetState<D> {
Ok(())
}
/// Exports the skipped keys for storage or inspection.
///
/// # Returns
///
/// A vector of `SkippedKey` representing the currently stored skipped message keys.
pub fn skipped_keys(&self) -> Vec<SkippedKey> {
self.skipped_keys
.iter()
.map(|((pk, msg_num), mk)| SkippedKey {
public_key: pk.to_bytes(),
msg_num: *msg_num,
message_key: *mk,
})
.collect()
}
}
#[cfg(test)]
@ -488,4 +512,53 @@ mod tests {
assert!(result.is_err());
assert_eq!(result.unwrap_err(), RatchetError::MessageReplay);
}
#[test]
fn test_skipped_keys_export() {
let (mut alice, mut bob, _) = setup_alice_bob();
// Initially no skipped keys
assert!(bob.skipped_keys().is_empty());
// Alice sends 4 messages
let mut encrypted = vec![];
for i in 0..4 {
let msg = format!("Message {}", i).into_bytes();
let (ct, h) = alice.encrypt_message(&msg);
encrypted.push((ct, h, msg));
}
// Bob receives message 0 first
bob.decrypt_message(&encrypted[0].0, encrypted[0].1.clone())
.unwrap();
assert!(bob.skipped_keys().is_empty());
// Bob receives message 3, skipping 1 and 2
bob.decrypt_message(&encrypted[3].0, encrypted[3].1.clone())
.unwrap();
// Now we should have 2 skipped keys (for messages 1 and 2)
let skipped = bob.skipped_keys();
assert_eq!(skipped.len(), 2);
// Verify the skipped keys have the expected message numbers
let msg_nums: Vec<u32> = skipped.iter().map(|sk| sk.msg_num).collect();
assert!(msg_nums.contains(&1));
assert!(msg_nums.contains(&2));
// Verify each skipped key has valid data
for sk in &skipped {
assert_eq!(sk.public_key.len(), 32);
assert_eq!(sk.message_key.len(), 32);
}
// Now decrypt message 1 using the skipped key
bob.decrypt_message(&encrypted[1].0, encrypted[1].1.clone())
.unwrap();
// Should only have 1 skipped key left (for message 2)
let skipped_after = bob.skipped_keys();
assert_eq!(skipped_after.len(), 1);
assert_eq!(skipped_after[0].msg_num, 2);
}
}

View File

@ -0,0 +1,5 @@
mod session;
mod sqlite;
pub use session::{RatchetSession, SessionError};
pub use sqlite::{SqliteStorage, StorageConfig};

View File

@ -0,0 +1,310 @@
use x25519_dalek::PublicKey;
use crate::{
InstallationKeyPair,
errors::RatchetError,
hkdf::HkdfInfo,
state::{Header, RatchetState},
types::SharedSecret,
};
use super::{SqliteStorage, StorageError};
/// A session wrapper that automatically persists ratchet state after operations.
/// Provides rollback semantics - state is only saved if the operation succeeds.
pub struct RatchetSession<'a, D: HkdfInfo + Clone> {
storage: &'a mut SqliteStorage,
conversation_id: String,
state: RatchetState<D>,
}
#[derive(Debug)]
pub enum SessionError {
Storage(StorageError),
Ratchet(RatchetError),
}
impl From<StorageError> for SessionError {
fn from(e: StorageError) -> Self {
SessionError::Storage(e)
}
}
impl From<RatchetError> for SessionError {
fn from(e: RatchetError) -> Self {
SessionError::Ratchet(e)
}
}
impl std::fmt::Display for SessionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SessionError::Storage(e) => write!(f, "storage error: {}", e),
SessionError::Ratchet(e) => write!(f, "ratchet error: {}", e),
}
}
}
impl std::error::Error for SessionError {}
impl<'a, D: HkdfInfo + Clone> RatchetSession<'a, D> {
/// Opens an existing session from storage.
pub fn open(
storage: &'a mut SqliteStorage,
conversation_id: impl Into<String>,
) -> Result<Self, StorageError> {
let conversation_id = conversation_id.into();
let state = storage.load(&conversation_id)?;
Ok(Self {
storage,
conversation_id,
state,
})
}
/// Creates a new session and persists the initial state.
pub fn create(
storage: &'a mut SqliteStorage,
conversation_id: impl Into<String>,
state: RatchetState<D>,
) -> Result<Self, StorageError> {
let conversation_id = conversation_id.into();
storage.save(&conversation_id, &state)?;
Ok(Self {
storage,
conversation_id,
state,
})
}
/// Initializes a new session as a sender and persists the initial state.
pub fn create_sender_session(
storage: &'a mut SqliteStorage,
conversation_id: impl Into<String>,
shared_secret: SharedSecret,
remote_pub: PublicKey,
) -> Result<Self, StorageError> {
let state = RatchetState::<D>::init_sender(shared_secret, remote_pub);
Self::create(storage, conversation_id, state)
}
/// Initializes a new session as a receiver and persists the initial state.
pub fn create_receiver_session(
storage: &'a mut SqliteStorage,
conversation_id: impl Into<String>,
shared_secret: SharedSecret,
dh_self: InstallationKeyPair,
) -> Result<Self, StorageError> {
let conversation_id = conversation_id.into();
if storage.exists(&conversation_id)? {
return Self::open(storage, conversation_id);
}
let state = RatchetState::<D>::init_receiver(shared_secret, dh_self);
Self::create(storage, conversation_id, state)
}
/// Encrypts a message and persists the updated state.
/// If persistence fails, the in-memory state is NOT modified.
pub fn encrypt_message(&mut self, plaintext: &[u8]) -> Result<(Vec<u8>, Header), SessionError> {
// Clone state for rollback
let state_backup = self.state.clone();
// Perform encryption (modifies state)
let result = self.state.encrypt_message(plaintext);
// Try to persist
if let Err(e) = self.storage.save(&self.conversation_id, &self.state) {
// Rollback
self.state = state_backup;
return Err(SessionError::Storage(e));
}
Ok(result)
}
/// Decrypts a message and persists the updated state.
/// If decryption or persistence fails, the in-memory state is NOT modified.
pub fn decrypt_message(
&mut self,
ciphertext_with_nonce: &[u8],
header: Header,
) -> Result<Vec<u8>, SessionError> {
// Clone state for rollback
let state_backup = self.state.clone();
// Perform decryption (modifies state)
let plaintext = match self.state.decrypt_message(ciphertext_with_nonce, header) {
Ok(pt) => pt,
Err(e) => {
// Rollback on decrypt failure
self.state = state_backup;
return Err(SessionError::Ratchet(e));
}
};
// Try to persist
if let Err(e) = self.storage.save(&self.conversation_id, &self.state) {
// Rollback
self.state = state_backup;
return Err(SessionError::Storage(e));
}
Ok(plaintext)
}
/// Returns a reference to the current state (read-only).
pub fn state(&self) -> &RatchetState<D> {
&self.state
}
/// Returns the conversation ID.
pub fn conversation_id(&self) -> &str {
&self.conversation_id
}
/// Manually saves the current state.
pub fn save(&mut self) -> Result<(), StorageError> {
self.storage.save(&self.conversation_id, &self.state)
}
pub fn msg_send(&self) -> u32 {
self.state.msg_send
}
pub fn msg_recv(&self) -> u32 {
self.state.msg_recv
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{hkdf::DefaultDomain, keypair::InstallationKeyPair, storage::StorageConfig};
fn create_test_storage() -> SqliteStorage {
SqliteStorage::new(StorageConfig::InMemory).unwrap()
}
#[test]
fn test_session_create_and_open() {
let mut storage = create_test_storage();
let shared_secret = [0x42; 32];
let bob_keypair = InstallationKeyPair::generate();
let alice: RatchetState<DefaultDomain> =
RatchetState::init_sender(shared_secret, bob_keypair.public().clone());
// Create session
{
let session = RatchetSession::create(&mut storage, "conv1", alice).unwrap();
assert_eq!(session.conversation_id(), "conv1");
}
// Open existing session
{
let session: RatchetSession<DefaultDomain> =
RatchetSession::open(&mut storage, "conv1").unwrap();
assert_eq!(session.state().msg_send, 0);
}
}
#[test]
fn test_session_encrypt_persists() {
let mut storage = create_test_storage();
let shared_secret = [0x42; 32];
let bob_keypair = InstallationKeyPair::generate();
let alice: RatchetState<DefaultDomain> =
RatchetState::init_sender(shared_secret, bob_keypair.public().clone());
// Create and encrypt
{
let mut session = RatchetSession::create(&mut storage, "conv1", alice).unwrap();
session.encrypt_message(b"Hello").unwrap();
assert_eq!(session.state().msg_send, 1);
}
// Reopen - state should be persisted
{
let session: RatchetSession<DefaultDomain> =
RatchetSession::open(&mut storage, "conv1").unwrap();
assert_eq!(session.state().msg_send, 1);
}
}
#[test]
fn test_session_full_conversation() {
let mut storage = create_test_storage();
let shared_secret = [0x42; 32];
let bob_keypair = InstallationKeyPair::generate();
let alice: RatchetState<DefaultDomain> =
RatchetState::init_sender(shared_secret, bob_keypair.public().clone());
let bob: RatchetState<DefaultDomain> =
RatchetState::init_receiver(shared_secret, bob_keypair);
// Alice sends
let (ct, header) = {
let mut session = RatchetSession::create(&mut storage, "alice", alice).unwrap();
session.encrypt_message(b"Hello Bob").unwrap()
};
// Bob receives
let plaintext = {
let mut session = RatchetSession::create(&mut storage, "bob", bob).unwrap();
session.decrypt_message(&ct, header).unwrap()
};
assert_eq!(plaintext, b"Hello Bob");
// Bob replies
let (ct2, header2) = {
let mut session: RatchetSession<DefaultDomain> =
RatchetSession::open(&mut storage, "bob").unwrap();
session.encrypt_message(b"Hi Alice").unwrap()
};
// Alice receives
let plaintext2 = {
let mut session: RatchetSession<DefaultDomain> =
RatchetSession::open(&mut storage, "alice").unwrap();
session.decrypt_message(&ct2, header2).unwrap()
};
assert_eq!(plaintext2, b"Hi Alice");
}
#[test]
fn test_session_open_or_create() {
let mut storage = create_test_storage();
let shared_secret = [0x42; 32];
let bob_keypair = InstallationKeyPair::generate();
let bob_pub = bob_keypair.public().clone();
// First call creates
{
let session: RatchetSession<DefaultDomain> = RatchetSession::create_sender_session(
&mut storage,
"conv1",
shared_secret,
bob_pub.clone(),
)
.unwrap();
assert_eq!(session.state().msg_send, 0);
}
// Second call opens existing
{
let mut session: RatchetSession<DefaultDomain> =
RatchetSession::open(&mut storage, "conv1").unwrap();
session.encrypt_message(b"test").unwrap();
}
// Verify persistence
{
let session: RatchetSession<DefaultDomain> =
RatchetSession::open(&mut storage, "conv1").unwrap();
assert_eq!(session.state().msg_send, 1);
}
}
}

View File

@ -0,0 +1,437 @@
use rusqlite::{Connection, params};
use super::{RatchetStateRecord, SkippedKey, StorageError};
use crate::{hkdf::HkdfInfo, state::RatchetState};
/// Configuration for SQLite storage.
#[derive(Debug, Clone)]
pub enum StorageConfig {
/// In-memory database (for testing).
InMemory,
/// File-based SQLite database (unencrypted, for local dev).
File(String),
/// SQLCipher encrypted database (for production).
/// Requires the `sqlcipher` feature.
#[cfg(feature = "sqlcipher")]
Encrypted { path: String, key: String },
}
/// SQLite-based storage for ratchet state.
pub struct SqliteStorage {
conn: Connection,
}
impl SqliteStorage {
/// Creates a new SQLite storage with the given configuration.
pub fn new(config: StorageConfig) -> Result<Self, StorageError> {
let conn = match config {
StorageConfig::InMemory => Connection::open_in_memory()?,
StorageConfig::File(path) => Connection::open(path)?,
#[cfg(feature = "sqlcipher")]
StorageConfig::Encrypted { path, key } => {
let conn = Connection::open(path)?;
conn.pragma_update(None, "key", &key)?;
conn
}
};
let storage = Self { conn };
storage.init_schema()?;
Ok(storage)
}
fn init_schema(&self) -> Result<(), StorageError> {
self.conn.execute_batch(
"
CREATE TABLE IF NOT EXISTS ratchet_state (
conversation_id TEXT PRIMARY KEY,
root_key BLOB NOT NULL,
sending_chain BLOB,
receiving_chain BLOB,
dh_self_secret BLOB NOT NULL,
dh_remote BLOB,
msg_send INTEGER NOT NULL,
msg_recv INTEGER NOT NULL,
prev_chain_len INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS skipped_keys (
conversation_id TEXT NOT NULL,
public_key BLOB NOT NULL,
msg_num INTEGER NOT NULL,
message_key BLOB NOT NULL,
created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')),
PRIMARY KEY (conversation_id, public_key, msg_num),
FOREIGN KEY (conversation_id) REFERENCES ratchet_state(conversation_id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_skipped_keys_conversation
ON skipped_keys(conversation_id);
",
)?;
Ok(())
}
/// Saves the ratchet state for a conversation within a transaction.
/// Rolls back automatically if any error occurs.
pub fn save<D: HkdfInfo>(
&mut self,
conversation_id: &str,
state: &RatchetState<D>,
) -> Result<(), StorageError> {
let tx = self.conn.transaction()?;
let data = RatchetStateRecord::from(state);
let skipped_keys: Vec<SkippedKey> = state.skipped_keys();
// Upsert main state
tx.execute(
"
INSERT INTO ratchet_state (
conversation_id, root_key, sending_chain, receiving_chain,
dh_self_secret, dh_remote, msg_send, msg_recv, prev_chain_len
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)
ON CONFLICT(conversation_id) DO UPDATE SET
root_key = excluded.root_key,
sending_chain = excluded.sending_chain,
receiving_chain = excluded.receiving_chain,
dh_self_secret = excluded.dh_self_secret,
dh_remote = excluded.dh_remote,
msg_send = excluded.msg_send,
msg_recv = excluded.msg_recv,
prev_chain_len = excluded.prev_chain_len
",
params![
conversation_id,
data.root_key.as_slice(),
data.sending_chain.as_ref().map(|c| c.as_slice()),
data.receiving_chain.as_ref().map(|c| c.as_slice()),
data.dh_self_secret.as_slice(),
data.dh_remote.as_ref().map(|c| c.as_slice()),
data.msg_send,
data.msg_recv,
data.prev_chain_len,
],
)?;
// Sync skipped keys efficiently - only insert new, delete removed
sync_skipped_keys(&tx, conversation_id, skipped_keys)?;
tx.commit()?;
Ok(())
}
/// Loads the ratchet state for a conversation.
pub fn load<D: HkdfInfo>(
&self,
conversation_id: &str,
) -> Result<RatchetState<D>, StorageError> {
let data = self.load_state_data(conversation_id)?;
let skipped_keys = self.load_skipped_keys(conversation_id)?;
Ok(data.into_ratchet_state(skipped_keys))
}
fn load_state_data(&self, conversation_id: &str) -> Result<RatchetStateRecord, StorageError> {
let mut stmt = self.conn.prepare(
"
SELECT root_key, sending_chain, receiving_chain, dh_self_secret,
dh_remote, msg_send, msg_recv, prev_chain_len
FROM ratchet_state
WHERE conversation_id = ?1
",
)?;
stmt.query_row(params![conversation_id], |row| {
Ok(RatchetStateRecord {
root_key: blob_to_array(row.get::<_, Vec<u8>>(0)?),
sending_chain: row.get::<_, Option<Vec<u8>>>(1)?.map(blob_to_array),
receiving_chain: row.get::<_, Option<Vec<u8>>>(2)?.map(blob_to_array),
dh_self_secret: blob_to_array(row.get::<_, Vec<u8>>(3)?),
dh_remote: row.get::<_, Option<Vec<u8>>>(4)?.map(blob_to_array),
msg_send: row.get(5)?,
msg_recv: row.get(6)?,
prev_chain_len: row.get(7)?,
})
})
.map_err(|e| match e {
rusqlite::Error::QueryReturnedNoRows => {
StorageError::ConversationNotFound(conversation_id.to_string())
}
e => StorageError::Database(e),
})
}
fn load_skipped_keys(&self, conversation_id: &str) -> Result<Vec<SkippedKey>, StorageError> {
let mut stmt = self.conn.prepare(
"
SELECT public_key, msg_num, message_key
FROM skipped_keys
WHERE conversation_id = ?1
",
)?;
let rows = stmt.query_map(params![conversation_id], |row| {
Ok(SkippedKey {
public_key: blob_to_array(row.get::<_, Vec<u8>>(0)?),
msg_num: row.get(1)?,
message_key: blob_to_array(row.get::<_, Vec<u8>>(2)?),
})
})?;
rows.collect::<Result<Vec<_>, _>>()
.map_err(StorageError::Database)
}
/// Checks if a conversation exists.
pub fn exists(&self, conversation_id: &str) -> Result<bool, StorageError> {
let count: i64 = self.conn.query_row(
"SELECT COUNT(*) FROM ratchet_state WHERE conversation_id = ?1",
params![conversation_id],
|row| row.get(0),
)?;
Ok(count > 0)
}
/// Deletes a conversation and its skipped keys.
pub fn delete(&mut self, conversation_id: &str) -> Result<(), StorageError> {
let tx = self.conn.transaction()?;
tx.execute(
"DELETE FROM skipped_keys WHERE conversation_id = ?1",
params![conversation_id],
)?;
tx.execute(
"DELETE FROM ratchet_state WHERE conversation_id = ?1",
params![conversation_id],
)?;
tx.commit()?;
Ok(())
}
/// Cleans up old skipped keys older than the given age in seconds.
pub fn cleanup_old_skipped_keys(&mut self, max_age_secs: i64) -> Result<usize, StorageError> {
let deleted = self.conn.execute(
"DELETE FROM skipped_keys WHERE created_at < strftime('%s', 'now') - ?1",
params![max_age_secs],
)?;
Ok(deleted)
}
}
/// Syncs skipped keys efficiently by computing diff and only inserting/deleting changes.
fn sync_skipped_keys(
tx: &rusqlite::Transaction,
conversation_id: &str,
current_keys: Vec<SkippedKey>,
) -> Result<(), StorageError> {
use std::collections::HashSet;
// Get existing keys from DB (just the identifiers)
let mut stmt =
tx.prepare("SELECT public_key, msg_num FROM skipped_keys WHERE conversation_id = ?1")?;
let existing: HashSet<([u8; 32], u32)> = stmt
.query_map(params![conversation_id], |row| {
Ok((
blob_to_array(row.get::<_, Vec<u8>>(0)?),
row.get::<_, u32>(1)?,
))
})?
.filter_map(|r| r.ok())
.collect();
// Build set of current keys
let current_set: HashSet<([u8; 32], u32)> = current_keys
.iter()
.map(|sk| (sk.public_key, sk.msg_num))
.collect();
// Delete keys that were removed (used for decryption)
for (pk, msg_num) in existing.difference(&current_set) {
tx.execute(
"DELETE FROM skipped_keys WHERE conversation_id = ?1 AND public_key = ?2 AND msg_num = ?3",
params![conversation_id, pk.as_slice(), msg_num],
)?;
}
// Insert new keys
for sk in &current_keys {
let key = (sk.public_key, sk.msg_num);
if !existing.contains(&key) {
tx.execute(
"INSERT INTO skipped_keys (conversation_id, public_key, msg_num, message_key)
VALUES (?1, ?2, ?3, ?4)",
params![
conversation_id,
sk.public_key.as_slice(),
sk.msg_num,
sk.message_key.as_slice(),
],
)?;
}
}
Ok(())
}
fn blob_to_array<const N: usize>(blob: Vec<u8>) -> [u8; N] {
blob.try_into()
.unwrap_or_else(|v: Vec<u8>| panic!("Expected {} bytes, got {}", N, v.len()))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{hkdf::DefaultDomain, keypair::InstallationKeyPair};
fn create_test_storage() -> SqliteStorage {
SqliteStorage::new(StorageConfig::InMemory).unwrap()
}
fn create_test_state() -> (RatchetState<DefaultDomain>, RatchetState<DefaultDomain>) {
let shared_secret = [0x42; 32];
let bob_keypair = InstallationKeyPair::generate();
let alice = RatchetState::init_sender(shared_secret, bob_keypair.public().clone());
let bob = RatchetState::init_receiver(shared_secret, bob_keypair);
(alice, bob)
}
#[test]
fn test_save_and_load_sender() {
let mut storage = create_test_storage();
let (alice, _) = create_test_state();
storage.save("conv1", &alice).unwrap();
let loaded: RatchetState<DefaultDomain> = storage.load("conv1").unwrap();
assert_eq!(alice.root_key, loaded.root_key);
assert_eq!(alice.sending_chain, loaded.sending_chain);
assert_eq!(alice.receiving_chain, loaded.receiving_chain);
assert_eq!(alice.msg_send, loaded.msg_send);
assert_eq!(alice.msg_recv, loaded.msg_recv);
assert_eq!(alice.prev_chain_len, loaded.prev_chain_len);
assert_eq!(
alice.dh_self.public().to_bytes(),
loaded.dh_self.public().to_bytes()
);
}
#[test]
fn test_save_and_load_receiver() {
let mut storage = create_test_storage();
let (_, bob) = create_test_state();
storage.save("conv1", &bob).unwrap();
let loaded: RatchetState<DefaultDomain> = storage.load("conv1").unwrap();
assert_eq!(bob.root_key, loaded.root_key);
assert!(loaded.dh_remote.is_none());
}
#[test]
fn test_load_not_found() {
let storage = create_test_storage();
let result: Result<RatchetState<DefaultDomain>, _> = storage.load("nonexistent");
assert!(matches!(result, Err(StorageError::ConversationNotFound(_))));
}
#[test]
fn test_save_with_skipped_keys() {
let mut storage = create_test_storage();
let (mut alice, mut bob) = create_test_state();
// Alice sends 3 messages
let mut sent = vec![];
for i in 0..3 {
let plaintext = format!("Message {}", i + 1).into_bytes();
let (ct, header) = alice.encrypt_message(&plaintext);
sent.push((ct, header, plaintext));
}
// Bob receives 0 and 2, skipping 1
bob.decrypt_message(&sent[0].0, sent[0].1.clone()).unwrap();
bob.decrypt_message(&sent[2].0, sent[2].1.clone()).unwrap();
assert_eq!(bob.skipped_keys.len(), 1);
// Save and reload
storage.save("conv1", &bob).unwrap();
let mut loaded: RatchetState<DefaultDomain> = storage.load("conv1").unwrap();
assert_eq!(loaded.skipped_keys.len(), 1);
// Should be able to decrypt skipped message
let pt = loaded
.decrypt_message(&sent[1].0, sent[1].1.clone())
.unwrap();
assert_eq!(pt, sent[1].2);
}
#[test]
fn test_update_existing() {
let mut storage = create_test_storage();
let (mut alice, mut bob) = create_test_state();
storage.save("conv1", &alice).unwrap();
// Exchange a message
let (ct, header) = alice.encrypt_message(b"Hello");
bob.decrypt_message(&ct, header).unwrap();
// Update Alice's state
storage.save("conv1", &alice).unwrap();
let loaded: RatchetState<DefaultDomain> = storage.load("conv1").unwrap();
assert_eq!(loaded.msg_send, 1);
}
#[test]
fn test_exists() {
let mut storage = create_test_storage();
let (alice, _) = create_test_state();
assert!(!storage.exists("conv1").unwrap());
storage.save("conv1", &alice).unwrap();
assert!(storage.exists("conv1").unwrap());
}
#[test]
fn test_delete() {
let mut storage = create_test_storage();
let (alice, _) = create_test_state();
storage.save("conv1", &alice).unwrap();
assert!(storage.exists("conv1").unwrap());
storage.delete("conv1").unwrap();
assert!(!storage.exists("conv1").unwrap());
}
#[test]
fn test_continue_conversation_after_reload() {
let mut storage = create_test_storage();
let (mut alice, mut bob) = create_test_state();
// Exchange messages
let (ct1, h1) = alice.encrypt_message(b"Hello Bob");
bob.decrypt_message(&ct1, h1).unwrap();
let (ct2, h2) = bob.encrypt_message(b"Hello Alice");
alice.decrypt_message(&ct2, h2).unwrap();
// Save both
storage.save("alice", &alice).unwrap();
storage.save("bob", &bob).unwrap();
// Reload
let mut alice_new: RatchetState<DefaultDomain> = storage.load("alice").unwrap();
let mut bob_new: RatchetState<DefaultDomain> = storage.load("bob").unwrap();
// Continue conversation
let (ct3, h3) = alice_new.encrypt_message(b"After reload");
let pt3 = bob_new.decrypt_message(&ct3, h3).unwrap();
assert_eq!(pt3, b"After reload");
let (ct4, h4) = bob_new.encrypt_message(b"Reply after reload");
let pt4 = alice_new.decrypt_message(&ct4, h4).unwrap();
assert_eq!(pt4, b"Reply after reload");
}
}

View File

@ -0,0 +1,81 @@
use crate::{
hkdf::HkdfInfo,
state::{RatchetState, SkippedKey},
types::MessageKey,
};
use thiserror::Error;
use x25519_dalek::PublicKey;
#[derive(Debug, Error)]
pub enum StorageError {
#[error("database error: {0}")]
Database(#[from] rusqlite::Error),
#[error("conversation not found: {0}")]
ConversationNotFound(String),
#[error("serialization error")]
Serialization,
#[error("deserialization error")]
Deserialization,
}
/// Stored representation of a skipped message key.
/// Raw state data for storage (without generic parameter).
#[derive(Debug, Clone)]
pub struct RatchetStateRecord {
pub root_key: [u8; 32],
pub sending_chain: Option<[u8; 32]>,
pub receiving_chain: Option<[u8; 32]>,
pub dh_self_secret: [u8; 32],
pub dh_remote: Option<[u8; 32]>,
pub msg_send: u32,
pub msg_recv: u32,
pub prev_chain_len: u32,
}
impl<D: HkdfInfo> From<&RatchetState<D>> for RatchetStateRecord {
fn from(state: &RatchetState<D>) -> Self {
Self {
root_key: state.root_key,
sending_chain: state.sending_chain,
receiving_chain: state.receiving_chain,
dh_self_secret: state.dh_self.secret_bytes(),
dh_remote: state.dh_remote.map(|pk| pk.to_bytes()),
msg_send: state.msg_send,
msg_recv: state.msg_recv,
prev_chain_len: state.prev_chain_len,
}
}
}
impl RatchetStateRecord {
pub fn into_ratchet_state<D: HkdfInfo>(self, skipped_keys: Vec<SkippedKey>) -> RatchetState<D> {
use crate::keypair::InstallationKeyPair;
use std::collections::HashMap;
use std::marker::PhantomData;
let dh_self = InstallationKeyPair::from_secret_bytes(self.dh_self_secret);
let dh_remote = self.dh_remote.map(PublicKey::from);
let skipped: HashMap<(PublicKey, u32), MessageKey> = skipped_keys
.into_iter()
.map(|sk| ((PublicKey::from(sk.public_key), sk.msg_num), sk.message_key))
.collect();
RatchetState {
root_key: self.root_key,
sending_chain: self.sending_chain,
receiving_chain: self.receiving_chain,
dh_self,
dh_remote,
msg_send: self.msg_send,
msg_recv: self.msg_recv,
prev_chain_len: self.prev_chain_len,
skipped_keys: skipped,
_domain: PhantomData,
}
}
}