From d94bfe60f6f68d5e4b44cdd2e4d148c0b953265f Mon Sep 17 00:00:00 2001 From: gmega Date: Fri, 27 Jun 2025 16:04:49 -0300 Subject: [PATCH] feat: add SafeAsyncIter chaining --- codex/utils/safeasynciter.nim | 25 +++++++++++++++++++++++++ tests/codex/utils/testsafeasynciter.nim | 23 +++++++++++++++++++---- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/codex/utils/safeasynciter.nim b/codex/utils/safeasynciter.nim index d582fec3..be46cc0f 100644 --- a/codex/utils/safeasynciter.nim +++ b/codex/utils/safeasynciter.nim @@ -232,3 +232,28 @@ proc empty*[T](_: type SafeAsyncIter[T]): SafeAsyncIter[T] = true SafeAsyncIter[T].new(genNext, isFinished) + +proc chain*[T](iters: seq[SafeAsyncIter[T]]): SafeAsyncIter[T] = + if iters.len == 0: + return SafeAsyncIter[T].empty + + var curIdx = 0 + + proc ensureNext(): void = + while curIdx < iters.len and iters[curIdx].finished: + inc(curIdx) + + proc isFinished(): bool = + curIdx == iters.len + + proc genNext(): Future[?!T] {.async: (raises: [CancelledError]).} = + let item = await iters[curIdx].next() + ensureNext() + return item + + ensureNext() + + return SafeAsyncIter[T].new(genNext, isFinished) + +proc chain*[T](iters: varargs[SafeAsyncIter[T]]): SafeAsyncIter[T] = + chain(iters.toSeq) diff --git a/tests/codex/utils/testsafeasynciter.nim b/tests/codex/utils/testsafeasynciter.nim index 1aeba4d2..7acd3640 100644 --- a/tests/codex/utils/testsafeasynciter.nim +++ b/tests/codex/utils/testsafeasynciter.nim @@ -373,7 +373,7 @@ asyncchecksuite "Test SafeAsyncIter": # Now, to make sure that this mechanism works, and to document its # cancellation semantics, this test shows that when the async predicate # function is cancelled, this cancellation has immediate effect, which means - # that `next()` (or more precisely `getNext()` in `mapFilter` function), is + # that `next()` (or more precisely `getNext()` in `mapFilter` function), is # interrupted immediately. If this is the case, the the iterator be interrupted # before `next()` returns this locally captured value from the previous # iteration and this is exactly the reason why at the end of the test @@ -404,10 +404,8 @@ asyncchecksuite "Test SafeAsyncIter": expect CancelledError: for fut in iter2: - if i =? (await fut): + without i =? (await fut), err: collected.add(i) - else: - fail() check: # We expect only values "0" and "1" to be collected @@ -415,3 +413,20 @@ asyncchecksuite "Test SafeAsyncIter": # will not be returned because of the cancellation. collected == @["0", "1"] iter2.finished + + test "should allow chaining": + let + iter1 = SafeAsyncIter[int].new(0 ..< 5) + iter2 = SafeAsyncIter[int].new(5 ..< 10) + iter3 = chain[int](iter1, SafeAsyncIter[int].empty, iter2) + + var collected: seq[int] + + for fut in iter3: + without i =? (await fut), err: + fail() + collected.add(i) + + check: + iter3.finished + collected == @[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]