mirror of
https://github.com/logos-messaging/libchat.git
synced 2026-02-10 08:53:08 +00:00
feat: memory and sqlcipher storage
This commit is contained in:
parent
fc76453f4c
commit
34a03275cc
231
Cargo.lock
generated
231
Cargo.lock
generated
@ -12,6 +12,24 @@ dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.8.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"once_cell",
|
||||
"version_check",
|
||||
"zerocopy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "2.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3"
|
||||
|
||||
[[package]]
|
||||
name = "blake2"
|
||||
version = "0.10.6"
|
||||
@ -30,6 +48,16 @@ dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.2.53"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "755d2fce177175ffca841e9a06afdb2c4ab0f593d53b4dee48147dfaade85932"
|
||||
dependencies = [
|
||||
"find-msvc-tools",
|
||||
"shlex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.4"
|
||||
@ -141,12 +169,60 @@ dependencies = [
|
||||
"x25519-dalek",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "double-ratchets-storage"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"chacha20poly1305",
|
||||
"double-ratchets",
|
||||
"rand",
|
||||
"rusqlite",
|
||||
"serde",
|
||||
"tempfile",
|
||||
"thiserror",
|
||||
"x25519-dalek",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fallible-iterator"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649"
|
||||
|
||||
[[package]]
|
||||
name = "fallible-streaming-iterator"
|
||||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||
|
||||
[[package]]
|
||||
name = "fiat-crypto"
|
||||
version = "0.2.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d"
|
||||
|
||||
[[package]]
|
||||
name = "find-msvc-tools"
|
||||
version = "0.1.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8591b0bcc8a98a64310a2fae1bb3e9b8564dd10e381e6e28010fde8e8e8568db"
|
||||
|
||||
[[package]]
|
||||
name = "generic-array"
|
||||
version = "0.14.7"
|
||||
@ -168,6 +244,36 @@ dependencies = [
|
||||
"wasi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.3.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"r-efi",
|
||||
"wasip2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.14.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashlink"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af"
|
||||
dependencies = [
|
||||
"hashbrown",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hkdf"
|
||||
version = "0.12.4"
|
||||
@ -201,6 +307,23 @@ version = "0.2.178"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091"
|
||||
|
||||
[[package]]
|
||||
name = "libsqlite3-sys"
|
||||
version = "0.30.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"pkg-config",
|
||||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039"
|
||||
|
||||
[[package]]
|
||||
name = "logos-chat"
|
||||
version = "0.1.0"
|
||||
@ -208,12 +331,24 @@ dependencies = [
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.21.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
|
||||
|
||||
[[package]]
|
||||
name = "opaque-debug"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
|
||||
|
||||
[[package]]
|
||||
name = "pkg-config"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
|
||||
|
||||
[[package]]
|
||||
name = "poly1305"
|
||||
version = "0.8.0"
|
||||
@ -252,6 +387,12 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "r-efi"
|
||||
version = "5.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.8.5"
|
||||
@ -279,7 +420,21 @@ version = "0.6.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"getrandom 0.2.16",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rusqlite"
|
||||
version = "0.32.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"fallible-iterator",
|
||||
"fallible-streaming-iterator",
|
||||
"hashlink",
|
||||
"libsqlite3-sys",
|
||||
"smallvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -291,6 +446,19 @@ dependencies = [
|
||||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "1.0.27"
|
||||
@ -327,6 +495,18 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "shlex"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
||||
|
||||
[[package]]
|
||||
name = "smallvec"
|
||||
version = "1.15.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03"
|
||||
|
||||
[[package]]
|
||||
name = "subtle"
|
||||
version = "2.6.1"
|
||||
@ -344,6 +524,19 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tempfile"
|
||||
version = "3.24.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c"
|
||||
dependencies = [
|
||||
"fastrand",
|
||||
"getrandom 0.3.4",
|
||||
"once_cell",
|
||||
"rustix",
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "2.0.17"
|
||||
@ -386,6 +579,12 @@ dependencies = [
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "vcpkg"
|
||||
version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.5"
|
||||
@ -398,6 +597,36 @@ version = "0.11.1+wasi-snapshot-preview1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b"
|
||||
|
||||
[[package]]
|
||||
name = "wasip2"
|
||||
version = "1.0.2+wasi-0.2.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5"
|
||||
dependencies = [
|
||||
"wit-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-link"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.61.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc"
|
||||
dependencies = [
|
||||
"windows-link",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen"
|
||||
version = "0.51.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5"
|
||||
|
||||
[[package]]
|
||||
name = "x25519-dalek"
|
||||
version = "2.0.1"
|
||||
|
||||
@ -5,4 +5,5 @@ resolver = "3"
|
||||
members = [
|
||||
"conversations",
|
||||
"double-ratchets",
|
||||
"double-ratchets-storage",
|
||||
]
|
||||
|
||||
21
double-ratchets-storage/Cargo.toml
Normal file
21
double-ratchets-storage/Cargo.toml
Normal file
@ -0,0 +1,21 @@
|
||||
[package]
|
||||
name = "double-ratchets-storage"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[features]
|
||||
default = ["sqlite"]
|
||||
sqlite = ["rusqlite/bundled"]
|
||||
sqlcipher = ["rusqlite/bundled-sqlcipher"]
|
||||
|
||||
[dependencies]
|
||||
double-ratchets = { path = "../double-ratchets" }
|
||||
x25519-dalek = { version = "2.0.1", features = ["static_secrets"] }
|
||||
rusqlite = { version = "0.32", optional = true }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
thiserror = "2"
|
||||
chacha20poly1305 = "0.10"
|
||||
rand = "0.8"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
46
double-ratchets-storage/src/error.rs
Normal file
46
double-ratchets-storage/src/error.rs
Normal file
@ -0,0 +1,46 @@
|
||||
//! Error types for the storage module.
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// Errors that can occur during storage operations.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum StorageError {
|
||||
/// Database operation failed.
|
||||
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
|
||||
#[error("database error: {0}")]
|
||||
Database(#[from] rusqlite::Error),
|
||||
|
||||
/// Field-level encryption failed.
|
||||
#[error("encryption failed: {0}")]
|
||||
Encryption(String),
|
||||
|
||||
/// Field-level decryption failed.
|
||||
#[error("decryption failed: {0}")]
|
||||
Decryption(String),
|
||||
|
||||
/// Stored state is corrupted or invalid.
|
||||
#[error("corrupted state: {0}")]
|
||||
CorruptedState(String),
|
||||
|
||||
/// Session was not found in storage.
|
||||
#[error("session not found: {}", hex::encode(.session_id))]
|
||||
SessionNotFound {
|
||||
/// The session ID that was not found.
|
||||
session_id: [u8; 32],
|
||||
},
|
||||
|
||||
/// I/O operation failed.
|
||||
#[error("I/O error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
/// Key reconstruction failed.
|
||||
#[error("key reconstruction failed: {0}")]
|
||||
KeyReconstruction(String),
|
||||
}
|
||||
|
||||
/// Helper module for hex encoding (used in error messages).
|
||||
mod hex {
|
||||
pub fn encode(bytes: &[u8]) -> String {
|
||||
bytes.iter().map(|b| format!("{:02x}", b)).collect()
|
||||
}
|
||||
}
|
||||
178
double-ratchets-storage/src/lib.rs
Normal file
178
double-ratchets-storage/src/lib.rs
Normal file
@ -0,0 +1,178 @@
|
||||
//! Persistent storage for Double Ratchet state.
|
||||
//!
|
||||
//! This crate provides storage backends for persisting [`RatchetState`](double_ratchets::RatchetState)
|
||||
//! across application restarts. It includes:
|
||||
//!
|
||||
//! - [`MemoryStorage`] - In-memory storage for testing
|
||||
//! - [`SqliteStorage`] - SQLite storage with field-level encryption (requires `sqlite` feature)
|
||||
//!
|
||||
//! # Features
|
||||
//!
|
||||
//! - `sqlite` (default) - Enables SQLite storage with `rusqlite/bundled`
|
||||
//! - `sqlcipher` - Enables SQLCipher full-database encryption (mutually exclusive with `sqlite`)
|
||||
//!
|
||||
//! # Security
|
||||
//!
|
||||
//! Private keys (`dh_self_secret`) are always encrypted with ChaCha20Poly1305 before storage,
|
||||
//! even when using plain SQLite. For additional security, enable the `sqlcipher` feature
|
||||
//! for full database encryption.
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```no_run
|
||||
//! use double_ratchets::hkdf::DefaultDomain;
|
||||
//! use double_ratchets::state::RatchetState;
|
||||
//! use double_ratchets::InstallationKeyPair;
|
||||
//! use double_ratchets_storage::{
|
||||
//! RatchetStorage, SqliteStorage, StorableRatchetState,
|
||||
//! };
|
||||
//!
|
||||
//! // Create a ratchet state
|
||||
//! let bob_keypair = InstallationKeyPair::generate();
|
||||
//! let shared_secret = [0x42u8; 32];
|
||||
//! let state: RatchetState<DefaultDomain> =
|
||||
//! RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
||||
//!
|
||||
//! // Open storage
|
||||
//! let encryption_key = [0u8; 32]; // Use proper key derivation!
|
||||
//! let storage = SqliteStorage::open("ratchets.db", encryption_key).unwrap();
|
||||
//!
|
||||
//! // Save state
|
||||
//! let session_id = [1u8; 32];
|
||||
//! let storable = StorableRatchetState::from_ratchet_state(&state, "default");
|
||||
//! storage.save(&session_id, &storable).unwrap();
|
||||
//!
|
||||
//! // Load state
|
||||
//! let loaded = storage.load(&session_id).unwrap().unwrap();
|
||||
//! let restored: RatchetState<DefaultDomain> = loaded.to_ratchet_state().unwrap();
|
||||
//! ```
|
||||
|
||||
pub mod error;
|
||||
pub mod memory;
|
||||
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
|
||||
pub mod sqlite;
|
||||
pub mod traits;
|
||||
pub mod types;
|
||||
|
||||
// Re-exports for convenience
|
||||
pub use error::StorageError;
|
||||
pub use memory::MemoryStorage;
|
||||
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
|
||||
pub use sqlite::{EncryptionKey, SqliteStorage};
|
||||
pub use traits::{RatchetStorage, SessionId};
|
||||
pub use types::{SkippedKey, StorableRatchetState};
|
||||
|
||||
#[cfg(test)]
|
||||
mod integration_tests {
|
||||
use super::*;
|
||||
use double_ratchets::hkdf::DefaultDomain;
|
||||
use double_ratchets::state::RatchetState;
|
||||
use double_ratchets::InstallationKeyPair;
|
||||
|
||||
/// Integration test: full encryption/decryption cycle with storage
|
||||
#[test]
|
||||
fn test_full_conversation_with_storage_roundtrip() {
|
||||
// Setup Alice and Bob
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
|
||||
let mut alice: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
||||
let mut bob: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_receiver(shared_secret, bob_keypair);
|
||||
|
||||
let storage = MemoryStorage::new();
|
||||
let alice_session = [0xAA; 32];
|
||||
let bob_session = [0xBB; 32];
|
||||
|
||||
// Alice sends a message
|
||||
let (ct1, header1) = alice.encrypt_message(b"Hello Bob!");
|
||||
|
||||
// Save Alice's state
|
||||
let alice_storable = StorableRatchetState::from_ratchet_state(&alice, "default");
|
||||
storage.save(&alice_session, &alice_storable).unwrap();
|
||||
|
||||
// Bob receives the message
|
||||
let pt1 = bob.decrypt_message(&ct1, header1).unwrap();
|
||||
assert_eq!(pt1, b"Hello Bob!");
|
||||
|
||||
// Save Bob's state
|
||||
let bob_storable = StorableRatchetState::from_ratchet_state(&bob, "default");
|
||||
storage.save(&bob_session, &bob_storable).unwrap();
|
||||
|
||||
// Simulate restart: load states from storage
|
||||
let alice_loaded = storage.load(&alice_session).unwrap().unwrap();
|
||||
let bob_loaded = storage.load(&bob_session).unwrap().unwrap();
|
||||
|
||||
let mut alice_restored: RatchetState<DefaultDomain> =
|
||||
alice_loaded.to_ratchet_state().unwrap();
|
||||
let mut bob_restored: RatchetState<DefaultDomain> =
|
||||
bob_loaded.to_ratchet_state().unwrap();
|
||||
|
||||
// Bob replies
|
||||
let (ct2, header2) = bob_restored.encrypt_message(b"Hi Alice!");
|
||||
let pt2 = alice_restored.decrypt_message(&ct2, header2).unwrap();
|
||||
assert_eq!(pt2, b"Hi Alice!");
|
||||
|
||||
// Alice sends another message
|
||||
let (ct3, header3) = alice_restored.encrypt_message(b"How are you?");
|
||||
let pt3 = bob_restored.decrypt_message(&ct3, header3).unwrap();
|
||||
assert_eq!(pt3, b"How are you?");
|
||||
}
|
||||
|
||||
/// Integration test: verify SQLite storage with encryption works
|
||||
#[cfg(any(feature = "sqlite", feature = "sqlcipher"))]
|
||||
#[test]
|
||||
fn test_sqlite_integration() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let db_path = dir.path().join("integration_test.db");
|
||||
let key = [0x42u8; 32];
|
||||
|
||||
// Setup
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
let mut alice: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
||||
let mut bob: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_receiver(shared_secret, bob_keypair);
|
||||
|
||||
let alice_session = [0xAA; 32];
|
||||
let bob_session = [0xBB; 32];
|
||||
|
||||
// Exchange messages
|
||||
let (ct1, header1) = alice.encrypt_message(b"Message 1");
|
||||
bob.decrypt_message(&ct1, header1).unwrap();
|
||||
|
||||
let (ct2, header2) = bob.encrypt_message(b"Response 1");
|
||||
alice.decrypt_message(&ct2, header2).unwrap();
|
||||
|
||||
// Save both states
|
||||
{
|
||||
let storage = SqliteStorage::open(&db_path, key).unwrap();
|
||||
|
||||
let alice_storable = StorableRatchetState::from_ratchet_state(&alice, "default");
|
||||
let bob_storable = StorableRatchetState::from_ratchet_state(&bob, "default");
|
||||
|
||||
storage.save(&alice_session, &alice_storable).unwrap();
|
||||
storage.save(&bob_session, &bob_storable).unwrap();
|
||||
}
|
||||
|
||||
// Reopen database (simulating restart)
|
||||
{
|
||||
let storage = SqliteStorage::open(&db_path, key).unwrap();
|
||||
|
||||
let alice_loaded = storage.load(&alice_session).unwrap().unwrap();
|
||||
let bob_loaded = storage.load(&bob_session).unwrap().unwrap();
|
||||
|
||||
let mut alice_restored: RatchetState<DefaultDomain> =
|
||||
alice_loaded.to_ratchet_state().unwrap();
|
||||
let mut bob_restored: RatchetState<DefaultDomain> =
|
||||
bob_loaded.to_ratchet_state().unwrap();
|
||||
|
||||
// Continue conversation
|
||||
let (ct3, header3) = alice_restored.encrypt_message(b"Message 2");
|
||||
let pt3 = bob_restored.decrypt_message(&ct3, header3).unwrap();
|
||||
assert_eq!(pt3, b"Message 2");
|
||||
}
|
||||
}
|
||||
}
|
||||
216
double-ratchets-storage/src/memory.rs
Normal file
216
double-ratchets-storage/src/memory.rs
Normal file
@ -0,0 +1,216 @@
|
||||
//! In-memory storage implementation for testing.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::RwLock;
|
||||
|
||||
use crate::error::StorageError;
|
||||
use crate::traits::{RatchetStorage, SessionId};
|
||||
use crate::types::StorableRatchetState;
|
||||
|
||||
/// In-memory storage backend for testing purposes.
|
||||
///
|
||||
/// This implementation stores ratchet states in a `HashMap` wrapped in a `RwLock`
|
||||
/// for thread-safe access. Data is not persisted across process restarts.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use double_ratchets_storage::{MemoryStorage, RatchetStorage};
|
||||
///
|
||||
/// let storage = MemoryStorage::new();
|
||||
/// assert!(storage.list_sessions().unwrap().is_empty());
|
||||
/// ```
|
||||
pub struct MemoryStorage {
|
||||
states: RwLock<HashMap<SessionId, StorableRatchetState>>,
|
||||
}
|
||||
|
||||
impl MemoryStorage {
|
||||
/// Create a new empty in-memory storage.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
states: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the number of stored sessions.
|
||||
pub fn len(&self) -> usize {
|
||||
self.states.read().unwrap().len()
|
||||
}
|
||||
|
||||
/// Check if the storage is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.states.read().unwrap().is_empty()
|
||||
}
|
||||
|
||||
/// Clear all stored sessions.
|
||||
pub fn clear(&self) {
|
||||
self.states.write().unwrap().clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MemoryStorage {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl RatchetStorage for MemoryStorage {
|
||||
fn save(&self, session_id: &SessionId, state: &StorableRatchetState) -> Result<(), StorageError> {
|
||||
let mut states = self.states.write().unwrap();
|
||||
states.insert(*session_id, state.clone());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load(&self, session_id: &SessionId) -> Result<Option<StorableRatchetState>, StorageError> {
|
||||
let states = self.states.read().unwrap();
|
||||
Ok(states.get(session_id).cloned())
|
||||
}
|
||||
|
||||
fn delete(&self, session_id: &SessionId) -> Result<bool, StorageError> {
|
||||
let mut states = self.states.write().unwrap();
|
||||
Ok(states.remove(session_id).is_some())
|
||||
}
|
||||
|
||||
fn exists(&self, session_id: &SessionId) -> Result<bool, StorageError> {
|
||||
let states = self.states.read().unwrap();
|
||||
Ok(states.contains_key(session_id))
|
||||
}
|
||||
|
||||
fn list_sessions(&self) -> Result<Vec<SessionId>, StorageError> {
|
||||
let states = self.states.read().unwrap();
|
||||
Ok(states.keys().copied().collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use double_ratchets::hkdf::DefaultDomain;
|
||||
use double_ratchets::state::RatchetState;
|
||||
use double_ratchets::InstallationKeyPair;
|
||||
|
||||
fn create_test_state() -> StorableRatchetState {
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
let state: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
||||
StorableRatchetState::from_ratchet_state(&state, "default")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_save_and_load() {
|
||||
let storage = MemoryStorage::new();
|
||||
let session_id = [1u8; 32];
|
||||
let state = create_test_state();
|
||||
|
||||
storage.save(&session_id, &state).unwrap();
|
||||
let loaded = storage.load(&session_id).unwrap();
|
||||
|
||||
assert!(loaded.is_some());
|
||||
let loaded = loaded.unwrap();
|
||||
assert_eq!(loaded.root_key, state.root_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_nonexistent() {
|
||||
let storage = MemoryStorage::new();
|
||||
let session_id = [1u8; 32];
|
||||
|
||||
let loaded = storage.load(&session_id).unwrap();
|
||||
assert!(loaded.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delete() {
|
||||
let storage = MemoryStorage::new();
|
||||
let session_id = [1u8; 32];
|
||||
let state = create_test_state();
|
||||
|
||||
storage.save(&session_id, &state).unwrap();
|
||||
assert!(storage.exists(&session_id).unwrap());
|
||||
|
||||
let deleted = storage.delete(&session_id).unwrap();
|
||||
assert!(deleted);
|
||||
assert!(!storage.exists(&session_id).unwrap());
|
||||
|
||||
// Deleting again should return false
|
||||
let deleted = storage.delete(&session_id).unwrap();
|
||||
assert!(!deleted);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exists() {
|
||||
let storage = MemoryStorage::new();
|
||||
let session_id = [1u8; 32];
|
||||
|
||||
assert!(!storage.exists(&session_id).unwrap());
|
||||
|
||||
let state = create_test_state();
|
||||
storage.save(&session_id, &state).unwrap();
|
||||
|
||||
assert!(storage.exists(&session_id).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_list_sessions() {
|
||||
let storage = MemoryStorage::new();
|
||||
|
||||
assert!(storage.list_sessions().unwrap().is_empty());
|
||||
|
||||
let state = create_test_state();
|
||||
let session_ids: Vec<SessionId> = (0..3).map(|i| [i; 32]).collect();
|
||||
|
||||
for id in &session_ids {
|
||||
storage.save(id, &state).unwrap();
|
||||
}
|
||||
|
||||
let mut listed = storage.list_sessions().unwrap();
|
||||
listed.sort();
|
||||
let mut expected = session_ids.clone();
|
||||
expected.sort();
|
||||
|
||||
assert_eq!(listed, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_overwrite() {
|
||||
let storage = MemoryStorage::new();
|
||||
let session_id = [1u8; 32];
|
||||
|
||||
// Create first state
|
||||
let bob_keypair1 = InstallationKeyPair::generate();
|
||||
let state1: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender([0x42u8; 32], *bob_keypair1.public());
|
||||
let storable1 = StorableRatchetState::from_ratchet_state(&state1, "default");
|
||||
|
||||
// Create second state with different root
|
||||
let bob_keypair2 = InstallationKeyPair::generate();
|
||||
let state2: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender([0x43u8; 32], *bob_keypair2.public());
|
||||
let storable2 = StorableRatchetState::from_ratchet_state(&state2, "default");
|
||||
|
||||
// Save first, then overwrite with second
|
||||
storage.save(&session_id, &storable1).unwrap();
|
||||
storage.save(&session_id, &storable2).unwrap();
|
||||
|
||||
// Should have the second state
|
||||
let loaded = storage.load(&session_id).unwrap().unwrap();
|
||||
assert_eq!(loaded.root_key, storable2.root_key);
|
||||
assert_ne!(loaded.root_key, storable1.root_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear() {
|
||||
let storage = MemoryStorage::new();
|
||||
let state = create_test_state();
|
||||
|
||||
for i in 0..5 {
|
||||
storage.save(&[i; 32], &state).unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(storage.len(), 5);
|
||||
|
||||
storage.clear();
|
||||
assert!(storage.is_empty());
|
||||
}
|
||||
}
|
||||
746
double-ratchets-storage/src/sqlite.rs
Normal file
746
double-ratchets-storage/src/sqlite.rs
Normal file
@ -0,0 +1,746 @@
|
||||
//! SQLite storage implementation with field-level encryption.
|
||||
|
||||
use std::path::Path;
|
||||
use std::sync::Mutex;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use chacha20poly1305::{
|
||||
aead::{Aead, KeyInit},
|
||||
ChaCha20Poly1305, Nonce,
|
||||
};
|
||||
use rand::RngCore;
|
||||
use rusqlite::{params, Connection, OptionalExtension};
|
||||
|
||||
use crate::error::StorageError;
|
||||
use crate::traits::{RatchetStorage, SessionId};
|
||||
use crate::types::{SkippedKey, StorableRatchetState};
|
||||
|
||||
/// Field encryption key type (32 bytes for ChaCha20Poly1305).
|
||||
pub type EncryptionKey = [u8; 32];
|
||||
|
||||
/// SQLite storage backend with field-level encryption for secrets.
|
||||
///
|
||||
/// This implementation stores ratchet states in SQLite with:
|
||||
/// - Field-level encryption for private keys using ChaCha20Poly1305
|
||||
/// - WAL mode for better concurrent performance
|
||||
/// - Foreign keys and cascading deletes for data integrity
|
||||
/// - Atomic transactions to prevent partial writes
|
||||
///
|
||||
/// # Security
|
||||
///
|
||||
/// The `dh_self_secret` field is encrypted with the provided encryption key.
|
||||
/// For additional security, consider using SQLCipher for full database encryption
|
||||
/// via the `open_encrypted` method (requires `sqlcipher` feature).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use double_ratchets_storage::SqliteStorage;
|
||||
///
|
||||
/// let key = [0u8; 32]; // Use a proper key derivation function
|
||||
/// let storage = SqliteStorage::open("ratchets.db", key).unwrap();
|
||||
/// ```
|
||||
pub struct SqliteStorage {
|
||||
conn: Mutex<Connection>,
|
||||
encryption_key: EncryptionKey,
|
||||
}
|
||||
|
||||
impl SqliteStorage {
|
||||
/// Open or create a SQLite database with field-level encryption.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the database file.
|
||||
/// * `encryption_key` - 32-byte key for field-level encryption.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(SqliteStorage)` on success.
|
||||
/// * `Err(StorageError)` on failure.
|
||||
pub fn open<P: AsRef<Path>>(path: P, encryption_key: EncryptionKey) -> Result<Self, StorageError> {
|
||||
let conn = Connection::open(path)?;
|
||||
Self::initialize(conn, encryption_key)
|
||||
}
|
||||
|
||||
/// Create an in-memory SQLite database (for testing).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `encryption_key` - 32-byte key for field-level encryption.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(SqliteStorage)` on success.
|
||||
/// * `Err(StorageError)` on failure.
|
||||
pub fn open_in_memory(encryption_key: EncryptionKey) -> Result<Self, StorageError> {
|
||||
let conn = Connection::open_in_memory()?;
|
||||
Self::initialize(conn, encryption_key)
|
||||
}
|
||||
|
||||
/// Open or create a SQLCipher-encrypted database.
|
||||
///
|
||||
/// This method requires the `sqlcipher` feature to be enabled.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the database file.
|
||||
/// * `db_password` - Password for SQLCipher database encryption.
|
||||
/// * `field_key` - 32-byte key for additional field-level encryption.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(SqliteStorage)` on success.
|
||||
/// * `Err(StorageError)` on failure.
|
||||
#[cfg(feature = "sqlcipher")]
|
||||
pub fn open_encrypted<P: AsRef<Path>>(
|
||||
path: P,
|
||||
db_password: &str,
|
||||
field_key: EncryptionKey,
|
||||
) -> Result<Self, StorageError> {
|
||||
let conn = Connection::open(path)?;
|
||||
|
||||
// Set SQLCipher key
|
||||
conn.pragma_update(None, "key", db_password)?;
|
||||
|
||||
Self::initialize(conn, field_key)
|
||||
}
|
||||
|
||||
fn initialize(conn: Connection, encryption_key: EncryptionKey) -> Result<Self, StorageError> {
|
||||
// Enable WAL mode for better performance
|
||||
conn.pragma_update(None, "journal_mode", "WAL")?;
|
||||
|
||||
// Enable foreign keys
|
||||
conn.pragma_update(None, "foreign_keys", "ON")?;
|
||||
|
||||
// Create tables
|
||||
conn.execute_batch(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS ratchet_states (
|
||||
session_id BLOB PRIMARY KEY,
|
||||
root_key BLOB NOT NULL,
|
||||
sending_chain BLOB,
|
||||
receiving_chain BLOB,
|
||||
dh_self_secret_encrypted BLOB NOT NULL,
|
||||
dh_self_secret_nonce BLOB NOT NULL,
|
||||
dh_self_public BLOB NOT NULL,
|
||||
dh_remote BLOB,
|
||||
msg_send INTEGER NOT NULL,
|
||||
msg_recv INTEGER NOT NULL,
|
||||
prev_chain_len INTEGER NOT NULL,
|
||||
domain_id TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS skipped_keys (
|
||||
id INTEGER PRIMARY KEY,
|
||||
session_id BLOB NOT NULL REFERENCES ratchet_states(session_id) ON DELETE CASCADE,
|
||||
public_key BLOB NOT NULL,
|
||||
msg_num INTEGER NOT NULL,
|
||||
message_key BLOB NOT NULL,
|
||||
UNIQUE(session_id, public_key, msg_num)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_skipped_keys_session ON skipped_keys(session_id);
|
||||
"#,
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
conn: Mutex::new(conn),
|
||||
encryption_key,
|
||||
})
|
||||
}
|
||||
|
||||
/// Encrypt a 32-byte secret using ChaCha20Poly1305.
|
||||
fn encrypt_secret(&self, secret: &[u8; 32]) -> Result<(Vec<u8>, [u8; 12]), StorageError> {
|
||||
let cipher = ChaCha20Poly1305::new((&self.encryption_key).into());
|
||||
|
||||
let mut nonce_bytes = [0u8; 12];
|
||||
rand::thread_rng().fill_bytes(&mut nonce_bytes);
|
||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||
|
||||
let ciphertext = cipher
|
||||
.encrypt(nonce, secret.as_ref())
|
||||
.map_err(|e| StorageError::Encryption(e.to_string()))?;
|
||||
|
||||
Ok((ciphertext, nonce_bytes))
|
||||
}
|
||||
|
||||
/// Decrypt a secret using ChaCha20Poly1305.
|
||||
fn decrypt_secret(&self, ciphertext: &[u8], nonce: &[u8; 12]) -> Result<[u8; 32], StorageError> {
|
||||
let cipher = ChaCha20Poly1305::new((&self.encryption_key).into());
|
||||
let nonce = Nonce::from_slice(nonce);
|
||||
|
||||
let plaintext = cipher
|
||||
.decrypt(nonce, ciphertext)
|
||||
.map_err(|e| StorageError::Decryption(e.to_string()))?;
|
||||
|
||||
plaintext
|
||||
.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("decrypted secret has wrong length".to_string()))
|
||||
}
|
||||
|
||||
fn current_timestamp() -> i64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64
|
||||
}
|
||||
}
|
||||
|
||||
impl RatchetStorage for SqliteStorage {
|
||||
fn save(&self, session_id: &SessionId, state: &StorableRatchetState) -> Result<(), StorageError> {
|
||||
let (encrypted_secret, nonce) = self.encrypt_secret(&state.dh_self_secret)?;
|
||||
|
||||
let conn = self.conn.lock().unwrap();
|
||||
let tx = conn.unchecked_transaction()?;
|
||||
|
||||
let now = Self::current_timestamp();
|
||||
|
||||
// Check if session exists to determine created_at
|
||||
let exists: bool = tx.query_row(
|
||||
"SELECT 1 FROM ratchet_states WHERE session_id = ?",
|
||||
[session_id.as_slice()],
|
||||
|_| Ok(true),
|
||||
).optional()?.unwrap_or(false);
|
||||
|
||||
if exists {
|
||||
// Update existing session
|
||||
tx.execute(
|
||||
r#"
|
||||
UPDATE ratchet_states SET
|
||||
root_key = ?,
|
||||
sending_chain = ?,
|
||||
receiving_chain = ?,
|
||||
dh_self_secret_encrypted = ?,
|
||||
dh_self_secret_nonce = ?,
|
||||
dh_self_public = ?,
|
||||
dh_remote = ?,
|
||||
msg_send = ?,
|
||||
msg_recv = ?,
|
||||
prev_chain_len = ?,
|
||||
domain_id = ?,
|
||||
updated_at = ?
|
||||
WHERE session_id = ?
|
||||
"#,
|
||||
params![
|
||||
state.root_key.as_slice(),
|
||||
state.sending_chain.as_ref().map(|c| c.as_slice()),
|
||||
state.receiving_chain.as_ref().map(|c| c.as_slice()),
|
||||
encrypted_secret.as_slice(),
|
||||
nonce.as_slice(),
|
||||
state.dh_self_public.as_slice(),
|
||||
state.dh_remote.as_ref().map(|pk| pk.as_slice()),
|
||||
state.msg_send,
|
||||
state.msg_recv,
|
||||
state.prev_chain_len,
|
||||
&state.domain_id,
|
||||
now,
|
||||
session_id.as_slice(),
|
||||
],
|
||||
)?;
|
||||
|
||||
// Delete existing skipped keys
|
||||
tx.execute(
|
||||
"DELETE FROM skipped_keys WHERE session_id = ?",
|
||||
[session_id.as_slice()],
|
||||
)?;
|
||||
} else {
|
||||
// Insert new session
|
||||
tx.execute(
|
||||
r#"
|
||||
INSERT INTO ratchet_states (
|
||||
session_id, root_key, sending_chain, receiving_chain,
|
||||
dh_self_secret_encrypted, dh_self_secret_nonce, dh_self_public,
|
||||
dh_remote, msg_send, msg_recv, prev_chain_len, domain_id,
|
||||
created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
params![
|
||||
session_id.as_slice(),
|
||||
state.root_key.as_slice(),
|
||||
state.sending_chain.as_ref().map(|c| c.as_slice()),
|
||||
state.receiving_chain.as_ref().map(|c| c.as_slice()),
|
||||
encrypted_secret.as_slice(),
|
||||
nonce.as_slice(),
|
||||
state.dh_self_public.as_slice(),
|
||||
state.dh_remote.as_ref().map(|pk| pk.as_slice()),
|
||||
state.msg_send,
|
||||
state.msg_recv,
|
||||
state.prev_chain_len,
|
||||
&state.domain_id,
|
||||
now,
|
||||
now,
|
||||
],
|
||||
)?;
|
||||
}
|
||||
|
||||
// Insert skipped keys
|
||||
for sk in &state.skipped_keys {
|
||||
tx.execute(
|
||||
r#"
|
||||
INSERT INTO skipped_keys (session_id, public_key, msg_num, message_key)
|
||||
VALUES (?, ?, ?, ?)
|
||||
"#,
|
||||
params![
|
||||
session_id.as_slice(),
|
||||
sk.public_key.as_slice(),
|
||||
sk.msg_num,
|
||||
sk.message_key.as_slice(),
|
||||
],
|
||||
)?;
|
||||
}
|
||||
|
||||
tx.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load(&self, session_id: &SessionId) -> Result<Option<StorableRatchetState>, StorageError> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
|
||||
let row = conn
|
||||
.query_row(
|
||||
r#"
|
||||
SELECT root_key, sending_chain, receiving_chain,
|
||||
dh_self_secret_encrypted, dh_self_secret_nonce, dh_self_public,
|
||||
dh_remote, msg_send, msg_recv, prev_chain_len, domain_id
|
||||
FROM ratchet_states WHERE session_id = ?
|
||||
"#,
|
||||
[session_id.as_slice()],
|
||||
|row| {
|
||||
Ok((
|
||||
row.get::<_, Vec<u8>>(0)?,
|
||||
row.get::<_, Option<Vec<u8>>>(1)?,
|
||||
row.get::<_, Option<Vec<u8>>>(2)?,
|
||||
row.get::<_, Vec<u8>>(3)?,
|
||||
row.get::<_, Vec<u8>>(4)?,
|
||||
row.get::<_, Vec<u8>>(5)?,
|
||||
row.get::<_, Option<Vec<u8>>>(6)?,
|
||||
row.get::<_, u32>(7)?,
|
||||
row.get::<_, u32>(8)?,
|
||||
row.get::<_, u32>(9)?,
|
||||
row.get::<_, String>(10)?,
|
||||
))
|
||||
},
|
||||
)
|
||||
.optional()?;
|
||||
|
||||
let Some((
|
||||
root_key_bytes,
|
||||
sending_chain_bytes,
|
||||
receiving_chain_bytes,
|
||||
encrypted_secret,
|
||||
nonce_bytes,
|
||||
dh_self_public_bytes,
|
||||
dh_remote_bytes,
|
||||
msg_send,
|
||||
msg_recv,
|
||||
prev_chain_len,
|
||||
domain_id,
|
||||
)) = row
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
// Decrypt the secret
|
||||
let nonce: [u8; 12] = nonce_bytes
|
||||
.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("invalid nonce length".to_string()))?;
|
||||
let dh_self_secret = self.decrypt_secret(&encrypted_secret, &nonce)?;
|
||||
|
||||
// Convert byte vectors to arrays
|
||||
let root_key: [u8; 32] = root_key_bytes
|
||||
.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("invalid root_key length".to_string()))?;
|
||||
let dh_self_public: [u8; 32] = dh_self_public_bytes
|
||||
.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("invalid dh_self_public length".to_string()))?;
|
||||
|
||||
let sending_chain = sending_chain_bytes
|
||||
.map(|b| {
|
||||
b.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("invalid sending_chain length".to_string()))
|
||||
})
|
||||
.transpose()?;
|
||||
let receiving_chain = receiving_chain_bytes
|
||||
.map(|b| {
|
||||
b.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("invalid receiving_chain length".to_string()))
|
||||
})
|
||||
.transpose()?;
|
||||
let dh_remote = dh_remote_bytes
|
||||
.map(|b| {
|
||||
b.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("invalid dh_remote length".to_string()))
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
// Load skipped keys
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT public_key, msg_num, message_key FROM skipped_keys WHERE session_id = ?",
|
||||
)?;
|
||||
let skipped_keys: Vec<SkippedKey> = stmt
|
||||
.query_map([session_id.as_slice()], |row| {
|
||||
let pk_bytes: Vec<u8> = row.get(0)?;
|
||||
let msg_num: u32 = row.get(1)?;
|
||||
let mk_bytes: Vec<u8> = row.get(2)?;
|
||||
Ok((pk_bytes, msg_num, mk_bytes))
|
||||
})?
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.map(|(pk_bytes, msg_num, mk_bytes)| {
|
||||
let public_key: [u8; 32] = pk_bytes
|
||||
.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("invalid skipped public_key length".to_string()))?;
|
||||
let message_key: [u8; 32] = mk_bytes
|
||||
.try_into()
|
||||
.map_err(|_| StorageError::CorruptedState("invalid skipped message_key length".to_string()))?;
|
||||
Ok(SkippedKey {
|
||||
public_key,
|
||||
msg_num,
|
||||
message_key,
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>, StorageError>>()?;
|
||||
|
||||
Ok(Some(StorableRatchetState {
|
||||
root_key,
|
||||
sending_chain,
|
||||
receiving_chain,
|
||||
dh_self_secret,
|
||||
dh_self_public,
|
||||
dh_remote,
|
||||
msg_send,
|
||||
msg_recv,
|
||||
prev_chain_len,
|
||||
skipped_keys,
|
||||
domain_id,
|
||||
}))
|
||||
}
|
||||
|
||||
fn delete(&self, session_id: &SessionId) -> Result<bool, StorageError> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
let changes = conn.execute(
|
||||
"DELETE FROM ratchet_states WHERE session_id = ?",
|
||||
[session_id.as_slice()],
|
||||
)?;
|
||||
Ok(changes > 0)
|
||||
}
|
||||
|
||||
fn exists(&self, session_id: &SessionId) -> Result<bool, StorageError> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
let exists: bool = conn
|
||||
.query_row(
|
||||
"SELECT 1 FROM ratchet_states WHERE session_id = ?",
|
||||
[session_id.as_slice()],
|
||||
|_| Ok(true),
|
||||
)
|
||||
.optional()?
|
||||
.unwrap_or(false);
|
||||
Ok(exists)
|
||||
}
|
||||
|
||||
fn list_sessions(&self) -> Result<Vec<SessionId>, StorageError> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
let mut stmt = conn.prepare("SELECT session_id FROM ratchet_states")?;
|
||||
let sessions: Vec<SessionId> = stmt
|
||||
.query_map([], |row| {
|
||||
let bytes: Vec<u8> = row.get(0)?;
|
||||
Ok(bytes)
|
||||
})?
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.into_iter()
|
||||
.filter_map(|bytes| bytes.try_into().ok())
|
||||
.collect();
|
||||
Ok(sessions)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use double_ratchets::hkdf::DefaultDomain;
|
||||
use double_ratchets::state::RatchetState;
|
||||
use double_ratchets::InstallationKeyPair;
|
||||
|
||||
fn create_test_storage() -> SqliteStorage {
|
||||
let key = [0x42u8; 32];
|
||||
SqliteStorage::open_in_memory(key).unwrap()
|
||||
}
|
||||
|
||||
fn create_test_state() -> StorableRatchetState {
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
let state: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
||||
StorableRatchetState::from_ratchet_state(&state, "default")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_save_and_load() {
|
||||
let storage = create_test_storage();
|
||||
let session_id = [1u8; 32];
|
||||
let state = create_test_state();
|
||||
|
||||
storage.save(&session_id, &state).unwrap();
|
||||
let loaded = storage.load(&session_id).unwrap();
|
||||
|
||||
assert!(loaded.is_some());
|
||||
let loaded = loaded.unwrap();
|
||||
assert_eq!(loaded.root_key, state.root_key);
|
||||
assert_eq!(loaded.dh_self_public, state.dh_self_public);
|
||||
// Secret should be decrypted correctly
|
||||
assert_eq!(loaded.dh_self_secret, state.dh_self_secret);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_nonexistent() {
|
||||
let storage = create_test_storage();
|
||||
let session_id = [1u8; 32];
|
||||
|
||||
let loaded = storage.load(&session_id).unwrap();
|
||||
assert!(loaded.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delete() {
|
||||
let storage = create_test_storage();
|
||||
let session_id = [1u8; 32];
|
||||
let state = create_test_state();
|
||||
|
||||
storage.save(&session_id, &state).unwrap();
|
||||
assert!(storage.exists(&session_id).unwrap());
|
||||
|
||||
let deleted = storage.delete(&session_id).unwrap();
|
||||
assert!(deleted);
|
||||
assert!(!storage.exists(&session_id).unwrap());
|
||||
|
||||
// Deleting again should return false
|
||||
let deleted = storage.delete(&session_id).unwrap();
|
||||
assert!(!deleted);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exists() {
|
||||
let storage = create_test_storage();
|
||||
let session_id = [1u8; 32];
|
||||
|
||||
assert!(!storage.exists(&session_id).unwrap());
|
||||
|
||||
let state = create_test_state();
|
||||
storage.save(&session_id, &state).unwrap();
|
||||
|
||||
assert!(storage.exists(&session_id).unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_list_sessions() {
|
||||
let storage = create_test_storage();
|
||||
|
||||
assert!(storage.list_sessions().unwrap().is_empty());
|
||||
|
||||
let state = create_test_state();
|
||||
let session_ids: Vec<SessionId> = (0..3).map(|i| [i; 32]).collect();
|
||||
|
||||
for id in &session_ids {
|
||||
storage.save(id, &state).unwrap();
|
||||
}
|
||||
|
||||
let mut listed = storage.list_sessions().unwrap();
|
||||
listed.sort();
|
||||
let mut expected = session_ids.clone();
|
||||
expected.sort();
|
||||
|
||||
assert_eq!(listed, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_overwrite() {
|
||||
let storage = create_test_storage();
|
||||
let session_id = [1u8; 32];
|
||||
|
||||
// Create first state
|
||||
let bob_keypair1 = InstallationKeyPair::generate();
|
||||
let state1: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender([0x42u8; 32], *bob_keypair1.public());
|
||||
let storable1 = StorableRatchetState::from_ratchet_state(&state1, "default");
|
||||
|
||||
// Create second state with different root
|
||||
let bob_keypair2 = InstallationKeyPair::generate();
|
||||
let state2: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender([0x43u8; 32], *bob_keypair2.public());
|
||||
let storable2 = StorableRatchetState::from_ratchet_state(&state2, "default");
|
||||
|
||||
// Save first, then overwrite with second
|
||||
storage.save(&session_id, &storable1).unwrap();
|
||||
storage.save(&session_id, &storable2).unwrap();
|
||||
|
||||
// Should have the second state
|
||||
let loaded = storage.load(&session_id).unwrap().unwrap();
|
||||
assert_eq!(loaded.root_key, storable2.root_key);
|
||||
assert_ne!(loaded.root_key, storable1.root_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skipped_keys_storage() {
|
||||
let storage = create_test_storage();
|
||||
let session_id = [1u8; 32];
|
||||
|
||||
// Create states and generate skipped keys
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
let mut alice: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
||||
let mut bob: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_receiver(shared_secret, bob_keypair);
|
||||
|
||||
// Alice sends multiple messages
|
||||
let mut messages = vec![];
|
||||
for i in 0..3 {
|
||||
let (ct, header) = alice.encrypt_message(&format!("Message {}", i).into_bytes());
|
||||
messages.push((ct, header));
|
||||
}
|
||||
|
||||
// Bob receives out of order to create skipped keys
|
||||
bob.decrypt_message(&messages[0].0, messages[0].1.clone())
|
||||
.unwrap();
|
||||
bob.decrypt_message(&messages[2].0, messages[2].1.clone())
|
||||
.unwrap();
|
||||
|
||||
assert!(!bob.skipped_keys.is_empty());
|
||||
|
||||
// Save and reload
|
||||
let storable = StorableRatchetState::from_ratchet_state(&bob, "default");
|
||||
storage.save(&session_id, &storable).unwrap();
|
||||
|
||||
let loaded = storage.load(&session_id).unwrap().unwrap();
|
||||
assert_eq!(loaded.skipped_keys.len(), storable.skipped_keys.len());
|
||||
|
||||
// Restore and verify we can decrypt the skipped message
|
||||
let mut restored: RatchetState<DefaultDomain> = loaded.to_ratchet_state().unwrap();
|
||||
let pt = restored
|
||||
.decrypt_message(&messages[1].0, messages[1].1.clone())
|
||||
.unwrap();
|
||||
assert_eq!(pt, b"Message 1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encryption_uses_different_nonces() {
|
||||
let storage = create_test_storage();
|
||||
let state = create_test_state();
|
||||
|
||||
// Save the same state twice with different session IDs
|
||||
storage.save(&[1u8; 32], &state).unwrap();
|
||||
storage.save(&[2u8; 32], &state).unwrap();
|
||||
|
||||
// Both should load correctly (encryption with different nonces)
|
||||
let loaded1 = storage.load(&[1u8; 32]).unwrap().unwrap();
|
||||
let loaded2 = storage.load(&[2u8; 32]).unwrap().unwrap();
|
||||
|
||||
assert_eq!(loaded1.dh_self_secret, loaded2.dh_self_secret);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cascade_delete_skipped_keys() {
|
||||
let storage = create_test_storage();
|
||||
let session_id = [1u8; 32];
|
||||
|
||||
// Create a state with skipped keys
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
let mut alice: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
||||
let mut bob: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_receiver(shared_secret, bob_keypair);
|
||||
|
||||
let mut messages = vec![];
|
||||
for i in 0..3 {
|
||||
let (ct, header) = alice.encrypt_message(&format!("Message {}", i).into_bytes());
|
||||
messages.push((ct, header));
|
||||
}
|
||||
|
||||
bob.decrypt_message(&messages[0].0, messages[0].1.clone())
|
||||
.unwrap();
|
||||
bob.decrypt_message(&messages[2].0, messages[2].1.clone())
|
||||
.unwrap();
|
||||
|
||||
let storable = StorableRatchetState::from_ratchet_state(&bob, "default");
|
||||
storage.save(&session_id, &storable).unwrap();
|
||||
|
||||
// Verify skipped keys exist
|
||||
{
|
||||
let conn = storage.conn.lock().unwrap();
|
||||
let count: i64 = conn
|
||||
.query_row(
|
||||
"SELECT COUNT(*) FROM skipped_keys WHERE session_id = ?",
|
||||
[session_id.as_slice()],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.unwrap();
|
||||
assert!(count > 0);
|
||||
}
|
||||
|
||||
// Delete session
|
||||
storage.delete(&session_id).unwrap();
|
||||
|
||||
// Verify skipped keys were also deleted (cascade)
|
||||
{
|
||||
let conn = storage.conn.lock().unwrap();
|
||||
let count: i64 = conn
|
||||
.query_row(
|
||||
"SELECT COUNT(*) FROM skipped_keys WHERE session_id = ?",
|
||||
[session_id.as_slice()],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(count, 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_file_storage() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let db_path = dir.path().join("test.db");
|
||||
let key = [0x42u8; 32];
|
||||
|
||||
let state = create_test_state();
|
||||
let session_id = [1u8; 32];
|
||||
|
||||
// Save in one instance
|
||||
{
|
||||
let storage = SqliteStorage::open(&db_path, key).unwrap();
|
||||
storage.save(&session_id, &state).unwrap();
|
||||
}
|
||||
|
||||
// Load in another instance
|
||||
{
|
||||
let storage = SqliteStorage::open(&db_path, key).unwrap();
|
||||
let loaded = storage.load(&session_id).unwrap().unwrap();
|
||||
assert_eq!(loaded.root_key, state.root_key);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wrong_key_fails_decryption() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let db_path = dir.path().join("test.db");
|
||||
let key1 = [0x42u8; 32];
|
||||
let key2 = [0x43u8; 32];
|
||||
|
||||
let state = create_test_state();
|
||||
let session_id = [1u8; 32];
|
||||
|
||||
// Save with key1
|
||||
{
|
||||
let storage = SqliteStorage::open(&db_path, key1).unwrap();
|
||||
storage.save(&session_id, &state).unwrap();
|
||||
}
|
||||
|
||||
// Try to load with key2 - should fail decryption
|
||||
{
|
||||
let storage = SqliteStorage::open(&db_path, key2).unwrap();
|
||||
let result = storage.load(&session_id);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
}
|
||||
74
double-ratchets-storage/src/traits.rs
Normal file
74
double-ratchets-storage/src/traits.rs
Normal file
@ -0,0 +1,74 @@
|
||||
//! Storage trait definitions.
|
||||
|
||||
use crate::error::StorageError;
|
||||
use crate::types::StorableRatchetState;
|
||||
|
||||
/// A 32-byte session identifier.
|
||||
pub type SessionId = [u8; 32];
|
||||
|
||||
/// Abstract storage interface for ratchet states.
|
||||
///
|
||||
/// Implementations must be thread-safe (`Send + Sync`).
|
||||
pub trait RatchetStorage: Send + Sync {
|
||||
/// Save a ratchet state for the given session.
|
||||
///
|
||||
/// If a state already exists for this session, it will be overwritten.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `session_id` - Unique identifier for the session.
|
||||
/// * `state` - The ratchet state to store.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(())` on success.
|
||||
/// * `Err(StorageError)` on failure.
|
||||
fn save(&self, session_id: &SessionId, state: &StorableRatchetState) -> Result<(), StorageError>;
|
||||
|
||||
/// Load a ratchet state for the given session.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `session_id` - Unique identifier for the session.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(Some(state))` if found.
|
||||
/// * `Ok(None)` if not found.
|
||||
/// * `Err(StorageError)` on failure.
|
||||
fn load(&self, session_id: &SessionId) -> Result<Option<StorableRatchetState>, StorageError>;
|
||||
|
||||
/// Delete a ratchet state for the given session.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `session_id` - Unique identifier for the session.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(true)` if the session existed and was deleted.
|
||||
/// * `Ok(false)` if the session did not exist.
|
||||
/// * `Err(StorageError)` on failure.
|
||||
fn delete(&self, session_id: &SessionId) -> Result<bool, StorageError>;
|
||||
|
||||
/// Check if a session exists in storage.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `session_id` - Unique identifier for the session.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(true)` if the session exists.
|
||||
/// * `Ok(false)` if the session does not exist.
|
||||
/// * `Err(StorageError)` on failure.
|
||||
fn exists(&self, session_id: &SessionId) -> Result<bool, StorageError>;
|
||||
|
||||
/// List all session IDs in storage.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(Vec<SessionId>)` containing all session IDs.
|
||||
/// * `Err(StorageError)` on failure.
|
||||
fn list_sessions(&self) -> Result<Vec<SessionId>, StorageError>;
|
||||
}
|
||||
225
double-ratchets-storage/src/types.rs
Normal file
225
double-ratchets-storage/src/types.rs
Normal file
@ -0,0 +1,225 @@
|
||||
//! Serializable types for ratchet state storage.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use double_ratchets::state::RatchetState;
|
||||
use double_ratchets::hkdf::HkdfInfo;
|
||||
use double_ratchets::InstallationKeyPair;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use x25519_dalek::PublicKey;
|
||||
|
||||
use crate::error::StorageError;
|
||||
|
||||
/// A skipped message key entry for storage.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct SkippedKey {
|
||||
/// The public key associated with this skipped message.
|
||||
pub public_key: [u8; 32],
|
||||
/// The message number.
|
||||
pub msg_num: u32,
|
||||
/// The 32-byte message key.
|
||||
pub message_key: [u8; 32],
|
||||
}
|
||||
|
||||
/// Serializable version of `RatchetState`.
|
||||
///
|
||||
/// This struct stores all keys as raw byte arrays for easy serialization
|
||||
/// and database storage. Use `from_ratchet_state()` and `to_ratchet_state()`
|
||||
/// for conversion.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StorableRatchetState {
|
||||
/// The current root key (32 bytes).
|
||||
pub root_key: [u8; 32],
|
||||
|
||||
/// The current sending chain key, if any (32 bytes).
|
||||
pub sending_chain: Option<[u8; 32]>,
|
||||
|
||||
/// The current receiving chain key, if any (32 bytes).
|
||||
pub receiving_chain: Option<[u8; 32]>,
|
||||
|
||||
/// Our DH secret key (32 bytes).
|
||||
///
|
||||
/// **Security**: This should be encrypted before storage.
|
||||
pub dh_self_secret: [u8; 32],
|
||||
|
||||
/// Our DH public key (32 bytes).
|
||||
pub dh_self_public: [u8; 32],
|
||||
|
||||
/// Remote party's DH public key, if known (32 bytes).
|
||||
pub dh_remote: Option<[u8; 32]>,
|
||||
|
||||
/// Number of messages sent in the current sending chain.
|
||||
pub msg_send: u32,
|
||||
|
||||
/// Number of messages received in the current receiving chain.
|
||||
pub msg_recv: u32,
|
||||
|
||||
/// Length of the previous sending chain.
|
||||
pub prev_chain_len: u32,
|
||||
|
||||
/// Skipped message keys for out-of-order message handling.
|
||||
pub skipped_keys: Vec<SkippedKey>,
|
||||
|
||||
/// Domain identifier for HKDF info.
|
||||
pub domain_id: String,
|
||||
}
|
||||
|
||||
impl StorableRatchetState {
|
||||
/// Convert a `RatchetState` into a `StorableRatchetState`.
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `D` - The HKDF domain type implementing `HkdfInfo`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `state` - The ratchet state to convert.
|
||||
/// * `domain_id` - A string identifier for the domain (used to reconstruct the correct domain type).
|
||||
pub fn from_ratchet_state<D: HkdfInfo>(state: &RatchetState<D>, domain_id: &str) -> Self {
|
||||
let skipped_keys: Vec<SkippedKey> = state
|
||||
.skipped_keys
|
||||
.iter()
|
||||
.map(|((pub_key, msg_num), msg_key)| SkippedKey {
|
||||
public_key: *pub_key.as_bytes(),
|
||||
msg_num: *msg_num,
|
||||
message_key: *msg_key,
|
||||
})
|
||||
.collect();
|
||||
|
||||
StorableRatchetState {
|
||||
root_key: state.root_key,
|
||||
sending_chain: state.sending_chain,
|
||||
receiving_chain: state.receiving_chain,
|
||||
dh_self_secret: state.dh_self.secret_bytes(),
|
||||
dh_self_public: *state.dh_self.public().as_bytes(),
|
||||
dh_remote: state.dh_remote.map(|pk| *pk.as_bytes()),
|
||||
msg_send: state.msg_send,
|
||||
msg_recv: state.msg_recv,
|
||||
prev_chain_len: state.prev_chain_len,
|
||||
skipped_keys,
|
||||
domain_id: domain_id.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert this `StorableRatchetState` back into a `RatchetState`.
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// * `D` - The HKDF domain type implementing `HkdfInfo`.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(RatchetState)` on success.
|
||||
/// * `Err(StorageError)` if key reconstruction fails.
|
||||
pub fn to_ratchet_state<D: HkdfInfo>(&self) -> Result<RatchetState<D>, StorageError> {
|
||||
// Reconstruct the keypair
|
||||
let dh_self = InstallationKeyPair::from_bytes(self.dh_self_secret, self.dh_self_public)
|
||||
.map_err(|e| StorageError::KeyReconstruction(e.to_string()))?;
|
||||
|
||||
// Reconstruct skipped keys HashMap
|
||||
let skipped_keys: HashMap<(PublicKey, u32), [u8; 32]> = self
|
||||
.skipped_keys
|
||||
.iter()
|
||||
.map(|sk| {
|
||||
let pub_key = PublicKey::from(sk.public_key);
|
||||
((pub_key, sk.msg_num), sk.message_key)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(RatchetState::from_parts(
|
||||
self.root_key,
|
||||
self.sending_chain,
|
||||
self.receiving_chain,
|
||||
dh_self,
|
||||
self.dh_remote.map(PublicKey::from),
|
||||
self.msg_send,
|
||||
self.msg_recv,
|
||||
self.prev_chain_len,
|
||||
skipped_keys,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use double_ratchets::hkdf::DefaultDomain;
|
||||
|
||||
#[test]
|
||||
fn test_roundtrip_sender_state() {
|
||||
// Create a sender state
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
let state: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
||||
|
||||
// Convert to storable and back
|
||||
let storable = StorableRatchetState::from_ratchet_state(&state, "default");
|
||||
let restored: RatchetState<DefaultDomain> = storable.to_ratchet_state().unwrap();
|
||||
|
||||
// Verify fields match
|
||||
assert_eq!(state.root_key, restored.root_key);
|
||||
assert_eq!(state.sending_chain, restored.sending_chain);
|
||||
assert_eq!(state.receiving_chain, restored.receiving_chain);
|
||||
assert_eq!(state.dh_self.public().as_bytes(), restored.dh_self.public().as_bytes());
|
||||
assert_eq!(state.dh_remote.map(|pk| *pk.as_bytes()), restored.dh_remote.map(|pk| *pk.as_bytes()));
|
||||
assert_eq!(state.msg_send, restored.msg_send);
|
||||
assert_eq!(state.msg_recv, restored.msg_recv);
|
||||
assert_eq!(state.prev_chain_len, restored.prev_chain_len);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_roundtrip_receiver_state() {
|
||||
// Create a receiver state
|
||||
let keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
let state: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_receiver(shared_secret, keypair);
|
||||
|
||||
// Convert to storable and back
|
||||
let storable = StorableRatchetState::from_ratchet_state(&state, "default");
|
||||
let restored: RatchetState<DefaultDomain> = storable.to_ratchet_state().unwrap();
|
||||
|
||||
// Verify fields match
|
||||
assert_eq!(state.root_key, restored.root_key);
|
||||
assert_eq!(state.dh_self.public().as_bytes(), restored.dh_self.public().as_bytes());
|
||||
assert!(restored.dh_remote.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_roundtrip_with_skipped_keys() {
|
||||
// Create states and exchange messages to generate skipped keys
|
||||
let bob_keypair = InstallationKeyPair::generate();
|
||||
let shared_secret = [0x42u8; 32];
|
||||
let mut alice: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_sender(shared_secret, *bob_keypair.public());
|
||||
let mut bob: RatchetState<DefaultDomain> =
|
||||
RatchetState::init_receiver(shared_secret, bob_keypair);
|
||||
|
||||
// Alice sends multiple messages
|
||||
let mut messages = vec![];
|
||||
for i in 0..3 {
|
||||
let (ct, header) = alice.encrypt_message(&format!("Message {}", i).into_bytes());
|
||||
messages.push((ct, header));
|
||||
}
|
||||
|
||||
// Bob receives them out of order to create skipped keys
|
||||
bob.decrypt_message(&messages[0].0, messages[0].1.clone()).unwrap();
|
||||
bob.decrypt_message(&messages[2].0, messages[2].1.clone()).unwrap();
|
||||
// Message 1 key is now in skipped_keys
|
||||
|
||||
assert!(!bob.skipped_keys.is_empty());
|
||||
|
||||
// Convert to storable and back
|
||||
let storable = StorableRatchetState::from_ratchet_state(&bob, "default");
|
||||
let restored: RatchetState<DefaultDomain> = storable.to_ratchet_state().unwrap();
|
||||
|
||||
// Verify skipped keys are preserved
|
||||
assert_eq!(bob.skipped_keys.len(), restored.skipped_keys.len());
|
||||
|
||||
// The restored state should be able to decrypt the skipped message
|
||||
let mut restored = restored;
|
||||
let pt = restored.decrypt_message(&messages[1].0, messages[1].1.clone()).unwrap();
|
||||
assert_eq!(pt, b"Message 1");
|
||||
}
|
||||
}
|
||||
@ -23,4 +23,37 @@ impl InstallationKeyPair {
|
||||
pub fn public(&self) -> &PublicKey {
|
||||
&self.public
|
||||
}
|
||||
|
||||
/// Export the secret key as raw bytes for serialization/storage.
|
||||
///
|
||||
/// # Security Warning
|
||||
///
|
||||
/// The returned bytes contain the private key material. Handle with care
|
||||
/// and ensure proper encryption when storing.
|
||||
pub fn secret_bytes(&self) -> [u8; 32] {
|
||||
self.secret.to_bytes()
|
||||
}
|
||||
|
||||
/// Reconstruct a keypair from raw secret and public key bytes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `secret` - The 32-byte secret key.
|
||||
/// * `public` - The 32-byte public key.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Ok(InstallationKeyPair)` if the keys are valid and consistent.
|
||||
/// * `Err(&'static str)` if the public key doesn't match the secret key.
|
||||
pub fn from_bytes(secret: [u8; 32], public: [u8; 32]) -> Result<Self, &'static str> {
|
||||
let secret = StaticSecret::from(secret);
|
||||
let expected_public = PublicKey::from(&secret);
|
||||
let public = PublicKey::from(public);
|
||||
|
||||
if expected_public.as_bytes() != public.as_bytes() {
|
||||
return Err("public key does not match secret key");
|
||||
}
|
||||
|
||||
Ok(Self { secret, public })
|
||||
}
|
||||
}
|
||||
|
||||
@ -59,6 +59,46 @@ impl Header {
|
||||
}
|
||||
|
||||
impl<D: HkdfInfo> RatchetState<D> {
|
||||
/// Reconstruct a `RatchetState` from its component parts.
|
||||
///
|
||||
/// This is primarily used for deserialization from storage.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `root_key` - The current root key.
|
||||
/// * `sending_chain` - The current sending chain key, if any.
|
||||
/// * `receiving_chain` - The current receiving chain key, if any.
|
||||
/// * `dh_self` - Our DH key pair.
|
||||
/// * `dh_remote` - Remote party's DH public key, if known.
|
||||
/// * `msg_send` - Number of messages sent in current sending chain.
|
||||
/// * `msg_recv` - Number of messages received in current receiving chain.
|
||||
/// * `prev_chain_len` - Length of the previous sending chain.
|
||||
/// * `skipped_keys` - Map of skipped message keys.
|
||||
pub fn from_parts(
|
||||
root_key: RootKey,
|
||||
sending_chain: Option<ChainKey>,
|
||||
receiving_chain: Option<ChainKey>,
|
||||
dh_self: InstallationKeyPair,
|
||||
dh_remote: Option<PublicKey>,
|
||||
msg_send: u32,
|
||||
msg_recv: u32,
|
||||
prev_chain_len: u32,
|
||||
skipped_keys: HashMap<(PublicKey, u32), MessageKey>,
|
||||
) -> Self {
|
||||
Self {
|
||||
root_key,
|
||||
sending_chain,
|
||||
receiving_chain,
|
||||
dh_self,
|
||||
dh_remote,
|
||||
msg_send,
|
||||
msg_recv,
|
||||
prev_chain_len,
|
||||
skipped_keys,
|
||||
_domain: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Initializes the party that sends the first message.
|
||||
///
|
||||
/// Performs the initial Diffie-Hellman computation with the remote public key
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user