package api import ( crand "crypto/rand" "crypto/tls" "fmt" "io/ioutil" "net/http" "os" "path/filepath" "reflect" "runtime" "testing" "time" "github.com/hashicorp/consul/testutil" "strings" ) type configCallback func(c *Config) func makeClient(t *testing.T) (*Client, *testutil.TestServer) { return makeClientWithConfig(t, nil, nil) } func makeACLClient(t *testing.T) (*Client, *testutil.TestServer) { return makeClientWithConfig(t, func(clientConfig *Config) { clientConfig.Token = "root" }, func(serverConfig *testutil.TestServerConfig) { serverConfig.ACLMasterToken = "root" serverConfig.ACLDatacenter = "dc1" serverConfig.ACLDefaultPolicy = "deny" }) } func makeClientWithConfig( t *testing.T, cb1 configCallback, cb2 testutil.ServerConfigCallback) (*Client, *testutil.TestServer) { // Make client config conf := DefaultConfig() if cb1 != nil { cb1(conf) } // Create server server, err := testutil.NewTestServerConfig(cb2) if err != nil { t.Fatal(err) } conf.Address = server.HTTPAddr // Create client client, err := NewClient(conf) if err != nil { t.Fatalf("err: %v", err) } return client, server } func testKey() string { buf := make([]byte, 16) if _, err := crand.Read(buf); err != nil { panic(fmt.Errorf("Failed to read random bytes: %v", err)) } return fmt.Sprintf("%08x-%04x-%04x-%04x-%12x", buf[0:4], buf[4:6], buf[6:8], buf[8:10], buf[10:16]) } func TestDefaultConfig_env(t *testing.T) { t.Parallel() addr := "1.2.3.4:5678" token := "abcd1234" auth := "username:password" os.Setenv(HTTPAddrEnvName, addr) defer os.Setenv(HTTPAddrEnvName, "") os.Setenv(HTTPTokenEnvName, token) defer os.Setenv(HTTPTokenEnvName, "") os.Setenv(HTTPAuthEnvName, auth) defer os.Setenv(HTTPAuthEnvName, "") os.Setenv(HTTPSSLEnvName, "1") defer os.Setenv(HTTPSSLEnvName, "") os.Setenv(HTTPCAFile, "ca.pem") defer os.Setenv(HTTPCAFile, "") os.Setenv(HTTPCAPath, "certs/") defer os.Setenv(HTTPCAPath, "") os.Setenv(HTTPClientCert, "client.crt") defer os.Setenv(HTTPClientCert, "") os.Setenv(HTTPClientKey, "client.key") defer os.Setenv(HTTPClientKey, "") os.Setenv(HTTPTLSServerName, "consul.test") defer os.Setenv(HTTPTLSServerName, "") os.Setenv(HTTPSSLVerifyEnvName, "0") defer os.Setenv(HTTPSSLVerifyEnvName, "") for i, config := range []*Config{DefaultConfig(), DefaultNonPooledConfig()} { if config.Address != addr { t.Errorf("expected %q to be %q", config.Address, addr) } if config.Token != token { t.Errorf("expected %q to be %q", config.Token, token) } if config.HttpAuth == nil { t.Fatalf("expected HttpAuth to be enabled") } if config.HttpAuth.Username != "username" { t.Errorf("expected %q to be %q", config.HttpAuth.Username, "username") } if config.HttpAuth.Password != "password" { t.Errorf("expected %q to be %q", config.HttpAuth.Password, "password") } if config.Scheme != "https" { t.Errorf("expected %q to be %q", config.Scheme, "https") } if config.TLSConfig.CAFile != "ca.pem" { t.Errorf("expected %q to be %q", config.TLSConfig.CAFile, "ca.pem") } if config.TLSConfig.CAPath != "certs/" { t.Errorf("expected %q to be %q", config.TLSConfig.CAPath, "certs/") } if config.TLSConfig.CertFile != "client.crt" { t.Errorf("expected %q to be %q", config.TLSConfig.CertFile, "client.crt") } if config.TLSConfig.KeyFile != "client.key" { t.Errorf("expected %q to be %q", config.TLSConfig.KeyFile, "client.key") } if config.TLSConfig.Address != "consul.test" { t.Errorf("expected %q to be %q", config.TLSConfig.Address, "consul.test") } if !config.TLSConfig.InsecureSkipVerify { t.Errorf("expected SSL verification to be off") } // Use keep alives as a check for whether pooling is on or off. if pooled := i == 0; pooled { if config.HttpClient.Transport.(*http.Transport).DisableKeepAlives != false { t.Errorf("expected keep alives to be enabled") } } else { if config.HttpClient.Transport.(*http.Transport).DisableKeepAlives != true { t.Errorf("expected keep alives to be disabled") } } } } func TestSetupTLSConfig(t *testing.T) { // A default config should result in a clean default client config. tlsConfig := &TLSConfig{} cc, err := SetupTLSConfig(tlsConfig) if err != nil { t.Fatalf("err: %v", err) } expected := &tls.Config{RootCAs: cc.RootCAs} if !reflect.DeepEqual(cc, expected) { t.Fatalf("bad: \n%v, \n%v", cc, expected) } // Try some address variations with and without ports. tlsConfig.Address = "127.0.0.1" cc, err = SetupTLSConfig(tlsConfig) if err != nil { t.Fatalf("err: %v", err) } expected.ServerName = "127.0.0.1" if !reflect.DeepEqual(cc, expected) { t.Fatalf("bad: %v", cc) } tlsConfig.Address = "127.0.0.1:80" cc, err = SetupTLSConfig(tlsConfig) if err != nil { t.Fatalf("err: %v", err) } expected.ServerName = "127.0.0.1" if !reflect.DeepEqual(cc, expected) { t.Fatalf("bad: %v", cc) } tlsConfig.Address = "demo.consul.io:80" cc, err = SetupTLSConfig(tlsConfig) if err != nil { t.Fatalf("err: %v", err) } expected.ServerName = "demo.consul.io" if !reflect.DeepEqual(cc, expected) { t.Fatalf("bad: %v", cc) } tlsConfig.Address = "[2001:db8:a0b:12f0::1]" cc, err = SetupTLSConfig(tlsConfig) if err != nil { t.Fatalf("err: %v", err) } expected.ServerName = "[2001:db8:a0b:12f0::1]" if !reflect.DeepEqual(cc, expected) { t.Fatalf("bad: %v", cc) } tlsConfig.Address = "[2001:db8:a0b:12f0::1]:80" cc, err = SetupTLSConfig(tlsConfig) if err != nil { t.Fatalf("err: %v", err) } expected.ServerName = "2001:db8:a0b:12f0::1" if !reflect.DeepEqual(cc, expected) { t.Fatalf("bad: %v", cc) } // Skip verification. tlsConfig.InsecureSkipVerify = true cc, err = SetupTLSConfig(tlsConfig) if err != nil { t.Fatalf("err: %v", err) } expected.InsecureSkipVerify = true if !reflect.DeepEqual(cc, expected) { t.Fatalf("bad: %v", cc) } // Make a new config that hits all the file parsers. tlsConfig = &TLSConfig{ CertFile: "../test/hostname/Alice.crt", KeyFile: "../test/hostname/Alice.key", CAFile: "../test/hostname/CertAuth.crt", } cc, err = SetupTLSConfig(tlsConfig) if err != nil { t.Fatalf("err: %v", err) } if len(cc.Certificates) != 1 { t.Fatalf("missing certificate: %v", cc.Certificates) } if cc.RootCAs == nil { t.Fatalf("didn't load root CAs") } // Use a directory to load the certs instead cc, err = SetupTLSConfig(&TLSConfig{ CAPath: "../test/ca_path", }) if err != nil { t.Fatalf("err: %v", err) } if len(cc.RootCAs.Subjects()) != 2 { t.Fatalf("didn't load root CAs") } } func TestClientTLSOptions(t *testing.T) { t.Parallel() _, s := makeClientWithConfig(t, nil, func(conf *testutil.TestServerConfig) { conf.CAFile = "../test/client_certs/rootca.crt" conf.CertFile = "../test/client_certs/server.crt" conf.KeyFile = "../test/client_certs/server.key" conf.VerifyIncoming = true }) defer s.Stop() // Set up a client without certs clientWithoutCerts, err := NewClient(&Config{ Address: s.HTTPSAddr, Scheme: "https", TLSConfig: TLSConfig{ Address: "consul.test", CAFile: "../test/client_certs/rootca.crt", }, }) if err != nil { t.Fatal(err) } // Should fail _, err = clientWithoutCerts.Agent().Self() if err == nil || !strings.Contains(err.Error(), "bad certificate") { t.Fatal(err) } // Set up a client with valid certs clientWithCerts, err := NewClient(&Config{ Address: s.HTTPSAddr, Scheme: "https", TLSConfig: TLSConfig{ Address: "consul.test", CAFile: "../test/client_certs/rootca.crt", CertFile: "../test/client_certs/client.crt", KeyFile: "../test/client_certs/client.key", }, }) if err != nil { t.Fatal(err) } // Should succeed _, err = clientWithCerts.Agent().Self() if err != nil { t.Fatal(err) } } func TestSetQueryOptions(t *testing.T) { t.Parallel() c, s := makeClient(t) defer s.Stop() r := c.newRequest("GET", "/v1/kv/foo") q := &QueryOptions{ Datacenter: "foo", AllowStale: true, RequireConsistent: true, WaitIndex: 1000, WaitTime: 100 * time.Second, Token: "12345", Near: "nodex", } r.setQueryOptions(q) if r.params.Get("dc") != "foo" { t.Fatalf("bad: %v", r.params) } if _, ok := r.params["stale"]; !ok { t.Fatalf("bad: %v", r.params) } if _, ok := r.params["consistent"]; !ok { t.Fatalf("bad: %v", r.params) } if r.params.Get("index") != "1000" { t.Fatalf("bad: %v", r.params) } if r.params.Get("wait") != "100000ms" { t.Fatalf("bad: %v", r.params) } if r.header.Get("X-Consul-Token") != "12345" { t.Fatalf("bad: %v", r.header) } if r.params.Get("near") != "nodex" { t.Fatalf("bad: %v", r.params) } } func TestSetWriteOptions(t *testing.T) { t.Parallel() c, s := makeClient(t) defer s.Stop() r := c.newRequest("GET", "/v1/kv/foo") q := &WriteOptions{ Datacenter: "foo", Token: "23456", } r.setWriteOptions(q) if r.params.Get("dc") != "foo" { t.Fatalf("bad: %v", r.params) } if r.header.Get("X-Consul-Token") != "23456" { t.Fatalf("bad: %v", r.header) } } func TestRequestToHTTP(t *testing.T) { t.Parallel() c, s := makeClient(t) defer s.Stop() r := c.newRequest("DELETE", "/v1/kv/foo") q := &QueryOptions{ Datacenter: "foo", } r.setQueryOptions(q) req, err := r.toHTTP() if err != nil { t.Fatalf("err: %v", err) } if req.Method != "DELETE" { t.Fatalf("bad: %v", req) } if req.URL.RequestURI() != "/v1/kv/foo?dc=foo" { t.Fatalf("bad: %v", req) } } func TestParseQueryMeta(t *testing.T) { t.Parallel() resp := &http.Response{ Header: make(map[string][]string), } resp.Header.Set("X-Consul-Index", "12345") resp.Header.Set("X-Consul-LastContact", "80") resp.Header.Set("X-Consul-KnownLeader", "true") resp.Header.Set("X-Consul-Translate-Addresses", "true") qm := &QueryMeta{} if err := parseQueryMeta(resp, qm); err != nil { t.Fatalf("err: %v", err) } if qm.LastIndex != 12345 { t.Fatalf("Bad: %v", qm) } if qm.LastContact != 80*time.Millisecond { t.Fatalf("Bad: %v", qm) } if !qm.KnownLeader { t.Fatalf("Bad: %v", qm) } if !qm.AddressTranslationEnabled { t.Fatalf("Bad: %v", qm) } } func TestAPI_UnixSocket(t *testing.T) { t.Parallel() if runtime.GOOS == "windows" { t.SkipNow() } tempDir, err := ioutil.TempDir("", "consul") if err != nil { t.Fatalf("err: %s", err) } defer os.RemoveAll(tempDir) socket := filepath.Join(tempDir, "test.sock") c, s := makeClientWithConfig(t, func(c *Config) { c.Address = "unix://" + socket }, func(c *testutil.TestServerConfig) { c.Addresses = &testutil.TestAddressConfig{ HTTP: "unix://" + socket, } }) defer s.Stop() agent := c.Agent() info, err := agent.Self() if err != nil { t.Fatalf("err: %s", err) } if info["Config"]["NodeName"] == "" { t.Fatalf("bad: %v", info) } } func TestAPI_durToMsec(t *testing.T) { if ms := durToMsec(0); ms != "0ms" { t.Fatalf("bad: %s", ms) } if ms := durToMsec(time.Millisecond); ms != "1ms" { t.Fatalf("bad: %s", ms) } if ms := durToMsec(time.Microsecond); ms != "1ms" { t.Fatalf("bad: %s", ms) } if ms := durToMsec(5 * time.Millisecond); ms != "5ms" { t.Fatalf("bad: %s", ms) } } func TestAPI_IsServerError(t *testing.T) { if IsServerError(nil) { t.Fatalf("should not be a server error") } if IsServerError(fmt.Errorf("not the error you are looking for")) { t.Fatalf("should not be a server error") } if !IsServerError(fmt.Errorf(serverError)) { t.Fatalf("should be a server error") } }