consul/agent/rpc/middleware/interceptors.go

209 lines
5.7 KiB
Go
Raw Normal View History

// Copyright (c) HashiCorp, Inc.
[COMPLIANCE] License changes (#18443) * Adding explicit MPL license for sub-package This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Adding explicit MPL license for sub-package This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository. * Updating the license from MPL to Business Source License Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at <Blog URL>, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl. * add missing license headers * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 * Update copyright file headers to BUSL-1.1 --------- Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com>
2023-08-11 09:12:13 -04:00
// SPDX-License-Identifier: BUSL-1.1
package middleware
import (
"net"
"reflect"
"strconv"
"strings"
"time"
"github.com/armon/go-metrics"
"github.com/armon/go-metrics/prometheus"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/consul-net-rpc/net/rpc"
rpcRate "github.com/hashicorp/consul/agent/consul/rate"
)
// RPCTypeInternal identifies the "RPC" request as coming from some internal
// operation that runs on the cluster leader. Technically this is not an RPC
// request, but these raft.Apply operations have the same impact on blocking
// queries, and streaming subscriptions, so need to be tracked by the same metric
// and logs.
// Really what we are measuring here is a "cluster operation". The term we have
// used for this historically is "RPC", so we continue to use that here.
const RPCTypeInternal = "internal"
const RPCTypeNetRPC = "net/rpc"
var metricRPCRequest = []string{"rpc", "server", "call"}
var requestLogName = strings.Join(metricRPCRequest, "_")
var OneTwelveRPCSummary = []prometheus.SummaryDefinition{
{
Name: metricRPCRequest,
Help: "Measures the time an RPC service call takes to make in milliseconds. Labels mark which RPC method was called and metadata about the call.",
},
}
type RequestRecorder struct {
Logger hclog.Logger
RecorderFunc func(key []string, val float32, labels []metrics.Label)
serverIsLeader func() bool
localDC string
}
func NewRequestRecorder(logger hclog.Logger, isLeader func() bool, localDC string) *RequestRecorder {
return &RequestRecorder{
Logger: logger,
RecorderFunc: metrics.AddSampleWithLabels,
serverIsLeader: isLeader,
localDC: localDC,
}
}
func (r *RequestRecorder) Record(requestName string, rpcType string, start time.Time, request interface{}, respErrored bool) {
elapsed := time.Since(start).Microseconds()
elapsedMs := float32(elapsed) / 1000
reqType := requestType(request)
isLeader := r.getServerLeadership()
labels := []metrics.Label{
{Name: "method", Value: requestName},
{Name: "errored", Value: strconv.FormatBool(respErrored)},
{Name: "request_type", Value: reqType},
{Name: "rpc_type", Value: rpcType},
{Name: "leader", Value: isLeader},
}
labels = r.addOptionalLabels(request, labels)
// math.MaxInt64 < math.MaxFloat32 is true so we should be good!
r.RecorderFunc(metricRPCRequest, elapsedMs, labels)
labelsArr := flattenLabels(labels)
r.Logger.Trace(requestLogName, labelsArr...)
}
func flattenLabels(labels []metrics.Label) []interface{} {
var labelArr []interface{}
for _, label := range labels {
labelArr = append(labelArr, label.Name, label.Value)
}
return labelArr
}
func (r *RequestRecorder) addOptionalLabels(request interface{}, labels []metrics.Label) []metrics.Label {
if rq, ok := request.(readQuery); ok {
labels = append(labels,
metrics.Label{
Name: "allow_stale",
Value: strconv.FormatBool(rq.AllowStaleRead()),
},
metrics.Label{
Name: "blocking",
Value: strconv.FormatBool(rq.GetMinQueryIndex() > 0),
})
}
if td, ok := request.(targetDC); ok {
requestDC := td.RequestDatacenter()
labels = append(labels, metrics.Label{Name: "target_datacenter", Value: requestDC})
if r.localDC == requestDC {
labels = append(labels, metrics.Label{Name: "locality", Value: "local"})
} else {
labels = append(labels, metrics.Label{Name: "locality", Value: "forwarded"})
}
}
return labels
}
func requestType(req interface{}) string {
if r, ok := req.(interface{ IsRead() bool }); ok {
if r.IsRead() {
return "read"
} else {
return "write"
}
}
// This logical branch should not happen. If it happens
// it means an underlying request is not implementing the interface.
// Rather than swallowing it up in a "read" or "write", let's be aware of it.
return "unreported"
}
func (r *RequestRecorder) getServerLeadership() string {
if r.serverIsLeader != nil {
if r.serverIsLeader() {
return "true"
} else {
return "false"
}
}
// This logical branch should not happen. If it happens
// it means that we have not plumbed down a way to verify
// whether the server handling the request was a leader or not
return "unreported"
}
type readQuery interface {
GetMinQueryIndex() uint64
AllowStaleRead() bool
}
type targetDC interface {
RequestDatacenter() string
}
func GetNetRPCInterceptor(recorder *RequestRecorder) rpc.ServerServiceCallInterceptor {
return func(reqServiceMethod string, argv, replyv reflect.Value, handler func() error) {
reqStart := time.Now()
err := handler()
recorder.Record(reqServiceMethod, RPCTypeNetRPC, reqStart, argv.Interface(), err != nil)
}
}
func GetNetRPCRateLimitingInterceptor(requestLimitsHandler rpcRate.RequestLimitsHandler, panicHandler RecoveryHandlerFunc) rpc.PreBodyInterceptor {
return func(reqServiceMethod string, sourceAddr net.Addr) (retErr error) {
defer func() {
if r := recover(); r != nil {
retErr = panicHandler(r)
}
}()
op := rpcRate.Operation{
Name: reqServiceMethod,
SourceAddr: sourceAddr,
Type: rpcRateLimitSpecs[reqServiceMethod].Type,
Category: rpcRateLimitSpecs[reqServiceMethod].Category,
}
// net/rpc does not provide a way to encode the nuances of the
// error response (retry or retry elsewhere) so the error string
// from the rate limiter is all that we have.
return requestLimitsHandler.Allow(op)
}
}
func ChainedRPCPreBodyInterceptor(chain ...rpc.PreBodyInterceptor) rpc.PreBodyInterceptor {
if len(chain) == 0 {
panic("don't call this with zero interceptors")
}
if len(chain) == 1 {
return chain[0]
}
return func(reqServiceMethod string, sourceAddr net.Addr) error {
for _, interceptor := range chain {
if err := interceptor(reqServiceMethod, sourceAddr); err != nil {
return err
}
}
return nil
}
}