Move header methods from config to client

This commit is contained in:
Conor Mongey 2021-01-07 23:48:53 +00:00
parent eb65e59741
commit 7a368bd2b6
No known key found for this signature in database
GPG Key ID: 5C886ACC44EB17C0
2 changed files with 62 additions and 13 deletions

View File

@ -14,6 +14,7 @@ import (
"os" "os"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-cleanhttp"
@ -314,8 +315,6 @@ type Config struct {
Namespace string Namespace string
TLSConfig TLSConfig TLSConfig TLSConfig
Header http.Header
} }
// TLSConfig is used to generate a TLSClientConfig that's useful for talking to // TLSConfig is used to generate a TLSClientConfig that's useful for talking to
@ -550,9 +549,48 @@ func (c *Config) GenerateEnv() []string {
// Client provides a client to the Consul API // Client provides a client to the Consul API
type Client struct { type Client struct {
modifyLock sync.RWMutex
headers http.Header
config Config config Config
} }
// Headers gets the current set of headers used for requests. This returns a
// copy; to modify it call AddHeader or SetHeaders.
func (c *Client) Headers() http.Header {
c.modifyLock.RLock()
defer c.modifyLock.RUnlock()
if c.headers == nil {
return nil
}
ret := make(http.Header)
for k, v := range c.headers {
for _, val := range v {
ret[k] = append(ret[k], val)
}
}
return ret
}
// AddHeader allows a single header key/value pair to be added
// in a race-safe fashion.
func (c *Client) AddHeader(key, value string) {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.headers.Add(key, value)
}
// SetHeaders clears all previous headers and uses only the given
// ones going forward.
func (c *Client) SetHeaders(headers http.Header) {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.headers = headers
}
// NewClient returns a new client // NewClient returns a new client
func NewClient(config *Config) (*Client, error) { func NewClient(config *Config) (*Client, error) {
// bootstrap the config // bootstrap the config
@ -642,7 +680,7 @@ func NewClient(config *Config) (*Client, error) {
config.Token = defConfig.Token config.Token = defConfig.Token
} }
return &Client{config: *config}, nil return &Client{config: *config, headers: make(http.Header)}, nil
} }
// NewHttpClient returns an http client configured with the given Transport and TLS // NewHttpClient returns an http client configured with the given Transport and TLS
@ -855,12 +893,9 @@ func (c *Client) newRequest(method, path string) *request {
Path: path, Path: path,
}, },
params: make(map[string][]string), params: make(map[string][]string),
header: make(http.Header), header: c.Headers(),
} }
if c.config.Header != nil {
r.header = c.config.Header
}
if c.config.Datacenter != "" { if c.config.Datacenter != "" {
r.params.Set("dc", c.config.Datacenter) r.params.Set("dc", c.config.Datacenter)
} }

View File

@ -810,17 +810,31 @@ func TestAPI_SetWriteOptions(t *testing.T) {
func TestAPI_Headers(t *testing.T) { func TestAPI_Headers(t *testing.T) {
t.Parallel() t.Parallel()
c, s := makeClientWithConfig(t, func(c *Config) { c, s := makeClient(t)
c.Header = http.Header{
"Hello": []string{"World"},
}
}, nil)
defer s.Stop() defer s.Stop()
if len(c.Headers()) != 0 {
t.Fatalf("expected headers to be empty: %v", c.Headers())
}
c.AddHeader("Hello", "World")
r := c.newRequest("GET", "/v1/kv/foo") r := c.newRequest("GET", "/v1/kv/foo")
if r.header.Get("Hello") != "World" { if r.header.Get("Hello") != "World" {
t.Fatalf("bad: %v", r.header) t.Fatalf("Hello header not set : %v", r.header)
}
c.SetHeaders(http.Header{
"Auth": []string{"Token"},
})
r = c.newRequest("GET", "/v1/kv/foo")
if r.header.Get("Hello") != "" {
t.Fatalf("Hello header should not be set: %v", r.header)
}
if r.header.Get("Auth") != "Token" {
t.Fatalf("Auth header not set: %v", r.header)
} }
} }