mirror of
https://github.com/vacp2p/nim-libp2p.git
synced 2025-03-01 16:40:32 +00:00
tcp limits
This commit is contained in:
parent
5c986cf657
commit
41d103bb9b
@ -8,7 +8,7 @@
|
||||
## those terms.
|
||||
|
||||
import oids
|
||||
import chronos, chronicles, sequtils
|
||||
import chronos, chronicles, sequtils, sets
|
||||
import transport,
|
||||
../errors,
|
||||
../wire,
|
||||
@ -22,14 +22,19 @@ logScope:
|
||||
|
||||
const
|
||||
TcpTransportTrackerName* = "libp2p.tcptransport"
|
||||
MaxTCPConnections* = 50
|
||||
|
||||
type
|
||||
TooManyConnections* = object of CatchableError
|
||||
|
||||
TcpTransport* = ref object of Transport
|
||||
server*: StreamServer
|
||||
clients: seq[StreamTransport]
|
||||
conns: HashSet[ChronosStream]
|
||||
flags: set[ServerFlags]
|
||||
cleanups*: seq[Future[void]]
|
||||
handlers*: seq[Future[void]]
|
||||
maxIncoming: int
|
||||
maxOutgoing: int
|
||||
|
||||
TcpTransportTracker* = ref object of TrackerBase
|
||||
opened*: uint64
|
||||
@ -59,12 +64,38 @@ proc setupTcpTransportTracker(): TcpTransportTracker =
|
||||
result.isLeaked = leakTransport
|
||||
addTracker(TcpTransportTrackerName, result)
|
||||
|
||||
proc newTooManyConnections(): ref TooManyConnections =
|
||||
newException(TooManyConnections, "too many inbound connections")
|
||||
|
||||
proc cleanup(t: TcpTransport, conn: ChronosStream) {.async.} =
|
||||
try:
|
||||
await conn.closeEvent.wait()
|
||||
trace "cleaning up socket", addrs = $conn.client.remoteAddress,
|
||||
connoid = $conn.oid
|
||||
if not(isNil(conn)):
|
||||
await conn.close()
|
||||
|
||||
t.conns.excl(conn)
|
||||
|
||||
let inLen = toSeq(t.conns).filterIt( it.dir == Direction.In ).len
|
||||
if inLen < t.maxIncoming:
|
||||
if not isNil(t.server):
|
||||
trace "restarting accept loop", limit = inLen
|
||||
t.server.start()
|
||||
|
||||
except CatchableError as exc:
|
||||
trace "error cleaning up socket", exc = exc.msg
|
||||
|
||||
proc connHandler*(t: TcpTransport,
|
||||
client: StreamTransport,
|
||||
initiator: bool): Connection =
|
||||
trace "handling connection", address = $client.remoteAddress
|
||||
let conn: Connection = Connection(ChronosStream.init(client))
|
||||
conn.observedAddr = MultiAddress.init(client.remoteAddress).tryGet()
|
||||
let stream = ChronosStream.init(client,
|
||||
dir = if initiator: Direction.Out
|
||||
else: Direction.In)
|
||||
|
||||
let conn = Connection(stream)
|
||||
stream.observedAddr = MultiAddress.init(client.remoteAddress).tryGet()
|
||||
if not initiator:
|
||||
if not isNil(t.handler):
|
||||
t.handlers &= t.handler(conn)
|
||||
@ -90,17 +121,31 @@ proc connCb(server: StreamServer,
|
||||
trace "incoming connection", address = $client.remoteAddress
|
||||
try:
|
||||
let t = cast[TcpTransport](server.udata)
|
||||
let inLen = toSeq(t.conns).filterIt( it.dir == Direction.In ).len
|
||||
if inLen + 1 >= t.maxIncoming:
|
||||
trace "connection limit reached", limit = t.maxIncoming,
|
||||
dir = $Direction.In
|
||||
server.stop()
|
||||
await client.closeWait()
|
||||
return
|
||||
|
||||
# we don't need result connection in this case
|
||||
# as it's added inside connHandler
|
||||
discard t.connHandler(client, false)
|
||||
discard connHandler(t, client, false)
|
||||
|
||||
except CancelledError as exc:
|
||||
raise exc
|
||||
except CatchableError as err:
|
||||
debug "Connection setup failed", err = err.msg
|
||||
client.close()
|
||||
|
||||
proc init*(T: type TcpTransport, flags: set[ServerFlags] = {}): T =
|
||||
result = T(flags: flags)
|
||||
proc init*(T: type TcpTransport,
|
||||
flags: set[ServerFlags] = {},
|
||||
maxIncoming, maxOutgoing = MaxTCPConnections): T =
|
||||
result = T(flags: flags,
|
||||
maxIncoming: maxIncoming,
|
||||
maxOutgoing: maxOutgoing)
|
||||
|
||||
result.initTransport()
|
||||
|
||||
method initTransport*(t: TcpTransport) =
|
||||
@ -114,7 +159,7 @@ method close*(t: TcpTransport) {.async, gcsafe.} =
|
||||
await procCall Transport(t).close() # call base
|
||||
|
||||
checkFutures(await allFinished(
|
||||
t.clients.mapIt(it.closeWait())))
|
||||
toSeq(t.conns).mapIt(it.client.closeWait())))
|
||||
|
||||
# server can be nil
|
||||
if not isNil(t.server):
|
||||
@ -153,7 +198,13 @@ method listen*(t: TcpTransport,
|
||||
discard await procCall Transport(t).listen(ma, handler) # call base
|
||||
|
||||
## listen on the transport
|
||||
t.server = createStreamServer(t.ma, connCb, t.flags, t)
|
||||
t.server = createStreamServer(
|
||||
t.ma,
|
||||
connCb,
|
||||
t.flags,
|
||||
t,
|
||||
backlog = t.maxIncoming)
|
||||
|
||||
t.server.start()
|
||||
|
||||
# always get the resolved address in case we're bound to 0.0.0.0:0
|
||||
@ -165,10 +216,18 @@ method dial*(t: TcpTransport,
|
||||
address: MultiAddress):
|
||||
Future[Connection] {.async, gcsafe.} =
|
||||
trace "dialing remote peer", address = $address
|
||||
|
||||
let outLen = toSeq(t.conns).filterIt( it.dir == Direction.Out ).len
|
||||
if outLen + 1 >= t.maxOutgoing:
|
||||
trace "connection limit reached", limit = t.maxOutgoing,
|
||||
dir = $Direction.Out
|
||||
raise newTooManyConnections()
|
||||
|
||||
## dial a peer
|
||||
let client: StreamTransport = await connect(address)
|
||||
result = t.connHandler(client, true)
|
||||
let client = await connect(address)
|
||||
return t.connHandler(client, true)
|
||||
|
||||
method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} =
|
||||
if procCall Transport(t).handles(address):
|
||||
result = address.protocols.tryGet().filterIt( it == multiCodec("tcp") ).len > 0
|
||||
return (address.protocols.tryGet()
|
||||
.filterIt( it == multiCodec("tcp") ).len > 0)
|
||||
|
@ -1,12 +1,13 @@
|
||||
{.used.}
|
||||
|
||||
import unittest
|
||||
import unittest, sequtils
|
||||
import chronos, stew/byteutils
|
||||
import ../libp2p/[stream/connection,
|
||||
transports/transport,
|
||||
transports/tcptransport,
|
||||
multiaddress,
|
||||
wire]
|
||||
wire,
|
||||
errors]
|
||||
import ./helpers
|
||||
|
||||
suite "TCP transport":
|
||||
@ -192,3 +193,51 @@ suite "TCP transport":
|
||||
|
||||
check:
|
||||
waitFor(testListenerDialer()) == true
|
||||
|
||||
test "e2e: should limit incoming connections":
|
||||
proc test() {.async.} =
|
||||
let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
|
||||
var times = 1
|
||||
proc connHandler(conn: Connection) {.async, gcsafe.} =
|
||||
times.inc()
|
||||
|
||||
var transports: seq[TcpTransport]
|
||||
transports.add(TcpTransport.init(maxIncoming = 2))
|
||||
asyncCheck transports[0].listen(ma, connHandler)
|
||||
|
||||
try:
|
||||
for i in 0..10:
|
||||
let transport = TcpTransport.init()
|
||||
transports.add(transport)
|
||||
discard await transport.dial(transports[0].ma).wait(10.millis)
|
||||
except AsyncTimeoutError:
|
||||
check times == 2
|
||||
|
||||
await allFuturesThrowing(
|
||||
transports.mapIt(it.close()))
|
||||
|
||||
waitFor(test())
|
||||
|
||||
test "e2e: should limit outgoing connections":
|
||||
proc test() {.async.} =
|
||||
let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
|
||||
var times = 1
|
||||
proc connHandler(conn: Connection) {.async, gcsafe.} =
|
||||
times.inc()
|
||||
|
||||
var transports: seq[TcpTransport]
|
||||
transports.add(TcpTransport.init())
|
||||
asyncCheck transports[0].listen(ma, connHandler)
|
||||
|
||||
try:
|
||||
let transport = TcpTransport.init(maxOutgoing = 2)
|
||||
transports.add(transport)
|
||||
for i in 0..10:
|
||||
discard await transport.dial(transports[0].ma)
|
||||
except TooManyConnections:
|
||||
check times == 2
|
||||
|
||||
await allFuturesThrowing(
|
||||
transports.mapIt(it.close()))
|
||||
|
||||
waitFor(test())
|
||||
|
Loading…
x
Reference in New Issue
Block a user