Race in connection setup (#464)
* check that connection is not closed or eof * don't release connection lock prematurely * test that only valid connections can be added * correct exception type on closed connection * add clarifying comment * use closeWithEOF for more stable test * misc comments * log stream id in buffestream asserts * use closeWithEOF to prevent races in tests * give some time to the remote handler to trigger * adding more tests to make codecov happy
This commit is contained in:
parent
d1c689e5ab
commit
e9d4679059
|
@ -208,8 +208,8 @@ proc delConn(c: ConnManager, conn: Connection) =
|
||||||
|
|
||||||
if c.conns[peerId].len == 0:
|
if c.conns[peerId].len == 0:
|
||||||
c.conns.del(peerId)
|
c.conns.del(peerId)
|
||||||
libp2p_peers.set(c.conns.len.int64)
|
|
||||||
|
|
||||||
|
libp2p_peers.set(c.conns.len.int64)
|
||||||
trace "Removed connection", conn
|
trace "Removed connection", conn
|
||||||
|
|
||||||
proc cleanupConn(c: ConnManager, conn: Connection) {.async.} =
|
proc cleanupConn(c: ConnManager, conn: Connection) {.async.} =
|
||||||
|
@ -327,6 +327,10 @@ proc storeConn*(c: ConnManager, conn: Connection) =
|
||||||
if isNil(conn):
|
if isNil(conn):
|
||||||
raise newException(CatchableError, "connection cannot be nil")
|
raise newException(CatchableError, "connection cannot be nil")
|
||||||
|
|
||||||
|
if conn.closed() or conn.atEof():
|
||||||
|
trace "Can't store dead connection", conn
|
||||||
|
raise newException(CatchableError, "can't store dead connection")
|
||||||
|
|
||||||
if isNil(conn.peerInfo):
|
if isNil(conn.peerInfo):
|
||||||
raise newException(CatchableError, "empty peer info")
|
raise newException(CatchableError, "empty peer info")
|
||||||
|
|
||||||
|
@ -370,6 +374,9 @@ proc storeMuxer*(c: ConnManager,
|
||||||
if isNil(muxer.connection):
|
if isNil(muxer.connection):
|
||||||
raise newException(CatchableError, "muxer's connection cannot be nil")
|
raise newException(CatchableError, "muxer's connection cannot be nil")
|
||||||
|
|
||||||
|
if muxer.connection notin c:
|
||||||
|
raise newException(CatchableError, "cant add muxer for untracked connection")
|
||||||
|
|
||||||
c.muxed[muxer.connection] = MuxerHolder(
|
c.muxed[muxer.connection] = MuxerHolder(
|
||||||
muxer: muxer,
|
muxer: muxer,
|
||||||
handle: handle)
|
handle: handle)
|
||||||
|
|
|
@ -64,7 +64,9 @@ method pushData*(s: BufferStream, data: seq[byte]) {.base, async.} =
|
||||||
## `pushTo` will block if the queue is full, thus maintaining backpressure.
|
## `pushTo` will block if the queue is full, thus maintaining backpressure.
|
||||||
##
|
##
|
||||||
|
|
||||||
doAssert(not s.pushing, "Only one concurrent push allowed")
|
doAssert(not s.pushing,
|
||||||
|
&"Only one concurrent push allowed for stream {s.shortLog()}")
|
||||||
|
|
||||||
if s.isClosed or s.pushedEof:
|
if s.isClosed or s.pushedEof:
|
||||||
raise newLPStreamEOFError()
|
raise newLPStreamEOFError()
|
||||||
|
|
||||||
|
@ -84,7 +86,8 @@ method pushEof*(s: BufferStream) {.base, async.} =
|
||||||
if s.pushedEof:
|
if s.pushedEof:
|
||||||
return
|
return
|
||||||
|
|
||||||
doAssert(not s.pushing, "Only one concurrent push allowed")
|
doAssert(not s.pushing,
|
||||||
|
&"Only one concurrent push allowed for stream {s.shortLog()}")
|
||||||
|
|
||||||
s.pushedEof = true
|
s.pushedEof = true
|
||||||
|
|
||||||
|
@ -105,7 +108,8 @@ method readOnce*(s: BufferStream,
|
||||||
nbytes: int):
|
nbytes: int):
|
||||||
Future[int] {.async.} =
|
Future[int] {.async.} =
|
||||||
doAssert(nbytes > 0, "nbytes must be positive integer")
|
doAssert(nbytes > 0, "nbytes must be positive integer")
|
||||||
doAssert(not s.reading, "Only one concurrent read allowed")
|
doAssert(not s.reading,
|
||||||
|
&"Only one concurrent read allowed for stream {s.shortLog()}")
|
||||||
|
|
||||||
if s.returnedEof:
|
if s.returnedEof:
|
||||||
raise newLPStreamEOFError()
|
raise newLPStreamEOFError()
|
||||||
|
|
|
@ -172,6 +172,14 @@ proc mux(s: Switch, conn: Connection): Future[Muxer] {.async, gcsafe.} =
|
||||||
# store it in muxed connections if we have a peer for it
|
# store it in muxed connections if we have a peer for it
|
||||||
s.connManager.storeMuxer(muxer, muxer.handle()) # store muxer and start read loop
|
s.connManager.storeMuxer(muxer, muxer.handle()) # store muxer and start read loop
|
||||||
|
|
||||||
|
try:
|
||||||
|
await s.identify(muxer)
|
||||||
|
except CatchableError as exc:
|
||||||
|
# Identify is non-essential, though if it fails, it might indicate that
|
||||||
|
# the connection was closed already - this will be picked up by the read
|
||||||
|
# loop
|
||||||
|
debug "Could not identify connection", conn, msg = exc.msg
|
||||||
|
|
||||||
return muxer
|
return muxer
|
||||||
|
|
||||||
proc disconnect*(s: Switch, peerId: PeerID): Future[void] {.gcsafe.} =
|
proc disconnect*(s: Switch, peerId: PeerID): Future[void] {.gcsafe.} =
|
||||||
|
@ -195,18 +203,10 @@ proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, g
|
||||||
raise newException(UpgradeFailedError,
|
raise newException(UpgradeFailedError,
|
||||||
"a muxer is required for outgoing connections")
|
"a muxer is required for outgoing connections")
|
||||||
|
|
||||||
try:
|
if sconn.closed() or isNil(sconn.peerInfo):
|
||||||
await s.identify(muxer)
|
|
||||||
except CatchableError as exc:
|
|
||||||
# Identify is non-essential, though if it fails, it might indicate that
|
|
||||||
# the connection was closed already - this will be picked up by the read
|
|
||||||
# loop
|
|
||||||
debug "Could not identify connection", conn, msg = exc.msg
|
|
||||||
|
|
||||||
if isNil(sconn.peerInfo):
|
|
||||||
await sconn.close()
|
await sconn.close()
|
||||||
raise newException(UpgradeFailedError,
|
raise newException(UpgradeFailedError,
|
||||||
"No peerInfo for connection, stopping upgrade")
|
"Connection closed or missing peer info, stopping upgrade")
|
||||||
|
|
||||||
trace "Upgraded outgoing connection", conn, sconn
|
trace "Upgraded outgoing connection", conn, sconn
|
||||||
|
|
||||||
|
@ -302,14 +302,13 @@ proc internalConnect(s: Switch,
|
||||||
if s.peerInfo.peerId == peerId:
|
if s.peerInfo.peerId == peerId:
|
||||||
raise newException(CatchableError, "can't dial self!")
|
raise newException(CatchableError, "can't dial self!")
|
||||||
|
|
||||||
var conn: Connection
|
|
||||||
# Ensure there's only one in-flight attempt per peer
|
# Ensure there's only one in-flight attempt per peer
|
||||||
let lock = s.dialLock.mgetOrPut(peerId, newAsyncLock())
|
let lock = s.dialLock.mgetOrPut(peerId, newAsyncLock())
|
||||||
try:
|
try:
|
||||||
await lock.acquire()
|
await lock.acquire()
|
||||||
|
|
||||||
# Check if we have a connection already and try to reuse it
|
# Check if we have a connection already and try to reuse it
|
||||||
conn = s.connManager.selectConn(peerId)
|
var conn = s.connManager.selectConn(peerId)
|
||||||
if conn != nil:
|
if conn != nil:
|
||||||
if conn.atEof or conn.closed:
|
if conn.atEof or conn.closed:
|
||||||
# This connection should already have been removed from the connection
|
# This connection should already have been removed from the connection
|
||||||
|
@ -323,22 +322,25 @@ proc internalConnect(s: Switch,
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
conn = await s.dialAndUpgrade(peerId, addrs)
|
conn = await s.dialAndUpgrade(peerId, addrs)
|
||||||
|
if isNil(conn): # None of the addresses connected
|
||||||
|
raise newException(DialFailedError, "Unable to establish outgoing link")
|
||||||
|
|
||||||
|
# We already check for this in Connection manager
|
||||||
|
# but a disconnect could have happened right after
|
||||||
|
# we've added the connection so we check again
|
||||||
|
# to prevent races due to that.
|
||||||
|
if conn.closed() or conn.atEof():
|
||||||
|
# This can happen when the other ends drops us
|
||||||
|
# before we get a chance to return the connection
|
||||||
|
# back to the dialer.
|
||||||
|
trace "Connection dead on arrival", conn
|
||||||
|
raise newLPStreamClosedError()
|
||||||
|
|
||||||
|
return conn
|
||||||
finally:
|
finally:
|
||||||
if lock.locked():
|
if lock.locked():
|
||||||
lock.release()
|
lock.release()
|
||||||
|
|
||||||
if isNil(conn): # None of the addresses connected
|
|
||||||
raise newException(DialFailedError, "Unable to establish outgoing link")
|
|
||||||
|
|
||||||
if conn.closed() or conn.atEof():
|
|
||||||
# This can happen when the other ends drops us
|
|
||||||
# before we get a chance to return the connection
|
|
||||||
# back to the dialer.
|
|
||||||
trace "Connection dead on arrival", conn
|
|
||||||
raise newLPStreamClosedError()
|
|
||||||
|
|
||||||
return conn
|
|
||||||
|
|
||||||
proc connect*(s: Switch, peerId: PeerID, addrs: seq[MultiAddress]) {.async.} =
|
proc connect*(s: Switch, peerId: PeerID, addrs: seq[MultiAddress]) {.async.} =
|
||||||
discard await s.internalConnect(peerId, addrs)
|
discard await s.internalConnect(peerId, addrs)
|
||||||
|
|
||||||
|
|
|
@ -38,6 +38,29 @@ suite "Connection Manager":
|
||||||
|
|
||||||
await connMngr.close()
|
await connMngr.close()
|
||||||
|
|
||||||
|
asyncTest "shouldn't allow a closed connection":
|
||||||
|
let connMngr = ConnManager.init()
|
||||||
|
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
|
||||||
|
let conn = Connection.init(peer, Direction.In)
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
expect CatchableError:
|
||||||
|
connMngr.storeConn(conn)
|
||||||
|
|
||||||
|
await connMngr.close()
|
||||||
|
|
||||||
|
asyncTest "shouldn't allow an EOFed connection":
|
||||||
|
let connMngr = ConnManager.init()
|
||||||
|
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
|
||||||
|
let conn = Connection.init(peer, Direction.In)
|
||||||
|
conn.isEof = true
|
||||||
|
|
||||||
|
expect CatchableError:
|
||||||
|
connMngr.storeConn(conn)
|
||||||
|
|
||||||
|
await conn.close()
|
||||||
|
await connMngr.close()
|
||||||
|
|
||||||
asyncTest "add and retrieve a muxer":
|
asyncTest "add and retrieve a muxer":
|
||||||
let connMngr = ConnManager.init()
|
let connMngr = ConnManager.init()
|
||||||
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
|
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
|
||||||
|
@ -54,6 +77,20 @@ suite "Connection Manager":
|
||||||
|
|
||||||
await connMngr.close()
|
await connMngr.close()
|
||||||
|
|
||||||
|
asyncTest "shouldn't allow a muxer for an untracked connection":
|
||||||
|
let connMngr = ConnManager.init()
|
||||||
|
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
|
||||||
|
let conn = Connection.init(peer, Direction.In)
|
||||||
|
let muxer = new Muxer
|
||||||
|
muxer.connection = conn
|
||||||
|
|
||||||
|
expect CatchableError:
|
||||||
|
connMngr.storeMuxer(muxer)
|
||||||
|
|
||||||
|
await conn.close()
|
||||||
|
await muxer.close()
|
||||||
|
await connMngr.close()
|
||||||
|
|
||||||
asyncTest "get conn with direction":
|
asyncTest "get conn with direction":
|
||||||
let connMngr = ConnManager.init()
|
let connMngr = ConnManager.init()
|
||||||
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
|
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
|
||||||
|
|
|
@ -907,7 +907,7 @@ suite "Mplex":
|
||||||
|
|
||||||
checkTracker(LPChannelTrackerName)
|
checkTracker(LPChannelTrackerName)
|
||||||
|
|
||||||
await conn.close()
|
await conn.closeWithEOF()
|
||||||
await mplexDialFut
|
await mplexDialFut
|
||||||
await allFuturesThrowing(
|
await allFuturesThrowing(
|
||||||
transport1.stop(),
|
transport1.stop(),
|
||||||
|
@ -943,7 +943,8 @@ suite "Mplex":
|
||||||
for i in 0..9:
|
for i in 0..9:
|
||||||
dialStreams.add((await mplexDial.newStream()))
|
dialStreams.add((await mplexDial.newStream()))
|
||||||
|
|
||||||
await listenConn.close()
|
await sleepAsync(100.millis)
|
||||||
|
await listenConn.closeWithEOF()
|
||||||
await allFuturesThrowing(
|
await allFuturesThrowing(
|
||||||
(dialStreams & listenStreams)
|
(dialStreams & listenStreams)
|
||||||
.mapIt( it.join() ))
|
.mapIt( it.join() ))
|
||||||
|
|
|
@ -672,7 +672,7 @@ suite "Switch":
|
||||||
|
|
||||||
proc acceptHandler() {.async, gcsafe.} =
|
proc acceptHandler() {.async, gcsafe.} =
|
||||||
let conn = await transport.accept()
|
let conn = await transport.accept()
|
||||||
await conn.close()
|
await conn.closeWithEOF()
|
||||||
|
|
||||||
let handlerWait = acceptHandler()
|
let handlerWait = acceptHandler()
|
||||||
let switch = newStandardSwitch(secureManagers = [SecureProtocol.Noise])
|
let switch = newStandardSwitch(secureManagers = [SecureProtocol.Noise])
|
||||||
|
|
Loading…
Reference in New Issue