feat: handling replies for bridged messages
BridgeMessage struct has MessageID and ParentMessageID. MessageID keeps the original Discord message ID. ParentMessageID keeps the original Discord parent message ID (response_to). When the new bridge message is received, corresponding status message response_to field is updated. Issue #13258
This commit is contained in:
parent
a3ad05db58
commit
9db149d4f6
|
@ -49,6 +49,7 @@ type QuotedMessage struct {
|
|||
DeletedForMe bool `json:"deletedForMe,omitempty"`
|
||||
|
||||
DiscordMessage *protobuf.DiscordMessage `json:"discordMessage,omitempty"`
|
||||
BridgeMessage *protobuf.BridgeMessage `json:"bridgeMessage,omitempty"`
|
||||
}
|
||||
|
||||
type CommandState int
|
||||
|
|
|
@ -34,6 +34,8 @@ LEFT JOIN discord_message_authors m2_dm_author
|
|||
ON m2_dm.author_id = m2_dm_author.id
|
||||
LEFT JOIN bridge_messages bm
|
||||
ON m1.id = bm.user_messages_id
|
||||
LEFT JOIN bridge_messages bm_response
|
||||
ON m2.id = bm_response.user_messages_id
|
||||
`
|
||||
|
||||
var basicInsertDiscordMessageAuthorQuery = `INSERT OR REPLACE INTO discord_message_authors(id,name,discriminator,nickname,avatar_url, avatar_image_payload) VALUES (?,?,?,?,?,?)`
|
||||
|
@ -205,7 +207,12 @@ func (db sqlitePersistence) tableUserMessagesAllFieldsJoin() string {
|
|||
COALESCE(bm.user_id, ""),
|
||||
COALESCE(bm.content, ""),
|
||||
COALESCE(bm.message_id, ""),
|
||||
COALESCE(bm.parent_message_id, "")`
|
||||
COALESCE(bm.parent_message_id, ""),
|
||||
COALESCE(bm_response.bridge_name, ""),
|
||||
COALESCE(bm_response.user_name, ""),
|
||||
COALESCE(bm_response.user_avatar, ""),
|
||||
COALESCE(bm_response.user_id, ""),
|
||||
COALESCE(bm_response.content, "")`
|
||||
}
|
||||
|
||||
func (db sqlitePersistence) tableUserMessagesAllFieldsCount() int {
|
||||
|
@ -256,6 +263,8 @@ func (db sqlitePersistence) tableUserMessagesScanAllFields(row scanner, message
|
|||
}
|
||||
bridgeMessage := &protobuf.BridgeMessage{}
|
||||
|
||||
quotedBridgeMessage := &protobuf.BridgeMessage{}
|
||||
|
||||
quotedDiscordMessage := &protobuf.DiscordMessage{
|
||||
Author: &protobuf.DiscordMessageAuthor{},
|
||||
}
|
||||
|
@ -354,6 +363,11 @@ func (db sqlitePersistence) tableUserMessagesScanAllFields(row scanner, message
|
|||
&bridgeMessage.Content,
|
||||
&bridgeMessage.MessageID,
|
||||
&bridgeMessage.ParentMessageID,
|
||||
"edBridgeMessage.BridgeName,
|
||||
"edBridgeMessage.UserName,
|
||||
"edBridgeMessage.UserAvatar,
|
||||
"edBridgeMessage.UserID,
|
||||
"edBridgeMessage.Content,
|
||||
}
|
||||
err := row.Scan(append(args, others...)...)
|
||||
if err != nil {
|
||||
|
@ -407,6 +421,9 @@ func (db sqlitePersistence) tableUserMessagesScanAllFields(row scanner, message
|
|||
if message.QuotedMessage.ContentType == int64(protobuf.ChatMessage_DISCORD_MESSAGE) {
|
||||
message.QuotedMessage.DiscordMessage = quotedDiscordMessage
|
||||
}
|
||||
if message.QuotedMessage.ContentType == int64(protobuf.ChatMessage_BRIDGE_MESSAGE) {
|
||||
message.QuotedMessage.BridgeMessage = quotedBridgeMessage
|
||||
}
|
||||
}
|
||||
}
|
||||
message.Alias = alias.String
|
||||
|
@ -1194,6 +1211,9 @@ func (db sqlitePersistence) PinnedMessageByChatIDs(chatIDs []string, currCursor
|
|||
LEFT JOIN bridge_messages bm
|
||||
ON m1.id = bm.user_messages_id
|
||||
|
||||
LEFT JOIN bridge_messages bm_response
|
||||
ON m2.id = bm_response.user_messages_id
|
||||
|
||||
WHERE
|
||||
pm.pinned = 1
|
||||
AND NOT(m1.hide) AND m1.local_chat_id IN %s %s
|
||||
|
@ -1532,6 +1552,22 @@ func (db sqlitePersistence) SaveMessages(messages []*common.Message) (err error)
|
|||
|
||||
if msg.ContentType == protobuf.ChatMessage_BRIDGE_MESSAGE {
|
||||
err = db.saveBridgeMessage(tx, msg.GetBridgeMessage(), msg.ID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// handle replies
|
||||
err = db.findAndUpdateReplies(tx, msg.GetBridgeMessage().MessageID, msg.ID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
parentMessageID := msg.GetBridgeMessage().ParentMessageID
|
||||
if parentMessageID != "" {
|
||||
err = db.findAndUpdateRepliedTo(tx, parentMessageID, msg.ID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
return
|
||||
|
@ -2910,3 +2946,82 @@ func (db sqlitePersistence) GetCommunityMemberAllMessagesID(member string, commu
|
|||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Finds status messages id which are replies for bridgeMessageID
|
||||
func (db sqlitePersistence) findStatusMessageIdsReplies(tx *sql.Tx, bridgeMessageID string) ([]string, error) {
|
||||
rows, err := tx.Query(`SELECT user_messages_id FROM bridge_messages WHERE parent_message_id = ?`, bridgeMessageID)
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var statusMessageIDs []string
|
||||
for rows.Next() {
|
||||
var statusMessageID string
|
||||
err = rows.Scan(&statusMessageID)
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
statusMessageIDs = append(statusMessageIDs, statusMessageID)
|
||||
}
|
||||
return statusMessageIDs, nil
|
||||
}
|
||||
|
||||
// Finds status messages id which are replies for bridgeMessageID
|
||||
func (db sqlitePersistence) findStatusMessageIdsRepliedTo(tx *sql.Tx, parentMessageID string) (string, error) {
|
||||
rows, err := tx.Query(`SELECT user_messages_id FROM bridge_messages WHERE message_id = ?`, parentMessageID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Next() {
|
||||
var statusMessageID string
|
||||
err = rows.Scan(&statusMessageID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return statusMessageID, nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (db sqlitePersistence) updateStatusMessagesWithResponse(tx *sql.Tx, statusMessagesToUpdate []string, responseValue string) error {
|
||||
sql := "UPDATE user_messages SET response_to = ? WHERE id IN (?" + strings.Repeat(",?", len(statusMessagesToUpdate)-1) + ")"
|
||||
stmt, err := tx.Prepare(sql)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
args := make([]interface{}, 0, len(statusMessagesToUpdate)+1)
|
||||
args = append(args, responseValue)
|
||||
for _, msgToUpdate := range statusMessagesToUpdate {
|
||||
args = append(args, msgToUpdate)
|
||||
}
|
||||
_, err = stmt.Exec(args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// Finds if there are any messages that are replies to that message (in case replies were received earlier)
|
||||
func (db sqlitePersistence) findAndUpdateReplies(tx *sql.Tx, bridgeMessageID string, statusMessageID string) error {
|
||||
replyMessageIds, err := db.findStatusMessageIdsReplies(tx, bridgeMessageID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(replyMessageIds) == 0 {
|
||||
return nil
|
||||
}
|
||||
return db.updateStatusMessagesWithResponse(tx, replyMessageIds, statusMessageID)
|
||||
}
|
||||
|
||||
func (db sqlitePersistence) findAndUpdateRepliedTo(tx *sql.Tx, discordParentMessageID string, statusMessageID string) error {
|
||||
repliedMessageID, err := db.findStatusMessageIdsRepliedTo(tx, discordParentMessageID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if repliedMessageID == "" {
|
||||
return nil
|
||||
}
|
||||
return db.updateStatusMessagesWithResponse(tx, []string{statusMessageID}, repliedMessageID)
|
||||
}
|
||||
|
|
|
@ -1911,6 +1911,85 @@ func TestSaveBridgeMessage(t *testing.T) {
|
|||
require.Equal(t, "789", retrievedMessages[0].GetBridgeMessage().ParentMessageID)
|
||||
}
|
||||
|
||||
func insertMinimalBridgeMessage(p *sqlitePersistence, messageID string, bridgeMessageID string, bridgeMessageParentID string) error {
|
||||
|
||||
bridgeMessage := &protobuf.BridgeMessage{
|
||||
BridgeName: "discord",
|
||||
UserName: "joe",
|
||||
Content: "abc",
|
||||
UserAvatar: "...",
|
||||
UserID: "123",
|
||||
MessageID: bridgeMessageID,
|
||||
ParentMessageID: bridgeMessageParentID,
|
||||
}
|
||||
|
||||
return p.SaveMessages([]*common.Message{{
|
||||
ID: messageID,
|
||||
LocalChatID: testPublicChatID,
|
||||
From: testPK,
|
||||
ChatMessage: &protobuf.ChatMessage{
|
||||
Text: "some-text",
|
||||
ContentType: protobuf.ChatMessage_BRIDGE_MESSAGE,
|
||||
ChatId: testPublicChatID,
|
||||
Payload: &protobuf.ChatMessage_BridgeMessage{
|
||||
BridgeMessage: bridgeMessage,
|
||||
},
|
||||
},
|
||||
}})
|
||||
}
|
||||
|
||||
func messageResponseTo(p *sqlitePersistence, messageID string) (string, error) {
|
||||
var responseTo string
|
||||
err := p.db.QueryRow("SELECT response_to FROM user_messages WHERE id = ?", messageID).Scan(&responseTo)
|
||||
return responseTo, err
|
||||
}
|
||||
|
||||
func TestBridgeMessageReplies(t *testing.T) {
|
||||
db, err := openTestDB()
|
||||
require.NoError(t, err)
|
||||
p := newSQLitePersistence(db)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
err = insertMinimalBridgeMessage(p, "111", "1", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = insertMinimalBridgeMessage(p, "222", "2", "1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// "333 is not delivered yet"
|
||||
|
||||
// this is a reply to a message which was not delivered yet
|
||||
err = insertMinimalBridgeMessage(p, "444", "4", "3")
|
||||
require.NoError(t, err)
|
||||
|
||||
// status message "222" should have reply_to = "111"
|
||||
responseTo, err := messageResponseTo(p, "222")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "111", responseTo)
|
||||
|
||||
responseTo, err = messageResponseTo(p, "111")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", responseTo)
|
||||
|
||||
responseTo, err = messageResponseTo(p, "444")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", responseTo)
|
||||
|
||||
// receiving message for which "444" is replied to
|
||||
err = insertMinimalBridgeMessage(p, "333", "3", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
responseTo, err = messageResponseTo(p, "333")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", responseTo)
|
||||
|
||||
// now 444 is replied to 333
|
||||
responseTo, err = messageResponseTo(p, "444")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "333", responseTo)
|
||||
}
|
||||
|
||||
func TestGetCommunityMemberAllNonDeletedMessages(t *testing.T) {
|
||||
db, err := openTestDB()
|
||||
require.NoError(t, err)
|
||||
|
|
Loading…
Reference in New Issue