move contact code to subscription

This commit is contained in:
Andrea Maria Piana 2020-07-31 14:22:05 +02:00
parent b557a64612
commit 541756c777
No known key found for this signature in database
GPG Key ID: AA6CCA6DE0E06424
12 changed files with 154 additions and 169 deletions

View File

@ -18,6 +18,7 @@ import (
"github.com/status-im/status-go/protocol/datasync" "github.com/status-im/status-go/protocol/datasync"
datasyncpeer "github.com/status-im/status-go/protocol/datasync/peer" datasyncpeer "github.com/status-im/status-go/protocol/datasync/peer"
"github.com/status-im/status-go/protocol/encryption" "github.com/status-im/status-go/protocol/encryption"
"github.com/status-im/status-go/protocol/encryption/sharedsecret"
"github.com/status-im/status-go/protocol/protobuf" "github.com/status-im/status-go/protocol/protobuf"
"github.com/status-im/status-go/protocol/transport" "github.com/status-im/status-go/protocol/transport"
v1protocol "github.com/status-im/status-go/protocol/v1" v1protocol "github.com/status-im/status-go/protocol/v1"
@ -60,6 +61,9 @@ type MessageProcessor struct {
scheduledMessagesSubscriptions []chan<- *RawMessage scheduledMessagesSubscriptions []chan<- *RawMessage
featureFlags FeatureFlags featureFlags FeatureFlags
// handleSharedSecrets is a callback that is called every time a new shared secret is negotiated
handleSharedSecrets func([]*sharedsecret.Secret) error
} }
func NewMessageProcessor( func NewMessageProcessor(
@ -110,9 +114,14 @@ func (p *MessageProcessor) Stop() {
for _, c := range p.sentMessagesSubscriptions { for _, c := range p.sentMessagesSubscriptions {
close(c) close(c)
} }
p.sentMessagesSubscriptions = nil
p.datasync.Stop() // idempotent op p.datasync.Stop() // idempotent op
} }
func (p *MessageProcessor) SetHandleSharedSecrets(handler func([]*sharedsecret.Secret) error) {
p.handleSharedSecrets = handler
}
// SendPrivate takes encoded data, encrypts it and sends through the wire. // SendPrivate takes encoded data, encrypts it and sends through the wire.
func (p *MessageProcessor) SendPrivate( func (p *MessageProcessor) SendPrivate(
ctx context.Context, ctx context.Context,
@ -203,7 +212,7 @@ func (p *MessageProcessor) sendPrivate(
} else if rawMessage.SkipEncryption { } else if rawMessage.SkipEncryption {
// When SkipEncryption is set we don't pass the message to the encryption layer // When SkipEncryption is set we don't pass the message to the encryption layer
messageIDs := [][]byte{messageID} messageIDs := [][]byte{messageID}
hash, newMessage, err := p.sendRawMessage(ctx, recipient, wrappedMessage, messageIDs) hash, newMessage, err := p.sendPrivateRawMessage(ctx, recipient, wrappedMessage, messageIDs)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to send a message spec") return nil, errors.Wrap(err, "failed to send a message spec")
} }
@ -216,6 +225,16 @@ func (p *MessageProcessor) sendPrivate(
return nil, errors.Wrap(err, "failed to encrypt message") return nil, errors.Wrap(err, "failed to encrypt message")
} }
// The shared secret needs to be handle before we send a message
// otherwise the topic might not be set up before we receive a message
if p.handleSharedSecrets != nil {
err := p.handleSharedSecrets([]*sharedsecret.Secret{messageSpec.SharedSecret})
if err != nil {
return nil, err
}
}
messageIDs := [][]byte{messageID} messageIDs := [][]byte{messageID}
hash, newMessage, err := p.sendMessageSpec(ctx, recipient, messageSpec, messageIDs) hash, newMessage, err := p.sendMessageSpec(ctx, recipient, messageSpec, messageIDs)
if err != nil { if err != nil {
@ -305,12 +324,24 @@ func (p *MessageProcessor) SendPublic(
return nil, errors.Wrap(err, "failed to wrap message") return nil, errors.Wrap(err, "failed to wrap message")
} }
newMessage := &types.NewMessage{ var newMessage *types.NewMessage
if !rawMessage.SkipEncryption {
messageSpec, err := p.protocol.BuildPublicMessage(p.identity, wrappedMessage)
if err != nil {
return nil, errors.Wrap(err, "failed to wrap a public message in the encryption layer")
}
newMessage, err = MessageSpecToWhisper(messageSpec)
if err != nil {
return nil, err
}
} else {
newMessage = &types.NewMessage{
TTL: whisperTTL, TTL: whisperTTL,
Payload: wrappedMessage, Payload: wrappedMessage,
PowTarget: calculatePoW(wrappedMessage), PowTarget: calculatePoW(wrappedMessage),
PowTime: whisperPoWTime, PowTime: whisperPoWTime,
} }
}
messageID := v1protocol.MessageID(&rawMessage.Sender.PublicKey, wrappedMessage) messageID := v1protocol.MessageID(&rawMessage.Sender.PublicKey, wrappedMessage)
rawMessage.ID = types.EncodeHex(messageID) rawMessage.ID = types.EncodeHex(messageID)
@ -479,6 +510,16 @@ func (p *MessageProcessor) sendDataSync(ctx context.Context, publicKey *ecdsa.Pu
return errors.Wrap(err, "failed to encrypt message") return errors.Wrap(err, "failed to encrypt message")
} }
// The shared secret needs to be handle before we send a message
// otherwise the topic might not be set up before we receive a message
if p.handleSharedSecrets != nil {
err := p.handleSharedSecrets([]*sharedsecret.Secret{messageSpec.SharedSecret})
if err != nil {
return err
}
}
hash, newMessage, err := p.sendMessageSpec(ctx, publicKey, messageSpec, messageIDs) hash, newMessage, err := p.sendMessageSpec(ctx, publicKey, messageSpec, messageIDs)
if err != nil { if err != nil {
return err return err
@ -489,8 +530,8 @@ func (p *MessageProcessor) sendDataSync(ctx context.Context, publicKey *ecdsa.Pu
return nil return nil
} }
// sendRawMessage sends a message not wrapped in an encryption layer // sendPrivateRawMessage sends a message not wrapped in an encryption layer
func (p *MessageProcessor) sendRawMessage(ctx context.Context, publicKey *ecdsa.PublicKey, payload []byte, messageIDs [][]byte) ([]byte, *types.NewMessage, error) { func (p *MessageProcessor) sendPrivateRawMessage(ctx context.Context, publicKey *ecdsa.PublicKey, payload []byte, messageIDs [][]byte) ([]byte, *types.NewMessage, error) {
newMessage := &types.NewMessage{ newMessage := &types.NewMessage{
TTL: whisperTTL, TTL: whisperTTL,
Payload: payload, Payload: payload,
@ -517,11 +558,11 @@ func (p *MessageProcessor) sendMessageSpec(ctx context.Context, publicKey *ecdsa
var hash []byte var hash []byte
switch { // process shared secret
case messageSpec.SharedSecret != nil: if messageSpec.AgreedSecret {
logger.Debug("sending using shared secret") logger.Debug("sending using shared secret")
hash, err = p.transport.SendPrivateWithSharedSecret(ctx, newMessage, publicKey, messageSpec.SharedSecret) hash, err = p.transport.SendPrivateWithSharedSecret(ctx, newMessage, publicKey, messageSpec.SharedSecret.Key)
default: } else {
logger.Debug("sending partitioned topic") logger.Debug("sending partitioned topic")
hash, err = p.transport.SendPrivateWithPartitioned(ctx, newMessage, publicKey) hash, err = p.transport.SendPrivateWithPartitioned(ctx, newMessage, publicKey)
} }

View File

@ -61,11 +61,9 @@ func (s *MessageProcessorSuite) SetupTest() {
database, err := sqlite.Open(filepath.Join(s.tmpDir, "processor-test.sql"), "some-key") database, err := sqlite.Open(filepath.Join(s.tmpDir, "processor-test.sql"), "some-key")
s.Require().NoError(err) s.Require().NoError(err)
onSendContactCode := func(*encryption.ProtocolMessageSpec) {}
encryptionProtocol := encryption.New( encryptionProtocol := encryption.New(
database, database,
"installation-1", "installation-1",
onSendContactCode,
s.logger, s.logger,
) )
@ -200,7 +198,6 @@ func (s *MessageProcessorSuite) TestHandleDecodedMessagesDatasyncEncrypted() {
senderEncryptionProtocol := encryption.New( senderEncryptionProtocol := encryption.New(
senderDatabase, senderDatabase,
"installation-2", "installation-2",
func(*encryption.ProtocolMessageSpec) {},
s.logger, s.logger,
) )

View File

@ -64,7 +64,6 @@ func setupUser(user string, s *EncryptionServiceMultiDeviceSuite, n int) error {
protocol := New( protocol := New(
db, db,
installationID, installationID,
func(*ProtocolMessageSpec) {},
s.logger.With(zap.String("user", user)), s.logger.With(zap.String("user", user)),
) )
s.services[user].services[i] = protocol s.services[user].services[i] = protocol

View File

@ -54,7 +54,6 @@ func (s *EncryptionServiceTestSuite) initDatabases(config encryptorConfig) {
db, db,
aliceInstallationID, aliceInstallationID,
config, config,
func(*ProtocolMessageSpec) {},
s.logger.With(zap.String("user", "alice")), s.logger.With(zap.String("user", "alice")),
) )
@ -65,7 +64,6 @@ func (s *EncryptionServiceTestSuite) initDatabases(config encryptorConfig) {
db, db,
bobInstallationID, bobInstallationID,
config, config,
func(*ProtocolMessageSpec) {},
s.logger.With(zap.String("user", "bob")), s.logger.With(zap.String("user", "bob")),
) )
} }
@ -123,7 +121,7 @@ func (s *EncryptionServiceTestSuite) TestEncryptPayloadNoBundle() {
// On the receiver side, we should be able to decrypt using our private key and the ephemeral just sent // On the receiver side, we should be able to decrypt using our private key and the ephemeral just sent
decryptedPayload1, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response1.Message, defaultMessageID) decryptedPayload1, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response1.Message, defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
s.Equal(cleartext, decryptedPayload1, "It correctly decrypts the payload using DH") s.Equal(cleartext, decryptedPayload1.DecryptedMessage, "It correctly decrypts the payload using DH")
// The next message will not be re-using the same key // The next message will not be re-using the same key
response2, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, cleartext) response2, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, cleartext)
@ -140,7 +138,7 @@ func (s *EncryptionServiceTestSuite) TestEncryptPayloadNoBundle() {
decryptedPayload2, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response2.Message, defaultMessageID) decryptedPayload2, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response2.Message, defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
s.Equal(cleartext, decryptedPayload2, "It correctly decrypts the payload using DH") s.Equal(cleartext, decryptedPayload2.DecryptedMessage, "It correctly decrypts the payload using DH")
} }
// Alice has Bob's bundle // Alice has Bob's bundle
@ -194,7 +192,7 @@ func (s *EncryptionServiceTestSuite) TestEncryptPayloadBundle() {
// Bob is able to decrypt it using the bundle // Bob is able to decrypt it using the bundle
decryptedPayload1, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response1.Message, defaultMessageID) decryptedPayload1, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response1.Message, defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
s.Equal(cleartext, decryptedPayload1, "It correctly decrypts the payload using X3DH") s.Equal(cleartext, decryptedPayload1.DecryptedMessage, "It correctly decrypts the payload using X3DH")
} }
// Alice has Bob's bundle // Alice has Bob's bundle
@ -260,7 +258,7 @@ func (s *EncryptionServiceTestSuite) TestConsequentMessagesBundle() {
decryptedPayload1, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response.Message, defaultMessageID) decryptedPayload1, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response.Message, defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
s.Equal(cleartext2, decryptedPayload1, "It correctly decrypts the payload using X3DH") s.Equal(cleartext2, decryptedPayload1.DecryptedMessage, "It correctly decrypts the payload using X3DH")
} }
// Alice has Bob's bundle // Alice has Bob's bundle
@ -344,7 +342,7 @@ func (s *EncryptionServiceTestSuite) TestConversation() {
decryptedPayload1, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response.Message, defaultMessageID) decryptedPayload1, err := s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, response.Message, defaultMessageID)
s.Require().NoError(err) s.Require().NoError(err)
s.Equal(cleartext2, decryptedPayload1, "It correctly decrypts the payload using X3DH") s.Equal(cleartext2, decryptedPayload1.DecryptedMessage, "It correctly decrypts the payload using X3DH")
} }
// Previous implementation allowed max maxSkip keys in the same receiving chain // Previous implementation allowed max maxSkip keys in the same receiving chain
@ -679,7 +677,7 @@ func receiver(
errChan <- err errChan <- err
return return
} }
if !reflect.DeepEqual(actualCleartext, cleartext) { if !reflect.DeepEqual(actualCleartext.DecryptedMessage, cleartext) {
errChan <- errors.New("Decrypted value does not match") errChan <- errors.New("Decrypted value does not match")
return return
} }

View File

@ -24,7 +24,6 @@ const (
sharedSecretNegotiationVersion = 1 sharedSecretNegotiationVersion = 1
partitionedTopicMinVersion = 1 partitionedTopicMinVersion = 1
defaultMinVersion = 0 defaultMinVersion = 0
subscriptionsChannelSize = 100
) )
type PartitionTopicMode int type PartitionTopicMode int
@ -39,7 +38,9 @@ type ProtocolMessageSpec struct {
// Installations is the targeted devices // Installations is the targeted devices
Installations []*multidevice.Installation Installations []*multidevice.Installation
// SharedSecret is a shared secret established among the installations // SharedSecret is a shared secret established among the installations
SharedSecret []byte SharedSecret *sharedsecret.Secret
// AgreedSecret indicates whether the shared secret has been agreed
AgreedSecret bool
// Public means that the spec contains a public wrapped message // Public means that the spec contains a public wrapped message
Public bool Public bool
} }
@ -73,8 +74,6 @@ type Protocol struct {
publisher *publisher.Publisher publisher *publisher.Publisher
subscriptions *Subscriptions subscriptions *Subscriptions
onSendContactCodeHandler func(*ProtocolMessageSpec)
logger *zap.Logger logger *zap.Logger
} }
@ -87,14 +86,12 @@ var (
func New( func New(
db *sql.DB, db *sql.DB,
installationID string, installationID string,
onSendContactCodeHandler func(*ProtocolMessageSpec),
logger *zap.Logger, logger *zap.Logger,
) *Protocol { ) *Protocol {
return NewWithEncryptorConfig( return NewWithEncryptorConfig(
db, db,
installationID, installationID,
defaultEncryptorConfig(installationID, logger), defaultEncryptorConfig(installationID, logger),
onSendContactCodeHandler,
logger, logger,
) )
} }
@ -105,7 +102,6 @@ func NewWithEncryptorConfig(
db *sql.DB, db *sql.DB,
installationID string, installationID string,
encryptorConfig encryptorConfig, encryptorConfig encryptorConfig,
onSendContactCodeHandler func(*ProtocolMessageSpec),
logger *zap.Logger, logger *zap.Logger,
) *Protocol { ) *Protocol {
return &Protocol{ return &Protocol{
@ -117,14 +113,12 @@ func NewWithEncryptorConfig(
InstallationID: installationID, InstallationID: installationID,
}), }),
publisher: publisher.New(logger), publisher: publisher.New(logger),
onSendContactCodeHandler: onSendContactCodeHandler,
logger: logger.With(zap.Namespace("Protocol")), logger: logger.With(zap.Namespace("Protocol")),
} }
} }
type Subscriptions struct { type Subscriptions struct {
NewInstallations chan []*multidevice.Installation SharedSecrets []*sharedsecret.Secret
NewSharedSecrets chan []*sharedsecret.Secret
SendContactCode <-chan struct{} SendContactCode <-chan struct{}
Quit chan struct{} Quit chan struct{}
} }
@ -136,32 +130,10 @@ func (p *Protocol) Start(myIdentity *ecdsa.PrivateKey) (*Subscriptions, error) {
return nil, errors.Wrap(err, "failed to get all secrets") return nil, errors.Wrap(err, "failed to get all secrets")
} }
p.subscriptions = &Subscriptions{ p.subscriptions = &Subscriptions{
NewInstallations: make(chan []*multidevice.Installation, subscriptionsChannelSize), SharedSecrets: secrets,
NewSharedSecrets: make(chan []*sharedsecret.Secret, subscriptionsChannelSize),
SendContactCode: p.publisher.Start(), SendContactCode: p.publisher.Start(),
Quit: make(chan struct{}), Quit: make(chan struct{}),
} }
if len(secrets) > 0 {
p.publishNewSharedSecrets(secrets)
}
// Handle Publisher system messages.
publisherCh := p.publisher.Start()
go func() {
for range publisherCh {
messageSpec, err := p.buildContactCodeMessage(myIdentity)
if err != nil {
p.logger.Error("failed to build contact code message",
zap.String("site", "Start"),
zap.Error(err))
continue
}
p.onSendContactCodeHandler(messageSpec)
}
}()
return p.subscriptions, nil return p.subscriptions, nil
} }
@ -212,12 +184,6 @@ func (p *Protocol) BuildPublicMessage(myIdentityKey *ecdsa.PrivateKey, payload [
return &ProtocolMessageSpec{Message: message, Public: true}, nil return &ProtocolMessageSpec{Message: message, Public: true}, nil
} }
// buildContactCodeMessage creates a contact code message. It's a public message
// without any data but it carries bundle information.
func (p *Protocol) buildContactCodeMessage(myIdentityKey *ecdsa.PrivateKey) (*ProtocolMessageSpec, error) {
return p.BuildPublicMessage(myIdentityKey, nil)
}
// BuildDirectMessage returns a 1:1 chat message and optionally a negotiated topic given the user identity private key, the recipient's public key, and a payload // BuildDirectMessage returns a 1:1 chat message and optionally a negotiated topic given the user identity private key, the recipient's public key, and a payload
func (p *Protocol) BuildDirectMessage(myIdentityKey *ecdsa.PrivateKey, publicKey *ecdsa.PublicKey, payload []byte) (*ProtocolMessageSpec, error) { func (p *Protocol) BuildDirectMessage(myIdentityKey *ecdsa.PrivateKey, publicKey *ecdsa.PublicKey, payload []byte) (*ProtocolMessageSpec, error) {
logger := p.logger.With( logger := p.logger.With(
@ -268,18 +234,12 @@ func (p *Protocol) BuildDirectMessage(myIdentityKey *ecdsa.PrivateKey, publicKey
zap.Bool("has-shared-secret", sharedSecret != nil), zap.Bool("has-shared-secret", sharedSecret != nil),
zap.Bool("agreed", agreed)) zap.Bool("agreed", agreed))
// Publish shared secrets
if sharedSecret != nil {
p.publishNewSharedSecrets([]*sharedsecret.Secret{sharedSecret})
}
spec := &ProtocolMessageSpec{ spec := &ProtocolMessageSpec{
SharedSecret: sharedSecret,
AgreedSecret: agreed,
Message: message, Message: message,
Installations: installations, Installations: installations,
} }
if agreed {
spec.SharedSecret = sharedSecret.Key
}
return spec, nil return spec, nil
} }
@ -425,24 +385,10 @@ func (p *Protocol) ConfirmMessageProcessed(messageID []byte) error {
return p.encryptor.ConfirmMessageProcessed(messageID) return p.encryptor.ConfirmMessageProcessed(messageID)
} }
func (p *Protocol) publishNewInstallations(installations []*multidevice.Installation) { type DecryptMessageResponse struct {
if p.subscriptions != nil { DecryptedMessage []byte
select { Installations []*multidevice.Installation
case p.subscriptions.NewInstallations <- installations: SharedSecrets []*sharedsecret.Secret
default:
p.logger.Warn("new installations channel full, dropping message")
}
}
}
func (p *Protocol) publishNewSharedSecrets(secrets []*sharedsecret.Secret) {
if p.subscriptions != nil {
select {
case p.subscriptions.NewSharedSecrets <- secrets:
default:
p.logger.Warn("new sharedsecrets channel full, dropping message")
}
}
} }
// HandleMessage unmarshals a message and processes it, decrypting it if it is a 1:1 message. // HandleMessage unmarshals a message and processes it, decrypting it if it is a 1:1 message.
@ -451,8 +397,9 @@ func (p *Protocol) HandleMessage(
theirPublicKey *ecdsa.PublicKey, theirPublicKey *ecdsa.PublicKey,
protocolMessage *ProtocolMessage, protocolMessage *ProtocolMessage,
messageID []byte, messageID []byte,
) ([]byte, error) { ) (*DecryptMessageResponse, error) {
logger := p.logger.With(zap.String("site", "HandleMessage")) logger := p.logger.With(zap.String("site", "HandleMessage"))
response := &DecryptMessageResponse{}
logger.Debug("received a protocol message", zap.Binary("sender-public-key", crypto.FromECDSAPub(theirPublicKey)), zap.Binary("message-id", messageID)) logger.Debug("received a protocol message", zap.Binary("sender-public-key", crypto.FromECDSAPub(theirPublicKey)), zap.Binary("message-id", messageID))
@ -463,20 +410,19 @@ func (p *Protocol) HandleMessage(
// Process bundles // Process bundles
for _, bundle := range protocolMessage.GetBundles() { for _, bundle := range protocolMessage.GetBundles() {
// Should we stop processing if the bundle cannot be verified? // Should we stop processing if the bundle cannot be verified?
addedBundles, err := p.ProcessPublicBundle(myIdentityKey, bundle) newInstallations, err := p.ProcessPublicBundle(myIdentityKey, bundle)
if err != nil { if err != nil {
return nil, err return nil, err
} }
response.Installations = newInstallations
// Publish without blocking if channel is full
p.publishNewInstallations(addedBundles)
} }
// Check if it's a public message // Check if it's a public message
if publicMessage := protocolMessage.GetPublicMessage(); publicMessage != nil { if publicMessage := protocolMessage.GetPublicMessage(); publicMessage != nil {
logger.Debug("received a public message in direct message") logger.Debug("received a public message in direct message")
// Nothing to do, as already in cleartext // Nothing to do, as already in cleartext
return publicMessage, nil response.DecryptedMessage = publicMessage
return response, nil
} }
// Decrypt message // Decrypt message
@ -504,9 +450,10 @@ func (p *Protocol) HandleMessage(
return nil, err return nil, err
} }
p.publishNewSharedSecrets([]*sharedsecret.Secret{sharedSecret}) response.SharedSecrets = []*sharedsecret.Secret{sharedSecret}
} }
return message, nil response.DecryptedMessage = message
return response, nil
} }
// Return error // Return error

View File

@ -4,7 +4,6 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"testing" "testing"
"time"
"github.com/status-im/status-go/protocol/tt" "github.com/status-im/status-go/protocol/tt"
@ -14,7 +13,6 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
"github.com/status-im/status-go/eth-node/crypto" "github.com/status-im/status-go/eth-node/crypto"
"github.com/status-im/status-go/protocol/encryption/sharedsecret"
) )
func TestProtocolServiceTestSuite(t *testing.T) { func TestProtocolServiceTestSuite(t *testing.T) {
@ -48,7 +46,6 @@ func (s *ProtocolServiceTestSuite) SetupTest() {
s.alice = New( s.alice = New(
db, db,
"1", "1",
func(*ProtocolMessageSpec) {},
s.logger.With(zap.String("user", "alice")), s.logger.With(zap.String("user", "alice")),
) )
@ -57,7 +54,6 @@ func (s *ProtocolServiceTestSuite) SetupTest() {
s.bob = New( s.bob = New(
db, db,
"2", "2",
func(*ProtocolMessageSpec) {},
s.logger.With(zap.String("user", "bob")), s.logger.With(zap.String("user", "bob")),
) )
} }
@ -134,7 +130,6 @@ func (s *ProtocolServiceTestSuite) TestBuildAndReadDirectMessage() {
} }
func (s *ProtocolServiceTestSuite) TestSecretNegotiation() { func (s *ProtocolServiceTestSuite) TestSecretNegotiation() {
var secretResponse []*sharedsecret.Secret
bobKey, err := crypto.GenerateKey() bobKey, err := crypto.GenerateKey()
s.NoError(err) s.NoError(err)
aliceKey, err := crypto.GenerateKey() aliceKey, err := crypto.GenerateKey()
@ -142,12 +137,13 @@ func (s *ProtocolServiceTestSuite) TestSecretNegotiation() {
payload := []byte("test") payload := []byte("test")
subscriptions, err := s.bob.Start(bobKey) _, err = s.bob.Start(bobKey)
s.Require().NoError(err) s.Require().NoError(err)
msgSpec, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, payload) msgSpec, err := s.alice.BuildDirectMessage(aliceKey, &bobKey.PublicKey, payload)
s.NoError(err) s.NoError(err)
s.NotNil(msgSpec, "It creates a message spec") s.NotNil(msgSpec, "It creates a message spec")
s.Require().NotNil(msgSpec.SharedSecret)
bundle := msgSpec.Message.GetBundles()[0] bundle := msgSpec.Message.GetBundles()[0]
s.Require().NotNil(bundle) s.Require().NotNil(bundle)
@ -163,19 +159,10 @@ func (s *ProtocolServiceTestSuite) TestSecretNegotiation() {
_, err = s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, msgSpec.Message, []byte("message-id")) _, err = s.bob.HandleMessage(bobKey, &aliceKey.PublicKey, msgSpec.Message, []byte("message-id"))
s.NoError(err) s.NoError(err)
select {
case <-time.After(2 * time.Second):
case secretResponse = <-subscriptions.NewSharedSecrets:
}
s.Require().NotNil(secretResponse)
s.Require().NoError(s.bob.Stop()) s.Require().NoError(s.bob.Stop())
} }
func (s *ProtocolServiceTestSuite) TestPropagatingSavedSharedSecretsOnStart() { func (s *ProtocolServiceTestSuite) TestPropagatingSavedSharedSecretsOnStart() {
var secretResponse []*sharedsecret.Secret
aliceKey, err := crypto.GenerateKey() aliceKey, err := crypto.GenerateKey()
s.NoError(err) s.NoError(err)
bobKey, err := crypto.GenerateKey() bobKey, err := crypto.GenerateKey()
@ -188,10 +175,7 @@ func (s *ProtocolServiceTestSuite) TestPropagatingSavedSharedSecretsOnStart() {
subscriptions, err := s.alice.Start(aliceKey) subscriptions, err := s.alice.Start(aliceKey)
s.Require().NoError(err) s.Require().NoError(err)
select { secretResponse := subscriptions.SharedSecrets
case <-time.After(2 * time.Second):
case secretResponse = <-subscriptions.NewSharedSecrets:
}
s.Require().NotNil(secretResponse) s.Require().NotNil(secretResponse)
s.Require().Len(secretResponse, 1) s.Require().Len(secretResponse, 1)

View File

@ -46,7 +46,7 @@ func (p *Publisher) Start() <-chan struct{} {
logger.Info("starting publisher") logger.Info("starting publisher")
p.notifyCh = make(chan struct{}) p.notifyCh = make(chan struct{}, 100)
p.quit = make(chan struct{}) p.quit = make(chan struct{})
go p.tickerLoop() go p.tickerLoop()
@ -105,7 +105,11 @@ func (p *Publisher) notify() error {
return errNotEnoughTimePassed return errNotEnoughTimePassed
} }
p.notifyCh <- struct{}{} select {
case p.notifyCh <- struct{}{}:
default:
p.logger.Warn("publisher channel full, dropping message")
}
p.persistence.setLastPublished(now) p.persistence.setLastPublished(now)
return nil return nil

View File

@ -128,27 +128,6 @@ func NewMessenger(
} }
} }
if c.onSendContactCodeHandler == nil {
c.onSendContactCodeHandler = func(messageSpec *encryption.ProtocolMessageSpec) {
slogger := logger.With(zap.String("site", "onSendContactCodeHandler"))
slogger.Debug("received a SendContactCode request")
newMessage, err := common.MessageSpecToWhisper(messageSpec)
if err != nil {
slogger.Warn("failed to convert spec to Whisper message", zap.Error(err))
return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
chatName := transport.ContactCodeTopic(&messenger.identity.PublicKey)
_, err = messenger.transport.SendPublic(ctx, newMessage, chatName)
if err != nil {
slogger.Warn("failed to send a contact code", zap.Error(err))
}
}
}
if c.systemMessagesTranslations == nil { if c.systemMessagesTranslations == nil {
c.systemMessagesTranslations = defaultSystemMessagesTranslations c.systemMessagesTranslations = defaultSystemMessagesTranslations
} }
@ -210,7 +189,6 @@ func NewMessenger(
encryptionProtocol := encryption.New( encryptionProtocol := encryption.New(
database, database,
installationID, installationID,
c.onSendContactCodeHandler,
logger, logger,
) )
@ -306,15 +284,43 @@ func (m *Messenger) Start() error {
} }
} }
// set shared secret handles
m.processor.SetHandleSharedSecrets(m.handleSharedSecrets)
subscriptions, err := m.encryptor.Start(m.identity) subscriptions, err := m.encryptor.Start(m.identity)
if err != nil { if err != nil {
return err return err
} }
// handle stored shared secrets
err = m.handleSharedSecrets(subscriptions.SharedSecrets)
if err != nil {
return err
}
m.handleEncryptionLayerSubscriptions(subscriptions) m.handleEncryptionLayerSubscriptions(subscriptions)
return nil return nil
} }
func (m *Messenger) handleSharedSecrets(secrets []*sharedsecret.Secret) ([]*transport.Filter, error) { // handleSendContactCode sends a public message wrapped in the encryption
// layer, which will propagate our bundle
func (m *Messenger) handleSendContactCode() error {
contactCodeTopic := transport.ContactCodeTopic(&m.identity.PublicKey)
rawMessage := common.RawMessage{
LocalChatID: contactCodeTopic,
Payload: nil,
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err := m.processor.SendPublic(ctx, contactCodeTopic, rawMessage)
if err != nil {
m.logger.Warn("failed to send a contact code", zap.Error(err))
}
return err
}
// handleSharedSecrets process the negotiated secrets received from the encryption layer
func (m *Messenger) handleSharedSecrets(secrets []*sharedsecret.Secret) error {
logger := m.logger.With(zap.String("site", "handleSharedSecrets")) logger := m.logger.With(zap.String("site", "handleSharedSecrets"))
var result []*transport.Filter var result []*transport.Filter
for _, secret := range secrets { for _, secret := range secrets {
@ -325,14 +331,19 @@ func (m *Messenger) handleSharedSecrets(secrets []*sharedsecret.Secret) ([]*tran
} }
filter, err := m.transport.ProcessNegotiatedSecret(fSecret) filter, err := m.transport.ProcessNegotiatedSecret(fSecret)
if err != nil { if err != nil {
return nil, err return err
} }
result = append(result, filter) result = append(result, filter)
} }
return result, nil if m.config.onNegotiatedFilters != nil {
m.config.onNegotiatedFilters(result)
} }
func (m *Messenger) handleNewInstallations(installations []*multidevice.Installation) { return nil
}
// handleInstallations adds the installations in the installations map
func (m *Messenger) handleInstallations(installations []*multidevice.Installation) {
for _, installation := range installations { for _, installation := range installations {
if installation.Identity == contactIDFromPublicKey(&m.identity.PublicKey) { if installation.Identity == contactIDFromPublicKey(&m.identity.PublicKey) {
if _, ok := m.allInstallations[installation.ID]; !ok { if _, ok := m.allInstallations[installation.ID]; !ok {
@ -348,20 +359,11 @@ func (m *Messenger) handleEncryptionLayerSubscriptions(subscriptions *encryption
go func() { go func() {
for { for {
select { select {
case secrets := <-subscriptions.NewSharedSecrets: case <-subscriptions.SendContactCode:
m.logger.Debug("handling new shared secrets") if err := m.handleSendContactCode(); err != nil {
filters, err := m.handleSharedSecrets(secrets) m.logger.Error("failed to publish contact code", zap.Error(err))
if err != nil {
m.logger.Warn("failed to process secrets", zap.Error(err))
continue
} }
if m.config.onNegotiatedFilters != nil {
m.config.onNegotiatedFilters(filters)
}
case newInstallations := <-subscriptions.NewInstallations:
m.logger.Debug("handling new installations")
m.handleNewInstallations(newInstallations)
case <-subscriptions.Quit: case <-subscriptions.Quit:
m.logger.Debug("quitting encryption subscription loop") m.logger.Debug("quitting encryption subscription loop")
return return
@ -1835,6 +1837,13 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
for _, msg := range statusMessages { for _, msg := range statusMessages {
publicKey := msg.SigPubKey() publicKey := msg.SigPubKey()
m.handleInstallations(msg.Installations)
err := m.handleSharedSecrets(msg.SharedSecrets)
if err != nil {
// log and continue, non-critical error
logger.Warn("failed to handle shared secrets")
}
// Check for messages from blocked users // Check for messages from blocked users
senderID := contactIDFromPublicKey(publicKey) senderID := contactIDFromPublicKey(publicKey)
if _, ok := messageState.AllContacts[senderID]; ok && messageState.AllContacts[senderID].IsBlocked() { if _, ok := messageState.AllContacts[senderID]; ok && messageState.AllContacts[senderID].IsBlocked() {

View File

@ -6,7 +6,6 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
"github.com/status-im/status-go/protocol/common" "github.com/status-im/status-go/protocol/common"
"github.com/status-im/status-go/protocol/encryption"
"github.com/status-im/status-go/protocol/protobuf" "github.com/status-im/status-go/protocol/protobuf"
"github.com/status-im/status-go/protocol/pushnotificationclient" "github.com/status-im/status-go/protocol/pushnotificationclient"
"github.com/status-im/status-go/protocol/pushnotificationserver" "github.com/status-im/status-go/protocol/pushnotificationserver"
@ -18,8 +17,6 @@ type config struct {
// as otherwise the client is not notified of a new filter and // as otherwise the client is not notified of a new filter and
// won't be pulling messages from mailservers until it reloads the chats/filters // won't be pulling messages from mailservers until it reloads the chats/filters
onNegotiatedFilters func([]*transport.Filter) onNegotiatedFilters func([]*transport.Filter)
// DEPRECATED: no need to expose it
onSendContactCodeHandler func(*encryption.ProtocolMessageSpec)
// systemMessagesTranslations holds translations for system-messages // systemMessagesTranslations holds translations for system-messages
systemMessagesTranslations map[protobuf.MembershipUpdateEvent_EventType]string systemMessagesTranslations map[protobuf.MembershipUpdateEvent_EventType]string

View File

@ -1405,7 +1405,7 @@ func (s *MessengerSuite) TestContactPersistenceUpdate() {
} }
func (s *MessengerSuite) TestSharedSecretHandler() { func (s *MessengerSuite) TestSharedSecretHandler() {
_, err := s.m.handleSharedSecrets(nil) err := s.m.handleSharedSecrets(nil)
s.NoError(err) s.NoError(err)
} }

View File

@ -506,7 +506,7 @@ func (s *ServerSuite) TestPushNotificationHandleRegistration() {
retrievedRegistration, err = s.persistence.GetPushNotificationRegistrationByPublicKeyAndInstallationID(common.HashPublicKey(&s.key.PublicKey), s.installationID) retrievedRegistration, err = s.persistence.GetPushNotificationRegistrationByPublicKeyAndInstallationID(common.HashPublicKey(&s.key.PublicKey), s.installationID)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Nil(retrievedRegistration) s.Require().Nil(retrievedRegistration)
// Check version is mantained // Check version is maintained
version, err := s.persistence.GetPushNotificationRegistrationVersion(common.HashPublicKey(&s.key.PublicKey), s.installationID) version, err := s.persistence.GetPushNotificationRegistrationVersion(common.HashPublicKey(&s.key.PublicKey), s.installationID)
s.Require().NoError(err) s.Require().NoError(err)
s.Require().Equal(uint64(2), version) s.Require().Equal(uint64(2), version)

View File

@ -14,6 +14,8 @@ import (
"github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/eth-node/types"
"github.com/status-im/status-go/protocol/datasync" "github.com/status-im/status-go/protocol/datasync"
"github.com/status-im/status-go/protocol/encryption" "github.com/status-im/status-go/protocol/encryption"
"github.com/status-im/status-go/protocol/encryption/multidevice"
"github.com/status-im/status-go/protocol/encryption/sharedsecret"
"github.com/status-im/status-go/protocol/protobuf" "github.com/status-im/status-go/protocol/protobuf"
) )
@ -45,6 +47,11 @@ type StatusMessage struct {
TransportLayerSigPubKey *ecdsa.PublicKey `json:"-"` TransportLayerSigPubKey *ecdsa.PublicKey `json:"-"`
// ApplicationMetadataLayerPubKey contains the public key provided by the application metadata layer // ApplicationMetadataLayerPubKey contains the public key provided by the application metadata layer
ApplicationMetadataLayerSigPubKey *ecdsa.PublicKey `json:"-"` ApplicationMetadataLayerSigPubKey *ecdsa.PublicKey `json:"-"`
// Installations is the new installations returned by the encryption layer
Installations []*multidevice.Installation
// SharedSecret is the shared secret returned by the encryption layer
SharedSecrets []*sharedsecret.Secret
} }
// Temporary JSON marshaling for those messages that are not yet processed // Temporary JSON marshaling for those messages that are not yet processed
@ -117,7 +124,7 @@ func (m *StatusMessage) HandleEncryption(myKey *ecdsa.PrivateKey, senderKey *ecd
return errors.Wrap(err, "failed to unmarshal ProtocolMessage") return errors.Wrap(err, "failed to unmarshal ProtocolMessage")
} }
payload, err := enc.HandleMessage( response, err := enc.HandleMessage(
myKey, myKey,
senderKey, senderKey,
&protocolMessage, &protocolMessage,
@ -128,7 +135,9 @@ func (m *StatusMessage) HandleEncryption(myKey *ecdsa.PrivateKey, senderKey *ecd
return errors.Wrap(err, "failed to handle Encryption message") return errors.Wrap(err, "failed to handle Encryption message")
} }
m.DecryptedPayload = payload m.DecryptedPayload = response.DecryptedMessage
m.Installations = response.Installations
m.SharedSecrets = response.SharedSecrets
return nil return nil
} }