fix(communities)_: validate if RawMessage from DB is valid before sending it

This commit is contained in:
Mykhailo Prakhov 2024-07-12 12:26:16 +02:00
parent 4d2d20cff4
commit 30fee0cfd3
5 changed files with 204 additions and 70 deletions

View File

@ -8,7 +8,7 @@ import (
"time"
"github.com/cenkalti/backoff/v3"
"github.com/golang/protobuf/proto"
"go.uber.org/zap"
"github.com/status-im/status-go/eth-node/types"
@ -55,11 +55,12 @@ func (s *MessengerRawMessageResendTest) SetupTest() {
signal.SetMobileSignalHandler(nil)
exchangeNodeConfig := &wakuv2.Config{
Port: 0,
EnableDiscV5: true,
EnablePeerExchangeServer: true,
ClusterID: 16,
DefaultShardPubsubTopic: shard.DefaultShardPubsubTopic(),
Port: 0,
EnableDiscV5: true,
EnablePeerExchangeServer: true,
ClusterID: 16,
DefaultShardPubsubTopic: shard.DefaultShardPubsubTopic(),
EnableStoreConfirmationForMessagesSent: false,
}
s.exchangeBootNode, err = wakuv2.New(nil, "", exchangeNodeConfig, s.logger.Named("pxServerNode"), nil, nil, nil, nil)
s.Require().NoError(err)
@ -208,14 +209,9 @@ func (s *MessengerRawMessageResendTest) setCreateAccountRequest(displayName, roo
}
}
// TestMessageSent tests if ApplicationMetadataMessage_COMMUNITY_REQUEST_TO_JOIN is in state `sent` without resending
func (s *MessengerRawMessageResendTest) TestMessageSent() {
ids, err := s.bobMessenger.RawMessagesIDsByType(protobuf.ApplicationMetadataMessage_COMMUNITY_REQUEST_TO_JOIN)
s.Require().NoError(err)
s.Require().Len(ids, 1)
err = tt.RetryWithBackOff(func() error {
rawMessage, err := s.bobMessenger.RawMessageByID(ids[0])
func (s *MessengerRawMessageResendTest) waitForMessageSent(messageID string) {
err := tt.RetryWithBackOff(func() error {
rawMessage, err := s.bobMessenger.RawMessageByID(messageID)
s.Require().NoError(err)
s.Require().NotNil(rawMessage)
if rawMessage.SendCount > 0 {
@ -226,16 +222,31 @@ func (s *MessengerRawMessageResendTest) TestMessageSent() {
s.Require().NoError(err)
}
// TestMessageSent tests if ApplicationMetadataMessage_COMMUNITY_REQUEST_TO_JOIN is in state `sent` without resending
func (s *MessengerRawMessageResendTest) TestMessageSent() {
ids, err := s.bobMessenger.RawMessagesIDsByType(protobuf.ApplicationMetadataMessage_COMMUNITY_REQUEST_TO_JOIN)
s.Require().NoError(err)
s.Require().Len(ids, 1)
s.waitForMessageSent(ids[0])
}
// TestMessageResend tests if ApplicationMetadataMessage_COMMUNITY_REQUEST_TO_JOIN is resent
func (s *MessengerRawMessageResendTest) TestMessageResend() {
ids, err := s.bobMessenger.RawMessagesIDsByType(protobuf.ApplicationMetadataMessage_COMMUNITY_REQUEST_TO_JOIN)
s.Require().NoError(err)
s.Require().Len(ids, 1)
// wait for Sent status for already sent message to make sure that sent message was delivered
// before testing resend
s.waitForMessageSent(ids[0])
rawMessage, err := s.bobMessenger.RawMessageByID(ids[0])
s.Require().NoError(err)
s.Require().NotNil(rawMessage)
s.Require().NoError(s.bobMessenger.UpdateRawMessageSent(rawMessage.ID, false))
s.Require().NoError(s.bobMessenger.UpdateRawMessageLastSent(rawMessage.ID, 0))
err = tt.RetryWithBackOff(func() error {
rawMessage, err := s.bobMessenger.RawMessageByID(ids[0])
s.Require().NoError(err)
@ -255,6 +266,47 @@ func (s *MessengerRawMessageResendTest) TestMessageResend() {
}, s.aliceMessenger)
}
func (s *MessengerRawMessageResendTest) TestInvalidRawMessageToWatchDoesNotProduceResendLoop() {
ids, err := s.bobMessenger.RawMessagesIDsByType(protobuf.ApplicationMetadataMessage_COMMUNITY_REQUEST_TO_JOIN)
s.Require().NoError(err)
s.Require().Len(ids, 1)
s.waitForMessageSent(ids[0])
rawMessage, err := s.bobMessenger.RawMessageByID(ids[0])
s.Require().NoError(err)
requestToJoinProto := &protobuf.CommunityRequestToJoin{}
err = proto.Unmarshal(rawMessage.Payload, requestToJoinProto)
s.Require().NoError(err)
requestToJoinProto.DisplayName = "invalid_ID"
payload, err := proto.Marshal(requestToJoinProto)
s.Require().NoError(err)
rawMessage.Payload = payload
_, err = s.bobMessenger.AddRawMessageToWatch(rawMessage)
s.Require().Error(err, common.ErrModifiedRawMessage)
// simulate storing msg with modified payload, but old message ID
_, err = s.bobMessenger.UpsertRawMessageToWatch(rawMessage)
s.Require().NoError(err)
s.Require().NoError(s.bobMessenger.UpdateRawMessageSent(rawMessage.ID, false))
s.Require().NoError(s.bobMessenger.UpdateRawMessageLastSent(rawMessage.ID, 0))
// check counter increased for invalid message to escape the loop
err = tt.RetryWithBackOff(func() error {
rawMessage, err := s.bobMessenger.RawMessageByID(ids[0])
s.Require().NoError(err)
s.Require().NotNil(rawMessage)
if rawMessage.SendCount < 2 {
return errors.New("message ApplicationMetadataMessage_COMMUNITY_REQUEST_TO_JOIN was not resent yet")
}
return nil
})
s.Require().NoError(err)
}
// To be removed in https://github.com/status-im/status-go/issues/4437
func advertiseCommunityToUserOldWay(s *suite.Suite, community *communities.Community, alice *protocol.Messenger, bob *protocol.Messenger) {
chat := protocol.CreateOneToOneChat(bob.IdentityPublicKeyString(), bob.IdentityPublicKey(), bob.GetTransport())

View File

@ -3,3 +3,4 @@ package common
import "errors"
var ErrRecordNotFound = errors.New("record not found")
var ErrModifiedRawMessage = errors.New("modified rawMessage")

View File

@ -213,7 +213,9 @@ func (s *MessageSender) SendPubsubTopicKey(
return nil, err
}
rawMessage.ID = types.EncodeHex(messageID)
if err = s.setMessageID(messageID, rawMessage); err != nil {
return nil, err
}
// Notify before dispatching, otherwise the dispatch subscription might happen
// earlier than the scheduled
@ -247,12 +249,14 @@ func (s *MessageSender) SendGroup(
}
// Calculate messageID first and set on raw message
wrappedMessage, err := s.wrapMessageV1(&rawMessage)
messageID, err := s.getMessageID(&rawMessage)
if err != nil {
return nil, errors.Wrap(err, "failed to wrap message")
return nil, err
}
if err = s.setMessageID(messageID, &rawMessage); err != nil {
return nil, err
}
messageID := v1protocol.MessageID(&rawMessage.Sender.PublicKey, wrappedMessage)
rawMessage.ID = types.EncodeHex(messageID)
// We call it only once, and we nil the function after so it doesn't get called again
if rawMessage.BeforeDispatch != nil {
@ -278,10 +282,42 @@ func (s *MessageSender) getMessageID(rawMessage *RawMessage) (types.HexBytes, er
}
messageID := v1protocol.MessageID(&rawMessage.Sender.PublicKey, wrappedMessage)
return messageID, nil
}
func (s *MessageSender) ValidateRawMessage(rawMessage *RawMessage) error {
id, err := s.getMessageID(rawMessage)
if err != nil {
return err
}
messageID := types.EncodeHex(id)
return s.validateMessageID(messageID, rawMessage)
}
func (s *MessageSender) validateMessageID(messageID string, rawMessage *RawMessage) error {
if len(rawMessage.ID) > 0 && rawMessage.ID != messageID {
s.logger.Error("failed to validate message ID, RawMessage content was modified",
zap.String("prevID", rawMessage.ID),
zap.String("newID", messageID),
zap.Any("contentType", rawMessage.MessageType))
return ErrModifiedRawMessage
}
return nil
}
func (s *MessageSender) setMessageID(messageID types.HexBytes, rawMessage *RawMessage) error {
msgID := types.EncodeHex(messageID)
if err := s.validateMessageID(msgID, rawMessage); err != nil {
return err
}
rawMessage.ID = msgID
return nil
}
func ShouldCommunityMessageBeEncrypted(msgType protobuf.ApplicationMetadataMessage_Type) bool {
return msgType == protobuf.ApplicationMetadataMessage_CHAT_MESSAGE ||
msgType == protobuf.ApplicationMetadataMessage_EDIT_MESSAGE ||
@ -308,7 +344,10 @@ func (s *MessageSender) sendCommunity(
if err != nil {
return nil, err
}
rawMessage.ID = types.EncodeHex(messageID)
if err = s.setMessageID(messageID, rawMessage); err != nil {
return nil, err
}
if rawMessage.BeforeDispatch != nil {
if err := rawMessage.BeforeDispatch(rawMessage); err != nil {
@ -418,7 +457,11 @@ func (s *MessageSender) sendPrivate(
}
messageID := v1protocol.MessageID(&rawMessage.Sender.PublicKey, wrappedMessage)
rawMessage.ID = types.EncodeHex(messageID)
if err = s.setMessageID(messageID, rawMessage); err != nil {
return nil, err
}
if rawMessage.BeforeDispatch != nil {
if err := rawMessage.BeforeDispatch(rawMessage); err != nil {
return nil, err
@ -479,7 +522,7 @@ func (s *MessageSender) sendPrivate(
}
s.logger.Debug("sent-message: private skipProtocolLayer",
zap.Strings("recipient", PubkeysToHex(rawMessage.Recipients)),
zap.String("recipient", PubkeyToHex(recipient)),
zap.String("messageID", messageID.String()),
zap.String("messageType", "private"),
zap.Any("contentType", rawMessage.MessageType),
@ -499,7 +542,7 @@ func (s *MessageSender) sendPrivate(
}
s.logger.Debug("sent-message: private without datasync",
zap.Strings("recipient", PubkeysToHex(rawMessage.Recipients)),
zap.String("recipient", PubkeyToHex(recipient)),
zap.String("messageID", messageID.String()),
zap.Any("contentType", rawMessage.MessageType),
zap.String("messageType", "private"),
@ -723,7 +766,10 @@ func (s *MessageSender) SendPublic(
newMessage.PubsubTopic = rawMessage.PubsubTopic
messageID := v1protocol.MessageID(&rawMessage.Sender.PublicKey, wrappedMessage)
rawMessage.ID = types.EncodeHex(messageID)
if err = s.setMessageID(messageID, &rawMessage); err != nil {
return nil, err
}
if rawMessage.BeforeDispatch != nil {
if err := rawMessage.BeforeDispatch(&rawMessage); err != nil {

View File

@ -1497,7 +1497,7 @@ func (m *Messenger) RequestToJoinCommunity(request *requests.RequestToJoinCommun
return nil, err
}
rawMessage := common.RawMessage{
rawMessage := &common.RawMessage{
Payload: payload,
CommunityID: community.ID(),
ResendType: common.ResendTypeRawMessage,
@ -1506,30 +1506,32 @@ func (m *Messenger) RequestToJoinCommunity(request *requests.RequestToJoinCommun
PubsubTopic: shard.DefaultNonProtectedPubsubTopic(),
}
_, err = m.SendMessageToControlNode(community, &rawMessage)
_, err = m.SendMessageToControlNode(community, rawMessage)
if err != nil {
return nil, err
}
if _, err = m.AddRawMessageToWatch(rawMessage); err != nil {
return nil, err
}
if !community.AutoAccept() {
privilegedMembers := community.GetFilteredPrivilegedMembers(map[string]struct{}{})
privilegedMembersSorted := community.GetFilteredPrivilegedMembers(map[string]struct{}{m.IdentityPublicKeyString(): {}})
privMembersArray := []*ecdsa.PublicKey{}
for _, member := range privilegedMembers[protobuf.CommunityMember_ROLE_OWNER] {
rawMessage.Recipients = append(rawMessage.Recipients, member)
_, err := m.sender.SendPrivate(context.Background(), member, &rawMessage)
if err != nil {
return nil, err
}
}
for _, member := range privilegedMembers[protobuf.CommunityMember_ROLE_TOKEN_MASTER] {
rawMessage.Recipients = append(rawMessage.Recipients, member)
_, err := m.sender.SendPrivate(context.Background(), member, &rawMessage)
if err != nil {
return nil, err
}
if rawMessage.ResendMethod != common.ResendMethodSendPrivate {
privMembersArray = append(privMembersArray, privilegedMembersSorted[protobuf.CommunityMember_ROLE_OWNER]...)
}
// don't send revealed addresses to admins
privMembersArray = append(privMembersArray, privilegedMembersSorted[protobuf.CommunityMember_ROLE_TOKEN_MASTER]...)
privMembersArray = append(privMembersArray, privilegedMembersSorted[protobuf.CommunityMember_ROLE_ADMIN]...)
rawMessage.ResendMethod = common.ResendMethodSendPrivate
rawMessage.ID = ""
rawMessage.Recipients = privMembersArray
// don't send revealed addresses to privileged members
// tokenMaster and owner without community private key will receive them from control node
requestToJoinProto.RevealedAccounts = make([]*protobuf.RevealedAccount, 0)
payload, err = proto.Marshal(requestToJoinProto)
if err != nil {
@ -1537,17 +1539,18 @@ func (m *Messenger) RequestToJoinCommunity(request *requests.RequestToJoinCommun
}
rawMessage.Payload = payload
for _, member := range privilegedMembers[protobuf.CommunityMember_ROLE_ADMIN] {
rawMessage.Recipients = append(rawMessage.Recipients, member)
_, err := m.sender.SendPrivate(context.Background(), member, &rawMessage)
for _, member := range rawMessage.Recipients {
_, err := m.sender.SendPrivate(context.Background(), member, rawMessage)
if err != nil {
return nil, err
}
}
}
if _, err = m.UpsertRawMessageToWatch(&rawMessage); err != nil {
return nil, err
if len(rawMessage.Recipients) > 0 {
if _, err = m.AddRawMessageToWatch(rawMessage); err != nil {
return nil, err
}
}
}
response := &MessengerResponse{}
@ -1681,7 +1684,7 @@ func (m *Messenger) EditSharedAddressesForCommunity(request *requests.EditShared
return nil, err
}
if _, err = m.UpsertRawMessageToWatch(&rawMessage); err != nil {
if _, err = m.AddRawMessageToWatch(&rawMessage); err != nil {
return nil, err
}
@ -1724,8 +1727,9 @@ func (m *Messenger) PublishTokenActionToPrivilegedMembers(communityID []byte, ch
allRecipients := privilegedMembers[protobuf.CommunityMember_ROLE_OWNER]
allRecipients = append(allRecipients, privilegedMembers[protobuf.CommunityMember_ROLE_TOKEN_MASTER]...)
rawMessage.Recipients = allRecipients
for _, recipient := range allRecipients {
for _, recipient := range rawMessage.Recipients {
_, err := m.sender.SendPrivate(context.Background(), recipient, &rawMessage)
if err != nil {
return err
@ -1733,8 +1737,7 @@ func (m *Messenger) PublishTokenActionToPrivilegedMembers(communityID []byte, ch
}
if len(allRecipients) > 0 {
rawMessage.Recipients = allRecipients
if _, err = m.UpsertRawMessageToWatch(&rawMessage); err != nil {
if _, err = m.AddRawMessageToWatch(&rawMessage); err != nil {
return err
}
}
@ -1885,22 +1888,47 @@ func (m *Messenger) CancelRequestToJoinCommunity(ctx context.Context, request *r
return nil, err
}
// NOTE: rawMessage.ID is generated from payload + sender + messageType
// rawMessage.ID will be the same for control node and privileged members, but for
// community without owner token resend type is different
// in order not to override msg to control node by message for privileged members,
// we skip storing the same message for privileged members
avoidDuplicateWatchingForPrivilegedMembers := community.AutoAccept() || rawMessage.ResendMethod != common.ResendMethodSendPrivate
if avoidDuplicateWatchingForPrivilegedMembers {
if _, err = m.AddRawMessageToWatch(&rawMessage); err != nil {
return nil, err
}
}
if !community.AutoAccept() {
// send cancelation to community admins also
rawMessage.Payload = payload
rawMessage.ResendMethod = common.ResendMethodSendPrivate
privilegedMembers := community.GetPrivilegedMembers()
for _, privilegedMember := range privilegedMembers {
rawMessage.Recipients = append(rawMessage.Recipients, privilegedMember)
privilegedMembersSorted := community.GetFilteredPrivilegedMembers(map[string]struct{}{m.IdentityPublicKeyString(): {}})
privMembersArray := privilegedMembersSorted[protobuf.CommunityMember_ROLE_TOKEN_MASTER]
privMembersArray = append(privMembersArray, privilegedMembersSorted[protobuf.CommunityMember_ROLE_ADMIN]...)
if !avoidDuplicateWatchingForPrivilegedMembers {
// control node was added to the recipients during 'SendMessageToControlNode'
rawMessage.Recipients = append(rawMessage.Recipients, privMembersArray...)
} else {
privMembersArray = append(privMembersArray, privilegedMembersSorted[protobuf.CommunityMember_ROLE_OWNER]...)
rawMessage.Recipients = privMembersArray
}
for _, privilegedMember := range privMembersArray {
_, err := m.sender.SendPrivate(context.Background(), privilegedMember, &rawMessage)
if err != nil {
return nil, err
}
}
}
if _, err = m.UpsertRawMessageToWatch(&rawMessage); err != nil {
return nil, err
if !avoidDuplicateWatchingForPrivilegedMembers {
if _, err = m.AddRawMessageToWatch(&rawMessage); err != nil {
return nil, err
}
}
}
response := &MessengerResponse{}
@ -2012,7 +2040,7 @@ func (m *Messenger) acceptRequestToJoinCommunity(requestToJoin *communities.Requ
return nil, err
}
if _, err = m.UpsertRawMessageToWatch(rawMessage); err != nil {
if _, err = m.AddRawMessageToWatch(rawMessage); err != nil {
return nil, err
}
}
@ -2200,7 +2228,7 @@ func (m *Messenger) LeaveCommunity(communityID types.HexBytes) (*MessengerRespon
return nil, err
}
if _, err = m.UpsertRawMessageToWatch(&rawMessage); err != nil {
if _, err = m.AddRawMessageToWatch(&rawMessage); err != nil {
return nil, err
}
}
@ -3686,7 +3714,7 @@ func (m *Messenger) sendSharedAddressToControlNode(receiver *ecdsa.PublicKey, co
return nil, err
}
_, err = m.UpsertRawMessageToWatch(&rawMessage)
_, err = m.AddRawMessageToWatch(&rawMessage)
return requestToJoin, err
}

View File

@ -65,12 +65,7 @@ func (m *Messenger) processMessageID(id string) (*common.RawMessage, bool, error
return nil, false, errors.Wrap(err, "Can't get raw message by ID")
}
shouldResend, err := m.shouldResendMessage(rawMessage, m.getTimesource())
if err != nil {
m.logger.Error("Can't check if message should be resent", zap.Error(err))
return rawMessage, false, err
}
shouldResend := m.shouldResendMessage(rawMessage, m.getTimesource())
if !shouldResend {
return rawMessage, false, nil
}
@ -137,15 +132,16 @@ func (m *Messenger) handleOtherResendMethods(rawMessage *common.RawMessage) (boo
return true, m.reSendRawMessage(context.Background(), rawMessage.ID)
}
func (m *Messenger) shouldResendMessage(message *common.RawMessage, t common.TimeSource) (bool, error) {
func (m *Messenger) shouldResendMessage(message *common.RawMessage, t common.TimeSource) bool {
if m.featureFlags.ResendRawMessagesDisabled {
return false, nil
return false
}
//exponential backoff depends on how many attempts to send message already made
power := math.Pow(2, float64(message.SendCount-1))
backoff := uint64(power) * uint64(m.config.messageResendMinDelay.Milliseconds())
backoffElapsed := t.GetCurrentTime() > (message.LastSent + backoff)
return backoffElapsed, nil
return backoffElapsed
}
// pull a message from the database and send it again
@ -184,6 +180,17 @@ func (m *Messenger) UpsertRawMessageToWatch(rawMessage *common.RawMessage) (*com
return rawMessage, nil
}
// AddRawMessageToWatch check if RawMessage is correct and insert the rawMessage to the database
// relate watch method: Messenger#watchExpiredMessages
func (m *Messenger) AddRawMessageToWatch(rawMessage *common.RawMessage) (*common.RawMessage, error) {
if err := m.sender.ValidateRawMessage(rawMessage); err != nil {
m.logger.Error("Can't add raw message to watch", zap.String("messageID", rawMessage.ID), zap.Error(err))
return nil, err
}
return m.UpsertRawMessageToWatch(rawMessage)
}
func (m *Messenger) upsertRawMessageToWatch(rawMessage *common.RawMessage) {
_, err := m.UpsertRawMessageToWatch(rawMessage)
if err != nil {