From c09d0321331c3a7116c42444366911b52c0eb72f Mon Sep 17 00:00:00 2001 From: Tanguy Date: Thu, 24 Feb 2022 17:31:47 +0100 Subject: [PATCH] Allow force dial (#696) --- libp2p/connmanager.nim | 7 +++-- libp2p/dial.nim | 9 ++++-- libp2p/dialer.nim | 21 ++++++++------ libp2p/switch.nim | 10 ++++--- libp2p/utils/semaphore.nim | 15 +++++++--- tests/testconnmngr.nim | 30 ++++++++++++++++++++ tests/testsemaphore.nim | 56 +++++++++++++++++++++++++++++++++++--- 7 files changed, 123 insertions(+), 25 deletions(-) diff --git a/libp2p/connmanager.nim b/libp2p/connmanager.nim index 3d769e5..f56c7b7 100644 --- a/libp2p/connmanager.nim +++ b/libp2p/connmanager.nim @@ -452,7 +452,8 @@ proc trackIncomingConn*(c: ConnManager, raise exc proc trackOutgoingConn*(c: ConnManager, - provider: ConnProvider): + provider: ConnProvider, + forceDial = false): Future[Connection] {.async.} = ## try acquiring a connection if all slots ## are already taken, raise TooManyConnectionsError @@ -462,7 +463,9 @@ proc trackOutgoingConn*(c: ConnManager, trace "Tracking outgoing connection", count = c.outSema.count, max = c.outSema.size - if not c.outSema.tryAcquire(): + if forceDial: + c.outSema.forceAcquire() + elif not c.outSema.tryAcquire(): trace "Too many outgoing connections!", count = c.outSema.count, max = c.outSema.size raise newTooManyConnectionsError() diff --git a/libp2p/dial.nim b/libp2p/dial.nim index 7eb3a46..ea51270 100644 --- a/libp2p/dial.nim +++ b/libp2p/dial.nim @@ -19,7 +19,8 @@ type method connect*( self: Dial, peerId: PeerId, - addrs: seq[MultiAddress]) {.async, base.} = + addrs: seq[MultiAddress], + forceDial = false) {.async, base.} = ## connect remote peer without negotiating ## a protocol ## @@ -29,7 +30,8 @@ method connect*( method dial*( self: Dial, peerId: PeerId, - protos: seq[string]): Future[Connection] {.async, base.} = + protos: seq[string], + ): Future[Connection] {.async, base.} = ## create a protocol stream over an ## existing connection ## @@ -40,7 +42,8 @@ method dial*( self: Dial, peerId: PeerId, addrs: seq[MultiAddress], - protos: seq[string]): Future[Connection] {.async, base.} = + protos: seq[string], + forceDial = false): Future[Connection] {.async, base.} = ## create a protocol stream and establish ## a connection if one doesn't exist already ## diff --git a/libp2p/dialer.nim b/libp2p/dialer.nim index 3d02379..65cc1d6 100644 --- a/libp2p/dialer.nim +++ b/libp2p/dialer.nim @@ -47,7 +47,8 @@ type proc dialAndUpgrade( self: Dialer, peerId: PeerId, - addrs: seq[MultiAddress]): + addrs: seq[MultiAddress], + forceDial: bool): Future[Connection] {.async.} = debug "Dialing peer", peerId @@ -72,7 +73,8 @@ proc dialAndUpgrade( transportCopy = transport addressCopy = a await self.connManager.trackOutgoingConn( - () => transportCopy.dial(hostname, addressCopy) + () => transportCopy.dial(hostname, addressCopy), + forceDial ) except TooManyConnectionsError as exc: trace "Connection limit reached!" @@ -112,7 +114,8 @@ proc dialAndUpgrade( proc internalConnect( self: Dialer, peerId: PeerId, - addrs: seq[MultiAddress]): + addrs: seq[MultiAddress], + forceDial: bool): Future[Connection] {.async.} = if self.localPeerId == peerId: raise newException(CatchableError, "can't dial self!") @@ -136,7 +139,7 @@ proc internalConnect( trace "Reusing existing connection", conn, direction = $conn.dir return conn - conn = await self.dialAndUpgrade(peerId, addrs) + conn = await self.dialAndUpgrade(peerId, addrs, forceDial) if isNil(conn): # None of the addresses connected raise newException(DialFailedError, "Unable to establish outgoing link") @@ -159,7 +162,8 @@ proc internalConnect( method connect*( self: Dialer, peerId: PeerId, - addrs: seq[MultiAddress]) {.async.} = + addrs: seq[MultiAddress], + forceDial = false) {.async.} = ## connect remote peer without negotiating ## a protocol ## @@ -167,7 +171,7 @@ method connect*( if self.connManager.connCount(peerId) > 0: return - discard await self.internalConnect(peerId, addrs) + discard await self.internalConnect(peerId, addrs, forceDial) proc negotiateStream( self: Dialer, @@ -200,7 +204,8 @@ method dial*( self: Dialer, peerId: PeerId, addrs: seq[MultiAddress], - protos: seq[string]): Future[Connection] {.async.} = + protos: seq[string], + forceDial = false): Future[Connection] {.async.} = ## create a protocol stream and establish ## a connection if one doesn't exist already ## @@ -218,7 +223,7 @@ method dial*( try: trace "Dialing (new)", peerId, protos - conn = await self.internalConnect(peerId, addrs) + conn = await self.internalConnect(peerId, addrs, forceDial) trace "Opening stream", conn stream = await self.connManager.getStream(conn) diff --git a/libp2p/switch.nim b/libp2p/switch.nim index c7a0a5b..da68e86 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -99,8 +99,9 @@ proc disconnect*(s: Switch, peerId: PeerId): Future[void] {.gcsafe.} = method connect*( s: Switch, peerId: PeerId, - addrs: seq[MultiAddress]): Future[void] = - s.dialer.connect(peerId, addrs) + addrs: seq[MultiAddress], + forceDial = false): Future[void] = + s.dialer.connect(peerId, addrs, forceDial) method dial*( s: Switch, @@ -117,8 +118,9 @@ method dial*( s: Switch, peerId: PeerId, addrs: seq[MultiAddress], - protos: seq[string]): Future[Connection] = - s.dialer.dial(peerId, addrs, protos) + protos: seq[string], + forceDial = false): Future[Connection] = + s.dialer.dial(peerId, addrs, protos, forceDial) proc dial*( s: Switch, diff --git a/libp2p/utils/semaphore.nim b/libp2p/utils/semaphore.nim index 8ded05e..e396f27 100644 --- a/libp2p/utils/semaphore.nim +++ b/libp2p/utils/semaphore.nim @@ -54,16 +54,21 @@ proc acquire*(s: AsyncSemaphore): Future[void] = fut.cancelCallback = nil if not fut.finished: s.queue.keepItIf( it != fut ) - s.count.inc fut.cancelCallback = cancellation s.queue.add(fut) - s.count.dec trace "Queued slot", available = s.count, queue = s.queue.len return fut +proc forceAcquire*(s: AsyncSemaphore) = + ## ForceAcquire will always succeed, + ## creating a temporary slot if required. + ## This temporary slot will stay usable until + ## there is less `acquire`s than `release`s + s.count.dec + proc release*(s: AsyncSemaphore) = ## Release a resource from the semaphore, ## by picking the first future from the queue @@ -77,13 +82,15 @@ proc release*(s: AsyncSemaphore) = trace "Releasing slot", available = s.count, queue = s.queue.len - if s.queue.len > 0: + s.count.inc + while s.queue.len > 0: var fut = s.queue[0] s.queue.delete(0) if not fut.finished(): + s.count.dec fut.complete() + break - s.count.inc # increment the resource count trace "Released slot", available = s.count, queue = s.queue.len return diff --git a/tests/testconnmngr.nim b/tests/testconnmngr.nim index 8119265..6aaa4ed 100644 --- a/tests/testconnmngr.nim +++ b/tests/testconnmngr.nim @@ -463,3 +463,33 @@ suite "Connection Manager": await connMngr.close() await allFuturesThrowing( allFutures(conns.mapIt( it.close() ))) + + asyncTest "allow force dial": + let connMngr = ConnManager.new(maxConnections = 2) + + var conns: seq[Connection] + for i in 0..<3: + let conn = connMngr.trackOutgoingConn( + (proc(): Future[Connection] {.async.} = + return Connection.new( + PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), + Direction.In) + ), true + ) + + check await conn.withTimeout(10.millis) + conns.add(await conn) + + # should throw adding a connection over the limit + expect TooManyConnectionsError: + discard await connMngr.trackOutgoingConn( + (proc(): Future[Connection] {.async.} = + return Connection.new( + PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), + Direction.In) + ), false + ) + + await connMngr.close() + await allFuturesThrowing( + allFutures(conns.mapIt( it.close() ))) diff --git a/tests/testsemaphore.nim b/tests/testsemaphore.nim index c7a26d8..09a7d29 100644 --- a/tests/testsemaphore.nim +++ b/tests/testsemaphore.nim @@ -36,7 +36,7 @@ suite "AsyncSemaphore": await sema.acquire() let fut = sema.acquire() - check sema.count == -1 + check sema.count == 0 sema.release() sema.release() check sema.count == 1 @@ -66,7 +66,7 @@ suite "AsyncSemaphore": let fut = sema.acquire() check fut.finished == false - check sema.count == -1 + check sema.count == 0 sema.release() sema.release() @@ -104,12 +104,20 @@ suite "AsyncSemaphore": await sema.acquire() - let tmp = sema.acquire() - check not tmp.finished() + let + tmp = sema.acquire() + tmp2 = sema.acquire() + check: + not tmp.finished() + not tmp2.finished() tmp.cancel() sema.release() + check tmp2.finished() + + sema.release() + check await sema.acquire().withTimeout(10.millis) asyncTest "should handle out of order cancellations": @@ -145,3 +153,43 @@ suite "AsyncSemaphore": sema.release() check await sema.acquire().withTimeout(10.millis) + + asyncTest "should handle forceAcquire properly": + let sema = newAsyncSemaphore(1) + + await sema.acquire() + check not(await sema.acquire().withTimeout(1.millis)) # should not acquire but cancel + + let + fut1 = sema.acquire() + fut2 = sema.acquire() + + sema.forceAcquire() + sema.release() + + await fut1 or fut2 or sleepAsync(1.millis) + check: + fut1.finished() + not fut2.finished() + + sema.release() + await fut1 or fut2 or sleepAsync(1.millis) + check: + fut1.finished() + fut2.finished() + + + sema.forceAcquire() + sema.forceAcquire() + + let + fut3 = sema.acquire() + fut4 = sema.acquire() + fut5 = sema.acquire() + sema.release() + sema.release() + await sleepAsync(1.millis) + check: + fut3.finished() + fut4.finished() + not fut5.finished()