From 7e4f930ae30827edcab643e6cdf8f1be81aac59e Mon Sep 17 00:00:00 2001 From: pablo Date: Mon, 14 Jul 2025 11:14:36 +0300 Subject: [PATCH 01/17] feat: store rate limit --- .gitignore | 2 + migrations/001_create_ratelimit_state.sql | 7 + ratelimit/store/memory.nim | 21 +++ ratelimit/store/sqlite.nim | 78 +++++++++ ratelimit/store/store.nim | 13 ++ ratelimit/token_bucket.nim | 198 ++++++++++++++++++++++ tests/test_sqlite_store.nim | 40 +++++ 7 files changed, 359 insertions(+) create mode 100644 migrations/001_create_ratelimit_state.sql create mode 100644 ratelimit/store/memory.nim create mode 100644 ratelimit/store/sqlite.nim create mode 100644 ratelimit/store/store.nim create mode 100644 ratelimit/token_bucket.nim create mode 100644 tests/test_sqlite_store.nim diff --git a/.gitignore b/.gitignore index 3a79a35..4ee6ee2 100644 --- a/.gitignore +++ b/.gitignore @@ -23,5 +23,7 @@ chat_sdk/* !*.nim apps/* !*.nim +tests/* +!*.nim nimble.develop nimble.paths diff --git a/migrations/001_create_ratelimit_state.sql b/migrations/001_create_ratelimit_state.sql new file mode 100644 index 0000000..11431ab --- /dev/null +++ b/migrations/001_create_ratelimit_state.sql @@ -0,0 +1,7 @@ + -- will only exist one row in the table + CREATE TABLE IF NOT EXISTS bucket_state ( + id INTEGER PRIMARY KEY, + budget INTEGER NOT NULL, + budget_cap INTEGER NOT NULL, + last_time_full_seconds INTEGER NOT NULL + ) \ No newline at end of file diff --git a/ratelimit/store/memory.nim b/ratelimit/store/memory.nim new file mode 100644 index 0000000..ef19de1 --- /dev/null +++ b/ratelimit/store/memory.nim @@ -0,0 +1,21 @@ +import std/times +import ./store +import chronos + +# Memory Implementation +type MemoryRateLimitStore* = ref object + bucketState: BucketState + +proc newMemoryRateLimitStore*(): MemoryRateLimitStore = + result = MemoryRateLimitStore() + result.bucketState = + BucketState(budget: 10, budgetCap: 10, lastTimeFull: Moment.now()) + +proc saveBucketState*( + store: MemoryRateLimitStore, bucketState: BucketState +): Future[bool] {.async.} = + store.bucketState = bucketState + return true + +proc loadBucketState*(store: MemoryRateLimitStore): Future[BucketState] {.async.} = + return store.bucketState diff --git a/ratelimit/store/sqlite.nim b/ratelimit/store/sqlite.nim new file mode 100644 index 0000000..d8aaf38 --- /dev/null +++ b/ratelimit/store/sqlite.nim @@ -0,0 +1,78 @@ +import std/times +import std/strutils +import ./store +import chronos +import db_connector/db_sqlite + +# SQLite Implementation +type SqliteRateLimitStore* = ref object + db: DbConn + dbPath: string + +proc newSqliteRateLimitStore*(dbPath: string = ":memory:"): SqliteRateLimitStore = + result = SqliteRateLimitStore(dbPath: dbPath) + result.db = open(dbPath, "", "", "") + + # Create table if it doesn't exist + result.db.exec( + sql""" + CREATE TABLE IF NOT EXISTS bucket_state ( + id INTEGER PRIMARY KEY, + budget INTEGER NOT NULL, + budget_cap INTEGER NOT NULL, + last_time_full_seconds INTEGER NOT NULL + ) + """ + ) + + # Insert default state if table is empty + let count = result.db.getValue(sql"SELECT COUNT(*) FROM bucket_state").parseInt() + if count == 0: + let defaultTimeSeconds = Moment.now().epochSeconds() + result.db.exec( + sql""" + INSERT INTO bucket_state (id, budget, budget_cap, last_time_full_seconds) + VALUES (1, 10, 10, ?) + """, + defaultTimeSeconds, + ) + +proc close*(store: SqliteRateLimitStore) = + if store.db != nil: + store.db.close() + +proc saveBucketState*( + store: SqliteRateLimitStore, bucketState: BucketState +): Future[bool] {.async.} = + try: + # Convert Moment to Unix seconds for storage + let lastTimeSeconds = bucketState.lastTimeFull.epochSeconds() + store.db.exec( + sql""" + UPDATE bucket_state + SET budget = ?, budget_cap = ?, last_time_full_seconds = ? + WHERE id = 1 + """, + bucketState.budget, + bucketState.budgetCap, + lastTimeSeconds, + ) + return true + except: + return false + +proc loadBucketState*(store: SqliteRateLimitStore): Future[BucketState] {.async.} = + let row = store.db.getRow( + sql""" + SELECT budget, budget_cap, last_time_full_seconds + FROM bucket_state + WHERE id = 1 + """ + ) + # Convert Unix seconds back to Moment (seconds precission) + let unixSeconds = row[2].parseInt().int64 + let lastTimeFull = Moment.init(unixSeconds, chronos.seconds(1)) + + return BucketState( + budget: row[0].parseInt(), budgetCap: row[1].parseInt(), lastTimeFull: lastTimeFull + ) diff --git a/ratelimit/store/store.nim b/ratelimit/store/store.nim new file mode 100644 index 0000000..fab578d --- /dev/null +++ b/ratelimit/store/store.nim @@ -0,0 +1,13 @@ +import std/[times, deques] +import chronos + +type + BucketState* = object + budget*: int + budgetCap*: int + lastTimeFull*: Moment + + RateLimitStoreConcept* = + concept s + s.saveBucketState(BucketState) is Future[bool] + s.loadBucketState() is Future[BucketState] diff --git a/ratelimit/token_bucket.nim b/ratelimit/token_bucket.nim new file mode 100644 index 0000000..447569a --- /dev/null +++ b/ratelimit/token_bucket.nim @@ -0,0 +1,198 @@ +{.push raises: [].} + +import chronos, std/math, std/options + +const BUDGET_COMPENSATION_LIMIT_PERCENT = 0.25 + +## This is an extract from chronos/rate_limit.nim due to the found bug in the original implementation. +## Unfortunately that bug cannot be solved without harm the original features of TokenBucket class. +## So, this current shortcut is used to enable move ahead with nwaku rate limiter implementation. +## ref: https://github.com/status-im/nim-chronos/issues/500 +## +## This version of TokenBucket is different from the original one in chronos/rate_limit.nim in many ways: +## - It has a new mode called `Compensating` which is the default mode. +## Compensation is calculated as the not used bucket capacity in the last measured period(s) in average. +## or up until maximum the allowed compansation treshold (Currently it is const 25%). +## Also compensation takes care of the proper time period calculation to avoid non-usage periods that can lead to +## overcompensation. +## - Strict mode is also available which will only replenish when time period is over but also will fill +## the bucket to the max capacity. + +type + ReplenishMode* = enum + Strict + Compensating + + TokenBucket* = ref object + budget: int ## Current number of tokens in the bucket + budgetCap: int ## Bucket capacity + lastTimeFull: Moment + ## This timer measures the proper periodizaiton of the bucket refilling + fillDuration: Duration ## Refill period + case replenishMode*: ReplenishMode + of Strict: + ## In strict mode, the bucket is refilled only till the budgetCap + discard + of Compensating: + ## This is the default mode. + maxCompensation: float + +func periodDistance(bucket: TokenBucket, currentTime: Moment): float = + ## notice fillDuration cannot be zero by design + ## period distance is a float number representing the calculated period time + ## since the last time bucket was refilled. + return + nanoseconds(currentTime - bucket.lastTimeFull).float / + nanoseconds(bucket.fillDuration).float + +func getUsageAverageSince(bucket: TokenBucket, distance: float): float = + if distance == 0.float: + ## in case there is zero time difference than the usage percentage is 100% + return 1.0 + + ## budgetCap can never be zero + ## usage average is calculated as a percentage of total capacity available over + ## the measured period + return bucket.budget.float / bucket.budgetCap.float / distance + +proc calcCompensation(bucket: TokenBucket, averageUsage: float): int = + # if we already fully used or even overused the tokens, there is no place for compensation + if averageUsage >= 1.0: + return 0 + + ## compensation is the not used bucket capacity in the last measured period(s) in average. + ## or maximum the allowed compansation treshold + let compensationPercent = + min((1.0 - averageUsage) * bucket.budgetCap.float, bucket.maxCompensation) + return trunc(compensationPercent).int + +func periodElapsed(bucket: TokenBucket, currentTime: Moment): bool = + return currentTime - bucket.lastTimeFull >= bucket.fillDuration + +## Update will take place if bucket is empty and trying to consume tokens. +## It checks if the bucket can be replenished as refill duration is passed or not. +## - strict mode: +proc updateStrict(bucket: TokenBucket, currentTime: Moment) = + if bucket.fillDuration == default(Duration): + bucket.budget = min(bucket.budgetCap, bucket.budget) + return + + if not periodElapsed(bucket, currentTime): + return + + bucket.budget = bucket.budgetCap + bucket.lastTimeFull = currentTime + +## - compensating - ballancing load: +## - between updates we calculate average load (current bucket capacity / number of periods till last update) +## - gives the percentage load used recently +## - with this we can replenish bucket up to 100% + calculated leftover from previous period (caped with max treshold) +proc updateWithCompensation(bucket: TokenBucket, currentTime: Moment) = + if bucket.fillDuration == default(Duration): + bucket.budget = min(bucket.budgetCap, bucket.budget) + return + + # do not replenish within the same period + if not periodElapsed(bucket, currentTime): + return + + let distance = bucket.periodDistance(currentTime) + let recentAvgUsage = bucket.getUsageAverageSince(distance) + let compensation = bucket.calcCompensation(recentAvgUsage) + + bucket.budget = bucket.budgetCap + compensation + bucket.lastTimeFull = currentTime + +proc update(bucket: TokenBucket, currentTime: Moment) = + if bucket.replenishMode == ReplenishMode.Compensating: + updateWithCompensation(bucket, currentTime) + else: + updateStrict(bucket, currentTime) + +## Returns the available capacity of the bucket: (budget, budgetCap) +proc getAvailableCapacity*( + bucket: TokenBucket, currentTime: Moment +): tuple[budget: int, budgetCap: int] = + if periodElapsed(bucket, currentTime): + case bucket.replenishMode + of ReplenishMode.Strict: + return (bucket.budgetCap, bucket.budgetCap) + of ReplenishMode.Compensating: + let distance = bucket.periodDistance(currentTime) + let recentAvgUsage = bucket.getUsageAverageSince(distance) + let compensation = bucket.calcCompensation(recentAvgUsage) + let availableBudget = bucket.budgetCap + compensation + return (availableBudget, bucket.budgetCap) + return (bucket.budget, bucket.budgetCap) + +proc tryConsume*(bucket: TokenBucket, tokens: int, now = Moment.now()): bool = + ## If `tokens` are available, consume them, + ## Otherwhise, return false. + + if bucket.budget >= bucket.budgetCap: + bucket.lastTimeFull = now + + if bucket.budget >= tokens: + bucket.budget -= tokens + return true + + bucket.update(now) + + if bucket.budget >= tokens: + bucket.budget -= tokens + return true + else: + return false + +proc replenish*(bucket: TokenBucket, tokens: int, now = Moment.now()) = + ## Add `tokens` to the budget (capped to the bucket capacity) + bucket.budget += tokens + bucket.update(now) + +proc new*( + T: type[TokenBucket], + budgetCap: int, + fillDuration: Duration = 1.seconds, + mode: ReplenishMode = ReplenishMode.Compensating, +): T = + assert not isZero(fillDuration) + assert budgetCap != 0 + + ## Create different mode TokenBucket + case mode + of ReplenishMode.Strict: + return T( + budget: budgetCap, + budgetCap: budgetCap, + fillDuration: fillDuration, + lastTimeFull: Moment.now(), + replenishMode: mode, + ) + of ReplenishMode.Compensating: + T( + budget: budgetCap, + budgetCap: budgetCap, + fillDuration: fillDuration, + lastTimeFull: Moment.now(), + replenishMode: mode, + maxCompensation: budgetCap.float * BUDGET_COMPENSATION_LIMIT_PERCENT, + ) + +proc newStrict*(T: type[TokenBucket], capacity: int, period: Duration): TokenBucket = + T.new(capacity, period, ReplenishMode.Strict) + +proc newCompensating*( + T: type[TokenBucket], capacity: int, period: Duration +): TokenBucket = + T.new(capacity, period, ReplenishMode.Compensating) + +func `$`*(b: TokenBucket): string {.inline.} = + if isNil(b): + return "nil" + return $b.budgetCap & "/" & $b.fillDuration + +func `$`*(ob: Option[TokenBucket]): string {.inline.} = + if ob.isNone(): + return "no-limit" + + return $ob.get() diff --git a/tests/test_sqlite_store.nim b/tests/test_sqlite_store.nim new file mode 100644 index 0000000..98dbe08 --- /dev/null +++ b/tests/test_sqlite_store.nim @@ -0,0 +1,40 @@ +{.used.} + +import testutils/unittests +import ../ratelimit/store/sqlite +import ../ratelimit/store/store +import chronos + +suite "SqliteRateLimitStore Tests": + asyncTest "newSqliteRateLimitStore - creates store with default values": + ## Given & When + let now = Moment.now() + let store = newSqliteRateLimitStore() + defer: + store.close() + + ## Then + let bucketState = await store.loadBucketState() + check bucketState.budget == 10 + check bucketState.budgetCap == 10 + check bucketState.lastTimeFull.epochSeconds() == now.epochSeconds() + + asyncTest "saveBucketState and loadBucketState - state persistence": + ## Given + let now = Moment.now() + let store = newSqliteRateLimitStore() + defer: + store.close() + + let newBucketState = BucketState(budget: 5, budgetCap: 20, lastTimeFull: now) + + ## When + let saveResult = await store.saveBucketState(newBucketState) + let loadedState = await store.loadBucketState() + + ## Then + check saveResult == true + check loadedState.budget == newBucketState.budget + check loadedState.budgetCap == newBucketState.budgetCap + check loadedState.lastTimeFull.epochSeconds() == + newBucketState.lastTimeFull.epochSeconds() From 378d6a5433748d5998e746682c5d2fa606817e0a Mon Sep 17 00:00:00 2001 From: pablo Date: Wed, 16 Jul 2025 09:22:29 +0300 Subject: [PATCH 02/17] fix: use kv store --- migrations/001_create_ratelimit_state.sql | 11 ++-- ratelimit/store/sqlite.nim | 74 ++++++++++++++--------- tests/test_sqlite_store.nim | 2 - 3 files changed, 48 insertions(+), 39 deletions(-) diff --git a/migrations/001_create_ratelimit_state.sql b/migrations/001_create_ratelimit_state.sql index 11431ab..030377c 100644 --- a/migrations/001_create_ratelimit_state.sql +++ b/migrations/001_create_ratelimit_state.sql @@ -1,7 +1,4 @@ - -- will only exist one row in the table - CREATE TABLE IF NOT EXISTS bucket_state ( - id INTEGER PRIMARY KEY, - budget INTEGER NOT NULL, - budget_cap INTEGER NOT NULL, - last_time_full_seconds INTEGER NOT NULL - ) \ No newline at end of file +CREATE TABLE IF NOT EXISTS kv_store ( + key TEXT PRIMARY KEY, + value BLOB +); \ No newline at end of file diff --git a/ratelimit/store/sqlite.nim b/ratelimit/store/sqlite.nim index d8aaf38..ba3e72e 100644 --- a/ratelimit/store/sqlite.nim +++ b/ratelimit/store/sqlite.nim @@ -1,5 +1,6 @@ import std/times import std/strutils +import std/json import ./store import chronos import db_connector/db_sqlite @@ -9,6 +10,8 @@ type SqliteRateLimitStore* = ref object db: DbConn dbPath: string +const BUCKET_STATE_KEY = "rate_limit_bucket_state" + proc newSqliteRateLimitStore*(dbPath: string = ":memory:"): SqliteRateLimitStore = result = SqliteRateLimitStore(dbPath: dbPath) result.db = open(dbPath, "", "", "") @@ -16,25 +19,35 @@ proc newSqliteRateLimitStore*(dbPath: string = ":memory:"): SqliteRateLimitStore # Create table if it doesn't exist result.db.exec( sql""" - CREATE TABLE IF NOT EXISTS bucket_state ( - id INTEGER PRIMARY KEY, - budget INTEGER NOT NULL, - budget_cap INTEGER NOT NULL, - last_time_full_seconds INTEGER NOT NULL + CREATE TABLE IF NOT EXISTS kv_store ( + key TEXT PRIMARY KEY, + value BLOB ) """ ) - # Insert default state if table is empty - let count = result.db.getValue(sql"SELECT COUNT(*) FROM bucket_state").parseInt() + # Insert default state if key doesn't exist + let count = result.db + .getValue(sql"SELECT COUNT(*) FROM kv_store WHERE key = ?", BUCKET_STATE_KEY) + .parseInt() if count == 0: let defaultTimeSeconds = Moment.now().epochSeconds() + let defaultState = BucketState( + budget: 10, + budgetCap: 10, + lastTimeFull: Moment.init(defaultTimeSeconds, chronos.seconds(1)), + ) + + # Serialize to JSON + let jsonState = + %*{ + "budget": defaultState.budget, + "budgetCap": defaultState.budgetCap, + "lastTimeFullSeconds": defaultTimeSeconds, + } + result.db.exec( - sql""" - INSERT INTO bucket_state (id, budget, budget_cap, last_time_full_seconds) - VALUES (1, 10, 10, ?) - """, - defaultTimeSeconds, + sql"INSERT INTO kv_store (key, value) VALUES (?, ?)", BUCKET_STATE_KEY, $jsonState ) proc close*(store: SqliteRateLimitStore) = @@ -47,32 +60,33 @@ proc saveBucketState*( try: # Convert Moment to Unix seconds for storage let lastTimeSeconds = bucketState.lastTimeFull.epochSeconds() + + # Serialize to JSON + let jsonState = + %*{ + "budget": bucketState.budget, + "budgetCap": bucketState.budgetCap, + "lastTimeFullSeconds": lastTimeSeconds, + } + store.db.exec( - sql""" - UPDATE bucket_state - SET budget = ?, budget_cap = ?, last_time_full_seconds = ? - WHERE id = 1 - """, - bucketState.budget, - bucketState.budgetCap, - lastTimeSeconds, + sql"UPDATE kv_store SET value = ? WHERE key = ?", $jsonState, BUCKET_STATE_KEY ) return true except: return false proc loadBucketState*(store: SqliteRateLimitStore): Future[BucketState] {.async.} = - let row = store.db.getRow( - sql""" - SELECT budget, budget_cap, last_time_full_seconds - FROM bucket_state - WHERE id = 1 - """ - ) - # Convert Unix seconds back to Moment (seconds precission) - let unixSeconds = row[2].parseInt().int64 + let jsonStr = + store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY) + + # Parse JSON and reconstruct BucketState + let jsonData = parseJson(jsonStr) + let unixSeconds = jsonData["lastTimeFullSeconds"].getInt().int64 let lastTimeFull = Moment.init(unixSeconds, chronos.seconds(1)) return BucketState( - budget: row[0].parseInt(), budgetCap: row[1].parseInt(), lastTimeFull: lastTimeFull + budget: jsonData["budget"].getInt(), + budgetCap: jsonData["budgetCap"].getInt(), + lastTimeFull: lastTimeFull, ) diff --git a/tests/test_sqlite_store.nim b/tests/test_sqlite_store.nim index 98dbe08..8728992 100644 --- a/tests/test_sqlite_store.nim +++ b/tests/test_sqlite_store.nim @@ -1,5 +1,3 @@ -{.used.} - import testutils/unittests import ../ratelimit/store/sqlite import ../ratelimit/store/store From ce18bb0f50911d4316220e361625f8ef23fbf739 Mon Sep 17 00:00:00 2001 From: pablo Date: Wed, 16 Jul 2025 10:05:47 +0300 Subject: [PATCH 03/17] fix: tests --- ratelimit/store/memory.nim | 10 +++--- ratelimit/store/sqlite.nim | 70 +++++++++---------------------------- ratelimit/store/store.nim | 4 +-- tests/test_sqlite_store.nim | 44 ++++++++++++++--------- 4 files changed, 51 insertions(+), 77 deletions(-) diff --git a/ratelimit/store/memory.nim b/ratelimit/store/memory.nim index ef19de1..557a17d 100644 --- a/ratelimit/store/memory.nim +++ b/ratelimit/store/memory.nim @@ -1,4 +1,4 @@ -import std/times +import std/[times, options] import ./store import chronos @@ -8,8 +8,6 @@ type MemoryRateLimitStore* = ref object proc newMemoryRateLimitStore*(): MemoryRateLimitStore = result = MemoryRateLimitStore() - result.bucketState = - BucketState(budget: 10, budgetCap: 10, lastTimeFull: Moment.now()) proc saveBucketState*( store: MemoryRateLimitStore, bucketState: BucketState @@ -17,5 +15,7 @@ proc saveBucketState*( store.bucketState = bucketState return true -proc loadBucketState*(store: MemoryRateLimitStore): Future[BucketState] {.async.} = - return store.bucketState +proc loadBucketState*( + store: MemoryRateLimitStore +): Future[Option[BucketState]] {.async.} = + return some(store.bucketState) diff --git a/ratelimit/store/sqlite.nim b/ratelimit/store/sqlite.nim index ba3e72e..e364e5d 100644 --- a/ratelimit/store/sqlite.nim +++ b/ratelimit/store/sqlite.nim @@ -1,6 +1,4 @@ -import std/times -import std/strutils -import std/json +import std/[times, strutils, json, options] import ./store import chronos import db_connector/db_sqlite @@ -12,47 +10,8 @@ type SqliteRateLimitStore* = ref object const BUCKET_STATE_KEY = "rate_limit_bucket_state" -proc newSqliteRateLimitStore*(dbPath: string = ":memory:"): SqliteRateLimitStore = - result = SqliteRateLimitStore(dbPath: dbPath) - result.db = open(dbPath, "", "", "") - - # Create table if it doesn't exist - result.db.exec( - sql""" - CREATE TABLE IF NOT EXISTS kv_store ( - key TEXT PRIMARY KEY, - value BLOB - ) - """ - ) - - # Insert default state if key doesn't exist - let count = result.db - .getValue(sql"SELECT COUNT(*) FROM kv_store WHERE key = ?", BUCKET_STATE_KEY) - .parseInt() - if count == 0: - let defaultTimeSeconds = Moment.now().epochSeconds() - let defaultState = BucketState( - budget: 10, - budgetCap: 10, - lastTimeFull: Moment.init(defaultTimeSeconds, chronos.seconds(1)), - ) - - # Serialize to JSON - let jsonState = - %*{ - "budget": defaultState.budget, - "budgetCap": defaultState.budgetCap, - "lastTimeFullSeconds": defaultTimeSeconds, - } - - result.db.exec( - sql"INSERT INTO kv_store (key, value) VALUES (?, ?)", BUCKET_STATE_KEY, $jsonState - ) - -proc close*(store: SqliteRateLimitStore) = - if store.db != nil: - store.db.close() +proc newSqliteRateLimitStore*(db: DbConn): SqliteRateLimitStore = + result = SqliteRateLimitStore(db: db) proc saveBucketState*( store: SqliteRateLimitStore, bucketState: BucketState @@ -61,32 +20,37 @@ proc saveBucketState*( # Convert Moment to Unix seconds for storage let lastTimeSeconds = bucketState.lastTimeFull.epochSeconds() - # Serialize to JSON let jsonState = %*{ "budget": bucketState.budget, "budgetCap": bucketState.budgetCap, "lastTimeFullSeconds": lastTimeSeconds, } - store.db.exec( - sql"UPDATE kv_store SET value = ? WHERE key = ?", $jsonState, BUCKET_STATE_KEY + sql"INSERT INTO kv_store (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value", + BUCKET_STATE_KEY, + $jsonState, ) return true except: return false -proc loadBucketState*(store: SqliteRateLimitStore): Future[BucketState] {.async.} = +proc loadBucketState*( + store: SqliteRateLimitStore +): Future[Option[BucketState]] {.async.} = let jsonStr = store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY) + if jsonStr == "": + return none(BucketState) - # Parse JSON and reconstruct BucketState let jsonData = parseJson(jsonStr) let unixSeconds = jsonData["lastTimeFullSeconds"].getInt().int64 let lastTimeFull = Moment.init(unixSeconds, chronos.seconds(1)) - return BucketState( - budget: jsonData["budget"].getInt(), - budgetCap: jsonData["budgetCap"].getInt(), - lastTimeFull: lastTimeFull, + return some( + BucketState( + budget: jsonData["budget"].getInt(), + budgetCap: jsonData["budgetCap"].getInt(), + lastTimeFull: lastTimeFull, + ) ) diff --git a/ratelimit/store/store.nim b/ratelimit/store/store.nim index fab578d..c4f6da3 100644 --- a/ratelimit/store/store.nim +++ b/ratelimit/store/store.nim @@ -1,4 +1,4 @@ -import std/[times, deques] +import std/[times, deques, options] import chronos type @@ -10,4 +10,4 @@ type RateLimitStoreConcept* = concept s s.saveBucketState(BucketState) is Future[bool] - s.loadBucketState() is Future[BucketState] + s.loadBucketState() is Future[Option[BucketState]] diff --git a/tests/test_sqlite_store.nim b/tests/test_sqlite_store.nim index 8728992..6315ec9 100644 --- a/tests/test_sqlite_store.nim +++ b/tests/test_sqlite_store.nim @@ -2,28 +2,37 @@ import testutils/unittests import ../ratelimit/store/sqlite import ../ratelimit/store/store import chronos +import db_connector/db_sqlite +import ../chat_sdk/migration +import std/[options, os] suite "SqliteRateLimitStore Tests": - asyncTest "newSqliteRateLimitStore - creates store with default values": - ## Given & When - let now = Moment.now() - let store = newSqliteRateLimitStore() - defer: - store.close() + setup: + let db = open("test-ratelimit.db", "", "", "") + runMigrations(db) + + teardown: + if db != nil: + db.close() + if fileExists("test-ratelimit.db"): + removeFile("test-ratelimit.db") + + asyncTest "newSqliteRateLimitStore - empty state": + ## Given + let store = newSqliteRateLimitStore(db) + + ## When + let loadedState = await store.loadBucketState() ## Then - let bucketState = await store.loadBucketState() - check bucketState.budget == 10 - check bucketState.budgetCap == 10 - check bucketState.lastTimeFull.epochSeconds() == now.epochSeconds() + check loadedState.isNone() asyncTest "saveBucketState and loadBucketState - state persistence": ## Given - let now = Moment.now() - let store = newSqliteRateLimitStore() - defer: - store.close() + let store = newSqliteRateLimitStore(db) + let now = Moment.now() + echo "now: ", now.epochSeconds() let newBucketState = BucketState(budget: 5, budgetCap: 20, lastTimeFull: now) ## When @@ -32,7 +41,8 @@ suite "SqliteRateLimitStore Tests": ## Then check saveResult == true - check loadedState.budget == newBucketState.budget - check loadedState.budgetCap == newBucketState.budgetCap - check loadedState.lastTimeFull.epochSeconds() == + check loadedState.isSome() + check loadedState.get().budget == newBucketState.budget + check loadedState.get().budgetCap == newBucketState.budgetCap + check loadedState.get().lastTimeFull.epochSeconds() == newBucketState.lastTimeFull.epochSeconds() From 9b6c9f359d8383fee5b863fc9a2ffe1865456186 Mon Sep 17 00:00:00 2001 From: pablo Date: Wed, 16 Jul 2025 15:51:26 +0300 Subject: [PATCH 04/17] feat: add rate limit store --- .gitignore | 2 + ratelimit/ratelimit.nim | 79 +++++++++++++++++++++++++------------- ratelimit/store/memory.nim | 10 ++--- ratelimit/store/store.nim | 2 +- ratelimit/token_bucket.nim | 20 ++++++---- tests/test_ratelimit.nim | 31 ++++++++++----- 6 files changed, 94 insertions(+), 50 deletions(-) diff --git a/.gitignore b/.gitignore index 4ee6ee2..1843ca9 100644 --- a/.gitignore +++ b/.gitignore @@ -25,5 +25,7 @@ apps/* !*.nim tests/* !*.nim +ratelimit/* +!*.nim nimble.develop nimble.paths diff --git a/ratelimit/ratelimit.nim b/ratelimit/ratelimit.nim index 9bd1bf6..ec10f2d 100644 --- a/ratelimit/ratelimit.nim +++ b/ratelimit/ratelimit.nim @@ -1,6 +1,7 @@ import std/[times, deques, options] -# import ./token_bucket -import waku/common/rate_limit/token_bucket +import ./token_bucket +# import waku/common/rate_limit/token_bucket +import ./store/store import chronos type @@ -27,7 +28,8 @@ type MessageSender*[T: Serializable] = proc(msgs: seq[tuple[msgId: string, msg: T]]): Future[void] {.async.} - RateLimitManager*[T: Serializable] = ref object + RateLimitManager*[T: Serializable, S: RateLimitStore] = ref object + store: S bucket: TokenBucket sender: MessageSender[T] running: bool @@ -35,14 +37,30 @@ type queueNormal: Deque[seq[tuple[msgId: string, msg: T]]] sleepDuration: chronos.Duration -proc newRateLimitManager*[T: Serializable]( +proc newRateLimitManager*[T: Serializable, S: RateLimitStore]( + store: S, sender: MessageSender[T], capacity: int = 100, duration: chronos.Duration = chronos.minutes(10), sleepDuration: chronos.Duration = chronos.milliseconds(1000), -): RateLimitManager[T] = - RateLimitManager[T]( - bucket: TokenBucket.newStrict(capacity, duration), +): Future[RateLimitManager[T, S]] {.async.} = + var current = await store.loadBucketState() + if current.isNone(): + # initialize bucket state with full capacity + current = some( + BucketState(budget: capacity, budgetCap: capacity, lastTimeFull: Moment.now()) + ) + discard await store.saveBucketState(current.get()) + + return RateLimitManager[T, S]( + store: store, + bucket: TokenBucket.new( + current.get().budgetCap, + duration, + ReplenishMode.Strict, + current.get().budget, + current.get().lastTimeFull, + ), sender: sender, running: false, queueCritical: Deque[seq[tuple[msgId: string, msg: T]]](), @@ -50,10 +68,10 @@ proc newRateLimitManager*[T: Serializable]( sleepDuration: sleepDuration, ) -proc getCapacityState[T: Serializable]( - manager: RateLimitManager[T], now: Moment, count: int = 1 +proc getCapacityState[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S], now: Moment, count: int = 1 ): CapacityState = - let (budget, budgetCap) = manager.bucket.getAvailableCapacity(now) + let (budget, budgetCap, _) = manager.bucket.getAvailableCapacity(now) let countAfter = budget - count let ratio = countAfter.float / budgetCap.float if ratio < 0.0: @@ -63,15 +81,15 @@ proc getCapacityState[T: Serializable]( else: return CapacityState.Normal -proc passToSender[T: Serializable]( - manager: RateLimitManager[T], +proc passToSender[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S], msgs: seq[tuple[msgId: string, msg: T]], now: Moment, priority: Priority, ): Future[SendResult] {.async.} = let count = msgs.len - let capacity = manager.bucket.tryConsume(count, now) - if not capacity: + let consumed = manager.bucket.tryConsume(count, now) + if not consumed: case priority of Priority.Critical: manager.queueCritical.addLast(msgs) @@ -81,11 +99,16 @@ proc passToSender[T: Serializable]( return SendResult.Enqueued of Priority.Optional: return SendResult.Dropped + + let (budget, budgetCap, lastTimeFull) = manager.bucket.getAvailableCapacity(now) + discard await manager.store.saveBucketState( + BucketState(budget: budget, budgetCap: budgetCap, lastTimeFull: lastTimeFull) + ) await manager.sender(msgs) return SendResult.PassedToSender -proc processCriticalQueue[T: Serializable]( - manager: RateLimitManager[T], now: Moment +proc processCriticalQueue[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S], now: Moment ): Future[void] {.async.} = while manager.queueCritical.len > 0: let msgs = manager.queueCritical.popFirst() @@ -99,8 +122,8 @@ proc processCriticalQueue[T: Serializable]( manager.queueCritical.addFirst(msgs) break -proc processNormalQueue[T: Serializable]( - manager: RateLimitManager[T], now: Moment +proc processNormalQueue[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S], now: Moment ): Future[void] {.async.} = while manager.queueNormal.len > 0: let msgs = manager.queueNormal.popFirst() @@ -112,13 +135,13 @@ proc processNormalQueue[T: Serializable]( manager.queueNormal.addFirst(msgs) break -proc sendOrEnqueue*[T: Serializable]( - manager: RateLimitManager[T], +proc sendOrEnqueue*[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S], msgs: seq[tuple[msgId: string, msg: T]], priority: Priority, now: Moment = Moment.now(), ): Future[SendResult] {.async.} = - let (_, budgetCap) = manager.bucket.getAvailableCapacity(now) + let (_, budgetCap, _) = manager.bucket.getAvailableCapacity(now) if msgs.len.float / budgetCap.float >= 0.3: # drop batch if it's too large to avoid starvation return SendResult.DroppedBatchTooLarge @@ -147,8 +170,8 @@ proc sendOrEnqueue*[T: Serializable]( of Priority.Optional: return SendResult.Dropped -proc getEnqueued*[T: Serializable]( - manager: RateLimitManager[T] +proc getEnqueued*[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S] ): tuple[ critical: seq[tuple[msgId: string, msg: T]], normal: seq[tuple[msgId: string, msg: T]] ] = @@ -163,8 +186,8 @@ proc getEnqueued*[T: Serializable]( return (criticalMsgs, normalMsgs) -proc start*[T: Serializable]( - manager: RateLimitManager[T], +proc start*[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S], nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} = Moment.now(), ): Future[void] {.async.} = @@ -181,10 +204,12 @@ proc start*[T: Serializable]( echo "Error in queue processing: ", e.msg await sleepAsync(manager.sleepDuration) -proc stop*[T: Serializable](manager: RateLimitManager[T]) = +proc stop*[T: Serializable, S: RateLimitStore](manager: RateLimitManager[T, S]) = manager.running = false -func `$`*[T: Serializable](b: RateLimitManager[T]): string {.inline.} = +func `$`*[T: Serializable, S: RateLimitStore]( + b: RateLimitManager[T, S] +): string {.inline.} = if isNil(b): return "nil" return diff --git a/ratelimit/store/memory.nim b/ratelimit/store/memory.nim index 557a17d..6391314 100644 --- a/ratelimit/store/memory.nim +++ b/ratelimit/store/memory.nim @@ -4,18 +4,18 @@ import chronos # Memory Implementation type MemoryRateLimitStore* = ref object - bucketState: BucketState + bucketState: Option[BucketState] -proc newMemoryRateLimitStore*(): MemoryRateLimitStore = - result = MemoryRateLimitStore() +proc new*(T: type[MemoryRateLimitStore]): T = + return T(bucketState: none(BucketState)) proc saveBucketState*( store: MemoryRateLimitStore, bucketState: BucketState ): Future[bool] {.async.} = - store.bucketState = bucketState + store.bucketState = some(bucketState) return true proc loadBucketState*( store: MemoryRateLimitStore ): Future[Option[BucketState]] {.async.} = - return some(store.bucketState) + return store.bucketState diff --git a/ratelimit/store/store.nim b/ratelimit/store/store.nim index c4f6da3..c916750 100644 --- a/ratelimit/store/store.nim +++ b/ratelimit/store/store.nim @@ -7,7 +7,7 @@ type budgetCap*: int lastTimeFull*: Moment - RateLimitStoreConcept* = + RateLimitStore* = concept s s.saveBucketState(BucketState) is Future[bool] s.loadBucketState() is Future[Option[BucketState]] diff --git a/ratelimit/token_bucket.nim b/ratelimit/token_bucket.nim index 447569a..e4a4487 100644 --- a/ratelimit/token_bucket.nim +++ b/ratelimit/token_bucket.nim @@ -112,18 +112,18 @@ proc update(bucket: TokenBucket, currentTime: Moment) = ## Returns the available capacity of the bucket: (budget, budgetCap) proc getAvailableCapacity*( bucket: TokenBucket, currentTime: Moment -): tuple[budget: int, budgetCap: int] = +): tuple[budget: int, budgetCap: int, lastTimeFull: Moment] = if periodElapsed(bucket, currentTime): case bucket.replenishMode of ReplenishMode.Strict: - return (bucket.budgetCap, bucket.budgetCap) + return (bucket.budgetCap, bucket.budgetCap, bucket.lastTimeFull) of ReplenishMode.Compensating: let distance = bucket.periodDistance(currentTime) let recentAvgUsage = bucket.getUsageAverageSince(distance) let compensation = bucket.calcCompensation(recentAvgUsage) let availableBudget = bucket.budgetCap + compensation - return (availableBudget, bucket.budgetCap) - return (bucket.budget, bucket.budgetCap) + return (availableBudget, bucket.budgetCap, bucket.lastTimeFull) + return (bucket.budget, bucket.budgetCap, bucket.lastTimeFull) proc tryConsume*(bucket: TokenBucket, tokens: int, now = Moment.now()): bool = ## If `tokens` are available, consume them, @@ -154,26 +154,30 @@ proc new*( budgetCap: int, fillDuration: Duration = 1.seconds, mode: ReplenishMode = ReplenishMode.Compensating, + budget: int = -1, # -1 means "use budgetCap" + lastTimeFull: Moment = Moment.now(), ): T = assert not isZero(fillDuration) assert budgetCap != 0 + let actualBudget = if budget == -1: budgetCap else: budget + assert actualBudget >= 0 and actualBudget <= budgetCap ## Create different mode TokenBucket case mode of ReplenishMode.Strict: return T( - budget: budgetCap, + budget: actualBudget, budgetCap: budgetCap, fillDuration: fillDuration, - lastTimeFull: Moment.now(), + lastTimeFull: lastTimeFull, replenishMode: mode, ) of ReplenishMode.Compensating: T( - budget: budgetCap, + budget: actualBudget, budgetCap: budgetCap, fillDuration: fillDuration, - lastTimeFull: Moment.now(), + lastTimeFull: lastTimeFull, replenishMode: mode, maxCompensation: budgetCap.float * BUDGET_COMPENSATION_LIMIT_PERCENT, ) diff --git a/tests/test_ratelimit.nim b/tests/test_ratelimit.nim index f38ddbf..a10dd9e 100644 --- a/tests/test_ratelimit.nim +++ b/tests/test_ratelimit.nim @@ -2,6 +2,7 @@ import testutils/unittests import ../ratelimit/ratelimit +import ../ratelimit/store/memory import chronos import strutils @@ -25,8 +26,9 @@ suite "Queue RateLimitManager": asyncTest "sendOrEnqueue - immediate send when capacity available": ## Given - let manager = newRateLimitManager[string]( - mockSender, capacity = 10, duration = chronos.milliseconds(100) + let store: MemoryRateLimitStore = MemoryRateLimitStore.new() + let manager = await newRateLimitManager[string, MemoryRateLimitStore]( + store, mockSender, capacity = 10, duration = chronos.milliseconds(100) ) let testMsg = "Hello World" @@ -43,8 +45,9 @@ suite "Queue RateLimitManager": asyncTest "sendOrEnqueue - multiple messages": ## Given - let manager = newRateLimitManager[string]( - mockSender, capacity = 10, duration = chronos.milliseconds(100) + let store = MemoryRateLimitStore.new() + let manager = await newRateLimitManager[string, MemoryRateLimitStore]( + store, mockSender, capacity = 10, duration = chronos.milliseconds(100) ) ## When @@ -63,7 +66,9 @@ suite "Queue RateLimitManager": asyncTest "start and stop - basic functionality": ## Given - let manager = newRateLimitManager[string]( + let store = MemoryRateLimitStore.new() + let manager = await newRateLimitManager[string, MemoryRateLimitStore]( + store, mockSender, capacity = 10, duration = chronos.milliseconds(100), @@ -94,7 +99,9 @@ suite "Queue RateLimitManager": asyncTest "start and stop - drop large batch": ## Given - let manager = newRateLimitManager[string]( + let store = MemoryRateLimitStore.new() + let manager = await newRateLimitManager[string, MemoryRateLimitStore]( + store, mockSender, capacity = 2, duration = chronos.milliseconds(100), @@ -109,7 +116,9 @@ suite "Queue RateLimitManager": asyncTest "enqueue - enqueue critical only when exceeded": ## Given - let manager = newRateLimitManager[string]( + let store = MemoryRateLimitStore.new() + let manager = await newRateLimitManager[string, MemoryRateLimitStore]( + store, mockSender, capacity = 10, duration = chronos.milliseconds(100), @@ -163,7 +172,9 @@ suite "Queue RateLimitManager": asyncTest "enqueue - enqueue normal on 70% capacity": ## Given - let manager = newRateLimitManager[string]( + let store = MemoryRateLimitStore.new() + let manager = await newRateLimitManager[string, MemoryRateLimitStore]( + store, mockSender, capacity = 10, duration = chronos.milliseconds(100), @@ -220,7 +231,9 @@ suite "Queue RateLimitManager": asyncTest "enqueue - process queued messages": ## Given - let manager = newRateLimitManager[string]( + let store = MemoryRateLimitStore.new() + let manager = await newRateLimitManager[string, MemoryRateLimitStore]( + store, mockSender, capacity = 10, duration = chronos.milliseconds(200), From 27fd37c4336cb47c1613f691dc5f7edbe6e931e8 Mon Sep 17 00:00:00 2001 From: pablo Date: Mon, 14 Jul 2025 11:14:36 +0300 Subject: [PATCH 05/17] feat: store rate limit --- .gitignore | 2 + migrations/001_create_ratelimit_state.sql | 7 + ratelimit/store/memory.nim | 21 +++ ratelimit/store/sqlite.nim | 78 +++++++++ ratelimit/store/store.nim | 13 ++ ratelimit/token_bucket.nim | 198 ++++++++++++++++++++++ tests/test_sqlite_store.nim | 40 +++++ 7 files changed, 359 insertions(+) create mode 100644 migrations/001_create_ratelimit_state.sql create mode 100644 ratelimit/store/memory.nim create mode 100644 ratelimit/store/sqlite.nim create mode 100644 ratelimit/store/store.nim create mode 100644 ratelimit/token_bucket.nim create mode 100644 tests/test_sqlite_store.nim diff --git a/.gitignore b/.gitignore index 3a79a35..4ee6ee2 100644 --- a/.gitignore +++ b/.gitignore @@ -23,5 +23,7 @@ chat_sdk/* !*.nim apps/* !*.nim +tests/* +!*.nim nimble.develop nimble.paths diff --git a/migrations/001_create_ratelimit_state.sql b/migrations/001_create_ratelimit_state.sql new file mode 100644 index 0000000..11431ab --- /dev/null +++ b/migrations/001_create_ratelimit_state.sql @@ -0,0 +1,7 @@ + -- will only exist one row in the table + CREATE TABLE IF NOT EXISTS bucket_state ( + id INTEGER PRIMARY KEY, + budget INTEGER NOT NULL, + budget_cap INTEGER NOT NULL, + last_time_full_seconds INTEGER NOT NULL + ) \ No newline at end of file diff --git a/ratelimit/store/memory.nim b/ratelimit/store/memory.nim new file mode 100644 index 0000000..ef19de1 --- /dev/null +++ b/ratelimit/store/memory.nim @@ -0,0 +1,21 @@ +import std/times +import ./store +import chronos + +# Memory Implementation +type MemoryRateLimitStore* = ref object + bucketState: BucketState + +proc newMemoryRateLimitStore*(): MemoryRateLimitStore = + result = MemoryRateLimitStore() + result.bucketState = + BucketState(budget: 10, budgetCap: 10, lastTimeFull: Moment.now()) + +proc saveBucketState*( + store: MemoryRateLimitStore, bucketState: BucketState +): Future[bool] {.async.} = + store.bucketState = bucketState + return true + +proc loadBucketState*(store: MemoryRateLimitStore): Future[BucketState] {.async.} = + return store.bucketState diff --git a/ratelimit/store/sqlite.nim b/ratelimit/store/sqlite.nim new file mode 100644 index 0000000..d8aaf38 --- /dev/null +++ b/ratelimit/store/sqlite.nim @@ -0,0 +1,78 @@ +import std/times +import std/strutils +import ./store +import chronos +import db_connector/db_sqlite + +# SQLite Implementation +type SqliteRateLimitStore* = ref object + db: DbConn + dbPath: string + +proc newSqliteRateLimitStore*(dbPath: string = ":memory:"): SqliteRateLimitStore = + result = SqliteRateLimitStore(dbPath: dbPath) + result.db = open(dbPath, "", "", "") + + # Create table if it doesn't exist + result.db.exec( + sql""" + CREATE TABLE IF NOT EXISTS bucket_state ( + id INTEGER PRIMARY KEY, + budget INTEGER NOT NULL, + budget_cap INTEGER NOT NULL, + last_time_full_seconds INTEGER NOT NULL + ) + """ + ) + + # Insert default state if table is empty + let count = result.db.getValue(sql"SELECT COUNT(*) FROM bucket_state").parseInt() + if count == 0: + let defaultTimeSeconds = Moment.now().epochSeconds() + result.db.exec( + sql""" + INSERT INTO bucket_state (id, budget, budget_cap, last_time_full_seconds) + VALUES (1, 10, 10, ?) + """, + defaultTimeSeconds, + ) + +proc close*(store: SqliteRateLimitStore) = + if store.db != nil: + store.db.close() + +proc saveBucketState*( + store: SqliteRateLimitStore, bucketState: BucketState +): Future[bool] {.async.} = + try: + # Convert Moment to Unix seconds for storage + let lastTimeSeconds = bucketState.lastTimeFull.epochSeconds() + store.db.exec( + sql""" + UPDATE bucket_state + SET budget = ?, budget_cap = ?, last_time_full_seconds = ? + WHERE id = 1 + """, + bucketState.budget, + bucketState.budgetCap, + lastTimeSeconds, + ) + return true + except: + return false + +proc loadBucketState*(store: SqliteRateLimitStore): Future[BucketState] {.async.} = + let row = store.db.getRow( + sql""" + SELECT budget, budget_cap, last_time_full_seconds + FROM bucket_state + WHERE id = 1 + """ + ) + # Convert Unix seconds back to Moment (seconds precission) + let unixSeconds = row[2].parseInt().int64 + let lastTimeFull = Moment.init(unixSeconds, chronos.seconds(1)) + + return BucketState( + budget: row[0].parseInt(), budgetCap: row[1].parseInt(), lastTimeFull: lastTimeFull + ) diff --git a/ratelimit/store/store.nim b/ratelimit/store/store.nim new file mode 100644 index 0000000..fab578d --- /dev/null +++ b/ratelimit/store/store.nim @@ -0,0 +1,13 @@ +import std/[times, deques] +import chronos + +type + BucketState* = object + budget*: int + budgetCap*: int + lastTimeFull*: Moment + + RateLimitStoreConcept* = + concept s + s.saveBucketState(BucketState) is Future[bool] + s.loadBucketState() is Future[BucketState] diff --git a/ratelimit/token_bucket.nim b/ratelimit/token_bucket.nim new file mode 100644 index 0000000..447569a --- /dev/null +++ b/ratelimit/token_bucket.nim @@ -0,0 +1,198 @@ +{.push raises: [].} + +import chronos, std/math, std/options + +const BUDGET_COMPENSATION_LIMIT_PERCENT = 0.25 + +## This is an extract from chronos/rate_limit.nim due to the found bug in the original implementation. +## Unfortunately that bug cannot be solved without harm the original features of TokenBucket class. +## So, this current shortcut is used to enable move ahead with nwaku rate limiter implementation. +## ref: https://github.com/status-im/nim-chronos/issues/500 +## +## This version of TokenBucket is different from the original one in chronos/rate_limit.nim in many ways: +## - It has a new mode called `Compensating` which is the default mode. +## Compensation is calculated as the not used bucket capacity in the last measured period(s) in average. +## or up until maximum the allowed compansation treshold (Currently it is const 25%). +## Also compensation takes care of the proper time period calculation to avoid non-usage periods that can lead to +## overcompensation. +## - Strict mode is also available which will only replenish when time period is over but also will fill +## the bucket to the max capacity. + +type + ReplenishMode* = enum + Strict + Compensating + + TokenBucket* = ref object + budget: int ## Current number of tokens in the bucket + budgetCap: int ## Bucket capacity + lastTimeFull: Moment + ## This timer measures the proper periodizaiton of the bucket refilling + fillDuration: Duration ## Refill period + case replenishMode*: ReplenishMode + of Strict: + ## In strict mode, the bucket is refilled only till the budgetCap + discard + of Compensating: + ## This is the default mode. + maxCompensation: float + +func periodDistance(bucket: TokenBucket, currentTime: Moment): float = + ## notice fillDuration cannot be zero by design + ## period distance is a float number representing the calculated period time + ## since the last time bucket was refilled. + return + nanoseconds(currentTime - bucket.lastTimeFull).float / + nanoseconds(bucket.fillDuration).float + +func getUsageAverageSince(bucket: TokenBucket, distance: float): float = + if distance == 0.float: + ## in case there is zero time difference than the usage percentage is 100% + return 1.0 + + ## budgetCap can never be zero + ## usage average is calculated as a percentage of total capacity available over + ## the measured period + return bucket.budget.float / bucket.budgetCap.float / distance + +proc calcCompensation(bucket: TokenBucket, averageUsage: float): int = + # if we already fully used or even overused the tokens, there is no place for compensation + if averageUsage >= 1.0: + return 0 + + ## compensation is the not used bucket capacity in the last measured period(s) in average. + ## or maximum the allowed compansation treshold + let compensationPercent = + min((1.0 - averageUsage) * bucket.budgetCap.float, bucket.maxCompensation) + return trunc(compensationPercent).int + +func periodElapsed(bucket: TokenBucket, currentTime: Moment): bool = + return currentTime - bucket.lastTimeFull >= bucket.fillDuration + +## Update will take place if bucket is empty and trying to consume tokens. +## It checks if the bucket can be replenished as refill duration is passed or not. +## - strict mode: +proc updateStrict(bucket: TokenBucket, currentTime: Moment) = + if bucket.fillDuration == default(Duration): + bucket.budget = min(bucket.budgetCap, bucket.budget) + return + + if not periodElapsed(bucket, currentTime): + return + + bucket.budget = bucket.budgetCap + bucket.lastTimeFull = currentTime + +## - compensating - ballancing load: +## - between updates we calculate average load (current bucket capacity / number of periods till last update) +## - gives the percentage load used recently +## - with this we can replenish bucket up to 100% + calculated leftover from previous period (caped with max treshold) +proc updateWithCompensation(bucket: TokenBucket, currentTime: Moment) = + if bucket.fillDuration == default(Duration): + bucket.budget = min(bucket.budgetCap, bucket.budget) + return + + # do not replenish within the same period + if not periodElapsed(bucket, currentTime): + return + + let distance = bucket.periodDistance(currentTime) + let recentAvgUsage = bucket.getUsageAverageSince(distance) + let compensation = bucket.calcCompensation(recentAvgUsage) + + bucket.budget = bucket.budgetCap + compensation + bucket.lastTimeFull = currentTime + +proc update(bucket: TokenBucket, currentTime: Moment) = + if bucket.replenishMode == ReplenishMode.Compensating: + updateWithCompensation(bucket, currentTime) + else: + updateStrict(bucket, currentTime) + +## Returns the available capacity of the bucket: (budget, budgetCap) +proc getAvailableCapacity*( + bucket: TokenBucket, currentTime: Moment +): tuple[budget: int, budgetCap: int] = + if periodElapsed(bucket, currentTime): + case bucket.replenishMode + of ReplenishMode.Strict: + return (bucket.budgetCap, bucket.budgetCap) + of ReplenishMode.Compensating: + let distance = bucket.periodDistance(currentTime) + let recentAvgUsage = bucket.getUsageAverageSince(distance) + let compensation = bucket.calcCompensation(recentAvgUsage) + let availableBudget = bucket.budgetCap + compensation + return (availableBudget, bucket.budgetCap) + return (bucket.budget, bucket.budgetCap) + +proc tryConsume*(bucket: TokenBucket, tokens: int, now = Moment.now()): bool = + ## If `tokens` are available, consume them, + ## Otherwhise, return false. + + if bucket.budget >= bucket.budgetCap: + bucket.lastTimeFull = now + + if bucket.budget >= tokens: + bucket.budget -= tokens + return true + + bucket.update(now) + + if bucket.budget >= tokens: + bucket.budget -= tokens + return true + else: + return false + +proc replenish*(bucket: TokenBucket, tokens: int, now = Moment.now()) = + ## Add `tokens` to the budget (capped to the bucket capacity) + bucket.budget += tokens + bucket.update(now) + +proc new*( + T: type[TokenBucket], + budgetCap: int, + fillDuration: Duration = 1.seconds, + mode: ReplenishMode = ReplenishMode.Compensating, +): T = + assert not isZero(fillDuration) + assert budgetCap != 0 + + ## Create different mode TokenBucket + case mode + of ReplenishMode.Strict: + return T( + budget: budgetCap, + budgetCap: budgetCap, + fillDuration: fillDuration, + lastTimeFull: Moment.now(), + replenishMode: mode, + ) + of ReplenishMode.Compensating: + T( + budget: budgetCap, + budgetCap: budgetCap, + fillDuration: fillDuration, + lastTimeFull: Moment.now(), + replenishMode: mode, + maxCompensation: budgetCap.float * BUDGET_COMPENSATION_LIMIT_PERCENT, + ) + +proc newStrict*(T: type[TokenBucket], capacity: int, period: Duration): TokenBucket = + T.new(capacity, period, ReplenishMode.Strict) + +proc newCompensating*( + T: type[TokenBucket], capacity: int, period: Duration +): TokenBucket = + T.new(capacity, period, ReplenishMode.Compensating) + +func `$`*(b: TokenBucket): string {.inline.} = + if isNil(b): + return "nil" + return $b.budgetCap & "/" & $b.fillDuration + +func `$`*(ob: Option[TokenBucket]): string {.inline.} = + if ob.isNone(): + return "no-limit" + + return $ob.get() diff --git a/tests/test_sqlite_store.nim b/tests/test_sqlite_store.nim new file mode 100644 index 0000000..98dbe08 --- /dev/null +++ b/tests/test_sqlite_store.nim @@ -0,0 +1,40 @@ +{.used.} + +import testutils/unittests +import ../ratelimit/store/sqlite +import ../ratelimit/store/store +import chronos + +suite "SqliteRateLimitStore Tests": + asyncTest "newSqliteRateLimitStore - creates store with default values": + ## Given & When + let now = Moment.now() + let store = newSqliteRateLimitStore() + defer: + store.close() + + ## Then + let bucketState = await store.loadBucketState() + check bucketState.budget == 10 + check bucketState.budgetCap == 10 + check bucketState.lastTimeFull.epochSeconds() == now.epochSeconds() + + asyncTest "saveBucketState and loadBucketState - state persistence": + ## Given + let now = Moment.now() + let store = newSqliteRateLimitStore() + defer: + store.close() + + let newBucketState = BucketState(budget: 5, budgetCap: 20, lastTimeFull: now) + + ## When + let saveResult = await store.saveBucketState(newBucketState) + let loadedState = await store.loadBucketState() + + ## Then + check saveResult == true + check loadedState.budget == newBucketState.budget + check loadedState.budgetCap == newBucketState.budgetCap + check loadedState.lastTimeFull.epochSeconds() == + newBucketState.lastTimeFull.epochSeconds() From 57ae8e87c01d5f6aea3e7fe49018859ba3e1513e Mon Sep 17 00:00:00 2001 From: pablo Date: Wed, 16 Jul 2025 09:22:29 +0300 Subject: [PATCH 06/17] fix: use kv store --- migrations/001_create_ratelimit_state.sql | 11 ++-- ratelimit/store/sqlite.nim | 74 ++++++++++++++--------- tests/test_sqlite_store.nim | 2 - 3 files changed, 48 insertions(+), 39 deletions(-) diff --git a/migrations/001_create_ratelimit_state.sql b/migrations/001_create_ratelimit_state.sql index 11431ab..030377c 100644 --- a/migrations/001_create_ratelimit_state.sql +++ b/migrations/001_create_ratelimit_state.sql @@ -1,7 +1,4 @@ - -- will only exist one row in the table - CREATE TABLE IF NOT EXISTS bucket_state ( - id INTEGER PRIMARY KEY, - budget INTEGER NOT NULL, - budget_cap INTEGER NOT NULL, - last_time_full_seconds INTEGER NOT NULL - ) \ No newline at end of file +CREATE TABLE IF NOT EXISTS kv_store ( + key TEXT PRIMARY KEY, + value BLOB +); \ No newline at end of file diff --git a/ratelimit/store/sqlite.nim b/ratelimit/store/sqlite.nim index d8aaf38..ba3e72e 100644 --- a/ratelimit/store/sqlite.nim +++ b/ratelimit/store/sqlite.nim @@ -1,5 +1,6 @@ import std/times import std/strutils +import std/json import ./store import chronos import db_connector/db_sqlite @@ -9,6 +10,8 @@ type SqliteRateLimitStore* = ref object db: DbConn dbPath: string +const BUCKET_STATE_KEY = "rate_limit_bucket_state" + proc newSqliteRateLimitStore*(dbPath: string = ":memory:"): SqliteRateLimitStore = result = SqliteRateLimitStore(dbPath: dbPath) result.db = open(dbPath, "", "", "") @@ -16,25 +19,35 @@ proc newSqliteRateLimitStore*(dbPath: string = ":memory:"): SqliteRateLimitStore # Create table if it doesn't exist result.db.exec( sql""" - CREATE TABLE IF NOT EXISTS bucket_state ( - id INTEGER PRIMARY KEY, - budget INTEGER NOT NULL, - budget_cap INTEGER NOT NULL, - last_time_full_seconds INTEGER NOT NULL + CREATE TABLE IF NOT EXISTS kv_store ( + key TEXT PRIMARY KEY, + value BLOB ) """ ) - # Insert default state if table is empty - let count = result.db.getValue(sql"SELECT COUNT(*) FROM bucket_state").parseInt() + # Insert default state if key doesn't exist + let count = result.db + .getValue(sql"SELECT COUNT(*) FROM kv_store WHERE key = ?", BUCKET_STATE_KEY) + .parseInt() if count == 0: let defaultTimeSeconds = Moment.now().epochSeconds() + let defaultState = BucketState( + budget: 10, + budgetCap: 10, + lastTimeFull: Moment.init(defaultTimeSeconds, chronos.seconds(1)), + ) + + # Serialize to JSON + let jsonState = + %*{ + "budget": defaultState.budget, + "budgetCap": defaultState.budgetCap, + "lastTimeFullSeconds": defaultTimeSeconds, + } + result.db.exec( - sql""" - INSERT INTO bucket_state (id, budget, budget_cap, last_time_full_seconds) - VALUES (1, 10, 10, ?) - """, - defaultTimeSeconds, + sql"INSERT INTO kv_store (key, value) VALUES (?, ?)", BUCKET_STATE_KEY, $jsonState ) proc close*(store: SqliteRateLimitStore) = @@ -47,32 +60,33 @@ proc saveBucketState*( try: # Convert Moment to Unix seconds for storage let lastTimeSeconds = bucketState.lastTimeFull.epochSeconds() + + # Serialize to JSON + let jsonState = + %*{ + "budget": bucketState.budget, + "budgetCap": bucketState.budgetCap, + "lastTimeFullSeconds": lastTimeSeconds, + } + store.db.exec( - sql""" - UPDATE bucket_state - SET budget = ?, budget_cap = ?, last_time_full_seconds = ? - WHERE id = 1 - """, - bucketState.budget, - bucketState.budgetCap, - lastTimeSeconds, + sql"UPDATE kv_store SET value = ? WHERE key = ?", $jsonState, BUCKET_STATE_KEY ) return true except: return false proc loadBucketState*(store: SqliteRateLimitStore): Future[BucketState] {.async.} = - let row = store.db.getRow( - sql""" - SELECT budget, budget_cap, last_time_full_seconds - FROM bucket_state - WHERE id = 1 - """ - ) - # Convert Unix seconds back to Moment (seconds precission) - let unixSeconds = row[2].parseInt().int64 + let jsonStr = + store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY) + + # Parse JSON and reconstruct BucketState + let jsonData = parseJson(jsonStr) + let unixSeconds = jsonData["lastTimeFullSeconds"].getInt().int64 let lastTimeFull = Moment.init(unixSeconds, chronos.seconds(1)) return BucketState( - budget: row[0].parseInt(), budgetCap: row[1].parseInt(), lastTimeFull: lastTimeFull + budget: jsonData["budget"].getInt(), + budgetCap: jsonData["budgetCap"].getInt(), + lastTimeFull: lastTimeFull, ) diff --git a/tests/test_sqlite_store.nim b/tests/test_sqlite_store.nim index 98dbe08..8728992 100644 --- a/tests/test_sqlite_store.nim +++ b/tests/test_sqlite_store.nim @@ -1,5 +1,3 @@ -{.used.} - import testutils/unittests import ../ratelimit/store/sqlite import ../ratelimit/store/store From 9f52377d44c714e1b764a0ef624fb723bec40c3e Mon Sep 17 00:00:00 2001 From: pablo Date: Wed, 16 Jul 2025 10:05:47 +0300 Subject: [PATCH 07/17] fix: tests --- ratelimit/store/memory.nim | 10 +++--- ratelimit/store/sqlite.nim | 70 +++++++++---------------------------- ratelimit/store/store.nim | 4 +-- tests/test_sqlite_store.nim | 44 ++++++++++++++--------- 4 files changed, 51 insertions(+), 77 deletions(-) diff --git a/ratelimit/store/memory.nim b/ratelimit/store/memory.nim index ef19de1..557a17d 100644 --- a/ratelimit/store/memory.nim +++ b/ratelimit/store/memory.nim @@ -1,4 +1,4 @@ -import std/times +import std/[times, options] import ./store import chronos @@ -8,8 +8,6 @@ type MemoryRateLimitStore* = ref object proc newMemoryRateLimitStore*(): MemoryRateLimitStore = result = MemoryRateLimitStore() - result.bucketState = - BucketState(budget: 10, budgetCap: 10, lastTimeFull: Moment.now()) proc saveBucketState*( store: MemoryRateLimitStore, bucketState: BucketState @@ -17,5 +15,7 @@ proc saveBucketState*( store.bucketState = bucketState return true -proc loadBucketState*(store: MemoryRateLimitStore): Future[BucketState] {.async.} = - return store.bucketState +proc loadBucketState*( + store: MemoryRateLimitStore +): Future[Option[BucketState]] {.async.} = + return some(store.bucketState) diff --git a/ratelimit/store/sqlite.nim b/ratelimit/store/sqlite.nim index ba3e72e..e364e5d 100644 --- a/ratelimit/store/sqlite.nim +++ b/ratelimit/store/sqlite.nim @@ -1,6 +1,4 @@ -import std/times -import std/strutils -import std/json +import std/[times, strutils, json, options] import ./store import chronos import db_connector/db_sqlite @@ -12,47 +10,8 @@ type SqliteRateLimitStore* = ref object const BUCKET_STATE_KEY = "rate_limit_bucket_state" -proc newSqliteRateLimitStore*(dbPath: string = ":memory:"): SqliteRateLimitStore = - result = SqliteRateLimitStore(dbPath: dbPath) - result.db = open(dbPath, "", "", "") - - # Create table if it doesn't exist - result.db.exec( - sql""" - CREATE TABLE IF NOT EXISTS kv_store ( - key TEXT PRIMARY KEY, - value BLOB - ) - """ - ) - - # Insert default state if key doesn't exist - let count = result.db - .getValue(sql"SELECT COUNT(*) FROM kv_store WHERE key = ?", BUCKET_STATE_KEY) - .parseInt() - if count == 0: - let defaultTimeSeconds = Moment.now().epochSeconds() - let defaultState = BucketState( - budget: 10, - budgetCap: 10, - lastTimeFull: Moment.init(defaultTimeSeconds, chronos.seconds(1)), - ) - - # Serialize to JSON - let jsonState = - %*{ - "budget": defaultState.budget, - "budgetCap": defaultState.budgetCap, - "lastTimeFullSeconds": defaultTimeSeconds, - } - - result.db.exec( - sql"INSERT INTO kv_store (key, value) VALUES (?, ?)", BUCKET_STATE_KEY, $jsonState - ) - -proc close*(store: SqliteRateLimitStore) = - if store.db != nil: - store.db.close() +proc newSqliteRateLimitStore*(db: DbConn): SqliteRateLimitStore = + result = SqliteRateLimitStore(db: db) proc saveBucketState*( store: SqliteRateLimitStore, bucketState: BucketState @@ -61,32 +20,37 @@ proc saveBucketState*( # Convert Moment to Unix seconds for storage let lastTimeSeconds = bucketState.lastTimeFull.epochSeconds() - # Serialize to JSON let jsonState = %*{ "budget": bucketState.budget, "budgetCap": bucketState.budgetCap, "lastTimeFullSeconds": lastTimeSeconds, } - store.db.exec( - sql"UPDATE kv_store SET value = ? WHERE key = ?", $jsonState, BUCKET_STATE_KEY + sql"INSERT INTO kv_store (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value", + BUCKET_STATE_KEY, + $jsonState, ) return true except: return false -proc loadBucketState*(store: SqliteRateLimitStore): Future[BucketState] {.async.} = +proc loadBucketState*( + store: SqliteRateLimitStore +): Future[Option[BucketState]] {.async.} = let jsonStr = store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY) + if jsonStr == "": + return none(BucketState) - # Parse JSON and reconstruct BucketState let jsonData = parseJson(jsonStr) let unixSeconds = jsonData["lastTimeFullSeconds"].getInt().int64 let lastTimeFull = Moment.init(unixSeconds, chronos.seconds(1)) - return BucketState( - budget: jsonData["budget"].getInt(), - budgetCap: jsonData["budgetCap"].getInt(), - lastTimeFull: lastTimeFull, + return some( + BucketState( + budget: jsonData["budget"].getInt(), + budgetCap: jsonData["budgetCap"].getInt(), + lastTimeFull: lastTimeFull, + ) ) diff --git a/ratelimit/store/store.nim b/ratelimit/store/store.nim index fab578d..c4f6da3 100644 --- a/ratelimit/store/store.nim +++ b/ratelimit/store/store.nim @@ -1,4 +1,4 @@ -import std/[times, deques] +import std/[times, deques, options] import chronos type @@ -10,4 +10,4 @@ type RateLimitStoreConcept* = concept s s.saveBucketState(BucketState) is Future[bool] - s.loadBucketState() is Future[BucketState] + s.loadBucketState() is Future[Option[BucketState]] diff --git a/tests/test_sqlite_store.nim b/tests/test_sqlite_store.nim index 8728992..6315ec9 100644 --- a/tests/test_sqlite_store.nim +++ b/tests/test_sqlite_store.nim @@ -2,28 +2,37 @@ import testutils/unittests import ../ratelimit/store/sqlite import ../ratelimit/store/store import chronos +import db_connector/db_sqlite +import ../chat_sdk/migration +import std/[options, os] suite "SqliteRateLimitStore Tests": - asyncTest "newSqliteRateLimitStore - creates store with default values": - ## Given & When - let now = Moment.now() - let store = newSqliteRateLimitStore() - defer: - store.close() + setup: + let db = open("test-ratelimit.db", "", "", "") + runMigrations(db) + + teardown: + if db != nil: + db.close() + if fileExists("test-ratelimit.db"): + removeFile("test-ratelimit.db") + + asyncTest "newSqliteRateLimitStore - empty state": + ## Given + let store = newSqliteRateLimitStore(db) + + ## When + let loadedState = await store.loadBucketState() ## Then - let bucketState = await store.loadBucketState() - check bucketState.budget == 10 - check bucketState.budgetCap == 10 - check bucketState.lastTimeFull.epochSeconds() == now.epochSeconds() + check loadedState.isNone() asyncTest "saveBucketState and loadBucketState - state persistence": ## Given - let now = Moment.now() - let store = newSqliteRateLimitStore() - defer: - store.close() + let store = newSqliteRateLimitStore(db) + let now = Moment.now() + echo "now: ", now.epochSeconds() let newBucketState = BucketState(budget: 5, budgetCap: 20, lastTimeFull: now) ## When @@ -32,7 +41,8 @@ suite "SqliteRateLimitStore Tests": ## Then check saveResult == true - check loadedState.budget == newBucketState.budget - check loadedState.budgetCap == newBucketState.budgetCap - check loadedState.lastTimeFull.epochSeconds() == + check loadedState.isSome() + check loadedState.get().budget == newBucketState.budget + check loadedState.get().budgetCap == newBucketState.budgetCap + check loadedState.get().lastTimeFull.epochSeconds() == newBucketState.lastTimeFull.epochSeconds() From 414ec6f920f8f0df8c56581ba4c112a7f25dea73 Mon Sep 17 00:00:00 2001 From: pablo Date: Wed, 16 Jul 2025 15:51:26 +0300 Subject: [PATCH 08/17] feat: add rate limit store --- .gitignore | 2 + chat_sdk.nimble | 4 +- ratelimit/rate_limit_manager.nim | 103 ++++++++++------ ratelimit/store/memory.nim | 10 +- ratelimit/store/store.nim | 2 +- ratelimit/token_bucket.nim | 198 ------------------------------ tests/test_rate_limit_manager.nim | 27 ++-- 7 files changed, 96 insertions(+), 250 deletions(-) delete mode 100644 ratelimit/token_bucket.nim diff --git a/.gitignore b/.gitignore index 4ee6ee2..1843ca9 100644 --- a/.gitignore +++ b/.gitignore @@ -25,5 +25,7 @@ apps/* !*.nim tests/* !*.nim +ratelimit/* +!*.nim nimble.develop nimble.paths diff --git a/chat_sdk.nimble b/chat_sdk.nimble index b7d1da0..e5fbb06 100644 --- a/chat_sdk.nimble +++ b/chat_sdk.nimble @@ -7,7 +7,9 @@ license = "MIT" srcDir = "src" ### Dependencies -requires "nim >= 2.2.4", "chronicles", "chronos", "db_connector", "waku" +requires "nim >= 2.2.4", + "chronicles", "chronos", "db_connector", + "https://github.com/waku-org/token_bucket.git" task buildSharedLib, "Build shared library for C bindings": exec "nim c --mm:refc --app:lib --out:../library/c-bindings/libchatsdk.so chat_sdk/chat_sdk.nim" diff --git a/ratelimit/rate_limit_manager.nim b/ratelimit/rate_limit_manager.nim index c8f51f9..5451216 100644 --- a/ratelimit/rate_limit_manager.nim +++ b/ratelimit/rate_limit_manager.nim @@ -1,5 +1,6 @@ import std/[times, deques, options] -import waku/common/rate_limit/token_bucket +import token_bucket +import ./store/store import chronos type @@ -27,7 +28,8 @@ type MessageSender*[T: Serializable] = proc(msgs: seq[MsgIdMsg[T]]) {.async.} - RateLimitManager*[T: Serializable] = ref object + RateLimitManager*[T: Serializable, S: RateLimitStore] = ref object + store: S bucket: TokenBucket sender: MessageSender[T] queueCritical: Deque[seq[MsgIdMsg[T]]] @@ -35,25 +37,41 @@ type sleepDuration: chronos.Duration pxQueueHandleLoop: Future[void] -proc new*[T: Serializable]( - M: type[RateLimitManager[T]], +proc new*[T: Serializable, S: RateLimitStore]( + M: type[RateLimitManager[T, S]], + store: S, sender: MessageSender[T], capacity: int = 100, duration: chronos.Duration = chronos.minutes(10), sleepDuration: chronos.Duration = chronos.milliseconds(1000), -): M = - M( - bucket: TokenBucket.newStrict(capacity, duration), +): Future[RateLimitManager[T, S]] {.async.} = + var current = await store.loadBucketState() + if current.isNone(): + # initialize bucket state with full capacity + current = some( + BucketState(budget: capacity, budgetCap: capacity, lastTimeFull: Moment.now()) + ) + discard await store.saveBucketState(current.get()) + + return RateLimitManager[T, S]( + store: store, + bucket: TokenBucket.new( + current.get().budgetCap, + duration, + ReplenishMode.Strict, + current.get().budget, + current.get().lastTimeFull, + ), sender: sender, queueCritical: Deque[seq[MsgIdMsg[T]]](), queueNormal: Deque[seq[MsgIdMsg[T]]](), sleepDuration: sleepDuration, ) -proc getCapacityState[T: Serializable]( - manager: RateLimitManager[T], now: Moment, count: int = 1 +proc getCapacityState[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S], now: Moment, count: int = 1 ): CapacityState = - let (budget, budgetCap) = manager.bucket.getAvailableCapacity(now) + let (budget, budgetCap, _) = manager.bucket.getAvailableCapacity(now) let countAfter = budget - count let ratio = countAfter.float / budgetCap.float if ratio < 0.0: @@ -63,15 +81,15 @@ proc getCapacityState[T: Serializable]( else: return CapacityState.Normal -proc passToSender[T: Serializable]( - manager: RateLimitManager[T], - msgs: sink seq[MsgIdMsg[T]], +proc passToSender[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S], + msgs: seq[tuple[msgId: string, msg: T]], now: Moment, priority: Priority, ): Future[SendResult] {.async.} = let count = msgs.len - let capacity = manager.bucket.tryConsume(count, now) - if not capacity: + let consumed = manager.bucket.tryConsume(count, now) + if not consumed: case priority of Priority.Critical: manager.queueCritical.addLast(msgs) @@ -81,12 +99,17 @@ proc passToSender[T: Serializable]( return SendResult.Enqueued of Priority.Optional: return SendResult.Dropped + + let (budget, budgetCap, lastTimeFull) = manager.bucket.getAvailableCapacity(now) + discard await manager.store.saveBucketState( + BucketState(budget: budget, budgetCap: budgetCap, lastTimeFull: lastTimeFull) + ) await manager.sender(msgs) return SendResult.PassedToSender -proc processCriticalQueue[T: Serializable]( - manager: RateLimitManager[T], now: Moment -) {.async.} = +proc processCriticalQueue[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S], now: Moment +): Future[void] {.async.} = while manager.queueCritical.len > 0: let msgs = manager.queueCritical.popFirst() let capacityState = manager.getCapacityState(now, msgs.len) @@ -99,9 +122,9 @@ proc processCriticalQueue[T: Serializable]( manager.queueCritical.addFirst(msgs) break -proc processNormalQueue[T: Serializable]( - manager: RateLimitManager[T], now: Moment -) {.async.} = +proc processNormalQueue[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S], now: Moment +): Future[void] {.async.} = while manager.queueNormal.len > 0: let msgs = manager.queueNormal.popFirst() let capacityState = manager.getCapacityState(now, msgs.len) @@ -112,13 +135,13 @@ proc processNormalQueue[T: Serializable]( manager.queueNormal.addFirst(msgs) break -proc sendOrEnqueue*[T: Serializable]( - manager: RateLimitManager[T], - msgs: seq[MsgIdMsg[T]], +proc sendOrEnqueue*[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S], + msgs: seq[tuple[msgId: string, msg: T]], priority: Priority, now: Moment = Moment.now(), ): Future[SendResult] {.async.} = - let (_, budgetCap) = manager.bucket.getAvailableCapacity(now) + let (_, budgetCap, _) = manager.bucket.getAvailableCapacity(now) if msgs.len.float / budgetCap.float >= 0.3: # drop batch if it's too large to avoid starvation return SendResult.DroppedBatchTooLarge @@ -147,11 +170,13 @@ proc sendOrEnqueue*[T: Serializable]( of Priority.Optional: return SendResult.Dropped -proc getEnqueued*[T: Serializable]( - manager: RateLimitManager[T] -): tuple[critical: seq[MsgIdMsg[T]], normal: seq[MsgIdMsg[T]]] = - var criticalMsgs: seq[MsgIdMsg[T]] - var normalMsgs: seq[MsgIdMsg[T]] +proc getEnqueued*[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S] +): tuple[ + critical: seq[tuple[msgId: string, msg: T]], normal: seq[tuple[msgId: string, msg: T]] +] = + var criticalMsgs: seq[tuple[msgId: string, msg: T]] + var normalMsgs: seq[tuple[msgId: string, msg: T]] for batch in manager.queueCritical: criticalMsgs.add(batch) @@ -161,8 +186,8 @@ proc getEnqueued*[T: Serializable]( return (criticalMsgs, normalMsgs) -proc queueHandleLoop[T: Serializable]( - manager: RateLimitManager[T], +proc queueHandleLoop*[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S], nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} = Moment.now(), ) {.async.} = @@ -177,18 +202,22 @@ proc queueHandleLoop[T: Serializable]( # configurable sleep duration for processing queued messages await sleepAsync(manager.sleepDuration) -proc start*[T: Serializable]( - manager: RateLimitManager[T], +proc start*[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S], nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} = Moment.now(), ) {.async.} = - manager.pxQueueHandleLoop = manager.queueHandleLoop(nowProvider) + manager.pxQueueHandleLoop = queueHandleLoop(manager, nowProvider) -proc stop*[T: Serializable](manager: RateLimitManager[T]) {.async.} = +proc stop*[T: Serializable, S: RateLimitStore]( + manager: RateLimitManager[T, S] +) {.async.} = if not isNil(manager.pxQueueHandleLoop): await manager.pxQueueHandleLoop.cancelAndWait() -func `$`*[T: Serializable](b: RateLimitManager[T]): string {.inline.} = +func `$`*[T: Serializable, S: RateLimitStore]( + b: RateLimitManager[T, S] +): string {.inline.} = if isNil(b): return "nil" return diff --git a/ratelimit/store/memory.nim b/ratelimit/store/memory.nim index 557a17d..6391314 100644 --- a/ratelimit/store/memory.nim +++ b/ratelimit/store/memory.nim @@ -4,18 +4,18 @@ import chronos # Memory Implementation type MemoryRateLimitStore* = ref object - bucketState: BucketState + bucketState: Option[BucketState] -proc newMemoryRateLimitStore*(): MemoryRateLimitStore = - result = MemoryRateLimitStore() +proc new*(T: type[MemoryRateLimitStore]): T = + return T(bucketState: none(BucketState)) proc saveBucketState*( store: MemoryRateLimitStore, bucketState: BucketState ): Future[bool] {.async.} = - store.bucketState = bucketState + store.bucketState = some(bucketState) return true proc loadBucketState*( store: MemoryRateLimitStore ): Future[Option[BucketState]] {.async.} = - return some(store.bucketState) + return store.bucketState diff --git a/ratelimit/store/store.nim b/ratelimit/store/store.nim index c4f6da3..c916750 100644 --- a/ratelimit/store/store.nim +++ b/ratelimit/store/store.nim @@ -7,7 +7,7 @@ type budgetCap*: int lastTimeFull*: Moment - RateLimitStoreConcept* = + RateLimitStore* = concept s s.saveBucketState(BucketState) is Future[bool] s.loadBucketState() is Future[Option[BucketState]] diff --git a/ratelimit/token_bucket.nim b/ratelimit/token_bucket.nim deleted file mode 100644 index 447569a..0000000 --- a/ratelimit/token_bucket.nim +++ /dev/null @@ -1,198 +0,0 @@ -{.push raises: [].} - -import chronos, std/math, std/options - -const BUDGET_COMPENSATION_LIMIT_PERCENT = 0.25 - -## This is an extract from chronos/rate_limit.nim due to the found bug in the original implementation. -## Unfortunately that bug cannot be solved without harm the original features of TokenBucket class. -## So, this current shortcut is used to enable move ahead with nwaku rate limiter implementation. -## ref: https://github.com/status-im/nim-chronos/issues/500 -## -## This version of TokenBucket is different from the original one in chronos/rate_limit.nim in many ways: -## - It has a new mode called `Compensating` which is the default mode. -## Compensation is calculated as the not used bucket capacity in the last measured period(s) in average. -## or up until maximum the allowed compansation treshold (Currently it is const 25%). -## Also compensation takes care of the proper time period calculation to avoid non-usage periods that can lead to -## overcompensation. -## - Strict mode is also available which will only replenish when time period is over but also will fill -## the bucket to the max capacity. - -type - ReplenishMode* = enum - Strict - Compensating - - TokenBucket* = ref object - budget: int ## Current number of tokens in the bucket - budgetCap: int ## Bucket capacity - lastTimeFull: Moment - ## This timer measures the proper periodizaiton of the bucket refilling - fillDuration: Duration ## Refill period - case replenishMode*: ReplenishMode - of Strict: - ## In strict mode, the bucket is refilled only till the budgetCap - discard - of Compensating: - ## This is the default mode. - maxCompensation: float - -func periodDistance(bucket: TokenBucket, currentTime: Moment): float = - ## notice fillDuration cannot be zero by design - ## period distance is a float number representing the calculated period time - ## since the last time bucket was refilled. - return - nanoseconds(currentTime - bucket.lastTimeFull).float / - nanoseconds(bucket.fillDuration).float - -func getUsageAverageSince(bucket: TokenBucket, distance: float): float = - if distance == 0.float: - ## in case there is zero time difference than the usage percentage is 100% - return 1.0 - - ## budgetCap can never be zero - ## usage average is calculated as a percentage of total capacity available over - ## the measured period - return bucket.budget.float / bucket.budgetCap.float / distance - -proc calcCompensation(bucket: TokenBucket, averageUsage: float): int = - # if we already fully used or even overused the tokens, there is no place for compensation - if averageUsage >= 1.0: - return 0 - - ## compensation is the not used bucket capacity in the last measured period(s) in average. - ## or maximum the allowed compansation treshold - let compensationPercent = - min((1.0 - averageUsage) * bucket.budgetCap.float, bucket.maxCompensation) - return trunc(compensationPercent).int - -func periodElapsed(bucket: TokenBucket, currentTime: Moment): bool = - return currentTime - bucket.lastTimeFull >= bucket.fillDuration - -## Update will take place if bucket is empty and trying to consume tokens. -## It checks if the bucket can be replenished as refill duration is passed or not. -## - strict mode: -proc updateStrict(bucket: TokenBucket, currentTime: Moment) = - if bucket.fillDuration == default(Duration): - bucket.budget = min(bucket.budgetCap, bucket.budget) - return - - if not periodElapsed(bucket, currentTime): - return - - bucket.budget = bucket.budgetCap - bucket.lastTimeFull = currentTime - -## - compensating - ballancing load: -## - between updates we calculate average load (current bucket capacity / number of periods till last update) -## - gives the percentage load used recently -## - with this we can replenish bucket up to 100% + calculated leftover from previous period (caped with max treshold) -proc updateWithCompensation(bucket: TokenBucket, currentTime: Moment) = - if bucket.fillDuration == default(Duration): - bucket.budget = min(bucket.budgetCap, bucket.budget) - return - - # do not replenish within the same period - if not periodElapsed(bucket, currentTime): - return - - let distance = bucket.periodDistance(currentTime) - let recentAvgUsage = bucket.getUsageAverageSince(distance) - let compensation = bucket.calcCompensation(recentAvgUsage) - - bucket.budget = bucket.budgetCap + compensation - bucket.lastTimeFull = currentTime - -proc update(bucket: TokenBucket, currentTime: Moment) = - if bucket.replenishMode == ReplenishMode.Compensating: - updateWithCompensation(bucket, currentTime) - else: - updateStrict(bucket, currentTime) - -## Returns the available capacity of the bucket: (budget, budgetCap) -proc getAvailableCapacity*( - bucket: TokenBucket, currentTime: Moment -): tuple[budget: int, budgetCap: int] = - if periodElapsed(bucket, currentTime): - case bucket.replenishMode - of ReplenishMode.Strict: - return (bucket.budgetCap, bucket.budgetCap) - of ReplenishMode.Compensating: - let distance = bucket.periodDistance(currentTime) - let recentAvgUsage = bucket.getUsageAverageSince(distance) - let compensation = bucket.calcCompensation(recentAvgUsage) - let availableBudget = bucket.budgetCap + compensation - return (availableBudget, bucket.budgetCap) - return (bucket.budget, bucket.budgetCap) - -proc tryConsume*(bucket: TokenBucket, tokens: int, now = Moment.now()): bool = - ## If `tokens` are available, consume them, - ## Otherwhise, return false. - - if bucket.budget >= bucket.budgetCap: - bucket.lastTimeFull = now - - if bucket.budget >= tokens: - bucket.budget -= tokens - return true - - bucket.update(now) - - if bucket.budget >= tokens: - bucket.budget -= tokens - return true - else: - return false - -proc replenish*(bucket: TokenBucket, tokens: int, now = Moment.now()) = - ## Add `tokens` to the budget (capped to the bucket capacity) - bucket.budget += tokens - bucket.update(now) - -proc new*( - T: type[TokenBucket], - budgetCap: int, - fillDuration: Duration = 1.seconds, - mode: ReplenishMode = ReplenishMode.Compensating, -): T = - assert not isZero(fillDuration) - assert budgetCap != 0 - - ## Create different mode TokenBucket - case mode - of ReplenishMode.Strict: - return T( - budget: budgetCap, - budgetCap: budgetCap, - fillDuration: fillDuration, - lastTimeFull: Moment.now(), - replenishMode: mode, - ) - of ReplenishMode.Compensating: - T( - budget: budgetCap, - budgetCap: budgetCap, - fillDuration: fillDuration, - lastTimeFull: Moment.now(), - replenishMode: mode, - maxCompensation: budgetCap.float * BUDGET_COMPENSATION_LIMIT_PERCENT, - ) - -proc newStrict*(T: type[TokenBucket], capacity: int, period: Duration): TokenBucket = - T.new(capacity, period, ReplenishMode.Strict) - -proc newCompensating*( - T: type[TokenBucket], capacity: int, period: Duration -): TokenBucket = - T.new(capacity, period, ReplenishMode.Compensating) - -func `$`*(b: TokenBucket): string {.inline.} = - if isNil(b): - return "nil" - return $b.budgetCap & "/" & $b.fillDuration - -func `$`*(ob: Option[TokenBucket]): string {.inline.} = - if ob.isNone(): - return "no-limit" - - return $ob.get() diff --git a/tests/test_rate_limit_manager.nim b/tests/test_rate_limit_manager.nim index 6c34c54..b15606e 100644 --- a/tests/test_rate_limit_manager.nim +++ b/tests/test_rate_limit_manager.nim @@ -1,5 +1,6 @@ import testutils/unittests import ../ratelimit/rate_limit_manager +import ../ratelimit/store/memory import chronos # Implement the Serializable concept for string @@ -22,8 +23,9 @@ suite "Queue RateLimitManager": asyncTest "sendOrEnqueue - immediate send when capacity available": ## Given - let manager = RateLimitManager[string].new( - mockSender, capacity = 10, duration = chronos.milliseconds(100) + let store: MemoryRateLimitStore = MemoryRateLimitStore.new() + let manager = await RateLimitManager[string, MemoryRateLimitStore].new( + store, mockSender, capacity = 10, duration = chronos.milliseconds(100) ) let testMsg = "Hello World" @@ -40,8 +42,9 @@ suite "Queue RateLimitManager": asyncTest "sendOrEnqueue - multiple messages": ## Given - let manager = RateLimitManager[string].new( - mockSender, capacity = 10, duration = chronos.milliseconds(100) + let store = MemoryRateLimitStore.new() + let manager = await RateLimitManager[string, MemoryRateLimitStore].new( + store, mockSender, capacity = 10, duration = chronos.milliseconds(100) ) ## When @@ -60,7 +63,9 @@ suite "Queue RateLimitManager": asyncTest "start and stop - drop large batch": ## Given - let manager = RateLimitManager[string].new( + let store = MemoryRateLimitStore.new() + let manager = await RateLimitManager[string, MemoryRateLimitStore].new( + store, mockSender, capacity = 2, duration = chronos.milliseconds(100), @@ -75,7 +80,9 @@ suite "Queue RateLimitManager": asyncTest "enqueue - enqueue critical only when exceeded": ## Given - let manager = RateLimitManager[string].new( + let store = MemoryRateLimitStore.new() + let manager = await RateLimitManager[string, MemoryRateLimitStore].new( + store, mockSender, capacity = 10, duration = chronos.milliseconds(100), @@ -123,7 +130,9 @@ suite "Queue RateLimitManager": asyncTest "enqueue - enqueue normal on 70% capacity": ## Given - let manager = RateLimitManager[string].new( + let store = MemoryRateLimitStore.new() + let manager = await RateLimitManager[string, MemoryRateLimitStore].new( + store, mockSender, capacity = 10, duration = chronos.milliseconds(100), @@ -174,7 +183,9 @@ suite "Queue RateLimitManager": asyncTest "enqueue - process queued messages": ## Given - let manager = RateLimitManager[string].new( + let store = MemoryRateLimitStore.new() + let manager = await RateLimitManager[string, MemoryRateLimitStore].new( + store, mockSender, capacity = 10, duration = chronos.milliseconds(200), From 109b5769da3382b483ee318192123ef6f47bffd6 Mon Sep 17 00:00:00 2001 From: pablo Date: Mon, 4 Aug 2025 09:04:52 +0300 Subject: [PATCH 09/17] fix: pr comments --- ...imit_manager.nim => ratelimit_manager.nim} | 19 +------------------ ratelimit/token_bucket.nim | 6 ++++-- ...manager.nim => test_ratelimit_manager.nim} | 2 +- 3 files changed, 6 insertions(+), 21 deletions(-) rename ratelimit/{rate_limit_manager.nim => ratelimit_manager.nim} (92%) rename tests/{test_rate_limit_manager.nim => test_ratelimit_manager.nim} (99%) diff --git a/ratelimit/rate_limit_manager.nim b/ratelimit/ratelimit_manager.nim similarity index 92% rename from ratelimit/rate_limit_manager.nim rename to ratelimit/ratelimit_manager.nim index e0b00be..dce9632 100644 --- a/ratelimit/rate_limit_manager.nim +++ b/ratelimit/ratelimit_manager.nim @@ -1,4 +1,5 @@ import std/[times, deques, options] +# TODO: move to waku's, chronos' or a lib tocken_bucket once decided where this will live import ./token_bucket # import waku/common/rate_limit/token_bucket import ./store/store @@ -54,24 +55,6 @@ proc new*[T: Serializable, S: RateLimitStore]( ) discard await store.saveBucketState(current.get()) - return RateLimitManager[T, S]( - store: store, - bucket: TokenBucket.new( - current.get().budgetCap, - duration, - ReplenishMode.Strict, - current.get().budget, - current.get().lastTimeFull, - ), -): Future[RateLimitManager[T, S]] {.async.} = - var current = await store.loadBucketState() - if current.isNone(): - # initialize bucket state with full capacity - current = some( - BucketState(budget: capacity, budgetCap: capacity, lastTimeFull: Moment.now()) - ) - discard await store.saveBucketState(current.get()) - return RateLimitManager[T, S]( store: store, bucket: TokenBucket.new( diff --git a/ratelimit/token_bucket.nim b/ratelimit/token_bucket.nim index e4a4487..e2c606d 100644 --- a/ratelimit/token_bucket.nim +++ b/ratelimit/token_bucket.nim @@ -109,10 +109,11 @@ proc update(bucket: TokenBucket, currentTime: Moment) = else: updateStrict(bucket, currentTime) -## Returns the available capacity of the bucket: (budget, budgetCap) proc getAvailableCapacity*( - bucket: TokenBucket, currentTime: Moment + bucket: TokenBucket, currentTime: Moment = Moment.now() ): tuple[budget: int, budgetCap: int, lastTimeFull: Moment] = + ## Returns the available capacity of the bucket: (budget, budgetCap, lastTimeFull) + if periodElapsed(bucket, currentTime): case bucket.replenishMode of ReplenishMode.Strict: @@ -159,6 +160,7 @@ proc new*( ): T = assert not isZero(fillDuration) assert budgetCap != 0 + assert lastTimeFull <= Moment.now() let actualBudget = if budget == -1: budgetCap else: budget assert actualBudget >= 0 and actualBudget <= budgetCap diff --git a/tests/test_rate_limit_manager.nim b/tests/test_ratelimit_manager.nim similarity index 99% rename from tests/test_rate_limit_manager.nim rename to tests/test_ratelimit_manager.nim index b15606e..f6c4039 100644 --- a/tests/test_rate_limit_manager.nim +++ b/tests/test_ratelimit_manager.nim @@ -1,5 +1,5 @@ import testutils/unittests -import ../ratelimit/rate_limit_manager +import ../ratelimit/ratelimit_manager import ../ratelimit/store/memory import chronos From 2c47183fb03ef283e99377cd584c28d327b00b4c Mon Sep 17 00:00:00 2001 From: pablo Date: Mon, 4 Aug 2025 10:43:59 +0300 Subject: [PATCH 10/17] feat: store queue --- chat_sdk/migration.nim | 21 ++- migrations/001_create_ratelimit_state.sql | 9 ++ ratelimit/store/memory.nim | 65 +++++++- ratelimit/store/sqlite.nim | 167 ++++++++++++++++++++- ratelimit/store/store.nim | 7 + tests/test_sqlite_store.nim | 175 +++++++++++++++++++++- 6 files changed, 421 insertions(+), 23 deletions(-) diff --git a/chat_sdk/migration.nim b/chat_sdk/migration.nim index 8af88d3..45ea614 100644 --- a/chat_sdk/migration.nim +++ b/chat_sdk/migration.nim @@ -1,17 +1,21 @@ -import os, sequtils, algorithm +import os, sequtils, algorithm, strutils import db_connector/db_sqlite import chronicles proc ensureMigrationTable(db: DbConn) = - db.exec(sql""" + db.exec( + sql""" CREATE TABLE IF NOT EXISTS schema_migrations ( filename TEXT PRIMARY KEY, applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) - """) + """ + ) proc hasMigrationRun(db: DbConn, filename: string): bool = - for row in db.fastRows(sql"SELECT 1 FROM schema_migrations WHERE filename = ?", filename): + for row in db.fastRows( + sql"SELECT 1 FROM schema_migrations WHERE filename = ?", filename + ): return true return false @@ -27,6 +31,11 @@ proc runMigrations*(db: DbConn, dir = "migrations") = info "Migration already applied", file else: info "Applying migration", file - let sql = readFile(file) - db.exec(sql(sql)) + let sqlContent = readFile(file) + # Split by semicolon and execute each statement separately + let statements = sqlContent.split(';') + for stmt in statements: + let trimmedStmt = stmt.strip() + if trimmedStmt.len > 0: + db.exec(sql(trimmedStmt)) markMigrationRun(db, file) diff --git a/migrations/001_create_ratelimit_state.sql b/migrations/001_create_ratelimit_state.sql index 030377c..293c6ee 100644 --- a/migrations/001_create_ratelimit_state.sql +++ b/migrations/001_create_ratelimit_state.sql @@ -1,4 +1,13 @@ CREATE TABLE IF NOT EXISTS kv_store ( key TEXT PRIMARY KEY, value BLOB +); + +CREATE TABLE IF NOT EXISTS ratelimit_queues ( + queue_type TEXT NOT NULL, + msg_id TEXT NOT NULL, + msg_data BLOB NOT NULL, + batch_id INTEGER NOT NULL, + created_at INTEGER NOT NULL, + PRIMARY KEY (queue_type, batch_id, msg_id) ); \ No newline at end of file diff --git a/ratelimit/store/memory.nim b/ratelimit/store/memory.nim index 6391314..e7e7c2f 100644 --- a/ratelimit/store/memory.nim +++ b/ratelimit/store/memory.nim @@ -1,21 +1,70 @@ -import std/[times, options] +import std/[times, options, deques, tables] import ./store import chronos # Memory Implementation -type MemoryRateLimitStore* = ref object +type MemoryRateLimitStore*[T] = ref object bucketState: Option[BucketState] + criticalQueue: Deque[seq[tuple[msgId: string, msg: T]]] + normalQueue: Deque[seq[tuple[msgId: string, msg: T]]] + criticalLength: int + normalLength: int -proc new*(T: type[MemoryRateLimitStore]): T = - return T(bucketState: none(BucketState)) +proc new*[T](M: type[MemoryRateLimitStore[T]]): M = + return M( + bucketState: none(BucketState), + criticalQueue: initDeque[seq[tuple[msgId: string, msg: T]]](), + normalQueue: initDeque[seq[tuple[msgId: string, msg: T]]](), + criticalLength: 0, + normalLength: 0 + ) -proc saveBucketState*( - store: MemoryRateLimitStore, bucketState: BucketState +proc saveBucketState*[T]( + store: MemoryRateLimitStore[T], bucketState: BucketState ): Future[bool] {.async.} = store.bucketState = some(bucketState) return true -proc loadBucketState*( - store: MemoryRateLimitStore +proc loadBucketState*[T]( + store: MemoryRateLimitStore[T] ): Future[Option[BucketState]] {.async.} = return store.bucketState + +proc addToQueue*[T]( + store: MemoryRateLimitStore[T], + queueType: QueueType, + msgs: seq[tuple[msgId: string, msg: T]] +): Future[bool] {.async.} = + case queueType + of QueueType.Critical: + store.criticalQueue.addLast(msgs) + inc store.criticalLength + of QueueType.Normal: + store.normalQueue.addLast(msgs) + inc store.normalLength + return true + +proc popFromQueue*[T]( + store: MemoryRateLimitStore[T], + queueType: QueueType +): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} = + case queueType + of QueueType.Critical: + if store.criticalQueue.len > 0: + dec store.criticalLength + return some(store.criticalQueue.popFirst()) + of QueueType.Normal: + if store.normalQueue.len > 0: + dec store.normalLength + return some(store.normalQueue.popFirst()) + return none(seq[tuple[msgId: string, msg: T]]) + +proc getQueueLength*[T]( + store: MemoryRateLimitStore[T], + queueType: QueueType +): int = + case queueType + of QueueType.Critical: + return store.criticalLength + of QueueType.Normal: + return store.normalLength diff --git a/ratelimit/store/sqlite.nim b/ratelimit/store/sqlite.nim index e364e5d..a3369e9 100644 --- a/ratelimit/store/sqlite.nim +++ b/ratelimit/store/sqlite.nim @@ -3,18 +3,58 @@ import ./store import chronos import db_connector/db_sqlite +# Generic deserialization function for basic types +proc fromBytesImpl(bytes: seq[byte], T: typedesc[string]): string = + # Convert each byte back to a character + result = newString(bytes.len) + for i, b in bytes: + result[i] = char(b) + # SQLite Implementation -type SqliteRateLimitStore* = ref object +type SqliteRateLimitStore*[T] = ref object db: DbConn dbPath: string + criticalLength: int + normalLength: int + nextBatchId: int const BUCKET_STATE_KEY = "rate_limit_bucket_state" -proc newSqliteRateLimitStore*(db: DbConn): SqliteRateLimitStore = - result = SqliteRateLimitStore(db: db) +proc newSqliteRateLimitStore*[T](db: DbConn): SqliteRateLimitStore[T] = + result = + SqliteRateLimitStore[T](db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1) -proc saveBucketState*( - store: SqliteRateLimitStore, bucketState: BucketState + # Initialize cached lengths from database + let criticalCount = db.getValue( + sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?", + "critical", + ) + let normalCount = db.getValue( + sql"SELECT COUNT(DISTINCT batch_id) FROM ratelimit_queues WHERE queue_type = ?", + "normal", + ) + + result.criticalLength = + if criticalCount == "": + 0 + else: + parseInt(criticalCount) + result.normalLength = + if normalCount == "": + 0 + else: + parseInt(normalCount) + + # Get next batch ID + let maxBatch = db.getValue(sql"SELECT MAX(batch_id) FROM ratelimit_queues") + result.nextBatchId = + if maxBatch == "": + 1 + else: + parseInt(maxBatch) + 1 + +proc saveBucketState*[T]( + store: SqliteRateLimitStore[T], bucketState: BucketState ): Future[bool] {.async.} = try: # Convert Moment to Unix seconds for storage @@ -35,8 +75,8 @@ proc saveBucketState*( except: return false -proc loadBucketState*( - store: SqliteRateLimitStore +proc loadBucketState*[T]( + store: SqliteRateLimitStore[T] ): Future[Option[BucketState]] {.async.} = let jsonStr = store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY) @@ -54,3 +94,116 @@ proc loadBucketState*( lastTimeFull: lastTimeFull, ) ) + +proc addToQueue*[T]( + store: SqliteRateLimitStore[T], + queueType: QueueType, + msgs: seq[tuple[msgId: string, msg: T]], +): Future[bool] {.async.} = + try: + let batchId = store.nextBatchId + inc store.nextBatchId + let now = times.getTime().toUnix() + let queueTypeStr = $queueType + + if msgs.len > 0: + store.db.exec(sql"BEGIN TRANSACTION") + try: + for msg in msgs: + # Consistent serialization for all types + let msgBytes = msg.msg.toBytes() + # Convert seq[byte] to string for SQLite storage (each byte becomes a character) + var binaryStr = newString(msgBytes.len) + for i, b in msgBytes: + binaryStr[i] = char(b) + + store.db.exec( + sql"INSERT INTO ratelimit_queues (queue_type, msg_id, msg_data, batch_id, created_at) VALUES (?, ?, ?, ?, ?)", + queueTypeStr, + msg.msgId, + binaryStr, + batchId, + now, + ) + store.db.exec(sql"COMMIT") + except: + store.db.exec(sql"ROLLBACK") + raise + + case queueType + of QueueType.Critical: + inc store.criticalLength + of QueueType.Normal: + inc store.normalLength + + return true + except: + return false + +proc popFromQueue*[T]( + store: SqliteRateLimitStore[T], queueType: QueueType +): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} = + try: + let queueTypeStr = $queueType + + # Get the oldest batch ID for this queue type + let oldestBatchStr = store.db.getValue( + sql"SELECT MIN(batch_id) FROM ratelimit_queues WHERE queue_type = ?", queueTypeStr + ) + + if oldestBatchStr == "": + return none(seq[tuple[msgId: string, msg: T]]) + + let batchId = parseInt(oldestBatchStr) + + # Get all messages in this batch (preserve insertion order using rowid) + let rows = store.db.getAllRows( + sql"SELECT msg_id, msg_data FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ? ORDER BY rowid", + queueTypeStr, + batchId, + ) + + if rows.len == 0: + return none(seq[tuple[msgId: string, msg: T]]) + + var msgs: seq[tuple[msgId: string, msg: T]] + for row in rows: + let msgIdStr = row[0] + let msgData = row[1] # SQLite returns BLOB as string where each char is a byte + # Convert string back to seq[byte] properly (each char in string is a byte) + var msgBytes: seq[byte] + for c in msgData: + msgBytes.add(byte(c)) + + # Generic deserialization - works for any type that implements fromBytes + when T is string: + let msg = fromBytesImpl(msgBytes, T) + msgs.add((msgId: msgIdStr, msg: msg)) + else: + # For other types, they need to provide their own fromBytes in the calling context + let msg = fromBytes(msgBytes, T) + msgs.add((msgId: msgIdStr, msg: msg)) + + # Delete the batch from database + store.db.exec( + sql"DELETE FROM ratelimit_queues WHERE queue_type = ? AND batch_id = ?", + queueTypeStr, + batchId, + ) + + case queueType + of QueueType.Critical: + dec store.criticalLength + of QueueType.Normal: + dec store.normalLength + + return some(msgs) + except: + return none(seq[tuple[msgId: string, msg: T]]) + +proc getQueueLength*[T](store: SqliteRateLimitStore[T], queueType: QueueType): int = + case queueType + of QueueType.Critical: + return store.criticalLength + of QueueType.Normal: + return store.normalLength diff --git a/ratelimit/store/store.nim b/ratelimit/store/store.nim index c916750..0f18eb1 100644 --- a/ratelimit/store/store.nim +++ b/ratelimit/store/store.nim @@ -7,7 +7,14 @@ type budgetCap*: int lastTimeFull*: Moment + QueueType* {.pure.} = enum + Critical = "critical" + Normal = "normal" + RateLimitStore* = concept s s.saveBucketState(BucketState) is Future[bool] s.loadBucketState() is Future[Option[BucketState]] + s.addToQueue(QueueType, seq[tuple[msgId: string, msg: untyped]]) is Future[bool] + s.popFromQueue(QueueType) is Future[Option[seq[tuple[msgId: string, msg: untyped]]]] + s.getQueueLength(QueueType) is int diff --git a/tests/test_sqlite_store.nim b/tests/test_sqlite_store.nim index 6315ec9..90d764c 100644 --- a/tests/test_sqlite_store.nim +++ b/tests/test_sqlite_store.nim @@ -6,6 +6,19 @@ import db_connector/db_sqlite import ../chat_sdk/migration import std/[options, os] +# Implement the Serializable concept for string (for testing) +proc toBytes*(s: string): seq[byte] = + # Convert each character to a byte + result = newSeq[byte](s.len) + for i, c in s: + result[i] = byte(c) + +proc fromBytes*(bytes: seq[byte], T: typedesc[string]): string = + # Convert each byte back to a character + result = newString(bytes.len) + for i, b in bytes: + result[i] = char(b) + suite "SqliteRateLimitStore Tests": setup: let db = open("test-ratelimit.db", "", "", "") @@ -19,7 +32,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "newSqliteRateLimitStore - empty state": ## Given - let store = newSqliteRateLimitStore(db) + let store = newSqliteRateLimitStore[string](db) ## When let loadedState = await store.loadBucketState() @@ -29,7 +42,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "saveBucketState and loadBucketState - state persistence": ## Given - let store = newSqliteRateLimitStore(db) + let store = newSqliteRateLimitStore[string](db) let now = Moment.now() echo "now: ", now.epochSeconds() @@ -46,3 +59,161 @@ suite "SqliteRateLimitStore Tests": check loadedState.get().budgetCap == newBucketState.budgetCap check loadedState.get().lastTimeFull.epochSeconds() == newBucketState.lastTimeFull.epochSeconds() + + asyncTest "queue operations - empty store": + ## Given + let store = newSqliteRateLimitStore[string](db) + + ## When/Then + check store.getQueueLength(QueueType.Critical) == 0 + check store.getQueueLength(QueueType.Normal) == 0 + + let criticalPop = await store.popFromQueue(QueueType.Critical) + let normalPop = await store.popFromQueue(QueueType.Normal) + + check criticalPop.isNone() + check normalPop.isNone() + + asyncTest "addToQueue and popFromQueue - single batch": + ## Given + let store = newSqliteRateLimitStore[string](db) + let msgs = @[("msg1", "Hello"), ("msg2", "World")] + + ## When + let addResult = await store.addToQueue(QueueType.Critical, msgs) + + ## Then + check addResult == true + check store.getQueueLength(QueueType.Critical) == 1 + check store.getQueueLength(QueueType.Normal) == 0 + + ## When + let popResult = await store.popFromQueue(QueueType.Critical) + + ## Then + check popResult.isSome() + let poppedMsgs = popResult.get() + check poppedMsgs.len == 2 + check poppedMsgs[0].msgId == "msg1" + check poppedMsgs[0].msg == "Hello" + check poppedMsgs[1].msgId == "msg2" + check poppedMsgs[1].msg == "World" + + check store.getQueueLength(QueueType.Critical) == 0 + + asyncTest "addToQueue and popFromQueue - multiple batches FIFO": + ## Given + let store = newSqliteRateLimitStore[string](db) + let batch1 = @[("msg1", "First")] + let batch2 = @[("msg2", "Second")] + let batch3 = @[("msg3", "Third")] + + ## When - Add batches + let result1 = await store.addToQueue(QueueType.Normal, batch1) + check result1 == true + let result2 = await store.addToQueue(QueueType.Normal, batch2) + check result2 == true + let result3 = await store.addToQueue(QueueType.Normal, batch3) + check result3 == true + + ## Then - Check lengths + check store.getQueueLength(QueueType.Normal) == 3 + check store.getQueueLength(QueueType.Critical) == 0 + + ## When/Then - Pop in FIFO order + let pop1 = await store.popFromQueue(QueueType.Normal) + check pop1.isSome() + check pop1.get()[0].msg == "First" + check store.getQueueLength(QueueType.Normal) == 2 + + let pop2 = await store.popFromQueue(QueueType.Normal) + check pop2.isSome() + check pop2.get()[0].msg == "Second" + check store.getQueueLength(QueueType.Normal) == 1 + + let pop3 = await store.popFromQueue(QueueType.Normal) + check pop3.isSome() + check pop3.get()[0].msg == "Third" + check store.getQueueLength(QueueType.Normal) == 0 + + let pop4 = await store.popFromQueue(QueueType.Normal) + check pop4.isNone() + + asyncTest "queue isolation - critical and normal queues are separate": + ## Given + let store = newSqliteRateLimitStore[string](db) + let criticalMsgs = @[("crit1", "Critical Message")] + let normalMsgs = @[("norm1", "Normal Message")] + + ## When + let critResult = await store.addToQueue(QueueType.Critical, criticalMsgs) + check critResult == true + let normResult = await store.addToQueue(QueueType.Normal, normalMsgs) + check normResult == true + + ## Then + check store.getQueueLength(QueueType.Critical) == 1 + check store.getQueueLength(QueueType.Normal) == 1 + + ## When - Pop from critical + let criticalPop = await store.popFromQueue(QueueType.Critical) + check criticalPop.isSome() + check criticalPop.get()[0].msg == "Critical Message" + + ## Then - Normal queue unaffected + check store.getQueueLength(QueueType.Critical) == 0 + check store.getQueueLength(QueueType.Normal) == 1 + + ## When - Pop from normal + let normalPop = await store.popFromQueue(QueueType.Normal) + check normalPop.isSome() + check normalPop.get()[0].msg == "Normal Message" + + ## Then - All queues empty + check store.getQueueLength(QueueType.Critical) == 0 + check store.getQueueLength(QueueType.Normal) == 0 + + asyncTest "queue persistence across store instances": + ## Given + let msgs = @[("persist1", "Persistent Message")] + + block: + let store1 = newSqliteRateLimitStore[string](db) + let addResult = await store1.addToQueue(QueueType.Critical, msgs) + check addResult == true + check store1.getQueueLength(QueueType.Critical) == 1 + + ## When - Create new store instance + block: + let store2 = newSqliteRateLimitStore[string](db) + + ## Then - Queue length should be restored from database + check store2.getQueueLength(QueueType.Critical) == 1 + + let popResult = await store2.popFromQueue(QueueType.Critical) + check popResult.isSome() + check popResult.get()[0].msg == "Persistent Message" + check store2.getQueueLength(QueueType.Critical) == 0 + + asyncTest "large batch handling": + ## Given + let store = newSqliteRateLimitStore[string](db) + var largeBatch: seq[tuple[msgId: string, msg: string]] + + for i in 1 .. 100: + largeBatch.add(("msg" & $i, "Message " & $i)) + + ## When + let addResult = await store.addToQueue(QueueType.Normal, largeBatch) + + ## Then + check addResult == true + check store.getQueueLength(QueueType.Normal) == 1 + + let popResult = await store.popFromQueue(QueueType.Normal) + check popResult.isSome() + let poppedMsgs = popResult.get() + check poppedMsgs.len == 100 + check poppedMsgs[0].msgId == "msg1" + check poppedMsgs[99].msgId == "msg100" + check store.getQueueLength(QueueType.Normal) == 0 From dd0082041c1ea069f7b09d5ec6fd864fbefe7f27 Mon Sep 17 00:00:00 2001 From: pablo Date: Mon, 4 Aug 2025 11:31:44 +0300 Subject: [PATCH 11/17] fix: refactor --- ratelimit/ratelimit_manager.nim | 60 +++++++--------- ratelimit/{store/sqlite.nim => store.nim} | 57 +++++++++------ ratelimit/store/memory.nim | 70 ------------------- ratelimit/store/store.nim | 20 ------ tests/test_ratelimit_manager.nim | 58 +++++++++++---- .../{test_sqlite_store.nim => test_store.nim} | 21 +++--- 6 files changed, 117 insertions(+), 169 deletions(-) rename ratelimit/{store/sqlite.nim => store.nim} (84%) delete mode 100644 ratelimit/store/memory.nim delete mode 100644 ratelimit/store/store.nim rename tests/{test_sqlite_store.nim => test_store.nim} (92%) diff --git a/ratelimit/ratelimit_manager.nim b/ratelimit/ratelimit_manager.nim index dce9632..438619f 100644 --- a/ratelimit/ratelimit_manager.nim +++ b/ratelimit/ratelimit_manager.nim @@ -2,7 +2,7 @@ import std/[times, deques, options] # TODO: move to waku's, chronos' or a lib tocken_bucket once decided where this will live import ./token_bucket # import waku/common/rate_limit/token_bucket -import ./store/store +import ./store import chronos type @@ -22,16 +22,12 @@ type Normal Optional - Serializable* = - concept x - x.toBytes() is seq[byte] - MsgIdMsg[T: Serializable] = tuple[msgId: string, msg: T] MessageSender*[T: Serializable] = proc(msgs: seq[MsgIdMsg[T]]) {.async.} - RateLimitManager*[T: Serializable, S: RateLimitStore] = ref object - store: S + RateLimitManager*[T: Serializable] = ref object + store: RateLimitStore[T] bucket: TokenBucket sender: MessageSender[T] queueCritical: Deque[seq[MsgIdMsg[T]]] @@ -39,14 +35,14 @@ type sleepDuration: chronos.Duration pxQueueHandleLoop: Future[void] -proc new*[T: Serializable, S: RateLimitStore]( - M: type[RateLimitManager[T, S]], - store: S, +proc new*[T: Serializable]( + M: type[RateLimitManager[T]], + store: RateLimitStore[T], sender: MessageSender[T], capacity: int = 100, duration: chronos.Duration = chronos.minutes(10), sleepDuration: chronos.Duration = chronos.milliseconds(1000), -): Future[RateLimitManager[T, S]] {.async.} = +): Future[RateLimitManager[T]] {.async.} = var current = await store.loadBucketState() if current.isNone(): # initialize bucket state with full capacity @@ -55,7 +51,7 @@ proc new*[T: Serializable, S: RateLimitStore]( ) discard await store.saveBucketState(current.get()) - return RateLimitManager[T, S]( + return RateLimitManager[T]( store: store, bucket: TokenBucket.new( current.get().budgetCap, @@ -70,8 +66,8 @@ proc new*[T: Serializable, S: RateLimitStore]( sleepDuration: sleepDuration, ) -proc getCapacityState[T: Serializable, S: RateLimitStore]( - manager: RateLimitManager[T, S], now: Moment, count: int = 1 +proc getCapacityState[T: Serializable]( + manager: RateLimitManager[T], now: Moment, count: int = 1 ): CapacityState = let (budget, budgetCap, _) = manager.bucket.getAvailableCapacity(now) let countAfter = budget - count @@ -83,8 +79,8 @@ proc getCapacityState[T: Serializable, S: RateLimitStore]( else: return CapacityState.Normal -proc passToSender[T: Serializable, S: RateLimitStore]( - manager: RateLimitManager[T, S], +proc passToSender[T: Serializable]( + manager: RateLimitManager[T], msgs: seq[tuple[msgId: string, msg: T]], now: Moment, priority: Priority, @@ -109,8 +105,8 @@ proc passToSender[T: Serializable, S: RateLimitStore]( await manager.sender(msgs) return SendResult.PassedToSender -proc processCriticalQueue[T: Serializable, S: RateLimitStore]( - manager: RateLimitManager[T, S], now: Moment +proc processCriticalQueue[T: Serializable]( + manager: RateLimitManager[T], now: Moment ): Future[void] {.async.} = while manager.queueCritical.len > 0: let msgs = manager.queueCritical.popFirst() @@ -124,8 +120,8 @@ proc processCriticalQueue[T: Serializable, S: RateLimitStore]( manager.queueCritical.addFirst(msgs) break -proc processNormalQueue[T: Serializable, S: RateLimitStore]( - manager: RateLimitManager[T, S], now: Moment +proc processNormalQueue[T: Serializable]( + manager: RateLimitManager[T], now: Moment ): Future[void] {.async.} = while manager.queueNormal.len > 0: let msgs = manager.queueNormal.popFirst() @@ -137,8 +133,8 @@ proc processNormalQueue[T: Serializable, S: RateLimitStore]( manager.queueNormal.addFirst(msgs) break -proc sendOrEnqueue*[T: Serializable, S: RateLimitStore]( - manager: RateLimitManager[T, S], +proc sendOrEnqueue*[T: Serializable]( + manager: RateLimitManager[T], msgs: seq[tuple[msgId: string, msg: T]], priority: Priority, now: Moment = Moment.now(), @@ -172,8 +168,8 @@ proc sendOrEnqueue*[T: Serializable, S: RateLimitStore]( of Priority.Optional: return SendResult.Dropped -proc getEnqueued*[T: Serializable, S: RateLimitStore]( - manager: RateLimitManager[T, S] +proc getEnqueued*[T: Serializable]( + manager: RateLimitManager[T] ): tuple[ critical: seq[tuple[msgId: string, msg: T]], normal: seq[tuple[msgId: string, msg: T]] ] = @@ -188,8 +184,8 @@ proc getEnqueued*[T: Serializable, S: RateLimitStore]( return (criticalMsgs, normalMsgs) -proc queueHandleLoop*[T: Serializable, S: RateLimitStore]( - manager: RateLimitManager[T, S], +proc queueHandleLoop*[T: Serializable]( + manager: RateLimitManager[T], nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} = Moment.now(), ) {.async.} = @@ -204,22 +200,18 @@ proc queueHandleLoop*[T: Serializable, S: RateLimitStore]( # configurable sleep duration for processing queued messages await sleepAsync(manager.sleepDuration) -proc start*[T: Serializable, S: RateLimitStore]( - manager: RateLimitManager[T, S], +proc start*[T: Serializable]( + manager: RateLimitManager[T], nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} = Moment.now(), ) {.async.} = manager.pxQueueHandleLoop = queueHandleLoop(manager, nowProvider) -proc stop*[T: Serializable, S: RateLimitStore]( - manager: RateLimitManager[T, S] -) {.async.} = +proc stop*[T: Serializable](manager: RateLimitManager[T]) {.async.} = if not isNil(manager.pxQueueHandleLoop): await manager.pxQueueHandleLoop.cancelAndWait() -func `$`*[T: Serializable, S: RateLimitStore]( - b: RateLimitManager[T, S] -): string {.inline.} = +func `$`*[T: Serializable](b: RateLimitManager[T]): string {.inline.} = if isNil(b): return "nil" return diff --git a/ratelimit/store/sqlite.nim b/ratelimit/store.nim similarity index 84% rename from ratelimit/store/sqlite.nim rename to ratelimit/store.nim index a3369e9..2618589 100644 --- a/ratelimit/store/sqlite.nim +++ b/ratelimit/store.nim @@ -1,7 +1,6 @@ import std/[times, strutils, json, options] -import ./store -import chronos import db_connector/db_sqlite +import chronos # Generic deserialization function for basic types proc fromBytesImpl(bytes: seq[byte], T: typedesc[string]): string = @@ -10,19 +9,31 @@ proc fromBytesImpl(bytes: seq[byte], T: typedesc[string]): string = for i, b in bytes: result[i] = char(b) -# SQLite Implementation -type SqliteRateLimitStore*[T] = ref object - db: DbConn - dbPath: string - criticalLength: int - normalLength: int - nextBatchId: int +type + Serializable* = + concept x + x.toBytes() is seq[byte] + + RateLimitStore*[T: Serializable] = ref object + db: DbConn + dbPath: string + criticalLength: int + normalLength: int + nextBatchId: int + + BucketState* = object + budget*: int + budgetCap*: int + lastTimeFull*: Moment + + QueueType* {.pure.} = enum + Critical = "critical" + Normal = "normal" const BUCKET_STATE_KEY = "rate_limit_bucket_state" -proc newSqliteRateLimitStore*[T](db: DbConn): SqliteRateLimitStore[T] = - result = - SqliteRateLimitStore[T](db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1) +proc new*[T: Serializable](M: type[RateLimitStore[T]], db: DbConn): M = + result = M(db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1) # Initialize cached lengths from database let criticalCount = db.getValue( @@ -53,8 +64,10 @@ proc newSqliteRateLimitStore*[T](db: DbConn): SqliteRateLimitStore[T] = else: parseInt(maxBatch) + 1 -proc saveBucketState*[T]( - store: SqliteRateLimitStore[T], bucketState: BucketState + return result + +proc saveBucketState*[T: Serializable]( + store: RateLimitStore[T], bucketState: BucketState ): Future[bool] {.async.} = try: # Convert Moment to Unix seconds for storage @@ -75,8 +88,8 @@ proc saveBucketState*[T]( except: return false -proc loadBucketState*[T]( - store: SqliteRateLimitStore[T] +proc loadBucketState*[T: Serializable]( + store: RateLimitStore[T] ): Future[Option[BucketState]] {.async.} = let jsonStr = store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY) @@ -95,8 +108,8 @@ proc loadBucketState*[T]( ) ) -proc addToQueue*[T]( - store: SqliteRateLimitStore[T], +proc addToQueue*[T: Serializable]( + store: RateLimitStore[T], queueType: QueueType, msgs: seq[tuple[msgId: string, msg: T]], ): Future[bool] {.async.} = @@ -140,8 +153,8 @@ proc addToQueue*[T]( except: return false -proc popFromQueue*[T]( - store: SqliteRateLimitStore[T], queueType: QueueType +proc popFromQueue*[T: Serializable]( + store: RateLimitStore[T], queueType: QueueType ): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} = try: let queueTypeStr = $queueType @@ -201,7 +214,9 @@ proc popFromQueue*[T]( except: return none(seq[tuple[msgId: string, msg: T]]) -proc getQueueLength*[T](store: SqliteRateLimitStore[T], queueType: QueueType): int = +proc getQueueLength*[T: Serializable]( + store: RateLimitStore[T], queueType: QueueType +): int = case queueType of QueueType.Critical: return store.criticalLength diff --git a/ratelimit/store/memory.nim b/ratelimit/store/memory.nim deleted file mode 100644 index e7e7c2f..0000000 --- a/ratelimit/store/memory.nim +++ /dev/null @@ -1,70 +0,0 @@ -import std/[times, options, deques, tables] -import ./store -import chronos - -# Memory Implementation -type MemoryRateLimitStore*[T] = ref object - bucketState: Option[BucketState] - criticalQueue: Deque[seq[tuple[msgId: string, msg: T]]] - normalQueue: Deque[seq[tuple[msgId: string, msg: T]]] - criticalLength: int - normalLength: int - -proc new*[T](M: type[MemoryRateLimitStore[T]]): M = - return M( - bucketState: none(BucketState), - criticalQueue: initDeque[seq[tuple[msgId: string, msg: T]]](), - normalQueue: initDeque[seq[tuple[msgId: string, msg: T]]](), - criticalLength: 0, - normalLength: 0 - ) - -proc saveBucketState*[T]( - store: MemoryRateLimitStore[T], bucketState: BucketState -): Future[bool] {.async.} = - store.bucketState = some(bucketState) - return true - -proc loadBucketState*[T]( - store: MemoryRateLimitStore[T] -): Future[Option[BucketState]] {.async.} = - return store.bucketState - -proc addToQueue*[T]( - store: MemoryRateLimitStore[T], - queueType: QueueType, - msgs: seq[tuple[msgId: string, msg: T]] -): Future[bool] {.async.} = - case queueType - of QueueType.Critical: - store.criticalQueue.addLast(msgs) - inc store.criticalLength - of QueueType.Normal: - store.normalQueue.addLast(msgs) - inc store.normalLength - return true - -proc popFromQueue*[T]( - store: MemoryRateLimitStore[T], - queueType: QueueType -): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} = - case queueType - of QueueType.Critical: - if store.criticalQueue.len > 0: - dec store.criticalLength - return some(store.criticalQueue.popFirst()) - of QueueType.Normal: - if store.normalQueue.len > 0: - dec store.normalLength - return some(store.normalQueue.popFirst()) - return none(seq[tuple[msgId: string, msg: T]]) - -proc getQueueLength*[T]( - store: MemoryRateLimitStore[T], - queueType: QueueType -): int = - case queueType - of QueueType.Critical: - return store.criticalLength - of QueueType.Normal: - return store.normalLength diff --git a/ratelimit/store/store.nim b/ratelimit/store/store.nim deleted file mode 100644 index 0f18eb1..0000000 --- a/ratelimit/store/store.nim +++ /dev/null @@ -1,20 +0,0 @@ -import std/[times, deques, options] -import chronos - -type - BucketState* = object - budget*: int - budgetCap*: int - lastTimeFull*: Moment - - QueueType* {.pure.} = enum - Critical = "critical" - Normal = "normal" - - RateLimitStore* = - concept s - s.saveBucketState(BucketState) is Future[bool] - s.loadBucketState() is Future[Option[BucketState]] - s.addToQueue(QueueType, seq[tuple[msgId: string, msg: untyped]]) is Future[bool] - s.popFromQueue(QueueType) is Future[Option[seq[tuple[msgId: string, msg: untyped]]]] - s.getQueueLength(QueueType) is int diff --git a/tests/test_ratelimit_manager.nim b/tests/test_ratelimit_manager.nim index f6c4039..3a6d870 100644 --- a/tests/test_ratelimit_manager.nim +++ b/tests/test_ratelimit_manager.nim @@ -1,12 +1,38 @@ import testutils/unittests import ../ratelimit/ratelimit_manager -import ../ratelimit/store/memory +import ../ratelimit/store import chronos +import db_connector/db_sqlite # Implement the Serializable concept for string proc toBytes*(s: string): seq[byte] = cast[seq[byte]](s) +# Helper function to create an in-memory database with the proper schema +proc createTestDatabase(): DbConn = + result = open(":memory:", "", "", "") + # Create the required tables + result.exec( + sql""" + CREATE TABLE IF NOT EXISTS kv_store ( + key TEXT PRIMARY KEY, + value BLOB + ) + """ + ) + result.exec( + sql""" + CREATE TABLE IF NOT EXISTS ratelimit_queues ( + queue_type TEXT NOT NULL, + msg_id TEXT NOT NULL, + msg_data BLOB NOT NULL, + batch_id INTEGER NOT NULL, + created_at INTEGER NOT NULL, + PRIMARY KEY (queue_type, batch_id, msg_id) + ) + """ + ) + suite "Queue RateLimitManager": setup: var sentMessages: seq[tuple[msgId: string, msg: string]] @@ -23,8 +49,9 @@ suite "Queue RateLimitManager": asyncTest "sendOrEnqueue - immediate send when capacity available": ## Given - let store: MemoryRateLimitStore = MemoryRateLimitStore.new() - let manager = await RateLimitManager[string, MemoryRateLimitStore].new( + let db = createTestDatabase() + let store: RateLimitStore[string] = RateLimitStore[string].new(db) + let manager = await RateLimitManager[string].new( store, mockSender, capacity = 10, duration = chronos.milliseconds(100) ) let testMsg = "Hello World" @@ -42,8 +69,9 @@ suite "Queue RateLimitManager": asyncTest "sendOrEnqueue - multiple messages": ## Given - let store = MemoryRateLimitStore.new() - let manager = await RateLimitManager[string, MemoryRateLimitStore].new( + let db = createTestDatabase() + let store = RateLimitStore[string].new(db) + let manager = await RateLimitManager[string].new( store, mockSender, capacity = 10, duration = chronos.milliseconds(100) ) @@ -63,8 +91,9 @@ suite "Queue RateLimitManager": asyncTest "start and stop - drop large batch": ## Given - let store = MemoryRateLimitStore.new() - let manager = await RateLimitManager[string, MemoryRateLimitStore].new( + let db = createTestDatabase() + let store = RateLimitStore[string].new(db) + let manager = await RateLimitManager[string].new( store, mockSender, capacity = 2, @@ -80,8 +109,9 @@ suite "Queue RateLimitManager": asyncTest "enqueue - enqueue critical only when exceeded": ## Given - let store = MemoryRateLimitStore.new() - let manager = await RateLimitManager[string, MemoryRateLimitStore].new( + let db = createTestDatabase() + let store = RateLimitStore[string].new(db) + let manager = await RateLimitManager[string].new( store, mockSender, capacity = 10, @@ -130,8 +160,9 @@ suite "Queue RateLimitManager": asyncTest "enqueue - enqueue normal on 70% capacity": ## Given - let store = MemoryRateLimitStore.new() - let manager = await RateLimitManager[string, MemoryRateLimitStore].new( + let db = createTestDatabase() + let store = RateLimitStore[string].new(db) + let manager = await RateLimitManager[string].new( store, mockSender, capacity = 10, @@ -183,8 +214,9 @@ suite "Queue RateLimitManager": asyncTest "enqueue - process queued messages": ## Given - let store = MemoryRateLimitStore.new() - let manager = await RateLimitManager[string, MemoryRateLimitStore].new( + let db = createTestDatabase() + let store = RateLimitStore[string].new(db) + let manager = await RateLimitManager[string].new( store, mockSender, capacity = 10, diff --git a/tests/test_sqlite_store.nim b/tests/test_store.nim similarity index 92% rename from tests/test_sqlite_store.nim rename to tests/test_store.nim index 90d764c..251cf24 100644 --- a/tests/test_sqlite_store.nim +++ b/tests/test_store.nim @@ -1,6 +1,5 @@ import testutils/unittests -import ../ratelimit/store/sqlite -import ../ratelimit/store/store +import ../ratelimit/store import chronos import db_connector/db_sqlite import ../chat_sdk/migration @@ -32,7 +31,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "newSqliteRateLimitStore - empty state": ## Given - let store = newSqliteRateLimitStore[string](db) + let store = RateLimitStore[string].new(db) ## When let loadedState = await store.loadBucketState() @@ -42,7 +41,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "saveBucketState and loadBucketState - state persistence": ## Given - let store = newSqliteRateLimitStore[string](db) + let store = RateLimitStore[string].new(db) let now = Moment.now() echo "now: ", now.epochSeconds() @@ -62,7 +61,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "queue operations - empty store": ## Given - let store = newSqliteRateLimitStore[string](db) + let store = RateLimitStore[string].new(db) ## When/Then check store.getQueueLength(QueueType.Critical) == 0 @@ -76,7 +75,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "addToQueue and popFromQueue - single batch": ## Given - let store = newSqliteRateLimitStore[string](db) + let store = RateLimitStore[string].new(db) let msgs = @[("msg1", "Hello"), ("msg2", "World")] ## When @@ -103,7 +102,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "addToQueue and popFromQueue - multiple batches FIFO": ## Given - let store = newSqliteRateLimitStore[string](db) + let store = RateLimitStore[string].new(db) let batch1 = @[("msg1", "First")] let batch2 = @[("msg2", "Second")] let batch3 = @[("msg3", "Third")] @@ -141,7 +140,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "queue isolation - critical and normal queues are separate": ## Given - let store = newSqliteRateLimitStore[string](db) + let store = RateLimitStore[string].new(db) let criticalMsgs = @[("crit1", "Critical Message")] let normalMsgs = @[("norm1", "Normal Message")] @@ -178,14 +177,14 @@ suite "SqliteRateLimitStore Tests": let msgs = @[("persist1", "Persistent Message")] block: - let store1 = newSqliteRateLimitStore[string](db) + let store1 = RateLimitStore[string].new(db) let addResult = await store1.addToQueue(QueueType.Critical, msgs) check addResult == true check store1.getQueueLength(QueueType.Critical) == 1 ## When - Create new store instance block: - let store2 = newSqliteRateLimitStore[string](db) + let store2 = RateLimitStore[string].new(db) ## Then - Queue length should be restored from database check store2.getQueueLength(QueueType.Critical) == 1 @@ -197,7 +196,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "large batch handling": ## Given - let store = newSqliteRateLimitStore[string](db) + let store = RateLimitStore[string].new(db) var largeBatch: seq[tuple[msgId: string, msg: string]] for i in 1 .. 100: From bcdb56c1ca9e7b6273dfe3dea95e0bb1819fa154 Mon Sep 17 00:00:00 2001 From: pablo Date: Mon, 4 Aug 2025 11:48:20 +0300 Subject: [PATCH 12/17] fix: tests --- ratelimit/ratelimit_manager.nim | 63 ++++++++++-------------- tests/test_ratelimit_manager.nim | 84 +++++++++----------------------- tests/test_store.nim | 8 +-- 3 files changed, 54 insertions(+), 101 deletions(-) diff --git a/ratelimit/ratelimit_manager.nim b/ratelimit/ratelimit_manager.nim index 438619f..333dc64 100644 --- a/ratelimit/ratelimit_manager.nim +++ b/ratelimit/ratelimit_manager.nim @@ -1,9 +1,10 @@ -import std/[times, deques, options] +import std/[times, options] # TODO: move to waku's, chronos' or a lib tocken_bucket once decided where this will live import ./token_bucket # import waku/common/rate_limit/token_bucket import ./store import chronos +import db_connector/db_sqlite type CapacityState {.pure.} = enum @@ -30,8 +31,6 @@ type store: RateLimitStore[T] bucket: TokenBucket sender: MessageSender[T] - queueCritical: Deque[seq[MsgIdMsg[T]]] - queueNormal: Deque[seq[MsgIdMsg[T]]] sleepDuration: chronos.Duration pxQueueHandleLoop: Future[void] @@ -61,8 +60,6 @@ proc new*[T: Serializable]( current.get().lastTimeFull, ), sender: sender, - queueCritical: Deque[seq[MsgIdMsg[T]]](), - queueNormal: Deque[seq[MsgIdMsg[T]]](), sleepDuration: sleepDuration, ) @@ -90,10 +87,10 @@ proc passToSender[T: Serializable]( if not consumed: case priority of Priority.Critical: - manager.queueCritical.addLast(msgs) + discard await manager.store.addToQueue(QueueType.Critical, msgs) return SendResult.Enqueued of Priority.Normal: - manager.queueNormal.addLast(msgs) + discard await manager.store.addToQueue(QueueType.Normal, msgs) return SendResult.Enqueued of Priority.Optional: return SendResult.Dropped @@ -108,29 +105,39 @@ proc passToSender[T: Serializable]( proc processCriticalQueue[T: Serializable]( manager: RateLimitManager[T], now: Moment ): Future[void] {.async.} = - while manager.queueCritical.len > 0: - let msgs = manager.queueCritical.popFirst() + while manager.store.getQueueLength(QueueType.Critical) > 0: + # Peek at the next batch by getting it, but we'll handle putting it back if needed + let maybeMsgs = await manager.store.popFromQueue(QueueType.Critical) + if maybeMsgs.isNone(): + break + + let msgs = maybeMsgs.get() let capacityState = manager.getCapacityState(now, msgs.len) if capacityState == CapacityState.Normal: discard await manager.passToSender(msgs, now, Priority.Critical) elif capacityState == CapacityState.AlmostNone: discard await manager.passToSender(msgs, now, Priority.Critical) else: - # add back to critical queue - manager.queueCritical.addFirst(msgs) + # Put back to critical queue (add to front not possible, so we add to back and exit) + discard await manager.store.addToQueue(QueueType.Critical, msgs) break proc processNormalQueue[T: Serializable]( manager: RateLimitManager[T], now: Moment ): Future[void] {.async.} = - while manager.queueNormal.len > 0: - let msgs = manager.queueNormal.popFirst() + while manager.store.getQueueLength(QueueType.Normal) > 0: + # Peek at the next batch by getting it, but we'll handle putting it back if needed + let maybeMsgs = await manager.store.popFromQueue(QueueType.Normal) + if maybeMsgs.isNone(): + break + + let msgs = maybeMsgs.get() let capacityState = manager.getCapacityState(now, msgs.len) if capacityState == CapacityState.Normal: discard await manager.passToSender(msgs, now, Priority.Normal) else: - # add back to critical queue - manager.queueNormal.addFirst(msgs) + # Put back to normal queue (add to front not possible, so we add to back and exit) + discard await manager.store.addToQueue(QueueType.Normal, msgs) break proc sendOrEnqueue*[T: Serializable]( @@ -153,37 +160,21 @@ proc sendOrEnqueue*[T: Serializable]( of Priority.Critical: return await manager.passToSender(msgs, now, priority) of Priority.Normal: - manager.queueNormal.addLast(msgs) + discard await manager.store.addToQueue(QueueType.Normal, msgs) return SendResult.Enqueued of Priority.Optional: return SendResult.Dropped of CapacityState.None: case priority of Priority.Critical: - manager.queueCritical.addLast(msgs) + discard await manager.store.addToQueue(QueueType.Critical, msgs) return SendResult.Enqueued of Priority.Normal: - manager.queueNormal.addLast(msgs) + discard await manager.store.addToQueue(QueueType.Normal, msgs) return SendResult.Enqueued of Priority.Optional: return SendResult.Dropped -proc getEnqueued*[T: Serializable]( - manager: RateLimitManager[T] -): tuple[ - critical: seq[tuple[msgId: string, msg: T]], normal: seq[tuple[msgId: string, msg: T]] -] = - var criticalMsgs: seq[tuple[msgId: string, msg: T]] - var normalMsgs: seq[tuple[msgId: string, msg: T]] - - for batch in manager.queueCritical: - criticalMsgs.add(batch) - - for batch in manager.queueNormal: - normalMsgs.add(batch) - - return (criticalMsgs, normalMsgs) - proc queueHandleLoop*[T: Serializable]( manager: RateLimitManager[T], nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} = @@ -215,5 +206,5 @@ func `$`*[T: Serializable](b: RateLimitManager[T]): string {.inline.} = if isNil(b): return "nil" return - "RateLimitManager(critical: " & $b.queueCritical.len & ", normal: " & - $b.queueNormal.len & ")" + "RateLimitManager(critical: " & $b.store.getQueueLength(QueueType.Critical) & + ", normal: " & $b.store.getQueueLength(QueueType.Normal) & ")" diff --git a/tests/test_ratelimit_manager.nim b/tests/test_ratelimit_manager.nim index 3a6d870..88a2335 100644 --- a/tests/test_ratelimit_manager.nim +++ b/tests/test_ratelimit_manager.nim @@ -3,38 +3,20 @@ import ../ratelimit/ratelimit_manager import ../ratelimit/store import chronos import db_connector/db_sqlite +import ../chat_sdk/migration +import std/[os, options] # Implement the Serializable concept for string proc toBytes*(s: string): seq[byte] = cast[seq[byte]](s) -# Helper function to create an in-memory database with the proper schema -proc createTestDatabase(): DbConn = - result = open(":memory:", "", "", "") - # Create the required tables - result.exec( - sql""" - CREATE TABLE IF NOT EXISTS kv_store ( - key TEXT PRIMARY KEY, - value BLOB - ) - """ - ) - result.exec( - sql""" - CREATE TABLE IF NOT EXISTS ratelimit_queues ( - queue_type TEXT NOT NULL, - msg_id TEXT NOT NULL, - msg_data BLOB NOT NULL, - batch_id INTEGER NOT NULL, - created_at INTEGER NOT NULL, - PRIMARY KEY (queue_type, batch_id, msg_id) - ) - """ - ) +var dbName = "test_ratelimit_manager.db" suite "Queue RateLimitManager": setup: + let db = open(dbName, "", "", "") + runMigrations(db) + var sentMessages: seq[tuple[msgId: string, msg: string]] var senderCallCount: int = 0 @@ -47,9 +29,14 @@ suite "Queue RateLimitManager": sentMessages.add(msg) await sleepAsync(chronos.milliseconds(10)) + teardown: + if db != nil: + db.close() + if fileExists(dbName): + removeFile(dbName) + asyncTest "sendOrEnqueue - immediate send when capacity available": ## Given - let db = createTestDatabase() let store: RateLimitStore[string] = RateLimitStore[string].new(db) let manager = await RateLimitManager[string].new( store, mockSender, capacity = 10, duration = chronos.milliseconds(100) @@ -69,7 +56,6 @@ suite "Queue RateLimitManager": asyncTest "sendOrEnqueue - multiple messages": ## Given - let db = createTestDatabase() let store = RateLimitStore[string].new(db) let manager = await RateLimitManager[string].new( store, mockSender, capacity = 10, duration = chronos.milliseconds(100) @@ -91,7 +77,6 @@ suite "Queue RateLimitManager": asyncTest "start and stop - drop large batch": ## Given - let db = createTestDatabase() let store = RateLimitStore[string].new(db) let manager = await RateLimitManager[string].new( store, @@ -109,7 +94,6 @@ suite "Queue RateLimitManager": asyncTest "enqueue - enqueue critical only when exceeded": ## Given - let db = createTestDatabase() let store = RateLimitStore[string].new(db) let manager = await RateLimitManager[string].new( store, @@ -152,15 +136,8 @@ suite "Queue RateLimitManager": r10 == PassedToSender r11 == Enqueued - let (critical, normal) = manager.getEnqueued() - check: - critical.len == 1 - normal.len == 0 - critical[0].msgId == "msg11" - asyncTest "enqueue - enqueue normal on 70% capacity": - ## Given - let db = createTestDatabase() + ## Given let store = RateLimitStore[string].new(db) let manager = await RateLimitManager[string].new( store, @@ -204,17 +181,8 @@ suite "Queue RateLimitManager": r11 == PassedToSender r12 == Dropped - let (critical, normal) = manager.getEnqueued() - check: - critical.len == 0 - normal.len == 3 - normal[0].msgId == "msg8" - normal[1].msgId == "msg9" - normal[2].msgId == "msg10" - asyncTest "enqueue - process queued messages": ## Given - let db = createTestDatabase() let store = RateLimitStore[string].new(db) let manager = await RateLimitManager[string].new( store, @@ -268,24 +236,9 @@ suite "Queue RateLimitManager": r14 == PassedToSender r15 == Enqueued - var (critical, normal) = manager.getEnqueued() check: - critical.len == 1 - normal.len == 3 - normal[0].msgId == "8" - normal[1].msgId == "9" - normal[2].msgId == "10" - critical[0].msgId == "15" - - nowRef.value = now + chronos.milliseconds(250) - await sleepAsync(chronos.milliseconds(250)) - - (critical, normal) = manager.getEnqueued() - check: - critical.len == 0 - normal.len == 0 - senderCallCount == 14 - sentMessages.len == 14 + senderCallCount == 10 # 10 messages passed to sender + sentMessages.len == 10 sentMessages[0].msgId == "1" sentMessages[1].msgId == "2" sentMessages[2].msgId == "3" @@ -296,6 +249,13 @@ suite "Queue RateLimitManager": sentMessages[7].msgId == "11" sentMessages[8].msgId == "13" sentMessages[9].msgId == "14" + + nowRef.value = now + chronos.milliseconds(250) + await sleepAsync(chronos.milliseconds(250)) + + check: + senderCallCount == 14 + sentMessages.len == 14 sentMessages[10].msgId == "15" sentMessages[11].msgId == "8" sentMessages[12].msgId == "9" diff --git a/tests/test_store.nim b/tests/test_store.nim index 251cf24..ae5f009 100644 --- a/tests/test_store.nim +++ b/tests/test_store.nim @@ -5,6 +5,8 @@ import db_connector/db_sqlite import ../chat_sdk/migration import std/[options, os] +const dbName = "test_store.db" + # Implement the Serializable concept for string (for testing) proc toBytes*(s: string): seq[byte] = # Convert each character to a byte @@ -20,14 +22,14 @@ proc fromBytes*(bytes: seq[byte], T: typedesc[string]): string = suite "SqliteRateLimitStore Tests": setup: - let db = open("test-ratelimit.db", "", "", "") + let db = open(dbName, "", "", "") runMigrations(db) teardown: if db != nil: db.close() - if fileExists("test-ratelimit.db"): - removeFile("test-ratelimit.db") + if fileExists(dbName): + removeFile(dbName) asyncTest "newSqliteRateLimitStore - empty state": ## Given From 1039d379db09a26ad074910dc5882efb4f7e1320 Mon Sep 17 00:00:00 2001 From: Pablo Lopez Date: Mon, 11 Aug 2025 14:58:04 +0300 Subject: [PATCH 13/17] Update ratelimit/store.nim Co-authored-by: NagyZoltanPeter <113987313+NagyZoltanPeter@users.noreply.github.com> --- ratelimit/store.nim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ratelimit/store.nim b/ratelimit/store.nim index 2618589..6aa0be5 100644 --- a/ratelimit/store.nim +++ b/ratelimit/store.nim @@ -108,7 +108,7 @@ proc loadBucketState*[T: Serializable]( ) ) -proc addToQueue*[T: Serializable]( +proc pushToQueue*[T: Serializable]( store: RateLimitStore[T], queueType: QueueType, msgs: seq[tuple[msgId: string, msg: T]], From b50240942f50b6a891db921fb934d5b3a5dc1754 Mon Sep 17 00:00:00 2001 From: pablo Date: Tue, 12 Aug 2025 12:09:57 +0300 Subject: [PATCH 14/17] fix: serialize using flatty --- chat_sdk.nimble | 4 +-- ratelimit/ratelimit_manager.nim | 40 +++++++++++----------- ratelimit/store.nim | 58 +++++++++----------------------- tests/test_ratelimit_manager.nim | 4 --- tests/test_store.nim | 32 ++++++------------ 5 files changed, 47 insertions(+), 91 deletions(-) diff --git a/chat_sdk.nimble b/chat_sdk.nimble index e5fbb06..8841ff3 100644 --- a/chat_sdk.nimble +++ b/chat_sdk.nimble @@ -7,9 +7,7 @@ license = "MIT" srcDir = "src" ### Dependencies -requires "nim >= 2.2.4", - "chronicles", "chronos", "db_connector", - "https://github.com/waku-org/token_bucket.git" +requires "nim >= 2.2.4", "chronicles", "chronos", "db_connector", "flatty" task buildSharedLib, "Build shared library for C bindings": exec "nim c --mm:refc --app:lib --out:../library/c-bindings/libchatsdk.so chat_sdk/chat_sdk.nim" diff --git a/ratelimit/ratelimit_manager.nim b/ratelimit/ratelimit_manager.nim index 333dc64..a30b5be 100644 --- a/ratelimit/ratelimit_manager.nim +++ b/ratelimit/ratelimit_manager.nim @@ -23,18 +23,18 @@ type Normal Optional - MsgIdMsg[T: Serializable] = tuple[msgId: string, msg: T] + MsgIdMsg[T] = tuple[msgId: string, msg: T] - MessageSender*[T: Serializable] = proc(msgs: seq[MsgIdMsg[T]]) {.async.} + MessageSender*[T] = proc(msgs: seq[MsgIdMsg[T]]) {.async.} - RateLimitManager*[T: Serializable] = ref object + RateLimitManager*[T] = ref object store: RateLimitStore[T] bucket: TokenBucket sender: MessageSender[T] sleepDuration: chronos.Duration pxQueueHandleLoop: Future[void] -proc new*[T: Serializable]( +proc new*[T]( M: type[RateLimitManager[T]], store: RateLimitStore[T], sender: MessageSender[T], @@ -63,7 +63,7 @@ proc new*[T: Serializable]( sleepDuration: sleepDuration, ) -proc getCapacityState[T: Serializable]( +proc getCapacityState[T]( manager: RateLimitManager[T], now: Moment, count: int = 1 ): CapacityState = let (budget, budgetCap, _) = manager.bucket.getAvailableCapacity(now) @@ -76,7 +76,7 @@ proc getCapacityState[T: Serializable]( else: return CapacityState.Normal -proc passToSender[T: Serializable]( +proc passToSender[T]( manager: RateLimitManager[T], msgs: seq[tuple[msgId: string, msg: T]], now: Moment, @@ -87,10 +87,10 @@ proc passToSender[T: Serializable]( if not consumed: case priority of Priority.Critical: - discard await manager.store.addToQueue(QueueType.Critical, msgs) + discard await manager.store.pushToQueue(QueueType.Critical, msgs) return SendResult.Enqueued of Priority.Normal: - discard await manager.store.addToQueue(QueueType.Normal, msgs) + discard await manager.store.pushToQueue(QueueType.Normal, msgs) return SendResult.Enqueued of Priority.Optional: return SendResult.Dropped @@ -102,7 +102,7 @@ proc passToSender[T: Serializable]( await manager.sender(msgs) return SendResult.PassedToSender -proc processCriticalQueue[T: Serializable]( +proc processCriticalQueue[T]( manager: RateLimitManager[T], now: Moment ): Future[void] {.async.} = while manager.store.getQueueLength(QueueType.Critical) > 0: @@ -119,10 +119,10 @@ proc processCriticalQueue[T: Serializable]( discard await manager.passToSender(msgs, now, Priority.Critical) else: # Put back to critical queue (add to front not possible, so we add to back and exit) - discard await manager.store.addToQueue(QueueType.Critical, msgs) + discard await manager.store.pushToQueue(QueueType.Critical, msgs) break -proc processNormalQueue[T: Serializable]( +proc processNormalQueue[T]( manager: RateLimitManager[T], now: Moment ): Future[void] {.async.} = while manager.store.getQueueLength(QueueType.Normal) > 0: @@ -137,10 +137,10 @@ proc processNormalQueue[T: Serializable]( discard await manager.passToSender(msgs, now, Priority.Normal) else: # Put back to normal queue (add to front not possible, so we add to back and exit) - discard await manager.store.addToQueue(QueueType.Normal, msgs) + discard await manager.store.pushToQueue(QueueType.Normal, msgs) break -proc sendOrEnqueue*[T: Serializable]( +proc sendOrEnqueue*[T]( manager: RateLimitManager[T], msgs: seq[tuple[msgId: string, msg: T]], priority: Priority, @@ -160,22 +160,22 @@ proc sendOrEnqueue*[T: Serializable]( of Priority.Critical: return await manager.passToSender(msgs, now, priority) of Priority.Normal: - discard await manager.store.addToQueue(QueueType.Normal, msgs) + discard await manager.store.pushToQueue(QueueType.Normal, msgs) return SendResult.Enqueued of Priority.Optional: return SendResult.Dropped of CapacityState.None: case priority of Priority.Critical: - discard await manager.store.addToQueue(QueueType.Critical, msgs) + discard await manager.store.pushToQueue(QueueType.Critical, msgs) return SendResult.Enqueued of Priority.Normal: - discard await manager.store.addToQueue(QueueType.Normal, msgs) + discard await manager.store.pushToQueue(QueueType.Normal, msgs) return SendResult.Enqueued of Priority.Optional: return SendResult.Dropped -proc queueHandleLoop*[T: Serializable]( +proc queueHandleLoop*[T]( manager: RateLimitManager[T], nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} = Moment.now(), @@ -191,18 +191,18 @@ proc queueHandleLoop*[T: Serializable]( # configurable sleep duration for processing queued messages await sleepAsync(manager.sleepDuration) -proc start*[T: Serializable]( +proc start*[T]( manager: RateLimitManager[T], nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} = Moment.now(), ) {.async.} = manager.pxQueueHandleLoop = queueHandleLoop(manager, nowProvider) -proc stop*[T: Serializable](manager: RateLimitManager[T]) {.async.} = +proc stop*[T](manager: RateLimitManager[T]) {.async.} = if not isNil(manager.pxQueueHandleLoop): await manager.pxQueueHandleLoop.cancelAndWait() -func `$`*[T: Serializable](b: RateLimitManager[T]): string {.inline.} = +func `$`*[T](b: RateLimitManager[T]): string {.inline.} = if isNil(b): return "nil" return diff --git a/ratelimit/store.nim b/ratelimit/store.nim index 6aa0be5..3c023aa 100644 --- a/ratelimit/store.nim +++ b/ratelimit/store.nim @@ -1,20 +1,10 @@ -import std/[times, strutils, json, options] +import std/[times, strutils, json, options, base64] import db_connector/db_sqlite import chronos - -# Generic deserialization function for basic types -proc fromBytesImpl(bytes: seq[byte], T: typedesc[string]): string = - # Convert each byte back to a character - result = newString(bytes.len) - for i, b in bytes: - result[i] = char(b) +import flatty type - Serializable* = - concept x - x.toBytes() is seq[byte] - - RateLimitStore*[T: Serializable] = ref object + RateLimitStore*[T] = ref object db: DbConn dbPath: string criticalLength: int @@ -32,7 +22,7 @@ type const BUCKET_STATE_KEY = "rate_limit_bucket_state" -proc new*[T: Serializable](M: type[RateLimitStore[T]], db: DbConn): M = +proc new*[T](M: type[RateLimitStore[T]], db: DbConn): M = result = M(db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1) # Initialize cached lengths from database @@ -66,7 +56,7 @@ proc new*[T: Serializable](M: type[RateLimitStore[T]], db: DbConn): M = return result -proc saveBucketState*[T: Serializable]( +proc saveBucketState*[T]( store: RateLimitStore[T], bucketState: BucketState ): Future[bool] {.async.} = try: @@ -88,7 +78,7 @@ proc saveBucketState*[T: Serializable]( except: return false -proc loadBucketState*[T: Serializable]( +proc loadBucketState*[T]( store: RateLimitStore[T] ): Future[Option[BucketState]] {.async.} = let jsonStr = @@ -108,7 +98,7 @@ proc loadBucketState*[T: Serializable]( ) ) -proc pushToQueue*[T: Serializable]( +proc pushToQueue*[T]( store: RateLimitStore[T], queueType: QueueType, msgs: seq[tuple[msgId: string, msg: T]], @@ -123,18 +113,13 @@ proc pushToQueue*[T: Serializable]( store.db.exec(sql"BEGIN TRANSACTION") try: for msg in msgs: - # Consistent serialization for all types - let msgBytes = msg.msg.toBytes() - # Convert seq[byte] to string for SQLite storage (each byte becomes a character) - var binaryStr = newString(msgBytes.len) - for i, b in msgBytes: - binaryStr[i] = char(b) - + let serialized = msg.msg.toFlatty() + let msgData = encode(serialized) store.db.exec( sql"INSERT INTO ratelimit_queues (queue_type, msg_id, msg_data, batch_id, created_at) VALUES (?, ?, ?, ?, ?)", queueTypeStr, msg.msgId, - binaryStr, + msgData, batchId, now, ) @@ -153,7 +138,7 @@ proc pushToQueue*[T: Serializable]( except: return false -proc popFromQueue*[T: Serializable]( +proc popFromQueue*[T]( store: RateLimitStore[T], queueType: QueueType ): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} = try: @@ -182,20 +167,11 @@ proc popFromQueue*[T: Serializable]( var msgs: seq[tuple[msgId: string, msg: T]] for row in rows: let msgIdStr = row[0] - let msgData = row[1] # SQLite returns BLOB as string where each char is a byte - # Convert string back to seq[byte] properly (each char in string is a byte) - var msgBytes: seq[byte] - for c in msgData: - msgBytes.add(byte(c)) + let msgDataB64 = row[1] - # Generic deserialization - works for any type that implements fromBytes - when T is string: - let msg = fromBytesImpl(msgBytes, T) - msgs.add((msgId: msgIdStr, msg: msg)) - else: - # For other types, they need to provide their own fromBytes in the calling context - let msg = fromBytes(msgBytes, T) - msgs.add((msgId: msgIdStr, msg: msg)) + let serialized = decode(msgDataB64) + let msg = serialized.fromFlatty(T) + msgs.add((msgId: msgIdStr, msg: msg)) # Delete the batch from database store.db.exec( @@ -214,9 +190,7 @@ proc popFromQueue*[T: Serializable]( except: return none(seq[tuple[msgId: string, msg: T]]) -proc getQueueLength*[T: Serializable]( - store: RateLimitStore[T], queueType: QueueType -): int = +proc getQueueLength*[T](store: RateLimitStore[T], queueType: QueueType): int = case queueType of QueueType.Critical: return store.criticalLength diff --git a/tests/test_ratelimit_manager.nim b/tests/test_ratelimit_manager.nim index 88a2335..70ebe2c 100644 --- a/tests/test_ratelimit_manager.nim +++ b/tests/test_ratelimit_manager.nim @@ -6,10 +6,6 @@ import db_connector/db_sqlite import ../chat_sdk/migration import std/[os, options] -# Implement the Serializable concept for string -proc toBytes*(s: string): seq[byte] = - cast[seq[byte]](s) - var dbName = "test_ratelimit_manager.db" suite "Queue RateLimitManager": diff --git a/tests/test_store.nim b/tests/test_store.nim index ae5f009..b08b7d3 100644 --- a/tests/test_store.nim +++ b/tests/test_store.nim @@ -3,23 +3,11 @@ import ../ratelimit/store import chronos import db_connector/db_sqlite import ../chat_sdk/migration -import std/[options, os] +import std/[options, os, json] +import flatty const dbName = "test_store.db" -# Implement the Serializable concept for string (for testing) -proc toBytes*(s: string): seq[byte] = - # Convert each character to a byte - result = newSeq[byte](s.len) - for i, c in s: - result[i] = byte(c) - -proc fromBytes*(bytes: seq[byte], T: typedesc[string]): string = - # Convert each byte back to a character - result = newString(bytes.len) - for i, b in bytes: - result[i] = char(b) - suite "SqliteRateLimitStore Tests": setup: let db = open(dbName, "", "", "") @@ -81,7 +69,7 @@ suite "SqliteRateLimitStore Tests": let msgs = @[("msg1", "Hello"), ("msg2", "World")] ## When - let addResult = await store.addToQueue(QueueType.Critical, msgs) + let addResult = await store.pushToQueue(QueueType.Critical, msgs) ## Then check addResult == true @@ -110,11 +98,11 @@ suite "SqliteRateLimitStore Tests": let batch3 = @[("msg3", "Third")] ## When - Add batches - let result1 = await store.addToQueue(QueueType.Normal, batch1) + let result1 = await store.pushToQueue(QueueType.Normal, batch1) check result1 == true - let result2 = await store.addToQueue(QueueType.Normal, batch2) + let result2 = await store.pushToQueue(QueueType.Normal, batch2) check result2 == true - let result3 = await store.addToQueue(QueueType.Normal, batch3) + let result3 = await store.pushToQueue(QueueType.Normal, batch3) check result3 == true ## Then - Check lengths @@ -147,9 +135,9 @@ suite "SqliteRateLimitStore Tests": let normalMsgs = @[("norm1", "Normal Message")] ## When - let critResult = await store.addToQueue(QueueType.Critical, criticalMsgs) + let critResult = await store.pushToQueue(QueueType.Critical, criticalMsgs) check critResult == true - let normResult = await store.addToQueue(QueueType.Normal, normalMsgs) + let normResult = await store.pushToQueue(QueueType.Normal, normalMsgs) check normResult == true ## Then @@ -180,7 +168,7 @@ suite "SqliteRateLimitStore Tests": block: let store1 = RateLimitStore[string].new(db) - let addResult = await store1.addToQueue(QueueType.Critical, msgs) + let addResult = await store1.pushToQueue(QueueType.Critical, msgs) check addResult == true check store1.getQueueLength(QueueType.Critical) == 1 @@ -205,7 +193,7 @@ suite "SqliteRateLimitStore Tests": largeBatch.add(("msg" & $i, "Message " & $i)) ## When - let addResult = await store.addToQueue(QueueType.Normal, largeBatch) + let addResult = await store.pushToQueue(QueueType.Normal, largeBatch) ## Then check addResult == true From faadd4f68c014ad1b974c933496f22036e63ba71 Mon Sep 17 00:00:00 2001 From: pablo Date: Tue, 12 Aug 2025 12:29:57 +0300 Subject: [PATCH 15/17] fix: async store --- ratelimit/store.nim | 4 ++-- tests/test_ratelimit_manager.nim | 12 ++++++------ tests/test_store.nim | 18 +++++++++--------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/ratelimit/store.nim b/ratelimit/store.nim index 3c023aa..262bd41 100644 --- a/ratelimit/store.nim +++ b/ratelimit/store.nim @@ -11,7 +11,7 @@ type normalLength: int nextBatchId: int - BucketState* = object + BucketState* {.pure} = object budget*: int budgetCap*: int lastTimeFull*: Moment @@ -22,7 +22,7 @@ type const BUCKET_STATE_KEY = "rate_limit_bucket_state" -proc new*[T](M: type[RateLimitStore[T]], db: DbConn): M = +proc new*[T](M: type[RateLimitStore[T]], db: DbConn): Future[M] {.async} = result = M(db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1) # Initialize cached lengths from database diff --git a/tests/test_ratelimit_manager.nim b/tests/test_ratelimit_manager.nim index 70ebe2c..50a2000 100644 --- a/tests/test_ratelimit_manager.nim +++ b/tests/test_ratelimit_manager.nim @@ -33,7 +33,7 @@ suite "Queue RateLimitManager": asyncTest "sendOrEnqueue - immediate send when capacity available": ## Given - let store: RateLimitStore[string] = RateLimitStore[string].new(db) + let store: RateLimitStore[string] = await RateLimitStore[string].new(db) let manager = await RateLimitManager[string].new( store, mockSender, capacity = 10, duration = chronos.milliseconds(100) ) @@ -52,7 +52,7 @@ suite "Queue RateLimitManager": asyncTest "sendOrEnqueue - multiple messages": ## Given - let store = RateLimitStore[string].new(db) + let store = await RateLimitStore[string].new(db) let manager = await RateLimitManager[string].new( store, mockSender, capacity = 10, duration = chronos.milliseconds(100) ) @@ -73,7 +73,7 @@ suite "Queue RateLimitManager": asyncTest "start and stop - drop large batch": ## Given - let store = RateLimitStore[string].new(db) + let store = await RateLimitStore[string].new(db) let manager = await RateLimitManager[string].new( store, mockSender, @@ -90,7 +90,7 @@ suite "Queue RateLimitManager": asyncTest "enqueue - enqueue critical only when exceeded": ## Given - let store = RateLimitStore[string].new(db) + let store = await RateLimitStore[string].new(db) let manager = await RateLimitManager[string].new( store, mockSender, @@ -134,7 +134,7 @@ suite "Queue RateLimitManager": asyncTest "enqueue - enqueue normal on 70% capacity": ## Given - let store = RateLimitStore[string].new(db) + let store = await RateLimitStore[string].new(db) let manager = await RateLimitManager[string].new( store, mockSender, @@ -179,7 +179,7 @@ suite "Queue RateLimitManager": asyncTest "enqueue - process queued messages": ## Given - let store = RateLimitStore[string].new(db) + let store = await RateLimitStore[string].new(db) let manager = await RateLimitManager[string].new( store, mockSender, diff --git a/tests/test_store.nim b/tests/test_store.nim index b08b7d3..16cfdfe 100644 --- a/tests/test_store.nim +++ b/tests/test_store.nim @@ -21,7 +21,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "newSqliteRateLimitStore - empty state": ## Given - let store = RateLimitStore[string].new(db) + let store = await RateLimitStore[string].new(db) ## When let loadedState = await store.loadBucketState() @@ -31,7 +31,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "saveBucketState and loadBucketState - state persistence": ## Given - let store = RateLimitStore[string].new(db) + let store = await RateLimitStore[string].new(db) let now = Moment.now() echo "now: ", now.epochSeconds() @@ -51,7 +51,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "queue operations - empty store": ## Given - let store = RateLimitStore[string].new(db) + let store = await RateLimitStore[string].new(db) ## When/Then check store.getQueueLength(QueueType.Critical) == 0 @@ -65,7 +65,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "addToQueue and popFromQueue - single batch": ## Given - let store = RateLimitStore[string].new(db) + let store = await RateLimitStore[string].new(db) let msgs = @[("msg1", "Hello"), ("msg2", "World")] ## When @@ -92,7 +92,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "addToQueue and popFromQueue - multiple batches FIFO": ## Given - let store = RateLimitStore[string].new(db) + let store = await RateLimitStore[string].new(db) let batch1 = @[("msg1", "First")] let batch2 = @[("msg2", "Second")] let batch3 = @[("msg3", "Third")] @@ -130,7 +130,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "queue isolation - critical and normal queues are separate": ## Given - let store = RateLimitStore[string].new(db) + let store = await RateLimitStore[string].new(db) let criticalMsgs = @[("crit1", "Critical Message")] let normalMsgs = @[("norm1", "Normal Message")] @@ -167,14 +167,14 @@ suite "SqliteRateLimitStore Tests": let msgs = @[("persist1", "Persistent Message")] block: - let store1 = RateLimitStore[string].new(db) + let store1 = await RateLimitStore[string].new(db) let addResult = await store1.pushToQueue(QueueType.Critical, msgs) check addResult == true check store1.getQueueLength(QueueType.Critical) == 1 ## When - Create new store instance block: - let store2 = RateLimitStore[string].new(db) + let store2 =await RateLimitStore[string].new(db) ## Then - Queue length should be restored from database check store2.getQueueLength(QueueType.Critical) == 1 @@ -186,7 +186,7 @@ suite "SqliteRateLimitStore Tests": asyncTest "large batch handling": ## Given - let store = RateLimitStore[string].new(db) + let store = await RateLimitStore[string].new(db) var largeBatch: seq[tuple[msgId: string, msg: string]] for i in 1 .. 100: From faefaa79663de70623ea28d220a909fc624aba66 Mon Sep 17 00:00:00 2001 From: pablo Date: Mon, 18 Aug 2025 16:40:58 +0300 Subject: [PATCH 16/17] fix: pr comment --- chat_sdk/migration.nim | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/chat_sdk/migration.nim b/chat_sdk/migration.nim index 45ea614..a23f740 100644 --- a/chat_sdk/migration.nim +++ b/chat_sdk/migration.nim @@ -32,10 +32,15 @@ proc runMigrations*(db: DbConn, dir = "migrations") = else: info "Applying migration", file let sqlContent = readFile(file) - # Split by semicolon and execute each statement separately - let statements = sqlContent.split(';') - for stmt in statements: - let trimmedStmt = stmt.strip() - if trimmedStmt.len > 0: - db.exec(sql(trimmedStmt)) - markMigrationRun(db, file) + db.exec(sql"BEGIN TRANSACTION") + try: + # Split by semicolon and execute each statement separately + for stmt in sqlContent.split(';'): + let trimmedStmt = stmt.strip() + if trimmedStmt.len > 0: + db.exec(sql(trimmedStmt)) + markMigrationRun(db, file) + db.exec(sql"COMMIT") + except: + db.exec(sql"ROLLBACK") + raise newException(ValueError, "Migration failed: " & file & " - " & getCurrentExceptionMsg()) From f2d8113f70e8a009de13d812df9ba5b3c41f23fd Mon Sep 17 00:00:00 2001 From: pablo Date: Mon, 1 Sep 2025 05:55:44 +0300 Subject: [PATCH 17/17] fix: added todos --- ratelimit/store.nim | 2 ++ ratelimit/token_bucket.nim | 2 ++ 2 files changed, 4 insertions(+) diff --git a/ratelimit/store.nim b/ratelimit/store.nim index 262bd41..42fd152 100644 --- a/ratelimit/store.nim +++ b/ratelimit/store.nim @@ -22,6 +22,8 @@ type const BUCKET_STATE_KEY = "rate_limit_bucket_state" +## TODO find a way to make these procs async + proc new*[T](M: type[RateLimitStore[T]], db: DbConn): Future[M] {.async} = result = M(db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1) diff --git a/ratelimit/token_bucket.nim b/ratelimit/token_bucket.nim index e2c606d..13e40a1 100644 --- a/ratelimit/token_bucket.nim +++ b/ratelimit/token_bucket.nim @@ -4,6 +4,8 @@ import chronos, std/math, std/options const BUDGET_COMPENSATION_LIMIT_PERCENT = 0.25 +## TODO! This will be remoded and replaced by https://github.com/status-im/nim-chronos/pull/582 + ## This is an extract from chronos/rate_limit.nim due to the found bug in the original implementation. ## Unfortunately that bug cannot be solved without harm the original features of TokenBucket class. ## So, this current shortcut is used to enable move ahead with nwaku rate limiter implementation.