Allow force dial (#696)

This commit is contained in:
Tanguy 2022-02-24 17:31:47 +01:00 committed by GitHub
parent f98bf612bd
commit c09d032133
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 123 additions and 25 deletions

View File

@ -452,7 +452,8 @@ proc trackIncomingConn*(c: ConnManager,
raise exc raise exc
proc trackOutgoingConn*(c: ConnManager, proc trackOutgoingConn*(c: ConnManager,
provider: ConnProvider): provider: ConnProvider,
forceDial = false):
Future[Connection] {.async.} = Future[Connection] {.async.} =
## try acquiring a connection if all slots ## try acquiring a connection if all slots
## are already taken, raise TooManyConnectionsError ## are already taken, raise TooManyConnectionsError
@ -462,7 +463,9 @@ proc trackOutgoingConn*(c: ConnManager,
trace "Tracking outgoing connection", count = c.outSema.count, trace "Tracking outgoing connection", count = c.outSema.count,
max = c.outSema.size max = c.outSema.size
if not c.outSema.tryAcquire(): if forceDial:
c.outSema.forceAcquire()
elif not c.outSema.tryAcquire():
trace "Too many outgoing connections!", count = c.outSema.count, trace "Too many outgoing connections!", count = c.outSema.count,
max = c.outSema.size max = c.outSema.size
raise newTooManyConnectionsError() raise newTooManyConnectionsError()

View File

@ -19,7 +19,8 @@ type
method connect*( method connect*(
self: Dial, self: Dial,
peerId: PeerId, peerId: PeerId,
addrs: seq[MultiAddress]) {.async, base.} = addrs: seq[MultiAddress],
forceDial = false) {.async, base.} =
## connect remote peer without negotiating ## connect remote peer without negotiating
## a protocol ## a protocol
## ##
@ -29,7 +30,8 @@ method connect*(
method dial*( method dial*(
self: Dial, self: Dial,
peerId: PeerId, peerId: PeerId,
protos: seq[string]): Future[Connection] {.async, base.} = protos: seq[string],
): Future[Connection] {.async, base.} =
## create a protocol stream over an ## create a protocol stream over an
## existing connection ## existing connection
## ##
@ -40,7 +42,8 @@ method dial*(
self: Dial, self: Dial,
peerId: PeerId, peerId: PeerId,
addrs: seq[MultiAddress], addrs: seq[MultiAddress],
protos: seq[string]): Future[Connection] {.async, base.} = protos: seq[string],
forceDial = false): Future[Connection] {.async, base.} =
## create a protocol stream and establish ## create a protocol stream and establish
## a connection if one doesn't exist already ## a connection if one doesn't exist already
## ##

View File

@ -47,7 +47,8 @@ type
proc dialAndUpgrade( proc dialAndUpgrade(
self: Dialer, self: Dialer,
peerId: PeerId, peerId: PeerId,
addrs: seq[MultiAddress]): addrs: seq[MultiAddress],
forceDial: bool):
Future[Connection] {.async.} = Future[Connection] {.async.} =
debug "Dialing peer", peerId debug "Dialing peer", peerId
@ -72,7 +73,8 @@ proc dialAndUpgrade(
transportCopy = transport transportCopy = transport
addressCopy = a addressCopy = a
await self.connManager.trackOutgoingConn( await self.connManager.trackOutgoingConn(
() => transportCopy.dial(hostname, addressCopy) () => transportCopy.dial(hostname, addressCopy),
forceDial
) )
except TooManyConnectionsError as exc: except TooManyConnectionsError as exc:
trace "Connection limit reached!" trace "Connection limit reached!"
@ -112,7 +114,8 @@ proc dialAndUpgrade(
proc internalConnect( proc internalConnect(
self: Dialer, self: Dialer,
peerId: PeerId, peerId: PeerId,
addrs: seq[MultiAddress]): addrs: seq[MultiAddress],
forceDial: bool):
Future[Connection] {.async.} = Future[Connection] {.async.} =
if self.localPeerId == peerId: if self.localPeerId == peerId:
raise newException(CatchableError, "can't dial self!") raise newException(CatchableError, "can't dial self!")
@ -136,7 +139,7 @@ proc internalConnect(
trace "Reusing existing connection", conn, direction = $conn.dir trace "Reusing existing connection", conn, direction = $conn.dir
return conn return conn
conn = await self.dialAndUpgrade(peerId, addrs) conn = await self.dialAndUpgrade(peerId, addrs, forceDial)
if isNil(conn): # None of the addresses connected if isNil(conn): # None of the addresses connected
raise newException(DialFailedError, "Unable to establish outgoing link") raise newException(DialFailedError, "Unable to establish outgoing link")
@ -159,7 +162,8 @@ proc internalConnect(
method connect*( method connect*(
self: Dialer, self: Dialer,
peerId: PeerId, peerId: PeerId,
addrs: seq[MultiAddress]) {.async.} = addrs: seq[MultiAddress],
forceDial = false) {.async.} =
## connect remote peer without negotiating ## connect remote peer without negotiating
## a protocol ## a protocol
## ##
@ -167,7 +171,7 @@ method connect*(
if self.connManager.connCount(peerId) > 0: if self.connManager.connCount(peerId) > 0:
return return
discard await self.internalConnect(peerId, addrs) discard await self.internalConnect(peerId, addrs, forceDial)
proc negotiateStream( proc negotiateStream(
self: Dialer, self: Dialer,
@ -200,7 +204,8 @@ method dial*(
self: Dialer, self: Dialer,
peerId: PeerId, peerId: PeerId,
addrs: seq[MultiAddress], addrs: seq[MultiAddress],
protos: seq[string]): Future[Connection] {.async.} = protos: seq[string],
forceDial = false): Future[Connection] {.async.} =
## create a protocol stream and establish ## create a protocol stream and establish
## a connection if one doesn't exist already ## a connection if one doesn't exist already
## ##
@ -218,7 +223,7 @@ method dial*(
try: try:
trace "Dialing (new)", peerId, protos trace "Dialing (new)", peerId, protos
conn = await self.internalConnect(peerId, addrs) conn = await self.internalConnect(peerId, addrs, forceDial)
trace "Opening stream", conn trace "Opening stream", conn
stream = await self.connManager.getStream(conn) stream = await self.connManager.getStream(conn)

View File

@ -99,8 +99,9 @@ proc disconnect*(s: Switch, peerId: PeerId): Future[void] {.gcsafe.} =
method connect*( method connect*(
s: Switch, s: Switch,
peerId: PeerId, peerId: PeerId,
addrs: seq[MultiAddress]): Future[void] = addrs: seq[MultiAddress],
s.dialer.connect(peerId, addrs) forceDial = false): Future[void] =
s.dialer.connect(peerId, addrs, forceDial)
method dial*( method dial*(
s: Switch, s: Switch,
@ -117,8 +118,9 @@ method dial*(
s: Switch, s: Switch,
peerId: PeerId, peerId: PeerId,
addrs: seq[MultiAddress], addrs: seq[MultiAddress],
protos: seq[string]): Future[Connection] = protos: seq[string],
s.dialer.dial(peerId, addrs, protos) forceDial = false): Future[Connection] =
s.dialer.dial(peerId, addrs, protos, forceDial)
proc dial*( proc dial*(
s: Switch, s: Switch,

View File

@ -54,16 +54,21 @@ proc acquire*(s: AsyncSemaphore): Future[void] =
fut.cancelCallback = nil fut.cancelCallback = nil
if not fut.finished: if not fut.finished:
s.queue.keepItIf( it != fut ) s.queue.keepItIf( it != fut )
s.count.inc
fut.cancelCallback = cancellation fut.cancelCallback = cancellation
s.queue.add(fut) s.queue.add(fut)
s.count.dec
trace "Queued slot", available = s.count, queue = s.queue.len trace "Queued slot", available = s.count, queue = s.queue.len
return fut return fut
proc forceAcquire*(s: AsyncSemaphore) =
## ForceAcquire will always succeed,
## creating a temporary slot if required.
## This temporary slot will stay usable until
## there is less `acquire`s than `release`s
s.count.dec
proc release*(s: AsyncSemaphore) = proc release*(s: AsyncSemaphore) =
## Release a resource from the semaphore, ## Release a resource from the semaphore,
## by picking the first future from the queue ## by picking the first future from the queue
@ -77,13 +82,15 @@ proc release*(s: AsyncSemaphore) =
trace "Releasing slot", available = s.count, trace "Releasing slot", available = s.count,
queue = s.queue.len queue = s.queue.len
if s.queue.len > 0: s.count.inc
while s.queue.len > 0:
var fut = s.queue[0] var fut = s.queue[0]
s.queue.delete(0) s.queue.delete(0)
if not fut.finished(): if not fut.finished():
s.count.dec
fut.complete() fut.complete()
break
s.count.inc # increment the resource count
trace "Released slot", available = s.count, trace "Released slot", available = s.count,
queue = s.queue.len queue = s.queue.len
return return

View File

@ -463,3 +463,33 @@ suite "Connection Manager":
await connMngr.close() await connMngr.close()
await allFuturesThrowing( await allFuturesThrowing(
allFutures(conns.mapIt( it.close() ))) allFutures(conns.mapIt( it.close() )))
asyncTest "allow force dial":
let connMngr = ConnManager.new(maxConnections = 2)
var conns: seq[Connection]
for i in 0..<3:
let conn = connMngr.trackOutgoingConn(
(proc(): Future[Connection] {.async.} =
return Connection.new(
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
Direction.In)
), true
)
check await conn.withTimeout(10.millis)
conns.add(await conn)
# should throw adding a connection over the limit
expect TooManyConnectionsError:
discard await connMngr.trackOutgoingConn(
(proc(): Future[Connection] {.async.} =
return Connection.new(
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
Direction.In)
), false
)
await connMngr.close()
await allFuturesThrowing(
allFutures(conns.mapIt( it.close() )))

View File

@ -36,7 +36,7 @@ suite "AsyncSemaphore":
await sema.acquire() await sema.acquire()
let fut = sema.acquire() let fut = sema.acquire()
check sema.count == -1 check sema.count == 0
sema.release() sema.release()
sema.release() sema.release()
check sema.count == 1 check sema.count == 1
@ -66,7 +66,7 @@ suite "AsyncSemaphore":
let fut = sema.acquire() let fut = sema.acquire()
check fut.finished == false check fut.finished == false
check sema.count == -1 check sema.count == 0
sema.release() sema.release()
sema.release() sema.release()
@ -104,12 +104,20 @@ suite "AsyncSemaphore":
await sema.acquire() await sema.acquire()
let tmp = sema.acquire() let
check not tmp.finished() tmp = sema.acquire()
tmp2 = sema.acquire()
check:
not tmp.finished()
not tmp2.finished()
tmp.cancel() tmp.cancel()
sema.release() sema.release()
check tmp2.finished()
sema.release()
check await sema.acquire().withTimeout(10.millis) check await sema.acquire().withTimeout(10.millis)
asyncTest "should handle out of order cancellations": asyncTest "should handle out of order cancellations":
@ -145,3 +153,43 @@ suite "AsyncSemaphore":
sema.release() sema.release()
check await sema.acquire().withTimeout(10.millis) check await sema.acquire().withTimeout(10.millis)
asyncTest "should handle forceAcquire properly":
let sema = newAsyncSemaphore(1)
await sema.acquire()
check not(await sema.acquire().withTimeout(1.millis)) # should not acquire but cancel
let
fut1 = sema.acquire()
fut2 = sema.acquire()
sema.forceAcquire()
sema.release()
await fut1 or fut2 or sleepAsync(1.millis)
check:
fut1.finished()
not fut2.finished()
sema.release()
await fut1 or fut2 or sleepAsync(1.millis)
check:
fut1.finished()
fut2.finished()
sema.forceAcquire()
sema.forceAcquire()
let
fut3 = sema.acquire()
fut4 = sema.acquire()
fut5 = sema.acquire()
sema.release()
sema.release()
await sleepAsync(1.millis)
check:
fut3.finished()
fut4.finished()
not fut5.finished()