mirror of
https://github.com/status-im/status-go.git
synced 2025-01-24 21:49:54 +00:00
Handle registration with server
This commit is contained in:
parent
7e16f940de
commit
1c379984cb
@ -1950,7 +1950,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
|
||||
}
|
||||
logger.Debug("Handling PushNotificationRegistrationResponse")
|
||||
// TODO: Compare DST with Identity
|
||||
if err := m.pushNotificationClient.HandlePushNotificationRegistrationResponse(msg.ParsedMessage.(protobuf.PushNotificationRegistrationResponse)); err != nil {
|
||||
if err := m.pushNotificationClient.HandlePushNotificationRegistrationResponse(publicKey, msg.ParsedMessage.(protobuf.PushNotificationRegistrationResponse)); err != nil {
|
||||
logger.Warn("failed to handle PushNotificationRegistrationResponse", zap.Error(err))
|
||||
}
|
||||
// We continue in any case, no changes to messenger
|
||||
@ -3021,7 +3021,7 @@ func (m *Messenger) AddPushNotificationServer(ctx context.Context, publicKey *ec
|
||||
}
|
||||
|
||||
// RegisterForPushNotification register deviceToken with any push notification server enabled
|
||||
func (m *Messenger) RegisterForPushNotifications(ctx context.Context, deviceToken string) ([]string, error) {
|
||||
func (m *Messenger) RegisterForPushNotifications(ctx context.Context, deviceToken string) ([]*push_notification_client.PushNotificationServer, error) {
|
||||
if m.pushNotificationClient == nil {
|
||||
return nil, errors.New("push notification client not enabled")
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ CREATE TABLE IF NOT EXISTS push_notification_client_servers (
|
||||
public_key BLOB NOT NULL,
|
||||
registered BOOLEAN DEFAULT FALSE,
|
||||
registered_at INT NOT NULL DEFAULT 0,
|
||||
access_token TEXT,
|
||||
UNIQUE(public_key) ON CONFLICT REPLACE
|
||||
);
|
||||
|
||||
|
@ -3,6 +3,7 @@ package push_notification_client
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"github.com/status-im/status-go/eth-node/crypto"
|
||||
)
|
||||
@ -28,13 +29,13 @@ func (p *Persistence) SentFor(publicKey *ecdsa.PublicKey, messageID []byte) erro
|
||||
}
|
||||
|
||||
func (p *Persistence) UpsertServer(server *PushNotificationServer) error {
|
||||
_, err := p.db.Exec(`INSERT INTO push_notification_client_servers (public_key, registered, registered_at) VALUES (?,?,?)`, crypto.CompressPubkey(server.publicKey), server.registered, server.registeredAt)
|
||||
_, err := p.db.Exec(`INSERT INTO push_notification_client_servers (public_key, registered, registered_at, access_token) VALUES (?,?,?,?)`, crypto.CompressPubkey(server.PublicKey), server.Registered, server.RegisteredAt, server.AccessToken)
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
func (p *Persistence) GetServers() ([]*PushNotificationServer, error) {
|
||||
rows, err := p.db.Query(`SELECT public_key, registered, registered_at FROM push_notification_client_servers`)
|
||||
rows, err := p.db.Query(`SELECT public_key, registered, registered_at,access_token FROM push_notification_client_servers`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -42,7 +43,7 @@ func (p *Persistence) GetServers() ([]*PushNotificationServer, error) {
|
||||
for rows.Next() {
|
||||
server := &PushNotificationServer{}
|
||||
var key []byte
|
||||
err := rows.Scan(&key, &server.registered, &server.registeredAt)
|
||||
err := rows.Scan(&key, &server.Registered, &server.RegisteredAt, &server.AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -50,7 +51,37 @@ func (p *Persistence) GetServers() ([]*PushNotificationServer, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
server.publicKey = parsedKey
|
||||
server.PublicKey = parsedKey
|
||||
servers = append(servers, server)
|
||||
}
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
func (p *Persistence) GetServersByPublicKey(keys []*ecdsa.PublicKey) ([]*PushNotificationServer, error) {
|
||||
|
||||
keyArgs := make([]interface{}, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
keyArgs = append(keyArgs, crypto.CompressPubkey(key))
|
||||
}
|
||||
|
||||
inVector := strings.Repeat("?, ", len(keys)-1) + "?"
|
||||
rows, err := p.db.Query(`SELECT public_key, registered, registered_at,access_token FROM push_notification_client_servers WHERE public_key IN (`+inVector+")", keyArgs...) //nolint: gosec
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var servers []*PushNotificationServer
|
||||
for rows.Next() {
|
||||
server := &PushNotificationServer{}
|
||||
var key []byte
|
||||
err := rows.Scan(&key, &server.Registered, &server.RegisteredAt, &server.AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsedKey, err := crypto.DecompressPubkey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
server.PublicKey = parsedKey
|
||||
servers = append(servers, server)
|
||||
}
|
||||
return servers, nil
|
||||
|
@ -39,11 +39,13 @@ func (s *SQLitePersistenceSuite) TearDownTest() {
|
||||
func (s *SQLitePersistenceSuite) TestSaveAndRetrieveServer() {
|
||||
key, err := crypto.GenerateKey()
|
||||
s.Require().NoError(err)
|
||||
accessToken := "token"
|
||||
|
||||
server := &PushNotificationServer{
|
||||
publicKey: &key.PublicKey,
|
||||
registered: true,
|
||||
registeredAt: 1,
|
||||
PublicKey: &key.PublicKey,
|
||||
Registered: true,
|
||||
RegisteredAt: 1,
|
||||
AccessToken: accessToken,
|
||||
}
|
||||
|
||||
s.Require().NoError(s.persistence.UpsertServer(server))
|
||||
@ -52,12 +54,13 @@ func (s *SQLitePersistenceSuite) TestSaveAndRetrieveServer() {
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Require().Len(retrievedServers, 1)
|
||||
s.Require().True(retrievedServers[0].registered)
|
||||
s.Require().Equal(int64(1), retrievedServers[0].registeredAt)
|
||||
s.Require().True(common.IsPubKeyEqual(retrievedServers[0].publicKey, &key.PublicKey))
|
||||
s.Require().True(retrievedServers[0].Registered)
|
||||
s.Require().Equal(int64(1), retrievedServers[0].RegisteredAt)
|
||||
s.Require().True(common.IsPubKeyEqual(retrievedServers[0].PublicKey, &key.PublicKey))
|
||||
s.Require().Equal(accessToken, retrievedServers[0].AccessToken)
|
||||
|
||||
server.registered = false
|
||||
server.registeredAt = 2
|
||||
server.Registered = false
|
||||
server.RegisteredAt = 2
|
||||
|
||||
s.Require().NoError(s.persistence.UpsertServer(server))
|
||||
|
||||
@ -65,7 +68,7 @@ func (s *SQLitePersistenceSuite) TestSaveAndRetrieveServer() {
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Require().Len(retrievedServers, 1)
|
||||
s.Require().False(retrievedServers[0].registered)
|
||||
s.Require().Equal(int64(2), retrievedServers[0].registeredAt)
|
||||
s.Require().True(common.IsPubKeyEqual(retrievedServers[0].publicKey, &key.PublicKey))
|
||||
s.Require().False(retrievedServers[0].Registered)
|
||||
s.Require().Equal(int64(2), retrievedServers[0].RegisteredAt)
|
||||
s.Require().True(common.IsPubKeyEqual(retrievedServers[0].PublicKey, &key.PublicKey))
|
||||
}
|
||||
|
@ -26,9 +26,10 @@ const encryptedPayloadKeyLength = 16
|
||||
const accessTokenKeyLength = 16
|
||||
|
||||
type PushNotificationServer struct {
|
||||
publicKey *ecdsa.PublicKey
|
||||
registered bool
|
||||
registeredAt int64
|
||||
PublicKey *ecdsa.PublicKey
|
||||
Registered bool
|
||||
RegisteredAt int64
|
||||
AccessToken string
|
||||
}
|
||||
|
||||
type PushNotificationInfo struct {
|
||||
@ -78,21 +79,18 @@ type Client struct {
|
||||
//messageProcessor is a message processor used to send and being notified of messages
|
||||
|
||||
messageProcessor *common.MessageProcessor
|
||||
//pushNotificationRegistrationResponses is a channel that listens to pushNotificationResponse
|
||||
pushNotificationRegistrationResponses chan *protobuf.PushNotificationRegistrationResponse
|
||||
//pushNotificationQueryResponses is a channel that listens to pushNotificationResponse
|
||||
pushNotificationQueryResponses chan *protobuf.PushNotificationQueryResponse
|
||||
}
|
||||
|
||||
func New(persistence *Persistence, config *Config, processor *common.MessageProcessor) *Client {
|
||||
return &Client{
|
||||
quit: make(chan struct{}),
|
||||
config: config,
|
||||
pushNotificationRegistrationResponses: make(chan *protobuf.PushNotificationRegistrationResponse),
|
||||
pushNotificationQueryResponses: make(chan *protobuf.PushNotificationQueryResponse),
|
||||
messageProcessor: processor,
|
||||
persistence: persistence,
|
||||
reader: rand.Reader}
|
||||
quit: make(chan struct{}),
|
||||
config: config,
|
||||
pushNotificationQueryResponses: make(chan *protobuf.PushNotificationQueryResponse),
|
||||
messageProcessor: processor,
|
||||
persistence: persistence,
|
||||
reader: rand.Reader}
|
||||
}
|
||||
|
||||
func (c *Client) Start() error {
|
||||
@ -198,7 +196,7 @@ func (p *Client) buildPushNotificationRegistrationMessage(contactIDs []*ecdsa.Pu
|
||||
return options, nil
|
||||
}
|
||||
|
||||
func (c *Client) Register(deviceToken string, contactIDs []*ecdsa.PublicKey, mutedChatIDs []string) ([]string, error) {
|
||||
func (c *Client) Register(deviceToken string, contactIDs []*ecdsa.PublicKey, mutedChatIDs []string) ([]*PushNotificationServer, error) {
|
||||
c.DeviceToken = deviceToken
|
||||
servers, err := c.persistence.GetServers()
|
||||
if err != nil {
|
||||
@ -219,9 +217,21 @@ func (c *Client) Register(deviceToken string, contactIDs []*ecdsa.PublicKey, mut
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var serverPublicKeys []*ecdsa.PublicKey
|
||||
for _, server := range servers {
|
||||
|
||||
encryptedRegistration, err := c.encryptRegistration(server.publicKey, marshaledRegistration)
|
||||
// Reset server registration data
|
||||
server.Registered = false
|
||||
server.RegisteredAt = 0
|
||||
server.AccessToken = registration.AccessToken
|
||||
serverPublicKeys = append(serverPublicKeys, server.PublicKey)
|
||||
|
||||
if err := c.persistence.UpsertServer(server); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Dispatch message
|
||||
encryptedRegistration, err := c.encryptRegistration(server.PublicKey, marshaledRegistration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -230,33 +240,71 @@ func (c *Client) Register(deviceToken string, contactIDs []*ecdsa.PublicKey, mut
|
||||
MessageType: protobuf.ApplicationMetadataMessage_PUSH_NOTIFICATION_REGISTRATION,
|
||||
}
|
||||
|
||||
_, err = c.messageProcessor.SendPrivate(context.Background(), server.publicKey, rawMessage)
|
||||
_, err = c.messageProcessor.SendPrivate(context.Background(), server.PublicKey, rawMessage)
|
||||
|
||||
// Send message and wait for reply
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
}
|
||||
// TODO: this needs to wait for all the registrations, probably best to poll the database
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
|
||||
// This code polls the database for server registrations, giving up
|
||||
// after 5 seconds
|
||||
for {
|
||||
select {
|
||||
case <-c.quit:
|
||||
return nil, nil
|
||||
case <-time.After(5 * time.Second):
|
||||
return nil, errors.New("no registration response received")
|
||||
case <-c.pushNotificationRegistrationResponses:
|
||||
return nil, nil
|
||||
return servers, nil
|
||||
case <-ctx.Done():
|
||||
c.config.Logger.Debug("Context done")
|
||||
return servers, nil
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
servers, err = c.persistence.GetServersByPublicKey(serverPublicKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allRegistered := true
|
||||
for _, server := range servers {
|
||||
allRegistered = allRegistered && server.Registered
|
||||
}
|
||||
|
||||
// If any of the servers we haven't registered yet, continue
|
||||
if !allRegistered {
|
||||
continue
|
||||
}
|
||||
|
||||
// all have registered,cancel context and return
|
||||
cancel()
|
||||
return servers, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HandlePushNotificationRegistrationResponse should check whether the response was successful or not, retry if necessary otherwise store the result in the database
|
||||
func (c *Client) HandlePushNotificationRegistrationResponse(response protobuf.PushNotificationRegistrationResponse) error {
|
||||
func (c *Client) HandlePushNotificationRegistrationResponse(publicKey *ecdsa.PublicKey, response protobuf.PushNotificationRegistrationResponse) error {
|
||||
c.config.Logger.Debug("received push notification registration response", zap.Any("response", response))
|
||||
select {
|
||||
case c.pushNotificationRegistrationResponses <- &response:
|
||||
default:
|
||||
return errors.New("could not process push notification registration response")
|
||||
// TODO: handle non successful response and match request id
|
||||
// Not successful ignore for now
|
||||
if !response.Success {
|
||||
return errors.New("response was not successful")
|
||||
}
|
||||
return nil
|
||||
|
||||
servers, err := c.persistence.GetServersByPublicKey([]*ecdsa.PublicKey{publicKey})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// We haven't registered with this server
|
||||
if len(servers) != 1 {
|
||||
return errors.New("not registered with this server, ignoring")
|
||||
}
|
||||
|
||||
server := servers[0]
|
||||
server.Registered = true
|
||||
server.RegisteredAt = time.Now().Unix()
|
||||
|
||||
return c.persistence.UpsertServer(server)
|
||||
}
|
||||
|
||||
// HandlePushNotificationAdvertisement should store any info related to push notifications
|
||||
@ -289,13 +337,13 @@ func (c *Client) AddPushNotificationServer(publicKey *ecdsa.PublicKey) error {
|
||||
}
|
||||
|
||||
for _, server := range currentServers {
|
||||
if common.IsPubKeyEqual(server.publicKey, publicKey) {
|
||||
if common.IsPubKeyEqual(server.PublicKey, publicKey) {
|
||||
return errors.New("push notification server already added")
|
||||
}
|
||||
}
|
||||
|
||||
return c.persistence.UpsertServer(&PushNotificationServer{
|
||||
publicKey: publicKey,
|
||||
PublicKey: publicKey,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/status-im/status-go/eth-node/crypto"
|
||||
"github.com/status-im/status-go/eth-node/crypto/ecies"
|
||||
"github.com/status-im/status-go/protocol/common"
|
||||
"github.com/status-im/status-go/protocol/protobuf"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -22,7 +23,7 @@ func TestBuildPushNotificationRegisterMessage(t *testing.T) {
|
||||
// build chat lish hashes
|
||||
var mutedChatListHashes [][]byte
|
||||
for _, chatID := range mutedChatList {
|
||||
mutedChatListHashes = append(mutedChatListHashes, shake256(chatID))
|
||||
mutedChatListHashes = append(mutedChatListHashes, common.Shake256([]byte(chatID)))
|
||||
}
|
||||
|
||||
identity, err := crypto.GenerateKey()
|
||||
|
@ -257,6 +257,9 @@ func (s *Server) HandlePushNotificationRegistration(publicKey *ecdsa.PublicKey,
|
||||
}
|
||||
|
||||
func (s *Server) listenToPublicKeyQueryTopic(hashedPublicKey []byte) error {
|
||||
if s.messageProcessor == nil {
|
||||
return nil
|
||||
}
|
||||
encodedPublicKey := hex.EncodeToString(hashedPublicKey)
|
||||
return s.messageProcessor.JoinPublic(encodedPublicKey)
|
||||
}
|
@ -14,6 +14,7 @@ import (
|
||||
"github.com/status-im/status-go/protocol/common"
|
||||
"github.com/status-im/status-go/protocol/protobuf"
|
||||
"github.com/status-im/status-go/protocol/sqlite"
|
||||
"github.com/status-im/status-go/protocol/tt"
|
||||
)
|
||||
|
||||
func TestServerSuite(t *testing.T) {
|
||||
@ -55,6 +56,7 @@ func (s *ServerSuite) SetupTest() {
|
||||
|
||||
config := &Config{
|
||||
Identity: identity,
|
||||
Logger: tt.MustCreateTestLogger(),
|
||||
}
|
||||
|
||||
s.server = New(config, s.persistence, nil)
|
@ -118,7 +118,7 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotification() {
|
||||
|
||||
bob1DeviceToken := "token-1"
|
||||
bob2DeviceToken := "token-2"
|
||||
var bob1AccessTokens, bob2AccessTokens []string
|
||||
var bob1Servers, bob2Servers []*push_notification_client.PushNotificationServer
|
||||
|
||||
bob1 := s.m
|
||||
bob2 := s.newMessengerWithKey(s.shh, s.m.identity)
|
||||
@ -130,7 +130,7 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotification() {
|
||||
s.Require().NoError(err)
|
||||
|
||||
go func() {
|
||||
bob1AccessTokens, err = bob1.RegisterForPushNotifications(context.Background(), bob1DeviceToken)
|
||||
bob1Servers, err = bob1.RegisterForPushNotifications(context.Background(), bob1DeviceToken)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
@ -165,14 +165,16 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotification() {
|
||||
// Make sure we receive it
|
||||
err = <-errChan
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(bob1AccessTokens)
|
||||
s.Require().NotNil(bob1Servers)
|
||||
s.Require().Len(bob1Servers, 1)
|
||||
s.Require().True(bob1Servers[0].Registered)
|
||||
|
||||
// Register bob2
|
||||
err = bob2.AddPushNotificationServer(context.Background(), &server.identity.PublicKey)
|
||||
s.Require().NoError(err)
|
||||
|
||||
go func() {
|
||||
bob2AccessTokens, err = bob2.RegisterForPushNotifications(context.Background(), bob2DeviceToken)
|
||||
bob2Servers, err = bob2.RegisterForPushNotifications(context.Background(), bob2DeviceToken)
|
||||
errChan <- err
|
||||
}()
|
||||
|
||||
@ -207,7 +209,9 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotification() {
|
||||
// Make sure we receive it
|
||||
err = <-errChan
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(bob2AccessTokens)
|
||||
s.Require().NotNil(bob2Servers)
|
||||
s.Require().Len(bob2Servers, 1)
|
||||
s.Require().True(bob2Servers[0].Registered)
|
||||
|
||||
var info []*push_notification_client.PushNotificationInfo
|
||||
go func() {
|
||||
@ -251,7 +255,7 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotification() {
|
||||
|
||||
var bob1Info, bob2Info *push_notification_client.PushNotificationInfo
|
||||
|
||||
if info[0].AccessToken == bob1AccessTokens[0] {
|
||||
if info[0].AccessToken == bob1Servers[0].AccessToken {
|
||||
bob1Info = info[0]
|
||||
bob2Info = info[1]
|
||||
} else {
|
||||
@ -262,14 +266,14 @@ func (s *MessengerPushNotificationSuite) TestReceivePushNotification() {
|
||||
s.Require().NotNil(bob1Info)
|
||||
s.Require().Equal(bob1Info, &push_notification_client.PushNotificationInfo{
|
||||
InstallationID: bob1.installationID,
|
||||
AccessToken: bob1DeviceToken,
|
||||
AccessToken: bob1Servers[0].AccessToken,
|
||||
PublicKey: &bob1.identity.PublicKey,
|
||||
})
|
||||
|
||||
s.Require().NotNil(bob2Info)
|
||||
s.Require().Equal(bob2Info, &push_notification_client.PushNotificationInfo{
|
||||
InstallationID: bob2.installationID,
|
||||
AccessToken: bob2DeviceToken,
|
||||
AccessToken: bob2Servers[0].AccessToken,
|
||||
PublicKey: &bob1.identity.PublicKey,
|
||||
})
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user