mirror of https://github.com/status-im/consul.git
feat: panic handler in rpc rate limit interceptor (#16022)
* feat: handle panic in rpc rate limit interceptor * test: additional test cases to rpc rate limiting interceptor * refactor: remove unused listener
This commit is contained in:
parent
e0f4f6c152
commit
f4f62b5da6
|
@ -490,7 +490,7 @@ func NewServer(config *Config, flat Deps, externalGRPCServer *grpc.Server, incom
|
||||||
}
|
}
|
||||||
|
|
||||||
rpcServerOpts := []func(*rpc.Server){
|
rpcServerOpts := []func(*rpc.Server){
|
||||||
rpc.WithPreBodyInterceptor(middleware.GetNetRPCRateLimitingInterceptor(s.incomingRPCLimiter)),
|
rpc.WithPreBodyInterceptor(middleware.GetNetRPCRateLimitingInterceptor(s.incomingRPCLimiter, middleware.NewPanicHandler(s.logger))),
|
||||||
}
|
}
|
||||||
|
|
||||||
if flat.GetNetRPCInterceptorFunc != nil {
|
if flat.GetNetRPCInterceptorFunc != nil {
|
||||||
|
|
|
@ -160,9 +160,16 @@ func GetNetRPCInterceptor(recorder *RequestRecorder) rpc.ServerServiceCallInterc
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetNetRPCRateLimitingInterceptor(requestLimitsHandler rpcRate.RequestLimitsHandler) rpc.PreBodyInterceptor {
|
func GetNetRPCRateLimitingInterceptor(requestLimitsHandler rpcRate.RequestLimitsHandler, panicHandler RecoveryHandlerFunc) rpc.PreBodyInterceptor {
|
||||||
|
|
||||||
|
return func(reqServiceMethod string, sourceAddr net.Addr) (retErr error) {
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
retErr = panicHandler(r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
return func(reqServiceMethod string, sourceAddr net.Addr) error {
|
|
||||||
op := rpcRate.Operation{
|
op := rpcRate.Operation{
|
||||||
Name: reqServiceMethod,
|
Name: reqServiceMethod,
|
||||||
SourceAddr: sourceAddr,
|
SourceAddr: sourceAddr,
|
||||||
|
|
|
@ -1,13 +1,18 @@
|
||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/armon/go-metrics"
|
"github.com/armon/go-metrics"
|
||||||
|
"github.com/hashicorp/consul/agent/consul/rate"
|
||||||
"github.com/hashicorp/go-hclog"
|
"github.com/hashicorp/go-hclog"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -266,3 +271,42 @@ func TestRequestRecorder(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetNetRPCRateLimitingInterceptor(t *testing.T) {
|
||||||
|
limiter := rate.NewMockRequestLimitsHandler(t)
|
||||||
|
|
||||||
|
logger := hclog.NewNullLogger()
|
||||||
|
rateLimitInterceptor := GetNetRPCRateLimitingInterceptor(limiter, NewPanicHandler(logger))
|
||||||
|
|
||||||
|
addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("1.2.3.4:5678"))
|
||||||
|
|
||||||
|
t.Run("allow operation", func(t *testing.T) {
|
||||||
|
limiter.On("Allow", mock.Anything).
|
||||||
|
Return(nil).
|
||||||
|
Once()
|
||||||
|
|
||||||
|
err := rateLimitInterceptor("Status.Leader", addr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("allow returns error", func(t *testing.T) {
|
||||||
|
limiter.On("Allow", mock.Anything).
|
||||||
|
Return(errors.New("uh oh")).
|
||||||
|
Once()
|
||||||
|
|
||||||
|
err := rateLimitInterceptor("Status.Leader", addr)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, "uh oh", err.Error())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("allow panics", func(t *testing.T) {
|
||||||
|
limiter.On("Allow", mock.Anything).
|
||||||
|
Panic("uh oh").
|
||||||
|
Once()
|
||||||
|
|
||||||
|
err := rateLimitInterceptor("Status.Leader", addr)
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Equal(t, "rpc: panic serving request", err.Error())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-hclog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewPanicHandler returns a RecoveryHandlerFunc type function
|
||||||
|
// to handle panic in RPC server's handlers.
|
||||||
|
func NewPanicHandler(logger hclog.Logger) RecoveryHandlerFunc {
|
||||||
|
return func(p interface{}) (err error) {
|
||||||
|
// Log the panic and the stack trace of the Goroutine that caused the panic.
|
||||||
|
stacktrace := hclog.Stacktrace()
|
||||||
|
logger.Error("panic serving rpc request",
|
||||||
|
"panic", p,
|
||||||
|
"stack", stacktrace,
|
||||||
|
)
|
||||||
|
|
||||||
|
return fmt.Errorf("rpc: panic serving request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type RecoveryHandlerFunc func(p interface{}) (err error)
|
Loading…
Reference in New Issue