mirror of https://github.com/vacp2p/nim-libp2p.git
Race in connection setup (#464)
* check that connection is not closed or eof * don't release connection lock prematurely * test that only valid connections can be added * correct exception type on closed connection * add clarifying comment * use closeWithEOF for more stable test * misc comments * log stream id in buffestream asserts * use closeWithEOF to prevent races in tests * give some time to the remote handler to trigger * adding more tests to make codecov happy
This commit is contained in:
parent
d1c689e5ab
commit
e9d4679059
|
@ -208,8 +208,8 @@ proc delConn(c: ConnManager, conn: Connection) =
|
|||
|
||||
if c.conns[peerId].len == 0:
|
||||
c.conns.del(peerId)
|
||||
libp2p_peers.set(c.conns.len.int64)
|
||||
|
||||
libp2p_peers.set(c.conns.len.int64)
|
||||
trace "Removed connection", conn
|
||||
|
||||
proc cleanupConn(c: ConnManager, conn: Connection) {.async.} =
|
||||
|
@ -327,6 +327,10 @@ proc storeConn*(c: ConnManager, conn: Connection) =
|
|||
if isNil(conn):
|
||||
raise newException(CatchableError, "connection cannot be nil")
|
||||
|
||||
if conn.closed() or conn.atEof():
|
||||
trace "Can't store dead connection", conn
|
||||
raise newException(CatchableError, "can't store dead connection")
|
||||
|
||||
if isNil(conn.peerInfo):
|
||||
raise newException(CatchableError, "empty peer info")
|
||||
|
||||
|
@ -370,6 +374,9 @@ proc storeMuxer*(c: ConnManager,
|
|||
if isNil(muxer.connection):
|
||||
raise newException(CatchableError, "muxer's connection cannot be nil")
|
||||
|
||||
if muxer.connection notin c:
|
||||
raise newException(CatchableError, "cant add muxer for untracked connection")
|
||||
|
||||
c.muxed[muxer.connection] = MuxerHolder(
|
||||
muxer: muxer,
|
||||
handle: handle)
|
||||
|
|
|
@ -64,7 +64,9 @@ method pushData*(s: BufferStream, data: seq[byte]) {.base, async.} =
|
|||
## `pushTo` will block if the queue is full, thus maintaining backpressure.
|
||||
##
|
||||
|
||||
doAssert(not s.pushing, "Only one concurrent push allowed")
|
||||
doAssert(not s.pushing,
|
||||
&"Only one concurrent push allowed for stream {s.shortLog()}")
|
||||
|
||||
if s.isClosed or s.pushedEof:
|
||||
raise newLPStreamEOFError()
|
||||
|
||||
|
@ -84,7 +86,8 @@ method pushEof*(s: BufferStream) {.base, async.} =
|
|||
if s.pushedEof:
|
||||
return
|
||||
|
||||
doAssert(not s.pushing, "Only one concurrent push allowed")
|
||||
doAssert(not s.pushing,
|
||||
&"Only one concurrent push allowed for stream {s.shortLog()}")
|
||||
|
||||
s.pushedEof = true
|
||||
|
||||
|
@ -105,7 +108,8 @@ method readOnce*(s: BufferStream,
|
|||
nbytes: int):
|
||||
Future[int] {.async.} =
|
||||
doAssert(nbytes > 0, "nbytes must be positive integer")
|
||||
doAssert(not s.reading, "Only one concurrent read allowed")
|
||||
doAssert(not s.reading,
|
||||
&"Only one concurrent read allowed for stream {s.shortLog()}")
|
||||
|
||||
if s.returnedEof:
|
||||
raise newLPStreamEOFError()
|
||||
|
|
|
@ -172,6 +172,14 @@ proc mux(s: Switch, conn: Connection): Future[Muxer] {.async, gcsafe.} =
|
|||
# store it in muxed connections if we have a peer for it
|
||||
s.connManager.storeMuxer(muxer, muxer.handle()) # store muxer and start read loop
|
||||
|
||||
try:
|
||||
await s.identify(muxer)
|
||||
except CatchableError as exc:
|
||||
# Identify is non-essential, though if it fails, it might indicate that
|
||||
# the connection was closed already - this will be picked up by the read
|
||||
# loop
|
||||
debug "Could not identify connection", conn, msg = exc.msg
|
||||
|
||||
return muxer
|
||||
|
||||
proc disconnect*(s: Switch, peerId: PeerID): Future[void] {.gcsafe.} =
|
||||
|
@ -195,18 +203,10 @@ proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, g
|
|||
raise newException(UpgradeFailedError,
|
||||
"a muxer is required for outgoing connections")
|
||||
|
||||
try:
|
||||
await s.identify(muxer)
|
||||
except CatchableError as exc:
|
||||
# Identify is non-essential, though if it fails, it might indicate that
|
||||
# the connection was closed already - this will be picked up by the read
|
||||
# loop
|
||||
debug "Could not identify connection", conn, msg = exc.msg
|
||||
|
||||
if isNil(sconn.peerInfo):
|
||||
if sconn.closed() or isNil(sconn.peerInfo):
|
||||
await sconn.close()
|
||||
raise newException(UpgradeFailedError,
|
||||
"No peerInfo for connection, stopping upgrade")
|
||||
"Connection closed or missing peer info, stopping upgrade")
|
||||
|
||||
trace "Upgraded outgoing connection", conn, sconn
|
||||
|
||||
|
@ -302,14 +302,13 @@ proc internalConnect(s: Switch,
|
|||
if s.peerInfo.peerId == peerId:
|
||||
raise newException(CatchableError, "can't dial self!")
|
||||
|
||||
var conn: Connection
|
||||
# Ensure there's only one in-flight attempt per peer
|
||||
let lock = s.dialLock.mgetOrPut(peerId, newAsyncLock())
|
||||
try:
|
||||
await lock.acquire()
|
||||
|
||||
# Check if we have a connection already and try to reuse it
|
||||
conn = s.connManager.selectConn(peerId)
|
||||
var conn = s.connManager.selectConn(peerId)
|
||||
if conn != nil:
|
||||
if conn.atEof or conn.closed:
|
||||
# This connection should already have been removed from the connection
|
||||
|
@ -323,22 +322,25 @@ proc internalConnect(s: Switch,
|
|||
return conn
|
||||
|
||||
conn = await s.dialAndUpgrade(peerId, addrs)
|
||||
if isNil(conn): # None of the addresses connected
|
||||
raise newException(DialFailedError, "Unable to establish outgoing link")
|
||||
|
||||
# We already check for this in Connection manager
|
||||
# but a disconnect could have happened right after
|
||||
# we've added the connection so we check again
|
||||
# to prevent races due to that.
|
||||
if conn.closed() or conn.atEof():
|
||||
# This can happen when the other ends drops us
|
||||
# before we get a chance to return the connection
|
||||
# back to the dialer.
|
||||
trace "Connection dead on arrival", conn
|
||||
raise newLPStreamClosedError()
|
||||
|
||||
return conn
|
||||
finally:
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
|
||||
if isNil(conn): # None of the addresses connected
|
||||
raise newException(DialFailedError, "Unable to establish outgoing link")
|
||||
|
||||
if conn.closed() or conn.atEof():
|
||||
# This can happen when the other ends drops us
|
||||
# before we get a chance to return the connection
|
||||
# back to the dialer.
|
||||
trace "Connection dead on arrival", conn
|
||||
raise newLPStreamClosedError()
|
||||
|
||||
return conn
|
||||
|
||||
proc connect*(s: Switch, peerId: PeerID, addrs: seq[MultiAddress]) {.async.} =
|
||||
discard await s.internalConnect(peerId, addrs)
|
||||
|
||||
|
|
|
@ -38,6 +38,29 @@ suite "Connection Manager":
|
|||
|
||||
await connMngr.close()
|
||||
|
||||
asyncTest "shouldn't allow a closed connection":
|
||||
let connMngr = ConnManager.init()
|
||||
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
|
||||
let conn = Connection.init(peer, Direction.In)
|
||||
await conn.close()
|
||||
|
||||
expect CatchableError:
|
||||
connMngr.storeConn(conn)
|
||||
|
||||
await connMngr.close()
|
||||
|
||||
asyncTest "shouldn't allow an EOFed connection":
|
||||
let connMngr = ConnManager.init()
|
||||
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
|
||||
let conn = Connection.init(peer, Direction.In)
|
||||
conn.isEof = true
|
||||
|
||||
expect CatchableError:
|
||||
connMngr.storeConn(conn)
|
||||
|
||||
await conn.close()
|
||||
await connMngr.close()
|
||||
|
||||
asyncTest "add and retrieve a muxer":
|
||||
let connMngr = ConnManager.init()
|
||||
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
|
||||
|
@ -54,6 +77,20 @@ suite "Connection Manager":
|
|||
|
||||
await connMngr.close()
|
||||
|
||||
asyncTest "shouldn't allow a muxer for an untracked connection":
|
||||
let connMngr = ConnManager.init()
|
||||
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
|
||||
let conn = Connection.init(peer, Direction.In)
|
||||
let muxer = new Muxer
|
||||
muxer.connection = conn
|
||||
|
||||
expect CatchableError:
|
||||
connMngr.storeMuxer(muxer)
|
||||
|
||||
await conn.close()
|
||||
await muxer.close()
|
||||
await connMngr.close()
|
||||
|
||||
asyncTest "get conn with direction":
|
||||
let connMngr = ConnManager.init()
|
||||
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
|
||||
|
|
|
@ -907,7 +907,7 @@ suite "Mplex":
|
|||
|
||||
checkTracker(LPChannelTrackerName)
|
||||
|
||||
await conn.close()
|
||||
await conn.closeWithEOF()
|
||||
await mplexDialFut
|
||||
await allFuturesThrowing(
|
||||
transport1.stop(),
|
||||
|
@ -943,7 +943,8 @@ suite "Mplex":
|
|||
for i in 0..9:
|
||||
dialStreams.add((await mplexDial.newStream()))
|
||||
|
||||
await listenConn.close()
|
||||
await sleepAsync(100.millis)
|
||||
await listenConn.closeWithEOF()
|
||||
await allFuturesThrowing(
|
||||
(dialStreams & listenStreams)
|
||||
.mapIt( it.join() ))
|
||||
|
|
|
@ -672,7 +672,7 @@ suite "Switch":
|
|||
|
||||
proc acceptHandler() {.async, gcsafe.} =
|
||||
let conn = await transport.accept()
|
||||
await conn.close()
|
||||
await conn.closeWithEOF()
|
||||
|
||||
let handlerWait = acceptHandler()
|
||||
let switch = newStandardSwitch(secureManagers = [SecureProtocol.Noise])
|
||||
|
|
Loading…
Reference in New Issue