238 lines
6.3 KiB
Nim
Raw Normal View History

2025-08-12 12:09:57 +03:00
import std/[times, strutils, json, options, base64]
2025-07-14 11:14:36 +03:00
import db_connector/db_sqlite
2025-08-04 11:31:44 +03:00
import chronos
2025-08-12 12:09:57 +03:00
import flatty
2025-08-04 10:43:59 +03:00
2025-08-04 11:31:44 +03:00
type
2025-08-12 12:09:57 +03:00
RateLimitStore*[T] = ref object
2025-08-04 11:31:44 +03:00
db: DbConn
dbPath: string
criticalLength: int
normalLength: int
nextBatchId: int
2025-09-01 12:07:39 +03:00
BucketState* {.pure.} = object
2025-08-04 11:31:44 +03:00
budget*: int
budgetCap*: int
lastTimeFull*: Moment
QueueType* {.pure.} = enum
Critical = "critical"
Normal = "normal"
2025-07-14 11:14:36 +03:00
MessageStatus* {.pure.} = enum
PassedToSender
Enqueued
Dropped
DroppedBatchTooLarge
DroppedFailedToEnqueue
2025-07-16 09:22:29 +03:00
const BUCKET_STATE_KEY = "rate_limit_bucket_state"
2025-09-01 05:55:44 +03:00
## TODO find a way to make these procs async
2025-09-01 12:07:39 +03:00
proc new*[T](M: type[RateLimitStore[T]], db: DbConn): Future[M] {.async.} =
2025-08-04 11:31:44 +03:00
result = M(db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1)
2025-08-04 10:43:59 +03:00
# Initialize cached lengths from database
let criticalCount = db.getValue(
sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?",
"critical",
)
let normalCount = db.getValue(
sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?",
"normal",
)
result.criticalLength =
if criticalCount == "":
0
else:
parseInt(criticalCount)
result.normalLength =
if normalCount == "":
0
else:
parseInt(normalCount)
2025-07-14 11:14:36 +03:00
2025-08-04 10:43:59 +03:00
# Get next batch ID
let maxBatch = db.getValue(sql"SELECT MAX(batch_id) FROM ratelimit_queues")
result.nextBatchId =
if maxBatch == "":
1
else:
parseInt(maxBatch) + 1
2025-08-04 11:31:44 +03:00
return result
2025-08-12 12:09:57 +03:00
proc saveBucketState*[T](
2025-08-04 11:31:44 +03:00
store: RateLimitStore[T], bucketState: BucketState
2025-07-14 11:14:36 +03:00
): Future[bool] {.async.} =
try:
# Convert Moment to Unix seconds for storage
let lastTimeSeconds = bucketState.lastTimeFull.epochSeconds()
2025-07-16 09:22:29 +03:00
let jsonState =
%*{
"budget": bucketState.budget,
"budgetCap": bucketState.budgetCap,
"lastTimeFullSeconds": lastTimeSeconds,
}
2025-07-14 11:14:36 +03:00
store.db.exec(
2025-07-16 10:05:47 +03:00
sql"INSERT INTO kv_store (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value",
BUCKET_STATE_KEY,
$jsonState,
2025-07-14 11:14:36 +03:00
)
return true
except:
return false
2025-08-12 12:09:57 +03:00
proc loadBucketState*[T](
2025-08-04 11:31:44 +03:00
store: RateLimitStore[T]
2025-07-16 10:05:47 +03:00
): Future[Option[BucketState]] {.async.} =
2025-07-16 09:22:29 +03:00
let jsonStr =
store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY)
2025-07-16 10:05:47 +03:00
if jsonStr == "":
return none(BucketState)
2025-07-16 09:22:29 +03:00
let jsonData = parseJson(jsonStr)
let unixSeconds = jsonData["lastTimeFullSeconds"].getInt().int64
2025-07-14 11:14:36 +03:00
let lastTimeFull = Moment.init(unixSeconds, chronos.seconds(1))
2025-07-16 10:05:47 +03:00
return some(
BucketState(
budget: jsonData["budget"].getInt(),
budgetCap: jsonData["budgetCap"].getInt(),
lastTimeFull: lastTimeFull,
)
2025-07-14 11:14:36 +03:00
)
2025-08-04 10:43:59 +03:00
2025-08-12 12:09:57 +03:00
proc pushToQueue*[T](
2025-08-04 11:31:44 +03:00
store: RateLimitStore[T],
2025-08-04 10:43:59 +03:00
queueType: QueueType,
msgs: seq[tuple[msgId: string, msg: T]],
): Future[bool] {.async.} =
try:
let batchId = store.nextBatchId
inc store.nextBatchId
let now = times.getTime().toUnix()
let queueTypeStr = $queueType
if msgs.len > 0:
store.db.exec(sql"BEGIN TRANSACTION")
try:
for msg in msgs:
2025-08-12 12:09:57 +03:00
let serialized = msg.msg.toFlatty()
let msgData = encode(serialized)
2025-08-04 10:43:59 +03:00
store.db.exec(
sql"INSERT INTO ratelimit_queues (queue_type, msg_id, msg_data, batch_id, created_at) VALUES (?, ?, ?, ?, ?)",
queueTypeStr,
msg.msgId,
2025-08-12 12:09:57 +03:00
msgData,
2025-08-04 10:43:59 +03:00
batchId,
now,
)
store.db.exec(sql"COMMIT")
except:
store.db.exec(sql"ROLLBACK")
raise
case queueType
of QueueType.Critical:
inc store.criticalLength
of QueueType.Normal:
inc store.normalLength
return true
except:
return false
2025-08-12 12:09:57 +03:00
proc popFromQueue*[T](
2025-08-04 11:31:44 +03:00
store: RateLimitStore[T], queueType: QueueType
2025-08-04 10:43:59 +03:00
): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} =
try:
let queueTypeStr = $queueType
# Get the oldest batch ID for this queue type
let oldestBatchStr = store.db.getValue(
sql"SELECT MIN(batch_id) FROM ratelimit_queues WHERE queue_type = ?", queueTypeStr
)
if oldestBatchStr == "":
return none(seq[tuple[msgId: string, msg: T]])
let batchId = parseInt(oldestBatchStr)
# Get all messages in this batch (preserve insertion order using rowid)
let rows = store.db.getAllRows(
sql"SELECT msg_id, msg_data FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ? ORDER BY rowid",
queueTypeStr,
batchId,
)
if rows.len == 0:
return none(seq[tuple[msgId: string, msg: T]])
var msgs: seq[tuple[msgId: string, msg: T]]
for row in rows:
let msgIdStr = row[0]
2025-08-12 12:09:57 +03:00
let msgDataB64 = row[1]
let serialized = decode(msgDataB64)
let msg = serialized.fromFlatty(T)
msgs.add((msgId: msgIdStr, msg: msg))
2025-08-04 10:43:59 +03:00
# Delete the batch from database
store.db.exec(
sql"DELETE FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ?",
queueTypeStr,
batchId,
)
case queueType
of QueueType.Critical:
dec store.criticalLength
of QueueType.Normal:
dec store.normalLength
return some(msgs)
except:
return none(seq[tuple[msgId: string, msg: T]])
proc updateMessageStatuses*[T](
store: RateLimitStore[T], messageIds: seq[string], status: MessageStatus
): Future[bool] {.async.} =
try:
let now = times.getTime().toUnix()
store.db.exec(sql"BEGIN TRANSACTION")
for msgId in messageIds:
store.db.exec(
sql"INSERT INTO ratelimit_message_status (msg_id, status, updated_at) VALUES (?, ?, ?) ON CONFLICT(msg_id) DO UPDATE SET status = excluded.status, updated_at = excluded.updated_at",
msgId,
status,
now,
)
store.db.exec(sql"COMMIT")
return true
except:
store.db.exec(sql"ROLLBACK")
return false
proc getMessageStatus*[T](
store: RateLimitStore[T], messageId: string
): Future[Option[MessageStatus]] {.async.} =
let statusStr = store.db.getValue(
sql"SELECT status FROM ratelimit_message_status WHERE msg_id = ?", messageId
)
if statusStr == "":
return none(MessageStatus)
2025-09-01 12:07:39 +03:00
return some(parseEnum[MessageStatus](statusStr))
2025-08-12 12:09:57 +03:00
proc getQueueLength*[T](store: RateLimitStore[T], queueType: QueueType): int =
2025-08-04 10:43:59 +03:00
case queueType
of QueueType.Critical:
return store.criticalLength
of QueueType.Normal:
return store.normalLength