diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index 214d99d2b..e48b52038 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -23,10 +23,16 @@ export muxer logScope: topics = "mplex" +const + MaxChannelCount = 200 + when defined(libp2p_expensive_metrics): - declareGauge(libp2p_mplex_channels, "mplex channels", labels = ["initiator", "peer"]) + declareGauge(libp2p_mplex_channels, + "mplex channels", labels = ["initiator", "peer"]) type + TooManyChannels* = object of CatchableError + Mplex* = ref object of Muxer remote: Table[uint64, LPChannel] local: Table[uint64, LPChannel] @@ -36,6 +42,10 @@ type outChannTimeout: Duration isClosed: bool oid*: Oid + maxChannCount: int + +proc newTooManyChannels(): ref TooManyChannels = + newException(TooManyChannels, "max allowed channel count exceeded") proc getChannelList(m: Mplex, initiator: bool): var Table[uint64, LPChannel] = if initiator: @@ -151,6 +161,10 @@ method handle*(m: Mplex) {.async, gcsafe.} = case msgType: of MessageType.New: let name = string.fromBytes(data) + if m.getChannelList(false).len > m.maxChannCount - 1: + warn "too many channels created by remote peer", allowedMax = MaxChannelCount + raise newTooManyChannels() + channel = await m.newStreamInternal( false, id, @@ -208,14 +222,16 @@ method handle*(m: Mplex) {.async, gcsafe.} = proc init*(M: type Mplex, conn: Connection, maxChanns: uint = MaxChannels, - inTimeout, outTimeout: Duration = DefaultChanTimeout): Mplex = + inTimeout, outTimeout: Duration = DefaultChanTimeout, + maxChannCount: int = MaxChannelCount): Mplex = M(connection: conn, maxChannels: maxChanns, inChannTimeout: inTimeout, outChannTimeout: outTimeout, remote: initTable[uint64, LPChannel](), local: initTable[uint64, LPChannel](), - oid: genOid()) + oid: genOid(), + maxChannCount: maxChannCount) method newStream*(m: Mplex, name: string = "",