Merge pull request #8231 from hashicorp/dnephin/unembed-HTTPServer-Server

agent/http: un-embed the http.Server
This commit is contained in:
Daniel Nephin 2020-07-09 17:42:33 -04:00 committed by GitHub
commit f22f3d300d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 89 additions and 67 deletions

View File

@ -1150,49 +1150,28 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) {
l = tls.NewListener(l, tlscfg) l = tls.NewListener(l, tlscfg)
} }
httpServer := &http.Server{
Addr: l.Addr().String(),
TLSConfig: tlscfg,
}
srv := &HTTPServer{ srv := &HTTPServer{
Server: &http.Server{ Server: httpServer,
Addr: l.Addr().String(),
TLSConfig: tlscfg,
},
ln: l, ln: l,
agent: a, agent: a,
denylist: NewDenylist(a.config.HTTPBlockEndpoints), denylist: NewDenylist(a.config.HTTPBlockEndpoints),
proto: proto, proto: proto,
} }
srv.Server.Handler = srv.handler(a.config.EnableDebug) httpServer.Handler = srv.handler(a.config.EnableDebug)
// Load the connlimit helper into the server // Load the connlimit helper into the server
connLimitFn := a.httpConnLimiter.HTTPConnStateFuncWithDefault429Handler(10 * time.Millisecond) connLimitFn := a.httpConnLimiter.HTTPConnStateFuncWithDefault429Handler(10 * time.Millisecond)
if proto == "https" { if proto == "https" {
// Enforce TLS handshake timeout if err := setupHTTPS(httpServer, connLimitFn, a.config.HTTPSHandshakeTimeout); err != nil {
srv.Server.ConnState = func(conn net.Conn, state http.ConnState) {
switch state {
case http.StateNew:
// Set deadline to prevent slow send before TLS handshake or first
// byte of request.
conn.SetReadDeadline(time.Now().Add(a.config.HTTPSHandshakeTimeout))
case http.StateActive:
// Clear read deadline. We should maybe set read timeouts more
// generally but that's a bigger task as some HTTP endpoints may
// stream large requests and responses (e.g. snapshot) so we can't
// set sensible blanket timeouts here.
conn.SetReadDeadline(time.Time{})
}
// Pass through to conn limit. This is OK because we didn't change
// state (i.e. Close conn).
connLimitFn(conn, state)
}
// This will enable upgrading connections to HTTP/2 as
// part of TLS negotiation.
err = http2.ConfigureServer(srv.Server, nil)
if err != nil {
return err return err
} }
} else { } else {
srv.Server.ConnState = connLimitFn httpServer.ConnState = connLimitFn
} }
ln = append(ln, l) ln = append(ln, l)
@ -1216,6 +1195,33 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) {
return servers, nil return servers, nil
} }
// setupHTTPS adds HTTP/2 support, ConnState, and a connection handshake timeout
// to the http.Server.
func setupHTTPS(server *http.Server, connState func(net.Conn, http.ConnState), timeout time.Duration) error {
// Enforce TLS handshake timeout
server.ConnState = func(conn net.Conn, state http.ConnState) {
switch state {
case http.StateNew:
// Set deadline to prevent slow send before TLS handshake or first
// byte of request.
conn.SetReadDeadline(time.Now().Add(timeout))
case http.StateActive:
// Clear read deadline. We should maybe set read timeouts more
// generally but that's a bigger task as some HTTP endpoints may
// stream large requests and responses (e.g. snapshot) so we can't
// set sensible blanket timeouts here.
conn.SetReadDeadline(time.Time{})
}
// Pass through to conn limit. This is OK because we didn't change
// state (i.e. Close conn).
connState(conn, state)
}
// This will enable upgrading connections to HTTP/2 as
// part of TLS negotiation.
return http2.ConfigureServer(server, nil)
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted // tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used so dead TCP connections eventually go away. // connections. It's used so dead TCP connections eventually go away.
type tcpKeepAliveListener struct { type tcpKeepAliveListener struct {
@ -1262,7 +1268,7 @@ func (a *Agent) serveHTTP(srv *HTTPServer) error {
go func() { go func() {
defer a.wgServers.Done() defer a.wgServers.Done()
notif <- srv.ln.Addr() notif <- srv.ln.Addr()
err := srv.Serve(srv.ln) err := srv.Server.Serve(srv.ln)
if err != nil && err != http.ErrServerClosed { if err != nil && err != http.ErrServerClosed {
a.logger.Error("error closing server", "error", err) a.logger.Error("error closing server", "error", err)
} }
@ -2110,7 +2116,7 @@ func (a *Agent) ShutdownEndpoints() {
) )
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel() defer cancel()
srv.Shutdown(ctx) srv.Server.Shutdown(ctx)
if ctx.Err() == context.DeadlineExceeded { if ctx.Err() == context.DeadlineExceeded {
a.logger.Warn("Timeout stopping server", a.logger.Warn("Timeout stopping server",
"protocol", strings.ToUpper(srv.proto), "protocol", strings.ToUpper(srv.proto),

View File

@ -4465,7 +4465,8 @@ func TestAgent_Monitor(t *testing.T) {
req = req.WithContext(cancelCtx) req = req.WithContext(cancelCtx)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
go a.srv.Handler.ServeHTTP(resp, req) handler := a.srv.handler(true)
go handler.ServeHTTP(resp, req)
args := &structs.ServiceDefinition{ args := &structs.ServiceDefinition{
Name: "monitor", Name: "monitor",

View File

@ -81,7 +81,8 @@ func (e ForbiddenError) Error() string {
// HTTPServer provides an HTTP api for an agent. // HTTPServer provides an HTTP api for an agent.
type HTTPServer struct { type HTTPServer struct {
*http.Server // TODO(dnephin): remove Server field, it is not used by any of the HTTPServer methods
Server *http.Server
ln net.Listener ln net.Listener
agent *Agent agent *Agent
denylist *Denylist denylist *Denylist

View File

@ -133,7 +133,7 @@ func TestHTTPAPI_OptionMethod_OSS(t *testing.T) {
uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), path) uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), path)
req, _ := http.NewRequest("OPTIONS", uri, nil) req, _ := http.NewRequest("OPTIONS", uri, nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
a.srv.Handler.ServeHTTP(resp, req) a.srv.handler(true).ServeHTTP(resp, req)
allMethods := append([]string{"OPTIONS"}, methods...) allMethods := append([]string{"OPTIONS"}, methods...)
if resp.Code != http.StatusOK { if resp.Code != http.StatusOK {
@ -175,7 +175,7 @@ func TestHTTPAPI_AllowedNets_OSS(t *testing.T) {
req, _ := http.NewRequest(method, uri, nil) req, _ := http.NewRequest(method, uri, nil)
req.RemoteAddr = "192.168.1.2:5555" req.RemoteAddr = "192.168.1.2:5555"
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
a.srv.Handler.ServeHTTP(resp, req) a.srv.handler(true).ServeHTTP(resp, req)
require.Equal(t, http.StatusForbidden, resp.Code, "%s %s", method, path) require.Equal(t, http.StatusForbidden, resp.Code, "%s %s", method, path)
}) })

View File

@ -3,6 +3,7 @@ package agent
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/tls"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -129,7 +130,7 @@ func TestHTTPServer_UnixSocket_FileExists(t *testing.T) {
} }
} }
func TestHTTPServer_H2(t *testing.T) { func TestSetupHTTPServer_HTTP2(t *testing.T) {
t.Parallel() t.Parallel()
// Fire up an agent with TLS enabled. // Fire up an agent with TLS enabled.
@ -161,24 +162,37 @@ func TestHTTPServer_H2(t *testing.T) {
if err := http2.ConfigureTransport(transport); err != nil { if err := http2.ConfigureTransport(transport); err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
hc := &http.Client{ httpClient := &http.Client{Transport: transport}
Transport: transport,
}
// Hook a handler that echoes back the protocol. // Hook a handler that echoes back the protocol.
handler := func(resp http.ResponseWriter, req *http.Request) { handler := func(resp http.ResponseWriter, req *http.Request) {
resp.WriteHeader(http.StatusOK) resp.WriteHeader(http.StatusOK)
fmt.Fprint(resp, req.Proto) fmt.Fprint(resp, req.Proto)
} }
w, ok := a.srv.Handler.(*wrappedMux)
if !ok { // Create an httpServer to be configured with setupHTTPS, and add our
t.Fatalf("handler is not expected type") // custom handler.
} httpServer := &http.Server{}
w.mux.HandleFunc("/echo", handler) noopConnState := func(net.Conn, http.ConnState) {}
err = setupHTTPS(httpServer, noopConnState, time.Second)
require.NoError(t, err)
srvHandler := a.srv.handler(true)
mux, ok := srvHandler.(*wrappedMux)
require.True(t, ok, "expected a *wrappedMux, got %T", handler)
mux.mux.HandleFunc("/echo", handler)
httpServer.Handler = mux
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
tlsListener := tls.NewListener(listener, a.tlsConfigurator.IncomingHTTPSConfig())
go httpServer.Serve(tlsListener)
defer httpServer.Shutdown(context.Background())
// Call it and make sure we see HTTP/2. // Call it and make sure we see HTTP/2.
url := fmt.Sprintf("https://%s/echo", a.srv.ln.Addr().String()) url := fmt.Sprintf("https://%s/echo", listener.Addr().String())
resp, err := hc.Get(url) resp, err := httpClient.Get(url)
if err != nil { if err != nil {
t.Fatalf("err: %v", err) t.Fatalf("err: %v", err)
} }
@ -195,9 +209,9 @@ func TestHTTPServer_H2(t *testing.T) {
// some other endpoint, but configure an API client and make a call // some other endpoint, but configure an API client and make a call
// just as a sanity check. // just as a sanity check.
cfg := &api.Config{ cfg := &api.Config{
Address: a.srv.ln.Addr().String(), Address: listener.Addr().String(),
Scheme: "https", Scheme: "https",
HttpClient: hc, HttpClient: httpClient,
} }
client, err := api.NewClient(cfg) client, err := api.NewClient(cfg)
if err != nil { if err != nil {
@ -333,7 +347,7 @@ func TestHTTPAPI_Ban_Nonprintable_Characters(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
a.srv.Handler.ServeHTTP(resp, req) a.srv.handler(true).ServeHTTP(resp, req)
if got, want := resp.Code, http.StatusBadRequest; got != want { if got, want := resp.Code, http.StatusBadRequest; got != want {
t.Fatalf("bad response code got %d want %d", got, want) t.Fatalf("bad response code got %d want %d", got, want)
} }
@ -352,7 +366,7 @@ func TestHTTPAPI_Allow_Nonprintable_Characters_With_Flag(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
a.srv.Handler.ServeHTTP(resp, req) a.srv.handler(true).ServeHTTP(resp, req)
// Key doesn't actually exist so we should get 404 // Key doesn't actually exist so we should get 404
if got, want := resp.Code, http.StatusNotFound; got != want { if got, want := resp.Code, http.StatusNotFound; got != want {
t.Fatalf("bad response code got %d want %d", got, want) t.Fatalf("bad response code got %d want %d", got, want)
@ -490,14 +504,14 @@ func TestAcceptEncodingGzip(t *testing.T) {
// negotiation, but since this call doesn't go through a real // negotiation, but since this call doesn't go through a real
// transport, the header has to be set manually // transport, the header has to be set manually
req.Header["Accept-Encoding"] = []string{"gzip"} req.Header["Accept-Encoding"] = []string{"gzip"}
a.srv.Handler.ServeHTTP(resp, req) a.srv.handler(true).ServeHTTP(resp, req)
require.Equal(t, 200, resp.Code) require.Equal(t, 200, resp.Code)
require.Equal(t, "", resp.Header().Get("Content-Encoding")) require.Equal(t, "", resp.Header().Get("Content-Encoding"))
resp = httptest.NewRecorder() resp = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/v1/kv/long", nil) req, _ = http.NewRequest("GET", "/v1/kv/long", nil)
req.Header["Accept-Encoding"] = []string{"gzip"} req.Header["Accept-Encoding"] = []string{"gzip"}
a.srv.Handler.ServeHTTP(resp, req) a.srv.handler(true).ServeHTTP(resp, req)
require.Equal(t, 200, resp.Code) require.Equal(t, 200, resp.Code)
require.Equal(t, "gzip", resp.Header().Get("Content-Encoding")) require.Equal(t, "gzip", resp.Header().Get("Content-Encoding"))
} }
@ -811,35 +825,35 @@ func TestParseWait(t *testing.T) {
} }
} }
func TestPProfHandlers_EnableDebug(t *testing.T) { func TestHTTPServer_PProfHandlers_EnableDebug(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t) a := NewTestAgent(t, ``)
a := NewTestAgent(t, "enable_debug = true")
defer a.Shutdown() defer a.Shutdown()
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/debug/pprof/profile?seconds=1", nil) req, _ := http.NewRequest("GET", "/debug/pprof/profile?seconds=1", nil)
a.srv.Handler.ServeHTTP(resp, req) httpServer := &HTTPServer{agent: a.Agent}
httpServer.handler(true).ServeHTTP(resp, req)
require.Equal(http.StatusOK, resp.Code) require.Equal(t, http.StatusOK, resp.Code)
} }
func TestPProfHandlers_DisableDebugNoACLs(t *testing.T) { func TestHTTPServer_PProfHandlers_DisableDebugNoACLs(t *testing.T) {
t.Parallel() t.Parallel()
require := require.New(t) a := NewTestAgent(t, ``)
a := NewTestAgent(t, "enable_debug = false")
defer a.Shutdown() defer a.Shutdown()
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/debug/pprof/profile", nil) req, _ := http.NewRequest("GET", "/debug/pprof/profile", nil)
a.srv.Handler.ServeHTTP(resp, req) httpServer := &HTTPServer{agent: a.Agent}
httpServer.handler(false).ServeHTTP(resp, req)
require.Equal(http.StatusUnauthorized, resp.Code) require.Equal(t, http.StatusUnauthorized, resp.Code)
} }
func TestPProfHandlers_ACLs(t *testing.T) { func TestHTTPServer_PProfHandlers_ACLs(t *testing.T) {
t.Parallel() t.Parallel()
assert := assert.New(t) assert := assert.New(t)
dc1 := "dc1" dc1 := "dc1"
@ -904,7 +918,7 @@ func TestPProfHandlers_ACLs(t *testing.T) {
t.Run(fmt.Sprintf("case %d (%#v)", i, c), func(t *testing.T) { t.Run(fmt.Sprintf("case %d (%#v)", i, c), func(t *testing.T) {
req, _ := http.NewRequest("GET", fmt.Sprintf("%s?token=%s", c.endpoint, c.token), nil) req, _ := http.NewRequest("GET", fmt.Sprintf("%s?token=%s", c.endpoint, c.token), nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
a.srv.Handler.ServeHTTP(resp, req) a.srv.handler(true).ServeHTTP(resp, req)
assert.Equal(c.code, resp.Code) assert.Equal(c.code, resp.Code)
}) })
} }
@ -1192,7 +1206,7 @@ func TestEnableWebUI(t *testing.T) {
req, _ := http.NewRequest("GET", "/ui/", nil) req, _ := http.NewRequest("GET", "/ui/", nil)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
a.srv.Handler.ServeHTTP(resp, req) a.srv.handler(true).ServeHTTP(resp, req)
if resp.Code != 200 { if resp.Code != 200 {
t.Fatalf("should handle ui") t.Fatalf("should handle ui")
} }

View File

@ -384,7 +384,7 @@ func (a *TestAgent) HTTPAddr() string {
if a.srv == nil { if a.srv == nil {
return "" return ""
} }
return a.srv.Addr return a.srv.Server.Addr
} }
func (a *TestAgent) SegmentAddr(name string) string { func (a *TestAgent) SegmentAddr(name string) string {

View File

@ -43,7 +43,7 @@ func TestUiIndex(t *testing.T) {
// Register node // Register node
req, _ := http.NewRequest("GET", "/ui/my-file", nil) req, _ := http.NewRequest("GET", "/ui/my-file", nil)
req.URL.Scheme = "http" req.URL.Scheme = "http"
req.URL.Host = a.srv.Addr req.URL.Host = a.srv.Server.Addr
// Make the request // Make the request
client := cleanhttp.DefaultClient() client := cleanhttp.DefaultClient()