From 4b244c6ced1bd7ae4c61b084b52af52cea759dcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=A1clav=20Pavl=C3=ADn?= Date: Wed, 29 May 2024 13:40:31 +0200 Subject: [PATCH] add per IP rate limitting --- go.mod | 4 +- go.sum | 2 - telemetry/ratelimiter.go | 91 +++++++++++++++++++++++++++++++++++ telemetry/ratelimiter_test.go | 52 ++++++++++++++++++++ telemetry/server.go | 39 ++++++++++++--- 5 files changed, 178 insertions(+), 10 deletions(-) create mode 100644 telemetry/ratelimiter.go create mode 100644 telemetry/ratelimiter_test.go diff --git a/go.mod b/go.mod index 1666c1d..8b80665 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,11 @@ module github.com/status-im/dev-telemetry go 1.15 require ( - github.com/go-bindata/go-bindata v3.1.2+incompatible // indirect github.com/golang-migrate/migrate/v4 v4.15.2 - github.com/gorilla/mux v1.8.0 + github.com/gorilla/mux v1.8.0 // indirect github.com/lib/pq v1.10.3 github.com/robfig/cron/v3 v3.0.1 github.com/stretchr/testify v1.8.1 go.uber.org/zap v1.27.0 + golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 ) diff --git a/go.sum b/go.sum index cc4b06e..a9337e8 100644 --- a/go.sum +++ b/go.sum @@ -408,8 +408,6 @@ github.com/garyburd/redigo v0.0.0-20150301180006-535138d7bcd7/go.mod h1:NR3MbYis github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= -github.com/go-bindata/go-bindata v3.1.2+incompatible h1:5vjJMVhowQdPzjE1LdxyFF7YFTXg5IgGVW4gBr5IbvE= -github.com/go-bindata/go-bindata v3.1.2+incompatible/go.mod h1:xK8Dsgwmeed+BBsSy2XTopBn/8uK2HWuGSnA11C3Joo= github.com/go-fonts/dejavu v0.1.0/go.mod h1:4Wt4I4OU2Nq9asgDCteaAaWZOV24E+0/Pwo0gppep4g= github.com/go-fonts/latin-modern v0.2.0/go.mod h1:rQVLdDMK+mK1xscDwsqM5J8U2jrRa3T0ecnM9pNujks= github.com/go-fonts/liberation v0.1.1/go.mod h1:K6qoJYypsmfVjWg8KOVDQhLc8UDgIK2HYqyqAO9z7GY= diff --git a/telemetry/ratelimiter.go b/telemetry/ratelimiter.go new file mode 100644 index 0000000..404c8e8 --- /dev/null +++ b/telemetry/ratelimiter.go @@ -0,0 +1,91 @@ +package telemetry + +import ( + "context" + "sync" + "time" + + "golang.org/x/time/rate" +) + +const DEFAULT_CLEANUP_TIME = 1 * time.Hour + +type Limiter struct { + limiter *rate.Limiter + lastUsed time.Time +} + +type RateLimiter struct { + limiters map[string]*Limiter + lock *sync.RWMutex + r rate.Limit + b int +} + +func NewRateLimiter(ctx context.Context, r rate.Limit, b int) *RateLimiter { + return NewRateLimiterWithCleanup(ctx, r, b, DEFAULT_CLEANUP_TIME) +} + +func NewRateLimiterWithCleanup(ctx context.Context, r rate.Limit, b int, cleanupTime time.Duration) *RateLimiter { + rl := &RateLimiter{ + limiters: make(map[string]*Limiter), + lock: &sync.RWMutex{}, + r: r, + b: b, + } + + go rl.cleanup(ctx, cleanupTime) + + return rl +} + +func (rl *RateLimiter) GetLimiter(ip string) *rate.Limiter { + rl.lock.Lock() + + limiter, ok := rl.limiters[ip] + if !ok { + rl.lock.Unlock() + return rl.AddIP(ip) + } + + limiter.lastUsed = time.Now() + + rl.lock.Unlock() + return limiter.limiter +} + +func (rl *RateLimiter) AddIP(ip string) *rate.Limiter { + rl.lock.Lock() + defer rl.lock.Unlock() + + limiter := rate.NewLimiter(rl.r, rl.b) + rl.limiters[ip] = &Limiter{limiter: limiter, lastUsed: time.Now()} + + return limiter +} + +func (rl *RateLimiter) RemoveIP(ip string) { + rl.lock.Lock() + defer rl.lock.Unlock() + + delete(rl.limiters, ip) +} + +func (rl *RateLimiter) cleanup(ctx context.Context, cleanupEvery time.Duration) { + t := time.NewTicker(cleanupEvery) + defer t.Stop() + + for { + select { + case <-ctx.Done(): + return + case now := <-t.C: + for ip := range rl.limiters { + limiter := rl.limiters[ip] + if limiter.lastUsed.Add(2 * time.Second).Before(now) { + rl.RemoveIP(ip) + } + } + } + } +} diff --git a/telemetry/ratelimiter_test.go b/telemetry/ratelimiter_test.go new file mode 100644 index 0000000..9176ffc --- /dev/null +++ b/telemetry/ratelimiter_test.go @@ -0,0 +1,52 @@ +package telemetry + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/time/rate" +) + +func TestRateLimit(t *testing.T) { + ctx := context.Background() + defer ctx.Done() + rl := NewRateLimiter(ctx, 1, 1) + + ip1 := "1.1.1.1" + + limiter := rl.GetLimiter(ip1) + require.True(t, limiter.Allow()) + + limiter = rl.GetLimiter(ip1) + require.False(t, limiter.Allow()) + + time.Sleep(1 * time.Second) + limiter = rl.GetLimiter(ip1) + require.True(t, limiter.Allow()) + + ip2 := "2.2.2.2:8080" + limiter = rl.GetLimiter(ip2) + require.True(t, limiter.Allow()) + + limiter = rl.GetLimiter(ip2) + require.False(t, limiter.Allow()) +} + +func TestRateLimitCleanup(t *testing.T) { + ctx := context.Background() + defer ctx.Done() + rl := NewRateLimiterWithCleanup(ctx, rate.Limit(1/time.Hour), 1, 100*time.Millisecond) + + ip1 := "1.1.1.1" + + limiter := rl.GetLimiter(ip1) + require.True(t, limiter.Allow()) + require.False(t, limiter.Allow()) + + time.Sleep(3 * time.Second) + + limiter2 := rl.GetLimiter(ip1) + require.True(t, limiter2.Allow()) +} diff --git a/telemetry/server.go b/telemetry/server.go index c6146ed..69f2804 100644 --- a/telemetry/server.go +++ b/telemetry/server.go @@ -1,6 +1,7 @@ package telemetry import ( + "context" "crypto/sha256" "database/sql" "encoding/hex" @@ -12,19 +13,30 @@ import ( "github.com/gorilla/mux" "go.uber.org/zap" + "golang.org/x/time/rate" +) + +const ( + RATE_LIMIT = rate.Limit(10) + BURST = 1 ) type Server struct { - Router *mux.Router - DB *sql.DB - logger *zap.Logger + Router *mux.Router + DB *sql.DB + logger *zap.Logger + rateLimiter RateLimiter + ctx context.Context } func NewServer(db *sql.DB, logger *zap.Logger) *Server { + ctx := context.Background() server := &Server{ - Router: mux.NewRouter().StrictSlash(true), - DB: db, - logger: logger, + Router: mux.NewRouter().StrictSlash(true), + DB: db, + logger: logger, + rateLimiter: *NewRateLimiter(ctx, RATE_LIMIT, BURST), + ctx: ctx, } server.Router.HandleFunc("/protocol-stats", server.createProtocolStats).Methods("POST") @@ -32,6 +44,7 @@ func NewServer(db *sql.DB, logger *zap.Logger) *Server { server.Router.HandleFunc("/received-envelope", server.createReceivedEnvelope).Methods("POST") server.Router.HandleFunc("/update-envelope", server.updateEnvelope).Methods("POST") server.Router.HandleFunc("/health", handleHealthCheck).Methods("GET") + server.Router.Use(server.rateLimit) return server } @@ -207,6 +220,20 @@ func (s *Server) createProtocolStats(w http.ResponseWriter, r *http.Request) { ) } +func (s *Server) rateLimit(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + limiter := s.rateLimiter.GetLimiter(r.RemoteAddr) + // Do stuff here + if !limiter.Allow() { + http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) + return + } + // Call the next handler, which can be another middleware in the chain, or the final handler. + next.ServeHTTP(w, r) + }) +} + func (s *Server) Start(port int) { s.logger.Info("Starting server", zap.Int("port", port)) log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", port), s.Router))