2025-09-01 12:07:39 +03:00

238 lines
6.3 KiB
Nim

import std/[times, strutils, json, options, base64]
import db_connector/db_sqlite
import chronos
import flatty
type
RateLimitStore*[T] = ref object
db: DbConn
dbPath: string
criticalLength: int
normalLength: int
nextBatchId: int
BucketState* {.pure.} = object
budget*: int
budgetCap*: int
lastTimeFull*: Moment
QueueType* {.pure.} = enum
Critical = "critical"
Normal = "normal"
MessageStatus* {.pure.} = enum
PassedToSender
Enqueued
Dropped
DroppedBatchTooLarge
DroppedFailedToEnqueue
const BUCKET_STATE_KEY = "rate_limit_bucket_state"
## TODO find a way to make these procs async
proc new*[T](M: type[RateLimitStore[T]], db: DbConn): Future[M] {.async.} =
result = M(db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1)
# Initialize cached lengths from database
let criticalCount = db.getValue(
sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?",
"critical",
)
let normalCount = db.getValue(
sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?",
"normal",
)
result.criticalLength =
if criticalCount == "":
0
else:
parseInt(criticalCount)
result.normalLength =
if normalCount == "":
0
else:
parseInt(normalCount)
# Get next batch ID
let maxBatch = db.getValue(sql"SELECT MAX(batch_id) FROM ratelimit_queues")
result.nextBatchId =
if maxBatch == "":
1
else:
parseInt(maxBatch) + 1
return result
proc saveBucketState*[T](
store: RateLimitStore[T], bucketState: BucketState
): Future[bool] {.async.} =
try:
# Convert Moment to Unix seconds for storage
let lastTimeSeconds = bucketState.lastTimeFull.epochSeconds()
let jsonState =
%*{
"budget": bucketState.budget,
"budgetCap": bucketState.budgetCap,
"lastTimeFullSeconds": lastTimeSeconds,
}
store.db.exec(
sql"INSERT INTO kv_store (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value",
BUCKET_STATE_KEY,
$jsonState,
)
return true
except:
return false
proc loadBucketState*[T](
store: RateLimitStore[T]
): Future[Option[BucketState]] {.async.} =
let jsonStr =
store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY)
if jsonStr == "":
return none(BucketState)
let jsonData = parseJson(jsonStr)
let unixSeconds = jsonData["lastTimeFullSeconds"].getInt().int64
let lastTimeFull = Moment.init(unixSeconds, chronos.seconds(1))
return some(
BucketState(
budget: jsonData["budget"].getInt(),
budgetCap: jsonData["budgetCap"].getInt(),
lastTimeFull: lastTimeFull,
)
)
proc pushToQueue*[T](
store: RateLimitStore[T],
queueType: QueueType,
msgs: seq[tuple[msgId: string, msg: T]],
): Future[bool] {.async.} =
try:
let batchId = store.nextBatchId
inc store.nextBatchId
let now = times.getTime().toUnix()
let queueTypeStr = $queueType
if msgs.len > 0:
store.db.exec(sql"BEGIN TRANSACTION")
try:
for msg in msgs:
let serialized = msg.msg.toFlatty()
let msgData = encode(serialized)
store.db.exec(
sql"INSERT INTO ratelimit_queues (queue_type, msg_id, msg_data, batch_id, created_at) VALUES (?, ?, ?, ?, ?)",
queueTypeStr,
msg.msgId,
msgData,
batchId,
now,
)
store.db.exec(sql"COMMIT")
except:
store.db.exec(sql"ROLLBACK")
raise
case queueType
of QueueType.Critical:
inc store.criticalLength
of QueueType.Normal:
inc store.normalLength
return true
except:
return false
proc popFromQueue*[T](
store: RateLimitStore[T], queueType: QueueType
): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} =
try:
let queueTypeStr = $queueType
# Get the oldest batch ID for this queue type
let oldestBatchStr = store.db.getValue(
sql"SELECT MIN(batch_id) FROM ratelimit_queues WHERE queue_type = ?", queueTypeStr
)
if oldestBatchStr == "":
return none(seq[tuple[msgId: string, msg: T]])
let batchId = parseInt(oldestBatchStr)
# Get all messages in this batch (preserve insertion order using rowid)
let rows = store.db.getAllRows(
sql"SELECT msg_id, msg_data FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ? ORDER BY rowid",
queueTypeStr,
batchId,
)
if rows.len == 0:
return none(seq[tuple[msgId: string, msg: T]])
var msgs: seq[tuple[msgId: string, msg: T]]
for row in rows:
let msgIdStr = row[0]
let msgDataB64 = row[1]
let serialized = decode(msgDataB64)
let msg = serialized.fromFlatty(T)
msgs.add((msgId: msgIdStr, msg: msg))
# Delete the batch from database
store.db.exec(
sql"DELETE FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ?",
queueTypeStr,
batchId,
)
case queueType
of QueueType.Critical:
dec store.criticalLength
of QueueType.Normal:
dec store.normalLength
return some(msgs)
except:
return none(seq[tuple[msgId: string, msg: T]])
proc updateMessageStatuses*[T](
store: RateLimitStore[T], messageIds: seq[string], status: MessageStatus
): Future[bool] {.async.} =
try:
let now = times.getTime().toUnix()
store.db.exec(sql"BEGIN TRANSACTION")
for msgId in messageIds:
store.db.exec(
sql"INSERT INTO ratelimit_message_status (msg_id, status, updated_at) VALUES (?, ?, ?) ON CONFLICT(msg_id) DO UPDATE SET status = excluded.status, updated_at = excluded.updated_at",
msgId,
status,
now,
)
store.db.exec(sql"COMMIT")
return true
except:
store.db.exec(sql"ROLLBACK")
return false
proc getMessageStatus*[T](
store: RateLimitStore[T], messageId: string
): Future[Option[MessageStatus]] {.async.} =
let statusStr = store.db.getValue(
sql"SELECT status FROM ratelimit_message_status WHERE msg_id = ?", messageId
)
if statusStr == "":
return none(MessageStatus)
return some(parseEnum[MessageStatus](statusStr))
proc getQueueLength*[T](store: RateLimitStore[T], queueType: QueueType): int =
case queueType
of QueueType.Critical:
return store.criticalLength
of QueueType.Normal:
return store.normalLength