Add convo cache

This commit is contained in:
Jazz Turner-Baggs 2026-06-13 08:51:54 -07:00
parent 2bbc83f6c3
commit 3245534790
No known key found for this signature in database
2 changed files with 149 additions and 29 deletions

View File

@ -1,8 +1,9 @@
use crate::causal_history::{CausalHistoryStore, MissingMessage};
use crate::conversation::{ConversationIdRef, GroupV1Convo, GroupV2Convo, PrivateV1Convo};
use crate::service_context::{ExternalServices, ServiceContext};
use crate::{DeliveryService, IdentityProvider, RegistrationService, WakeupService};
use crate::{
conversation::{Convo, GroupConvo, GroupV1Convo, PrivateV1Convo},
conversation::{Convo, GroupConvo},
errors::ChatError,
inbox::Inbox,
inbox_v2::{InboxV2, MlsEphemeralPqProvider, MlsIdentityProvider},
@ -10,8 +11,9 @@ use crate::{
proto::{EncryptedPayload, EnvelopeV1, Message},
};
use crypto::{Identity, PublicKey};
use openmls::prelude::GroupId;
use openmls::group::GroupId;
use shared_traits::IdentIdRef;
use std::collections::HashMap;
use storage::{ChatStore, ConversationKind, ConversationStore};
pub use crate::conversation::ConversationId;
@ -27,6 +29,8 @@ pub struct Core<S: ExternalServices> {
services: ServiceContext<S>,
inbox: Inbox,
pq_inbox: InboxV2,
// Cache of loaded conversations
cached_convos: HashMap<String, ConvoTypeOwned<S>>,
}
// Constructors live on the `(DS, RS, CS)` form: `S` can't be inferred backwards
@ -131,6 +135,7 @@ where
},
inbox,
pq_inbox,
cached_convos: HashMap::new(),
})
}
}
@ -219,7 +224,13 @@ impl<'a, S: ExternalServices + 'static> Core<S> {
kind: ConversationKind::GroupV1,
})?;
convo.add_member(&mut self.services, participants)?;
Ok(convo.id().to_string())
let convo_id = convo.id().to_string();
self.register_convo(ConvoTypeOwned::Group(Box::new(convo)))?;
Ok(convo_id)
}
pub fn create_group_convo_v2(
&mut self,
participants: &[IdentIdRef],
@ -238,6 +249,8 @@ impl<'a, S: ExternalServices + 'static> Core<S> {
convo.add_member(&mut self.services, participants)?;
let convo_id = convo.id().to_string();
self.register_convo(ConvoTypeOwned::Group(Box::new(convo)))?;
Ok(convo_id)
}
@ -247,13 +260,38 @@ impl<'a, S: ExternalServices + 'static> Core<S> {
convo_id: &str,
members: &[IdentIdRef],
) -> Result<(), ChatError> {
let mut convo = self.load_group_convo(convo_id)?;
convo.add_member(&mut self.services, members)
if self.cached_convos.contains_key(convo_id) {
let convo = self
.cached_convos
.get_mut(convo_id)
.ok_or_else(|| ChatError::NoConvo(convo_id.to_string()))?;
match convo {
ConvoTypeOwned::Group(group_convo) => {
group_convo.add_member(&mut self.services, members)
}
}
} else {
let mut convo = self.load_group_convo(convo_id)?;
convo.add_member(&mut self.services, members)
}
}
pub fn list_conversations(&self) -> Result<Vec<ConversationId>, ChatError> {
// Check Legacy load_convo store
let records = self.services.store.load_conversations()?;
Ok(records.into_iter().map(|r| r.local_convo_id).collect())
let mut convos: Vec<ConversationId> =
records.into_iter().map(|r| r.local_convo_id).collect();
// Add cached mls convos
for convo in self.cached_convos.keys() {
convos.push(convo.to_string());
}
// Conversations may use both storage mechanisms.
// Remove duplicates
convos.dedup();
Ok(convos)
}
pub fn take_missing_messages(&self) -> Vec<MissingMessage> {
@ -262,8 +300,16 @@ impl<'a, S: ExternalServices + 'static> Core<S> {
/// Encrypt and publish `content` to an existing conversation.
pub fn send_content(&mut self, convo_id: &str, content: &[u8]) -> Result<(), ChatError> {
let mut convo = self.load_convo(convo_id)?;
convo.send_content(&mut self.services, content)
if self.cached_convos.contains_key(convo_id) {
let convo = self
.cached_convos
.get_mut(convo_id)
.ok_or_else(|| ChatError::NoConvo(convo_id.to_string()))?;
convo.send_content(&mut self.services, content)
} else {
let mut convo = self.load_convo(convo_id)?;
convo.send_content(&mut self.services, content)
}
}
// Decode bytes and send to protocol for processing.
@ -276,6 +322,9 @@ impl<'a, S: ExternalServices + 'static> Core<S> {
match convo_id {
c if c == self.inbox.id() => self.dispatch_to_inbox(&env.payload).map(Into::into),
c if c == self.pq_inbox.id() => self.dispatch_to_inbox2(&env.payload).map(Into::into),
c if self.cached_convos.contains_key(&c) => {
self.dispatch_to_convo(&c, &env.payload).map(Into::into)
}
c if self.services.store.has_conversation(&c)? => {
self.dispatch_to_convo(&c, &env.payload).map(Into::into)
}
@ -295,8 +344,22 @@ impl<'a, S: ExternalServices + 'static> Core<S> {
}
// Dispatch encrypted payload to the post-quantum inbox.
fn dispatch_to_inbox2(&mut self, payload: &[u8]) -> Result<InboxOutcome, ChatError> {
self.pq_inbox.handle_frame(payload, &mut self.services)
fn dispatch_to_inbox2(&mut self, payload: &[u8]) -> Result<PayloadOutcome, ChatError> {
if let Some(convo) = self.pq_inbox.handle_frame(&mut self.services, payload)? {
let convo_id = convo.id().to_string();
// Cache convos created by InboxV2
self.register_convo(ConvoTypeOwned::Group(convo))?;
Ok(PayloadOutcome::Inbox(InboxOutcome {
new_conversation: crate::NewConversation {
convo_id: convo_id,
class: crate::ConversationClass::Group,
},
initial: None,
}))
} else {
Ok(PayloadOutcome::Empty)
}
}
// Dispatch encrypted payload to its corresponding conversation.
@ -306,8 +369,20 @@ impl<'a, S: ExternalServices + 'static> Core<S> {
enc_payload_bytes: &[u8],
) -> Result<ConvoOutcome, ChatError> {
let enc_payload = EncryptedPayload::decode(enc_payload_bytes)?;
let mut convo = self.load_convo(convo_id)?;
convo.handle_frame(&mut self.services, enc_payload)
if self.cached_convos.contains_key(convo_id) {
let convo_type = self
.cached_convos
.get_mut(convo_id)
.ok_or_else(|| ChatError::NoConvo(convo_id.to_string()))?;
convo_type.handle_frame(&mut self.services, enc_payload)
} else {
let mut convo = self.load_convo(convo_id)?;
convo.handle_frame(&mut self.services, enc_payload)
}
}
pub fn wakeup(&mut self, convo_id: ConversationIdRef) -> Result<(), ChatError> {
info!(convos = ?self.cached_convos.keys().collect::<Vec<_>>(), id = ?self.services.mls_identity.id(), "Cached Convos");
@ -330,6 +405,13 @@ impl<'a, S: ExternalServices + 'static> Core<S> {
convo.wakeup(&mut self.services)
}
fn register_convo(&mut self, convo: ConvoTypeOwned<S>) -> Result<(), ChatError> {
let res = self.cached_convos.insert(convo.id().to_string(), convo);
match res {
Some(_) => Err(ChatError::generic("Convo already exists. Cannot save")),
None => Ok(()),
}
}
/// Rebuilds a conversation from storage — the one site that branches on
@ -386,3 +468,45 @@ impl<'a, S: ExternalServices + 'static> Core<S> {
.ok_or_else(|| ChatError::NoConvo(convo_id.into()))
}
}
#[derive(Debug)]
enum ConvoTypeOwned<S: ExternalServices> {
// Pairwise(Box<dyn BaseConvo<S>>),
Group(Box<dyn GroupConvo<S>>),
}
impl<'a, S: ExternalServices> ConvoTypeOwned<S> {
pub fn id(&'a self) -> ConversationIdRef<'a> {
match self {
ConvoTypeOwned::Group(group_convo) => group_convo.id(),
}
}
}
impl<S: ExternalServices> Convo<S> for ConvoTypeOwned<S> {
fn send_content(
&mut self,
cx: &mut ServiceContext<S>,
content: &[u8],
) -> Result<(), ChatError> {
match self {
ConvoTypeOwned::Group(group_convo) => group_convo.send_content(cx, content),
}
}
fn handle_frame(
&mut self,
cx: &mut ServiceContext<S>,
enc: EncryptedPayload,
) -> Result<ConvoOutcome, ChatError> {
match self {
ConvoTypeOwned::Group(group_convo) => group_convo.handle_frame(cx, enc),
}
}
fn wakeup(&mut self, service_ctx: &mut ServiceContext<S>) -> Result<(), ChatError> {
match self {
ConvoTypeOwned::Group(group_convo) => group_convo.wakeup(service_ctx),
}
}
}

View File

@ -3,29 +3,31 @@ mod mls_provider;
use crypto::Ed25519VerifyingKey;
pub use identity::MlsIdentityProvider;
pub(crate) use mls_provider::MlsEphemeralPqProvider;
use shared_traits::IdentId;
use shared_traits::IdentIdRef;
use chat_proto::logoschat::envelope::EnvelopeV1;
use crypto::Ed25519VerifyingKey;
use de_mls::protos::de_mls::messages::v1::MemberWelcome;
use openmls::prelude::tls_codec::Serialize;
use openmls::prelude::*;
use prost::{Message, Oneof};
use std::cell::RefCell;
use storage::{ConversationKind, ConversationMeta, ConversationStore};
pub use identity::MlsIdentityProvider;
pub(crate) use mls_provider::MlsEphemeralPqProvider;
use crate::ChatError;
use crate::DeliveryService;
use crate::IdentityProvider;
use crate::RegistrationService;
use crate::conversation::ConversationId;
use crate::conversation::GroupConvo;
use crate::conversation::GroupV1Convo;
use crate::outcomes::{ConversationClass, InboxOutcome, NewConversation};
use crate::conversation::GroupV2Convo;
use crate::service_context::{ExternalServices, ServiceContext};
use crate::utils::{blake2b_hex, hash_size};
use crate::{
AccountAuthority, AccountDirectory, AddressedEnvelope, SignedDeviceBundle,
encode_bundle_payload,
};
use crate::{IdentId, IdentIdRef, IdentityProvider};
// Define unique Identifiers derivations used in InboxV2
fn delivery_address_for(ident_id: IdentIdRef) -> String {
@ -174,9 +176,9 @@ impl InboxV2 {
fn handle_heavy_invite<S: ExternalServices>(
&self,
invite: GroupV1HeavyInvite,
cx: &mut ServiceContext<S>,
) -> Result<InboxOutcome, ChatError> {
invite: GroupV1HeavyInvite,
) -> Result<GroupV1Convo, ChatError> {
let (msg_in, _rest) = MlsMessageIn::tls_deserialize_bytes(invite.welcome_bytes.as_slice())?;
let MlsMessageBodyIn::Welcome(welcome) = msg_in.extract() else {
@ -187,15 +189,9 @@ impl InboxV2 {
};
let convo = GroupV1Convo::new_from_welcome(cx, welcome)?;
let convo_id: ConversationId = convo.id().to_string();
self.persist_convo(&convo, cx)?;
Ok(InboxOutcome {
new_conversation: NewConversation {
convo_id,
class: ConversationClass::Group,
},
initial: None,
})
Ok(convo)
}
fn create_keypackage<S: ExternalServices>(