diff --git a/waku/v2/protocol/relay/waku_relay.go b/waku/v2/protocol/relay/waku_relay.go index a62c1304..edb75e49 100644 --- a/waku/v2/protocol/relay/waku_relay.go +++ b/waku/v2/protocol/relay/waku_relay.go @@ -8,6 +8,7 @@ import ( "sync" "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" "go.opencensus.io/stats" "go.opencensus.io/tag" @@ -29,6 +30,18 @@ const WakuRelayID_v200 = protocol.ID("/vac/waku/relay/2.0.0") var DefaultWakuTopic string = waku_proto.DefaultPubsubTopic().String() +type cacheItem struct { + msg *pb.WakuMessage + id string + pubsubTopic string +} + +type msgReq struct { + id string + pubsubTopic string + ch chan *pb.WakuMessage +} + type WakuRelay struct { host host.Host opts []pubsub.Option @@ -49,6 +62,14 @@ type WakuRelay struct { // TODO: convert to concurrent maps subscriptions map[string][]*Subscription subscriptionsMutex sync.Mutex + + msgCache map[string]*pb.WakuMessage + msgCacheCh chan cacheItem + deleteCacheCh chan string + getMsgCh chan msgReq + + cancel context.CancelFunc + wg sync.WaitGroup } func msgIdFn(pmsg *pubsub_pb.Message) string { @@ -65,6 +86,7 @@ func NewWakuRelay(h host.Host, bcaster v2.Broadcaster, minPeersToPublish int, ti w.subscriptions = make(map[string][]*Subscription) w.bcaster = bcaster w.minPeersToPublish = minPeersToPublish + w.wg = sync.WaitGroup{} w.log = log.Named("relay") // default options required by WakuRelay @@ -91,6 +113,18 @@ func NewWakuRelay(h host.Host, bcaster v2.Broadcaster, minPeersToPublish int, ti } func (w *WakuRelay) Start(ctx context.Context) error { + w.wg.Wait() + ctx, cancel := context.WithCancel(ctx) + w.cancel = cancel + + w.msgCacheCh = make(chan cacheItem, 1000) + w.deleteCacheCh = make(chan string, 1000) + w.msgCache = make(map[string]*pb.WakuMessage, 1000) + w.getMsgCh = make(chan msgReq, 1000) + + w.wg.Add(1) + go w.cacheWorker(ctx) + ps, err := pubsub.NewGossipSub(ctx, w.host, w.opts...) if err != nil { return err @@ -101,6 +135,56 @@ func (w *WakuRelay) Start(ctx context.Context) error { return nil } +func (w *WakuRelay) getMessageFromCache(topic string, id string) *pb.WakuMessage { + resultCh := make(chan *pb.WakuMessage, 1) + defer close(resultCh) + + w.getMsgCh <- msgReq{ + id: id, + pubsubTopic: topic, + ch: resultCh, + } + + return <-resultCh +} + +func (w *WakuRelay) AddToCache(pubsubTopic string, id string, msg *pb.WakuMessage) { + w.msgCacheCh <- cacheItem{ + msg: msg, + id: id, + pubsubTopic: pubsubTopic, + } +} + +func (w *WakuRelay) cacheWorker(ctx context.Context) { + defer w.wg.Done() + + deleteCounter := 0 + for { + select { + case <-ctx.Done(): + return + case item := <-w.msgCacheCh: + w.msgCache[item.pubsubTopic+item.id] = item.msg + case req := <-w.getMsgCh: + key := req.pubsubTopic + req.id + req.ch <- w.msgCache[key] + w.deleteCacheCh <- key + case key := <-w.deleteCacheCh: + delete(w.msgCache, key) + deleteCounter++ + // Shrink msg cache to avoid oom + if deleteCounter > 1000 { + newMsgCache := make(map[string]*pb.WakuMessage, 1000) + for k, v := range w.msgCache { + newMsgCache[k] = v + } + w.msgCache = newMsgCache + } + } + } +} + // PubSub returns the implementation of the pubsub system func (w *WakuRelay) PubSub() *pubsub.PubSub { return w.pubsub @@ -139,6 +223,20 @@ func (w *WakuRelay) upsertTopic(topic string) (*pubsub.Topic, error) { return pubSubTopic, nil } +func (w *WakuRelay) validatorFactory(pubsubTopic string) func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool { + return func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool { + msg := new(pb.WakuMessage) + err := proto.Unmarshal(message.Data, msg) + if err != nil { + return false + } + + w.AddToCache(pubsubTopic, message.ID, msg) + + return true + } +} + func (w *WakuRelay) subscribe(topic string) (subs *pubsub.Subscription, err error) { sub, ok := w.relaySubs[topic] if !ok { @@ -147,6 +245,11 @@ func (w *WakuRelay) subscribe(topic string) (subs *pubsub.Subscription, err erro return nil, err } + err = w.pubsub.RegisterTopicValidator(topic, w.validatorFactory(topic)) + if err != nil { + return nil, err + } + sub, err = pubSubTopic.Subscribe() if err != nil { return nil, err @@ -204,7 +307,21 @@ func (w *WakuRelay) Publish(ctx context.Context, message *pb.WakuMessage) ([]byt // Stop unmounts the relay protocol and stops all subscriptions func (w *WakuRelay) Stop() { + if w.cancel == nil { + return // Not started + } + w.host.RemoveStreamHandler(WakuRelayID_v200) + + w.cancel() + w.wg.Wait() + + close(w.msgCacheCh) + close(w.deleteCacheCh) + close(w.getMsgCh) + + w.msgCache = nil + w.subscriptionsMutex.Lock() defer w.subscriptionsMutex.Unlock() @@ -228,10 +345,7 @@ func (w *WakuRelay) EnoughPeersToPublishToTopic(topic string) bool { // SubscribeToTopic returns a Subscription to receive messages from a pubsub topic func (w *WakuRelay) SubscribeToTopic(ctx context.Context, topic string) (*Subscription, error) { - // Subscribes to a PubSub topic. - // NOTE The data field SHOULD be decoded as a WakuMessage. sub, err := w.subscribe(topic) - if err != nil { return nil, err } @@ -314,7 +428,7 @@ func (w *WakuRelay) nextMessage(ctx context.Context, sub *pubsub.Subscription) < return msgChannel } -func (w *WakuRelay) subscribeToTopic(ctx context.Context, t string, subscription *Subscription, sub *pubsub.Subscription) { +func (w *WakuRelay) subscribeToTopic(ctx context.Context, pubsubTopic string, subscription *Subscription, sub *pubsub.Subscription) { ctx, err := tag.New(ctx, tag.Insert(metrics.KeyType, "relay")) if err != nil { w.log.Error("creating tag map", zap.Error(err)) @@ -322,7 +436,6 @@ func (w *WakuRelay) subscribeToTopic(ctx context.Context, t string, subscription } subChannel := w.nextMessage(ctx, sub) - for { select { case <-subscription.quit: @@ -339,20 +452,16 @@ func (w *WakuRelay) subscribeToTopic(ctx context.Context, t string, subscription } close(subscription.C) - }(t) + }(pubsubTopic) // TODO: if there are no more relay subscriptions, close the pubsub subscription case msg := <-subChannel: if msg == nil { return } stats.Record(ctx, metrics.Messages.M(1)) - wakuMessage := &pb.WakuMessage{} - if err := proto.Unmarshal(msg.Data, wakuMessage); err != nil { - w.log.Error("decoding message", zap.Error(err)) - return - } - envelope := waku_proto.NewEnvelope(wakuMessage, w.timesource.Now().UnixNano(), string(t)) + wakuMessage := w.getMessageFromCache(pubsubTopic, msg.ID) + envelope := waku_proto.NewEnvelope(wakuMessage, w.timesource.Now().UnixNano(), string(pubsubTopic)) w.log.Debug("waku.relay received", logging.HexString("hash", envelope.Hash())) diff --git a/waku/v2/protocol/relay/waku_relay_test.go b/waku/v2/protocol/relay/waku_relay_test.go index 08eac0fd..8568e32f 100644 --- a/waku/v2/protocol/relay/waku_relay_test.go +++ b/waku/v2/protocol/relay/waku_relay_test.go @@ -36,6 +36,7 @@ func TestWakuRelay(t *testing.T) { require.Equal(t, testTopic, topics[0]) ctx, cancel := context.WithCancel(context.Background()) + go func() { defer cancel() diff --git a/waku/v2/protocol/rln/waku_rln_relay.go b/waku/v2/protocol/rln/waku_rln_relay.go index da724e69..02d47976 100644 --- a/waku/v2/protocol/rln/waku_rln_relay.go +++ b/waku/v2/protocol/rln/waku_rln_relay.go @@ -372,6 +372,9 @@ func (r *WakuRLNRelay) addValidator( zap.Binary("payload", wakuMessage.Payload), zap.Any("proof", wakuMessage.RateLimitProof), ) + + relay.AddToCache(pubsubTopic, message.ID, wakuMessage) + return true case MessageValidationResult_Invalid: r.log.Debug("message could not be verified", @@ -404,6 +407,9 @@ func (r *WakuRLNRelay) addValidator( } } + // In case there's a topic validator registered + _ = relay.PubSub().UnregisterTopicValidator(pubsubTopic) + return relay.PubSub().RegisterTopicValidator(pubsubTopic, validator) }