diff --git a/agent/agent.go b/agent/agent.go index 0c639da9a7..20486af125 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -259,10 +259,12 @@ type Agent struct { // dnsServer provides the DNS API dnsServers []*DNSServer - // httpServers provides the HTTP API on various endpoints - httpServers []*HTTPServer + // apiServers listening for connections. If any of these server goroutines + // fail, the agent will be shutdown. + apiServers *apiServers // wgServers is the wait group for all HTTP and DNS servers + // TODO: remove once dnsServers are handled by apiServers wgServers sync.WaitGroup // watchPlans tracks all the currently-running watch plans for the @@ -375,6 +377,9 @@ func New(bd BaseDeps) (*Agent, error) { a.loadTokens(a.config) a.loadEnterpriseTokens(a.config) + // TODO: pass in a fully populated apiServers into Agent.New + a.apiServers = NewAPIServers(a.logger) + return &a, nil } @@ -580,10 +585,7 @@ func (a *Agent) Start(ctx context.Context) error { // Start HTTP and HTTPS servers. for _, srv := range servers { - if err := a.serveHTTP(srv); err != nil { - return err - } - a.httpServers = append(a.httpServers, srv) + a.apiServers.Start(srv) } // Start gRPC server. @@ -605,6 +607,12 @@ func (a *Agent) Start(ctx context.Context) error { return nil } +// Failed returns a channel which is closed when the first server goroutine exits +// with a non-nil error. +func (a *Agent) Failed() <-chan struct{} { + return a.apiServers.failed +} + func (a *Agent) listenAndServeGRPC() error { if len(a.config.GRPCAddrs) < 1 { return nil @@ -737,14 +745,16 @@ func (a *Agent) startListeners(addrs []net.Addr) ([]net.Listener, error) { // // This approach should ultimately be refactored to the point where we just // start the server and any error should trigger a proper shutdown of the agent. -func (a *Agent) listenHTTP() ([]*HTTPServer, error) { +func (a *Agent) listenHTTP() ([]apiServer, error) { var ln []net.Listener - var servers []*HTTPServer + var servers []apiServer + start := func(proto string, addrs []net.Addr) error { listeners, err := a.startListeners(addrs) if err != nil { return err } + ln = append(ln, listeners...) for _, l := range listeners { var tlscfg *tls.Config @@ -754,18 +764,15 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) { l = tls.NewListener(l, tlscfg) } + srv := &HTTPServer{ + agent: a, + denylist: NewDenylist(a.config.HTTPBlockEndpoints), + } httpServer := &http.Server{ Addr: l.Addr().String(), TLSConfig: tlscfg, + Handler: srv.handler(a.config.EnableDebug), } - srv := &HTTPServer{ - Server: httpServer, - ln: l, - agent: a, - denylist: NewDenylist(a.config.HTTPBlockEndpoints), - proto: proto, - } - httpServer.Handler = srv.handler(a.config.EnableDebug) // Load the connlimit helper into the server connLimitFn := a.httpConnLimiter.HTTPConnStateFuncWithDefault429Handler(10 * time.Millisecond) @@ -778,27 +785,39 @@ func (a *Agent) listenHTTP() ([]*HTTPServer, error) { httpServer.ConnState = connLimitFn } - ln = append(ln, l) - servers = append(servers, srv) + servers = append(servers, apiServer{ + Protocol: proto, + Addr: l.Addr(), + Shutdown: httpServer.Shutdown, + Run: func() error { + err := httpServer.Serve(l) + if err == nil || err == http.ErrServerClosed { + return nil + } + return fmt.Errorf("%s server %s failed: %w", proto, l.Addr(), err) + }, + }) } return nil } if err := start("http", a.config.HTTPAddrs); err != nil { - for _, l := range ln { - l.Close() - } + closeListeners(ln) return nil, err } if err := start("https", a.config.HTTPSAddrs); err != nil { - for _, l := range ln { - l.Close() - } + closeListeners(ln) return nil, err } return servers, nil } +func closeListeners(lns []net.Listener) { + for _, l := range lns { + l.Close() + } +} + // setupHTTPS adds HTTP/2 support, ConnState, and a connection handshake timeout // to the http.Server. func setupHTTPS(server *http.Server, connState func(net.Conn, http.ConnState), timeout time.Duration) error { @@ -860,43 +879,6 @@ func (a *Agent) listenSocket(path string) (net.Listener, error) { return l, nil } -func (a *Agent) serveHTTP(srv *HTTPServer) error { - // https://github.com/golang/go/issues/20239 - // - // In go.8.1 there is a race between Serve and Shutdown. If - // Shutdown is called before the Serve go routine was scheduled then - // the Serve go routine never returns. This deadlocks the agent - // shutdown for some tests since it will wait forever. - notif := make(chan net.Addr) - a.wgServers.Add(1) - go func() { - defer a.wgServers.Done() - notif <- srv.ln.Addr() - err := srv.Server.Serve(srv.ln) - if err != nil && err != http.ErrServerClosed { - a.logger.Error("error closing server", "error", err) - } - }() - - select { - case addr := <-notif: - if srv.proto == "https" { - a.logger.Info("Started HTTPS server", - "address", addr.String(), - "network", addr.Network(), - ) - } else { - a.logger.Info("Started HTTP server", - "address", addr.String(), - "network", addr.Network(), - ) - } - return nil - case <-time.After(time.Second): - return fmt.Errorf("agent: timeout starting HTTP servers") - } -} - // stopAllWatches stops all the currently running watches func (a *Agent) stopAllWatches() { for _, wp := range a.watchPlans { @@ -1395,13 +1377,12 @@ func (a *Agent) ShutdownAgent() error { // ShutdownEndpoints terminates the HTTP and DNS servers. Should be // preceded by ShutdownAgent. +// TODO: remove this method, move to ShutdownAgent func (a *Agent) ShutdownEndpoints() { a.shutdownLock.Lock() defer a.shutdownLock.Unlock() - if len(a.dnsServers) == 0 && len(a.httpServers) == 0 { - return - } + ctx := context.TODO() for _, srv := range a.dnsServers { if srv.Server != nil { @@ -1415,27 +1396,11 @@ func (a *Agent) ShutdownEndpoints() { } a.dnsServers = nil - for _, srv := range a.httpServers { - a.logger.Info("Stopping server", - "protocol", strings.ToUpper(srv.proto), - "address", srv.ln.Addr().String(), - "network", srv.ln.Addr().Network(), - ) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - srv.Server.Shutdown(ctx) - if ctx.Err() == context.DeadlineExceeded { - a.logger.Warn("Timeout stopping server", - "protocol", strings.ToUpper(srv.proto), - "address", srv.ln.Addr().String(), - "network", srv.ln.Addr().Network(), - ) - } - } - a.httpServers = nil - + a.apiServers.Shutdown(ctx) a.logger.Info("Waiting for endpoints to shut down") - a.wgServers.Wait() + if err := a.apiServers.WaitForShutdown(); err != nil { + a.logger.Error(err.Error()) + } a.logger.Info("Endpoints down") } diff --git a/agent/agent_test.go b/agent/agent_test.go index 479421f591..472f1b652d 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -1917,7 +1917,7 @@ func TestAgent_HTTPCheck_EnableAgentTLSForChecks(t *testing.T) { Status: api.HealthCritical, } - url := fmt.Sprintf("https://%s/v1/agent/self", a.srv.ln.Addr().String()) + url := fmt.Sprintf("https://%s/v1/agent/self", a.HTTPAddr()) chk := &structs.CheckType{ HTTP: url, Interval: 20 * time.Millisecond, diff --git a/agent/apiserver.go b/agent/apiserver.go new file mode 100644 index 0000000000..27087829a6 --- /dev/null +++ b/agent/apiserver.go @@ -0,0 +1,94 @@ +package agent + +import ( + "context" + "net" + "sync" + "time" + + "github.com/hashicorp/go-hclog" + "golang.org/x/sync/errgroup" +) + +// apiServers is a wrapper around errgroup.Group for managing go routines for +// long running agent components (ex: http server, dns server). If any of the +// servers fail, the failed channel will be closed, which will cause the agent +// to be shutdown instead of running in a degraded state. +// +// This struct exists as a shim for using errgroup.Group without making major +// changes to Agent. In the future it may be removed and replaced with more +// direct usage of errgroup.Group. +type apiServers struct { + logger hclog.Logger + group *errgroup.Group + servers []apiServer + // failed channel is closed when the first server goroutines exit with a + // non-nil error. + failed <-chan struct{} +} + +type apiServer struct { + // Protocol supported by this server. One of: dns, http, https + Protocol string + // Addr the server is listening on + Addr net.Addr + // Run will be called in a goroutine to run the server. When any Run exits + // with a non-nil error, the failed channel will be closed. + Run func() error + // Shutdown function used to stop the server + Shutdown func(context.Context) error +} + +// NewAPIServers returns an empty apiServers that is ready to Start servers. +func NewAPIServers(logger hclog.Logger) *apiServers { + group, ctx := errgroup.WithContext(context.TODO()) + return &apiServers{ + logger: logger, + group: group, + failed: ctx.Done(), + } +} + +func (s *apiServers) Start(srv apiServer) { + srv.logger(s.logger).Info("Starting server") + s.servers = append(s.servers, srv) + s.group.Go(srv.Run) +} + +func (s apiServer) logger(base hclog.Logger) hclog.Logger { + return base.With( + "protocol", s.Protocol, + "address", s.Addr.String(), + "network", s.Addr.Network()) +} + +// Shutdown all the servers and log any errors as warning. Each server is given +// 1 second, or until ctx is cancelled, to shutdown gracefully. +func (s *apiServers) Shutdown(ctx context.Context) { + shutdownGroup := new(sync.WaitGroup) + + for i := range s.servers { + server := s.servers[i] + shutdownGroup.Add(1) + + go func() { + defer shutdownGroup.Done() + logger := server.logger(s.logger) + logger.Info("Stopping server") + + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + if err := server.Shutdown(ctx); err != nil { + logger.Warn("Failed to stop server") + } + }() + } + s.servers = nil + shutdownGroup.Wait() +} + +// WaitForShutdown waits until all server goroutines have exited. Shutdown +// must be called before WaitForShutdown, otherwise it will block forever. +func (s *apiServers) WaitForShutdown() error { + return s.group.Wait() +} diff --git a/agent/apiserver_test.go b/agent/apiserver_test.go new file mode 100644 index 0000000000..72f8c6d651 --- /dev/null +++ b/agent/apiserver_test.go @@ -0,0 +1,65 @@ +package agent + +import ( + "context" + "fmt" + "net" + "testing" + "time" + + "github.com/hashicorp/go-hclog" + "github.com/stretchr/testify/require" +) + +func TestAPIServers_WithServiceRunError(t *testing.T) { + servers := NewAPIServers(hclog.New(nil)) + + server1, chErr1 := newAPIServerStub() + server2, _ := newAPIServerStub() + + t.Run("Start", func(t *testing.T) { + servers.Start(server1) + servers.Start(server2) + + select { + case <-servers.failed: + t.Fatalf("expected servers to still be running") + case <-time.After(5 * time.Millisecond): + } + }) + + err := fmt.Errorf("oops, I broke") + + t.Run("server exit non-nil error", func(t *testing.T) { + chErr1 <- err + + select { + case <-servers.failed: + case <-time.After(time.Second): + t.Fatalf("expected failed channel to be closed") + } + }) + + t.Run("shutdown remaining services", func(t *testing.T) { + servers.Shutdown(context.Background()) + require.Equal(t, err, servers.WaitForShutdown()) + }) +} + +func newAPIServerStub() (apiServer, chan error) { + chErr := make(chan error) + return apiServer{ + Protocol: "http", + Addr: &net.TCPAddr{ + IP: net.ParseIP("127.0.0.11"), + Port: 5505, + }, + Run: func() error { + return <-chErr + }, + Shutdown: func(ctx context.Context) error { + close(chErr) + return nil + }, + }, chErr +} diff --git a/agent/http.go b/agent/http.go index bac6c172c2..dc9438230d 100644 --- a/agent/http.go +++ b/agent/http.go @@ -80,16 +80,14 @@ func (e ForbiddenError) Error() string { } // HTTPServer provides an HTTP api for an agent. +// +// TODO: rename this struct to something more appropriate. It is an http.Handler, +// request router or multiplexer, but it is not a Server. type HTTPServer struct { - // TODO(dnephin): remove Server field, it is not used by any of the HTTPServer methods - Server *http.Server - ln net.Listener agent *Agent denylist *Denylist - - // proto is filled by the agent to "http" or "https". - proto string } + type templatedFile struct { templated *bytes.Reader name string diff --git a/agent/http_test.go b/agent/http_test.go index 6574b89180..36ecf387b3 100644 --- a/agent/http_test.go +++ b/agent/http_test.go @@ -1353,7 +1353,7 @@ func TestHTTPServer_HandshakeTimeout(t *testing.T) { // Connect to it with a plain TCP client that doesn't attempt to send HTTP or // complete a TLS handshake. - conn, err := net.Dial("tcp", a.srv.ln.Addr().String()) + conn, err := net.Dial("tcp", a.HTTPAddr()) require.NoError(t, err) defer conn.Close() @@ -1413,7 +1413,7 @@ func TestRPC_HTTPSMaxConnsPerClient(t *testing.T) { }) defer a.Shutdown() - addr := a.srv.ln.Addr() + addr := a.HTTPAddr() assertConn := func(conn net.Conn, wantOpen bool) { retry.Run(t, func(r *retry.R) { @@ -1433,21 +1433,21 @@ func TestRPC_HTTPSMaxConnsPerClient(t *testing.T) { } // Connect to the server with bare TCP - conn1, err := net.DialTimeout("tcp", addr.String(), time.Second) + conn1, err := net.DialTimeout("tcp", addr, time.Second) require.NoError(t, err) defer conn1.Close() assertConn(conn1, true) // Two conns should succeed - conn2, err := net.DialTimeout("tcp", addr.String(), time.Second) + conn2, err := net.DialTimeout("tcp", addr, time.Second) require.NoError(t, err) defer conn2.Close() assertConn(conn2, true) // Third should succeed negotiating TCP handshake... - conn3, err := net.DialTimeout("tcp", addr.String(), time.Second) + conn3, err := net.DialTimeout("tcp", addr, time.Second) require.NoError(t, err) defer conn3.Close() @@ -1460,7 +1460,7 @@ func TestRPC_HTTPSMaxConnsPerClient(t *testing.T) { require.NoError(t, a.reloadConfigInternal(&newCfg)) // Now another conn should be allowed - conn4, err := net.DialTimeout("tcp", addr.String(), time.Second) + conn4, err := net.DialTimeout("tcp", addr, time.Second) require.NoError(t, err) defer conn4.Close() diff --git a/agent/testagent.go b/agent/testagent.go index 8f05b6ed47..fa3508ffab 100644 --- a/agent/testagent.go +++ b/agent/testagent.go @@ -73,8 +73,7 @@ type TestAgent struct { // It is valid after Start(). dns *DNSServer - // srv is a reference to the first started HTTP endpoint. - // It is valid after Start(). + // srv is an HTTPServer that may be used to test http endpoints. srv *HTTPServer // overrides is an hcl config source to use to override otherwise @@ -213,6 +212,8 @@ func (a *TestAgent) Start(t *testing.T) (err error) { // Start the anti-entropy syncer a.Agent.StartSync() + a.srv = &HTTPServer{agent: agent, denylist: NewDenylist(a.config.HTTPBlockEndpoints)} + if err := a.waitForUp(); err != nil { a.Shutdown() t.Logf("Error while waiting for test agent to start: %v", err) @@ -220,7 +221,6 @@ func (a *TestAgent) Start(t *testing.T) (err error) { } a.dns = a.dnsServers[0] - a.srv = a.httpServers[0] return nil } @@ -233,7 +233,7 @@ func (a *TestAgent) waitForUp() error { var retErr error var out structs.IndexedNodes for ; !time.Now().After(deadline); time.Sleep(timer.Wait) { - if len(a.httpServers) == 0 { + if len(a.apiServers.servers) == 0 { retErr = fmt.Errorf("waiting for server") continue // fail, try again } @@ -262,7 +262,7 @@ func (a *TestAgent) waitForUp() error { } else { req := httptest.NewRequest("GET", "/v1/agent/self", nil) resp := httptest.NewRecorder() - _, err := a.httpServers[0].AgentSelf(resp, req) + _, err := a.srv.AgentSelf(resp, req) if acl.IsErrPermissionDenied(err) || resp.Code == 403 { // permission denied is enough to show that the client is // connected to the servers as it would get a 503 if @@ -313,10 +313,13 @@ func (a *TestAgent) DNSAddr() string { } func (a *TestAgent) HTTPAddr() string { - if a.srv == nil { - return "" + var srv apiServer + for _, srv = range a.Agent.apiServers.servers { + if srv.Protocol == "http" { + break + } } - return a.srv.Server.Addr + return srv.Addr.String() } func (a *TestAgent) SegmentAddr(name string) string { diff --git a/agent/ui_endpoint_test.go b/agent/ui_endpoint_test.go index 876d4a97c8..4640bcfebb 100644 --- a/agent/ui_endpoint_test.go +++ b/agent/ui_endpoint_test.go @@ -41,7 +41,7 @@ func TestUiIndex(t *testing.T) { // Register node req, _ := http.NewRequest("GET", "/ui/my-file", nil) req.URL.Scheme = "http" - req.URL.Host = a.srv.Server.Addr + req.URL.Host = a.HTTPAddr() // Make the request client := cleanhttp.DefaultClient() diff --git a/command/agent/agent.go b/command/agent/agent.go index 7da6613066..1e06ef90be 100644 --- a/command/agent/agent.go +++ b/command/agent/agent.go @@ -288,6 +288,9 @@ func (c *cmd) run(args []string) int { case err := <-agent.RetryJoinCh(): c.logger.Error("Retry join failed", "error", err) return 1 + case <-agent.Failed(): + // The deferred Shutdown method will log the appropriate error + return 1 case <-agent.ShutdownCh(): // agent is already down! return 0