diff --git a/protocol/common/message_sender.go b/protocol/common/message_sender.go index f7013b8e6..75ca16beb 100644 --- a/protocol/common/message_sender.go +++ b/protocol/common/message_sender.go @@ -713,26 +713,30 @@ func (s *MessageSender) SendPublic( // unwrapDatasyncMessage tries to unwrap message as datasync one and in case of success // returns cloned messages with replaced payloads -func unwrapDatasyncMessage(m *v1protocol.StatusMessage, datasync *datasync.DataSync) ([]*v1protocol.StatusMessage, [][]byte, error) { - var statusMessages []*v1protocol.StatusMessage +func (s *MessageSender) unwrapDatasyncMessage(m *v1protocol.StatusMessage, response *handleMessageResponse) error { - payloads, acks, err := datasync.UnwrapPayloadsAndAcks( + datasyncMessage, err := s.datasync.Unwrap( m.SigPubKey(), m.EncryptionLayer.Payload, ) if err != nil { - return nil, nil, err + return err } - for _, payload := range payloads { + response.DatasyncAcks = append(response.DatasyncAcks, datasyncMessage.Acks...) + response.DatasyncOffers = append(response.DatasyncAcks, datasyncMessage.Offers...) + response.DatasyncRequests = append(response.DatasyncRequests, datasyncMessage.Requests...) + + for _, ds := range datasyncMessage.Messages { message, err := m.Clone() if err != nil { - return nil, nil, err + return err } - message.EncryptionLayer.Payload = payload - statusMessages = append(statusMessages, message) + message.EncryptionLayer.Payload = ds.Body + response.DatasyncMessages = append(response.DatasyncMessages, message) + } - return statusMessages, acks, nil + return nil } // HandleMessages expects a whisper message as input, and it will go through @@ -740,48 +744,43 @@ func unwrapDatasyncMessage(m *v1protocol.StatusMessage, datasync *datasync.DataS // layer message, or in case of Raw methods, the processing stops at the layer // before. // It returns an error only if the processing of required steps failed. -func (s *MessageSender) HandleMessages(wakuMessage *types.Message) ([]*v1protocol.StatusMessage, [][]byte, error) { +func (s *MessageSender) HandleMessages(wakuMessage *types.Message) (*HandleMessageResponse, error) { logger := s.logger.With(zap.String("site", "HandleMessages")) hlogger := logger.With(zap.String("hash", types.HexBytes(wakuMessage.Hash).String())) - var statusMessages []*v1protocol.StatusMessage - var acks [][]byte - response, err := s.handleMessage(wakuMessage) if err != nil { // Hash ratchet with a group id not found yet, save the message for future processing if err == encryption.ErrHashRatchetGroupIDNotFound && len(response.Message.EncryptionLayer.HashRatchetInfo) == 1 { info := response.Message.EncryptionLayer.HashRatchetInfo[0] - return nil, nil, s.persistence.SaveHashRatchetMessage(info.GroupID, info.KeyID, wakuMessage) + return nil, s.persistence.SaveHashRatchetMessage(info.GroupID, info.KeyID, wakuMessage) } // The current message segment has been successfully retrieved. // However, the collection of segments is not yet complete. if err == ErrMessageSegmentsIncomplete { - return nil, nil, nil + return nil, nil } - return nil, nil, err + return nil, err } - statusMessages = append(statusMessages, response.Messages()...) - acks = append(acks, response.DatasyncAcks...) // Process queued hash ratchet messages for _, hashRatchetInfo := range response.Message.EncryptionLayer.HashRatchetInfo { messages, err := s.persistence.GetHashRatchetMessages(hashRatchetInfo.KeyID) if err != nil { - return nil, nil, err + return nil, err } var processedIds [][]byte for _, message := range messages { - response, err := s.handleMessage(message) + r, err := s.handleMessage(message) if err != nil { hlogger.Debug("failed to handle hash ratchet message", zap.Error(err)) continue } - statusMessages = append(statusMessages, response.Messages()...) - acks = append(acks, response.DatasyncAcks...) + response.DatasyncMessages = append(response.toPublicResponse().StatusMessages, r.Messages()...) + response.DatasyncAcks = append(response.DatasyncAcks, r.DatasyncAcks...) processedIds = append(processedIds, message.Hash) } @@ -789,17 +788,35 @@ func (s *MessageSender) HandleMessages(wakuMessage *types.Message) ([]*v1protoco err = s.persistence.DeleteHashRatchetMessages(processedIds) if err != nil { s.logger.Warn("failed to delete hash ratchet messages", zap.Error(err)) - return nil, nil, err + return nil, err } } - return statusMessages, acks, nil + return response.toPublicResponse(), nil +} + +type HandleMessageResponse struct { + StatusMessages []*v1protocol.StatusMessage + DatasyncAcks [][]byte + DatasyncOffers [][]byte + DatasyncRequests [][]byte +} + +func (h *handleMessageResponse) toPublicResponse() *HandleMessageResponse { + return &HandleMessageResponse{ + StatusMessages: h.Messages(), + DatasyncAcks: h.DatasyncAcks, + DatasyncOffers: h.DatasyncOffers, + DatasyncRequests: h.DatasyncRequests, + } } type handleMessageResponse struct { Message *v1protocol.StatusMessage DatasyncMessages []*v1protocol.StatusMessage DatasyncAcks [][]byte + DatasyncOffers [][]byte + DatasyncRequests [][]byte } func (h *handleMessageResponse) Messages() []*v1protocol.StatusMessage { @@ -813,19 +830,21 @@ func (s *MessageSender) handleMessage(wakuMessage *types.Message) (*handleMessag logger := s.logger.With(zap.String("site", "handleMessage")) hlogger := logger.With(zap.ByteString("hash", wakuMessage.Hash)) + message := &v1protocol.StatusMessage{} + response := &handleMessageResponse{ - Message: &v1protocol.StatusMessage{}, + Message: message, DatasyncMessages: []*v1protocol.StatusMessage{}, DatasyncAcks: [][]byte{}, } - err := response.Message.HandleTransportLayer(wakuMessage) + err := message.HandleTransportLayer(wakuMessage) if err != nil { hlogger.Error("failed to handle transport layer message", zap.Error(err)) return nil, err } - err = s.handleSegmentationLayer(response.Message) + err = s.handleSegmentationLayer(message) if err != nil { hlogger.Debug("failed to handle segmentation layer message", zap.Error(err)) @@ -839,7 +858,7 @@ func (s *MessageSender) handleMessage(wakuMessage *types.Message) (*handleMessag } } - err = s.handleEncryptionLayer(context.Background(), response.Message) + err = s.handleEncryptionLayer(context.Background(), message) if err != nil { hlogger.Debug("failed to handle an encryption message", zap.Error(err)) @@ -850,12 +869,9 @@ func (s *MessageSender) handleMessage(wakuMessage *types.Message) (*handleMessag } if s.datasync != nil && s.datasyncEnabled { - datasyncMessages, as, err := unwrapDatasyncMessage(response.Message, s.datasync) + err := s.unwrapDatasyncMessage(message, response) if err != nil { hlogger.Debug("failed to handle datasync message", zap.Error(err)) - } else { - response.DatasyncMessages = append(response.DatasyncMessages, datasyncMessages...) - response.DatasyncAcks = append(response.DatasyncAcks, as...) } } diff --git a/protocol/common/message_sender_test.go b/protocol/common/message_sender_test.go index cb6bb0b3c..96b60cfaa 100644 --- a/protocol/common/message_sender_test.go +++ b/protocol/common/message_sender_test.go @@ -118,8 +118,9 @@ func (s *MessageSenderSuite) TestHandleDecodedMessagesWrapped() { message.Sig = crypto.FromECDSAPub(&relayerKey.PublicKey) message.Payload = wrappedPayload - decodedMessages, _, err := s.sender.HandleMessages(message) + response, err := s.sender.HandleMessages(message) s.Require().NoError(err) + decodedMessages := response.StatusMessages s.Require().Equal(1, len(decodedMessages)) s.Require().Equal(&authorKey.PublicKey, decodedMessages[0].SigPubKey()) @@ -152,8 +153,9 @@ func (s *MessageSenderSuite) TestHandleDecodedMessagesDatasync() { message.Sig = crypto.FromECDSAPub(&relayerKey.PublicKey) message.Payload = marshalledDataSyncMessage - decodedMessages, _, err := s.sender.HandleMessages(message) + response, err := s.sender.HandleMessages(message) s.Require().NoError(err) + decodedMessages := response.StatusMessages // We send two messages, the unwrapped one will be attributed to the relayer, while the wrapped one will be attributed to the author s.Require().Equal(1, len(decodedMessages)) @@ -217,8 +219,9 @@ func (s *MessageSenderSuite) TestHandleDecodedMessagesDatasyncEncrypted() { message.Sig = crypto.FromECDSAPub(&relayerKey.PublicKey) message.Payload = encryptedPayload - decodedMessages, _, err := s.sender.HandleMessages(message) + response, err := s.sender.HandleMessages(message) s.Require().NoError(err) + decodedMessages := response.StatusMessages // We send two messages, the unwrapped one will be attributed to the relayer, // while the wrapped one will be attributed to the author. @@ -277,7 +280,7 @@ func (s *MessageSenderSuite) TestHandleOutOfOrderHashRatchet() { message.Hash = []byte{0x1} message.Payload = encryptedPayload2 - _, _, err = s.sender.HandleMessages(message) + _, err = s.sender.HandleMessages(message) s.Require().NoError(err) keyID, err := ratchet.GetKeyID() @@ -293,8 +296,9 @@ func (s *MessageSenderSuite) TestHandleOutOfOrderHashRatchet() { message.Hash = []byte{0x2} message.Payload = encryptedPayload1 - decodedMessages2, _, err := s.sender.HandleMessages(message) + response, err := s.sender.HandleMessages(message) s.Require().NoError(err) + decodedMessages2 := response.StatusMessages s.Require().NotNil(decodedMessages2) // It should have 2 messages, the key exchange and the one from the database @@ -330,14 +334,16 @@ func (s *MessageSenderSuite) TestHandleSegmentMessages() { message.Payload = segmentedMessages[0].Payload // First segment is received, no messages are decoded - decodedMessages, _, err := s.sender.HandleMessages(message) + response, err := s.sender.HandleMessages(message) s.Require().NoError(err) - s.Require().Len(decodedMessages, 0) + s.Require().Nil(response) // Second (and final) segment is received, reassembled message is decoded message.Payload = segmentedMessages[1].Payload - decodedMessages, _, err = s.sender.HandleMessages(message) + response, err = s.sender.HandleMessages(message) s.Require().NoError(err) + + decodedMessages := response.StatusMessages s.Require().Len(decodedMessages, 1) s.Require().Equal(&authorKey.PublicKey, decodedMessages[0].SigPubKey()) s.Require().Equal(v1protocol.MessageID(&authorKey.PublicKey, wrappedPayload), decodedMessages[0].ApplicationLayer.ID) @@ -345,6 +351,6 @@ func (s *MessageSenderSuite) TestHandleSegmentMessages() { s.Require().Equal(protobuf.ApplicationMetadataMessage_CHAT_MESSAGE, decodedMessages[0].ApplicationLayer.Type) // Receiving another segment after the message has been reassembled is considered an error - _, _, err = s.sender.HandleMessages(message) + _, err = s.sender.HandleMessages(message) s.Require().ErrorIs(err, ErrMessageSegmentsAlreadyCompleted) } diff --git a/protocol/datasync/datasync.go b/protocol/datasync/datasync.go index 528e5368e..790e03980 100644 --- a/protocol/datasync/datasync.go +++ b/protocol/datasync/datasync.go @@ -6,6 +6,7 @@ import ( "github.com/golang/protobuf/proto" datasyncnode "github.com/vacp2p/mvds/node" + "github.com/vacp2p/mvds/protobuf" datasyncproto "github.com/vacp2p/mvds/protobuf" datasynctransport "github.com/vacp2p/mvds/transport" "go.uber.org/zap" @@ -25,35 +26,25 @@ func New(node *datasyncnode.Node, transport *NodeTransport, sendingEnabled bool, return &DataSync{Node: node, NodeTransport: transport, sendingEnabled: sendingEnabled, logger: logger} } -// UnwrapPayloadsAndAcks tries to unwrap datasync message and return messages payloads -// and acknowledgements for previously sent messages -func (d *DataSync) UnwrapPayloadsAndAcks(sender *ecdsa.PublicKey, payload []byte) ([][]byte, [][]byte, error) { - var payloads [][]byte - var acks [][]byte +// Unwrap tries to unwrap datasync message and passes back the message to datasync in order to acknowledge any potential message and mark messages as acknowledged +func (d *DataSync) Unwrap(sender *ecdsa.PublicKey, payload []byte) (*protobuf.Payload, error) { logger := d.logger.With(zap.String("site", "Handle")) datasyncMessage, err := unwrap(payload) // If it failed to decode is not a protobuf message, if it successfully decoded but body is empty, is likedly a protobuf wrapped message if err != nil { logger.Debug("Unwrapping datasync message failed", zap.Error(err)) - return nil, nil, err + return nil, err } else if !datasyncMessage.IsValid() { - return nil, nil, errors.New("handling non-datasync message") + return nil, errors.New("handling non-datasync message") } else { logger.Debug("handling datasync message") - // datasync message - for _, message := range datasyncMessage.Messages { - payloads = append(payloads, message.Body) - } - - acks = append(acks, datasyncMessage.Acks...) - if d.sendingEnabled { d.add(sender, datasyncMessage) } } - return payloads, acks, nil + return &datasyncMessage, nil } func (d *DataSync) Stop() { diff --git a/protocol/messenger.go b/protocol/messenger.go index 9ba98e483..86450c50d 100644 --- a/protocol/messenger.go +++ b/protocol/messenger.go @@ -3545,11 +3545,12 @@ func (m *Messenger) handleImportedMessages(messagesToHandle map[transport.Filter for filter, messages := range messagesToHandle { for _, shhMessage := range messages { - statusMessages, _, err := m.sender.HandleMessages(shhMessage) + handleMessageResponse, err := m.sender.HandleMessages(shhMessage) if err != nil { logger.Info("failed to decode messages", zap.Error(err)) continue } + statusMessages := handleMessageResponse.StatusMessages for _, msg := range statusMessages { logger := logger.With(zap.String("message-id", msg.TransportLayer.Message.ThirdPartyID)) @@ -3697,7 +3698,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte } } - statusMessages, acks, err := m.sender.HandleMessages(shhMessage) + handleMessagesResponse, err := m.sender.HandleMessages(shhMessage) if err != nil { if m.telemetryClient != nil { go m.telemetryClient.UpdateEnvelopeProcessingError(shhMessage, err) @@ -3706,10 +3707,16 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte continue } + if handleMessagesResponse == nil { + continue + } + + statusMessages := handleMessagesResponse.StatusMessages + if m.telemetryClient != nil { go m.telemetryClient.PushReceivedMessages(filter, shhMessage, statusMessages) } - m.markDeliveredMessages(acks) + m.markDeliveredMessages(handleMessagesResponse.DatasyncAcks) logger.Debug("processing messages further", zap.Int("count", len(statusMessages)))