mirror of https://github.com/status-im/consul.git
agent/http: un-embed the HTTPServer
The embedded HTTPServer struct is not used by the large HTTPServer struct. It is used by tests and the agent. This change is a small first step in the process of removing that field. The eventual goal is to reduce the scope of HTTPServer making it easier to test, and split into separate packages.
This commit is contained in:
parent
db387eccd6
commit
a5e45defb1
|
@ -1151,24 +1151,25 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) {
|
|||
l = tls.NewListener(l, tlscfg)
|
||||
}
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: l.Addr().String(),
|
||||
TLSConfig: tlscfg,
|
||||
}
|
||||
srv := &HTTPServer{
|
||||
Server: &http.Server{
|
||||
Addr: l.Addr().String(),
|
||||
TLSConfig: tlscfg,
|
||||
},
|
||||
Server: httpServer,
|
||||
ln: l,
|
||||
agent: a,
|
||||
denylist: NewDenylist(a.config.HTTPBlockEndpoints),
|
||||
proto: proto,
|
||||
}
|
||||
srv.Server.Handler = srv.handler(a.config.EnableDebug)
|
||||
httpServer.Handler = srv.handler(a.config.EnableDebug)
|
||||
|
||||
// Load the connlimit helper into the server
|
||||
connLimitFn := a.httpConnLimiter.HTTPConnStateFunc()
|
||||
|
||||
if proto == "https" {
|
||||
// Enforce TLS handshake timeout
|
||||
srv.Server.ConnState = func(conn net.Conn, state http.ConnState) {
|
||||
httpServer.ConnState = func(conn net.Conn, state http.ConnState) {
|
||||
switch state {
|
||||
case http.StateNew:
|
||||
// Set deadline to prevent slow send before TLS handshake or first
|
||||
|
@ -1188,12 +1189,12 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) {
|
|||
|
||||
// This will enable upgrading connections to HTTP/2 as
|
||||
// part of TLS negotiation.
|
||||
err = http2.ConfigureServer(srv.Server, nil)
|
||||
err = http2.ConfigureServer(httpServer, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
srv.Server.ConnState = connLimitFn
|
||||
httpServer.ConnState = connLimitFn
|
||||
}
|
||||
|
||||
ln = append(ln, l)
|
||||
|
@ -1263,7 +1264,7 @@ func (a *Agent) serveHTTP(srv *HTTPServer) error {
|
|||
go func() {
|
||||
defer a.wgServers.Done()
|
||||
notif <- srv.ln.Addr()
|
||||
err := srv.Serve(srv.ln)
|
||||
err := srv.Server.Serve(srv.ln)
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
a.logger.Error("error closing server", "error", err)
|
||||
}
|
||||
|
@ -2111,7 +2112,7 @@ func (a *Agent) ShutdownEndpoints() {
|
|||
)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
srv.Shutdown(ctx)
|
||||
srv.Server.Shutdown(ctx)
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
a.logger.Warn("Timeout stopping server",
|
||||
"protocol", strings.ToUpper(srv.proto),
|
||||
|
|
|
@ -4465,7 +4465,8 @@ func TestAgent_Monitor(t *testing.T) {
|
|||
req = req.WithContext(cancelCtx)
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
go a.srv.Handler.ServeHTTP(resp, req)
|
||||
handler := a.srv.handler(true)
|
||||
go handler.ServeHTTP(resp, req)
|
||||
|
||||
args := &structs.ServiceDefinition{
|
||||
Name: "monitor",
|
||||
|
|
|
@ -81,7 +81,8 @@ func (e ForbiddenError) Error() string {
|
|||
|
||||
// HTTPServer provides an HTTP api for an agent.
|
||||
type HTTPServer struct {
|
||||
*http.Server
|
||||
// TODO(dnephin): remove Server field, it is not used by any of the HTTPServer methods
|
||||
Server *http.Server
|
||||
ln net.Listener
|
||||
agent *Agent
|
||||
denylist *Denylist
|
||||
|
|
|
@ -133,7 +133,7 @@ func TestHTTPAPI_OptionMethod_OSS(t *testing.T) {
|
|||
uri := fmt.Sprintf("http://%s%s", a.HTTPAddr(), path)
|
||||
req, _ := http.NewRequest("OPTIONS", uri, nil)
|
||||
resp := httptest.NewRecorder()
|
||||
a.srv.Handler.ServeHTTP(resp, req)
|
||||
a.srv.handler(true).ServeHTTP(resp, req)
|
||||
allMethods := append([]string{"OPTIONS"}, methods...)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
|
@ -175,7 +175,7 @@ func TestHTTPAPI_AllowedNets_OSS(t *testing.T) {
|
|||
req, _ := http.NewRequest(method, uri, nil)
|
||||
req.RemoteAddr = "192.168.1.2:5555"
|
||||
resp := httptest.NewRecorder()
|
||||
a.srv.Handler.ServeHTTP(resp, req)
|
||||
a.srv.handler(true).ServeHTTP(resp, req)
|
||||
|
||||
require.Equal(t, http.StatusForbidden, resp.Code, "%s %s", method, path)
|
||||
})
|
||||
|
|
|
@ -129,7 +129,7 @@ func TestHTTPServer_UnixSocket_FileExists(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestHTTPServer_H2(t *testing.T) {
|
||||
func TestHTTPServer_HTTP2(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Fire up an agent with TLS enabled.
|
||||
|
@ -161,16 +161,15 @@ func TestHTTPServer_H2(t *testing.T) {
|
|||
if err := http2.ConfigureTransport(transport); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
hc := &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
httpClient := &http.Client{Transport: transport}
|
||||
|
||||
// Hook a handler that echoes back the protocol.
|
||||
handler := func(resp http.ResponseWriter, req *http.Request) {
|
||||
resp.WriteHeader(http.StatusOK)
|
||||
fmt.Fprint(resp, req.Proto)
|
||||
}
|
||||
w, ok := a.srv.Handler.(*wrappedMux)
|
||||
|
||||
w, ok := a.srv.Server.Handler.(*wrappedMux)
|
||||
if !ok {
|
||||
t.Fatalf("handler is not expected type")
|
||||
}
|
||||
|
@ -178,7 +177,7 @@ func TestHTTPServer_H2(t *testing.T) {
|
|||
|
||||
// Call it and make sure we see HTTP/2.
|
||||
url := fmt.Sprintf("https://%s/echo", a.srv.ln.Addr().String())
|
||||
resp, err := hc.Get(url)
|
||||
resp, err := httpClient.Get(url)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
@ -197,7 +196,7 @@ func TestHTTPServer_H2(t *testing.T) {
|
|||
cfg := &api.Config{
|
||||
Address: a.srv.ln.Addr().String(),
|
||||
Scheme: "https",
|
||||
HttpClient: hc,
|
||||
HttpClient: httpClient,
|
||||
}
|
||||
client, err := api.NewClient(cfg)
|
||||
if err != nil {
|
||||
|
@ -333,7 +332,7 @@ func TestHTTPAPI_Ban_Nonprintable_Characters(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
resp := httptest.NewRecorder()
|
||||
a.srv.Handler.ServeHTTP(resp, req)
|
||||
a.srv.handler(true).ServeHTTP(resp, req)
|
||||
if got, want := resp.Code, http.StatusBadRequest; got != want {
|
||||
t.Fatalf("bad response code got %d want %d", got, want)
|
||||
}
|
||||
|
@ -352,7 +351,7 @@ func TestHTTPAPI_Allow_Nonprintable_Characters_With_Flag(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
resp := httptest.NewRecorder()
|
||||
a.srv.Handler.ServeHTTP(resp, req)
|
||||
a.srv.handler(true).ServeHTTP(resp, req)
|
||||
// Key doesn't actually exist so we should get 404
|
||||
if got, want := resp.Code, http.StatusNotFound; got != want {
|
||||
t.Fatalf("bad response code got %d want %d", got, want)
|
||||
|
@ -490,14 +489,14 @@ func TestAcceptEncodingGzip(t *testing.T) {
|
|||
// negotiation, but since this call doesn't go through a real
|
||||
// transport, the header has to be set manually
|
||||
req.Header["Accept-Encoding"] = []string{"gzip"}
|
||||
a.srv.Handler.ServeHTTP(resp, req)
|
||||
a.srv.handler(true).ServeHTTP(resp, req)
|
||||
require.Equal(t, 200, resp.Code)
|
||||
require.Equal(t, "", resp.Header().Get("Content-Encoding"))
|
||||
|
||||
resp = httptest.NewRecorder()
|
||||
req, _ = http.NewRequest("GET", "/v1/kv/long", nil)
|
||||
req.Header["Accept-Encoding"] = []string{"gzip"}
|
||||
a.srv.Handler.ServeHTTP(resp, req)
|
||||
a.srv.handler(true).ServeHTTP(resp, req)
|
||||
require.Equal(t, 200, resp.Code)
|
||||
require.Equal(t, "gzip", resp.Header().Get("Content-Encoding"))
|
||||
}
|
||||
|
@ -811,35 +810,35 @@ func TestParseWait(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestPProfHandlers_EnableDebug(t *testing.T) {
|
||||
func TestHTTPServer_PProfHandlers_EnableDebug(t *testing.T) {
|
||||
t.Parallel()
|
||||
require := require.New(t)
|
||||
a := NewTestAgent(t, "enable_debug = true")
|
||||
a := NewTestAgent(t, ``)
|
||||
defer a.Shutdown()
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/debug/pprof/profile?seconds=1", nil)
|
||||
|
||||
a.srv.Handler.ServeHTTP(resp, req)
|
||||
httpServer := &HTTPServer{agent: a.Agent}
|
||||
httpServer.handler(true).ServeHTTP(resp, req)
|
||||
|
||||
require.Equal(http.StatusOK, resp.Code)
|
||||
require.Equal(t, http.StatusOK, resp.Code)
|
||||
}
|
||||
|
||||
func TestPProfHandlers_DisableDebugNoACLs(t *testing.T) {
|
||||
func TestHTTPServer_PProfHandlers_DisableDebugNoACLs(t *testing.T) {
|
||||
t.Parallel()
|
||||
require := require.New(t)
|
||||
a := NewTestAgent(t, "enable_debug = false")
|
||||
a := NewTestAgent(t, ``)
|
||||
defer a.Shutdown()
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/debug/pprof/profile", nil)
|
||||
|
||||
a.srv.Handler.ServeHTTP(resp, req)
|
||||
httpServer := &HTTPServer{agent: a.Agent}
|
||||
httpServer.handler(false).ServeHTTP(resp, req)
|
||||
|
||||
require.Equal(http.StatusUnauthorized, resp.Code)
|
||||
require.Equal(t, http.StatusUnauthorized, resp.Code)
|
||||
}
|
||||
|
||||
func TestPProfHandlers_ACLs(t *testing.T) {
|
||||
func TestHTTPServer_PProfHandlers_ACLs(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert := assert.New(t)
|
||||
dc1 := "dc1"
|
||||
|
@ -904,7 +903,7 @@ func TestPProfHandlers_ACLs(t *testing.T) {
|
|||
t.Run(fmt.Sprintf("case %d (%#v)", i, c), func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", fmt.Sprintf("%s?token=%s", c.endpoint, c.token), nil)
|
||||
resp := httptest.NewRecorder()
|
||||
a.srv.Handler.ServeHTTP(resp, req)
|
||||
a.srv.handler(true).ServeHTTP(resp, req)
|
||||
assert.Equal(c.code, resp.Code)
|
||||
})
|
||||
}
|
||||
|
@ -1192,7 +1191,7 @@ func TestEnableWebUI(t *testing.T) {
|
|||
|
||||
req, _ := http.NewRequest("GET", "/ui/", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
a.srv.Handler.ServeHTTP(resp, req)
|
||||
a.srv.handler(true).ServeHTTP(resp, req)
|
||||
if resp.Code != 200 {
|
||||
t.Fatalf("should handle ui")
|
||||
}
|
||||
|
|
|
@ -384,7 +384,7 @@ func (a *TestAgent) HTTPAddr() string {
|
|||
if a.srv == nil {
|
||||
return ""
|
||||
}
|
||||
return a.srv.Addr
|
||||
return a.srv.Server.Addr
|
||||
}
|
||||
|
||||
func (a *TestAgent) SegmentAddr(name string) string {
|
||||
|
|
|
@ -43,7 +43,7 @@ func TestUiIndex(t *testing.T) {
|
|||
// Register node
|
||||
req, _ := http.NewRequest("GET", "/ui/my-file", nil)
|
||||
req.URL.Scheme = "http"
|
||||
req.URL.Host = a.srv.Addr
|
||||
req.URL.Host = a.srv.Server.Addr
|
||||
|
||||
// Make the request
|
||||
client := cleanhttp.DefaultClient()
|
||||
|
|
Loading…
Reference in New Issue