fix: process empty albumId in albumMessages (#4874) (#4898)

* fix: process empty albumId in albumMessages
* fix: right `prepareMessage` for empty album
This commit is contained in:
Igor Sirotin 2024-03-08 13:48:22 +00:00 committed by GitHub
parent 6c792a0e73
commit c3e7d3823f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 138 additions and 104 deletions

View File

@ -648,6 +648,9 @@ func (db sqlitePersistence) messageByID(tx *sql.Tx, id string) (*common.Message,
}
func (db sqlitePersistence) albumMessages(chatID, albumID string) ([]*common.Message, error) {
if albumID == "" {
return nil, nil
}
query := db.buildMessagesQuery("WHERE m1.album_id = ? and m1.local_chat_id = ?")
rows, err := db.db.Query(query, albumID, chatID)
if err != nil {

View File

@ -4207,9 +4207,8 @@ func (m *Messenger) MessageByChatID(chatID, cursor string, limit int) ([]*common
}
if m.httpServer != nil {
for idx := range msgs {
err = m.prepareMessage(msgs[idx], m.httpServer)
for _, msg := range msgs {
err = m.prepareMessage(msg, m.httpServer)
if err != nil {
return nil, "", err
}
@ -4220,15 +4219,15 @@ func (m *Messenger) MessageByChatID(chatID, cursor string, limit int) ([]*common
}
func (m *Messenger) prepareMessages(messages map[string]*common.Message) error {
if m.httpServer != nil {
if m.httpServer == nil {
return nil
}
for idx := range messages {
err := m.prepareMessage(messages[idx], m.httpServer)
if err != nil {
return err
}
}
}
return nil
}
@ -4265,19 +4264,23 @@ func (m *Messenger) prepareMessage(msg *common.Message, s *server.MediaServer) e
}
if quotedMessage.ChatMessage != nil {
image := quotedMessage.ChatMessage.GetImage()
albumID := quotedMessage.ChatMessage.GetImage().AlbumId
if image != nil && image.GetAlbumId() != "" {
albumMessages, err := m.persistence.albumMessages(quotedMessage.LocalChatID, albumID)
if err != nil {
return err
}
var quotedImages = extractQuotedImages(albumMessages, s)
if quotedImagesJSON, err := json.Marshal(quotedImages); err == nil {
msg.QuotedMessage.AlbumImages = quotedImagesJSON
} else {
quotedImages := extractQuotedImages(albumMessages, s)
quotedImagesJSON, err := json.Marshal(quotedImages)
if err != nil {
return err
}
msg.QuotedMessage.AlbumImages = quotedImagesJSON
}
}
}
if msg.QuotedMessage != nil && msg.QuotedMessage.ContentType == int64(protobuf.ChatMessage_AUDIO) {

View File

@ -3,119 +3,169 @@ package protocol
import (
"encoding/json"
"fmt"
"testing"
"github.com/stretchr/testify/suite"
"github.com/status-im/status-go/protocol/common"
"github.com/status-im/status-go/protocol/protobuf"
"github.com/status-im/status-go/server"
)
func (s *MessengerSuite) setUpTestDatabase() (string, *sqlitePersistence) {
func TestMessengerPrepareMessage(t *testing.T) {
suite.Run(t, new(TestMessengerPrepareMessageSuite))
}
type TestMessengerPrepareMessageSuite struct {
MessengerBaseTestSuite
chatID string
p *sqlitePersistence
}
func (s *TestMessengerPrepareMessageSuite) SetupTest() {
s.MessengerBaseTestSuite.SetupTest()
s.chatID, s.p = s.setUpTestDatabase()
}
func (s *TestMessengerPrepareMessageSuite) setUpTestDatabase() (string, *sqlitePersistence) {
chat := CreatePublicChat("test-chat", s.m.transport)
err := s.m.SaveChat(chat)
s.NoError(err)
db, err := openTestDB()
s.NoError(err)
p := newSQLitePersistence(db)
p := newSQLitePersistence(db)
return chat.ID, p
}
func (s *MessengerSuite) Test_WHEN_MessageContainsImage_Then_preparedMessageAddsAlbumImageWithImageGeneratedLink() {
chatID, p := s.setUpTestDatabase()
message1 := &common.Message{
ID: "id-1",
LocalChatID: chatID,
func (s *TestMessengerPrepareMessageSuite) generateTextMessage(ID string, From string, Clock uint64, responseTo string) *common.Message {
return &common.Message{
ID: ID,
From: From,
LocalChatID: s.chatID,
ChatMessage: &protobuf.ChatMessage{
Text: "content-1",
Clock: uint64(1),
ContentType: protobuf.ChatMessage_IMAGE,
Payload: &protobuf.ChatMessage_Image{
Text: RandomLettersString(5),
Clock: Clock,
ResponseTo: responseTo,
},
}
}
func (s *TestMessengerPrepareMessageSuite) testMessageContainsImage(testAlbum bool) {
message1 := s.generateTextMessage("id-1", "1", 1, "")
message1.ContentType = protobuf.ChatMessage_IMAGE
message1.Payload = &protobuf.ChatMessage_Image{
Image: &protobuf.ImageMessage{
Format: 1,
Payload: []byte("some-payload"),
Payload: RandomBytes(10),
},
},
},
From: "1",
}
message2 := &common.Message{
ID: "id-2",
LocalChatID: chatID,
ChatMessage: &protobuf.ChatMessage{
Text: "content-2",
Clock: uint64(2),
ResponseTo: "id-1",
},
From: "2",
}
message2 := s.generateTextMessage("id-2", "2", 2, message1.ID)
messages := []*common.Message{message1, message2}
err := s.m.SaveMessages([]*common.Message{message1, message2})
var message3 *common.Message
if testAlbum {
albumID := RandomLettersString(5)
message1.GetImage().AlbumId = albumID
message3 = s.generateTextMessage("id-3", "1", 0, "")
message3.ContentType = protobuf.ChatMessage_IMAGE
message3.Payload = &protobuf.ChatMessage_Image{
Image: &protobuf.ImageMessage{
Format: 1,
Payload: RandomBytes(10),
AlbumId: albumID,
},
}
messages = append(messages, message3)
}
err := s.m.SaveMessages(messages)
s.Require().NoError(err)
err = p.SaveMessages(messages)
err = s.p.SaveMessages(messages)
s.Require().NoError(err)
mediaServer, err := server.NewMediaServer(s.m.database, nil, nil, nil)
s.Require().NoError(err)
s.Require().NoError(mediaServer.Start())
retrievedMessages, _, err := p.MessageByChatID(chatID, "", 10)
retrievedMessages, _, err := s.p.MessageByChatID(s.chatID, "", 10)
s.Require().NoError(err)
s.Require().Equal("id-2", retrievedMessages[0].ID)
s.Require().Equal("id-1", retrievedMessages[0].ResponseTo)
if testAlbum {
s.Require().Len(retrievedMessages, 3)
} else {
s.Require().Len(retrievedMessages, 2)
}
s.Require().Equal(message2.ID, retrievedMessages[0].ID)
s.Require().Equal(message1.ID, retrievedMessages[0].ResponseTo)
err = s.m.prepareMessage(retrievedMessages[0], mediaServer)
s.Require().NoError(err)
expectedURL := fmt.Sprintf(`["https://Localhost:%d/messages/images?messageId=id-1"]`, mediaServer.GetPort())
mediaServerImageLink := func(messageID string) string {
return fmt.Sprintf(`https://Localhost:%d/messages/images?messageId=%s`,
mediaServer.GetPort(),
messageID)
}
s.Require().Equal(json.RawMessage(expectedURL), retrievedMessages[0].QuotedMessage.AlbumImages)
if testAlbum {
expectedJSON := fmt.Sprintf(`["%s","%s"]`,
mediaServerImageLink(message1.ID),
mediaServerImageLink(message3.ID),
)
s.Require().Equal(json.RawMessage(expectedJSON), retrievedMessages[0].QuotedMessage.AlbumImages)
} else {
expectedURL := mediaServerImageLink(message1.ID)
s.Require().Equal(expectedURL, retrievedMessages[0].QuotedMessage.ImageLocalURL)
}
}
func (s *MessengerSuite) Test_WHEN_NoQuotedMessage_THEN_RetrievedMessageDoesNotContainQuotedMessage() {
chatID, p := s.setUpTestDatabase()
message1 := &common.Message{
ID: "id-1",
LocalChatID: chatID,
ChatMessage: &protobuf.ChatMessage{
Text: "content-1",
Clock: uint64(1),
func (s *TestMessengerPrepareMessageSuite) Test_WHEN_MessageContainsImage_THEN_preparedMessageAddsAlbumImageWithImageGeneratedLink() {
testCases := []struct {
name string
album bool
}{
{
name: "single image",
album: false,
},
{
name: "album",
album: true,
},
From: "1",
}
message2 := &common.Message{
ID: "id-2",
LocalChatID: chatID,
ChatMessage: &protobuf.ChatMessage{
Text: "content-2",
Clock: uint64(2),
},
From: "2",
for _, tc := range testCases {
s.Run(tc.name, func() {
s.testMessageContainsImage(tc.album)
})
}
}
func (s *TestMessengerPrepareMessageSuite) Test_WHEN_NoQuotedMessage_THEN_RetrievedMessageDoesNotContainQuotedMessage() {
message1 := s.generateTextMessage("id-1", "1", 1, "")
message2 := s.generateTextMessage("id-2", "2", 2, "")
messages := []*common.Message{message1, message2}
err := s.m.SaveMessages([]*common.Message{message1, message2})
s.Require().NoError(err)
err = p.SaveMessages(messages)
err = s.p.SaveMessages(messages)
s.Require().NoError(err)
mediaServer, err := server.NewMediaServer(s.m.database, nil, nil, nil)
s.Require().NoError(err)
s.Require().NoError(mediaServer.Start())
retrievedMessages, _, err := p.MessageByChatID(chatID, "", 10)
retrievedMessages, _, err := s.p.MessageByChatID(s.chatID, "", 10)
s.Require().NoError(err)
s.Require().Equal("id-2", retrievedMessages[0].ID)
s.Require().Equal("", retrievedMessages[0].ResponseTo)
s.Require().Equal(message2.ID, retrievedMessages[0].ID)
s.Require().Empty(retrievedMessages[0].ResponseTo)
err = s.m.prepareMessage(retrievedMessages[0], mediaServer)
s.Require().NoError(err)
@ -123,47 +173,25 @@ func (s *MessengerSuite) Test_WHEN_NoQuotedMessage_THEN_RetrievedMessageDoesNotC
s.Require().Equal((*common.QuotedMessage)(nil), retrievedMessages[0].QuotedMessage)
}
func (s *MessengerSuite) Test_WHEN_QuotedMessageDoesNotContainsImage_THEN_RetrievedMessageContainsNoImages() {
chatID, p := s.setUpTestDatabase()
message1 := &common.Message{
ID: "id-1",
LocalChatID: chatID,
ChatMessage: &protobuf.ChatMessage{
Text: "content-1",
Clock: uint64(1),
},
From: "1",
}
message2 := &common.Message{
ID: "id-2",
LocalChatID: chatID,
ChatMessage: &protobuf.ChatMessage{
Text: "content-2",
Clock: uint64(2),
ResponseTo: "id-1",
},
From: "2",
}
func (s *TestMessengerPrepareMessageSuite) Test_WHEN_QuotedMessageDoesNotContainsImage_THEN_RetrievedMessageContainsNoImages() {
message1 := s.generateTextMessage("id-1", "1", 1, "")
message2 := s.generateTextMessage("id-2", "2", 2, message1.ID)
messages := []*common.Message{message1, message2}
err := s.m.SaveMessages([]*common.Message{message1, message2})
s.Require().NoError(err)
err = p.SaveMessages(messages)
err = s.p.SaveMessages(messages)
s.Require().NoError(err)
mediaServer, err := server.NewMediaServer(s.m.database, nil, nil, nil)
s.Require().NoError(err)
s.Require().NoError(mediaServer.Start())
retrievedMessages, _, err := p.MessageByChatID(chatID, "", 10)
retrievedMessages, _, err := s.p.MessageByChatID(s.chatID, "", 10)
s.Require().NoError(err)
s.Require().Equal("id-2", retrievedMessages[0].ID)
s.Require().Equal("id-1", retrievedMessages[0].ResponseTo)
s.Require().Equal(message2.ID, retrievedMessages[0].ID)
s.Require().Equal(message1.ID, retrievedMessages[0].ResponseTo)
err = s.m.prepareMessage(retrievedMessages[0], mediaServer)
s.Require().NoError(err)