diff --git a/agent/consul/subscribe_backend_test.go b/agent/consul/subscribe_backend_test.go index fe6a957324..de80a2088e 100644 --- a/agent/consul/subscribe_backend_test.go +++ b/agent/consul/subscribe_backend_test.go @@ -2,15 +2,16 @@ package consul import ( "context" + "errors" "fmt" "io" "strings" - "sync" "sync/atomic" "testing" "time" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" gogrpc "google.golang.org/grpc" grpc "github.com/hashicorp/consul/agent/grpc" @@ -328,31 +329,18 @@ func TestSubscribeBackend_IntegrationWithServer_DeliversAllMessages(t *testing.T // Now start a whole bunch of streamers in parallel to maximise chance of // catching a race. n := 5 - var wg sync.WaitGroup + var g errgroup.Group var updateCount uint64 - // Buffered error chan so that workers can exit and terminate wg without - // blocking on send. We collect errors this way since t isn't thread safe. - errCh := make(chan error, n) for i := 0; i < n; i++ { i := i - wg.Add(1) - go func() { - defer wg.Done() - verifyMonotonicStreamUpdates(ctx, t, streamClient, i, &updateCount, errCh) - }() + g.Go(func() error { + return verifyMonotonicStreamUpdates(ctx, t, streamClient, i, &updateCount) + }) } // Wait until all subscribers have verified the first bunch of updates all got // delivered. - wg.Wait() - - close(errCh) - - // Require that none of them errored. Since we closed the chan above this loop - // should terminate immediately if no errors were buffered. - for err := range errCh { - require.NoError(t, err) - } + require.NoError(t, g.Wait()) // Sanity check that at least some non-snapshot messages were delivered. We // can't know exactly how many because it's timing dependent based on when @@ -394,70 +382,62 @@ type testLogger interface { Logf(format string, args ...interface{}) } -func verifyMonotonicStreamUpdates(ctx context.Context, logger testLogger, client pbsubscribe.StateChangeSubscriptionClient, i int, updateCount *uint64, errCh chan<- error) { +func verifyMonotonicStreamUpdates(ctx context.Context, logger testLogger, client pbsubscribe.StateChangeSubscriptionClient, i int, updateCount *uint64) error { req := &pbsubscribe.SubscribeRequest{Topic: pbsubscribe.Topic_ServiceHealth, Key: "redis"} streamHandle, err := client.Subscribe(ctx, req) - if err != nil { - if strings.Contains(err.Error(), "context deadline exceeded") || - strings.Contains(err.Error(), "context canceled") { - logger.Logf("subscriber %05d: context cancelled before loop") - return - } - errCh <- err - return + switch { + case errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded): + logger.Logf("subscriber %05d: context cancelled before loop") + return nil + case err != nil: + return err } snapshotDone := false expectPort := int32(0) for { event, err := streamHandle.Recv() - if err == io.EOF { - break - } - if err != nil { - if strings.Contains(err.Error(), "context deadline exceeded") || - strings.Contains(err.Error(), "context canceled") { - break - } - errCh <- err - return + switch { + case err == io.EOF: + return nil + case errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded): + return nil + case err != nil: + return err } switch { case event.GetEndOfSnapshot(): snapshotDone = true - logger.Logf("subscriber %05d: snapshot done, expect next port to be %d", i, expectPort) + logger.Logf("subscriber %05d: snapshot done at index %d, expect next port to be %d", i, event.Index, expectPort) case snapshotDone: // Verify we get all updates in order svc, err := svcOrErr(event) if err != nil { - errCh <- err - return + return err } if expectPort != svc.Port { - errCh <- fmt.Errorf("subscriber %05d: missed %d update(s)!", i, svc.Port-expectPort) - return + return fmt.Errorf("subscriber %05d: at index %d: expected port %d, got %d", + i, event.Index, expectPort, svc.Port) } atomic.AddUint64(updateCount, 1) - logger.Logf("subscriber %05d: got event with correct port=%d", i, expectPort) + logger.Logf("subscriber %05d: got event with correct port=%d at index %d", i, expectPort, event.Index) expectPort++ default: - // This is a snapshot update. Check if it's an update for the canary - // instance that got applied before our snapshot was sent (likely) + // snapshot events svc, err := svcOrErr(event) if err != nil { - errCh <- err - return + return err } if svc.ID == "redis-canary" { // Update the expected port we see in the next update to be one more // than the port in the snapshot. expectPort = svc.Port + 1 - logger.Logf("subscriber %05d: saw canary in snapshot with port %d", i, svc.Port) + logger.Logf("subscriber %05d: saw canary in snapshot with port %d at index %d", i, svc.Port, event.Index) } } if expectPort > 100 { - return + return nil } } }