diff --git a/libp2p/muxers/mplex/coder.nim b/libp2p/muxers/mplex/coder.nim index 2f05e711f..3b7ecbe4d 100644 --- a/libp2p/muxers/mplex/coder.nim +++ b/libp2p/muxers/mplex/coder.nim @@ -46,7 +46,8 @@ proc writeMsg*(conn: Connection, id: uint64, msgType: MessageType, data: seq[byte] = @[]) {.async, gcsafe.} = - trace "sending data over mplex", id, + trace "sending data over mplex", oid = $conn.oid, + id, msgType, data = data.len var @@ -55,15 +56,14 @@ proc writeMsg*(conn: Connection, while left > 0 or data.len == 0: let chunkSize = if left > MaxMsgSize: MaxMsgSize - 64 else: left - chunk = if chunkSize > 0 : data[offset..(offset + chunkSize - 1)] else: data ## write length prefixed var buf = initVBuffer() buf.writePBVarint(id shl 3 or ord(msgType).uint64) - buf.writePBVarint(chunkSize.uint64) # size should be always sent + buf.writeSeq(data.toOpenArray(offset, offset + chunkSize - 1)) buf.finish() left = left - chunkSize offset = offset + chunkSize - await conn.write(buf.buffer & chunk) + await conn.write(buf.buffer) if data.len == 0: return diff --git a/libp2p/muxers/mplex/lpchannel.nim b/libp2p/muxers/mplex/lpchannel.nim index 27e653e44..48eadb97d 100644 --- a/libp2p/muxers/mplex/lpchannel.nim +++ b/libp2p/muxers/mplex/lpchannel.nim @@ -268,7 +268,7 @@ proc init*( await chann.open() # writes should happen in sequence - trace "sending data" + trace "sending data", len = data.len await conn.writeMsg(chann.id, chann.msgCode, diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index e6b519d04..5915f0042 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -34,10 +34,8 @@ type TooManyChannels* = object of CatchableError Mplex* = ref object of Muxer - remote: Table[uint64, LPChannel] - local: Table[uint64, LPChannel] - currentId*: uint64 - maxChannels*: uint64 + channels: array[bool, Table[uint64, LPChannel]] + currentId: uint64 inChannTimeout: Duration outChannTimeout: Duration isClosed: bool @@ -47,13 +45,17 @@ type proc newTooManyChannels(): ref TooManyChannels = newException(TooManyChannels, "max allowed channel count exceeded") -proc getChannelList(m: Mplex, initiator: bool): var Table[uint64, LPChannel] = - if initiator: - trace "picking local channels", initiator = initiator, oid = $m.oid - result = m.local - else: - trace "picking remote channels", initiator = initiator, oid = $m.oid - result = m.remote +proc cleanupChann(m: Mplex, chann: LPChannel) {.async, inline.} = + ## remove the local channel from the internal tables + ## + await chann.join() + m.channels[chann.initiator].del(chann.id) + trace "cleaned up channel", id = chann.id, oid = $chann.oid + + when defined(libp2p_expensive_metrics): + libp2p_mplex_channels.set( + m.channels[chann.initiator].len.int64, + labelValues = [$chann.initiator, $m.connection.peerInfo]) proc newStreamInternal*(m: Mplex, initiator: bool = true, @@ -61,7 +63,7 @@ proc newStreamInternal*(m: Mplex, name: string = "", lazy: bool = false, timeout: Duration): - Future[LPChannel] {.async, gcsafe.} = + LPChannel {.gcsafe.} = ## create new channel/stream ## let id = if initiator: @@ -83,29 +85,17 @@ proc newStreamInternal*(m: Mplex, result.peerInfo = m.connection.peerInfo result.observedAddr = m.connection.observedAddr - doAssert(id notin m.getChannelList(initiator), + doAssert(id notin m.channels[initiator], "channel slot already taken!") - m.getChannelList(initiator)[id] = result - when defined(libp2p_expensive_metrics): - libp2p_mplex_channels.set( - m.getChannelList(initiator).len.int64, - labelValues = [$initiator, - $m.connection.peerInfo]) + m.channels[initiator][id] = result -proc cleanupChann(m: Mplex, chann: LPChannel) {.async, inline.} = - ## remove the local channel from the internal tables - ## - await chann.join() - if not isNil(chann): - m.getChannelList(chann.initiator).del(chann.id) - trace "cleaned up channel", id = chann.id + asyncCheck m.cleanupChann(result) when defined(libp2p_expensive_metrics): libp2p_mplex_channels.set( - m.getChannelList(chann.initiator).len.int64, - labelValues = [$chann.initiator, - $m.connection.peerInfo]) + m.channels[initiator].len.int64, + labelValues = [$initiator, $m.connection.peerInfo]) proc handleStream(m: Mplex, chann: LPChannel) {.async.} = ## call the muxer stream handler for this channel @@ -121,99 +111,75 @@ proc handleStream(m: Mplex, chann: LPChannel) {.async.} = except CatchableError as exc: trace "exception in stream handler", exc = exc.msg await chann.reset() - await m.cleanupChann(chann) method handle*(m: Mplex) {.async, gcsafe.} = - trace "starting mplex main loop", oid = $m.oid + logScope: moid = $m.oid + + trace "starting mplex main loop" try: defer: - trace "stopping mplex main loop", oid = $m.oid + trace "stopping mplex main loop" await m.close() while not m.connection.atEof: - trace "waiting for data", oid = $m.oid - let (id, msgType, data) = await m.connection.readMsg() - trace "read message from connection", id = id, - msgType = msgType, - data = data.shortLog, - oid = $m.oid - - let initiator = bool(ord(msgType) and 1) - var channel: LPChannel - if MessageType(msgType) != MessageType.New: - let channels = m.getChannelList(initiator) - if id notin channels: - - trace "Channel not found, skipping", id = id, - initiator = initiator, - msg = msgType, - oid = $m.oid - continue - channel = channels[id] + trace "waiting for data" + let + (id, msgType, data) = await m.connection.readMsg() + initiator = bool(ord(msgType) and 1) logScope: id = id initiator = initiator msgType = msgType size = data.len - muxer_oid = $m.oid - case msgType: - of MessageType.New: - let name = string.fromBytes(data) - if m.getChannelList(false).len > m.maxChannCount - 1: + trace "read message from connection", data = data.shortLog + + var channel = + if MessageType(msgType) != MessageType.New: + let tmp = m.channels[initiator].getOrDefault(id, nil) + if tmp == nil: + trace "Channel not found, skipping" + continue + + tmp + else: + if m.channels[false].len > m.maxChannCount - 1: warn "too many channels created by remote peer", allowedMax = MaxChannelCount raise newTooManyChannels() - channel = await m.newStreamInternal( - false, - id, - name, - timeout = m.outChannTimeout) + let name = string.fromBytes(data) + m.newStreamInternal(false, id, name, timeout = m.outChannTimeout) - trace "created channel", name = channel.name, - oid = $channel.oid + logScope: + name = channel.name + oid = $channel.oid + + case msgType: + of MessageType.New: + trace "created channel" if not isNil(m.streamHandler): # launch handler task asyncCheck m.handleStream(channel) of MessageType.MsgIn, MessageType.MsgOut: - logScope: - name = channel.name - oid = $channel.oid - - trace "pushing data to channel" - if data.len > MaxMsgSize: - warn "attempting to send a packet larger than allowed", allowed = MaxMsgSize, - sending = data.len + warn "attempting to send a packet larger than allowed", allowed = MaxMsgSize raise newLPStreamLimitError() + trace "pushing data to channel" await channel.pushTo(data) + trace "pushed data to channel" of MessageType.CloseIn, MessageType.CloseOut: - logScope: - name = channel.name - oid = $channel.oid - trace "closing channel" - await channel.closeRemote() - await m.cleanupChann(channel) - - trace "deleted channel" + trace "closed channel" of MessageType.ResetIn, MessageType.ResetOut: - logScope: - name = channel.name - oid = $channel.oid - trace "resetting channel" - await channel.reset() - await m.cleanupChann(channel) - - trace "deleted channel" + trace "reset channel" except CancelledError as exc: raise exc except CatchableError as exc: @@ -221,45 +187,41 @@ method handle*(m: Mplex) {.async, gcsafe.} = proc init*(M: type Mplex, conn: Connection, - maxChanns: uint = MaxChannels, inTimeout, outTimeout: Duration = DefaultChanTimeout, maxChannCount: int = MaxChannelCount): Mplex = M(connection: conn, - maxChannels: maxChanns, inChannTimeout: inTimeout, outChannTimeout: outTimeout, - remote: initTable[uint64, LPChannel](), - local: initTable[uint64, LPChannel](), oid: genOid(), maxChannCount: maxChannCount) method newStream*(m: Mplex, name: string = "", lazy: bool = false): Future[Connection] {.async, gcsafe.} = - let channel = await m.newStreamInternal( + let channel = m.newStreamInternal( lazy = lazy, timeout = m.inChannTimeout) if not lazy: await channel.open() - asyncCheck m.cleanupChann(channel) return Connection(channel) method close*(m: Mplex) {.async, gcsafe.} = if m.isClosed: return - defer: - m.remote.clear() - m.local.clear() - m.isClosed = true + trace "closing mplex muxer", moid = $m.oid - trace "closing mplex muxer", oid = $m.oid - let channs = toSeq(m.remote.values) & - toSeq(m.local.values) + m.isClosed = true + + let channs = toSeq(m.channels[false].values) & toSeq(m.channels[true].values) for chann in channs: await chann.reset() - await m.cleanupChann(chann) await m.connection.close() + + # TODO while we're resetting, new channels may be created that will not be + # closed properly + m.channels[false].clear() + m.channels[true].clear() diff --git a/libp2p/muxers/mplex/types.nim b/libp2p/muxers/mplex/types.nim index 86709effc..680dd10fb 100644 --- a/libp2p/muxers/mplex/types.nim +++ b/libp2p/muxers/mplex/types.nim @@ -11,7 +11,6 @@ import chronos # https://github.com/libp2p/specs/tree/master/mplex#writing-to-a-stream const MaxMsgSize* = 1 shl 20 # 1mb -const MaxChannels* = 1000 const MplexCodec* = "/mplex/6.7.0" const MaxReadWriteTime* = 5.seconds diff --git a/libp2p/stream/bufferstream.nim b/libp2p/stream/bufferstream.nim index 00547d24b..6fcaab52a 100644 --- a/libp2p/stream/bufferstream.nim +++ b/libp2p/stream/bufferstream.nim @@ -204,7 +204,7 @@ proc drainBuffer*(s: BufferStream) {.async.} = ## wait for all data in the buffer to be consumed ## - trace "draining buffer", len = s.len + trace "draining buffer", len = s.len, oid = $s.oid while s.len > 0: await s.dataReadEvent.wait() s.dataReadEvent.clear() @@ -306,7 +306,8 @@ method close*(s: BufferStream) {.async, gcsafe.} = inc getBufferStreamTracker().closed trace "bufferstream closed", oid = $s.oid else: - trace "attempt to close an already closed bufferstream", trace = getStackTrace() + trace "attempt to close an already closed bufferstream", + trace = getStackTrace(), oid = $s.oid except CancelledError as exc: raise exc except CatchableError as exc: