fix: verify matching `requestId` before removing subscription (#280)

* Check we are removing content topics for the correct request
* Verify request id matches before removing peer as well
This commit is contained in:
Elise Alix 2022-08-03 09:35:25 -04:00 committed by GitHub
parent 546416a9d5
commit 924acf67d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 66 additions and 13 deletions

View File

@ -114,14 +114,14 @@ func (sub *Subscribers) FlagAsFailure(peerID peer.ID) {
} }
} }
func (sub *Subscribers) RemoveContentFilters(peerID peer.ID, contentFilters []*pb.FilterRequest_ContentFilter) { func (sub *Subscribers) RemoveContentFilters(peerID peer.ID, requestId string, contentFilters []*pb.FilterRequest_ContentFilter) {
sub.Lock() sub.Lock()
defer sub.Unlock() defer sub.Unlock()
var peerIdsToRemove []peer.ID var peerIdsToRemove []peer.ID
for subIndex, subscriber := range sub.subscribers { for subIndex, subscriber := range sub.subscribers {
if subscriber.peer != peerID { if subscriber.peer != peerID || subscriber.requestId != requestId {
continue continue
} }
@ -148,11 +148,10 @@ func (sub *Subscribers) RemoveContentFilters(peerID peer.ID, contentFilters []*p
// if no more content filters left // if no more content filters left
for _, peerId := range peerIdsToRemove { for _, peerId := range peerIdsToRemove {
for i, s := range sub.subscribers { for i, s := range sub.subscribers {
if s.peer == peerId { if s.peer == peerId && s.requestId == requestId {
l := len(sub.subscribers) - 1 l := len(sub.subscribers) - 1
sub.subscribers[l], sub.subscribers[i] = sub.subscribers[i], sub.subscribers[l] sub.subscribers[i] = sub.subscribers[l]
sub.subscribers = sub.subscribers[:l] sub.subscribers = sub.subscribers[:l]
break
} }
} }
} }

View File

@ -28,13 +28,14 @@ func firstSubscriber(subs *Subscribers, contentTopic string) *Subscriber {
func TestAppend(t *testing.T) { func TestAppend(t *testing.T) {
subs := NewSubscribers(10 * time.Second) subs := NewSubscribers(10 * time.Second)
peerId := createPeerId(t) peerId := createPeerId(t)
requestId := "request_1"
contentTopic := "topic1" contentTopic := "topic1"
request := pb.FilterRequest{ request := pb.FilterRequest{
Subscribe: true, Subscribe: true,
Topic: TOPIC, Topic: TOPIC,
ContentFilters: []*pb.FilterRequest_ContentFilter{{ContentTopic: contentTopic}}, ContentFilters: []*pb.FilterRequest_ContentFilter{{ContentTopic: contentTopic}},
} }
subs.Append(Subscriber{peerId, "request_1", request}) subs.Append(Subscriber{peerId, requestId, request})
sub := firstSubscriber(subs, contentTopic) sub := firstSubscriber(subs, contentTopic)
assert.NotNil(t, sub) assert.NotNil(t, sub)
@ -43,14 +44,15 @@ func TestAppend(t *testing.T) {
func TestRemove(t *testing.T) { func TestRemove(t *testing.T) {
subs := NewSubscribers(10 * time.Second) subs := NewSubscribers(10 * time.Second)
peerId := createPeerId(t) peerId := createPeerId(t)
requestId := "request_1"
contentTopic := "topic1" contentTopic := "topic1"
request := pb.FilterRequest{ request := pb.FilterRequest{
Subscribe: true, Subscribe: true,
Topic: TOPIC, Topic: TOPIC,
ContentFilters: []*pb.FilterRequest_ContentFilter{{ContentTopic: contentTopic}}, ContentFilters: []*pb.FilterRequest_ContentFilter{{ContentTopic: contentTopic}},
} }
subs.Append(Subscriber{peerId, "request_1", request}) subs.Append(Subscriber{peerId, requestId, request})
subs.RemoveContentFilters(peerId, request.ContentFilters) subs.RemoveContentFilters(peerId, requestId, request.ContentFilters)
sub := firstSubscriber(subs, contentTopic) sub := firstSubscriber(subs, contentTopic)
assert.Nil(t, sub) assert.Nil(t, sub)
@ -59,6 +61,7 @@ func TestRemove(t *testing.T) {
func TestRemovePartial(t *testing.T) { func TestRemovePartial(t *testing.T) {
subs := NewSubscribers(10 * time.Second) subs := NewSubscribers(10 * time.Second)
peerId := createPeerId(t) peerId := createPeerId(t)
requestId := "request_1"
topic1 := "topic1" topic1 := "topic1"
topic2 := "topic2" topic2 := "topic2"
request := pb.FilterRequest{ request := pb.FilterRequest{
@ -66,25 +69,76 @@ func TestRemovePartial(t *testing.T) {
Topic: TOPIC, Topic: TOPIC,
ContentFilters: []*pb.FilterRequest_ContentFilter{{ContentTopic: topic1}, {ContentTopic: topic2}}, ContentFilters: []*pb.FilterRequest_ContentFilter{{ContentTopic: topic1}, {ContentTopic: topic2}},
} }
subs.Append(Subscriber{peerId, "request_1", request}) subs.Append(Subscriber{peerId, requestId, request})
subs.RemoveContentFilters(peerId, []*pb.FilterRequest_ContentFilter{{ContentTopic: topic1}}) subs.RemoveContentFilters(peerId, requestId, []*pb.FilterRequest_ContentFilter{{ContentTopic: topic1}})
sub := firstSubscriber(subs, topic2) sub := firstSubscriber(subs, topic2)
assert.NotNil(t, sub) assert.NotNil(t, sub)
assert.Len(t, sub.filter.ContentFilters, 1) assert.Len(t, sub.filter.ContentFilters, 1)
} }
func TestRemoveDuplicateSubscriptions(t *testing.T) {
subs := NewSubscribers(10 * time.Second)
peerId := createPeerId(t)
topic := "topic"
requestId1 := "request_1"
requestId2 := "request_2"
request1 := pb.FilterRequest{
Subscribe: true,
Topic: TOPIC,
ContentFilters: []*pb.FilterRequest_ContentFilter{{ContentTopic: topic}},
}
request2 := pb.FilterRequest{
Subscribe: true,
Topic: TOPIC,
ContentFilters: []*pb.FilterRequest_ContentFilter{{ContentTopic: topic}},
}
subs.Append(Subscriber{peerId, requestId1, request1})
subs.Append(Subscriber{peerId, requestId2, request2})
subs.RemoveContentFilters(peerId, requestId2, []*pb.FilterRequest_ContentFilter{{ContentTopic: topic}})
subs.RemoveContentFilters(peerId, requestId1, []*pb.FilterRequest_ContentFilter{{ContentTopic: topic}})
sub := firstSubscriber(subs, topic)
assert.Nil(t, sub)
}
func TestRemoveDuplicateSubscriptionsPartial(t *testing.T) {
subs := NewSubscribers(10 * time.Second)
peerId := createPeerId(t)
topic := "topic"
requestId1 := "request_1"
requestId2 := "request_2"
request1 := pb.FilterRequest{
Subscribe: true,
Topic: TOPIC,
ContentFilters: []*pb.FilterRequest_ContentFilter{{ContentTopic: topic}},
}
request2 := pb.FilterRequest{
Subscribe: true,
Topic: TOPIC,
ContentFilters: []*pb.FilterRequest_ContentFilter{{ContentTopic: topic}},
}
subs.Append(Subscriber{peerId, requestId1, request1})
subs.Append(Subscriber{peerId, requestId2, request2})
subs.RemoveContentFilters(peerId, requestId1, []*pb.FilterRequest_ContentFilter{{ContentTopic: topic}})
sub := firstSubscriber(subs, topic)
assert.NotNil(t, sub)
assert.Equal(t, sub.requestId, requestId2)
}
func TestRemoveBogus(t *testing.T) { func TestRemoveBogus(t *testing.T) {
subs := NewSubscribers(10 * time.Second) subs := NewSubscribers(10 * time.Second)
peerId := createPeerId(t) peerId := createPeerId(t)
requestId := "request_1"
contentTopic := "topic1" contentTopic := "topic1"
request := pb.FilterRequest{ request := pb.FilterRequest{
Subscribe: true, Subscribe: true,
Topic: TOPIC, Topic: TOPIC,
ContentFilters: []*pb.FilterRequest_ContentFilter{{ContentTopic: contentTopic}}, ContentFilters: []*pb.FilterRequest_ContentFilter{{ContentTopic: contentTopic}},
} }
subs.Append(Subscriber{peerId, "request_1", request}) subs.Append(Subscriber{peerId, requestId, request})
subs.RemoveContentFilters(peerId, []*pb.FilterRequest_ContentFilter{{ContentTopic: "does not exist"}, {ContentTopic: contentTopic}}) subs.RemoveContentFilters(peerId, requestId, []*pb.FilterRequest_ContentFilter{{ContentTopic: "does not exist"}, {ContentTopic: contentTopic}})
sub := firstSubscriber(subs, contentTopic) sub := firstSubscriber(subs, contentTopic)
assert.Nil(t, sub) assert.Nil(t, sub)

View File

@ -132,7 +132,7 @@ func (wf *WakuFilter) onRequest(s network.Stream) {
stats.Record(wf.ctx, metrics.FilterSubscriptions.M(int64(len))) stats.Record(wf.ctx, metrics.FilterSubscriptions.M(int64(len)))
} else { } else {
peerId := s.Conn().RemotePeer() peerId := s.Conn().RemotePeer()
wf.subscribers.RemoveContentFilters(peerId, filterRPCRequest.Request.ContentFilters) wf.subscribers.RemoveContentFilters(peerId, filterRPCRequest.RequestId, filterRPCRequest.Request.ContentFilters)
logger.Info("removing subscriber") logger.Info("removing subscriber")
stats.Record(wf.ctx, metrics.FilterSubscriptions.M(int64(wf.subscribers.Length()))) stats.Record(wf.ctx, metrics.FilterSubscriptions.M(int64(wf.subscribers.Length())))