From 6f8c3025f3ed8e053f2665b2bc313970a34b495e Mon Sep 17 00:00:00 2001 From: Andrea Maria Piana Date: Wed, 15 Jul 2020 09:23:31 +0200 Subject: [PATCH] Retrieve and add push notification registration --- protocol/push_notification_client/client.go | 169 ++++++++++++------ .../push_notification_client/client_test.go | 21 +++ .../migrations/migrations.go | 8 +- .../sql/1593601729_initial_schema.up.sql | 6 + .../push_notification_client/persistence.go | 32 ++++ .../persistence_test.go | 28 +++ 6 files changed, 203 insertions(+), 61 deletions(-) diff --git a/protocol/push_notification_client/client.go b/protocol/push_notification_client/client.go index fdc75dc61..b2a895fa9 100644 --- a/protocol/push_notification_client/client.go +++ b/protocol/push_notification_client/client.go @@ -84,8 +84,8 @@ type Client struct { quit chan struct{} config *Config - // lastPushNotificationVersion is the latest known push notification version - lastPushNotificationVersion uint64 + // lastPushNotificationRegistration is the latest known push notification version + lastPushNotificationRegistration *protobuf.PushNotificationRegistration // AccessToken is the access token that is currently being used AccessToken string @@ -109,11 +109,7 @@ func New(persistence *Persistence, config *Config, processor *common.MessageProc reader: rand.Reader} } -func (c *Client) Start() error { - if c.messageProcessor == nil { - return errors.New("can't start, missing message processor") - } - +func (c *Client) subscribeForSentMessages() { go func() { subscription := c.messageProcessor.Subscribe() for { @@ -132,6 +128,33 @@ func (c *Client) Start() error { } } }() + +} + +func (c *Client) loadLastPushNotificationRegistration() error { + lastRegistration, err := c.persistence.GetLastPushNotificationRegistration() + if err != nil { + return err + } + if lastRegistration == nil { + lastRegistration = &protobuf.PushNotificationRegistration{} + } + c.lastPushNotificationRegistration = lastRegistration + return nil + +} + +func (c *Client) Start() error { + if c.messageProcessor == nil { + return errors.New("can't start, missing message processor") + } + + err := c.loadLastPushNotificationRegistration() + if err != nil { + return err + } + c.subscribeForSentMessages() + return nil } @@ -371,26 +394,95 @@ func (p *Client) allowedUserList(token []byte, contactIDs []*ecdsa.PublicKey) ([ return encryptedTokens, nil } -func (p *Client) buildPushNotificationRegistrationMessage(contactIDs []*ecdsa.PublicKey, mutedChatIDs []string) (*protobuf.PushNotificationRegistration, error) { - token := uuid.New().String() - allowedUserList, err := p.allowedUserList([]byte(token), contactIDs) +func (p *Client) getToken() string { + return uuid.New().String() + +} +func (c *Client) getVersion() uint64 { + if c.lastPushNotificationRegistration == nil { + return 1 + } + return c.lastPushNotificationRegistration.Version + 1 +} + +func (c *Client) buildPushNotificationRegistrationMessage(contactIDs []*ecdsa.PublicKey, mutedChatIDs []string) (*protobuf.PushNotificationRegistration, error) { + token := c.getToken() + allowedUserList, err := c.allowedUserList([]byte(token), contactIDs) if err != nil { return nil, err } options := &protobuf.PushNotificationRegistration{ AccessToken: token, - TokenType: p.config.TokenType, - Version: p.lastPushNotificationVersion + 1, - InstallationId: p.config.InstallationID, - Token: p.DeviceToken, - Enabled: p.config.RemoteNotificationsEnabled, - BlockedChatList: p.mutedChatIDsHashes(mutedChatIDs), + TokenType: c.config.TokenType, + Version: c.getVersion(), + InstallationId: c.config.InstallationID, + Token: c.DeviceToken, + Enabled: c.config.RemoteNotificationsEnabled, + BlockedChatList: c.mutedChatIDsHashes(mutedChatIDs), AllowedUserList: allowedUserList, } return options, nil } +// shouldRefreshToken tells us whether we should pull a new token, that's only necessary when a contact is removed +func (c *Client) shouldRefreshToken(oldContactIDs, newContactIDs []*ecdsa.PublicKey) bool { + newContactIDsMap := make(map[string]bool) + for _, pk := range newContactIDs { + newContactIDsMap[types.EncodeHex(crypto.FromECDSAPub(pk))] = true + } + + for _, pk := range oldContactIDs { + if ok := newContactIDsMap[types.EncodeHex(crypto.FromECDSAPub(pk))]; !ok { + return true + } + + } + return false +} + +func (c *Client) registerWithServer(registration *protobuf.PushNotificationRegistration, server *PushNotificationServer) error { + // Reset server registration data + server.Registered = false + server.RegisteredAt = 0 + server.AccessToken = registration.AccessToken + + if err := c.persistence.UpsertServer(server); err != nil { + return err + } + + grant, err := c.buildGrantSignature(server.PublicKey, registration.AccessToken) + if err != nil { + c.config.Logger.Error("failed to build grant", zap.Error(err)) + return err + } + + registration.Grant = grant + + marshaledRegistration, err := proto.Marshal(registration) + if err != nil { + return err + } + + // Dispatch message + encryptedRegistration, err := c.encryptRegistration(server.PublicKey, marshaledRegistration) + if err != nil { + return err + } + rawMessage := &common.RawMessage{ + Payload: encryptedRegistration, + MessageType: protobuf.ApplicationMetadataMessage_PUSH_NOTIFICATION_REGISTRATION, + } + + _, err = c.messageProcessor.SendPrivate(context.Background(), server.PublicKey, rawMessage) + + if err != nil { + return err + } + return nil + +} + func (c *Client) Register(deviceToken string, contactIDs []*ecdsa.PublicKey, mutedChatIDs []string) ([]*PushNotificationServer, error) { c.DeviceToken = deviceToken servers, err := c.persistence.GetServers() @@ -409,47 +501,11 @@ func (c *Client) Register(deviceToken string, contactIDs []*ecdsa.PublicKey, mut var serverPublicKeys []*ecdsa.PublicKey for _, server := range servers { - - // Reset server registration data - server.Registered = false - server.RegisteredAt = 0 - server.AccessToken = registration.AccessToken + err := c.registerWithServer(registration, server) + if err != nil { + return nil, err + } serverPublicKeys = append(serverPublicKeys, server.PublicKey) - - if err := c.persistence.UpsertServer(server); err != nil { - return nil, err - } - - grant, err := c.buildGrantSignature(server.PublicKey, registration.AccessToken) - if err != nil { - c.config.Logger.Error("failed to build grant", zap.Error(err)) - return nil, err - } - - c.config.Logger.Info("GRANT2", zap.Binary("GRANT", grant)) - registration.Grant = grant - - marshaledRegistration, err := proto.Marshal(registration) - if err != nil { - return nil, err - } - - // Dispatch message - encryptedRegistration, err := c.encryptRegistration(server.PublicKey, marshaledRegistration) - if err != nil { - return nil, err - } - rawMessage := &common.RawMessage{ - Payload: encryptedRegistration, - MessageType: protobuf.ApplicationMetadataMessage_PUSH_NOTIFICATION_REGISTRATION, - } - - _, err = c.messageProcessor.SendPrivate(context.Background(), server.PublicKey, rawMessage) - - if err != nil { - return nil, err - } - } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -538,7 +594,6 @@ func (c *Client) buildGrantSignature(serverPublicKey *ecdsa.PublicKey, accessTok func (c *Client) handleGrant(clientPublicKey *ecdsa.PublicKey, serverPublicKey *ecdsa.PublicKey, grant []byte, accessToken string) error { signatureMaterial := c.buildGrantSignatureMaterial(clientPublicKey, serverPublicKey, accessToken) - c.config.Logger.Info("GRANT", zap.Binary("GRANT", grant)) extractedPublicKey, err := crypto.SigToPub(signatureMaterial, grant) if err != nil { return err diff --git a/protocol/push_notification_client/client_test.go b/protocol/push_notification_client/client_test.go index 852ea6ec1..0306fe044 100644 --- a/protocol/push_notification_client/client_test.go +++ b/protocol/push_notification_client/client_test.go @@ -156,3 +156,24 @@ func (s *ClientSuite) TestNotifyOnMessageID() { s.Require().NoError(err) s.Require().False(response) } + +func (s *ClientSuite) TestShouldRefreshToken() { + key1, err := crypto.GenerateKey() + s.Require().NoError(err) + key2, err := crypto.GenerateKey() + s.Require().NoError(err) + key3, err := crypto.GenerateKey() + s.Require().NoError(err) + key4, err := crypto.GenerateKey() + s.Require().NoError(err) + + // Contacts are added + s.Require().False(s.client.shouldRefreshToken([]*ecdsa.PublicKey{&key1.PublicKey, &key2.PublicKey}, []*ecdsa.PublicKey{&key1.PublicKey, &key2.PublicKey, &key3.PublicKey, &key4.PublicKey})) + + // everything the same + s.Require().False(s.client.shouldRefreshToken([]*ecdsa.PublicKey{&key1.PublicKey, &key2.PublicKey}, []*ecdsa.PublicKey{&key2.PublicKey, &key1.PublicKey})) + + // A contact is removed + s.Require().True(s.client.shouldRefreshToken([]*ecdsa.PublicKey{&key1.PublicKey, &key2.PublicKey}, []*ecdsa.PublicKey{&key2.PublicKey})) + +} diff --git a/protocol/push_notification_client/migrations/migrations.go b/protocol/push_notification_client/migrations/migrations.go index 7752d8ed8..e4c2c0cd3 100644 --- a/protocol/push_notification_client/migrations/migrations.go +++ b/protocol/push_notification_client/migrations/migrations.go @@ -1,7 +1,7 @@ // Code generated by go-bindata. DO NOT EDIT. // sources: // 1593601729_initial_schema.down.sql (144B) -// 1593601729_initial_schema.up.sql (1.284kB) +// 1593601729_initial_schema.up.sql (1.474kB) // doc.go (382B) package migrations @@ -91,7 +91,7 @@ func _1593601729_initial_schemaDownSql() (*asset, error) { return a, nil } -var __1593601729_initial_schemaUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xac\x92\x4f\x8f\xaa\x30\x14\xc5\xf7\x7c\x8a\xbb\x94\x84\xc5\xdb\xbb\x02\x2c\x2f\x24\x4d\xfb\x9e\x96\xc4\x5d\xd3\x29\x57\x6d\x64\xc0\x69\xab\x19\xbf\xfd\xc4\xff\x22\x0e\x4e\x1c\x37\x24\xb4\xa7\xf7\xde\xdf\x3d\x27\x1d\x93\x58\x10\x10\x71\x42\x09\xe4\x19\x30\x2e\x80\x4c\xf3\x89\x98\xc0\x6a\xed\x16\xb2\x6e\xbc\x99\x19\xad\xbc\x69\x6a\xa9\x2b\x83\xb5\x97\x0e\xed\x06\xad\x83\x41\x00\xb0\x5a\xbf\x55\x46\xcb\x25\x6e\x21\xa1\x3c\xd9\xbf\x67\x05\xa5\x51\x00\x60\x71\x6e\x9c\x47\x8b\x25\x24\x9c\x53\x12\x33\x18\x91\x2c\x2e\xa8\x80\x2c\xa6\x13\xd2\xd6\x48\xe5\x21\x67\xe2\x5c\xe1\xac\xfd\xb3\xd3\x29\xad\xd1\x39\xe9\x9b\x25\xd6\x20\xc8\x54\xec\x0e\x0b\x96\xff\x2f\xc8\xe0\x32\x43\x08\x9c\x41\xca\x59\x46\xf3\x54\xc0\x98\xfc\xa3\x71\x4a\x82\x70\x18\x04\xcf\x70\x7e\xac\xd1\x1a\x7c\xcc\x79\xd0\x75\x00\x4e\x57\x5b\x69\xca\xee\xa3\xce\xec\xd1\x49\xfb\x5a\x08\x53\xcf\x9a\x87\x04\x07\x47\x65\x9f\xc4\xd4\xce\xab\xaa\x3a\xd4\x36\xe5\xde\x83\x96\xa0\xe3\xd0\x4d\x16\xbc\x35\xb8\xb9\xbf\xa5\xee\x2e\x6e\xdb\x45\xdd\x11\x5f\xbb\x26\x6f\x95\x5e\x62\x29\xdf\xd1\x39\x35\x3f\x9a\x7e\xfc\xb9\xeb\x9f\x5e\x28\x7f\x77\x0f\xa7\x4a\xdf\x73\x5e\xca\xb6\x19\xf2\xbf\x8c\x8f\x49\x00\xf0\x2c\x84\xdb\x7d\xae\x2f\x1e\x63\xfc\xca\xf2\x7d\xbf\x9f\x70\x46\xd0\xe3\x6d\xd8\x26\xce\xd9\x88\x4c\xc1\x94\x9f\xb2\x37\xd3\xd7\x61\xe5\xac\x3f\xff\x7d\xc9\x0a\x87\xc1\x57\x00\x00\x00\xff\xff\xca\x86\xd6\x11\x04\x05\x00\x00") +var __1593601729_initial_schemaUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\xac\x94\xc1\x8e\x82\x30\x10\x86\xef\x3c\xc5\x1c\x25\xe1\xb0\x77\x4f\x80\x65\x43\xd2\xb4\xbb\x5a\x12\x6f\x0d\x5b\xaa\x36\xb2\xe0\xb6\xd5\xac\x6f\xbf\x41\x44\x51\x14\x36\xea\x85\x84\xf6\xef\x74\xbe\xf9\x3b\x13\x4e\x91\xcf\x10\x30\x3f\xc0\x08\xe2\x08\x08\x65\x80\xe6\xf1\x8c\xcd\x60\xb3\x35\x2b\x5e\x94\x56\x2d\x94\x48\xad\x2a\x0b\x2e\x72\x25\x0b\xcb\x8d\xd4\x3b\xa9\x0d\x8c\x1c\x80\xcd\xf6\x2b\x57\x82\xaf\xe5\x1e\x02\x4c\x83\xc3\x79\x92\x60\xec\x39\x00\x5a\x2e\x95\xb1\x52\xcb\x0c\x02\x4a\x31\xf2\x09\x4c\x50\xe4\x27\x98\x41\xe4\xe3\x19\xba\xd4\xf0\xd4\x42\x4c\xd8\x29\xc2\x49\xfb\x56\xe9\x52\x21\xa4\x31\xdc\x96\x6b\x59\x00\x43\x73\x56\x2d\x26\x24\xfe\x4c\xd0\xe8\x9c\x83\x0b\x94\x40\x48\x49\x84\xe3\x90\xc1\x14\x7d\x60\x3f\x44\x8e\x3b\x76\x9c\x47\x38\x7f\xb6\x52\x2b\x39\xcc\x59\xeb\x3a\x00\xcd\xd6\x9e\xab\xac\x7b\xa8\x93\xbb\xd7\x68\x5f\x0b\xa1\x8a\x45\x39\x48\x50\x3b\xca\xfb\x24\xaa\x30\x36\xcd\xf3\x3a\xb6\xca\x0e\x1e\x5c\x08\x3a\x0e\x5d\xbd\x05\xab\x95\xdc\xdd\xae\x52\xb7\x16\xd7\xd7\x79\xdd\x14\x5f\x5b\x26\xab\x53\xb1\x96\x19\xff\x96\xc6\xa4\xcb\xa3\xe9\xc7\x9f\x9b\xfe\x89\x55\x6a\x6f\xd6\xa1\x89\x74\x9f\xf3\x1c\xf6\x92\x21\x7e\x27\x74\x8a\x1c\x80\x47\x21\x4c\xf5\x69\x6f\x0c\x63\x3c\x65\xf9\xe1\xbe\xff\x70\x7a\xd0\xe3\xad\xfb\x04\x71\x3d\x3e\x74\x0b\xb6\x19\x29\xf5\x5a\x17\x0a\xc0\xec\x0b\xbb\x92\x56\x89\x8a\xe9\xfe\xc4\x39\x61\xb4\xf5\x83\x8f\x2e\x26\x13\x34\x07\x95\xfd\xf2\xde\x8e\x6c\xb7\x1a\x25\xfd\xdd\xdb\xd7\x17\xee\xd8\xf9\x0b\x00\x00\xff\xff\xd9\x16\xce\x6d\xc2\x05\x00\x00") func _1593601729_initial_schemaUpSqlBytes() ([]byte, error) { return bindataRead( @@ -106,8 +106,8 @@ func _1593601729_initial_schemaUpSql() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "1593601729_initial_schema.up.sql", size: 1284, mode: os.FileMode(0644), modTime: time.Unix(1594393629, 0)} - a := &asset{bytes: bytes, info: info, digest: [32]uint8{0x77, 0x6e, 0xee, 0xca, 0x95, 0x4b, 0xf0, 0x61, 0x6, 0xc, 0xad, 0x52, 0xe1, 0x81, 0x1a, 0x2d, 0xff, 0x4f, 0xb4, 0x2f, 0x6b, 0x56, 0xdf, 0x5c, 0xa7, 0xe8, 0x56, 0xaa, 0xe5, 0x6e, 0x7f, 0xb8}} + info := bindataFileInfo{name: "1593601729_initial_schema.up.sql", size: 1474, mode: os.FileMode(0644), modTime: time.Unix(1594797308, 0)} + a := &asset{bytes: bytes, info: info, digest: [32]uint8{0x3b, 0xc0, 0x6a, 0xde, 0x9a, 0xc, 0x25, 0xf3, 0x24, 0xbd, 0x44, 0xd5, 0x38, 0xd4, 0x65, 0xe5, 0x51, 0xf1, 0x2b, 0x3e, 0x84, 0x99, 0x65, 0x39, 0xb6, 0x2a, 0x79, 0x4b, 0xc5, 0x9d, 0x25, 0x75}} return a, nil } diff --git a/protocol/push_notification_client/migrations/sql/1593601729_initial_schema.up.sql b/protocol/push_notification_client/migrations/sql/1593601729_initial_schema.up.sql index afb671961..92050ee9a 100644 --- a/protocol/push_notification_client/migrations/sql/1593601729_initial_schema.up.sql +++ b/protocol/push_notification_client/migrations/sql/1593601729_initial_schema.up.sql @@ -37,4 +37,10 @@ CREATE TABLE IF NOT EXISTS push_notification_client_sent_notifications ( UNIQUE(message_id, public_key, installation_id) ); +CREATE TABLE IF NOT EXISTS push_notification_client_registrations ( + registration BLOB NOT NULL, + synthetic_id INT NOT NULL DEFAULT 0, + UNIQUE(synthetic_id) ON CONFLICT REPLACE +); + CREATE INDEX idx_push_notification_client_info_public_key ON push_notification_client_info(public_key, installation_id); diff --git a/protocol/push_notification_client/persistence.go b/protocol/push_notification_client/persistence.go index 60e265184..f3e6c3dbd 100644 --- a/protocol/push_notification_client/persistence.go +++ b/protocol/push_notification_client/persistence.go @@ -7,7 +7,10 @@ import ( "strings" "time" + "github.com/golang/protobuf/proto" + "github.com/status-im/status-go/eth-node/crypto" + "github.com/status-im/status-go/protocol/protobuf" ) type Persistence struct { @@ -18,6 +21,35 @@ func NewPersistence(db *sql.DB) *Persistence { return &Persistence{db: db} } +func (p *Persistence) GetLastPushNotificationRegistration() (*protobuf.PushNotificationRegistration, error) { + var registrationBytes []byte + err := p.db.QueryRow(`SELECT registration FROM push_notification_client_registrations LIMIT 1`).Scan(®istrationBytes) + if err == sql.ErrNoRows { + return nil, nil + } else if err != nil { + return nil, err + } + + registration := &protobuf.PushNotificationRegistration{} + + err = proto.Unmarshal(registrationBytes, registration) + if err != nil { + return nil, err + } + + return registration, nil +} + +func (p *Persistence) SaveLastPushNotificationRegistration(registration *protobuf.PushNotificationRegistration) error { + marshaledRegistration, err := proto.Marshal(registration) + if err != nil { + return err + } + _, err = p.db.Exec(`INSERT INTO push_notification_client_registrations (registration) VALUES (?)`, marshaledRegistration) + return err + +} + func (p *Persistence) TrackPushNotification(chatID string, messageID []byte) error { trackedAt := time.Now().Unix() _, err := p.db.Exec(`INSERT INTO push_notification_client_tracked_messages (chat_id, message_id, tracked_at) VALUES (?,?,?)`, chatID, messageID, trackedAt) diff --git a/protocol/push_notification_client/persistence_test.go b/protocol/push_notification_client/persistence_test.go index f81699240..709ce1d6b 100644 --- a/protocol/push_notification_client/persistence_test.go +++ b/protocol/push_notification_client/persistence_test.go @@ -5,10 +5,12 @@ import ( "os" "testing" + "github.com/golang/protobuf/proto" "github.com/stretchr/testify/suite" "github.com/status-im/status-go/eth-node/crypto" "github.com/status-im/status-go/protocol/common" + "github.com/status-im/status-go/protocol/protobuf" "github.com/status-im/status-go/protocol/sqlite" ) @@ -138,3 +140,29 @@ func (s *SQLitePersistenceSuite) TestSaveAndRetrieveInfo() { s.Require().Len(retrievedInfos, 2) } + +func (s *SQLitePersistenceSuite) TestSaveAndRetrieveRegistration() { + // Try with nil first + retrievedRegistration, err := s.persistence.GetLastPushNotificationRegistration() + s.Require().NoError(err) + s.Require().Nil(retrievedRegistration) + + // Save & retrieve registration + registration := &protobuf.PushNotificationRegistration{ + AccessToken: "test", + Version: 3, + } + + s.Require().NoError(s.persistence.SaveLastPushNotificationRegistration(registration)) + retrievedRegistration, err = s.persistence.GetLastPushNotificationRegistration() + s.Require().NoError(err) + s.Require().True(proto.Equal(registration, retrievedRegistration)) + + // Override and retrieve + + registration.Version = 5 + s.Require().NoError(s.persistence.SaveLastPushNotificationRegistration(registration)) + retrievedRegistration, err = s.persistence.GetLastPushNotificationRegistration() + s.Require().NoError(err) + s.Require().True(proto.Equal(registration, retrievedRegistration)) +}