diff --git a/agent/proxy/daemon.go b/agent/proxy/daemon.go index d5fd30256d..a930c978bd 100644 --- a/agent/proxy/daemon.go +++ b/agent/proxy/daemon.go @@ -38,11 +38,16 @@ type Daemon struct { // a file. Logger *log.Logger + // For tests, they can set this to change the default duration to wait + // for a graceful quit. + gracefulWait time.Duration + // process is the started process - lock sync.Mutex - stopped bool - stopCh chan struct{} - process *os.Process + lock sync.Mutex + stopped bool + stopCh chan struct{} + exitedCh chan struct{} + process *os.Process } // Start starts the daemon and keeps it running. @@ -64,17 +69,21 @@ func (p *Daemon) Start() error { // Setup our stop channel stopCh := make(chan struct{}) + exitedCh := make(chan struct{}) p.stopCh = stopCh + p.exitedCh = exitedCh // Start the loop. - go p.keepAlive(stopCh) + go p.keepAlive(stopCh, exitedCh) return nil } // keepAlive starts and keeps the configured process alive until it // is stopped via Stop. -func (p *Daemon) keepAlive(stopCh <-chan struct{}) { +func (p *Daemon) keepAlive(stopCh <-chan struct{}, exitedCh chan<- struct{}) { + defer close(exitedCh) + p.lock.Lock() process := p.process p.lock.Unlock() @@ -196,24 +205,42 @@ func (p *Daemon) start() (*os.Process, error) { // then this returns no error. func (p *Daemon) Stop() error { p.lock.Lock() - defer p.lock.Unlock() // If we're already stopped or never started, then no problem. if p.stopped || p.process == nil { // In the case we never even started, calling Stop makes it so // that we can't ever start in the future, either, so mark this. p.stopped = true + p.lock.Unlock() return nil } // Note that we've stopped p.stopped = true close(p.stopCh) + process := p.process + p.lock.Unlock() - err := p.process.Signal(os.Interrupt) + gracefulWait := p.gracefulWait + if gracefulWait == 0 { + gracefulWait = 5 * time.Second + } - return err - //return p.Command.Process.Kill() + // First, try a graceful stop + err := process.Signal(os.Interrupt) + if err == nil { + select { + case <-p.exitedCh: + // Success! + return nil + + case <-time.After(gracefulWait): + // Interrupt didn't work + } + } + + // Graceful didn't work, forcibly kill + return process.Kill() } // Equal implements Proxy to check for equality. diff --git a/agent/proxy/daemon_test.go b/agent/proxy/daemon_test.go index a1638b2665..32acde636d 100644 --- a/agent/proxy/daemon_test.go +++ b/agent/proxy/daemon_test.go @@ -6,6 +6,7 @@ import ( "os/exec" "path/filepath" "testing" + "time" "github.com/hashicorp/consul/testutil/retry" "github.com/hashicorp/go-uuid" @@ -99,6 +100,48 @@ func TestDaemonRestart(t *testing.T) { waitFile() } +func TestDaemonStop_kill(t *testing.T) { + t.Parallel() + + require := require.New(t) + td, closer := testTempDir(t) + defer closer() + + path := filepath.Join(td, "file") + + d := &Daemon{ + Command: helperProcess("stop-kill", path), + ProxyToken: "hello", + Logger: testLogger, + gracefulWait: 200 * time.Millisecond, + } + require.NoError(d.Start()) + + // Wait for the file to exist + retry.Run(t, func(r *retry.R) { + _, err := os.Stat(path) + if err == nil { + return + } + + r.Fatalf("error: %s", err) + }) + + // Stop the process + require.NoError(d.Stop()) + + // State the file so that we can get the mtime + fi, err := os.Stat(path) + require.NoError(err) + mtime := fi.ModTime() + + // The mtime shouldn't change + time.Sleep(100 * time.Millisecond) + fi, err = os.Stat(path) + require.NoError(err) + require.Equal(mtime, fi.ModTime()) +} + func TestDaemonEqual(t *testing.T) { cases := []struct { Name string diff --git a/agent/proxy/proxy_test.go b/agent/proxy/proxy_test.go index 11994b1bfd..71cfd4ebc8 100644 --- a/agent/proxy/proxy_test.go +++ b/agent/proxy/proxy_test.go @@ -120,6 +120,24 @@ func TestHelperProcess(t *testing.T) { } } + case "stop-kill": + // Setup listeners so it is ignored + ch := make(chan os.Signal, 1) + signal.Notify(ch, os.Interrupt) + defer signal.Stop(ch) + + path := args[0] + data := []byte(os.Getenv(EnvProxyToken)) + for { + if err := ioutil.WriteFile(path, data, 0644); err != nil { + t.Fatalf("err: %s", err) + } + time.Sleep(25 * time.Millisecond) + } + + // Run forever + <-make(chan struct{}) + default: fmt.Fprintf(os.Stderr, "Unknown command: %q\n", cmd) os.Exit(2)