Limit number of streams per protocol per peer (#811)
This commit is contained in:
parent
31ad4ae205
commit
64cbbe1e0a
|
@ -32,7 +32,7 @@ proc new(T: typedesc[TestProto]): T =
|
|||
# We must close the connections ourselves when we're done with it
|
||||
await conn.close()
|
||||
|
||||
return T(codecs: @[TestCodec], handler: handle)
|
||||
return T.new(codecs = @[TestCodec], handler = handle)
|
||||
|
||||
## This is a constructor for our `TestProto`, that will specify our `codecs` and a `handler`, which will be called for each incoming peer asking for this protocol.
|
||||
## In our handle, we simply read a message from the connection and `echo` it.
|
||||
|
|
|
@ -107,7 +107,7 @@ type
|
|||
metricGetter: MetricCallback
|
||||
|
||||
proc new(_: typedesc[MetricProto], cb: MetricCallback): MetricProto =
|
||||
let res = MetricProto(metricGetter: cb)
|
||||
var res: MetricProto
|
||||
proc handle(conn: Connection, proto: string) {.async, gcsafe.} =
|
||||
let
|
||||
metrics = await res.metricGetter()
|
||||
|
@ -115,8 +115,8 @@ proc new(_: typedesc[MetricProto], cb: MetricCallback): MetricProto =
|
|||
await conn.writeLp(asProtobuf.buffer)
|
||||
await conn.close()
|
||||
|
||||
res.codecs = @["/metric-getter/1.0.0"]
|
||||
res.handler = handle
|
||||
res = MetricProto.new(@["/metric-getter/1.0.0"], handle)
|
||||
res.metricGetter = cb
|
||||
return res
|
||||
|
||||
proc fetch(p: MetricProto, conn: Connection): Future[MetricList] {.async.} =
|
||||
|
|
|
@ -36,7 +36,7 @@ proc new(T: typedesc[DumbProto], nodeNumber: int): T =
|
|||
proc handle(conn: Connection, proto: string) {.async, gcsafe.} =
|
||||
echo "Node", nodeNumber, " received: ", string.fromBytes(await conn.readLp(1024))
|
||||
await conn.close()
|
||||
return T(codecs: @[DumbCodec], handler: handle)
|
||||
return T.new(codecs = @[DumbCodec], handler = handle)
|
||||
|
||||
## ## Bootnodes
|
||||
## The first time a p2p program is ran, he needs to know how to join
|
||||
|
|
|
@ -157,7 +157,7 @@ proc new(T: typedesc[GameProto], g: Game): T =
|
|||
# The handler of a protocol must wait for the stream to
|
||||
# be finished before returning
|
||||
await conn.join()
|
||||
return T(codecs: @["/tron/1.0.0"], handler: handle)
|
||||
return T.new(codecs = @["/tron/1.0.0"], handler = handle)
|
||||
|
||||
proc networking(g: Game) {.async.} =
|
||||
# Create our switch, similar to the GossipSub example and
|
||||
|
|
|
@ -12,7 +12,7 @@ when (NimMajor, NimMinor) < (1, 4):
|
|||
else:
|
||||
{.push raises: [].}
|
||||
|
||||
import std/[strutils, sequtils]
|
||||
import std/[strutils, sequtils, tables]
|
||||
import chronos, chronicles, stew/byteutils
|
||||
import stream/connection,
|
||||
protocols/protocol
|
||||
|
@ -21,7 +21,7 @@ logScope:
|
|||
topics = "libp2p multistream"
|
||||
|
||||
const
|
||||
MsgSize* = 64*1024
|
||||
MsgSize* = 1024
|
||||
Codec* = "/multistream/1.0.0"
|
||||
|
||||
MSCodec* = "\x13" & Codec & "\n"
|
||||
|
@ -33,17 +33,20 @@ type
|
|||
|
||||
MultiStreamError* = object of LPError
|
||||
|
||||
HandlerHolder* = object
|
||||
HandlerHolder* = ref object
|
||||
protos*: seq[string]
|
||||
protocol*: LPProtocol
|
||||
match*: Matcher
|
||||
openedStreams: CountTable[PeerId]
|
||||
|
||||
MultistreamSelect* = ref object of RootObj
|
||||
handlers*: seq[HandlerHolder]
|
||||
codec*: string
|
||||
|
||||
proc new*(T: typedesc[MultistreamSelect]): T =
|
||||
T(codec: MSCodec)
|
||||
T(
|
||||
codec: MSCodec,
|
||||
)
|
||||
|
||||
template validateSuffix(str: string): untyped =
|
||||
if str.endsWith("\n"):
|
||||
|
@ -169,9 +172,22 @@ proc handle*(m: MultistreamSelect, conn: Connection, active: bool = false) {.asy
|
|||
for h in m.handlers:
|
||||
if (not isNil(h.match) and h.match(ms)) or h.protos.contains(ms):
|
||||
trace "found handler", conn, protocol = ms
|
||||
await conn.writeLp(ms & "\n")
|
||||
conn.protocol = ms
|
||||
await h.protocol.handler(conn, ms)
|
||||
|
||||
var protocolHolder = h
|
||||
let maxIncomingStreams = protocolHolder.protocol.maxIncomingStreams
|
||||
if protocolHolder.openedStreams.getOrDefault(conn.peerId) >= maxIncomingStreams:
|
||||
debug "Max streams for protocol reached, blocking new stream",
|
||||
conn, protocol = ms, maxIncomingStreams
|
||||
return
|
||||
protocolHolder.openedStreams.inc(conn.peerId)
|
||||
try:
|
||||
await conn.writeLp(ms & "\n")
|
||||
conn.protocol = ms
|
||||
await protocolHolder.protocol.handler(conn, ms)
|
||||
finally:
|
||||
protocolHolder.openedStreams.inc(conn.peerId, -1)
|
||||
if protocolHolder.openedStreams[conn.peerId] == 0:
|
||||
protocolHolder.openedStreams.del(conn.peerId)
|
||||
return
|
||||
debug "no handlers", conn, protocol = ms
|
||||
await conn.write(Na)
|
||||
|
|
|
@ -12,9 +12,14 @@ when (NimMajor, NimMinor) < (1, 4):
|
|||
else:
|
||||
{.push raises: [].}
|
||||
|
||||
import chronos
|
||||
import chronos, stew/results
|
||||
import ../stream/connection
|
||||
|
||||
export results
|
||||
|
||||
const
|
||||
DefaultMaxIncomingStreams* = 10
|
||||
|
||||
type
|
||||
LPProtoHandler* = proc (
|
||||
conn: Connection,
|
||||
|
@ -26,11 +31,17 @@ type
|
|||
codecs*: seq[string]
|
||||
handler*: LPProtoHandler ## this handler gets invoked by the protocol negotiator
|
||||
started*: bool
|
||||
maxIncomingStreams: Opt[int]
|
||||
|
||||
method init*(p: LPProtocol) {.base, gcsafe.} = discard
|
||||
method start*(p: LPProtocol) {.async, base.} = p.started = true
|
||||
method stop*(p: LPProtocol) {.async, base.} = p.started = false
|
||||
|
||||
proc maxIncomingStreams*(p: LPProtocol): int =
|
||||
p.maxIncomingStreams.get(DefaultMaxIncomingStreams)
|
||||
|
||||
proc `maxIncomingStreams=`*(p: LPProtocol, val: int) =
|
||||
p.maxIncomingStreams = Opt.some(val)
|
||||
|
||||
func codec*(p: LPProtocol): string =
|
||||
assert(p.codecs.len > 0, "Codecs sequence was empty!")
|
||||
|
@ -40,3 +51,16 @@ func `codec=`*(p: LPProtocol, codec: string) =
|
|||
# always insert as first codec
|
||||
# if we use this abstraction
|
||||
p.codecs.insert(codec, 0)
|
||||
|
||||
proc new*(
|
||||
T: type LPProtocol,
|
||||
codecs: seq[string],
|
||||
handler: LPProtoHandler, # default(Opt[int]) or Opt.none(int) don't work on 1.2
|
||||
maxIncomingStreams: Opt[int] | int = Opt[int]()): T =
|
||||
T(
|
||||
codecs: codecs,
|
||||
handler: handler,
|
||||
maxIncomingStreams:
|
||||
when maxIncomingStreams is int: Opt.some(maxIncomingStreams)
|
||||
else: maxIncomingStreams
|
||||
)
|
||||
|
|
|
@ -278,6 +278,79 @@ suite "Multistream select":
|
|||
|
||||
await handlerWait.wait(30.seconds)
|
||||
|
||||
asyncTest "e2e - streams limit":
|
||||
let ma = @[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()]
|
||||
let blocker = newFuture[void]()
|
||||
|
||||
# Start 5 streams which are blocked by `blocker`
|
||||
# Try to start a new one, which should fail
|
||||
# Unblock the 5 streams, check that we can open a new one
|
||||
proc testHandler(conn: Connection,
|
||||
proto: string):
|
||||
Future[void] {.async, gcsafe.} =
|
||||
await blocker
|
||||
await conn.writeLp("Hello!")
|
||||
await conn.close()
|
||||
|
||||
var protocol: LPProtocol = LPProtocol.new(
|
||||
@["/test/proto/1.0.0"],
|
||||
testHandler,
|
||||
maxIncomingStreams = 5
|
||||
)
|
||||
|
||||
protocol.handler = testHandler
|
||||
let msListen = MultistreamSelect.new()
|
||||
msListen.addHandler("/test/proto/1.0.0", protocol)
|
||||
|
||||
let transport1 = TcpTransport.new(upgrade = Upgrade())
|
||||
await transport1.start(ma)
|
||||
|
||||
proc acceptedOne(c: Connection) {.async.} =
|
||||
await msListen.handle(c)
|
||||
await c.close()
|
||||
|
||||
proc acceptHandler() {.async, gcsafe.} =
|
||||
while true:
|
||||
let conn = await transport1.accept()
|
||||
asyncSpawn acceptedOne(conn)
|
||||
|
||||
var handlerWait = acceptHandler()
|
||||
|
||||
let msDial = MultistreamSelect.new()
|
||||
let transport2 = TcpTransport.new(upgrade = Upgrade())
|
||||
|
||||
proc connector {.async.} =
|
||||
let conn = await transport2.dial(transport1.addrs[0])
|
||||
check: (await msDial.select(conn, "/test/proto/1.0.0")) == true
|
||||
check: string.fromBytes(await conn.readLp(1024)) == "Hello!"
|
||||
await conn.close()
|
||||
|
||||
# Fill up the 5 allowed streams
|
||||
var dialers: seq[Future[void]]
|
||||
for _ in 0..<5:
|
||||
dialers.add(connector())
|
||||
|
||||
# This one will fail during negotiation
|
||||
expect(CatchableError):
|
||||
try: waitFor(connector().wait(1.seconds))
|
||||
except AsyncTimeoutError as exc:
|
||||
check false
|
||||
raise exc
|
||||
# check that the dialers aren't finished
|
||||
check: (await dialers[0].withTimeout(10.milliseconds)) == false
|
||||
|
||||
# unblock the dialers
|
||||
blocker.complete()
|
||||
await allFutures(dialers)
|
||||
|
||||
# now must work
|
||||
waitFor(connector())
|
||||
|
||||
await transport2.stop()
|
||||
await transport1.stop()
|
||||
|
||||
await handlerWait.cancelAndWait()
|
||||
|
||||
asyncTest "e2e - ls":
|
||||
let ma = @[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()]
|
||||
|
||||
|
|
Loading…
Reference in New Issue