Register asynchronously

This commit is contained in:
Andrea Maria Piana 2020-07-16 10:36:17 +02:00
parent 7a54d9b1c9
commit 026e16533f
No known key found for this signature in database
GPG Key ID: AA6CCA6DE0E06424
4 changed files with 110 additions and 101 deletions

View File

@ -3064,9 +3064,9 @@ func (m *Messenger) EnableSendingPushNotifications() error {
} }
// RegisterForPushNotification register deviceToken with any push notification server enabled // RegisterForPushNotification register deviceToken with any push notification server enabled
func (m *Messenger) RegisterForPushNotifications(ctx context.Context, deviceToken string) ([]*push_notification_client.PushNotificationServer, error) { func (m *Messenger) RegisterForPushNotifications(ctx context.Context, deviceToken string) error {
if m.pushNotificationClient == nil { if m.pushNotificationClient == nil {
return nil, errors.New("push notification client not enabled") return errors.New("push notification client not enabled")
} }
var contactIDs []*ecdsa.PublicKey var contactIDs []*ecdsa.PublicKey
@ -3095,6 +3095,20 @@ func (m *Messenger) RegisterForPushNotifications(ctx context.Context, deviceToke
return m.pushNotificationClient.Register(deviceToken, contactIDs, mutedChatIDs) return m.pushNotificationClient.Register(deviceToken, contactIDs, mutedChatIDs)
} }
func (m *Messenger) RegisteredForPushNotifications() (bool, error) {
if m.pushNotificationClient == nil {
return false, errors.New("no push notification client")
}
return m.pushNotificationClient.Registered()
}
func (m *Messenger) GetPushNotificationServers() ([]*push_notification_client.PushNotificationServer, error) {
if m.pushNotificationClient == nil {
return nil, errors.New("no push notification client")
}
return m.pushNotificationClient.GetServers()
}
func (m *Messenger) StartPushNotificationServer() error { func (m *Messenger) StartPushNotificationServer() error {
if m.pushNotificationServer == nil { if m.pushNotificationServer == nil {
pushNotificationServerPersistence := push_notification_server.NewSQLitePersistence(m.database) pushNotificationServerPersistence := push_notification_server.NewSQLitePersistence(m.database)

View File

@ -487,7 +487,7 @@ func nextServerRetry(server *PushNotificationServer) int64 {
// We calculate if it's too early to retry, by exponentially backing off // We calculate if it's too early to retry, by exponentially backing off
func shouldRetryRegisteringWithServer(server *PushNotificationServer) bool { func shouldRetryRegisteringWithServer(server *PushNotificationServer) bool {
return time.Now().Unix() < nextServerRetry(server) return time.Now().Unix() > nextServerRetry(server)
} }
func (c *Client) resetServers() error { func (c *Client) resetServers() error {
@ -558,7 +558,7 @@ func (c *Client) registerWithServer(registration *protobuf.PushNotificationRegis
func (c *Client) registrationLoop() error { func (c *Client) registrationLoop() error {
for { for {
c.config.Logger.Info("runing registration loop") c.config.Logger.Info("running registration loop")
servers, err := c.persistence.GetServers() servers, err := c.persistence.GetServers()
if err != nil { if err != nil {
c.config.Logger.Error("failed retrieving servers, quitting registration loop", zap.Error(err)) c.config.Logger.Error("failed retrieving servers, quitting registration loop", zap.Error(err))
@ -571,39 +571,44 @@ func (c *Client) registrationLoop() error {
var nonRegisteredServers []*PushNotificationServer var nonRegisteredServers []*PushNotificationServer
for _, server := range servers { for _, server := range servers {
if server.Registered { if !server.Registered {
nonRegisteredServers = append(nonRegisteredServers, server) nonRegisteredServers = append(nonRegisteredServers, server)
} }
if len(nonRegisteredServers) == 0 { }
c.config.Logger.Debug("registered with all servers, quitting registration loop") if len(nonRegisteredServers) == 0 {
return nil c.config.Logger.Debug("registered with all servers, quitting registration loop")
} return nil
}
var lowestNextRetry int64 c.config.Logger.Info("Trying to register with", zap.Int("servers", len(nonRegisteredServers)))
for _, server := range nonRegisteredServers { var lowestNextRetry int64
if shouldRetryRegisteringWithServer(server) {
err := c.registerWithServer(c.lastPushNotificationRegistration, server) for _, server := range nonRegisteredServers {
if err != nil { nR := nextServerRetry(server)
return err c.config.Logger.Info("Next retry", zap.Int64("now", time.Now().Unix()), zap.Int64("next", nR))
} if shouldRetryRegisteringWithServer(server) {
c.config.Logger.Info("registering with server", zap.Any("server", server))
err := c.registerWithServer(c.lastPushNotificationRegistration, server)
if err != nil {
return err
} }
nextRetry := nextServerRetry(server) }
if lowestNextRetry == 0 || nextRetry < lowestNextRetry { nextRetry := nextServerRetry(server)
lowestNextRetry = nextRetry if lowestNextRetry == 0 || nextRetry < lowestNextRetry {
} lowestNextRetry = nextRetry
} }
nextRetry := lowestNextRetry - time.Now().Unix() }
waitFor := time.Duration(nextRetry)
select {
case <-time.After(waitFor * time.Second): nextRetry := lowestNextRetry - time.Now().Unix()
case <-c.registrationLoopQuitChan: waitFor := time.Duration(nextRetry)
return nil select {
case <-time.After(waitFor * time.Second):
case <-c.registrationLoopQuitChan:
return nil
}
} }
} }
} }
@ -643,73 +648,44 @@ func (c *Client) SaveLastPushNotificationRegistration(registration *protobuf.Pus
return nil return nil
} }
func (c *Client) Register(deviceToken string, contactIDs []*ecdsa.PublicKey, mutedChatIDs []string) ([]*PushNotificationServer, error) { func (c *Client) Registered() (bool, error) {
servers, err := c.persistence.GetServers()
if err != nil {
return false, err
}
for _, s := range servers {
if !s.Registered {
return false, nil
}
}
return true, nil
}
func (c *Client) GetServers() ([]*PushNotificationServer, error) {
return c.persistence.GetServers()
}
func (c *Client) Register(deviceToken string, contactIDs []*ecdsa.PublicKey, mutedChatIDs []string) error {
// stop registration loop // stop registration loop
c.stopRegistrationLoop() c.stopRegistrationLoop()
c.DeviceToken = deviceToken c.DeviceToken = deviceToken
servers, err := c.persistence.GetServers()
if err != nil {
return nil, err
}
if len(servers) == 0 {
return nil, errors.New("no servers to register with")
}
registration, err := c.buildPushNotificationRegistrationMessage(contactIDs, mutedChatIDs) registration, err := c.buildPushNotificationRegistrationMessage(contactIDs, mutedChatIDs)
if err != nil { if err != nil {
return nil, err return err
} }
err = c.SaveLastPushNotificationRegistration(registration, contactIDs) err = c.SaveLastPushNotificationRegistration(registration, contactIDs)
if err != nil { if err != nil {
return nil, err return err
} }
var serverPublicKeys []*ecdsa.PublicKey c.startRegistrationLoop()
for _, server := range servers {
err := c.registerWithServer(registration, server)
if err != nil {
return nil, err
}
serverPublicKeys = append(serverPublicKeys, server.PublicKey)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) return nil
// This code polls the database for server registrations, giving up
// after 5 seconds
for {
select {
case <-c.quit:
return servers, nil
case <-ctx.Done():
c.config.Logger.Info("could not register all servers")
// start registration loop
c.startRegistrationLoop()
return servers, nil
case <-time.After(200 * time.Millisecond):
servers, err = c.persistence.GetServersByPublicKey(serverPublicKeys)
if err != nil {
return nil, err
}
allRegistered := true
for _, server := range servers {
allRegistered = allRegistered && server.Registered
}
// If any of the servers we haven't registered yet, continue
if !allRegistered {
continue
}
// all have registered,cancel context and return
cancel()
return servers, nil
}
}
} }
// HandlePushNotificationRegistrationResponse should check whether the response was successful or not, retry if necessary otherwise store the result in the database // HandlePushNotificationRegistrationResponse should check whether the response was successful or not, retry if necessary otherwise store the result in the database

View File

@ -3,6 +3,7 @@ package protocol
import ( import (
"context" "context"
"crypto/ecdsa" "crypto/ecdsa"
"errors"
"io/ioutil" "io/ioutil"
"os" "os"
"testing" "testing"
@ -114,11 +115,9 @@ func (s *MessengerPushNotificationSuite) newPushNotificationServer(shh types.Whi
} }
func (s *MessengerPushNotificationSuite) TestReceivePushNotification() { func (s *MessengerPushNotificationSuite) TestReceivePushNotification() {
errChan := make(chan error)
bob1DeviceToken := "token-1" bob1DeviceToken := "token-1"
bob2DeviceToken := "token-2" bob2DeviceToken := "token-2"
var bob1Servers, bob2Servers []*push_notification_client.PushNotificationServer
bob1 := s.m bob1 := s.m
bob2 := s.newMessengerWithKey(s.shh, s.m.identity) bob2 := s.newMessengerWithKey(s.shh, s.m.identity)
@ -130,10 +129,7 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotification() {
err := bob1.AddPushNotificationServer(context.Background(), &server.identity.PublicKey) err := bob1.AddPushNotificationServer(context.Background(), &server.identity.PublicKey)
s.Require().NoError(err) s.Require().NoError(err)
go func() { err = bob1.RegisterForPushNotifications(context.Background(), bob1DeviceToken)
bob1Servers, err = bob1.RegisterForPushNotifications(context.Background(), bob1DeviceToken)
errChan <- err
}()
// Receive message, reply // Receive message, reply
// TODO: find a better way to handle this waiting // TODO: find a better way to handle this waiting
@ -163,21 +159,28 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotification() {
_, err = bob1.RetrieveAll() _, err = bob1.RetrieveAll()
s.Require().NoError(err) s.Require().NoError(err)
// Pull servers and check we registered
err = tt.RetryWithBackOff(func() error {
registered, err := bob1.RegisteredForPushNotifications()
if err != nil {
return err
}
if !registered {
return errors.New("not registered")
}
return nil
})
// Make sure we receive it // Make sure we receive it
err = <-errChan
s.Require().NoError(err) s.Require().NoError(err)
s.Require().NotNil(bob1Servers) bob1Servers, err := bob1.GetPushNotificationServers()
s.Require().Len(bob1Servers, 1) s.Require().NoError(err)
s.Require().True(bob1Servers[0].Registered)
// Register bob2 // Register bob2
err = bob2.AddPushNotificationServer(context.Background(), &server.identity.PublicKey) err = bob2.AddPushNotificationServer(context.Background(), &server.identity.PublicKey)
s.Require().NoError(err) s.Require().NoError(err)
go func() { err = bob2.RegisterForPushNotifications(context.Background(), bob2DeviceToken)
bob2Servers, err = bob2.RegisterForPushNotifications(context.Background(), bob2DeviceToken) s.Require().NoError(err)
errChan <- err
}()
// Receive message, reply // Receive message, reply
// TODO: find a better way to handle this waiting // TODO: find a better way to handle this waiting
@ -207,12 +210,20 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotification() {
_, err = bob2.RetrieveAll() _, err = bob2.RetrieveAll()
s.Require().NoError(err) s.Require().NoError(err)
err = tt.RetryWithBackOff(func() error {
registered, err := bob2.RegisteredForPushNotifications()
if err != nil {
return err
}
if !registered {
return errors.New("not registered")
}
return nil
})
// Make sure we receive it // Make sure we receive it
err = <-errChan
s.Require().NoError(err) s.Require().NoError(err)
s.Require().NotNil(bob2Servers) bob2Servers, err := bob2.GetPushNotificationServers()
s.Require().Len(bob2Servers, 1) s.Require().NoError(err)
s.Require().True(bob2Servers[0].Registered)
err = alice.pushNotificationClient.QueryPushNotificationInfo(&bob2.identity.PublicKey) err = alice.pushNotificationClient.QueryPushNotificationInfo(&bob2.identity.PublicKey)
s.Require().NoError(err) s.Require().NoError(err)

View File

@ -417,10 +417,10 @@ func (api *PublicAPI) StopPushNotificationServer() error {
// PushNotification client // PushNotification client
func (api *PublicAPI) RegisterForPushNotifications(ctx context.Context, deviceToken string) ([]*push_notification_client.PushNotificationServer, error) { func (api *PublicAPI) RegisterForPushNotifications(ctx context.Context, deviceToken string) error {
err := api.service.accountsDB.SaveSetting("remote-push-notifications-enabled", true) err := api.service.accountsDB.SaveSetting("remote-push-notifications-enabled", true)
if err != nil { if err != nil {
return nil, err return err
} }
return api.service.messenger.RegisterForPushNotifications(ctx, deviceToken) return api.service.messenger.RegisterForPushNotifications(ctx, deviceToken)
} }
@ -459,6 +459,14 @@ func (api *PublicAPI) AddPushNotificationServer(ctx context.Context, publicKeyBy
return api.service.messenger.AddPushNotificationServer(ctx, publicKey) return api.service.messenger.AddPushNotificationServer(ctx, publicKey)
} }
func (api *PublicAPI) GetPushNotificationServers() ([]*push_notification_client.PushNotificationServer, error) {
return api.service.messenger.GetPushNotificationServers()
}
func (api *PublicAPI) RegisteredForPushNotifications() (bool, error) {
return api.service.messenger.RegisteredForPushNotifications()
}
// Echo is a method for testing purposes. // Echo is a method for testing purposes.
func (api *PublicAPI) Echo(ctx context.Context, message string) (string, error) { func (api *PublicAPI) Echo(ctx context.Context, message string) (string, error) {
return message, nil return message, nil