resolve several races in connmanager (#302)

* resolve several races in connmanager

collections may change while doing await

* close conn

* simplify connmanager API

PeerID avoids nil and ref issues

* remove silly condition
This commit is contained in:
Jacek Sieka 2020-08-01 22:50:40 +02:00 committed by GitHub
parent afcfd27aa0
commit d544b64010
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 77 additions and 76 deletions

View File

@ -7,7 +7,7 @@
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
import tables, sequtils, sets import std/[options, tables, sequtils, sets]
import chronos, chronicles, metrics import chronos, chronicles, metrics
import peerinfo, import peerinfo,
stream/connection, stream/connection,
@ -31,7 +31,6 @@ type
# copies and mangling by unrelated code. # copies and mangling by unrelated code.
conns: Table[PeerID, HashSet[Connection]] conns: Table[PeerID, HashSet[Connection]]
muxed: Table[Connection, MuxerHolder] muxed: Table[Connection, MuxerHolder]
cleanUpLock: Table[PeerInfo, AsyncLock]
maxConns: int maxConns: int
proc newTooManyConnections(): ref TooManyConnections {.inline.} = proc newTooManyConnections(): ref TooManyConnections {.inline.} =
@ -54,9 +53,6 @@ proc contains*(c: ConnManager, conn: Connection): bool =
if isNil(conn.peerInfo): if isNil(conn.peerInfo):
return return
if conn.peerInfo.peerId notin c.conns:
return
return conn in c.conns[conn.peerInfo.peerId] return conn in c.conns[conn.peerInfo.peerId]
proc contains*(c: ConnManager, peerId: PeerID): bool = proc contains*(c: ConnManager, peerId: PeerID): bool =
@ -79,9 +75,24 @@ proc contains*(c: ConnManager, muxer: Muxer): bool =
return muxer == c.muxed[conn].muxer return muxer == c.muxed[conn].muxer
proc closeMuxerHolder(muxerHolder: MuxerHolder) {.async.} =
trace "cleaning up muxer for peer"
await muxerHolder.muxer.close()
if not(isNil(muxerHolder.handle)):
await muxerHolder.handle # TODO noraises?
proc delConn(c: ConnManager, conn: Connection) =
let peerId = conn.peerInfo.peerId
if peerId in c.conns:
c.conns[peerId].excl(conn)
if c.conns[peerId].len == 0:
c.conns.del(peerId)
libp2p_peers.set(c.conns.len.int64)
proc cleanupConn(c: ConnManager, conn: Connection) {.async.} = proc cleanupConn(c: ConnManager, conn: Connection) {.async.} =
## clean connection's resources such as muxers and streams ## clean connection's resources such as muxers and streams
##
if isNil(conn): if isNil(conn):
return return
@ -89,37 +100,20 @@ proc cleanupConn(c: ConnManager, conn: Connection) {.async.} =
if isNil(conn.peerInfo): if isNil(conn.peerInfo):
return return
let peerInfo = conn.peerInfo # Remove connection from all tables without async breaks
let lock = c.cleanUpLock.mgetOrPut(peerInfo, newAsyncLock()) var muxer = some(MuxerHolder())
if not c.muxed.pop(conn, muxer.get()):
muxer = none(MuxerHolder)
delConn(c, conn)
try: try:
await lock.acquire() if muxer.isSome:
trace "cleaning up connection for peer", peer = $peerInfo await closeMuxerHolder(muxer.get())
if conn in c.muxed:
let muxerHolder = c.muxed[conn]
c.muxed.del(conn)
await muxerHolder.muxer.close()
if not(isNil(muxerHolder.handle)):
await muxerHolder.handle
if peerInfo.peerId in c.conns:
c.conns[peerInfo.peerId].excl(conn)
if c.conns[peerInfo.peerId].len == 0:
c.conns.del(peerInfo.peerId)
if not(conn.peerInfo.isClosed()):
conn.peerInfo.close()
finally: finally:
await conn.close() await conn.close()
libp2p_peers.set(c.conns.len.int64)
if lock.locked(): trace "connection cleaned up", peer = $conn.peerInfo
lock.release()
trace "connection cleaned up"
proc onClose(c: ConnManager, conn: Connection) {.async.} = proc onClose(c: ConnManager, conn: Connection) {.async.} =
## connection close even handler ## connection close even handler
@ -132,32 +126,25 @@ proc onClose(c: ConnManager, conn: Connection) {.async.} =
await c.cleanupConn(conn) await c.cleanupConn(conn)
proc selectConn*(c: ConnManager, proc selectConn*(c: ConnManager,
peerInfo: PeerInfo, peerId: PeerID,
dir: Direction): Connection = dir: Direction): Connection =
## Select a connection for the provided peer and direction ## Select a connection for the provided peer and direction
## ##
if isNil(peerInfo):
return
let conns = toSeq( let conns = toSeq(
c.conns.getOrDefault(peerInfo.peerId)) c.conns.getOrDefault(peerId))
.filterIt( it.dir == dir ) .filterIt( it.dir == dir )
if conns.len > 0: if conns.len > 0:
return conns[0] return conns[0]
proc selectConn*(c: ConnManager, peerInfo: PeerInfo): Connection = proc selectConn*(c: ConnManager, peerId: PeerID): Connection =
## Select a connection for the provided giving priority ## Select a connection for the provided giving priority
## to outgoing connections ## to outgoing connections
## ##
if isNil(peerInfo): var conn = c.selectConn(peerId, Direction.Out)
return
var conn = c.selectConn(peerInfo, Direction.Out)
if isNil(conn): if isNil(conn):
conn = c.selectConn(peerInfo, Direction.In) conn = c.selectConn(peerId, Direction.In)
return conn return conn
@ -181,18 +168,18 @@ proc storeConn*(c: ConnManager, conn: Connection) =
if isNil(conn.peerInfo): if isNil(conn.peerInfo):
raise newException(CatchableError, "empty peer info") raise newException(CatchableError, "empty peer info")
let peerInfo = conn.peerInfo let peerId = conn.peerInfo.peerId
if c.conns.getOrDefault(peerInfo.peerId).len > c.maxConns: if c.conns.getOrDefault(peerId).len > c.maxConns:
trace "too many connections", peer = $conn.peerInfo, trace "too many connections", peer = $peerId,
conns = c.conns conns = c.conns
.getOrDefault(peerInfo.peerId).len .getOrDefault(peerId).len
raise newTooManyConnections() raise newTooManyConnections()
if peerInfo.peerId notin c.conns: if peerId notin c.conns:
c.conns[peerInfo.peerId] = initHashSet[Connection]() c.conns[peerId] = initHashSet[Connection]()
c.conns[peerInfo.peerId].incl(conn) c.conns[peerId].incl(conn)
# launch on close listener # launch on close listener
asyncCheck c.onClose(conn) asyncCheck c.onClose(conn)
@ -222,25 +209,25 @@ proc storeMuxer*(c: ConnManager,
muxer: muxer, muxer: muxer,
handle: handle) handle: handle)
trace "storred connection", connections = c.conns.len trace "stored connection", connections = c.conns.len
proc getMuxedStream*(c: ConnManager, proc getMuxedStream*(c: ConnManager,
peerInfo: PeerInfo, peerId: PeerID,
dir: Direction): Future[Connection] {.async, gcsafe.} = dir: Direction): Future[Connection] {.async, gcsafe.} =
## get a muxed stream for the provided peer ## get a muxed stream for the provided peer
## with the given direction ## with the given direction
## ##
let muxer = c.selectMuxer(c.selectConn(peerInfo, dir)) let muxer = c.selectMuxer(c.selectConn(peerId, dir))
if not(isNil(muxer)): if not(isNil(muxer)):
return await muxer.newStream() return await muxer.newStream()
proc getMuxedStream*(c: ConnManager, proc getMuxedStream*(c: ConnManager,
peerInfo: PeerInfo): Future[Connection] {.async, gcsafe.} = peerId: PeerID): Future[Connection] {.async, gcsafe.} =
## get a muxed stream for the passed peer from any connection ## get a muxed stream for the passed peer from any connection
## ##
let muxer = c.selectMuxer(c.selectConn(peerInfo)) let muxer = c.selectMuxer(c.selectConn(peerId))
if not(isNil(muxer)): if not(isNil(muxer)):
return await muxer.newStream() return await muxer.newStream()
@ -253,24 +240,38 @@ proc getMuxedStream*(c: ConnManager,
if not(isNil(muxer)): if not(isNil(muxer)):
return await muxer.newStream() return await muxer.newStream()
proc dropPeer*(c: ConnManager, peerInfo: PeerInfo) {.async.} = proc dropPeer*(c: ConnManager, peerId: PeerID) {.async.} =
## drop connections and cleanup resources for peer ## drop connections and cleanup resources for peer
## ##
let conns = c.conns.getOrDefault(peerId)
for conn in conns:
delConn(c, conn)
for conn in c.conns.getOrDefault(peerInfo.peerId): var muxers: seq[MuxerHolder]
if not(isNil(conn)): for conn in conns:
await c.cleanupConn(conn) if conn in c.muxed:
muxers.add c.muxed[conn]
c.muxed.del(conn)
for muxer in muxers:
await closeMuxerHolder(muxer)
for conn in conns:
await conn.close()
proc close*(c: ConnManager) {.async.} = proc close*(c: ConnManager) {.async.} =
## cleanup resources for the connection ## cleanup resources for the connection
## manager ## manager
## ##
let conns = c.conns
c.conns.clear()
for conns in toSeq(c.conns.values): let muxed = c.muxed
for conn in conns: c.muxed.clear()
try:
await c.cleanupConn(conn) for _, muxer in muxed:
except CancelledError as exc: await closeMuxerHolder(muxer)
raise exc
except CatchableError as exc: for _, conns2 in conns:
warn "error cleaning up connections" for conn in conns2:
await conn.close()

View File

@ -188,7 +188,6 @@ proc mux(s: Switch, conn: Connection) {.async, gcsafe.} =
# new stream for identify # new stream for identify
var stream = await muxer.newStream() var stream = await muxer.newStream()
var handlerFut: Future[void]
defer: defer:
if not(isNil(stream)): if not(isNil(stream)):
@ -196,7 +195,7 @@ proc mux(s: Switch, conn: Connection) {.async, gcsafe.} =
# call muxer handler, this should # call muxer handler, this should
# not end until muxer ends # not end until muxer ends
handlerFut = muxer.handle() let handlerFut = muxer.handle()
# do identify first, so that we have a # do identify first, so that we have a
# PeerInfo in case we didn't before # PeerInfo in case we didn't before
@ -212,7 +211,8 @@ proc mux(s: Switch, conn: Connection) {.async, gcsafe.} =
s.connManager.storeMuxer(muxer, handlerFut) # update muxer with handler s.connManager.storeMuxer(muxer, handlerFut) # update muxer with handler
proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} = proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} =
await s.connManager.dropPeer(peer) if not peer.isNil:
await s.connManager.dropPeer(peer.peerId)
proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
logScope: logScope:
@ -231,7 +231,7 @@ proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, g
raise newException(CatchableError, raise newException(CatchableError,
"unable to identify connection, stopping upgrade") "unable to identify connection, stopping upgrade")
trace "succesfully upgraded outgoing connection", oid = sconn.oid trace "successfully upgraded outgoing connection", oid = sconn.oid
return sconn return sconn
@ -290,7 +290,7 @@ proc internalConnect(s: Switch,
try: try:
await lock.acquire() await lock.acquire()
trace "about to dial peer", peer = id trace "about to dial peer", peer = id
conn = s.connManager.selectConn(peer) conn = s.connManager.selectConn(peer.peerId)
if conn.isNil or (conn.closed or conn.atEof): if conn.isNil or (conn.closed or conn.atEof):
trace "Dialing peer", peer = id trace "Dialing peer", peer = id
for t in s.transports: # for each transport for t in s.transports: # for each transport
@ -323,7 +323,7 @@ proc internalConnect(s: Switch,
s.connManager.storeOutgoing(uconn) s.connManager.storeOutgoing(uconn)
asyncCheck s.triggerHooks(uconn.peerInfo, Lifecycle.Upgraded) asyncCheck s.triggerHooks(uconn.peerInfo, Lifecycle.Upgraded)
conn = uconn conn = uconn
trace "dial succesfull", oid = $conn.oid, peer = $conn.peerInfo trace "dial successful", oid = $conn.oid, peer = $conn.peerInfo
except CatchableError as exc: except CatchableError as exc:
if not(isNil(conn)): if not(isNil(conn)):
await conn.close() await conn.close()
@ -354,7 +354,7 @@ proc internalConnect(s: Switch,
doAssert(conn in s.connManager, "connection not tracked!") doAssert(conn in s.connManager, "connection not tracked!")
trace "dial succesfull", oid = $conn.oid, trace "dial successful", oid = $conn.oid,
peer = $conn.peerInfo peer = $conn.peerInfo
await s.subscribePeer(peer) await s.subscribePeer(peer)
@ -475,7 +475,7 @@ proc subscribePeerInternal(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} =
trace "about to subscribe to pubsub peer", peer = peerInfo.shortLog() trace "about to subscribe to pubsub peer", peer = peerInfo.shortLog()
var stream: Connection var stream: Connection
try: try:
stream = await s.connManager.getMuxedStream(peerInfo) stream = await s.connManager.getMuxedStream(peerInfo.peerId)
if isNil(stream): if isNil(stream):
trace "unable to subscribe to peer", peer = peerInfo.shortLog trace "unable to subscribe to peer", peer = peerInfo.shortLog
return return