diff --git a/libp2p/connmanager.nim b/libp2p/connmanager.nim new file mode 100644 index 000000000..16d463ab6 --- /dev/null +++ b/libp2p/connmanager.nim @@ -0,0 +1,276 @@ +## Nim-LibP2P +## Copyright (c) 2020 Status Research & Development GmbH +## Licensed under either of +## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +## * MIT license ([LICENSE-MIT](LICENSE-MIT)) +## at your option. +## This file may not be copied, modified, or distributed except according to +## those terms. + +import tables, sequtils, sets +import chronos, chronicles, metrics +import peerinfo, + stream/connection, + muxers/muxer + +declareGauge(libp2p_peers, "total connected peers") + +const MaxConnectionsPerPeer = 5 + +type + TooManyConnections* = object of CatchableError + + MuxerHolder = object + muxer: Muxer + handle: Future[void] + + ConnManager* = ref object of RootObj + # NOTE: don't change to PeerInfo here + # the reference semantics on the PeerInfo + # object itself make it succeptible to + # copies and mangling by unrelated code. + conns: Table[PeerID, HashSet[Connection]] + muxed: Table[Connection, MuxerHolder] + cleanUpLock: Table[PeerInfo, AsyncLock] + maxConns: int + +proc newTooManyConnections(): ref TooManyConnections {.inline.} = + result = newException(TooManyConnections, "too many connections for peer") + +proc init*(C: type ConnManager, + maxConnsPerPeer: int = MaxConnectionsPerPeer): ConnManager = + C(maxConns: maxConnsPerPeer, + conns: initTable[PeerID, HashSet[Connection]](), + muxed: initTable[Connection, MuxerHolder]()) + +proc contains*(c: ConnManager, conn: Connection): bool = + ## checks if a connection is being tracked by the + ## connection manager + ## + + if isNil(conn): + return + + if isNil(conn.peerInfo): + return + + if conn.peerInfo.peerId notin c.conns: + return + + return conn in c.conns[conn.peerInfo.peerId] + +proc contains*(c: ConnManager, peerId: PeerID): bool = + peerId in c.conns + +proc contains*(c: ConnManager, muxer: Muxer): bool = + ## checks if a muxer is being tracked by the connection + ## manager + ## + + if isNil(muxer): + return + + let conn = muxer.connection + if conn notin c: + return + + if conn notin c.muxed: + return + + return muxer == c.muxed[conn].muxer + +proc cleanupConn(c: ConnManager, conn: Connection) {.async.} = + ## clean connection's resources such as muxers and streams + ## + + if isNil(conn): + return + + if isNil(conn.peerInfo): + return + + let peerInfo = conn.peerInfo + let lock = c.cleanUpLock.mgetOrPut(peerInfo, newAsyncLock()) + + try: + await lock.acquire() + trace "cleaning up connection for peer", peer = $peerInfo + if conn in c.muxed: + let muxerHolder = c.muxed[conn] + c.muxed.del(conn) + + await muxerHolder.muxer.close() + if not(isNil(muxerHolder.handle)): + await muxerHolder.handle + + if peerInfo.peerId in c.conns: + c.conns[peerInfo.peerId].excl(conn) + + if c.conns[peerInfo.peerId].len == 0: + c.conns.del(peerInfo.peerId) + + if not(conn.peerInfo.isClosed()): + conn.peerInfo.close() + + finally: + await conn.close() + libp2p_peers.set(c.conns.len.int64) + + if lock.locked(): + lock.release() + + trace "connection cleaned up" + +proc onClose(c: ConnManager, conn: Connection) {.async.} = + ## connection close even handler + ## + ## triggers the connections resource cleanup + ## + + await conn.closeEvent.wait() + trace "triggering connection cleanup" + await c.cleanupConn(conn) + +proc selectConn*(c: ConnManager, + peerInfo: PeerInfo, + dir: Direction): Connection = + ## Select a connection for the provided peer and direction + ## + + if isNil(peerInfo): + return + + let conns = toSeq( + c.conns.getOrDefault(peerInfo.peerId)) + .filterIt( it.dir == dir ) + + if conns.len > 0: + return conns[0] + +proc selectConn*(c: ConnManager, peerInfo: PeerInfo): Connection = + ## Select a connection for the provided giving priority + ## to outgoing connections + ## + + if isNil(peerInfo): + return + + var conn = c.selectConn(peerInfo, Direction.Out) + if isNil(conn): + conn = c.selectConn(peerInfo, Direction.In) + + return conn + +proc selectMuxer*(c: ConnManager, conn: Connection): Muxer = + ## select the muxer for the provided connection + ## + + if isNil(conn): + return + + if conn in c.muxed: + return c.muxed[conn].muxer + +proc storeConn*(c: ConnManager, conn: Connection) = + ## store a connection + ## + + if isNil(conn): + raise newException(CatchableError, "connection cannot be nil") + + if isNil(conn.peerInfo): + raise newException(CatchableError, "empty peer info") + + let peerInfo = conn.peerInfo + if c.conns.getOrDefault(peerInfo.peerId).len > c.maxConns: + trace "too many connections", peer = $conn.peerInfo, + conns = c.conns + .getOrDefault(peerInfo.peerId).len + + raise newTooManyConnections() + + if peerInfo.peerId notin c.conns: + c.conns[peerInfo.peerId] = initHashSet[Connection]() + + c.conns[peerInfo.peerId].incl(conn) + + # launch on close listener + asyncCheck c.onClose(conn) + libp2p_peers.set(c.conns.len.int64) + +proc storeOutgoing*(c: ConnManager, conn: Connection) = + conn.dir = Direction.Out + c.storeConn(conn) + +proc storeIncoming*(c: ConnManager, conn: Connection) = + conn.dir = Direction.In + c.storeConn(conn) + +proc storeMuxer*(c: ConnManager, + muxer: Muxer, + handle: Future[void] = nil) = + ## store the connection and muxer + ## + + if isNil(muxer): + raise newException(CatchableError, "muxer cannot be nil") + + if isNil(muxer.connection): + raise newException(CatchableError, "muxer's connection cannot be nil") + + c.muxed[muxer.connection] = MuxerHolder( + muxer: muxer, + handle: handle) + + trace "storred connection", connections = c.conns.len + +proc getMuxedStream*(c: ConnManager, + peerInfo: PeerInfo, + dir: Direction): Future[Connection] {.async, gcsafe.} = + ## get a muxed stream for the provided peer + ## with the given direction + ## + + let muxer = c.selectMuxer(c.selectConn(peerInfo, dir)) + if not(isNil(muxer)): + return await muxer.newStream() + +proc getMuxedStream*(c: ConnManager, + peerInfo: PeerInfo): Future[Connection] {.async, gcsafe.} = + ## get a muxed stream for the passed peer from any connection + ## + + let muxer = c.selectMuxer(c.selectConn(peerInfo)) + if not(isNil(muxer)): + return await muxer.newStream() + +proc getMuxedStream*(c: ConnManager, + conn: Connection): Future[Connection] {.async, gcsafe.} = + ## get a muxed stream for the passed connection + ## + + let muxer = c.selectMuxer(conn) + if not(isNil(muxer)): + return await muxer.newStream() + +proc dropPeer*(c: ConnManager, peerInfo: PeerInfo) {.async.} = + ## drop connections and cleanup resources for peer + ## + + for conn in c.conns.getOrDefault(peerInfo.peerId): + if not(isNil(conn)): + await c.cleanupConn(conn) + +proc close*(c: ConnManager) {.async.} = + ## cleanup resources for the connection + ## manager + ## + + for conns in toSeq(c.conns.values): + for conn in conns: + try: + await c.cleanupConn(conn) + except CancelledError as exc: + raise exc + except CatchableError as exc: + warn "error cleaning up connections" diff --git a/libp2p/peerinfo.nim b/libp2p/peerinfo.nim index 8c38d3f54..f86f9e983 100644 --- a/libp2p/peerinfo.nim +++ b/libp2p/peerinfo.nim @@ -7,7 +7,7 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import options, sequtils +import options, sequtils, hashes import chronos, chronicles import peerid, multiaddress, crypto/crypto @@ -134,18 +134,5 @@ proc publicKey*(p: PeerInfo): Option[PublicKey] {.inline.} = else: result = some(p.privateKey.getKey().tryGet()) -func `==`*(a, b: PeerInfo): bool = - # override equiality to support both nil and peerInfo comparisons - # this in the future will allow us to recycle refs - let - aptr = cast[pointer](a) - bptr = cast[pointer](b) - - if isNil(aptr) and isNil(bptr): - return true - - if isNil(aptr) or isNil(bptr): - return false - - if aptr == bptr and a.peerId == b.peerId: - return true +func hash*(p: PeerInfo): Hash = + cast[pointer](p).hash diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index a850b68d1..d94eec1d9 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -8,14 +8,14 @@ ## those terms. import std/[tables, sequtils, sets] -import chronos, chronicles +import chronos, chronicles, metrics import pubsubpeer, rpc/[message, messages], ../protocol, ../../stream/connection, ../../peerid, - ../../peerinfo -import metrics + ../../peerinfo, + ../../errors export PubSubPeer export PubSubObserver @@ -233,8 +233,11 @@ method subscribe*(p: PubSub, p.topics[topic].handler.add(handler) + var sent: seq[Future[void]] for peer in toSeq(p.peers.values): - await p.sendSubs(peer, @[topic], true) + sent.add(p.sendSubs(peer, @[topic], true)) + + checkFutures(await allFinished(sent)) # metrics libp2p_pubsub_topics.inc() diff --git a/libp2p/protocols/pubsub/pubsubpeer.nim b/libp2p/protocols/pubsub/pubsubpeer.nim index 8458f1038..8515a7165 100644 --- a/libp2p/protocols/pubsub/pubsubpeer.nim +++ b/libp2p/protocols/pubsub/pubsubpeer.nim @@ -48,22 +48,6 @@ func hash*(p: PubSubPeer): Hash = # int is either 32/64, so intptr basically, pubsubpeer is a ref cast[pointer](p).hash -func `==`*(a, b: PubSubPeer): bool = - # override equiality to support both nil and peerInfo comparisons - # this in the future will allow us to recycle refs - let - aptr = cast[pointer](a) - bptr = cast[pointer](b) - - if isNil(aptr) and isNil(bptr): - return true - - if isNil(aptr) or isNil(bptr): - return false - - if aptr == bptr and a.peerInfo == b.peerInfo: - return true - proc id*(p: PubSubPeer): string = p.peerInfo.id proc inUse*(p: PubSubPeer): bool = diff --git a/libp2p/stream/bufferstream.nim b/libp2p/stream/bufferstream.nim index 41c51f6ed..06d289d95 100644 --- a/libp2p/stream/bufferstream.nim +++ b/libp2p/stream/bufferstream.nim @@ -289,7 +289,7 @@ method close*(s: BufferStream) {.async, gcsafe.} = try: ## close the stream and clear the buffer if not s.isClosed: - trace "closing bufferstream", oid = s.oid + trace "closing bufferstream", oid = $s.oid s.isEof = true for r in s.readReqs: if not(isNil(r)) and not(r.finished()): diff --git a/libp2p/stream/connection.nim b/libp2p/stream/connection.nim index cb22e3fc0..60d5be845 100644 --- a/libp2p/stream/connection.nim +++ b/libp2p/stream/connection.nim @@ -7,6 +7,7 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. +import hashes import chronos, metrics import lpstream, ../multiaddress, @@ -18,9 +19,13 @@ const ConnectionTrackerName* = "libp2p.connection" type + Direction* {.pure.} = enum + None, In, Out + Connection* = ref object of LPStream peerInfo*: PeerInfo observedAddr*: Multiaddress + dir*: Direction ConnectionTracker* = ref object of TrackerBase opened*: uint64 @@ -50,9 +55,11 @@ proc setupConnectionTracker(): ConnectionTracker = result.isLeaked = leakTransport addTracker(ConnectionTrackerName, result) -proc init*[T: Connection](self: var T, peerInfo: PeerInfo): T = - new self - self.initStream() +proc init*(C: type Connection, + peerInfo: PeerInfo, + dir: Direction): Connection = + result = C(peerInfo: peerInfo, dir: dir) + result.initStream() method initStream*(s: Connection) = if s.objName.len == 0: @@ -63,9 +70,13 @@ method initStream*(s: Connection) = inc getConnectionTracker().opened method close*(s: Connection) {.async.} = - await procCall LPStream(s).close() - inc getConnectionTracker().closed + if not s.isClosed: + await procCall LPStream(s).close() + inc getConnectionTracker().closed proc `$`*(conn: Connection): string = if not isNil(conn.peerInfo): result = conn.peerInfo.id + +func hash*(p: Connection): Hash = + cast[pointer](p).hash diff --git a/libp2p/stream/lpstream.nim b/libp2p/stream/lpstream.nim index 68bd47ef1..cea09339d 100644 --- a/libp2p/stream/lpstream.nim +++ b/libp2p/stream/lpstream.nim @@ -76,15 +76,6 @@ method initStream*(s: LPStream) {.base.} = libp2p_open_streams.inc(labelValues = [s.objName]) trace "stream created", oid = $s.oid, name = s.objName - # TODO: debuging aid to troubleshoot streams open/close - # try: - # echo "ChronosStream ", libp2p_open_streams.value(labelValues = ["ChronosStream"]) - # echo "SecureConn ", libp2p_open_streams.value(labelValues = ["SecureConn"]) - # # doAssert(libp2p_open_streams.value(labelValues = ["ChronosStream"]) >= - # # libp2p_open_streams.value(labelValues = ["SecureConn"])) - # except CatchableError: - # discard - proc join*(s: LPStream): Future[void] = s.closeEvent.wait() @@ -207,12 +198,3 @@ method close*(s: LPStream) {.base, async.} = s.closeEvent.fire() libp2p_open_streams.dec(labelValues = [s.objName]) trace "stream destroyed", oid = $s.oid, name = s.objName - - # TODO: debuging aid to troubleshoot streams open/close - # try: - # echo "ChronosStream ", libp2p_open_streams.value(labelValues = ["ChronosStream"]) - # echo "SecureConn ", libp2p_open_streams.value(labelValues = ["SecureConn"]) - # # doAssert(libp2p_open_streams.value(labelValues = ["ChronosStream"]) >= - # # libp2p_open_streams.value(labelValues = ["SecureConn"])) - # except CatchableError: - # discard diff --git a/libp2p/switch.nim b/libp2p/switch.nim index d9e047348..a45ce449e 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -11,7 +11,6 @@ import tables, sequtils, options, sets, - algorithm, oids import chronos, @@ -28,6 +27,7 @@ import stream/connection, protocols/identify, protocols/pubsub/pubsub, muxers/muxer, + connmanager, peerid logScope: @@ -39,33 +39,16 @@ logScope: # and only if the channel has been secured (i.e. if a secure manager has been # previously provided) -declareGauge(libp2p_peers, "total connected peers") declareCounter(libp2p_dialed_peers, "dialed peers") declareCounter(libp2p_failed_dials, "failed dials") declareCounter(libp2p_failed_upgrade, "peers failed upgrade") -const MaxConnectionsPerPeer = 5 - type NoPubSubException* = object of CatchableError - TooManyConnections* = object of CatchableError - - Direction {.pure.} = enum - In, Out - - ConnectionHolder = object - dir: Direction - conn: Connection - - MuxerHolder = object - dir: Direction - muxer: Muxer - handle: Future[void] Switch* = ref object of RootObj peerInfo*: PeerInfo - connections*: Table[string, seq[ConnectionHolder]] - muxed*: Table[string, seq[MuxerHolder]] + connManager: ConnManager transports*: seq[Transport] protocols*: seq[LPProtocol] muxers*: Table[string, MuxerProvider] @@ -75,90 +58,20 @@ type secureManagers*: seq[Secure] pubSub*: Option[PubSub] dialLock: Table[string, AsyncLock] - cleanUpLock: Table[string, AsyncLock] proc newNoPubSubException(): ref NoPubSubException {.inline.} = result = newException(NoPubSubException, "no pubsub provided!") -proc newTooManyConnections(): ref TooManyConnections {.inline.} = - result = newException(TooManyConnections, "too many connections for peer") - proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} proc subscribePeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} -proc selectConn(s: Switch, peerInfo: PeerInfo): Connection = - ## select the "best" connection according to some criteria - ## - ## Ideally when the connection's stats are available - ## we'd select the fastest, but for now we simply pick an outgoing - ## connection first if none is available, we pick the first outgoing - ## +proc cleanupPubSubPeer(s: Switch, conn: Connection) {.async.} = + await conn.closeEvent.wait() + if s.pubSub.isSome: + await s.pubSub.get().unsubscribePeer(conn.peerInfo) - if isNil(peerInfo): - return - - let conns = s.connections - .getOrDefault(peerInfo.id) - # it should be OK to sort on each - # access as there should only be - # up to MaxConnectionsPerPeer entries - .sorted( - proc(a, b: ConnectionHolder): int = - if a.dir < b.dir: -1 - elif a.dir == b.dir: 0 - else: 1 - , SortOrder.Descending) - - if conns.len > 0: - return conns[0].conn - -proc selectMuxer(s: Switch, conn: Connection): Muxer = - ## select the muxer for the supplied connection - ## - - if isNil(conn): - return - - if not(isNil(conn.peerInfo)) and conn.peerInfo.id in s.muxed: - if s.muxed[conn.peerInfo.id].len > 0: - let muxers = s.muxed[conn.peerInfo.id] - .filterIt( it.muxer.connection == conn ) - if muxers.len > 0: - return muxers[0].muxer - -proc storeConn(s: Switch, - muxer: Muxer, - dir: Direction, - handle: Future[void] = nil) {.async.} = - ## store the connection and muxer - ## - if isNil(muxer): - return - - let conn = muxer.connection - if isNil(conn): - return - - let id = conn.peerInfo.id - if s.connections.getOrDefault(id).len > MaxConnectionsPerPeer: - warn "disconnecting peer, too many connections", peer = $conn.peerInfo, - conns = s.connections - .getOrDefault(id).len - await s.disconnect(conn.peerInfo) - raise newTooManyConnections() - - s.connections.mgetOrPut( - id, - newSeq[ConnectionHolder]()) - .add(ConnectionHolder(conn: conn, dir: dir)) - - s.muxed.mgetOrPut( - muxer.connection.peerInfo.id, - newSeq[MuxerHolder]()) - .add(MuxerHolder(muxer: muxer, handle: handle, dir: dir)) - - trace "storred connection", connections = s.connections.len - libp2p_peers.set(s.connections.len.int64) +proc isConnected*(s: Switch, peer: PeerInfo): bool = + peer.peerId in s.connManager proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = if s.secureManagers.len <= 0: @@ -170,9 +83,11 @@ proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = trace "securing connection", codec = manager let secureProtocol = s.secureManagers.filterIt(it.codec == manager) + # ms.select should deal with the correctness of this # let's avoid duplicating checks but detect if it fails to do it properly doAssert(secureProtocol.len > 0) + result = await secureProtocol[0].secure(conn, true) proc identify(s: Switch, conn: Connection) {.async, gcsafe.} = @@ -218,6 +133,7 @@ proc mux(s: Switch, conn: Connection) {.async, gcsafe.} = # create new muxer for connection let muxer = s.muxers[muxerName].newMuxer(conn) + s.connManager.storeMuxer(muxer) trace "found a muxer", name = muxerName, peer = $conn @@ -247,75 +163,10 @@ proc mux(s: Switch, conn: Connection) {.async, gcsafe.} = # store it in muxed connections if we have a peer for it trace "adding muxer for peer", peer = conn.peerInfo.id - await s.storeConn(muxer, Direction.Out, handlerFut) - -proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} = - if isNil(conn): - return - - if isNil(conn.peerInfo): - return - - let id = conn.peerInfo.id - let lock = s.cleanUpLock.mgetOrPut(id, newAsyncLock()) - - try: - await lock.acquire() - trace "cleaning up connection for peer", peerId = id - if id in s.muxed: - let muxerHolder = s.muxed[id] - .filterIt( - it.muxer.connection == conn - ) - - if muxerHolder.len > 0: - await muxerHolder[0].muxer.close() - if not(isNil(muxerHolder[0].handle)): - await muxerHolder[0].handle - - if id in s.muxed: - s.muxed[id].keepItIf( - it.muxer.connection != conn - ) - - if s.muxed[id].len == 0: - s.muxed.del(id) - - if s.pubSub.isSome: - await s.pubSub.get() - .unsubscribePeer(conn.peerInfo) - - if id in s.connections: - s.connections[id].keepItIf( - it.conn != conn - ) - - if s.connections[id].len == 0: - s.connections.del(id) - - # TODO: Investigate cleanupConn() always called twice for one peer. - if not(conn.peerInfo.isClosed()): - conn.peerInfo.close() - finally: - await conn.close() - libp2p_peers.set(s.connections.len.int64) - - if lock.locked(): - lock.release() + s.connManager.storeMuxer(muxer, handlerFut) # update muxer with handler proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} = - let connections = s.connections.getOrDefault(peer.id) - for connHolder in connections: - if not isNil(connHolder.conn): - await s.cleanupConn(connHolder.conn) - -proc getMuxedStream(s: Switch, peerInfo: PeerInfo): Future[Connection] {.async, gcsafe.} = - # if there is a muxer for the connection - # use it instead to create a muxed stream - - let muxer = s.selectMuxer(s.selectConn(peerInfo)) # always get the first muxer here - if not(isNil(muxer)): - return await muxer.newStream() + await s.connManager.dropPeer(peer) proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = logScope: @@ -388,52 +239,51 @@ proc internalConnect(s: Switch, var conn: Connection let lock = s.dialLock.mgetOrPut(id, newAsyncLock()) - defer: + try: + await lock.acquire() + trace "about to dial peer", peer = id + conn = s.connManager.selectConn(peer) + if conn.isNil or (conn.closed or conn.atEof): + trace "Dialing peer", peer = id + for t in s.transports: # for each transport + for a in peer.addrs: # for each address + if t.handles(a): # check if it can dial it + trace "Dialing address", address = $a, peer = id + try: + conn = await t.dial(a) + # make sure to assign the peer to the connection + conn.peerInfo = peer + + libp2p_dialed_peers.inc() + except CancelledError as exc: + trace "dialing canceled", exc = exc.msg + raise + except CatchableError as exc: + trace "dialing failed", exc = exc.msg + libp2p_failed_dials.inc() + continue + + try: + let uconn = await s.upgradeOutgoing(conn) + s.connManager.storeOutgoing(uconn) + conn = uconn + except CatchableError as exc: + if not(isNil(conn)): + await conn.close() + + trace "Unable to establish outgoing link", exc = exc.msg + raise exc + + if isNil(conn): + libp2p_failed_upgrade.inc() + continue + break + else: + trace "Reusing existing connection", oid = conn.oid + finally: if lock.locked(): lock.release() - await lock.acquire() - trace "about to dial peer", peer = id - conn = s.selectConn(peer) - if conn.isNil or conn.closed: - trace "Dialing peer", peer = id - for t in s.transports: # for each transport - for a in peer.addrs: # for each address - if t.handles(a): # check if it can dial it - trace "Dialing address", address = $a - try: - conn = await t.dial(a) - libp2p_dialed_peers.inc() - except CancelledError as exc: - trace "dialing canceled", exc = exc.msg - raise - except CatchableError as exc: - trace "dialing failed", exc = exc.msg - libp2p_failed_dials.inc() - continue - - # make sure to assign the peer to the connection - conn.peerInfo = peer - try: - conn = await s.upgradeOutgoing(conn) - except CatchableError as exc: - if not(isNil(conn)): - await conn.close() - - trace "Unable to establish outgoing link", exc = exc.msg - raise exc - - if isNil(conn): - libp2p_failed_upgrade.inc() - continue - - conn.closeEvent.wait() - .addCallback do(udata: pointer): - asyncCheck s.cleanupConn(conn) - break - else: - trace "Reusing existing connection", oid = conn.oid - if isNil(conn): raise newException(CatchableError, "Unable to establish outgoing link") @@ -443,13 +293,14 @@ proc internalConnect(s: Switch, raise newException(CatchableError, "Connection dead on arrival") - doAssert(conn.peerInfo.id in s.connections, - "connection not tracked!") + doAssert(conn in s.connManager, "connection not tracked!") trace "dial succesfull", oid = $conn.oid, peer = $conn.peerInfo await s.subscribePeer(peer) + asyncCheck s.cleanupPubSubPeer(conn) + return conn proc connect*(s: Switch, peer: PeerInfo) {.async.} = @@ -460,7 +311,7 @@ proc dial*(s: Switch, proto: string): Future[Connection] {.async.} = let conn = await s.internalConnect(peer) - let stream = await s.getMuxedStream(peer) + let stream = await s.connManager.getMuxedStream(conn) proc cleanup() {.async.} = if not(isNil(stream)): @@ -505,14 +356,14 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} = try: - defer: - await s.cleanupConn(conn) - + conn.dir = Direction.In # tag connection with direction await s.upgradeIncoming(conn) # perform upgrade on incoming connection except CancelledError as exc: raise exc except CatchableError as exc: trace "Exception occurred in Switch.start", exc = exc.msg + finally: + await conn.close() var startFuts: seq[Future[void]] for t in s.transports: # for each transport @@ -537,14 +388,8 @@ proc stop*(s: Switch) {.async.} = if s.pubSub.isSome: await s.pubSub.get().stop() - for conns in toSeq(s.connections.values): - for conn in conns: - try: - await s.cleanupConn(conn.conn) - except CancelledError as exc: - raise exc - except CatchableError as exc: - warn "error cleaning up connections" + # close and cleanup all connections + await s.connManager.close() for t in s.transports: try: @@ -562,7 +407,18 @@ proc subscribePeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} = trace "about to subscribe to pubsub peer", peer = peerInfo.shortLog() var stream: Connection try: - stream = await s.getMuxedStream(peerInfo) + stream = await s.connManager.getMuxedStream(peerInfo) + if isNil(stream): + trace "unable to subscribe to peer", peer = peerInfo.shortLog + return + + if not await s.ms.select(stream, s.pubSub.get().codec): + if not(isNil(stream)): + await stream.close() + return + + s.pubSub.get().subscribePeer(stream) + except CancelledError as exc: if not(isNil(stream)): await stream.close() @@ -574,44 +430,27 @@ proc subscribePeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} = if not(isNil(stream)): await stream.close() - if isNil(stream): - trace "unable to subscribe to peer", peer = peerInfo.shortLog - return - - if not await s.ms.select(stream, s.pubSub.get().codec): - if not(isNil(stream)): - await stream.close() - return - - s.pubSub.get().subscribePeer(stream) - proc subscribe*(s: Switch, topic: string, - handler: TopicHandler): Future[void] = + handler: TopicHandler) {.async.} = ## subscribe to a pubsub topic if s.pubSub.isNone: - var retFuture = newFuture[void]("Switch.subscribe") - retFuture.fail(newNoPubSubException()) - return retFuture + raise newNoPubSubException() - return s.pubSub.get().subscribe(topic, handler) + await s.pubSub.get().subscribe(topic, handler) -proc unsubscribe*(s: Switch, topics: seq[TopicPair]): Future[void] = +proc unsubscribe*(s: Switch, topics: seq[TopicPair]) {.async.} = ## unsubscribe from topics if s.pubSub.isNone: - var retFuture = newFuture[void]("Switch.unsubscribe") - retFuture.fail(newNoPubSubException()) - return retFuture + raise newNoPubSubException() - return s.pubSub.get().unsubscribe(topics) + await s.pubSub.get().unsubscribe(topics) -proc publish*(s: Switch, topic: string, data: seq[byte]): Future[int] = +proc publish*(s: Switch, topic: string, data: seq[byte]): Future[int] {.async.} = # pubslish to pubsub topic if s.pubSub.isNone: - var retFuture = newFuture[int]("Switch.publish") - retFuture.fail(newNoPubSubException()) - return retFuture + raise newNoPubSubException() - return s.pubSub.get().publish(topic, data) + return await s.pubSub.get().publish(topic, data) proc addValidator*(s: Switch, topics: varargs[string], @@ -647,17 +486,17 @@ proc muxerHandler(s: Switch, muxer: Muxer) {.async, gcsafe.} = muxer.connection.peerInfo = stream.peerInfo - # store muxer and muxed connection - await s.storeConn(muxer, Direction.In) + # store incoming connection + s.connManager.storeIncoming(muxer.connection) - muxer.connection.closeEvent.wait() - .addCallback do(udata: pointer): - asyncCheck s.cleanupConn(muxer.connection) + # store muxer and muxed connection + s.connManager.storeMuxer(muxer) trace "got new muxer", peer = $muxer.connection.peerInfo # try establishing a pubsub connection await s.subscribePeer(muxer.connection.peerInfo) + asyncCheck s.cleanupPubSubPeer(muxer.connection) except CancelledError as exc: await muxer.close() @@ -680,8 +519,7 @@ proc newSwitch*(peerInfo: PeerInfo, peerInfo: peerInfo, ms: newMultistream(), transports: transports, - connections: initTable[string, seq[ConnectionHolder]](), - muxed: initTable[string, seq[MuxerHolder]](), + connManager: ConnManager.init(), identity: identity, muxers: muxers, secureManagers: @secureManagers, diff --git a/tests/testconnmngr.nim b/tests/testconnmngr.nim new file mode 100644 index 000000000..568bc9187 --- /dev/null +++ b/tests/testconnmngr.nim @@ -0,0 +1,192 @@ +import unittest +import chronos +import ../libp2p/[connmanager, + stream/connection, + crypto/crypto, + muxers/muxer, + peerinfo] + +import helpers + +type + TestMuxer = ref object of Muxer + peerInfo: PeerInfo + +method newStream*( + m: TestMuxer, + name: string = "", + lazy: bool = false): + Future[Connection] {.async, gcsafe.} = + result = Connection.init(m.peerInfo, Direction.Out) + +suite "Connection Manager": + teardown: + for tracker in testTrackers(): + # echo tracker.dump() + check tracker.isLeaked() == false + + test "add and retrive a connection": + let connMngr = ConnManager.init() + let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()) + let conn = Connection.init(peer, Direction.In) + + connMngr.storeConn(conn) + check conn in connMngr + + let peerConn = connMngr.selectConn(peer) + check peerConn == conn + check peerConn.dir == Direction.In + + test "add and retrieve a muxer": + let connMngr = ConnManager.init() + let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()) + let conn = Connection.init(peer, Direction.In) + let muxer = new Muxer + muxer.connection = conn + + connMngr.storeConn(conn) + connMngr.storeMuxer(muxer) + check muxer in connMngr + + let peerMuxer = connMngr.selectMuxer(conn) + check peerMuxer == muxer + + test "get conn with direction": + let connMngr = ConnManager.init() + let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()) + let conn1 = Connection.init(peer, Direction.Out) + let conn2 = Connection.init(peer, Direction.In) + + connMngr.storeConn(conn1) + connMngr.storeConn(conn2) + check conn1 in connMngr + check conn2 in connMngr + + let outConn = connMngr.selectConn(peer, Direction.Out) + let inConn = connMngr.selectConn(peer, Direction.In) + + check outConn != inConn + check outConn.dir == Direction.Out + check inConn.dir == Direction.In + + test "get muxed stream for peer": + proc test() {.async.} = + let connMngr = ConnManager.init() + let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()) + let conn = Connection.init(peer, Direction.In) + + let muxer = new TestMuxer + muxer.peerInfo = peer + muxer.connection = conn + + connMngr.storeConn(conn) + connMngr.storeMuxer(muxer) + check muxer in connMngr + + let stream = await connMngr.getMuxedStream(peer) + check not(isNil(stream)) + check stream.peerInfo == peer + + waitFor(test()) + + test "get stream from directed connection": + proc test() {.async.} = + let connMngr = ConnManager.init() + let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()) + let conn = Connection.init(peer, Direction.In) + + let muxer = new TestMuxer + muxer.peerInfo = peer + muxer.connection = conn + + connMngr.storeConn(conn) + connMngr.storeMuxer(muxer) + check muxer in connMngr + + check not(isNil((await connMngr.getMuxedStream(peer, Direction.In)))) + check isNil((await connMngr.getMuxedStream(peer, Direction.Out))) + + waitFor(test()) + + test "get stream from any connection": + proc test() {.async.} = + let connMngr = ConnManager.init() + let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()) + let conn = Connection.init(peer, Direction.In) + + let muxer = new TestMuxer + muxer.peerInfo = peer + muxer.connection = conn + + connMngr.storeConn(conn) + connMngr.storeMuxer(muxer) + check muxer in connMngr + + check not(isNil((await connMngr.getMuxedStream(conn)))) + + waitFor(test()) + + test "should raise on too many connections": + proc test() = + let connMngr = ConnManager.init(1) + let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()) + + connMngr.storeConn(Connection.init(peer, Direction.In)) + connMngr.storeConn(Connection.init(peer, Direction.In)) + connMngr.storeConn(Connection.init(peer, Direction.In)) + + expect TooManyConnections: + test() + + test "cleanup on connection close": + proc test() {.async.} = + let connMngr = ConnManager.init() + let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()) + let conn = Connection.init(peer, Direction.In) + let muxer = new Muxer + muxer.connection = conn + + connMngr.storeConn(conn) + connMngr.storeMuxer(muxer) + + check conn in connMngr + check muxer in connMngr + + await conn.close() + await sleepAsync(10.millis) + + check conn notin connMngr + check muxer notin connMngr + + waitFor(test()) + + test "drop connections for peer": + proc test() {.async.} = + let connMngr = ConnManager.init() + let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet()) + + for i in 0..<2: + let dir = if i mod 2 == 0: + Direction.In else: + Direction.Out + + let conn = Connection.init(peer, dir) + let muxer = new Muxer + muxer.connection = conn + + connMngr.storeConn(conn) + connMngr.storeMuxer(muxer) + + check conn in connMngr + check muxer in connMngr + check not(isNil(connMngr.selectConn(peer, dir))) + + check peer in connMngr.peers + await connMngr.dropPeer(peer) + + check peer notin connMngr.peers + check isNil(connMngr.selectConn(peer, Direction.In)) + check isNil(connMngr.selectConn(peer, Direction.Out)) + check connMngr.peers.len == 0 + + waitFor(test()) diff --git a/tests/testswitch.nim b/tests/testswitch.nim index a486b4034..030c12274 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -1,6 +1,6 @@ {.used.} -import unittest, tables +import unittest import chronos import stew/byteutils import nimcrypto/sysrand @@ -56,6 +56,10 @@ suite "Switch": awaiters.add(await switch2.start()) let conn = await switch2.dial(switch1.peerInfo, TestCodec) + + check switch1.isConnected(switch2.peerInfo) + check switch2.isConnected(switch1.peerInfo) + await conn.writeLp("Hello!") let msg = string.fromBytes(await conn.readLp(1024)) check "Hello!" == msg @@ -69,6 +73,9 @@ suite "Switch": # this needs to go at end await allFuturesThrowing(awaiters) + check not switch1.isConnected(switch2.peerInfo) + check not switch2.isConnected(switch1.peerInfo) + waitFor(testSwitch()) test "e2e should not leak bufferstreams and connections on channel close": @@ -96,6 +103,10 @@ suite "Switch": awaiters.add(await switch2.start()) let conn = await switch2.dial(switch1.peerInfo, TestCodec) + + check switch1.isConnected(switch2.peerInfo) + check switch2.isConnected(switch1.peerInfo) + await conn.writeLp("Hello!") let msg = string.fromBytes(await conn.readLp(1024)) check "Hello!" == msg @@ -103,20 +114,20 @@ suite "Switch": await sleepAsync(2.seconds) # wait a little for cleanup to happen var bufferTracker = getTracker(BufferStreamTrackerName) - # echo bufferTracker.dump() + echo bufferTracker.dump() # plus 4 for the pubsub streams check (BufferStreamTracker(bufferTracker).opened == (BufferStreamTracker(bufferTracker).closed + 4.uint64)) - # var connTracker = getTracker(ConnectionTrackerName) - # echo connTracker.dump() + var connTracker = getTracker(ConnectionTrackerName) + echo connTracker.dump() # plus 8 is for the secured connection and the socket # and the pubsub streams that won't clean up until # `disconnect()` or `stop()` - # check (ConnectionTracker(connTracker).opened == - # (ConnectionTracker(connTracker).closed + 8.uint64)) + check (ConnectionTracker(connTracker).opened == + (ConnectionTracker(connTracker).closed + 8.uint64)) await allFuturesThrowing( done.wait(5.seconds), @@ -127,6 +138,9 @@ suite "Switch": # this needs to go at end await allFuturesThrowing(awaiters) + check not switch1.isConnected(switch2.peerInfo) + check not switch2.isConnected(switch1.peerInfo) + waitFor(testSwitch()) test "e2e use connect then dial": @@ -153,10 +167,11 @@ suite "Switch": awaiters.add(await switch2.start()) await switch2.connect(switch1.peerInfo) - check switch1.peerInfo.id in switch2.connections - let conn = await switch2.dial(switch1.peerInfo, TestCodec) + check switch1.isConnected(switch2.peerInfo) + check switch2.isConnected(switch1.peerInfo) + try: await conn.writeLp("Hello!") let msg = string.fromBytes(await conn.readLp(1024)) @@ -172,6 +187,9 @@ suite "Switch": ) await allFuturesThrowing(awaiters) + check not switch1.isConnected(switch2.peerInfo) + check not switch2.isConnected(switch1.peerInfo) + check: waitFor(testSwitch()) == true @@ -186,23 +204,23 @@ suite "Switch": await switch2.connect(switch1.peerInfo) - check switch1.connections[switch2.peerInfo.id].len > 0 - check switch2.connections[switch1.peerInfo.id].len > 0 + check switch1.isConnected(switch2.peerInfo) + check switch2.isConnected(switch1.peerInfo) await sleepAsync(100.millis) await switch2.disconnect(switch1.peerInfo) - await sleepAsync(2.seconds) + + check not switch1.isConnected(switch2.peerInfo) + check not switch2.isConnected(switch1.peerInfo) + var bufferTracker = getTracker(BufferStreamTrackerName) # echo bufferTracker.dump() check bufferTracker.isLeaked() == false - # var connTracker = getTracker(ConnectionTrackerName) + var connTracker = getTracker(ConnectionTrackerName) # echo connTracker.dump() - # check connTracker.isLeaked() == false - - check switch2.peerInfo.id notin switch1.connections - check switch1.peerInfo.id notin switch2.connections + check connTracker.isLeaked() == false await allFuturesThrowing( switch1.stop(), @@ -210,47 +228,3 @@ suite "Switch": await allFuturesThrowing(awaiters) waitFor(testSwitch()) - - # test "e2e: handle read + secio fragmented": - # proc testListenerDialer(): Future[bool] {.async.} = - # let - # server: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") - # serverInfo = PeerInfo.init(PrivateKey.random(ECDSA), [server]) - # serverNoise = newSecio(serverInfo.privateKey) - # readTask = newFuture[void]() - - # var hugePayload = newSeq[byte](0x1200000) - # check randomBytes(hugePayload) == hugePayload.len - # trace "Sending huge payload", size = hugePayload.len - - # proc connHandler(conn: Connection) {.async, gcsafe.} = - # let sconn = await serverNoise.secure(conn) - # defer: - # await sconn.close() - # let msg = await sconn.read(0x1200000) - # check msg == hugePayload - # readTask.complete() - - # let - # transport1: TcpTransport = TcpTransport.init() - # asyncCheck await transport1.listen(server, connHandler) - - # let - # transport2: TcpTransport = TcpTransport.init() - # clientInfo = PeerInfo.init(PrivateKey.random(ECDSA), [transport1.ma]) - # clientNoise = newSecio(clientInfo.privateKey) - # conn = await transport2.dial(transport1.ma) - # sconn = await clientNoise.secure(conn) - - # await sconn.write(hugePayload) - # await readTask - - # await sconn.close() - # await conn.close() - # await transport2.close() - # await transport1.close() - - # result = true - - # check: - # waitFor(testListenerDialer()) == true