import std/[times, strutils, json, options, base64] import db_connector/db_sqlite import chronos import flatty type RateLimitStore*[T] = ref object db: DbConn dbPath: string criticalLength: int normalLength: int nextBatchId: int BucketState* {.pure.} = object budget*: int budgetCap*: int lastTimeFull*: Moment QueueType* {.pure.} = enum Critical = "critical" Normal = "normal" MessageStatus* {.pure.} = enum PassedToSender Enqueued Dropped DroppedBatchTooLarge DroppedFailedToEnqueue const BUCKET_STATE_KEY = "rate_limit_bucket_state" ## TODO find a way to make these procs async proc new*[T](M: type[RateLimitStore[T]], db: DbConn): Future[M] {.async.} = result = M(db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1) # Initialize cached lengths from database let criticalCount = db.getValue( sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?", "critical", ) let normalCount = db.getValue( sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?", "normal", ) result.criticalLength = if criticalCount == "": 0 else: parseInt(criticalCount) result.normalLength = if normalCount == "": 0 else: parseInt(normalCount) # Get next batch ID let maxBatch = db.getValue(sql"SELECT MAX(batch_id) FROM ratelimit_queues") result.nextBatchId = if maxBatch == "": 1 else: parseInt(maxBatch) + 1 return result proc saveBucketState*[T]( store: RateLimitStore[T], bucketState: BucketState ): Future[bool] {.async.} = try: # Convert Moment to Unix seconds for storage let lastTimeSeconds = bucketState.lastTimeFull.epochSeconds() let jsonState = %*{ "budget": bucketState.budget, "budgetCap": bucketState.budgetCap, "lastTimeFullSeconds": lastTimeSeconds, } store.db.exec( sql"INSERT INTO kv_store (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value", BUCKET_STATE_KEY, $jsonState, ) return true except: return false proc loadBucketState*[T]( store: RateLimitStore[T] ): Future[Option[BucketState]] {.async.} = let jsonStr = store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY) if jsonStr == "": return none(BucketState) let jsonData = parseJson(jsonStr) let unixSeconds = jsonData["lastTimeFullSeconds"].getInt().int64 let lastTimeFull = Moment.init(unixSeconds, chronos.seconds(1)) return some( BucketState( budget: jsonData["budget"].getInt(), budgetCap: jsonData["budgetCap"].getInt(), lastTimeFull: lastTimeFull, ) ) proc pushToQueue*[T]( store: RateLimitStore[T], queueType: QueueType, msgs: seq[tuple[msgId: string, msg: T]], ): Future[bool] {.async.} = try: let batchId = store.nextBatchId inc store.nextBatchId let now = times.getTime().toUnix() let queueTypeStr = $queueType if msgs.len > 0: store.db.exec(sql"BEGIN TRANSACTION") try: for msg in msgs: let serialized = msg.msg.toFlatty() let msgData = encode(serialized) store.db.exec( sql"INSERT INTO ratelimit_queues (queue_type, msg_id, msg_data, batch_id, created_at) VALUES (?, ?, ?, ?, ?)", queueTypeStr, msg.msgId, msgData, batchId, now, ) store.db.exec(sql"COMMIT") except: store.db.exec(sql"ROLLBACK") raise case queueType of QueueType.Critical: inc store.criticalLength of QueueType.Normal: inc store.normalLength return true except: return false proc popFromQueue*[T]( store: RateLimitStore[T], queueType: QueueType ): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} = try: let queueTypeStr = $queueType # Get the oldest batch ID for this queue type let oldestBatchStr = store.db.getValue( sql"SELECT MIN(batch_id) FROM ratelimit_queues WHERE queue_type = ?", queueTypeStr ) if oldestBatchStr == "": return none(seq[tuple[msgId: string, msg: T]]) let batchId = parseInt(oldestBatchStr) # Get all messages in this batch (preserve insertion order using rowid) let rows = store.db.getAllRows( sql"SELECT msg_id, msg_data FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ? ORDER BY rowid", queueTypeStr, batchId, ) if rows.len == 0: return none(seq[tuple[msgId: string, msg: T]]) var msgs: seq[tuple[msgId: string, msg: T]] for row in rows: let msgIdStr = row[0] let msgDataB64 = row[1] let serialized = decode(msgDataB64) let msg = serialized.fromFlatty(T) msgs.add((msgId: msgIdStr, msg: msg)) # Delete the batch from database store.db.exec( sql"DELETE FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ?", queueTypeStr, batchId, ) case queueType of QueueType.Critical: dec store.criticalLength of QueueType.Normal: dec store.normalLength return some(msgs) except: return none(seq[tuple[msgId: string, msg: T]]) proc updateMessageStatuses*[T]( store: RateLimitStore[T], messageIds: seq[string], status: MessageStatus ): Future[bool] {.async.} = try: let now = times.getTime().toUnix() store.db.exec(sql"BEGIN TRANSACTION") for msgId in messageIds: store.db.exec( sql"INSERT INTO ratelimit_message_status (msg_id, status, updated_at) VALUES (?, ?, ?) ON CONFLICT(msg_id) DO UPDATE SET status = excluded.status, updated_at = excluded.updated_at", msgId, status, now, ) store.db.exec(sql"COMMIT") return true except: store.db.exec(sql"ROLLBACK") return false proc getMessageStatus*[T]( store: RateLimitStore[T], messageId: string ): Future[Option[MessageStatus]] {.async.} = let statusStr = store.db.getValue( sql"SELECT status FROM ratelimit_message_status WHERE msg_id = ?", messageId ) if statusStr == "": return none(MessageStatus) return some(parseEnum[MessageStatus](statusStr)) proc getQueueLength*[T](store: RateLimitStore[T], queueType: QueueType): int = case queueType of QueueType.Critical: return store.criticalLength of QueueType.Normal: return store.normalLength