diff --git a/protocol/push_notification_client/client.go b/protocol/push_notification_client/client.go index b2a895fa9..265a0f155 100644 --- a/protocol/push_notification_client/client.go +++ b/protocol/push_notification_client/client.go @@ -87,6 +87,9 @@ type Client struct { // lastPushNotificationRegistration is the latest known push notification version lastPushNotificationRegistration *protobuf.PushNotificationRegistration + // lastContactIDs is the latest contact ids array + lastContactIDs []*ecdsa.PublicKey + // AccessToken is the access token that is currently being used AccessToken string // DeviceToken is the device token for this device @@ -132,13 +135,14 @@ func (c *Client) subscribeForSentMessages() { } func (c *Client) loadLastPushNotificationRegistration() error { - lastRegistration, err := c.persistence.GetLastPushNotificationRegistration() + lastRegistration, lastContactIDs, err := c.persistence.GetLastPushNotificationRegistration() if err != nil { return err } if lastRegistration == nil { lastRegistration = &protobuf.PushNotificationRegistration{} } + c.lastContactIDs = lastContactIDs c.lastPushNotificationRegistration = lastRegistration return nil @@ -394,8 +398,13 @@ func (p *Client) allowedUserList(token []byte, contactIDs []*ecdsa.PublicKey) ([ return encryptedTokens, nil } -func (p *Client) getToken() string { - return uuid.New().String() +// getToken checks if we need to refresh the token +// and return a new one in that case +func (c *Client) getToken(contactIDs []*ecdsa.PublicKey) string { + if c.lastPushNotificationRegistration == nil || len(c.lastPushNotificationRegistration.AccessToken) == 0 || c.shouldRefreshToken(c.lastContactIDs, contactIDs) { + return uuid.New().String() + } + return c.lastPushNotificationRegistration.AccessToken } func (c *Client) getVersion() uint64 { @@ -406,7 +415,7 @@ func (c *Client) getVersion() uint64 { } func (c *Client) buildPushNotificationRegistrationMessage(contactIDs []*ecdsa.PublicKey, mutedChatIDs []string) (*protobuf.PushNotificationRegistration, error) { - token := c.getToken() + token := c.getToken(contactIDs) allowedUserList, err := c.allowedUserList([]byte(token), contactIDs) if err != nil { return nil, err diff --git a/protocol/push_notification_client/migrations/migrations.go b/protocol/push_notification_client/migrations/migrations.go index e4c2c0cd3..6f5d47a54 100644 --- a/protocol/push_notification_client/migrations/migrations.go +++ b/protocol/push_notification_client/migrations/migrations.go @@ -1,7 +1,7 @@ // Code generated by go-bindata. DO NOT EDIT. // sources: // 1593601729_initial_schema.down.sql (144B) -// 1593601729_initial_schema.up.sql (1.474kB) +// 1593601729_initial_schema.up.sql (1.496kB) // doc.go (382B) package migrations @@ -91,7 +91,7 @@ func _1593601729_initial_schemaDownSql() (*asset, error) { return a, nil } -var __1593601729_initial_schemaUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xac\x94\xc1\x8e\x82\x30\x10\x86\xef\x3c\xc5\x1c\x25\xe1\xb0\x77\x4f\x80\x65\x43\xd2\xb4\xbb\x5a\x12\x6f\x0d\x5b\xaa\x36\xb2\xe0\xb6\xd5\xac\x6f\xbf\x41\x44\x51\x14\x36\xea\x85\x84\xf6\xef\x74\xbe\xf9\x3b\x13\x4e\x91\xcf\x10\x30\x3f\xc0\x08\xe2\x08\x08\x65\x80\xe6\xf1\x8c\xcd\x60\xb3\x35\x2b\x5e\x94\x56\x2d\x94\x48\xad\x2a\x0b\x2e\x72\x25\x0b\xcb\x8d\xd4\x3b\xa9\x0d\x8c\x1c\x80\xcd\xf6\x2b\x57\x82\xaf\xe5\x1e\x02\x4c\x83\xc3\x79\x92\x60\xec\x39\x00\x5a\x2e\x95\xb1\x52\xcb\x0c\x02\x4a\x31\xf2\x09\x4c\x50\xe4\x27\x98\x41\xe4\xe3\x19\xba\xd4\xf0\xd4\x42\x4c\xd8\x29\xc2\x49\xfb\x56\xe9\x52\x21\xa4\x31\xdc\x96\x6b\x59\x00\x43\x73\x56\x2d\x26\x24\xfe\x4c\xd0\xe8\x9c\x83\x0b\x94\x40\x48\x49\x84\xe3\x90\xc1\x14\x7d\x60\x3f\x44\x8e\x3b\x76\x9c\x47\x38\x7f\xb6\x52\x2b\x39\xcc\x59\xeb\x3a\x00\xcd\xd6\x9e\xab\xac\x7b\xa8\x93\xbb\xd7\x68\x5f\x0b\xa1\x8a\x45\x39\x48\x50\x3b\xca\xfb\x24\xaa\x30\x36\xcd\xf3\x3a\xb6\xca\x0e\x1e\x5c\x08\x3a\x0e\x5d\xbd\x05\xab\x95\xdc\xdd\xae\x52\xb7\x16\xd7\xd7\x79\xdd\x14\x5f\x5b\x26\xab\x53\xb1\x96\x19\xff\x96\xc6\xa4\xcb\xa3\xe9\xc7\x9f\x9b\xfe\x89\x55\x6a\x6f\xd6\xa1\x89\x74\x9f\xf3\x1c\xf6\x92\x21\x7e\x27\x74\x8a\x1c\x80\x47\x21\x4c\xf5\x69\x6f\x0c\x63\x3c\x65\xf9\xe1\xbe\xff\x70\x7a\xd0\xe3\xad\xfb\x04\x71\x3d\x3e\x74\x0b\xb6\x19\x29\xf5\x5a\x17\x0a\xc0\xec\x0b\xbb\x92\x56\x89\x8a\xe9\xfe\xc4\x39\x61\xb4\xf5\x83\x8f\x2e\x26\x13\x34\x07\x95\xfd\xf2\xde\x8e\x6c\xb7\x1a\x25\xfd\xdd\xdb\xd7\x17\xee\xd8\xf9\x0b\x00\x00\xff\xff\xd9\x16\xce\x6d\xc2\x05\x00\x00") +var __1593601729_initial_schemaUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xac\x94\xc1\x8e\x9b\x30\x10\x86\xef\x3c\xc5\x1c\x17\x89\x43\xef\x7b\x02\xd6\x54\x48\x96\xdd\x26\x46\xca\xcd\x72\x8d\x77\x63\x85\x9a\xd4\x76\xa2\xe6\xed\x2b\x20\x24\x24\x24\x50\x25\xb9\x20\x61\xff\x1e\xcf\x37\xbf\x67\xd2\x05\x8a\x19\x02\x16\x27\x18\x41\x9e\x01\xa1\x0c\xd0\x2a\x5f\xb2\x25\x6c\x77\x6e\xcd\x4d\xed\xf5\xa7\x96\xc2\xeb\xda\x70\x59\x69\x65\x3c\x77\xca\xee\x95\x75\xf0\x16\x00\x6c\x77\xbf\x2a\x2d\xf9\x46\x1d\x20\xc1\x34\x69\xcf\x93\x02\xe3\x28\x00\xb0\xea\x4b\x3b\xaf\xac\x2a\x21\xa1\x14\xa3\x98\xc0\x07\xca\xe2\x02\x33\xc8\x62\xbc\x44\x97\x1a\x2e\x3c\xe4\x84\x9d\x22\x9c\xb4\xdf\x1a\x9d\x90\x52\x39\xc7\x7d\xbd\x51\x06\x18\x5a\xb1\x66\xb1\x20\xf9\xcf\x02\xbd\x9d\x73\x08\x81\x12\x48\x29\xc9\x70\x9e\x32\x58\xa0\x1f\x38\x4e\x51\x10\xbe\x07\xc1\x23\x9c\x7f\x76\xca\x6a\x35\xcf\xd9\xe9\x46\x00\xfd\xd6\x81\xeb\x72\x7c\x68\x94\x7b\xd4\x6b\x5f\x0b\xa1\xcd\x67\x3d\x4b\xd0\x39\xca\xa7\x24\xda\x38\x2f\xaa\xaa\x8b\xad\xcb\xd6\x83\x0b\xc1\xc8\xa1\xab\xb7\xe0\xad\x56\xfb\xdb\x55\x1a\xd7\xe2\xfa\xba\x68\x9c\xe2\x6b\xcb\xe4\xad\x90\x1b\x55\xf2\xdf\xca\x39\xf1\x75\x34\xfd\xf8\x73\xd3\x3f\xb9\x16\xfe\x66\x1d\xfa\x48\xf7\x39\xcf\x61\x2f\x19\xf2\xef\x84\x2e\x50\x00\xf0\x28\x84\x6b\x3e\xc3\x8d\x79\x8c\xa7\x2c\x6f\xef\xfb\x1f\xce\x08\x26\xbc\x0d\x9f\x20\xee\xc6\x87\x1d\xc0\xf6\x23\xa5\x5b\x1b\x43\x01\xc8\xda\x78\x21\x1b\xf3\x5c\xbb\xdd\xad\xba\x83\xf1\x6b\xe5\xb5\x6c\x48\xef\xcf\xa1\x13\xdc\x50\x3f\xfb\x14\x73\xf2\x81\x56\xa0\xcb\xbf\x7c\xb2\x4f\x87\x0d\x48\xc9\x74\x4f\x4f\x75\x4b\xf8\x1e\xfc\x0b\x00\x00\xff\xff\x1c\x18\x75\x11\xd8\x05\x00\x00") func _1593601729_initial_schemaUpSqlBytes() ([]byte, error) { return bindataRead( @@ -106,8 +106,8 @@ func _1593601729_initial_schemaUpSql() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "1593601729_initial_schema.up.sql", size: 1474, mode: os.FileMode(0644), modTime: time.Unix(1594797308, 0)} - a := &asset{bytes: bytes, info: info, digest: [32]uint8{0x3b, 0xc0, 0x6a, 0xde, 0x9a, 0xc, 0x25, 0xf3, 0x24, 0xbd, 0x44, 0xd5, 0x38, 0xd4, 0x65, 0xe5, 0x51, 0xf1, 0x2b, 0x3e, 0x84, 0x99, 0x65, 0x39, 0xb6, 0x2a, 0x79, 0x4b, 0xc5, 0x9d, 0x25, 0x75}} + info := bindataFileInfo{name: "1593601729_initial_schema.up.sql", size: 1496, mode: os.FileMode(0644), modTime: time.Unix(1594800672, 0)} + a := &asset{bytes: bytes, info: info, digest: [32]uint8{0xe7, 0xb6, 0xaf, 0x11, 0x19, 0xf8, 0xc6, 0xed, 0x2f, 0xf5, 0x42, 0x54, 0xdd, 0x65, 0xf7, 0x39, 0xbc, 0x19, 0xff, 0x72, 0xa1, 0x38, 0x87, 0xfa, 0x6d, 0xd5, 0xe3, 0x6a, 0x49, 0x65, 0x3c, 0x49}} return a, nil } diff --git a/protocol/push_notification_client/migrations/sql/1593601729_initial_schema.up.sql b/protocol/push_notification_client/migrations/sql/1593601729_initial_schema.up.sql index 92050ee9a..c4c6b6999 100644 --- a/protocol/push_notification_client/migrations/sql/1593601729_initial_schema.up.sql +++ b/protocol/push_notification_client/migrations/sql/1593601729_initial_schema.up.sql @@ -39,6 +39,7 @@ CREATE TABLE IF NOT EXISTS push_notification_client_sent_notifications ( CREATE TABLE IF NOT EXISTS push_notification_client_registrations ( registration BLOB NOT NULL, + contact_ids BLOB, synthetic_id INT NOT NULL DEFAULT 0, UNIQUE(synthetic_id) ON CONFLICT REPLACE ); diff --git a/protocol/push_notification_client/persistence.go b/protocol/push_notification_client/persistence.go index f3e6c3dbd..4afe0f815 100644 --- a/protocol/push_notification_client/persistence.go +++ b/protocol/push_notification_client/persistence.go @@ -1,9 +1,11 @@ package push_notification_client import ( + "bytes" "context" "crypto/ecdsa" "database/sql" + "encoding/gob" "strings" "time" @@ -21,33 +23,59 @@ func NewPersistence(db *sql.DB) *Persistence { return &Persistence{db: db} } -func (p *Persistence) GetLastPushNotificationRegistration() (*protobuf.PushNotificationRegistration, error) { +func (p *Persistence) GetLastPushNotificationRegistration() (*protobuf.PushNotificationRegistration, []*ecdsa.PublicKey, error) { var registrationBytes []byte - err := p.db.QueryRow(`SELECT registration FROM push_notification_client_registrations LIMIT 1`).Scan(®istrationBytes) + var contactIDsBytes []byte + err := p.db.QueryRow(`SELECT registration,contact_ids FROM push_notification_client_registrations LIMIT 1`).Scan(®istrationBytes, &contactIDsBytes) if err == sql.ErrNoRows { - return nil, nil + return nil, nil, nil } else if err != nil { - return nil, err + return nil, nil, err + } + + var publicKeyBytes [][]byte + var contactIDs []*ecdsa.PublicKey + // Restore contactIDs + contactIDsDecoder := gob.NewDecoder(bytes.NewBuffer(contactIDsBytes)) + err = contactIDsDecoder.Decode(&publicKeyBytes) + if err != nil { + return nil, nil, err + } + for _, pkBytes := range publicKeyBytes { + pk, err := crypto.UnmarshalPubkey(pkBytes) + if err != nil { + return nil, nil, err + } + contactIDs = append(contactIDs, pk) } registration := &protobuf.PushNotificationRegistration{} err = proto.Unmarshal(registrationBytes, registration) if err != nil { - return nil, err + return nil, nil, err } - return registration, nil + return registration, contactIDs, nil } -func (p *Persistence) SaveLastPushNotificationRegistration(registration *protobuf.PushNotificationRegistration) error { +func (p *Persistence) SaveLastPushNotificationRegistration(registration *protobuf.PushNotificationRegistration, contactIDs []*ecdsa.PublicKey) error { + var encodedContactIDs bytes.Buffer + var contactIDsBytes [][]byte + for _, pk := range contactIDs { + contactIDsBytes = append(contactIDsBytes, crypto.FromECDSAPub(pk)) + } + pkEncoder := gob.NewEncoder(&encodedContactIDs) + if err := pkEncoder.Encode(contactIDsBytes); err != nil { + return err + } + marshaledRegistration, err := proto.Marshal(registration) if err != nil { return err } - _, err = p.db.Exec(`INSERT INTO push_notification_client_registrations (registration) VALUES (?)`, marshaledRegistration) + _, err = p.db.Exec(`INSERT INTO push_notification_client_registrations (registration,contact_ids) VALUES (?, ?)`, marshaledRegistration, encodedContactIDs.Bytes()) return err - } func (p *Persistence) TrackPushNotification(chatID string, messageID []byte) error { diff --git a/protocol/push_notification_client/persistence_test.go b/protocol/push_notification_client/persistence_test.go index 709ce1d6b..0cf2947c3 100644 --- a/protocol/push_notification_client/persistence_test.go +++ b/protocol/push_notification_client/persistence_test.go @@ -1,6 +1,7 @@ package push_notification_client import ( + "crypto/ecdsa" "io/ioutil" "os" "testing" @@ -143,9 +144,10 @@ func (s *SQLitePersistenceSuite) TestSaveAndRetrieveInfo() { func (s *SQLitePersistenceSuite) TestSaveAndRetrieveRegistration() { // Try with nil first - retrievedRegistration, err := s.persistence.GetLastPushNotificationRegistration() + retrievedRegistration, retrievedContactIDs, err := s.persistence.GetLastPushNotificationRegistration() s.Require().NoError(err) s.Require().Nil(retrievedRegistration) + s.Require().Nil(retrievedContactIDs) // Save & retrieve registration registration := &protobuf.PushNotificationRegistration{ @@ -153,16 +155,30 @@ func (s *SQLitePersistenceSuite) TestSaveAndRetrieveRegistration() { Version: 3, } - s.Require().NoError(s.persistence.SaveLastPushNotificationRegistration(registration)) - retrievedRegistration, err = s.persistence.GetLastPushNotificationRegistration() + key1, err := crypto.GenerateKey() + s.Require().NoError(err) + + key2, err := crypto.GenerateKey() + s.Require().NoError(err) + + key3, err := crypto.GenerateKey() + s.Require().NoError(err) + + publicKeys := []*ecdsa.PublicKey{&key1.PublicKey, &key2.PublicKey} + + s.Require().NoError(s.persistence.SaveLastPushNotificationRegistration(registration, publicKeys)) + retrievedRegistration, retrievedContactIDs, err = s.persistence.GetLastPushNotificationRegistration() s.Require().NoError(err) s.Require().True(proto.Equal(registration, retrievedRegistration)) + s.Require().Equal(publicKeys, retrievedContactIDs) // Override and retrieve registration.Version = 5 - s.Require().NoError(s.persistence.SaveLastPushNotificationRegistration(registration)) - retrievedRegistration, err = s.persistence.GetLastPushNotificationRegistration() + publicKeys = append(publicKeys, &key3.PublicKey) + s.Require().NoError(s.persistence.SaveLastPushNotificationRegistration(registration, publicKeys)) + retrievedRegistration, retrievedContactIDs, err = s.persistence.GetLastPushNotificationRegistration() s.Require().NoError(err) s.Require().True(proto.Equal(registration, retrievedRegistration)) + s.Require().Equal(publicKeys, retrievedContactIDs) }