Merge pull request #8818 from hashicorp/streaming/add-subscribe-service-batch-events

stream: handle batch events as a special case of Event
This commit is contained in:
Daniel Nephin 2020-10-07 21:25:32 -04:00 committed by GitHub
commit b103568e98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 208 additions and 202 deletions

View File

@ -5,10 +5,11 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/acl" "github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/consul/stream" "github.com/hashicorp/consul/agent/consul/stream"
"github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/agent/structs"
"github.com/stretchr/testify/require"
) )
func TestStore_IntegrationWithEventPublisher_ACLTokenUpdate(t *testing.T) { func TestStore_IntegrationWithEventPublisher_ACLTokenUpdate(t *testing.T) {
@ -294,8 +295,8 @@ func TestStore_IntegrationWithEventPublisher_ACLRoleUpdate(t *testing.T) {
} }
type nextResult struct { type nextResult struct {
Events []stream.Event Event stream.Event
Err error Err error
} }
func testRunSub(sub *stream.Subscription) <-chan nextResult { func testRunSub(sub *stream.Subscription) <-chan nextResult {
@ -304,8 +305,8 @@ func testRunSub(sub *stream.Subscription) <-chan nextResult {
for { for {
es, err := sub.Next(context.TODO()) es, err := sub.Next(context.TODO())
eventCh <- nextResult{ eventCh <- nextResult{
Events: es, Event: es,
Err: err, Err: err,
} }
if err != nil { if err != nil {
return return
@ -320,8 +321,8 @@ func assertNoEvent(t *testing.T, eventCh <-chan nextResult) {
select { select {
case next := <-eventCh: case next := <-eventCh:
require.NoError(t, next.Err) require.NoError(t, next.Err)
require.Len(t, next.Events, 1) require.Len(t, next.Event, 1)
t.Fatalf("got unwanted event: %#v", next.Events[0].Payload) t.Fatalf("got unwanted event: %#v", next.Event.Payload)
case <-time.After(100 * time.Millisecond): case <-time.After(100 * time.Millisecond):
} }
} }
@ -331,8 +332,7 @@ func assertEvent(t *testing.T, eventCh <-chan nextResult) *stream.Event {
select { select {
case next := <-eventCh: case next := <-eventCh:
require.NoError(t, next.Err) require.NoError(t, next.Err)
require.Len(t, next.Events, 1) return &next.Event
return &next.Events[0]
case <-time.After(100 * time.Millisecond): case <-time.After(100 * time.Millisecond):
t.Fatalf("no event after 100ms") t.Fatalf("no event after 100ms")
} }
@ -362,7 +362,7 @@ func assertReset(t *testing.T, eventCh <-chan nextResult, allowEOS bool) {
select { select {
case next := <-eventCh: case next := <-eventCh:
if allowEOS { if allowEOS {
if next.Err == nil && len(next.Events) == 1 && next.Events[0].IsEndOfSnapshot() { if next.Err == nil && next.Event.IsEndOfSnapshot() {
continue continue
} }
} }

View File

@ -19,6 +19,51 @@ type Event struct {
Payload interface{} Payload interface{}
} }
// Len returns the number of events contained within this event. If the Payload
// is a []Event, the length of that slice is returned. Otherwise 1 is returned.
func (e Event) Len() int {
if batch, ok := e.Payload.([]Event); ok {
return len(batch)
}
return 1
}
// Filter returns an Event filtered to only those Events where f returns true.
// If the second return value is false, every Event was removed by the filter.
func (e Event) Filter(f func(Event) bool) (Event, bool) {
batch, ok := e.Payload.([]Event)
if !ok {
return e, f(e)
}
// To avoid extra allocations, iterate over the list of events first and
// get a count of the total desired size. This trades off some extra cpu
// time in the worse case (when not all items match the filter), for
// fewer memory allocations.
var size int
for idx := range batch {
if f(batch[idx]) {
size++
}
}
if len(batch) == size || size == 0 {
return e, size != 0
}
filtered := make([]Event, 0, size)
for idx := range batch {
event := batch[idx]
if f(event) {
filtered = append(filtered, event)
}
}
if len(filtered) == 0 {
return e, false
}
e.Payload = filtered
return e, true
}
// IsEndOfSnapshot returns true if this is a framing event that indicates the // IsEndOfSnapshot returns true if this is a framing event that indicates the
// snapshot has completed. Subsequent events from Subscription.Next will be // snapshot has completed. Subsequent events from Subscription.Next will be
// streamed as they occur. // streamed as they occur.

View File

@ -32,13 +32,11 @@ func TestEventPublisher_SubscribeWithIndex0(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
eventCh := runSubscription(ctx, sub) eventCh := runSubscription(ctx, sub)
next := getNextEvents(t, eventCh) next := getNextEvent(t, eventCh)
expected := []Event{testSnapshotEvent} require.Equal(t, testSnapshotEvent, next)
require.Equal(t, expected, next)
next = getNextEvents(t, eventCh) next = getNextEvent(t, eventCh)
require.Len(t, next, 1) require.True(t, next.IsEndOfSnapshot())
require.True(t, next[0].IsEndOfSnapshot())
assertNoResult(t, eventCh) assertNoResult(t, eventCh)
@ -50,8 +48,8 @@ func TestEventPublisher_SubscribeWithIndex0(t *testing.T) {
publisher.Publish(events) publisher.Publish(events)
// Subscriber should see the published event // Subscriber should see the published event
next = getNextEvents(t, eventCh) next = getNextEvent(t, eventCh)
expected = []Event{{Payload: "the-published-event-payload", Key: "sub-key", Topic: testTopic}} expected := Event{Payload: "the-published-event-payload", Key: "sub-key", Topic: testTopic}
require.Equal(t, expected, next) require.Equal(t, expected, next)
} }
@ -80,8 +78,8 @@ func runSubscription(ctx context.Context, sub *Subscription) <-chan eventOrErr {
for { for {
es, err := sub.Next(ctx) es, err := sub.Next(ctx)
eventCh <- eventOrErr{ eventCh <- eventOrErr{
Events: es, Event: es,
Err: err, Err: err,
} }
if err != nil { if err != nil {
return return
@ -92,19 +90,19 @@ func runSubscription(ctx context.Context, sub *Subscription) <-chan eventOrErr {
} }
type eventOrErr struct { type eventOrErr struct {
Events []Event Event Event
Err error Err error
} }
func getNextEvents(t *testing.T, eventCh <-chan eventOrErr) []Event { func getNextEvent(t *testing.T, eventCh <-chan eventOrErr) Event {
t.Helper() t.Helper()
select { select {
case next := <-eventCh: case next := <-eventCh:
require.NoError(t, next.Err) require.NoError(t, next.Err)
return next.Events return next.Event
case <-time.After(100 * time.Millisecond): case <-time.After(100 * time.Millisecond):
t.Fatalf("timeout waiting for event from subscription") t.Fatalf("timeout waiting for event from subscription")
return nil return Event{}
} }
} }
@ -113,8 +111,7 @@ func assertNoResult(t *testing.T, eventCh <-chan eventOrErr) {
select { select {
case next := <-eventCh: case next := <-eventCh:
require.NoError(t, next.Err) require.NoError(t, next.Err)
require.Len(t, next.Events, 1) t.Fatalf("received unexpected event: %#v", next.Event.Payload)
t.Fatalf("received unexpected event: %#v", next.Events[0].Payload)
case <-time.After(25 * time.Millisecond): case <-time.After(25 * time.Millisecond):
} }
} }
@ -152,11 +149,11 @@ func TestEventPublisher_ShutdownClosesSubscriptions(t *testing.T) {
func consumeSub(ctx context.Context, sub *Subscription) error { func consumeSub(ctx context.Context, sub *Subscription) error {
for { for {
events, err := sub.Next(ctx) event, err := sub.Next(ctx)
switch { switch {
case err != nil: case err != nil:
return err return err
case len(events) == 1 && events[0].IsEndOfSnapshot(): case event.IsEndOfSnapshot():
continue continue
} }
} }
@ -183,28 +180,25 @@ func TestEventPublisher_SubscribeWithIndex0_FromCache(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
eventCh := runSubscription(ctx, sub) eventCh := runSubscription(ctx, sub)
next := getNextEvents(t, eventCh) next := getNextEvent(t, eventCh)
expected := []Event{testSnapshotEvent} require.Equal(t, testSnapshotEvent, next)
require.Equal(t, expected, next)
next = getNextEvents(t, eventCh) next = getNextEvent(t, eventCh)
require.Len(t, next, 1) require.True(t, next.IsEndOfSnapshot())
require.True(t, next[0].IsEndOfSnapshot())
// Now subscriber should block waiting for updates // Now subscriber should block waiting for updates
assertNoResult(t, eventCh) assertNoResult(t, eventCh)
events := []Event{{ expected := Event{
Topic: testTopic, Topic: testTopic,
Key: "sub-key", Key: "sub-key",
Payload: "the-published-event-payload", Payload: "the-published-event-payload",
Index: 3, Index: 3,
}} }
publisher.Publish(events) publisher.Publish([]Event{expected})
// Subscriber should see the published event // Subscriber should see the published event
next = getNextEvents(t, eventCh) next = getNextEvent(t, eventCh)
expected = []Event{events[0]}
require.Equal(t, expected, next) require.Equal(t, expected, next)
} }
@ -228,14 +222,12 @@ func TestEventPublisher_SubscribeWithIndexNotZero_CanResume(t *testing.T) {
eventCh := runSubscription(ctx, sub) eventCh := runSubscription(ctx, sub)
next := getNextEvents(t, eventCh) next := getNextEvent(t, eventCh)
expected := []Event{testSnapshotEvent} require.Equal(t, testSnapshotEvent, next)
require.Equal(t, expected, next)
next = getNextEvents(t, eventCh) next = getNextEvent(t, eventCh)
require.Len(t, next, 1) require.True(t, next.IsEndOfSnapshot())
require.True(t, next[0].IsEndOfSnapshot()) require.Equal(t, uint64(1), next.Index)
require.Equal(t, uint64(1), next[0].Index)
}) })
runStep(t, "resume the subscription", func(t *testing.T) { runStep(t, "resume the subscription", func(t *testing.T) {
@ -255,8 +247,8 @@ func TestEventPublisher_SubscribeWithIndexNotZero_CanResume(t *testing.T) {
} }
publisher.publishEvent([]Event{expected}) publisher.publishEvent([]Event{expected})
next := getNextEvents(t, eventCh) next := getNextEvent(t, eventCh)
require.Equal(t, []Event{expected}, next) require.Equal(t, expected, next)
}) })
} }
@ -280,14 +272,12 @@ func TestEventPublisher_SubscribeWithIndexNotZero_NewSnapshot(t *testing.T) {
eventCh := runSubscription(ctx, sub) eventCh := runSubscription(ctx, sub)
next := getNextEvents(t, eventCh) next := getNextEvent(t, eventCh)
expected := []Event{testSnapshotEvent} require.Equal(t, testSnapshotEvent, next)
require.Equal(t, expected, next)
next = getNextEvents(t, eventCh) next = getNextEvent(t, eventCh)
require.Len(t, next, 1) require.True(t, next.IsEndOfSnapshot())
require.True(t, next[0].IsEndOfSnapshot()) require.Equal(t, uint64(1), next.Index)
require.Equal(t, uint64(1), next[0].Index)
}) })
nextEvent := Event{ nextEvent := Event{
@ -308,14 +298,14 @@ func TestEventPublisher_SubscribeWithIndexNotZero_NewSnapshot(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
eventCh := runSubscription(ctx, sub) eventCh := runSubscription(ctx, sub)
next := getNextEvents(t, eventCh) next := getNextEvent(t, eventCh)
require.True(t, next[0].IsNewSnapshotToFollow(), next) require.True(t, next.IsNewSnapshotToFollow(), next)
next = getNextEvents(t, eventCh) next = getNextEvent(t, eventCh)
require.Equal(t, testSnapshotEvent, next[0]) require.Equal(t, testSnapshotEvent, next)
next = getNextEvents(t, eventCh) next = getNextEvent(t, eventCh)
require.True(t, next[0].IsEndOfSnapshot()) require.True(t, next.IsEndOfSnapshot())
}) })
} }
@ -339,14 +329,12 @@ func TestEventPublisher_SubscribeWithIndexNotZero_NewSnapshotFromCache(t *testin
eventCh := runSubscription(ctx, sub) eventCh := runSubscription(ctx, sub)
next := getNextEvents(t, eventCh) next := getNextEvent(t, eventCh)
expected := []Event{testSnapshotEvent} require.Equal(t, testSnapshotEvent, next)
require.Equal(t, expected, next)
next = getNextEvents(t, eventCh) next = getNextEvent(t, eventCh)
require.Len(t, next, 1) require.True(t, next.IsEndOfSnapshot())
require.True(t, next[0].IsEndOfSnapshot()) require.Equal(t, uint64(1), next.Index)
require.Equal(t, uint64(1), next[0].Index)
}) })
nextEvent := Event{ nextEvent := Event{
@ -371,17 +359,17 @@ func TestEventPublisher_SubscribeWithIndexNotZero_NewSnapshotFromCache(t *testin
require.NoError(t, err) require.NoError(t, err)
eventCh := runSubscription(ctx, sub) eventCh := runSubscription(ctx, sub)
next := getNextEvents(t, eventCh) next := getNextEvent(t, eventCh)
require.True(t, next[0].IsNewSnapshotToFollow(), next) require.True(t, next.IsNewSnapshotToFollow(), next)
next = getNextEvents(t, eventCh) next = getNextEvent(t, eventCh)
require.Equal(t, testSnapshotEvent, next[0]) require.Equal(t, testSnapshotEvent, next)
next = getNextEvents(t, eventCh) next = getNextEvent(t, eventCh)
require.True(t, next[0].IsEndOfSnapshot()) require.True(t, next.IsEndOfSnapshot())
next = getNextEvents(t, eventCh) next = getNextEvent(t, eventCh)
require.Equal(t, nextEvent, next[0]) require.Equal(t, nextEvent, next)
}) })
} }

View File

@ -65,59 +65,56 @@ func newSubscription(req SubscribeRequest, item *bufferItem, unsub func()) *Subs
} }
} }
// Next returns the next set of events to deliver. It must only be called from a // Next returns the next Event to deliver. It must only be called from a
// single goroutine concurrently as it mutates the Subscription. // single goroutine concurrently as it mutates the Subscription.
func (s *Subscription) Next(ctx context.Context) ([]Event, error) { func (s *Subscription) Next(ctx context.Context) (Event, error) {
if atomic.LoadUint32(&s.state) == subscriptionStateClosed { if atomic.LoadUint32(&s.state) == subscriptionStateClosed {
return nil, ErrSubscriptionClosed return Event{}, ErrSubscriptionClosed
} }
for { for {
next, err := s.currentItem.Next(ctx, s.forceClosed) next, err := s.currentItem.Next(ctx, s.forceClosed)
switch { switch {
case err != nil && atomic.LoadUint32(&s.state) == subscriptionStateClosed: case err != nil && atomic.LoadUint32(&s.state) == subscriptionStateClosed:
return nil, ErrSubscriptionClosed return Event{}, ErrSubscriptionClosed
case err != nil: case err != nil:
return nil, err return Event{}, err
} }
s.currentItem = next s.currentItem = next
if len(next.Events) == 0 {
events := filter(s.req.Key, next.Events)
if len(events) == 0 {
continue continue
} }
return events, nil event, ok := filterByKey(s.req, next.Events)
if !ok {
continue
}
return event, nil
} }
} }
// filter events to only those that match the key exactly. func newEventFromBatch(req SubscribeRequest, events []Event) Event {
func filter(key string, events []Event) []Event { first := events[0]
if key == "" || len(events) == 0 { if len(events) == 1 {
return events return first
}
return Event{
Topic: req.Topic,
Key: req.Key,
Index: first.Index,
Payload: events,
}
}
func filterByKey(req SubscribeRequest, events []Event) (Event, bool) {
event := newEventFromBatch(req, events)
if req.Key == "" {
return event, true
} }
var count int fn := func(e Event) bool {
for _, e := range events { return req.Key == e.Key
if key == e.Key {
count++
}
} }
return event.Filter(fn)
// Only allocate a new slice if some events need to be filtered out
switch count {
case 0:
return nil
case len(events):
return events
}
result := make([]Event, 0, count)
for _, e := range events {
if key == e.Key {
result = append(result, e)
}
}
return result
} }
// Close the subscription. Subscribers will receive an error when they call Next, // Close the subscription. Subscribers will receive an error when they call Next,

View File

@ -36,8 +36,7 @@ func TestSubscription(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.True(t, elapsed < 200*time.Millisecond, require.True(t, elapsed < 200*time.Millisecond,
"Event should have been delivered immediately, took %s", elapsed) "Event should have been delivered immediately, took %s", elapsed)
require.Len(t, got, 1) require.Equal(t, index, got.Index)
require.Equal(t, index, got[0].Index)
// Schedule an event publish in a while // Schedule an event publish in a while
index++ index++
@ -54,8 +53,7 @@ func TestSubscription(t *testing.T) {
"Event should have been delivered after blocking 200ms, took %s", elapsed) "Event should have been delivered after blocking 200ms, took %s", elapsed)
require.True(t, elapsed < 2*time.Second, require.True(t, elapsed < 2*time.Second,
"Event should have been delivered after short time, took %s", elapsed) "Event should have been delivered after short time, took %s", elapsed)
require.Len(t, got, 1) require.Equal(t, index, got.Index)
require.Equal(t, index, got[0].Index)
// Event with wrong key should not be delivered. Deliver a good message right // Event with wrong key should not be delivered. Deliver a good message right
// so we don't have to block test thread forever or cancel func yet. // so we don't have to block test thread forever or cancel func yet.
@ -70,9 +68,8 @@ func TestSubscription(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.True(t, elapsed < 200*time.Millisecond, require.True(t, elapsed < 200*time.Millisecond,
"Event should have been delivered immediately, took %s", elapsed) "Event should have been delivered immediately, took %s", elapsed)
require.Len(t, got, 1) require.Equal(t, index, got.Index)
require.Equal(t, index, got[0].Index) require.Equal(t, "test", got.Key)
require.Equal(t, "test", got[0].Key)
// Cancelling the subscription context should unblock Next // Cancelling the subscription context should unblock Next
start = time.Now() start = time.Now()
@ -91,9 +88,7 @@ func TestSubscription(t *testing.T) {
func TestSubscription_Close(t *testing.T) { func TestSubscription_Close(t *testing.T) {
eb := newEventBuffer() eb := newEventBuffer()
index := uint64(100) index := uint64(100)
startHead := eb.Head() startHead := eb.Head()
// Start with an event in the buffer // Start with an event in the buffer
@ -115,8 +110,7 @@ func TestSubscription_Close(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.True(t, elapsed < 200*time.Millisecond, require.True(t, elapsed < 200*time.Millisecond,
"Event should have been delivered immediately, took %s", elapsed) "Event should have been delivered immediately, took %s", elapsed)
require.Len(t, got, 1) require.Equal(t, index, got.Index)
require.Equal(t, index, got[0].Index)
// Schedule a Close simulating the server deciding this subscroption // Schedule a Close simulating the server deciding this subscroption
// needs to reset (e.g. on ACL perm change). // needs to reset (e.g. on ACL perm change).
@ -149,46 +143,55 @@ func publishTestEvent(index uint64, b *eventBuffer, key string) {
func TestFilter_NoKey(t *testing.T) { func TestFilter_NoKey(t *testing.T) {
events := make([]Event, 0, 5) events := make([]Event, 0, 5)
events = append(events, Event{Key: "One"}, Event{Key: "Two"}) events = append(events, Event{Key: "One", Index: 102}, Event{Key: "Two"})
actual := filter("", events) req := SubscribeRequest{Topic: testTopic}
require.Equal(t, events, actual) actual, ok := filterByKey(req, events)
require.True(t, ok)
require.Equal(t, Event{Topic: testTopic, Index: 102, Payload: events}, actual)
// test that a new array was not allocated // test that a new array was not allocated
require.Equal(t, cap(actual), 5) require.Equal(t, cap(actual.Payload.([]Event)), 5)
} }
func TestFilter_WithKey_AllEventsMatch(t *testing.T) { func TestFilter_WithKey_AllEventsMatch(t *testing.T) {
events := make([]Event, 0, 5) events := make([]Event, 0, 5)
events = append(events, Event{Key: "Same"}, Event{Key: "Same"}) events = append(events, Event{Key: "Same", Index: 103}, Event{Key: "Same"})
actual := filter("Same", events) req := SubscribeRequest{Topic: testTopic, Key: "Same"}
require.Equal(t, events, actual) actual, ok := filterByKey(req, events)
require.True(t, ok)
expected := Event{Topic: testTopic, Index: 103, Key: "Same", Payload: events}
require.Equal(t, expected, actual)
// test that a new array was not allocated // test that a new array was not allocated
require.Equal(t, cap(actual), 5) require.Equal(t, 5, cap(actual.Payload.([]Event)))
} }
func TestFilter_WithKey_SomeEventsMatch(t *testing.T) { func TestFilter_WithKey_SomeEventsMatch(t *testing.T) {
events := make([]Event, 0, 5) events := make([]Event, 0, 5)
events = append(events, Event{Key: "Same"}, Event{Key: "Other"}, Event{Key: "Same"}) events = append(events, Event{Key: "Same", Index: 104}, Event{Key: "Other"}, Event{Key: "Same"})
actual := filter("Same", events) req := SubscribeRequest{Topic: testTopic, Key: "Same"}
expected := []Event{{Key: "Same"}, {Key: "Same"}} actual, ok := filterByKey(req, events)
require.True(t, ok)
expected := Event{
Topic: testTopic,
Index: 104,
Key: "Same",
Payload: []Event{{Key: "Same", Index: 104}, {Key: "Same"}},
}
require.Equal(t, expected, actual) require.Equal(t, expected, actual)
// test that a new array was allocated with the correct size // test that a new array was allocated with the correct size
require.Equal(t, cap(actual), 2) require.Equal(t, cap(actual.Payload.([]Event)), 2)
} }
func TestFilter_WithKey_NoEventsMatch(t *testing.T) { func TestFilter_WithKey_NoEventsMatch(t *testing.T) {
events := make([]Event, 0, 5) events := make([]Event, 0, 5)
events = append(events, Event{Key: "Same"}, Event{Key: "Same"}) events = append(events, Event{Key: "Same"}, Event{Key: "Same"})
actual := filter("Other", events) req := SubscribeRequest{Topic: testTopic, Key: "Other"}
var expected []Event _, ok := filterByKey(req, events)
require.Equal(t, expected, actual) require.False(t, ok)
// test that no array was allocated
require.Equal(t, cap(actual), 0)
} }

View File

@ -52,21 +52,17 @@ type eventLogger struct {
count uint64 count uint64
} }
func (l *eventLogger) Trace(e []stream.Event) { func (l *eventLogger) Trace(e stream.Event) {
if len(e) == 0 {
return
}
first := e[0]
switch { switch {
case first.IsEndOfSnapshot(): case e.IsEndOfSnapshot():
l.snapshotDone = true l.snapshotDone = true
l.logger.Trace("snapshot complete", "index", first.Index, "sent", l.count) l.logger.Trace("snapshot complete", "index", e.Index, "sent", l.count)
case first.IsNewSnapshotToFollow(): case e.IsNewSnapshotToFollow():
l.logger.Trace("starting new snapshot", "sent", l.count)
return return
case l.snapshotDone: case l.snapshotDone:
l.logger.Trace("sending events", "index", first.Index, "sent", l.count, "batch_size", len(e)) l.logger.Trace("sending events", "index", e.Index, "sent", l.count, "batch_size", e.Len())
} }
l.count += uint64(len(e)) l.count += uint64(e.Len())
} }

View File

@ -67,7 +67,7 @@ func (h *Server) Subscribe(req *pbsubscribe.SubscribeRequest, serverStream pbsub
ctx := serverStream.Context() ctx := serverStream.Context()
elog := &eventLogger{logger: logger} elog := &eventLogger{logger: logger}
for { for {
events, err := sub.Next(ctx) event, err := sub.Next(ctx)
switch { switch {
case errors.Is(err, stream.ErrSubscriptionClosed): case errors.Is(err, stream.ErrSubscriptionClosed):
logger.Trace("subscription reset by server") logger.Trace("subscription reset by server")
@ -76,13 +76,14 @@ func (h *Server) Subscribe(req *pbsubscribe.SubscribeRequest, serverStream pbsub
return err return err
} }
events = filterStreamEvents(authz, events) var ok bool
if len(events) == 0 { event, ok = filterByAuth(authz, event)
if !ok {
continue continue
} }
elog.Trace(events) elog.Trace(event)
e := newEventFromStreamEvents(req, events) e := newEventFromStreamEvent(req, event)
if err := serverStream.Send(e); err != nil { if err := serverStream.Send(e); err != nil {
return err return err
} }
@ -126,68 +127,44 @@ func forwardToDC(
} }
} }
// filterStreamEvents to only those allowed by the acl token. // filterByAuth to only those Events allowed by the acl token.
func filterStreamEvents(authz acl.Authorizer, events []stream.Event) []stream.Event { func filterByAuth(authz acl.Authorizer, event stream.Event) (stream.Event, bool) {
// authz will be nil when ACLs are disabled // authz will be nil when ACLs are disabled
if authz == nil || len(events) == 0 { if authz == nil {
return events return event, true
} }
fn := func(e stream.Event) bool {
// Fast path for the common case of only 1 event since we can avoid slice return enforceACL(authz, e) == acl.Allow
// allocation in the hot path of every single update event delivered in vast
// majority of cases with this. Note that this is called _per event/item_ when
// sending snapshots which is a lot worse than being called once on regular
// result.
if len(events) == 1 {
if enforceACL(authz, events[0]) == acl.Allow {
return events
}
return nil
} }
return event.Filter(fn)
var filtered []stream.Event
for idx := range events {
event := events[idx]
if enforceACL(authz, event) == acl.Allow {
filtered = append(filtered, event)
}
}
return filtered
} }
func newEventFromStreamEvents(req *pbsubscribe.SubscribeRequest, events []stream.Event) *pbsubscribe.Event { func newEventFromStreamEvent(req *pbsubscribe.SubscribeRequest, event stream.Event) *pbsubscribe.Event {
e := &pbsubscribe.Event{ e := &pbsubscribe.Event{
Topic: req.Topic, Topic: req.Topic,
Key: req.Key, Key: req.Key,
Index: events[0].Index, Index: event.Index,
} }
switch {
if len(events) == 1 { case event.IsEndOfSnapshot():
event := events[0] e.Payload = &pbsubscribe.Event_EndOfSnapshot{EndOfSnapshot: true}
// TODO: refactor so these are only checked once, instead of 3 times. return e
switch { case event.IsNewSnapshotToFollow():
case event.IsEndOfSnapshot(): e.Payload = &pbsubscribe.Event_NewSnapshotToFollow{NewSnapshotToFollow: true}
e.Payload = &pbsubscribe.Event_EndOfSnapshot{EndOfSnapshot: true}
return e
case event.IsNewSnapshotToFollow():
e.Payload = &pbsubscribe.Event_NewSnapshotToFollow{NewSnapshotToFollow: true}
return e
}
setPayload(e, event.Payload)
return e return e
} }
setPayload(e, event.Payload)
e.Payload = &pbsubscribe.Event_EventBatch{
EventBatch: &pbsubscribe.EventBatch{
Events: batchEventsFromEventSlice(events),
},
}
return e return e
} }
func setPayload(e *pbsubscribe.Event, payload interface{}) { func setPayload(e *pbsubscribe.Event, payload interface{}) {
switch p := payload.(type) { switch p := payload.(type) {
case []stream.Event:
e.Payload = &pbsubscribe.Event_EventBatch{
EventBatch: &pbsubscribe.EventBatch{
Events: batchEventsFromEventSlice(p),
},
}
case state.EventPayloadCheckServiceNode: case state.EventPayloadCheckServiceNode:
e.Payload = &pbsubscribe.Event_ServiceHealth{ e.Payload = &pbsubscribe.Event_ServiceHealth{
ServiceHealth: &pbsubscribe.ServiceHealthUpdate{ ServiceHealth: &pbsubscribe.ServiceHealthUpdate{