Limit number of streams per protocol per peer (#811)

This commit is contained in:
Tanguy 2022-12-01 12:20:40 +01:00 committed by GitHub
parent 31ad4ae205
commit 64cbbe1e0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 127 additions and 14 deletions

View File

@ -32,7 +32,7 @@ proc new(T: typedesc[TestProto]): T =
# We must close the connections ourselves when we're done with it # We must close the connections ourselves when we're done with it
await conn.close() await conn.close()
return T(codecs: @[TestCodec], handler: handle) return T.new(codecs = @[TestCodec], handler = handle)
## This is a constructor for our `TestProto`, that will specify our `codecs` and a `handler`, which will be called for each incoming peer asking for this protocol. ## This is a constructor for our `TestProto`, that will specify our `codecs` and a `handler`, which will be called for each incoming peer asking for this protocol.
## In our handle, we simply read a message from the connection and `echo` it. ## In our handle, we simply read a message from the connection and `echo` it.

View File

@ -107,7 +107,7 @@ type
metricGetter: MetricCallback metricGetter: MetricCallback
proc new(_: typedesc[MetricProto], cb: MetricCallback): MetricProto = proc new(_: typedesc[MetricProto], cb: MetricCallback): MetricProto =
let res = MetricProto(metricGetter: cb) var res: MetricProto
proc handle(conn: Connection, proto: string) {.async, gcsafe.} = proc handle(conn: Connection, proto: string) {.async, gcsafe.} =
let let
metrics = await res.metricGetter() metrics = await res.metricGetter()
@ -115,8 +115,8 @@ proc new(_: typedesc[MetricProto], cb: MetricCallback): MetricProto =
await conn.writeLp(asProtobuf.buffer) await conn.writeLp(asProtobuf.buffer)
await conn.close() await conn.close()
res.codecs = @["/metric-getter/1.0.0"] res = MetricProto.new(@["/metric-getter/1.0.0"], handle)
res.handler = handle res.metricGetter = cb
return res return res
proc fetch(p: MetricProto, conn: Connection): Future[MetricList] {.async.} = proc fetch(p: MetricProto, conn: Connection): Future[MetricList] {.async.} =

View File

@ -36,7 +36,7 @@ proc new(T: typedesc[DumbProto], nodeNumber: int): T =
proc handle(conn: Connection, proto: string) {.async, gcsafe.} = proc handle(conn: Connection, proto: string) {.async, gcsafe.} =
echo "Node", nodeNumber, " received: ", string.fromBytes(await conn.readLp(1024)) echo "Node", nodeNumber, " received: ", string.fromBytes(await conn.readLp(1024))
await conn.close() await conn.close()
return T(codecs: @[DumbCodec], handler: handle) return T.new(codecs = @[DumbCodec], handler = handle)
## ## Bootnodes ## ## Bootnodes
## The first time a p2p program is ran, he needs to know how to join ## The first time a p2p program is ran, he needs to know how to join

View File

@ -157,7 +157,7 @@ proc new(T: typedesc[GameProto], g: Game): T =
# The handler of a protocol must wait for the stream to # The handler of a protocol must wait for the stream to
# be finished before returning # be finished before returning
await conn.join() await conn.join()
return T(codecs: @["/tron/1.0.0"], handler: handle) return T.new(codecs = @["/tron/1.0.0"], handler = handle)
proc networking(g: Game) {.async.} = proc networking(g: Game) {.async.} =
# Create our switch, similar to the GossipSub example and # Create our switch, similar to the GossipSub example and

View File

@ -12,7 +12,7 @@ when (NimMajor, NimMinor) < (1, 4):
else: else:
{.push raises: [].} {.push raises: [].}
import std/[strutils, sequtils] import std/[strutils, sequtils, tables]
import chronos, chronicles, stew/byteutils import chronos, chronicles, stew/byteutils
import stream/connection, import stream/connection,
protocols/protocol protocols/protocol
@ -21,7 +21,7 @@ logScope:
topics = "libp2p multistream" topics = "libp2p multistream"
const const
MsgSize* = 64*1024 MsgSize* = 1024
Codec* = "/multistream/1.0.0" Codec* = "/multistream/1.0.0"
MSCodec* = "\x13" & Codec & "\n" MSCodec* = "\x13" & Codec & "\n"
@ -33,17 +33,20 @@ type
MultiStreamError* = object of LPError MultiStreamError* = object of LPError
HandlerHolder* = object HandlerHolder* = ref object
protos*: seq[string] protos*: seq[string]
protocol*: LPProtocol protocol*: LPProtocol
match*: Matcher match*: Matcher
openedStreams: CountTable[PeerId]
MultistreamSelect* = ref object of RootObj MultistreamSelect* = ref object of RootObj
handlers*: seq[HandlerHolder] handlers*: seq[HandlerHolder]
codec*: string codec*: string
proc new*(T: typedesc[MultistreamSelect]): T = proc new*(T: typedesc[MultistreamSelect]): T =
T(codec: MSCodec) T(
codec: MSCodec,
)
template validateSuffix(str: string): untyped = template validateSuffix(str: string): untyped =
if str.endsWith("\n"): if str.endsWith("\n"):
@ -169,9 +172,22 @@ proc handle*(m: MultistreamSelect, conn: Connection, active: bool = false) {.asy
for h in m.handlers: for h in m.handlers:
if (not isNil(h.match) and h.match(ms)) or h.protos.contains(ms): if (not isNil(h.match) and h.match(ms)) or h.protos.contains(ms):
trace "found handler", conn, protocol = ms trace "found handler", conn, protocol = ms
var protocolHolder = h
let maxIncomingStreams = protocolHolder.protocol.maxIncomingStreams
if protocolHolder.openedStreams.getOrDefault(conn.peerId) >= maxIncomingStreams:
debug "Max streams for protocol reached, blocking new stream",
conn, protocol = ms, maxIncomingStreams
return
protocolHolder.openedStreams.inc(conn.peerId)
try:
await conn.writeLp(ms & "\n") await conn.writeLp(ms & "\n")
conn.protocol = ms conn.protocol = ms
await h.protocol.handler(conn, ms) await protocolHolder.protocol.handler(conn, ms)
finally:
protocolHolder.openedStreams.inc(conn.peerId, -1)
if protocolHolder.openedStreams[conn.peerId] == 0:
protocolHolder.openedStreams.del(conn.peerId)
return return
debug "no handlers", conn, protocol = ms debug "no handlers", conn, protocol = ms
await conn.write(Na) await conn.write(Na)

View File

@ -12,9 +12,14 @@ when (NimMajor, NimMinor) < (1, 4):
else: else:
{.push raises: [].} {.push raises: [].}
import chronos import chronos, stew/results
import ../stream/connection import ../stream/connection
export results
const
DefaultMaxIncomingStreams* = 10
type type
LPProtoHandler* = proc ( LPProtoHandler* = proc (
conn: Connection, conn: Connection,
@ -26,11 +31,17 @@ type
codecs*: seq[string] codecs*: seq[string]
handler*: LPProtoHandler ## this handler gets invoked by the protocol negotiator handler*: LPProtoHandler ## this handler gets invoked by the protocol negotiator
started*: bool started*: bool
maxIncomingStreams: Opt[int]
method init*(p: LPProtocol) {.base, gcsafe.} = discard method init*(p: LPProtocol) {.base, gcsafe.} = discard
method start*(p: LPProtocol) {.async, base.} = p.started = true method start*(p: LPProtocol) {.async, base.} = p.started = true
method stop*(p: LPProtocol) {.async, base.} = p.started = false method stop*(p: LPProtocol) {.async, base.} = p.started = false
proc maxIncomingStreams*(p: LPProtocol): int =
p.maxIncomingStreams.get(DefaultMaxIncomingStreams)
proc `maxIncomingStreams=`*(p: LPProtocol, val: int) =
p.maxIncomingStreams = Opt.some(val)
func codec*(p: LPProtocol): string = func codec*(p: LPProtocol): string =
assert(p.codecs.len > 0, "Codecs sequence was empty!") assert(p.codecs.len > 0, "Codecs sequence was empty!")
@ -40,3 +51,16 @@ func `codec=`*(p: LPProtocol, codec: string) =
# always insert as first codec # always insert as first codec
# if we use this abstraction # if we use this abstraction
p.codecs.insert(codec, 0) p.codecs.insert(codec, 0)
proc new*(
T: type LPProtocol,
codecs: seq[string],
handler: LPProtoHandler, # default(Opt[int]) or Opt.none(int) don't work on 1.2
maxIncomingStreams: Opt[int] | int = Opt[int]()): T =
T(
codecs: codecs,
handler: handler,
maxIncomingStreams:
when maxIncomingStreams is int: Opt.some(maxIncomingStreams)
else: maxIncomingStreams
)

View File

@ -278,6 +278,79 @@ suite "Multistream select":
await handlerWait.wait(30.seconds) await handlerWait.wait(30.seconds)
asyncTest "e2e - streams limit":
let ma = @[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()]
let blocker = newFuture[void]()
# Start 5 streams which are blocked by `blocker`
# Try to start a new one, which should fail
# Unblock the 5 streams, check that we can open a new one
proc testHandler(conn: Connection,
proto: string):
Future[void] {.async, gcsafe.} =
await blocker
await conn.writeLp("Hello!")
await conn.close()
var protocol: LPProtocol = LPProtocol.new(
@["/test/proto/1.0.0"],
testHandler,
maxIncomingStreams = 5
)
protocol.handler = testHandler
let msListen = MultistreamSelect.new()
msListen.addHandler("/test/proto/1.0.0", protocol)
let transport1 = TcpTransport.new(upgrade = Upgrade())
await transport1.start(ma)
proc acceptedOne(c: Connection) {.async.} =
await msListen.handle(c)
await c.close()
proc acceptHandler() {.async, gcsafe.} =
while true:
let conn = await transport1.accept()
asyncSpawn acceptedOne(conn)
var handlerWait = acceptHandler()
let msDial = MultistreamSelect.new()
let transport2 = TcpTransport.new(upgrade = Upgrade())
proc connector {.async.} =
let conn = await transport2.dial(transport1.addrs[0])
check: (await msDial.select(conn, "/test/proto/1.0.0")) == true
check: string.fromBytes(await conn.readLp(1024)) == "Hello!"
await conn.close()
# Fill up the 5 allowed streams
var dialers: seq[Future[void]]
for _ in 0..<5:
dialers.add(connector())
# This one will fail during negotiation
expect(CatchableError):
try: waitFor(connector().wait(1.seconds))
except AsyncTimeoutError as exc:
check false
raise exc
# check that the dialers aren't finished
check: (await dialers[0].withTimeout(10.milliseconds)) == false
# unblock the dialers
blocker.complete()
await allFutures(dialers)
# now must work
waitFor(connector())
await transport2.stop()
await transport1.stop()
await handlerWait.cancelAndWait()
asyncTest "e2e - ls": asyncTest "e2e - ls":
let ma = @[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()] let ma = @[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()]