diff --git a/protocol/messenger.go b/protocol/messenger.go index 0ac1dfae1..182497222 100644 --- a/protocol/messenger.go +++ b/protocol/messenger.go @@ -3071,16 +3071,12 @@ func (m *Messenger) EnableSendingPushNotifications() error { return nil } -// RegisterForPushNotification register deviceToken with any push notification server enabled -func (m *Messenger) RegisterForPushNotifications(ctx context.Context, deviceToken string) error { - if m.pushNotificationClient == nil { - return errors.New("push notification client not enabled") - } - +func (m *Messenger) addedContactsAndMutedChatIDs() ([]*ecdsa.PublicKey, []string) { var contactIDs []*ecdsa.PublicKey var mutedChatIDs []string m.mutex.Lock() + defer m.mutex.Unlock() for _, contact := range m.allContacts { if contact.IsAdded() { pk, err := contact.PublicKey() @@ -3099,7 +3095,16 @@ func (m *Messenger) RegisterForPushNotifications(ctx context.Context, deviceToke } } - m.mutex.Unlock() + return contactIDs, mutedChatIDs +} + +// RegisterForPushNotification register deviceToken with any push notification server enabled +func (m *Messenger) RegisterForPushNotifications(ctx context.Context, deviceToken string) error { + if m.pushNotificationClient == nil { + return errors.New("push notification client not enabled") + } + + contactIDs, mutedChatIDs := m.addedContactsAndMutedChatIDs() return m.pushNotificationClient.Register(deviceToken, contactIDs, mutedChatIDs) } @@ -3115,7 +3120,8 @@ func (m *Messenger) EnablePushNotificationsFromContactsOnly() error { return errors.New("no push notification client") } - return m.pushNotificationClient.EnablePushNotificationsFromContactsOnly() + contactIDs, mutedChatIDs := m.addedContactsAndMutedChatIDs() + return m.pushNotificationClient.EnablePushNotificationsFromContactsOnly(contactIDs, mutedChatIDs) } func (m *Messenger) DisablePushNotificationsFromContactsOnly() error { @@ -3123,7 +3129,8 @@ func (m *Messenger) DisablePushNotificationsFromContactsOnly() error { return errors.New("no push notification client") } - return m.pushNotificationClient.DisablePushNotificationsFromContactsOnly() + contactIDs, mutedChatIDs := m.addedContactsAndMutedChatIDs() + return m.pushNotificationClient.DisablePushNotificationsFromContactsOnly(contactIDs, mutedChatIDs) } func (m *Messenger) GetPushNotificationServers() ([]*push_notification_client.PushNotificationServer, error) { diff --git a/protocol/push_notification_client/client.go b/protocol/push_notification_client/client.go index cda929a05..46ee030a0 100644 --- a/protocol/push_notification_client/client.go +++ b/protocol/push_notification_client/client.go @@ -75,9 +75,9 @@ type Config struct { // RemoteNotificationsEnabled is whether we should register with a remote server for push notifications RemoteNotificationsEnabled bool - // AllowyFromContactsOnly indicates whether we should be receiving push notifications + // allowyFromContactsOnly indicates whether we should be receiving push notifications // only from contacts - AllowFromContactsOnly bool + allowFromContactsOnly bool // InstallationID is the installation-id for this device InstallationID string @@ -101,8 +101,8 @@ type Client struct { // AccessToken is the access token that is currently being used AccessToken string - // DeviceToken is the device token for this device - DeviceToken string + // deviceToken is the device token for this device + deviceToken string // randomReader only used for testing so we have deterministic encryption reader io.Reader @@ -156,6 +156,7 @@ func (c *Client) loadLastPushNotificationRegistration() error { } c.lastContactIDs = lastContactIDs c.lastPushNotificationRegistration = lastRegistration + c.deviceToken = lastRegistration.Token return nil } @@ -407,10 +408,30 @@ func (p *Client) encryptToken(publicKey *ecdsa.PublicKey, token []byte) ([]byte, return encryptedToken, nil } -func (p *Client) allowedUserList(token []byte, contactIDs []*ecdsa.PublicKey) ([][]byte, error) { +func (p *Client) decryptToken(publicKey *ecdsa.PublicKey, token []byte) ([]byte, error) { + sharedKey, err := ecies.ImportECDSA(p.config.Identity).GenerateShared( + ecies.ImportECDSAPublic(publicKey), + accessTokenKeyLength, + accessTokenKeyLength, + ) + if err != nil { + return nil, err + } + decryptedToken, err := common.Decrypt(token, sharedKey) + if err != nil { + return nil, err + } + return decryptedToken, nil +} + +func (c *Client) allowedUserList(token []byte, contactIDs []*ecdsa.PublicKey) ([][]byte, error) { + // If we allow everyone, don't set the list + if !c.config.allowFromContactsOnly { + return nil, nil + } var encryptedTokens [][]byte for _, publicKey := range contactIDs { - encryptedToken, err := p.encryptToken(publicKey, token) + encryptedToken, err := c.encryptToken(publicKey, token) if err != nil { return nil, err } @@ -449,7 +470,7 @@ func (c *Client) buildPushNotificationRegistrationMessage(contactIDs []*ecdsa.Pu TokenType: c.config.TokenType, Version: c.getVersion(), InstallationId: c.config.InstallationID, - Token: c.DeviceToken, + Token: c.deviceToken, Enabled: c.config.RemoteNotificationsEnabled, BlockedChatList: c.mutedChatIDsHashes(mutedChatIDs), AllowedUserList: allowedUserList, @@ -672,7 +693,7 @@ func (c *Client) Register(deviceToken string, contactIDs []*ecdsa.PublicKey, mut // stop registration loop c.stopRegistrationLoop() - c.DeviceToken = deviceToken + c.deviceToken = deviceToken registration, err := c.buildPushNotificationRegistrationMessage(contactIDs, mutedChatIDs) if err != nil { @@ -752,6 +773,18 @@ func (c *Client) handleGrant(clientPublicKey *ecdsa.PublicKey, serverPublicKey * return nil } +func (c *Client) handleAllowedUserList(publicKey *ecdsa.PublicKey, allowedUserList [][]byte) string { + for _, encryptedToken := range allowedUserList { + token, err := c.decryptToken(publicKey, encryptedToken) + if err != nil { + c.config.Logger.Warn("could not decrypt token", zap.Error(err)) + continue + } + return string(token) + } + return "" +} + // HandlePushNotificationQueryResponse should update the data in the database for a given user func (c *Client) HandlePushNotificationQueryResponse(serverPublicKey *ecdsa.PublicKey, response protobuf.PushNotificationQueryResponse) error { @@ -775,6 +808,18 @@ func (c *Client) HandlePushNotificationQueryResponse(serverPublicKey *ecdsa.Publ continue } + accessToken := info.AccessToken + + if len(info.AllowedUserList) != 0 { + accessToken = c.handleAllowedUserList(publicKey, info.AllowedUserList) + + } + + if len(accessToken) == 0 { + c.config.Logger.Info("not in the allowed users list") + continue + } + // We check the user has allowed this server to store this particular // access token, otherwise anyone could reply with a fake token // and receive notifications for a user @@ -881,13 +926,19 @@ func (c *Client) DisableSending() { c.config.SendEnabled = false } -func (c *Client) EnablePushNotificationsFromContactsOnly() error { - c.config.AllowFromContactsOnly = true +func (c *Client) EnablePushNotificationsFromContactsOnly(contactIDs []*ecdsa.PublicKey, mutedChatIDs []string) error { + c.config.allowFromContactsOnly = true + if c.lastPushNotificationRegistration != nil { + return c.Register(c.deviceToken, contactIDs, mutedChatIDs) + } return nil } -func (c *Client) DisablePushNotificationsFromContactsOnly() error { - c.config.AllowFromContactsOnly = false +func (c *Client) DisablePushNotificationsFromContactsOnly(contactIDs []*ecdsa.PublicKey, mutedChatIDs []string) error { + c.config.allowFromContactsOnly = false + if c.lastPushNotificationRegistration != nil { + return c.Register(c.deviceToken, contactIDs, mutedChatIDs) + } return nil } diff --git a/protocol/push_notification_client/client_test.go b/protocol/push_notification_client/client_test.go index 0306fe044..56b60a70c 100644 --- a/protocol/push_notification_client/client_test.go +++ b/protocol/push_notification_client/client_test.go @@ -79,6 +79,49 @@ func (s *ClientSuite) TestBuildPushNotificationRegisterMessage() { // Get token expectedUUID := uuid.New().String() + // Reset random generator + uuid.SetRand(rand.New(rand.NewSource(seed))) + + s.client.deviceToken = myDeviceToken + // Set reader + s.client.reader = bytes.NewReader([]byte(expectedUUID)) + + options := &protobuf.PushNotificationRegistration{ + Version: 1, + AccessToken: expectedUUID, + Token: myDeviceToken, + InstallationId: s.installationID, + Enabled: true, + BlockedChatList: mutedChatListHashes, + } + + actualMessage, err := s.client.buildPushNotificationRegistrationMessage(contactIDs, mutedChatList) + s.Require().NoError(err) + + s.Require().Equal(options, actualMessage) +} + +func (s *ClientSuite) TestBuildPushNotificationRegisterMessageAllowFromContactsOnly() { + myDeviceToken := "device-token" + mutedChatList := []string{"a", "b"} + + // build chat lish hashes + var mutedChatListHashes [][]byte + for _, chatID := range mutedChatList { + mutedChatListHashes = append(mutedChatListHashes, common.Shake256([]byte(chatID))) + } + + contactKey, err := crypto.GenerateKey() + s.Require().NoError(err) + contactIDs := []*ecdsa.PublicKey{&contactKey.PublicKey} + + // Set random generator for uuid + var seed int64 = 1 + uuid.SetRand(rand.New(rand.NewSource(seed))) + + // Get token + expectedUUID := uuid.New().String() + // set up reader reader := bytes.NewReader([]byte(expectedUUID)) @@ -95,7 +138,8 @@ func (s *ClientSuite) TestBuildPushNotificationRegisterMessage() { // Reset random generator uuid.SetRand(rand.New(rand.NewSource(seed))) - s.client.DeviceToken = myDeviceToken + s.client.config.allowFromContactsOnly = true + s.client.deviceToken = myDeviceToken // Set reader s.client.reader = bytes.NewReader([]byte(expectedUUID))