tcp limits

This commit is contained in:
Dmitriy Ryajov 2020-08-04 23:57:04 -06:00
parent 5c986cf657
commit 41d103bb9b
No known key found for this signature in database
GPG Key ID: DA8C680CE7C657A4
2 changed files with 122 additions and 14 deletions

View File

@ -8,7 +8,7 @@
## those terms.
import oids
import chronos, chronicles, sequtils
import chronos, chronicles, sequtils, sets
import transport,
../errors,
../wire,
@ -22,14 +22,19 @@ logScope:
const
TcpTransportTrackerName* = "libp2p.tcptransport"
MaxTCPConnections* = 50
type
TooManyConnections* = object of CatchableError
TcpTransport* = ref object of Transport
server*: StreamServer
clients: seq[StreamTransport]
conns: HashSet[ChronosStream]
flags: set[ServerFlags]
cleanups*: seq[Future[void]]
handlers*: seq[Future[void]]
maxIncoming: int
maxOutgoing: int
TcpTransportTracker* = ref object of TrackerBase
opened*: uint64
@ -59,12 +64,38 @@ proc setupTcpTransportTracker(): TcpTransportTracker =
result.isLeaked = leakTransport
addTracker(TcpTransportTrackerName, result)
proc newTooManyConnections(): ref TooManyConnections =
newException(TooManyConnections, "too many inbound connections")
proc cleanup(t: TcpTransport, conn: ChronosStream) {.async.} =
try:
await conn.closeEvent.wait()
trace "cleaning up socket", addrs = $conn.client.remoteAddress,
connoid = $conn.oid
if not(isNil(conn)):
await conn.close()
t.conns.excl(conn)
let inLen = toSeq(t.conns).filterIt( it.dir == Direction.In ).len
if inLen < t.maxIncoming:
if not isNil(t.server):
trace "restarting accept loop", limit = inLen
t.server.start()
except CatchableError as exc:
trace "error cleaning up socket", exc = exc.msg
proc connHandler*(t: TcpTransport,
client: StreamTransport,
initiator: bool): Connection =
trace "handling connection", address = $client.remoteAddress
let conn: Connection = Connection(ChronosStream.init(client))
conn.observedAddr = MultiAddress.init(client.remoteAddress).tryGet()
let stream = ChronosStream.init(client,
dir = if initiator: Direction.Out
else: Direction.In)
let conn = Connection(stream)
stream.observedAddr = MultiAddress.init(client.remoteAddress).tryGet()
if not initiator:
if not isNil(t.handler):
t.handlers &= t.handler(conn)
@ -90,17 +121,31 @@ proc connCb(server: StreamServer,
trace "incoming connection", address = $client.remoteAddress
try:
let t = cast[TcpTransport](server.udata)
let inLen = toSeq(t.conns).filterIt( it.dir == Direction.In ).len
if inLen + 1 >= t.maxIncoming:
trace "connection limit reached", limit = t.maxIncoming,
dir = $Direction.In
server.stop()
await client.closeWait()
return
# we don't need result connection in this case
# as it's added inside connHandler
discard t.connHandler(client, false)
discard connHandler(t, client, false)
except CancelledError as exc:
raise exc
except CatchableError as err:
debug "Connection setup failed", err = err.msg
client.close()
proc init*(T: type TcpTransport, flags: set[ServerFlags] = {}): T =
result = T(flags: flags)
proc init*(T: type TcpTransport,
flags: set[ServerFlags] = {},
maxIncoming, maxOutgoing = MaxTCPConnections): T =
result = T(flags: flags,
maxIncoming: maxIncoming,
maxOutgoing: maxOutgoing)
result.initTransport()
method initTransport*(t: TcpTransport) =
@ -114,7 +159,7 @@ method close*(t: TcpTransport) {.async, gcsafe.} =
await procCall Transport(t).close() # call base
checkFutures(await allFinished(
t.clients.mapIt(it.closeWait())))
toSeq(t.conns).mapIt(it.client.closeWait())))
# server can be nil
if not isNil(t.server):
@ -153,7 +198,13 @@ method listen*(t: TcpTransport,
discard await procCall Transport(t).listen(ma, handler) # call base
## listen on the transport
t.server = createStreamServer(t.ma, connCb, t.flags, t)
t.server = createStreamServer(
t.ma,
connCb,
t.flags,
t,
backlog = t.maxIncoming)
t.server.start()
# always get the resolved address in case we're bound to 0.0.0.0:0
@ -165,10 +216,18 @@ method dial*(t: TcpTransport,
address: MultiAddress):
Future[Connection] {.async, gcsafe.} =
trace "dialing remote peer", address = $address
let outLen = toSeq(t.conns).filterIt( it.dir == Direction.Out ).len
if outLen + 1 >= t.maxOutgoing:
trace "connection limit reached", limit = t.maxOutgoing,
dir = $Direction.Out
raise newTooManyConnections()
## dial a peer
let client: StreamTransport = await connect(address)
result = t.connHandler(client, true)
let client = await connect(address)
return t.connHandler(client, true)
method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} =
if procCall Transport(t).handles(address):
result = address.protocols.tryGet().filterIt( it == multiCodec("tcp") ).len > 0
return (address.protocols.tryGet()
.filterIt( it == multiCodec("tcp") ).len > 0)

View File

@ -1,12 +1,13 @@
{.used.}
import unittest
import unittest, sequtils
import chronos, stew/byteutils
import ../libp2p/[stream/connection,
transports/transport,
transports/tcptransport,
multiaddress,
wire]
wire,
errors]
import ./helpers
suite "TCP transport":
@ -192,3 +193,51 @@ suite "TCP transport":
check:
waitFor(testListenerDialer()) == true
test "e2e: should limit incoming connections":
proc test() {.async.} =
let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
var times = 1
proc connHandler(conn: Connection) {.async, gcsafe.} =
times.inc()
var transports: seq[TcpTransport]
transports.add(TcpTransport.init(maxIncoming = 2))
asyncCheck transports[0].listen(ma, connHandler)
try:
for i in 0..10:
let transport = TcpTransport.init()
transports.add(transport)
discard await transport.dial(transports[0].ma).wait(10.millis)
except AsyncTimeoutError:
check times == 2
await allFuturesThrowing(
transports.mapIt(it.close()))
waitFor(test())
test "e2e: should limit outgoing connections":
proc test() {.async.} =
let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet()
var times = 1
proc connHandler(conn: Connection) {.async, gcsafe.} =
times.inc()
var transports: seq[TcpTransport]
transports.add(TcpTransport.init())
asyncCheck transports[0].listen(ma, connHandler)
try:
let transport = TcpTransport.init(maxOutgoing = 2)
transports.add(transport)
for i in 0..10:
discard await transport.dial(transports[0].ma)
except TooManyConnections:
check times == 2
await allFuturesThrowing(
transports.mapIt(it.close()))
waitFor(test())