diff --git a/api/messenger_raw_message_resend_test.go b/api/messenger_raw_message_resend_test.go index eed1a2d0c..b0c07b4dd 100644 --- a/api/messenger_raw_message_resend_test.go +++ b/api/messenger_raw_message_resend_test.go @@ -8,7 +8,7 @@ import ( "time" "github.com/cenkalti/backoff/v3" - + "github.com/golang/protobuf/proto" "go.uber.org/zap" "github.com/status-im/status-go/eth-node/types" @@ -55,11 +55,12 @@ func (s *MessengerRawMessageResendTest) SetupTest() { signal.SetMobileSignalHandler(nil) exchangeNodeConfig := &wakuv2.Config{ - Port: 0, - EnableDiscV5: true, - EnablePeerExchangeServer: true, - ClusterID: 16, - DefaultShardPubsubTopic: shard.DefaultShardPubsubTopic(), + Port: 0, + EnableDiscV5: true, + EnablePeerExchangeServer: true, + ClusterID: 16, + DefaultShardPubsubTopic: shard.DefaultShardPubsubTopic(), + EnableStoreConfirmationForMessagesSent: false, } s.exchangeBootNode, err = wakuv2.New(nil, "", exchangeNodeConfig, s.logger.Named("pxServerNode"), nil, nil, nil, nil) s.Require().NoError(err) @@ -208,14 +209,9 @@ func (s *MessengerRawMessageResendTest) setCreateAccountRequest(displayName, roo } } -// TestMessageSent tests if ApplicationMetadataMessage_COMMUNITY_REQUEST_TO_JOIN is in state `sent` without resending -func (s *MessengerRawMessageResendTest) TestMessageSent() { - ids, err := s.bobMessenger.RawMessagesIDsByType(protobuf.ApplicationMetadataMessage_COMMUNITY_REQUEST_TO_JOIN) - s.Require().NoError(err) - s.Require().Len(ids, 1) - - err = tt.RetryWithBackOff(func() error { - rawMessage, err := s.bobMessenger.RawMessageByID(ids[0]) +func (s *MessengerRawMessageResendTest) waitForMessageSent(messageID string) { + err := tt.RetryWithBackOff(func() error { + rawMessage, err := s.bobMessenger.RawMessageByID(messageID) s.Require().NoError(err) s.Require().NotNil(rawMessage) if rawMessage.SendCount > 0 { @@ -226,16 +222,31 @@ func (s *MessengerRawMessageResendTest) TestMessageSent() { s.Require().NoError(err) } +// TestMessageSent tests if ApplicationMetadataMessage_COMMUNITY_REQUEST_TO_JOIN is in state `sent` without resending +func (s *MessengerRawMessageResendTest) TestMessageSent() { + ids, err := s.bobMessenger.RawMessagesIDsByType(protobuf.ApplicationMetadataMessage_COMMUNITY_REQUEST_TO_JOIN) + s.Require().NoError(err) + s.Require().Len(ids, 1) + + s.waitForMessageSent(ids[0]) +} + // TestMessageResend tests if ApplicationMetadataMessage_COMMUNITY_REQUEST_TO_JOIN is resent func (s *MessengerRawMessageResendTest) TestMessageResend() { ids, err := s.bobMessenger.RawMessagesIDsByType(protobuf.ApplicationMetadataMessage_COMMUNITY_REQUEST_TO_JOIN) s.Require().NoError(err) s.Require().Len(ids, 1) + // wait for Sent status for already sent message to make sure that sent message was delivered + // before testing resend + s.waitForMessageSent(ids[0]) + rawMessage, err := s.bobMessenger.RawMessageByID(ids[0]) s.Require().NoError(err) s.Require().NotNil(rawMessage) + s.Require().NoError(s.bobMessenger.UpdateRawMessageSent(rawMessage.ID, false)) s.Require().NoError(s.bobMessenger.UpdateRawMessageLastSent(rawMessage.ID, 0)) + err = tt.RetryWithBackOff(func() error { rawMessage, err := s.bobMessenger.RawMessageByID(ids[0]) s.Require().NoError(err) @@ -255,6 +266,47 @@ func (s *MessengerRawMessageResendTest) TestMessageResend() { }, s.aliceMessenger) } +func (s *MessengerRawMessageResendTest) TestInvalidRawMessageToWatchDoesNotProduceResendLoop() { + ids, err := s.bobMessenger.RawMessagesIDsByType(protobuf.ApplicationMetadataMessage_COMMUNITY_REQUEST_TO_JOIN) + s.Require().NoError(err) + s.Require().Len(ids, 1) + + s.waitForMessageSent(ids[0]) + + rawMessage, err := s.bobMessenger.RawMessageByID(ids[0]) + s.Require().NoError(err) + + requestToJoinProto := &protobuf.CommunityRequestToJoin{} + err = proto.Unmarshal(rawMessage.Payload, requestToJoinProto) + s.Require().NoError(err) + + requestToJoinProto.DisplayName = "invalid_ID" + payload, err := proto.Marshal(requestToJoinProto) + s.Require().NoError(err) + rawMessage.Payload = payload + + _, err = s.bobMessenger.AddRawMessageToWatch(rawMessage) + s.Require().Error(err, common.ErrModifiedRawMessage) + + // simulate storing msg with modified payload, but old message ID + _, err = s.bobMessenger.UpsertRawMessageToWatch(rawMessage) + s.Require().NoError(err) + s.Require().NoError(s.bobMessenger.UpdateRawMessageSent(rawMessage.ID, false)) + s.Require().NoError(s.bobMessenger.UpdateRawMessageLastSent(rawMessage.ID, 0)) + + // check counter increased for invalid message to escape the loop + err = tt.RetryWithBackOff(func() error { + rawMessage, err := s.bobMessenger.RawMessageByID(ids[0]) + s.Require().NoError(err) + s.Require().NotNil(rawMessage) + if rawMessage.SendCount < 2 { + return errors.New("message ApplicationMetadataMessage_COMMUNITY_REQUEST_TO_JOIN was not resent yet") + } + return nil + }) + s.Require().NoError(err) +} + // To be removed in https://github.com/status-im/status-go/issues/4437 func advertiseCommunityToUserOldWay(s *suite.Suite, community *communities.Community, alice *protocol.Messenger, bob *protocol.Messenger) { chat := protocol.CreateOneToOneChat(bob.IdentityPublicKeyString(), bob.IdentityPublicKey(), bob.GetTransport()) diff --git a/protocol/common/errors.go b/protocol/common/errors.go index e40a89922..f9f0067f9 100644 --- a/protocol/common/errors.go +++ b/protocol/common/errors.go @@ -3,3 +3,4 @@ package common import "errors" var ErrRecordNotFound = errors.New("record not found") +var ErrModifiedRawMessage = errors.New("modified rawMessage") diff --git a/protocol/common/message_sender.go b/protocol/common/message_sender.go index d48eec5c5..479cd203e 100644 --- a/protocol/common/message_sender.go +++ b/protocol/common/message_sender.go @@ -213,7 +213,9 @@ func (s *MessageSender) SendPubsubTopicKey( return nil, err } - rawMessage.ID = types.EncodeHex(messageID) + if err = s.setMessageID(messageID, rawMessage); err != nil { + return nil, err + } // Notify before dispatching, otherwise the dispatch subscription might happen // earlier than the scheduled @@ -247,12 +249,14 @@ func (s *MessageSender) SendGroup( } // Calculate messageID first and set on raw message - wrappedMessage, err := s.wrapMessageV1(&rawMessage) + messageID, err := s.getMessageID(&rawMessage) if err != nil { - return nil, errors.Wrap(err, "failed to wrap message") + return nil, err + } + + if err = s.setMessageID(messageID, &rawMessage); err != nil { + return nil, err } - messageID := v1protocol.MessageID(&rawMessage.Sender.PublicKey, wrappedMessage) - rawMessage.ID = types.EncodeHex(messageID) // We call it only once, and we nil the function after so it doesn't get called again if rawMessage.BeforeDispatch != nil { @@ -278,10 +282,42 @@ func (s *MessageSender) getMessageID(rawMessage *RawMessage) (types.HexBytes, er } messageID := v1protocol.MessageID(&rawMessage.Sender.PublicKey, wrappedMessage) - return messageID, nil } +func (s *MessageSender) ValidateRawMessage(rawMessage *RawMessage) error { + id, err := s.getMessageID(rawMessage) + if err != nil { + return err + } + messageID := types.EncodeHex(id) + + return s.validateMessageID(messageID, rawMessage) + +} + +func (s *MessageSender) validateMessageID(messageID string, rawMessage *RawMessage) error { + if len(rawMessage.ID) > 0 && rawMessage.ID != messageID { + s.logger.Error("failed to validate message ID, RawMessage content was modified", + zap.String("prevID", rawMessage.ID), + zap.String("newID", messageID), + zap.Any("contentType", rawMessage.MessageType)) + return ErrModifiedRawMessage + } + return nil +} + +func (s *MessageSender) setMessageID(messageID types.HexBytes, rawMessage *RawMessage) error { + msgID := types.EncodeHex(messageID) + + if err := s.validateMessageID(msgID, rawMessage); err != nil { + return err + } + + rawMessage.ID = msgID + return nil +} + func ShouldCommunityMessageBeEncrypted(msgType protobuf.ApplicationMetadataMessage_Type) bool { return msgType == protobuf.ApplicationMetadataMessage_CHAT_MESSAGE || msgType == protobuf.ApplicationMetadataMessage_EDIT_MESSAGE || @@ -308,7 +344,10 @@ func (s *MessageSender) sendCommunity( if err != nil { return nil, err } - rawMessage.ID = types.EncodeHex(messageID) + + if err = s.setMessageID(messageID, rawMessage); err != nil { + return nil, err + } if rawMessage.BeforeDispatch != nil { if err := rawMessage.BeforeDispatch(rawMessage); err != nil { @@ -418,7 +457,11 @@ func (s *MessageSender) sendPrivate( } messageID := v1protocol.MessageID(&rawMessage.Sender.PublicKey, wrappedMessage) - rawMessage.ID = types.EncodeHex(messageID) + + if err = s.setMessageID(messageID, rawMessage); err != nil { + return nil, err + } + if rawMessage.BeforeDispatch != nil { if err := rawMessage.BeforeDispatch(rawMessage); err != nil { return nil, err @@ -479,7 +522,7 @@ func (s *MessageSender) sendPrivate( } s.logger.Debug("sent-message: private skipProtocolLayer", - zap.Strings("recipient", PubkeysToHex(rawMessage.Recipients)), + zap.String("recipient", PubkeyToHex(recipient)), zap.String("messageID", messageID.String()), zap.String("messageType", "private"), zap.Any("contentType", rawMessage.MessageType), @@ -499,7 +542,7 @@ func (s *MessageSender) sendPrivate( } s.logger.Debug("sent-message: private without datasync", - zap.Strings("recipient", PubkeysToHex(rawMessage.Recipients)), + zap.String("recipient", PubkeyToHex(recipient)), zap.String("messageID", messageID.String()), zap.Any("contentType", rawMessage.MessageType), zap.String("messageType", "private"), @@ -723,7 +766,10 @@ func (s *MessageSender) SendPublic( newMessage.PubsubTopic = rawMessage.PubsubTopic messageID := v1protocol.MessageID(&rawMessage.Sender.PublicKey, wrappedMessage) - rawMessage.ID = types.EncodeHex(messageID) + + if err = s.setMessageID(messageID, &rawMessage); err != nil { + return nil, err + } if rawMessage.BeforeDispatch != nil { if err := rawMessage.BeforeDispatch(&rawMessage); err != nil { diff --git a/protocol/messenger_communities.go b/protocol/messenger_communities.go index 413f0f740..aacffc442 100644 --- a/protocol/messenger_communities.go +++ b/protocol/messenger_communities.go @@ -1497,7 +1497,7 @@ func (m *Messenger) RequestToJoinCommunity(request *requests.RequestToJoinCommun return nil, err } - rawMessage := common.RawMessage{ + rawMessage := &common.RawMessage{ Payload: payload, CommunityID: community.ID(), ResendType: common.ResendTypeRawMessage, @@ -1506,30 +1506,32 @@ func (m *Messenger) RequestToJoinCommunity(request *requests.RequestToJoinCommun PubsubTopic: shard.DefaultNonProtectedPubsubTopic(), } - _, err = m.SendMessageToControlNode(community, &rawMessage) + _, err = m.SendMessageToControlNode(community, rawMessage) if err != nil { return nil, err } + if _, err = m.AddRawMessageToWatch(rawMessage); err != nil { + return nil, err + } + if !community.AutoAccept() { - privilegedMembers := community.GetFilteredPrivilegedMembers(map[string]struct{}{}) + privilegedMembersSorted := community.GetFilteredPrivilegedMembers(map[string]struct{}{m.IdentityPublicKeyString(): {}}) + privMembersArray := []*ecdsa.PublicKey{} - for _, member := range privilegedMembers[protobuf.CommunityMember_ROLE_OWNER] { - rawMessage.Recipients = append(rawMessage.Recipients, member) - _, err := m.sender.SendPrivate(context.Background(), member, &rawMessage) - if err != nil { - return nil, err - } - } - for _, member := range privilegedMembers[protobuf.CommunityMember_ROLE_TOKEN_MASTER] { - rawMessage.Recipients = append(rawMessage.Recipients, member) - _, err := m.sender.SendPrivate(context.Background(), member, &rawMessage) - if err != nil { - return nil, err - } + if rawMessage.ResendMethod != common.ResendMethodSendPrivate { + privMembersArray = append(privMembersArray, privilegedMembersSorted[protobuf.CommunityMember_ROLE_OWNER]...) } - // don't send revealed addresses to admins + privMembersArray = append(privMembersArray, privilegedMembersSorted[protobuf.CommunityMember_ROLE_TOKEN_MASTER]...) + privMembersArray = append(privMembersArray, privilegedMembersSorted[protobuf.CommunityMember_ROLE_ADMIN]...) + + rawMessage.ResendMethod = common.ResendMethodSendPrivate + rawMessage.ID = "" + rawMessage.Recipients = privMembersArray + + // don't send revealed addresses to privileged members + // tokenMaster and owner without community private key will receive them from control node requestToJoinProto.RevealedAccounts = make([]*protobuf.RevealedAccount, 0) payload, err = proto.Marshal(requestToJoinProto) if err != nil { @@ -1537,17 +1539,18 @@ func (m *Messenger) RequestToJoinCommunity(request *requests.RequestToJoinCommun } rawMessage.Payload = payload - for _, member := range privilegedMembers[protobuf.CommunityMember_ROLE_ADMIN] { - rawMessage.Recipients = append(rawMessage.Recipients, member) - _, err := m.sender.SendPrivate(context.Background(), member, &rawMessage) + for _, member := range rawMessage.Recipients { + _, err := m.sender.SendPrivate(context.Background(), member, rawMessage) if err != nil { return nil, err } } - } - if _, err = m.UpsertRawMessageToWatch(&rawMessage); err != nil { - return nil, err + if len(rawMessage.Recipients) > 0 { + if _, err = m.AddRawMessageToWatch(rawMessage); err != nil { + return nil, err + } + } } response := &MessengerResponse{} @@ -1681,7 +1684,7 @@ func (m *Messenger) EditSharedAddressesForCommunity(request *requests.EditShared return nil, err } - if _, err = m.UpsertRawMessageToWatch(&rawMessage); err != nil { + if _, err = m.AddRawMessageToWatch(&rawMessage); err != nil { return nil, err } @@ -1724,8 +1727,9 @@ func (m *Messenger) PublishTokenActionToPrivilegedMembers(communityID []byte, ch allRecipients := privilegedMembers[protobuf.CommunityMember_ROLE_OWNER] allRecipients = append(allRecipients, privilegedMembers[protobuf.CommunityMember_ROLE_TOKEN_MASTER]...) + rawMessage.Recipients = allRecipients - for _, recipient := range allRecipients { + for _, recipient := range rawMessage.Recipients { _, err := m.sender.SendPrivate(context.Background(), recipient, &rawMessage) if err != nil { return err @@ -1733,8 +1737,7 @@ func (m *Messenger) PublishTokenActionToPrivilegedMembers(communityID []byte, ch } if len(allRecipients) > 0 { - rawMessage.Recipients = allRecipients - if _, err = m.UpsertRawMessageToWatch(&rawMessage); err != nil { + if _, err = m.AddRawMessageToWatch(&rawMessage); err != nil { return err } } @@ -1885,22 +1888,47 @@ func (m *Messenger) CancelRequestToJoinCommunity(ctx context.Context, request *r return nil, err } + // NOTE: rawMessage.ID is generated from payload + sender + messageType + // rawMessage.ID will be the same for control node and privileged members, but for + // community without owner token resend type is different + // in order not to override msg to control node by message for privileged members, + // we skip storing the same message for privileged members + avoidDuplicateWatchingForPrivilegedMembers := community.AutoAccept() || rawMessage.ResendMethod != common.ResendMethodSendPrivate + if avoidDuplicateWatchingForPrivilegedMembers { + if _, err = m.AddRawMessageToWatch(&rawMessage); err != nil { + return nil, err + } + } + if !community.AutoAccept() { // send cancelation to community admins also rawMessage.Payload = payload + rawMessage.ResendMethod = common.ResendMethodSendPrivate - privilegedMembers := community.GetPrivilegedMembers() - for _, privilegedMember := range privilegedMembers { - rawMessage.Recipients = append(rawMessage.Recipients, privilegedMember) + privilegedMembersSorted := community.GetFilteredPrivilegedMembers(map[string]struct{}{m.IdentityPublicKeyString(): {}}) + privMembersArray := privilegedMembersSorted[protobuf.CommunityMember_ROLE_TOKEN_MASTER] + privMembersArray = append(privMembersArray, privilegedMembersSorted[protobuf.CommunityMember_ROLE_ADMIN]...) + + if !avoidDuplicateWatchingForPrivilegedMembers { + // control node was added to the recipients during 'SendMessageToControlNode' + rawMessage.Recipients = append(rawMessage.Recipients, privMembersArray...) + } else { + privMembersArray = append(privMembersArray, privilegedMembersSorted[protobuf.CommunityMember_ROLE_OWNER]...) + rawMessage.Recipients = privMembersArray + } + + for _, privilegedMember := range privMembersArray { _, err := m.sender.SendPrivate(context.Background(), privilegedMember, &rawMessage) if err != nil { return nil, err } } - } - if _, err = m.UpsertRawMessageToWatch(&rawMessage); err != nil { - return nil, err + if !avoidDuplicateWatchingForPrivilegedMembers { + if _, err = m.AddRawMessageToWatch(&rawMessage); err != nil { + return nil, err + } + } } response := &MessengerResponse{} @@ -2012,7 +2040,7 @@ func (m *Messenger) acceptRequestToJoinCommunity(requestToJoin *communities.Requ return nil, err } - if _, err = m.UpsertRawMessageToWatch(rawMessage); err != nil { + if _, err = m.AddRawMessageToWatch(rawMessage); err != nil { return nil, err } } @@ -2200,7 +2228,7 @@ func (m *Messenger) LeaveCommunity(communityID types.HexBytes) (*MessengerRespon return nil, err } - if _, err = m.UpsertRawMessageToWatch(&rawMessage); err != nil { + if _, err = m.AddRawMessageToWatch(&rawMessage); err != nil { return nil, err } } @@ -3686,7 +3714,7 @@ func (m *Messenger) sendSharedAddressToControlNode(receiver *ecdsa.PublicKey, co return nil, err } - _, err = m.UpsertRawMessageToWatch(&rawMessage) + _, err = m.AddRawMessageToWatch(&rawMessage) return requestToJoin, err } diff --git a/protocol/messenger_raw_message_resend.go b/protocol/messenger_raw_message_resend.go index 80f08445f..b09d3aa1b 100644 --- a/protocol/messenger_raw_message_resend.go +++ b/protocol/messenger_raw_message_resend.go @@ -65,12 +65,7 @@ func (m *Messenger) processMessageID(id string) (*common.RawMessage, bool, error return nil, false, errors.Wrap(err, "Can't get raw message by ID") } - shouldResend, err := m.shouldResendMessage(rawMessage, m.getTimesource()) - if err != nil { - m.logger.Error("Can't check if message should be resent", zap.Error(err)) - return rawMessage, false, err - } - + shouldResend := m.shouldResendMessage(rawMessage, m.getTimesource()) if !shouldResend { return rawMessage, false, nil } @@ -137,15 +132,16 @@ func (m *Messenger) handleOtherResendMethods(rawMessage *common.RawMessage) (boo return true, m.reSendRawMessage(context.Background(), rawMessage.ID) } -func (m *Messenger) shouldResendMessage(message *common.RawMessage, t common.TimeSource) (bool, error) { +func (m *Messenger) shouldResendMessage(message *common.RawMessage, t common.TimeSource) bool { if m.featureFlags.ResendRawMessagesDisabled { - return false, nil + return false } //exponential backoff depends on how many attempts to send message already made power := math.Pow(2, float64(message.SendCount-1)) backoff := uint64(power) * uint64(m.config.messageResendMinDelay.Milliseconds()) backoffElapsed := t.GetCurrentTime() > (message.LastSent + backoff) - return backoffElapsed, nil + + return backoffElapsed } // pull a message from the database and send it again @@ -184,6 +180,17 @@ func (m *Messenger) UpsertRawMessageToWatch(rawMessage *common.RawMessage) (*com return rawMessage, nil } +// AddRawMessageToWatch check if RawMessage is correct and insert the rawMessage to the database +// relate watch method: Messenger#watchExpiredMessages +func (m *Messenger) AddRawMessageToWatch(rawMessage *common.RawMessage) (*common.RawMessage, error) { + if err := m.sender.ValidateRawMessage(rawMessage); err != nil { + m.logger.Error("Can't add raw message to watch", zap.String("messageID", rawMessage.ID), zap.Error(err)) + return nil, err + } + + return m.UpsertRawMessageToWatch(rawMessage) +} + func (m *Messenger) upsertRawMessageToWatch(rawMessage *common.RawMessage) { _, err := m.UpsertRawMessageToWatch(rawMessage) if err != nil {