Use reflect.Value as ParsedMessage type and handle type case

This commit is contained in:
Andrea Maria Piana 2020-07-27 12:13:22 +02:00
parent de79f2ced0
commit 2bf1991190
No known key found for this signature in database
GPG Key ID: AA6CCA6DE0E06424
7 changed files with 51 additions and 44 deletions

View File

@ -1,4 +1,4 @@
package protocol package common
import ( import (
"crypto/ecdsa" "crypto/ecdsa"

View File

@ -265,19 +265,22 @@ func (p *MessageProcessor) EncodeMembershipUpdate(
group *v1protocol.Group, group *v1protocol.Group,
chatEntity ChatEntity, chatEntity ChatEntity,
) ([]byte, error) { ) ([]byte, error) {
m := chatEntity.GetProtobuf().(*protobuf.ChatMessage)
e := chatEntity.GetProtobuf().(*protobuf.EmojiReaction)
if m == nil && e == nil {
return nil, errors.New("chat entity must be of type protobuf.ChatMessage or protobuf.EmojiReaction")
}
message := v1protocol.MembershipUpdateMessage{ message := v1protocol.MembershipUpdateMessage{
ChatID: group.ChatID(), ChatID: group.ChatID(),
Events: group.Events(), Events: group.Events(),
Message: m,
EmojiReaction: e,
} }
if chatEntity != nil {
chatEntityProtobuf := chatEntity.GetProtobuf()
switch chatEntityProtobuf.(type) {
case *protobuf.ChatMessage:
message.Message = chatEntityProtobuf.(*protobuf.ChatMessage)
case *protobuf.EmojiReaction:
message.EmojiReaction = chatEntityProtobuf.(*protobuf.EmojiReaction)
}
}
encodedMessage, err := v1protocol.EncodeMembershipUpdateMessage(message) encodedMessage, err := v1protocol.EncodeMembershipUpdateMessage(message)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to encode membership update message") return nil, errors.Wrap(err, "failed to encode membership update message")

View File

@ -565,7 +565,7 @@ func (m *MessageHandler) HandleDeclineRequestTransaction(messageState *ReceivedM
return m.handleCommandMessage(messageState, oldMessage) return m.handleCommandMessage(messageState, oldMessage)
} }
func (m *MessageHandler) matchChatEntity(chatEntity ChatEntity, chats map[string]*Chat, timesource TimeSource) (*Chat, error) { func (m *MessageHandler) matchChatEntity(chatEntity common.ChatEntity, chats map[string]*Chat, timesource TimeSource) (*Chat, error) {
if chatEntity.GetSigPubKey() == nil { if chatEntity.GetSigPubKey() == nil {
m.logger.Error("public key can't be empty") m.logger.Error("public key can't be empty")
return nil, errors.New("received a chatEntity with empty public key") return nil, errors.New("received a chatEntity with empty public key")
@ -581,7 +581,7 @@ func (m *MessageHandler) matchChatEntity(chatEntity ChatEntity, chats map[string
return nil, errors.New("received a public chatEntity from non-existing chat") return nil, errors.New("received a public chatEntity from non-existing chat")
} }
return chat, nil return chat, nil
case chatEntity.GetMessageType() == protobuf.MessageType_ONE_TO_ONE && common.IsPubKeyEqual(message.SigPubKey, &m.identity.PublicKey): case chatEntity.GetMessageType() == protobuf.MessageType_ONE_TO_ONE && common.IsPubKeyEqual(chatEntity.GetSigPubKey(), &m.identity.PublicKey):
// It's a private message coming from us so we rely on Message.ChatID // It's a private message coming from us so we rely on Message.ChatID
// If chat does not exist, it should be created to support multidevice synchronization. // If chat does not exist, it should be created to support multidevice synchronization.
chatID := chatEntity.GetChatId() chatID := chatEntity.GetChatId()
@ -692,3 +692,7 @@ func (m *MessageHandler) HandleEmojiReaction(state *ReceivedMessageState, pbEmoj
return nil return nil
} }
func (m *MessageHandler) HandleEmojiReactionRetraction(state *ReceivedMessageState, pbEmojiR protobuf.EmojiReactionRetraction) error {
return nil
}

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"crypto/ecdsa" "crypto/ecdsa"
"database/sql" "database/sql"
"github.com/duo-labs/webauthn.io/logger"
"io/ioutil" "io/ioutil"
"math/rand" "math/rand"
"os" "os"
@ -1852,11 +1851,11 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
if msg.ParsedMessage != nil { if msg.ParsedMessage != nil {
logger.Debug("Handling parsed message") logger.Debug("Handling parsed message")
switch msg.ParsedMessage.(type) { switch msg.ParsedMessage.Interface().(type) {
case protobuf.MembershipUpdateMessage: case protobuf.MembershipUpdateMessage:
logger.Debug("Handling MembershipUpdateMessage") logger.Debug("Handling MembershipUpdateMessage")
rawMembershipUpdate := msg.ParsedMessage.(protobuf.MembershipUpdateMessage) rawMembershipUpdate := msg.ParsedMessage.Interface().(protobuf.MembershipUpdateMessage)
err = m.handler.HandleMembershipUpdate(messageState, messageState.AllChats[rawMembershipUpdate.ChatId], rawMembershipUpdate, m.systemMessagesTranslations) err = m.handler.HandleMembershipUpdate(messageState, messageState.AllChats[rawMembershipUpdate.ChatId], rawMembershipUpdate, m.systemMessagesTranslations)
if err != nil { if err != nil {
@ -1866,7 +1865,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
case protobuf.ChatMessage: case protobuf.ChatMessage:
logger.Debug("Handling ChatMessage") logger.Debug("Handling ChatMessage")
messageState.CurrentMessageState.Message = msg.ParsedMessage.(protobuf.ChatMessage) messageState.CurrentMessageState.Message = msg.ParsedMessage.Interface().(protobuf.ChatMessage)
err = m.handler.HandleChatMessage(messageState) err = m.handler.HandleChatMessage(messageState)
if err != nil { if err != nil {
logger.Warn("failed to handle ChatMessage", zap.Error(err)) logger.Warn("failed to handle ChatMessage", zap.Error(err))
@ -1878,7 +1877,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
logger.Warn("not coming from us, ignoring") logger.Warn("not coming from us, ignoring")
continue continue
} }
p := msg.ParsedMessage.(protobuf.PairInstallation) p := msg.ParsedMessage.Interface().(protobuf.PairInstallation)
logger.Debug("Handling PairInstallation", zap.Any("message", p)) logger.Debug("Handling PairInstallation", zap.Any("message", p))
err = m.handler.HandlePairInstallation(messageState, p) err = m.handler.HandlePairInstallation(messageState, p)
if err != nil { if err != nil {
@ -1892,7 +1891,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
continue continue
} }
p := msg.ParsedMessage.(protobuf.SyncInstallationContact) p := msg.ParsedMessage.Interface().(protobuf.SyncInstallationContact)
logger.Debug("Handling SyncInstallationContact", zap.Any("message", p)) logger.Debug("Handling SyncInstallationContact", zap.Any("message", p))
err = m.handler.HandleSyncInstallationContact(messageState, p) err = m.handler.HandleSyncInstallationContact(messageState, p)
if err != nil { if err != nil {
@ -1906,7 +1905,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
continue continue
} }
p := msg.ParsedMessage.(protobuf.SyncInstallationPublicChat) p := msg.ParsedMessage.Interface().(protobuf.SyncInstallationPublicChat)
logger.Debug("Handling SyncInstallationPublicChat", zap.Any("message", p)) logger.Debug("Handling SyncInstallationPublicChat", zap.Any("message", p))
err = m.handler.HandleSyncInstallationPublicChat(messageState, p) err = m.handler.HandleSyncInstallationPublicChat(messageState, p)
if err != nil { if err != nil {
@ -1915,7 +1914,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
} }
case protobuf.RequestAddressForTransaction: case protobuf.RequestAddressForTransaction:
command := msg.ParsedMessage.(protobuf.RequestAddressForTransaction) command := msg.ParsedMessage.Interface().(protobuf.RequestAddressForTransaction)
logger.Debug("Handling RequestAddressForTransaction", zap.Any("message", command)) logger.Debug("Handling RequestAddressForTransaction", zap.Any("message", command))
err = m.handler.HandleRequestAddressForTransaction(messageState, command) err = m.handler.HandleRequestAddressForTransaction(messageState, command)
if err != nil { if err != nil {
@ -1924,7 +1923,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
} }
case protobuf.SendTransaction: case protobuf.SendTransaction:
command := msg.ParsedMessage.(protobuf.SendTransaction) command := msg.ParsedMessage.Interface().(protobuf.SendTransaction)
logger.Debug("Handling SendTransaction", zap.Any("message", command)) logger.Debug("Handling SendTransaction", zap.Any("message", command))
err = m.handler.HandleSendTransaction(messageState, command) err = m.handler.HandleSendTransaction(messageState, command)
if err != nil { if err != nil {
@ -1933,7 +1932,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
} }
case protobuf.AcceptRequestAddressForTransaction: case protobuf.AcceptRequestAddressForTransaction:
command := msg.ParsedMessage.(protobuf.AcceptRequestAddressForTransaction) command := msg.ParsedMessage.Interface().(protobuf.AcceptRequestAddressForTransaction)
logger.Debug("Handling AcceptRequestAddressForTransaction") logger.Debug("Handling AcceptRequestAddressForTransaction")
err = m.handler.HandleAcceptRequestAddressForTransaction(messageState, command) err = m.handler.HandleAcceptRequestAddressForTransaction(messageState, command)
if err != nil { if err != nil {
@ -1942,7 +1941,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
} }
case protobuf.DeclineRequestAddressForTransaction: case protobuf.DeclineRequestAddressForTransaction:
command := msg.ParsedMessage.(protobuf.DeclineRequestAddressForTransaction) command := msg.ParsedMessage.Interface().(protobuf.DeclineRequestAddressForTransaction)
logger.Debug("Handling DeclineRequestAddressForTransaction") logger.Debug("Handling DeclineRequestAddressForTransaction")
err = m.handler.HandleDeclineRequestAddressForTransaction(messageState, command) err = m.handler.HandleDeclineRequestAddressForTransaction(messageState, command)
if err != nil { if err != nil {
@ -1951,7 +1950,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
} }
case protobuf.DeclineRequestTransaction: case protobuf.DeclineRequestTransaction:
command := msg.ParsedMessage.(protobuf.DeclineRequestTransaction) command := msg.ParsedMessage.Interface().(protobuf.DeclineRequestTransaction)
logger.Debug("Handling DeclineRequestTransaction") logger.Debug("Handling DeclineRequestTransaction")
err = m.handler.HandleDeclineRequestTransaction(messageState, command) err = m.handler.HandleDeclineRequestTransaction(messageState, command)
if err != nil { if err != nil {
@ -1960,7 +1959,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
} }
case protobuf.RequestTransaction: case protobuf.RequestTransaction:
command := msg.ParsedMessage.(protobuf.RequestTransaction) command := msg.ParsedMessage.Interface().(protobuf.RequestTransaction)
logger.Debug("Handling RequestTransaction") logger.Debug("Handling RequestTransaction")
err = m.handler.HandleRequestTransaction(messageState, command) err = m.handler.HandleRequestTransaction(messageState, command)
if err != nil { if err != nil {
@ -1970,7 +1969,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
case protobuf.ContactUpdate: case protobuf.ContactUpdate:
logger.Debug("Handling ContactUpdate") logger.Debug("Handling ContactUpdate")
contactUpdate := msg.ParsedMessage.(protobuf.ContactUpdate) contactUpdate := msg.ParsedMessage.Interface().(protobuf.ContactUpdate)
err = m.handler.HandleContactUpdate(messageState, contactUpdate) err = m.handler.HandleContactUpdate(messageState, contactUpdate)
if err != nil { if err != nil {
logger.Warn("failed to handle ContactUpdate", zap.Error(err)) logger.Warn("failed to handle ContactUpdate", zap.Error(err))
@ -1982,7 +1981,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
continue continue
} }
logger.Debug("Handling PushNotificationQuery") logger.Debug("Handling PushNotificationQuery")
if err := m.pushNotificationServer.HandlePushNotificationQuery(publicKey, msg.ID, msg.ParsedMessage.(protobuf.PushNotificationQuery)); err != nil { if err := m.pushNotificationServer.HandlePushNotificationQuery(publicKey, msg.ID, msg.ParsedMessage.Interface().(protobuf.PushNotificationQuery)); err != nil {
logger.Warn("failed to handle PushNotificationQuery", zap.Error(err)) logger.Warn("failed to handle PushNotificationQuery", zap.Error(err))
} }
// We continue in any case, no changes to messenger // We continue in any case, no changes to messenger
@ -1993,7 +1992,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
continue continue
} }
logger.Debug("Handling PushNotificationRegistrationResponse") logger.Debug("Handling PushNotificationRegistrationResponse")
if err := m.pushNotificationClient.HandlePushNotificationRegistrationResponse(publicKey, msg.ParsedMessage.(protobuf.PushNotificationRegistrationResponse)); err != nil { if err := m.pushNotificationClient.HandlePushNotificationRegistrationResponse(publicKey, msg.ParsedMessage.Interface().(protobuf.PushNotificationRegistrationResponse)); err != nil {
logger.Warn("failed to handle PushNotificationRegistrationResponse", zap.Error(err)) logger.Warn("failed to handle PushNotificationRegistrationResponse", zap.Error(err))
} }
// We continue in any case, no changes to messenger // We continue in any case, no changes to messenger
@ -2004,7 +2003,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
continue continue
} }
logger.Debug("Handling PushNotificationResponse") logger.Debug("Handling PushNotificationResponse")
if err := m.pushNotificationClient.HandlePushNotificationResponse(publicKey, msg.ParsedMessage.(protobuf.PushNotificationResponse)); err != nil { if err := m.pushNotificationClient.HandlePushNotificationResponse(publicKey, msg.ParsedMessage.Interface().(protobuf.PushNotificationResponse)); err != nil {
logger.Warn("failed to handle PushNotificationResponse", zap.Error(err)) logger.Warn("failed to handle PushNotificationResponse", zap.Error(err))
} }
// We continue in any case, no changes to messenger // We continue in any case, no changes to messenger
@ -2016,7 +2015,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
continue continue
} }
logger.Debug("Handling PushNotificationQueryResponse") logger.Debug("Handling PushNotificationQueryResponse")
if err := m.pushNotificationClient.HandlePushNotificationQueryResponse(publicKey, msg.ParsedMessage.(protobuf.PushNotificationQueryResponse)); err != nil { if err := m.pushNotificationClient.HandlePushNotificationQueryResponse(publicKey, msg.ParsedMessage.Interface().(protobuf.PushNotificationQueryResponse)); err != nil {
logger.Warn("failed to handle PushNotificationQueryResponse", zap.Error(err)) logger.Warn("failed to handle PushNotificationQueryResponse", zap.Error(err))
} }
// We continue in any case, no changes to messenger // We continue in any case, no changes to messenger
@ -2028,14 +2027,14 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
continue continue
} }
logger.Debug("Handling PushNotificationRequest") logger.Debug("Handling PushNotificationRequest")
if err := m.pushNotificationServer.HandlePushNotificationRequest(publicKey, msg.ParsedMessage.(protobuf.PushNotificationRequest)); err != nil { if err := m.pushNotificationServer.HandlePushNotificationRequest(publicKey, msg.ParsedMessage.Interface().(protobuf.PushNotificationRequest)); err != nil {
logger.Warn("failed to handle PushNotificationRequest", zap.Error(err)) logger.Warn("failed to handle PushNotificationRequest", zap.Error(err))
} }
// We continue in any case, no changes to messenger // We continue in any case, no changes to messenger
continue continue
case protobuf.EmojiReaction: case protobuf.EmojiReaction:
logger.Debug("Handling EmojiReaction") logger.Debug("Handling EmojiReaction")
err = m.handler.HandleEmojiReaction(messageState, msg.ParsedMessage.(protobuf.EmojiReaction)) err = m.handler.HandleEmojiReaction(messageState, msg.ParsedMessage.Interface().(protobuf.EmojiReaction))
if err != nil { if err != nil {
logger.Warn("failed to handle EmojiReaction", zap.Error(err)) logger.Warn("failed to handle EmojiReaction", zap.Error(err))
continue continue
@ -2043,7 +2042,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
case protobuf.EmojiReactionRetraction: case protobuf.EmojiReactionRetraction:
logger.Debug("Handling EmojiReactionRetraction") logger.Debug("Handling EmojiReactionRetraction")
err = m.handler.HandleEmojiReactionRetraction(messageState, msg.ParsedMessage.(protobuf.EmojiReactionRetraction)) err = m.handler.HandleEmojiReactionRetraction(messageState, msg.ParsedMessage.Interface().(protobuf.EmojiReactionRetraction))
if err != nil { if err != nil {
logger.Warn("failed to handle EmojiReactionRetraction", zap.Error(err)) logger.Warn("failed to handle EmojiReactionRetraction", zap.Error(err))
continue continue
@ -2056,14 +2055,14 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte
continue continue
} }
logger.Debug("Handling PushNotificationRegistration") logger.Debug("Handling PushNotificationRegistration")
if err := m.pushNotificationServer.HandlePushNotificationRegistration(publicKey, msg.ParsedMessage.([]byte)); err != nil { if err := m.pushNotificationServer.HandlePushNotificationRegistration(publicKey, msg.ParsedMessage.Interface().([]byte)); err != nil {
logger.Warn("failed to handle PushNotificationRegistration", zap.Error(err)) logger.Warn("failed to handle PushNotificationRegistration", zap.Error(err))
} }
// We continue in any case, no changes to messenger // We continue in any case, no changes to messenger
continue continue
} }
logger.Debug("message not handled", zap.Any("messageType", reflect.TypeOf(msg.ParsedMessage))) logger.Debug("message not handled", zap.Any("messageType", reflect.TypeOf(msg.ParsedMessage.Interface())))
} }
} }
@ -3340,7 +3339,7 @@ func (m *Messenger) SendEmojiReactionRetraction(ctx context.Context, emojiReacti
return &response, nil return &response, nil
} }
func (m *Messenger) encodeChatEntity(chat *Chat, message ChatEntity) ([]byte, error) { func (m *Messenger) encodeChatEntity(chat *Chat, message common.ChatEntity) ([]byte, error) {
var encodedMessage []byte var encodedMessage []byte
var err error var err error
l := m.logger.With(zap.String("site", "Send"), zap.String("chatID", chat.ID)) l := m.logger.With(zap.String("site", "Send"), zap.String("chatID", chat.ID))

View File

@ -2087,7 +2087,7 @@ func (s *MessengerSuite) TestMessageJSON() {
From: "from-field", From: "from-field",
} }
expectedJSON := `{"id":"test-1","whisperTimestamp":0,"from":"from-field","alias":"alias","identicon":"","seen":false,"quotedMessage":null,"rtl":false,"parsedText":null,"lineCount":0,"text":"test-1","chatId":"remote-chat-id","localChatId":"local-chat-id","clock":1,"replace":"","responseTo":"","ensName":"","sticker":null,"emojiReaction":null,"emojiReactionRetraction":null,"commandParameters":null,"timestamp":0,"contentType":0,"messageType":0}` expectedJSON := `{"id":"test-1","whisperTimestamp":0,"from":"from-field","alias":"alias","identicon":"","seen":false,"quotedMessage":null,"rtl":false,"parsedText":null,"lineCount":0,"text":"test-1","chatId":"remote-chat-id","localChatId":"local-chat-id","clock":1,"replace":"","responseTo":"","ensName":"","sticker":null,"commandParameters":null,"timestamp":0,"contentType":0,"messageType":0}`
messageJSON, err := json.Marshal(message) messageJSON, err := json.Marshal(message)
s.Require().NoError(err) s.Require().NoError(err)

View File

@ -74,13 +74,12 @@ func (m *MembershipUpdateMessage) ToProtobuf() (*protobuf.MembershipUpdateMessag
Events: rawEvents, Events: rawEvents,
} }
// If message is not piggybacking anything, that's a valid case and we just return
switch { switch {
case m.Message != nil: case m.Message != nil:
mUM.ChatEntity = &protobuf.MembershipUpdateMessage_Message{Message: m.Message} mUM.ChatEntity = &protobuf.MembershipUpdateMessage_Message{Message: m.Message}
case m.EmojiReaction != nil: case m.EmojiReaction != nil:
mUM.ChatEntity = &protobuf.MembershipUpdateMessage_EmojiReaction{EmojiReaction: m.EmojiReaction} mUM.ChatEntity = &protobuf.MembershipUpdateMessage_EmojiReaction{EmojiReaction: m.EmojiReaction}
default:
return nil, errors.New("neither Message or EmojiReaction is set")
} }
return mUM, nil return mUM, nil

View File

@ -26,7 +26,7 @@ type StatusMessage struct {
// Type is the type of application message contained // Type is the type of application message contained
Type protobuf.ApplicationMetadataMessage_Type `json:"-"` Type protobuf.ApplicationMetadataMessage_Type `json:"-"`
// ParsedMessage is the parsed message by the application layer, i.e the output // ParsedMessage is the parsed message by the application layer, i.e the output
ParsedMessage interface{} `json:"-"` ParsedMessage *reflect.Value `json:"-"`
// TransportPayload is the payload as received from the transport layer // TransportPayload is the payload as received from the transport layer
TransportPayload []byte `json:"-"` TransportPayload []byte `json:"-"`
@ -236,7 +236,8 @@ func (m *StatusMessage) HandleApplication() error {
return m.unmarshalProtobufData(new(protobuf.EmojiReactionRetraction)) return m.unmarshalProtobufData(new(protobuf.EmojiReactionRetraction))
case protobuf.ApplicationMetadataMessage_PUSH_NOTIFICATION_REGISTRATION: case protobuf.ApplicationMetadataMessage_PUSH_NOTIFICATION_REGISTRATION:
// This message is a bit different as it's encrypted, so we pass it straight through // This message is a bit different as it's encrypted, so we pass it straight through
m.ParsedMessage = m.DecryptedPayload v := reflect.ValueOf(m.DecryptedPayload)
m.ParsedMessage = &v
return nil return nil
} }
return nil return nil
@ -257,7 +258,8 @@ func (m *StatusMessage) unmarshalProtobufData(pb proto.Message) error {
log.Printf("[message::DecodeMessage] could not decode %T: %#x, err: %v", pb, m.Hash, err.Error()) log.Printf("[message::DecodeMessage] could not decode %T: %#x, err: %v", pb, m.Hash, err.Error())
} else { } else {
rv = reflect.ValueOf(ptr) rv = reflect.ValueOf(ptr)
m.ParsedMessage = rv.Elem() elem := rv.Elem()
m.ParsedMessage = &elem
return nil return nil
} }