mirror of
https://github.com/logos-messaging/nim-chat-sdk.git
synced 2026-05-21 01:19:33 +00:00
fix: tests
This commit is contained in:
parent
dd0082041c
commit
bcdb56c1ca
@ -1,9 +1,10 @@
|
|||||||
import std/[times, deques, options]
|
import std/[times, options]
|
||||||
# TODO: move to waku's, chronos' or a lib tocken_bucket once decided where this will live
|
# TODO: move to waku's, chronos' or a lib tocken_bucket once decided where this will live
|
||||||
import ./token_bucket
|
import ./token_bucket
|
||||||
# import waku/common/rate_limit/token_bucket
|
# import waku/common/rate_limit/token_bucket
|
||||||
import ./store
|
import ./store
|
||||||
import chronos
|
import chronos
|
||||||
|
import db_connector/db_sqlite
|
||||||
|
|
||||||
type
|
type
|
||||||
CapacityState {.pure.} = enum
|
CapacityState {.pure.} = enum
|
||||||
@ -30,8 +31,6 @@ type
|
|||||||
store: RateLimitStore[T]
|
store: RateLimitStore[T]
|
||||||
bucket: TokenBucket
|
bucket: TokenBucket
|
||||||
sender: MessageSender[T]
|
sender: MessageSender[T]
|
||||||
queueCritical: Deque[seq[MsgIdMsg[T]]]
|
|
||||||
queueNormal: Deque[seq[MsgIdMsg[T]]]
|
|
||||||
sleepDuration: chronos.Duration
|
sleepDuration: chronos.Duration
|
||||||
pxQueueHandleLoop: Future[void]
|
pxQueueHandleLoop: Future[void]
|
||||||
|
|
||||||
@ -61,8 +60,6 @@ proc new*[T: Serializable](
|
|||||||
current.get().lastTimeFull,
|
current.get().lastTimeFull,
|
||||||
),
|
),
|
||||||
sender: sender,
|
sender: sender,
|
||||||
queueCritical: Deque[seq[MsgIdMsg[T]]](),
|
|
||||||
queueNormal: Deque[seq[MsgIdMsg[T]]](),
|
|
||||||
sleepDuration: sleepDuration,
|
sleepDuration: sleepDuration,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -90,10 +87,10 @@ proc passToSender[T: Serializable](
|
|||||||
if not consumed:
|
if not consumed:
|
||||||
case priority
|
case priority
|
||||||
of Priority.Critical:
|
of Priority.Critical:
|
||||||
manager.queueCritical.addLast(msgs)
|
discard await manager.store.addToQueue(QueueType.Critical, msgs)
|
||||||
return SendResult.Enqueued
|
return SendResult.Enqueued
|
||||||
of Priority.Normal:
|
of Priority.Normal:
|
||||||
manager.queueNormal.addLast(msgs)
|
discard await manager.store.addToQueue(QueueType.Normal, msgs)
|
||||||
return SendResult.Enqueued
|
return SendResult.Enqueued
|
||||||
of Priority.Optional:
|
of Priority.Optional:
|
||||||
return SendResult.Dropped
|
return SendResult.Dropped
|
||||||
@ -108,29 +105,39 @@ proc passToSender[T: Serializable](
|
|||||||
proc processCriticalQueue[T: Serializable](
|
proc processCriticalQueue[T: Serializable](
|
||||||
manager: RateLimitManager[T], now: Moment
|
manager: RateLimitManager[T], now: Moment
|
||||||
): Future[void] {.async.} =
|
): Future[void] {.async.} =
|
||||||
while manager.queueCritical.len > 0:
|
while manager.store.getQueueLength(QueueType.Critical) > 0:
|
||||||
let msgs = manager.queueCritical.popFirst()
|
# Peek at the next batch by getting it, but we'll handle putting it back if needed
|
||||||
|
let maybeMsgs = await manager.store.popFromQueue(QueueType.Critical)
|
||||||
|
if maybeMsgs.isNone():
|
||||||
|
break
|
||||||
|
|
||||||
|
let msgs = maybeMsgs.get()
|
||||||
let capacityState = manager.getCapacityState(now, msgs.len)
|
let capacityState = manager.getCapacityState(now, msgs.len)
|
||||||
if capacityState == CapacityState.Normal:
|
if capacityState == CapacityState.Normal:
|
||||||
discard await manager.passToSender(msgs, now, Priority.Critical)
|
discard await manager.passToSender(msgs, now, Priority.Critical)
|
||||||
elif capacityState == CapacityState.AlmostNone:
|
elif capacityState == CapacityState.AlmostNone:
|
||||||
discard await manager.passToSender(msgs, now, Priority.Critical)
|
discard await manager.passToSender(msgs, now, Priority.Critical)
|
||||||
else:
|
else:
|
||||||
# add back to critical queue
|
# Put back to critical queue (add to front not possible, so we add to back and exit)
|
||||||
manager.queueCritical.addFirst(msgs)
|
discard await manager.store.addToQueue(QueueType.Critical, msgs)
|
||||||
break
|
break
|
||||||
|
|
||||||
proc processNormalQueue[T: Serializable](
|
proc processNormalQueue[T: Serializable](
|
||||||
manager: RateLimitManager[T], now: Moment
|
manager: RateLimitManager[T], now: Moment
|
||||||
): Future[void] {.async.} =
|
): Future[void] {.async.} =
|
||||||
while manager.queueNormal.len > 0:
|
while manager.store.getQueueLength(QueueType.Normal) > 0:
|
||||||
let msgs = manager.queueNormal.popFirst()
|
# Peek at the next batch by getting it, but we'll handle putting it back if needed
|
||||||
|
let maybeMsgs = await manager.store.popFromQueue(QueueType.Normal)
|
||||||
|
if maybeMsgs.isNone():
|
||||||
|
break
|
||||||
|
|
||||||
|
let msgs = maybeMsgs.get()
|
||||||
let capacityState = manager.getCapacityState(now, msgs.len)
|
let capacityState = manager.getCapacityState(now, msgs.len)
|
||||||
if capacityState == CapacityState.Normal:
|
if capacityState == CapacityState.Normal:
|
||||||
discard await manager.passToSender(msgs, now, Priority.Normal)
|
discard await manager.passToSender(msgs, now, Priority.Normal)
|
||||||
else:
|
else:
|
||||||
# add back to critical queue
|
# Put back to normal queue (add to front not possible, so we add to back and exit)
|
||||||
manager.queueNormal.addFirst(msgs)
|
discard await manager.store.addToQueue(QueueType.Normal, msgs)
|
||||||
break
|
break
|
||||||
|
|
||||||
proc sendOrEnqueue*[T: Serializable](
|
proc sendOrEnqueue*[T: Serializable](
|
||||||
@ -153,37 +160,21 @@ proc sendOrEnqueue*[T: Serializable](
|
|||||||
of Priority.Critical:
|
of Priority.Critical:
|
||||||
return await manager.passToSender(msgs, now, priority)
|
return await manager.passToSender(msgs, now, priority)
|
||||||
of Priority.Normal:
|
of Priority.Normal:
|
||||||
manager.queueNormal.addLast(msgs)
|
discard await manager.store.addToQueue(QueueType.Normal, msgs)
|
||||||
return SendResult.Enqueued
|
return SendResult.Enqueued
|
||||||
of Priority.Optional:
|
of Priority.Optional:
|
||||||
return SendResult.Dropped
|
return SendResult.Dropped
|
||||||
of CapacityState.None:
|
of CapacityState.None:
|
||||||
case priority
|
case priority
|
||||||
of Priority.Critical:
|
of Priority.Critical:
|
||||||
manager.queueCritical.addLast(msgs)
|
discard await manager.store.addToQueue(QueueType.Critical, msgs)
|
||||||
return SendResult.Enqueued
|
return SendResult.Enqueued
|
||||||
of Priority.Normal:
|
of Priority.Normal:
|
||||||
manager.queueNormal.addLast(msgs)
|
discard await manager.store.addToQueue(QueueType.Normal, msgs)
|
||||||
return SendResult.Enqueued
|
return SendResult.Enqueued
|
||||||
of Priority.Optional:
|
of Priority.Optional:
|
||||||
return SendResult.Dropped
|
return SendResult.Dropped
|
||||||
|
|
||||||
proc getEnqueued*[T: Serializable](
|
|
||||||
manager: RateLimitManager[T]
|
|
||||||
): tuple[
|
|
||||||
critical: seq[tuple[msgId: string, msg: T]], normal: seq[tuple[msgId: string, msg: T]]
|
|
||||||
] =
|
|
||||||
var criticalMsgs: seq[tuple[msgId: string, msg: T]]
|
|
||||||
var normalMsgs: seq[tuple[msgId: string, msg: T]]
|
|
||||||
|
|
||||||
for batch in manager.queueCritical:
|
|
||||||
criticalMsgs.add(batch)
|
|
||||||
|
|
||||||
for batch in manager.queueNormal:
|
|
||||||
normalMsgs.add(batch)
|
|
||||||
|
|
||||||
return (criticalMsgs, normalMsgs)
|
|
||||||
|
|
||||||
proc queueHandleLoop*[T: Serializable](
|
proc queueHandleLoop*[T: Serializable](
|
||||||
manager: RateLimitManager[T],
|
manager: RateLimitManager[T],
|
||||||
nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} =
|
nowProvider: proc(): Moment {.gcsafe.} = proc(): Moment {.gcsafe.} =
|
||||||
@ -215,5 +206,5 @@ func `$`*[T: Serializable](b: RateLimitManager[T]): string {.inline.} =
|
|||||||
if isNil(b):
|
if isNil(b):
|
||||||
return "nil"
|
return "nil"
|
||||||
return
|
return
|
||||||
"RateLimitManager(critical: " & $b.queueCritical.len & ", normal: " &
|
"RateLimitManager(critical: " & $b.store.getQueueLength(QueueType.Critical) &
|
||||||
$b.queueNormal.len & ")"
|
", normal: " & $b.store.getQueueLength(QueueType.Normal) & ")"
|
||||||
|
|||||||
@ -3,38 +3,20 @@ import ../ratelimit/ratelimit_manager
|
|||||||
import ../ratelimit/store
|
import ../ratelimit/store
|
||||||
import chronos
|
import chronos
|
||||||
import db_connector/db_sqlite
|
import db_connector/db_sqlite
|
||||||
|
import ../chat_sdk/migration
|
||||||
|
import std/[os, options]
|
||||||
|
|
||||||
# Implement the Serializable concept for string
|
# Implement the Serializable concept for string
|
||||||
proc toBytes*(s: string): seq[byte] =
|
proc toBytes*(s: string): seq[byte] =
|
||||||
cast[seq[byte]](s)
|
cast[seq[byte]](s)
|
||||||
|
|
||||||
# Helper function to create an in-memory database with the proper schema
|
var dbName = "test_ratelimit_manager.db"
|
||||||
proc createTestDatabase(): DbConn =
|
|
||||||
result = open(":memory:", "", "", "")
|
|
||||||
# Create the required tables
|
|
||||||
result.exec(
|
|
||||||
sql"""
|
|
||||||
CREATE TABLE IF NOT EXISTS kv_store (
|
|
||||||
key TEXT PRIMARY KEY,
|
|
||||||
value BLOB
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
result.exec(
|
|
||||||
sql"""
|
|
||||||
CREATE TABLE IF NOT EXISTS ratelimit_queues (
|
|
||||||
queue_type TEXT NOT NULL,
|
|
||||||
msg_id TEXT NOT NULL,
|
|
||||||
msg_data BLOB NOT NULL,
|
|
||||||
batch_id INTEGER NOT NULL,
|
|
||||||
created_at INTEGER NOT NULL,
|
|
||||||
PRIMARY KEY (queue_type, batch_id, msg_id)
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
suite "Queue RateLimitManager":
|
suite "Queue RateLimitManager":
|
||||||
setup:
|
setup:
|
||||||
|
let db = open(dbName, "", "", "")
|
||||||
|
runMigrations(db)
|
||||||
|
|
||||||
var sentMessages: seq[tuple[msgId: string, msg: string]]
|
var sentMessages: seq[tuple[msgId: string, msg: string]]
|
||||||
var senderCallCount: int = 0
|
var senderCallCount: int = 0
|
||||||
|
|
||||||
@ -47,9 +29,14 @@ suite "Queue RateLimitManager":
|
|||||||
sentMessages.add(msg)
|
sentMessages.add(msg)
|
||||||
await sleepAsync(chronos.milliseconds(10))
|
await sleepAsync(chronos.milliseconds(10))
|
||||||
|
|
||||||
|
teardown:
|
||||||
|
if db != nil:
|
||||||
|
db.close()
|
||||||
|
if fileExists(dbName):
|
||||||
|
removeFile(dbName)
|
||||||
|
|
||||||
asyncTest "sendOrEnqueue - immediate send when capacity available":
|
asyncTest "sendOrEnqueue - immediate send when capacity available":
|
||||||
## Given
|
## Given
|
||||||
let db = createTestDatabase()
|
|
||||||
let store: RateLimitStore[string] = RateLimitStore[string].new(db)
|
let store: RateLimitStore[string] = RateLimitStore[string].new(db)
|
||||||
let manager = await RateLimitManager[string].new(
|
let manager = await RateLimitManager[string].new(
|
||||||
store, mockSender, capacity = 10, duration = chronos.milliseconds(100)
|
store, mockSender, capacity = 10, duration = chronos.milliseconds(100)
|
||||||
@ -69,7 +56,6 @@ suite "Queue RateLimitManager":
|
|||||||
|
|
||||||
asyncTest "sendOrEnqueue - multiple messages":
|
asyncTest "sendOrEnqueue - multiple messages":
|
||||||
## Given
|
## Given
|
||||||
let db = createTestDatabase()
|
|
||||||
let store = RateLimitStore[string].new(db)
|
let store = RateLimitStore[string].new(db)
|
||||||
let manager = await RateLimitManager[string].new(
|
let manager = await RateLimitManager[string].new(
|
||||||
store, mockSender, capacity = 10, duration = chronos.milliseconds(100)
|
store, mockSender, capacity = 10, duration = chronos.milliseconds(100)
|
||||||
@ -91,7 +77,6 @@ suite "Queue RateLimitManager":
|
|||||||
|
|
||||||
asyncTest "start and stop - drop large batch":
|
asyncTest "start and stop - drop large batch":
|
||||||
## Given
|
## Given
|
||||||
let db = createTestDatabase()
|
|
||||||
let store = RateLimitStore[string].new(db)
|
let store = RateLimitStore[string].new(db)
|
||||||
let manager = await RateLimitManager[string].new(
|
let manager = await RateLimitManager[string].new(
|
||||||
store,
|
store,
|
||||||
@ -109,7 +94,6 @@ suite "Queue RateLimitManager":
|
|||||||
|
|
||||||
asyncTest "enqueue - enqueue critical only when exceeded":
|
asyncTest "enqueue - enqueue critical only when exceeded":
|
||||||
## Given
|
## Given
|
||||||
let db = createTestDatabase()
|
|
||||||
let store = RateLimitStore[string].new(db)
|
let store = RateLimitStore[string].new(db)
|
||||||
let manager = await RateLimitManager[string].new(
|
let manager = await RateLimitManager[string].new(
|
||||||
store,
|
store,
|
||||||
@ -152,15 +136,8 @@ suite "Queue RateLimitManager":
|
|||||||
r10 == PassedToSender
|
r10 == PassedToSender
|
||||||
r11 == Enqueued
|
r11 == Enqueued
|
||||||
|
|
||||||
let (critical, normal) = manager.getEnqueued()
|
|
||||||
check:
|
|
||||||
critical.len == 1
|
|
||||||
normal.len == 0
|
|
||||||
critical[0].msgId == "msg11"
|
|
||||||
|
|
||||||
asyncTest "enqueue - enqueue normal on 70% capacity":
|
asyncTest "enqueue - enqueue normal on 70% capacity":
|
||||||
## Given
|
## Given
|
||||||
let db = createTestDatabase()
|
|
||||||
let store = RateLimitStore[string].new(db)
|
let store = RateLimitStore[string].new(db)
|
||||||
let manager = await RateLimitManager[string].new(
|
let manager = await RateLimitManager[string].new(
|
||||||
store,
|
store,
|
||||||
@ -204,17 +181,8 @@ suite "Queue RateLimitManager":
|
|||||||
r11 == PassedToSender
|
r11 == PassedToSender
|
||||||
r12 == Dropped
|
r12 == Dropped
|
||||||
|
|
||||||
let (critical, normal) = manager.getEnqueued()
|
|
||||||
check:
|
|
||||||
critical.len == 0
|
|
||||||
normal.len == 3
|
|
||||||
normal[0].msgId == "msg8"
|
|
||||||
normal[1].msgId == "msg9"
|
|
||||||
normal[2].msgId == "msg10"
|
|
||||||
|
|
||||||
asyncTest "enqueue - process queued messages":
|
asyncTest "enqueue - process queued messages":
|
||||||
## Given
|
## Given
|
||||||
let db = createTestDatabase()
|
|
||||||
let store = RateLimitStore[string].new(db)
|
let store = RateLimitStore[string].new(db)
|
||||||
let manager = await RateLimitManager[string].new(
|
let manager = await RateLimitManager[string].new(
|
||||||
store,
|
store,
|
||||||
@ -268,24 +236,9 @@ suite "Queue RateLimitManager":
|
|||||||
r14 == PassedToSender
|
r14 == PassedToSender
|
||||||
r15 == Enqueued
|
r15 == Enqueued
|
||||||
|
|
||||||
var (critical, normal) = manager.getEnqueued()
|
|
||||||
check:
|
check:
|
||||||
critical.len == 1
|
senderCallCount == 10 # 10 messages passed to sender
|
||||||
normal.len == 3
|
sentMessages.len == 10
|
||||||
normal[0].msgId == "8"
|
|
||||||
normal[1].msgId == "9"
|
|
||||||
normal[2].msgId == "10"
|
|
||||||
critical[0].msgId == "15"
|
|
||||||
|
|
||||||
nowRef.value = now + chronos.milliseconds(250)
|
|
||||||
await sleepAsync(chronos.milliseconds(250))
|
|
||||||
|
|
||||||
(critical, normal) = manager.getEnqueued()
|
|
||||||
check:
|
|
||||||
critical.len == 0
|
|
||||||
normal.len == 0
|
|
||||||
senderCallCount == 14
|
|
||||||
sentMessages.len == 14
|
|
||||||
sentMessages[0].msgId == "1"
|
sentMessages[0].msgId == "1"
|
||||||
sentMessages[1].msgId == "2"
|
sentMessages[1].msgId == "2"
|
||||||
sentMessages[2].msgId == "3"
|
sentMessages[2].msgId == "3"
|
||||||
@ -296,6 +249,13 @@ suite "Queue RateLimitManager":
|
|||||||
sentMessages[7].msgId == "11"
|
sentMessages[7].msgId == "11"
|
||||||
sentMessages[8].msgId == "13"
|
sentMessages[8].msgId == "13"
|
||||||
sentMessages[9].msgId == "14"
|
sentMessages[9].msgId == "14"
|
||||||
|
|
||||||
|
nowRef.value = now + chronos.milliseconds(250)
|
||||||
|
await sleepAsync(chronos.milliseconds(250))
|
||||||
|
|
||||||
|
check:
|
||||||
|
senderCallCount == 14
|
||||||
|
sentMessages.len == 14
|
||||||
sentMessages[10].msgId == "15"
|
sentMessages[10].msgId == "15"
|
||||||
sentMessages[11].msgId == "8"
|
sentMessages[11].msgId == "8"
|
||||||
sentMessages[12].msgId == "9"
|
sentMessages[12].msgId == "9"
|
||||||
|
|||||||
@ -5,6 +5,8 @@ import db_connector/db_sqlite
|
|||||||
import ../chat_sdk/migration
|
import ../chat_sdk/migration
|
||||||
import std/[options, os]
|
import std/[options, os]
|
||||||
|
|
||||||
|
const dbName = "test_store.db"
|
||||||
|
|
||||||
# Implement the Serializable concept for string (for testing)
|
# Implement the Serializable concept for string (for testing)
|
||||||
proc toBytes*(s: string): seq[byte] =
|
proc toBytes*(s: string): seq[byte] =
|
||||||
# Convert each character to a byte
|
# Convert each character to a byte
|
||||||
@ -20,14 +22,14 @@ proc fromBytes*(bytes: seq[byte], T: typedesc[string]): string =
|
|||||||
|
|
||||||
suite "SqliteRateLimitStore Tests":
|
suite "SqliteRateLimitStore Tests":
|
||||||
setup:
|
setup:
|
||||||
let db = open("test-ratelimit.db", "", "", "")
|
let db = open(dbName, "", "", "")
|
||||||
runMigrations(db)
|
runMigrations(db)
|
||||||
|
|
||||||
teardown:
|
teardown:
|
||||||
if db != nil:
|
if db != nil:
|
||||||
db.close()
|
db.close()
|
||||||
if fileExists("test-ratelimit.db"):
|
if fileExists(dbName):
|
||||||
removeFile("test-ratelimit.db")
|
removeFile(dbName)
|
||||||
|
|
||||||
asyncTest "newSqliteRateLimitStore - empty state":
|
asyncTest "newSqliteRateLimitStore - empty state":
|
||||||
## Given
|
## Given
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user