mirror of https://github.com/vacp2p/nim-libp2p.git
Connection manager (#277)
* splitting out connection management * wip * wip conn mngr tests * set peerinfo in contructor * comments and documentation * tests * wip * add `None` to detect untagged connections * use `PeerID` to index connections * fix tests * remove useless equality
This commit is contained in:
parent
170685f9c6
commit
0348773ec9
|
@ -0,0 +1,276 @@
|
||||||
|
## Nim-LibP2P
|
||||||
|
## Copyright (c) 2020 Status Research & Development GmbH
|
||||||
|
## Licensed under either of
|
||||||
|
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||||
|
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||||
|
## at your option.
|
||||||
|
## This file may not be copied, modified, or distributed except according to
|
||||||
|
## those terms.
|
||||||
|
|
||||||
|
import tables, sequtils, sets
|
||||||
|
import chronos, chronicles, metrics
|
||||||
|
import peerinfo,
|
||||||
|
stream/connection,
|
||||||
|
muxers/muxer
|
||||||
|
|
||||||
|
declareGauge(libp2p_peers, "total connected peers")
|
||||||
|
|
||||||
|
const MaxConnectionsPerPeer = 5
|
||||||
|
|
||||||
|
type
|
||||||
|
TooManyConnections* = object of CatchableError
|
||||||
|
|
||||||
|
MuxerHolder = object
|
||||||
|
muxer: Muxer
|
||||||
|
handle: Future[void]
|
||||||
|
|
||||||
|
ConnManager* = ref object of RootObj
|
||||||
|
# NOTE: don't change to PeerInfo here
|
||||||
|
# the reference semantics on the PeerInfo
|
||||||
|
# object itself make it succeptible to
|
||||||
|
# copies and mangling by unrelated code.
|
||||||
|
conns: Table[PeerID, HashSet[Connection]]
|
||||||
|
muxed: Table[Connection, MuxerHolder]
|
||||||
|
cleanUpLock: Table[PeerInfo, AsyncLock]
|
||||||
|
maxConns: int
|
||||||
|
|
||||||
|
proc newTooManyConnections(): ref TooManyConnections {.inline.} =
|
||||||
|
result = newException(TooManyConnections, "too many connections for peer")
|
||||||
|
|
||||||
|
proc init*(C: type ConnManager,
|
||||||
|
maxConnsPerPeer: int = MaxConnectionsPerPeer): ConnManager =
|
||||||
|
C(maxConns: maxConnsPerPeer,
|
||||||
|
conns: initTable[PeerID, HashSet[Connection]](),
|
||||||
|
muxed: initTable[Connection, MuxerHolder]())
|
||||||
|
|
||||||
|
proc contains*(c: ConnManager, conn: Connection): bool =
|
||||||
|
## checks if a connection is being tracked by the
|
||||||
|
## connection manager
|
||||||
|
##
|
||||||
|
|
||||||
|
if isNil(conn):
|
||||||
|
return
|
||||||
|
|
||||||
|
if isNil(conn.peerInfo):
|
||||||
|
return
|
||||||
|
|
||||||
|
if conn.peerInfo.peerId notin c.conns:
|
||||||
|
return
|
||||||
|
|
||||||
|
return conn in c.conns[conn.peerInfo.peerId]
|
||||||
|
|
||||||
|
proc contains*(c: ConnManager, peerId: PeerID): bool =
|
||||||
|
peerId in c.conns
|
||||||
|
|
||||||
|
proc contains*(c: ConnManager, muxer: Muxer): bool =
|
||||||
|
## checks if a muxer is being tracked by the connection
|
||||||
|
## manager
|
||||||
|
##
|
||||||
|
|
||||||
|
if isNil(muxer):
|
||||||
|
return
|
||||||
|
|
||||||
|
let conn = muxer.connection
|
||||||
|
if conn notin c:
|
||||||
|
return
|
||||||
|
|
||||||
|
if conn notin c.muxed:
|
||||||
|
return
|
||||||
|
|
||||||
|
return muxer == c.muxed[conn].muxer
|
||||||
|
|
||||||
|
proc cleanupConn(c: ConnManager, conn: Connection) {.async.} =
|
||||||
|
## clean connection's resources such as muxers and streams
|
||||||
|
##
|
||||||
|
|
||||||
|
if isNil(conn):
|
||||||
|
return
|
||||||
|
|
||||||
|
if isNil(conn.peerInfo):
|
||||||
|
return
|
||||||
|
|
||||||
|
let peerInfo = conn.peerInfo
|
||||||
|
let lock = c.cleanUpLock.mgetOrPut(peerInfo, newAsyncLock())
|
||||||
|
|
||||||
|
try:
|
||||||
|
await lock.acquire()
|
||||||
|
trace "cleaning up connection for peer", peer = $peerInfo
|
||||||
|
if conn in c.muxed:
|
||||||
|
let muxerHolder = c.muxed[conn]
|
||||||
|
c.muxed.del(conn)
|
||||||
|
|
||||||
|
await muxerHolder.muxer.close()
|
||||||
|
if not(isNil(muxerHolder.handle)):
|
||||||
|
await muxerHolder.handle
|
||||||
|
|
||||||
|
if peerInfo.peerId in c.conns:
|
||||||
|
c.conns[peerInfo.peerId].excl(conn)
|
||||||
|
|
||||||
|
if c.conns[peerInfo.peerId].len == 0:
|
||||||
|
c.conns.del(peerInfo.peerId)
|
||||||
|
|
||||||
|
if not(conn.peerInfo.isClosed()):
|
||||||
|
conn.peerInfo.close()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await conn.close()
|
||||||
|
libp2p_peers.set(c.conns.len.int64)
|
||||||
|
|
||||||
|
if lock.locked():
|
||||||
|
lock.release()
|
||||||
|
|
||||||
|
trace "connection cleaned up"
|
||||||
|
|
||||||
|
proc onClose(c: ConnManager, conn: Connection) {.async.} =
|
||||||
|
## connection close even handler
|
||||||
|
##
|
||||||
|
## triggers the connections resource cleanup
|
||||||
|
##
|
||||||
|
|
||||||
|
await conn.closeEvent.wait()
|
||||||
|
trace "triggering connection cleanup"
|
||||||
|
await c.cleanupConn(conn)
|
||||||
|
|
||||||
|
proc selectConn*(c: ConnManager,
|
||||||
|
peerInfo: PeerInfo,
|
||||||
|
dir: Direction): Connection =
|
||||||
|
## Select a connection for the provided peer and direction
|
||||||
|
##
|
||||||
|
|
||||||
|
if isNil(peerInfo):
|
||||||
|
return
|
||||||
|
|
||||||
|
let conns = toSeq(
|
||||||
|
c.conns.getOrDefault(peerInfo.peerId))
|
||||||
|
.filterIt( it.dir == dir )
|
||||||
|
|
||||||
|
if conns.len > 0:
|
||||||
|
return conns[0]
|
||||||
|
|
||||||
|
proc selectConn*(c: ConnManager, peerInfo: PeerInfo): Connection =
|
||||||
|
## Select a connection for the provided giving priority
|
||||||
|
## to outgoing connections
|
||||||
|
##
|
||||||
|
|
||||||
|
if isNil(peerInfo):
|
||||||
|
return
|
||||||
|
|
||||||
|
var conn = c.selectConn(peerInfo, Direction.Out)
|
||||||
|
if isNil(conn):
|
||||||
|
conn = c.selectConn(peerInfo, Direction.In)
|
||||||
|
|
||||||
|
return conn
|
||||||
|
|
||||||
|
proc selectMuxer*(c: ConnManager, conn: Connection): Muxer =
|
||||||
|
## select the muxer for the provided connection
|
||||||
|
##
|
||||||
|
|
||||||
|
if isNil(conn):
|
||||||
|
return
|
||||||
|
|
||||||
|
if conn in c.muxed:
|
||||||
|
return c.muxed[conn].muxer
|
||||||
|
|
||||||
|
proc storeConn*(c: ConnManager, conn: Connection) =
|
||||||
|
## store a connection
|
||||||
|
##
|
||||||
|
|
||||||
|
if isNil(conn):
|
||||||
|
raise newException(CatchableError, "connection cannot be nil")
|
||||||
|
|
||||||
|
if isNil(conn.peerInfo):
|
||||||
|
raise newException(CatchableError, "empty peer info")
|
||||||
|
|
||||||
|
let peerInfo = conn.peerInfo
|
||||||
|
if c.conns.getOrDefault(peerInfo.peerId).len > c.maxConns:
|
||||||
|
trace "too many connections", peer = $conn.peerInfo,
|
||||||
|
conns = c.conns
|
||||||
|
.getOrDefault(peerInfo.peerId).len
|
||||||
|
|
||||||
|
raise newTooManyConnections()
|
||||||
|
|
||||||
|
if peerInfo.peerId notin c.conns:
|
||||||
|
c.conns[peerInfo.peerId] = initHashSet[Connection]()
|
||||||
|
|
||||||
|
c.conns[peerInfo.peerId].incl(conn)
|
||||||
|
|
||||||
|
# launch on close listener
|
||||||
|
asyncCheck c.onClose(conn)
|
||||||
|
libp2p_peers.set(c.conns.len.int64)
|
||||||
|
|
||||||
|
proc storeOutgoing*(c: ConnManager, conn: Connection) =
|
||||||
|
conn.dir = Direction.Out
|
||||||
|
c.storeConn(conn)
|
||||||
|
|
||||||
|
proc storeIncoming*(c: ConnManager, conn: Connection) =
|
||||||
|
conn.dir = Direction.In
|
||||||
|
c.storeConn(conn)
|
||||||
|
|
||||||
|
proc storeMuxer*(c: ConnManager,
|
||||||
|
muxer: Muxer,
|
||||||
|
handle: Future[void] = nil) =
|
||||||
|
## store the connection and muxer
|
||||||
|
##
|
||||||
|
|
||||||
|
if isNil(muxer):
|
||||||
|
raise newException(CatchableError, "muxer cannot be nil")
|
||||||
|
|
||||||
|
if isNil(muxer.connection):
|
||||||
|
raise newException(CatchableError, "muxer's connection cannot be nil")
|
||||||
|
|
||||||
|
c.muxed[muxer.connection] = MuxerHolder(
|
||||||
|
muxer: muxer,
|
||||||
|
handle: handle)
|
||||||
|
|
||||||
|
trace "storred connection", connections = c.conns.len
|
||||||
|
|
||||||
|
proc getMuxedStream*(c: ConnManager,
|
||||||
|
peerInfo: PeerInfo,
|
||||||
|
dir: Direction): Future[Connection] {.async, gcsafe.} =
|
||||||
|
## get a muxed stream for the provided peer
|
||||||
|
## with the given direction
|
||||||
|
##
|
||||||
|
|
||||||
|
let muxer = c.selectMuxer(c.selectConn(peerInfo, dir))
|
||||||
|
if not(isNil(muxer)):
|
||||||
|
return await muxer.newStream()
|
||||||
|
|
||||||
|
proc getMuxedStream*(c: ConnManager,
|
||||||
|
peerInfo: PeerInfo): Future[Connection] {.async, gcsafe.} =
|
||||||
|
## get a muxed stream for the passed peer from any connection
|
||||||
|
##
|
||||||
|
|
||||||
|
let muxer = c.selectMuxer(c.selectConn(peerInfo))
|
||||||
|
if not(isNil(muxer)):
|
||||||
|
return await muxer.newStream()
|
||||||
|
|
||||||
|
proc getMuxedStream*(c: ConnManager,
|
||||||
|
conn: Connection): Future[Connection] {.async, gcsafe.} =
|
||||||
|
## get a muxed stream for the passed connection
|
||||||
|
##
|
||||||
|
|
||||||
|
let muxer = c.selectMuxer(conn)
|
||||||
|
if not(isNil(muxer)):
|
||||||
|
return await muxer.newStream()
|
||||||
|
|
||||||
|
proc dropPeer*(c: ConnManager, peerInfo: PeerInfo) {.async.} =
|
||||||
|
## drop connections and cleanup resources for peer
|
||||||
|
##
|
||||||
|
|
||||||
|
for conn in c.conns.getOrDefault(peerInfo.peerId):
|
||||||
|
if not(isNil(conn)):
|
||||||
|
await c.cleanupConn(conn)
|
||||||
|
|
||||||
|
proc close*(c: ConnManager) {.async.} =
|
||||||
|
## cleanup resources for the connection
|
||||||
|
## manager
|
||||||
|
##
|
||||||
|
|
||||||
|
for conns in toSeq(c.conns.values):
|
||||||
|
for conn in conns:
|
||||||
|
try:
|
||||||
|
await c.cleanupConn(conn)
|
||||||
|
except CancelledError as exc:
|
||||||
|
raise exc
|
||||||
|
except CatchableError as exc:
|
||||||
|
warn "error cleaning up connections"
|
|
@ -7,7 +7,7 @@
|
||||||
## This file may not be copied, modified, or distributed except according to
|
## This file may not be copied, modified, or distributed except according to
|
||||||
## those terms.
|
## those terms.
|
||||||
|
|
||||||
import options, sequtils
|
import options, sequtils, hashes
|
||||||
import chronos, chronicles
|
import chronos, chronicles
|
||||||
import peerid, multiaddress, crypto/crypto
|
import peerid, multiaddress, crypto/crypto
|
||||||
|
|
||||||
|
@ -134,18 +134,5 @@ proc publicKey*(p: PeerInfo): Option[PublicKey] {.inline.} =
|
||||||
else:
|
else:
|
||||||
result = some(p.privateKey.getKey().tryGet())
|
result = some(p.privateKey.getKey().tryGet())
|
||||||
|
|
||||||
func `==`*(a, b: PeerInfo): bool =
|
func hash*(p: PeerInfo): Hash =
|
||||||
# override equiality to support both nil and peerInfo comparisons
|
cast[pointer](p).hash
|
||||||
# this in the future will allow us to recycle refs
|
|
||||||
let
|
|
||||||
aptr = cast[pointer](a)
|
|
||||||
bptr = cast[pointer](b)
|
|
||||||
|
|
||||||
if isNil(aptr) and isNil(bptr):
|
|
||||||
return true
|
|
||||||
|
|
||||||
if isNil(aptr) or isNil(bptr):
|
|
||||||
return false
|
|
||||||
|
|
||||||
if aptr == bptr and a.peerId == b.peerId:
|
|
||||||
return true
|
|
||||||
|
|
|
@ -8,14 +8,14 @@
|
||||||
## those terms.
|
## those terms.
|
||||||
|
|
||||||
import std/[tables, sequtils, sets]
|
import std/[tables, sequtils, sets]
|
||||||
import chronos, chronicles
|
import chronos, chronicles, metrics
|
||||||
import pubsubpeer,
|
import pubsubpeer,
|
||||||
rpc/[message, messages],
|
rpc/[message, messages],
|
||||||
../protocol,
|
../protocol,
|
||||||
../../stream/connection,
|
../../stream/connection,
|
||||||
../../peerid,
|
../../peerid,
|
||||||
../../peerinfo
|
../../peerinfo,
|
||||||
import metrics
|
../../errors
|
||||||
|
|
||||||
export PubSubPeer
|
export PubSubPeer
|
||||||
export PubSubObserver
|
export PubSubObserver
|
||||||
|
@ -233,8 +233,11 @@ method subscribe*(p: PubSub,
|
||||||
|
|
||||||
p.topics[topic].handler.add(handler)
|
p.topics[topic].handler.add(handler)
|
||||||
|
|
||||||
|
var sent: seq[Future[void]]
|
||||||
for peer in toSeq(p.peers.values):
|
for peer in toSeq(p.peers.values):
|
||||||
await p.sendSubs(peer, @[topic], true)
|
sent.add(p.sendSubs(peer, @[topic], true))
|
||||||
|
|
||||||
|
checkFutures(await allFinished(sent))
|
||||||
|
|
||||||
# metrics
|
# metrics
|
||||||
libp2p_pubsub_topics.inc()
|
libp2p_pubsub_topics.inc()
|
||||||
|
|
|
@ -48,22 +48,6 @@ func hash*(p: PubSubPeer): Hash =
|
||||||
# int is either 32/64, so intptr basically, pubsubpeer is a ref
|
# int is either 32/64, so intptr basically, pubsubpeer is a ref
|
||||||
cast[pointer](p).hash
|
cast[pointer](p).hash
|
||||||
|
|
||||||
func `==`*(a, b: PubSubPeer): bool =
|
|
||||||
# override equiality to support both nil and peerInfo comparisons
|
|
||||||
# this in the future will allow us to recycle refs
|
|
||||||
let
|
|
||||||
aptr = cast[pointer](a)
|
|
||||||
bptr = cast[pointer](b)
|
|
||||||
|
|
||||||
if isNil(aptr) and isNil(bptr):
|
|
||||||
return true
|
|
||||||
|
|
||||||
if isNil(aptr) or isNil(bptr):
|
|
||||||
return false
|
|
||||||
|
|
||||||
if aptr == bptr and a.peerInfo == b.peerInfo:
|
|
||||||
return true
|
|
||||||
|
|
||||||
proc id*(p: PubSubPeer): string = p.peerInfo.id
|
proc id*(p: PubSubPeer): string = p.peerInfo.id
|
||||||
|
|
||||||
proc inUse*(p: PubSubPeer): bool =
|
proc inUse*(p: PubSubPeer): bool =
|
||||||
|
|
|
@ -289,7 +289,7 @@ method close*(s: BufferStream) {.async, gcsafe.} =
|
||||||
try:
|
try:
|
||||||
## close the stream and clear the buffer
|
## close the stream and clear the buffer
|
||||||
if not s.isClosed:
|
if not s.isClosed:
|
||||||
trace "closing bufferstream", oid = s.oid
|
trace "closing bufferstream", oid = $s.oid
|
||||||
s.isEof = true
|
s.isEof = true
|
||||||
for r in s.readReqs:
|
for r in s.readReqs:
|
||||||
if not(isNil(r)) and not(r.finished()):
|
if not(isNil(r)) and not(r.finished()):
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
## This file may not be copied, modified, or distributed except according to
|
## This file may not be copied, modified, or distributed except according to
|
||||||
## those terms.
|
## those terms.
|
||||||
|
|
||||||
|
import hashes
|
||||||
import chronos, metrics
|
import chronos, metrics
|
||||||
import lpstream,
|
import lpstream,
|
||||||
../multiaddress,
|
../multiaddress,
|
||||||
|
@ -18,9 +19,13 @@ const
|
||||||
ConnectionTrackerName* = "libp2p.connection"
|
ConnectionTrackerName* = "libp2p.connection"
|
||||||
|
|
||||||
type
|
type
|
||||||
|
Direction* {.pure.} = enum
|
||||||
|
None, In, Out
|
||||||
|
|
||||||
Connection* = ref object of LPStream
|
Connection* = ref object of LPStream
|
||||||
peerInfo*: PeerInfo
|
peerInfo*: PeerInfo
|
||||||
observedAddr*: Multiaddress
|
observedAddr*: Multiaddress
|
||||||
|
dir*: Direction
|
||||||
|
|
||||||
ConnectionTracker* = ref object of TrackerBase
|
ConnectionTracker* = ref object of TrackerBase
|
||||||
opened*: uint64
|
opened*: uint64
|
||||||
|
@ -50,9 +55,11 @@ proc setupConnectionTracker(): ConnectionTracker =
|
||||||
result.isLeaked = leakTransport
|
result.isLeaked = leakTransport
|
||||||
addTracker(ConnectionTrackerName, result)
|
addTracker(ConnectionTrackerName, result)
|
||||||
|
|
||||||
proc init*[T: Connection](self: var T, peerInfo: PeerInfo): T =
|
proc init*(C: type Connection,
|
||||||
new self
|
peerInfo: PeerInfo,
|
||||||
self.initStream()
|
dir: Direction): Connection =
|
||||||
|
result = C(peerInfo: peerInfo, dir: dir)
|
||||||
|
result.initStream()
|
||||||
|
|
||||||
method initStream*(s: Connection) =
|
method initStream*(s: Connection) =
|
||||||
if s.objName.len == 0:
|
if s.objName.len == 0:
|
||||||
|
@ -63,9 +70,13 @@ method initStream*(s: Connection) =
|
||||||
inc getConnectionTracker().opened
|
inc getConnectionTracker().opened
|
||||||
|
|
||||||
method close*(s: Connection) {.async.} =
|
method close*(s: Connection) {.async.} =
|
||||||
await procCall LPStream(s).close()
|
if not s.isClosed:
|
||||||
inc getConnectionTracker().closed
|
await procCall LPStream(s).close()
|
||||||
|
inc getConnectionTracker().closed
|
||||||
|
|
||||||
proc `$`*(conn: Connection): string =
|
proc `$`*(conn: Connection): string =
|
||||||
if not isNil(conn.peerInfo):
|
if not isNil(conn.peerInfo):
|
||||||
result = conn.peerInfo.id
|
result = conn.peerInfo.id
|
||||||
|
|
||||||
|
func hash*(p: Connection): Hash =
|
||||||
|
cast[pointer](p).hash
|
||||||
|
|
|
@ -76,15 +76,6 @@ method initStream*(s: LPStream) {.base.} =
|
||||||
libp2p_open_streams.inc(labelValues = [s.objName])
|
libp2p_open_streams.inc(labelValues = [s.objName])
|
||||||
trace "stream created", oid = $s.oid, name = s.objName
|
trace "stream created", oid = $s.oid, name = s.objName
|
||||||
|
|
||||||
# TODO: debuging aid to troubleshoot streams open/close
|
|
||||||
# try:
|
|
||||||
# echo "ChronosStream ", libp2p_open_streams.value(labelValues = ["ChronosStream"])
|
|
||||||
# echo "SecureConn ", libp2p_open_streams.value(labelValues = ["SecureConn"])
|
|
||||||
# # doAssert(libp2p_open_streams.value(labelValues = ["ChronosStream"]) >=
|
|
||||||
# # libp2p_open_streams.value(labelValues = ["SecureConn"]))
|
|
||||||
# except CatchableError:
|
|
||||||
# discard
|
|
||||||
|
|
||||||
proc join*(s: LPStream): Future[void] =
|
proc join*(s: LPStream): Future[void] =
|
||||||
s.closeEvent.wait()
|
s.closeEvent.wait()
|
||||||
|
|
||||||
|
@ -207,12 +198,3 @@ method close*(s: LPStream) {.base, async.} =
|
||||||
s.closeEvent.fire()
|
s.closeEvent.fire()
|
||||||
libp2p_open_streams.dec(labelValues = [s.objName])
|
libp2p_open_streams.dec(labelValues = [s.objName])
|
||||||
trace "stream destroyed", oid = $s.oid, name = s.objName
|
trace "stream destroyed", oid = $s.oid, name = s.objName
|
||||||
|
|
||||||
# TODO: debuging aid to troubleshoot streams open/close
|
|
||||||
# try:
|
|
||||||
# echo "ChronosStream ", libp2p_open_streams.value(labelValues = ["ChronosStream"])
|
|
||||||
# echo "SecureConn ", libp2p_open_streams.value(labelValues = ["SecureConn"])
|
|
||||||
# # doAssert(libp2p_open_streams.value(labelValues = ["ChronosStream"]) >=
|
|
||||||
# # libp2p_open_streams.value(labelValues = ["SecureConn"]))
|
|
||||||
# except CatchableError:
|
|
||||||
# discard
|
|
||||||
|
|
|
@ -11,7 +11,6 @@ import tables,
|
||||||
sequtils,
|
sequtils,
|
||||||
options,
|
options,
|
||||||
sets,
|
sets,
|
||||||
algorithm,
|
|
||||||
oids
|
oids
|
||||||
|
|
||||||
import chronos,
|
import chronos,
|
||||||
|
@ -28,6 +27,7 @@ import stream/connection,
|
||||||
protocols/identify,
|
protocols/identify,
|
||||||
protocols/pubsub/pubsub,
|
protocols/pubsub/pubsub,
|
||||||
muxers/muxer,
|
muxers/muxer,
|
||||||
|
connmanager,
|
||||||
peerid
|
peerid
|
||||||
|
|
||||||
logScope:
|
logScope:
|
||||||
|
@ -39,33 +39,16 @@ logScope:
|
||||||
# and only if the channel has been secured (i.e. if a secure manager has been
|
# and only if the channel has been secured (i.e. if a secure manager has been
|
||||||
# previously provided)
|
# previously provided)
|
||||||
|
|
||||||
declareGauge(libp2p_peers, "total connected peers")
|
|
||||||
declareCounter(libp2p_dialed_peers, "dialed peers")
|
declareCounter(libp2p_dialed_peers, "dialed peers")
|
||||||
declareCounter(libp2p_failed_dials, "failed dials")
|
declareCounter(libp2p_failed_dials, "failed dials")
|
||||||
declareCounter(libp2p_failed_upgrade, "peers failed upgrade")
|
declareCounter(libp2p_failed_upgrade, "peers failed upgrade")
|
||||||
|
|
||||||
const MaxConnectionsPerPeer = 5
|
|
||||||
|
|
||||||
type
|
type
|
||||||
NoPubSubException* = object of CatchableError
|
NoPubSubException* = object of CatchableError
|
||||||
TooManyConnections* = object of CatchableError
|
|
||||||
|
|
||||||
Direction {.pure.} = enum
|
|
||||||
In, Out
|
|
||||||
|
|
||||||
ConnectionHolder = object
|
|
||||||
dir: Direction
|
|
||||||
conn: Connection
|
|
||||||
|
|
||||||
MuxerHolder = object
|
|
||||||
dir: Direction
|
|
||||||
muxer: Muxer
|
|
||||||
handle: Future[void]
|
|
||||||
|
|
||||||
Switch* = ref object of RootObj
|
Switch* = ref object of RootObj
|
||||||
peerInfo*: PeerInfo
|
peerInfo*: PeerInfo
|
||||||
connections*: Table[string, seq[ConnectionHolder]]
|
connManager: ConnManager
|
||||||
muxed*: Table[string, seq[MuxerHolder]]
|
|
||||||
transports*: seq[Transport]
|
transports*: seq[Transport]
|
||||||
protocols*: seq[LPProtocol]
|
protocols*: seq[LPProtocol]
|
||||||
muxers*: Table[string, MuxerProvider]
|
muxers*: Table[string, MuxerProvider]
|
||||||
|
@ -75,90 +58,20 @@ type
|
||||||
secureManagers*: seq[Secure]
|
secureManagers*: seq[Secure]
|
||||||
pubSub*: Option[PubSub]
|
pubSub*: Option[PubSub]
|
||||||
dialLock: Table[string, AsyncLock]
|
dialLock: Table[string, AsyncLock]
|
||||||
cleanUpLock: Table[string, AsyncLock]
|
|
||||||
|
|
||||||
proc newNoPubSubException(): ref NoPubSubException {.inline.} =
|
proc newNoPubSubException(): ref NoPubSubException {.inline.} =
|
||||||
result = newException(NoPubSubException, "no pubsub provided!")
|
result = newException(NoPubSubException, "no pubsub provided!")
|
||||||
|
|
||||||
proc newTooManyConnections(): ref TooManyConnections {.inline.} =
|
|
||||||
result = newException(TooManyConnections, "too many connections for peer")
|
|
||||||
|
|
||||||
proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.}
|
proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.}
|
||||||
proc subscribePeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.}
|
proc subscribePeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.}
|
||||||
|
|
||||||
proc selectConn(s: Switch, peerInfo: PeerInfo): Connection =
|
proc cleanupPubSubPeer(s: Switch, conn: Connection) {.async.} =
|
||||||
## select the "best" connection according to some criteria
|
await conn.closeEvent.wait()
|
||||||
##
|
if s.pubSub.isSome:
|
||||||
## Ideally when the connection's stats are available
|
await s.pubSub.get().unsubscribePeer(conn.peerInfo)
|
||||||
## we'd select the fastest, but for now we simply pick an outgoing
|
|
||||||
## connection first if none is available, we pick the first outgoing
|
|
||||||
##
|
|
||||||
|
|
||||||
if isNil(peerInfo):
|
proc isConnected*(s: Switch, peer: PeerInfo): bool =
|
||||||
return
|
peer.peerId in s.connManager
|
||||||
|
|
||||||
let conns = s.connections
|
|
||||||
.getOrDefault(peerInfo.id)
|
|
||||||
# it should be OK to sort on each
|
|
||||||
# access as there should only be
|
|
||||||
# up to MaxConnectionsPerPeer entries
|
|
||||||
.sorted(
|
|
||||||
proc(a, b: ConnectionHolder): int =
|
|
||||||
if a.dir < b.dir: -1
|
|
||||||
elif a.dir == b.dir: 0
|
|
||||||
else: 1
|
|
||||||
, SortOrder.Descending)
|
|
||||||
|
|
||||||
if conns.len > 0:
|
|
||||||
return conns[0].conn
|
|
||||||
|
|
||||||
proc selectMuxer(s: Switch, conn: Connection): Muxer =
|
|
||||||
## select the muxer for the supplied connection
|
|
||||||
##
|
|
||||||
|
|
||||||
if isNil(conn):
|
|
||||||
return
|
|
||||||
|
|
||||||
if not(isNil(conn.peerInfo)) and conn.peerInfo.id in s.muxed:
|
|
||||||
if s.muxed[conn.peerInfo.id].len > 0:
|
|
||||||
let muxers = s.muxed[conn.peerInfo.id]
|
|
||||||
.filterIt( it.muxer.connection == conn )
|
|
||||||
if muxers.len > 0:
|
|
||||||
return muxers[0].muxer
|
|
||||||
|
|
||||||
proc storeConn(s: Switch,
|
|
||||||
muxer: Muxer,
|
|
||||||
dir: Direction,
|
|
||||||
handle: Future[void] = nil) {.async.} =
|
|
||||||
## store the connection and muxer
|
|
||||||
##
|
|
||||||
if isNil(muxer):
|
|
||||||
return
|
|
||||||
|
|
||||||
let conn = muxer.connection
|
|
||||||
if isNil(conn):
|
|
||||||
return
|
|
||||||
|
|
||||||
let id = conn.peerInfo.id
|
|
||||||
if s.connections.getOrDefault(id).len > MaxConnectionsPerPeer:
|
|
||||||
warn "disconnecting peer, too many connections", peer = $conn.peerInfo,
|
|
||||||
conns = s.connections
|
|
||||||
.getOrDefault(id).len
|
|
||||||
await s.disconnect(conn.peerInfo)
|
|
||||||
raise newTooManyConnections()
|
|
||||||
|
|
||||||
s.connections.mgetOrPut(
|
|
||||||
id,
|
|
||||||
newSeq[ConnectionHolder]())
|
|
||||||
.add(ConnectionHolder(conn: conn, dir: dir))
|
|
||||||
|
|
||||||
s.muxed.mgetOrPut(
|
|
||||||
muxer.connection.peerInfo.id,
|
|
||||||
newSeq[MuxerHolder]())
|
|
||||||
.add(MuxerHolder(muxer: muxer, handle: handle, dir: dir))
|
|
||||||
|
|
||||||
trace "storred connection", connections = s.connections.len
|
|
||||||
libp2p_peers.set(s.connections.len.int64)
|
|
||||||
|
|
||||||
proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
|
proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
|
||||||
if s.secureManagers.len <= 0:
|
if s.secureManagers.len <= 0:
|
||||||
|
@ -170,9 +83,11 @@ proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
|
||||||
|
|
||||||
trace "securing connection", codec = manager
|
trace "securing connection", codec = manager
|
||||||
let secureProtocol = s.secureManagers.filterIt(it.codec == manager)
|
let secureProtocol = s.secureManagers.filterIt(it.codec == manager)
|
||||||
|
|
||||||
# ms.select should deal with the correctness of this
|
# ms.select should deal with the correctness of this
|
||||||
# let's avoid duplicating checks but detect if it fails to do it properly
|
# let's avoid duplicating checks but detect if it fails to do it properly
|
||||||
doAssert(secureProtocol.len > 0)
|
doAssert(secureProtocol.len > 0)
|
||||||
|
|
||||||
result = await secureProtocol[0].secure(conn, true)
|
result = await secureProtocol[0].secure(conn, true)
|
||||||
|
|
||||||
proc identify(s: Switch, conn: Connection) {.async, gcsafe.} =
|
proc identify(s: Switch, conn: Connection) {.async, gcsafe.} =
|
||||||
|
@ -218,6 +133,7 @@ proc mux(s: Switch, conn: Connection) {.async, gcsafe.} =
|
||||||
|
|
||||||
# create new muxer for connection
|
# create new muxer for connection
|
||||||
let muxer = s.muxers[muxerName].newMuxer(conn)
|
let muxer = s.muxers[muxerName].newMuxer(conn)
|
||||||
|
s.connManager.storeMuxer(muxer)
|
||||||
|
|
||||||
trace "found a muxer", name = muxerName, peer = $conn
|
trace "found a muxer", name = muxerName, peer = $conn
|
||||||
|
|
||||||
|
@ -247,75 +163,10 @@ proc mux(s: Switch, conn: Connection) {.async, gcsafe.} =
|
||||||
|
|
||||||
# store it in muxed connections if we have a peer for it
|
# store it in muxed connections if we have a peer for it
|
||||||
trace "adding muxer for peer", peer = conn.peerInfo.id
|
trace "adding muxer for peer", peer = conn.peerInfo.id
|
||||||
await s.storeConn(muxer, Direction.Out, handlerFut)
|
s.connManager.storeMuxer(muxer, handlerFut) # update muxer with handler
|
||||||
|
|
||||||
proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} =
|
|
||||||
if isNil(conn):
|
|
||||||
return
|
|
||||||
|
|
||||||
if isNil(conn.peerInfo):
|
|
||||||
return
|
|
||||||
|
|
||||||
let id = conn.peerInfo.id
|
|
||||||
let lock = s.cleanUpLock.mgetOrPut(id, newAsyncLock())
|
|
||||||
|
|
||||||
try:
|
|
||||||
await lock.acquire()
|
|
||||||
trace "cleaning up connection for peer", peerId = id
|
|
||||||
if id in s.muxed:
|
|
||||||
let muxerHolder = s.muxed[id]
|
|
||||||
.filterIt(
|
|
||||||
it.muxer.connection == conn
|
|
||||||
)
|
|
||||||
|
|
||||||
if muxerHolder.len > 0:
|
|
||||||
await muxerHolder[0].muxer.close()
|
|
||||||
if not(isNil(muxerHolder[0].handle)):
|
|
||||||
await muxerHolder[0].handle
|
|
||||||
|
|
||||||
if id in s.muxed:
|
|
||||||
s.muxed[id].keepItIf(
|
|
||||||
it.muxer.connection != conn
|
|
||||||
)
|
|
||||||
|
|
||||||
if s.muxed[id].len == 0:
|
|
||||||
s.muxed.del(id)
|
|
||||||
|
|
||||||
if s.pubSub.isSome:
|
|
||||||
await s.pubSub.get()
|
|
||||||
.unsubscribePeer(conn.peerInfo)
|
|
||||||
|
|
||||||
if id in s.connections:
|
|
||||||
s.connections[id].keepItIf(
|
|
||||||
it.conn != conn
|
|
||||||
)
|
|
||||||
|
|
||||||
if s.connections[id].len == 0:
|
|
||||||
s.connections.del(id)
|
|
||||||
|
|
||||||
# TODO: Investigate cleanupConn() always called twice for one peer.
|
|
||||||
if not(conn.peerInfo.isClosed()):
|
|
||||||
conn.peerInfo.close()
|
|
||||||
finally:
|
|
||||||
await conn.close()
|
|
||||||
libp2p_peers.set(s.connections.len.int64)
|
|
||||||
|
|
||||||
if lock.locked():
|
|
||||||
lock.release()
|
|
||||||
|
|
||||||
proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} =
|
proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} =
|
||||||
let connections = s.connections.getOrDefault(peer.id)
|
await s.connManager.dropPeer(peer)
|
||||||
for connHolder in connections:
|
|
||||||
if not isNil(connHolder.conn):
|
|
||||||
await s.cleanupConn(connHolder.conn)
|
|
||||||
|
|
||||||
proc getMuxedStream(s: Switch, peerInfo: PeerInfo): Future[Connection] {.async, gcsafe.} =
|
|
||||||
# if there is a muxer for the connection
|
|
||||||
# use it instead to create a muxed stream
|
|
||||||
|
|
||||||
let muxer = s.selectMuxer(s.selectConn(peerInfo)) # always get the first muxer here
|
|
||||||
if not(isNil(muxer)):
|
|
||||||
return await muxer.newStream()
|
|
||||||
|
|
||||||
proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
|
proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
|
||||||
logScope:
|
logScope:
|
||||||
|
@ -388,52 +239,51 @@ proc internalConnect(s: Switch,
|
||||||
var conn: Connection
|
var conn: Connection
|
||||||
let lock = s.dialLock.mgetOrPut(id, newAsyncLock())
|
let lock = s.dialLock.mgetOrPut(id, newAsyncLock())
|
||||||
|
|
||||||
defer:
|
try:
|
||||||
|
await lock.acquire()
|
||||||
|
trace "about to dial peer", peer = id
|
||||||
|
conn = s.connManager.selectConn(peer)
|
||||||
|
if conn.isNil or (conn.closed or conn.atEof):
|
||||||
|
trace "Dialing peer", peer = id
|
||||||
|
for t in s.transports: # for each transport
|
||||||
|
for a in peer.addrs: # for each address
|
||||||
|
if t.handles(a): # check if it can dial it
|
||||||
|
trace "Dialing address", address = $a, peer = id
|
||||||
|
try:
|
||||||
|
conn = await t.dial(a)
|
||||||
|
# make sure to assign the peer to the connection
|
||||||
|
conn.peerInfo = peer
|
||||||
|
|
||||||
|
libp2p_dialed_peers.inc()
|
||||||
|
except CancelledError as exc:
|
||||||
|
trace "dialing canceled", exc = exc.msg
|
||||||
|
raise
|
||||||
|
except CatchableError as exc:
|
||||||
|
trace "dialing failed", exc = exc.msg
|
||||||
|
libp2p_failed_dials.inc()
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
let uconn = await s.upgradeOutgoing(conn)
|
||||||
|
s.connManager.storeOutgoing(uconn)
|
||||||
|
conn = uconn
|
||||||
|
except CatchableError as exc:
|
||||||
|
if not(isNil(conn)):
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
trace "Unable to establish outgoing link", exc = exc.msg
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
if isNil(conn):
|
||||||
|
libp2p_failed_upgrade.inc()
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
trace "Reusing existing connection", oid = conn.oid
|
||||||
|
finally:
|
||||||
if lock.locked():
|
if lock.locked():
|
||||||
lock.release()
|
lock.release()
|
||||||
|
|
||||||
await lock.acquire()
|
|
||||||
trace "about to dial peer", peer = id
|
|
||||||
conn = s.selectConn(peer)
|
|
||||||
if conn.isNil or conn.closed:
|
|
||||||
trace "Dialing peer", peer = id
|
|
||||||
for t in s.transports: # for each transport
|
|
||||||
for a in peer.addrs: # for each address
|
|
||||||
if t.handles(a): # check if it can dial it
|
|
||||||
trace "Dialing address", address = $a
|
|
||||||
try:
|
|
||||||
conn = await t.dial(a)
|
|
||||||
libp2p_dialed_peers.inc()
|
|
||||||
except CancelledError as exc:
|
|
||||||
trace "dialing canceled", exc = exc.msg
|
|
||||||
raise
|
|
||||||
except CatchableError as exc:
|
|
||||||
trace "dialing failed", exc = exc.msg
|
|
||||||
libp2p_failed_dials.inc()
|
|
||||||
continue
|
|
||||||
|
|
||||||
# make sure to assign the peer to the connection
|
|
||||||
conn.peerInfo = peer
|
|
||||||
try:
|
|
||||||
conn = await s.upgradeOutgoing(conn)
|
|
||||||
except CatchableError as exc:
|
|
||||||
if not(isNil(conn)):
|
|
||||||
await conn.close()
|
|
||||||
|
|
||||||
trace "Unable to establish outgoing link", exc = exc.msg
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
if isNil(conn):
|
|
||||||
libp2p_failed_upgrade.inc()
|
|
||||||
continue
|
|
||||||
|
|
||||||
conn.closeEvent.wait()
|
|
||||||
.addCallback do(udata: pointer):
|
|
||||||
asyncCheck s.cleanupConn(conn)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
trace "Reusing existing connection", oid = conn.oid
|
|
||||||
|
|
||||||
if isNil(conn):
|
if isNil(conn):
|
||||||
raise newException(CatchableError,
|
raise newException(CatchableError,
|
||||||
"Unable to establish outgoing link")
|
"Unable to establish outgoing link")
|
||||||
|
@ -443,13 +293,14 @@ proc internalConnect(s: Switch,
|
||||||
raise newException(CatchableError,
|
raise newException(CatchableError,
|
||||||
"Connection dead on arrival")
|
"Connection dead on arrival")
|
||||||
|
|
||||||
doAssert(conn.peerInfo.id in s.connections,
|
doAssert(conn in s.connManager, "connection not tracked!")
|
||||||
"connection not tracked!")
|
|
||||||
|
|
||||||
trace "dial succesfull", oid = $conn.oid,
|
trace "dial succesfull", oid = $conn.oid,
|
||||||
peer = $conn.peerInfo
|
peer = $conn.peerInfo
|
||||||
|
|
||||||
await s.subscribePeer(peer)
|
await s.subscribePeer(peer)
|
||||||
|
asyncCheck s.cleanupPubSubPeer(conn)
|
||||||
|
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
proc connect*(s: Switch, peer: PeerInfo) {.async.} =
|
proc connect*(s: Switch, peer: PeerInfo) {.async.} =
|
||||||
|
@ -460,7 +311,7 @@ proc dial*(s: Switch,
|
||||||
proto: string):
|
proto: string):
|
||||||
Future[Connection] {.async.} =
|
Future[Connection] {.async.} =
|
||||||
let conn = await s.internalConnect(peer)
|
let conn = await s.internalConnect(peer)
|
||||||
let stream = await s.getMuxedStream(peer)
|
let stream = await s.connManager.getMuxedStream(conn)
|
||||||
|
|
||||||
proc cleanup() {.async.} =
|
proc cleanup() {.async.} =
|
||||||
if not(isNil(stream)):
|
if not(isNil(stream)):
|
||||||
|
@ -505,14 +356,14 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
|
||||||
|
|
||||||
proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} =
|
proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} =
|
||||||
try:
|
try:
|
||||||
defer:
|
conn.dir = Direction.In # tag connection with direction
|
||||||
await s.cleanupConn(conn)
|
|
||||||
|
|
||||||
await s.upgradeIncoming(conn) # perform upgrade on incoming connection
|
await s.upgradeIncoming(conn) # perform upgrade on incoming connection
|
||||||
except CancelledError as exc:
|
except CancelledError as exc:
|
||||||
raise exc
|
raise exc
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
trace "Exception occurred in Switch.start", exc = exc.msg
|
trace "Exception occurred in Switch.start", exc = exc.msg
|
||||||
|
finally:
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
var startFuts: seq[Future[void]]
|
var startFuts: seq[Future[void]]
|
||||||
for t in s.transports: # for each transport
|
for t in s.transports: # for each transport
|
||||||
|
@ -537,14 +388,8 @@ proc stop*(s: Switch) {.async.} =
|
||||||
if s.pubSub.isSome:
|
if s.pubSub.isSome:
|
||||||
await s.pubSub.get().stop()
|
await s.pubSub.get().stop()
|
||||||
|
|
||||||
for conns in toSeq(s.connections.values):
|
# close and cleanup all connections
|
||||||
for conn in conns:
|
await s.connManager.close()
|
||||||
try:
|
|
||||||
await s.cleanupConn(conn.conn)
|
|
||||||
except CancelledError as exc:
|
|
||||||
raise exc
|
|
||||||
except CatchableError as exc:
|
|
||||||
warn "error cleaning up connections"
|
|
||||||
|
|
||||||
for t in s.transports:
|
for t in s.transports:
|
||||||
try:
|
try:
|
||||||
|
@ -562,7 +407,18 @@ proc subscribePeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} =
|
||||||
trace "about to subscribe to pubsub peer", peer = peerInfo.shortLog()
|
trace "about to subscribe to pubsub peer", peer = peerInfo.shortLog()
|
||||||
var stream: Connection
|
var stream: Connection
|
||||||
try:
|
try:
|
||||||
stream = await s.getMuxedStream(peerInfo)
|
stream = await s.connManager.getMuxedStream(peerInfo)
|
||||||
|
if isNil(stream):
|
||||||
|
trace "unable to subscribe to peer", peer = peerInfo.shortLog
|
||||||
|
return
|
||||||
|
|
||||||
|
if not await s.ms.select(stream, s.pubSub.get().codec):
|
||||||
|
if not(isNil(stream)):
|
||||||
|
await stream.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
s.pubSub.get().subscribePeer(stream)
|
||||||
|
|
||||||
except CancelledError as exc:
|
except CancelledError as exc:
|
||||||
if not(isNil(stream)):
|
if not(isNil(stream)):
|
||||||
await stream.close()
|
await stream.close()
|
||||||
|
@ -574,44 +430,27 @@ proc subscribePeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} =
|
||||||
if not(isNil(stream)):
|
if not(isNil(stream)):
|
||||||
await stream.close()
|
await stream.close()
|
||||||
|
|
||||||
if isNil(stream):
|
|
||||||
trace "unable to subscribe to peer", peer = peerInfo.shortLog
|
|
||||||
return
|
|
||||||
|
|
||||||
if not await s.ms.select(stream, s.pubSub.get().codec):
|
|
||||||
if not(isNil(stream)):
|
|
||||||
await stream.close()
|
|
||||||
return
|
|
||||||
|
|
||||||
s.pubSub.get().subscribePeer(stream)
|
|
||||||
|
|
||||||
proc subscribe*(s: Switch, topic: string,
|
proc subscribe*(s: Switch, topic: string,
|
||||||
handler: TopicHandler): Future[void] =
|
handler: TopicHandler) {.async.} =
|
||||||
## subscribe to a pubsub topic
|
## subscribe to a pubsub topic
|
||||||
if s.pubSub.isNone:
|
if s.pubSub.isNone:
|
||||||
var retFuture = newFuture[void]("Switch.subscribe")
|
raise newNoPubSubException()
|
||||||
retFuture.fail(newNoPubSubException())
|
|
||||||
return retFuture
|
|
||||||
|
|
||||||
return s.pubSub.get().subscribe(topic, handler)
|
await s.pubSub.get().subscribe(topic, handler)
|
||||||
|
|
||||||
proc unsubscribe*(s: Switch, topics: seq[TopicPair]): Future[void] =
|
proc unsubscribe*(s: Switch, topics: seq[TopicPair]) {.async.} =
|
||||||
## unsubscribe from topics
|
## unsubscribe from topics
|
||||||
if s.pubSub.isNone:
|
if s.pubSub.isNone:
|
||||||
var retFuture = newFuture[void]("Switch.unsubscribe")
|
raise newNoPubSubException()
|
||||||
retFuture.fail(newNoPubSubException())
|
|
||||||
return retFuture
|
|
||||||
|
|
||||||
return s.pubSub.get().unsubscribe(topics)
|
await s.pubSub.get().unsubscribe(topics)
|
||||||
|
|
||||||
proc publish*(s: Switch, topic: string, data: seq[byte]): Future[int] =
|
proc publish*(s: Switch, topic: string, data: seq[byte]): Future[int] {.async.} =
|
||||||
# pubslish to pubsub topic
|
# pubslish to pubsub topic
|
||||||
if s.pubSub.isNone:
|
if s.pubSub.isNone:
|
||||||
var retFuture = newFuture[int]("Switch.publish")
|
raise newNoPubSubException()
|
||||||
retFuture.fail(newNoPubSubException())
|
|
||||||
return retFuture
|
|
||||||
|
|
||||||
return s.pubSub.get().publish(topic, data)
|
return await s.pubSub.get().publish(topic, data)
|
||||||
|
|
||||||
proc addValidator*(s: Switch,
|
proc addValidator*(s: Switch,
|
||||||
topics: varargs[string],
|
topics: varargs[string],
|
||||||
|
@ -647,17 +486,17 @@ proc muxerHandler(s: Switch, muxer: Muxer) {.async, gcsafe.} =
|
||||||
|
|
||||||
muxer.connection.peerInfo = stream.peerInfo
|
muxer.connection.peerInfo = stream.peerInfo
|
||||||
|
|
||||||
# store muxer and muxed connection
|
# store incoming connection
|
||||||
await s.storeConn(muxer, Direction.In)
|
s.connManager.storeIncoming(muxer.connection)
|
||||||
|
|
||||||
muxer.connection.closeEvent.wait()
|
# store muxer and muxed connection
|
||||||
.addCallback do(udata: pointer):
|
s.connManager.storeMuxer(muxer)
|
||||||
asyncCheck s.cleanupConn(muxer.connection)
|
|
||||||
|
|
||||||
trace "got new muxer", peer = $muxer.connection.peerInfo
|
trace "got new muxer", peer = $muxer.connection.peerInfo
|
||||||
|
|
||||||
# try establishing a pubsub connection
|
# try establishing a pubsub connection
|
||||||
await s.subscribePeer(muxer.connection.peerInfo)
|
await s.subscribePeer(muxer.connection.peerInfo)
|
||||||
|
asyncCheck s.cleanupPubSubPeer(muxer.connection)
|
||||||
|
|
||||||
except CancelledError as exc:
|
except CancelledError as exc:
|
||||||
await muxer.close()
|
await muxer.close()
|
||||||
|
@ -680,8 +519,7 @@ proc newSwitch*(peerInfo: PeerInfo,
|
||||||
peerInfo: peerInfo,
|
peerInfo: peerInfo,
|
||||||
ms: newMultistream(),
|
ms: newMultistream(),
|
||||||
transports: transports,
|
transports: transports,
|
||||||
connections: initTable[string, seq[ConnectionHolder]](),
|
connManager: ConnManager.init(),
|
||||||
muxed: initTable[string, seq[MuxerHolder]](),
|
|
||||||
identity: identity,
|
identity: identity,
|
||||||
muxers: muxers,
|
muxers: muxers,
|
||||||
secureManagers: @secureManagers,
|
secureManagers: @secureManagers,
|
||||||
|
|
|
@ -0,0 +1,192 @@
|
||||||
|
import unittest
|
||||||
|
import chronos
|
||||||
|
import ../libp2p/[connmanager,
|
||||||
|
stream/connection,
|
||||||
|
crypto/crypto,
|
||||||
|
muxers/muxer,
|
||||||
|
peerinfo]
|
||||||
|
|
||||||
|
import helpers
|
||||||
|
|
||||||
|
type
|
||||||
|
TestMuxer = ref object of Muxer
|
||||||
|
peerInfo: PeerInfo
|
||||||
|
|
||||||
|
method newStream*(
|
||||||
|
m: TestMuxer,
|
||||||
|
name: string = "",
|
||||||
|
lazy: bool = false):
|
||||||
|
Future[Connection] {.async, gcsafe.} =
|
||||||
|
result = Connection.init(m.peerInfo, Direction.Out)
|
||||||
|
|
||||||
|
suite "Connection Manager":
|
||||||
|
teardown:
|
||||||
|
for tracker in testTrackers():
|
||||||
|
# echo tracker.dump()
|
||||||
|
check tracker.isLeaked() == false
|
||||||
|
|
||||||
|
test "add and retrive a connection":
|
||||||
|
let connMngr = ConnManager.init()
|
||||||
|
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
|
||||||
|
let conn = Connection.init(peer, Direction.In)
|
||||||
|
|
||||||
|
connMngr.storeConn(conn)
|
||||||
|
check conn in connMngr
|
||||||
|
|
||||||
|
let peerConn = connMngr.selectConn(peer)
|
||||||
|
check peerConn == conn
|
||||||
|
check peerConn.dir == Direction.In
|
||||||
|
|
||||||
|
test "add and retrieve a muxer":
|
||||||
|
let connMngr = ConnManager.init()
|
||||||
|
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
|
||||||
|
let conn = Connection.init(peer, Direction.In)
|
||||||
|
let muxer = new Muxer
|
||||||
|
muxer.connection = conn
|
||||||
|
|
||||||
|
connMngr.storeConn(conn)
|
||||||
|
connMngr.storeMuxer(muxer)
|
||||||
|
check muxer in connMngr
|
||||||
|
|
||||||
|
let peerMuxer = connMngr.selectMuxer(conn)
|
||||||
|
check peerMuxer == muxer
|
||||||
|
|
||||||
|
test "get conn with direction":
|
||||||
|
let connMngr = ConnManager.init()
|
||||||
|
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
|
||||||
|
let conn1 = Connection.init(peer, Direction.Out)
|
||||||
|
let conn2 = Connection.init(peer, Direction.In)
|
||||||
|
|
||||||
|
connMngr.storeConn(conn1)
|
||||||
|
connMngr.storeConn(conn2)
|
||||||
|
check conn1 in connMngr
|
||||||
|
check conn2 in connMngr
|
||||||
|
|
||||||
|
let outConn = connMngr.selectConn(peer, Direction.Out)
|
||||||
|
let inConn = connMngr.selectConn(peer, Direction.In)
|
||||||
|
|
||||||
|
check outConn != inConn
|
||||||
|
check outConn.dir == Direction.Out
|
||||||
|
check inConn.dir == Direction.In
|
||||||
|
|
||||||
|
test "get muxed stream for peer":
|
||||||
|
proc test() {.async.} =
|
||||||
|
let connMngr = ConnManager.init()
|
||||||
|
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
|
||||||
|
let conn = Connection.init(peer, Direction.In)
|
||||||
|
|
||||||
|
let muxer = new TestMuxer
|
||||||
|
muxer.peerInfo = peer
|
||||||
|
muxer.connection = conn
|
||||||
|
|
||||||
|
connMngr.storeConn(conn)
|
||||||
|
connMngr.storeMuxer(muxer)
|
||||||
|
check muxer in connMngr
|
||||||
|
|
||||||
|
let stream = await connMngr.getMuxedStream(peer)
|
||||||
|
check not(isNil(stream))
|
||||||
|
check stream.peerInfo == peer
|
||||||
|
|
||||||
|
waitFor(test())
|
||||||
|
|
||||||
|
test "get stream from directed connection":
|
||||||
|
proc test() {.async.} =
|
||||||
|
let connMngr = ConnManager.init()
|
||||||
|
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
|
||||||
|
let conn = Connection.init(peer, Direction.In)
|
||||||
|
|
||||||
|
let muxer = new TestMuxer
|
||||||
|
muxer.peerInfo = peer
|
||||||
|
muxer.connection = conn
|
||||||
|
|
||||||
|
connMngr.storeConn(conn)
|
||||||
|
connMngr.storeMuxer(muxer)
|
||||||
|
check muxer in connMngr
|
||||||
|
|
||||||
|
check not(isNil((await connMngr.getMuxedStream(peer, Direction.In))))
|
||||||
|
check isNil((await connMngr.getMuxedStream(peer, Direction.Out)))
|
||||||
|
|
||||||
|
waitFor(test())
|
||||||
|
|
||||||
|
test "get stream from any connection":
|
||||||
|
proc test() {.async.} =
|
||||||
|
let connMngr = ConnManager.init()
|
||||||
|
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
|
||||||
|
let conn = Connection.init(peer, Direction.In)
|
||||||
|
|
||||||
|
let muxer = new TestMuxer
|
||||||
|
muxer.peerInfo = peer
|
||||||
|
muxer.connection = conn
|
||||||
|
|
||||||
|
connMngr.storeConn(conn)
|
||||||
|
connMngr.storeMuxer(muxer)
|
||||||
|
check muxer in connMngr
|
||||||
|
|
||||||
|
check not(isNil((await connMngr.getMuxedStream(conn))))
|
||||||
|
|
||||||
|
waitFor(test())
|
||||||
|
|
||||||
|
test "should raise on too many connections":
|
||||||
|
proc test() =
|
||||||
|
let connMngr = ConnManager.init(1)
|
||||||
|
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
|
||||||
|
|
||||||
|
connMngr.storeConn(Connection.init(peer, Direction.In))
|
||||||
|
connMngr.storeConn(Connection.init(peer, Direction.In))
|
||||||
|
connMngr.storeConn(Connection.init(peer, Direction.In))
|
||||||
|
|
||||||
|
expect TooManyConnections:
|
||||||
|
test()
|
||||||
|
|
||||||
|
test "cleanup on connection close":
|
||||||
|
proc test() {.async.} =
|
||||||
|
let connMngr = ConnManager.init()
|
||||||
|
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
|
||||||
|
let conn = Connection.init(peer, Direction.In)
|
||||||
|
let muxer = new Muxer
|
||||||
|
muxer.connection = conn
|
||||||
|
|
||||||
|
connMngr.storeConn(conn)
|
||||||
|
connMngr.storeMuxer(muxer)
|
||||||
|
|
||||||
|
check conn in connMngr
|
||||||
|
check muxer in connMngr
|
||||||
|
|
||||||
|
await conn.close()
|
||||||
|
await sleepAsync(10.millis)
|
||||||
|
|
||||||
|
check conn notin connMngr
|
||||||
|
check muxer notin connMngr
|
||||||
|
|
||||||
|
waitFor(test())
|
||||||
|
|
||||||
|
test "drop connections for peer":
|
||||||
|
proc test() {.async.} =
|
||||||
|
let connMngr = ConnManager.init()
|
||||||
|
let peer = PeerInfo.init(PrivateKey.random(ECDSA, (newRng())[]).tryGet())
|
||||||
|
|
||||||
|
for i in 0..<2:
|
||||||
|
let dir = if i mod 2 == 0:
|
||||||
|
Direction.In else:
|
||||||
|
Direction.Out
|
||||||
|
|
||||||
|
let conn = Connection.init(peer, dir)
|
||||||
|
let muxer = new Muxer
|
||||||
|
muxer.connection = conn
|
||||||
|
|
||||||
|
connMngr.storeConn(conn)
|
||||||
|
connMngr.storeMuxer(muxer)
|
||||||
|
|
||||||
|
check conn in connMngr
|
||||||
|
check muxer in connMngr
|
||||||
|
check not(isNil(connMngr.selectConn(peer, dir)))
|
||||||
|
|
||||||
|
check peer in connMngr.peers
|
||||||
|
await connMngr.dropPeer(peer)
|
||||||
|
|
||||||
|
check peer notin connMngr.peers
|
||||||
|
check isNil(connMngr.selectConn(peer, Direction.In))
|
||||||
|
check isNil(connMngr.selectConn(peer, Direction.Out))
|
||||||
|
check connMngr.peers.len == 0
|
||||||
|
|
||||||
|
waitFor(test())
|
|
@ -1,6 +1,6 @@
|
||||||
{.used.}
|
{.used.}
|
||||||
|
|
||||||
import unittest, tables
|
import unittest
|
||||||
import chronos
|
import chronos
|
||||||
import stew/byteutils
|
import stew/byteutils
|
||||||
import nimcrypto/sysrand
|
import nimcrypto/sysrand
|
||||||
|
@ -56,6 +56,10 @@ suite "Switch":
|
||||||
awaiters.add(await switch2.start())
|
awaiters.add(await switch2.start())
|
||||||
|
|
||||||
let conn = await switch2.dial(switch1.peerInfo, TestCodec)
|
let conn = await switch2.dial(switch1.peerInfo, TestCodec)
|
||||||
|
|
||||||
|
check switch1.isConnected(switch2.peerInfo)
|
||||||
|
check switch2.isConnected(switch1.peerInfo)
|
||||||
|
|
||||||
await conn.writeLp("Hello!")
|
await conn.writeLp("Hello!")
|
||||||
let msg = string.fromBytes(await conn.readLp(1024))
|
let msg = string.fromBytes(await conn.readLp(1024))
|
||||||
check "Hello!" == msg
|
check "Hello!" == msg
|
||||||
|
@ -69,6 +73,9 @@ suite "Switch":
|
||||||
# this needs to go at end
|
# this needs to go at end
|
||||||
await allFuturesThrowing(awaiters)
|
await allFuturesThrowing(awaiters)
|
||||||
|
|
||||||
|
check not switch1.isConnected(switch2.peerInfo)
|
||||||
|
check not switch2.isConnected(switch1.peerInfo)
|
||||||
|
|
||||||
waitFor(testSwitch())
|
waitFor(testSwitch())
|
||||||
|
|
||||||
test "e2e should not leak bufferstreams and connections on channel close":
|
test "e2e should not leak bufferstreams and connections on channel close":
|
||||||
|
@ -96,6 +103,10 @@ suite "Switch":
|
||||||
awaiters.add(await switch2.start())
|
awaiters.add(await switch2.start())
|
||||||
|
|
||||||
let conn = await switch2.dial(switch1.peerInfo, TestCodec)
|
let conn = await switch2.dial(switch1.peerInfo, TestCodec)
|
||||||
|
|
||||||
|
check switch1.isConnected(switch2.peerInfo)
|
||||||
|
check switch2.isConnected(switch1.peerInfo)
|
||||||
|
|
||||||
await conn.writeLp("Hello!")
|
await conn.writeLp("Hello!")
|
||||||
let msg = string.fromBytes(await conn.readLp(1024))
|
let msg = string.fromBytes(await conn.readLp(1024))
|
||||||
check "Hello!" == msg
|
check "Hello!" == msg
|
||||||
|
@ -103,20 +114,20 @@ suite "Switch":
|
||||||
|
|
||||||
await sleepAsync(2.seconds) # wait a little for cleanup to happen
|
await sleepAsync(2.seconds) # wait a little for cleanup to happen
|
||||||
var bufferTracker = getTracker(BufferStreamTrackerName)
|
var bufferTracker = getTracker(BufferStreamTrackerName)
|
||||||
# echo bufferTracker.dump()
|
echo bufferTracker.dump()
|
||||||
|
|
||||||
# plus 4 for the pubsub streams
|
# plus 4 for the pubsub streams
|
||||||
check (BufferStreamTracker(bufferTracker).opened ==
|
check (BufferStreamTracker(bufferTracker).opened ==
|
||||||
(BufferStreamTracker(bufferTracker).closed + 4.uint64))
|
(BufferStreamTracker(bufferTracker).closed + 4.uint64))
|
||||||
|
|
||||||
# var connTracker = getTracker(ConnectionTrackerName)
|
var connTracker = getTracker(ConnectionTrackerName)
|
||||||
# echo connTracker.dump()
|
echo connTracker.dump()
|
||||||
|
|
||||||
# plus 8 is for the secured connection and the socket
|
# plus 8 is for the secured connection and the socket
|
||||||
# and the pubsub streams that won't clean up until
|
# and the pubsub streams that won't clean up until
|
||||||
# `disconnect()` or `stop()`
|
# `disconnect()` or `stop()`
|
||||||
# check (ConnectionTracker(connTracker).opened ==
|
check (ConnectionTracker(connTracker).opened ==
|
||||||
# (ConnectionTracker(connTracker).closed + 8.uint64))
|
(ConnectionTracker(connTracker).closed + 8.uint64))
|
||||||
|
|
||||||
await allFuturesThrowing(
|
await allFuturesThrowing(
|
||||||
done.wait(5.seconds),
|
done.wait(5.seconds),
|
||||||
|
@ -127,6 +138,9 @@ suite "Switch":
|
||||||
# this needs to go at end
|
# this needs to go at end
|
||||||
await allFuturesThrowing(awaiters)
|
await allFuturesThrowing(awaiters)
|
||||||
|
|
||||||
|
check not switch1.isConnected(switch2.peerInfo)
|
||||||
|
check not switch2.isConnected(switch1.peerInfo)
|
||||||
|
|
||||||
waitFor(testSwitch())
|
waitFor(testSwitch())
|
||||||
|
|
||||||
test "e2e use connect then dial":
|
test "e2e use connect then dial":
|
||||||
|
@ -153,10 +167,11 @@ suite "Switch":
|
||||||
awaiters.add(await switch2.start())
|
awaiters.add(await switch2.start())
|
||||||
|
|
||||||
await switch2.connect(switch1.peerInfo)
|
await switch2.connect(switch1.peerInfo)
|
||||||
check switch1.peerInfo.id in switch2.connections
|
|
||||||
|
|
||||||
let conn = await switch2.dial(switch1.peerInfo, TestCodec)
|
let conn = await switch2.dial(switch1.peerInfo, TestCodec)
|
||||||
|
|
||||||
|
check switch1.isConnected(switch2.peerInfo)
|
||||||
|
check switch2.isConnected(switch1.peerInfo)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await conn.writeLp("Hello!")
|
await conn.writeLp("Hello!")
|
||||||
let msg = string.fromBytes(await conn.readLp(1024))
|
let msg = string.fromBytes(await conn.readLp(1024))
|
||||||
|
@ -172,6 +187,9 @@ suite "Switch":
|
||||||
)
|
)
|
||||||
await allFuturesThrowing(awaiters)
|
await allFuturesThrowing(awaiters)
|
||||||
|
|
||||||
|
check not switch1.isConnected(switch2.peerInfo)
|
||||||
|
check not switch2.isConnected(switch1.peerInfo)
|
||||||
|
|
||||||
check:
|
check:
|
||||||
waitFor(testSwitch()) == true
|
waitFor(testSwitch()) == true
|
||||||
|
|
||||||
|
@ -186,23 +204,23 @@ suite "Switch":
|
||||||
|
|
||||||
await switch2.connect(switch1.peerInfo)
|
await switch2.connect(switch1.peerInfo)
|
||||||
|
|
||||||
check switch1.connections[switch2.peerInfo.id].len > 0
|
check switch1.isConnected(switch2.peerInfo)
|
||||||
check switch2.connections[switch1.peerInfo.id].len > 0
|
check switch2.isConnected(switch1.peerInfo)
|
||||||
|
|
||||||
await sleepAsync(100.millis)
|
await sleepAsync(100.millis)
|
||||||
await switch2.disconnect(switch1.peerInfo)
|
await switch2.disconnect(switch1.peerInfo)
|
||||||
|
|
||||||
await sleepAsync(2.seconds)
|
await sleepAsync(2.seconds)
|
||||||
|
|
||||||
|
check not switch1.isConnected(switch2.peerInfo)
|
||||||
|
check not switch2.isConnected(switch1.peerInfo)
|
||||||
|
|
||||||
var bufferTracker = getTracker(BufferStreamTrackerName)
|
var bufferTracker = getTracker(BufferStreamTrackerName)
|
||||||
# echo bufferTracker.dump()
|
# echo bufferTracker.dump()
|
||||||
check bufferTracker.isLeaked() == false
|
check bufferTracker.isLeaked() == false
|
||||||
|
|
||||||
# var connTracker = getTracker(ConnectionTrackerName)
|
var connTracker = getTracker(ConnectionTrackerName)
|
||||||
# echo connTracker.dump()
|
# echo connTracker.dump()
|
||||||
# check connTracker.isLeaked() == false
|
check connTracker.isLeaked() == false
|
||||||
|
|
||||||
check switch2.peerInfo.id notin switch1.connections
|
|
||||||
check switch1.peerInfo.id notin switch2.connections
|
|
||||||
|
|
||||||
await allFuturesThrowing(
|
await allFuturesThrowing(
|
||||||
switch1.stop(),
|
switch1.stop(),
|
||||||
|
@ -210,47 +228,3 @@ suite "Switch":
|
||||||
await allFuturesThrowing(awaiters)
|
await allFuturesThrowing(awaiters)
|
||||||
|
|
||||||
waitFor(testSwitch())
|
waitFor(testSwitch())
|
||||||
|
|
||||||
# test "e2e: handle read + secio fragmented":
|
|
||||||
# proc testListenerDialer(): Future[bool] {.async.} =
|
|
||||||
# let
|
|
||||||
# server: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0")
|
|
||||||
# serverInfo = PeerInfo.init(PrivateKey.random(ECDSA), [server])
|
|
||||||
# serverNoise = newSecio(serverInfo.privateKey)
|
|
||||||
# readTask = newFuture[void]()
|
|
||||||
|
|
||||||
# var hugePayload = newSeq[byte](0x1200000)
|
|
||||||
# check randomBytes(hugePayload) == hugePayload.len
|
|
||||||
# trace "Sending huge payload", size = hugePayload.len
|
|
||||||
|
|
||||||
# proc connHandler(conn: Connection) {.async, gcsafe.} =
|
|
||||||
# let sconn = await serverNoise.secure(conn)
|
|
||||||
# defer:
|
|
||||||
# await sconn.close()
|
|
||||||
# let msg = await sconn.read(0x1200000)
|
|
||||||
# check msg == hugePayload
|
|
||||||
# readTask.complete()
|
|
||||||
|
|
||||||
# let
|
|
||||||
# transport1: TcpTransport = TcpTransport.init()
|
|
||||||
# asyncCheck await transport1.listen(server, connHandler)
|
|
||||||
|
|
||||||
# let
|
|
||||||
# transport2: TcpTransport = TcpTransport.init()
|
|
||||||
# clientInfo = PeerInfo.init(PrivateKey.random(ECDSA), [transport1.ma])
|
|
||||||
# clientNoise = newSecio(clientInfo.privateKey)
|
|
||||||
# conn = await transport2.dial(transport1.ma)
|
|
||||||
# sconn = await clientNoise.secure(conn)
|
|
||||||
|
|
||||||
# await sconn.write(hugePayload)
|
|
||||||
# await readTask
|
|
||||||
|
|
||||||
# await sconn.close()
|
|
||||||
# await conn.close()
|
|
||||||
# await transport2.close()
|
|
||||||
# await transport1.close()
|
|
||||||
|
|
||||||
# result = true
|
|
||||||
|
|
||||||
# check:
|
|
||||||
# waitFor(testListenerDialer()) == true
|
|
||||||
|
|
Loading…
Reference in New Issue