mirror of
https://github.com/codex-storage/nim-libp2p.git
synced 2025-01-23 17:29:58 +00:00
Allow force dial (#696)
This commit is contained in:
parent
f98bf612bd
commit
c09d032133
@ -452,7 +452,8 @@ proc trackIncomingConn*(c: ConnManager,
|
||||
raise exc
|
||||
|
||||
proc trackOutgoingConn*(c: ConnManager,
|
||||
provider: ConnProvider):
|
||||
provider: ConnProvider,
|
||||
forceDial = false):
|
||||
Future[Connection] {.async.} =
|
||||
## try acquiring a connection if all slots
|
||||
## are already taken, raise TooManyConnectionsError
|
||||
@ -462,7 +463,9 @@ proc trackOutgoingConn*(c: ConnManager,
|
||||
trace "Tracking outgoing connection", count = c.outSema.count,
|
||||
max = c.outSema.size
|
||||
|
||||
if not c.outSema.tryAcquire():
|
||||
if forceDial:
|
||||
c.outSema.forceAcquire()
|
||||
elif not c.outSema.tryAcquire():
|
||||
trace "Too many outgoing connections!", count = c.outSema.count,
|
||||
max = c.outSema.size
|
||||
raise newTooManyConnectionsError()
|
||||
|
@ -19,7 +19,8 @@ type
|
||||
method connect*(
|
||||
self: Dial,
|
||||
peerId: PeerId,
|
||||
addrs: seq[MultiAddress]) {.async, base.} =
|
||||
addrs: seq[MultiAddress],
|
||||
forceDial = false) {.async, base.} =
|
||||
## connect remote peer without negotiating
|
||||
## a protocol
|
||||
##
|
||||
@ -29,7 +30,8 @@ method connect*(
|
||||
method dial*(
|
||||
self: Dial,
|
||||
peerId: PeerId,
|
||||
protos: seq[string]): Future[Connection] {.async, base.} =
|
||||
protos: seq[string],
|
||||
): Future[Connection] {.async, base.} =
|
||||
## create a protocol stream over an
|
||||
## existing connection
|
||||
##
|
||||
@ -40,7 +42,8 @@ method dial*(
|
||||
self: Dial,
|
||||
peerId: PeerId,
|
||||
addrs: seq[MultiAddress],
|
||||
protos: seq[string]): Future[Connection] {.async, base.} =
|
||||
protos: seq[string],
|
||||
forceDial = false): Future[Connection] {.async, base.} =
|
||||
## create a protocol stream and establish
|
||||
## a connection if one doesn't exist already
|
||||
##
|
||||
|
@ -47,7 +47,8 @@ type
|
||||
proc dialAndUpgrade(
|
||||
self: Dialer,
|
||||
peerId: PeerId,
|
||||
addrs: seq[MultiAddress]):
|
||||
addrs: seq[MultiAddress],
|
||||
forceDial: bool):
|
||||
Future[Connection] {.async.} =
|
||||
debug "Dialing peer", peerId
|
||||
|
||||
@ -72,7 +73,8 @@ proc dialAndUpgrade(
|
||||
transportCopy = transport
|
||||
addressCopy = a
|
||||
await self.connManager.trackOutgoingConn(
|
||||
() => transportCopy.dial(hostname, addressCopy)
|
||||
() => transportCopy.dial(hostname, addressCopy),
|
||||
forceDial
|
||||
)
|
||||
except TooManyConnectionsError as exc:
|
||||
trace "Connection limit reached!"
|
||||
@ -112,7 +114,8 @@ proc dialAndUpgrade(
|
||||
proc internalConnect(
|
||||
self: Dialer,
|
||||
peerId: PeerId,
|
||||
addrs: seq[MultiAddress]):
|
||||
addrs: seq[MultiAddress],
|
||||
forceDial: bool):
|
||||
Future[Connection] {.async.} =
|
||||
if self.localPeerId == peerId:
|
||||
raise newException(CatchableError, "can't dial self!")
|
||||
@ -136,7 +139,7 @@ proc internalConnect(
|
||||
trace "Reusing existing connection", conn, direction = $conn.dir
|
||||
return conn
|
||||
|
||||
conn = await self.dialAndUpgrade(peerId, addrs)
|
||||
conn = await self.dialAndUpgrade(peerId, addrs, forceDial)
|
||||
if isNil(conn): # None of the addresses connected
|
||||
raise newException(DialFailedError, "Unable to establish outgoing link")
|
||||
|
||||
@ -159,7 +162,8 @@ proc internalConnect(
|
||||
method connect*(
|
||||
self: Dialer,
|
||||
peerId: PeerId,
|
||||
addrs: seq[MultiAddress]) {.async.} =
|
||||
addrs: seq[MultiAddress],
|
||||
forceDial = false) {.async.} =
|
||||
## connect remote peer without negotiating
|
||||
## a protocol
|
||||
##
|
||||
@ -167,7 +171,7 @@ method connect*(
|
||||
if self.connManager.connCount(peerId) > 0:
|
||||
return
|
||||
|
||||
discard await self.internalConnect(peerId, addrs)
|
||||
discard await self.internalConnect(peerId, addrs, forceDial)
|
||||
|
||||
proc negotiateStream(
|
||||
self: Dialer,
|
||||
@ -200,7 +204,8 @@ method dial*(
|
||||
self: Dialer,
|
||||
peerId: PeerId,
|
||||
addrs: seq[MultiAddress],
|
||||
protos: seq[string]): Future[Connection] {.async.} =
|
||||
protos: seq[string],
|
||||
forceDial = false): Future[Connection] {.async.} =
|
||||
## create a protocol stream and establish
|
||||
## a connection if one doesn't exist already
|
||||
##
|
||||
@ -218,7 +223,7 @@ method dial*(
|
||||
|
||||
try:
|
||||
trace "Dialing (new)", peerId, protos
|
||||
conn = await self.internalConnect(peerId, addrs)
|
||||
conn = await self.internalConnect(peerId, addrs, forceDial)
|
||||
trace "Opening stream", conn
|
||||
stream = await self.connManager.getStream(conn)
|
||||
|
||||
|
@ -99,8 +99,9 @@ proc disconnect*(s: Switch, peerId: PeerId): Future[void] {.gcsafe.} =
|
||||
method connect*(
|
||||
s: Switch,
|
||||
peerId: PeerId,
|
||||
addrs: seq[MultiAddress]): Future[void] =
|
||||
s.dialer.connect(peerId, addrs)
|
||||
addrs: seq[MultiAddress],
|
||||
forceDial = false): Future[void] =
|
||||
s.dialer.connect(peerId, addrs, forceDial)
|
||||
|
||||
method dial*(
|
||||
s: Switch,
|
||||
@ -117,8 +118,9 @@ method dial*(
|
||||
s: Switch,
|
||||
peerId: PeerId,
|
||||
addrs: seq[MultiAddress],
|
||||
protos: seq[string]): Future[Connection] =
|
||||
s.dialer.dial(peerId, addrs, protos)
|
||||
protos: seq[string],
|
||||
forceDial = false): Future[Connection] =
|
||||
s.dialer.dial(peerId, addrs, protos, forceDial)
|
||||
|
||||
proc dial*(
|
||||
s: Switch,
|
||||
|
@ -54,16 +54,21 @@ proc acquire*(s: AsyncSemaphore): Future[void] =
|
||||
fut.cancelCallback = nil
|
||||
if not fut.finished:
|
||||
s.queue.keepItIf( it != fut )
|
||||
s.count.inc
|
||||
|
||||
fut.cancelCallback = cancellation
|
||||
|
||||
s.queue.add(fut)
|
||||
s.count.dec
|
||||
|
||||
trace "Queued slot", available = s.count, queue = s.queue.len
|
||||
return fut
|
||||
|
||||
proc forceAcquire*(s: AsyncSemaphore) =
|
||||
## ForceAcquire will always succeed,
|
||||
## creating a temporary slot if required.
|
||||
## This temporary slot will stay usable until
|
||||
## there is less `acquire`s than `release`s
|
||||
s.count.dec
|
||||
|
||||
proc release*(s: AsyncSemaphore) =
|
||||
## Release a resource from the semaphore,
|
||||
## by picking the first future from the queue
|
||||
@ -77,13 +82,15 @@ proc release*(s: AsyncSemaphore) =
|
||||
trace "Releasing slot", available = s.count,
|
||||
queue = s.queue.len
|
||||
|
||||
if s.queue.len > 0:
|
||||
s.count.inc
|
||||
while s.queue.len > 0:
|
||||
var fut = s.queue[0]
|
||||
s.queue.delete(0)
|
||||
if not fut.finished():
|
||||
s.count.dec
|
||||
fut.complete()
|
||||
break
|
||||
|
||||
s.count.inc # increment the resource count
|
||||
trace "Released slot", available = s.count,
|
||||
queue = s.queue.len
|
||||
return
|
||||
|
@ -463,3 +463,33 @@ suite "Connection Manager":
|
||||
await connMngr.close()
|
||||
await allFuturesThrowing(
|
||||
allFutures(conns.mapIt( it.close() )))
|
||||
|
||||
asyncTest "allow force dial":
|
||||
let connMngr = ConnManager.new(maxConnections = 2)
|
||||
|
||||
var conns: seq[Connection]
|
||||
for i in 0..<3:
|
||||
let conn = connMngr.trackOutgoingConn(
|
||||
(proc(): Future[Connection] {.async.} =
|
||||
return Connection.new(
|
||||
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||
Direction.In)
|
||||
), true
|
||||
)
|
||||
|
||||
check await conn.withTimeout(10.millis)
|
||||
conns.add(await conn)
|
||||
|
||||
# should throw adding a connection over the limit
|
||||
expect TooManyConnectionsError:
|
||||
discard await connMngr.trackOutgoingConn(
|
||||
(proc(): Future[Connection] {.async.} =
|
||||
return Connection.new(
|
||||
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||
Direction.In)
|
||||
), false
|
||||
)
|
||||
|
||||
await connMngr.close()
|
||||
await allFuturesThrowing(
|
||||
allFutures(conns.mapIt( it.close() )))
|
||||
|
@ -36,7 +36,7 @@ suite "AsyncSemaphore":
|
||||
await sema.acquire()
|
||||
let fut = sema.acquire()
|
||||
|
||||
check sema.count == -1
|
||||
check sema.count == 0
|
||||
sema.release()
|
||||
sema.release()
|
||||
check sema.count == 1
|
||||
@ -66,7 +66,7 @@ suite "AsyncSemaphore":
|
||||
|
||||
let fut = sema.acquire()
|
||||
check fut.finished == false
|
||||
check sema.count == -1
|
||||
check sema.count == 0
|
||||
|
||||
sema.release()
|
||||
sema.release()
|
||||
@ -104,12 +104,20 @@ suite "AsyncSemaphore":
|
||||
|
||||
await sema.acquire()
|
||||
|
||||
let tmp = sema.acquire()
|
||||
check not tmp.finished()
|
||||
let
|
||||
tmp = sema.acquire()
|
||||
tmp2 = sema.acquire()
|
||||
check:
|
||||
not tmp.finished()
|
||||
not tmp2.finished()
|
||||
|
||||
tmp.cancel()
|
||||
sema.release()
|
||||
|
||||
check tmp2.finished()
|
||||
|
||||
sema.release()
|
||||
|
||||
check await sema.acquire().withTimeout(10.millis)
|
||||
|
||||
asyncTest "should handle out of order cancellations":
|
||||
@ -145,3 +153,43 @@ suite "AsyncSemaphore":
|
||||
sema.release()
|
||||
|
||||
check await sema.acquire().withTimeout(10.millis)
|
||||
|
||||
asyncTest "should handle forceAcquire properly":
|
||||
let sema = newAsyncSemaphore(1)
|
||||
|
||||
await sema.acquire()
|
||||
check not(await sema.acquire().withTimeout(1.millis)) # should not acquire but cancel
|
||||
|
||||
let
|
||||
fut1 = sema.acquire()
|
||||
fut2 = sema.acquire()
|
||||
|
||||
sema.forceAcquire()
|
||||
sema.release()
|
||||
|
||||
await fut1 or fut2 or sleepAsync(1.millis)
|
||||
check:
|
||||
fut1.finished()
|
||||
not fut2.finished()
|
||||
|
||||
sema.release()
|
||||
await fut1 or fut2 or sleepAsync(1.millis)
|
||||
check:
|
||||
fut1.finished()
|
||||
fut2.finished()
|
||||
|
||||
|
||||
sema.forceAcquire()
|
||||
sema.forceAcquire()
|
||||
|
||||
let
|
||||
fut3 = sema.acquire()
|
||||
fut4 = sema.acquire()
|
||||
fut5 = sema.acquire()
|
||||
sema.release()
|
||||
sema.release()
|
||||
await sleepAsync(1.millis)
|
||||
check:
|
||||
fut3.finished()
|
||||
fut4.finished()
|
||||
not fut5.finished()
|
||||
|
Loading…
x
Reference in New Issue
Block a user