393 lines
9.1 KiB
Go
393 lines
9.1 KiB
Go
package protocol
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
var (
|
|
errRecordNotFound = errors.New("record not found")
|
|
)
|
|
|
|
func (db sqlitePersistence) tableUserMessagesLegacyAllFields() string {
|
|
return `id,
|
|
whisper_timestamp,
|
|
source,
|
|
destination,
|
|
content,
|
|
content_type,
|
|
username,
|
|
timestamp,
|
|
chat_id,
|
|
retry_count,
|
|
message_type,
|
|
message_status,
|
|
clock_value,
|
|
show,
|
|
seen,
|
|
outgoing_status,
|
|
reply_to`
|
|
}
|
|
|
|
func (db sqlitePersistence) tableUserMessagesLegacyAllFieldsJoin() string {
|
|
return `m1.id,
|
|
m1.whisper_timestamp,
|
|
m1.source,
|
|
m1.destination,
|
|
m1.content,
|
|
m1.content_type,
|
|
m1.username,
|
|
m1.timestamp,
|
|
m1.chat_id,
|
|
m1.retry_count,
|
|
m1.message_type,
|
|
m1.message_status,
|
|
m1.clock_value,
|
|
m1.show,
|
|
m1.seen,
|
|
m1.outgoing_status,
|
|
m1.reply_to,
|
|
m2.source,
|
|
m2.content,
|
|
c.alias,
|
|
c.identicon`
|
|
}
|
|
|
|
func (db sqlitePersistence) tableUserMessagesLegacyAllFieldsCount() int {
|
|
return strings.Count(db.tableUserMessagesLegacyAllFields(), ",") + 1
|
|
}
|
|
|
|
type scanner interface {
|
|
Scan(dest ...interface{}) error
|
|
}
|
|
|
|
func (db sqlitePersistence) tableUserMessagesLegacyScanAllFields(row scanner, message *Message, others ...interface{}) error {
|
|
var quotedContent sql.NullString
|
|
var quotedFrom sql.NullString
|
|
var alias sql.NullString
|
|
var identicon sql.NullString
|
|
args := []interface{}{
|
|
&message.ID,
|
|
&message.WhisperTimestamp,
|
|
&message.From, // source in table
|
|
&message.To, // destination in table
|
|
&message.Content,
|
|
&message.ContentType,
|
|
&message.Alias,
|
|
&message.Timestamp,
|
|
&message.ChatID,
|
|
&message.RetryCount,
|
|
&message.MessageType,
|
|
&message.MessageStatus,
|
|
&message.ClockValue,
|
|
&message.Show,
|
|
&message.Seen,
|
|
&message.OutgoingStatus,
|
|
&message.ReplyTo,
|
|
"edFrom,
|
|
"edContent,
|
|
&alias,
|
|
&identicon,
|
|
}
|
|
err := row.Scan(append(args, others...)...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if quotedContent.Valid {
|
|
message.QuotedMessage = &QuotedMessage{
|
|
From: quotedFrom.String,
|
|
Content: quotedContent.String,
|
|
}
|
|
}
|
|
message.Alias = alias.String
|
|
message.Identicon = identicon.String
|
|
|
|
return nil
|
|
}
|
|
|
|
func (db sqlitePersistence) tableUserMessagesLegacyAllValues(message *Message) []interface{} {
|
|
return []interface{}{
|
|
message.ID,
|
|
message.WhisperTimestamp,
|
|
message.From, // source in table
|
|
message.To, // destination in table
|
|
message.Content,
|
|
message.ContentType,
|
|
message.Alias,
|
|
message.Timestamp,
|
|
message.ChatID,
|
|
message.RetryCount,
|
|
message.MessageType,
|
|
message.MessageStatus,
|
|
message.ClockValue,
|
|
message.Show,
|
|
message.Seen,
|
|
message.OutgoingStatus,
|
|
message.ReplyTo,
|
|
}
|
|
}
|
|
|
|
func (db sqlitePersistence) MessageByID(id string) (*Message, error) {
|
|
var message Message
|
|
|
|
allFields := db.tableUserMessagesLegacyAllFieldsJoin()
|
|
row := db.db.QueryRow(
|
|
fmt.Sprintf(`
|
|
SELECT
|
|
%s
|
|
FROM
|
|
user_messages_legacy m1
|
|
LEFT JOIN
|
|
user_messages_legacy m2
|
|
ON
|
|
m1.reply_to = m2.id
|
|
|
|
LEFT JOIN
|
|
contacts c
|
|
ON
|
|
m1.source = c.id
|
|
WHERE
|
|
m1.id = ?
|
|
`, allFields),
|
|
id,
|
|
)
|
|
err := db.tableUserMessagesLegacyScanAllFields(row, &message)
|
|
switch err {
|
|
case sql.ErrNoRows:
|
|
return nil, errRecordNotFound
|
|
case nil:
|
|
return &message, nil
|
|
default:
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
func (db sqlitePersistence) MessagesExist(ids []string) (map[string]bool, error) {
|
|
result := make(map[string]bool)
|
|
if len(ids) == 0 {
|
|
return result, nil
|
|
}
|
|
|
|
idsArgs := make([]interface{}, 0, len(ids))
|
|
for _, id := range ids {
|
|
idsArgs = append(idsArgs, id)
|
|
}
|
|
|
|
inVector := strings.Repeat("?, ", len(ids)-1) + "?"
|
|
query := fmt.Sprintf(`SELECT id FROM user_messages_legacy WHERE id IN (%s)`, inVector)
|
|
rows, err := db.db.Query(query, idsArgs...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var id string
|
|
err := rows.Scan(&id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result[id] = true
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// MessageByChatID returns all messages for a given chatID in descending order.
|
|
// Ordering is accomplished using two concatenated values: ClockValue and ID.
|
|
// These two values are also used to compose a cursor which is returned to the result.
|
|
func (db sqlitePersistence) MessageByChatID(chatID string, currCursor string, limit int) ([]*Message, string, error) {
|
|
cursorWhere := ""
|
|
if currCursor != "" {
|
|
cursorWhere = "AND cursor <= ?"
|
|
}
|
|
allFields := db.tableUserMessagesLegacyAllFieldsJoin()
|
|
args := []interface{}{chatID}
|
|
if currCursor != "" {
|
|
args = append(args, currCursor)
|
|
}
|
|
// Build a new column `cursor` at the query time by having a fixed-sized clock value at the beginning
|
|
// concatenated with message ID. Results are sorted using this new column.
|
|
// This new column values can also be returned as a cursor for subsequent requests.
|
|
rows, err := db.db.Query(
|
|
fmt.Sprintf(`
|
|
SELECT
|
|
%s,
|
|
substr('0000000000000000000000000000000000000000000000000000000000000000' || m1.clock_value, -64, 64) || m1.id as cursor
|
|
FROM
|
|
user_messages_legacy m1
|
|
LEFT JOIN
|
|
user_messages_legacy m2
|
|
ON
|
|
m1.reply_to = m2.id
|
|
|
|
LEFT JOIN
|
|
contacts c
|
|
ON
|
|
|
|
m1.source = c.id
|
|
WHERE
|
|
m1.chat_id = ? %s
|
|
ORDER BY cursor DESC
|
|
LIMIT ?
|
|
`, allFields, cursorWhere),
|
|
append(args, limit+1)..., // take one more to figure our whether a cursor should be returned
|
|
)
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var (
|
|
result []*Message
|
|
cursors []string
|
|
)
|
|
for rows.Next() {
|
|
var (
|
|
message Message
|
|
cursor string
|
|
)
|
|
if err := db.tableUserMessagesLegacyScanAllFields(rows, &message, &cursor); err != nil {
|
|
return nil, "", err
|
|
}
|
|
result = append(result, &message)
|
|
cursors = append(cursors, cursor)
|
|
}
|
|
|
|
var newCursor string
|
|
if len(result) > limit {
|
|
newCursor = cursors[limit]
|
|
result = result[:limit]
|
|
}
|
|
return result, newCursor, nil
|
|
}
|
|
|
|
func (db sqlitePersistence) SaveMessagesLegacy(messages []*Message) (err error) {
|
|
tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer func() {
|
|
if err == nil {
|
|
err = tx.Commit()
|
|
return
|
|
}
|
|
// don't shadow original error
|
|
_ = tx.Rollback()
|
|
}()
|
|
|
|
allFields := db.tableUserMessagesLegacyAllFields()
|
|
valuesVector := strings.Repeat("?, ", db.tableUserMessagesLegacyAllFieldsCount()-1) + "?"
|
|
query := fmt.Sprintf(`INSERT INTO user_messages_legacy(%s) VALUES (%s)`, allFields, valuesVector)
|
|
stmt, err := tx.Prepare(query)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
for _, msg := range messages {
|
|
_, err = stmt.Exec(db.tableUserMessagesLegacyAllValues(msg)...)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func (db sqlitePersistence) DeleteMessage(id string) error {
|
|
_, err := db.db.Exec(`DELETE FROM user_messages_legacy WHERE id = ?`, id)
|
|
return err
|
|
}
|
|
|
|
func (db sqlitePersistence) DeleteMessagesByChatID(id string) error {
|
|
_, err := db.db.Exec(`DELETE FROM user_messages_legacy WHERE chat_id = ?`, id)
|
|
return err
|
|
}
|
|
|
|
func (db sqlitePersistence) MarkMessagesSeen(ids ...string) error {
|
|
idsArgs := make([]interface{}, 0, len(ids))
|
|
for _, id := range ids {
|
|
idsArgs = append(idsArgs, id)
|
|
}
|
|
|
|
inVector := strings.Repeat("?, ", len(ids)-1) + "?"
|
|
_, err := db.db.Exec(
|
|
fmt.Sprintf(`
|
|
UPDATE user_messages_legacy
|
|
SET seen = 1
|
|
WHERE id IN (%s)
|
|
`, inVector),
|
|
idsArgs...)
|
|
return err
|
|
}
|
|
|
|
func (db sqlitePersistence) UpdateMessageOutgoingStatus(id string, newOutgoingStatus string) error {
|
|
_, err := db.db.Exec(`
|
|
UPDATE user_messages_legacy
|
|
SET outgoing_status = ?
|
|
WHERE id = ?
|
|
`, newOutgoingStatus, id)
|
|
return err
|
|
}
|
|
|
|
// BlockContact updates a contact, deletes all the messages and 1-to-1 chat, updates the unread messages count and returns a map with the new count
|
|
func (db sqlitePersistence) BlockContact(contact Contact) (chats []*Chat, err error) {
|
|
tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer func() {
|
|
if err == nil {
|
|
err = tx.Commit()
|
|
return
|
|
}
|
|
// don't shadow original error
|
|
_ = tx.Rollback()
|
|
}()
|
|
|
|
// Delete messages
|
|
_, err = tx.Exec(
|
|
`DELETE
|
|
FROM user_messages_legacy
|
|
WHERE source = ?`,
|
|
contact.ID,
|
|
)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// Update contact
|
|
err = db.SaveContact(contact, tx)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// Delete one-to-one chat
|
|
_, err = tx.Exec("DELETE FROM chats WHERE id = ?", contact.ID)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// Recalculate denormalized fields
|
|
_, err = tx.Exec(`
|
|
UPDATE chats
|
|
SET
|
|
unviewed_message_count = (SELECT COUNT(1) FROM user_messages_legacy WHERE seen = 0 AND chat_id = chats.id),
|
|
last_message_content = (SELECT content from user_messages_legacy WHERE chat_id = chats.id ORDER BY clock_value DESC LIMIT 1),
|
|
last_message_timestamp = (SELECT timestamp from user_messages_legacy WHERE chat_id = chats.id ORDER BY clock_value DESC LIMIT 1),
|
|
last_message_clock_value = (SELECT clock_value from user_messages_legacy WHERE chat_id = chats.id ORDER BY clock_value DESC LIMIT 1),
|
|
last_message_content_type = (SELECT content_type from user_messages_legacy WHERE chat_id = chats.id ORDER BY clock_value DESC LIMIT 1)
|
|
`)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// return the updated chats
|
|
chats, err = db.chats(tx)
|
|
return
|
|
}
|