From 240051b8b8aeb40d96a868a35634f6ede1aa8b59 Mon Sep 17 00:00:00 2001 From: kaichao Date: Tue, 6 Aug 2024 21:05:47 +0800 Subject: [PATCH] chore: move outgoing message check from status-go to go-waku (#1180) --- waku/v2/api/publish/message_check.go | 248 ++++++++++++++++++++++ waku/v2/api/publish/message_check_test.go | 33 +++ 2 files changed, 281 insertions(+) create mode 100644 waku/v2/api/publish/message_check.go create mode 100644 waku/v2/api/publish/message_check_test.go diff --git a/waku/v2/api/publish/message_check.go b/waku/v2/api/publish/message_check.go new file mode 100644 index 00000000..ef5148bd --- /dev/null +++ b/waku/v2/api/publish/message_check.go @@ -0,0 +1,248 @@ +package publish + +import ( + "bytes" + "context" + "sync" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/waku-org/go-waku/waku/v2/protocol" + "github.com/waku-org/go-waku/waku/v2/protocol/pb" + "github.com/waku-org/go-waku/waku/v2/protocol/store" + "github.com/waku-org/go-waku/waku/v2/timesource" + "go.uber.org/zap" +) + +const DefaultMaxHashQueryLength = 100 +const DefaultHashQueryInterval = 3 * time.Second +const DefaultMessageSentPeriod = 3 // in seconds +const DefaultMessageExpiredPerid = 10 // in seconds + +type MessageSentCheckOption func(*MessageSentCheck) error + +// MessageSentCheck tracks the outgoing messages and check against store node +// if the message sent time has passed the `messageSentPeriod`, the message id will be includes for the next query +// if the message keeps missing after `messageExpiredPerid`, the message id will be expired +type MessageSentCheck struct { + messageIDs map[string]map[common.Hash]uint32 + messageIDsMu sync.RWMutex + storePeerID peer.ID + MessageStoredChan chan common.Hash + MessageExpiredChan chan common.Hash + ctx context.Context + store *store.WakuStore + timesource timesource.Timesource + logger *zap.Logger + maxHashQueryLength uint64 + hashQueryInterval time.Duration + messageSentPeriod uint32 + messageExpiredPerid uint32 +} + +// NewMessageSentCheck creates a new instance of MessageSentCheck with default parameters +func NewMessageSentCheck(ctx context.Context, store *store.WakuStore, timesource timesource.Timesource, logger *zap.Logger) *MessageSentCheck { + return &MessageSentCheck{ + messageIDs: make(map[string]map[common.Hash]uint32), + messageIDsMu: sync.RWMutex{}, + MessageStoredChan: make(chan common.Hash, 1000), + MessageExpiredChan: make(chan common.Hash, 1000), + ctx: ctx, + store: store, + timesource: timesource, + logger: logger, + maxHashQueryLength: DefaultMaxHashQueryLength, + hashQueryInterval: DefaultHashQueryInterval, + messageSentPeriod: DefaultMessageSentPeriod, + messageExpiredPerid: DefaultMessageExpiredPerid, + } +} + +// WithMaxHashQueryLength sets the maximum number of message hashes to query in one request +func WithMaxHashQueryLength(count uint64) MessageSentCheckOption { + return func(params *MessageSentCheck) error { + params.maxHashQueryLength = count + return nil + } +} + +// WithHashQueryInterval sets the interval to query the store node +func WithHashQueryInterval(interval time.Duration) MessageSentCheckOption { + return func(params *MessageSentCheck) error { + params.hashQueryInterval = interval + return nil + } +} + +// WithMessageSentPeriod sets the delay period to query the store node after message is published +func WithMessageSentPeriod(period uint32) MessageSentCheckOption { + return func(params *MessageSentCheck) error { + params.messageSentPeriod = period + return nil + } +} + +// WithMessageExpiredPerid sets the period that a message is considered expired +func WithMessageExpiredPerid(period uint32) MessageSentCheckOption { + return func(params *MessageSentCheck) error { + params.messageExpiredPerid = period + return nil + } +} + +// Add adds a message for message sent check +func (m *MessageSentCheck) Add(topic string, messageID common.Hash, sentTime uint32) { + m.messageIDsMu.Lock() + defer m.messageIDsMu.Unlock() + + if _, ok := m.messageIDs[topic]; !ok { + m.messageIDs[topic] = make(map[common.Hash]uint32) + } + m.messageIDs[topic][messageID] = sentTime +} + +// DeleteByMessageIDs deletes the message ids from the message sent check, used by scenarios like message acked with MVDS +func (m *MessageSentCheck) DeleteByMessageIDs(messageIDs []common.Hash) { + m.messageIDsMu.Lock() + defer m.messageIDsMu.Unlock() + + for pubsubTopic, subMsgs := range m.messageIDs { + for _, hash := range messageIDs { + delete(subMsgs, hash) + if len(subMsgs) == 0 { + delete(m.messageIDs, pubsubTopic) + } else { + m.messageIDs[pubsubTopic] = subMsgs + } + } + } +} + +// SetStorePeerID sets the peer id of store node +func (m *MessageSentCheck) SetStorePeerID(peerID peer.ID) { + m.storePeerID = peerID +} + +// CheckIfMessagesStored checks if the tracked outgoing messages are stored periodically +func (m *MessageSentCheck) CheckIfMessagesStored() { + ticker := time.NewTicker(m.hashQueryInterval) + defer ticker.Stop() + for { + select { + case <-m.ctx.Done(): + m.logger.Debug("stop the look for message stored check") + return + case <-ticker.C: + m.messageIDsMu.Lock() + m.logger.Debug("running loop for messages stored check", zap.Any("messageIds", m.messageIDs)) + pubsubTopics := make([]string, 0, len(m.messageIDs)) + pubsubMessageIds := make([][]common.Hash, 0, len(m.messageIDs)) + pubsubMessageTime := make([][]uint32, 0, len(m.messageIDs)) + for pubsubTopic, subMsgs := range m.messageIDs { + var queryMsgIds []common.Hash + var queryMsgTime []uint32 + for msgID, sendTime := range subMsgs { + if uint64(len(queryMsgIds)) >= m.maxHashQueryLength { + break + } + // message is sent 5 seconds ago, check if it's stored + if uint32(m.timesource.Now().Unix()) > sendTime+m.messageSentPeriod { + queryMsgIds = append(queryMsgIds, msgID) + queryMsgTime = append(queryMsgTime, sendTime) + } + } + m.logger.Debug("store query for message hashes", zap.Any("queryMsgIds", queryMsgIds), zap.String("pubsubTopic", pubsubTopic)) + if len(queryMsgIds) > 0 { + pubsubTopics = append(pubsubTopics, pubsubTopic) + pubsubMessageIds = append(pubsubMessageIds, queryMsgIds) + pubsubMessageTime = append(pubsubMessageTime, queryMsgTime) + } + } + m.messageIDsMu.Unlock() + + pubsubProcessedMessages := make([][]common.Hash, len(pubsubTopics)) + for i, pubsubTopic := range pubsubTopics { + processedMessages := m.messageHashBasedQuery(m.ctx, pubsubMessageIds[i], pubsubMessageTime[i], pubsubTopic) + pubsubProcessedMessages[i] = processedMessages + } + + m.messageIDsMu.Lock() + for i, pubsubTopic := range pubsubTopics { + subMsgs, ok := m.messageIDs[pubsubTopic] + if !ok { + continue + } + for _, hash := range pubsubProcessedMessages[i] { + delete(subMsgs, hash) + if len(subMsgs) == 0 { + delete(m.messageIDs, pubsubTopic) + } else { + m.messageIDs[pubsubTopic] = subMsgs + } + } + } + m.logger.Debug("messages for next store hash query", zap.Any("messageIds", m.messageIDs)) + m.messageIDsMu.Unlock() + + } + } +} + +func (m *MessageSentCheck) messageHashBasedQuery(ctx context.Context, hashes []common.Hash, relayTime []uint32, pubsubTopic string) []common.Hash { + selectedPeer := m.storePeerID + if selectedPeer == "" { + m.logger.Error("no store peer id available", zap.String("pubsubTopic", pubsubTopic)) + return []common.Hash{} + } + + var opts []store.RequestOption + requestID := protocol.GenerateRequestID() + opts = append(opts, store.WithRequestID(requestID)) + opts = append(opts, store.WithPeer(selectedPeer)) + opts = append(opts, store.WithPaging(false, m.maxHashQueryLength)) + opts = append(opts, store.IncludeData(false)) + + messageHashes := make([]pb.MessageHash, len(hashes)) + for i, hash := range hashes { + messageHashes[i] = pb.ToMessageHash(hash.Bytes()) + } + + m.logger.Debug("store.queryByHash request", zap.String("requestID", hexutil.Encode(requestID)), zap.Stringer("peerID", selectedPeer), zap.Any("messageHashes", messageHashes)) + + result, err := m.store.QueryByHash(ctx, messageHashes, opts...) + if err != nil { + m.logger.Error("store.queryByHash failed", zap.String("requestID", hexutil.Encode(requestID)), zap.Stringer("peerID", selectedPeer), zap.Error(err)) + return []common.Hash{} + } + + m.logger.Debug("store.queryByHash result", zap.String("requestID", hexutil.Encode(requestID)), zap.Int("messages", len(result.Messages()))) + + var ackHashes []common.Hash + var missedHashes []common.Hash + for i, hash := range hashes { + found := false + for _, msg := range result.Messages() { + if bytes.Equal(msg.GetMessageHash(), hash.Bytes()) { + found = true + break + } + } + + if found { + ackHashes = append(ackHashes, hash) + m.MessageStoredChan <- hash + } + + if !found && uint32(m.timesource.Now().Unix()) > relayTime[i]+m.messageExpiredPerid { + missedHashes = append(missedHashes, hash) + m.MessageExpiredChan <- hash + } + } + + m.logger.Debug("ack message hashes", zap.Any("ackHashes", ackHashes)) + m.logger.Debug("missed message hashes", zap.Any("missedHashes", missedHashes)) + + return append(ackHashes, missedHashes...) +} diff --git a/waku/v2/api/publish/message_check_test.go b/waku/v2/api/publish/message_check_test.go new file mode 100644 index 00000000..12947258 --- /dev/null +++ b/waku/v2/api/publish/message_check_test.go @@ -0,0 +1,33 @@ +package publish + +import ( + "context" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/require" +) + +func TestAddAndDelete(t *testing.T) { + ctx := context.TODO() + messageSentCheck := NewMessageSentCheck(ctx, nil, nil, nil) + + messageSentCheck.Add("topic", [32]byte{1}, 1) + messageSentCheck.Add("topic", [32]byte{2}, 2) + messageSentCheck.Add("topic", [32]byte{3}, 3) + messageSentCheck.Add("another-topic", [32]byte{4}, 4) + + require.Equal(t, uint32(1), messageSentCheck.messageIDs["topic"][[32]byte{1}]) + require.Equal(t, uint32(2), messageSentCheck.messageIDs["topic"][[32]byte{2}]) + require.Equal(t, uint32(3), messageSentCheck.messageIDs["topic"][[32]byte{3}]) + require.Equal(t, uint32(4), messageSentCheck.messageIDs["another-topic"][[32]byte{4}]) + + messageSentCheck.DeleteByMessageIDs([]common.Hash{[32]byte{1}, [32]byte{2}}) + require.NotNil(t, messageSentCheck.messageIDs["topic"]) + require.Equal(t, uint32(3), messageSentCheck.messageIDs["topic"][[32]byte{3}]) + + messageSentCheck.DeleteByMessageIDs([]common.Hash{[32]byte{3}}) + require.Nil(t, messageSentCheck.messageIDs["topic"]) + + require.Equal(t, uint32(4), messageSentCheck.messageIDs["another-topic"][[32]byte{4}]) +}