diff --git a/libp2p/transports/tcptransport.nim b/libp2p/transports/tcptransport.nim index 4d0963b82..0bb2df4d3 100644 --- a/libp2p/transports/tcptransport.nim +++ b/libp2p/transports/tcptransport.nim @@ -8,7 +8,7 @@ ## those terms. import oids -import chronos, chronicles, sequtils +import chronos, chronicles, sequtils, sets import transport, ../errors, ../wire, @@ -22,14 +22,19 @@ logScope: const TcpTransportTrackerName* = "libp2p.tcptransport" + MaxTCPConnections* = 50 type + TooManyConnections* = object of CatchableError + TcpTransport* = ref object of Transport server*: StreamServer - clients: seq[StreamTransport] + conns: HashSet[ChronosStream] flags: set[ServerFlags] cleanups*: seq[Future[void]] handlers*: seq[Future[void]] + maxIncoming: int + maxOutgoing: int TcpTransportTracker* = ref object of TrackerBase opened*: uint64 @@ -59,12 +64,38 @@ proc setupTcpTransportTracker(): TcpTransportTracker = result.isLeaked = leakTransport addTracker(TcpTransportTrackerName, result) +proc newTooManyConnections(): ref TooManyConnections = + newException(TooManyConnections, "too many inbound connections") + +proc cleanup(t: TcpTransport, conn: ChronosStream) {.async.} = + try: + await conn.closeEvent.wait() + trace "cleaning up socket", addrs = $conn.client.remoteAddress, + connoid = $conn.oid + if not(isNil(conn)): + await conn.close() + + t.conns.excl(conn) + + let inLen = toSeq(t.conns).filterIt( it.dir == Direction.In ).len + if inLen < t.maxIncoming: + if not isNil(t.server): + trace "restarting accept loop", limit = inLen + t.server.start() + + except CatchableError as exc: + trace "error cleaning up socket", exc = exc.msg + proc connHandler*(t: TcpTransport, client: StreamTransport, initiator: bool): Connection = trace "handling connection", address = $client.remoteAddress - let conn: Connection = Connection(ChronosStream.init(client)) - conn.observedAddr = MultiAddress.init(client.remoteAddress).tryGet() + let stream = ChronosStream.init(client, + dir = if initiator: Direction.Out + else: Direction.In) + + let conn = Connection(stream) + stream.observedAddr = MultiAddress.init(client.remoteAddress).tryGet() if not initiator: if not isNil(t.handler): t.handlers &= t.handler(conn) @@ -90,17 +121,31 @@ proc connCb(server: StreamServer, trace "incoming connection", address = $client.remoteAddress try: let t = cast[TcpTransport](server.udata) + let inLen = toSeq(t.conns).filterIt( it.dir == Direction.In ).len + if inLen + 1 >= t.maxIncoming: + trace "connection limit reached", limit = t.maxIncoming, + dir = $Direction.In + server.stop() + await client.closeWait() + return + # we don't need result connection in this case # as it's added inside connHandler - discard t.connHandler(client, false) + discard connHandler(t, client, false) + except CancelledError as exc: raise exc except CatchableError as err: debug "Connection setup failed", err = err.msg client.close() -proc init*(T: type TcpTransport, flags: set[ServerFlags] = {}): T = - result = T(flags: flags) +proc init*(T: type TcpTransport, + flags: set[ServerFlags] = {}, + maxIncoming, maxOutgoing = MaxTCPConnections): T = + result = T(flags: flags, + maxIncoming: maxIncoming, + maxOutgoing: maxOutgoing) + result.initTransport() method initTransport*(t: TcpTransport) = @@ -114,7 +159,7 @@ method close*(t: TcpTransport) {.async, gcsafe.} = await procCall Transport(t).close() # call base checkFutures(await allFinished( - t.clients.mapIt(it.closeWait()))) + toSeq(t.conns).mapIt(it.client.closeWait()))) # server can be nil if not isNil(t.server): @@ -153,7 +198,13 @@ method listen*(t: TcpTransport, discard await procCall Transport(t).listen(ma, handler) # call base ## listen on the transport - t.server = createStreamServer(t.ma, connCb, t.flags, t) + t.server = createStreamServer( + t.ma, + connCb, + t.flags, + t, + backlog = t.maxIncoming) + t.server.start() # always get the resolved address in case we're bound to 0.0.0.0:0 @@ -165,10 +216,18 @@ method dial*(t: TcpTransport, address: MultiAddress): Future[Connection] {.async, gcsafe.} = trace "dialing remote peer", address = $address + + let outLen = toSeq(t.conns).filterIt( it.dir == Direction.Out ).len + if outLen + 1 >= t.maxOutgoing: + trace "connection limit reached", limit = t.maxOutgoing, + dir = $Direction.Out + raise newTooManyConnections() + ## dial a peer - let client: StreamTransport = await connect(address) - result = t.connHandler(client, true) + let client = await connect(address) + return t.connHandler(client, true) method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} = if procCall Transport(t).handles(address): - result = address.protocols.tryGet().filterIt( it == multiCodec("tcp") ).len > 0 + return (address.protocols.tryGet() + .filterIt( it == multiCodec("tcp") ).len > 0) diff --git a/tests/testtransport.nim b/tests/testtransport.nim index df5ee4b69..1a41268f3 100644 --- a/tests/testtransport.nim +++ b/tests/testtransport.nim @@ -1,12 +1,13 @@ {.used.} -import unittest +import unittest, sequtils import chronos, stew/byteutils import ../libp2p/[stream/connection, transports/transport, transports/tcptransport, multiaddress, - wire] + wire, + errors] import ./helpers suite "TCP transport": @@ -192,3 +193,51 @@ suite "TCP transport": check: waitFor(testListenerDialer()) == true + + test "e2e: should limit incoming connections": + proc test() {.async.} = + let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + var times = 1 + proc connHandler(conn: Connection) {.async, gcsafe.} = + times.inc() + + var transports: seq[TcpTransport] + transports.add(TcpTransport.init(maxIncoming = 2)) + asyncCheck transports[0].listen(ma, connHandler) + + try: + for i in 0..10: + let transport = TcpTransport.init() + transports.add(transport) + discard await transport.dial(transports[0].ma).wait(10.millis) + except AsyncTimeoutError: + check times == 2 + + await allFuturesThrowing( + transports.mapIt(it.close())) + + waitFor(test()) + + test "e2e: should limit outgoing connections": + proc test() {.async.} = + let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + var times = 1 + proc connHandler(conn: Connection) {.async, gcsafe.} = + times.inc() + + var transports: seq[TcpTransport] + transports.add(TcpTransport.init()) + asyncCheck transports[0].listen(ma, connHandler) + + try: + let transport = TcpTransport.init(maxOutgoing = 2) + transports.add(transport) + for i in 0..10: + discard await transport.dial(transports[0].ma) + except TooManyConnections: + check times == 2 + + await allFuturesThrowing( + transports.mapIt(it.close())) + + waitFor(test())