From dd0082041c1ea069f7b09d5ec6fd864fbefe7f27 Mon Sep 17 00:00:00 2001 From: pablo Date: Mon, 4 Aug 2025 11:31:44 +0300 Subject: [PATCH] fix: refactor --- ratelimit/ratelimit_manager.nim | 60 +++++++--------- ratelimit/{store/sqlite.nim => store.nim} | 57 +++++++++------ ratelimit/store/memory.nim | 70 ------------------- ratelimit/store/store.nim | 20 ------ tests/test_ratelimit_manager.nim | 58 +++++++++++---- .../{test_sqlite_store.nim => test_store.nim} | 21 +++--- 6 files changed, 117 insertions(+), 169 deletions(-) rename ratelimit/{store/sqlite.nim => store.nim} (84%) delete mode 100644 ratelimit/store/memory.nim delete mode 100644 ratelimit/store/store.nim rename tests/{test_sqlite_store.nim => test_store.nim} (92%) diff --git a/ratelimit/ratelimit_manager.nim b/ratelimit/ratelimit_manager.nim index dce9632..438619f 100644 --- a/ratelimit/ratelimit_manager.nim +++ b/ratelimit/ratelimit_manager.nim @@ -2,7 +2,7 @@ import std/[times, deques, options] # TODO: move to waku's, chronos' or a lib tocken_bucket once decided where this will live import ./token_bucket # import waku/common/rate_limit/token_bucket -import ./store/store +import ./store import chronos type @@ -22,16 +22,12 @@ type Normal Optional - Serializable* = - concept x - x.toBytes() is seq[byte] - MsgIdMsg[T: Serializable] = tuple[msgId: string, msg: T] MessageSender*[T: Serializable] = proc(msgs: seq[MsgIdMsg[T]]) {.async.} - RateLimitManager*[T: Serializable, S: RateLimitStore] = ref object - store: S + RateLimitManager*[T: Serializable] = ref object + store: RateLimitStore[T] bucket: TokenBucket sender: MessageSender[T] queueCritical: Deque[seq[MsgIdMsg[T]]] @@ -39,14 +35,14 @@ type sleepDuration: chronos.Duration pxQueueHandleLoop: Future[void] -proc new*[T: Serializable, S: RateLimitStore]( - M: type[RateLimitManager[T, S]], - store: S, +proc new*[T: Serializable]( + M: type[RateLimitManager[T]], + store: RateLimitStore[T], sender: MessageSender[T], capacity: int = 100, duration: chronos.Duration = chronos.minutes(10), sleepDuration: chronos.Duration = chronos.milliseconds(1000), -): Future[RateLimitManager[T, S]] {.async.} = +): Future[RateLimitManager[T]] {.async.} = var current = await store.loadBucketState() if current.isNone(): # initialize bucket state with full capacity @@ -55,7 +51,7 @@ proc new*[T: Serializable, S: RateLimitStore]( ) discard await store.saveBucketState(current.get()) - return RateLimitManager[T, S]( + return RateLimitManager[T]( store: store, bucket: TokenBucket.new( current.get().budgetCap, @@ -70,8 +66,8 @@ proc new*[T: Serializable, S: RateLimitStore]( sleepDuration: sleepDuration, ) -proc getCapacityState[T: Serializable, S: RateLimitStore]( - manager: RateLimitManager[T, S], now: Moment, count: int = 1 +proc getCapacityState[T: Serializable]( + manager: RateLimitManager[T], now: Moment, count: int = 1 ): CapacityState = let (budget, budgetCap, _) = manager.bucket.getAvailableCapacity(now) let countAfter = budget - count @@ -83,8 +79,8 @@ proc getCapacityState[T: Serializable, S: RateLimitStore]( else: return CapacityState.Normal -proc passToSender[T: Serializable, S: RateLimitStore]( - manager: RateLimitManager[T, S], +proc passToSender[T: Serializable]( + manager: RateLimitManager[T], msgs: seq[tuple[msgId: string, msg: T]], now: Moment, priority: Priority, @@ -109,8 +105,8 @@ proc passToSender[T: Serializable, S: RateLimitStore]( await manager.sender(msgs) return SendResult.PassedToSender -proc processCriticalQueue[T: Serializable, S: RateLimitStore]( - manager: RateLimitManager[T, S], now: Moment +proc processCriticalQueue[T: Serializable]( + manager: RateLimitManager[T], now: Moment ): Future[void] {.async.} = while manager.queueCritical.len > 0: let msgs = manager.queueCritical.popFirst() @@ -124,8 +120,8 @@ proc processCriticalQueue[T: Serializable, S: RateLimitStore]( manager.queueCritical.addFirst(msgs) break -proc processNormalQueue[T: Serializable, S: RateLimitStore]( - manager: RateLimitManager[T, S], now: Moment +proc processNormalQueue[T: Serializable]( + manager: RateLimitManager[T], now: Moment ): Future[void] {.async.} = while manager.queueNormal.len > 0: let msgs = manager.queueNormal.popFirst() @@ -137,8 +133,8 @@ proc processNormalQueue[T: Serializable, S: RateLimitStore]( manager.queueNormal.addFirst(msgs) break -proc sendOrEnqueue*[T: Serializable, S: RateLimitStore]( - manager: RateLimitManager[T, S], +proc sendOrEnqueue*[T: Serializable]( + manager: RateLimitManager[T], msgs: seq[tuple[msgId: string, msg: T]], priority: Priority, now: Moment = Moment.now(), @@ -172,8 +168,8 @@ proc sendOrEnqueue*[T: Serializable, S: RateLimitStore]( of Priority.Optional: return SendResult.Dropped -proc getEnqueued*[T: Serializable, S: RateLimitStore]( - manager: RateLimitManager[T, S] +proc getEnqueued*[T: Serializable]( + manager: RateLimitManager[T] ): tuple[ critical: seq[tuple[msgId: string, msg: T]], normal: seq[tuple[msgId: string, msg: T]] ] = @@ -188,8 +184,8 @@ proc getEnqueued*[T: Serializable, S: RateLimitStore]( return (criticalMsgs, normalMsgs) -proc queueHandleLoop*[T: Serializable, S: RateLimitStore]( - manager: RateLimitManager[T, S], +proc queueHandleLoop*[T: Serializable]( + manager: RateLimitManager[T], nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} = Moment.now(), ) {.async.} = @@ -204,22 +200,18 @@ proc queueHandleLoop*[T: Serializable, S: RateLimitStore]( # configurable sleep duration for processing queued messages await sleepAsync(manager.sleepDuration) -proc start*[T: Serializable, S: RateLimitStore]( - manager: RateLimitManager[T, S], +proc start*[T: Serializable]( + manager: RateLimitManager[T], nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} = Moment.now(), ) {.async.} = manager.pxQueueHandleLoop = queueHandleLoop(manager, nowProvider) -proc stop*[T: Serializable, S: RateLimitStore]( - manager: RateLimitManager[T, S] -) {.async.} = +proc stop*[T: Serializable](manager: RateLimitManager[T]) {.async.} = if not isNil(manager.pxQueueHandleLoop): await manager.pxQueueHandleLoop.cancelAndWait() -func `$`*[T: Serializable, S: RateLimitStore]( - b: RateLimitManager[T, S] -): string {.inline.} = +func `$`*[T: Serializable](b: RateLimitManager[T]): string {.inline.} = if isNil(b): return "nil" return diff --git a/ratelimit/store/sqlite.nim b/ratelimit/store.nim similarity index 84% rename from ratelimit/store/sqlite.nim rename to ratelimit/store.nim index a3369e9..2618589 100644 --- a/ratelimit/store/sqlite.nim +++ b/ratelimit/store.nim @@ -1,7 +1,6 @@ import std/[times, strutils, json, options] -import ./store -import chronos import db_connector/db_sqlite +import chronos # Generic deserialization function for basic types proc fromBytesImpl(bytes: seq[byte], T: typedesc[string]): string = @@ -10,19 +9,31 @@ proc fromBytesImpl(bytes: seq[byte], T: typedesc[string]): string = for i, b in bytes: result[i] = char(b) -# SQLite Implementation -type SqliteRateLimitStore*[T] = ref object - db: DbConn - dbPath: string - criticalLength: int - normalLength: int - nextBatchId: int +type + Serializable* = + concept x + x.toBytes() is seq[byte] + + RateLimitStore*[T: Serializable] = ref object + db: DbConn + dbPath: string + criticalLength: int + normalLength: int + nextBatchId: int + + BucketState* = object + budget*: int + budgetCap*: int + lastTimeFull*: Moment + + QueueType* {.pure.} = enum + Critical = "critical" + Normal = "normal" const BUCKET_STATE_KEY = "rate_limit_bucket_state" -proc newSqliteRateLimitStore*[T](db: DbConn): SqliteRateLimitStore[T] = - result = - SqliteRateLimitStore[T](db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1) +proc new*[T: Serializable](M: type[RateLimitStore[T]], db: DbConn): M = + result = M(db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1) # Initialize cached lengths from database let criticalCount = db.getValue( @@ -53,8 +64,10 @@ proc newSqliteRateLimitStore*[T](db: DbConn): SqliteRateLimitStore[T] = else: parseInt(maxBatch) + 1 -proc saveBucketState*[T]( - store: SqliteRateLimitStore[T], bucketState: BucketState + return result + +proc saveBucketState*[T: Serializable]( + store: RateLimitStore[T], bucketState: BucketState ): Future[bool] {.async.} = try: # Convert Moment to Unix seconds for storage @@ -75,8 +88,8 @@ proc saveBucketState*[T]( except: return false -proc loadBucketState*[T]( - store: SqliteRateLimitStore[T] +proc loadBucketState*[T: Serializable]( + store: RateLimitStore[T] ): Future[Option[BucketState]] {.async.} = let jsonStr = store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY) @@ -95,8 +108,8 @@ proc loadBucketState*[T]( ) ) -proc addToQueue*[T]( - store: SqliteRateLimitStore[T], +proc addToQueue*[T: Serializable]( + store: RateLimitStore[T], queueType: QueueType, msgs: seq[tuple[msgId: string, msg: T]], ): Future[bool] {.async.} = @@ -140,8 +153,8 @@ proc addToQueue*[T]( except: return false -proc popFromQueue*[T]( - store: SqliteRateLimitStore[T], queueType: QueueType +proc popFromQueue*[T: Serializable]( + store: RateLimitStore[T], queueType: QueueType ): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} = try: let queueTypeStr = $queueType @@ -201,7 +214,9 @@ proc popFromQueue*[T]( except: return none(seq[tuple[msgId: string, msg: T]]) -proc getQueueLength*[T](store: SqliteRateLimitStore[T], queueType: QueueType): int = +proc getQueueLength*[T: Serializable]( + store: RateLimitStore[T], queueType: QueueType +): int = case queueType of QueueType.Critical: return store.criticalLength diff --git a/ratelimit/store/memory.nim b/ratelimit/store/memory.nim deleted file mode 100644 index e7e7c2f..0000000 --- a/ratelimit/store/memory.nim +++ /dev/null @@ -1,70 +0,0 @@ -import std/[times, options, deques, tables] -import ./store -import chronos - -# Memory Implementation -type MemoryRateLimitStore*[T] = ref object - bucketState: Option[BucketState] - criticalQueue: Deque[seq[tuple[msgId: string, msg: T]]] - normalQueue: Deque[seq[tuple[msgId: string, msg: T]]] - criticalLength: int - normalLength: int - -proc new*[T](M: type[MemoryRateLimitStore[T]]): M = - return M( - bucketState: none(BucketState), - criticalQueue: initDeque[seq[tuple[msgId: string, msg: T]]](), - normalQueue: initDeque[seq[tuple[msgId: string, msg: T]]](), - criticalLength: 0, - normalLength: 0 - ) - -proc saveBucketState*[T]( - store: MemoryRateLimitStore[T], bucketState: BucketState -): Future[bool] {.async.} = - store.bucketState = some(bucketState) - return true - -proc loadBucketState*[T]( - store: MemoryRateLimitStore[T] -): Future[Option[BucketState]] {.async.} = - return store.bucketState - -proc addToQueue*[T]( - store: MemoryRateLimitStore[T], - queueType: QueueType, - msgs: seq[tuple[msgId: string, msg: T]] -): Future[bool] {.async.} = - case queueType - of QueueType.Critical: - store.criticalQueue.addLast(msgs) - inc store.criticalLength - of QueueType.Normal: - store.normalQueue.addLast(msgs) - inc store.normalLength - return true - -proc popFromQueue*[T]( - store: MemoryRateLimitStore[T], - queueType: QueueType -): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} = - case queueType - of QueueType.Critical: - if store.criticalQueue.len > 0: - dec store.criticalLength - return some(store.criticalQueue.popFirst()) - of QueueType.Normal: - if store.normalQueue.len > 0: - dec store.normalLength - return some(store.normalQueue.popFirst()) - return none(seq[tuple[msgId: string, msg: T]]) - -proc getQueueLength*[T]( - store: MemoryRateLimitStore[T], - queueType: QueueType -): int = - case queueType - of QueueType.Critical: - return store.criticalLength - of QueueType.Normal: - return store.normalLength diff --git a/ratelimit/store/store.nim b/ratelimit/store/store.nim deleted file mode 100644 index 0f18eb1..0000000 --- a/ratelimit/store/store.nim +++ /dev/null @@ -1,20 +0,0 @@ -import std/[times, deques, options] -import chronos - -type - BucketState* = object - budget*: int - budgetCap*: int - lastTimeFull*: Moment - - QueueType* {.pure.} = enum - Critical = "critical" - Normal = "normal" - - RateLimitStore* = - concept s - s.saveBucketState(BucketState) is Future[bool] - s.loadBucketState() is Future[Option[BucketState]] - s.addToQueue(QueueType, seq[tuple[msgId: string, msg: untyped]]) is Future[bool] - s.popFromQueue(QueueType) is Future[Option[seq[tuple[msgId: string, msg: untyped]]]] - s.getQueueLength(QueueType) is int diff --git a/tests/test_ratelimit_manager.nim b/tests/test_ratelimit_manager.nim index f6c4039..3a6d870 100644 --- a/tests/test_ratelimit_manager.nim +++ b/tests/test_ratelimit_manager.nim @@ -1,12 +1,38 @@ import testutils/unittests import ../ratelimit/ratelimit_manager -import ../ratelimit/store/memory +import ../ratelimit/store import chronos +import db_connector/db_sqlite # Implement the Serializable concept for string proc toBytes*(s: string): seq[byte] = cast[seq[byte]](s) +# Helper function to create an in-memory database with the proper schema +proc createTestDatabase(): DbConn = + result = open(":memory:", "", "", "") + # Create the required tables + result.exec( + sql""" + CREATE TABLE IF NOT EXISTS kv_store ( + key TEXT PRIMARY KEY, + value BLOB + ) + """ + ) + result.exec( + sql""" + CREATE TABLE IF NOT EXISTS ratelimit_queues ( + queue_type TEXT NOT NULL, + msg_id TEXT NOT NULL, + msg_data BLOB NOT NULL, + batch_id INTEGER NOT NULL, + created_at INTEGER NOT NULL, + PRIMARY KEY (queue_type, batch_id, msg_id) + ) + """ + ) + suite "Queue RateLimitManager": setup: var sentMessages: seq[tuple[msgId: string, msg: string]] @@ -23,8 +49,9 @@ suite "Queue RateLimitManager": asyncTest "sendOrEnqueue - immediate send when capacity available": ## Given - let store: MemoryRateLimitStore = MemoryRateLimitStore.new() - let manager = await RateLimitManager[string, MemoryRateLimitStore].new( + let db = createTestDatabase() + let store: RateLimitStore[string] = RateLimitStore[string].new(db) + let manager = await RateLimitManager[string].new( store, mockSender, capacity = 10, duration = chronos.milliseconds(100) ) let testMsg = "Hello World" @@ -42,8 +69,9 @@ suite "Queue RateLimitManager": asyncTest "sendOrEnqueue - multiple messages": ## Given - let store = MemoryRateLimitStore.new() - let manager = await RateLimitManager[string, MemoryRateLimitStore].new( + let db = createTestDatabase() + let store = RateLimitStore[string].new(db) + let manager = await RateLimitManager[string].new( store, mockSender, capacity = 10, duration = chronos.milliseconds(100) ) @@ -63,8 +91,9 @@ suite "Queue RateLimitManager": asyncTest "start and stop - drop large batch": ## Given - let store = MemoryRateLimitStore.new() - let manager = await RateLimitManager[string, MemoryRateLimitStore].new( + let db = createTestDatabase() + let store = RateLimitStore[string].new(db) + let manager = await RateLimitManager[string].new( store, mockSender, capacity = 2, @@ -80,8 +109,9 @@ suite "Queue RateLimitManager": asyncTest "enqueue - enqueue critical only when exceeded": ## Given - let store = MemoryRateLimitStore.new() - let manager = await RateLimitManager[string, MemoryRateLimitStore].new( + let db = createTestDatabase() + let store = RateLimitStore[string].new(db) + let manager = await RateLimitManager[string].new( store, mockSender, capacity = 10, @@ -130,8 +160,9 @@ suite "Queue RateLimitManager": asyncTest "enqueue - enqueue normal on 70% capacity": ## Given - let store = MemoryRateLimitStore.new() - let manager = await RateLimitManager[string, MemoryRateLimitStore].new( + let db = createTestDatabase() + let store = RateLimitStore[string].new(db) + let manager = await RateLimitManager[string].new( store, mockSender, capacity = 10, @@ -183,8 +214,9 @@ suite "Queue RateLimitManager": asyncTest "enqueue - process queued messages": ## Given - let store = MemoryRateLimitStore.new() - let manager = await RateLimitManager[string, MemoryRateLimitStore].new( + let db = createTestDatabase() + let store = RateLimitStore[string].new(db) + let manager = await RateLimitManager[string].new( store, mockSender, capacity = 10, diff --git a/tests/test_sqlite_store.nim b/tests/test_store.nim similarity index 92% rename from tests/test_sqlite_store.nim rename to tests/test_store.nim index 90d764c..251cf24 100644 --- a/tests/test_sqlite_store.nim +++ b/tests/test_store.nim @@ -1,6 +1,5 @@ import testutils/unittests -import ../ratelimit/store/sqlite -import ../ratelimit/store/store +import ../ratelimit/store import chronos import db_connector/db_sqlite import ../chat_sdk/migration @@ -32,7 +31,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "newSqliteRateLimitStore - empty state": ## Given - let store = newSqliteRateLimitStore[string](db) + let store = RateLimitStore[string].new(db) ## When let loadedState = await store.loadBucketState() @@ -42,7 +41,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "saveBucketState and loadBucketState - state persistence": ## Given - let store = newSqliteRateLimitStore[string](db) + let store = RateLimitStore[string].new(db) let now = Moment.now() echo "now: ", now.epochSeconds() @@ -62,7 +61,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "queue operations - empty store": ## Given - let store = newSqliteRateLimitStore[string](db) + let store = RateLimitStore[string].new(db) ## When/Then check store.getQueueLength(QueueType.Critical) == 0 @@ -76,7 +75,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "addToQueue and popFromQueue - single batch": ## Given - let store = newSqliteRateLimitStore[string](db) + let store = RateLimitStore[string].new(db) let msgs = @[("msg1", "Hello"), ("msg2", "World")] ## When @@ -103,7 +102,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "addToQueue and popFromQueue - multiple batches FIFO": ## Given - let store = newSqliteRateLimitStore[string](db) + let store = RateLimitStore[string].new(db) let batch1 = @[("msg1", "First")] let batch2 = @[("msg2", "Second")] let batch3 = @[("msg3", "Third")] @@ -141,7 +140,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "queue isolation - critical and normal queues are separate": ## Given - let store = newSqliteRateLimitStore[string](db) + let store = RateLimitStore[string].new(db) let criticalMsgs = @[("crit1", "Critical Message")] let normalMsgs = @[("norm1", "Normal Message")] @@ -178,14 +177,14 @@ suite "SqliteRateLimitStore Tests": let msgs = @[("persist1", "Persistent Message")] block: - let store1 = newSqliteRateLimitStore[string](db) + let store1 = RateLimitStore[string].new(db) let addResult = await store1.addToQueue(QueueType.Critical, msgs) check addResult == true check store1.getQueueLength(QueueType.Critical) == 1 ## When - Create new store instance block: - let store2 = newSqliteRateLimitStore[string](db) + let store2 = RateLimitStore[string].new(db) ## Then - Queue length should be restored from database check store2.getQueueLength(QueueType.Critical) == 1 @@ -197,7 +196,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "large batch handling": ## Given - let store = newSqliteRateLimitStore[string](db) + let store = RateLimitStore[string].new(db) var largeBatch: seq[tuple[msgId: string, msg: string]] for i in 1 .. 100: