consul/agent/grpc-internal/balancer/balancer_test.go

328 lines
8.6 KiB
Go
Raw Normal View History

package balancer
import (
"context"
"fmt"
"math/rand"
"net"
"net/url"
"sort"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/sdk/testutil/retry"
)
func TestBalancer(t *testing.T) {
t.Run("remains pinned to the same server", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")
target, _ := stubResolver(t, server1, server2)
balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
balancerBuilder.Register()
conn := dial(t, target, balancerBuilder)
client := testservice.NewSimpleClient(conn)
var serverName string
for i := 0; i < 5; i++ {
rsp, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
if i == 0 {
serverName = rsp.ServerName
} else {
require.Equal(t, serverName, rsp.ServerName)
}
}
var pinnedServer, otherServer *server
switch serverName {
case server1.name:
pinnedServer, otherServer = server1, server2
case server2.name:
pinnedServer, otherServer = server2, server1
}
require.Equal(t, 1,
pinnedServer.openConnections(),
"pinned server should have 1 connection",
)
require.Zero(t,
otherServer.openConnections(),
"other server should have no connections",
)
})
t.Run("switches server on-error", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")
target, _ := stubResolver(t, server1, server2)
balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
balancerBuilder.Register()
conn := dial(t, target, balancerBuilder)
client := testservice.NewSimpleClient(conn)
// Figure out which server we're talking to now, and which we should switch to.
rsp, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
var initialServer, otherServer *server
switch rsp.ServerName {
case server1.name:
initialServer, otherServer = server1, server2
case server2.name:
initialServer, otherServer = server2, server1
}
// Next request should fail (we don't have retries configured).
initialServer.err = status.Error(codes.ResourceExhausted, "rate limit exceeded")
_, err = client.Something(ctx, &testservice.Req{})
require.Error(t, err)
// Following request should succeed (against the other server).
rsp, err = client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
require.Equal(t, otherServer.name, rsp.ServerName)
retry.Run(t, func(r *retry.R) {
require.Zero(r,
initialServer.openConnections(),
"connection to previous server should have been torn down",
)
})
})
t.Run("rebalance changes the server", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")
target, _ := stubResolver(t, server1, server2)
balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
balancerBuilder.Register()
// Provide a custom prioritizer that causes Rebalance to choose whichever
// server didn't get our first request.
var otherServer *server
balancerBuilder.shuffler = func(addrs []resolver.Address) {
sort.Slice(addrs, func(a, b int) bool {
return addrs[a].Addr == otherServer.addr
})
}
conn := dial(t, target, balancerBuilder)
client := testservice.NewSimpleClient(conn)
// Figure out which server we're talking to now.
rsp, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
var initialServer *server
switch rsp.ServerName {
case server1.name:
initialServer, otherServer = server1, server2
case server2.name:
initialServer, otherServer = server2, server1
}
// Trigger a rebalance.
targetURL, err := url.Parse(target)
require.NoError(t, err)
balancerBuilder.Rebalance(resolver.Target{URL: *targetURL})
// Following request should hit the other server.
rsp, err = client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
require.Equal(t, otherServer.name, rsp.ServerName)
retry.Run(t, func(r *retry.R) {
require.Zero(r,
initialServer.openConnections(),
"connection to previous server should have been torn down",
)
})
})
t.Run("resolver removes the server", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")
target, res := stubResolver(t, server1, server2)
balancerBuilder := NewBuilder(t.Name(), testutil.Logger(t))
balancerBuilder.Register()
conn := dial(t, target, balancerBuilder)
client := testservice.NewSimpleClient(conn)
// Figure out which server we're talking to now.
rsp, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
var initialServer, otherServer *server
switch rsp.ServerName {
case server1.name:
initialServer, otherServer = server1, server2
case server2.name:
initialServer, otherServer = server2, server1
}
// Remove the server's address.
res.UpdateState(resolver.State{
Addresses: []resolver.Address{
{Addr: otherServer.addr},
},
})
// Following request should hit the other server.
rsp, err = client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
require.Equal(t, otherServer.name, rsp.ServerName)
retry.Run(t, func(r *retry.R) {
require.Zero(r,
initialServer.openConnections(),
"connection to previous server should have been torn down",
)
})
// Remove the other server too.
res.UpdateState(resolver.State{
Addresses: []resolver.Address{},
})
_, err = client.Something(ctx, &testservice.Req{})
require.Error(t, err)
require.Contains(t, err.Error(), "resolver produced no addresses")
retry.Run(t, func(r *retry.R) {
require.Zero(r,
otherServer.openConnections(),
"connection to other server should have been torn down",
)
})
})
}
func stubResolver(t *testing.T, servers ...*server) (string, *manual.Resolver) {
t.Helper()
addresses := make([]resolver.Address, len(servers))
for idx, s := range servers {
addresses[idx] = resolver.Address{Addr: s.addr}
}
scheme := fmt.Sprintf("consul-%d-%d", time.Now().UnixNano(), rand.Int())
r := manual.NewBuilderWithScheme(scheme)
r.InitialState(resolver.State{Addresses: addresses})
resolver.Register(r)
t.Cleanup(func() { resolver.UnregisterForTesting(scheme) })
return fmt.Sprintf("%s://", scheme), r
}
func runServer(t *testing.T, name string) *server {
t.Helper()
lis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
s := &server{
name: name,
addr: lis.Addr().String(),
}
gs := grpc.NewServer(
grpc.StatsHandler(s),
)
testservice.RegisterSimpleServer(gs, s)
go gs.Serve(lis)
var once sync.Once
s.shutdown = func() { once.Do(gs.Stop) }
t.Cleanup(s.shutdown)
return s
}
type server struct {
name string
addr string
err error
c int32
shutdown func()
}
func (s *server) openConnections() int { return int(atomic.LoadInt32(&s.c)) }
func (*server) HandleRPC(context.Context, stats.RPCStats) {}
func (*server) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { return ctx }
func (*server) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { return ctx }
func (s *server) HandleConn(_ context.Context, cs stats.ConnStats) {
switch cs.(type) {
case *stats.ConnBegin:
atomic.AddInt32(&s.c, 1)
case *stats.ConnEnd:
atomic.AddInt32(&s.c, -1)
}
}
func (*server) Flow(*testservice.Req, testservice.Simple_FlowServer) error { return nil }
func (s *server) Something(context.Context, *testservice.Req) (*testservice.Resp, error) {
if s.err != nil {
return nil, s.err
}
return &testservice.Resp{ServerName: s.name}, nil
}
func dial(t *testing.T, target string, builder *Builder) *grpc.ClientConn {
conn, err := grpc.Dial(
target,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig(
fmt.Sprintf(`{"loadBalancingConfig":[{"%s":{}}]}`, builder.Name()),
),
)
t.Cleanup(func() {
if err := conn.Close(); err != nil {
t.Logf("error closing connection: %v", err)
}
})
require.NoError(t, err)
return conn
}