diff --git a/command/agent/agent_test.go b/command/agent/agent_test.go index 8240b4854d..10ee331b8c 100644 --- a/command/agent/agent_test.go +++ b/command/agent/agent_test.go @@ -80,11 +80,11 @@ func makeAgentKeyring(t *testing.T, conf *Config, key string) (string, *Agent) { conf.DataDir = dir - fileLAN := filepath.Join(dir, SerfLANKeyring) + fileLAN := filepath.Join(dir, serfLANKeyring) if err := testutil.InitKeyring(fileLAN, key); err != nil { t.Fatalf("err: %s", err) } - fileWAN := filepath.Join(dir, SerfWANKeyring) + fileWAN := filepath.Join(dir, serfWANKeyring) if err := testutil.InitKeyring(fileWAN, key); err != nil { t.Fatalf("err: %s", err) } diff --git a/command/agent/command.go b/command/agent/command.go index 620389767e..7bb9496189 100644 --- a/command/agent/command.go +++ b/command/agent/command.go @@ -156,20 +156,16 @@ func (c *Command) readConfig() *Config { } fileLAN := filepath.Join(config.DataDir, serfLANKeyring) - if _, err := os.Stat(fileLAN); err != nil { - initKeyring(fileLAN, config.EncryptKey) - } else { - c.Ui.Error(fmt.Sprintf("WARNING: %s exists, not using key: %s", - fileLAN, config.EncryptKey)) + if err := initKeyring(fileLAN, config.EncryptKey); err != nil { + c.Ui.Error(fmt.Sprintf("Error initializing keyring: %s", err)) + return nil } if config.Server { fileWAN := filepath.Join(config.DataDir, serfWANKeyring) - if _, err := os.Stat(fileWAN); err != nil { - initKeyring(fileWAN, config.EncryptKey) - } else { - c.Ui.Error(fmt.Sprintf("WARNING: %s exists, not using key: %s", - fileWAN, config.EncryptKey)) + if err := initKeyring(fileWAN, config.EncryptKey); err != nil { + c.Ui.Error(fmt.Sprintf("Error initializing keyring: %s", err)) + return nil } } } diff --git a/command/agent/command_test.go b/command/agent/command_test.go index 9c1bf4db5d..703d476ac9 100644 --- a/command/agent/command_test.go +++ b/command/agent/command_test.go @@ -1,12 +1,16 @@ package agent import ( +<<<<<<< HEAD "fmt" "io/ioutil" "log" "os" "path/filepath" "strings" +======= + "github.com/mitchellh/cli" +>>>>>>> agent: -encrypt appends to keyring if one exists "testing" "github.com/hashicorp/consul/testutil" diff --git a/command/agent/config_test.go b/command/agent/config_test.go index 52c0c58f76..13e8634ce4 100644 --- a/command/agent/config_test.go +++ b/command/agent/config_test.go @@ -1030,64 +1030,3 @@ func TestReadConfigPaths_dir(t *testing.T) { t.Fatalf("bad: %#v", config) } } - -func TestKeyringFileExists(t *testing.T) { - tempDir, err := ioutil.TempDir("", "consul") - if err != nil { - t.Fatalf("err: %s", err) - } - defer os.RemoveAll(tempDir) - - fileLAN := filepath.Join(tempDir, SerfLANKeyring) - fileWAN := filepath.Join(tempDir, SerfWANKeyring) - - if err := os.MkdirAll(filepath.Dir(fileLAN), 0700); err != nil { - t.Fatalf("err: %s", err) - } - if err := os.MkdirAll(filepath.Dir(fileWAN), 0700); err != nil { - t.Fatalf("err: %s", err) - } - - config := &Config{DataDir: tempDir, Server: true} - - // Returns false if we are a server and no keyring files present - if config.keyringFileExists() { - t.Fatalf("should return false") - } - - // Returns false if we are a client and no keyring files present - config.Server = false - if config.keyringFileExists() { - t.Fatalf("should return false") - } - - // Returns true if we are a client and the lan file exists - if err := ioutil.WriteFile(fileLAN, nil, 0600); err != nil { - t.Fatalf("err: %s", err) - } - if !config.keyringFileExists() { - t.Fatalf("should return true") - } - - // Returns true if we are a server and only the lan file exists - config.Server = true - if !config.keyringFileExists() { - t.Fatalf("should return true") - } - - // Returns true if we are a server and both files exist - if err := ioutil.WriteFile(fileWAN, nil, 0600); err != nil { - t.Fatalf("err: %s", err) - } - if !config.keyringFileExists() { - t.Fatalf("should return true") - } - - // Returns true if we are a server and only the wan file exists - if err := os.Remove(fileLAN); err != nil { - t.Fatalf("err: %s", err) - } - if !config.keyringFileExists() { - t.Fatalf("should return true") - } -} diff --git a/command/agent/keyring.go b/command/agent/keyring.go index d4253b5e1c..524968605d 100644 --- a/command/agent/keyring.go +++ b/command/agent/keyring.go @@ -20,11 +20,28 @@ const ( // initKeyring will create a keyring file at a given path. func initKeyring(path, key string) error { + var keys []string + if _, err := base64.StdEncoding.DecodeString(key); err != nil { return fmt.Errorf("Invalid key: %s", err) } - keys := []string{key} + if _, err := os.Stat(path); err == nil { + content, err := ioutil.ReadFile(path) + if err != nil { + return err + } + if err := json.Unmarshal(content, &keys); err != nil { + return err + } + for _, existing := range keys { + if key == existing { + return nil + } + } + } + + keys = append(keys, key) keyringBytes, err := json.Marshal(keys) if err != nil { return err @@ -34,10 +51,6 @@ func initKeyring(path, key string) error { return err } - if _, err := os.Stat(path); err == nil { - return fmt.Errorf("File already exists: %s", path) - } - fh, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600) if err != nil { return err diff --git a/command/agent/keyring_test.go b/command/agent/keyring_test.go index 7807e53761..734a67dc14 100644 --- a/command/agent/keyring_test.go +++ b/command/agent/keyring_test.go @@ -1,7 +1,12 @@ package agent import ( + "bytes" + "encoding/json" + "io/ioutil" "os" + "path/filepath" + "strings" "testing" ) @@ -69,3 +74,70 @@ func TestAgent_LoadKeyrings(t *testing.T) { t.Fatalf("keyring should not be loaded") } } + +func TestAgent_InitKeyring(t *testing.T) { + key1 := "tbLJg26ZJyJ9pK3qhc9jig==" + key2 := "4leC33rgtXKIVUr9Nr0snQ==" + + dir, err := ioutil.TempDir("", "consul") + if err != nil { + t.Fatalf("err: %s", err) + } + defer os.RemoveAll(dir) + + file := filepath.Join(dir, "keyring") + + // First initialize the keyring + if err := initKeyring(file, key1); err != nil { + t.Fatalf("err: %s", err) + } + + content1, err := ioutil.ReadFile(file) + if err != nil { + t.Fatalf("err: %s", err) + } + if !strings.Contains(string(content1), key1) { + t.Fatalf("bad: %s", content1) + } + if strings.Contains(string(content1), key2) { + t.Fatalf("bad: %s", content1) + } + + // Now initialize again with the same key + if err := initKeyring(file, key1); err != nil { + t.Fatalf("err: %s", err) + } + + content2, err := ioutil.ReadFile(file) + if err != nil { + t.Fatalf("err: %s", err) + } + if !bytes.Equal(content1, content2) { + t.Fatalf("bad: %s", content2) + } + + // Initialize an existing keyring with a new key + if err := initKeyring(file, key2); err != nil { + t.Fatalf("err: %s", err) + } + + content3, err := ioutil.ReadFile(file) + if err != nil { + t.Fatalf("err: %s", err) + } + if !strings.Contains(string(content3), key1) { + t.Fatalf("bad: %s", content3) + } + if !strings.Contains(string(content3), key2) { + t.Fatalf("bad: %s", content3) + } + + // Unmarshal and make sure that key1 is still primary + var keys []string + if err := json.Unmarshal(content3, &keys); err != nil { + t.Fatalf("err: %s", err) + } + if keys[0] != key1 { + t.Fatalf("bad: %#v", keys) + } +} diff --git a/command/agent/rpc_client_test.go b/command/agent/rpc_client_test.go index 518423a69a..2eed8f8a87 100644 --- a/command/agent/rpc_client_test.go +++ b/command/agent/rpc_client_test.go @@ -334,7 +334,7 @@ func TestRPCClientInstallKey(t *testing.T) { func TestRPCClientUseKey(t *testing.T) { key1 := "tbLJg26ZJyJ9pK3qhc9jig==" key2 := "xAEZ3uVHRMZD9GcYMZaRQw==" - conf := Config{EncryptKey: key1} + conf := Config{EncryptKey: key1, Server: true} p1 := testRPCClientWithConfig(t, &conf) defer p1.Close() diff --git a/command/keyring_test.go b/command/keyring_test.go index 7e975b6ccc..cc75d9797e 100644 --- a/command/keyring_test.go +++ b/command/keyring_test.go @@ -129,8 +129,8 @@ func TestKeyringCommandRun_initKeyring(t *testing.T) { t.Fatalf("bad: %d. %#v", code, ui.ErrorWriter.String()) } - fileLAN := filepath.Join(tempDir, agent.SerfLANKeyring) - fileWAN := filepath.Join(tempDir, agent.SerfWANKeyring) + fileLAN := filepath.Join(tempDir, agent.serfLANKeyring) + fileWAN := filepath.Join(tempDir, agent.serfWANKeyring) if _, err := os.Stat(fileLAN); err != nil { t.Fatalf("err: %s", err) } diff --git a/consul/rpc.go b/consul/rpc.go index 6fcb07d820..f400b2be95 100644 --- a/consul/rpc.go +++ b/consul/rpc.go @@ -229,6 +229,9 @@ func (s *Server) forwardDC(method, dc string, args interface{}, reply interface{ func (s *Server) globalRPC(method string, args interface{}, reply structs.CompoundResponse) error { + if reply == nil { + return fmt.Errorf("nil reply struct") + } rlen := len(s.remoteConsuls) if rlen < 2 { return nil diff --git a/consul/serf_test.go b/consul/serf_test.go index b30b4fe821..07225a7293 100644 --- a/consul/serf_test.go +++ b/consul/serf_test.go @@ -1,8 +1,6 @@ package consul import ( - "fmt" - "os" "testing" ) @@ -21,25 +19,3 @@ func TestUserEventNames(t *testing.T) { t.Fatalf("bad: %v", raw) } } - -func TestKeyringRPCError(t *testing.T) { - dir1, s1 := testServerDC(t, "dc1") - defer os.RemoveAll(dir1) - defer s1.Shutdown() - - dir2, s2 := testServerDC(t, "dc2") - defer os.RemoveAll(dir2) - defer s2.Shutdown() - - // Try to join - addr := fmt.Sprintf("127.0.0.1:%d", - s1.config.SerfWANConfig.MemberlistConfig.BindPort) - if _, err := s2.JoinWAN([]string{addr}); err != nil { - t.Fatalf("err: %v", err) - } - - // RPC error from remote datacenter is returned - if err := s1.keyringRPC("Bad.Method", nil, nil); err == nil { - t.Fatalf("bad") - } -} diff --git a/consul/server_test.go b/consul/server_test.go index 76b7d4ed44..50627837e6 100644 --- a/consul/server_test.go +++ b/consul/server_test.go @@ -6,9 +6,11 @@ import ( "io/ioutil" "net" "os" + "strings" "testing" "time" + "github.com/hashicorp/consul/consul/structs" "github.com/hashicorp/consul/testutil" ) @@ -471,5 +473,43 @@ func TestServer_BadExpect(t *testing.T) { }, func(err error) { t.Fatalf("should have 0 peers: %v", err) }) - +} + +func TestServer_globalRPC(t *testing.T) { + dir1, s1 := testServerDC(t, "dc1") + defer os.RemoveAll(dir1) + defer s1.Shutdown() + + dir2, s2 := testServerDC(t, "dc2") + defer os.RemoveAll(dir2) + defer s2.Shutdown() + + // Try to join + addr := fmt.Sprintf("127.0.0.1:%d", + s1.config.SerfLANConfig.MemberlistConfig.BindPort) + if _, err := s2.JoinLAN([]string{addr}); err != nil { + t.Fatalf("err: %v", err) + } + + testutil.WaitForLeader(t, s1.RPC, "dc1") + + // Check that replies from each DC come in + resp := &structs.KeyringResponses{} + args := &structs.KeyringRequest{Operation: structs.KeyringList} + if err := s1.globalRPC("Internal.KeyringOperation", args, resp); err != nil { + t.Fatalf("err: %s", err) + } + if len(resp.Responses) != 3 { + t.Fatalf("bad: %#v", resp.Responses) + } + + // Check that error from remote DC is returned + resp = &structs.KeyringResponses{} + err := s1.globalRPC("Bad.Method", nil, resp) + if err == nil { + t.Fatalf("should have errored") + } + if !strings.Contains(err.Error(), "Bad.Method") { + t.Fatalf("unexpcted error: %s", err) + } }