From d6b3c1a7263fb21782d29fe7544c178b27f04037 Mon Sep 17 00:00:00 2001 From: kaichaosun Date: Thu, 5 Feb 2026 15:40:30 +0800 Subject: [PATCH] feat: consistent chat ids between parties --- conversations/src/chat.rs | 231 ++++++++++++++++++++++++------ conversations/src/common.rs | 20 ++- conversations/src/dm/privatev1.rs | 28 ++-- conversations/src/errors.rs | 8 -- conversations/src/inbox/inbox.rs | 81 ++++++++--- conversations/src/storage/db.rs | 20 +++ 6 files changed, 312 insertions(+), 76 deletions(-) diff --git a/conversations/src/chat.rs b/conversations/src/chat.rs index b8727ce..454e830 100644 --- a/conversations/src/chat.rs +++ b/conversations/src/chat.rs @@ -7,6 +7,7 @@ use std::collections::HashMap; use std::rc::Rc; use double_ratchets::storage::RatchetStorage; +use prost::Message; use crate::{ common::{Chat, HasChatId, InboundMessageHandler}, @@ -14,6 +15,7 @@ use crate::{ errors::ChatError, identity::Identity, inbox::{Inbox, Introduction}, + proto, storage::{ChatRecord, ChatStorage, StorageError}, types::{AddressedEnvelope, ContentData}, }; @@ -243,48 +245,79 @@ impl ChatManager { /// Returns the decrypted content if successful. /// Any new chats or state changes are automatically persisted. pub fn handle_incoming(&mut self, payload: &[u8]) -> Result { - // Create storage for potential new conversation - let ratchet_storage = self.create_ratchet_storage()?; - - // Try to handle as inbox message (new chat invitation) - match self.inbox.handle_frame(ratchet_storage, payload) { - Ok((chat, content_data)) => { - let chat_id = chat.id().to_string(); - - // Persist the new chat metadata - let chat_record = ChatRecord { - chat_id: chat_id.clone(), - chat_type: "private_v1".to_string(), - remote_public_key: None, // Would need to extract from handshake - remote_address: "unknown".to_string(), - created_at: crate::utils::timestamp_millis() as i64, - }; - self.storage.save_chat(&chat_record)?; - - // TODO: Persist ratchet state for incoming chats - // This requires modifying InboundMessageHandler to return PrivateV1Convo - // or adding downcast support. For now, new chats from inbox won't persist - // their ratchet state until next send_message call. - - // Return first content if any, otherwise empty - if let Some(first) = content_data.into_iter().next() { - return Ok(first); - } - - Ok(ContentData { - conversation_id: chat_id, - data: vec![], - }) - } - Err(_) => { - // Not an inbox message, try existing chats - // For now, return placeholder - would need to route to correct chat - Ok(ContentData { - conversation_id: "unknown".into(), - data: vec![], - }) + // Try to decode as an envelope + if let Ok(envelope) = proto::EnvelopeV1::decode(payload) { + let chat_id = &envelope.conversation_hint; + + // Check if we have this chat - if so, route to it for decryption + if !chat_id.is_empty() && self.chat_exists(chat_id)? { + return self.receive_message(chat_id, &envelope.payload); } + + // We don't have this chat - try to handle as inbox handshake + // Pass the conversation_hint so both parties use the same chat ID + return self.handle_inbox_handshake(chat_id, &envelope.payload); } + + // Not a valid envelope - generate a new chat ID (for backwards compatibility) + let new_chat_id = crate::utils::generate_chat_id(); + self.handle_inbox_handshake(&new_chat_id, payload) + } + + /// Handle an inbox handshake to establish a new chat. + fn handle_inbox_handshake(&mut self, conversation_hint: &str, payload: &[u8]) -> Result { + let ratchet_storage = self.create_ratchet_storage()?; + let result = self.inbox.handle_frame(ratchet_storage, conversation_hint, payload)?; + + let chat_id = result.convo.id().to_string(); + + // Persist the new chat metadata + let chat_record = ChatRecord { + chat_id: chat_id.clone(), + chat_type: "private_v1".to_string(), + remote_public_key: Some(result.remote_public_key), + remote_address: hex::encode(result.remote_public_key), + created_at: crate::utils::timestamp_millis() as i64, + }; + self.storage.save_chat(&chat_record)?; + + // Store the conversation in memory cache + // (ratchet state is already persisted by RatchetSession) + self.chats.insert(chat_id.clone(), result.convo); + + Ok(ContentData { + conversation_id: chat_id, + data: result.initial_content.unwrap_or_default(), + }) + } + + /// Receive and decrypt a message for an existing chat. + /// + /// The payload should be the raw encrypted payload bytes. + pub fn receive_message( + &mut self, + chat_id: &str, + payload: &[u8], + ) -> Result { + // Ensure the chat is loaded + self.ensure_chat_loaded(chat_id)?; + + let chat = self + .chats + .get_mut(chat_id) + .ok_or_else(|| ChatManagerError::ChatNotFound(chat_id.to_string()))?; + + // Decode and decrypt the payload + let encrypted_payload = proto::EncryptedPayload::decode(payload) + .map_err(|e| ChatManagerError::Chat(ChatError::Protocol(format!("failed to decode: {}", e))))?; + + let frame = chat.decrypt(encrypted_payload)?; + let content = PrivateV1Convo::extract_content(&frame).unwrap_or_default(); + + Ok(ContentData { + conversation_id: chat_id.to_string(), + data: content, + }) } /// Get a reference to an active chat. @@ -472,4 +505,122 @@ mod tests { assert!(alice2.chats.contains_key(&chat_id)); } } + + #[test] + fn test_full_message_roundtrip() { + let mut alice = ChatManager::in_memory().unwrap(); + let mut bob = ChatManager::in_memory().unwrap(); + + // Bob creates an intro bundle and shares it with Alice + let bob_intro = bob.create_intro_bundle().unwrap(); + + // Alice starts a chat with Bob and sends "Hello!" + let (alice_chat_id, envelopes) = alice + .start_private_chat(&bob_intro, "Hello Bob!") + .unwrap(); + + // Verify Alice has the chat + assert!(alice.chat_exists(&alice_chat_id).unwrap()); + assert_eq!(alice.list_chats().len(), 1); + + // Simulate network delivery: Bob receives the envelope + let envelope = envelopes.first().unwrap(); + let content = bob.handle_incoming(&envelope.data).unwrap(); + + // Bob should have received the message + assert_eq!(content.data, b"Hello Bob!"); + + // Bob should now have a chat + assert_eq!(bob.list_chats().len(), 1); + let bob_chat_id = bob.list_chats().first().unwrap().clone(); + + // Bob replies to Alice + let bob_reply_envelopes = bob.send_message(&bob_chat_id, b"Hi Alice!").unwrap(); + assert!(!bob_reply_envelopes.is_empty()); + + // Alice receives Bob's reply + let bob_reply = bob_reply_envelopes.first().unwrap(); + let alice_received = alice.handle_incoming(&bob_reply.data).unwrap(); + + assert_eq!(alice_received.data, b"Hi Alice!"); + assert_eq!(alice_received.conversation_id, alice_chat_id); + + // Continue the conversation - Alice sends another message + let alice_envelopes = alice.send_message(&alice_chat_id, b"How are you?").unwrap(); + let alice_msg = alice_envelopes.first().unwrap(); + let bob_received = bob.handle_incoming(&alice_msg.data).unwrap(); + + assert_eq!(bob_received.data, b"How are you?"); + + // Bob replies again + let bob_envelopes = bob.send_message(&bob_chat_id, b"I'm good, thanks!").unwrap(); + let bob_msg = bob_envelopes.first().unwrap(); + let alice_received2 = alice.handle_incoming(&bob_msg.data).unwrap(); + + assert_eq!(alice_received2.data, b"I'm good, thanks!"); + } + + #[test] + fn test_message_persistence_across_sessions() { + use tempfile::tempdir; + + let dir = tempdir().unwrap(); + let alice_db = dir.path().join("alice.db"); + let bob_db = dir.path().join("bob.db"); + + let alice_chat_id; + let bob_chat_id; + let bob_intro; + + // Phase 1: Establish chat + { + let mut alice = ChatManager::open(StorageConfig::File( + alice_db.to_str().unwrap().to_string(), + )) + .unwrap(); + let mut bob = + ChatManager::open(StorageConfig::File(bob_db.to_str().unwrap().to_string())) + .unwrap(); + + bob_intro = bob.create_intro_bundle().unwrap(); + let (chat_id, envelopes) = alice.start_private_chat(&bob_intro, "Initial").unwrap(); + alice_chat_id = chat_id; + + // Bob receives + let envelope = envelopes.first().unwrap(); + let content = bob.handle_incoming(&envelope.data).unwrap(); + assert_eq!(content.data, b"Initial"); + bob_chat_id = bob.list_chats().first().unwrap().clone(); + } + // Both dropped - simulates app restart + + // Phase 2: Continue conversation after restart + { + let mut alice = ChatManager::open(StorageConfig::File( + alice_db.to_str().unwrap().to_string(), + )) + .unwrap(); + let mut bob = + ChatManager::open(StorageConfig::File(bob_db.to_str().unwrap().to_string())) + .unwrap(); + + // Both should have persisted chats + assert!(alice.list_stored_chats().unwrap().contains(&alice_chat_id)); + assert!(bob.list_stored_chats().unwrap().contains(&bob_chat_id)); + + // Alice sends a message (chat loads from storage) + let envelopes = alice.send_message(&alice_chat_id, b"After restart").unwrap(); + + // Bob receives (chat loads from storage) + let envelope = envelopes.first().unwrap(); + let content = bob.handle_incoming(&envelope.data).unwrap(); + assert_eq!(content.data, b"After restart"); + + // Bob replies + let bob_envelopes = bob.send_message(&bob_chat_id, b"Still works!").unwrap(); + let bob_msg = bob_envelopes.first().unwrap(); + let alice_received = alice.handle_incoming(&bob_msg.data).unwrap(); + assert_eq!(alice_received.data, b"Still works!"); + } + } } diff --git a/conversations/src/common.rs b/conversations/src/common.rs index 9714a68..6f0cfad 100644 --- a/conversations/src/common.rs +++ b/conversations/src/common.rs @@ -1,7 +1,8 @@ use std::fmt::Debug; +use crate::dm::privatev1::PrivateV1Convo; pub use crate::errors::ChatError; -use crate::types::{AddressedEncryptedPayload, ContentData}; +use crate::types::AddressedEncryptedPayload; use double_ratchets::storage::RatchetStorage; pub type ChatId<'a> = &'a str; @@ -10,12 +11,27 @@ pub trait HasChatId: Debug { fn id(&self) -> ChatId<'_>; } +/// Result of handling an incoming inbox message (new chat invitation). +pub struct InboxHandleResult { + /// The newly created conversation. + pub convo: PrivateV1Convo, + /// The remote party's public key (for storage/display). + pub remote_public_key: [u8; 32], + /// Decrypted initial message content, if any. + pub initial_content: Option>, +} + pub trait InboundMessageHandler { + /// Handle an incoming inbox frame. + /// + /// `conversation_hint` is the sender's conversation ID from the envelope, + /// which should be used as the shared conversation ID for this chat. fn handle_frame( &mut self, storage: RatchetStorage, + conversation_hint: &str, encoded_payload: &[u8], - ) -> Result<(Box, Vec), ChatError>; + ) -> Result; } pub trait Chat: HasChatId + Debug { diff --git a/conversations/src/dm/privatev1.rs b/conversations/src/dm/privatev1.rs index 4c22918..bfa82b7 100644 --- a/conversations/src/dm/privatev1.rs +++ b/conversations/src/dm/privatev1.rs @@ -13,7 +13,7 @@ use x25519_dalek::PublicKey; use crate::{ common::{Chat, ChatId, HasChatId}, - errors::{ChatError, EncryptionError}, + errors::ChatError, proto, types::AddressedEncryptedPayload, utils::timestamp_millis, @@ -86,18 +86,17 @@ impl PrivateV1Convo { }) } - fn decrypt(&mut self, payload: EncryptedPayload) -> Result { + /// Decrypt an incoming encrypted payload. + pub fn decrypt(&mut self, payload: EncryptedPayload) -> Result { // Validate and extract the encryption header or return errors let dr_header = if let Some(enc) = payload.encryption { if let proto::Encryption::Doubleratchet(dr) = enc { dr } else { - return Err(EncryptionError::Decryption( - "incorrect encryption type".into(), - )); + return Err(ChatError::Protocol("incorrect encryption type".into())); } } else { - return Err(EncryptionError::Decryption("missing payload".into())); + return Err(ChatError::Protocol("missing payload".into())); }; // Turn the bytes into a PublicKey @@ -105,7 +104,7 @@ impl PrivateV1Convo { .dh .to_vec() .try_into() - .map_err(|_| EncryptionError::Decryption("invalid public key length".into()))?; + .map_err(|_| ChatError::InvalidKeyLength)?; let dh_pub = PublicKey::from(byte_arr); // Build the Header that DR impl expects @@ -118,9 +117,18 @@ impl PrivateV1Convo { // Decrypt into Frame let content_bytes = self .session - .decrypt_message(&dr_header.ciphertext, header) - .map_err(|e| EncryptionError::Decryption(e.to_string()))?; - Ok(PrivateV1Frame::decode(content_bytes.as_slice()).unwrap()) + .decrypt_message(&dr_header.ciphertext, header)?; + + PrivateV1Frame::decode(content_bytes.as_slice()) + .map_err(|e| ChatError::Protocol(format!("failed to decode frame: {}", e))) + } + + /// Extract content bytes from a decrypted frame. + pub fn extract_content(frame: &PrivateV1Frame) -> Option> { + match &frame.frame_type { + Some(FrameType::Content(bytes)) => Some(bytes.to_vec()), + _ => None, + } } } diff --git a/conversations/src/errors.rs b/conversations/src/errors.rs index ec2e766..098e39f 100644 --- a/conversations/src/errors.rs +++ b/conversations/src/errors.rs @@ -25,11 +25,3 @@ pub enum ChatError { #[error("session error: {0}")] Session(#[from] double_ratchets::SessionError), } - -#[derive(Error, Debug)] -pub enum EncryptionError { - #[error("encryption: {0}")] - Encryption(String), - #[error("decryption: {0}")] - Decryption(String), -} diff --git a/conversations/src/inbox/inbox.rs b/conversations/src/inbox/inbox.rs index 89a6e6e..06969da 100644 --- a/conversations/src/inbox/inbox.rs +++ b/conversations/src/inbox/inbox.rs @@ -8,14 +8,14 @@ use std::rc::Rc; use crypto::{PrekeyBundle, SecretKey}; use double_ratchets::storage::RatchetStorage; -use crate::common::{Chat, ChatId, HasChatId, InboundMessageHandler}; +use crate::common::{Chat, ChatId, HasChatId, InboundMessageHandler, InboxHandleResult}; use crate::dm::privatev1::PrivateV1Convo; use crate::errors::ChatError; use crate::identity::Identity; use crate::identity::{PublicKey, StaticSecret}; use crate::inbox::handshake::InboxHandshake; use crate::proto::{self, CopyBytes}; -use crate::types::{AddressedEncryptedPayload, ContentData}; +use crate::types::AddressedEncryptedPayload; use crate::utils::generate_chat_id; use super::Introduction; @@ -231,10 +231,11 @@ impl InboundMessageHandler for Inbox { fn handle_frame( &mut self, storage: RatchetStorage, + conversation_hint: &str, message: &[u8], - ) -> Result<(Box, Vec), ChatError> { - if message.len() == 0 { - return Err(ChatError::Protocol("Example error".into())); + ) -> Result { + if message.is_empty() { + return Err(ChatError::Protocol("empty message".into())); } let handshake = Self::extract_payload(proto::EncryptedPayload::decode(message)?)?; @@ -243,23 +244,49 @@ impl InboundMessageHandler for Inbox { .header .ok_or(ChatError::UnexpectedPayload("InboxV1Header".into()))?; - // Get Ephemeral key used by the initator + // Extract the remote party's public key + let remote_public_key: [u8; 32] = header + .initiator_static + .as_ref() + .try_into() + .map_err(|_| ChatError::InvalidKeyLength)?; + + // Get Ephemeral key used by the initiator let key_index = hex::encode(header.responder_ephemeral.as_ref()); let ephemeral_key = self.lookup_ephemeral_key(&key_index)?; // Perform handshake and decrypt frame let (seed_key, frame) = self.perform_handshake(ephemeral_key, header, handshake.payload)?; - match frame.frame_type.unwrap() { - proto::inbox_v1_frame::FrameType::InvitePrivateV1(_invite_private_v1) => { - // Generate unique chat ID for the responder - let chat_id = generate_chat_id(); + match frame.frame_type.ok_or(ChatError::Protocol("missing frame type".into()))? { + proto::inbox_v1_frame::FrameType::InvitePrivateV1(invite) => { + // Use the sender's conversation_hint as the shared chat ID + let chat_id = conversation_hint.to_string(); let installation_keypair = double_ratchets::InstallationKeyPair::from(ephemeral_key.clone()); - let convo = PrivateV1Convo::new_responder(storage, chat_id, seed_key, installation_keypair)?; + let mut convo = PrivateV1Convo::new_responder( + storage, + chat_id, + seed_key, + installation_keypair, + )?; - // TODO: Update PrivateV1 Constructor with DR, initial_message - Ok((Box::new(convo), vec![])) + // Decrypt the initial message if present + let initial_content = if let Some(encrypted_payload) = invite.initial_message { + let frame = convo.decrypt(encrypted_payload)?; + PrivateV1Convo::extract_content(&frame) + } else { + None + }; + + // Consume the ephemeral key after successful handshake + self.consume_ephemeral_key(&key_index); + + Ok(InboxHandleResult { + convo, + remote_public_key, + initial_content, + }) } } } @@ -282,10 +309,13 @@ mod tests { let storage_receiver = RatchetStorage::in_memory().unwrap(); let (bundle, _secret) = raya_inbox.create_bundle(); - let (_, payloads) = saro_inbox + let (saro_convo, payloads) = saro_inbox .invite_to_private_convo(storage_sender, &bundle.into(), "hello".into()) .unwrap(); + // The initiator's conversation ID becomes the shared conversation_hint + let conversation_hint = saro_convo.id().to_string(); + let payload = payloads .get(0) .expect("RemoteInbox::invite_to_private_convo did not generate any payloads"); @@ -294,11 +324,30 @@ mod tests { payload.data.encode(&mut buf).unwrap(); // Test handle_frame with valid payload - let result = raya_inbox.handle_frame(storage_receiver, &buf); + let result = raya_inbox.handle_frame(storage_receiver, &conversation_hint, &buf); assert!( result.is_ok(), - "handle_frame should accept valid encrypted payloads" + "handle_frame should accept valid encrypted payloads: {:?}", + result.err() + ); + + // Verify we got the decrypted initial message + let handle_result = result.unwrap(); + assert_eq!( + handle_result.initial_content, + Some(b"hello".to_vec()), + "should decrypt initial message" + ); + + // Verify remote public key was extracted + assert_eq!(handle_result.remote_public_key.len(), 32); + + // Verify both parties have the same conversation ID + assert_eq!( + handle_result.convo.id(), + saro_convo.id(), + "both parties should share the same conversation ID" ); } } diff --git a/conversations/src/storage/db.rs b/conversations/src/storage/db.rs index 9cb2abb..dd8652f 100644 --- a/conversations/src/storage/db.rs +++ b/conversations/src/storage/db.rs @@ -178,6 +178,26 @@ impl ChatStorage { Ok(exists) } + /// Finds a chat by remote address. + /// Returns the chat_id if found, None otherwise. + #[allow(dead_code)] + pub fn find_chat_by_remote_address( + &self, + remote_address: &str, + ) -> Result, StorageError> { + let mut stmt = self + .db + .connection() + .prepare("SELECT chat_id FROM chats WHERE remote_address = ?1 LIMIT 1")?; + + let mut rows = stmt.query(params![remote_address])?; + if let Some(row) = rows.next()? { + Ok(Some(row.get(0)?)) + } else { + Ok(None) + } + } + /// Deletes a chat record. /// Note: Ratchet state must be deleted separately via RatchetStorage. pub fn delete_chat(&mut self, chat_id: &str) -> Result<(), StorageError> {