mirror of
https://github.com/logos-messaging/libchat.git
synced 2026-02-10 08:53:08 +00:00
chore: refactor
This commit is contained in:
parent
34a03275cc
commit
9f968fc80d
@ -1,10 +1,13 @@
|
||||
//! Persistent storage for Double Ratchet state.
|
||||
//!
|
||||
//! This crate provides storage backends for persisting [`RatchetState`](double_ratchets::RatchetState)
|
||||
//! across application restarts. It includes:
|
||||
//! with automatic field-level persistence on each ratchet operation.
|
||||
//!
|
||||
//! - [`MemoryStorage`] - In-memory storage for testing
|
||||
//! - [`SqliteStorage`] - SQLite storage with field-level encryption (requires `sqlite` feature)
|
||||
//! # Main API
|
||||
//!
|
||||
//! - [`PersistentRatchet`] - Wrapper that auto-persists state changes during encrypt/decrypt
|
||||
//! - [`SqliteRatchetStore`] - SQLite backend with field-level encryption
|
||||
//! - [`RatchetStore`] - Trait for implementing custom storage backends
|
||||
//!
|
||||
//! # Features
|
||||
//!
|
||||
@ -13,166 +16,233 @@
|
||||
//!
|
||||
//! # Security
|
||||
//!
|
||||
//! Private keys (`dh_self_secret`) are always encrypted with ChaCha20Poly1305 before storage,
|
||||
//! even when using plain SQLite. For additional security, enable the `sqlcipher` feature
|
||||
//! for full database encryption.
|
||||
//! Private keys (`dh_self_secret`) and message keys are always encrypted with ChaCha20Poly1305
|
||||
//! before storage, even when using plain SQLite. For additional security, enable the `sqlcipher`
|
||||
//! feature for full database encryption.
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```no_run
|
||||
//! use std::sync::Arc;
|
||||
//! use double_ratchets::hkdf::DefaultDomain;
|
||||
//! use double_ratchets::state::RatchetState;
|
||||
//! use double_ratchets::InstallationKeyPair;
|
||||
//! use double_ratchets_storage::{
|
||||
//! RatchetStorage, SqliteStorage, StorableRatchetState,
|
||||
//! };
|
||||
//!
|
||||
//! // Create a ratchet state
|
||||
//! let bob_keypair = InstallationKeyPair::generate();
|
||||
//! let shared_secret = [0x42u8; 32];
|
||||
//! let state: RatchetState<DefaultDomain> =
|
||||
//! RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
||||
//! use double_ratchets_storage::{PersistentRatchet, RatchetStore, SqliteRatchetStore};
|
||||
//!
|
||||
//! // Open storage
|
||||
//! let encryption_key = [0u8; 32]; // Use proper key derivation!
|
||||
//! let storage = SqliteStorage::open("ratchets.db", encryption_key).unwrap();
|
||||
//! let store: Arc<dyn RatchetStore> =
|
||||
//! Arc::new(SqliteRatchetStore::open("ratchets.db", encryption_key).unwrap());
|
||||
//!
|
||||
//! // Save state
|
||||
//! // Initialize sender
|
||||
//! let bob_keypair = InstallationKeyPair::generate();
|
||||
//! let shared_secret = [0x42u8; 32];
|
||||
//! let session_id = [1u8; 32];
|
||||
//! let storable = StorableRatchetState::from_ratchet_state(&state, "default");
|
||||
//! storage.save(&session_id, &storable).unwrap();
|
||||
//!
|
||||
//! // Load state
|
||||
//! let loaded = storage.load(&session_id).unwrap().unwrap();
|
||||
//! let restored: RatchetState<DefaultDomain> = loaded.to_ratchet_state().unwrap();
|
||||
//! let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
|
||||
//! Arc::clone(&store),
|
||||
//! session_id,
|
||||
//! shared_secret,
|
||||
//! *bob_keypair.public(),
|
||||
//! ).unwrap();
|
||||
//!
|
||||
//! // Encrypt - state is automatically persisted
|
||||
//! let (ciphertext, header) = alice.encrypt_message(b"Hello!").unwrap();
|
||||
//!
|
||||
//! // Later: load from storage
|
||||
//! let alice_restored: PersistentRatchet<DefaultDomain> =
|
||||
//! PersistentRatchet::load(store, session_id).unwrap().unwrap();
|
||||
//! ```
|
||||
|
||||
pub mod error;
|
||||
pub mod memory;
|
||||
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
|
||||
pub mod persistent;
|
||||
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
|
||||
pub mod sqlite;
|
||||
pub mod traits;
|
||||
pub mod types;
|
||||
|
||||
// Re-exports for convenience
|
||||
pub use error::StorageError;
|
||||
pub use memory::MemoryStorage;
|
||||
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
|
||||
pub use sqlite::{EncryptionKey, SqliteStorage};
|
||||
pub use traits::{RatchetStorage, SessionId};
|
||||
pub use types::{SkippedKey, StorableRatchetState};
|
||||
pub use persistent::{PersistentRatchet, PersistentRatchetError};
|
||||
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
|
||||
pub use sqlite::{EncryptionKey, SqliteRatchetStore};
|
||||
pub use traits::{RatchetStore, SessionId, SkippedKeyEntry, StoredState};
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(all(test, any(feature = "sqlite", feature = "sqlcipher")))]
|
||||
mod integration_tests {
|
||||
use super::*;
|
||||
use double_ratchets::hkdf::DefaultDomain;
|
||||
use double_ratchets::state::RatchetState;
|
||||
use double_ratchets::InstallationKeyPair;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Integration test: full encryption/decryption cycle with storage
|
||||
/// Integration test: full conversation with auto-persist
|
||||
#[test]
|
||||
fn test_full_conversation_with_storage_roundtrip() {
|
||||
// Setup Alice and Bob
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
|
||||
let mut alice: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
||||
let mut bob: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_receiver(shared_secret, bob_keypair);
|
||||
|
||||
let storage = MemoryStorage::new();
|
||||
fn test_full_conversation_with_auto_persist() {
|
||||
let store: Arc<dyn RatchetStore> =
|
||||
Arc::new(SqliteRatchetStore::open_in_memory([0x42u8; 32]).unwrap());
|
||||
let alice_session = [0xAA; 32];
|
||||
let bob_session = [0xBB; 32];
|
||||
|
||||
// Alice sends a message
|
||||
let (ct1, header1) = alice.encrypt_message(b"Hello Bob!");
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
|
||||
// Save Alice's state
|
||||
let alice_storable = StorableRatchetState::from_ratchet_state(&alice, "default");
|
||||
storage.save(&alice_session, &alice_storable).unwrap();
|
||||
// Initialize both parties - state is auto-persisted
|
||||
let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
|
||||
Arc::clone(&store),
|
||||
alice_session,
|
||||
shared_secret,
|
||||
*bob_keypair.public(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Bob receives the message
|
||||
let mut bob: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_receiver(
|
||||
Arc::clone(&store),
|
||||
bob_session,
|
||||
shared_secret,
|
||||
bob_keypair,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Alice sends - state auto-persisted
|
||||
let (ct1, header1) = alice.encrypt_message(b"Hello Bob!").unwrap();
|
||||
let pt1 = bob.decrypt_message(&ct1, header1).unwrap();
|
||||
assert_eq!(pt1, b"Hello Bob!");
|
||||
|
||||
// Save Bob's state
|
||||
let bob_storable = StorableRatchetState::from_ratchet_state(&bob, "default");
|
||||
storage.save(&bob_session, &bob_storable).unwrap();
|
||||
// Verify state was persisted
|
||||
assert!(store.session_exists(&alice_session).unwrap());
|
||||
assert!(store.session_exists(&bob_session).unwrap());
|
||||
|
||||
// Simulate restart: load states from storage
|
||||
let alice_loaded = storage.load(&alice_session).unwrap().unwrap();
|
||||
let bob_loaded = storage.load(&bob_session).unwrap().unwrap();
|
||||
|
||||
let mut alice_restored: RatchetState<DefaultDomain> =
|
||||
alice_loaded.to_ratchet_state().unwrap();
|
||||
let mut bob_restored: RatchetState<DefaultDomain> =
|
||||
bob_loaded.to_ratchet_state().unwrap();
|
||||
|
||||
// Bob replies
|
||||
let (ct2, header2) = bob_restored.encrypt_message(b"Hi Alice!");
|
||||
let pt2 = alice_restored.decrypt_message(&ct2, header2).unwrap();
|
||||
// Bob replies - state auto-persisted
|
||||
let (ct2, header2) = bob.encrypt_message(b"Hi Alice!").unwrap();
|
||||
let pt2 = alice.decrypt_message(&ct2, header2).unwrap();
|
||||
assert_eq!(pt2, b"Hi Alice!");
|
||||
|
||||
// Alice sends another message
|
||||
let (ct3, header3) = alice_restored.encrypt_message(b"How are you?");
|
||||
let pt3 = bob_restored.decrypt_message(&ct3, header3).unwrap();
|
||||
// Alice sends another - state auto-persisted
|
||||
let (ct3, header3) = alice.encrypt_message(b"How are you?").unwrap();
|
||||
let pt3 = bob.decrypt_message(&ct3, header3).unwrap();
|
||||
assert_eq!(pt3, b"How are you?");
|
||||
}
|
||||
|
||||
/// Integration test: verify SQLite storage with encryption works
|
||||
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
|
||||
/// Integration test: verify SQLite storage with file persistence
|
||||
#[test]
|
||||
fn test_sqlite_integration() {
|
||||
fn test_sqlite_file_persistence() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let db_path = dir.path().join("integration_test.db");
|
||||
let key = [0x42u8; 32];
|
||||
|
||||
// Setup
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let bob_pub = *bob_keypair.public();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
let mut alice: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
||||
let mut bob: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_receiver(shared_secret, bob_keypair);
|
||||
|
||||
let alice_session = [0xAA; 32];
|
||||
let bob_session = [0xBB; 32];
|
||||
|
||||
// Exchange messages
|
||||
let (ct1, header1) = alice.encrypt_message(b"Message 1");
|
||||
bob.decrypt_message(&ct1, header1).unwrap();
|
||||
|
||||
let (ct2, header2) = bob.encrypt_message(b"Response 1");
|
||||
alice.decrypt_message(&ct2, header2).unwrap();
|
||||
|
||||
// Save both states
|
||||
// First session: exchange messages
|
||||
{
|
||||
let storage = SqliteStorage::open(&db_path, key).unwrap();
|
||||
let store: Arc<dyn RatchetStore> =
|
||||
Arc::new(SqliteRatchetStore::open(&db_path, key).unwrap());
|
||||
|
||||
let alice_storable = StorableRatchetState::from_ratchet_state(&alice, "default");
|
||||
let bob_storable = StorableRatchetState::from_ratchet_state(&bob, "default");
|
||||
let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
|
||||
Arc::clone(&store),
|
||||
alice_session,
|
||||
shared_secret,
|
||||
bob_pub,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
storage.save(&alice_session, &alice_storable).unwrap();
|
||||
storage.save(&bob_session, &bob_storable).unwrap();
|
||||
let mut bob: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_receiver(
|
||||
Arc::clone(&store),
|
||||
bob_session,
|
||||
shared_secret,
|
||||
bob_keypair,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (ct1, h1) = alice.encrypt_message(b"Message 1").unwrap();
|
||||
bob.decrypt_message(&ct1, h1).unwrap();
|
||||
|
||||
let (ct2, h2) = bob.encrypt_message(b"Response 1").unwrap();
|
||||
alice.decrypt_message(&ct2, h2).unwrap();
|
||||
}
|
||||
|
||||
// Reopen database (simulating restart)
|
||||
{
|
||||
let storage = SqliteStorage::open(&db_path, key).unwrap();
|
||||
let store: Arc<dyn RatchetStore> =
|
||||
Arc::new(SqliteRatchetStore::open(&db_path, key).unwrap());
|
||||
|
||||
let alice_loaded = storage.load(&alice_session).unwrap().unwrap();
|
||||
let bob_loaded = storage.load(&bob_session).unwrap().unwrap();
|
||||
let mut alice: PersistentRatchet<DefaultDomain> =
|
||||
PersistentRatchet::load(Arc::clone(&store), alice_session)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let mut alice_restored: RatchetState<DefaultDomain> =
|
||||
alice_loaded.to_ratchet_state().unwrap();
|
||||
let mut bob_restored: RatchetState<DefaultDomain> =
|
||||
bob_loaded.to_ratchet_state().unwrap();
|
||||
let mut bob: PersistentRatchet<DefaultDomain> =
|
||||
PersistentRatchet::load(Arc::clone(&store), bob_session)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
// Continue conversation
|
||||
let (ct3, header3) = alice_restored.encrypt_message(b"Message 2");
|
||||
let pt3 = bob_restored.decrypt_message(&ct3, header3).unwrap();
|
||||
let (ct3, h3) = alice.encrypt_message(b"Message 2").unwrap();
|
||||
let pt3 = bob.decrypt_message(&ct3, h3).unwrap();
|
||||
assert_eq!(pt3, b"Message 2");
|
||||
}
|
||||
}
|
||||
|
||||
/// Integration test: out-of-order messages with skipped keys
|
||||
#[test]
|
||||
fn test_out_of_order_messages_persisted() {
|
||||
let store: Arc<dyn RatchetStore> =
|
||||
Arc::new(SqliteRatchetStore::open_in_memory([0x42u8; 32]).unwrap());
|
||||
let alice_session = [0xAA; 32];
|
||||
let bob_session = [0xBB; 32];
|
||||
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
|
||||
let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
|
||||
Arc::clone(&store),
|
||||
alice_session,
|
||||
shared_secret,
|
||||
*bob_keypair.public(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut bob: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_receiver(
|
||||
Arc::clone(&store),
|
||||
bob_session,
|
||||
shared_secret,
|
||||
bob_keypair,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Alice sends 4 messages
|
||||
let mut messages = vec![];
|
||||
for i in 0..4 {
|
||||
let (ct, h) = alice
|
||||
.encrypt_message(format!("Message {}", i).as_bytes())
|
||||
.unwrap();
|
||||
messages.push((ct, h));
|
||||
}
|
||||
|
||||
// Bob receives 0, 2, 3 (skipping 1)
|
||||
bob.decrypt_message(&messages[0].0, messages[0].1.clone())
|
||||
.unwrap();
|
||||
bob.decrypt_message(&messages[2].0, messages[2].1.clone())
|
||||
.unwrap();
|
||||
bob.decrypt_message(&messages[3].0, messages[3].1.clone())
|
||||
.unwrap();
|
||||
|
||||
// Verify skipped key is persisted
|
||||
let bob_state = store.load_state(&bob_session).unwrap().unwrap();
|
||||
assert_eq!(bob_state.skipped_keys.len(), 1);
|
||||
assert_eq!(bob_state.skipped_keys[0].msg_num, 1);
|
||||
|
||||
// Now receive the skipped message
|
||||
let pt1 = bob
|
||||
.decrypt_message(&messages[1].0, messages[1].1.clone())
|
||||
.unwrap();
|
||||
assert_eq!(pt1, b"Message 1");
|
||||
|
||||
// Skipped key should be removed from storage
|
||||
let bob_state = store.load_state(&bob_session).unwrap().unwrap();
|
||||
assert!(bob_state.skipped_keys.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,216 +0,0 @@
|
||||
//! In-memory storage implementation for testing.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::RwLock;
|
||||
|
||||
use crate::error::StorageError;
|
||||
use crate::traits::{RatchetStorage, SessionId};
|
||||
use crate::types::StorableRatchetState;
|
||||
|
||||
/// In-memory storage backend for testing purposes.
|
||||
///
|
||||
/// This implementation stores ratchet states in a `HashMap` wrapped in a `RwLock`
|
||||
/// for thread-safe access. Data is not persisted across process restarts.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use double_ratchets_storage::{MemoryStorage, RatchetStorage};
|
||||
///
|
||||
/// let storage = MemoryStorage::new();
|
||||
/// assert!(storage.list_sessions().unwrap().is_empty());
|
||||
/// ```
|
||||
pub struct MemoryStorage {
|
||||
states: RwLock<HashMap<SessionId, StorableRatchetState>>,
|
||||
}
|
||||
|
||||
impl MemoryStorage {
|
||||
/// Create a new empty in-memory storage.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
states: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the number of stored sessions.
|
||||
pub fn len(&self) -> usize {
|
||||
self.states.read().unwrap().len()
|
||||
}
|
||||
|
||||
/// Check if the storage is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.states.read().unwrap().is_empty()
|
||||
}
|
||||
|
||||
/// Clear all stored sessions.
|
||||
pub fn clear(&self) {
|
||||
self.states.write().unwrap().clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MemoryStorage {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl RatchetStorage for MemoryStorage {
|
||||
fn save(&self, session_id: &SessionId, state: &StorableRatchetState) -> Result<(), StorageError> {
|
||||
let mut states = self.states.write().unwrap();
|
||||
states.insert(*session_id, state.clone());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load(&self, session_id: &SessionId) -> Result<Option<StorableRatchetState>, StorageError> {
|
||||
let states = self.states.read().unwrap();
|
||||
Ok(states.get(session_id).cloned())
|
||||
}
|
||||
|
||||
fn delete(&self, session_id: &SessionId) -> Result<bool, StorageError> {
|
||||
let mut states = self.states.write().unwrap();
|
||||
Ok(states.remove(session_id).is_some())
|
||||
}
|
||||
|
||||
fn exists(&self, session_id: &SessionId) -> Result<bool, StorageError> {
|
||||
let states = self.states.read().unwrap();
|
||||
Ok(states.contains_key(session_id))
|
||||
}
|
||||
|
||||
fn list_sessions(&self) -> Result<Vec<SessionId>, StorageError> {
|
||||
let states = self.states.read().unwrap();
|
||||
Ok(states.keys().copied().collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use double_ratchets::hkdf::DefaultDomain;
|
||||
use double_ratchets::state::RatchetState;
|
||||
use double_ratchets::InstallationKeyPair;
|
||||
|
||||
fn create_test_state() -> StorableRatchetState {
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
let state: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
||||
StorableRatchetState::from_ratchet_state(&state, "default")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_save_and_load() {
|
||||
let storage = MemoryStorage::new();
|
||||
let session_id = [1u8; 32];
|
||||
let state = create_test_state();
|
||||
|
||||
storage.save(&session_id, &state).unwrap();
|
||||
let loaded = storage.load(&session_id).unwrap();
|
||||
|
||||
assert!(loaded.is_some());
|
||||
let loaded = loaded.unwrap();
|
||||
assert_eq!(loaded.root_key, state.root_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_nonexistent() {
|
||||
let storage = MemoryStorage::new();
|
||||
let session_id = [1u8; 32];
|
||||
|
||||
let loaded = storage.load(&session_id).unwrap();
|
||||
assert!(loaded.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delete() {
|
||||
let storage = MemoryStorage::new();
|
||||
let session_id = [1u8; 32];
|
||||
let state = create_test_state();
|
||||
|
||||
storage.save(&session_id, &state).unwrap();
|
||||
assert!(storage.exists(&session_id).unwrap());
|
||||
|
||||
let deleted = storage.delete(&session_id).unwrap();
|
||||
assert!(deleted);
|
||||
assert!(!storage.exists(&session_id).unwrap());
|
||||
|
||||
// Deleting again should return false
|
||||
let deleted = storage.delete(&session_id).unwrap();
|
||||
assert!(!deleted);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exists() {
|
||||
let storage = MemoryStorage::new();
|
||||
let session_id = [1u8; 32];
|
||||
|
||||
assert!(!storage.exists(&session_id).unwrap());
|
||||
|
||||
let state = create_test_state();
|
||||
storage.save(&session_id, &state).unwrap();
|
||||
|
||||
assert!(storage.exists(&session_id).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_list_sessions() {
|
||||
let storage = MemoryStorage::new();
|
||||
|
||||
assert!(storage.list_sessions().unwrap().is_empty());
|
||||
|
||||
let state = create_test_state();
|
||||
let session_ids: Vec<SessionId> = (0..3).map(|i| [i; 32]).collect();
|
||||
|
||||
for id in &session_ids {
|
||||
storage.save(id, &state).unwrap();
|
||||
}
|
||||
|
||||
let mut listed = storage.list_sessions().unwrap();
|
||||
listed.sort();
|
||||
let mut expected = session_ids.clone();
|
||||
expected.sort();
|
||||
|
||||
assert_eq!(listed, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_overwrite() {
|
||||
let storage = MemoryStorage::new();
|
||||
let session_id = [1u8; 32];
|
||||
|
||||
// Create first state
|
||||
let bob_keypair1 = InstallationKeyPair::generate();
|
||||
let state1: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender([0x42u8; 32], *bob_keypair1.public());
|
||||
let storable1 = StorableRatchetState::from_ratchet_state(&state1, "default");
|
||||
|
||||
// Create second state with different root
|
||||
let bob_keypair2 = InstallationKeyPair::generate();
|
||||
let state2: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender([0x43u8; 32], *bob_keypair2.public());
|
||||
let storable2 = StorableRatchetState::from_ratchet_state(&state2, "default");
|
||||
|
||||
// Save first, then overwrite with second
|
||||
storage.save(&session_id, &storable1).unwrap();
|
||||
storage.save(&session_id, &storable2).unwrap();
|
||||
|
||||
// Should have the second state
|
||||
let loaded = storage.load(&session_id).unwrap().unwrap();
|
||||
assert_eq!(loaded.root_key, storable2.root_key);
|
||||
assert_ne!(loaded.root_key, storable1.root_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear() {
|
||||
let storage = MemoryStorage::new();
|
||||
let state = create_test_state();
|
||||
|
||||
for i in 0..5 {
|
||||
storage.save(&[i; 32], &state).unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(storage.len(), 5);
|
||||
|
||||
storage.clear();
|
||||
assert!(storage.is_empty());
|
||||
}
|
||||
}
|
||||
554
double-ratchets-storage/src/persistent.rs
Normal file
554
double-ratchets-storage/src/persistent.rs
Normal file
@ -0,0 +1,554 @@
|
||||
//! Persistent ratchet wrapper that auto-saves state changes.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use double_ratchets::errors::RatchetError;
|
||||
use double_ratchets::hkdf::HkdfInfo;
|
||||
use double_ratchets::state::{Header, RatchetState};
|
||||
use double_ratchets::InstallationKeyPair;
|
||||
use x25519_dalek::PublicKey;
|
||||
|
||||
use crate::error::StorageError;
|
||||
use crate::traits::{RatchetStore, SessionId};
|
||||
|
||||
/// A wrapper around `RatchetState` that automatically persists state changes.
|
||||
///
|
||||
/// This wrapper intercepts `encrypt_message` and `decrypt_message` calls,
|
||||
/// delegates to the underlying `RatchetState`, and then persists the changed
|
||||
/// fields to storage.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use double_ratchets::hkdf::DefaultDomain;
|
||||
/// use double_ratchets::InstallationKeyPair;
|
||||
/// use double_ratchets_storage::{PersistentRatchet, SqliteRatchetStore};
|
||||
/// use std::sync::Arc;
|
||||
///
|
||||
/// let store = Arc::new(SqliteRatchetStore::open_in_memory([0u8; 32]).unwrap());
|
||||
/// let session_id = [1u8; 32];
|
||||
/// let bob_pub = InstallationKeyPair::generate().public().clone();
|
||||
/// let shared_secret = [0x42u8; 32];
|
||||
///
|
||||
/// let mut ratchet: PersistentRatchet<DefaultDomain> =
|
||||
/// PersistentRatchet::init_sender(store, session_id, shared_secret, bob_pub).unwrap();
|
||||
///
|
||||
/// let (ciphertext, header) = ratchet.encrypt_message(b"Hello!").unwrap();
|
||||
/// ```
|
||||
pub struct PersistentRatchet<D: HkdfInfo> {
|
||||
state: RatchetState<D>,
|
||||
store: Arc<dyn RatchetStore>,
|
||||
session_id: SessionId,
|
||||
}
|
||||
|
||||
impl<D: HkdfInfo> PersistentRatchet<D> {
|
||||
/// Initialize as the sender (first to send a message).
|
||||
///
|
||||
/// Creates a new ratchet state and persists it to storage.
|
||||
pub fn init_sender(
|
||||
store: Arc<dyn RatchetStore>,
|
||||
session_id: SessionId,
|
||||
shared_secret: [u8; 32],
|
||||
remote_pub: PublicKey,
|
||||
) -> Result<Self, StorageError> {
|
||||
let state = RatchetState::<D>::init_sender(shared_secret, remote_pub);
|
||||
|
||||
// Persist initial state
|
||||
store.init_session(
|
||||
&session_id,
|
||||
&state.root_key,
|
||||
state.sending_chain.as_ref(),
|
||||
state.receiving_chain.as_ref(),
|
||||
&state.dh_self.secret_bytes(),
|
||||
state.dh_self.public().as_bytes(),
|
||||
state.dh_remote.as_ref().map(|pk| pk.as_bytes()),
|
||||
state.msg_send,
|
||||
state.msg_recv,
|
||||
state.prev_chain_len,
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
state,
|
||||
store,
|
||||
session_id,
|
||||
})
|
||||
}
|
||||
|
||||
/// Initialize as the receiver (first to receive a message).
|
||||
///
|
||||
/// Creates a new ratchet state and persists it to storage.
|
||||
pub fn init_receiver(
|
||||
store: Arc<dyn RatchetStore>,
|
||||
session_id: SessionId,
|
||||
shared_secret: [u8; 32],
|
||||
dh_self: InstallationKeyPair,
|
||||
) -> Result<Self, StorageError> {
|
||||
let state = RatchetState::<D>::init_receiver(shared_secret, dh_self);
|
||||
|
||||
// Persist initial state
|
||||
store.init_session(
|
||||
&session_id,
|
||||
&state.root_key,
|
||||
state.sending_chain.as_ref(),
|
||||
state.receiving_chain.as_ref(),
|
||||
&state.dh_self.secret_bytes(),
|
||||
state.dh_self.public().as_bytes(),
|
||||
state.dh_remote.as_ref().map(|pk| pk.as_bytes()),
|
||||
state.msg_send,
|
||||
state.msg_recv,
|
||||
state.prev_chain_len,
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
state,
|
||||
store,
|
||||
session_id,
|
||||
})
|
||||
}
|
||||
|
||||
/// Load an existing session from storage.
|
||||
pub fn load(
|
||||
store: Arc<dyn RatchetStore>,
|
||||
session_id: SessionId,
|
||||
) -> Result<Option<Self>, StorageError> {
|
||||
let Some(stored) = store.load_state(&session_id)? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
// Reconstruct the keypair
|
||||
let dh_self = InstallationKeyPair::from_bytes(stored.dh_self_secret, stored.dh_self_public)
|
||||
.map_err(|e| StorageError::KeyReconstruction(e.to_string()))?;
|
||||
|
||||
// Reconstruct skipped keys
|
||||
let mut skipped_keys = std::collections::HashMap::new();
|
||||
for entry in stored.skipped_keys {
|
||||
let pub_key = PublicKey::from(entry.dh_public);
|
||||
skipped_keys.insert((pub_key, entry.msg_num), entry.message_key);
|
||||
}
|
||||
|
||||
let state = RatchetState::<D>::from_parts(
|
||||
stored.root_key,
|
||||
stored.sending_chain,
|
||||
stored.receiving_chain,
|
||||
dh_self,
|
||||
stored.dh_remote.map(PublicKey::from),
|
||||
stored.msg_send,
|
||||
stored.msg_recv,
|
||||
stored.prev_chain_len,
|
||||
skipped_keys,
|
||||
);
|
||||
|
||||
Ok(Some(Self {
|
||||
state,
|
||||
store,
|
||||
session_id,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Encrypt a message and persist state changes.
|
||||
///
|
||||
/// This may trigger a DH ratchet if the sending direction changed.
|
||||
/// All state changes are persisted to storage after encryption.
|
||||
pub fn encrypt_message(&mut self, plaintext: &[u8]) -> Result<(Vec<u8>, Header), StorageError> {
|
||||
// Check if we'll do a DH ratchet (no sending chain)
|
||||
let will_ratchet = self.state.sending_chain.is_none();
|
||||
|
||||
// Perform encryption
|
||||
let (ciphertext, header) = self.state.encrypt_message(plaintext);
|
||||
|
||||
// Persist changes
|
||||
if will_ratchet {
|
||||
// DH ratchet happened: root_key, sending_chain, dh_self all changed
|
||||
self.store.store_root_and_chains(
|
||||
&self.session_id,
|
||||
&self.state.root_key,
|
||||
self.state.sending_chain.as_ref(),
|
||||
self.state.receiving_chain.as_ref(),
|
||||
)?;
|
||||
self.store.store_dh_self(
|
||||
&self.session_id,
|
||||
&self.state.dh_self.secret_bytes(),
|
||||
self.state.dh_self.public().as_bytes(),
|
||||
)?;
|
||||
} else {
|
||||
// Only sending chain changed
|
||||
self.store.store_root_and_chains(
|
||||
&self.session_id,
|
||||
&self.state.root_key,
|
||||
self.state.sending_chain.as_ref(),
|
||||
self.state.receiving_chain.as_ref(),
|
||||
)?;
|
||||
}
|
||||
|
||||
// Counters always change
|
||||
self.store.store_counters(
|
||||
&self.session_id,
|
||||
self.state.msg_send,
|
||||
self.state.msg_recv,
|
||||
self.state.prev_chain_len,
|
||||
)?;
|
||||
|
||||
Ok((ciphertext, header))
|
||||
}
|
||||
|
||||
/// Decrypt a message and persist state changes.
|
||||
///
|
||||
/// Handles DH ratcheting, skipped messages, and replay protection.
|
||||
/// All state changes are persisted to storage after decryption.
|
||||
pub fn decrypt_message(
|
||||
&mut self,
|
||||
ciphertext_with_nonce: &[u8],
|
||||
header: Header,
|
||||
) -> Result<Vec<u8>, PersistentRatchetError> {
|
||||
// Track skipped keys before decryption
|
||||
let skipped_before: std::collections::HashSet<_> = self
|
||||
.state
|
||||
.skipped_keys
|
||||
.keys()
|
||||
.map(|(pk, n)| (*pk.as_bytes(), *n))
|
||||
.collect();
|
||||
|
||||
// Check if we'll do a DH ratchet
|
||||
let will_ratchet = self.state.dh_remote.as_ref() != Some(&header.dh_pub);
|
||||
|
||||
// Perform decryption
|
||||
let plaintext = self
|
||||
.state
|
||||
.decrypt_message(ciphertext_with_nonce, header)
|
||||
.map_err(PersistentRatchetError::Ratchet)?;
|
||||
|
||||
// Track skipped keys after decryption
|
||||
let skipped_after: std::collections::HashSet<_> = self
|
||||
.state
|
||||
.skipped_keys
|
||||
.keys()
|
||||
.map(|(pk, n)| (*pk.as_bytes(), *n))
|
||||
.collect();
|
||||
|
||||
// Persist changes
|
||||
if will_ratchet {
|
||||
// DH ratchet happened
|
||||
self.store
|
||||
.store_root_and_chains(
|
||||
&self.session_id,
|
||||
&self.state.root_key,
|
||||
self.state.sending_chain.as_ref(),
|
||||
self.state.receiving_chain.as_ref(),
|
||||
)
|
||||
.map_err(PersistentRatchetError::Storage)?;
|
||||
self.store
|
||||
.store_dh_remote(
|
||||
&self.session_id,
|
||||
self.state.dh_remote.as_ref().map(|pk| pk.as_bytes()),
|
||||
)
|
||||
.map_err(PersistentRatchetError::Storage)?;
|
||||
} else {
|
||||
// Only receiving chain changed
|
||||
self.store
|
||||
.store_root_and_chains(
|
||||
&self.session_id,
|
||||
&self.state.root_key,
|
||||
self.state.sending_chain.as_ref(),
|
||||
self.state.receiving_chain.as_ref(),
|
||||
)
|
||||
.map_err(PersistentRatchetError::Storage)?;
|
||||
}
|
||||
|
||||
// Counters
|
||||
self.store
|
||||
.store_counters(
|
||||
&self.session_id,
|
||||
self.state.msg_send,
|
||||
self.state.msg_recv,
|
||||
self.state.prev_chain_len,
|
||||
)
|
||||
.map_err(PersistentRatchetError::Storage)?;
|
||||
|
||||
// Handle skipped keys changes
|
||||
// New skipped keys (added during skip_message_keys)
|
||||
for (pk_bytes, msg_num) in skipped_after.difference(&skipped_before) {
|
||||
let pk = PublicKey::from(*pk_bytes);
|
||||
if let Some(key) = self.state.skipped_keys.get(&(pk, *msg_num)) {
|
||||
self.store
|
||||
.add_skipped_key(&self.session_id, pk_bytes, *msg_num, key)
|
||||
.map_err(PersistentRatchetError::Storage)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Removed skipped keys (used for decryption)
|
||||
for (pk_bytes, msg_num) in skipped_before.difference(&skipped_after) {
|
||||
self.store
|
||||
.remove_skipped_key(&self.session_id, pk_bytes, *msg_num)
|
||||
.map_err(PersistentRatchetError::Storage)?;
|
||||
}
|
||||
|
||||
Ok(plaintext)
|
||||
}
|
||||
|
||||
/// Get a reference to the underlying state (read-only).
|
||||
pub fn state(&self) -> &RatchetState<D> {
|
||||
&self.state
|
||||
}
|
||||
|
||||
/// Get the session ID.
|
||||
pub fn session_id(&self) -> &SessionId {
|
||||
&self.session_id
|
||||
}
|
||||
|
||||
/// Delete this session from storage.
|
||||
pub fn delete(self) -> Result<bool, StorageError> {
|
||||
self.store.delete_session(&self.session_id)
|
||||
}
|
||||
}
|
||||
|
||||
/// Error type for persistent ratchet operations.
|
||||
#[derive(Debug)]
|
||||
pub enum PersistentRatchetError {
|
||||
/// Storage operation failed.
|
||||
Storage(StorageError),
|
||||
/// Ratchet operation failed.
|
||||
Ratchet(RatchetError),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PersistentRatchetError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Storage(e) => write!(f, "storage error: {}", e),
|
||||
Self::Ratchet(e) => write!(f, "ratchet error: {:?}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for PersistentRatchetError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Self::Storage(e) => Some(e),
|
||||
Self::Ratchet(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<StorageError> for PersistentRatchetError {
|
||||
fn from(e: StorageError) -> Self {
|
||||
Self::Storage(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RatchetError> for PersistentRatchetError {
|
||||
fn from(e: RatchetError) -> Self {
|
||||
Self::Ratchet(e)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::sqlite::SqliteRatchetStore;
|
||||
use double_ratchets::hkdf::DefaultDomain;
|
||||
|
||||
fn test_store() -> Arc<dyn RatchetStore> {
|
||||
Arc::new(SqliteRatchetStore::open_in_memory([0x42u8; 32]).unwrap())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_basic_roundtrip() {
|
||||
let store = test_store();
|
||||
let alice_session = [0xAA; 32];
|
||||
let bob_session = [0xBB; 32];
|
||||
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
|
||||
// Initialize both parties
|
||||
let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
|
||||
Arc::clone(&store),
|
||||
alice_session,
|
||||
shared_secret,
|
||||
*bob_keypair.public(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut bob: PersistentRatchet<DefaultDomain> =
|
||||
PersistentRatchet::init_receiver(Arc::clone(&store), bob_session, shared_secret, bob_keypair)
|
||||
.unwrap();
|
||||
|
||||
// Alice sends a message
|
||||
let (ct, header) = alice.encrypt_message(b"Hello Bob!").unwrap();
|
||||
let pt = bob.decrypt_message(&ct, header).unwrap();
|
||||
assert_eq!(pt, b"Hello Bob!");
|
||||
|
||||
// Verify state was persisted
|
||||
let alice_loaded = store.load_state(&alice_session).unwrap().unwrap();
|
||||
assert_eq!(alice_loaded.msg_send, 1);
|
||||
|
||||
let bob_loaded = store.load_state(&bob_session).unwrap().unwrap();
|
||||
assert_eq!(bob_loaded.msg_recv, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_and_continue() {
|
||||
let store = test_store();
|
||||
let alice_session = [0xAA; 32];
|
||||
let bob_session = [0xBB; 32];
|
||||
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
|
||||
// Initialize and exchange one message
|
||||
{
|
||||
let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
|
||||
Arc::clone(&store),
|
||||
alice_session,
|
||||
shared_secret,
|
||||
*bob_keypair.public(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut bob: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_receiver(
|
||||
Arc::clone(&store),
|
||||
bob_session,
|
||||
shared_secret,
|
||||
bob_keypair,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (ct, header) = alice.encrypt_message(b"Message 1").unwrap();
|
||||
bob.decrypt_message(&ct, header).unwrap();
|
||||
}
|
||||
|
||||
// Load from storage and continue
|
||||
{
|
||||
let mut alice: PersistentRatchet<DefaultDomain> =
|
||||
PersistentRatchet::load(Arc::clone(&store), alice_session)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let mut bob: PersistentRatchet<DefaultDomain> =
|
||||
PersistentRatchet::load(Arc::clone(&store), bob_session)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
// Bob replies
|
||||
let (ct, header) = bob.encrypt_message(b"Reply from Bob").unwrap();
|
||||
let pt = alice.decrypt_message(&ct, header).unwrap();
|
||||
assert_eq!(pt, b"Reply from Bob");
|
||||
|
||||
// Alice sends another
|
||||
let (ct2, header2) = alice.encrypt_message(b"Message 2").unwrap();
|
||||
let pt2 = bob.decrypt_message(&ct2, header2).unwrap();
|
||||
assert_eq!(pt2, b"Message 2");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skipped_keys_persisted() {
|
||||
let store = test_store();
|
||||
let alice_session = [0xAA; 32];
|
||||
let bob_session = [0xBB; 32];
|
||||
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
|
||||
let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
|
||||
Arc::clone(&store),
|
||||
alice_session,
|
||||
shared_secret,
|
||||
*bob_keypair.public(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut bob: PersistentRatchet<DefaultDomain> =
|
||||
PersistentRatchet::init_receiver(Arc::clone(&store), bob_session, shared_secret, bob_keypair)
|
||||
.unwrap();
|
||||
|
||||
// Alice sends 3 messages
|
||||
let mut messages = vec![];
|
||||
for i in 0..3 {
|
||||
let (ct, header) = alice
|
||||
.encrypt_message(format!("Message {}", i).as_bytes())
|
||||
.unwrap();
|
||||
messages.push((ct, header));
|
||||
}
|
||||
|
||||
// Bob receives them out of order: 0, 2 (skipping 1)
|
||||
bob.decrypt_message(&messages[0].0, messages[0].1.clone())
|
||||
.unwrap();
|
||||
bob.decrypt_message(&messages[2].0, messages[2].1.clone())
|
||||
.unwrap();
|
||||
|
||||
// Check skipped key was persisted
|
||||
let bob_loaded = store.load_state(&bob_session).unwrap().unwrap();
|
||||
assert_eq!(bob_loaded.skipped_keys.len(), 1);
|
||||
assert_eq!(bob_loaded.skipped_keys[0].msg_num, 1);
|
||||
|
||||
// Now receive the skipped message
|
||||
bob.decrypt_message(&messages[1].0, messages[1].1.clone())
|
||||
.unwrap();
|
||||
|
||||
// Skipped key should be removed
|
||||
let bob_loaded = store.load_state(&bob_session).unwrap().unwrap();
|
||||
assert!(bob_loaded.skipped_keys.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dh_ratchet_persisted() {
|
||||
let store = test_store();
|
||||
let alice_session = [0xAA; 32];
|
||||
let bob_session = [0xBB; 32];
|
||||
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
|
||||
let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
|
||||
Arc::clone(&store),
|
||||
alice_session,
|
||||
shared_secret,
|
||||
*bob_keypair.public(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut bob: PersistentRatchet<DefaultDomain> =
|
||||
PersistentRatchet::init_receiver(Arc::clone(&store), bob_session, shared_secret, bob_keypair)
|
||||
.unwrap();
|
||||
|
||||
// Alice sends
|
||||
let (ct1, h1) = alice.encrypt_message(b"Hello").unwrap();
|
||||
bob.decrypt_message(&ct1, h1).unwrap();
|
||||
|
||||
// Get Bob's initial DH public
|
||||
let bob_initial_pub = bob.state().dh_self.public().as_bytes().clone();
|
||||
|
||||
// Bob replies (triggers DH ratchet)
|
||||
let (ct2, h2) = bob.encrypt_message(b"Hi").unwrap();
|
||||
alice.decrypt_message(&ct2, h2).unwrap();
|
||||
|
||||
// Bob's DH key should have changed
|
||||
let bob_new_pub = bob.state().dh_self.public().as_bytes().clone();
|
||||
assert_ne!(bob_initial_pub, bob_new_pub);
|
||||
|
||||
// Verify persisted
|
||||
let bob_loaded = store.load_state(&bob_session).unwrap().unwrap();
|
||||
assert_eq!(bob_loaded.dh_self_public, bob_new_pub);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delete_session() {
|
||||
let store = test_store();
|
||||
let session_id = [0x11; 32];
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
|
||||
let ratchet: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
|
||||
Arc::clone(&store),
|
||||
session_id,
|
||||
[0x42u8; 32],
|
||||
*bob_keypair.public(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(store.session_exists(&session_id).unwrap());
|
||||
|
||||
ratchet.delete().unwrap();
|
||||
|
||||
assert!(!store.session_exists(&session_id).unwrap());
|
||||
}
|
||||
}
|
||||
@ -2,7 +2,6 @@
|
||||
|
||||
use std::path::Path;
|
||||
use std::sync::Mutex;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use chacha20poly1305::{
|
||||
aead::{Aead, KeyInit},
|
||||
@ -12,85 +11,37 @@ use rand::RngCore;
|
||||
use rusqlite::{params, Connection, OptionalExtension};
|
||||
|
||||
use crate::error::StorageError;
|
||||
use crate::traits::{RatchetStorage, SessionId};
|
||||
use crate::types::{SkippedKey, StorableRatchetState};
|
||||
use crate::traits::{RatchetStore, SessionId, SkippedKeyEntry, StoredState};
|
||||
|
||||
/// Field encryption key type (32 bytes for ChaCha20Poly1305).
|
||||
pub type EncryptionKey = [u8; 32];
|
||||
|
||||
/// SQLite storage backend with field-level encryption for secrets.
|
||||
/// SQLite storage with field-level encryption for secrets.
|
||||
///
|
||||
/// This implementation stores ratchet states in SQLite with:
|
||||
/// - Field-level encryption for private keys using ChaCha20Poly1305
|
||||
/// - WAL mode for better concurrent performance
|
||||
/// - Foreign keys and cascading deletes for data integrity
|
||||
/// - Atomic transactions to prevent partial writes
|
||||
/// Schema:
|
||||
/// - `sessions`: Core ratchet state (one row per session)
|
||||
/// - `skipped_keys`: Skipped message keys (many per session)
|
||||
///
|
||||
/// # Security
|
||||
///
|
||||
/// The `dh_self_secret` field is encrypted with the provided encryption key.
|
||||
/// For additional security, consider using SQLCipher for full database encryption
|
||||
/// via the `open_encrypted` method (requires `sqlcipher` feature).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use double_ratchets_storage::SqliteStorage;
|
||||
///
|
||||
/// let key = [0u8; 32]; // Use a proper key derivation function
|
||||
/// let storage = SqliteStorage::open("ratchets.db", key).unwrap();
|
||||
/// ```
|
||||
pub struct SqliteStorage {
|
||||
/// Encrypted fields: dh_self_secret, skipped message_key
|
||||
pub struct SqliteRatchetStore {
|
||||
conn: Mutex<Connection>,
|
||||
encryption_key: EncryptionKey,
|
||||
}
|
||||
|
||||
impl SqliteStorage {
|
||||
/// Open or create a SQLite database with field-level encryption.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the database file.
|
||||
/// * `encryption_key` - 32-byte key for field-level encryption.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(SqliteStorage)` on success.
|
||||
/// * `Err(StorageError)` on failure.
|
||||
impl SqliteRatchetStore {
|
||||
/// Open or create a SQLite database.
|
||||
pub fn open<P: AsRef<Path>>(path: P, encryption_key: EncryptionKey) -> Result<Self, StorageError> {
|
||||
let conn = Connection::open(path)?;
|
||||
Self::initialize(conn, encryption_key)
|
||||
}
|
||||
|
||||
/// Create an in-memory SQLite database (for testing).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `encryption_key` - 32-byte key for field-level encryption.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(SqliteStorage)` on success.
|
||||
/// * `Err(StorageError)` on failure.
|
||||
/// Create an in-memory database (for testing).
|
||||
pub fn open_in_memory(encryption_key: EncryptionKey) -> Result<Self, StorageError> {
|
||||
let conn = Connection::open_in_memory()?;
|
||||
Self::initialize(conn, encryption_key)
|
||||
}
|
||||
|
||||
/// Open or create a SQLCipher-encrypted database.
|
||||
///
|
||||
/// This method requires the `sqlcipher` feature to be enabled.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the database file.
|
||||
/// * `db_password` - Password for SQLCipher database encryption.
|
||||
/// * `field_key` - 32-byte key for additional field-level encryption.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(SqliteStorage)` on success.
|
||||
/// * `Err(StorageError)` on failure.
|
||||
/// Open with SQLCipher full-database encryption.
|
||||
#[cfg(feature = "sqlcipher")]
|
||||
pub fn open_encrypted<P: AsRef<Path>>(
|
||||
path: P,
|
||||
@ -98,50 +49,39 @@ impl SqliteStorage {
|
||||
field_key: EncryptionKey,
|
||||
) -> Result<Self, StorageError> {
|
||||
let conn = Connection::open(path)?;
|
||||
|
||||
// Set SQLCipher key
|
||||
conn.pragma_update(None, "key", db_password)?;
|
||||
|
||||
Self::initialize(conn, field_key)
|
||||
}
|
||||
|
||||
fn initialize(conn: Connection, encryption_key: EncryptionKey) -> Result<Self, StorageError> {
|
||||
// Enable WAL mode for better performance
|
||||
conn.pragma_update(None, "journal_mode", "WAL")?;
|
||||
|
||||
// Enable foreign keys
|
||||
conn.pragma_update(None, "foreign_keys", "ON")?;
|
||||
|
||||
// Create tables
|
||||
conn.execute_batch(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS ratchet_states (
|
||||
session_id BLOB PRIMARY KEY,
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
session_id BLOB PRIMARY KEY NOT NULL,
|
||||
root_key BLOB NOT NULL,
|
||||
sending_chain BLOB,
|
||||
receiving_chain BLOB,
|
||||
dh_self_secret_encrypted BLOB NOT NULL,
|
||||
dh_self_secret_nonce BLOB NOT NULL,
|
||||
dh_self_public BLOB NOT NULL,
|
||||
dh_secret_enc BLOB NOT NULL,
|
||||
dh_secret_nonce BLOB NOT NULL,
|
||||
dh_public BLOB NOT NULL,
|
||||
dh_remote BLOB,
|
||||
msg_send INTEGER NOT NULL,
|
||||
msg_recv INTEGER NOT NULL,
|
||||
prev_chain_len INTEGER NOT NULL,
|
||||
domain_id TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL
|
||||
prev_chain_len INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS skipped_keys (
|
||||
id INTEGER PRIMARY KEY,
|
||||
session_id BLOB NOT NULL REFERENCES ratchet_states(session_id) ON DELETE CASCADE,
|
||||
public_key BLOB NOT NULL,
|
||||
session_id BLOB NOT NULL,
|
||||
dh_public BLOB NOT NULL,
|
||||
msg_num INTEGER NOT NULL,
|
||||
message_key BLOB NOT NULL,
|
||||
UNIQUE(session_id, public_key, msg_num)
|
||||
message_key_enc BLOB NOT NULL,
|
||||
message_key_nonce BLOB NOT NULL,
|
||||
PRIMARY KEY (session_id, dh_public, msg_num),
|
||||
FOREIGN KEY (session_id) REFERENCES sessions(session_id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_skipped_keys_session ON skipped_keys(session_id);
|
||||
"#,
|
||||
)?;
|
||||
|
||||
@ -151,23 +91,20 @@ impl SqliteStorage {
|
||||
})
|
||||
}
|
||||
|
||||
/// Encrypt a 32-byte secret using ChaCha20Poly1305.
|
||||
fn encrypt_secret(&self, secret: &[u8; 32]) -> Result<(Vec<u8>, [u8; 12]), StorageError> {
|
||||
fn encrypt(&self, plaintext: &[u8; 32]) -> Result<(Vec<u8>, [u8; 12]), StorageError> {
|
||||
let cipher = ChaCha20Poly1305::new((&self.encryption_key).into());
|
||||
|
||||
let mut nonce_bytes = [0u8; 12];
|
||||
rand::thread_rng().fill_bytes(&mut nonce_bytes);
|
||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||
|
||||
let ciphertext = cipher
|
||||
.encrypt(nonce, secret.as_ref())
|
||||
.encrypt(nonce, plaintext.as_ref())
|
||||
.map_err(|e| StorageError::Encryption(e.to_string()))?;
|
||||
|
||||
Ok((ciphertext, nonce_bytes))
|
||||
}
|
||||
|
||||
/// Decrypt a secret using ChaCha20Poly1305.
|
||||
fn decrypt_secret(&self, ciphertext: &[u8], nonce: &[u8; 12]) -> Result<[u8; 32], StorageError> {
|
||||
fn decrypt(&self, ciphertext: &[u8], nonce: &[u8; 12]) -> Result<[u8; 32], StorageError> {
|
||||
let cipher = ChaCha20Poly1305::new((&self.encryption_key).into());
|
||||
let nonce = Nonce::from_slice(nonce);
|
||||
|
||||
@ -177,135 +114,110 @@ impl SqliteStorage {
|
||||
|
||||
plaintext
|
||||
.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("decrypted secret has wrong length".to_string()))
|
||||
}
|
||||
|
||||
fn current_timestamp() -> i64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64
|
||||
.map_err(|_| StorageError::CorruptedState("wrong decrypted length".into()))
|
||||
}
|
||||
}
|
||||
|
||||
impl RatchetStorage for SqliteStorage {
|
||||
fn save(&self, session_id: &SessionId, state: &StorableRatchetState) -> Result<(), StorageError> {
|
||||
let (encrypted_secret, nonce) = self.encrypt_secret(&state.dh_self_secret)?;
|
||||
|
||||
impl RatchetStore for SqliteRatchetStore {
|
||||
fn store_root_and_chains(
|
||||
&self,
|
||||
session_id: &SessionId,
|
||||
root_key: &[u8; 32],
|
||||
sending_chain: Option<&[u8; 32]>,
|
||||
receiving_chain: Option<&[u8; 32]>,
|
||||
) -> Result<(), StorageError> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
let tx = conn.unchecked_transaction()?;
|
||||
|
||||
let now = Self::current_timestamp();
|
||||
|
||||
// Check if session exists to determine created_at
|
||||
let exists: bool = tx.query_row(
|
||||
"SELECT 1 FROM ratchet_states WHERE session_id = ?",
|
||||
[session_id.as_slice()],
|
||||
|_| Ok(true),
|
||||
).optional()?.unwrap_or(false);
|
||||
|
||||
if exists {
|
||||
// Update existing session
|
||||
tx.execute(
|
||||
r#"
|
||||
UPDATE ratchet_states SET
|
||||
root_key = ?,
|
||||
sending_chain = ?,
|
||||
receiving_chain = ?,
|
||||
dh_self_secret_encrypted = ?,
|
||||
dh_self_secret_nonce = ?,
|
||||
dh_self_public = ?,
|
||||
dh_remote = ?,
|
||||
msg_send = ?,
|
||||
msg_recv = ?,
|
||||
prev_chain_len = ?,
|
||||
domain_id = ?,
|
||||
updated_at = ?
|
||||
WHERE session_id = ?
|
||||
"#,
|
||||
params![
|
||||
state.root_key.as_slice(),
|
||||
state.sending_chain.as_ref().map(|c| c.as_slice()),
|
||||
state.receiving_chain.as_ref().map(|c| c.as_slice()),
|
||||
encrypted_secret.as_slice(),
|
||||
nonce.as_slice(),
|
||||
state.dh_self_public.as_slice(),
|
||||
state.dh_remote.as_ref().map(|pk| pk.as_slice()),
|
||||
state.msg_send,
|
||||
state.msg_recv,
|
||||
state.prev_chain_len,
|
||||
&state.domain_id,
|
||||
now,
|
||||
session_id.as_slice(),
|
||||
],
|
||||
)?;
|
||||
|
||||
// Delete existing skipped keys
|
||||
tx.execute(
|
||||
"DELETE FROM skipped_keys WHERE session_id = ?",
|
||||
[session_id.as_slice()],
|
||||
)?;
|
||||
} else {
|
||||
// Insert new session
|
||||
tx.execute(
|
||||
r#"
|
||||
INSERT INTO ratchet_states (
|
||||
session_id, root_key, sending_chain, receiving_chain,
|
||||
dh_self_secret_encrypted, dh_self_secret_nonce, dh_self_public,
|
||||
dh_remote, msg_send, msg_recv, prev_chain_len, domain_id,
|
||||
created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
params![
|
||||
session_id.as_slice(),
|
||||
state.root_key.as_slice(),
|
||||
state.sending_chain.as_ref().map(|c| c.as_slice()),
|
||||
state.receiving_chain.as_ref().map(|c| c.as_slice()),
|
||||
encrypted_secret.as_slice(),
|
||||
nonce.as_slice(),
|
||||
state.dh_self_public.as_slice(),
|
||||
state.dh_remote.as_ref().map(|pk| pk.as_slice()),
|
||||
state.msg_send,
|
||||
state.msg_recv,
|
||||
state.prev_chain_len,
|
||||
&state.domain_id,
|
||||
now,
|
||||
now,
|
||||
],
|
||||
)?;
|
||||
}
|
||||
|
||||
// Insert skipped keys
|
||||
for sk in &state.skipped_keys {
|
||||
tx.execute(
|
||||
r#"
|
||||
INSERT INTO skipped_keys (session_id, public_key, msg_num, message_key)
|
||||
VALUES (?, ?, ?, ?)
|
||||
"#,
|
||||
params![
|
||||
session_id.as_slice(),
|
||||
sk.public_key.as_slice(),
|
||||
sk.msg_num,
|
||||
sk.message_key.as_slice(),
|
||||
],
|
||||
)?;
|
||||
}
|
||||
|
||||
tx.commit()?;
|
||||
conn.execute(
|
||||
"UPDATE sessions SET root_key = ?, sending_chain = ?, receiving_chain = ? WHERE session_id = ?",
|
||||
params![
|
||||
root_key.as_slice(),
|
||||
sending_chain.map(|c| c.as_slice()),
|
||||
receiving_chain.map(|c| c.as_slice()),
|
||||
session_id.as_slice(),
|
||||
],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load(&self, session_id: &SessionId) -> Result<Option<StorableRatchetState>, StorageError> {
|
||||
fn store_dh_self(
|
||||
&self,
|
||||
session_id: &SessionId,
|
||||
secret: &[u8; 32],
|
||||
public: &[u8; 32],
|
||||
) -> Result<(), StorageError> {
|
||||
let (enc, nonce) = self.encrypt(secret)?;
|
||||
let conn = self.conn.lock().unwrap();
|
||||
conn.execute(
|
||||
"UPDATE sessions SET dh_secret_enc = ?, dh_secret_nonce = ?, dh_public = ? WHERE session_id = ?",
|
||||
params![enc.as_slice(), nonce.as_slice(), public.as_slice(), session_id.as_slice()],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn store_dh_remote(
|
||||
&self,
|
||||
session_id: &SessionId,
|
||||
remote: Option<&[u8; 32]>,
|
||||
) -> Result<(), StorageError> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
conn.execute(
|
||||
"UPDATE sessions SET dh_remote = ? WHERE session_id = ?",
|
||||
params![remote.map(|r| r.as_slice()), session_id.as_slice()],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn store_counters(
|
||||
&self,
|
||||
session_id: &SessionId,
|
||||
msg_send: u32,
|
||||
msg_recv: u32,
|
||||
prev_chain_len: u32,
|
||||
) -> Result<(), StorageError> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
conn.execute(
|
||||
"UPDATE sessions SET msg_send = ?, msg_recv = ?, prev_chain_len = ? WHERE session_id = ?",
|
||||
params![msg_send, msg_recv, prev_chain_len, session_id.as_slice()],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn add_skipped_key(
|
||||
&self,
|
||||
session_id: &SessionId,
|
||||
dh_public: &[u8; 32],
|
||||
msg_num: u32,
|
||||
message_key: &[u8; 32],
|
||||
) -> Result<(), StorageError> {
|
||||
let (enc, nonce) = self.encrypt(message_key)?;
|
||||
let conn = self.conn.lock().unwrap();
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO skipped_keys (session_id, dh_public, msg_num, message_key_enc, message_key_nonce) VALUES (?, ?, ?, ?, ?)",
|
||||
params![session_id.as_slice(), dh_public.as_slice(), msg_num, enc.as_slice(), nonce.as_slice()],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove_skipped_key(
|
||||
&self,
|
||||
session_id: &SessionId,
|
||||
dh_public: &[u8; 32],
|
||||
msg_num: u32,
|
||||
) -> Result<(), StorageError> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
conn.execute(
|
||||
"DELETE FROM skipped_keys WHERE session_id = ? AND dh_public = ? AND msg_num = ?",
|
||||
params![session_id.as_slice(), dh_public.as_slice(), msg_num],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_state(&self, session_id: &SessionId) -> Result<Option<StoredState>, StorageError> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
|
||||
let row = conn
|
||||
.query_row(
|
||||
r#"
|
||||
SELECT root_key, sending_chain, receiving_chain,
|
||||
dh_self_secret_encrypted, dh_self_secret_nonce, dh_self_public,
|
||||
dh_remote, msg_send, msg_recv, prev_chain_len, domain_id
|
||||
FROM ratchet_states WHERE session_id = ?
|
||||
"#,
|
||||
"SELECT root_key, sending_chain, receiving_chain, dh_secret_enc, dh_secret_nonce, dh_public, dh_remote, msg_send, msg_recv, prev_chain_len FROM sessions WHERE session_id = ?",
|
||||
[session_id.as_slice()],
|
||||
|row| {
|
||||
Ok((
|
||||
@ -319,91 +231,64 @@ impl RatchetStorage for SqliteStorage {
|
||||
row.get::<_, u32>(7)?,
|
||||
row.get::<_, u32>(8)?,
|
||||
row.get::<_, u32>(9)?,
|
||||
row.get::<_, String>(10)?,
|
||||
))
|
||||
},
|
||||
)
|
||||
.optional()?;
|
||||
|
||||
let Some((
|
||||
root_key_bytes,
|
||||
sending_chain_bytes,
|
||||
receiving_chain_bytes,
|
||||
encrypted_secret,
|
||||
nonce_bytes,
|
||||
dh_self_public_bytes,
|
||||
dh_remote_bytes,
|
||||
msg_send,
|
||||
msg_recv,
|
||||
prev_chain_len,
|
||||
domain_id,
|
||||
)) = row
|
||||
else {
|
||||
root_key_bytes, sending_bytes, receiving_bytes,
|
||||
secret_enc, secret_nonce, dh_pub_bytes, dh_remote_bytes,
|
||||
msg_send, msg_recv, prev_chain_len,
|
||||
)) = row else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
// Decrypt the secret
|
||||
let nonce: [u8; 12] = nonce_bytes
|
||||
.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("invalid nonce length".to_string()))?;
|
||||
let dh_self_secret = self.decrypt_secret(&encrypted_secret, &nonce)?;
|
||||
let nonce: [u8; 12] = secret_nonce.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("invalid nonce".into()))?;
|
||||
let dh_self_secret = self.decrypt(&secret_enc, &nonce)?;
|
||||
|
||||
// Convert byte vectors to arrays
|
||||
let root_key: [u8; 32] = root_key_bytes
|
||||
.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("invalid root_key length".to_string()))?;
|
||||
let dh_self_public: [u8; 32] = dh_self_public_bytes
|
||||
.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("invalid dh_self_public length".to_string()))?;
|
||||
let root_key: [u8; 32] = root_key_bytes.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("invalid root_key".into()))?;
|
||||
let dh_self_public: [u8; 32] = dh_pub_bytes.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("invalid dh_public".into()))?;
|
||||
|
||||
let sending_chain = sending_chain_bytes
|
||||
.map(|b| {
|
||||
b.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("invalid sending_chain length".to_string()))
|
||||
})
|
||||
let sending_chain = sending_bytes
|
||||
.map(|b| b.try_into().map_err(|_| StorageError::CorruptedState("invalid sending_chain".into())))
|
||||
.transpose()?;
|
||||
let receiving_chain = receiving_chain_bytes
|
||||
.map(|b| {
|
||||
b.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("invalid receiving_chain length".to_string()))
|
||||
})
|
||||
let receiving_chain = receiving_bytes
|
||||
.map(|b| b.try_into().map_err(|_| StorageError::CorruptedState("invalid receiving_chain".into())))
|
||||
.transpose()?;
|
||||
let dh_remote = dh_remote_bytes
|
||||
.map(|b| {
|
||||
b.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("invalid dh_remote length".to_string()))
|
||||
})
|
||||
.map(|b| b.try_into().map_err(|_| StorageError::CorruptedState("invalid dh_remote".into())))
|
||||
.transpose()?;
|
||||
|
||||
// Load skipped keys
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT public_key, msg_num, message_key FROM skipped_keys WHERE session_id = ?",
|
||||
"SELECT dh_public, msg_num, message_key_enc, message_key_nonce FROM skipped_keys WHERE session_id = ?",
|
||||
)?;
|
||||
let skipped_keys: Vec<SkippedKey> = stmt
|
||||
.query_map([session_id.as_slice()], |row| {
|
||||
let pk_bytes: Vec<u8> = row.get(0)?;
|
||||
let msg_num: u32 = row.get(1)?;
|
||||
let mk_bytes: Vec<u8> = row.get(2)?;
|
||||
Ok((pk_bytes, msg_num, mk_bytes))
|
||||
})?
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.map(|(pk_bytes, msg_num, mk_bytes)| {
|
||||
let public_key: [u8; 32] = pk_bytes
|
||||
.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("invalid skipped public_key length".to_string()))?;
|
||||
let message_key: [u8; 32] = mk_bytes
|
||||
.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("invalid skipped message_key length".to_string()))?;
|
||||
Ok(SkippedKey {
|
||||
public_key,
|
||||
msg_num,
|
||||
message_key,
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, StorageError>>()?;
|
||||
let mut skipped_keys = Vec::new();
|
||||
let rows = stmt.query_map([session_id.as_slice()], |row| {
|
||||
Ok((
|
||||
row.get::<_, Vec<u8>>(0)?,
|
||||
row.get::<_, u32>(1)?,
|
||||
row.get::<_, Vec<u8>>(2)?,
|
||||
row.get::<_, Vec<u8>>(3)?,
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(Some(StorableRatchetState {
|
||||
for row in rows {
|
||||
let (pk_bytes, msg_num, key_enc, key_nonce) = row?;
|
||||
let dh_public: [u8; 32] = pk_bytes.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("invalid skipped dh_public".into()))?;
|
||||
let nonce: [u8; 12] = key_nonce.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("invalid skipped nonce".into()))?;
|
||||
let message_key = self.decrypt(&key_enc, &nonce)?;
|
||||
|
||||
skipped_keys.push(SkippedKeyEntry { dh_public, msg_num, message_key });
|
||||
}
|
||||
|
||||
Ok(Some(StoredState {
|
||||
root_key,
|
||||
sending_chain,
|
||||
receiving_chain,
|
||||
@ -414,24 +299,72 @@ impl RatchetStorage for SqliteStorage {
|
||||
msg_recv,
|
||||
prev_chain_len,
|
||||
skipped_keys,
|
||||
domain_id,
|
||||
}))
|
||||
}
|
||||
|
||||
fn delete(&self, session_id: &SessionId) -> Result<bool, StorageError> {
|
||||
fn init_session(
|
||||
&self,
|
||||
session_id: &SessionId,
|
||||
root_key: &[u8; 32],
|
||||
sending_chain: Option<&[u8; 32]>,
|
||||
receiving_chain: Option<&[u8; 32]>,
|
||||
dh_self_secret: &[u8; 32],
|
||||
dh_self_public: &[u8; 32],
|
||||
dh_remote: Option<&[u8; 32]>,
|
||||
msg_send: u32,
|
||||
msg_recv: u32,
|
||||
prev_chain_len: u32,
|
||||
) -> Result<(), StorageError> {
|
||||
let (secret_enc, secret_nonce) = self.encrypt(dh_self_secret)?;
|
||||
let conn = self.conn.lock().unwrap();
|
||||
|
||||
conn.execute(
|
||||
r#"
|
||||
INSERT INTO sessions (session_id, root_key, sending_chain, receiving_chain, dh_secret_enc, dh_secret_nonce, dh_public, dh_remote, msg_send, msg_recv, prev_chain_len)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(session_id) DO UPDATE SET
|
||||
root_key = excluded.root_key,
|
||||
sending_chain = excluded.sending_chain,
|
||||
receiving_chain = excluded.receiving_chain,
|
||||
dh_secret_enc = excluded.dh_secret_enc,
|
||||
dh_secret_nonce = excluded.dh_secret_nonce,
|
||||
dh_public = excluded.dh_public,
|
||||
dh_remote = excluded.dh_remote,
|
||||
msg_send = excluded.msg_send,
|
||||
msg_recv = excluded.msg_recv,
|
||||
prev_chain_len = excluded.prev_chain_len
|
||||
"#,
|
||||
params![
|
||||
session_id.as_slice(),
|
||||
root_key.as_slice(),
|
||||
sending_chain.map(|c| c.as_slice()),
|
||||
receiving_chain.map(|c| c.as_slice()),
|
||||
secret_enc.as_slice(),
|
||||
secret_nonce.as_slice(),
|
||||
dh_self_public.as_slice(),
|
||||
dh_remote.map(|r| r.as_slice()),
|
||||
msg_send,
|
||||
msg_recv,
|
||||
prev_chain_len,
|
||||
],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn delete_session(&self, session_id: &SessionId) -> Result<bool, StorageError> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
let changes = conn.execute(
|
||||
"DELETE FROM ratchet_states WHERE session_id = ?",
|
||||
"DELETE FROM sessions WHERE session_id = ?",
|
||||
[session_id.as_slice()],
|
||||
)?;
|
||||
Ok(changes > 0)
|
||||
}
|
||||
|
||||
fn exists(&self, session_id: &SessionId) -> Result<bool, StorageError> {
|
||||
fn session_exists(&self, session_id: &SessionId) -> Result<bool, StorageError> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
let exists: bool = conn
|
||||
.query_row(
|
||||
"SELECT 1 FROM ratchet_states WHERE session_id = ?",
|
||||
"SELECT 1 FROM sessions WHERE session_id = ?",
|
||||
[session_id.as_slice()],
|
||||
|_| Ok(true),
|
||||
)
|
||||
@ -442,15 +375,11 @@ impl RatchetStorage for SqliteStorage {
|
||||
|
||||
fn list_sessions(&self) -> Result<Vec<SessionId>, StorageError> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
let mut stmt = conn.prepare("SELECT session_id FROM ratchet_states")?;
|
||||
let sessions: Vec<SessionId> = stmt
|
||||
.query_map([], |row| {
|
||||
let bytes: Vec<u8> = row.get(0)?;
|
||||
Ok(bytes)
|
||||
})?
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.filter_map(|bytes| bytes.try_into().ok())
|
||||
let mut stmt = conn.prepare("SELECT session_id FROM sessions")?;
|
||||
let sessions = stmt
|
||||
.query_map([], |row| row.get::<_, Vec<u8>>(0))?
|
||||
.filter_map(|r| r.ok())
|
||||
.filter_map(|v| v.try_into().ok())
|
||||
.collect();
|
||||
Ok(sessions)
|
||||
}
|
||||
@ -459,288 +388,99 @@ impl RatchetStorage for SqliteStorage {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use double_ratchets::hkdf::DefaultDomain;
|
||||
use double_ratchets::state::RatchetState;
|
||||
use double_ratchets::InstallationKeyPair;
|
||||
|
||||
fn create_test_storage() -> SqliteStorage {
|
||||
let key = [0x42u8; 32];
|
||||
SqliteStorage::open_in_memory(key).unwrap()
|
||||
}
|
||||
|
||||
fn create_test_state() -> StorableRatchetState {
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
let state: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
||||
StorableRatchetState::from_ratchet_state(&state, "default")
|
||||
fn test_store() -> SqliteRatchetStore {
|
||||
SqliteRatchetStore::open_in_memory([0x42u8; 32]).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_save_and_load() {
|
||||
let storage = create_test_storage();
|
||||
let session_id = [1u8; 32];
|
||||
let state = create_test_state();
|
||||
|
||||
storage.save(&session_id, &state).unwrap();
|
||||
let loaded = storage.load(&session_id).unwrap();
|
||||
|
||||
assert!(loaded.is_some());
|
||||
let loaded = loaded.unwrap();
|
||||
assert_eq!(loaded.root_key, state.root_key);
|
||||
assert_eq!(loaded.dh_self_public, state.dh_self_public);
|
||||
// Secret should be decrypted correctly
|
||||
assert_eq!(loaded.dh_self_secret, state.dh_self_secret);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_nonexistent() {
|
||||
let storage = create_test_storage();
|
||||
fn test_init_and_load() {
|
||||
let store = test_store();
|
||||
let session_id = [1u8; 32];
|
||||
|
||||
let loaded = storage.load(&session_id).unwrap();
|
||||
assert!(loaded.is_none());
|
||||
store.init_session(
|
||||
&session_id,
|
||||
&[0xAA; 32], // root
|
||||
Some(&[0xBB; 32]), // sending
|
||||
None, // receiving
|
||||
&[0xCC; 32], // secret
|
||||
&[0xDD; 32], // public
|
||||
Some(&[0xEE; 32]), // remote
|
||||
5, 3, 2,
|
||||
).unwrap();
|
||||
|
||||
let state = store.load_state(&session_id).unwrap().unwrap();
|
||||
assert_eq!(state.root_key, [0xAA; 32]);
|
||||
assert_eq!(state.sending_chain, Some([0xBB; 32]));
|
||||
assert_eq!(state.receiving_chain, None);
|
||||
assert_eq!(state.dh_self_secret, [0xCC; 32]);
|
||||
assert_eq!(state.dh_self_public, [0xDD; 32]);
|
||||
assert_eq!(state.dh_remote, Some([0xEE; 32]));
|
||||
assert_eq!(state.msg_send, 5);
|
||||
assert_eq!(state.msg_recv, 3);
|
||||
assert_eq!(state.prev_chain_len, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delete() {
|
||||
let storage = create_test_storage();
|
||||
let session_id = [1u8; 32];
|
||||
let state = create_test_state();
|
||||
|
||||
storage.save(&session_id, &state).unwrap();
|
||||
assert!(storage.exists(&session_id).unwrap());
|
||||
|
||||
let deleted = storage.delete(&session_id).unwrap();
|
||||
assert!(deleted);
|
||||
assert!(!storage.exists(&session_id).unwrap());
|
||||
|
||||
// Deleting again should return false
|
||||
let deleted = storage.delete(&session_id).unwrap();
|
||||
assert!(!deleted);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exists() {
|
||||
let storage = create_test_storage();
|
||||
fn test_update_fields() {
|
||||
let store = test_store();
|
||||
let session_id = [1u8; 32];
|
||||
|
||||
assert!(!storage.exists(&session_id).unwrap());
|
||||
store.init_session(
|
||||
&session_id, &[0xAA; 32], None, None,
|
||||
&[0xCC; 32], &[0xDD; 32], None, 0, 0, 0,
|
||||
).unwrap();
|
||||
|
||||
let state = create_test_state();
|
||||
storage.save(&session_id, &state).unwrap();
|
||||
// Update root and chains
|
||||
store.store_root_and_chains(&session_id, &[0x11; 32], Some(&[0x22; 32]), Some(&[0x33; 32])).unwrap();
|
||||
|
||||
assert!(storage.exists(&session_id).unwrap());
|
||||
let state = store.load_state(&session_id).unwrap().unwrap();
|
||||
assert_eq!(state.root_key, [0x11; 32]);
|
||||
assert_eq!(state.sending_chain, Some([0x22; 32]));
|
||||
assert_eq!(state.receiving_chain, Some([0x33; 32]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_list_sessions() {
|
||||
let storage = create_test_storage();
|
||||
|
||||
assert!(storage.list_sessions().unwrap().is_empty());
|
||||
|
||||
let state = create_test_state();
|
||||
let session_ids: Vec<SessionId> = (0..3).map(|i| [i; 32]).collect();
|
||||
|
||||
for id in &session_ids {
|
||||
storage.save(id, &state).unwrap();
|
||||
}
|
||||
|
||||
let mut listed = storage.list_sessions().unwrap();
|
||||
listed.sort();
|
||||
let mut expected = session_ids.clone();
|
||||
expected.sort();
|
||||
|
||||
assert_eq!(listed, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_overwrite() {
|
||||
let storage = create_test_storage();
|
||||
fn test_skipped_keys() {
|
||||
let store = test_store();
|
||||
let session_id = [1u8; 32];
|
||||
|
||||
// Create first state
|
||||
let bob_keypair1 = InstallationKeyPair::generate();
|
||||
let state1: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender([0x42u8; 32], *bob_keypair1.public());
|
||||
let storable1 = StorableRatchetState::from_ratchet_state(&state1, "default");
|
||||
store.init_session(
|
||||
&session_id, &[0xAA; 32], None, None,
|
||||
&[0xCC; 32], &[0xDD; 32], None, 0, 0, 0,
|
||||
).unwrap();
|
||||
|
||||
// Create second state with different root
|
||||
let bob_keypair2 = InstallationKeyPair::generate();
|
||||
let state2: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender([0x43u8; 32], *bob_keypair2.public());
|
||||
let storable2 = StorableRatchetState::from_ratchet_state(&state2, "default");
|
||||
// Add skipped keys
|
||||
store.add_skipped_key(&session_id, &[0x11; 32], 5, &[0xAB; 32]).unwrap();
|
||||
store.add_skipped_key(&session_id, &[0x11; 32], 6, &[0xCD; 32]).unwrap();
|
||||
|
||||
// Save first, then overwrite with second
|
||||
storage.save(&session_id, &storable1).unwrap();
|
||||
storage.save(&session_id, &storable2).unwrap();
|
||||
let state = store.load_state(&session_id).unwrap().unwrap();
|
||||
assert_eq!(state.skipped_keys.len(), 2);
|
||||
|
||||
// Should have the second state
|
||||
let loaded = storage.load(&session_id).unwrap().unwrap();
|
||||
assert_eq!(loaded.root_key, storable2.root_key);
|
||||
assert_ne!(loaded.root_key, storable1.root_key);
|
||||
// Remove one
|
||||
store.remove_skipped_key(&session_id, &[0x11; 32], 5).unwrap();
|
||||
|
||||
let state = store.load_state(&session_id).unwrap().unwrap();
|
||||
assert_eq!(state.skipped_keys.len(), 1);
|
||||
assert_eq!(state.skipped_keys[0].msg_num, 6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skipped_keys_storage() {
|
||||
let storage = create_test_storage();
|
||||
fn test_delete_cascade() {
|
||||
let store = test_store();
|
||||
let session_id = [1u8; 32];
|
||||
|
||||
// Create states and generate skipped keys
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
let mut alice: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
||||
let mut bob: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_receiver(shared_secret, bob_keypair);
|
||||
store.init_session(
|
||||
&session_id, &[0xAA; 32], None, None,
|
||||
&[0xCC; 32], &[0xDD; 32], None, 0, 0, 0,
|
||||
).unwrap();
|
||||
store.add_skipped_key(&session_id, &[0x11; 32], 5, &[0xAB; 32]).unwrap();
|
||||
|
||||
// Alice sends multiple messages
|
||||
let mut messages = vec![];
|
||||
for i in 0..3 {
|
||||
let (ct, header) = alice.encrypt_message(&format!("Message {}", i).into_bytes());
|
||||
messages.push((ct, header));
|
||||
}
|
||||
assert!(store.session_exists(&session_id).unwrap());
|
||||
|
||||
// Bob receives out of order to create skipped keys
|
||||
bob.decrypt_message(&messages[0].0, messages[0].1.clone())
|
||||
.unwrap();
|
||||
bob.decrypt_message(&messages[2].0, messages[2].1.clone())
|
||||
.unwrap();
|
||||
store.delete_session(&session_id).unwrap();
|
||||
|
||||
assert!(!bob.skipped_keys.is_empty());
|
||||
|
||||
// Save and reload
|
||||
let storable = StorableRatchetState::from_ratchet_state(&bob, "default");
|
||||
storage.save(&session_id, &storable).unwrap();
|
||||
|
||||
let loaded = storage.load(&session_id).unwrap().unwrap();
|
||||
assert_eq!(loaded.skipped_keys.len(), storable.skipped_keys.len());
|
||||
|
||||
// Restore and verify we can decrypt the skipped message
|
||||
let mut restored: RatchetState<DefaultDomain> = loaded.to_ratchet_state().unwrap();
|
||||
let pt = restored
|
||||
.decrypt_message(&messages[1].0, messages[1].1.clone())
|
||||
.unwrap();
|
||||
assert_eq!(pt, b"Message 1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encryption_uses_different_nonces() {
|
||||
let storage = create_test_storage();
|
||||
let state = create_test_state();
|
||||
|
||||
// Save the same state twice with different session IDs
|
||||
storage.save(&[1u8; 32], &state).unwrap();
|
||||
storage.save(&[2u8; 32], &state).unwrap();
|
||||
|
||||
// Both should load correctly (encryption with different nonces)
|
||||
let loaded1 = storage.load(&[1u8; 32]).unwrap().unwrap();
|
||||
let loaded2 = storage.load(&[2u8; 32]).unwrap().unwrap();
|
||||
|
||||
assert_eq!(loaded1.dh_self_secret, loaded2.dh_self_secret);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cascade_delete_skipped_keys() {
|
||||
let storage = create_test_storage();
|
||||
let session_id = [1u8; 32];
|
||||
|
||||
// Create a state with skipped keys
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
let mut alice: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
||||
let mut bob: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_receiver(shared_secret, bob_keypair);
|
||||
|
||||
let mut messages = vec![];
|
||||
for i in 0..3 {
|
||||
let (ct, header) = alice.encrypt_message(&format!("Message {}", i).into_bytes());
|
||||
messages.push((ct, header));
|
||||
}
|
||||
|
||||
bob.decrypt_message(&messages[0].0, messages[0].1.clone())
|
||||
.unwrap();
|
||||
bob.decrypt_message(&messages[2].0, messages[2].1.clone())
|
||||
.unwrap();
|
||||
|
||||
let storable = StorableRatchetState::from_ratchet_state(&bob, "default");
|
||||
storage.save(&session_id, &storable).unwrap();
|
||||
|
||||
// Verify skipped keys exist
|
||||
{
|
||||
let conn = storage.conn.lock().unwrap();
|
||||
let count: i64 = conn
|
||||
.query_row(
|
||||
"SELECT COUNT(*) FROM skipped_keys WHERE session_id = ?",
|
||||
[session_id.as_slice()],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.unwrap();
|
||||
assert!(count > 0);
|
||||
}
|
||||
|
||||
// Delete session
|
||||
storage.delete(&session_id).unwrap();
|
||||
|
||||
// Verify skipped keys were also deleted (cascade)
|
||||
{
|
||||
let conn = storage.conn.lock().unwrap();
|
||||
let count: i64 = conn
|
||||
.query_row(
|
||||
"SELECT COUNT(*) FROM skipped_keys WHERE session_id = ?",
|
||||
[session_id.as_slice()],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(count, 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_file_storage() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let db_path = dir.path().join("test.db");
|
||||
let key = [0x42u8; 32];
|
||||
|
||||
let state = create_test_state();
|
||||
let session_id = [1u8; 32];
|
||||
|
||||
// Save in one instance
|
||||
{
|
||||
let storage = SqliteStorage::open(&db_path, key).unwrap();
|
||||
storage.save(&session_id, &state).unwrap();
|
||||
}
|
||||
|
||||
// Load in another instance
|
||||
{
|
||||
let storage = SqliteStorage::open(&db_path, key).unwrap();
|
||||
let loaded = storage.load(&session_id).unwrap().unwrap();
|
||||
assert_eq!(loaded.root_key, state.root_key);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wrong_key_fails_decryption() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let db_path = dir.path().join("test.db");
|
||||
let key1 = [0x42u8; 32];
|
||||
let key2 = [0x43u8; 32];
|
||||
|
||||
let state = create_test_state();
|
||||
let session_id = [1u8; 32];
|
||||
|
||||
// Save with key1
|
||||
{
|
||||
let storage = SqliteStorage::open(&db_path, key1).unwrap();
|
||||
storage.save(&session_id, &state).unwrap();
|
||||
}
|
||||
|
||||
// Try to load with key2 - should fail decryption
|
||||
{
|
||||
let storage = SqliteStorage::open(&db_path, key2).unwrap();
|
||||
let result = storage.load(&session_id);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
assert!(!store.session_exists(&session_id).unwrap());
|
||||
// Skipped keys should be gone too (cascade)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,74 +1,112 @@
|
||||
//! Storage trait definitions.
|
||||
//! Storage trait for field-level ratchet state persistence.
|
||||
|
||||
use crate::error::StorageError;
|
||||
use crate::types::StorableRatchetState;
|
||||
|
||||
/// A 32-byte session identifier.
|
||||
pub type SessionId = [u8; 32];
|
||||
|
||||
/// Abstract storage interface for ratchet states.
|
||||
/// Field-level storage interface for ratchet state.
|
||||
///
|
||||
/// Implementations must be thread-safe (`Send + Sync`).
|
||||
pub trait RatchetStorage: Send + Sync {
|
||||
/// Save a ratchet state for the given session.
|
||||
///
|
||||
/// If a state already exists for this session, it will be overwritten.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `session_id` - Unique identifier for the session.
|
||||
/// * `state` - The ratchet state to store.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(())` on success.
|
||||
/// * `Err(StorageError)` on failure.
|
||||
fn save(&self, session_id: &SessionId, state: &StorableRatchetState) -> Result<(), StorageError>;
|
||||
/// This trait provides granular storage operations that are called automatically
|
||||
/// during ratchet operations. Each method persists only the fields that changed.
|
||||
pub trait RatchetStore: Send + Sync {
|
||||
/// Store the root key and chain keys after a DH ratchet step.
|
||||
fn store_root_and_chains(
|
||||
&self,
|
||||
session_id: &SessionId,
|
||||
root_key: &[u8; 32],
|
||||
sending_chain: Option<&[u8; 32]>,
|
||||
receiving_chain: Option<&[u8; 32]>,
|
||||
) -> Result<(), StorageError>;
|
||||
|
||||
/// Load a ratchet state for the given session.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `session_id` - Unique identifier for the session.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(Some(state))` if found.
|
||||
/// * `Ok(None)` if not found.
|
||||
/// * `Err(StorageError)` on failure.
|
||||
fn load(&self, session_id: &SessionId) -> Result<Option<StorableRatchetState>, StorageError>;
|
||||
/// Store our DH keypair (secret encrypted, public plaintext).
|
||||
fn store_dh_self(
|
||||
&self,
|
||||
session_id: &SessionId,
|
||||
secret: &[u8; 32],
|
||||
public: &[u8; 32],
|
||||
) -> Result<(), StorageError>;
|
||||
|
||||
/// Delete a ratchet state for the given session.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `session_id` - Unique identifier for the session.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(true)` if the session existed and was deleted.
|
||||
/// * `Ok(false)` if the session did not exist.
|
||||
/// * `Err(StorageError)` on failure.
|
||||
fn delete(&self, session_id: &SessionId) -> Result<bool, StorageError>;
|
||||
/// Store the remote party's DH public key.
|
||||
fn store_dh_remote(
|
||||
&self,
|
||||
session_id: &SessionId,
|
||||
remote: Option<&[u8; 32]>,
|
||||
) -> Result<(), StorageError>;
|
||||
|
||||
/// Check if a session exists in storage.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `session_id` - Unique identifier for the session.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(true)` if the session exists.
|
||||
/// * `Ok(false)` if the session does not exist.
|
||||
/// * `Err(StorageError)` on failure.
|
||||
fn exists(&self, session_id: &SessionId) -> Result<bool, StorageError>;
|
||||
/// Store message counters.
|
||||
fn store_counters(
|
||||
&self,
|
||||
session_id: &SessionId,
|
||||
msg_send: u32,
|
||||
msg_recv: u32,
|
||||
prev_chain_len: u32,
|
||||
) -> Result<(), StorageError>;
|
||||
|
||||
/// List all session IDs in storage.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(Vec<SessionId>)` containing all session IDs.
|
||||
/// * `Err(StorageError)` on failure.
|
||||
/// Add a skipped message key.
|
||||
fn add_skipped_key(
|
||||
&self,
|
||||
session_id: &SessionId,
|
||||
dh_public: &[u8; 32],
|
||||
msg_num: u32,
|
||||
message_key: &[u8; 32],
|
||||
) -> Result<(), StorageError>;
|
||||
|
||||
/// Remove a skipped message key (after use).
|
||||
fn remove_skipped_key(
|
||||
&self,
|
||||
session_id: &SessionId,
|
||||
dh_public: &[u8; 32],
|
||||
msg_num: u32,
|
||||
) -> Result<(), StorageError>;
|
||||
|
||||
/// Load all state for a session. Returns None if session doesn't exist.
|
||||
fn load_state(&self, session_id: &SessionId) -> Result<Option<StoredState>, StorageError>;
|
||||
|
||||
/// Initialize a new session with all fields.
|
||||
fn init_session(
|
||||
&self,
|
||||
session_id: &SessionId,
|
||||
root_key: &[u8; 32],
|
||||
sending_chain: Option<&[u8; 32]>,
|
||||
receiving_chain: Option<&[u8; 32]>,
|
||||
dh_self_secret: &[u8; 32],
|
||||
dh_self_public: &[u8; 32],
|
||||
dh_remote: Option<&[u8; 32]>,
|
||||
msg_send: u32,
|
||||
msg_recv: u32,
|
||||
prev_chain_len: u32,
|
||||
) -> Result<(), StorageError>;
|
||||
|
||||
/// Delete a session and all its data.
|
||||
fn delete_session(&self, session_id: &SessionId) -> Result<bool, StorageError>;
|
||||
|
||||
/// Check if a session exists.
|
||||
fn session_exists(&self, session_id: &SessionId) -> Result<bool, StorageError>;
|
||||
|
||||
/// List all session IDs.
|
||||
fn list_sessions(&self) -> Result<Vec<SessionId>, StorageError>;
|
||||
}
|
||||
|
||||
/// Complete state loaded from storage.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StoredState {
|
||||
pub root_key: [u8; 32],
|
||||
pub sending_chain: Option<[u8; 32]>,
|
||||
pub receiving_chain: Option<[u8; 32]>,
|
||||
pub dh_self_secret: [u8; 32],
|
||||
pub dh_self_public: [u8; 32],
|
||||
pub dh_remote: Option<[u8; 32]>,
|
||||
pub msg_send: u32,
|
||||
pub msg_recv: u32,
|
||||
pub prev_chain_len: u32,
|
||||
pub skipped_keys: Vec<SkippedKeyEntry>,
|
||||
}
|
||||
|
||||
/// A skipped key entry from storage.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SkippedKeyEntry {
|
||||
pub dh_public: [u8; 32],
|
||||
pub msg_num: u32,
|
||||
pub message_key: [u8; 32],
|
||||
}
|
||||
|
||||
@ -1,225 +0,0 @@
|
||||
//! Serializable types for ratchet state storage.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use double_ratchets::state::RatchetState;
|
||||
use double_ratchets::hkdf::HkdfInfo;
|
||||
use double_ratchets::InstallationKeyPair;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use x25519_dalek::PublicKey;
|
||||
|
||||
use crate::error::StorageError;
|
||||
|
||||
/// A skipped message key entry for storage.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct SkippedKey {
|
||||
/// The public key associated with this skipped message.
|
||||
pub public_key: [u8; 32],
|
||||
/// The message number.
|
||||
pub msg_num: u32,
|
||||
/// The 32-byte message key.
|
||||
pub message_key: [u8; 32],
|
||||
}
|
||||
|
||||
/// Serializable version of `RatchetState`.
|
||||
///
|
||||
/// This struct stores all keys as raw byte arrays for easy serialization
|
||||
/// and database storage. Use `from_ratchet_state()` and `to_ratchet_state()`
|
||||
/// for conversion.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StorableRatchetState {
|
||||
/// The current root key (32 bytes).
|
||||
pub root_key: [u8; 32],
|
||||
|
||||
/// The current sending chain key, if any (32 bytes).
|
||||
pub sending_chain: Option<[u8; 32]>,
|
||||
|
||||
/// The current receiving chain key, if any (32 bytes).
|
||||
pub receiving_chain: Option<[u8; 32]>,
|
||||
|
||||
/// Our DH secret key (32 bytes).
|
||||
///
|
||||
/// **Security**: This should be encrypted before storage.
|
||||
pub dh_self_secret: [u8; 32],
|
||||
|
||||
/// Our DH public key (32 bytes).
|
||||
pub dh_self_public: [u8; 32],
|
||||
|
||||
/// Remote party's DH public key, if known (32 bytes).
|
||||
pub dh_remote: Option<[u8; 32]>,
|
||||
|
||||
/// Number of messages sent in the current sending chain.
|
||||
pub msg_send: u32,
|
||||
|
||||
/// Number of messages received in the current receiving chain.
|
||||
pub msg_recv: u32,
|
||||
|
||||
/// Length of the previous sending chain.
|
||||
pub prev_chain_len: u32,
|
||||
|
||||
/// Skipped message keys for out-of-order message handling.
|
||||
pub skipped_keys: Vec<SkippedKey>,
|
||||
|
||||
/// Domain identifier for HKDF info.
|
||||
pub domain_id: String,
|
||||
}
|
||||
|
||||
impl StorableRatchetState {
|
||||
/// Convert a `RatchetState` into a `StorableRatchetState`.
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `D` - The HKDF domain type implementing `HkdfInfo`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `state` - The ratchet state to convert.
|
||||
/// * `domain_id` - A string identifier for the domain (used to reconstruct the correct domain type).
|
||||
pub fn from_ratchet_state<D: HkdfInfo>(state: &RatchetState<D>, domain_id: &str) -> Self {
|
||||
let skipped_keys: Vec<SkippedKey> = state
|
||||
.skipped_keys
|
||||
.iter()
|
||||
.map(|((pub_key, msg_num), msg_key)| SkippedKey {
|
||||
public_key: *pub_key.as_bytes(),
|
||||
msg_num: *msg_num,
|
||||
message_key: *msg_key,
|
||||
})
|
||||
.collect();
|
||||
|
||||
StorableRatchetState {
|
||||
root_key: state.root_key,
|
||||
sending_chain: state.sending_chain,
|
||||
receiving_chain: state.receiving_chain,
|
||||
dh_self_secret: state.dh_self.secret_bytes(),
|
||||
dh_self_public: *state.dh_self.public().as_bytes(),
|
||||
dh_remote: state.dh_remote.map(|pk| *pk.as_bytes()),
|
||||
msg_send: state.msg_send,
|
||||
msg_recv: state.msg_recv,
|
||||
prev_chain_len: state.prev_chain_len,
|
||||
skipped_keys,
|
||||
domain_id: domain_id.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert this `StorableRatchetState` back into a `RatchetState`.
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `D` - The HKDF domain type implementing `HkdfInfo`.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(RatchetState)` on success.
|
||||
/// * `Err(StorageError)` if key reconstruction fails.
|
||||
pub fn to_ratchet_state<D: HkdfInfo>(&self) -> Result<RatchetState<D>, StorageError> {
|
||||
// Reconstruct the keypair
|
||||
let dh_self = InstallationKeyPair::from_bytes(self.dh_self_secret, self.dh_self_public)
|
||||
.map_err(|e| StorageError::KeyReconstruction(e.to_string()))?;
|
||||
|
||||
// Reconstruct skipped keys HashMap
|
||||
let skipped_keys: HashMap<(PublicKey, u32), [u8; 32]> = self
|
||||
.skipped_keys
|
||||
.iter()
|
||||
.map(|sk| {
|
||||
let pub_key = PublicKey::from(sk.public_key);
|
||||
((pub_key, sk.msg_num), sk.message_key)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(RatchetState::from_parts(
|
||||
self.root_key,
|
||||
self.sending_chain,
|
||||
self.receiving_chain,
|
||||
dh_self,
|
||||
self.dh_remote.map(PublicKey::from),
|
||||
self.msg_send,
|
||||
self.msg_recv,
|
||||
self.prev_chain_len,
|
||||
skipped_keys,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use double_ratchets::hkdf::DefaultDomain;
|
||||
|
||||
#[test]
|
||||
fn test_roundtrip_sender_state() {
|
||||
// Create a sender state
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
let state: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
||||
|
||||
// Convert to storable and back
|
||||
let storable = StorableRatchetState::from_ratchet_state(&state, "default");
|
||||
let restored: RatchetState<DefaultDomain> = storable.to_ratchet_state().unwrap();
|
||||
|
||||
// Verify fields match
|
||||
assert_eq!(state.root_key, restored.root_key);
|
||||
assert_eq!(state.sending_chain, restored.sending_chain);
|
||||
assert_eq!(state.receiving_chain, restored.receiving_chain);
|
||||
assert_eq!(state.dh_self.public().as_bytes(), restored.dh_self.public().as_bytes());
|
||||
assert_eq!(state.dh_remote.map(|pk| *pk.as_bytes()), restored.dh_remote.map(|pk| *pk.as_bytes()));
|
||||
assert_eq!(state.msg_send, restored.msg_send);
|
||||
assert_eq!(state.msg_recv, restored.msg_recv);
|
||||
assert_eq!(state.prev_chain_len, restored.prev_chain_len);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_roundtrip_receiver_state() {
|
||||
// Create a receiver state
|
||||
let keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
let state: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_receiver(shared_secret, keypair);
|
||||
|
||||
// Convert to storable and back
|
||||
let storable = StorableRatchetState::from_ratchet_state(&state, "default");
|
||||
let restored: RatchetState<DefaultDomain> = storable.to_ratchet_state().unwrap();
|
||||
|
||||
// Verify fields match
|
||||
assert_eq!(state.root_key, restored.root_key);
|
||||
assert_eq!(state.dh_self.public().as_bytes(), restored.dh_self.public().as_bytes());
|
||||
assert!(restored.dh_remote.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_roundtrip_with_skipped_keys() {
|
||||
// Create states and exchange messages to generate skipped keys
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
let mut alice: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
||||
let mut bob: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_receiver(shared_secret, bob_keypair);
|
||||
|
||||
// Alice sends multiple messages
|
||||
let mut messages = vec![];
|
||||
for i in 0..3 {
|
||||
let (ct, header) = alice.encrypt_message(&format!("Message {}", i).into_bytes());
|
||||
messages.push((ct, header));
|
||||
}
|
||||
|
||||
// Bob receives them out of order to create skipped keys
|
||||
bob.decrypt_message(&messages[0].0, messages[0].1.clone()).unwrap();
|
||||
bob.decrypt_message(&messages[2].0, messages[2].1.clone()).unwrap();
|
||||
// Message 1 key is now in skipped_keys
|
||||
|
||||
assert!(!bob.skipped_keys.is_empty());
|
||||
|
||||
// Convert to storable and back
|
||||
let storable = StorableRatchetState::from_ratchet_state(&bob, "default");
|
||||
let restored: RatchetState<DefaultDomain> = storable.to_ratchet_state().unwrap();
|
||||
|
||||
// Verify skipped keys are preserved
|
||||
assert_eq!(bob.skipped_keys.len(), restored.skipped_keys.len());
|
||||
|
||||
// The restored state should be able to decrypt the skipped message
|
||||
let mut restored = restored;
|
||||
let pt = restored.decrypt_message(&messages[1].0, messages[1].1.clone()).unwrap();
|
||||
assert_eq!(pt, b"Message 1");
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user