Connection manager (#277)

* splitting out connection management

* wip

* wip conn mngr tests

* set peerinfo in contructor

* comments and documentation

* tests

* wip

* add `None` to detect untagged connections

* use `PeerID` to index connections

* fix tests

* remove useless equality
This commit is contained in:
Dmitriy Ryajov 2020-07-17 09:36:48 -06:00 committed by GitHub
parent 170685f9c6
commit 0348773ec9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 620 additions and 373 deletions

276
libp2p/connmanager.nim Normal file
View File

@ -0,0 +1,276 @@
## Nim-LibP2P
## Copyright (c) 2020 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
import tables, sequtils, sets
import chronos, chronicles, metrics
import peerinfo,
stream/connection,
muxers/muxer
declareGauge(libp2p_peers, "total connected peers")
const MaxConnectionsPerPeer = 5
type
TooManyConnections* = object of CatchableError
MuxerHolder = object
muxer: Muxer
handle: Future[void]
ConnManager* = ref object of RootObj
# NOTE: don't change to PeerInfo here
# the reference semantics on the PeerInfo
# object itself make it succeptible to
# copies and mangling by unrelated code.
conns: Table[PeerID, HashSet[Connection]]
muxed: Table[Connection, MuxerHolder]
cleanUpLock: Table[PeerInfo, AsyncLock]
maxConns: int
proc newTooManyConnections(): ref TooManyConnections {.inline.} =
result = newException(TooManyConnections, "too many connections for peer")
proc init*(C: type ConnManager,
maxConnsPerPeer: int = MaxConnectionsPerPeer): ConnManager =
C(maxConns: maxConnsPerPeer,
conns: initTable[PeerID, HashSet[Connection]](),
muxed: initTable[Connection, MuxerHolder]())
proc contains*(c: ConnManager, conn: Connection): bool =
## checks if a connection is being tracked by the
## connection manager
##
if isNil(conn):
return
if isNil(conn.peerInfo):
return
if conn.peerInfo.peerId notin c.conns:
return
return conn in c.conns[conn.peerInfo.peerId]
proc contains*(c: ConnManager, peerId: PeerID): bool =
peerId in c.conns
proc contains*(c: ConnManager, muxer: Muxer): bool =
## checks if a muxer is being tracked by the connection
## manager
##
if isNil(muxer):
return
let conn = muxer.connection
if conn notin c:
return
if conn notin c.muxed:
return
return muxer == c.muxed[conn].muxer
proc cleanupConn(c: ConnManager, conn: Connection) {.async.} =
## clean connection's resources such as muxers and streams
##
if isNil(conn):
return
if isNil(conn.peerInfo):
return
let peerInfo = conn.peerInfo
let lock = c.cleanUpLock.mgetOrPut(peerInfo, newAsyncLock())
try:
await lock.acquire()
trace "cleaning up connection for peer", peer = $peerInfo
if conn in c.muxed:
let muxerHolder = c.muxed[conn]
c.muxed.del(conn)
await muxerHolder.muxer.close()
if not(isNil(muxerHolder.handle)):
await muxerHolder.handle
if peerInfo.peerId in c.conns:
c.conns[peerInfo.peerId].excl(conn)
if c.conns[peerInfo.peerId].len == 0:
c.conns.del(peerInfo.peerId)
if not(conn.peerInfo.isClosed()):
conn.peerInfo.close()
finally:
await conn.close()
libp2p_peers.set(c.conns.len.int64)
if lock.locked():
lock.release()
trace "connection cleaned up"
proc onClose(c: ConnManager, conn: Connection) {.async.} =
## connection close even handler
##
## triggers the connections resource cleanup
##
await conn.closeEvent.wait()
trace "triggering connection cleanup"
await c.cleanupConn(conn)
proc selectConn*(c: ConnManager,
peerInfo: PeerInfo,
dir: Direction): Connection =
## Select a connection for the provided peer and direction
##
if isNil(peerInfo):
return
let conns = toSeq(
c.conns.getOrDefault(peerInfo.peerId))
.filterIt( it.dir == dir )
if conns.len > 0:
return conns[0]
proc selectConn*(c: ConnManager, peerInfo: PeerInfo): Connection =
## Select a connection for the provided giving priority
## to outgoing connections
##
if isNil(peerInfo):
return
var conn = c.selectConn(peerInfo, Direction.Out)
if isNil(conn):
conn = c.selectConn(peerInfo, Direction.In)
return conn
proc selectMuxer*(c: ConnManager, conn: Connection): Muxer =
## select the muxer for the provided connection
##
if isNil(conn):
return
if conn in c.muxed:
return c.muxed[conn].muxer
proc storeConn*(c: ConnManager, conn: Connection) =
## store a connection
##
if isNil(conn):
raise newException(CatchableError, "connection cannot be nil")
if isNil(conn.peerInfo):
raise newException(CatchableError, "empty peer info")
let peerInfo = conn.peerInfo
if c.conns.getOrDefault(peerInfo.peerId).len > c.maxConns:
trace "too many connections", peer = $conn.peerInfo,
conns = c.conns
.getOrDefault(peerInfo.peerId).len
raise newTooManyConnections()
if peerInfo.peerId notin c.conns:
c.conns[peerInfo.peerId] = initHashSet[Connection]()
c.conns[peerInfo.peerId].incl(conn)
# launch on close listener
asyncCheck c.onClose(conn)
libp2p_peers.set(c.conns.len.int64)
proc storeOutgoing*(c: ConnManager, conn: Connection) =
conn.dir = Direction.Out
c.storeConn(conn)
proc storeIncoming*(c: ConnManager, conn: Connection) =
conn.dir = Direction.In
c.storeConn(conn)
proc storeMuxer*(c: ConnManager,
muxer: Muxer,
handle: Future[void] = nil) =
## store the connection and muxer
##
if isNil(muxer):
raise newException(CatchableError, "muxer cannot be nil")
if isNil(muxer.connection):
raise newException(CatchableError, "muxer's connection cannot be nil")
c.muxed[muxer.connection] = MuxerHolder(
muxer: muxer,
handle: handle)
trace "storred connection", connections = c.conns.len
proc getMuxedStream*(c: ConnManager,
peerInfo: PeerInfo,
dir: Direction): Future[Connection] {.async, gcsafe.} =
## get a muxed stream for the provided peer
## with the given direction
##
let muxer = c.selectMuxer(c.selectConn(peerInfo, dir))
if not(isNil(muxer)):
return await muxer.newStream()
proc getMuxedStream*(c: ConnManager,
peerInfo: PeerInfo): Future[Connection] {.async, gcsafe.} =
## get a muxed stream for the passed peer from any connection
##
let muxer = c.selectMuxer(c.selectConn(peerInfo))
if not(isNil(muxer)):
return await muxer.newStream()
proc getMuxedStream*(c: ConnManager,
conn: Connection): Future[Connection] {.async, gcsafe.} =
## get a muxed stream for the passed connection
##
let muxer = c.selectMuxer(conn)
if not(isNil(muxer)):
return await muxer.newStream()
proc dropPeer*(c: ConnManager, peerInfo: PeerInfo) {.async.} =
## drop connections and cleanup resources for peer
##
for conn in c.conns.getOrDefault(peerInfo.peerId):
if not(isNil(conn)):
await c.cleanupConn(conn)
proc close*(c: ConnManager) {.async.} =
## cleanup resources for the connection
## manager
##
for conns in toSeq(c.conns.values):
for conn in conns:
try:
await c.cleanupConn(conn)
except CancelledError as exc:
raise exc
except CatchableError as exc:
warn "error cleaning up connections"

View File

@ -7,7 +7,7 @@
## This file may not be copied, modified, or distributed except according to
## those terms.
import options, sequtils
import options, sequtils, hashes
import chronos, chronicles
import peerid, multiaddress, crypto/crypto
@ -134,18 +134,5 @@ proc publicKey*(p: PeerInfo): Option[PublicKey] {.inline.} =
else:
result = some(p.privateKey.getKey().tryGet())
func `==`*(a, b: PeerInfo): bool =
# override equiality to support both nil and peerInfo comparisons
# this in the future will allow us to recycle refs
let
aptr = cast[pointer](a)
bptr = cast[pointer](b)
if isNil(aptr) and isNil(bptr):
return true
if isNil(aptr) or isNil(bptr):
return false
if aptr == bptr and a.peerId == b.peerId:
return true
func hash*(p: PeerInfo): Hash =
cast[pointer](p).hash

View File

@ -8,14 +8,14 @@
## those terms.
import std/[tables, sequtils, sets]
import chronos, chronicles
import chronos, chronicles, metrics
import pubsubpeer,
rpc/[message, messages],
../protocol,
../../stream/connection,
../../peerid,
../../peerinfo
import metrics
../../peerinfo,
../../errors
export PubSubPeer
export PubSubObserver
@ -233,8 +233,11 @@ method subscribe*(p: PubSub,
p.topics[topic].handler.add(handler)
var sent: seq[Future[void]]
for peer in toSeq(p.peers.values):
await p.sendSubs(peer, @[topic], true)
sent.add(p.sendSubs(peer, @[topic], true))
checkFutures(await allFinished(sent))
# metrics
libp2p_pubsub_topics.inc()

View File

@ -48,22 +48,6 @@ func hash*(p: PubSubPeer): Hash =
# int is either 32/64, so intptr basically, pubsubpeer is a ref
cast[pointer](p).hash
func `==`*(a, b: PubSubPeer): bool =
# override equiality to support both nil and peerInfo comparisons
# this in the future will allow us to recycle refs
let
aptr = cast[pointer](a)
bptr = cast[pointer](b)
if isNil(aptr) and isNil(bptr):
return true
if isNil(aptr) or isNil(bptr):
return false
if aptr == bptr and a.peerInfo == b.peerInfo:
return true
proc id*(p: PubSubPeer): string = p.peerInfo.id
proc inUse*(p: PubSubPeer): bool =

View File

@ -289,7 +289,7 @@ method close*(s: BufferStream) {.async, gcsafe.} =
try:
## close the stream and clear the buffer
if not s.isClosed:
trace "closing bufferstream", oid = s.oid
trace "closing bufferstream", oid = $s.oid
s.isEof = true
for r in s.readReqs:
if not(isNil(r)) and not(r.finished()):

View File

@ -7,6 +7,7 @@
## This file may not be copied, modified, or distributed except according to
## those terms.
import hashes
import chronos, metrics
import lpstream,
../multiaddress,
@ -18,9 +19,13 @@ const
ConnectionTrackerName* = "libp2p.connection"
type
Direction* {.pure.} = enum
None, In, Out
Connection* = ref object of LPStream
peerInfo*: PeerInfo
observedAddr*: Multiaddress
dir*: Direction
ConnectionTracker* = ref object of TrackerBase
opened*: uint64
@ -50,9 +55,11 @@ proc setupConnectionTracker(): ConnectionTracker =
result.isLeaked = leakTransport
addTracker(ConnectionTrackerName, result)
proc init*[T: Connection](self: var T, peerInfo: PeerInfo): T =
new self
self.initStream()
proc init*(C: type Connection,
peerInfo: PeerInfo,
dir: Direction): Connection =
result = C(peerInfo: peerInfo, dir: dir)
result.initStream()
method initStream*(s: Connection) =
if s.objName.len == 0:
@ -63,9 +70,13 @@ method initStream*(s: Connection) =
inc getConnectionTracker().opened
method close*(s: Connection) {.async.} =
await procCall LPStream(s).close()
inc getConnectionTracker().closed
if not s.isClosed:
await procCall LPStream(s).close()
inc getConnectionTracker().closed
proc `$`*(conn: Connection): string =
if not isNil(conn.peerInfo):
result = conn.peerInfo.id
func hash*(p: Connection): Hash =
cast[pointer](p).hash

View File

@ -76,15 +76,6 @@ method initStream*(s: LPStream) {.base.} =
libp2p_open_streams.inc(labelValues = [s.objName])
trace "stream created", oid = $s.oid, name = s.objName
# TODO: debuging aid to troubleshoot streams open/close
# try:
# echo "ChronosStream ", libp2p_open_streams.value(labelValues = ["ChronosStream"])
# echo "SecureConn ", libp2p_open_streams.value(labelValues = ["SecureConn"])
# # doAssert(libp2p_open_streams.value(labelValues = ["ChronosStream"]) >=
# # libp2p_open_streams.value(labelValues = ["SecureConn"]))
# except CatchableError:
# discard
proc join*(s: LPStream): Future[void] =
s.closeEvent.wait()
@ -207,12 +198,3 @@ method close*(s: LPStream) {.base, async.} =
s.closeEvent.fire()
libp2p_open_streams.dec(labelValues = [s.objName])
trace "stream destroyed", oid = $s.oid, name = s.objName
# TODO: debuging aid to troubleshoot streams open/close
# try:
# echo "ChronosStream ", libp2p_open_streams.value(labelValues = ["ChronosStream"])
# echo "SecureConn ", libp2p_open_streams.value(labelValues = ["SecureConn"])
# # doAssert(libp2p_open_streams.value(labelValues = ["ChronosStream"]) >=
# # libp2p_open_streams.value(labelValues = ["SecureConn"]))
# except CatchableError:
# discard

View File

@ -11,7 +11,6 @@ import tables,
sequtils,
options,
sets,
algorithm,
oids
import chronos,
@ -28,6 +27,7 @@ import stream/connection,
protocols/identify,
protocols/pubsub/pubsub,
muxers/muxer,
connmanager,
peerid
logScope:
@ -39,33 +39,16 @@ logScope:
# and only if the channel has been secured (i.e. if a secure manager has been
# previously provided)
declareGauge(libp2p_peers, "total connected peers")
declareCounter(libp2p_dialed_peers, "dialed peers")
declareCounter(libp2p_failed_dials, "failed dials")
declareCounter(libp2p_failed_upgrade, "peers failed upgrade")
const MaxConnectionsPerPeer = 5
type
NoPubSubException* = object of CatchableError
TooManyConnections* = object of CatchableError
Direction {.pure.} = enum
In, Out
ConnectionHolder = object
dir: Direction
conn: Connection
MuxerHolder = object
dir: Direction
muxer: Muxer
handle: Future[void]
Switch* = ref object of RootObj
peerInfo*: PeerInfo
connections*: Table[string, seq[ConnectionHolder]]
muxed*: Table[string, seq[MuxerHolder]]
connManager: ConnManager
transports*: seq[Transport]
protocols*: seq[LPProtocol]
muxers*: Table[string, MuxerProvider]
@ -75,90 +58,20 @@ type
secureManagers*: seq[Secure]
pubSub*: Option[PubSub]
dialLock: Table[string, AsyncLock]
cleanUpLock: Table[string, AsyncLock]
proc newNoPubSubException(): ref NoPubSubException {.inline.} =
result = newException(NoPubSubException, "no pubsub provided!")
proc newTooManyConnections(): ref TooManyConnections {.inline.} =
result = newException(TooManyConnections, "too many connections for peer")
proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.}
proc subscribePeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.}
proc selectConn(s: Switch, peerInfo: PeerInfo): Connection =
## select the "best" connection according to some criteria
##
## Ideally when the connection's stats are available
## we'd select the fastest, but for now we simply pick an outgoing
## connection first if none is available, we pick the first outgoing
##
proc cleanupPubSubPeer(s: Switch, conn: Connection) {.async.} =
await conn.closeEvent.wait()
if s.pubSub.isSome:
await s.pubSub.get().unsubscribePeer(conn.peerInfo)
if isNil(peerInfo):
return
let conns = s.connections
.getOrDefault(peerInfo.id)
# it should be OK to sort on each
# access as there should only be
# up to MaxConnectionsPerPeer entries
.sorted(
proc(a, b: ConnectionHolder): int =
if a.dir < b.dir: -1
elif a.dir == b.dir: 0
else: 1
, SortOrder.Descending)
if conns.len > 0:
return conns[0].conn
proc selectMuxer(s: Switch, conn: Connection): Muxer =
## select the muxer for the supplied connection
##
if isNil(conn):
return
if not(isNil(conn.peerInfo)) and conn.peerInfo.id in s.muxed:
if s.muxed[conn.peerInfo.id].len > 0:
let muxers = s.muxed[conn.peerInfo.id]
.filterIt( it.muxer.connection == conn )
if muxers.len > 0:
return muxers[0].muxer
proc storeConn(s: Switch,
muxer: Muxer,
dir: Direction,
handle: Future[void] = nil) {.async.} =
## store the connection and muxer
##
if isNil(muxer):
return
let conn = muxer.connection
if isNil(conn):
return
let id = conn.peerInfo.id
if s.connections.getOrDefault(id).len > MaxConnectionsPerPeer:
warn "disconnecting peer, too many connections", peer = $conn.peerInfo,
conns = s.connections
.getOrDefault(id).len
await s.disconnect(conn.peerInfo)
raise newTooManyConnections()
s.connections.mgetOrPut(
id,
newSeq[ConnectionHolder]())
.add(ConnectionHolder(conn: conn, dir: dir))
s.muxed.mgetOrPut(
muxer.connection.peerInfo.id,
newSeq[MuxerHolder]())
.add(MuxerHolder(muxer: muxer, handle: handle, dir: dir))
trace "storred connection", connections = s.connections.len
libp2p_peers.set(s.connections.len.int64)
proc isConnected*(s: Switch, peer: PeerInfo): bool =
peer.peerId in s.connManager
proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
if s.secureManagers.len <= 0:
@ -170,9 +83,11 @@ proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
trace "securing connection", codec = manager
let secureProtocol = s.secureManagers.filterIt(it.codec == manager)
# ms.select should deal with the correctness of this
# let's avoid duplicating checks but detect if it fails to do it properly
doAssert(secureProtocol.len > 0)
result = await secureProtocol[0].secure(conn, true)
proc identify(s: Switch, conn: Connection) {.async, gcsafe.} =
@ -218,6 +133,7 @@ proc mux(s: Switch, conn: Connection) {.async, gcsafe.} =
# create new muxer for connection
let muxer = s.muxers[muxerName].newMuxer(conn)
s.connManager.storeMuxer(muxer)
trace "found a muxer", name = muxerName, peer = $conn
@ -247,75 +163,10 @@ proc mux(s: Switch, conn: Connection) {.async, gcsafe.} =
# store it in muxed connections if we have a peer for it
trace "adding muxer for peer", peer = conn.peerInfo.id
await s.storeConn(muxer, Direction.Out, handlerFut)
proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} =
if isNil(conn):
return
if isNil(conn.peerInfo):
return
let id = conn.peerInfo.id
let lock = s.cleanUpLock.mgetOrPut(id, newAsyncLock())
try:
await lock.acquire()
trace "cleaning up connection for peer", peerId = id
if id in s.muxed:
let muxerHolder = s.muxed[id]
.filterIt(
it.muxer.connection == conn
)
if muxerHolder.len > 0:
await muxerHolder[0].muxer.close()
if not(isNil(muxerHolder[0].handle)):
await muxerHolder[0].handle
if id in s.muxed:
s.muxed[id].keepItIf(
it.muxer.connection != conn
)
if s.muxed[id].len == 0:
s.muxed.del(id)
if s.pubSub.isSome:
await s.pubSub.get()
.unsubscribePeer(conn.peerInfo)
if id in s.connections:
s.connections[id].keepItIf(
it.conn != conn
)
if s.connections[id].len == 0:
s.connections.del(id)
# TODO: Investigate cleanupConn() always called twice for one peer.
if not(conn.peerInfo.isClosed()):
conn.peerInfo.close()
finally:
await conn.close()
libp2p_peers.set(s.connections.len.int64)
if lock.locked():
lock.release()
s.connManager.storeMuxer(muxer, handlerFut) # update muxer with handler
proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} =
let connections = s.connections.getOrDefault(peer.id)
for connHolder in connections:
if not isNil(connHolder.conn):
await s.cleanupConn(connHolder.conn)
proc getMuxedStream(s: Switch, peerInfo: PeerInfo): Future[Connection] {.async, gcsafe.} =
# if there is a muxer for the connection
# use it instead to create a muxed stream
let muxer = s.selectMuxer(s.selectConn(peerInfo)) # always get the first muxer here
if not(isNil(muxer)):
return await muxer.newStream()
await s.connManager.dropPeer(peer)
proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
logScope:
@ -388,52 +239,51 @@ proc internalConnect(s: Switch,
var conn: Connection
let lock = s.dialLock.mgetOrPut(id, newAsyncLock())
defer:
try:
await lock.acquire()
trace "about to dial peer", peer = id
conn = s.connManager.selectConn(peer)
if conn.isNil or (conn.closed or conn.atEof):
trace "Dialing peer", peer = id
for t in s.transports: # for each transport
for a in peer.addrs: # for each address
if t.handles(a): # check if it can dial it
trace "Dialing address", address = $a, peer = id
try:
conn = await t.dial(a)
# make sure to assign the peer to the connection
conn.peerInfo = peer
libp2p_dialed_peers.inc()
except CancelledError as exc:
trace "dialing canceled", exc = exc.msg
raise
except CatchableError as exc:
trace "dialing failed", exc = exc.msg
libp2p_failed_dials.inc()
continue
try:
let uconn = await s.upgradeOutgoing(conn)
s.connManager.storeOutgoing(uconn)
conn = uconn
except CatchableError as exc:
if not(isNil(conn)):
await conn.close()
trace "Unable to establish outgoing link", exc = exc.msg
raise exc
if isNil(conn):
libp2p_failed_upgrade.inc()
continue
break
else:
trace "Reusing existing connection", oid = conn.oid
finally:
if lock.locked():
lock.release()
await lock.acquire()
trace "about to dial peer", peer = id
conn = s.selectConn(peer)
if conn.isNil or conn.closed:
trace "Dialing peer", peer = id
for t in s.transports: # for each transport
for a in peer.addrs: # for each address
if t.handles(a): # check if it can dial it
trace "Dialing address", address = $a
try:
conn = await t.dial(a)
libp2p_dialed_peers.inc()
except CancelledError as exc:
trace "dialing canceled", exc = exc.msg
raise
except CatchableError as exc:
trace "dialing failed", exc = exc.msg
libp2p_failed_dials.inc()
continue
# make sure to assign the peer to the connection
conn.peerInfo = peer
try:
conn = await s.upgradeOutgoing(conn)
except CatchableError as exc:
if not(isNil(conn)):
await conn.close()
trace "Unable to establish outgoing link", exc = exc.msg
raise exc
if isNil(conn):
libp2p_failed_upgrade.inc()
continue
conn.closeEvent.wait()
.addCallback do(udata: pointer):
asyncCheck s.cleanupConn(conn)
break
else:
trace "Reusing existing connection", oid = conn.oid
if isNil(conn):
raise newException(CatchableError,
"Unable to establish outgoing link")
@ -443,13 +293,14 @@ proc internalConnect(s: Switch,
raise newException(CatchableError,
"Connection dead on arrival")
doAssert(conn.peerInfo.id in s.connections,
"connection not tracked!")
doAssert(conn in s.connManager, "connection not tracked!")
trace "dial succesfull", oid = $conn.oid,
peer = $conn.peerInfo
await s.subscribePeer(peer)
asyncCheck s.cleanupPubSubPeer(conn)
return conn
proc connect*(s: Switch, peer: PeerInfo) {.async.} =
@ -460,7 +311,7 @@ proc dial*(s: Switch,
proto: string):
Future[Connection] {.async.} =
let conn = await s.internalConnect(peer)
let stream = await s.getMuxedStream(peer)
let stream = await s.connManager.getMuxedStream(conn)
proc cleanup() {.async.} =
if not(isNil(stream)):
@ -505,14 +356,14 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} =
try:
defer:
await s.cleanupConn(conn)
conn.dir = Direction.In # tag connection with direction
await s.upgradeIncoming(conn) # perform upgrade on incoming connection
except CancelledError as exc:
raise exc
except CatchableError as exc:
trace "Exception occurred in Switch.start", exc = exc.msg
finally:
await conn.close()
var startFuts: seq[Future[void]]
for t in s.transports: # for each transport
@ -537,14 +388,8 @@ proc stop*(s: Switch) {.async.} =
if s.pubSub.isSome:
await s.pubSub.get().stop()
for conns in toSeq(s.connections.values):
for conn in conns:
try:
await s.cleanupConn(conn.conn)
except CancelledError as exc:
raise exc
except CatchableError as exc:
warn "error cleaning up connections"
# close and cleanup all connections
await s.connManager.close()
for t in s.transports:
try:
@ -562,7 +407,18 @@ proc subscribePeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} =
trace "about to subscribe to pubsub peer", peer = peerInfo.shortLog()
var stream: Connection
try:
stream = await s.getMuxedStream(peerInfo)
stream = await s.connManager.getMuxedStream(peerInfo)
if isNil(stream):
trace "unable to subscribe to peer", peer = peerInfo.shortLog
return
if not await s.ms.select(stream, s.pubSub.get().codec):
if not(isNil(stream)):
await stream.close()
return
s.pubSub.get().subscribePeer(stream)
except CancelledError as exc:
if not(isNil(stream)):
await stream.close()
@ -574,44 +430,27 @@ proc subscribePeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} =
if not(isNil(stream)):
await stream.close()
if isNil(stream):
trace "unable to subscribe to peer", peer = peerInfo.shortLog
return
if not await s.ms.select(stream, s.pubSub.get().codec):
if not(isNil(stream)):
await stream.close()
return
s.pubSub.get().subscribePeer(stream)
proc subscribe*(s: Switch, topic: string,
handler: TopicHandler): Future[void] =
handler: TopicHandler) {.async.} =
## subscribe to a pubsub topic
if s.pubSub.isNone:
var retFuture = newFuture[void]("Switch.subscribe")
retFuture.fail(newNoPubSubException())
return retFuture
raise newNoPubSubException()
return s.pubSub.get().subscribe(topic, handler)
await s.pubSub.get().subscribe(topic, handler)
proc unsubscribe*(s: Switch, topics: seq[TopicPair]): Future[void] =
proc unsubscribe*(s: Switch, topics: seq[TopicPair]) {.async.} =
## unsubscribe from topics
if s.pubSub.isNone:
var retFuture = newFuture[void]("Switch.unsubscribe")
retFuture.fail(newNoPubSubException())
return retFuture
raise newNoPubSubException()
return s.pubSub.get().unsubscribe(topics)
await s.pubSub.get().unsubscribe(topics)
proc publish*(s: Switch, topic: string, data: seq[byte]): Future[int] =
proc publish*(s: Switch, topic: string, data: seq[byte]): Future[int] {.async.} =
# pubslish to pubsub topic
if s.pubSub.isNone:
var retFuture = newFuture[int]("Switch.publish")
retFuture.fail(newNoPubSubException())
return retFuture
raise newNoPubSubException()
return s.pubSub.get().publish(topic, data)
return await s.pubSub.get().publish(topic, data)
proc addValidator*(s: Switch,
topics: varargs[string],
@ -647,17 +486,17 @@ proc muxerHandler(s: Switch, muxer: Muxer) {.async, gcsafe.} =
muxer.connection.peerInfo = stream.peerInfo
# store muxer and muxed connection
await s.storeConn(muxer, Direction.In)
# store incoming connection
s.connManager.storeIncoming(muxer.connection)
muxer.connection.closeEvent.wait()
.addCallback do(udata: pointer):
asyncCheck s.cleanupConn(muxer.connection)
# store muxer and muxed connection
s.connManager.storeMuxer(muxer)
trace "got new muxer", peer = $muxer.connection.peerInfo
# try establishing a pubsub connection
await s.subscribePeer(muxer.connection.peerInfo)
asyncCheck s.cleanupPubSubPeer(muxer.connection)
except CancelledError as exc:
await muxer.close()
@ -680,8 +519,7 @@ proc newSwitch*(peerInfo: PeerInfo,
peerInfo: peerInfo,
ms: newMultistream(),
transports: transports,
connections: initTable[string, seq[ConnectionHolder]](),
muxed: initTable[string, seq[MuxerHolder]](),
connManager: ConnManager.init(),
identity: identity,
muxers: muxers,
secureManagers: @secureManagers,

192
tests/testconnmngr.nim Normal file
View File

@ -0,0 +1,192 @@
import unittest
import chronos
import ../libp2p/[connmanager,
stream/connection,
crypto/crypto,
muxers/muxer,
peerinfo]
import helpers
type
TestMuxer = ref object of Muxer
peerInfo: PeerInfo
method newStream*(
m: TestMuxer,
name: string = "",
lazy: bool = false):
Future[Connection] {.async, gcsafe.} =
result = Connection.init(m.peerInfo, Direction.Out)
suite "Connection Manager":
teardown:
for tracker in testTrackers():
# echo tracker.dump()
check tracker.isLeaked() == false
test "add and retrive a connection":
let connMngr = ConnManager.init()
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
let conn = Connection.init(peer, Direction.In)
connMngr.storeConn(conn)
check conn in connMngr
let peerConn = connMngr.selectConn(peer)
check peerConn == conn
check peerConn.dir == Direction.In
test "add and retrieve a muxer":
let connMngr = ConnManager.init()
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
let conn = Connection.init(peer, Direction.In)
let muxer = new Muxer
muxer.connection = conn
connMngr.storeConn(conn)
connMngr.storeMuxer(muxer)
check muxer in connMngr
let peerMuxer = connMngr.selectMuxer(conn)
check peerMuxer == muxer
test "get conn with direction":
let connMngr = ConnManager.init()
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
let conn1 = Connection.init(peer, Direction.Out)
let conn2 = Connection.init(peer, Direction.In)
connMngr.storeConn(conn1)
connMngr.storeConn(conn2)
check conn1 in connMngr
check conn2 in connMngr
let outConn = connMngr.selectConn(peer, Direction.Out)
let inConn = connMngr.selectConn(peer, Direction.In)
check outConn != inConn
check outConn.dir == Direction.Out
check inConn.dir == Direction.In
test "get muxed stream for peer":
proc test() {.async.} =
let connMngr = ConnManager.init()
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
let conn = Connection.init(peer, Direction.In)
let muxer = new TestMuxer
muxer.peerInfo = peer
muxer.connection = conn
connMngr.storeConn(conn)
connMngr.storeMuxer(muxer)
check muxer in connMngr
let stream = await connMngr.getMuxedStream(peer)
check not(isNil(stream))
check stream.peerInfo == peer
waitFor(test())
test "get stream from directed connection":
proc test() {.async.} =
let connMngr = ConnManager.init()
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
let conn = Connection.init(peer, Direction.In)
let muxer = new TestMuxer
muxer.peerInfo = peer
muxer.connection = conn
connMngr.storeConn(conn)
connMngr.storeMuxer(muxer)
check muxer in connMngr
check not(isNil((await connMngr.getMuxedStream(peer, Direction.In))))
check isNil((await connMngr.getMuxedStream(peer, Direction.Out)))
waitFor(test())
test "get stream from any connection":
proc test() {.async.} =
let connMngr = ConnManager.init()
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
let conn = Connection.init(peer, Direction.In)
let muxer = new TestMuxer
muxer.peerInfo = peer
muxer.connection = conn
connMngr.storeConn(conn)
connMngr.storeMuxer(muxer)
check muxer in connMngr
check not(isNil((await connMngr.getMuxedStream(conn))))
waitFor(test())
test "should raise on too many connections":
proc test() =
let connMngr = ConnManager.init(1)
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
connMngr.storeConn(Connection.init(peer, Direction.In))
connMngr.storeConn(Connection.init(peer, Direction.In))
connMngr.storeConn(Connection.init(peer, Direction.In))
expect TooManyConnections:
test()
test "cleanup on connection close":
proc test() {.async.} =
let connMngr = ConnManager.init()
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
let conn = Connection.init(peer, Direction.In)
let muxer = new Muxer
muxer.connection = conn
connMngr.storeConn(conn)
connMngr.storeMuxer(muxer)
check conn in connMngr
check muxer in connMngr
await conn.close()
await sleepAsync(10.millis)
check conn notin connMngr
check muxer notin connMngr
waitFor(test())
test "drop connections for peer":
proc test() {.async.} =
let connMngr = ConnManager.init()
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
for i in 0..<2:
let dir = if i mod 2 == 0:
Direction.In else:
Direction.Out
let conn = Connection.init(peer, dir)
let muxer = new Muxer
muxer.connection = conn
connMngr.storeConn(conn)
connMngr.storeMuxer(muxer)
check conn in connMngr
check muxer in connMngr
check not(isNil(connMngr.selectConn(peer, dir)))
check peer in connMngr.peers
await connMngr.dropPeer(peer)
check peer notin connMngr.peers
check isNil(connMngr.selectConn(peer, Direction.In))
check isNil(connMngr.selectConn(peer, Direction.Out))
check connMngr.peers.len == 0
waitFor(test())

View File

@ -1,6 +1,6 @@
{.used.}
import unittest, tables
import unittest
import chronos
import stew/byteutils
import nimcrypto/sysrand
@ -56,6 +56,10 @@ suite "Switch":
awaiters.add(await switch2.start())
let conn = await switch2.dial(switch1.peerInfo, TestCodec)
check switch1.isConnected(switch2.peerInfo)
check switch2.isConnected(switch1.peerInfo)
await conn.writeLp("Hello!")
let msg = string.fromBytes(await conn.readLp(1024))
check "Hello!" == msg
@ -69,6 +73,9 @@ suite "Switch":
# this needs to go at end
await allFuturesThrowing(awaiters)
check not switch1.isConnected(switch2.peerInfo)
check not switch2.isConnected(switch1.peerInfo)
waitFor(testSwitch())
test "e2e should not leak bufferstreams and connections on channel close":
@ -96,6 +103,10 @@ suite "Switch":
awaiters.add(await switch2.start())
let conn = await switch2.dial(switch1.peerInfo, TestCodec)
check switch1.isConnected(switch2.peerInfo)
check switch2.isConnected(switch1.peerInfo)
await conn.writeLp("Hello!")
let msg = string.fromBytes(await conn.readLp(1024))
check "Hello!" == msg
@ -103,20 +114,20 @@ suite "Switch":
await sleepAsync(2.seconds) # wait a little for cleanup to happen
var bufferTracker = getTracker(BufferStreamTrackerName)
# echo bufferTracker.dump()
echo bufferTracker.dump()
# plus 4 for the pubsub streams
check (BufferStreamTracker(bufferTracker).opened ==
(BufferStreamTracker(bufferTracker).closed + 4.uint64))
# var connTracker = getTracker(ConnectionTrackerName)
# echo connTracker.dump()
var connTracker = getTracker(ConnectionTrackerName)
echo connTracker.dump()
# plus 8 is for the secured connection and the socket
# and the pubsub streams that won't clean up until
# `disconnect()` or `stop()`
# check (ConnectionTracker(connTracker).opened ==
# (ConnectionTracker(connTracker).closed + 8.uint64))
check (ConnectionTracker(connTracker).opened ==
(ConnectionTracker(connTracker).closed + 8.uint64))
await allFuturesThrowing(
done.wait(5.seconds),
@ -127,6 +138,9 @@ suite "Switch":
# this needs to go at end
await allFuturesThrowing(awaiters)
check not switch1.isConnected(switch2.peerInfo)
check not switch2.isConnected(switch1.peerInfo)
waitFor(testSwitch())
test "e2e use connect then dial":
@ -153,10 +167,11 @@ suite "Switch":
awaiters.add(await switch2.start())
await switch2.connect(switch1.peerInfo)
check switch1.peerInfo.id in switch2.connections
let conn = await switch2.dial(switch1.peerInfo, TestCodec)
check switch1.isConnected(switch2.peerInfo)
check switch2.isConnected(switch1.peerInfo)
try:
await conn.writeLp("Hello!")
let msg = string.fromBytes(await conn.readLp(1024))
@ -172,6 +187,9 @@ suite "Switch":
)
await allFuturesThrowing(awaiters)
check not switch1.isConnected(switch2.peerInfo)
check not switch2.isConnected(switch1.peerInfo)
check:
waitFor(testSwitch()) == true
@ -186,23 +204,23 @@ suite "Switch":
await switch2.connect(switch1.peerInfo)
check switch1.connections[switch2.peerInfo.id].len > 0
check switch2.connections[switch1.peerInfo.id].len > 0
check switch1.isConnected(switch2.peerInfo)
check switch2.isConnected(switch1.peerInfo)
await sleepAsync(100.millis)
await switch2.disconnect(switch1.peerInfo)
await sleepAsync(2.seconds)
check not switch1.isConnected(switch2.peerInfo)
check not switch2.isConnected(switch1.peerInfo)
var bufferTracker = getTracker(BufferStreamTrackerName)
# echo bufferTracker.dump()
check bufferTracker.isLeaked() == false
# var connTracker = getTracker(ConnectionTrackerName)
var connTracker = getTracker(ConnectionTrackerName)
# echo connTracker.dump()
# check connTracker.isLeaked() == false
check switch2.peerInfo.id notin switch1.connections
check switch1.peerInfo.id notin switch2.connections
check connTracker.isLeaked() == false
await allFuturesThrowing(
switch1.stop(),
@ -210,47 +228,3 @@ suite "Switch":
await allFuturesThrowing(awaiters)
waitFor(testSwitch())
# test "e2e: handle read + secio fragmented":
# proc testListenerDialer(): Future[bool] {.async.} =
# let
# server: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
# serverInfo = PeerInfo.init(PrivateKey.random(ECDSA), [server])
# serverNoise = newSecio(serverInfo.privateKey)
# readTask = newFuture[void]()
# var hugePayload = newSeq[byte](0x1200000)
# check randomBytes(hugePayload) == hugePayload.len
# trace "Sending huge payload", size = hugePayload.len
# proc connHandler(conn: Connection) {.async, gcsafe.} =
# let sconn = await serverNoise.secure(conn)
# defer:
# await sconn.close()
# let msg = await sconn.read(0x1200000)
# check msg == hugePayload
# readTask.complete()
# let
# transport1: TcpTransport = TcpTransport.init()
# asyncCheck await transport1.listen(server, connHandler)
# let
# transport2: TcpTransport = TcpTransport.init()
# clientInfo = PeerInfo.init(PrivateKey.random(ECDSA), [transport1.ma])
# clientNoise = newSecio(clientInfo.privateKey)
# conn = await transport2.dial(transport1.ma)
# sconn = await clientNoise.secure(conn)
# await sconn.write(hugePayload)
# await readTask
# await sconn.close()
# await conn.close()
# await transport2.close()
# await transport1.close()
# result = true
# check:
# waitFor(testListenerDialer()) == true