diff --git a/protocol/messenger.go b/protocol/messenger.go index e075231b4..0414866c4 100644 --- a/protocol/messenger.go +++ b/protocol/messenger.go @@ -3064,9 +3064,9 @@ func (m *Messenger) EnableSendingPushNotifications() error { } // RegisterForPushNotification register deviceToken with any push notification server enabled -func (m *Messenger) RegisterForPushNotifications(ctx context.Context, deviceToken string) ([]*push_notification_client.PushNotificationServer, error) { +func (m *Messenger) RegisterForPushNotifications(ctx context.Context, deviceToken string) error { if m.pushNotificationClient == nil { - return nil, errors.New("push notification client not enabled") + return errors.New("push notification client not enabled") } var contactIDs []*ecdsa.PublicKey @@ -3095,6 +3095,20 @@ func (m *Messenger) RegisterForPushNotifications(ctx context.Context, deviceToke return m.pushNotificationClient.Register(deviceToken, contactIDs, mutedChatIDs) } +func (m *Messenger) RegisteredForPushNotifications() (bool, error) { + if m.pushNotificationClient == nil { + return false, errors.New("no push notification client") + } + return m.pushNotificationClient.Registered() +} + +func (m *Messenger) GetPushNotificationServers() ([]*push_notification_client.PushNotificationServer, error) { + if m.pushNotificationClient == nil { + return nil, errors.New("no push notification client") + } + return m.pushNotificationClient.GetServers() +} + func (m *Messenger) StartPushNotificationServer() error { if m.pushNotificationServer == nil { pushNotificationServerPersistence := push_notification_server.NewSQLitePersistence(m.database) diff --git a/protocol/push_notification_client/client.go b/protocol/push_notification_client/client.go index 3a706685e..9839c557d 100644 --- a/protocol/push_notification_client/client.go +++ b/protocol/push_notification_client/client.go @@ -487,7 +487,7 @@ func nextServerRetry(server *PushNotificationServer) int64 { // We calculate if it's too early to retry, by exponentially backing off func shouldRetryRegisteringWithServer(server *PushNotificationServer) bool { - return time.Now().Unix() < nextServerRetry(server) + return time.Now().Unix() > nextServerRetry(server) } func (c *Client) resetServers() error { @@ -558,7 +558,7 @@ func (c *Client) registerWithServer(registration *protobuf.PushNotificationRegis func (c *Client) registrationLoop() error { for { - c.config.Logger.Info("runing registration loop") + c.config.Logger.Info("running registration loop") servers, err := c.persistence.GetServers() if err != nil { c.config.Logger.Error("failed retrieving servers, quitting registration loop", zap.Error(err)) @@ -571,39 +571,44 @@ func (c *Client) registrationLoop() error { var nonRegisteredServers []*PushNotificationServer for _, server := range servers { - if server.Registered { + if !server.Registered { nonRegisteredServers = append(nonRegisteredServers, server) } - if len(nonRegisteredServers) == 0 { - c.config.Logger.Debug("registered with all servers, quitting registration loop") - return nil - } + } + if len(nonRegisteredServers) == 0 { + c.config.Logger.Debug("registered with all servers, quitting registration loop") + return nil + } - var lowestNextRetry int64 + c.config.Logger.Info("Trying to register with", zap.Int("servers", len(nonRegisteredServers))) - for _, server := range nonRegisteredServers { - if shouldRetryRegisteringWithServer(server) { - err := c.registerWithServer(c.lastPushNotificationRegistration, server) - if err != nil { - return err - } + var lowestNextRetry int64 + + for _, server := range nonRegisteredServers { + nR := nextServerRetry(server) + c.config.Logger.Info("Next retry", zap.Int64("now", time.Now().Unix()), zap.Int64("next", nR)) + if shouldRetryRegisteringWithServer(server) { + c.config.Logger.Info("registering with server", zap.Any("server", server)) + err := c.registerWithServer(c.lastPushNotificationRegistration, server) + if err != nil { + return err } - nextRetry := nextServerRetry(server) - if lowestNextRetry == 0 || nextRetry < lowestNextRetry { - lowestNextRetry = nextRetry - } - + } + nextRetry := nextServerRetry(server) + if lowestNextRetry == 0 || nextRetry < lowestNextRetry { + lowestNextRetry = nextRetry } - nextRetry := lowestNextRetry - time.Now().Unix() - waitFor := time.Duration(nextRetry) - select { + } - case <-time.After(waitFor * time.Second): - case <-c.registrationLoopQuitChan: - return nil + nextRetry := lowestNextRetry - time.Now().Unix() + waitFor := time.Duration(nextRetry) + select { + + case <-time.After(waitFor * time.Second): + case <-c.registrationLoopQuitChan: + return nil - } } } } @@ -643,73 +648,44 @@ func (c *Client) SaveLastPushNotificationRegistration(registration *protobuf.Pus return nil } -func (c *Client) Register(deviceToken string, contactIDs []*ecdsa.PublicKey, mutedChatIDs []string) ([]*PushNotificationServer, error) { +func (c *Client) Registered() (bool, error) { + servers, err := c.persistence.GetServers() + if err != nil { + return false, err + } + + for _, s := range servers { + if !s.Registered { + return false, nil + } + } + + return true, nil +} + +func (c *Client) GetServers() ([]*PushNotificationServer, error) { + return c.persistence.GetServers() +} + +func (c *Client) Register(deviceToken string, contactIDs []*ecdsa.PublicKey, mutedChatIDs []string) error { // stop registration loop c.stopRegistrationLoop() c.DeviceToken = deviceToken - servers, err := c.persistence.GetServers() - if err != nil { - return nil, err - } - - if len(servers) == 0 { - return nil, errors.New("no servers to register with") - } registration, err := c.buildPushNotificationRegistrationMessage(contactIDs, mutedChatIDs) if err != nil { - return nil, err + return err } err = c.SaveLastPushNotificationRegistration(registration, contactIDs) if err != nil { - return nil, err + return err } - var serverPublicKeys []*ecdsa.PublicKey - for _, server := range servers { - err := c.registerWithServer(registration, server) - if err != nil { - return nil, err - } - serverPublicKeys = append(serverPublicKeys, server.PublicKey) - } + c.startRegistrationLoop() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - - // This code polls the database for server registrations, giving up - // after 5 seconds - for { - select { - case <-c.quit: - return servers, nil - case <-ctx.Done(): - c.config.Logger.Info("could not register all servers") - // start registration loop - c.startRegistrationLoop() - return servers, nil - case <-time.After(200 * time.Millisecond): - servers, err = c.persistence.GetServersByPublicKey(serverPublicKeys) - if err != nil { - return nil, err - } - - allRegistered := true - for _, server := range servers { - allRegistered = allRegistered && server.Registered - } - - // If any of the servers we haven't registered yet, continue - if !allRegistered { - continue - } - - // all have registered,cancel context and return - cancel() - return servers, nil - } - } + return nil } // HandlePushNotificationRegistrationResponse should check whether the response was successful or not, retry if necessary otherwise store the result in the database diff --git a/protocol/push_notification_test.go b/protocol/push_notification_test.go index 95961ff28..910794894 100644 --- a/protocol/push_notification_test.go +++ b/protocol/push_notification_test.go @@ -3,6 +3,7 @@ package protocol import ( "context" "crypto/ecdsa" + "errors" "io/ioutil" "os" "testing" @@ -114,11 +115,9 @@ func (s *MessengerPushNotificationSuite) newPushNotificationServer(shh types.Whi } func (s *MessengerPushNotificationSuite) TestReceivePushNotification() { - errChan := make(chan error) bob1DeviceToken := "token-1" bob2DeviceToken := "token-2" - var bob1Servers, bob2Servers []*push_notification_client.PushNotificationServer bob1 := s.m bob2 := s.newMessengerWithKey(s.shh, s.m.identity) @@ -130,10 +129,7 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotification() { err := bob1.AddPushNotificationServer(context.Background(), &server.identity.PublicKey) s.Require().NoError(err) - go func() { - bob1Servers, err = bob1.RegisterForPushNotifications(context.Background(), bob1DeviceToken) - errChan <- err - }() + err = bob1.RegisterForPushNotifications(context.Background(), bob1DeviceToken) // Receive message, reply // TODO: find a better way to handle this waiting @@ -163,21 +159,28 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotification() { _, err = bob1.RetrieveAll() s.Require().NoError(err) + // Pull servers and check we registered + err = tt.RetryWithBackOff(func() error { + registered, err := bob1.RegisteredForPushNotifications() + if err != nil { + return err + } + if !registered { + return errors.New("not registered") + } + return nil + }) // Make sure we receive it - err = <-errChan s.Require().NoError(err) - s.Require().NotNil(bob1Servers) - s.Require().Len(bob1Servers, 1) - s.Require().True(bob1Servers[0].Registered) + bob1Servers, err := bob1.GetPushNotificationServers() + s.Require().NoError(err) // Register bob2 err = bob2.AddPushNotificationServer(context.Background(), &server.identity.PublicKey) s.Require().NoError(err) - go func() { - bob2Servers, err = bob2.RegisterForPushNotifications(context.Background(), bob2DeviceToken) - errChan <- err - }() + err = bob2.RegisterForPushNotifications(context.Background(), bob2DeviceToken) + s.Require().NoError(err) // Receive message, reply // TODO: find a better way to handle this waiting @@ -207,12 +210,20 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotification() { _, err = bob2.RetrieveAll() s.Require().NoError(err) + err = tt.RetryWithBackOff(func() error { + registered, err := bob2.RegisteredForPushNotifications() + if err != nil { + return err + } + if !registered { + return errors.New("not registered") + } + return nil + }) // Make sure we receive it - err = <-errChan s.Require().NoError(err) - s.Require().NotNil(bob2Servers) - s.Require().Len(bob2Servers, 1) - s.Require().True(bob2Servers[0].Registered) + bob2Servers, err := bob2.GetPushNotificationServers() + s.Require().NoError(err) err = alice.pushNotificationClient.QueryPushNotificationInfo(&bob2.identity.PublicKey) s.Require().NoError(err) diff --git a/services/ext/api.go b/services/ext/api.go index 468f01dab..62c9181e8 100644 --- a/services/ext/api.go +++ b/services/ext/api.go @@ -417,10 +417,10 @@ func (api *PublicAPI) StopPushNotificationServer() error { // PushNotification client -func (api *PublicAPI) RegisterForPushNotifications(ctx context.Context, deviceToken string) ([]*push_notification_client.PushNotificationServer, error) { +func (api *PublicAPI) RegisterForPushNotifications(ctx context.Context, deviceToken string) error { err := api.service.accountsDB.SaveSetting("remote-push-notifications-enabled", true) if err != nil { - return nil, err + return err } return api.service.messenger.RegisterForPushNotifications(ctx, deviceToken) } @@ -459,6 +459,14 @@ func (api *PublicAPI) AddPushNotificationServer(ctx context.Context, publicKeyBy return api.service.messenger.AddPushNotificationServer(ctx, publicKey) } +func (api *PublicAPI) GetPushNotificationServers() ([]*push_notification_client.PushNotificationServer, error) { + return api.service.messenger.GetPushNotificationServers() +} + +func (api *PublicAPI) RegisteredForPushNotifications() (bool, error) { + return api.service.messenger.RegisteredForPushNotifications() +} + // Echo is a method for testing purposes. func (api *PublicAPI) Echo(ctx context.Context, message string) (string, error) { return message, nil