diff --git a/agent/consul/stream/event_publisher.go b/agent/consul/stream/event_publisher.go index 444b117c53..9dfb8bf9e5 100644 --- a/agent/consul/stream/event_publisher.go +++ b/agent/consul/stream/event_publisher.go @@ -103,8 +103,7 @@ func (e *EventPublisher) handleUpdates(ctx context.Context) { for { select { case <-ctx.Done(): - // TODO: also close all subscriptions so the subscribers are moved - // to the new publisher? + e.subscriptions.closeAll() return case update := <-e.publishCh: e.sendEvents(update) @@ -249,6 +248,17 @@ func (s *subscriptions) unsubscribe(req *SubscribeRequest) func() { } } +func (s *subscriptions) closeAll() { + s.lock.Lock() + defer s.lock.Unlock() + + for _, byRequest := range s.byToken { + for _, sub := range byRequest { + sub.forceClose() + } + } +} + func (e *EventPublisher) getSnapshotLocked(req *SubscribeRequest, topicHead *bufferItem) (*eventSnapshot, error) { topicSnaps, ok := e.snapCache[req.Topic] if !ok { diff --git a/agent/consul/stream/event_publisher_test.go b/agent/consul/stream/event_publisher_test.go index 4a8c6542ef..4deeb1503e 100644 --- a/agent/consul/stream/event_publisher_test.go +++ b/agent/consul/stream/event_publisher_test.go @@ -111,3 +111,45 @@ func assertNoResult(t *testing.T, eventCh <-chan subNextResult) { case <-time.After(100 * time.Millisecond): } } + +func TestEventPublisher_ShutdownClosesSubscriptions(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + handlers := newTestSnapshotHandlers() + fn := func(req *SubscribeRequest, buf SnapshotAppender) (uint64, error) { + return 0, nil + } + handlers[intTopic(22)] = fn + handlers[intTopic(33)] = fn + + publisher := NewEventPublisher(ctx, handlers, time.Second) + + sub1, err := publisher.Subscribe(&SubscribeRequest{Topic: intTopic(22)}) + require.NoError(t, err) + defer sub1.Unsubscribe() + + sub2, err := publisher.Subscribe(&SubscribeRequest{Topic: intTopic(33)}) + require.NoError(t, err) + defer sub2.Unsubscribe() + + cancel() // Shutdown + + err = consumeSub(context.Background(), sub1) + require.Equal(t, err, ErrSubscriptionClosed) + + _, err = sub2.Next(context.Background()) + require.Equal(t, err, ErrSubscriptionClosed) +} + +func consumeSub(ctx context.Context, sub *Subscription) error { + for { + events, err := sub.Next(ctx) + switch { + case err != nil: + return err + case len(events) == 1 && events[0].IsEndOfSnapshot(): + continue + } + } +} diff --git a/agent/consul/stream/subscription.go b/agent/consul/stream/subscription.go index c2177468ff..e4a24cc1eb 100644 --- a/agent/consul/stream/subscription.go +++ b/agent/consul/stream/subscription.go @@ -82,7 +82,7 @@ func (s *Subscription) Next(ctx context.Context) ([]Event, error) { } s.currentItem = next - events := s.filter(next.Events) + events := filter(s.req.Key, next.Events) if len(events) == 0 { continue } @@ -90,34 +90,34 @@ func (s *Subscription) Next(ctx context.Context) ([]Event, error) { } } -// TODO: test cases for this method -func (s *Subscription) filter(events []Event) []Event { - if s.req.Key == "" || len(events) == 0 { +// filter events to only those that match the key exactly. +func filter(key string, events []Event) []Event { + if key == "" || len(events) == 0 { return events } - allMatch := true + var count int for _, e := range events { - if s.req.Key != e.Key { - allMatch = false - break + if key == e.Key { + count++ } } // Only allocate a new slice if some events need to be filtered out - if allMatch { + switch count { + case 0: + return nil + case len(events): return events } - // FIXME: this will over-allocate. We could get a count from the previous range - // over events. - events = make([]Event, 0, len(events)) + result := make([]Event, 0, count) for _, e := range events { - if s.req.Key == e.Key { - events = append(events, e) + if key == e.Key { + result = append(result, e) } } - return events + return result } // Close the subscription. Subscribers will receive an error when they call Next, diff --git a/agent/consul/stream/subscription_test.go b/agent/consul/stream/subscription_test.go index 36a60dc482..84e941a3bf 100644 --- a/agent/consul/stream/subscription_test.go +++ b/agent/consul/stream/subscription_test.go @@ -148,3 +148,49 @@ func publishTestEvent(index uint64, b *eventBuffer, key string) { } b.Append([]Event{e}) } + +func TestFilter_NoKey(t *testing.T) { + events := make([]Event, 0, 5) + events = append(events, Event{Key: "One"}, Event{Key: "Two"}) + + actual := filter("", events) + require.Equal(t, events, actual) + + // test that a new array was not allocated + require.Equal(t, cap(actual), 5) +} + +func TestFilter_WithKey_AllEventsMatch(t *testing.T) { + events := make([]Event, 0, 5) + events = append(events, Event{Key: "Same"}, Event{Key: "Same"}) + + actual := filter("Same", events) + require.Equal(t, events, actual) + + // test that a new array was not allocated + require.Equal(t, cap(actual), 5) +} + +func TestFilter_WithKey_SomeEventsMatch(t *testing.T) { + events := make([]Event, 0, 5) + events = append(events, Event{Key: "Same"}, Event{Key: "Other"}, Event{Key: "Same"}) + + actual := filter("Same", events) + expected := []Event{{Key: "Same"}, {Key: "Same"}} + require.Equal(t, expected, actual) + + // test that a new array was allocated with the correct size + require.Equal(t, cap(actual), 2) +} + +func TestFilter_WithKey_NoEventsMatch(t *testing.T) { + events := make([]Event, 0, 5) + events = append(events, Event{Key: "Same"}, Event{Key: "Same"}) + + actual := filter("Other", events) + var expected []Event + require.Equal(t, expected, actual) + + // test that no array was allocated + require.Equal(t, cap(actual), 0) +}