diff --git a/command/agent/command_test.go b/command/agent/command_test.go index bcd6be807c..d365b338ed 100644 --- a/command/agent/command_test.go +++ b/command/agent/command_test.go @@ -6,7 +6,6 @@ import ( "io/ioutil" "log" "os" - "strings" "testing" "github.com/hashicorp/consul/testutil" @@ -166,7 +165,7 @@ func TestRetryJoinWanFail(t *testing.T) { } } -func TestSetupAgent_UnixSocket_Fails(t *testing.T) { +func TestSetupAgent_RPCUnixSocket_FileExists(t *testing.T) { conf := nextConfig() tmpDir, err := ioutil.TempDir("", "consul") if err != nil { @@ -185,8 +184,7 @@ func TestSetupAgent_UnixSocket_Fails(t *testing.T) { conf.Server = true conf.Bootstrap = true - // Set socket address to an existing file. Consul should fail to - // start and return an error. + // Set socket address to an existing file. conf.Addresses.RPC = "unix://" + socketPath shutdownCh := make(chan struct{}) @@ -200,12 +198,17 @@ func TestSetupAgent_UnixSocket_Fails(t *testing.T) { logWriter := NewLogWriter(512) logOutput := new(bytes.Buffer) - // Ensure we got an error mentioning the socket file - err = cmd.setupAgent(conf, logOutput, logWriter) - if err == nil { - t.Fatalf("should have failed") + // Ensure the server is created + if err := cmd.setupAgent(conf, logOutput, logWriter); err != nil { + t.Fatalf("err: %s", err) } - if !strings.Contains(err.Error(), socketPath) { - t.Fatalf("expected socket file error, got: %q", err) + + // Ensure the file was replaced by the socket + fi, err := os.Stat(socketPath) + if err != nil { + t.Fatalf("err: %s", err) + } + if fi.Mode()&os.ModeSocket == 0 { + t.Fatalf("expected socket to replace file") } } diff --git a/command/agent/http_test.go b/command/agent/http_test.go index 4677d5ba53..0f723ffd62 100644 --- a/command/agent/http_test.go +++ b/command/agent/http_test.go @@ -132,11 +132,17 @@ func TestHTTPServer_UnixSocket_FileExists(t *testing.T) { defer os.RemoveAll(dir) // Try to start the server with the same path anyways. - if servers, err := NewHTTPServers(agent, conf, agent.logOutput); err == nil { - for _, server := range servers { - server.Shutdown() - } - t.Fatalf("expected socket binding error") + if _, err := NewHTTPServers(agent, conf, agent.logOutput); err != nil { + t.Fatalf("err: %s", err) + } + + // Ensure the file was replaced by the socket + fi, err = os.Stat(socket) + if err != nil { + t.Fatalf("err: %s", err) + } + if fi.Mode()&os.ModeSocket == 0 { + t.Fatalf("expected socket to replace file") } } diff --git a/command/agent/util_test.go b/command/agent/util_test.go index 92b89d26f4..b140f477d2 100644 --- a/command/agent/util_test.go +++ b/command/agent/util_test.go @@ -1,6 +1,8 @@ package agent import ( + "io/ioutil" + "os" "testing" "time" ) @@ -39,3 +41,59 @@ func TestStringHash(t *testing.T) { t.Fatalf("bad: %s", out) } } + +func TestSetFilePermissions(t *testing.T) { + tempFile, err := ioutil.TempFile("", "consul") + if err != nil { + t.Fatalf("err: %s", err) + } + path := tempFile.Name() + defer os.Remove(path) + + // Bad UID fails + if err := setFilePermissions(path, map[string]string{"uid": "%"}); err == nil { + t.Fatalf("should fail") + } + + // Bad GID fails + if err := setFilePermissions(path, map[string]string{"gid": "%"}); err == nil { + t.Fatalf("should fail") + } + + // Bad mode fails + if err := setFilePermissions(path, map[string]string{"mode": "%"}); err == nil { + t.Fatalf("should fail") + } + + // Allows omitting user/group/mode + if err := setFilePermissions(path, map[string]string{}); err != nil { + t.Fatalf("err: %s", err) + } + + // Doesn't change mode if not given + if err := os.Chmod(path, 0700); err != nil { + t.Fatalf("err: %s", err) + } + if err := setFilePermissions(path, map[string]string{}); err != nil { + t.Fatalf("err: %s", err) + } + fi, err := os.Stat(path) + if err != nil { + t.Fatalf("err: %s", err) + } + if fi.Mode().String() != "-rwx------" { + t.Fatalf("bad: %s", fi.Mode()) + } + + // Changes mode if given + if err := setFilePermissions(path, map[string]string{"mode": "0777"}); err != nil { + t.Fatalf("err: %s", err) + } + fi, err = os.Stat(path) + if err != nil { + t.Fatalf("err: %s", err) + } + if fi.Mode().String() != "-rwxrwxrwx" { + t.Fatalf("bad: %s", fi.Mode()) + } +}