diff --git a/examples/tutorial_2_customproto.nim b/examples/tutorial_2_customproto.nim index be418a7a1..0bf6c197a 100644 --- a/examples/tutorial_2_customproto.nim +++ b/examples/tutorial_2_customproto.nim @@ -32,7 +32,7 @@ proc new(T: typedesc[TestProto]): T = # We must close the connections ourselves when we're done with it await conn.close() - return T(codecs: @[TestCodec], handler: handle) + return T.new(codecs = @[TestCodec], handler = handle) ## This is a constructor for our `TestProto`, that will specify our `codecs` and a `handler`, which will be called for each incoming peer asking for this protocol. ## In our handle, we simply read a message from the connection and `echo` it. diff --git a/examples/tutorial_3_protobuf.nim b/examples/tutorial_3_protobuf.nim index 2af7efe61..4ba7ac98f 100644 --- a/examples/tutorial_3_protobuf.nim +++ b/examples/tutorial_3_protobuf.nim @@ -107,7 +107,7 @@ type metricGetter: MetricCallback proc new(_: typedesc[MetricProto], cb: MetricCallback): MetricProto = - let res = MetricProto(metricGetter: cb) + var res: MetricProto proc handle(conn: Connection, proto: string) {.async, gcsafe.} = let metrics = await res.metricGetter() @@ -115,8 +115,8 @@ proc new(_: typedesc[MetricProto], cb: MetricCallback): MetricProto = await conn.writeLp(asProtobuf.buffer) await conn.close() - res.codecs = @["/metric-getter/1.0.0"] - res.handler = handle + res = MetricProto.new(@["/metric-getter/1.0.0"], handle) + res.metricGetter = cb return res proc fetch(p: MetricProto, conn: Connection): Future[MetricList] {.async.} = diff --git a/examples/tutorial_5_discovery.nim b/examples/tutorial_5_discovery.nim index ce02e19df..889087736 100644 --- a/examples/tutorial_5_discovery.nim +++ b/examples/tutorial_5_discovery.nim @@ -36,7 +36,7 @@ proc new(T: typedesc[DumbProto], nodeNumber: int): T = proc handle(conn: Connection, proto: string) {.async, gcsafe.} = echo "Node", nodeNumber, " received: ", string.fromBytes(await conn.readLp(1024)) await conn.close() - return T(codecs: @[DumbCodec], handler: handle) + return T.new(codecs = @[DumbCodec], handler = handle) ## ## Bootnodes ## The first time a p2p program is ran, he needs to know how to join diff --git a/examples/tutorial_6_game.nim b/examples/tutorial_6_game.nim index ffbf09a7b..f3be6d372 100644 --- a/examples/tutorial_6_game.nim +++ b/examples/tutorial_6_game.nim @@ -157,7 +157,7 @@ proc new(T: typedesc[GameProto], g: Game): T = # The handler of a protocol must wait for the stream to # be finished before returning await conn.join() - return T(codecs: @["/tron/1.0.0"], handler: handle) + return T.new(codecs = @["/tron/1.0.0"], handler = handle) proc networking(g: Game) {.async.} = # Create our switch, similar to the GossipSub example and diff --git a/libp2p/multistream.nim b/libp2p/multistream.nim index e2797a6a3..bded5a532 100644 --- a/libp2p/multistream.nim +++ b/libp2p/multistream.nim @@ -12,7 +12,7 @@ when (NimMajor, NimMinor) < (1, 4): else: {.push raises: [].} -import std/[strutils, sequtils] +import std/[strutils, sequtils, tables] import chronos, chronicles, stew/byteutils import stream/connection, protocols/protocol @@ -21,7 +21,7 @@ logScope: topics = "libp2p multistream" const - MsgSize* = 64*1024 + MsgSize* = 1024 Codec* = "/multistream/1.0.0" MSCodec* = "\x13" & Codec & "\n" @@ -33,17 +33,20 @@ type MultiStreamError* = object of LPError - HandlerHolder* = object + HandlerHolder* = ref object protos*: seq[string] protocol*: LPProtocol match*: Matcher + openedStreams: CountTable[PeerId] MultistreamSelect* = ref object of RootObj handlers*: seq[HandlerHolder] codec*: string proc new*(T: typedesc[MultistreamSelect]): T = - T(codec: MSCodec) + T( + codec: MSCodec, + ) template validateSuffix(str: string): untyped = if str.endsWith("\n"): @@ -169,9 +172,22 @@ proc handle*(m: MultistreamSelect, conn: Connection, active: bool = false) {.asy for h in m.handlers: if (not isNil(h.match) and h.match(ms)) or h.protos.contains(ms): trace "found handler", conn, protocol = ms - await conn.writeLp(ms & "\n") - conn.protocol = ms - await h.protocol.handler(conn, ms) + + var protocolHolder = h + let maxIncomingStreams = protocolHolder.protocol.maxIncomingStreams + if protocolHolder.openedStreams.getOrDefault(conn.peerId) >= maxIncomingStreams: + debug "Max streams for protocol reached, blocking new stream", + conn, protocol = ms, maxIncomingStreams + return + protocolHolder.openedStreams.inc(conn.peerId) + try: + await conn.writeLp(ms & "\n") + conn.protocol = ms + await protocolHolder.protocol.handler(conn, ms) + finally: + protocolHolder.openedStreams.inc(conn.peerId, -1) + if protocolHolder.openedStreams[conn.peerId] == 0: + protocolHolder.openedStreams.del(conn.peerId) return debug "no handlers", conn, protocol = ms await conn.write(Na) diff --git a/libp2p/protocols/protocol.nim b/libp2p/protocols/protocol.nim index ee3c39a62..5103264ad 100644 --- a/libp2p/protocols/protocol.nim +++ b/libp2p/protocols/protocol.nim @@ -12,9 +12,14 @@ when (NimMajor, NimMinor) < (1, 4): else: {.push raises: [].} -import chronos +import chronos, stew/results import ../stream/connection +export results + +const + DefaultMaxIncomingStreams* = 10 + type LPProtoHandler* = proc ( conn: Connection, @@ -26,11 +31,17 @@ type codecs*: seq[string] handler*: LPProtoHandler ## this handler gets invoked by the protocol negotiator started*: bool + maxIncomingStreams: Opt[int] method init*(p: LPProtocol) {.base, gcsafe.} = discard method start*(p: LPProtocol) {.async, base.} = p.started = true method stop*(p: LPProtocol) {.async, base.} = p.started = false +proc maxIncomingStreams*(p: LPProtocol): int = + p.maxIncomingStreams.get(DefaultMaxIncomingStreams) + +proc `maxIncomingStreams=`*(p: LPProtocol, val: int) = + p.maxIncomingStreams = Opt.some(val) func codec*(p: LPProtocol): string = assert(p.codecs.len > 0, "Codecs sequence was empty!") @@ -40,3 +51,16 @@ func `codec=`*(p: LPProtocol, codec: string) = # always insert as first codec # if we use this abstraction p.codecs.insert(codec, 0) + +proc new*( + T: type LPProtocol, + codecs: seq[string], + handler: LPProtoHandler, # default(Opt[int]) or Opt.none(int) don't work on 1.2 + maxIncomingStreams: Opt[int] | int = Opt[int]()): T = + T( + codecs: codecs, + handler: handler, + maxIncomingStreams: + when maxIncomingStreams is int: Opt.some(maxIncomingStreams) + else: maxIncomingStreams + ) diff --git a/tests/testmultistream.nim b/tests/testmultistream.nim index 6bdf1aa40..a29993d53 100644 --- a/tests/testmultistream.nim +++ b/tests/testmultistream.nim @@ -278,6 +278,79 @@ suite "Multistream select": await handlerWait.wait(30.seconds) + asyncTest "e2e - streams limit": + let ma = @[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()] + let blocker = newFuture[void]() + + # Start 5 streams which are blocked by `blocker` + # Try to start a new one, which should fail + # Unblock the 5 streams, check that we can open a new one + proc testHandler(conn: Connection, + proto: string): + Future[void] {.async, gcsafe.} = + await blocker + await conn.writeLp("Hello!") + await conn.close() + + var protocol: LPProtocol = LPProtocol.new( + @["/test/proto/1.0.0"], + testHandler, + maxIncomingStreams = 5 + ) + + protocol.handler = testHandler + let msListen = MultistreamSelect.new() + msListen.addHandler("/test/proto/1.0.0", protocol) + + let transport1 = TcpTransport.new(upgrade = Upgrade()) + await transport1.start(ma) + + proc acceptedOne(c: Connection) {.async.} = + await msListen.handle(c) + await c.close() + + proc acceptHandler() {.async, gcsafe.} = + while true: + let conn = await transport1.accept() + asyncSpawn acceptedOne(conn) + + var handlerWait = acceptHandler() + + let msDial = MultistreamSelect.new() + let transport2 = TcpTransport.new(upgrade = Upgrade()) + + proc connector {.async.} = + let conn = await transport2.dial(transport1.addrs[0]) + check: (await msDial.select(conn, "/test/proto/1.0.0")) == true + check: string.fromBytes(await conn.readLp(1024)) == "Hello!" + await conn.close() + + # Fill up the 5 allowed streams + var dialers: seq[Future[void]] + for _ in 0..<5: + dialers.add(connector()) + + # This one will fail during negotiation + expect(CatchableError): + try: waitFor(connector().wait(1.seconds)) + except AsyncTimeoutError as exc: + check false + raise exc + # check that the dialers aren't finished + check: (await dialers[0].withTimeout(10.milliseconds)) == false + + # unblock the dialers + blocker.complete() + await allFutures(dialers) + + # now must work + waitFor(connector()) + + await transport2.stop() + await transport1.stop() + + await handlerWait.cancelAndWait() + asyncTest "e2e - ls": let ma = @[MultiAddress.init("/ip4/0.0.0.0/tcp/0").tryGet()]