diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 726a3a0e6..f29e0d6b2 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -257,6 +257,45 @@ proc upgradeIncoming(s: Switch, incomingConn: Connection) {.async, gcsafe.} = # finally: await incomingConn.close() +proc dialAndUpgrade(s: Switch, + peerId: PeerID, + addrs: seq[MultiAddress]): + Future[Connection] {.async.} = + debug "Dialing peer", peerId + for t in s.transports: # for each transport + for a in addrs: # for each address + if t.handles(a): # check if it can dial it + trace "Dialing address", address = $a, peerId + let dialed = try: + await t.dial(a) + except CancelledError as exc: + debug "Dialing canceled", msg = exc.msg, peerId + raise exc + except CatchableError as exc: + debug "Dialing failed", msg = exc.msg, peerId + libp2p_failed_dials.inc() + continue # Try the next address + + # make sure to assign the peer to the connection + dialed.peerInfo = PeerInfo.init(peerId, addrs) + + libp2p_dialed_peers.inc() + + let conn = try: + await s.upgradeOutgoing(dialed) + except CatchableError as exc: + # If we failed to establish the connection through one transport, + # we won't succeeded through another - no use in trying again + await dialed.close() + debug "Upgrade failed", msg = exc.msg, peerId + if exc isnot CancelledError: + libp2p_failed_upgrade.inc() + raise exc + + doAssert not isNil(conn), "connection died after upgradeOutgoing" + debug "Dial successful", conn, peerInfo = conn.peerInfo + return conn + proc internalConnect(s: Switch, peerId: PeerID, addrs: seq[MultiAddress]): Future[Connection] {.async.} = @@ -281,45 +320,9 @@ proc internalConnect(s: Switch, raise newException(DialFailedError, "Zombie connection encountered") trace "Reusing existing connection", conn, direction = $conn.dir - return conn - debug "Dialing peer", peerId - for t in s.transports: # for each transport - for a in addrs: # for each address - if t.handles(a): # check if it can dial it - trace "Dialing address", address = $a, peerId - let dialed = try: - await t.dial(a) - except CancelledError as exc: - debug "Dialing canceled", msg = exc.msg, peerId - raise exc - except CatchableError as exc: - debug "Dialing failed", msg = exc.msg, peerId - libp2p_failed_dials.inc() - continue # Try the next address - - # make sure to assign the peer to the connection - dialed.peerInfo = PeerInfo.init(peerId, addrs) - - libp2p_dialed_peers.inc() - - let upgraded = try: - await s.upgradeOutgoing(dialed) - except CatchableError as exc: - # If we failed to establish the connection through one transport, - # we won't succeeded through another - no use in trying again - await dialed.close() - debug "Upgrade failed", msg = exc.msg, peerId - if exc isnot CancelledError: - libp2p_failed_upgrade.inc() - raise exc - - doAssert not isNil(upgraded), "connection died after upgradeOutgoing" - - conn = upgraded - debug "Dial successful", conn, peerInfo = conn.peerInfo - break + conn = await s.dialAndUpgrade(peerId, addrs) finally: if lock.locked(): lock.release() diff --git a/libp2p/transports/tcptransport.nim b/libp2p/transports/tcptransport.nim index eb7f413ed..48f0e57ce 100644 --- a/libp2p/transports/tcptransport.nim +++ b/libp2p/transports/tcptransport.nim @@ -20,6 +20,8 @@ import transport, logScope: topics = "tcptransport" +export transport + const TcpTransportTrackerName* = "libp2p.tcptransport" diff --git a/tests/testswitch.nim b/tests/testswitch.nim index c113ce11b..a6fa9bca0 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -17,7 +17,9 @@ import ../libp2p/[errors, protocols/secure/secure, muxers/muxer, muxers/mplex/lpchannel, - stream/lpstream] + stream/lpstream, + stream/chronosstream, + transports/tcptransport] import ./helpers const @@ -624,6 +626,78 @@ suite "Switch": switches.mapIt( it.stop() )) await allFuturesThrowing(awaiters) + # TODO: we should be able to test cancellation + # for most of the steps in the upgrade flow - + # this is just a basic test for dials + asyncTest "e2e canceling dial should not leak": + let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + + let transport = TcpTransport.init() + await transport.start(ma) + + proc acceptHandler() {.async, gcsafe.} = + try: + let conn = await transport.accept() + discard await conn.readLp(100) + except CatchableError as exc: + discard + + let handlerWait = acceptHandler() + let switch = newStandardSwitch(secureManagers = [SecureProtocol.Noise]) + + var awaiters: seq[Future[void]] + awaiters.add(await switch.start()) + + var peerId = PeerID.init(PrivateKey.random(ECDSA, rng[]).get()).get() + let connectFut = switch.connect(peerId, @[transport.ma]) + await sleepAsync(500.millis) + connectFut.cancel() + await handlerWait + + checkTracker(LPChannelTrackerName) + checkTracker(SecureConnTrackerName) + checkTracker(ChronosStreamTrackerName) + + await allFuturesThrowing( + transport.stop(), + switch.stop()) + + # this needs to go at end + await allFuturesThrowing(awaiters) + + asyncTest "e2e closing remote conn should not leak": + let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + + let transport = TcpTransport.init() + await transport.start(ma) + + proc acceptHandler() {.async, gcsafe.} = + let conn = await transport.accept() + await conn.close() + + let handlerWait = acceptHandler() + let switch = newStandardSwitch(secureManagers = [SecureProtocol.Noise]) + + var awaiters: seq[Future[void]] + awaiters.add(await switch.start()) + + var peerId = PeerID.init(PrivateKey.random(ECDSA, rng[]).get()).get() + expect LPStreamClosedError: + await switch.connect(peerId, @[transport.ma]) + + await handlerWait + + checkTracker(LPChannelTrackerName) + checkTracker(SecureConnTrackerName) + checkTracker(ChronosStreamTrackerName) + + await allFuturesThrowing( + transport.stop(), + switch.stop()) + + # this needs to go at end + await allFuturesThrowing(awaiters) + asyncTest "connect to inexistent peer": let switch2 = newStandardSwitch(secureManagers = [SecureProtocol.Noise]) let sfut = await switch2.start()