diff --git a/waku/v2/protocol/filter/filter_subscribers.go b/waku/v2/protocol/filter/filter_subscribers.go index d83036e2..2095409e 100644 --- a/waku/v2/protocol/filter/filter_subscribers.go +++ b/waku/v2/protocol/filter/filter_subscribers.go @@ -114,14 +114,14 @@ func (sub *Subscribers) FlagAsFailure(peerID peer.ID) { } } -func (sub *Subscribers) RemoveContentFilters(peerID peer.ID, contentFilters []*pb.FilterRequest_ContentFilter) { +func (sub *Subscribers) RemoveContentFilters(peerID peer.ID, requestId string, contentFilters []*pb.FilterRequest_ContentFilter) { sub.Lock() defer sub.Unlock() var peerIdsToRemove []peer.ID for subIndex, subscriber := range sub.subscribers { - if subscriber.peer != peerID { + if subscriber.peer != peerID || subscriber.requestId != requestId { continue } @@ -148,11 +148,10 @@ func (sub *Subscribers) RemoveContentFilters(peerID peer.ID, contentFilters []*p // if no more content filters left for _, peerId := range peerIdsToRemove { for i, s := range sub.subscribers { - if s.peer == peerId { + if s.peer == peerId && s.requestId == requestId { l := len(sub.subscribers) - 1 - sub.subscribers[l], sub.subscribers[i] = sub.subscribers[i], sub.subscribers[l] + sub.subscribers[i] = sub.subscribers[l] sub.subscribers = sub.subscribers[:l] - break } } } diff --git a/waku/v2/protocol/filter/filter_subscribers_test.go b/waku/v2/protocol/filter/filter_subscribers_test.go index fa924f06..413a6b87 100644 --- a/waku/v2/protocol/filter/filter_subscribers_test.go +++ b/waku/v2/protocol/filter/filter_subscribers_test.go @@ -28,13 +28,14 @@ func firstSubscriber(subs *Subscribers, contentTopic string) *Subscriber { func TestAppend(t *testing.T) { subs := NewSubscribers(10 * time.Second) peerId := createPeerId(t) + requestId := "request_1" contentTopic := "topic1" request := pb.FilterRequest{ Subscribe: true, Topic: TOPIC, ContentFilters: []*pb.FilterRequest_ContentFilter{{ContentTopic: contentTopic}}, } - subs.Append(Subscriber{peerId, "request_1", request}) + subs.Append(Subscriber{peerId, requestId, request}) sub := firstSubscriber(subs, contentTopic) assert.NotNil(t, sub) @@ -43,14 +44,15 @@ func TestAppend(t *testing.T) { func TestRemove(t *testing.T) { subs := NewSubscribers(10 * time.Second) peerId := createPeerId(t) + requestId := "request_1" contentTopic := "topic1" request := pb.FilterRequest{ Subscribe: true, Topic: TOPIC, ContentFilters: []*pb.FilterRequest_ContentFilter{{ContentTopic: contentTopic}}, } - subs.Append(Subscriber{peerId, "request_1", request}) - subs.RemoveContentFilters(peerId, request.ContentFilters) + subs.Append(Subscriber{peerId, requestId, request}) + subs.RemoveContentFilters(peerId, requestId, request.ContentFilters) sub := firstSubscriber(subs, contentTopic) assert.Nil(t, sub) @@ -59,6 +61,7 @@ func TestRemove(t *testing.T) { func TestRemovePartial(t *testing.T) { subs := NewSubscribers(10 * time.Second) peerId := createPeerId(t) + requestId := "request_1" topic1 := "topic1" topic2 := "topic2" request := pb.FilterRequest{ @@ -66,25 +69,76 @@ func TestRemovePartial(t *testing.T) { Topic: TOPIC, ContentFilters: []*pb.FilterRequest_ContentFilter{{ContentTopic: topic1}, {ContentTopic: topic2}}, } - subs.Append(Subscriber{peerId, "request_1", request}) - subs.RemoveContentFilters(peerId, []*pb.FilterRequest_ContentFilter{{ContentTopic: topic1}}) + subs.Append(Subscriber{peerId, requestId, request}) + subs.RemoveContentFilters(peerId, requestId, []*pb.FilterRequest_ContentFilter{{ContentTopic: topic1}}) sub := firstSubscriber(subs, topic2) assert.NotNil(t, sub) assert.Len(t, sub.filter.ContentFilters, 1) } +func TestRemoveDuplicateSubscriptions(t *testing.T) { + subs := NewSubscribers(10 * time.Second) + peerId := createPeerId(t) + topic := "topic" + requestId1 := "request_1" + requestId2 := "request_2" + request1 := pb.FilterRequest{ + Subscribe: true, + Topic: TOPIC, + ContentFilters: []*pb.FilterRequest_ContentFilter{{ContentTopic: topic}}, + } + request2 := pb.FilterRequest{ + Subscribe: true, + Topic: TOPIC, + ContentFilters: []*pb.FilterRequest_ContentFilter{{ContentTopic: topic}}, + } + subs.Append(Subscriber{peerId, requestId1, request1}) + subs.Append(Subscriber{peerId, requestId2, request2}) + subs.RemoveContentFilters(peerId, requestId2, []*pb.FilterRequest_ContentFilter{{ContentTopic: topic}}) + subs.RemoveContentFilters(peerId, requestId1, []*pb.FilterRequest_ContentFilter{{ContentTopic: topic}}) + + sub := firstSubscriber(subs, topic) + assert.Nil(t, sub) +} + +func TestRemoveDuplicateSubscriptionsPartial(t *testing.T) { + subs := NewSubscribers(10 * time.Second) + peerId := createPeerId(t) + topic := "topic" + requestId1 := "request_1" + requestId2 := "request_2" + request1 := pb.FilterRequest{ + Subscribe: true, + Topic: TOPIC, + ContentFilters: []*pb.FilterRequest_ContentFilter{{ContentTopic: topic}}, + } + request2 := pb.FilterRequest{ + Subscribe: true, + Topic: TOPIC, + ContentFilters: []*pb.FilterRequest_ContentFilter{{ContentTopic: topic}}, + } + subs.Append(Subscriber{peerId, requestId1, request1}) + subs.Append(Subscriber{peerId, requestId2, request2}) + subs.RemoveContentFilters(peerId, requestId1, []*pb.FilterRequest_ContentFilter{{ContentTopic: topic}}) + + sub := firstSubscriber(subs, topic) + assert.NotNil(t, sub) + assert.Equal(t, sub.requestId, requestId2) +} + func TestRemoveBogus(t *testing.T) { subs := NewSubscribers(10 * time.Second) peerId := createPeerId(t) + requestId := "request_1" contentTopic := "topic1" request := pb.FilterRequest{ Subscribe: true, Topic: TOPIC, ContentFilters: []*pb.FilterRequest_ContentFilter{{ContentTopic: contentTopic}}, } - subs.Append(Subscriber{peerId, "request_1", request}) - subs.RemoveContentFilters(peerId, []*pb.FilterRequest_ContentFilter{{ContentTopic: "does not exist"}, {ContentTopic: contentTopic}}) + subs.Append(Subscriber{peerId, requestId, request}) + subs.RemoveContentFilters(peerId, requestId, []*pb.FilterRequest_ContentFilter{{ContentTopic: "does not exist"}, {ContentTopic: contentTopic}}) sub := firstSubscriber(subs, contentTopic) assert.Nil(t, sub) diff --git a/waku/v2/protocol/filter/waku_filter.go b/waku/v2/protocol/filter/waku_filter.go index 5b0d4ef7..c98bbee0 100644 --- a/waku/v2/protocol/filter/waku_filter.go +++ b/waku/v2/protocol/filter/waku_filter.go @@ -132,7 +132,7 @@ func (wf *WakuFilter) onRequest(s network.Stream) { stats.Record(wf.ctx, metrics.FilterSubscriptions.M(int64(len))) } else { peerId := s.Conn().RemotePeer() - wf.subscribers.RemoveContentFilters(peerId, filterRPCRequest.Request.ContentFilters) + wf.subscribers.RemoveContentFilters(peerId, filterRPCRequest.RequestId, filterRPCRequest.Request.ContentFilters) logger.Info("removing subscriber") stats.Record(wf.ctx, metrics.FilterSubscriptions.M(int64(wf.subscribers.Length())))