Merge pull request #6 from waku-org/feat/ratelimit-store-queue

Feat/ratelimit store queue
This commit is contained in:
Pablo Lopez 2025-09-01 05:50:10 +03:00 committed by GitHub
commit a567e0ab9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 526 additions and 266 deletions

View File

@ -7,9 +7,7 @@ license = "MIT"
srcDir = "src"
### Dependencies
requires "nim >= 2.2.4",
"chronicles", "chronos", "db_connector",
"https://github.com/waku-org/token_bucket.git"
requires "nim >= 2.2.4", "chronicles", "chronos", "db_connector", "flatty"
task buildSharedLib, "Build shared library for C bindings":
exec "nim c --mm:refc --app:lib --out:../library/c-bindings/libchatsdk.so chat_sdk/chat_sdk.nim"

View File

@ -1,17 +1,21 @@
import os, sequtils, algorithm
import os, sequtils, algorithm, strutils
import db_connector/db_sqlite
import chronicles
proc ensureMigrationTable(db: DbConn) =
db.exec(sql"""
db.exec(
sql"""
CREATE TABLE IF NOT EXISTS schema_migrations (
filename TEXT PRIMARY KEY,
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
"""
)
proc hasMigrationRun(db: DbConn, filename: string): bool =
for row in db.fastRows(sql"SELECT 1 FROM schema_migrations WHERE filename = ?", filename):
for row in db.fastRows(
sql"SELECT 1 FROM schema_migrations WHERE filename = ?", filename
):
return true
return false
@ -27,6 +31,16 @@ proc runMigrations*(db: DbConn, dir = "migrations") =
info "Migration already applied", file
else:
info "Applying migration", file
let sql = readFile(file)
db.exec(sql(sql))
markMigrationRun(db, file)
let sqlContent = readFile(file)
db.exec(sql"BEGIN TRANSACTION")
try:
# Split by semicolon and execute each statement separately
for stmt in sqlContent.split(';'):
let trimmedStmt = stmt.strip()
if trimmedStmt.len > 0:
db.exec(sql(trimmedStmt))
markMigrationRun(db, file)
db.exec(sql"COMMIT")
except:
db.exec(sql"ROLLBACK")
raise newException(ValueError, "Migration failed: " & file & " - " & getCurrentExceptionMsg())

View File

@ -1,4 +1,13 @@
CREATE TABLE IF NOT EXISTS kv_store (
key TEXT PRIMARY KEY,
value BLOB
);
CREATE TABLE IF NOT EXISTS ratelimit_queues (
queue_type TEXT NOT NULL,
msg_id TEXT NOT NULL,
msg_data BLOB NOT NULL,
batch_id INTEGER NOT NULL,
created_at INTEGER NOT NULL,
PRIMARY KEY (queue_type, batch_id, msg_id)
);

View File

@ -1,9 +1,10 @@
import std/[times, deques, options]
import std/[times, options]
# TODO: move to waku's, chronos' or a lib tocken_bucket once decided where this will live
import ./token_bucket
# import waku/common/rate_limit/token_bucket
import ./store/store
import ./store
import chronos
import db_connector/db_sqlite
type
CapacityState {.pure.} = enum
@ -22,31 +23,25 @@ type
Normal
Optional
Serializable* =
concept x
x.toBytes() is seq[byte]
MsgIdMsg[T] = tuple[msgId: string, msg: T]
MsgIdMsg[T: Serializable] = tuple[msgId: string, msg: T]
MessageSender*[T] = proc(msgs: seq[MsgIdMsg[T]]) {.async.}
MessageSender*[T: Serializable] = proc(msgs: seq[MsgIdMsg[T]]) {.async.}
RateLimitManager*[T: Serializable, S: RateLimitStore] = ref object
store: S
RateLimitManager*[T] = ref object
store: RateLimitStore[T]
bucket: TokenBucket
sender: MessageSender[T]
queueCritical: Deque[seq[MsgIdMsg[T]]]
queueNormal: Deque[seq[MsgIdMsg[T]]]
sleepDuration: chronos.Duration
pxQueueHandleLoop: Future[void]
proc new*[T: Serializable, S: RateLimitStore](
M: type[RateLimitManager[T, S]],
store: S,
proc new*[T](
M: type[RateLimitManager[T]],
store: RateLimitStore[T],
sender: MessageSender[T],
capacity: int = 100,
duration: chronos.Duration = chronos.minutes(10),
sleepDuration: chronos.Duration = chronos.milliseconds(1000),
): Future[RateLimitManager[T, S]] {.async.} =
): Future[RateLimitManager[T]] {.async.} =
var current = await store.loadBucketState()
if current.isNone():
# initialize bucket state with full capacity
@ -55,7 +50,7 @@ proc new*[T: Serializable, S: RateLimitStore](
)
discard await store.saveBucketState(current.get())
return RateLimitManager[T, S](
return RateLimitManager[T](
store: store,
bucket: TokenBucket.new(
current.get().budgetCap,
@ -65,13 +60,11 @@ proc new*[T: Serializable, S: RateLimitStore](
current.get().lastTimeFull,
),
sender: sender,
queueCritical: Deque[seq[MsgIdMsg[T]]](),
queueNormal: Deque[seq[MsgIdMsg[T]]](),
sleepDuration: sleepDuration,
)
proc getCapacityState[T: Serializable, S: RateLimitStore](
manager: RateLimitManager[T, S], now: Moment, count: int = 1
proc getCapacityState[T](
manager: RateLimitManager[T], now: Moment, count: int = 1
): CapacityState =
let (budget, budgetCap, _) = manager.bucket.getAvailableCapacity(now)
let countAfter = budget - count
@ -83,8 +76,8 @@ proc getCapacityState[T: Serializable, S: RateLimitStore](
else:
return CapacityState.Normal
proc passToSender[T: Serializable, S: RateLimitStore](
manager: RateLimitManager[T, S],
proc passToSender[T](
manager: RateLimitManager[T],
msgs: seq[tuple[msgId: string, msg: T]],
now: Moment,
priority: Priority,
@ -94,10 +87,10 @@ proc passToSender[T: Serializable, S: RateLimitStore](
if not consumed:
case priority
of Priority.Critical:
manager.queueCritical.addLast(msgs)
discard await manager.store.pushToQueue(QueueType.Critical, msgs)
return SendResult.Enqueued
of Priority.Normal:
manager.queueNormal.addLast(msgs)
discard await manager.store.pushToQueue(QueueType.Normal, msgs)
return SendResult.Enqueued
of Priority.Optional:
return SendResult.Dropped
@ -109,36 +102,46 @@ proc passToSender[T: Serializable, S: RateLimitStore](
await manager.sender(msgs)
return SendResult.PassedToSender
proc processCriticalQueue[T: Serializable, S: RateLimitStore](
manager: RateLimitManager[T, S], now: Moment
proc processCriticalQueue[T](
manager: RateLimitManager[T], now: Moment
): Future[void] {.async.} =
while manager.queueCritical.len > 0:
let msgs = manager.queueCritical.popFirst()
while manager.store.getQueueLength(QueueType.Critical) > 0:
# Peek at the next batch by getting it, but we'll handle putting it back if needed
let maybeMsgs = await manager.store.popFromQueue(QueueType.Critical)
if maybeMsgs.isNone():
break
let msgs = maybeMsgs.get()
let capacityState = manager.getCapacityState(now, msgs.len)
if capacityState == CapacityState.Normal:
discard await manager.passToSender(msgs, now, Priority.Critical)
elif capacityState == CapacityState.AlmostNone:
discard await manager.passToSender(msgs, now, Priority.Critical)
else:
# add back to critical queue
manager.queueCritical.addFirst(msgs)
# Put back to critical queue (add to front not possible, so we add to back and exit)
discard await manager.store.pushToQueue(QueueType.Critical, msgs)
break
proc processNormalQueue[T: Serializable, S: RateLimitStore](
manager: RateLimitManager[T, S], now: Moment
proc processNormalQueue[T](
manager: RateLimitManager[T], now: Moment
): Future[void] {.async.} =
while manager.queueNormal.len > 0:
let msgs = manager.queueNormal.popFirst()
while manager.store.getQueueLength(QueueType.Normal) > 0:
# Peek at the next batch by getting it, but we'll handle putting it back if needed
let maybeMsgs = await manager.store.popFromQueue(QueueType.Normal)
if maybeMsgs.isNone():
break
let msgs = maybeMsgs.get()
let capacityState = manager.getCapacityState(now, msgs.len)
if capacityState == CapacityState.Normal:
discard await manager.passToSender(msgs, now, Priority.Normal)
else:
# add back to critical queue
manager.queueNormal.addFirst(msgs)
# Put back to normal queue (add to front not possible, so we add to back and exit)
discard await manager.store.pushToQueue(QueueType.Normal, msgs)
break
proc sendOrEnqueue*[T: Serializable, S: RateLimitStore](
manager: RateLimitManager[T, S],
proc sendOrEnqueue*[T](
manager: RateLimitManager[T],
msgs: seq[tuple[msgId: string, msg: T]],
priority: Priority,
now: Moment = Moment.now(),
@ -157,39 +160,23 @@ proc sendOrEnqueue*[T: Serializable, S: RateLimitStore](
of Priority.Critical:
return await manager.passToSender(msgs, now, priority)
of Priority.Normal:
manager.queueNormal.addLast(msgs)
discard await manager.store.pushToQueue(QueueType.Normal, msgs)
return SendResult.Enqueued
of Priority.Optional:
return SendResult.Dropped
of CapacityState.None:
case priority
of Priority.Critical:
manager.queueCritical.addLast(msgs)
discard await manager.store.pushToQueue(QueueType.Critical, msgs)
return SendResult.Enqueued
of Priority.Normal:
manager.queueNormal.addLast(msgs)
discard await manager.store.pushToQueue(QueueType.Normal, msgs)
return SendResult.Enqueued
of Priority.Optional:
return SendResult.Dropped
proc getEnqueued*[T: Serializable, S: RateLimitStore](
manager: RateLimitManager[T, S]
): tuple[
critical: seq[tuple[msgId: string, msg: T]], normal: seq[tuple[msgId: string, msg: T]]
] =
var criticalMsgs: seq[tuple[msgId: string, msg: T]]
var normalMsgs: seq[tuple[msgId: string, msg: T]]
for batch in manager.queueCritical:
criticalMsgs.add(batch)
for batch in manager.queueNormal:
normalMsgs.add(batch)
return (criticalMsgs, normalMsgs)
proc queueHandleLoop*[T: Serializable, S: RateLimitStore](
manager: RateLimitManager[T, S],
proc queueHandleLoop*[T](
manager: RateLimitManager[T],
nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} =
Moment.now(),
) {.async.} =
@ -204,24 +191,20 @@ proc queueHandleLoop*[T: Serializable, S: RateLimitStore](
# configurable sleep duration for processing queued messages
await sleepAsync(manager.sleepDuration)
proc start*[T: Serializable, S: RateLimitStore](
manager: RateLimitManager[T, S],
proc start*[T](
manager: RateLimitManager[T],
nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} =
Moment.now(),
) {.async.} =
manager.pxQueueHandleLoop = queueHandleLoop(manager, nowProvider)
proc stop*[T: Serializable, S: RateLimitStore](
manager: RateLimitManager[T, S]
) {.async.} =
proc stop*[T](manager: RateLimitManager[T]) {.async.} =
if not isNil(manager.pxQueueHandleLoop):
await manager.pxQueueHandleLoop.cancelAndWait()
func `$`*[T: Serializable, S: RateLimitStore](
b: RateLimitManager[T, S]
): string {.inline.} =
func `$`*[T](b: RateLimitManager[T]): string {.inline.} =
if isNil(b):
return "nil"
return
"RateLimitManager(critical: " & $b.queueCritical.len & ", normal: " &
$b.queueNormal.len & ")"
"RateLimitManager(critical: " & $b.store.getQueueLength(QueueType.Critical) &
", normal: " & $b.store.getQueueLength(QueueType.Normal) & ")"

198
ratelimit/store.nim Normal file
View File

@ -0,0 +1,198 @@
import std/[times, strutils, json, options, base64]
import db_connector/db_sqlite
import chronos
import flatty
type
RateLimitStore*[T] = ref object
db: DbConn
dbPath: string
criticalLength: int
normalLength: int
nextBatchId: int
BucketState* {.pure} = object
budget*: int
budgetCap*: int
lastTimeFull*: Moment
QueueType* {.pure.} = enum
Critical = "critical"
Normal = "normal"
const BUCKET_STATE_KEY = "rate_limit_bucket_state"
proc new*[T](M: type[RateLimitStore[T]], db: DbConn): Future[M] {.async} =
result = M(db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1)
# Initialize cached lengths from database
let criticalCount = db.getValue(
sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?",
"critical",
)
let normalCount = db.getValue(
sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?",
"normal",
)
result.criticalLength =
if criticalCount == "":
0
else:
parseInt(criticalCount)
result.normalLength =
if normalCount == "":
0
else:
parseInt(normalCount)
# Get next batch ID
let maxBatch = db.getValue(sql"SELECT MAX(batch_id) FROM ratelimit_queues")
result.nextBatchId =
if maxBatch == "":
1
else:
parseInt(maxBatch) + 1
return result
proc saveBucketState*[T](
store: RateLimitStore[T], bucketState: BucketState
): Future[bool] {.async.} =
try:
# Convert Moment to Unix seconds for storage
let lastTimeSeconds = bucketState.lastTimeFull.epochSeconds()
let jsonState =
%*{
"budget": bucketState.budget,
"budgetCap": bucketState.budgetCap,
"lastTimeFullSeconds": lastTimeSeconds,
}
store.db.exec(
sql"INSERT INTO kv_store (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value",
BUCKET_STATE_KEY,
$jsonState,
)
return true
except:
return false
proc loadBucketState*[T](
store: RateLimitStore[T]
): Future[Option[BucketState]] {.async.} =
let jsonStr =
store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY)
if jsonStr == "":
return none(BucketState)
let jsonData = parseJson(jsonStr)
let unixSeconds = jsonData["lastTimeFullSeconds"].getInt().int64
let lastTimeFull = Moment.init(unixSeconds, chronos.seconds(1))
return some(
BucketState(
budget: jsonData["budget"].getInt(),
budgetCap: jsonData["budgetCap"].getInt(),
lastTimeFull: lastTimeFull,
)
)
proc pushToQueue*[T](
store: RateLimitStore[T],
queueType: QueueType,
msgs: seq[tuple[msgId: string, msg: T]],
): Future[bool] {.async.} =
try:
let batchId = store.nextBatchId
inc store.nextBatchId
let now = times.getTime().toUnix()
let queueTypeStr = $queueType
if msgs.len > 0:
store.db.exec(sql"BEGIN TRANSACTION")
try:
for msg in msgs:
let serialized = msg.msg.toFlatty()
let msgData = encode(serialized)
store.db.exec(
sql"INSERT INTO ratelimit_queues (queue_type, msg_id, msg_data, batch_id, created_at) VALUES (?, ?, ?, ?, ?)",
queueTypeStr,
msg.msgId,
msgData,
batchId,
now,
)
store.db.exec(sql"COMMIT")
except:
store.db.exec(sql"ROLLBACK")
raise
case queueType
of QueueType.Critical:
inc store.criticalLength
of QueueType.Normal:
inc store.normalLength
return true
except:
return false
proc popFromQueue*[T](
store: RateLimitStore[T], queueType: QueueType
): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} =
try:
let queueTypeStr = $queueType
# Get the oldest batch ID for this queue type
let oldestBatchStr = store.db.getValue(
sql"SELECT MIN(batch_id) FROM ratelimit_queues WHERE queue_type = ?", queueTypeStr
)
if oldestBatchStr == "":
return none(seq[tuple[msgId: string, msg: T]])
let batchId = parseInt(oldestBatchStr)
# Get all messages in this batch (preserve insertion order using rowid)
let rows = store.db.getAllRows(
sql"SELECT msg_id, msg_data FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ? ORDER BY rowid",
queueTypeStr,
batchId,
)
if rows.len == 0:
return none(seq[tuple[msgId: string, msg: T]])
var msgs: seq[tuple[msgId: string, msg: T]]
for row in rows:
let msgIdStr = row[0]
let msgDataB64 = row[1]
let serialized = decode(msgDataB64)
let msg = serialized.fromFlatty(T)
msgs.add((msgId: msgIdStr, msg: msg))
# Delete the batch from database
store.db.exec(
sql"DELETE FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ?",
queueTypeStr,
batchId,
)
case queueType
of QueueType.Critical:
dec store.criticalLength
of QueueType.Normal:
dec store.normalLength
return some(msgs)
except:
return none(seq[tuple[msgId: string, msg: T]])
proc getQueueLength*[T](store: RateLimitStore[T], queueType: QueueType): int =
case queueType
of QueueType.Critical:
return store.criticalLength
of QueueType.Normal:
return store.normalLength

View File

@ -1,21 +0,0 @@
import std/[times, options]
import ./store
import chronos
# Memory Implementation
type MemoryRateLimitStore* = ref object
bucketState: Option[BucketState]
proc new*(T: type[MemoryRateLimitStore]): T =
return T(bucketState: none(BucketState))
proc saveBucketState*(
store: MemoryRateLimitStore, bucketState: BucketState
): Future[bool] {.async.} =
store.bucketState = some(bucketState)
return true
proc loadBucketState*(
store: MemoryRateLimitStore
): Future[Option[BucketState]] {.async.} =
return store.bucketState

View File

@ -1,56 +0,0 @@
import std/[times, strutils, json, options]
import ./store
import chronos
import db_connector/db_sqlite
# SQLite Implementation
type SqliteRateLimitStore* = ref object
db: DbConn
dbPath: string
const BUCKET_STATE_KEY = "rate_limit_bucket_state"
proc newSqliteRateLimitStore*(db: DbConn): SqliteRateLimitStore =
result = SqliteRateLimitStore(db: db)
proc saveBucketState*(
store: SqliteRateLimitStore, bucketState: BucketState
): Future[bool] {.async.} =
try:
# Convert Moment to Unix seconds for storage
let lastTimeSeconds = bucketState.lastTimeFull.epochSeconds()
let jsonState =
%*{
"budget": bucketState.budget,
"budgetCap": bucketState.budgetCap,
"lastTimeFullSeconds": lastTimeSeconds,
}
store.db.exec(
sql"INSERT INTO kv_store (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value",
BUCKET_STATE_KEY,
$jsonState,
)
return true
except:
return false
proc loadBucketState*(
store: SqliteRateLimitStore
): Future[Option[BucketState]] {.async.} =
let jsonStr =
store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY)
if jsonStr == "":
return none(BucketState)
let jsonData = parseJson(jsonStr)
let unixSeconds = jsonData["lastTimeFullSeconds"].getInt().int64
let lastTimeFull = Moment.init(unixSeconds, chronos.seconds(1))
return some(
BucketState(
budget: jsonData["budget"].getInt(),
budgetCap: jsonData["budgetCap"].getInt(),
lastTimeFull: lastTimeFull,
)
)

View File

@ -1,13 +0,0 @@
import std/[times, deques, options]
import chronos
type
BucketState* = object
budget*: int
budgetCap*: int
lastTimeFull*: Moment
RateLimitStore* =
concept s
s.saveBucketState(BucketState) is Future[bool]
s.loadBucketState() is Future[Option[BucketState]]

View File

@ -1,14 +1,18 @@
import testutils/unittests
import ../ratelimit/ratelimit_manager
import ../ratelimit/store/memory
import ../ratelimit/store
import chronos
import db_connector/db_sqlite
import ../chat_sdk/migration
import std/[os, options]
# Implement the Serializable concept for string
proc toBytes*(s: string): seq[byte] =
cast[seq[byte]](s)
var dbName = "test_ratelimit_manager.db"
suite "Queue RateLimitManager":
setup:
let db = open(dbName, "", "", "")
runMigrations(db)
var sentMessages: seq[tuple[msgId: string, msg: string]]
var senderCallCount: int = 0
@ -21,10 +25,16 @@ suite "Queue RateLimitManager":
sentMessages.add(msg)
await sleepAsync(chronos.milliseconds(10))
teardown:
if db != nil:
db.close()
if fileExists(dbName):
removeFile(dbName)
asyncTest "sendOrEnqueue - immediate send when capacity available":
## Given
let store: MemoryRateLimitStore = MemoryRateLimitStore.new()
let manager = await RateLimitManager[string, MemoryRateLimitStore].new(
let store: RateLimitStore[string] = await RateLimitStore[string].new(db)
let manager = await RateLimitManager[string].new(
store, mockSender, capacity = 10, duration = chronos.milliseconds(100)
)
let testMsg = "Hello World"
@ -42,8 +52,8 @@ suite "Queue RateLimitManager":
asyncTest "sendOrEnqueue - multiple messages":
## Given
let store = MemoryRateLimitStore.new()
let manager = await RateLimitManager[string, MemoryRateLimitStore].new(
let store = await RateLimitStore[string].new(db)
let manager = await RateLimitManager[string].new(
store, mockSender, capacity = 10, duration = chronos.milliseconds(100)
)
@ -63,8 +73,8 @@ suite "Queue RateLimitManager":
asyncTest "start and stop - drop large batch":
## Given
let store = MemoryRateLimitStore.new()
let manager = await RateLimitManager[string, MemoryRateLimitStore].new(
let store = await RateLimitStore[string].new(db)
let manager = await RateLimitManager[string].new(
store,
mockSender,
capacity = 2,
@ -80,8 +90,8 @@ suite "Queue RateLimitManager":
asyncTest "enqueue - enqueue critical only when exceeded":
## Given
let store = MemoryRateLimitStore.new()
let manager = await RateLimitManager[string, MemoryRateLimitStore].new(
let store = await RateLimitStore[string].new(db)
let manager = await RateLimitManager[string].new(
store,
mockSender,
capacity = 10,
@ -122,16 +132,10 @@ suite "Queue RateLimitManager":
r10 == PassedToSender
r11 == Enqueued
let (critical, normal) = manager.getEnqueued()
check:
critical.len == 1
normal.len == 0
critical[0].msgId == "msg11"
asyncTest "enqueue - enqueue normal on 70% capacity":
## Given
let store = MemoryRateLimitStore.new()
let manager = await RateLimitManager[string, MemoryRateLimitStore].new(
## Given
let store = await RateLimitStore[string].new(db)
let manager = await RateLimitManager[string].new(
store,
mockSender,
capacity = 10,
@ -173,18 +177,10 @@ suite "Queue RateLimitManager":
r11 == PassedToSender
r12 == Dropped
let (critical, normal) = manager.getEnqueued()
check:
critical.len == 0
normal.len == 3
normal[0].msgId == "msg8"
normal[1].msgId == "msg9"
normal[2].msgId == "msg10"
asyncTest "enqueue - process queued messages":
## Given
let store = MemoryRateLimitStore.new()
let manager = await RateLimitManager[string, MemoryRateLimitStore].new(
let store = await RateLimitStore[string].new(db)
let manager = await RateLimitManager[string].new(
store,
mockSender,
capacity = 10,
@ -236,24 +232,9 @@ suite "Queue RateLimitManager":
r14 == PassedToSender
r15 == Enqueued
var (critical, normal) = manager.getEnqueued()
check:
critical.len == 1
normal.len == 3
normal[0].msgId == "8"
normal[1].msgId == "9"
normal[2].msgId == "10"
critical[0].msgId == "15"
nowRef.value = now + chronos.milliseconds(250)
await sleepAsync(chronos.milliseconds(250))
(critical, normal) = manager.getEnqueued()
check:
critical.len == 0
normal.len == 0
senderCallCount == 14
sentMessages.len == 14
senderCallCount == 10 # 10 messages passed to sender
sentMessages.len == 10
sentMessages[0].msgId == "1"
sentMessages[1].msgId == "2"
sentMessages[2].msgId == "3"
@ -264,6 +245,13 @@ suite "Queue RateLimitManager":
sentMessages[7].msgId == "11"
sentMessages[8].msgId == "13"
sentMessages[9].msgId == "14"
nowRef.value = now + chronos.milliseconds(250)
await sleepAsync(chronos.milliseconds(250))
check:
senderCallCount == 14
sentMessages.len == 14
sentMessages[10].msgId == "15"
sentMessages[11].msgId == "8"
sentMessages[12].msgId == "9"

View File

@ -1,48 +0,0 @@
import testutils/unittests
import ../ratelimit/store/sqlite
import ../ratelimit/store/store
import chronos
import db_connector/db_sqlite
import ../chat_sdk/migration
import std/[options, os]
suite "SqliteRateLimitStore Tests":
setup:
let db = open("test-ratelimit.db", "", "", "")
runMigrations(db)
teardown:
if db != nil:
db.close()
if fileExists("test-ratelimit.db"):
removeFile("test-ratelimit.db")
asyncTest "newSqliteRateLimitStore - empty state":
## Given
let store = newSqliteRateLimitStore(db)
## When
let loadedState = await store.loadBucketState()
## Then
check loadedState.isNone()
asyncTest "saveBucketState and loadBucketState - state persistence":
## Given
let store = newSqliteRateLimitStore(db)
let now = Moment.now()
echo "now: ", now.epochSeconds()
let newBucketState = BucketState(budget: 5, budgetCap: 20, lastTimeFull: now)
## When
let saveResult = await store.saveBucketState(newBucketState)
let loadedState = await store.loadBucketState()
## Then
check saveResult == true
check loadedState.isSome()
check loadedState.get().budget == newBucketState.budget
check loadedState.get().budgetCap == newBucketState.budgetCap
check loadedState.get().lastTimeFull.epochSeconds() ==
newBucketState.lastTimeFull.epochSeconds()

208
tests/test_store.nim Normal file
View File

@ -0,0 +1,208 @@
import testutils/unittests
import ../ratelimit/store
import chronos
import db_connector/db_sqlite
import ../chat_sdk/migration
import std/[options, os, json]
import flatty
const dbName = "test_store.db"
suite "SqliteRateLimitStore Tests":
setup:
let db = open(dbName, "", "", "")
runMigrations(db)
teardown:
if db != nil:
db.close()
if fileExists(dbName):
removeFile(dbName)
asyncTest "newSqliteRateLimitStore - empty state":
## Given
let store = await RateLimitStore[string].new(db)
## When
let loadedState = await store.loadBucketState()
## Then
check loadedState.isNone()
asyncTest "saveBucketState and loadBucketState - state persistence":
## Given
let store = await RateLimitStore[string].new(db)
let now = Moment.now()
echo "now: ", now.epochSeconds()
let newBucketState = BucketState(budget: 5, budgetCap: 20, lastTimeFull: now)
## When
let saveResult = await store.saveBucketState(newBucketState)
let loadedState = await store.loadBucketState()
## Then
check saveResult == true
check loadedState.isSome()
check loadedState.get().budget == newBucketState.budget
check loadedState.get().budgetCap == newBucketState.budgetCap
check loadedState.get().lastTimeFull.epochSeconds() ==
newBucketState.lastTimeFull.epochSeconds()
asyncTest "queue operations - empty store":
## Given
let store = await RateLimitStore[string].new(db)
## When/Then
check store.getQueueLength(QueueType.Critical) == 0
check store.getQueueLength(QueueType.Normal) == 0
let criticalPop = await store.popFromQueue(QueueType.Critical)
let normalPop = await store.popFromQueue(QueueType.Normal)
check criticalPop.isNone()
check normalPop.isNone()
asyncTest "addToQueue and popFromQueue - single batch":
## Given
let store = await RateLimitStore[string].new(db)
let msgs = @[("msg1", "Hello"), ("msg2", "World")]
## When
let addResult = await store.pushToQueue(QueueType.Critical, msgs)
## Then
check addResult == true
check store.getQueueLength(QueueType.Critical) == 1
check store.getQueueLength(QueueType.Normal) == 0
## When
let popResult = await store.popFromQueue(QueueType.Critical)
## Then
check popResult.isSome()
let poppedMsgs = popResult.get()
check poppedMsgs.len == 2
check poppedMsgs[0].msgId == "msg1"
check poppedMsgs[0].msg == "Hello"
check poppedMsgs[1].msgId == "msg2"
check poppedMsgs[1].msg == "World"
check store.getQueueLength(QueueType.Critical) == 0
asyncTest "addToQueue and popFromQueue - multiple batches FIFO":
## Given
let store = await RateLimitStore[string].new(db)
let batch1 = @[("msg1", "First")]
let batch2 = @[("msg2", "Second")]
let batch3 = @[("msg3", "Third")]
## When - Add batches
let result1 = await store.pushToQueue(QueueType.Normal, batch1)
check result1 == true
let result2 = await store.pushToQueue(QueueType.Normal, batch2)
check result2 == true
let result3 = await store.pushToQueue(QueueType.Normal, batch3)
check result3 == true
## Then - Check lengths
check store.getQueueLength(QueueType.Normal) == 3
check store.getQueueLength(QueueType.Critical) == 0
## When/Then - Pop in FIFO order
let pop1 = await store.popFromQueue(QueueType.Normal)
check pop1.isSome()
check pop1.get()[0].msg == "First"
check store.getQueueLength(QueueType.Normal) == 2
let pop2 = await store.popFromQueue(QueueType.Normal)
check pop2.isSome()
check pop2.get()[0].msg == "Second"
check store.getQueueLength(QueueType.Normal) == 1
let pop3 = await store.popFromQueue(QueueType.Normal)
check pop3.isSome()
check pop3.get()[0].msg == "Third"
check store.getQueueLength(QueueType.Normal) == 0
let pop4 = await store.popFromQueue(QueueType.Normal)
check pop4.isNone()
asyncTest "queue isolation - critical and normal queues are separate":
## Given
let store = await RateLimitStore[string].new(db)
let criticalMsgs = @[("crit1", "Critical Message")]
let normalMsgs = @[("norm1", "Normal Message")]
## When
let critResult = await store.pushToQueue(QueueType.Critical, criticalMsgs)
check critResult == true
let normResult = await store.pushToQueue(QueueType.Normal, normalMsgs)
check normResult == true
## Then
check store.getQueueLength(QueueType.Critical) == 1
check store.getQueueLength(QueueType.Normal) == 1
## When - Pop from critical
let criticalPop = await store.popFromQueue(QueueType.Critical)
check criticalPop.isSome()
check criticalPop.get()[0].msg == "Critical Message"
## Then - Normal queue unaffected
check store.getQueueLength(QueueType.Critical) == 0
check store.getQueueLength(QueueType.Normal) == 1
## When - Pop from normal
let normalPop = await store.popFromQueue(QueueType.Normal)
check normalPop.isSome()
check normalPop.get()[0].msg == "Normal Message"
## Then - All queues empty
check store.getQueueLength(QueueType.Critical) == 0
check store.getQueueLength(QueueType.Normal) == 0
asyncTest "queue persistence across store instances":
## Given
let msgs = @[("persist1", "Persistent Message")]
block:
let store1 = await RateLimitStore[string].new(db)
let addResult = await store1.pushToQueue(QueueType.Critical, msgs)
check addResult == true
check store1.getQueueLength(QueueType.Critical) == 1
## When - Create new store instance
block:
let store2 =await RateLimitStore[string].new(db)
## Then - Queue length should be restored from database
check store2.getQueueLength(QueueType.Critical) == 1
let popResult = await store2.popFromQueue(QueueType.Critical)
check popResult.isSome()
check popResult.get()[0].msg == "Persistent Message"
check store2.getQueueLength(QueueType.Critical) == 0
asyncTest "large batch handling":
## Given
let store = await RateLimitStore[string].new(db)
var largeBatch: seq[tuple[msgId: string, msg: string]]
for i in 1 .. 100:
largeBatch.add(("msg" & $i, "Message " & $i))
## When
let addResult = await store.pushToQueue(QueueType.Normal, largeBatch)
## Then
check addResult == true
check store.getQueueLength(QueueType.Normal) == 1
let popResult = await store.popFromQueue(QueueType.Normal)
check popResult.isSome()
let poppedMsgs = popResult.get()
check poppedMsgs.len == 100
check poppedMsgs[0].msgId == "msg1"
check poppedMsgs[99].msgId == "msg100"
check store.getQueueLength(QueueType.Normal) == 0