fix: process empty albumId in albumMessages (#4874) (#4898)

* fix: process empty albumId in albumMessages
* fix: right `prepareMessage` for empty album
This commit is contained in:
Igor Sirotin 2024-03-08 13:48:22 +00:00 committed by GitHub
parent 6c792a0e73
commit c3e7d3823f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 138 additions and 104 deletions

View File

@ -648,6 +648,9 @@ func (db sqlitePersistence) messageByID(tx *sql.Tx, id string) (*common.Message,
} }
func (db sqlitePersistence) albumMessages(chatID, albumID string) ([]*common.Message, error) { func (db sqlitePersistence) albumMessages(chatID, albumID string) ([]*common.Message, error) {
if albumID == "" {
return nil, nil
}
query := db.buildMessagesQuery("WHERE m1.album_id = ? and m1.local_chat_id = ?") query := db.buildMessagesQuery("WHERE m1.album_id = ? and m1.local_chat_id = ?")
rows, err := db.db.Query(query, albumID, chatID) rows, err := db.db.Query(query, albumID, chatID)
if err != nil { if err != nil {

View File

@ -4207,9 +4207,8 @@ func (m *Messenger) MessageByChatID(chatID, cursor string, limit int) ([]*common
} }
if m.httpServer != nil { if m.httpServer != nil {
for idx := range msgs { for _, msg := range msgs {
err = m.prepareMessage(msgs[idx], m.httpServer) err = m.prepareMessage(msg, m.httpServer)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
@ -4220,13 +4219,13 @@ func (m *Messenger) MessageByChatID(chatID, cursor string, limit int) ([]*common
} }
func (m *Messenger) prepareMessages(messages map[string]*common.Message) error { func (m *Messenger) prepareMessages(messages map[string]*common.Message) error {
if m.httpServer != nil { if m.httpServer == nil {
for idx := range messages { return nil
err := m.prepareMessage(messages[idx], m.httpServer) }
for idx := range messages {
if err != nil { err := m.prepareMessage(messages[idx], m.httpServer)
return err if err != nil {
} return err
} }
} }
return nil return nil
@ -4265,18 +4264,22 @@ func (m *Messenger) prepareMessage(msg *common.Message, s *server.MediaServer) e
} }
if quotedMessage.ChatMessage != nil { if quotedMessage.ChatMessage != nil {
image := quotedMessage.ChatMessage.GetImage()
albumID := quotedMessage.ChatMessage.GetImage().AlbumId albumID := quotedMessage.ChatMessage.GetImage().AlbumId
albumMessages, err := m.persistence.albumMessages(quotedMessage.LocalChatID, albumID)
if err != nil {
return err
}
var quotedImages = extractQuotedImages(albumMessages, s) if image != nil && image.GetAlbumId() != "" {
albumMessages, err := m.persistence.albumMessages(quotedMessage.LocalChatID, albumID)
if err != nil {
return err
}
quotedImages := extractQuotedImages(albumMessages, s)
quotedImagesJSON, err := json.Marshal(quotedImages)
if err != nil {
return err
}
if quotedImagesJSON, err := json.Marshal(quotedImages); err == nil {
msg.QuotedMessage.AlbumImages = quotedImagesJSON msg.QuotedMessage.AlbumImages = quotedImagesJSON
} else {
return err
} }
} }
} }

View File

@ -3,119 +3,169 @@ package protocol
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"testing"
"github.com/stretchr/testify/suite"
"github.com/status-im/status-go/protocol/common" "github.com/status-im/status-go/protocol/common"
"github.com/status-im/status-go/protocol/protobuf" "github.com/status-im/status-go/protocol/protobuf"
"github.com/status-im/status-go/server" "github.com/status-im/status-go/server"
) )
func (s *MessengerSuite) setUpTestDatabase() (string, *sqlitePersistence) { func TestMessengerPrepareMessage(t *testing.T) {
suite.Run(t, new(TestMessengerPrepareMessageSuite))
}
type TestMessengerPrepareMessageSuite struct {
MessengerBaseTestSuite
chatID string
p *sqlitePersistence
}
func (s *TestMessengerPrepareMessageSuite) SetupTest() {
s.MessengerBaseTestSuite.SetupTest()
s.chatID, s.p = s.setUpTestDatabase()
}
func (s *TestMessengerPrepareMessageSuite) setUpTestDatabase() (string, *sqlitePersistence) {
chat := CreatePublicChat("test-chat", s.m.transport) chat := CreatePublicChat("test-chat", s.m.transport)
err := s.m.SaveChat(chat) err := s.m.SaveChat(chat)
s.NoError(err) s.NoError(err)
db, err := openTestDB() db, err := openTestDB()
s.NoError(err) s.NoError(err)
p := newSQLitePersistence(db)
p := newSQLitePersistence(db)
return chat.ID, p return chat.ID, p
} }
func (s *MessengerSuite) Test_WHEN_MessageContainsImage_Then_preparedMessageAddsAlbumImageWithImageGeneratedLink() { func (s *TestMessengerPrepareMessageSuite) generateTextMessage(ID string, From string, Clock uint64, responseTo string) *common.Message {
chatID, p := s.setUpTestDatabase() return &common.Message{
ID: ID,
message1 := &common.Message{ From: From,
ID: "id-1", LocalChatID: s.chatID,
LocalChatID: chatID,
ChatMessage: &protobuf.ChatMessage{ ChatMessage: &protobuf.ChatMessage{
Text: "content-1", Text: RandomLettersString(5),
Clock: uint64(1), Clock: Clock,
ContentType: protobuf.ChatMessage_IMAGE, ResponseTo: responseTo,
Payload: &protobuf.ChatMessage_Image{
Image: &protobuf.ImageMessage{
Format: 1,
Payload: []byte("some-payload"),
},
},
}, },
From: "1",
} }
message2 := &common.Message{ }
ID: "id-2",
LocalChatID: chatID,
ChatMessage: &protobuf.ChatMessage{
Text: "content-2",
Clock: uint64(2),
ResponseTo: "id-1",
},
From: "2", func (s *TestMessengerPrepareMessageSuite) testMessageContainsImage(testAlbum bool) {
message1 := s.generateTextMessage("id-1", "1", 1, "")
message1.ContentType = protobuf.ChatMessage_IMAGE
message1.Payload = &protobuf.ChatMessage_Image{
Image: &protobuf.ImageMessage{
Format: 1,
Payload: RandomBytes(10),
},
} }
message2 := s.generateTextMessage("id-2", "2", 2, message1.ID)
messages := []*common.Message{message1, message2} messages := []*common.Message{message1, message2}
err := s.m.SaveMessages([]*common.Message{message1, message2}) var message3 *common.Message
if testAlbum {
albumID := RandomLettersString(5)
message1.GetImage().AlbumId = albumID
message3 = s.generateTextMessage("id-3", "1", 0, "")
message3.ContentType = protobuf.ChatMessage_IMAGE
message3.Payload = &protobuf.ChatMessage_Image{
Image: &protobuf.ImageMessage{
Format: 1,
Payload: RandomBytes(10),
AlbumId: albumID,
},
}
messages = append(messages, message3)
}
err := s.m.SaveMessages(messages)
s.Require().NoError(err) s.Require().NoError(err)
err = p.SaveMessages(messages) err = s.p.SaveMessages(messages)
s.Require().NoError(err) s.Require().NoError(err)
mediaServer, err := server.NewMediaServer(s.m.database, nil, nil, nil) mediaServer, err := server.NewMediaServer(s.m.database, nil, nil, nil)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().NoError(mediaServer.Start()) s.Require().NoError(mediaServer.Start())
retrievedMessages, _, err := p.MessageByChatID(chatID, "", 10) retrievedMessages, _, err := s.p.MessageByChatID(s.chatID, "", 10)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal("id-2", retrievedMessages[0].ID) if testAlbum {
s.Require().Equal("id-1", retrievedMessages[0].ResponseTo) s.Require().Len(retrievedMessages, 3)
} else {
s.Require().Len(retrievedMessages, 2)
}
s.Require().Equal(message2.ID, retrievedMessages[0].ID)
s.Require().Equal(message1.ID, retrievedMessages[0].ResponseTo)
err = s.m.prepareMessage(retrievedMessages[0], mediaServer) err = s.m.prepareMessage(retrievedMessages[0], mediaServer)
s.Require().NoError(err) s.Require().NoError(err)
expectedURL := fmt.Sprintf(`["https://Localhost:%d/messages/images?messageId=id-1"]`, mediaServer.GetPort()) mediaServerImageLink := func(messageID string) string {
return fmt.Sprintf(`https://Localhost:%d/messages/images?messageId=%s`,
mediaServer.GetPort(),
messageID)
}
s.Require().Equal(json.RawMessage(expectedURL), retrievedMessages[0].QuotedMessage.AlbumImages) if testAlbum {
expectedJSON := fmt.Sprintf(`["%s","%s"]`,
mediaServerImageLink(message1.ID),
mediaServerImageLink(message3.ID),
)
s.Require().Equal(json.RawMessage(expectedJSON), retrievedMessages[0].QuotedMessage.AlbumImages)
} else {
expectedURL := mediaServerImageLink(message1.ID)
s.Require().Equal(expectedURL, retrievedMessages[0].QuotedMessage.ImageLocalURL)
}
} }
func (s *MessengerSuite) Test_WHEN_NoQuotedMessage_THEN_RetrievedMessageDoesNotContainQuotedMessage() { func (s *TestMessengerPrepareMessageSuite) Test_WHEN_MessageContainsImage_THEN_preparedMessageAddsAlbumImageWithImageGeneratedLink() {
chatID, p := s.setUpTestDatabase() testCases := []struct {
name string
message1 := &common.Message{ album bool
ID: "id-1", }{
LocalChatID: chatID, {
ChatMessage: &protobuf.ChatMessage{ name: "single image",
Text: "content-1", album: false,
Clock: uint64(1), },
{
name: "album",
album: true,
}, },
From: "1",
} }
message2 := &common.Message{ for _, tc := range testCases {
ID: "id-2", s.Run(tc.name, func() {
LocalChatID: chatID, s.testMessageContainsImage(tc.album)
ChatMessage: &protobuf.ChatMessage{ })
Text: "content-2",
Clock: uint64(2),
},
From: "2",
} }
}
func (s *TestMessengerPrepareMessageSuite) Test_WHEN_NoQuotedMessage_THEN_RetrievedMessageDoesNotContainQuotedMessage() {
message1 := s.generateTextMessage("id-1", "1", 1, "")
message2 := s.generateTextMessage("id-2", "2", 2, "")
messages := []*common.Message{message1, message2} messages := []*common.Message{message1, message2}
err := s.m.SaveMessages([]*common.Message{message1, message2}) err := s.m.SaveMessages([]*common.Message{message1, message2})
s.Require().NoError(err) s.Require().NoError(err)
err = p.SaveMessages(messages) err = s.p.SaveMessages(messages)
s.Require().NoError(err) s.Require().NoError(err)
mediaServer, err := server.NewMediaServer(s.m.database, nil, nil, nil) mediaServer, err := server.NewMediaServer(s.m.database, nil, nil, nil)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().NoError(mediaServer.Start()) s.Require().NoError(mediaServer.Start())
retrievedMessages, _, err := p.MessageByChatID(chatID, "", 10) retrievedMessages, _, err := s.p.MessageByChatID(s.chatID, "", 10)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal("id-2", retrievedMessages[0].ID) s.Require().Equal(message2.ID, retrievedMessages[0].ID)
s.Require().Equal("", retrievedMessages[0].ResponseTo) s.Require().Empty(retrievedMessages[0].ResponseTo)
err = s.m.prepareMessage(retrievedMessages[0], mediaServer) err = s.m.prepareMessage(retrievedMessages[0], mediaServer)
s.Require().NoError(err) s.Require().NoError(err)
@ -123,47 +173,25 @@ func (s *MessengerSuite) Test_WHEN_NoQuotedMessage_THEN_RetrievedMessageDoesNotC
s.Require().Equal((*common.QuotedMessage)(nil), retrievedMessages[0].QuotedMessage) s.Require().Equal((*common.QuotedMessage)(nil), retrievedMessages[0].QuotedMessage)
} }
func (s *MessengerSuite) Test_WHEN_QuotedMessageDoesNotContainsImage_THEN_RetrievedMessageContainsNoImages() { func (s *TestMessengerPrepareMessageSuite) Test_WHEN_QuotedMessageDoesNotContainsImage_THEN_RetrievedMessageContainsNoImages() {
chatID, p := s.setUpTestDatabase() message1 := s.generateTextMessage("id-1", "1", 1, "")
message2 := s.generateTextMessage("id-2", "2", 2, message1.ID)
message1 := &common.Message{
ID: "id-1",
LocalChatID: chatID,
ChatMessage: &protobuf.ChatMessage{
Text: "content-1",
Clock: uint64(1),
},
From: "1",
}
message2 := &common.Message{
ID: "id-2",
LocalChatID: chatID,
ChatMessage: &protobuf.ChatMessage{
Text: "content-2",
Clock: uint64(2),
ResponseTo: "id-1",
},
From: "2",
}
messages := []*common.Message{message1, message2} messages := []*common.Message{message1, message2}
err := s.m.SaveMessages([]*common.Message{message1, message2}) err := s.m.SaveMessages([]*common.Message{message1, message2})
s.Require().NoError(err) s.Require().NoError(err)
err = p.SaveMessages(messages) err = s.p.SaveMessages(messages)
s.Require().NoError(err) s.Require().NoError(err)
mediaServer, err := server.NewMediaServer(s.m.database, nil, nil, nil) mediaServer, err := server.NewMediaServer(s.m.database, nil, nil, nil)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().NoError(mediaServer.Start()) s.Require().NoError(mediaServer.Start())
retrievedMessages, _, err := p.MessageByChatID(chatID, "", 10) retrievedMessages, _, err := s.p.MessageByChatID(s.chatID, "", 10)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal("id-2", retrievedMessages[0].ID) s.Require().Equal(message2.ID, retrievedMessages[0].ID)
s.Require().Equal("id-1", retrievedMessages[0].ResponseTo) s.Require().Equal(message1.ID, retrievedMessages[0].ResponseTo)
err = s.m.prepareMessage(retrievedMessages[0], mediaServer) err = s.m.prepareMessage(retrievedMessages[0], mediaServer)
s.Require().NoError(err) s.Require().NoError(err)