diff --git a/.changelog/21930.txt b/.changelog/21930.txt new file mode 100644 index 0000000000..bfcf2748f0 --- /dev/null +++ b/.changelog/21930.txt @@ -0,0 +1,3 @@ +```release-note:security +api: Enforces strict content-type header validation to protect against XSS vulnerability. +``` \ No newline at end of file diff --git a/agent/http.go b/agent/http.go index 506377074a..eb6e186cd8 100644 --- a/agent/http.go +++ b/agent/http.go @@ -6,7 +6,6 @@ package agent import ( "encoding/json" "fmt" - "github.com/hashicorp/go-hclog" "io" "net" "net/http" @@ -20,6 +19,8 @@ import ( "sync/atomic" "time" + "github.com/hashicorp/go-hclog" + "github.com/NYTimes/gziphandler" "github.com/armon/go-metrics" "github.com/armon/go-metrics/prometheus" @@ -348,16 +349,24 @@ func withRemoteAddrHandler(next http.Handler) http.Handler { }) } -// Injects content type explicitly if not already set into response to prevent XSS +// ensureContentTypeHeader injects content-type explicitly if not already set into response to prevent XSS func ensureContentTypeHeader(next http.Handler, logger hclog.Logger) http.Handler { - return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { next.ServeHTTP(resp, req) - val := resp.Header().Get(contentTypeHeader) - if val == "" { - resp.Header().Set(contentTypeHeader, plainContentType) - logger.Debug("warning: content-type header not explicitly set.", "request-path", req.URL) + contentType := api.GetContentType(req) + + if req != nil { + logger.Debug("warning: request content-type is not supported", "request-path", req.URL) + req.Header.Set(contentTypeHeader, contentType) + } + + if resp != nil { + respContentType := resp.Header().Get(contentTypeHeader) + if respContentType == "" || respContentType != contentType { + logger.Debug("warning: response content-type header not explicitly set.", "request-path", req.URL) + resp.Header().Set(contentTypeHeader, contentType) + } } }) } diff --git a/agent/http_test.go b/agent/http_test.go index 497789f689..73a599546a 100644 --- a/agent/http_test.go +++ b/agent/http_test.go @@ -617,7 +617,6 @@ func TestHTTPAPI_DefaultACLPolicy(t *testing.T) { }) } } - func TestHTTPAPIResponseHeaders(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") @@ -646,6 +645,87 @@ func TestHTTPAPIResponseHeaders(t *testing.T) { requireHasHeadersSet(t, a, "/", "text/plain; charset=utf-8") } +func TestHTTPAPIValidateContentTypeHeaders(t *testing.T) { + if testing.Short() { + t.Skip("too slow for testing.Short") + } + + t.Parallel() + type testcase struct { + name string + endpoint string + method string + requestBody io.Reader + expectedContentType string + } + + cases := []testcase{ + { + name: "snapshot endpoint expect non-default content type", + method: http.MethodPut, + endpoint: "/v1/snapshot", + requestBody: bytes.NewBuffer([]byte("test")), + expectedContentType: "application/octet-stream", + }, + { + name: "kv endpoint expect non-default content type", + method: http.MethodPut, + endpoint: "/v1/kv", + requestBody: bytes.NewBuffer([]byte("test")), + expectedContentType: "application/octet-stream", + }, + { + name: "event/fire endpoint expect default content type", + method: http.MethodPut, + endpoint: "/v1/event/fire", + requestBody: bytes.NewBuffer([]byte("test")), + expectedContentType: "application/octet-stream", + }, + { + name: "peering/token endpoint expect default content type", + method: http.MethodPost, + endpoint: "/v1/peering/token", + requestBody: bytes.NewBuffer([]byte("test")), + expectedContentType: "application/json", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + a := NewTestAgent(t, "") + defer a.Shutdown() + + requireContentTypeHeadersSet(t, a, tc.method, tc.endpoint, tc.requestBody, tc.expectedContentType) + }) + } +} + +func requireContentTypeHeadersSet(t *testing.T, a *TestAgent, method, path string, body io.Reader, contentType string) { + t.Helper() + + resp := httptest.NewRecorder() + req, _ := http.NewRequest(method, path, body) + a.enableDebug.Store(true) + + a.srv.handler().ServeHTTP(resp, req) + + reqHdrs := req.Header + respHdrs := resp.Header() + + // require request content-type + require.NotEmpty(t, reqHdrs.Get("Content-Type")) + require.Equal(t, contentType, reqHdrs.Get("Content-Type"), + "Request Header Content-Type value incorrect") + + // require response content-type + require.NotEmpty(t, respHdrs.Get("Content-Type")) + require.Equal(t, contentType, respHdrs.Get("Content-Type"), + "Response Header Content-Type value incorrect") +} + func requireHasHeadersSet(t *testing.T, a *TestAgent, path string, contentType string) { t.Helper() @@ -663,7 +743,7 @@ func requireHasHeadersSet(t *testing.T, a *TestAgent, path string, contentType s "X-XSS-Protection header value incorrect") require.Equal(t, contentType, hdrs.Get("Content-Type"), - "") + "Response Content-Type header value incorrect") } func TestUIResponseHeaders(t *testing.T) { @@ -704,7 +784,7 @@ func TestErrorContentTypeHeaderSet(t *testing.T) { `) defer a.Shutdown() - requireHasHeadersSet(t, a, "/fake-path-doesn't-exist", "text/plain; charset=utf-8") + requireHasHeadersSet(t, a, "/fake-path-doesn't-exist", "application/json") } func TestAcceptEncodingGzip(t *testing.T) { diff --git a/api/api.go b/api/api.go index d4d853d5d4..27af1ea569 100644 --- a/api/api.go +++ b/api/api.go @@ -1087,8 +1087,23 @@ func (c *Client) doRequest(r *request) (time.Duration, *http.Response, error) { if err != nil { return 0, nil, err } + + contentType := GetContentType(req) + + if req != nil { + req.Header.Set(contentTypeHeader, contentType) + } + start := time.Now() resp, err := c.config.HttpClient.Do(req) + + if resp != nil { + respContentType := resp.Header.Get(contentTypeHeader) + if respContentType == "" || respContentType != contentType { + resp.Header.Set(contentTypeHeader, contentType) + } + } + diff := time.Since(start) return diff, resp, err } diff --git a/api/api_test.go b/api/api_test.go index e8a03f7218..9a3ed7374c 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -935,11 +935,11 @@ func TestAPI_Headers(t *testing.T) { _, _, err = kv.Get("test-headers", nil) require.NoError(t, err) - require.Equal(t, "", request.Header.Get("Content-Type")) + require.Equal(t, "application/json", request.Header.Get("Content-Type")) _, err = kv.Delete("test-headers", nil) require.NoError(t, err) - require.Equal(t, "", request.Header.Get("Content-Type")) + require.Equal(t, "application/json", request.Header.Get("Content-Type")) err = c.Snapshot().Restore(nil, strings.NewReader("foo")) require.Error(t, err) diff --git a/api/content_type.go b/api/content_type.go new file mode 100644 index 0000000000..37c8cf60aa --- /dev/null +++ b/api/content_type.go @@ -0,0 +1,81 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package api + +import ( + "net/http" + "strings" +) + +const ( + contentTypeHeader = "Content-Type" + plainContentType = "text/plain; charset=utf-8" + octetStream = "application/octet-stream" + jsonContentType = "application/json" // Default content type +) + +// ContentTypeRule defines a rule for determining the content type of an HTTP request. +// This rule is based on the combination of the HTTP path, method, and the desired content type. +type ContentTypeRule struct { + path string + httpMethod string + contentType string +} + +var ContentTypeRules = []ContentTypeRule{ + { + path: "/v1/snapshot", + httpMethod: http.MethodPut, + contentType: octetStream, + }, + { + path: "/v1/kv", + httpMethod: http.MethodPut, + contentType: octetStream, + }, + { + path: "/v1/event/fire", + httpMethod: http.MethodPut, + contentType: octetStream, + }, +} + +// GetContentType returns the content type for a request +// This function isused as routing logic or middleware to determine and enforce +// the appropriate content type for HTTP requests. +func GetContentType(req *http.Request) string { + reqContentType := req.Header.Get(contentTypeHeader) + + if isIndexPage(req) { + return plainContentType + } + + // For GET, DELETE, or internal API paths, ensure a valid Content-Type is returned. + if req.Method == http.MethodGet || req.Method == http.MethodDelete || strings.HasPrefix(req.URL.Path, "/v1/internal") { + if reqContentType == "" { + // Default to JSON Content-Type if no Content-Type is provided. + return jsonContentType + } + // Return the provided Content-Type if it exists. + return reqContentType + } + + for _, rule := range ContentTypeRules { + if matchesRule(req, rule) { + return rule.contentType + } + } + return jsonContentType +} + +// matchesRule checks if a request matches a content type rule +func matchesRule(req *http.Request, rule ContentTypeRule) bool { + return strings.HasPrefix(req.URL.Path, rule.path) && + (rule.httpMethod == "" || req.Method == rule.httpMethod) +} + +// isIndexPage checks if the request is for the index page +func isIndexPage(req *http.Request) bool { + return req.URL.Path == "/" || req.URL.Path == "/ui" +}