diff --git a/libp2p/connmanager.nim b/libp2p/connmanager.nim index e64947e51..b7ac3d87a 100644 --- a/libp2p/connmanager.nim +++ b/libp2p/connmanager.nim @@ -33,9 +33,6 @@ const type TooManyConnectionsError* = object of LPError - ConnProvider* = proc(): Future[Connection] - {.gcsafe, closure, raises: [Defect].} - ConnEventKind* {.pure.} = enum Connected, # A connection was made and securely upgraded - there may be # more than one concurrent connection thus more than one upgrade @@ -84,6 +81,10 @@ type peerEvents: array[PeerEventKind, OrderedSet[PeerEventHandler]] peerStore*: PeerStore + ConnectionSlot* = object + connManager: ConnManager + direction: Direction + proc newTooManyConnectionsError(): ref TooManyConnectionsError {.inline.} = result = newException(TooManyConnectionsError, "Too many connections") @@ -404,90 +405,39 @@ proc storeConn*(c: ConnManager, conn: Connection) trace "Stored connection", conn, direction = $conn.dir, connections = c.conns.len -proc trackConn(c: ConnManager, - provider: ConnProvider, - sema: AsyncSemaphore): - Future[Connection] {.async.} = - var conn: Connection - try: - conn = await provider() - - if isNil(conn): - return - - trace "Got connection", conn - - proc semaphoreMonitor() {.async.} = - try: - await conn.join() - except CatchableError as exc: - trace "Exception in semaphore monitor, ignoring", exc = exc.msg - - sema.release() - - asyncSpawn semaphoreMonitor() - except CatchableError as exc: - trace "Exception tracking connection", exc = exc.msg - if not isNil(conn): - await conn.close() - - raise exc - - return conn - -proc trackIncomingConn*(c: ConnManager, - provider: ConnProvider): - Future[Connection] {.async.} = - ## await for a connection slot before attempting - ## to call the connection provider - ## - - var conn: Connection - try: - trace "Tracking incoming connection" - await c.inSema.acquire() - conn = await c.trackConn(provider, c.inSema) - if isNil(conn): - trace "Couldn't acquire connection, releasing semaphore slot", dir = $Direction.In - c.inSema.release() - - return conn - except CatchableError as exc: - trace "Exception tracking connection", exc = exc.msg - c.inSema.release() - raise exc - -proc trackOutgoingConn*(c: ConnManager, - provider: ConnProvider, - forceDial = false): - Future[Connection] {.async.} = - ## try acquiring a connection if all slots - ## are already taken, raise TooManyConnectionsError - ## exception - ## - - trace "Tracking outgoing connection", count = c.outSema.count, - max = c.outSema.size +proc getIncomingSlot*(c: ConnManager): Future[ConnectionSlot] {.async.} = + await c.inSema.acquire() + return ConnectionSlot(connManager: c, direction: In) +proc getOutgoingSlot*(c: ConnManager, forceDial = false): Future[ConnectionSlot] {.async.} = if forceDial: c.outSema.forceAcquire() elif not c.outSema.tryAcquire(): trace "Too many outgoing connections!", count = c.outSema.count, max = c.outSema.size raise newTooManyConnectionsError() + return ConnectionSlot(connManager: c, direction: Out) - var conn: Connection - try: - conn = await c.trackConn(provider, c.outSema) - if isNil(conn): - trace "Couldn't acquire connection, releasing semaphore slot", dir = $Direction.Out - c.outSema.release() +proc release*(cs: ConnectionSlot) = + if cs.direction == In: + cs.connManager.inSema.release() + else: + cs.connManager.outSema.release() - return conn - except CatchableError as exc: - trace "Exception tracking connection", exc = exc.msg - c.outSema.release() - raise exc +proc trackConnection*(cs: ConnectionSlot, conn: Connection) = + if isNil(conn): + cs.release() + return + + proc semaphoreMonitor() {.async.} = + try: + await conn.join() + except CatchableError as exc: + trace "Exception in semaphore monitor, ignoring", exc = exc.msg + + cs.release() + + asyncSpawn semaphoreMonitor() proc storeMuxer*(c: ConnManager, muxer: Muxer, diff --git a/libp2p/dialer.nim b/libp2p/dialer.nim index 9d50cb547..ece835af7 100644 --- a/libp2p/dialer.nim +++ b/libp2p/dialer.nim @@ -47,8 +47,7 @@ type proc dialAndUpgrade( self: Dialer, peerId: PeerId, - addrs: seq[MultiAddress], - forceDial: bool): + addrs: seq[MultiAddress]): Future[Connection] {.async.} = debug "Dialing peer", peerId @@ -65,20 +64,7 @@ proc dialAndUpgrade( trace "Dialing address", address = $a, peerId, hostname let dialed = try: libp2p_total_dial_attempts.inc() - # await a connection slot when the total - # connection count is equal to `maxConns` - # - # Need to copy to avoid "cannot be captured" errors in Nim-1.4.x. - let - transportCopy = transport - addressCopy = a - await self.connManager.trackOutgoingConn( - () => transportCopy.dial(hostname, addressCopy), - forceDial - ) - except TooManyConnectionsError as exc: - trace "Connection limit reached!" - raise exc + await transport.dial(hostname, a) except CancelledError as exc: debug "Dialing canceled", msg = exc.msg, peerId raise exc @@ -101,6 +87,7 @@ proc dialAndUpgrade( except CatchableError as exc: # If we failed to establish the connection through one transport, # we won't succeeded through another - no use in trying again + # TODO we should try another address though await dialed.close() debug "Upgrade failed", msg = exc.msg, peerId if exc isnot CancelledError: @@ -139,12 +126,18 @@ proc internalConnect( trace "Reusing existing connection", conn, direction = $conn.dir return conn - conn = await self.dialAndUpgrade(peerId, addrs, forceDial) + let slot = await self.connManager.getOutgoingSlot(forceDial) + conn = + try: + await self.dialAndUpgrade(peerId, addrs) + except CatchableError as exc: + slot.release() + raise exc + slot.trackConnection(conn) if isNil(conn): # None of the addresses connected raise newException(DialFailedError, "Unable to establish outgoing link") - # We already check for this in Connection manager - # but a disconnect could have happened right after + # A disconnect could have happened right after # we've added the connection so we check again # to prevent races due to that. if conn.closed() or conn.atEof(): diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 6a6cf6979..f6f292510 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -214,10 +214,14 @@ proc accept(s: Switch, transport: Transport) {.async.} = # noraises # the upgrade succeeds or fails, this is # currently done by the `upgradeMonitor` await upgrades.acquire() # first wait for an upgrade slot to become available - conn = await s.connManager # next attempt to get an incoming connection - .trackIncomingConn( - () => transport.accept() - ) + let slot = await s.connManager.getIncomingSlot() + conn = + try: + await transport.accept() + except CatchableError as exc: + slot.release() + raise exc + slot.trackConnection(conn) if isNil(conn): # A nil connection means that we might have hit a # file-handle limit (or another non-fatal error), diff --git a/tests/testconnmngr.nim b/tests/testconnmngr.nim index 6aaa4edb0..5c705bf36 100644 --- a/tests/testconnmngr.nim +++ b/tests/testconnmngr.nim @@ -243,253 +243,131 @@ suite "Connection Manager": asyncTest "track total incoming connection limits": let connMngr = ConnManager.new(maxConnections = 3) - var conns: seq[Connection] for i in 0..<3: - let conn = connMngr.trackIncomingConn( - proc(): Future[Connection] {.async.} = - return Connection.new( - PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), - Direction.In) - ) - - check await conn.withTimeout(10.millis) - conns.add(await conn) + check await connMngr.getIncomingSlot().withTimeout(10.millis) # should timeout adding a connection over the limit - let conn = connMngr.trackIncomingConn( - proc(): Future[Connection] {.async.} = - return Connection.new( - PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), - Direction.In) - ) - - check not(await conn.withTimeout(10.millis)) + check not(await connMngr.getIncomingSlot().withTimeout(10.millis)) await connMngr.close() - await allFuturesThrowing( - allFutures(conns.mapIt( it.close() ))) asyncTest "track total outgoing connection limits": let connMngr = ConnManager.new(maxConnections = 3) - var conns: seq[Connection] for i in 0..<3: - let conn = await connMngr.trackOutgoingConn( - proc(): Future[Connection] {.async.} = - return Connection.new( - PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), - Direction.In) - ) - - conns.add(conn) + check await connMngr.getOutgoingSlot().withTimeout(10.millis) # should throw adding a connection over the limit expect TooManyConnectionsError: - discard await connMngr.trackOutgoingConn( - proc(): Future[Connection] {.async.} = - return Connection.new( - PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), - Direction.In) - ) + discard await connMngr.getOutgoingSlot() await connMngr.close() - await allFuturesThrowing( - allFutures(conns.mapIt( it.close() ))) asyncTest "track both incoming and outgoing total connections limits - fail on incoming": let connMngr = ConnManager.new(maxConnections = 3) - var conns: seq[Connection] for i in 0..<3: - let conn = await connMngr.trackOutgoingConn( - proc(): Future[Connection] {.async.} = - return Connection.new( - PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), - Direction.In) - ) - - conns.add(conn) + check await connMngr.getOutgoingSlot().withTimeout(10.millis) # should timeout adding a connection over the limit - let conn = connMngr.trackIncomingConn( - proc(): Future[Connection] {.async.} = - return Connection.new( - PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), - Direction.In) - ) - - check not(await conn.withTimeout(10.millis)) + check not(await connMngr.getIncomingSlot().withTimeout(10.millis)) await connMngr.close() - await allFuturesThrowing( - allFutures(conns.mapIt( it.close() ))) asyncTest "track both incoming and outgoing total connections limits - fail on outgoing": let connMngr = ConnManager.new(maxConnections = 3) - var conns: seq[Connection] for i in 0..<3: - let conn = connMngr.trackIncomingConn( - proc(): Future[Connection] {.async.} = - return Connection.new( - PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), - Direction.In) - ) - - check await conn.withTimeout(10.millis) - conns.add(await conn) + check await connMngr.getIncomingSlot().withTimeout(10.millis) # should throw adding a connection over the limit expect TooManyConnectionsError: - discard await connMngr.trackOutgoingConn( - proc(): Future[Connection] {.async.} = - return Connection.new( - PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), - Direction.In) - ) + discard await connMngr.getOutgoingSlot() await connMngr.close() - await allFuturesThrowing( - allFutures(conns.mapIt( it.close() ))) asyncTest "track max incoming connection limits": let connMngr = ConnManager.new(maxIn = 3) - var conns: seq[Connection] for i in 0..<3: - let conn = connMngr.trackIncomingConn( - proc(): Future[Connection] {.async.} = - return Connection.new( - PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), - Direction.In) - ) + check await connMngr.getIncomingSlot().withTimeout(10.millis) - check await conn.withTimeout(10.millis) - conns.add(await conn) - - # should timeout adding a connection over the limit - let conn = connMngr.trackIncomingConn( - proc(): Future[Connection] {.async.} = - return Connection.new( - PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), - Direction.In) - ) - - check not(await conn.withTimeout(10.millis)) + check not(await connMngr.getIncomingSlot().withTimeout(10.millis)) await connMngr.close() - await allFuturesThrowing( - allFutures(conns.mapIt( it.close() ))) asyncTest "track max outgoing connection limits": let connMngr = ConnManager.new(maxOut = 3) - var conns: seq[Connection] for i in 0..<3: - let conn = await connMngr.trackOutgoingConn( - proc(): Future[Connection] {.async.} = - return Connection.new( - PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), - Direction.In) - ) - - conns.add(conn) + check await connMngr.getOutgoingSlot().withTimeout(10.millis) # should throw adding a connection over the limit expect TooManyConnectionsError: - discard await connMngr.trackOutgoingConn( - proc(): Future[Connection] {.async.} = - return Connection.new( - PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), - Direction.In) - ) + discard await connMngr.getOutgoingSlot() await connMngr.close() - await allFuturesThrowing( - allFutures(conns.mapIt( it.close() ))) asyncTest "track incoming max connections limits - fail on incoming": let connMngr = ConnManager.new(maxOut = 3) - var conns: seq[Connection] for i in 0..<3: - let conn = await connMngr.trackOutgoingConn( - proc(): Future[Connection] {.async.} = - return Connection.new( - PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), - Direction.In) - ) - - conns.add(conn) + check await connMngr.getOutgoingSlot().withTimeout(10.millis) # should timeout adding a connection over the limit - let conn = connMngr.trackIncomingConn( - proc(): Future[Connection] {.async.} = - return Connection.new( - PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), - Direction.In) - ) - - check not(await conn.withTimeout(10.millis)) + check not(await connMngr.getIncomingSlot().withTimeout(10.millis)) await connMngr.close() - await allFuturesThrowing( - allFutures(conns.mapIt( it.close() ))) asyncTest "track incoming max connections limits - fail on outgoing": let connMngr = ConnManager.new(maxIn = 3) var conns: seq[Connection] for i in 0..<3: - let conn = connMngr.trackIncomingConn( - proc(): Future[Connection] {.async.} = - return Connection.new( - PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), - Direction.In) - ) - - check await conn.withTimeout(10.millis) - conns.add(await conn) + check await connMngr.getIncomingSlot().withTimeout(10.millis) # should throw adding a connection over the limit expect TooManyConnectionsError: - discard await connMngr.trackOutgoingConn( - proc(): Future[Connection] {.async.} = - return Connection.new( - PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), - Direction.In) - ) + discard await connMngr.getOutgoingSlot() await connMngr.close() - await allFuturesThrowing( - allFutures(conns.mapIt( it.close() ))) asyncTest "allow force dial": let connMngr = ConnManager.new(maxConnections = 2) var conns: seq[Connection] for i in 0..<3: - let conn = connMngr.trackOutgoingConn( - (proc(): Future[Connection] {.async.} = - return Connection.new( - PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), - Direction.In) - ), true - ) - - check await conn.withTimeout(10.millis) - conns.add(await conn) + check await connMngr.getOutgoingSlot(true).withTimeout(10.millis) # should throw adding a connection over the limit expect TooManyConnectionsError: - discard await connMngr.trackOutgoingConn( - (proc(): Future[Connection] {.async.} = - return Connection.new( - PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), - Direction.In) - ), false - ) + discard await connMngr.getOutgoingSlot(false) await connMngr.close() + + asyncTest "release slot on connection end": + let connMngr = ConnManager.new(maxConnections = 3) + + var conns: seq[Connection] + for i in 0..<3: + let slot = await ((connMngr.getOutgoingSlot()).wait(10.millis)) + + let conn = + Connection.new( + PeerId.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()).tryGet(), + Direction.In) + + slot.trackConnection(conn) + conns.add(conn) + + # should be full now + let incomingSlot = connMngr.getIncomingSlot() + + check (await incomingSlot.withTimeout(10.millis)) == false + await allFuturesThrowing( allFutures(conns.mapIt( it.close() ))) + + check await incomingSlot.withTimeout(10.millis) + + await connMngr.close() diff --git a/tests/testswitch.nim b/tests/testswitch.nim index fc64a21f7..daaa0f088 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -21,6 +21,7 @@ import ../libp2p/[errors, nameresolving/nameresolver, nameresolving/mockresolver, stream/chronosstream, + utils/semaphore, transports/tcptransport, transports/wstransport] import ./helpers @@ -206,6 +207,12 @@ suite "Switch": await switch1.start() await switch2.start() + let startCounts = + @[ + switch1.connManager.inSema.count, switch1.connManager.outSema.count, + switch2.connManager.inSema.count, switch2.connManager.outSema.count + ] + await switch2.connect(switch1.peerInfo.peerId, switch1.peerInfo.addrs) check switch1.isConnected(switch2.peerInfo.peerId) @@ -219,6 +226,15 @@ suite "Switch": checkTracker(LPChannelTrackerName) checkTracker(SecureConnTrackerName) + await sleepAsync(1.seconds) + + check: + startCounts == + @[ + switch1.connManager.inSema.count, switch1.connManager.outSema.count, + switch2.connManager.inSema.count, switch2.connManager.outSema.count + ] + await allFuturesThrowing( switch1.stop(), switch2.stop())