mirror of https://github.com/status-im/consul.git
316 lines
8.1 KiB
Go
316 lines
8.1 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package middleware
|
|
|
|
import (
|
|
"errors"
|
|
"net"
|
|
"net/netip"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/armon/go-metrics"
|
|
"github.com/hashicorp/consul/agent/consul/rate"
|
|
"github.com/hashicorp/go-hclog"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// obs holds all the things we want to assert on that we recorded correctly in our tests.
|
|
type obs struct {
|
|
key []string
|
|
elapsed float32
|
|
labels []metrics.Label
|
|
}
|
|
|
|
// recorderStore acts as an in-mem mock storage for all the RequestRecorder.Record() RecorderFunc calls.
|
|
type recorderStore struct {
|
|
lock sync.Mutex
|
|
store map[string]obs
|
|
}
|
|
|
|
func (rs *recorderStore) put(key []string, o obs) {
|
|
rs.lock.Lock()
|
|
defer rs.lock.Unlock()
|
|
|
|
actualKey := strings.Join(append(key, o.labels[0].Value), "")
|
|
rs.store[actualKey] = o
|
|
}
|
|
|
|
func (rs *recorderStore) get(key []string) obs {
|
|
rs.lock.Lock()
|
|
defer rs.lock.Unlock()
|
|
|
|
actualKey := strings.Join(key, "")
|
|
return rs.store[actualKey]
|
|
}
|
|
|
|
var store = recorderStore{store: make(map[string]obs)}
|
|
var simpleRecorderFunc = func(key []string, val float32, labels []metrics.Label) {
|
|
o := obs{key: key, elapsed: val, labels: labels}
|
|
store.put(key, o)
|
|
}
|
|
|
|
type readRequest struct{}
|
|
type writeRequest struct{}
|
|
type readReqWithTD struct{}
|
|
type writeReqWithTD struct{}
|
|
|
|
func (rr readRequest) IsRead() bool {
|
|
return true
|
|
}
|
|
|
|
func (wr writeRequest) IsRead() bool {
|
|
return false
|
|
}
|
|
|
|
func (r readReqWithTD) IsRead() bool {
|
|
return true
|
|
}
|
|
|
|
func (r readReqWithTD) RequestDatacenter() string {
|
|
return "dc3"
|
|
}
|
|
|
|
func (r readReqWithTD) GetMinQueryIndex() uint64 {
|
|
return 1
|
|
}
|
|
func (r readReqWithTD) AllowStaleRead() bool {
|
|
return false
|
|
}
|
|
|
|
func (w writeReqWithTD) IsRead() bool {
|
|
return false
|
|
}
|
|
|
|
func (w writeReqWithTD) RequestDatacenter() string {
|
|
return "dc2"
|
|
}
|
|
|
|
type testCase struct {
|
|
name string
|
|
// description is meant for human friendliness
|
|
description string
|
|
// requestName is encouraged to be unique across tests to
|
|
// avoid lock contention
|
|
requestName string
|
|
requestI interface{}
|
|
rpcType string
|
|
errored bool
|
|
isLeader func() bool
|
|
dc string
|
|
// the first element in expectedLabels should be the method name
|
|
expectedLabels []metrics.Label
|
|
}
|
|
|
|
var testCases = []testCase{
|
|
{
|
|
name: "simple ok",
|
|
description: "This is a simple happy path test case. We check for pass through and normal request processing",
|
|
requestName: "A.B",
|
|
requestI: struct{}{},
|
|
rpcType: RPCTypeInternal,
|
|
errored: false,
|
|
dc: "dc1",
|
|
expectedLabels: []metrics.Label{
|
|
{Name: "method", Value: "A.B"},
|
|
{Name: "errored", Value: "false"},
|
|
{Name: "request_type", Value: "unreported"},
|
|
{Name: "rpc_type", Value: RPCTypeInternal},
|
|
{Name: "leader", Value: "unreported"},
|
|
},
|
|
},
|
|
{
|
|
name: "simple ok errored",
|
|
description: "Checks that the errored value is populated right.",
|
|
requestName: "A.C",
|
|
requestI: struct{}{},
|
|
rpcType: "test",
|
|
errored: true,
|
|
dc: "dc1",
|
|
expectedLabels: []metrics.Label{
|
|
{Name: "method", Value: "A.C"},
|
|
{Name: "errored", Value: "true"},
|
|
{Name: "request_type", Value: "unreported"},
|
|
{Name: "rpc_type", Value: "test"},
|
|
{Name: "leader", Value: "unreported"},
|
|
},
|
|
},
|
|
{
|
|
name: "read request, rpc type internal",
|
|
description: "Checks for read request interface parsing",
|
|
requestName: "B.C",
|
|
requestI: readRequest{},
|
|
rpcType: RPCTypeInternal,
|
|
errored: false,
|
|
dc: "dc1",
|
|
expectedLabels: []metrics.Label{
|
|
{Name: "method", Value: "B.C"},
|
|
{Name: "errored", Value: "false"},
|
|
{Name: "request_type", Value: "read"},
|
|
{Name: "rpc_type", Value: RPCTypeInternal},
|
|
{Name: "leader", Value: "unreported"},
|
|
},
|
|
},
|
|
{
|
|
name: "write request, rpc type net/rpc",
|
|
description: "Checks for write request interface, different RPC type",
|
|
requestName: "D.E",
|
|
requestI: writeRequest{},
|
|
rpcType: RPCTypeNetRPC,
|
|
errored: false,
|
|
dc: "dc1",
|
|
expectedLabels: []metrics.Label{
|
|
{Name: "method", Value: "D.E"},
|
|
{Name: "errored", Value: "false"},
|
|
{Name: "request_type", Value: "write"},
|
|
{Name: "rpc_type", Value: RPCTypeNetRPC},
|
|
{Name: "leader", Value: "unreported"},
|
|
},
|
|
},
|
|
{
|
|
name: "read request with blocking stale and target dc",
|
|
description: "Checks for locality, blocking status and target dc",
|
|
requestName: "E.F",
|
|
requestI: readReqWithTD{},
|
|
rpcType: RPCTypeNetRPC,
|
|
errored: false,
|
|
dc: "dc1",
|
|
expectedLabels: []metrics.Label{
|
|
{Name: "method", Value: "E.F"},
|
|
{Name: "errored", Value: "false"},
|
|
{Name: "request_type", Value: "read"},
|
|
{Name: "rpc_type", Value: RPCTypeNetRPC},
|
|
{Name: "leader", Value: "unreported"},
|
|
{Name: "allow_stale", Value: "false"},
|
|
{Name: "blocking", Value: "true"},
|
|
{Name: "target_datacenter", Value: "dc3"},
|
|
{Name: "locality", Value: "forwarded"},
|
|
},
|
|
},
|
|
{
|
|
name: "write request with TD, locality local",
|
|
description: "Checks for write request with local forwarding and target dc",
|
|
requestName: "F.G",
|
|
requestI: writeReqWithTD{},
|
|
rpcType: RPCTypeNetRPC,
|
|
errored: false,
|
|
dc: "dc2",
|
|
expectedLabels: []metrics.Label{
|
|
{Name: "method", Value: "F.G"},
|
|
{Name: "errored", Value: "false"},
|
|
{Name: "request_type", Value: "write"},
|
|
{Name: "rpc_type", Value: RPCTypeNetRPC},
|
|
{Name: "leader", Value: "unreported"},
|
|
{Name: "target_datacenter", Value: "dc2"},
|
|
{Name: "locality", Value: "local"},
|
|
},
|
|
},
|
|
{
|
|
name: "is leader",
|
|
description: "checks for is leader",
|
|
requestName: "G.H",
|
|
requestI: struct{}{},
|
|
rpcType: "test",
|
|
errored: false,
|
|
isLeader: func() bool {
|
|
return true
|
|
},
|
|
expectedLabels: []metrics.Label{
|
|
{Name: "method", Value: "G.H"},
|
|
{Name: "errored", Value: "false"},
|
|
{Name: "request_type", Value: "unreported"},
|
|
{Name: "rpc_type", Value: "test"},
|
|
{Name: "leader", Value: "true"},
|
|
},
|
|
},
|
|
{
|
|
name: "is not leader",
|
|
description: "checks for is not leader",
|
|
requestName: "H.I",
|
|
requestI: struct{}{},
|
|
rpcType: "test",
|
|
errored: false,
|
|
isLeader: func() bool {
|
|
return false
|
|
},
|
|
expectedLabels: []metrics.Label{
|
|
{Name: "method", Value: "H.I"},
|
|
{Name: "errored", Value: "false"},
|
|
{Name: "request_type", Value: "unreported"},
|
|
{Name: "rpc_type", Value: "test"},
|
|
{Name: "leader", Value: "false"},
|
|
},
|
|
},
|
|
}
|
|
|
|
// TestRequestRecorder goes over all the parsing and reporting that RequestRecorder
|
|
// is expected to perform.
|
|
func TestRequestRecorder(t *testing.T) {
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
|
|
r := RequestRecorder{
|
|
Logger: hclog.NewInterceptLogger(&hclog.LoggerOptions{}),
|
|
RecorderFunc: simpleRecorderFunc,
|
|
serverIsLeader: tc.isLeader,
|
|
localDC: tc.dc,
|
|
}
|
|
|
|
start := time.Now()
|
|
r.Record(tc.requestName, tc.rpcType, start, tc.requestI, tc.errored)
|
|
|
|
key := append(metricRPCRequest, tc.expectedLabels[0].Value)
|
|
o := store.get(key)
|
|
|
|
require.Equal(t, o.key, metricRPCRequest)
|
|
require.LessOrEqual(t, o.elapsed, float32(time.Now().Sub(start).Microseconds())/1000)
|
|
require.Equal(t, o.labels, tc.expectedLabels)
|
|
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetNetRPCRateLimitingInterceptor(t *testing.T) {
|
|
limiter := rate.NewMockRequestLimitsHandler(t)
|
|
|
|
logger := hclog.NewNullLogger()
|
|
rateLimitInterceptor := GetNetRPCRateLimitingInterceptor(limiter, NewPanicHandler(logger))
|
|
|
|
addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("1.2.3.4:5678"))
|
|
|
|
t.Run("allow operation", func(t *testing.T) {
|
|
limiter.On("Allow", mock.Anything).
|
|
Return(nil).
|
|
Once()
|
|
|
|
err := rateLimitInterceptor("Status.Leader", addr)
|
|
require.NoError(t, err)
|
|
})
|
|
|
|
t.Run("allow returns error", func(t *testing.T) {
|
|
limiter.On("Allow", mock.Anything).
|
|
Return(errors.New("uh oh")).
|
|
Once()
|
|
|
|
err := rateLimitInterceptor("Status.Leader", addr)
|
|
require.Error(t, err)
|
|
require.Equal(t, "uh oh", err.Error())
|
|
})
|
|
|
|
t.Run("allow panics", func(t *testing.T) {
|
|
limiter.On("Allow", mock.Anything).
|
|
Panic("uh oh").
|
|
Once()
|
|
|
|
err := rateLimitInterceptor("Status.Leader", addr)
|
|
|
|
require.Error(t, err)
|
|
require.Equal(t, "rpc: panic serving request", err.Error())
|
|
})
|
|
}
|