diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index 0b30f4fc7..f11ae6014 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -25,7 +25,6 @@ type Mplex* = ref object of Muxer remote: Table[uint64, LPChannel] local: Table[uint64, LPChannel] - handlerFuts: seq[Future[void]] currentId*: uint64 maxChannels*: uint64 isClosed: bool @@ -66,6 +65,15 @@ proc newStreamInternal*(m: Mplex, m.getChannelList(initiator)[id] = result +proc handleStream(m: Muxer, chann: LPChannel) {.async.} = + try: + await m.streamHandler(chann) + trace "finished handling stream" + doAssert(chann.closed, "connection not closed by handler!") + except CatchableError as exc: + trace "exception in stream handler", exc = exc.msg + await chann.reset() + method handle*(m: Mplex) {.async, gcsafe.} = trace "starting mplex main loop", oid = m.oid try: @@ -96,7 +104,7 @@ method handle*(m: Mplex) {.async, gcsafe.} = initiator = initiator msgType = msgType size = data.len - oid = m.oid + muxer_oid = m.oid case msgType: of MessageType.New: @@ -104,27 +112,16 @@ method handle*(m: Mplex) {.async, gcsafe.} = channel = await m.newStreamInternal(false, id, name) trace "created channel", name = channel.name, - chann_iod = channel.oid + oid = channel.oid if not isNil(m.streamHandler): - var fut = newFuture[void]() - proc handler() {.async.} = - try: - await m.streamHandler(channel) - trace "finished handling stream" - # doAssert(channel.closed, "connection not closed by handler!") - except CatchableError as exc: - trace "exception in stream handler", exc = exc.msg - await channel.reset() - finally: - m.handlerFuts.keepItIf(it != fut) - - fut = handler() + # launch handler task + asyncCheck m.handleStream(channel) of MessageType.MsgIn, MessageType.MsgOut: logScope: name = channel.name - chann_iod = channel.oid + oid = channel.oid trace "pushing data to channel" @@ -134,7 +131,7 @@ method handle*(m: Mplex) {.async, gcsafe.} = of MessageType.CloseIn, MessageType.CloseOut: logScope: name = channel.name - chann_iod = channel.oid + oid = channel.oid trace "closing channel" @@ -144,7 +141,7 @@ method handle*(m: Mplex) {.async, gcsafe.} = of MessageType.ResetIn, MessageType.ResetOut: logScope: name = channel.name - chann_iod = channel.oid + oid = channel.oid trace "resetting channel" @@ -201,12 +198,9 @@ method close*(m: Mplex) {.async, gcsafe.} = except CatchableError as exc: warn "error resetting channel", exc = exc.msg - checkFutures( - await allFinished(m.handlerFuts)) - await m.connection.close() finally: m.remote.clear() m.local.clear() - m.handlerFuts = @[] + # m.handlerFuts = @[] m.isClosed = true diff --git a/libp2p/protocols/pubsub/pubsubpeer.nim b/libp2p/protocols/pubsub/pubsubpeer.nim index 8dd478142..681c511cf 100644 --- a/libp2p/protocols/pubsub/pubsubpeer.nim +++ b/libp2p/protocols/pubsub/pubsubpeer.nim @@ -60,13 +60,15 @@ proc recvObservers(p: PubSubPeer, msg: var RPCMsg) = # trigger hooks if not(isNil(p.observers)) and p.observers[].len > 0: for obs in p.observers[]: - obs.onRecv(p, msg) + if not(isNil(obs)): # TODO: should never be nil, but... + obs.onRecv(p, msg) proc sendObservers(p: PubSubPeer, msg: var RPCMsg) = # trigger hooks if not(isNil(p.observers)) and p.observers[].len > 0: for obs in p.observers[]: - obs.onSend(p, msg) + if not(isNil(obs)): # TODO: should never be nil, but... + obs.onSend(p, msg) proc handle*(p: PubSubPeer, conn: Connection) {.async.} = trace "handling pubsub rpc", peer = p.id, closed = conn.closed diff --git a/libp2p/protocols/secure/noise.nim b/libp2p/protocols/secure/noise.nim index 59d88af27..0e286b90d 100644 --- a/libp2p/protocols/secure/noise.nim +++ b/libp2p/protocols/secure/noise.nim @@ -467,13 +467,9 @@ method handshake*(p: Noise, conn: Connection, initiator: bool): Future[SecureCon debug "Noise handshake, peer infos don't match!", initiator, dealt_peer = $conn.peerInfo.id, dealt_key = $failedKey, received_peer = $pid, received_key = $remotePubKey raise newException(NoiseHandshakeError, "Noise handshake, peer infos don't match! " & $pid & " != " & $conn.peerInfo.peerId) - var secure = new NoiseConnection - secure.initStream() - - secure.stream = conn - secure.peerInfo = PeerInfo.init(remotePubKey) - secure.observedAddr = conn.observedAddr - + var secure = NoiseConnection.init(conn, + PeerInfo.init(remotePubKey), + conn.observedAddr) if initiator: secure.readCs = handshakeRes.cs2 secure.writeCs = handshakeRes.cs1 diff --git a/libp2p/protocols/secure/secio.nim b/libp2p/protocols/secure/secio.nim index 60f042412..3d369b92d 100644 --- a/libp2p/protocols/secure/secio.nim +++ b/libp2p/protocols/secure/secio.nim @@ -245,9 +245,9 @@ proc newSecioConn(conn: Connection, ## Create new secure stream/lpstream, using specified hash algorithm ``hash``, ## cipher algorithm ``cipher``, stretched keys ``secrets`` and order ## ``order``. - new result - result.initStream() - result.stream = conn + result = SecioConn.init(conn, + PeerInfo.init(remotePubKey), + conn.observedAddr) let i0 = if order < 0: 1 else: 0 let i1 = if order < 0: 0 else: 1 @@ -265,9 +265,6 @@ proc newSecioConn(conn: Connection, result.readerCoder.init(cipher, secrets.keyOpenArray(i1), secrets.ivOpenArray(i1)) - result.peerInfo = PeerInfo.init(remotePubKey) - result.observedAddr = conn.observedAddr - proc transactMessage(conn: Connection, msg: seq[byte]): Future[seq[byte]] {.async.} = trace "Sending message", message = msg.shortLog, length = len(msg) diff --git a/libp2p/protocols/secure/secure.nim b/libp2p/protocols/secure/secure.nim index a12c35f97..5a594f161 100644 --- a/libp2p/protocols/secure/secure.nim +++ b/libp2p/protocols/secure/secure.nim @@ -12,6 +12,7 @@ import chronos, chronicles import ../protocol, ../../stream/streamseq, ../../stream/connection, + ../../multiaddress, ../../peerinfo logScope: @@ -24,6 +25,16 @@ type stream*: Connection buf: StreamSeq +proc init*[T: SecureConn](C: type T, + conn: Connection, + peerInfo: PeerInfo, + observedAddr: Multiaddress): T = + result = C(stream: conn, + peerInfo: peerInfo, + observedAddr: observedAddr, + closeEvent: conn.closeEvent) + result.initStream() + method initStream*(s: SecureConn) = if s.objName.len == 0: s.objName = "SecureConn" @@ -31,11 +42,11 @@ method initStream*(s: SecureConn) = procCall Connection(s).initStream() method close*(s: SecureConn) {.async.} = + await procCall Connection(s).close() + if not(isNil(s.stream)): await s.stream.close() - await procCall Connection(s).close() - method readMessage*(c: SecureConn): Future[seq[byte]] {.async, base.} = doAssert(false, "Not implemented!") @@ -47,11 +58,12 @@ method handshake(s: Secure, proc handleConn*(s: Secure, conn: Connection, initiator: bool): Future[Connection] {.async, gcsafe.} = var sconn = await s.handshake(conn, initiator) - result = sconn - result.observedAddr = conn.observedAddr + conn.closeEvent.wait() + .addCallback do(udata: pointer = nil): + if not(isNil(sconn)): + asyncCheck sconn.close() - if not isNil(sconn.peerInfo) and sconn.peerInfo.publicKey.isSome: - result.peerInfo = PeerInfo.init(sconn.peerInfo.publicKey.get()) + return sconn method init*(s: Secure) {.gcsafe.} = proc handle(conn: Connection, proto: string) {.async, gcsafe.} = @@ -94,7 +106,7 @@ method readExactly*(s: SecureConn, let consumed = s.buf.consumeTo(toOpenArray(p, 0, nbytes - 1)) doAssert consumed == nbytes, "checked above" except CatchableError as exc: - trace "exception reading from secure connection", exc = exc.msg + trace "exception reading from secure connection", exc = exc.msg, oid = s.oid await s.close() # make sure to close the wrapped connection raise exc @@ -115,6 +127,6 @@ method readOnce*(s: SecureConn, var p = cast[ptr UncheckedArray[byte]](pbytes) return s.buf.consumeTo(toOpenArray(p, 0, nbytes - 1)) except CatchableError as exc: - trace "exception reading from secure connection", exc = exc.msg + trace "exception reading from secure connection", exc = exc.msg, oid = s.oid await s.close() # make sure to close the wrapped connection raise exc diff --git a/libp2p/stream/chronosstream.nim b/libp2p/stream/chronosstream.nim index 05fbb8517..f5b239b80 100644 --- a/libp2p/stream/chronosstream.nim +++ b/libp2p/stream/chronosstream.nim @@ -82,12 +82,11 @@ method atEof*(s: ChronosStream): bool {.inline.} = method close*(s: ChronosStream) {.async.} = try: if not s.isClosed: - s.isClosed = true + await procCall Connection(s).close() - trace "shutting down chronos stream", address = $s.client.remoteAddress() + trace "shutting down chronos stream", address = $s.client.remoteAddress(), oid = s.oid if not s.client.closed(): await s.client.closeWait() - await procCall Connection(s).close() except CatchableError as exc: trace "error closing chronosstream", exc = exc.msg diff --git a/libp2p/stream/connection.nim b/libp2p/stream/connection.nim index 9e6ad9577..cb22e3fc0 100644 --- a/libp2p/stream/connection.nim +++ b/libp2p/stream/connection.nim @@ -21,7 +21,6 @@ type Connection* = ref object of LPStream peerInfo*: PeerInfo observedAddr*: Multiaddress - closeEvent*: AsyncEvent ConnectionTracker* = ref object of TrackerBase opened*: uint64 @@ -65,8 +64,6 @@ method initStream*(s: Connection) = method close*(s: Connection) {.async.} = await procCall LPStream(s).close() - - s.closeEvent.fire() inc getConnectionTracker().closed proc `$`*(conn: Connection): string = diff --git a/libp2p/stream/lpstream.nim b/libp2p/stream/lpstream.nim index 23879a176..d6969a76a 100644 --- a/libp2p/stream/lpstream.nim +++ b/libp2p/stream/lpstream.nim @@ -18,6 +18,7 @@ declareGauge(libp2p_open_streams, "open stream instances", labels = ["type"]) type LPStream* = ref object of RootObj + closeEvent*: AsyncEvent isClosed*: bool isEof*: bool objName*: string @@ -73,7 +74,19 @@ method initStream*(s: LPStream) {.base.} = s.oid = genOid() libp2p_open_streams.inc(labelValues = [s.objName]) - trace "stream created", oid = s.oid + trace "stream created", oid = s.oid, name = s.objName + + # TODO: debuging aid to troubleshoot streams open/close + # try: + # echo "ChronosStream ", libp2p_open_streams.value(labelValues = ["ChronosStream"]) + # echo "SecureConn ", libp2p_open_streams.value(labelValues = ["SecureConn"]) + # # doAssert(libp2p_open_streams.value(labelValues = ["ChronosStream"]) >= + # # libp2p_open_streams.value(labelValues = ["SecureConn"])) + # except CatchableError: + # discard + +proc join*(s: LPStream): Future[void] = + s.closeEvent.wait() method closed*(s: LPStream): bool {.base, inline.} = s.isClosed @@ -169,6 +182,16 @@ proc write*(s: LPStream, msg: string): Future[void] = method close*(s: LPStream) {.base, async.} = if not s.isClosed: - libp2p_open_streams.dec(labelValues = [s.objName]) s.isClosed = true - trace "stream destroyed", oid = s.oid + s.closeEvent.fire() + libp2p_open_streams.dec(labelValues = [s.objName]) + trace "stream destroyed", oid = s.oid, name = s.objName + + # TODO: debuging aid to troubleshoot streams open/close + # try: + # echo "ChronosStream ", libp2p_open_streams.value(labelValues = ["ChronosStream"]) + # echo "SecureConn ", libp2p_open_streams.value(labelValues = ["SecureConn"]) + # # doAssert(libp2p_open_streams.value(labelValues = ["ChronosStream"]) >= + # # libp2p_open_streams.value(labelValues = ["SecureConn"])) + # except CatchableError: + # discard diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 909a766e0..bf0463156 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -7,8 +7,17 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import tables, sequtils, options, strformat, sets -import chronos, chronicles, metrics +import tables, + sequtils, + options, + strformat, + sets, + algorithm + +import chronos, + chronicles, + metrics + import stream/connection, stream/chronosstream, transports/transport, @@ -38,13 +47,28 @@ declareCounter(libp2p_dialed_peers, "dialed peers") declareCounter(libp2p_failed_dials, "failed dials") declareCounter(libp2p_failed_upgrade, "peers failed upgrade") +const MaxConnectionsPerPeer = 5 + type - NoPubSubException = object of CatchableError + NoPubSubException* = object of CatchableError + TooManyConnections* = object of CatchableError + + Direction {.pure.} = enum + In, Out + + ConnectionHolder = object + dir: Direction + conn: Connection + + MuxerHolder = object + dir: Direction + muxer: Muxer + handle: Future[void] Switch* = ref object of RootObj peerInfo*: PeerInfo - connections*: Table[string, Connection] - muxed*: Table[string, Muxer] + connections*: Table[string, seq[ConnectionHolder]] + muxed*: Table[string, seq[MuxerHolder]] transports*: seq[Transport] protocols*: seq[LPProtocol] muxers*: Table[string, MuxerProvider] @@ -54,10 +78,84 @@ type secureManagers*: seq[Secure] pubSub*: Option[PubSub] dialedPubSubPeers: HashSet[string] + dialLock: Table[string, AsyncLock] -proc newNoPubSubException(): ref CatchableError {.inline.} = +proc newNoPubSubException(): ref NoPubSubException {.inline.} = result = newException(NoPubSubException, "no pubsub provided!") +proc newTooManyConnections(): ref TooManyConnections {.inline.} = + result = newException(TooManyConnections, "too many connections for peer") + +proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} + +proc selectConn(s: Switch, peerInfo: PeerInfo): Connection = + ## select the "best" connection according to some criteria + ## + ## Ideally when the connection's stats are available + ## we'd select the fastest, but for now we simply pick an outgoing + ## connection first if none is available, we pick the first outgoing + ## + + if isNil(peerInfo): + return + + let conns = s.connections + .getOrDefault(peerInfo.id) + # it should be OK to sort on each + # access as there should only be + # up to MaxConnectionsPerPeer entries + .sorted( + proc(a, b: ConnectionHolder): int = + if a.dir < b.dir: -1 + elif a.dir == b.dir: 0 + else: 1 + , SortOrder.Descending) + + if conns.len > 0: + return conns[0].conn + +proc selectMuxer(s: Switch, conn: Connection): Muxer = + ## select the muxer for the supplied connection + ## + + if isNil(conn): + return + + if not(isNil(conn.peerInfo)) and conn.peerInfo.id in s.muxed: + if s.muxed[conn.peerInfo.id].len > 0: + let muxers = s.muxed[conn.peerInfo.id] + .filterIt( it.muxer.connection == conn ) + if muxers.len > 0: + return muxers[0].muxer + +proc storeConn(s: Switch, + muxer: Muxer, + dir: Direction, + handle: Future[void] = nil) {.async.} = + ## store the connection and muxer + ## + if not(isNil(muxer)): + let conn = muxer.connection + if not(isNil(conn)): + let id = conn.peerInfo.id + if s.connections.getOrDefault(id).len > MaxConnectionsPerPeer: + warn "disconnecting peer, too many connections", peer = $conn.peerInfo, + conns = s.connections + .getOrDefault(id).len + await muxer.close() + await s.disconnect(conn.peerInfo) + raise newTooManyConnections() + + s.connections.mgetOrPut( + id, + newSeq[ConnectionHolder]()) + .add(ConnectionHolder(conn: conn, dir: dir)) + + s.muxed.mgetOrPut( + muxer.connection.peerInfo.id, + newSeq[MuxerHolder]()) + .add(MuxerHolder(muxer: muxer, handle: handle, dir: dir)) + proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = if s.secureManagers.len <= 0: raise newException(CatchableError, "No secure managers registered!") @@ -137,11 +235,6 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} = # not end until muxer ends let handlerFut = muxer.handle() - # add muxer handler cleanup proc - handlerFut.addCallback do (udata: pointer = nil): - trace "muxer handler completed for peer", - peer = conn.peerInfo.id - try: # do identify first, so that we have a # PeerInfo in case we didn't before @@ -149,10 +242,13 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} = finally: await stream.close() # close identify stream + if isNil(conn.peerInfo): + await muxer.close() + return + # store it in muxed connections if we have a peer for it - if not isNil(conn.peerInfo): - trace "adding muxer for peer", peer = conn.peerInfo.id - s.muxed[conn.peerInfo.id] = muxer + trace "adding muxer for peer", peer = conn.peerInfo.id + await s.storeConn(muxer, Direction.Out, handlerFut) proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} = try: @@ -160,55 +256,82 @@ proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} = let id = conn.peerInfo.id trace "cleaning up connection for peer", peerId = id if id in s.muxed: - await s.muxed[id].close() - s.muxed.del(id) + let muxerHolder = s.muxed[id] + .filterIt( + it.muxer.connection == conn + ) + + if muxerHolder.len > 0: + await muxerHolder[0].muxer.close() + if not(isNil(muxerHolder[0].handle)): + await muxerHolder[0].handle + + s.muxed[id].keepItIf( + it.muxer.connection != conn + ) + + if s.muxed[id].len == 0: + s.muxed.del(id) if id in s.connections: - s.connections.del(id) + s.connections[id].keepItIf( + it.conn != conn + ) - await conn.close() + if s.connections[id].len == 0: + s.connections.del(id) - s.dialedPubSubPeers.excl(id) + await conn.close() + s.dialedPubSubPeers.excl(id) - libp2p_peers.dec() # TODO: Investigate cleanupConn() always called twice for one peer. if not(conn.peerInfo.isClosed()): conn.peerInfo.close() + except CatchableError as exc: trace "exception cleaning up connection", exc = exc.msg + finally: + libp2p_peers.set(s.connections.len.int64) proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} = - let conn = s.connections.getOrDefault(peer.id) - if not isNil(conn): - trace "disconnecting peer", peer = $peer - await s.cleanupConn(conn) + let connections = s.connections.getOrDefault(peer.id) + for connHolder in connections: + if not isNil(connHolder.conn): + await s.cleanupConn(connHolder.conn) proc getMuxedStream(s: Switch, peerInfo: PeerInfo): Future[Connection] {.async, gcsafe.} = # if there is a muxer for the connection # use it instead to create a muxed stream - if peerInfo.id in s.muxed: - trace "connection is muxed, setting up a stream" - let muxer = s.muxed[peerInfo.id] - let conn = await muxer.newStream() - result = conn + + let muxer = s.selectMuxer(s.selectConn(peerInfo)) # always get the first muxer here + if not(isNil(muxer)): + return await muxer.newStream() proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = - trace "handling connection", conn = $conn - result = conn + trace "handling connection", conn = $conn, oid = conn.oid - # don't mux/secure twise - if conn.peerInfo.id in s.muxed: + let sconn = await s.secure(conn) # secure the connection + if isNil(sconn): + trace "unable to secure connection, stopping upgrade", conn = $conn, + oid = conn.oid + await conn.close() return - result = await s.secure(result) # secure the connection - if isNil(result): + await s.mux(sconn) # mux it if possible + if isNil(conn.peerInfo): + trace "unable to mux connection, stopping upgrade", conn = $conn, + oid = conn.oid + await sconn.close() return - await s.mux(result) # mux it if possible - s.connections[conn.peerInfo.id] = result + libp2p_peers.set(s.connections.len.int64) + trace "succesfully upgraded outgoing connection", conn = $conn, + oid = conn.oid, + uoid = sconn.oid + result = sconn proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = - trace "upgrading incoming connection", conn = $conn + trace "upgrading incoming connection", conn = $conn, oid = conn.oid let ms = newMultistream() # secure incoming connections @@ -216,7 +339,7 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = proto: string) {.async, gcsafe, closure.} = try: - trace "Securing connection" + trace "Securing connection", oid = conn.oid let secure = s.secureManagers.filterIt(it.codec == proto)[0] let sconn = await secure.secure(conn, false) if sconn.isNil: @@ -257,43 +380,60 @@ proc subscribeToPeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} proc internalConnect(s: Switch, peer: PeerInfo): Future[Connection] {.async.} = + + if s.peerInfo.peerId == peer.peerId: + raise newException(CatchableError, "can't dial self!") + let id = peer.id - trace "Dialing peer", peer = id - var conn = s.connections.getOrDefault(id) - if conn.isNil or conn.closed: - for t in s.transports: # for each transport - for a in peer.addrs: # for each address - if t.handles(a): # check if it can dial it - trace "Dialing address", address = $a - try: - conn = await t.dial(a) - libp2p_dialed_peers.inc() - except CatchableError as exc: - trace "dialing failed", exc = exc.msg - libp2p_failed_dials.inc() - continue + let lock = s.dialLock.mgetOrPut(id, newAsyncLock()) + var conn: Connection - # make sure to assign the peer to the connection - conn.peerInfo = peer + try: + await lock.acquire() + trace "about to dial peer", peer = id + conn = s.selectConn(peer) + if conn.isNil or conn.closed: + trace "Dialing peer", peer = id + for t in s.transports: # for each transport + for a in peer.addrs: # for each address + if t.handles(a): # check if it can dial it + trace "Dialing address", address = $a + try: + conn = await t.dial(a) + libp2p_dialed_peers.inc() + except CatchableError as exc: + trace "dialing failed", exc = exc.msg + libp2p_failed_dials.inc() + continue - conn = await s.upgradeOutgoing(conn) - if isNil(conn): - libp2p_failed_upgrade.inc() - continue + # make sure to assign the peer to the connection + conn.peerInfo = peer + conn = await s.upgradeOutgoing(conn) + if isNil(conn): + libp2p_failed_upgrade.inc() + continue - conn.closeEvent.wait() - .addCallback do(udata: pointer): - asyncCheck s.cleanupConn(conn) + conn.closeEvent.wait() + .addCallback do(udata: pointer): + asyncCheck s.cleanupConn(conn) + break + else: + trace "Reusing existing connection", oid = conn.oid + except CatchableError as exc: + trace "exception connecting to peer", exc = exc.msg + if not(isNil(conn)): + await conn.close() - libp2p_peers.inc() - break - else: - trace "Reusing existing connection" + raise exc # re-raise + finally: + if lock.locked(): + lock.release() if not isNil(conn): + doAssert(conn.peerInfo.id in s.connections, "connection not tracked!") + trace "dial succesfull", oid = conn.oid await s.subscribeToPeer(peer) - - result = conn + result = conn proc connect*(s: Switch, peer: PeerInfo) {.async.} = var conn = await s.internalConnect(peer) @@ -314,9 +454,9 @@ proc dial*(s: Switch, result = conn let stream = await s.getMuxedStream(peer) if not isNil(stream): - trace "Connection is muxed, return muxed stream" + trace "Connection is muxed, return muxed stream", oid = conn.oid result = stream - trace "Attempting to select remote", proto = proto + trace "Attempting to select remote", proto = proto, oid = conn.oid if not await s.ms.select(result, proto): raise newException(CatchableError, "Unable to select sub-protocol " & proto) @@ -338,7 +478,6 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} = try: try: - libp2p_peers.inc() await s.upgradeIncoming(conn) # perform upgrade on incoming connection finally: await s.cleanupConn(conn) @@ -358,6 +497,7 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = if s.pubSub.isSome: await s.pubSub.get().start() + info "started libp2p node", peer = $s.peerInfo, addrs = s.peerInfo.addrs result = startFuts # listen for incoming connections proc stop*(s: Switch) {.async.} = @@ -370,11 +510,12 @@ proc stop*(s: Switch) {.async.} = if s.pubSub.isSome: await s.pubSub.get().stop() - for conn in toSeq(s.connections.values): - try: - await s.cleanupConn(conn) - except CatchableError as exc: - warn "error cleaning up connections" + for conns in toSeq(s.connections.values): + for conn in conns: + try: + await s.cleanupConn(conn.conn) + except CatchableError as exc: + warn "error cleaning up connections" for t in s.transports: try: @@ -463,8 +604,8 @@ proc newSwitch*(peerInfo: PeerInfo, result.peerInfo = peerInfo result.ms = newMultistream() result.transports = transports - result.connections = initTable[string, Connection]() - result.muxed = initTable[string, Muxer]() + result.connections = initTable[string, seq[ConnectionHolder]]() + result.muxed = initTable[string, seq[MuxerHolder]]() result.identity = identity result.muxers = muxers result.secureManagers = @secureManagers @@ -494,11 +635,9 @@ proc newSwitch*(peerInfo: PeerInfo, # identify it muxer.connection.peerInfo = await s.identify(stream) - # store muxer for connection - s.muxed[muxer.connection.peerInfo.id] = muxer - - # store muxed connection - s.connections[muxer.connection.peerInfo.id] = muxer.connection + # store muxer and muxed connection + await s.storeConn(muxer, Direction.In) + libp2p_peers.set(s.connections.len.int64) muxer.connection.closeEvent.wait() .addCallback do(udata: pointer): @@ -506,6 +645,7 @@ proc newSwitch*(peerInfo: PeerInfo, # try establishing a pubsub connection await s.subscribeToPeer(muxer.connection.peerInfo) + except CatchableError as exc: libp2p_failed_upgrade.inc() trace "exception in muxer handler", exc = exc.msg diff --git a/tests/pubsub/testgossipsub.nim b/tests/pubsub/testgossipsub.nim index a6d4b927b..a72d67c3d 100644 --- a/tests/pubsub/testgossipsub.nim +++ b/tests/pubsub/testgossipsub.nim @@ -46,6 +46,7 @@ proc waitSub(sender, receiver: auto; key: string) {.async, gcsafe.} = suite "GossipSub": teardown: for tracker in testTrackers(): + # echo tracker.dump() check tracker.isLeaked() == false test "GossipSub validation should succeed": diff --git a/tests/testinterop.nim b/tests/testinterop.nim index 1648aa81a..e3dea9350 100644 --- a/tests/testinterop.nim +++ b/tests/testinterop.nim @@ -189,6 +189,7 @@ suite "Interop": check string.fromBytes(await stream.transp.readLp()) == "test 3" asyncDiscard stream.transp.writeLp("test 4") testFuture.complete() + await stream.close() await daemonNode.addHandler(protos, daemonHandler) let conn = await nativeNode.dial(NativePeerInfo.init(daemonPeer.peer, @@ -240,6 +241,7 @@ suite "Interop": var line = await stream.transp.readLine() check line == expect testFuture.complete(line) + await stream.close() await daemonNode.addHandler(protos, daemonHandler) let conn = await nativeNode.dial(NativePeerInfo.init(daemonPeer.peer, @@ -285,9 +287,12 @@ suite "Interop": discard await stream.transp.writeLp(test) result = test == (await wait(testFuture, 10.secs)) + + await stream.close() await nativeNode.stop() await allFutures(awaiters) await daemonNode.close() + await sleepAsync(1.seconds) check: waitFor(runTests()) == true @@ -331,6 +336,7 @@ suite "Interop": await wait(testFuture, 10.secs) result = true + await stream.close() await nativeNode.stop() await allFutures(awaiters) await daemonNode.close() diff --git a/tests/testswitch.nim b/tests/testswitch.nim index cf9622208..e06212504 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -192,8 +192,8 @@ suite "Switch": await switch2.connect(switch1.peerInfo) - check switch1.connections.len > 0 - check switch2.connections.len > 0 + check switch1.connections[switch2.peerInfo.id].len > 0 + check switch2.connections[switch1.peerInfo.id].len > 0 await sleepAsync(100.millis) await switch2.disconnect(switch1.peerInfo) @@ -207,8 +207,8 @@ suite "Switch": # echo connTracker.dump() # check connTracker.isLeaked() == false - check switch1.connections.len == 0 - check switch2.connections.len == 0 + check switch2.peerInfo.id notin switch1.connections + check switch1.peerInfo.id notin switch2.connections await allFuturesThrowing( switch1.stop(),