Make sure request is originating from us before marking it as processed

This commit fixes one source of flakyness in the tests, which was an
actual bug.

If 1 device is registering with a push notification server, if there's
another device with the same public key, both would mark themselves as
registered, while maybe only one has been actually registered.

To fix this, we keep track of the request ids we send (in memory for
now), and only mark it as registered if the request was originating on
this device.
This commit is contained in:
Andrea Maria Piana 2021-09-28 11:25:13 +01:00
parent 902b97be06
commit 5c55ab5264
2 changed files with 23 additions and 8 deletions

View File

@ -1 +1 @@
0.88.5
0.88.6

View File

@ -192,6 +192,11 @@ type Client struct {
// registrationSubscriptions is a list of chan of client subscribed to the registration event
registrationSubscriptions []chan struct{}
// pendingRegistrations is a map of pending registrations.
// in theory we should store them in the database, but for now we can keep them in memory at
// the cost of having to register multiple times in case the program stops
pendingRegistrations map[string]bool
}
func New(persistence *Persistence, config *Config, sender *common.MessageSender, messagePersistence MessagePersistence) *Client {
@ -201,6 +206,7 @@ func New(persistence *Persistence, config *Config, sender *common.MessageSender,
messageSender: sender,
messagePersistence: messagePersistence,
persistence: persistence,
pendingRegistrations: make(map[string]bool),
reader: rand.Reader,
}
}
@ -399,6 +405,14 @@ func (c *Client) Register(deviceToken, apnTopic string, tokenType protobuf.PushN
func (c *Client) HandlePushNotificationRegistrationResponse(publicKey *ecdsa.PublicKey, response protobuf.PushNotificationRegistrationResponse) error {
c.config.Logger.Debug("received push notification registration response", zap.Any("response", response))
if len(response.RequestId) == 0 {
return errors.New("empty requestId")
}
if !c.pendingRegistrations[hex.EncodeToString(response.RequestId)] {
return errors.New("not for one of our requests")
}
// Not successful ignore for now
if !response.Success {
return errors.New("response was not successful")
@ -1278,6 +1292,8 @@ func (c *Client) registerWithServer(registration *protobuf.PushNotificationRegis
if err != nil {
return err
}
c.pendingRegistrations[hex.EncodeToString(common.Shake256(encryptedRegistration))] = true
return nil
}
@ -1318,7 +1334,6 @@ func (c *Client) SendNotification(publicKey *ecdsa.PublicKey, installationIDs []
// one info per installation id, grouped by server
actionableInfos := make(map[string][]*PushNotificationInfo)
c.config.Logger.Info("INFOS", zap.Any("info", info))
for _, i := range info {
if !installationIDsMap[i.InstallationID] {