consul/command/agent/http_test.go
2016-05-10 20:02:36 -07:00

702 lines
16 KiB
Go

package agent
import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"testing"
"time"
"github.com/hashicorp/consul/consul/structs"
"github.com/hashicorp/consul/testutil"
"github.com/hashicorp/go-cleanhttp"
)
func makeHTTPServer(t *testing.T) (string, *HTTPServer) {
return makeHTTPServerWithConfig(t, nil)
}
func makeHTTPServerWithConfig(t *testing.T, cb func(c *Config)) (string, *HTTPServer) {
configTry := 0
RECONF:
configTry += 1
conf := nextConfig()
if cb != nil {
cb(conf)
}
dir, agent := makeAgent(t, conf)
servers, err := NewHTTPServers(agent, conf, agent.logOutput)
if err != nil {
if configTry < 3 {
goto RECONF
}
t.Fatalf("err: %v", err)
}
if len(servers) == 0 {
t.Fatalf(fmt.Sprintf("Failed to make HTTP server"))
}
return dir, servers[0]
}
func encodeReq(obj interface{}) io.ReadCloser {
buf := bytes.NewBuffer(nil)
enc := json.NewEncoder(buf)
enc.Encode(obj)
return ioutil.NopCloser(buf)
}
func TestHTTPServer_UnixSocket(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}
tempDir, err := ioutil.TempDir("", "consul")
if err != nil {
t.Fatalf("err: %s", err)
}
defer os.RemoveAll(tempDir)
socket := filepath.Join(tempDir, "test.sock")
dir, srv := makeHTTPServerWithConfig(t, func(c *Config) {
c.Addresses.HTTP = "unix://" + socket
// Only testing mode, since uid/gid might not be settable
// from test environment.
c.UnixSockets = UnixSocketConfig{}
c.UnixSockets.Perms = "0777"
})
defer os.RemoveAll(dir)
defer srv.Shutdown()
defer srv.agent.Shutdown()
// Ensure the socket was created
if _, err := os.Stat(socket); err != nil {
t.Fatalf("err: %s", err)
}
// Ensure the mode was set properly
fi, err := os.Stat(socket)
if err != nil {
t.Fatalf("err: %s", err)
}
if fi.Mode().String() != "Srwxrwxrwx" {
t.Fatalf("bad permissions: %s", fi.Mode())
}
// Ensure we can get a response from the socket.
path, _ := unixSocketAddr(srv.agent.config.Addresses.HTTP)
trans := cleanhttp.DefaultTransport()
trans.Dial = func(_, _ string) (net.Conn, error) {
return net.Dial("unix", path)
}
client := &http.Client{
Transport: trans,
}
// This URL doesn't look like it makes sense, but the scheme (http://) and
// the host (127.0.0.1) are required by the HTTP client library. In reality
// this will just use the custom dialer and talk to the socket.
resp, err := client.Get("http://127.0.0.1/v1/agent/self")
if err != nil {
t.Fatalf("err: %s", err)
}
defer resp.Body.Close()
if body, err := ioutil.ReadAll(resp.Body); err != nil || len(body) == 0 {
t.Fatalf("bad: %s %v", body, err)
}
}
func TestHTTPServer_UnixSocket_FileExists(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}
tempDir, err := ioutil.TempDir("", "consul")
if err != nil {
t.Fatalf("err: %s", err)
}
defer os.RemoveAll(tempDir)
socket := filepath.Join(tempDir, "test.sock")
// Create a regular file at the socket path
if err := ioutil.WriteFile(socket, []byte("hello world"), 0644); err != nil {
t.Fatalf("err: %s", err)
}
fi, err := os.Stat(socket)
if err != nil {
t.Fatalf("err: %s", err)
}
if !fi.Mode().IsRegular() {
t.Fatalf("not a regular file: %s", socket)
}
conf := nextConfig()
conf.Addresses.HTTP = "unix://" + socket
dir, agent := makeAgent(t, conf)
defer os.RemoveAll(dir)
// Try to start the server with the same path anyways.
if _, err := NewHTTPServers(agent, conf, agent.logOutput); err != nil {
t.Fatalf("err: %s", err)
}
// Ensure the file was replaced by the socket
fi, err = os.Stat(socket)
if err != nil {
t.Fatalf("err: %s", err)
}
if fi.Mode()&os.ModeSocket == 0 {
t.Fatalf("expected socket to replace file")
}
}
func TestSetIndex(t *testing.T) {
resp := httptest.NewRecorder()
setIndex(resp, 1000)
header := resp.Header().Get("X-Consul-Index")
if header != "1000" {
t.Fatalf("Bad: %v", header)
}
setIndex(resp, 2000)
if v := resp.Header()["X-Consul-Index"]; len(v) != 1 {
t.Fatalf("bad: %#v", v)
}
}
func TestSetKnownLeader(t *testing.T) {
resp := httptest.NewRecorder()
setKnownLeader(resp, true)
header := resp.Header().Get("X-Consul-KnownLeader")
if header != "true" {
t.Fatalf("Bad: %v", header)
}
resp = httptest.NewRecorder()
setKnownLeader(resp, false)
header = resp.Header().Get("X-Consul-KnownLeader")
if header != "false" {
t.Fatalf("Bad: %v", header)
}
}
func TestSetLastContact(t *testing.T) {
resp := httptest.NewRecorder()
setLastContact(resp, 123456*time.Microsecond)
header := resp.Header().Get("X-Consul-LastContact")
if header != "123" {
t.Fatalf("Bad: %v", header)
}
}
func TestSetMeta(t *testing.T) {
meta := structs.QueryMeta{
Index: 1000,
KnownLeader: true,
LastContact: 123456 * time.Microsecond,
}
resp := httptest.NewRecorder()
setMeta(resp, &meta)
header := resp.Header().Get("X-Consul-Index")
if header != "1000" {
t.Fatalf("Bad: %v", header)
}
header = resp.Header().Get("X-Consul-KnownLeader")
if header != "true" {
t.Fatalf("Bad: %v", header)
}
header = resp.Header().Get("X-Consul-LastContact")
if header != "123" {
t.Fatalf("Bad: %v", header)
}
}
func TestHTTPAPIResponseHeaders(t *testing.T) {
dir, srv := makeHTTPServer(t)
srv.agent.config.HTTPAPIResponseHeaders = map[string]string{
"Access-Control-Allow-Origin": "*",
"X-XSS-Protection": "1; mode=block",
}
defer os.RemoveAll(dir)
defer srv.Shutdown()
defer srv.agent.Shutdown()
resp := httptest.NewRecorder()
handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
return nil, nil
}
req, _ := http.NewRequest("GET", "/v1/agent/self", nil)
srv.wrap(handler)(resp, req)
origin := resp.Header().Get("Access-Control-Allow-Origin")
if origin != "*" {
t.Fatalf("bad Access-Control-Allow-Origin: expected %q, got %q", "*", origin)
}
xss := resp.Header().Get("X-XSS-Protection")
if xss != "1; mode=block" {
t.Fatalf("bad X-XSS-Protection header: expected %q, got %q", "1; mode=block", xss)
}
}
func TestContentTypeIsJSON(t *testing.T) {
dir, srv := makeHTTPServer(t)
defer os.RemoveAll(dir)
defer srv.Shutdown()
defer srv.agent.Shutdown()
resp := httptest.NewRecorder()
handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
// stub out a DirEntry so that it will be encoded as JSON
return &structs.DirEntry{Key: "key"}, nil
}
req, _ := http.NewRequest("GET", "/v1/kv/key", nil)
srv.wrap(handler)(resp, req)
contentType := resp.Header().Get("Content-Type")
if contentType != "application/json" {
t.Fatalf("Content-Type header was not 'application/json'")
}
}
func TestHTTP_wrap_obfuscateLog(t *testing.T) {
dir, srv := makeHTTPServer(t)
defer os.RemoveAll(dir)
defer srv.Shutdown()
defer srv.agent.Shutdown()
// Attach a custom logger so we can inspect it
buf := &bytes.Buffer{}
srv.logger = log.New(buf, "", log.LstdFlags)
resp := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/some/url?token=secret1&token=secret2", nil)
handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
return nil, nil
}
srv.wrap(handler)(resp, req)
// Make sure no tokens from the URL show up in the log
if strings.Contains(buf.String(), "secret") {
t.Fatalf("bad: %s", buf.String())
}
}
func TestPrettyPrint(t *testing.T) {
testPrettyPrint("pretty=1", t)
}
func TestPrettyPrintBare(t *testing.T) {
testPrettyPrint("pretty", t)
}
func testPrettyPrint(pretty string, t *testing.T) {
dir, srv := makeHTTPServer(t)
defer os.RemoveAll(dir)
defer srv.Shutdown()
defer srv.agent.Shutdown()
r := &structs.DirEntry{Key: "key"}
resp := httptest.NewRecorder()
handler := func(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
return r, nil
}
urlStr := "/v1/kv/key?" + pretty
req, _ := http.NewRequest("GET", urlStr, nil)
srv.wrap(handler)(resp, req)
expected, _ := json.MarshalIndent(r, "", " ")
expected = append(expected, "\n"...)
actual, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("err: %s", err)
}
if !bytes.Equal(expected, actual) {
t.Fatalf("bad: %q", string(actual))
}
}
func TestParseSource(t *testing.T) {
dir, srv := makeHTTPServer(t)
defer os.RemoveAll(dir)
defer srv.Shutdown()
defer srv.agent.Shutdown()
// Default is agent's DC and no node (since the user didn't care, then
// just give them the cheapest possible query).
req, err := http.NewRequest("GET",
"/v1/catalog/nodes", nil)
if err != nil {
t.Fatalf("err: %v", err)
}
source := structs.QuerySource{}
srv.parseSource(req, &source)
if source.Datacenter != "dc1" || source.Node != "" {
t.Fatalf("bad: %v", source)
}
// Adding the source parameter should set that node.
req, err = http.NewRequest("GET",
"/v1/catalog/nodes?near=bob", nil)
if err != nil {
t.Fatalf("err: %v", err)
}
source = structs.QuerySource{}
srv.parseSource(req, &source)
if source.Datacenter != "dc1" || source.Node != "bob" {
t.Fatalf("bad: %v", source)
}
// We should follow whatever dc parameter was given so that the node is
// looked up correctly on the receiving end.
req, err = http.NewRequest("GET",
"/v1/catalog/nodes?near=bob&dc=foo", nil)
if err != nil {
t.Fatalf("err: %v", err)
}
source = structs.QuerySource{}
srv.parseSource(req, &source)
if source.Datacenter != "foo" || source.Node != "bob" {
t.Fatalf("bad: %v", source)
}
// The magic "_agent" node name will use the agent's local node name.
req, err = http.NewRequest("GET",
"/v1/catalog/nodes?near=_agent", nil)
if err != nil {
t.Fatalf("err: %v", err)
}
source = structs.QuerySource{}
srv.parseSource(req, &source)
if source.Datacenter != "dc1" || source.Node != srv.agent.config.NodeName {
t.Fatalf("bad: %v", source)
}
}
func TestParseWait(t *testing.T) {
resp := httptest.NewRecorder()
var b structs.QueryOptions
req, err := http.NewRequest("GET",
"/v1/catalog/nodes?wait=60s&index=1000", nil)
if err != nil {
t.Fatalf("err: %v", err)
}
if d := parseWait(resp, req, &b); d {
t.Fatalf("unexpected done")
}
if b.MinQueryIndex != 1000 {
t.Fatalf("Bad: %v", b)
}
if b.MaxQueryTime != 60*time.Second {
t.Fatalf("Bad: %v", b)
}
}
func TestParseWait_InvalidTime(t *testing.T) {
resp := httptest.NewRecorder()
var b structs.QueryOptions
req, err := http.NewRequest("GET",
"/v1/catalog/nodes?wait=60foo&index=1000", nil)
if err != nil {
t.Fatalf("err: %v", err)
}
if d := parseWait(resp, req, &b); !d {
t.Fatalf("expected done")
}
if resp.Code != 400 {
t.Fatalf("bad code: %v", resp.Code)
}
}
func TestParseWait_InvalidIndex(t *testing.T) {
resp := httptest.NewRecorder()
var b structs.QueryOptions
req, err := http.NewRequest("GET",
"/v1/catalog/nodes?wait=60s&index=foo", nil)
if err != nil {
t.Fatalf("err: %v", err)
}
if d := parseWait(resp, req, &b); !d {
t.Fatalf("expected done")
}
if resp.Code != 400 {
t.Fatalf("bad code: %v", resp.Code)
}
}
func TestParseConsistency(t *testing.T) {
resp := httptest.NewRecorder()
var b structs.QueryOptions
req, err := http.NewRequest("GET",
"/v1/catalog/nodes?stale", nil)
if err != nil {
t.Fatalf("err: %v", err)
}
if d := parseConsistency(resp, req, &b); d {
t.Fatalf("unexpected done")
}
if !b.AllowStale {
t.Fatalf("Bad: %v", b)
}
if b.RequireConsistent {
t.Fatalf("Bad: %v", b)
}
b = structs.QueryOptions{}
req, err = http.NewRequest("GET",
"/v1/catalog/nodes?consistent", nil)
if err != nil {
t.Fatalf("err: %v", err)
}
if d := parseConsistency(resp, req, &b); d {
t.Fatalf("unexpected done")
}
if b.AllowStale {
t.Fatalf("Bad: %v", b)
}
if !b.RequireConsistent {
t.Fatalf("Bad: %v", b)
}
}
func TestParseConsistency_Invalid(t *testing.T) {
resp := httptest.NewRecorder()
var b structs.QueryOptions
req, err := http.NewRequest("GET",
"/v1/catalog/nodes?stale&consistent", nil)
if err != nil {
t.Fatalf("err: %v", err)
}
if d := parseConsistency(resp, req, &b); !d {
t.Fatalf("expected done")
}
if resp.Code != 400 {
t.Fatalf("bad code: %v", resp.Code)
}
}
// Test ACL token is resolved in correct order
func TestACLResolution(t *testing.T) {
var token string
// Request without token
req, err := http.NewRequest("GET",
"/v1/catalog/nodes", nil)
if err != nil {
t.Fatalf("err: %v", err)
}
// Request with explicit token
reqToken, err := http.NewRequest("GET",
"/v1/catalog/nodes?token=foo", nil)
if err != nil {
t.Fatalf("err: %v", err)
}
// Request with header token only
reqHeaderToken, err := http.NewRequest("GET",
"/v1/catalog/nodes", nil)
if err != nil {
t.Fatalf("err: %v", err)
}
reqHeaderToken.Header.Add("X-Consul-Token", "bar")
// Request with header and querystring tokens
reqBothTokens, err := http.NewRequest("GET",
"/v1/catalog/nodes?token=baz", nil)
if err != nil {
t.Fatalf("err: %v", err)
}
reqBothTokens.Header.Add("X-Consul-Token", "zap")
httpTest(t, func(srv *HTTPServer) {
// Check when no token is set
srv.agent.config.ACLToken = ""
srv.parseToken(req, &token)
if token != "" {
t.Fatalf("bad: %s", token)
}
// Check when ACLToken set
srv.agent.config.ACLToken = "agent"
srv.parseToken(req, &token)
if token != "agent" {
t.Fatalf("bad: %s", token)
}
// Check when AtlasACLToken set, wrong server
srv.agent.config.AtlasACLToken = "atlas"
srv.parseToken(req, &token)
if token != "agent" {
t.Fatalf("bad: %s", token)
}
// Check when AtlasACLToken set, correct server
srv.addr = scadaHTTPAddr
srv.parseToken(req, &token)
if token != "atlas" {
t.Fatalf("bad: %s", token)
}
// Check when AtlasACLToken not, correct server
srv.agent.config.AtlasACLToken = ""
srv.parseToken(req, &token)
if token != "agent" {
t.Fatalf("bad: %s", token)
}
// Explicit token has highest precedence
srv.parseToken(reqToken, &token)
if token != "foo" {
t.Fatalf("bad: %s", token)
}
// Header token has precedence over agent token
srv.parseToken(reqHeaderToken, &token)
if token != "bar" {
t.Fatalf("bad: %s", token)
}
// Querystring token has precendence over header and agent tokens
srv.parseToken(reqBothTokens, &token)
if token != "baz" {
t.Fatalf("bad: %s", token)
}
})
}
func TestScadaHTTP(t *testing.T) {
// Create the agent
dir, agent := makeAgent(t, nextConfig())
defer os.RemoveAll(dir)
defer agent.Shutdown()
// Create a generic listener
list, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("err: %s", err)
}
defer list.Close()
// Create the SCADA HTTP server
scadaHttp := newScadaHttp(agent, list)
// Returned server uses the listener and scada addr
if scadaHttp.listener != list {
t.Fatalf("bad listener: %#v", scadaHttp)
}
if scadaHttp.addr != scadaHTTPAddr {
t.Fatalf("expected %v, got: %v", scadaHttp.addr, scadaHTTPAddr)
}
// Check that debug endpoints were not enabled. This will cause
// the serve mux to panic if the routes are already handled.
mockFn := func(w http.ResponseWriter, r *http.Request) {}
scadaHttp.mux.HandleFunc("/debug/pprof/", mockFn)
scadaHttp.mux.HandleFunc("/debug/pprof/cmdline", mockFn)
scadaHttp.mux.HandleFunc("/debug/pprof/profile", mockFn)
scadaHttp.mux.HandleFunc("/debug/pprof/symbol", mockFn)
}
func TestEnableWebUI(t *testing.T) {
httpTestWithConfig(t, func(s *HTTPServer) {
req, err := http.NewRequest("GET", "/ui/", nil)
if err != nil {
t.Fatalf("err: %v", err)
}
// Perform the request
resp := httptest.NewRecorder()
s.mux.ServeHTTP(resp, req)
// Check the result
if resp.Code != 200 {
t.Fatalf("should handle ui")
}
}, func(c *Config) {
c.EnableUi = true
})
}
// assertIndex tests that X-Consul-Index is set and non-zero
func assertIndex(t *testing.T, resp *httptest.ResponseRecorder) {
header := resp.Header().Get("X-Consul-Index")
if header == "" || header == "0" {
t.Fatalf("Bad: %v", header)
}
}
// checkIndex is like assertIndex but returns an error
func checkIndex(resp *httptest.ResponseRecorder) error {
header := resp.Header().Get("X-Consul-Index")
if header == "" || header == "0" {
return fmt.Errorf("Bad: %v", header)
}
return nil
}
// getIndex parses X-Consul-Index
func getIndex(t *testing.T, resp *httptest.ResponseRecorder) uint64 {
header := resp.Header().Get("X-Consul-Index")
if header == "" {
t.Fatalf("Bad: %v", header)
}
val, err := strconv.Atoi(header)
if err != nil {
t.Fatalf("Bad: %v", header)
}
return uint64(val)
}
func httpTest(t *testing.T, f func(srv *HTTPServer)) {
httpTestWithConfig(t, f, nil)
}
func httpTestWithConfig(t *testing.T, f func(srv *HTTPServer), cb func(c *Config)) {
dir, srv := makeHTTPServerWithConfig(t, cb)
defer os.RemoveAll(dir)
defer srv.Shutdown()
defer srv.agent.Shutdown()
testutil.WaitForLeader(t, srv.agent.RPC, "dc1")
f(srv)
}