diff --git a/waku/v2/api/publish/rln_rate_limiting.go b/waku/v2/api/publish/rln_rate_limiting.go index 11a469bf..59f00ab0 100644 --- a/waku/v2/api/publish/rln_rate_limiting.go +++ b/waku/v2/api/publish/rln_rate_limiting.go @@ -11,7 +11,7 @@ import ( var ErrRateLimited = errors.New("rate limit exceeded") -const RlnLimiterCapacity = 100 +const RlnLimiterCapacity = 600 const RlnLimiterRefillInterval = 10 * time.Minute // RlnRateLimiter is used to rate limit the outgoing messages, @@ -22,15 +22,23 @@ type RlnRateLimiter struct { tokens int refillInterval time.Duration lastRefill time.Time + updateCh chan BucketUpdate +} + +// BucketUpdate includes the information that need to be persisted in database. +type BucketUpdate struct { + RemainingTokens int + LastRefill time.Time } // NewRlnPublishRateLimiter creates a new rate limiter, starts with a full capacity bucket. -func NewRlnRateLimiter(capacity int, refillInterval time.Duration) *RlnRateLimiter { +func NewRlnRateLimiter(capacity int, refillInterval time.Duration, availableTokens int, lastRefill time.Time, updateCh chan BucketUpdate) *RlnRateLimiter { return &RlnRateLimiter{ capacity: capacity, - tokens: capacity, // Start with a full bucket + tokens: availableTokens, // Start with a full bucket in the first run, then track the remaining tokens in storage refillInterval: refillInterval, - lastRefill: time.Now(), + lastRefill: lastRefill, + updateCh: updateCh, } } @@ -42,19 +50,26 @@ func (rl *RlnRateLimiter) Allow() bool { // Refill tokens if the refill interval has passed now := time.Now() if now.Sub(rl.lastRefill) >= rl.refillInterval { - rl.tokens = rl.capacity // Refill the bucket + rl.tokens = rl.capacity rl.lastRefill = now + rl.sendUpdate() } // Check if there are tokens available if rl.tokens > 0 { rl.tokens-- + rl.sendUpdate() return true } return false } +// sendUpdate sends the latest token state to the update channel. +func (rl *RlnRateLimiter) sendUpdate() { + rl.updateCh <- BucketUpdate{RemainingTokens: rl.tokens, LastRefill: rl.lastRefill} +} + func (rl *RlnRateLimiter) Check(ctx context.Context, logger *zap.Logger) error { if rl.Allow() { return nil diff --git a/waku/v2/api/publish/rln_rate_limiting_test.go b/waku/v2/api/publish/rln_rate_limiting_test.go index f91f6ca6..c10a899a 100644 --- a/waku/v2/api/publish/rln_rate_limiting_test.go +++ b/waku/v2/api/publish/rln_rate_limiting_test.go @@ -2,6 +2,7 @@ package publish import ( "context" + "sync" "testing" "time" @@ -10,17 +11,50 @@ import ( ) func TestRlnRateLimit(t *testing.T) { - r := NewRlnRateLimiter(3, 5*time.Second) + updateCh := make(chan BucketUpdate, 10) + refillTime := time.Now() + capacity := 3 + r := NewRlnRateLimiter(capacity, 5*time.Second, capacity, refillTime, updateCh) l := utils.Logger() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sleepDuration := 6 * time.Second + var mu sync.Mutex + go func(ctx context.Context, ch chan BucketUpdate) { + usedToken := 0 + for { + select { + case update := <-ch: + mu.Lock() + if update.LastRefill != refillTime { + usedToken = 0 + require.WithinDuration(t, refillTime.Add(sleepDuration), update.LastRefill, time.Second, "Last refill timestamp is incorrect") + require.Equal(t, update.RemainingTokens, capacity) + continue + } + usedToken++ + require.Equal(t, update.RemainingTokens, capacity-usedToken) + mu.Unlock() + case <-ctx.Done(): + return + } + } + }(ctx, updateCh) + + for i := 0; i < capacity; i++ { + require.NoError(t, r.Check(context.Background(), l)) + } + require.ErrorIs(t, r.Check(context.Background(), l), ErrRateLimited) + + time.Sleep(sleepDuration) + for i := 0; i < 3; i++ { require.NoError(t, r.Check(context.Background(), l)) } require.ErrorIs(t, r.Check(context.Background(), l), ErrRateLimited) - time.Sleep(6 * time.Second) - for i := 0; i < 3; i++ { - require.NoError(t, r.Check(context.Background(), l)) - } - require.ErrorIs(t, r.Check(context.Background(), l), ErrRateLimited) + // wait for goroutine to finish + time.Sleep(time.Second) }