2021-03-09 12:48:15 +00:00
package common
import (
"bytes"
"context"
2023-11-09 20:36:57 +00:00
"crypto/ecdsa"
2021-03-09 12:48:15 +00:00
"database/sql"
"encoding/gob"
2023-10-30 11:53:37 +00:00
"strings"
2021-03-09 12:48:15 +00:00
"time"
"github.com/status-im/status-go/eth-node/crypto"
"github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/protocol/protobuf"
)
type RawMessageConfirmation struct {
// DataSyncID is the ID of the datasync message sent
DataSyncID [ ] byte
// MessageID is the message id of the message
MessageID [ ] byte
// PublicKey is the compressed receiver public key
PublicKey [ ] byte
// ConfirmedAt is the unix timestamp in seconds of when the message was confirmed
ConfirmedAt int64
}
type RawMessagesPersistence struct {
db * sql . DB
}
func NewRawMessagesPersistence ( db * sql . DB ) * RawMessagesPersistence {
return & RawMessagesPersistence { db : db }
}
func ( db RawMessagesPersistence ) SaveRawMessage ( message * RawMessage ) error {
2023-06-19 13:10:21 +00:00
tx , err := db . db . BeginTx ( context . Background ( ) , & sql . TxOptions { } )
if err != nil {
return err
}
defer func ( ) {
if err == nil {
err = tx . Commit ( )
return
}
// don't shadow original error
_ = tx . Rollback ( )
} ( )
2021-03-09 12:48:15 +00:00
var pubKeys [ ] [ ] byte
for _ , pk := range message . Recipients {
pubKeys = append ( pubKeys , crypto . CompressPubkey ( pk ) )
}
// Encode recipients
var encodedRecipients bytes . Buffer
encoder := gob . NewEncoder ( & encodedRecipients )
if err := encoder . Encode ( pubKeys ) ; err != nil {
return err
}
2023-06-19 13:10:21 +00:00
// If the message is not sent, we check whether there's a record
// in the database already and preserve the state
if ! message . Sent {
oldMessage , err := db . rawMessageByID ( tx , message . ID )
if err != nil && err != sql . ErrNoRows {
return err
}
if oldMessage != nil {
message . Sent = oldMessage . Sent
}
}
_ , err = tx . Exec ( `
2021-03-09 12:48:15 +00:00
INSERT INTO
raw_messages
(
id ,
local_chat_id ,
last_sent ,
send_count ,
sent ,
message_type ,
resend_automatically ,
recipients ,
skip_encryption ,
send_push_notification ,
skip_group_message_wrap ,
send_on_personal_topic ,
payload
)
VALUES ( ? , ? , ? , ? , ? , ? , ? , ? , ? , ? , ? , ? , ? ) ` ,
message . ID ,
message . LocalChatID ,
message . LastSent ,
message . SendCount ,
message . Sent ,
message . MessageType ,
message . ResendAutomatically ,
encodedRecipients . Bytes ( ) ,
2023-11-08 18:05:33 +00:00
message . SkipEncryptionLayer ,
2021-03-09 12:48:15 +00:00
message . SendPushNotification ,
message . SkipGroupMessageWrap ,
message . SendOnPersonalTopic ,
message . Payload )
return err
}
func ( db RawMessagesPersistence ) RawMessageByID ( id string ) ( * RawMessage , error ) {
2023-06-19 13:10:21 +00:00
tx , err := db . db . BeginTx ( context . Background ( ) , & sql . TxOptions { } )
if err != nil {
return nil , err
}
defer func ( ) {
if err == nil {
err = tx . Commit ( )
return
}
// don't shadow original error
_ = tx . Rollback ( )
} ( )
return db . rawMessageByID ( tx , id )
}
func ( db RawMessagesPersistence ) rawMessageByID ( tx * sql . Tx , id string ) ( * RawMessage , error ) {
2021-03-09 12:48:15 +00:00
var rawPubKeys [ ] [ ] byte
var encodedRecipients [ ] byte
var skipGroupMessageWrap sql . NullBool
var sendOnPersonalTopic sql . NullBool
message := & RawMessage { }
2023-06-19 13:10:21 +00:00
err := tx . QueryRow ( `
2021-03-09 12:48:15 +00:00
SELECT
id ,
local_chat_id ,
last_sent ,
send_count ,
sent ,
message_type ,
resend_automatically ,
recipients ,
skip_encryption ,
send_push_notification ,
skip_group_message_wrap ,
send_on_personal_topic ,
payload
FROM
raw_messages
WHERE
id = ? ` ,
id ,
) . Scan (
& message . ID ,
& message . LocalChatID ,
& message . LastSent ,
& message . SendCount ,
& message . Sent ,
& message . MessageType ,
& message . ResendAutomatically ,
& encodedRecipients ,
2023-11-08 18:05:33 +00:00
& message . SkipEncryptionLayer ,
2021-03-09 12:48:15 +00:00
& message . SendPushNotification ,
& skipGroupMessageWrap ,
& sendOnPersonalTopic ,
& message . Payload ,
)
if err != nil {
return nil , err
}
if rawPubKeys != nil {
// Restore recipients
decoder := gob . NewDecoder ( bytes . NewBuffer ( encodedRecipients ) )
err = decoder . Decode ( & rawPubKeys )
if err != nil {
return nil , err
}
for _ , pkBytes := range rawPubKeys {
pubkey , err := crypto . UnmarshalPubkey ( pkBytes )
if err != nil {
return nil , err
}
message . Recipients = append ( message . Recipients , pubkey )
}
}
if skipGroupMessageWrap . Valid {
message . SkipGroupMessageWrap = skipGroupMessageWrap . Bool
}
if sendOnPersonalTopic . Valid {
message . SendOnPersonalTopic = sendOnPersonalTopic . Bool
}
return message , nil
}
func ( db RawMessagesPersistence ) RawMessagesIDsByType ( t protobuf . ApplicationMetadataMessage_Type ) ( [ ] string , error ) {
ids := [ ] string { }
rows , err := db . db . Query ( `
SELECT
id
FROM
raw_messages
WHERE
message_type = ? ` ,
t )
if err != nil {
return ids , err
}
defer rows . Close ( )
for rows . Next ( ) {
var id string
if err := rows . Scan ( & id ) ; err != nil {
return ids , err
}
ids = append ( ids , id )
}
return ids , nil
}
// MarkAsConfirmed marks all the messages with dataSyncID as confirmed and returns
// the messageIDs that can be considered confirmed.
// If atLeastOne is set it will return messageid if at least once of the messages
// sent has been confirmed
func ( db RawMessagesPersistence ) MarkAsConfirmed ( dataSyncID [ ] byte , atLeastOne bool ) ( messageID types . HexBytes , err error ) {
tx , err := db . db . BeginTx ( context . Background ( ) , & sql . TxOptions { } )
if err != nil {
return nil , err
}
defer func ( ) {
if err == nil {
err = tx . Commit ( )
return
}
// don't shadow original error
_ = tx . Rollback ( )
} ( )
confirmedAt := time . Now ( ) . Unix ( )
_ , err = tx . Exec ( ` UPDATE raw_message_confirmations SET confirmed_at = ? WHERE datasync_id = ? AND confirmed_at = 0 ` , confirmedAt , dataSyncID )
if err != nil {
return
}
// Select any tuple that has a message_id with a datasync_id = ? and that has just been confirmed
rows , err := tx . Query ( ` SELECT message_id,confirmed_at FROM raw_message_confirmations WHERE message_id = (SELECT message_id FROM raw_message_confirmations WHERE datasync_id = ? LIMIT 1) ` , dataSyncID )
if err != nil {
return
}
defer rows . Close ( )
confirmedResult := true
for rows . Next ( ) {
var confirmedAt int64
err = rows . Scan ( & messageID , & confirmedAt )
if err != nil {
return
}
confirmed := confirmedAt > 0
if atLeastOne && confirmed {
// We return, as at least one was confirmed
return
}
confirmedResult = confirmedResult && confirmed
}
if ! confirmedResult {
messageID = nil
return
}
return
}
func ( db RawMessagesPersistence ) InsertPendingConfirmation ( confirmation * RawMessageConfirmation ) error {
_ , err := db . db . Exec ( ` INSERT INTO raw_message_confirmations
( datasync_id , message_id , public_key )
VALUES
( ? , ? , ? ) ` ,
confirmation . DataSyncID ,
confirmation . MessageID ,
confirmation . PublicKey ,
)
return err
}
2022-09-21 16:05:29 +00:00
2023-10-12 15:45:23 +00:00
func ( db RawMessagesPersistence ) SaveHashRatchetMessage ( groupID [ ] byte , keyID [ ] byte , m * types . Message ) error {
2023-10-29 08:04:01 +00:00
_ , err := db . db . Exec ( ` INSERT INTO hash_ratchet_encrypted_messages(hash, sig, TTL, timestamp, topic, payload, dst, p2p, padding, group_id, key_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` , m . Hash , m . Sig , m . TTL , m . Timestamp , types . TopicTypeToByteArray ( m . Topic ) , m . Payload , m . Dst , m . P2P , m . Padding , groupID , keyID )
2022-09-21 16:05:29 +00:00
return err
}
2023-10-12 15:45:23 +00:00
func ( db RawMessagesPersistence ) GetHashRatchetMessages ( keyID [ ] byte ) ( [ ] * types . Message , error ) {
2022-09-21 16:05:29 +00:00
var messages [ ] * types . Message
2023-10-29 08:04:01 +00:00
rows , err := db . db . Query ( ` SELECT hash, sig, TTL, timestamp, topic, payload, dst, p2p, padding FROM hash_ratchet_encrypted_messages WHERE key_id = ? ` , keyID )
2022-09-21 16:05:29 +00:00
if err != nil {
return nil , err
}
for rows . Next ( ) {
var topic [ ] byte
message := & types . Message { }
2023-10-29 08:04:01 +00:00
err := rows . Scan ( & message . Hash , & message . Sig , & message . TTL , & message . Timestamp , & topic , & message . Payload , & message . Dst , & message . P2P , & message . Padding )
2022-09-21 16:05:29 +00:00
if err != nil {
return nil , err
}
message . Topic = types . BytesToTopic ( topic )
messages = append ( messages , message )
}
return messages , nil
}
2023-10-30 11:53:37 +00:00
func ( db RawMessagesPersistence ) DeleteHashRatchetMessages ( ids [ ] [ ] byte ) error {
if len ( ids ) == 0 {
return nil
}
idsArgs := make ( [ ] interface { } , 0 , len ( ids ) )
for _ , id := range ids {
idsArgs = append ( idsArgs , id )
}
inVector := strings . Repeat ( "?, " , len ( ids ) - 1 ) + "?"
_ , err := db . db . Exec ( "DELETE FROM hash_ratchet_encrypted_messages WHERE hash IN (" + inVector + ")" , idsArgs ... ) // nolint: gosec
return err
}
2023-11-09 20:36:57 +00:00
func ( db * RawMessagesPersistence ) IsMessageAlreadyCompleted ( hash [ ] byte ) ( bool , error ) {
var alreadyCompleted int
err := db . db . QueryRow ( "SELECT COUNT(*) FROM message_segments_completed WHERE hash = ?" , hash ) . Scan ( & alreadyCompleted )
if err != nil {
return false , err
}
return alreadyCompleted > 0 , nil
}
2023-11-15 11:01:02 +00:00
func ( db * RawMessagesPersistence ) SaveMessageSegment ( segment * protobuf . SegmentMessage , sigPubKey * ecdsa . PublicKey , timestamp int64 ) error {
2023-11-09 20:36:57 +00:00
sigPubKeyBlob := crypto . CompressPubkey ( sigPubKey )
2023-11-15 11:01:02 +00:00
_ , err := db . db . Exec ( "INSERT INTO message_segments (hash, segment_index, segments_count, sig_pub_key, payload, timestamp) VALUES (?, ?, ?, ?, ?, ?)" ,
segment . EntireMessageHash , segment . Index , segment . SegmentsCount , sigPubKeyBlob , segment . Payload , timestamp )
2023-11-09 20:36:57 +00:00
return err
}
// Get ordered message segments for given hash
func ( db * RawMessagesPersistence ) GetMessageSegments ( hash [ ] byte , sigPubKey * ecdsa . PublicKey ) ( [ ] * protobuf . SegmentMessage , error ) {
sigPubKeyBlob := crypto . CompressPubkey ( sigPubKey )
rows , err := db . db . Query ( "SELECT hash, segment_index, segments_count, payload FROM message_segments WHERE hash = ? AND sig_pub_key = ? ORDER BY segment_index" , hash , sigPubKeyBlob )
if err != nil {
return nil , err
}
defer rows . Close ( )
var segments [ ] * protobuf . SegmentMessage
for rows . Next ( ) {
var segment protobuf . SegmentMessage
err := rows . Scan ( & segment . EntireMessageHash , & segment . Index , & segment . SegmentsCount , & segment . Payload )
if err != nil {
return nil , err
}
segments = append ( segments , & segment )
}
err = rows . Err ( )
if err != nil {
return nil , err
}
return segments , nil
}
2023-11-15 11:01:02 +00:00
func ( db * RawMessagesPersistence ) RemoveMessageSegmentsOlderThan ( timestamp int64 ) error {
_ , err := db . db . Exec ( "DELETE FROM message_segments WHERE timestamp < ?" , timestamp )
return err
}
func ( db * RawMessagesPersistence ) CompleteMessageSegments ( hash [ ] byte , sigPubKey * ecdsa . PublicKey , timestamp int64 ) error {
2023-11-09 20:36:57 +00:00
tx , err := db . db . BeginTx ( context . Background ( ) , & sql . TxOptions { } )
if err != nil {
return err
}
defer func ( ) {
if err == nil {
err = tx . Commit ( )
return
}
// don't shadow original error
_ = tx . Rollback ( )
} ( )
sigPubKeyBlob := crypto . CompressPubkey ( sigPubKey )
_ , err = tx . Exec ( "DELETE FROM message_segments WHERE hash = ? AND sig_pub_key = ?" , hash , sigPubKeyBlob )
if err != nil {
return err
}
2023-11-15 11:01:02 +00:00
_ , err = tx . Exec ( "INSERT INTO message_segments_completed (hash, sig_pub_key, timestamp) VALUES (?,?,?)" , hash , sigPubKeyBlob , timestamp )
2023-11-09 20:36:57 +00:00
if err != nil {
return err
}
return err
}
2023-11-15 11:01:02 +00:00
func ( db * RawMessagesPersistence ) RemoveMessageSegmentsCompletedOlderThan ( timestamp int64 ) error {
_ , err := db . db . Exec ( "DELETE FROM message_segments_completed WHERE timestamp < ?" , timestamp )
return err
}