diff --git a/agent/grpc/client.go b/agent/grpc/client.go index e65e95a13c..71f16c7c31 100644 --- a/agent/grpc/client.go +++ b/agent/grpc/client.go @@ -54,14 +54,15 @@ func (c *ClientConnPool) ClientConn(datacenter string) (*grpc.ClientConn, error) return conn, nil } - conn, err := grpc.Dial(fmt.Sprintf("%s:///server.%s", c.servers.Scheme(), datacenter), + conn, err := grpc.Dial( + fmt.Sprintf("%s:///server.%s", c.servers.Scheme(), datacenter), // use WithInsecure mode here because we handle the TLS wrapping in the // custom dialer based on logic around whether the server has TLS enabled. grpc.WithInsecure(), grpc.WithContextDialer(c.dialer), grpc.WithDisableRetry(), // TODO: previously this statsHandler was shared with the Handler. Is that necessary? - grpc.WithStatsHandler(&statsHandler{}), + grpc.WithStatsHandler(newStatsHandler()), // nolint:staticcheck // there is no other supported alternative to WithBalancerName grpc.WithBalancerName("pick_first")) if err != nil { diff --git a/agent/grpc/client_test.go b/agent/grpc/client_test.go new file mode 100644 index 0000000000..d8ea50dd8b --- /dev/null +++ b/agent/grpc/client_test.go @@ -0,0 +1,92 @@ +package grpc + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/hashicorp/consul/agent/grpc/internal/testservice" + "github.com/hashicorp/consul/agent/grpc/resolver" + "github.com/hashicorp/consul/agent/metadata" + "github.com/stretchr/testify/require" +) + +func TestNewDialer(t *testing.T) { + // TODO: conn is closed on errors + // TODO: with TLS enabled +} + +func TestClientConnPool_IntegrationWithGRPCResolver_Failover(t *testing.T) { + count := 4 + cfg := resolver.Config{Datacenter: "dc1", Scheme: newScheme(t.Name())} + res := resolver.NewServerResolverBuilder(cfg, fakeNodes{num: count}) + resolver.RegisterWithGRPC(res) + pool := NewClientConnPool(res, nil) + + for i := 0; i < count; i++ { + name := fmt.Sprintf("server-%d", i) + srv := newTestServer(t, name, "dc1") + res.AddServer(srv.Metadata()) + t.Cleanup(srv.shutdown) + } + + conn, err := pool.ClientConn("dc1") + require.NoError(t, err) + client := testservice.NewSimpleClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + t.Cleanup(cancel) + + first, err := client.Something(ctx, &testservice.Req{}) + require.NoError(t, err) + res.RemoveServer(&metadata.Server{ID: first.ServerName, Datacenter: "dc1"}) + + resp, err := client.Something(ctx, &testservice.Req{}) + require.NoError(t, err) + require.NotEqual(t, resp.ServerName, first.ServerName) +} + +func newScheme(n string) string { + s := strings.Replace(n, "/", "", -1) + s = strings.Replace(s, "_", "", -1) + return strings.ToLower(s) +} + +type fakeNodes struct { + num int +} + +func (n fakeNodes) NumNodes() int { + return n.num +} + +func TestClientConnPool_IntegrationWithGRPCResolver_MultiDC(t *testing.T) { + dcs := []string{"dc1", "dc2", "dc3"} + + cfg := resolver.Config{Datacenter: "dc1", Scheme: newScheme(t.Name())} + res := resolver.NewServerResolverBuilder(cfg, fakeNodes{num: 1}) + resolver.RegisterWithGRPC(res) + pool := NewClientConnPool(res, nil) + + for _, dc := range dcs { + name := "server-0-" + dc + srv := newTestServer(t, name, dc) + res.AddServer(srv.Metadata()) + t.Cleanup(srv.shutdown) + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + t.Cleanup(cancel) + + for _, dc := range dcs { + conn, err := pool.ClientConn(dc) + require.NoError(t, err) + client := testservice.NewSimpleClient(conn) + + resp, err := client.Something(ctx, &testservice.Req{}) + require.NoError(t, err) + require.Equal(t, resp.Datacenter, dc) + } +} diff --git a/agent/grpc/handler.go b/agent/grpc/handler.go index c3af7f38c4..c43c1ba1e2 100644 --- a/agent/grpc/handler.go +++ b/agent/grpc/handler.go @@ -21,10 +21,8 @@ func NewHandler(addr net.Addr) *Handler { // TODO(streaming): add gRPC services to srv here - return &Handler{ - srv: srv, - listener: &chanListener{addr: addr, conns: make(chan net.Conn)}, - } + lis := &chanListener{addr: addr, conns: make(chan net.Conn)} + return &Handler{srv: srv, listener: lis} } // Handler implements a handler for the rpc server listener, and the @@ -57,15 +55,26 @@ type chanListener struct { // Accept blocks until a connection is received from Handle, and then returns the // connection. Accept implements part of the net.Listener interface for grpc.Server. func (l *chanListener) Accept() (net.Conn, error) { - return <-l.conns, nil + select { + case c, ok := <-l.conns: + if !ok { + return nil, &net.OpError{ + Op: "accept", + Net: l.addr.Network(), + Addr: l.addr, + Err: fmt.Errorf("listener closed"), + } + } + return c, nil + } } func (l *chanListener) Addr() net.Addr { return l.addr } -// Close does nothing. The connections are managed by the caller. func (l *chanListener) Close() error { + close(l.conns) return nil } diff --git a/agent/grpc/resolver/resolver.go b/agent/grpc/resolver/resolver.go index 82e814ae09..3bf66b74c1 100644 --- a/agent/grpc/resolver/resolver.go +++ b/agent/grpc/resolver/resolver.go @@ -40,7 +40,8 @@ type ServerResolverBuilder struct { // scheme used to query the server. Defaults to consul. Used to support // parallel testing because gRPC registers resolvers globally. scheme string - // servers is an index of Servers by Server.ID + // servers is an index of Servers by Server.ID. The map contains server IDs + // for all datacenters, so it assumes the ID is globally unique. servers map[string]*metadata.Server // resolvers is an index of connections to the serverResolver which manages // addresses of servers for that connection. diff --git a/agent/grpc/server_test.go b/agent/grpc/server_test.go index b7843ff011..b4cb9c7834 100644 --- a/agent/grpc/server_test.go +++ b/agent/grpc/server_test.go @@ -2,11 +2,66 @@ package grpc import ( "context" + "fmt" + "io" + "net" + "testing" "time" "github.com/hashicorp/consul/agent/grpc/internal/testservice" + "github.com/hashicorp/consul/agent/metadata" + "github.com/hashicorp/consul/agent/pool" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" ) +type testServer struct { + addr net.Addr + name string + dc string + shutdown func() +} + +func (s testServer) Metadata() *metadata.Server { + return &metadata.Server{ID: s.name, Datacenter: s.dc, Addr: s.addr} +} + +func newTestServer(t *testing.T, name string, dc string) testServer { + addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")} + handler := NewHandler(addr) + + testservice.RegisterSimpleServer(handler.srv, &simple{name: name, dc: dc}) + + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + rpc := &fakeRPCListener{t: t, handler: handler} + + g := errgroup.Group{} + g.Go(func() error { + return rpc.listen(lis) + }) + g.Go(func() error { + return handler.Run() + }) + return testServer{ + addr: lis.Addr(), + name: name, + dc: dc, + shutdown: func() { + if err := lis.Close(); err != nil { + t.Logf("listener closed with error: %v", err) + } + if err := handler.Shutdown(); err != nil { + t.Logf("grpc server shutdown: %v", err) + } + if err := g.Wait(); err != nil { + t.Logf("grpc server error: %v", err) + } + }, + } +} + type simple struct { name string dc string @@ -26,3 +81,45 @@ func (s *simple) Flow(_ *testservice.Req, flow testservice.Simple_FlowServer) er func (s *simple) Something(_ context.Context, _ *testservice.Req) (*testservice.Resp, error) { return &testservice.Resp{ServerName: s.name, Datacenter: s.dc}, nil } + +// fakeRPCListener mimics agent/consul.Server.listen to handle the RPCType byte. +// In the future we should be able to refactor Server and extract this RPC +// handling logic so that we don't need to use a fake. +// For now, since this logic is in agent/consul, we can't easily use Server.listen +// so we fake it. +type fakeRPCListener struct { + t *testing.T + handler *Handler +} + +func (f *fakeRPCListener) listen(listener net.Listener) error { + for { + conn, err := listener.Accept() + if err != nil { + return err + } + + go f.handleConn(conn) + } +} + +func (f *fakeRPCListener) handleConn(conn net.Conn) { + buf := make([]byte, 1) + + if _, err := conn.Read(buf); err != nil { + if err != io.EOF { + fmt.Println("ERROR", err.Error()) + } + conn.Close() + return + } + typ := pool.RPCType(buf[0]) + + if typ == pool.RPCGRPC { + f.handler.Handle(conn) + return + } + + fmt.Println("ERROR: unexpected byte", typ) + conn.Close() +}