Fix chats reseting when a Sync signal comes for communities (#3539)

* fix(community): stop re-joining comm when receiving a sync community msg

Fixes an issue with chats being reset. Since joining a community resaves the chats with the synced default value, it resets the sate of the chats, losing the unread messages, the muted state and more.
The solution is to block the re-joining of the community. In the case of the sync, we catch that error and just continue on.

* fix(import): fix HandleImport not saving the chat

Doesn't change much, but it could have caused issues in the future, so since we might have modified the chat, we make sure to save them
Also adds a test

* fix tests
This commit is contained in:
Jonathan Rainville 2023-05-29 13:57:05 -04:00 committed by GitHub
parent 589cc965e3
commit a6285cc827
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 167 additions and 26 deletions

View File

@ -159,7 +159,7 @@ func main() {
} }
logger.Info("GOT community", "comm", chat) logger.Info("GOT community", "comm", chat)
response, err := messenger.JoinCommunity(context.Background(), community.ID()) response, err := messenger.JoinCommunity(context.Background(), community.ID(), false)
if err != nil { if err != nil {
logger.Error("failed to join community", "err", err) logger.Error("failed to join community", "err", err)
} }

View File

@ -7,6 +7,7 @@ var ErrCategoryNotFound = errors.New("category not found")
var ErrNoChangeInPosition = errors.New("no change in category position") var ErrNoChangeInPosition = errors.New("no change in category position")
var ErrChatAlreadyAssigned = errors.New("chat already assigned to a category") var ErrChatAlreadyAssigned = errors.New("chat already assigned to a category")
var ErrOrgNotFound = errors.New("community not found") var ErrOrgNotFound = errors.New("community not found")
var ErrOrgAlreadyJoined = errors.New("community already joined")
var ErrChatAlreadyExists = errors.New("chat already exists") var ErrChatAlreadyExists = errors.New("chat already exists")
var ErrCategoryAlreadyExists = errors.New("category already exists") var ErrCategoryAlreadyExists = errors.New("category already exists")
var ErrCantRequestAccess = errors.New("can't request access") var ErrCantRequestAccess = errors.New("can't request access")

View File

@ -1916,7 +1916,7 @@ func (m *Manager) HandleWrappedCommunityDescriptionMessage(payload []byte) (*Com
return m.HandleCommunityDescriptionMessage(signer, description, payload) return m.HandleCommunityDescriptionMessage(signer, description, payload)
} }
func (m *Manager) JoinCommunity(id types.HexBytes) (*Community, error) { func (m *Manager) JoinCommunity(id types.HexBytes, forceJoin bool) (*Community, error) {
community, err := m.GetByID(id) community, err := m.GetByID(id)
if err != nil { if err != nil {
return nil, err return nil, err
@ -1924,6 +1924,10 @@ func (m *Manager) JoinCommunity(id types.HexBytes) (*Community, error) {
if community == nil { if community == nil {
return nil, ErrOrgNotFound return nil, ErrOrgNotFound
} }
if !forceJoin && community.Joined() {
// Nothing to do, we are already joined
return community, ErrOrgAlreadyJoined
}
community.Join() community.Join()
err = m.persistence.SaveCommunity(community) err = m.persistence.SaveCommunity(community)
if err != nil { if err != nil {

View File

@ -13,6 +13,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/golang/protobuf/proto"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"go.uber.org/zap" "go.uber.org/zap"
@ -32,7 +33,9 @@ import (
"github.com/status-im/status-go/protocol/protobuf" "github.com/status-im/status-go/protocol/protobuf"
"github.com/status-im/status-go/protocol/requests" "github.com/status-im/status-go/protocol/requests"
"github.com/status-im/status-go/protocol/sqlite" "github.com/status-im/status-go/protocol/sqlite"
"github.com/status-im/status-go/protocol/transport"
"github.com/status-im/status-go/protocol/tt" "github.com/status-im/status-go/protocol/tt"
v1protocol "github.com/status-im/status-go/protocol/v1"
"github.com/status-im/status-go/waku" "github.com/status-im/status-go/waku"
) )
@ -347,7 +350,7 @@ func (s *MessengerCommunitiesSuite) TestJoinCommunity() {
s.Require().Equal(community.IDString(), response.Messages()[0].CommunityID) s.Require().Equal(community.IDString(), response.Messages()[0].CommunityID)
// We join the org // We join the org
response, err = s.alice.JoinCommunity(ctx, community.ID()) response, err = s.alice.JoinCommunity(ctx, community.ID(), false)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().NotNil(response) s.Require().NotNil(response)
s.Require().Len(response.Communities(), 1) s.Require().Len(response.Communities(), 1)
@ -716,7 +719,7 @@ func (s *MessengerCommunitiesSuite) TestPostToCommunityChat() {
ctx := context.Background() ctx := context.Background()
// We join the org // We join the org
response, err = s.alice.JoinCommunity(ctx, community.ID()) response, err = s.alice.JoinCommunity(ctx, community.ID(), false)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().NotNil(response) s.Require().NotNil(response)
s.Require().Len(response.Communities(), 1) s.Require().Len(response.Communities(), 1)
@ -2983,7 +2986,7 @@ func (s *MessengerCommunitiesSuite) TestSyncCommunity_Leave() {
} }
// alice joins the community // alice joins the community
mr, err = s.alice.JoinCommunity(context.Background(), community.ID()) mr, err = s.alice.JoinCommunity(context.Background(), community.ID(), false)
s.Require().NoError(err, "s.alice.JoinCommunity") s.Require().NoError(err, "s.alice.JoinCommunity")
s.Require().NotNil(mr) s.Require().NotNil(mr)
s.Len(mr.Communities(), 1) s.Len(mr.Communities(), 1)
@ -3309,3 +3312,130 @@ func (s *MessengerCommunitiesSuite) TestCommunityBanUserRequesToJoin() {
s.Require().ErrorContains(err, "can't request access") s.Require().ErrorContains(err, "can't request access")
} }
func (s *MessengerCommunitiesSuite) TestHandleImport() {
description := &requests.CreateCommunity{
Membership: protobuf.CommunityPermissions_INVITATION_ONLY,
Name: "status",
Color: "#ffffff",
Description: "status community description",
}
// Create a community
response, err := s.bob.CreateCommunity(description, true)
s.Require().NoError(err)
s.Require().NotNil(response)
s.Require().Len(response.Communities(), 1)
s.Require().Len(response.Communities()[0].Chats(), 1)
s.Require().Len(response.Chats(), 1)
community := response.Communities()[0]
// Create chat
orgChat := &protobuf.CommunityChat{
Permissions: &protobuf.CommunityPermissions{
Access: protobuf.CommunityPermissions_NO_MEMBERSHIP,
},
Identity: &protobuf.ChatIdentity{
DisplayName: "status-core",
Description: "status-core community chat",
},
}
response, err = s.bob.CreateCommunityChat(community.ID(), orgChat)
s.Require().NoError(err)
s.Require().NotNil(response)
s.Require().Len(response.Communities(), 1)
s.Require().Len(response.Communities()[0].Chats(), 2)
s.Require().Len(response.Chats(), 1)
response, err = s.bob.InviteUsersToCommunity(
&requests.InviteUsersToCommunity{
CommunityID: community.ID(),
Users: []types.HexBytes{common.PubkeyToHexBytes(&s.alice.identity.PublicKey)},
},
)
s.Require().NoError(err)
s.Require().NotNil(response)
s.Require().Len(response.Communities(), 1)
community = response.Communities()[0]
s.Require().True(community.HasMember(&s.alice.identity.PublicKey))
// Pull message and make sure org is received
err = tt.RetryWithBackOff(func() error {
response, err = s.alice.RetrieveAll()
if err != nil {
return err
}
if len(response.Communities()) == 0 {
return errors.New("community not received")
}
return nil
})
s.Require().NoError(err)
communities, err := s.alice.Communities()
s.Require().NoError(err)
s.Require().Len(communities, 2)
s.Require().Len(response.Communities(), 1)
communityID := response.Communities()[0].ID()
s.Require().Equal(communityID, community.ID())
ctx := context.Background()
// We join the org
response, err = s.alice.JoinCommunity(ctx, community.ID(), false)
s.Require().NoError(err)
s.Require().NotNil(response)
s.Require().Len(response.Communities(), 1)
s.Require().Len(response.Communities()[0].Chats(), 2)
s.Require().True(response.Communities()[0].Joined())
s.Require().Len(response.Chats(), 2)
chatID := response.Chats()[1].ID
// Check that there are no messages in the chat at first
chat, err := s.alice.persistence.Chat(chatID)
s.Require().NoError(err)
s.Require().NotNil(chat)
s.Require().Equal(0, int(chat.UnviewedMessagesCount))
// Create an message that will be imported
testMessage := protobuf.ChatMessage{
Text: "abc123",
ChatId: chatID,
ContentType: protobuf.ChatMessage_TEXT_PLAIN,
MessageType: protobuf.MessageType_COMMUNITY_CHAT,
Clock: 1,
Timestamp: 1,
}
encodedPayload, err := proto.Marshal(&testMessage)
s.Require().NoError(err)
wrappedPayload, err := v1protocol.WrapMessageV1(
encodedPayload,
protobuf.ApplicationMetadataMessage_CHAT_MESSAGE,
s.bob.identity,
)
s.Require().NoError(err)
message := &types.Message{}
message.Sig = crypto.FromECDSAPub(&s.bob.identity.PublicKey)
message.Payload = wrappedPayload
filter := s.alice.transport.FilterByChatID(chatID)
importedMessages := make(map[transport.Filter][]*types.Message, 0)
importedMessages[*filter] = append(importedMessages[*filter], message)
// Import that message
err = s.alice.handleImportedMessages(importedMessages)
s.Require().NoError(err)
// Get the chat again and see that there is still no unread message because we don't count import messages
chat, err = s.alice.persistence.Chat(chatID)
s.Require().NoError(err)
s.Require().NotNil(chat)
s.Require().Equal(0, int(chat.UnviewedMessagesCount))
}

View File

@ -3367,6 +3367,13 @@ func (m *Messenger) handleImportedMessages(messagesToHandle map[transport.Filter
} }
} }
} }
// Save chats if they were modified
if len(messageState.Response.chats) > 0 {
err := m.saveChats(messageState.Response.Chats())
if err != nil {
return err
}
}
return nil return nil
} }
@ -4346,7 +4353,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
// Process any community changes // Process any community changes
for _, changes := range messageState.Response.CommunityChanges { for _, changes := range messageState.Response.CommunityChanges {
if changes.ShouldMemberJoin { if changes.ShouldMemberJoin {
response, err := m.joinCommunity(context.TODO(), changes.Community.ID()) response, err := m.joinCommunity(context.TODO(), changes.Community.ID(), false)
if err != nil { if err != nil {
logger.Error("cannot join community", zap.Error(err)) logger.Error("cannot join community", zap.Error(err))
continue continue

View File

@ -158,7 +158,7 @@ func (s *MessengerActivityCenterMessageSuite) TestEveryoneMentionTag() {
s.Require().NoError(err) s.Require().NoError(err)
// Alice joins the community // Alice joins the community
response, err = alice.JoinCommunity(context.Background(), community.ID()) response, err = alice.JoinCommunity(context.Background(), community.ID(), false)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().NotNil(response) s.Require().NotNil(response)
s.Require().Len(response.Communities(), 1) s.Require().Len(response.Communities(), 1)

View File

@ -436,8 +436,8 @@ func (m *Messenger) initCommunitySettings(communityID types.HexBytes) (*communit
return communitySettings, nil return communitySettings, nil
} }
func (m *Messenger) JoinCommunity(ctx context.Context, communityID types.HexBytes) (*MessengerResponse, error) { func (m *Messenger) JoinCommunity(ctx context.Context, communityID types.HexBytes, forceJoin bool) (*MessengerResponse, error) {
mr, err := m.joinCommunity(ctx, communityID) mr, err := m.joinCommunity(ctx, communityID, forceJoin)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -452,12 +452,12 @@ func (m *Messenger) JoinCommunity(ctx context.Context, communityID types.HexByte
return mr, nil return mr, nil
} }
func (m *Messenger) joinCommunity(ctx context.Context, communityID types.HexBytes) (*MessengerResponse, error) { func (m *Messenger) joinCommunity(ctx context.Context, communityID types.HexBytes, forceJoin bool) (*MessengerResponse, error) {
logger := m.logger.Named("joinCommunity") logger := m.logger.Named("joinCommunity")
response := &MessengerResponse{} response := &MessengerResponse{}
community, err := m.communitiesManager.JoinCommunity(communityID) community, err := m.communitiesManager.JoinCommunity(communityID, forceJoin)
if err != nil { if err != nil {
logger.Debug("m.communitiesManager.JoinCommunity error", zap.Error(err)) logger.Debug("m.communitiesManager.JoinCommunity error", zap.Error(err))
return nil, err return nil, err
@ -1429,7 +1429,7 @@ func (m *Messenger) ImportCommunity(ctx context.Context, key *ecdsa.PrivateKey)
return nil, err return nil, err
} }
response, err := m.JoinCommunity(ctx, community.ID()) response, err := m.JoinCommunity(ctx, community.ID(), true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -2097,8 +2097,8 @@ func (m *Messenger) handleSyncCommunity(messageState *ReceivedMessageState, sync
if !pending { if !pending {
var mr *MessengerResponse var mr *MessengerResponse
if syncCommunity.Joined { if syncCommunity.Joined {
mr, err = m.joinCommunity(context.Background(), syncCommunity.Id) mr, err = m.joinCommunity(context.Background(), syncCommunity.Id, false)
if err != nil { if err != nil && err != communities.ErrOrgAlreadyJoined {
logger.Debug("m.joinCommunity error", zap.Error(err)) logger.Debug("m.joinCommunity error", zap.Error(err))
return err return err
} }
@ -2109,10 +2109,12 @@ func (m *Messenger) handleSyncCommunity(messageState *ReceivedMessageState, sync
return err return err
} }
} }
err = messageState.Response.Merge(mr) if mr != nil {
if err != nil { err = messageState.Response.Merge(mr)
logger.Debug("messageState.Response.Merge error", zap.Error(err)) if err != nil {
return err logger.Debug("messageState.Response.Merge error", zap.Error(err))
return err
}
} }
} }

View File

@ -175,7 +175,7 @@ func (s *MessengerDeleteMessageForEveryoneSuite) inviteAndJoin(community *commun
}, "community not received") }, "community not received")
s.Require().NoError(err) s.Require().NoError(err)
response, err = target.JoinCommunity(context.Background(), community.ID()) response, err = target.JoinCommunity(context.Background(), community.ID(), false)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().NotNil(response) s.Require().NotNil(response)
s.Require().Len(response.Communities(), 1) s.Require().Len(response.Communities(), 1)

View File

@ -1435,7 +1435,7 @@ func (m *Messenger) HandleCommunityRequestToJoinResponse(state *ReceivedMessageS
} }
if requestToJoinResponseProto.Accepted { if requestToJoinResponseProto.Accepted {
response, err := m.JoinCommunity(context.Background(), requestToJoinResponseProto.CommunityId) response, err := m.JoinCommunity(context.Background(), requestToJoinResponseProto.CommunityId, false)
if err != nil { if err != nil {
return err return err
} }
@ -1570,12 +1570,10 @@ func (m *Messenger) HandleEditMessage(state *ReceivedMessageState, editMessage E
return err return err
} }
needToSaveChat := false
if chat.LastMessage != nil && chat.LastMessage.ID == editedMessage.ID { if chat.LastMessage != nil && chat.LastMessage.ID == editedMessage.ID {
chat.LastMessage = editedMessage chat.LastMessage = editedMessage
err := m.saveChat(chat) needToSaveChat = true
if err != nil {
return err
}
} }
responseTo, err := m.persistence.MessageByID(editedMessage.ResponseTo) responseTo, err := m.persistence.MessageByID(editedMessage.ResponseTo)
@ -1590,7 +1588,6 @@ func (m *Messenger) HandleEditMessage(state *ReceivedMessageState, editMessage E
editedMessageHasMentions := editedMessage.Mentioned editedMessageHasMentions := editedMessage.Mentioned
needToSaveChat := false
if editedMessageHasMentions && !originalMessageMentioned && !editedMessage.Seen { if editedMessageHasMentions && !originalMessageMentioned && !editedMessage.Seen {
// Increase unviewed count when the edited message has a mention and didn't have one before // Increase unviewed count when the edited message has a mention and didn't have one before
chat.UnviewedMentionsCount++ chat.UnviewedMentionsCount++

View File

@ -402,7 +402,7 @@ func (api *PublicAPI) SpectateCommunity(parent context.Context, communityID type
// JoinCommunity joins a community with the given ID // JoinCommunity joins a community with the given ID
func (api *PublicAPI) JoinCommunity(parent context.Context, communityID types.HexBytes) (*protocol.MessengerResponse, error) { func (api *PublicAPI) JoinCommunity(parent context.Context, communityID types.HexBytes) (*protocol.MessengerResponse, error) {
return api.service.messenger.JoinCommunity(parent, communityID) return api.service.messenger.JoinCommunity(parent, communityID, false)
} }
// LeaveCommunity leaves a commuity with the given ID // LeaveCommunity leaves a commuity with the given ID