Add confirm messages processed by ID (#1375)

Currently PFS messages are decrypted and therefore modified before being
passed to the client. This make IDs computation difficult, as we pass
the whole object to the client and expect the object be passed back once
confirmed.
This changes the behavior allowing confirmation by ID, which is passed
to the client instead of the raw object.
This is a breaking change, but status-react is already forward
compatible.
This commit is contained in:
Andrea Maria Piana 2019-02-19 13:58:42 +01:00 committed by GitHub
parent 72906ac655
commit 81d8ca82a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 258 additions and 59 deletions

7
Gopkg.lock generated
View File

@ -660,7 +660,6 @@
revision = "ad98a36ba0da87206e3378c556abbfeaeaa98668"
[[projects]]
branch = "master"
digest = "1:b9b9f43a8a410d633e6547f89e830926741070941f2243d4d3a0bb154f565c9e"
name = "github.com/mr-tron/base58"
packages = ["base58"]
@ -783,12 +782,12 @@
version = "v1.1"
[[projects]]
digest = "1:41fb72d7a71f37f1f9c766d965178636ecda21b429b1f2e3fff42cfc31279751"
digest = "1:821730d8591dde31a0d53b286b5fb9fb7f02f601e0f32a6b759c97934587c0a7"
name = "github.com/status-im/doubleratchet"
packages = ["."]
pruneopts = "NUT"
revision = "4dcb6cba284ae9f97129e2a98b9277f629d9dbc4"
version = "v1.0.0"
revision = "f2aeb83683d753011e9cffda29be0bf6524e1b65"
version = "v2.0.0"
[[projects]]
digest = "1:f3044238fc5d70eca12cc181b1e6d5270570d85a7b7046686381e618a783a7d6"

View File

@ -293,7 +293,7 @@
[[constraint]]
name = "github.com/status-im/doubleratchet"
version = "=v1.0.0"
version = "=v2.0.0"
[[constraint]]
name = "github.com/status-im/migrate"

View File

@ -1 +1 @@
0.22.0-beta.0
0.23.0-beta.0

View File

@ -18,6 +18,7 @@ import (
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/status-im/status-go/mailserver"
"github.com/status-im/status-go/services/shhext/chat"
"github.com/status-im/status-go/services/shhext/dedup"
"github.com/status-im/status-go/services/shhext/mailservers"
whisper "github.com/status-im/whisper/whisperv6"
)
@ -368,7 +369,7 @@ func (api *PublicAPI) SyncMessages(ctx context.Context, r SyncMessagesRequest) (
}
// GetNewFilterMessages is a prototype method with deduplication
func (api *PublicAPI) GetNewFilterMessages(filterID string) ([]*whisper.Message, error) {
func (api *PublicAPI) GetNewFilterMessages(filterID string) ([]dedup.DeduplicateMessage, error) {
msgs, err := api.publicAPI.GetFilterMessages(filterID)
if err != nil {
return nil, err
@ -378,9 +379,9 @@ func (api *PublicAPI) GetNewFilterMessages(filterID string) ([]*whisper.Message,
if api.service.pfsEnabled {
// Attempt to decrypt message, otherwise leave unchanged
for _, msg := range dedupMessages {
for _, dedupMessage := range dedupMessages {
if err := api.processPFSMessage(msg); err != nil {
if err := api.processPFSMessage(dedupMessage); err != nil {
return nil, err
}
}
@ -395,6 +396,16 @@ func (api *PublicAPI) ConfirmMessagesProcessed(messages []*whisper.Message) erro
return api.service.deduplicator.AddMessages(messages)
}
// ConfirmMessagesProcessedByID is a method to confirm that messages was consumed by
// the client side.
func (api *PublicAPI) ConfirmMessagesProcessedByID(messageIDs [][]byte) error {
if err := api.service.protocol.ConfirmMessagesProcessed(messageIDs); err != nil {
return err
}
return api.service.deduplicator.AddMessageByID(messageIDs)
}
// SendPublicMessage sends a public chat message to the underlying transport
func (api *PublicAPI) SendPublicMessage(ctx context.Context, msg chat.SendPublicMessageRPC) (hexutil.Bytes, error) {
privateKey, err := api.service.w.GetPrivateKey(msg.Sig)
@ -494,7 +505,8 @@ func (api *PublicAPI) SendPairingMessage(ctx context.Context, msg chat.SendDirec
return response, nil
}
func (api *PublicAPI) processPFSMessage(msg *whisper.Message) error {
func (api *PublicAPI) processPFSMessage(dedupMessage dedup.DeduplicateMessage) error {
msg := dedupMessage.Message
privateKeyID := api.service.w.SelectedKeyPairID()
if privateKeyID == "" {
@ -511,7 +523,7 @@ func (api *PublicAPI) processPFSMessage(msg *whisper.Message) error {
return err
}
response, err := api.service.protocol.HandleMessage(privateKey, publicKey, msg.Payload)
response, err := api.service.protocol.HandleMessage(privateKey, publicKey, msg.Payload, dedupMessage.DedupID)
switch err {
case nil:

View File

@ -3,17 +3,17 @@ package chat
import (
"bytes"
"crypto/ecdsa"
"encoding/hex"
"errors"
"fmt"
"sync"
"time"
ecrypto "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/ecies"
"github.com/ethereum/go-ethereum/log"
dr "github.com/status-im/doubleratchet"
"sync"
"time"
"github.com/status-im/status-go/services/shhext/chat/crypto"
)
@ -23,11 +23,17 @@ var ErrDeviceNotFound = errors.New("device not found")
// If we have no bundles, we use a constant so that the message can reach any device.
const noInstallationID = "none"
type ConfirmationData struct {
header *dr.MessageHeader
drInfo *RatchetInfo
}
// EncryptionService defines a service that is responsible for the encryption aspect of the protocol.
type EncryptionService struct {
log log.Logger
persistence PersistenceService
config EncryptionServiceConfig
messageIDs map[string]*ConfirmationData
mutex sync.Mutex
}
@ -68,6 +74,7 @@ func NewEncryptionService(p PersistenceService, config EncryptionServiceConfig)
persistence: p,
config: config,
mutex: sync.Mutex{},
messageIDs: make(map[string]*ConfirmationData),
}
}
@ -94,6 +101,36 @@ func (s *EncryptionService) getDRSession(id []byte) (dr.Session, error) {
}
func confirmationIDString(id []byte) string {
return hex.EncodeToString(id)
}
// ConfirmMessagesProcessed confirms and deletes message keys for the given messages
func (s *EncryptionService) ConfirmMessagesProcessed(messageIDs [][]byte) error {
s.mutex.Lock()
defer s.mutex.Unlock()
for _, idByte := range messageIDs {
id := confirmationIDString(idByte)
confirmationData, ok := s.messageIDs[id]
if !ok {
s.log.Warn("Could not confirm message", "messageID", id)
continue
}
// Load session from store first
session, err := s.getDRSession(confirmationData.drInfo.ID)
if err != nil {
return err
}
if err := session.DeleteMk(confirmationData.header.DH, confirmationData.header.N); err != nil {
return err
}
}
return nil
}
// CreateBundle retrieves or creates an X3DH bundle given a private key
func (s *EncryptionService) CreateBundle(privateKey *ecdsa.PrivateKey) (*Bundle, error) {
ourIdentityKeyC := ecrypto.CompressPubkey(&privateKey.PublicKey)
@ -229,7 +266,7 @@ func (s *EncryptionService) ProcessPublicBundle(myIdentityKey *ecdsa.PrivateKey,
}
// DecryptPayload decrypts the payload of a DirectMessageProtocol, given an identity private key and the sender's public key
func (s *EncryptionService) DecryptPayload(myIdentityKey *ecdsa.PrivateKey, theirIdentityKey *ecdsa.PublicKey, theirInstallationID string, msgs map[string]*DirectMessageProtocol) ([]byte, error) {
func (s *EncryptionService) DecryptPayload(myIdentityKey *ecdsa.PrivateKey, theirIdentityKey *ecdsa.PublicKey, theirInstallationID string, msgs map[string]*DirectMessageProtocol, messageID []byte) ([]byte, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
@ -301,6 +338,12 @@ func (s *EncryptionService) DecryptPayload(myIdentityKey *ecdsa.PrivateKey, thei
return nil, ErrSessionNotFound
}
confirmationData := &ConfirmationData{
header: &drMessage.Header,
drInfo: drInfo,
}
s.messageIDs[confirmationIDString(messageID)] = confirmationData
return s.decryptUsingDR(theirIdentityKey, drInfo, drMessage)
}

View File

@ -19,6 +19,7 @@ import (
var cleartext = []byte("hello")
var aliceInstallationID = "1"
var bobInstallationID = "2"
var defaultMessageID = []byte("default")
func TestEncryptionServiceTestSuite(t *testing.T) {
suite.Run(t, new(EncryptionServiceTestSuite))
@ -118,7 +119,7 @@ func (s *EncryptionServiceTestSuite) TestEncryptPayloadNoBundle() {
s.NotEqual(cyphertext1, cleartext, "It encrypts the payload correctly")
// On the receiver side, we should be able to decrypt using our private key and the ephemeral just sent
decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse1)
decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse1, defaultMessageID)
s.Require().NoError(err)
s.Equal(cleartext, decryptedPayload1, "It correctly decrypts the payload using DH")
@ -133,7 +134,7 @@ func (s *EncryptionServiceTestSuite) TestEncryptPayloadNoBundle() {
s.NotEqual(cyphertext1, cyphertext2, "It does not re-use the symmetric key")
s.NotEqual(ephemeralKey1, ephemeralKey2, "It does not re-use the ephemeral key")
decryptedPayload2, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse2)
decryptedPayload2, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse2, defaultMessageID)
s.Require().NoError(err)
s.Equal(cleartext, decryptedPayload2, "It correctly decrypts the payload using DH")
}
@ -185,7 +186,7 @@ func (s *EncryptionServiceTestSuite) TestEncryptPayloadBundle() {
s.Equal(uint32(0), drHeader.GetPn(), "It adds the correct length of the message chain")
// Bob is able to decrypt it using the bundle
decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse1)
decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse1, defaultMessageID)
s.Require().NoError(err)
s.Equal(cleartext, decryptedPayload1, "It correctly decrypts the payload using X3DH")
}
@ -249,7 +250,7 @@ func (s *EncryptionServiceTestSuite) TestConsequentMessagesBundle() {
s.Equal(uint32(0), drHeader.GetPn(), "It adds the correct length of the message chain")
// Bob is able to decrypt it using the bundle
decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse)
decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse, defaultMessageID)
s.Require().NoError(err)
s.Equal(cleartext2, decryptedPayload1, "It correctly decrypts the payload using X3DH")
@ -293,7 +294,7 @@ func (s *EncryptionServiceTestSuite) TestConversation() {
s.Require().NoError(err)
// Bob receives the message
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse)
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse, defaultMessageID)
s.Require().NoError(err)
// Bob replies to the message
@ -301,7 +302,7 @@ func (s *EncryptionServiceTestSuite) TestConversation() {
s.Require().NoError(err)
// Alice receives the message
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, encryptionResponse)
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, encryptionResponse, defaultMessageID)
s.Require().NoError(err)
// We send another message using the bundle
@ -332,7 +333,7 @@ func (s *EncryptionServiceTestSuite) TestConversation() {
s.Equal(uint32(1), drHeader.GetPn(), "It adds the correct length of the message chain")
// Bob is able to decrypt it using the bundle
decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse)
decryptedPayload1, err := s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, encryptionResponse, defaultMessageID)
s.Require().NoError(err)
s.Equal(cleartext2, decryptedPayload1, "It correctly decrypts the payload using X3DH")
@ -380,7 +381,7 @@ func (s *EncryptionServiceTestSuite) TestMaxSkipKeys() {
s.Require().NoError(err)
// Alice receives the message
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1)
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, defaultMessageID)
s.Require().NoError(err)
// Bob sends a message
@ -393,7 +394,7 @@ func (s *EncryptionServiceTestSuite) TestMaxSkipKeys() {
// Alice receives the message, we should have maxSkip + 1 keys in the db, but
// we should not throw an error
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2)
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2, defaultMessageID)
s.Require().NoError(err)
}
@ -435,11 +436,18 @@ func (s *EncryptionServiceTestSuite) TestMaxSkipKeysError() {
s.Require().NoError(err)
// Alice receives the message
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1)
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, defaultMessageID)
s.Require().Equal(errors.New("can't skip current chain message keys: too many messages"), err)
}
func (s *EncryptionServiceTestSuite) TestMaxMessageKeysPerSession() {
config := DefaultEncryptionServiceConfig("none")
// Set MaxKeep and MaxSkip to an high value so it does not interfere
config.MaxKeep = 100000
config.MaxSkip = 100000
s.initDatabases(&config)
bobText := []byte("text")
bobKey, err := crypto.GenerateKey()
@ -466,38 +474,37 @@ func (s *EncryptionServiceTestSuite) TestMaxMessageKeysPerSession() {
// We create just enough messages so that the first key should be deleted
nMessages := s.alice.config.MaxMessageKeysPerSession + s.alice.config.MaxMessageKeysPerSession/s.alice.config.MaxSkip + 2
nMessages := s.alice.config.MaxMessageKeysPerSession
messages := make([]map[string]*DirectMessageProtocol, nMessages)
for i := 0; i < nMessages; i++ {
m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
s.Require().NoError(err)
messages[i] = m
// We decrypt some messages otherwise we hit maxSkip limit
if i%s.alice.config.MaxSkip == 0 {
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m)
s.Require().NoError(err)
}
}
// Another message to trigger the deletion
m, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText)
s.Require().NoError(err)
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m)
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m, defaultMessageID)
s.Require().NoError(err)
// We decrypt the first message, and it should fail
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[1])
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[0], defaultMessageID)
s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err)
// We decrypt the second message, and it should be decrypted
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[2])
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[1], defaultMessageID)
s.Require().NoError(err)
}
func (s *EncryptionServiceTestSuite) TestMaxKeep() {
config := DefaultEncryptionServiceConfig("none")
// Set MaxMessageKeysPerSession to an high value so it does not interfere
config.MaxMessageKeysPerSession = 100000
s.initDatabases(&config)
bobText := []byte("text")
bobKey, err := crypto.GenerateKey()
@ -530,18 +537,21 @@ func (s *EncryptionServiceTestSuite) TestMaxKeep() {
s.Require().NoError(err)
if i != 0 && i != 1 {
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m)
messageID := []byte(fmt.Sprintf("%d", i))
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, m, messageID)
s.Require().NoError(err)
err = s.alice.ConfirmMessagesProcessed([][]byte{messageID})
s.Require().NoError(err)
}
}
// We decrypt the first message, and it should fail, as it should have been removed
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[0])
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[0], defaultMessageID)
s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err)
// We decrypt the second message, and it should be decrypted
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[1])
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, messages[1], defaultMessageID)
s.Require().NoError(err)
}
@ -590,11 +600,11 @@ func (s *EncryptionServiceTestSuite) TestConcurrentBundles() {
s.Require().NoError(err)
// Bob receives the message
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage1)
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage1, defaultMessageID)
s.Require().NoError(err)
// Alice receives the message
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1)
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, defaultMessageID)
s.Require().NoError(err)
// Bob replies to the message
@ -606,11 +616,11 @@ func (s *EncryptionServiceTestSuite) TestConcurrentBundles() {
s.Require().NoError(err)
// Alice receives the message
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2)
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2, defaultMessageID)
s.Require().NoError(err)
// Bob receives the message
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage2)
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage2, defaultMessageID)
s.Require().NoError(err)
}
@ -658,7 +668,7 @@ func receiver(
i := 0
for payload := range input {
actualCleartext, err := s.DecryptPayload(privateKey, publicKey, installationID, payload)
actualCleartext, err := s.DecryptPayload(privateKey, publicKey, installationID, payload, defaultMessageID)
if err != nil {
errChan <- err
return
@ -765,7 +775,7 @@ func (s *EncryptionServiceTestSuite) TestBundleNotExisting() {
s.Require().NoError(err)
// Bob receives the message, and returns a bundlenotfound error
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage)
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage, defaultMessageID)
s.Require().Error(err)
s.Equal(ErrSessionNotFound, err)
}
@ -798,7 +808,7 @@ func (s *EncryptionServiceTestSuite) TestDeviceNotIncluded() {
s.Require().NoError(err)
// Bob receives the message, and returns a bundlenotfound error
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage)
_, err = s.bob.DecryptPayload(bobKey, &aliceKey.PublicKey, aliceInstallationID, aliceMessage, defaultMessageID)
s.Require().Error(err)
s.Equal(ErrDeviceNotFound, err)
}
@ -866,3 +876,88 @@ func (s *EncryptionServiceTestSuite) TestRefreshedBundle() {
s.Equal(bobBundle2.GetSignedPreKeys()[bobInstallationID].GetSignedPreKey(), x3dhHeader2.GetId())
}
func (s *EncryptionServiceTestSuite) TestMessageConfirmation() {
bobText1 := []byte("bob text 1")
bobKey, err := crypto.GenerateKey()
s.Require().NoError(err)
aliceKey, err := crypto.GenerateKey()
s.Require().NoError(err)
// Create a bundle
bobBundle, err := s.bob.CreateBundle(bobKey)
s.Require().NoError(err)
// We add bob bundle
_, err = s.alice.ProcessPublicBundle(aliceKey, bobBundle)
s.Require().NoError(err)
// Create a bundle
aliceBundle, err := s.alice.CreateBundle(aliceKey)
s.Require().NoError(err)
// We add alice bundle
_, err = s.bob.ProcessPublicBundle(bobKey, aliceBundle)
s.Require().NoError(err)
// Bob sends a message
bobMessage1, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText1)
s.Require().NoError(err)
bobMessage1ID := []byte("bob-message-1-id")
// Alice receives the message once
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, bobMessage1ID)
s.Require().NoError(err)
// Alice receives the message twice
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, bobMessage1ID)
s.Require().NoError(err)
// Alice confirms the message
err = s.alice.ConfirmMessagesProcessed([][]byte{bobMessage1ID})
s.Require().NoError(err)
// Alice decrypts it again, it should fail
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage1, bobMessage1ID)
s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err)
// Bob sends a message
bobMessage2, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText1)
s.Require().NoError(err)
bobMessage2ID := []byte("bob-message-2-id")
// Bob sends a message
bobMessage3, err := s.bob.EncryptPayload(&aliceKey.PublicKey, bobKey, bobText1)
s.Require().NoError(err)
bobMessage3ID := []byte("bob-message-3-id")
// Alice receives message 3 once
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage3, bobMessage3ID)
s.Require().NoError(err)
// Alice receives message 3 twice
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage3, bobMessage3ID)
s.Require().NoError(err)
// Alice receives message 2 once
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2, bobMessage2ID)
s.Require().NoError(err)
// Alice receives message 2 twice
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2, bobMessage2ID)
s.Require().NoError(err)
// Alice confirms the messages
err = s.alice.ConfirmMessagesProcessed([][]byte{bobMessage2ID, bobMessage3ID})
s.Require().NoError(err)
// Alice decrypts it again, it should fail
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage3, bobMessage3ID)
s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err)
// Alice decrypts it again, it should fail
_, err = s.alice.DecryptPayload(aliceKey, &bobKey.PublicKey, bobInstallationID, bobMessage2, bobMessage2ID)
s.Require().Equal(errors.New("can't skip current chain message keys: bad until: probably an out-of-order message that was deleted"), err)
}

View File

@ -124,8 +124,13 @@ func (p *ProtocolService) GetPublicBundle(theirIdentityKey *ecdsa.PublicKey) (*B
return p.encryption.GetPublicBundle(theirIdentityKey)
}
// ConfirmMessagesProcessed confirms and deletes message keys for the given messages
func (p *ProtocolService) ConfirmMessagesProcessed(messageIDs [][]byte) error {
return p.encryption.ConfirmMessagesProcessed(messageIDs)
}
// HandleMessage unmarshals a message and processes it, decrypting it if it is a 1:1 message.
func (p *ProtocolService) HandleMessage(myIdentityKey *ecdsa.PrivateKey, theirPublicKey *ecdsa.PublicKey, payload []byte) ([]byte, error) {
func (p *ProtocolService) HandleMessage(myIdentityKey *ecdsa.PrivateKey, theirPublicKey *ecdsa.PublicKey, payload []byte, messageID []byte) ([]byte, error) {
if p.encryption == nil {
return nil, errors.New("encryption service not initialized")
}
@ -167,7 +172,7 @@ func (p *ProtocolService) HandleMessage(myIdentityKey *ecdsa.PrivateKey, theirPu
// Decrypt message
if directMessage := protocolMessage.GetDirectMessage(); directMessage != nil {
message, err := p.encryption.DecryptPayload(myIdentityKey, theirPublicKey, protocolMessage.GetInstallationId(), directMessage)
message, err := p.encryption.DecryptPayload(myIdentityKey, theirPublicKey, protocolMessage.GetInstallationId(), directMessage, messageID)
if err != nil {
return nil, err
}

View File

@ -120,7 +120,7 @@ func (s *ProtocolServiceTestSuite) TestBuildAndReadDirectMessage() {
s.NoError(err)
// Bob is able to decrypt the message
unmarshaledMsg, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, marshaledMsg)
unmarshaledMsg, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, marshaledMsg, []byte("message-id"))
s.NoError(err)
s.NotNil(unmarshaledMsg)

View File

@ -22,7 +22,7 @@ func newCache(db *leveldb.DB) *cache {
}
func (d *cache) Has(filterID string, message *whisper.Message) (bool, error) {
has, err := d.db.Has(d.keyToday(filterID, message), nil)
has, err := d.db.Has(d.KeyToday(filterID, message), nil)
if err != nil {
return false, err
@ -38,7 +38,22 @@ func (d *cache) Put(filterID string, messages []*whisper.Message) error {
batch := leveldb.Batch{}
for _, msg := range messages {
batch.Put(d.keyToday(filterID, msg), []byte{})
batch.Put(d.KeyToday(filterID, msg), []byte{})
}
err := d.db.Write(&batch, nil)
if err != nil {
return err
}
return d.cleanOldEntries()
}
func (d *cache) PutIDs(messageIDs [][]byte) error {
batch := leveldb.Batch{}
for _, id := range messageIDs {
batch.Put(id, []byte{})
}
err := d.db.Write(&batch, nil)
@ -78,7 +93,7 @@ func (d *cache) keyYesterday(filterID string, message *whisper.Message) []byte {
return prefixedKey(d.yesterdayDateString(), filterID, message)
}
func (d *cache) keyToday(filterID string, message *whisper.Message) []byte {
func (d *cache) KeyToday(filterID string, message *whisper.Message) []byte {
return prefixedKey(d.todayDateString(), filterID, message)
}

View File

@ -18,6 +18,11 @@ type Deduplicator struct {
log log.Logger
}
type DeduplicateMessage struct {
DedupID []byte `json:"id"`
Message *whisper.Message `json:"message"`
}
// NewDeduplicator creates a new deduplicator
func NewDeduplicator(keyPairProvider keyPairProvider, db *leveldb.DB) *Deduplicator {
return &Deduplicator{
@ -30,15 +35,19 @@ func NewDeduplicator(keyPairProvider keyPairProvider, db *leveldb.DB) *Deduplica
// Deduplicate receives a list of whisper messages and
// returns the list of the messages that weren't filtered previously for the
// specified filter.
func (d *Deduplicator) Deduplicate(messages []*whisper.Message) []*whisper.Message {
result := make([]*whisper.Message, 0)
func (d *Deduplicator) Deduplicate(messages []*whisper.Message) []DeduplicateMessage {
result := make([]DeduplicateMessage, 0)
selectedKeyPairID := d.keyPairProvider.SelectedKeyPairID()
for _, message := range messages {
if has, err := d.cache.Has(d.keyPairProvider.SelectedKeyPairID(), message); !has {
if has, err := d.cache.Has(selectedKeyPairID, message); !has {
if err != nil {
d.log.Error("error while deduplicating messages: search cache failed", "err", err)
}
result = append(result, message)
result = append(result, DeduplicateMessage{
DedupID: d.cache.KeyToday(selectedKeyPairID, message),
Message: message,
})
}
}
@ -50,3 +59,9 @@ func (d *Deduplicator) Deduplicate(messages []*whisper.Message) []*whisper.Messa
func (d *Deduplicator) AddMessages(messages []*whisper.Message) error {
return d.cache.Put(d.keyPairProvider.SelectedKeyPairID(), messages)
}
// AddMessageByID adds a message to the deduplicator DB, so it will be filtered
// out.
func (d *Deduplicator) AddMessageByID(messageIDs [][]byte) error {
return d.cache.PutIDs(messageIDs)
}

View File

@ -10,6 +10,9 @@ type Session interface {
// RatchetDecrypt is called to AEAD-decrypt messages.
RatchetDecrypt(m Message, associatedData []byte) ([]byte, error)
//DeleteMk remove a message key from the database
DeleteMk(Key, uint32) error
}
type sessionState struct {
@ -101,6 +104,11 @@ func (s *sessionState) RatchetEncrypt(plaintext, ad []byte) (Message, error) {
return Message{h, ct}, nil
}
// DeleteMk deletes a message key
func (s *sessionState) DeleteMk(dh Key, n uint32) error {
return s.MkSkipped.DeleteMk(dh, uint(n))
}
// RatchetDecrypt is called to decrypt messages.
func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
// Is the message one of the skipped?
@ -114,7 +122,6 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
if err != nil {
return nil, fmt.Errorf("can't decrypt skipped message: %s", err)
}
_ = s.MkSkipped.DeleteMk(m.Header.DH, uint(m.Header.N))
if err := s.store(); err != nil {
return nil, err
}
@ -149,11 +156,20 @@ func (s *sessionState) RatchetDecrypt(m Message, ad []byte) ([]byte, error) {
return nil, fmt.Errorf("can't decrypt: %s", err)
}
// Append current key, waiting for confirmation
skippedKeys := append(skippedKeys1, skippedKeys2...)
skippedKeys = append(skippedKeys, skippedKey{
key: sc.DHr,
nr: uint(m.Header.N),
mk: mk,
seq: sc.KeysCount,
})
// Increment the number of keys
sc.KeysCount++
// Apply changes.
if err := s.applyChanges(sc, s.id, append(skippedKeys1, skippedKeys2...)); err != nil {
if err := s.applyChanges(sc, s.id, skippedKeys); err != nil {
return nil, err
}

View File

@ -75,7 +75,6 @@ type ReceivedMessage struct {
SymKeyHash common.Hash // The Keccak256Hash of the key
EnvelopeHash common.Hash // Message envelope hash to act as a unique id
History bool
}
func isMessageSigned(flags byte) bool {