adding channel limits to mplex (#309)

This commit is contained in:
Dmitriy Ryajov 2020-08-04 23:16:04 -06:00 committed by GitHub
parent 145657895f
commit 74a6dccd80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 19 additions and 3 deletions

View File

@ -23,10 +23,16 @@ export muxer
logScope: logScope:
topics = "mplex" topics = "mplex"
const
MaxChannelCount = 200
when defined(libp2p_expensive_metrics): when defined(libp2p_expensive_metrics):
declareGauge(libp2p_mplex_channels, "mplex channels", labels = ["initiator", "peer"]) declareGauge(libp2p_mplex_channels,
"mplex channels", labels = ["initiator", "peer"])
type type
TooManyChannels* = object of CatchableError
Mplex* = ref object of Muxer Mplex* = ref object of Muxer
remote: Table[uint64, LPChannel] remote: Table[uint64, LPChannel]
local: Table[uint64, LPChannel] local: Table[uint64, LPChannel]
@ -36,6 +42,10 @@ type
outChannTimeout: Duration outChannTimeout: Duration
isClosed: bool isClosed: bool
oid*: Oid oid*: Oid
maxChannCount: int
proc newTooManyChannels(): ref TooManyChannels =
newException(TooManyChannels, "max allowed channel count exceeded")
proc getChannelList(m: Mplex, initiator: bool): var Table[uint64, LPChannel] = proc getChannelList(m: Mplex, initiator: bool): var Table[uint64, LPChannel] =
if initiator: if initiator:
@ -151,6 +161,10 @@ method handle*(m: Mplex) {.async, gcsafe.} =
case msgType: case msgType:
of MessageType.New: of MessageType.New:
let name = string.fromBytes(data) let name = string.fromBytes(data)
if m.getChannelList(false).len > m.maxChannCount - 1:
warn "too many channels created by remote peer", allowedMax = MaxChannelCount
raise newTooManyChannels()
channel = await m.newStreamInternal( channel = await m.newStreamInternal(
false, false,
id, id,
@ -208,14 +222,16 @@ method handle*(m: Mplex) {.async, gcsafe.} =
proc init*(M: type Mplex, proc init*(M: type Mplex,
conn: Connection, conn: Connection,
maxChanns: uint = MaxChannels, maxChanns: uint = MaxChannels,
inTimeout, outTimeout: Duration = DefaultChanTimeout): Mplex = inTimeout, outTimeout: Duration = DefaultChanTimeout,
maxChannCount: int = MaxChannelCount): Mplex =
M(connection: conn, M(connection: conn,
maxChannels: maxChanns, maxChannels: maxChanns,
inChannTimeout: inTimeout, inChannTimeout: inTimeout,
outChannTimeout: outTimeout, outChannTimeout: outTimeout,
remote: initTable[uint64, LPChannel](), remote: initTable[uint64, LPChannel](),
local: initTable[uint64, LPChannel](), local: initTable[uint64, LPChannel](),
oid: genOid()) oid: genOid(),
maxChannCount: maxChannCount)
method newStream*(m: Mplex, method newStream*(m: Mplex,
name: string = "", name: string = "",