From f04e6ae7aa3ba58793569b04734072d14af573b4 Mon Sep 17 00:00:00 2001
From: pablo
Date: Fri, 27 Jun 2025 08:50:48 +0300
Subject: [PATCH] fix: add ref object and serializable
---
ratelimit/ratelimit.nim | 44 ++++++++++++++++++++--------------------
tests/test_ratelimit.nim | 11 ++++++++++
2 files changed, 33 insertions(+), 22 deletions(-)
diff --git a/ratelimit/ratelimit.nim b/ratelimit/ratelimit.nim
index 18b3a1f..d634d82 100644
--- a/ratelimit/ratelimit.nim
+++ b/ratelimit/ratelimit.nim
@@ -2,24 +2,27 @@ import std/[times, deques]
import chronos
type
+ Serializable* = concept x
+ x.toBytes() is seq[byte]
+
MessagePriority* = enum
Critical = 0
Normal = 1
Optional = 2
- QueuedMessage*[T] = object
+ QueuedMessage*[T: Serializable] = ref object of RootObj
messageId*: string
msg*: T
priority*: MessagePriority
queuedAt*: float
- MessageSender*[T] = proc(messageId: string, msg: T): Future[bool] {.async.}
+ MessageSender*[T: Serializable] = proc(messageId: string, msg: T): Future[bool] {.async.}
- RateLimitManager*[T] = ref object
+ RateLimitManager*[T: Serializable] = ref object
messageCount*: int = 100 # Default to 100 messages
epochDurationSec*: int = 600 # Default to 10 minutes
currentCount*: int
- currentEpoch*: int64
+ lastResetTime*: float
criticalQueue*: Deque[QueuedMessage[T]]
normalQueue*: Deque[QueuedMessage[T]]
optionalQueue*: Deque[QueuedMessage[T]]
@@ -27,15 +30,12 @@ type
isRunning*: bool
sendTask*: Future[void]
-proc getCurrentEpoch(epochDurationSec: int): int64 =
- int64(epochTime() / float(epochDurationSec))
-
-proc newRateLimitManager*[T](messageCount: int, epochDurationSec: int, sender: MessageSender[T]): RateLimitManager[T] =
+proc newRateLimitManager*[T: Serializable](messageCount: int, epochDurationSec: int, sender: MessageSender[T]): RateLimitManager[T] =
RateLimitManager[T](
messageCount: messageCount,
epochDurationSec: epochDurationSec,
currentCount: 0,
- currentEpoch: getCurrentEpoch(epochDurationSec),
+ lastResetTime: epochTime(),
criticalQueue: initDeque[QueuedMessage[T]](),
normalQueue: initDeque[QueuedMessage[T]](),
optionalQueue: initDeque[QueuedMessage[T]](),
@@ -43,18 +43,18 @@ proc newRateLimitManager*[T](messageCount: int, epochDurationSec: int, sender: M
isRunning: false
)
-proc updateEpochIfNeeded[T](manager: RateLimitManager[T]) =
- let newEpoch = getCurrentEpoch(manager.epochDurationSec)
- if newEpoch > manager.currentEpoch:
- manager.currentEpoch = newEpoch
+proc updateEpochIfNeeded[T: Serializable](manager: RateLimitManager[T]) =
+ let now = epochTime()
+ if now - manager.lastResetTime >= float(manager.epochDurationSec):
+ manager.lastResetTime = now
manager.currentCount = 0
-proc getUsagePercent[T](manager: RateLimitManager[T]): float =
+proc getUsagePercent[T: Serializable](manager: RateLimitManager[T]): float =
if manager.messageCount == 0:
return 1.0
float(manager.currentCount) / float(manager.messageCount)
-proc queueForSend*[T](manager: RateLimitManager[T], messageId: string, msg: T, priority: MessagePriority) =
+proc queueForSend*[T: Serializable](manager: RateLimitManager[T], messageId: string, msg: T, priority: MessagePriority) =
manager.updateEpochIfNeeded()
let queuedMsg = QueuedMessage[T](
@@ -94,7 +94,7 @@ proc queueForSend*[T](manager: RateLimitManager[T], messageId: string, msg: T, p
of Optional:
manager.optionalQueue.addLast(queuedMsg)
-proc getNextMessage[T](manager: RateLimitManager[T]): QueuedMessage[T] =
+proc getNextMessage[T: Serializable](manager: RateLimitManager[T]): QueuedMessage[T] =
# Priority order: Critical -> Normal -> Optional
if manager.criticalQueue.len > 0:
return manager.criticalQueue.popFirst()
@@ -105,10 +105,10 @@ proc getNextMessage[T](manager: RateLimitManager[T]): QueuedMessage[T] =
else:
raise newException(ValueError, "No messages in queue")
-proc hasMessages[T](manager: RateLimitManager[T]): bool =
+proc hasMessages[T: Serializable](manager: RateLimitManager[T]): bool =
manager.criticalQueue.len > 0 or manager.normalQueue.len > 0 or manager.optionalQueue.len > 0
-proc sendLoop*[T](manager: RateLimitManager[T]): Future[void] {.async.} =
+proc sendLoop*[T: Serializable](manager: RateLimitManager[T]): Future[void] {.async.} =
manager.isRunning = true
while manager.isRunning:
@@ -153,11 +153,11 @@ proc sendLoop*[T](manager: RateLimitManager[T]): Future[void] {.async.} =
except:
await sleepAsync(chronos.seconds(1)) # Wait longer on error
-proc start*[T](manager: RateLimitManager[T]): Future[void] {.async.} =
+proc start*[T: Serializable](manager: RateLimitManager[T]): Future[void] {.async.} =
if not manager.isRunning:
manager.sendTask = manager.sendLoop()
-proc stop*[T](manager: RateLimitManager[T]): Future[void] {.async.} =
+proc stop*[T: Serializable](manager: RateLimitManager[T]): Future[void] {.async.} =
if manager.isRunning:
manager.isRunning = false
if not manager.sendTask.isNil:
@@ -167,7 +167,7 @@ proc stop*[T](manager: RateLimitManager[T]): Future[void] {.async.} =
except CancelledError:
discard
-proc getQueueStatus*[T](manager: RateLimitManager[T]): tuple[critical: int, normal: int, optional: int, total: int] =
+proc getQueueStatus*[T: Serializable](manager: RateLimitManager[T]): tuple[critical: int, normal: int, optional: int, total: int] =
(
critical: manager.criticalQueue.len,
normal: manager.normalQueue.len,
@@ -175,7 +175,7 @@ proc getQueueStatus*[T](manager: RateLimitManager[T]): tuple[critical: int, norm
total: manager.criticalQueue.len + manager.normalQueue.len + manager.optionalQueue.len
)
-proc getCurrentQuota*[T](manager: RateLimitManager[T]): tuple[total: int, used: int, remaining: int] =
+proc getCurrentQuota*[T: Serializable](manager: RateLimitManager[T]): tuple[total: int, used: int, remaining: int] =
(
total: manager.messageCount,
used: manager.currentCount,
diff --git a/tests/test_ratelimit.nim b/tests/test_ratelimit.nim
index 0598b3f..e5f7e5e 100644
--- a/tests/test_ratelimit.nim
+++ b/tests/test_ratelimit.nim
@@ -12,6 +12,17 @@ type
content: string
id: int
+# Implement Serializable for test types
+proc toBytes*(s: string): seq[byte] =
+ cast[seq[byte]](s)
+
+proc toBytes*(msg: TestMessage): seq[byte] =
+ result = toBytes(msg.content)
+ result.add(cast[seq[byte]](@[byte(msg.id)]))
+
+proc toBytes*(i: int): seq[byte] =
+ cast[seq[byte]](@[byte(i)])
+
suite "Rate Limit Manager":
setup:
## Given