diff --git a/protocol/common/message.go b/protocol/common/message.go index bb73d4bcb..365264260 100644 --- a/protocol/common/message.go +++ b/protocol/common/message.go @@ -49,6 +49,7 @@ type QuotedMessage struct { DeletedForMe bool `json:"deletedForMe,omitempty"` DiscordMessage *protobuf.DiscordMessage `json:"discordMessage,omitempty"` + BridgeMessage *protobuf.BridgeMessage `json:"bridgeMessage,omitempty"` } type CommandState int diff --git a/protocol/message_persistence.go b/protocol/message_persistence.go index 1d4c8c8d2..8c3d5d466 100644 --- a/protocol/message_persistence.go +++ b/protocol/message_persistence.go @@ -34,6 +34,8 @@ LEFT JOIN discord_message_authors m2_dm_author ON m2_dm.author_id = m2_dm_author.id LEFT JOIN bridge_messages bm ON m1.id = bm.user_messages_id +LEFT JOIN bridge_messages bm_response +ON m2.id = bm_response.user_messages_id ` var basicInsertDiscordMessageAuthorQuery = `INSERT OR REPLACE INTO discord_message_authors(id,name,discriminator,nickname,avatar_url, avatar_image_payload) VALUES (?,?,?,?,?,?)` @@ -205,7 +207,12 @@ func (db sqlitePersistence) tableUserMessagesAllFieldsJoin() string { COALESCE(bm.user_id, ""), COALESCE(bm.content, ""), COALESCE(bm.message_id, ""), - COALESCE(bm.parent_message_id, "")` + COALESCE(bm.parent_message_id, ""), + COALESCE(bm_response.bridge_name, ""), + COALESCE(bm_response.user_name, ""), + COALESCE(bm_response.user_avatar, ""), + COALESCE(bm_response.user_id, ""), + COALESCE(bm_response.content, "")` } func (db sqlitePersistence) tableUserMessagesAllFieldsCount() int { @@ -256,6 +263,8 @@ func (db sqlitePersistence) tableUserMessagesScanAllFields(row scanner, message } bridgeMessage := &protobuf.BridgeMessage{} + quotedBridgeMessage := &protobuf.BridgeMessage{} + quotedDiscordMessage := &protobuf.DiscordMessage{ Author: &protobuf.DiscordMessageAuthor{}, } @@ -354,6 +363,11 @@ func (db sqlitePersistence) tableUserMessagesScanAllFields(row scanner, message &bridgeMessage.Content, &bridgeMessage.MessageID, &bridgeMessage.ParentMessageID, + "edBridgeMessage.BridgeName, + "edBridgeMessage.UserName, + "edBridgeMessage.UserAvatar, + "edBridgeMessage.UserID, + "edBridgeMessage.Content, } err := row.Scan(append(args, others...)...) if err != nil { @@ -407,6 +421,9 @@ func (db sqlitePersistence) tableUserMessagesScanAllFields(row scanner, message if message.QuotedMessage.ContentType == int64(protobuf.ChatMessage_DISCORD_MESSAGE) { message.QuotedMessage.DiscordMessage = quotedDiscordMessage } + if message.QuotedMessage.ContentType == int64(protobuf.ChatMessage_BRIDGE_MESSAGE) { + message.QuotedMessage.BridgeMessage = quotedBridgeMessage + } } } message.Alias = alias.String @@ -1194,6 +1211,9 @@ func (db sqlitePersistence) PinnedMessageByChatIDs(chatIDs []string, currCursor LEFT JOIN bridge_messages bm ON m1.id = bm.user_messages_id + LEFT JOIN bridge_messages bm_response + ON m2.id = bm_response.user_messages_id + WHERE pm.pinned = 1 AND NOT(m1.hide) AND m1.local_chat_id IN %s %s @@ -1532,6 +1552,22 @@ func (db sqlitePersistence) SaveMessages(messages []*common.Message) (err error) if msg.ContentType == protobuf.ChatMessage_BRIDGE_MESSAGE { err = db.saveBridgeMessage(tx, msg.GetBridgeMessage(), msg.ID) + if err != nil { + return + } + // handle replies + err = db.findAndUpdateReplies(tx, msg.GetBridgeMessage().MessageID, msg.ID) + if err != nil { + return + } + parentMessageID := msg.GetBridgeMessage().ParentMessageID + if parentMessageID != "" { + err = db.findAndUpdateRepliedTo(tx, parentMessageID, msg.ID) + if err != nil { + return + } + } + } } return @@ -2910,3 +2946,82 @@ func (db sqlitePersistence) GetCommunityMemberAllMessagesID(member string, commu return result, nil } + +// Finds status messages id which are replies for bridgeMessageID +func (db sqlitePersistence) findStatusMessageIdsReplies(tx *sql.Tx, bridgeMessageID string) ([]string, error) { + rows, err := tx.Query(`SELECT user_messages_id FROM bridge_messages WHERE parent_message_id = ?`, bridgeMessageID) + if err != nil { + return []string{}, err + } + defer rows.Close() + + var statusMessageIDs []string + for rows.Next() { + var statusMessageID string + err = rows.Scan(&statusMessageID) + if err != nil { + return []string{}, err + } + statusMessageIDs = append(statusMessageIDs, statusMessageID) + } + return statusMessageIDs, nil +} + +// Finds status messages id which are replies for bridgeMessageID +func (db sqlitePersistence) findStatusMessageIdsRepliedTo(tx *sql.Tx, parentMessageID string) (string, error) { + rows, err := tx.Query(`SELECT user_messages_id FROM bridge_messages WHERE message_id = ?`, parentMessageID) + if err != nil { + return "", err + } + defer rows.Close() + + if rows.Next() { + var statusMessageID string + err = rows.Scan(&statusMessageID) + if err != nil { + return "", err + } + return statusMessageID, nil + } + return "", nil +} + +func (db sqlitePersistence) updateStatusMessagesWithResponse(tx *sql.Tx, statusMessagesToUpdate []string, responseValue string) error { + sql := "UPDATE user_messages SET response_to = ? WHERE id IN (?" + strings.Repeat(",?", len(statusMessagesToUpdate)-1) + ")" + stmt, err := tx.Prepare(sql) + if err != nil { + return err + } + defer stmt.Close() + + args := make([]interface{}, 0, len(statusMessagesToUpdate)+1) + args = append(args, responseValue) + for _, msgToUpdate := range statusMessagesToUpdate { + args = append(args, msgToUpdate) + } + _, err = stmt.Exec(args...) + return err +} + +// Finds if there are any messages that are replies to that message (in case replies were received earlier) +func (db sqlitePersistence) findAndUpdateReplies(tx *sql.Tx, bridgeMessageID string, statusMessageID string) error { + replyMessageIds, err := db.findStatusMessageIdsReplies(tx, bridgeMessageID) + if err != nil { + return err + } + if len(replyMessageIds) == 0 { + return nil + } + return db.updateStatusMessagesWithResponse(tx, replyMessageIds, statusMessageID) +} + +func (db sqlitePersistence) findAndUpdateRepliedTo(tx *sql.Tx, discordParentMessageID string, statusMessageID string) error { + repliedMessageID, err := db.findStatusMessageIdsRepliedTo(tx, discordParentMessageID) + if err != nil { + return err + } + if repliedMessageID == "" { + return nil + } + return db.updateStatusMessagesWithResponse(tx, []string{statusMessageID}, repliedMessageID) +} diff --git a/protocol/persistence_test.go b/protocol/persistence_test.go index cf055b1ba..b3bd0b40a 100644 --- a/protocol/persistence_test.go +++ b/protocol/persistence_test.go @@ -1911,6 +1911,85 @@ func TestSaveBridgeMessage(t *testing.T) { require.Equal(t, "789", retrievedMessages[0].GetBridgeMessage().ParentMessageID) } +func insertMinimalBridgeMessage(p *sqlitePersistence, messageID string, bridgeMessageID string, bridgeMessageParentID string) error { + + bridgeMessage := &protobuf.BridgeMessage{ + BridgeName: "discord", + UserName: "joe", + Content: "abc", + UserAvatar: "data:image/png;base64,iVBO...", + UserID: "123", + MessageID: bridgeMessageID, + ParentMessageID: bridgeMessageParentID, + } + + return p.SaveMessages([]*common.Message{{ + ID: messageID, + LocalChatID: testPublicChatID, + From: testPK, + ChatMessage: &protobuf.ChatMessage{ + Text: "some-text", + ContentType: protobuf.ChatMessage_BRIDGE_MESSAGE, + ChatId: testPublicChatID, + Payload: &protobuf.ChatMessage_BridgeMessage{ + BridgeMessage: bridgeMessage, + }, + }, + }}) +} + +func messageResponseTo(p *sqlitePersistence, messageID string) (string, error) { + var responseTo string + err := p.db.QueryRow("SELECT response_to FROM user_messages WHERE id = ?", messageID).Scan(&responseTo) + return responseTo, err +} + +func TestBridgeMessageReplies(t *testing.T) { + db, err := openTestDB() + require.NoError(t, err) + p := newSQLitePersistence(db) + + require.NoError(t, err) + + err = insertMinimalBridgeMessage(p, "111", "1", "") + require.NoError(t, err) + + err = insertMinimalBridgeMessage(p, "222", "2", "1") + require.NoError(t, err) + + // "333 is not delivered yet" + + // this is a reply to a message which was not delivered yet + err = insertMinimalBridgeMessage(p, "444", "4", "3") + require.NoError(t, err) + + // status message "222" should have reply_to = "111" + responseTo, err := messageResponseTo(p, "222") + require.NoError(t, err) + require.Equal(t, "111", responseTo) + + responseTo, err = messageResponseTo(p, "111") + require.NoError(t, err) + require.Equal(t, "", responseTo) + + responseTo, err = messageResponseTo(p, "444") + require.NoError(t, err) + require.Equal(t, "", responseTo) + + // receiving message for which "444" is replied to + err = insertMinimalBridgeMessage(p, "333", "3", "") + require.NoError(t, err) + + responseTo, err = messageResponseTo(p, "333") + require.NoError(t, err) + require.Equal(t, "", responseTo) + + // now 444 is replied to 333 + responseTo, err = messageResponseTo(p, "444") + require.NoError(t, err) + require.Equal(t, "333", responseTo) +} + func TestGetCommunityMemberAllNonDeletedMessages(t *testing.T) { db, err := openTestDB() require.NoError(t, err)