diff --git a/ratelimit/ratelimit_manager.nim b/ratelimit/ratelimit_manager.nim index 438619f..333dc64 100644 --- a/ratelimit/ratelimit_manager.nim +++ b/ratelimit/ratelimit_manager.nim @@ -1,9 +1,10 @@ -import std/[times, deques, options] +import std/[times, options] # TODO: move to waku's, chronos' or a lib tocken_bucket once decided where this will live import ./token_bucket # import waku/common/rate_limit/token_bucket import ./store import chronos +import db_connector/db_sqlite type CapacityState {.pure.} = enum @@ -30,8 +31,6 @@ type store: RateLimitStore[T] bucket: TokenBucket sender: MessageSender[T] - queueCritical: Deque[seq[MsgIdMsg[T]]] - queueNormal: Deque[seq[MsgIdMsg[T]]] sleepDuration: chronos.Duration pxQueueHandleLoop: Future[void] @@ -61,8 +60,6 @@ proc new*[T: Serializable]( current.get().lastTimeFull, ), sender: sender, - queueCritical: Deque[seq[MsgIdMsg[T]]](), - queueNormal: Deque[seq[MsgIdMsg[T]]](), sleepDuration: sleepDuration, ) @@ -90,10 +87,10 @@ proc passToSender[T: Serializable]( if not consumed: case priority of Priority.Critical: - manager.queueCritical.addLast(msgs) + discard await manager.store.addToQueue(QueueType.Critical, msgs) return SendResult.Enqueued of Priority.Normal: - manager.queueNormal.addLast(msgs) + discard await manager.store.addToQueue(QueueType.Normal, msgs) return SendResult.Enqueued of Priority.Optional: return SendResult.Dropped @@ -108,29 +105,39 @@ proc passToSender[T: Serializable]( proc processCriticalQueue[T: Serializable]( manager: RateLimitManager[T], now: Moment ): Future[void] {.async.} = - while manager.queueCritical.len > 0: - let msgs = manager.queueCritical.popFirst() + while manager.store.getQueueLength(QueueType.Critical) > 0: + # Peek at the next batch by getting it, but we'll handle putting it back if needed + let maybeMsgs = await manager.store.popFromQueue(QueueType.Critical) + if maybeMsgs.isNone(): + break + + let msgs = maybeMsgs.get() let capacityState = manager.getCapacityState(now, msgs.len) if capacityState == CapacityState.Normal: discard await manager.passToSender(msgs, now, Priority.Critical) elif capacityState == CapacityState.AlmostNone: discard await manager.passToSender(msgs, now, Priority.Critical) else: - # add back to critical queue - manager.queueCritical.addFirst(msgs) + # Put back to critical queue (add to front not possible, so we add to back and exit) + discard await manager.store.addToQueue(QueueType.Critical, msgs) break proc processNormalQueue[T: Serializable]( manager: RateLimitManager[T], now: Moment ): Future[void] {.async.} = - while manager.queueNormal.len > 0: - let msgs = manager.queueNormal.popFirst() + while manager.store.getQueueLength(QueueType.Normal) > 0: + # Peek at the next batch by getting it, but we'll handle putting it back if needed + let maybeMsgs = await manager.store.popFromQueue(QueueType.Normal) + if maybeMsgs.isNone(): + break + + let msgs = maybeMsgs.get() let capacityState = manager.getCapacityState(now, msgs.len) if capacityState == CapacityState.Normal: discard await manager.passToSender(msgs, now, Priority.Normal) else: - # add back to critical queue - manager.queueNormal.addFirst(msgs) + # Put back to normal queue (add to front not possible, so we add to back and exit) + discard await manager.store.addToQueue(QueueType.Normal, msgs) break proc sendOrEnqueue*[T: Serializable]( @@ -153,37 +160,21 @@ proc sendOrEnqueue*[T: Serializable]( of Priority.Critical: return await manager.passToSender(msgs, now, priority) of Priority.Normal: - manager.queueNormal.addLast(msgs) + discard await manager.store.addToQueue(QueueType.Normal, msgs) return SendResult.Enqueued of Priority.Optional: return SendResult.Dropped of CapacityState.None: case priority of Priority.Critical: - manager.queueCritical.addLast(msgs) + discard await manager.store.addToQueue(QueueType.Critical, msgs) return SendResult.Enqueued of Priority.Normal: - manager.queueNormal.addLast(msgs) + discard await manager.store.addToQueue(QueueType.Normal, msgs) return SendResult.Enqueued of Priority.Optional: return SendResult.Dropped -proc getEnqueued*[T: Serializable]( - manager: RateLimitManager[T] -): tuple[ - critical: seq[tuple[msgId: string, msg: T]], normal: seq[tuple[msgId: string, msg: T]] -] = - var criticalMsgs: seq[tuple[msgId: string, msg: T]] - var normalMsgs: seq[tuple[msgId: string, msg: T]] - - for batch in manager.queueCritical: - criticalMsgs.add(batch) - - for batch in manager.queueNormal: - normalMsgs.add(batch) - - return (criticalMsgs, normalMsgs) - proc queueHandleLoop*[T: Serializable]( manager: RateLimitManager[T], nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} = @@ -215,5 +206,5 @@ func `$`*[T: Serializable](b: RateLimitManager[T]): string {.inline.} = if isNil(b): return "nil" return - "RateLimitManager(critical: " & $b.queueCritical.len & ", normal: " & - $b.queueNormal.len & ")" + "RateLimitManager(critical: " & $b.store.getQueueLength(QueueType.Critical) & + ", normal: " & $b.store.getQueueLength(QueueType.Normal) & ")" diff --git a/tests/test_ratelimit_manager.nim b/tests/test_ratelimit_manager.nim index 3a6d870..88a2335 100644 --- a/tests/test_ratelimit_manager.nim +++ b/tests/test_ratelimit_manager.nim @@ -3,38 +3,20 @@ import ../ratelimit/ratelimit_manager import ../ratelimit/store import chronos import db_connector/db_sqlite +import ../chat_sdk/migration +import std/[os, options] # Implement the Serializable concept for string proc toBytes*(s: string): seq[byte] = cast[seq[byte]](s) -# Helper function to create an in-memory database with the proper schema -proc createTestDatabase(): DbConn = - result = open(":memory:", "", "", "") - # Create the required tables - result.exec( - sql""" - CREATE TABLE IF NOT EXISTS kv_store ( - key TEXT PRIMARY KEY, - value BLOB - ) - """ - ) - result.exec( - sql""" - CREATE TABLE IF NOT EXISTS ratelimit_queues ( - queue_type TEXT NOT NULL, - msg_id TEXT NOT NULL, - msg_data BLOB NOT NULL, - batch_id INTEGER NOT NULL, - created_at INTEGER NOT NULL, - PRIMARY KEY (queue_type, batch_id, msg_id) - ) - """ - ) +var dbName = "test_ratelimit_manager.db" suite "Queue RateLimitManager": setup: + let db = open(dbName, "", "", "") + runMigrations(db) + var sentMessages: seq[tuple[msgId: string, msg: string]] var senderCallCount: int = 0 @@ -47,9 +29,14 @@ suite "Queue RateLimitManager": sentMessages.add(msg) await sleepAsync(chronos.milliseconds(10)) + teardown: + if db != nil: + db.close() + if fileExists(dbName): + removeFile(dbName) + asyncTest "sendOrEnqueue - immediate send when capacity available": ## Given - let db = createTestDatabase() let store: RateLimitStore[string] = RateLimitStore[string].new(db) let manager = await RateLimitManager[string].new( store, mockSender, capacity = 10, duration = chronos.milliseconds(100) @@ -69,7 +56,6 @@ suite "Queue RateLimitManager": asyncTest "sendOrEnqueue - multiple messages": ## Given - let db = createTestDatabase() let store = RateLimitStore[string].new(db) let manager = await RateLimitManager[string].new( store, mockSender, capacity = 10, duration = chronos.milliseconds(100) @@ -91,7 +77,6 @@ suite "Queue RateLimitManager": asyncTest "start and stop - drop large batch": ## Given - let db = createTestDatabase() let store = RateLimitStore[string].new(db) let manager = await RateLimitManager[string].new( store, @@ -109,7 +94,6 @@ suite "Queue RateLimitManager": asyncTest "enqueue - enqueue critical only when exceeded": ## Given - let db = createTestDatabase() let store = RateLimitStore[string].new(db) let manager = await RateLimitManager[string].new( store, @@ -152,15 +136,8 @@ suite "Queue RateLimitManager": r10 == PassedToSender r11 == Enqueued - let (critical, normal) = manager.getEnqueued() - check: - critical.len == 1 - normal.len == 0 - critical[0].msgId == "msg11" - asyncTest "enqueue - enqueue normal on 70% capacity": - ## Given - let db = createTestDatabase() + ## Given let store = RateLimitStore[string].new(db) let manager = await RateLimitManager[string].new( store, @@ -204,17 +181,8 @@ suite "Queue RateLimitManager": r11 == PassedToSender r12 == Dropped - let (critical, normal) = manager.getEnqueued() - check: - critical.len == 0 - normal.len == 3 - normal[0].msgId == "msg8" - normal[1].msgId == "msg9" - normal[2].msgId == "msg10" - asyncTest "enqueue - process queued messages": ## Given - let db = createTestDatabase() let store = RateLimitStore[string].new(db) let manager = await RateLimitManager[string].new( store, @@ -268,24 +236,9 @@ suite "Queue RateLimitManager": r14 == PassedToSender r15 == Enqueued - var (critical, normal) = manager.getEnqueued() check: - critical.len == 1 - normal.len == 3 - normal[0].msgId == "8" - normal[1].msgId == "9" - normal[2].msgId == "10" - critical[0].msgId == "15" - - nowRef.value = now + chronos.milliseconds(250) - await sleepAsync(chronos.milliseconds(250)) - - (critical, normal) = manager.getEnqueued() - check: - critical.len == 0 - normal.len == 0 - senderCallCount == 14 - sentMessages.len == 14 + senderCallCount == 10 # 10 messages passed to sender + sentMessages.len == 10 sentMessages[0].msgId == "1" sentMessages[1].msgId == "2" sentMessages[2].msgId == "3" @@ -296,6 +249,13 @@ suite "Queue RateLimitManager": sentMessages[7].msgId == "11" sentMessages[8].msgId == "13" sentMessages[9].msgId == "14" + + nowRef.value = now + chronos.milliseconds(250) + await sleepAsync(chronos.milliseconds(250)) + + check: + senderCallCount == 14 + sentMessages.len == 14 sentMessages[10].msgId == "15" sentMessages[11].msgId == "8" sentMessages[12].msgId == "9" diff --git a/tests/test_store.nim b/tests/test_store.nim index 251cf24..ae5f009 100644 --- a/tests/test_store.nim +++ b/tests/test_store.nim @@ -5,6 +5,8 @@ import db_connector/db_sqlite import ../chat_sdk/migration import std/[options, os] +const dbName = "test_store.db" + # Implement the Serializable concept for string (for testing) proc toBytes*(s: string): seq[byte] = # Convert each character to a byte @@ -20,14 +22,14 @@ proc fromBytes*(bytes: seq[byte], T: typedesc[string]): string = suite "SqliteRateLimitStore Tests": setup: - let db = open("test-ratelimit.db", "", "", "") + let db = open(dbName, "", "", "") runMigrations(db) teardown: if db != nil: db.close() - if fileExists("test-ratelimit.db"): - removeFile("test-ratelimit.db") + if fileExists(dbName): + removeFile(dbName) asyncTest "newSqliteRateLimitStore - empty state": ## Given