fix: refactor

This commit is contained in:
pablo 2025-08-04 11:31:44 +03:00
parent 2c47183fb0
commit dd0082041c
No known key found for this signature in database
GPG Key ID: 78F35FCC60FDC63A
6 changed files with 117 additions and 169 deletions

View File

@ -2,7 +2,7 @@ import std/[times, deques, options]
# TODO: move to waku's, chronos' or a lib tocken_bucket once decided where this will live
import ./token_bucket
# import waku/common/rate_limit/token_bucket
import ./store/store
import ./store
import chronos
type
@ -22,16 +22,12 @@ type
Normal
Optional
Serializable* =
concept x
x.toBytes() is seq[byte]
MsgIdMsg[T: Serializable] = tuple[msgId: string, msg: T]
MessageSender*[T: Serializable] = proc(msgs: seq[MsgIdMsg[T]]) {.async.}
RateLimitManager*[T: Serializable, S: RateLimitStore] = ref object
store: S
RateLimitManager*[T: Serializable] = ref object
store: RateLimitStore[T]
bucket: TokenBucket
sender: MessageSender[T]
queueCritical: Deque[seq[MsgIdMsg[T]]]
@ -39,14 +35,14 @@ type
sleepDuration: chronos.Duration
pxQueueHandleLoop: Future[void]
proc new*[T: Serializable, S: RateLimitStore](
M: type[RateLimitManager[T, S]],
store: S,
proc new*[T: Serializable](
M: type[RateLimitManager[T]],
store: RateLimitStore[T],
sender: MessageSender[T],
capacity: int = 100,
duration: chronos.Duration = chronos.minutes(10),
sleepDuration: chronos.Duration = chronos.milliseconds(1000),
): Future[RateLimitManager[T, S]] {.async.} =
): Future[RateLimitManager[T]] {.async.} =
var current = await store.loadBucketState()
if current.isNone():
# initialize bucket state with full capacity
@ -55,7 +51,7 @@ proc new*[T: Serializable, S: RateLimitStore](
)
discard await store.saveBucketState(current.get())
return RateLimitManager[T, S](
return RateLimitManager[T](
store: store,
bucket: TokenBucket.new(
current.get().budgetCap,
@ -70,8 +66,8 @@ proc new*[T: Serializable, S: RateLimitStore](
sleepDuration: sleepDuration,
)
proc getCapacityState[T: Serializable, S: RateLimitStore](
manager: RateLimitManager[T, S], now: Moment, count: int = 1
proc getCapacityState[T: Serializable](
manager: RateLimitManager[T], now: Moment, count: int = 1
): CapacityState =
let (budget, budgetCap, _) = manager.bucket.getAvailableCapacity(now)
let countAfter = budget - count
@ -83,8 +79,8 @@ proc getCapacityState[T: Serializable, S: RateLimitStore](
else:
return CapacityState.Normal
proc passToSender[T: Serializable, S: RateLimitStore](
manager: RateLimitManager[T, S],
proc passToSender[T: Serializable](
manager: RateLimitManager[T],
msgs: seq[tuple[msgId: string, msg: T]],
now: Moment,
priority: Priority,
@ -109,8 +105,8 @@ proc passToSender[T: Serializable, S: RateLimitStore](
await manager.sender(msgs)
return SendResult.PassedToSender
proc processCriticalQueue[T: Serializable, S: RateLimitStore](
manager: RateLimitManager[T, S], now: Moment
proc processCriticalQueue[T: Serializable](
manager: RateLimitManager[T], now: Moment
): Future[void] {.async.} =
while manager.queueCritical.len > 0:
let msgs = manager.queueCritical.popFirst()
@ -124,8 +120,8 @@ proc processCriticalQueue[T: Serializable, S: RateLimitStore](
manager.queueCritical.addFirst(msgs)
break
proc processNormalQueue[T: Serializable, S: RateLimitStore](
manager: RateLimitManager[T, S], now: Moment
proc processNormalQueue[T: Serializable](
manager: RateLimitManager[T], now: Moment
): Future[void] {.async.} =
while manager.queueNormal.len > 0:
let msgs = manager.queueNormal.popFirst()
@ -137,8 +133,8 @@ proc processNormalQueue[T: Serializable, S: RateLimitStore](
manager.queueNormal.addFirst(msgs)
break
proc sendOrEnqueue*[T: Serializable, S: RateLimitStore](
manager: RateLimitManager[T, S],
proc sendOrEnqueue*[T: Serializable](
manager: RateLimitManager[T],
msgs: seq[tuple[msgId: string, msg: T]],
priority: Priority,
now: Moment = Moment.now(),
@ -172,8 +168,8 @@ proc sendOrEnqueue*[T: Serializable, S: RateLimitStore](
of Priority.Optional:
return SendResult.Dropped
proc getEnqueued*[T: Serializable, S: RateLimitStore](
manager: RateLimitManager[T, S]
proc getEnqueued*[T: Serializable](
manager: RateLimitManager[T]
): tuple[
critical: seq[tuple[msgId: string, msg: T]], normal: seq[tuple[msgId: string, msg: T]]
] =
@ -188,8 +184,8 @@ proc getEnqueued*[T: Serializable, S: RateLimitStore](
return (criticalMsgs, normalMsgs)
proc queueHandleLoop*[T: Serializable, S: RateLimitStore](
manager: RateLimitManager[T, S],
proc queueHandleLoop*[T: Serializable](
manager: RateLimitManager[T],
nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} =
Moment.now(),
) {.async.} =
@ -204,22 +200,18 @@ proc queueHandleLoop*[T: Serializable, S: RateLimitStore](
# configurable sleep duration for processing queued messages
await sleepAsync(manager.sleepDuration)
proc start*[T: Serializable, S: RateLimitStore](
manager: RateLimitManager[T, S],
proc start*[T: Serializable](
manager: RateLimitManager[T],
nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} =
Moment.now(),
) {.async.} =
manager.pxQueueHandleLoop = queueHandleLoop(manager, nowProvider)
proc stop*[T: Serializable, S: RateLimitStore](
manager: RateLimitManager[T, S]
) {.async.} =
proc stop*[T: Serializable](manager: RateLimitManager[T]) {.async.} =
if not isNil(manager.pxQueueHandleLoop):
await manager.pxQueueHandleLoop.cancelAndWait()
func `$`*[T: Serializable, S: RateLimitStore](
b: RateLimitManager[T, S]
): string {.inline.} =
func `$`*[T: Serializable](b: RateLimitManager[T]): string {.inline.} =
if isNil(b):
return "nil"
return

View File

@ -1,7 +1,6 @@
import std/[times, strutils, json, options]
import ./store
import chronos
import db_connector/db_sqlite
import chronos
# Generic deserialization function for basic types
proc fromBytesImpl(bytes: seq[byte], T: typedesc[string]): string =
@ -10,19 +9,31 @@ proc fromBytesImpl(bytes: seq[byte], T: typedesc[string]): string =
for i, b in bytes:
result[i] = char(b)
# SQLite Implementation
type SqliteRateLimitStore*[T] = ref object
db: DbConn
dbPath: string
criticalLength: int
normalLength: int
nextBatchId: int
type
Serializable* =
concept x
x.toBytes() is seq[byte]
RateLimitStore*[T: Serializable] = ref object
db: DbConn
dbPath: string
criticalLength: int
normalLength: int
nextBatchId: int
BucketState* = object
budget*: int
budgetCap*: int
lastTimeFull*: Moment
QueueType* {.pure.} = enum
Critical = "critical"
Normal = "normal"
const BUCKET_STATE_KEY = "rate_limit_bucket_state"
proc newSqliteRateLimitStore*[T](db: DbConn): SqliteRateLimitStore[T] =
result =
SqliteRateLimitStore[T](db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1)
proc new*[T: Serializable](M: type[RateLimitStore[T]], db: DbConn): M =
result = M(db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1)
# Initialize cached lengths from database
let criticalCount = db.getValue(
@ -53,8 +64,10 @@ proc newSqliteRateLimitStore*[T](db: DbConn): SqliteRateLimitStore[T] =
else:
parseInt(maxBatch) + 1
proc saveBucketState*[T](
store: SqliteRateLimitStore[T], bucketState: BucketState
return result
proc saveBucketState*[T: Serializable](
store: RateLimitStore[T], bucketState: BucketState
): Future[bool] {.async.} =
try:
# Convert Moment to Unix seconds for storage
@ -75,8 +88,8 @@ proc saveBucketState*[T](
except:
return false
proc loadBucketState*[T](
store: SqliteRateLimitStore[T]
proc loadBucketState*[T: Serializable](
store: RateLimitStore[T]
): Future[Option[BucketState]] {.async.} =
let jsonStr =
store.db.getValue(sql"SELECT value FROM kv_store WHERE key = ?", BUCKET_STATE_KEY)
@ -95,8 +108,8 @@ proc loadBucketState*[T](
)
)
proc addToQueue*[T](
store: SqliteRateLimitStore[T],
proc addToQueue*[T: Serializable](
store: RateLimitStore[T],
queueType: QueueType,
msgs: seq[tuple[msgId: string, msg: T]],
): Future[bool] {.async.} =
@ -140,8 +153,8 @@ proc addToQueue*[T](
except:
return false
proc popFromQueue*[T](
store: SqliteRateLimitStore[T], queueType: QueueType
proc popFromQueue*[T: Serializable](
store: RateLimitStore[T], queueType: QueueType
): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} =
try:
let queueTypeStr = $queueType
@ -201,7 +214,9 @@ proc popFromQueue*[T](
except:
return none(seq[tuple[msgId: string, msg: T]])
proc getQueueLength*[T](store: SqliteRateLimitStore[T], queueType: QueueType): int =
proc getQueueLength*[T: Serializable](
store: RateLimitStore[T], queueType: QueueType
): int =
case queueType
of QueueType.Critical:
return store.criticalLength

View File

@ -1,70 +0,0 @@
import std/[times, options, deques, tables]
import ./store
import chronos
# Memory Implementation
type MemoryRateLimitStore*[T] = ref object
bucketState: Option[BucketState]
criticalQueue: Deque[seq[tuple[msgId: string, msg: T]]]
normalQueue: Deque[seq[tuple[msgId: string, msg: T]]]
criticalLength: int
normalLength: int
proc new*[T](M: type[MemoryRateLimitStore[T]]): M =
return M(
bucketState: none(BucketState),
criticalQueue: initDeque[seq[tuple[msgId: string, msg: T]]](),
normalQueue: initDeque[seq[tuple[msgId: string, msg: T]]](),
criticalLength: 0,
normalLength: 0
)
proc saveBucketState*[T](
store: MemoryRateLimitStore[T], bucketState: BucketState
): Future[bool] {.async.} =
store.bucketState = some(bucketState)
return true
proc loadBucketState*[T](
store: MemoryRateLimitStore[T]
): Future[Option[BucketState]] {.async.} =
return store.bucketState
proc addToQueue*[T](
store: MemoryRateLimitStore[T],
queueType: QueueType,
msgs: seq[tuple[msgId: string, msg: T]]
): Future[bool] {.async.} =
case queueType
of QueueType.Critical:
store.criticalQueue.addLast(msgs)
inc store.criticalLength
of QueueType.Normal:
store.normalQueue.addLast(msgs)
inc store.normalLength
return true
proc popFromQueue*[T](
store: MemoryRateLimitStore[T],
queueType: QueueType
): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} =
case queueType
of QueueType.Critical:
if store.criticalQueue.len > 0:
dec store.criticalLength
return some(store.criticalQueue.popFirst())
of QueueType.Normal:
if store.normalQueue.len > 0:
dec store.normalLength
return some(store.normalQueue.popFirst())
return none(seq[tuple[msgId: string, msg: T]])
proc getQueueLength*[T](
store: MemoryRateLimitStore[T],
queueType: QueueType
): int =
case queueType
of QueueType.Critical:
return store.criticalLength
of QueueType.Normal:
return store.normalLength

View File

@ -1,20 +0,0 @@
import std/[times, deques, options]
import chronos
type
BucketState* = object
budget*: int
budgetCap*: int
lastTimeFull*: Moment
QueueType* {.pure.} = enum
Critical = "critical"
Normal = "normal"
RateLimitStore* =
concept s
s.saveBucketState(BucketState) is Future[bool]
s.loadBucketState() is Future[Option[BucketState]]
s.addToQueue(QueueType, seq[tuple[msgId: string, msg: untyped]]) is Future[bool]
s.popFromQueue(QueueType) is Future[Option[seq[tuple[msgId: string, msg: untyped]]]]
s.getQueueLength(QueueType) is int

View File

@ -1,12 +1,38 @@
import testutils/unittests
import ../ratelimit/ratelimit_manager
import ../ratelimit/store/memory
import ../ratelimit/store
import chronos
import db_connector/db_sqlite
# Implement the Serializable concept for string
proc toBytes*(s: string): seq[byte] =
cast[seq[byte]](s)
# Helper function to create an in-memory database with the proper schema
proc createTestDatabase(): DbConn =
result = open(":memory:", "", "", "")
# Create the required tables
result.exec(
sql"""
CREATE TABLE IF NOT EXISTS kv_store (
key TEXT PRIMARY KEY,
value BLOB
)
"""
)
result.exec(
sql"""
CREATE TABLE IF NOT EXISTS ratelimit_queues (
queue_type TEXT NOT NULL,
msg_id TEXT NOT NULL,
msg_data BLOB NOT NULL,
batch_id INTEGER NOT NULL,
created_at INTEGER NOT NULL,
PRIMARY KEY (queue_type, batch_id, msg_id)
)
"""
)
suite "Queue RateLimitManager":
setup:
var sentMessages: seq[tuple[msgId: string, msg: string]]
@ -23,8 +49,9 @@ suite "Queue RateLimitManager":
asyncTest "sendOrEnqueue - immediate send when capacity available":
## Given
let store: MemoryRateLimitStore = MemoryRateLimitStore.new()
let manager = await RateLimitManager[string, MemoryRateLimitStore].new(
let db = createTestDatabase()
let store: RateLimitStore[string] = RateLimitStore[string].new(db)
let manager = await RateLimitManager[string].new(
store, mockSender, capacity = 10, duration = chronos.milliseconds(100)
)
let testMsg = "Hello World"
@ -42,8 +69,9 @@ suite "Queue RateLimitManager":
asyncTest "sendOrEnqueue - multiple messages":
## Given
let store = MemoryRateLimitStore.new()
let manager = await RateLimitManager[string, MemoryRateLimitStore].new(
let db = createTestDatabase()
let store = RateLimitStore[string].new(db)
let manager = await RateLimitManager[string].new(
store, mockSender, capacity = 10, duration = chronos.milliseconds(100)
)
@ -63,8 +91,9 @@ suite "Queue RateLimitManager":
asyncTest "start and stop - drop large batch":
## Given
let store = MemoryRateLimitStore.new()
let manager = await RateLimitManager[string, MemoryRateLimitStore].new(
let db = createTestDatabase()
let store = RateLimitStore[string].new(db)
let manager = await RateLimitManager[string].new(
store,
mockSender,
capacity = 2,
@ -80,8 +109,9 @@ suite "Queue RateLimitManager":
asyncTest "enqueue - enqueue critical only when exceeded":
## Given
let store = MemoryRateLimitStore.new()
let manager = await RateLimitManager[string, MemoryRateLimitStore].new(
let db = createTestDatabase()
let store = RateLimitStore[string].new(db)
let manager = await RateLimitManager[string].new(
store,
mockSender,
capacity = 10,
@ -130,8 +160,9 @@ suite "Queue RateLimitManager":
asyncTest "enqueue - enqueue normal on 70% capacity":
## Given
let store = MemoryRateLimitStore.new()
let manager = await RateLimitManager[string, MemoryRateLimitStore].new(
let db = createTestDatabase()
let store = RateLimitStore[string].new(db)
let manager = await RateLimitManager[string].new(
store,
mockSender,
capacity = 10,
@ -183,8 +214,9 @@ suite "Queue RateLimitManager":
asyncTest "enqueue - process queued messages":
## Given
let store = MemoryRateLimitStore.new()
let manager = await RateLimitManager[string, MemoryRateLimitStore].new(
let db = createTestDatabase()
let store = RateLimitStore[string].new(db)
let manager = await RateLimitManager[string].new(
store,
mockSender,
capacity = 10,

View File

@ -1,6 +1,5 @@
import testutils/unittests
import ../ratelimit/store/sqlite
import ../ratelimit/store/store
import ../ratelimit/store
import chronos
import db_connector/db_sqlite
import ../chat_sdk/migration
@ -32,7 +31,7 @@ suite "SqliteRateLimitStore Tests":
asyncTest "newSqliteRateLimitStore - empty state":
## Given
let store = newSqliteRateLimitStore[string](db)
let store = RateLimitStore[string].new(db)
## When
let loadedState = await store.loadBucketState()
@ -42,7 +41,7 @@ suite "SqliteRateLimitStore Tests":
asyncTest "saveBucketState and loadBucketState - state persistence":
## Given
let store = newSqliteRateLimitStore[string](db)
let store = RateLimitStore[string].new(db)
let now = Moment.now()
echo "now: ", now.epochSeconds()
@ -62,7 +61,7 @@ suite "SqliteRateLimitStore Tests":
asyncTest "queue operations - empty store":
## Given
let store = newSqliteRateLimitStore[string](db)
let store = RateLimitStore[string].new(db)
## When/Then
check store.getQueueLength(QueueType.Critical) == 0
@ -76,7 +75,7 @@ suite "SqliteRateLimitStore Tests":
asyncTest "addToQueue and popFromQueue - single batch":
## Given
let store = newSqliteRateLimitStore[string](db)
let store = RateLimitStore[string].new(db)
let msgs = @[("msg1", "Hello"), ("msg2", "World")]
## When
@ -103,7 +102,7 @@ suite "SqliteRateLimitStore Tests":
asyncTest "addToQueue and popFromQueue - multiple batches FIFO":
## Given
let store = newSqliteRateLimitStore[string](db)
let store = RateLimitStore[string].new(db)
let batch1 = @[("msg1", "First")]
let batch2 = @[("msg2", "Second")]
let batch3 = @[("msg3", "Third")]
@ -141,7 +140,7 @@ suite "SqliteRateLimitStore Tests":
asyncTest "queue isolation - critical and normal queues are separate":
## Given
let store = newSqliteRateLimitStore[string](db)
let store = RateLimitStore[string].new(db)
let criticalMsgs = @[("crit1", "Critical Message")]
let normalMsgs = @[("norm1", "Normal Message")]
@ -178,14 +177,14 @@ suite "SqliteRateLimitStore Tests":
let msgs = @[("persist1", "Persistent Message")]
block:
let store1 = newSqliteRateLimitStore[string](db)
let store1 = RateLimitStore[string].new(db)
let addResult = await store1.addToQueue(QueueType.Critical, msgs)
check addResult == true
check store1.getQueueLength(QueueType.Critical) == 1
## When - Create new store instance
block:
let store2 = newSqliteRateLimitStore[string](db)
let store2 = RateLimitStore[string].new(db)
## Then - Queue length should be restored from database
check store2.getQueueLength(QueueType.Critical) == 1
@ -197,7 +196,7 @@ suite "SqliteRateLimitStore Tests":
asyncTest "large batch handling":
## Given
let store = newSqliteRateLimitStore[string](db)
let store = RateLimitStore[string].new(db)
var largeBatch: seq[tuple[msgId: string, msg: string]]
for i in 1 .. 100: