Delete processed messages & add tests
This commit adds a test for out of order messages, which were only implicitly tested. It also deletes them after being processed, otherwise they would be reproceessed each time a message was sent
This commit is contained in:
parent
d47b5733c0
commit
85f8c92cde
|
@ -775,6 +775,7 @@ func (s *MessageSender) HandleMessages(shhMessage *types.Message) ([]*v1protocol
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var processedIds [][]byte
|
||||||
for _, message := range messages {
|
for _, message := range messages {
|
||||||
var statusMessage v1protocol.StatusMessage
|
var statusMessage v1protocol.StatusMessage
|
||||||
err := statusMessage.HandleTransport(message)
|
err := statusMessage.HandleTransport(message)
|
||||||
|
@ -785,7 +786,9 @@ func (s *MessageSender) HandleMessages(shhMessage *types.Message) ([]*v1protocol
|
||||||
err = s.handleEncryptionLayer(context.Background(), &statusMessage)
|
err = s.handleEncryptionLayer(context.Background(), &statusMessage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
hlogger.Debug("failed to handle an encryption message", zap.Error(err))
|
hlogger.Debug("failed to handle an encryption message", zap.Error(err))
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
processedIds = append(processedIds, message.Hash)
|
||||||
|
|
||||||
stms, as, err := unwrapDatasyncMessage(&statusMessage, s.datasync)
|
stms, as, err := unwrapDatasyncMessage(&statusMessage, s.datasync)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -798,6 +801,12 @@ func (s *MessageSender) HandleMessages(shhMessage *types.Message) ([]*v1protocol
|
||||||
acks = append(acks, as...)
|
acks = append(acks, as...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
err = s.persistence.DeleteHashRatchetMessages(processedIds)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warn("failed to delete hash ratchet messages", zap.Error(err))
|
||||||
|
return nil, nil, err
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
stms, as, err := unwrapDatasyncMessage(&statusMessage, s.datasync)
|
stms, as, err := unwrapDatasyncMessage(&statusMessage, s.datasync)
|
||||||
|
|
|
@ -225,3 +225,82 @@ func (s *MessageSenderSuite) TestHandleDecodedMessagesDatasyncEncrypted() {
|
||||||
s.Require().Equal(encodedPayload, decodedMessages[0].UnwrappedPayload)
|
s.Require().Equal(encodedPayload, decodedMessages[0].UnwrappedPayload)
|
||||||
s.Require().Equal(protobuf.ApplicationMetadataMessage_CHAT_MESSAGE, decodedMessages[0].Type)
|
s.Require().Equal(protobuf.ApplicationMetadataMessage_CHAT_MESSAGE, decodedMessages[0].Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *MessageSenderSuite) TestHandleOutOfOrderHashRatchet() {
|
||||||
|
groupID := []byte("group-id")
|
||||||
|
senderKey, err := crypto.GenerateKey()
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
encodedPayload, err := proto.Marshal(&s.testMessage)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
// Create sender encryption protocol.
|
||||||
|
senderDatabase, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{})
|
||||||
|
s.Require().NoError(err)
|
||||||
|
err = sqlite.Migrate(senderDatabase)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
senderEncryptionProtocol := encryption.New(
|
||||||
|
senderDatabase,
|
||||||
|
"installation-2",
|
||||||
|
s.logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
ratchet, err := senderEncryptionProtocol.GenerateHashRatchetKey(groupID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
ratchets := []*encryption.HashRatchetKeyCompatibility{ratchet}
|
||||||
|
|
||||||
|
hashRatchetKeyExchangeMessage, err := senderEncryptionProtocol.BuildHashRatchetKeyExchangeMessage(senderKey, &s.sender.identity.PublicKey, groupID, ratchets)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
encryptedPayload1, err := proto.Marshal(hashRatchetKeyExchangeMessage.Message)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
wrappedPayload2, err := v1protocol.WrapMessageV1(encodedPayload, protobuf.ApplicationMetadataMessage_CHAT_MESSAGE, senderKey)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
messageSpec2, err := senderEncryptionProtocol.BuildHashRatchetMessage(
|
||||||
|
groupID,
|
||||||
|
wrappedPayload2,
|
||||||
|
)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
encryptedPayload2, err := proto.Marshal(messageSpec2.Message)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
message := &types.Message{}
|
||||||
|
message.Sig = crypto.FromECDSAPub(&senderKey.PublicKey)
|
||||||
|
message.Hash = []byte{0x1}
|
||||||
|
message.Payload = encryptedPayload2
|
||||||
|
|
||||||
|
_, _, err = s.sender.HandleMessages(message)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
keyID, err := ratchet.GetKeyID()
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
msgs, err := s.sender.persistence.GetHashRatchetMessages(keyID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
s.Require().Len(msgs, 1)
|
||||||
|
|
||||||
|
message = &types.Message{}
|
||||||
|
message.Sig = crypto.FromECDSAPub(&senderKey.PublicKey)
|
||||||
|
message.Hash = []byte{0x2}
|
||||||
|
message.Payload = encryptedPayload1
|
||||||
|
|
||||||
|
decodedMessages2, _, err := s.sender.HandleMessages(message)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
s.Require().NotNil(decodedMessages2)
|
||||||
|
|
||||||
|
// It should have 2 messages, the key exchange and the one from the database
|
||||||
|
s.Require().Len(decodedMessages2, 2)
|
||||||
|
|
||||||
|
// it deletes the messages after being processed
|
||||||
|
msgs, err = s.sender.persistence.GetHashRatchetMessages(keyID)
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
|
s.Require().Len(msgs, 0)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/status-im/status-go/eth-node/crypto"
|
"github.com/status-im/status-go/eth-node/crypto"
|
||||||
|
@ -319,3 +320,19 @@ func (db RawMessagesPersistence) GetHashRatchetMessages(keyID []byte) ([]*types.
|
||||||
|
|
||||||
return messages, nil
|
return messages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db RawMessagesPersistence) DeleteHashRatchetMessages(ids [][]byte) error {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
idsArgs := make([]interface{}, 0, len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
idsArgs = append(idsArgs, id)
|
||||||
|
}
|
||||||
|
inVector := strings.Repeat("?, ", len(ids)-1) + "?"
|
||||||
|
|
||||||
|
_, err := db.db.Exec("DELETE FROM hash_ratchet_encrypted_messages WHERE hash IN ("+inVector+")", idsArgs...) // nolint: gosec
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
|
@ -1747,3 +1747,61 @@ func TestCountActiveChattersInCommunity(t *testing.T) {
|
||||||
checker(7, 1)
|
checker(7, 1)
|
||||||
checker(8, 0)
|
checker(8, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDeleteHashRatchetMessage(t *testing.T) {
|
||||||
|
db, err := openTestDB()
|
||||||
|
require.NoError(t, err)
|
||||||
|
p := newSQLitePersistence(db)
|
||||||
|
|
||||||
|
groupID := []byte("group-id")
|
||||||
|
keyID := []byte("key-id")
|
||||||
|
|
||||||
|
message1 := &types.Message{
|
||||||
|
Hash: []byte{1},
|
||||||
|
Sig: []byte{2},
|
||||||
|
TTL: 1,
|
||||||
|
Timestamp: 2,
|
||||||
|
Payload: []byte{3},
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, p.SaveHashRatchetMessage(groupID, keyID, message1))
|
||||||
|
|
||||||
|
message2 := &types.Message{
|
||||||
|
Hash: []byte{2},
|
||||||
|
Sig: []byte{2},
|
||||||
|
TTL: 1,
|
||||||
|
Topic: types.BytesToTopic([]byte{5}),
|
||||||
|
Timestamp: 2,
|
||||||
|
Payload: []byte{3},
|
||||||
|
Dst: []byte{4},
|
||||||
|
P2P: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, p.SaveHashRatchetMessage(groupID, keyID, message2))
|
||||||
|
|
||||||
|
message3 := &types.Message{
|
||||||
|
Hash: []byte{3},
|
||||||
|
Sig: []byte{2},
|
||||||
|
TTL: 1,
|
||||||
|
Topic: types.BytesToTopic([]byte{5}),
|
||||||
|
Timestamp: 2,
|
||||||
|
Payload: []byte{3},
|
||||||
|
Dst: []byte{4},
|
||||||
|
P2P: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, p.SaveHashRatchetMessage(groupID, keyID, message3))
|
||||||
|
|
||||||
|
fetchedMessages, err := p.GetHashRatchetMessages(keyID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, fetchedMessages)
|
||||||
|
require.Len(t, fetchedMessages, 3)
|
||||||
|
|
||||||
|
require.NoError(t, p.DeleteHashRatchetMessages([][]byte{[]byte{1}, []byte{2}}))
|
||||||
|
|
||||||
|
fetchedMessages, err = p.GetHashRatchetMessages(keyID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, fetchedMessages)
|
||||||
|
require.Len(t, fetchedMessages, 1)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue