diff --git a/protocol/common/message_processor_test.go b/protocol/common/message_processor_test.go index 59705a6f6..b9d0c1f8e 100644 --- a/protocol/common/message_processor_test.go +++ b/protocol/common/message_processor_test.go @@ -17,7 +17,6 @@ import ( "github.com/status-im/status-go/eth-node/crypto" "github.com/status-im/status-go/eth-node/types" "github.com/status-im/status-go/protocol/encryption" - "github.com/status-im/status-go/protocol/encryption/multidevice" "github.com/status-im/status-go/protocol/encryption/sharedsecret" "github.com/status-im/status-go/protocol/protobuf" "github.com/status-im/status-go/protocol/sqlite" @@ -63,13 +62,11 @@ func (s *MessageProcessorSuite) SetupTest() { database, err := sqlite.Open(filepath.Join(s.tmpDir, "processor-test.sql"), "some-key") s.Require().NoError(err) - onNewInstallations := func([]*multidevice.Installation) {} onNewSharedSecret := func([]*sharedsecret.Secret) {} onSendContactCode := func(*encryption.ProtocolMessageSpec) {} encryptionProtocol := encryption.New( database, "installation-1", - onNewInstallations, onNewSharedSecret, onSendContactCode, s.logger, @@ -206,7 +203,6 @@ func (s *MessageProcessorSuite) TestHandleDecodedMessagesDatasyncEncrypted() { senderEncryptionProtocol := encryption.New( senderDatabase, "installation-2", - func([]*multidevice.Installation) {}, func([]*sharedsecret.Secret) {}, func(*encryption.ProtocolMessageSpec) {}, s.logger, diff --git a/protocol/encryption/encryption_multi_device_test.go b/protocol/encryption/encryption_multi_device_test.go index 4d23ad37e..f32c053aa 100644 --- a/protocol/encryption/encryption_multi_device_test.go +++ b/protocol/encryption/encryption_multi_device_test.go @@ -65,7 +65,6 @@ func setupUser(user string, s *EncryptionServiceMultiDeviceSuite, n int) error { protocol := New( db, installationID, - func(s []*multidevice.Installation) {}, func(s []*sharedsecret.Secret) {}, func(*ProtocolMessageSpec) {}, s.logger.With(zap.String("user", user)), diff --git a/protocol/encryption/encryption_test.go b/protocol/encryption/encryption_test.go index c60d76a98..262eadf79 100644 --- a/protocol/encryption/encryption_test.go +++ b/protocol/encryption/encryption_test.go @@ -19,7 +19,6 @@ import ( "github.com/status-im/status-go/eth-node/crypto" - "github.com/status-im/status-go/protocol/encryption/multidevice" "github.com/status-im/status-go/protocol/encryption/sharedsecret" ) @@ -57,7 +56,6 @@ func (s *EncryptionServiceTestSuite) initDatabases(config encryptorConfig) { db, aliceInstallationID, config, - func(s []*multidevice.Installation) {}, func(s []*sharedsecret.Secret) {}, func(*ProtocolMessageSpec) {}, s.logger.With(zap.String("user", "alice")), @@ -70,7 +68,6 @@ func (s *EncryptionServiceTestSuite) initDatabases(config encryptorConfig) { db, bobInstallationID, config, - func(s []*multidevice.Installation) {}, func(s []*sharedsecret.Secret) {}, func(*ProtocolMessageSpec) {}, s.logger.With(zap.String("user", "bob")), diff --git a/protocol/encryption/protocol.go b/protocol/encryption/protocol.go index ac2d5076c..4b7af228b 100644 --- a/protocol/encryption/protocol.go +++ b/protocol/encryption/protocol.go @@ -24,6 +24,7 @@ const ( sharedSecretNegotiationVersion = 1 partitionedTopicMinVersion = 1 defaultMinVersion = 0 + subscriptionsChannelSize = 100 ) type PartitionTopicMode int @@ -66,12 +67,12 @@ func (p *ProtocolMessageSpec) PartitionedTopicMode() PartitionTopicMode { } type Protocol struct { - encryptor *encryptor - secret *sharedsecret.SharedSecret - multidevice *multidevice.Multidevice - publisher *publisher.Publisher + encryptor *encryptor + secret *sharedsecret.SharedSecret + multidevice *multidevice.Multidevice + publisher *publisher.Publisher + subscriptions *Subscriptions - onAddedBundlesHandler func([]*multidevice.Installation) onNewSharedSecretHandler func([]*sharedsecret.Secret) onSendContactCodeHandler func(*ProtocolMessageSpec) @@ -87,7 +88,6 @@ var ( func New( db *sql.DB, installationID string, - addedBundlesHandler func([]*multidevice.Installation), onNewSharedSecretHandler func([]*sharedsecret.Secret), onSendContactCodeHandler func(*ProtocolMessageSpec), logger *zap.Logger, @@ -96,7 +96,6 @@ func New( db, installationID, defaultEncryptorConfig(installationID, logger), - addedBundlesHandler, onNewSharedSecretHandler, onSendContactCodeHandler, logger, @@ -109,7 +108,6 @@ func NewWithEncryptorConfig( db *sql.DB, installationID string, encryptorConfig encryptorConfig, - addedBundlesHandler func([]*multidevice.Installation), onNewSharedSecretHandler func([]*sharedsecret.Secret), onSendContactCodeHandler func(*ProtocolMessageSpec), logger *zap.Logger, @@ -123,18 +121,30 @@ func NewWithEncryptorConfig( InstallationID: installationID, }), publisher: publisher.New(logger), - onAddedBundlesHandler: addedBundlesHandler, onNewSharedSecretHandler: onNewSharedSecretHandler, onSendContactCodeHandler: onSendContactCodeHandler, logger: logger.With(zap.Namespace("Protocol")), } } -func (p *Protocol) Start(myIdentity *ecdsa.PrivateKey) error { +type Subscriptions struct { + NewInstallations chan []*multidevice.Installation + NewSharedSecret chan []*sharedsecret.Secret + SendContactCode <-chan struct{} + Quit chan struct{} +} + +func (p *Protocol) Start(myIdentity *ecdsa.PrivateKey) (*Subscriptions, error) { // Propagate currently cached shared secrets. secrets, err := p.secret.All() if err != nil { - return errors.Wrap(err, "failed to get all secrets") + return nil, errors.Wrap(err, "failed to get all secrets") + } + p.subscriptions = &Subscriptions{ + NewInstallations: make(chan []*multidevice.Installation, subscriptionsChannelSize), + NewSharedSecret: make(chan []*sharedsecret.Secret, subscriptionsChannelSize), + SendContactCode: p.publisher.Start(), + Quit: make(chan struct{}), } p.onNewSharedSecretHandler(secrets) @@ -155,6 +165,14 @@ func (p *Protocol) Start(myIdentity *ecdsa.PrivateKey) error { } }() + return p.subscriptions, nil +} + +func (p *Protocol) Stop() error { + p.publisher.Stop() + if p.subscriptions != nil { + close(p.subscriptions.Quit) + } return nil } @@ -433,7 +451,14 @@ func (p *Protocol) HandleMessage( return nil, err } - p.onAddedBundlesHandler(addedBundles) + // Publish without blocking if channel is full + if p.subscriptions != nil { + select { + case p.subscriptions.NewInstallations <- addedBundles: + default: + p.logger.Warn("new installations channel full, dropping message") + } + } } // Check if it's a public message diff --git a/protocol/encryption/protocol_test.go b/protocol/encryption/protocol_test.go index 7fba7301e..57882010b 100644 --- a/protocol/encryption/protocol_test.go +++ b/protocol/encryption/protocol_test.go @@ -14,7 +14,6 @@ import ( "github.com/status-im/status-go/eth-node/crypto" - "github.com/status-im/status-go/protocol/encryption/multidevice" "github.com/status-im/status-go/protocol/encryption/sharedsecret" ) @@ -44,7 +43,6 @@ func (s *ProtocolServiceTestSuite) SetupTest() { s.Require().NoError(err) bobDBKey := "bob" - addedBundlesHandler := func(addedBundles []*multidevice.Installation) {} onNewSharedSecretHandler := func(secret []*sharedsecret.Secret) {} db, err := sqlite.Open(s.aliceDBPath.Name(), aliceDBKey) @@ -52,7 +50,6 @@ func (s *ProtocolServiceTestSuite) SetupTest() { s.alice = New( db, "1", - addedBundlesHandler, onNewSharedSecretHandler, func(*ProtocolMessageSpec) {}, s.logger.With(zap.String("user", "alice")), @@ -63,7 +60,6 @@ func (s *ProtocolServiceTestSuite) SetupTest() { s.bob = New( db, "2", - addedBundlesHandler, onNewSharedSecretHandler, func(*ProtocolMessageSpec) {}, s.logger.With(zap.String("user", "bob")), @@ -190,7 +186,7 @@ func (s *ProtocolServiceTestSuite) TestPropagatingSavedSharedSecretsOnStart() { secretResponse = secret } - err = s.alice.Start(aliceKey) + _, err = s.alice.Start(aliceKey) s.NoError(err) s.Require().NotNil(secretResponse) diff --git a/protocol/encryption/publisher/publisher.go b/protocol/encryption/publisher/publisher.go index af42e38c6..bbde27b00 100644 --- a/protocol/encryption/publisher/publisher.go +++ b/protocol/encryption/publisher/publisher.go @@ -55,6 +55,10 @@ func (p *Publisher) Start() <-chan struct{} { } func (p *Publisher) Stop() { + // If hasn't started, ignore + if p.quit == nil { + return + } select { case _, ok := <-p.quit: if !ok { diff --git a/protocol/messenger.go b/protocol/messenger.go index a2ae5b8aa..b8b752ce1 100644 --- a/protocol/messenger.go +++ b/protocol/messenger.go @@ -127,17 +127,6 @@ func NewMessenger( } } - onNewInstallationsHandler := func(installations []*multidevice.Installation) { - - for _, installation := range installations { - if installation.Identity == contactIDFromPublicKey(&messenger.identity.PublicKey) { - if _, ok := messenger.allInstallations[installation.ID]; !ok { - messenger.allInstallations[installation.ID] = installation - messenger.modifiedInstallations[installation.ID] = true - } - } - } - } // Set default config fields. onNewSharedSecretHandler := func(secrets []*sharedsecret.Secret) { filters, err := messenger.handleSharedSecrets(secrets) @@ -232,7 +221,6 @@ func NewMessenger( encryptionProtocol := encryption.New( database, installationID, - onNewInstallationsHandler, onNewSharedSecretHandler, c.onSendContactCodeHandler, logger, @@ -297,6 +285,7 @@ func NewMessenger( shutdownTasks: []func() error{ database.Close, pushNotificationClient.Stop, + encryptionProtocol.Stop, transp.ResetFilters, transp.Stop, func() error { processor.Stop(); return nil }, @@ -328,7 +317,39 @@ func (m *Messenger) Start() error { } } - return m.encryptor.Start(m.identity) + subscriptions, err := m.encryptor.Start(m.identity) + if err != nil { + return err + } + m.handleEncryptionLayerSubscriptions(subscriptions) + return nil +} + +func (m *Messenger) handleNewInstallations(installations []*multidevice.Installation) { + for _, installation := range installations { + if installation.Identity == contactIDFromPublicKey(&m.identity.PublicKey) { + if _, ok := m.allInstallations[installation.ID]; !ok { + m.allInstallations[installation.ID] = installation + m.modifiedInstallations[installation.ID] = true + } + } + } +} + +func (m *Messenger) handleEncryptionLayerSubscriptions(subscriptions *encryption.Subscriptions) { + go func() { + for { + select { + case newInstallations := <-subscriptions.NewInstallations: + m.logger.Debug("handling new installations") + m.handleNewInstallations(newInstallations) + case <-subscriptions.Quit: + m.logger.Debug("quitting encryption subscription loop") + return + + } + } + }() } // Init analyzes chats and contacts in order to setup filters diff --git a/protocol/messenger_installations_test.go b/protocol/messenger_installations_test.go index 41cf7e948..dec7ce458 100644 --- a/protocol/messenger_installations_test.go +++ b/protocol/messenger_installations_test.go @@ -50,6 +50,12 @@ func (s *MessengerInstallationSuite) SetupTest() { s.m = s.newMessenger(s.shh) s.privateKey = s.m.identity + // We start the messenger in order to receive installations + s.Require().NoError(s.m.Start()) +} + +func (s *MessengerInstallationSuite) TearDownTest() { + s.Require().NoError(s.m.Shutdown()) } func (s *MessengerInstallationSuite) newMessengerWithKey(shh types.Waku, privateKey *ecdsa.PrivateKey) *Messenger {