consul/agent/grpc-internal/balancer/balancer_test.go
hashicorp-copywrite[bot] 5fb9df1640
[COMPLIANCE] License changes (#18443)
* Adding explicit MPL license for sub-package

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Adding explicit MPL license for sub-package

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Updating the license from MPL to Business Source License

Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at <Blog URL>, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl.

* add missing license headers

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

---------

Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com>
2023-08-11 09:12:13 -04:00

340 lines
8.9 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package balancer
import (
"context"
"fmt"
"math/rand"
"net"
"net/url"
"sort"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/consul/agent/grpc-middleware/testutil/testservice"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/sdk/testutil/retry"
)
func TestBalancer(t *testing.T) {
t.Run("remains pinned to the same server", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")
target, authority, _ := stubResolver(t, server1, server2)
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)
conn := dial(t, target)
client := testservice.NewSimpleClient(conn)
var serverName string
for i := 0; i < 5; i++ {
rsp, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
if i == 0 {
serverName = rsp.ServerName
} else {
require.Equal(t, serverName, rsp.ServerName)
}
}
var pinnedServer, otherServer *server
switch serverName {
case server1.name:
pinnedServer, otherServer = server1, server2
case server2.name:
pinnedServer, otherServer = server2, server1
}
require.Equal(t, 1,
pinnedServer.openConnections(),
"pinned server should have 1 connection",
)
require.Zero(t,
otherServer.openConnections(),
"other server should have no connections",
)
})
t.Run("switches server on-error", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")
target, authority, _ := stubResolver(t, server1, server2)
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)
conn := dial(t, target)
client := testservice.NewSimpleClient(conn)
// Figure out which server we're talking to now, and which we should switch to.
rsp, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
var initialServer, otherServer *server
switch rsp.ServerName {
case server1.name:
initialServer, otherServer = server1, server2
case server2.name:
initialServer, otherServer = server2, server1
}
// Next request should fail (we don't have retries configured).
initialServer.err = status.Error(codes.ResourceExhausted, "rate limit exceeded")
_, err = client.Something(ctx, &testservice.Req{})
require.Error(t, err)
// Following request should succeed (against the other server).
rsp, err = client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
require.Equal(t, otherServer.name, rsp.ServerName)
retry.Run(t, func(r *retry.R) {
require.Zero(r,
initialServer.openConnections(),
"connection to previous server should have been torn down",
)
})
})
t.Run("rebalance changes the server", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")
target, authority, _ := stubResolver(t, server1, server2)
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)
// Provide a custom prioritizer that causes Rebalance to choose whichever
// server didn't get our first request.
var otherServer *server
balancerBuilder.shuffler = func(addrs []resolver.Address) {
sort.Slice(addrs, func(a, b int) bool {
return addrs[a].Addr == otherServer.addr
})
}
conn := dial(t, target)
client := testservice.NewSimpleClient(conn)
// Figure out which server we're talking to now.
rsp, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
var initialServer *server
switch rsp.ServerName {
case server1.name:
initialServer, otherServer = server1, server2
case server2.name:
initialServer, otherServer = server2, server1
}
// Trigger a rebalance.
targetURL, err := url.Parse(target)
require.NoError(t, err)
balancerBuilder.Rebalance(resolver.Target{URL: *targetURL})
// Following request should hit the other server.
rsp, err = client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
require.Equal(t, otherServer.name, rsp.ServerName)
retry.Run(t, func(r *retry.R) {
require.Zero(r,
initialServer.openConnections(),
"connection to previous server should have been torn down",
)
})
})
t.Run("resolver removes the server", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
server1 := runServer(t, "server1")
server2 := runServer(t, "server2")
target, authority, res := stubResolver(t, server1, server2)
balancerBuilder := NewBuilder(authority, testutil.Logger(t))
balancerBuilder.Register()
t.Cleanup(balancerBuilder.Deregister)
conn := dial(t, target)
client := testservice.NewSimpleClient(conn)
// Figure out which server we're talking to now.
rsp, err := client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
var initialServer, otherServer *server
switch rsp.ServerName {
case server1.name:
initialServer, otherServer = server1, server2
case server2.name:
initialServer, otherServer = server2, server1
}
// Remove the server's address.
res.UpdateState(resolver.State{
Addresses: []resolver.Address{
{Addr: otherServer.addr},
},
})
// Following request should hit the other server.
rsp, err = client.Something(ctx, &testservice.Req{})
require.NoError(t, err)
require.Equal(t, otherServer.name, rsp.ServerName)
retry.Run(t, func(r *retry.R) {
require.Zero(r,
initialServer.openConnections(),
"connection to previous server should have been torn down",
)
})
// Remove the other server too.
res.UpdateState(resolver.State{
Addresses: []resolver.Address{},
})
_, err = client.Something(ctx, &testservice.Req{})
require.Error(t, err)
require.Contains(t, err.Error(), "resolver produced no addresses")
retry.Run(t, func(r *retry.R) {
require.Zero(r,
otherServer.openConnections(),
"connection to other server should have been torn down",
)
})
})
}
func stubResolver(t *testing.T, servers ...*server) (string, string, *manual.Resolver) {
t.Helper()
addresses := make([]resolver.Address, len(servers))
for idx, s := range servers {
addresses[idx] = resolver.Address{Addr: s.addr}
}
scheme := fmt.Sprintf("consul-%d-%d", time.Now().UnixNano(), rand.Int())
r := manual.NewBuilderWithScheme(scheme)
r.InitialState(resolver.State{Addresses: addresses})
resolver.Register(r)
t.Cleanup(func() { resolver.UnregisterForTesting(scheme) })
authority, err := uuid.GenerateUUID()
require.NoError(t, err)
return fmt.Sprintf("%s://%s", scheme, authority), authority, r
}
func runServer(t *testing.T, name string) *server {
t.Helper()
lis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
s := &server{
name: name,
addr: lis.Addr().String(),
}
gs := grpc.NewServer(
grpc.StatsHandler(s),
)
testservice.RegisterSimpleServer(gs, s)
go gs.Serve(lis)
var once sync.Once
s.shutdown = func() { once.Do(gs.Stop) }
t.Cleanup(s.shutdown)
return s
}
type server struct {
name string
addr string
err error
c int32
shutdown func()
}
func (s *server) openConnections() int { return int(atomic.LoadInt32(&s.c)) }
func (*server) HandleRPC(context.Context, stats.RPCStats) {}
func (*server) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { return ctx }
func (*server) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { return ctx }
func (s *server) HandleConn(_ context.Context, cs stats.ConnStats) {
switch cs.(type) {
case *stats.ConnBegin:
atomic.AddInt32(&s.c, 1)
case *stats.ConnEnd:
atomic.AddInt32(&s.c, -1)
}
}
func (*server) Flow(*testservice.Req, testservice.Simple_FlowServer) error { return nil }
func (s *server) Something(context.Context, *testservice.Req) (*testservice.Resp, error) {
if s.err != nil {
return nil, s.err
}
return &testservice.Resp{ServerName: s.name}, nil
}
func dial(t *testing.T, target string) *grpc.ClientConn {
conn, err := grpc.Dial(
target,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig(
fmt.Sprintf(`{"loadBalancingConfig":[{"%s":{}}]}`, BuilderName),
),
)
t.Cleanup(func() {
if err := conn.Close(); err != nil {
t.Logf("error closing connection: %v", err)
}
})
require.NoError(t, err)
return conn
}