diff --git a/api/api.go b/api/api.go index a2a9e89b68..d67963c1b7 100644 --- a/api/api.go +++ b/api/api.go @@ -168,6 +168,9 @@ type Config struct { // Datacenter to use. If not provided, the default agent datacenter is used. Datacenter string + // Transport is the Transport to use for the http client. + Transport *http.Transport + // HttpClient is the client to use. Default will be // used if not provided. HttpClient *http.Client @@ -239,9 +242,7 @@ func defaultConfig(transportFn func() *http.Transport) *Config { config := &Config{ Address: "127.0.0.1:8500", Scheme: "http", - HttpClient: &http.Client{ - Transport: transportFn(), - }, + Transport: transportFn(), } if addr := os.Getenv(HTTPAddrEnvName); addr != "" { @@ -364,6 +365,10 @@ func NewClient(config *Config) (*Client, error) { config.Scheme = defConfig.Scheme } + if config.Transport == nil { + config.Transport = defConfig.Transport + } + if config.HttpClient == nil { config.HttpClient = defConfig.HttpClient } @@ -392,17 +397,14 @@ func NewClient(config *Config) (*Client, error) { config.TLSConfig.InsecureSkipVerify = defConfig.TLSConfig.InsecureSkipVerify } - tlsClientConfig, err := SetupTLSConfig(&config.TLSConfig) - - // We don't expect this to fail given that we aren't - // parsing any of the input, but we panic just in case - // since this doesn't have an error return. - if err != nil { - return nil, err + if config.HttpClient == nil { + var err error + config.HttpClient, err = NewHttpClient(config.Transport, config.TLSConfig) + if err != nil { + return nil, err + } } - config.HttpClient.Transport.(*http.Transport).TLSClientConfig = tlsClientConfig - parts := strings.SplitN(config.Address, "://", 2) if len(parts) == 2 { switch parts[0] { @@ -429,6 +431,23 @@ func NewClient(config *Config) (*Client, error) { return client, nil } +// NewHttpClient returns an http client configured with the given Transport and TLS +// config. +func NewHttpClient(transport *http.Transport, tlsConf TLSConfig) (*http.Client, error) { + tlsClientConfig, err := SetupTLSConfig(&tlsConf) + + if err != nil { + return nil, err + } + + transport.TLSClientConfig = tlsClientConfig + client := &http.Client{ + Transport: transport, + } + + return client, nil +} + // request is used to help build up a request type request struct { config *Config diff --git a/api/api_test.go b/api/api_test.go index 8e6a40ee5d..df39f85256 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -10,11 +10,11 @@ import ( "path/filepath" "reflect" "runtime" + "strings" "testing" "time" "github.com/hashicorp/consul/testutil" - "strings" ) type configCallback func(c *Config) @@ -140,11 +140,11 @@ func TestDefaultConfig_env(t *testing.T) { // Use keep alives as a check for whether pooling is on or off. if pooled := i == 0; pooled { - if config.HttpClient.Transport.(*http.Transport).DisableKeepAlives != false { + if config.Transport.DisableKeepAlives != false { t.Errorf("expected keep alives to be enabled") } } else { - if config.HttpClient.Transport.(*http.Transport).DisableKeepAlives != true { + if config.Transport.DisableKeepAlives != true { t.Errorf("expected keep alives to be disabled") } }