2025-08-12 12:09:57 +03:00
import std / [ times , strutils , json , options , base64 ]
2025-07-14 11:14:36 +03:00
import db_connector / db_sqlite
2025-08-04 11:31:44 +03:00
import chronos
2025-08-12 12:09:57 +03:00
import flatty
2025-08-04 10:43:59 +03:00
2025-08-04 11:31:44 +03:00
type
2025-08-12 12:09:57 +03:00
RateLimitStore * [ T ] = ref object
2025-08-04 11:31:44 +03:00
db : DbConn
dbPath : string
criticalLength : int
normalLength : int
nextBatchId : int
2025-09-01 12:07:39 +03:00
BucketState * {. pure . } = object
2025-08-04 11:31:44 +03:00
budget * : int
budgetCap * : int
lastTimeFull * : Moment
QueueType * {. pure . } = enum
Critical = " critical "
Normal = " normal "
2025-07-14 11:14:36 +03:00
2025-09-01 11:50:41 +03:00
MessageStatus * {. pure . } = enum
PassedToSender
Enqueued
Dropped
DroppedBatchTooLarge
DroppedFailedToEnqueue
2025-07-16 09:22:29 +03:00
const BUCKET_STATE_KEY = " rate_limit_bucket_state "
2025-09-01 05:55:44 +03:00
## TODO find a way to make these procs async
2025-09-01 12:07:39 +03:00
proc new * [ T ] ( M : type [ RateLimitStore [ T ] ] , db : DbConn ) : Future [ M ] {. async . } =
2025-08-04 11:31:44 +03:00
result = M ( db : db , criticalLength : 0 , normalLength : 0 , nextBatchId : 1 )
2025-08-04 10:43:59 +03:00
# Initialize cached lengths from database
let criticalCount = db . getValue (
sql" SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ? " ,
" critical " ,
)
let normalCount = db . getValue (
sql" SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ? " ,
" normal " ,
)
result . criticalLength =
if criticalCount = = " " :
0
else :
parseInt ( criticalCount )
result . normalLength =
if normalCount = = " " :
0
else :
parseInt ( normalCount )
2025-07-14 11:14:36 +03:00
2025-08-04 10:43:59 +03:00
# Get next batch ID
let maxBatch = db . getValue ( sql" SELECT MAX(batch_id) FROM ratelimit_queues " )
result . nextBatchId =
if maxBatch = = " " :
1
else :
parseInt ( maxBatch ) + 1
2025-08-04 11:31:44 +03:00
return result
2025-08-12 12:09:57 +03:00
proc saveBucketState * [ T ] (
2025-08-04 11:31:44 +03:00
store : RateLimitStore [ T ] , bucketState : BucketState
2025-07-14 11:14:36 +03:00
) : Future [ bool ] {. async . } =
try :
# Convert Moment to Unix seconds for storage
let lastTimeSeconds = bucketState . lastTimeFull . epochSeconds ( )
2025-07-16 09:22:29 +03:00
let jsonState =
% * {
" budget " : bucketState . budget ,
" budgetCap " : bucketState . budgetCap ,
" lastTimeFullSeconds " : lastTimeSeconds ,
}
2025-07-14 11:14:36 +03:00
store . db . exec (
2025-07-16 10:05:47 +03:00
sql" INSERT INTO kv_store (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value " ,
BUCKET_STATE_KEY ,
$ jsonState ,
2025-07-14 11:14:36 +03:00
)
return true
except :
return false
2025-08-12 12:09:57 +03:00
proc loadBucketState * [ T ] (
2025-08-04 11:31:44 +03:00
store : RateLimitStore [ T ]
2025-07-16 10:05:47 +03:00
) : Future [ Option [ BucketState ] ] {. async . } =
2025-07-16 09:22:29 +03:00
let jsonStr =
store . db . getValue ( sql" SELECT value FROM kv_store WHERE key = ? " , BUCKET_STATE_KEY )
2025-07-16 10:05:47 +03:00
if jsonStr = = " " :
return none ( BucketState )
2025-07-16 09:22:29 +03:00
let jsonData = parseJson ( jsonStr )
let unixSeconds = jsonData [ " lastTimeFullSeconds " ] . getInt ( ) . int64
2025-07-14 11:14:36 +03:00
let lastTimeFull = Moment . init ( unixSeconds , chronos . seconds ( 1 ) )
2025-07-16 10:05:47 +03:00
return some (
BucketState (
budget : jsonData [ " budget " ] . getInt ( ) ,
budgetCap : jsonData [ " budgetCap " ] . getInt ( ) ,
lastTimeFull : lastTimeFull ,
)
2025-07-14 11:14:36 +03:00
)
2025-08-04 10:43:59 +03:00
2025-08-12 12:09:57 +03:00
proc pushToQueue * [ T ] (
2025-08-04 11:31:44 +03:00
store : RateLimitStore [ T ] ,
2025-08-04 10:43:59 +03:00
queueType : QueueType ,
msgs : seq [ tuple [ msgId : string , msg : T ] ] ,
) : Future [ bool ] {. async . } =
try :
let batchId = store . nextBatchId
inc store . nextBatchId
let now = times . getTime ( ) . toUnix ( )
let queueTypeStr = $ queueType
if msgs . len > 0 :
store . db . exec ( sql" BEGIN TRANSACTION " )
try :
for msg in msgs :
2025-08-12 12:09:57 +03:00
let serialized = msg . msg . toFlatty ( )
let msgData = encode ( serialized )
2025-08-04 10:43:59 +03:00
store . db . exec (
sql" INSERT INTO ratelimit_queues (queue_type, msg_id, msg_data, batch_id, created_at) VALUES (?, ?, ?, ?, ?) " ,
queueTypeStr ,
msg . msgId ,
2025-08-12 12:09:57 +03:00
msgData ,
2025-08-04 10:43:59 +03:00
batchId ,
now ,
)
store . db . exec ( sql" COMMIT " )
except :
store . db . exec ( sql" ROLLBACK " )
raise
case queueType
of QueueType . Critical :
inc store . criticalLength
of QueueType . Normal :
inc store . normalLength
return true
except :
return false
2025-08-12 12:09:57 +03:00
proc popFromQueue * [ T ] (
2025-08-04 11:31:44 +03:00
store : RateLimitStore [ T ] , queueType : QueueType
2025-08-04 10:43:59 +03:00
) : Future [ Option [ seq [ tuple [ msgId : string , msg : T ] ] ] ] {. async . } =
try :
let queueTypeStr = $ queueType
# Get the oldest batch ID for this queue type
let oldestBatchStr = store . db . getValue (
sql" SELECT MIN(batch_id) FROM ratelimit_queues WHERE queue_type = ? " , queueTypeStr
)
if oldestBatchStr = = " " :
return none ( seq [ tuple [ msgId : string , msg : T ] ] )
let batchId = parseInt ( oldestBatchStr )
# Get all messages in this batch (preserve insertion order using rowid)
let rows = store . db . getAllRows (
sql" SELECT msg_id, msg_data FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ? ORDER BY rowid " ,
queueTypeStr ,
batchId ,
)
if rows . len = = 0 :
return none ( seq [ tuple [ msgId : string , msg : T ] ] )
var msgs : seq [ tuple [ msgId : string , msg : T ] ]
for row in rows :
let msgIdStr = row [ 0 ]
2025-08-12 12:09:57 +03:00
let msgDataB64 = row [ 1 ]
let serialized = decode ( msgDataB64 )
let msg = serialized . fromFlatty ( T )
msgs . add ( ( msgId : msgIdStr , msg : msg ) )
2025-08-04 10:43:59 +03:00
# Delete the batch from database
store . db . exec (
sql" DELETE FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ? " ,
queueTypeStr ,
batchId ,
)
case queueType
of QueueType . Critical :
dec store . criticalLength
of QueueType . Normal :
dec store . normalLength
return some ( msgs )
except :
return none ( seq [ tuple [ msgId : string , msg : T ] ] )
2025-09-01 11:50:41 +03:00
proc updateMessageStatuses * [ T ] (
store : RateLimitStore [ T ] , messageIds : seq [ string ] , status : MessageStatus
) : Future [ bool ] {. async . } =
try :
let now = times . getTime ( ) . toUnix ( )
store . db . exec ( sql" BEGIN TRANSACTION " )
for msgId in messageIds :
store . db . exec (
sql" INSERT INTO ratelimit_message_status (msg_id, status, updated_at) VALUES (?, ?, ?) ON CONFLICT(msg_id) DO UPDATE SET status = excluded.status, updated_at = excluded.updated_at " ,
msgId ,
status ,
now ,
)
store . db . exec ( sql" COMMIT " )
return true
except :
store . db . exec ( sql" ROLLBACK " )
return false
proc getMessageStatus * [ T ] (
store : RateLimitStore [ T ] , messageId : string
) : Future [ Option [ MessageStatus ] ] {. async . } =
let statusStr = store . db . getValue (
sql" SELECT status FROM ratelimit_message_status WHERE msg_id = ? " , messageId
)
if statusStr = = " " :
return none ( MessageStatus )
2025-09-01 12:07:39 +03:00
return some ( parseEnum [ MessageStatus ] ( statusStr ) )
2025-09-01 11:50:41 +03:00
2025-08-12 12:09:57 +03:00
proc getQueueLength * [ T ] ( store : RateLimitStore [ T ] , queueType : QueueType ) : int =
2025-08-04 10:43:59 +03:00
case queueType
of QueueType . Critical :
return store . criticalLength
of QueueType . Normal :
return store . normalLength