Add ExternalServices Trait

This commit is contained in:
Jazz Turner-Baggs 2026-05-12 17:30:22 -07:00
parent 4c6286234b
commit 6dc027124f
No known key found for this signature in database
4 changed files with 104 additions and 92 deletions

View File

@ -12,40 +12,56 @@ pub use group_v1::GroupV1Convo;
pub type ConversationIdRef<'a> = &'a str; pub type ConversationIdRef<'a> = &'a str;
pub type ConversationId = String; pub type ConversationId = String;
pub struct ServiceContext<IP: IdentityProvider, DS: DeliveryService, RS: RegistrationService> { /// A trait which bundles all the external service traits into a single scope.
pub identity_provider: IP, /// This allows for a single bound to be used internally, and cuts down on
pub ds: DS, /// the clutter
pub rs: RS, pub trait ExternalServices: Debug {
type IP: IdentityProvider;
type DS: DeliveryService;
type RS: RegistrationService;
}
#[derive(Debug)]
pub struct ServiceContext<S: ExternalServices> {
pub(crate) identity_provider: S::IP,
pub(crate) ds: S::DS,
pub(crate) rs: S::RS,
}
impl<S: ExternalServices> ServiceContext<S> {
pub fn new(identity_provider: S::IP, ds: S::DS, rs: S::RS) -> Self {
ServiceContext {
identity_provider,
ds,
rs,
}
}
} }
pub trait Id: Debug { pub trait Id: Debug {
fn id(&self) -> ConversationIdRef<'_>; fn id(&self) -> ConversationIdRef<'_>;
} }
pub trait BaseConvo<IP: IdentityProvider, DS: DeliveryService, RS: RegistrationService>: pub trait BaseConvo<S: ExternalServices>: Id + Debug {
Id + Debug fn init(&self, service_ctx: &mut ServiceContext<S>) -> Result<(), ChatError>;
{
fn init(&self, service_ctx: &mut ServiceContext<IP, DS, RS>) -> Result<(), ChatError>;
fn send_content( fn send_content(
&mut self, &mut self,
service_ctx: &mut ServiceContext<IP, DS, RS>, service_ctx: &mut ServiceContext<S>,
content: &[u8], content: &[u8],
) -> Result<(), ChatError>; ) -> Result<(), ChatError>;
fn handle_frame( fn handle_frame(
&mut self, &mut self,
service_ctx: &mut ServiceContext<IP, DS, RS>, service_ctx: &mut ServiceContext<S>,
enc_payload: EncryptedPayload, enc_payload: EncryptedPayload,
) -> Result<Option<ContentData>, ChatError>; ) -> Result<Option<ContentData>, ChatError>;
} }
pub trait BaseGroupConvo<IP: IdentityProvider, DS: DeliveryService, RS: RegistrationService>: pub trait BaseGroupConvo<S: ExternalServices>: BaseConvo<S> {
BaseConvo<IP, DS, RS>
{
fn add_member( fn add_member(
&mut self, &mut self,
service_ctx: &mut ServiceContext<IP, DS, RS>, service_ctx: &mut ServiceContext<S>,
members: &[&AccountId], members: &[&AccountId],
) -> Result<(), ChatError>; ) -> Result<(), ChatError>;
} }

View File

@ -11,7 +11,7 @@ use openmls::prelude::tls_codec::Deserialize;
use openmls::prelude::*; use openmls::prelude::*;
use crate::AccountId; use crate::AccountId;
use crate::conversation::{ConversationIdRef, ServiceContext}; use crate::conversation::{ConversationIdRef, ExternalServices, ServiceContext};
use crate::inbox_v2::{MlsIdentityProvider, MlsProvider}; use crate::inbox_v2::{MlsIdentityProvider, MlsProvider};
use crate::{ use crate::{
AddressedEncryptedPayload, ContentData, DeliveryService, IdentityProvider, RegistrationService, AddressedEncryptedPayload, ContentData, DeliveryService, IdentityProvider, RegistrationService,
@ -128,15 +128,12 @@ where
} }
} }
impl<IP, MP, DS, RS> BaseConvo<IP, DS, RS> for GroupV1Convo<MP> impl<S, MP> BaseConvo<S> for GroupV1Convo<MP>
where where
IP: IdentityProvider, S: ExternalServices,
MP: MlsProvider, MP: MlsProvider,
DS: DeliveryService,
RS: RegistrationService,
// KP: RegistrationService,
{ {
fn init(&self, service_ctx: &mut super::ServiceContext<IP, DS, RS>) -> Result<(), ChatError> { fn init(&self, service_ctx: &mut super::ServiceContext<S>) -> Result<(), ChatError> {
// Configure the delivery service to listen for the required delivery addresses. // Configure the delivery service to listen for the required delivery addresses.
service_ctx service_ctx
@ -153,7 +150,7 @@ where
fn send_content( fn send_content(
&mut self, &mut self,
service_ctx: &mut super::ServiceContext<IP, DS, RS>, service_ctx: &mut super::ServiceContext<S>,
content: &[u8], content: &[u8],
) -> Result<(), ChatError> { ) -> Result<(), ChatError> {
let signer = MlsIdentityProvider(&service_ctx.identity_provider); let signer = MlsIdentityProvider(&service_ctx.identity_provider);
@ -182,7 +179,7 @@ where
fn handle_frame( fn handle_frame(
&mut self, &mut self,
_service_ctx: &mut super::ServiceContext<IP, DS, RS>, _service_ctx: &mut super::ServiceContext<S>,
encoded_payload: EncryptedPayload, encoded_payload: EncryptedPayload,
) -> Result<Option<ContentData>, ChatError> { ) -> Result<Option<ContentData>, ChatError> {
let bytes = match encoded_payload.encryption { let bytes = match encoded_payload.encryption {
@ -231,12 +228,10 @@ where
} }
} }
impl<IP, MP, DS, RS> BaseGroupConvo<IP, DS, RS> for GroupV1Convo<MP> impl<S, MP> BaseGroupConvo<S> for GroupV1Convo<MP>
where where
IP: IdentityProvider, S: ExternalServices,
MP: MlsProvider, MP: MlsProvider,
DS: DeliveryService,
RS: RegistrationService,
{ {
// add_members returns: // add_members returns:
// commit — the Commit message Alice broadcasts to all members // commit — the Commit message Alice broadcasts to all members
@ -244,7 +239,7 @@ where
// _group_info — used for external joins; ignore for now // _group_info — used for external joins; ignore for now
fn add_member( fn add_member(
&mut self, &mut self,
service_ctx: &mut ServiceContext<IP, DS, RS>, service_ctx: &mut ServiceContext<S>,
members: &[&AccountId], members: &[&AccountId],
) -> Result<(), ChatError> { ) -> Result<(), ChatError> {
let mls_provider = &*self.mls_provider.borrow(); let mls_provider = &*self.mls_provider.borrow();
@ -304,19 +299,15 @@ where
} }
impl<MP: MlsProvider> GroupV1Convo<MP> { impl<MP: MlsProvider> GroupV1Convo<MP> {
fn key_package_for_account< fn key_package_for_account<S: ExternalServices>(
IP: IdentityProvider,
DS: DeliveryService,
RS: RegistrationService,
>(
&self, &self,
service_ctx: &mut ServiceContext<IP, DS, RS>, service_ctx: &mut ServiceContext<S>,
ident: &AccountId, ident: &AccountId,
) -> Result<KeyPackage, ChatError> { ) -> Result<KeyPackage, ChatError> {
let retrieved_bytes = service_ctx let retrieved_bytes = service_ctx
.rs .rs
.retrieve(ident) .retrieve(ident)
.map_err(|e: RS::Error| ChatError::Generic(e.to_string()))?; .map_err(|e| ChatError::Generic(e.to_string()))?;
// dbg!(ctx.contact_registry()); // dbg!(ctx.contact_registry());
let Some(keypkg_bytes) = retrieved_bytes else { let Some(keypkg_bytes) = retrieved_bytes else {

View File

@ -1,8 +1,11 @@
use std::cell::RefMut; use std::cell::RefMut;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Debug;
use std::{cell::RefCell, rc::Rc}; use std::{cell::RefCell, rc::Rc};
use crate::conversation::{BaseGroupConvo, ConversationId, ConversationIdRef, Id, ServiceContext}; use crate::conversation::{
BaseGroupConvo, ConversationId, ConversationIdRef, ExternalServices, Id, ServiceContext,
};
use crate::inbox_v2::InboxV2; use crate::inbox_v2::InboxV2;
use crate::{AccountId, errors::ChatError}; use crate::{AccountId, errors::ChatError};
@ -14,16 +17,14 @@ use prost::Message;
use storage::ChatStore; use storage::ChatStore;
#[derive(Debug)] #[derive(Debug)]
enum ConvoTypeOwned<IP: IdentityProvider, DS: DeliveryService, RS: RegistrationService> { enum ConvoTypeOwned<S: ExternalServices> {
// Pairwise(Box<dyn BaseConvo<IP, DS, RS>>), // Pairwise(Box<dyn BaseConvo<S>>),
Group(Box<dyn BaseGroupConvo<IP, DS, RS>>), Group(Box<dyn BaseGroupConvo<S>>),
} }
impl<IP, DS, RS> Id for ConvoTypeOwned<IP, DS, RS> impl<S> Id for ConvoTypeOwned<S>
where where
IP: IdentityProvider, S: ExternalServices,
DS: DeliveryService,
RS: RegistrationService,
{ {
fn id(&self) -> crate::conversation::ConversationIdRef<'_> { fn id(&self) -> crate::conversation::ConversationIdRef<'_> {
match self { match self {
@ -33,21 +34,14 @@ where
} }
} }
pub struct GroupConvo< pub struct GroupConvo<S: ExternalServices, CS: ChatStore> {
IP: IdentityProvider, client: Rc<RefCell<InnerClient<S, CS>>>,
DS: DeliveryService,
RS: RegistrationService,
CS: ChatStore,
> {
client: Rc<RefCell<InnerClient<IP, DS, RS, CS>>>,
convo_id: ConversationId, convo_id: ConversationId,
} }
impl<IP, DS, RS, CS> GroupConvo<IP, DS, RS, CS> impl<S, CS> GroupConvo<S, CS>
where where
IP: IdentityProvider + 'static, S: ExternalServices,
DS: DeliveryService + 'static,
RS: RegistrationService + 'static,
CS: ChatStore + 'static, CS: ChatStore + 'static,
{ {
pub fn send_content(&self, content: &[u8]) -> Result<(), ChatError> { pub fn send_content(&self, content: &[u8]) -> Result<(), ChatError> {
@ -56,20 +50,34 @@ where
} }
} }
// This allows the ExternalServices trait to be converted from a tuple.
// This is used in CoreClient to convert from the explicit impls to a
// ExternalServices bundle, which means it does not have to be exposed externally.
impl<IP, DS, RS> ExternalServices for (IP, DS, RS)
where
IP: IdentityProvider + Debug,
DS: DeliveryService + Debug,
RS: RegistrationService + Debug,
{
type IP = IP;
type DS = DS;
type RS = RS;
}
pub struct CoreClient< pub struct CoreClient<
IP: IdentityProvider, IP: IdentityProvider,
DS: DeliveryService, DS: DeliveryService,
RS: RegistrationService, RS: RegistrationService,
CS: ChatStore, CS: ChatStore,
> { > {
inner: Rc<RefCell<InnerClient<IP, DS, RS, CS>>>, inner: Rc<RefCell<InnerClient<(IP, DS, RS), CS>>>,
} }
impl<IP, DS, RS, CS> CoreClient<IP, DS, RS, CS> impl<IP, DS, RS, CS> CoreClient<IP, DS, RS, CS>
where where
IP: IdentityProvider + 'static, IP: IdentityProvider,
DS: DeliveryService + 'static, DS: DeliveryService,
RS: RegistrationService + 'static, RS: RegistrationService,
CS: ChatStore + 'static, CS: ChatStore + 'static,
{ {
pub fn new(account: IP, delivery: DS, registration: RS, store: CS) -> Result<Self, ChatError> { pub fn new(account: IP, delivery: DS, registration: RS, store: CS) -> Result<Self, ChatError> {
@ -90,7 +98,7 @@ where
pub fn create_group_convo( pub fn create_group_convo(
&self, &self,
participants: &[&AccountId], participants: &[&AccountId],
) -> Result<GroupConvo<IP, DS, RS, CS>, ChatError> { ) -> Result<GroupConvo<(IP, DS, RS), CS>, ChatError> {
let convo_id = self.inner.borrow_mut().create_group_convo(participants)?; let convo_id = self.inner.borrow_mut().create_group_convo(participants)?;
Ok(GroupConvo { Ok(GroupConvo {
client: self.inner.clone(), client: self.inner.clone(),
@ -114,7 +122,7 @@ where
self.inner.borrow_mut().handle_payload(payload) self.inner.borrow_mut().handle_payload(payload)
} }
pub fn convo(&self, convo_id: ConversationIdRef) -> Option<GroupConvo<IP, DS, RS, CS>> { pub fn convo(&self, convo_id: ConversationIdRef) -> Option<GroupConvo<(IP, DS, RS), CS>> {
let client = self.inner.clone(); let client = self.inner.clone();
if !client.borrow().has_conversation(convo_id) { if !client.borrow().has_conversation(convo_id) {
@ -128,36 +136,32 @@ where
} }
} }
struct InnerClient< struct InnerClient<S: ExternalServices, CS: ChatStore> {
IP: IdentityProvider, service_ctx: ServiceContext<S>,
DS: DeliveryService,
RS: RegistrationService,
CS: ChatStore,
> {
service_ctx: ServiceContext<IP, DS, RS>,
_store: Rc<RefCell<CS>>, _store: Rc<RefCell<CS>>,
pq_inbox: InboxV2<CS>, pq_inbox: InboxV2<CS>,
// Cache of loaded conversations // Cache of loaded conversations
cached_convos: HashMap<String, ConvoTypeOwned<IP, DS, RS>>, cached_convos: HashMap<String, ConvoTypeOwned<S>>,
} }
impl<IP, DS, RS, CS> InnerClient<IP, DS, RS, CS> impl<S, CS> InnerClient<S, CS>
where where
IP: IdentityProvider + 'static, S: ExternalServices,
DS: DeliveryService + 'static,
RS: RegistrationService + 'static,
CS: ChatStore + 'static, CS: ChatStore + 'static,
{ {
pub fn new(account: IP, delivery: DS, registration: RS, store: CS) -> Result<Self, ChatError> { pub fn new(
account: S::IP,
delivery: S::DS,
registration: S::RS,
store: CS,
) -> Result<Self, ChatError> {
// Services for sharing with Converastions/Inboxes // Services for sharing with Converastions/Inboxes
let mut service_ctx = ServiceContext { // let mut service_ctx: ServiceContext<S> = ServiceContext::new(account, delivery, registration);
identity_provider: account, let mut service_ctx: ServiceContext<S> =
ds: delivery, ServiceContext::new(account, delivery, registration);
rs: registration,
};
// let contact_registry = Rc::new(RefCell::new(registration)); // let contact_registry = Rc::new(RefCell::new(registration));
let _store = Rc::new(RefCell::new(store)); let _store = Rc::new(RefCell::new(store));
@ -179,7 +183,7 @@ where
}) })
} }
pub fn ds(&mut self) -> &mut DS { pub fn ds(&mut self) -> &mut S::DS {
&mut self.service_ctx.ds &mut self.service_ctx.ds
} }
@ -190,7 +194,7 @@ where
pub fn create_group_convo(&mut self, participants: &[&AccountId]) -> Result<String, ChatError> { pub fn create_group_convo(&mut self, participants: &[&AccountId]) -> Result<String, ChatError> {
let convo = self.pq_inbox.create_group_v1(&mut self.service_ctx)?; let convo = self.pq_inbox.create_group_v1(&mut self.service_ctx)?;
let mut convo: Box<dyn BaseGroupConvo<IP, DS, RS>> = Box::new(convo); let mut convo: Box<dyn BaseGroupConvo<S>> = Box::new(convo);
convo.init(&mut self.service_ctx)?; convo.init(&mut self.service_ctx)?;
convo.add_member(&mut self.service_ctx, participants)?; convo.add_member(&mut self.service_ctx, participants)?;
@ -242,7 +246,7 @@ where
// Dispatch encrypted payload to Inbox, and register the created Conversation // Dispatch encrypted payload to Inbox, and register the created Conversation
fn dispatch_to_inbox2(&mut self, payload: &[u8]) -> Result<Option<ContentData>, ChatError> { fn dispatch_to_inbox2(&mut self, payload: &[u8]) -> Result<Option<ContentData>, ChatError> {
if let Some(convo) = self.pq_inbox.handle_frame(&mut self.service_ctx, payload)? { if let Some(convo) = self.pq_inbox.handle_frame(&mut self.service_ctx, payload)? {
let convo: Box<dyn BaseGroupConvo<IP, DS, RS>> = Box::new(convo); let convo: Box<dyn BaseGroupConvo<S>> = Box::new(convo);
self.register_convo(ConvoTypeOwned::Group(convo))?; self.register_convo(ConvoTypeOwned::Group(convo))?;
} }
Ok(None) Ok(None)
@ -267,7 +271,7 @@ where
convo.handle_frame(&mut self.service_ctx, enc_payload) convo.handle_frame(&mut self.service_ctx, enc_payload)
} }
fn register_convo(&mut self, convo: ConvoTypeOwned<IP, DS, RS>) -> Result<(), ChatError> { fn register_convo(&mut self, convo: ConvoTypeOwned<S>) -> Result<(), ChatError> {
let res = self.cached_convos.insert(convo.id().to_string(), convo); let res = self.cached_convos.insert(convo.id().to_string(), convo);
match res { match res {

View File

@ -20,6 +20,7 @@ use crate::DeliveryService;
use crate::IdentityProvider; use crate::IdentityProvider;
use crate::RegistrationService; use crate::RegistrationService;
use crate::conversation::BaseConvo; use crate::conversation::BaseConvo;
use crate::conversation::ExternalServices;
use crate::conversation::ServiceContext; use crate::conversation::ServiceContext;
use crate::conversation::{GroupV1Convo, Id}; use crate::conversation::{GroupV1Convo, Id};
use crate::utils::{blake2b_hex, hash_size}; use crate::utils::{blake2b_hex, hash_size};
@ -166,8 +167,8 @@ pub struct InboxV2<CS> {
} }
impl<CS: ChatStore> InboxV2<CS> { impl<CS: ChatStore> InboxV2<CS> {
pub fn new<IP: IdentityProvider, DS: DeliveryService, RS: RegistrationService>( pub fn new<S: ExternalServices>(
service_ctx: &mut ServiceContext<IP, DS, RS>, service_ctx: &mut ServiceContext<S>,
_store: Rc<RefCell<CS>>, _store: Rc<RefCell<CS>>,
) -> Self { ) -> Self {
// Avoid referencing a temporary value by caching it. // Avoid referencing a temporary value by caching it.
@ -193,9 +194,9 @@ impl<CS: ChatStore> InboxV2<CS> {
} }
/// Submit MlsKeypackage to registration service /// Submit MlsKeypackage to registration service
pub fn register<IP: IdentityProvider, DS: DeliveryService, RS: RegistrationService>( pub fn register<S: ExternalServices>(
&self, &self,
service_ctx: &mut ServiceContext<IP, DS, RS>, service_ctx: &mut ServiceContext<S>,
) -> Result<(), ChatError> { ) -> Result<(), ChatError> {
let mls_ident = MlsIdentityProvider(&service_ctx.identity_provider); let mls_ident = MlsIdentityProvider(&service_ctx.identity_provider);
let keypackage_bytes = self let keypackage_bytes = self
@ -213,9 +214,9 @@ impl<CS: ChatStore> InboxV2<CS> {
.map_err(ChatError::generic) .map_err(ChatError::generic)
} }
pub fn create_group_v1<IP: IdentityProvider, DS: DeliveryService, RS: RegistrationService>( pub fn create_group_v1<S: ExternalServices>(
&self, &self,
service_ctx: &mut ServiceContext<IP, DS, RS>, service_ctx: &mut ServiceContext<S>,
) -> Result<GroupV1Convo<MlsEphemeralPqProvider>, ChatError> { ) -> Result<GroupV1Convo<MlsEphemeralPqProvider>, ChatError> {
let mls_ident = MlsIdentityProvider(&service_ctx.identity_provider); let mls_ident = MlsIdentityProvider(&service_ctx.identity_provider);
GroupV1Convo::new(mls_ident, self.mls_provider.clone()) GroupV1Convo::new(mls_ident, self.mls_provider.clone())
@ -247,9 +248,9 @@ impl<CS: ChatStore> InboxV2<CS> {
} }
impl<CS: ChatStore> InboxV2<CS> { impl<CS: ChatStore> InboxV2<CS> {
pub fn handle_frame<IP: IdentityProvider, DS: DeliveryService, RS: RegistrationService>( pub fn handle_frame<S: ExternalServices>(
&self, &self,
service_ctx: &mut ServiceContext<IP, DS, RS>, service_ctx: &mut ServiceContext<S>,
payload_bytes: &[u8], payload_bytes: &[u8],
) -> Result<Option<GroupV1Convo<MlsEphemeralPqProvider>>, ChatError> { ) -> Result<Option<GroupV1Convo<MlsEphemeralPqProvider>>, ChatError> {
let inbox_frame = InboxV2Frame::decode(payload_bytes)?; let inbox_frame = InboxV2Frame::decode(payload_bytes)?;
@ -265,9 +266,9 @@ impl<CS: ChatStore> InboxV2<CS> {
} }
} }
fn handle_heavy_invite<IP: IdentityProvider, DS: DeliveryService, RS: RegistrationService>( fn handle_heavy_invite<S: ExternalServices>(
&self, &self,
service_ctx: &mut ServiceContext<IP, DS, RS>, service_ctx: &mut ServiceContext<S>,
invite: GroupV1HeavyInvite, invite: GroupV1HeavyInvite,
) -> Result<GroupV1Convo<MlsEphemeralPqProvider>, ChatError> { ) -> Result<GroupV1Convo<MlsEphemeralPqProvider>, ChatError> {
let (msg_in, _rest) = MlsMessageIn::tls_deserialize_bytes(invite.welcome_bytes.as_slice())?; let (msg_in, _rest) = MlsMessageIn::tls_deserialize_bytes(invite.welcome_bytes.as_slice())?;