diff --git a/agent/agent.go b/agent/agent.go index 76fd923d4a..6a2a6fea04 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -797,18 +797,7 @@ func (a *Agent) listenHTTP() ([]apiServer, error) { httpServer.ConnState = connLimitFn } - servers = append(servers, apiServer{ - Protocol: proto, - Addr: l.Addr(), - Shutdown: httpServer.Shutdown, - Run: func() error { - err := httpServer.Serve(l) - if err == nil || err == http.ErrServerClosed { - return nil - } - return fmt.Errorf("%s server %s failed: %w", proto, l.Addr(), err) - }, - }) + servers = append(servers, newAPIServerHTTP(proto, l, httpServer)) } return nil } diff --git a/agent/agent_test.go b/agent/agent_test.go index a31a207c19..8d7d418ab0 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -13,6 +13,7 @@ import ( "net" "net/http" "net/http/httptest" + "net/url" "os" "path/filepath" "strconv" @@ -23,12 +24,23 @@ import ( "github.com/golang/protobuf/jsonpb" "github.com/google/tcpproxy" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/serf/coordinate" + "github.com/hashicorp/serf/serf" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + "golang.org/x/time/rate" + "gopkg.in/square/go-jose.v2/jwt" + "github.com/hashicorp/consul/agent/cache" cachetype "github.com/hashicorp/consul/agent/cache-types" "github.com/hashicorp/consul/agent/checks" "github.com/hashicorp/consul/agent/config" "github.com/hashicorp/consul/agent/connect" + "github.com/hashicorp/consul/agent/consul" "github.com/hashicorp/consul/agent/structs" + "github.com/hashicorp/consul/agent/token" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest" "github.com/hashicorp/consul/ipaddr" @@ -38,13 +50,8 @@ import ( "github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/testrpc" + "github.com/hashicorp/consul/tlsutil" "github.com/hashicorp/consul/types" - "github.com/hashicorp/serf/coordinate" - "github.com/hashicorp/serf/serf" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/time/rate" - "gopkg.in/square/go-jose.v2/jwt" ) func getService(a *TestAgent, id string) *structs.NodeService { @@ -4705,3 +4712,64 @@ func TestSharedRPCRouter(t *testing.T) { require.NotNil(t, mgr) require.NotNil(t, server) } + +func TestAgent_ListenHTTP_MultipleAddresses(t *testing.T) { + caConfig := tlsutil.Config{} + tlsConf, err := tlsutil.NewConfigurator(caConfig, hclog.New(nil)) + require.NoError(t, err) + bd := BaseDeps{ + Deps: consul.Deps{ + Logger: hclog.NewInterceptLogger(nil), + Tokens: new(token.Store), + TLSConfigurator: tlsConf, + }, + RuntimeConfig: &config.RuntimeConfig{ + HTTPAddrs: []net.Addr{ + &net.TCPAddr{IP: net.ParseIP("127.0.0.1")}, + &net.TCPAddr{IP: net.ParseIP("127.0.0.1")}, + }, + }, + Cache: cache.New(cache.Options{}), + } + agent, err := New(bd) + require.NoError(t, err) + + srvs, err := agent.listenHTTP() + require.NoError(t, err) + defer func() { + ctx := context.Background() + for _, srv := range srvs { + srv.Shutdown(ctx) + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + t.Cleanup(cancel) + + g := new(errgroup.Group) + for _, s := range srvs { + g.Go(s.Run) + } + + require.Len(t, srvs, 2) + require.Len(t, uniqueAddrs(srvs), 2) + + client := &http.Client{} + for _, s := range srvs { + u := url.URL{Scheme: s.Protocol, Host: s.Addr.String()} + req, err := http.NewRequest(http.MethodGet, u.String(), nil) + require.NoError(t, err) + + resp, err := client.Do(req.WithContext(ctx)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + } +} + +func uniqueAddrs(srvs []apiServer) map[string]struct{} { + result := make(map[string]struct{}, len(srvs)) + for _, s := range srvs { + result[s.Addr.String()] = struct{}{} + } + return result +} diff --git a/agent/apiserver.go b/agent/apiserver.go index 27087829a6..044bf60412 100644 --- a/agent/apiserver.go +++ b/agent/apiserver.go @@ -2,7 +2,9 @@ package agent import ( "context" + "fmt" "net" + "net/http" "sync" "time" @@ -92,3 +94,18 @@ func (s *apiServers) Shutdown(ctx context.Context) { func (s *apiServers) WaitForShutdown() error { return s.group.Wait() } + +func newAPIServerHTTP(proto string, l net.Listener, httpServer *http.Server) apiServer { + return apiServer{ + Protocol: proto, + Addr: l.Addr(), + Shutdown: httpServer.Shutdown, + Run: func() error { + err := httpServer.Serve(l) + if err == nil || err == http.ErrServerClosed { + return nil + } + return fmt.Errorf("%s server %s failed: %w", proto, l.Addr(), err) + }, + } +}