Limit number of streams per protocol per peer (#811)
This commit is contained in:
parent
31ad4ae205
commit
64cbbe1e0a
|
@ -32,7 +32,7 @@ proc new(T: typedesc[TestProto]): T =
|
||||||
# We must close the connections ourselves when we're done with it
|
# We must close the connections ourselves when we're done with it
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
return T(codecs: @[TestCodec], handler: handle)
|
return T.new(codecs = @[TestCodec], handler = handle)
|
||||||
|
|
||||||
## This is a constructor for our `TestProto`, that will specify our `codecs` and a `handler`, which will be called for each incoming peer asking for this protocol.
|
## This is a constructor for our `TestProto`, that will specify our `codecs` and a `handler`, which will be called for each incoming peer asking for this protocol.
|
||||||
## In our handle, we simply read a message from the connection and `echo` it.
|
## In our handle, we simply read a message from the connection and `echo` it.
|
||||||
|
|
|
@ -107,7 +107,7 @@ type
|
||||||
metricGetter: MetricCallback
|
metricGetter: MetricCallback
|
||||||
|
|
||||||
proc new(_: typedesc[MetricProto], cb: MetricCallback): MetricProto =
|
proc new(_: typedesc[MetricProto], cb: MetricCallback): MetricProto =
|
||||||
let res = MetricProto(metricGetter: cb)
|
var res: MetricProto
|
||||||
proc handle(conn: Connection, proto: string) {.async, gcsafe.} =
|
proc handle(conn: Connection, proto: string) {.async, gcsafe.} =
|
||||||
let
|
let
|
||||||
metrics = await res.metricGetter()
|
metrics = await res.metricGetter()
|
||||||
|
@ -115,8 +115,8 @@ proc new(_: typedesc[MetricProto], cb: MetricCallback): MetricProto =
|
||||||
await conn.writeLp(asProtobuf.buffer)
|
await conn.writeLp(asProtobuf.buffer)
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
res.codecs = @["/metric-getter/1.0.0"]
|
res = MetricProto.new(@["/metric-getter/1.0.0"], handle)
|
||||||
res.handler = handle
|
res.metricGetter = cb
|
||||||
return res
|
return res
|
||||||
|
|
||||||
proc fetch(p: MetricProto, conn: Connection): Future[MetricList] {.async.} =
|
proc fetch(p: MetricProto, conn: Connection): Future[MetricList] {.async.} =
|
||||||
|
|
|
@ -36,7 +36,7 @@ proc new(T: typedesc[DumbProto], nodeNumber: int): T =
|
||||||
proc handle(conn: Connection, proto: string) {.async, gcsafe.} =
|
proc handle(conn: Connection, proto: string) {.async, gcsafe.} =
|
||||||
echo "Node", nodeNumber, " received: ", string.fromBytes(await conn.readLp(1024))
|
echo "Node", nodeNumber, " received: ", string.fromBytes(await conn.readLp(1024))
|
||||||
await conn.close()
|
await conn.close()
|
||||||
return T(codecs: @[DumbCodec], handler: handle)
|
return T.new(codecs = @[DumbCodec], handler = handle)
|
||||||
|
|
||||||
## ## Bootnodes
|
## ## Bootnodes
|
||||||
## The first time a p2p program is ran, he needs to know how to join
|
## The first time a p2p program is ran, he needs to know how to join
|
||||||
|
|
|
@ -157,7 +157,7 @@ proc new(T: typedesc[GameProto], g: Game): T =
|
||||||
# The handler of a protocol must wait for the stream to
|
# The handler of a protocol must wait for the stream to
|
||||||
# be finished before returning
|
# be finished before returning
|
||||||
await conn.join()
|
await conn.join()
|
||||||
return T(codecs: @["/tron/1.0.0"], handler: handle)
|
return T.new(codecs = @["/tron/1.0.0"], handler = handle)
|
||||||
|
|
||||||
proc networking(g: Game) {.async.} =
|
proc networking(g: Game) {.async.} =
|
||||||
# Create our switch, similar to the GossipSub example and
|
# Create our switch, similar to the GossipSub example and
|
||||||
|
|
|
@ -12,7 +12,7 @@ when (NimMajor, NimMinor) < (1, 4):
|
||||||
else:
|
else:
|
||||||
{.push raises: [].}
|
{.push raises: [].}
|
||||||
|
|
||||||
import std/[strutils, sequtils]
|
import std/[strutils, sequtils, tables]
|
||||||
import chronos, chronicles, stew/byteutils
|
import chronos, chronicles, stew/byteutils
|
||||||
import stream/connection,
|
import stream/connection,
|
||||||
protocols/protocol
|
protocols/protocol
|
||||||
|
@ -21,7 +21,7 @@ logScope:
|
||||||
topics = "libp2p multistream"
|
topics = "libp2p multistream"
|
||||||
|
|
||||||
const
|
const
|
||||||
MsgSize* = 64*1024
|
MsgSize* = 1024
|
||||||
Codec* = "/multistream/1.0.0"
|
Codec* = "/multistream/1.0.0"
|
||||||
|
|
||||||
MSCodec* = "\x13" & Codec & "\n"
|
MSCodec* = "\x13" & Codec & "\n"
|
||||||
|
@ -33,17 +33,20 @@ type
|
||||||
|
|
||||||
MultiStreamError* = object of LPError
|
MultiStreamError* = object of LPError
|
||||||
|
|
||||||
HandlerHolder* = object
|
HandlerHolder* = ref object
|
||||||
protos*: seq[string]
|
protos*: seq[string]
|
||||||
protocol*: LPProtocol
|
protocol*: LPProtocol
|
||||||
match*: Matcher
|
match*: Matcher
|
||||||
|
openedStreams: CountTable[PeerId]
|
||||||
|
|
||||||
MultistreamSelect* = ref object of RootObj
|
MultistreamSelect* = ref object of RootObj
|
||||||
handlers*: seq[HandlerHolder]
|
handlers*: seq[HandlerHolder]
|
||||||
codec*: string
|
codec*: string
|
||||||
|
|
||||||
proc new*(T: typedesc[MultistreamSelect]): T =
|
proc new*(T: typedesc[MultistreamSelect]): T =
|
||||||
T(codec: MSCodec)
|
T(
|
||||||
|
codec: MSCodec,
|
||||||
|
)
|
||||||
|
|
||||||
template validateSuffix(str: string): untyped =
|
template validateSuffix(str: string): untyped =
|
||||||
if str.endsWith("\n"):
|
if str.endsWith("\n"):
|
||||||
|
@ -169,9 +172,22 @@ proc handle*(m: MultistreamSelect, conn: Connection, active: bool = false) {.asy
|
||||||
for h in m.handlers:
|
for h in m.handlers:
|
||||||
if (not isNil(h.match) and h.match(ms)) or h.protos.contains(ms):
|
if (not isNil(h.match) and h.match(ms)) or h.protos.contains(ms):
|
||||||
trace "found handler", conn, protocol = ms
|
trace "found handler", conn, protocol = ms
|
||||||
|
|
||||||
|
var protocolHolder = h
|
||||||
|
let maxIncomingStreams = protocolHolder.protocol.maxIncomingStreams
|
||||||
|
if protocolHolder.openedStreams.getOrDefault(conn.peerId) >= maxIncomingStreams:
|
||||||
|
debug "Max streams for protocol reached, blocking new stream",
|
||||||
|
conn, protocol = ms, maxIncomingStreams
|
||||||
|
return
|
||||||
|
protocolHolder.openedStreams.inc(conn.peerId)
|
||||||
|
try:
|
||||||
await conn.writeLp(ms & "\n")
|
await conn.writeLp(ms & "\n")
|
||||||
conn.protocol = ms
|
conn.protocol = ms
|
||||||
await h.protocol.handler(conn, ms)
|
await protocolHolder.protocol.handler(conn, ms)
|
||||||
|
finally:
|
||||||
|
protocolHolder.openedStreams.inc(conn.peerId, -1)
|
||||||
|
if protocolHolder.openedStreams[conn.peerId] == 0:
|
||||||
|
protocolHolder.openedStreams.del(conn.peerId)
|
||||||
return
|
return
|
||||||
debug "no handlers", conn, protocol = ms
|
debug "no handlers", conn, protocol = ms
|
||||||
await conn.write(Na)
|
await conn.write(Na)
|
||||||
|
|
|
@ -12,9 +12,14 @@ when (NimMajor, NimMinor) < (1, 4):
|
||||||
else:
|
else:
|
||||||
{.push raises: [].}
|
{.push raises: [].}
|
||||||
|
|
||||||
import chronos
|
import chronos, stew/results
|
||||||
import ../stream/connection
|
import ../stream/connection
|
||||||
|
|
||||||
|
export results
|
||||||
|
|
||||||
|
const
|
||||||
|
DefaultMaxIncomingStreams* = 10
|
||||||
|
|
||||||
type
|
type
|
||||||
LPProtoHandler* = proc (
|
LPProtoHandler* = proc (
|
||||||
conn: Connection,
|
conn: Connection,
|
||||||
|
@ -26,11 +31,17 @@ type
|
||||||
codecs*: seq[string]
|
codecs*: seq[string]
|
||||||
handler*: LPProtoHandler ## this handler gets invoked by the protocol negotiator
|
handler*: LPProtoHandler ## this handler gets invoked by the protocol negotiator
|
||||||
started*: bool
|
started*: bool
|
||||||
|
maxIncomingStreams: Opt[int]
|
||||||
|
|
||||||
method init*(p: LPProtocol) {.base, gcsafe.} = discard
|
method init*(p: LPProtocol) {.base, gcsafe.} = discard
|
||||||
method start*(p: LPProtocol) {.async, base.} = p.started = true
|
method start*(p: LPProtocol) {.async, base.} = p.started = true
|
||||||
method stop*(p: LPProtocol) {.async, base.} = p.started = false
|
method stop*(p: LPProtocol) {.async, base.} = p.started = false
|
||||||
|
|
||||||
|
proc maxIncomingStreams*(p: LPProtocol): int =
|
||||||
|
p.maxIncomingStreams.get(DefaultMaxIncomingStreams)
|
||||||
|
|
||||||
|
proc `maxIncomingStreams=`*(p: LPProtocol, val: int) =
|
||||||
|
p.maxIncomingStreams = Opt.some(val)
|
||||||
|
|
||||||
func codec*(p: LPProtocol): string =
|
func codec*(p: LPProtocol): string =
|
||||||
assert(p.codecs.len > 0, "Codecs sequence was empty!")
|
assert(p.codecs.len > 0, "Codecs sequence was empty!")
|
||||||
|
@ -40,3 +51,16 @@ func `codec=`*(p: LPProtocol, codec: string) =
|
||||||
# always insert as first codec
|
# always insert as first codec
|
||||||
# if we use this abstraction
|
# if we use this abstraction
|
||||||
p.codecs.insert(codec, 0)
|
p.codecs.insert(codec, 0)
|
||||||
|
|
||||||
|
proc new*(
|
||||||
|
T: type LPProtocol,
|
||||||
|
codecs: seq[string],
|
||||||
|
handler: LPProtoHandler, # default(Opt[int]) or Opt.none(int) don't work on 1.2
|
||||||
|
maxIncomingStreams: Opt[int] | int = Opt[int]()): T =
|
||||||
|
T(
|
||||||
|
codecs: codecs,
|
||||||
|
handler: handler,
|
||||||
|
maxIncomingStreams:
|
||||||
|
when maxIncomingStreams is int: Opt.some(maxIncomingStreams)
|
||||||
|
else: maxIncomingStreams
|
||||||
|
)
|
||||||
|
|
|
@ -278,6 +278,79 @@ suite "Multistream select":
|
||||||
|
|
||||||
await handlerWait.wait(30.seconds)
|
await handlerWait.wait(30.seconds)
|
||||||
|
|
||||||
|
asyncTest "e2e - streams limit":
|
||||||
|
let ma = @[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()]
|
||||||
|
let blocker = newFuture[void]()
|
||||||
|
|
||||||
|
# Start 5 streams which are blocked by `blocker`
|
||||||
|
# Try to start a new one, which should fail
|
||||||
|
# Unblock the 5 streams, check that we can open a new one
|
||||||
|
proc testHandler(conn: Connection,
|
||||||
|
proto: string):
|
||||||
|
Future[void] {.async, gcsafe.} =
|
||||||
|
await blocker
|
||||||
|
await conn.writeLp("Hello!")
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
var protocol: LPProtocol = LPProtocol.new(
|
||||||
|
@["/test/proto/1.0.0"],
|
||||||
|
testHandler,
|
||||||
|
maxIncomingStreams = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
protocol.handler = testHandler
|
||||||
|
let msListen = MultistreamSelect.new()
|
||||||
|
msListen.addHandler("/test/proto/1.0.0", protocol)
|
||||||
|
|
||||||
|
let transport1 = TcpTransport.new(upgrade = Upgrade())
|
||||||
|
await transport1.start(ma)
|
||||||
|
|
||||||
|
proc acceptedOne(c: Connection) {.async.} =
|
||||||
|
await msListen.handle(c)
|
||||||
|
await c.close()
|
||||||
|
|
||||||
|
proc acceptHandler() {.async, gcsafe.} =
|
||||||
|
while true:
|
||||||
|
let conn = await transport1.accept()
|
||||||
|
asyncSpawn acceptedOne(conn)
|
||||||
|
|
||||||
|
var handlerWait = acceptHandler()
|
||||||
|
|
||||||
|
let msDial = MultistreamSelect.new()
|
||||||
|
let transport2 = TcpTransport.new(upgrade = Upgrade())
|
||||||
|
|
||||||
|
proc connector {.async.} =
|
||||||
|
let conn = await transport2.dial(transport1.addrs[0])
|
||||||
|
check: (await msDial.select(conn, "/test/proto/1.0.0")) == true
|
||||||
|
check: string.fromBytes(await conn.readLp(1024)) == "Hello!"
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
# Fill up the 5 allowed streams
|
||||||
|
var dialers: seq[Future[void]]
|
||||||
|
for _ in 0..<5:
|
||||||
|
dialers.add(connector())
|
||||||
|
|
||||||
|
# This one will fail during negotiation
|
||||||
|
expect(CatchableError):
|
||||||
|
try: waitFor(connector().wait(1.seconds))
|
||||||
|
except AsyncTimeoutError as exc:
|
||||||
|
check false
|
||||||
|
raise exc
|
||||||
|
# check that the dialers aren't finished
|
||||||
|
check: (await dialers[0].withTimeout(10.milliseconds)) == false
|
||||||
|
|
||||||
|
# unblock the dialers
|
||||||
|
blocker.complete()
|
||||||
|
await allFutures(dialers)
|
||||||
|
|
||||||
|
# now must work
|
||||||
|
waitFor(connector())
|
||||||
|
|
||||||
|
await transport2.stop()
|
||||||
|
await transport1.stop()
|
||||||
|
|
||||||
|
await handlerWait.cancelAndWait()
|
||||||
|
|
||||||
asyncTest "e2e - ls":
|
asyncTest "e2e - ls":
|
||||||
let ma = @[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()]
|
let ma = @[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue