Fix AsyncLock.locked flag to be consistent. (#129)

* Fix `locked` flag to be more consistent.
Refactor AsyncLock to not use `result`.
Add test for `locked` flag.

* Fixes.

* Fix imports.

* Fix multiple release() without scheduler.
Add more tests.

* Fix review comments.
This commit is contained in:
Eugene Kabanov 2020-09-10 23:28:20 +03:00 committed by GitHub
parent 483054cda6
commit 2134980744
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 109 additions and 29 deletions

View File

@ -9,8 +9,8 @@
# MIT license (LICENSE-MIT)
## This module implements some core synchronization primitives
import std/sequtils
import asyncloop, deques
import std/[sequtils, deques]
import ./asyncloop
type
AsyncLock* = ref object of RootRef
@ -23,6 +23,7 @@ type
## ``release()`` call resets the state to unlocked; first coroutine which
## is blocked in ``acquire()`` is being processed.
locked: bool
acquired: bool
waiters: seq[Future[void]]
AsyncEvent* = ref object of RootRef
@ -71,28 +72,29 @@ proc newAsyncLock*(): AsyncLock =
# Workaround for callSoon() not worked correctly before
# getGlobalDispatcher() call.
discard getGlobalDispatcher()
result = new AsyncLock
result.waiters = newSeq[Future[void]]()
result.locked = false
AsyncLock(waiters: newSeq[Future[void]](), locked: false, acquired: false)
proc wakeUpFirst(lock: AsyncLock) {.inline.} =
proc wakeUpFirst(lock: AsyncLock): bool {.inline.} =
## Wake up the first waiter if it isn't done.
for fut in lock.waiters.mitems():
if not(fut.finished()):
fut.complete()
var i = 0
var res = false
while i < len(lock.waiters):
var waiter = lock.waiters[i]
inc(i)
if not(waiter.finished()):
waiter.complete()
res = true
break
if i > 0:
lock.waiters.delete(0, i - 1)
res
proc checkAll(lock: AsyncLock): bool {.inline.} =
## Returns ``true`` if waiters array is empty or full of cancelled futures.
result = true
for fut in lock.waiters.mitems():
if not(fut.cancelled()):
result = false
break
proc removeWaiter(lock: AsyncLock, waiter: Future[void]) {.inline.} =
## Removes ``waiter`` from list of waiters in ``lock``.
lock.waiters.delete(lock.waiters.find(waiter))
return false
return true
proc acquire*(lock: AsyncLock) {.async.} =
## Acquire a lock ``lock``.
@ -100,24 +102,18 @@ proc acquire*(lock: AsyncLock) {.async.} =
## This procedure blocks until the lock ``lock`` is unlocked, then sets it
## to locked and returns.
if not(lock.locked) and lock.checkAll():
lock.acquired = true
lock.locked = true
else:
var w = newFuture[void]("AsyncLock.acquire")
lock.waiters.add(w)
try:
try:
await w
finally:
lock.removeWaiter(w)
except CancelledError:
if not(lock.locked):
lock.wakeUpFirst()
raise
await w
lock.acquired = true
lock.locked = true
proc locked*(lock: AsyncLock): bool =
## Return `true` if the lock ``lock`` is acquired, `false` otherwise.
result = lock.locked
lock.locked
proc release*(lock: AsyncLock) =
## Release a lock ``lock``.
@ -126,8 +122,15 @@ proc release*(lock: AsyncLock) =
## other coroutines are blocked waiting for the lock to become unlocked,
## allow exactly one of them to proceed.
if lock.locked:
lock.locked = false
lock.wakeUpFirst()
# We set ``lock.locked`` to ``false`` only when there no active waiters.
# If active waiters are present, then ``lock.locked`` will be set to `true`
# in ``acquire()`` procedure's continuation.
if not(lock.acquired):
raise newException(AsyncLockError, "AsyncLock was already released!")
else:
lock.acquired = false
if not(lock.wakeUpFirst()):
lock.locked = false
else:
raise newException(AsyncLockError, "AsyncLock is not acquired!")

View File

@ -41,6 +41,75 @@ suite "Asynchronous sync primitives test suite":
poll()
result = testLockResult
proc testFlag(): Future[bool] {.async.} =
var lock = newAsyncLock()
var futs: array[4, Future[void]]
futs[0] = lock.acquire()
futs[1] = lock.acquire()
futs[2] = lock.acquire()
futs[3] = lock.acquire()
proc checkFlags(b0, b1, b2, b3, b4: bool): bool =
(lock.locked == b0) and
(futs[0].finished == b1) and (futs[1].finished == b2) and
(futs[2].finished == b3) and (futs[3].finished == b4)
if not(checkFlags(true, true, false, false ,false)):
return false
lock.release()
if not(checkFlags(true, true, false, false, false)):
return false
await sleepAsync(10.milliseconds)
if not(checkFlags(true, true, true, false, false)):
return false
lock.release()
if not(checkFlags(true, true, true, false, false)):
return false
await sleepAsync(10.milliseconds)
if not(checkFlags(true, true, true, true, false)):
return false
lock.release()
if not(checkFlags(true, true, true, true, false)):
return false
await sleepAsync(10.milliseconds)
if not(checkFlags(true, true, true, true, true)):
return false
lock.release()
if not(checkFlags(false, true, true, true, true)):
return false
await sleepAsync(10.milliseconds)
if not(checkFlags(false, true, true, true, true)):
return false
return true
proc testNoAcquiredRelease(): Future[bool] {.async.} =
var lock = newAsyncLock()
var res = false
try:
lock.release()
except AsyncLockError:
res = true
return res
proc testDoubleRelease(): Future[bool] {.async.} =
var lock = newAsyncLock()
var fut0 = lock.acquire()
var fut1 = lock.acquire()
var res = false
asyncSpawn fut0
asyncSpawn fut1
lock.release()
try:
lock.release()
except AsyncLockError:
res = true
return res
proc testBehaviorLock(n1, n2, n3: Duration): Future[seq[int]] {.async.} =
var stripe: seq[int]
@ -70,6 +139,7 @@ suite "Asynchronous sync primitives test suite":
stripe.add(n * 10)
await sleepAsync(timeout)
lock.release()
await lock.acquire()
stripe.add(n * 10 + 1)
await sleepAsync(timeout)
@ -252,13 +322,20 @@ suite "Asynchronous sync primitives test suite":
waitFor(testBehaviorLock(50.milliseconds,
20.milliseconds,
10.milliseconds)) == @[10, 20, 30, 11, 21, 31]
test "AsyncLock() cancellation test":
check:
waitFor(testCancelLock(10.milliseconds,
20.milliseconds,
50.milliseconds, 2)) == @[10, 30, 11, 31]
waitFor(testCancelLock(50.milliseconds,
20.milliseconds,
10.milliseconds, 3)) == @[10, 20, 11, 21]
test "AsyncLock() flag consistency test":
check waitFor(testFlag()) == true
test "AsyncLock() double release test":
check waitFor(testDoubleRelease()) == true
test "AsyncLock() non-acquired release test":
check waitFor(testNoAcquiredRelease()) == true
test "AsyncEvent() behavior test":
check test2() == "0123456789"
test "AsyncQueue() behavior test":