From c4a99447bd1939d736778e053a2afee527446f76 Mon Sep 17 00:00:00 2001 From: cheatfate Date: Wed, 17 Jul 2019 16:12:31 +0300 Subject: [PATCH] Fix AsyncLock race and refactor asyncsync.nim to properly support cancellation. Fix async macro to not transform nested procedures. --- chronos/asyncmacro2.nim | 5 ++ chronos/asyncsync.nim | 180 ++++++++++++++++++++++++++-------------- tests/testsync.nim | 64 +++++++++++++- 3 files changed, 186 insertions(+), 63 deletions(-) diff --git a/chronos/asyncmacro2.nim b/chronos/asyncmacro2.nim index 46bef06..55e7acd 100644 --- a/chronos/asyncmacro2.nim +++ b/chronos/asyncmacro2.nim @@ -205,6 +205,11 @@ proc processBody(node, retFutureSym: NimNode, else: discard for i in 0 ..< result.len: + # We must not transform nested procedures of any form, otherwise + # `retFutureSym` will be used for all nested procedures as their own + # `retFuture`. + if result[i].kind in {nnkProcDef, nnkMethodDef, nnkDo, nnkLambda}: + continue result[i] = processBody(result[i], retFutureSym, subTypeIsVoid, futureVarIdents) diff --git a/chronos/asyncsync.nim b/chronos/asyncsync.nim index dc77fa5..c9c6fa8 100644 --- a/chronos/asyncsync.nim +++ b/chronos/asyncsync.nim @@ -22,7 +22,7 @@ type ## ``release()`` call resets the state to unlocked; first coroutine which ## is blocked in ``acquire()`` is being processed. locked: bool - waiters: Deque[Future[void]] + waiters: seq[Future[void]] AsyncEvent* = ref object of RootRef ## A primitive event object. @@ -34,9 +34,8 @@ type ## If more than one coroutine blocked in ``wait()`` waiting for event ## state to be signaled, when event get fired, then all coroutines ## continue proceeds in order, they have entered waiting state. - flag: bool - waiters: Deque[Future[void]] + waiters: seq[Future[void]] AsyncQueue*[T] = ref object of RootRef ## A queue, useful for coordinating producer and consumer coroutines. @@ -45,8 +44,8 @@ type ## infinite. If it is an integer greater than ``0``, then "await put()" ## will block when the queue reaches ``maxsize``, until an item is ## removed by "await get()". - getters: Deque[Future[void]] - putters: Deque[Future[void]] + getters: seq[Future[void]] + putters: seq[Future[void]] queue: Deque[T] maxsize: int @@ -72,31 +71,49 @@ proc newAsyncLock*(): AsyncLock = # getGlobalDispatcher() call. discard getGlobalDispatcher() result = new AsyncLock - result.waiters = initDeque[Future[void]]() + result.waiters = newSeq[Future[void]]() result.locked = false +proc wakeUpFirst(lock: AsyncLock) {.inline.} = + ## Wake up the first waiter if it isn't done. + for fut in lock.waiters.mitems(): + if not(fut.finished()): + fut.complete() + break + +proc checkAll(lock: AsyncLock): bool {.inline.} = + ## Returns ``true`` if waiters array is empty or full of cancelled futures. + result = true + for fut in lock.waiters.mitems(): + if not(fut.cancelled()): + result = false + break + +proc removeWaiter(lock: AsyncLock, waiter: Future[void]) {.inline.} = + ## Removes ``waiter`` from list of waiters in ``lock``. + lock.waiters.delete(lock.waiters.find(waiter)) + proc acquire*(lock: AsyncLock) {.async.} = ## Acquire a lock ``lock``. ## ## This procedure blocks until the lock ``lock`` is unlocked, then sets it ## to locked and returns. - if not(lock.locked): + if not(lock.locked) and lock.checkAll(): lock.locked = true else: var w = newFuture[void]("AsyncLock.acquire") - lock.waiters.addLast(w) - await w + lock.waiters.add(w) + try: + try: + await w + finally: + lock.removeWaiter(w) + except CancelledError: + if not(lock.locked): + lock.wakeUpFirst() + raise lock.locked = true -proc own*(lock: AsyncLock) = - ## Acquire a lock ``lock``. - ## - ## This procedure not blocks, if ``lock`` is locked, then ``AsyncLockError`` - ## exception would be raised. - if lock.locked: - raise newException(AsyncLockError, "AsyncLock is already acquired!") - lock.locked = true - proc locked*(lock: AsyncLock): bool = ## Return `true` if the lock ``lock`` is acquired, `false` otherwise. result = lock.locked @@ -107,14 +124,9 @@ proc release*(lock: AsyncLock) = ## When the ``lock`` is locked, reset it to unlocked, and return. If any ## other coroutines are blocked waiting for the lock to become unlocked, ## allow exactly one of them to proceed. - var w: Future[void] if lock.locked: lock.locked = false - while len(lock.waiters) > 0: - w = lock.waiters.popFirst() - if not(w.finished()): - w.complete() - break + lock.wakeUpFirst() else: raise newException(AsyncLockError, "AsyncLock is not acquired!") @@ -130,32 +142,35 @@ proc newAsyncEvent*(): AsyncEvent = # getGlobalDispatcher() call. discard getGlobalDispatcher() result = new AsyncEvent - result.waiters = initDeque[Future[void]]() + result.waiters = newSeq[Future[void]]() result.flag = false +proc removeWaiter(event: AsyncEvent, waiter: Future[void]) {.inline.} = + ## Removes ``waiter`` from list of waiters in ``lock``. + event.waiters.delete(event.waiters.find(waiter)) + proc wait*(event: AsyncEvent) {.async.} = ## Block until the internal flag of ``event`` is `true`. ## If the internal flag is `true` on entry, return immediately. Otherwise, ## block until another task calls `fire()` to set the flag to `true`, ## then return. - if event.flag: - discard - else: + if not(event.flag): var w = newFuture[void]("AsyncEvent.wait") - event.waiters.addLast(w) - await w + event.waiters.add(w) + try: + await w + finally: + event.removeWaiter(w) proc fire*(event: AsyncEvent) = ## Set the internal flag of ``event`` to `true`. All tasks waiting for it ## to become `true` are awakened. Task that call `wait()` once the flag is ## `true` will not block at all. - var w: Future[void] if not(event.flag): event.flag = true - while len(event.waiters) > 0: - w = event.waiters.popFirst() - if not(w.finished()): - w.complete() + for fut in event.waiters: + if not(fut.finished()): + fut.complete() proc clear*(event: AsyncEvent) = ## Reset the internal flag of ``event`` to `false`. Subsequently, tasks @@ -174,11 +189,42 @@ proc newAsyncQueue*[T](maxsize: int = 0): AsyncQueue[T] = # getGlobalDispatcher() call. discard getGlobalDispatcher() result = new AsyncQueue[T] - result.getters = initDeque[Future[void]]() - result.putters = initDeque[Future[void]]() + result.getters = newSeq[Future[void]]() + result.putters = newSeq[Future[void]]() result.queue = initDeque[T]() result.maxsize = maxsize +proc wakeupNext(waiters: var seq[Future[void]]) {.inline.} = + var i = 0 + while i < len(waiters): + var waiter = waiters[i] + if not(waiter.finished()): + let length = len(waiters) - (i + 1) + let offset = len(waiters) - length + if length > 0: + for k in 0..= 0: + waiters.delete(index) + +proc removeWaiter(waiters: var Deque[Future[void]], + fut: Future[void]) {.inline.} = + var nwaiters = initDeque[Future[void]]() + while len(waiters) > 0: + var waiter = waiters.popFirst() + if waiter != fut: + nwaiters.addFirst(waiter) + proc full*[T](aq: AsyncQueue[T]): bool {.inline.} = ## Return ``true`` if there are ``maxsize`` items in the queue. ## @@ -201,10 +247,7 @@ proc addFirstNoWait*[T](aq: AsyncQueue[T], item: T) = if aq.full(): raise newException(AsyncQueueFullError, "AsyncQueue is full!") aq.queue.addFirst(item) - while len(aq.getters) > 0: - w = aq.getters.popFirst() - if not(w.finished()): - w.complete() + aq.getters.wakeupNext() proc addLastNoWait*[T](aq: AsyncQueue[T], item: T) = ## Put an item ``item`` at the end of the queue ``aq`` immediately. @@ -214,10 +257,7 @@ proc addLastNoWait*[T](aq: AsyncQueue[T], item: T) = if aq.full(): raise newException(AsyncQueueFullError, "AsyncQueue is full!") aq.queue.addLast(item) - while len(aq.getters) > 0: - w = aq.getters.popFirst() - if not(w.finished()): - w.complete() + aq.getters.wakeupNext() proc popFirstNoWait*[T](aq: AsyncQueue[T]): T = ## Get an item from the beginning of the queue ``aq`` immediately. @@ -227,10 +267,7 @@ proc popFirstNoWait*[T](aq: AsyncQueue[T]): T = if aq.empty(): raise newException(AsyncQueueEmptyError, "AsyncQueue is empty!") result = aq.queue.popFirst() - while len(aq.putters) > 0: - w = aq.putters.popFirst() - if not(w.finished()): - w.complete() + aq.putters.wakeupNext() proc popLastNoWait*[T](aq: AsyncQueue[T]): T = ## Get an item from the end of the queue ``aq`` immediately. @@ -240,18 +277,21 @@ proc popLastNoWait*[T](aq: AsyncQueue[T]): T = if aq.empty(): raise newException(AsyncQueueEmptyError, "AsyncQueue is empty!") result = aq.queue.popLast() - while len(aq.putters) > 0: - w = aq.putters.popFirst() - if not(w.finished()): - w.complete() + aq.putters.wakeupNext() proc addFirst*[T](aq: AsyncQueue[T], item: T) {.async.} = ## Put an ``item`` to the beginning of the queue ``aq``. If the queue is full, ## wait until a free slot is available before adding item. while aq.full(): var putter = newFuture[void]("AsyncQueue.addFirst") - aq.putters.addLast(putter) - await putter + aq.putters.add(putter) + try: + await putter + except: + aq.putters.removeWaiter(putter) + if not aq.full() and not(putter.cancelled()): + aq.putters.wakeupNext() + raise aq.addFirstNoWait(item) proc addLast*[T](aq: AsyncQueue[T], item: T) {.async.} = @@ -259,8 +299,14 @@ proc addLast*[T](aq: AsyncQueue[T], item: T) {.async.} = ## wait until a free slot is available before adding item. while aq.full(): var putter = newFuture[void]("AsyncQueue.addLast") - aq.putters.addLast(putter) - await putter + aq.putters.add(putter) + try: + await putter + except: + aq.putters.removeWaiter(putter) + if not aq.full() and not(putter.cancelled()): + aq.putters.wakeupNext() + raise aq.addLastNoWait(item) proc popFirst*[T](aq: AsyncQueue[T]): Future[T] {.async.} = @@ -268,8 +314,14 @@ proc popFirst*[T](aq: AsyncQueue[T]): Future[T] {.async.} = ## If the queue is empty, wait until an item is available. while aq.empty(): var getter = newFuture[void]("AsyncQueue.popFirst") - aq.getters.addLast(getter) - await getter + aq.getters.add(getter) + try: + await getter + except: + aq.getters.removeWaiter(getter) + if not(aq.empty()) and not(getter.cancelled()): + aq.getters.wakeupNext() + raise result = aq.popFirstNoWait() proc popLast*[T](aq: AsyncQueue[T]): Future[T] {.async.} = @@ -277,8 +329,14 @@ proc popLast*[T](aq: AsyncQueue[T]): Future[T] {.async.} = ## If the queue is empty, wait until an item is available. while aq.empty(): var getter = newFuture[void]("AsyncQueue.popLast") - aq.getters.addLast(getter) - await getter + aq.getters.add(getter) + try: + await getter + except: + aq.getters.removeWaiter(getter) + if not(aq.empty()) and not(getter.cancelled()): + aq.getters.wakeupNext() + raise result = aq.popLastNoWait() proc putNoWait*[T](aq: AsyncQueue[T], item: T) {.inline.} = diff --git a/tests/testsync.nim b/tests/testsync.nim index 47bbe9b..02b9433 100644 --- a/tests/testsync.nim +++ b/tests/testsync.nim @@ -22,7 +22,7 @@ suite "Asynchronous sync primitives test suite": proc test1(): string = var lock = newAsyncLock() - lock.own() + waitFor lock.acquire() discard testLock(0, lock) discard testLock(1, lock) discard testLock(2, lock) @@ -39,6 +39,52 @@ suite "Asynchronous sync primitives test suite": poll() result = testLockResult + proc testBehaviorLock(n1, n2, n3: Duration): Future[seq[int]] {.async.} = + var stripe: seq[int] + + proc task(lock: AsyncLock, n: int, timeout: Duration) {.async.} = + await lock.acquire() + stripe.add(n * 10) + await sleepAsync(timeout) + lock.release() + await lock.acquire() + stripe.add(n * 10 + 1) + await sleepAsync(timeout) + lock.release() + + var lock = newAsyncLock() + var fut1 = task(lock, 1, n1) + var fut2 = task(lock, 2, n2) + var fut3 = task(lock, 3, n3) + await allFutures(fut1, fut2, fut3) + result = stripe + + proc testCancelLock(n1, n2, n3: Duration, + cancelIndex: int): Future[seq[int]] {.async.} = + var stripe: seq[int] + + proc task(lock: AsyncLock, n: int, timeout: Duration) {.async.} = + await lock.acquire() + stripe.add(n * 10) + await sleepAsync(timeout) + lock.release() + await lock.acquire() + stripe.add(n * 10 + 1) + await sleepAsync(timeout) + lock.release() + + var lock = newAsyncLock() + var fut1 = task(lock, 1, n1) + var fut2 = task(lock, 2, n2) + var fut3 = task(lock, 3, n3) + if cancelIndex == 2: + fut2.cancel() + else: + fut3.cancel() + await allFutures(fut1, fut2, fut3) + result = stripe + + proc testEvent(n: int, ev: AsyncEvent) {.async.} = await ev.wait() testEventResult = testEventResult & $n @@ -197,7 +243,21 @@ suite "Asynchronous sync primitives test suite": result = (5 in q and not(6 in q)) test "AsyncLock() behavior test": - check test1() == "0123456789" + check: + test1() == "0123456789" + waitFor(testBehaviorLock(10.milliseconds, + 20.milliseconds, + 50.milliseconds)) == @[10, 20, 30, 11, 21, 31] + waitFor(testBehaviorLock(50.milliseconds, + 20.milliseconds, + 10.milliseconds)) == @[10, 20, 30, 11, 21, 31] + waitFor(testCancelLock(10.milliseconds, + 20.milliseconds, + 50.milliseconds, 2)) == @[10, 30, 11, 31] + waitFor(testCancelLock(50.milliseconds, + 20.milliseconds, + 10.milliseconds, 3)) == @[10, 20, 11, 21] + test "AsyncEvent() behavior test": check test2() == "0123456789" test "AsyncQueue() behavior test":