diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 6e53a49bc..88e54d57b 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -48,7 +48,7 @@ type pubSub*: Option[PubSub] dialedPubSubPeers: HashSet[string] -proc newNoPubSubException(): ref Exception {.inline.} = +proc newNoPubSubException(): ref CatchableError {.inline.} = result = newException(NoPubSubException, "no pubsub provided!") proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = @@ -134,23 +134,26 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} = s.muxed[conn.peerInfo.id] = muxer proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} = - if not isNil(conn.peerInfo): - let id = conn.peerInfo.id - trace "cleaning up connection for peer", peerId = id - if id in s.muxed: - await s.muxed[id].close() - s.muxed.del(id) + try: + if not isNil(conn.peerInfo): + let id = conn.peerInfo.id + trace "cleaning up connection for peer", peerId = id + if id in s.muxed: + await s.muxed[id].close() + s.muxed.del(id) - if id in s.connections: - if not s.connections[id].closed: - await s.connections[id].close() - s.connections.del(id) + if id in s.connections: + if not s.connections[id].closed: + await s.connections[id].close() + s.connections.del(id) - s.dialedPubSubPeers.excl(id) + s.dialedPubSubPeers.excl(id) - # TODO: Investigate cleanupConn() always called twice for one peer. - if not(conn.peerInfo.isClosed()): - conn.peerInfo.close() + # TODO: Investigate cleanupConn() always called twice for one peer. + if not(conn.peerInfo.isClosed()): + conn.peerInfo.close() + except CatchableError as exc: + trace "exception cleaning up connection", exc = exc.msg proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} = let conn = s.connections.getOrDefault(peer.id) @@ -167,25 +170,19 @@ proc getMuxedStream(s: Switch, peerInfo: PeerInfo): Future[Connection] {.async, result = conn proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = - try: - trace "handling connection", conn = $conn - result = conn + trace "handling connection", conn = $conn + result = conn - # don't mux/secure twise - if conn.peerInfo.id in s.muxed: - return + # don't mux/secure twise + if conn.peerInfo.id in s.muxed: + return - result = await s.secure(result) # secure the connection - if isNil(result): - return + result = await s.secure(result) # secure the connection + if isNil(result): + return - await s.mux(result) # mux it if possible - s.connections[conn.peerInfo.id] = result - except CancelledError as exc: - raise exc - except CatchableError as exc: - debug "Couldn't upgrade outgoing connection", msg = exc.msg - return nil + await s.mux(result) # mux it if possible + s.connections[conn.peerInfo.id] = result proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = trace "upgrading incoming connection", conn = $conn @@ -207,27 +204,33 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = ms.addHandler(muxer.codec, muxer) # handle subsequent requests - await ms.handle(sconn) - await sconn.close() + try: + await ms.handle(sconn) + finally: + await sconn.close() + except CancelledError as exc: raise exc except CatchableError as exc: debug "ending secured handler", err = exc.msg - if (await ms.select(conn)): # just handshake - # add the secure handlers - for k in s.secureManagers.keys: - ms.addHandler(k, securedHandler) - try: - # handle secured connections - await ms.handle(conn) + try: + if (await ms.select(conn)): # just handshake + # add the secure handlers + for k in s.secureManagers.keys: + ms.addHandler(k, securedHandler) + + # handle secured connections + await ms.handle(conn) + finally: + await conn.close() except CancelledError as exc: raise exc except CatchableError as exc: debug "ending multistream", err = exc.msg -proc subscribeToPeer(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} +proc subscribeToPeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} proc internalConnect(s: Switch, peer: PeerInfo): Future[Connection] {.async.} = @@ -239,13 +242,7 @@ proc internalConnect(s: Switch, for a in peer.addrs: # for each address if t.handles(a): # check if it can dial it trace "Dialing address", address = $a - try: - conn = await t.dial(a) - except CancelledError as exc: - raise exc - except CatchableError as exc: - trace "couldn't dial peer, transport failed", exc = exc.msg, address = a - continue + conn = await t.dial(a) # make sure to assign the peer to the connection conn.peerInfo = peer conn = await s.upgradeOutgoing(conn) @@ -253,8 +250,9 @@ proc internalConnect(s: Switch, continue conn.closeEvent.wait() - .addCallback do (udata: pointer): - asyncCheck s.cleanupConn(conn) + .addCallback do(udata: pointer): + asyncCheck s.cleanupConn(conn) + break else: trace "Reusing existing connection" @@ -289,7 +287,7 @@ proc dial*(s: Switch, if not await s.ms.select(result, proto): warn "Unable to select sub-protocol", proto = proto - raise newException(CatchableError, &"unable to select protocol: {proto}") + return nil proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} = if isNil(proto.handler): @@ -307,14 +305,14 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} = try: - await s.upgradeIncoming(conn) # perform upgrade on incoming connection + try: + await s.upgradeIncoming(conn) # perform upgrade on incoming connection + finally: + await s.cleanupConn(conn) except CancelledError as exc: raise exc except CatchableError as exc: trace "Exception occurred in Switch.start", exc = exc.msg - finally: - await conn.close() - await s.cleanupConn(conn) var startFuts: seq[Future[void]] for t in s.transports: # for each transport @@ -338,27 +336,27 @@ proc stop*(s: Switch) {.async.} = if s.pubSub.isSome: await s.pubSub.get().stop() - checkFutures( - await allFinished( - toSeq(s.connections.values).mapIt(s.cleanupConn(it)))) + await all( + toSeq(s.connections.values) + .mapIt(s.cleanupConn(it))) - checkFutures( - await allFinished( - s.transports.mapIt(it.close()))) + await all( + s.transports.mapIt(it.close())) trace "switch stopped" -proc subscribeToPeer(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} = +proc subscribeToPeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} = ## Subscribe to pub sub peer if s.pubSub.isSome and peerInfo.id notin s.dialedPubSubPeers: try: s.dialedPubSubPeers.incl(peerInfo.id) let conn = await s.dial(peerInfo, s.pubSub.get().codec) + if isNil(conn): + trace "unable to subscribe to peer" + return await s.pubSub.get().subscribeToPeer(conn) - except CancelledError as exc: - raise exc except CatchableError as exc: - warn "unable to initiate pubsub", exc = exc.msg + trace "exception in subscribe to peer", exc = exc.msg finally: s.dialedPubSubPeers.excl(peerInfo.id) @@ -434,19 +432,25 @@ proc newSwitch*(peerInfo: PeerInfo, for key, val in muxers: val.streamHandler = result.streamHandler val.muxerHandler = proc(muxer: Muxer) {.async, gcsafe.} = - trace "got new muxer" - let stream = await muxer.newStream() - muxer.connection.peerInfo = await s.identify(stream) - await stream.close() + var stream: Connection + try: + trace "got new muxer" + stream = await muxer.newStream() + muxer.connection.peerInfo = await s.identify(stream) - # store muxer for connection - s.muxed[muxer.connection.peerInfo.id] = muxer + # store muxer for connection + s.muxed[muxer.connection.peerInfo.id] = muxer - # store muxed connection - s.connections[muxer.connection.peerInfo.id] = muxer.connection + # store muxed connection + s.connections[muxer.connection.peerInfo.id] = muxer.connection - # try establishing a pubsub connection - await s.subscribeToPeer(muxer.connection.peerInfo) + # try establishing a pubsub connection + await s.subscribeToPeer(muxer.connection.peerInfo) + except CatchableError as exc: + trace "exception in muxer handler", exc = exc.msg + finally: + if not(isNil(stream)): + await stream.close() for proto in secureManagers: trace "adding secure manager ", codec = proto.codec