feat: memory and sqlcipher storage

This commit is contained in:
kaichaosun 2026-01-20 17:36:16 +08:00
parent fc76453f4c
commit 34a03275cc
No known key found for this signature in database
GPG Key ID: 223E0F992F4F03BF
11 changed files with 1810 additions and 1 deletions

231
Cargo.lock generated
View File

@ -12,6 +12,24 @@ dependencies = [
"generic-array",
]
[[package]]
name = "ahash"
version = "0.8.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75"
dependencies = [
"cfg-if",
"once_cell",
"version_check",
"zerocopy",
]
[[package]]
name = "bitflags"
version = "2.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3"
[[package]]
name = "blake2"
version = "0.10.6"
@ -30,6 +48,16 @@ dependencies = [
"generic-array",
]
[[package]]
name = "cc"
version = "1.2.53"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "755d2fce177175ffca841e9a06afdb2c4ab0f593d53b4dee48147dfaade85932"
dependencies = [
"find-msvc-tools",
"shlex",
]
[[package]]
name = "cfg-if"
version = "1.0.4"
@ -141,12 +169,60 @@ dependencies = [
"x25519-dalek",
]
[[package]]
name = "double-ratchets-storage"
version = "0.1.0"
dependencies = [
"chacha20poly1305",
"double-ratchets",
"rand",
"rusqlite",
"serde",
"tempfile",
"thiserror",
"x25519-dalek",
]
[[package]]
name = "errno"
version = "0.3.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
"windows-sys",
]
[[package]]
name = "fallible-iterator"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649"
[[package]]
name = "fallible-streaming-iterator"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "fiat-crypto"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d"
[[package]]
name = "find-msvc-tools"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8591b0bcc8a98a64310a2fae1bb3e9b8564dd10e381e6e28010fde8e8e8568db"
[[package]]
name = "generic-array"
version = "0.14.7"
@ -168,6 +244,36 @@ dependencies = [
"wasi",
]
[[package]]
name = "getrandom"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd"
dependencies = [
"cfg-if",
"libc",
"r-efi",
"wasip2",
]
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
dependencies = [
"ahash",
]
[[package]]
name = "hashlink"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af"
dependencies = [
"hashbrown",
]
[[package]]
name = "hkdf"
version = "0.12.4"
@ -201,6 +307,23 @@ version = "0.2.178"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091"
[[package]]
name = "libsqlite3-sys"
version = "0.30.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149"
dependencies = [
"cc",
"pkg-config",
"vcpkg",
]
[[package]]
name = "linux-raw-sys"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039"
[[package]]
name = "logos-chat"
version = "0.1.0"
@ -208,12 +331,24 @@ dependencies = [
"thiserror",
]
[[package]]
name = "once_cell"
version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "opaque-debug"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
[[package]]
name = "pkg-config"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
[[package]]
name = "poly1305"
version = "0.8.0"
@ -252,6 +387,12 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "r-efi"
version = "5.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
[[package]]
name = "rand"
version = "0.8.5"
@ -279,7 +420,21 @@ version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
"getrandom 0.2.16",
]
[[package]]
name = "rusqlite"
version = "0.32.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e"
dependencies = [
"bitflags",
"fallible-iterator",
"fallible-streaming-iterator",
"hashlink",
"libsqlite3-sys",
"smallvec",
]
[[package]]
@ -291,6 +446,19 @@ dependencies = [
"semver",
]
[[package]]
name = "rustix"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34"
dependencies = [
"bitflags",
"errno",
"libc",
"linux-raw-sys",
"windows-sys",
]
[[package]]
name = "semver"
version = "1.0.27"
@ -327,6 +495,18 @@ dependencies = [
"syn",
]
[[package]]
name = "shlex"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "smallvec"
version = "1.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03"
[[package]]
name = "subtle"
version = "2.6.1"
@ -344,6 +524,19 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "tempfile"
version = "3.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c"
dependencies = [
"fastrand",
"getrandom 0.3.4",
"once_cell",
"rustix",
"windows-sys",
]
[[package]]
name = "thiserror"
version = "2.0.17"
@ -386,6 +579,12 @@ dependencies = [
"subtle",
]
[[package]]
name = "vcpkg"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
name = "version_check"
version = "0.9.5"
@ -398,6 +597,36 @@ version = "0.11.1+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b"
[[package]]
name = "wasip2"
version = "1.0.2+wasi-0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5"
dependencies = [
"wit-bindgen",
]
[[package]]
name = "windows-link"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
[[package]]
name = "windows-sys"
version = "0.61.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc"
dependencies = [
"windows-link",
]
[[package]]
name = "wit-bindgen"
version = "0.51.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5"
[[package]]
name = "x25519-dalek"
version = "2.0.1"

View File

@ -5,4 +5,5 @@ resolver = "3"
members = [
"conversations",
"double-ratchets",
"double-ratchets-storage",
]

View File

@ -0,0 +1,21 @@
[package]
name = "double-ratchets-storage"
version = "0.1.0"
edition = "2024"
[features]
default = ["sqlite"]
sqlite = ["rusqlite/bundled"]
sqlcipher = ["rusqlite/bundled-sqlcipher"]
[dependencies]
double-ratchets = { path = "../double-ratchets" }
x25519-dalek = { version = "2.0.1", features = ["static_secrets"] }
rusqlite = { version = "0.32", optional = true }
serde = { version = "1.0", features = ["derive"] }
thiserror = "2"
chacha20poly1305 = "0.10"
rand = "0.8"
[dev-dependencies]
tempfile = "3"

View File

@ -0,0 +1,46 @@
//! Error types for the storage module.
use thiserror::Error;
/// Errors that can occur during storage operations.
#[derive(Error, Debug)]
pub enum StorageError {
/// Database operation failed.
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
#[error("database error: {0}")]
Database(#[from] rusqlite::Error),
/// Field-level encryption failed.
#[error("encryption failed: {0}")]
Encryption(String),
/// Field-level decryption failed.
#[error("decryption failed: {0}")]
Decryption(String),
/// Stored state is corrupted or invalid.
#[error("corrupted state: {0}")]
CorruptedState(String),
/// Session was not found in storage.
#[error("session not found: {}", hex::encode(.session_id))]
SessionNotFound {
/// The session ID that was not found.
session_id: [u8; 32],
},
/// I/O operation failed.
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
/// Key reconstruction failed.
#[error("key reconstruction failed: {0}")]
KeyReconstruction(String),
}
/// Helper module for hex encoding (used in error messages).
mod hex {
pub fn encode(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{:02x}", b)).collect()
}
}

View File

@ -0,0 +1,178 @@
//! Persistent storage for Double Ratchet state.
//!
//! This crate provides storage backends for persisting [`RatchetState`](double_ratchets::RatchetState)
//! across application restarts. It includes:
//!
//! - [`MemoryStorage`] - In-memory storage for testing
//! - [`SqliteStorage`] - SQLite storage with field-level encryption (requires `sqlite` feature)
//!
//! # Features
//!
//! - `sqlite` (default) - Enables SQLite storage with `rusqlite/bundled`
//! - `sqlcipher` - Enables SQLCipher full-database encryption (mutually exclusive with `sqlite`)
//!
//! # Security
//!
//! Private keys (`dh_self_secret`) are always encrypted with ChaCha20Poly1305 before storage,
//! even when using plain SQLite. For additional security, enable the `sqlcipher` feature
//! for full database encryption.
//!
//! # Example
//!
//! ```no_run
//! use double_ratchets::hkdf::DefaultDomain;
//! use double_ratchets::state::RatchetState;
//! use double_ratchets::InstallationKeyPair;
//! use double_ratchets_storage::{
//! RatchetStorage, SqliteStorage, StorableRatchetState,
//! };
//!
//! // Create a ratchet state
//! let bob_keypair = InstallationKeyPair::generate();
//! let shared_secret = [0x42u8; 32];
//! let state: RatchetState<DefaultDomain> =
//! RatchetState::init_sender(shared_secret, *bob_keypair.public());
//!
//! // Open storage
//! let encryption_key = [0u8; 32]; // Use proper key derivation!
//! let storage = SqliteStorage::open("ratchets.db", encryption_key).unwrap();
//!
//! // Save state
//! let session_id = [1u8; 32];
//! let storable = StorableRatchetState::from_ratchet_state(&state, "default");
//! storage.save(&session_id, &storable).unwrap();
//!
//! // Load state
//! let loaded = storage.load(&session_id).unwrap().unwrap();
//! let restored: RatchetState<DefaultDomain> = loaded.to_ratchet_state().unwrap();
//! ```
pub mod error;
pub mod memory;
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
pub mod sqlite;
pub mod traits;
pub mod types;
// Re-exports for convenience
pub use error::StorageError;
pub use memory::MemoryStorage;
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
pub use sqlite::{EncryptionKey, SqliteStorage};
pub use traits::{RatchetStorage, SessionId};
pub use types::{SkippedKey, StorableRatchetState};
#[cfg(test)]
mod integration_tests {
use super::*;
use double_ratchets::hkdf::DefaultDomain;
use double_ratchets::state::RatchetState;
use double_ratchets::InstallationKeyPair;
/// Integration test: full encryption/decryption cycle with storage
#[test]
fn test_full_conversation_with_storage_roundtrip() {
// Setup Alice and Bob
let bob_keypair = InstallationKeyPair::generate();
let shared_secret = [0x42u8; 32];
let mut alice: RatchetState<DefaultDomain> =
RatchetState::init_sender(shared_secret, *bob_keypair.public());
let mut bob: RatchetState<DefaultDomain> =
RatchetState::init_receiver(shared_secret, bob_keypair);
let storage = MemoryStorage::new();
let alice_session = [0xAA; 32];
let bob_session = [0xBB; 32];
// Alice sends a message
let (ct1, header1) = alice.encrypt_message(b"Hello Bob!");
// Save Alice's state
let alice_storable = StorableRatchetState::from_ratchet_state(&alice, "default");
storage.save(&alice_session, &alice_storable).unwrap();
// Bob receives the message
let pt1 = bob.decrypt_message(&ct1, header1).unwrap();
assert_eq!(pt1, b"Hello Bob!");
// Save Bob's state
let bob_storable = StorableRatchetState::from_ratchet_state(&bob, "default");
storage.save(&bob_session, &bob_storable).unwrap();
// Simulate restart: load states from storage
let alice_loaded = storage.load(&alice_session).unwrap().unwrap();
let bob_loaded = storage.load(&bob_session).unwrap().unwrap();
let mut alice_restored: RatchetState<DefaultDomain> =
alice_loaded.to_ratchet_state().unwrap();
let mut bob_restored: RatchetState<DefaultDomain> =
bob_loaded.to_ratchet_state().unwrap();
// Bob replies
let (ct2, header2) = bob_restored.encrypt_message(b"Hi Alice!");
let pt2 = alice_restored.decrypt_message(&ct2, header2).unwrap();
assert_eq!(pt2, b"Hi Alice!");
// Alice sends another message
let (ct3, header3) = alice_restored.encrypt_message(b"How are you?");
let pt3 = bob_restored.decrypt_message(&ct3, header3).unwrap();
assert_eq!(pt3, b"How are you?");
}
/// Integration test: verify SQLite storage with encryption works
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
#[test]
fn test_sqlite_integration() {
let dir = tempfile::tempdir().unwrap();
let db_path = dir.path().join("integration_test.db");
let key = [0x42u8; 32];
// Setup
let bob_keypair = InstallationKeyPair::generate();
let shared_secret = [0x42u8; 32];
let mut alice: RatchetState<DefaultDomain> =
RatchetState::init_sender(shared_secret, *bob_keypair.public());
let mut bob: RatchetState<DefaultDomain> =
RatchetState::init_receiver(shared_secret, bob_keypair);
let alice_session = [0xAA; 32];
let bob_session = [0xBB; 32];
// Exchange messages
let (ct1, header1) = alice.encrypt_message(b"Message 1");
bob.decrypt_message(&ct1, header1).unwrap();
let (ct2, header2) = bob.encrypt_message(b"Response 1");
alice.decrypt_message(&ct2, header2).unwrap();
// Save both states
{
let storage = SqliteStorage::open(&db_path, key).unwrap();
let alice_storable = StorableRatchetState::from_ratchet_state(&alice, "default");
let bob_storable = StorableRatchetState::from_ratchet_state(&bob, "default");
storage.save(&alice_session, &alice_storable).unwrap();
storage.save(&bob_session, &bob_storable).unwrap();
}
// Reopen database (simulating restart)
{
let storage = SqliteStorage::open(&db_path, key).unwrap();
let alice_loaded = storage.load(&alice_session).unwrap().unwrap();
let bob_loaded = storage.load(&bob_session).unwrap().unwrap();
let mut alice_restored: RatchetState<DefaultDomain> =
alice_loaded.to_ratchet_state().unwrap();
let mut bob_restored: RatchetState<DefaultDomain> =
bob_loaded.to_ratchet_state().unwrap();
// Continue conversation
let (ct3, header3) = alice_restored.encrypt_message(b"Message 2");
let pt3 = bob_restored.decrypt_message(&ct3, header3).unwrap();
assert_eq!(pt3, b"Message 2");
}
}
}

View File

@ -0,0 +1,216 @@
//! In-memory storage implementation for testing.
use std::collections::HashMap;
use std::sync::RwLock;
use crate::error::StorageError;
use crate::traits::{RatchetStorage, SessionId};
use crate::types::StorableRatchetState;
/// In-memory storage backend for testing purposes.
///
/// This implementation stores ratchet states in a `HashMap` wrapped in a `RwLock`
/// for thread-safe access. Data is not persisted across process restarts.
///
/// # Example
///
/// ```
/// use double_ratchets_storage::{MemoryStorage, RatchetStorage};
///
/// let storage = MemoryStorage::new();
/// assert!(storage.list_sessions().unwrap().is_empty());
/// ```
pub struct MemoryStorage {
states: RwLock<HashMap<SessionId, StorableRatchetState>>,
}
impl MemoryStorage {
/// Create a new empty in-memory storage.
pub fn new() -> Self {
Self {
states: RwLock::new(HashMap::new()),
}
}
/// Get the number of stored sessions.
pub fn len(&self) -> usize {
self.states.read().unwrap().len()
}
/// Check if the storage is empty.
pub fn is_empty(&self) -> bool {
self.states.read().unwrap().is_empty()
}
/// Clear all stored sessions.
pub fn clear(&self) {
self.states.write().unwrap().clear();
}
}
impl Default for MemoryStorage {
fn default() -> Self {
Self::new()
}
}
impl RatchetStorage for MemoryStorage {
fn save(&self, session_id: &SessionId, state: &StorableRatchetState) -> Result<(), StorageError> {
let mut states = self.states.write().unwrap();
states.insert(*session_id, state.clone());
Ok(())
}
fn load(&self, session_id: &SessionId) -> Result<Option<StorableRatchetState>, StorageError> {
let states = self.states.read().unwrap();
Ok(states.get(session_id).cloned())
}
fn delete(&self, session_id: &SessionId) -> Result<bool, StorageError> {
let mut states = self.states.write().unwrap();
Ok(states.remove(session_id).is_some())
}
fn exists(&self, session_id: &SessionId) -> Result<bool, StorageError> {
let states = self.states.read().unwrap();
Ok(states.contains_key(session_id))
}
fn list_sessions(&self) -> Result<Vec<SessionId>, StorageError> {
let states = self.states.read().unwrap();
Ok(states.keys().copied().collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use double_ratchets::hkdf::DefaultDomain;
use double_ratchets::state::RatchetState;
use double_ratchets::InstallationKeyPair;
fn create_test_state() -> StorableRatchetState {
let bob_keypair = InstallationKeyPair::generate();
let shared_secret = [0x42u8; 32];
let state: RatchetState<DefaultDomain> =
RatchetState::init_sender(shared_secret, *bob_keypair.public());
StorableRatchetState::from_ratchet_state(&state, "default")
}
#[test]
fn test_save_and_load() {
let storage = MemoryStorage::new();
let session_id = [1u8; 32];
let state = create_test_state();
storage.save(&session_id, &state).unwrap();
let loaded = storage.load(&session_id).unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.root_key, state.root_key);
}
#[test]
fn test_load_nonexistent() {
let storage = MemoryStorage::new();
let session_id = [1u8; 32];
let loaded = storage.load(&session_id).unwrap();
assert!(loaded.is_none());
}
#[test]
fn test_delete() {
let storage = MemoryStorage::new();
let session_id = [1u8; 32];
let state = create_test_state();
storage.save(&session_id, &state).unwrap();
assert!(storage.exists(&session_id).unwrap());
let deleted = storage.delete(&session_id).unwrap();
assert!(deleted);
assert!(!storage.exists(&session_id).unwrap());
// Deleting again should return false
let deleted = storage.delete(&session_id).unwrap();
assert!(!deleted);
}
#[test]
fn test_exists() {
let storage = MemoryStorage::new();
let session_id = [1u8; 32];
assert!(!storage.exists(&session_id).unwrap());
let state = create_test_state();
storage.save(&session_id, &state).unwrap();
assert!(storage.exists(&session_id).unwrap());
}
#[test]
fn test_list_sessions() {
let storage = MemoryStorage::new();
assert!(storage.list_sessions().unwrap().is_empty());
let state = create_test_state();
let session_ids: Vec<SessionId> = (0..3).map(|i| [i; 32]).collect();
for id in &session_ids {
storage.save(id, &state).unwrap();
}
let mut listed = storage.list_sessions().unwrap();
listed.sort();
let mut expected = session_ids.clone();
expected.sort();
assert_eq!(listed, expected);
}
#[test]
fn test_overwrite() {
let storage = MemoryStorage::new();
let session_id = [1u8; 32];
// Create first state
let bob_keypair1 = InstallationKeyPair::generate();
let state1: RatchetState<DefaultDomain> =
RatchetState::init_sender([0x42u8; 32], *bob_keypair1.public());
let storable1 = StorableRatchetState::from_ratchet_state(&state1, "default");
// Create second state with different root
let bob_keypair2 = InstallationKeyPair::generate();
let state2: RatchetState<DefaultDomain> =
RatchetState::init_sender([0x43u8; 32], *bob_keypair2.public());
let storable2 = StorableRatchetState::from_ratchet_state(&state2, "default");
// Save first, then overwrite with second
storage.save(&session_id, &storable1).unwrap();
storage.save(&session_id, &storable2).unwrap();
// Should have the second state
let loaded = storage.load(&session_id).unwrap().unwrap();
assert_eq!(loaded.root_key, storable2.root_key);
assert_ne!(loaded.root_key, storable1.root_key);
}
#[test]
fn test_clear() {
let storage = MemoryStorage::new();
let state = create_test_state();
for i in 0..5 {
storage.save(&[i; 32], &state).unwrap();
}
assert_eq!(storage.len(), 5);
storage.clear();
assert!(storage.is_empty());
}
}

View File

@ -0,0 +1,746 @@
//! SQLite storage implementation with field-level encryption.
use std::path::Path;
use std::sync::Mutex;
use std::time::{SystemTime, UNIX_EPOCH};
use chacha20poly1305::{
aead::{Aead, KeyInit},
ChaCha20Poly1305, Nonce,
};
use rand::RngCore;
use rusqlite::{params, Connection, OptionalExtension};
use crate::error::StorageError;
use crate::traits::{RatchetStorage, SessionId};
use crate::types::{SkippedKey, StorableRatchetState};
/// Field encryption key type (32 bytes for ChaCha20Poly1305).
pub type EncryptionKey = [u8; 32];
/// SQLite storage backend with field-level encryption for secrets.
///
/// This implementation stores ratchet states in SQLite with:
/// - Field-level encryption for private keys using ChaCha20Poly1305
/// - WAL mode for better concurrent performance
/// - Foreign keys and cascading deletes for data integrity
/// - Atomic transactions to prevent partial writes
///
/// # Security
///
/// The `dh_self_secret` field is encrypted with the provided encryption key.
/// For additional security, consider using SQLCipher for full database encryption
/// via the `open_encrypted` method (requires `sqlcipher` feature).
///
/// # Example
///
/// ```no_run
/// use double_ratchets_storage::SqliteStorage;
///
/// let key = [0u8; 32]; // Use a proper key derivation function
/// let storage = SqliteStorage::open("ratchets.db", key).unwrap();
/// ```
pub struct SqliteStorage {
conn: Mutex<Connection>,
encryption_key: EncryptionKey,
}
impl SqliteStorage {
/// Open or create a SQLite database with field-level encryption.
///
/// # Arguments
///
/// * `path` - Path to the database file.
/// * `encryption_key` - 32-byte key for field-level encryption.
///
/// # Returns
///
/// * `Ok(SqliteStorage)` on success.
/// * `Err(StorageError)` on failure.
pub fn open<P: AsRef<Path>>(path: P, encryption_key: EncryptionKey) -> Result<Self, StorageError> {
let conn = Connection::open(path)?;
Self::initialize(conn, encryption_key)
}
/// Create an in-memory SQLite database (for testing).
///
/// # Arguments
///
/// * `encryption_key` - 32-byte key for field-level encryption.
///
/// # Returns
///
/// * `Ok(SqliteStorage)` on success.
/// * `Err(StorageError)` on failure.
pub fn open_in_memory(encryption_key: EncryptionKey) -> Result<Self, StorageError> {
let conn = Connection::open_in_memory()?;
Self::initialize(conn, encryption_key)
}
/// Open or create a SQLCipher-encrypted database.
///
/// This method requires the `sqlcipher` feature to be enabled.
///
/// # Arguments
///
/// * `path` - Path to the database file.
/// * `db_password` - Password for SQLCipher database encryption.
/// * `field_key` - 32-byte key for additional field-level encryption.
///
/// # Returns
///
/// * `Ok(SqliteStorage)` on success.
/// * `Err(StorageError)` on failure.
#[cfg(feature = "sqlcipher")]
pub fn open_encrypted<P: AsRef<Path>>(
path: P,
db_password: &str,
field_key: EncryptionKey,
) -> Result<Self, StorageError> {
let conn = Connection::open(path)?;
// Set SQLCipher key
conn.pragma_update(None, "key", db_password)?;
Self::initialize(conn, field_key)
}
fn initialize(conn: Connection, encryption_key: EncryptionKey) -> Result<Self, StorageError> {
// Enable WAL mode for better performance
conn.pragma_update(None, "journal_mode", "WAL")?;
// Enable foreign keys
conn.pragma_update(None, "foreign_keys", "ON")?;
// Create tables
conn.execute_batch(
r#"
CREATE TABLE IF NOT EXISTS ratchet_states (
session_id BLOB PRIMARY KEY,
root_key BLOB NOT NULL,
sending_chain BLOB,
receiving_chain BLOB,
dh_self_secret_encrypted BLOB NOT NULL,
dh_self_secret_nonce BLOB NOT NULL,
dh_self_public BLOB NOT NULL,
dh_remote BLOB,
msg_send INTEGER NOT NULL,
msg_recv INTEGER NOT NULL,
prev_chain_len INTEGER NOT NULL,
domain_id TEXT NOT NULL,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS skipped_keys (
id INTEGER PRIMARY KEY,
session_id BLOB NOT NULL REFERENCES ratchet_states(session_id) ON DELETE CASCADE,
public_key BLOB NOT NULL,
msg_num INTEGER NOT NULL,
message_key BLOB NOT NULL,
UNIQUE(session_id, public_key, msg_num)
);
CREATE INDEX IF NOT EXISTS idx_skipped_keys_session ON skipped_keys(session_id);
"#,
)?;
Ok(Self {
conn: Mutex::new(conn),
encryption_key,
})
}
/// Encrypt a 32-byte secret using ChaCha20Poly1305.
fn encrypt_secret(&self, secret: &[u8; 32]) -> Result<(Vec<u8>, [u8; 12]), StorageError> {
let cipher = ChaCha20Poly1305::new((&self.encryption_key).into());
let mut nonce_bytes = [0u8; 12];
rand::thread_rng().fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, secret.as_ref())
.map_err(|e| StorageError::Encryption(e.to_string()))?;
Ok((ciphertext, nonce_bytes))
}
/// Decrypt a secret using ChaCha20Poly1305.
fn decrypt_secret(&self, ciphertext: &[u8], nonce: &[u8; 12]) -> Result<[u8; 32], StorageError> {
let cipher = ChaCha20Poly1305::new((&self.encryption_key).into());
let nonce = Nonce::from_slice(nonce);
let plaintext = cipher
.decrypt(nonce, ciphertext)
.map_err(|e| StorageError::Decryption(e.to_string()))?;
plaintext
.try_into()
.map_err(|_| StorageError::CorruptedState("decrypted secret has wrong length".to_string()))
}
fn current_timestamp() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64
}
}
impl RatchetStorage for SqliteStorage {
fn save(&self, session_id: &SessionId, state: &StorableRatchetState) -> Result<(), StorageError> {
let (encrypted_secret, nonce) = self.encrypt_secret(&state.dh_self_secret)?;
let conn = self.conn.lock().unwrap();
let tx = conn.unchecked_transaction()?;
let now = Self::current_timestamp();
// Check if session exists to determine created_at
let exists: bool = tx.query_row(
"SELECT 1 FROM ratchet_states WHERE session_id = ?",
[session_id.as_slice()],
|_| Ok(true),
).optional()?.unwrap_or(false);
if exists {
// Update existing session
tx.execute(
r#"
UPDATE ratchet_states SET
root_key = ?,
sending_chain = ?,
receiving_chain = ?,
dh_self_secret_encrypted = ?,
dh_self_secret_nonce = ?,
dh_self_public = ?,
dh_remote = ?,
msg_send = ?,
msg_recv = ?,
prev_chain_len = ?,
domain_id = ?,
updated_at = ?
WHERE session_id = ?
"#,
params![
state.root_key.as_slice(),
state.sending_chain.as_ref().map(|c| c.as_slice()),
state.receiving_chain.as_ref().map(|c| c.as_slice()),
encrypted_secret.as_slice(),
nonce.as_slice(),
state.dh_self_public.as_slice(),
state.dh_remote.as_ref().map(|pk| pk.as_slice()),
state.msg_send,
state.msg_recv,
state.prev_chain_len,
&state.domain_id,
now,
session_id.as_slice(),
],
)?;
// Delete existing skipped keys
tx.execute(
"DELETE FROM skipped_keys WHERE session_id = ?",
[session_id.as_slice()],
)?;
} else {
// Insert new session
tx.execute(
r#"
INSERT INTO ratchet_states (
session_id, root_key, sending_chain, receiving_chain,
dh_self_secret_encrypted, dh_self_secret_nonce, dh_self_public,
dh_remote, msg_send, msg_recv, prev_chain_len, domain_id,
created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
params![
session_id.as_slice(),
state.root_key.as_slice(),
state.sending_chain.as_ref().map(|c| c.as_slice()),
state.receiving_chain.as_ref().map(|c| c.as_slice()),
encrypted_secret.as_slice(),
nonce.as_slice(),
state.dh_self_public.as_slice(),
state.dh_remote.as_ref().map(|pk| pk.as_slice()),
state.msg_send,
state.msg_recv,
state.prev_chain_len,
&state.domain_id,
now,
now,
],
)?;
}
// Insert skipped keys
for sk in &state.skipped_keys {
tx.execute(
r#"
INSERT INTO skipped_keys (session_id, public_key, msg_num, message_key)
VALUES (?, ?, ?, ?)
"#,
params![
session_id.as_slice(),
sk.public_key.as_slice(),
sk.msg_num,
sk.message_key.as_slice(),
],
)?;
}
tx.commit()?;
Ok(())
}
fn load(&self, session_id: &SessionId) -> Result<Option<StorableRatchetState>, StorageError> {
let conn = self.conn.lock().unwrap();
let row = conn
.query_row(
r#"
SELECT root_key, sending_chain, receiving_chain,
dh_self_secret_encrypted, dh_self_secret_nonce, dh_self_public,
dh_remote, msg_send, msg_recv, prev_chain_len, domain_id
FROM ratchet_states WHERE session_id = ?
"#,
[session_id.as_slice()],
|row| {
Ok((
row.get::<_, Vec<u8>>(0)?,
row.get::<_, Option<Vec<u8>>>(1)?,
row.get::<_, Option<Vec<u8>>>(2)?,
row.get::<_, Vec<u8>>(3)?,
row.get::<_, Vec<u8>>(4)?,
row.get::<_, Vec<u8>>(5)?,
row.get::<_, Option<Vec<u8>>>(6)?,
row.get::<_, u32>(7)?,
row.get::<_, u32>(8)?,
row.get::<_, u32>(9)?,
row.get::<_, String>(10)?,
))
},
)
.optional()?;
let Some((
root_key_bytes,
sending_chain_bytes,
receiving_chain_bytes,
encrypted_secret,
nonce_bytes,
dh_self_public_bytes,
dh_remote_bytes,
msg_send,
msg_recv,
prev_chain_len,
domain_id,
)) = row
else {
return Ok(None);
};
// Decrypt the secret
let nonce: [u8; 12] = nonce_bytes
.try_into()
.map_err(|_| StorageError::CorruptedState("invalid nonce length".to_string()))?;
let dh_self_secret = self.decrypt_secret(&encrypted_secret, &nonce)?;
// Convert byte vectors to arrays
let root_key: [u8; 32] = root_key_bytes
.try_into()
.map_err(|_| StorageError::CorruptedState("invalid root_key length".to_string()))?;
let dh_self_public: [u8; 32] = dh_self_public_bytes
.try_into()
.map_err(|_| StorageError::CorruptedState("invalid dh_self_public length".to_string()))?;
let sending_chain = sending_chain_bytes
.map(|b| {
b.try_into()
.map_err(|_| StorageError::CorruptedState("invalid sending_chain length".to_string()))
})
.transpose()?;
let receiving_chain = receiving_chain_bytes
.map(|b| {
b.try_into()
.map_err(|_| StorageError::CorruptedState("invalid receiving_chain length".to_string()))
})
.transpose()?;
let dh_remote = dh_remote_bytes
.map(|b| {
b.try_into()
.map_err(|_| StorageError::CorruptedState("invalid dh_remote length".to_string()))
})
.transpose()?;
// Load skipped keys
let mut stmt = conn.prepare(
"SELECT public_key, msg_num, message_key FROM skipped_keys WHERE session_id = ?",
)?;
let skipped_keys: Vec<SkippedKey> = stmt
.query_map([session_id.as_slice()], |row| {
let pk_bytes: Vec<u8> = row.get(0)?;
let msg_num: u32 = row.get(1)?;
let mk_bytes: Vec<u8> = row.get(2)?;
Ok((pk_bytes, msg_num, mk_bytes))
})?
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.map(|(pk_bytes, msg_num, mk_bytes)| {
let public_key: [u8; 32] = pk_bytes
.try_into()
.map_err(|_| StorageError::CorruptedState("invalid skipped public_key length".to_string()))?;
let message_key: [u8; 32] = mk_bytes
.try_into()
.map_err(|_| StorageError::CorruptedState("invalid skipped message_key length".to_string()))?;
Ok(SkippedKey {
public_key,
msg_num,
message_key,
})
})
.collect::<Result<Vec<_>, StorageError>>()?;
Ok(Some(StorableRatchetState {
root_key,
sending_chain,
receiving_chain,
dh_self_secret,
dh_self_public,
dh_remote,
msg_send,
msg_recv,
prev_chain_len,
skipped_keys,
domain_id,
}))
}
fn delete(&self, session_id: &SessionId) -> Result<bool, StorageError> {
let conn = self.conn.lock().unwrap();
let changes = conn.execute(
"DELETE FROM ratchet_states WHERE session_id = ?",
[session_id.as_slice()],
)?;
Ok(changes > 0)
}
fn exists(&self, session_id: &SessionId) -> Result<bool, StorageError> {
let conn = self.conn.lock().unwrap();
let exists: bool = conn
.query_row(
"SELECT 1 FROM ratchet_states WHERE session_id = ?",
[session_id.as_slice()],
|_| Ok(true),
)
.optional()?
.unwrap_or(false);
Ok(exists)
}
fn list_sessions(&self) -> Result<Vec<SessionId>, StorageError> {
let conn = self.conn.lock().unwrap();
let mut stmt = conn.prepare("SELECT session_id FROM ratchet_states")?;
let sessions: Vec<SessionId> = stmt
.query_map([], |row| {
let bytes: Vec<u8> = row.get(0)?;
Ok(bytes)
})?
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.filter_map(|bytes| bytes.try_into().ok())
.collect();
Ok(sessions)
}
}
#[cfg(test)]
mod tests {
use super::*;
use double_ratchets::hkdf::DefaultDomain;
use double_ratchets::state::RatchetState;
use double_ratchets::InstallationKeyPair;
fn create_test_storage() -> SqliteStorage {
let key = [0x42u8; 32];
SqliteStorage::open_in_memory(key).unwrap()
}
fn create_test_state() -> StorableRatchetState {
let bob_keypair = InstallationKeyPair::generate();
let shared_secret = [0x42u8; 32];
let state: RatchetState<DefaultDomain> =
RatchetState::init_sender(shared_secret, *bob_keypair.public());
StorableRatchetState::from_ratchet_state(&state, "default")
}
#[test]
fn test_save_and_load() {
let storage = create_test_storage();
let session_id = [1u8; 32];
let state = create_test_state();
storage.save(&session_id, &state).unwrap();
let loaded = storage.load(&session_id).unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.root_key, state.root_key);
assert_eq!(loaded.dh_self_public, state.dh_self_public);
// Secret should be decrypted correctly
assert_eq!(loaded.dh_self_secret, state.dh_self_secret);
}
#[test]
fn test_load_nonexistent() {
let storage = create_test_storage();
let session_id = [1u8; 32];
let loaded = storage.load(&session_id).unwrap();
assert!(loaded.is_none());
}
#[test]
fn test_delete() {
let storage = create_test_storage();
let session_id = [1u8; 32];
let state = create_test_state();
storage.save(&session_id, &state).unwrap();
assert!(storage.exists(&session_id).unwrap());
let deleted = storage.delete(&session_id).unwrap();
assert!(deleted);
assert!(!storage.exists(&session_id).unwrap());
// Deleting again should return false
let deleted = storage.delete(&session_id).unwrap();
assert!(!deleted);
}
#[test]
fn test_exists() {
let storage = create_test_storage();
let session_id = [1u8; 32];
assert!(!storage.exists(&session_id).unwrap());
let state = create_test_state();
storage.save(&session_id, &state).unwrap();
assert!(storage.exists(&session_id).unwrap());
}
#[test]
fn test_list_sessions() {
let storage = create_test_storage();
assert!(storage.list_sessions().unwrap().is_empty());
let state = create_test_state();
let session_ids: Vec<SessionId> = (0..3).map(|i| [i; 32]).collect();
for id in &session_ids {
storage.save(id, &state).unwrap();
}
let mut listed = storage.list_sessions().unwrap();
listed.sort();
let mut expected = session_ids.clone();
expected.sort();
assert_eq!(listed, expected);
}
#[test]
fn test_overwrite() {
let storage = create_test_storage();
let session_id = [1u8; 32];
// Create first state
let bob_keypair1 = InstallationKeyPair::generate();
let state1: RatchetState<DefaultDomain> =
RatchetState::init_sender([0x42u8; 32], *bob_keypair1.public());
let storable1 = StorableRatchetState::from_ratchet_state(&state1, "default");
// Create second state with different root
let bob_keypair2 = InstallationKeyPair::generate();
let state2: RatchetState<DefaultDomain> =
RatchetState::init_sender([0x43u8; 32], *bob_keypair2.public());
let storable2 = StorableRatchetState::from_ratchet_state(&state2, "default");
// Save first, then overwrite with second
storage.save(&session_id, &storable1).unwrap();
storage.save(&session_id, &storable2).unwrap();
// Should have the second state
let loaded = storage.load(&session_id).unwrap().unwrap();
assert_eq!(loaded.root_key, storable2.root_key);
assert_ne!(loaded.root_key, storable1.root_key);
}
#[test]
fn test_skipped_keys_storage() {
let storage = create_test_storage();
let session_id = [1u8; 32];
// Create states and generate skipped keys
let bob_keypair = InstallationKeyPair::generate();
let shared_secret = [0x42u8; 32];
let mut alice: RatchetState<DefaultDomain> =
RatchetState::init_sender(shared_secret, *bob_keypair.public());
let mut bob: RatchetState<DefaultDomain> =
RatchetState::init_receiver(shared_secret, bob_keypair);
// Alice sends multiple messages
let mut messages = vec![];
for i in 0..3 {
let (ct, header) = alice.encrypt_message(&format!("Message {}", i).into_bytes());
messages.push((ct, header));
}
// Bob receives out of order to create skipped keys
bob.decrypt_message(&messages[0].0, messages[0].1.clone())
.unwrap();
bob.decrypt_message(&messages[2].0, messages[2].1.clone())
.unwrap();
assert!(!bob.skipped_keys.is_empty());
// Save and reload
let storable = StorableRatchetState::from_ratchet_state(&bob, "default");
storage.save(&session_id, &storable).unwrap();
let loaded = storage.load(&session_id).unwrap().unwrap();
assert_eq!(loaded.skipped_keys.len(), storable.skipped_keys.len());
// Restore and verify we can decrypt the skipped message
let mut restored: RatchetState<DefaultDomain> = loaded.to_ratchet_state().unwrap();
let pt = restored
.decrypt_message(&messages[1].0, messages[1].1.clone())
.unwrap();
assert_eq!(pt, b"Message 1");
}
#[test]
fn test_encryption_uses_different_nonces() {
let storage = create_test_storage();
let state = create_test_state();
// Save the same state twice with different session IDs
storage.save(&[1u8; 32], &state).unwrap();
storage.save(&[2u8; 32], &state).unwrap();
// Both should load correctly (encryption with different nonces)
let loaded1 = storage.load(&[1u8; 32]).unwrap().unwrap();
let loaded2 = storage.load(&[2u8; 32]).unwrap().unwrap();
assert_eq!(loaded1.dh_self_secret, loaded2.dh_self_secret);
}
#[test]
fn test_cascade_delete_skipped_keys() {
let storage = create_test_storage();
let session_id = [1u8; 32];
// Create a state with skipped keys
let bob_keypair = InstallationKeyPair::generate();
let shared_secret = [0x42u8; 32];
let mut alice: RatchetState<DefaultDomain> =
RatchetState::init_sender(shared_secret, *bob_keypair.public());
let mut bob: RatchetState<DefaultDomain> =
RatchetState::init_receiver(shared_secret, bob_keypair);
let mut messages = vec![];
for i in 0..3 {
let (ct, header) = alice.encrypt_message(&format!("Message {}", i).into_bytes());
messages.push((ct, header));
}
bob.decrypt_message(&messages[0].0, messages[0].1.clone())
.unwrap();
bob.decrypt_message(&messages[2].0, messages[2].1.clone())
.unwrap();
let storable = StorableRatchetState::from_ratchet_state(&bob, "default");
storage.save(&session_id, &storable).unwrap();
// Verify skipped keys exist
{
let conn = storage.conn.lock().unwrap();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM skipped_keys WHERE session_id = ?",
[session_id.as_slice()],
|row| row.get(0),
)
.unwrap();
assert!(count > 0);
}
// Delete session
storage.delete(&session_id).unwrap();
// Verify skipped keys were also deleted (cascade)
{
let conn = storage.conn.lock().unwrap();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM skipped_keys WHERE session_id = ?",
[session_id.as_slice()],
|row| row.get(0),
)
.unwrap();
assert_eq!(count, 0);
}
}
#[test]
fn test_file_storage() {
let dir = tempfile::tempdir().unwrap();
let db_path = dir.path().join("test.db");
let key = [0x42u8; 32];
let state = create_test_state();
let session_id = [1u8; 32];
// Save in one instance
{
let storage = SqliteStorage::open(&db_path, key).unwrap();
storage.save(&session_id, &state).unwrap();
}
// Load in another instance
{
let storage = SqliteStorage::open(&db_path, key).unwrap();
let loaded = storage.load(&session_id).unwrap().unwrap();
assert_eq!(loaded.root_key, state.root_key);
}
}
#[test]
fn test_wrong_key_fails_decryption() {
let dir = tempfile::tempdir().unwrap();
let db_path = dir.path().join("test.db");
let key1 = [0x42u8; 32];
let key2 = [0x43u8; 32];
let state = create_test_state();
let session_id = [1u8; 32];
// Save with key1
{
let storage = SqliteStorage::open(&db_path, key1).unwrap();
storage.save(&session_id, &state).unwrap();
}
// Try to load with key2 - should fail decryption
{
let storage = SqliteStorage::open(&db_path, key2).unwrap();
let result = storage.load(&session_id);
assert!(result.is_err());
}
}
}

View File

@ -0,0 +1,74 @@
//! Storage trait definitions.
use crate::error::StorageError;
use crate::types::StorableRatchetState;
/// A 32-byte session identifier.
pub type SessionId = [u8; 32];
/// Abstract storage interface for ratchet states.
///
/// Implementations must be thread-safe (`Send + Sync`).
pub trait RatchetStorage: Send + Sync {
/// Save a ratchet state for the given session.
///
/// If a state already exists for this session, it will be overwritten.
///
/// # Arguments
///
/// * `session_id` - Unique identifier for the session.
/// * `state` - The ratchet state to store.
///
/// # Returns
///
/// * `Ok(())` on success.
/// * `Err(StorageError)` on failure.
fn save(&self, session_id: &SessionId, state: &StorableRatchetState) -> Result<(), StorageError>;
/// Load a ratchet state for the given session.
///
/// # Arguments
///
/// * `session_id` - Unique identifier for the session.
///
/// # Returns
///
/// * `Ok(Some(state))` if found.
/// * `Ok(None)` if not found.
/// * `Err(StorageError)` on failure.
fn load(&self, session_id: &SessionId) -> Result<Option<StorableRatchetState>, StorageError>;
/// Delete a ratchet state for the given session.
///
/// # Arguments
///
/// * `session_id` - Unique identifier for the session.
///
/// # Returns
///
/// * `Ok(true)` if the session existed and was deleted.
/// * `Ok(false)` if the session did not exist.
/// * `Err(StorageError)` on failure.
fn delete(&self, session_id: &SessionId) -> Result<bool, StorageError>;
/// Check if a session exists in storage.
///
/// # Arguments
///
/// * `session_id` - Unique identifier for the session.
///
/// # Returns
///
/// * `Ok(true)` if the session exists.
/// * `Ok(false)` if the session does not exist.
/// * `Err(StorageError)` on failure.
fn exists(&self, session_id: &SessionId) -> Result<bool, StorageError>;
/// List all session IDs in storage.
///
/// # Returns
///
/// * `Ok(Vec<SessionId>)` containing all session IDs.
/// * `Err(StorageError)` on failure.
fn list_sessions(&self) -> Result<Vec<SessionId>, StorageError>;
}

View File

@ -0,0 +1,225 @@
//! Serializable types for ratchet state storage.
use std::collections::HashMap;
use double_ratchets::state::RatchetState;
use double_ratchets::hkdf::HkdfInfo;
use double_ratchets::InstallationKeyPair;
use serde::{Deserialize, Serialize};
use x25519_dalek::PublicKey;
use crate::error::StorageError;
/// A skipped message key entry for storage.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct SkippedKey {
/// The public key associated with this skipped message.
pub public_key: [u8; 32],
/// The message number.
pub msg_num: u32,
/// The 32-byte message key.
pub message_key: [u8; 32],
}
/// Serializable version of `RatchetState`.
///
/// This struct stores all keys as raw byte arrays for easy serialization
/// and database storage. Use `from_ratchet_state()` and `to_ratchet_state()`
/// for conversion.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorableRatchetState {
/// The current root key (32 bytes).
pub root_key: [u8; 32],
/// The current sending chain key, if any (32 bytes).
pub sending_chain: Option<[u8; 32]>,
/// The current receiving chain key, if any (32 bytes).
pub receiving_chain: Option<[u8; 32]>,
/// Our DH secret key (32 bytes).
///
/// **Security**: This should be encrypted before storage.
pub dh_self_secret: [u8; 32],
/// Our DH public key (32 bytes).
pub dh_self_public: [u8; 32],
/// Remote party's DH public key, if known (32 bytes).
pub dh_remote: Option<[u8; 32]>,
/// Number of messages sent in the current sending chain.
pub msg_send: u32,
/// Number of messages received in the current receiving chain.
pub msg_recv: u32,
/// Length of the previous sending chain.
pub prev_chain_len: u32,
/// Skipped message keys for out-of-order message handling.
pub skipped_keys: Vec<SkippedKey>,
/// Domain identifier for HKDF info.
pub domain_id: String,
}
impl StorableRatchetState {
/// Convert a `RatchetState` into a `StorableRatchetState`.
///
/// # Type Parameters
///
/// * `D` - The HKDF domain type implementing `HkdfInfo`.
///
/// # Arguments
///
/// * `state` - The ratchet state to convert.
/// * `domain_id` - A string identifier for the domain (used to reconstruct the correct domain type).
pub fn from_ratchet_state<D: HkdfInfo>(state: &RatchetState<D>, domain_id: &str) -> Self {
let skipped_keys: Vec<SkippedKey> = state
.skipped_keys
.iter()
.map(|((pub_key, msg_num), msg_key)| SkippedKey {
public_key: *pub_key.as_bytes(),
msg_num: *msg_num,
message_key: *msg_key,
})
.collect();
StorableRatchetState {
root_key: state.root_key,
sending_chain: state.sending_chain,
receiving_chain: state.receiving_chain,
dh_self_secret: state.dh_self.secret_bytes(),
dh_self_public: *state.dh_self.public().as_bytes(),
dh_remote: state.dh_remote.map(|pk| *pk.as_bytes()),
msg_send: state.msg_send,
msg_recv: state.msg_recv,
prev_chain_len: state.prev_chain_len,
skipped_keys,
domain_id: domain_id.to_string(),
}
}
/// Convert this `StorableRatchetState` back into a `RatchetState`.
///
/// # Type Parameters
///
/// * `D` - The HKDF domain type implementing `HkdfInfo`.
///
/// # Returns
///
/// * `Ok(RatchetState)` on success.
/// * `Err(StorageError)` if key reconstruction fails.
pub fn to_ratchet_state<D: HkdfInfo>(&self) -> Result<RatchetState<D>, StorageError> {
// Reconstruct the keypair
let dh_self = InstallationKeyPair::from_bytes(self.dh_self_secret, self.dh_self_public)
.map_err(|e| StorageError::KeyReconstruction(e.to_string()))?;
// Reconstruct skipped keys HashMap
let skipped_keys: HashMap<(PublicKey, u32), [u8; 32]> = self
.skipped_keys
.iter()
.map(|sk| {
let pub_key = PublicKey::from(sk.public_key);
((pub_key, sk.msg_num), sk.message_key)
})
.collect();
Ok(RatchetState::from_parts(
self.root_key,
self.sending_chain,
self.receiving_chain,
dh_self,
self.dh_remote.map(PublicKey::from),
self.msg_send,
self.msg_recv,
self.prev_chain_len,
skipped_keys,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use double_ratchets::hkdf::DefaultDomain;
#[test]
fn test_roundtrip_sender_state() {
// Create a sender state
let bob_keypair = InstallationKeyPair::generate();
let shared_secret = [0x42u8; 32];
let state: RatchetState<DefaultDomain> =
RatchetState::init_sender(shared_secret, *bob_keypair.public());
// Convert to storable and back
let storable = StorableRatchetState::from_ratchet_state(&state, "default");
let restored: RatchetState<DefaultDomain> = storable.to_ratchet_state().unwrap();
// Verify fields match
assert_eq!(state.root_key, restored.root_key);
assert_eq!(state.sending_chain, restored.sending_chain);
assert_eq!(state.receiving_chain, restored.receiving_chain);
assert_eq!(state.dh_self.public().as_bytes(), restored.dh_self.public().as_bytes());
assert_eq!(state.dh_remote.map(|pk| *pk.as_bytes()), restored.dh_remote.map(|pk| *pk.as_bytes()));
assert_eq!(state.msg_send, restored.msg_send);
assert_eq!(state.msg_recv, restored.msg_recv);
assert_eq!(state.prev_chain_len, restored.prev_chain_len);
}
#[test]
fn test_roundtrip_receiver_state() {
// Create a receiver state
let keypair = InstallationKeyPair::generate();
let shared_secret = [0x42u8; 32];
let state: RatchetState<DefaultDomain> =
RatchetState::init_receiver(shared_secret, keypair);
// Convert to storable and back
let storable = StorableRatchetState::from_ratchet_state(&state, "default");
let restored: RatchetState<DefaultDomain> = storable.to_ratchet_state().unwrap();
// Verify fields match
assert_eq!(state.root_key, restored.root_key);
assert_eq!(state.dh_self.public().as_bytes(), restored.dh_self.public().as_bytes());
assert!(restored.dh_remote.is_none());
}
#[test]
fn test_roundtrip_with_skipped_keys() {
// Create states and exchange messages to generate skipped keys
let bob_keypair = InstallationKeyPair::generate();
let shared_secret = [0x42u8; 32];
let mut alice: RatchetState<DefaultDomain> =
RatchetState::init_sender(shared_secret, *bob_keypair.public());
let mut bob: RatchetState<DefaultDomain> =
RatchetState::init_receiver(shared_secret, bob_keypair);
// Alice sends multiple messages
let mut messages = vec![];
for i in 0..3 {
let (ct, header) = alice.encrypt_message(&format!("Message {}", i).into_bytes());
messages.push((ct, header));
}
// Bob receives them out of order to create skipped keys
bob.decrypt_message(&messages[0].0, messages[0].1.clone()).unwrap();
bob.decrypt_message(&messages[2].0, messages[2].1.clone()).unwrap();
// Message 1 key is now in skipped_keys
assert!(!bob.skipped_keys.is_empty());
// Convert to storable and back
let storable = StorableRatchetState::from_ratchet_state(&bob, "default");
let restored: RatchetState<DefaultDomain> = storable.to_ratchet_state().unwrap();
// Verify skipped keys are preserved
assert_eq!(bob.skipped_keys.len(), restored.skipped_keys.len());
// The restored state should be able to decrypt the skipped message
let mut restored = restored;
let pt = restored.decrypt_message(&messages[1].0, messages[1].1.clone()).unwrap();
assert_eq!(pt, b"Message 1");
}
}

View File

@ -23,4 +23,37 @@ impl InstallationKeyPair {
pub fn public(&self) -> &PublicKey {
&self.public
}
/// Export the secret key as raw bytes for serialization/storage.
///
/// # Security Warning
///
/// The returned bytes contain the private key material. Handle with care
/// and ensure proper encryption when storing.
pub fn secret_bytes(&self) -> [u8; 32] {
self.secret.to_bytes()
}
/// Reconstruct a keypair from raw secret and public key bytes.
///
/// # Arguments
///
/// * `secret` - The 32-byte secret key.
/// * `public` - The 32-byte public key.
///
/// # Returns
///
/// * `Ok(InstallationKeyPair)` if the keys are valid and consistent.
/// * `Err(&'static str)` if the public key doesn't match the secret key.
pub fn from_bytes(secret: [u8; 32], public: [u8; 32]) -> Result<Self, &'static str> {
let secret = StaticSecret::from(secret);
let expected_public = PublicKey::from(&secret);
let public = PublicKey::from(public);
if expected_public.as_bytes() != public.as_bytes() {
return Err("public key does not match secret key");
}
Ok(Self { secret, public })
}
}

View File

@ -59,6 +59,46 @@ impl Header {
}
impl<D: HkdfInfo> RatchetState<D> {
/// Reconstruct a `RatchetState` from its component parts.
///
/// This is primarily used for deserialization from storage.
///
/// # Arguments
///
/// * `root_key` - The current root key.
/// * `sending_chain` - The current sending chain key, if any.
/// * `receiving_chain` - The current receiving chain key, if any.
/// * `dh_self` - Our DH key pair.
/// * `dh_remote` - Remote party's DH public key, if known.
/// * `msg_send` - Number of messages sent in current sending chain.
/// * `msg_recv` - Number of messages received in current receiving chain.
/// * `prev_chain_len` - Length of the previous sending chain.
/// * `skipped_keys` - Map of skipped message keys.
pub fn from_parts(
root_key: RootKey,
sending_chain: Option<ChainKey>,
receiving_chain: Option<ChainKey>,
dh_self: InstallationKeyPair,
dh_remote: Option<PublicKey>,
msg_send: u32,
msg_recv: u32,
prev_chain_len: u32,
skipped_keys: HashMap<(PublicKey, u32), MessageKey>,
) -> Self {
Self {
root_key,
sending_chain,
receiving_chain,
dh_self,
dh_remote,
msg_send,
msg_recv,
prev_chain_len,
skipped_keys,
_domain: PhantomData,
}
}
/// Initializes the party that sends the first message.
///
/// Performs the initial Diffie-Hellman computation with the remote public key