status-go/whisper/rate_limiter.go
Andrea Maria Piana cec7c52505 Don't leak goroutines on decorator
When an error occours (or the peer disconnect), we return from decorator
as an error is published to the chan.
There are though still 5 or 6 goroutines that want to write on that
channel and at least 3/4 of them will be hanging, leaving them stuck
publishing on the chan.
This though is probably not the cause of
https://github.com/status-im/infra-eth-cluster/issues/39
fills up the stack trace with hung go routines.
2021-01-26 17:58:25 +01:00

268 lines
6.1 KiB
Go

package whisper
import (
"bytes"
"errors"
"fmt"
"time"
"github.com/tsenart/tb"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
)
type runLoop func(p *Peer, rw p2p.MsgReadWriter) error
type RateLimiterHandler interface {
ExceedPeerLimit() error
ExceedIPLimit() error
}
type MetricsRateLimiterHandler struct{}
func (MetricsRateLimiterHandler) ExceedPeerLimit() error {
rateLimitsExceeded.WithLabelValues("peer_id").Inc()
return nil
}
func (MetricsRateLimiterHandler) ExceedIPLimit() error {
rateLimitsExceeded.WithLabelValues("ip").Inc()
return nil
}
// RateLimits contains information about rate limit settings.
// It is exchanged using rateLimitingCode packet or in the handshake.
type RateLimits struct {
IPLimits uint64 // messages per second from a single IP (default 0, no limits)
PeerIDLimits uint64 // messages per second from a single peer ID (default 0, no limits)
TopicLimits uint64 // messages per second from a single topic (default 0, no limits)
}
func (r RateLimits) IsZero() bool {
return r == (RateLimits{})
}
var ErrRateLimitExceeded = errors.New("rate limit has been exceeded")
type DropPeerRateLimiterHandler struct {
// Tolerance is a number of how many a limit must be exceeded
// in order to drop a peer.
Tolerance int64
peerLimitExceeds int64
ipLimitExceeds int64
}
func (h *DropPeerRateLimiterHandler) ExceedPeerLimit() error {
h.peerLimitExceeds++
if h.Tolerance > 0 && h.peerLimitExceeds >= h.Tolerance {
return ErrRateLimitExceeded
}
return nil
}
func (h *DropPeerRateLimiterHandler) ExceedIPLimit() error {
h.ipLimitExceeds++
if h.Tolerance > 0 && h.ipLimitExceeds >= h.Tolerance {
return ErrRateLimitExceeded
}
return nil
}
type PeerRateLimiterConfig struct {
LimitPerSecIP int64
LimitPerSecPeerID int64
WhitelistedIPs []string
WhitelistedPeerIDs []enode.ID
}
var defaultPeerRateLimiterConfig = PeerRateLimiterConfig{
LimitPerSecIP: 10,
LimitPerSecPeerID: 5,
WhitelistedIPs: nil,
WhitelistedPeerIDs: nil,
}
type PeerRateLimiter struct {
peerIDThrottler *tb.Throttler
ipThrottler *tb.Throttler
limitPerSecIP int64
limitPerSecPeerID int64
whitelistedPeerIDs []enode.ID
whitelistedIPs []string
handlers []RateLimiterHandler
}
func NewPeerRateLimiter(cfg *PeerRateLimiterConfig, handlers ...RateLimiterHandler) *PeerRateLimiter {
if cfg == nil {
copy := defaultPeerRateLimiterConfig
cfg = &copy
}
return &PeerRateLimiter{
peerIDThrottler: tb.NewThrottler(time.Millisecond * 100),
ipThrottler: tb.NewThrottler(time.Millisecond * 100),
limitPerSecIP: cfg.LimitPerSecIP,
limitPerSecPeerID: cfg.LimitPerSecPeerID,
whitelistedPeerIDs: cfg.WhitelistedPeerIDs,
whitelistedIPs: cfg.WhitelistedIPs,
handlers: handlers,
}
}
func (r *PeerRateLimiter) decorate(p *Peer, rw p2p.MsgReadWriter, runLoop runLoop) error {
in, out := p2p.MsgPipe()
defer in.Close()
defer out.Close()
errC := make(chan error, 1)
// Read from the original reader and write to the message pipe.
go func() {
for {
packet, err := rw.ReadMsg()
if err != nil {
// We don't block as that might leak goroutines
select {
case errC <- fmt.Errorf("failed to read packet: %v", err):
return
default:
return
}
}
rateLimitsProcessed.Inc()
var ip string
if p != nil && p.peer != nil {
ip = p.peer.Node().IP().String()
}
if halted := r.throttleIP(ip); halted {
for _, h := range r.handlers {
if err := h.ExceedIPLimit(); err != nil {
// We don't block as that might leak goroutines
select {
case errC <- fmt.Errorf("exceed rate limit by IP: %v", err):
return
default:
return
}
}
}
}
var peerID []byte
if p != nil {
peerID = p.ID()
}
if halted := r.throttlePeer(peerID); halted {
for _, h := range r.handlers {
if err := h.ExceedPeerLimit(); err != nil {
// We don't block as that might leak goroutines
select {
case errC <- fmt.Errorf("exceeded rate limit by peer: %v", err):
return
default:
return
}
}
}
}
if err := in.WriteMsg(packet); err != nil {
// We don't block as that might leak goroutines
select {
case errC <- fmt.Errorf("failed to write packet to pipe: %v", err):
return
default:
return
}
}
}
}()
// Read from the message pipe and write to the original writer.
go func() {
for {
packet, err := in.ReadMsg()
if err != nil {
// We don't block as that might leak goroutines
select {
case errC <- fmt.Errorf("failed to read packet from pipe: %v", err):
return
default:
return
}
}
if err := rw.WriteMsg(packet); err != nil {
// We don't block as that might leak goroutines
select {
case errC <- fmt.Errorf("failed to write packet: %v", err):
return
default:
return
}
}
}
}()
go func() {
// We don't block as that might leak goroutines
select {
case errC <- runLoop(p, out):
return
default:
return
}
}()
return <-errC
}
// throttleIP throttles a number of messages incoming from a given IP.
// It allows 10 packets per second.
func (r *PeerRateLimiter) throttleIP(ip string) bool {
if r.limitPerSecIP == 0 {
return false
}
if stringSliceContains(r.whitelistedIPs, ip) {
return false
}
return r.ipThrottler.Halt(ip, 1, r.limitPerSecIP)
}
// throttlePeer throttles a number of messages incoming from a peer.
// It allows 3 packets per second.
func (r *PeerRateLimiter) throttlePeer(peerID []byte) bool {
if r.limitPerSecIP == 0 {
return false
}
var id enode.ID
copy(id[:], peerID)
if enodeIDSliceContains(r.whitelistedPeerIDs, id) {
return false
}
return r.peerIDThrottler.Halt(id.String(), 1, r.limitPerSecPeerID)
}
func stringSliceContains(s []string, searched string) bool {
for _, item := range s {
if item == searched {
return true
}
}
return false
}
func enodeIDSliceContains(s []enode.ID, searched enode.ID) bool {
for _, item := range s {
if bytes.Equal(item.Bytes(), searched.Bytes()) {
return true
}
}
return false
}