status-go/protocol/message_persistence.go

708 lines
15 KiB
Go

package protocol
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"github.com/status-im/status-go/protocol/protobuf"
"github.com/pkg/errors"
)
var (
errRecordNotFound = errors.New("record not found")
)
func (db sqlitePersistence) tableUserMessagesAllFields() string {
return `id,
whisper_timestamp,
source,
text,
content_type,
username,
timestamp,
chat_id,
local_chat_id,
message_type,
clock_value,
seen,
outgoing_status,
parsed_text,
sticker_pack,
sticker_hash,
image_payload,
image_type,
image_base64,
audio_payload,
audio_type,
audio_duration_ms,
audio_base64,
command_id,
command_value,
command_from,
command_address,
command_contract,
command_transaction_hash,
command_state,
command_signature,
replace_message,
rtl,
line_count,
response_to`
}
func (db sqlitePersistence) tableUserMessagesAllFieldsJoin() string {
return `m1.id,
m1.whisper_timestamp,
m1.source,
m1.text,
m1.content_type,
m1.username,
m1.timestamp,
m1.chat_id,
m1.local_chat_id,
m1.message_type,
m1.clock_value,
m1.seen,
m1.outgoing_status,
m1.parsed_text,
m1.sticker_pack,
m1.sticker_hash,
m1.image_base64,
COALESCE(m1.audio_duration_ms,0),
m1.audio_base64,
m1.command_id,
m1.command_value,
m1.command_from,
m1.command_address,
m1.command_contract,
m1.command_transaction_hash,
m1.command_state,
m1.command_signature,
m1.replace_message,
m1.rtl,
m1.line_count,
m1.response_to,
m2.source,
m2.text,
m2.image_base64,
m2.audio_duration_ms,
m2.audio_base64,
c.alias,
c.identicon`
}
func (db sqlitePersistence) tableUserMessagesAllFieldsCount() int {
return strings.Count(db.tableUserMessagesAllFields(), ",") + 1
}
type scanner interface {
Scan(dest ...interface{}) error
}
func (db sqlitePersistence) tableUserMessagesScanAllFields(row scanner, message *Message, others ...interface{}) error {
var quotedText sql.NullString
var quotedFrom sql.NullString
var quotedImage sql.NullString
var quotedAudio sql.NullString
var quotedAudioDuration sql.NullInt64
var alias sql.NullString
var identicon sql.NullString
sticker := &protobuf.StickerMessage{}
command := &CommandParameters{}
audio := &protobuf.AudioMessage{}
args := []interface{}{
&message.ID,
&message.WhisperTimestamp,
&message.From, // source in table
&message.Text,
&message.ContentType,
&message.Alias,
&message.Timestamp,
&message.ChatId,
&message.LocalChatID,
&message.MessageType,
&message.Clock,
&message.Seen,
&message.OutgoingStatus,
&message.ParsedText,
&sticker.Pack,
&sticker.Hash,
&message.Base64Image,
&audio.DurationMs,
&message.Base64Audio,
&command.ID,
&command.Value,
&command.From,
&command.Address,
&command.Contract,
&command.TransactionHash,
&command.CommandState,
&command.Signature,
&message.Replace,
&message.RTL,
&message.LineCount,
&message.ResponseTo,
&quotedFrom,
&quotedText,
&quotedImage,
&quotedAudioDuration,
&quotedAudio,
&alias,
&identicon,
}
err := row.Scan(append(args, others...)...)
if err != nil {
return err
}
if quotedText.Valid {
message.QuotedMessage = &QuotedMessage{
From: quotedFrom.String,
Text: quotedText.String,
Base64Image: quotedImage.String,
AudioDurationMs: uint64(quotedAudioDuration.Int64),
Base64Audio: quotedAudio.String,
}
}
message.Alias = alias.String
message.Identicon = identicon.String
switch message.ContentType {
case protobuf.ChatMessage_STICKER:
message.Payload = &protobuf.ChatMessage_Sticker{Sticker: sticker}
if message.ContentType == protobuf.ChatMessage_AUDIO {
message.Payload = &protobuf.ChatMessage_Audio{Audio: audio}
}
case protobuf.ChatMessage_TRANSACTION_COMMAND:
message.CommandParameters = command
}
return nil
}
func (db sqlitePersistence) tableUserMessagesAllValues(message *Message) ([]interface{}, error) {
sticker := message.GetSticker()
if sticker == nil {
sticker = &protobuf.StickerMessage{}
}
image := message.GetImage()
if image == nil {
image = &protobuf.ImageMessage{}
}
audio := message.GetAudio()
if audio == nil {
audio = &protobuf.AudioMessage{}
}
command := message.CommandParameters
if command == nil {
command = &CommandParameters{}
}
return []interface{}{
message.ID,
message.WhisperTimestamp,
message.From, // source in table
message.Text,
message.ContentType,
message.Alias,
message.Timestamp,
message.ChatId,
message.LocalChatID,
message.MessageType,
message.Clock,
message.Seen,
message.OutgoingStatus,
message.ParsedText,
sticker.Pack,
sticker.Hash,
image.Payload,
image.Type,
message.Base64Image,
audio.Payload,
audio.Type,
audio.DurationMs,
message.Base64Audio,
command.ID,
command.Value,
command.From,
command.Address,
command.Contract,
command.TransactionHash,
command.CommandState,
command.Signature,
message.Replace,
message.RTL,
message.LineCount,
message.ResponseTo,
}, nil
}
func (db sqlitePersistence) messageByID(tx *sql.Tx, id string) (*Message, error) {
var err error
if tx == nil {
tx, err = db.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return nil, err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
}
var message Message
allFields := db.tableUserMessagesAllFieldsJoin()
row := tx.QueryRow(
fmt.Sprintf(`
SELECT
%s
FROM
user_messages m1
LEFT JOIN
user_messages m2
ON
m1.response_to = m2.id
LEFT JOIN
contacts c
ON
m1.source = c.id
WHERE
m1.id = ?
`, allFields),
id,
)
err = db.tableUserMessagesScanAllFields(row, &message)
switch err {
case sql.ErrNoRows:
return nil, errRecordNotFound
case nil:
return &message, nil
default:
return nil, err
}
}
func (db sqlitePersistence) MessageByCommandID(chatID, id string) (*Message, error) {
var message Message
allFields := db.tableUserMessagesAllFieldsJoin()
row := db.db.QueryRow(
fmt.Sprintf(`
SELECT
%s
FROM
user_messages m1
LEFT JOIN
user_messages m2
ON
m1.response_to = m2.id
LEFT JOIN
contacts c
ON
m1.source = c.id
WHERE
m1.command_id = ?
AND
m1.local_chat_id = ?
ORDER BY m1.clock_value DESC
LIMIT 1
`, allFields),
id,
chatID,
)
err := db.tableUserMessagesScanAllFields(row, &message)
switch err {
case sql.ErrNoRows:
return nil, errRecordNotFound
case nil:
return &message, nil
default:
return nil, err
}
}
func (db sqlitePersistence) MessageByID(id string) (*Message, error) {
return db.messageByID(nil, id)
}
func (db sqlitePersistence) MessagesExist(ids []string) (map[string]bool, error) {
result := make(map[string]bool)
if len(ids) == 0 {
return result, nil
}
idsArgs := make([]interface{}, 0, len(ids))
for _, id := range ids {
idsArgs = append(idsArgs, id)
}
inVector := strings.Repeat("?, ", len(ids)-1) + "?"
query := "SELECT id FROM user_messages WHERE id IN (" + inVector + ")" // nolint: gosec
rows, err := db.db.Query(query, idsArgs...)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var id string
err := rows.Scan(&id)
if err != nil {
return nil, err
}
result[id] = true
}
return result, nil
}
func (db sqlitePersistence) MessagesByIDs(ids []string) ([]*Message, error) {
if len(ids) == 0 {
return nil, nil
}
idsArgs := make([]interface{}, 0, len(ids))
for _, id := range ids {
idsArgs = append(idsArgs, id)
}
allFields := db.tableUserMessagesAllFieldsJoin()
inVector := strings.Repeat("?, ", len(ids)-1) + "?"
// nolint: gosec
rows, err := db.db.Query(fmt.Sprintf(`
SELECT
%s
FROM
user_messages m1
LEFT JOIN
user_messages m2
ON
m1.response_to = m2.id
LEFT JOIN
contacts c
ON
m1.source = c.id
WHERE NOT(m1.hide) AND m1.id IN (%s)`, allFields, inVector), idsArgs...)
if err != nil {
return nil, err
}
defer rows.Close()
var result []*Message
for rows.Next() {
var message Message
if err := db.tableUserMessagesScanAllFields(rows, &message); err != nil {
return nil, err
}
result = append(result, &message)
}
return result, nil
}
// MessageByChatID returns all messages for a given chatID in descending order.
// Ordering is accomplished using two concatenated values: ClockValue and ID.
// These two values are also used to compose a cursor which is returned to the result.
func (db sqlitePersistence) MessageByChatID(chatID string, currCursor string, limit int) ([]*Message, string, error) {
cursorWhere := ""
if currCursor != "" {
cursorWhere = "AND cursor <= ?"
}
allFields := db.tableUserMessagesAllFieldsJoin()
args := []interface{}{chatID}
if currCursor != "" {
args = append(args, currCursor)
}
// Build a new column `cursor` at the query time by having a fixed-sized clock value at the beginning
// concatenated with message ID. Results are sorted using this new column.
// This new column values can also be returned as a cursor for subsequent requests.
rows, err := db.db.Query(
fmt.Sprintf(`
SELECT
%s,
substr('0000000000000000000000000000000000000000000000000000000000000000' || m1.clock_value, -64, 64) || m1.id as cursor
FROM
user_messages m1
LEFT JOIN
user_messages m2
ON
m1.response_to = m2.id
LEFT JOIN
contacts c
ON
m1.source = c.id
WHERE
NOT(m1.hide) AND m1.local_chat_id = ? %s
ORDER BY cursor DESC
LIMIT ?
`, allFields, cursorWhere),
append(args, limit+1)..., // take one more to figure our whether a cursor should be returned
)
if err != nil {
return nil, "", err
}
defer rows.Close()
var (
result []*Message
cursors []string
)
for rows.Next() {
var (
message Message
cursor string
)
if err := db.tableUserMessagesScanAllFields(rows, &message, &cursor); err != nil {
return nil, "", err
}
result = append(result, &message)
cursors = append(cursors, cursor)
}
var newCursor string
if len(result) > limit {
newCursor = cursors[limit]
result = result[:limit]
}
return result, newCursor, nil
}
func (db sqlitePersistence) SaveMessages(messages []*Message) (err error) {
tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
allFields := db.tableUserMessagesAllFields()
valuesVector := strings.Repeat("?, ", db.tableUserMessagesAllFieldsCount()-1) + "?"
query := "INSERT INTO user_messages(" + allFields + ") VALUES (" + valuesVector + ")" // nolint: gosec
stmt, err := tx.Prepare(query)
if err != nil {
return
}
for _, msg := range messages {
var allValues []interface{}
allValues, err = db.tableUserMessagesAllValues(msg)
if err != nil {
return
}
_, err = stmt.Exec(allValues...)
if err != nil {
return
}
}
return
}
func (db sqlitePersistence) DeleteMessage(id string) error {
_, err := db.db.Exec(`DELETE FROM user_messages WHERE id = ?`, id)
return err
}
func (db sqlitePersistence) HideMessage(id string) error {
_, err := db.db.Exec(`UPDATE user_messages SET hide = 1 WHERE id = ?`, id)
return err
}
func (db sqlitePersistence) DeleteMessagesByChatID(id string) error {
_, err := db.db.Exec(`DELETE FROM user_messages WHERE local_chat_id = ?`, id)
return err
}
func (db sqlitePersistence) MarkAllRead(chatID string) error {
tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
_, err = tx.Exec(`UPDATE user_messages SET seen = 1 WHERE local_chat_id = ?`, chatID)
if err != nil {
return err
}
_, err = tx.Exec(`UPDATE chats SET unviewed_message_count = 0 WHERE id = ?`, chatID)
return err
}
func (db sqlitePersistence) MarkMessagesSeen(chatID string, ids []string) (uint64, error) {
tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return 0, err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
idsArgs := make([]interface{}, 0, len(ids))
for _, id := range ids {
idsArgs = append(idsArgs, id)
}
inVector := strings.Repeat("?, ", len(ids)-1) + "?"
q := "UPDATE user_messages SET seen = 1 WHERE id IN (" + inVector + ")" // nolint: gosec
_, err = tx.Exec(q, idsArgs...)
if err != nil {
return 0, err
}
var count uint64
row := tx.QueryRow("SELECT changes();")
if err := row.Scan(&count); err != nil {
return 0, err
}
// Update denormalized count
_, err = tx.Exec(
`UPDATE chats
SET unviewed_message_count =
(SELECT COUNT(1)
FROM user_messages
WHERE local_chat_id = ? AND seen = 0)
WHERE id = ?`, chatID, chatID)
return count, err
}
func (db sqlitePersistence) UpdateMessageOutgoingStatus(id string, newOutgoingStatus string) error {
_, err := db.db.Exec(`
UPDATE user_messages
SET outgoing_status = ?
WHERE id = ?
`, newOutgoingStatus, id)
return err
}
// BlockContact updates a contact, deletes all the messages and 1-to-1 chat, updates the unread messages count and returns a map with the new count
func (db sqlitePersistence) BlockContact(contact *Contact) ([]*Chat, error) {
var chats []*Chat
tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return nil, err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
// Delete messages
_, err = tx.Exec(
`DELETE
FROM user_messages
WHERE source = ?`,
contact.ID,
)
if err != nil {
return nil, err
}
// Update contact
err = db.SaveContact(contact, tx)
if err != nil {
return nil, err
}
// Delete one-to-one chat
_, err = tx.Exec("DELETE FROM chats WHERE id = ?", contact.ID)
if err != nil {
return nil, err
}
// Recalculate denormalized fields
_, err = tx.Exec(`
UPDATE chats
SET
unviewed_message_count = (SELECT COUNT(1) FROM user_messages WHERE seen = 0 AND local_chat_id = chats.id)`)
if err != nil {
return nil, err
}
// return the updated chats
chats, err = db.chats(tx)
if err != nil {
return nil, err
}
for _, c := range chats {
var lastMessageID string
row := tx.QueryRow(`SELECT id FROM user_messages WHERE local_chat_id = ? ORDER BY clock_value DESC LIMIT 1`, c.ID)
switch err := row.Scan(&lastMessageID); err {
case nil:
message, err := db.messageByID(tx, lastMessageID)
if err != nil {
return nil, err
}
if message != nil {
encodedMessage, err := json.Marshal(message)
if err != nil {
return nil, err
}
_, err = tx.Exec(`UPDATE chats SET last_message = ? WHERE id = ?`, encodedMessage, c.ID)
if err != nil {
return nil, err
}
c.LastMessage = encodedMessage
}
case sql.ErrNoRows:
// Reset LastMessage
_, err = tx.Exec(`UPDATE chats SET last_message = NULL WHERE id = ?`, c.ID)
if err != nil {
return nil, err
}
c.LastMessage = nil
default:
return nil, err
}
}
return chats, err
}