diff --git a/protocol/pushnotificationserver/persistence.go b/protocol/pushnotificationserver/persistence.go index d0638c6ae..b8de8a04f 100644 --- a/protocol/pushnotificationserver/persistence.go +++ b/protocol/pushnotificationserver/persistence.go @@ -16,9 +16,14 @@ type Persistence interface { GetPushNotificationRegistrationByPublicKeyAndInstallationID(publicKey []byte, installationID string) (*protobuf.PushNotificationRegistration, error) // GetPushNotificationRegistrationByPublicKey retrieve all the push notification registrations from storage given a public key GetPushNotificationRegistrationByPublicKeys(publicKeys [][]byte) ([]*PushNotificationIDAndRegistration, error) - //GetPushNotificationRegistrationPublicKeys return all the public keys stored + // GetPushNotificationRegistrationPublicKeys return all the public keys stored GetPushNotificationRegistrationPublicKeys() ([][]byte, error) + //GetPushNotificationRegistrationVersion returns the latest version or 0 for a given pk and installationID + GetPushNotificationRegistrationVersion(publicKey []byte, installationID string) (uint64, error) + // UnregisterPushNotificationRegistration unregister a given pk/installationID + UnregisterPushNotificationRegistration(publicKey []byte, installationID string, version uint64) error + // DeletePushNotificationRegistration deletes a push notification registration from storage given a public key and installation id DeletePushNotificationRegistration(publicKey []byte, installationID string) error // SavePushNotificationRegistration saves a push notification option to the db @@ -39,8 +44,7 @@ func NewSQLitePersistence(db *sql.DB) Persistence { func (p *SQLitePersistence) GetPushNotificationRegistrationByPublicKeyAndInstallationID(publicKey []byte, installationID string) (*protobuf.PushNotificationRegistration, error) { var marshaledRegistration []byte - var version uint64 - err := p.db.QueryRow(`SELECT version,registration FROM push_notification_server_registrations WHERE public_key = ? AND installation_id = ?`, publicKey, installationID).Scan(&version, &marshaledRegistration) + err := p.db.QueryRow(`SELECT registration FROM push_notification_server_registrations WHERE public_key = ? AND installation_id = ? AND registration IS NOT NULL`, publicKey, installationID).Scan(&marshaledRegistration) if err == sql.ErrNoRows { return nil, nil @@ -48,21 +52,25 @@ func (p *SQLitePersistence) GetPushNotificationRegistrationByPublicKeyAndInstall return nil, err } - registration := &protobuf.PushNotificationRegistration{ - InstallationId: installationID, - Version: version, - } - - if marshaledRegistration == nil { - return registration, nil - } - + registration := &protobuf.PushNotificationRegistration{} if err := proto.Unmarshal(marshaledRegistration, registration); err != nil { return nil, err } return registration, nil } +func (p *SQLitePersistence) GetPushNotificationRegistrationVersion(publicKey []byte, installationID string) (uint64, error) { + registration, err := p.GetPushNotificationRegistrationByPublicKeyAndInstallationID(publicKey, installationID) + if err != nil { + return 0, err + } + + if registration == nil { + return 0, nil + } + return registration.Version, nil +} + type PushNotificationIDAndRegistration struct { ID []byte Registration *protobuf.PushNotificationRegistration @@ -78,7 +86,7 @@ func (p *SQLitePersistence) GetPushNotificationRegistrationByPublicKeys(publicKe inVector := strings.Repeat("?, ", len(publicKeys)-1) + "?" - rows, err := p.db.Query(`SELECT public_key,registration FROM push_notification_server_registrations WHERE public_key IN (`+inVector+`)`, publicKeyArgs...) // nolint: gosec + rows, err := p.db.Query(`SELECT public_key,registration FROM push_notification_server_registrations WHERE registration IS NOT NULL AND public_key IN (`+inVector+`)`, publicKeyArgs...) // nolint: gosec if err != nil { return nil, err } @@ -138,6 +146,11 @@ func (p *SQLitePersistence) SavePushNotificationRegistration(publicKey []byte, r return err } +func (p *SQLitePersistence) UnregisterPushNotificationRegistration(publicKey []byte, installationID string, version uint64) error { + _, err := p.db.Exec(`UPDATE push_notification_server_registrations SET registration = NULL, version = ? WHERE public_key = ? AND installation_id = ?`, version, publicKey, installationID) + return err +} + func (p *SQLitePersistence) DeletePushNotificationRegistration(publicKey []byte, installationID string) error { _, err := p.db.Exec(`DELETE FROM push_notification_server_registrations WHERE public_key = ? AND installation_id = ?`, publicKey, installationID) return err diff --git a/protocol/pushnotificationserver/server.go b/protocol/pushnotificationserver/server.go index cad6a4331..8549c9c3e 100644 --- a/protocol/pushnotificationserver/server.go +++ b/protocol/pushnotificationserver/server.go @@ -233,12 +233,12 @@ func (s *Server) validateRegistration(publicKey *ecdsa.PublicKey, payload []byte return nil, ErrMalformedPushNotificationRegistrationInstallationID } - previousRegistration, err := s.persistence.GetPushNotificationRegistrationByPublicKeyAndInstallationID(common.HashPublicKey(publicKey), registration.InstallationId) + previousVersion, err := s.persistence.GetPushNotificationRegistrationVersion(common.HashPublicKey(publicKey), registration.InstallationId) if err != nil { return nil, err } - if previousRegistration != nil && registration.Version <= previousRegistration.Version { + if registration.Version <= previousVersion { return nil, ErrInvalidPushNotificationRegistrationVersion } @@ -290,6 +290,7 @@ func (s *Server) buildPushNotificationQueryResponse(query *protobuf.PushNotifica for _, idAndResponse := range registrations { registration := idAndResponse.Registration + info := &protobuf.PushNotificationQueryInfo{ PublicKey: idAndResponse.ID, Grant: registration.Grant, @@ -336,7 +337,7 @@ func (s *Server) buildPushNotificationRequestResponseAndSendNotification(request if err != nil { s.config.Logger.Error("failed to retrieve registration", zap.Error(err)) report.Error = protobuf.PushNotificationReport_UNKNOWN_ERROR_TYPE - } else if registration == nil || registration.AccessToken == "" { + } else if registration == nil { s.config.Logger.Warn("empty registration") report.Error = protobuf.PushNotificationReport_NOT_REGISTERED } else if registration.AccessToken != pn.AccessToken { @@ -404,11 +405,7 @@ func (s *Server) buildPushNotificationRegistrationResponse(publicKey *ecdsa.Publ if registration.Unregister { s.config.Logger.Info("unregistering client") // We save an empty registration, only keeping version and installation-id - emptyRegistration := &protobuf.PushNotificationRegistration{ - Version: registration.Version, - InstallationId: registration.InstallationId, - } - if err := s.persistence.SavePushNotificationRegistration(common.HashPublicKey(publicKey), emptyRegistration); err != nil { + if err := s.persistence.UnregisterPushNotificationRegistration(common.HashPublicKey(publicKey), registration.InstallationId, registration.Version); err != nil { response.Error = protobuf.PushNotificationRegistrationResponse_INTERNAL_ERROR s.config.Logger.Error("failed to unregister ", zap.Error(err)) return response