agent/http: un-embed the HTTPServer

The embedded HTTPServer struct is not used by the large HTTPServer
struct. It is used by tests and the agent. This change is a small first
step in the process of removing that field.

The eventual goal is to reduce the scope of HTTPServer making it easier
to test, and split into separate packages.
This commit is contained in:
Daniel Nephin 2020-07-02 16:47:54 -04:00
parent db387eccd6
commit a5e45defb1
7 changed files with 42 additions and 40 deletions

View File

@ -1151,24 +1151,25 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) {
l = tls.NewListener(l, tlscfg)
}
httpServer := &http.Server{
Addr: l.Addr().String(),
TLSConfig: tlscfg,
}
srv := &HTTPServer{
Server: &http.Server{
Addr: l.Addr().String(),
TLSConfig: tlscfg,
},
Server: httpServer,
ln: l,
agent: a,
denylist: NewDenylist(a.config.HTTPBlockEndpoints),
proto: proto,
}
srv.Server.Handler = srv.handler(a.config.EnableDebug)
httpServer.Handler = srv.handler(a.config.EnableDebug)
// Load the connlimit helper into the server
connLimitFn := a.httpConnLimiter.HTTPConnStateFunc()
if proto == "https" {
// Enforce TLS handshake timeout
srv.Server.ConnState = func(conn net.Conn, state http.ConnState) {
httpServer.ConnState = func(conn net.Conn, state http.ConnState) {
switch state {
case http.StateNew:
// Set deadline to prevent slow send before TLS handshake or first
@ -1188,12 +1189,12 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) {
// This will enable upgrading connections to HTTP/2 as
// part of TLS negotiation.
err = http2.ConfigureServer(srv.Server, nil)
err = http2.ConfigureServer(httpServer, nil)
if err != nil {
return err
}
} else {
srv.Server.ConnState = connLimitFn
httpServer.ConnState = connLimitFn
}
ln = append(ln, l)
@ -1263,7 +1264,7 @@ func (a *Agent) serveHTTP(srv *HTTPServer) error {
go func() {
defer a.wgServers.Done()
notif <- srv.ln.Addr()
err := srv.Serve(srv.ln)
err := srv.Server.Serve(srv.ln)
if err != nil && err != http.ErrServerClosed {
a.logger.Error("error closing server", "error", err)
}
@ -2111,7 +2112,7 @@ func (a *Agent) ShutdownEndpoints() {
)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
srv.Shutdown(ctx)
srv.Server.Shutdown(ctx)
if ctx.Err() == context.DeadlineExceeded {
a.logger.Warn("Timeout stopping server",
"protocol", strings.ToUpper(srv.proto),

View File

@ -4465,7 +4465,8 @@ func TestAgent_Monitor(t *testing.T) {
req = req.WithContext(cancelCtx)
resp := httptest.NewRecorder()
go a.srv.Handler.ServeHTTP(resp, req)
handler := a.srv.handler(true)
go handler.ServeHTTP(resp, req)
args := &structs.ServiceDefinition{
Name: "monitor",

View File

@ -81,7 +81,8 @@ func (e ForbiddenError) Error() string {
// HTTPServer provides an HTTP api for an agent.
type HTTPServer struct {
*http.Server
// TODO(dnephin): remove Server field, it is not used by any of the HTTPServer methods
Server *http.Server
ln net.Listener
agent *Agent
denylist *Denylist

View File

@ -133,7 +133,7 @@ func TestHTTPAPI_OptionMethod_OSS(t *testing.T) {
uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), path)
req, _ := http.NewRequest("OPTIONS", uri, nil)
resp := httptest.NewRecorder()
a.srv.Handler.ServeHTTP(resp, req)
a.srv.handler(true).ServeHTTP(resp, req)
allMethods := append([]string{"OPTIONS"}, methods...)
if resp.Code != http.StatusOK {
@ -175,7 +175,7 @@ func TestHTTPAPI_AllowedNets_OSS(t *testing.T) {
req, _ := http.NewRequest(method, uri, nil)
req.RemoteAddr = "192.168.1.2:5555"
resp := httptest.NewRecorder()
a.srv.Handler.ServeHTTP(resp, req)
a.srv.handler(true).ServeHTTP(resp, req)
require.Equal(t, http.StatusForbidden, resp.Code, "%s %s", method, path)
})

View File

@ -129,7 +129,7 @@ func TestHTTPServer_UnixSocket_FileExists(t *testing.T) {
}
}
func TestHTTPServer_H2(t *testing.T) {
func TestHTTPServer_HTTP2(t *testing.T) {
t.Parallel()
// Fire up an agent with TLS enabled.
@ -161,16 +161,15 @@ func TestHTTPServer_H2(t *testing.T) {
if err := http2.ConfigureTransport(transport); err != nil {
t.Fatalf("err: %v", err)
}
hc := &http.Client{
Transport: transport,
}
httpClient := &http.Client{Transport: transport}
// Hook a handler that echoes back the protocol.
handler := func(resp http.ResponseWriter, req *http.Request) {
resp.WriteHeader(http.StatusOK)
fmt.Fprint(resp, req.Proto)
}
w, ok := a.srv.Handler.(*wrappedMux)
w, ok := a.srv.Server.Handler.(*wrappedMux)
if !ok {
t.Fatalf("handler is not expected type")
}
@ -178,7 +177,7 @@ func TestHTTPServer_H2(t *testing.T) {
// Call it and make sure we see HTTP/2.
url := fmt.Sprintf("https://%s/echo", a.srv.ln.Addr().String())
resp, err := hc.Get(url)
resp, err := httpClient.Get(url)
if err != nil {
t.Fatalf("err: %v", err)
}
@ -197,7 +196,7 @@ func TestHTTPServer_H2(t *testing.T) {
cfg := &api.Config{
Address: a.srv.ln.Addr().String(),
Scheme: "https",
HttpClient: hc,
HttpClient: httpClient,
}
client, err := api.NewClient(cfg)
if err != nil {
@ -333,7 +332,7 @@ func TestHTTPAPI_Ban_Nonprintable_Characters(t *testing.T) {
t.Fatal(err)
}
resp := httptest.NewRecorder()
a.srv.Handler.ServeHTTP(resp, req)
a.srv.handler(true).ServeHTTP(resp, req)
if got, want := resp.Code, http.StatusBadRequest; got != want {
t.Fatalf("bad response code got %d want %d", got, want)
}
@ -352,7 +351,7 @@ func TestHTTPAPI_Allow_Nonprintable_Characters_With_Flag(t *testing.T) {
t.Fatal(err)
}
resp := httptest.NewRecorder()
a.srv.Handler.ServeHTTP(resp, req)
a.srv.handler(true).ServeHTTP(resp, req)
// Key doesn't actually exist so we should get 404
if got, want := resp.Code, http.StatusNotFound; got != want {
t.Fatalf("bad response code got %d want %d", got, want)
@ -490,14 +489,14 @@ func TestAcceptEncodingGzip(t *testing.T) {
// negotiation, but since this call doesn't go through a real
// transport, the header has to be set manually
req.Header["Accept-Encoding"] = []string{"gzip"}
a.srv.Handler.ServeHTTP(resp, req)
a.srv.handler(true).ServeHTTP(resp, req)
require.Equal(t, 200, resp.Code)
require.Equal(t, "", resp.Header().Get("Content-Encoding"))
resp = httptest.NewRecorder()
req, _ = http.NewRequest("GET", "/v1/kv/long", nil)
req.Header["Accept-Encoding"] = []string{"gzip"}
a.srv.Handler.ServeHTTP(resp, req)
a.srv.handler(true).ServeHTTP(resp, req)
require.Equal(t, 200, resp.Code)
require.Equal(t, "gzip", resp.Header().Get("Content-Encoding"))
}
@ -811,35 +810,35 @@ func TestParseWait(t *testing.T) {
}
}
func TestPProfHandlers_EnableDebug(t *testing.T) {
func TestHTTPServer_PProfHandlers_EnableDebug(t *testing.T) {
t.Parallel()
require := require.New(t)
a := NewTestAgent(t, "enable_debug = true")
a := NewTestAgent(t, ``)
defer a.Shutdown()
resp := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/debug/pprof/profile?seconds=1", nil)
a.srv.Handler.ServeHTTP(resp, req)
httpServer := &HTTPServer{agent: a.Agent}
httpServer.handler(true).ServeHTTP(resp, req)
require.Equal(http.StatusOK, resp.Code)
require.Equal(t, http.StatusOK, resp.Code)
}
func TestPProfHandlers_DisableDebugNoACLs(t *testing.T) {
func TestHTTPServer_PProfHandlers_DisableDebugNoACLs(t *testing.T) {
t.Parallel()
require := require.New(t)
a := NewTestAgent(t, "enable_debug = false")
a := NewTestAgent(t, ``)
defer a.Shutdown()
resp := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/debug/pprof/profile", nil)
a.srv.Handler.ServeHTTP(resp, req)
httpServer := &HTTPServer{agent: a.Agent}
httpServer.handler(false).ServeHTTP(resp, req)
require.Equal(http.StatusUnauthorized, resp.Code)
require.Equal(t, http.StatusUnauthorized, resp.Code)
}
func TestPProfHandlers_ACLs(t *testing.T) {
func TestHTTPServer_PProfHandlers_ACLs(t *testing.T) {
t.Parallel()
assert := assert.New(t)
dc1 := "dc1"
@ -904,7 +903,7 @@ func TestPProfHandlers_ACLs(t *testing.T) {
t.Run(fmt.Sprintf("case %d (%#v)", i, c), func(t *testing.T) {
req, _ := http.NewRequest("GET", fmt.Sprintf("%s?token=%s", c.endpoint, c.token), nil)
resp := httptest.NewRecorder()
a.srv.Handler.ServeHTTP(resp, req)
a.srv.handler(true).ServeHTTP(resp, req)
assert.Equal(c.code, resp.Code)
})
}
@ -1192,7 +1191,7 @@ func TestEnableWebUI(t *testing.T) {
req, _ := http.NewRequest("GET", "/ui/", nil)
resp := httptest.NewRecorder()
a.srv.Handler.ServeHTTP(resp, req)
a.srv.handler(true).ServeHTTP(resp, req)
if resp.Code != 200 {
t.Fatalf("should handle ui")
}

View File

@ -384,7 +384,7 @@ func (a *TestAgent) HTTPAddr() string {
if a.srv == nil {
return ""
}
return a.srv.Addr
return a.srv.Server.Addr
}
func (a *TestAgent) SegmentAddr(name string) string {

View File

@ -43,7 +43,7 @@ func TestUiIndex(t *testing.T) {
// Register node
req, _ := http.NewRequest("GET", "/ui/my-file", nil)
req.URL.Scheme = "http"
req.URL.Host = a.srv.Addr
req.URL.Host = a.srv.Server.Addr
// Make the request
client := cleanhttp.DefaultClient()