From 4359e3811424b062d2e751045f99be98cc6bd4eb Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Fri, 6 Aug 2021 14:24:12 -0400 Subject: [PATCH] debug: restore cancel on SigInt Some previous changes broke interrupting the debug on SigInterupt. This change restores the original behaviour by passing a context to requests. Since a new API client function was required to pass the context, I had it also return an io.ReadCloser, so that output can be streamed to files instead of fully buffering in process memory. --- api/debug.go | 22 +++++++++ command/commands_oss.go | 2 +- command/debug/debug.go | 92 +++++++++++++++++++++++-------------- command/debug/debug_test.go | 24 +++++----- command/registry.go | 1 + 5 files changed, 93 insertions(+), 48 deletions(-) diff --git a/api/debug.go b/api/debug.go index 56dcc9bcd2..2e7bb94b5b 100644 --- a/api/debug.go +++ b/api/debug.go @@ -1,7 +1,9 @@ package api import ( + "context" "fmt" + "io" "io/ioutil" "strconv" ) @@ -70,6 +72,26 @@ func (d *Debug) Profile(seconds int) ([]byte, error) { return body, nil } +// PProf returns a pprof profile for the specified number of seconds. The caller +// is responsible for closing the returned io.ReadCloser once all bytes are read. +func (d *Debug) PProf(ctx context.Context, name string, seconds int) (io.ReadCloser, error) { + r := d.c.newRequest("GET", "/debug/pprof/"+name) + r.ctx = ctx + + // Capture a profile for the specified number of seconds + r.params.Set("seconds", strconv.Itoa(seconds)) + + _, resp, err := d.c.doRequest(r) + if err != nil { + return nil, fmt.Errorf("error making request: %s", err) + } + + if resp.StatusCode != 200 { + return nil, generateUnexpectedResponseCodeError(resp) + } + return resp.Body, nil +} + // Trace returns an execution trace func (d *Debug) Trace(seconds int) ([]byte, error) { r := d.c.newRequest("GET", "/debug/pprof/trace") diff --git a/command/commands_oss.go b/command/commands_oss.go index cfa8a7e621..92b49e4106 100644 --- a/command/commands_oss.go +++ b/command/commands_oss.go @@ -166,7 +166,7 @@ func init() { Register("connect envoy pipe-bootstrap", func(ui cli.Ui) (cli.Command, error) { return pipebootstrap.New(ui), nil }) Register("connect expose", func(ui cli.Ui) (cli.Command, error) { return expose.New(ui), nil }) Register("connect redirect-traffic", func(ui cli.Ui) (cli.Command, error) { return redirecttraffic.New(ui), nil }) - Register("debug", func(ui cli.Ui) (cli.Command, error) { return debug.New(ui, MakeShutdownCh()), nil }) + Register("debug", func(ui cli.Ui) (cli.Command, error) { return debug.New(ui), nil }) Register("event", func(ui cli.Ui) (cli.Command, error) { return event.New(ui), nil }) Register("exec", func(ui cli.Ui) (cli.Command, error) { return exec.New(ui, MakeShutdownCh()), nil }) Register("force-leave", func(ui cli.Ui) (cli.Command, error) { return forceleave.New(ui), nil }) diff --git a/command/debug/debug.go b/command/debug/debug.go index ccb2607889..869161dadd 100644 --- a/command/debug/debug.go +++ b/command/debug/debug.go @@ -12,8 +12,10 @@ import ( "io" "io/ioutil" "os" + "os/signal" "path/filepath" "strings" + "syscall" "time" "golang.org/x/sync/errgroup" @@ -55,7 +57,7 @@ const ( debugProtocolVersion = 1 ) -func New(ui cli.Ui, shutdownCh <-chan struct{}) *cmd { +func New(ui cli.Ui) *cmd { ui = &cli.PrefixedUi{ OutputPrefix: "==> ", InfoPrefix: " ", @@ -63,7 +65,7 @@ func New(ui cli.Ui, shutdownCh <-chan struct{}) *cmd { Ui: ui, } - c := &cmd{UI: ui, shutdownCh: shutdownCh} + c := &cmd{UI: ui} c.init() return c } @@ -74,8 +76,6 @@ type cmd struct { http *flags.HTTPFlags help string - shutdownCh <-chan struct{} - // flags interval time.Duration duration time.Duration @@ -136,6 +136,9 @@ func (c *cmd) init() { } func (c *cmd) Run(args []string) int { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer cancel() + if err := c.flags.Parse(args); err != nil { c.UI.Error(fmt.Sprintf("Error parsing flags: %s", err)) return 1 @@ -197,8 +200,12 @@ func (c *cmd) Run(args []string) int { // Capture dynamic information from the target agent, blocking for duration if c.captureTarget(targetMetrics) || c.captureTarget(targetLogs) || c.captureTarget(targetProfiles) { g := new(errgroup.Group) - g.Go(c.captureInterval) - g.Go(c.captureLongRunning) + g.Go(func() error { + return c.captureInterval(ctx) + }) + g.Go(func() error { + return c.captureLongRunning(ctx) + }) err = g.Wait() if err != nil { c.UI.Error(fmt.Sprintf("Error encountered during collection: %v", err)) @@ -333,7 +340,7 @@ func writeJSONFile(filename string, content interface{}) error { // captureInterval blocks for the duration of the command // specified by the duration flag, capturing the dynamic // targets at the interval specified -func (c *cmd) captureInterval() error { +func (c *cmd) captureInterval(ctx context.Context) error { intervalChn := time.NewTicker(c.interval) defer intervalChn.Stop() durationChn := time.After(c.duration) @@ -358,7 +365,7 @@ func (c *cmd) captureInterval() error { case <-durationChn: intervalChn.Stop() return nil - case <-c.shutdownCh: + case <-ctx.Done(): return errors.New("stopping collection due to shutdown signal") } } @@ -395,7 +402,7 @@ func (c *cmd) createTimestampDir(timestamp int64) (string, error) { return timestampDir, nil } -func (c *cmd) captureLongRunning() error { +func (c *cmd) captureLongRunning(ctx context.Context) error { timestamp := time.Now().Local().Unix() timestampDir, err := c.createTimestampDir(timestamp) @@ -411,24 +418,27 @@ func (c *cmd) captureLongRunning() error { } if c.captureTarget(targetProfiles) { g.Go(func() error { - return c.captureProfile(s, timestampDir) + // use ctx without a timeout to allow the profile to finish sending + return c.captureProfile(ctx, s, timestampDir) }) g.Go(func() error { - return c.captureTrace(s, timestampDir) + // use ctx without a timeout to allow the trace to finish sending + return c.captureTrace(ctx, s, timestampDir) }) } if c.captureTarget(targetLogs) { g.Go(func() error { - return c.captureLogs(timestampDir) + ctx, cancel := context.WithTimeout(ctx, c.duration) + defer cancel() + return c.captureLogs(ctx, timestampDir) }) } if c.captureTarget(targetMetrics) { - // TODO: pass in context from caller - ctx, cancel := context.WithTimeout(context.Background(), c.duration) - defer cancel() - g.Go(func() error { + + ctx, cancel := context.WithTimeout(ctx, c.duration) + defer cancel() return c.captureMetrics(ctx, timestampDir) }) } @@ -445,22 +455,38 @@ func (c *cmd) captureGoRoutines(timestampDir string) error { return ioutil.WriteFile(fmt.Sprintf("%s/goroutine.prof", timestampDir), gr, 0644) } -func (c *cmd) captureTrace(s float64, timestampDir string) error { - trace, err := c.client.Debug().Trace(int(s)) - if err != nil { - return fmt.Errorf("failed to collect trace: %w", err) - } - - return ioutil.WriteFile(fmt.Sprintf("%s/trace.out", timestampDir), trace, 0644) -} - -func (c *cmd) captureProfile(s float64, timestampDir string) error { - prof, err := c.client.Debug().Profile(int(s)) +func (c *cmd) captureTrace(ctx context.Context, s float64, timestampDir string) error { + prof, err := c.client.Debug().PProf(ctx, "trace", int(s)) if err != nil { return fmt.Errorf("failed to collect cpu profile: %w", err) } + defer prof.Close() - return ioutil.WriteFile(fmt.Sprintf("%s/profile.prof", timestampDir), prof, 0644) + r := bufio.NewReader(prof) + fh, err := os.Create(fmt.Sprintf("%s/trace.out", timestampDir)) + if err != nil { + return err + } + defer fh.Close() + _, err = r.WriteTo(fh) + return err +} + +func (c *cmd) captureProfile(ctx context.Context, s float64, timestampDir string) error { + prof, err := c.client.Debug().PProf(ctx, "profile", int(s)) + if err != nil { + return fmt.Errorf("failed to collect cpu profile: %w", err) + } + defer prof.Close() + + r := bufio.NewReader(prof) + fh, err := os.Create(fmt.Sprintf("%s/profile.prof", timestampDir)) + if err != nil { + return err + } + defer fh.Close() + _, err = r.WriteTo(fh) + return err } func (c *cmd) captureHeap(timestampDir string) error { @@ -472,15 +498,11 @@ func (c *cmd) captureHeap(timestampDir string) error { return ioutil.WriteFile(fmt.Sprintf("%s/heap.prof", timestampDir), heap, 0644) } -func (c *cmd) captureLogs(timestampDir string) error { - endLogChn := make(chan struct{}) - timeIsUp := time.After(c.duration) - logCh, err := c.client.Agent().Monitor("DEBUG", endLogChn, nil) +func (c *cmd) captureLogs(ctx context.Context, timestampDir string) error { + logCh, err := c.client.Agent().Monitor("DEBUG", ctx.Done(), nil) if err != nil { return err } - // Close the log stream - defer close(endLogChn) // Create the log file for writing f, err := os.Create(fmt.Sprintf("%s/%s", timestampDir, "consul.log")) @@ -498,7 +520,7 @@ func (c *cmd) captureLogs(timestampDir string) error { if _, err = f.WriteString(log + "\n"); err != nil { return err } - case <-timeIsUp: + case <-ctx.Done(): return nil } } diff --git a/command/debug/debug_test.go b/command/debug/debug_test.go index 215475706c..22820e5cfb 100644 --- a/command/debug/debug_test.go +++ b/command/debug/debug_test.go @@ -26,7 +26,7 @@ import ( ) func TestDebugCommand_Help_TextContainsNoTabs(t *testing.T) { - if strings.ContainsRune(New(cli.NewMockUi(), nil).Help(), '\t') { + if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') { t.Fatal("help has tabs") } } @@ -46,7 +46,7 @@ func TestDebugCommand(t *testing.T) { testrpc.WaitForLeader(t, a.RPC, "dc1") ui := cli.NewMockUi() - cmd := New(ui, nil) + cmd := New(ui) cmd.validateTiming = false outputPath := fmt.Sprintf("%s/debug", testDir) @@ -92,7 +92,7 @@ func TestDebugCommand_Archive(t *testing.T) { testrpc.WaitForLeader(t, a.RPC, "dc1") ui := cli.NewMockUi() - cmd := New(ui, nil) + cmd := New(ui) cmd.validateTiming = false outputPath := fmt.Sprintf("%s/debug", testDir) @@ -137,7 +137,7 @@ func TestDebugCommand_Archive(t *testing.T) { func TestDebugCommand_ArgsBad(t *testing.T) { ui := cli.NewMockUi() - cmd := New(ui, nil) + cmd := New(ui) args := []string{"foo", "bad"} @@ -153,7 +153,7 @@ func TestDebugCommand_ArgsBad(t *testing.T) { func TestDebugCommand_InvalidFlags(t *testing.T) { ui := cli.NewMockUi() - cmd := New(ui, nil) + cmd := New(ui) cmd.validateTiming = false outputPath := "" @@ -186,7 +186,7 @@ func TestDebugCommand_OutputPathBad(t *testing.T) { testrpc.WaitForLeader(t, a.RPC, "dc1") ui := cli.NewMockUi() - cmd := New(ui, nil) + cmd := New(ui) cmd.validateTiming = false outputPath := "" @@ -219,7 +219,7 @@ func TestDebugCommand_OutputPathExists(t *testing.T) { testrpc.WaitForLeader(t, a.RPC, "dc1") ui := cli.NewMockUi() - cmd := New(ui, nil) + cmd := New(ui) cmd.validateTiming = false outputPath := fmt.Sprintf("%s/debug", testDir) @@ -304,7 +304,7 @@ func TestDebugCommand_CaptureTargets(t *testing.T) { testrpc.WaitForLeader(t, a.RPC, "dc1") ui := cli.NewMockUi() - cmd := New(ui, nil) + cmd := New(ui) cmd.validateTiming = false outputPath := fmt.Sprintf("%s/debug-%s", testDir, name) @@ -387,7 +387,7 @@ func TestDebugCommand_CaptureLogs(t *testing.T) { testrpc.WaitForLeader(t, a.RPC, "dc1") ui := cli.NewMockUi() - cmd := New(ui, nil) + cmd := New(ui) cmd.validateTiming = false outputPath := fmt.Sprintf("%s/debug-%s", testDir, name) @@ -480,7 +480,7 @@ func TestDebugCommand_ProfilesExist(t *testing.T) { testrpc.WaitForLeader(t, a.RPC, "dc1") ui := cli.NewMockUi() - cmd := New(ui, nil) + cmd := New(ui) cmd.validateTiming = false outputPath := fmt.Sprintf("%s/debug", testDir) @@ -548,7 +548,7 @@ func TestDebugCommand_Prepare_ValidateTiming(t *testing.T) { for name, tc := range cases { t.Run(name, func(t *testing.T) { ui := cli.NewMockUi() - cmd := New(ui, nil) + cmd := New(ui) args := []string{ "-duration=" + tc.duration, @@ -579,7 +579,7 @@ func TestDebugCommand_DebugDisabled(t *testing.T) { testrpc.WaitForLeader(t, a.RPC, "dc1") ui := cli.NewMockUi() - cmd := New(ui, nil) + cmd := New(ui) cmd.validateTiming = false outputPath := fmt.Sprintf("%s/debug", testDir) diff --git a/command/registry.go b/command/registry.go index a96818c2f3..b400a92dfb 100644 --- a/command/registry.go +++ b/command/registry.go @@ -47,6 +47,7 @@ var registry map[string]Factory // MakeShutdownCh returns a channel that can be used for shutdown notifications // for commands. This channel will send a message for every interrupt or SIGTERM // received. +// Deprecated: use signal.NotifyContext func MakeShutdownCh() <-chan struct{} { resultCh := make(chan struct{}) signalCh := make(chan os.Signal, 4)