Delete processed messages & add tests
This commit adds a test for out of order messages, which were only implicitly tested. It also deletes them after being processed, otherwise they would be reproceessed each time a message was sent
This commit is contained in:
parent
d47b5733c0
commit
85f8c92cde
|
@ -775,6 +775,7 @@ func (s *MessageSender) HandleMessages(shhMessage *types.Message) ([]*v1protocol
|
|||
return nil, nil, err
|
||||
}
|
||||
|
||||
var processedIds [][]byte
|
||||
for _, message := range messages {
|
||||
var statusMessage v1protocol.StatusMessage
|
||||
err := statusMessage.HandleTransport(message)
|
||||
|
@ -785,7 +786,9 @@ func (s *MessageSender) HandleMessages(shhMessage *types.Message) ([]*v1protocol
|
|||
err = s.handleEncryptionLayer(context.Background(), &statusMessage)
|
||||
if err != nil {
|
||||
hlogger.Debug("failed to handle an encryption message", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
processedIds = append(processedIds, message.Hash)
|
||||
|
||||
stms, as, err := unwrapDatasyncMessage(&statusMessage, s.datasync)
|
||||
if err != nil {
|
||||
|
@ -798,6 +801,12 @@ func (s *MessageSender) HandleMessages(shhMessage *types.Message) ([]*v1protocol
|
|||
acks = append(acks, as...)
|
||||
}
|
||||
}
|
||||
err = s.persistence.DeleteHashRatchetMessages(processedIds)
|
||||
if err != nil {
|
||||
s.logger.Warn("failed to delete hash ratchet messages", zap.Error(err))
|
||||
return nil, nil, err
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
stms, as, err := unwrapDatasyncMessage(&statusMessage, s.datasync)
|
||||
|
|
|
@ -225,3 +225,82 @@ func (s *MessageSenderSuite) TestHandleDecodedMessagesDatasyncEncrypted() {
|
|||
s.Require().Equal(encodedPayload, decodedMessages[0].UnwrappedPayload)
|
||||
s.Require().Equal(protobuf.ApplicationMetadataMessage_CHAT_MESSAGE, decodedMessages[0].Type)
|
||||
}
|
||||
|
||||
func (s *MessageSenderSuite) TestHandleOutOfOrderHashRatchet() {
|
||||
groupID := []byte("group-id")
|
||||
senderKey, err := crypto.GenerateKey()
|
||||
s.Require().NoError(err)
|
||||
|
||||
encodedPayload, err := proto.Marshal(&s.testMessage)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Create sender encryption protocol.
|
||||
senderDatabase, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{})
|
||||
s.Require().NoError(err)
|
||||
err = sqlite.Migrate(senderDatabase)
|
||||
s.Require().NoError(err)
|
||||
|
||||
senderEncryptionProtocol := encryption.New(
|
||||
senderDatabase,
|
||||
"installation-2",
|
||||
s.logger,
|
||||
)
|
||||
|
||||
ratchet, err := senderEncryptionProtocol.GenerateHashRatchetKey(groupID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
ratchets := []*encryption.HashRatchetKeyCompatibility{ratchet}
|
||||
|
||||
hashRatchetKeyExchangeMessage, err := senderEncryptionProtocol.BuildHashRatchetKeyExchangeMessage(senderKey, &s.sender.identity.PublicKey, groupID, ratchets)
|
||||
s.Require().NoError(err)
|
||||
|
||||
encryptedPayload1, err := proto.Marshal(hashRatchetKeyExchangeMessage.Message)
|
||||
s.Require().NoError(err)
|
||||
|
||||
wrappedPayload2, err := v1protocol.WrapMessageV1(encodedPayload, protobuf.ApplicationMetadataMessage_CHAT_MESSAGE, senderKey)
|
||||
s.Require().NoError(err)
|
||||
|
||||
messageSpec2, err := senderEncryptionProtocol.BuildHashRatchetMessage(
|
||||
groupID,
|
||||
wrappedPayload2,
|
||||
)
|
||||
s.Require().NoError(err)
|
||||
|
||||
encryptedPayload2, err := proto.Marshal(messageSpec2.Message)
|
||||
s.Require().NoError(err)
|
||||
|
||||
message := &types.Message{}
|
||||
message.Sig = crypto.FromECDSAPub(&senderKey.PublicKey)
|
||||
message.Hash = []byte{0x1}
|
||||
message.Payload = encryptedPayload2
|
||||
|
||||
_, _, err = s.sender.HandleMessages(message)
|
||||
s.Require().NoError(err)
|
||||
|
||||
keyID, err := ratchet.GetKeyID()
|
||||
s.Require().NoError(err)
|
||||
|
||||
msgs, err := s.sender.persistence.GetHashRatchetMessages(keyID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Require().Len(msgs, 1)
|
||||
|
||||
message = &types.Message{}
|
||||
message.Sig = crypto.FromECDSAPub(&senderKey.PublicKey)
|
||||
message.Hash = []byte{0x2}
|
||||
message.Payload = encryptedPayload1
|
||||
|
||||
decodedMessages2, _, err := s.sender.HandleMessages(message)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(decodedMessages2)
|
||||
|
||||
// It should have 2 messages, the key exchange and the one from the database
|
||||
s.Require().Len(decodedMessages2, 2)
|
||||
|
||||
// it deletes the messages after being processed
|
||||
msgs, err = s.sender.persistence.GetHashRatchetMessages(keyID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Require().Len(msgs, 0)
|
||||
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"encoding/gob"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/status-im/status-go/eth-node/crypto"
|
||||
|
@ -319,3 +320,19 @@ func (db RawMessagesPersistence) GetHashRatchetMessages(keyID []byte) ([]*types.
|
|||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (db RawMessagesPersistence) DeleteHashRatchetMessages(ids [][]byte) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
idsArgs := make([]interface{}, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
idsArgs = append(idsArgs, id)
|
||||
}
|
||||
inVector := strings.Repeat("?, ", len(ids)-1) + "?"
|
||||
|
||||
_, err := db.db.Exec("DELETE FROM hash_ratchet_encrypted_messages WHERE hash IN ("+inVector+")", idsArgs...) // nolint: gosec
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -1747,3 +1747,61 @@ func TestCountActiveChattersInCommunity(t *testing.T) {
|
|||
checker(7, 1)
|
||||
checker(8, 0)
|
||||
}
|
||||
|
||||
func TestDeleteHashRatchetMessage(t *testing.T) {
|
||||
db, err := openTestDB()
|
||||
require.NoError(t, err)
|
||||
p := newSQLitePersistence(db)
|
||||
|
||||
groupID := []byte("group-id")
|
||||
keyID := []byte("key-id")
|
||||
|
||||
message1 := &types.Message{
|
||||
Hash: []byte{1},
|
||||
Sig: []byte{2},
|
||||
TTL: 1,
|
||||
Timestamp: 2,
|
||||
Payload: []byte{3},
|
||||
}
|
||||
|
||||
require.NoError(t, p.SaveHashRatchetMessage(groupID, keyID, message1))
|
||||
|
||||
message2 := &types.Message{
|
||||
Hash: []byte{2},
|
||||
Sig: []byte{2},
|
||||
TTL: 1,
|
||||
Topic: types.BytesToTopic([]byte{5}),
|
||||
Timestamp: 2,
|
||||
Payload: []byte{3},
|
||||
Dst: []byte{4},
|
||||
P2P: true,
|
||||
}
|
||||
|
||||
require.NoError(t, p.SaveHashRatchetMessage(groupID, keyID, message2))
|
||||
|
||||
message3 := &types.Message{
|
||||
Hash: []byte{3},
|
||||
Sig: []byte{2},
|
||||
TTL: 1,
|
||||
Topic: types.BytesToTopic([]byte{5}),
|
||||
Timestamp: 2,
|
||||
Payload: []byte{3},
|
||||
Dst: []byte{4},
|
||||
P2P: true,
|
||||
}
|
||||
|
||||
require.NoError(t, p.SaveHashRatchetMessage(groupID, keyID, message3))
|
||||
|
||||
fetchedMessages, err := p.GetHashRatchetMessages(keyID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, fetchedMessages)
|
||||
require.Len(t, fetchedMessages, 3)
|
||||
|
||||
require.NoError(t, p.DeleteHashRatchetMessages([][]byte{[]byte{1}, []byte{2}}))
|
||||
|
||||
fetchedMessages, err = p.GetHashRatchetMessages(keyID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, fetchedMessages)
|
||||
require.Len(t, fetchedMessages, 1)
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue