status-go/protocol/message_persistence.go

838 lines
19 KiB
Go
Raw Normal View History

package protocol
2019-08-06 23:50:13 +02:00
import (
"context"
2019-08-06 23:50:13 +02:00
"database/sql"
"encoding/json"
2019-08-06 23:50:13 +02:00
"fmt"
"strings"
2020-01-02 10:10:19 +01:00
"github.com/status-im/status-go/protocol/protobuf"
2019-08-06 23:50:13 +02:00
"github.com/pkg/errors"
)
var (
errRecordNotFound = errors.New("record not found")
)
func (db sqlitePersistence) tableUserMessagesAllFields() string {
2019-08-06 23:50:13 +02:00
return `id,
whisper_timestamp,
source,
text,
2019-08-06 23:50:13 +02:00
content_type,
username,
timestamp,
chat_id,
local_chat_id,
2019-08-06 23:50:13 +02:00
message_type,
clock_value,
seen,
outgoing_status,
parsed_text,
sticker_pack,
sticker_hash,
image_payload,
image_type,
image_base64,
audio_payload,
audio_type,
2020-06-23 16:30:39 +02:00
audio_duration_ms,
audio_base64,
command_id,
command_value,
command_from,
command_address,
command_contract,
command_transaction_hash,
command_state,
command_signature,
replace_message,
rtl,
line_count,
response_to`
}
func (db sqlitePersistence) tableUserMessagesAllFieldsJoin() string {
return `m1.id,
m1.whisper_timestamp,
m1.source,
m1.text,
m1.content_type,
m1.username,
m1.timestamp,
m1.chat_id,
m1.local_chat_id,
m1.message_type,
m1.clock_value,
m1.seen,
m1.outgoing_status,
m1.parsed_text,
m1.sticker_pack,
m1.sticker_hash,
m1.image_base64,
COALESCE(m1.audio_duration_ms,0),
m1.audio_base64,
m1.command_id,
m1.command_value,
m1.command_from,
m1.command_address,
m1.command_contract,
m1.command_transaction_hash,
m1.command_state,
m1.command_signature,
m1.replace_message,
m1.rtl,
m1.line_count,
m1.response_to,
m2.source,
m2.text,
m2.image_base64,
2020-06-23 16:30:39 +02:00
m2.audio_duration_ms,
m2.audio_base64,
c.alias,
c.identicon`
2019-08-06 23:50:13 +02:00
}
func (db sqlitePersistence) tableUserMessagesAllFieldsCount() int {
return strings.Count(db.tableUserMessagesAllFields(), ",") + 1
2019-08-06 23:50:13 +02:00
}
type scanner interface {
Scan(dest ...interface{}) error
}
func (db sqlitePersistence) tableUserMessagesScanAllFields(row scanner, message *Message, others ...interface{}) error {
var quotedText sql.NullString
var quotedFrom sql.NullString
var quotedImage sql.NullString
var quotedAudio sql.NullString
2020-06-23 16:30:39 +02:00
var quotedAudioDuration sql.NullInt64
var alias sql.NullString
var identicon sql.NullString
sticker := &protobuf.StickerMessage{}
command := &CommandParameters{}
2020-06-23 16:30:39 +02:00
audio := &protobuf.AudioMessage{}
2019-08-06 23:50:13 +02:00
args := []interface{}{
&message.ID,
&message.WhisperTimestamp,
&message.From, // source in table
&message.Text,
2019-08-06 23:50:13 +02:00
&message.ContentType,
&message.Alias,
2019-08-06 23:50:13 +02:00
&message.Timestamp,
&message.ChatId,
&message.LocalChatID,
2019-08-06 23:50:13 +02:00
&message.MessageType,
&message.Clock,
2019-08-06 23:50:13 +02:00
&message.Seen,
&message.OutgoingStatus,
&message.ParsedText,
&sticker.Pack,
&sticker.Hash,
&message.Base64Image,
2020-06-23 16:30:39 +02:00
&audio.DurationMs,
&message.Base64Audio,
&command.ID,
&command.Value,
&command.From,
&command.Address,
&command.Contract,
&command.TransactionHash,
&command.CommandState,
&command.Signature,
&message.Replace,
&message.RTL,
&message.LineCount,
&message.ResponseTo,
&quotedFrom,
&quotedText,
&quotedImage,
2020-06-23 16:30:39 +02:00
&quotedAudioDuration,
&quotedAudio,
&alias,
&identicon,
}
err := row.Scan(append(args, others...)...)
if err != nil {
return err
}
if quotedText.Valid {
message.QuotedMessage = &QuotedMessage{
2020-06-23 16:30:39 +02:00
From: quotedFrom.String,
Text: quotedText.String,
Base64Image: quotedImage.String,
AudioDurationMs: uint64(quotedAudioDuration.Int64),
Base64Audio: quotedAudio.String,
}
2019-08-06 23:50:13 +02:00
}
message.Alias = alias.String
message.Identicon = identicon.String
2020-07-10 15:20:18 +01:00
switch message.ContentType {
case protobuf.ChatMessage_STICKER:
message.Payload = &protobuf.ChatMessage_Sticker{Sticker: sticker}
case protobuf.ChatMessage_AUDIO:
message.Payload = &protobuf.ChatMessage_Audio{Audio: audio}
2020-06-23 16:30:39 +02:00
2020-07-10 15:20:18 +01:00
case protobuf.ChatMessage_TRANSACTION_COMMAND:
message.CommandParameters = command
}
return nil
2019-08-06 23:50:13 +02:00
}
func (db sqlitePersistence) tableUserMessagesAllValues(message *Message) ([]interface{}, error) {
sticker := message.GetSticker()
if sticker == nil {
sticker = &protobuf.StickerMessage{}
}
image := message.GetImage()
if image == nil {
image = &protobuf.ImageMessage{}
}
audio := message.GetAudio()
if audio == nil {
audio = &protobuf.AudioMessage{}
}
command := message.CommandParameters
if command == nil {
command = &CommandParameters{}
}
2019-08-06 23:50:13 +02:00
return []interface{}{
message.ID,
message.WhisperTimestamp,
message.From, // source in table
message.Text,
2019-08-06 23:50:13 +02:00
message.ContentType,
message.Alias,
2019-08-06 23:50:13 +02:00
message.Timestamp,
message.ChatId,
message.LocalChatID,
2019-08-06 23:50:13 +02:00
message.MessageType,
message.Clock,
2019-08-06 23:50:13 +02:00
message.Seen,
message.OutgoingStatus,
message.ParsedText,
sticker.Pack,
sticker.Hash,
image.Payload,
image.Type,
message.Base64Image,
audio.Payload,
audio.Type,
2020-06-23 16:30:39 +02:00
audio.DurationMs,
message.Base64Audio,
command.ID,
command.Value,
command.From,
command.Address,
command.Contract,
command.TransactionHash,
command.CommandState,
command.Signature,
message.Replace,
message.RTL,
message.LineCount,
message.ResponseTo,
}, nil
2019-08-06 23:50:13 +02:00
}
func (db sqlitePersistence) messageByID(tx *sql.Tx, id string) (*Message, error) {
var err error
if tx == nil {
tx, err = db.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return nil, err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
}
2019-08-06 23:50:13 +02:00
var message Message
allFields := db.tableUserMessagesAllFieldsJoin()
row := tx.QueryRow(
2019-08-06 23:50:13 +02:00
fmt.Sprintf(`
SELECT
%s
FROM
user_messages m1
LEFT JOIN
user_messages m2
ON
m1.response_to = m2.id
LEFT JOIN
contacts c
ON
m1.source = c.id
2019-08-06 23:50:13 +02:00
WHERE
m1.id = ?
2019-08-06 23:50:13 +02:00
`, allFields),
id,
)
err = db.tableUserMessagesScanAllFields(row, &message)
2019-08-06 23:50:13 +02:00
switch err {
case sql.ErrNoRows:
return nil, errRecordNotFound
case nil:
return &message, nil
default:
return nil, err
}
}
2020-01-17 13:39:09 +01:00
func (db sqlitePersistence) MessageByCommandID(chatID, id string) (*Message, error) {
var message Message
allFields := db.tableUserMessagesAllFieldsJoin()
row := db.db.QueryRow(
fmt.Sprintf(`
SELECT
%s
FROM
user_messages m1
LEFT JOIN
user_messages m2
ON
m1.response_to = m2.id
LEFT JOIN
contacts c
ON
m1.source = c.id
WHERE
m1.command_id = ?
2020-01-17 13:39:09 +01:00
AND
m1.local_chat_id = ?
ORDER BY m1.clock_value DESC
LIMIT 1
`, allFields),
id,
2020-01-17 13:39:09 +01:00
chatID,
)
err := db.tableUserMessagesScanAllFields(row, &message)
switch err {
case sql.ErrNoRows:
return nil, errRecordNotFound
case nil:
return &message, nil
default:
return nil, err
}
}
func (db sqlitePersistence) MessageByID(id string) (*Message, error) {
return db.messageByID(nil, id)
}
func (db sqlitePersistence) MessagesExist(ids []string) (map[string]bool, error) {
result := make(map[string]bool)
if len(ids) == 0 {
2019-08-06 23:50:13 +02:00
return result, nil
}
idsArgs := make([]interface{}, 0, len(ids))
for _, id := range ids {
idsArgs = append(idsArgs, id)
}
inVector := strings.Repeat("?, ", len(ids)-1) + "?"
query := "SELECT id FROM user_messages WHERE id IN (" + inVector + ")" // nolint: gosec
rows, err := db.db.Query(query, idsArgs...)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var id string
err := rows.Scan(&id)
if err != nil {
return nil, err
}
result[id] = true
}
return result, nil
2019-08-06 23:50:13 +02:00
}
func (db sqlitePersistence) MessagesByIDs(ids []string) ([]*Message, error) {
if len(ids) == 0 {
return nil, nil
}
idsArgs := make([]interface{}, 0, len(ids))
for _, id := range ids {
idsArgs = append(idsArgs, id)
}
allFields := db.tableUserMessagesAllFieldsJoin()
inVector := strings.Repeat("?, ", len(ids)-1) + "?"
// nolint: gosec
rows, err := db.db.Query(fmt.Sprintf(`
SELECT
%s
FROM
user_messages m1
LEFT JOIN
user_messages m2
ON
m1.response_to = m2.id
LEFT JOIN
contacts c
ON
m1.source = c.id
WHERE NOT(m1.hide) AND m1.id IN (%s)`, allFields, inVector), idsArgs...)
if err != nil {
return nil, err
}
defer rows.Close()
var result []*Message
for rows.Next() {
var message Message
if err := db.tableUserMessagesScanAllFields(rows, &message); err != nil {
return nil, err
}
result = append(result, &message)
}
return result, nil
}
2019-08-06 23:50:13 +02:00
// MessageByChatID returns all messages for a given chatID in descending order.
// Ordering is accomplished using two concatenated values: ClockValue and ID.
// These two values are also used to compose a cursor which is returned to the result.
func (db sqlitePersistence) MessageByChatID(chatID string, currCursor string, limit int) ([]*Message, string, error) {
cursorWhere := ""
if currCursor != "" {
cursorWhere = "AND cursor <= ?"
}
allFields := db.tableUserMessagesAllFieldsJoin()
2019-08-06 23:50:13 +02:00
args := []interface{}{chatID}
if currCursor != "" {
args = append(args, currCursor)
}
// Build a new column `cursor` at the query time by having a fixed-sized clock value at the beginning
// concatenated with message ID. Results are sorted using this new column.
2019-08-06 23:50:13 +02:00
// This new column values can also be returned as a cursor for subsequent requests.
rows, err := db.db.Query(
fmt.Sprintf(`
SELECT
%s,
substr('0000000000000000000000000000000000000000000000000000000000000000' || m1.clock_value, -64, 64) || m1.id as cursor
2019-08-06 23:50:13 +02:00
FROM
user_messages m1
LEFT JOIN
user_messages m2
ON
m1.response_to = m2.id
LEFT JOIN
contacts c
ON
m1.source = c.id
2019-08-06 23:50:13 +02:00
WHERE
NOT(m1.hide) AND m1.local_chat_id = ? %s
2019-08-06 23:50:13 +02:00
ORDER BY cursor DESC
LIMIT ?
`, allFields, cursorWhere),
append(args, limit+1)..., // take one more to figure our whether a cursor should be returned
)
if err != nil {
return nil, "", err
}
defer rows.Close()
var (
result []*Message
cursors []string
)
for rows.Next() {
var (
message Message
cursor string
)
if err := db.tableUserMessagesScanAllFields(rows, &message, &cursor); err != nil {
2019-08-06 23:50:13 +02:00
return nil, "", err
}
result = append(result, &message)
cursors = append(cursors, cursor)
}
var newCursor string
if len(result) > limit {
newCursor = cursors[limit]
result = result[:limit]
}
return result, newCursor, nil
}
2020-07-27 17:57:01 +02:00
// EmojiReactionsByChatID returns the emoji reactions for the queried messages, up to a maximum of 100, as it's a potentially unbound number.
// NOTE: This is not completely accurate, as the messages in the database might have change since the last call to `MessageByChatID`.
func (db sqlitePersistence) EmojiReactionsByChatID(chatID string, currCursor string, limit int) ([]*EmojiReaction, error) {
cursorWhere := ""
if currCursor != "" {
2020-07-28 07:41:50 +02:00
cursorWhere = "AND substr('0000000000000000000000000000000000000000000000000000000000000000' || m.clock_value, -64, 64) || m.id <= ?"
2020-07-27 17:57:01 +02:00
}
args := []interface{}{chatID, chatID}
2020-07-27 17:57:01 +02:00
if currCursor != "" {
args = append(args, currCursor)
}
2020-07-28 07:41:50 +02:00
args = append(args, limit)
// NOTE: We match against local_chat_id for security reasons.
// As a user could potentially send an emoji reaction for a one to
// one/group chat that has no access to.
2020-07-29 14:00:51 +02:00
// We also limit the number of emoji to a reasonable number (1000)
// for now, as we don't want the client to choke on this.
// The issue is that your own emoji might not be returned in such cases,
// allowing the user to react to a post multiple times.
// Jakubgs: Returning the whole list seems like a real overkill.
// This will get very heavy in threads that have loads of reactions on loads of messages.
// A more sensible response would just include a count and a bool telling you if you are in the list.
2020-07-28 10:02:51 +02:00
// nolint: gosec
2020-07-28 07:41:50 +02:00
query := fmt.Sprintf(`
2020-07-27 17:57:01 +02:00
SELECT
e.clock_value,
e.source,
e.emoji_id,
e.message_id,
e.chat_id,
e.local_chat_id,
2020-07-27 17:57:01 +02:00
e.retracted
FROM
emoji_reactions e
WHERE NOT(e.retracted)
AND
e.local_chat_id = ?
AND
2020-07-27 17:57:01 +02:00
e.message_id IN
(SELECT id FROM user_messages m WHERE NOT(m.hide) AND m.local_chat_id = ? %s
ORDER BY substr('0000000000000000000000000000000000000000000000000000000000000000' || m.clock_value, -64, 64) || m.id DESC LIMIT ?)
2020-07-29 14:00:51 +02:00
LIMIT 1000
2020-07-28 07:41:50 +02:00
`, cursorWhere)
rows, err := db.db.Query(
query,
args...,
2020-07-27 17:57:01 +02:00
)
if err != nil {
return nil, err
}
defer rows.Close()
var result []*EmojiReaction
for rows.Next() {
var emojiReaction EmojiReaction
err := rows.Scan(&emojiReaction.Clock,
&emojiReaction.From,
&emojiReaction.Type,
&emojiReaction.MessageId,
&emojiReaction.ChatId,
&emojiReaction.LocalChatID,
2020-07-27 17:57:01 +02:00
&emojiReaction.Retracted)
if err != nil {
return nil, err
}
result = append(result, &emojiReaction)
}
return result, nil
}
func (db sqlitePersistence) SaveMessages(messages []*Message) (err error) {
tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
2019-08-06 23:50:13 +02:00
if err != nil {
return
2019-08-06 23:50:13 +02:00
}
defer func() {
if err == nil {
err = tx.Commit()
return
2019-08-06 23:50:13 +02:00
}
// don't shadow original error
_ = tx.Rollback()
}()
2019-08-06 23:50:13 +02:00
allFields := db.tableUserMessagesAllFields()
valuesVector := strings.Repeat("?, ", db.tableUserMessagesAllFieldsCount()-1) + "?"
query := "INSERT INTO user_messages(" + allFields + ") VALUES (" + valuesVector + ")" // nolint: gosec
stmt, err := tx.Prepare(query)
2019-08-06 23:50:13 +02:00
if err != nil {
return
2019-08-06 23:50:13 +02:00
}
for _, msg := range messages {
var allValues []interface{}
allValues, err = db.tableUserMessagesAllValues(msg)
if err != nil {
return
}
_, err = stmt.Exec(allValues...)
if err != nil {
return
2019-08-06 23:50:13 +02:00
}
}
return
2019-08-06 23:50:13 +02:00
}
func (db sqlitePersistence) DeleteMessage(id string) error {
_, err := db.db.Exec(`DELETE FROM user_messages WHERE id = ?`, id)
return err
}
func (db sqlitePersistence) HideMessage(id string) error {
_, err := db.db.Exec(`UPDATE user_messages SET hide = 1 WHERE id = ?`, id)
2019-08-06 23:50:13 +02:00
return err
}
func (db sqlitePersistence) DeleteMessagesByChatID(id string) error {
_, err := db.db.Exec(`DELETE FROM user_messages WHERE local_chat_id = ?`, id)
return err
}
2020-02-26 13:31:48 +01:00
func (db sqlitePersistence) MarkAllRead(chatID string) error {
tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
_, err = tx.Exec(`UPDATE user_messages SET seen = 1 WHERE local_chat_id = ?`, chatID)
if err != nil {
return err
}
_, err = tx.Exec(`UPDATE chats SET unviewed_message_count = 0 WHERE id = ?`, chatID)
return err
}
func (db sqlitePersistence) MarkMessagesSeen(chatID string, ids []string) (uint64, error) {
tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return 0, err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
2019-08-06 23:50:13 +02:00
idsArgs := make([]interface{}, 0, len(ids))
for _, id := range ids {
idsArgs = append(idsArgs, id)
}
inVector := strings.Repeat("?, ", len(ids)-1) + "?"
q := "UPDATE user_messages SET seen = 1 WHERE id IN (" + inVector + ")" // nolint: gosec
_, err = tx.Exec(q, idsArgs...)
if err != nil {
return 0, err
}
var count uint64
row := tx.QueryRow("SELECT changes();")
if err := row.Scan(&count); err != nil {
return 0, err
}
// Update denormalized count
_, err = tx.Exec(
`UPDATE chats
SET unviewed_message_count =
(SELECT COUNT(1)
FROM user_messages
WHERE local_chat_id = ? AND seen = 0)
WHERE id = ?`, chatID, chatID)
return count, err
2019-08-06 23:50:13 +02:00
}
func (db sqlitePersistence) UpdateMessageOutgoingStatus(id string, newOutgoingStatus string) error {
_, err := db.db.Exec(`
UPDATE user_messages
2019-08-06 23:50:13 +02:00
SET outgoing_status = ?
WHERE id = ?
`, newOutgoingStatus, id)
return err
}
// BlockContact updates a contact, deletes all the messages and 1-to-1 chat, updates the unread messages count and returns a map with the new count
func (db sqlitePersistence) BlockContact(contact *Contact) ([]*Chat, error) {
var chats []*Chat
tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return nil, err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
// Delete messages
_, err = tx.Exec(
`DELETE
FROM user_messages
WHERE source = ?`,
contact.ID,
)
if err != nil {
return nil, err
}
// Update contact
err = db.SaveContact(contact, tx)
if err != nil {
return nil, err
}
// Delete one-to-one chat
_, err = tx.Exec("DELETE FROM chats WHERE id = ?", contact.ID)
if err != nil {
return nil, err
}
// Recalculate denormalized fields
_, err = tx.Exec(`
UPDATE chats
SET
unviewed_message_count = (SELECT COUNT(1) FROM user_messages WHERE seen = 0 AND local_chat_id = chats.id)`)
if err != nil {
return nil, err
}
// return the updated chats
chats, err = db.chats(tx)
if err != nil {
return nil, err
}
for _, c := range chats {
var lastMessageID string
row := tx.QueryRow(`SELECT id FROM user_messages WHERE local_chat_id = ? ORDER BY clock_value DESC LIMIT 1`, c.ID)
switch err := row.Scan(&lastMessageID); err {
case nil:
message, err := db.messageByID(tx, lastMessageID)
if err != nil {
return nil, err
}
if message != nil {
encodedMessage, err := json.Marshal(message)
if err != nil {
return nil, err
}
_, err = tx.Exec(`UPDATE chats SET last_message = ? WHERE id = ?`, encodedMessage, c.ID)
if err != nil {
return nil, err
}
c.LastMessage = message
}
case sql.ErrNoRows:
// Reset LastMessage
_, err = tx.Exec(`UPDATE chats SET last_message = NULL WHERE id = ?`, c.ID)
if err != nil {
return nil, err
}
c.LastMessage = nil
default:
return nil, err
}
}
return chats, err
}
2020-07-22 00:42:55 +01:00
func (db sqlitePersistence) SaveEmojiReaction(emojiReaction *EmojiReaction) (err error) {
query := "INSERT INTO emoji_reactions(id,clock_value,source,emoji_id,message_id,chat_id,local_chat_id,retracted) VALUES (?,?,?,?,?,?,?,?)"
stmt, err := db.db.Prepare(query)
2020-07-22 00:42:55 +01:00
if err != nil {
return
}
_, err = stmt.Exec(
emojiReaction.ID(),
2020-07-22 00:42:55 +01:00
emojiReaction.Clock,
emojiReaction.From,
emojiReaction.Type,
emojiReaction.MessageId,
emojiReaction.ChatId,
emojiReaction.LocalChatID,
2020-07-22 00:42:55 +01:00
emojiReaction.Retracted,
)
2020-07-22 00:42:55 +01:00
return
}
2020-07-22 01:21:05 +01:00
func (db sqlitePersistence) EmojiReactionByID(id string) (*EmojiReaction, error) {
row := db.db.QueryRow(
`SELECT
clock_value,
source,
emoji_id,
message_id,
chat_id,
local_chat_id,
retracted
2020-07-22 01:21:05 +01:00
FROM
emoji_reactions
WHERE
emoji_reactions.id = ?
`, id)
2020-07-22 01:21:05 +01:00
emojiReaction := new(EmojiReaction)
err := row.Scan(&emojiReaction.Clock,
2020-07-22 01:21:05 +01:00
&emojiReaction.From,
&emojiReaction.Type,
&emojiReaction.MessageId,
&emojiReaction.ChatId,
&emojiReaction.LocalChatID,
2020-07-22 01:21:05 +01:00
&emojiReaction.Retracted,
)
switch err {
case sql.ErrNoRows:
return nil, errRecordNotFound
case nil:
return emojiReaction, nil
default:
return nil, err
}
}