consul/agent/grpc-middleware/rate.go

69 lines
2.0 KiB
Go
Raw Normal View History

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package middleware
import (
"context"
"errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"google.golang.org/grpc/tap"
recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
"github.com/hashicorp/consul/agent/consul/rate"
)
// ServerRateLimiterMiddleware implements a ServerInHandle function to perform
// RPC rate limiting at the cheapest possible point (before the full request has
// been decoded).
2023-01-04 16:07:02 +00:00
func ServerRateLimiterMiddleware(limiter rate.RequestLimitsHandler, panicHandler recovery.RecoveryHandlerFunc, logger Logger) tap.ServerInHandle {
return func(ctx context.Context, info *tap.Info) (_ context.Context, retErr error) {
// This function is called before unary and stream RPC interceptors, so we
// must handle our own panics here.
defer func() {
if r := recover(); r != nil {
retErr = panicHandler(r)
}
}()
// Do not rate-limit the xDS service, it handles its own limiting.
if info.FullMethodName == "/envoy.service.discovery.v3.AggregatedDiscoveryService/DeltaAggregatedResources" {
return ctx, nil
}
peer, ok := peer.FromContext(ctx)
if !ok {
// This should never happen!
return ctx, status.Error(codes.Internal, "gRPC rate limit middleware unable to read peer")
}
operationSpec, ok := rpcRateLimitSpecs[info.FullMethodName]
2023-01-04 16:07:02 +00:00
if !ok {
logger.Warn("failed to determine which rate limit to apply to RPC", "rpc", info.FullMethodName)
return ctx, nil
}
err := limiter.Allow(rate.Operation{
Name: info.FullMethodName,
SourceAddr: peer.Addr,
Type: operationSpec.Type,
Category: operationSpec.Category,
})
switch {
case err == nil:
return ctx, nil
case errors.Is(err, rate.ErrRetryElsewhere):
return ctx, status.Error(codes.ResourceExhausted, err.Error())
case errors.Is(err, rate.ErrRetryLater):
return ctx, status.Error(codes.Unavailable, err.Error())
default:
return ctx, status.Error(codes.Internal, err.Error())
}
}
}