status-go/protocol/common/raw_messages_persistence.go

500 lines
13 KiB
Go

package common
import (
"bytes"
"context"
"crypto/ecdsa"
"database/sql"
"encoding/gob"
"errors"
"strings"
"time"
"github.com/status-im/status-go/eth-node/crypto"
"github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/protocol/protobuf"
)
type RawMessageConfirmation struct {
// DataSyncID is the ID of the datasync message sent
DataSyncID []byte
// MessageID is the message id of the message
MessageID []byte
// PublicKey is the compressed receiver public key
PublicKey []byte
// ConfirmedAt is the unix timestamp in seconds of when the message was confirmed
ConfirmedAt int64
}
type RawMessagesPersistence struct {
db *sql.DB
}
func NewRawMessagesPersistence(db *sql.DB) *RawMessagesPersistence {
return &RawMessagesPersistence{db: db}
}
func (db RawMessagesPersistence) SaveRawMessage(message *RawMessage) error {
tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
var pubKeys [][]byte
for _, pk := range message.Recipients {
pubKeys = append(pubKeys, crypto.CompressPubkey(pk))
}
// Encode recipients
var encodedRecipients bytes.Buffer
encoder := gob.NewEncoder(&encodedRecipients)
if err := encoder.Encode(pubKeys); err != nil {
return err
}
// If the message is not sent, we check whether there's a record
// in the database already and preserve the state
if !message.Sent {
oldMessage, err := db.rawMessageByID(tx, message.ID)
if err != nil && err != sql.ErrNoRows {
return err
}
if oldMessage != nil {
message.Sent = oldMessage.Sent
}
}
var sender []byte
if message.Sender != nil {
sender = crypto.FromECDSA(message.Sender)
}
_, err = tx.Exec(`
INSERT INTO
raw_messages
(
id,
local_chat_id,
last_sent,
send_count,
sent,
message_type,
recipients,
skip_encryption,
send_push_notification,
skip_group_message_wrap,
send_on_personal_topic,
payload,
sender,
community_id,
resend_type,
pubsub_topic,
hash_ratchet_group_id,
community_key_ex_msg_type,
resend_method
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
message.ID,
message.LocalChatID,
message.LastSent,
message.SendCount,
message.Sent,
message.MessageType,
encodedRecipients.Bytes(),
message.SkipEncryptionLayer,
message.SendPushNotification,
message.SkipGroupMessageWrap,
message.SendOnPersonalTopic,
message.Payload,
sender,
message.CommunityID,
message.ResendType,
message.PubsubTopic,
message.HashRatchetGroupID,
message.CommunityKeyExMsgType,
message.ResendMethod,
)
return err
}
func (db RawMessagesPersistence) RawMessageByID(id string) (*RawMessage, error) {
tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return nil, err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
return db.rawMessageByID(tx, id)
}
func (db RawMessagesPersistence) rawMessageByID(tx *sql.Tx, id string) (*RawMessage, error) {
var rawPubKeys [][]byte
var encodedRecipients []byte
var skipGroupMessageWrap, sendOnPersonalTopic sql.NullBool
var sender []byte
message := &RawMessage{}
err := tx.QueryRow(`
SELECT
id,
local_chat_id,
last_sent,
send_count,
sent,
message_type,
recipients,
skip_encryption,
send_push_notification,
skip_group_message_wrap,
send_on_personal_topic,
payload,
sender,
community_id,
resend_type,
pubsub_topic,
hash_ratchet_group_id,
community_key_ex_msg_type,
resend_method
FROM
raw_messages
WHERE
id = ?`,
id,
).Scan(
&message.ID,
&message.LocalChatID,
&message.LastSent,
&message.SendCount,
&message.Sent,
&message.MessageType,
&encodedRecipients,
&message.SkipEncryptionLayer,
&message.SendPushNotification,
&skipGroupMessageWrap,
&sendOnPersonalTopic,
&message.Payload,
&sender,
&message.CommunityID,
&message.ResendType,
&message.PubsubTopic,
&message.HashRatchetGroupID,
&message.CommunityKeyExMsgType,
&message.ResendMethod,
)
if err != nil {
return nil, err
}
if encodedRecipients != nil {
// Restore recipients
decoder := gob.NewDecoder(bytes.NewBuffer(encodedRecipients))
err = decoder.Decode(&rawPubKeys)
if err != nil {
return nil, err
}
for _, pkBytes := range rawPubKeys {
pubkey, err := crypto.DecompressPubkey(pkBytes)
if err != nil {
return nil, err
}
message.Recipients = append(message.Recipients, pubkey)
}
}
if skipGroupMessageWrap.Valid {
message.SkipGroupMessageWrap = skipGroupMessageWrap.Bool
}
if sendOnPersonalTopic.Valid {
message.SendOnPersonalTopic = sendOnPersonalTopic.Bool
}
if sender != nil {
message.Sender, err = crypto.ToECDSA(sender)
if err != nil {
return nil, err
}
}
return message, nil
}
func (db RawMessagesPersistence) RawMessagesIDsByType(t protobuf.ApplicationMetadataMessage_Type) ([]string, error) {
ids := []string{}
rows, err := db.db.Query(`
SELECT
id
FROM
raw_messages
WHERE
message_type = ?`,
t)
if err != nil {
return ids, err
}
defer rows.Close()
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return ids, err
}
ids = append(ids, id)
}
return ids, nil
}
// MarkAsConfirmed marks all the messages with dataSyncID as confirmed and returns
// the messageIDs that can be considered confirmed.
// If atLeastOne is set it will return messageid if at least once of the messages
// sent has been confirmed
func (db RawMessagesPersistence) MarkAsConfirmed(dataSyncID []byte, atLeastOne bool) (messageID types.HexBytes, err error) {
tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return nil, err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
confirmedAt := time.Now().Unix()
_, err = tx.Exec(`UPDATE raw_message_confirmations SET confirmed_at = ? WHERE datasync_id = ? AND confirmed_at = 0`, confirmedAt, dataSyncID)
if err != nil {
return
}
// Select any tuple that has a message_id with a datasync_id = ? and that has just been confirmed
rows, err := tx.Query(`SELECT message_id,confirmed_at FROM raw_message_confirmations WHERE message_id = (SELECT message_id FROM raw_message_confirmations WHERE datasync_id = ? LIMIT 1)`, dataSyncID)
if err != nil {
return
}
defer rows.Close()
confirmedResult := true
for rows.Next() {
var confirmedAt int64
err = rows.Scan(&messageID, &confirmedAt)
if err != nil {
return
}
confirmed := confirmedAt > 0
if atLeastOne && confirmed {
// We return, as at least one was confirmed
return
}
confirmedResult = confirmedResult && confirmed
}
if !confirmedResult {
messageID = nil
return
}
return
}
func (db RawMessagesPersistence) InsertPendingConfirmation(confirmation *RawMessageConfirmation) error {
_, err := db.db.Exec(`INSERT INTO raw_message_confirmations
(datasync_id, message_id, public_key)
VALUES
(?,?,?)`,
confirmation.DataSyncID,
confirmation.MessageID,
confirmation.PublicKey,
)
return err
}
func (db RawMessagesPersistence) SaveHashRatchetMessage(groupID []byte, keyID []byte, m *types.Message) error {
_, err := db.db.Exec(`INSERT INTO hash_ratchet_encrypted_messages(hash, sig, TTL, timestamp, topic, payload, dst, p2p, padding, group_id, key_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, m.Hash, m.Sig, m.TTL, m.Timestamp, types.TopicTypeToByteArray(m.Topic), m.Payload, m.Dst, m.P2P, m.Padding, groupID, keyID)
return err
}
func (db RawMessagesPersistence) GetHashRatchetMessages(keyID []byte) ([]*types.Message, error) {
var messages []*types.Message
rows, err := db.db.Query(`SELECT hash, sig, TTL, timestamp, topic, payload, dst, p2p, padding FROM hash_ratchet_encrypted_messages WHERE key_id = ?`, keyID)
if err != nil {
return nil, err
}
for rows.Next() {
var topic []byte
message := &types.Message{}
err := rows.Scan(&message.Hash, &message.Sig, &message.TTL, &message.Timestamp, &topic, &message.Payload, &message.Dst, &message.P2P, &message.Padding)
if err != nil {
return nil, err
}
message.Topic = types.BytesToTopic(topic)
messages = append(messages, message)
}
return messages, nil
}
func (db RawMessagesPersistence) GetHashRatchetMessagesCountForGroup(groupID []byte) (int, error) {
var count int
err := db.db.QueryRow(`SELECT count(*) FROM hash_ratchet_encrypted_messages WHERE group_id = ?`, groupID).Scan(&count)
if err == nil {
return count, nil
}
if errors.Is(err, sql.ErrNoRows) {
return 0, nil
}
return 0, err
}
func (db RawMessagesPersistence) DeleteHashRatchetMessages(ids [][]byte) error {
if len(ids) == 0 {
return nil
}
idsArgs := make([]interface{}, 0, len(ids))
for _, id := range ids {
idsArgs = append(idsArgs, id)
}
inVector := strings.Repeat("?, ", len(ids)-1) + "?"
_, err := db.db.Exec("DELETE FROM hash_ratchet_encrypted_messages WHERE hash IN ("+inVector+")", idsArgs...) // nolint: gosec
return err
}
func (db *RawMessagesPersistence) DeleteHashRatchetMessagesOlderThan(timestamp int64) error {
_, err := db.db.Exec("DELETE FROM hash_ratchet_encrypted_messages WHERE timestamp < ?", timestamp)
return err
}
func (db *RawMessagesPersistence) IsMessageAlreadyCompleted(hash []byte) (bool, error) {
var alreadyCompleted int
err := db.db.QueryRow("SELECT COUNT(*) FROM message_segments_completed WHERE hash = ?", hash).Scan(&alreadyCompleted)
if err != nil {
return false, err
}
return alreadyCompleted > 0, nil
}
func (db *RawMessagesPersistence) SaveMessageSegment(segment *SegmentMessage, sigPubKey *ecdsa.PublicKey, timestamp int64) error {
sigPubKeyBlob := crypto.CompressPubkey(sigPubKey)
_, err := db.db.Exec("INSERT INTO message_segments (hash, segment_index, segments_count, parity_segment_index, parity_segments_count, sig_pub_key, payload, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
segment.EntireMessageHash, segment.Index, segment.SegmentsCount, segment.ParitySegmentIndex, segment.ParitySegmentsCount, sigPubKeyBlob, segment.Payload, timestamp)
return err
}
// Get ordered message segments for given hash
func (db *RawMessagesPersistence) GetMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey) ([]*SegmentMessage, error) {
sigPubKeyBlob := crypto.CompressPubkey(sigPubKey)
rows, err := db.db.Query(`
SELECT
hash, segment_index, segments_count, parity_segment_index, parity_segments_count, payload
FROM
message_segments
WHERE
hash = ? AND sig_pub_key = ?
ORDER BY
(segments_count = 0) ASC, -- Prioritize segments_count > 0
segment_index ASC,
parity_segment_index ASC`,
hash, sigPubKeyBlob)
if err != nil {
return nil, err
}
defer rows.Close()
var segments []*SegmentMessage
for rows.Next() {
segment := &SegmentMessage{
SegmentMessage: &protobuf.SegmentMessage{},
}
err := rows.Scan(&segment.EntireMessageHash, &segment.Index, &segment.SegmentsCount, &segment.ParitySegmentIndex, &segment.ParitySegmentsCount, &segment.Payload)
if err != nil {
return nil, err
}
segments = append(segments, segment)
}
err = rows.Err()
if err != nil {
return nil, err
}
return segments, nil
}
func (db *RawMessagesPersistence) RemoveMessageSegmentsOlderThan(timestamp int64) error {
_, err := db.db.Exec("DELETE FROM message_segments WHERE timestamp < ?", timestamp)
return err
}
func (db *RawMessagesPersistence) CompleteMessageSegments(hash []byte, sigPubKey *ecdsa.PublicKey, timestamp int64) error {
tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return err
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
sigPubKeyBlob := crypto.CompressPubkey(sigPubKey)
_, err = tx.Exec("DELETE FROM message_segments WHERE hash = ? AND sig_pub_key = ?", hash, sigPubKeyBlob)
if err != nil {
return err
}
_, err = tx.Exec("INSERT INTO message_segments_completed (hash, sig_pub_key, timestamp) VALUES (?,?,?)", hash, sigPubKeyBlob, timestamp)
if err != nil {
return err
}
return err
}
func (db *RawMessagesPersistence) RemoveMessageSegmentsCompletedOlderThan(timestamp int64) error {
_, err := db.db.Exec("DELETE FROM message_segments_completed WHERE timestamp < ?", timestamp)
return err
}
func (db RawMessagesPersistence) UpdateRawMessageSent(id string, sent bool) error {
_, err := db.db.Exec("UPDATE raw_messages SET sent = ? WHERE id = ?", sent, id)
return err
}
func (db RawMessagesPersistence) UpdateRawMessageLastSent(id string, lastSent uint64) error {
_, err := db.db.Exec("UPDATE raw_messages SET last_sent = ? WHERE id = ?", lastSent, id)
return err
}