diff --git a/waku/v2/protocol/filter/waku_filter.go b/waku/v2/protocol/filter/waku_filter.go index d0eff0e6..f8021356 100644 --- a/waku/v2/protocol/filter/waku_filter.go +++ b/waku/v2/protocol/filter/waku_filter.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "errors" "fmt" + "sync" logging "github.com/ipfs/go-log" "github.com/libp2p/go-libp2p-core/host" @@ -65,10 +66,12 @@ type ( WakuFilter struct { ctx context.Context h host.Host - subscribers []Subscriber isFullNode bool pushHandler MessagePushHandler MsgC chan *protocol.Envelope + + subscriberMutex sync.Mutex + subscribers []Subscriber } ) @@ -150,16 +153,21 @@ func (wf *WakuFilter) onRequest(s network.Stream) { // We're on a full node. // This is a filter request coming from a light node. if filterRPCRequest.Request.Subscribe { + wf.subscriberMutex.Lock() + defer wf.subscriberMutex.Unlock() + subscriber := Subscriber{peer: s.Conn().RemotePeer(), requestId: filterRPCRequest.RequestId, filter: *filterRPCRequest.Request} wf.subscribers = append(wf.subscribers, subscriber) log.Info("filter full node, add a filter subscriber: ", subscriber.peer) - stats.Record(wf.ctx, metrics.FilterSubscriptions.M(int64(len(wf.subscribers)))) } else { peerId := s.Conn().RemotePeer() log.Info("filter full node, remove a filter subscriber: ", peerId.Pretty()) contentFilters := filterRPCRequest.Request.ContentFilters var peerIdsToRemove []peer.ID + + wf.subscriberMutex.Lock() + defer wf.subscriberMutex.Unlock() for _, subscriber := range wf.subscribers { if subscriber.peer != peerId { continue diff --git a/waku/v2/protocol/filter/waku_filter_test.go b/waku/v2/protocol/filter/waku_filter_test.go index 86279b3e..daebb59e 100644 --- a/waku/v2/protocol/filter/waku_filter_test.go +++ b/waku/v2/protocol/filter/waku_filter_test.go @@ -62,6 +62,9 @@ func makeWakuFilter(t *testing.T, filters Filters) (*WakuFilter, host.Host) { // Node2 send a succesful message with topic B // Node1 doesn't receive the message func TestWakuFilter(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) // Test can't exceed 10 seconds + defer cancel() + var filters = make(Filters) var testTopic relay.Topic = "/waku/2/go/filter/test" testContentTopic := "TopicA" @@ -79,7 +82,7 @@ func TestWakuFilter(t *testing.T) { } } - node2Filter := NewWakuFilter(context.Background(), host2, true, filterHandler) + node2Filter := NewWakuFilter(ctx, host2, true, filterHandler) broadcaster.Register(node2Filter.MsgC) host1.Peerstore().AddAddr(host2.ID(), tests.GetHostAddress(host2), peerstore.PermanentAddrTTL) @@ -90,7 +93,7 @@ func TestWakuFilter(t *testing.T) { Topic: string(testTopic), ContentTopics: []string{testContentTopic}, } - sub, err := node1.Subscribe(context.Background(), *contentFilter, []FilterSubscribeOption{WithAutomaticPeerSelection()}...) + sub, err := node1.Subscribe(ctx, *contentFilter, WithPeer(node2Filter.h.ID())) require.NoError(t, err) // Sleep to make sure the filter is subscribed @@ -112,12 +115,7 @@ func TestWakuFilter(t *testing.T) { require.Equal(t, contentFilter.ContentTopics[0], env.Message().GetContentTopic()) }() - _, err = node2.Publish(context.Background(), &pb.WakuMessage{ - Payload: []byte{1}, - Version: 0, - ContentTopic: testContentTopic, - Timestamp: 0, - }, &testTopic) + _, err = node2.Publish(ctx, tests.CreateWakuMessage(testContentTopic, 0), &testTopic) require.NoError(t, err) wg.Wait() @@ -127,18 +125,36 @@ func TestWakuFilter(t *testing.T) { select { case <-ch: require.Fail(t, "should not receive another message") - case <-time.After(3 * time.Second): + case <-time.After(1 * time.Second): defer wg.Done() + case <-ctx.Done(): + require.Fail(t, "test exceeded allocated time") } }() - _, err = node2.Publish(context.Background(), &pb.WakuMessage{ - Payload: []byte{1}, - Version: 0, - ContentTopic: "TopicB", - Timestamp: 0, - }, &testTopic) + _, err = node2.Publish(ctx, tests.CreateWakuMessage("TopicB", 1), &testTopic) require.NoError(t, err) wg.Wait() + + wg.Add(1) + go func() { + select { + case <-ch: + require.Fail(t, "should not receive another message") + case <-time.After(1 * time.Second): + defer wg.Done() + case <-ctx.Done(): + require.Fail(t, "test exceeded allocated time") + } + }() + + err = node1.Unsubscribe(ctx, *contentFilter, node2Filter.h.ID()) + require.NoError(t, err) + + time.Sleep(1 * time.Second) + + _, err = node2.Publish(ctx, tests.CreateWakuMessage(testContentTopic, 2), &testTopic) + require.NoError(t, err) + wg.Wait() }