diff --git a/.gitignore b/.gitignore index bd74343..9a9c03f 100644 --- a/.gitignore +++ b/.gitignore @@ -21,8 +21,10 @@ nimcache/ # Compiled files chat_sdk/* apps/* +!*.nim tests/* - +!*.nim +ratelimit/* !*.nim !*.proto nimble.develop diff --git a/chat_sdk.nimble b/chat_sdk.nimble index 199cdf5..b1165f3 100644 --- a/chat_sdk.nimble +++ b/chat_sdk.nimble @@ -7,7 +7,7 @@ license = "MIT" srcDir = "src" ### Dependencies -requires "nim >= 2.2.4", "chronicles", "chronos", "db_connector", "waku" +requires "nim >= 2.2.4", "chronicles", "chronos", "db_connector", "flatty" task buildSharedLib, "Build shared library for C bindings": exec "nim c --mm:refc --app:lib --out:../library/c-bindings/libchatsdk.so chat_sdk/chat_sdk.nim" diff --git a/chat_sdk/migration.nim b/chat_sdk/migration.nim index a71f466..2c91f21 100644 --- a/chat_sdk/migration.nim +++ b/chat_sdk/migration.nim @@ -3,15 +3,19 @@ import db_connector/db_sqlite import chronicles proc ensureMigrationTable(db: DbConn) = - db.exec(sql""" + db.exec( + sql""" CREATE TABLE IF NOT EXISTS schema_migrations ( filename TEXT PRIMARY KEY, applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) - """) + """ + ) proc hasMigrationRun(db: DbConn, filename: string): bool = - for row in db.fastRows(sql"SELECT 1 FROM schema_migrations WHERE filename = ?", filename): + for row in db.fastRows( + sql"SELECT 1 FROM schema_migrations WHERE filename = ?", filename + ): return true return false diff --git a/migrations/001_create_ratelimit_state.sql b/migrations/001_create_ratelimit_state.sql new file mode 100644 index 0000000..293c6ee --- /dev/null +++ b/migrations/001_create_ratelimit_state.sql @@ -0,0 +1,13 @@ +CREATE TABLE IF NOT EXISTS kv_store ( + key TEXT PRIMARY KEY, + value BLOB +); + +CREATE TABLE IF NOT EXISTS ratelimit_queues ( + queue_type TEXT NOT NULL, + msg_id TEXT NOT NULL, + msg_data BLOB NOT NULL, + batch_id INTEGER NOT NULL, + created_at INTEGER NOT NULL, + PRIMARY KEY (queue_type, batch_id, msg_id) +); \ No newline at end of file diff --git a/ratelimit/rate_limit_manager.nim b/ratelimit/ratelimit_manager.nim similarity index 52% rename from ratelimit/rate_limit_manager.nim rename to ratelimit/ratelimit_manager.nim index c8f51f9..a30b5be 100644 --- a/ratelimit/rate_limit_manager.nim +++ b/ratelimit/ratelimit_manager.nim @@ -1,6 +1,10 @@ -import std/[times, deques, options] -import waku/common/rate_limit/token_bucket +import std/[times, options] +# TODO: move to waku's, chronos' or a lib tocken_bucket once decided where this will live +import ./token_bucket +# import waku/common/rate_limit/token_bucket +import ./store import chronos +import db_connector/db_sqlite type CapacityState {.pure.} = enum @@ -19,41 +23,50 @@ type Normal Optional - Serializable* = - concept x - x.toBytes() is seq[byte] + MsgIdMsg[T] = tuple[msgId: string, msg: T] - MsgIdMsg[T: Serializable] = tuple[msgId: string, msg: T] + MessageSender*[T] = proc(msgs: seq[MsgIdMsg[T]]) {.async.} - MessageSender*[T: Serializable] = proc(msgs: seq[MsgIdMsg[T]]) {.async.} - - RateLimitManager*[T: Serializable] = ref object + RateLimitManager*[T] = ref object + store: RateLimitStore[T] bucket: TokenBucket sender: MessageSender[T] - queueCritical: Deque[seq[MsgIdMsg[T]]] - queueNormal: Deque[seq[MsgIdMsg[T]]] sleepDuration: chronos.Duration pxQueueHandleLoop: Future[void] -proc new*[T: Serializable]( +proc new*[T]( M: type[RateLimitManager[T]], + store: RateLimitStore[T], sender: MessageSender[T], capacity: int = 100, duration: chronos.Duration = chronos.minutes(10), sleepDuration: chronos.Duration = chronos.milliseconds(1000), -): M = - M( - bucket: TokenBucket.newStrict(capacity, duration), +): Future[RateLimitManager[T]] {.async.} = + var current = await store.loadBucketState() + if current.isNone(): + # initialize bucket state with full capacity + current = some( + BucketState(budget: capacity, budgetCap: capacity, lastTimeFull: Moment.now()) + ) + discard await store.saveBucketState(current.get()) + + return RateLimitManager[T]( + store: store, + bucket: TokenBucket.new( + current.get().budgetCap, + duration, + ReplenishMode.Strict, + current.get().budget, + current.get().lastTimeFull, + ), sender: sender, - queueCritical: Deque[seq[MsgIdMsg[T]]](), - queueNormal: Deque[seq[MsgIdMsg[T]]](), sleepDuration: sleepDuration, ) -proc getCapacityState[T: Serializable]( +proc getCapacityState[T]( manager: RateLimitManager[T], now: Moment, count: int = 1 ): CapacityState = - let (budget, budgetCap) = manager.bucket.getAvailableCapacity(now) + let (budget, budgetCap, _) = manager.bucket.getAvailableCapacity(now) let countAfter = budget - count let ratio = countAfter.float / budgetCap.float if ratio < 0.0: @@ -63,62 +76,77 @@ proc getCapacityState[T: Serializable]( else: return CapacityState.Normal -proc passToSender[T: Serializable]( +proc passToSender[T]( manager: RateLimitManager[T], - msgs: sink seq[MsgIdMsg[T]], + msgs: seq[tuple[msgId: string, msg: T]], now: Moment, priority: Priority, ): Future[SendResult] {.async.} = let count = msgs.len - let capacity = manager.bucket.tryConsume(count, now) - if not capacity: + let consumed = manager.bucket.tryConsume(count, now) + if not consumed: case priority of Priority.Critical: - manager.queueCritical.addLast(msgs) + discard await manager.store.pushToQueue(QueueType.Critical, msgs) return SendResult.Enqueued of Priority.Normal: - manager.queueNormal.addLast(msgs) + discard await manager.store.pushToQueue(QueueType.Normal, msgs) return SendResult.Enqueued of Priority.Optional: return SendResult.Dropped + + let (budget, budgetCap, lastTimeFull) = manager.bucket.getAvailableCapacity(now) + discard await manager.store.saveBucketState( + BucketState(budget: budget, budgetCap: budgetCap, lastTimeFull: lastTimeFull) + ) await manager.sender(msgs) return SendResult.PassedToSender -proc processCriticalQueue[T: Serializable]( +proc processCriticalQueue[T]( manager: RateLimitManager[T], now: Moment -) {.async.} = - while manager.queueCritical.len > 0: - let msgs = manager.queueCritical.popFirst() +): Future[void] {.async.} = + while manager.store.getQueueLength(QueueType.Critical) > 0: + # Peek at the next batch by getting it, but we'll handle putting it back if needed + let maybeMsgs = await manager.store.popFromQueue(QueueType.Critical) + if maybeMsgs.isNone(): + break + + let msgs = maybeMsgs.get() let capacityState = manager.getCapacityState(now, msgs.len) if capacityState == CapacityState.Normal: discard await manager.passToSender(msgs, now, Priority.Critical) elif capacityState == CapacityState.AlmostNone: discard await manager.passToSender(msgs, now, Priority.Critical) else: - # add back to critical queue - manager.queueCritical.addFirst(msgs) + # Put back to critical queue (add to front not possible, so we add to back and exit) + discard await manager.store.pushToQueue(QueueType.Critical, msgs) break -proc processNormalQueue[T: Serializable]( +proc processNormalQueue[T]( manager: RateLimitManager[T], now: Moment -) {.async.} = - while manager.queueNormal.len > 0: - let msgs = manager.queueNormal.popFirst() +): Future[void] {.async.} = + while manager.store.getQueueLength(QueueType.Normal) > 0: + # Peek at the next batch by getting it, but we'll handle putting it back if needed + let maybeMsgs = await manager.store.popFromQueue(QueueType.Normal) + if maybeMsgs.isNone(): + break + + let msgs = maybeMsgs.get() let capacityState = manager.getCapacityState(now, msgs.len) if capacityState == CapacityState.Normal: discard await manager.passToSender(msgs, now, Priority.Normal) else: - # add back to critical queue - manager.queueNormal.addFirst(msgs) + # Put back to normal queue (add to front not possible, so we add to back and exit) + discard await manager.store.pushToQueue(QueueType.Normal, msgs) break -proc sendOrEnqueue*[T: Serializable]( +proc sendOrEnqueue*[T]( manager: RateLimitManager[T], - msgs: seq[MsgIdMsg[T]], + msgs: seq[tuple[msgId: string, msg: T]], priority: Priority, now: Moment = Moment.now(), ): Future[SendResult] {.async.} = - let (_, budgetCap) = manager.bucket.getAvailableCapacity(now) + let (_, budgetCap, _) = manager.bucket.getAvailableCapacity(now) if msgs.len.float / budgetCap.float >= 0.3: # drop batch if it's too large to avoid starvation return SendResult.DroppedBatchTooLarge @@ -132,36 +160,22 @@ proc sendOrEnqueue*[T: Serializable]( of Priority.Critical: return await manager.passToSender(msgs, now, priority) of Priority.Normal: - manager.queueNormal.addLast(msgs) + discard await manager.store.pushToQueue(QueueType.Normal, msgs) return SendResult.Enqueued of Priority.Optional: return SendResult.Dropped of CapacityState.None: case priority of Priority.Critical: - manager.queueCritical.addLast(msgs) + discard await manager.store.pushToQueue(QueueType.Critical, msgs) return SendResult.Enqueued of Priority.Normal: - manager.queueNormal.addLast(msgs) + discard await manager.store.pushToQueue(QueueType.Normal, msgs) return SendResult.Enqueued of Priority.Optional: return SendResult.Dropped -proc getEnqueued*[T: Serializable]( - manager: RateLimitManager[T] -): tuple[critical: seq[MsgIdMsg[T]], normal: seq[MsgIdMsg[T]]] = - var criticalMsgs: seq[MsgIdMsg[T]] - var normalMsgs: seq[MsgIdMsg[T]] - - for batch in manager.queueCritical: - criticalMsgs.add(batch) - - for batch in manager.queueNormal: - normalMsgs.add(batch) - - return (criticalMsgs, normalMsgs) - -proc queueHandleLoop[T: Serializable]( +proc queueHandleLoop*[T]( manager: RateLimitManager[T], nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} = Moment.now(), @@ -177,20 +191,20 @@ proc queueHandleLoop[T: Serializable]( # configurable sleep duration for processing queued messages await sleepAsync(manager.sleepDuration) -proc start*[T: Serializable]( +proc start*[T]( manager: RateLimitManager[T], nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} = Moment.now(), ) {.async.} = - manager.pxQueueHandleLoop = manager.queueHandleLoop(nowProvider) + manager.pxQueueHandleLoop = queueHandleLoop(manager, nowProvider) -proc stop*[T: Serializable](manager: RateLimitManager[T]) {.async.} = +proc stop*[T](manager: RateLimitManager[T]) {.async.} = if not isNil(manager.pxQueueHandleLoop): await manager.pxQueueHandleLoop.cancelAndWait() -func `$`*[T: Serializable](b: RateLimitManager[T]): string {.inline.} = +func `$`*[T](b: RateLimitManager[T]): string {.inline.} = if isNil(b): return "nil" return - "RateLimitManager(critical: " & $b.queueCritical.len & ", normal: " & - $b.queueNormal.len & ")" + "RateLimitManager(critical: " & $b.store.getQueueLength(QueueType.Critical) & + ", normal: " & $b.store.getQueueLength(QueueType.Normal) & ")" diff --git a/ratelimit/store.nim b/ratelimit/store.nim new file mode 100644 index 0000000..42fd152 --- /dev/null +++ b/ratelimit/store.nim @@ -0,0 +1,200 @@ +import std/[times, strutils, json, options, base64] +import db_connector/db_sqlite +import chronos +import flatty + +type + RateLimitStore*[T] = ref object + db: DbConn + dbPath: string + criticalLength: int + normalLength: int + nextBatchId: int + + BucketState* {.pure} = object + budget*: int + budgetCap*: int + lastTimeFull*: Moment + + QueueType* {.pure.} = enum + Critical = "critical" + Normal = "normal" + +const BUCKET_STATE_KEY = "rate_limit_bucket_state" + +## TODO find a way to make these procs async + +proc new*[T](M: type[RateLimitStore[T]], db: DbConn): Future[M] {.async} = + result = M(db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1) + + # Initialize cached lengths from database + let criticalCount = db.getValue( + sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?", + "critical", + ) + let normalCount = db.getValue( + sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?", + "normal", + ) + + result.criticalLength = + if criticalCount == "": + 0 + else: + parseInt(criticalCount) + result.normalLength = + if normalCount == "": + 0 + else: + parseInt(normalCount) + + # Get next batch ID + let maxBatch = db.getValue(sql"SELECT MAX(batch_id) FROM ratelimit_queues") + result.nextBatchId = + if maxBatch == "": + 1 + else: + parseInt(maxBatch) + 1 + + return result + +proc saveBucketState*[T]( + store: RateLimitStore[T], bucketState: BucketState +): Future[bool] {.async.} = + try: + # Convert Moment to Unix seconds for storage + let lastTimeSeconds = bucketState.lastTimeFull.epochSeconds() + + let jsonState = + %*{ + "budget": bucketState.budget, + "budgetCap": bucketState.budgetCap, + "lastTimeFullSeconds": lastTimeSeconds, + } + store.db.exec( + sql"INSERT INTO kv_store (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value", + BUCKET_STATE_KEY, + $jsonState, + ) + return true + except: + return false + +proc loadBucketState*[T]( + store: RateLimitStore[T] +): Future[Option[BucketState]] {.async.} = + let jsonStr = + store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY) + if jsonStr == "": + return none(BucketState) + + let jsonData = parseJson(jsonStr) + let unixSeconds = jsonData["lastTimeFullSeconds"].getInt().int64 + let lastTimeFull = Moment.init(unixSeconds, chronos.seconds(1)) + + return some( + BucketState( + budget: jsonData["budget"].getInt(), + budgetCap: jsonData["budgetCap"].getInt(), + lastTimeFull: lastTimeFull, + ) + ) + +proc pushToQueue*[T]( + store: RateLimitStore[T], + queueType: QueueType, + msgs: seq[tuple[msgId: string, msg: T]], +): Future[bool] {.async.} = + try: + let batchId = store.nextBatchId + inc store.nextBatchId + let now = times.getTime().toUnix() + let queueTypeStr = $queueType + + if msgs.len > 0: + store.db.exec(sql"BEGIN TRANSACTION") + try: + for msg in msgs: + let serialized = msg.msg.toFlatty() + let msgData = encode(serialized) + store.db.exec( + sql"INSERT INTO ratelimit_queues (queue_type, msg_id, msg_data, batch_id, created_at) VALUES (?, ?, ?, ?, ?)", + queueTypeStr, + msg.msgId, + msgData, + batchId, + now, + ) + store.db.exec(sql"COMMIT") + except: + store.db.exec(sql"ROLLBACK") + raise + + case queueType + of QueueType.Critical: + inc store.criticalLength + of QueueType.Normal: + inc store.normalLength + + return true + except: + return false + +proc popFromQueue*[T]( + store: RateLimitStore[T], queueType: QueueType +): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} = + try: + let queueTypeStr = $queueType + + # Get the oldest batch ID for this queue type + let oldestBatchStr = store.db.getValue( + sql"SELECT MIN(batch_id) FROM ratelimit_queues WHERE queue_type = ?", queueTypeStr + ) + + if oldestBatchStr == "": + return none(seq[tuple[msgId: string, msg: T]]) + + let batchId = parseInt(oldestBatchStr) + + # Get all messages in this batch (preserve insertion order using rowid) + let rows = store.db.getAllRows( + sql"SELECT msg_id, msg_data FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ? ORDER BY rowid", + queueTypeStr, + batchId, + ) + + if rows.len == 0: + return none(seq[tuple[msgId: string, msg: T]]) + + var msgs: seq[tuple[msgId: string, msg: T]] + for row in rows: + let msgIdStr = row[0] + let msgDataB64 = row[1] + + let serialized = decode(msgDataB64) + let msg = serialized.fromFlatty(T) + msgs.add((msgId: msgIdStr, msg: msg)) + + # Delete the batch from database + store.db.exec( + sql"DELETE FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ?", + queueTypeStr, + batchId, + ) + + case queueType + of QueueType.Critical: + dec store.criticalLength + of QueueType.Normal: + dec store.normalLength + + return some(msgs) + except: + return none(seq[tuple[msgId: string, msg: T]]) + +proc getQueueLength*[T](store: RateLimitStore[T], queueType: QueueType): int = + case queueType + of QueueType.Critical: + return store.criticalLength + of QueueType.Normal: + return store.normalLength diff --git a/ratelimit/token_bucket.nim b/ratelimit/token_bucket.nim new file mode 100644 index 0000000..13e40a1 --- /dev/null +++ b/ratelimit/token_bucket.nim @@ -0,0 +1,206 @@ +{.push raises: [].} + +import chronos, std/math, std/options + +const BUDGET_COMPENSATION_LIMIT_PERCENT = 0.25 + +## TODO! This will be remoded and replaced by https://github.com/status-im/nim-chronos/pull/582 + +## This is an extract from chronos/rate_limit.nim due to the found bug in the original implementation. +## Unfortunately that bug cannot be solved without harm the original features of TokenBucket class. +## So, this current shortcut is used to enable move ahead with nwaku rate limiter implementation. +## ref: https://github.com/status-im/nim-chronos/issues/500 +## +## This version of TokenBucket is different from the original one in chronos/rate_limit.nim in many ways: +## - It has a new mode called `Compensating` which is the default mode. +## Compensation is calculated as the not used bucket capacity in the last measured period(s) in average. +## or up until maximum the allowed compansation treshold (Currently it is const 25%). +## Also compensation takes care of the proper time period calculation to avoid non-usage periods that can lead to +## overcompensation. +## - Strict mode is also available which will only replenish when time period is over but also will fill +## the bucket to the max capacity. + +type + ReplenishMode* = enum + Strict + Compensating + + TokenBucket* = ref object + budget: int ## Current number of tokens in the bucket + budgetCap: int ## Bucket capacity + lastTimeFull: Moment + ## This timer measures the proper periodizaiton of the bucket refilling + fillDuration: Duration ## Refill period + case replenishMode*: ReplenishMode + of Strict: + ## In strict mode, the bucket is refilled only till the budgetCap + discard + of Compensating: + ## This is the default mode. + maxCompensation: float + +func periodDistance(bucket: TokenBucket, currentTime: Moment): float = + ## notice fillDuration cannot be zero by design + ## period distance is a float number representing the calculated period time + ## since the last time bucket was refilled. + return + nanoseconds(currentTime - bucket.lastTimeFull).float / + nanoseconds(bucket.fillDuration).float + +func getUsageAverageSince(bucket: TokenBucket, distance: float): float = + if distance == 0.float: + ## in case there is zero time difference than the usage percentage is 100% + return 1.0 + + ## budgetCap can never be zero + ## usage average is calculated as a percentage of total capacity available over + ## the measured period + return bucket.budget.float / bucket.budgetCap.float / distance + +proc calcCompensation(bucket: TokenBucket, averageUsage: float): int = + # if we already fully used or even overused the tokens, there is no place for compensation + if averageUsage >= 1.0: + return 0 + + ## compensation is the not used bucket capacity in the last measured period(s) in average. + ## or maximum the allowed compansation treshold + let compensationPercent = + min((1.0 - averageUsage) * bucket.budgetCap.float, bucket.maxCompensation) + return trunc(compensationPercent).int + +func periodElapsed(bucket: TokenBucket, currentTime: Moment): bool = + return currentTime - bucket.lastTimeFull >= bucket.fillDuration + +## Update will take place if bucket is empty and trying to consume tokens. +## It checks if the bucket can be replenished as refill duration is passed or not. +## - strict mode: +proc updateStrict(bucket: TokenBucket, currentTime: Moment) = + if bucket.fillDuration == default(Duration): + bucket.budget = min(bucket.budgetCap, bucket.budget) + return + + if not periodElapsed(bucket, currentTime): + return + + bucket.budget = bucket.budgetCap + bucket.lastTimeFull = currentTime + +## - compensating - ballancing load: +## - between updates we calculate average load (current bucket capacity / number of periods till last update) +## - gives the percentage load used recently +## - with this we can replenish bucket up to 100% + calculated leftover from previous period (caped with max treshold) +proc updateWithCompensation(bucket: TokenBucket, currentTime: Moment) = + if bucket.fillDuration == default(Duration): + bucket.budget = min(bucket.budgetCap, bucket.budget) + return + + # do not replenish within the same period + if not periodElapsed(bucket, currentTime): + return + + let distance = bucket.periodDistance(currentTime) + let recentAvgUsage = bucket.getUsageAverageSince(distance) + let compensation = bucket.calcCompensation(recentAvgUsage) + + bucket.budget = bucket.budgetCap + compensation + bucket.lastTimeFull = currentTime + +proc update(bucket: TokenBucket, currentTime: Moment) = + if bucket.replenishMode == ReplenishMode.Compensating: + updateWithCompensation(bucket, currentTime) + else: + updateStrict(bucket, currentTime) + +proc getAvailableCapacity*( + bucket: TokenBucket, currentTime: Moment = Moment.now() +): tuple[budget: int, budgetCap: int, lastTimeFull: Moment] = + ## Returns the available capacity of the bucket: (budget, budgetCap, lastTimeFull) + + if periodElapsed(bucket, currentTime): + case bucket.replenishMode + of ReplenishMode.Strict: + return (bucket.budgetCap, bucket.budgetCap, bucket.lastTimeFull) + of ReplenishMode.Compensating: + let distance = bucket.periodDistance(currentTime) + let recentAvgUsage = bucket.getUsageAverageSince(distance) + let compensation = bucket.calcCompensation(recentAvgUsage) + let availableBudget = bucket.budgetCap + compensation + return (availableBudget, bucket.budgetCap, bucket.lastTimeFull) + return (bucket.budget, bucket.budgetCap, bucket.lastTimeFull) + +proc tryConsume*(bucket: TokenBucket, tokens: int, now = Moment.now()): bool = + ## If `tokens` are available, consume them, + ## Otherwhise, return false. + + if bucket.budget >= bucket.budgetCap: + bucket.lastTimeFull = now + + if bucket.budget >= tokens: + bucket.budget -= tokens + return true + + bucket.update(now) + + if bucket.budget >= tokens: + bucket.budget -= tokens + return true + else: + return false + +proc replenish*(bucket: TokenBucket, tokens: int, now = Moment.now()) = + ## Add `tokens` to the budget (capped to the bucket capacity) + bucket.budget += tokens + bucket.update(now) + +proc new*( + T: type[TokenBucket], + budgetCap: int, + fillDuration: Duration = 1.seconds, + mode: ReplenishMode = ReplenishMode.Compensating, + budget: int = -1, # -1 means "use budgetCap" + lastTimeFull: Moment = Moment.now(), +): T = + assert not isZero(fillDuration) + assert budgetCap != 0 + assert lastTimeFull <= Moment.now() + let actualBudget = if budget == -1: budgetCap else: budget + assert actualBudget >= 0 and actualBudget <= budgetCap + + ## Create different mode TokenBucket + case mode + of ReplenishMode.Strict: + return T( + budget: actualBudget, + budgetCap: budgetCap, + fillDuration: fillDuration, + lastTimeFull: lastTimeFull, + replenishMode: mode, + ) + of ReplenishMode.Compensating: + T( + budget: actualBudget, + budgetCap: budgetCap, + fillDuration: fillDuration, + lastTimeFull: lastTimeFull, + replenishMode: mode, + maxCompensation: budgetCap.float * BUDGET_COMPENSATION_LIMIT_PERCENT, + ) + +proc newStrict*(T: type[TokenBucket], capacity: int, period: Duration): TokenBucket = + T.new(capacity, period, ReplenishMode.Strict) + +proc newCompensating*( + T: type[TokenBucket], capacity: int, period: Duration +): TokenBucket = + T.new(capacity, period, ReplenishMode.Compensating) + +func `$`*(b: TokenBucket): string {.inline.} = + if isNil(b): + return "nil" + return $b.budgetCap & "/" & $b.fillDuration + +func `$`*(ob: Option[TokenBucket]): string {.inline.} = + if ob.isNone(): + return "no-limit" + + return $ob.get() diff --git a/tests/test_rate_limit_manager.nim b/tests/test_ratelimit_manager.nim similarity index 84% rename from tests/test_rate_limit_manager.nim rename to tests/test_ratelimit_manager.nim index 6c34c54..50a2000 100644 --- a/tests/test_rate_limit_manager.nim +++ b/tests/test_ratelimit_manager.nim @@ -1,13 +1,18 @@ import testutils/unittests -import ../ratelimit/rate_limit_manager +import ../ratelimit/ratelimit_manager +import ../ratelimit/store import chronos +import db_connector/db_sqlite +import ../chat_sdk/migration +import std/[os, options] -# Implement the Serializable concept for string -proc toBytes*(s: string): seq[byte] = - cast[seq[byte]](s) +var dbName = "test_ratelimit_manager.db" suite "Queue RateLimitManager": setup: + let db = open(dbName, "", "", "") + runMigrations(db) + var sentMessages: seq[tuple[msgId: string, msg: string]] var senderCallCount: int = 0 @@ -20,10 +25,17 @@ suite "Queue RateLimitManager": sentMessages.add(msg) await sleepAsync(chronos.milliseconds(10)) + teardown: + if db != nil: + db.close() + if fileExists(dbName): + removeFile(dbName) + asyncTest "sendOrEnqueue - immediate send when capacity available": ## Given - let manager = RateLimitManager[string].new( - mockSender, capacity = 10, duration = chronos.milliseconds(100) + let store: RateLimitStore[string] = await RateLimitStore[string].new(db) + let manager = await RateLimitManager[string].new( + store, mockSender, capacity = 10, duration = chronos.milliseconds(100) ) let testMsg = "Hello World" @@ -40,8 +52,9 @@ suite "Queue RateLimitManager": asyncTest "sendOrEnqueue - multiple messages": ## Given - let manager = RateLimitManager[string].new( - mockSender, capacity = 10, duration = chronos.milliseconds(100) + let store = await RateLimitStore[string].new(db) + let manager = await RateLimitManager[string].new( + store, mockSender, capacity = 10, duration = chronos.milliseconds(100) ) ## When @@ -60,7 +73,9 @@ suite "Queue RateLimitManager": asyncTest "start and stop - drop large batch": ## Given - let manager = RateLimitManager[string].new( + let store = await RateLimitStore[string].new(db) + let manager = await RateLimitManager[string].new( + store, mockSender, capacity = 2, duration = chronos.milliseconds(100), @@ -75,7 +90,9 @@ suite "Queue RateLimitManager": asyncTest "enqueue - enqueue critical only when exceeded": ## Given - let manager = RateLimitManager[string].new( + let store = await RateLimitStore[string].new(db) + let manager = await RateLimitManager[string].new( + store, mockSender, capacity = 10, duration = chronos.milliseconds(100), @@ -115,15 +132,11 @@ suite "Queue RateLimitManager": r10 == PassedToSender r11 == Enqueued - let (critical, normal) = manager.getEnqueued() - check: - critical.len == 1 - normal.len == 0 - critical[0].msgId == "msg11" - asyncTest "enqueue - enqueue normal on 70% capacity": - ## Given - let manager = RateLimitManager[string].new( + ## Given + let store = await RateLimitStore[string].new(db) + let manager = await RateLimitManager[string].new( + store, mockSender, capacity = 10, duration = chronos.milliseconds(100), @@ -164,17 +177,11 @@ suite "Queue RateLimitManager": r11 == PassedToSender r12 == Dropped - let (critical, normal) = manager.getEnqueued() - check: - critical.len == 0 - normal.len == 3 - normal[0].msgId == "msg8" - normal[1].msgId == "msg9" - normal[2].msgId == "msg10" - asyncTest "enqueue - process queued messages": ## Given - let manager = RateLimitManager[string].new( + let store = await RateLimitStore[string].new(db) + let manager = await RateLimitManager[string].new( + store, mockSender, capacity = 10, duration = chronos.milliseconds(200), @@ -225,24 +232,9 @@ suite "Queue RateLimitManager": r14 == PassedToSender r15 == Enqueued - var (critical, normal) = manager.getEnqueued() check: - critical.len == 1 - normal.len == 3 - normal[0].msgId == "8" - normal[1].msgId == "9" - normal[2].msgId == "10" - critical[0].msgId == "15" - - nowRef.value = now + chronos.milliseconds(250) - await sleepAsync(chronos.milliseconds(250)) - - (critical, normal) = manager.getEnqueued() - check: - critical.len == 0 - normal.len == 0 - senderCallCount == 14 - sentMessages.len == 14 + senderCallCount == 10 # 10 messages passed to sender + sentMessages.len == 10 sentMessages[0].msgId == "1" sentMessages[1].msgId == "2" sentMessages[2].msgId == "3" @@ -253,6 +245,13 @@ suite "Queue RateLimitManager": sentMessages[7].msgId == "11" sentMessages[8].msgId == "13" sentMessages[9].msgId == "14" + + nowRef.value = now + chronos.milliseconds(250) + await sleepAsync(chronos.milliseconds(250)) + + check: + senderCallCount == 14 + sentMessages.len == 14 sentMessages[10].msgId == "15" sentMessages[11].msgId == "8" sentMessages[12].msgId == "9" diff --git a/tests/test_store.nim b/tests/test_store.nim new file mode 100644 index 0000000..16cfdfe --- /dev/null +++ b/tests/test_store.nim @@ -0,0 +1,208 @@ +import testutils/unittests +import ../ratelimit/store +import chronos +import db_connector/db_sqlite +import ../chat_sdk/migration +import std/[options, os, json] +import flatty + +const dbName = "test_store.db" + +suite "SqliteRateLimitStore Tests": + setup: + let db = open(dbName, "", "", "") + runMigrations(db) + + teardown: + if db != nil: + db.close() + if fileExists(dbName): + removeFile(dbName) + + asyncTest "newSqliteRateLimitStore - empty state": + ## Given + let store = await RateLimitStore[string].new(db) + + ## When + let loadedState = await store.loadBucketState() + + ## Then + check loadedState.isNone() + + asyncTest "saveBucketState and loadBucketState - state persistence": + ## Given + let store = await RateLimitStore[string].new(db) + + let now = Moment.now() + echo "now: ", now.epochSeconds() + let newBucketState = BucketState(budget: 5, budgetCap: 20, lastTimeFull: now) + + ## When + let saveResult = await store.saveBucketState(newBucketState) + let loadedState = await store.loadBucketState() + + ## Then + check saveResult == true + check loadedState.isSome() + check loadedState.get().budget == newBucketState.budget + check loadedState.get().budgetCap == newBucketState.budgetCap + check loadedState.get().lastTimeFull.epochSeconds() == + newBucketState.lastTimeFull.epochSeconds() + + asyncTest "queue operations - empty store": + ## Given + let store = await RateLimitStore[string].new(db) + + ## When/Then + check store.getQueueLength(QueueType.Critical) == 0 + check store.getQueueLength(QueueType.Normal) == 0 + + let criticalPop = await store.popFromQueue(QueueType.Critical) + let normalPop = await store.popFromQueue(QueueType.Normal) + + check criticalPop.isNone() + check normalPop.isNone() + + asyncTest "addToQueue and popFromQueue - single batch": + ## Given + let store = await RateLimitStore[string].new(db) + let msgs = @[("msg1", "Hello"), ("msg2", "World")] + + ## When + let addResult = await store.pushToQueue(QueueType.Critical, msgs) + + ## Then + check addResult == true + check store.getQueueLength(QueueType.Critical) == 1 + check store.getQueueLength(QueueType.Normal) == 0 + + ## When + let popResult = await store.popFromQueue(QueueType.Critical) + + ## Then + check popResult.isSome() + let poppedMsgs = popResult.get() + check poppedMsgs.len == 2 + check poppedMsgs[0].msgId == "msg1" + check poppedMsgs[0].msg == "Hello" + check poppedMsgs[1].msgId == "msg2" + check poppedMsgs[1].msg == "World" + + check store.getQueueLength(QueueType.Critical) == 0 + + asyncTest "addToQueue and popFromQueue - multiple batches FIFO": + ## Given + let store = await RateLimitStore[string].new(db) + let batch1 = @[("msg1", "First")] + let batch2 = @[("msg2", "Second")] + let batch3 = @[("msg3", "Third")] + + ## When - Add batches + let result1 = await store.pushToQueue(QueueType.Normal, batch1) + check result1 == true + let result2 = await store.pushToQueue(QueueType.Normal, batch2) + check result2 == true + let result3 = await store.pushToQueue(QueueType.Normal, batch3) + check result3 == true + + ## Then - Check lengths + check store.getQueueLength(QueueType.Normal) == 3 + check store.getQueueLength(QueueType.Critical) == 0 + + ## When/Then - Pop in FIFO order + let pop1 = await store.popFromQueue(QueueType.Normal) + check pop1.isSome() + check pop1.get()[0].msg == "First" + check store.getQueueLength(QueueType.Normal) == 2 + + let pop2 = await store.popFromQueue(QueueType.Normal) + check pop2.isSome() + check pop2.get()[0].msg == "Second" + check store.getQueueLength(QueueType.Normal) == 1 + + let pop3 = await store.popFromQueue(QueueType.Normal) + check pop3.isSome() + check pop3.get()[0].msg == "Third" + check store.getQueueLength(QueueType.Normal) == 0 + + let pop4 = await store.popFromQueue(QueueType.Normal) + check pop4.isNone() + + asyncTest "queue isolation - critical and normal queues are separate": + ## Given + let store = await RateLimitStore[string].new(db) + let criticalMsgs = @[("crit1", "Critical Message")] + let normalMsgs = @[("norm1", "Normal Message")] + + ## When + let critResult = await store.pushToQueue(QueueType.Critical, criticalMsgs) + check critResult == true + let normResult = await store.pushToQueue(QueueType.Normal, normalMsgs) + check normResult == true + + ## Then + check store.getQueueLength(QueueType.Critical) == 1 + check store.getQueueLength(QueueType.Normal) == 1 + + ## When - Pop from critical + let criticalPop = await store.popFromQueue(QueueType.Critical) + check criticalPop.isSome() + check criticalPop.get()[0].msg == "Critical Message" + + ## Then - Normal queue unaffected + check store.getQueueLength(QueueType.Critical) == 0 + check store.getQueueLength(QueueType.Normal) == 1 + + ## When - Pop from normal + let normalPop = await store.popFromQueue(QueueType.Normal) + check normalPop.isSome() + check normalPop.get()[0].msg == "Normal Message" + + ## Then - All queues empty + check store.getQueueLength(QueueType.Critical) == 0 + check store.getQueueLength(QueueType.Normal) == 0 + + asyncTest "queue persistence across store instances": + ## Given + let msgs = @[("persist1", "Persistent Message")] + + block: + let store1 = await RateLimitStore[string].new(db) + let addResult = await store1.pushToQueue(QueueType.Critical, msgs) + check addResult == true + check store1.getQueueLength(QueueType.Critical) == 1 + + ## When - Create new store instance + block: + let store2 =await RateLimitStore[string].new(db) + + ## Then - Queue length should be restored from database + check store2.getQueueLength(QueueType.Critical) == 1 + + let popResult = await store2.popFromQueue(QueueType.Critical) + check popResult.isSome() + check popResult.get()[0].msg == "Persistent Message" + check store2.getQueueLength(QueueType.Critical) == 0 + + asyncTest "large batch handling": + ## Given + let store = await RateLimitStore[string].new(db) + var largeBatch: seq[tuple[msgId: string, msg: string]] + + for i in 1 .. 100: + largeBatch.add(("msg" & $i, "Message " & $i)) + + ## When + let addResult = await store.pushToQueue(QueueType.Normal, largeBatch) + + ## Then + check addResult == true + check store.getQueueLength(QueueType.Normal) == 1 + + let popResult = await store.popFromQueue(QueueType.Normal) + check popResult.isSome() + let poppedMsgs = popResult.get() + check poppedMsgs.len == 100 + check poppedMsgs[0].msgId == "msg1" + check poppedMsgs[99].msgId == "msg100" + check store.getQueueLength(QueueType.Normal) == 0