diff --git a/chat_sdk.nimble b/chat_sdk.nimble index e5fbb06..8841ff3 100644 --- a/chat_sdk.nimble +++ b/chat_sdk.nimble @@ -7,9 +7,7 @@ license = "MIT" srcDir = "src" ### Dependencies -requires "nim >= 2.2.4", - "chronicles", "chronos", "db_connector", - "https://github.com/waku-org/token_bucket.git" +requires "nim >= 2.2.4", "chronicles", "chronos", "db_connector", "flatty" task buildSharedLib, "Build shared library for C bindings": exec "nim c --mm:refc --app:lib --out:../library/c-bindings/libchatsdk.so chat_sdk/chat_sdk.nim" diff --git a/ratelimit/ratelimit_manager.nim b/ratelimit/ratelimit_manager.nim index 333dc64..a30b5be 100644 --- a/ratelimit/ratelimit_manager.nim +++ b/ratelimit/ratelimit_manager.nim @@ -23,18 +23,18 @@ type Normal Optional - MsgIdMsg[T: Serializable] = tuple[msgId: string, msg: T] + MsgIdMsg[T] = tuple[msgId: string, msg: T] - MessageSender*[T: Serializable] = proc(msgs: seq[MsgIdMsg[T]]) {.async.} + MessageSender*[T] = proc(msgs: seq[MsgIdMsg[T]]) {.async.} - RateLimitManager*[T: Serializable] = ref object + RateLimitManager*[T] = ref object store: RateLimitStore[T] bucket: TokenBucket sender: MessageSender[T] sleepDuration: chronos.Duration pxQueueHandleLoop: Future[void] -proc new*[T: Serializable]( +proc new*[T]( M: type[RateLimitManager[T]], store: RateLimitStore[T], sender: MessageSender[T], @@ -63,7 +63,7 @@ proc new*[T: Serializable]( sleepDuration: sleepDuration, ) -proc getCapacityState[T: Serializable]( +proc getCapacityState[T]( manager: RateLimitManager[T], now: Moment, count: int = 1 ): CapacityState = let (budget, budgetCap, _) = manager.bucket.getAvailableCapacity(now) @@ -76,7 +76,7 @@ proc getCapacityState[T: Serializable]( else: return CapacityState.Normal -proc passToSender[T: Serializable]( +proc passToSender[T]( manager: RateLimitManager[T], msgs: seq[tuple[msgId: string, msg: T]], now: Moment, @@ -87,10 +87,10 @@ proc passToSender[T: Serializable]( if not consumed: case priority of Priority.Critical: - discard await manager.store.addToQueue(QueueType.Critical, msgs) + discard await manager.store.pushToQueue(QueueType.Critical, msgs) return SendResult.Enqueued of Priority.Normal: - discard await manager.store.addToQueue(QueueType.Normal, msgs) + discard await manager.store.pushToQueue(QueueType.Normal, msgs) return SendResult.Enqueued of Priority.Optional: return SendResult.Dropped @@ -102,7 +102,7 @@ proc passToSender[T: Serializable]( await manager.sender(msgs) return SendResult.PassedToSender -proc processCriticalQueue[T: Serializable]( +proc processCriticalQueue[T]( manager: RateLimitManager[T], now: Moment ): Future[void] {.async.} = while manager.store.getQueueLength(QueueType.Critical) > 0: @@ -119,10 +119,10 @@ proc processCriticalQueue[T: Serializable]( discard await manager.passToSender(msgs, now, Priority.Critical) else: # Put back to critical queue (add to front not possible, so we add to back and exit) - discard await manager.store.addToQueue(QueueType.Critical, msgs) + discard await manager.store.pushToQueue(QueueType.Critical, msgs) break -proc processNormalQueue[T: Serializable]( +proc processNormalQueue[T]( manager: RateLimitManager[T], now: Moment ): Future[void] {.async.} = while manager.store.getQueueLength(QueueType.Normal) > 0: @@ -137,10 +137,10 @@ proc processNormalQueue[T: Serializable]( discard await manager.passToSender(msgs, now, Priority.Normal) else: # Put back to normal queue (add to front not possible, so we add to back and exit) - discard await manager.store.addToQueue(QueueType.Normal, msgs) + discard await manager.store.pushToQueue(QueueType.Normal, msgs) break -proc sendOrEnqueue*[T: Serializable]( +proc sendOrEnqueue*[T]( manager: RateLimitManager[T], msgs: seq[tuple[msgId: string, msg: T]], priority: Priority, @@ -160,22 +160,22 @@ proc sendOrEnqueue*[T: Serializable]( of Priority.Critical: return await manager.passToSender(msgs, now, priority) of Priority.Normal: - discard await manager.store.addToQueue(QueueType.Normal, msgs) + discard await manager.store.pushToQueue(QueueType.Normal, msgs) return SendResult.Enqueued of Priority.Optional: return SendResult.Dropped of CapacityState.None: case priority of Priority.Critical: - discard await manager.store.addToQueue(QueueType.Critical, msgs) + discard await manager.store.pushToQueue(QueueType.Critical, msgs) return SendResult.Enqueued of Priority.Normal: - discard await manager.store.addToQueue(QueueType.Normal, msgs) + discard await manager.store.pushToQueue(QueueType.Normal, msgs) return SendResult.Enqueued of Priority.Optional: return SendResult.Dropped -proc queueHandleLoop*[T: Serializable]( +proc queueHandleLoop*[T]( manager: RateLimitManager[T], nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} = Moment.now(), @@ -191,18 +191,18 @@ proc queueHandleLoop*[T: Serializable]( # configurable sleep duration for processing queued messages await sleepAsync(manager.sleepDuration) -proc start*[T: Serializable]( +proc start*[T]( manager: RateLimitManager[T], nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} = Moment.now(), ) {.async.} = manager.pxQueueHandleLoop = queueHandleLoop(manager, nowProvider) -proc stop*[T: Serializable](manager: RateLimitManager[T]) {.async.} = +proc stop*[T](manager: RateLimitManager[T]) {.async.} = if not isNil(manager.pxQueueHandleLoop): await manager.pxQueueHandleLoop.cancelAndWait() -func `$`*[T: Serializable](b: RateLimitManager[T]): string {.inline.} = +func `$`*[T](b: RateLimitManager[T]): string {.inline.} = if isNil(b): return "nil" return diff --git a/ratelimit/store.nim b/ratelimit/store.nim index 6aa0be5..3c023aa 100644 --- a/ratelimit/store.nim +++ b/ratelimit/store.nim @@ -1,20 +1,10 @@ -import std/[times, strutils, json, options] +import std/[times, strutils, json, options, base64] import db_connector/db_sqlite import chronos - -# Generic deserialization function for basic types -proc fromBytesImpl(bytes: seq[byte], T: typedesc[string]): string = - # Convert each byte back to a character - result = newString(bytes.len) - for i, b in bytes: - result[i] = char(b) +import flatty type - Serializable* = - concept x - x.toBytes() is seq[byte] - - RateLimitStore*[T: Serializable] = ref object + RateLimitStore*[T] = ref object db: DbConn dbPath: string criticalLength: int @@ -32,7 +22,7 @@ type const BUCKET_STATE_KEY = "rate_limit_bucket_state" -proc new*[T: Serializable](M: type[RateLimitStore[T]], db: DbConn): M = +proc new*[T](M: type[RateLimitStore[T]], db: DbConn): M = result = M(db: db, criticalLength: 0, normalLength: 0, nextBatchId: 1) # Initialize cached lengths from database @@ -66,7 +56,7 @@ proc new*[T: Serializable](M: type[RateLimitStore[T]], db: DbConn): M = return result -proc saveBucketState*[T: Serializable]( +proc saveBucketState*[T]( store: RateLimitStore[T], bucketState: BucketState ): Future[bool] {.async.} = try: @@ -88,7 +78,7 @@ proc saveBucketState*[T: Serializable]( except: return false -proc loadBucketState*[T: Serializable]( +proc loadBucketState*[T]( store: RateLimitStore[T] ): Future[Option[BucketState]] {.async.} = let jsonStr = @@ -108,7 +98,7 @@ proc loadBucketState*[T: Serializable]( ) ) -proc pushToQueue*[T: Serializable]( +proc pushToQueue*[T]( store: RateLimitStore[T], queueType: QueueType, msgs: seq[tuple[msgId: string, msg: T]], @@ -123,18 +113,13 @@ proc pushToQueue*[T: Serializable]( store.db.exec(sql"BEGIN TRANSACTION") try: for msg in msgs: - # Consistent serialization for all types - let msgBytes = msg.msg.toBytes() - # Convert seq[byte] to string for SQLite storage (each byte becomes a character) - var binaryStr = newString(msgBytes.len) - for i, b in msgBytes: - binaryStr[i] = char(b) - + let serialized = msg.msg.toFlatty() + let msgData = encode(serialized) store.db.exec( sql"INSERT INTO ratelimit_queues (queue_type, msg_id, msg_data, batch_id, created_at) VALUES (?, ?, ?, ?, ?)", queueTypeStr, msg.msgId, - binaryStr, + msgData, batchId, now, ) @@ -153,7 +138,7 @@ proc pushToQueue*[T: Serializable]( except: return false -proc popFromQueue*[T: Serializable]( +proc popFromQueue*[T]( store: RateLimitStore[T], queueType: QueueType ): Future[Option[seq[tuple[msgId: string, msg: T]]]] {.async.} = try: @@ -182,20 +167,11 @@ proc popFromQueue*[T: Serializable]( var msgs: seq[tuple[msgId: string, msg: T]] for row in rows: let msgIdStr = row[0] - let msgData = row[1] # SQLite returns BLOB as string where each char is a byte - # Convert string back to seq[byte] properly (each char in string is a byte) - var msgBytes: seq[byte] - for c in msgData: - msgBytes.add(byte(c)) + let msgDataB64 = row[1] - # Generic deserialization - works for any type that implements fromBytes - when T is string: - let msg = fromBytesImpl(msgBytes, T) - msgs.add((msgId: msgIdStr, msg: msg)) - else: - # For other types, they need to provide their own fromBytes in the calling context - let msg = fromBytes(msgBytes, T) - msgs.add((msgId: msgIdStr, msg: msg)) + let serialized = decode(msgDataB64) + let msg = serialized.fromFlatty(T) + msgs.add((msgId: msgIdStr, msg: msg)) # Delete the batch from database store.db.exec( @@ -214,9 +190,7 @@ proc popFromQueue*[T: Serializable]( except: return none(seq[tuple[msgId: string, msg: T]]) -proc getQueueLength*[T: Serializable]( - store: RateLimitStore[T], queueType: QueueType -): int = +proc getQueueLength*[T](store: RateLimitStore[T], queueType: QueueType): int = case queueType of QueueType.Critical: return store.criticalLength diff --git a/tests/test_ratelimit_manager.nim b/tests/test_ratelimit_manager.nim index 88a2335..70ebe2c 100644 --- a/tests/test_ratelimit_manager.nim +++ b/tests/test_ratelimit_manager.nim @@ -6,10 +6,6 @@ import db_connector/db_sqlite import ../chat_sdk/migration import std/[os, options] -# Implement the Serializable concept for string -proc toBytes*(s: string): seq[byte] = - cast[seq[byte]](s) - var dbName = "test_ratelimit_manager.db" suite "Queue RateLimitManager": diff --git a/tests/test_store.nim b/tests/test_store.nim index ae5f009..b08b7d3 100644 --- a/tests/test_store.nim +++ b/tests/test_store.nim @@ -3,23 +3,11 @@ import ../ratelimit/store import chronos import db_connector/db_sqlite import ../chat_sdk/migration -import std/[options, os] +import std/[options, os, json] +import flatty const dbName = "test_store.db" -# Implement the Serializable concept for string (for testing) -proc toBytes*(s: string): seq[byte] = - # Convert each character to a byte - result = newSeq[byte](s.len) - for i, c in s: - result[i] = byte(c) - -proc fromBytes*(bytes: seq[byte], T: typedesc[string]): string = - # Convert each byte back to a character - result = newString(bytes.len) - for i, b in bytes: - result[i] = char(b) - suite "SqliteRateLimitStore Tests": setup: let db = open(dbName, "", "", "") @@ -81,7 +69,7 @@ suite "SqliteRateLimitStore Tests": let msgs = @[("msg1", "Hello"), ("msg2", "World")] ## When - let addResult = await store.addToQueue(QueueType.Critical, msgs) + let addResult = await store.pushToQueue(QueueType.Critical, msgs) ## Then check addResult == true @@ -110,11 +98,11 @@ suite "SqliteRateLimitStore Tests": let batch3 = @[("msg3", "Third")] ## When - Add batches - let result1 = await store.addToQueue(QueueType.Normal, batch1) + let result1 = await store.pushToQueue(QueueType.Normal, batch1) check result1 == true - let result2 = await store.addToQueue(QueueType.Normal, batch2) + let result2 = await store.pushToQueue(QueueType.Normal, batch2) check result2 == true - let result3 = await store.addToQueue(QueueType.Normal, batch3) + let result3 = await store.pushToQueue(QueueType.Normal, batch3) check result3 == true ## Then - Check lengths @@ -147,9 +135,9 @@ suite "SqliteRateLimitStore Tests": let normalMsgs = @[("norm1", "Normal Message")] ## When - let critResult = await store.addToQueue(QueueType.Critical, criticalMsgs) + let critResult = await store.pushToQueue(QueueType.Critical, criticalMsgs) check critResult == true - let normResult = await store.addToQueue(QueueType.Normal, normalMsgs) + let normResult = await store.pushToQueue(QueueType.Normal, normalMsgs) check normResult == true ## Then @@ -180,7 +168,7 @@ suite "SqliteRateLimitStore Tests": block: let store1 = RateLimitStore[string].new(db) - let addResult = await store1.addToQueue(QueueType.Critical, msgs) + let addResult = await store1.pushToQueue(QueueType.Critical, msgs) check addResult == true check store1.getQueueLength(QueueType.Critical) == 1 @@ -205,7 +193,7 @@ suite "SqliteRateLimitStore Tests": largeBatch.add(("msg" & $i, "Message " & $i)) ## When - let addResult = await store.addToQueue(QueueType.Normal, largeBatch) + let addResult = await store.pushToQueue(QueueType.Normal, largeBatch) ## Then check addResult == true