feat: add SafeAsyncIter chaining

This commit is contained in:
gmega 2025-06-27 16:04:49 -03:00
parent 313d6bac1f
commit d94bfe60f6
No known key found for this signature in database
GPG Key ID: 6290D34EAD824B18
2 changed files with 44 additions and 4 deletions

View File

@ -232,3 +232,28 @@ proc empty*[T](_: type SafeAsyncIter[T]): SafeAsyncIter[T] =
true
SafeAsyncIter[T].new(genNext, isFinished)
proc chain*[T](iters: seq[SafeAsyncIter[T]]): SafeAsyncIter[T] =
if iters.len == 0:
return SafeAsyncIter[T].empty
var curIdx = 0
proc ensureNext(): void =
while curIdx < iters.len and iters[curIdx].finished:
inc(curIdx)
proc isFinished(): bool =
curIdx == iters.len
proc genNext(): Future[?!T] {.async: (raises: [CancelledError]).} =
let item = await iters[curIdx].next()
ensureNext()
return item
ensureNext()
return SafeAsyncIter[T].new(genNext, isFinished)
proc chain*[T](iters: varargs[SafeAsyncIter[T]]): SafeAsyncIter[T] =
chain(iters.toSeq)

View File

@ -373,7 +373,7 @@ asyncchecksuite "Test SafeAsyncIter":
# Now, to make sure that this mechanism works, and to document its
# cancellation semantics, this test shows that when the async predicate
# function is cancelled, this cancellation has immediate effect, which means
# that `next()` (or more precisely `getNext()` in `mapFilter` function), is
# that `next()` (or more precisely `getNext()` in `mapFilter` function), is
# interrupted immediately. If this is the case, the the iterator be interrupted
# before `next()` returns this locally captured value from the previous
# iteration and this is exactly the reason why at the end of the test
@ -404,10 +404,8 @@ asyncchecksuite "Test SafeAsyncIter":
expect CancelledError:
for fut in iter2:
if i =? (await fut):
without i =? (await fut), err:
collected.add(i)
else:
fail()
check:
# We expect only values "0" and "1" to be collected
@ -415,3 +413,20 @@ asyncchecksuite "Test SafeAsyncIter":
# will not be returned because of the cancellation.
collected == @["0", "1"]
iter2.finished
test "should allow chaining":
let
iter1 = SafeAsyncIter[int].new(0 ..< 5)
iter2 = SafeAsyncIter[int].new(5 ..< 10)
iter3 = chain[int](iter1, SafeAsyncIter[int].empty, iter2)
var collected: seq[int]
for fut in iter3:
without i =? (await fut), err:
fail()
collected.add(i)
check:
iter3.finished
collected == @[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]