mirror of
https://github.com/logos-messaging/nim-chat-sdk.git
synced 2026-01-02 14:13:07 +00:00
feat: store queue
This commit is contained in:
parent
109b5769da
commit
2c47183fb0
@ -1,17 +1,21 @@
|
||||
import os, sequtils, algorithm
|
||||
import os, sequtils, algorithm, strutils
|
||||
import db_connector/db_sqlite
|
||||
import chronicles
|
||||
|
||||
proc ensureMigrationTable(db: DbConn) =
|
||||
db.exec(sql"""
|
||||
db.exec(
|
||||
sql"""
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
filename TEXT PRIMARY KEY,
|
||||
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
proc hasMigrationRun(db: DbConn, filename: string): bool =
|
||||
for row in db.fastRows(sql"SELECT 1 FROM schema_migrations WHERE filename = ?", filename):
|
||||
for row in db.fastRows(
|
||||
sql"SELECT 1 FROM schema_migrations WHERE filename = ?", filename
|
||||
):
|
||||
return true
|
||||
return false
|
||||
|
||||
@ -27,6 +31,11 @@ proc runMigrations*(db: DbConn, dir = "migrations") =
|
||||
info "Migration already applied", file
|
||||
else:
|
||||
info "Applying migration", file
|
||||
let sql = readFile(file)
|
||||
db.exec(sql(sql))
|
||||
let sqlContent = readFile(file)
|
||||
# Split by semicolon and execute each statement separately
|
||||
let statements = sqlContent.split(';')
|
||||
for stmt in statements:
|
||||
let trimmedStmt = stmt.strip()
|
||||
if trimmedStmt.len > 0:
|
||||
db.exec(sql(trimmedStmt))
|
||||
markMigrationRun(db, file)
|
||||
|
||||
@ -1,4 +1,13 @@
|
||||
CREATE TABLE IF NOT EXISTS kv_store (
|
||||
key TEXT PRIMARY KEY,
|
||||
value BLOB
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS ratelimit_queues (
|
||||
queue_type TEXT NOT NULL,
|
||||
msg_id TEXT NOT NULL,
|
||||
msg_data BLOB NOT NULL,
|
||||
batch_id INTEGER NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
PRIMARY KEY (queue_type, batch_id, msg_id)
|
||||
);
|
||||
@ -1,21 +1,70 @@
|
||||
import std/[times, options]
|
||||
import std/[times, options, deques, tables]
|
||||
import ./store
|
||||
import chronos
|
||||
|
||||
# Memory Implementation
|
||||
type MemoryRateLimitStore* = ref object
|
||||
type MemoryRateLimitStore*[T] = ref object
|
||||
bucketState: Option[BucketState]
|
||||
criticalQueue: Deque[seq[tuple[msgId: string, msg: T]]]
|
||||
normalQueue: Deque[seq[tuple[msgId: string, msg: T]]]
|
||||
criticalLength: int
|
||||
normalLength: int
|
||||
|
||||
proc new*(T: type[MemoryRateLimitStore]): T =
|
||||
return T(bucketState: none(BucketState))
|
||||
proc new*[T](M: type[MemoryRateLimitStore[T]]): M =
|
||||
return M(
|
||||
bucketState: none(BucketState),
|
||||
criticalQueue: initDeque[seq[tuple[msgId: string, msg: T]]](),
|
||||
normalQueue: initDeque[seq[tuple[msgId: string, msg: T]]](),
|
||||
criticalLength: 0,
|
||||
normalLength: 0
|
||||
)
|
||||
|
||||
proc saveBucketState*(
|
||||
store: MemoryRateLimitStore, bucketState: BucketState
|
||||
proc saveBucketState*[T](
|
||||
store: MemoryRateLimitStore[T], bucketState: BucketState
|
||||
): Future[bool] {.async.} =
|
||||
store.bucketState = some(bucketState)
|
||||
return true
|
||||
|
||||
proc loadBucketState*(
|
||||
store: MemoryRateLimitStore
|
||||
proc loadBucketState*[T](
|
||||
store: MemoryRateLimitStore[T]
|
||||
): Future[Option[BucketState]] {.async.} =
|
||||
return store.bucketState
|
||||
|
||||
proc addToQueue*[T](
|
||||
store: MemoryRateLimitStore[T],
|
||||
queueType: QueueType,
|
||||
msgs: seq[tuple[msgId: string, msg: T]]
|
||||
): Future[bool] {.async.} =
|
||||
case queueType
|
||||
of QueueType.Critical:
|
||||
store.criticalQueue.addLast(msgs)
|
||||
inc store.criticalLength
|
||||
of QueueType.Normal:
|
||||
store.normalQueue.addLast(msgs)
|
||||
inc store.normalLength
|
||||
return true
|
||||
|
||||
proc popFromQueue*[T](
|
||||
store: MemoryRateLimitStore[T],
|
||||
queueType: QueueType
|
||||
): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} =
|
||||
case queueType
|
||||
of QueueType.Critical:
|
||||
if store.criticalQueue.len > 0:
|
||||
dec store.criticalLength
|
||||
return some(store.criticalQueue.popFirst())
|
||||
of QueueType.Normal:
|
||||
if store.normalQueue.len > 0:
|
||||
dec store.normalLength
|
||||
return some(store.normalQueue.popFirst())
|
||||
return none(seq[tuple[msgId: string, msg: T]])
|
||||
|
||||
proc getQueueLength*[T](
|
||||
store: MemoryRateLimitStore[T],
|
||||
queueType: QueueType
|
||||
): int =
|
||||
case queueType
|
||||
of QueueType.Critical:
|
||||
return store.criticalLength
|
||||
of QueueType.Normal:
|
||||
return store.normalLength
|
||||
|
||||
@ -3,18 +3,58 @@ import ./store
|
||||
import chronos
|
||||
import db_connector/db_sqlite
|
||||
|
||||
# Generic deserialization function for basic types
|
||||
proc fromBytesImpl(bytes: seq[byte], T: typedesc[string]): string =
|
||||
# Convert each byte back to a character
|
||||
result = newString(bytes.len)
|
||||
for i, b in bytes:
|
||||
result[i] = char(b)
|
||||
|
||||
# SQLite Implementation
|
||||
type SqliteRateLimitStore* = ref object
|
||||
type SqliteRateLimitStore*[T] = ref object
|
||||
db: DbConn
|
||||
dbPath: string
|
||||
criticalLength: int
|
||||
normalLength: int
|
||||
nextBatchId: int
|
||||
|
||||
const BUCKET_STATE_KEY = "rate_limit_bucket_state"
|
||||
|
||||
proc newSqliteRateLimitStore*(db: DbConn): SqliteRateLimitStore =
|
||||
result = SqliteRateLimitStore(db: db)
|
||||
proc newSqliteRateLimitStore*[T](db: DbConn): SqliteRateLimitStore[T] =
|
||||
result =
|
||||
SqliteRateLimitStore[T](db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1)
|
||||
|
||||
proc saveBucketState*(
|
||||
store: SqliteRateLimitStore, bucketState: BucketState
|
||||
# Initialize cached lengths from database
|
||||
let criticalCount = db.getValue(
|
||||
sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?",
|
||||
"critical",
|
||||
)
|
||||
let normalCount = db.getValue(
|
||||
sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?",
|
||||
"normal",
|
||||
)
|
||||
|
||||
result.criticalLength =
|
||||
if criticalCount == "":
|
||||
0
|
||||
else:
|
||||
parseInt(criticalCount)
|
||||
result.normalLength =
|
||||
if normalCount == "":
|
||||
0
|
||||
else:
|
||||
parseInt(normalCount)
|
||||
|
||||
# Get next batch ID
|
||||
let maxBatch = db.getValue(sql"SELECT MAX(batch_id) FROM ratelimit_queues")
|
||||
result.nextBatchId =
|
||||
if maxBatch == "":
|
||||
1
|
||||
else:
|
||||
parseInt(maxBatch) + 1
|
||||
|
||||
proc saveBucketState*[T](
|
||||
store: SqliteRateLimitStore[T], bucketState: BucketState
|
||||
): Future[bool] {.async.} =
|
||||
try:
|
||||
# Convert Moment to Unix seconds for storage
|
||||
@ -35,8 +75,8 @@ proc saveBucketState*(
|
||||
except:
|
||||
return false
|
||||
|
||||
proc loadBucketState*(
|
||||
store: SqliteRateLimitStore
|
||||
proc loadBucketState*[T](
|
||||
store: SqliteRateLimitStore[T]
|
||||
): Future[Option[BucketState]] {.async.} =
|
||||
let jsonStr =
|
||||
store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY)
|
||||
@ -54,3 +94,116 @@ proc loadBucketState*(
|
||||
lastTimeFull: lastTimeFull,
|
||||
)
|
||||
)
|
||||
|
||||
proc addToQueue*[T](
|
||||
store: SqliteRateLimitStore[T],
|
||||
queueType: QueueType,
|
||||
msgs: seq[tuple[msgId: string, msg: T]],
|
||||
): Future[bool] {.async.} =
|
||||
try:
|
||||
let batchId = store.nextBatchId
|
||||
inc store.nextBatchId
|
||||
let now = times.getTime().toUnix()
|
||||
let queueTypeStr = $queueType
|
||||
|
||||
if msgs.len > 0:
|
||||
store.db.exec(sql"BEGIN TRANSACTION")
|
||||
try:
|
||||
for msg in msgs:
|
||||
# Consistent serialization for all types
|
||||
let msgBytes = msg.msg.toBytes()
|
||||
# Convert seq[byte] to string for SQLite storage (each byte becomes a character)
|
||||
var binaryStr = newString(msgBytes.len)
|
||||
for i, b in msgBytes:
|
||||
binaryStr[i] = char(b)
|
||||
|
||||
store.db.exec(
|
||||
sql"INSERT INTO ratelimit_queues (queue_type, msg_id, msg_data, batch_id, created_at) VALUES (?, ?, ?, ?, ?)",
|
||||
queueTypeStr,
|
||||
msg.msgId,
|
||||
binaryStr,
|
||||
batchId,
|
||||
now,
|
||||
)
|
||||
store.db.exec(sql"COMMIT")
|
||||
except:
|
||||
store.db.exec(sql"ROLLBACK")
|
||||
raise
|
||||
|
||||
case queueType
|
||||
of QueueType.Critical:
|
||||
inc store.criticalLength
|
||||
of QueueType.Normal:
|
||||
inc store.normalLength
|
||||
|
||||
return true
|
||||
except:
|
||||
return false
|
||||
|
||||
proc popFromQueue*[T](
|
||||
store: SqliteRateLimitStore[T], queueType: QueueType
|
||||
): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} =
|
||||
try:
|
||||
let queueTypeStr = $queueType
|
||||
|
||||
# Get the oldest batch ID for this queue type
|
||||
let oldestBatchStr = store.db.getValue(
|
||||
sql"SELECT MIN(batch_id) FROM ratelimit_queues WHERE queue_type = ?", queueTypeStr
|
||||
)
|
||||
|
||||
if oldestBatchStr == "":
|
||||
return none(seq[tuple[msgId: string, msg: T]])
|
||||
|
||||
let batchId = parseInt(oldestBatchStr)
|
||||
|
||||
# Get all messages in this batch (preserve insertion order using rowid)
|
||||
let rows = store.db.getAllRows(
|
||||
sql"SELECT msg_id, msg_data FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ? ORDER BY rowid",
|
||||
queueTypeStr,
|
||||
batchId,
|
||||
)
|
||||
|
||||
if rows.len == 0:
|
||||
return none(seq[tuple[msgId: string, msg: T]])
|
||||
|
||||
var msgs: seq[tuple[msgId: string, msg: T]]
|
||||
for row in rows:
|
||||
let msgIdStr = row[0]
|
||||
let msgData = row[1] # SQLite returns BLOB as string where each char is a byte
|
||||
# Convert string back to seq[byte] properly (each char in string is a byte)
|
||||
var msgBytes: seq[byte]
|
||||
for c in msgData:
|
||||
msgBytes.add(byte(c))
|
||||
|
||||
# Generic deserialization - works for any type that implements fromBytes
|
||||
when T is string:
|
||||
let msg = fromBytesImpl(msgBytes, T)
|
||||
msgs.add((msgId: msgIdStr, msg: msg))
|
||||
else:
|
||||
# For other types, they need to provide their own fromBytes in the calling context
|
||||
let msg = fromBytes(msgBytes, T)
|
||||
msgs.add((msgId: msgIdStr, msg: msg))
|
||||
|
||||
# Delete the batch from database
|
||||
store.db.exec(
|
||||
sql"DELETE FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ?",
|
||||
queueTypeStr,
|
||||
batchId,
|
||||
)
|
||||
|
||||
case queueType
|
||||
of QueueType.Critical:
|
||||
dec store.criticalLength
|
||||
of QueueType.Normal:
|
||||
dec store.normalLength
|
||||
|
||||
return some(msgs)
|
||||
except:
|
||||
return none(seq[tuple[msgId: string, msg: T]])
|
||||
|
||||
proc getQueueLength*[T](store: SqliteRateLimitStore[T], queueType: QueueType): int =
|
||||
case queueType
|
||||
of QueueType.Critical:
|
||||
return store.criticalLength
|
||||
of QueueType.Normal:
|
||||
return store.normalLength
|
||||
|
||||
@ -7,7 +7,14 @@ type
|
||||
budgetCap*: int
|
||||
lastTimeFull*: Moment
|
||||
|
||||
QueueType* {.pure.} = enum
|
||||
Critical = "critical"
|
||||
Normal = "normal"
|
||||
|
||||
RateLimitStore* =
|
||||
concept s
|
||||
s.saveBucketState(BucketState) is Future[bool]
|
||||
s.loadBucketState() is Future[Option[BucketState]]
|
||||
s.addToQueue(QueueType, seq[tuple[msgId: string, msg: untyped]]) is Future[bool]
|
||||
s.popFromQueue(QueueType) is Future[Option[seq[tuple[msgId: string, msg: untyped]]]]
|
||||
s.getQueueLength(QueueType) is int
|
||||
|
||||
@ -6,6 +6,19 @@ import db_connector/db_sqlite
|
||||
import ../chat_sdk/migration
|
||||
import std/[options, os]
|
||||
|
||||
# Implement the Serializable concept for string (for testing)
|
||||
proc toBytes*(s: string): seq[byte] =
|
||||
# Convert each character to a byte
|
||||
result = newSeq[byte](s.len)
|
||||
for i, c in s:
|
||||
result[i] = byte(c)
|
||||
|
||||
proc fromBytes*(bytes: seq[byte], T: typedesc[string]): string =
|
||||
# Convert each byte back to a character
|
||||
result = newString(bytes.len)
|
||||
for i, b in bytes:
|
||||
result[i] = char(b)
|
||||
|
||||
suite "SqliteRateLimitStore Tests":
|
||||
setup:
|
||||
let db = open("test-ratelimit.db", "", "", "")
|
||||
@ -19,7 +32,7 @@ suite "SqliteRateLimitStore Tests":
|
||||
|
||||
asyncTest "newSqliteRateLimitStore - empty state":
|
||||
## Given
|
||||
let store = newSqliteRateLimitStore(db)
|
||||
let store = newSqliteRateLimitStore[string](db)
|
||||
|
||||
## When
|
||||
let loadedState = await store.loadBucketState()
|
||||
@ -29,7 +42,7 @@ suite "SqliteRateLimitStore Tests":
|
||||
|
||||
asyncTest "saveBucketState and loadBucketState - state persistence":
|
||||
## Given
|
||||
let store = newSqliteRateLimitStore(db)
|
||||
let store = newSqliteRateLimitStore[string](db)
|
||||
|
||||
let now = Moment.now()
|
||||
echo "now: ", now.epochSeconds()
|
||||
@ -46,3 +59,161 @@ suite "SqliteRateLimitStore Tests":
|
||||
check loadedState.get().budgetCap == newBucketState.budgetCap
|
||||
check loadedState.get().lastTimeFull.epochSeconds() ==
|
||||
newBucketState.lastTimeFull.epochSeconds()
|
||||
|
||||
asyncTest "queue operations - empty store":
|
||||
## Given
|
||||
let store = newSqliteRateLimitStore[string](db)
|
||||
|
||||
## When/Then
|
||||
check store.getQueueLength(QueueType.Critical) == 0
|
||||
check store.getQueueLength(QueueType.Normal) == 0
|
||||
|
||||
let criticalPop = await store.popFromQueue(QueueType.Critical)
|
||||
let normalPop = await store.popFromQueue(QueueType.Normal)
|
||||
|
||||
check criticalPop.isNone()
|
||||
check normalPop.isNone()
|
||||
|
||||
asyncTest "addToQueue and popFromQueue - single batch":
|
||||
## Given
|
||||
let store = newSqliteRateLimitStore[string](db)
|
||||
let msgs = @[("msg1", "Hello"), ("msg2", "World")]
|
||||
|
||||
## When
|
||||
let addResult = await store.addToQueue(QueueType.Critical, msgs)
|
||||
|
||||
## Then
|
||||
check addResult == true
|
||||
check store.getQueueLength(QueueType.Critical) == 1
|
||||
check store.getQueueLength(QueueType.Normal) == 0
|
||||
|
||||
## When
|
||||
let popResult = await store.popFromQueue(QueueType.Critical)
|
||||
|
||||
## Then
|
||||
check popResult.isSome()
|
||||
let poppedMsgs = popResult.get()
|
||||
check poppedMsgs.len == 2
|
||||
check poppedMsgs[0].msgId == "msg1"
|
||||
check poppedMsgs[0].msg == "Hello"
|
||||
check poppedMsgs[1].msgId == "msg2"
|
||||
check poppedMsgs[1].msg == "World"
|
||||
|
||||
check store.getQueueLength(QueueType.Critical) == 0
|
||||
|
||||
asyncTest "addToQueue and popFromQueue - multiple batches FIFO":
|
||||
## Given
|
||||
let store = newSqliteRateLimitStore[string](db)
|
||||
let batch1 = @[("msg1", "First")]
|
||||
let batch2 = @[("msg2", "Second")]
|
||||
let batch3 = @[("msg3", "Third")]
|
||||
|
||||
## When - Add batches
|
||||
let result1 = await store.addToQueue(QueueType.Normal, batch1)
|
||||
check result1 == true
|
||||
let result2 = await store.addToQueue(QueueType.Normal, batch2)
|
||||
check result2 == true
|
||||
let result3 = await store.addToQueue(QueueType.Normal, batch3)
|
||||
check result3 == true
|
||||
|
||||
## Then - Check lengths
|
||||
check store.getQueueLength(QueueType.Normal) == 3
|
||||
check store.getQueueLength(QueueType.Critical) == 0
|
||||
|
||||
## When/Then - Pop in FIFO order
|
||||
let pop1 = await store.popFromQueue(QueueType.Normal)
|
||||
check pop1.isSome()
|
||||
check pop1.get()[0].msg == "First"
|
||||
check store.getQueueLength(QueueType.Normal) == 2
|
||||
|
||||
let pop2 = await store.popFromQueue(QueueType.Normal)
|
||||
check pop2.isSome()
|
||||
check pop2.get()[0].msg == "Second"
|
||||
check store.getQueueLength(QueueType.Normal) == 1
|
||||
|
||||
let pop3 = await store.popFromQueue(QueueType.Normal)
|
||||
check pop3.isSome()
|
||||
check pop3.get()[0].msg == "Third"
|
||||
check store.getQueueLength(QueueType.Normal) == 0
|
||||
|
||||
let pop4 = await store.popFromQueue(QueueType.Normal)
|
||||
check pop4.isNone()
|
||||
|
||||
asyncTest "queue isolation - critical and normal queues are separate":
|
||||
## Given
|
||||
let store = newSqliteRateLimitStore[string](db)
|
||||
let criticalMsgs = @[("crit1", "Critical Message")]
|
||||
let normalMsgs = @[("norm1", "Normal Message")]
|
||||
|
||||
## When
|
||||
let critResult = await store.addToQueue(QueueType.Critical, criticalMsgs)
|
||||
check critResult == true
|
||||
let normResult = await store.addToQueue(QueueType.Normal, normalMsgs)
|
||||
check normResult == true
|
||||
|
||||
## Then
|
||||
check store.getQueueLength(QueueType.Critical) == 1
|
||||
check store.getQueueLength(QueueType.Normal) == 1
|
||||
|
||||
## When - Pop from critical
|
||||
let criticalPop = await store.popFromQueue(QueueType.Critical)
|
||||
check criticalPop.isSome()
|
||||
check criticalPop.get()[0].msg == "Critical Message"
|
||||
|
||||
## Then - Normal queue unaffected
|
||||
check store.getQueueLength(QueueType.Critical) == 0
|
||||
check store.getQueueLength(QueueType.Normal) == 1
|
||||
|
||||
## When - Pop from normal
|
||||
let normalPop = await store.popFromQueue(QueueType.Normal)
|
||||
check normalPop.isSome()
|
||||
check normalPop.get()[0].msg == "Normal Message"
|
||||
|
||||
## Then - All queues empty
|
||||
check store.getQueueLength(QueueType.Critical) == 0
|
||||
check store.getQueueLength(QueueType.Normal) == 0
|
||||
|
||||
asyncTest "queue persistence across store instances":
|
||||
## Given
|
||||
let msgs = @[("persist1", "Persistent Message")]
|
||||
|
||||
block:
|
||||
let store1 = newSqliteRateLimitStore[string](db)
|
||||
let addResult = await store1.addToQueue(QueueType.Critical, msgs)
|
||||
check addResult == true
|
||||
check store1.getQueueLength(QueueType.Critical) == 1
|
||||
|
||||
## When - Create new store instance
|
||||
block:
|
||||
let store2 = newSqliteRateLimitStore[string](db)
|
||||
|
||||
## Then - Queue length should be restored from database
|
||||
check store2.getQueueLength(QueueType.Critical) == 1
|
||||
|
||||
let popResult = await store2.popFromQueue(QueueType.Critical)
|
||||
check popResult.isSome()
|
||||
check popResult.get()[0].msg == "Persistent Message"
|
||||
check store2.getQueueLength(QueueType.Critical) == 0
|
||||
|
||||
asyncTest "large batch handling":
|
||||
## Given
|
||||
let store = newSqliteRateLimitStore[string](db)
|
||||
var largeBatch: seq[tuple[msgId: string, msg: string]]
|
||||
|
||||
for i in 1 .. 100:
|
||||
largeBatch.add(("msg" & $i, "Message " & $i))
|
||||
|
||||
## When
|
||||
let addResult = await store.addToQueue(QueueType.Normal, largeBatch)
|
||||
|
||||
## Then
|
||||
check addResult == true
|
||||
check store.getQueueLength(QueueType.Normal) == 1
|
||||
|
||||
let popResult = await store.popFromQueue(QueueType.Normal)
|
||||
check popResult.isSome()
|
||||
let poppedMsgs = popResult.get()
|
||||
check poppedMsgs.len == 100
|
||||
check poppedMsgs[0].msgId == "msg1"
|
||||
check poppedMsgs[99].msgId == "msg100"
|
||||
check store.getQueueLength(QueueType.Normal) == 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user