chore: refactor

This commit is contained in:
kaichaosun 2026-01-20 22:52:09 +08:00
parent 34a03275cc
commit 9f968fc80d
No known key found for this signature in database
GPG Key ID: 223E0F992F4F03BF
6 changed files with 1101 additions and 1140 deletions

View File

@ -1,10 +1,13 @@
//! Persistent storage for Double Ratchet state. //! Persistent storage for Double Ratchet state.
//! //!
//! This crate provides storage backends for persisting [`RatchetState`](double_ratchets::RatchetState) //! This crate provides storage backends for persisting [`RatchetState`](double_ratchets::RatchetState)
//! across application restarts. It includes: //! with automatic field-level persistence on each ratchet operation.
//! //!
//! - [`MemoryStorage`] - In-memory storage for testing //! # Main API
//! - [`SqliteStorage`] - SQLite storage with field-level encryption (requires `sqlite` feature) //!
//! - [`PersistentRatchet`] - Wrapper that auto-persists state changes during encrypt/decrypt
//! - [`SqliteRatchetStore`] - SQLite backend with field-level encryption
//! - [`RatchetStore`] - Trait for implementing custom storage backends
//! //!
//! # Features //! # Features
//! //!
@ -13,166 +16,233 @@
//! //!
//! # Security //! # Security
//! //!
//! Private keys (`dh_self_secret`) are always encrypted with ChaCha20Poly1305 before storage, //! Private keys (`dh_self_secret`) and message keys are always encrypted with ChaCha20Poly1305
//! even when using plain SQLite. For additional security, enable the `sqlcipher` feature //! before storage, even when using plain SQLite. For additional security, enable the `sqlcipher`
//! for full database encryption. //! feature for full database encryption.
//! //!
//! # Example //! # Example
//! //!
//! ```no_run //! ```no_run
//! use std::sync::Arc;
//! use double_ratchets::hkdf::DefaultDomain; //! use double_ratchets::hkdf::DefaultDomain;
//! use double_ratchets::state::RatchetState;
//! use double_ratchets::InstallationKeyPair; //! use double_ratchets::InstallationKeyPair;
//! use double_ratchets_storage::{ //! use double_ratchets_storage::{PersistentRatchet, RatchetStore, SqliteRatchetStore};
//! RatchetStorage, SqliteStorage, StorableRatchetState,
//! };
//!
//! // Create a ratchet state
//! let bob_keypair = InstallationKeyPair::generate();
//! let shared_secret = [0x42u8; 32];
//! let state: RatchetState<DefaultDomain> =
//! RatchetState::init_sender(shared_secret, *bob_keypair.public());
//! //!
//! // Open storage //! // Open storage
//! let encryption_key = [0u8; 32]; // Use proper key derivation! //! let encryption_key = [0u8; 32]; // Use proper key derivation!
//! let storage = SqliteStorage::open("ratchets.db", encryption_key).unwrap(); //! let store: Arc<dyn RatchetStore> =
//! Arc::new(SqliteRatchetStore::open("ratchets.db", encryption_key).unwrap());
//! //!
//! // Save state //! // Initialize sender
//! let bob_keypair = InstallationKeyPair::generate();
//! let shared_secret = [0x42u8; 32];
//! let session_id = [1u8; 32]; //! let session_id = [1u8; 32];
//! let storable = StorableRatchetState::from_ratchet_state(&state, "default");
//! storage.save(&session_id, &storable).unwrap();
//! //!
//! // Load state //! let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
//! let loaded = storage.load(&session_id).unwrap().unwrap(); //! Arc::clone(&store),
//! let restored: RatchetState<DefaultDomain> = loaded.to_ratchet_state().unwrap(); //! session_id,
//! shared_secret,
//! *bob_keypair.public(),
//! ).unwrap();
//!
//! // Encrypt - state is automatically persisted
//! let (ciphertext, header) = alice.encrypt_message(b"Hello!").unwrap();
//!
//! // Later: load from storage
//! let alice_restored: PersistentRatchet<DefaultDomain> =
//! PersistentRatchet::load(store, session_id).unwrap().unwrap();
//! ``` //! ```
pub mod error; pub mod error;
pub mod memory; #[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
pub mod persistent;
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))] #[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
pub mod sqlite; pub mod sqlite;
pub mod traits; pub mod traits;
pub mod types;
// Re-exports for convenience // Re-exports for convenience
pub use error::StorageError; pub use error::StorageError;
pub use memory::MemoryStorage;
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))] #[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
pub use sqlite::{EncryptionKey, SqliteStorage}; pub use persistent::{PersistentRatchet, PersistentRatchetError};
pub use traits::{RatchetStorage, SessionId}; #[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
pub use types::{SkippedKey, StorableRatchetState}; pub use sqlite::{EncryptionKey, SqliteRatchetStore};
pub use traits::{RatchetStore, SessionId, SkippedKeyEntry, StoredState};
#[cfg(test)] #[cfg(all(test, any(feature = "sqlite", feature = "sqlcipher")))]
mod integration_tests { mod integration_tests {
use super::*; use super::*;
use double_ratchets::hkdf::DefaultDomain; use double_ratchets::hkdf::DefaultDomain;
use double_ratchets::state::RatchetState;
use double_ratchets::InstallationKeyPair; use double_ratchets::InstallationKeyPair;
use std::sync::Arc;
/// Integration test: full encryption/decryption cycle with storage /// Integration test: full conversation with auto-persist
#[test] #[test]
fn test_full_conversation_with_storage_roundtrip() { fn test_full_conversation_with_auto_persist() {
// Setup Alice and Bob let store: Arc<dyn RatchetStore> =
let bob_keypair = InstallationKeyPair::generate(); Arc::new(SqliteRatchetStore::open_in_memory([0x42u8; 32]).unwrap());
let shared_secret = [0x42u8; 32];
let mut alice: RatchetState<DefaultDomain> =
RatchetState::init_sender(shared_secret, *bob_keypair.public());
let mut bob: RatchetState<DefaultDomain> =
RatchetState::init_receiver(shared_secret, bob_keypair);
let storage = MemoryStorage::new();
let alice_session = [0xAA; 32]; let alice_session = [0xAA; 32];
let bob_session = [0xBB; 32]; let bob_session = [0xBB; 32];
// Alice sends a message let bob_keypair = InstallationKeyPair::generate();
let (ct1, header1) = alice.encrypt_message(b"Hello Bob!"); let shared_secret = [0x42u8; 32];
// Save Alice's state // Initialize both parties - state is auto-persisted
let alice_storable = StorableRatchetState::from_ratchet_state(&alice, "default"); let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
storage.save(&alice_session, &alice_storable).unwrap(); Arc::clone(&store),
alice_session,
shared_secret,
*bob_keypair.public(),
)
.unwrap();
// Bob receives the message let mut bob: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_receiver(
Arc::clone(&store),
bob_session,
shared_secret,
bob_keypair,
)
.unwrap();
// Alice sends - state auto-persisted
let (ct1, header1) = alice.encrypt_message(b"Hello Bob!").unwrap();
let pt1 = bob.decrypt_message(&ct1, header1).unwrap(); let pt1 = bob.decrypt_message(&ct1, header1).unwrap();
assert_eq!(pt1, b"Hello Bob!"); assert_eq!(pt1, b"Hello Bob!");
// Save Bob's state // Verify state was persisted
let bob_storable = StorableRatchetState::from_ratchet_state(&bob, "default"); assert!(store.session_exists(&alice_session).unwrap());
storage.save(&bob_session, &bob_storable).unwrap(); assert!(store.session_exists(&bob_session).unwrap());
// Simulate restart: load states from storage // Bob replies - state auto-persisted
let alice_loaded = storage.load(&alice_session).unwrap().unwrap(); let (ct2, header2) = bob.encrypt_message(b"Hi Alice!").unwrap();
let bob_loaded = storage.load(&bob_session).unwrap().unwrap(); let pt2 = alice.decrypt_message(&ct2, header2).unwrap();
let mut alice_restored: RatchetState<DefaultDomain> =
alice_loaded.to_ratchet_state().unwrap();
let mut bob_restored: RatchetState<DefaultDomain> =
bob_loaded.to_ratchet_state().unwrap();
// Bob replies
let (ct2, header2) = bob_restored.encrypt_message(b"Hi Alice!");
let pt2 = alice_restored.decrypt_message(&ct2, header2).unwrap();
assert_eq!(pt2, b"Hi Alice!"); assert_eq!(pt2, b"Hi Alice!");
// Alice sends another message // Alice sends another - state auto-persisted
let (ct3, header3) = alice_restored.encrypt_message(b"How are you?"); let (ct3, header3) = alice.encrypt_message(b"How are you?").unwrap();
let pt3 = bob_restored.decrypt_message(&ct3, header3).unwrap(); let pt3 = bob.decrypt_message(&ct3, header3).unwrap();
assert_eq!(pt3, b"How are you?"); assert_eq!(pt3, b"How are you?");
} }
/// Integration test: verify SQLite storage with encryption works /// Integration test: verify SQLite storage with file persistence
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
#[test] #[test]
fn test_sqlite_integration() { fn test_sqlite_file_persistence() {
let dir = tempfile::tempdir().unwrap(); let dir = tempfile::tempdir().unwrap();
let db_path = dir.path().join("integration_test.db"); let db_path = dir.path().join("integration_test.db");
let key = [0x42u8; 32]; let key = [0x42u8; 32];
// Setup
let bob_keypair = InstallationKeyPair::generate(); let bob_keypair = InstallationKeyPair::generate();
let bob_pub = *bob_keypair.public();
let shared_secret = [0x42u8; 32]; let shared_secret = [0x42u8; 32];
let mut alice: RatchetState<DefaultDomain> =
RatchetState::init_sender(shared_secret, *bob_keypair.public());
let mut bob: RatchetState<DefaultDomain> =
RatchetState::init_receiver(shared_secret, bob_keypair);
let alice_session = [0xAA; 32]; let alice_session = [0xAA; 32];
let bob_session = [0xBB; 32]; let bob_session = [0xBB; 32];
// Exchange messages // First session: exchange messages
let (ct1, header1) = alice.encrypt_message(b"Message 1");
bob.decrypt_message(&ct1, header1).unwrap();
let (ct2, header2) = bob.encrypt_message(b"Response 1");
alice.decrypt_message(&ct2, header2).unwrap();
// Save both states
{ {
let storage = SqliteStorage::open(&db_path, key).unwrap(); let store: Arc<dyn RatchetStore> =
Arc::new(SqliteRatchetStore::open(&db_path, key).unwrap());
let alice_storable = StorableRatchetState::from_ratchet_state(&alice, "default"); let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
let bob_storable = StorableRatchetState::from_ratchet_state(&bob, "default"); Arc::clone(&store),
alice_session,
shared_secret,
bob_pub,
)
.unwrap();
storage.save(&alice_session, &alice_storable).unwrap(); let mut bob: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_receiver(
storage.save(&bob_session, &bob_storable).unwrap(); Arc::clone(&store),
bob_session,
shared_secret,
bob_keypair,
)
.unwrap();
let (ct1, h1) = alice.encrypt_message(b"Message 1").unwrap();
bob.decrypt_message(&ct1, h1).unwrap();
let (ct2, h2) = bob.encrypt_message(b"Response 1").unwrap();
alice.decrypt_message(&ct2, h2).unwrap();
} }
// Reopen database (simulating restart) // Reopen database (simulating restart)
{ {
let storage = SqliteStorage::open(&db_path, key).unwrap(); let store: Arc<dyn RatchetStore> =
Arc::new(SqliteRatchetStore::open(&db_path, key).unwrap());
let alice_loaded = storage.load(&alice_session).unwrap().unwrap(); let mut alice: PersistentRatchet<DefaultDomain> =
let bob_loaded = storage.load(&bob_session).unwrap().unwrap(); PersistentRatchet::load(Arc::clone(&store), alice_session)
.unwrap()
.unwrap();
let mut alice_restored: RatchetState<DefaultDomain> = let mut bob: PersistentRatchet<DefaultDomain> =
alice_loaded.to_ratchet_state().unwrap(); PersistentRatchet::load(Arc::clone(&store), bob_session)
let mut bob_restored: RatchetState<DefaultDomain> = .unwrap()
bob_loaded.to_ratchet_state().unwrap(); .unwrap();
// Continue conversation // Continue conversation
let (ct3, header3) = alice_restored.encrypt_message(b"Message 2"); let (ct3, h3) = alice.encrypt_message(b"Message 2").unwrap();
let pt3 = bob_restored.decrypt_message(&ct3, header3).unwrap(); let pt3 = bob.decrypt_message(&ct3, h3).unwrap();
assert_eq!(pt3, b"Message 2"); assert_eq!(pt3, b"Message 2");
} }
} }
/// Integration test: out-of-order messages with skipped keys
#[test]
fn test_out_of_order_messages_persisted() {
let store: Arc<dyn RatchetStore> =
Arc::new(SqliteRatchetStore::open_in_memory([0x42u8; 32]).unwrap());
let alice_session = [0xAA; 32];
let bob_session = [0xBB; 32];
let bob_keypair = InstallationKeyPair::generate();
let shared_secret = [0x42u8; 32];
let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
Arc::clone(&store),
alice_session,
shared_secret,
*bob_keypair.public(),
)
.unwrap();
let mut bob: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_receiver(
Arc::clone(&store),
bob_session,
shared_secret,
bob_keypair,
)
.unwrap();
// Alice sends 4 messages
let mut messages = vec![];
for i in 0..4 {
let (ct, h) = alice
.encrypt_message(format!("Message {}", i).as_bytes())
.unwrap();
messages.push((ct, h));
}
// Bob receives 0, 2, 3 (skipping 1)
bob.decrypt_message(&messages[0].0, messages[0].1.clone())
.unwrap();
bob.decrypt_message(&messages[2].0, messages[2].1.clone())
.unwrap();
bob.decrypt_message(&messages[3].0, messages[3].1.clone())
.unwrap();
// Verify skipped key is persisted
let bob_state = store.load_state(&bob_session).unwrap().unwrap();
assert_eq!(bob_state.skipped_keys.len(), 1);
assert_eq!(bob_state.skipped_keys[0].msg_num, 1);
// Now receive the skipped message
let pt1 = bob
.decrypt_message(&messages[1].0, messages[1].1.clone())
.unwrap();
assert_eq!(pt1, b"Message 1");
// Skipped key should be removed from storage
let bob_state = store.load_state(&bob_session).unwrap().unwrap();
assert!(bob_state.skipped_keys.is_empty());
}
} }

View File

@ -1,216 +0,0 @@
//! In-memory storage implementation for testing.
use std::collections::HashMap;
use std::sync::RwLock;
use crate::error::StorageError;
use crate::traits::{RatchetStorage, SessionId};
use crate::types::StorableRatchetState;
/// In-memory storage backend for testing purposes.
///
/// This implementation stores ratchet states in a `HashMap` wrapped in a `RwLock`
/// for thread-safe access. Data is not persisted across process restarts.
///
/// # Example
///
/// ```
/// use double_ratchets_storage::{MemoryStorage, RatchetStorage};
///
/// let storage = MemoryStorage::new();
/// assert!(storage.list_sessions().unwrap().is_empty());
/// ```
pub struct MemoryStorage {
states: RwLock<HashMap<SessionId, StorableRatchetState>>,
}
impl MemoryStorage {
/// Create a new empty in-memory storage.
pub fn new() -> Self {
Self {
states: RwLock::new(HashMap::new()),
}
}
/// Get the number of stored sessions.
pub fn len(&self) -> usize {
self.states.read().unwrap().len()
}
/// Check if the storage is empty.
pub fn is_empty(&self) -> bool {
self.states.read().unwrap().is_empty()
}
/// Clear all stored sessions.
pub fn clear(&self) {
self.states.write().unwrap().clear();
}
}
impl Default for MemoryStorage {
fn default() -> Self {
Self::new()
}
}
impl RatchetStorage for MemoryStorage {
fn save(&self, session_id: &SessionId, state: &StorableRatchetState) -> Result<(), StorageError> {
let mut states = self.states.write().unwrap();
states.insert(*session_id, state.clone());
Ok(())
}
fn load(&self, session_id: &SessionId) -> Result<Option<StorableRatchetState>, StorageError> {
let states = self.states.read().unwrap();
Ok(states.get(session_id).cloned())
}
fn delete(&self, session_id: &SessionId) -> Result<bool, StorageError> {
let mut states = self.states.write().unwrap();
Ok(states.remove(session_id).is_some())
}
fn exists(&self, session_id: &SessionId) -> Result<bool, StorageError> {
let states = self.states.read().unwrap();
Ok(states.contains_key(session_id))
}
fn list_sessions(&self) -> Result<Vec<SessionId>, StorageError> {
let states = self.states.read().unwrap();
Ok(states.keys().copied().collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use double_ratchets::hkdf::DefaultDomain;
use double_ratchets::state::RatchetState;
use double_ratchets::InstallationKeyPair;
fn create_test_state() -> StorableRatchetState {
let bob_keypair = InstallationKeyPair::generate();
let shared_secret = [0x42u8; 32];
let state: RatchetState<DefaultDomain> =
RatchetState::init_sender(shared_secret, *bob_keypair.public());
StorableRatchetState::from_ratchet_state(&state, "default")
}
#[test]
fn test_save_and_load() {
let storage = MemoryStorage::new();
let session_id = [1u8; 32];
let state = create_test_state();
storage.save(&session_id, &state).unwrap();
let loaded = storage.load(&session_id).unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.root_key, state.root_key);
}
#[test]
fn test_load_nonexistent() {
let storage = MemoryStorage::new();
let session_id = [1u8; 32];
let loaded = storage.load(&session_id).unwrap();
assert!(loaded.is_none());
}
#[test]
fn test_delete() {
let storage = MemoryStorage::new();
let session_id = [1u8; 32];
let state = create_test_state();
storage.save(&session_id, &state).unwrap();
assert!(storage.exists(&session_id).unwrap());
let deleted = storage.delete(&session_id).unwrap();
assert!(deleted);
assert!(!storage.exists(&session_id).unwrap());
// Deleting again should return false
let deleted = storage.delete(&session_id).unwrap();
assert!(!deleted);
}
#[test]
fn test_exists() {
let storage = MemoryStorage::new();
let session_id = [1u8; 32];
assert!(!storage.exists(&session_id).unwrap());
let state = create_test_state();
storage.save(&session_id, &state).unwrap();
assert!(storage.exists(&session_id).unwrap());
}
#[test]
fn test_list_sessions() {
let storage = MemoryStorage::new();
assert!(storage.list_sessions().unwrap().is_empty());
let state = create_test_state();
let session_ids: Vec<SessionId> = (0..3).map(|i| [i; 32]).collect();
for id in &session_ids {
storage.save(id, &state).unwrap();
}
let mut listed = storage.list_sessions().unwrap();
listed.sort();
let mut expected = session_ids.clone();
expected.sort();
assert_eq!(listed, expected);
}
#[test]
fn test_overwrite() {
let storage = MemoryStorage::new();
let session_id = [1u8; 32];
// Create first state
let bob_keypair1 = InstallationKeyPair::generate();
let state1: RatchetState<DefaultDomain> =
RatchetState::init_sender([0x42u8; 32], *bob_keypair1.public());
let storable1 = StorableRatchetState::from_ratchet_state(&state1, "default");
// Create second state with different root
let bob_keypair2 = InstallationKeyPair::generate();
let state2: RatchetState<DefaultDomain> =
RatchetState::init_sender([0x43u8; 32], *bob_keypair2.public());
let storable2 = StorableRatchetState::from_ratchet_state(&state2, "default");
// Save first, then overwrite with second
storage.save(&session_id, &storable1).unwrap();
storage.save(&session_id, &storable2).unwrap();
// Should have the second state
let loaded = storage.load(&session_id).unwrap().unwrap();
assert_eq!(loaded.root_key, storable2.root_key);
assert_ne!(loaded.root_key, storable1.root_key);
}
#[test]
fn test_clear() {
let storage = MemoryStorage::new();
let state = create_test_state();
for i in 0..5 {
storage.save(&[i; 32], &state).unwrap();
}
assert_eq!(storage.len(), 5);
storage.clear();
assert!(storage.is_empty());
}
}

View File

@ -0,0 +1,554 @@
//! Persistent ratchet wrapper that auto-saves state changes.
use std::sync::Arc;
use double_ratchets::errors::RatchetError;
use double_ratchets::hkdf::HkdfInfo;
use double_ratchets::state::{Header, RatchetState};
use double_ratchets::InstallationKeyPair;
use x25519_dalek::PublicKey;
use crate::error::StorageError;
use crate::traits::{RatchetStore, SessionId};
/// A wrapper around `RatchetState` that automatically persists state changes.
///
/// This wrapper intercepts `encrypt_message` and `decrypt_message` calls,
/// delegates to the underlying `RatchetState`, and then persists the changed
/// fields to storage.
///
/// # Example
///
/// ```no_run
/// use double_ratchets::hkdf::DefaultDomain;
/// use double_ratchets::InstallationKeyPair;
/// use double_ratchets_storage::{PersistentRatchet, SqliteRatchetStore};
/// use std::sync::Arc;
///
/// let store = Arc::new(SqliteRatchetStore::open_in_memory([0u8; 32]).unwrap());
/// let session_id = [1u8; 32];
/// let bob_pub = InstallationKeyPair::generate().public().clone();
/// let shared_secret = [0x42u8; 32];
///
/// let mut ratchet: PersistentRatchet<DefaultDomain> =
/// PersistentRatchet::init_sender(store, session_id, shared_secret, bob_pub).unwrap();
///
/// let (ciphertext, header) = ratchet.encrypt_message(b"Hello!").unwrap();
/// ```
pub struct PersistentRatchet<D: HkdfInfo> {
state: RatchetState<D>,
store: Arc<dyn RatchetStore>,
session_id: SessionId,
}
impl<D: HkdfInfo> PersistentRatchet<D> {
/// Initialize as the sender (first to send a message).
///
/// Creates a new ratchet state and persists it to storage.
pub fn init_sender(
store: Arc<dyn RatchetStore>,
session_id: SessionId,
shared_secret: [u8; 32],
remote_pub: PublicKey,
) -> Result<Self, StorageError> {
let state = RatchetState::<D>::init_sender(shared_secret, remote_pub);
// Persist initial state
store.init_session(
&session_id,
&state.root_key,
state.sending_chain.as_ref(),
state.receiving_chain.as_ref(),
&state.dh_self.secret_bytes(),
state.dh_self.public().as_bytes(),
state.dh_remote.as_ref().map(|pk| pk.as_bytes()),
state.msg_send,
state.msg_recv,
state.prev_chain_len,
)?;
Ok(Self {
state,
store,
session_id,
})
}
/// Initialize as the receiver (first to receive a message).
///
/// Creates a new ratchet state and persists it to storage.
pub fn init_receiver(
store: Arc<dyn RatchetStore>,
session_id: SessionId,
shared_secret: [u8; 32],
dh_self: InstallationKeyPair,
) -> Result<Self, StorageError> {
let state = RatchetState::<D>::init_receiver(shared_secret, dh_self);
// Persist initial state
store.init_session(
&session_id,
&state.root_key,
state.sending_chain.as_ref(),
state.receiving_chain.as_ref(),
&state.dh_self.secret_bytes(),
state.dh_self.public().as_bytes(),
state.dh_remote.as_ref().map(|pk| pk.as_bytes()),
state.msg_send,
state.msg_recv,
state.prev_chain_len,
)?;
Ok(Self {
state,
store,
session_id,
})
}
/// Load an existing session from storage.
pub fn load(
store: Arc<dyn RatchetStore>,
session_id: SessionId,
) -> Result<Option<Self>, StorageError> {
let Some(stored) = store.load_state(&session_id)? else {
return Ok(None);
};
// Reconstruct the keypair
let dh_self = InstallationKeyPair::from_bytes(stored.dh_self_secret, stored.dh_self_public)
.map_err(|e| StorageError::KeyReconstruction(e.to_string()))?;
// Reconstruct skipped keys
let mut skipped_keys = std::collections::HashMap::new();
for entry in stored.skipped_keys {
let pub_key = PublicKey::from(entry.dh_public);
skipped_keys.insert((pub_key, entry.msg_num), entry.message_key);
}
let state = RatchetState::<D>::from_parts(
stored.root_key,
stored.sending_chain,
stored.receiving_chain,
dh_self,
stored.dh_remote.map(PublicKey::from),
stored.msg_send,
stored.msg_recv,
stored.prev_chain_len,
skipped_keys,
);
Ok(Some(Self {
state,
store,
session_id,
}))
}
/// Encrypt a message and persist state changes.
///
/// This may trigger a DH ratchet if the sending direction changed.
/// All state changes are persisted to storage after encryption.
pub fn encrypt_message(&mut self, plaintext: &[u8]) -> Result<(Vec<u8>, Header), StorageError> {
// Check if we'll do a DH ratchet (no sending chain)
let will_ratchet = self.state.sending_chain.is_none();
// Perform encryption
let (ciphertext, header) = self.state.encrypt_message(plaintext);
// Persist changes
if will_ratchet {
// DH ratchet happened: root_key, sending_chain, dh_self all changed
self.store.store_root_and_chains(
&self.session_id,
&self.state.root_key,
self.state.sending_chain.as_ref(),
self.state.receiving_chain.as_ref(),
)?;
self.store.store_dh_self(
&self.session_id,
&self.state.dh_self.secret_bytes(),
self.state.dh_self.public().as_bytes(),
)?;
} else {
// Only sending chain changed
self.store.store_root_and_chains(
&self.session_id,
&self.state.root_key,
self.state.sending_chain.as_ref(),
self.state.receiving_chain.as_ref(),
)?;
}
// Counters always change
self.store.store_counters(
&self.session_id,
self.state.msg_send,
self.state.msg_recv,
self.state.prev_chain_len,
)?;
Ok((ciphertext, header))
}
/// Decrypt a message and persist state changes.
///
/// Handles DH ratcheting, skipped messages, and replay protection.
/// All state changes are persisted to storage after decryption.
pub fn decrypt_message(
&mut self,
ciphertext_with_nonce: &[u8],
header: Header,
) -> Result<Vec<u8>, PersistentRatchetError> {
// Track skipped keys before decryption
let skipped_before: std::collections::HashSet<_> = self
.state
.skipped_keys
.keys()
.map(|(pk, n)| (*pk.as_bytes(), *n))
.collect();
// Check if we'll do a DH ratchet
let will_ratchet = self.state.dh_remote.as_ref() != Some(&header.dh_pub);
// Perform decryption
let plaintext = self
.state
.decrypt_message(ciphertext_with_nonce, header)
.map_err(PersistentRatchetError::Ratchet)?;
// Track skipped keys after decryption
let skipped_after: std::collections::HashSet<_> = self
.state
.skipped_keys
.keys()
.map(|(pk, n)| (*pk.as_bytes(), *n))
.collect();
// Persist changes
if will_ratchet {
// DH ratchet happened
self.store
.store_root_and_chains(
&self.session_id,
&self.state.root_key,
self.state.sending_chain.as_ref(),
self.state.receiving_chain.as_ref(),
)
.map_err(PersistentRatchetError::Storage)?;
self.store
.store_dh_remote(
&self.session_id,
self.state.dh_remote.as_ref().map(|pk| pk.as_bytes()),
)
.map_err(PersistentRatchetError::Storage)?;
} else {
// Only receiving chain changed
self.store
.store_root_and_chains(
&self.session_id,
&self.state.root_key,
self.state.sending_chain.as_ref(),
self.state.receiving_chain.as_ref(),
)
.map_err(PersistentRatchetError::Storage)?;
}
// Counters
self.store
.store_counters(
&self.session_id,
self.state.msg_send,
self.state.msg_recv,
self.state.prev_chain_len,
)
.map_err(PersistentRatchetError::Storage)?;
// Handle skipped keys changes
// New skipped keys (added during skip_message_keys)
for (pk_bytes, msg_num) in skipped_after.difference(&skipped_before) {
let pk = PublicKey::from(*pk_bytes);
if let Some(key) = self.state.skipped_keys.get(&(pk, *msg_num)) {
self.store
.add_skipped_key(&self.session_id, pk_bytes, *msg_num, key)
.map_err(PersistentRatchetError::Storage)?;
}
}
// Removed skipped keys (used for decryption)
for (pk_bytes, msg_num) in skipped_before.difference(&skipped_after) {
self.store
.remove_skipped_key(&self.session_id, pk_bytes, *msg_num)
.map_err(PersistentRatchetError::Storage)?;
}
Ok(plaintext)
}
/// Get a reference to the underlying state (read-only).
pub fn state(&self) -> &RatchetState<D> {
&self.state
}
/// Get the session ID.
pub fn session_id(&self) -> &SessionId {
&self.session_id
}
/// Delete this session from storage.
pub fn delete(self) -> Result<bool, StorageError> {
self.store.delete_session(&self.session_id)
}
}
/// Error type for persistent ratchet operations.
#[derive(Debug)]
pub enum PersistentRatchetError {
/// Storage operation failed.
Storage(StorageError),
/// Ratchet operation failed.
Ratchet(RatchetError),
}
impl std::fmt::Display for PersistentRatchetError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Storage(e) => write!(f, "storage error: {}", e),
Self::Ratchet(e) => write!(f, "ratchet error: {:?}", e),
}
}
}
impl std::error::Error for PersistentRatchetError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Storage(e) => Some(e),
Self::Ratchet(_) => None,
}
}
}
impl From<StorageError> for PersistentRatchetError {
fn from(e: StorageError) -> Self {
Self::Storage(e)
}
}
impl From<RatchetError> for PersistentRatchetError {
fn from(e: RatchetError) -> Self {
Self::Ratchet(e)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sqlite::SqliteRatchetStore;
use double_ratchets::hkdf::DefaultDomain;
fn test_store() -> Arc<dyn RatchetStore> {
Arc::new(SqliteRatchetStore::open_in_memory([0x42u8; 32]).unwrap())
}
#[test]
fn test_basic_roundtrip() {
let store = test_store();
let alice_session = [0xAA; 32];
let bob_session = [0xBB; 32];
let bob_keypair = InstallationKeyPair::generate();
let shared_secret = [0x42u8; 32];
// Initialize both parties
let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
Arc::clone(&store),
alice_session,
shared_secret,
*bob_keypair.public(),
)
.unwrap();
let mut bob: PersistentRatchet<DefaultDomain> =
PersistentRatchet::init_receiver(Arc::clone(&store), bob_session, shared_secret, bob_keypair)
.unwrap();
// Alice sends a message
let (ct, header) = alice.encrypt_message(b"Hello Bob!").unwrap();
let pt = bob.decrypt_message(&ct, header).unwrap();
assert_eq!(pt, b"Hello Bob!");
// Verify state was persisted
let alice_loaded = store.load_state(&alice_session).unwrap().unwrap();
assert_eq!(alice_loaded.msg_send, 1);
let bob_loaded = store.load_state(&bob_session).unwrap().unwrap();
assert_eq!(bob_loaded.msg_recv, 1);
}
#[test]
fn test_load_and_continue() {
let store = test_store();
let alice_session = [0xAA; 32];
let bob_session = [0xBB; 32];
let bob_keypair = InstallationKeyPair::generate();
let shared_secret = [0x42u8; 32];
// Initialize and exchange one message
{
let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
Arc::clone(&store),
alice_session,
shared_secret,
*bob_keypair.public(),
)
.unwrap();
let mut bob: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_receiver(
Arc::clone(&store),
bob_session,
shared_secret,
bob_keypair,
)
.unwrap();
let (ct, header) = alice.encrypt_message(b"Message 1").unwrap();
bob.decrypt_message(&ct, header).unwrap();
}
// Load from storage and continue
{
let mut alice: PersistentRatchet<DefaultDomain> =
PersistentRatchet::load(Arc::clone(&store), alice_session)
.unwrap()
.unwrap();
let mut bob: PersistentRatchet<DefaultDomain> =
PersistentRatchet::load(Arc::clone(&store), bob_session)
.unwrap()
.unwrap();
// Bob replies
let (ct, header) = bob.encrypt_message(b"Reply from Bob").unwrap();
let pt = alice.decrypt_message(&ct, header).unwrap();
assert_eq!(pt, b"Reply from Bob");
// Alice sends another
let (ct2, header2) = alice.encrypt_message(b"Message 2").unwrap();
let pt2 = bob.decrypt_message(&ct2, header2).unwrap();
assert_eq!(pt2, b"Message 2");
}
}
#[test]
fn test_skipped_keys_persisted() {
let store = test_store();
let alice_session = [0xAA; 32];
let bob_session = [0xBB; 32];
let bob_keypair = InstallationKeyPair::generate();
let shared_secret = [0x42u8; 32];
let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
Arc::clone(&store),
alice_session,
shared_secret,
*bob_keypair.public(),
)
.unwrap();
let mut bob: PersistentRatchet<DefaultDomain> =
PersistentRatchet::init_receiver(Arc::clone(&store), bob_session, shared_secret, bob_keypair)
.unwrap();
// Alice sends 3 messages
let mut messages = vec![];
for i in 0..3 {
let (ct, header) = alice
.encrypt_message(format!("Message {}", i).as_bytes())
.unwrap();
messages.push((ct, header));
}
// Bob receives them out of order: 0, 2 (skipping 1)
bob.decrypt_message(&messages[0].0, messages[0].1.clone())
.unwrap();
bob.decrypt_message(&messages[2].0, messages[2].1.clone())
.unwrap();
// Check skipped key was persisted
let bob_loaded = store.load_state(&bob_session).unwrap().unwrap();
assert_eq!(bob_loaded.skipped_keys.len(), 1);
assert_eq!(bob_loaded.skipped_keys[0].msg_num, 1);
// Now receive the skipped message
bob.decrypt_message(&messages[1].0, messages[1].1.clone())
.unwrap();
// Skipped key should be removed
let bob_loaded = store.load_state(&bob_session).unwrap().unwrap();
assert!(bob_loaded.skipped_keys.is_empty());
}
#[test]
fn test_dh_ratchet_persisted() {
let store = test_store();
let alice_session = [0xAA; 32];
let bob_session = [0xBB; 32];
let bob_keypair = InstallationKeyPair::generate();
let shared_secret = [0x42u8; 32];
let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
Arc::clone(&store),
alice_session,
shared_secret,
*bob_keypair.public(),
)
.unwrap();
let mut bob: PersistentRatchet<DefaultDomain> =
PersistentRatchet::init_receiver(Arc::clone(&store), bob_session, shared_secret, bob_keypair)
.unwrap();
// Alice sends
let (ct1, h1) = alice.encrypt_message(b"Hello").unwrap();
bob.decrypt_message(&ct1, h1).unwrap();
// Get Bob's initial DH public
let bob_initial_pub = bob.state().dh_self.public().as_bytes().clone();
// Bob replies (triggers DH ratchet)
let (ct2, h2) = bob.encrypt_message(b"Hi").unwrap();
alice.decrypt_message(&ct2, h2).unwrap();
// Bob's DH key should have changed
let bob_new_pub = bob.state().dh_self.public().as_bytes().clone();
assert_ne!(bob_initial_pub, bob_new_pub);
// Verify persisted
let bob_loaded = store.load_state(&bob_session).unwrap().unwrap();
assert_eq!(bob_loaded.dh_self_public, bob_new_pub);
}
#[test]
fn test_delete_session() {
let store = test_store();
let session_id = [0x11; 32];
let bob_keypair = InstallationKeyPair::generate();
let ratchet: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
Arc::clone(&store),
session_id,
[0x42u8; 32],
*bob_keypair.public(),
)
.unwrap();
assert!(store.session_exists(&session_id).unwrap());
ratchet.delete().unwrap();
assert!(!store.session_exists(&session_id).unwrap());
}
}

View File

@ -2,7 +2,6 @@
use std::path::Path; use std::path::Path;
use std::sync::Mutex; use std::sync::Mutex;
use std::time::{SystemTime, UNIX_EPOCH};
use chacha20poly1305::{ use chacha20poly1305::{
aead::{Aead, KeyInit}, aead::{Aead, KeyInit},
@ -12,85 +11,37 @@ use rand::RngCore;
use rusqlite::{params, Connection, OptionalExtension}; use rusqlite::{params, Connection, OptionalExtension};
use crate::error::StorageError; use crate::error::StorageError;
use crate::traits::{RatchetStorage, SessionId}; use crate::traits::{RatchetStore, SessionId, SkippedKeyEntry, StoredState};
use crate::types::{SkippedKey, StorableRatchetState};
/// Field encryption key type (32 bytes for ChaCha20Poly1305). /// Field encryption key type (32 bytes for ChaCha20Poly1305).
pub type EncryptionKey = [u8; 32]; pub type EncryptionKey = [u8; 32];
/// SQLite storage backend with field-level encryption for secrets. /// SQLite storage with field-level encryption for secrets.
/// ///
/// This implementation stores ratchet states in SQLite with: /// Schema:
/// - Field-level encryption for private keys using ChaCha20Poly1305 /// - `sessions`: Core ratchet state (one row per session)
/// - WAL mode for better concurrent performance /// - `skipped_keys`: Skipped message keys (many per session)
/// - Foreign keys and cascading deletes for data integrity
/// - Atomic transactions to prevent partial writes
/// ///
/// # Security /// Encrypted fields: dh_self_secret, skipped message_key
/// pub struct SqliteRatchetStore {
/// The `dh_self_secret` field is encrypted with the provided encryption key.
/// For additional security, consider using SQLCipher for full database encryption
/// via the `open_encrypted` method (requires `sqlcipher` feature).
///
/// # Example
///
/// ```no_run
/// use double_ratchets_storage::SqliteStorage;
///
/// let key = [0u8; 32]; // Use a proper key derivation function
/// let storage = SqliteStorage::open("ratchets.db", key).unwrap();
/// ```
pub struct SqliteStorage {
conn: Mutex<Connection>, conn: Mutex<Connection>,
encryption_key: EncryptionKey, encryption_key: EncryptionKey,
} }
impl SqliteStorage { impl SqliteRatchetStore {
/// Open or create a SQLite database with field-level encryption. /// Open or create a SQLite database.
///
/// # Arguments
///
/// * `path` - Path to the database file.
/// * `encryption_key` - 32-byte key for field-level encryption.
///
/// # Returns
///
/// * `Ok(SqliteStorage)` on success.
/// * `Err(StorageError)` on failure.
pub fn open<P: AsRef<Path>>(path: P, encryption_key: EncryptionKey) -> Result<Self, StorageError> { pub fn open<P: AsRef<Path>>(path: P, encryption_key: EncryptionKey) -> Result<Self, StorageError> {
let conn = Connection::open(path)?; let conn = Connection::open(path)?;
Self::initialize(conn, encryption_key) Self::initialize(conn, encryption_key)
} }
/// Create an in-memory SQLite database (for testing). /// Create an in-memory database (for testing).
///
/// # Arguments
///
/// * `encryption_key` - 32-byte key for field-level encryption.
///
/// # Returns
///
/// * `Ok(SqliteStorage)` on success.
/// * `Err(StorageError)` on failure.
pub fn open_in_memory(encryption_key: EncryptionKey) -> Result<Self, StorageError> { pub fn open_in_memory(encryption_key: EncryptionKey) -> Result<Self, StorageError> {
let conn = Connection::open_in_memory()?; let conn = Connection::open_in_memory()?;
Self::initialize(conn, encryption_key) Self::initialize(conn, encryption_key)
} }
/// Open or create a SQLCipher-encrypted database. /// Open with SQLCipher full-database encryption.
///
/// This method requires the `sqlcipher` feature to be enabled.
///
/// # Arguments
///
/// * `path` - Path to the database file.
/// * `db_password` - Password for SQLCipher database encryption.
/// * `field_key` - 32-byte key for additional field-level encryption.
///
/// # Returns
///
/// * `Ok(SqliteStorage)` on success.
/// * `Err(StorageError)` on failure.
#[cfg(feature = "sqlcipher")] #[cfg(feature = "sqlcipher")]
pub fn open_encrypted<P: AsRef<Path>>( pub fn open_encrypted<P: AsRef<Path>>(
path: P, path: P,
@ -98,50 +49,39 @@ impl SqliteStorage {
field_key: EncryptionKey, field_key: EncryptionKey,
) -> Result<Self, StorageError> { ) -> Result<Self, StorageError> {
let conn = Connection::open(path)?; let conn = Connection::open(path)?;
// Set SQLCipher key
conn.pragma_update(None, "key", db_password)?; conn.pragma_update(None, "key", db_password)?;
Self::initialize(conn, field_key) Self::initialize(conn, field_key)
} }
fn initialize(conn: Connection, encryption_key: EncryptionKey) -> Result<Self, StorageError> { fn initialize(conn: Connection, encryption_key: EncryptionKey) -> Result<Self, StorageError> {
// Enable WAL mode for better performance
conn.pragma_update(None, "journal_mode", "WAL")?; conn.pragma_update(None, "journal_mode", "WAL")?;
// Enable foreign keys
conn.pragma_update(None, "foreign_keys", "ON")?; conn.pragma_update(None, "foreign_keys", "ON")?;
// Create tables
conn.execute_batch( conn.execute_batch(
r#" r#"
CREATE TABLE IF NOT EXISTS ratchet_states ( CREATE TABLE IF NOT EXISTS sessions (
session_id BLOB PRIMARY KEY, session_id BLOB PRIMARY KEY NOT NULL,
root_key BLOB NOT NULL, root_key BLOB NOT NULL,
sending_chain BLOB, sending_chain BLOB,
receiving_chain BLOB, receiving_chain BLOB,
dh_self_secret_encrypted BLOB NOT NULL, dh_secret_enc BLOB NOT NULL,
dh_self_secret_nonce BLOB NOT NULL, dh_secret_nonce BLOB NOT NULL,
dh_self_public BLOB NOT NULL, dh_public BLOB NOT NULL,
dh_remote BLOB, dh_remote BLOB,
msg_send INTEGER NOT NULL, msg_send INTEGER NOT NULL,
msg_recv INTEGER NOT NULL, msg_recv INTEGER NOT NULL,
prev_chain_len INTEGER NOT NULL, prev_chain_len INTEGER NOT NULL
domain_id TEXT NOT NULL,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL
); );
CREATE TABLE IF NOT EXISTS skipped_keys ( CREATE TABLE IF NOT EXISTS skipped_keys (
id INTEGER PRIMARY KEY, session_id BLOB NOT NULL,
session_id BLOB NOT NULL REFERENCES ratchet_states(session_id) ON DELETE CASCADE, dh_public BLOB NOT NULL,
public_key BLOB NOT NULL,
msg_num INTEGER NOT NULL, msg_num INTEGER NOT NULL,
message_key BLOB NOT NULL, message_key_enc BLOB NOT NULL,
UNIQUE(session_id, public_key, msg_num) message_key_nonce BLOB NOT NULL,
PRIMARY KEY (session_id, dh_public, msg_num),
FOREIGN KEY (session_id) REFERENCES sessions(session_id) ON DELETE CASCADE
); );
CREATE INDEX IF NOT EXISTS idx_skipped_keys_session ON skipped_keys(session_id);
"#, "#,
)?; )?;
@ -151,23 +91,20 @@ impl SqliteStorage {
}) })
} }
/// Encrypt a 32-byte secret using ChaCha20Poly1305. fn encrypt(&self, plaintext: &[u8; 32]) -> Result<(Vec<u8>, [u8; 12]), StorageError> {
fn encrypt_secret(&self, secret: &[u8; 32]) -> Result<(Vec<u8>, [u8; 12]), StorageError> {
let cipher = ChaCha20Poly1305::new((&self.encryption_key).into()); let cipher = ChaCha20Poly1305::new((&self.encryption_key).into());
let mut nonce_bytes = [0u8; 12]; let mut nonce_bytes = [0u8; 12];
rand::thread_rng().fill_bytes(&mut nonce_bytes); rand::thread_rng().fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes); let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher let ciphertext = cipher
.encrypt(nonce, secret.as_ref()) .encrypt(nonce, plaintext.as_ref())
.map_err(|e| StorageError::Encryption(e.to_string()))?; .map_err(|e| StorageError::Encryption(e.to_string()))?;
Ok((ciphertext, nonce_bytes)) Ok((ciphertext, nonce_bytes))
} }
/// Decrypt a secret using ChaCha20Poly1305. fn decrypt(&self, ciphertext: &[u8], nonce: &[u8; 12]) -> Result<[u8; 32], StorageError> {
fn decrypt_secret(&self, ciphertext: &[u8], nonce: &[u8; 12]) -> Result<[u8; 32], StorageError> {
let cipher = ChaCha20Poly1305::new((&self.encryption_key).into()); let cipher = ChaCha20Poly1305::new((&self.encryption_key).into());
let nonce = Nonce::from_slice(nonce); let nonce = Nonce::from_slice(nonce);
@ -177,135 +114,110 @@ impl SqliteStorage {
plaintext plaintext
.try_into() .try_into()
.map_err(|_| StorageError::CorruptedState("decrypted secret has wrong length".to_string())) .map_err(|_| StorageError::CorruptedState("wrong decrypted length".into()))
}
fn current_timestamp() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64
} }
} }
impl RatchetStorage for SqliteStorage { impl RatchetStore for SqliteRatchetStore {
fn save(&self, session_id: &SessionId, state: &StorableRatchetState) -> Result<(), StorageError> { fn store_root_and_chains(
let (encrypted_secret, nonce) = self.encrypt_secret(&state.dh_self_secret)?; &self,
session_id: &SessionId,
root_key: &[u8; 32],
sending_chain: Option<&[u8; 32]>,
receiving_chain: Option<&[u8; 32]>,
) -> Result<(), StorageError> {
let conn = self.conn.lock().unwrap(); let conn = self.conn.lock().unwrap();
let tx = conn.unchecked_transaction()?; conn.execute(
"UPDATE sessions SET root_key = ?, sending_chain = ?, receiving_chain = ? WHERE session_id = ?",
let now = Self::current_timestamp(); params![
root_key.as_slice(),
// Check if session exists to determine created_at sending_chain.map(|c| c.as_slice()),
let exists: bool = tx.query_row( receiving_chain.map(|c| c.as_slice()),
"SELECT 1 FROM ratchet_states WHERE session_id = ?", session_id.as_slice(),
[session_id.as_slice()], ],
|_| Ok(true), )?;
).optional()?.unwrap_or(false);
if exists {
// Update existing session
tx.execute(
r#"
UPDATE ratchet_states SET
root_key = ?,
sending_chain = ?,
receiving_chain = ?,
dh_self_secret_encrypted = ?,
dh_self_secret_nonce = ?,
dh_self_public = ?,
dh_remote = ?,
msg_send = ?,
msg_recv = ?,
prev_chain_len = ?,
domain_id = ?,
updated_at = ?
WHERE session_id = ?
"#,
params![
state.root_key.as_slice(),
state.sending_chain.as_ref().map(|c| c.as_slice()),
state.receiving_chain.as_ref().map(|c| c.as_slice()),
encrypted_secret.as_slice(),
nonce.as_slice(),
state.dh_self_public.as_slice(),
state.dh_remote.as_ref().map(|pk| pk.as_slice()),
state.msg_send,
state.msg_recv,
state.prev_chain_len,
&state.domain_id,
now,
session_id.as_slice(),
],
)?;
// Delete existing skipped keys
tx.execute(
"DELETE FROM skipped_keys WHERE session_id = ?",
[session_id.as_slice()],
)?;
} else {
// Insert new session
tx.execute(
r#"
INSERT INTO ratchet_states (
session_id, root_key, sending_chain, receiving_chain,
dh_self_secret_encrypted, dh_self_secret_nonce, dh_self_public,
dh_remote, msg_send, msg_recv, prev_chain_len, domain_id,
created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
params![
session_id.as_slice(),
state.root_key.as_slice(),
state.sending_chain.as_ref().map(|c| c.as_slice()),
state.receiving_chain.as_ref().map(|c| c.as_slice()),
encrypted_secret.as_slice(),
nonce.as_slice(),
state.dh_self_public.as_slice(),
state.dh_remote.as_ref().map(|pk| pk.as_slice()),
state.msg_send,
state.msg_recv,
state.prev_chain_len,
&state.domain_id,
now,
now,
],
)?;
}
// Insert skipped keys
for sk in &state.skipped_keys {
tx.execute(
r#"
INSERT INTO skipped_keys (session_id, public_key, msg_num, message_key)
VALUES (?, ?, ?, ?)
"#,
params![
session_id.as_slice(),
sk.public_key.as_slice(),
sk.msg_num,
sk.message_key.as_slice(),
],
)?;
}
tx.commit()?;
Ok(()) Ok(())
} }
fn load(&self, session_id: &SessionId) -> Result<Option<StorableRatchetState>, StorageError> { fn store_dh_self(
&self,
session_id: &SessionId,
secret: &[u8; 32],
public: &[u8; 32],
) -> Result<(), StorageError> {
let (enc, nonce) = self.encrypt(secret)?;
let conn = self.conn.lock().unwrap();
conn.execute(
"UPDATE sessions SET dh_secret_enc = ?, dh_secret_nonce = ?, dh_public = ? WHERE session_id = ?",
params![enc.as_slice(), nonce.as_slice(), public.as_slice(), session_id.as_slice()],
)?;
Ok(())
}
fn store_dh_remote(
&self,
session_id: &SessionId,
remote: Option<&[u8; 32]>,
) -> Result<(), StorageError> {
let conn = self.conn.lock().unwrap();
conn.execute(
"UPDATE sessions SET dh_remote = ? WHERE session_id = ?",
params![remote.map(|r| r.as_slice()), session_id.as_slice()],
)?;
Ok(())
}
fn store_counters(
&self,
session_id: &SessionId,
msg_send: u32,
msg_recv: u32,
prev_chain_len: u32,
) -> Result<(), StorageError> {
let conn = self.conn.lock().unwrap();
conn.execute(
"UPDATE sessions SET msg_send = ?, msg_recv = ?, prev_chain_len = ? WHERE session_id = ?",
params![msg_send, msg_recv, prev_chain_len, session_id.as_slice()],
)?;
Ok(())
}
fn add_skipped_key(
&self,
session_id: &SessionId,
dh_public: &[u8; 32],
msg_num: u32,
message_key: &[u8; 32],
) -> Result<(), StorageError> {
let (enc, nonce) = self.encrypt(message_key)?;
let conn = self.conn.lock().unwrap();
conn.execute(
"INSERT OR REPLACE INTO skipped_keys (session_id, dh_public, msg_num, message_key_enc, message_key_nonce) VALUES (?, ?, ?, ?, ?)",
params![session_id.as_slice(), dh_public.as_slice(), msg_num, enc.as_slice(), nonce.as_slice()],
)?;
Ok(())
}
fn remove_skipped_key(
&self,
session_id: &SessionId,
dh_public: &[u8; 32],
msg_num: u32,
) -> Result<(), StorageError> {
let conn = self.conn.lock().unwrap();
conn.execute(
"DELETE FROM skipped_keys WHERE session_id = ? AND dh_public = ? AND msg_num = ?",
params![session_id.as_slice(), dh_public.as_slice(), msg_num],
)?;
Ok(())
}
fn load_state(&self, session_id: &SessionId) -> Result<Option<StoredState>, StorageError> {
let conn = self.conn.lock().unwrap(); let conn = self.conn.lock().unwrap();
let row = conn let row = conn
.query_row( .query_row(
r#" "SELECT root_key, sending_chain, receiving_chain, dh_secret_enc, dh_secret_nonce, dh_public, dh_remote, msg_send, msg_recv, prev_chain_len FROM sessions WHERE session_id = ?",
SELECT root_key, sending_chain, receiving_chain,
dh_self_secret_encrypted, dh_self_secret_nonce, dh_self_public,
dh_remote, msg_send, msg_recv, prev_chain_len, domain_id
FROM ratchet_states WHERE session_id = ?
"#,
[session_id.as_slice()], [session_id.as_slice()],
|row| { |row| {
Ok(( Ok((
@ -319,91 +231,64 @@ impl RatchetStorage for SqliteStorage {
row.get::<_, u32>(7)?, row.get::<_, u32>(7)?,
row.get::<_, u32>(8)?, row.get::<_, u32>(8)?,
row.get::<_, u32>(9)?, row.get::<_, u32>(9)?,
row.get::<_, String>(10)?,
)) ))
}, },
) )
.optional()?; .optional()?;
let Some(( let Some((
root_key_bytes, root_key_bytes, sending_bytes, receiving_bytes,
sending_chain_bytes, secret_enc, secret_nonce, dh_pub_bytes, dh_remote_bytes,
receiving_chain_bytes, msg_send, msg_recv, prev_chain_len,
encrypted_secret, )) = row else {
nonce_bytes,
dh_self_public_bytes,
dh_remote_bytes,
msg_send,
msg_recv,
prev_chain_len,
domain_id,
)) = row
else {
return Ok(None); return Ok(None);
}; };
// Decrypt the secret let nonce: [u8; 12] = secret_nonce.try_into()
let nonce: [u8; 12] = nonce_bytes .map_err(|_| StorageError::CorruptedState("invalid nonce".into()))?;
.try_into() let dh_self_secret = self.decrypt(&secret_enc, &nonce)?;
.map_err(|_| StorageError::CorruptedState("invalid nonce length".to_string()))?;
let dh_self_secret = self.decrypt_secret(&encrypted_secret, &nonce)?;
// Convert byte vectors to arrays let root_key: [u8; 32] = root_key_bytes.try_into()
let root_key: [u8; 32] = root_key_bytes .map_err(|_| StorageError::CorruptedState("invalid root_key".into()))?;
.try_into() let dh_self_public: [u8; 32] = dh_pub_bytes.try_into()
.map_err(|_| StorageError::CorruptedState("invalid root_key length".to_string()))?; .map_err(|_| StorageError::CorruptedState("invalid dh_public".into()))?;
let dh_self_public: [u8; 32] = dh_self_public_bytes
.try_into()
.map_err(|_| StorageError::CorruptedState("invalid dh_self_public length".to_string()))?;
let sending_chain = sending_chain_bytes let sending_chain = sending_bytes
.map(|b| { .map(|b| b.try_into().map_err(|_| StorageError::CorruptedState("invalid sending_chain".into())))
b.try_into()
.map_err(|_| StorageError::CorruptedState("invalid sending_chain length".to_string()))
})
.transpose()?; .transpose()?;
let receiving_chain = receiving_chain_bytes let receiving_chain = receiving_bytes
.map(|b| { .map(|b| b.try_into().map_err(|_| StorageError::CorruptedState("invalid receiving_chain".into())))
b.try_into()
.map_err(|_| StorageError::CorruptedState("invalid receiving_chain length".to_string()))
})
.transpose()?; .transpose()?;
let dh_remote = dh_remote_bytes let dh_remote = dh_remote_bytes
.map(|b| { .map(|b| b.try_into().map_err(|_| StorageError::CorruptedState("invalid dh_remote".into())))
b.try_into()
.map_err(|_| StorageError::CorruptedState("invalid dh_remote length".to_string()))
})
.transpose()?; .transpose()?;
// Load skipped keys // Load skipped keys
let mut stmt = conn.prepare( let mut stmt = conn.prepare(
"SELECT public_key, msg_num, message_key FROM skipped_keys WHERE session_id = ?", "SELECT dh_public, msg_num, message_key_enc, message_key_nonce FROM skipped_keys WHERE session_id = ?",
)?; )?;
let skipped_keys: Vec<SkippedKey> = stmt let mut skipped_keys = Vec::new();
.query_map([session_id.as_slice()], |row| { let rows = stmt.query_map([session_id.as_slice()], |row| {
let pk_bytes: Vec<u8> = row.get(0)?; Ok((
let msg_num: u32 = row.get(1)?; row.get::<_, Vec<u8>>(0)?,
let mk_bytes: Vec<u8> = row.get(2)?; row.get::<_, u32>(1)?,
Ok((pk_bytes, msg_num, mk_bytes)) row.get::<_, Vec<u8>>(2)?,
})? row.get::<_, Vec<u8>>(3)?,
.collect::<Result<Vec<_>, _>>()? ))
.into_iter() })?;
.map(|(pk_bytes, msg_num, mk_bytes)| {
let public_key: [u8; 32] = pk_bytes
.try_into()
.map_err(|_| StorageError::CorruptedState("invalid skipped public_key length".to_string()))?;
let message_key: [u8; 32] = mk_bytes
.try_into()
.map_err(|_| StorageError::CorruptedState("invalid skipped message_key length".to_string()))?;
Ok(SkippedKey {
public_key,
msg_num,
message_key,
})
})
.collect::<Result<Vec<_>, StorageError>>()?;
Ok(Some(StorableRatchetState { for row in rows {
let (pk_bytes, msg_num, key_enc, key_nonce) = row?;
let dh_public: [u8; 32] = pk_bytes.try_into()
.map_err(|_| StorageError::CorruptedState("invalid skipped dh_public".into()))?;
let nonce: [u8; 12] = key_nonce.try_into()
.map_err(|_| StorageError::CorruptedState("invalid skipped nonce".into()))?;
let message_key = self.decrypt(&key_enc, &nonce)?;
skipped_keys.push(SkippedKeyEntry { dh_public, msg_num, message_key });
}
Ok(Some(StoredState {
root_key, root_key,
sending_chain, sending_chain,
receiving_chain, receiving_chain,
@ -414,24 +299,72 @@ impl RatchetStorage for SqliteStorage {
msg_recv, msg_recv,
prev_chain_len, prev_chain_len,
skipped_keys, skipped_keys,
domain_id,
})) }))
} }
fn delete(&self, session_id: &SessionId) -> Result<bool, StorageError> { fn init_session(
&self,
session_id: &SessionId,
root_key: &[u8; 32],
sending_chain: Option<&[u8; 32]>,
receiving_chain: Option<&[u8; 32]>,
dh_self_secret: &[u8; 32],
dh_self_public: &[u8; 32],
dh_remote: Option<&[u8; 32]>,
msg_send: u32,
msg_recv: u32,
prev_chain_len: u32,
) -> Result<(), StorageError> {
let (secret_enc, secret_nonce) = self.encrypt(dh_self_secret)?;
let conn = self.conn.lock().unwrap();
conn.execute(
r#"
INSERT INTO sessions (session_id, root_key, sending_chain, receiving_chain, dh_secret_enc, dh_secret_nonce, dh_public, dh_remote, msg_send, msg_recv, prev_chain_len)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(session_id) DO UPDATE SET
root_key = excluded.root_key,
sending_chain = excluded.sending_chain,
receiving_chain = excluded.receiving_chain,
dh_secret_enc = excluded.dh_secret_enc,
dh_secret_nonce = excluded.dh_secret_nonce,
dh_public = excluded.dh_public,
dh_remote = excluded.dh_remote,
msg_send = excluded.msg_send,
msg_recv = excluded.msg_recv,
prev_chain_len = excluded.prev_chain_len
"#,
params![
session_id.as_slice(),
root_key.as_slice(),
sending_chain.map(|c| c.as_slice()),
receiving_chain.map(|c| c.as_slice()),
secret_enc.as_slice(),
secret_nonce.as_slice(),
dh_self_public.as_slice(),
dh_remote.map(|r| r.as_slice()),
msg_send,
msg_recv,
prev_chain_len,
],
)?;
Ok(())
}
fn delete_session(&self, session_id: &SessionId) -> Result<bool, StorageError> {
let conn = self.conn.lock().unwrap(); let conn = self.conn.lock().unwrap();
let changes = conn.execute( let changes = conn.execute(
"DELETE FROM ratchet_states WHERE session_id = ?", "DELETE FROM sessions WHERE session_id = ?",
[session_id.as_slice()], [session_id.as_slice()],
)?; )?;
Ok(changes > 0) Ok(changes > 0)
} }
fn exists(&self, session_id: &SessionId) -> Result<bool, StorageError> { fn session_exists(&self, session_id: &SessionId) -> Result<bool, StorageError> {
let conn = self.conn.lock().unwrap(); let conn = self.conn.lock().unwrap();
let exists: bool = conn let exists: bool = conn
.query_row( .query_row(
"SELECT 1 FROM ratchet_states WHERE session_id = ?", "SELECT 1 FROM sessions WHERE session_id = ?",
[session_id.as_slice()], [session_id.as_slice()],
|_| Ok(true), |_| Ok(true),
) )
@ -442,15 +375,11 @@ impl RatchetStorage for SqliteStorage {
fn list_sessions(&self) -> Result<Vec<SessionId>, StorageError> { fn list_sessions(&self) -> Result<Vec<SessionId>, StorageError> {
let conn = self.conn.lock().unwrap(); let conn = self.conn.lock().unwrap();
let mut stmt = conn.prepare("SELECT session_id FROM ratchet_states")?; let mut stmt = conn.prepare("SELECT session_id FROM sessions")?;
let sessions: Vec<SessionId> = stmt let sessions = stmt
.query_map([], |row| { .query_map([], |row| row.get::<_, Vec<u8>>(0))?
let bytes: Vec<u8> = row.get(0)?; .filter_map(|r| r.ok())
Ok(bytes) .filter_map(|v| v.try_into().ok())
})?
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.filter_map(|bytes| bytes.try_into().ok())
.collect(); .collect();
Ok(sessions) Ok(sessions)
} }
@ -459,288 +388,99 @@ impl RatchetStorage for SqliteStorage {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use double_ratchets::hkdf::DefaultDomain;
use double_ratchets::state::RatchetState;
use double_ratchets::InstallationKeyPair;
fn create_test_storage() -> SqliteStorage { fn test_store() -> SqliteRatchetStore {
let key = [0x42u8; 32]; SqliteRatchetStore::open_in_memory([0x42u8; 32]).unwrap()
SqliteStorage::open_in_memory(key).unwrap()
}
fn create_test_state() -> StorableRatchetState {
let bob_keypair = InstallationKeyPair::generate();
let shared_secret = [0x42u8; 32];
let state: RatchetState<DefaultDomain> =
RatchetState::init_sender(shared_secret, *bob_keypair.public());
StorableRatchetState::from_ratchet_state(&state, "default")
} }
#[test] #[test]
fn test_save_and_load() { fn test_init_and_load() {
let storage = create_test_storage(); let store = test_store();
let session_id = [1u8; 32];
let state = create_test_state();
storage.save(&session_id, &state).unwrap();
let loaded = storage.load(&session_id).unwrap();
assert!(loaded.is_some());
let loaded = loaded.unwrap();
assert_eq!(loaded.root_key, state.root_key);
assert_eq!(loaded.dh_self_public, state.dh_self_public);
// Secret should be decrypted correctly
assert_eq!(loaded.dh_self_secret, state.dh_self_secret);
}
#[test]
fn test_load_nonexistent() {
let storage = create_test_storage();
let session_id = [1u8; 32]; let session_id = [1u8; 32];
let loaded = storage.load(&session_id).unwrap(); store.init_session(
assert!(loaded.is_none()); &session_id,
&[0xAA; 32], // root
Some(&[0xBB; 32]), // sending
None, // receiving
&[0xCC; 32], // secret
&[0xDD; 32], // public
Some(&[0xEE; 32]), // remote
5, 3, 2,
).unwrap();
let state = store.load_state(&session_id).unwrap().unwrap();
assert_eq!(state.root_key, [0xAA; 32]);
assert_eq!(state.sending_chain, Some([0xBB; 32]));
assert_eq!(state.receiving_chain, None);
assert_eq!(state.dh_self_secret, [0xCC; 32]);
assert_eq!(state.dh_self_public, [0xDD; 32]);
assert_eq!(state.dh_remote, Some([0xEE; 32]));
assert_eq!(state.msg_send, 5);
assert_eq!(state.msg_recv, 3);
assert_eq!(state.prev_chain_len, 2);
} }
#[test] #[test]
fn test_delete() { fn test_update_fields() {
let storage = create_test_storage(); let store = test_store();
let session_id = [1u8; 32];
let state = create_test_state();
storage.save(&session_id, &state).unwrap();
assert!(storage.exists(&session_id).unwrap());
let deleted = storage.delete(&session_id).unwrap();
assert!(deleted);
assert!(!storage.exists(&session_id).unwrap());
// Deleting again should return false
let deleted = storage.delete(&session_id).unwrap();
assert!(!deleted);
}
#[test]
fn test_exists() {
let storage = create_test_storage();
let session_id = [1u8; 32]; let session_id = [1u8; 32];
assert!(!storage.exists(&session_id).unwrap()); store.init_session(
&session_id, &[0xAA; 32], None, None,
&[0xCC; 32], &[0xDD; 32], None, 0, 0, 0,
).unwrap();
let state = create_test_state(); // Update root and chains
storage.save(&session_id, &state).unwrap(); store.store_root_and_chains(&session_id, &[0x11; 32], Some(&[0x22; 32]), Some(&[0x33; 32])).unwrap();
assert!(storage.exists(&session_id).unwrap()); let state = store.load_state(&session_id).unwrap().unwrap();
assert_eq!(state.root_key, [0x11; 32]);
assert_eq!(state.sending_chain, Some([0x22; 32]));
assert_eq!(state.receiving_chain, Some([0x33; 32]));
} }
#[test] #[test]
fn test_list_sessions() { fn test_skipped_keys() {
let storage = create_test_storage(); let store = test_store();
assert!(storage.list_sessions().unwrap().is_empty());
let state = create_test_state();
let session_ids: Vec<SessionId> = (0..3).map(|i| [i; 32]).collect();
for id in &session_ids {
storage.save(id, &state).unwrap();
}
let mut listed = storage.list_sessions().unwrap();
listed.sort();
let mut expected = session_ids.clone();
expected.sort();
assert_eq!(listed, expected);
}
#[test]
fn test_overwrite() {
let storage = create_test_storage();
let session_id = [1u8; 32]; let session_id = [1u8; 32];
// Create first state store.init_session(
let bob_keypair1 = InstallationKeyPair::generate(); &session_id, &[0xAA; 32], None, None,
let state1: RatchetState<DefaultDomain> = &[0xCC; 32], &[0xDD; 32], None, 0, 0, 0,
RatchetState::init_sender([0x42u8; 32], *bob_keypair1.public()); ).unwrap();
let storable1 = StorableRatchetState::from_ratchet_state(&state1, "default");
// Create second state with different root // Add skipped keys
let bob_keypair2 = InstallationKeyPair::generate(); store.add_skipped_key(&session_id, &[0x11; 32], 5, &[0xAB; 32]).unwrap();
let state2: RatchetState<DefaultDomain> = store.add_skipped_key(&session_id, &[0x11; 32], 6, &[0xCD; 32]).unwrap();
RatchetState::init_sender([0x43u8; 32], *bob_keypair2.public());
let storable2 = StorableRatchetState::from_ratchet_state(&state2, "default");
// Save first, then overwrite with second let state = store.load_state(&session_id).unwrap().unwrap();
storage.save(&session_id, &storable1).unwrap(); assert_eq!(state.skipped_keys.len(), 2);
storage.save(&session_id, &storable2).unwrap();
// Should have the second state // Remove one
let loaded = storage.load(&session_id).unwrap().unwrap(); store.remove_skipped_key(&session_id, &[0x11; 32], 5).unwrap();
assert_eq!(loaded.root_key, storable2.root_key);
assert_ne!(loaded.root_key, storable1.root_key); let state = store.load_state(&session_id).unwrap().unwrap();
assert_eq!(state.skipped_keys.len(), 1);
assert_eq!(state.skipped_keys[0].msg_num, 6);
} }
#[test] #[test]
fn test_skipped_keys_storage() { fn test_delete_cascade() {
let storage = create_test_storage(); let store = test_store();
let session_id = [1u8; 32]; let session_id = [1u8; 32];
// Create states and generate skipped keys store.init_session(
let bob_keypair = InstallationKeyPair::generate(); &session_id, &[0xAA; 32], None, None,
let shared_secret = [0x42u8; 32]; &[0xCC; 32], &[0xDD; 32], None, 0, 0, 0,
let mut alice: RatchetState<DefaultDomain> = ).unwrap();
RatchetState::init_sender(shared_secret, *bob_keypair.public()); store.add_skipped_key(&session_id, &[0x11; 32], 5, &[0xAB; 32]).unwrap();
let mut bob: RatchetState<DefaultDomain> =
RatchetState::init_receiver(shared_secret, bob_keypair);
// Alice sends multiple messages assert!(store.session_exists(&session_id).unwrap());
let mut messages = vec![];
for i in 0..3 {
let (ct, header) = alice.encrypt_message(&format!("Message {}", i).into_bytes());
messages.push((ct, header));
}
// Bob receives out of order to create skipped keys store.delete_session(&session_id).unwrap();
bob.decrypt_message(&messages[0].0, messages[0].1.clone())
.unwrap();
bob.decrypt_message(&messages[2].0, messages[2].1.clone())
.unwrap();
assert!(!bob.skipped_keys.is_empty()); assert!(!store.session_exists(&session_id).unwrap());
// Skipped keys should be gone too (cascade)
// Save and reload
let storable = StorableRatchetState::from_ratchet_state(&bob, "default");
storage.save(&session_id, &storable).unwrap();
let loaded = storage.load(&session_id).unwrap().unwrap();
assert_eq!(loaded.skipped_keys.len(), storable.skipped_keys.len());
// Restore and verify we can decrypt the skipped message
let mut restored: RatchetState<DefaultDomain> = loaded.to_ratchet_state().unwrap();
let pt = restored
.decrypt_message(&messages[1].0, messages[1].1.clone())
.unwrap();
assert_eq!(pt, b"Message 1");
}
#[test]
fn test_encryption_uses_different_nonces() {
let storage = create_test_storage();
let state = create_test_state();
// Save the same state twice with different session IDs
storage.save(&[1u8; 32], &state).unwrap();
storage.save(&[2u8; 32], &state).unwrap();
// Both should load correctly (encryption with different nonces)
let loaded1 = storage.load(&[1u8; 32]).unwrap().unwrap();
let loaded2 = storage.load(&[2u8; 32]).unwrap().unwrap();
assert_eq!(loaded1.dh_self_secret, loaded2.dh_self_secret);
}
#[test]
fn test_cascade_delete_skipped_keys() {
let storage = create_test_storage();
let session_id = [1u8; 32];
// Create a state with skipped keys
let bob_keypair = InstallationKeyPair::generate();
let shared_secret = [0x42u8; 32];
let mut alice: RatchetState<DefaultDomain> =
RatchetState::init_sender(shared_secret, *bob_keypair.public());
let mut bob: RatchetState<DefaultDomain> =
RatchetState::init_receiver(shared_secret, bob_keypair);
let mut messages = vec![];
for i in 0..3 {
let (ct, header) = alice.encrypt_message(&format!("Message {}", i).into_bytes());
messages.push((ct, header));
}
bob.decrypt_message(&messages[0].0, messages[0].1.clone())
.unwrap();
bob.decrypt_message(&messages[2].0, messages[2].1.clone())
.unwrap();
let storable = StorableRatchetState::from_ratchet_state(&bob, "default");
storage.save(&session_id, &storable).unwrap();
// Verify skipped keys exist
{
let conn = storage.conn.lock().unwrap();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM skipped_keys WHERE session_id = ?",
[session_id.as_slice()],
|row| row.get(0),
)
.unwrap();
assert!(count > 0);
}
// Delete session
storage.delete(&session_id).unwrap();
// Verify skipped keys were also deleted (cascade)
{
let conn = storage.conn.lock().unwrap();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM skipped_keys WHERE session_id = ?",
[session_id.as_slice()],
|row| row.get(0),
)
.unwrap();
assert_eq!(count, 0);
}
}
#[test]
fn test_file_storage() {
let dir = tempfile::tempdir().unwrap();
let db_path = dir.path().join("test.db");
let key = [0x42u8; 32];
let state = create_test_state();
let session_id = [1u8; 32];
// Save in one instance
{
let storage = SqliteStorage::open(&db_path, key).unwrap();
storage.save(&session_id, &state).unwrap();
}
// Load in another instance
{
let storage = SqliteStorage::open(&db_path, key).unwrap();
let loaded = storage.load(&session_id).unwrap().unwrap();
assert_eq!(loaded.root_key, state.root_key);
}
}
#[test]
fn test_wrong_key_fails_decryption() {
let dir = tempfile::tempdir().unwrap();
let db_path = dir.path().join("test.db");
let key1 = [0x42u8; 32];
let key2 = [0x43u8; 32];
let state = create_test_state();
let session_id = [1u8; 32];
// Save with key1
{
let storage = SqliteStorage::open(&db_path, key1).unwrap();
storage.save(&session_id, &state).unwrap();
}
// Try to load with key2 - should fail decryption
{
let storage = SqliteStorage::open(&db_path, key2).unwrap();
let result = storage.load(&session_id);
assert!(result.is_err());
}
} }
} }

View File

@ -1,74 +1,112 @@
//! Storage trait definitions. //! Storage trait for field-level ratchet state persistence.
use crate::error::StorageError; use crate::error::StorageError;
use crate::types::StorableRatchetState;
/// A 32-byte session identifier. /// A 32-byte session identifier.
pub type SessionId = [u8; 32]; pub type SessionId = [u8; 32];
/// Abstract storage interface for ratchet states. /// Field-level storage interface for ratchet state.
/// ///
/// Implementations must be thread-safe (`Send + Sync`). /// This trait provides granular storage operations that are called automatically
pub trait RatchetStorage: Send + Sync { /// during ratchet operations. Each method persists only the fields that changed.
/// Save a ratchet state for the given session. pub trait RatchetStore: Send + Sync {
/// /// Store the root key and chain keys after a DH ratchet step.
/// If a state already exists for this session, it will be overwritten. fn store_root_and_chains(
/// &self,
/// # Arguments session_id: &SessionId,
/// root_key: &[u8; 32],
/// * `session_id` - Unique identifier for the session. sending_chain: Option<&[u8; 32]>,
/// * `state` - The ratchet state to store. receiving_chain: Option<&[u8; 32]>,
/// ) -> Result<(), StorageError>;
/// # Returns
///
/// * `Ok(())` on success.
/// * `Err(StorageError)` on failure.
fn save(&self, session_id: &SessionId, state: &StorableRatchetState) -> Result<(), StorageError>;
/// Load a ratchet state for the given session. /// Store our DH keypair (secret encrypted, public plaintext).
/// fn store_dh_self(
/// # Arguments &self,
/// session_id: &SessionId,
/// * `session_id` - Unique identifier for the session. secret: &[u8; 32],
/// public: &[u8; 32],
/// # Returns ) -> Result<(), StorageError>;
///
/// * `Ok(Some(state))` if found.
/// * `Ok(None)` if not found.
/// * `Err(StorageError)` on failure.
fn load(&self, session_id: &SessionId) -> Result<Option<StorableRatchetState>, StorageError>;
/// Delete a ratchet state for the given session. /// Store the remote party's DH public key.
/// fn store_dh_remote(
/// # Arguments &self,
/// session_id: &SessionId,
/// * `session_id` - Unique identifier for the session. remote: Option<&[u8; 32]>,
/// ) -> Result<(), StorageError>;
/// # Returns
///
/// * `Ok(true)` if the session existed and was deleted.
/// * `Ok(false)` if the session did not exist.
/// * `Err(StorageError)` on failure.
fn delete(&self, session_id: &SessionId) -> Result<bool, StorageError>;
/// Check if a session exists in storage. /// Store message counters.
/// fn store_counters(
/// # Arguments &self,
/// session_id: &SessionId,
/// * `session_id` - Unique identifier for the session. msg_send: u32,
/// msg_recv: u32,
/// # Returns prev_chain_len: u32,
/// ) -> Result<(), StorageError>;
/// * `Ok(true)` if the session exists.
/// * `Ok(false)` if the session does not exist.
/// * `Err(StorageError)` on failure.
fn exists(&self, session_id: &SessionId) -> Result<bool, StorageError>;
/// List all session IDs in storage. /// Add a skipped message key.
/// fn add_skipped_key(
/// # Returns &self,
/// session_id: &SessionId,
/// * `Ok(Vec<SessionId>)` containing all session IDs. dh_public: &[u8; 32],
/// * `Err(StorageError)` on failure. msg_num: u32,
message_key: &[u8; 32],
) -> Result<(), StorageError>;
/// Remove a skipped message key (after use).
fn remove_skipped_key(
&self,
session_id: &SessionId,
dh_public: &[u8; 32],
msg_num: u32,
) -> Result<(), StorageError>;
/// Load all state for a session. Returns None if session doesn't exist.
fn load_state(&self, session_id: &SessionId) -> Result<Option<StoredState>, StorageError>;
/// Initialize a new session with all fields.
fn init_session(
&self,
session_id: &SessionId,
root_key: &[u8; 32],
sending_chain: Option<&[u8; 32]>,
receiving_chain: Option<&[u8; 32]>,
dh_self_secret: &[u8; 32],
dh_self_public: &[u8; 32],
dh_remote: Option<&[u8; 32]>,
msg_send: u32,
msg_recv: u32,
prev_chain_len: u32,
) -> Result<(), StorageError>;
/// Delete a session and all its data.
fn delete_session(&self, session_id: &SessionId) -> Result<bool, StorageError>;
/// Check if a session exists.
fn session_exists(&self, session_id: &SessionId) -> Result<bool, StorageError>;
/// List all session IDs.
fn list_sessions(&self) -> Result<Vec<SessionId>, StorageError>; fn list_sessions(&self) -> Result<Vec<SessionId>, StorageError>;
} }
/// Complete state loaded from storage.
#[derive(Debug, Clone)]
pub struct StoredState {
pub root_key: [u8; 32],
pub sending_chain: Option<[u8; 32]>,
pub receiving_chain: Option<[u8; 32]>,
pub dh_self_secret: [u8; 32],
pub dh_self_public: [u8; 32],
pub dh_remote: Option<[u8; 32]>,
pub msg_send: u32,
pub msg_recv: u32,
pub prev_chain_len: u32,
pub skipped_keys: Vec<SkippedKeyEntry>,
}
/// A skipped key entry from storage.
#[derive(Debug, Clone)]
pub struct SkippedKeyEntry {
pub dh_public: [u8; 32],
pub msg_num: u32,
pub message_key: [u8; 32],
}

View File

@ -1,225 +0,0 @@
//! Serializable types for ratchet state storage.
use std::collections::HashMap;
use double_ratchets::state::RatchetState;
use double_ratchets::hkdf::HkdfInfo;
use double_ratchets::InstallationKeyPair;
use serde::{Deserialize, Serialize};
use x25519_dalek::PublicKey;
use crate::error::StorageError;
/// A skipped message key entry for storage.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct SkippedKey {
/// The public key associated with this skipped message.
pub public_key: [u8; 32],
/// The message number.
pub msg_num: u32,
/// The 32-byte message key.
pub message_key: [u8; 32],
}
/// Serializable version of `RatchetState`.
///
/// This struct stores all keys as raw byte arrays for easy serialization
/// and database storage. Use `from_ratchet_state()` and `to_ratchet_state()`
/// for conversion.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorableRatchetState {
/// The current root key (32 bytes).
pub root_key: [u8; 32],
/// The current sending chain key, if any (32 bytes).
pub sending_chain: Option<[u8; 32]>,
/// The current receiving chain key, if any (32 bytes).
pub receiving_chain: Option<[u8; 32]>,
/// Our DH secret key (32 bytes).
///
/// **Security**: This should be encrypted before storage.
pub dh_self_secret: [u8; 32],
/// Our DH public key (32 bytes).
pub dh_self_public: [u8; 32],
/// Remote party's DH public key, if known (32 bytes).
pub dh_remote: Option<[u8; 32]>,
/// Number of messages sent in the current sending chain.
pub msg_send: u32,
/// Number of messages received in the current receiving chain.
pub msg_recv: u32,
/// Length of the previous sending chain.
pub prev_chain_len: u32,
/// Skipped message keys for out-of-order message handling.
pub skipped_keys: Vec<SkippedKey>,
/// Domain identifier for HKDF info.
pub domain_id: String,
}
impl StorableRatchetState {
/// Convert a `RatchetState` into a `StorableRatchetState`.
///
/// # Type Parameters
///
/// * `D` - The HKDF domain type implementing `HkdfInfo`.
///
/// # Arguments
///
/// * `state` - The ratchet state to convert.
/// * `domain_id` - A string identifier for the domain (used to reconstruct the correct domain type).
pub fn from_ratchet_state<D: HkdfInfo>(state: &RatchetState<D>, domain_id: &str) -> Self {
let skipped_keys: Vec<SkippedKey> = state
.skipped_keys
.iter()
.map(|((pub_key, msg_num), msg_key)| SkippedKey {
public_key: *pub_key.as_bytes(),
msg_num: *msg_num,
message_key: *msg_key,
})
.collect();
StorableRatchetState {
root_key: state.root_key,
sending_chain: state.sending_chain,
receiving_chain: state.receiving_chain,
dh_self_secret: state.dh_self.secret_bytes(),
dh_self_public: *state.dh_self.public().as_bytes(),
dh_remote: state.dh_remote.map(|pk| *pk.as_bytes()),
msg_send: state.msg_send,
msg_recv: state.msg_recv,
prev_chain_len: state.prev_chain_len,
skipped_keys,
domain_id: domain_id.to_string(),
}
}
/// Convert this `StorableRatchetState` back into a `RatchetState`.
///
/// # Type Parameters
///
/// * `D` - The HKDF domain type implementing `HkdfInfo`.
///
/// # Returns
///
/// * `Ok(RatchetState)` on success.
/// * `Err(StorageError)` if key reconstruction fails.
pub fn to_ratchet_state<D: HkdfInfo>(&self) -> Result<RatchetState<D>, StorageError> {
// Reconstruct the keypair
let dh_self = InstallationKeyPair::from_bytes(self.dh_self_secret, self.dh_self_public)
.map_err(|e| StorageError::KeyReconstruction(e.to_string()))?;
// Reconstruct skipped keys HashMap
let skipped_keys: HashMap<(PublicKey, u32), [u8; 32]> = self
.skipped_keys
.iter()
.map(|sk| {
let pub_key = PublicKey::from(sk.public_key);
((pub_key, sk.msg_num), sk.message_key)
})
.collect();
Ok(RatchetState::from_parts(
self.root_key,
self.sending_chain,
self.receiving_chain,
dh_self,
self.dh_remote.map(PublicKey::from),
self.msg_send,
self.msg_recv,
self.prev_chain_len,
skipped_keys,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use double_ratchets::hkdf::DefaultDomain;
#[test]
fn test_roundtrip_sender_state() {
// Create a sender state
let bob_keypair = InstallationKeyPair::generate();
let shared_secret = [0x42u8; 32];
let state: RatchetState<DefaultDomain> =
RatchetState::init_sender(shared_secret, *bob_keypair.public());
// Convert to storable and back
let storable = StorableRatchetState::from_ratchet_state(&state, "default");
let restored: RatchetState<DefaultDomain> = storable.to_ratchet_state().unwrap();
// Verify fields match
assert_eq!(state.root_key, restored.root_key);
assert_eq!(state.sending_chain, restored.sending_chain);
assert_eq!(state.receiving_chain, restored.receiving_chain);
assert_eq!(state.dh_self.public().as_bytes(), restored.dh_self.public().as_bytes());
assert_eq!(state.dh_remote.map(|pk| *pk.as_bytes()), restored.dh_remote.map(|pk| *pk.as_bytes()));
assert_eq!(state.msg_send, restored.msg_send);
assert_eq!(state.msg_recv, restored.msg_recv);
assert_eq!(state.prev_chain_len, restored.prev_chain_len);
}
#[test]
fn test_roundtrip_receiver_state() {
// Create a receiver state
let keypair = InstallationKeyPair::generate();
let shared_secret = [0x42u8; 32];
let state: RatchetState<DefaultDomain> =
RatchetState::init_receiver(shared_secret, keypair);
// Convert to storable and back
let storable = StorableRatchetState::from_ratchet_state(&state, "default");
let restored: RatchetState<DefaultDomain> = storable.to_ratchet_state().unwrap();
// Verify fields match
assert_eq!(state.root_key, restored.root_key);
assert_eq!(state.dh_self.public().as_bytes(), restored.dh_self.public().as_bytes());
assert!(restored.dh_remote.is_none());
}
#[test]
fn test_roundtrip_with_skipped_keys() {
// Create states and exchange messages to generate skipped keys
let bob_keypair = InstallationKeyPair::generate();
let shared_secret = [0x42u8; 32];
let mut alice: RatchetState<DefaultDomain> =
RatchetState::init_sender(shared_secret, *bob_keypair.public());
let mut bob: RatchetState<DefaultDomain> =
RatchetState::init_receiver(shared_secret, bob_keypair);
// Alice sends multiple messages
let mut messages = vec![];
for i in 0..3 {
let (ct, header) = alice.encrypt_message(&format!("Message {}", i).into_bytes());
messages.push((ct, header));
}
// Bob receives them out of order to create skipped keys
bob.decrypt_message(&messages[0].0, messages[0].1.clone()).unwrap();
bob.decrypt_message(&messages[2].0, messages[2].1.clone()).unwrap();
// Message 1 key is now in skipped_keys
assert!(!bob.skipped_keys.is_empty());
// Convert to storable and back
let storable = StorableRatchetState::from_ratchet_state(&bob, "default");
let restored: RatchetState<DefaultDomain> = storable.to_ratchet_state().unwrap();
// Verify skipped keys are preserved
assert_eq!(bob.skipped_keys.len(), restored.skipped_keys.len());
// The restored state should be able to decrypt the skipped message
let mut restored = restored;
let pt = restored.decrypt_message(&messages[1].0, messages[1].1.clone()).unwrap();
assert_eq!(pt, b"Message 1");
}
}