Connection manager (#277)

* splitting out connection management

* wip

* wip conn mngr tests

* set peerinfo in contructor

* comments and documentation

* tests

* wip

* add `None` to detect untagged connections

* use `PeerID` to index connections

* fix tests

* remove useless equality
This commit is contained in:
Dmitriy Ryajov 2020-07-17 09:36:48 -06:00 committed by GitHub
parent 170685f9c6
commit 0348773ec9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 620 additions and 373 deletions

276
libp2p/connmanager.nim Normal file
View File

@ -0,0 +1,276 @@
## Nim-LibP2P
## Copyright (c) 2020 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
import tables, sequtils, sets
import chronos, chronicles, metrics
import peerinfo,
stream/connection,
muxers/muxer
declareGauge(libp2p_peers, "total connected peers")
const MaxConnectionsPerPeer = 5
type
TooManyConnections* = object of CatchableError
MuxerHolder = object
muxer: Muxer
handle: Future[void]
ConnManager* = ref object of RootObj
# NOTE: don't change to PeerInfo here
# the reference semantics on the PeerInfo
# object itself make it succeptible to
# copies and mangling by unrelated code.
conns: Table[PeerID, HashSet[Connection]]
muxed: Table[Connection, MuxerHolder]
cleanUpLock: Table[PeerInfo, AsyncLock]
maxConns: int
proc newTooManyConnections(): ref TooManyConnections {.inline.} =
result = newException(TooManyConnections, "too many connections for peer")
proc init*(C: type ConnManager,
maxConnsPerPeer: int = MaxConnectionsPerPeer): ConnManager =
C(maxConns: maxConnsPerPeer,
conns: initTable[PeerID, HashSet[Connection]](),
muxed: initTable[Connection, MuxerHolder]())
proc contains*(c: ConnManager, conn: Connection): bool =
## checks if a connection is being tracked by the
## connection manager
##
if isNil(conn):
return
if isNil(conn.peerInfo):
return
if conn.peerInfo.peerId notin c.conns:
return
return conn in c.conns[conn.peerInfo.peerId]
proc contains*(c: ConnManager, peerId: PeerID): bool =
peerId in c.conns
proc contains*(c: ConnManager, muxer: Muxer): bool =
## checks if a muxer is being tracked by the connection
## manager
##
if isNil(muxer):
return
let conn = muxer.connection
if conn notin c:
return
if conn notin c.muxed:
return
return muxer == c.muxed[conn].muxer
proc cleanupConn(c: ConnManager, conn: Connection) {.async.} =
## clean connection's resources such as muxers and streams
##
if isNil(conn):
return
if isNil(conn.peerInfo):
return
let peerInfo = conn.peerInfo
let lock = c.cleanUpLock.mgetOrPut(peerInfo, newAsyncLock())
try:
await lock.acquire()
trace "cleaning up connection for peer", peer = $peerInfo
if conn in c.muxed:
let muxerHolder = c.muxed[conn]
c.muxed.del(conn)
await muxerHolder.muxer.close()
if not(isNil(muxerHolder.handle)):
await muxerHolder.handle
if peerInfo.peerId in c.conns:
c.conns[peerInfo.peerId].excl(conn)
if c.conns[peerInfo.peerId].len == 0:
c.conns.del(peerInfo.peerId)
if not(conn.peerInfo.isClosed()):
conn.peerInfo.close()
finally:
await conn.close()
libp2p_peers.set(c.conns.len.int64)
if lock.locked():
lock.release()
trace "connection cleaned up"
proc onClose(c: ConnManager, conn: Connection) {.async.} =
## connection close even handler
##
## triggers the connections resource cleanup
##
await conn.closeEvent.wait()
trace "triggering connection cleanup"
await c.cleanupConn(conn)
proc selectConn*(c: ConnManager,
peerInfo: PeerInfo,
dir: Direction): Connection =
## Select a connection for the provided peer and direction
##
if isNil(peerInfo):
return
let conns = toSeq(
c.conns.getOrDefault(peerInfo.peerId))
.filterIt( it.dir == dir )
if conns.len > 0:
return conns[0]
proc selectConn*(c: ConnManager, peerInfo: PeerInfo): Connection =
## Select a connection for the provided giving priority
## to outgoing connections
##
if isNil(peerInfo):
return
var conn = c.selectConn(peerInfo, Direction.Out)
if isNil(conn):
conn = c.selectConn(peerInfo, Direction.In)
return conn
proc selectMuxer*(c: ConnManager, conn: Connection): Muxer =
## select the muxer for the provided connection
##
if isNil(conn):
return
if conn in c.muxed:
return c.muxed[conn].muxer
proc storeConn*(c: ConnManager, conn: Connection) =
## store a connection
##
if isNil(conn):
raise newException(CatchableError, "connection cannot be nil")
if isNil(conn.peerInfo):
raise newException(CatchableError, "empty peer info")
let peerInfo = conn.peerInfo
if c.conns.getOrDefault(peerInfo.peerId).len > c.maxConns:
trace "too many connections", peer = $conn.peerInfo,
conns = c.conns
.getOrDefault(peerInfo.peerId).len
raise newTooManyConnections()
if peerInfo.peerId notin c.conns:
c.conns[peerInfo.peerId] = initHashSet[Connection]()
c.conns[peerInfo.peerId].incl(conn)
# launch on close listener
asyncCheck c.onClose(conn)
libp2p_peers.set(c.conns.len.int64)
proc storeOutgoing*(c: ConnManager, conn: Connection) =
conn.dir = Direction.Out
c.storeConn(conn)
proc storeIncoming*(c: ConnManager, conn: Connection) =
conn.dir = Direction.In
c.storeConn(conn)
proc storeMuxer*(c: ConnManager,
muxer: Muxer,
handle: Future[void] = nil) =
## store the connection and muxer
##
if isNil(muxer):
raise newException(CatchableError, "muxer cannot be nil")
if isNil(muxer.connection):
raise newException(CatchableError, "muxer's connection cannot be nil")
c.muxed[muxer.connection] = MuxerHolder(
muxer: muxer,
handle: handle)
trace "storred connection", connections = c.conns.len
proc getMuxedStream*(c: ConnManager,
peerInfo: PeerInfo,
dir: Direction): Future[Connection] {.async, gcsafe.} =
## get a muxed stream for the provided peer
## with the given direction
##
let muxer = c.selectMuxer(c.selectConn(peerInfo, dir))
if not(isNil(muxer)):
return await muxer.newStream()
proc getMuxedStream*(c: ConnManager,
peerInfo: PeerInfo): Future[Connection] {.async, gcsafe.} =
## get a muxed stream for the passed peer from any connection
##
let muxer = c.selectMuxer(c.selectConn(peerInfo))
if not(isNil(muxer)):
return await muxer.newStream()
proc getMuxedStream*(c: ConnManager,
conn: Connection): Future[Connection] {.async, gcsafe.} =
## get a muxed stream for the passed connection
##
let muxer = c.selectMuxer(conn)
if not(isNil(muxer)):
return await muxer.newStream()
proc dropPeer*(c: ConnManager, peerInfo: PeerInfo) {.async.} =
## drop connections and cleanup resources for peer
##
for conn in c.conns.getOrDefault(peerInfo.peerId):
if not(isNil(conn)):
await c.cleanupConn(conn)
proc close*(c: ConnManager) {.async.} =
## cleanup resources for the connection
## manager
##
for conns in toSeq(c.conns.values):
for conn in conns:
try:
await c.cleanupConn(conn)
except CancelledError as exc:
raise exc
except CatchableError as exc:
warn "error cleaning up connections"

View File

@ -7,7 +7,7 @@
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
import options, sequtils import options, sequtils, hashes
import chronos, chronicles import chronos, chronicles
import peerid, multiaddress, crypto/crypto import peerid, multiaddress, crypto/crypto
@ -134,18 +134,5 @@ proc publicKey*(p: PeerInfo): Option[PublicKey] {.inline.} =
else: else:
result = some(p.privateKey.getKey().tryGet()) result = some(p.privateKey.getKey().tryGet())
func `==`*(a, b: PeerInfo): bool = func hash*(p: PeerInfo): Hash =
# override equiality to support both nil and peerInfo comparisons cast[pointer](p).hash
# this in the future will allow us to recycle refs
let
aptr = cast[pointer](a)
bptr = cast[pointer](b)
if isNil(aptr) and isNil(bptr):
return true
if isNil(aptr) or isNil(bptr):
return false
if aptr == bptr and a.peerId == b.peerId:
return true

View File

@ -8,14 +8,14 @@
## those terms. ## those terms.
import std/[tables, sequtils, sets] import std/[tables, sequtils, sets]
import chronos, chronicles import chronos, chronicles, metrics
import pubsubpeer, import pubsubpeer,
rpc/[message, messages], rpc/[message, messages],
../protocol, ../protocol,
../../stream/connection, ../../stream/connection,
../../peerid, ../../peerid,
../../peerinfo ../../peerinfo,
import metrics ../../errors
export PubSubPeer export PubSubPeer
export PubSubObserver export PubSubObserver
@ -233,8 +233,11 @@ method subscribe*(p: PubSub,
p.topics[topic].handler.add(handler) p.topics[topic].handler.add(handler)
var sent: seq[Future[void]]
for peer in toSeq(p.peers.values): for peer in toSeq(p.peers.values):
await p.sendSubs(peer, @[topic], true) sent.add(p.sendSubs(peer, @[topic], true))
checkFutures(await allFinished(sent))
# metrics # metrics
libp2p_pubsub_topics.inc() libp2p_pubsub_topics.inc()

View File

@ -48,22 +48,6 @@ func hash*(p: PubSubPeer): Hash =
# int is either 32/64, so intptr basically, pubsubpeer is a ref # int is either 32/64, so intptr basically, pubsubpeer is a ref
cast[pointer](p).hash cast[pointer](p).hash
func `==`*(a, b: PubSubPeer): bool =
# override equiality to support both nil and peerInfo comparisons
# this in the future will allow us to recycle refs
let
aptr = cast[pointer](a)
bptr = cast[pointer](b)
if isNil(aptr) and isNil(bptr):
return true
if isNil(aptr) or isNil(bptr):
return false
if aptr == bptr and a.peerInfo == b.peerInfo:
return true
proc id*(p: PubSubPeer): string = p.peerInfo.id proc id*(p: PubSubPeer): string = p.peerInfo.id
proc inUse*(p: PubSubPeer): bool = proc inUse*(p: PubSubPeer): bool =

View File

@ -289,7 +289,7 @@ method close*(s: BufferStream) {.async, gcsafe.} =
try: try:
## close the stream and clear the buffer ## close the stream and clear the buffer
if not s.isClosed: if not s.isClosed:
trace "closing bufferstream", oid = s.oid trace "closing bufferstream", oid = $s.oid
s.isEof = true s.isEof = true
for r in s.readReqs: for r in s.readReqs:
if not(isNil(r)) and not(r.finished()): if not(isNil(r)) and not(r.finished()):

View File

@ -7,6 +7,7 @@
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
import hashes
import chronos, metrics import chronos, metrics
import lpstream, import lpstream,
../multiaddress, ../multiaddress,
@ -18,9 +19,13 @@ const
ConnectionTrackerName* = "libp2p.connection" ConnectionTrackerName* = "libp2p.connection"
type type
Direction* {.pure.} = enum
None, In, Out
Connection* = ref object of LPStream Connection* = ref object of LPStream
peerInfo*: PeerInfo peerInfo*: PeerInfo
observedAddr*: Multiaddress observedAddr*: Multiaddress
dir*: Direction
ConnectionTracker* = ref object of TrackerBase ConnectionTracker* = ref object of TrackerBase
opened*: uint64 opened*: uint64
@ -50,9 +55,11 @@ proc setupConnectionTracker(): ConnectionTracker =
result.isLeaked = leakTransport result.isLeaked = leakTransport
addTracker(ConnectionTrackerName, result) addTracker(ConnectionTrackerName, result)
proc init*[T: Connection](self: var T, peerInfo: PeerInfo): T = proc init*(C: type Connection,
new self peerInfo: PeerInfo,
self.initStream() dir: Direction): Connection =
result = C(peerInfo: peerInfo, dir: dir)
result.initStream()
method initStream*(s: Connection) = method initStream*(s: Connection) =
if s.objName.len == 0: if s.objName.len == 0:
@ -63,9 +70,13 @@ method initStream*(s: Connection) =
inc getConnectionTracker().opened inc getConnectionTracker().opened
method close*(s: Connection) {.async.} = method close*(s: Connection) {.async.} =
await procCall LPStream(s).close() if not s.isClosed:
inc getConnectionTracker().closed await procCall LPStream(s).close()
inc getConnectionTracker().closed
proc `$`*(conn: Connection): string = proc `$`*(conn: Connection): string =
if not isNil(conn.peerInfo): if not isNil(conn.peerInfo):
result = conn.peerInfo.id result = conn.peerInfo.id
func hash*(p: Connection): Hash =
cast[pointer](p).hash

View File

@ -76,15 +76,6 @@ method initStream*(s: LPStream) {.base.} =
libp2p_open_streams.inc(labelValues = [s.objName]) libp2p_open_streams.inc(labelValues = [s.objName])
trace "stream created", oid = $s.oid, name = s.objName trace "stream created", oid = $s.oid, name = s.objName
# TODO: debuging aid to troubleshoot streams open/close
# try:
# echo "ChronosStream ", libp2p_open_streams.value(labelValues = ["ChronosStream"])
# echo "SecureConn ", libp2p_open_streams.value(labelValues = ["SecureConn"])
# # doAssert(libp2p_open_streams.value(labelValues = ["ChronosStream"]) >=
# # libp2p_open_streams.value(labelValues = ["SecureConn"]))
# except CatchableError:
# discard
proc join*(s: LPStream): Future[void] = proc join*(s: LPStream): Future[void] =
s.closeEvent.wait() s.closeEvent.wait()
@ -207,12 +198,3 @@ method close*(s: LPStream) {.base, async.} =
s.closeEvent.fire() s.closeEvent.fire()
libp2p_open_streams.dec(labelValues = [s.objName]) libp2p_open_streams.dec(labelValues = [s.objName])
trace "stream destroyed", oid = $s.oid, name = s.objName trace "stream destroyed", oid = $s.oid, name = s.objName
# TODO: debuging aid to troubleshoot streams open/close
# try:
# echo "ChronosStream ", libp2p_open_streams.value(labelValues = ["ChronosStream"])
# echo "SecureConn ", libp2p_open_streams.value(labelValues = ["SecureConn"])
# # doAssert(libp2p_open_streams.value(labelValues = ["ChronosStream"]) >=
# # libp2p_open_streams.value(labelValues = ["SecureConn"]))
# except CatchableError:
# discard

View File

@ -11,7 +11,6 @@ import tables,
sequtils, sequtils,
options, options,
sets, sets,
algorithm,
oids oids
import chronos, import chronos,
@ -28,6 +27,7 @@ import stream/connection,
protocols/identify, protocols/identify,
protocols/pubsub/pubsub, protocols/pubsub/pubsub,
muxers/muxer, muxers/muxer,
connmanager,
peerid peerid
logScope: logScope:
@ -39,33 +39,16 @@ logScope:
# and only if the channel has been secured (i.e. if a secure manager has been # and only if the channel has been secured (i.e. if a secure manager has been
# previously provided) # previously provided)
declareGauge(libp2p_peers, "total connected peers")
declareCounter(libp2p_dialed_peers, "dialed peers") declareCounter(libp2p_dialed_peers, "dialed peers")
declareCounter(libp2p_failed_dials, "failed dials") declareCounter(libp2p_failed_dials, "failed dials")
declareCounter(libp2p_failed_upgrade, "peers failed upgrade") declareCounter(libp2p_failed_upgrade, "peers failed upgrade")
const MaxConnectionsPerPeer = 5
type type
NoPubSubException* = object of CatchableError NoPubSubException* = object of CatchableError
TooManyConnections* = object of CatchableError
Direction {.pure.} = enum
In, Out
ConnectionHolder = object
dir: Direction
conn: Connection
MuxerHolder = object
dir: Direction
muxer: Muxer
handle: Future[void]
Switch* = ref object of RootObj Switch* = ref object of RootObj
peerInfo*: PeerInfo peerInfo*: PeerInfo
connections*: Table[string, seq[ConnectionHolder]] connManager: ConnManager
muxed*: Table[string, seq[MuxerHolder]]
transports*: seq[Transport] transports*: seq[Transport]
protocols*: seq[LPProtocol] protocols*: seq[LPProtocol]
muxers*: Table[string, MuxerProvider] muxers*: Table[string, MuxerProvider]
@ -75,90 +58,20 @@ type
secureManagers*: seq[Secure] secureManagers*: seq[Secure]
pubSub*: Option[PubSub] pubSub*: Option[PubSub]
dialLock: Table[string, AsyncLock] dialLock: Table[string, AsyncLock]
cleanUpLock: Table[string, AsyncLock]
proc newNoPubSubException(): ref NoPubSubException {.inline.} = proc newNoPubSubException(): ref NoPubSubException {.inline.} =
result = newException(NoPubSubException, "no pubsub provided!") result = newException(NoPubSubException, "no pubsub provided!")
proc newTooManyConnections(): ref TooManyConnections {.inline.} =
result = newException(TooManyConnections, "too many connections for peer")
proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.}
proc subscribePeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} proc subscribePeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.}
proc selectConn(s: Switch, peerInfo: PeerInfo): Connection = proc cleanupPubSubPeer(s: Switch, conn: Connection) {.async.} =
## select the "best" connection according to some criteria await conn.closeEvent.wait()
## if s.pubSub.isSome:
## Ideally when the connection's stats are available await s.pubSub.get().unsubscribePeer(conn.peerInfo)
## we'd select the fastest, but for now we simply pick an outgoing
## connection first if none is available, we pick the first outgoing
##
if isNil(peerInfo): proc isConnected*(s: Switch, peer: PeerInfo): bool =
return peer.peerId in s.connManager
let conns = s.connections
.getOrDefault(peerInfo.id)
# it should be OK to sort on each
# access as there should only be
# up to MaxConnectionsPerPeer entries
.sorted(
proc(a, b: ConnectionHolder): int =
if a.dir < b.dir: -1
elif a.dir == b.dir: 0
else: 1
, SortOrder.Descending)
if conns.len > 0:
return conns[0].conn
proc selectMuxer(s: Switch, conn: Connection): Muxer =
## select the muxer for the supplied connection
##
if isNil(conn):
return
if not(isNil(conn.peerInfo)) and conn.peerInfo.id in s.muxed:
if s.muxed[conn.peerInfo.id].len > 0:
let muxers = s.muxed[conn.peerInfo.id]
.filterIt( it.muxer.connection == conn )
if muxers.len > 0:
return muxers[0].muxer
proc storeConn(s: Switch,
muxer: Muxer,
dir: Direction,
handle: Future[void] = nil) {.async.} =
## store the connection and muxer
##
if isNil(muxer):
return
let conn = muxer.connection
if isNil(conn):
return
let id = conn.peerInfo.id
if s.connections.getOrDefault(id).len > MaxConnectionsPerPeer:
warn "disconnecting peer, too many connections", peer = $conn.peerInfo,
conns = s.connections
.getOrDefault(id).len
await s.disconnect(conn.peerInfo)
raise newTooManyConnections()
s.connections.mgetOrPut(
id,
newSeq[ConnectionHolder]())
.add(ConnectionHolder(conn: conn, dir: dir))
s.muxed.mgetOrPut(
muxer.connection.peerInfo.id,
newSeq[MuxerHolder]())
.add(MuxerHolder(muxer: muxer, handle: handle, dir: dir))
trace "storred connection", connections = s.connections.len
libp2p_peers.set(s.connections.len.int64)
proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
if s.secureManagers.len <= 0: if s.secureManagers.len <= 0:
@ -170,9 +83,11 @@ proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
trace "securing connection", codec = manager trace "securing connection", codec = manager
let secureProtocol = s.secureManagers.filterIt(it.codec == manager) let secureProtocol = s.secureManagers.filterIt(it.codec == manager)
# ms.select should deal with the correctness of this # ms.select should deal with the correctness of this
# let's avoid duplicating checks but detect if it fails to do it properly # let's avoid duplicating checks but detect if it fails to do it properly
doAssert(secureProtocol.len > 0) doAssert(secureProtocol.len > 0)
result = await secureProtocol[0].secure(conn, true) result = await secureProtocol[0].secure(conn, true)
proc identify(s: Switch, conn: Connection) {.async, gcsafe.} = proc identify(s: Switch, conn: Connection) {.async, gcsafe.} =
@ -218,6 +133,7 @@ proc mux(s: Switch, conn: Connection) {.async, gcsafe.} =
# create new muxer for connection # create new muxer for connection
let muxer = s.muxers[muxerName].newMuxer(conn) let muxer = s.muxers[muxerName].newMuxer(conn)
s.connManager.storeMuxer(muxer)
trace "found a muxer", name = muxerName, peer = $conn trace "found a muxer", name = muxerName, peer = $conn
@ -247,75 +163,10 @@ proc mux(s: Switch, conn: Connection) {.async, gcsafe.} =
# store it in muxed connections if we have a peer for it # store it in muxed connections if we have a peer for it
trace "adding muxer for peer", peer = conn.peerInfo.id trace "adding muxer for peer", peer = conn.peerInfo.id
await s.storeConn(muxer, Direction.Out, handlerFut) s.connManager.storeMuxer(muxer, handlerFut) # update muxer with handler
proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} =
if isNil(conn):
return
if isNil(conn.peerInfo):
return
let id = conn.peerInfo.id
let lock = s.cleanUpLock.mgetOrPut(id, newAsyncLock())
try:
await lock.acquire()
trace "cleaning up connection for peer", peerId = id
if id in s.muxed:
let muxerHolder = s.muxed[id]
.filterIt(
it.muxer.connection == conn
)
if muxerHolder.len > 0:
await muxerHolder[0].muxer.close()
if not(isNil(muxerHolder[0].handle)):
await muxerHolder[0].handle
if id in s.muxed:
s.muxed[id].keepItIf(
it.muxer.connection != conn
)
if s.muxed[id].len == 0:
s.muxed.del(id)
if s.pubSub.isSome:
await s.pubSub.get()
.unsubscribePeer(conn.peerInfo)
if id in s.connections:
s.connections[id].keepItIf(
it.conn != conn
)
if s.connections[id].len == 0:
s.connections.del(id)
# TODO: Investigate cleanupConn() always called twice for one peer.
if not(conn.peerInfo.isClosed()):
conn.peerInfo.close()
finally:
await conn.close()
libp2p_peers.set(s.connections.len.int64)
if lock.locked():
lock.release()
proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} = proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} =
let connections = s.connections.getOrDefault(peer.id) await s.connManager.dropPeer(peer)
for connHolder in connections:
if not isNil(connHolder.conn):
await s.cleanupConn(connHolder.conn)
proc getMuxedStream(s: Switch, peerInfo: PeerInfo): Future[Connection] {.async, gcsafe.} =
# if there is a muxer for the connection
# use it instead to create a muxed stream
let muxer = s.selectMuxer(s.selectConn(peerInfo)) # always get the first muxer here
if not(isNil(muxer)):
return await muxer.newStream()
proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
logScope: logScope:
@ -388,52 +239,51 @@ proc internalConnect(s: Switch,
var conn: Connection var conn: Connection
let lock = s.dialLock.mgetOrPut(id, newAsyncLock()) let lock = s.dialLock.mgetOrPut(id, newAsyncLock())
defer: try:
await lock.acquire()
trace "about to dial peer", peer = id
conn = s.connManager.selectConn(peer)
if conn.isNil or (conn.closed or conn.atEof):
trace "Dialing peer", peer = id
for t in s.transports: # for each transport
for a in peer.addrs: # for each address
if t.handles(a): # check if it can dial it
trace "Dialing address", address = $a, peer = id
try:
conn = await t.dial(a)
# make sure to assign the peer to the connection
conn.peerInfo = peer
libp2p_dialed_peers.inc()
except CancelledError as exc:
trace "dialing canceled", exc = exc.msg
raise
except CatchableError as exc:
trace "dialing failed", exc = exc.msg
libp2p_failed_dials.inc()
continue
try:
let uconn = await s.upgradeOutgoing(conn)
s.connManager.storeOutgoing(uconn)
conn = uconn
except CatchableError as exc:
if not(isNil(conn)):
await conn.close()
trace "Unable to establish outgoing link", exc = exc.msg
raise exc
if isNil(conn):
libp2p_failed_upgrade.inc()
continue
break
else:
trace "Reusing existing connection", oid = conn.oid
finally:
if lock.locked(): if lock.locked():
lock.release() lock.release()
await lock.acquire()
trace "about to dial peer", peer = id
conn = s.selectConn(peer)
if conn.isNil or conn.closed:
trace "Dialing peer", peer = id
for t in s.transports: # for each transport
for a in peer.addrs: # for each address
if t.handles(a): # check if it can dial it
trace "Dialing address", address = $a
try:
conn = await t.dial(a)
libp2p_dialed_peers.inc()
except CancelledError as exc:
trace "dialing canceled", exc = exc.msg
raise
except CatchableError as exc:
trace "dialing failed", exc = exc.msg
libp2p_failed_dials.inc()
continue
# make sure to assign the peer to the connection
conn.peerInfo = peer
try:
conn = await s.upgradeOutgoing(conn)
except CatchableError as exc:
if not(isNil(conn)):
await conn.close()
trace "Unable to establish outgoing link", exc = exc.msg
raise exc
if isNil(conn):
libp2p_failed_upgrade.inc()
continue
conn.closeEvent.wait()
.addCallback do(udata: pointer):
asyncCheck s.cleanupConn(conn)
break
else:
trace "Reusing existing connection", oid = conn.oid
if isNil(conn): if isNil(conn):
raise newException(CatchableError, raise newException(CatchableError,
"Unable to establish outgoing link") "Unable to establish outgoing link")
@ -443,13 +293,14 @@ proc internalConnect(s: Switch,
raise newException(CatchableError, raise newException(CatchableError,
"Connection dead on arrival") "Connection dead on arrival")
doAssert(conn.peerInfo.id in s.connections, doAssert(conn in s.connManager, "connection not tracked!")
"connection not tracked!")
trace "dial succesfull", oid = $conn.oid, trace "dial succesfull", oid = $conn.oid,
peer = $conn.peerInfo peer = $conn.peerInfo
await s.subscribePeer(peer) await s.subscribePeer(peer)
asyncCheck s.cleanupPubSubPeer(conn)
return conn return conn
proc connect*(s: Switch, peer: PeerInfo) {.async.} = proc connect*(s: Switch, peer: PeerInfo) {.async.} =
@ -460,7 +311,7 @@ proc dial*(s: Switch,
proto: string): proto: string):
Future[Connection] {.async.} = Future[Connection] {.async.} =
let conn = await s.internalConnect(peer) let conn = await s.internalConnect(peer)
let stream = await s.getMuxedStream(peer) let stream = await s.connManager.getMuxedStream(conn)
proc cleanup() {.async.} = proc cleanup() {.async.} =
if not(isNil(stream)): if not(isNil(stream)):
@ -505,14 +356,14 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} = proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} =
try: try:
defer: conn.dir = Direction.In # tag connection with direction
await s.cleanupConn(conn)
await s.upgradeIncoming(conn) # perform upgrade on incoming connection await s.upgradeIncoming(conn) # perform upgrade on incoming connection
except CancelledError as exc: except CancelledError as exc:
raise exc raise exc
except CatchableError as exc: except CatchableError as exc:
trace "Exception occurred in Switch.start", exc = exc.msg trace "Exception occurred in Switch.start", exc = exc.msg
finally:
await conn.close()
var startFuts: seq[Future[void]] var startFuts: seq[Future[void]]
for t in s.transports: # for each transport for t in s.transports: # for each transport
@ -537,14 +388,8 @@ proc stop*(s: Switch) {.async.} =
if s.pubSub.isSome: if s.pubSub.isSome:
await s.pubSub.get().stop() await s.pubSub.get().stop()
for conns in toSeq(s.connections.values): # close and cleanup all connections
for conn in conns: await s.connManager.close()
try:
await s.cleanupConn(conn.conn)
except CancelledError as exc:
raise exc
except CatchableError as exc:
warn "error cleaning up connections"
for t in s.transports: for t in s.transports:
try: try:
@ -562,7 +407,18 @@ proc subscribePeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} =
trace "about to subscribe to pubsub peer", peer = peerInfo.shortLog() trace "about to subscribe to pubsub peer", peer = peerInfo.shortLog()
var stream: Connection var stream: Connection
try: try:
stream = await s.getMuxedStream(peerInfo) stream = await s.connManager.getMuxedStream(peerInfo)
if isNil(stream):
trace "unable to subscribe to peer", peer = peerInfo.shortLog
return
if not await s.ms.select(stream, s.pubSub.get().codec):
if not(isNil(stream)):
await stream.close()
return
s.pubSub.get().subscribePeer(stream)
except CancelledError as exc: except CancelledError as exc:
if not(isNil(stream)): if not(isNil(stream)):
await stream.close() await stream.close()
@ -574,44 +430,27 @@ proc subscribePeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} =
if not(isNil(stream)): if not(isNil(stream)):
await stream.close() await stream.close()
if isNil(stream):
trace "unable to subscribe to peer", peer = peerInfo.shortLog
return
if not await s.ms.select(stream, s.pubSub.get().codec):
if not(isNil(stream)):
await stream.close()
return
s.pubSub.get().subscribePeer(stream)
proc subscribe*(s: Switch, topic: string, proc subscribe*(s: Switch, topic: string,
handler: TopicHandler): Future[void] = handler: TopicHandler) {.async.} =
## subscribe to a pubsub topic ## subscribe to a pubsub topic
if s.pubSub.isNone: if s.pubSub.isNone:
var retFuture = newFuture[void]("Switch.subscribe") raise newNoPubSubException()
retFuture.fail(newNoPubSubException())
return retFuture
return s.pubSub.get().subscribe(topic, handler) await s.pubSub.get().subscribe(topic, handler)
proc unsubscribe*(s: Switch, topics: seq[TopicPair]): Future[void] = proc unsubscribe*(s: Switch, topics: seq[TopicPair]) {.async.} =
## unsubscribe from topics ## unsubscribe from topics
if s.pubSub.isNone: if s.pubSub.isNone:
var retFuture = newFuture[void]("Switch.unsubscribe") raise newNoPubSubException()
retFuture.fail(newNoPubSubException())
return retFuture
return s.pubSub.get().unsubscribe(topics) await s.pubSub.get().unsubscribe(topics)
proc publish*(s: Switch, topic: string, data: seq[byte]): Future[int] = proc publish*(s: Switch, topic: string, data: seq[byte]): Future[int] {.async.} =
# pubslish to pubsub topic # pubslish to pubsub topic
if s.pubSub.isNone: if s.pubSub.isNone:
var retFuture = newFuture[int]("Switch.publish") raise newNoPubSubException()
retFuture.fail(newNoPubSubException())
return retFuture
return s.pubSub.get().publish(topic, data) return await s.pubSub.get().publish(topic, data)
proc addValidator*(s: Switch, proc addValidator*(s: Switch,
topics: varargs[string], topics: varargs[string],
@ -647,17 +486,17 @@ proc muxerHandler(s: Switch, muxer: Muxer) {.async, gcsafe.} =
muxer.connection.peerInfo = stream.peerInfo muxer.connection.peerInfo = stream.peerInfo
# store muxer and muxed connection # store incoming connection
await s.storeConn(muxer, Direction.In) s.connManager.storeIncoming(muxer.connection)
muxer.connection.closeEvent.wait() # store muxer and muxed connection
.addCallback do(udata: pointer): s.connManager.storeMuxer(muxer)
asyncCheck s.cleanupConn(muxer.connection)
trace "got new muxer", peer = $muxer.connection.peerInfo trace "got new muxer", peer = $muxer.connection.peerInfo
# try establishing a pubsub connection # try establishing a pubsub connection
await s.subscribePeer(muxer.connection.peerInfo) await s.subscribePeer(muxer.connection.peerInfo)
asyncCheck s.cleanupPubSubPeer(muxer.connection)
except CancelledError as exc: except CancelledError as exc:
await muxer.close() await muxer.close()
@ -680,8 +519,7 @@ proc newSwitch*(peerInfo: PeerInfo,
peerInfo: peerInfo, peerInfo: peerInfo,
ms: newMultistream(), ms: newMultistream(),
transports: transports, transports: transports,
connections: initTable[string, seq[ConnectionHolder]](), connManager: ConnManager.init(),
muxed: initTable[string, seq[MuxerHolder]](),
identity: identity, identity: identity,
muxers: muxers, muxers: muxers,
secureManagers: @secureManagers, secureManagers: @secureManagers,

192
tests/testconnmngr.nim Normal file
View File

@ -0,0 +1,192 @@
import unittest
import chronos
import ../libp2p/[connmanager,
stream/connection,
crypto/crypto,
muxers/muxer,
peerinfo]
import helpers
type
TestMuxer = ref object of Muxer
peerInfo: PeerInfo
method newStream*(
m: TestMuxer,
name: string = "",
lazy: bool = false):
Future[Connection] {.async, gcsafe.} =
result = Connection.init(m.peerInfo, Direction.Out)
suite "Connection Manager":
teardown:
for tracker in testTrackers():
# echo tracker.dump()
check tracker.isLeaked() == false
test "add and retrive a connection":
let connMngr = ConnManager.init()
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
let conn = Connection.init(peer, Direction.In)
connMngr.storeConn(conn)
check conn in connMngr
let peerConn = connMngr.selectConn(peer)
check peerConn == conn
check peerConn.dir == Direction.In
test "add and retrieve a muxer":
let connMngr = ConnManager.init()
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
let conn = Connection.init(peer, Direction.In)
let muxer = new Muxer
muxer.connection = conn
connMngr.storeConn(conn)
connMngr.storeMuxer(muxer)
check muxer in connMngr
let peerMuxer = connMngr.selectMuxer(conn)
check peerMuxer == muxer
test "get conn with direction":
let connMngr = ConnManager.init()
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
let conn1 = Connection.init(peer, Direction.Out)
let conn2 = Connection.init(peer, Direction.In)
connMngr.storeConn(conn1)
connMngr.storeConn(conn2)
check conn1 in connMngr
check conn2 in connMngr
let outConn = connMngr.selectConn(peer, Direction.Out)
let inConn = connMngr.selectConn(peer, Direction.In)
check outConn != inConn
check outConn.dir == Direction.Out
check inConn.dir == Direction.In
test "get muxed stream for peer":
proc test() {.async.} =
let connMngr = ConnManager.init()
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
let conn = Connection.init(peer, Direction.In)
let muxer = new TestMuxer
muxer.peerInfo = peer
muxer.connection = conn
connMngr.storeConn(conn)
connMngr.storeMuxer(muxer)
check muxer in connMngr
let stream = await connMngr.getMuxedStream(peer)
check not(isNil(stream))
check stream.peerInfo == peer
waitFor(test())
test "get stream from directed connection":
proc test() {.async.} =
let connMngr = ConnManager.init()
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
let conn = Connection.init(peer, Direction.In)
let muxer = new TestMuxer
muxer.peerInfo = peer
muxer.connection = conn
connMngr.storeConn(conn)
connMngr.storeMuxer(muxer)
check muxer in connMngr
check not(isNil((await connMngr.getMuxedStream(peer, Direction.In))))
check isNil((await connMngr.getMuxedStream(peer, Direction.Out)))
waitFor(test())
test "get stream from any connection":
proc test() {.async.} =
let connMngr = ConnManager.init()
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
let conn = Connection.init(peer, Direction.In)
let muxer = new TestMuxer
muxer.peerInfo = peer
muxer.connection = conn
connMngr.storeConn(conn)
connMngr.storeMuxer(muxer)
check muxer in connMngr
check not(isNil((await connMngr.getMuxedStream(conn))))
waitFor(test())
test "should raise on too many connections":
proc test() =
let connMngr = ConnManager.init(1)
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
connMngr.storeConn(Connection.init(peer, Direction.In))
connMngr.storeConn(Connection.init(peer, Direction.In))
connMngr.storeConn(Connection.init(peer, Direction.In))
expect TooManyConnections:
test()
test "cleanup on connection close":
proc test() {.async.} =
let connMngr = ConnManager.init()
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
let conn = Connection.init(peer, Direction.In)
let muxer = new Muxer
muxer.connection = conn
connMngr.storeConn(conn)
connMngr.storeMuxer(muxer)
check conn in connMngr
check muxer in connMngr
await conn.close()
await sleepAsync(10.millis)
check conn notin connMngr
check muxer notin connMngr
waitFor(test())
test "drop connections for peer":
proc test() {.async.} =
let connMngr = ConnManager.init()
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
for i in 0..<2:
let dir = if i mod 2 == 0:
Direction.In else:
Direction.Out
let conn = Connection.init(peer, dir)
let muxer = new Muxer
muxer.connection = conn
connMngr.storeConn(conn)
connMngr.storeMuxer(muxer)
check conn in connMngr
check muxer in connMngr
check not(isNil(connMngr.selectConn(peer, dir)))
check peer in connMngr.peers
await connMngr.dropPeer(peer)
check peer notin connMngr.peers
check isNil(connMngr.selectConn(peer, Direction.In))
check isNil(connMngr.selectConn(peer, Direction.Out))
check connMngr.peers.len == 0
waitFor(test())

View File

@ -1,6 +1,6 @@
{.used.} {.used.}
import unittest, tables import unittest
import chronos import chronos
import stew/byteutils import stew/byteutils
import nimcrypto/sysrand import nimcrypto/sysrand
@ -56,6 +56,10 @@ suite "Switch":
awaiters.add(await switch2.start()) awaiters.add(await switch2.start())
let conn = await switch2.dial(switch1.peerInfo, TestCodec) let conn = await switch2.dial(switch1.peerInfo, TestCodec)
check switch1.isConnected(switch2.peerInfo)
check switch2.isConnected(switch1.peerInfo)
await conn.writeLp("Hello!") await conn.writeLp("Hello!")
let msg = string.fromBytes(await conn.readLp(1024)) let msg = string.fromBytes(await conn.readLp(1024))
check "Hello!" == msg check "Hello!" == msg
@ -69,6 +73,9 @@ suite "Switch":
# this needs to go at end # this needs to go at end
await allFuturesThrowing(awaiters) await allFuturesThrowing(awaiters)
check not switch1.isConnected(switch2.peerInfo)
check not switch2.isConnected(switch1.peerInfo)
waitFor(testSwitch()) waitFor(testSwitch())
test "e2e should not leak bufferstreams and connections on channel close": test "e2e should not leak bufferstreams and connections on channel close":
@ -96,6 +103,10 @@ suite "Switch":
awaiters.add(await switch2.start()) awaiters.add(await switch2.start())
let conn = await switch2.dial(switch1.peerInfo, TestCodec) let conn = await switch2.dial(switch1.peerInfo, TestCodec)
check switch1.isConnected(switch2.peerInfo)
check switch2.isConnected(switch1.peerInfo)
await conn.writeLp("Hello!") await conn.writeLp("Hello!")
let msg = string.fromBytes(await conn.readLp(1024)) let msg = string.fromBytes(await conn.readLp(1024))
check "Hello!" == msg check "Hello!" == msg
@ -103,20 +114,20 @@ suite "Switch":
await sleepAsync(2.seconds) # wait a little for cleanup to happen await sleepAsync(2.seconds) # wait a little for cleanup to happen
var bufferTracker = getTracker(BufferStreamTrackerName) var bufferTracker = getTracker(BufferStreamTrackerName)
# echo bufferTracker.dump() echo bufferTracker.dump()
# plus 4 for the pubsub streams # plus 4 for the pubsub streams
check (BufferStreamTracker(bufferTracker).opened == check (BufferStreamTracker(bufferTracker).opened ==
(BufferStreamTracker(bufferTracker).closed + 4.uint64)) (BufferStreamTracker(bufferTracker).closed + 4.uint64))
# var connTracker = getTracker(ConnectionTrackerName) var connTracker = getTracker(ConnectionTrackerName)
# echo connTracker.dump() echo connTracker.dump()
# plus 8 is for the secured connection and the socket # plus 8 is for the secured connection and the socket
# and the pubsub streams that won't clean up until # and the pubsub streams that won't clean up until
# `disconnect()` or `stop()` # `disconnect()` or `stop()`
# check (ConnectionTracker(connTracker).opened == check (ConnectionTracker(connTracker).opened ==
# (ConnectionTracker(connTracker).closed + 8.uint64)) (ConnectionTracker(connTracker).closed + 8.uint64))
await allFuturesThrowing( await allFuturesThrowing(
done.wait(5.seconds), done.wait(5.seconds),
@ -127,6 +138,9 @@ suite "Switch":
# this needs to go at end # this needs to go at end
await allFuturesThrowing(awaiters) await allFuturesThrowing(awaiters)
check not switch1.isConnected(switch2.peerInfo)
check not switch2.isConnected(switch1.peerInfo)
waitFor(testSwitch()) waitFor(testSwitch())
test "e2e use connect then dial": test "e2e use connect then dial":
@ -153,10 +167,11 @@ suite "Switch":
awaiters.add(await switch2.start()) awaiters.add(await switch2.start())
await switch2.connect(switch1.peerInfo) await switch2.connect(switch1.peerInfo)
check switch1.peerInfo.id in switch2.connections
let conn = await switch2.dial(switch1.peerInfo, TestCodec) let conn = await switch2.dial(switch1.peerInfo, TestCodec)
check switch1.isConnected(switch2.peerInfo)
check switch2.isConnected(switch1.peerInfo)
try: try:
await conn.writeLp("Hello!") await conn.writeLp("Hello!")
let msg = string.fromBytes(await conn.readLp(1024)) let msg = string.fromBytes(await conn.readLp(1024))
@ -172,6 +187,9 @@ suite "Switch":
) )
await allFuturesThrowing(awaiters) await allFuturesThrowing(awaiters)
check not switch1.isConnected(switch2.peerInfo)
check not switch2.isConnected(switch1.peerInfo)
check: check:
waitFor(testSwitch()) == true waitFor(testSwitch()) == true
@ -186,23 +204,23 @@ suite "Switch":
await switch2.connect(switch1.peerInfo) await switch2.connect(switch1.peerInfo)
check switch1.connections[switch2.peerInfo.id].len > 0 check switch1.isConnected(switch2.peerInfo)
check switch2.connections[switch1.peerInfo.id].len > 0 check switch2.isConnected(switch1.peerInfo)
await sleepAsync(100.millis) await sleepAsync(100.millis)
await switch2.disconnect(switch1.peerInfo) await switch2.disconnect(switch1.peerInfo)
await sleepAsync(2.seconds) await sleepAsync(2.seconds)
check not switch1.isConnected(switch2.peerInfo)
check not switch2.isConnected(switch1.peerInfo)
var bufferTracker = getTracker(BufferStreamTrackerName) var bufferTracker = getTracker(BufferStreamTrackerName)
# echo bufferTracker.dump() # echo bufferTracker.dump()
check bufferTracker.isLeaked() == false check bufferTracker.isLeaked() == false
# var connTracker = getTracker(ConnectionTrackerName) var connTracker = getTracker(ConnectionTrackerName)
# echo connTracker.dump() # echo connTracker.dump()
# check connTracker.isLeaked() == false check connTracker.isLeaked() == false
check switch2.peerInfo.id notin switch1.connections
check switch1.peerInfo.id notin switch2.connections
await allFuturesThrowing( await allFuturesThrowing(
switch1.stop(), switch1.stop(),
@ -210,47 +228,3 @@ suite "Switch":
await allFuturesThrowing(awaiters) await allFuturesThrowing(awaiters)
waitFor(testSwitch()) waitFor(testSwitch())
# test "e2e: handle read + secio fragmented":
# proc testListenerDialer(): Future[bool] {.async.} =
# let
# server: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
# serverInfo = PeerInfo.init(PrivateKey.random(ECDSA), [server])
# serverNoise = newSecio(serverInfo.privateKey)
# readTask = newFuture[void]()
# var hugePayload = newSeq[byte](0x1200000)
# check randomBytes(hugePayload) == hugePayload.len
# trace "Sending huge payload", size = hugePayload.len
# proc connHandler(conn: Connection) {.async, gcsafe.} =
# let sconn = await serverNoise.secure(conn)
# defer:
# await sconn.close()
# let msg = await sconn.read(0x1200000)
# check msg == hugePayload
# readTask.complete()
# let
# transport1: TcpTransport = TcpTransport.init()
# asyncCheck await transport1.listen(server, connHandler)
# let
# transport2: TcpTransport = TcpTransport.init()
# clientInfo = PeerInfo.init(PrivateKey.random(ECDSA), [transport1.ma])
# clientNoise = newSecio(clientInfo.privateKey)
# conn = await transport2.dial(transport1.ma)
# sconn = await clientNoise.secure(conn)
# await sconn.write(hugePayload)
# await readTask
# await sconn.close()
# await conn.close()
# await transport2.close()
# await transport1.close()
# result = true
# check:
# waitFor(testListenerDialer()) == true