status-go/protocol/message_persistence.go

957 lines
21 KiB
Go

package protocol
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"github.com/status-im/status-go/protocol/common"
"github.com/status-im/status-go/protocol/protobuf"
"github.com/pkg/errors"
)
var (
errRecordNotFound = errors.New("record not found")
)
func (db sqlitePersistence) tableUserMessagesAllFields() string {
return `id,
whisper_timestamp,
source,
text,
content_type,
username,
timestamp,
chat_id,
local_chat_id,
message_type,
clock_value,
seen,
outgoing_status,
parsed_text,
sticker_pack,
sticker_hash,
image_payload,
image_type,
image_base64,
audio_payload,
audio_type,
audio_duration_ms,
audio_base64,
mentions,
command_id,
command_value,
command_from,
command_address,
command_contract,
command_transaction_hash,
command_state,
command_signature,
replace_message,
rtl,
line_count,
response_to`
}
func (db sqlitePersistence) tableUserMessagesAllFieldsJoin() string {
return `m1.id,
m1.whisper_timestamp,
m1.source,
m1.text,
m1.content_type,
m1.username,
m1.timestamp,
m1.chat_id,
m1.local_chat_id,
m1.message_type,
m1.clock_value,
m1.seen,
m1.outgoing_status,
m1.parsed_text,
m1.sticker_pack,
m1.sticker_hash,
m1.image_base64,
COALESCE(m1.audio_duration_ms,0),
m1.audio_base64,
m1.mentions,
m1.command_id,
m1.command_value,
m1.command_from,
m1.command_address,
m1.command_contract,
m1.command_transaction_hash,
m1.command_state,
m1.command_signature,
m1.replace_message,
m1.rtl,
m1.line_count,
m1.response_to,
m2.source,
m2.text,
m2.parsed_text,
m2.image_base64,
m2.audio_duration_ms,
m2.audio_base64,
c.alias,
c.identicon`
}
func (db sqlitePersistence) tableUserMessagesAllFieldsCount() int {
return strings.Count(db.tableUserMessagesAllFields(), ",") + 1
}
type scanner interface {
Scan(dest ...interface{}) error
}
func (db sqlitePersistence) tableUserMessagesScanAllFields(row scanner, message *common.Message, others ...interface{}) error {
var quotedText sql.NullString
var quotedParsedText []byte
var quotedFrom sql.NullString
var quotedImage sql.NullString
var quotedAudio sql.NullString
var quotedAudioDuration sql.NullInt64
var serializedMentions []byte
var alias sql.NullString
var identicon sql.NullString
sticker := &protobuf.StickerMessage{}
command := &common.CommandParameters{}
audio := &protobuf.AudioMessage{}
args := []interface{}{
&message.ID,
&message.WhisperTimestamp,
&message.From, // source in table
&message.Text,
&message.ContentType,
&message.Alias,
&message.Timestamp,
&message.ChatId,
&message.LocalChatID,
&message.MessageType,
&message.Clock,
&message.Seen,
&message.OutgoingStatus,
&message.ParsedText,
&sticker.Pack,
&sticker.Hash,
&message.Base64Image,
&audio.DurationMs,
&message.Base64Audio,
&serializedMentions,
&command.ID,
&command.Value,
&command.From,
&command.Address,
&command.Contract,
&command.TransactionHash,
&command.CommandState,
&command.Signature,
&message.Replace,
&message.RTL,
&message.LineCount,
&message.ResponseTo,
&quotedFrom,
&quotedText,
&quotedParsedText,
&quotedImage,
&quotedAudioDuration,
&quotedAudio,
&alias,
&identicon,
}
err := row.Scan(append(args, others...)...)
if err != nil {
return err
}
if quotedText.Valid {
message.QuotedMessage = &common.QuotedMessage{
From: quotedFrom.String,
Text: quotedText.String,
ParsedText: quotedParsedText,
Base64Image: quotedImage.String,
AudioDurationMs: uint64(quotedAudioDuration.Int64),
Base64Audio: quotedAudio.String,
}
}
message.Alias = alias.String
message.Identicon = identicon.String
if serializedMentions != nil {
err := json.Unmarshal(serializedMentions, &message.Mentions)
if err != nil {
return err
}
}
switch message.ContentType {
case protobuf.ChatMessage_STICKER:
message.Payload = &protobuf.ChatMessage_Sticker{Sticker: sticker}
case protobuf.ChatMessage_AUDIO:
message.Payload = &protobuf.ChatMessage_Audio{Audio: audio}
case protobuf.ChatMessage_TRANSACTION_COMMAND:
message.CommandParameters = command
}
return nil
}
func (db sqlitePersistence) tableUserMessagesAllValues(message *common.Message) ([]interface{}, error) {
sticker := message.GetSticker()
if sticker == nil {
sticker = &protobuf.StickerMessage{}
}
image := message.GetImage()
if image == nil {
image = &protobuf.ImageMessage{}
}
audio := message.GetAudio()
if audio == nil {
audio = &protobuf.AudioMessage{}
}
command := message.CommandParameters
if command == nil {
command = &common.CommandParameters{}
}
var serializedMentions []byte
var err error
if len(message.Mentions) != 0 {
serializedMentions, err = json.Marshal(message.Mentions)
if err != nil {
return nil, err
}
}
return []interface{}{
message.ID,
message.WhisperTimestamp,
message.From, // source in table
message.Text,
message.ContentType,
message.Alias,
message.Timestamp,
message.ChatId,
message.LocalChatID,
message.MessageType,
message.Clock,
message.Seen,
message.OutgoingStatus,
message.ParsedText,
sticker.Pack,
sticker.Hash,
image.Payload,
image.Type,
message.Base64Image,
audio.Payload,
audio.Type,
audio.DurationMs,
message.Base64Audio,
serializedMentions,
command.ID,
command.Value,
command.From,
command.Address,
command.Contract,
command.TransactionHash,
command.CommandState,
command.Signature,
message.Replace,
message.RTL,
message.LineCount,
message.ResponseTo,
}, nil
}
func (db sqlitePersistence) messageByID(tx *sql.Tx, id string) (*common.Message, error) {
var err error
if tx == nil {
tx, err = db.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return nil, err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
}
var message common.Message
allFields := db.tableUserMessagesAllFieldsJoin()
row := tx.QueryRow(
fmt.Sprintf(`
SELECT
%s
FROM
user_messages m1
LEFT JOIN
user_messages m2
ON
m1.response_to = m2.id
LEFT JOIN
contacts c
ON
m1.source = c.id
WHERE
m1.id = ?
`, allFields),
id,
)
err = db.tableUserMessagesScanAllFields(row, &message)
switch err {
case sql.ErrNoRows:
return nil, errRecordNotFound
case nil:
return &message, nil
default:
return nil, err
}
}
func (db sqlitePersistence) MessageByCommandID(chatID, id string) (*common.Message, error) {
var message common.Message
allFields := db.tableUserMessagesAllFieldsJoin()
row := db.db.QueryRow(
fmt.Sprintf(`
SELECT
%s
FROM
user_messages m1
LEFT JOIN
user_messages m2
ON
m1.response_to = m2.id
LEFT JOIN
contacts c
ON
m1.source = c.id
WHERE
m1.command_id = ?
AND
m1.local_chat_id = ?
ORDER BY m1.clock_value DESC
LIMIT 1
`, allFields),
id,
chatID,
)
err := db.tableUserMessagesScanAllFields(row, &message)
switch err {
case sql.ErrNoRows:
return nil, errRecordNotFound
case nil:
return &message, nil
default:
return nil, err
}
}
func (db sqlitePersistence) MessageByID(id string) (*common.Message, error) {
return db.messageByID(nil, id)
}
func (db sqlitePersistence) MessagesExist(ids []string) (map[string]bool, error) {
result := make(map[string]bool)
if len(ids) == 0 {
return result, nil
}
idsArgs := make([]interface{}, 0, len(ids))
for _, id := range ids {
idsArgs = append(idsArgs, id)
}
inVector := strings.Repeat("?, ", len(ids)-1) + "?"
query := "SELECT id FROM user_messages WHERE id IN (" + inVector + ")" // nolint: gosec
rows, err := db.db.Query(query, idsArgs...)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var id string
err := rows.Scan(&id)
if err != nil {
return nil, err
}
result[id] = true
}
return result, nil
}
func (db sqlitePersistence) MessagesByIDs(ids []string) ([]*common.Message, error) {
if len(ids) == 0 {
return nil, nil
}
idsArgs := make([]interface{}, 0, len(ids))
for _, id := range ids {
idsArgs = append(idsArgs, id)
}
allFields := db.tableUserMessagesAllFieldsJoin()
inVector := strings.Repeat("?, ", len(ids)-1) + "?"
// nolint: gosec
rows, err := db.db.Query(fmt.Sprintf(`
SELECT
%s
FROM
user_messages m1
LEFT JOIN
user_messages m2
ON
m1.response_to = m2.id
LEFT JOIN
contacts c
ON
m1.source = c.id
WHERE NOT(m1.hide) AND m1.id IN (%s)`, allFields, inVector), idsArgs...)
if err != nil {
return nil, err
}
defer rows.Close()
var result []*common.Message
for rows.Next() {
var message common.Message
if err := db.tableUserMessagesScanAllFields(rows, &message); err != nil {
return nil, err
}
result = append(result, &message)
}
return result, nil
}
// MessageByChatID returns all messages for a given chatID in descending order.
// Ordering is accomplished using two concatenated values: ClockValue and ID.
// These two values are also used to compose a cursor which is returned to the result.
func (db sqlitePersistence) MessageByChatID(chatID string, currCursor string, limit int) ([]*common.Message, string, error) {
cursorWhere := ""
if currCursor != "" {
cursorWhere = "AND cursor <= ?"
}
allFields := db.tableUserMessagesAllFieldsJoin()
args := []interface{}{chatID}
if currCursor != "" {
args = append(args, currCursor)
}
// Build a new column `cursor` at the query time by having a fixed-sized clock value at the beginning
// concatenated with message ID. Results are sorted using this new column.
// This new column values can also be returned as a cursor for subsequent requests.
rows, err := db.db.Query(
fmt.Sprintf(`
SELECT
%s,
substr('0000000000000000000000000000000000000000000000000000000000000000' || m1.clock_value, -64, 64) || m1.id as cursor
FROM
user_messages m1
LEFT JOIN
user_messages m2
ON
m1.response_to = m2.id
LEFT JOIN
contacts c
ON
m1.source = c.id
WHERE
NOT(m1.hide) AND m1.local_chat_id = ? %s
ORDER BY cursor DESC
LIMIT ?
`, allFields, cursorWhere),
append(args, limit+1)..., // take one more to figure our whether a cursor should be returned
)
if err != nil {
return nil, "", err
}
defer rows.Close()
var (
result []*common.Message
cursors []string
)
for rows.Next() {
var (
message common.Message
cursor string
)
if err := db.tableUserMessagesScanAllFields(rows, &message, &cursor); err != nil {
return nil, "", err
}
result = append(result, &message)
cursors = append(cursors, cursor)
}
var newCursor string
if len(result) > limit {
newCursor = cursors[limit]
result = result[:limit]
}
return result, newCursor, nil
}
// EmojiReactionsByChatID returns the emoji reactions for the queried messages, up to a maximum of 100, as it's a potentially unbound number.
// NOTE: This is not completely accurate, as the messages in the database might have change since the last call to `MessageByChatID`.
func (db sqlitePersistence) EmojiReactionsByChatID(chatID string, currCursor string, limit int) ([]*EmojiReaction, error) {
cursorWhere := ""
if currCursor != "" {
cursorWhere = "AND substr('0000000000000000000000000000000000000000000000000000000000000000' || m.clock_value, -64, 64) || m.id <= ?"
}
args := []interface{}{chatID, chatID}
if currCursor != "" {
args = append(args, currCursor)
}
args = append(args, limit)
// NOTE: We match against local_chat_id for security reasons.
// As a user could potentially send an emoji reaction for a one to
// one/group chat that has no access to.
// We also limit the number of emoji to a reasonable number (1000)
// for now, as we don't want the client to choke on this.
// The issue is that your own emoji might not be returned in such cases,
// allowing the user to react to a post multiple times.
// Jakubgs: Returning the whole list seems like a real overkill.
// This will get very heavy in threads that have loads of reactions on loads of messages.
// A more sensible response would just include a count and a bool telling you if you are in the list.
// nolint: gosec
query := fmt.Sprintf(`
SELECT
e.clock_value,
e.source,
e.emoji_id,
e.message_id,
e.chat_id,
e.local_chat_id,
e.retracted
FROM
emoji_reactions e
WHERE NOT(e.retracted)
AND
e.local_chat_id = ?
AND
e.message_id IN
(SELECT id FROM user_messages m WHERE NOT(m.hide) AND m.local_chat_id = ? %s
ORDER BY substr('0000000000000000000000000000000000000000000000000000000000000000' || m.clock_value, -64, 64) || m.id DESC LIMIT ?)
LIMIT 1000
`, cursorWhere)
rows, err := db.db.Query(
query,
args...,
)
if err != nil {
return nil, err
}
defer rows.Close()
var result []*EmojiReaction
for rows.Next() {
var emojiReaction EmojiReaction
err := rows.Scan(&emojiReaction.Clock,
&emojiReaction.From,
&emojiReaction.Type,
&emojiReaction.MessageId,
&emojiReaction.ChatId,
&emojiReaction.LocalChatID,
&emojiReaction.Retracted)
if err != nil {
return nil, err
}
result = append(result, &emojiReaction)
}
return result, nil
}
func (db sqlitePersistence) SaveMessages(messages []*common.Message) (err error) {
tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
allFields := db.tableUserMessagesAllFields()
valuesVector := strings.Repeat("?, ", db.tableUserMessagesAllFieldsCount()-1) + "?"
query := "INSERT INTO user_messages(" + allFields + ") VALUES (" + valuesVector + ")" // nolint: gosec
stmt, err := tx.Prepare(query)
if err != nil {
return
}
for _, msg := range messages {
var allValues []interface{}
allValues, err = db.tableUserMessagesAllValues(msg)
if err != nil {
return
}
_, err = stmt.Exec(allValues...)
if err != nil {
return
}
}
return
}
func (db sqlitePersistence) DeleteMessage(id string) error {
_, err := db.db.Exec(`DELETE FROM user_messages WHERE id = ?`, id)
return err
}
func (db sqlitePersistence) HideMessage(id string) error {
_, err := db.db.Exec(`UPDATE user_messages SET hide = 1 WHERE id = ?`, id)
return err
}
func (db sqlitePersistence) DeleteMessagesByChatID(id string) error {
_, err := db.db.Exec(`DELETE FROM user_messages WHERE local_chat_id = ?`, id)
return err
}
func (db sqlitePersistence) MarkAllRead(chatID string) error {
tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
_, err = tx.Exec(`UPDATE user_messages SET seen = 1 WHERE local_chat_id = ? AND seen != 1`, chatID)
if err != nil {
return err
}
_, err = tx.Exec(`UPDATE chats SET unviewed_message_count = 0 WHERE id = ?`, chatID)
return err
}
func (db sqlitePersistence) MarkMessagesSeen(chatID string, ids []string) (uint64, error) {
tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return 0, err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
idsArgs := make([]interface{}, 0, len(ids))
for _, id := range ids {
idsArgs = append(idsArgs, id)
}
inVector := strings.Repeat("?, ", len(ids)-1) + "?"
q := "UPDATE user_messages SET seen = 1 WHERE id IN (" + inVector + ")" // nolint: gosec
_, err = tx.Exec(q, idsArgs...)
if err != nil {
return 0, err
}
var count uint64
row := tx.QueryRow("SELECT changes();")
if err := row.Scan(&count); err != nil {
return 0, err
}
// Update denormalized count
_, err = tx.Exec(
`UPDATE chats
SET unviewed_message_count =
(SELECT COUNT(1)
FROM user_messages
WHERE local_chat_id = ? AND seen = 0)
WHERE id = ?`, chatID, chatID)
return count, err
}
func (db sqlitePersistence) UpdateMessageOutgoingStatus(id string, newOutgoingStatus string) error {
_, err := db.db.Exec(`
UPDATE user_messages
SET outgoing_status = ?
WHERE id = ?
`, newOutgoingStatus, id)
return err
}
// BlockContact updates a contact, deletes all the messages and 1-to-1 chat, updates the unread messages count and returns a map with the new count
func (db sqlitePersistence) BlockContact(contact *Contact) ([]*Chat, error) {
var chats []*Chat
tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return nil, err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
// Delete messages
_, err = tx.Exec(
`DELETE
FROM user_messages
WHERE source = ?`,
contact.ID,
)
if err != nil {
return nil, err
}
// Update contact
err = db.SaveContact(contact, tx)
if err != nil {
return nil, err
}
// Delete one-to-one chat
_, err = tx.Exec("DELETE FROM chats WHERE id = ?", contact.ID)
if err != nil {
return nil, err
}
// Recalculate denormalized fields
_, err = tx.Exec(`
UPDATE chats
SET
unviewed_message_count = (SELECT COUNT(1) FROM user_messages WHERE seen = 0 AND local_chat_id = chats.id)`)
if err != nil {
return nil, err
}
// return the updated chats
chats, err = db.chats(tx)
if err != nil {
return nil, err
}
for _, c := range chats {
var lastMessageID string
row := tx.QueryRow(`SELECT id FROM user_messages WHERE local_chat_id = ? ORDER BY clock_value DESC LIMIT 1`, c.ID)
switch err := row.Scan(&lastMessageID); err {
case nil:
message, err := db.messageByID(tx, lastMessageID)
if err != nil {
return nil, err
}
if message != nil {
encodedMessage, err := json.Marshal(message)
if err != nil {
return nil, err
}
_, err = tx.Exec(`UPDATE chats SET last_message = ? WHERE id = ?`, encodedMessage, c.ID)
if err != nil {
return nil, err
}
c.LastMessage = message
}
case sql.ErrNoRows:
// Reset LastMessage
_, err = tx.Exec(`UPDATE chats SET last_message = NULL WHERE id = ?`, c.ID)
if err != nil {
return nil, err
}
c.LastMessage = nil
default:
return nil, err
}
}
return chats, err
}
func (db sqlitePersistence) SaveEmojiReaction(emojiReaction *EmojiReaction) (err error) {
query := "INSERT INTO emoji_reactions(id,clock_value,source,emoji_id,message_id,chat_id,local_chat_id,retracted) VALUES (?,?,?,?,?,?,?,?)"
stmt, err := db.db.Prepare(query)
if err != nil {
return
}
_, err = stmt.Exec(
emojiReaction.ID(),
emojiReaction.Clock,
emojiReaction.From,
emojiReaction.Type,
emojiReaction.MessageId,
emojiReaction.ChatId,
emojiReaction.LocalChatID,
emojiReaction.Retracted,
)
return
}
func (db sqlitePersistence) EmojiReactionByID(id string) (*EmojiReaction, error) {
row := db.db.QueryRow(
`SELECT
clock_value,
source,
emoji_id,
message_id,
chat_id,
local_chat_id,
retracted
FROM
emoji_reactions
WHERE
emoji_reactions.id = ?
`, id)
emojiReaction := new(EmojiReaction)
err := row.Scan(&emojiReaction.Clock,
&emojiReaction.From,
&emojiReaction.Type,
&emojiReaction.MessageId,
&emojiReaction.ChatId,
&emojiReaction.LocalChatID,
&emojiReaction.Retracted,
)
switch err {
case sql.ErrNoRows:
return nil, errRecordNotFound
case nil:
return emojiReaction, nil
default:
return nil, err
}
}
func (db sqlitePersistence) SaveInvitation(invitation *GroupChatInvitation) (err error) {
query := "INSERT INTO group_chat_invitations(id,source,chat_id,message,state,clock) VALUES (?,?,?,?,?,?)"
stmt, err := db.db.Prepare(query)
if err != nil {
return
}
_, err = stmt.Exec(
invitation.ID(),
invitation.From,
invitation.ChatId,
invitation.IntroductionMessage,
invitation.State,
invitation.Clock,
)
return
}
func (db sqlitePersistence) GetGroupChatInvitations() (rst []*GroupChatInvitation, err error) {
tx, err := db.db.Begin()
if err != nil {
return
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
_ = tx.Rollback()
}()
bRows, err := tx.Query(`SELECT
source,
chat_id,
message,
state,
clock
FROM
group_chat_invitations`)
if err != nil {
return
}
defer bRows.Close()
for bRows.Next() {
invitation := GroupChatInvitation{}
err = bRows.Scan(
&invitation.From,
&invitation.ChatId,
&invitation.IntroductionMessage,
&invitation.State,
&invitation.Clock)
if err != nil {
return nil, err
}
rst = append(rst, &invitation)
}
return rst, nil
}
func (db sqlitePersistence) InvitationByID(id string) (*GroupChatInvitation, error) {
row := db.db.QueryRow(
`SELECT
source,
chat_id,
message,
state,
clock
FROM
group_chat_invitations
WHERE
group_chat_invitations.id = ?
`, id)
chatInvitations := new(GroupChatInvitation)
err := row.Scan(&chatInvitations.From,
&chatInvitations.ChatId,
&chatInvitations.IntroductionMessage,
&chatInvitations.State,
&chatInvitations.Clock,
)
switch err {
case sql.ErrNoRows:
return nil, errRecordNotFound
case nil:
return chatInvitations, nil
default:
return nil, err
}
}