mirror of
https://github.com/vacp2p/nim-libp2p-experimental.git
synced 2025-01-27 02:25:21 +00:00
Allow force dial (#696)
This commit is contained in:
parent
f98bf612bd
commit
c09d032133
@ -452,7 +452,8 @@ proc trackIncomingConn*(c: ConnManager,
|
|||||||
raise exc
|
raise exc
|
||||||
|
|
||||||
proc trackOutgoingConn*(c: ConnManager,
|
proc trackOutgoingConn*(c: ConnManager,
|
||||||
provider: ConnProvider):
|
provider: ConnProvider,
|
||||||
|
forceDial = false):
|
||||||
Future[Connection] {.async.} =
|
Future[Connection] {.async.} =
|
||||||
## try acquiring a connection if all slots
|
## try acquiring a connection if all slots
|
||||||
## are already taken, raise TooManyConnectionsError
|
## are already taken, raise TooManyConnectionsError
|
||||||
@ -462,7 +463,9 @@ proc trackOutgoingConn*(c: ConnManager,
|
|||||||
trace "Tracking outgoing connection", count = c.outSema.count,
|
trace "Tracking outgoing connection", count = c.outSema.count,
|
||||||
max = c.outSema.size
|
max = c.outSema.size
|
||||||
|
|
||||||
if not c.outSema.tryAcquire():
|
if forceDial:
|
||||||
|
c.outSema.forceAcquire()
|
||||||
|
elif not c.outSema.tryAcquire():
|
||||||
trace "Too many outgoing connections!", count = c.outSema.count,
|
trace "Too many outgoing connections!", count = c.outSema.count,
|
||||||
max = c.outSema.size
|
max = c.outSema.size
|
||||||
raise newTooManyConnectionsError()
|
raise newTooManyConnectionsError()
|
||||||
|
@ -19,7 +19,8 @@ type
|
|||||||
method connect*(
|
method connect*(
|
||||||
self: Dial,
|
self: Dial,
|
||||||
peerId: PeerId,
|
peerId: PeerId,
|
||||||
addrs: seq[MultiAddress]) {.async, base.} =
|
addrs: seq[MultiAddress],
|
||||||
|
forceDial = false) {.async, base.} =
|
||||||
## connect remote peer without negotiating
|
## connect remote peer without negotiating
|
||||||
## a protocol
|
## a protocol
|
||||||
##
|
##
|
||||||
@ -29,7 +30,8 @@ method connect*(
|
|||||||
method dial*(
|
method dial*(
|
||||||
self: Dial,
|
self: Dial,
|
||||||
peerId: PeerId,
|
peerId: PeerId,
|
||||||
protos: seq[string]): Future[Connection] {.async, base.} =
|
protos: seq[string],
|
||||||
|
): Future[Connection] {.async, base.} =
|
||||||
## create a protocol stream over an
|
## create a protocol stream over an
|
||||||
## existing connection
|
## existing connection
|
||||||
##
|
##
|
||||||
@ -40,7 +42,8 @@ method dial*(
|
|||||||
self: Dial,
|
self: Dial,
|
||||||
peerId: PeerId,
|
peerId: PeerId,
|
||||||
addrs: seq[MultiAddress],
|
addrs: seq[MultiAddress],
|
||||||
protos: seq[string]): Future[Connection] {.async, base.} =
|
protos: seq[string],
|
||||||
|
forceDial = false): Future[Connection] {.async, base.} =
|
||||||
## create a protocol stream and establish
|
## create a protocol stream and establish
|
||||||
## a connection if one doesn't exist already
|
## a connection if one doesn't exist already
|
||||||
##
|
##
|
||||||
|
@ -47,7 +47,8 @@ type
|
|||||||
proc dialAndUpgrade(
|
proc dialAndUpgrade(
|
||||||
self: Dialer,
|
self: Dialer,
|
||||||
peerId: PeerId,
|
peerId: PeerId,
|
||||||
addrs: seq[MultiAddress]):
|
addrs: seq[MultiAddress],
|
||||||
|
forceDial: bool):
|
||||||
Future[Connection] {.async.} =
|
Future[Connection] {.async.} =
|
||||||
debug "Dialing peer", peerId
|
debug "Dialing peer", peerId
|
||||||
|
|
||||||
@ -72,7 +73,8 @@ proc dialAndUpgrade(
|
|||||||
transportCopy = transport
|
transportCopy = transport
|
||||||
addressCopy = a
|
addressCopy = a
|
||||||
await self.connManager.trackOutgoingConn(
|
await self.connManager.trackOutgoingConn(
|
||||||
() => transportCopy.dial(hostname, addressCopy)
|
() => transportCopy.dial(hostname, addressCopy),
|
||||||
|
forceDial
|
||||||
)
|
)
|
||||||
except TooManyConnectionsError as exc:
|
except TooManyConnectionsError as exc:
|
||||||
trace "Connection limit reached!"
|
trace "Connection limit reached!"
|
||||||
@ -112,7 +114,8 @@ proc dialAndUpgrade(
|
|||||||
proc internalConnect(
|
proc internalConnect(
|
||||||
self: Dialer,
|
self: Dialer,
|
||||||
peerId: PeerId,
|
peerId: PeerId,
|
||||||
addrs: seq[MultiAddress]):
|
addrs: seq[MultiAddress],
|
||||||
|
forceDial: bool):
|
||||||
Future[Connection] {.async.} =
|
Future[Connection] {.async.} =
|
||||||
if self.localPeerId == peerId:
|
if self.localPeerId == peerId:
|
||||||
raise newException(CatchableError, "can't dial self!")
|
raise newException(CatchableError, "can't dial self!")
|
||||||
@ -136,7 +139,7 @@ proc internalConnect(
|
|||||||
trace "Reusing existing connection", conn, direction = $conn.dir
|
trace "Reusing existing connection", conn, direction = $conn.dir
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
conn = await self.dialAndUpgrade(peerId, addrs)
|
conn = await self.dialAndUpgrade(peerId, addrs, forceDial)
|
||||||
if isNil(conn): # None of the addresses connected
|
if isNil(conn): # None of the addresses connected
|
||||||
raise newException(DialFailedError, "Unable to establish outgoing link")
|
raise newException(DialFailedError, "Unable to establish outgoing link")
|
||||||
|
|
||||||
@ -159,7 +162,8 @@ proc internalConnect(
|
|||||||
method connect*(
|
method connect*(
|
||||||
self: Dialer,
|
self: Dialer,
|
||||||
peerId: PeerId,
|
peerId: PeerId,
|
||||||
addrs: seq[MultiAddress]) {.async.} =
|
addrs: seq[MultiAddress],
|
||||||
|
forceDial = false) {.async.} =
|
||||||
## connect remote peer without negotiating
|
## connect remote peer without negotiating
|
||||||
## a protocol
|
## a protocol
|
||||||
##
|
##
|
||||||
@ -167,7 +171,7 @@ method connect*(
|
|||||||
if self.connManager.connCount(peerId) > 0:
|
if self.connManager.connCount(peerId) > 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
discard await self.internalConnect(peerId, addrs)
|
discard await self.internalConnect(peerId, addrs, forceDial)
|
||||||
|
|
||||||
proc negotiateStream(
|
proc negotiateStream(
|
||||||
self: Dialer,
|
self: Dialer,
|
||||||
@ -200,7 +204,8 @@ method dial*(
|
|||||||
self: Dialer,
|
self: Dialer,
|
||||||
peerId: PeerId,
|
peerId: PeerId,
|
||||||
addrs: seq[MultiAddress],
|
addrs: seq[MultiAddress],
|
||||||
protos: seq[string]): Future[Connection] {.async.} =
|
protos: seq[string],
|
||||||
|
forceDial = false): Future[Connection] {.async.} =
|
||||||
## create a protocol stream and establish
|
## create a protocol stream and establish
|
||||||
## a connection if one doesn't exist already
|
## a connection if one doesn't exist already
|
||||||
##
|
##
|
||||||
@ -218,7 +223,7 @@ method dial*(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
trace "Dialing (new)", peerId, protos
|
trace "Dialing (new)", peerId, protos
|
||||||
conn = await self.internalConnect(peerId, addrs)
|
conn = await self.internalConnect(peerId, addrs, forceDial)
|
||||||
trace "Opening stream", conn
|
trace "Opening stream", conn
|
||||||
stream = await self.connManager.getStream(conn)
|
stream = await self.connManager.getStream(conn)
|
||||||
|
|
||||||
|
@ -99,8 +99,9 @@ proc disconnect*(s: Switch, peerId: PeerId): Future[void] {.gcsafe.} =
|
|||||||
method connect*(
|
method connect*(
|
||||||
s: Switch,
|
s: Switch,
|
||||||
peerId: PeerId,
|
peerId: PeerId,
|
||||||
addrs: seq[MultiAddress]): Future[void] =
|
addrs: seq[MultiAddress],
|
||||||
s.dialer.connect(peerId, addrs)
|
forceDial = false): Future[void] =
|
||||||
|
s.dialer.connect(peerId, addrs, forceDial)
|
||||||
|
|
||||||
method dial*(
|
method dial*(
|
||||||
s: Switch,
|
s: Switch,
|
||||||
@ -117,8 +118,9 @@ method dial*(
|
|||||||
s: Switch,
|
s: Switch,
|
||||||
peerId: PeerId,
|
peerId: PeerId,
|
||||||
addrs: seq[MultiAddress],
|
addrs: seq[MultiAddress],
|
||||||
protos: seq[string]): Future[Connection] =
|
protos: seq[string],
|
||||||
s.dialer.dial(peerId, addrs, protos)
|
forceDial = false): Future[Connection] =
|
||||||
|
s.dialer.dial(peerId, addrs, protos, forceDial)
|
||||||
|
|
||||||
proc dial*(
|
proc dial*(
|
||||||
s: Switch,
|
s: Switch,
|
||||||
|
@ -54,16 +54,21 @@ proc acquire*(s: AsyncSemaphore): Future[void] =
|
|||||||
fut.cancelCallback = nil
|
fut.cancelCallback = nil
|
||||||
if not fut.finished:
|
if not fut.finished:
|
||||||
s.queue.keepItIf( it != fut )
|
s.queue.keepItIf( it != fut )
|
||||||
s.count.inc
|
|
||||||
|
|
||||||
fut.cancelCallback = cancellation
|
fut.cancelCallback = cancellation
|
||||||
|
|
||||||
s.queue.add(fut)
|
s.queue.add(fut)
|
||||||
s.count.dec
|
|
||||||
|
|
||||||
trace "Queued slot", available = s.count, queue = s.queue.len
|
trace "Queued slot", available = s.count, queue = s.queue.len
|
||||||
return fut
|
return fut
|
||||||
|
|
||||||
|
proc forceAcquire*(s: AsyncSemaphore) =
|
||||||
|
## ForceAcquire will always succeed,
|
||||||
|
## creating a temporary slot if required.
|
||||||
|
## This temporary slot will stay usable until
|
||||||
|
## there is less `acquire`s than `release`s
|
||||||
|
s.count.dec
|
||||||
|
|
||||||
proc release*(s: AsyncSemaphore) =
|
proc release*(s: AsyncSemaphore) =
|
||||||
## Release a resource from the semaphore,
|
## Release a resource from the semaphore,
|
||||||
## by picking the first future from the queue
|
## by picking the first future from the queue
|
||||||
@ -77,13 +82,15 @@ proc release*(s: AsyncSemaphore) =
|
|||||||
trace "Releasing slot", available = s.count,
|
trace "Releasing slot", available = s.count,
|
||||||
queue = s.queue.len
|
queue = s.queue.len
|
||||||
|
|
||||||
if s.queue.len > 0:
|
s.count.inc
|
||||||
|
while s.queue.len > 0:
|
||||||
var fut = s.queue[0]
|
var fut = s.queue[0]
|
||||||
s.queue.delete(0)
|
s.queue.delete(0)
|
||||||
if not fut.finished():
|
if not fut.finished():
|
||||||
|
s.count.dec
|
||||||
fut.complete()
|
fut.complete()
|
||||||
|
break
|
||||||
|
|
||||||
s.count.inc # increment the resource count
|
|
||||||
trace "Released slot", available = s.count,
|
trace "Released slot", available = s.count,
|
||||||
queue = s.queue.len
|
queue = s.queue.len
|
||||||
return
|
return
|
||||||
|
@ -463,3 +463,33 @@ suite "Connection Manager":
|
|||||||
await connMngr.close()
|
await connMngr.close()
|
||||||
await allFuturesThrowing(
|
await allFuturesThrowing(
|
||||||
allFutures(conns.mapIt( it.close() )))
|
allFutures(conns.mapIt( it.close() )))
|
||||||
|
|
||||||
|
asyncTest "allow force dial":
|
||||||
|
let connMngr = ConnManager.new(maxConnections = 2)
|
||||||
|
|
||||||
|
var conns: seq[Connection]
|
||||||
|
for i in 0..<3:
|
||||||
|
let conn = connMngr.trackOutgoingConn(
|
||||||
|
(proc(): Future[Connection] {.async.} =
|
||||||
|
return Connection.new(
|
||||||
|
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||||
|
Direction.In)
|
||||||
|
), true
|
||||||
|
)
|
||||||
|
|
||||||
|
check await conn.withTimeout(10.millis)
|
||||||
|
conns.add(await conn)
|
||||||
|
|
||||||
|
# should throw adding a connection over the limit
|
||||||
|
expect TooManyConnectionsError:
|
||||||
|
discard await connMngr.trackOutgoingConn(
|
||||||
|
(proc(): Future[Connection] {.async.} =
|
||||||
|
return Connection.new(
|
||||||
|
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||||
|
Direction.In)
|
||||||
|
), false
|
||||||
|
)
|
||||||
|
|
||||||
|
await connMngr.close()
|
||||||
|
await allFuturesThrowing(
|
||||||
|
allFutures(conns.mapIt( it.close() )))
|
||||||
|
@ -36,7 +36,7 @@ suite "AsyncSemaphore":
|
|||||||
await sema.acquire()
|
await sema.acquire()
|
||||||
let fut = sema.acquire()
|
let fut = sema.acquire()
|
||||||
|
|
||||||
check sema.count == -1
|
check sema.count == 0
|
||||||
sema.release()
|
sema.release()
|
||||||
sema.release()
|
sema.release()
|
||||||
check sema.count == 1
|
check sema.count == 1
|
||||||
@ -66,7 +66,7 @@ suite "AsyncSemaphore":
|
|||||||
|
|
||||||
let fut = sema.acquire()
|
let fut = sema.acquire()
|
||||||
check fut.finished == false
|
check fut.finished == false
|
||||||
check sema.count == -1
|
check sema.count == 0
|
||||||
|
|
||||||
sema.release()
|
sema.release()
|
||||||
sema.release()
|
sema.release()
|
||||||
@ -104,12 +104,20 @@ suite "AsyncSemaphore":
|
|||||||
|
|
||||||
await sema.acquire()
|
await sema.acquire()
|
||||||
|
|
||||||
let tmp = sema.acquire()
|
let
|
||||||
check not tmp.finished()
|
tmp = sema.acquire()
|
||||||
|
tmp2 = sema.acquire()
|
||||||
|
check:
|
||||||
|
not tmp.finished()
|
||||||
|
not tmp2.finished()
|
||||||
|
|
||||||
tmp.cancel()
|
tmp.cancel()
|
||||||
sema.release()
|
sema.release()
|
||||||
|
|
||||||
|
check tmp2.finished()
|
||||||
|
|
||||||
|
sema.release()
|
||||||
|
|
||||||
check await sema.acquire().withTimeout(10.millis)
|
check await sema.acquire().withTimeout(10.millis)
|
||||||
|
|
||||||
asyncTest "should handle out of order cancellations":
|
asyncTest "should handle out of order cancellations":
|
||||||
@ -145,3 +153,43 @@ suite "AsyncSemaphore":
|
|||||||
sema.release()
|
sema.release()
|
||||||
|
|
||||||
check await sema.acquire().withTimeout(10.millis)
|
check await sema.acquire().withTimeout(10.millis)
|
||||||
|
|
||||||
|
asyncTest "should handle forceAcquire properly":
|
||||||
|
let sema = newAsyncSemaphore(1)
|
||||||
|
|
||||||
|
await sema.acquire()
|
||||||
|
check not(await sema.acquire().withTimeout(1.millis)) # should not acquire but cancel
|
||||||
|
|
||||||
|
let
|
||||||
|
fut1 = sema.acquire()
|
||||||
|
fut2 = sema.acquire()
|
||||||
|
|
||||||
|
sema.forceAcquire()
|
||||||
|
sema.release()
|
||||||
|
|
||||||
|
await fut1 or fut2 or sleepAsync(1.millis)
|
||||||
|
check:
|
||||||
|
fut1.finished()
|
||||||
|
not fut2.finished()
|
||||||
|
|
||||||
|
sema.release()
|
||||||
|
await fut1 or fut2 or sleepAsync(1.millis)
|
||||||
|
check:
|
||||||
|
fut1.finished()
|
||||||
|
fut2.finished()
|
||||||
|
|
||||||
|
|
||||||
|
sema.forceAcquire()
|
||||||
|
sema.forceAcquire()
|
||||||
|
|
||||||
|
let
|
||||||
|
fut3 = sema.acquire()
|
||||||
|
fut4 = sema.acquire()
|
||||||
|
fut5 = sema.acquire()
|
||||||
|
sema.release()
|
||||||
|
sema.release()
|
||||||
|
await sleepAsync(1.millis)
|
||||||
|
check:
|
||||||
|
fut3.finished()
|
||||||
|
fut4.finished()
|
||||||
|
not fut5.finished()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user