consul/agent/rpc/middleware/interceptors_test.go

316 lines
8.1 KiB
Go
Raw Normal View History

// Copyright (c) HashiCorp, Inc.
[COMPLIANCE] License changes (#18443) * Adding explicit MPL license for sub-package This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Adding explicit MPL license for sub-package This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Updating the license from MPL to Business Source License Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at <Blog URL>, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl. * add missing license headers * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 --------- Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com>
2023-08-11 13:12:13 +00:00
// SPDX-License-Identifier: BUSL-1.1
package middleware
import (
"errors"
"net"
"net/netip"
"strings"
"sync"
"testing"
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/consul/agent/consul/rate"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
// obs holds all the things we want to assert on that we recorded correctly in our tests.
type obs struct {
key []string
elapsed float32
labels []metrics.Label
}
// recorderStore acts as an in-mem mock storage for all the RequestRecorder.Record() RecorderFunc calls.
type recorderStore struct {
lock sync.Mutex
store map[string]obs
}
func (rs *recorderStore) put(key []string, o obs) {
rs.lock.Lock()
defer rs.lock.Unlock()
actualKey := strings.Join(append(key, o.labels[0].Value), "")
rs.store[actualKey] = o
}
func (rs *recorderStore) get(key []string) obs {
rs.lock.Lock()
defer rs.lock.Unlock()
actualKey := strings.Join(key, "")
return rs.store[actualKey]
}
var store = recorderStore{store: make(map[string]obs)}
var simpleRecorderFunc = func(key []string, val float32, labels []metrics.Label) {
o := obs{key: key, elapsed: val, labels: labels}
store.put(key, o)
}
type readRequest struct{}
type writeRequest struct{}
type readReqWithTD struct{}
type writeReqWithTD struct{}
func (rr readRequest) IsRead() bool {
return true
}
func (wr writeRequest) IsRead() bool {
return false
}
func (r readReqWithTD) IsRead() bool {
return true
}
func (r readReqWithTD) RequestDatacenter() string {
return "dc3"
}
func (r readReqWithTD) GetMinQueryIndex() uint64 {
return 1
}
func (r readReqWithTD) AllowStaleRead() bool {
return false
}
func (w writeReqWithTD) IsRead() bool {
return false
}
func (w writeReqWithTD) RequestDatacenter() string {
return "dc2"
}
type testCase struct {
name string
// description is meant for human friendliness
description string
// requestName is encouraged to be unique across tests to
// avoid lock contention
requestName string
requestI interface{}
rpcType string
errored bool
isLeader func() bool
dc string
// the first element in expectedLabels should be the method name
expectedLabels []metrics.Label
}
var testCases = []testCase{
{
name: "simple ok",
description: "This is a simple happy path test case. We check for pass through and normal request processing",
requestName: "A.B",
requestI: struct{}{},
rpcType: RPCTypeInternal,
errored: false,
dc: "dc1",
expectedLabels: []metrics.Label{
{Name: "method", Value: "A.B"},
{Name: "errored", Value: "false"},
{Name: "request_type", Value: "unreported"},
{Name: "rpc_type", Value: RPCTypeInternal},
{Name: "leader", Value: "unreported"},
},
},
{
name: "simple ok errored",
description: "Checks that the errored value is populated right.",
requestName: "A.C",
requestI: struct{}{},
rpcType: "test",
errored: true,
dc: "dc1",
expectedLabels: []metrics.Label{
{Name: "method", Value: "A.C"},
{Name: "errored", Value: "true"},
{Name: "request_type", Value: "unreported"},
{Name: "rpc_type", Value: "test"},
{Name: "leader", Value: "unreported"},
},
},
{
name: "read request, rpc type internal",
description: "Checks for read request interface parsing",
requestName: "B.C",
requestI: readRequest{},
rpcType: RPCTypeInternal,
errored: false,
dc: "dc1",
expectedLabels: []metrics.Label{
{Name: "method", Value: "B.C"},
{Name: "errored", Value: "false"},
{Name: "request_type", Value: "read"},
{Name: "rpc_type", Value: RPCTypeInternal},
{Name: "leader", Value: "unreported"},
},
},
{
name: "write request, rpc type net/rpc",
description: "Checks for write request interface, different RPC type",
requestName: "D.E",
requestI: writeRequest{},
rpcType: RPCTypeNetRPC,
errored: false,
dc: "dc1",
expectedLabels: []metrics.Label{
{Name: "method", Value: "D.E"},
{Name: "errored", Value: "false"},
{Name: "request_type", Value: "write"},
{Name: "rpc_type", Value: RPCTypeNetRPC},
{Name: "leader", Value: "unreported"},
},
},
{
name: "read request with blocking stale and target dc",
description: "Checks for locality, blocking status and target dc",
requestName: "E.F",
requestI: readReqWithTD{},
rpcType: RPCTypeNetRPC,
errored: false,
dc: "dc1",
expectedLabels: []metrics.Label{
{Name: "method", Value: "E.F"},
{Name: "errored", Value: "false"},
{Name: "request_type", Value: "read"},
{Name: "rpc_type", Value: RPCTypeNetRPC},
{Name: "leader", Value: "unreported"},
{Name: "allow_stale", Value: "false"},
{Name: "blocking", Value: "true"},
{Name: "target_datacenter", Value: "dc3"},
{Name: "locality", Value: "forwarded"},
},
},
{
name: "write request with TD, locality local",
description: "Checks for write request with local forwarding and target dc",
requestName: "F.G",
requestI: writeReqWithTD{},
rpcType: RPCTypeNetRPC,
errored: false,
dc: "dc2",
expectedLabels: []metrics.Label{
{Name: "method", Value: "F.G"},
{Name: "errored", Value: "false"},
{Name: "request_type", Value: "write"},
{Name: "rpc_type", Value: RPCTypeNetRPC},
{Name: "leader", Value: "unreported"},
{Name: "target_datacenter", Value: "dc2"},
{Name: "locality", Value: "local"},
},
},
{
name: "is leader",
description: "checks for is leader",
requestName: "G.H",
requestI: struct{}{},
rpcType: "test",
errored: false,
isLeader: func() bool {
return true
},
expectedLabels: []metrics.Label{
{Name: "method", Value: "G.H"},
{Name: "errored", Value: "false"},
{Name: "request_type", Value: "unreported"},
{Name: "rpc_type", Value: "test"},
{Name: "leader", Value: "true"},
},
},
{
name: "is not leader",
description: "checks for is not leader",
requestName: "H.I",
requestI: struct{}{},
rpcType: "test",
errored: false,
isLeader: func() bool {
return false
},
expectedLabels: []metrics.Label{
{Name: "method", Value: "H.I"},
{Name: "errored", Value: "false"},
{Name: "request_type", Value: "unreported"},
{Name: "rpc_type", Value: "test"},
{Name: "leader", Value: "false"},
},
},
}
// TestRequestRecorder goes over all the parsing and reporting that RequestRecorder
// is expected to perform.
func TestRequestRecorder(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := RequestRecorder{
Logger: hclog.NewInterceptLogger(&hclog.LoggerOptions{}),
RecorderFunc: simpleRecorderFunc,
serverIsLeader: tc.isLeader,
localDC: tc.dc,
}
start := time.Now()
r.Record(tc.requestName, tc.rpcType, start, tc.requestI, tc.errored)
key := append(metricRPCRequest, tc.expectedLabels[0].Value)
o := store.get(key)
require.Equal(t, o.key, metricRPCRequest)
require.LessOrEqual(t, o.elapsed, float32(time.Now().Sub(start).Microseconds())/1000)
require.Equal(t, o.labels, tc.expectedLabels)
})
}
}
func TestGetNetRPCRateLimitingInterceptor(t *testing.T) {
limiter := rate.NewMockRequestLimitsHandler(t)
logger := hclog.NewNullLogger()
rateLimitInterceptor := GetNetRPCRateLimitingInterceptor(limiter, NewPanicHandler(logger))
addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("1.2.3.4:5678"))
t.Run("allow operation", func(t *testing.T) {
limiter.On("Allow", mock.Anything).
Return(nil).
Once()
err := rateLimitInterceptor("Status.Leader", addr)
require.NoError(t, err)
})
t.Run("allow returns error", func(t *testing.T) {
limiter.On("Allow", mock.Anything).
Return(errors.New("uh oh")).
Once()
err := rateLimitInterceptor("Status.Leader", addr)
require.Error(t, err)
require.Equal(t, "uh oh", err.Error())
})
t.Run("allow panics", func(t *testing.T) {
limiter.On("Allow", mock.Anything).
Panic("uh oh").
Once()
err := rateLimitInterceptor("Status.Leader", addr)
require.Error(t, err)
require.Equal(t, "rpc: panic serving request", err.Error())
})
}