From 2c47183fb03ef283e99377cd584c28d327b00b4c Mon Sep 17 00:00:00 2001
From: pablo
Date: Mon, 4 Aug 2025 10:43:59 +0300
Subject: [PATCH] feat: store queue
---
chat_sdk/migration.nim | 21 ++-
migrations/001_create_ratelimit_state.sql | 9 ++
ratelimit/store/memory.nim | 65 +++++++-
ratelimit/store/sqlite.nim | 167 ++++++++++++++++++++-
ratelimit/store/store.nim | 7 +
tests/test_sqlite_store.nim | 175 +++++++++++++++++++++-
6 files changed, 421 insertions(+), 23 deletions(-)
diff --git a/chat_sdk/migration.nim b/chat_sdk/migration.nim
index 8af88d3..45ea614 100644
--- a/chat_sdk/migration.nim
+++ b/chat_sdk/migration.nim
@@ -1,17 +1,21 @@
-import os, sequtils, algorithm
+import os, sequtils, algorithm, strutils
import db_connector/db_sqlite
import chronicles
proc ensureMigrationTable(db: DbConn) =
- db.exec(sql"""
+ db.exec(
+ sql"""
CREATE TABLE IF NOT EXISTS schema_migrations (
filename TEXT PRIMARY KEY,
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
- """)
+ """
+ )
proc hasMigrationRun(db: DbConn, filename: string): bool =
- for row in db.fastRows(sql"SELECT 1 FROM schema_migrations WHERE filename = ?", filename):
+ for row in db.fastRows(
+ sql"SELECT 1 FROM schema_migrations WHERE filename = ?", filename
+ ):
return true
return false
@@ -27,6 +31,11 @@ proc runMigrations*(db: DbConn, dir = "migrations") =
info "Migration already applied", file
else:
info "Applying migration", file
- let sql = readFile(file)
- db.exec(sql(sql))
+ let sqlContent = readFile(file)
+ # Split by semicolon and execute each statement separately
+ let statements = sqlContent.split(';')
+ for stmt in statements:
+ let trimmedStmt = stmt.strip()
+ if trimmedStmt.len > 0:
+ db.exec(sql(trimmedStmt))
markMigrationRun(db, file)
diff --git a/migrations/001_create_ratelimit_state.sql b/migrations/001_create_ratelimit_state.sql
index 030377c..293c6ee 100644
--- a/migrations/001_create_ratelimit_state.sql
+++ b/migrations/001_create_ratelimit_state.sql
@@ -1,4 +1,13 @@
CREATE TABLE IF NOT EXISTS kv_store (
key TEXT PRIMARY KEY,
value BLOB
+);
+
+CREATE TABLE IF NOT EXISTS ratelimit_queues (
+ queue_type TEXT NOT NULL,
+ msg_id TEXT NOT NULL,
+ msg_data BLOB NOT NULL,
+ batch_id INTEGER NOT NULL,
+ created_at INTEGER NOT NULL,
+ PRIMARY KEY (queue_type, batch_id, msg_id)
);
\ No newline at end of file
diff --git a/ratelimit/store/memory.nim b/ratelimit/store/memory.nim
index 6391314..e7e7c2f 100644
--- a/ratelimit/store/memory.nim
+++ b/ratelimit/store/memory.nim
@@ -1,21 +1,70 @@
-import std/[times, options]
+import std/[times, options, deques, tables]
import ./store
import chronos
# Memory Implementation
-type MemoryRateLimitStore* = ref object
+type MemoryRateLimitStore*[T] = ref object
bucketState: Option[BucketState]
+ criticalQueue: Deque[seq[tuple[msgId: string, msg: T]]]
+ normalQueue: Deque[seq[tuple[msgId: string, msg: T]]]
+ criticalLength: int
+ normalLength: int
-proc new*(T: type[MemoryRateLimitStore]): T =
- return T(bucketState: none(BucketState))
+proc new*[T](M: type[MemoryRateLimitStore[T]]): M =
+ return M(
+ bucketState: none(BucketState),
+ criticalQueue: initDeque[seq[tuple[msgId: string, msg: T]]](),
+ normalQueue: initDeque[seq[tuple[msgId: string, msg: T]]](),
+ criticalLength: 0,
+ normalLength: 0
+ )
-proc saveBucketState*(
- store: MemoryRateLimitStore, bucketState: BucketState
+proc saveBucketState*[T](
+ store: MemoryRateLimitStore[T], bucketState: BucketState
): Future[bool] {.async.} =
store.bucketState = some(bucketState)
return true
-proc loadBucketState*(
- store: MemoryRateLimitStore
+proc loadBucketState*[T](
+ store: MemoryRateLimitStore[T]
): Future[Option[BucketState]] {.async.} =
return store.bucketState
+
+proc addToQueue*[T](
+ store: MemoryRateLimitStore[T],
+ queueType: QueueType,
+ msgs: seq[tuple[msgId: string, msg: T]]
+): Future[bool] {.async.} =
+ case queueType
+ of QueueType.Critical:
+ store.criticalQueue.addLast(msgs)
+ inc store.criticalLength
+ of QueueType.Normal:
+ store.normalQueue.addLast(msgs)
+ inc store.normalLength
+ return true
+
+proc popFromQueue*[T](
+ store: MemoryRateLimitStore[T],
+ queueType: QueueType
+): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} =
+ case queueType
+ of QueueType.Critical:
+ if store.criticalQueue.len > 0:
+ dec store.criticalLength
+ return some(store.criticalQueue.popFirst())
+ of QueueType.Normal:
+ if store.normalQueue.len > 0:
+ dec store.normalLength
+ return some(store.normalQueue.popFirst())
+ return none(seq[tuple[msgId: string, msg: T]])
+
+proc getQueueLength*[T](
+ store: MemoryRateLimitStore[T],
+ queueType: QueueType
+): int =
+ case queueType
+ of QueueType.Critical:
+ return store.criticalLength
+ of QueueType.Normal:
+ return store.normalLength
diff --git a/ratelimit/store/sqlite.nim b/ratelimit/store/sqlite.nim
index e364e5d..a3369e9 100644
--- a/ratelimit/store/sqlite.nim
+++ b/ratelimit/store/sqlite.nim
@@ -3,18 +3,58 @@ import ./store
import chronos
import db_connector/db_sqlite
+# Generic deserialization function for basic types
+proc fromBytesImpl(bytes: seq[byte], T: typedesc[string]): string =
+ # Convert each byte back to a character
+ result = newString(bytes.len)
+ for i, b in bytes:
+ result[i] = char(b)
+
# SQLite Implementation
-type SqliteRateLimitStore* = ref object
+type SqliteRateLimitStore*[T] = ref object
db: DbConn
dbPath: string
+ criticalLength: int
+ normalLength: int
+ nextBatchId: int
const BUCKET_STATE_KEY = "rate_limit_bucket_state"
-proc newSqliteRateLimitStore*(db: DbConn): SqliteRateLimitStore =
- result = SqliteRateLimitStore(db: db)
+proc newSqliteRateLimitStore*[T](db: DbConn): SqliteRateLimitStore[T] =
+ result =
+ SqliteRateLimitStore[T](db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1)
-proc saveBucketState*(
- store: SqliteRateLimitStore, bucketState: BucketState
+ # Initialize cached lengths from database
+ let criticalCount = db.getValue(
+ sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?",
+ "critical",
+ )
+ let normalCount = db.getValue(
+ sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?",
+ "normal",
+ )
+
+ result.criticalLength =
+ if criticalCount == "":
+ 0
+ else:
+ parseInt(criticalCount)
+ result.normalLength =
+ if normalCount == "":
+ 0
+ else:
+ parseInt(normalCount)
+
+ # Get next batch ID
+ let maxBatch = db.getValue(sql"SELECT MAX(batch_id) FROM ratelimit_queues")
+ result.nextBatchId =
+ if maxBatch == "":
+ 1
+ else:
+ parseInt(maxBatch) + 1
+
+proc saveBucketState*[T](
+ store: SqliteRateLimitStore[T], bucketState: BucketState
): Future[bool] {.async.} =
try:
# Convert Moment to Unix seconds for storage
@@ -35,8 +75,8 @@ proc saveBucketState*(
except:
return false
-proc loadBucketState*(
- store: SqliteRateLimitStore
+proc loadBucketState*[T](
+ store: SqliteRateLimitStore[T]
): Future[Option[BucketState]] {.async.} =
let jsonStr =
store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY)
@@ -54,3 +94,116 @@ proc loadBucketState*(
lastTimeFull: lastTimeFull,
)
)
+
+proc addToQueue*[T](
+ store: SqliteRateLimitStore[T],
+ queueType: QueueType,
+ msgs: seq[tuple[msgId: string, msg: T]],
+): Future[bool] {.async.} =
+ try:
+ let batchId = store.nextBatchId
+ inc store.nextBatchId
+ let now = times.getTime().toUnix()
+ let queueTypeStr = $queueType
+
+ if msgs.len > 0:
+ store.db.exec(sql"BEGIN TRANSACTION")
+ try:
+ for msg in msgs:
+ # Consistent serialization for all types
+ let msgBytes = msg.msg.toBytes()
+ # Convert seq[byte] to string for SQLite storage (each byte becomes a character)
+ var binaryStr = newString(msgBytes.len)
+ for i, b in msgBytes:
+ binaryStr[i] = char(b)
+
+ store.db.exec(
+ sql"INSERT INTO ratelimit_queues (queue_type, msg_id, msg_data, batch_id, created_at) VALUES (?, ?, ?, ?, ?)",
+ queueTypeStr,
+ msg.msgId,
+ binaryStr,
+ batchId,
+ now,
+ )
+ store.db.exec(sql"COMMIT")
+ except:
+ store.db.exec(sql"ROLLBACK")
+ raise
+
+ case queueType
+ of QueueType.Critical:
+ inc store.criticalLength
+ of QueueType.Normal:
+ inc store.normalLength
+
+ return true
+ except:
+ return false
+
+proc popFromQueue*[T](
+ store: SqliteRateLimitStore[T], queueType: QueueType
+): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} =
+ try:
+ let queueTypeStr = $queueType
+
+ # Get the oldest batch ID for this queue type
+ let oldestBatchStr = store.db.getValue(
+ sql"SELECT MIN(batch_id) FROM ratelimit_queues WHERE queue_type = ?", queueTypeStr
+ )
+
+ if oldestBatchStr == "":
+ return none(seq[tuple[msgId: string, msg: T]])
+
+ let batchId = parseInt(oldestBatchStr)
+
+ # Get all messages in this batch (preserve insertion order using rowid)
+ let rows = store.db.getAllRows(
+ sql"SELECT msg_id, msg_data FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ? ORDER BY rowid",
+ queueTypeStr,
+ batchId,
+ )
+
+ if rows.len == 0:
+ return none(seq[tuple[msgId: string, msg: T]])
+
+ var msgs: seq[tuple[msgId: string, msg: T]]
+ for row in rows:
+ let msgIdStr = row[0]
+ let msgData = row[1] # SQLite returns BLOB as string where each char is a byte
+ # Convert string back to seq[byte] properly (each char in string is a byte)
+ var msgBytes: seq[byte]
+ for c in msgData:
+ msgBytes.add(byte(c))
+
+ # Generic deserialization - works for any type that implements fromBytes
+ when T is string:
+ let msg = fromBytesImpl(msgBytes, T)
+ msgs.add((msgId: msgIdStr, msg: msg))
+ else:
+ # For other types, they need to provide their own fromBytes in the calling context
+ let msg = fromBytes(msgBytes, T)
+ msgs.add((msgId: msgIdStr, msg: msg))
+
+ # Delete the batch from database
+ store.db.exec(
+ sql"DELETE FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ?",
+ queueTypeStr,
+ batchId,
+ )
+
+ case queueType
+ of QueueType.Critical:
+ dec store.criticalLength
+ of QueueType.Normal:
+ dec store.normalLength
+
+ return some(msgs)
+ except:
+ return none(seq[tuple[msgId: string, msg: T]])
+
+proc getQueueLength*[T](store: SqliteRateLimitStore[T], queueType: QueueType): int =
+ case queueType
+ of QueueType.Critical:
+ return store.criticalLength
+ of QueueType.Normal:
+ return store.normalLength
diff --git a/ratelimit/store/store.nim b/ratelimit/store/store.nim
index c916750..0f18eb1 100644
--- a/ratelimit/store/store.nim
+++ b/ratelimit/store/store.nim
@@ -7,7 +7,14 @@ type
budgetCap*: int
lastTimeFull*: Moment
+ QueueType* {.pure.} = enum
+ Critical = "critical"
+ Normal = "normal"
+
RateLimitStore* =
concept s
s.saveBucketState(BucketState) is Future[bool]
s.loadBucketState() is Future[Option[BucketState]]
+ s.addToQueue(QueueType, seq[tuple[msgId: string, msg: untyped]]) is Future[bool]
+ s.popFromQueue(QueueType) is Future[Option[seq[tuple[msgId: string, msg: untyped]]]]
+ s.getQueueLength(QueueType) is int
diff --git a/tests/test_sqlite_store.nim b/tests/test_sqlite_store.nim
index 6315ec9..90d764c 100644
--- a/tests/test_sqlite_store.nim
+++ b/tests/test_sqlite_store.nim
@@ -6,6 +6,19 @@ import db_connector/db_sqlite
import ../chat_sdk/migration
import std/[options, os]
+# Implement the Serializable concept for string (for testing)
+proc toBytes*(s: string): seq[byte] =
+ # Convert each character to a byte
+ result = newSeq[byte](s.len)
+ for i, c in s:
+ result[i] = byte(c)
+
+proc fromBytes*(bytes: seq[byte], T: typedesc[string]): string =
+ # Convert each byte back to a character
+ result = newString(bytes.len)
+ for i, b in bytes:
+ result[i] = char(b)
+
suite "SqliteRateLimitStore Tests":
setup:
let db = open("test-ratelimit.db", "", "", "")
@@ -19,7 +32,7 @@ suite "SqliteRateLimitStore Tests":
asyncTest "newSqliteRateLimitStore - empty state":
## Given
- let store = newSqliteRateLimitStore(db)
+ let store = newSqliteRateLimitStore[string](db)
## When
let loadedState = await store.loadBucketState()
@@ -29,7 +42,7 @@ suite "SqliteRateLimitStore Tests":
asyncTest "saveBucketState and loadBucketState - state persistence":
## Given
- let store = newSqliteRateLimitStore(db)
+ let store = newSqliteRateLimitStore[string](db)
let now = Moment.now()
echo "now: ", now.epochSeconds()
@@ -46,3 +59,161 @@ suite "SqliteRateLimitStore Tests":
check loadedState.get().budgetCap == newBucketState.budgetCap
check loadedState.get().lastTimeFull.epochSeconds() ==
newBucketState.lastTimeFull.epochSeconds()
+
+ asyncTest "queue operations - empty store":
+ ## Given
+ let store = newSqliteRateLimitStore[string](db)
+
+ ## When/Then
+ check store.getQueueLength(QueueType.Critical) == 0
+ check store.getQueueLength(QueueType.Normal) == 0
+
+ let criticalPop = await store.popFromQueue(QueueType.Critical)
+ let normalPop = await store.popFromQueue(QueueType.Normal)
+
+ check criticalPop.isNone()
+ check normalPop.isNone()
+
+ asyncTest "addToQueue and popFromQueue - single batch":
+ ## Given
+ let store = newSqliteRateLimitStore[string](db)
+ let msgs = @[("msg1", "Hello"), ("msg2", "World")]
+
+ ## When
+ let addResult = await store.addToQueue(QueueType.Critical, msgs)
+
+ ## Then
+ check addResult == true
+ check store.getQueueLength(QueueType.Critical) == 1
+ check store.getQueueLength(QueueType.Normal) == 0
+
+ ## When
+ let popResult = await store.popFromQueue(QueueType.Critical)
+
+ ## Then
+ check popResult.isSome()
+ let poppedMsgs = popResult.get()
+ check poppedMsgs.len == 2
+ check poppedMsgs[0].msgId == "msg1"
+ check poppedMsgs[0].msg == "Hello"
+ check poppedMsgs[1].msgId == "msg2"
+ check poppedMsgs[1].msg == "World"
+
+ check store.getQueueLength(QueueType.Critical) == 0
+
+ asyncTest "addToQueue and popFromQueue - multiple batches FIFO":
+ ## Given
+ let store = newSqliteRateLimitStore[string](db)
+ let batch1 = @[("msg1", "First")]
+ let batch2 = @[("msg2", "Second")]
+ let batch3 = @[("msg3", "Third")]
+
+ ## When - Add batches
+ let result1 = await store.addToQueue(QueueType.Normal, batch1)
+ check result1 == true
+ let result2 = await store.addToQueue(QueueType.Normal, batch2)
+ check result2 == true
+ let result3 = await store.addToQueue(QueueType.Normal, batch3)
+ check result3 == true
+
+ ## Then - Check lengths
+ check store.getQueueLength(QueueType.Normal) == 3
+ check store.getQueueLength(QueueType.Critical) == 0
+
+ ## When/Then - Pop in FIFO order
+ let pop1 = await store.popFromQueue(QueueType.Normal)
+ check pop1.isSome()
+ check pop1.get()[0].msg == "First"
+ check store.getQueueLength(QueueType.Normal) == 2
+
+ let pop2 = await store.popFromQueue(QueueType.Normal)
+ check pop2.isSome()
+ check pop2.get()[0].msg == "Second"
+ check store.getQueueLength(QueueType.Normal) == 1
+
+ let pop3 = await store.popFromQueue(QueueType.Normal)
+ check pop3.isSome()
+ check pop3.get()[0].msg == "Third"
+ check store.getQueueLength(QueueType.Normal) == 0
+
+ let pop4 = await store.popFromQueue(QueueType.Normal)
+ check pop4.isNone()
+
+ asyncTest "queue isolation - critical and normal queues are separate":
+ ## Given
+ let store = newSqliteRateLimitStore[string](db)
+ let criticalMsgs = @[("crit1", "Critical Message")]
+ let normalMsgs = @[("norm1", "Normal Message")]
+
+ ## When
+ let critResult = await store.addToQueue(QueueType.Critical, criticalMsgs)
+ check critResult == true
+ let normResult = await store.addToQueue(QueueType.Normal, normalMsgs)
+ check normResult == true
+
+ ## Then
+ check store.getQueueLength(QueueType.Critical) == 1
+ check store.getQueueLength(QueueType.Normal) == 1
+
+ ## When - Pop from critical
+ let criticalPop = await store.popFromQueue(QueueType.Critical)
+ check criticalPop.isSome()
+ check criticalPop.get()[0].msg == "Critical Message"
+
+ ## Then - Normal queue unaffected
+ check store.getQueueLength(QueueType.Critical) == 0
+ check store.getQueueLength(QueueType.Normal) == 1
+
+ ## When - Pop from normal
+ let normalPop = await store.popFromQueue(QueueType.Normal)
+ check normalPop.isSome()
+ check normalPop.get()[0].msg == "Normal Message"
+
+ ## Then - All queues empty
+ check store.getQueueLength(QueueType.Critical) == 0
+ check store.getQueueLength(QueueType.Normal) == 0
+
+ asyncTest "queue persistence across store instances":
+ ## Given
+ let msgs = @[("persist1", "Persistent Message")]
+
+ block:
+ let store1 = newSqliteRateLimitStore[string](db)
+ let addResult = await store1.addToQueue(QueueType.Critical, msgs)
+ check addResult == true
+ check store1.getQueueLength(QueueType.Critical) == 1
+
+ ## When - Create new store instance
+ block:
+ let store2 = newSqliteRateLimitStore[string](db)
+
+ ## Then - Queue length should be restored from database
+ check store2.getQueueLength(QueueType.Critical) == 1
+
+ let popResult = await store2.popFromQueue(QueueType.Critical)
+ check popResult.isSome()
+ check popResult.get()[0].msg == "Persistent Message"
+ check store2.getQueueLength(QueueType.Critical) == 0
+
+ asyncTest "large batch handling":
+ ## Given
+ let store = newSqliteRateLimitStore[string](db)
+ var largeBatch: seq[tuple[msgId: string, msg: string]]
+
+ for i in 1 .. 100:
+ largeBatch.add(("msg" & $i, "Message " & $i))
+
+ ## When
+ let addResult = await store.addToQueue(QueueType.Normal, largeBatch)
+
+ ## Then
+ check addResult == true
+ check store.getQueueLength(QueueType.Normal) == 1
+
+ let popResult = await store.popFromQueue(QueueType.Normal)
+ check popResult.isSome()
+ let poppedMsgs = popResult.get()
+ check poppedMsgs.len == 100
+ check poppedMsgs[0].msgId == "msg1"
+ check poppedMsgs[99].msgId == "msg100"
+ check store.getQueueLength(QueueType.Normal) == 0