Add confirm messages processed by ID (#1375)
Currently PFS messages are decrypted and therefore modified before being passed to the client. This make IDs computation difficult, as we pass the whole object to the client and expect the object be passed back once confirmed. This changes the behavior allowing confirmation by ID, which is passed to the client instead of the raw object. This is a breaking change, but status-react is already forward compatible.
This commit is contained in:
parent
72906ac655
commit
81d8ca82a2
|
@ -660,7 +660,6 @@
|
|||
revision = "ad98a36ba0da87206e3378c556abbfeaeaa98668"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
digest = "1:b9b9f43a8a410d633e6547f89e830926741070941f2243d4d3a0bb154f565c9e"
|
||||
name = "github.com/mr-tron/base58"
|
||||
packages = ["base58"]
|
||||
|
@ -783,12 +782,12 @@
|
|||
version = "v1.1"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:41fb72d7a71f37f1f9c766d965178636ecda21b429b1f2e3fff42cfc31279751"
|
||||
digest = "1:821730d8591dde31a0d53b286b5fb9fb7f02f601e0f32a6b759c97934587c0a7"
|
||||
name = "github.com/status-im/doubleratchet"
|
||||
packages = ["."]
|
||||
pruneopts = "NUT"
|
||||
revision = "4dcb6cba284ae9f97129e2a98b9277f629d9dbc4"
|
||||
version = "v1.0.0"
|
||||
revision = "f2aeb83683d753011e9cffda29be0bf6524e1b65"
|
||||
version = "v2.0.0"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:f3044238fc5d70eca12cc181b1e6d5270570d85a7b7046686381e618a783a7d6"
|
||||
|
|
|
@ -293,7 +293,7 @@
|
|||
|
||||
[[constraint]]
|
||||
name = "github.com/status-im/doubleratchet"
|
||||
version = "=v1.0.0"
|
||||
version = "=v2.0.0"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/status-im/migrate"
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
"github.com/ethereum/go-ethereum/p2p/enode"
|
||||
"github.com/status-im/status-go/mailserver"
|
||||
"github.com/status-im/status-go/services/shhext/chat"
|
||||
"github.com/status-im/status-go/services/shhext/dedup"
|
||||
"github.com/status-im/status-go/services/shhext/mailservers"
|
||||
whisper "github.com/status-im/whisper/whisperv6"
|
||||
)
|
||||
|
@ -368,7 +369,7 @@ func (api *PublicAPI) SyncMessages(ctx context.Context, r SyncMessagesRequest) (
|
|||
}
|
||||
|
||||
// GetNewFilterMessages is a prototype method with deduplication
|
||||
func (api *PublicAPI) GetNewFilterMessages(filterID string) ([]*whisper.Message, error) {
|
||||
func (api *PublicAPI) GetNewFilterMessages(filterID string) ([]dedup.DeduplicateMessage, error) {
|
||||
msgs, err := api.publicAPI.GetFilterMessages(filterID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -378,9 +379,9 @@ func (api *PublicAPI) GetNewFilterMessages(filterID string) ([]*whisper.Message,
|
|||
|
||||
if api.service.pfsEnabled {
|
||||
// Attempt to decrypt message, otherwise leave unchanged
|
||||
for _, msg := range dedupMessages {
|
||||
for _, dedupMessage := range dedupMessages {
|
||||
|
||||
if err := api.processPFSMessage(msg); err != nil {
|
||||
if err := api.processPFSMessage(dedupMessage); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
@ -395,6 +396,16 @@ func (api *PublicAPI) ConfirmMessagesProcessed(messages []*whisper.Message) erro
|
|||
return api.service.deduplicator.AddMessages(messages)
|
||||
}
|
||||
|
||||
// ConfirmMessagesProcessedByID is a method to confirm that messages was consumed by
|
||||
// the client side.
|
||||
func (api *PublicAPI) ConfirmMessagesProcessedByID(messageIDs [][]byte) error {
|
||||
if err := api.service.protocol.ConfirmMessagesProcessed(messageIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return api.service.deduplicator.AddMessageByID(messageIDs)
|
||||
}
|
||||
|
||||
// SendPublicMessage sends a public chat message to the underlying transport
|
||||
func (api *PublicAPI) SendPublicMessage(ctx context.Context, msg chat.SendPublicMessageRPC) (hexutil.Bytes, error) {
|
||||
privateKey, err := api.service.w.GetPrivateKey(msg.Sig)
|
||||
|
@ -494,7 +505,8 @@ func (api *PublicAPI) SendPairingMessage(ctx context.Context, msg chat.SendDirec
|
|||
return response, nil
|
||||
}
|
||||
|
||||
func (api *PublicAPI) processPFSMessage(msg *whisper.Message) error {
|
||||
func (api *PublicAPI) processPFSMessage(dedupMessage dedup.DeduplicateMessage) error {
|
||||
msg := dedupMessage.Message
|
||||
|
||||
privateKeyID := api.service.w.SelectedKeyPairID()
|
||||
if privateKeyID == "" {
|
||||
|
@ -511,7 +523,7 @@ func (api *PublicAPI) processPFSMessage(msg *whisper.Message) error {
|
|||
return err
|
||||
}
|
||||
|
||||
response, err := api.service.protocol.HandleMessage(privateKey, publicKey, msg.Payload)
|
||||
response, err := api.service.protocol.HandleMessage(privateKey, publicKey, msg.Payload, dedupMessage.DedupID)
|
||||
|
||||
switch err {
|
||||
case nil:
|
||||
|
|
|
@ -3,17 +3,17 @@ package chat
|
|||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
ecrypto "github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/crypto/ecies"
|
||||
"github.com/ethereum/go-ethereum/log"
|
||||
dr "github.com/status-im/doubleratchet"
|
||||
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/status-im/status-go/services/shhext/chat/crypto"
|
||||
)
|
||||
|
||||
|
@ -23,11 +23,17 @@ var ErrDeviceNotFound = errors.New("device not found")
|
|||
// If we have no bundles, we use a constant so that the message can reach any device.
|
||||
const noInstallationID = "none"
|
||||
|
||||
type ConfirmationData struct {
|
||||
header *dr.MessageHeader
|
||||
drInfo *RatchetInfo
|
||||
}
|
||||
|
||||
// EncryptionService defines a service that is responsible for the encryption aspect of the protocol.
|
||||
type EncryptionService struct {
|
||||
log log.Logger
|
||||
persistence PersistenceService
|
||||
config EncryptionServiceConfig
|
||||
messageIDs map[string]*ConfirmationData
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
|
@ -68,6 +74,7 @@ func NewEncryptionService(p PersistenceService, config EncryptionServiceConfig)
|
|||
persistence: p,
|
||||
config: config,
|
||||
mutex: sync.Mutex{},
|
||||
messageIDs: make(map[string]*ConfirmationData),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -94,6 +101,36 @@ func (s *EncryptionService) getDRSession(id []byte) (dr.Session, error) {
|
|||
|
||||
}
|
||||
|
||||
func confirmationIDString(id []byte) string {
|
||||
return hex.EncodeToString(id)
|
||||
}
|
||||
|
||||
// ConfirmMessagesProcessed confirms and deletes message keys for the given messages
|
||||
func (s *EncryptionService) ConfirmMessagesProcessed(messageIDs [][]byte) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
for _, idByte := range messageIDs {
|
||||
id := confirmationIDString(idByte)
|
||||
confirmationData, ok := s.messageIDs[id]
|
||||
if !ok {
|
||||
s.log.Warn("Could not confirm message", "messageID", id)
|
||||
continue
|
||||
}
|
||||
|
||||
// Load session from store first
|
||||
session, err := s.getDRSession(confirmationData.drInfo.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := session.DeleteMk(confirmationData.header.DH, confirmationData.header.N); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateBundle retrieves or creates an X3DH bundle given a private key
|
||||
func (s *EncryptionService) CreateBundle(privateKey *ecdsa.PrivateKey) (*Bundle, error) {
|
||||
ourIdentityKeyC := ecrypto.CompressPubkey(&privateKey.PublicKey)
|
||||
|
@ -229,7 +266,7 @@ func (s *EncryptionService) ProcessPublicBundle(myIdentityKey *ecdsa.PrivateKey,
|
|||
}
|
||||
|
||||
// DecryptPayload decrypts the payload of a DirectMessageProtocol, given an identity private key and the sender's public key
|
||||
func (s *EncryptionService) DecryptPayload(myIdentityKey *ecdsa.PrivateKey, theirIdentityKey *ecdsa.PublicKey, theirInstallationID string, msgs map[string]*DirectMessageProtocol) ([]byte, error) {
|
||||
func (s *EncryptionService) DecryptPayload(myIdentityKey *ecdsa.PrivateKey, theirIdentityKey *ecdsa.PublicKey, theirInstallationID string, msgs map[string]*DirectMessageProtocol, messageID []byte) ([]byte, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
|
@ -301,6 +338,12 @@ func (s *EncryptionService) DecryptPayload(myIdentityKey *ecdsa.PrivateKey, thei
|
|||
return nil, ErrSessionNotFound
|
||||
}
|
||||
|
||||
confirmationData := &ConfirmationData{
|
||||
header: &drMessage.Header,
|
||||
drInfo: drInfo,
|
||||
}
|
||||
s.messageIDs[confirmationIDString(messageID)] = confirmationData
|
||||
|
||||
return s.decryptUsingDR(theirIdentityKey, drInfo, drMessage)
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
var cleartext = []byte("hello")
|
||||
var aliceInstallationID = "1"
|
||||
var bobInstallationID = "2"
|
||||
var defaultMessageID = []byte("default")
|
||||
|
||||
func TestEncryptionServiceTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(EncryptionServiceTestSuite))
|
||||
|
@ -118,7 +119,7 @@ func (s *EncryptionServiceTestSuite) TestEncryptPayloadNoBundle() {
|
|||
s.NotEqual(cyphertext1, cleartext, "It encrypts the payload correctly")
|
||||
|
||||
// On the receiver side, we should be able to decrypt using our private key and the ephemeral just sent
|
||||
decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse1)
|
||||
decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse1, defaultMessageID)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(cleartext, decryptedPayload1, "It correctly decrypts the payload using DH")
|
||||
|
||||
|
@ -133,7 +134,7 @@ func (s *EncryptionServiceTestSuite) TestEncryptPayloadNoBundle() {
|
|||
s.NotEqual(cyphertext1, cyphertext2, "It does not re-use the symmetric key")
|
||||
s.NotEqual(ephemeralKey1, ephemeralKey2, "It does not re-use the ephemeral key")
|
||||
|
||||
decryptedPayload2, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse2)
|
||||
decryptedPayload2, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse2, defaultMessageID)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(cleartext, decryptedPayload2, "It correctly decrypts the payload using DH")
|
||||
}
|
||||
|
@ -185,7 +186,7 @@ func (s *EncryptionServiceTestSuite) TestEncryptPayloadBundle() {
|
|||
s.Equal(uint32(0), drHeader.GetPn(), "It adds the correct length of the message chain")
|
||||
|
||||
// Bob is able to decrypt it using the bundle
|
||||
decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse1)
|
||||
decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse1, defaultMessageID)
|
||||
s.Require().NoError(err)
|
||||
s.Equal(cleartext, decryptedPayload1, "It correctly decrypts the payload using X3DH")
|
||||
}
|
||||
|
@ -249,7 +250,7 @@ func (s *EncryptionServiceTestSuite) TestConsequentMessagesBundle() {
|
|||
s.Equal(uint32(0), drHeader.GetPn(), "It adds the correct length of the message chain")
|
||||
|
||||
// Bob is able to decrypt it using the bundle
|
||||
decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse)
|
||||
decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse, defaultMessageID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Equal(cleartext2, decryptedPayload1, "It correctly decrypts the payload using X3DH")
|
||||
|
@ -293,7 +294,7 @@ func (s *EncryptionServiceTestSuite) TestConversation() {
|
|||
s.Require().NoError(err)
|
||||
|
||||
// Bob receives the message
|
||||
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse)
|
||||
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse, defaultMessageID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Bob replies to the message
|
||||
|
@ -301,7 +302,7 @@ func (s *EncryptionServiceTestSuite) TestConversation() {
|
|||
s.Require().NoError(err)
|
||||
|
||||
// Alice receives the message
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, encryptionResponse)
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, encryptionResponse, defaultMessageID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// We send another message using the bundle
|
||||
|
@ -332,7 +333,7 @@ func (s *EncryptionServiceTestSuite) TestConversation() {
|
|||
s.Equal(uint32(1), drHeader.GetPn(), "It adds the correct length of the message chain")
|
||||
|
||||
// Bob is able to decrypt it using the bundle
|
||||
decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse)
|
||||
decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse, defaultMessageID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Equal(cleartext2, decryptedPayload1, "It correctly decrypts the payload using X3DH")
|
||||
|
@ -380,7 +381,7 @@ func (s *EncryptionServiceTestSuite) TestMaxSkipKeys() {
|
|||
s.Require().NoError(err)
|
||||
|
||||
// Alice receives the message
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1)
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, defaultMessageID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Bob sends a message
|
||||
|
@ -393,7 +394,7 @@ func (s *EncryptionServiceTestSuite) TestMaxSkipKeys() {
|
|||
|
||||
// Alice receives the message, we should have maxSkip + 1 keys in the db, but
|
||||
// we should not throw an error
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2)
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2, defaultMessageID)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
|
@ -435,11 +436,18 @@ func (s *EncryptionServiceTestSuite) TestMaxSkipKeysError() {
|
|||
s.Require().NoError(err)
|
||||
|
||||
// Alice receives the message
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1)
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, defaultMessageID)
|
||||
s.Require().Equal(errors.New("can't skip current chain message keys: too many messages"), err)
|
||||
}
|
||||
|
||||
func (s *EncryptionServiceTestSuite) TestMaxMessageKeysPerSession() {
|
||||
config := DefaultEncryptionServiceConfig("none")
|
||||
// Set MaxKeep and MaxSkip to an high value so it does not interfere
|
||||
config.MaxKeep = 100000
|
||||
config.MaxSkip = 100000
|
||||
|
||||
s.initDatabases(&config)
|
||||
|
||||
bobText := []byte("text")
|
||||
|
||||
bobKey, err := crypto.GenerateKey()
|
||||
|
@ -466,38 +474,37 @@ func (s *EncryptionServiceTestSuite) TestMaxMessageKeysPerSession() {
|
|||
|
||||
// We create just enough messages so that the first key should be deleted
|
||||
|
||||
nMessages := s.alice.config.MaxMessageKeysPerSession + s.alice.config.MaxMessageKeysPerSession/s.alice.config.MaxSkip + 2
|
||||
nMessages := s.alice.config.MaxMessageKeysPerSession
|
||||
messages := make([]map[string]*DirectMessageProtocol, nMessages)
|
||||
for i := 0; i < nMessages; i++ {
|
||||
m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||
s.Require().NoError(err)
|
||||
|
||||
messages[i] = m
|
||||
|
||||
// We decrypt some messages otherwise we hit maxSkip limit
|
||||
if i%s.alice.config.MaxSkip == 0 {
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Another message to trigger the deletion
|
||||
m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
|
||||
s.Require().NoError(err)
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m)
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m, defaultMessageID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// We decrypt the first message, and it should fail
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[1])
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[0], defaultMessageID)
|
||||
s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err)
|
||||
|
||||
// We decrypt the second message, and it should be decrypted
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[2])
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[1], defaultMessageID)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *EncryptionServiceTestSuite) TestMaxKeep() {
|
||||
config := DefaultEncryptionServiceConfig("none")
|
||||
// Set MaxMessageKeysPerSession to an high value so it does not interfere
|
||||
config.MaxMessageKeysPerSession = 100000
|
||||
|
||||
s.initDatabases(&config)
|
||||
|
||||
bobText := []byte("text")
|
||||
|
||||
bobKey, err := crypto.GenerateKey()
|
||||
|
@ -530,18 +537,21 @@ func (s *EncryptionServiceTestSuite) TestMaxKeep() {
|
|||
s.Require().NoError(err)
|
||||
|
||||
if i != 0 && i != 1 {
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m)
|
||||
messageID := []byte(fmt.Sprintf("%d", i))
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m, messageID)
|
||||
s.Require().NoError(err)
|
||||
err = s.alice.ConfirmMessagesProcessed([][]byte{messageID})
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// We decrypt the first message, and it should fail, as it should have been removed
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[0])
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[0], defaultMessageID)
|
||||
s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err)
|
||||
|
||||
// We decrypt the second message, and it should be decrypted
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[1])
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[1], defaultMessageID)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
|
@ -590,11 +600,11 @@ func (s *EncryptionServiceTestSuite) TestConcurrentBundles() {
|
|||
s.Require().NoError(err)
|
||||
|
||||
// Bob receives the message
|
||||
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage1)
|
||||
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage1, defaultMessageID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Alice receives the message
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1)
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, defaultMessageID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Bob replies to the message
|
||||
|
@ -606,11 +616,11 @@ func (s *EncryptionServiceTestSuite) TestConcurrentBundles() {
|
|||
s.Require().NoError(err)
|
||||
|
||||
// Alice receives the message
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2)
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2, defaultMessageID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Bob receives the message
|
||||
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage2)
|
||||
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage2, defaultMessageID)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
|
@ -658,7 +668,7 @@ func receiver(
|
|||
i := 0
|
||||
|
||||
for payload := range input {
|
||||
actualCleartext, err := s.DecryptPayload(privateKey, publicKey, installationID, payload)
|
||||
actualCleartext, err := s.DecryptPayload(privateKey, publicKey, installationID, payload, defaultMessageID)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
|
@ -765,7 +775,7 @@ func (s *EncryptionServiceTestSuite) TestBundleNotExisting() {
|
|||
s.Require().NoError(err)
|
||||
|
||||
// Bob receives the message, and returns a bundlenotfound error
|
||||
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage)
|
||||
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage, defaultMessageID)
|
||||
s.Require().Error(err)
|
||||
s.Equal(ErrSessionNotFound, err)
|
||||
}
|
||||
|
@ -798,7 +808,7 @@ func (s *EncryptionServiceTestSuite) TestDeviceNotIncluded() {
|
|||
s.Require().NoError(err)
|
||||
|
||||
// Bob receives the message, and returns a bundlenotfound error
|
||||
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage)
|
||||
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage, defaultMessageID)
|
||||
s.Require().Error(err)
|
||||
s.Equal(ErrDeviceNotFound, err)
|
||||
}
|
||||
|
@ -866,3 +876,88 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() {
|
|||
s.Equal(bobBundle2.GetSignedPreKeys()[bobInstallationID].GetSignedPreKey(), x3dhHeader2.GetId())
|
||||
|
||||
}
|
||||
|
||||
func (s *EncryptionServiceTestSuite) TestMessageConfirmation() {
|
||||
bobText1 := []byte("bob text 1")
|
||||
|
||||
bobKey, err := crypto.GenerateKey()
|
||||
s.Require().NoError(err)
|
||||
|
||||
aliceKey, err := crypto.GenerateKey()
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Create a bundle
|
||||
bobBundle, err := s.bob.CreateBundle(bobKey)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// We add bob bundle
|
||||
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Create a bundle
|
||||
aliceBundle, err := s.alice.CreateBundle(aliceKey)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// We add alice bundle
|
||||
_, err = s.bob.ProcessPublicBundle(bobKey, aliceBundle)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Bob sends a message
|
||||
bobMessage1, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText1)
|
||||
s.Require().NoError(err)
|
||||
bobMessage1ID := []byte("bob-message-1-id")
|
||||
|
||||
// Alice receives the message once
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, bobMessage1ID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Alice receives the message twice
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, bobMessage1ID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Alice confirms the message
|
||||
err = s.alice.ConfirmMessagesProcessed([][]byte{bobMessage1ID})
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Alice decrypts it again, it should fail
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, bobMessage1ID)
|
||||
s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err)
|
||||
|
||||
// Bob sends a message
|
||||
bobMessage2, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText1)
|
||||
s.Require().NoError(err)
|
||||
bobMessage2ID := []byte("bob-message-2-id")
|
||||
|
||||
// Bob sends a message
|
||||
bobMessage3, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText1)
|
||||
s.Require().NoError(err)
|
||||
bobMessage3ID := []byte("bob-message-3-id")
|
||||
|
||||
// Alice receives message 3 once
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage3, bobMessage3ID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Alice receives message 3 twice
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage3, bobMessage3ID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Alice receives message 2 once
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2, bobMessage2ID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Alice receives message 2 twice
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2, bobMessage2ID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Alice confirms the messages
|
||||
err = s.alice.ConfirmMessagesProcessed([][]byte{bobMessage2ID, bobMessage3ID})
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Alice decrypts it again, it should fail
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage3, bobMessage3ID)
|
||||
s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err)
|
||||
|
||||
// Alice decrypts it again, it should fail
|
||||
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2, bobMessage2ID)
|
||||
s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err)
|
||||
}
|
||||
|
|
|
@ -124,8 +124,13 @@ func (p *ProtocolService) GetPublicBundle(theirIdentityKey *ecdsa.PublicKey) (*B
|
|||
return p.encryption.GetPublicBundle(theirIdentityKey)
|
||||
}
|
||||
|
||||
// ConfirmMessagesProcessed confirms and deletes message keys for the given messages
|
||||
func (p *ProtocolService) ConfirmMessagesProcessed(messageIDs [][]byte) error {
|
||||
return p.encryption.ConfirmMessagesProcessed(messageIDs)
|
||||
}
|
||||
|
||||
// HandleMessage unmarshals a message and processes it, decrypting it if it is a 1:1 message.
|
||||
func (p *ProtocolService) HandleMessage(myIdentityKey *ecdsa.PrivateKey, theirPublicKey *ecdsa.PublicKey, payload []byte) ([]byte, error) {
|
||||
func (p *ProtocolService) HandleMessage(myIdentityKey *ecdsa.PrivateKey, theirPublicKey *ecdsa.PublicKey, payload []byte, messageID []byte) ([]byte, error) {
|
||||
if p.encryption == nil {
|
||||
return nil, errors.New("encryption service not initialized")
|
||||
}
|
||||
|
@ -167,7 +172,7 @@ func (p *ProtocolService) HandleMessage(myIdentityKey *ecdsa.PrivateKey, theirPu
|
|||
|
||||
// Decrypt message
|
||||
if directMessage := protocolMessage.GetDirectMessage(); directMessage != nil {
|
||||
message, err := p.encryption.DecryptPayload(myIdentityKey, theirPublicKey, protocolMessage.GetInstallationId(), directMessage)
|
||||
message, err := p.encryption.DecryptPayload(myIdentityKey, theirPublicKey, protocolMessage.GetInstallationId(), directMessage, messageID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -120,7 +120,7 @@ func (s *ProtocolServiceTestSuite) TestBuildAndReadDirectMessage() {
|
|||
s.NoError(err)
|
||||
|
||||
// Bob is able to decrypt the message
|
||||
unmarshaledMsg, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, marshaledMsg)
|
||||
unmarshaledMsg, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, marshaledMsg, []byte("message-id"))
|
||||
s.NoError(err)
|
||||
|
||||
s.NotNil(unmarshaledMsg)
|
||||
|
|
|
@ -22,7 +22,7 @@ func newCache(db *leveldb.DB) *cache {
|
|||
}
|
||||
|
||||
func (d *cache) Has(filterID string, message *whisper.Message) (bool, error) {
|
||||
has, err := d.db.Has(d.keyToday(filterID, message), nil)
|
||||
has, err := d.db.Has(d.KeyToday(filterID, message), nil)
|
||||
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
@ -38,7 +38,22 @@ func (d *cache) Put(filterID string, messages []*whisper.Message) error {
|
|||
batch := leveldb.Batch{}
|
||||
|
||||
for _, msg := range messages {
|
||||
batch.Put(d.keyToday(filterID, msg), []byte{})
|
||||
batch.Put(d.KeyToday(filterID, msg), []byte{})
|
||||
}
|
||||
|
||||
err := d.db.Write(&batch, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return d.cleanOldEntries()
|
||||
}
|
||||
|
||||
func (d *cache) PutIDs(messageIDs [][]byte) error {
|
||||
batch := leveldb.Batch{}
|
||||
|
||||
for _, id := range messageIDs {
|
||||
batch.Put(id, []byte{})
|
||||
}
|
||||
|
||||
err := d.db.Write(&batch, nil)
|
||||
|
@ -78,7 +93,7 @@ func (d *cache) keyYesterday(filterID string, message *whisper.Message) []byte {
|
|||
return prefixedKey(d.yesterdayDateString(), filterID, message)
|
||||
}
|
||||
|
||||
func (d *cache) keyToday(filterID string, message *whisper.Message) []byte {
|
||||
func (d *cache) KeyToday(filterID string, message *whisper.Message) []byte {
|
||||
return prefixedKey(d.todayDateString(), filterID, message)
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,11 @@ type Deduplicator struct {
|
|||
log log.Logger
|
||||
}
|
||||
|
||||
type DeduplicateMessage struct {
|
||||
DedupID []byte `json:"id"`
|
||||
Message *whisper.Message `json:"message"`
|
||||
}
|
||||
|
||||
// NewDeduplicator creates a new deduplicator
|
||||
func NewDeduplicator(keyPairProvider keyPairProvider, db *leveldb.DB) *Deduplicator {
|
||||
return &Deduplicator{
|
||||
|
@ -30,15 +35,19 @@ func NewDeduplicator(keyPairProvider keyPairProvider, db *leveldb.DB) *Deduplica
|
|||
// Deduplicate receives a list of whisper messages and
|
||||
// returns the list of the messages that weren't filtered previously for the
|
||||
// specified filter.
|
||||
func (d *Deduplicator) Deduplicate(messages []*whisper.Message) []*whisper.Message {
|
||||
result := make([]*whisper.Message, 0)
|
||||
func (d *Deduplicator) Deduplicate(messages []*whisper.Message) []DeduplicateMessage {
|
||||
result := make([]DeduplicateMessage, 0)
|
||||
selectedKeyPairID := d.keyPairProvider.SelectedKeyPairID()
|
||||
|
||||
for _, message := range messages {
|
||||
if has, err := d.cache.Has(d.keyPairProvider.SelectedKeyPairID(), message); !has {
|
||||
if has, err := d.cache.Has(selectedKeyPairID, message); !has {
|
||||
if err != nil {
|
||||
d.log.Error("error while deduplicating messages: search cache failed", "err", err)
|
||||
}
|
||||
result = append(result, message)
|
||||
result = append(result, DeduplicateMessage{
|
||||
DedupID: d.cache.KeyToday(selectedKeyPairID, message),
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -50,3 +59,9 @@ func (d *Deduplicator) Deduplicate(messages []*whisper.Message) []*whisper.Messa
|
|||
func (d *Deduplicator) AddMessages(messages []*whisper.Message) error {
|
||||
return d.cache.Put(d.keyPairProvider.SelectedKeyPairID(), messages)
|
||||
}
|
||||
|
||||
// AddMessageByID adds a message to the deduplicator DB, so it will be filtered
|
||||
// out.
|
||||
func (d *Deduplicator) AddMessageByID(messageIDs [][]byte) error {
|
||||
return d.cache.PutIDs(messageIDs)
|
||||
}
|
||||
|
|
|
@ -10,6 +10,9 @@ type Session interface {
|
|||
|
||||
// RatchetDecrypt is called to AEAD-decrypt messages.
|
||||
RatchetDecrypt(m Message, associatedData []byte) ([]byte, error)
|
||||
|
||||
//DeleteMk remove a message key from the database
|
||||
DeleteMk(Key, uint32) error
|
||||
}
|
||||
|
||||
type sessionState struct {
|
||||
|
@ -101,6 +104,11 @@ func (s *sessionState) RatchetEncrypt(plaintext, ad []byte) (Message, error) {
|
|||
return Message{h, ct}, nil
|
||||
}
|
||||
|
||||
// DeleteMk deletes a message key
|
||||
func (s *sessionState) DeleteMk(dh Key, n uint32) error {
|
||||
return s.MkSkipped.DeleteMk(dh, uint(n))
|
||||
}
|
||||
|
||||
// RatchetDecrypt is called to decrypt messages.
|
||||
func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
|
||||
// Is the message one of the skipped?
|
||||
|
@ -114,7 +122,6 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("can't decrypt skipped message: %s", err)
|
||||
}
|
||||
_ = s.MkSkipped.DeleteMk(m.Header.DH, uint(m.Header.N))
|
||||
if err := s.store(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -149,11 +156,20 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
|
|||
return nil, fmt.Errorf("can't decrypt: %s", err)
|
||||
}
|
||||
|
||||
// Append current key, waiting for confirmation
|
||||
skippedKeys := append(skippedKeys1, skippedKeys2...)
|
||||
skippedKeys = append(skippedKeys, skippedKey{
|
||||
key: sc.DHr,
|
||||
nr: uint(m.Header.N),
|
||||
mk: mk,
|
||||
seq: sc.KeysCount,
|
||||
})
|
||||
|
||||
// Increment the number of keys
|
||||
sc.KeysCount++
|
||||
|
||||
// Apply changes.
|
||||
if err := s.applyChanges(sc, s.id, append(skippedKeys1, skippedKeys2...)); err != nil {
|
||||
if err := s.applyChanges(sc, s.id, skippedKeys); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -75,7 +75,6 @@ type ReceivedMessage struct {
|
|||
|
||||
SymKeyHash common.Hash // The Keccak256Hash of the key
|
||||
EnvelopeHash common.Hash // Message envelope hash to act as a unique id
|
||||
History bool
|
||||
}
|
||||
|
||||
func isMessageSigned(flags byte) bool {
|
||||
|
|
Loading…
Reference in New Issue