diff --git a/protocol/pushnotificationserver/persistence.go b/protocol/pushnotificationserver/persistence.go index 0f31fb178..d0638c6ae 100644 --- a/protocol/pushnotificationserver/persistence.go +++ b/protocol/pushnotificationserver/persistence.go @@ -39,7 +39,8 @@ func NewSQLitePersistence(db *sql.DB) Persistence { func (p *SQLitePersistence) GetPushNotificationRegistrationByPublicKeyAndInstallationID(publicKey []byte, installationID string) (*protobuf.PushNotificationRegistration, error) { var marshaledRegistration []byte - err := p.db.QueryRow(`SELECT registration FROM push_notification_server_registrations WHERE public_key = ? AND installation_id = ?`, publicKey, installationID).Scan(&marshaledRegistration) + var version uint64 + err := p.db.QueryRow(`SELECT version,registration FROM push_notification_server_registrations WHERE public_key = ? AND installation_id = ?`, publicKey, installationID).Scan(&version, &marshaledRegistration) if err == sql.ErrNoRows { return nil, nil @@ -47,7 +48,14 @@ func (p *SQLitePersistence) GetPushNotificationRegistrationByPublicKeyAndInstall return nil, err } - registration := &protobuf.PushNotificationRegistration{} + registration := &protobuf.PushNotificationRegistration{ + InstallationId: installationID, + Version: version, + } + + if marshaledRegistration == nil { + return registration, nil + } if err := proto.Unmarshal(marshaledRegistration, registration); err != nil { return nil, err @@ -86,6 +94,10 @@ func (p *SQLitePersistence) GetPushNotificationRegistrationByPublicKeys(publicKe } registration := &protobuf.PushNotificationRegistration{} + // Skip if there's no registration + if marshaledRegistration == nil { + continue + } if err := proto.Unmarshal(marshaledRegistration, registration); err != nil { return nil, err @@ -97,7 +109,7 @@ func (p *SQLitePersistence) GetPushNotificationRegistrationByPublicKeys(publicKe } func (p *SQLitePersistence) GetPushNotificationRegistrationPublicKeys() ([][]byte, error) { - rows, err := p.db.Query(`SELECT public_key FROM push_notification_server_registrations`) + rows, err := p.db.Query(`SELECT public_key FROM push_notification_server_registrations WHERE registration IS NOT NULL`) if err != nil { return nil, err } diff --git a/protocol/pushnotificationserver/server.go b/protocol/pushnotificationserver/server.go index 92adb5c07..cad6a4331 100644 --- a/protocol/pushnotificationserver/server.go +++ b/protocol/pushnotificationserver/server.go @@ -336,11 +336,10 @@ func (s *Server) buildPushNotificationRequestResponseAndSendNotification(request if err != nil { s.config.Logger.Error("failed to retrieve registration", zap.Error(err)) report.Error = protobuf.PushNotificationReport_UNKNOWN_ERROR_TYPE - } else if registration == nil { + } else if registration == nil || registration.AccessToken == "" { s.config.Logger.Warn("empty registration") report.Error = protobuf.PushNotificationReport_NOT_REGISTERED } else if registration.AccessToken != pn.AccessToken { - s.config.Logger.Warn("invalid access token") report.Error = protobuf.PushNotificationReport_WRONG_TOKEN } else { // For now we just assume that the notification will be successful @@ -403,6 +402,7 @@ func (s *Server) buildPushNotificationRegistrationResponse(publicKey *ecdsa.Publ } if registration.Unregister { + s.config.Logger.Info("unregistering client") // We save an empty registration, only keeping version and installation-id emptyRegistration := &protobuf.PushNotificationRegistration{ Version: registration.Version, diff --git a/services/ext/api.go b/services/ext/api.go index 32fcb2a77..cb612cf60 100644 --- a/services/ext/api.go +++ b/services/ext/api.go @@ -444,7 +444,7 @@ func (api *PublicAPI) RegisterForPushNotifications(ctx context.Context, deviceTo return api.service.messenger.RegisterForPushNotifications(ctx, deviceToken, apnTopic, tokenType) } -func (api *PublicAPI) UnregisterForPushNotifications(ctx context.Context) error { +func (api *PublicAPI) UnregisterFromPushNotifications(ctx context.Context) error { err := api.service.accountsDB.SaveSetting("remote-push-notifications-enabled?", false) if err != nil { return err