mirror of https://github.com/vacp2p/nim-libp2p.git
ConnManager connection tracking refacto (#749)
This commit is contained in:
parent
2fbe82bf9d
commit
2d864633ea
|
@ -33,9 +33,6 @@ const
|
|||
type
|
||||
TooManyConnectionsError* = object of LPError
|
||||
|
||||
ConnProvider* = proc(): Future[Connection]
|
||||
{.gcsafe, closure, raises: [Defect].}
|
||||
|
||||
ConnEventKind* {.pure.} = enum
|
||||
Connected, # A connection was made and securely upgraded - there may be
|
||||
# more than one concurrent connection thus more than one upgrade
|
||||
|
@ -84,6 +81,10 @@ type
|
|||
peerEvents: array[PeerEventKind, OrderedSet[PeerEventHandler]]
|
||||
peerStore*: PeerStore
|
||||
|
||||
ConnectionSlot* = object
|
||||
connManager: ConnManager
|
||||
direction: Direction
|
||||
|
||||
proc newTooManyConnectionsError(): ref TooManyConnectionsError {.inline.} =
|
||||
result = newException(TooManyConnectionsError, "Too many connections")
|
||||
|
||||
|
@ -404,90 +405,39 @@ proc storeConn*(c: ConnManager, conn: Connection)
|
|||
trace "Stored connection",
|
||||
conn, direction = $conn.dir, connections = c.conns.len
|
||||
|
||||
proc trackConn(c: ConnManager,
|
||||
provider: ConnProvider,
|
||||
sema: AsyncSemaphore):
|
||||
Future[Connection] {.async.} =
|
||||
var conn: Connection
|
||||
try:
|
||||
conn = await provider()
|
||||
|
||||
if isNil(conn):
|
||||
return
|
||||
|
||||
trace "Got connection", conn
|
||||
|
||||
proc semaphoreMonitor() {.async.} =
|
||||
try:
|
||||
await conn.join()
|
||||
except CatchableError as exc:
|
||||
trace "Exception in semaphore monitor, ignoring", exc = exc.msg
|
||||
|
||||
sema.release()
|
||||
|
||||
asyncSpawn semaphoreMonitor()
|
||||
except CatchableError as exc:
|
||||
trace "Exception tracking connection", exc = exc.msg
|
||||
if not isNil(conn):
|
||||
await conn.close()
|
||||
|
||||
raise exc
|
||||
|
||||
return conn
|
||||
|
||||
proc trackIncomingConn*(c: ConnManager,
|
||||
provider: ConnProvider):
|
||||
Future[Connection] {.async.} =
|
||||
## await for a connection slot before attempting
|
||||
## to call the connection provider
|
||||
##
|
||||
|
||||
var conn: Connection
|
||||
try:
|
||||
trace "Tracking incoming connection"
|
||||
await c.inSema.acquire()
|
||||
conn = await c.trackConn(provider, c.inSema)
|
||||
if isNil(conn):
|
||||
trace "Couldn't acquire connection, releasing semaphore slot", dir = $Direction.In
|
||||
c.inSema.release()
|
||||
|
||||
return conn
|
||||
except CatchableError as exc:
|
||||
trace "Exception tracking connection", exc = exc.msg
|
||||
c.inSema.release()
|
||||
raise exc
|
||||
|
||||
proc trackOutgoingConn*(c: ConnManager,
|
||||
provider: ConnProvider,
|
||||
forceDial = false):
|
||||
Future[Connection] {.async.} =
|
||||
## try acquiring a connection if all slots
|
||||
## are already taken, raise TooManyConnectionsError
|
||||
## exception
|
||||
##
|
||||
|
||||
trace "Tracking outgoing connection", count = c.outSema.count,
|
||||
max = c.outSema.size
|
||||
proc getIncomingSlot*(c: ConnManager): Future[ConnectionSlot] {.async.} =
|
||||
await c.inSema.acquire()
|
||||
return ConnectionSlot(connManager: c, direction: In)
|
||||
|
||||
proc getOutgoingSlot*(c: ConnManager, forceDial = false): Future[ConnectionSlot] {.async.} =
|
||||
if forceDial:
|
||||
c.outSema.forceAcquire()
|
||||
elif not c.outSema.tryAcquire():
|
||||
trace "Too many outgoing connections!", count = c.outSema.count,
|
||||
max = c.outSema.size
|
||||
raise newTooManyConnectionsError()
|
||||
return ConnectionSlot(connManager: c, direction: Out)
|
||||
|
||||
var conn: Connection
|
||||
try:
|
||||
conn = await c.trackConn(provider, c.outSema)
|
||||
if isNil(conn):
|
||||
trace "Couldn't acquire connection, releasing semaphore slot", dir = $Direction.Out
|
||||
c.outSema.release()
|
||||
proc release*(cs: ConnectionSlot) =
|
||||
if cs.direction == In:
|
||||
cs.connManager.inSema.release()
|
||||
else:
|
||||
cs.connManager.outSema.release()
|
||||
|
||||
return conn
|
||||
except CatchableError as exc:
|
||||
trace "Exception tracking connection", exc = exc.msg
|
||||
c.outSema.release()
|
||||
raise exc
|
||||
proc trackConnection*(cs: ConnectionSlot, conn: Connection) =
|
||||
if isNil(conn):
|
||||
cs.release()
|
||||
return
|
||||
|
||||
proc semaphoreMonitor() {.async.} =
|
||||
try:
|
||||
await conn.join()
|
||||
except CatchableError as exc:
|
||||
trace "Exception in semaphore monitor, ignoring", exc = exc.msg
|
||||
|
||||
cs.release()
|
||||
|
||||
asyncSpawn semaphoreMonitor()
|
||||
|
||||
proc storeMuxer*(c: ConnManager,
|
||||
muxer: Muxer,
|
||||
|
|
|
@ -47,8 +47,7 @@ type
|
|||
proc dialAndUpgrade(
|
||||
self: Dialer,
|
||||
peerId: PeerId,
|
||||
addrs: seq[MultiAddress],
|
||||
forceDial: bool):
|
||||
addrs: seq[MultiAddress]):
|
||||
Future[Connection] {.async.} =
|
||||
debug "Dialing peer", peerId
|
||||
|
||||
|
@ -65,20 +64,7 @@ proc dialAndUpgrade(
|
|||
trace "Dialing address", address = $a, peerId, hostname
|
||||
let dialed = try:
|
||||
libp2p_total_dial_attempts.inc()
|
||||
# await a connection slot when the total
|
||||
# connection count is equal to `maxConns`
|
||||
#
|
||||
# Need to copy to avoid "cannot be captured" errors in Nim-1.4.x.
|
||||
let
|
||||
transportCopy = transport
|
||||
addressCopy = a
|
||||
await self.connManager.trackOutgoingConn(
|
||||
() => transportCopy.dial(hostname, addressCopy),
|
||||
forceDial
|
||||
)
|
||||
except TooManyConnectionsError as exc:
|
||||
trace "Connection limit reached!"
|
||||
raise exc
|
||||
await transport.dial(hostname, a)
|
||||
except CancelledError as exc:
|
||||
debug "Dialing canceled", msg = exc.msg, peerId
|
||||
raise exc
|
||||
|
@ -101,6 +87,7 @@ proc dialAndUpgrade(
|
|||
except CatchableError as exc:
|
||||
# If we failed to establish the connection through one transport,
|
||||
# we won't succeeded through another - no use in trying again
|
||||
# TODO we should try another address though
|
||||
await dialed.close()
|
||||
debug "Upgrade failed", msg = exc.msg, peerId
|
||||
if exc isnot CancelledError:
|
||||
|
@ -139,12 +126,18 @@ proc internalConnect(
|
|||
trace "Reusing existing connection", conn, direction = $conn.dir
|
||||
return conn
|
||||
|
||||
conn = await self.dialAndUpgrade(peerId, addrs, forceDial)
|
||||
let slot = await self.connManager.getOutgoingSlot(forceDial)
|
||||
conn =
|
||||
try:
|
||||
await self.dialAndUpgrade(peerId, addrs)
|
||||
except CatchableError as exc:
|
||||
slot.release()
|
||||
raise exc
|
||||
slot.trackConnection(conn)
|
||||
if isNil(conn): # None of the addresses connected
|
||||
raise newException(DialFailedError, "Unable to establish outgoing link")
|
||||
|
||||
# We already check for this in Connection manager
|
||||
# but a disconnect could have happened right after
|
||||
# A disconnect could have happened right after
|
||||
# we've added the connection so we check again
|
||||
# to prevent races due to that.
|
||||
if conn.closed() or conn.atEof():
|
||||
|
|
|
@ -214,10 +214,14 @@ proc accept(s: Switch, transport: Transport) {.async.} = # noraises
|
|||
# the upgrade succeeds or fails, this is
|
||||
# currently done by the `upgradeMonitor`
|
||||
await upgrades.acquire() # first wait for an upgrade slot to become available
|
||||
conn = await s.connManager # next attempt to get an incoming connection
|
||||
.trackIncomingConn(
|
||||
() => transport.accept()
|
||||
)
|
||||
let slot = await s.connManager.getIncomingSlot()
|
||||
conn =
|
||||
try:
|
||||
await transport.accept()
|
||||
except CatchableError as exc:
|
||||
slot.release()
|
||||
raise exc
|
||||
slot.trackConnection(conn)
|
||||
if isNil(conn):
|
||||
# A nil connection means that we might have hit a
|
||||
# file-handle limit (or another non-fatal error),
|
||||
|
|
|
@ -243,253 +243,131 @@ suite "Connection Manager":
|
|||
asyncTest "track total incoming connection limits":
|
||||
let connMngr = ConnManager.new(maxConnections = 3)
|
||||
|
||||
var conns: seq[Connection]
|
||||
for i in 0..<3:
|
||||
let conn = connMngr.trackIncomingConn(
|
||||
proc(): Future[Connection] {.async.} =
|
||||
return Connection.new(
|
||||
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||
Direction.In)
|
||||
)
|
||||
|
||||
check await conn.withTimeout(10.millis)
|
||||
conns.add(await conn)
|
||||
check await connMngr.getIncomingSlot().withTimeout(10.millis)
|
||||
|
||||
# should timeout adding a connection over the limit
|
||||
let conn = connMngr.trackIncomingConn(
|
||||
proc(): Future[Connection] {.async.} =
|
||||
return Connection.new(
|
||||
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||
Direction.In)
|
||||
)
|
||||
|
||||
check not(await conn.withTimeout(10.millis))
|
||||
check not(await connMngr.getIncomingSlot().withTimeout(10.millis))
|
||||
|
||||
await connMngr.close()
|
||||
await allFuturesThrowing(
|
||||
allFutures(conns.mapIt( it.close() )))
|
||||
|
||||
asyncTest "track total outgoing connection limits":
|
||||
let connMngr = ConnManager.new(maxConnections = 3)
|
||||
|
||||
var conns: seq[Connection]
|
||||
for i in 0..<3:
|
||||
let conn = await connMngr.trackOutgoingConn(
|
||||
proc(): Future[Connection] {.async.} =
|
||||
return Connection.new(
|
||||
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||
Direction.In)
|
||||
)
|
||||
|
||||
conns.add(conn)
|
||||
check await connMngr.getOutgoingSlot().withTimeout(10.millis)
|
||||
|
||||
# should throw adding a connection over the limit
|
||||
expect TooManyConnectionsError:
|
||||
discard await connMngr.trackOutgoingConn(
|
||||
proc(): Future[Connection] {.async.} =
|
||||
return Connection.new(
|
||||
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||
Direction.In)
|
||||
)
|
||||
discard await connMngr.getOutgoingSlot()
|
||||
|
||||
await connMngr.close()
|
||||
await allFuturesThrowing(
|
||||
allFutures(conns.mapIt( it.close() )))
|
||||
|
||||
asyncTest "track both incoming and outgoing total connections limits - fail on incoming":
|
||||
let connMngr = ConnManager.new(maxConnections = 3)
|
||||
|
||||
var conns: seq[Connection]
|
||||
for i in 0..<3:
|
||||
let conn = await connMngr.trackOutgoingConn(
|
||||
proc(): Future[Connection] {.async.} =
|
||||
return Connection.new(
|
||||
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||
Direction.In)
|
||||
)
|
||||
|
||||
conns.add(conn)
|
||||
check await connMngr.getOutgoingSlot().withTimeout(10.millis)
|
||||
|
||||
# should timeout adding a connection over the limit
|
||||
let conn = connMngr.trackIncomingConn(
|
||||
proc(): Future[Connection] {.async.} =
|
||||
return Connection.new(
|
||||
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||
Direction.In)
|
||||
)
|
||||
|
||||
check not(await conn.withTimeout(10.millis))
|
||||
check not(await connMngr.getIncomingSlot().withTimeout(10.millis))
|
||||
|
||||
await connMngr.close()
|
||||
await allFuturesThrowing(
|
||||
allFutures(conns.mapIt( it.close() )))
|
||||
|
||||
asyncTest "track both incoming and outgoing total connections limits - fail on outgoing":
|
||||
let connMngr = ConnManager.new(maxConnections = 3)
|
||||
|
||||
var conns: seq[Connection]
|
||||
for i in 0..<3:
|
||||
let conn = connMngr.trackIncomingConn(
|
||||
proc(): Future[Connection] {.async.} =
|
||||
return Connection.new(
|
||||
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||
Direction.In)
|
||||
)
|
||||
|
||||
check await conn.withTimeout(10.millis)
|
||||
conns.add(await conn)
|
||||
check await connMngr.getIncomingSlot().withTimeout(10.millis)
|
||||
|
||||
# should throw adding a connection over the limit
|
||||
expect TooManyConnectionsError:
|
||||
discard await connMngr.trackOutgoingConn(
|
||||
proc(): Future[Connection] {.async.} =
|
||||
return Connection.new(
|
||||
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||
Direction.In)
|
||||
)
|
||||
discard await connMngr.getOutgoingSlot()
|
||||
|
||||
await connMngr.close()
|
||||
await allFuturesThrowing(
|
||||
allFutures(conns.mapIt( it.close() )))
|
||||
|
||||
asyncTest "track max incoming connection limits":
|
||||
let connMngr = ConnManager.new(maxIn = 3)
|
||||
|
||||
var conns: seq[Connection]
|
||||
for i in 0..<3:
|
||||
let conn = connMngr.trackIncomingConn(
|
||||
proc(): Future[Connection] {.async.} =
|
||||
return Connection.new(
|
||||
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||
Direction.In)
|
||||
)
|
||||
check await connMngr.getIncomingSlot().withTimeout(10.millis)
|
||||
|
||||
check await conn.withTimeout(10.millis)
|
||||
conns.add(await conn)
|
||||
|
||||
# should timeout adding a connection over the limit
|
||||
let conn = connMngr.trackIncomingConn(
|
||||
proc(): Future[Connection] {.async.} =
|
||||
return Connection.new(
|
||||
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||
Direction.In)
|
||||
)
|
||||
|
||||
check not(await conn.withTimeout(10.millis))
|
||||
check not(await connMngr.getIncomingSlot().withTimeout(10.millis))
|
||||
|
||||
await connMngr.close()
|
||||
await allFuturesThrowing(
|
||||
allFutures(conns.mapIt( it.close() )))
|
||||
|
||||
asyncTest "track max outgoing connection limits":
|
||||
let connMngr = ConnManager.new(maxOut = 3)
|
||||
|
||||
var conns: seq[Connection]
|
||||
for i in 0..<3:
|
||||
let conn = await connMngr.trackOutgoingConn(
|
||||
proc(): Future[Connection] {.async.} =
|
||||
return Connection.new(
|
||||
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||
Direction.In)
|
||||
)
|
||||
|
||||
conns.add(conn)
|
||||
check await connMngr.getOutgoingSlot().withTimeout(10.millis)
|
||||
|
||||
# should throw adding a connection over the limit
|
||||
expect TooManyConnectionsError:
|
||||
discard await connMngr.trackOutgoingConn(
|
||||
proc(): Future[Connection] {.async.} =
|
||||
return Connection.new(
|
||||
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||
Direction.In)
|
||||
)
|
||||
discard await connMngr.getOutgoingSlot()
|
||||
|
||||
await connMngr.close()
|
||||
await allFuturesThrowing(
|
||||
allFutures(conns.mapIt( it.close() )))
|
||||
|
||||
asyncTest "track incoming max connections limits - fail on incoming":
|
||||
let connMngr = ConnManager.new(maxOut = 3)
|
||||
|
||||
var conns: seq[Connection]
|
||||
for i in 0..<3:
|
||||
let conn = await connMngr.trackOutgoingConn(
|
||||
proc(): Future[Connection] {.async.} =
|
||||
return Connection.new(
|
||||
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||
Direction.In)
|
||||
)
|
||||
|
||||
conns.add(conn)
|
||||
check await connMngr.getOutgoingSlot().withTimeout(10.millis)
|
||||
|
||||
# should timeout adding a connection over the limit
|
||||
let conn = connMngr.trackIncomingConn(
|
||||
proc(): Future[Connection] {.async.} =
|
||||
return Connection.new(
|
||||
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||
Direction.In)
|
||||
)
|
||||
|
||||
check not(await conn.withTimeout(10.millis))
|
||||
check not(await connMngr.getIncomingSlot().withTimeout(10.millis))
|
||||
|
||||
await connMngr.close()
|
||||
await allFuturesThrowing(
|
||||
allFutures(conns.mapIt( it.close() )))
|
||||
|
||||
asyncTest "track incoming max connections limits - fail on outgoing":
|
||||
let connMngr = ConnManager.new(maxIn = 3)
|
||||
|
||||
var conns: seq[Connection]
|
||||
for i in 0..<3:
|
||||
let conn = connMngr.trackIncomingConn(
|
||||
proc(): Future[Connection] {.async.} =
|
||||
return Connection.new(
|
||||
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||
Direction.In)
|
||||
)
|
||||
|
||||
check await conn.withTimeout(10.millis)
|
||||
conns.add(await conn)
|
||||
check await connMngr.getIncomingSlot().withTimeout(10.millis)
|
||||
|
||||
# should throw adding a connection over the limit
|
||||
expect TooManyConnectionsError:
|
||||
discard await connMngr.trackOutgoingConn(
|
||||
proc(): Future[Connection] {.async.} =
|
||||
return Connection.new(
|
||||
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||
Direction.In)
|
||||
)
|
||||
discard await connMngr.getOutgoingSlot()
|
||||
|
||||
await connMngr.close()
|
||||
await allFuturesThrowing(
|
||||
allFutures(conns.mapIt( it.close() )))
|
||||
|
||||
asyncTest "allow force dial":
|
||||
let connMngr = ConnManager.new(maxConnections = 2)
|
||||
|
||||
var conns: seq[Connection]
|
||||
for i in 0..<3:
|
||||
let conn = connMngr.trackOutgoingConn(
|
||||
(proc(): Future[Connection] {.async.} =
|
||||
return Connection.new(
|
||||
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||
Direction.In)
|
||||
), true
|
||||
)
|
||||
|
||||
check await conn.withTimeout(10.millis)
|
||||
conns.add(await conn)
|
||||
check await connMngr.getOutgoingSlot(true).withTimeout(10.millis)
|
||||
|
||||
# should throw adding a connection over the limit
|
||||
expect TooManyConnectionsError:
|
||||
discard await connMngr.trackOutgoingConn(
|
||||
(proc(): Future[Connection] {.async.} =
|
||||
return Connection.new(
|
||||
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||
Direction.In)
|
||||
), false
|
||||
)
|
||||
discard await connMngr.getOutgoingSlot(false)
|
||||
|
||||
await connMngr.close()
|
||||
|
||||
asyncTest "release slot on connection end":
|
||||
let connMngr = ConnManager.new(maxConnections = 3)
|
||||
|
||||
var conns: seq[Connection]
|
||||
for i in 0..<3:
|
||||
let slot = await ((connMngr.getOutgoingSlot()).wait(10.millis))
|
||||
|
||||
let conn =
|
||||
Connection.new(
|
||||
PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(),
|
||||
Direction.In)
|
||||
|
||||
slot.trackConnection(conn)
|
||||
conns.add(conn)
|
||||
|
||||
# should be full now
|
||||
let incomingSlot = connMngr.getIncomingSlot()
|
||||
|
||||
check (await incomingSlot.withTimeout(10.millis)) == false
|
||||
|
||||
await allFuturesThrowing(
|
||||
allFutures(conns.mapIt( it.close() )))
|
||||
|
||||
check await incomingSlot.withTimeout(10.millis)
|
||||
|
||||
await connMngr.close()
|
||||
|
|
|
@ -21,6 +21,7 @@ import ../libp2p/[errors,
|
|||
nameresolving/nameresolver,
|
||||
nameresolving/mockresolver,
|
||||
stream/chronosstream,
|
||||
utils/semaphore,
|
||||
transports/tcptransport,
|
||||
transports/wstransport]
|
||||
import ./helpers
|
||||
|
@ -206,6 +207,12 @@ suite "Switch":
|
|||
await switch1.start()
|
||||
await switch2.start()
|
||||
|
||||
let startCounts =
|
||||
@[
|
||||
switch1.connManager.inSema.count, switch1.connManager.outSema.count,
|
||||
switch2.connManager.inSema.count, switch2.connManager.outSema.count
|
||||
]
|
||||
|
||||
await switch2.connect(switch1.peerInfo.peerId, switch1.peerInfo.addrs)
|
||||
|
||||
check switch1.isConnected(switch2.peerInfo.peerId)
|
||||
|
@ -219,6 +226,15 @@ suite "Switch":
|
|||
checkTracker(LPChannelTrackerName)
|
||||
checkTracker(SecureConnTrackerName)
|
||||
|
||||
await sleepAsync(1.seconds)
|
||||
|
||||
check:
|
||||
startCounts ==
|
||||
@[
|
||||
switch1.connManager.inSema.count, switch1.connManager.outSema.count,
|
||||
switch2.connManager.inSema.count, switch2.connManager.outSema.count
|
||||
]
|
||||
|
||||
await allFuturesThrowing(
|
||||
switch1.stop(),
|
||||
switch2.stop())
|
||||
|
|
Loading…
Reference in New Issue