diff --git a/libp2p/connmanager.nim b/libp2p/connmanager.nim index 16d463a..45aaca3 100644 --- a/libp2p/connmanager.nim +++ b/libp2p/connmanager.nim @@ -7,7 +7,7 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import tables, sequtils, sets +import std/[options, tables, sequtils, sets] import chronos, chronicles, metrics import peerinfo, stream/connection, @@ -31,7 +31,6 @@ type # copies and mangling by unrelated code. conns: Table[PeerID, HashSet[Connection]] muxed: Table[Connection, MuxerHolder] - cleanUpLock: Table[PeerInfo, AsyncLock] maxConns: int proc newTooManyConnections(): ref TooManyConnections {.inline.} = @@ -54,9 +53,6 @@ proc contains*(c: ConnManager, conn: Connection): bool = if isNil(conn.peerInfo): return - if conn.peerInfo.peerId notin c.conns: - return - return conn in c.conns[conn.peerInfo.peerId] proc contains*(c: ConnManager, peerId: PeerID): bool = @@ -79,9 +75,24 @@ proc contains*(c: ConnManager, muxer: Muxer): bool = return muxer == c.muxed[conn].muxer +proc closeMuxerHolder(muxerHolder: MuxerHolder) {.async.} = + trace "cleaning up muxer for peer" + + await muxerHolder.muxer.close() + if not(isNil(muxerHolder.handle)): + await muxerHolder.handle # TODO noraises? + +proc delConn(c: ConnManager, conn: Connection) = + let peerId = conn.peerInfo.peerId + if peerId in c.conns: + c.conns[peerId].excl(conn) + + if c.conns[peerId].len == 0: + c.conns.del(peerId) + libp2p_peers.set(c.conns.len.int64) + proc cleanupConn(c: ConnManager, conn: Connection) {.async.} = ## clean connection's resources such as muxers and streams - ## if isNil(conn): return @@ -89,37 +100,20 @@ proc cleanupConn(c: ConnManager, conn: Connection) {.async.} = if isNil(conn.peerInfo): return - let peerInfo = conn.peerInfo - let lock = c.cleanUpLock.mgetOrPut(peerInfo, newAsyncLock()) + # Remove connection from all tables without async breaks + var muxer = some(MuxerHolder()) + if not c.muxed.pop(conn, muxer.get()): + muxer = none(MuxerHolder) + + delConn(c, conn) try: - await lock.acquire() - trace "cleaning up connection for peer", peer = $peerInfo - if conn in c.muxed: - let muxerHolder = c.muxed[conn] - c.muxed.del(conn) - - await muxerHolder.muxer.close() - if not(isNil(muxerHolder.handle)): - await muxerHolder.handle - - if peerInfo.peerId in c.conns: - c.conns[peerInfo.peerId].excl(conn) - - if c.conns[peerInfo.peerId].len == 0: - c.conns.del(peerInfo.peerId) - - if not(conn.peerInfo.isClosed()): - conn.peerInfo.close() - + if muxer.isSome: + await closeMuxerHolder(muxer.get()) finally: await conn.close() - libp2p_peers.set(c.conns.len.int64) - if lock.locked(): - lock.release() - - trace "connection cleaned up" + trace "connection cleaned up", peer = $conn.peerInfo proc onClose(c: ConnManager, conn: Connection) {.async.} = ## connection close even handler @@ -132,32 +126,25 @@ proc onClose(c: ConnManager, conn: Connection) {.async.} = await c.cleanupConn(conn) proc selectConn*(c: ConnManager, - peerInfo: PeerInfo, + peerId: PeerID, dir: Direction): Connection = ## Select a connection for the provided peer and direction ## - - if isNil(peerInfo): - return - let conns = toSeq( - c.conns.getOrDefault(peerInfo.peerId)) + c.conns.getOrDefault(peerId)) .filterIt( it.dir == dir ) if conns.len > 0: return conns[0] -proc selectConn*(c: ConnManager, peerInfo: PeerInfo): Connection = +proc selectConn*(c: ConnManager, peerId: PeerID): Connection = ## Select a connection for the provided giving priority ## to outgoing connections ## - if isNil(peerInfo): - return - - var conn = c.selectConn(peerInfo, Direction.Out) + var conn = c.selectConn(peerId, Direction.Out) if isNil(conn): - conn = c.selectConn(peerInfo, Direction.In) + conn = c.selectConn(peerId, Direction.In) return conn @@ -181,18 +168,18 @@ proc storeConn*(c: ConnManager, conn: Connection) = if isNil(conn.peerInfo): raise newException(CatchableError, "empty peer info") - let peerInfo = conn.peerInfo - if c.conns.getOrDefault(peerInfo.peerId).len > c.maxConns: - trace "too many connections", peer = $conn.peerInfo, + let peerId = conn.peerInfo.peerId + if c.conns.getOrDefault(peerId).len > c.maxConns: + trace "too many connections", peer = $peerId, conns = c.conns - .getOrDefault(peerInfo.peerId).len + .getOrDefault(peerId).len raise newTooManyConnections() - if peerInfo.peerId notin c.conns: - c.conns[peerInfo.peerId] = initHashSet[Connection]() + if peerId notin c.conns: + c.conns[peerId] = initHashSet[Connection]() - c.conns[peerInfo.peerId].incl(conn) + c.conns[peerId].incl(conn) # launch on close listener asyncCheck c.onClose(conn) @@ -222,25 +209,25 @@ proc storeMuxer*(c: ConnManager, muxer: muxer, handle: handle) - trace "storred connection", connections = c.conns.len + trace "stored connection", connections = c.conns.len proc getMuxedStream*(c: ConnManager, - peerInfo: PeerInfo, + peerId: PeerID, dir: Direction): Future[Connection] {.async, gcsafe.} = ## get a muxed stream for the provided peer ## with the given direction ## - let muxer = c.selectMuxer(c.selectConn(peerInfo, dir)) + let muxer = c.selectMuxer(c.selectConn(peerId, dir)) if not(isNil(muxer)): return await muxer.newStream() proc getMuxedStream*(c: ConnManager, - peerInfo: PeerInfo): Future[Connection] {.async, gcsafe.} = + peerId: PeerID): Future[Connection] {.async, gcsafe.} = ## get a muxed stream for the passed peer from any connection ## - let muxer = c.selectMuxer(c.selectConn(peerInfo)) + let muxer = c.selectMuxer(c.selectConn(peerId)) if not(isNil(muxer)): return await muxer.newStream() @@ -253,24 +240,38 @@ proc getMuxedStream*(c: ConnManager, if not(isNil(muxer)): return await muxer.newStream() -proc dropPeer*(c: ConnManager, peerInfo: PeerInfo) {.async.} = +proc dropPeer*(c: ConnManager, peerId: PeerID) {.async.} = ## drop connections and cleanup resources for peer ## + let conns = c.conns.getOrDefault(peerId) + for conn in conns: + delConn(c, conn) - for conn in c.conns.getOrDefault(peerInfo.peerId): - if not(isNil(conn)): - await c.cleanupConn(conn) + var muxers: seq[MuxerHolder] + for conn in conns: + if conn in c.muxed: + muxers.add c.muxed[conn] + c.muxed.del(conn) + + for muxer in muxers: + await closeMuxerHolder(muxer) + + for conn in conns: + await conn.close() proc close*(c: ConnManager) {.async.} = ## cleanup resources for the connection ## manager ## + let conns = c.conns + c.conns.clear() - for conns in toSeq(c.conns.values): - for conn in conns: - try: - await c.cleanupConn(conn) - except CancelledError as exc: - raise exc - except CatchableError as exc: - warn "error cleaning up connections" + let muxed = c.muxed + c.muxed.clear() + + for _, muxer in muxed: + await closeMuxerHolder(muxer) + + for _, conns2 in conns: + for conn in conns2: + await conn.close() diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 4e5c27c..1accfe5 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -188,7 +188,6 @@ proc mux(s: Switch, conn: Connection) {.async, gcsafe.} = # new stream for identify var stream = await muxer.newStream() - var handlerFut: Future[void] defer: if not(isNil(stream)): @@ -196,7 +195,7 @@ proc mux(s: Switch, conn: Connection) {.async, gcsafe.} = # call muxer handler, this should # not end until muxer ends - handlerFut = muxer.handle() + let handlerFut = muxer.handle() # do identify first, so that we have a # PeerInfo in case we didn't before @@ -212,7 +211,8 @@ proc mux(s: Switch, conn: Connection) {.async, gcsafe.} = s.connManager.storeMuxer(muxer, handlerFut) # update muxer with handler proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} = - await s.connManager.dropPeer(peer) + if not peer.isNil: + await s.connManager.dropPeer(peer.peerId) proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = logScope: @@ -231,7 +231,7 @@ proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, g raise newException(CatchableError, "unable to identify connection, stopping upgrade") - trace "succesfully upgraded outgoing connection", oid = sconn.oid + trace "successfully upgraded outgoing connection", oid = sconn.oid return sconn @@ -290,7 +290,7 @@ proc internalConnect(s: Switch, try: await lock.acquire() trace "about to dial peer", peer = id - conn = s.connManager.selectConn(peer) + conn = s.connManager.selectConn(peer.peerId) if conn.isNil or (conn.closed or conn.atEof): trace "Dialing peer", peer = id for t in s.transports: # for each transport @@ -323,7 +323,7 @@ proc internalConnect(s: Switch, s.connManager.storeOutgoing(uconn) asyncCheck s.triggerHooks(uconn.peerInfo, Lifecycle.Upgraded) conn = uconn - trace "dial succesfull", oid = $conn.oid, peer = $conn.peerInfo + trace "dial successful", oid = $conn.oid, peer = $conn.peerInfo except CatchableError as exc: if not(isNil(conn)): await conn.close() @@ -354,7 +354,7 @@ proc internalConnect(s: Switch, doAssert(conn in s.connManager, "connection not tracked!") - trace "dial succesfull", oid = $conn.oid, + trace "dial successful", oid = $conn.oid, peer = $conn.peerInfo await s.subscribePeer(peer) @@ -475,7 +475,7 @@ proc subscribePeerInternal(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} = trace "about to subscribe to pubsub peer", peer = peerInfo.shortLog() var stream: Connection try: - stream = await s.connManager.getMuxedStream(peerInfo) + stream = await s.connManager.getMuxedStream(peerInfo.peerId) if isNil(stream): trace "unable to subscribe to peer", peer = peerInfo.shortLog return