diff --git a/protocol/pushnotificationserver/persistence.go b/protocol/pushnotificationserver/persistence.go index b8de8a04f..7ba19d99b 100644 --- a/protocol/pushnotificationserver/persistence.go +++ b/protocol/pushnotificationserver/persistence.go @@ -60,15 +60,16 @@ func (p *SQLitePersistence) GetPushNotificationRegistrationByPublicKeyAndInstall } func (p *SQLitePersistence) GetPushNotificationRegistrationVersion(publicKey []byte, installationID string) (uint64, error) { - registration, err := p.GetPushNotificationRegistrationByPublicKeyAndInstallationID(publicKey, installationID) - if err != nil { + var version uint64 + err := p.db.QueryRow(`SELECT version FROM push_notification_server_registrations WHERE public_key = ? AND installation_id = ?`, publicKey, installationID).Scan(&version) + + if err == sql.ErrNoRows { + return 0, nil + } else if err != nil { return 0, err } - if registration == nil { - return 0, nil - } - return registration.Version, nil + return version, nil } type PushNotificationIDAndRegistration struct { diff --git a/protocol/pushnotificationserver/server_test.go b/protocol/pushnotificationserver/server_test.go index 77613ce50..84e2245f4 100644 --- a/protocol/pushnotificationserver/server_test.go +++ b/protocol/pushnotificationserver/server_test.go @@ -500,16 +500,16 @@ func (s *ServerSuite) TestPushNotificationHandleRegistration() { response = s.server.buildPushNotificationRegistrationResponse(&s.key.PublicKey, cyphertext) s.Require().NotNil(response) s.Require().True(response.Success) + s.Require().Equal(common.Shake256(cyphertext), response.RequestId) // Check is gone from the db retrievedRegistration, err = s.persistence.GetPushNotificationRegistrationByPublicKeyAndInstallationID(common.HashPublicKey(&s.key.PublicKey), s.installationID) s.Require().NoError(err) - s.Require().NotNil(retrievedRegistration) - s.Require().Empty(retrievedRegistration.AccessToken) - s.Require().Empty(retrievedRegistration.DeviceToken) - s.Require().Equal(uint64(2), retrievedRegistration.Version) - s.Require().Equal(s.installationID, retrievedRegistration.InstallationId) - s.Require().Equal(common.Shake256(cyphertext), response.RequestId) + s.Require().Nil(retrievedRegistration) + // Check version is mantained + version, err := s.persistence.GetPushNotificationRegistrationVersion(common.HashPublicKey(&s.key.PublicKey), s.installationID) + s.Require().NoError(err) + s.Require().Equal(uint64(2), version) } func (s *ServerSuite) TestbuildPushNotificationQueryResponseNoFiltering() {