From f59588a970e6b215190bd32cac496dad94e82bc5 Mon Sep 17 00:00:00 2001 From: kaichao Date: Tue, 25 Mar 2025 17:52:50 +0800 Subject: [PATCH] feat: send bucket update when rate limit applied (#1277) --- waku/v2/api/publish/rln_rate_limiting.go | 27 +++++++--- waku/v2/api/publish/rln_rate_limiting_test.go | 50 ++++++++++++++++--- 2 files changed, 65 insertions(+), 12 deletions(-) diff --git a/waku/v2/api/publish/rln_rate_limiting.go b/waku/v2/api/publish/rln_rate_limiting.go index 11a469bf..d0b52689 100644 --- a/waku/v2/api/publish/rln_rate_limiting.go +++ b/waku/v2/api/publish/rln_rate_limiting.go @@ -11,8 +11,8 @@ import ( var ErrRateLimited = errors.New("rate limit exceeded") -const RlnLimiterCapacity = 100 -const RlnLimiterRefillInterval = 10 * time.Minute +const DefaultRlnLimiterCapacity = 600 +const DefaultRlnLimiterRefillInterval = 10 * time.Minute // RlnRateLimiter is used to rate limit the outgoing messages, // The capacity and refillInterval comes from RLN contract configuration. @@ -22,15 +22,23 @@ type RlnRateLimiter struct { tokens int refillInterval time.Duration lastRefill time.Time + updateCh chan RlnRateLimitState +} + +// RlnRateLimitState includes the information that need to be persisted in database. +type RlnRateLimitState struct { + RemainingTokens int + LastRefill time.Time } // NewRlnPublishRateLimiter creates a new rate limiter, starts with a full capacity bucket. -func NewRlnRateLimiter(capacity int, refillInterval time.Duration) *RlnRateLimiter { +func NewRlnRateLimiter(capacity int, refillInterval time.Duration, state RlnRateLimitState, updateCh chan RlnRateLimitState) *RlnRateLimiter { return &RlnRateLimiter{ capacity: capacity, - tokens: capacity, // Start with a full bucket + tokens: state.RemainingTokens, refillInterval: refillInterval, - lastRefill: time.Now(), + lastRefill: state.LastRefill, + updateCh: updateCh, } } @@ -42,19 +50,26 @@ func (rl *RlnRateLimiter) Allow() bool { // Refill tokens if the refill interval has passed now := time.Now() if now.Sub(rl.lastRefill) >= rl.refillInterval { - rl.tokens = rl.capacity // Refill the bucket + rl.tokens = rl.capacity rl.lastRefill = now + rl.sendUpdate() } // Check if there are tokens available if rl.tokens > 0 { rl.tokens-- + rl.sendUpdate() return true } return false } +// sendUpdate sends the latest token state to the update channel. +func (rl *RlnRateLimiter) sendUpdate() { + rl.updateCh <- RlnRateLimitState{RemainingTokens: rl.tokens, LastRefill: rl.lastRefill} +} + func (rl *RlnRateLimiter) Check(ctx context.Context, logger *zap.Logger) error { if rl.Allow() { return nil diff --git a/waku/v2/api/publish/rln_rate_limiting_test.go b/waku/v2/api/publish/rln_rate_limiting_test.go index f91f6ca6..f62d52c1 100644 --- a/waku/v2/api/publish/rln_rate_limiting_test.go +++ b/waku/v2/api/publish/rln_rate_limiting_test.go @@ -2,6 +2,7 @@ package publish import ( "context" + "sync" "testing" "time" @@ -10,17 +11,54 @@ import ( ) func TestRlnRateLimit(t *testing.T) { - r := NewRlnRateLimiter(3, 5*time.Second) + updateCh := make(chan RlnRateLimitState, 10) + refillTime := time.Now() + capacity := 3 + state := RlnRateLimitState{ + RemainingTokens: capacity, + LastRefill: refillTime, + } + r := NewRlnRateLimiter(capacity, 5*time.Second, state, updateCh) l := utils.Logger() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sleepDuration := 6 * time.Second + var mu sync.Mutex + go func(ctx context.Context, ch chan RlnRateLimitState) { + usedToken := 0 + for { + select { + case update := <-ch: + mu.Lock() + if update.LastRefill != refillTime { + usedToken = 0 + require.WithinDuration(t, refillTime.Add(sleepDuration), update.LastRefill, time.Second, "Last refill timestamp is incorrect") + require.Equal(t, update.RemainingTokens, capacity) + continue + } + usedToken++ + require.Equal(t, update.RemainingTokens, capacity-usedToken) + mu.Unlock() + case <-ctx.Done(): + return + } + } + }(ctx, updateCh) + + for i := 0; i < capacity; i++ { + require.NoError(t, r.Check(context.Background(), l)) + } + require.ErrorIs(t, r.Check(context.Background(), l), ErrRateLimited) + + time.Sleep(sleepDuration) + for i := 0; i < 3; i++ { require.NoError(t, r.Check(context.Background(), l)) } require.ErrorIs(t, r.Check(context.Background(), l), ErrRateLimited) - time.Sleep(6 * time.Second) - for i := 0; i < 3; i++ { - require.NoError(t, r.Check(context.Background(), l)) - } - require.ErrorIs(t, r.Check(context.Background(), l), ErrRateLimited) + // wait for goroutine to finish + time.Sleep(time.Second) }