diff --git a/protocol/common/message_sender.go b/protocol/common/message_sender.go index 4b589b6d5..31436d1c7 100644 --- a/protocol/common/message_sender.go +++ b/protocol/common/message_sender.go @@ -261,6 +261,11 @@ func (s *MessageSender) sendCommunity( rawMessage.ID = types.EncodeHex(messageID) messageIDs := [][]byte{messageID} + if rawMessage.BeforeDispatch != nil { + if err := rawMessage.BeforeDispatch(rawMessage); err != nil { + return nil, err + } + } // Notify before dispatching, otherwise the dispatch subscription might happen // earlier than the scheduled s.notifyOnScheduledMessage(nil, rawMessage) @@ -351,6 +356,12 @@ func (s *MessageSender) sendPrivate( messageID := v1protocol.MessageID(&rawMessage.Sender.PublicKey, wrappedMessage) rawMessage.ID = types.EncodeHex(messageID) + if rawMessage.BeforeDispatch != nil { + if err := rawMessage.BeforeDispatch(rawMessage); err != nil { + return nil, err + } + } + // Notify before dispatching, otherwise the dispatch subscription might happen // earlier than the scheduled s.notifyOnScheduledMessage(recipient, rawMessage) @@ -510,6 +521,12 @@ func (s *MessageSender) dispatchCommunityChatMessage(ctx context.Context, rawMes PowTime: whisperPoWTime, } + if rawMessage.BeforeDispatch != nil { + if err := rawMessage.BeforeDispatch(rawMessage); err != nil { + return nil, nil, err + } + } + // notify before dispatching s.notifyOnScheduledMessage(nil, rawMessage) @@ -564,6 +581,12 @@ func (s *MessageSender) SendPublic( messageID := v1protocol.MessageID(&rawMessage.Sender.PublicKey, wrappedMessage) rawMessage.ID = types.EncodeHex(messageID) + if rawMessage.BeforeDispatch != nil { + if err := rawMessage.BeforeDispatch(&rawMessage); err != nil { + return nil, err + } + } + // notify before dispatching s.notifyOnScheduledMessage(nil, &rawMessage) diff --git a/protocol/common/raw_message.go b/protocol/common/raw_message.go index 2fbbd5751..96209b33c 100644 --- a/protocol/common/raw_message.go +++ b/protocol/common/raw_message.go @@ -34,4 +34,5 @@ type RawMessage struct { CommunityID []byte CommunityKeyExMsgType CommKeyExMsgType Ephemeral bool + BeforeDispatch func(*RawMessage) error } diff --git a/protocol/messenger.go b/protocol/messenger.go index ad5dfbfed..4626bb885 100644 --- a/protocol/messenger.go +++ b/protocol/messenger.go @@ -2189,30 +2189,34 @@ func (m *Messenger) sendChatMessage(ctx context.Context, message *common.Message ResendAutomatically: true, } + // We want to save the raw message before dispatching it, to avoid race conditions + // since it might get dispatched and confirmed before it's saved. + // This is not the best solution, probably it would be better to split + // the sent status in a different table and join on query for messages, + // but that's a much larger change and it would require an expensive migration of clients + rawMessage.BeforeDispatch = func(rawMessage *common.RawMessage) error { + if rawMessage.Sent { + message.OutgoingStatus = common.OutgoingStatusSent + } + message.ID = rawMessage.ID + err = message.PrepareContent(common.PubkeyToHex(&m.identity.PublicKey)) + if err != nil { + return err + } + + err = chat.UpdateFromMessage(message, m.getTimesource()) + if err != nil { + return err + } + + return m.persistence.SaveMessages([]*common.Message{message}) + } + rawMessage, err = m.dispatchMessage(ctx, rawMessage) if err != nil { return nil, err } - if rawMessage.Sent { - message.OutgoingStatus = common.OutgoingStatusSent - } - message.ID = rawMessage.ID - err = message.PrepareContent(common.PubkeyToHex(&m.identity.PublicKey)) - if err != nil { - return nil, err - } - - err = chat.UpdateFromMessage(message, m.getTimesource()) - if err != nil { - return nil, err - } - - err = m.persistence.SaveMessages([]*common.Message{message}) - if err != nil { - return nil, err - } - msg, err := m.pullMessagesAndResponsesFromDB([]*common.Message{message}) if err != nil { return nil, err diff --git a/protocol/messenger_test.go b/protocol/messenger_test.go index 50ccb8e3a..bc7b26602 100644 --- a/protocol/messenger_test.go +++ b/protocol/messenger_test.go @@ -2354,6 +2354,12 @@ func (s *MessengerSuite) TestMessageSent() { s.True(rawMessage.Sent) } +func (s *MessengerSuite) TestProcessSentMessages() { + ids := []string{"a"} + err := s.m.processSentMessages(ids) + s.Require().NoError(err) +} + func (s *MessengerSuite) TestResendExpiredEmojis() { //send message chat := CreatePublicChat("test-chat", s.m.transport)