refactor: eliminate logic duplication in MessageSender.HandleMessages

This commit is contained in:
Patryk Osmaczko 2023-11-09 14:04:05 +01:00 committed by osmaczko
parent ab6cb85690
commit fa44e03ac2
1 changed files with 72 additions and 51 deletions

View File

@ -741,33 +741,28 @@ func unwrapDatasyncMessage(m *v1protocol.StatusMessage, datasync *datasync.DataS
// layer message, or in case of Raw methods, the processing stops at the layer
// before.
// It returns an error only if the processing of required steps failed.
func (s *MessageSender) HandleMessages(shhMessage *types.Message) ([]*v1protocol.StatusMessage, [][]byte, error) {
logger := s.logger.With(zap.String("site", "handleMessages"))
hlogger := logger.With(zap.ByteString("hash", shhMessage.Hash))
var statusMessage v1protocol.StatusMessage
func (s *MessageSender) HandleMessages(wakuMessage *types.Message) ([]*v1protocol.StatusMessage, [][]byte, error) {
logger := s.logger.With(zap.String("site", "HandleMessages"))
hlogger := logger.With(zap.ByteString("hash", wakuMessage.Hash))
var statusMessages []*v1protocol.StatusMessage
var acks [][]byte
err := statusMessage.HandleTransportLayer(shhMessage)
response, err := s.handleMessage(wakuMessage)
if err != nil {
hlogger.Error("failed to handle transport layer message", zap.Error(err))
// Hash ratchet with a group id not found yet, save the message for future processing
if err == encryption.ErrHashRatchetGroupIDNotFound && len(response.Message.EncryptionLayer.HashRatchetInfo) == 1 {
info := response.Message.EncryptionLayer.HashRatchetInfo[0]
return nil, nil, s.persistence.SaveHashRatchetMessage(info.GroupID, info.KeyID, wakuMessage)
}
return nil, nil, err
}
statusMessages = append(statusMessages, response.Messages()...)
acks = append(acks, response.DatasyncAcks...)
err = s.handleEncryptionLayer(context.Background(), &statusMessage)
if err != nil {
hlogger.Debug("failed to handle an encryption message", zap.Error(err))
}
// Hash ratchet with a group id not found yet
if err == encryption.ErrHashRatchetGroupIDNotFound && len(statusMessage.EncryptionLayer.HashRatchetInfo) == 1 {
info := statusMessage.EncryptionLayer.HashRatchetInfo[0]
err := s.persistence.SaveHashRatchetMessage(info.GroupID, info.KeyID, shhMessage)
return nil, nil, err
}
// Check if there are undecrypted message
for _, hashRatchetInfo := range statusMessage.EncryptionLayer.HashRatchetInfo {
// Process queued hash ratchet messages
for _, hashRatchetInfo := range response.Message.EncryptionLayer.HashRatchetInfo {
messages, err := s.persistence.GetHashRatchetMessages(hashRatchetInfo.KeyID)
if err != nil {
return nil, nil, err
@ -775,56 +770,82 @@ func (s *MessageSender) HandleMessages(shhMessage *types.Message) ([]*v1protocol
var processedIds [][]byte
for _, message := range messages {
var statusMessage v1protocol.StatusMessage
err := statusMessage.HandleTransportLayer(message)
response, err := s.handleMessage(message)
if err != nil {
hlogger.Error("failed to handle transport layer message", zap.Error(err))
return nil, nil, err
}
err = s.handleEncryptionLayer(context.Background(), &statusMessage)
if err != nil {
hlogger.Debug("failed to handle an encryption message", zap.Error(err))
hlogger.Debug("failed to handle hash ratchet message", zap.Error(err))
continue
}
statusMessages = append(statusMessages, response.Messages()...)
acks = append(acks, response.DatasyncAcks...)
processedIds = append(processedIds, message.Hash)
stms, as, err := unwrapDatasyncMessage(&statusMessage, s.datasync)
if err != nil {
hlogger.Debug("failed to handle datasync message", zap.Error(err))
//that wasn't a datasync message, so use the original payload
statusMessages = append(statusMessages, &statusMessage)
} else {
statusMessages = append(statusMessages, stms...)
acks = append(acks, as...)
}
}
err = s.persistence.DeleteHashRatchetMessages(processedIds)
if err != nil {
s.logger.Warn("failed to delete hash ratchet messages", zap.Error(err))
return nil, nil, err
}
}
stms, as, err := unwrapDatasyncMessage(&statusMessage, s.datasync)
if err != nil {
hlogger.Debug("failed to handle datasync message", zap.Error(err))
//that wasn't a datasync message, so use the original payload
statusMessages = append(statusMessages, &statusMessage)
} else {
statusMessages = append(statusMessages, stms...)
acks = append(acks, as...)
return statusMessages, acks, nil
}
type handleMessageResponse struct {
Message *v1protocol.StatusMessage
DatasyncMessages []*v1protocol.StatusMessage
DatasyncAcks [][]byte
}
func (h *handleMessageResponse) Messages() []*v1protocol.StatusMessage {
if len(h.DatasyncMessages) > 0 {
return h.DatasyncMessages
}
return []*v1protocol.StatusMessage{h.Message}
}
func (s *MessageSender) handleMessage(wakuMessage *types.Message) (*handleMessageResponse, error) {
logger := s.logger.With(zap.String("site", "handleMessage"))
hlogger := logger.With(zap.ByteString("hash", wakuMessage.Hash))
response := &handleMessageResponse{
Message: &v1protocol.StatusMessage{},
DatasyncMessages: []*v1protocol.StatusMessage{},
DatasyncAcks: [][]byte{},
}
for _, statusMessage := range statusMessages {
err := statusMessage.HandleApplicationLayer()
err := response.Message.HandleTransportLayer(wakuMessage)
if err != nil {
hlogger.Error("failed to handle transport layer message", zap.Error(err))
return nil, err
}
err = s.handleEncryptionLayer(context.Background(), response.Message)
if err != nil {
hlogger.Debug("failed to handle an encryption message", zap.Error(err))
// Hash ratchet with a group id not found yet, stop processing
if err == encryption.ErrHashRatchetGroupIDNotFound {
return response, err
}
}
datasyncMessages, as, err := unwrapDatasyncMessage(response.Message, s.datasync)
if err != nil {
hlogger.Debug("failed to handle datasync message", zap.Error(err))
} else {
response.DatasyncMessages = append(response.DatasyncMessages, datasyncMessages...)
response.DatasyncAcks = append(response.DatasyncAcks, as...)
}
for _, msg := range response.Messages() {
err := msg.HandleApplicationLayer()
if err != nil {
hlogger.Error("failed to handle application metadata layer message", zap.Error(err))
}
}
return statusMessages, acks, nil
return response, nil
}
// fetchDecryptionKey returns the private key associated with this public key, and returns true if it's an ephemeral key