diff --git a/protocol/edit_message.go b/protocol/edit_message.go index 98b4a5160..b58667c68 100644 --- a/protocol/edit_message.go +++ b/protocol/edit_message.go @@ -13,6 +13,8 @@ import ( type EditMessage struct { protobuf.EditMessage + ID string `json:"id",omitempty"` + // From is a public key of the author of the edit reaction. From string `json:"from,omitempty"` diff --git a/protocol/message_handler.go b/protocol/message_handler.go index 2fb827e07..2340f505f 100644 --- a/protocol/message_handler.go +++ b/protocol/message_handler.go @@ -512,7 +512,49 @@ func (m *MessageHandler) handleWrappedCommunityDescriptionMessage(payload []byte return m.communitiesManager.HandleWrappedCommunityDescriptionMessage(payload) } -func (m *MessageHandler) HandleEditMessage(state *ReceivedMessageState) error { +func (m *Messenger) HandleEditMessage(response *MessengerResponse, editMessage EditMessage) error { + messageID := editMessage.MessageId + // Check if it's already in the response + originalMessage := response.GetMessage(messageID) + // otherwise pull from database + if originalMessage == nil { + var err error + originalMessage, err = m.persistence.MessageByID(messageID) + + if err != nil { + return err + } + } + + // We don't have the original message, save the edited message + if originalMessage == nil { + return m.persistence.SaveEdit(editMessage) + } + + chat, ok := m.allChats.Load(originalMessage.LocalChatID) + if !ok { + return errors.New("chat not found") + } + + // Check edit is valid + if originalMessage.From != editMessage.From { + return errors.New("invalid edit, not the right author") + } + + // Check that edit should be applied + if originalMessage.EditedAt >= editMessage.Clock { + return m.persistence.SaveEdit(editMessage) + } + + // Update message and return it + err := m.applyEditMessage(&editMessage.EditMessage, originalMessage) + if err != nil { + return err + } + + response.AddMessage(originalMessage) + response.AddChat(chat) + return nil } @@ -1114,37 +1156,6 @@ func (m *MessageHandler) HandleChatIdentity(state *ReceivedMessageState, ci prot return nil } -func (m *MessageHandler) handleEditedMessage(state *ReceivedMessageState, message *protobuf.EditMessage) (bool, error) { - /* - originalMessageID := message.OriginalMessageId - // Check if it's already in the response - originalMessage := state.Response.GetMessage(originalMessageID) - // otherwise pull from database - if originalMessage == nil { - originalMessage, err := m.persistence.MessageByID(originalMessageID) - - if err != nil { - return false, err - } - } - - // We don't have the original message, save the edited message - if originalMessage == nil { - // Save edit and return - //m.persistence.SaveMessageEdit() - return false, nil - - } - - // Check edit is valid - - // Check that edit should be applied - - // Update message and return it */ - - return true, nil -} - func (m *MessageHandler) checkForEdits(message *common.Message) error { // Check for any pending edit // If any pending edits are available and valid, apply them diff --git a/protocol/message_persistence.go b/protocol/message_persistence.go index 452dc6036..5e9762a3f 100644 --- a/protocol/message_persistence.go +++ b/protocol/message_persistence.go @@ -326,7 +326,7 @@ func (db sqlitePersistence) tableUserMessagesAllValues(message *common.Message) command.CommandState, command.Signature, message.Replace, - message.EditedAt, + int64(message.EditedAt), message.RTL, message.LineCount, message.ResponseTo, @@ -1427,6 +1427,11 @@ func (db sqlitePersistence) deactivateChat(chat *Chat, currentClockValue uint64, return db.clearHistory(chat, currentClockValue, tx, true) } +func (db sqlitePersistence) SaveEdit(editMessage EditMessage) error { + _, err := db.db.Exec(`INSERT INTO user_messages_edits (clock, chat_id, message_id, source, text, id) VALUES(?,?,?,?,?,?)`, editMessage.Clock, editMessage.ChatId, editMessage.MessageId, editMessage.Text, editMessage.From, editMessage.ID) + return err +} + func (db sqlitePersistence) clearHistory(chat *Chat, currentClockValue uint64, tx *sql.Tx, deactivate bool) error { // Set deleted at clock value if it's not a public chat so that // old messages will be discarded, or if it's a straight clear history diff --git a/protocol/messenger.go b/protocol/messenger.go index 131061b4a..042d2f9f8 100644 --- a/protocol/messenger.go +++ b/protocol/messenger.go @@ -2460,6 +2460,7 @@ func (r *ReceivedMessageState) addNewActivityCenterNotification(publicKey ecdsa. } func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filter][]*types.Message) (*MessengerResponse, error) { + response := &MessengerResponse{} messageState := &ReceivedMessageState{ AllChats: m.allChats, AllContacts: m.allContacts, @@ -2469,7 +2470,7 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte ExistingMessagesMap: make(map[string]bool), EmojiReactions: make(map[string]*EmojiReaction), GroupChatInvitations: make(map[string]*GroupChatInvitation), - Response: &MessengerResponse{}, + Response: response, Timesource: m.getTimesource(), } @@ -2563,6 +2564,22 @@ func (m *Messenger) handleRetrievedMessages(chatWithMessages map[transport.Filte continue } + case protobuf.EditMessage: + logger.Debug("Handling EditMessage") + editProto := msg.ParsedMessage.Interface().(protobuf.EditMessage) + editMessage := EditMessage{ + EditMessage: editProto, + From: contact.ID, + ID: messageID, + SigPubKey: publicKey, + } + err = m.HandleEditMessage(response, editMessage) + if err != nil { + logger.Warn("failed to handle EditMessage", zap.Error(err)) + allMessagesProcessed = false + continue + } + case protobuf.PinMessage: pinMessage := msg.ParsedMessage.Interface().(protobuf.PinMessage) err = m.handler.HandlePinMessage(messageState, pinMessage) diff --git a/protocol/messenger_edit_message_test.go b/protocol/messenger_edit_message_test.go index bce9ebf7b..6b8c647aa 100644 --- a/protocol/messenger_edit_message_test.go +++ b/protocol/messenger_edit_message_test.go @@ -11,6 +11,7 @@ import ( gethbridge "github.com/status-im/status-go/eth-node/bridge/geth" "github.com/status-im/status-go/eth-node/crypto" "github.com/status-im/status-go/eth-node/types" + "github.com/status-im/status-go/protocol/protobuf" "github.com/status-im/status-go/protocol/requests" "github.com/status-im/status-go/protocol/tt" "github.com/status-im/status-go/waku" @@ -121,5 +122,101 @@ func (s *MessengerEditMessageSuite) TestEditMessage() { } _, err = s.m.EditMessage(context.Background(), editedMessage) - s.Require().Equal(ErrInvalidEditAuthor, err.Error()) + s.Require().Equal(ErrInvalidEditAuthor, err) +} + +func (s *MessengerEditMessageSuite) TestEditMessageEdgeCases() { + theirMessenger := s.newMessenger() + _, err := theirMessenger.Start() + s.Require().NoError(err) + + theirChat := CreateOneToOneChat("Their 1TO1", &s.privateKey.PublicKey, s.m.transport) + err = theirMessenger.SaveChat(theirChat) + s.Require().NoError(err) + + ourChat := CreateOneToOneChat("Our 1TO1", &theirMessenger.identity.PublicKey, s.m.transport) + err = s.m.SaveChat(ourChat) + s.Require().NoError(err) + + inputMessage := buildTestMessage(*theirChat) + sendResponse, err := theirMessenger.SendChatMessage(context.Background(), inputMessage) + s.NoError(err) + s.Require().Len(sendResponse.Messages(), 1) + + response, err := WaitOnMessengerResponse( + s.m, + func(r *MessengerResponse) bool { return len(r.messages) > 0 }, + "no messages", + ) + s.Require().NoError(err) + s.Require().Len(response.Chats(), 1) + s.Require().Len(response.Messages(), 1) + + chat := response.Chats()[0] + editedMessage := sendResponse.Messages()[0] + + newContactKey, err := crypto.GenerateKey() + s.Require().NoError(err) + wrongContact, err := BuildContactFromPublicKey(&newContactKey.PublicKey) + s.Require().NoError(err) + + editMessage := EditMessage{ + EditMessage: protobuf.EditMessage{ + Clock: editedMessage.Clock + 1, + Text: "some text", + MessageId: editedMessage.ID, + ChatId: chat.ID, + }, + From: wrongContact.ID, + } + + response = &MessengerResponse{} + + err = s.m.HandleEditMessage(response, editMessage) + // It should error as the user can't edit this message + s.Require().Error(err) + + // Edit with a newer clock value + + response = &MessengerResponse{} + + contact, err := BuildContactFromPublicKey(&theirMessenger.identity.PublicKey) + s.Require().NoError(err) + + editMessage = EditMessage{ + EditMessage: protobuf.EditMessage{ + Clock: editedMessage.Clock + 2, + Text: "some text", + MessageId: editedMessage.ID, + ChatId: chat.ID, + }, + From: contact.ID, + } + + err = s.m.HandleEditMessage(response, editMessage) + // It should error as the user can't edit this message + s.Require().NoError(err) + // It save the edit + s.Require().Len(response.Messages(), 1) + + editedMessage = response.Messages()[0] + + // In-between edit + editMessage = EditMessage{ + EditMessage: protobuf.EditMessage{ + Clock: editedMessage.Clock + 1, + Text: "some other text", + MessageId: editedMessage.ID, + ChatId: chat.ID, + }, + From: contact.ID, + } + + response = &MessengerResponse{} + + err = s.m.HandleEditMessage(response, editMessage) + // It should error as the user can't edit this message + s.Require().NoError(err) + // It discards the edit + s.Require().Len(response.Messages(), 0) } diff --git a/protocol/messenger_messages.go b/protocol/messenger_messages.go index ac0875311..bbaf1c489 100644 --- a/protocol/messenger_messages.go +++ b/protocol/messenger_messages.go @@ -45,19 +45,6 @@ func (m *Messenger) EditMessage(ctx context.Context, request *requests.EditMessa clock, _ := chat.NextClockAndTimestamp(m.getTimesource()) - message.Text = request.Text - message.EditedAt = clock - - err = message.PrepareContent(common.PubkeyToHex(&m.identity.PublicKey)) - if err != nil { - return nil, err - } - - err = m.persistence.SaveMessages([]*common.Message{message}) - if err != nil { - return nil, err - } - editMessage := &EditMessage{} editMessage.Text = request.Text @@ -65,6 +52,11 @@ func (m *Messenger) EditMessage(ctx context.Context, request *requests.EditMessa editMessage.MessageId = request.ID.String() editMessage.Clock = clock + err = m.applyEditMessage(&editMessage.EditMessage, message) + if err != nil { + return nil, err + } + encodedMessage, err := m.encodeChatEntity(chat, editMessage) if err != nil { return nil, err @@ -86,3 +78,15 @@ func (m *Messenger) EditMessage(ctx context.Context, request *requests.EditMessa return response, nil } + +func (m *Messenger) applyEditMessage(editMessage *protobuf.EditMessage, message *common.Message) error { + message.Text = editMessage.Text + message.EditedAt = editMessage.Clock + + err := message.PrepareContent(common.PubkeyToHex(&m.identity.PublicKey)) + if err != nil { + return err + } + + return m.persistence.SaveMessages([]*common.Message{message}) +} diff --git a/protocol/migrations/migrations.go b/protocol/migrations/migrations.go index 00f5f8b48..e125e53f4 100644 --- a/protocol/migrations/migrations.go +++ b/protocol/migrations/migrations.go @@ -36,7 +36,7 @@ // 1622464518_set_synced_to_from.up.sql (105B) // 1622464519_add_chat_description.up.sql (93B) // 1622622253_add_pinned_by_to_pin_messages.up.sql (52B) -// 1622722745_add_original_message_id.up.sql (254B) +// 1622722745_add_original_message_id.up.sql (273B) // 1623938329_add_author_activity_center_notification_field.up.sql (66B) // README.md (554B) // doc.go (850B) @@ -828,7 +828,7 @@ func _1622622253_add_pinned_by_to_pin_messagesUpSql() (*asset, error) { return a, nil } -var __1622722745_add_original_message_idUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x74\xcc\x41\x0a\xc2\x30\x10\x85\xe1\x7d\x4e\xf1\x96\x0a\xde\xa0\xab\xb1\x1d\x54\x88\x29\x84\xd4\x6d\x28\xcd\xa0\x45\xa5\xd0\x49\xc1\xe3\x4b\x41\x17\x42\x5c\x7f\xef\xfd\x64\x03\x7b\x04\xda\x5b\xc6\xa2\x32\xc7\xa7\xa8\xf6\x57\x51\x50\xd3\xa0\x6e\x6d\x77\x76\x90\x34\x66\x49\xb1\xcf\x38\xb9\xc0\x07\xf6\x95\x31\xb5\x67\x0a\x5c\x7a\xc6\x75\xae\xd8\x18\x60\x78\x4c\xc3\xfd\x7b\x82\x6b\x03\x5c\x67\xed\x6e\x95\x5b\x9f\xe3\x98\x70\x21\x5f\x1f\xe9\xd7\x3e\xa1\x7f\xac\xd3\x32\x0f\x52\xa4\x2c\xaf\x5c\x84\x42\xca\x6c\x2b\xf3\x0e\x00\x00\xff\xff\x78\xbe\xc2\xbb\xfe\x00\x00\x00") +var __1622722745_add_original_message_idUpSql = []byte("\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\xff\x74\xcc\xc1\x8a\x83\x30\x14\x85\xe1\x7d\x9e\xe2\x2c\x15\xe6\x0d\x5c\xdd\xd1\xcb\x8c\x4c\x8c\x43\x88\x05\x57\x41\xcc\xa5\x95\xb6\x08\x26\x42\x1f\xbf\x08\xed\xa2\xa0\xeb\xef\x9c\x9f\xb4\x63\x0b\x47\xdf\x9a\xb1\x46\x59\xfc\x5d\x62\x1c\xce\x12\x41\x55\x85\xb2\xd5\x5d\x63\x20\x61\x4a\x12\xfc\x90\x50\x1b\xc7\x3f\x6c\x0b\xa5\x4a\xcb\xe4\x78\xef\xe9\xb7\x79\x44\xa6\x80\xf1\x36\x8f\xd7\xf7\x09\xa6\x75\x30\x9d\xd6\x5f\x9b\x5c\x86\xe4\xa7\x80\x13\xd9\xf2\x97\x3e\xed\x15\x3a\xe2\x38\xaf\xcb\x28\xbb\x94\xe4\x91\x76\xe1\x20\xf5\x6f\xeb\x86\x6c\x8f\x3f\xee\xb3\x29\xe4\x2a\x2f\xd4\x33\x00\x00\xff\xff\x94\x32\x5e\xe6\x11\x01\x00\x00") func _1622722745_add_original_message_idUpSqlBytes() ([]byte, error) { return bindataRead( @@ -843,8 +843,8 @@ func _1622722745_add_original_message_idUpSql() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "1622722745_add_original_message_id.up.sql", size: 254, mode: os.FileMode(0644), modTime: time.Unix(1624368035, 0)} - a := &asset{bytes: bytes, info: info, digest: [32]uint8{0x12, 0x7e, 0x5e, 0x10, 0xbe, 0xe6, 0xdf, 0xb7, 0xbe, 0xce, 0x67, 0xcf, 0x63, 0xae, 0x4, 0x80, 0xab, 0xc3, 0x74, 0x9, 0x3b, 0x6b, 0x48, 0xa9, 0xd0, 0x79, 0xbe, 0x2d, 0xb7, 0x0, 0x5, 0xfc}} + info := bindataFileInfo{name: "1622722745_add_original_message_id.up.sql", size: 273, mode: os.FileMode(0644), modTime: time.Unix(1624368044, 0)} + a := &asset{bytes: bytes, info: info, digest: [32]uint8{0xa8, 0x47, 0x48, 0x84, 0x7a, 0x2f, 0x30, 0x5c, 0x33, 0xa4, 0x42, 0xfb, 0x7d, 0xe1, 0xa6, 0x46, 0x9d, 0x20, 0x19, 0x99, 0x56, 0xbb, 0x9f, 0xd, 0xe4, 0x6b, 0x99, 0x29, 0xe5, 0xef, 0xef, 0x58}} return a, nil } diff --git a/protocol/migrations/sqlite/1622722745_add_original_message_id.up.sql b/protocol/migrations/sqlite/1622722745_add_original_message_id.up.sql index e939ee52a..5b8505a6f 100644 --- a/protocol/migrations/sqlite/1622722745_add_original_message_id.up.sql +++ b/protocol/migrations/sqlite/1622722745_add_original_message_id.up.sql @@ -6,5 +6,6 @@ CREATE TABLE user_messages_edits ( message_id VARCHAR NOT NULL, source VARCHAR NOT NULL, text VARCHAR NOT NULL, - id VARCHAR NOT NULL + id VARCHAR NOT NULL, + PRIMARY KEY(id) );