diff --git a/protocol/message_persistence.go b/protocol/message_persistence.go index 133f35a89..f482ac649 100644 --- a/protocol/message_persistence.go +++ b/protocol/message_persistence.go @@ -43,6 +43,9 @@ var basicInsertDiscordMessageAuthorQuery = `INSERT OR REPLACE INTO discord_messa var cursor = "substr('0000000000000000000000000000000000000000000000000000000000000000' || m1.clock_value, -64, 64) || m1.id" var cursorField = cursor + " as cursor" +var caseSensitiveSearchCond = "(m1.text LIKE '%' || ? || '%' OR bm.content LIKE '%' || ? || '%' OR dm.content LIKE '%' || ? || '%')" +var caseInsensitiveSearchCond = "(LOWER(m1.text) LIKE LOWER('%' || ? || '%') OR LOWER(bm.content) LIKE LOWER('%' || ? || '%') OR LOWER(dm.content) LIKE LOWER('%' || ? || '%'))" + func (db sqlitePersistence) buildMessagesQueryWithAdditionalFields(additionalSelectFields, whereAndTheRest string) string { allFields := db.tableUserMessagesAllFieldsJoin() if additionalSelectFields != "" { @@ -1002,9 +1005,9 @@ func (db sqlitePersistence) AllMessageByChatIDWhichMatchTerm(chatID string, sear searchCond := "" if caseSensitive { - searchCond = "AND m1.text LIKE '%' || ? || '%'" + searchCond = "AND " + caseSensitiveSearchCond } else { - searchCond = "AND LOWER(m1.text) LIKE LOWER('%' || ? || '%')" + searchCond = "AND " + caseInsensitiveSearchCond } where := fmt.Sprintf(` @@ -1015,7 +1018,10 @@ func (db sqlitePersistence) AllMessageByChatIDWhichMatchTerm(chatID string, sear query := db.buildMessagesQueryWithAdditionalFields(cursorField, where) rows, err := db.db.Query( query, - chatID, searchTerm, + chatID, + searchTerm, + searchTerm, + searchTerm, ) if err != nil { @@ -1051,9 +1057,9 @@ func (db sqlitePersistence) AllMessagesFromChatsAndCommunitiesWhichMatchTerm(com searchCond := "" if caseSensitive { - searchCond = "m1.text LIKE '%' || ? || '%'" + searchCond = caseSensitiveSearchCond } else { - searchCond = "LOWER(m1.text) LIKE LOWER('%' || ? || '%')" + searchCond = caseInsensitiveSearchCond } finalCond := "AND %s AND %s" @@ -1072,6 +1078,8 @@ func (db sqlitePersistence) AllMessagesFromChatsAndCommunitiesWhichMatchTerm(com parameters = append(parameters, chatIds...) parameters = append(parameters, communityIds...) parameters = append(parameters, searchTerm) + parameters = append(parameters, searchTerm) + parameters = append(parameters, searchTerm) idsArgs := make([]interface{}, 0, len(parameters)) for _, param := range parameters { diff --git a/protocol/messenger_bridge_message_test.go b/protocol/messenger_bridge_message_test.go index 3e561e93d..81e2b89dd 100644 --- a/protocol/messenger_bridge_message_test.go +++ b/protocol/messenger_bridge_message_test.go @@ -100,3 +100,77 @@ func (s *BridgeMessageSuite) TestSendBridgeMessage() { s.Require().Equal(receivedBridgeMessagePayload.MessageID, "456") s.Require().Equal(receivedBridgeMessagePayload.ParentMessageID, "789") } + +func (s *BridgeMessageSuite) TestSearchForDiscordMessages() { + //send bridged message + chat := CreatePublicChat("test-chat", s.m.transport) + err := s.m.SaveChat(chat) + s.NoError(err) + + bridgeMessage := buildTestMessage(*chat) + bridgeMessage.ContentType = protobuf.ChatMessage_BRIDGE_MESSAGE + bridgeMessage.Payload = &protobuf.ChatMessage_BridgeMessage{ + BridgeMessage: &protobuf.BridgeMessage{ + BridgeName: "discord", + UserName: "user1", + UserAvatar: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADIAAAAyCAIAAACRXR/mAAAAjklEQVR4nOzXwQmFMBAAUZXUYh32ZB32ZB02sxYQQSZGsod55/91WFgSS0RM+SyjA56ZRZhFmEWYRRT6h+M6G16zrxv6fdJpmUWYRbxsYr13dKfanpN0WmYRZhGzXz6AWYRZRIfbaX26fT9Jk07LLMIsosPt9I/dTDotswizCG+nhFmEWYRZhFnEHQAA///z1CFkYamgfQAAAABJRU5ErkJggg==", + UserID: "123", + Content: "bridged discord message", + MessageID: "456", + ParentMessageID: "789", + }, + } + + _, err = s.m.SendChatMessage(context.Background(), bridgeMessage) + s.NoError(err) + + // Search for the message + messages, err := s.m.AllMessageByChatIDWhichMatchTerm(chat.ID, "bridged", true) + s.NoError(err) + s.Require().Len(messages, 1) + + //send discord import message + discordMessage := &protobuf.DiscordMessage{ + Id: "discordMessageID", + Type: "Default", + Timestamp: "123456", + Content: "discord import message", + Author: &protobuf.DiscordMessageAuthor{ + Id: "2", + }, + Reference: &protobuf.DiscordMessageReference{}, + } + + err = s.m.persistence.SaveDiscordMessage(discordMessage) + s.NoError(err) + bridgeMessage = buildTestMessage(*chat) + bridgeMessage.ContentType = protobuf.ChatMessage_DISCORD_MESSAGE + bridgeMessage.Payload = &protobuf.ChatMessage_DiscordMessage{ + DiscordMessage: discordMessage, + } + + _, err = s.m.SendChatMessage(context.Background(), bridgeMessage) + s.NoError(err) + + // Search for the message + messages, err = s.m.AllMessageByChatIDWhichMatchTerm(chat.ID, "import", true) + s.NoError(err) + s.Require().Len(messages, 1) + + // Search for discord messages + messages, err = s.m.AllMessageByChatIDWhichMatchTerm(chat.ID, "discord", true) + s.NoError(err) + s.Require().Len(messages, 2) + + // Search for discord messages using AllMessagesFromChatsAndCommunitiesWhichMatchTerm + chatIDs := make([]string, 1) + chatIDs = append(chatIDs, chat.ID) + messages, err = s.m.AllMessagesFromChatsAndCommunitiesWhichMatchTerm(make([]string, 0), chatIDs, "discord", true) + s.NoError(err) + s.Require().Len(messages, 2) + + // Same with case insensitive + messages, err = s.m.AllMessagesFromChatsAndCommunitiesWhichMatchTerm(make([]string, 0), chatIDs, "discord", false) + s.NoError(err) + s.Require().Len(messages, 2) +}