diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 42a1a78f6..26028d015 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -209,10 +209,8 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = proc subscribeToPeer(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} -proc dial*(s: Switch, - peer: PeerInfo, - proto: string = ""): - Future[Connection] {.async.} = +proc internalConnect(s: Switch, + peer: PeerInfo): Future[Connection] {.async.} = let id = peer.id trace "Dialing peer", peer = id var conn = s.connections.getOrDefault(id) @@ -239,21 +237,34 @@ proc dial*(s: Switch, else: trace "Reusing existing connection" + await s.subscribeToPeer(peer) + result = conn + +proc connect*(s: Switch, peer: PeerInfo) {.async.} = + var conn = s.internalConnect(peer) + if isNil(conn): + raise newException(CatchableError, "Unable to connect to peer") + +proc dial*(s: Switch, + peer: PeerInfo, + proto: string): + Future[Connection] {.async.} = + var conn = await s.internalConnect(peer) if isNil(conn): raise newException(CatchableError, "Unable to establish outgoing link") - if proto.len > 0 and not conn.closed: - let stream = await s.getMuxedStream(peer) - if not isNil(stream): - trace "Connection is muxed, return muxed stream" - result = stream - trace "Attempting to select remote", proto = proto + if conn.closed: + raise newException(CatchableError, "Connection dead on arrival") - if not await s.ms.select(result, proto): - error "Unable to select sub-protocol", proto = proto - raise newException(CatchableError, &"unable to select protocol: {proto}") + let stream = await s.getMuxedStream(peer) + if not isNil(stream): + trace "Connection is muxed, return muxed stream" + result = stream + trace "Attempting to select remote", proto = proto - await s.subscribeToPeer(peer) + if not await s.ms.select(result, proto): + warn "Unable to select sub-protocol", proto = proto + raise newException(CatchableError, &"unable to select protocol: {proto}") proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} = if isNil(proto.handler): @@ -279,7 +290,7 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = await conn.close() await s.cleanupConn(conn) - + var startFuts: seq[Future[void]] for t in s.transports: # for each transport for i, a in s.peerInfo.addrs: @@ -310,7 +321,7 @@ proc subscribeToPeer(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} = let conn = await s.dial(peerInfo, s.pubSub.get().codec) await s.pubSub.get().subscribeToPeer(conn) except CatchableError as exc: - warn "unable to initiate pubsub", exc = exc.msg + trace "unable to initiate pubsub", exc = exc.msg s.dialedPubSubPeers.excl(peerInfo.id) proc subscribe*(s: Switch, topic: string, handler: TopicHandler): Future[void] {.gcsafe.} = diff --git a/tests/pubsub/utils.nim b/tests/pubsub/utils.nim index 79e4e4a8b..22babb4ff 100644 --- a/tests/pubsub/utils.nim +++ b/tests/pubsub/utils.nim @@ -1,4 +1,3 @@ -import options, tables import chronos import ../../libp2p/standard_setup export standard_setup @@ -8,10 +7,10 @@ proc generateNodes*(num: Natural, gossip: bool = false): seq[Switch] = result.add(newStandardSwitch(gossip = gossip)) proc subscribeNodes*(nodes: seq[Switch]) {.async.} = - var dials: seq[Future[Connection]] + var dials: seq[Future[void]] for dialer in nodes: for node in nodes: if dialer.peerInfo.peerId != node.peerInfo.peerId: - dials.add(dialer.dial(node.peerInfo)) + dials.add(dialer.connect(node.peerInfo)) await sleepAsync(100.millis) await allFutures(dials) diff --git a/tests/testinterop.nim b/tests/testinterop.nim index 33fd50076..6d1ed765e 100644 --- a/tests/testinterop.nim +++ b/tests/testinterop.nim @@ -118,8 +118,8 @@ proc testPubSubDaemonPublish(gossip: bool = false, if times >= count and not handlerFuture.finished: handlerFuture.complete(true) - discard await nativeNode.dial(NativePeerInfo.init(daemonPeer.peer, - daemonPeer.addresses)) + await nativeNode.connect(NativePeerInfo.init(daemonPeer.peer, + daemonPeer.addresses)) await sleepAsync(1.seconds) await daemonNode.connect(nativePeer.peerId, nativePeer.addrs) @@ -157,8 +157,8 @@ proc testPubSubNodePublish(gossip: bool = false, let nativePeer = nativeNode.peerInfo var handlerFuture = newFuture[bool]() - discard await nativeNode.dial(NativePeerInfo.init(daemonPeer.peer, - daemonPeer.addresses)) + await nativeNode.connect(NativePeerInfo.init(daemonPeer.peer, + daemonPeer.addresses)) await sleepAsync(1.seconds) await daemonNode.connect(nativePeer.peerId, nativePeer.addrs) diff --git a/tests/testswitch.nim b/tests/testswitch.nim index ce689efbc..b9d968c3e 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -91,18 +91,26 @@ suite "Switch": var peerInfo1, peerInfo2: PeerInfo var switch1, switch2: Switch - (switch1, peerInfo1) = createSwitch(ma1) var awaiters: seq[Future[void]] - awaiters.add(await switch1.start()) + (switch1, peerInfo1) = createSwitch(ma1) + + let testProto = new TestProto + testProto.init() + testProto.codec = TestCodec + switch1.mount(testProto) (switch2, peerInfo2) = createSwitch(ma2) + awaiters.add(await switch1.start()) awaiters.add(await switch2.start()) - var conn = await switch2.dial(switch1.peerInfo) + await switch2.connect(switch1.peerInfo) + let conn = await switch2.dial(switch1.peerInfo, TestCodec) + await conn.writeLp("Hello!") + let msg = cast[string](await conn.readLp()) + check "Hello!" == msg - check isNil(conn) - discard allFutures(switch1.stop(), switch2.stop()) + await allFutures(switch1.stop(), switch2.stop()) await allFutures(awaiters) result = true check: - waitFor(testSwitch()) == true + waitFor(testSwitch()) == true \ No newline at end of file