2023-03-28 18:39:22 +00:00
|
|
|
// Copyright (c) HashiCorp, Inc.
|
2023-08-11 13:12:13 +00:00
|
|
|
// SPDX-License-Identifier: BUSL-1.1
|
2023-03-28 18:39:22 +00:00
|
|
|
|
2020-07-02 17:31:47 +00:00
|
|
|
package agent
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
2020-11-18 17:22:07 +00:00
|
|
|
"fmt"
|
2020-07-02 17:31:47 +00:00
|
|
|
"net"
|
2020-11-18 17:22:07 +00:00
|
|
|
"net/http"
|
2020-07-02 17:31:47 +00:00
|
|
|
"sync"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/hashicorp/go-hclog"
|
|
|
|
"golang.org/x/sync/errgroup"
|
|
|
|
)
|
|
|
|
|
|
|
|
// apiServers is a wrapper around errgroup.Group for managing go routines for
|
|
|
|
// long running agent components (ex: http server, dns server). If any of the
|
|
|
|
// servers fail, the failed channel will be closed, which will cause the agent
|
|
|
|
// to be shutdown instead of running in a degraded state.
|
|
|
|
//
|
|
|
|
// This struct exists as a shim for using errgroup.Group without making major
|
|
|
|
// changes to Agent. In the future it may be removed and replaced with more
|
|
|
|
// direct usage of errgroup.Group.
|
|
|
|
type apiServers struct {
|
|
|
|
logger hclog.Logger
|
|
|
|
group *errgroup.Group
|
|
|
|
servers []apiServer
|
|
|
|
// failed channel is closed when the first server goroutines exit with a
|
|
|
|
// non-nil error.
|
|
|
|
failed <-chan struct{}
|
|
|
|
}
|
|
|
|
|
|
|
|
type apiServer struct {
|
|
|
|
// Protocol supported by this server. One of: dns, http, https
|
|
|
|
Protocol string
|
|
|
|
// Addr the server is listening on
|
|
|
|
Addr net.Addr
|
|
|
|
// Run will be called in a goroutine to run the server. When any Run exits
|
|
|
|
// with a non-nil error, the failed channel will be closed.
|
|
|
|
Run func() error
|
|
|
|
// Shutdown function used to stop the server
|
|
|
|
Shutdown func(context.Context) error
|
|
|
|
}
|
|
|
|
|
|
|
|
// NewAPIServers returns an empty apiServers that is ready to Start servers.
|
|
|
|
func NewAPIServers(logger hclog.Logger) *apiServers {
|
|
|
|
group, ctx := errgroup.WithContext(context.TODO())
|
|
|
|
return &apiServers{
|
|
|
|
logger: logger,
|
|
|
|
group: group,
|
|
|
|
failed: ctx.Done(),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *apiServers) Start(srv apiServer) {
|
|
|
|
srv.logger(s.logger).Info("Starting server")
|
|
|
|
s.servers = append(s.servers, srv)
|
|
|
|
s.group.Go(srv.Run)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s apiServer) logger(base hclog.Logger) hclog.Logger {
|
|
|
|
return base.With(
|
|
|
|
"protocol", s.Protocol,
|
|
|
|
"address", s.Addr.String(),
|
|
|
|
"network", s.Addr.Network())
|
|
|
|
}
|
|
|
|
|
|
|
|
// Shutdown all the servers and log any errors as warning. Each server is given
|
|
|
|
// 1 second, or until ctx is cancelled, to shutdown gracefully.
|
|
|
|
func (s *apiServers) Shutdown(ctx context.Context) {
|
|
|
|
shutdownGroup := new(sync.WaitGroup)
|
|
|
|
|
|
|
|
for i := range s.servers {
|
|
|
|
server := s.servers[i]
|
|
|
|
shutdownGroup.Add(1)
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
defer shutdownGroup.Done()
|
|
|
|
logger := server.logger(s.logger)
|
|
|
|
logger.Info("Stopping server")
|
|
|
|
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
|
|
|
defer cancel()
|
|
|
|
if err := server.Shutdown(ctx); err != nil {
|
|
|
|
logger.Warn("Failed to stop server")
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
}
|
|
|
|
s.servers = nil
|
|
|
|
shutdownGroup.Wait()
|
|
|
|
}
|
|
|
|
|
|
|
|
// WaitForShutdown waits until all server goroutines have exited. Shutdown
|
|
|
|
// must be called before WaitForShutdown, otherwise it will block forever.
|
|
|
|
func (s *apiServers) WaitForShutdown() error {
|
|
|
|
return s.group.Wait()
|
|
|
|
}
|
2020-11-18 17:22:07 +00:00
|
|
|
|
|
|
|
func newAPIServerHTTP(proto string, l net.Listener, httpServer *http.Server) apiServer {
|
|
|
|
return apiServer{
|
|
|
|
Protocol: proto,
|
|
|
|
Addr: l.Addr(),
|
|
|
|
Shutdown: httpServer.Shutdown,
|
|
|
|
Run: func() error {
|
|
|
|
err := httpServer.Serve(l)
|
|
|
|
if err == nil || err == http.ErrServerClosed {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
return fmt.Errorf("%s server %s failed: %w", proto, l.Addr(), err)
|
|
|
|
},
|
|
|
|
}
|
|
|
|
}
|