diff --git a/agent/signal_unix.go b/agent/signal_unix.go new file mode 100644 index 0000000000..2768a55883 --- /dev/null +++ b/agent/signal_unix.go @@ -0,0 +1,10 @@ +// +build !windows + +package agent + +import ( + "os" + "syscall" +) + +var forwardSignals = []os.Signal{os.Interrupt, syscall.SIGTERM} diff --git a/agent/signal_windows.go b/agent/signal_windows.go new file mode 100644 index 0000000000..3e5b8d7248 --- /dev/null +++ b/agent/signal_windows.go @@ -0,0 +1,9 @@ +// +build windows + +package agent + +import ( + "os" +) + +var forwardSignals = []os.Signal{os.Interrupt} diff --git a/agent/util.go b/agent/util.go index 76670a7048..835b756cb7 100644 --- a/agent/util.go +++ b/agent/util.go @@ -98,7 +98,7 @@ GROUP: func ForwardSignals(cmd *exec.Cmd, logFn func(error), shutdownCh <-chan struct{}) { go func() { signalCh := make(chan os.Signal, 10) - signal.Notify(signalCh, os.Interrupt, os.Kill) + signal.Notify(signalCh, forwardSignals...) defer signal.Stop(signalCh) for { diff --git a/agent/util_test.go b/agent/util_test.go index f90589782e..49c2d5c32b 100644 --- a/agent/util_test.go +++ b/agent/util_test.go @@ -1,7 +1,11 @@ package agent import ( + "bufio" + "fmt" "os" + "os/exec" + "os/signal" "runtime" "testing" "time" @@ -119,3 +123,165 @@ func TestDurationFixer(t *testing.T) { // Ensure we only processed the intended fieldnames verify.Values(t, "", obj, expected) } + +// helperProcessSentinel is a sentinel value that is put as the first +// argument following "--" and is used to determine if TestHelperProcess +// should run. +const helperProcessSentinel = "GO_WANT_HELPER_PROCESS" + +// helperProcess returns an *exec.Cmd that can be used to execute the +// TestHelperProcess function below. This can be used to test multi-process +// interactions. +func helperProcess(s ...string) (*exec.Cmd, func()) { + cs := []string{"-test.run=TestHelperProcess", "--", helperProcessSentinel} + cs = append(cs, s...) + + cmd := exec.Command(os.Args[0], cs...) + destroy := func() { + if p := cmd.Process; p != nil { + p.Kill() + } + } + + return cmd, destroy +} + +// This is not a real test. This is just a helper process kicked off by tests +// using the helperProcess helper function. +func TestHelperProcess(t *testing.T) { + args := os.Args + for len(args) > 0 { + if args[0] == "--" { + args = args[1:] + break + } + + args = args[1:] + } + + if len(args) == 0 || args[0] != helperProcessSentinel { + return + } + + defer os.Exit(0) + args = args[1:] // strip sentinel value + cmd, args := args[0], args[1:] + + switch cmd { + case "parent-signal": + // This subcommand forwards signals to a child process subcommand "print-signal". + + limitProcessLifetime(2 * time.Minute) + + cmd, destroy := helperProcess("print-signal") + defer destroy() + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Start(); err != nil { + fmt.Fprintf(os.Stderr, "child process failed to start: %v\n", err) + os.Exit(1) + } + + doneCh := make(chan struct{}) + defer func() { close(doneCh) }() + logFn := func(err error) { + fmt.Fprintf(os.Stderr, "could not forward signal: %s\n", err) + os.Exit(1) + } + ForwardSignals(cmd, logFn, doneCh) + + if err := cmd.Wait(); err != nil { + fmt.Fprintf(os.Stderr, "unexpected error waiting for child: %v", err) + os.Exit(1) + } + + case "print-signal": + // This subcommand is instrumented to help verify signals are passed correctly. + + limitProcessLifetime(2 * time.Minute) + + ch := make(chan os.Signal, 10) + signal.Notify(ch) + defer signal.Stop(ch) + + fmt.Fprintf(os.Stdout, "ready\n") + + s := <-ch + + fmt.Fprintf(os.Stdout, "signal: %s\n", s) + + default: + fmt.Fprintf(os.Stderr, "Unknown command: %q\n", cmd) + os.Exit(2) + } +} + +// limitProcessLifetime installs a background goroutine that self-exits after +// the specified duration elapses to prevent leaking processes from tests that +// may spawn them. +func limitProcessLifetime(dur time.Duration) { + go time.AfterFunc(dur, func() { + os.Exit(99) + }) +} + +func TestForwardSignals(t *testing.T) { + for _, s := range forwardSignals { + t.Run("signal-"+s.String(), func(t *testing.T) { + testForwardSignal(t, s) + }) + } +} + +func testForwardSignal(t *testing.T, s os.Signal) { + t.Helper() + + if s == os.Kill { + t.Fatalf("you can't forward SIGKILL") + } + + // Launch a child process which registers the forwarding signal handler + // under test and then that in turn launches a grand child process that is + // our test instrument. + cmd, destroy := helperProcess("parent-signal") + defer destroy() + + cmd.Stderr = os.Stderr + prc, err := cmd.StdoutPipe() + if err != nil { + t.Fatalf("could not open stdout pipe for child process: %v", err) + } + defer prc.Close() + + if err := cmd.Start(); err != nil { + t.Fatalf("child process failed to start: %v", err) + } + scan := bufio.NewScanner(prc) + + // Wait until the grandchild relays back to us that it's ready to receive + // signals. + expectLine(t, "ready", scan) + + // Relay our chosen signal down through the intermediary process. + if err := cmd.Process.Signal(s); err != nil { + t.Fatalf("signalling child failed: %v", err) + } + + // Verify that the signal we intended made it all the way to the grandchild. + expectLine(t, "signal: "+s.String(), scan) +} + +func expectLine(t *testing.T, expect string, scan *bufio.Scanner) { + if !scan.Scan() { + if scan.Err() != nil { + t.Fatalf("expected to read line %q but failed: %v", expect, scan.Err()) + } else { + t.Fatalf("expected to read line %q but got no line", expect) + } + } + + if line := scan.Text(); expect != line { + t.Fatalf("expected to read line %q but got %q", expect, line) + } +}