diff --git a/protocol/messenger_builder_test.go b/protocol/messenger_builder_test.go index 1db3c9e79..08d665787 100644 --- a/protocol/messenger_builder_test.go +++ b/protocol/messenger_builder_test.go @@ -25,7 +25,9 @@ type testMessengerConfig struct { logger *zap.Logger unhandledMessagesTracker *unhandledMessagesTracker - extraOptions []Option + messagesOrderController *MessagesOrderController + + extraOptions []Option } func (tmc *testMessengerConfig) complete() error { @@ -99,6 +101,10 @@ func newTestMessenger(waku types.Waku, config testMessengerConfig) (*Messenger, m.unhandledMessagesTracker = config.unhandledMessagesTracker.addMessage } + if config.messagesOrderController != nil { + m.retrievedMessagesIteratorFactory = config.messagesOrderController.newMessagesIterator + } + err = m.Init() if err != nil { return nil, err diff --git a/protocol/messenger_messages_order_controller_test.go b/protocol/messenger_messages_order_controller_test.go new file mode 100644 index 000000000..5652ceaaa --- /dev/null +++ b/protocol/messenger_messages_order_controller_test.go @@ -0,0 +1,130 @@ +package protocol + +import ( + "sort" + "sync" + + "github.com/status-im/status-go/eth-node/types" + "github.com/status-im/status-go/protocol/transport" +) + +type messagesOrderType int + +const ( + messagesOrderRandom messagesOrderType = iota + messagesOrderAsPosted + messagesOrderReversed +) + +type MessagesOrderController struct { + order messagesOrderType + messagesInPostOrder [][]byte + mutex sync.RWMutex + quit chan struct{} + quitOnce sync.Once +} + +func NewMessagesOrderController(order messagesOrderType) *MessagesOrderController { + return &MessagesOrderController{ + order: order, + messagesInPostOrder: [][]byte{}, + mutex: sync.RWMutex{}, + quit: make(chan struct{}), + } +} + +func (m *MessagesOrderController) Start(c chan *PostMessageSubscription) { + go func() { + for { + select { + case sub, more := <-c: + if !more { + return + } + m.mutex.Lock() + m.messagesInPostOrder = append(m.messagesInPostOrder, sub.id) + m.mutex.Unlock() + + case <-m.quit: + return + } + } + }() +} + +func (m *MessagesOrderController) Stop() { + m.quitOnce.Do(func() { + close(m.quit) + }) +} + +func (m *MessagesOrderController) newMessagesIterator(chatWithMessages map[transport.Filter][]*types.Message) MessagesIterator { + switch m.order { + case messagesOrderAsPosted, messagesOrderReversed: + return &messagesIterator{chatWithMessages: m.sort(chatWithMessages, m.order)} + } + + return NewDefaultMessagesIterator(chatWithMessages) +} + +func buildIndexMap(messages [][]byte) map[string]int { + indexMap := make(map[string]int) + for i, hash := range messages { + hashStr := string(hash) + indexMap[hashStr] = i + } + return indexMap +} + +func (m *MessagesOrderController) sort(chatWithMessages map[transport.Filter][]*types.Message, order messagesOrderType) []*chatWithMessage { + allMessages := make([]*chatWithMessage, 0) + for chat, messages := range chatWithMessages { + for _, message := range messages { + allMessages = append(allMessages, &chatWithMessage{chat: chat, message: message}) + } + } + + m.mutex.RLock() + indexMap := buildIndexMap(m.messagesInPostOrder) + m.mutex.RUnlock() + + sort.SliceStable(allMessages, func(i, j int) bool { + indexI, okI := indexMap[string(allMessages[i].message.Hash)] + indexJ, okJ := indexMap[string(allMessages[j].message.Hash)] + + if okI && okJ { + if order == messagesOrderReversed { + return indexI > indexJ + } + return indexI < indexJ + } + + return !okI && okJ // keep messages with unknown hashes at the end + }) + + return allMessages +} + +type chatWithMessage struct { + chat transport.Filter + message *types.Message +} + +type messagesIterator struct { + chatWithMessages []*chatWithMessage + currentIndex int +} + +func (it *messagesIterator) HasNext() bool { + return it.currentIndex < len(it.chatWithMessages) +} + +func (it *messagesIterator) Next() (transport.Filter, []*types.Message) { + if it.HasNext() { + m := it.chatWithMessages[it.currentIndex] + it.currentIndex++ + return m.chat, []*types.Message{m.message} + } + + return transport.Filter{}, nil +}