mirror of
https://github.com/logos-messaging/libchat.git
synced 2026-02-10 08:53:08 +00:00
chore: refactor
This commit is contained in:
parent
34a03275cc
commit
9f968fc80d
@ -1,10 +1,13 @@
|
|||||||
//! Persistent storage for Double Ratchet state.
|
//! Persistent storage for Double Ratchet state.
|
||||||
//!
|
//!
|
||||||
//! This crate provides storage backends for persisting [`RatchetState`](double_ratchets::RatchetState)
|
//! This crate provides storage backends for persisting [`RatchetState`](double_ratchets::RatchetState)
|
||||||
//! across application restarts. It includes:
|
//! with automatic field-level persistence on each ratchet operation.
|
||||||
//!
|
//!
|
||||||
//! - [`MemoryStorage`] - In-memory storage for testing
|
//! # Main API
|
||||||
//! - [`SqliteStorage`] - SQLite storage with field-level encryption (requires `sqlite` feature)
|
//!
|
||||||
|
//! - [`PersistentRatchet`] - Wrapper that auto-persists state changes during encrypt/decrypt
|
||||||
|
//! - [`SqliteRatchetStore`] - SQLite backend with field-level encryption
|
||||||
|
//! - [`RatchetStore`] - Trait for implementing custom storage backends
|
||||||
//!
|
//!
|
||||||
//! # Features
|
//! # Features
|
||||||
//!
|
//!
|
||||||
@ -13,166 +16,233 @@
|
|||||||
//!
|
//!
|
||||||
//! # Security
|
//! # Security
|
||||||
//!
|
//!
|
||||||
//! Private keys (`dh_self_secret`) are always encrypted with ChaCha20Poly1305 before storage,
|
//! Private keys (`dh_self_secret`) and message keys are always encrypted with ChaCha20Poly1305
|
||||||
//! even when using plain SQLite. For additional security, enable the `sqlcipher` feature
|
//! before storage, even when using plain SQLite. For additional security, enable the `sqlcipher`
|
||||||
//! for full database encryption.
|
//! feature for full database encryption.
|
||||||
//!
|
//!
|
||||||
//! # Example
|
//! # Example
|
||||||
//!
|
//!
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
|
//! use std::sync::Arc;
|
||||||
//! use double_ratchets::hkdf::DefaultDomain;
|
//! use double_ratchets::hkdf::DefaultDomain;
|
||||||
//! use double_ratchets::state::RatchetState;
|
|
||||||
//! use double_ratchets::InstallationKeyPair;
|
//! use double_ratchets::InstallationKeyPair;
|
||||||
//! use double_ratchets_storage::{
|
//! use double_ratchets_storage::{PersistentRatchet, RatchetStore, SqliteRatchetStore};
|
||||||
//! RatchetStorage, SqliteStorage, StorableRatchetState,
|
|
||||||
//! };
|
|
||||||
//!
|
|
||||||
//! // Create a ratchet state
|
|
||||||
//! let bob_keypair = InstallationKeyPair::generate();
|
|
||||||
//! let shared_secret = [0x42u8; 32];
|
|
||||||
//! let state: RatchetState<DefaultDomain> =
|
|
||||||
//! RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
|
||||||
//!
|
//!
|
||||||
//! // Open storage
|
//! // Open storage
|
||||||
//! let encryption_key = [0u8; 32]; // Use proper key derivation!
|
//! let encryption_key = [0u8; 32]; // Use proper key derivation!
|
||||||
//! let storage = SqliteStorage::open("ratchets.db", encryption_key).unwrap();
|
//! let store: Arc<dyn RatchetStore> =
|
||||||
|
//! Arc::new(SqliteRatchetStore::open("ratchets.db", encryption_key).unwrap());
|
||||||
//!
|
//!
|
||||||
//! // Save state
|
//! // Initialize sender
|
||||||
|
//! let bob_keypair = InstallationKeyPair::generate();
|
||||||
|
//! let shared_secret = [0x42u8; 32];
|
||||||
//! let session_id = [1u8; 32];
|
//! let session_id = [1u8; 32];
|
||||||
//! let storable = StorableRatchetState::from_ratchet_state(&state, "default");
|
|
||||||
//! storage.save(&session_id, &storable).unwrap();
|
|
||||||
//!
|
//!
|
||||||
//! // Load state
|
//! let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
|
||||||
//! let loaded = storage.load(&session_id).unwrap().unwrap();
|
//! Arc::clone(&store),
|
||||||
//! let restored: RatchetState<DefaultDomain> = loaded.to_ratchet_state().unwrap();
|
//! session_id,
|
||||||
|
//! shared_secret,
|
||||||
|
//! *bob_keypair.public(),
|
||||||
|
//! ).unwrap();
|
||||||
|
//!
|
||||||
|
//! // Encrypt - state is automatically persisted
|
||||||
|
//! let (ciphertext, header) = alice.encrypt_message(b"Hello!").unwrap();
|
||||||
|
//!
|
||||||
|
//! // Later: load from storage
|
||||||
|
//! let alice_restored: PersistentRatchet<DefaultDomain> =
|
||||||
|
//! PersistentRatchet::load(store, session_id).unwrap().unwrap();
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod memory;
|
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
|
||||||
|
pub mod persistent;
|
||||||
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
|
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
|
||||||
pub mod sqlite;
|
pub mod sqlite;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
pub mod types;
|
|
||||||
|
|
||||||
// Re-exports for convenience
|
// Re-exports for convenience
|
||||||
pub use error::StorageError;
|
pub use error::StorageError;
|
||||||
pub use memory::MemoryStorage;
|
|
||||||
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
|
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
|
||||||
pub use sqlite::{EncryptionKey, SqliteStorage};
|
pub use persistent::{PersistentRatchet, PersistentRatchetError};
|
||||||
pub use traits::{RatchetStorage, SessionId};
|
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
|
||||||
pub use types::{SkippedKey, StorableRatchetState};
|
pub use sqlite::{EncryptionKey, SqliteRatchetStore};
|
||||||
|
pub use traits::{RatchetStore, SessionId, SkippedKeyEntry, StoredState};
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(all(test, any(feature = "sqlite", feature = "sqlcipher")))]
|
||||||
mod integration_tests {
|
mod integration_tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use double_ratchets::hkdf::DefaultDomain;
|
use double_ratchets::hkdf::DefaultDomain;
|
||||||
use double_ratchets::state::RatchetState;
|
|
||||||
use double_ratchets::InstallationKeyPair;
|
use double_ratchets::InstallationKeyPair;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// Integration test: full encryption/decryption cycle with storage
|
/// Integration test: full conversation with auto-persist
|
||||||
#[test]
|
#[test]
|
||||||
fn test_full_conversation_with_storage_roundtrip() {
|
fn test_full_conversation_with_auto_persist() {
|
||||||
// Setup Alice and Bob
|
let store: Arc<dyn RatchetStore> =
|
||||||
let bob_keypair = InstallationKeyPair::generate();
|
Arc::new(SqliteRatchetStore::open_in_memory([0x42u8; 32]).unwrap());
|
||||||
let shared_secret = [0x42u8; 32];
|
|
||||||
|
|
||||||
let mut alice: RatchetState<DefaultDomain> =
|
|
||||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
|
||||||
let mut bob: RatchetState<DefaultDomain> =
|
|
||||||
RatchetState::init_receiver(shared_secret, bob_keypair);
|
|
||||||
|
|
||||||
let storage = MemoryStorage::new();
|
|
||||||
let alice_session = [0xAA; 32];
|
let alice_session = [0xAA; 32];
|
||||||
let bob_session = [0xBB; 32];
|
let bob_session = [0xBB; 32];
|
||||||
|
|
||||||
// Alice sends a message
|
let bob_keypair = InstallationKeyPair::generate();
|
||||||
let (ct1, header1) = alice.encrypt_message(b"Hello Bob!");
|
let shared_secret = [0x42u8; 32];
|
||||||
|
|
||||||
// Save Alice's state
|
// Initialize both parties - state is auto-persisted
|
||||||
let alice_storable = StorableRatchetState::from_ratchet_state(&alice, "default");
|
let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
|
||||||
storage.save(&alice_session, &alice_storable).unwrap();
|
Arc::clone(&store),
|
||||||
|
alice_session,
|
||||||
|
shared_secret,
|
||||||
|
*bob_keypair.public(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Bob receives the message
|
let mut bob: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_receiver(
|
||||||
|
Arc::clone(&store),
|
||||||
|
bob_session,
|
||||||
|
shared_secret,
|
||||||
|
bob_keypair,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Alice sends - state auto-persisted
|
||||||
|
let (ct1, header1) = alice.encrypt_message(b"Hello Bob!").unwrap();
|
||||||
let pt1 = bob.decrypt_message(&ct1, header1).unwrap();
|
let pt1 = bob.decrypt_message(&ct1, header1).unwrap();
|
||||||
assert_eq!(pt1, b"Hello Bob!");
|
assert_eq!(pt1, b"Hello Bob!");
|
||||||
|
|
||||||
// Save Bob's state
|
// Verify state was persisted
|
||||||
let bob_storable = StorableRatchetState::from_ratchet_state(&bob, "default");
|
assert!(store.session_exists(&alice_session).unwrap());
|
||||||
storage.save(&bob_session, &bob_storable).unwrap();
|
assert!(store.session_exists(&bob_session).unwrap());
|
||||||
|
|
||||||
// Simulate restart: load states from storage
|
// Bob replies - state auto-persisted
|
||||||
let alice_loaded = storage.load(&alice_session).unwrap().unwrap();
|
let (ct2, header2) = bob.encrypt_message(b"Hi Alice!").unwrap();
|
||||||
let bob_loaded = storage.load(&bob_session).unwrap().unwrap();
|
let pt2 = alice.decrypt_message(&ct2, header2).unwrap();
|
||||||
|
|
||||||
let mut alice_restored: RatchetState<DefaultDomain> =
|
|
||||||
alice_loaded.to_ratchet_state().unwrap();
|
|
||||||
let mut bob_restored: RatchetState<DefaultDomain> =
|
|
||||||
bob_loaded.to_ratchet_state().unwrap();
|
|
||||||
|
|
||||||
// Bob replies
|
|
||||||
let (ct2, header2) = bob_restored.encrypt_message(b"Hi Alice!");
|
|
||||||
let pt2 = alice_restored.decrypt_message(&ct2, header2).unwrap();
|
|
||||||
assert_eq!(pt2, b"Hi Alice!");
|
assert_eq!(pt2, b"Hi Alice!");
|
||||||
|
|
||||||
// Alice sends another message
|
// Alice sends another - state auto-persisted
|
||||||
let (ct3, header3) = alice_restored.encrypt_message(b"How are you?");
|
let (ct3, header3) = alice.encrypt_message(b"How are you?").unwrap();
|
||||||
let pt3 = bob_restored.decrypt_message(&ct3, header3).unwrap();
|
let pt3 = bob.decrypt_message(&ct3, header3).unwrap();
|
||||||
assert_eq!(pt3, b"How are you?");
|
assert_eq!(pt3, b"How are you?");
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Integration test: verify SQLite storage with encryption works
|
/// Integration test: verify SQLite storage with file persistence
|
||||||
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_sqlite_integration() {
|
fn test_sqlite_file_persistence() {
|
||||||
let dir = tempfile::tempdir().unwrap();
|
let dir = tempfile::tempdir().unwrap();
|
||||||
let db_path = dir.path().join("integration_test.db");
|
let db_path = dir.path().join("integration_test.db");
|
||||||
let key = [0x42u8; 32];
|
let key = [0x42u8; 32];
|
||||||
|
|
||||||
// Setup
|
|
||||||
let bob_keypair = InstallationKeyPair::generate();
|
let bob_keypair = InstallationKeyPair::generate();
|
||||||
|
let bob_pub = *bob_keypair.public();
|
||||||
let shared_secret = [0x42u8; 32];
|
let shared_secret = [0x42u8; 32];
|
||||||
let mut alice: RatchetState<DefaultDomain> =
|
|
||||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
|
||||||
let mut bob: RatchetState<DefaultDomain> =
|
|
||||||
RatchetState::init_receiver(shared_secret, bob_keypair);
|
|
||||||
|
|
||||||
let alice_session = [0xAA; 32];
|
let alice_session = [0xAA; 32];
|
||||||
let bob_session = [0xBB; 32];
|
let bob_session = [0xBB; 32];
|
||||||
|
|
||||||
// Exchange messages
|
// First session: exchange messages
|
||||||
let (ct1, header1) = alice.encrypt_message(b"Message 1");
|
|
||||||
bob.decrypt_message(&ct1, header1).unwrap();
|
|
||||||
|
|
||||||
let (ct2, header2) = bob.encrypt_message(b"Response 1");
|
|
||||||
alice.decrypt_message(&ct2, header2).unwrap();
|
|
||||||
|
|
||||||
// Save both states
|
|
||||||
{
|
{
|
||||||
let storage = SqliteStorage::open(&db_path, key).unwrap();
|
let store: Arc<dyn RatchetStore> =
|
||||||
|
Arc::new(SqliteRatchetStore::open(&db_path, key).unwrap());
|
||||||
|
|
||||||
let alice_storable = StorableRatchetState::from_ratchet_state(&alice, "default");
|
let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
|
||||||
let bob_storable = StorableRatchetState::from_ratchet_state(&bob, "default");
|
Arc::clone(&store),
|
||||||
|
alice_session,
|
||||||
|
shared_secret,
|
||||||
|
bob_pub,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
storage.save(&alice_session, &alice_storable).unwrap();
|
let mut bob: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_receiver(
|
||||||
storage.save(&bob_session, &bob_storable).unwrap();
|
Arc::clone(&store),
|
||||||
|
bob_session,
|
||||||
|
shared_secret,
|
||||||
|
bob_keypair,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let (ct1, h1) = alice.encrypt_message(b"Message 1").unwrap();
|
||||||
|
bob.decrypt_message(&ct1, h1).unwrap();
|
||||||
|
|
||||||
|
let (ct2, h2) = bob.encrypt_message(b"Response 1").unwrap();
|
||||||
|
alice.decrypt_message(&ct2, h2).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reopen database (simulating restart)
|
// Reopen database (simulating restart)
|
||||||
{
|
{
|
||||||
let storage = SqliteStorage::open(&db_path, key).unwrap();
|
let store: Arc<dyn RatchetStore> =
|
||||||
|
Arc::new(SqliteRatchetStore::open(&db_path, key).unwrap());
|
||||||
|
|
||||||
let alice_loaded = storage.load(&alice_session).unwrap().unwrap();
|
let mut alice: PersistentRatchet<DefaultDomain> =
|
||||||
let bob_loaded = storage.load(&bob_session).unwrap().unwrap();
|
PersistentRatchet::load(Arc::clone(&store), alice_session)
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let mut alice_restored: RatchetState<DefaultDomain> =
|
let mut bob: PersistentRatchet<DefaultDomain> =
|
||||||
alice_loaded.to_ratchet_state().unwrap();
|
PersistentRatchet::load(Arc::clone(&store), bob_session)
|
||||||
let mut bob_restored: RatchetState<DefaultDomain> =
|
.unwrap()
|
||||||
bob_loaded.to_ratchet_state().unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Continue conversation
|
// Continue conversation
|
||||||
let (ct3, header3) = alice_restored.encrypt_message(b"Message 2");
|
let (ct3, h3) = alice.encrypt_message(b"Message 2").unwrap();
|
||||||
let pt3 = bob_restored.decrypt_message(&ct3, header3).unwrap();
|
let pt3 = bob.decrypt_message(&ct3, h3).unwrap();
|
||||||
assert_eq!(pt3, b"Message 2");
|
assert_eq!(pt3, b"Message 2");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Integration test: out-of-order messages with skipped keys
|
||||||
|
#[test]
|
||||||
|
fn test_out_of_order_messages_persisted() {
|
||||||
|
let store: Arc<dyn RatchetStore> =
|
||||||
|
Arc::new(SqliteRatchetStore::open_in_memory([0x42u8; 32]).unwrap());
|
||||||
|
let alice_session = [0xAA; 32];
|
||||||
|
let bob_session = [0xBB; 32];
|
||||||
|
|
||||||
|
let bob_keypair = InstallationKeyPair::generate();
|
||||||
|
let shared_secret = [0x42u8; 32];
|
||||||
|
|
||||||
|
let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
|
||||||
|
Arc::clone(&store),
|
||||||
|
alice_session,
|
||||||
|
shared_secret,
|
||||||
|
*bob_keypair.public(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut bob: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_receiver(
|
||||||
|
Arc::clone(&store),
|
||||||
|
bob_session,
|
||||||
|
shared_secret,
|
||||||
|
bob_keypair,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Alice sends 4 messages
|
||||||
|
let mut messages = vec![];
|
||||||
|
for i in 0..4 {
|
||||||
|
let (ct, h) = alice
|
||||||
|
.encrypt_message(format!("Message {}", i).as_bytes())
|
||||||
|
.unwrap();
|
||||||
|
messages.push((ct, h));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bob receives 0, 2, 3 (skipping 1)
|
||||||
|
bob.decrypt_message(&messages[0].0, messages[0].1.clone())
|
||||||
|
.unwrap();
|
||||||
|
bob.decrypt_message(&messages[2].0, messages[2].1.clone())
|
||||||
|
.unwrap();
|
||||||
|
bob.decrypt_message(&messages[3].0, messages[3].1.clone())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Verify skipped key is persisted
|
||||||
|
let bob_state = store.load_state(&bob_session).unwrap().unwrap();
|
||||||
|
assert_eq!(bob_state.skipped_keys.len(), 1);
|
||||||
|
assert_eq!(bob_state.skipped_keys[0].msg_num, 1);
|
||||||
|
|
||||||
|
// Now receive the skipped message
|
||||||
|
let pt1 = bob
|
||||||
|
.decrypt_message(&messages[1].0, messages[1].1.clone())
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(pt1, b"Message 1");
|
||||||
|
|
||||||
|
// Skipped key should be removed from storage
|
||||||
|
let bob_state = store.load_state(&bob_session).unwrap().unwrap();
|
||||||
|
assert!(bob_state.skipped_keys.is_empty());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,216 +0,0 @@
|
|||||||
//! In-memory storage implementation for testing.
|
|
||||||
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::sync::RwLock;
|
|
||||||
|
|
||||||
use crate::error::StorageError;
|
|
||||||
use crate::traits::{RatchetStorage, SessionId};
|
|
||||||
use crate::types::StorableRatchetState;
|
|
||||||
|
|
||||||
/// In-memory storage backend for testing purposes.
|
|
||||||
///
|
|
||||||
/// This implementation stores ratchet states in a `HashMap` wrapped in a `RwLock`
|
|
||||||
/// for thread-safe access. Data is not persisted across process restarts.
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```
|
|
||||||
/// use double_ratchets_storage::{MemoryStorage, RatchetStorage};
|
|
||||||
///
|
|
||||||
/// let storage = MemoryStorage::new();
|
|
||||||
/// assert!(storage.list_sessions().unwrap().is_empty());
|
|
||||||
/// ```
|
|
||||||
pub struct MemoryStorage {
|
|
||||||
states: RwLock<HashMap<SessionId, StorableRatchetState>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MemoryStorage {
|
|
||||||
/// Create a new empty in-memory storage.
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
states: RwLock::new(HashMap::new()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the number of stored sessions.
|
|
||||||
pub fn len(&self) -> usize {
|
|
||||||
self.states.read().unwrap().len()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if the storage is empty.
|
|
||||||
pub fn is_empty(&self) -> bool {
|
|
||||||
self.states.read().unwrap().is_empty()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Clear all stored sessions.
|
|
||||||
pub fn clear(&self) {
|
|
||||||
self.states.write().unwrap().clear();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for MemoryStorage {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RatchetStorage for MemoryStorage {
|
|
||||||
fn save(&self, session_id: &SessionId, state: &StorableRatchetState) -> Result<(), StorageError> {
|
|
||||||
let mut states = self.states.write().unwrap();
|
|
||||||
states.insert(*session_id, state.clone());
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load(&self, session_id: &SessionId) -> Result<Option<StorableRatchetState>, StorageError> {
|
|
||||||
let states = self.states.read().unwrap();
|
|
||||||
Ok(states.get(session_id).cloned())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn delete(&self, session_id: &SessionId) -> Result<bool, StorageError> {
|
|
||||||
let mut states = self.states.write().unwrap();
|
|
||||||
Ok(states.remove(session_id).is_some())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn exists(&self, session_id: &SessionId) -> Result<bool, StorageError> {
|
|
||||||
let states = self.states.read().unwrap();
|
|
||||||
Ok(states.contains_key(session_id))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn list_sessions(&self) -> Result<Vec<SessionId>, StorageError> {
|
|
||||||
let states = self.states.read().unwrap();
|
|
||||||
Ok(states.keys().copied().collect())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use double_ratchets::hkdf::DefaultDomain;
|
|
||||||
use double_ratchets::state::RatchetState;
|
|
||||||
use double_ratchets::InstallationKeyPair;
|
|
||||||
|
|
||||||
fn create_test_state() -> StorableRatchetState {
|
|
||||||
let bob_keypair = InstallationKeyPair::generate();
|
|
||||||
let shared_secret = [0x42u8; 32];
|
|
||||||
let state: RatchetState<DefaultDomain> =
|
|
||||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
|
||||||
StorableRatchetState::from_ratchet_state(&state, "default")
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_save_and_load() {
|
|
||||||
let storage = MemoryStorage::new();
|
|
||||||
let session_id = [1u8; 32];
|
|
||||||
let state = create_test_state();
|
|
||||||
|
|
||||||
storage.save(&session_id, &state).unwrap();
|
|
||||||
let loaded = storage.load(&session_id).unwrap();
|
|
||||||
|
|
||||||
assert!(loaded.is_some());
|
|
||||||
let loaded = loaded.unwrap();
|
|
||||||
assert_eq!(loaded.root_key, state.root_key);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_load_nonexistent() {
|
|
||||||
let storage = MemoryStorage::new();
|
|
||||||
let session_id = [1u8; 32];
|
|
||||||
|
|
||||||
let loaded = storage.load(&session_id).unwrap();
|
|
||||||
assert!(loaded.is_none());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_delete() {
|
|
||||||
let storage = MemoryStorage::new();
|
|
||||||
let session_id = [1u8; 32];
|
|
||||||
let state = create_test_state();
|
|
||||||
|
|
||||||
storage.save(&session_id, &state).unwrap();
|
|
||||||
assert!(storage.exists(&session_id).unwrap());
|
|
||||||
|
|
||||||
let deleted = storage.delete(&session_id).unwrap();
|
|
||||||
assert!(deleted);
|
|
||||||
assert!(!storage.exists(&session_id).unwrap());
|
|
||||||
|
|
||||||
// Deleting again should return false
|
|
||||||
let deleted = storage.delete(&session_id).unwrap();
|
|
||||||
assert!(!deleted);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_exists() {
|
|
||||||
let storage = MemoryStorage::new();
|
|
||||||
let session_id = [1u8; 32];
|
|
||||||
|
|
||||||
assert!(!storage.exists(&session_id).unwrap());
|
|
||||||
|
|
||||||
let state = create_test_state();
|
|
||||||
storage.save(&session_id, &state).unwrap();
|
|
||||||
|
|
||||||
assert!(storage.exists(&session_id).unwrap());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_list_sessions() {
|
|
||||||
let storage = MemoryStorage::new();
|
|
||||||
|
|
||||||
assert!(storage.list_sessions().unwrap().is_empty());
|
|
||||||
|
|
||||||
let state = create_test_state();
|
|
||||||
let session_ids: Vec<SessionId> = (0..3).map(|i| [i; 32]).collect();
|
|
||||||
|
|
||||||
for id in &session_ids {
|
|
||||||
storage.save(id, &state).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut listed = storage.list_sessions().unwrap();
|
|
||||||
listed.sort();
|
|
||||||
let mut expected = session_ids.clone();
|
|
||||||
expected.sort();
|
|
||||||
|
|
||||||
assert_eq!(listed, expected);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_overwrite() {
|
|
||||||
let storage = MemoryStorage::new();
|
|
||||||
let session_id = [1u8; 32];
|
|
||||||
|
|
||||||
// Create first state
|
|
||||||
let bob_keypair1 = InstallationKeyPair::generate();
|
|
||||||
let state1: RatchetState<DefaultDomain> =
|
|
||||||
RatchetState::init_sender([0x42u8; 32], *bob_keypair1.public());
|
|
||||||
let storable1 = StorableRatchetState::from_ratchet_state(&state1, "default");
|
|
||||||
|
|
||||||
// Create second state with different root
|
|
||||||
let bob_keypair2 = InstallationKeyPair::generate();
|
|
||||||
let state2: RatchetState<DefaultDomain> =
|
|
||||||
RatchetState::init_sender([0x43u8; 32], *bob_keypair2.public());
|
|
||||||
let storable2 = StorableRatchetState::from_ratchet_state(&state2, "default");
|
|
||||||
|
|
||||||
// Save first, then overwrite with second
|
|
||||||
storage.save(&session_id, &storable1).unwrap();
|
|
||||||
storage.save(&session_id, &storable2).unwrap();
|
|
||||||
|
|
||||||
// Should have the second state
|
|
||||||
let loaded = storage.load(&session_id).unwrap().unwrap();
|
|
||||||
assert_eq!(loaded.root_key, storable2.root_key);
|
|
||||||
assert_ne!(loaded.root_key, storable1.root_key);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_clear() {
|
|
||||||
let storage = MemoryStorage::new();
|
|
||||||
let state = create_test_state();
|
|
||||||
|
|
||||||
for i in 0..5 {
|
|
||||||
storage.save(&[i; 32], &state).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
assert_eq!(storage.len(), 5);
|
|
||||||
|
|
||||||
storage.clear();
|
|
||||||
assert!(storage.is_empty());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
554
double-ratchets-storage/src/persistent.rs
Normal file
554
double-ratchets-storage/src/persistent.rs
Normal file
@ -0,0 +1,554 @@
|
|||||||
|
//! Persistent ratchet wrapper that auto-saves state changes.
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use double_ratchets::errors::RatchetError;
|
||||||
|
use double_ratchets::hkdf::HkdfInfo;
|
||||||
|
use double_ratchets::state::{Header, RatchetState};
|
||||||
|
use double_ratchets::InstallationKeyPair;
|
||||||
|
use x25519_dalek::PublicKey;
|
||||||
|
|
||||||
|
use crate::error::StorageError;
|
||||||
|
use crate::traits::{RatchetStore, SessionId};
|
||||||
|
|
||||||
|
/// A wrapper around `RatchetState` that automatically persists state changes.
|
||||||
|
///
|
||||||
|
/// This wrapper intercepts `encrypt_message` and `decrypt_message` calls,
|
||||||
|
/// delegates to the underlying `RatchetState`, and then persists the changed
|
||||||
|
/// fields to storage.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```no_run
|
||||||
|
/// use double_ratchets::hkdf::DefaultDomain;
|
||||||
|
/// use double_ratchets::InstallationKeyPair;
|
||||||
|
/// use double_ratchets_storage::{PersistentRatchet, SqliteRatchetStore};
|
||||||
|
/// use std::sync::Arc;
|
||||||
|
///
|
||||||
|
/// let store = Arc::new(SqliteRatchetStore::open_in_memory([0u8; 32]).unwrap());
|
||||||
|
/// let session_id = [1u8; 32];
|
||||||
|
/// let bob_pub = InstallationKeyPair::generate().public().clone();
|
||||||
|
/// let shared_secret = [0x42u8; 32];
|
||||||
|
///
|
||||||
|
/// let mut ratchet: PersistentRatchet<DefaultDomain> =
|
||||||
|
/// PersistentRatchet::init_sender(store, session_id, shared_secret, bob_pub).unwrap();
|
||||||
|
///
|
||||||
|
/// let (ciphertext, header) = ratchet.encrypt_message(b"Hello!").unwrap();
|
||||||
|
/// ```
|
||||||
|
pub struct PersistentRatchet<D: HkdfInfo> {
|
||||||
|
state: RatchetState<D>,
|
||||||
|
store: Arc<dyn RatchetStore>,
|
||||||
|
session_id: SessionId,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D: HkdfInfo> PersistentRatchet<D> {
|
||||||
|
/// Initialize as the sender (first to send a message).
|
||||||
|
///
|
||||||
|
/// Creates a new ratchet state and persists it to storage.
|
||||||
|
pub fn init_sender(
|
||||||
|
store: Arc<dyn RatchetStore>,
|
||||||
|
session_id: SessionId,
|
||||||
|
shared_secret: [u8; 32],
|
||||||
|
remote_pub: PublicKey,
|
||||||
|
) -> Result<Self, StorageError> {
|
||||||
|
let state = RatchetState::<D>::init_sender(shared_secret, remote_pub);
|
||||||
|
|
||||||
|
// Persist initial state
|
||||||
|
store.init_session(
|
||||||
|
&session_id,
|
||||||
|
&state.root_key,
|
||||||
|
state.sending_chain.as_ref(),
|
||||||
|
state.receiving_chain.as_ref(),
|
||||||
|
&state.dh_self.secret_bytes(),
|
||||||
|
state.dh_self.public().as_bytes(),
|
||||||
|
state.dh_remote.as_ref().map(|pk| pk.as_bytes()),
|
||||||
|
state.msg_send,
|
||||||
|
state.msg_recv,
|
||||||
|
state.prev_chain_len,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
state,
|
||||||
|
store,
|
||||||
|
session_id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initialize as the receiver (first to receive a message).
|
||||||
|
///
|
||||||
|
/// Creates a new ratchet state and persists it to storage.
|
||||||
|
pub fn init_receiver(
|
||||||
|
store: Arc<dyn RatchetStore>,
|
||||||
|
session_id: SessionId,
|
||||||
|
shared_secret: [u8; 32],
|
||||||
|
dh_self: InstallationKeyPair,
|
||||||
|
) -> Result<Self, StorageError> {
|
||||||
|
let state = RatchetState::<D>::init_receiver(shared_secret, dh_self);
|
||||||
|
|
||||||
|
// Persist initial state
|
||||||
|
store.init_session(
|
||||||
|
&session_id,
|
||||||
|
&state.root_key,
|
||||||
|
state.sending_chain.as_ref(),
|
||||||
|
state.receiving_chain.as_ref(),
|
||||||
|
&state.dh_self.secret_bytes(),
|
||||||
|
state.dh_self.public().as_bytes(),
|
||||||
|
state.dh_remote.as_ref().map(|pk| pk.as_bytes()),
|
||||||
|
state.msg_send,
|
||||||
|
state.msg_recv,
|
||||||
|
state.prev_chain_len,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
state,
|
||||||
|
store,
|
||||||
|
session_id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load an existing session from storage.
|
||||||
|
pub fn load(
|
||||||
|
store: Arc<dyn RatchetStore>,
|
||||||
|
session_id: SessionId,
|
||||||
|
) -> Result<Option<Self>, StorageError> {
|
||||||
|
let Some(stored) = store.load_state(&session_id)? else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Reconstruct the keypair
|
||||||
|
let dh_self = InstallationKeyPair::from_bytes(stored.dh_self_secret, stored.dh_self_public)
|
||||||
|
.map_err(|e| StorageError::KeyReconstruction(e.to_string()))?;
|
||||||
|
|
||||||
|
// Reconstruct skipped keys
|
||||||
|
let mut skipped_keys = std::collections::HashMap::new();
|
||||||
|
for entry in stored.skipped_keys {
|
||||||
|
let pub_key = PublicKey::from(entry.dh_public);
|
||||||
|
skipped_keys.insert((pub_key, entry.msg_num), entry.message_key);
|
||||||
|
}
|
||||||
|
|
||||||
|
let state = RatchetState::<D>::from_parts(
|
||||||
|
stored.root_key,
|
||||||
|
stored.sending_chain,
|
||||||
|
stored.receiving_chain,
|
||||||
|
dh_self,
|
||||||
|
stored.dh_remote.map(PublicKey::from),
|
||||||
|
stored.msg_send,
|
||||||
|
stored.msg_recv,
|
||||||
|
stored.prev_chain_len,
|
||||||
|
skipped_keys,
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(Some(Self {
|
||||||
|
state,
|
||||||
|
store,
|
||||||
|
session_id,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Encrypt a message and persist state changes.
|
||||||
|
///
|
||||||
|
/// This may trigger a DH ratchet if the sending direction changed.
|
||||||
|
/// All state changes are persisted to storage after encryption.
|
||||||
|
pub fn encrypt_message(&mut self, plaintext: &[u8]) -> Result<(Vec<u8>, Header), StorageError> {
|
||||||
|
// Check if we'll do a DH ratchet (no sending chain)
|
||||||
|
let will_ratchet = self.state.sending_chain.is_none();
|
||||||
|
|
||||||
|
// Perform encryption
|
||||||
|
let (ciphertext, header) = self.state.encrypt_message(plaintext);
|
||||||
|
|
||||||
|
// Persist changes
|
||||||
|
if will_ratchet {
|
||||||
|
// DH ratchet happened: root_key, sending_chain, dh_self all changed
|
||||||
|
self.store.store_root_and_chains(
|
||||||
|
&self.session_id,
|
||||||
|
&self.state.root_key,
|
||||||
|
self.state.sending_chain.as_ref(),
|
||||||
|
self.state.receiving_chain.as_ref(),
|
||||||
|
)?;
|
||||||
|
self.store.store_dh_self(
|
||||||
|
&self.session_id,
|
||||||
|
&self.state.dh_self.secret_bytes(),
|
||||||
|
self.state.dh_self.public().as_bytes(),
|
||||||
|
)?;
|
||||||
|
} else {
|
||||||
|
// Only sending chain changed
|
||||||
|
self.store.store_root_and_chains(
|
||||||
|
&self.session_id,
|
||||||
|
&self.state.root_key,
|
||||||
|
self.state.sending_chain.as_ref(),
|
||||||
|
self.state.receiving_chain.as_ref(),
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Counters always change
|
||||||
|
self.store.store_counters(
|
||||||
|
&self.session_id,
|
||||||
|
self.state.msg_send,
|
||||||
|
self.state.msg_recv,
|
||||||
|
self.state.prev_chain_len,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok((ciphertext, header))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decrypt a message and persist state changes.
|
||||||
|
///
|
||||||
|
/// Handles DH ratcheting, skipped messages, and replay protection.
|
||||||
|
/// All state changes are persisted to storage after decryption.
|
||||||
|
pub fn decrypt_message(
|
||||||
|
&mut self,
|
||||||
|
ciphertext_with_nonce: &[u8],
|
||||||
|
header: Header,
|
||||||
|
) -> Result<Vec<u8>, PersistentRatchetError> {
|
||||||
|
// Track skipped keys before decryption
|
||||||
|
let skipped_before: std::collections::HashSet<_> = self
|
||||||
|
.state
|
||||||
|
.skipped_keys
|
||||||
|
.keys()
|
||||||
|
.map(|(pk, n)| (*pk.as_bytes(), *n))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Check if we'll do a DH ratchet
|
||||||
|
let will_ratchet = self.state.dh_remote.as_ref() != Some(&header.dh_pub);
|
||||||
|
|
||||||
|
// Perform decryption
|
||||||
|
let plaintext = self
|
||||||
|
.state
|
||||||
|
.decrypt_message(ciphertext_with_nonce, header)
|
||||||
|
.map_err(PersistentRatchetError::Ratchet)?;
|
||||||
|
|
||||||
|
// Track skipped keys after decryption
|
||||||
|
let skipped_after: std::collections::HashSet<_> = self
|
||||||
|
.state
|
||||||
|
.skipped_keys
|
||||||
|
.keys()
|
||||||
|
.map(|(pk, n)| (*pk.as_bytes(), *n))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Persist changes
|
||||||
|
if will_ratchet {
|
||||||
|
// DH ratchet happened
|
||||||
|
self.store
|
||||||
|
.store_root_and_chains(
|
||||||
|
&self.session_id,
|
||||||
|
&self.state.root_key,
|
||||||
|
self.state.sending_chain.as_ref(),
|
||||||
|
self.state.receiving_chain.as_ref(),
|
||||||
|
)
|
||||||
|
.map_err(PersistentRatchetError::Storage)?;
|
||||||
|
self.store
|
||||||
|
.store_dh_remote(
|
||||||
|
&self.session_id,
|
||||||
|
self.state.dh_remote.as_ref().map(|pk| pk.as_bytes()),
|
||||||
|
)
|
||||||
|
.map_err(PersistentRatchetError::Storage)?;
|
||||||
|
} else {
|
||||||
|
// Only receiving chain changed
|
||||||
|
self.store
|
||||||
|
.store_root_and_chains(
|
||||||
|
&self.session_id,
|
||||||
|
&self.state.root_key,
|
||||||
|
self.state.sending_chain.as_ref(),
|
||||||
|
self.state.receiving_chain.as_ref(),
|
||||||
|
)
|
||||||
|
.map_err(PersistentRatchetError::Storage)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Counters
|
||||||
|
self.store
|
||||||
|
.store_counters(
|
||||||
|
&self.session_id,
|
||||||
|
self.state.msg_send,
|
||||||
|
self.state.msg_recv,
|
||||||
|
self.state.prev_chain_len,
|
||||||
|
)
|
||||||
|
.map_err(PersistentRatchetError::Storage)?;
|
||||||
|
|
||||||
|
// Handle skipped keys changes
|
||||||
|
// New skipped keys (added during skip_message_keys)
|
||||||
|
for (pk_bytes, msg_num) in skipped_after.difference(&skipped_before) {
|
||||||
|
let pk = PublicKey::from(*pk_bytes);
|
||||||
|
if let Some(key) = self.state.skipped_keys.get(&(pk, *msg_num)) {
|
||||||
|
self.store
|
||||||
|
.add_skipped_key(&self.session_id, pk_bytes, *msg_num, key)
|
||||||
|
.map_err(PersistentRatchetError::Storage)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Removed skipped keys (used for decryption)
|
||||||
|
for (pk_bytes, msg_num) in skipped_before.difference(&skipped_after) {
|
||||||
|
self.store
|
||||||
|
.remove_skipped_key(&self.session_id, pk_bytes, *msg_num)
|
||||||
|
.map_err(PersistentRatchetError::Storage)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(plaintext)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a reference to the underlying state (read-only).
|
||||||
|
pub fn state(&self) -> &RatchetState<D> {
|
||||||
|
&self.state
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the session ID.
|
||||||
|
pub fn session_id(&self) -> &SessionId {
|
||||||
|
&self.session_id
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete this session from storage.
|
||||||
|
pub fn delete(self) -> Result<bool, StorageError> {
|
||||||
|
self.store.delete_session(&self.session_id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Error type for persistent ratchet operations.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum PersistentRatchetError {
|
||||||
|
/// Storage operation failed.
|
||||||
|
Storage(StorageError),
|
||||||
|
/// Ratchet operation failed.
|
||||||
|
Ratchet(RatchetError),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for PersistentRatchetError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Self::Storage(e) => write!(f, "storage error: {}", e),
|
||||||
|
Self::Ratchet(e) => write!(f, "ratchet error: {:?}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for PersistentRatchetError {
|
||||||
|
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||||
|
match self {
|
||||||
|
Self::Storage(e) => Some(e),
|
||||||
|
Self::Ratchet(_) => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<StorageError> for PersistentRatchetError {
|
||||||
|
fn from(e: StorageError) -> Self {
|
||||||
|
Self::Storage(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<RatchetError> for PersistentRatchetError {
|
||||||
|
fn from(e: RatchetError) -> Self {
|
||||||
|
Self::Ratchet(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::sqlite::SqliteRatchetStore;
|
||||||
|
use double_ratchets::hkdf::DefaultDomain;
|
||||||
|
|
||||||
|
fn test_store() -> Arc<dyn RatchetStore> {
|
||||||
|
Arc::new(SqliteRatchetStore::open_in_memory([0x42u8; 32]).unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_basic_roundtrip() {
|
||||||
|
let store = test_store();
|
||||||
|
let alice_session = [0xAA; 32];
|
||||||
|
let bob_session = [0xBB; 32];
|
||||||
|
|
||||||
|
let bob_keypair = InstallationKeyPair::generate();
|
||||||
|
let shared_secret = [0x42u8; 32];
|
||||||
|
|
||||||
|
// Initialize both parties
|
||||||
|
let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
|
||||||
|
Arc::clone(&store),
|
||||||
|
alice_session,
|
||||||
|
shared_secret,
|
||||||
|
*bob_keypair.public(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut bob: PersistentRatchet<DefaultDomain> =
|
||||||
|
PersistentRatchet::init_receiver(Arc::clone(&store), bob_session, shared_secret, bob_keypair)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Alice sends a message
|
||||||
|
let (ct, header) = alice.encrypt_message(b"Hello Bob!").unwrap();
|
||||||
|
let pt = bob.decrypt_message(&ct, header).unwrap();
|
||||||
|
assert_eq!(pt, b"Hello Bob!");
|
||||||
|
|
||||||
|
// Verify state was persisted
|
||||||
|
let alice_loaded = store.load_state(&alice_session).unwrap().unwrap();
|
||||||
|
assert_eq!(alice_loaded.msg_send, 1);
|
||||||
|
|
||||||
|
let bob_loaded = store.load_state(&bob_session).unwrap().unwrap();
|
||||||
|
assert_eq!(bob_loaded.msg_recv, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_load_and_continue() {
|
||||||
|
let store = test_store();
|
||||||
|
let alice_session = [0xAA; 32];
|
||||||
|
let bob_session = [0xBB; 32];
|
||||||
|
|
||||||
|
let bob_keypair = InstallationKeyPair::generate();
|
||||||
|
let shared_secret = [0x42u8; 32];
|
||||||
|
|
||||||
|
// Initialize and exchange one message
|
||||||
|
{
|
||||||
|
let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
|
||||||
|
Arc::clone(&store),
|
||||||
|
alice_session,
|
||||||
|
shared_secret,
|
||||||
|
*bob_keypair.public(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut bob: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_receiver(
|
||||||
|
Arc::clone(&store),
|
||||||
|
bob_session,
|
||||||
|
shared_secret,
|
||||||
|
bob_keypair,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let (ct, header) = alice.encrypt_message(b"Message 1").unwrap();
|
||||||
|
bob.decrypt_message(&ct, header).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load from storage and continue
|
||||||
|
{
|
||||||
|
let mut alice: PersistentRatchet<DefaultDomain> =
|
||||||
|
PersistentRatchet::load(Arc::clone(&store), alice_session)
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut bob: PersistentRatchet<DefaultDomain> =
|
||||||
|
PersistentRatchet::load(Arc::clone(&store), bob_session)
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Bob replies
|
||||||
|
let (ct, header) = bob.encrypt_message(b"Reply from Bob").unwrap();
|
||||||
|
let pt = alice.decrypt_message(&ct, header).unwrap();
|
||||||
|
assert_eq!(pt, b"Reply from Bob");
|
||||||
|
|
||||||
|
// Alice sends another
|
||||||
|
let (ct2, header2) = alice.encrypt_message(b"Message 2").unwrap();
|
||||||
|
let pt2 = bob.decrypt_message(&ct2, header2).unwrap();
|
||||||
|
assert_eq!(pt2, b"Message 2");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_skipped_keys_persisted() {
|
||||||
|
let store = test_store();
|
||||||
|
let alice_session = [0xAA; 32];
|
||||||
|
let bob_session = [0xBB; 32];
|
||||||
|
|
||||||
|
let bob_keypair = InstallationKeyPair::generate();
|
||||||
|
let shared_secret = [0x42u8; 32];
|
||||||
|
|
||||||
|
let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
|
||||||
|
Arc::clone(&store),
|
||||||
|
alice_session,
|
||||||
|
shared_secret,
|
||||||
|
*bob_keypair.public(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut bob: PersistentRatchet<DefaultDomain> =
|
||||||
|
PersistentRatchet::init_receiver(Arc::clone(&store), bob_session, shared_secret, bob_keypair)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Alice sends 3 messages
|
||||||
|
let mut messages = vec![];
|
||||||
|
for i in 0..3 {
|
||||||
|
let (ct, header) = alice
|
||||||
|
.encrypt_message(format!("Message {}", i).as_bytes())
|
||||||
|
.unwrap();
|
||||||
|
messages.push((ct, header));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bob receives them out of order: 0, 2 (skipping 1)
|
||||||
|
bob.decrypt_message(&messages[0].0, messages[0].1.clone())
|
||||||
|
.unwrap();
|
||||||
|
bob.decrypt_message(&messages[2].0, messages[2].1.clone())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Check skipped key was persisted
|
||||||
|
let bob_loaded = store.load_state(&bob_session).unwrap().unwrap();
|
||||||
|
assert_eq!(bob_loaded.skipped_keys.len(), 1);
|
||||||
|
assert_eq!(bob_loaded.skipped_keys[0].msg_num, 1);
|
||||||
|
|
||||||
|
// Now receive the skipped message
|
||||||
|
bob.decrypt_message(&messages[1].0, messages[1].1.clone())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Skipped key should be removed
|
||||||
|
let bob_loaded = store.load_state(&bob_session).unwrap().unwrap();
|
||||||
|
assert!(bob_loaded.skipped_keys.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_dh_ratchet_persisted() {
|
||||||
|
let store = test_store();
|
||||||
|
let alice_session = [0xAA; 32];
|
||||||
|
let bob_session = [0xBB; 32];
|
||||||
|
|
||||||
|
let bob_keypair = InstallationKeyPair::generate();
|
||||||
|
let shared_secret = [0x42u8; 32];
|
||||||
|
|
||||||
|
let mut alice: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
|
||||||
|
Arc::clone(&store),
|
||||||
|
alice_session,
|
||||||
|
shared_secret,
|
||||||
|
*bob_keypair.public(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut bob: PersistentRatchet<DefaultDomain> =
|
||||||
|
PersistentRatchet::init_receiver(Arc::clone(&store), bob_session, shared_secret, bob_keypair)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Alice sends
|
||||||
|
let (ct1, h1) = alice.encrypt_message(b"Hello").unwrap();
|
||||||
|
bob.decrypt_message(&ct1, h1).unwrap();
|
||||||
|
|
||||||
|
// Get Bob's initial DH public
|
||||||
|
let bob_initial_pub = bob.state().dh_self.public().as_bytes().clone();
|
||||||
|
|
||||||
|
// Bob replies (triggers DH ratchet)
|
||||||
|
let (ct2, h2) = bob.encrypt_message(b"Hi").unwrap();
|
||||||
|
alice.decrypt_message(&ct2, h2).unwrap();
|
||||||
|
|
||||||
|
// Bob's DH key should have changed
|
||||||
|
let bob_new_pub = bob.state().dh_self.public().as_bytes().clone();
|
||||||
|
assert_ne!(bob_initial_pub, bob_new_pub);
|
||||||
|
|
||||||
|
// Verify persisted
|
||||||
|
let bob_loaded = store.load_state(&bob_session).unwrap().unwrap();
|
||||||
|
assert_eq!(bob_loaded.dh_self_public, bob_new_pub);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_delete_session() {
|
||||||
|
let store = test_store();
|
||||||
|
let session_id = [0x11; 32];
|
||||||
|
let bob_keypair = InstallationKeyPair::generate();
|
||||||
|
|
||||||
|
let ratchet: PersistentRatchet<DefaultDomain> = PersistentRatchet::init_sender(
|
||||||
|
Arc::clone(&store),
|
||||||
|
session_id,
|
||||||
|
[0x42u8; 32],
|
||||||
|
*bob_keypair.public(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(store.session_exists(&session_id).unwrap());
|
||||||
|
|
||||||
|
ratchet.delete().unwrap();
|
||||||
|
|
||||||
|
assert!(!store.session_exists(&session_id).unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::sync::Mutex;
|
use std::sync::Mutex;
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
|
||||||
|
|
||||||
use chacha20poly1305::{
|
use chacha20poly1305::{
|
||||||
aead::{Aead, KeyInit},
|
aead::{Aead, KeyInit},
|
||||||
@ -12,85 +11,37 @@ use rand::RngCore;
|
|||||||
use rusqlite::{params, Connection, OptionalExtension};
|
use rusqlite::{params, Connection, OptionalExtension};
|
||||||
|
|
||||||
use crate::error::StorageError;
|
use crate::error::StorageError;
|
||||||
use crate::traits::{RatchetStorage, SessionId};
|
use crate::traits::{RatchetStore, SessionId, SkippedKeyEntry, StoredState};
|
||||||
use crate::types::{SkippedKey, StorableRatchetState};
|
|
||||||
|
|
||||||
/// Field encryption key type (32 bytes for ChaCha20Poly1305).
|
/// Field encryption key type (32 bytes for ChaCha20Poly1305).
|
||||||
pub type EncryptionKey = [u8; 32];
|
pub type EncryptionKey = [u8; 32];
|
||||||
|
|
||||||
/// SQLite storage backend with field-level encryption for secrets.
|
/// SQLite storage with field-level encryption for secrets.
|
||||||
///
|
///
|
||||||
/// This implementation stores ratchet states in SQLite with:
|
/// Schema:
|
||||||
/// - Field-level encryption for private keys using ChaCha20Poly1305
|
/// - `sessions`: Core ratchet state (one row per session)
|
||||||
/// - WAL mode for better concurrent performance
|
/// - `skipped_keys`: Skipped message keys (many per session)
|
||||||
/// - Foreign keys and cascading deletes for data integrity
|
|
||||||
/// - Atomic transactions to prevent partial writes
|
|
||||||
///
|
///
|
||||||
/// # Security
|
/// Encrypted fields: dh_self_secret, skipped message_key
|
||||||
///
|
pub struct SqliteRatchetStore {
|
||||||
/// The `dh_self_secret` field is encrypted with the provided encryption key.
|
|
||||||
/// For additional security, consider using SQLCipher for full database encryption
|
|
||||||
/// via the `open_encrypted` method (requires `sqlcipher` feature).
|
|
||||||
///
|
|
||||||
/// # Example
|
|
||||||
///
|
|
||||||
/// ```no_run
|
|
||||||
/// use double_ratchets_storage::SqliteStorage;
|
|
||||||
///
|
|
||||||
/// let key = [0u8; 32]; // Use a proper key derivation function
|
|
||||||
/// let storage = SqliteStorage::open("ratchets.db", key).unwrap();
|
|
||||||
/// ```
|
|
||||||
pub struct SqliteStorage {
|
|
||||||
conn: Mutex<Connection>,
|
conn: Mutex<Connection>,
|
||||||
encryption_key: EncryptionKey,
|
encryption_key: EncryptionKey,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SqliteStorage {
|
impl SqliteRatchetStore {
|
||||||
/// Open or create a SQLite database with field-level encryption.
|
/// Open or create a SQLite database.
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `path` - Path to the database file.
|
|
||||||
/// * `encryption_key` - 32-byte key for field-level encryption.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// * `Ok(SqliteStorage)` on success.
|
|
||||||
/// * `Err(StorageError)` on failure.
|
|
||||||
pub fn open<P: AsRef<Path>>(path: P, encryption_key: EncryptionKey) -> Result<Self, StorageError> {
|
pub fn open<P: AsRef<Path>>(path: P, encryption_key: EncryptionKey) -> Result<Self, StorageError> {
|
||||||
let conn = Connection::open(path)?;
|
let conn = Connection::open(path)?;
|
||||||
Self::initialize(conn, encryption_key)
|
Self::initialize(conn, encryption_key)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create an in-memory SQLite database (for testing).
|
/// Create an in-memory database (for testing).
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `encryption_key` - 32-byte key for field-level encryption.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// * `Ok(SqliteStorage)` on success.
|
|
||||||
/// * `Err(StorageError)` on failure.
|
|
||||||
pub fn open_in_memory(encryption_key: EncryptionKey) -> Result<Self, StorageError> {
|
pub fn open_in_memory(encryption_key: EncryptionKey) -> Result<Self, StorageError> {
|
||||||
let conn = Connection::open_in_memory()?;
|
let conn = Connection::open_in_memory()?;
|
||||||
Self::initialize(conn, encryption_key)
|
Self::initialize(conn, encryption_key)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Open or create a SQLCipher-encrypted database.
|
/// Open with SQLCipher full-database encryption.
|
||||||
///
|
|
||||||
/// This method requires the `sqlcipher` feature to be enabled.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `path` - Path to the database file.
|
|
||||||
/// * `db_password` - Password for SQLCipher database encryption.
|
|
||||||
/// * `field_key` - 32-byte key for additional field-level encryption.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// * `Ok(SqliteStorage)` on success.
|
|
||||||
/// * `Err(StorageError)` on failure.
|
|
||||||
#[cfg(feature = "sqlcipher")]
|
#[cfg(feature = "sqlcipher")]
|
||||||
pub fn open_encrypted<P: AsRef<Path>>(
|
pub fn open_encrypted<P: AsRef<Path>>(
|
||||||
path: P,
|
path: P,
|
||||||
@ -98,50 +49,39 @@ impl SqliteStorage {
|
|||||||
field_key: EncryptionKey,
|
field_key: EncryptionKey,
|
||||||
) -> Result<Self, StorageError> {
|
) -> Result<Self, StorageError> {
|
||||||
let conn = Connection::open(path)?;
|
let conn = Connection::open(path)?;
|
||||||
|
|
||||||
// Set SQLCipher key
|
|
||||||
conn.pragma_update(None, "key", db_password)?;
|
conn.pragma_update(None, "key", db_password)?;
|
||||||
|
|
||||||
Self::initialize(conn, field_key)
|
Self::initialize(conn, field_key)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn initialize(conn: Connection, encryption_key: EncryptionKey) -> Result<Self, StorageError> {
|
fn initialize(conn: Connection, encryption_key: EncryptionKey) -> Result<Self, StorageError> {
|
||||||
// Enable WAL mode for better performance
|
|
||||||
conn.pragma_update(None, "journal_mode", "WAL")?;
|
conn.pragma_update(None, "journal_mode", "WAL")?;
|
||||||
|
|
||||||
// Enable foreign keys
|
|
||||||
conn.pragma_update(None, "foreign_keys", "ON")?;
|
conn.pragma_update(None, "foreign_keys", "ON")?;
|
||||||
|
|
||||||
// Create tables
|
|
||||||
conn.execute_batch(
|
conn.execute_batch(
|
||||||
r#"
|
r#"
|
||||||
CREATE TABLE IF NOT EXISTS ratchet_states (
|
CREATE TABLE IF NOT EXISTS sessions (
|
||||||
session_id BLOB PRIMARY KEY,
|
session_id BLOB PRIMARY KEY NOT NULL,
|
||||||
root_key BLOB NOT NULL,
|
root_key BLOB NOT NULL,
|
||||||
sending_chain BLOB,
|
sending_chain BLOB,
|
||||||
receiving_chain BLOB,
|
receiving_chain BLOB,
|
||||||
dh_self_secret_encrypted BLOB NOT NULL,
|
dh_secret_enc BLOB NOT NULL,
|
||||||
dh_self_secret_nonce BLOB NOT NULL,
|
dh_secret_nonce BLOB NOT NULL,
|
||||||
dh_self_public BLOB NOT NULL,
|
dh_public BLOB NOT NULL,
|
||||||
dh_remote BLOB,
|
dh_remote BLOB,
|
||||||
msg_send INTEGER NOT NULL,
|
msg_send INTEGER NOT NULL,
|
||||||
msg_recv INTEGER NOT NULL,
|
msg_recv INTEGER NOT NULL,
|
||||||
prev_chain_len INTEGER NOT NULL,
|
prev_chain_len INTEGER NOT NULL
|
||||||
domain_id TEXT NOT NULL,
|
|
||||||
created_at INTEGER NOT NULL,
|
|
||||||
updated_at INTEGER NOT NULL
|
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS skipped_keys (
|
CREATE TABLE IF NOT EXISTS skipped_keys (
|
||||||
id INTEGER PRIMARY KEY,
|
session_id BLOB NOT NULL,
|
||||||
session_id BLOB NOT NULL REFERENCES ratchet_states(session_id) ON DELETE CASCADE,
|
dh_public BLOB NOT NULL,
|
||||||
public_key BLOB NOT NULL,
|
|
||||||
msg_num INTEGER NOT NULL,
|
msg_num INTEGER NOT NULL,
|
||||||
message_key BLOB NOT NULL,
|
message_key_enc BLOB NOT NULL,
|
||||||
UNIQUE(session_id, public_key, msg_num)
|
message_key_nonce BLOB NOT NULL,
|
||||||
|
PRIMARY KEY (session_id, dh_public, msg_num),
|
||||||
|
FOREIGN KEY (session_id) REFERENCES sessions(session_id) ON DELETE CASCADE
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_skipped_keys_session ON skipped_keys(session_id);
|
|
||||||
"#,
|
"#,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
@ -151,23 +91,20 @@ impl SqliteStorage {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Encrypt a 32-byte secret using ChaCha20Poly1305.
|
fn encrypt(&self, plaintext: &[u8; 32]) -> Result<(Vec<u8>, [u8; 12]), StorageError> {
|
||||||
fn encrypt_secret(&self, secret: &[u8; 32]) -> Result<(Vec<u8>, [u8; 12]), StorageError> {
|
|
||||||
let cipher = ChaCha20Poly1305::new((&self.encryption_key).into());
|
let cipher = ChaCha20Poly1305::new((&self.encryption_key).into());
|
||||||
|
|
||||||
let mut nonce_bytes = [0u8; 12];
|
let mut nonce_bytes = [0u8; 12];
|
||||||
rand::thread_rng().fill_bytes(&mut nonce_bytes);
|
rand::thread_rng().fill_bytes(&mut nonce_bytes);
|
||||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||||
|
|
||||||
let ciphertext = cipher
|
let ciphertext = cipher
|
||||||
.encrypt(nonce, secret.as_ref())
|
.encrypt(nonce, plaintext.as_ref())
|
||||||
.map_err(|e| StorageError::Encryption(e.to_string()))?;
|
.map_err(|e| StorageError::Encryption(e.to_string()))?;
|
||||||
|
|
||||||
Ok((ciphertext, nonce_bytes))
|
Ok((ciphertext, nonce_bytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Decrypt a secret using ChaCha20Poly1305.
|
fn decrypt(&self, ciphertext: &[u8], nonce: &[u8; 12]) -> Result<[u8; 32], StorageError> {
|
||||||
fn decrypt_secret(&self, ciphertext: &[u8], nonce: &[u8; 12]) -> Result<[u8; 32], StorageError> {
|
|
||||||
let cipher = ChaCha20Poly1305::new((&self.encryption_key).into());
|
let cipher = ChaCha20Poly1305::new((&self.encryption_key).into());
|
||||||
let nonce = Nonce::from_slice(nonce);
|
let nonce = Nonce::from_slice(nonce);
|
||||||
|
|
||||||
@ -177,135 +114,110 @@ impl SqliteStorage {
|
|||||||
|
|
||||||
plaintext
|
plaintext
|
||||||
.try_into()
|
.try_into()
|
||||||
.map_err(|_| StorageError::CorruptedState("decrypted secret has wrong length".to_string()))
|
.map_err(|_| StorageError::CorruptedState("wrong decrypted length".into()))
|
||||||
}
|
|
||||||
|
|
||||||
fn current_timestamp() -> i64 {
|
|
||||||
SystemTime::now()
|
|
||||||
.duration_since(UNIX_EPOCH)
|
|
||||||
.unwrap()
|
|
||||||
.as_secs() as i64
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RatchetStorage for SqliteStorage {
|
impl RatchetStore for SqliteRatchetStore {
|
||||||
fn save(&self, session_id: &SessionId, state: &StorableRatchetState) -> Result<(), StorageError> {
|
fn store_root_and_chains(
|
||||||
let (encrypted_secret, nonce) = self.encrypt_secret(&state.dh_self_secret)?;
|
&self,
|
||||||
|
session_id: &SessionId,
|
||||||
|
root_key: &[u8; 32],
|
||||||
|
sending_chain: Option<&[u8; 32]>,
|
||||||
|
receiving_chain: Option<&[u8; 32]>,
|
||||||
|
) -> Result<(), StorageError> {
|
||||||
let conn = self.conn.lock().unwrap();
|
let conn = self.conn.lock().unwrap();
|
||||||
let tx = conn.unchecked_transaction()?;
|
conn.execute(
|
||||||
|
"UPDATE sessions SET root_key = ?, sending_chain = ?, receiving_chain = ? WHERE session_id = ?",
|
||||||
let now = Self::current_timestamp();
|
params![
|
||||||
|
root_key.as_slice(),
|
||||||
// Check if session exists to determine created_at
|
sending_chain.map(|c| c.as_slice()),
|
||||||
let exists: bool = tx.query_row(
|
receiving_chain.map(|c| c.as_slice()),
|
||||||
"SELECT 1 FROM ratchet_states WHERE session_id = ?",
|
session_id.as_slice(),
|
||||||
[session_id.as_slice()],
|
],
|
||||||
|_| Ok(true),
|
)?;
|
||||||
).optional()?.unwrap_or(false);
|
|
||||||
|
|
||||||
if exists {
|
|
||||||
// Update existing session
|
|
||||||
tx.execute(
|
|
||||||
r#"
|
|
||||||
UPDATE ratchet_states SET
|
|
||||||
root_key = ?,
|
|
||||||
sending_chain = ?,
|
|
||||||
receiving_chain = ?,
|
|
||||||
dh_self_secret_encrypted = ?,
|
|
||||||
dh_self_secret_nonce = ?,
|
|
||||||
dh_self_public = ?,
|
|
||||||
dh_remote = ?,
|
|
||||||
msg_send = ?,
|
|
||||||
msg_recv = ?,
|
|
||||||
prev_chain_len = ?,
|
|
||||||
domain_id = ?,
|
|
||||||
updated_at = ?
|
|
||||||
WHERE session_id = ?
|
|
||||||
"#,
|
|
||||||
params![
|
|
||||||
state.root_key.as_slice(),
|
|
||||||
state.sending_chain.as_ref().map(|c| c.as_slice()),
|
|
||||||
state.receiving_chain.as_ref().map(|c| c.as_slice()),
|
|
||||||
encrypted_secret.as_slice(),
|
|
||||||
nonce.as_slice(),
|
|
||||||
state.dh_self_public.as_slice(),
|
|
||||||
state.dh_remote.as_ref().map(|pk| pk.as_slice()),
|
|
||||||
state.msg_send,
|
|
||||||
state.msg_recv,
|
|
||||||
state.prev_chain_len,
|
|
||||||
&state.domain_id,
|
|
||||||
now,
|
|
||||||
session_id.as_slice(),
|
|
||||||
],
|
|
||||||
)?;
|
|
||||||
|
|
||||||
// Delete existing skipped keys
|
|
||||||
tx.execute(
|
|
||||||
"DELETE FROM skipped_keys WHERE session_id = ?",
|
|
||||||
[session_id.as_slice()],
|
|
||||||
)?;
|
|
||||||
} else {
|
|
||||||
// Insert new session
|
|
||||||
tx.execute(
|
|
||||||
r#"
|
|
||||||
INSERT INTO ratchet_states (
|
|
||||||
session_id, root_key, sending_chain, receiving_chain,
|
|
||||||
dh_self_secret_encrypted, dh_self_secret_nonce, dh_self_public,
|
|
||||||
dh_remote, msg_send, msg_recv, prev_chain_len, domain_id,
|
|
||||||
created_at, updated_at
|
|
||||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
||||||
"#,
|
|
||||||
params![
|
|
||||||
session_id.as_slice(),
|
|
||||||
state.root_key.as_slice(),
|
|
||||||
state.sending_chain.as_ref().map(|c| c.as_slice()),
|
|
||||||
state.receiving_chain.as_ref().map(|c| c.as_slice()),
|
|
||||||
encrypted_secret.as_slice(),
|
|
||||||
nonce.as_slice(),
|
|
||||||
state.dh_self_public.as_slice(),
|
|
||||||
state.dh_remote.as_ref().map(|pk| pk.as_slice()),
|
|
||||||
state.msg_send,
|
|
||||||
state.msg_recv,
|
|
||||||
state.prev_chain_len,
|
|
||||||
&state.domain_id,
|
|
||||||
now,
|
|
||||||
now,
|
|
||||||
],
|
|
||||||
)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert skipped keys
|
|
||||||
for sk in &state.skipped_keys {
|
|
||||||
tx.execute(
|
|
||||||
r#"
|
|
||||||
INSERT INTO skipped_keys (session_id, public_key, msg_num, message_key)
|
|
||||||
VALUES (?, ?, ?, ?)
|
|
||||||
"#,
|
|
||||||
params![
|
|
||||||
session_id.as_slice(),
|
|
||||||
sk.public_key.as_slice(),
|
|
||||||
sk.msg_num,
|
|
||||||
sk.message_key.as_slice(),
|
|
||||||
],
|
|
||||||
)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
tx.commit()?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load(&self, session_id: &SessionId) -> Result<Option<StorableRatchetState>, StorageError> {
|
fn store_dh_self(
|
||||||
|
&self,
|
||||||
|
session_id: &SessionId,
|
||||||
|
secret: &[u8; 32],
|
||||||
|
public: &[u8; 32],
|
||||||
|
) -> Result<(), StorageError> {
|
||||||
|
let (enc, nonce) = self.encrypt(secret)?;
|
||||||
|
let conn = self.conn.lock().unwrap();
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE sessions SET dh_secret_enc = ?, dh_secret_nonce = ?, dh_public = ? WHERE session_id = ?",
|
||||||
|
params![enc.as_slice(), nonce.as_slice(), public.as_slice(), session_id.as_slice()],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn store_dh_remote(
|
||||||
|
&self,
|
||||||
|
session_id: &SessionId,
|
||||||
|
remote: Option<&[u8; 32]>,
|
||||||
|
) -> Result<(), StorageError> {
|
||||||
|
let conn = self.conn.lock().unwrap();
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE sessions SET dh_remote = ? WHERE session_id = ?",
|
||||||
|
params![remote.map(|r| r.as_slice()), session_id.as_slice()],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn store_counters(
|
||||||
|
&self,
|
||||||
|
session_id: &SessionId,
|
||||||
|
msg_send: u32,
|
||||||
|
msg_recv: u32,
|
||||||
|
prev_chain_len: u32,
|
||||||
|
) -> Result<(), StorageError> {
|
||||||
|
let conn = self.conn.lock().unwrap();
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE sessions SET msg_send = ?, msg_recv = ?, prev_chain_len = ? WHERE session_id = ?",
|
||||||
|
params![msg_send, msg_recv, prev_chain_len, session_id.as_slice()],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_skipped_key(
|
||||||
|
&self,
|
||||||
|
session_id: &SessionId,
|
||||||
|
dh_public: &[u8; 32],
|
||||||
|
msg_num: u32,
|
||||||
|
message_key: &[u8; 32],
|
||||||
|
) -> Result<(), StorageError> {
|
||||||
|
let (enc, nonce) = self.encrypt(message_key)?;
|
||||||
|
let conn = self.conn.lock().unwrap();
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR REPLACE INTO skipped_keys (session_id, dh_public, msg_num, message_key_enc, message_key_nonce) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
params![session_id.as_slice(), dh_public.as_slice(), msg_num, enc.as_slice(), nonce.as_slice()],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remove_skipped_key(
|
||||||
|
&self,
|
||||||
|
session_id: &SessionId,
|
||||||
|
dh_public: &[u8; 32],
|
||||||
|
msg_num: u32,
|
||||||
|
) -> Result<(), StorageError> {
|
||||||
|
let conn = self.conn.lock().unwrap();
|
||||||
|
conn.execute(
|
||||||
|
"DELETE FROM skipped_keys WHERE session_id = ? AND dh_public = ? AND msg_num = ?",
|
||||||
|
params![session_id.as_slice(), dh_public.as_slice(), msg_num],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_state(&self, session_id: &SessionId) -> Result<Option<StoredState>, StorageError> {
|
||||||
let conn = self.conn.lock().unwrap();
|
let conn = self.conn.lock().unwrap();
|
||||||
|
|
||||||
let row = conn
|
let row = conn
|
||||||
.query_row(
|
.query_row(
|
||||||
r#"
|
"SELECT root_key, sending_chain, receiving_chain, dh_secret_enc, dh_secret_nonce, dh_public, dh_remote, msg_send, msg_recv, prev_chain_len FROM sessions WHERE session_id = ?",
|
||||||
SELECT root_key, sending_chain, receiving_chain,
|
|
||||||
dh_self_secret_encrypted, dh_self_secret_nonce, dh_self_public,
|
|
||||||
dh_remote, msg_send, msg_recv, prev_chain_len, domain_id
|
|
||||||
FROM ratchet_states WHERE session_id = ?
|
|
||||||
"#,
|
|
||||||
[session_id.as_slice()],
|
[session_id.as_slice()],
|
||||||
|row| {
|
|row| {
|
||||||
Ok((
|
Ok((
|
||||||
@ -319,91 +231,64 @@ impl RatchetStorage for SqliteStorage {
|
|||||||
row.get::<_, u32>(7)?,
|
row.get::<_, u32>(7)?,
|
||||||
row.get::<_, u32>(8)?,
|
row.get::<_, u32>(8)?,
|
||||||
row.get::<_, u32>(9)?,
|
row.get::<_, u32>(9)?,
|
||||||
row.get::<_, String>(10)?,
|
|
||||||
))
|
))
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.optional()?;
|
.optional()?;
|
||||||
|
|
||||||
let Some((
|
let Some((
|
||||||
root_key_bytes,
|
root_key_bytes, sending_bytes, receiving_bytes,
|
||||||
sending_chain_bytes,
|
secret_enc, secret_nonce, dh_pub_bytes, dh_remote_bytes,
|
||||||
receiving_chain_bytes,
|
msg_send, msg_recv, prev_chain_len,
|
||||||
encrypted_secret,
|
)) = row else {
|
||||||
nonce_bytes,
|
|
||||||
dh_self_public_bytes,
|
|
||||||
dh_remote_bytes,
|
|
||||||
msg_send,
|
|
||||||
msg_recv,
|
|
||||||
prev_chain_len,
|
|
||||||
domain_id,
|
|
||||||
)) = row
|
|
||||||
else {
|
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Decrypt the secret
|
let nonce: [u8; 12] = secret_nonce.try_into()
|
||||||
let nonce: [u8; 12] = nonce_bytes
|
.map_err(|_| StorageError::CorruptedState("invalid nonce".into()))?;
|
||||||
.try_into()
|
let dh_self_secret = self.decrypt(&secret_enc, &nonce)?;
|
||||||
.map_err(|_| StorageError::CorruptedState("invalid nonce length".to_string()))?;
|
|
||||||
let dh_self_secret = self.decrypt_secret(&encrypted_secret, &nonce)?;
|
|
||||||
|
|
||||||
// Convert byte vectors to arrays
|
let root_key: [u8; 32] = root_key_bytes.try_into()
|
||||||
let root_key: [u8; 32] = root_key_bytes
|
.map_err(|_| StorageError::CorruptedState("invalid root_key".into()))?;
|
||||||
.try_into()
|
let dh_self_public: [u8; 32] = dh_pub_bytes.try_into()
|
||||||
.map_err(|_| StorageError::CorruptedState("invalid root_key length".to_string()))?;
|
.map_err(|_| StorageError::CorruptedState("invalid dh_public".into()))?;
|
||||||
let dh_self_public: [u8; 32] = dh_self_public_bytes
|
|
||||||
.try_into()
|
|
||||||
.map_err(|_| StorageError::CorruptedState("invalid dh_self_public length".to_string()))?;
|
|
||||||
|
|
||||||
let sending_chain = sending_chain_bytes
|
let sending_chain = sending_bytes
|
||||||
.map(|b| {
|
.map(|b| b.try_into().map_err(|_| StorageError::CorruptedState("invalid sending_chain".into())))
|
||||||
b.try_into()
|
|
||||||
.map_err(|_| StorageError::CorruptedState("invalid sending_chain length".to_string()))
|
|
||||||
})
|
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
let receiving_chain = receiving_chain_bytes
|
let receiving_chain = receiving_bytes
|
||||||
.map(|b| {
|
.map(|b| b.try_into().map_err(|_| StorageError::CorruptedState("invalid receiving_chain".into())))
|
||||||
b.try_into()
|
|
||||||
.map_err(|_| StorageError::CorruptedState("invalid receiving_chain length".to_string()))
|
|
||||||
})
|
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
let dh_remote = dh_remote_bytes
|
let dh_remote = dh_remote_bytes
|
||||||
.map(|b| {
|
.map(|b| b.try_into().map_err(|_| StorageError::CorruptedState("invalid dh_remote".into())))
|
||||||
b.try_into()
|
|
||||||
.map_err(|_| StorageError::CorruptedState("invalid dh_remote length".to_string()))
|
|
||||||
})
|
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
|
|
||||||
// Load skipped keys
|
// Load skipped keys
|
||||||
let mut stmt = conn.prepare(
|
let mut stmt = conn.prepare(
|
||||||
"SELECT public_key, msg_num, message_key FROM skipped_keys WHERE session_id = ?",
|
"SELECT dh_public, msg_num, message_key_enc, message_key_nonce FROM skipped_keys WHERE session_id = ?",
|
||||||
)?;
|
)?;
|
||||||
let skipped_keys: Vec<SkippedKey> = stmt
|
let mut skipped_keys = Vec::new();
|
||||||
.query_map([session_id.as_slice()], |row| {
|
let rows = stmt.query_map([session_id.as_slice()], |row| {
|
||||||
let pk_bytes: Vec<u8> = row.get(0)?;
|
Ok((
|
||||||
let msg_num: u32 = row.get(1)?;
|
row.get::<_, Vec<u8>>(0)?,
|
||||||
let mk_bytes: Vec<u8> = row.get(2)?;
|
row.get::<_, u32>(1)?,
|
||||||
Ok((pk_bytes, msg_num, mk_bytes))
|
row.get::<_, Vec<u8>>(2)?,
|
||||||
})?
|
row.get::<_, Vec<u8>>(3)?,
|
||||||
.collect::<Result<Vec<_>, _>>()?
|
))
|
||||||
.into_iter()
|
})?;
|
||||||
.map(|(pk_bytes, msg_num, mk_bytes)| {
|
|
||||||
let public_key: [u8; 32] = pk_bytes
|
|
||||||
.try_into()
|
|
||||||
.map_err(|_| StorageError::CorruptedState("invalid skipped public_key length".to_string()))?;
|
|
||||||
let message_key: [u8; 32] = mk_bytes
|
|
||||||
.try_into()
|
|
||||||
.map_err(|_| StorageError::CorruptedState("invalid skipped message_key length".to_string()))?;
|
|
||||||
Ok(SkippedKey {
|
|
||||||
public_key,
|
|
||||||
msg_num,
|
|
||||||
message_key,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<_>, StorageError>>()?;
|
|
||||||
|
|
||||||
Ok(Some(StorableRatchetState {
|
for row in rows {
|
||||||
|
let (pk_bytes, msg_num, key_enc, key_nonce) = row?;
|
||||||
|
let dh_public: [u8; 32] = pk_bytes.try_into()
|
||||||
|
.map_err(|_| StorageError::CorruptedState("invalid skipped dh_public".into()))?;
|
||||||
|
let nonce: [u8; 12] = key_nonce.try_into()
|
||||||
|
.map_err(|_| StorageError::CorruptedState("invalid skipped nonce".into()))?;
|
||||||
|
let message_key = self.decrypt(&key_enc, &nonce)?;
|
||||||
|
|
||||||
|
skipped_keys.push(SkippedKeyEntry { dh_public, msg_num, message_key });
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Some(StoredState {
|
||||||
root_key,
|
root_key,
|
||||||
sending_chain,
|
sending_chain,
|
||||||
receiving_chain,
|
receiving_chain,
|
||||||
@ -414,24 +299,72 @@ impl RatchetStorage for SqliteStorage {
|
|||||||
msg_recv,
|
msg_recv,
|
||||||
prev_chain_len,
|
prev_chain_len,
|
||||||
skipped_keys,
|
skipped_keys,
|
||||||
domain_id,
|
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn delete(&self, session_id: &SessionId) -> Result<bool, StorageError> {
|
fn init_session(
|
||||||
|
&self,
|
||||||
|
session_id: &SessionId,
|
||||||
|
root_key: &[u8; 32],
|
||||||
|
sending_chain: Option<&[u8; 32]>,
|
||||||
|
receiving_chain: Option<&[u8; 32]>,
|
||||||
|
dh_self_secret: &[u8; 32],
|
||||||
|
dh_self_public: &[u8; 32],
|
||||||
|
dh_remote: Option<&[u8; 32]>,
|
||||||
|
msg_send: u32,
|
||||||
|
msg_recv: u32,
|
||||||
|
prev_chain_len: u32,
|
||||||
|
) -> Result<(), StorageError> {
|
||||||
|
let (secret_enc, secret_nonce) = self.encrypt(dh_self_secret)?;
|
||||||
|
let conn = self.conn.lock().unwrap();
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
r#"
|
||||||
|
INSERT INTO sessions (session_id, root_key, sending_chain, receiving_chain, dh_secret_enc, dh_secret_nonce, dh_public, dh_remote, msg_send, msg_recv, prev_chain_len)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(session_id) DO UPDATE SET
|
||||||
|
root_key = excluded.root_key,
|
||||||
|
sending_chain = excluded.sending_chain,
|
||||||
|
receiving_chain = excluded.receiving_chain,
|
||||||
|
dh_secret_enc = excluded.dh_secret_enc,
|
||||||
|
dh_secret_nonce = excluded.dh_secret_nonce,
|
||||||
|
dh_public = excluded.dh_public,
|
||||||
|
dh_remote = excluded.dh_remote,
|
||||||
|
msg_send = excluded.msg_send,
|
||||||
|
msg_recv = excluded.msg_recv,
|
||||||
|
prev_chain_len = excluded.prev_chain_len
|
||||||
|
"#,
|
||||||
|
params![
|
||||||
|
session_id.as_slice(),
|
||||||
|
root_key.as_slice(),
|
||||||
|
sending_chain.map(|c| c.as_slice()),
|
||||||
|
receiving_chain.map(|c| c.as_slice()),
|
||||||
|
secret_enc.as_slice(),
|
||||||
|
secret_nonce.as_slice(),
|
||||||
|
dh_self_public.as_slice(),
|
||||||
|
dh_remote.map(|r| r.as_slice()),
|
||||||
|
msg_send,
|
||||||
|
msg_recv,
|
||||||
|
prev_chain_len,
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn delete_session(&self, session_id: &SessionId) -> Result<bool, StorageError> {
|
||||||
let conn = self.conn.lock().unwrap();
|
let conn = self.conn.lock().unwrap();
|
||||||
let changes = conn.execute(
|
let changes = conn.execute(
|
||||||
"DELETE FROM ratchet_states WHERE session_id = ?",
|
"DELETE FROM sessions WHERE session_id = ?",
|
||||||
[session_id.as_slice()],
|
[session_id.as_slice()],
|
||||||
)?;
|
)?;
|
||||||
Ok(changes > 0)
|
Ok(changes > 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn exists(&self, session_id: &SessionId) -> Result<bool, StorageError> {
|
fn session_exists(&self, session_id: &SessionId) -> Result<bool, StorageError> {
|
||||||
let conn = self.conn.lock().unwrap();
|
let conn = self.conn.lock().unwrap();
|
||||||
let exists: bool = conn
|
let exists: bool = conn
|
||||||
.query_row(
|
.query_row(
|
||||||
"SELECT 1 FROM ratchet_states WHERE session_id = ?",
|
"SELECT 1 FROM sessions WHERE session_id = ?",
|
||||||
[session_id.as_slice()],
|
[session_id.as_slice()],
|
||||||
|_| Ok(true),
|
|_| Ok(true),
|
||||||
)
|
)
|
||||||
@ -442,15 +375,11 @@ impl RatchetStorage for SqliteStorage {
|
|||||||
|
|
||||||
fn list_sessions(&self) -> Result<Vec<SessionId>, StorageError> {
|
fn list_sessions(&self) -> Result<Vec<SessionId>, StorageError> {
|
||||||
let conn = self.conn.lock().unwrap();
|
let conn = self.conn.lock().unwrap();
|
||||||
let mut stmt = conn.prepare("SELECT session_id FROM ratchet_states")?;
|
let mut stmt = conn.prepare("SELECT session_id FROM sessions")?;
|
||||||
let sessions: Vec<SessionId> = stmt
|
let sessions = stmt
|
||||||
.query_map([], |row| {
|
.query_map([], |row| row.get::<_, Vec<u8>>(0))?
|
||||||
let bytes: Vec<u8> = row.get(0)?;
|
.filter_map(|r| r.ok())
|
||||||
Ok(bytes)
|
.filter_map(|v| v.try_into().ok())
|
||||||
})?
|
|
||||||
.collect::<Result<Vec<_>, _>>()?
|
|
||||||
.into_iter()
|
|
||||||
.filter_map(|bytes| bytes.try_into().ok())
|
|
||||||
.collect();
|
.collect();
|
||||||
Ok(sessions)
|
Ok(sessions)
|
||||||
}
|
}
|
||||||
@ -459,288 +388,99 @@ impl RatchetStorage for SqliteStorage {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use double_ratchets::hkdf::DefaultDomain;
|
|
||||||
use double_ratchets::state::RatchetState;
|
|
||||||
use double_ratchets::InstallationKeyPair;
|
|
||||||
|
|
||||||
fn create_test_storage() -> SqliteStorage {
|
fn test_store() -> SqliteRatchetStore {
|
||||||
let key = [0x42u8; 32];
|
SqliteRatchetStore::open_in_memory([0x42u8; 32]).unwrap()
|
||||||
SqliteStorage::open_in_memory(key).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn create_test_state() -> StorableRatchetState {
|
|
||||||
let bob_keypair = InstallationKeyPair::generate();
|
|
||||||
let shared_secret = [0x42u8; 32];
|
|
||||||
let state: RatchetState<DefaultDomain> =
|
|
||||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
|
||||||
StorableRatchetState::from_ratchet_state(&state, "default")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_save_and_load() {
|
fn test_init_and_load() {
|
||||||
let storage = create_test_storage();
|
let store = test_store();
|
||||||
let session_id = [1u8; 32];
|
|
||||||
let state = create_test_state();
|
|
||||||
|
|
||||||
storage.save(&session_id, &state).unwrap();
|
|
||||||
let loaded = storage.load(&session_id).unwrap();
|
|
||||||
|
|
||||||
assert!(loaded.is_some());
|
|
||||||
let loaded = loaded.unwrap();
|
|
||||||
assert_eq!(loaded.root_key, state.root_key);
|
|
||||||
assert_eq!(loaded.dh_self_public, state.dh_self_public);
|
|
||||||
// Secret should be decrypted correctly
|
|
||||||
assert_eq!(loaded.dh_self_secret, state.dh_self_secret);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_load_nonexistent() {
|
|
||||||
let storage = create_test_storage();
|
|
||||||
let session_id = [1u8; 32];
|
let session_id = [1u8; 32];
|
||||||
|
|
||||||
let loaded = storage.load(&session_id).unwrap();
|
store.init_session(
|
||||||
assert!(loaded.is_none());
|
&session_id,
|
||||||
|
&[0xAA; 32], // root
|
||||||
|
Some(&[0xBB; 32]), // sending
|
||||||
|
None, // receiving
|
||||||
|
&[0xCC; 32], // secret
|
||||||
|
&[0xDD; 32], // public
|
||||||
|
Some(&[0xEE; 32]), // remote
|
||||||
|
5, 3, 2,
|
||||||
|
).unwrap();
|
||||||
|
|
||||||
|
let state = store.load_state(&session_id).unwrap().unwrap();
|
||||||
|
assert_eq!(state.root_key, [0xAA; 32]);
|
||||||
|
assert_eq!(state.sending_chain, Some([0xBB; 32]));
|
||||||
|
assert_eq!(state.receiving_chain, None);
|
||||||
|
assert_eq!(state.dh_self_secret, [0xCC; 32]);
|
||||||
|
assert_eq!(state.dh_self_public, [0xDD; 32]);
|
||||||
|
assert_eq!(state.dh_remote, Some([0xEE; 32]));
|
||||||
|
assert_eq!(state.msg_send, 5);
|
||||||
|
assert_eq!(state.msg_recv, 3);
|
||||||
|
assert_eq!(state.prev_chain_len, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_delete() {
|
fn test_update_fields() {
|
||||||
let storage = create_test_storage();
|
let store = test_store();
|
||||||
let session_id = [1u8; 32];
|
|
||||||
let state = create_test_state();
|
|
||||||
|
|
||||||
storage.save(&session_id, &state).unwrap();
|
|
||||||
assert!(storage.exists(&session_id).unwrap());
|
|
||||||
|
|
||||||
let deleted = storage.delete(&session_id).unwrap();
|
|
||||||
assert!(deleted);
|
|
||||||
assert!(!storage.exists(&session_id).unwrap());
|
|
||||||
|
|
||||||
// Deleting again should return false
|
|
||||||
let deleted = storage.delete(&session_id).unwrap();
|
|
||||||
assert!(!deleted);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_exists() {
|
|
||||||
let storage = create_test_storage();
|
|
||||||
let session_id = [1u8; 32];
|
let session_id = [1u8; 32];
|
||||||
|
|
||||||
assert!(!storage.exists(&session_id).unwrap());
|
store.init_session(
|
||||||
|
&session_id, &[0xAA; 32], None, None,
|
||||||
|
&[0xCC; 32], &[0xDD; 32], None, 0, 0, 0,
|
||||||
|
).unwrap();
|
||||||
|
|
||||||
let state = create_test_state();
|
// Update root and chains
|
||||||
storage.save(&session_id, &state).unwrap();
|
store.store_root_and_chains(&session_id, &[0x11; 32], Some(&[0x22; 32]), Some(&[0x33; 32])).unwrap();
|
||||||
|
|
||||||
assert!(storage.exists(&session_id).unwrap());
|
let state = store.load_state(&session_id).unwrap().unwrap();
|
||||||
|
assert_eq!(state.root_key, [0x11; 32]);
|
||||||
|
assert_eq!(state.sending_chain, Some([0x22; 32]));
|
||||||
|
assert_eq!(state.receiving_chain, Some([0x33; 32]));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_list_sessions() {
|
fn test_skipped_keys() {
|
||||||
let storage = create_test_storage();
|
let store = test_store();
|
||||||
|
|
||||||
assert!(storage.list_sessions().unwrap().is_empty());
|
|
||||||
|
|
||||||
let state = create_test_state();
|
|
||||||
let session_ids: Vec<SessionId> = (0..3).map(|i| [i; 32]).collect();
|
|
||||||
|
|
||||||
for id in &session_ids {
|
|
||||||
storage.save(id, &state).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut listed = storage.list_sessions().unwrap();
|
|
||||||
listed.sort();
|
|
||||||
let mut expected = session_ids.clone();
|
|
||||||
expected.sort();
|
|
||||||
|
|
||||||
assert_eq!(listed, expected);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_overwrite() {
|
|
||||||
let storage = create_test_storage();
|
|
||||||
let session_id = [1u8; 32];
|
let session_id = [1u8; 32];
|
||||||
|
|
||||||
// Create first state
|
store.init_session(
|
||||||
let bob_keypair1 = InstallationKeyPair::generate();
|
&session_id, &[0xAA; 32], None, None,
|
||||||
let state1: RatchetState<DefaultDomain> =
|
&[0xCC; 32], &[0xDD; 32], None, 0, 0, 0,
|
||||||
RatchetState::init_sender([0x42u8; 32], *bob_keypair1.public());
|
).unwrap();
|
||||||
let storable1 = StorableRatchetState::from_ratchet_state(&state1, "default");
|
|
||||||
|
|
||||||
// Create second state with different root
|
// Add skipped keys
|
||||||
let bob_keypair2 = InstallationKeyPair::generate();
|
store.add_skipped_key(&session_id, &[0x11; 32], 5, &[0xAB; 32]).unwrap();
|
||||||
let state2: RatchetState<DefaultDomain> =
|
store.add_skipped_key(&session_id, &[0x11; 32], 6, &[0xCD; 32]).unwrap();
|
||||||
RatchetState::init_sender([0x43u8; 32], *bob_keypair2.public());
|
|
||||||
let storable2 = StorableRatchetState::from_ratchet_state(&state2, "default");
|
|
||||||
|
|
||||||
// Save first, then overwrite with second
|
let state = store.load_state(&session_id).unwrap().unwrap();
|
||||||
storage.save(&session_id, &storable1).unwrap();
|
assert_eq!(state.skipped_keys.len(), 2);
|
||||||
storage.save(&session_id, &storable2).unwrap();
|
|
||||||
|
|
||||||
// Should have the second state
|
// Remove one
|
||||||
let loaded = storage.load(&session_id).unwrap().unwrap();
|
store.remove_skipped_key(&session_id, &[0x11; 32], 5).unwrap();
|
||||||
assert_eq!(loaded.root_key, storable2.root_key);
|
|
||||||
assert_ne!(loaded.root_key, storable1.root_key);
|
let state = store.load_state(&session_id).unwrap().unwrap();
|
||||||
|
assert_eq!(state.skipped_keys.len(), 1);
|
||||||
|
assert_eq!(state.skipped_keys[0].msg_num, 6);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_skipped_keys_storage() {
|
fn test_delete_cascade() {
|
||||||
let storage = create_test_storage();
|
let store = test_store();
|
||||||
let session_id = [1u8; 32];
|
let session_id = [1u8; 32];
|
||||||
|
|
||||||
// Create states and generate skipped keys
|
store.init_session(
|
||||||
let bob_keypair = InstallationKeyPair::generate();
|
&session_id, &[0xAA; 32], None, None,
|
||||||
let shared_secret = [0x42u8; 32];
|
&[0xCC; 32], &[0xDD; 32], None, 0, 0, 0,
|
||||||
let mut alice: RatchetState<DefaultDomain> =
|
).unwrap();
|
||||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
store.add_skipped_key(&session_id, &[0x11; 32], 5, &[0xAB; 32]).unwrap();
|
||||||
let mut bob: RatchetState<DefaultDomain> =
|
|
||||||
RatchetState::init_receiver(shared_secret, bob_keypair);
|
|
||||||
|
|
||||||
// Alice sends multiple messages
|
assert!(store.session_exists(&session_id).unwrap());
|
||||||
let mut messages = vec![];
|
|
||||||
for i in 0..3 {
|
|
||||||
let (ct, header) = alice.encrypt_message(&format!("Message {}", i).into_bytes());
|
|
||||||
messages.push((ct, header));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bob receives out of order to create skipped keys
|
store.delete_session(&session_id).unwrap();
|
||||||
bob.decrypt_message(&messages[0].0, messages[0].1.clone())
|
|
||||||
.unwrap();
|
|
||||||
bob.decrypt_message(&messages[2].0, messages[2].1.clone())
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert!(!bob.skipped_keys.is_empty());
|
assert!(!store.session_exists(&session_id).unwrap());
|
||||||
|
// Skipped keys should be gone too (cascade)
|
||||||
// Save and reload
|
|
||||||
let storable = StorableRatchetState::from_ratchet_state(&bob, "default");
|
|
||||||
storage.save(&session_id, &storable).unwrap();
|
|
||||||
|
|
||||||
let loaded = storage.load(&session_id).unwrap().unwrap();
|
|
||||||
assert_eq!(loaded.skipped_keys.len(), storable.skipped_keys.len());
|
|
||||||
|
|
||||||
// Restore and verify we can decrypt the skipped message
|
|
||||||
let mut restored: RatchetState<DefaultDomain> = loaded.to_ratchet_state().unwrap();
|
|
||||||
let pt = restored
|
|
||||||
.decrypt_message(&messages[1].0, messages[1].1.clone())
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(pt, b"Message 1");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_encryption_uses_different_nonces() {
|
|
||||||
let storage = create_test_storage();
|
|
||||||
let state = create_test_state();
|
|
||||||
|
|
||||||
// Save the same state twice with different session IDs
|
|
||||||
storage.save(&[1u8; 32], &state).unwrap();
|
|
||||||
storage.save(&[2u8; 32], &state).unwrap();
|
|
||||||
|
|
||||||
// Both should load correctly (encryption with different nonces)
|
|
||||||
let loaded1 = storage.load(&[1u8; 32]).unwrap().unwrap();
|
|
||||||
let loaded2 = storage.load(&[2u8; 32]).unwrap().unwrap();
|
|
||||||
|
|
||||||
assert_eq!(loaded1.dh_self_secret, loaded2.dh_self_secret);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_cascade_delete_skipped_keys() {
|
|
||||||
let storage = create_test_storage();
|
|
||||||
let session_id = [1u8; 32];
|
|
||||||
|
|
||||||
// Create a state with skipped keys
|
|
||||||
let bob_keypair = InstallationKeyPair::generate();
|
|
||||||
let shared_secret = [0x42u8; 32];
|
|
||||||
let mut alice: RatchetState<DefaultDomain> =
|
|
||||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
|
||||||
let mut bob: RatchetState<DefaultDomain> =
|
|
||||||
RatchetState::init_receiver(shared_secret, bob_keypair);
|
|
||||||
|
|
||||||
let mut messages = vec![];
|
|
||||||
for i in 0..3 {
|
|
||||||
let (ct, header) = alice.encrypt_message(&format!("Message {}", i).into_bytes());
|
|
||||||
messages.push((ct, header));
|
|
||||||
}
|
|
||||||
|
|
||||||
bob.decrypt_message(&messages[0].0, messages[0].1.clone())
|
|
||||||
.unwrap();
|
|
||||||
bob.decrypt_message(&messages[2].0, messages[2].1.clone())
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let storable = StorableRatchetState::from_ratchet_state(&bob, "default");
|
|
||||||
storage.save(&session_id, &storable).unwrap();
|
|
||||||
|
|
||||||
// Verify skipped keys exist
|
|
||||||
{
|
|
||||||
let conn = storage.conn.lock().unwrap();
|
|
||||||
let count: i64 = conn
|
|
||||||
.query_row(
|
|
||||||
"SELECT COUNT(*) FROM skipped_keys WHERE session_id = ?",
|
|
||||||
[session_id.as_slice()],
|
|
||||||
|row| row.get(0),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
assert!(count > 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete session
|
|
||||||
storage.delete(&session_id).unwrap();
|
|
||||||
|
|
||||||
// Verify skipped keys were also deleted (cascade)
|
|
||||||
{
|
|
||||||
let conn = storage.conn.lock().unwrap();
|
|
||||||
let count: i64 = conn
|
|
||||||
.query_row(
|
|
||||||
"SELECT COUNT(*) FROM skipped_keys WHERE session_id = ?",
|
|
||||||
[session_id.as_slice()],
|
|
||||||
|row| row.get(0),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(count, 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_file_storage() {
|
|
||||||
let dir = tempfile::tempdir().unwrap();
|
|
||||||
let db_path = dir.path().join("test.db");
|
|
||||||
let key = [0x42u8; 32];
|
|
||||||
|
|
||||||
let state = create_test_state();
|
|
||||||
let session_id = [1u8; 32];
|
|
||||||
|
|
||||||
// Save in one instance
|
|
||||||
{
|
|
||||||
let storage = SqliteStorage::open(&db_path, key).unwrap();
|
|
||||||
storage.save(&session_id, &state).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load in another instance
|
|
||||||
{
|
|
||||||
let storage = SqliteStorage::open(&db_path, key).unwrap();
|
|
||||||
let loaded = storage.load(&session_id).unwrap().unwrap();
|
|
||||||
assert_eq!(loaded.root_key, state.root_key);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_wrong_key_fails_decryption() {
|
|
||||||
let dir = tempfile::tempdir().unwrap();
|
|
||||||
let db_path = dir.path().join("test.db");
|
|
||||||
let key1 = [0x42u8; 32];
|
|
||||||
let key2 = [0x43u8; 32];
|
|
||||||
|
|
||||||
let state = create_test_state();
|
|
||||||
let session_id = [1u8; 32];
|
|
||||||
|
|
||||||
// Save with key1
|
|
||||||
{
|
|
||||||
let storage = SqliteStorage::open(&db_path, key1).unwrap();
|
|
||||||
storage.save(&session_id, &state).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to load with key2 - should fail decryption
|
|
||||||
{
|
|
||||||
let storage = SqliteStorage::open(&db_path, key2).unwrap();
|
|
||||||
let result = storage.load(&session_id);
|
|
||||||
assert!(result.is_err());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,74 +1,112 @@
|
|||||||
//! Storage trait definitions.
|
//! Storage trait for field-level ratchet state persistence.
|
||||||
|
|
||||||
use crate::error::StorageError;
|
use crate::error::StorageError;
|
||||||
use crate::types::StorableRatchetState;
|
|
||||||
|
|
||||||
/// A 32-byte session identifier.
|
/// A 32-byte session identifier.
|
||||||
pub type SessionId = [u8; 32];
|
pub type SessionId = [u8; 32];
|
||||||
|
|
||||||
/// Abstract storage interface for ratchet states.
|
/// Field-level storage interface for ratchet state.
|
||||||
///
|
///
|
||||||
/// Implementations must be thread-safe (`Send + Sync`).
|
/// This trait provides granular storage operations that are called automatically
|
||||||
pub trait RatchetStorage: Send + Sync {
|
/// during ratchet operations. Each method persists only the fields that changed.
|
||||||
/// Save a ratchet state for the given session.
|
pub trait RatchetStore: Send + Sync {
|
||||||
///
|
/// Store the root key and chain keys after a DH ratchet step.
|
||||||
/// If a state already exists for this session, it will be overwritten.
|
fn store_root_and_chains(
|
||||||
///
|
&self,
|
||||||
/// # Arguments
|
session_id: &SessionId,
|
||||||
///
|
root_key: &[u8; 32],
|
||||||
/// * `session_id` - Unique identifier for the session.
|
sending_chain: Option<&[u8; 32]>,
|
||||||
/// * `state` - The ratchet state to store.
|
receiving_chain: Option<&[u8; 32]>,
|
||||||
///
|
) -> Result<(), StorageError>;
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// * `Ok(())` on success.
|
|
||||||
/// * `Err(StorageError)` on failure.
|
|
||||||
fn save(&self, session_id: &SessionId, state: &StorableRatchetState) -> Result<(), StorageError>;
|
|
||||||
|
|
||||||
/// Load a ratchet state for the given session.
|
/// Store our DH keypair (secret encrypted, public plaintext).
|
||||||
///
|
fn store_dh_self(
|
||||||
/// # Arguments
|
&self,
|
||||||
///
|
session_id: &SessionId,
|
||||||
/// * `session_id` - Unique identifier for the session.
|
secret: &[u8; 32],
|
||||||
///
|
public: &[u8; 32],
|
||||||
/// # Returns
|
) -> Result<(), StorageError>;
|
||||||
///
|
|
||||||
/// * `Ok(Some(state))` if found.
|
|
||||||
/// * `Ok(None)` if not found.
|
|
||||||
/// * `Err(StorageError)` on failure.
|
|
||||||
fn load(&self, session_id: &SessionId) -> Result<Option<StorableRatchetState>, StorageError>;
|
|
||||||
|
|
||||||
/// Delete a ratchet state for the given session.
|
/// Store the remote party's DH public key.
|
||||||
///
|
fn store_dh_remote(
|
||||||
/// # Arguments
|
&self,
|
||||||
///
|
session_id: &SessionId,
|
||||||
/// * `session_id` - Unique identifier for the session.
|
remote: Option<&[u8; 32]>,
|
||||||
///
|
) -> Result<(), StorageError>;
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// * `Ok(true)` if the session existed and was deleted.
|
|
||||||
/// * `Ok(false)` if the session did not exist.
|
|
||||||
/// * `Err(StorageError)` on failure.
|
|
||||||
fn delete(&self, session_id: &SessionId) -> Result<bool, StorageError>;
|
|
||||||
|
|
||||||
/// Check if a session exists in storage.
|
/// Store message counters.
|
||||||
///
|
fn store_counters(
|
||||||
/// # Arguments
|
&self,
|
||||||
///
|
session_id: &SessionId,
|
||||||
/// * `session_id` - Unique identifier for the session.
|
msg_send: u32,
|
||||||
///
|
msg_recv: u32,
|
||||||
/// # Returns
|
prev_chain_len: u32,
|
||||||
///
|
) -> Result<(), StorageError>;
|
||||||
/// * `Ok(true)` if the session exists.
|
|
||||||
/// * `Ok(false)` if the session does not exist.
|
|
||||||
/// * `Err(StorageError)` on failure.
|
|
||||||
fn exists(&self, session_id: &SessionId) -> Result<bool, StorageError>;
|
|
||||||
|
|
||||||
/// List all session IDs in storage.
|
/// Add a skipped message key.
|
||||||
///
|
fn add_skipped_key(
|
||||||
/// # Returns
|
&self,
|
||||||
///
|
session_id: &SessionId,
|
||||||
/// * `Ok(Vec<SessionId>)` containing all session IDs.
|
dh_public: &[u8; 32],
|
||||||
/// * `Err(StorageError)` on failure.
|
msg_num: u32,
|
||||||
|
message_key: &[u8; 32],
|
||||||
|
) -> Result<(), StorageError>;
|
||||||
|
|
||||||
|
/// Remove a skipped message key (after use).
|
||||||
|
fn remove_skipped_key(
|
||||||
|
&self,
|
||||||
|
session_id: &SessionId,
|
||||||
|
dh_public: &[u8; 32],
|
||||||
|
msg_num: u32,
|
||||||
|
) -> Result<(), StorageError>;
|
||||||
|
|
||||||
|
/// Load all state for a session. Returns None if session doesn't exist.
|
||||||
|
fn load_state(&self, session_id: &SessionId) -> Result<Option<StoredState>, StorageError>;
|
||||||
|
|
||||||
|
/// Initialize a new session with all fields.
|
||||||
|
fn init_session(
|
||||||
|
&self,
|
||||||
|
session_id: &SessionId,
|
||||||
|
root_key: &[u8; 32],
|
||||||
|
sending_chain: Option<&[u8; 32]>,
|
||||||
|
receiving_chain: Option<&[u8; 32]>,
|
||||||
|
dh_self_secret: &[u8; 32],
|
||||||
|
dh_self_public: &[u8; 32],
|
||||||
|
dh_remote: Option<&[u8; 32]>,
|
||||||
|
msg_send: u32,
|
||||||
|
msg_recv: u32,
|
||||||
|
prev_chain_len: u32,
|
||||||
|
) -> Result<(), StorageError>;
|
||||||
|
|
||||||
|
/// Delete a session and all its data.
|
||||||
|
fn delete_session(&self, session_id: &SessionId) -> Result<bool, StorageError>;
|
||||||
|
|
||||||
|
/// Check if a session exists.
|
||||||
|
fn session_exists(&self, session_id: &SessionId) -> Result<bool, StorageError>;
|
||||||
|
|
||||||
|
/// List all session IDs.
|
||||||
fn list_sessions(&self) -> Result<Vec<SessionId>, StorageError>;
|
fn list_sessions(&self) -> Result<Vec<SessionId>, StorageError>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Complete state loaded from storage.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct StoredState {
|
||||||
|
pub root_key: [u8; 32],
|
||||||
|
pub sending_chain: Option<[u8; 32]>,
|
||||||
|
pub receiving_chain: Option<[u8; 32]>,
|
||||||
|
pub dh_self_secret: [u8; 32],
|
||||||
|
pub dh_self_public: [u8; 32],
|
||||||
|
pub dh_remote: Option<[u8; 32]>,
|
||||||
|
pub msg_send: u32,
|
||||||
|
pub msg_recv: u32,
|
||||||
|
pub prev_chain_len: u32,
|
||||||
|
pub skipped_keys: Vec<SkippedKeyEntry>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A skipped key entry from storage.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct SkippedKeyEntry {
|
||||||
|
pub dh_public: [u8; 32],
|
||||||
|
pub msg_num: u32,
|
||||||
|
pub message_key: [u8; 32],
|
||||||
|
}
|
||||||
|
|||||||
@ -1,225 +0,0 @@
|
|||||||
//! Serializable types for ratchet state storage.
|
|
||||||
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use double_ratchets::state::RatchetState;
|
|
||||||
use double_ratchets::hkdf::HkdfInfo;
|
|
||||||
use double_ratchets::InstallationKeyPair;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use x25519_dalek::PublicKey;
|
|
||||||
|
|
||||||
use crate::error::StorageError;
|
|
||||||
|
|
||||||
/// A skipped message key entry for storage.
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
|
||||||
pub struct SkippedKey {
|
|
||||||
/// The public key associated with this skipped message.
|
|
||||||
pub public_key: [u8; 32],
|
|
||||||
/// The message number.
|
|
||||||
pub msg_num: u32,
|
|
||||||
/// The 32-byte message key.
|
|
||||||
pub message_key: [u8; 32],
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Serializable version of `RatchetState`.
|
|
||||||
///
|
|
||||||
/// This struct stores all keys as raw byte arrays for easy serialization
|
|
||||||
/// and database storage. Use `from_ratchet_state()` and `to_ratchet_state()`
|
|
||||||
/// for conversion.
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct StorableRatchetState {
|
|
||||||
/// The current root key (32 bytes).
|
|
||||||
pub root_key: [u8; 32],
|
|
||||||
|
|
||||||
/// The current sending chain key, if any (32 bytes).
|
|
||||||
pub sending_chain: Option<[u8; 32]>,
|
|
||||||
|
|
||||||
/// The current receiving chain key, if any (32 bytes).
|
|
||||||
pub receiving_chain: Option<[u8; 32]>,
|
|
||||||
|
|
||||||
/// Our DH secret key (32 bytes).
|
|
||||||
///
|
|
||||||
/// **Security**: This should be encrypted before storage.
|
|
||||||
pub dh_self_secret: [u8; 32],
|
|
||||||
|
|
||||||
/// Our DH public key (32 bytes).
|
|
||||||
pub dh_self_public: [u8; 32],
|
|
||||||
|
|
||||||
/// Remote party's DH public key, if known (32 bytes).
|
|
||||||
pub dh_remote: Option<[u8; 32]>,
|
|
||||||
|
|
||||||
/// Number of messages sent in the current sending chain.
|
|
||||||
pub msg_send: u32,
|
|
||||||
|
|
||||||
/// Number of messages received in the current receiving chain.
|
|
||||||
pub msg_recv: u32,
|
|
||||||
|
|
||||||
/// Length of the previous sending chain.
|
|
||||||
pub prev_chain_len: u32,
|
|
||||||
|
|
||||||
/// Skipped message keys for out-of-order message handling.
|
|
||||||
pub skipped_keys: Vec<SkippedKey>,
|
|
||||||
|
|
||||||
/// Domain identifier for HKDF info.
|
|
||||||
pub domain_id: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StorableRatchetState {
|
|
||||||
/// Convert a `RatchetState` into a `StorableRatchetState`.
|
|
||||||
///
|
|
||||||
/// # Type Parameters
|
|
||||||
///
|
|
||||||
/// * `D` - The HKDF domain type implementing `HkdfInfo`.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `state` - The ratchet state to convert.
|
|
||||||
/// * `domain_id` - A string identifier for the domain (used to reconstruct the correct domain type).
|
|
||||||
pub fn from_ratchet_state<D: HkdfInfo>(state: &RatchetState<D>, domain_id: &str) -> Self {
|
|
||||||
let skipped_keys: Vec<SkippedKey> = state
|
|
||||||
.skipped_keys
|
|
||||||
.iter()
|
|
||||||
.map(|((pub_key, msg_num), msg_key)| SkippedKey {
|
|
||||||
public_key: *pub_key.as_bytes(),
|
|
||||||
msg_num: *msg_num,
|
|
||||||
message_key: *msg_key,
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
StorableRatchetState {
|
|
||||||
root_key: state.root_key,
|
|
||||||
sending_chain: state.sending_chain,
|
|
||||||
receiving_chain: state.receiving_chain,
|
|
||||||
dh_self_secret: state.dh_self.secret_bytes(),
|
|
||||||
dh_self_public: *state.dh_self.public().as_bytes(),
|
|
||||||
dh_remote: state.dh_remote.map(|pk| *pk.as_bytes()),
|
|
||||||
msg_send: state.msg_send,
|
|
||||||
msg_recv: state.msg_recv,
|
|
||||||
prev_chain_len: state.prev_chain_len,
|
|
||||||
skipped_keys,
|
|
||||||
domain_id: domain_id.to_string(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert this `StorableRatchetState` back into a `RatchetState`.
|
|
||||||
///
|
|
||||||
/// # Type Parameters
|
|
||||||
///
|
|
||||||
/// * `D` - The HKDF domain type implementing `HkdfInfo`.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// * `Ok(RatchetState)` on success.
|
|
||||||
/// * `Err(StorageError)` if key reconstruction fails.
|
|
||||||
pub fn to_ratchet_state<D: HkdfInfo>(&self) -> Result<RatchetState<D>, StorageError> {
|
|
||||||
// Reconstruct the keypair
|
|
||||||
let dh_self = InstallationKeyPair::from_bytes(self.dh_self_secret, self.dh_self_public)
|
|
||||||
.map_err(|e| StorageError::KeyReconstruction(e.to_string()))?;
|
|
||||||
|
|
||||||
// Reconstruct skipped keys HashMap
|
|
||||||
let skipped_keys: HashMap<(PublicKey, u32), [u8; 32]> = self
|
|
||||||
.skipped_keys
|
|
||||||
.iter()
|
|
||||||
.map(|sk| {
|
|
||||||
let pub_key = PublicKey::from(sk.public_key);
|
|
||||||
((pub_key, sk.msg_num), sk.message_key)
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
Ok(RatchetState::from_parts(
|
|
||||||
self.root_key,
|
|
||||||
self.sending_chain,
|
|
||||||
self.receiving_chain,
|
|
||||||
dh_self,
|
|
||||||
self.dh_remote.map(PublicKey::from),
|
|
||||||
self.msg_send,
|
|
||||||
self.msg_recv,
|
|
||||||
self.prev_chain_len,
|
|
||||||
skipped_keys,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use double_ratchets::hkdf::DefaultDomain;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_roundtrip_sender_state() {
|
|
||||||
// Create a sender state
|
|
||||||
let bob_keypair = InstallationKeyPair::generate();
|
|
||||||
let shared_secret = [0x42u8; 32];
|
|
||||||
let state: RatchetState<DefaultDomain> =
|
|
||||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
|
||||||
|
|
||||||
// Convert to storable and back
|
|
||||||
let storable = StorableRatchetState::from_ratchet_state(&state, "default");
|
|
||||||
let restored: RatchetState<DefaultDomain> = storable.to_ratchet_state().unwrap();
|
|
||||||
|
|
||||||
// Verify fields match
|
|
||||||
assert_eq!(state.root_key, restored.root_key);
|
|
||||||
assert_eq!(state.sending_chain, restored.sending_chain);
|
|
||||||
assert_eq!(state.receiving_chain, restored.receiving_chain);
|
|
||||||
assert_eq!(state.dh_self.public().as_bytes(), restored.dh_self.public().as_bytes());
|
|
||||||
assert_eq!(state.dh_remote.map(|pk| *pk.as_bytes()), restored.dh_remote.map(|pk| *pk.as_bytes()));
|
|
||||||
assert_eq!(state.msg_send, restored.msg_send);
|
|
||||||
assert_eq!(state.msg_recv, restored.msg_recv);
|
|
||||||
assert_eq!(state.prev_chain_len, restored.prev_chain_len);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_roundtrip_receiver_state() {
|
|
||||||
// Create a receiver state
|
|
||||||
let keypair = InstallationKeyPair::generate();
|
|
||||||
let shared_secret = [0x42u8; 32];
|
|
||||||
let state: RatchetState<DefaultDomain> =
|
|
||||||
RatchetState::init_receiver(shared_secret, keypair);
|
|
||||||
|
|
||||||
// Convert to storable and back
|
|
||||||
let storable = StorableRatchetState::from_ratchet_state(&state, "default");
|
|
||||||
let restored: RatchetState<DefaultDomain> = storable.to_ratchet_state().unwrap();
|
|
||||||
|
|
||||||
// Verify fields match
|
|
||||||
assert_eq!(state.root_key, restored.root_key);
|
|
||||||
assert_eq!(state.dh_self.public().as_bytes(), restored.dh_self.public().as_bytes());
|
|
||||||
assert!(restored.dh_remote.is_none());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_roundtrip_with_skipped_keys() {
|
|
||||||
// Create states and exchange messages to generate skipped keys
|
|
||||||
let bob_keypair = InstallationKeyPair::generate();
|
|
||||||
let shared_secret = [0x42u8; 32];
|
|
||||||
let mut alice: RatchetState<DefaultDomain> =
|
|
||||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
|
||||||
let mut bob: RatchetState<DefaultDomain> =
|
|
||||||
RatchetState::init_receiver(shared_secret, bob_keypair);
|
|
||||||
|
|
||||||
// Alice sends multiple messages
|
|
||||||
let mut messages = vec![];
|
|
||||||
for i in 0..3 {
|
|
||||||
let (ct, header) = alice.encrypt_message(&format!("Message {}", i).into_bytes());
|
|
||||||
messages.push((ct, header));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bob receives them out of order to create skipped keys
|
|
||||||
bob.decrypt_message(&messages[0].0, messages[0].1.clone()).unwrap();
|
|
||||||
bob.decrypt_message(&messages[2].0, messages[2].1.clone()).unwrap();
|
|
||||||
// Message 1 key is now in skipped_keys
|
|
||||||
|
|
||||||
assert!(!bob.skipped_keys.is_empty());
|
|
||||||
|
|
||||||
// Convert to storable and back
|
|
||||||
let storable = StorableRatchetState::from_ratchet_state(&bob, "default");
|
|
||||||
let restored: RatchetState<DefaultDomain> = storable.to_ratchet_state().unwrap();
|
|
||||||
|
|
||||||
// Verify skipped keys are preserved
|
|
||||||
assert_eq!(bob.skipped_keys.len(), restored.skipped_keys.len());
|
|
||||||
|
|
||||||
// The restored state should be able to decrypt the skipped message
|
|
||||||
let mut restored = restored;
|
|
||||||
let pt = restored.decrypt_message(&messages[1].0, messages[1].1.clone()).unwrap();
|
|
||||||
assert_eq!(pt, b"Message 1");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Loading…
x
Reference in New Issue
Block a user