mirror of
https://github.com/logos-messaging/nim-chat-sdk.git
synced 2026-01-02 14:13:07 +00:00
Merge pull request #5 from waku-org/feat/rate-limit-store-state
Feat/rate-limit-store-state
This commit is contained in:
commit
2b1b7f5699
4
.gitignore
vendored
4
.gitignore
vendored
@ -21,8 +21,10 @@ nimcache/
|
||||
# Compiled files
|
||||
chat_sdk/*
|
||||
apps/*
|
||||
!*.nim
|
||||
tests/*
|
||||
|
||||
!*.nim
|
||||
ratelimit/*
|
||||
!*.nim
|
||||
!*.proto
|
||||
nimble.develop
|
||||
|
||||
@ -7,7 +7,7 @@ license = "MIT"
|
||||
srcDir = "src"
|
||||
|
||||
### Dependencies
|
||||
requires "nim >= 2.2.4", "chronicles", "chronos", "db_connector", "waku"
|
||||
requires "nim >= 2.2.4", "chronicles", "chronos", "db_connector", "flatty"
|
||||
|
||||
task buildSharedLib, "Build shared library for C bindings":
|
||||
exec "nim c --mm:refc --app:lib --out:../library/c-bindings/libchatsdk.so chat_sdk/chat_sdk.nim"
|
||||
|
||||
@ -3,15 +3,19 @@ import db_connector/db_sqlite
|
||||
import chronicles
|
||||
|
||||
proc ensureMigrationTable(db: DbConn) =
|
||||
db.exec(sql"""
|
||||
db.exec(
|
||||
sql"""
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
filename TEXT PRIMARY KEY,
|
||||
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
proc hasMigrationRun(db: DbConn, filename: string): bool =
|
||||
for row in db.fastRows(sql"SELECT 1 FROM schema_migrations WHERE filename = ?", filename):
|
||||
for row in db.fastRows(
|
||||
sql"SELECT 1 FROM schema_migrations WHERE filename = ?", filename
|
||||
):
|
||||
return true
|
||||
return false
|
||||
|
||||
|
||||
13
migrations/001_create_ratelimit_state.sql
Normal file
13
migrations/001_create_ratelimit_state.sql
Normal file
@ -0,0 +1,13 @@
|
||||
CREATE TABLE IF NOT EXISTS kv_store (
|
||||
key TEXT PRIMARY KEY,
|
||||
value BLOB
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS ratelimit_queues (
|
||||
queue_type TEXT NOT NULL,
|
||||
msg_id TEXT NOT NULL,
|
||||
msg_data BLOB NOT NULL,
|
||||
batch_id INTEGER NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
PRIMARY KEY (queue_type, batch_id, msg_id)
|
||||
);
|
||||
@ -1,6 +1,10 @@
|
||||
import std/[times, deques, options]
|
||||
import waku/common/rate_limit/token_bucket
|
||||
import std/[times, options]
|
||||
# TODO: move to waku's, chronos' or a lib tocken_bucket once decided where this will live
|
||||
import ./token_bucket
|
||||
# import waku/common/rate_limit/token_bucket
|
||||
import ./store
|
||||
import chronos
|
||||
import db_connector/db_sqlite
|
||||
|
||||
type
|
||||
CapacityState {.pure.} = enum
|
||||
@ -19,41 +23,50 @@ type
|
||||
Normal
|
||||
Optional
|
||||
|
||||
Serializable* =
|
||||
concept x
|
||||
x.toBytes() is seq[byte]
|
||||
MsgIdMsg[T] = tuple[msgId: string, msg: T]
|
||||
|
||||
MsgIdMsg[T: Serializable] = tuple[msgId: string, msg: T]
|
||||
MessageSender*[T] = proc(msgs: seq[MsgIdMsg[T]]) {.async.}
|
||||
|
||||
MessageSender*[T: Serializable] = proc(msgs: seq[MsgIdMsg[T]]) {.async.}
|
||||
|
||||
RateLimitManager*[T: Serializable] = ref object
|
||||
RateLimitManager*[T] = ref object
|
||||
store: RateLimitStore[T]
|
||||
bucket: TokenBucket
|
||||
sender: MessageSender[T]
|
||||
queueCritical: Deque[seq[MsgIdMsg[T]]]
|
||||
queueNormal: Deque[seq[MsgIdMsg[T]]]
|
||||
sleepDuration: chronos.Duration
|
||||
pxQueueHandleLoop: Future[void]
|
||||
|
||||
proc new*[T: Serializable](
|
||||
proc new*[T](
|
||||
M: type[RateLimitManager[T]],
|
||||
store: RateLimitStore[T],
|
||||
sender: MessageSender[T],
|
||||
capacity: int = 100,
|
||||
duration: chronos.Duration = chronos.minutes(10),
|
||||
sleepDuration: chronos.Duration = chronos.milliseconds(1000),
|
||||
): M =
|
||||
M(
|
||||
bucket: TokenBucket.newStrict(capacity, duration),
|
||||
): Future[RateLimitManager[T]] {.async.} =
|
||||
var current = await store.loadBucketState()
|
||||
if current.isNone():
|
||||
# initialize bucket state with full capacity
|
||||
current = some(
|
||||
BucketState(budget: capacity, budgetCap: capacity, lastTimeFull: Moment.now())
|
||||
)
|
||||
discard await store.saveBucketState(current.get())
|
||||
|
||||
return RateLimitManager[T](
|
||||
store: store,
|
||||
bucket: TokenBucket.new(
|
||||
current.get().budgetCap,
|
||||
duration,
|
||||
ReplenishMode.Strict,
|
||||
current.get().budget,
|
||||
current.get().lastTimeFull,
|
||||
),
|
||||
sender: sender,
|
||||
queueCritical: Deque[seq[MsgIdMsg[T]]](),
|
||||
queueNormal: Deque[seq[MsgIdMsg[T]]](),
|
||||
sleepDuration: sleepDuration,
|
||||
)
|
||||
|
||||
proc getCapacityState[T: Serializable](
|
||||
proc getCapacityState[T](
|
||||
manager: RateLimitManager[T], now: Moment, count: int = 1
|
||||
): CapacityState =
|
||||
let (budget, budgetCap) = manager.bucket.getAvailableCapacity(now)
|
||||
let (budget, budgetCap, _) = manager.bucket.getAvailableCapacity(now)
|
||||
let countAfter = budget - count
|
||||
let ratio = countAfter.float / budgetCap.float
|
||||
if ratio < 0.0:
|
||||
@ -63,62 +76,77 @@ proc getCapacityState[T: Serializable](
|
||||
else:
|
||||
return CapacityState.Normal
|
||||
|
||||
proc passToSender[T: Serializable](
|
||||
proc passToSender[T](
|
||||
manager: RateLimitManager[T],
|
||||
msgs: sink seq[MsgIdMsg[T]],
|
||||
msgs: seq[tuple[msgId: string, msg: T]],
|
||||
now: Moment,
|
||||
priority: Priority,
|
||||
): Future[SendResult] {.async.} =
|
||||
let count = msgs.len
|
||||
let capacity = manager.bucket.tryConsume(count, now)
|
||||
if not capacity:
|
||||
let consumed = manager.bucket.tryConsume(count, now)
|
||||
if not consumed:
|
||||
case priority
|
||||
of Priority.Critical:
|
||||
manager.queueCritical.addLast(msgs)
|
||||
discard await manager.store.pushToQueue(QueueType.Critical, msgs)
|
||||
return SendResult.Enqueued
|
||||
of Priority.Normal:
|
||||
manager.queueNormal.addLast(msgs)
|
||||
discard await manager.store.pushToQueue(QueueType.Normal, msgs)
|
||||
return SendResult.Enqueued
|
||||
of Priority.Optional:
|
||||
return SendResult.Dropped
|
||||
|
||||
let (budget, budgetCap, lastTimeFull) = manager.bucket.getAvailableCapacity(now)
|
||||
discard await manager.store.saveBucketState(
|
||||
BucketState(budget: budget, budgetCap: budgetCap, lastTimeFull: lastTimeFull)
|
||||
)
|
||||
await manager.sender(msgs)
|
||||
return SendResult.PassedToSender
|
||||
|
||||
proc processCriticalQueue[T: Serializable](
|
||||
proc processCriticalQueue[T](
|
||||
manager: RateLimitManager[T], now: Moment
|
||||
) {.async.} =
|
||||
while manager.queueCritical.len > 0:
|
||||
let msgs = manager.queueCritical.popFirst()
|
||||
): Future[void] {.async.} =
|
||||
while manager.store.getQueueLength(QueueType.Critical) > 0:
|
||||
# Peek at the next batch by getting it, but we'll handle putting it back if needed
|
||||
let maybeMsgs = await manager.store.popFromQueue(QueueType.Critical)
|
||||
if maybeMsgs.isNone():
|
||||
break
|
||||
|
||||
let msgs = maybeMsgs.get()
|
||||
let capacityState = manager.getCapacityState(now, msgs.len)
|
||||
if capacityState == CapacityState.Normal:
|
||||
discard await manager.passToSender(msgs, now, Priority.Critical)
|
||||
elif capacityState == CapacityState.AlmostNone:
|
||||
discard await manager.passToSender(msgs, now, Priority.Critical)
|
||||
else:
|
||||
# add back to critical queue
|
||||
manager.queueCritical.addFirst(msgs)
|
||||
# Put back to critical queue (add to front not possible, so we add to back and exit)
|
||||
discard await manager.store.pushToQueue(QueueType.Critical, msgs)
|
||||
break
|
||||
|
||||
proc processNormalQueue[T: Serializable](
|
||||
proc processNormalQueue[T](
|
||||
manager: RateLimitManager[T], now: Moment
|
||||
) {.async.} =
|
||||
while manager.queueNormal.len > 0:
|
||||
let msgs = manager.queueNormal.popFirst()
|
||||
): Future[void] {.async.} =
|
||||
while manager.store.getQueueLength(QueueType.Normal) > 0:
|
||||
# Peek at the next batch by getting it, but we'll handle putting it back if needed
|
||||
let maybeMsgs = await manager.store.popFromQueue(QueueType.Normal)
|
||||
if maybeMsgs.isNone():
|
||||
break
|
||||
|
||||
let msgs = maybeMsgs.get()
|
||||
let capacityState = manager.getCapacityState(now, msgs.len)
|
||||
if capacityState == CapacityState.Normal:
|
||||
discard await manager.passToSender(msgs, now, Priority.Normal)
|
||||
else:
|
||||
# add back to critical queue
|
||||
manager.queueNormal.addFirst(msgs)
|
||||
# Put back to normal queue (add to front not possible, so we add to back and exit)
|
||||
discard await manager.store.pushToQueue(QueueType.Normal, msgs)
|
||||
break
|
||||
|
||||
proc sendOrEnqueue*[T: Serializable](
|
||||
proc sendOrEnqueue*[T](
|
||||
manager: RateLimitManager[T],
|
||||
msgs: seq[MsgIdMsg[T]],
|
||||
msgs: seq[tuple[msgId: string, msg: T]],
|
||||
priority: Priority,
|
||||
now: Moment = Moment.now(),
|
||||
): Future[SendResult] {.async.} =
|
||||
let (_, budgetCap) = manager.bucket.getAvailableCapacity(now)
|
||||
let (_, budgetCap, _) = manager.bucket.getAvailableCapacity(now)
|
||||
if msgs.len.float / budgetCap.float >= 0.3:
|
||||
# drop batch if it's too large to avoid starvation
|
||||
return SendResult.DroppedBatchTooLarge
|
||||
@ -132,36 +160,22 @@ proc sendOrEnqueue*[T: Serializable](
|
||||
of Priority.Critical:
|
||||
return await manager.passToSender(msgs, now, priority)
|
||||
of Priority.Normal:
|
||||
manager.queueNormal.addLast(msgs)
|
||||
discard await manager.store.pushToQueue(QueueType.Normal, msgs)
|
||||
return SendResult.Enqueued
|
||||
of Priority.Optional:
|
||||
return SendResult.Dropped
|
||||
of CapacityState.None:
|
||||
case priority
|
||||
of Priority.Critical:
|
||||
manager.queueCritical.addLast(msgs)
|
||||
discard await manager.store.pushToQueue(QueueType.Critical, msgs)
|
||||
return SendResult.Enqueued
|
||||
of Priority.Normal:
|
||||
manager.queueNormal.addLast(msgs)
|
||||
discard await manager.store.pushToQueue(QueueType.Normal, msgs)
|
||||
return SendResult.Enqueued
|
||||
of Priority.Optional:
|
||||
return SendResult.Dropped
|
||||
|
||||
proc getEnqueued*[T: Serializable](
|
||||
manager: RateLimitManager[T]
|
||||
): tuple[critical: seq[MsgIdMsg[T]], normal: seq[MsgIdMsg[T]]] =
|
||||
var criticalMsgs: seq[MsgIdMsg[T]]
|
||||
var normalMsgs: seq[MsgIdMsg[T]]
|
||||
|
||||
for batch in manager.queueCritical:
|
||||
criticalMsgs.add(batch)
|
||||
|
||||
for batch in manager.queueNormal:
|
||||
normalMsgs.add(batch)
|
||||
|
||||
return (criticalMsgs, normalMsgs)
|
||||
|
||||
proc queueHandleLoop[T: Serializable](
|
||||
proc queueHandleLoop*[T](
|
||||
manager: RateLimitManager[T],
|
||||
nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} =
|
||||
Moment.now(),
|
||||
@ -177,20 +191,20 @@ proc queueHandleLoop[T: Serializable](
|
||||
# configurable sleep duration for processing queued messages
|
||||
await sleepAsync(manager.sleepDuration)
|
||||
|
||||
proc start*[T: Serializable](
|
||||
proc start*[T](
|
||||
manager: RateLimitManager[T],
|
||||
nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} =
|
||||
Moment.now(),
|
||||
) {.async.} =
|
||||
manager.pxQueueHandleLoop = manager.queueHandleLoop(nowProvider)
|
||||
manager.pxQueueHandleLoop = queueHandleLoop(manager, nowProvider)
|
||||
|
||||
proc stop*[T: Serializable](manager: RateLimitManager[T]) {.async.} =
|
||||
proc stop*[T](manager: RateLimitManager[T]) {.async.} =
|
||||
if not isNil(manager.pxQueueHandleLoop):
|
||||
await manager.pxQueueHandleLoop.cancelAndWait()
|
||||
|
||||
func `$`*[T: Serializable](b: RateLimitManager[T]): string {.inline.} =
|
||||
func `$`*[T](b: RateLimitManager[T]): string {.inline.} =
|
||||
if isNil(b):
|
||||
return "nil"
|
||||
return
|
||||
"RateLimitManager(critical: " & $b.queueCritical.len & ", normal: " &
|
||||
$b.queueNormal.len & ")"
|
||||
"RateLimitManager(critical: " & $b.store.getQueueLength(QueueType.Critical) &
|
||||
", normal: " & $b.store.getQueueLength(QueueType.Normal) & ")"
|
||||
200
ratelimit/store.nim
Normal file
200
ratelimit/store.nim
Normal file
@ -0,0 +1,200 @@
|
||||
import std/[times, strutils, json, options, base64]
|
||||
import db_connector/db_sqlite
|
||||
import chronos
|
||||
import flatty
|
||||
|
||||
type
|
||||
RateLimitStore*[T] = ref object
|
||||
db: DbConn
|
||||
dbPath: string
|
||||
criticalLength: int
|
||||
normalLength: int
|
||||
nextBatchId: int
|
||||
|
||||
BucketState* {.pure} = object
|
||||
budget*: int
|
||||
budgetCap*: int
|
||||
lastTimeFull*: Moment
|
||||
|
||||
QueueType* {.pure.} = enum
|
||||
Critical = "critical"
|
||||
Normal = "normal"
|
||||
|
||||
const BUCKET_STATE_KEY = "rate_limit_bucket_state"
|
||||
|
||||
## TODO find a way to make these procs async
|
||||
|
||||
proc new*[T](M: type[RateLimitStore[T]], db: DbConn): Future[M] {.async} =
|
||||
result = M(db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1)
|
||||
|
||||
# Initialize cached lengths from database
|
||||
let criticalCount = db.getValue(
|
||||
sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?",
|
||||
"critical",
|
||||
)
|
||||
let normalCount = db.getValue(
|
||||
sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?",
|
||||
"normal",
|
||||
)
|
||||
|
||||
result.criticalLength =
|
||||
if criticalCount == "":
|
||||
0
|
||||
else:
|
||||
parseInt(criticalCount)
|
||||
result.normalLength =
|
||||
if normalCount == "":
|
||||
0
|
||||
else:
|
||||
parseInt(normalCount)
|
||||
|
||||
# Get next batch ID
|
||||
let maxBatch = db.getValue(sql"SELECT MAX(batch_id) FROM ratelimit_queues")
|
||||
result.nextBatchId =
|
||||
if maxBatch == "":
|
||||
1
|
||||
else:
|
||||
parseInt(maxBatch) + 1
|
||||
|
||||
return result
|
||||
|
||||
proc saveBucketState*[T](
|
||||
store: RateLimitStore[T], bucketState: BucketState
|
||||
): Future[bool] {.async.} =
|
||||
try:
|
||||
# Convert Moment to Unix seconds for storage
|
||||
let lastTimeSeconds = bucketState.lastTimeFull.epochSeconds()
|
||||
|
||||
let jsonState =
|
||||
%*{
|
||||
"budget": bucketState.budget,
|
||||
"budgetCap": bucketState.budgetCap,
|
||||
"lastTimeFullSeconds": lastTimeSeconds,
|
||||
}
|
||||
store.db.exec(
|
||||
sql"INSERT INTO kv_store (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value",
|
||||
BUCKET_STATE_KEY,
|
||||
$jsonState,
|
||||
)
|
||||
return true
|
||||
except:
|
||||
return false
|
||||
|
||||
proc loadBucketState*[T](
|
||||
store: RateLimitStore[T]
|
||||
): Future[Option[BucketState]] {.async.} =
|
||||
let jsonStr =
|
||||
store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY)
|
||||
if jsonStr == "":
|
||||
return none(BucketState)
|
||||
|
||||
let jsonData = parseJson(jsonStr)
|
||||
let unixSeconds = jsonData["lastTimeFullSeconds"].getInt().int64
|
||||
let lastTimeFull = Moment.init(unixSeconds, chronos.seconds(1))
|
||||
|
||||
return some(
|
||||
BucketState(
|
||||
budget: jsonData["budget"].getInt(),
|
||||
budgetCap: jsonData["budgetCap"].getInt(),
|
||||
lastTimeFull: lastTimeFull,
|
||||
)
|
||||
)
|
||||
|
||||
proc pushToQueue*[T](
|
||||
store: RateLimitStore[T],
|
||||
queueType: QueueType,
|
||||
msgs: seq[tuple[msgId: string, msg: T]],
|
||||
): Future[bool] {.async.} =
|
||||
try:
|
||||
let batchId = store.nextBatchId
|
||||
inc store.nextBatchId
|
||||
let now = times.getTime().toUnix()
|
||||
let queueTypeStr = $queueType
|
||||
|
||||
if msgs.len > 0:
|
||||
store.db.exec(sql"BEGIN TRANSACTION")
|
||||
try:
|
||||
for msg in msgs:
|
||||
let serialized = msg.msg.toFlatty()
|
||||
let msgData = encode(serialized)
|
||||
store.db.exec(
|
||||
sql"INSERT INTO ratelimit_queues (queue_type, msg_id, msg_data, batch_id, created_at) VALUES (?, ?, ?, ?, ?)",
|
||||
queueTypeStr,
|
||||
msg.msgId,
|
||||
msgData,
|
||||
batchId,
|
||||
now,
|
||||
)
|
||||
store.db.exec(sql"COMMIT")
|
||||
except:
|
||||
store.db.exec(sql"ROLLBACK")
|
||||
raise
|
||||
|
||||
case queueType
|
||||
of QueueType.Critical:
|
||||
inc store.criticalLength
|
||||
of QueueType.Normal:
|
||||
inc store.normalLength
|
||||
|
||||
return true
|
||||
except:
|
||||
return false
|
||||
|
||||
proc popFromQueue*[T](
|
||||
store: RateLimitStore[T], queueType: QueueType
|
||||
): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} =
|
||||
try:
|
||||
let queueTypeStr = $queueType
|
||||
|
||||
# Get the oldest batch ID for this queue type
|
||||
let oldestBatchStr = store.db.getValue(
|
||||
sql"SELECT MIN(batch_id) FROM ratelimit_queues WHERE queue_type = ?", queueTypeStr
|
||||
)
|
||||
|
||||
if oldestBatchStr == "":
|
||||
return none(seq[tuple[msgId: string, msg: T]])
|
||||
|
||||
let batchId = parseInt(oldestBatchStr)
|
||||
|
||||
# Get all messages in this batch (preserve insertion order using rowid)
|
||||
let rows = store.db.getAllRows(
|
||||
sql"SELECT msg_id, msg_data FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ? ORDER BY rowid",
|
||||
queueTypeStr,
|
||||
batchId,
|
||||
)
|
||||
|
||||
if rows.len == 0:
|
||||
return none(seq[tuple[msgId: string, msg: T]])
|
||||
|
||||
var msgs: seq[tuple[msgId: string, msg: T]]
|
||||
for row in rows:
|
||||
let msgIdStr = row[0]
|
||||
let msgDataB64 = row[1]
|
||||
|
||||
let serialized = decode(msgDataB64)
|
||||
let msg = serialized.fromFlatty(T)
|
||||
msgs.add((msgId: msgIdStr, msg: msg))
|
||||
|
||||
# Delete the batch from database
|
||||
store.db.exec(
|
||||
sql"DELETE FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ?",
|
||||
queueTypeStr,
|
||||
batchId,
|
||||
)
|
||||
|
||||
case queueType
|
||||
of QueueType.Critical:
|
||||
dec store.criticalLength
|
||||
of QueueType.Normal:
|
||||
dec store.normalLength
|
||||
|
||||
return some(msgs)
|
||||
except:
|
||||
return none(seq[tuple[msgId: string, msg: T]])
|
||||
|
||||
proc getQueueLength*[T](store: RateLimitStore[T], queueType: QueueType): int =
|
||||
case queueType
|
||||
of QueueType.Critical:
|
||||
return store.criticalLength
|
||||
of QueueType.Normal:
|
||||
return store.normalLength
|
||||
206
ratelimit/token_bucket.nim
Normal file
206
ratelimit/token_bucket.nim
Normal file
@ -0,0 +1,206 @@
|
||||
{.push raises: [].}
|
||||
|
||||
import chronos, std/math, std/options
|
||||
|
||||
const BUDGET_COMPENSATION_LIMIT_PERCENT = 0.25
|
||||
|
||||
## TODO! This will be remoded and replaced by https://github.com/status-im/nim-chronos/pull/582
|
||||
|
||||
## This is an extract from chronos/rate_limit.nim due to the found bug in the original implementation.
|
||||
## Unfortunately that bug cannot be solved without harm the original features of TokenBucket class.
|
||||
## So, this current shortcut is used to enable move ahead with nwaku rate limiter implementation.
|
||||
## ref: https://github.com/status-im/nim-chronos/issues/500
|
||||
##
|
||||
## This version of TokenBucket is different from the original one in chronos/rate_limit.nim in many ways:
|
||||
## - It has a new mode called `Compensating` which is the default mode.
|
||||
## Compensation is calculated as the not used bucket capacity in the last measured period(s) in average.
|
||||
## or up until maximum the allowed compansation treshold (Currently it is const 25%).
|
||||
## Also compensation takes care of the proper time period calculation to avoid non-usage periods that can lead to
|
||||
## overcompensation.
|
||||
## - Strict mode is also available which will only replenish when time period is over but also will fill
|
||||
## the bucket to the max capacity.
|
||||
|
||||
type
|
||||
ReplenishMode* = enum
|
||||
Strict
|
||||
Compensating
|
||||
|
||||
TokenBucket* = ref object
|
||||
budget: int ## Current number of tokens in the bucket
|
||||
budgetCap: int ## Bucket capacity
|
||||
lastTimeFull: Moment
|
||||
## This timer measures the proper periodizaiton of the bucket refilling
|
||||
fillDuration: Duration ## Refill period
|
||||
case replenishMode*: ReplenishMode
|
||||
of Strict:
|
||||
## In strict mode, the bucket is refilled only till the budgetCap
|
||||
discard
|
||||
of Compensating:
|
||||
## This is the default mode.
|
||||
maxCompensation: float
|
||||
|
||||
func periodDistance(bucket: TokenBucket, currentTime: Moment): float =
|
||||
## notice fillDuration cannot be zero by design
|
||||
## period distance is a float number representing the calculated period time
|
||||
## since the last time bucket was refilled.
|
||||
return
|
||||
nanoseconds(currentTime - bucket.lastTimeFull).float /
|
||||
nanoseconds(bucket.fillDuration).float
|
||||
|
||||
func getUsageAverageSince(bucket: TokenBucket, distance: float): float =
|
||||
if distance == 0.float:
|
||||
## in case there is zero time difference than the usage percentage is 100%
|
||||
return 1.0
|
||||
|
||||
## budgetCap can never be zero
|
||||
## usage average is calculated as a percentage of total capacity available over
|
||||
## the measured period
|
||||
return bucket.budget.float / bucket.budgetCap.float / distance
|
||||
|
||||
proc calcCompensation(bucket: TokenBucket, averageUsage: float): int =
|
||||
# if we already fully used or even overused the tokens, there is no place for compensation
|
||||
if averageUsage >= 1.0:
|
||||
return 0
|
||||
|
||||
## compensation is the not used bucket capacity in the last measured period(s) in average.
|
||||
## or maximum the allowed compansation treshold
|
||||
let compensationPercent =
|
||||
min((1.0 - averageUsage) * bucket.budgetCap.float, bucket.maxCompensation)
|
||||
return trunc(compensationPercent).int
|
||||
|
||||
func periodElapsed(bucket: TokenBucket, currentTime: Moment): bool =
|
||||
return currentTime - bucket.lastTimeFull >= bucket.fillDuration
|
||||
|
||||
## Update will take place if bucket is empty and trying to consume tokens.
|
||||
## It checks if the bucket can be replenished as refill duration is passed or not.
|
||||
## - strict mode:
|
||||
proc updateStrict(bucket: TokenBucket, currentTime: Moment) =
|
||||
if bucket.fillDuration == default(Duration):
|
||||
bucket.budget = min(bucket.budgetCap, bucket.budget)
|
||||
return
|
||||
|
||||
if not periodElapsed(bucket, currentTime):
|
||||
return
|
||||
|
||||
bucket.budget = bucket.budgetCap
|
||||
bucket.lastTimeFull = currentTime
|
||||
|
||||
## - compensating - ballancing load:
|
||||
## - between updates we calculate average load (current bucket capacity / number of periods till last update)
|
||||
## - gives the percentage load used recently
|
||||
## - with this we can replenish bucket up to 100% + calculated leftover from previous period (caped with max treshold)
|
||||
proc updateWithCompensation(bucket: TokenBucket, currentTime: Moment) =
|
||||
if bucket.fillDuration == default(Duration):
|
||||
bucket.budget = min(bucket.budgetCap, bucket.budget)
|
||||
return
|
||||
|
||||
# do not replenish within the same period
|
||||
if not periodElapsed(bucket, currentTime):
|
||||
return
|
||||
|
||||
let distance = bucket.periodDistance(currentTime)
|
||||
let recentAvgUsage = bucket.getUsageAverageSince(distance)
|
||||
let compensation = bucket.calcCompensation(recentAvgUsage)
|
||||
|
||||
bucket.budget = bucket.budgetCap + compensation
|
||||
bucket.lastTimeFull = currentTime
|
||||
|
||||
proc update(bucket: TokenBucket, currentTime: Moment) =
|
||||
if bucket.replenishMode == ReplenishMode.Compensating:
|
||||
updateWithCompensation(bucket, currentTime)
|
||||
else:
|
||||
updateStrict(bucket, currentTime)
|
||||
|
||||
proc getAvailableCapacity*(
|
||||
bucket: TokenBucket, currentTime: Moment = Moment.now()
|
||||
): tuple[budget: int, budgetCap: int, lastTimeFull: Moment] =
|
||||
## Returns the available capacity of the bucket: (budget, budgetCap, lastTimeFull)
|
||||
|
||||
if periodElapsed(bucket, currentTime):
|
||||
case bucket.replenishMode
|
||||
of ReplenishMode.Strict:
|
||||
return (bucket.budgetCap, bucket.budgetCap, bucket.lastTimeFull)
|
||||
of ReplenishMode.Compensating:
|
||||
let distance = bucket.periodDistance(currentTime)
|
||||
let recentAvgUsage = bucket.getUsageAverageSince(distance)
|
||||
let compensation = bucket.calcCompensation(recentAvgUsage)
|
||||
let availableBudget = bucket.budgetCap + compensation
|
||||
return (availableBudget, bucket.budgetCap, bucket.lastTimeFull)
|
||||
return (bucket.budget, bucket.budgetCap, bucket.lastTimeFull)
|
||||
|
||||
proc tryConsume*(bucket: TokenBucket, tokens: int, now = Moment.now()): bool =
|
||||
## If `tokens` are available, consume them,
|
||||
## Otherwhise, return false.
|
||||
|
||||
if bucket.budget >= bucket.budgetCap:
|
||||
bucket.lastTimeFull = now
|
||||
|
||||
if bucket.budget >= tokens:
|
||||
bucket.budget -= tokens
|
||||
return true
|
||||
|
||||
bucket.update(now)
|
||||
|
||||
if bucket.budget >= tokens:
|
||||
bucket.budget -= tokens
|
||||
return true
|
||||
else:
|
||||
return false
|
||||
|
||||
proc replenish*(bucket: TokenBucket, tokens: int, now = Moment.now()) =
|
||||
## Add `tokens` to the budget (capped to the bucket capacity)
|
||||
bucket.budget += tokens
|
||||
bucket.update(now)
|
||||
|
||||
proc new*(
|
||||
T: type[TokenBucket],
|
||||
budgetCap: int,
|
||||
fillDuration: Duration = 1.seconds,
|
||||
mode: ReplenishMode = ReplenishMode.Compensating,
|
||||
budget: int = -1, # -1 means "use budgetCap"
|
||||
lastTimeFull: Moment = Moment.now(),
|
||||
): T =
|
||||
assert not isZero(fillDuration)
|
||||
assert budgetCap != 0
|
||||
assert lastTimeFull <= Moment.now()
|
||||
let actualBudget = if budget == -1: budgetCap else: budget
|
||||
assert actualBudget >= 0 and actualBudget <= budgetCap
|
||||
|
||||
## Create different mode TokenBucket
|
||||
case mode
|
||||
of ReplenishMode.Strict:
|
||||
return T(
|
||||
budget: actualBudget,
|
||||
budgetCap: budgetCap,
|
||||
fillDuration: fillDuration,
|
||||
lastTimeFull: lastTimeFull,
|
||||
replenishMode: mode,
|
||||
)
|
||||
of ReplenishMode.Compensating:
|
||||
T(
|
||||
budget: actualBudget,
|
||||
budgetCap: budgetCap,
|
||||
fillDuration: fillDuration,
|
||||
lastTimeFull: lastTimeFull,
|
||||
replenishMode: mode,
|
||||
maxCompensation: budgetCap.float * BUDGET_COMPENSATION_LIMIT_PERCENT,
|
||||
)
|
||||
|
||||
proc newStrict*(T: type[TokenBucket], capacity: int, period: Duration): TokenBucket =
|
||||
T.new(capacity, period, ReplenishMode.Strict)
|
||||
|
||||
proc newCompensating*(
|
||||
T: type[TokenBucket], capacity: int, period: Duration
|
||||
): TokenBucket =
|
||||
T.new(capacity, period, ReplenishMode.Compensating)
|
||||
|
||||
func `$`*(b: TokenBucket): string {.inline.} =
|
||||
if isNil(b):
|
||||
return "nil"
|
||||
return $b.budgetCap & "/" & $b.fillDuration
|
||||
|
||||
func `$`*(ob: Option[TokenBucket]): string {.inline.} =
|
||||
if ob.isNone():
|
||||
return "no-limit"
|
||||
|
||||
return $ob.get()
|
||||
@ -1,13 +1,18 @@
|
||||
import testutils/unittests
|
||||
import ../ratelimit/rate_limit_manager
|
||||
import ../ratelimit/ratelimit_manager
|
||||
import ../ratelimit/store
|
||||
import chronos
|
||||
import db_connector/db_sqlite
|
||||
import ../chat_sdk/migration
|
||||
import std/[os, options]
|
||||
|
||||
# Implement the Serializable concept for string
|
||||
proc toBytes*(s: string): seq[byte] =
|
||||
cast[seq[byte]](s)
|
||||
var dbName = "test_ratelimit_manager.db"
|
||||
|
||||
suite "Queue RateLimitManager":
|
||||
setup:
|
||||
let db = open(dbName, "", "", "")
|
||||
runMigrations(db)
|
||||
|
||||
var sentMessages: seq[tuple[msgId: string, msg: string]]
|
||||
var senderCallCount: int = 0
|
||||
|
||||
@ -20,10 +25,17 @@ suite "Queue RateLimitManager":
|
||||
sentMessages.add(msg)
|
||||
await sleepAsync(chronos.milliseconds(10))
|
||||
|
||||
teardown:
|
||||
if db != nil:
|
||||
db.close()
|
||||
if fileExists(dbName):
|
||||
removeFile(dbName)
|
||||
|
||||
asyncTest "sendOrEnqueue - immediate send when capacity available":
|
||||
## Given
|
||||
let manager = RateLimitManager[string].new(
|
||||
mockSender, capacity = 10, duration = chronos.milliseconds(100)
|
||||
let store: RateLimitStore[string] = await RateLimitStore[string].new(db)
|
||||
let manager = await RateLimitManager[string].new(
|
||||
store, mockSender, capacity = 10, duration = chronos.milliseconds(100)
|
||||
)
|
||||
let testMsg = "Hello World"
|
||||
|
||||
@ -40,8 +52,9 @@ suite "Queue RateLimitManager":
|
||||
|
||||
asyncTest "sendOrEnqueue - multiple messages":
|
||||
## Given
|
||||
let manager = RateLimitManager[string].new(
|
||||
mockSender, capacity = 10, duration = chronos.milliseconds(100)
|
||||
let store = await RateLimitStore[string].new(db)
|
||||
let manager = await RateLimitManager[string].new(
|
||||
store, mockSender, capacity = 10, duration = chronos.milliseconds(100)
|
||||
)
|
||||
|
||||
## When
|
||||
@ -60,7 +73,9 @@ suite "Queue RateLimitManager":
|
||||
|
||||
asyncTest "start and stop - drop large batch":
|
||||
## Given
|
||||
let manager = RateLimitManager[string].new(
|
||||
let store = await RateLimitStore[string].new(db)
|
||||
let manager = await RateLimitManager[string].new(
|
||||
store,
|
||||
mockSender,
|
||||
capacity = 2,
|
||||
duration = chronos.milliseconds(100),
|
||||
@ -75,7 +90,9 @@ suite "Queue RateLimitManager":
|
||||
|
||||
asyncTest "enqueue - enqueue critical only when exceeded":
|
||||
## Given
|
||||
let manager = RateLimitManager[string].new(
|
||||
let store = await RateLimitStore[string].new(db)
|
||||
let manager = await RateLimitManager[string].new(
|
||||
store,
|
||||
mockSender,
|
||||
capacity = 10,
|
||||
duration = chronos.milliseconds(100),
|
||||
@ -115,15 +132,11 @@ suite "Queue RateLimitManager":
|
||||
r10 == PassedToSender
|
||||
r11 == Enqueued
|
||||
|
||||
let (critical, normal) = manager.getEnqueued()
|
||||
check:
|
||||
critical.len == 1
|
||||
normal.len == 0
|
||||
critical[0].msgId == "msg11"
|
||||
|
||||
asyncTest "enqueue - enqueue normal on 70% capacity":
|
||||
## Given
|
||||
let manager = RateLimitManager[string].new(
|
||||
## Given
|
||||
let store = await RateLimitStore[string].new(db)
|
||||
let manager = await RateLimitManager[string].new(
|
||||
store,
|
||||
mockSender,
|
||||
capacity = 10,
|
||||
duration = chronos.milliseconds(100),
|
||||
@ -164,17 +177,11 @@ suite "Queue RateLimitManager":
|
||||
r11 == PassedToSender
|
||||
r12 == Dropped
|
||||
|
||||
let (critical, normal) = manager.getEnqueued()
|
||||
check:
|
||||
critical.len == 0
|
||||
normal.len == 3
|
||||
normal[0].msgId == "msg8"
|
||||
normal[1].msgId == "msg9"
|
||||
normal[2].msgId == "msg10"
|
||||
|
||||
asyncTest "enqueue - process queued messages":
|
||||
## Given
|
||||
let manager = RateLimitManager[string].new(
|
||||
let store = await RateLimitStore[string].new(db)
|
||||
let manager = await RateLimitManager[string].new(
|
||||
store,
|
||||
mockSender,
|
||||
capacity = 10,
|
||||
duration = chronos.milliseconds(200),
|
||||
@ -225,24 +232,9 @@ suite "Queue RateLimitManager":
|
||||
r14 == PassedToSender
|
||||
r15 == Enqueued
|
||||
|
||||
var (critical, normal) = manager.getEnqueued()
|
||||
check:
|
||||
critical.len == 1
|
||||
normal.len == 3
|
||||
normal[0].msgId == "8"
|
||||
normal[1].msgId == "9"
|
||||
normal[2].msgId == "10"
|
||||
critical[0].msgId == "15"
|
||||
|
||||
nowRef.value = now + chronos.milliseconds(250)
|
||||
await sleepAsync(chronos.milliseconds(250))
|
||||
|
||||
(critical, normal) = manager.getEnqueued()
|
||||
check:
|
||||
critical.len == 0
|
||||
normal.len == 0
|
||||
senderCallCount == 14
|
||||
sentMessages.len == 14
|
||||
senderCallCount == 10 # 10 messages passed to sender
|
||||
sentMessages.len == 10
|
||||
sentMessages[0].msgId == "1"
|
||||
sentMessages[1].msgId == "2"
|
||||
sentMessages[2].msgId == "3"
|
||||
@ -253,6 +245,13 @@ suite "Queue RateLimitManager":
|
||||
sentMessages[7].msgId == "11"
|
||||
sentMessages[8].msgId == "13"
|
||||
sentMessages[9].msgId == "14"
|
||||
|
||||
nowRef.value = now + chronos.milliseconds(250)
|
||||
await sleepAsync(chronos.milliseconds(250))
|
||||
|
||||
check:
|
||||
senderCallCount == 14
|
||||
sentMessages.len == 14
|
||||
sentMessages[10].msgId == "15"
|
||||
sentMessages[11].msgId == "8"
|
||||
sentMessages[12].msgId == "9"
|
||||
208
tests/test_store.nim
Normal file
208
tests/test_store.nim
Normal file
@ -0,0 +1,208 @@
|
||||
import testutils/unittests
|
||||
import ../ratelimit/store
|
||||
import chronos
|
||||
import db_connector/db_sqlite
|
||||
import ../chat_sdk/migration
|
||||
import std/[options, os, json]
|
||||
import flatty
|
||||
|
||||
const dbName = "test_store.db"
|
||||
|
||||
suite "SqliteRateLimitStore Tests":
|
||||
setup:
|
||||
let db = open(dbName, "", "", "")
|
||||
runMigrations(db)
|
||||
|
||||
teardown:
|
||||
if db != nil:
|
||||
db.close()
|
||||
if fileExists(dbName):
|
||||
removeFile(dbName)
|
||||
|
||||
asyncTest "newSqliteRateLimitStore - empty state":
|
||||
## Given
|
||||
let store = await RateLimitStore[string].new(db)
|
||||
|
||||
## When
|
||||
let loadedState = await store.loadBucketState()
|
||||
|
||||
## Then
|
||||
check loadedState.isNone()
|
||||
|
||||
asyncTest "saveBucketState and loadBucketState - state persistence":
|
||||
## Given
|
||||
let store = await RateLimitStore[string].new(db)
|
||||
|
||||
let now = Moment.now()
|
||||
echo "now: ", now.epochSeconds()
|
||||
let newBucketState = BucketState(budget: 5, budgetCap: 20, lastTimeFull: now)
|
||||
|
||||
## When
|
||||
let saveResult = await store.saveBucketState(newBucketState)
|
||||
let loadedState = await store.loadBucketState()
|
||||
|
||||
## Then
|
||||
check saveResult == true
|
||||
check loadedState.isSome()
|
||||
check loadedState.get().budget == newBucketState.budget
|
||||
check loadedState.get().budgetCap == newBucketState.budgetCap
|
||||
check loadedState.get().lastTimeFull.epochSeconds() ==
|
||||
newBucketState.lastTimeFull.epochSeconds()
|
||||
|
||||
asyncTest "queue operations - empty store":
|
||||
## Given
|
||||
let store = await RateLimitStore[string].new(db)
|
||||
|
||||
## When/Then
|
||||
check store.getQueueLength(QueueType.Critical) == 0
|
||||
check store.getQueueLength(QueueType.Normal) == 0
|
||||
|
||||
let criticalPop = await store.popFromQueue(QueueType.Critical)
|
||||
let normalPop = await store.popFromQueue(QueueType.Normal)
|
||||
|
||||
check criticalPop.isNone()
|
||||
check normalPop.isNone()
|
||||
|
||||
asyncTest "addToQueue and popFromQueue - single batch":
|
||||
## Given
|
||||
let store = await RateLimitStore[string].new(db)
|
||||
let msgs = @[("msg1", "Hello"), ("msg2", "World")]
|
||||
|
||||
## When
|
||||
let addResult = await store.pushToQueue(QueueType.Critical, msgs)
|
||||
|
||||
## Then
|
||||
check addResult == true
|
||||
check store.getQueueLength(QueueType.Critical) == 1
|
||||
check store.getQueueLength(QueueType.Normal) == 0
|
||||
|
||||
## When
|
||||
let popResult = await store.popFromQueue(QueueType.Critical)
|
||||
|
||||
## Then
|
||||
check popResult.isSome()
|
||||
let poppedMsgs = popResult.get()
|
||||
check poppedMsgs.len == 2
|
||||
check poppedMsgs[0].msgId == "msg1"
|
||||
check poppedMsgs[0].msg == "Hello"
|
||||
check poppedMsgs[1].msgId == "msg2"
|
||||
check poppedMsgs[1].msg == "World"
|
||||
|
||||
check store.getQueueLength(QueueType.Critical) == 0
|
||||
|
||||
asyncTest "addToQueue and popFromQueue - multiple batches FIFO":
|
||||
## Given
|
||||
let store = await RateLimitStore[string].new(db)
|
||||
let batch1 = @[("msg1", "First")]
|
||||
let batch2 = @[("msg2", "Second")]
|
||||
let batch3 = @[("msg3", "Third")]
|
||||
|
||||
## When - Add batches
|
||||
let result1 = await store.pushToQueue(QueueType.Normal, batch1)
|
||||
check result1 == true
|
||||
let result2 = await store.pushToQueue(QueueType.Normal, batch2)
|
||||
check result2 == true
|
||||
let result3 = await store.pushToQueue(QueueType.Normal, batch3)
|
||||
check result3 == true
|
||||
|
||||
## Then - Check lengths
|
||||
check store.getQueueLength(QueueType.Normal) == 3
|
||||
check store.getQueueLength(QueueType.Critical) == 0
|
||||
|
||||
## When/Then - Pop in FIFO order
|
||||
let pop1 = await store.popFromQueue(QueueType.Normal)
|
||||
check pop1.isSome()
|
||||
check pop1.get()[0].msg == "First"
|
||||
check store.getQueueLength(QueueType.Normal) == 2
|
||||
|
||||
let pop2 = await store.popFromQueue(QueueType.Normal)
|
||||
check pop2.isSome()
|
||||
check pop2.get()[0].msg == "Second"
|
||||
check store.getQueueLength(QueueType.Normal) == 1
|
||||
|
||||
let pop3 = await store.popFromQueue(QueueType.Normal)
|
||||
check pop3.isSome()
|
||||
check pop3.get()[0].msg == "Third"
|
||||
check store.getQueueLength(QueueType.Normal) == 0
|
||||
|
||||
let pop4 = await store.popFromQueue(QueueType.Normal)
|
||||
check pop4.isNone()
|
||||
|
||||
asyncTest "queue isolation - critical and normal queues are separate":
|
||||
## Given
|
||||
let store = await RateLimitStore[string].new(db)
|
||||
let criticalMsgs = @[("crit1", "Critical Message")]
|
||||
let normalMsgs = @[("norm1", "Normal Message")]
|
||||
|
||||
## When
|
||||
let critResult = await store.pushToQueue(QueueType.Critical, criticalMsgs)
|
||||
check critResult == true
|
||||
let normResult = await store.pushToQueue(QueueType.Normal, normalMsgs)
|
||||
check normResult == true
|
||||
|
||||
## Then
|
||||
check store.getQueueLength(QueueType.Critical) == 1
|
||||
check store.getQueueLength(QueueType.Normal) == 1
|
||||
|
||||
## When - Pop from critical
|
||||
let criticalPop = await store.popFromQueue(QueueType.Critical)
|
||||
check criticalPop.isSome()
|
||||
check criticalPop.get()[0].msg == "Critical Message"
|
||||
|
||||
## Then - Normal queue unaffected
|
||||
check store.getQueueLength(QueueType.Critical) == 0
|
||||
check store.getQueueLength(QueueType.Normal) == 1
|
||||
|
||||
## When - Pop from normal
|
||||
let normalPop = await store.popFromQueue(QueueType.Normal)
|
||||
check normalPop.isSome()
|
||||
check normalPop.get()[0].msg == "Normal Message"
|
||||
|
||||
## Then - All queues empty
|
||||
check store.getQueueLength(QueueType.Critical) == 0
|
||||
check store.getQueueLength(QueueType.Normal) == 0
|
||||
|
||||
asyncTest "queue persistence across store instances":
|
||||
## Given
|
||||
let msgs = @[("persist1", "Persistent Message")]
|
||||
|
||||
block:
|
||||
let store1 = await RateLimitStore[string].new(db)
|
||||
let addResult = await store1.pushToQueue(QueueType.Critical, msgs)
|
||||
check addResult == true
|
||||
check store1.getQueueLength(QueueType.Critical) == 1
|
||||
|
||||
## When - Create new store instance
|
||||
block:
|
||||
let store2 =await RateLimitStore[string].new(db)
|
||||
|
||||
## Then - Queue length should be restored from database
|
||||
check store2.getQueueLength(QueueType.Critical) == 1
|
||||
|
||||
let popResult = await store2.popFromQueue(QueueType.Critical)
|
||||
check popResult.isSome()
|
||||
check popResult.get()[0].msg == "Persistent Message"
|
||||
check store2.getQueueLength(QueueType.Critical) == 0
|
||||
|
||||
asyncTest "large batch handling":
|
||||
## Given
|
||||
let store = await RateLimitStore[string].new(db)
|
||||
var largeBatch: seq[tuple[msgId: string, msg: string]]
|
||||
|
||||
for i in 1 .. 100:
|
||||
largeBatch.add(("msg" & $i, "Message " & $i))
|
||||
|
||||
## When
|
||||
let addResult = await store.pushToQueue(QueueType.Normal, largeBatch)
|
||||
|
||||
## Then
|
||||
check addResult == true
|
||||
check store.getQueueLength(QueueType.Normal) == 1
|
||||
|
||||
let popResult = await store.popFromQueue(QueueType.Normal)
|
||||
check popResult.isSome()
|
||||
let poppedMsgs = popResult.get()
|
||||
check poppedMsgs.len == 100
|
||||
check poppedMsgs[0].msgId == "msg1"
|
||||
check poppedMsgs[99].msgId == "msg100"
|
||||
check store.getQueueLength(QueueType.Normal) == 0
|
||||
Loading…
x
Reference in New Issue
Block a user