diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 76c98fdf4..89c434d27 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -451,7 +451,7 @@ proc accept(s: Switch, transport: Transport) {.async.} = # noraises ## switch accept loop, ran for every transport ## - let upgrades = AsyncSemaphore.init(ConcurrentUpgrades) + let upgrades = newAsyncSemaphore(ConcurrentUpgrades) while transport.running: var conn: Connection try: diff --git a/libp2p/utils/semaphore.nim b/libp2p/utils/semaphore.nim index 72ed68b75..301a3e517 100644 --- a/libp2p/utils/semaphore.nim +++ b/libp2p/utils/semaphore.nim @@ -18,11 +18,13 @@ logScope: type AsyncSemaphore* = ref object of RootObj size*: int - count*: int - queue*: seq[Future[void]] + count: int + queue: seq[Future[void]] -proc init*(T: type AsyncSemaphore, size: int): T = - T(size: size, count: size) +proc newAsyncSemaphore*(size: int): AsyncSemaphore = + AsyncSemaphore(size: size, count: size) + +proc `count`*(s: AsyncSemaphore): int = s.count proc tryAcquire*(s: AsyncSemaphore): bool = ## Attempts to acquire a resource, if successful @@ -38,7 +40,7 @@ proc acquire*(s: AsyncSemaphore): Future[void] = ## Acquire a resource and decrement the resource ## counter. If no more resources are available, ## the returned future will not complete until - ## the resource count goes above 0 again. + ## the resource count goes above 0. ## let fut = newFuture[void]("AsyncSemaphore.acquire") @@ -46,8 +48,17 @@ proc acquire*(s: AsyncSemaphore): Future[void] = fut.complete() return fut + proc cancellation(udata: pointer) {.gcsafe.} = + fut.cancelCallback = nil + if not fut.finished: + s.queue.keepItIf( it != fut ) + s.count.inc + + fut.cancelCallback = cancellation + s.queue.add(fut) s.count.dec + trace "Queued slot", available = s.count, queue = s.queue.len return fut @@ -65,7 +76,8 @@ proc release*(s: AsyncSemaphore) = queue = s.queue.len if s.queue.len > 0: - var fut = s.queue.pop() + var fut = s.queue[0] + s.queue.delete(0) if not fut.finished(): fut.complete() diff --git a/tests/testsemaphore.nim b/tests/testsemaphore.nim index 39134b69e..c7a26d86d 100644 --- a/tests/testsemaphore.nim +++ b/tests/testsemaphore.nim @@ -1,5 +1,6 @@ import random import chronos + import ../libp2p/utils/semaphore import ./helpers @@ -8,7 +9,7 @@ randomize() suite "AsyncSemaphore": asyncTest "should acquire": - let sema = AsyncSemaphore.init(3) + let sema = newAsyncSemaphore(3) await sema.acquire() await sema.acquire() @@ -17,7 +18,7 @@ suite "AsyncSemaphore": check sema.count == 0 asyncTest "should release": - let sema = AsyncSemaphore.init(3) + let sema = newAsyncSemaphore(3) await sema.acquire() await sema.acquire() @@ -30,13 +31,12 @@ suite "AsyncSemaphore": check sema.count == 3 asyncTest "should queue acquire": - let sema = AsyncSemaphore.init(1) + let sema = newAsyncSemaphore(1) await sema.acquire() let fut = sema.acquire() check sema.count == -1 - check sema.queue.len == 1 sema.release() sema.release() check sema.count == 1 @@ -45,19 +45,19 @@ suite "AsyncSemaphore": check fut.finished() asyncTest "should keep count == size": - let sema = AsyncSemaphore.init(1) + let sema = newAsyncSemaphore(1) sema.release() sema.release() sema.release() check sema.count == 1 asyncTest "should tryAcquire": - let sema = AsyncSemaphore.init(1) + let sema = newAsyncSemaphore(1) await sema.acquire() check sema.tryAcquire() == false asyncTest "should tryAcquire and acquire": - let sema = AsyncSemaphore.init(4) + let sema = newAsyncSemaphore(4) check sema.tryAcquire() == true check sema.tryAcquire() == true check sema.tryAcquire() == true @@ -67,8 +67,6 @@ suite "AsyncSemaphore": let fut = sema.acquire() check fut.finished == false check sema.count == -1 - # queue is only used when count is < 0 - check sema.queue.len == 1 sema.release() sema.release() @@ -78,10 +76,9 @@ suite "AsyncSemaphore": check fut.finished == true check sema.count == 4 - check sema.queue.len == 0 asyncTest "should restrict resource access": - let sema = AsyncSemaphore.init(3) + let sema = newAsyncSemaphore(3) var resource = 0 proc task() {.async.} = @@ -101,3 +98,50 @@ suite "AsyncSemaphore": tasks.add(task()) await allFutures(tasks) + + asyncTest "should cancel sequential semaphore slot": + let sema = newAsyncSemaphore(1) + + await sema.acquire() + + let tmp = sema.acquire() + check not tmp.finished() + + tmp.cancel() + sema.release() + + check await sema.acquire().withTimeout(10.millis) + + asyncTest "should handle out of order cancellations": + let sema = newAsyncSemaphore(1) + + await sema.acquire() # 1st acquire + let tmp1 = sema.acquire() # 2nd acquire + check not tmp1.finished() + + let tmp2 = sema.acquire() # 3rd acquire + check not tmp2.finished() + + let tmp3 = sema.acquire() # 4th acquire + check not tmp3.finished() + + # up to this point, we've called acquire 4 times + tmp1.cancel() # 1st release (implicit) + tmp2.cancel() # 2nd release (implicit) + + check not tmp3.finished() # check that we didn't release the wrong slot + + sema.release() # 3rd release (explicit) + check tmp3.finished() + + sema.release() # 4th release + check await sema.acquire().withTimeout(10.millis) + + asyncTest "should properly handle timeouts and cancellations": + let sema = newAsyncSemaphore(1) + + await sema.acquire() + check not(await sema.acquire().withTimeout(1.millis)) # should not acquire but cancel + sema.release() + + check await sema.acquire().withTimeout(10.millis)