mirror of
https://github.com/status-im/consul.git
synced 2025-01-10 22:06:20 +00:00
feat: panic handler in rpc rate limit interceptor (#16022)
* feat: handle panic in rpc rate limit interceptor * test: additional test cases to rpc rate limiting interceptor * refactor: remove unused listener
This commit is contained in:
parent
e0f4f6c152
commit
f4f62b5da6
@ -490,7 +490,7 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server, incom
|
||||
}
|
||||
|
||||
rpcServerOpts := []func(*rpc.Server){
|
||||
rpc.WithPreBodyInterceptor(middleware.GetNetRPCRateLimitingInterceptor(s.incomingRPCLimiter)),
|
||||
rpc.WithPreBodyInterceptor(middleware.GetNetRPCRateLimitingInterceptor(s.incomingRPCLimiter, middleware.NewPanicHandler(s.logger))),
|
||||
}
|
||||
|
||||
if flat.GetNetRPCInterceptorFunc != nil {
|
||||
|
@ -160,9 +160,16 @@ func GetNetRPCInterceptor(recorder *RequestRecorder) rpc.ServerServiceCallInterc
|
||||
}
|
||||
}
|
||||
|
||||
func GetNetRPCRateLimitingInterceptor(requestLimitsHandler rpcRate.RequestLimitsHandler) rpc.PreBodyInterceptor {
|
||||
func GetNetRPCRateLimitingInterceptor(requestLimitsHandler rpcRate.RequestLimitsHandler, panicHandler RecoveryHandlerFunc) rpc.PreBodyInterceptor {
|
||||
|
||||
return func(reqServiceMethod string, sourceAddr net.Addr) (retErr error) {
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
retErr = panicHandler(r)
|
||||
}
|
||||
}()
|
||||
|
||||
return func(reqServiceMethod string, sourceAddr net.Addr) error {
|
||||
op := rpcRate.Operation{
|
||||
Name: reqServiceMethod,
|
||||
SourceAddr: sourceAddr,
|
||||
|
@ -1,13 +1,18 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-metrics"
|
||||
"github.com/hashicorp/consul/agent/consul/rate"
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -266,3 +271,42 @@ func TestRequestRecorder(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNetRPCRateLimitingInterceptor(t *testing.T) {
|
||||
limiter := rate.NewMockRequestLimitsHandler(t)
|
||||
|
||||
logger := hclog.NewNullLogger()
|
||||
rateLimitInterceptor := GetNetRPCRateLimitingInterceptor(limiter, NewPanicHandler(logger))
|
||||
|
||||
addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("1.2.3.4:5678"))
|
||||
|
||||
t.Run("allow operation", func(t *testing.T) {
|
||||
limiter.On("Allow", mock.Anything).
|
||||
Return(nil).
|
||||
Once()
|
||||
|
||||
err := rateLimitInterceptor("Status.Leader", addr)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("allow returns error", func(t *testing.T) {
|
||||
limiter.On("Allow", mock.Anything).
|
||||
Return(errors.New("uh oh")).
|
||||
Once()
|
||||
|
||||
err := rateLimitInterceptor("Status.Leader", addr)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "uh oh", err.Error())
|
||||
})
|
||||
|
||||
t.Run("allow panics", func(t *testing.T) {
|
||||
limiter.On("Allow", mock.Anything).
|
||||
Panic("uh oh").
|
||||
Once()
|
||||
|
||||
err := rateLimitInterceptor("Status.Leader", addr)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "rpc: panic serving request", err.Error())
|
||||
})
|
||||
}
|
||||
|
24
agent/rpc/middleware/recovery.go
Normal file
24
agent/rpc/middleware/recovery.go
Normal file
@ -0,0 +1,24 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
)
|
||||
|
||||
// NewPanicHandler returns a RecoveryHandlerFunc type function
|
||||
// to handle panic in RPC server's handlers.
|
||||
func NewPanicHandler(logger hclog.Logger) RecoveryHandlerFunc {
|
||||
return func(p interface{}) (err error) {
|
||||
// Log the panic and the stack trace of the Goroutine that caused the panic.
|
||||
stacktrace := hclog.Stacktrace()
|
||||
logger.Error("panic serving rpc request",
|
||||
"panic", p,
|
||||
"stack", stacktrace,
|
||||
)
|
||||
|
||||
return fmt.Errorf("rpc: panic serving request")
|
||||
}
|
||||
}
|
||||
|
||||
type RecoveryHandlerFunc func(p interface{}) (err error)
|
Loading…
x
Reference in New Issue
Block a user