diff --git a/libp2p/connmanager.nim b/libp2p/connmanager.nim index e20642e65..818af37a8 100644 --- a/libp2p/connmanager.nim +++ b/libp2p/connmanager.nim @@ -83,7 +83,7 @@ proc init*(C: type ConnManager, maxIn = -1, maxOut = -1): ConnManager = var inSema, outSema: AsyncSemaphore - if maxIn > 0 and maxOut > 0: + if maxIn > 0 or maxOut > 0: inSema = newAsyncSemaphore(maxIn) outSema = newAsyncSemaphore(maxOut) elif maxConnections > 0: diff --git a/libp2p/standard_setup.nim b/libp2p/standard_setup.nim index b61a45454..50aee71bc 100644 --- a/libp2p/standard_setup.nim +++ b/libp2p/standard_setup.nim @@ -20,8 +20,6 @@ type proc newStandardSwitch*(privKey = none(PrivateKey), address = MultiAddress.init("/ip4/127.0.0.1/tcp/0").tryGet(), secureManagers: openarray[SecureProtocol] = [ - # array cos order matters - SecureProtocol.Secio, SecureProtocol.Noise, ], transportFlags: set[ServerFlags] = {}, diff --git a/libp2p/stream/connection.nim b/libp2p/stream/connection.nim index a523673f1..27a287d81 100644 --- a/libp2p/stream/connection.nim +++ b/libp2p/stream/connection.nim @@ -41,11 +41,12 @@ proc isUpgraded*(s: Connection): bool = return s.upgraded.finished proc upgrade*(s: Connection, failed: ref Exception = nil) = - if not isNil(failed): - s.upgraded.fail(failed) - return + if not isNil(s.upgraded): + if not isNil(failed): + s.upgraded.fail(failed) + return - s.upgraded.complete() + s.upgraded.complete() proc onUpgrade*(s: Connection) {.async.} = if not isNil(s.upgraded): diff --git a/tests/testconnmngr.nim b/tests/testconnmngr.nim index d427bcb9a..86aaa67d1 100644 --- a/tests/testconnmngr.nim +++ b/tests/testconnmngr.nim @@ -239,3 +239,227 @@ suite "Connection Manager": check isNil(connMngr.selectConn(peer.peerId, Direction.Out)) await connMngr.close() + + asyncTest "track total incoming connection limits": + let connMngr = ConnManager.init(maxConnections = 3) + + var conns: seq[Connection] + for i in 0..<3: + let conn = connMngr.trackIncomingConn( + proc(): Future[Connection] {.async.} = + return Connection.init( + PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()), + Direction.In) + ) + + check await conn.withTimeout(10.millis) + conns.add(await conn) + + # should timeout adding a connection over the limit + let conn = connMngr.trackIncomingConn( + proc(): Future[Connection] {.async.} = + return Connection.init( + PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()), + Direction.In) + ) + + check not(await conn.withTimeout(10.millis)) + + await connMngr.close() + await allFuturesThrowing( + allFutures(conns.mapIt( it.close() ))) + + asyncTest "track total outgoing connection limits": + let connMngr = ConnManager.init(maxConnections = 3) + + var conns: seq[Connection] + for i in 0..<3: + let conn = await connMngr.trackOutgoingConn( + proc(): Future[Connection] {.async.} = + return Connection.init( + PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()), + Direction.In) + ) + + conns.add(conn) + + # should throw adding a connection over the limit + expect TooManyConnectionsError: + discard await connMngr.trackOutgoingConn( + proc(): Future[Connection] {.async.} = + return Connection.init( + PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()), + Direction.In) + ) + + await connMngr.close() + await allFuturesThrowing( + allFutures(conns.mapIt( it.close() ))) + + asyncTest "track both incoming and outgoing total connections limits - fail on incoming": + let connMngr = ConnManager.init(maxConnections = 3) + + var conns: seq[Connection] + for i in 0..<3: + let conn = await connMngr.trackOutgoingConn( + proc(): Future[Connection] {.async.} = + return Connection.init( + PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()), + Direction.In) + ) + + conns.add(conn) + + # should timeout adding a connection over the limit + let conn = connMngr.trackIncomingConn( + proc(): Future[Connection] {.async.} = + return Connection.init( + PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()), + Direction.In) + ) + + check not(await conn.withTimeout(10.millis)) + + await connMngr.close() + await allFuturesThrowing( + allFutures(conns.mapIt( it.close() ))) + + asyncTest "track both incoming and outgoing total connections limits - fail on outgoing": + let connMngr = ConnManager.init(maxConnections = 3) + + var conns: seq[Connection] + for i in 0..<3: + let conn = connMngr.trackIncomingConn( + proc(): Future[Connection] {.async.} = + return Connection.init( + PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()), + Direction.In) + ) + + check await conn.withTimeout(10.millis) + conns.add(await conn) + + # should throw adding a connection over the limit + expect TooManyConnectionsError: + discard await connMngr.trackOutgoingConn( + proc(): Future[Connection] {.async.} = + return Connection.init( + PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()), + Direction.In) + ) + + await connMngr.close() + await allFuturesThrowing( + allFutures(conns.mapIt( it.close() ))) + + asyncTest "track max incoming connection limits": + let connMngr = ConnManager.init(maxIn = 3) + + var conns: seq[Connection] + for i in 0..<3: + let conn = connMngr.trackIncomingConn( + proc(): Future[Connection] {.async.} = + return Connection.init( + PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()), + Direction.In) + ) + + check await conn.withTimeout(10.millis) + conns.add(await conn) + + # should timeout adding a connection over the limit + let conn = connMngr.trackIncomingConn( + proc(): Future[Connection] {.async.} = + return Connection.init( + PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()), + Direction.In) + ) + + check not(await conn.withTimeout(10.millis)) + + await connMngr.close() + await allFuturesThrowing( + allFutures(conns.mapIt( it.close() ))) + + asyncTest "track max outgoing connection limits": + let connMngr = ConnManager.init(maxOut = 3) + + var conns: seq[Connection] + for i in 0..<3: + let conn = await connMngr.trackOutgoingConn( + proc(): Future[Connection] {.async.} = + return Connection.init( + PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()), + Direction.In) + ) + + conns.add(conn) + + # should throw adding a connection over the limit + expect TooManyConnectionsError: + discard await connMngr.trackOutgoingConn( + proc(): Future[Connection] {.async.} = + return Connection.init( + PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()), + Direction.In) + ) + + await connMngr.close() + await allFuturesThrowing( + allFutures(conns.mapIt( it.close() ))) + + asyncTest "track incoming max connections limits - fail on incoming": + let connMngr = ConnManager.init(maxOut = 3) + + var conns: seq[Connection] + for i in 0..<3: + let conn = await connMngr.trackOutgoingConn( + proc(): Future[Connection] {.async.} = + return Connection.init( + PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()), + Direction.In) + ) + + conns.add(conn) + + # should timeout adding a connection over the limit + let conn = connMngr.trackIncomingConn( + proc(): Future[Connection] {.async.} = + return Connection.init( + PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()), + Direction.In) + ) + + check not(await conn.withTimeout(10.millis)) + + await connMngr.close() + await allFuturesThrowing( + allFutures(conns.mapIt( it.close() ))) + + asyncTest "track incoming max connections limits - fail on outgoing": + let connMngr = ConnManager.init(maxIn = 3) + + var conns: seq[Connection] + for i in 0..<3: + let conn = connMngr.trackIncomingConn( + proc(): Future[Connection] {.async.} = + return Connection.init( + PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()), + Direction.In) + ) + + check await conn.withTimeout(10.millis) + conns.add(await conn) + + # should throw adding a connection over the limit + expect TooManyConnectionsError: + discard await connMngr.trackOutgoingConn( + proc(): Future[Connection] {.async.} = + return Connection.init( + PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()), + Direction.In) + ) + + await connMngr.close() + await allFuturesThrowing( + allFutures(conns.mapIt( it.close() ))) diff --git a/tests/testswitch.nim b/tests/testswitch.nim index 205ac2cdf..fc371b34c 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -741,3 +741,115 @@ suite "Switch": expect(DialFailedError): let conn = await switch2.dial(somePeer, TestCodec) await switch2.stop() + + asyncTest "e2e total connection limits on incoming connections": + var awaiters: seq[Future[void]] + + var switches: seq[Switch] + let destSwitch = newStandardSwitch(maxConnections = 3) + switches.add(destSwitch) + awaiters.add(await destSwitch.start()) + + let destPeerInfo = destSwitch.peerInfo + for i in 0..<3: + let switch = newStandardSwitch() + switches.add(switch) + awaiters.add(await switch.start()) + + check await switch.connect(destPeerInfo) + .withTimeout(100.millis) + + let switchFail = newStandardSwitch() + switches.add(switchFail) + awaiters.add(await switchFail.start()) + + check not(await switchFail.connect(destPeerInfo) + .withTimeout(100.millis)) + + await allFuturesThrowing( + allFutures(switches.mapIt( it.stop() ))) + await allFuturesThrowing(awaiters) + + asyncTest "e2e total connection limits on incoming connections": + var awaiters: seq[Future[void]] + + var switches: seq[Switch] + for i in 0..<3: + switches.add(newStandardSwitch()) + awaiters.add(await switches[i].start()) + + let srcSwitch = newStandardSwitch(maxConnections = 3) + awaiters.add(await srcSwitch.start()) + + let dstSwitch = newStandardSwitch() + awaiters.add(await dstSwitch.start()) + + for s in switches: + check await srcSwitch.connect(s.peerInfo) + .withTimeout(100.millis) + + expect TooManyConnectionsError: + await srcSwitch.connect(dstSwitch.peerInfo) + + switches.add(srcSwitch) + switches.add(dstSwitch) + + await allFuturesThrowing( + allFutures(switches.mapIt( it.stop() ))) + await allFuturesThrowing(awaiters) + + asyncTest "e2e max incoming connection limits": + var awaiters: seq[Future[void]] + + var switches: seq[Switch] + let destSwitch = newStandardSwitch(maxIn = 3) + switches.add(destSwitch) + awaiters.add(await destSwitch.start()) + + let destPeerInfo = destSwitch.peerInfo + for i in 0..<3: + let switch = newStandardSwitch() + switches.add(switch) + awaiters.add(await switch.start()) + + check await switch.connect(destPeerInfo) + .withTimeout(100.millis) + + let switchFail = newStandardSwitch() + switches.add(switchFail) + awaiters.add(await switchFail.start()) + + check not(await switchFail.connect(destPeerInfo) + .withTimeout(100.millis)) + + await allFuturesThrowing( + allFutures(switches.mapIt( it.stop() ))) + await allFuturesThrowing(awaiters) + + asyncTest "e2e max outgoing connection limits": + var awaiters: seq[Future[void]] + + var switches: seq[Switch] + for i in 0..<3: + switches.add(newStandardSwitch()) + awaiters.add(await switches[i].start()) + + let srcSwitch = newStandardSwitch(maxOut = 3) + awaiters.add(await srcSwitch.start()) + + let dstSwitch = newStandardSwitch() + awaiters.add(await dstSwitch.start()) + + for s in switches: + check await srcSwitch.connect(s.peerInfo) + .withTimeout(100.millis) + + expect TooManyConnectionsError: + await srcSwitch.connect(dstSwitch.peerInfo) + + switches.add(srcSwitch) + switches.add(dstSwitch) + + await allFuturesThrowing( + allFutures(switches.mapIt( it.stop() ))) + await allFuturesThrowing(awaiters)