diff --git a/libp2p/multistream.nim b/libp2p/multistream.nim index fbf1981..4d40068 100644 --- a/libp2p/multistream.nim +++ b/libp2p/multistream.nim @@ -37,12 +37,6 @@ type handlers*: seq[HandlerHolder] codec*: string - MultistreamHandshakeException* = object of CatchableError - -proc newMultistreamHandshakeException*(): ref CatchableError {.inline.} = - result = newException(MultistreamHandshakeException, - "could not perform multistream handshake") - proc newMultistream*(): MultistreamSelect = new result result.codec = MSCodec @@ -62,7 +56,7 @@ proc select*(m: MultistreamSelect, s.removeSuffix("\n") if s != Codec: notice "handshake failed", codec = s.toHex() - raise newMultistreamHandshakeException() + return "" if proto.len() == 0: # no protocols, must be a handshake call return Codec @@ -152,8 +146,12 @@ proc handle*(m: MultistreamSelect, conn: Connection) {.async, gcsafe.} = return debug "no handlers for ", protocol = ms await conn.write(Na) + except CancelledError as exc: + await conn.close() + raise exc except CatchableError as exc: trace "exception in multistream", exc = exc.msg + await conn.close() finally: trace "leaving multistream loop" diff --git a/libp2p/muxers/mplex/coder.nim b/libp2p/muxers/mplex/coder.nim index 6e1520f..6164aa9 100644 --- a/libp2p/muxers/mplex/coder.nim +++ b/libp2p/muxers/mplex/coder.nim @@ -47,8 +47,8 @@ proc writeMsg*(conn: Connection, msgType: MessageType, data: seq[byte] = @[]) {.async, gcsafe.} = trace "sending data over mplex", id, - msgType, - data = data.len + msgType, + data = data.len var left = data.len offset = 0 diff --git a/libp2p/muxers/mplex/lpchannel.nim b/libp2p/muxers/mplex/lpchannel.nim index 2486572..771a4e3 100644 --- a/libp2p/muxers/mplex/lpchannel.nim +++ b/libp2p/muxers/mplex/lpchannel.nim @@ -15,7 +15,8 @@ import types, ../../stream/connection, ../../stream/bufferstream, ../../utility, - ../../errors + ../../errors, + ../../peerinfo export connection @@ -90,87 +91,104 @@ proc newChannel*(id: uint64, name: string = "", size: int = DefaultBufferSize, lazy: bool = false): LPChannel = - new result - result.id = id - result.name = name - result.conn = conn - result.initiator = initiator - result.msgCode = if initiator: MessageType.MsgOut else: MessageType.MsgIn - result.closeCode = if initiator: MessageType.CloseOut else: MessageType.CloseIn - result.resetCode = if initiator: MessageType.ResetOut else: MessageType.ResetIn - result.isLazy = lazy + result = LPChannel(id: id, + name: name, + conn: conn, + initiator: initiator, + msgCode: if initiator: MessageType.MsgOut else: MessageType.MsgIn, + closeCode: if initiator: MessageType.CloseOut else: MessageType.CloseIn, + resetCode: if initiator: MessageType.ResetOut else: MessageType.ResetIn, + isLazy: lazy) let chan = result + logScope: + id = chan.id + initiator = chan.initiator + name = chan.name + oid = $chan.oid + peer = $chan.conn.peerInfo + # stack = getStackTrace() + proc writeHandler(data: seq[byte]): Future[void] {.async, gcsafe.} = try: if chan.isLazy and not(chan.isOpen): await chan.open() # writes should happen in sequence - trace "sending data", data = data.shortLog, - id = chan.id, - initiator = chan.initiator, - name = chan.name, - oid = chan.oid + trace "sending data" - try: - await conn.writeMsg(chan.id, - chan.msgCode, - data).wait(2.minutes) # write header - except AsyncTimeoutError: - trace "timeout writing channel, resetting" - asyncCheck chan.reset() + await conn.writeMsg(chan.id, + chan.msgCode, + data).wait(2.minutes) # write header except CatchableError as exc: - trace "unable to write in bufferstream handler", exc = exc.msg + trace "exception in lpchannel write handler", exc = exc.msg + await chan.reset() + raise exc result.initBufferStream(writeHandler, size) when chronicles.enabledLogLevel == LogLevel.TRACE: result.name = if result.name.len > 0: result.name else: $result.oid - trace "created new lpchannel", id = result.id, - oid = result.oid, - initiator = result.initiator, - name = result.name + trace "created new lpchannel" proc closeMessage(s: LPChannel) {.async.} = + logScope: + id = s.id + initiator = s.initiator + name = s.name + oid = $s.oid + peer = $s.conn.peerInfo + # stack = getStackTrace() + ## send close message - this will not raise ## on EOF or Closed - withEOFExceptions: - withWriteLock(s.writeLock): - trace "sending close message", id = s.id, - initiator = s.initiator, - name = s.name, - oid = s.oid + withWriteLock(s.writeLock): + trace "sending close message" - await s.conn.writeMsg(s.id, s.closeCode) # write close + await s.conn.writeMsg(s.id, s.closeCode) # write close proc resetMessage(s: LPChannel) {.async.} = + logScope: + id = s.id + initiator = s.initiator + name = s.name + oid = $s.oid + peer = $s.conn.peerInfo + # stack = getStackTrace() + ## send reset message - this will not raise withEOFExceptions: withWriteLock(s.writeLock): - trace "sending reset message", id = s.id, - initiator = s.initiator, - name = s.name, - oid = s.oid + trace "sending reset message" await s.conn.writeMsg(s.id, s.resetCode) # write reset proc open*(s: LPChannel) {.async, gcsafe.} = + logScope: + id = s.id + initiator = s.initiator + name = s.name + oid = $s.oid + peer = $s.conn.peerInfo + # stack = getStackTrace() + ## NOTE: Don't call withExcAndLock or withWriteLock, ## because this already gets called from writeHandler ## which is locked - withEOFExceptions: - await s.conn.writeMsg(s.id, MessageType.New, s.name) - trace "opened channel", oid = s.oid, - name = s.name, - initiator = s.initiator - s.isOpen = true + await s.conn.writeMsg(s.id, MessageType.New, s.name) + trace "opened channel" + s.isOpen = true proc closeRemote*(s: LPChannel) {.async.} = - trace "got EOF, closing channel", id = s.id, - initiator = s.initiator, - name = s.name, - oid = s.oid + logScope: + id = s.id + initiator = s.initiator + name = s.name + oid = $s.oid + peer = $s.conn.peerInfo + # stack = getStackTrace() + + trace "got EOF, closing channel" # wait for all data in the buffer to be consumed while s.len > 0: @@ -181,11 +199,7 @@ proc closeRemote*(s: LPChannel) {.async.} = await s.close() # close local end # call to avoid leaks await procCall BufferStream(s).close() # close parent bufferstream - - trace "channel closed on EOF", id = s.id, - initiator = s.initiator, - oid = s.oid, - name = s.name + trace "channel closed on EOF" method closed*(s: LPChannel): bool = ## this emulates half-closed behavior @@ -195,6 +209,20 @@ method closed*(s: LPChannel): bool = s.closedLocal method reset*(s: LPChannel) {.base, async, gcsafe.} = + logScope: + id = s.id + initiator = s.initiator + name = s.name + oid = $s.oid + peer = $s.conn.peerInfo + # stack = getStackTrace() + + trace "resetting channel" + + if s.closedLocal and s.isEof: + trace "channel already closed or reset" + return + # we asyncCheck here because the other end # might be dead already - reset is always # optimistic @@ -203,33 +231,36 @@ method reset*(s: LPChannel) {.base, async, gcsafe.} = s.isEof = true s.closedLocal = true + trace "channel reset" + method close*(s: LPChannel) {.async, gcsafe.} = + logScope: + id = s.id + initiator = s.initiator + name = s.name + oid = $s.oid + peer = $s.conn.peerInfo + # stack = getStackTrace() + if s.closedLocal: - trace "channel already closed", id = s.id, - initiator = s.initiator, - name = s.name, - oid = s.oid + trace "channel already closed" return - proc closeRemote() {.async.} = + trace "closing local lpchannel" + + proc closeInternal() {.async.} = try: - trace "closing local lpchannel", id = s.id, - initiator = s.initiator, - name = s.name, - oid = s.oid await s.closeMessage().wait(2.minutes) if s.atEof: # already closed by remote close parent buffer immediately await procCall BufferStream(s).close() - except AsyncTimeoutError: - trace "close timed out, reset channel" - asyncCheck s.reset() # reset on timeout + except CancelledError as exc: + await s.reset() # reset on timeout + raise exc except CatchableError as exc: trace "exception closing channel" + await s.reset() # reset on timeout - trace "lpchannel closed local", id = s.id, - initiator = s.initiator, - name = s.name, - oid = s.oid + trace "lpchannel closed local" s.closedLocal = true - asyncCheck closeRemote() + asyncCheck closeInternal() diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index eee7a56..6966641 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -8,12 +8,13 @@ ## those terms. import tables, sequtils, oids -import chronos, chronicles, stew/byteutils +import chronos, chronicles, stew/byteutils, metrics import ../muxer, ../../stream/connection, ../../stream/bufferstream, ../../utility, ../../errors, + ../../peerinfo, coder, types, lpchannel @@ -21,6 +22,8 @@ import ../muxer, logScope: topics = "mplex" +declareGauge(libp2p_mplex_channels, "mplex channels", labels = ["initiator", "peer"]) + type Mplex* = ref object of Muxer remote: Table[uint64, LPChannel] @@ -33,10 +36,10 @@ type proc getChannelList(m: Mplex, initiator: bool): var Table[uint64, LPChannel] = if initiator: - trace "picking local channels", initiator = initiator, oid = m.oid + trace "picking local channels", initiator = initiator, oid = $m.oid result = m.local else: - trace "picking remote channels", initiator = initiator, oid = m.oid + trace "picking remote channels", initiator = initiator, oid = $m.oid result = m.remote proc newStreamInternal*(m: Mplex, @@ -46,6 +49,7 @@ proc newStreamInternal*(m: Mplex, lazy: bool = false): Future[LPChannel] {.async, gcsafe.} = ## create new channel/stream + ## let id = if initiator: m.currentId.inc(); m.currentId else: chanId @@ -53,7 +57,7 @@ proc newStreamInternal*(m: Mplex, trace "creating new channel", channelId = id, initiator = initiator, name = name, - oid = m.oid + oid = $m.oid result = newChannel(id, m.connection, initiator, @@ -63,96 +67,128 @@ proc newStreamInternal*(m: Mplex, result.peerInfo = m.connection.peerInfo result.observedAddr = m.connection.observedAddr - m.getChannelList(initiator)[id] = result + doAssert(id notin m.getChannelList(initiator), + "channel slot already taken!") -proc handleStream(m: Muxer, chann: LPChannel) {.async.} = + m.getChannelList(initiator)[id] = result + libp2p_mplex_channels.set( + m.getChannelList(initiator).len.int64, + labelValues = [$initiator, + $m.connection.peerInfo]) + +proc cleanupChann(m: Mplex, chann: LPChannel) {.async, inline.} = + ## remove the local channel from the internal tables + ## + await chann.closeEvent.wait() + if not isNil(chann): + m.getChannelList(chann.initiator).del(chann.id) + trace "cleaned up channel", id = chann.id + + libp2p_mplex_channels.set( + m.getChannelList(chann.initiator).len.int64, + labelValues = [$chann.initiator, + $m.connection.peerInfo]) + +proc handleStream(m: Mplex, chann: LPChannel) {.async.} = + ## call the muxer stream handler for this channel + ## try: await m.streamHandler(chann) trace "finished handling stream" doAssert(chann.closed, "connection not closed by handler!") + except CancelledError as exc: + trace "cancling stream handler", exc = exc.msg + await chann.reset() + raise except CatchableError as exc: trace "exception in stream handler", exc = exc.msg await chann.reset() + await m.cleanupChann(chann) method handle*(m: Mplex) {.async, gcsafe.} = - trace "starting mplex main loop", oid = m.oid + trace "starting mplex main loop", oid = $m.oid try: - try: - while not m.connection.closed: - trace "waiting for data", oid = m.oid - let (id, msgType, data) = await m.connection.readMsg() - trace "read message from connection", id = id, - msgType = msgType, - data = data.shortLog, - oid = m.oid - - let initiator = bool(ord(msgType) and 1) - var channel: LPChannel - if MessageType(msgType) != MessageType.New: - let channels = m.getChannelList(initiator) - if id notin channels: - - trace "Channel not found, skipping", id = id, - initiator = initiator, - msg = msgType, - oid = m.oid - continue - channel = channels[id] - - logScope: - id = id - initiator = initiator - msgType = msgType - size = data.len - muxer_oid = m.oid - - case msgType: - of MessageType.New: - let name = string.fromBytes(data) - channel = await m.newStreamInternal(false, id, name) - - trace "created channel", name = channel.name, - oid = channel.oid - - if not isNil(m.streamHandler): - # launch handler task - asyncCheck m.handleStream(channel) - - of MessageType.MsgIn, MessageType.MsgOut: - logScope: - name = channel.name - oid = channel.oid - - trace "pushing data to channel" - - if data.len > MaxMsgSize: - raise newLPStreamLimitError() - await channel.pushTo(data) - of MessageType.CloseIn, MessageType.CloseOut: - logScope: - name = channel.name - oid = channel.oid - - trace "closing channel" - - await channel.closeRemote() - m.getChannelList(initiator).del(id) - trace "deleted channel" - of MessageType.ResetIn, MessageType.ResetOut: - logScope: - name = channel.name - oid = channel.oid - - trace "resetting channel" - - await channel.reset() - m.getChannelList(initiator).del(id) - trace "deleted channel" - finally: - trace "stopping mplex main loop", oid = m.oid + defer: + trace "stopping mplex main loop", oid = $m.oid await m.close() + + while not m.connection.closed: + trace "waiting for data", oid = $m.oid + let (id, msgType, data) = await m.connection.readMsg() + trace "read message from connection", id = id, + msgType = msgType, + data = data.shortLog, + oid = $m.oid + + let initiator = bool(ord(msgType) and 1) + var channel: LPChannel + if MessageType(msgType) != MessageType.New: + let channels = m.getChannelList(initiator) + if id notin channels: + + trace "Channel not found, skipping", id = id, + initiator = initiator, + msg = msgType, + oid = $m.oid + continue + channel = channels[id] + + logScope: + id = id + initiator = initiator + msgType = msgType + size = data.len + muxer_oid = $m.oid + + case msgType: + of MessageType.New: + let name = string.fromBytes(data) + channel = await m.newStreamInternal(false, id, name) + + trace "created channel", name = channel.name, + oid = $channel.oid + + if not isNil(m.streamHandler): + # launch handler task + asyncCheck m.handleStream(channel) + + of MessageType.MsgIn, MessageType.MsgOut: + logScope: + name = channel.name + oid = $channel.oid + + trace "pushing data to channel" + + if data.len > MaxMsgSize: + raise newLPStreamLimitError() + await channel.pushTo(data) + + of MessageType.CloseIn, MessageType.CloseOut: + logScope: + name = channel.name + oid = $channel.oid + + trace "closing channel" + + await channel.closeRemote() + await m.cleanupChann(channel) + + trace "deleted channel" + of MessageType.ResetIn, MessageType.ResetOut: + logScope: + name = channel.name + oid = $channel.oid + + trace "resetting channel" + + await channel.reset() + await m.cleanupChann(channel) + + trace "deleted channel" + except CancelledError as exc: + raise exc except CatchableError as exc: - trace "Exception occurred", exception = exc.msg, oid = m.oid + trace "Exception occurred", exception = exc.msg, oid = $m.oid proc newMplex*(conn: Connection, maxChanns: uint = MaxChannels): Mplex = @@ -165,14 +201,6 @@ proc newMplex*(conn: Connection, when chronicles.enabledLogLevel == LogLevel.TRACE: result.oid = genOid() -proc cleanupChann(m: Mplex, chann: LPChannel) {.async, inline.} = - ## remove the local channel from the internal tables - ## - await chann.closeEvent.wait() - if not isNil(chann): - m.getChannelList(true).del(chann.id) - trace "cleaned up channel", id = chann.id - method newStream*(m: Mplex, name: string = "", lazy: bool = false): Future[Connection] {.async, gcsafe.} = @@ -187,19 +215,17 @@ method close*(m: Mplex) {.async, gcsafe.} = if m.isClosed: return - try: - trace "closing mplex muxer", oid = m.oid - let channs = toSeq(m.remote.values) & - toSeq(m.local.values) - - for chann in channs: - try: - await chann.reset() - except CatchableError as exc: - warn "error resetting channel", exc = exc.msg - - await m.connection.close() - finally: + defer: m.remote.clear() m.local.clear() m.isClosed = true + + trace "closing mplex muxer", oid = $m.oid + let channs = toSeq(m.remote.values) & + toSeq(m.local.values) + + for chann in channs: + await chann.reset() + await m.cleanupChann(chann) + + await m.connection.close() diff --git a/libp2p/muxers/muxer.nim b/libp2p/muxers/muxer.nim index 999188b..2d61160 100644 --- a/libp2p/muxers/muxer.nim +++ b/libp2p/muxers/muxer.nim @@ -63,8 +63,12 @@ method init(c: MuxerProvider) = futs &= c.muxerHandler(muxer) checkFutures(await allFinished(futs)) + except CancelledError as exc: + raise exc except CatchableError as exc: trace "exception in muxer handler", exc = exc.msg, peer = $conn, proto=proto + finally: + await conn.close() c.handler = handler diff --git a/libp2p/peerinfo.nim b/libp2p/peerinfo.nim index 31c2a6b..bfe5048 100644 --- a/libp2p/peerinfo.nim +++ b/libp2p/peerinfo.nim @@ -38,7 +38,8 @@ type key: Option[PublicKey] proc id*(p: PeerInfo): string = - p.peerId.pretty() + if not(isNil(p)): + return p.peerId.pretty() proc `$`*(p: PeerInfo): string = p.id diff --git a/libp2p/protocols/identify.nim b/libp2p/protocols/identify.nim index 3c25f09..b2d3c50 100644 --- a/libp2p/protocols/identify.nim +++ b/libp2p/protocols/identify.nim @@ -113,13 +113,15 @@ proc newIdentify*(peerInfo: PeerInfo): Identify = method init*(p: Identify) = proc handle(conn: Connection, proto: string) {.async, gcsafe, closure.} = try: - try: - trace "handling identify request", oid = conn.oid - var pb = encodeMsg(p.peerInfo, conn.observedAddr) - await conn.writeLp(pb.buffer) - finally: + defer: trace "exiting identify handler", oid = conn.oid await conn.close() + + trace "handling identify request", oid = conn.oid + var pb = encodeMsg(p.peerInfo, conn.observedAddr) + await conn.writeLp(pb.buffer) + except CancelledError as exc: + raise exc except CatchableError as exc: trace "exception in identify handler", exc = exc.msg diff --git a/libp2p/protocols/pubsub/gossipsub.nim b/libp2p/protocols/pubsub/gossipsub.nim index 028a6ce..64e8a82 100644 --- a/libp2p/protocols/pubsub/gossipsub.nim +++ b/libp2p/protocols/pubsub/gossipsub.nim @@ -52,7 +52,7 @@ type gossip*: Table[string, seq[ControlIHave]] # pending gossip control*: Table[string, ControlMessage] # pending control messages mcache*: MCache # messages cache - heartbeatFut: Future[void] # cancellation future for heartbeat interval + heartbeatFut: Future[void] # cancellation future for heartbeat interval heartbeatRunning: bool heartbeatLock: AsyncLock # heartbeat lock to prevent two consecutive concurrent heartbeats @@ -159,6 +159,8 @@ proc rebalanceMesh(g: GossipSub, topic: string) {.async.} = trace "mesh balanced, got peers", peers = g.mesh.getOrDefault(topic).len, topicId = topic + except CancelledError as exc: + raise exc except CatchableError as exc: trace "exception occurred re-balancing mesh", exc = exc.msg @@ -227,12 +229,10 @@ proc heartbeat(g: GossipSub) {.async.} = checkFutures(await allFinished(sent)) g.mcache.shift() # shift the cache + except CancelledError as exc: + raise exc except CatchableError as exc: trace "exception ocurred in gossipsub heartbeat", exc = exc.msg - # sleep less in the case of an error - # but still throttle - await sleepAsync(100.millis) - continue await sleepAsync(1.seconds) diff --git a/libp2p/protocols/pubsub/pubsub.nim b/libp2p/protocols/pubsub/pubsub.nim index 8a0bbab..da1c24a 100644 --- a/libp2p/protocols/pubsub/pubsub.nim +++ b/libp2p/protocols/pubsub/pubsub.nim @@ -104,13 +104,13 @@ method handleDisconnect*(p: PubSub, peer: PubSubPeer) {.async, base.} = p.peers.del(peer.id) # metrics - libp2p_pubsub_peers.dec() + libp2p_pubsub_peers.set(p.peers.len.int64) proc cleanUpHelper(p: PubSub, peer: PubSubPeer) {.async.} = try: await p.cleanupLock.acquire() peer.refs.dec() # decrement refcount - if peer.refs == 0: + if peer.refs <= 0: await p.handleDisconnect(peer) finally: p.cleanupLock.release() @@ -119,24 +119,23 @@ proc getPeer(p: PubSub, peerInfo: PeerInfo, proto: string): PubSubPeer = if peerInfo.id in p.peers: - result = p.peers[peerInfo.id] - return + return p.peers[peerInfo.id] # create new pubsub peer let peer = newPubSubPeer(peerInfo, proto) trace "created new pubsub peer", peerId = peer.id # metrics - libp2p_pubsub_peers.inc() p.peers[peer.id] = peer peer.refs.inc # increment reference count peer.observers = p.observers - result = peer + libp2p_pubsub_peers.set(p.peers.len.int64) + return peer proc internalCleanup(p: PubSub, conn: Connection) {.async.} = # handle connection close - if conn.closed: + if isNil(conn): return var peer = p.getPeer(conn.peerInfo, p.codec) @@ -168,6 +167,7 @@ method handleConn*(p: PubSub, # call pubsub rpc handler await p.rpcHandler(peer, msgs) + asyncCheck p.internalCleanup(conn) let peer = p.getPeer(conn.peerInfo, proto) let topics = toSeq(p.topics.keys) if topics.len > 0: @@ -176,18 +176,27 @@ method handleConn*(p: PubSub, peer.handler = handler await peer.handle(conn) # spawn peer read loop trace "pubsub peer handler ended, cleaning up" - await p.internalCleanup(conn) + except CancelledError as exc: + await conn.close() + raise exc except CatchableError as exc: trace "exception ocurred in pubsub handle", exc = exc.msg + await conn.close() method subscribeToPeer*(p: PubSub, conn: Connection) {.base, async.} = - var peer = p.getPeer(conn.peerInfo, p.codec) - trace "setting connection for peer", peerId = conn.peerInfo.id - if not peer.isConnected: - peer.conn = conn + if not(isNil(conn)): + let peer = p.getPeer(conn.peerInfo, p.codec) + trace "setting connection for peer", peerId = conn.peerInfo.id + if not peer.connected: + peer.conn = conn - asyncCheck p.internalCleanup(conn) + asyncCheck p.internalCleanup(conn) + +proc connected*(p: PubSub, peer: PeerInfo): bool = + let peer = p.getPeer(peer, p.codec) + if not(isNil(peer)): + return peer.connected method unsubscribe*(p: PubSub, topics: seq[TopicPair]) {.base, async.} = @@ -309,7 +318,8 @@ proc newPubSub*(P: typedesc[PubSub], msgIdProvider: msgIdProvider) result.initPubSub() -proc addObserver*(p: PubSub; observer: PubSubObserver) = p.observers[] &= observer +proc addObserver*(p: PubSub; observer: PubSubObserver) = + p.observers[] &= observer proc removeObserver*(p: PubSub; observer: PubSubObserver) = let idx = p.observers[].find(observer) diff --git a/libp2p/protocols/pubsub/pubsubpeer.nim b/libp2p/protocols/pubsub/pubsubpeer.nim index 3fb4377..1ea9200 100644 --- a/libp2p/protocols/pubsub/pubsubpeer.nim +++ b/libp2p/protocols/pubsub/pubsubpeer.nim @@ -47,8 +47,8 @@ type proc id*(p: PubSubPeer): string = p.peerInfo.id -proc isConnected*(p: PubSubPeer): bool = - (not isNil(p.sendConn)) +proc connected*(p: PubSubPeer): bool = + not(isNil(p.sendConn)) proc `conn=`*(p: PubSubPeer, conn: Connection) = if not(isNil(conn)): @@ -126,8 +126,10 @@ proc send*(p: PubSubPeer, msgs: seq[RPCMsg]) {.async.} = try: trace "about to send message", peer = p.id, encoded = digest - await p.onConnect.wait() - if p.isConnected: # this can happen if the remote disconnected + if not p.onConnect.isSet: + await p.onConnect.wait() + + if p.connected: # this can happen if the remote disconnected trace "sending encoded msgs to peer", peer = p.id, encoded = encoded.buffer.shortLog await p.sendConn.writeLp(encoded.buffer) @@ -139,10 +141,14 @@ proc send*(p: PubSubPeer, msgs: seq[RPCMsg]) {.async.} = # metrics libp2p_pubsub_sent_messages.inc(labelValues = [p.id, t]) + except CancelledError as exc: + raise exc except CatchableError as exc: trace "unable to send to remote", exc = exc.msg - p.sendConn = nil - p.onConnect.clear() + if not(isNil(p.sendConn)): + await p.sendConn.close() + p.sendConn = nil + p.onConnect.clear() # if no connection has been set, # queue messages until a connection diff --git a/libp2p/protocols/secure/secure.nim b/libp2p/protocols/secure/secure.nim index baeaece..3aa354f 100644 --- a/libp2p/protocols/secure/secure.nim +++ b/libp2p/protocols/secure/secure.nim @@ -72,6 +72,10 @@ method init*(s: Secure) {.gcsafe.} = # We don't need the result but we definitely need to await the handshake discard await s.handleConn(conn, false) trace "connection secured" + except CancelledError as exc: + warn "securing connection canceled" + await conn.close() + raise except CatchableError as exc: warn "securing connection failed", msg = exc.msg await conn.close() @@ -79,31 +83,20 @@ method init*(s: Secure) {.gcsafe.} = s.handler = handle method secure*(s: Secure, conn: Connection, initiator: bool): Future[Connection] {.async, base, gcsafe.} = - try: - result = await s.handleConn(conn, initiator) - except CancelledError as exc: - raise exc - except CatchableError as exc: - warn "securing connection failed", msg = exc.msg - return nil + result = await s.handleConn(conn, initiator) method readOnce*(s: SecureConn, pbytes: pointer, nbytes: int): Future[int] {.async, gcsafe.} = - try: - if nbytes == 0: - return 0 + if nbytes == 0: + return 0 - if s.buf.data().len() == 0: - let buf = await s.readMessage() - if buf.len == 0: - raise newLPStreamIncompleteError() - s.buf.add(buf) + if s.buf.data().len() == 0: + let buf = await s.readMessage() + if buf.len == 0: + raise newLPStreamIncompleteError() + s.buf.add(buf) - var p = cast[ptr UncheckedArray[byte]](pbytes) - return s.buf.consumeTo(toOpenArray(p, 0, nbytes - 1)) - except CatchableError as exc: - trace "exception reading from secure connection", exc = exc.msg, oid = s.oid - await s.close() # make sure to close the wrapped connection - raise exc + var p = cast[ptr UncheckedArray[byte]](pbytes) + return s.buf.consumeTo(toOpenArray(p, 0, nbytes - 1)) diff --git a/libp2p/stream/bufferstream.nim b/libp2p/stream/bufferstream.nim index e879766..c15fa7b 100644 --- a/libp2p/stream/bufferstream.nim +++ b/libp2p/stream/bufferstream.nim @@ -128,19 +128,19 @@ proc initBufferStream*(s: BufferStream, if not(isNil(handler)): s.writeHandler = proc (data: seq[byte]) {.async, gcsafe.} = - try: - # Using a lock here to guarantee - # proper write ordering. This is - # specially important when - # implementing half-closed in mplex - # or other functionality that requires - # strict message ordering - await s.writeLock.acquire() - await handler(data) - finally: + defer: s.writeLock.release() - trace "created bufferstream", oid = s.oid + # Using a lock here to guarantee + # proper write ordering. This is + # specially important when + # implementing half-closed in mplex + # or other functionality that requires + # strict message ordering + await s.writeLock.acquire() + await handler(data) + + trace "created bufferstream", oid = $s.oid proc newBufferStream*(handler: WriteHandler = nil, size: int = DefaultBufferSize): BufferStream = @@ -173,31 +173,31 @@ method pushTo*(s: BufferStream, data: seq[byte]) {.base, async.} = if s.atEof: raise newLPStreamEOFError() - try: - await s.lock.acquire() - var index = 0 - while not s.closed(): - while index < data.len and s.readBuf.len < s.maxSize: - s.readBuf.addLast(data[index]) - inc(index) - # trace "pushTo()", msg = "added " & $s.len & " bytes to readBuf", oid = s.oid - - # resolve the next queued read request - if s.readReqs.len > 0: - s.readReqs.popFirst().complete() - # trace "pushTo(): completed a readReqs future", oid = s.oid - - if index >= data.len: - return - - # if we couldn't transfer all the data to the - # internal buf wait on a read event - await s.dataReadEvent.wait() - s.dataReadEvent.clear() - finally: + defer: # trace "ended", size = s.len s.lock.release() + await s.lock.acquire() + var index = 0 + while not s.closed(): + while index < data.len and s.readBuf.len < s.maxSize: + s.readBuf.addLast(data[index]) + inc(index) + # trace "pushTo()", msg = "added " & $s.len & " bytes to readBuf", oid = s.oid + + # resolve the next queued read request + if s.readReqs.len > 0: + s.readReqs.popFirst().complete() + # trace "pushTo(): completed a readReqs future", oid = s.oid + + if index >= data.len: + return + + # if we couldn't transfer all the data to the + # internal buf wait on a read event + await s.dataReadEvent.wait() + s.dataReadEvent.clear() + method readOnce*(s: BufferStream, pbytes: pointer, nbytes: int): @@ -290,8 +290,10 @@ method close*(s: BufferStream) {.async, gcsafe.} = await procCall Connection(s).close() inc getBufferStreamTracker().closed - trace "bufferstream closed", oid = s.oid + trace "bufferstream closed", oid = $s.oid else: trace "attempt to close an already closed bufferstream", trace = getStackTrace() + except CancelledError as exc: + raise except CatchableError as exc: trace "error closing buffer stream", exc = exc.msg diff --git a/libp2p/stream/lpstream.nim b/libp2p/stream/lpstream.nim index 410d3ed..f7a4eda 100644 --- a/libp2p/stream/lpstream.nim +++ b/libp2p/stream/lpstream.nim @@ -74,7 +74,7 @@ method initStream*(s: LPStream) {.base.} = s.oid = genOid() libp2p_open_streams.inc(labelValues = [s.objName]) - trace "stream created", oid = s.oid, name = s.objName + trace "stream created", oid = $s.oid, name = s.objName # TODO: debuging aid to troubleshoot streams open/close # try: @@ -150,7 +150,6 @@ proc readVarint*(conn: LPStream): Future[uint64] {.async, gcsafe.} = for i in 0.. MaxConnectionsPerPeer: - warn "disconnecting peer, too many connections", peer = $conn.peerInfo, - conns = s.connections - .getOrDefault(id).len - await muxer.close() - await s.disconnect(conn.peerInfo) - raise newTooManyConnections() + if isNil(muxer): + return - s.connections.mgetOrPut( - id, - newSeq[ConnectionHolder]()) - .add(ConnectionHolder(conn: conn, dir: dir)) + let conn = muxer.connection + if isNil(conn): + return - s.muxed.mgetOrPut( - muxer.connection.peerInfo.id, - newSeq[MuxerHolder]()) - .add(MuxerHolder(muxer: muxer, handle: handle, dir: dir)) + let id = conn.peerInfo.id + if s.connections.getOrDefault(id).len > MaxConnectionsPerPeer: + warn "disconnecting peer, too many connections", peer = $conn.peerInfo, + conns = s.connections + .getOrDefault(id).len + await s.disconnect(conn.peerInfo) + raise newTooManyConnections() + + s.connections.mgetOrPut( + id, + newSeq[ConnectionHolder]()) + .add(ConnectionHolder(conn: conn, dir: dir)) + + s.muxed.mgetOrPut( + muxer.connection.peerInfo.id, + newSeq[MuxerHolder]()) + .add(MuxerHolder(muxer: muxer, handle: handle, dir: dir)) proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = if s.secureManagers.len <= 0: @@ -164,50 +167,41 @@ proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = if manager.len == 0: raise newException(CatchableError, "Unable to negotiate a secure channel!") - trace "securing connection", codec=manager + trace "securing connection", codec = manager let secureProtocol = s.secureManagers.filterIt(it.codec == manager) # ms.select should deal with the correctness of this # let's avoid duplicating checks but detect if it fails to do it properly doAssert(secureProtocol.len > 0) result = await secureProtocol[0].secure(conn, true) -proc identify(s: Switch, conn: Connection): Future[PeerInfo] {.async, gcsafe.} = +proc identify(s: Switch, conn: Connection) {.async, gcsafe.} = ## identify the connection - if not isNil(conn.peerInfo): - result = conn.peerInfo + if (await s.ms.select(conn, s.identity.codec)): + let info = await s.identity.identify(conn, conn.peerInfo) - try: - if (await s.ms.select(conn, s.identity.codec)): - let info = await s.identity.identify(conn, conn.peerInfo) + if info.pubKey.isNone and isNil(conn): + raise newException(CatchableError, + "no public key provided and no existing peer identity found") - if info.pubKey.isNone and isNil(result): - raise newException(CatchableError, - "no public key provided and no existing peer identity found") + if isNil(conn.peerInfo): + conn.peerInfo = PeerInfo.init(info.pubKey.get()) - if info.pubKey.isSome: - result = PeerInfo.init(info.pubKey.get()) - trace "identify: identified remote peer", peer = result.id + if info.addrs.len > 0: + conn.peerInfo.addrs = info.addrs - if info.addrs.len > 0: - result.addrs = info.addrs + if info.agentVersion.isSome: + conn.peerInfo.agentVersion = info.agentVersion.get() - if info.agentVersion.isSome: - result.agentVersion = info.agentVersion.get() + if info.protoVersion.isSome: + conn.peerInfo.protoVersion = info.protoVersion.get() - if info.protoVersion.isSome: - result.protoVersion = info.protoVersion.get() + if info.protos.len > 0: + conn.peerInfo.protocols = info.protos - if info.protos.len > 0: - result.protocols = info.protos + trace "identify: identified remote peer", peer = $conn.peerInfo - trace "identify", info = shortLog(result) - except IdentityInvalidMsgError as exc: - debug "identify: invalid message", msg = exc.msg - except IdentityNoMatchError as exc: - debug "identify: peer's public keys don't match ", msg = exc.msg - -proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} = +proc mux(s: Switch, conn: Connection) {.async, gcsafe.} = ## mux incoming connection trace "muxing connection", peer = $conn @@ -224,48 +218,61 @@ proc mux(s: Switch, conn: Connection): Future[void] {.async, gcsafe.} = # create new muxer for connection let muxer = s.muxers[muxerName].newMuxer(conn) - trace "found a muxer", name=muxerName, peer = $conn + trace "found a muxer", name = muxerName, peer = $conn # install stream handler muxer.streamHandler = s.streamHandler # new stream for identify var stream = await muxer.newStream() + var handlerFut: Future[void] + + defer: + if not(isNil(stream)): + await stream.close() # close identify stream + # call muxer handler, this should # not end until muxer ends - let handlerFut = muxer.handle() + handlerFut = muxer.handle() - try: - # do identify first, so that we have a - # PeerInfo in case we didn't before - conn.peerInfo = await s.identify(stream) - finally: - await stream.close() # close identify stream + # do identify first, so that we have a + # PeerInfo in case we didn't before + await s.identify(stream) if isNil(conn.peerInfo): await muxer.close() - return + raise newException(CatchableError, + "unable to identify peer, aborting upgrade") # store it in muxed connections if we have a peer for it trace "adding muxer for peer", peer = conn.peerInfo.id await s.storeConn(muxer, Direction.Out, handlerFut) proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} = - try: - if not isNil(conn.peerInfo): - let id = conn.peerInfo.id - trace "cleaning up connection for peer", peerId = id + if isNil(conn): + return + + defer: + await conn.close() + libp2p_peers.set(s.connections.len.int64) + + if isNil(conn.peerInfo): + return + + let id = conn.peerInfo.id + trace "cleaning up connection for peer", peerId = id + if id in s.muxed: + let muxerHolder = s.muxed[id] + .filterIt( + it.muxer.connection == conn + ) + + if muxerHolder.len > 0: + await muxerHolder[0].muxer.close() + if not(isNil(muxerHolder[0].handle)): + await muxerHolder[0].handle + if id in s.muxed: - let muxerHolder = s.muxed[id] - .filterIt( - it.muxer.connection == conn - ) - - if muxerHolder.len > 0: - await muxerHolder[0].muxer.close() - if not(isNil(muxerHolder[0].handle)): - await muxerHolder[0].handle - s.muxed[id].keepItIf( it.muxer.connection != conn ) @@ -273,25 +280,17 @@ proc cleanupConn(s: Switch, conn: Connection) {.async, gcsafe.} = if s.muxed[id].len == 0: s.muxed.del(id) - if id in s.connections: - s.connections[id].keepItIf( - it.conn != conn - ) + if id in s.connections: + s.connections[id].keepItIf( + it.conn != conn + ) - if s.connections[id].len == 0: - s.connections.del(id) + if s.connections[id].len == 0: + s.connections.del(id) - await conn.close() - s.dialedPubSubPeers.excl(id) - - # TODO: Investigate cleanupConn() always called twice for one peer. - if not(conn.peerInfo.isClosed()): - conn.peerInfo.close() - - except CatchableError as exc: - trace "exception cleaning up connection", exc = exc.msg - finally: - libp2p_peers.set(s.connections.len.int64) + # TODO: Investigate cleanupConn() always called twice for one peer. + if not(conn.peerInfo.isClosed()): + conn.peerInfo.close() proc disconnect*(s: Switch, peer: PeerInfo) {.async, gcsafe.} = let connections = s.connections.getOrDefault(peer.id) @@ -308,27 +307,25 @@ proc getMuxedStream(s: Switch, peerInfo: PeerInfo): Future[Connection] {.async, return await muxer.newStream() proc upgradeOutgoing(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} = - trace "handling connection", conn = $conn, oid = conn.oid + logScope: + conn = $conn + oid = $conn.oid let sconn = await s.secure(conn) # secure the connection if isNil(sconn): - trace "unable to secure connection, stopping upgrade", conn = $conn, - oid = conn.oid - await conn.close() - return + raise newException(CatchableError, + "unable to secure connection, stopping upgrade") + trace "upgrading connection" await s.mux(sconn) # mux it if possible - if isNil(conn.peerInfo): - trace "unable to mux connection, stopping upgrade", conn = $conn, - oid = conn.oid + if isNil(sconn.peerInfo): await sconn.close() - return + raise newException(CatchableError, + "unable to mux connection, stopping upgrade") libp2p_peers.set(s.connections.len.int64) - trace "succesfully upgraded outgoing connection", conn = $conn, - oid = conn.oid, - uoid = sconn.oid - result = sconn + trace "succesfully upgraded outgoing connection", uoid = sconn.oid + return sconn proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = trace "upgrading incoming connection", conn = $conn, oid = conn.oid @@ -338,45 +335,38 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = proc securedHandler (conn: Connection, proto: string) {.async, gcsafe, closure.} = + + var sconn: Connection + trace "Securing connection", oid = conn.oid + let secure = s.secureManagers.filterIt(it.codec == proto)[0] + try: - trace "Securing connection", oid = conn.oid - let secure = s.secureManagers.filterIt(it.codec == proto)[0] - let sconn = await secure.secure(conn, false) - if sconn.isNil: + sconn = await secure.secure(conn, false) + if isNil(sconn): return + defer: + await sconn.close() + # add the muxer for muxer in s.muxers.values: ms.addHandler(muxer.codec, muxer) # handle subsequent requests - try: - await ms.handle(sconn) - finally: - await sconn.close() + await ms.handle(sconn) except CancelledError as exc: raise exc except CatchableError as exc: debug "ending secured handler", err = exc.msg - try: - try: - if (await ms.select(conn)): # just handshake - # add the secure handlers - for k in s.secureManagers: - ms.addHandler(k.codec, securedHandler) + if (await ms.select(conn)): # just handshake + # add the secure handlers + for k in s.secureManagers: + ms.addHandler(k.codec, securedHandler) - # handle secured connections - await ms.handle(conn) - finally: - await conn.close() - except CancelledError as exc: - raise exc - except CatchableError as exc: - trace "error in multistream", err = exc.msg - -proc subscribeToPeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} + # handle secured connections + await ms.handle(conn) proc internalConnect(s: Switch, peer: PeerInfo): Future[Connection] {.async.} = @@ -388,79 +378,88 @@ proc internalConnect(s: Switch, let lock = s.dialLock.mgetOrPut(id, newAsyncLock()) var conn: Connection - try: - await lock.acquire() - trace "about to dial peer", peer = id - conn = s.selectConn(peer) - if conn.isNil or conn.closed: - trace "Dialing peer", peer = id - for t in s.transports: # for each transport - for a in peer.addrs: # for each address - if t.handles(a): # check if it can dial it - trace "Dialing address", address = $a - try: - conn = await t.dial(a) - libp2p_dialed_peers.inc() - except CatchableError as exc: - trace "dialing failed", exc = exc.msg - libp2p_failed_dials.inc() - continue - - # make sure to assign the peer to the connection - conn.peerInfo = peer - conn = await s.upgradeOutgoing(conn) - if isNil(conn): - libp2p_failed_upgrade.inc() - continue - - conn.closeEvent.wait() - .addCallback do(udata: pointer): - asyncCheck s.cleanupConn(conn) - break - else: - trace "Reusing existing connection", oid = conn.oid - except CatchableError as exc: - trace "exception connecting to peer", exc = exc.msg - if not(isNil(conn)): - await conn.close() - - raise exc # re-raise - finally: + defer: if lock.locked(): lock.release() - if not isNil(conn): - doAssert(conn.peerInfo.id in s.connections, "connection not tracked!") - trace "dial succesfull", oid = conn.oid - await s.subscribeToPeer(peer) - result = conn + await lock.acquire() + trace "about to dial peer", peer = id + conn = s.selectConn(peer) + if conn.isNil or conn.closed: + trace "Dialing peer", peer = id + for t in s.transports: # for each transport + for a in peer.addrs: # for each address + if t.handles(a): # check if it can dial it + trace "Dialing address", address = $a + try: + conn = await t.dial(a) + libp2p_dialed_peers.inc() + except CancelledError as exc: + trace "dialing canceled", exc = exc.msg + raise + except CatchableError as exc: + trace "dialing failed", exc = exc.msg + libp2p_failed_dials.inc() + continue + + # make sure to assign the peer to the connection + conn.peerInfo = peer + try: + conn = await s.upgradeOutgoing(conn) + except CatchableError as exc: + if not(isNil(conn)): + await conn.close() + + trace "Unable to establish outgoing link", exc = exc.msg + raise exc + + if isNil(conn): + libp2p_failed_upgrade.inc() + continue + + conn.closeEvent.wait() + .addCallback do(udata: pointer): + asyncCheck s.cleanupConn(conn) + break + else: + trace "Reusing existing connection", oid = conn.oid + + if isNil(conn): + raise newException(CatchableError, + "Unable to establish outgoing link") + + if conn.closed or conn.atEof: + await conn.close() + raise newException(CatchableError, + "Connection dead on arrival") + + doAssert(conn.peerInfo.id in s.connections, + "connection not tracked!") + + trace "dial succesfull", oid = conn.oid + await s.subscribeToPeer(peer) + return conn proc connect*(s: Switch, peer: PeerInfo) {.async.} = var conn = await s.internalConnect(peer) - if isNil(conn): - raise newException(CatchableError, "Unable to connect to peer") proc dial*(s: Switch, peer: PeerInfo, proto: string): Future[Connection] {.async.} = var conn = await s.internalConnect(peer) - if isNil(conn): - raise newException(CatchableError, "Unable to establish outgoing link") - - if conn.closed: - raise newException(CatchableError, "Connection dead on arrival") - - result = conn let stream = await s.getMuxedStream(peer) - if not isNil(stream): - trace "Connection is muxed, return muxed stream", oid = conn.oid - result = stream - trace "Attempting to select remote", proto = proto, oid = conn.oid + if isNil(stream): + await conn.close() + raise newException(CatchableError, "Couldn't get muxed stream") - if not await s.ms.select(result, proto): + trace "Attempting to select remote", proto = proto, oid = conn.oid + if not await s.ms.select(stream, proto): + await stream.close() raise newException(CatchableError, "Unable to select sub-protocol " & proto) + return stream + proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} = if isNil(proto.handler): raise newException(CatchableError, @@ -477,10 +476,10 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = proc handle(conn: Connection): Future[void] {.async, closure, gcsafe.} = try: - try: - await s.upgradeIncoming(conn) # perform upgrade on incoming connection - finally: + defer: await s.cleanupConn(conn) + + await s.upgradeIncoming(conn) # perform upgrade on incoming connection except CancelledError as exc: raise exc except CatchableError as exc: @@ -501,52 +500,61 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = result = startFuts # listen for incoming connections proc stop*(s: Switch) {.async.} = - try: - trace "stopping switch" + trace "stopping switch" - # we want to report errors but we do not want to fail - # or crash here, cos we need to clean possibly MANY items - # and any following conn/transport won't be cleaned up - if s.pubSub.isSome: - await s.pubSub.get().stop() + # we want to report errors but we do not want to fail + # or crash here, cos we need to clean possibly MANY items + # and any following conn/transport won't be cleaned up + if s.pubSub.isSome: + await s.pubSub.get().stop() - for conns in toSeq(s.connections.values): - for conn in conns: - try: - await s.cleanupConn(conn.conn) - except CatchableError as exc: - warn "error cleaning up connections" - - for t in s.transports: + for conns in toSeq(s.connections.values): + for conn in conns: try: - await t.close() + await s.cleanupConn(conn.conn) + except CancelledError as exc: + raise exc except CatchableError as exc: - warn "error cleaning up transports" + warn "error cleaning up connections" - trace "switch stopped" - except CatchableError as exc: - warn "error stopping switch", exc = exc.msg + for t in s.transports: + try: + await t.close() + except CancelledError as exc: + raise exc + except CatchableError as exc: + warn "error cleaning up transports" + + trace "switch stopped" proc subscribeToPeer*(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} = - trace "about to subscribe to pubsub peer", peer = peerInfo.shortLog() ## Subscribe to pub sub peer - if s.pubSub.isSome and (peerInfo.id notin s.dialedPubSubPeers): - let conn = await s.getMuxedStream(peerInfo) - if isNil(conn): + if s.pubSub.isSome and not(s.pubSub.get().connected(peerInfo)): + trace "about to subscribe to pubsub peer", peer = peerInfo.shortLog() + var stream: Connection + try: + stream = await s.getMuxedStream(peerInfo) + except CancelledError as exc: + if not(isNil(stream)): + await stream.close() + + raise exc + except CatchableError as exc: + trace "exception in subscribe to peer", peer = peerInfo.shortLog, + exc = exc.msg + if not(isNil(stream)): + await stream.close() + + if isNil(stream): trace "unable to subscribe to peer", peer = peerInfo.shortLog return - s.dialedPubSubPeers.incl(peerInfo.id) - try: - if (await s.ms.select(conn, s.pubSub.get().codec)): - await s.pubSub.get().subscribeToPeer(conn) - else: - await conn.close() - except CatchableError as exc: - trace "exception in subscribe to peer", peer = peerInfo.shortLog, exc = exc.msg - await conn.close() - finally: - s.dialedPubSubPeers.excl(peerInfo.id) + if not await s.ms.select(stream, s.pubSub.get().codec): + if not(isNil(stream)): + await stream.close() + return + + await s.pubSub.get().subscribeToPeer(stream) proc subscribe*(s: Switch, topic: string, handler: TopicHandler): Future[void] = @@ -594,6 +602,43 @@ proc removeValidator*(s: Switch, s.pubSub.get().removeValidator(topics, hook) +proc muxerHandler(s: Switch, muxer: Muxer) {.async, gcsafe.} = + var stream = await muxer.newStream() + defer: + if not(isNil(stream)): + await stream.close() + + trace "got new muxer" + + try: + # once we got a muxed connection, attempt to + # identify it + await s.identify(stream) + if isNil(stream.peerInfo): + await muxer.close() + return + + muxer.connection.peerInfo = stream.peerInfo + + # store muxer and muxed connection + await s.storeConn(muxer, Direction.In) + libp2p_peers.set(s.connections.len.int64) + + muxer.connection.closeEvent.wait() + .addCallback do(udata: pointer): + asyncCheck s.cleanupConn(muxer.connection) + + # try establishing a pubsub connection + await s.subscribeToPeer(muxer.connection.peerInfo) + + except CancelledError as exc: + await muxer.close() + raise exc + except CatchableError as exc: + await muxer.close() + libp2p_failed_upgrade.inc() + trace "exception in muxer handler", exc = exc.msg + proc newSwitch*(peerInfo: PeerInfo, transports: seq[Transport], identity: Identify, @@ -609,49 +654,25 @@ proc newSwitch*(peerInfo: PeerInfo, result.identity = identity result.muxers = muxers result.secureManagers = @secureManagers - result.dialedPubSubPeers = initHashSet[string]() let s = result # can't capture result result.streamHandler = proc(stream: Connection) {.async, gcsafe.} = try: trace "handling connection for", peerInfo = $stream - try: - await s.ms.handle(stream) # handle incoming connection - finally: - if not(stream.closed): + defer: + if not(isNil(stream)): await stream.close() + await s.ms.handle(stream) # handle incoming connection + except CancelledError as exc: + raise exc except CatchableError as exc: trace "exception in stream handler", exc = exc.msg result.mount(identity) for key, val in muxers: val.streamHandler = result.streamHandler - val.muxerHandler = proc(muxer: Muxer) {.async, gcsafe.} = - var stream: Connection - try: - trace "got new muxer" - stream = await muxer.newStream() - # once we got a muxed connection, attempt to - # identify it - muxer.connection.peerInfo = await s.identify(stream) - - # store muxer and muxed connection - await s.storeConn(muxer, Direction.In) - libp2p_peers.set(s.connections.len.int64) - - muxer.connection.closeEvent.wait() - .addCallback do(udata: pointer): - asyncCheck s.cleanupConn(muxer.connection) - - # try establishing a pubsub connection - await s.subscribeToPeer(muxer.connection.peerInfo) - - except CatchableError as exc: - libp2p_failed_upgrade.inc() - trace "exception in muxer handler", exc = exc.msg - finally: - if not(isNil(stream)): - await stream.close() + val.muxerHandler = proc(muxer: Muxer): Future[void] = + s.muxerHandler(muxer) if result.secureManagers.len <= 0: # use plain text if no secure managers are provided diff --git a/libp2p/transports/tcptransport.nim b/libp2p/transports/tcptransport.nim index a832924..6edb510 100644 --- a/libp2p/transports/tcptransport.nim +++ b/libp2p/transports/tcptransport.nim @@ -97,14 +97,7 @@ proc connCb(server: StreamServer, raise exc except CatchableError as err: debug "Connection setup failed", err = err.msg - if not client.closed: - try: - client.close() - except CancelledError as err: - raise err - except CatchableError as err: - # shouldn't happen but.. - warn "Error closing connection", err = err.msg + client.close() proc init*(T: type TcpTransport, flags: set[ServerFlags] = {}): T = result = T(flags: flags) diff --git a/tests/testmplex.nim b/tests/testmplex.nim index 613f1e3..987d8a7 100644 --- a/tests/testmplex.nim +++ b/tests/testmplex.nim @@ -243,7 +243,7 @@ suite "Mplex": await done.wait(1.seconds) await conn.close() - await mplexDialFut + await mplexDialFut.wait(1.seconds) await allFuturesThrowing( transport1.close(), transport2.close()) diff --git a/tests/testnoise.nim b/tests/testnoise.nim index d3a7769..b3f21b4 100644 --- a/tests/testnoise.nim +++ b/tests/testnoise.nim @@ -71,8 +71,8 @@ proc createSwitch(ma: MultiAddress; outgoing: bool): (Switch, PeerInfo) = suite "Noise": teardown: for tracker in testTrackers(): - echo tracker.dump() - # check tracker.isLeaked() == false + # echo tracker.dump() + check tracker.isLeaked() == false test "e2e: handle write + noise": proc testListenerDialer(): Future[bool] {.async.} =