diff --git a/library/sds_thread/sds_thread.nim b/library/sds_thread/sds_thread.nim index c5c12cb..85e5073 100644 --- a/library/sds_thread/sds_thread.nim +++ b/library/sds_thread/sds_thread.nim @@ -43,13 +43,18 @@ proc runSds(ctx: ptr SdsContext) {.async.} = error "sds thread could not receive a request" continue - ## Handle the request - asyncSpawn SdsThreadRequest.process(request, addr rm) - + ## Ack receipt to the requester thread BEFORE processing — it only + ## waits for "received", not "processed", so the caller's throughput + ## doesn't change. Processing is then awaited (was: asyncSpawn'd), + ## which serializes requests on this worker. The SP channel + lock + ## above already assume no concurrent requests, so awaiting here + ## aligns the processing side with that assumption. let fireRes = ctx.reqReceivedSignal.fireSync() if fireRes.isErr(): error "could not fireSync back to requester thread", error = fireRes.error + await SdsThreadRequest.process(request, addr rm) + proc run(ctx: ptr SdsContext) {.thread.} = ## Launch sds worker waitFor runSds(ctx) diff --git a/sds.nim b/sds.nim index e426637..29e4a2f 100644 --- a/sds.nim +++ b/sds.nim @@ -511,10 +511,13 @@ proc periodicRepairSweep( await sleepAsync(chronos.milliseconds(rm.config.repairSweepInterval.inMilliseconds)) proc startPeriodicTasks*(rm: ReliabilityManager) = - ## Starts the periodic tasks for buffer sweeping and sync message sending. - asyncSpawn rm.periodicBufferSweep() - asyncSpawn rm.periodicSyncMessage() - asyncSpawn rm.periodicRepairSweep() + ## Starts the periodic background tasks (buffer sweep, sync message, + ## SDS-R repair sweep). The futures are kept on the manager so `cleanup` + ## can cancel them — without that, the loops would outlive a cleaned-up + ## manager and keep firing against cleared state. + rm.periodicTasks.add(FutureBase(rm.periodicBufferSweep())) + rm.periodicTasks.add(FutureBase(rm.periodicSyncMessage())) + rm.periodicTasks.add(FutureBase(rm.periodicRepairSweep())) proc resetReliabilityManager*( rm: ReliabilityManager diff --git a/sds/sds_utils.nim b/sds/sds_utils.nim index 9770f30..0e3a07d 100644 --- a/sds/sds_utils.nim +++ b/sds/sds_utils.nim @@ -29,8 +29,16 @@ proc cleanup*( ## reconstructed against the same backend after cleanup, so disk state must ## survive. For deliberate disk wipe, use `removeChannel` or ## `resetReliabilityManager`. + ## + ## Periodic tasks are cancelled BEFORE acquiring the lock so that a task + ## currently blocked on `lock.acquire()` can unwind via CancelledError + ## without deadlocking against cleanup itself. if rm.isNil(): return + for task in rm.periodicTasks: + if not task.finished: + await task.cancelAndWait() + rm.periodicTasks.setLen(0) try: await rm.lock.acquire() try: diff --git a/sds/types/reliability_manager.nim b/sds/types/reliability_manager.nim index d487c48..6b4cc7e 100644 --- a/sds/types/reliability_manager.nim +++ b/sds/types/reliability_manager.nim @@ -21,6 +21,9 @@ type ReliabilityManager* = ref object ## one another at await points; the manager assumes all calls come from ## the same Chronos event loop (the FFI worker thread). Multi-OS-thread ## use is the caller's responsibility. + periodicTasks*: seq[FutureBase] + ## Handles to the background loops started by `startPeriodicTasks` so + ## `cleanup` can cancel them on shutdown instead of leaking them. onMessageReady*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} onMessageSent*: proc(messageId: SdsMessageID, channelId: SdsChannelID) {.gcsafe.} onMissingDependencies*: proc( @@ -48,5 +51,6 @@ proc new*( participantId: participantId, persistence: persistence, lock: newAsyncLock(), + periodicTasks: @[], ) return rm diff --git a/tests/async_unittest.nim b/tests/async_unittest.nim new file mode 100644 index 0000000..7f625fe --- /dev/null +++ b/tests/async_unittest.nim @@ -0,0 +1,69 @@ +## Shared async-aware wrappers around `unittest` so tests in this repo can +## `await` directly in setup/test/teardown blocks instead of sprinkling +## `waitFor` at each call site. +## +## Usage: +## +## ```nim +## import ./async_unittest +## +## suite "X": +## var rm: ReliabilityManager +## +## asyncSetup: +## rm = newReliabilityManager(...).get() +## check (await rm.ensureChannel("ch")).isOk() +## +## asyncTeardown: +## if not rm.isNil: +## await rm.cleanup() +## +## asyncTest "Y": +## await rm.wrapOutgoingMessage(...) +## ``` +## +## All three blocks run inside the same async proc (per test). unittest's +## own `setup:`/`teardown:` still work for purely synchronous fixtures. + +import unittest, chronos +export unittest, chronos + +template asyncSetup*(body: untyped) {.dirty.} = + ## Async counterpart to unittest's `setup:`. Runs inside each asyncTest's + ## async proc, so `await` works. + template asyncTestSetupIMPL(): untyped {.dirty.} = + body + +template asyncTeardown*(body: untyped) {.dirty.} = + ## Async counterpart to unittest's `teardown:`. Runs in a `finally` so it + ## executes even when the test body (or setup) raises. + template asyncTestTeardownIMPL(): untyped {.dirty.} = + body + +template asyncTest*(name: string, body: untyped) = + ## Wraps a unittest `test` body in an async proc so `await` works on the + ## now-async ReliabilityManager API. unittest's `check` raises Exception, + ## which is wider than chronos's default CatchableError; the exception is + ## caught inside the async body, stashed, and re-raised after waitFor so + ## unittest's normal failure handling sees it. + ## + ## `cast(gcsafe)` is needed because suite-level vars (e.g. `var rm`) look + ## like globals to the async closure, but the FFI runtime is single-thread + ## so the "not gcsafe" warning isn't a real hazard here. + test name: + var asyncTestErr {.inject.}: ref Exception = nil + proc inner() {.async.} = + {.cast(gcsafe).}: + try: + when declared(asyncTestSetupIMPL): + asyncTestSetupIMPL() + try: + body + finally: + when declared(asyncTestTeardownIMPL): + asyncTestTeardownIMPL() + except Exception as e: + asyncTestErr = e + waitFor inner() + if asyncTestErr != nil: + raise asyncTestErr diff --git a/tests/test_persistence.nim b/tests/test_persistence.nim index 981780d..6e26aa9 100644 --- a/tests/test_persistence.nim +++ b/tests/test_persistence.nim @@ -1,28 +1,12 @@ -import unittest, results, chronos, std/[tables, sets, times] +import results, std/[tables, sets, times] import sds +import ./async_unittest import ./in_memory_persistence converter toParticipantID(s: string): SdsParticipantID = s.SdsParticipantID const testChannel = "testChannel" -template asyncTest(name: string, body: untyped) = - ## Wraps a unittest `test` body in an async proc and runs it to completion. - ## Tests can now `await` rm.* and rm.persistence.* calls directly. - ## unittest's `check` raises Exception (wider than chronos's CatchableError), - ## so catch inside the async body and re-raise after waitFor. - test name: - var asyncTestErr {.inject.}: ref Exception = nil - proc inner() {.async.} = - {.cast(gcsafe).}: - try: - body - except Exception as e: - asyncTestErr = e - waitFor inner() - if asyncTestErr != nil: - raise asyncTestErr - suite "Persistence: write → restart → read-back": asyncTest "outgoing buffer survives restart": let store = newInMemoryStore() diff --git a/tests/test_reliability.nim b/tests/test_reliability.nim index a5dcdf1..c81abdd 100644 --- a/tests/test_reliability.nim +++ b/tests/test_reliability.nim @@ -1,5 +1,6 @@ -import unittest, results, chronos, std/[times, options, tables] +import results, std/[times, options, tables] import sds +import ./async_unittest # Test-only convenience: implicit string → SdsParticipantID so test fixtures # can use string literals. Production code retains the distinct-type safety. @@ -7,28 +8,6 @@ converter toParticipantID(s: string): SdsParticipantID = s.SdsParticipantID const testChannel = "testChannel" -template asyncTest(name: string, body: untyped) = - ## Wraps a unittest `test` body in an async proc so tests can `await` the - ## now-async ReliabilityManager API directly. Setup/teardown blocks still - ## run in the outer (sync) scope — use `waitFor` for any async calls there. - ## unittest's `check` raises `Exception`, which is wider than chronos's - ## default `CatchableError` for async procs — so we catch it inside and - ## re-raise after waitFor, where unittest's normal handling can see it. - ## cast(gcsafe) is needed because suite-level `var rm` looks like a global - ## to the closure capture, but the FFI runtime is single-threaded so the - ## "not gcsafe" warning isn't a real hazard here. - test name: - var asyncTestErr {.inject.}: ref Exception = nil - proc inner() {.async.} = - {.cast(gcsafe).}: - try: - body - except Exception as e: - asyncTestErr = e - waitFor inner() - if asyncTestErr != nil: - raise asyncTestErr - proc seedBloom( rm: ReliabilityManager, channel: SdsChannelID, n: int, prefix = "noise" ) = @@ -43,15 +22,15 @@ proc seedBloom( suite "Core Operations": var rm: ReliabilityManager - setup: + asyncSetup: let rmResult = newReliabilityManager(participantId = "alice") check rmResult.isOk() rm = rmResult.get() - check (waitFor rm.ensureChannel(testChannel)).isOk() + check (await rm.ensureChannel(testChannel)).isOk() - teardown: + asyncTeardown: if not rm.isNil: - waitFor rm.cleanup() + await rm.cleanup() asyncTest "can create with default config": let config = defaultConfig() @@ -140,15 +119,15 @@ suite "Core Operations": suite "Reliability Mechanisms": var rm: ReliabilityManager - setup: + asyncSetup: let rmResult = newReliabilityManager(participantId = "alice") check rmResult.isOk() rm = rmResult.get() - check (waitFor rm.ensureChannel(testChannel)).isOk() + check (await rm.ensureChannel(testChannel)).isOk() - teardown: + asyncTeardown: if not rm.isNil: - waitFor rm.cleanup() + await rm.cleanup() asyncTest "dependency detection and resolution": var messageReadyCount = 0 @@ -558,15 +537,15 @@ suite "Reliability Mechanisms": suite "Periodic Tasks & Buffer Management": var rm: ReliabilityManager - setup: + asyncSetup: let rmResult = newReliabilityManager(participantId = "alice") check rmResult.isOk() rm = rmResult.get() - check (waitFor rm.ensureChannel(testChannel)).isOk() + check (await rm.ensureChannel(testChannel)).isOk() - teardown: + asyncTeardown: if not rm.isNil: - waitFor rm.cleanup() + await rm.cleanup() asyncTest "outgoing buffer management": var messageSentCount = 0 @@ -696,15 +675,15 @@ suite "Periodic Tasks & Buffer Management": suite "Special Cases Handling": var rm: ReliabilityManager - setup: + asyncSetup: let rmResult = newReliabilityManager(participantId = "alice") check rmResult.isOk() rm = rmResult.get() - check (waitFor rm.ensureChannel(testChannel)).isOk() + check (await rm.ensureChannel(testChannel)).isOk() - teardown: + asyncTeardown: if not rm.isNil: - waitFor rm.cleanup() + await rm.cleanup() asyncTest "message history limits": # Add messages up to max history size @@ -810,14 +789,14 @@ suite "cleanup": suite "Multi-Channel ReliabilityManager Tests": var rm: ReliabilityManager - setup: + asyncSetup: let rmResult = newReliabilityManager(participantId = "alice") check rmResult.isOk() rm = rmResult.get() - teardown: + asyncTeardown: if not rm.isNil: - waitFor rm.cleanup() + await rm.cleanup() asyncTest "can create multi-channel manager without channel ID": check rm.channels.len == 0 @@ -1050,17 +1029,17 @@ suite "SDS-R: Computation Functions": suite "SDS-R: Repair Buffer Management": var rm: ReliabilityManager - setup: + asyncSetup: let rmResult = newReliabilityManager( participantId = "test-participant" ) check rmResult.isOk() rm = rmResult.get() - check (waitFor rm.ensureChannel(testChannel)).isOk() + check (await rm.ensureChannel(testChannel)).isOk() - teardown: + asyncTeardown: if not rm.isNil: - waitFor rm.cleanup() + await rm.cleanup() asyncTest "missing deps added to outgoing repair buffer": var missingDepsCount = 0 @@ -1580,13 +1559,13 @@ suite "SDS-R: Lifecycle and State": suite "SDS-R: Repair Sweep": var rm: ReliabilityManager - setup: + asyncSetup: rm = newReliabilityManager(participantId = "bob").get() - check (waitFor rm.ensureChannel(testChannel)).isOk() + check (await rm.ensureChannel(testChannel)).isOk() - teardown: + asyncTeardown: if not rm.isNil: - waitFor rm.cleanup() + await rm.cleanup() asyncTest "runRepairSweep fires onRepairReady for expired tResp": var fireCount = 0 @@ -1666,12 +1645,16 @@ type delivered: Table[SdsParticipantID, seq[SdsMessageID]] # Log of raw message-ids placed on the wire, tagged with the source peer. wireLog: seq[tuple[senderId: SdsParticipantID, messageId: SdsMessageID]] + # Queue of (sender, bytes) the repair callback would have delivered if it + # could await. Drained explicitly by `bus.drain()` from the test body. + pending: seq[(SdsParticipantID, seq[byte])] proc newTestBus(): TestBus = TestBus( peers: initOrderedTable[SdsParticipantID, ReliabilityManager](), delivered: initTable[SdsParticipantID, seq[SdsMessageID]](), wireLog: @[], + pending: @[], ) proc recordWire(bus: TestBus, senderId: SdsParticipantID, bytes: seq[byte]) {.gcsafe.} = @@ -1690,6 +1673,16 @@ proc deliverExcept( continue discard await peer.unwrapReceivedMessage(bytes) +proc drain(bus: TestBus): Future[void] {.async.} = + ## Delivers every (sender, bytes) the repair callback enqueued. Loops until + ## the queue stays empty across one full pass — a delivery may trigger a + ## new repair-ready callback that re-enqueues. + while bus.pending.len > 0: + let batch = move bus.pending + bus.pending = @[] + for entry in batch: + await bus.deliverExcept(entry[0], entry[1], @[]) + proc addPeer( bus: TestBus, participantId: SdsParticipantID, @@ -1709,12 +1702,11 @@ proc addPeer( proc(msgId: SdsMessageID, ch: SdsChannelID) {.gcsafe.} = discard, proc(msgId: SdsMessageID, deps: seq[HistoryEntry], ch: SdsChannelID) {.gcsafe.} = discard, onRepairReady = proc(bytes: seq[byte], ch: SdsChannelID) {.gcsafe.} = + # The callback contract is sync, so we cannot `await` here. Enqueue the + # delivery and let the test drive it via `bus.drain()` instead. {.cast(gcsafe).}: busRef.recordWire(pid, bytes) - # Fire-and-forget delivery from a sync callback context — we cannot - # await here, so spawn the async delivery onto the same event loop. - asyncSpawn(busRef.deliverExcept(pid, bytes, @[])) - , + busRef.pending.add((pid, bytes)), ) return rm @@ -1786,9 +1778,7 @@ suite "SDS-R: Multi-Participant Integration": # then run her sweep. She rebroadcasts M1. alice.forceIncomingExpired("m1") await alice.runRepairSweep() - - # Allow any asyncSpawn'd deliveries from the repair callback to run. - await sleepAsync(chronos.milliseconds(10)) + await bus.drain() # Bob now has M1 and M2 delivered. check: @@ -1820,7 +1810,7 @@ suite "SDS-R: Multi-Participant Integration": # pending entry when Carol receives the rebroadcast. alice.forceIncomingExpired("m1") await alice.runRepairSweep() - await sleepAsync(chronos.milliseconds(10)) + await bus.drain() # Carol's pending response must have been cleared by the dedup-path cleanup. check "m1" notin carol.channels[testChannel].incomingRepairBuffer @@ -1828,7 +1818,7 @@ suite "SDS-R: Multi-Participant Integration": # Even if we now force-run Carol's sweep, nothing should fire. let wireCountBefore = bus.wireLog.len await carol.runRepairSweep() - await sleepAsync(chronos.milliseconds(10)) + await bus.drain() check bus.wireLog.len == wireCountBefore # Bob received exactly one rebroadcast of M1.