diff --git a/.changelog/17270.txt b/.changelog/17270.txt new file mode 100644 index 0000000000..b9bd52888e --- /dev/null +++ b/.changelog/17270.txt @@ -0,0 +1,3 @@ +```release-note:bug +grpc: ensure grpc resolver correctly uses lan/wan addresses on servers +``` diff --git a/agent/consul/client_test.go b/agent/consul/client_test.go index 419867a4eb..3f7542d090 100644 --- a/agent/consul/client_test.go +++ b/agent/consul/client_test.go @@ -504,11 +504,15 @@ func newClient(t *testing.T, config *Config) *Client { return client } -func newTestResolverConfig(t *testing.T, suffix string) resolver.Config { +func newTestResolverConfig(t *testing.T, suffix string, dc, agentType string) resolver.Config { n := t.Name() s := strings.Replace(n, "/", "", -1) s = strings.Replace(s, "_", "", -1) - return resolver.Config{Authority: strings.ToLower(s) + "-" + suffix} + return resolver.Config{ + Datacenter: dc, + AgentType: agentType, + Authority: strings.ToLower(s) + "-" + suffix, + } } func newDefaultDeps(t *testing.T, c *Config) Deps { @@ -523,7 +527,7 @@ func newDefaultDeps(t *testing.T, c *Config) Deps { tls, err := tlsutil.NewConfigurator(c.TLSConfig, logger) require.NoError(t, err, "failed to create tls configuration") - resolverBuilder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, c.NodeName+"-"+c.Datacenter)) + resolverBuilder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, c.NodeName+"-"+c.Datacenter, c.Datacenter, "server")) resolver.Register(resolverBuilder) t.Cleanup(func() { resolver.Deregister(resolverBuilder.Authority()) diff --git a/agent/consul/server_serf.go b/agent/consul/server_serf.go index bb17235b14..1dc6c25b1c 100644 --- a/agent/consul/server_serf.go +++ b/agent/consul/server_serf.go @@ -23,6 +23,7 @@ import ( "github.com/hashicorp/consul/lib" libserf "github.com/hashicorp/consul/lib/serf" "github.com/hashicorp/consul/logging" + "github.com/hashicorp/consul/types" ) const ( @@ -359,6 +360,7 @@ func (s *Server) lanNodeJoin(me serf.MemberEvent) { // Update server lookup s.serverLookup.AddServer(serverMeta) + s.router.AddServer(types.AreaLAN, serverMeta) // If we're still expecting to bootstrap, may need to handle this. if s.config.BootstrapExpect != 0 { @@ -380,6 +382,7 @@ func (s *Server) lanNodeUpdate(me serf.MemberEvent) { // Update server lookup s.serverLookup.AddServer(serverMeta) + s.router.AddServer(types.AreaLAN, serverMeta) } } @@ -518,5 +521,6 @@ func (s *Server) lanNodeFailed(me serf.MemberEvent) { // Update id to address map s.serverLookup.RemoveServer(serverMeta) + s.router.RemoveServer(types.AreaLAN, serverMeta) } } diff --git a/agent/consul/subscribe_backend_test.go b/agent/consul/subscribe_backend_test.go index 01c8184518..833f049c97 100644 --- a/agent/consul/subscribe_backend_test.go +++ b/agent/consul/subscribe_backend_test.go @@ -382,7 +382,10 @@ func newClientWithGRPCPlumbing(t *testing.T, ops ...func(*Config)) (*Client, *re } resolverBuilder := resolver.NewServerResolverBuilder(newTestResolverConfig(t, - "client."+config.Datacenter+"."+string(config.NodeID))) + "client."+config.Datacenter+"."+string(config.NodeID), + config.Datacenter, + "client", + )) resolver.Register(resolverBuilder) t.Cleanup(func() { diff --git a/agent/grpc-internal/client_test.go b/agent/grpc-internal/client_test.go index ad6906a958..a3b99e78ad 100644 --- a/agent/grpc-internal/client_test.go +++ b/agent/grpc-internal/client_test.go @@ -38,8 +38,8 @@ func TestNewDialer_WithTLSWrapper(t *testing.T) { require.NoError(t, err) t.Cleanup(logError(t, lis.Close)) - builder := resolver.NewServerResolverBuilder(newConfig(t)) - builder.AddServer(types.AreaWAN, &metadata.Server{ + builder := resolver.NewServerResolverBuilder(newConfig(t, "dc1", "server")) + builder.AddServer(types.AreaLAN, &metadata.Server{ Name: "server-1", ID: "ID1", Datacenter: "dc1", @@ -89,7 +89,7 @@ func TestNewDialer_WithALPNWrapper(t *testing.T) { p.Wait() }() - builder := resolver.NewServerResolverBuilder(newConfig(t)) + builder := resolver.NewServerResolverBuilder(newConfig(t, "dc1", "server")) builder.AddServer(types.AreaWAN, &metadata.Server{ Name: "server-1", ID: "ID1", @@ -144,7 +144,7 @@ func TestNewDialer_WithALPNWrapper(t *testing.T) { func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) { // if this test is failing because of expired certificates // use the procedure in test/CA-GENERATION.md - res := resolver.NewServerResolverBuilder(newConfig(t)) + res := resolver.NewServerResolverBuilder(newConfig(t, "dc1", "server")) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) registerWithGRPC(t, res, bb) @@ -162,9 +162,17 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler(t *testing.T) { srv := newSimpleTestServer(t, "server-1", "dc1", tlsConf) md := srv.Metadata() - res.AddServer(types.AreaWAN, md) + res.AddServer(types.AreaLAN, md) t.Cleanup(srv.shutdown) + { + // Put a duplicate instance of this on the WAN that will + // fail if we accidentally use it. + srv := newPanicTestServer(t, hclog.Default(), "server-1", "dc1", nil) + res.AddServer(types.AreaWAN, srv.Metadata()) + t.Cleanup(srv.shutdown) + } + pool := NewClientConnPool(ClientConnPoolConfig{ Servers: res, TLSWrapper: TLSWrapper(tlsConf.OutgoingRPCWrapper()), @@ -192,7 +200,7 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T) // use the procedure in test/CA-GENERATION.md gwAddr := ipaddr.FormatAddressPort("127.0.0.1", freeport.GetOne(t)) - res := resolver.NewServerResolverBuilder(newConfig(t)) + res := resolver.NewServerResolverBuilder(newConfig(t, "dc2", "server")) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) registerWithGRPC(t, res, bb) @@ -268,7 +276,7 @@ func TestNewDialer_IntegrationWithTLSEnabledHandler_viaMeshGateway(t *testing.T) func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { count := 4 - res := resolver.NewServerResolverBuilder(newConfig(t)) + res := resolver.NewServerResolverBuilder(newConfig(t, "dc1", "server")) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) registerWithGRPC(t, res, bb) pool := NewClientConnPool(ClientConnPoolConfig{ @@ -280,9 +288,18 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { for i := 0; i < count; i++ { name := fmt.Sprintf("server-%d", i) - srv := newSimpleTestServer(t, name, "dc1", nil) - res.AddServer(types.AreaWAN, srv.Metadata()) - t.Cleanup(srv.shutdown) + { + srv := newSimpleTestServer(t, name, "dc1", nil) + res.AddServer(types.AreaLAN, srv.Metadata()) + t.Cleanup(srv.shutdown) + } + { + // Put a duplicate instance of this on the WAN that will + // fail if we accidentally use it. + srv := newPanicTestServer(t, hclog.Default(), name, "dc1", nil) + res.AddServer(types.AreaWAN, srv.Metadata()) + t.Cleanup(srv.shutdown) + } } conn, err := pool.ClientConn("dc1") @@ -295,7 +312,7 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { first, err := client.Something(ctx, &testservice.Req{}) require.NoError(t, err) - res.RemoveServer(types.AreaWAN, &metadata.Server{ID: first.ServerName, Datacenter: "dc1"}) + res.RemoveServer(types.AreaLAN, &metadata.Server{ID: first.ServerName, Datacenter: "dc1"}) resp, err := client.Something(ctx, &testservice.Req{}) require.NoError(t, err) @@ -304,7 +321,7 @@ func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) { count := 3 - res := resolver.NewServerResolverBuilder(newConfig(t)) + res := resolver.NewServerResolverBuilder(newConfig(t, "dc1", "server")) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) registerWithGRPC(t, res, bb) pool := NewClientConnPool(ClientConnPoolConfig{ @@ -317,10 +334,19 @@ func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) { var servers []testServer for i := 0; i < count; i++ { name := fmt.Sprintf("server-%d", i) - srv := newSimpleTestServer(t, name, "dc1", nil) - res.AddServer(types.AreaWAN, srv.Metadata()) - servers = append(servers, srv) - t.Cleanup(srv.shutdown) + { + srv := newSimpleTestServer(t, name, "dc1", nil) + res.AddServer(types.AreaLAN, srv.Metadata()) + servers = append(servers, srv) + t.Cleanup(srv.shutdown) + } + { + // Put a duplicate instance of this on the WAN that will + // fail if we accidentally use it. + srv := newPanicTestServer(t, hclog.Default(), name, "dc1", nil) + res.AddServer(types.AreaWAN, srv.Metadata()) + t.Cleanup(srv.shutdown) + } } // Set the leader address to the first server. @@ -347,19 +373,24 @@ func TestClientConnPool_ForwardToLeader_Failover(t *testing.T) { require.Equal(t, resp.ServerName, servers[1].name) } -func newConfig(t *testing.T) resolver.Config { +func newConfig(t *testing.T, dc, agentType string) resolver.Config { n := t.Name() s := strings.Replace(n, "/", "", -1) s = strings.Replace(s, "_", "", -1) - return resolver.Config{Authority: strings.ToLower(s)} + return resolver.Config{ + Datacenter: dc, + AgentType: agentType, + Authority: strings.ToLower(s), + } } func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) { dcs := []string{"dc1", "dc2", "dc3"} - res := resolver.NewServerResolverBuilder(newConfig(t)) + res := resolver.NewServerResolverBuilder(newConfig(t, "dc1", "server")) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) registerWithGRPC(t, res, bb) + pool := NewClientConnPool(ClientConnPoolConfig{ Servers: res, UseTLSForDC: useTLSForDcAlwaysTrue, @@ -370,7 +401,16 @@ func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) { for _, dc := range dcs { name := "server-0-" + dc srv := newSimpleTestServer(t, name, dc, nil) - res.AddServer(types.AreaWAN, srv.Metadata()) + if dc == "dc1" { + res.AddServer(types.AreaLAN, srv.Metadata()) + // Put a duplicate instance of this on the WAN that will + // fail if we accidentally use it. + srvBad := newPanicTestServer(t, hclog.Default(), name, dc, nil) + res.AddServer(types.AreaWAN, srvBad.Metadata()) + t.Cleanup(srvBad.shutdown) + } else { + res.AddServer(types.AreaWAN, srv.Metadata()) + } t.Cleanup(srv.shutdown) } diff --git a/agent/grpc-internal/handler_test.go b/agent/grpc-internal/handler_test.go index e39f93be39..80c026113d 100644 --- a/agent/grpc-internal/handler_test.go +++ b/agent/grpc-internal/handler_test.go @@ -31,12 +31,12 @@ func TestHandler_PanicRecoveryInterceptor(t *testing.T) { Output: &buf, }) - res := resolver.NewServerResolverBuilder(newConfig(t)) + res := resolver.NewServerResolverBuilder(newConfig(t, "dc1", "server")) bb := balancer.NewBuilder(res.Authority(), testutil.Logger(t)) registerWithGRPC(t, res, bb) srv := newPanicTestServer(t, logger, "server-1", "dc1", nil) - res.AddServer(types.AreaWAN, srv.Metadata()) + res.AddServer(types.AreaLAN, srv.Metadata()) t.Cleanup(srv.shutdown) pool := NewClientConnPool(ClientConnPoolConfig{ diff --git a/agent/grpc-internal/resolver/resolver.go b/agent/grpc-internal/resolver/resolver.go index f4d1e95ea0..bb2224ee08 100644 --- a/agent/grpc-internal/resolver/resolver.go +++ b/agent/grpc-internal/resolver/resolver.go @@ -18,25 +18,45 @@ import ( // ServerResolvers updated when changes occur. type ServerResolverBuilder struct { cfg Config + // leaderResolver is used to track the address of the leader in the local DC. leaderResolver leaderResolver + // servers is an index of Servers by area and Server.ID. The map contains server IDs // for all datacenters. servers map[types.AreaID]map[string]*metadata.Server + // resolvers is an index of connections to the serverResolver which manages // addresses of servers for that connection. + // + // this is only applicable for non-leader conn types resolvers map[resolver.ClientConn]*serverResolver + // lock for all stateful fields (excludes config which is immutable). lock sync.RWMutex } type Config struct { + // Datacenter is the datacenter of this agent. + Datacenter string + + // AgentType is either 'server' or 'client' and is required. + AgentType string + // Authority used to query the server. Defaults to "". Used to support // parallel testing because gRPC registers resolvers globally. Authority string } func NewServerResolverBuilder(cfg Config) *ServerResolverBuilder { + if cfg.Datacenter == "" { + panic("ServerResolverBuilder needs Config.Datacenter to be nonempty") + } + switch cfg.AgentType { + case "server", "client": + default: + panic("ServerResolverBuilder needs Config.AgentType to be either server or client") + } return &ServerResolverBuilder{ cfg: cfg, servers: make(map[types.AreaID]map[string]*metadata.Server), @@ -56,6 +76,7 @@ func (s *ServerResolverBuilder) ServerForGlobalAddr(globalAddr string) (*metadat } } } + return nil, fmt.Errorf("failed to find Consul server for global address %q", globalAddr) } @@ -67,12 +88,12 @@ func (s *ServerResolverBuilder) Build(target resolver.Target, cc resolver.Client // If there's already a resolver for this connection, return it. // TODO(streaming): how would this happen since we already cache connections in ClientConnPool? - if resolver, ok := s.resolvers[cc]; ok { - return resolver, nil - } if cc == s.leaderResolver.clientConn { return s.leaderResolver, nil } + if resolver, ok := s.resolvers[cc]; ok { + return resolver, nil + } //nolint:staticcheck serverType, datacenter, err := parseEndpoint(target.Endpoint) @@ -119,6 +140,10 @@ func (s *ServerResolverBuilder) Authority() string { // AddServer updates the resolvers' states to include the new server's address. func (s *ServerResolverBuilder) AddServer(areaID types.AreaID, server *metadata.Server) { + if s.shouldIgnoreServer(areaID, server) { + return + } + s.lock.Lock() defer s.lock.Unlock() @@ -155,6 +180,10 @@ func DCPrefix(datacenter, suffix string) string { // RemoveServer updates the resolvers' states with the given server removed. func (s *ServerResolverBuilder) RemoveServer(areaID types.AreaID, server *metadata.Server) { + if s.shouldIgnoreServer(areaID, server) { + return + } + s.lock.Lock() defer s.lock.Unlock() @@ -176,14 +205,48 @@ func (s *ServerResolverBuilder) RemoveServer(areaID types.AreaID, server *metada } } +// shouldIgnoreServer is used to contextually decide if a particular kind of +// server should be accepted into a given area. +// +// On client agents it's pretty easy: clients only participate in the standard +// LAN, so we only accept servers from the LAN. +// +// On server agents it's a little less obvious. This resolver is ultimately +// used to have servers dial other servers. If a server is going to cross +// between datacenters (using traditional federation) then we want to use the +// WAN addresses for them, but if a server is going to dial a sibling server in +// the same datacenter we want it to use the LAN addresses always. To achieve +// that here we simply never allow WAN servers for our current datacenter to be +// added into the resolver, letting only the LAN instances through. +func (s *ServerResolverBuilder) shouldIgnoreServer(areaID types.AreaID, server *metadata.Server) bool { + if s.cfg.AgentType == "client" && areaID != types.AreaLAN { + return true + } + + if s.cfg.AgentType == "server" && + server.Datacenter == s.cfg.Datacenter && + areaID != types.AreaLAN { + return true + } + + return false +} + // getDCAddrs returns a list of the server addresses for the given datacenter. // This method requires that lock is held for reads. func (s *ServerResolverBuilder) getDCAddrs(dc string) []resolver.Address { + lanRequest := (s.cfg.Datacenter == dc) + var ( addrs []resolver.Address keptServerIDs = make(map[string]struct{}) ) - for _, areaServers := range s.servers { + for areaID, areaServers := range s.servers { + if (areaID == types.AreaLAN) != lanRequest { + // LAN requests only look at LAN data. WAN requests only look at + // WAN data. + continue + } for _, server := range areaServers { if server.Datacenter != dc { continue diff --git a/agent/grpc-internal/resolver/resolver_test.go b/agent/grpc-internal/resolver/resolver_test.go new file mode 100644 index 0000000000..ab6e403d7d --- /dev/null +++ b/agent/grpc-internal/resolver/resolver_test.go @@ -0,0 +1,195 @@ +package resolver + +import ( + "fmt" + "net" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/serviceconfig" + + "github.com/hashicorp/consul/agent/metadata" + "github.com/hashicorp/consul/types" +) + +func TestServerResolverBuilder(t *testing.T) { + const agentDC = "dc1" + + type testcase struct { + name string + agentType string // server/client + serverType string // server/leader + requestDC string + expectLAN bool + } + + run := func(t *testing.T, tc testcase) { + rs := NewServerResolverBuilder(newConfig(t, agentDC, tc.agentType)) + + endpoint := "" + if tc.serverType == "leader" { + endpoint = "leader.local" + } else { + endpoint = tc.serverType + "." + tc.requestDC + } + + cc := &fakeClientConn{} + _, err := rs.Build(resolver.Target{ + Scheme: "consul", + Authority: rs.Authority(), + Endpoint: endpoint, + }, cc, resolver.BuildOptions{}) + require.NoError(t, err) + + for i := 0; i < 3; i++ { + dc := fmt.Sprintf("dc%d", i+1) + for j := 0; j < 3; j++ { + wanIP := fmt.Sprintf("127.1.%d.%d", i+1, j+10) + name := fmt.Sprintf("%s-server-%d", dc, j+1) + wanMeta := newServerMeta(name, dc, wanIP, true) + + if tc.agentType == "server" { + rs.AddServer(types.AreaWAN, wanMeta) + } + + if dc == agentDC { + // register LAN/WAN pairs for the same instances + lanIP := fmt.Sprintf("127.0.%d.%d", i+1, j+10) + lanMeta := newServerMeta(name, dc, lanIP, false) + rs.AddServer(types.AreaLAN, lanMeta) + + if j == 0 { + rs.UpdateLeaderAddr(dc, lanIP) + } + } + } + } + + if tc.serverType == "leader" { + assert.Len(t, cc.state.Addresses, 1) + } else { + assert.Len(t, cc.state.Addresses, 3) + } + + for _, addr := range cc.state.Addresses { + addrPrefix := tc.requestDC + "-" + if tc.expectLAN { + addrPrefix += "127.0." + } else { + addrPrefix += "127.1." + } + assert.True(t, strings.HasPrefix(addr.Addr, addrPrefix), + "%q does not start with %q (returned WAN for LAN request)", addr.Addr, addrPrefix) + + if tc.expectLAN { + assert.False(t, strings.Contains(addr.ServerName, ".dc"), + "%q ends with datacenter suffix (returned WAN for LAN request)", addr.ServerName) + } else { + assert.True(t, strings.HasSuffix(addr.ServerName, "."+tc.requestDC), + "%q does not end with %q", addr.ServerName, "."+tc.requestDC) + } + } + } + + cases := []testcase{ + { + name: "server requesting local servers", + agentType: "server", + serverType: "server", + requestDC: agentDC, + expectLAN: true, + }, + { + name: "server requesting remote servers in dc2", + agentType: "server", + serverType: "server", + requestDC: "dc2", + expectLAN: false, + }, + { + name: "server requesting remote servers in dc3", + agentType: "server", + serverType: "server", + requestDC: "dc3", + expectLAN: false, + }, + // --------------- + { + name: "server requesting local leader", + agentType: "server", + serverType: "leader", + requestDC: agentDC, + expectLAN: true, + }, + // --------------- + { + name: "client requesting local server", + agentType: "client", + serverType: "server", + requestDC: agentDC, + expectLAN: true, + }, + { + name: "client requesting local leader", + agentType: "client", + serverType: "leader", + requestDC: agentDC, + expectLAN: true, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + run(t, tc) + }) + } +} + +func newServerMeta(name, dc, ip string, wan bool) *metadata.Server { + fullname := name + if wan { + fullname = name + "." + dc + } + return &metadata.Server{ + ID: name, + Name: fullname, + ShortName: name, + Datacenter: dc, + Addr: &net.IPAddr{IP: net.ParseIP(ip)}, + UseTLS: false, + } +} + +func newConfig(t *testing.T, dc, agentType string) Config { + n := t.Name() + s := strings.Replace(n, "/", "", -1) + s = strings.Replace(s, "_", "", -1) + return Config{ + Datacenter: dc, + AgentType: agentType, + Authority: strings.ToLower(s), + } +} + +// fakeClientConn implements resolver.ClientConn for tests +type fakeClientConn struct { + state resolver.State +} + +var _ resolver.ClientConn = (*fakeClientConn)(nil) + +func (f *fakeClientConn) UpdateState(state resolver.State) error { + f.state = state + return nil +} + +func (*fakeClientConn) ReportError(error) {} +func (*fakeClientConn) NewAddress(addresses []resolver.Address) {} +func (*fakeClientConn) NewServiceConfig(serviceConfig string) {} +func (*fakeClientConn) ParseServiceConfig(serviceConfigJSON string) *serviceconfig.ParseResult { + return nil +} diff --git a/agent/peering_endpoint_test.go b/agent/peering_endpoint_test.go index 1a5ad4b352..99d6cb6c36 100644 --- a/agent/peering_endpoint_test.go +++ b/agent/peering_endpoint_test.go @@ -4,6 +4,7 @@ package agent import ( + "bufio" "bytes" "context" "encoding/base64" @@ -12,19 +13,208 @@ import ( "io" "net/http" "net/http/httptest" + "strconv" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc" + gpeer "google.golang.org/grpc/peer" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/api" "github.com/hashicorp/consul/proto/private/pbpeering" + "github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/consul/sdk/testutil/retry" "github.com/hashicorp/consul/testrpc" ) +func TestHTTP_Peering_Integration(t *testing.T) { + // This is a full-stack integration test of the gRPC (internal) stack. We + // use peering CRUD b/c that is one of the few endpoints exposed over gRPC + // (internal). + + if testing.Short() { + t.Skip("too slow for testing.Short") + } + + // We advertise a wan address we are not using, so that incidental attempts + // to use it will loudly fail. + const ip = "192.0.2.2" + + connectivityConfig := ` +ports { serf_wan = -1 } +bind_addr = "0.0.0.0" +client_addr = "0.0.0.0" +advertise_addr = "127.0.0.1" +advertise_addr_wan = "` + ip + `" ` + + var ( + buf1, buf2, buf3 bytes.Buffer + testLog = testutil.NewLogBuffer(t) + + log1 = io.MultiWriter(testLog, &buf1) + log2 = io.MultiWriter(testLog, &buf2) + log3 = io.MultiWriter(testLog, &buf3) + ) + + a1 := StartTestAgent(t, TestAgent{LogOutput: log1, HCL: ` + server = true + bootstrap = false + bootstrap_expect = 3 + ` + connectivityConfig}) + t.Cleanup(func() { a1.Shutdown() }) + + a2 := StartTestAgent(t, TestAgent{LogOutput: log2, HCL: ` + server = true + bootstrap = false + bootstrap_expect = 3 + ` + connectivityConfig}) + t.Cleanup(func() { a2.Shutdown() }) + + a3 := StartTestAgent(t, TestAgent{LogOutput: log3, HCL: ` + server = true + bootstrap = false + bootstrap_expect = 3 + ` + connectivityConfig}) + t.Cleanup(func() { a3.Shutdown() }) + + { // join a2 to a1 + addr := fmt.Sprintf("127.0.0.1:%d", a2.Config.SerfPortLAN) + _, err := a1.JoinLAN([]string{addr}, nil) + require.NoError(t, err) + } + { // join a3 to a1 + addr := fmt.Sprintf("127.0.0.1:%d", a3.Config.SerfPortLAN) + _, err := a1.JoinLAN([]string{addr}, nil) + require.NoError(t, err) + } + + testrpc.WaitForLeader(t, a1.RPC, "dc1") + testrpc.WaitForActiveCARoot(t, a1.RPC, "dc1", nil) + + testrpc.WaitForTestAgent(t, a1.RPC, "dc1") + testrpc.WaitForTestAgent(t, a2.RPC, "dc1") + testrpc.WaitForTestAgent(t, a3.RPC, "dc1") + + retry.Run(t, func(r *retry.R) { + require.Len(r, a1.LANMembersInAgentPartition(), 3) + require.Len(r, a2.LANMembersInAgentPartition(), 3) + require.Len(r, a3.LANMembersInAgentPartition(), 3) + }) + + type testcase struct { + agent *TestAgent + peerName string + prevCount int + } + + checkPeeringList := func(t *testing.T, a *TestAgent, expect int) { + req, err := http.NewRequest("GET", "/v1/peerings", nil) + require.NoError(t, err) + + resp := httptest.NewRecorder() + a.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code) + + var apiResp []*api.Peering + require.NoError(t, json.NewDecoder(resp.Body).Decode(&apiResp)) + + require.Len(t, apiResp, expect) + } + + testConn := func(t *testing.T, conn *grpc.ClientConn, peers map[string]int) { + rpcClientPeering := pbpeering.NewPeeringServiceClient(conn) + + peer := &gpeer.Peer{} + _, err := rpcClientPeering.PeeringList( + context.Background(), + &pbpeering.PeeringListRequest{}, + grpc.Peer(peer), + ) + require.NoError(t, err) + + peers[peer.Addr.String()]++ + } + + var ( + standardPeers = make(map[string]int) + leaderPeers = make(map[string]int) + ) + runOnce := func(t *testing.T, tc testcase) { + conn, err := tc.agent.baseDeps.GRPCConnPool.ClientConn("dc1") + require.NoError(t, err) + testConn(t, conn, standardPeers) + + leaderConn, err := tc.agent.baseDeps.GRPCConnPool.ClientConnLeader() + require.NoError(t, err) + testConn(t, leaderConn, leaderPeers) + + checkPeeringList(t, tc.agent, tc.prevCount) + + body := &pbpeering.GenerateTokenRequest{ + PeerName: tc.peerName, + } + + bodyBytes, err := json.Marshal(body) + require.NoError(t, err) + + req, err := http.NewRequest("POST", "/v1/peering/token", bytes.NewReader(bodyBytes)) + require.NoError(t, err) + + resp := httptest.NewRecorder() + tc.agent.srv.h.ServeHTTP(resp, req) + require.Equal(t, http.StatusOK, resp.Code, "expected 200, got %d: %v", resp.Code, resp.Body.String()) + + var r pbpeering.GenerateTokenResponse + require.NoError(t, json.NewDecoder(resp.Body).Decode(&r)) + + checkPeeringList(t, tc.agent, tc.prevCount+1) + } + + // Try the procedure on all agents to force N-1 of them to leader-forward. + cases := []testcase{ + {agent: a1, peerName: "peer-1", prevCount: 0}, + {agent: a2, peerName: "peer-2", prevCount: 1}, + {agent: a3, peerName: "peer-3", prevCount: 2}, + } + + for i, tc := range cases { + tc := tc + testutil.RunStep(t, "server-"+strconv.Itoa(i+1), func(t *testing.T) { + runOnce(t, tc) + }) + } + + testutil.RunStep(t, "ensure we got the right mixture of responses", func(t *testing.T) { + assert.Len(t, standardPeers, 3) + + // Each server talks to a single leader. + assert.Len(t, leaderPeers, 1) + for p, n := range leaderPeers { + assert.Equal(t, 3, n, "peer %q expected 3 uses", p) + } + }) + + testutil.RunStep(t, "no server experienced the server resolution error", func(t *testing.T) { + // Check them all for the bad error + const grpcError = `failed to find Consul server for global address` + + var buf bytes.Buffer + buf.ReadFrom(&buf1) + buf.ReadFrom(&buf2) + buf.ReadFrom(&buf3) + + scan := bufio.NewScanner(&buf) + for scan.Scan() { + line := scan.Text() + require.NotContains(t, line, grpcError) + } + require.NoError(t, scan.Err()) + }) +} + func TestHTTP_Peering_GenerateToken(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") diff --git a/agent/rpc/peering/service_test.go b/agent/rpc/peering/service_test.go index 3830cbdfac..405fe1a737 100644 --- a/agent/rpc/peering/service_test.go +++ b/agent/rpc/peering/service_test.go @@ -11,6 +11,7 @@ import ( "net" "os" "path" + "strings" "testing" "time" @@ -1669,6 +1670,17 @@ type testingServer struct { PublicGRPCAddr string } +func newConfig(t *testing.T, dc, agentType string) resolver.Config { + n := t.Name() + s := strings.Replace(n, "/", "", -1) + s = strings.Replace(s, "_", "", -1) + return resolver.Config{ + Datacenter: dc, + AgentType: agentType, + Authority: strings.ToLower(s), + } +} + // TODO(peering): remove duplication between this and agent/consul tests func newDefaultDeps(t *testing.T, c *consul.Config) consul.Deps { t.Helper() @@ -1683,7 +1695,7 @@ func newDefaultDeps(t *testing.T, c *consul.Config) consul.Deps { require.NoError(t, err, "failed to create tls configuration") r := router.NewRouter(logger, c.Datacenter, fmt.Sprintf("%s.%s", c.NodeName, c.Datacenter), nil) - builder := resolver.NewServerResolverBuilder(resolver.Config{}) + builder := resolver.NewServerResolverBuilder(newConfig(t, c.Datacenter, "client")) resolver.Register(builder) connPool := &pool.ConnPool{ diff --git a/agent/setup.go b/agent/setup.go index 07ab3852de..a4520e3cfc 100644 --- a/agent/setup.go +++ b/agent/setup.go @@ -120,7 +120,14 @@ func NewBaseDeps(configLoader ConfigLoader, logOut io.Writer, providedLogger hcl d.ViewStore = submatview.NewStore(d.Logger.Named("viewstore")) d.ConnPool = newConnPool(cfg, d.Logger, d.TLSConfigurator) + agentType := "client" + if cfg.ServerMode { + agentType = "server" + } + resolverBuilder := resolver.NewServerResolverBuilder(resolver.Config{ + AgentType: agentType, + Datacenter: cfg.Datacenter, // Set the authority to something sufficiently unique so any usage in // tests would be self-isolating in the global resolver map, while also // not incurring a huge penalty for non-test code.