consul/agent/grpc/server_test.go

177 lines
3.9 KiB
Go

package grpc
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/resolver"
"github.com/hashicorp/consul/agent/grpc/internal/testservice"
"github.com/hashicorp/consul/agent/metadata"
"github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/tlsutil"
)
type testServer struct {
addr net.Addr
name string
dc string
shutdown func()
rpc *fakeRPCListener
}
func (s testServer) Metadata() *metadata.Server {
return &metadata.Server{
ID: s.name,
Datacenter: s.dc,
Addr: s.addr,
UseTLS: s.rpc.tlsConf != nil,
}
}
func newTestServer(t *testing.T, name string, dc string) testServer {
addr := &net.IPAddr{IP: net.ParseIP("127.0.0.1")}
handler := NewHandler(addr, func(server *grpc.Server) {
testservice.RegisterSimpleServer(server, &simple{name: name, dc: dc})
})
lis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
rpc := &fakeRPCListener{t: t, handler: handler}
g := errgroup.Group{}
g.Go(func() error {
if err := rpc.listen(lis); err != nil {
return fmt.Errorf("fake rpc listen error: %w", err)
}
return nil
})
g.Go(func() error {
if err := handler.Run(); err != nil {
return fmt.Errorf("grpc server error: %w", err)
}
return nil
})
return testServer{
addr: lis.Addr(),
name: name,
dc: dc,
rpc: rpc,
shutdown: func() {
rpc.shutdown = true
if err := lis.Close(); err != nil {
t.Logf("listener closed with error: %v", err)
}
if err := handler.Shutdown(); err != nil {
t.Logf("grpc server shutdown: %v", err)
}
if err := g.Wait(); err != nil {
t.Log(err)
}
},
}
}
type simple struct {
name string
dc string
}
func (s *simple) Flow(_ *testservice.Req, flow testservice.Simple_FlowServer) error {
for flow.Context().Err() == nil {
resp := &testservice.Resp{ServerName: "one", Datacenter: s.dc}
if err := flow.Send(resp); err != nil {
return err
}
time.Sleep(time.Millisecond)
}
return nil
}
func (s *simple) Something(_ context.Context, _ *testservice.Req) (*testservice.Resp, error) {
return &testservice.Resp{ServerName: s.name, Datacenter: s.dc}, nil
}
// fakeRPCListener mimics agent/consul.Server.listen to handle the RPCType byte.
// In the future we should be able to refactor Server and extract this RPC
// handling logic so that we don't need to use a fake.
// For now, since this logic is in agent/consul, we can't easily use Server.listen
// so we fake it.
type fakeRPCListener struct {
t *testing.T
handler *Handler
shutdown bool
tlsConf *tlsutil.Configurator
tlsConnEstablished int32
}
func (f *fakeRPCListener) listen(listener net.Listener) error {
for {
conn, err := listener.Accept()
if err != nil {
if f.shutdown {
return nil
}
return err
}
go f.handleConn(conn)
}
}
func (f *fakeRPCListener) handleConn(conn net.Conn) {
buf := make([]byte, 1)
if _, err := conn.Read(buf); err != nil {
if err != io.EOF {
fmt.Println("ERROR", err.Error())
}
conn.Close()
return
}
typ := pool.RPCType(buf[0])
switch typ {
case pool.RPCGRPC:
f.handler.Handle(conn)
return
case pool.RPCTLS:
// occasionally we see a test client connecting to an rpc listener that
// was created as part of another test, despite none of the tests running
// in parallel.
// Maybe some strange grpc behaviour? I'm not sure.
if f.tlsConf == nil {
fmt.Println("ERROR: tls is not configured")
conn.Close()
return
}
atomic.AddInt32(&f.tlsConnEstablished, 1)
conn = tls.Server(conn, f.tlsConf.IncomingRPCConfig())
f.handleConn(conn)
default:
fmt.Println("ERROR: unexpected byte", typ)
conn.Close()
}
}
func registerWithGRPC(t *testing.T, b resolver.Builder) {
resolver.Register(b)
t.Cleanup(func() {
resolver.UnregisterForTesting(b.Scheme())
})
}