diff --git a/protocol/common/message_sender.go b/protocol/common/message_sender.go index 5398c1a66..d48eec5c5 100644 --- a/protocol/common/message_sender.go +++ b/protocol/common/message_sender.go @@ -4,6 +4,7 @@ import ( "context" "crypto/ecdsa" "database/sql" + "math/rand" "sync" "time" @@ -33,8 +34,9 @@ const ( whisperLargeSizePoW = 0.000002 // largeSizeInBytes is when should we be using a lower POW. // Roughly this is 50KB - largeSizeInBytes = 50000 - whisperPoWTime = 5 + largeSizeInBytes = 50000 + whisperPoWTime = 5 + maxMessageSenderEphemeralKeys = 3 ) // RekeyCompatibility indicates whether we should be sending @@ -1245,15 +1247,37 @@ func (s *MessageSender) JoinPublic(id string) (*transport.Filter, error) { return s.transport.JoinPublic(id) } -// AddEphemeralKey adds an ephemeral key that we will be listening to -// note that we never removed them from now, as waku/whisper does not -// recalculate topics on removal, so effectively there's no benefit. -// On restart they will be gone. -func (s *MessageSender) AddEphemeralKey(privateKey *ecdsa.PrivateKey) (*transport.Filter, error) { +func (s *MessageSender) getRandomEphemeralKey() *ecdsa.PrivateKey { + k := rand.Intn(len(s.ephemeralKeys)) //nolint: gosec + for _, key := range s.ephemeralKeys { + if k == 0 { + return key + } + k-- + } + return nil +} + +func (s *MessageSender) GetEphemeralKey() (*ecdsa.PrivateKey, error) { s.ephemeralKeysMutex.Lock() + if len(s.ephemeralKeys) >= maxMessageSenderEphemeralKeys { + s.ephemeralKeysMutex.Unlock() + return s.getRandomEphemeralKey(), nil + } + privateKey, err := crypto.GenerateKey() + if err != nil { + s.ephemeralKeysMutex.Unlock() + return nil, err + } + s.ephemeralKeys[types.EncodeHex(crypto.FromECDSAPub(&privateKey.PublicKey))] = privateKey s.ephemeralKeysMutex.Unlock() - return s.transport.LoadKeyFilters(privateKey) + _, err = s.transport.LoadKeyFilters(privateKey) + if err != nil { + return nil, err + } + + return privateKey, nil } func MessageSpecToWhisper(spec *encryption.ProtocolMessageSpec) (*types.NewMessage, error) { diff --git a/protocol/common/message_sender_test.go b/protocol/common/message_sender_test.go index f03938b10..09a299fe5 100644 --- a/protocol/common/message_sender_test.go +++ b/protocol/common/message_sender_test.go @@ -361,3 +361,20 @@ func (s *MessageSenderSuite) TestHandleSegmentMessages() { _, err = s.sender.HandleMessages(message) s.Require().ErrorIs(err, ErrMessageSegmentsAlreadyCompleted) } + +func (s *MessageSenderSuite) TestGetEphemeralKey() { + keyMap := make(map[string]bool) + for i := 0; i < maxMessageSenderEphemeralKeys; i++ { + key, err := s.sender.GetEphemeralKey() + s.Require().NoError(err) + s.Require().NotNil(key) + keyMap[PubkeyToHex(&key.PublicKey)] = true + } + s.Require().Len(keyMap, maxMessageSenderEphemeralKeys) + // Add one more + key, err := s.sender.GetEphemeralKey() + s.Require().NoError(err) + s.Require().NotNil(key) + + s.Require().True(keyMap[PubkeyToHex(&key.PublicKey)]) +} diff --git a/protocol/pushnotificationclient/client.go b/protocol/pushnotificationclient/client.go index 0c9f11145..b292047ff 100644 --- a/protocol/pushnotificationclient/client.go +++ b/protocol/pushnotificationclient/client.go @@ -1375,12 +1375,7 @@ func (c *Client) SendNotification(publicKey *ecdsa.PublicKey, installationIDs [] c.config.Logger.Debug("actionable info", zap.Int("count", len(actionableInfos))) - // add ephemeral key and listen to it - ephemeralKey, err := crypto.GenerateKey() - if err != nil { - return nil, err - } - _, err = c.messageSender.AddEphemeralKey(ephemeralKey) + ephemeralKey, err := c.messageSender.GetEphemeralKey() if err != nil { return nil, err } @@ -1688,7 +1683,7 @@ func (c *Client) queryPushNotificationInfo(publicKey *ecdsa.PublicKey) error { return err } - ephemeralKey, err := crypto.GenerateKey() + ephemeralKey, err := c.messageSender.GetEphemeralKey() if err != nil { return err } @@ -1701,11 +1696,6 @@ func (c *Client) queryPushNotificationInfo(publicKey *ecdsa.PublicKey) error { MessageType: protobuf.ApplicationMetadataMessage_PUSH_NOTIFICATION_QUERY, } - _, err = c.messageSender.AddEphemeralKey(ephemeralKey) - if err != nil { - return err - } - // this is the topic of message encodedPublicKey := hex.EncodeToString(hashedPublicKey) messageID, err := c.messageSender.SendPublic(context.Background(), encodedPublicKey, rawMessage)