From 2bf19911907324da1bb8a760cb8a4ed1737e215c Mon Sep 17 00:00:00 2001 From: Andrea Maria Piana Date: Mon, 27 Jul 2020 12:13:22 +0200 Subject: [PATCH] Use reflect.Value as ParsedMessage type and handle type case --- protocol/{ => common}/chat_entity.go | 2 +- protocol/common/message_processor.go | 25 +++++++------ protocol/message_handler.go | 8 +++- protocol/messenger.go | 47 ++++++++++++------------ protocol/messenger_test.go | 2 +- protocol/v1/membership_update_message.go | 3 +- protocol/v1/status_message.go | 8 ++-- 7 files changed, 51 insertions(+), 44 deletions(-) rename protocol/{ => common}/chat_entity.go (95%) diff --git a/protocol/chat_entity.go b/protocol/common/chat_entity.go similarity index 95% rename from protocol/chat_entity.go rename to protocol/common/chat_entity.go index d369cdda6..78e2dc4da 100644 --- a/protocol/chat_entity.go +++ b/protocol/common/chat_entity.go @@ -1,4 +1,4 @@ -package protocol +package common import ( "crypto/ecdsa" diff --git a/protocol/common/message_processor.go b/protocol/common/message_processor.go index 5b3475c5d..bacbb131e 100644 --- a/protocol/common/message_processor.go +++ b/protocol/common/message_processor.go @@ -265,19 +265,22 @@ func (p *MessageProcessor) EncodeMembershipUpdate( group *v1protocol.Group, chatEntity ChatEntity, ) ([]byte, error) { - m := chatEntity.GetProtobuf().(*protobuf.ChatMessage) - e := chatEntity.GetProtobuf().(*protobuf.EmojiReaction) - - if m == nil && e == nil { - return nil, errors.New("chat entity must be of type protobuf.ChatMessage or protobuf.EmojiReaction") - } - message := v1protocol.MembershipUpdateMessage{ - ChatID: group.ChatID(), - Events: group.Events(), - Message: m, - EmojiReaction: e, + ChatID: group.ChatID(), + Events: group.Events(), } + + if chatEntity != nil { + chatEntityProtobuf := chatEntity.GetProtobuf() + switch chatEntityProtobuf.(type) { + case *protobuf.ChatMessage: + message.Message = chatEntityProtobuf.(*protobuf.ChatMessage) + case *protobuf.EmojiReaction: + message.EmojiReaction = chatEntityProtobuf.(*protobuf.EmojiReaction) + + } + } + encodedMessage, err := v1protocol.EncodeMembershipUpdateMessage(message) if err != nil { return nil, errors.Wrap(err, "failed to encode membership update message") diff --git a/protocol/message_handler.go b/protocol/message_handler.go index f1de98e9c..5ca27bf02 100644 --- a/protocol/message_handler.go +++ b/protocol/message_handler.go @@ -565,7 +565,7 @@ func (m *MessageHandler) HandleDeclineRequestTransaction(messageState *ReceivedM return m.handleCommandMessage(messageState, oldMessage) } -func (m *MessageHandler) matchChatEntity(chatEntity ChatEntity, chats map[string]*Chat, timesource TimeSource) (*Chat, error) { +func (m *MessageHandler) matchChatEntity(chatEntity common.ChatEntity, chats map[string]*Chat, timesource TimeSource) (*Chat, error) { if chatEntity.GetSigPubKey() == nil { m.logger.Error("public key can't be empty") return nil, errors.New("received a chatEntity with empty public key") @@ -581,7 +581,7 @@ func (m *MessageHandler) matchChatEntity(chatEntity ChatEntity, chats map[string return nil, errors.New("received a public chatEntity from non-existing chat") } return chat, nil - case chatEntity.GetMessageType() == protobuf.MessageType_ONE_TO_ONE && common.IsPubKeyEqual(message.SigPubKey, &m.identity.PublicKey): + case chatEntity.GetMessageType() == protobuf.MessageType_ONE_TO_ONE && common.IsPubKeyEqual(chatEntity.GetSigPubKey(), &m.identity.PublicKey): // It's a private message coming from us so we rely on Message.ChatID // If chat does not exist, it should be created to support multidevice synchronization. chatID := chatEntity.GetChatId() @@ -692,3 +692,7 @@ func (m *MessageHandler) HandleEmojiReaction(state *ReceivedMessageState, pbEmoj return nil } + +func (m *MessageHandler) HandleEmojiReactionRetraction(state *ReceivedMessageState, pbEmojiR protobuf.EmojiReactionRetraction) error { + return nil +} diff --git a/protocol/messenger.go b/protocol/messenger.go index f69681431..392d84b41 100644 --- a/protocol/messenger.go +++ b/protocol/messenger.go @@ -4,7 +4,6 @@ import ( "context" "crypto/ecdsa" "database/sql" - "github.com/duo-labs/webauthn.io/logger" "io/ioutil" "math/rand" "os" @@ -1852,11 +1851,11 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte if msg.ParsedMessage != nil { logger.Debug("Handling parsed message") - switch msg.ParsedMessage.(type) { + switch msg.ParsedMessage.Interface().(type) { case protobuf.MembershipUpdateMessage: logger.Debug("Handling MembershipUpdateMessage") - rawMembershipUpdate := msg.ParsedMessage.(protobuf.MembershipUpdateMessage) + rawMembershipUpdate := msg.ParsedMessage.Interface().(protobuf.MembershipUpdateMessage) err = m.handler.HandleMembershipUpdate(messageState, messageState.AllChats[rawMembershipUpdate.ChatId], rawMembershipUpdate, m.systemMessagesTranslations) if err != nil { @@ -1866,7 +1865,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte case protobuf.ChatMessage: logger.Debug("Handling ChatMessage") - messageState.CurrentMessageState.Message = msg.ParsedMessage.(protobuf.ChatMessage) + messageState.CurrentMessageState.Message = msg.ParsedMessage.Interface().(protobuf.ChatMessage) err = m.handler.HandleChatMessage(messageState) if err != nil { logger.Warn("failed to handle ChatMessage", zap.Error(err)) @@ -1878,7 +1877,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte logger.Warn("not coming from us, ignoring") continue } - p := msg.ParsedMessage.(protobuf.PairInstallation) + p := msg.ParsedMessage.Interface().(protobuf.PairInstallation) logger.Debug("Handling PairInstallation", zap.Any("message", p)) err = m.handler.HandlePairInstallation(messageState, p) if err != nil { @@ -1892,7 +1891,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte continue } - p := msg.ParsedMessage.(protobuf.SyncInstallationContact) + p := msg.ParsedMessage.Interface().(protobuf.SyncInstallationContact) logger.Debug("Handling SyncInstallationContact", zap.Any("message", p)) err = m.handler.HandleSyncInstallationContact(messageState, p) if err != nil { @@ -1906,7 +1905,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte continue } - p := msg.ParsedMessage.(protobuf.SyncInstallationPublicChat) + p := msg.ParsedMessage.Interface().(protobuf.SyncInstallationPublicChat) logger.Debug("Handling SyncInstallationPublicChat", zap.Any("message", p)) err = m.handler.HandleSyncInstallationPublicChat(messageState, p) if err != nil { @@ -1915,7 +1914,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte } case protobuf.RequestAddressForTransaction: - command := msg.ParsedMessage.(protobuf.RequestAddressForTransaction) + command := msg.ParsedMessage.Interface().(protobuf.RequestAddressForTransaction) logger.Debug("Handling RequestAddressForTransaction", zap.Any("message", command)) err = m.handler.HandleRequestAddressForTransaction(messageState, command) if err != nil { @@ -1924,7 +1923,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte } case protobuf.SendTransaction: - command := msg.ParsedMessage.(protobuf.SendTransaction) + command := msg.ParsedMessage.Interface().(protobuf.SendTransaction) logger.Debug("Handling SendTransaction", zap.Any("message", command)) err = m.handler.HandleSendTransaction(messageState, command) if err != nil { @@ -1933,7 +1932,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte } case protobuf.AcceptRequestAddressForTransaction: - command := msg.ParsedMessage.(protobuf.AcceptRequestAddressForTransaction) + command := msg.ParsedMessage.Interface().(protobuf.AcceptRequestAddressForTransaction) logger.Debug("Handling AcceptRequestAddressForTransaction") err = m.handler.HandleAcceptRequestAddressForTransaction(messageState, command) if err != nil { @@ -1942,7 +1941,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte } case protobuf.DeclineRequestAddressForTransaction: - command := msg.ParsedMessage.(protobuf.DeclineRequestAddressForTransaction) + command := msg.ParsedMessage.Interface().(protobuf.DeclineRequestAddressForTransaction) logger.Debug("Handling DeclineRequestAddressForTransaction") err = m.handler.HandleDeclineRequestAddressForTransaction(messageState, command) if err != nil { @@ -1951,7 +1950,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte } case protobuf.DeclineRequestTransaction: - command := msg.ParsedMessage.(protobuf.DeclineRequestTransaction) + command := msg.ParsedMessage.Interface().(protobuf.DeclineRequestTransaction) logger.Debug("Handling DeclineRequestTransaction") err = m.handler.HandleDeclineRequestTransaction(messageState, command) if err != nil { @@ -1960,7 +1959,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte } case protobuf.RequestTransaction: - command := msg.ParsedMessage.(protobuf.RequestTransaction) + command := msg.ParsedMessage.Interface().(protobuf.RequestTransaction) logger.Debug("Handling RequestTransaction") err = m.handler.HandleRequestTransaction(messageState, command) if err != nil { @@ -1970,7 +1969,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte case protobuf.ContactUpdate: logger.Debug("Handling ContactUpdate") - contactUpdate := msg.ParsedMessage.(protobuf.ContactUpdate) + contactUpdate := msg.ParsedMessage.Interface().(protobuf.ContactUpdate) err = m.handler.HandleContactUpdate(messageState, contactUpdate) if err != nil { logger.Warn("failed to handle ContactUpdate", zap.Error(err)) @@ -1982,7 +1981,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte continue } logger.Debug("Handling PushNotificationQuery") - if err := m.pushNotificationServer.HandlePushNotificationQuery(publicKey, msg.ID, msg.ParsedMessage.(protobuf.PushNotificationQuery)); err != nil { + if err := m.pushNotificationServer.HandlePushNotificationQuery(publicKey, msg.ID, msg.ParsedMessage.Interface().(protobuf.PushNotificationQuery)); err != nil { logger.Warn("failed to handle PushNotificationQuery", zap.Error(err)) } // We continue in any case, no changes to messenger @@ -1993,7 +1992,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte continue } logger.Debug("Handling PushNotificationRegistrationResponse") - if err := m.pushNotificationClient.HandlePushNotificationRegistrationResponse(publicKey, msg.ParsedMessage.(protobuf.PushNotificationRegistrationResponse)); err != nil { + if err := m.pushNotificationClient.HandlePushNotificationRegistrationResponse(publicKey, msg.ParsedMessage.Interface().(protobuf.PushNotificationRegistrationResponse)); err != nil { logger.Warn("failed to handle PushNotificationRegistrationResponse", zap.Error(err)) } // We continue in any case, no changes to messenger @@ -2004,7 +2003,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte continue } logger.Debug("Handling PushNotificationResponse") - if err := m.pushNotificationClient.HandlePushNotificationResponse(publicKey, msg.ParsedMessage.(protobuf.PushNotificationResponse)); err != nil { + if err := m.pushNotificationClient.HandlePushNotificationResponse(publicKey, msg.ParsedMessage.Interface().(protobuf.PushNotificationResponse)); err != nil { logger.Warn("failed to handle PushNotificationResponse", zap.Error(err)) } // We continue in any case, no changes to messenger @@ -2016,7 +2015,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte continue } logger.Debug("Handling PushNotificationQueryResponse") - if err := m.pushNotificationClient.HandlePushNotificationQueryResponse(publicKey, msg.ParsedMessage.(protobuf.PushNotificationQueryResponse)); err != nil { + if err := m.pushNotificationClient.HandlePushNotificationQueryResponse(publicKey, msg.ParsedMessage.Interface().(protobuf.PushNotificationQueryResponse)); err != nil { logger.Warn("failed to handle PushNotificationQueryResponse", zap.Error(err)) } // We continue in any case, no changes to messenger @@ -2028,14 +2027,14 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte continue } logger.Debug("Handling PushNotificationRequest") - if err := m.pushNotificationServer.HandlePushNotificationRequest(publicKey, msg.ParsedMessage.(protobuf.PushNotificationRequest)); err != nil { + if err := m.pushNotificationServer.HandlePushNotificationRequest(publicKey, msg.ParsedMessage.Interface().(protobuf.PushNotificationRequest)); err != nil { logger.Warn("failed to handle PushNotificationRequest", zap.Error(err)) } // We continue in any case, no changes to messenger continue case protobuf.EmojiReaction: logger.Debug("Handling EmojiReaction") - err = m.handler.HandleEmojiReaction(messageState, msg.ParsedMessage.(protobuf.EmojiReaction)) + err = m.handler.HandleEmojiReaction(messageState, msg.ParsedMessage.Interface().(protobuf.EmojiReaction)) if err != nil { logger.Warn("failed to handle EmojiReaction", zap.Error(err)) continue @@ -2043,7 +2042,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte case protobuf.EmojiReactionRetraction: logger.Debug("Handling EmojiReactionRetraction") - err = m.handler.HandleEmojiReactionRetraction(messageState, msg.ParsedMessage.(protobuf.EmojiReactionRetraction)) + err = m.handler.HandleEmojiReactionRetraction(messageState, msg.ParsedMessage.Interface().(protobuf.EmojiReactionRetraction)) if err != nil { logger.Warn("failed to handle EmojiReactionRetraction", zap.Error(err)) continue @@ -2056,14 +2055,14 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte continue } logger.Debug("Handling PushNotificationRegistration") - if err := m.pushNotificationServer.HandlePushNotificationRegistration(publicKey, msg.ParsedMessage.([]byte)); err != nil { + if err := m.pushNotificationServer.HandlePushNotificationRegistration(publicKey, msg.ParsedMessage.Interface().([]byte)); err != nil { logger.Warn("failed to handle PushNotificationRegistration", zap.Error(err)) } // We continue in any case, no changes to messenger continue } - logger.Debug("message not handled", zap.Any("messageType", reflect.TypeOf(msg.ParsedMessage))) + logger.Debug("message not handled", zap.Any("messageType", reflect.TypeOf(msg.ParsedMessage.Interface()))) } } @@ -3340,7 +3339,7 @@ func (m *Messenger) SendEmojiReactionRetraction(ctx context.Context, emojiReacti return &response, nil } -func (m *Messenger) encodeChatEntity(chat *Chat, message ChatEntity) ([]byte, error) { +func (m *Messenger) encodeChatEntity(chat *Chat, message common.ChatEntity) ([]byte, error) { var encodedMessage []byte var err error l := m.logger.With(zap.String("site", "Send"), zap.String("chatID", chat.ID)) diff --git a/protocol/messenger_test.go b/protocol/messenger_test.go index 511039645..d180184db 100644 --- a/protocol/messenger_test.go +++ b/protocol/messenger_test.go @@ -2087,7 +2087,7 @@ func (s *MessengerSuite) TestMessageJSON() { From: "from-field", } - expectedJSON := `{"id":"test-1","whisperTimestamp":0,"from":"from-field","alias":"alias","identicon":"","seen":false,"quotedMessage":null,"rtl":false,"parsedText":null,"lineCount":0,"text":"test-1","chatId":"remote-chat-id","localChatId":"local-chat-id","clock":1,"replace":"","responseTo":"","ensName":"","sticker":null,"emojiReaction":null,"emojiReactionRetraction":null,"commandParameters":null,"timestamp":0,"contentType":0,"messageType":0}` + expectedJSON := `{"id":"test-1","whisperTimestamp":0,"from":"from-field","alias":"alias","identicon":"","seen":false,"quotedMessage":null,"rtl":false,"parsedText":null,"lineCount":0,"text":"test-1","chatId":"remote-chat-id","localChatId":"local-chat-id","clock":1,"replace":"","responseTo":"","ensName":"","sticker":null,"commandParameters":null,"timestamp":0,"contentType":0,"messageType":0}` messageJSON, err := json.Marshal(message) s.Require().NoError(err) diff --git a/protocol/v1/membership_update_message.go b/protocol/v1/membership_update_message.go index 0daec84c6..b65862484 100644 --- a/protocol/v1/membership_update_message.go +++ b/protocol/v1/membership_update_message.go @@ -74,13 +74,12 @@ func (m *MembershipUpdateMessage) ToProtobuf() (*protobuf.MembershipUpdateMessag Events: rawEvents, } + // If message is not piggybacking anything, that's a valid case and we just return switch { case m.Message != nil: mUM.ChatEntity = &protobuf.MembershipUpdateMessage_Message{Message: m.Message} case m.EmojiReaction != nil: mUM.ChatEntity = &protobuf.MembershipUpdateMessage_EmojiReaction{EmojiReaction: m.EmojiReaction} - default: - return nil, errors.New("neither Message or EmojiReaction is set") } return mUM, nil diff --git a/protocol/v1/status_message.go b/protocol/v1/status_message.go index bd565a2ec..0aa317d91 100644 --- a/protocol/v1/status_message.go +++ b/protocol/v1/status_message.go @@ -26,7 +26,7 @@ type StatusMessage struct { // Type is the type of application message contained Type protobuf.ApplicationMetadataMessage_Type `json:"-"` // ParsedMessage is the parsed message by the application layer, i.e the output - ParsedMessage interface{} `json:"-"` + ParsedMessage *reflect.Value `json:"-"` // TransportPayload is the payload as received from the transport layer TransportPayload []byte `json:"-"` @@ -236,7 +236,8 @@ func (m *StatusMessage) HandleApplication() error { return m.unmarshalProtobufData(new(protobuf.EmojiReactionRetraction)) case protobuf.ApplicationMetadataMessage_PUSH_NOTIFICATION_REGISTRATION: // This message is a bit different as it's encrypted, so we pass it straight through - m.ParsedMessage = m.DecryptedPayload + v := reflect.ValueOf(m.DecryptedPayload) + m.ParsedMessage = &v return nil } return nil @@ -257,7 +258,8 @@ func (m *StatusMessage) unmarshalProtobufData(pb proto.Message) error { log.Printf("[message::DecodeMessage] could not decode %T: %#x, err: %v", pb, m.Hash, err.Error()) } else { rv = reflect.ValueOf(ptr) - m.ParsedMessage = rv.Elem() + elem := rv.Elem() + m.ParsedMessage = &elem return nil }