Race in connection setup (#464)

* check that connection is not closed or eof

* don't release connection lock prematurely

* test that only valid connections can be added

* correct exception type on closed connection

* add clarifying comment

* use closeWithEOF for more stable test

* misc comments

* log stream id in buffestream asserts

* use closeWithEOF to prevent races in tests

* give some time to the remote handler to trigger

* adding more tests to make codecov happy
This commit is contained in:
Dmitriy Ryajov 2020-12-02 19:24:48 -06:00 committed by GitHub
parent d1c689e5ab
commit e9d4679059
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 82 additions and 31 deletions

View File

@ -208,8 +208,8 @@ proc delConn(c: ConnManager, conn: Connection) =
if c.conns[peerId].len == 0: if c.conns[peerId].len == 0:
c.conns.del(peerId) c.conns.del(peerId)
libp2p_peers.set(c.conns.len.int64)
libp2p_peers.set(c.conns.len.int64)
trace "Removed connection", conn trace "Removed connection", conn
proc cleanupConn(c: ConnManager, conn: Connection) {.async.} = proc cleanupConn(c: ConnManager, conn: Connection) {.async.} =
@ -327,6 +327,10 @@ proc storeConn*(c: ConnManager, conn: Connection) =
if isNil(conn): if isNil(conn):
raise newException(CatchableError, "connection cannot be nil") raise newException(CatchableError, "connection cannot be nil")
if conn.closed() or conn.atEof():
trace "Can't store dead connection", conn
raise newException(CatchableError, "can't store dead connection")
if isNil(conn.peerInfo): if isNil(conn.peerInfo):
raise newException(CatchableError, "empty peer info") raise newException(CatchableError, "empty peer info")
@ -370,6 +374,9 @@ proc storeMuxer*(c: ConnManager,
if isNil(muxer.connection): if isNil(muxer.connection):
raise newException(CatchableError, "muxer's connection cannot be nil") raise newException(CatchableError, "muxer's connection cannot be nil")
if muxer.connection notin c:
raise newException(CatchableError, "cant add muxer for untracked connection")
c.muxed[muxer.connection] = MuxerHolder( c.muxed[muxer.connection] = MuxerHolder(
muxer: muxer, muxer: muxer,
handle: handle) handle: handle)

View File

@ -64,7 +64,9 @@ method pushData*(s: BufferStream, data: seq[byte]) {.base, async.} =
## `pushTo` will block if the queue is full, thus maintaining backpressure. ## `pushTo` will block if the queue is full, thus maintaining backpressure.
## ##
doAssert(not s.pushing, "Only one concurrent push allowed") doAssert(not s.pushing,
&"Only one concurrent push allowed for stream {s.shortLog()}")
if s.isClosed or s.pushedEof: if s.isClosed or s.pushedEof:
raise newLPStreamEOFError() raise newLPStreamEOFError()
@ -84,7 +86,8 @@ method pushEof*(s: BufferStream) {.base, async.} =
if s.pushedEof: if s.pushedEof:
return return
doAssert(not s.pushing, "Only one concurrent push allowed") doAssert(not s.pushing,
&"Only one concurrent push allowed for stream {s.shortLog()}")
s.pushedEof = true s.pushedEof = true
@ -105,7 +108,8 @@ method readOnce*(s: BufferStream,
nbytes: int): nbytes: int):
Future[int] {.async.} = Future[int] {.async.} =
doAssert(nbytes > 0, "nbytes must be positive integer") doAssert(nbytes > 0, "nbytes must be positive integer")
doAssert(not s.reading, "Only one concurrent read allowed") doAssert(not s.reading,
&"Only one concurrent read allowed for stream {s.shortLog()}")
if s.returnedEof: if s.returnedEof:
raise newLPStreamEOFError() raise newLPStreamEOFError()

View File

@ -172,6 +172,14 @@ proc mux(s: Switch, conn: Connection): Future[Muxer] {.async, gcsafe.} =
# store it in muxed connections if we have a peer for it # store it in muxed connections if we have a peer for it
s.connManager.storeMuxer(muxer, muxer.handle()) # store muxer and start read loop s.connManager.storeMuxer(muxer, muxer.handle()) # store muxer and start read loop
try:
await s.identify(muxer)
except CatchableError as exc:
# Identify is non-essential, though if it fails, it might indicate that
# the connection was closed already - this will be picked up by the read
# loop
debug "Could not identify connection", conn, msg = exc.msg
return muxer return muxer
proc disconnect*(s: Switch, peerId: PeerID): Future[void] {.gcsafe.} = proc disconnect*(s: Switch, peerId: PeerID): Future[void] {.gcsafe.} =
@ -195,18 +203,10 @@ proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, g
raise newException(UpgradeFailedError, raise newException(UpgradeFailedError,
"a muxer is required for outgoing connections") "a muxer is required for outgoing connections")
try: if sconn.closed() or isNil(sconn.peerInfo):
await s.identify(muxer)
except CatchableError as exc:
# Identify is non-essential, though if it fails, it might indicate that
# the connection was closed already - this will be picked up by the read
# loop
debug "Could not identify connection", conn, msg = exc.msg
if isNil(sconn.peerInfo):
await sconn.close() await sconn.close()
raise newException(UpgradeFailedError, raise newException(UpgradeFailedError,
"No peerInfo for connection, stopping upgrade") "Connection closed or missing peer info, stopping upgrade")
trace "Upgraded outgoing connection", conn, sconn trace "Upgraded outgoing connection", conn, sconn
@ -302,14 +302,13 @@ proc internalConnect(s: Switch,
if s.peerInfo.peerId == peerId: if s.peerInfo.peerId == peerId:
raise newException(CatchableError, "can't dial self!") raise newException(CatchableError, "can't dial self!")
var conn: Connection
# Ensure there's only one in-flight attempt per peer # Ensure there's only one in-flight attempt per peer
let lock = s.dialLock.mgetOrPut(peerId, newAsyncLock()) let lock = s.dialLock.mgetOrPut(peerId, newAsyncLock())
try: try:
await lock.acquire() await lock.acquire()
# Check if we have a connection already and try to reuse it # Check if we have a connection already and try to reuse it
conn = s.connManager.selectConn(peerId) var conn = s.connManager.selectConn(peerId)
if conn != nil: if conn != nil:
if conn.atEof or conn.closed: if conn.atEof or conn.closed:
# This connection should already have been removed from the connection # This connection should already have been removed from the connection
@ -323,22 +322,25 @@ proc internalConnect(s: Switch,
return conn return conn
conn = await s.dialAndUpgrade(peerId, addrs) conn = await s.dialAndUpgrade(peerId, addrs)
if isNil(conn): # None of the addresses connected
raise newException(DialFailedError, "Unable to establish outgoing link")
# We already check for this in Connection manager
# but a disconnect could have happened right after
# we've added the connection so we check again
# to prevent races due to that.
if conn.closed() or conn.atEof():
# This can happen when the other ends drops us
# before we get a chance to return the connection
# back to the dialer.
trace "Connection dead on arrival", conn
raise newLPStreamClosedError()
return conn
finally: finally:
if lock.locked(): if lock.locked():
lock.release() lock.release()
if isNil(conn): # None of the addresses connected
raise newException(DialFailedError, "Unable to establish outgoing link")
if conn.closed() or conn.atEof():
# This can happen when the other ends drops us
# before we get a chance to return the connection
# back to the dialer.
trace "Connection dead on arrival", conn
raise newLPStreamClosedError()
return conn
proc connect*(s: Switch, peerId: PeerID, addrs: seq[MultiAddress]) {.async.} = proc connect*(s: Switch, peerId: PeerID, addrs: seq[MultiAddress]) {.async.} =
discard await s.internalConnect(peerId, addrs) discard await s.internalConnect(peerId, addrs)

View File

@ -38,6 +38,29 @@ suite "Connection Manager":
await connMngr.close() await connMngr.close()
asyncTest "shouldn't allow a closed connection":
let connMngr = ConnManager.init()
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
let conn = Connection.init(peer, Direction.In)
await conn.close()
expect CatchableError:
connMngr.storeConn(conn)
await connMngr.close()
asyncTest "shouldn't allow an EOFed connection":
let connMngr = ConnManager.init()
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
let conn = Connection.init(peer, Direction.In)
conn.isEof = true
expect CatchableError:
connMngr.storeConn(conn)
await conn.close()
await connMngr.close()
asyncTest "add and retrieve a muxer": asyncTest "add and retrieve a muxer":
let connMngr = ConnManager.init() let connMngr = ConnManager.init()
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()) let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
@ -54,6 +77,20 @@ suite "Connection Manager":
await connMngr.close() await connMngr.close()
asyncTest "shouldn't allow a muxer for an untracked connection":
let connMngr = ConnManager.init()
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
let conn = Connection.init(peer, Direction.In)
let muxer = new Muxer
muxer.connection = conn
expect CatchableError:
connMngr.storeMuxer(muxer)
await conn.close()
await muxer.close()
await connMngr.close()
asyncTest "get conn with direction": asyncTest "get conn with direction":
let connMngr = ConnManager.init() let connMngr = ConnManager.init()
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()) let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())

View File

@ -907,7 +907,7 @@ suite "Mplex":
checkTracker(LPChannelTrackerName) checkTracker(LPChannelTrackerName)
await conn.close() await conn.closeWithEOF()
await mplexDialFut await mplexDialFut
await allFuturesThrowing( await allFuturesThrowing(
transport1.stop(), transport1.stop(),
@ -943,7 +943,8 @@ suite "Mplex":
for i in 0..9: for i in 0..9:
dialStreams.add((await mplexDial.newStream())) dialStreams.add((await mplexDial.newStream()))
await listenConn.close() await sleepAsync(100.millis)
await listenConn.closeWithEOF()
await allFuturesThrowing( await allFuturesThrowing(
(dialStreams & listenStreams) (dialStreams & listenStreams)
.mapIt( it.join() )) .mapIt( it.join() ))

View File

@ -672,7 +672,7 @@ suite "Switch":
proc acceptHandler() {.async, gcsafe.} = proc acceptHandler() {.async, gcsafe.} =
let conn = await transport.accept() let conn = await transport.accept()
await conn.close() await conn.closeWithEOF()
let handlerWait = acceptHandler() let handlerWait = acceptHandler()
let switch = newStandardSwitch(secureManagers = [SecureProtocol.Noise]) let switch = newStandardSwitch(secureManagers = [SecureProtocol.Noise])