From f04e6ae7aa3ba58793569b04734072d14af573b4 Mon Sep 17 00:00:00 2001 From: pablo Date: Fri, 27 Jun 2025 08:50:48 +0300 Subject: [PATCH] fix: add ref object and serializable --- ratelimit/ratelimit.nim | 44 ++++++++++++++++++++-------------------- tests/test_ratelimit.nim | 11 ++++++++++ 2 files changed, 33 insertions(+), 22 deletions(-) diff --git a/ratelimit/ratelimit.nim b/ratelimit/ratelimit.nim index 18b3a1f..d634d82 100644 --- a/ratelimit/ratelimit.nim +++ b/ratelimit/ratelimit.nim @@ -2,24 +2,27 @@ import std/[times, deques] import chronos type + Serializable* = concept x + x.toBytes() is seq[byte] + MessagePriority* = enum Critical = 0 Normal = 1 Optional = 2 - QueuedMessage*[T] = object + QueuedMessage*[T: Serializable] = ref object of RootObj messageId*: string msg*: T priority*: MessagePriority queuedAt*: float - MessageSender*[T] = proc(messageId: string, msg: T): Future[bool] {.async.} + MessageSender*[T: Serializable] = proc(messageId: string, msg: T): Future[bool] {.async.} - RateLimitManager*[T] = ref object + RateLimitManager*[T: Serializable] = ref object messageCount*: int = 100 # Default to 100 messages epochDurationSec*: int = 600 # Default to 10 minutes currentCount*: int - currentEpoch*: int64 + lastResetTime*: float criticalQueue*: Deque[QueuedMessage[T]] normalQueue*: Deque[QueuedMessage[T]] optionalQueue*: Deque[QueuedMessage[T]] @@ -27,15 +30,12 @@ type isRunning*: bool sendTask*: Future[void] -proc getCurrentEpoch(epochDurationSec: int): int64 = - int64(epochTime() / float(epochDurationSec)) - -proc newRateLimitManager*[T](messageCount: int, epochDurationSec: int, sender: MessageSender[T]): RateLimitManager[T] = +proc newRateLimitManager*[T: Serializable](messageCount: int, epochDurationSec: int, sender: MessageSender[T]): RateLimitManager[T] = RateLimitManager[T]( messageCount: messageCount, epochDurationSec: epochDurationSec, currentCount: 0, - currentEpoch: getCurrentEpoch(epochDurationSec), + lastResetTime: epochTime(), criticalQueue: initDeque[QueuedMessage[T]](), normalQueue: initDeque[QueuedMessage[T]](), optionalQueue: initDeque[QueuedMessage[T]](), @@ -43,18 +43,18 @@ proc newRateLimitManager*[T](messageCount: int, epochDurationSec: int, sender: M isRunning: false ) -proc updateEpochIfNeeded[T](manager: RateLimitManager[T]) = - let newEpoch = getCurrentEpoch(manager.epochDurationSec) - if newEpoch > manager.currentEpoch: - manager.currentEpoch = newEpoch +proc updateEpochIfNeeded[T: Serializable](manager: RateLimitManager[T]) = + let now = epochTime() + if now - manager.lastResetTime >= float(manager.epochDurationSec): + manager.lastResetTime = now manager.currentCount = 0 -proc getUsagePercent[T](manager: RateLimitManager[T]): float = +proc getUsagePercent[T: Serializable](manager: RateLimitManager[T]): float = if manager.messageCount == 0: return 1.0 float(manager.currentCount) / float(manager.messageCount) -proc queueForSend*[T](manager: RateLimitManager[T], messageId: string, msg: T, priority: MessagePriority) = +proc queueForSend*[T: Serializable](manager: RateLimitManager[T], messageId: string, msg: T, priority: MessagePriority) = manager.updateEpochIfNeeded() let queuedMsg = QueuedMessage[T]( @@ -94,7 +94,7 @@ proc queueForSend*[T](manager: RateLimitManager[T], messageId: string, msg: T, p of Optional: manager.optionalQueue.addLast(queuedMsg) -proc getNextMessage[T](manager: RateLimitManager[T]): QueuedMessage[T] = +proc getNextMessage[T: Serializable](manager: RateLimitManager[T]): QueuedMessage[T] = # Priority order: Critical -> Normal -> Optional if manager.criticalQueue.len > 0: return manager.criticalQueue.popFirst() @@ -105,10 +105,10 @@ proc getNextMessage[T](manager: RateLimitManager[T]): QueuedMessage[T] = else: raise newException(ValueError, "No messages in queue") -proc hasMessages[T](manager: RateLimitManager[T]): bool = +proc hasMessages[T: Serializable](manager: RateLimitManager[T]): bool = manager.criticalQueue.len > 0 or manager.normalQueue.len > 0 or manager.optionalQueue.len > 0 -proc sendLoop*[T](manager: RateLimitManager[T]): Future[void] {.async.} = +proc sendLoop*[T: Serializable](manager: RateLimitManager[T]): Future[void] {.async.} = manager.isRunning = true while manager.isRunning: @@ -153,11 +153,11 @@ proc sendLoop*[T](manager: RateLimitManager[T]): Future[void] {.async.} = except: await sleepAsync(chronos.seconds(1)) # Wait longer on error -proc start*[T](manager: RateLimitManager[T]): Future[void] {.async.} = +proc start*[T: Serializable](manager: RateLimitManager[T]): Future[void] {.async.} = if not manager.isRunning: manager.sendTask = manager.sendLoop() -proc stop*[T](manager: RateLimitManager[T]): Future[void] {.async.} = +proc stop*[T: Serializable](manager: RateLimitManager[T]): Future[void] {.async.} = if manager.isRunning: manager.isRunning = false if not manager.sendTask.isNil: @@ -167,7 +167,7 @@ proc stop*[T](manager: RateLimitManager[T]): Future[void] {.async.} = except CancelledError: discard -proc getQueueStatus*[T](manager: RateLimitManager[T]): tuple[critical: int, normal: int, optional: int, total: int] = +proc getQueueStatus*[T: Serializable](manager: RateLimitManager[T]): tuple[critical: int, normal: int, optional: int, total: int] = ( critical: manager.criticalQueue.len, normal: manager.normalQueue.len, @@ -175,7 +175,7 @@ proc getQueueStatus*[T](manager: RateLimitManager[T]): tuple[critical: int, norm total: manager.criticalQueue.len + manager.normalQueue.len + manager.optionalQueue.len ) -proc getCurrentQuota*[T](manager: RateLimitManager[T]): tuple[total: int, used: int, remaining: int] = +proc getCurrentQuota*[T: Serializable](manager: RateLimitManager[T]): tuple[total: int, used: int, remaining: int] = ( total: manager.messageCount, used: manager.currentCount, diff --git a/tests/test_ratelimit.nim b/tests/test_ratelimit.nim index 0598b3f..e5f7e5e 100644 --- a/tests/test_ratelimit.nim +++ b/tests/test_ratelimit.nim @@ -12,6 +12,17 @@ type content: string id: int +# Implement Serializable for test types +proc toBytes*(s: string): seq[byte] = + cast[seq[byte]](s) + +proc toBytes*(msg: TestMessage): seq[byte] = + result = toBytes(msg.content) + result.add(cast[seq[byte]](@[byte(msg.id)])) + +proc toBytes*(i: int): seq[byte] = + cast[seq[byte]](@[byte(i)]) + suite "Rate Limit Manager": setup: ## Given