diff --git a/protocol/common/message_processor.go b/protocol/common/message_processor.go index 738b5ca00..5c4720fc4 100644 --- a/protocol/common/message_processor.go +++ b/protocol/common/message_processor.go @@ -438,8 +438,10 @@ func (p *MessageProcessor) sendMessageSpec(ctx context.Context, publicKey *ecdsa MessageIDs: messageIDs, } + logger.Debug("subscriptions", zap.Int("count", len(p.subscriptions))) // Publish on channels, drop if buffer is full for _, c := range p.subscriptions { + logger.Debug("sending on subscription") select { case c <- sentMessage: default: diff --git a/protocol/push_notification_client/client.go b/protocol/push_notification_client/client.go index 59f48680d..2f49ba45e 100644 --- a/protocol/push_notification_client/client.go +++ b/protocol/push_notification_client/client.go @@ -30,8 +30,14 @@ const encryptedPayloadKeyLength = 16 const accessTokenKeyLength = 16 const staleQueryTimeInSeconds = 86400 -// maxRetries is the maximum number of attempts we do before giving up registering with a server -const maxRetries int64 = 12 +// maxRegistrationRetries is the maximum number of attempts we do before giving up registering with a server +const maxRegistrationRetries int64 = 12 + +// maxPushNotificationRetries is the maximum number of attempts before we give up sending a push notification +const maxPushNotificationRetries int64 = 4 + +// pushNotificationBackoffTime is the step of the exponential backoff +const pushNotificationBackoffTime int64 = 2 // RegistrationBackoffTime is the step of the exponential backoff const RegistrationBackoffTime int64 = 15 @@ -70,7 +76,8 @@ type PushNotificationInfo struct { type SentNotification struct { PublicKey *ecdsa.PublicKey InstallationID string - SentAt int64 + LastTriedAt int64 + RetryCount int64 MessageID []byte Success bool Error protobuf.PushNotificationReport_ErrorType @@ -125,6 +132,9 @@ type Client struct { // registrationLoopQuitChan is a channel to indicate to the registration loop that should be terminating registrationLoopQuitChan chan struct{} + + // resendingLoopQuitChan is a channel to indicate to the send loop that shoudl be terminating + resendingLoopQuitChan chan struct{} } func New(persistence *Persistence, config *Config, processor *common.MessageProcessor) *Client { @@ -139,6 +149,7 @@ func New(persistence *Persistence, config *Config, processor *common.MessageProc func (c *Client) subscribeForSentMessages() { go func() { + c.config.Logger.Info("subscribing for messages") subscription := c.messageProcessor.Subscribe() for { select { @@ -181,12 +192,26 @@ func (c *Client) stopRegistrationLoop() { } } +func (c *Client) stopResendingLoop() { + // stop old registration loop + if c.resendingLoopQuitChan != nil { + close(c.resendingLoopQuitChan) + c.resendingLoopQuitChan = nil + } +} + func (c *Client) startRegistrationLoop() { c.stopRegistrationLoop() c.registrationLoopQuitChan = make(chan struct{}) go c.registrationLoop() } +func (c *Client) startResendingLoop() { + c.stopResendingLoop() + c.resendingLoopQuitChan = make(chan struct{}) + go c.resendingLoop() +} + func (c *Client) Start() error { if c.messageProcessor == nil { return errors.New("can't start, missing message processor") @@ -198,19 +223,19 @@ func (c *Client) Start() error { } c.subscribeForSentMessages() c.startRegistrationLoop() + c.startResendingLoop() return nil } func (c *Client) Stop() error { close(c.quit) - if c.registrationLoopQuitChan != nil { - close(c.registrationLoopQuitChan) - } + c.stopRegistrationLoop() + c.stopResendingLoop() return nil } -func (c *Client) queryNotificationInfo(publicKey *ecdsa.PublicKey) error { +func (c *Client) queryNotificationInfo(publicKey *ecdsa.PublicKey, force bool) error { // Check if we queried recently queriedAt, err := c.persistence.GetQueriedAt(publicKey) if err != nil { @@ -218,7 +243,7 @@ func (c *Client) queryNotificationInfo(publicKey *ecdsa.PublicKey) error { } // Naively query again if too much time has passed. // Here it might not be necessary - if time.Now().Unix()-queriedAt > staleQueryTimeInSeconds { + if force || time.Now().Unix()-queriedAt > staleQueryTimeInSeconds { c.config.Logger.Info("querying info") err := c.QueryPushNotificationInfo(publicKey) if err != nil { @@ -294,92 +319,21 @@ func (c *Client) HandleMessageSent(sentMessage *common.SentMessage) error { c.config.Logger.Info("actionable messages", zap.Any("message-ids", trackedMessageIDs), zap.Any("installation-ids", installationIDs)) - err := c.queryNotificationInfo(publicKey) + infos, err := c.sendNotification(publicKey, installationIDs, trackedMessageIDs[0]) if err != nil { return err } - c.config.Logger.Info("queried info") - // Retrieve infos - info, err := c.GetPushNotificationInfo(publicKey, installationIDs) - if err != nil { - c.config.Logger.Error("could not get pn info", zap.Error(err)) - return err - } - - // Naively dispatch to the first server for now - // This wait for an acknowledgement and try a different server after a timeout - // Also we sent a single notification for multiple message ids, need to check with UI what's the desired behavior - - // Sort by server so we tend to hit the same one - sort.Slice(info, func(i, j int) bool { - return info[i].ServerPublicKey.X.Cmp(info[j].ServerPublicKey.X) <= 0 - }) - - c.config.Logger.Info("retrieved info") - - installationIDsMap := make(map[string]bool) - // One info per installation id, grouped by server - actionableInfos := make(map[string][]*PushNotificationInfo) - for _, i := range info { - - c.config.Logger.Info("queried info", zap.String("id", i.InstallationID)) - if !installationIDsMap[i.InstallationID] { - serverKey := hex.EncodeToString(crypto.CompressPubkey(i.ServerPublicKey)) - actionableInfos[serverKey] = append(actionableInfos[serverKey], i) - installationIDsMap[i.InstallationID] = true - } - - } - - c.config.Logger.Info("actionable info", zap.Int("count", len(actionableInfos))) - - for _, infos := range actionableInfos { - var pushNotifications []*protobuf.PushNotification - for _, i := range infos { - // TODO: Add ChatID, message, public_key - pushNotifications = append(pushNotifications, &protobuf.PushNotification{ - AccessToken: i.AccessToken, - PublicKey: common.HashPublicKey(publicKey), - InstallationId: i.InstallationID, - }) - - } - request := &protobuf.PushNotificationRequest{ - MessageId: trackedMessageIDs[0], - Requests: pushNotifications, - } - serverPublicKey := infos[0].ServerPublicKey - - payload, err := proto.Marshal(request) - if err != nil { - return err - } - - rawMessage := &common.RawMessage{ - Payload: payload, - MessageType: protobuf.ApplicationMetadataMessage_PUSH_NOTIFICATION_REQUEST, - } - - // TODO: We should use the messageID for the response - _, err = c.messageProcessor.SendPrivate(context.Background(), serverPublicKey, rawMessage) - - if err != nil { - return err - } - - // Mark message as sent, this is at-most-once semantic - // for all messageIDs - for _, i := range infos { - for _, messageID := range trackedMessageIDs { - - c.config.Logger.Info("marking as sent ", zap.Binary("mid", messageID), zap.String("id", i.InstallationID)) - if err := c.notifiedOn(publicKey, i.InstallationID, messageID); err != nil { - return err - } + // Mark message as sent, this is at-most-once semantic + // for all messageIDs + for _, i := range infos { + for _, messageID := range trackedMessageIDs { + c.config.Logger.Info("marking as sent ", zap.Binary("mid", messageID), zap.String("id", i.InstallationID)) + if err := c.notifiedOn(publicKey, i.InstallationID, messageID); err != nil { + return err } - } + } } return nil @@ -399,9 +353,9 @@ func (c *Client) shouldNotifyOn(publicKey *ecdsa.PublicKey, installationID strin } func (c *Client) notifiedOn(publicKey *ecdsa.PublicKey, installationID string, messageID []byte) error { - return c.persistence.NotifiedOn(&SentNotification{ + return c.persistence.UpsertSentNotification(&SentNotification{ PublicKey: publicKey, - SentAt: time.Now().Unix(), + LastTriedAt: time.Now().Unix(), InstallationID: installationID, MessageID: messageID, }) @@ -532,11 +486,26 @@ func nextServerRetry(server *PushNotificationServer) int64 { return server.LastRetriedAt + RegistrationBackoffTime*server.RetryCount*int64(math.Exp2(float64(server.RetryCount))) } +func nextPushNotificationRetry(pn *SentNotification) int64 { + return pn.LastTriedAt + pushNotificationBackoffTime*pn.RetryCount*int64(math.Exp2(float64(pn.RetryCount))) +} + // We calculate if it's too early to retry, by exponentially backing off func shouldRetryRegisteringWithServer(server *PushNotificationServer) bool { + if server.RetryCount > maxRegistrationRetries { + return false + } return time.Now().Unix() > nextServerRetry(server) } +// We calculate if it's too early to retry, by exponentially backing off +func shouldRetryPushNotification(pn *SentNotification) bool { + if pn.RetryCount > maxPushNotificationRetries { + return false + } + return time.Now().Unix() > nextPushNotificationRetry(pn) +} + func (c *Client) resetServers() error { servers, err := c.persistence.GetServers() if err != nil { @@ -603,6 +572,153 @@ func (c *Client) registerWithServer(registration *protobuf.PushNotificationRegis return nil } +func (c *Client) sendNotification(publicKey *ecdsa.PublicKey, installationIDs []string, messageID []byte) ([]*PushNotificationInfo, error) { + err := c.queryNotificationInfo(publicKey, false) + if err != nil { + return nil, err + } + c.config.Logger.Info("queried info") + // Retrieve infos + info, err := c.GetPushNotificationInfo(publicKey, installationIDs) + if err != nil { + c.config.Logger.Error("could not get pn info", zap.Error(err)) + return nil, err + } + + // Naively dispatch to the first server for now + // This wait for an acknowledgement and try a different server after a timeout + // Also we sent a single notification for multiple message ids, need to check with UI what's the desired behavior + + // Sort by server so we tend to hit the same one + sort.Slice(info, func(i, j int) bool { + return info[i].ServerPublicKey.X.Cmp(info[j].ServerPublicKey.X) <= 0 + }) + + c.config.Logger.Info("retrieved info") + + installationIDsMap := make(map[string]bool) + // One info per installation id, grouped by server + actionableInfos := make(map[string][]*PushNotificationInfo) + for _, i := range info { + + c.config.Logger.Info("queried info", zap.String("id", i.InstallationID)) + if !installationIDsMap[i.InstallationID] { + serverKey := hex.EncodeToString(crypto.CompressPubkey(i.ServerPublicKey)) + actionableInfos[serverKey] = append(actionableInfos[serverKey], i) + installationIDsMap[i.InstallationID] = true + } + + } + + c.config.Logger.Info("actionable info", zap.Int("count", len(actionableInfos))) + + var actionedInfo []*PushNotificationInfo + for _, infos := range actionableInfos { + var pushNotifications []*protobuf.PushNotification + for _, i := range infos { + // TODO: Add ChatID, message, public_key + pushNotifications = append(pushNotifications, &protobuf.PushNotification{ + AccessToken: i.AccessToken, + PublicKey: common.HashPublicKey(publicKey), + InstallationId: i.InstallationID, + }) + + } + request := &protobuf.PushNotificationRequest{ + MessageId: messageID, + Requests: pushNotifications, + } + serverPublicKey := infos[0].ServerPublicKey + + payload, err := proto.Marshal(request) + if err != nil { + return nil, err + } + + rawMessage := &common.RawMessage{ + Payload: payload, + MessageType: protobuf.ApplicationMetadataMessage_PUSH_NOTIFICATION_REQUEST, + } + + // TODO: We should use the messageID for the response + _, err = c.messageProcessor.SendPrivate(context.Background(), serverPublicKey, rawMessage) + + if err != nil { + return nil, err + } + actionedInfo = append(actionedInfo, infos...) + } + return actionedInfo, nil +} + +func (c *Client) resendNotification(pn *SentNotification) error { + c.config.Logger.Info("resending notification", zap.Any("notification", pn)) + pn.RetryCount += 1 + pn.LastTriedAt = time.Now().Unix() + err := c.persistence.UpsertSentNotification(pn) + if err != nil { + return err + } + + // Re-fetch push notification info + err = c.queryNotificationInfo(pn.PublicKey, true) + if err != nil { + return err + } + + if err != nil { + c.config.Logger.Error("could not get pn info", zap.Error(err)) + return err + } + + _, err = c.sendNotification(pn.PublicKey, []string{pn.InstallationID}, pn.MessageID) + return err +} + +func (c *Client) resendingLoop() error { + for { + c.config.Logger.Info("running resending loop") + var lowestNextRetry int64 + + retriableNotifications, err := c.persistence.GetRetriablePushNotifications() + if err != nil { + c.config.Logger.Error("failed retrieving notifications, quitting resending loop", zap.Error(err)) + return err + } + + if len(retriableNotifications) == 0 { + c.config.Logger.Debug("no retriable notifications, quitting") + return nil + } + + for _, pn := range retriableNotifications { + nextRetry := nextPushNotificationRetry(pn) + c.config.Logger.Info("Next retry", zap.Int64("now", time.Now().Unix()), zap.Int64("next", nextRetry)) + if shouldRetryPushNotification(pn) { + c.config.Logger.Info("retrying pn", zap.Any("pn", pn)) + err := c.resendNotification(pn) + if err != nil { + return err + } + } + if lowestNextRetry == 0 || nextRetry < lowestNextRetry { + lowestNextRetry = nextRetry + } + } + + nextRetry := lowestNextRetry - time.Now().Unix() + waitFor := time.Duration(nextRetry) + select { + + case <-time.After(waitFor * time.Second): + case <-c.resendingLoopQuitChan: + return nil + + } + + } +} + func (c *Client) registrationLoop() error { for { c.config.Logger.Info("running registration loop") @@ -882,6 +998,9 @@ func (c *Client) HandlePushNotificationResponse(serverKey *ecdsa.PublicKey, resp return err } } + // Restart resending loop + c.stopResendingLoop() + c.startResendingLoop() return nil } diff --git a/protocol/push_notification_client/migrations/migrations.go b/protocol/push_notification_client/migrations/migrations.go index 431630032..2743a2282 100644 --- a/protocol/push_notification_client/migrations/migrations.go +++ b/protocol/push_notification_client/migrations/migrations.go @@ -1,7 +1,7 @@ // Code generated by go-bindata. DO NOT EDIT. // sources: // 1593601729_initial_schema.down.sql (144B) -// 1593601729_initial_schema.up.sql (1.709kB) +// 1593601729_initial_schema.up.sql (1.753kB) // doc.go (382B) package migrations @@ -91,7 +91,7 @@ func _1593601729_initial_schemaDownSql() (*asset, error) { return a, nil } -var __1593601729_initial_schemaUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xac\x54\xc1\x6e\xe2\x30\x10\xbd\xe7\x2b\xe6\x58\x24\x0e\x7b\xef\x29\x50\xb3\x8a\x64\x39\xbb\x60\x24\x6e\x96\xd7\x99\x36\x16\x59\xa7\x6b\x3b\xd5\xf2\xf7\x2b\x27\x90\x42\x9d\x75\xa4\x96\x0b\x12\x33\xcf\xa3\x79\x6f\x5e\xde\x7a\x4b\x72\x4e\x80\xe7\x2b\x4a\xa0\xd8\x00\x2b\x39\x90\x43\xb1\xe3\x3b\x78\xed\x5c\x2d\x4c\xeb\xf5\xb3\x56\xd2\xeb\xd6\x08\xd5\x68\x34\x5e\x38\xb4\x6f\x68\x1d\x3c\x64\x00\xaf\xdd\xaf\x46\x2b\x71\xc4\x13\xac\x68\xb9\xea\xdf\xb3\x3d\xa5\xcb\x0c\xc0\xe2\x8b\x76\x1e\x2d\x56\xb0\x2a\x4b\x4a\x72\x06\x4f\x64\x93\xef\x29\x87\x4d\x4e\x77\xe4\x16\x23\xa4\x87\x82\xf1\x71\xc2\x88\xfd\x16\x70\x8d\x74\x5e\x58\xf4\x56\xcf\x21\x03\xe8\x24\x54\xdb\x99\x14\x4a\x2a\x85\xce\x09\xdf\x1e\xd1\x00\x27\x07\x1e\x8a\x7b\x56\xfc\xdc\x93\x87\x77\x4e\x0b\x28\x19\xac\x4b\xb6\xa1\xc5\x9a\xc3\x96\xfc\xa0\xf9\x9a\x64\x8b\xc7\x2c\xfb\x8c\x6e\x7f\x3a\xb4\x1a\xe7\x75\x1b\x70\x11\xcd\x4b\xeb\x24\x74\x15\x3f\x8a\x76\x5f\x5e\xb0\xf7\x25\xa1\xcd\x73\x3b\xcb\x60\x70\x88\x48\x41\xb4\x71\x5e\x36\xcd\x30\x5b\x57\xfd\x0d\x6e\x00\xd1\x85\x3e\x78\x2b\x58\xe1\x6d\x5a\xa5\xe0\x4e\xdd\x9a\xa8\x1e\x6b\xf4\x71\x8d\x65\xbc\xfa\x7d\xe5\xf3\x56\xaa\x23\x56\xe2\x37\x3a\x27\x5f\xce\x66\x38\xff\x99\xbc\xab\xaa\xa5\x9f\xd4\xe7\x32\x69\x82\xff\x99\xe7\xfb\xd8\x5b\x0e\xc5\x77\x56\x6e\x49\x06\xf0\x59\x12\x2e\xfc\x5c\x37\xe6\x69\xa4\xac\x50\x4b\x57\x63\xf5\x35\xb7\xf4\x2b\x4d\x48\xe1\xba\xde\x46\x63\x00\x45\x69\x30\x26\x11\x5a\xdb\xda\x44\x62\x44\xa2\x2e\x21\x61\xa4\xc5\x17\xe4\x1d\x32\xd1\x5e\x29\x7b\xc9\xc9\xa1\x16\xcb\x03\xa0\x5a\xe3\xa5\x0a\x4e\x71\x7d\x7b\xa8\xba\x93\xf1\x35\x7a\xad\x82\x66\xff\xa7\x36\x92\xbb\xc6\xcf\xfa\xbe\x60\x4f\xe4\x00\xba\xfa\x2b\x92\x61\x71\x7d\xd7\x92\xa5\x83\x25\xf5\x69\x2e\x1e\xb3\x7f\x01\x00\x00\xff\xff\xed\x10\xc3\xcd\xad\x06\x00\x00") +var __1593601729_initial_schemaUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xac\x54\xc1\x6e\xe3\x20\x10\xbd\xfb\x2b\xe6\xd8\x48\x39\xec\xbd\x27\x27\x25\x2b\x4b\x08\xef\x26\x44\xca\x0d\xb1\x78\x5a\xa3\x78\x71\x17\x70\xb5\xf9\xfb\x15\x76\xe2\x26\xc5\x8b\xab\xb6\x17\x4b\x1e\x1e\xa3\x79\x6f\x1e\x6f\xbd\x25\x39\x27\xc0\xf3\x15\x25\x50\x6c\x80\x95\x1c\xc8\xa1\xd8\xf1\x1d\x3c\x77\xae\x16\xa6\xf5\xfa\x51\x2b\xe9\x75\x6b\x84\x6a\x34\x1a\x2f\x1c\xda\x17\xb4\x0e\xee\x32\x80\xe7\xee\x57\xa3\x95\x38\xe2\x09\x56\xb4\x5c\xf5\xf7\xd9\x9e\xd2\x65\x06\x60\xf1\x49\x3b\x8f\x16\x2b\x58\x95\x25\x25\x39\x83\x07\xb2\xc9\xf7\x94\xc3\x26\xa7\x3b\x72\x8b\x11\xd2\x43\xc1\xf8\xd8\x61\xc4\x7e\x0b\xb8\x46\x3a\x2f\x2c\x7a\xab\xe7\x90\x01\x74\x12\xaa\xed\x4c\x0a\x25\x95\x42\xe7\x84\x6f\x8f\x68\x80\x93\x03\x0f\xc5\x3d\x2b\x7e\xee\xc9\xdd\x2b\xa7\x05\x94\x0c\xd6\x25\xdb\xd0\x62\xcd\x61\x4b\x7e\xd0\x7c\x4d\xb2\xc5\x7d\x96\x7d\x44\xb7\x3f\x1d\x5a\x8d\xf3\xba\x0d\xb8\x88\xe6\xe5\xe8\x24\x74\x15\x5f\x8a\x66\x5f\x5e\xb0\x5f\x4b\x42\x9b\xc7\x76\x96\xc1\xe0\x10\x91\x82\x68\xe3\xbc\x6c\x9a\xa1\xb7\xae\xfa\x1d\xdc\x00\xa2\x0d\xbd\xf1\x56\xb0\xc2\xcb\xb4\x4a\xc1\x9d\xba\x35\x51\x3d\xd6\xe8\xed\x18\xcb\x78\xf4\xaf\x95\xcf\x5b\xa9\x8e\x58\x89\xdf\xe8\x9c\x7c\x3a\x9b\xe1\xfc\x33\xb9\x57\x55\x4b\x3f\xa9\xcf\xa5\xd3\x04\xff\x33\xcf\xd7\xb6\xb7\x1c\x8a\xef\xac\xdc\x92\x0c\xe0\xa3\x24\x5c\xf8\x5c\x1f\xcc\xd3\x48\x59\xa1\x96\xae\xc6\xea\x73\x6e\xe9\xf3\x61\x32\x1d\xde\x9f\x09\xae\xeb\x2d\x37\x86\x55\x84\x1a\x53\x0b\xad\x6d\x6d\xa2\x53\xb4\x80\x25\x24\x4c\xb7\xf8\xc4\x2a\x86\xfc\xb4\x57\x5b\xb8\x64\xea\x50\x8b\xa5\x04\x50\xad\xf1\x52\x05\x57\xb9\xfe\x78\xa8\xba\x93\xf1\x35\x7a\xad\x82\xbe\xff\xa7\x36\x92\xbb\xc6\xcf\xbe\x91\x82\x3d\x90\x03\xe8\xea\xaf\x48\x06\xcb\xb5\x07\x4a\x96\x0e\xa1\xd4\x33\x5e\xdc\x67\xff\x02\x00\x00\xff\xff\x39\x9f\x6b\x23\xd9\x06\x00\x00") func _1593601729_initial_schemaUpSqlBytes() ([]byte, error) { return bindataRead( @@ -106,8 +106,8 @@ func _1593601729_initial_schemaUpSql() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "1593601729_initial_schema.up.sql", size: 1709, mode: os.FileMode(0644), modTime: time.Unix(1595237467, 0)} - a := &asset{bytes: bytes, info: info, digest: [32]uint8{0x35, 0x40, 0x6a, 0x4a, 0x45, 0x37, 0x37, 0x99, 0x97, 0x5, 0xb3, 0x43, 0x6, 0x43, 0xcc, 0x10, 0x32, 0xbc, 0x16, 0xcc, 0xe0, 0xfb, 0x3, 0xa8, 0xce, 0x6a, 0x6b, 0x39, 0xd4, 0xe0, 0xbe, 0xa4}} + info := bindataFileInfo{name: "1593601729_initial_schema.up.sql", size: 1753, mode: os.FileMode(0644), modTime: time.Unix(1595240420, 0)} + a := &asset{bytes: bytes, info: info, digest: [32]uint8{0x28, 0xbf, 0x64, 0xe0, 0x65, 0x53, 0xd3, 0x80, 0xf4, 0x46, 0xce, 0xd6, 0x23, 0x4e, 0xc5, 0x8f, 0x80, 0x4e, 0x91, 0xa7, 0x2e, 0x9, 0x3b, 0xf4, 0x5f, 0xa1, 0xff, 0xfc, 0x6e, 0x4, 0xa2, 0xe7}} return a, nil } diff --git a/protocol/push_notification_client/migrations/sql/1593601729_initial_schema.up.sql b/protocol/push_notification_client/migrations/sql/1593601729_initial_schema.up.sql index d4e4941e6..1af27e9de 100644 --- a/protocol/push_notification_client/migrations/sql/1593601729_initial_schema.up.sql +++ b/protocol/push_notification_client/migrations/sql/1593601729_initial_schema.up.sql @@ -37,7 +37,8 @@ CREATE TABLE IF NOT EXISTS push_notification_client_sent_notifications ( public_key BLOB NOT NULL, hashed_public_key BLOB NOT NULL, installation_id TEXT NOT NULL, - sent_at INT NOT NULL, + last_tried_at INT NOT NULL, + retry_count INT NOT NULL DEFAULT 0, success BOOLEAN NOT NULL DEFAULT FALSE, error INT NOT NULL DEFAULT 0, UNIQUE(message_id, public_key, installation_id) diff --git a/protocol/push_notification_client/persistence.go b/protocol/push_notification_client/persistence.go index 42fd128c9..bc657be94 100644 --- a/protocol/push_notification_client/persistence.go +++ b/protocol/push_notification_client/persistence.go @@ -271,8 +271,8 @@ func (p *Persistence) ShouldSendNotificationToAllInstallationIDs(publicKey *ecds return count == 0, nil } -func (p *Persistence) NotifiedOn(n *SentNotification) error { - _, err := p.db.Exec(`INSERT INTO push_notification_client_sent_notifications (public_key, installation_id, message_id, sent_at, hashed_public_key) VALUES (?, ?, ?, ?, ?)`, crypto.CompressPubkey(n.PublicKey), n.InstallationID, n.MessageID, n.SentAt, n.HashedPublicKey()) +func (p *Persistence) UpsertSentNotification(n *SentNotification) error { + _, err := p.db.Exec(`INSERT INTO push_notification_client_sent_notifications (public_key, installation_id, message_id, last_tried_at, retry_count, success, error, hashed_public_key) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, crypto.CompressPubkey(n.PublicKey), n.InstallationID, n.MessageID, n.LastTriedAt, n.RetryCount, n.Success, n.Error, n.HashedPublicKey()) return err } @@ -282,7 +282,7 @@ func (p *Persistence) GetSentNotification(hashedPublicKey []byte, installationID InstallationID: installationID, MessageID: messageID, } - err := p.db.QueryRow(`SELECT sent_at, error, success, public_key FROM push_notification_client_sent_notifications WHERE hashed_public_key = ?`, hashedPublicKey).Scan(&sentNotification.SentAt, &sentNotification.Error, &sentNotification.Success, &publicKeyBytes) + err := p.db.QueryRow(`SELECT retry_count, last_tried_at, error, success, public_key FROM push_notification_client_sent_notifications WHERE hashed_public_key = ?`, hashedPublicKey).Scan(&sentNotification.RetryCount, &sentNotification.LastTriedAt, &sentNotification.Error, &sentNotification.Success, &publicKeyBytes) if err != nil { return nil, err } @@ -302,6 +302,30 @@ func (p *Persistence) UpdateNotificationResponse(messageID []byte, response *pro return err } +func (p *Persistence) GetRetriablePushNotifications() ([]*SentNotification, error) { + var notifications []*SentNotification + rows, err := p.db.Query(`SELECT retry_count, last_tried_at, error, success, public_key, installation_id, message_id FROM push_notification_client_sent_notifications WHERE NOT success AND error = ?`, protobuf.PushNotificationReport_WRONG_TOKEN) + if err != nil { + return nil, err + } + + for rows.Next() { + var publicKeyBytes []byte + notification := &SentNotification{} + err = rows.Scan(¬ification.RetryCount, ¬ification.LastTriedAt, ¬ification.Error, ¬ification.Success, &publicKeyBytes, ¬ification.InstallationID, ¬ification.MessageID) + if err != nil { + return nil, err + } + publicKey, err := crypto.DecompressPubkey(publicKeyBytes) + if err != nil { + return nil, err + } + notification.PublicKey = publicKey + notifications = append(notifications, notification) + } + return notifications, err +} + func (p *Persistence) UpsertServer(server *PushNotificationServer) error { _, err := p.db.Exec(`INSERT INTO push_notification_client_servers (public_key, registered, registered_at, access_token, last_retried_at, retry_count) VALUES (?,?,?,?,?,?)`, crypto.CompressPubkey(server.PublicKey), server.Registered, server.RegisteredAt, server.AccessToken, server.LastRetriedAt, server.RetryCount) return err diff --git a/protocol/push_notification_client/persistence_test.go b/protocol/push_notification_client/persistence_test.go index 385f3cf14..8f522c767 100644 --- a/protocol/push_notification_client/persistence_test.go +++ b/protocol/push_notification_client/persistence_test.go @@ -222,15 +222,19 @@ func (s *SQLitePersistenceSuite) TestNotifiedOnAndUpdateNotificationResponse() { PublicKey: &key.PublicKey, InstallationID: installationID, MessageID: messageID, - SentAt: time.Now().Unix(), + LastTriedAt: time.Now().Unix(), } - s.Require().NoError(s.persistence.NotifiedOn(sentNotification)) + s.Require().NoError(s.persistence.UpsertSentNotification(sentNotification)) retrievedNotification, err := s.persistence.GetSentNotification(sentNotification.HashedPublicKey(), installationID, messageID) s.Require().NoError(err) s.Require().Equal(sentNotification, retrievedNotification) + retriableNotifications, err := s.persistence.GetRetriablePushNotifications() + s.Require().NoError(err) + s.Require().Len(retriableNotifications, 0) + response := &protobuf.PushNotificationReport{ Success: false, Error: protobuf.PushNotificationReport_WRONG_TOKEN, @@ -239,6 +243,10 @@ func (s *SQLitePersistenceSuite) TestNotifiedOnAndUpdateNotificationResponse() { } s.Require().NoError(s.persistence.UpdateNotificationResponse(messageID, response)) + // This notification should be retriable + retriableNotifications, err = s.persistence.GetRetriablePushNotifications() + s.Require().NoError(err) + s.Require().Len(retriableNotifications, 1) sentNotification.Error = protobuf.PushNotificationReport_WRONG_TOKEN @@ -262,6 +270,11 @@ func (s *SQLitePersistenceSuite) TestNotifiedOnAndUpdateNotificationResponse() { s.Require().NoError(err) s.Require().Equal(sentNotification, retrievedNotification) + // This notification should not be retriable + retriableNotifications, err = s.persistence.GetRetriablePushNotifications() + s.Require().NoError(err) + s.Require().Len(retriableNotifications, 0) + // Update with a unsuccessful notification, it should be ignored response = &protobuf.PushNotificationReport{ Success: false, diff --git a/protocol/push_notification_test.go b/protocol/push_notification_test.go index ddac6e8fe..0640470a8 100644 --- a/protocol/push_notification_test.go +++ b/protocol/push_notification_test.go @@ -3,6 +3,7 @@ package protocol import ( "context" "crypto/ecdsa" + "encoding/hex" "errors" "io/ioutil" "os" @@ -120,6 +121,9 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotification() { bob2 := s.newMessengerWithKey(s.shh, s.m.identity) server := s.newPushNotificationServer(s.shh) alice := s.newMessenger(s.shh) + // start alice and enable sending push notifications + s.Require().NoError(alice.Start()) + s.Require().NoError(alice.EnableSendingPushNotifications()) bobInstallationIDs := []string{bob1.installationID, bob2.installationID} // Register bob1 @@ -183,7 +187,12 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotification() { bob2Servers, err := bob2.GetPushNotificationServers() s.Require().NoError(err) - err = alice.pushNotificationClient.QueryPushNotificationInfo(&bob2.identity.PublicKey) + // Create one to one chat & send message + pkString := hex.EncodeToString(crypto.FromECDSAPub(&s.m.identity.PublicKey)) + chat := CreateOneToOneChat(pkString, &s.m.identity.PublicKey, alice.transport) + s.Require().NoError(alice.SaveChat(&chat)) + inputMessage := buildTestMessage(chat) + _, err = alice.SendChatMessage(context.Background(), inputMessage) s.Require().NoError(err) var info []*push_notification_client.PushNotificationInfo @@ -247,6 +256,9 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotificationFromContactO bob2 := s.newMessengerWithKey(s.shh, s.m.identity) server := s.newPushNotificationServer(s.shh) alice := s.newMessenger(s.shh) + // start alice and enable push notifications + s.Require().NoError(alice.Start()) + s.Require().NoError(alice.EnableSendingPushNotifications()) bobInstallationIDs := []string{bob.installationID, bob2.installationID} // Register bob @@ -294,7 +306,12 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotificationFromContactO bobServers, err := bob.GetPushNotificationServers() s.Require().NoError(err) - err = alice.pushNotificationClient.QueryPushNotificationInfo(&bob2.identity.PublicKey) + // Create one to one chat & send message + pkString := hex.EncodeToString(crypto.FromECDSAPub(&s.m.identity.PublicKey)) + chat := CreateOneToOneChat(pkString, &s.m.identity.PublicKey, alice.transport) + s.Require().NoError(alice.SaveChat(&chat)) + inputMessage := buildTestMessage(chat) + _, err = alice.SendChatMessage(context.Background(), inputMessage) s.Require().NoError(err) var info []*push_notification_client.PushNotificationInfo