Delete processed messages & add tests

This commit adds a test for out of order messages, which were only
implicitly tested.
It also deletes them after being processed, otherwise they would be
reproceessed each time a message was sent
This commit is contained in:
Andrea Maria Piana 2023-10-30 11:53:37 +00:00
parent d47b5733c0
commit 85f8c92cde
4 changed files with 163 additions and 0 deletions

View File

@ -775,6 +775,7 @@ func (s *MessageSender) HandleMessages(shhMessage *types.Message) ([]*v1protocol
return nil, nil, err return nil, nil, err
} }
var processedIds [][]byte
for _, message := range messages { for _, message := range messages {
var statusMessage v1protocol.StatusMessage var statusMessage v1protocol.StatusMessage
err := statusMessage.HandleTransport(message) err := statusMessage.HandleTransport(message)
@ -785,7 +786,9 @@ func (s *MessageSender) HandleMessages(shhMessage *types.Message) ([]*v1protocol
err = s.handleEncryptionLayer(context.Background(), &statusMessage) err = s.handleEncryptionLayer(context.Background(), &statusMessage)
if err != nil { if err != nil {
hlogger.Debug("failed to handle an encryption message", zap.Error(err)) hlogger.Debug("failed to handle an encryption message", zap.Error(err))
continue
} }
processedIds = append(processedIds, message.Hash)
stms, as, err := unwrapDatasyncMessage(&statusMessage, s.datasync) stms, as, err := unwrapDatasyncMessage(&statusMessage, s.datasync)
if err != nil { if err != nil {
@ -798,6 +801,12 @@ func (s *MessageSender) HandleMessages(shhMessage *types.Message) ([]*v1protocol
acks = append(acks, as...) acks = append(acks, as...)
} }
} }
err = s.persistence.DeleteHashRatchetMessages(processedIds)
if err != nil {
s.logger.Warn("failed to delete hash ratchet messages", zap.Error(err))
return nil, nil, err
}
} }
stms, as, err := unwrapDatasyncMessage(&statusMessage, s.datasync) stms, as, err := unwrapDatasyncMessage(&statusMessage, s.datasync)

View File

@ -225,3 +225,82 @@ func (s *MessageSenderSuite) TestHandleDecodedMessagesDatasyncEncrypted() {
s.Require().Equal(encodedPayload, decodedMessages[0].UnwrappedPayload) s.Require().Equal(encodedPayload, decodedMessages[0].UnwrappedPayload)
s.Require().Equal(protobuf.ApplicationMetadataMessage_CHAT_MESSAGE, decodedMessages[0].Type) s.Require().Equal(protobuf.ApplicationMetadataMessage_CHAT_MESSAGE, decodedMessages[0].Type)
} }
func (s *MessageSenderSuite) TestHandleOutOfOrderHashRatchet() {
groupID := []byte("group-id")
senderKey, err := crypto.GenerateKey()
s.Require().NoError(err)
encodedPayload, err := proto.Marshal(&s.testMessage)
s.Require().NoError(err)
// Create sender encryption protocol.
senderDatabase, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{})
s.Require().NoError(err)
err = sqlite.Migrate(senderDatabase)
s.Require().NoError(err)
senderEncryptionProtocol := encryption.New(
senderDatabase,
"installation-2",
s.logger,
)
ratchet, err := senderEncryptionProtocol.GenerateHashRatchetKey(groupID)
s.Require().NoError(err)
ratchets := []*encryption.HashRatchetKeyCompatibility{ratchet}
hashRatchetKeyExchangeMessage, err := senderEncryptionProtocol.BuildHashRatchetKeyExchangeMessage(senderKey, &s.sender.identity.PublicKey, groupID, ratchets)
s.Require().NoError(err)
encryptedPayload1, err := proto.Marshal(hashRatchetKeyExchangeMessage.Message)
s.Require().NoError(err)
wrappedPayload2, err := v1protocol.WrapMessageV1(encodedPayload, protobuf.ApplicationMetadataMessage_CHAT_MESSAGE, senderKey)
s.Require().NoError(err)
messageSpec2, err := senderEncryptionProtocol.BuildHashRatchetMessage(
groupID,
wrappedPayload2,
)
s.Require().NoError(err)
encryptedPayload2, err := proto.Marshal(messageSpec2.Message)
s.Require().NoError(err)
message := &types.Message{}
message.Sig = crypto.FromECDSAPub(&senderKey.PublicKey)
message.Hash = []byte{0x1}
message.Payload = encryptedPayload2
_, _, err = s.sender.HandleMessages(message)
s.Require().NoError(err)
keyID, err := ratchet.GetKeyID()
s.Require().NoError(err)
msgs, err := s.sender.persistence.GetHashRatchetMessages(keyID)
s.Require().NoError(err)
s.Require().Len(msgs, 1)
message = &types.Message{}
message.Sig = crypto.FromECDSAPub(&senderKey.PublicKey)
message.Hash = []byte{0x2}
message.Payload = encryptedPayload1
decodedMessages2, _, err := s.sender.HandleMessages(message)
s.Require().NoError(err)
s.Require().NotNil(decodedMessages2)
// It should have 2 messages, the key exchange and the one from the database
s.Require().Len(decodedMessages2, 2)
// it deletes the messages after being processed
msgs, err = s.sender.persistence.GetHashRatchetMessages(keyID)
s.Require().NoError(err)
s.Require().Len(msgs, 0)
}

View File

@ -5,6 +5,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/gob" "encoding/gob"
"strings"
"time" "time"
"github.com/status-im/status-go/eth-node/crypto" "github.com/status-im/status-go/eth-node/crypto"
@ -319,3 +320,19 @@ func (db RawMessagesPersistence) GetHashRatchetMessages(keyID []byte) ([]*types.
return messages, nil return messages, nil
} }
func (db RawMessagesPersistence) DeleteHashRatchetMessages(ids [][]byte) error {
if len(ids) == 0 {
return nil
}
idsArgs := make([]interface{}, 0, len(ids))
for _, id := range ids {
idsArgs = append(idsArgs, id)
}
inVector := strings.Repeat("?, ", len(ids)-1) + "?"
_, err := db.db.Exec("DELETE FROM hash_ratchet_encrypted_messages WHERE hash IN ("+inVector+")", idsArgs...) // nolint: gosec
return err
}

View File

@ -1747,3 +1747,61 @@ func TestCountActiveChattersInCommunity(t *testing.T) {
checker(7, 1) checker(7, 1)
checker(8, 0) checker(8, 0)
} }
func TestDeleteHashRatchetMessage(t *testing.T) {
db, err := openTestDB()
require.NoError(t, err)
p := newSQLitePersistence(db)
groupID := []byte("group-id")
keyID := []byte("key-id")
message1 := &types.Message{
Hash: []byte{1},
Sig: []byte{2},
TTL: 1,
Timestamp: 2,
Payload: []byte{3},
}
require.NoError(t, p.SaveHashRatchetMessage(groupID, keyID, message1))
message2 := &types.Message{
Hash: []byte{2},
Sig: []byte{2},
TTL: 1,
Topic: types.BytesToTopic([]byte{5}),
Timestamp: 2,
Payload: []byte{3},
Dst: []byte{4},
P2P: true,
}
require.NoError(t, p.SaveHashRatchetMessage(groupID, keyID, message2))
message3 := &types.Message{
Hash: []byte{3},
Sig: []byte{2},
TTL: 1,
Topic: types.BytesToTopic([]byte{5}),
Timestamp: 2,
Payload: []byte{3},
Dst: []byte{4},
P2P: true,
}
require.NoError(t, p.SaveHashRatchetMessage(groupID, keyID, message3))
fetchedMessages, err := p.GetHashRatchetMessages(keyID)
require.NoError(t, err)
require.NotNil(t, fetchedMessages)
require.Len(t, fetchedMessages, 3)
require.NoError(t, p.DeleteHashRatchetMessages([][]byte{[]byte{1}, []byte{2}}))
fetchedMessages, err = p.GetHashRatchetMessages(keyID)
require.NoError(t, err)
require.NotNil(t, fetchedMessages)
require.Len(t, fetchedMessages, 1)
}