Allow force dial (#696)

This commit is contained in:
Tanguy 2022-02-24 17:31:47 +01:00 committed by GitHub
parent f98bf612bd
commit c09d032133
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 123 additions and 25 deletions

View File

@ -452,7 +452,8 @@ proc trackIncomingConn*(c: ConnManager,
raise exc
proc trackOutgoingConn*(c: ConnManager,
provider: ConnProvider):
provider: ConnProvider,
forceDial = false):
Future[Connection] {.async.} =
## try acquiring a connection if all slots
## are already taken, raise TooManyConnectionsError
@ -462,7 +463,9 @@ proc trackOutgoingConn*(c: ConnManager,
trace "Tracking outgoing connection", count = c.outSema.count,
max = c.outSema.size
if not c.outSema.tryAcquire():
if forceDial:
c.outSema.forceAcquire()
elif not c.outSema.tryAcquire():
trace "Too many outgoing connections!", count = c.outSema.count,
max = c.outSema.size
raise newTooManyConnectionsError()

View File

@ -19,7 +19,8 @@ type
method connect*(
self: Dial,
peerId: PeerId,
addrs: seq[MultiAddress]) {.async, base.} =
addrs: seq[MultiAddress],
forceDial = false) {.async, base.} =
## connect remote peer without negotiating
## a protocol
##
@ -29,7 +30,8 @@ method connect*(
method dial*(
self: Dial,
peerId: PeerId,
protos: seq[string]): Future[Connection] {.async, base.} =
protos: seq[string],
): Future[Connection] {.async, base.} =
## create a protocol stream over an
## existing connection
##
@ -40,7 +42,8 @@ method dial*(
self: Dial,
peerId: PeerId,
addrs: seq[MultiAddress],
protos: seq[string]): Future[Connection] {.async, base.} =
protos: seq[string],
forceDial = false): Future[Connection] {.async, base.} =
## create a protocol stream and establish
## a connection if one doesn't exist already
##

View File

@ -47,7 +47,8 @@ type
proc dialAndUpgrade(
self: Dialer,
peerId: PeerId,
addrs: seq[MultiAddress]):
addrs: seq[MultiAddress],
forceDial: bool):
Future[Connection] {.async.} =
debug "Dialing peer", peerId
@ -72,7 +73,8 @@ proc dialAndUpgrade(
transportCopy = transport
addressCopy = a
await self.connManager.trackOutgoingConn(
() => transportCopy.dial(hostname, addressCopy)
() => transportCopy.dial(hostname, addressCopy),
forceDial
)
except TooManyConnectionsError as exc:
trace "Connection limit reached!"
@ -112,7 +114,8 @@ proc dialAndUpgrade(
proc internalConnect(
self: Dialer,
peerId: PeerId,
addrs: seq[MultiAddress]):
addrs: seq[MultiAddress],
forceDial: bool):
Future[Connection] {.async.} =
if self.localPeerId == peerId:
raise newException(CatchableError, "can't dial self!")
@ -136,7 +139,7 @@ proc internalConnect(
trace "Reusing existing connection", conn, direction = $conn.dir
return conn
conn = await self.dialAndUpgrade(peerId, addrs)
conn = await self.dialAndUpgrade(peerId, addrs, forceDial)
if isNil(conn): # None of the addresses connected
raise newException(DialFailedError, "Unable to establish outgoing link")
@ -159,7 +162,8 @@ proc internalConnect(
method connect*(
self: Dialer,
peerId: PeerId,
addrs: seq[MultiAddress]) {.async.} =
addrs: seq[MultiAddress],
forceDial = false) {.async.} =
## connect remote peer without negotiating
## a protocol
##
@ -167,7 +171,7 @@ method connect*(
if self.connManager.connCount(peerId) > 0:
return
discard await self.internalConnect(peerId, addrs)
discard await self.internalConnect(peerId, addrs, forceDial)
proc negotiateStream(
self: Dialer,
@ -200,7 +204,8 @@ method dial*(
self: Dialer,
peerId: PeerId,
addrs: seq[MultiAddress],
protos: seq[string]): Future[Connection] {.async.} =
protos: seq[string],
forceDial = false): Future[Connection] {.async.} =
## create a protocol stream and establish
## a connection if one doesn't exist already
##
@ -218,7 +223,7 @@ method dial*(
try:
trace "Dialing (new)", peerId, protos
conn = await self.internalConnect(peerId, addrs)
conn = await self.internalConnect(peerId, addrs, forceDial)
trace "Opening stream", conn
stream = await self.connManager.getStream(conn)

View File

@ -99,8 +99,9 @@ proc disconnect*(s: Switch, peerId: PeerId): Future[void] {.gcsafe.} =
method connect*(
s: Switch,
peerId: PeerId,
addrs: seq[MultiAddress]): Future[void] =
s.dialer.connect(peerId, addrs)
addrs: seq[MultiAddress],
forceDial = false): Future[void] =
s.dialer.connect(peerId, addrs, forceDial)
method dial*(
s: Switch,
@ -117,8 +118,9 @@ method dial*(
s: Switch,
peerId: PeerId,
addrs: seq[MultiAddress],
protos: seq[string]): Future[Connection] =
s.dialer.dial(peerId, addrs, protos)
protos: seq[string],
forceDial = false): Future[Connection] =
s.dialer.dial(peerId, addrs, protos, forceDial)
proc dial*(
s: Switch,

View File

@ -54,16 +54,21 @@ proc acquire*(s: AsyncSemaphore): Future[void] =
fut.cancelCallback = nil
if not fut.finished:
s.queue.keepItIf( it != fut )
s.count.inc
fut.cancelCallback = cancellation
s.queue.add(fut)
s.count.dec
trace "Queued slot", available = s.count, queue = s.queue.len
return fut
proc forceAcquire*(s: AsyncSemaphore) =
## ForceAcquire will always succeed,
## creating a temporary slot if required.
## This temporary slot will stay usable until
## there is less `acquire`s than `release`s
s.count.dec
proc release*(s: AsyncSemaphore) =
## Release a resource from the semaphore,
## by picking the first future from the queue
@ -77,13 +82,15 @@ proc release*(s: AsyncSemaphore) =
trace "Releasing slot", available = s.count,
queue = s.queue.len
if s.queue.len > 0:
s.count.inc
while s.queue.len > 0:
var fut = s.queue[0]
s.queue.delete(0)
if not fut.finished():
s.count.dec
fut.complete()
break
s.count.inc # increment the resource count
trace "Released slot", available = s.count,
queue = s.queue.len
return

View File

@ -463,3 +463,33 @@ suite "Connection Manager":
await connMngr.close()
await allFuturesThrowing(
allFutures(conns.mapIt( it.close() )))
asyncTest "allow force dial":
let connMngr = ConnManager.new(maxConnections = 2)
var conns: seq[Connection]
for i in 0..<3:
let conn = connMngr.trackOutgoingConn(
(proc(): Future[Connection] {.async.} =
return Connection.new(
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
Direction.In)
), true
)
check await conn.withTimeout(10.millis)
conns.add(await conn)
# should throw adding a connection over the limit
expect TooManyConnectionsError:
discard await connMngr.trackOutgoingConn(
(proc(): Future[Connection] {.async.} =
return Connection.new(
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
Direction.In)
), false
)
await connMngr.close()
await allFuturesThrowing(
allFutures(conns.mapIt( it.close() )))

View File

@ -36,7 +36,7 @@ suite "AsyncSemaphore":
await sema.acquire()
let fut = sema.acquire()
check sema.count == -1
check sema.count == 0
sema.release()
sema.release()
check sema.count == 1
@ -66,7 +66,7 @@ suite "AsyncSemaphore":
let fut = sema.acquire()
check fut.finished == false
check sema.count == -1
check sema.count == 0
sema.release()
sema.release()
@ -104,12 +104,20 @@ suite "AsyncSemaphore":
await sema.acquire()
let tmp = sema.acquire()
check not tmp.finished()
let
tmp = sema.acquire()
tmp2 = sema.acquire()
check:
not tmp.finished()
not tmp2.finished()
tmp.cancel()
sema.release()
check tmp2.finished()
sema.release()
check await sema.acquire().withTimeout(10.millis)
asyncTest "should handle out of order cancellations":
@ -145,3 +153,43 @@ suite "AsyncSemaphore":
sema.release()
check await sema.acquire().withTimeout(10.millis)
asyncTest "should handle forceAcquire properly":
let sema = newAsyncSemaphore(1)
await sema.acquire()
check not(await sema.acquire().withTimeout(1.millis)) # should not acquire but cancel
let
fut1 = sema.acquire()
fut2 = sema.acquire()
sema.forceAcquire()
sema.release()
await fut1 or fut2 or sleepAsync(1.millis)
check:
fut1.finished()
not fut2.finished()
sema.release()
await fut1 or fut2 or sleepAsync(1.millis)
check:
fut1.finished()
fut2.finished()
sema.forceAcquire()
sema.forceAcquire()
let
fut3 = sema.acquire()
fut4 = sema.acquire()
fut5 = sema.acquire()
sema.release()
sema.release()
await sleepAsync(1.millis)
check:
fut3.finished()
fut4.finished()
not fut5.finished()