From dd0082041c1ea069f7b09d5ec6fd864fbefe7f27 Mon Sep 17 00:00:00 2001
From: pablo
Date: Mon, 4 Aug 2025 11:31:44 +0300
Subject: [PATCH] fix: refactor
---
ratelimit/ratelimit_manager.nim | 60 +++++++---------
ratelimit/{store/sqlite.nim => store.nim} | 57 +++++++++------
ratelimit/store/memory.nim | 70 -------------------
ratelimit/store/store.nim | 20 ------
tests/test_ratelimit_manager.nim | 58 +++++++++++----
.../{test_sqlite_store.nim => test_store.nim} | 21 +++---
6 files changed, 117 insertions(+), 169 deletions(-)
rename ratelimit/{store/sqlite.nim => store.nim} (84%)
delete mode 100644 ratelimit/store/memory.nim
delete mode 100644 ratelimit/store/store.nim
rename tests/{test_sqlite_store.nim => test_store.nim} (92%)
diff --git a/ratelimit/ratelimit_manager.nim b/ratelimit/ratelimit_manager.nim
index dce9632..438619f 100644
--- a/ratelimit/ratelimit_manager.nim
+++ b/ratelimit/ratelimit_manager.nim
@@ -2,7 +2,7 @@ import std/[times, deques, options]
# TODO: move to waku's, chronos' or a lib tocken_bucket once decided where this will live
import ./token_bucket
# import waku/common/rate_limit/token_bucket
-import ./store/store
+import ./store
import chronos
type
@@ -22,16 +22,12 @@ type
Normal
Optional
- Serializable* =
- concept x
- x.toBytes() is seq[byte]
-
MsgIdMsg[T: Serializable] = tuple[msgId: string, msg: T]
MessageSender*[T: Serializable] = proc(msgs: seq[MsgIdMsg[T]]) {.async.}
- RateLimitManager*[T: Serializable, S: RateLimitStore] = ref object
- store: S
+ RateLimitManager*[T: Serializable] = ref object
+ store: RateLimitStore[T]
bucket: TokenBucket
sender: MessageSender[T]
queueCritical: Deque[seq[MsgIdMsg[T]]]
@@ -39,14 +35,14 @@ type
sleepDuration: chronos.Duration
pxQueueHandleLoop: Future[void]
-proc new*[T: Serializable, S: RateLimitStore](
- M: type[RateLimitManager[T, S]],
- store: S,
+proc new*[T: Serializable](
+ M: type[RateLimitManager[T]],
+ store: RateLimitStore[T],
sender: MessageSender[T],
capacity: int = 100,
duration: chronos.Duration = chronos.minutes(10),
sleepDuration: chronos.Duration = chronos.milliseconds(1000),
-): Future[RateLimitManager[T, S]] {.async.} =
+): Future[RateLimitManager[T]] {.async.} =
var current = await store.loadBucketState()
if current.isNone():
# initialize bucket state with full capacity
@@ -55,7 +51,7 @@ proc new*[T: Serializable, S: RateLimitStore](
)
discard await store.saveBucketState(current.get())
- return RateLimitManager[T, S](
+ return RateLimitManager[T](
store: store,
bucket: TokenBucket.new(
current.get().budgetCap,
@@ -70,8 +66,8 @@ proc new*[T: Serializable, S: RateLimitStore](
sleepDuration: sleepDuration,
)
-proc getCapacityState[T: Serializable, S: RateLimitStore](
- manager: RateLimitManager[T, S], now: Moment, count: int = 1
+proc getCapacityState[T: Serializable](
+ manager: RateLimitManager[T], now: Moment, count: int = 1
): CapacityState =
let (budget, budgetCap, _) = manager.bucket.getAvailableCapacity(now)
let countAfter = budget - count
@@ -83,8 +79,8 @@ proc getCapacityState[T: Serializable, S: RateLimitStore](
else:
return CapacityState.Normal
-proc passToSender[T: Serializable, S: RateLimitStore](
- manager: RateLimitManager[T, S],
+proc passToSender[T: Serializable](
+ manager: RateLimitManager[T],
msgs: seq[tuple[msgId: string, msg: T]],
now: Moment,
priority: Priority,
@@ -109,8 +105,8 @@ proc passToSender[T: Serializable, S: RateLimitStore](
await manager.sender(msgs)
return SendResult.PassedToSender
-proc processCriticalQueue[T: Serializable, S: RateLimitStore](
- manager: RateLimitManager[T, S], now: Moment
+proc processCriticalQueue[T: Serializable](
+ manager: RateLimitManager[T], now: Moment
): Future[void] {.async.} =
while manager.queueCritical.len > 0:
let msgs = manager.queueCritical.popFirst()
@@ -124,8 +120,8 @@ proc processCriticalQueue[T: Serializable, S: RateLimitStore](
manager.queueCritical.addFirst(msgs)
break
-proc processNormalQueue[T: Serializable, S: RateLimitStore](
- manager: RateLimitManager[T, S], now: Moment
+proc processNormalQueue[T: Serializable](
+ manager: RateLimitManager[T], now: Moment
): Future[void] {.async.} =
while manager.queueNormal.len > 0:
let msgs = manager.queueNormal.popFirst()
@@ -137,8 +133,8 @@ proc processNormalQueue[T: Serializable, S: RateLimitStore](
manager.queueNormal.addFirst(msgs)
break
-proc sendOrEnqueue*[T: Serializable, S: RateLimitStore](
- manager: RateLimitManager[T, S],
+proc sendOrEnqueue*[T: Serializable](
+ manager: RateLimitManager[T],
msgs: seq[tuple[msgId: string, msg: T]],
priority: Priority,
now: Moment = Moment.now(),
@@ -172,8 +168,8 @@ proc sendOrEnqueue*[T: Serializable, S: RateLimitStore](
of Priority.Optional:
return SendResult.Dropped
-proc getEnqueued*[T: Serializable, S: RateLimitStore](
- manager: RateLimitManager[T, S]
+proc getEnqueued*[T: Serializable](
+ manager: RateLimitManager[T]
): tuple[
critical: seq[tuple[msgId: string, msg: T]], normal: seq[tuple[msgId: string, msg: T]]
] =
@@ -188,8 +184,8 @@ proc getEnqueued*[T: Serializable, S: RateLimitStore](
return (criticalMsgs, normalMsgs)
-proc queueHandleLoop*[T: Serializable, S: RateLimitStore](
- manager: RateLimitManager[T, S],
+proc queueHandleLoop*[T: Serializable](
+ manager: RateLimitManager[T],
nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} =
Moment.now(),
) {.async.} =
@@ -204,22 +200,18 @@ proc queueHandleLoop*[T: Serializable, S: RateLimitStore](
# configurable sleep duration for processing queued messages
await sleepAsync(manager.sleepDuration)
-proc start*[T: Serializable, S: RateLimitStore](
- manager: RateLimitManager[T, S],
+proc start*[T: Serializable](
+ manager: RateLimitManager[T],
nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} =
Moment.now(),
) {.async.} =
manager.pxQueueHandleLoop = queueHandleLoop(manager, nowProvider)
-proc stop*[T: Serializable, S: RateLimitStore](
- manager: RateLimitManager[T, S]
-) {.async.} =
+proc stop*[T: Serializable](manager: RateLimitManager[T]) {.async.} =
if not isNil(manager.pxQueueHandleLoop):
await manager.pxQueueHandleLoop.cancelAndWait()
-func `$`*[T: Serializable, S: RateLimitStore](
- b: RateLimitManager[T, S]
-): string {.inline.} =
+func `$`*[T: Serializable](b: RateLimitManager[T]): string {.inline.} =
if isNil(b):
return "nil"
return
diff --git a/ratelimit/store/sqlite.nim b/ratelimit/store.nim
similarity index 84%
rename from ratelimit/store/sqlite.nim
rename to ratelimit/store.nim
index a3369e9..2618589 100644
--- a/ratelimit/store/sqlite.nim
+++ b/ratelimit/store.nim
@@ -1,7 +1,6 @@
import std/[times, strutils, json, options]
-import ./store
-import chronos
import db_connector/db_sqlite
+import chronos
# Generic deserialization function for basic types
proc fromBytesImpl(bytes: seq[byte], T: typedesc[string]): string =
@@ -10,19 +9,31 @@ proc fromBytesImpl(bytes: seq[byte], T: typedesc[string]): string =
for i, b in bytes:
result[i] = char(b)
-# SQLite Implementation
-type SqliteRateLimitStore*[T] = ref object
- db: DbConn
- dbPath: string
- criticalLength: int
- normalLength: int
- nextBatchId: int
+type
+ Serializable* =
+ concept x
+ x.toBytes() is seq[byte]
+
+ RateLimitStore*[T: Serializable] = ref object
+ db: DbConn
+ dbPath: string
+ criticalLength: int
+ normalLength: int
+ nextBatchId: int
+
+ BucketState* = object
+ budget*: int
+ budgetCap*: int
+ lastTimeFull*: Moment
+
+ QueueType* {.pure.} = enum
+ Critical = "critical"
+ Normal = "normal"
const BUCKET_STATE_KEY = "rate_limit_bucket_state"
-proc newSqliteRateLimitStore*[T](db: DbConn): SqliteRateLimitStore[T] =
- result =
- SqliteRateLimitStore[T](db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1)
+proc new*[T: Serializable](M: type[RateLimitStore[T]], db: DbConn): M =
+ result = M(db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1)
# Initialize cached lengths from database
let criticalCount = db.getValue(
@@ -53,8 +64,10 @@ proc newSqliteRateLimitStore*[T](db: DbConn): SqliteRateLimitStore[T] =
else:
parseInt(maxBatch) + 1
-proc saveBucketState*[T](
- store: SqliteRateLimitStore[T], bucketState: BucketState
+ return result
+
+proc saveBucketState*[T: Serializable](
+ store: RateLimitStore[T], bucketState: BucketState
): Future[bool] {.async.} =
try:
# Convert Moment to Unix seconds for storage
@@ -75,8 +88,8 @@ proc saveBucketState*[T](
except:
return false
-proc loadBucketState*[T](
- store: SqliteRateLimitStore[T]
+proc loadBucketState*[T: Serializable](
+ store: RateLimitStore[T]
): Future[Option[BucketState]] {.async.} =
let jsonStr =
store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY)
@@ -95,8 +108,8 @@ proc loadBucketState*[T](
)
)
-proc addToQueue*[T](
- store: SqliteRateLimitStore[T],
+proc addToQueue*[T: Serializable](
+ store: RateLimitStore[T],
queueType: QueueType,
msgs: seq[tuple[msgId: string, msg: T]],
): Future[bool] {.async.} =
@@ -140,8 +153,8 @@ proc addToQueue*[T](
except:
return false
-proc popFromQueue*[T](
- store: SqliteRateLimitStore[T], queueType: QueueType
+proc popFromQueue*[T: Serializable](
+ store: RateLimitStore[T], queueType: QueueType
): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} =
try:
let queueTypeStr = $queueType
@@ -201,7 +214,9 @@ proc popFromQueue*[T](
except:
return none(seq[tuple[msgId: string, msg: T]])
-proc getQueueLength*[T](store: SqliteRateLimitStore[T], queueType: QueueType): int =
+proc getQueueLength*[T: Serializable](
+ store: RateLimitStore[T], queueType: QueueType
+): int =
case queueType
of QueueType.Critical:
return store.criticalLength
diff --git a/ratelimit/store/memory.nim b/ratelimit/store/memory.nim
deleted file mode 100644
index e7e7c2f..0000000
--- a/ratelimit/store/memory.nim
+++ /dev/null
@@ -1,70 +0,0 @@
-import std/[times, options, deques, tables]
-import ./store
-import chronos
-
-# Memory Implementation
-type MemoryRateLimitStore*[T] = ref object
- bucketState: Option[BucketState]
- criticalQueue: Deque[seq[tuple[msgId: string, msg: T]]]
- normalQueue: Deque[seq[tuple[msgId: string, msg: T]]]
- criticalLength: int
- normalLength: int
-
-proc new*[T](M: type[MemoryRateLimitStore[T]]): M =
- return M(
- bucketState: none(BucketState),
- criticalQueue: initDeque[seq[tuple[msgId: string, msg: T]]](),
- normalQueue: initDeque[seq[tuple[msgId: string, msg: T]]](),
- criticalLength: 0,
- normalLength: 0
- )
-
-proc saveBucketState*[T](
- store: MemoryRateLimitStore[T], bucketState: BucketState
-): Future[bool] {.async.} =
- store.bucketState = some(bucketState)
- return true
-
-proc loadBucketState*[T](
- store: MemoryRateLimitStore[T]
-): Future[Option[BucketState]] {.async.} =
- return store.bucketState
-
-proc addToQueue*[T](
- store: MemoryRateLimitStore[T],
- queueType: QueueType,
- msgs: seq[tuple[msgId: string, msg: T]]
-): Future[bool] {.async.} =
- case queueType
- of QueueType.Critical:
- store.criticalQueue.addLast(msgs)
- inc store.criticalLength
- of QueueType.Normal:
- store.normalQueue.addLast(msgs)
- inc store.normalLength
- return true
-
-proc popFromQueue*[T](
- store: MemoryRateLimitStore[T],
- queueType: QueueType
-): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} =
- case queueType
- of QueueType.Critical:
- if store.criticalQueue.len > 0:
- dec store.criticalLength
- return some(store.criticalQueue.popFirst())
- of QueueType.Normal:
- if store.normalQueue.len > 0:
- dec store.normalLength
- return some(store.normalQueue.popFirst())
- return none(seq[tuple[msgId: string, msg: T]])
-
-proc getQueueLength*[T](
- store: MemoryRateLimitStore[T],
- queueType: QueueType
-): int =
- case queueType
- of QueueType.Critical:
- return store.criticalLength
- of QueueType.Normal:
- return store.normalLength
diff --git a/ratelimit/store/store.nim b/ratelimit/store/store.nim
deleted file mode 100644
index 0f18eb1..0000000
--- a/ratelimit/store/store.nim
+++ /dev/null
@@ -1,20 +0,0 @@
-import std/[times, deques, options]
-import chronos
-
-type
- BucketState* = object
- budget*: int
- budgetCap*: int
- lastTimeFull*: Moment
-
- QueueType* {.pure.} = enum
- Critical = "critical"
- Normal = "normal"
-
- RateLimitStore* =
- concept s
- s.saveBucketState(BucketState) is Future[bool]
- s.loadBucketState() is Future[Option[BucketState]]
- s.addToQueue(QueueType, seq[tuple[msgId: string, msg: untyped]]) is Future[bool]
- s.popFromQueue(QueueType) is Future[Option[seq[tuple[msgId: string, msg: untyped]]]]
- s.getQueueLength(QueueType) is int
diff --git a/tests/test_ratelimit_manager.nim b/tests/test_ratelimit_manager.nim
index f6c4039..3a6d870 100644
--- a/tests/test_ratelimit_manager.nim
+++ b/tests/test_ratelimit_manager.nim
@@ -1,12 +1,38 @@
import testutils/unittests
import ../ratelimit/ratelimit_manager
-import ../ratelimit/store/memory
+import ../ratelimit/store
import chronos
+import db_connector/db_sqlite
# Implement the Serializable concept for string
proc toBytes*(s: string): seq[byte] =
cast[seq[byte]](s)
+# Helper function to create an in-memory database with the proper schema
+proc createTestDatabase(): DbConn =
+ result = open(":memory:", "", "", "")
+ # Create the required tables
+ result.exec(
+ sql"""
+ CREATE TABLE IF NOT EXISTS kv_store (
+ key TEXT PRIMARY KEY,
+ value BLOB
+ )
+ """
+ )
+ result.exec(
+ sql"""
+ CREATE TABLE IF NOT EXISTS ratelimit_queues (
+ queue_type TEXT NOT NULL,
+ msg_id TEXT NOT NULL,
+ msg_data BLOB NOT NULL,
+ batch_id INTEGER NOT NULL,
+ created_at INTEGER NOT NULL,
+ PRIMARY KEY (queue_type, batch_id, msg_id)
+ )
+ """
+ )
+
suite "Queue RateLimitManager":
setup:
var sentMessages: seq[tuple[msgId: string, msg: string]]
@@ -23,8 +49,9 @@ suite "Queue RateLimitManager":
asyncTest "sendOrEnqueue - immediate send when capacity available":
## Given
- let store: MemoryRateLimitStore = MemoryRateLimitStore.new()
- let manager = await RateLimitManager[string, MemoryRateLimitStore].new(
+ let db = createTestDatabase()
+ let store: RateLimitStore[string] = RateLimitStore[string].new(db)
+ let manager = await RateLimitManager[string].new(
store, mockSender, capacity = 10, duration = chronos.milliseconds(100)
)
let testMsg = "Hello World"
@@ -42,8 +69,9 @@ suite "Queue RateLimitManager":
asyncTest "sendOrEnqueue - multiple messages":
## Given
- let store = MemoryRateLimitStore.new()
- let manager = await RateLimitManager[string, MemoryRateLimitStore].new(
+ let db = createTestDatabase()
+ let store = RateLimitStore[string].new(db)
+ let manager = await RateLimitManager[string].new(
store, mockSender, capacity = 10, duration = chronos.milliseconds(100)
)
@@ -63,8 +91,9 @@ suite "Queue RateLimitManager":
asyncTest "start and stop - drop large batch":
## Given
- let store = MemoryRateLimitStore.new()
- let manager = await RateLimitManager[string, MemoryRateLimitStore].new(
+ let db = createTestDatabase()
+ let store = RateLimitStore[string].new(db)
+ let manager = await RateLimitManager[string].new(
store,
mockSender,
capacity = 2,
@@ -80,8 +109,9 @@ suite "Queue RateLimitManager":
asyncTest "enqueue - enqueue critical only when exceeded":
## Given
- let store = MemoryRateLimitStore.new()
- let manager = await RateLimitManager[string, MemoryRateLimitStore].new(
+ let db = createTestDatabase()
+ let store = RateLimitStore[string].new(db)
+ let manager = await RateLimitManager[string].new(
store,
mockSender,
capacity = 10,
@@ -130,8 +160,9 @@ suite "Queue RateLimitManager":
asyncTest "enqueue - enqueue normal on 70% capacity":
## Given
- let store = MemoryRateLimitStore.new()
- let manager = await RateLimitManager[string, MemoryRateLimitStore].new(
+ let db = createTestDatabase()
+ let store = RateLimitStore[string].new(db)
+ let manager = await RateLimitManager[string].new(
store,
mockSender,
capacity = 10,
@@ -183,8 +214,9 @@ suite "Queue RateLimitManager":
asyncTest "enqueue - process queued messages":
## Given
- let store = MemoryRateLimitStore.new()
- let manager = await RateLimitManager[string, MemoryRateLimitStore].new(
+ let db = createTestDatabase()
+ let store = RateLimitStore[string].new(db)
+ let manager = await RateLimitManager[string].new(
store,
mockSender,
capacity = 10,
diff --git a/tests/test_sqlite_store.nim b/tests/test_store.nim
similarity index 92%
rename from tests/test_sqlite_store.nim
rename to tests/test_store.nim
index 90d764c..251cf24 100644
--- a/tests/test_sqlite_store.nim
+++ b/tests/test_store.nim
@@ -1,6 +1,5 @@
import testutils/unittests
-import ../ratelimit/store/sqlite
-import ../ratelimit/store/store
+import ../ratelimit/store
import chronos
import db_connector/db_sqlite
import ../chat_sdk/migration
@@ -32,7 +31,7 @@ suite "SqliteRateLimitStore Tests":
asyncTest "newSqliteRateLimitStore - empty state":
## Given
- let store = newSqliteRateLimitStore[string](db)
+ let store = RateLimitStore[string].new(db)
## When
let loadedState = await store.loadBucketState()
@@ -42,7 +41,7 @@ suite "SqliteRateLimitStore Tests":
asyncTest "saveBucketState and loadBucketState - state persistence":
## Given
- let store = newSqliteRateLimitStore[string](db)
+ let store = RateLimitStore[string].new(db)
let now = Moment.now()
echo "now: ", now.epochSeconds()
@@ -62,7 +61,7 @@ suite "SqliteRateLimitStore Tests":
asyncTest "queue operations - empty store":
## Given
- let store = newSqliteRateLimitStore[string](db)
+ let store = RateLimitStore[string].new(db)
## When/Then
check store.getQueueLength(QueueType.Critical) == 0
@@ -76,7 +75,7 @@ suite "SqliteRateLimitStore Tests":
asyncTest "addToQueue and popFromQueue - single batch":
## Given
- let store = newSqliteRateLimitStore[string](db)
+ let store = RateLimitStore[string].new(db)
let msgs = @[("msg1", "Hello"), ("msg2", "World")]
## When
@@ -103,7 +102,7 @@ suite "SqliteRateLimitStore Tests":
asyncTest "addToQueue and popFromQueue - multiple batches FIFO":
## Given
- let store = newSqliteRateLimitStore[string](db)
+ let store = RateLimitStore[string].new(db)
let batch1 = @[("msg1", "First")]
let batch2 = @[("msg2", "Second")]
let batch3 = @[("msg3", "Third")]
@@ -141,7 +140,7 @@ suite "SqliteRateLimitStore Tests":
asyncTest "queue isolation - critical and normal queues are separate":
## Given
- let store = newSqliteRateLimitStore[string](db)
+ let store = RateLimitStore[string].new(db)
let criticalMsgs = @[("crit1", "Critical Message")]
let normalMsgs = @[("norm1", "Normal Message")]
@@ -178,14 +177,14 @@ suite "SqliteRateLimitStore Tests":
let msgs = @[("persist1", "Persistent Message")]
block:
- let store1 = newSqliteRateLimitStore[string](db)
+ let store1 = RateLimitStore[string].new(db)
let addResult = await store1.addToQueue(QueueType.Critical, msgs)
check addResult == true
check store1.getQueueLength(QueueType.Critical) == 1
## When - Create new store instance
block:
- let store2 = newSqliteRateLimitStore[string](db)
+ let store2 = RateLimitStore[string].new(db)
## Then - Queue length should be restored from database
check store2.getQueueLength(QueueType.Critical) == 1
@@ -197,7 +196,7 @@ suite "SqliteRateLimitStore Tests":
asyncTest "large batch handling":
## Given
- let store = newSqliteRateLimitStore[string](db)
+ let store = RateLimitStore[string].new(db)
var largeBatch: seq[tuple[msgId: string, msg: string]]
for i in 1 .. 100: