// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 package middleware import ( "errors" "net" "net/netip" "strings" "sync" "testing" "time" "github.com/armon/go-metrics" "github.com/hashicorp/consul/agent/consul/rate" "github.com/hashicorp/go-hclog" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) // obs holds all the things we want to assert on that we recorded correctly in our tests. type obs struct { key []string elapsed float32 labels []metrics.Label } // recorderStore acts as an in-mem mock storage for all the RequestRecorder.Record() RecorderFunc calls. type recorderStore struct { lock sync.Mutex store map[string]obs } func (rs *recorderStore) put(key []string, o obs) { rs.lock.Lock() defer rs.lock.Unlock() actualKey := strings.Join(append(key, o.labels[0].Value), "") rs.store[actualKey] = o } func (rs *recorderStore) get(key []string) obs { rs.lock.Lock() defer rs.lock.Unlock() actualKey := strings.Join(key, "") return rs.store[actualKey] } var store = recorderStore{store: make(map[string]obs)} var simpleRecorderFunc = func(key []string, val float32, labels []metrics.Label) { o := obs{key: key, elapsed: val, labels: labels} store.put(key, o) } type readRequest struct{} type writeRequest struct{} type readReqWithTD struct{} type writeReqWithTD struct{} func (rr readRequest) IsRead() bool { return true } func (wr writeRequest) IsRead() bool { return false } func (r readReqWithTD) IsRead() bool { return true } func (r readReqWithTD) RequestDatacenter() string { return "dc3" } func (r readReqWithTD) GetMinQueryIndex() uint64 { return 1 } func (r readReqWithTD) AllowStaleRead() bool { return false } func (w writeReqWithTD) IsRead() bool { return false } func (w writeReqWithTD) RequestDatacenter() string { return "dc2" } type testCase struct { name string // description is meant for human friendliness description string // requestName is encouraged to be unique across tests to // avoid lock contention requestName string requestI interface{} rpcType string errored bool isLeader func() bool dc string // the first element in expectedLabels should be the method name expectedLabels []metrics.Label } var testCases = []testCase{ { name: "simple ok", description: "This is a simple happy path test case. We check for pass through and normal request processing", requestName: "A.B", requestI: struct{}{}, rpcType: RPCTypeInternal, errored: false, dc: "dc1", expectedLabels: []metrics.Label{ {Name: "method", Value: "A.B"}, {Name: "errored", Value: "false"}, {Name: "request_type", Value: "unreported"}, {Name: "rpc_type", Value: RPCTypeInternal}, {Name: "leader", Value: "unreported"}, }, }, { name: "simple ok errored", description: "Checks that the errored value is populated right.", requestName: "A.C", requestI: struct{}{}, rpcType: "test", errored: true, dc: "dc1", expectedLabels: []metrics.Label{ {Name: "method", Value: "A.C"}, {Name: "errored", Value: "true"}, {Name: "request_type", Value: "unreported"}, {Name: "rpc_type", Value: "test"}, {Name: "leader", Value: "unreported"}, }, }, { name: "read request, rpc type internal", description: "Checks for read request interface parsing", requestName: "B.C", requestI: readRequest{}, rpcType: RPCTypeInternal, errored: false, dc: "dc1", expectedLabels: []metrics.Label{ {Name: "method", Value: "B.C"}, {Name: "errored", Value: "false"}, {Name: "request_type", Value: "read"}, {Name: "rpc_type", Value: RPCTypeInternal}, {Name: "leader", Value: "unreported"}, }, }, { name: "write request, rpc type net/rpc", description: "Checks for write request interface, different RPC type", requestName: "D.E", requestI: writeRequest{}, rpcType: RPCTypeNetRPC, errored: false, dc: "dc1", expectedLabels: []metrics.Label{ {Name: "method", Value: "D.E"}, {Name: "errored", Value: "false"}, {Name: "request_type", Value: "write"}, {Name: "rpc_type", Value: RPCTypeNetRPC}, {Name: "leader", Value: "unreported"}, }, }, { name: "read request with blocking stale and target dc", description: "Checks for locality, blocking status and target dc", requestName: "E.F", requestI: readReqWithTD{}, rpcType: RPCTypeNetRPC, errored: false, dc: "dc1", expectedLabels: []metrics.Label{ {Name: "method", Value: "E.F"}, {Name: "errored", Value: "false"}, {Name: "request_type", Value: "read"}, {Name: "rpc_type", Value: RPCTypeNetRPC}, {Name: "leader", Value: "unreported"}, {Name: "allow_stale", Value: "false"}, {Name: "blocking", Value: "true"}, {Name: "target_datacenter", Value: "dc3"}, {Name: "locality", Value: "forwarded"}, }, }, { name: "write request with TD, locality local", description: "Checks for write request with local forwarding and target dc", requestName: "F.G", requestI: writeReqWithTD{}, rpcType: RPCTypeNetRPC, errored: false, dc: "dc2", expectedLabels: []metrics.Label{ {Name: "method", Value: "F.G"}, {Name: "errored", Value: "false"}, {Name: "request_type", Value: "write"}, {Name: "rpc_type", Value: RPCTypeNetRPC}, {Name: "leader", Value: "unreported"}, {Name: "target_datacenter", Value: "dc2"}, {Name: "locality", Value: "local"}, }, }, { name: "is leader", description: "checks for is leader", requestName: "G.H", requestI: struct{}{}, rpcType: "test", errored: false, isLeader: func() bool { return true }, expectedLabels: []metrics.Label{ {Name: "method", Value: "G.H"}, {Name: "errored", Value: "false"}, {Name: "request_type", Value: "unreported"}, {Name: "rpc_type", Value: "test"}, {Name: "leader", Value: "true"}, }, }, { name: "is not leader", description: "checks for is not leader", requestName: "H.I", requestI: struct{}{}, rpcType: "test", errored: false, isLeader: func() bool { return false }, expectedLabels: []metrics.Label{ {Name: "method", Value: "H.I"}, {Name: "errored", Value: "false"}, {Name: "request_type", Value: "unreported"}, {Name: "rpc_type", Value: "test"}, {Name: "leader", Value: "false"}, }, }, } // TestRequestRecorder goes over all the parsing and reporting that RequestRecorder // is expected to perform. func TestRequestRecorder(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { r := RequestRecorder{ Logger: hclog.NewInterceptLogger(&hclog.LoggerOptions{}), RecorderFunc: simpleRecorderFunc, serverIsLeader: tc.isLeader, localDC: tc.dc, } start := time.Now() r.Record(tc.requestName, tc.rpcType, start, tc.requestI, tc.errored) key := append(metricRPCRequest, tc.expectedLabels[0].Value) o := store.get(key) require.Equal(t, o.key, metricRPCRequest) require.LessOrEqual(t, o.elapsed, float32(time.Now().Sub(start).Microseconds())/1000) require.Equal(t, o.labels, tc.expectedLabels) }) } } func TestGetNetRPCRateLimitingInterceptor(t *testing.T) { limiter := rate.NewMockRequestLimitsHandler(t) logger := hclog.NewNullLogger() rateLimitInterceptor := GetNetRPCRateLimitingInterceptor(limiter, NewPanicHandler(logger)) addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("1.2.3.4:5678")) t.Run("allow operation", func(t *testing.T) { limiter.On("Allow", mock.Anything). Return(nil). Once() err := rateLimitInterceptor("Status.Leader", addr) require.NoError(t, err) }) t.Run("allow returns error", func(t *testing.T) { limiter.On("Allow", mock.Anything). Return(errors.New("uh oh")). Once() err := rateLimitInterceptor("Status.Leader", addr) require.Error(t, err) require.Equal(t, "uh oh", err.Error()) }) t.Run("allow panics", func(t *testing.T) { limiter.On("Allow", mock.Anything). Panic("uh oh"). Once() err := rateLimitInterceptor("Status.Leader", addr) require.Error(t, err) require.Equal(t, "rpc: panic serving request", err.Error()) }) }