diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 0c9c3704d..b70118aa7 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -175,19 +175,43 @@ proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, g return await s.mux(result) # mux it if possible - s.connections[conn.peerInfo.id] = result + if result.peerInfo.id notin s.connections: + s.connections[result.peerInfo.id] = result proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = + ## Upgrade incoming connections, this roughly looks like: + ## - First, register all the secure handlers and await for + ## a secure request + ## + ## - Next, when a secure request arrives, handle it in the + ## catch all ``securedHandler`` handler + ## + ## - Lastly, register muxers and handler subsequent muxer + ## requests + ## trace "upgrading incoming connection" let ms = newMultistream() # secure incoming connections - proc securedHandler (conn: Connection, - proto: string) + proc securedHandler (conn: Connection, proto: string) {.async, gcsafe, closure.} = + ## generic handler for secure managers trace "Securing connection" + + # get the secure handler for the proto let secure = s.secureManagers[proto] let sconn = await secure.secure(conn) + + # if the connection has been already + # established while negotiating this + # one we drop it + if sconn.peerInfo.id in s.connections: + await sconn.close() + return + + s.connections[sconn.peerInfo.id] = sconn + + # if securing succedded, handle muxer requests if not isNil(sconn): # add the muxer for muxer in s.muxers.values: @@ -197,7 +221,8 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = await ms.handle(sconn) if (await ms.select(conn)): # just handshake - # add the secure handlers + # register all the secure managers to be + # handled by the catch all ``securedHandler`` for k in s.secureManagers.keys: ms.addHandler(k, securedHandler) @@ -217,8 +242,17 @@ proc dial*(s: Switch, if t.handles(a): # check if it can dial it trace "Dialing address", address = $a conn = await t.dial(a) + + # avoid raicing with incoming connections + if peer.id in s.connections: + if not isNil(conn) and not conn.closed(): + await conn.close() + conn = s.connections[peer.id] + # make sure to assign the peer to the connection - conn.peerInfo = peer + if isNil(conn.peerInfo): + conn.peerInfo = peer + conn = await s.upgradeOutgoing(conn) if isNil(conn): continue @@ -234,7 +268,8 @@ proc dial*(s: Switch, raise newException(CatchableError, "Unable to establish outgoing link") if proto.len > 0 and not conn.closed: - let stream = await s.getMuxedStream(peer) + result = conn + var stream = await s.getMuxedStream(peer) if not isNil(stream): trace "Connection is muxed, return muxed stream" result = stream diff --git a/tests/testswitch.nim b/tests/testswitch.nim index ce689efbc..4cce425e3 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -106,3 +106,49 @@ suite "Switch": check: waitFor(testSwitch()) == true + + test "e2e use switch nested dial": + proc testSwitch(): Future[bool] {.async, gcsafe.} = + let ma1: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") + let ma2: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") + + var peerInfo1, peerInfo2: PeerInfo + var switch1, switch2: Switch + (switch1, peerInfo1) = createSwitch(ma1) + var awaiters: seq[Future[void]] + awaiters.add(await switch1.start()) + + (switch2, peerInfo2) = createSwitch(ma2) + awaiters.add(await switch2.start()) + + var proto1 = new LPProtocol + proto1.codec = "/proto/1" + var awaiter = newFuture[void]() + proc handler1(conn: Connection, proto: string) {.async, gcsafe.} = + var nested = await switch1.dial(switch2.peerInfo, "/proto/2") + await nested.writeLp("proto 1") + check cast[string](await nested.readLp()) == "proto 2" + await nested.close() + awaiter.complete() + + proto1.handler = handler1 + switch1.mount(proto1) + + var proto2 = new LPProtocol + proto2.codec = "/proto/2" + proc handler2(conn: Connection, proto: string) {.async, gcsafe.} = + check cast[string](await conn.readLp()) == "proto 1" + await conn.writeLp("proto 2") + await conn.close() + + proto2.handler = handler2 + switch2.mount(proto2) + + discard await switch1.dial(switch1.peerInfo, "/proto/1") + await awaiter + discard allFutures(switch1.stop(), switch2.stop()) + await allFutures(awaiters) + result = true + + check: + waitFor(testSwitch()) == true