diff --git a/protocol/communities/manager.go b/protocol/communities/manager.go index 1bbbf35ff..a635d3f8a 100644 --- a/protocol/communities/manager.go +++ b/protocol/communities/manager.go @@ -3694,17 +3694,20 @@ func (m *Manager) GetRequestToJoinIDByPkAndCommunityID(pk *ecdsa.PublicKey, comm } func (m *Manager) GetCommunityRequestToJoinClock(pk *ecdsa.PublicKey, communityID string) (uint64, error) { - request, err := m.persistence.GetRequestToJoinByPkAndCommunityID(common.PubkeyToHex(pk), []byte(communityID)) + communityIDBytes, err := types.DecodeHex(communityID) + if err != nil { + return 0, err + } + + joinClock, err := m.persistence.GetRequestToJoinClockByPkAndCommunityID(common.PubkeyToHex(pk), communityIDBytes) + if errors.Is(err, sql.ErrNoRows) { return 0, nil } else if err != nil { return 0, err } - if request == nil || request.State != RequestToJoinStateAccepted { - return 0, nil - } - return request.Clock, nil + return joinClock, nil } func (m *Manager) GetRequestToJoinByPkAndCommunityID(pk *ecdsa.PublicKey, communityID []byte) (*RequestToJoin, error) { diff --git a/protocol/communities/persistence.go b/protocol/communities/persistence.go index 81d5c493b..60ac19015 100644 --- a/protocol/communities/persistence.go +++ b/protocol/communities/persistence.go @@ -862,6 +862,16 @@ func (p *Persistence) GetNumberOfPendingRequestsToJoin(communityID types.HexByte return count, nil } +func (p *Persistence) GetRequestToJoinClockByPkAndCommunityID(pk string, communityID types.HexBytes) (uint64, error) { + var clock uint64 + + err := p.db.QueryRow(` + SELECT clock + FROM communities_requests_to_join + WHERE public_key = ? AND community_id = ?`, pk, communityID).Scan(&clock) + return clock, err +} + func (p *Persistence) GetRequestToJoinByPkAndCommunityID(pk string, communityID []byte) (*RequestToJoin, error) { request := &RequestToJoin{} err := p.db.QueryRow(`SELECT id,public_key,clock,ens_name,customization_color,chat_id,community_id,state FROM communities_requests_to_join WHERE public_key = ? AND community_id = ?`, pk, communityID).Scan(&request.ID, &request.PublicKey, &request.Clock, &request.ENSName, &request.CustomizationColor, &request.ChatID, &request.CommunityID, &request.State) diff --git a/protocol/communities_messenger_test.go b/protocol/communities_messenger_test.go index a47fdefb5..8031a1266 100644 --- a/protocol/communities_messenger_test.go +++ b/protocol/communities_messenger_test.go @@ -4407,3 +4407,83 @@ func (s *MessengerCommunitiesSuite) TestOpenAndNotJoinedCommunityNewChannelIsNot s.Require().Len(chat.Members, 2) } } + +func (s *MessengerCommunitiesSuite) sendMention(sender *Messenger, chatID string) *common.Message { + ctx := context.Background() + messageToSend := common.NewMessage() + messageToSend.ChatId = chatID + messageToSend.ContentType = protobuf.ChatMessage_TEXT_PLAIN + messageToSend.Text = "Hello @" + common.EveryoneMentionTag + + response, err := sender.SendChatMessage(ctx, messageToSend) + s.Require().NoError(err) + s.Require().Len(response.Messages(), 1) + s.Require().True(response.Messages()[0].Mentioned) + return response.Messages()[0] +} + +func (s *MessengerCommunitiesSuite) TestAliceDoesNotReceiveMentionWhenSpectating() { + // GIVEN: Create an open community + community, communityChat := s.createCommunity() + community, err := s.owner.GetCommunityByID(community.ID()) + s.Require().NoError(err) + s.Require().Len(community.Chats(), 1) + s.Require().False(community.Encrypted()) + + // Alice SPECTATES the community + advertiseCommunityTo(&s.Suite, community, s.owner, s.alice) + _, err = s.alice.SpectateCommunity(community.ID()) + s.Require().NoError(err) + + aliceCommunity, err := s.alice.GetCommunityByID(community.ID()) + s.Require().NoError(err) + s.Require().Contains(aliceCommunity.ChatIDs(), communityChat.ID) + + // Bob JOINS the community + advertiseCommunityTo(&s.Suite, community, s.owner, s.bob) + request := &requests.RequestToJoinCommunity{CommunityID: community.ID()} + joinCommunity(&s.Suite, community, s.owner, s.bob, request, "") + + // Check Alice gets the updated community + _, err = WaitOnMessengerResponse( + s.alice, + func(r *MessengerResponse) bool { + return len(r.Communities()) > 0 && r.Communities()[0].MembersCount() == 2 + }, + "no community updates for Alice", + ) + s.Require().NoError(err) + + // WHEN: Bob sends a message to a channel with mention + sentMessage := s.sendMention(s.bob, communityChat.ID) + + // THEN: Check Alice gets the message, but no activity center notification + _, err = WaitOnMessengerResponse( + s.alice, + func(r *MessengerResponse) bool { + return len(r.Messages()) == 1 && len(r.ActivityCenterNotifications()) == 0 && + r.Messages()[0].ID == sentMessage.ID + }, + "no message for Alice", + ) + s.Require().NoError(err) + + // Alice joins community + request = &requests.RequestToJoinCommunity{CommunityID: community.ID()} + joinCommunity(&s.Suite, community, s.owner, s.alice, request, "") + + // Bob sends a message with mention + sentMessage = s.sendMention(s.bob, communityChat.ID) + + // Check Alice gets the message and activity center notification + _, err = WaitOnMessengerResponse( + s.alice, + func(r *MessengerResponse) bool { + return len(r.Messages()) == 1 && len(r.ActivityCenterNotifications()) == 1 && + r.Messages()[0].ID == sentMessage.ID && r.ActivityCenterNotifications()[0].Message.ID == sentMessage.ID && + r.ActivityCenterNotifications()[0].Type == ActivityCenterNotificationTypeMention + }, + "no message for Alice", + ) + s.Require().NoError(err) +} diff --git a/protocol/messenger.go b/protocol/messenger.go index 3949871cd..21b58c7fc 100644 --- a/protocol/messenger.go +++ b/protocol/messenger.go @@ -3542,6 +3542,28 @@ func (r *ReceivedMessageState) updateExistingActivityCenterNotification(publicKe return nil } +// function returns if the community is joined before the clock +func (m *Messenger) isCommunityJoinedBeforeClock(publicKey ecdsa.PublicKey, communityID string, clock uint64) (bool, error) { + community, err := m.communitiesManager.GetByIDString(communityID) + if err != nil { + return false, err + } + + if !community.Joined() || clock < uint64(community.JoinedAt()) { + joinedClock, err := m.communitiesManager.GetCommunityRequestToJoinClock(&publicKey, communityID) + if err != nil { + return false, err + } + + // no request to join, or request to join is after the message + if joinedClock == 0 || clock < joinedClock { + return false, nil + } + return true, nil + } + return true, nil +} + // addNewActivityCenterNotification takes a common.Message and generates a new ActivityCenterNotification and appends it to the // []Response.ActivityCenterNotifications if the message is m.New func (r *ReceivedMessageState) addNewActivityCenterNotification(publicKey ecdsa.PublicKey, m *Messenger, message *common.Message, responseTo *common.Message) error { @@ -3560,13 +3582,9 @@ func (r *ReceivedMessageState) addNewActivityCenterNotification(publicKey ecdsa. } if chat.CommunityChat() { - joinedClock, err := m.communitiesManager.GetCommunityRequestToJoinClock(&publicKey, message.CommunityID) - if err != nil { - return err - } - // Ignore mentions & replies in community before joining - if message.Clock < joinedClock { + ok, err := m.isCommunityJoinedBeforeClock(publicKey, chat.CommunityID, message.Clock) + if err != nil || !ok { return nil } }