import std/[times, strutils, json, options] import ./store import chronos import db_connector/db_sqlite # Generic deserialization function for basic types proc fromBytesImpl(bytes: seq[byte], T: typedesc[string]): string = # Convert each byte back to a character result = newString(bytes.len) for i, b in bytes: result[i] = char(b) # SQLite Implementation type SqliteRateLimitStore*[T] = ref object db: DbConn dbPath: string criticalLength: int normalLength: int nextBatchId: int const BUCKET_STATE_KEY = "rate_limit_bucket_state" proc newSqliteRateLimitStore*[T](db: DbConn): SqliteRateLimitStore[T] = result = SqliteRateLimitStore[T](db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1) # Initialize cached lengths from database let criticalCount = db.getValue( sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?", "critical", ) let normalCount = db.getValue( sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?", "normal", ) result.criticalLength = if criticalCount == "": 0 else: parseInt(criticalCount) result.normalLength = if normalCount == "": 0 else: parseInt(normalCount) # Get next batch ID let maxBatch = db.getValue(sql"SELECT MAX(batch_id) FROM ratelimit_queues") result.nextBatchId = if maxBatch == "": 1 else: parseInt(maxBatch) + 1 proc saveBucketState*[T]( store: SqliteRateLimitStore[T], bucketState: BucketState ): Future[bool] {.async.} = try: # Convert Moment to Unix seconds for storage let lastTimeSeconds = bucketState.lastTimeFull.epochSeconds() let jsonState = %*{ "budget": bucketState.budget, "budgetCap": bucketState.budgetCap, "lastTimeFullSeconds": lastTimeSeconds, } store.db.exec( sql"INSERT INTO kv_store (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value", BUCKET_STATE_KEY, $jsonState, ) return true except: return false proc loadBucketState*[T]( store: SqliteRateLimitStore[T] ): Future[Option[BucketState]] {.async.} = let jsonStr = store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY) if jsonStr == "": return none(BucketState) let jsonData = parseJson(jsonStr) let unixSeconds = jsonData["lastTimeFullSeconds"].getInt().int64 let lastTimeFull = Moment.init(unixSeconds, chronos.seconds(1)) return some( BucketState( budget: jsonData["budget"].getInt(), budgetCap: jsonData["budgetCap"].getInt(), lastTimeFull: lastTimeFull, ) ) proc addToQueue*[T]( store: SqliteRateLimitStore[T], queueType: QueueType, msgs: seq[tuple[msgId: string, msg: T]], ): Future[bool] {.async.} = try: let batchId = store.nextBatchId inc store.nextBatchId let now = times.getTime().toUnix() let queueTypeStr = $queueType if msgs.len > 0: store.db.exec(sql"BEGIN TRANSACTION") try: for msg in msgs: # Consistent serialization for all types let msgBytes = msg.msg.toBytes() # Convert seq[byte] to string for SQLite storage (each byte becomes a character) var binaryStr = newString(msgBytes.len) for i, b in msgBytes: binaryStr[i] = char(b) store.db.exec( sql"INSERT INTO ratelimit_queues (queue_type, msg_id, msg_data, batch_id, created_at) VALUES (?, ?, ?, ?, ?)", queueTypeStr, msg.msgId, binaryStr, batchId, now, ) store.db.exec(sql"COMMIT") except: store.db.exec(sql"ROLLBACK") raise case queueType of QueueType.Critical: inc store.criticalLength of QueueType.Normal: inc store.normalLength return true except: return false proc popFromQueue*[T]( store: SqliteRateLimitStore[T], queueType: QueueType ): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} = try: let queueTypeStr = $queueType # Get the oldest batch ID for this queue type let oldestBatchStr = store.db.getValue( sql"SELECT MIN(batch_id) FROM ratelimit_queues WHERE queue_type = ?", queueTypeStr ) if oldestBatchStr == "": return none(seq[tuple[msgId: string, msg: T]]) let batchId = parseInt(oldestBatchStr) # Get all messages in this batch (preserve insertion order using rowid) let rows = store.db.getAllRows( sql"SELECT msg_id, msg_data FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ? ORDER BY rowid", queueTypeStr, batchId, ) if rows.len == 0: return none(seq[tuple[msgId: string, msg: T]]) var msgs: seq[tuple[msgId: string, msg: T]] for row in rows: let msgIdStr = row[0] let msgData = row[1] # SQLite returns BLOB as string where each char is a byte # Convert string back to seq[byte] properly (each char in string is a byte) var msgBytes: seq[byte] for c in msgData: msgBytes.add(byte(c)) # Generic deserialization - works for any type that implements fromBytes when T is string: let msg = fromBytesImpl(msgBytes, T) msgs.add((msgId: msgIdStr, msg: msg)) else: # For other types, they need to provide their own fromBytes in the calling context let msg = fromBytes(msgBytes, T) msgs.add((msgId: msgIdStr, msg: msg)) # Delete the batch from database store.db.exec( sql"DELETE FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ?", queueTypeStr, batchId, ) case queueType of QueueType.Critical: dec store.criticalLength of QueueType.Normal: dec store.normalLength return some(msgs) except: return none(seq[tuple[msgId: string, msg: T]]) proc getQueueLength*[T](store: SqliteRateLimitStore[T], queueType: QueueType): int = case queueType of QueueType.Critical: return store.criticalLength of QueueType.Normal: return store.normalLength