diff --git a/libp2p/connmanager.nim b/libp2p/connmanager.nim index 36fbe6d..ad37f00 100644 --- a/libp2p/connmanager.nim +++ b/libp2p/connmanager.nim @@ -208,8 +208,8 @@ proc delConn(c: ConnManager, conn: Connection) = if c.conns[peerId].len == 0: c.conns.del(peerId) - libp2p_peers.set(c.conns.len.int64) + libp2p_peers.set(c.conns.len.int64) trace "Removed connection", conn proc cleanupConn(c: ConnManager, conn: Connection) {.async.} = @@ -327,6 +327,10 @@ proc storeConn*(c: ConnManager, conn: Connection) = if isNil(conn): raise newException(CatchableError, "connection cannot be nil") + if conn.closed() or conn.atEof(): + trace "Can't store dead connection", conn + raise newException(CatchableError, "can't store dead connection") + if isNil(conn.peerInfo): raise newException(CatchableError, "empty peer info") @@ -370,6 +374,9 @@ proc storeMuxer*(c: ConnManager, if isNil(muxer.connection): raise newException(CatchableError, "muxer's connection cannot be nil") + if muxer.connection notin c: + raise newException(CatchableError, "cant add muxer for untracked connection") + c.muxed[muxer.connection] = MuxerHolder( muxer: muxer, handle: handle) diff --git a/libp2p/stream/bufferstream.nim b/libp2p/stream/bufferstream.nim index 69e4e57..60a551d 100644 --- a/libp2p/stream/bufferstream.nim +++ b/libp2p/stream/bufferstream.nim @@ -64,7 +64,9 @@ method pushData*(s: BufferStream, data: seq[byte]) {.base, async.} = ## `pushTo` will block if the queue is full, thus maintaining backpressure. ## - doAssert(not s.pushing, "Only one concurrent push allowed") + doAssert(not s.pushing, + &"Only one concurrent push allowed for stream {s.shortLog()}") + if s.isClosed or s.pushedEof: raise newLPStreamEOFError() @@ -84,7 +86,8 @@ method pushEof*(s: BufferStream) {.base, async.} = if s.pushedEof: return - doAssert(not s.pushing, "Only one concurrent push allowed") + doAssert(not s.pushing, + &"Only one concurrent push allowed for stream {s.shortLog()}") s.pushedEof = true @@ -105,7 +108,8 @@ method readOnce*(s: BufferStream, nbytes: int): Future[int] {.async.} = doAssert(nbytes > 0, "nbytes must be positive integer") - doAssert(not s.reading, "Only one concurrent read allowed") + doAssert(not s.reading, + &"Only one concurrent read allowed for stream {s.shortLog()}") if s.returnedEof: raise newLPStreamEOFError() diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 3a67ef1..d254ae6 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -172,6 +172,14 @@ proc mux(s: Switch, conn: Connection): Future[Muxer] {.async, gcsafe.} = # store it in muxed connections if we have a peer for it s.connManager.storeMuxer(muxer, muxer.handle()) # store muxer and start read loop + try: + await s.identify(muxer) + except CatchableError as exc: + # Identify is non-essential, though if it fails, it might indicate that + # the connection was closed already - this will be picked up by the read + # loop + debug "Could not identify connection", conn, msg = exc.msg + return muxer proc disconnect*(s: Switch, peerId: PeerID): Future[void] {.gcsafe.} = @@ -195,18 +203,10 @@ proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, g raise newException(UpgradeFailedError, "a muxer is required for outgoing connections") - try: - await s.identify(muxer) - except CatchableError as exc: - # Identify is non-essential, though if it fails, it might indicate that - # the connection was closed already - this will be picked up by the read - # loop - debug "Could not identify connection", conn, msg = exc.msg - - if isNil(sconn.peerInfo): + if sconn.closed() or isNil(sconn.peerInfo): await sconn.close() raise newException(UpgradeFailedError, - "No peerInfo for connection, stopping upgrade") + "Connection closed or missing peer info, stopping upgrade") trace "Upgraded outgoing connection", conn, sconn @@ -302,14 +302,13 @@ proc internalConnect(s: Switch, if s.peerInfo.peerId == peerId: raise newException(CatchableError, "can't dial self!") - var conn: Connection # Ensure there's only one in-flight attempt per peer let lock = s.dialLock.mgetOrPut(peerId, newAsyncLock()) try: await lock.acquire() # Check if we have a connection already and try to reuse it - conn = s.connManager.selectConn(peerId) + var conn = s.connManager.selectConn(peerId) if conn != nil: if conn.atEof or conn.closed: # This connection should already have been removed from the connection @@ -323,22 +322,25 @@ proc internalConnect(s: Switch, return conn conn = await s.dialAndUpgrade(peerId, addrs) + if isNil(conn): # None of the addresses connected + raise newException(DialFailedError, "Unable to establish outgoing link") + + # We already check for this in Connection manager + # but a disconnect could have happened right after + # we've added the connection so we check again + # to prevent races due to that. + if conn.closed() or conn.atEof(): + # This can happen when the other ends drops us + # before we get a chance to return the connection + # back to the dialer. + trace "Connection dead on arrival", conn + raise newLPStreamClosedError() + + return conn finally: if lock.locked(): lock.release() - if isNil(conn): # None of the addresses connected - raise newException(DialFailedError, "Unable to establish outgoing link") - - if conn.closed() or conn.atEof(): - # This can happen when the other ends drops us - # before we get a chance to return the connection - # back to the dialer. - trace "Connection dead on arrival", conn - raise newLPStreamClosedError() - - return conn - proc connect*(s: Switch, peerId: PeerID, addrs: seq[MultiAddress]) {.async.} = discard await s.internalConnect(peerId, addrs) diff --git a/tests/testconnmngr.nim b/tests/testconnmngr.nim index 8afad8a..f700200 100644 --- a/tests/testconnmngr.nim +++ b/tests/testconnmngr.nim @@ -38,6 +38,29 @@ suite "Connection Manager": await connMngr.close() + asyncTest "shouldn't allow a closed connection": + let connMngr = ConnManager.init() + let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()) + let conn = Connection.init(peer, Direction.In) + await conn.close() + + expect CatchableError: + connMngr.storeConn(conn) + + await connMngr.close() + + asyncTest "shouldn't allow an EOFed connection": + let connMngr = ConnManager.init() + let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()) + let conn = Connection.init(peer, Direction.In) + conn.isEof = true + + expect CatchableError: + connMngr.storeConn(conn) + + await conn.close() + await connMngr.close() + asyncTest "add and retrieve a muxer": let connMngr = ConnManager.init() let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()) @@ -54,6 +77,20 @@ suite "Connection Manager": await connMngr.close() + asyncTest "shouldn't allow a muxer for an untracked connection": + let connMngr = ConnManager.init() + let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()) + let conn = Connection.init(peer, Direction.In) + let muxer = new Muxer + muxer.connection = conn + + expect CatchableError: + connMngr.storeMuxer(muxer) + + await conn.close() + await muxer.close() + await connMngr.close() + asyncTest "get conn with direction": let connMngr = ConnManager.init() let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()) diff --git a/tests/testmplex.nim b/tests/testmplex.nim index 19d3082..028a5bb 100644 --- a/tests/testmplex.nim +++ b/tests/testmplex.nim @@ -907,7 +907,7 @@ suite "Mplex": checkTracker(LPChannelTrackerName) - await conn.close() + await conn.closeWithEOF() await mplexDialFut await allFuturesThrowing( transport1.stop(), @@ -943,7 +943,8 @@ suite "Mplex": for i in 0..9: dialStreams.add((await mplexDial.newStream())) - await listenConn.close() + await sleepAsync(100.millis) + await listenConn.closeWithEOF() await allFuturesThrowing( (dialStreams & listenStreams) .mapIt( it.join() )) diff --git a/tests/testswitch.nim b/tests/testswitch.nim index 4786f85..b77cff9 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -672,7 +672,7 @@ suite "Switch": proc acceptHandler() {.async, gcsafe.} = let conn = await transport.accept() - await conn.close() + await conn.closeWithEOF() let handlerWait = acceptHandler() let switch = newStandardSwitch(secureManagers = [SecureProtocol.Noise])