Merge pull request #5 from waku-org/feat/rate-limit-store-state

Feat/rate-limit-store-state
This commit is contained in:
Pablo Lopez 2025-09-01 05:58:52 +03:00 committed by GitHub
commit 2b1b7f5699
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 759 additions and 113 deletions

4
.gitignore vendored
View File

@ -21,8 +21,10 @@ nimcache/
# Compiled files
chat_sdk/*
apps/*
!*.nim
tests/*
!*.nim
ratelimit/*
!*.nim
!*.proto
nimble.develop

View File

@ -7,7 +7,7 @@ license = "MIT"
srcDir = "src"
### Dependencies
requires "nim >= 2.2.4", "chronicles", "chronos", "db_connector", "waku"
requires "nim >= 2.2.4", "chronicles", "chronos", "db_connector", "flatty"
task buildSharedLib, "Build shared library for C bindings":
exec "nim c --mm:refc --app:lib --out:../library/c-bindings/libchatsdk.so chat_sdk/chat_sdk.nim"

View File

@ -3,15 +3,19 @@ import db_connector/db_sqlite
import chronicles
proc ensureMigrationTable(db: DbConn) =
db.exec(sql"""
db.exec(
sql"""
CREATE TABLE IF NOT EXISTS schema_migrations (
filename TEXT PRIMARY KEY,
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
"""
)
proc hasMigrationRun(db: DbConn, filename: string): bool =
for row in db.fastRows(sql"SELECT 1 FROM schema_migrations WHERE filename = ?", filename):
for row in db.fastRows(
sql"SELECT 1 FROM schema_migrations WHERE filename = ?", filename
):
return true
return false

View File

@ -0,0 +1,13 @@
CREATE TABLE IF NOT EXISTS kv_store (
key TEXT PRIMARY KEY,
value BLOB
);
CREATE TABLE IF NOT EXISTS ratelimit_queues (
queue_type TEXT NOT NULL,
msg_id TEXT NOT NULL,
msg_data BLOB NOT NULL,
batch_id INTEGER NOT NULL,
created_at INTEGER NOT NULL,
PRIMARY KEY (queue_type, batch_id, msg_id)
);

View File

@ -1,6 +1,10 @@
import std/[times, deques, options]
import waku/common/rate_limit/token_bucket
import std/[times, options]
# TODO: move to waku's, chronos' or a lib tocken_bucket once decided where this will live
import ./token_bucket
# import waku/common/rate_limit/token_bucket
import ./store
import chronos
import db_connector/db_sqlite
type
CapacityState {.pure.} = enum
@ -19,41 +23,50 @@ type
Normal
Optional
Serializable* =
concept x
x.toBytes() is seq[byte]
MsgIdMsg[T] = tuple[msgId: string, msg: T]
MsgIdMsg[T: Serializable] = tuple[msgId: string, msg: T]
MessageSender*[T] = proc(msgs: seq[MsgIdMsg[T]]) {.async.}
MessageSender*[T: Serializable] = proc(msgs: seq[MsgIdMsg[T]]) {.async.}
RateLimitManager*[T: Serializable] = ref object
RateLimitManager*[T] = ref object
store: RateLimitStore[T]
bucket: TokenBucket
sender: MessageSender[T]
queueCritical: Deque[seq[MsgIdMsg[T]]]
queueNormal: Deque[seq[MsgIdMsg[T]]]
sleepDuration: chronos.Duration
pxQueueHandleLoop: Future[void]
proc new*[T: Serializable](
proc new*[T](
M: type[RateLimitManager[T]],
store: RateLimitStore[T],
sender: MessageSender[T],
capacity: int = 100,
duration: chronos.Duration = chronos.minutes(10),
sleepDuration: chronos.Duration = chronos.milliseconds(1000),
): M =
M(
bucket: TokenBucket.newStrict(capacity, duration),
): Future[RateLimitManager[T]] {.async.} =
var current = await store.loadBucketState()
if current.isNone():
# initialize bucket state with full capacity
current = some(
BucketState(budget: capacity, budgetCap: capacity, lastTimeFull: Moment.now())
)
discard await store.saveBucketState(current.get())
return RateLimitManager[T](
store: store,
bucket: TokenBucket.new(
current.get().budgetCap,
duration,
ReplenishMode.Strict,
current.get().budget,
current.get().lastTimeFull,
),
sender: sender,
queueCritical: Deque[seq[MsgIdMsg[T]]](),
queueNormal: Deque[seq[MsgIdMsg[T]]](),
sleepDuration: sleepDuration,
)
proc getCapacityState[T: Serializable](
proc getCapacityState[T](
manager: RateLimitManager[T], now: Moment, count: int = 1
): CapacityState =
let (budget, budgetCap) = manager.bucket.getAvailableCapacity(now)
let (budget, budgetCap, _) = manager.bucket.getAvailableCapacity(now)
let countAfter = budget - count
let ratio = countAfter.float / budgetCap.float
if ratio < 0.0:
@ -63,62 +76,77 @@ proc getCapacityState[T: Serializable](
else:
return CapacityState.Normal
proc passToSender[T: Serializable](
proc passToSender[T](
manager: RateLimitManager[T],
msgs: sink seq[MsgIdMsg[T]],
msgs: seq[tuple[msgId: string, msg: T]],
now: Moment,
priority: Priority,
): Future[SendResult] {.async.} =
let count = msgs.len
let capacity = manager.bucket.tryConsume(count, now)
if not capacity:
let consumed = manager.bucket.tryConsume(count, now)
if not consumed:
case priority
of Priority.Critical:
manager.queueCritical.addLast(msgs)
discard await manager.store.pushToQueue(QueueType.Critical, msgs)
return SendResult.Enqueued
of Priority.Normal:
manager.queueNormal.addLast(msgs)
discard await manager.store.pushToQueue(QueueType.Normal, msgs)
return SendResult.Enqueued
of Priority.Optional:
return SendResult.Dropped
let (budget, budgetCap, lastTimeFull) = manager.bucket.getAvailableCapacity(now)
discard await manager.store.saveBucketState(
BucketState(budget: budget, budgetCap: budgetCap, lastTimeFull: lastTimeFull)
)
await manager.sender(msgs)
return SendResult.PassedToSender
proc processCriticalQueue[T: Serializable](
proc processCriticalQueue[T](
manager: RateLimitManager[T], now: Moment
) {.async.} =
while manager.queueCritical.len > 0:
let msgs = manager.queueCritical.popFirst()
): Future[void] {.async.} =
while manager.store.getQueueLength(QueueType.Critical) > 0:
# Peek at the next batch by getting it, but we'll handle putting it back if needed
let maybeMsgs = await manager.store.popFromQueue(QueueType.Critical)
if maybeMsgs.isNone():
break
let msgs = maybeMsgs.get()
let capacityState = manager.getCapacityState(now, msgs.len)
if capacityState == CapacityState.Normal:
discard await manager.passToSender(msgs, now, Priority.Critical)
elif capacityState == CapacityState.AlmostNone:
discard await manager.passToSender(msgs, now, Priority.Critical)
else:
# add back to critical queue
manager.queueCritical.addFirst(msgs)
# Put back to critical queue (add to front not possible, so we add to back and exit)
discard await manager.store.pushToQueue(QueueType.Critical, msgs)
break
proc processNormalQueue[T: Serializable](
proc processNormalQueue[T](
manager: RateLimitManager[T], now: Moment
) {.async.} =
while manager.queueNormal.len > 0:
let msgs = manager.queueNormal.popFirst()
): Future[void] {.async.} =
while manager.store.getQueueLength(QueueType.Normal) > 0:
# Peek at the next batch by getting it, but we'll handle putting it back if needed
let maybeMsgs = await manager.store.popFromQueue(QueueType.Normal)
if maybeMsgs.isNone():
break
let msgs = maybeMsgs.get()
let capacityState = manager.getCapacityState(now, msgs.len)
if capacityState == CapacityState.Normal:
discard await manager.passToSender(msgs, now, Priority.Normal)
else:
# add back to critical queue
manager.queueNormal.addFirst(msgs)
# Put back to normal queue (add to front not possible, so we add to back and exit)
discard await manager.store.pushToQueue(QueueType.Normal, msgs)
break
proc sendOrEnqueue*[T: Serializable](
proc sendOrEnqueue*[T](
manager: RateLimitManager[T],
msgs: seq[MsgIdMsg[T]],
msgs: seq[tuple[msgId: string, msg: T]],
priority: Priority,
now: Moment = Moment.now(),
): Future[SendResult] {.async.} =
let (_, budgetCap) = manager.bucket.getAvailableCapacity(now)
let (_, budgetCap, _) = manager.bucket.getAvailableCapacity(now)
if msgs.len.float / budgetCap.float >= 0.3:
# drop batch if it's too large to avoid starvation
return SendResult.DroppedBatchTooLarge
@ -132,36 +160,22 @@ proc sendOrEnqueue*[T: Serializable](
of Priority.Critical:
return await manager.passToSender(msgs, now, priority)
of Priority.Normal:
manager.queueNormal.addLast(msgs)
discard await manager.store.pushToQueue(QueueType.Normal, msgs)
return SendResult.Enqueued
of Priority.Optional:
return SendResult.Dropped
of CapacityState.None:
case priority
of Priority.Critical:
manager.queueCritical.addLast(msgs)
discard await manager.store.pushToQueue(QueueType.Critical, msgs)
return SendResult.Enqueued
of Priority.Normal:
manager.queueNormal.addLast(msgs)
discard await manager.store.pushToQueue(QueueType.Normal, msgs)
return SendResult.Enqueued
of Priority.Optional:
return SendResult.Dropped
proc getEnqueued*[T: Serializable](
manager: RateLimitManager[T]
): tuple[critical: seq[MsgIdMsg[T]], normal: seq[MsgIdMsg[T]]] =
var criticalMsgs: seq[MsgIdMsg[T]]
var normalMsgs: seq[MsgIdMsg[T]]
for batch in manager.queueCritical:
criticalMsgs.add(batch)
for batch in manager.queueNormal:
normalMsgs.add(batch)
return (criticalMsgs, normalMsgs)
proc queueHandleLoop[T: Serializable](
proc queueHandleLoop*[T](
manager: RateLimitManager[T],
nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} =
Moment.now(),
@ -177,20 +191,20 @@ proc queueHandleLoop[T: Serializable](
# configurable sleep duration for processing queued messages
await sleepAsync(manager.sleepDuration)
proc start*[T: Serializable](
proc start*[T](
manager: RateLimitManager[T],
nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} =
Moment.now(),
) {.async.} =
manager.pxQueueHandleLoop = manager.queueHandleLoop(nowProvider)
manager.pxQueueHandleLoop = queueHandleLoop(manager, nowProvider)
proc stop*[T: Serializable](manager: RateLimitManager[T]) {.async.} =
proc stop*[T](manager: RateLimitManager[T]) {.async.} =
if not isNil(manager.pxQueueHandleLoop):
await manager.pxQueueHandleLoop.cancelAndWait()
func `$`*[T: Serializable](b: RateLimitManager[T]): string {.inline.} =
func `$`*[T](b: RateLimitManager[T]): string {.inline.} =
if isNil(b):
return "nil"
return
"RateLimitManager(critical: " & $b.queueCritical.len & ", normal: " &
$b.queueNormal.len & ")"
"RateLimitManager(critical: " & $b.store.getQueueLength(QueueType.Critical) &
", normal: " & $b.store.getQueueLength(QueueType.Normal) & ")"

200
ratelimit/store.nim Normal file
View File

@ -0,0 +1,200 @@
import std/[times, strutils, json, options, base64]
import db_connector/db_sqlite
import chronos
import flatty
type
RateLimitStore*[T] = ref object
db: DbConn
dbPath: string
criticalLength: int
normalLength: int
nextBatchId: int
BucketState* {.pure} = object
budget*: int
budgetCap*: int
lastTimeFull*: Moment
QueueType* {.pure.} = enum
Critical = "critical"
Normal = "normal"
const BUCKET_STATE_KEY = "rate_limit_bucket_state"
## TODO find a way to make these procs async
proc new*[T](M: type[RateLimitStore[T]], db: DbConn): Future[M] {.async} =
result = M(db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1)
# Initialize cached lengths from database
let criticalCount = db.getValue(
sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?",
"critical",
)
let normalCount = db.getValue(
sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?",
"normal",
)
result.criticalLength =
if criticalCount == "":
0
else:
parseInt(criticalCount)
result.normalLength =
if normalCount == "":
0
else:
parseInt(normalCount)
# Get next batch ID
let maxBatch = db.getValue(sql"SELECT MAX(batch_id) FROM ratelimit_queues")
result.nextBatchId =
if maxBatch == "":
1
else:
parseInt(maxBatch) + 1
return result
proc saveBucketState*[T](
store: RateLimitStore[T], bucketState: BucketState
): Future[bool] {.async.} =
try:
# Convert Moment to Unix seconds for storage
let lastTimeSeconds = bucketState.lastTimeFull.epochSeconds()
let jsonState =
%*{
"budget": bucketState.budget,
"budgetCap": bucketState.budgetCap,
"lastTimeFullSeconds": lastTimeSeconds,
}
store.db.exec(
sql"INSERT INTO kv_store (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value",
BUCKET_STATE_KEY,
$jsonState,
)
return true
except:
return false
proc loadBucketState*[T](
store: RateLimitStore[T]
): Future[Option[BucketState]] {.async.} =
let jsonStr =
store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY)
if jsonStr == "":
return none(BucketState)
let jsonData = parseJson(jsonStr)
let unixSeconds = jsonData["lastTimeFullSeconds"].getInt().int64
let lastTimeFull = Moment.init(unixSeconds, chronos.seconds(1))
return some(
BucketState(
budget: jsonData["budget"].getInt(),
budgetCap: jsonData["budgetCap"].getInt(),
lastTimeFull: lastTimeFull,
)
)
proc pushToQueue*[T](
store: RateLimitStore[T],
queueType: QueueType,
msgs: seq[tuple[msgId: string, msg: T]],
): Future[bool] {.async.} =
try:
let batchId = store.nextBatchId
inc store.nextBatchId
let now = times.getTime().toUnix()
let queueTypeStr = $queueType
if msgs.len > 0:
store.db.exec(sql"BEGIN TRANSACTION")
try:
for msg in msgs:
let serialized = msg.msg.toFlatty()
let msgData = encode(serialized)
store.db.exec(
sql"INSERT INTO ratelimit_queues (queue_type, msg_id, msg_data, batch_id, created_at) VALUES (?, ?, ?, ?, ?)",
queueTypeStr,
msg.msgId,
msgData,
batchId,
now,
)
store.db.exec(sql"COMMIT")
except:
store.db.exec(sql"ROLLBACK")
raise
case queueType
of QueueType.Critical:
inc store.criticalLength
of QueueType.Normal:
inc store.normalLength
return true
except:
return false
proc popFromQueue*[T](
store: RateLimitStore[T], queueType: QueueType
): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} =
try:
let queueTypeStr = $queueType
# Get the oldest batch ID for this queue type
let oldestBatchStr = store.db.getValue(
sql"SELECT MIN(batch_id) FROM ratelimit_queues WHERE queue_type = ?", queueTypeStr
)
if oldestBatchStr == "":
return none(seq[tuple[msgId: string, msg: T]])
let batchId = parseInt(oldestBatchStr)
# Get all messages in this batch (preserve insertion order using rowid)
let rows = store.db.getAllRows(
sql"SELECT msg_id, msg_data FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ? ORDER BY rowid",
queueTypeStr,
batchId,
)
if rows.len == 0:
return none(seq[tuple[msgId: string, msg: T]])
var msgs: seq[tuple[msgId: string, msg: T]]
for row in rows:
let msgIdStr = row[0]
let msgDataB64 = row[1]
let serialized = decode(msgDataB64)
let msg = serialized.fromFlatty(T)
msgs.add((msgId: msgIdStr, msg: msg))
# Delete the batch from database
store.db.exec(
sql"DELETE FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ?",
queueTypeStr,
batchId,
)
case queueType
of QueueType.Critical:
dec store.criticalLength
of QueueType.Normal:
dec store.normalLength
return some(msgs)
except:
return none(seq[tuple[msgId: string, msg: T]])
proc getQueueLength*[T](store: RateLimitStore[T], queueType: QueueType): int =
case queueType
of QueueType.Critical:
return store.criticalLength
of QueueType.Normal:
return store.normalLength

206
ratelimit/token_bucket.nim Normal file
View File

@ -0,0 +1,206 @@
{.push raises: [].}
import chronos, std/math, std/options
const BUDGET_COMPENSATION_LIMIT_PERCENT = 0.25
## TODO! This will be remoded and replaced by https://github.com/status-im/nim-chronos/pull/582
## This is an extract from chronos/rate_limit.nim due to the found bug in the original implementation.
## Unfortunately that bug cannot be solved without harm the original features of TokenBucket class.
## So, this current shortcut is used to enable move ahead with nwaku rate limiter implementation.
## ref: https://github.com/status-im/nim-chronos/issues/500
##
## This version of TokenBucket is different from the original one in chronos/rate_limit.nim in many ways:
## - It has a new mode called `Compensating` which is the default mode.
## Compensation is calculated as the not used bucket capacity in the last measured period(s) in average.
## or up until maximum the allowed compansation treshold (Currently it is const 25%).
## Also compensation takes care of the proper time period calculation to avoid non-usage periods that can lead to
## overcompensation.
## - Strict mode is also available which will only replenish when time period is over but also will fill
## the bucket to the max capacity.
type
ReplenishMode* = enum
Strict
Compensating
TokenBucket* = ref object
budget: int ## Current number of tokens in the bucket
budgetCap: int ## Bucket capacity
lastTimeFull: Moment
## This timer measures the proper periodizaiton of the bucket refilling
fillDuration: Duration ## Refill period
case replenishMode*: ReplenishMode
of Strict:
## In strict mode, the bucket is refilled only till the budgetCap
discard
of Compensating:
## This is the default mode.
maxCompensation: float
func periodDistance(bucket: TokenBucket, currentTime: Moment): float =
## notice fillDuration cannot be zero by design
## period distance is a float number representing the calculated period time
## since the last time bucket was refilled.
return
nanoseconds(currentTime - bucket.lastTimeFull).float /
nanoseconds(bucket.fillDuration).float
func getUsageAverageSince(bucket: TokenBucket, distance: float): float =
if distance == 0.float:
## in case there is zero time difference than the usage percentage is 100%
return 1.0
## budgetCap can never be zero
## usage average is calculated as a percentage of total capacity available over
## the measured period
return bucket.budget.float / bucket.budgetCap.float / distance
proc calcCompensation(bucket: TokenBucket, averageUsage: float): int =
# if we already fully used or even overused the tokens, there is no place for compensation
if averageUsage >= 1.0:
return 0
## compensation is the not used bucket capacity in the last measured period(s) in average.
## or maximum the allowed compansation treshold
let compensationPercent =
min((1.0 - averageUsage) * bucket.budgetCap.float, bucket.maxCompensation)
return trunc(compensationPercent).int
func periodElapsed(bucket: TokenBucket, currentTime: Moment): bool =
return currentTime - bucket.lastTimeFull >= bucket.fillDuration
## Update will take place if bucket is empty and trying to consume tokens.
## It checks if the bucket can be replenished as refill duration is passed or not.
## - strict mode:
proc updateStrict(bucket: TokenBucket, currentTime: Moment) =
if bucket.fillDuration == default(Duration):
bucket.budget = min(bucket.budgetCap, bucket.budget)
return
if not periodElapsed(bucket, currentTime):
return
bucket.budget = bucket.budgetCap
bucket.lastTimeFull = currentTime
## - compensating - ballancing load:
## - between updates we calculate average load (current bucket capacity / number of periods till last update)
## - gives the percentage load used recently
## - with this we can replenish bucket up to 100% + calculated leftover from previous period (caped with max treshold)
proc updateWithCompensation(bucket: TokenBucket, currentTime: Moment) =
if bucket.fillDuration == default(Duration):
bucket.budget = min(bucket.budgetCap, bucket.budget)
return
# do not replenish within the same period
if not periodElapsed(bucket, currentTime):
return
let distance = bucket.periodDistance(currentTime)
let recentAvgUsage = bucket.getUsageAverageSince(distance)
let compensation = bucket.calcCompensation(recentAvgUsage)
bucket.budget = bucket.budgetCap + compensation
bucket.lastTimeFull = currentTime
proc update(bucket: TokenBucket, currentTime: Moment) =
if bucket.replenishMode == ReplenishMode.Compensating:
updateWithCompensation(bucket, currentTime)
else:
updateStrict(bucket, currentTime)
proc getAvailableCapacity*(
bucket: TokenBucket, currentTime: Moment = Moment.now()
): tuple[budget: int, budgetCap: int, lastTimeFull: Moment] =
## Returns the available capacity of the bucket: (budget, budgetCap, lastTimeFull)
if periodElapsed(bucket, currentTime):
case bucket.replenishMode
of ReplenishMode.Strict:
return (bucket.budgetCap, bucket.budgetCap, bucket.lastTimeFull)
of ReplenishMode.Compensating:
let distance = bucket.periodDistance(currentTime)
let recentAvgUsage = bucket.getUsageAverageSince(distance)
let compensation = bucket.calcCompensation(recentAvgUsage)
let availableBudget = bucket.budgetCap + compensation
return (availableBudget, bucket.budgetCap, bucket.lastTimeFull)
return (bucket.budget, bucket.budgetCap, bucket.lastTimeFull)
proc tryConsume*(bucket: TokenBucket, tokens: int, now = Moment.now()): bool =
## If `tokens` are available, consume them,
## Otherwhise, return false.
if bucket.budget >= bucket.budgetCap:
bucket.lastTimeFull = now
if bucket.budget >= tokens:
bucket.budget -= tokens
return true
bucket.update(now)
if bucket.budget >= tokens:
bucket.budget -= tokens
return true
else:
return false
proc replenish*(bucket: TokenBucket, tokens: int, now = Moment.now()) =
## Add `tokens` to the budget (capped to the bucket capacity)
bucket.budget += tokens
bucket.update(now)
proc new*(
T: type[TokenBucket],
budgetCap: int,
fillDuration: Duration = 1.seconds,
mode: ReplenishMode = ReplenishMode.Compensating,
budget: int = -1, # -1 means "use budgetCap"
lastTimeFull: Moment = Moment.now(),
): T =
assert not isZero(fillDuration)
assert budgetCap != 0
assert lastTimeFull <= Moment.now()
let actualBudget = if budget == -1: budgetCap else: budget
assert actualBudget >= 0 and actualBudget <= budgetCap
## Create different mode TokenBucket
case mode
of ReplenishMode.Strict:
return T(
budget: actualBudget,
budgetCap: budgetCap,
fillDuration: fillDuration,
lastTimeFull: lastTimeFull,
replenishMode: mode,
)
of ReplenishMode.Compensating:
T(
budget: actualBudget,
budgetCap: budgetCap,
fillDuration: fillDuration,
lastTimeFull: lastTimeFull,
replenishMode: mode,
maxCompensation: budgetCap.float * BUDGET_COMPENSATION_LIMIT_PERCENT,
)
proc newStrict*(T: type[TokenBucket], capacity: int, period: Duration): TokenBucket =
T.new(capacity, period, ReplenishMode.Strict)
proc newCompensating*(
T: type[TokenBucket], capacity: int, period: Duration
): TokenBucket =
T.new(capacity, period, ReplenishMode.Compensating)
func `$`*(b: TokenBucket): string {.inline.} =
if isNil(b):
return "nil"
return $b.budgetCap & "/" & $b.fillDuration
func `$`*(ob: Option[TokenBucket]): string {.inline.} =
if ob.isNone():
return "no-limit"
return $ob.get()

View File

@ -1,13 +1,18 @@
import testutils/unittests
import ../ratelimit/rate_limit_manager
import ../ratelimit/ratelimit_manager
import ../ratelimit/store
import chronos
import db_connector/db_sqlite
import ../chat_sdk/migration
import std/[os, options]
# Implement the Serializable concept for string
proc toBytes*(s: string): seq[byte] =
cast[seq[byte]](s)
var dbName = "test_ratelimit_manager.db"
suite "Queue RateLimitManager":
setup:
let db = open(dbName, "", "", "")
runMigrations(db)
var sentMessages: seq[tuple[msgId: string, msg: string]]
var senderCallCount: int = 0
@ -20,10 +25,17 @@ suite "Queue RateLimitManager":
sentMessages.add(msg)
await sleepAsync(chronos.milliseconds(10))
teardown:
if db != nil:
db.close()
if fileExists(dbName):
removeFile(dbName)
asyncTest "sendOrEnqueue - immediate send when capacity available":
## Given
let manager = RateLimitManager[string].new(
mockSender, capacity = 10, duration = chronos.milliseconds(100)
let store: RateLimitStore[string] = await RateLimitStore[string].new(db)
let manager = await RateLimitManager[string].new(
store, mockSender, capacity = 10, duration = chronos.milliseconds(100)
)
let testMsg = "Hello World"
@ -40,8 +52,9 @@ suite "Queue RateLimitManager":
asyncTest "sendOrEnqueue - multiple messages":
## Given
let manager = RateLimitManager[string].new(
mockSender, capacity = 10, duration = chronos.milliseconds(100)
let store = await RateLimitStore[string].new(db)
let manager = await RateLimitManager[string].new(
store, mockSender, capacity = 10, duration = chronos.milliseconds(100)
)
## When
@ -60,7 +73,9 @@ suite "Queue RateLimitManager":
asyncTest "start and stop - drop large batch":
## Given
let manager = RateLimitManager[string].new(
let store = await RateLimitStore[string].new(db)
let manager = await RateLimitManager[string].new(
store,
mockSender,
capacity = 2,
duration = chronos.milliseconds(100),
@ -75,7 +90,9 @@ suite "Queue RateLimitManager":
asyncTest "enqueue - enqueue critical only when exceeded":
## Given
let manager = RateLimitManager[string].new(
let store = await RateLimitStore[string].new(db)
let manager = await RateLimitManager[string].new(
store,
mockSender,
capacity = 10,
duration = chronos.milliseconds(100),
@ -115,15 +132,11 @@ suite "Queue RateLimitManager":
r10 == PassedToSender
r11 == Enqueued
let (critical, normal) = manager.getEnqueued()
check:
critical.len == 1
normal.len == 0
critical[0].msgId == "msg11"
asyncTest "enqueue - enqueue normal on 70% capacity":
## Given
let manager = RateLimitManager[string].new(
## Given
let store = await RateLimitStore[string].new(db)
let manager = await RateLimitManager[string].new(
store,
mockSender,
capacity = 10,
duration = chronos.milliseconds(100),
@ -164,17 +177,11 @@ suite "Queue RateLimitManager":
r11 == PassedToSender
r12 == Dropped
let (critical, normal) = manager.getEnqueued()
check:
critical.len == 0
normal.len == 3
normal[0].msgId == "msg8"
normal[1].msgId == "msg9"
normal[2].msgId == "msg10"
asyncTest "enqueue - process queued messages":
## Given
let manager = RateLimitManager[string].new(
let store = await RateLimitStore[string].new(db)
let manager = await RateLimitManager[string].new(
store,
mockSender,
capacity = 10,
duration = chronos.milliseconds(200),
@ -225,24 +232,9 @@ suite "Queue RateLimitManager":
r14 == PassedToSender
r15 == Enqueued
var (critical, normal) = manager.getEnqueued()
check:
critical.len == 1
normal.len == 3
normal[0].msgId == "8"
normal[1].msgId == "9"
normal[2].msgId == "10"
critical[0].msgId == "15"
nowRef.value = now + chronos.milliseconds(250)
await sleepAsync(chronos.milliseconds(250))
(critical, normal) = manager.getEnqueued()
check:
critical.len == 0
normal.len == 0
senderCallCount == 14
sentMessages.len == 14
senderCallCount == 10 # 10 messages passed to sender
sentMessages.len == 10
sentMessages[0].msgId == "1"
sentMessages[1].msgId == "2"
sentMessages[2].msgId == "3"
@ -253,6 +245,13 @@ suite "Queue RateLimitManager":
sentMessages[7].msgId == "11"
sentMessages[8].msgId == "13"
sentMessages[9].msgId == "14"
nowRef.value = now + chronos.milliseconds(250)
await sleepAsync(chronos.milliseconds(250))
check:
senderCallCount == 14
sentMessages.len == 14
sentMessages[10].msgId == "15"
sentMessages[11].msgId == "8"
sentMessages[12].msgId == "9"

208
tests/test_store.nim Normal file
View File

@ -0,0 +1,208 @@
import testutils/unittests
import ../ratelimit/store
import chronos
import db_connector/db_sqlite
import ../chat_sdk/migration
import std/[options, os, json]
import flatty
const dbName = "test_store.db"
suite "SqliteRateLimitStore Tests":
setup:
let db = open(dbName, "", "", "")
runMigrations(db)
teardown:
if db != nil:
db.close()
if fileExists(dbName):
removeFile(dbName)
asyncTest "newSqliteRateLimitStore - empty state":
## Given
let store = await RateLimitStore[string].new(db)
## When
let loadedState = await store.loadBucketState()
## Then
check loadedState.isNone()
asyncTest "saveBucketState and loadBucketState - state persistence":
## Given
let store = await RateLimitStore[string].new(db)
let now = Moment.now()
echo "now: ", now.epochSeconds()
let newBucketState = BucketState(budget: 5, budgetCap: 20, lastTimeFull: now)
## When
let saveResult = await store.saveBucketState(newBucketState)
let loadedState = await store.loadBucketState()
## Then
check saveResult == true
check loadedState.isSome()
check loadedState.get().budget == newBucketState.budget
check loadedState.get().budgetCap == newBucketState.budgetCap
check loadedState.get().lastTimeFull.epochSeconds() ==
newBucketState.lastTimeFull.epochSeconds()
asyncTest "queue operations - empty store":
## Given
let store = await RateLimitStore[string].new(db)
## When/Then
check store.getQueueLength(QueueType.Critical) == 0
check store.getQueueLength(QueueType.Normal) == 0
let criticalPop = await store.popFromQueue(QueueType.Critical)
let normalPop = await store.popFromQueue(QueueType.Normal)
check criticalPop.isNone()
check normalPop.isNone()
asyncTest "addToQueue and popFromQueue - single batch":
## Given
let store = await RateLimitStore[string].new(db)
let msgs = @[("msg1", "Hello"), ("msg2", "World")]
## When
let addResult = await store.pushToQueue(QueueType.Critical, msgs)
## Then
check addResult == true
check store.getQueueLength(QueueType.Critical) == 1
check store.getQueueLength(QueueType.Normal) == 0
## When
let popResult = await store.popFromQueue(QueueType.Critical)
## Then
check popResult.isSome()
let poppedMsgs = popResult.get()
check poppedMsgs.len == 2
check poppedMsgs[0].msgId == "msg1"
check poppedMsgs[0].msg == "Hello"
check poppedMsgs[1].msgId == "msg2"
check poppedMsgs[1].msg == "World"
check store.getQueueLength(QueueType.Critical) == 0
asyncTest "addToQueue and popFromQueue - multiple batches FIFO":
## Given
let store = await RateLimitStore[string].new(db)
let batch1 = @[("msg1", "First")]
let batch2 = @[("msg2", "Second")]
let batch3 = @[("msg3", "Third")]
## When - Add batches
let result1 = await store.pushToQueue(QueueType.Normal, batch1)
check result1 == true
let result2 = await store.pushToQueue(QueueType.Normal, batch2)
check result2 == true
let result3 = await store.pushToQueue(QueueType.Normal, batch3)
check result3 == true
## Then - Check lengths
check store.getQueueLength(QueueType.Normal) == 3
check store.getQueueLength(QueueType.Critical) == 0
## When/Then - Pop in FIFO order
let pop1 = await store.popFromQueue(QueueType.Normal)
check pop1.isSome()
check pop1.get()[0].msg == "First"
check store.getQueueLength(QueueType.Normal) == 2
let pop2 = await store.popFromQueue(QueueType.Normal)
check pop2.isSome()
check pop2.get()[0].msg == "Second"
check store.getQueueLength(QueueType.Normal) == 1
let pop3 = await store.popFromQueue(QueueType.Normal)
check pop3.isSome()
check pop3.get()[0].msg == "Third"
check store.getQueueLength(QueueType.Normal) == 0
let pop4 = await store.popFromQueue(QueueType.Normal)
check pop4.isNone()
asyncTest "queue isolation - critical and normal queues are separate":
## Given
let store = await RateLimitStore[string].new(db)
let criticalMsgs = @[("crit1", "Critical Message")]
let normalMsgs = @[("norm1", "Normal Message")]
## When
let critResult = await store.pushToQueue(QueueType.Critical, criticalMsgs)
check critResult == true
let normResult = await store.pushToQueue(QueueType.Normal, normalMsgs)
check normResult == true
## Then
check store.getQueueLength(QueueType.Critical) == 1
check store.getQueueLength(QueueType.Normal) == 1
## When - Pop from critical
let criticalPop = await store.popFromQueue(QueueType.Critical)
check criticalPop.isSome()
check criticalPop.get()[0].msg == "Critical Message"
## Then - Normal queue unaffected
check store.getQueueLength(QueueType.Critical) == 0
check store.getQueueLength(QueueType.Normal) == 1
## When - Pop from normal
let normalPop = await store.popFromQueue(QueueType.Normal)
check normalPop.isSome()
check normalPop.get()[0].msg == "Normal Message"
## Then - All queues empty
check store.getQueueLength(QueueType.Critical) == 0
check store.getQueueLength(QueueType.Normal) == 0
asyncTest "queue persistence across store instances":
## Given
let msgs = @[("persist1", "Persistent Message")]
block:
let store1 = await RateLimitStore[string].new(db)
let addResult = await store1.pushToQueue(QueueType.Critical, msgs)
check addResult == true
check store1.getQueueLength(QueueType.Critical) == 1
## When - Create new store instance
block:
let store2 =await RateLimitStore[string].new(db)
## Then - Queue length should be restored from database
check store2.getQueueLength(QueueType.Critical) == 1
let popResult = await store2.popFromQueue(QueueType.Critical)
check popResult.isSome()
check popResult.get()[0].msg == "Persistent Message"
check store2.getQueueLength(QueueType.Critical) == 0
asyncTest "large batch handling":
## Given
let store = await RateLimitStore[string].new(db)
var largeBatch: seq[tuple[msgId: string, msg: string]]
for i in 1 .. 100:
largeBatch.add(("msg" & $i, "Message " & $i))
## When
let addResult = await store.pushToQueue(QueueType.Normal, largeBatch)
## Then
check addResult == true
check store.getQueueLength(QueueType.Normal) == 1
let popResult = await store.popFromQueue(QueueType.Normal)
check popResult.isSome()
let poppedMsgs = popResult.get()
check poppedMsgs.len == 100
check poppedMsgs[0].msgId == "msg1"
check poppedMsgs[99].msgId == "msg100"
check store.getQueueLength(QueueType.Normal) == 0