From 92d62a7c381648da7fd3ce67ea1fabc2231c9725 Mon Sep 17 00:00:00 2001 From: kaichao Date: Sat, 10 Aug 2024 20:05:51 +0800 Subject: [PATCH] chore: refactor sender api (#1187) --- waku/v2/api/publish/message_check.go | 21 ++- waku/v2/api/publish/message_check_test.go | 2 +- waku/v2/api/publish/message_sender.go | 170 +++++++++++++++++++++ waku/v2/api/publish/message_sender_test.go | 123 +++++++++++++++ waku/v2/api/publish/rate_limiting.go | 15 +- 5 files changed, 319 insertions(+), 12 deletions(-) create mode 100644 waku/v2/api/publish/message_sender.go create mode 100644 waku/v2/api/publish/message_sender_test.go diff --git a/waku/v2/api/publish/message_check.go b/waku/v2/api/publish/message_check.go index a7b16a57..a60a8d91 100644 --- a/waku/v2/api/publish/message_check.go +++ b/waku/v2/api/publish/message_check.go @@ -23,6 +23,13 @@ const DefaultMessageExpiredPerid = 10 // in seconds type MessageSentCheckOption func(*MessageSentCheck) error +type ISentCheck interface { + Start() + Add(topic string, messageID common.Hash, sentTime uint32) + DeleteByMessageIDs(messageIDs []common.Hash) + SetStorePeerID(peerID peer.ID) +} + // MessageSentCheck tracks the outgoing messages and check against store node // if the message sent time has passed the `messageSentPeriod`, the message id will be includes for the next query // if the message keeps missing after `messageExpiredPerid`, the message id will be expired @@ -30,8 +37,8 @@ type MessageSentCheck struct { messageIDs map[string]map[common.Hash]uint32 messageIDsMu sync.RWMutex storePeerID peer.ID - MessageStoredChan chan common.Hash - MessageExpiredChan chan common.Hash + messageStoredChan chan common.Hash + messageExpiredChan chan common.Hash ctx context.Context store *store.WakuStore timesource timesource.Timesource @@ -43,12 +50,12 @@ type MessageSentCheck struct { } // NewMessageSentCheck creates a new instance of MessageSentCheck with default parameters -func NewMessageSentCheck(ctx context.Context, store *store.WakuStore, timesource timesource.Timesource, logger *zap.Logger) *MessageSentCheck { +func NewMessageSentCheck(ctx context.Context, store *store.WakuStore, timesource timesource.Timesource, msgStoredChan chan common.Hash, msgExpiredChan chan common.Hash, logger *zap.Logger) *MessageSentCheck { return &MessageSentCheck{ messageIDs: make(map[string]map[common.Hash]uint32), messageIDsMu: sync.RWMutex{}, - MessageStoredChan: make(chan common.Hash, 1000), - MessageExpiredChan: make(chan common.Hash, 1000), + messageStoredChan: msgStoredChan, + messageExpiredChan: msgExpiredChan, ctx: ctx, store: store, timesource: timesource, @@ -232,12 +239,12 @@ func (m *MessageSentCheck) messageHashBasedQuery(ctx context.Context, hashes []c if found { ackHashes = append(ackHashes, hash) - m.MessageStoredChan <- hash + m.messageStoredChan <- hash } if !found && uint32(m.timesource.Now().Unix()) > relayTime[i]+m.messageExpiredPerid { missedHashes = append(missedHashes, hash) - m.MessageExpiredChan <- hash + m.messageExpiredChan <- hash } } diff --git a/waku/v2/api/publish/message_check_test.go b/waku/v2/api/publish/message_check_test.go index 12947258..ef53f4d3 100644 --- a/waku/v2/api/publish/message_check_test.go +++ b/waku/v2/api/publish/message_check_test.go @@ -10,7 +10,7 @@ import ( func TestAddAndDelete(t *testing.T) { ctx := context.TODO() - messageSentCheck := NewMessageSentCheck(ctx, nil, nil, nil) + messageSentCheck := NewMessageSentCheck(ctx, nil, nil, nil, nil, nil) messageSentCheck.Add("topic", [32]byte{1}, 1) messageSentCheck.Add("topic", [32]byte{2}, 2) diff --git a/waku/v2/api/publish/message_sender.go b/waku/v2/api/publish/message_sender.go new file mode 100644 index 00000000..479d894a --- /dev/null +++ b/waku/v2/api/publish/message_sender.go @@ -0,0 +1,170 @@ +package publish + +import ( + "context" + "errors" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/waku-org/go-waku/waku/v2/protocol" + "github.com/waku-org/go-waku/waku/v2/protocol/lightpush" + "github.com/waku-org/go-waku/waku/v2/protocol/relay" + "go.uber.org/zap" + "golang.org/x/time/rate" +) + +const DefaultPeersToPublishForLightpush = 2 +const DefaultPublishingLimiterRate = rate.Limit(2) +const DefaultPublishingLimitBurst = 4 + +type PublishMethod int + +const ( + LightPush PublishMethod = iota + Relay + UnknownMethod +) + +func (pm PublishMethod) String() string { + switch pm { + case LightPush: + return "LightPush" + case Relay: + return "Relay" + default: + return "Unknown" + } +} + +type MessageSender struct { + publishMethod PublishMethod + lightPush *lightpush.WakuLightPush + relay *relay.WakuRelay + messageSentCheck ISentCheck + rateLimiter *PublishRateLimiter + logger *zap.Logger +} + +type Request struct { + ctx context.Context + envelope *protocol.Envelope + publishMethod PublishMethod +} + +func NewRequest(ctx context.Context, envelope *protocol.Envelope) *Request { + return &Request{ + ctx: ctx, + envelope: envelope, + publishMethod: UnknownMethod, + } +} + +func (r *Request) WithPublishMethod(publishMethod PublishMethod) *Request { + r.publishMethod = publishMethod + return r +} + +func NewMessageSender(publishMethod PublishMethod, lightPush *lightpush.WakuLightPush, relay *relay.WakuRelay, logger *zap.Logger) (*MessageSender, error) { + if publishMethod == UnknownMethod { + return nil, errors.New("publish method is required") + } + return &MessageSender{ + publishMethod: publishMethod, + lightPush: lightPush, + relay: relay, + rateLimiter: NewPublishRateLimiter(DefaultPublishingLimiterRate, DefaultPublishingLimitBurst), + logger: logger, + }, nil +} + +func (ms *MessageSender) WithMessageSentCheck(messageSentCheck ISentCheck) *MessageSender { + ms.messageSentCheck = messageSentCheck + return ms +} + +func (ms *MessageSender) WithRateLimiting(rateLimiter *PublishRateLimiter) *MessageSender { + ms.rateLimiter = rateLimiter + return ms +} + +func (ms *MessageSender) Send(req *Request) error { + logger := ms.logger.With( + zap.Stringer("envelopeHash", req.envelope.Hash()), + zap.String("pubsubTopic", req.envelope.PubsubTopic()), + zap.String("contentTopic", req.envelope.Message().ContentTopic), + zap.Int64("timestamp", req.envelope.Message().GetTimestamp()), + ) + + if ms.rateLimiter != nil { + if err := ms.rateLimiter.Check(req.ctx, logger); err != nil { + return err + } + } + + publishMethod := req.publishMethod + if publishMethod == UnknownMethod { + publishMethod = ms.publishMethod + } + + switch publishMethod { + case LightPush: + if ms.lightPush == nil { + return errors.New("lightpush is not available") + } + logger.Info("publishing message via lightpush") + _, err := ms.lightPush.Publish( + req.ctx, + req.envelope.Message(), + lightpush.WithPubSubTopic(req.envelope.PubsubTopic()), + lightpush.WithMaxPeers(DefaultPeersToPublishForLightpush), + ) + if err != nil { + return err + } + case Relay: + if ms.relay == nil { + return errors.New("relay is not available") + } + peerCnt := len(ms.relay.PubSub().ListPeers(req.envelope.PubsubTopic())) + logger.Info("publishing message via relay", zap.Int("peerCnt", peerCnt)) + _, err := ms.relay.Publish(req.ctx, req.envelope.Message(), relay.WithPubSubTopic(req.envelope.PubsubTopic())) + if err != nil { + return err + } + default: + return errors.New("unknown publish method") + } + + if ms.messageSentCheck != nil && !req.envelope.Message().GetEphemeral() { + ms.messageSentCheck.Add( + req.envelope.PubsubTopic(), + common.BytesToHash(req.envelope.Hash().Bytes()), + uint32(req.envelope.Message().GetTimestamp()/int64(time.Second)), + ) + } + + return nil +} + +func (ms *MessageSender) Start() { + if ms.messageSentCheck != nil { + go ms.messageSentCheck.Start() + } +} + +func (ms *MessageSender) PublishMethod() PublishMethod { + return ms.publishMethod +} + +func (ms *MessageSender) MessagesDelivered(messageIDs []common.Hash) { + if ms.messageSentCheck != nil { + ms.messageSentCheck.DeleteByMessageIDs(messageIDs) + } +} + +func (ms *MessageSender) SetStorePeerID(peerID peer.ID) { + if ms.messageSentCheck != nil { + ms.messageSentCheck.SetStorePeerID(peerID) + } +} diff --git a/waku/v2/api/publish/message_sender_test.go b/waku/v2/api/publish/message_sender_test.go new file mode 100644 index 00000000..d6945c8c --- /dev/null +++ b/waku/v2/api/publish/message_sender_test.go @@ -0,0 +1,123 @@ +package publish + +import ( + "context" + "crypto/rand" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/require" + "github.com/waku-org/go-waku/tests" + "github.com/waku-org/go-waku/waku/v2/protocol" + "github.com/waku-org/go-waku/waku/v2/protocol/pb" + "github.com/waku-org/go-waku/waku/v2/protocol/relay" + "github.com/waku-org/go-waku/waku/v2/timesource" + "github.com/waku-org/go-waku/waku/v2/utils" +) + +type MockMessageSentCheck struct { + Messages map[string]map[common.Hash]uint32 +} + +func (m *MockMessageSentCheck) Add(topic string, messageID common.Hash, time uint32) { + if m.Messages[topic] == nil { + m.Messages[topic] = make(map[common.Hash]uint32) + } + m.Messages[topic][messageID] = time +} + +func (m *MockMessageSentCheck) DeleteByMessageIDs(messageIDs []common.Hash) { +} + +func (m *MockMessageSentCheck) SetStorePeerID(peerID peer.ID) { +} + +func (m *MockMessageSentCheck) Start() { +} + +func TestNewSenderWithUnknownMethod(t *testing.T) { + sender, err := NewMessageSender(UnknownMethod, nil, nil, nil) + require.NotNil(t, err) + require.Nil(t, sender) +} + +func TestNewSenderWithRelay(t *testing.T) { + _, relayNode := createRelayNode(t) + err := relayNode.Start(context.Background()) + require.Nil(t, err) + defer relayNode.Stop() + sender, err := NewMessageSender(Relay, nil, relayNode, utils.Logger()) + require.Nil(t, err) + require.NotNil(t, sender) + require.Nil(t, sender.messageSentCheck) + require.Equal(t, Relay, sender.publishMethod) + + msg := &pb.WakuMessage{ + Payload: []byte{1, 2, 3}, + Timestamp: utils.GetUnixEpoch(), + ContentTopic: "test-content-topic", + } + envelope := protocol.NewEnvelope(msg, *utils.GetUnixEpoch(), "test-pubsub-topic") + req := NewRequest(context.TODO(), envelope) + err = sender.Send(req) + require.Nil(t, err) +} + +func TestNewSenderWithRelayAndMessageSentCheck(t *testing.T) { + _, relayNode := createRelayNode(t) + err := relayNode.Start(context.Background()) + require.Nil(t, err) + defer relayNode.Stop() + sender, err := NewMessageSender(Relay, nil, relayNode, utils.Logger()) + + check := &MockMessageSentCheck{Messages: make(map[string]map[common.Hash]uint32)} + sender.WithMessageSentCheck(check) + require.Nil(t, err) + require.NotNil(t, sender) + require.NotNil(t, sender.messageSentCheck) + require.Equal(t, Relay, sender.publishMethod) + + msg := &pb.WakuMessage{ + Payload: []byte{1, 2, 3}, + Timestamp: utils.GetUnixEpoch(), + ContentTopic: "test-content-topic", + } + envelope := protocol.NewEnvelope(msg, *utils.GetUnixEpoch(), "test-pubsub-topic") + req := NewRequest(context.TODO(), envelope) + + require.Equal(t, 0, len(check.Messages)) + + err = sender.Send(req) + require.Nil(t, err) + require.Equal(t, 1, len(check.Messages)) + require.Equal( + t, + uint32(msg.GetTimestamp()/int64(time.Second)), + check.Messages["test-pubsub-topic"][common.BytesToHash(envelope.Hash().Bytes())], + ) +} + +func TestNewSenderWithLightPush(t *testing.T) { + sender, err := NewMessageSender(LightPush, nil, nil, nil) + require.Nil(t, err) + require.NotNil(t, sender) + require.Equal(t, LightPush, sender.publishMethod) +} + +func createRelayNode(t *testing.T) (host.Host, *relay.WakuRelay) { + port, err := tests.FindFreePort(t, "", 5) + require.NoError(t, err) + host, err := tests.MakeHost(context.Background(), port, rand.Reader) + require.NoError(t, err) + bcaster := relay.NewBroadcaster(10) + relay := relay.NewWakuRelay(bcaster, 0, timesource.NewDefaultClock(), prometheus.DefaultRegisterer, utils.Logger()) + relay.SetHost(host) + err = bcaster.Start(context.Background()) + require.NoError(t, err) + + return host, relay +} diff --git a/waku/v2/api/publish/rate_limiting.go b/waku/v2/api/publish/rate_limiting.go index 4322413b..a0bddcbd 100644 --- a/waku/v2/api/publish/rate_limiting.go +++ b/waku/v2/api/publish/rate_limiting.go @@ -26,12 +26,19 @@ func NewPublishRateLimiter(r rate.Limit, b int) *PublishRateLimiter { // ThrottlePublishFn is used to decorate a PublishFn so rate limiting is applied func (p *PublishRateLimiter) ThrottlePublishFn(ctx context.Context, publishFn PublishFn) PublishFn { return func(envelope *protocol.Envelope, logger *zap.Logger) error { - if err := p.limiter.Wait(ctx); err != nil { - if !errors.Is(err, context.Canceled) { - logger.Error("could not send message (limiter)", zap.Error(err)) - } + if err := p.Check(ctx, logger); err != nil { return err } return publishFn(envelope, logger) } } + +func (p *PublishRateLimiter) Check(ctx context.Context, logger *zap.Logger) error { + if err := p.limiter.Wait(ctx); err != nil { + if !errors.Is(err, context.Canceled) { + logger.Error("could not send message (limiter)", zap.Error(err)) + } + return err + } + return nil +}