From 9e7f1338f0e152e201456210ac82e06434bbba68 Mon Sep 17 00:00:00 2001 From: frank Date: Fri, 16 Jun 2023 23:09:37 +0800 Subject: [PATCH] refactor function diffText (#3615) --- VERSION | 2 +- protocol/messenger_mention.go | 240 ++++++++++++++-------------- protocol/messenger_mention_test.go | 245 ++++++++++++++++++++++++----- 3 files changed, 334 insertions(+), 153 deletions(-) diff --git a/VERSION b/VERSION index 4397dd171..c03605ef9 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.158.0 +0.158.1 diff --git a/protocol/messenger_mention.go b/protocol/messenger_mention.go index 80e839202..820a666bd 100644 --- a/protocol/messenger_mention.go +++ b/protocol/messenger_mention.go @@ -139,10 +139,11 @@ type MentionState struct { AtSignIdx int AtIdxs []*AtIndexEntry MentionEnd int - PreviousText string // searched text - NewText *string // the matched username - Start int // position after the @ - End int // position of the end of newText + PreviousText string + NewText string + Start int + End int + operation textOperation } func (ms *MentionState) String() string { @@ -154,7 +155,7 @@ func (ms *MentionState) String() string { atIdxsStr += fmt.Sprintf("%+v", entry) } return fmt.Sprintf("MentionState{AtSignIdx: %d, AtIdxs: [%s], MentionEnd: %d, PreviousText: %q, NewText: %s, Start: %d, End: %d}", - ms.AtSignIdx, atIdxsStr, ms.MentionEnd, ms.PreviousText, *ms.NewText, ms.Start, ms.End) + ms.AtSignIdx, atIdxsStr, ms.MentionEnd, ms.PreviousText, ms.NewText, ms.Start, ms.End) } type ChatMentionContext struct { @@ -206,7 +207,7 @@ func (m *MentionManager) getChatMentionContext(chatID string) *ChatMentionContex } func (m *MentionManager) getMentionableUser(chatID string, pk string) (*MentionableUser, error) { - mentionableUsers, err := m.getMentionableUsers(chatID) + mentionableUsers, err := m.mentionableUserGetter.getMentionableUsers(chatID) if err != nil { return nil, err } @@ -283,7 +284,7 @@ func (m *MentionManager) ReplaceWithPublicKey(chatID, text string) (string, erro if chat == nil { return "", fmt.Errorf("chat not found when check mentions, chatID: %s", chatID) } - mentionableUsers, err := m.getMentionableUsers(chatID) + mentionableUsers, err := m.mentionableUserGetter.getMentionableUsers(chatID) if err != nil { return "", err } @@ -292,28 +293,29 @@ func (m *MentionManager) ReplaceWithPublicKey(chatID, text string) (string, erro return newText, nil } -func (m *MentionManager) OnChangeText(chatID, text string) (*ChatMentionContext, error) { +func (m *MentionManager) OnChangeText(chatID, fullText string) (*ChatMentionContext, error) { ctx := m.getChatMentionContext(chatID) - diff := diffText(ctx.PreviousText, text) + diff := diffText(ctx.PreviousText, fullText) if diff == nil { return ctx, nil } - ctx.PreviousText = text + ctx.PreviousText = fullText if ctx.MentionState == nil { ctx.MentionState = &MentionState{} } ctx.MentionState.PreviousText = diff.previousText - ctx.MentionState.NewText = &diff.newText + ctx.MentionState.NewText = diff.newText ctx.MentionState.Start = diff.start ctx.MentionState.End = diff.end + ctx.MentionState.operation = diff.operation ctx.MentionState.AtIdxs = calcAtIdxs(ctx.MentionState) m.logger.Debug("OnChangeText", zap.String("chatID", chatID), zap.Any("state", ctx.MentionState)) - return m.CalculateSuggestions(chatID, text) + return m.calculateSuggestions(chatID, fullText) } func (m *MentionManager) recheckAtIdxs(chatID string, text string, publicKey string) (*ChatMentionContext, error) { - user, err := m.getMentionableUser(chatID, publicKey) + user, err := m.mentionableUserGetter.getMentionableUser(chatID, publicKey) if err != nil { return nil, err } @@ -326,58 +328,61 @@ func (m *MentionManager) recheckAtIdxs(chatID string, text string, publicKey str return ctx, nil } -func (m *MentionManager) CalculateSuggestions(chatID, text string) (*ChatMentionContext, error) { +func (m *MentionManager) calculateSuggestions(chatID, fullText string) (*ChatMentionContext, error) { ctx := m.getChatMentionContext(chatID) mentionableUsers, err := m.mentionableUserGetter.getMentionableUsers(chatID) if err != nil { return nil, err } - m.logger.Debug("CalculateSuggestions", zap.String("chatID", chatID), zap.String("text", text), zap.Int("num of mentionable user", len(mentionableUsers))) + m.logger.Debug("calculateSuggestions", zap.String("chatID", chatID), zap.String("fullText", fullText), zap.Int("num of mentionable user", len(mentionableUsers))) - m.calculateSuggestions(chatID, text, mentionableUsers) + m.calculateSuggestionsWithMentionableUsers(chatID, fullText, mentionableUsers) return ctx, nil } -func (m *MentionManager) calculateSuggestions(chatID string, text string, mentionableUsers map[string]*MentionableUser) { +func (m *MentionManager) calculateSuggestionsWithMentionableUsers(chatID string, fullText string, mentionableUsers map[string]*MentionableUser) { ctx := m.getChatMentionContext(chatID) state := ctx.MentionState - newText := state.NewText - if newText == nil { - newText = &text - } if len(state.AtIdxs) == 0 { state.AtIdxs = nil ctx.MentionSuggestions = nil ctx.InputSegments = []InputSegment{{ Type: Text, - Value: text, + Value: fullText, }} return } - newAtIdxs := checkIdxForMentions(text, state.AtIdxs, mentionableUsers) - calculatedInput := calculateInput(text, newAtIdxs) - addition := state.Start <= state.End - var end int - if addition { - end = state.Start + len([]rune(*newText)) - } else { - end = state.Start - } - atSignIdx := lastIndexOf(text, charAtSign, state.End) - searchedText := strings.ToLower(subs(text, atSignIdx+1, end)) - m.logger.Debug("calculateSuggestions", zap.Int("atSignIdx", atSignIdx), zap.String("searchedText", searchedText), zap.String("text", text), zap.Any("state", state)) + newAtIndexEntries := checkIdxForMentions(fullText, state.AtIdxs, mentionableUsers) + calculatedInput := calculateInput(fullText, newAtIndexEntries) + var end int + switch state.operation { + case textOperationAdd: + end = state.Start + len([]rune(state.NewText)) + case textOperationDelete: + end = state.Start + case textOperationReplace: + end = state.Start + len([]rune(state.NewText)) + default: + m.logger.Error("calculateSuggestionsWithMentionableUsers: unknown textOperation", zap.String("chatID", chatID), zap.String("fullText", fullText), zap.Any("state", state)) + } + + atSignIdx := lastIndexOf(fullText, charAtSign, end) var suggestions map[string]*MentionableUser - if (atSignIdx <= state.Start && end-atSignIdx <= 100) || text[len(text)-1] == charAtSign[0] { - suggestions = getUserSuggestions(mentionableUsers, searchedText, -1) + if atSignIdx != -1 { + searchedText := strings.ToLower(subs(fullText, atSignIdx+1, end)) + m.logger.Debug("calculateSuggestionsWithMentionableUsers", zap.Int("atSignIdx", atSignIdx), zap.String("searchedText", searchedText), zap.String("fullText", fullText), zap.Any("state", state), zap.Int("end", end)) + if end-atSignIdx <= 100 { + suggestions = getUserSuggestions(mentionableUsers, searchedText, -1) + } } state.AtSignIdx = atSignIdx - state.AtIdxs = newAtIdxs + state.AtIdxs = newAtIndexEntries state.MentionEnd = end ctx.InputSegments = calculatedInput ctx.MentionSuggestions = suggestions @@ -403,7 +408,13 @@ func (m *MentionManager) SelectMention(chatID, text, primaryName, publicKey stri ctx.NewText = string(tr[:atSignIdx+1]) + primaryName + space + string(tr[mentionEnd:]) - return m.recheckAtIdxs(chatID, ctx.NewText, publicKey) + ctx, err := m.recheckAtIdxs(chatID, ctx.NewText, publicKey) + if err != nil { + return nil, err + } + ctx.PreviousText = ctx.NewText + m.clearSuggestions(chatID) + return ctx, nil } func (m *MentionManager) clearSuggestions(chatID string) { @@ -420,7 +431,7 @@ func (m *MentionManager) ClearMentions(chatID string) { } func (m *MentionManager) ToInputField(chatID, text string) (*ChatMentionContext, error) { - mentionableUsers, err := m.getMentionableUsers(chatID) + mentionableUsers, err := m.mentionableUserGetter.getMentionableUsers(chatID) if err != nil { return nil, err } @@ -821,25 +832,20 @@ type AtIndexEntry struct { Checked bool Mentioned bool - Mention bool NextAtIdx int } +func (e *AtIndexEntry) String() string { + return fmt.Sprintf("{From: %d, To: %d, Checked: %t, Mentioned: %t, NextAtIdx: %d}", e.From, e.To, e.Checked, e.Mentioned, e.NextAtIdx) +} + // implementation reference: https://github.com/status-im/status-react/blob/04d0252e013d9c67862e77a3467dd32c3abde934/src/status_im/chat/models/mentions.cljs#L433 func calcAtIdxs(state *MentionState) []*AtIndexEntry { - newIdxs := getAtSignIdxs(*state.NewText, state.Start) - newIdxCnt := len(newIdxs) - var lastNewIdx *int - if newIdxCnt > 0 { - idx := newIdxs[newIdxCnt-1] - lastNewIdx = &idx - } - newTextLen := len([]rune(*state.NewText)) - oldTextLen := len([]rune(state.PreviousText)) - oldEnd := state.Start + oldTextLen + newAtSignIndexes := getAtSignIdxs(state.NewText, state.Start) + newAtSignIndexesCount := len(newAtSignIndexes) if len(state.AtIdxs) == 0 { - result := make([]*AtIndexEntry, newIdxCnt) - for i, idx := range newIdxs { + result := make([]*AtIndexEntry, newAtSignIndexesCount) + for i, idx := range newAtSignIndexes { result[i] = &AtIndexEntry{ From: idx, Checked: false, @@ -848,6 +854,9 @@ func calcAtIdxs(state *MentionState) []*AtIndexEntry { return result } + newTextLen := len([]rune(state.NewText)) + oldTextLen := len([]rune(state.PreviousText)) + oldEnd := state.Start + oldTextLen diff := newTextLen - oldTextLen var keptAtIdxs []*AtIndexEntry for _, entry := range state.AtIdxs { @@ -870,21 +879,26 @@ func calcAtIdxs(state *MentionState) []*AtIndexEntry { } } - var newState []*AtIndexEntry + var newAtIndexEntries []*AtIndexEntry var added bool + var lastNewIdx *int + if newAtSignIndexesCount > 0 { + idx := newAtSignIndexes[newAtSignIndexesCount-1] + lastNewIdx = &idx + } for _, entry := range keptAtIdxs { if lastNewIdx != nil && entry.From > *lastNewIdx && !added { - newState = append(newState, makeAtIdxs(newIdxs)...) - newState = append(newState, entry) + newAtIndexEntries = append(newAtIndexEntries, makeAtIdxs(newAtSignIndexes)...) + newAtIndexEntries = append(newAtIndexEntries, entry) added = true } else { - newState = append(newState, entry) + newAtIndexEntries = append(newAtIndexEntries, entry) } } if !added { - newState = append(newState, makeAtIdxs(newIdxs)...) + newAtIndexEntries = append(newAtIndexEntries, makeAtIdxs(newAtSignIndexes)...) } - return newState + return newAtIndexEntries } func makeAtIdxs(idxs []int) []*AtIndexEntry { @@ -898,26 +912,28 @@ func makeAtIdxs(idxs []int) []*AtIndexEntry { return result } -func getAtSignIdxs(text string, start int) []int { - return getAtSignIdxsHelper(text, start, 0, []int{}) +// getAtSignIdxs returns the indexes of all @ signs in the text. +// delta is the offset of the text within the original text. +func getAtSignIdxs(text string, delta int) []int { + return getAtSignIdxsHelper(text, delta, 0, []int{}) } -func getAtSignIdxsHelper(text string, start int, from int, idxs []int) []int { +func getAtSignIdxsHelper(text string, delta int, from int, idxs []int) []int { tr := []rune(text) idx := strings.Index(string(tr[from:]), charAtSign) if idx != -1 { idx += from - idxs = append(idxs, start+idx) - return getAtSignIdxsHelper(text, start, idx+1, idxs) + idxs = append(idxs, delta+idx) + return getAtSignIdxsHelper(text, delta, idx+1, idxs) } return idxs } -func checkEntry(text string, entry *AtIndexEntry, mentionableUsers map[string]*MentionableUser) *AtIndexEntry { +func checkAtIndexEntry(fullText string, entry *AtIndexEntry, mentionableUsers map[string]*MentionableUser) *AtIndexEntry { if entry.Checked { return entry } - result := MatchMention(text+charAtSign, mentionableUsers, entry.From) + result := MatchMention(fullText+charAtSign, mentionableUsers, entry.From) if result != nil && result.Match != "" { return &AtIndexEntry{ From: entry.From, @@ -928,36 +944,34 @@ func checkEntry(text string, entry *AtIndexEntry, mentionableUsers map[string]*M } return &AtIndexEntry{ From: entry.From, - To: len([]rune(text)), + To: len([]rune(fullText)), Checked: true, - Mention: false, // Mention vs Mentioned? wrong spelling? } } -func checkIdxForMentions(text string, idxs []*AtIndexEntry, mentionableUsers map[string]*MentionableUser) []*AtIndexEntry { - var newIdxs []*AtIndexEntry - for _, entry := range idxs { - previousEntryIdx := len(newIdxs) - 1 - newEntry := checkEntry(text, entry, mentionableUsers) - if previousEntryIdx >= 0 && !newIdxs[previousEntryIdx].Mentioned { - newIdxs[previousEntryIdx].To = entry.From - 1 +func checkIdxForMentions(fullText string, currentAtIndexEntries []*AtIndexEntry, mentionableUsers map[string]*MentionableUser) []*AtIndexEntry { + var newIndexEntries []*AtIndexEntry + for _, entry := range currentAtIndexEntries { + previousEntryIdx := len(newIndexEntries) - 1 + newEntry := checkAtIndexEntry(fullText, entry, mentionableUsers) + if previousEntryIdx >= 0 && !newIndexEntries[previousEntryIdx].Mentioned { + newIndexEntries[previousEntryIdx].To = entry.From - 1 } if previousEntryIdx >= 0 { - newIdxs[previousEntryIdx].NextAtIdx = entry.From + newIndexEntries[previousEntryIdx].NextAtIdx = entry.From } - // simulate (dissoc new-entry :next-at-idx) newEntry.NextAtIdx = intUnknown - newIdxs = append(newIdxs, newEntry) + newIndexEntries = append(newIndexEntries, newEntry) } - if len(newIdxs) > 0 { - lastIdx := len(newIdxs) - 1 - if newIdxs[lastIdx].Mentioned { - return newIdxs + if len(newIndexEntries) > 0 { + lastIdx := len(newIndexEntries) - 1 + if newIndexEntries[lastIdx].Mentioned { + return newIndexEntries } - newIdxs[lastIdx].To = len([]rune(text)) - 1 - newIdxs[lastIdx].Checked = false - return newIdxs + newIndexEntries[lastIdx].To = len([]rune(fullText)) - 1 + newIndexEntries[lastIdx].Checked = false + return newIndexEntries } return nil @@ -1129,7 +1143,7 @@ func toInfo(inputSegments []InputSegment) *MentionState { AtIdxs: []*AtIndexEntry{}, MentionEnd: 0, PreviousText: "", - NewText: &newText, + NewText: newText, Start: intUnknown, } @@ -1157,8 +1171,7 @@ func toInfo(inputSegments []InputSegment) *MentionState { } state.MentionEnd += len(tr) - nt := string(tr[len(tr)-1]) - state.NewText = &nt + state.NewText = string(tr[len(tr)-1]) state.Start += len(tr) state.End += len(tr) } @@ -1201,26 +1214,20 @@ func reverse(r []rune) string { return string(r) } +type textOperation int + +const ( + textOperationAdd textOperation = iota + 1 + textOperationDelete + textOperationReplace +) + type TextDiff struct { previousText string - newText string // we always set it to empty if it's a delete operation - start int - end int -} - -// hasCommonCharSequence checks if str1 has a common character sequence with str2. -// It iterates through both strings and compares their characters one by one. -// The function returns true if all characters in str1 can be found in str2 in the same order, but not necessarily consecutively. -// This is helpful for determining if there is an insertion or deletion operation between two strings. -func hasCommonCharSequence(str1, str2 []rune) bool { - i, j := 0, 0 - for i < len(str1) && j < len(str2) { - if str1[i] == str2[j] { - i++ - } - j++ - } - return i == len(str1) + newText string // if add operation, newText is the added text; if replace operation, newText is the text used to replace the previousText + start int // start index of the operation relate to previousText + end int // end index of the operation relate to previousText, always the same as start if the operation is add, range: start<=end<=len(previousText)-1 + operation textOperation } func diffText(oldText, newText string) *TextDiff { @@ -1232,10 +1239,10 @@ func diffText(oldText, newText string) *TextDiff { oldLen := len(t1) newLen := len(t2) if oldLen == 0 { - return &TextDiff{previousText: oldText, newText: newText, start: 0, end: 0} + return &TextDiff{previousText: oldText, newText: newText, start: 0, end: 0, operation: textOperationAdd} } if newLen == 0 { - return &TextDiff{previousText: oldText, newText: "", start: 0, end: oldLen} + return &TextDiff{previousText: oldText, newText: "", start: 0, end: oldLen, operation: textOperationReplace} } // if we reach here, t1 and t2 are not empty @@ -1251,16 +1258,21 @@ func diffText(oldText, newText string) *TextDiff { } diff := &TextDiff{previousText: oldText, start: start} - if hasCommonCharSequence(t1, t2) { // is just a insert operation + if newLen > oldLen && (start == oldLen || oldEnd == 0 || start == oldEnd) { + diff.operation = textOperationAdd diff.end = start diff.newText = string(t2[start:newEnd]) + } else if newLen < oldLen && (start == newLen || newEnd == 0 || start == newEnd) { + diff.operation = textOperationDelete + diff.end = oldEnd - 1 } else { - diff.end = newEnd - if oldLen > newLen { - diff.end = oldEnd - } - if !hasCommonCharSequence(t2, t1) { // is not a delete operation - diff.newText = string(t2[start:diff.end]) + diff.operation = textOperationReplace + if start == 0 && oldEnd == oldLen { // full replace + diff.end = oldLen - 1 + diff.newText = newText + } else { // partial replace + diff.end = oldEnd - 1 + diff.newText = string(t2[start:newEnd]) } } return diff diff --git a/protocol/messenger_mention_test.go b/protocol/messenger_mention_test.go index ba2cae408..7451bfb1d 100644 --- a/protocol/messenger_mention_test.go +++ b/protocol/messenger_mention_test.go @@ -252,12 +252,11 @@ func TestGetAtSignIdxs(t *testing.T) { } func TestCalcAtIdxs(t *testing.T) { - newText := "@abc" state := MentionState{ AtIdxs: []*AtIndexEntry{ {From: 0, To: 3, Checked: false}, }, - NewText: &newText, + NewText: "@abc", PreviousText: "", Start: 0, } @@ -288,7 +287,7 @@ func TestToInfo(t *testing.T) { }, MentionEnd: 19, PreviousText: "", - NewText: &newText, + NewText: newText, Start: 18, End: 18, } @@ -505,6 +504,9 @@ func TestLastIndexOf(t *testing.T) { //at-sign-idx 0 text @t searched-text t start 2 end 2 new-text atSignIdx = lastIndexOf("@t", charAtSign, 2) require.Equal(t, 0, atSignIdx) + + atSignIdx = lastIndexOf("at", charAtSign, 3) + require.Equal(t, -1, atSignIdx) } func TestDiffText(t *testing.T) { @@ -521,6 +523,7 @@ func TestDiffText(t *testing.T) { end: 0, previousText: "", newText: "A", + operation: textOperationAdd, }, }, { @@ -531,6 +534,7 @@ func TestDiffText(t *testing.T) { end: 1, previousText: "A", newText: "b", + operation: textOperationAdd, }, }, { @@ -541,16 +545,18 @@ func TestDiffText(t *testing.T) { end: 2, previousText: "Ab", newText: "c", + operation: textOperationAdd, }, }, { - oldText: "Abc", - newText: "Ac", + oldText: "Ab", + newText: "cAb", expected: &TextDiff{ - start: 1, - end: 2, - previousText: "Abc", - newText: "", + start: 0, + end: 0, + previousText: "Ab", + newText: "c", + operation: textOperationAdd, }, }, { @@ -561,6 +567,7 @@ func TestDiffText(t *testing.T) { end: 1, previousText: "Ac", newText: "d", + operation: textOperationAdd, }, }, { @@ -571,6 +578,7 @@ func TestDiffText(t *testing.T) { end: 2, previousText: "Adc", newText: " ee ", + operation: textOperationAdd, }, }, { @@ -581,6 +589,62 @@ func TestDiffText(t *testing.T) { end: 1, previousText: "Ad ee c", newText: " fff ", + operation: textOperationAdd, + }, + }, + { + oldText: "Abc", + newText: "Ac", + expected: &TextDiff{ + start: 1, + end: 1, + previousText: "Abc", + newText: "", + operation: textOperationDelete, + }, + }, + { + oldText: "Abcd", + newText: "Ab", + expected: &TextDiff{ + start: 2, + end: 3, + previousText: "Abcd", + newText: "", + operation: textOperationDelete, + }, + }, + { + oldText: "Abcd", + newText: "bcd", + expected: &TextDiff{ + start: 0, + end: 0, + previousText: "Abcd", + newText: "", + operation: textOperationDelete, + }, + }, + { + oldText: "Abcd你好", + newText: "Abcd你", + expected: &TextDiff{ + start: 5, + end: 5, + previousText: "Abcd你好", + newText: "", + operation: textOperationDelete, + }, + }, + { + oldText: "Abcd你好", + newText: "Abcd", + expected: &TextDiff{ + start: 4, + end: 5, + previousText: "Abcd你好", + newText: "", + operation: textOperationDelete, }, }, { @@ -588,9 +652,10 @@ func TestDiffText(t *testing.T) { newText: " fff d ee c", expected: &TextDiff{ start: 0, - end: 1, + end: 0, previousText: "A fff d ee c", newText: "", + operation: textOperationDelete, }, }, { @@ -598,9 +663,10 @@ func TestDiffText(t *testing.T) { newText: " fffee c", expected: &TextDiff{ start: 4, - end: 7, + end: 6, previousText: " fff d ee c", newText: "", + operation: textOperationDelete, }, }, { @@ -608,6 +674,94 @@ func TestDiffText(t *testing.T) { newText: "abc", expected: nil, }, + { + oldText: "abc", + newText: "ghij", + expected: &TextDiff{ + start: 0, + end: 2, + previousText: "abc", + newText: "ghij", + operation: textOperationReplace, + }, + }, + { + oldText: "abc", + newText: "babcd", + expected: &TextDiff{ + start: 0, + end: 2, + previousText: "abc", + newText: "babcd", + operation: textOperationReplace, + }, + }, + { + oldText: "abc", + newText: "baebcd", + expected: &TextDiff{ + start: 0, + end: 2, + previousText: "abc", + newText: "baebcd", + operation: textOperationReplace, + }, + }, + { + oldText: "abc", + newText: "aefc", + expected: &TextDiff{ + start: 1, + end: 1, + previousText: "abc", + newText: "ef", + operation: textOperationReplace, + }, + }, + { + oldText: "abc", + newText: "adc", + expected: &TextDiff{ + start: 1, + end: 1, + previousText: "abc", + newText: "d", + operation: textOperationReplace, + }, + }, + { + oldText: "abc", + newText: "abd", + expected: &TextDiff{ + start: 2, + end: 2, + previousText: "abc", + newText: "d", + operation: textOperationReplace, + }, + }, + { + oldText: "abc", + newText: "cbc", + expected: &TextDiff{ + start: 0, + end: 0, + previousText: "abc", + newText: "c", + operation: textOperationReplace, + }, + }, + { + oldText: "abc", + newText: "ffbc", + expected: &TextDiff{ + start: 0, + end: 0, + previousText: "abc", + newText: "ff", + operation: textOperationReplace, + }, + }, } for i, tc := range testCases { t.Run(fmt.Sprintf("%d", i+1), func(t *testing.T) { @@ -641,16 +795,19 @@ func TestMentionSuggestionCases(t *testing.T) { {"@u2", 1}, {"@u23", 0}, {"@u2", 1}, + {"@u2 abc", 0}, + {"@u2 abc @u3", 1}, + {"@u2 abc@u3", 0}, + {"@u2 abc@u3 ", 0}, + {"@u2 abc @u3", 1}, } for i, tc := range testCases { t.Run(fmt.Sprintf("%d", i+1), func(t *testing.T) { - _, err := mentionManager.OnChangeText(chatID, tc.inputText) + ctx, err := mentionManager.OnChangeText(chatID, tc.inputText) require.NoError(t, err) - ctx, err := mentionManager.CalculateSuggestions(chatID, tc.inputText) - require.NoError(t, err) - require.Equal(t, tc.expectedSize, len(ctx.MentionSuggestions)) t.Logf("Input: %+v, MentionState:%+v, InputSegments:%+v\n", tc.inputText, ctx.MentionState, ctx.InputSegments) + require.Equal(t, tc.expectedSize, len(ctx.MentionSuggestions)) }) } } @@ -691,24 +848,19 @@ func TestMentionSuggestionSpecialChars(t *testing.T) { mentionableUserMap, chatID, mentionManager := setupMentionSuggestionTest(t, nil) testCases := []struct { - inputText string - expectedSize int - calculateSuggestion bool + inputText string + expectedSize int }{ - {"'", 0, false}, - {"‘", 0, true}, - {"‘@", len(mentionableUserMap), true}, + {"'", 0}, + {"‘", 0}, + {"‘@", len(mentionableUserMap)}, } for _, tc := range testCases { ctx, err := mentionManager.OnChangeText(chatID, tc.inputText) require.NoError(t, err) - if tc.calculateSuggestion { - ctx, err = mentionManager.CalculateSuggestions(chatID, tc.inputText) - require.NoError(t, err) - require.Equal(t, tc.expectedSize, len(ctx.MentionSuggestions)) - } t.Logf("Input: %+v, MentionState:%+v, InputSegments:%+v\n", tc.inputText, ctx.MentionState, ctx.InputSegments) + require.Equal(t, tc.expectedSize, len(ctx.MentionSuggestions)) } } @@ -723,13 +875,12 @@ func TestMentionSuggestionAtSignSpaceCases(t *testing.T) { }) testCases := []struct { - inputText string - expectedSize int - calculateSuggestion bool + inputText string + expectedSize int }{ - {"@", len(mentionableUserMap), true}, - {"@ ", 0, true}, - {"@ @", len(mentionableUserMap), true}, + {"@", len(mentionableUserMap)}, + {"@ ", 0}, + {"@ @", len(mentionableUserMap)}, } var ctx *ChatMentionContext @@ -738,12 +889,7 @@ func TestMentionSuggestionAtSignSpaceCases(t *testing.T) { ctx, err = mentionManager.OnChangeText(chatID, tc.inputText) require.NoError(t, err) t.Logf("After OnChangeText, Input: %+v, MentionState:%+v, InputSegments:%+v\n", tc.inputText, ctx.MentionState, ctx.InputSegments) - if tc.calculateSuggestion { - ctx, err = mentionManager.CalculateSuggestions(chatID, tc.inputText) - require.NoError(t, err) - require.Equal(t, tc.expectedSize, len(ctx.MentionSuggestions)) - t.Logf("After CalculateSuggestions, Input: %+v, MentionState:%+v, InputSegments:%+v\n", tc.inputText, ctx.MentionState, ctx.InputSegments) - } + require.Equal(t, tc.expectedSize, len(ctx.MentionSuggestions)) } require.Len(t, ctx.InputSegments, 3) require.Equal(t, Mention, ctx.InputSegments[0].Type) @@ -754,6 +900,29 @@ func TestMentionSuggestionAtSignSpaceCases(t *testing.T) { require.Equal(t, "@", ctx.InputSegments[2].Value) } +func TestSelectMention(t *testing.T) { + _, chatID, mentionManager := setupMentionSuggestionTest(t, nil) + + text := "@u2 abc" + ctx, err := mentionManager.OnChangeText(chatID, text) + require.NoError(t, err) + require.Equal(t, 0, len(ctx.MentionSuggestions)) + + ctx, err = mentionManager.OnChangeText(chatID, "@u abc") + require.NoError(t, err) + require.Equal(t, 3, len(ctx.MentionSuggestions)) + + ctx, err = mentionManager.SelectMention(chatID, "@u abc", "u2", "0xpk2") + require.NoError(t, err) + require.Equal(t, 0, len(ctx.MentionSuggestions)) + require.Equal(t, text, ctx.NewText) + require.Equal(t, text, ctx.PreviousText) + + ctx, err = mentionManager.OnChangeText(chatID, text) + require.NoError(t, err) + require.Equal(t, 0, len(ctx.MentionSuggestions)) +} + func setupMentionSuggestionTest(t *testing.T, mentionableUserMapInput map[string]*MentionableUser) (map[string]*MentionableUser, string, *MentionManager) { mentionableUserMap := mentionableUserMapInput if mentionableUserMap == nil {