diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 52277c7..f802f69 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -59,40 +59,32 @@ proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = result = await s.secureManagers[manager].secure(conn) -proc identify*(s: Switch, conn: Connection) {.async, gcsafe.} = +proc identify(s: Switch, conn: Connection): Future[PeerInfo] {.async, gcsafe.} = ## identify the connection - # s.peerInfo.protocols = await s.ms.list(conn) # update protos before engagin in identify + try: if (await s.ms.select(conn, s.identity.codec)): let info = await s.identity.identify(conn, conn.peerInfo) - let id = if conn.peerInfo.peerId.isSome: - conn.peerInfo.peerId.get().pretty - else: - "" + if info.pubKey.isSome: + result.peerId = some(PeerID.init(info.pubKey.get())) # we might not have a peerId at all + + if info.addrs.len > 0: + result.addrs = info.addrs + + if info.protos.len > 0: + result.protocols = info.protos - if id.len > 0 and s.connections.contains(id): - let connection = s.connections[id] - var peerInfo = conn.peerInfo - - if info.pubKey.isSome: - peerInfo.peerId = some(PeerID.init(info.pubKey.get())) # we might not have a peerId at all - - if info.addrs.len > 0: - peerInfo.addrs = info.addrs - - if info.protos.len > 0: - peerInfo.protocols = info.protos - - trace "identify: identified remote peer ", peer = peerInfo.peerId.get().pretty + trace "identify: identified remote peer ", peer = result.peerId.get().pretty except IdentityInvalidMsgError as exc: error "identify: invalid message", msg = exc.msg except IdentityNoMatchError as exc: error "identify: peer's public keys don't match ", msg = exc.msg proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} = - trace "muxing connection" ## mux incoming connection + + trace "muxing connection" let muxers = toSeq(s.muxers.keys) if muxers.len == 0: warn "no muxers registered, skipping upgrade flow" @@ -118,13 +110,9 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} = peer = conn.peerInfo.peerId.get().pretty ) - # do identify first, so that we have a + # do identify first, so that we have a # PeerInfo in case we didn't before - await s.identify(stream) - - # update main connection with refreshed info - if stream.peerInfo.peerId.isSome: - conn.peerInfo = stream.peerInfo + conn.peerInfo = await s.identify(stream) await stream.close() # close idenity stream trace "connection's peerInfo", peerInfo = conn.peerInfo.peerId @@ -138,23 +126,6 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} = trace "adding muxer for peer", peer = conn.peerInfo.peerId.get().pretty s.muxed[conn.peerInfo.peerId.get().pretty] = muxer -proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = - trace "handling connection", conn = conn - result = conn - ## perform upgrade flow - if result.peerInfo.peerId.isSome: - let id = result.peerInfo.peerId.get().pretty - if s.connections.contains(id): - # if we already have a connection for this peer, - # close the incoming connection and return the - # existing one - await result.close() - return s.connections[id] - s.connections[id] = result - - result = await s.secure(conn) # secure the connection - await s.mux(result) # mux it if possible - proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} = if conn.peerInfo.peerId.isSome: let id = conn.peerInfo.peerId.get().pretty @@ -173,6 +144,50 @@ proc getMuxedStream(s: Switch, peerInfo: PeerInfo): Future[Option[Connection]] { let conn = await muxer.newStream() result = some(conn) +proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = + trace "handling connection", conn = conn + result = conn + ## perform upgrade flow + if result.peerInfo.peerId.isSome: + let id = result.peerInfo.peerId.get().pretty + if s.connections.contains(id): + # if we already have a connection for this peer, + # close the incoming connection and return the + # existing one + await result.close() + return s.connections[id] + s.connections[id] = result + + result = await s.secure(conn) # secure the connection + await s.mux(result) # mux it if possible + +proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = + trace "upgrading incoming connection" + let ms = newMultistream() + + # secure incoming connections + proc securedHandler (conn: Connection, + proto: string) + {.async, gcsafe, closure.} = + trace "Securing connection" + let secure = s.secureManagers[proto] + let sconn = await secure.secure(conn) + if not isNil(sconn): + # add the muxer + for muxer in s.muxers.values: + ms.addHandler(muxer.codec, muxer) + + # handle subsequent requests + await ms.handle(sconn) + + if (await ms.select(conn)): # just handshake + # add the secure handlers + for k in s.secureManagers.keys: + ms.addHandler(k, securedHandler) + + # handle secured connections + await ms.handle(conn) + proc dial*(s: Switch, peer: PeerInfo, proto: string = ""): @@ -208,33 +223,6 @@ proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} = s.ms.addHandler(proto.codec, proto) -proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = - trace "upgrading incoming connection" - let ms = newMultistream() - - # secure incoming connections - proc securedHandler (conn: Connection, - proto: string) - {.async, gcsafe, closure.} = - trace "Securing connection" - let secure = s.secureManagers[proto] - let sconn = await secure.secure(conn) - if not isNil(sconn): - # add the muxer - for muxer in s.muxers.values: - ms.addHandler(muxer.codec, muxer) - - # handle subsequent requests - await ms.handle(sconn) - - if (await ms.select(conn)): # just handshake - # add the secure handlers - for k in s.secureManagers.keys: - ms.addHandler(k, securedHandler) - - # handle secured connections - await ms.handle(conn) - proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} = try: @@ -244,13 +232,15 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = var startFuts: seq[Future[void]] for t in s.transports: # for each transport - for a in s.peerInfo.addrs: + for i, a in s.peerInfo.addrs: if t.handles(a): # check if it handles the multiaddr var server = await t.listen(a, handle) + s.peerInfo.addrs[i] = t.ma # update peer's address startFuts.add(server) result = startFuts # listen for incoming connections proc stop*(s: Switch) {.async.} = + await allFutures(toSeq(s.connections.values).mapIt(it.close())) await allFutures(s.transports.mapIt(it.close())) proc subscribeToPeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} = @@ -307,7 +297,8 @@ proc newSwitch*(peerInfo: PeerInfo, val.muxerHandler = proc(muxer: Muxer) {.async, gcsafe.} = trace "got new muxer" let stream = await muxer.newStream() - await s.identify(stream) + muxer.connection.peerInfo = await s.identify(stream) + await stream.close() for k in secureManagers.keys: trace "adding secure manager ", codec = secureManagers[k].codec