chore: refactor sender api (#1187)

This commit is contained in:
kaichao 2024-08-10 20:05:51 +08:00 committed by GitHub
parent 3eab289abb
commit 92d62a7c38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 319 additions and 12 deletions

View File

@ -23,6 +23,13 @@ const DefaultMessageExpiredPerid = 10 // in seconds
type MessageSentCheckOption func(*MessageSentCheck) error
type ISentCheck interface {
Start()
Add(topic string, messageID common.Hash, sentTime uint32)
DeleteByMessageIDs(messageIDs []common.Hash)
SetStorePeerID(peerID peer.ID)
}
// MessageSentCheck tracks the outgoing messages and check against store node
// if the message sent time has passed the `messageSentPeriod`, the message id will be includes for the next query
// if the message keeps missing after `messageExpiredPerid`, the message id will be expired
@ -30,8 +37,8 @@ type MessageSentCheck struct {
messageIDs map[string]map[common.Hash]uint32
messageIDsMu sync.RWMutex
storePeerID peer.ID
MessageStoredChan chan common.Hash
MessageExpiredChan chan common.Hash
messageStoredChan chan common.Hash
messageExpiredChan chan common.Hash
ctx context.Context
store *store.WakuStore
timesource timesource.Timesource
@ -43,12 +50,12 @@ type MessageSentCheck struct {
}
// NewMessageSentCheck creates a new instance of MessageSentCheck with default parameters
func NewMessageSentCheck(ctx context.Context, store *store.WakuStore, timesource timesource.Timesource, logger *zap.Logger) *MessageSentCheck {
func NewMessageSentCheck(ctx context.Context, store *store.WakuStore, timesource timesource.Timesource, msgStoredChan chan common.Hash, msgExpiredChan chan common.Hash, logger *zap.Logger) *MessageSentCheck {
return &MessageSentCheck{
messageIDs: make(map[string]map[common.Hash]uint32),
messageIDsMu: sync.RWMutex{},
MessageStoredChan: make(chan common.Hash, 1000),
MessageExpiredChan: make(chan common.Hash, 1000),
messageStoredChan: msgStoredChan,
messageExpiredChan: msgExpiredChan,
ctx: ctx,
store: store,
timesource: timesource,
@ -232,12 +239,12 @@ func (m *MessageSentCheck) messageHashBasedQuery(ctx context.Context, hashes []c
if found {
ackHashes = append(ackHashes, hash)
m.MessageStoredChan <- hash
m.messageStoredChan <- hash
}
if !found && uint32(m.timesource.Now().Unix()) > relayTime[i]+m.messageExpiredPerid {
missedHashes = append(missedHashes, hash)
m.MessageExpiredChan <- hash
m.messageExpiredChan <- hash
}
}

View File

@ -10,7 +10,7 @@ import (
func TestAddAndDelete(t *testing.T) {
ctx := context.TODO()
messageSentCheck := NewMessageSentCheck(ctx, nil, nil, nil)
messageSentCheck := NewMessageSentCheck(ctx, nil, nil, nil, nil, nil)
messageSentCheck.Add("topic", [32]byte{1}, 1)
messageSentCheck.Add("topic", [32]byte{2}, 2)

View File

@ -0,0 +1,170 @@
package publish
import (
"context"
"errors"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/waku-org/go-waku/waku/v2/protocol"
"github.com/waku-org/go-waku/waku/v2/protocol/lightpush"
"github.com/waku-org/go-waku/waku/v2/protocol/relay"
"go.uber.org/zap"
"golang.org/x/time/rate"
)
const DefaultPeersToPublishForLightpush = 2
const DefaultPublishingLimiterRate = rate.Limit(2)
const DefaultPublishingLimitBurst = 4
type PublishMethod int
const (
LightPush PublishMethod = iota
Relay
UnknownMethod
)
func (pm PublishMethod) String() string {
switch pm {
case LightPush:
return "LightPush"
case Relay:
return "Relay"
default:
return "Unknown"
}
}
type MessageSender struct {
publishMethod PublishMethod
lightPush *lightpush.WakuLightPush
relay *relay.WakuRelay
messageSentCheck ISentCheck
rateLimiter *PublishRateLimiter
logger *zap.Logger
}
type Request struct {
ctx context.Context
envelope *protocol.Envelope
publishMethod PublishMethod
}
func NewRequest(ctx context.Context, envelope *protocol.Envelope) *Request {
return &Request{
ctx: ctx,
envelope: envelope,
publishMethod: UnknownMethod,
}
}
func (r *Request) WithPublishMethod(publishMethod PublishMethod) *Request {
r.publishMethod = publishMethod
return r
}
func NewMessageSender(publishMethod PublishMethod, lightPush *lightpush.WakuLightPush, relay *relay.WakuRelay, logger *zap.Logger) (*MessageSender, error) {
if publishMethod == UnknownMethod {
return nil, errors.New("publish method is required")
}
return &MessageSender{
publishMethod: publishMethod,
lightPush: lightPush,
relay: relay,
rateLimiter: NewPublishRateLimiter(DefaultPublishingLimiterRate, DefaultPublishingLimitBurst),
logger: logger,
}, nil
}
func (ms *MessageSender) WithMessageSentCheck(messageSentCheck ISentCheck) *MessageSender {
ms.messageSentCheck = messageSentCheck
return ms
}
func (ms *MessageSender) WithRateLimiting(rateLimiter *PublishRateLimiter) *MessageSender {
ms.rateLimiter = rateLimiter
return ms
}
func (ms *MessageSender) Send(req *Request) error {
logger := ms.logger.With(
zap.Stringer("envelopeHash", req.envelope.Hash()),
zap.String("pubsubTopic", req.envelope.PubsubTopic()),
zap.String("contentTopic", req.envelope.Message().ContentTopic),
zap.Int64("timestamp", req.envelope.Message().GetTimestamp()),
)
if ms.rateLimiter != nil {
if err := ms.rateLimiter.Check(req.ctx, logger); err != nil {
return err
}
}
publishMethod := req.publishMethod
if publishMethod == UnknownMethod {
publishMethod = ms.publishMethod
}
switch publishMethod {
case LightPush:
if ms.lightPush == nil {
return errors.New("lightpush is not available")
}
logger.Info("publishing message via lightpush")
_, err := ms.lightPush.Publish(
req.ctx,
req.envelope.Message(),
lightpush.WithPubSubTopic(req.envelope.PubsubTopic()),
lightpush.WithMaxPeers(DefaultPeersToPublishForLightpush),
)
if err != nil {
return err
}
case Relay:
if ms.relay == nil {
return errors.New("relay is not available")
}
peerCnt := len(ms.relay.PubSub().ListPeers(req.envelope.PubsubTopic()))
logger.Info("publishing message via relay", zap.Int("peerCnt", peerCnt))
_, err := ms.relay.Publish(req.ctx, req.envelope.Message(), relay.WithPubSubTopic(req.envelope.PubsubTopic()))
if err != nil {
return err
}
default:
return errors.New("unknown publish method")
}
if ms.messageSentCheck != nil && !req.envelope.Message().GetEphemeral() {
ms.messageSentCheck.Add(
req.envelope.PubsubTopic(),
common.BytesToHash(req.envelope.Hash().Bytes()),
uint32(req.envelope.Message().GetTimestamp()/int64(time.Second)),
)
}
return nil
}
func (ms *MessageSender) Start() {
if ms.messageSentCheck != nil {
go ms.messageSentCheck.Start()
}
}
func (ms *MessageSender) PublishMethod() PublishMethod {
return ms.publishMethod
}
func (ms *MessageSender) MessagesDelivered(messageIDs []common.Hash) {
if ms.messageSentCheck != nil {
ms.messageSentCheck.DeleteByMessageIDs(messageIDs)
}
}
func (ms *MessageSender) SetStorePeerID(peerID peer.ID) {
if ms.messageSentCheck != nil {
ms.messageSentCheck.SetStorePeerID(peerID)
}
}

View File

@ -0,0 +1,123 @@
package publish
import (
"context"
"crypto/rand"
"testing"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/require"
"github.com/waku-org/go-waku/tests"
"github.com/waku-org/go-waku/waku/v2/protocol"
"github.com/waku-org/go-waku/waku/v2/protocol/pb"
"github.com/waku-org/go-waku/waku/v2/protocol/relay"
"github.com/waku-org/go-waku/waku/v2/timesource"
"github.com/waku-org/go-waku/waku/v2/utils"
)
type MockMessageSentCheck struct {
Messages map[string]map[common.Hash]uint32
}
func (m *MockMessageSentCheck) Add(topic string, messageID common.Hash, time uint32) {
if m.Messages[topic] == nil {
m.Messages[topic] = make(map[common.Hash]uint32)
}
m.Messages[topic][messageID] = time
}
func (m *MockMessageSentCheck) DeleteByMessageIDs(messageIDs []common.Hash) {
}
func (m *MockMessageSentCheck) SetStorePeerID(peerID peer.ID) {
}
func (m *MockMessageSentCheck) Start() {
}
func TestNewSenderWithUnknownMethod(t *testing.T) {
sender, err := NewMessageSender(UnknownMethod, nil, nil, nil)
require.NotNil(t, err)
require.Nil(t, sender)
}
func TestNewSenderWithRelay(t *testing.T) {
_, relayNode := createRelayNode(t)
err := relayNode.Start(context.Background())
require.Nil(t, err)
defer relayNode.Stop()
sender, err := NewMessageSender(Relay, nil, relayNode, utils.Logger())
require.Nil(t, err)
require.NotNil(t, sender)
require.Nil(t, sender.messageSentCheck)
require.Equal(t, Relay, sender.publishMethod)
msg := &pb.WakuMessage{
Payload: []byte{1, 2, 3},
Timestamp: utils.GetUnixEpoch(),
ContentTopic: "test-content-topic",
}
envelope := protocol.NewEnvelope(msg, *utils.GetUnixEpoch(), "test-pubsub-topic")
req := NewRequest(context.TODO(), envelope)
err = sender.Send(req)
require.Nil(t, err)
}
func TestNewSenderWithRelayAndMessageSentCheck(t *testing.T) {
_, relayNode := createRelayNode(t)
err := relayNode.Start(context.Background())
require.Nil(t, err)
defer relayNode.Stop()
sender, err := NewMessageSender(Relay, nil, relayNode, utils.Logger())
check := &MockMessageSentCheck{Messages: make(map[string]map[common.Hash]uint32)}
sender.WithMessageSentCheck(check)
require.Nil(t, err)
require.NotNil(t, sender)
require.NotNil(t, sender.messageSentCheck)
require.Equal(t, Relay, sender.publishMethod)
msg := &pb.WakuMessage{
Payload: []byte{1, 2, 3},
Timestamp: utils.GetUnixEpoch(),
ContentTopic: "test-content-topic",
}
envelope := protocol.NewEnvelope(msg, *utils.GetUnixEpoch(), "test-pubsub-topic")
req := NewRequest(context.TODO(), envelope)
require.Equal(t, 0, len(check.Messages))
err = sender.Send(req)
require.Nil(t, err)
require.Equal(t, 1, len(check.Messages))
require.Equal(
t,
uint32(msg.GetTimestamp()/int64(time.Second)),
check.Messages["test-pubsub-topic"][common.BytesToHash(envelope.Hash().Bytes())],
)
}
func TestNewSenderWithLightPush(t *testing.T) {
sender, err := NewMessageSender(LightPush, nil, nil, nil)
require.Nil(t, err)
require.NotNil(t, sender)
require.Equal(t, LightPush, sender.publishMethod)
}
func createRelayNode(t *testing.T) (host.Host, *relay.WakuRelay) {
port, err := tests.FindFreePort(t, "", 5)
require.NoError(t, err)
host, err := tests.MakeHost(context.Background(), port, rand.Reader)
require.NoError(t, err)
bcaster := relay.NewBroadcaster(10)
relay := relay.NewWakuRelay(bcaster, 0, timesource.NewDefaultClock(), prometheus.DefaultRegisterer, utils.Logger())
relay.SetHost(host)
err = bcaster.Start(context.Background())
require.NoError(t, err)
return host, relay
}

View File

@ -26,12 +26,19 @@ func NewPublishRateLimiter(r rate.Limit, b int) *PublishRateLimiter {
// ThrottlePublishFn is used to decorate a PublishFn so rate limiting is applied
func (p *PublishRateLimiter) ThrottlePublishFn(ctx context.Context, publishFn PublishFn) PublishFn {
return func(envelope *protocol.Envelope, logger *zap.Logger) error {
if err := p.Check(ctx, logger); err != nil {
return err
}
return publishFn(envelope, logger)
}
}
func (p *PublishRateLimiter) Check(ctx context.Context, logger *zap.Logger) error {
if err := p.limiter.Wait(ctx); err != nil {
if !errors.Is(err, context.Canceled) {
logger.Error("could not send message (limiter)", zap.Error(err))
}
return err
}
return publishFn(envelope, logger)
}
return nil
}