From 2c47183fb03ef283e99377cd584c28d327b00b4c Mon Sep 17 00:00:00 2001 From: pablo Date: Mon, 4 Aug 2025 10:43:59 +0300 Subject: [PATCH] feat: store queue --- chat_sdk/migration.nim | 21 ++- migrations/001_create_ratelimit_state.sql | 9 ++ ratelimit/store/memory.nim | 65 +++++++- ratelimit/store/sqlite.nim | 167 ++++++++++++++++++++- ratelimit/store/store.nim | 7 + tests/test_sqlite_store.nim | 175 +++++++++++++++++++++- 6 files changed, 421 insertions(+), 23 deletions(-) diff --git a/chat_sdk/migration.nim b/chat_sdk/migration.nim index 8af88d3..45ea614 100644 --- a/chat_sdk/migration.nim +++ b/chat_sdk/migration.nim @@ -1,17 +1,21 @@ -import os, sequtils, algorithm +import os, sequtils, algorithm, strutils import db_connector/db_sqlite import chronicles proc ensureMigrationTable(db: DbConn) = - db.exec(sql""" + db.exec( + sql""" CREATE TABLE IF NOT EXISTS schema_migrations ( filename TEXT PRIMARY KEY, applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) - """) + """ + ) proc hasMigrationRun(db: DbConn, filename: string): bool = - for row in db.fastRows(sql"SELECT 1 FROM schema_migrations WHERE filename = ?", filename): + for row in db.fastRows( + sql"SELECT 1 FROM schema_migrations WHERE filename = ?", filename + ): return true return false @@ -27,6 +31,11 @@ proc runMigrations*(db: DbConn, dir = "migrations") = info "Migration already applied", file else: info "Applying migration", file - let sql = readFile(file) - db.exec(sql(sql)) + let sqlContent = readFile(file) + # Split by semicolon and execute each statement separately + let statements = sqlContent.split(';') + for stmt in statements: + let trimmedStmt = stmt.strip() + if trimmedStmt.len > 0: + db.exec(sql(trimmedStmt)) markMigrationRun(db, file) diff --git a/migrations/001_create_ratelimit_state.sql b/migrations/001_create_ratelimit_state.sql index 030377c..293c6ee 100644 --- a/migrations/001_create_ratelimit_state.sql +++ b/migrations/001_create_ratelimit_state.sql @@ -1,4 +1,13 @@ CREATE TABLE IF NOT EXISTS kv_store ( key TEXT PRIMARY KEY, value BLOB +); + +CREATE TABLE IF NOT EXISTS ratelimit_queues ( + queue_type TEXT NOT NULL, + msg_id TEXT NOT NULL, + msg_data BLOB NOT NULL, + batch_id INTEGER NOT NULL, + created_at INTEGER NOT NULL, + PRIMARY KEY (queue_type, batch_id, msg_id) ); \ No newline at end of file diff --git a/ratelimit/store/memory.nim b/ratelimit/store/memory.nim index 6391314..e7e7c2f 100644 --- a/ratelimit/store/memory.nim +++ b/ratelimit/store/memory.nim @@ -1,21 +1,70 @@ -import std/[times, options] +import std/[times, options, deques, tables] import ./store import chronos # Memory Implementation -type MemoryRateLimitStore* = ref object +type MemoryRateLimitStore*[T] = ref object bucketState: Option[BucketState] + criticalQueue: Deque[seq[tuple[msgId: string, msg: T]]] + normalQueue: Deque[seq[tuple[msgId: string, msg: T]]] + criticalLength: int + normalLength: int -proc new*(T: type[MemoryRateLimitStore]): T = - return T(bucketState: none(BucketState)) +proc new*[T](M: type[MemoryRateLimitStore[T]]): M = + return M( + bucketState: none(BucketState), + criticalQueue: initDeque[seq[tuple[msgId: string, msg: T]]](), + normalQueue: initDeque[seq[tuple[msgId: string, msg: T]]](), + criticalLength: 0, + normalLength: 0 + ) -proc saveBucketState*( - store: MemoryRateLimitStore, bucketState: BucketState +proc saveBucketState*[T]( + store: MemoryRateLimitStore[T], bucketState: BucketState ): Future[bool] {.async.} = store.bucketState = some(bucketState) return true -proc loadBucketState*( - store: MemoryRateLimitStore +proc loadBucketState*[T]( + store: MemoryRateLimitStore[T] ): Future[Option[BucketState]] {.async.} = return store.bucketState + +proc addToQueue*[T]( + store: MemoryRateLimitStore[T], + queueType: QueueType, + msgs: seq[tuple[msgId: string, msg: T]] +): Future[bool] {.async.} = + case queueType + of QueueType.Critical: + store.criticalQueue.addLast(msgs) + inc store.criticalLength + of QueueType.Normal: + store.normalQueue.addLast(msgs) + inc store.normalLength + return true + +proc popFromQueue*[T]( + store: MemoryRateLimitStore[T], + queueType: QueueType +): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} = + case queueType + of QueueType.Critical: + if store.criticalQueue.len > 0: + dec store.criticalLength + return some(store.criticalQueue.popFirst()) + of QueueType.Normal: + if store.normalQueue.len > 0: + dec store.normalLength + return some(store.normalQueue.popFirst()) + return none(seq[tuple[msgId: string, msg: T]]) + +proc getQueueLength*[T]( + store: MemoryRateLimitStore[T], + queueType: QueueType +): int = + case queueType + of QueueType.Critical: + return store.criticalLength + of QueueType.Normal: + return store.normalLength diff --git a/ratelimit/store/sqlite.nim b/ratelimit/store/sqlite.nim index e364e5d..a3369e9 100644 --- a/ratelimit/store/sqlite.nim +++ b/ratelimit/store/sqlite.nim @@ -3,18 +3,58 @@ import ./store import chronos import db_connector/db_sqlite +# Generic deserialization function for basic types +proc fromBytesImpl(bytes: seq[byte], T: typedesc[string]): string = + # Convert each byte back to a character + result = newString(bytes.len) + for i, b in bytes: + result[i] = char(b) + # SQLite Implementation -type SqliteRateLimitStore* = ref object +type SqliteRateLimitStore*[T] = ref object db: DbConn dbPath: string + criticalLength: int + normalLength: int + nextBatchId: int const BUCKET_STATE_KEY = "rate_limit_bucket_state" -proc newSqliteRateLimitStore*(db: DbConn): SqliteRateLimitStore = - result = SqliteRateLimitStore(db: db) +proc newSqliteRateLimitStore*[T](db: DbConn): SqliteRateLimitStore[T] = + result = + SqliteRateLimitStore[T](db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1) -proc saveBucketState*( - store: SqliteRateLimitStore, bucketState: BucketState + # Initialize cached lengths from database + let criticalCount = db.getValue( + sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?", + "critical", + ) + let normalCount = db.getValue( + sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?", + "normal", + ) + + result.criticalLength = + if criticalCount == "": + 0 + else: + parseInt(criticalCount) + result.normalLength = + if normalCount == "": + 0 + else: + parseInt(normalCount) + + # Get next batch ID + let maxBatch = db.getValue(sql"SELECT MAX(batch_id) FROM ratelimit_queues") + result.nextBatchId = + if maxBatch == "": + 1 + else: + parseInt(maxBatch) + 1 + +proc saveBucketState*[T]( + store: SqliteRateLimitStore[T], bucketState: BucketState ): Future[bool] {.async.} = try: # Convert Moment to Unix seconds for storage @@ -35,8 +75,8 @@ proc saveBucketState*( except: return false -proc loadBucketState*( - store: SqliteRateLimitStore +proc loadBucketState*[T]( + store: SqliteRateLimitStore[T] ): Future[Option[BucketState]] {.async.} = let jsonStr = store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY) @@ -54,3 +94,116 @@ proc loadBucketState*( lastTimeFull: lastTimeFull, ) ) + +proc addToQueue*[T]( + store: SqliteRateLimitStore[T], + queueType: QueueType, + msgs: seq[tuple[msgId: string, msg: T]], +): Future[bool] {.async.} = + try: + let batchId = store.nextBatchId + inc store.nextBatchId + let now = times.getTime().toUnix() + let queueTypeStr = $queueType + + if msgs.len > 0: + store.db.exec(sql"BEGIN TRANSACTION") + try: + for msg in msgs: + # Consistent serialization for all types + let msgBytes = msg.msg.toBytes() + # Convert seq[byte] to string for SQLite storage (each byte becomes a character) + var binaryStr = newString(msgBytes.len) + for i, b in msgBytes: + binaryStr[i] = char(b) + + store.db.exec( + sql"INSERT INTO ratelimit_queues (queue_type, msg_id, msg_data, batch_id, created_at) VALUES (?, ?, ?, ?, ?)", + queueTypeStr, + msg.msgId, + binaryStr, + batchId, + now, + ) + store.db.exec(sql"COMMIT") + except: + store.db.exec(sql"ROLLBACK") + raise + + case queueType + of QueueType.Critical: + inc store.criticalLength + of QueueType.Normal: + inc store.normalLength + + return true + except: + return false + +proc popFromQueue*[T]( + store: SqliteRateLimitStore[T], queueType: QueueType +): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} = + try: + let queueTypeStr = $queueType + + # Get the oldest batch ID for this queue type + let oldestBatchStr = store.db.getValue( + sql"SELECT MIN(batch_id) FROM ratelimit_queues WHERE queue_type = ?", queueTypeStr + ) + + if oldestBatchStr == "": + return none(seq[tuple[msgId: string, msg: T]]) + + let batchId = parseInt(oldestBatchStr) + + # Get all messages in this batch (preserve insertion order using rowid) + let rows = store.db.getAllRows( + sql"SELECT msg_id, msg_data FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ? ORDER BY rowid", + queueTypeStr, + batchId, + ) + + if rows.len == 0: + return none(seq[tuple[msgId: string, msg: T]]) + + var msgs: seq[tuple[msgId: string, msg: T]] + for row in rows: + let msgIdStr = row[0] + let msgData = row[1] # SQLite returns BLOB as string where each char is a byte + # Convert string back to seq[byte] properly (each char in string is a byte) + var msgBytes: seq[byte] + for c in msgData: + msgBytes.add(byte(c)) + + # Generic deserialization - works for any type that implements fromBytes + when T is string: + let msg = fromBytesImpl(msgBytes, T) + msgs.add((msgId: msgIdStr, msg: msg)) + else: + # For other types, they need to provide their own fromBytes in the calling context + let msg = fromBytes(msgBytes, T) + msgs.add((msgId: msgIdStr, msg: msg)) + + # Delete the batch from database + store.db.exec( + sql"DELETE FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ?", + queueTypeStr, + batchId, + ) + + case queueType + of QueueType.Critical: + dec store.criticalLength + of QueueType.Normal: + dec store.normalLength + + return some(msgs) + except: + return none(seq[tuple[msgId: string, msg: T]]) + +proc getQueueLength*[T](store: SqliteRateLimitStore[T], queueType: QueueType): int = + case queueType + of QueueType.Critical: + return store.criticalLength + of QueueType.Normal: + return store.normalLength diff --git a/ratelimit/store/store.nim b/ratelimit/store/store.nim index c916750..0f18eb1 100644 --- a/ratelimit/store/store.nim +++ b/ratelimit/store/store.nim @@ -7,7 +7,14 @@ type budgetCap*: int lastTimeFull*: Moment + QueueType* {.pure.} = enum + Critical = "critical" + Normal = "normal" + RateLimitStore* = concept s s.saveBucketState(BucketState) is Future[bool] s.loadBucketState() is Future[Option[BucketState]] + s.addToQueue(QueueType, seq[tuple[msgId: string, msg: untyped]]) is Future[bool] + s.popFromQueue(QueueType) is Future[Option[seq[tuple[msgId: string, msg: untyped]]]] + s.getQueueLength(QueueType) is int diff --git a/tests/test_sqlite_store.nim b/tests/test_sqlite_store.nim index 6315ec9..90d764c 100644 --- a/tests/test_sqlite_store.nim +++ b/tests/test_sqlite_store.nim @@ -6,6 +6,19 @@ import db_connector/db_sqlite import ../chat_sdk/migration import std/[options, os] +# Implement the Serializable concept for string (for testing) +proc toBytes*(s: string): seq[byte] = + # Convert each character to a byte + result = newSeq[byte](s.len) + for i, c in s: + result[i] = byte(c) + +proc fromBytes*(bytes: seq[byte], T: typedesc[string]): string = + # Convert each byte back to a character + result = newString(bytes.len) + for i, b in bytes: + result[i] = char(b) + suite "SqliteRateLimitStore Tests": setup: let db = open("test-ratelimit.db", "", "", "") @@ -19,7 +32,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "newSqliteRateLimitStore - empty state": ## Given - let store = newSqliteRateLimitStore(db) + let store = newSqliteRateLimitStore[string](db) ## When let loadedState = await store.loadBucketState() @@ -29,7 +42,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "saveBucketState and loadBucketState - state persistence": ## Given - let store = newSqliteRateLimitStore(db) + let store = newSqliteRateLimitStore[string](db) let now = Moment.now() echo "now: ", now.epochSeconds() @@ -46,3 +59,161 @@ suite "SqliteRateLimitStore Tests": check loadedState.get().budgetCap == newBucketState.budgetCap check loadedState.get().lastTimeFull.epochSeconds() == newBucketState.lastTimeFull.epochSeconds() + + asyncTest "queue operations - empty store": + ## Given + let store = newSqliteRateLimitStore[string](db) + + ## When/Then + check store.getQueueLength(QueueType.Critical) == 0 + check store.getQueueLength(QueueType.Normal) == 0 + + let criticalPop = await store.popFromQueue(QueueType.Critical) + let normalPop = await store.popFromQueue(QueueType.Normal) + + check criticalPop.isNone() + check normalPop.isNone() + + asyncTest "addToQueue and popFromQueue - single batch": + ## Given + let store = newSqliteRateLimitStore[string](db) + let msgs = @[("msg1", "Hello"), ("msg2", "World")] + + ## When + let addResult = await store.addToQueue(QueueType.Critical, msgs) + + ## Then + check addResult == true + check store.getQueueLength(QueueType.Critical) == 1 + check store.getQueueLength(QueueType.Normal) == 0 + + ## When + let popResult = await store.popFromQueue(QueueType.Critical) + + ## Then + check popResult.isSome() + let poppedMsgs = popResult.get() + check poppedMsgs.len == 2 + check poppedMsgs[0].msgId == "msg1" + check poppedMsgs[0].msg == "Hello" + check poppedMsgs[1].msgId == "msg2" + check poppedMsgs[1].msg == "World" + + check store.getQueueLength(QueueType.Critical) == 0 + + asyncTest "addToQueue and popFromQueue - multiple batches FIFO": + ## Given + let store = newSqliteRateLimitStore[string](db) + let batch1 = @[("msg1", "First")] + let batch2 = @[("msg2", "Second")] + let batch3 = @[("msg3", "Third")] + + ## When - Add batches + let result1 = await store.addToQueue(QueueType.Normal, batch1) + check result1 == true + let result2 = await store.addToQueue(QueueType.Normal, batch2) + check result2 == true + let result3 = await store.addToQueue(QueueType.Normal, batch3) + check result3 == true + + ## Then - Check lengths + check store.getQueueLength(QueueType.Normal) == 3 + check store.getQueueLength(QueueType.Critical) == 0 + + ## When/Then - Pop in FIFO order + let pop1 = await store.popFromQueue(QueueType.Normal) + check pop1.isSome() + check pop1.get()[0].msg == "First" + check store.getQueueLength(QueueType.Normal) == 2 + + let pop2 = await store.popFromQueue(QueueType.Normal) + check pop2.isSome() + check pop2.get()[0].msg == "Second" + check store.getQueueLength(QueueType.Normal) == 1 + + let pop3 = await store.popFromQueue(QueueType.Normal) + check pop3.isSome() + check pop3.get()[0].msg == "Third" + check store.getQueueLength(QueueType.Normal) == 0 + + let pop4 = await store.popFromQueue(QueueType.Normal) + check pop4.isNone() + + asyncTest "queue isolation - critical and normal queues are separate": + ## Given + let store = newSqliteRateLimitStore[string](db) + let criticalMsgs = @[("crit1", "Critical Message")] + let normalMsgs = @[("norm1", "Normal Message")] + + ## When + let critResult = await store.addToQueue(QueueType.Critical, criticalMsgs) + check critResult == true + let normResult = await store.addToQueue(QueueType.Normal, normalMsgs) + check normResult == true + + ## Then + check store.getQueueLength(QueueType.Critical) == 1 + check store.getQueueLength(QueueType.Normal) == 1 + + ## When - Pop from critical + let criticalPop = await store.popFromQueue(QueueType.Critical) + check criticalPop.isSome() + check criticalPop.get()[0].msg == "Critical Message" + + ## Then - Normal queue unaffected + check store.getQueueLength(QueueType.Critical) == 0 + check store.getQueueLength(QueueType.Normal) == 1 + + ## When - Pop from normal + let normalPop = await store.popFromQueue(QueueType.Normal) + check normalPop.isSome() + check normalPop.get()[0].msg == "Normal Message" + + ## Then - All queues empty + check store.getQueueLength(QueueType.Critical) == 0 + check store.getQueueLength(QueueType.Normal) == 0 + + asyncTest "queue persistence across store instances": + ## Given + let msgs = @[("persist1", "Persistent Message")] + + block: + let store1 = newSqliteRateLimitStore[string](db) + let addResult = await store1.addToQueue(QueueType.Critical, msgs) + check addResult == true + check store1.getQueueLength(QueueType.Critical) == 1 + + ## When - Create new store instance + block: + let store2 = newSqliteRateLimitStore[string](db) + + ## Then - Queue length should be restored from database + check store2.getQueueLength(QueueType.Critical) == 1 + + let popResult = await store2.popFromQueue(QueueType.Critical) + check popResult.isSome() + check popResult.get()[0].msg == "Persistent Message" + check store2.getQueueLength(QueueType.Critical) == 0 + + asyncTest "large batch handling": + ## Given + let store = newSqliteRateLimitStore[string](db) + var largeBatch: seq[tuple[msgId: string, msg: string]] + + for i in 1 .. 100: + largeBatch.add(("msg" & $i, "Message " & $i)) + + ## When + let addResult = await store.addToQueue(QueueType.Normal, largeBatch) + + ## Then + check addResult == true + check store.getQueueLength(QueueType.Normal) == 1 + + let popResult = await store.popFromQueue(QueueType.Normal) + check popResult.isSome() + let poppedMsgs = popResult.get() + check poppedMsgs.len == 100 + check poppedMsgs[0].msgId == "msg1" + check poppedMsgs[99].msgId == "msg100" + check store.getQueueLength(QueueType.Normal) == 0