From 7e4f930ae30827edcab643e6cdf8f1be81aac59e Mon Sep 17 00:00:00 2001 From: pablo Date: Mon, 14 Jul 2025 11:14:36 +0300 Subject: [PATCH 1/4] feat: store rate limit --- .gitignore | 2 + migrations/001_create_ratelimit_state.sql | 7 + ratelimit/store/memory.nim | 21 +++ ratelimit/store/sqlite.nim | 78 +++++++++ ratelimit/store/store.nim | 13 ++ ratelimit/token_bucket.nim | 198 ++++++++++++++++++++++ tests/test_sqlite_store.nim | 40 +++++ 7 files changed, 359 insertions(+) create mode 100644 migrations/001_create_ratelimit_state.sql create mode 100644 ratelimit/store/memory.nim create mode 100644 ratelimit/store/sqlite.nim create mode 100644 ratelimit/store/store.nim create mode 100644 ratelimit/token_bucket.nim create mode 100644 tests/test_sqlite_store.nim diff --git a/.gitignore b/.gitignore index 3a79a35..4ee6ee2 100644 --- a/.gitignore +++ b/.gitignore @@ -23,5 +23,7 @@ chat_sdk/* !*.nim apps/* !*.nim +tests/* +!*.nim nimble.develop nimble.paths diff --git a/migrations/001_create_ratelimit_state.sql b/migrations/001_create_ratelimit_state.sql new file mode 100644 index 0000000..11431ab --- /dev/null +++ b/migrations/001_create_ratelimit_state.sql @@ -0,0 +1,7 @@ + -- will only exist one row in the table + CREATE TABLE IF NOT EXISTS bucket_state ( + id INTEGER PRIMARY KEY, + budget INTEGER NOT NULL, + budget_cap INTEGER NOT NULL, + last_time_full_seconds INTEGER NOT NULL + ) \ No newline at end of file diff --git a/ratelimit/store/memory.nim b/ratelimit/store/memory.nim new file mode 100644 index 0000000..ef19de1 --- /dev/null +++ b/ratelimit/store/memory.nim @@ -0,0 +1,21 @@ +import std/times +import ./store +import chronos + +# Memory Implementation +type MemoryRateLimitStore* = ref object + bucketState: BucketState + +proc newMemoryRateLimitStore*(): MemoryRateLimitStore = + result = MemoryRateLimitStore() + result.bucketState = + BucketState(budget: 10, budgetCap: 10, lastTimeFull: Moment.now()) + +proc saveBucketState*( + store: MemoryRateLimitStore, bucketState: BucketState +): Future[bool] {.async.} = + store.bucketState = bucketState + return true + +proc loadBucketState*(store: MemoryRateLimitStore): Future[BucketState] {.async.} = + return store.bucketState diff --git a/ratelimit/store/sqlite.nim b/ratelimit/store/sqlite.nim new file mode 100644 index 0000000..d8aaf38 --- /dev/null +++ b/ratelimit/store/sqlite.nim @@ -0,0 +1,78 @@ +import std/times +import std/strutils +import ./store +import chronos +import db_connector/db_sqlite + +# SQLite Implementation +type SqliteRateLimitStore* = ref object + db: DbConn + dbPath: string + +proc newSqliteRateLimitStore*(dbPath: string = ":memory:"): SqliteRateLimitStore = + result = SqliteRateLimitStore(dbPath: dbPath) + result.db = open(dbPath, "", "", "") + + # Create table if it doesn't exist + result.db.exec( + sql""" + CREATE TABLE IF NOT EXISTS bucket_state ( + id INTEGER PRIMARY KEY, + budget INTEGER NOT NULL, + budget_cap INTEGER NOT NULL, + last_time_full_seconds INTEGER NOT NULL + ) + """ + ) + + # Insert default state if table is empty + let count = result.db.getValue(sql"SELECT COUNT(*) FROM bucket_state").parseInt() + if count == 0: + let defaultTimeSeconds = Moment.now().epochSeconds() + result.db.exec( + sql""" + INSERT INTO bucket_state (id, budget, budget_cap, last_time_full_seconds) + VALUES (1, 10, 10, ?) + """, + defaultTimeSeconds, + ) + +proc close*(store: SqliteRateLimitStore) = + if store.db != nil: + store.db.close() + +proc saveBucketState*( + store: SqliteRateLimitStore, bucketState: BucketState +): Future[bool] {.async.} = + try: + # Convert Moment to Unix seconds for storage + let lastTimeSeconds = bucketState.lastTimeFull.epochSeconds() + store.db.exec( + sql""" + UPDATE bucket_state + SET budget = ?, budget_cap = ?, last_time_full_seconds = ? + WHERE id = 1 + """, + bucketState.budget, + bucketState.budgetCap, + lastTimeSeconds, + ) + return true + except: + return false + +proc loadBucketState*(store: SqliteRateLimitStore): Future[BucketState] {.async.} = + let row = store.db.getRow( + sql""" + SELECT budget, budget_cap, last_time_full_seconds + FROM bucket_state + WHERE id = 1 + """ + ) + # Convert Unix seconds back to Moment (seconds precission) + let unixSeconds = row[2].parseInt().int64 + let lastTimeFull = Moment.init(unixSeconds, chronos.seconds(1)) + + return BucketState( + budget: row[0].parseInt(), budgetCap: row[1].parseInt(), lastTimeFull: lastTimeFull + ) diff --git a/ratelimit/store/store.nim b/ratelimit/store/store.nim new file mode 100644 index 0000000..fab578d --- /dev/null +++ b/ratelimit/store/store.nim @@ -0,0 +1,13 @@ +import std/[times, deques] +import chronos + +type + BucketState* = object + budget*: int + budgetCap*: int + lastTimeFull*: Moment + + RateLimitStoreConcept* = + concept s + s.saveBucketState(BucketState) is Future[bool] + s.loadBucketState() is Future[BucketState] diff --git a/ratelimit/token_bucket.nim b/ratelimit/token_bucket.nim new file mode 100644 index 0000000..447569a --- /dev/null +++ b/ratelimit/token_bucket.nim @@ -0,0 +1,198 @@ +{.push raises: [].} + +import chronos, std/math, std/options + +const BUDGET_COMPENSATION_LIMIT_PERCENT = 0.25 + +## This is an extract from chronos/rate_limit.nim due to the found bug in the original implementation. +## Unfortunately that bug cannot be solved without harm the original features of TokenBucket class. +## So, this current shortcut is used to enable move ahead with nwaku rate limiter implementation. +## ref: https://github.com/status-im/nim-chronos/issues/500 +## +## This version of TokenBucket is different from the original one in chronos/rate_limit.nim in many ways: +## - It has a new mode called `Compensating` which is the default mode. +## Compensation is calculated as the not used bucket capacity in the last measured period(s) in average. +## or up until maximum the allowed compansation treshold (Currently it is const 25%). +## Also compensation takes care of the proper time period calculation to avoid non-usage periods that can lead to +## overcompensation. +## - Strict mode is also available which will only replenish when time period is over but also will fill +## the bucket to the max capacity. + +type + ReplenishMode* = enum + Strict + Compensating + + TokenBucket* = ref object + budget: int ## Current number of tokens in the bucket + budgetCap: int ## Bucket capacity + lastTimeFull: Moment + ## This timer measures the proper periodizaiton of the bucket refilling + fillDuration: Duration ## Refill period + case replenishMode*: ReplenishMode + of Strict: + ## In strict mode, the bucket is refilled only till the budgetCap + discard + of Compensating: + ## This is the default mode. + maxCompensation: float + +func periodDistance(bucket: TokenBucket, currentTime: Moment): float = + ## notice fillDuration cannot be zero by design + ## period distance is a float number representing the calculated period time + ## since the last time bucket was refilled. + return + nanoseconds(currentTime - bucket.lastTimeFull).float / + nanoseconds(bucket.fillDuration).float + +func getUsageAverageSince(bucket: TokenBucket, distance: float): float = + if distance == 0.float: + ## in case there is zero time difference than the usage percentage is 100% + return 1.0 + + ## budgetCap can never be zero + ## usage average is calculated as a percentage of total capacity available over + ## the measured period + return bucket.budget.float / bucket.budgetCap.float / distance + +proc calcCompensation(bucket: TokenBucket, averageUsage: float): int = + # if we already fully used or even overused the tokens, there is no place for compensation + if averageUsage >= 1.0: + return 0 + + ## compensation is the not used bucket capacity in the last measured period(s) in average. + ## or maximum the allowed compansation treshold + let compensationPercent = + min((1.0 - averageUsage) * bucket.budgetCap.float, bucket.maxCompensation) + return trunc(compensationPercent).int + +func periodElapsed(bucket: TokenBucket, currentTime: Moment): bool = + return currentTime - bucket.lastTimeFull >= bucket.fillDuration + +## Update will take place if bucket is empty and trying to consume tokens. +## It checks if the bucket can be replenished as refill duration is passed or not. +## - strict mode: +proc updateStrict(bucket: TokenBucket, currentTime: Moment) = + if bucket.fillDuration == default(Duration): + bucket.budget = min(bucket.budgetCap, bucket.budget) + return + + if not periodElapsed(bucket, currentTime): + return + + bucket.budget = bucket.budgetCap + bucket.lastTimeFull = currentTime + +## - compensating - ballancing load: +## - between updates we calculate average load (current bucket capacity / number of periods till last update) +## - gives the percentage load used recently +## - with this we can replenish bucket up to 100% + calculated leftover from previous period (caped with max treshold) +proc updateWithCompensation(bucket: TokenBucket, currentTime: Moment) = + if bucket.fillDuration == default(Duration): + bucket.budget = min(bucket.budgetCap, bucket.budget) + return + + # do not replenish within the same period + if not periodElapsed(bucket, currentTime): + return + + let distance = bucket.periodDistance(currentTime) + let recentAvgUsage = bucket.getUsageAverageSince(distance) + let compensation = bucket.calcCompensation(recentAvgUsage) + + bucket.budget = bucket.budgetCap + compensation + bucket.lastTimeFull = currentTime + +proc update(bucket: TokenBucket, currentTime: Moment) = + if bucket.replenishMode == ReplenishMode.Compensating: + updateWithCompensation(bucket, currentTime) + else: + updateStrict(bucket, currentTime) + +## Returns the available capacity of the bucket: (budget, budgetCap) +proc getAvailableCapacity*( + bucket: TokenBucket, currentTime: Moment +): tuple[budget: int, budgetCap: int] = + if periodElapsed(bucket, currentTime): + case bucket.replenishMode + of ReplenishMode.Strict: + return (bucket.budgetCap, bucket.budgetCap) + of ReplenishMode.Compensating: + let distance = bucket.periodDistance(currentTime) + let recentAvgUsage = bucket.getUsageAverageSince(distance) + let compensation = bucket.calcCompensation(recentAvgUsage) + let availableBudget = bucket.budgetCap + compensation + return (availableBudget, bucket.budgetCap) + return (bucket.budget, bucket.budgetCap) + +proc tryConsume*(bucket: TokenBucket, tokens: int, now = Moment.now()): bool = + ## If `tokens` are available, consume them, + ## Otherwhise, return false. + + if bucket.budget >= bucket.budgetCap: + bucket.lastTimeFull = now + + if bucket.budget >= tokens: + bucket.budget -= tokens + return true + + bucket.update(now) + + if bucket.budget >= tokens: + bucket.budget -= tokens + return true + else: + return false + +proc replenish*(bucket: TokenBucket, tokens: int, now = Moment.now()) = + ## Add `tokens` to the budget (capped to the bucket capacity) + bucket.budget += tokens + bucket.update(now) + +proc new*( + T: type[TokenBucket], + budgetCap: int, + fillDuration: Duration = 1.seconds, + mode: ReplenishMode = ReplenishMode.Compensating, +): T = + assert not isZero(fillDuration) + assert budgetCap != 0 + + ## Create different mode TokenBucket + case mode + of ReplenishMode.Strict: + return T( + budget: budgetCap, + budgetCap: budgetCap, + fillDuration: fillDuration, + lastTimeFull: Moment.now(), + replenishMode: mode, + ) + of ReplenishMode.Compensating: + T( + budget: budgetCap, + budgetCap: budgetCap, + fillDuration: fillDuration, + lastTimeFull: Moment.now(), + replenishMode: mode, + maxCompensation: budgetCap.float * BUDGET_COMPENSATION_LIMIT_PERCENT, + ) + +proc newStrict*(T: type[TokenBucket], capacity: int, period: Duration): TokenBucket = + T.new(capacity, period, ReplenishMode.Strict) + +proc newCompensating*( + T: type[TokenBucket], capacity: int, period: Duration +): TokenBucket = + T.new(capacity, period, ReplenishMode.Compensating) + +func `$`*(b: TokenBucket): string {.inline.} = + if isNil(b): + return "nil" + return $b.budgetCap & "/" & $b.fillDuration + +func `$`*(ob: Option[TokenBucket]): string {.inline.} = + if ob.isNone(): + return "no-limit" + + return $ob.get() diff --git a/tests/test_sqlite_store.nim b/tests/test_sqlite_store.nim new file mode 100644 index 0000000..98dbe08 --- /dev/null +++ b/tests/test_sqlite_store.nim @@ -0,0 +1,40 @@ +{.used.} + +import testutils/unittests +import ../ratelimit/store/sqlite +import ../ratelimit/store/store +import chronos + +suite "SqliteRateLimitStore Tests": + asyncTest "newSqliteRateLimitStore - creates store with default values": + ## Given & When + let now = Moment.now() + let store = newSqliteRateLimitStore() + defer: + store.close() + + ## Then + let bucketState = await store.loadBucketState() + check bucketState.budget == 10 + check bucketState.budgetCap == 10 + check bucketState.lastTimeFull.epochSeconds() == now.epochSeconds() + + asyncTest "saveBucketState and loadBucketState - state persistence": + ## Given + let now = Moment.now() + let store = newSqliteRateLimitStore() + defer: + store.close() + + let newBucketState = BucketState(budget: 5, budgetCap: 20, lastTimeFull: now) + + ## When + let saveResult = await store.saveBucketState(newBucketState) + let loadedState = await store.loadBucketState() + + ## Then + check saveResult == true + check loadedState.budget == newBucketState.budget + check loadedState.budgetCap == newBucketState.budgetCap + check loadedState.lastTimeFull.epochSeconds() == + newBucketState.lastTimeFull.epochSeconds() From 378d6a5433748d5998e746682c5d2fa606817e0a Mon Sep 17 00:00:00 2001 From: pablo Date: Wed, 16 Jul 2025 09:22:29 +0300 Subject: [PATCH 2/4] fix: use kv store --- migrations/001_create_ratelimit_state.sql | 11 ++-- ratelimit/store/sqlite.nim | 74 ++++++++++++++--------- tests/test_sqlite_store.nim | 2 - 3 files changed, 48 insertions(+), 39 deletions(-) diff --git a/migrations/001_create_ratelimit_state.sql b/migrations/001_create_ratelimit_state.sql index 11431ab..030377c 100644 --- a/migrations/001_create_ratelimit_state.sql +++ b/migrations/001_create_ratelimit_state.sql @@ -1,7 +1,4 @@ - -- will only exist one row in the table - CREATE TABLE IF NOT EXISTS bucket_state ( - id INTEGER PRIMARY KEY, - budget INTEGER NOT NULL, - budget_cap INTEGER NOT NULL, - last_time_full_seconds INTEGER NOT NULL - ) \ No newline at end of file +CREATE TABLE IF NOT EXISTS kv_store ( + key TEXT PRIMARY KEY, + value BLOB +); \ No newline at end of file diff --git a/ratelimit/store/sqlite.nim b/ratelimit/store/sqlite.nim index d8aaf38..ba3e72e 100644 --- a/ratelimit/store/sqlite.nim +++ b/ratelimit/store/sqlite.nim @@ -1,5 +1,6 @@ import std/times import std/strutils +import std/json import ./store import chronos import db_connector/db_sqlite @@ -9,6 +10,8 @@ type SqliteRateLimitStore* = ref object db: DbConn dbPath: string +const BUCKET_STATE_KEY = "rate_limit_bucket_state" + proc newSqliteRateLimitStore*(dbPath: string = ":memory:"): SqliteRateLimitStore = result = SqliteRateLimitStore(dbPath: dbPath) result.db = open(dbPath, "", "", "") @@ -16,25 +19,35 @@ proc newSqliteRateLimitStore*(dbPath: string = ":memory:"): SqliteRateLimitStore # Create table if it doesn't exist result.db.exec( sql""" - CREATE TABLE IF NOT EXISTS bucket_state ( - id INTEGER PRIMARY KEY, - budget INTEGER NOT NULL, - budget_cap INTEGER NOT NULL, - last_time_full_seconds INTEGER NOT NULL + CREATE TABLE IF NOT EXISTS kv_store ( + key TEXT PRIMARY KEY, + value BLOB ) """ ) - # Insert default state if table is empty - let count = result.db.getValue(sql"SELECT COUNT(*) FROM bucket_state").parseInt() + # Insert default state if key doesn't exist + let count = result.db + .getValue(sql"SELECT COUNT(*) FROM kv_store WHERE key = ?", BUCKET_STATE_KEY) + .parseInt() if count == 0: let defaultTimeSeconds = Moment.now().epochSeconds() + let defaultState = BucketState( + budget: 10, + budgetCap: 10, + lastTimeFull: Moment.init(defaultTimeSeconds, chronos.seconds(1)), + ) + + # Serialize to JSON + let jsonState = + %*{ + "budget": defaultState.budget, + "budgetCap": defaultState.budgetCap, + "lastTimeFullSeconds": defaultTimeSeconds, + } + result.db.exec( - sql""" - INSERT INTO bucket_state (id, budget, budget_cap, last_time_full_seconds) - VALUES (1, 10, 10, ?) - """, - defaultTimeSeconds, + sql"INSERT INTO kv_store (key, value) VALUES (?, ?)", BUCKET_STATE_KEY, $jsonState ) proc close*(store: SqliteRateLimitStore) = @@ -47,32 +60,33 @@ proc saveBucketState*( try: # Convert Moment to Unix seconds for storage let lastTimeSeconds = bucketState.lastTimeFull.epochSeconds() + + # Serialize to JSON + let jsonState = + %*{ + "budget": bucketState.budget, + "budgetCap": bucketState.budgetCap, + "lastTimeFullSeconds": lastTimeSeconds, + } + store.db.exec( - sql""" - UPDATE bucket_state - SET budget = ?, budget_cap = ?, last_time_full_seconds = ? - WHERE id = 1 - """, - bucketState.budget, - bucketState.budgetCap, - lastTimeSeconds, + sql"UPDATE kv_store SET value = ? WHERE key = ?", $jsonState, BUCKET_STATE_KEY ) return true except: return false proc loadBucketState*(store: SqliteRateLimitStore): Future[BucketState] {.async.} = - let row = store.db.getRow( - sql""" - SELECT budget, budget_cap, last_time_full_seconds - FROM bucket_state - WHERE id = 1 - """ - ) - # Convert Unix seconds back to Moment (seconds precission) - let unixSeconds = row[2].parseInt().int64 + let jsonStr = + store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY) + + # Parse JSON and reconstruct BucketState + let jsonData = parseJson(jsonStr) + let unixSeconds = jsonData["lastTimeFullSeconds"].getInt().int64 let lastTimeFull = Moment.init(unixSeconds, chronos.seconds(1)) return BucketState( - budget: row[0].parseInt(), budgetCap: row[1].parseInt(), lastTimeFull: lastTimeFull + budget: jsonData["budget"].getInt(), + budgetCap: jsonData["budgetCap"].getInt(), + lastTimeFull: lastTimeFull, ) diff --git a/tests/test_sqlite_store.nim b/tests/test_sqlite_store.nim index 98dbe08..8728992 100644 --- a/tests/test_sqlite_store.nim +++ b/tests/test_sqlite_store.nim @@ -1,5 +1,3 @@ -{.used.} - import testutils/unittests import ../ratelimit/store/sqlite import ../ratelimit/store/store From ce18bb0f50911d4316220e361625f8ef23fbf739 Mon Sep 17 00:00:00 2001 From: pablo Date: Wed, 16 Jul 2025 10:05:47 +0300 Subject: [PATCH 3/4] fix: tests --- ratelimit/store/memory.nim | 10 +++--- ratelimit/store/sqlite.nim | 70 +++++++++---------------------------- ratelimit/store/store.nim | 4 +-- tests/test_sqlite_store.nim | 44 ++++++++++++++--------- 4 files changed, 51 insertions(+), 77 deletions(-) diff --git a/ratelimit/store/memory.nim b/ratelimit/store/memory.nim index ef19de1..557a17d 100644 --- a/ratelimit/store/memory.nim +++ b/ratelimit/store/memory.nim @@ -1,4 +1,4 @@ -import std/times +import std/[times, options] import ./store import chronos @@ -8,8 +8,6 @@ type MemoryRateLimitStore* = ref object proc newMemoryRateLimitStore*(): MemoryRateLimitStore = result = MemoryRateLimitStore() - result.bucketState = - BucketState(budget: 10, budgetCap: 10, lastTimeFull: Moment.now()) proc saveBucketState*( store: MemoryRateLimitStore, bucketState: BucketState @@ -17,5 +15,7 @@ proc saveBucketState*( store.bucketState = bucketState return true -proc loadBucketState*(store: MemoryRateLimitStore): Future[BucketState] {.async.} = - return store.bucketState +proc loadBucketState*( + store: MemoryRateLimitStore +): Future[Option[BucketState]] {.async.} = + return some(store.bucketState) diff --git a/ratelimit/store/sqlite.nim b/ratelimit/store/sqlite.nim index ba3e72e..e364e5d 100644 --- a/ratelimit/store/sqlite.nim +++ b/ratelimit/store/sqlite.nim @@ -1,6 +1,4 @@ -import std/times -import std/strutils -import std/json +import std/[times, strutils, json, options] import ./store import chronos import db_connector/db_sqlite @@ -12,47 +10,8 @@ type SqliteRateLimitStore* = ref object const BUCKET_STATE_KEY = "rate_limit_bucket_state" -proc newSqliteRateLimitStore*(dbPath: string = ":memory:"): SqliteRateLimitStore = - result = SqliteRateLimitStore(dbPath: dbPath) - result.db = open(dbPath, "", "", "") - - # Create table if it doesn't exist - result.db.exec( - sql""" - CREATE TABLE IF NOT EXISTS kv_store ( - key TEXT PRIMARY KEY, - value BLOB - ) - """ - ) - - # Insert default state if key doesn't exist - let count = result.db - .getValue(sql"SELECT COUNT(*) FROM kv_store WHERE key = ?", BUCKET_STATE_KEY) - .parseInt() - if count == 0: - let defaultTimeSeconds = Moment.now().epochSeconds() - let defaultState = BucketState( - budget: 10, - budgetCap: 10, - lastTimeFull: Moment.init(defaultTimeSeconds, chronos.seconds(1)), - ) - - # Serialize to JSON - let jsonState = - %*{ - "budget": defaultState.budget, - "budgetCap": defaultState.budgetCap, - "lastTimeFullSeconds": defaultTimeSeconds, - } - - result.db.exec( - sql"INSERT INTO kv_store (key, value) VALUES (?, ?)", BUCKET_STATE_KEY, $jsonState - ) - -proc close*(store: SqliteRateLimitStore) = - if store.db != nil: - store.db.close() +proc newSqliteRateLimitStore*(db: DbConn): SqliteRateLimitStore = + result = SqliteRateLimitStore(db: db) proc saveBucketState*( store: SqliteRateLimitStore, bucketState: BucketState @@ -61,32 +20,37 @@ proc saveBucketState*( # Convert Moment to Unix seconds for storage let lastTimeSeconds = bucketState.lastTimeFull.epochSeconds() - # Serialize to JSON let jsonState = %*{ "budget": bucketState.budget, "budgetCap": bucketState.budgetCap, "lastTimeFullSeconds": lastTimeSeconds, } - store.db.exec( - sql"UPDATE kv_store SET value = ? WHERE key = ?", $jsonState, BUCKET_STATE_KEY + sql"INSERT INTO kv_store (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value", + BUCKET_STATE_KEY, + $jsonState, ) return true except: return false -proc loadBucketState*(store: SqliteRateLimitStore): Future[BucketState] {.async.} = +proc loadBucketState*( + store: SqliteRateLimitStore +): Future[Option[BucketState]] {.async.} = let jsonStr = store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY) + if jsonStr == "": + return none(BucketState) - # Parse JSON and reconstruct BucketState let jsonData = parseJson(jsonStr) let unixSeconds = jsonData["lastTimeFullSeconds"].getInt().int64 let lastTimeFull = Moment.init(unixSeconds, chronos.seconds(1)) - return BucketState( - budget: jsonData["budget"].getInt(), - budgetCap: jsonData["budgetCap"].getInt(), - lastTimeFull: lastTimeFull, + return some( + BucketState( + budget: jsonData["budget"].getInt(), + budgetCap: jsonData["budgetCap"].getInt(), + lastTimeFull: lastTimeFull, + ) ) diff --git a/ratelimit/store/store.nim b/ratelimit/store/store.nim index fab578d..c4f6da3 100644 --- a/ratelimit/store/store.nim +++ b/ratelimit/store/store.nim @@ -1,4 +1,4 @@ -import std/[times, deques] +import std/[times, deques, options] import chronos type @@ -10,4 +10,4 @@ type RateLimitStoreConcept* = concept s s.saveBucketState(BucketState) is Future[bool] - s.loadBucketState() is Future[BucketState] + s.loadBucketState() is Future[Option[BucketState]] diff --git a/tests/test_sqlite_store.nim b/tests/test_sqlite_store.nim index 8728992..6315ec9 100644 --- a/tests/test_sqlite_store.nim +++ b/tests/test_sqlite_store.nim @@ -2,28 +2,37 @@ import testutils/unittests import ../ratelimit/store/sqlite import ../ratelimit/store/store import chronos +import db_connector/db_sqlite +import ../chat_sdk/migration +import std/[options, os] suite "SqliteRateLimitStore Tests": - asyncTest "newSqliteRateLimitStore - creates store with default values": - ## Given & When - let now = Moment.now() - let store = newSqliteRateLimitStore() - defer: - store.close() + setup: + let db = open("test-ratelimit.db", "", "", "") + runMigrations(db) + + teardown: + if db != nil: + db.close() + if fileExists("test-ratelimit.db"): + removeFile("test-ratelimit.db") + + asyncTest "newSqliteRateLimitStore - empty state": + ## Given + let store = newSqliteRateLimitStore(db) + + ## When + let loadedState = await store.loadBucketState() ## Then - let bucketState = await store.loadBucketState() - check bucketState.budget == 10 - check bucketState.budgetCap == 10 - check bucketState.lastTimeFull.epochSeconds() == now.epochSeconds() + check loadedState.isNone() asyncTest "saveBucketState and loadBucketState - state persistence": ## Given - let now = Moment.now() - let store = newSqliteRateLimitStore() - defer: - store.close() + let store = newSqliteRateLimitStore(db) + let now = Moment.now() + echo "now: ", now.epochSeconds() let newBucketState = BucketState(budget: 5, budgetCap: 20, lastTimeFull: now) ## When @@ -32,7 +41,8 @@ suite "SqliteRateLimitStore Tests": ## Then check saveResult == true - check loadedState.budget == newBucketState.budget - check loadedState.budgetCap == newBucketState.budgetCap - check loadedState.lastTimeFull.epochSeconds() == + check loadedState.isSome() + check loadedState.get().budget == newBucketState.budget + check loadedState.get().budgetCap == newBucketState.budgetCap + check loadedState.get().lastTimeFull.epochSeconds() == newBucketState.lastTimeFull.epochSeconds() From 9b6c9f359d8383fee5b863fc9a2ffe1865456186 Mon Sep 17 00:00:00 2001 From: pablo Date: Wed, 16 Jul 2025 15:51:26 +0300 Subject: [PATCH 4/4] feat: add rate limit store --- .gitignore | 2 + ratelimit/ratelimit.nim | 79 +++++++++++++++++++++++++------------- ratelimit/store/memory.nim | 10 ++--- ratelimit/store/store.nim | 2 +- ratelimit/token_bucket.nim | 20 ++++++---- tests/test_ratelimit.nim | 31 ++++++++++----- 6 files changed, 94 insertions(+), 50 deletions(-) diff --git a/.gitignore b/.gitignore index 4ee6ee2..1843ca9 100644 --- a/.gitignore +++ b/.gitignore @@ -25,5 +25,7 @@ apps/* !*.nim tests/* !*.nim +ratelimit/* +!*.nim nimble.develop nimble.paths diff --git a/ratelimit/ratelimit.nim b/ratelimit/ratelimit.nim index 9bd1bf6..ec10f2d 100644 --- a/ratelimit/ratelimit.nim +++ b/ratelimit/ratelimit.nim @@ -1,6 +1,7 @@ import std/[times, deques, options] -# import ./token_bucket -import waku/common/rate_limit/token_bucket +import ./token_bucket +# import waku/common/rate_limit/token_bucket +import ./store/store import chronos type @@ -27,7 +28,8 @@ type MessageSender*[T: Serializable] = proc(msgs: seq[tuple[msgId: string, msg: T]]): Future[void] {.async.} - RateLimitManager*[T: Serializable] = ref object + RateLimitManager*[T: Serializable, S: RateLimitStore] = ref object + store: S bucket: TokenBucket sender: MessageSender[T] running: bool @@ -35,14 +37,30 @@ type queueNormal: Deque[seq[tuple[msgId: string, msg: T]]] sleepDuration: chronos.Duration -proc newRateLimitManager*[T: Serializable]( +proc newRateLimitManager*[T: Serializable, S: RateLimitStore]( + store: S, sender: MessageSender[T], capacity: int = 100, duration: chronos.Duration = chronos.minutes(10), sleepDuration: chronos.Duration = chronos.milliseconds(1000), -): RateLimitManager[T] = - RateLimitManager[T]( - bucket: TokenBucket.newStrict(capacity, duration), +): Future[RateLimitManager[T, S]] {.async.} = + var current = await store.loadBucketState() + if current.isNone(): + # initialize bucket state with full capacity + current = some( + BucketState(budget: capacity, budgetCap: capacity, lastTimeFull: Moment.now()) + ) + discard await store.saveBucketState(current.get()) + + return RateLimitManager[T, S]( + store: store, + bucket: TokenBucket.new( + current.get().budgetCap, + duration, + ReplenishMode.Strict, + current.get().budget, + current.get().lastTimeFull, + ), sender: sender, running: false, queueCritical: Deque[seq[tuple[msgId: string, msg: T]]](), @@ -50,10 +68,10 @@ proc newRateLimitManager*[T: Serializable]( sleepDuration: sleepDuration, ) -proc getCapacityState[T: Serializable]( - manager: RateLimitManager[T], now: Moment, count: int = 1 +proc getCapacityState[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S], now: Moment, count: int = 1 ): CapacityState = - let (budget, budgetCap) = manager.bucket.getAvailableCapacity(now) + let (budget, budgetCap, _) = manager.bucket.getAvailableCapacity(now) let countAfter = budget - count let ratio = countAfter.float / budgetCap.float if ratio < 0.0: @@ -63,15 +81,15 @@ proc getCapacityState[T: Serializable]( else: return CapacityState.Normal -proc passToSender[T: Serializable]( - manager: RateLimitManager[T], +proc passToSender[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S], msgs: seq[tuple[msgId: string, msg: T]], now: Moment, priority: Priority, ): Future[SendResult] {.async.} = let count = msgs.len - let capacity = manager.bucket.tryConsume(count, now) - if not capacity: + let consumed = manager.bucket.tryConsume(count, now) + if not consumed: case priority of Priority.Critical: manager.queueCritical.addLast(msgs) @@ -81,11 +99,16 @@ proc passToSender[T: Serializable]( return SendResult.Enqueued of Priority.Optional: return SendResult.Dropped + + let (budget, budgetCap, lastTimeFull) = manager.bucket.getAvailableCapacity(now) + discard await manager.store.saveBucketState( + BucketState(budget: budget, budgetCap: budgetCap, lastTimeFull: lastTimeFull) + ) await manager.sender(msgs) return SendResult.PassedToSender -proc processCriticalQueue[T: Serializable]( - manager: RateLimitManager[T], now: Moment +proc processCriticalQueue[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S], now: Moment ): Future[void] {.async.} = while manager.queueCritical.len > 0: let msgs = manager.queueCritical.popFirst() @@ -99,8 +122,8 @@ proc processCriticalQueue[T: Serializable]( manager.queueCritical.addFirst(msgs) break -proc processNormalQueue[T: Serializable]( - manager: RateLimitManager[T], now: Moment +proc processNormalQueue[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S], now: Moment ): Future[void] {.async.} = while manager.queueNormal.len > 0: let msgs = manager.queueNormal.popFirst() @@ -112,13 +135,13 @@ proc processNormalQueue[T: Serializable]( manager.queueNormal.addFirst(msgs) break -proc sendOrEnqueue*[T: Serializable]( - manager: RateLimitManager[T], +proc sendOrEnqueue*[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S], msgs: seq[tuple[msgId: string, msg: T]], priority: Priority, now: Moment = Moment.now(), ): Future[SendResult] {.async.} = - let (_, budgetCap) = manager.bucket.getAvailableCapacity(now) + let (_, budgetCap, _) = manager.bucket.getAvailableCapacity(now) if msgs.len.float / budgetCap.float >= 0.3: # drop batch if it's too large to avoid starvation return SendResult.DroppedBatchTooLarge @@ -147,8 +170,8 @@ proc sendOrEnqueue*[T: Serializable]( of Priority.Optional: return SendResult.Dropped -proc getEnqueued*[T: Serializable]( - manager: RateLimitManager[T] +proc getEnqueued*[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S] ): tuple[ critical: seq[tuple[msgId: string, msg: T]], normal: seq[tuple[msgId: string, msg: T]] ] = @@ -163,8 +186,8 @@ proc getEnqueued*[T: Serializable]( return (criticalMsgs, normalMsgs) -proc start*[T: Serializable]( - manager: RateLimitManager[T], +proc start*[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S], nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} = Moment.now(), ): Future[void] {.async.} = @@ -181,10 +204,12 @@ proc start*[T: Serializable]( echo "Error in queue processing: ", e.msg await sleepAsync(manager.sleepDuration) -proc stop*[T: Serializable](manager: RateLimitManager[T]) = +proc stop*[T: Serializable, S: RateLimitStore](manager: RateLimitManager[T, S]) = manager.running = false -func `$`*[T: Serializable](b: RateLimitManager[T]): string {.inline.} = +func `$`*[T: Serializable, S: RateLimitStore]( + b: RateLimitManager[T, S] +): string {.inline.} = if isNil(b): return "nil" return diff --git a/ratelimit/store/memory.nim b/ratelimit/store/memory.nim index 557a17d..6391314 100644 --- a/ratelimit/store/memory.nim +++ b/ratelimit/store/memory.nim @@ -4,18 +4,18 @@ import chronos # Memory Implementation type MemoryRateLimitStore* = ref object - bucketState: BucketState + bucketState: Option[BucketState] -proc newMemoryRateLimitStore*(): MemoryRateLimitStore = - result = MemoryRateLimitStore() +proc new*(T: type[MemoryRateLimitStore]): T = + return T(bucketState: none(BucketState)) proc saveBucketState*( store: MemoryRateLimitStore, bucketState: BucketState ): Future[bool] {.async.} = - store.bucketState = bucketState + store.bucketState = some(bucketState) return true proc loadBucketState*( store: MemoryRateLimitStore ): Future[Option[BucketState]] {.async.} = - return some(store.bucketState) + return store.bucketState diff --git a/ratelimit/store/store.nim b/ratelimit/store/store.nim index c4f6da3..c916750 100644 --- a/ratelimit/store/store.nim +++ b/ratelimit/store/store.nim @@ -7,7 +7,7 @@ type budgetCap*: int lastTimeFull*: Moment - RateLimitStoreConcept* = + RateLimitStore* = concept s s.saveBucketState(BucketState) is Future[bool] s.loadBucketState() is Future[Option[BucketState]] diff --git a/ratelimit/token_bucket.nim b/ratelimit/token_bucket.nim index 447569a..e4a4487 100644 --- a/ratelimit/token_bucket.nim +++ b/ratelimit/token_bucket.nim @@ -112,18 +112,18 @@ proc update(bucket: TokenBucket, currentTime: Moment) = ## Returns the available capacity of the bucket: (budget, budgetCap) proc getAvailableCapacity*( bucket: TokenBucket, currentTime: Moment -): tuple[budget: int, budgetCap: int] = +): tuple[budget: int, budgetCap: int, lastTimeFull: Moment] = if periodElapsed(bucket, currentTime): case bucket.replenishMode of ReplenishMode.Strict: - return (bucket.budgetCap, bucket.budgetCap) + return (bucket.budgetCap, bucket.budgetCap, bucket.lastTimeFull) of ReplenishMode.Compensating: let distance = bucket.periodDistance(currentTime) let recentAvgUsage = bucket.getUsageAverageSince(distance) let compensation = bucket.calcCompensation(recentAvgUsage) let availableBudget = bucket.budgetCap + compensation - return (availableBudget, bucket.budgetCap) - return (bucket.budget, bucket.budgetCap) + return (availableBudget, bucket.budgetCap, bucket.lastTimeFull) + return (bucket.budget, bucket.budgetCap, bucket.lastTimeFull) proc tryConsume*(bucket: TokenBucket, tokens: int, now = Moment.now()): bool = ## If `tokens` are available, consume them, @@ -154,26 +154,30 @@ proc new*( budgetCap: int, fillDuration: Duration = 1.seconds, mode: ReplenishMode = ReplenishMode.Compensating, + budget: int = -1, # -1 means "use budgetCap" + lastTimeFull: Moment = Moment.now(), ): T = assert not isZero(fillDuration) assert budgetCap != 0 + let actualBudget = if budget == -1: budgetCap else: budget + assert actualBudget >= 0 and actualBudget <= budgetCap ## Create different mode TokenBucket case mode of ReplenishMode.Strict: return T( - budget: budgetCap, + budget: actualBudget, budgetCap: budgetCap, fillDuration: fillDuration, - lastTimeFull: Moment.now(), + lastTimeFull: lastTimeFull, replenishMode: mode, ) of ReplenishMode.Compensating: T( - budget: budgetCap, + budget: actualBudget, budgetCap: budgetCap, fillDuration: fillDuration, - lastTimeFull: Moment.now(), + lastTimeFull: lastTimeFull, replenishMode: mode, maxCompensation: budgetCap.float * BUDGET_COMPENSATION_LIMIT_PERCENT, ) diff --git a/tests/test_ratelimit.nim b/tests/test_ratelimit.nim index f38ddbf..a10dd9e 100644 --- a/tests/test_ratelimit.nim +++ b/tests/test_ratelimit.nim @@ -2,6 +2,7 @@ import testutils/unittests import ../ratelimit/ratelimit +import ../ratelimit/store/memory import chronos import strutils @@ -25,8 +26,9 @@ suite "Queue RateLimitManager": asyncTest "sendOrEnqueue - immediate send when capacity available": ## Given - let manager = newRateLimitManager[string]( - mockSender, capacity = 10, duration = chronos.milliseconds(100) + let store: MemoryRateLimitStore = MemoryRateLimitStore.new() + let manager = await newRateLimitManager[string, MemoryRateLimitStore]( + store, mockSender, capacity = 10, duration = chronos.milliseconds(100) ) let testMsg = "Hello World" @@ -43,8 +45,9 @@ suite "Queue RateLimitManager": asyncTest "sendOrEnqueue - multiple messages": ## Given - let manager = newRateLimitManager[string]( - mockSender, capacity = 10, duration = chronos.milliseconds(100) + let store = MemoryRateLimitStore.new() + let manager = await newRateLimitManager[string, MemoryRateLimitStore]( + store, mockSender, capacity = 10, duration = chronos.milliseconds(100) ) ## When @@ -63,7 +66,9 @@ suite "Queue RateLimitManager": asyncTest "start and stop - basic functionality": ## Given - let manager = newRateLimitManager[string]( + let store = MemoryRateLimitStore.new() + let manager = await newRateLimitManager[string, MemoryRateLimitStore]( + store, mockSender, capacity = 10, duration = chronos.milliseconds(100), @@ -94,7 +99,9 @@ suite "Queue RateLimitManager": asyncTest "start and stop - drop large batch": ## Given - let manager = newRateLimitManager[string]( + let store = MemoryRateLimitStore.new() + let manager = await newRateLimitManager[string, MemoryRateLimitStore]( + store, mockSender, capacity = 2, duration = chronos.milliseconds(100), @@ -109,7 +116,9 @@ suite "Queue RateLimitManager": asyncTest "enqueue - enqueue critical only when exceeded": ## Given - let manager = newRateLimitManager[string]( + let store = MemoryRateLimitStore.new() + let manager = await newRateLimitManager[string, MemoryRateLimitStore]( + store, mockSender, capacity = 10, duration = chronos.milliseconds(100), @@ -163,7 +172,9 @@ suite "Queue RateLimitManager": asyncTest "enqueue - enqueue normal on 70% capacity": ## Given - let manager = newRateLimitManager[string]( + let store = MemoryRateLimitStore.new() + let manager = await newRateLimitManager[string, MemoryRateLimitStore]( + store, mockSender, capacity = 10, duration = chronos.milliseconds(100), @@ -220,7 +231,9 @@ suite "Queue RateLimitManager": asyncTest "enqueue - process queued messages": ## Given - let manager = newRateLimitManager[string]( + let store = MemoryRateLimitStore.new() + let manager = await newRateLimitManager[string, MemoryRateLimitStore]( + store, mockSender, capacity = 10, duration = chronos.milliseconds(200),