mirror of
https://github.com/logos-messaging/nim-chat-sdk.git
synced 2026-01-03 14:43:07 +00:00
238 lines
6.3 KiB
Nim
238 lines
6.3 KiB
Nim
import std/[times, strutils, json, options, base64]
|
|
import db_connector/db_sqlite
|
|
import chronos
|
|
import flatty
|
|
|
|
type
|
|
RateLimitStore*[T] = ref object
|
|
db: DbConn
|
|
dbPath: string
|
|
criticalLength: int
|
|
normalLength: int
|
|
nextBatchId: int
|
|
|
|
BucketState* {.pure.} = object
|
|
budget*: int
|
|
budgetCap*: int
|
|
lastTimeFull*: Moment
|
|
|
|
QueueType* {.pure.} = enum
|
|
Critical = "critical"
|
|
Normal = "normal"
|
|
|
|
MessageStatus* {.pure.} = enum
|
|
PassedToSender
|
|
Enqueued
|
|
Dropped
|
|
DroppedBatchTooLarge
|
|
DroppedFailedToEnqueue
|
|
|
|
const BUCKET_STATE_KEY = "rate_limit_bucket_state"
|
|
|
|
## TODO find a way to make these procs async
|
|
|
|
proc new*[T](M: type[RateLimitStore[T]], db: DbConn): Future[M] {.async.} =
|
|
result = M(db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1)
|
|
|
|
# Initialize cached lengths from database
|
|
let criticalCount = db.getValue(
|
|
sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?",
|
|
"critical",
|
|
)
|
|
let normalCount = db.getValue(
|
|
sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?",
|
|
"normal",
|
|
)
|
|
|
|
result.criticalLength =
|
|
if criticalCount == "":
|
|
0
|
|
else:
|
|
parseInt(criticalCount)
|
|
result.normalLength =
|
|
if normalCount == "":
|
|
0
|
|
else:
|
|
parseInt(normalCount)
|
|
|
|
# Get next batch ID
|
|
let maxBatch = db.getValue(sql"SELECT MAX(batch_id) FROM ratelimit_queues")
|
|
result.nextBatchId =
|
|
if maxBatch == "":
|
|
1
|
|
else:
|
|
parseInt(maxBatch) + 1
|
|
|
|
return result
|
|
|
|
proc saveBucketState*[T](
|
|
store: RateLimitStore[T], bucketState: BucketState
|
|
): Future[bool] {.async.} =
|
|
try:
|
|
# Convert Moment to Unix seconds for storage
|
|
let lastTimeSeconds = bucketState.lastTimeFull.epochSeconds()
|
|
|
|
let jsonState =
|
|
%*{
|
|
"budget": bucketState.budget,
|
|
"budgetCap": bucketState.budgetCap,
|
|
"lastTimeFullSeconds": lastTimeSeconds,
|
|
}
|
|
store.db.exec(
|
|
sql"INSERT INTO kv_store (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value",
|
|
BUCKET_STATE_KEY,
|
|
$jsonState,
|
|
)
|
|
return true
|
|
except:
|
|
return false
|
|
|
|
proc loadBucketState*[T](
|
|
store: RateLimitStore[T]
|
|
): Future[Option[BucketState]] {.async.} =
|
|
let jsonStr =
|
|
store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY)
|
|
if jsonStr == "":
|
|
return none(BucketState)
|
|
|
|
let jsonData = parseJson(jsonStr)
|
|
let unixSeconds = jsonData["lastTimeFullSeconds"].getInt().int64
|
|
let lastTimeFull = Moment.init(unixSeconds, chronos.seconds(1))
|
|
|
|
return some(
|
|
BucketState(
|
|
budget: jsonData["budget"].getInt(),
|
|
budgetCap: jsonData["budgetCap"].getInt(),
|
|
lastTimeFull: lastTimeFull,
|
|
)
|
|
)
|
|
|
|
proc pushToQueue*[T](
|
|
store: RateLimitStore[T],
|
|
queueType: QueueType,
|
|
msgs: seq[tuple[msgId: string, msg: T]],
|
|
): Future[bool] {.async.} =
|
|
try:
|
|
let batchId = store.nextBatchId
|
|
inc store.nextBatchId
|
|
let now = times.getTime().toUnix()
|
|
let queueTypeStr = $queueType
|
|
|
|
if msgs.len > 0:
|
|
store.db.exec(sql"BEGIN TRANSACTION")
|
|
try:
|
|
for msg in msgs:
|
|
let serialized = msg.msg.toFlatty()
|
|
let msgData = encode(serialized)
|
|
store.db.exec(
|
|
sql"INSERT INTO ratelimit_queues (queue_type, msg_id, msg_data, batch_id, created_at) VALUES (?, ?, ?, ?, ?)",
|
|
queueTypeStr,
|
|
msg.msgId,
|
|
msgData,
|
|
batchId,
|
|
now,
|
|
)
|
|
store.db.exec(sql"COMMIT")
|
|
except:
|
|
store.db.exec(sql"ROLLBACK")
|
|
raise
|
|
|
|
case queueType
|
|
of QueueType.Critical:
|
|
inc store.criticalLength
|
|
of QueueType.Normal:
|
|
inc store.normalLength
|
|
|
|
return true
|
|
except:
|
|
return false
|
|
|
|
proc popFromQueue*[T](
|
|
store: RateLimitStore[T], queueType: QueueType
|
|
): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} =
|
|
try:
|
|
let queueTypeStr = $queueType
|
|
|
|
# Get the oldest batch ID for this queue type
|
|
let oldestBatchStr = store.db.getValue(
|
|
sql"SELECT MIN(batch_id) FROM ratelimit_queues WHERE queue_type = ?", queueTypeStr
|
|
)
|
|
|
|
if oldestBatchStr == "":
|
|
return none(seq[tuple[msgId: string, msg: T]])
|
|
|
|
let batchId = parseInt(oldestBatchStr)
|
|
|
|
# Get all messages in this batch (preserve insertion order using rowid)
|
|
let rows = store.db.getAllRows(
|
|
sql"SELECT msg_id, msg_data FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ? ORDER BY rowid",
|
|
queueTypeStr,
|
|
batchId,
|
|
)
|
|
|
|
if rows.len == 0:
|
|
return none(seq[tuple[msgId: string, msg: T]])
|
|
|
|
var msgs: seq[tuple[msgId: string, msg: T]]
|
|
for row in rows:
|
|
let msgIdStr = row[0]
|
|
let msgDataB64 = row[1]
|
|
|
|
let serialized = decode(msgDataB64)
|
|
let msg = serialized.fromFlatty(T)
|
|
msgs.add((msgId: msgIdStr, msg: msg))
|
|
|
|
# Delete the batch from database
|
|
store.db.exec(
|
|
sql"DELETE FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ?",
|
|
queueTypeStr,
|
|
batchId,
|
|
)
|
|
|
|
case queueType
|
|
of QueueType.Critical:
|
|
dec store.criticalLength
|
|
of QueueType.Normal:
|
|
dec store.normalLength
|
|
|
|
return some(msgs)
|
|
except:
|
|
return none(seq[tuple[msgId: string, msg: T]])
|
|
|
|
proc updateMessageStatuses*[T](
|
|
store: RateLimitStore[T], messageIds: seq[string], status: MessageStatus
|
|
): Future[bool] {.async.} =
|
|
try:
|
|
let now = times.getTime().toUnix()
|
|
store.db.exec(sql"BEGIN TRANSACTION")
|
|
for msgId in messageIds:
|
|
store.db.exec(
|
|
sql"INSERT INTO ratelimit_message_status (msg_id, status, updated_at) VALUES (?, ?, ?) ON CONFLICT(msg_id) DO UPDATE SET status = excluded.status, updated_at = excluded.updated_at",
|
|
msgId,
|
|
status,
|
|
now,
|
|
)
|
|
store.db.exec(sql"COMMIT")
|
|
return true
|
|
except:
|
|
store.db.exec(sql"ROLLBACK")
|
|
return false
|
|
|
|
proc getMessageStatus*[T](
|
|
store: RateLimitStore[T], messageId: string
|
|
): Future[Option[MessageStatus]] {.async.} =
|
|
let statusStr = store.db.getValue(
|
|
sql"SELECT status FROM ratelimit_message_status WHERE msg_id = ?", messageId
|
|
)
|
|
if statusStr == "":
|
|
return none(MessageStatus)
|
|
|
|
return some(parseEnum[MessageStatus](statusStr))
|
|
|
|
proc getQueueLength*[T](store: RateLimitStore[T], queueType: QueueType): int =
|
|
case queueType
|
|
of QueueType.Critical:
|
|
return store.criticalLength
|
|
of QueueType.Normal:
|
|
return store.normalLength
|