Handle registration with server

This commit is contained in:
Andrea Maria Piana 2020-07-10 09:45:40 +02:00
parent 7e16f940de
commit 1c379984cb
No known key found for this signature in database
GPG Key ID: AA6CCA6DE0E06424
11 changed files with 149 additions and 56 deletions

View File

@ -1950,7 +1950,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
}
logger.Debug("Handling PushNotificationRegistrationResponse")
// TODO: Compare DST with Identity
if err := m.pushNotificationClient.HandlePushNotificationRegistrationResponse(msg.ParsedMessage.(protobuf.PushNotificationRegistrationResponse)); err != nil {
if err := m.pushNotificationClient.HandlePushNotificationRegistrationResponse(publicKey, msg.ParsedMessage.(protobuf.PushNotificationRegistrationResponse)); err != nil {
logger.Warn("failed to handle PushNotificationRegistrationResponse", zap.Error(err))
}
// We continue in any case, no changes to messenger
@ -3021,7 +3021,7 @@ func (m *Messenger) AddPushNotificationServer(ctx context.Context, publicKey *ec
}
// RegisterForPushNotification register deviceToken with any push notification server enabled
func (m *Messenger) RegisterForPushNotifications(ctx context.Context, deviceToken string) ([]string, error) {
func (m *Messenger) RegisterForPushNotifications(ctx context.Context, deviceToken string) ([]*push_notification_client.PushNotificationServer, error) {
if m.pushNotificationClient == nil {
return nil, errors.New("push notification client not enabled")
}

View File

@ -2,6 +2,7 @@ CREATE TABLE IF NOT EXISTS push_notification_client_servers (
public_key BLOB NOT NULL,
registered BOOLEAN DEFAULT FALSE,
registered_at INT NOT NULL DEFAULT 0,
access_token TEXT,
UNIQUE(public_key) ON CONFLICT REPLACE
);

View File

@ -3,6 +3,7 @@ package push_notification_client
import (
"crypto/ecdsa"
"database/sql"
"strings"
"github.com/status-im/status-go/eth-node/crypto"
)
@ -28,13 +29,13 @@ func (p *Persistence) SentFor(publicKey *ecdsa.PublicKey, messageID []byte) erro
}
func (p *Persistence) UpsertServer(server *PushNotificationServer) error {
_, err := p.db.Exec(`INSERT INTO push_notification_client_servers (public_key, registered, registered_at) VALUES (?,?,?)`, crypto.CompressPubkey(server.publicKey), server.registered, server.registeredAt)
_, err := p.db.Exec(`INSERT INTO push_notification_client_servers (public_key, registered, registered_at, access_token) VALUES (?,?,?,?)`, crypto.CompressPubkey(server.PublicKey), server.Registered, server.RegisteredAt, server.AccessToken)
return err
}
func (p *Persistence) GetServers() ([]*PushNotificationServer, error) {
rows, err := p.db.Query(`SELECT public_key, registered, registered_at FROM push_notification_client_servers`)
rows, err := p.db.Query(`SELECT public_key, registered, registered_at,access_token FROM push_notification_client_servers`)
if err != nil {
return nil, err
}
@ -42,7 +43,7 @@ func (p *Persistence) GetServers() ([]*PushNotificationServer, error) {
for rows.Next() {
server := &PushNotificationServer{}
var key []byte
err := rows.Scan(&key, &server.registered, &server.registeredAt)
err := rows.Scan(&key, &server.Registered, &server.RegisteredAt, &server.AccessToken)
if err != nil {
return nil, err
}
@ -50,7 +51,37 @@ func (p *Persistence) GetServers() ([]*PushNotificationServer, error) {
if err != nil {
return nil, err
}
server.publicKey = parsedKey
server.PublicKey = parsedKey
servers = append(servers, server)
}
return servers, nil
}
func (p *Persistence) GetServersByPublicKey(keys []*ecdsa.PublicKey) ([]*PushNotificationServer, error) {
keyArgs := make([]interface{}, 0, len(keys))
for _, key := range keys {
keyArgs = append(keyArgs, crypto.CompressPubkey(key))
}
inVector := strings.Repeat("?, ", len(keys)-1) + "?"
rows, err := p.db.Query(`SELECT public_key, registered, registered_at,access_token FROM push_notification_client_servers WHERE public_key IN (`+inVector+")", keyArgs...) //nolint: gosec
if err != nil {
return nil, err
}
var servers []*PushNotificationServer
for rows.Next() {
server := &PushNotificationServer{}
var key []byte
err := rows.Scan(&key, &server.Registered, &server.RegisteredAt, &server.AccessToken)
if err != nil {
return nil, err
}
parsedKey, err := crypto.DecompressPubkey(key)
if err != nil {
return nil, err
}
server.PublicKey = parsedKey
servers = append(servers, server)
}
return servers, nil

View File

@ -39,11 +39,13 @@ func (s *SQLitePersistenceSuite) TearDownTest() {
func (s *SQLitePersistenceSuite) TestSaveAndRetrieveServer() {
key, err := crypto.GenerateKey()
s.Require().NoError(err)
accessToken := "token"
server := &PushNotificationServer{
publicKey: &key.PublicKey,
registered: true,
registeredAt: 1,
PublicKey: &key.PublicKey,
Registered: true,
RegisteredAt: 1,
AccessToken: accessToken,
}
s.Require().NoError(s.persistence.UpsertServer(server))
@ -52,12 +54,13 @@ func (s *SQLitePersistenceSuite) TestSaveAndRetrieveServer() {
s.Require().NoError(err)
s.Require().Len(retrievedServers, 1)
s.Require().True(retrievedServers[0].registered)
s.Require().Equal(int64(1), retrievedServers[0].registeredAt)
s.Require().True(common.IsPubKeyEqual(retrievedServers[0].publicKey, &key.PublicKey))
s.Require().True(retrievedServers[0].Registered)
s.Require().Equal(int64(1), retrievedServers[0].RegisteredAt)
s.Require().True(common.IsPubKeyEqual(retrievedServers[0].PublicKey, &key.PublicKey))
s.Require().Equal(accessToken, retrievedServers[0].AccessToken)
server.registered = false
server.registeredAt = 2
server.Registered = false
server.RegisteredAt = 2
s.Require().NoError(s.persistence.UpsertServer(server))
@ -65,7 +68,7 @@ func (s *SQLitePersistenceSuite) TestSaveAndRetrieveServer() {
s.Require().NoError(err)
s.Require().Len(retrievedServers, 1)
s.Require().False(retrievedServers[0].registered)
s.Require().Equal(int64(2), retrievedServers[0].registeredAt)
s.Require().True(common.IsPubKeyEqual(retrievedServers[0].publicKey, &key.PublicKey))
s.Require().False(retrievedServers[0].Registered)
s.Require().Equal(int64(2), retrievedServers[0].RegisteredAt)
s.Require().True(common.IsPubKeyEqual(retrievedServers[0].PublicKey, &key.PublicKey))
}

View File

@ -26,9 +26,10 @@ const encryptedPayloadKeyLength = 16
const accessTokenKeyLength = 16
type PushNotificationServer struct {
publicKey *ecdsa.PublicKey
registered bool
registeredAt int64
PublicKey *ecdsa.PublicKey
Registered bool
RegisteredAt int64
AccessToken string
}
type PushNotificationInfo struct {
@ -78,21 +79,18 @@ type Client struct {
//messageProcessor is a message processor used to send and being notified of messages
messageProcessor *common.MessageProcessor
//pushNotificationRegistrationResponses is a channel that listens to pushNotificationResponse
pushNotificationRegistrationResponses chan *protobuf.PushNotificationRegistrationResponse
//pushNotificationQueryResponses is a channel that listens to pushNotificationResponse
pushNotificationQueryResponses chan *protobuf.PushNotificationQueryResponse
}
func New(persistence *Persistence, config *Config, processor *common.MessageProcessor) *Client {
return &Client{
quit: make(chan struct{}),
config: config,
pushNotificationRegistrationResponses: make(chan *protobuf.PushNotificationRegistrationResponse),
pushNotificationQueryResponses: make(chan *protobuf.PushNotificationQueryResponse),
messageProcessor: processor,
persistence: persistence,
reader: rand.Reader}
quit: make(chan struct{}),
config: config,
pushNotificationQueryResponses: make(chan *protobuf.PushNotificationQueryResponse),
messageProcessor: processor,
persistence: persistence,
reader: rand.Reader}
}
func (c *Client) Start() error {
@ -198,7 +196,7 @@ func (p *Client) buildPushNotificationRegistrationMessage(contactIDs []*ecdsa.Pu
return options, nil
}
func (c *Client) Register(deviceToken string, contactIDs []*ecdsa.PublicKey, mutedChatIDs []string) ([]string, error) {
func (c *Client) Register(deviceToken string, contactIDs []*ecdsa.PublicKey, mutedChatIDs []string) ([]*PushNotificationServer, error) {
c.DeviceToken = deviceToken
servers, err := c.persistence.GetServers()
if err != nil {
@ -219,9 +217,21 @@ func (c *Client) Register(deviceToken string, contactIDs []*ecdsa.PublicKey, mut
return nil, err
}
var serverPublicKeys []*ecdsa.PublicKey
for _, server := range servers {
encryptedRegistration, err := c.encryptRegistration(server.publicKey, marshaledRegistration)
// Reset server registration data
server.Registered = false
server.RegisteredAt = 0
server.AccessToken = registration.AccessToken
serverPublicKeys = append(serverPublicKeys, server.PublicKey)
if err := c.persistence.UpsertServer(server); err != nil {
return nil, err
}
// Dispatch message
encryptedRegistration, err := c.encryptRegistration(server.PublicKey, marshaledRegistration)
if err != nil {
return nil, err
}
@ -230,33 +240,71 @@ func (c *Client) Register(deviceToken string, contactIDs []*ecdsa.PublicKey, mut
MessageType: protobuf.ApplicationMetadataMessage_PUSH_NOTIFICATION_REGISTRATION,
}
_, err = c.messageProcessor.SendPrivate(context.Background(), server.publicKey, rawMessage)
_, err = c.messageProcessor.SendPrivate(context.Background(), server.PublicKey, rawMessage)
// Send message and wait for reply
if err != nil {
return nil, err
}
}
// TODO: this needs to wait for all the registrations, probably best to poll the database
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
// This code polls the database for server registrations, giving up
// after 5 seconds
for {
select {
case <-c.quit:
return nil, nil
case <-time.After(5 * time.Second):
return nil, errors.New("no registration response received")
case <-c.pushNotificationRegistrationResponses:
return nil, nil
return servers, nil
case <-ctx.Done():
c.config.Logger.Debug("Context done")
return servers, nil
case <-time.After(200 * time.Millisecond):
servers, err = c.persistence.GetServersByPublicKey(serverPublicKeys)
if err != nil {
return nil, err
}
allRegistered := true
for _, server := range servers {
allRegistered = allRegistered && server.Registered
}
// If any of the servers we haven't registered yet, continue
if !allRegistered {
continue
}
// all have registered,cancel context and return
cancel()
return servers, nil
}
}
}
// HandlePushNotificationRegistrationResponse should check whether the response was successful or not, retry if necessary otherwise store the result in the database
func (c *Client) HandlePushNotificationRegistrationResponse(response protobuf.PushNotificationRegistrationResponse) error {
func (c *Client) HandlePushNotificationRegistrationResponse(publicKey *ecdsa.PublicKey, response protobuf.PushNotificationRegistrationResponse) error {
c.config.Logger.Debug("received push notification registration response", zap.Any("response", response))
select {
case c.pushNotificationRegistrationResponses <- &response:
default:
return errors.New("could not process push notification registration response")
// TODO: handle non successful response and match request id
// Not successful ignore for now
if !response.Success {
return errors.New("response was not successful")
}
return nil
servers, err := c.persistence.GetServersByPublicKey([]*ecdsa.PublicKey{publicKey})
if err != nil {
return err
}
// We haven't registered with this server
if len(servers) != 1 {
return errors.New("not registered with this server, ignoring")
}
server := servers[0]
server.Registered = true
server.RegisteredAt = time.Now().Unix()
return c.persistence.UpsertServer(server)
}
// HandlePushNotificationAdvertisement should store any info related to push notifications
@ -289,13 +337,13 @@ func (c *Client) AddPushNotificationServer(publicKey *ecdsa.PublicKey) error {
}
for _, server := range currentServers {
if common.IsPubKeyEqual(server.publicKey, publicKey) {
if common.IsPubKeyEqual(server.PublicKey, publicKey) {
return errors.New("push notification server already added")
}
}
return c.persistence.UpsertServer(&PushNotificationServer{
publicKey: publicKey,
PublicKey: publicKey,
})
}

View File

@ -10,6 +10,7 @@ import (
"github.com/google/uuid"
"github.com/status-im/status-go/eth-node/crypto"
"github.com/status-im/status-go/eth-node/crypto/ecies"
"github.com/status-im/status-go/protocol/common"
"github.com/status-im/status-go/protocol/protobuf"
"github.com/stretchr/testify/require"
)
@ -22,7 +23,7 @@ func TestBuildPushNotificationRegisterMessage(t *testing.T) {
// build chat lish hashes
var mutedChatListHashes [][]byte
for _, chatID := range mutedChatList {
mutedChatListHashes = append(mutedChatListHashes, shake256(chatID))
mutedChatListHashes = append(mutedChatListHashes, common.Shake256([]byte(chatID)))
}
identity, err := crypto.GenerateKey()

View File

@ -257,6 +257,9 @@ func (s *Server) HandlePushNotificationRegistration(publicKey *ecdsa.PublicKey,
}
func (s *Server) listenToPublicKeyQueryTopic(hashedPublicKey []byte) error {
if s.messageProcessor == nil {
return nil
}
encodedPublicKey := hex.EncodeToString(hashedPublicKey)
return s.messageProcessor.JoinPublic(encodedPublicKey)
}

View File

@ -14,6 +14,7 @@ import (
"github.com/status-im/status-go/protocol/common"
"github.com/status-im/status-go/protocol/protobuf"
"github.com/status-im/status-go/protocol/sqlite"
"github.com/status-im/status-go/protocol/tt"
)
func TestServerSuite(t *testing.T) {
@ -55,6 +56,7 @@ func (s *ServerSuite) SetupTest() {
config := &Config{
Identity: identity,
Logger: tt.MustCreateTestLogger(),
}
s.server = New(config, s.persistence, nil)

View File

@ -118,7 +118,7 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotification() {
bob1DeviceToken := "token-1"
bob2DeviceToken := "token-2"
var bob1AccessTokens, bob2AccessTokens []string
var bob1Servers, bob2Servers []*push_notification_client.PushNotificationServer
bob1 := s.m
bob2 := s.newMessengerWithKey(s.shh, s.m.identity)
@ -130,7 +130,7 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotification() {
s.Require().NoError(err)
go func() {
bob1AccessTokens, err = bob1.RegisterForPushNotifications(context.Background(), bob1DeviceToken)
bob1Servers, err = bob1.RegisterForPushNotifications(context.Background(), bob1DeviceToken)
errChan <- err
}()
@ -165,14 +165,16 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotification() {
// Make sure we receive it
err = <-errChan
s.Require().NoError(err)
s.Require().NotNil(bob1AccessTokens)
s.Require().NotNil(bob1Servers)
s.Require().Len(bob1Servers, 1)
s.Require().True(bob1Servers[0].Registered)
// Register bob2
err = bob2.AddPushNotificationServer(context.Background(), &server.identity.PublicKey)
s.Require().NoError(err)
go func() {
bob2AccessTokens, err = bob2.RegisterForPushNotifications(context.Background(), bob2DeviceToken)
bob2Servers, err = bob2.RegisterForPushNotifications(context.Background(), bob2DeviceToken)
errChan <- err
}()
@ -207,7 +209,9 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotification() {
// Make sure we receive it
err = <-errChan
s.Require().NoError(err)
s.Require().NotNil(bob2AccessTokens)
s.Require().NotNil(bob2Servers)
s.Require().Len(bob2Servers, 1)
s.Require().True(bob2Servers[0].Registered)
var info []*push_notification_client.PushNotificationInfo
go func() {
@ -251,7 +255,7 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotification() {
var bob1Info, bob2Info *push_notification_client.PushNotificationInfo
if info[0].AccessToken == bob1AccessTokens[0] {
if info[0].AccessToken == bob1Servers[0].AccessToken {
bob1Info = info[0]
bob2Info = info[1]
} else {
@ -262,14 +266,14 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotification() {
s.Require().NotNil(bob1Info)
s.Require().Equal(bob1Info, &push_notification_client.PushNotificationInfo{
InstallationID: bob1.installationID,
AccessToken: bob1DeviceToken,
AccessToken: bob1Servers[0].AccessToken,
PublicKey: &bob1.identity.PublicKey,
})
s.Require().NotNil(bob2Info)
s.Require().Equal(bob2Info, &push_notification_client.PushNotificationInfo{
InstallationID: bob2.installationID,
AccessToken: bob2DeviceToken,
AccessToken: bob2Servers[0].AccessToken,
PublicKey: &bob1.identity.PublicKey,
})