mirror of https://github.com/status-im/go-waku.git
249 lines
8.4 KiB
Go
249 lines
8.4 KiB
Go
package publish
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/ethereum/go-ethereum/common"
|
|
"github.com/ethereum/go-ethereum/common/hexutil"
|
|
"github.com/libp2p/go-libp2p/core/peer"
|
|
"github.com/waku-org/go-waku/waku/v2/protocol"
|
|
"github.com/waku-org/go-waku/waku/v2/protocol/pb"
|
|
"github.com/waku-org/go-waku/waku/v2/protocol/store"
|
|
"github.com/waku-org/go-waku/waku/v2/timesource"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
const DefaultMaxHashQueryLength = 100
|
|
const DefaultHashQueryInterval = 3 * time.Second
|
|
const DefaultMessageSentPeriod = 3 // in seconds
|
|
const DefaultMessageExpiredPerid = 10 // in seconds
|
|
|
|
type MessageSentCheckOption func(*MessageSentCheck) error
|
|
|
|
// MessageSentCheck tracks the outgoing messages and check against store node
|
|
// if the message sent time has passed the `messageSentPeriod`, the message id will be includes for the next query
|
|
// if the message keeps missing after `messageExpiredPerid`, the message id will be expired
|
|
type MessageSentCheck struct {
|
|
messageIDs map[string]map[common.Hash]uint32
|
|
messageIDsMu sync.RWMutex
|
|
storePeerID peer.ID
|
|
MessageStoredChan chan common.Hash
|
|
MessageExpiredChan chan common.Hash
|
|
ctx context.Context
|
|
store *store.WakuStore
|
|
timesource timesource.Timesource
|
|
logger *zap.Logger
|
|
maxHashQueryLength uint64
|
|
hashQueryInterval time.Duration
|
|
messageSentPeriod uint32
|
|
messageExpiredPerid uint32
|
|
}
|
|
|
|
// NewMessageSentCheck creates a new instance of MessageSentCheck with default parameters
|
|
func NewMessageSentCheck(ctx context.Context, store *store.WakuStore, timesource timesource.Timesource, logger *zap.Logger) *MessageSentCheck {
|
|
return &MessageSentCheck{
|
|
messageIDs: make(map[string]map[common.Hash]uint32),
|
|
messageIDsMu: sync.RWMutex{},
|
|
MessageStoredChan: make(chan common.Hash, 1000),
|
|
MessageExpiredChan: make(chan common.Hash, 1000),
|
|
ctx: ctx,
|
|
store: store,
|
|
timesource: timesource,
|
|
logger: logger,
|
|
maxHashQueryLength: DefaultMaxHashQueryLength,
|
|
hashQueryInterval: DefaultHashQueryInterval,
|
|
messageSentPeriod: DefaultMessageSentPeriod,
|
|
messageExpiredPerid: DefaultMessageExpiredPerid,
|
|
}
|
|
}
|
|
|
|
// WithMaxHashQueryLength sets the maximum number of message hashes to query in one request
|
|
func WithMaxHashQueryLength(count uint64) MessageSentCheckOption {
|
|
return func(params *MessageSentCheck) error {
|
|
params.maxHashQueryLength = count
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// WithHashQueryInterval sets the interval to query the store node
|
|
func WithHashQueryInterval(interval time.Duration) MessageSentCheckOption {
|
|
return func(params *MessageSentCheck) error {
|
|
params.hashQueryInterval = interval
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// WithMessageSentPeriod sets the delay period to query the store node after message is published
|
|
func WithMessageSentPeriod(period uint32) MessageSentCheckOption {
|
|
return func(params *MessageSentCheck) error {
|
|
params.messageSentPeriod = period
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// WithMessageExpiredPerid sets the period that a message is considered expired
|
|
func WithMessageExpiredPerid(period uint32) MessageSentCheckOption {
|
|
return func(params *MessageSentCheck) error {
|
|
params.messageExpiredPerid = period
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Add adds a message for message sent check
|
|
func (m *MessageSentCheck) Add(topic string, messageID common.Hash, sentTime uint32) {
|
|
m.messageIDsMu.Lock()
|
|
defer m.messageIDsMu.Unlock()
|
|
|
|
if _, ok := m.messageIDs[topic]; !ok {
|
|
m.messageIDs[topic] = make(map[common.Hash]uint32)
|
|
}
|
|
m.messageIDs[topic][messageID] = sentTime
|
|
}
|
|
|
|
// DeleteByMessageIDs deletes the message ids from the message sent check, used by scenarios like message acked with MVDS
|
|
func (m *MessageSentCheck) DeleteByMessageIDs(messageIDs []common.Hash) {
|
|
m.messageIDsMu.Lock()
|
|
defer m.messageIDsMu.Unlock()
|
|
|
|
for pubsubTopic, subMsgs := range m.messageIDs {
|
|
for _, hash := range messageIDs {
|
|
delete(subMsgs, hash)
|
|
if len(subMsgs) == 0 {
|
|
delete(m.messageIDs, pubsubTopic)
|
|
} else {
|
|
m.messageIDs[pubsubTopic] = subMsgs
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// SetStorePeerID sets the peer id of store node
|
|
func (m *MessageSentCheck) SetStorePeerID(peerID peer.ID) {
|
|
m.storePeerID = peerID
|
|
}
|
|
|
|
// CheckIfMessagesStored checks if the tracked outgoing messages are stored periodically
|
|
func (m *MessageSentCheck) CheckIfMessagesStored() {
|
|
ticker := time.NewTicker(m.hashQueryInterval)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-m.ctx.Done():
|
|
m.logger.Debug("stop the look for message stored check")
|
|
return
|
|
case <-ticker.C:
|
|
m.messageIDsMu.Lock()
|
|
m.logger.Debug("running loop for messages stored check", zap.Any("messageIds", m.messageIDs))
|
|
pubsubTopics := make([]string, 0, len(m.messageIDs))
|
|
pubsubMessageIds := make([][]common.Hash, 0, len(m.messageIDs))
|
|
pubsubMessageTime := make([][]uint32, 0, len(m.messageIDs))
|
|
for pubsubTopic, subMsgs := range m.messageIDs {
|
|
var queryMsgIds []common.Hash
|
|
var queryMsgTime []uint32
|
|
for msgID, sendTime := range subMsgs {
|
|
if uint64(len(queryMsgIds)) >= m.maxHashQueryLength {
|
|
break
|
|
}
|
|
// message is sent 5 seconds ago, check if it's stored
|
|
if uint32(m.timesource.Now().Unix()) > sendTime+m.messageSentPeriod {
|
|
queryMsgIds = append(queryMsgIds, msgID)
|
|
queryMsgTime = append(queryMsgTime, sendTime)
|
|
}
|
|
}
|
|
m.logger.Debug("store query for message hashes", zap.Any("queryMsgIds", queryMsgIds), zap.String("pubsubTopic", pubsubTopic))
|
|
if len(queryMsgIds) > 0 {
|
|
pubsubTopics = append(pubsubTopics, pubsubTopic)
|
|
pubsubMessageIds = append(pubsubMessageIds, queryMsgIds)
|
|
pubsubMessageTime = append(pubsubMessageTime, queryMsgTime)
|
|
}
|
|
}
|
|
m.messageIDsMu.Unlock()
|
|
|
|
pubsubProcessedMessages := make([][]common.Hash, len(pubsubTopics))
|
|
for i, pubsubTopic := range pubsubTopics {
|
|
processedMessages := m.messageHashBasedQuery(m.ctx, pubsubMessageIds[i], pubsubMessageTime[i], pubsubTopic)
|
|
pubsubProcessedMessages[i] = processedMessages
|
|
}
|
|
|
|
m.messageIDsMu.Lock()
|
|
for i, pubsubTopic := range pubsubTopics {
|
|
subMsgs, ok := m.messageIDs[pubsubTopic]
|
|
if !ok {
|
|
continue
|
|
}
|
|
for _, hash := range pubsubProcessedMessages[i] {
|
|
delete(subMsgs, hash)
|
|
if len(subMsgs) == 0 {
|
|
delete(m.messageIDs, pubsubTopic)
|
|
} else {
|
|
m.messageIDs[pubsubTopic] = subMsgs
|
|
}
|
|
}
|
|
}
|
|
m.logger.Debug("messages for next store hash query", zap.Any("messageIds", m.messageIDs))
|
|
m.messageIDsMu.Unlock()
|
|
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *MessageSentCheck) messageHashBasedQuery(ctx context.Context, hashes []common.Hash, relayTime []uint32, pubsubTopic string) []common.Hash {
|
|
selectedPeer := m.storePeerID
|
|
if selectedPeer == "" {
|
|
m.logger.Error("no store peer id available", zap.String("pubsubTopic", pubsubTopic))
|
|
return []common.Hash{}
|
|
}
|
|
|
|
var opts []store.RequestOption
|
|
requestID := protocol.GenerateRequestID()
|
|
opts = append(opts, store.WithRequestID(requestID))
|
|
opts = append(opts, store.WithPeer(selectedPeer))
|
|
opts = append(opts, store.WithPaging(false, m.maxHashQueryLength))
|
|
opts = append(opts, store.IncludeData(false))
|
|
|
|
messageHashes := make([]pb.MessageHash, len(hashes))
|
|
for i, hash := range hashes {
|
|
messageHashes[i] = pb.ToMessageHash(hash.Bytes())
|
|
}
|
|
|
|
m.logger.Debug("store.queryByHash request", zap.String("requestID", hexutil.Encode(requestID)), zap.Stringer("peerID", selectedPeer), zap.Any("messageHashes", messageHashes))
|
|
|
|
result, err := m.store.QueryByHash(ctx, messageHashes, opts...)
|
|
if err != nil {
|
|
m.logger.Error("store.queryByHash failed", zap.String("requestID", hexutil.Encode(requestID)), zap.Stringer("peerID", selectedPeer), zap.Error(err))
|
|
return []common.Hash{}
|
|
}
|
|
|
|
m.logger.Debug("store.queryByHash result", zap.String("requestID", hexutil.Encode(requestID)), zap.Int("messages", len(result.Messages())))
|
|
|
|
var ackHashes []common.Hash
|
|
var missedHashes []common.Hash
|
|
for i, hash := range hashes {
|
|
found := false
|
|
for _, msg := range result.Messages() {
|
|
if bytes.Equal(msg.GetMessageHash(), hash.Bytes()) {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if found {
|
|
ackHashes = append(ackHashes, hash)
|
|
m.MessageStoredChan <- hash
|
|
}
|
|
|
|
if !found && uint32(m.timesource.Now().Unix()) > relayTime[i]+m.messageExpiredPerid {
|
|
missedHashes = append(missedHashes, hash)
|
|
m.MessageExpiredChan <- hash
|
|
}
|
|
}
|
|
|
|
m.logger.Debug("ack message hashes", zap.Any("ackHashes", ackHashes))
|
|
m.logger.Debug("missed message hashes", zap.Any("missedHashes", missedHashes))
|
|
|
|
return append(ackHashes, missedHashes...)
|
|
}
|