From 9cc73622ed22e90aac0ca0bd03e9e70ec3ae97e1 Mon Sep 17 00:00:00 2001 From: kaichao Date: Fri, 10 Apr 2026 08:33:58 +0800 Subject: [PATCH] Move double ratchet storage operations to PrivateV1 (#82) * feat: move private store out of context * feat: move convo store to private v1 * feat: clean context * chore: params postion * chore: use git exclue for justfile --- core/conversations/src/context.rs | 71 ++++++----------- core/conversations/src/conversation.rs | 6 +- .../src/conversation/privatev1.rs | 76 +++++++++++++++---- core/conversations/src/inbox/handler.rs | 33 ++++---- 4 files changed, 108 insertions(+), 78 deletions(-) diff --git a/core/conversations/src/context.rs b/core/conversations/src/context.rs index 3415c88..761101c 100644 --- a/core/conversations/src/context.rs +++ b/core/conversations/src/context.rs @@ -2,8 +2,7 @@ use std::sync::Arc; use std::{cell::RefCell, rc::Rc}; use crypto::Identity; -use double_ratchets::{RatchetState, restore_ratchet_state}; -use storage::{ChatStore, ConversationKind, ConversationMeta}; +use storage::{ChatStore, ConversationKind}; use crate::{ conversation::{Conversation, ConversationId, Convo, Id, PrivateV1Convo}, @@ -18,18 +17,18 @@ pub use crate::inbox::Introduction; // This is the main entry point to the conversations api. // Ctx manages lifetimes of objects to process and generate payloads. -pub struct Context { +pub struct Context { _identity: Rc, - inbox: Inbox, - store: Rc>, + inbox: Inbox, + store: Rc>, } -impl Context { +impl Context { /// Opens or creates a Context with the given storage configuration. /// /// If an identity exists in storage, it will be restored. /// Otherwise, a new identity will be created with the given name and saved. - pub fn new_from_store(name: impl Into, store: T) -> Result { + pub fn new_from_store(name: impl Into, store: S) -> Result { let name = name.into(); let store = Rc::new(RefCell::new(store)); @@ -43,7 +42,7 @@ impl Context { }; let identity = Rc::new(identity); - let inbox = Inbox::new(Rc::clone(&identity), Rc::clone(&store)); + let inbox = Inbox::new(Rc::clone(&store), Rc::clone(&identity)); Ok(Self { _identity: identity, @@ -55,7 +54,7 @@ impl Context { /// Creates a new in-memory Context (for testing). /// /// Uses in-memory SQLite database. Each call creates a new isolated database. - pub fn new_with_name(name: impl Into, chat_store: T) -> Self { + pub fn new_with_name(name: impl Into, chat_store: S) -> Self { let name = name.into(); let identity = Identity::new(&name); let chat_store = Rc::new(RefCell::new(chat_store)); @@ -65,7 +64,7 @@ impl Context { .expect("in-memory storage should not fail"); let identity = Rc::new(identity); - let inbox = Inbox::new(Rc::clone(&identity), Rc::clone(&chat_store)); + let inbox = Inbox::new(Rc::clone(&chat_store), Rc::clone(&identity)); Self { _identity: identity, @@ -83,18 +82,18 @@ impl Context { remote_bundle: &Introduction, content: &[u8], ) -> Result<(ConversationIdOwned, Vec), ChatError> { - let (convo, payloads) = self + let (mut convo, payloads) = self .inbox - .invite_to_private_convo(remote_bundle, content) + .invite_to_private_convo(remote_bundle, content, Rc::clone(&self.store)) .unwrap_or_else(|_| todo!("Log/Surface Error")); - let remote_id = Inbox::::inbox_identifier_for_key(*remote_bundle.installation_key()); + let remote_id = Inbox::::inbox_identifier_for_key(*remote_bundle.installation_key()); let payload_bytes = payloads .into_iter() .map(|p| p.into_envelope(remote_id.clone())) .collect(); - let convo_id = self.persist_convo(&convo)?; + let convo_id = convo.persist()?; Ok((convo_id, payload_bytes)) } @@ -117,7 +116,6 @@ impl Context { Conversation::Private(mut convo) => { let payloads = convo.send_message(content)?; let remote_id = convo.remote_id(); - convo.save_ratchet_state::(&mut *self.store.borrow_mut())?; Ok(payloads .into_iter() @@ -146,11 +144,13 @@ impl Context { &mut self, enc_payload: EncryptedPayload, ) -> Result, ChatError> { - let public_key_hex = Inbox::::extract_ephemeral_key_hex(&enc_payload)?; - let (convo, content) = self.inbox.handle_frame(enc_payload, &public_key_hex)?; + let public_key_hex = Inbox::::extract_ephemeral_key_hex(&enc_payload)?; + let (convo, content) = + self.inbox + .handle_frame(enc_payload, &public_key_hex, Rc::clone(&self.store))?; match convo { - Conversation::Private(convo) => self.persist_convo(&convo)?, + Conversation::Private(mut convo) => convo.persist()?, }; self.store @@ -170,7 +170,6 @@ impl Context { match convo { Conversation::Private(mut convo) => { let result = convo.handle_frame(enc_payload)?; - convo.save_ratchet_state(&mut *self.store.borrow_mut())?; Ok(result) } } @@ -181,8 +180,8 @@ impl Context { Ok(intro.into()) } - /// Loads a conversation from DB by constructing it from metadata + ratchet state. - fn load_convo(&self, convo_id: ConversationId) -> Result { + /// Loads a conversation from DB by constructing it from metadata. + fn load_convo(&self, convo_id: ConversationId) -> Result, ChatError> { let record = self .store .borrow() @@ -191,21 +190,12 @@ impl Context { match record.kind { ConversationKind::PrivateV1 => { - let dr_record = self - .store - .borrow() - .load_ratchet_state(&record.local_convo_id)?; - let skipped_keys = self - .store - .borrow() - .load_skipped_keys(&record.local_convo_id)?; - let dr_state: RatchetState = restore_ratchet_state(dr_record, skipped_keys); - - Ok(Conversation::Private(PrivateV1Convo::new( + let private_convo = PrivateV1Convo::new( + self.store.clone(), record.local_convo_id, record.remote_convo_id, - dr_state, - ))) + )?; + Ok(Conversation::Private(private_convo)) } ConversationKind::Unknown(_) => Err(ChatError::BadBundleValue(format!( "unsupported conversation type: {}", @@ -213,18 +203,6 @@ impl Context { ))), } } - - /// Persists a conversation's metadata and ratchet state to DB. - fn persist_convo(&mut self, convo: &PrivateV1Convo) -> Result { - let convo_info = ConversationMeta { - local_convo_id: convo.id().to_string(), - remote_convo_id: convo.remote_id(), - kind: convo.convo_type(), - }; - self.store.borrow_mut().save_conversation(&convo_info)?; - convo.save_ratchet_state(&mut *self.store.borrow_mut())?; - Ok(Arc::from(convo.id())) - } } #[cfg(test)] @@ -347,7 +325,6 @@ mod tests { let content = alice.handle_payload(&payload.data).unwrap().unwrap(); let alice_convo_id = content.conversation_id; - // Exchange a few messages to advance ratchet state let payloads = alice.send_content(&alice_convo_id, b"reply 1").unwrap(); let payload = payloads.first().unwrap(); bob.handle_payload(&payload.data).unwrap().unwrap(); diff --git a/core/conversations/src/conversation.rs b/core/conversations/src/conversation.rs index 2c058dd..1580d78 100644 --- a/core/conversations/src/conversation.rs +++ b/core/conversations/src/conversation.rs @@ -4,7 +4,7 @@ use crate::types::{AddressedEncryptedPayload, ContentData}; use chat_proto::logoschat::encryption::EncryptedPayload; use std::fmt::Debug; use std::sync::Arc; -use storage::ConversationKind; +use storage::{ConversationKind, ConversationStore, RatchetStore}; pub use crate::errors::ChatError; pub use privatev1::PrivateV1Convo; @@ -36,6 +36,6 @@ pub trait Convo: Id + Debug { fn convo_type(&self) -> ConversationKind; } -pub enum Conversation { - Private(PrivateV1Convo), +pub enum Conversation { + Private(PrivateV1Convo), } diff --git a/core/conversations/src/conversation/privatev1.rs b/core/conversations/src/conversation/privatev1.rs index f2f8a22..b7736d8 100644 --- a/core/conversations/src/conversation/privatev1.rs +++ b/core/conversations/src/conversation/privatev1.rs @@ -7,12 +7,13 @@ use chat_proto::logoschat::{ encryption::{Doubleratchet, EncryptedPayload, encrypted_payload::Encryption}, }; use crypto::{PrivateKey, PublicKey, SymmetricKey32}; -use double_ratchets::{Header, InstallationKeyPair, RatchetState}; +use double_ratchets::{Header, InstallationKeyPair, RatchetState, restore_ratchet_state}; use prost::{Message, bytes::Bytes}; -use std::fmt::Debug; -use storage::ConversationKind; +use std::{cell::RefCell, fmt::Debug, rc::Rc, sync::Arc}; +use storage::{ConversationKind, ConversationMeta, ConversationStore}; use crate::{ + context::ConversationIdOwned, conversation::{ChatError, ConversationId, Convo, Id}, errors::EncryptionError, proto, @@ -55,23 +56,37 @@ impl BaseConvoId { } } -pub struct PrivateV1Convo { +pub struct PrivateV1Convo { local_convo_id: String, remote_convo_id: String, dr_state: RatchetState, + store: Rc>, } -impl PrivateV1Convo { +impl PrivateV1Convo { /// Reconstructs a PrivateV1Convo from persisted metadata and ratchet state. - pub fn new(local_convo_id: String, remote_convo_id: String, dr_state: RatchetState) -> Self { - Self { + pub fn new( + store: Rc>, + local_convo_id: String, + remote_convo_id: String, + ) -> Result { + let dr_record = store.borrow().load_ratchet_state(&local_convo_id)?; + let skipped_keys = store.borrow().load_skipped_keys(&local_convo_id)?; + let dr_state: RatchetState = restore_ratchet_state(dr_record, skipped_keys); + + Ok(Self { local_convo_id, remote_convo_id, dr_state, - } + store, + }) } - pub fn new_initiator(seed_key: SymmetricKey32, remote: PublicKey) -> Self { + pub fn new_initiator( + store: Rc>, + seed_key: SymmetricKey32, + remote: PublicKey, + ) -> Self { let base_convo_id = BaseConvoId::new(&seed_key); let local_convo_id = base_convo_id.id_for_participant(Role::Initiator); let remote_convo_id = base_convo_id.id_for_participant(Role::Responder); @@ -86,10 +101,15 @@ impl PrivateV1Convo { local_convo_id, remote_convo_id, dr_state, + store, } } - pub fn new_responder(seed_key: SymmetricKey32, dh_self: &PrivateKey) -> Self { + pub fn new_responder( + store: Rc>, + seed_key: SymmetricKey32, + dh_self: &PrivateKey, + ) -> Self { let base_convo_id = BaseConvoId::new(&seed_key); let local_convo_id = base_convo_id.id_for_participant(Role::Responder); let remote_convo_id = base_convo_id.id_for_participant(Role::Initiator); @@ -105,6 +125,7 @@ impl PrivateV1Convo { local_convo_id, remote_convo_id, dr_state, + store, } } @@ -169,6 +190,18 @@ impl PrivateV1Convo { }) } + /// Persists a conversation's metadata and ratchet state to DB. + pub fn persist(&mut self) -> Result { + let convo_info = ConversationMeta { + local_convo_id: self.id().to_string(), + remote_convo_id: self.remote_id(), + kind: self.convo_type(), + }; + self.store.borrow_mut().save_conversation(&convo_info)?; + self.save_ratchet_state(&mut *self.store.borrow_mut())?; + Ok(Arc::from(self.id())) + } + pub fn save_ratchet_state(&self, storage: &mut T) -> Result<(), ChatError> { let record = to_ratchet_record(&self.dr_state); let skipped_keys = to_skipped_key_records(&self.dr_state.skipped_keys()); @@ -177,13 +210,13 @@ impl PrivateV1Convo { } } -impl Id for PrivateV1Convo { +impl Id for PrivateV1Convo { fn id(&self) -> ConversationId<'_> { &self.local_convo_id } } -impl Convo for PrivateV1Convo { +impl Convo for PrivateV1Convo { fn send_message( &mut self, content: &[u8], @@ -197,6 +230,8 @@ impl Convo for PrivateV1Convo { let data = self.encrypt(frame); + self.save_ratchet_state::(&mut *self.store.borrow_mut())?; + Ok(vec![AddressedEncryptedPayload { delivery_address: "delivery_address".into(), data, @@ -216,6 +251,8 @@ impl Convo for PrivateV1Convo { return Err(ChatError::ProtocolExpectation("None", "Some".into())); }; + self.save_ratchet_state(&mut *self.store.borrow_mut())?; + // Handle FrameTypes let output = match frame_type { FrameType::Content(bytes) => self.handle_content(bytes.into()), @@ -234,7 +271,7 @@ impl Convo for PrivateV1Convo { } } -impl Debug for PrivateV1Convo { +impl Debug for PrivateV1Convo { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PrivateV1Convo") .field("dr_state", &"******") @@ -245,6 +282,7 @@ impl Debug for PrivateV1Convo { #[cfg(test)] mod tests { use crypto::PrivateKey; + use sqlite::{ChatStorage, StorageConfig}; use super::*; @@ -253,14 +291,22 @@ mod tests { let saro = PrivateKey::random(); let raya = PrivateKey::random(); + let saro_storage = Rc::new(RefCell::new( + ChatStorage::new(StorageConfig::InMemory).unwrap(), + )); + + let raya_storage = Rc::new(RefCell::new( + ChatStorage::new(StorageConfig::InMemory).unwrap(), + )); + let pub_raya = PublicKey::from(&raya); let seed_key = saro.diffie_hellman(&pub_raya).DANGER_to_bytes(); let seed_key_saro = SymmetricKey32::from(seed_key); let seed_key_raya = SymmetricKey32::from(seed_key); let send_content_bytes = vec![0, 2, 4, 6, 8]; - let mut sr_convo = PrivateV1Convo::new_initiator(seed_key_saro, pub_raya); - let mut rs_convo = PrivateV1Convo::new_responder(seed_key_raya, &raya); + let mut sr_convo = PrivateV1Convo::new_initiator(saro_storage, seed_key_saro, pub_raya); + let mut rs_convo = PrivateV1Convo::new_responder(raya_storage, seed_key_raya, &raya); let send_frame = PrivateV1Frame { conversation_id: "_".into(), diff --git a/core/conversations/src/inbox/handler.rs b/core/conversations/src/inbox/handler.rs index 395f9e3..9b90ac3 100644 --- a/core/conversations/src/inbox/handler.rs +++ b/core/conversations/src/inbox/handler.rs @@ -5,7 +5,7 @@ use prost::bytes::Bytes; use rand_core::OsRng; use std::cell::RefCell; use std::rc::Rc; -use storage::EphemeralKeyStore; +use storage::{ConversationStore, EphemeralKeyStore, RatchetStore}; use crypto::{PrekeyBundle, SymmetricKey32}; @@ -39,7 +39,7 @@ impl std::fmt::Debug for Inbox { } impl Inbox { - pub fn new(ident: Rc, store: Rc>) -> Self { + pub fn new(store: Rc>, ident: Rc) -> Self { let local_convo_id = Self::inbox_identifier_for_key(ident.public_key()); Self { ident, @@ -64,11 +64,12 @@ impl Inbox { Ok(intro) } - pub fn invite_to_private_convo( + pub fn invite_to_private_convo( &self, remote_bundle: &Introduction, initial_message: &[u8], - ) -> Result<(PrivateV1Convo, Vec), ChatError> { + private_store: Rc>, + ) -> Result<(PrivateV1Convo, Vec), ChatError> { let mut rng = OsRng; let pkb = PrekeyBundle { @@ -81,7 +82,8 @@ impl Inbox { let (seed_key, ephemeral_pub) = InboxHandshake::perform_as_initiator(self.ident.secret(), &pkb, &mut rng); - let mut convo = PrivateV1Convo::new_initiator(seed_key, *remote_bundle.ephemeral_key()); + let mut convo = + PrivateV1Convo::new_initiator(private_store, seed_key, *remote_bundle.ephemeral_key()); let mut payloads = convo.send_message(initial_message)?; @@ -119,11 +121,12 @@ impl Inbox { /// Handles an incoming inbox frame. The caller must provide the ephemeral private key /// looked up from storage. Returns the created conversation and optional content data. - pub fn handle_frame( + pub fn handle_frame( &self, enc_payload: EncryptedPayload, public_key_hex: &str, - ) -> Result<(Conversation, Option), ChatError> { + private_store: Rc>, + ) -> Result<(Conversation, Option), ChatError> { let ephemeral_key = self .store .borrow() @@ -142,7 +145,8 @@ impl Inbox { match frame.frame_type.unwrap() { proto::inbox_v1_frame::FrameType::InvitePrivateV1(_invite_private_v1) => { - let mut convo = PrivateV1Convo::new_responder(seed_key, &ephemeral_key); + let mut convo = + PrivateV1Convo::new_responder(private_store, seed_key, &ephemeral_key); let Some(enc_payload) = _invite_private_v1.initial_message else { return Err(ChatError::Protocol("missing initial encpayload".into())); @@ -260,26 +264,29 @@ mod tests { #[test] fn test_invite_privatev1_roundtrip() { - let storage = Rc::new(RefCell::new( + let saro_storage = Rc::new(RefCell::new( + ChatStorage::new(StorageConfig::InMemory).unwrap(), + )); + let raya_storage = Rc::new(RefCell::new( ChatStorage::new(StorageConfig::InMemory).unwrap(), )); let saro_ident = Identity::new("saro"); - let saro_inbox = Inbox::new(saro_ident.into(), Rc::clone(&storage)); + let saro_inbox = Inbox::new(Rc::clone(&saro_storage), saro_ident.into()); let raya_ident = Identity::new("raya"); - let raya_inbox = Inbox::new(raya_ident.into(), Rc::clone(&storage)); + let raya_inbox = Inbox::new(Rc::clone(&raya_storage), raya_ident.into()); let bundle = raya_inbox.create_intro_bundle().unwrap(); let (_, mut payloads) = saro_inbox - .invite_to_private_convo(&bundle, "hello".as_bytes()) + .invite_to_private_convo(&bundle, "hello".as_bytes(), saro_storage) .unwrap(); let payload = payloads.remove(0); let key_hex = Inbox::::extract_ephemeral_key_hex(&payload.data).unwrap(); - let result = raya_inbox.handle_frame(payload.data, &key_hex); + let result = raya_inbox.handle_frame(payload.data, &key_hex, raya_storage); assert!( result.is_ok(),