Test token invalidation
This commit is contained in:
parent
d775bb888a
commit
149fc5e3eb
|
@ -1121,6 +1121,14 @@ func (m *Messenger) isNewContact(contact *Contact) bool {
|
|||
return contact.IsAdded() && (!ok || !previousContact.IsAdded())
|
||||
}
|
||||
|
||||
func (m *Messenger) removedContact(contact *Contact) bool {
|
||||
previousContact, ok := m.allContacts[contact.ID]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return previousContact.IsAdded() && !contact.IsAdded()
|
||||
}
|
||||
|
||||
func (m *Messenger) saveContact(contact *Contact) error {
|
||||
name, identicon, err := generateAliasAndIdenticon(contact.ID)
|
||||
if err != nil {
|
||||
|
@ -1137,6 +1145,9 @@ func (m *Messenger) saveContact(contact *Contact) error {
|
|||
}
|
||||
}
|
||||
|
||||
// We check if it should re-register
|
||||
shouldReregisterForPushNotifications := m.pushNotificationClient != nil && (m.isNewContact(contact) || m.removedContact(contact))
|
||||
|
||||
err = m.persistence.SaveContact(contact, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -1144,6 +1155,16 @@ func (m *Messenger) saveContact(contact *Contact) error {
|
|||
|
||||
m.allContacts[contact.ID] = contact
|
||||
|
||||
// Reregister only when data has changed
|
||||
if shouldReregisterForPushNotifications {
|
||||
m.logger.Info("contact state changed, re-registering for push notification")
|
||||
contactIDs, mutedChatIDs := m.addedContactsAndMutedChatIDs()
|
||||
err := m.pushNotificationClient.Reregister(contactIDs, mutedChatIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -3087,8 +3108,6 @@ func (m *Messenger) addedContactsAndMutedChatIDs() ([]*ecdsa.PublicKey, []string
|
|||
var contactIDs []*ecdsa.PublicKey
|
||||
var mutedChatIDs []string
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
for _, contact := range m.allContacts {
|
||||
if contact.IsAdded() {
|
||||
pk, err := contact.PublicKey()
|
||||
|
@ -3115,6 +3134,8 @@ func (m *Messenger) RegisterForPushNotifications(ctx context.Context, deviceToke
|
|||
if m.pushNotificationClient == nil {
|
||||
return errors.New("push notification client not enabled")
|
||||
}
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
contactIDs, mutedChatIDs := m.addedContactsAndMutedChatIDs()
|
||||
return m.pushNotificationClient.Register(deviceToken, contactIDs, mutedChatIDs)
|
||||
|
@ -3131,6 +3152,8 @@ func (m *Messenger) EnablePushNotificationsFromContactsOnly() error {
|
|||
if m.pushNotificationClient == nil {
|
||||
return errors.New("no push notification client")
|
||||
}
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
contactIDs, mutedChatIDs := m.addedContactsAndMutedChatIDs()
|
||||
return m.pushNotificationClient.EnablePushNotificationsFromContactsOnly(contactIDs, mutedChatIDs)
|
||||
|
@ -3140,6 +3163,8 @@ func (m *Messenger) DisablePushNotificationsFromContactsOnly() error {
|
|||
if m.pushNotificationClient == nil {
|
||||
return errors.New("no push notification client")
|
||||
}
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
contactIDs, mutedChatIDs := m.addedContactsAndMutedChatIDs()
|
||||
return m.pushNotificationClient.DisablePushNotificationsFromContactsOnly(contactIDs, mutedChatIDs)
|
||||
|
|
|
@ -424,6 +424,7 @@ func (c *Client) allowedUserList(token []byte, contactIDs []*ecdsa.PublicKey) ([
|
|||
// and return a new one in that case
|
||||
func (c *Client) getToken(contactIDs []*ecdsa.PublicKey) string {
|
||||
if c.lastPushNotificationRegistration == nil || len(c.lastPushNotificationRegistration.AccessToken) == 0 || c.shouldRefreshToken(c.lastContactIDs, contactIDs) {
|
||||
c.config.Logger.Info("refreshing access token")
|
||||
return uuid.New().String()
|
||||
}
|
||||
return c.lastPushNotificationRegistration.AccessToken
|
||||
|
@ -492,10 +493,7 @@ func nextPushNotificationRetry(pn *SentNotification) int64 {
|
|||
|
||||
// We calculate if it's too early to retry, by exponentially backing off
|
||||
func shouldRetryRegisteringWithServer(server *PushNotificationServer) bool {
|
||||
if server.RetryCount > maxRegistrationRetries {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() > nextServerRetry(server)
|
||||
return time.Now().Unix() >= nextServerRetry(server)
|
||||
}
|
||||
|
||||
// We calculate if it's too early to retry, by exponentially backing off
|
||||
|
@ -516,14 +514,13 @@ func (c *Client) resetServers() error {
|
|||
// Reset server registration data
|
||||
server.Registered = false
|
||||
server.RegisteredAt = 0
|
||||
server.RetryCount += 1
|
||||
server.RetryCount = 0
|
||||
server.LastRetriedAt = time.Now().Unix()
|
||||
server.AccessToken = ""
|
||||
|
||||
if err := c.persistence.UpsertServer(server); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -692,8 +689,8 @@ func (c *Client) resendingLoop() error {
|
|||
}
|
||||
|
||||
for _, pn := range retriableNotifications {
|
||||
nextRetry := nextPushNotificationRetry(pn)
|
||||
c.config.Logger.Info("Next retry", zap.Int64("now", time.Now().Unix()), zap.Int64("next", nextRetry))
|
||||
nR := nextPushNotificationRetry(pn)
|
||||
c.config.Logger.Info("Next retry", zap.Int64("now", time.Now().Unix()), zap.Int64("next", nR))
|
||||
if shouldRetryPushNotification(pn) {
|
||||
c.config.Logger.Info("retrying pn", zap.Any("pn", pn))
|
||||
err := c.resendNotification(pn)
|
||||
|
@ -701,6 +698,7 @@ func (c *Client) resendingLoop() error {
|
|||
return err
|
||||
}
|
||||
}
|
||||
nextRetry := nextPushNotificationRetry(pn)
|
||||
if lowestNextRetry == 0 || nextRetry < lowestNextRetry {
|
||||
lowestNextRetry = nextRetry
|
||||
}
|
||||
|
@ -734,7 +732,7 @@ func (c *Client) registrationLoop() error {
|
|||
|
||||
var nonRegisteredServers []*PushNotificationServer
|
||||
for _, server := range servers {
|
||||
if !server.Registered {
|
||||
if !server.Registered && server.RetryCount < maxRegistrationRetries {
|
||||
nonRegisteredServers = append(nonRegisteredServers, server)
|
||||
}
|
||||
}
|
||||
|
@ -766,6 +764,7 @@ func (c *Client) registrationLoop() error {
|
|||
|
||||
nextRetry := lowestNextRetry - time.Now().Unix()
|
||||
waitFor := time.Duration(nextRetry)
|
||||
c.config.Logger.Info("Waiting for", zap.Any("wait for", waitFor))
|
||||
select {
|
||||
|
||||
case <-time.After(waitFor * time.Second):
|
||||
|
@ -830,10 +829,26 @@ func (c *Client) GetServers() ([]*PushNotificationServer, error) {
|
|||
return c.persistence.GetServers()
|
||||
}
|
||||
|
||||
func (c *Client) Reregister(contactIDs []*ecdsa.PublicKey, mutedChatIDs []string) error {
|
||||
c.config.Logger.Info("re-registering")
|
||||
if len(c.deviceToken) == 0 {
|
||||
c.config.Logger.Info("no device token, not registering")
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.Register(c.deviceToken, contactIDs, mutedChatIDs)
|
||||
}
|
||||
|
||||
func (c *Client) Register(deviceToken string, contactIDs []*ecdsa.PublicKey, mutedChatIDs []string) error {
|
||||
// stop registration loop
|
||||
c.stopRegistrationLoop()
|
||||
|
||||
// Reset servers
|
||||
err := c.resetServers()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.deviceToken = deviceToken
|
||||
|
||||
registration, err := c.buildPushNotificationRegistrationMessage(contactIDs, mutedChatIDs)
|
||||
|
|
|
@ -253,13 +253,12 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotificationFromContactO
|
|||
bobDeviceToken := "token-1"
|
||||
|
||||
bob := s.m
|
||||
bob2 := s.newMessengerWithKey(s.shh, s.m.identity)
|
||||
server := s.newPushNotificationServer(s.shh)
|
||||
alice := s.newMessenger(s.shh)
|
||||
// start alice and enable push notifications
|
||||
s.Require().NoError(alice.Start())
|
||||
s.Require().NoError(alice.EnableSendingPushNotifications())
|
||||
bobInstallationIDs := []string{bob.installationID, bob2.installationID}
|
||||
bobInstallationIDs := []string{bob.installationID}
|
||||
|
||||
// Register bob
|
||||
err := bob.AddPushNotificationServer(context.Background(), &server.identity.PublicKey)
|
||||
|
@ -348,3 +347,150 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotificationFromContactO
|
|||
s.Require().NotNil(retrievedNotificationInfo)
|
||||
s.Require().Len(retrievedNotificationInfo, 1)
|
||||
}
|
||||
|
||||
func (s *MessengerPushNotificationSuite) TestReceivePushNotificationRetries() {
|
||||
|
||||
bobDeviceToken := "token-1"
|
||||
|
||||
bob := s.m
|
||||
server := s.newPushNotificationServer(s.shh)
|
||||
alice := s.newMessenger(s.shh)
|
||||
// another contact to invalidate the token
|
||||
frank := s.newMessenger(s.shh)
|
||||
// start alice and enable push notifications
|
||||
s.Require().NoError(alice.Start())
|
||||
s.Require().NoError(alice.EnableSendingPushNotifications())
|
||||
bobInstallationIDs := []string{bob.installationID}
|
||||
|
||||
// Register bob
|
||||
err := bob.AddPushNotificationServer(context.Background(), &server.identity.PublicKey)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Add alice has a contact
|
||||
aliceContact := &Contact{
|
||||
ID: types.EncodeHex(crypto.FromECDSAPub(&alice.identity.PublicKey)),
|
||||
Name: "Some Contact",
|
||||
SystemTags: []string{contactAdded},
|
||||
}
|
||||
|
||||
err = bob.SaveContact(aliceContact)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Add frank has a contact
|
||||
frankContact := &Contact{
|
||||
ID: types.EncodeHex(crypto.FromECDSAPub(&frank.identity.PublicKey)),
|
||||
Name: "Some Contact",
|
||||
SystemTags: []string{contactAdded},
|
||||
}
|
||||
|
||||
err = bob.SaveContact(frankContact)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Enable from contacts only
|
||||
err = bob.EnablePushNotificationsFromContactsOnly()
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = bob.RegisterForPushNotifications(context.Background(), bobDeviceToken)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Pull servers and check we registered
|
||||
err = tt.RetryWithBackOff(func() error {
|
||||
_, err = server.RetrieveAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = bob.RetrieveAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
registered, err := bob.RegisteredForPushNotifications()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !registered {
|
||||
return errors.New("not registered")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
// Make sure we receive it
|
||||
s.Require().NoError(err)
|
||||
bobServers, err := bob.GetPushNotificationServers()
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Create one to one chat & send message
|
||||
pkString := hex.EncodeToString(crypto.FromECDSAPub(&s.m.identity.PublicKey))
|
||||
chat := CreateOneToOneChat(pkString, &s.m.identity.PublicKey, alice.transport)
|
||||
s.Require().NoError(alice.SaveChat(&chat))
|
||||
inputMessage := buildTestMessage(chat)
|
||||
_, err = alice.SendChatMessage(context.Background(), inputMessage)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// The message has been sent, but not received, now we remove a contact so that the token is invalidated
|
||||
frankContact = &Contact{
|
||||
ID: types.EncodeHex(crypto.FromECDSAPub(&frank.identity.PublicKey)),
|
||||
Name: "Some Contact",
|
||||
SystemTags: []string{},
|
||||
}
|
||||
err = bob.SaveContact(frankContact)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Re-registration should be triggered, pull from server and bob to check we are correctly registered
|
||||
// Pull servers and check we registered
|
||||
err = tt.RetryWithBackOff(func() error {
|
||||
_, err = server.RetrieveAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = bob.RetrieveAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
registered, err := bob.RegisteredForPushNotifications()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !registered {
|
||||
return errors.New("not registered")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
newBobServers, err := bob.GetPushNotificationServers()
|
||||
s.Require().NoError(err)
|
||||
// Make sure access token is not the same
|
||||
s.Require().NotEqual(newBobServers[0].AccessToken, bobServers[0].AccessToken)
|
||||
|
||||
var info []*push_notification_client.PushNotificationInfo
|
||||
err = tt.RetryWithBackOff(func() error {
|
||||
_, err = server.RetrieveAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = alice.RetrieveAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
info, err = alice.pushNotificationClient.GetPushNotificationInfo(&bob.identity.PublicKey, bobInstallationIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Check we have replies for bob
|
||||
if len(info) != 1 {
|
||||
return errors.New("info not fetched")
|
||||
}
|
||||
return nil
|
||||
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Require().NotNil(info)
|
||||
s.Require().Equal(bob.installationID, info[0].InstallationID)
|
||||
s.Require().Equal(newBobServers[0].AccessToken, info[0].AccessToken)
|
||||
s.Require().Equal(&bob.identity.PublicKey, info[0].PublicKey)
|
||||
|
||||
retrievedNotificationInfo, err := alice.pushNotificationClient.GetPushNotificationInfo(&bob.identity.PublicKey, bobInstallationIDs)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(retrievedNotificationInfo)
|
||||
s.Require().Len(retrievedNotificationInfo, 1)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue