Concurrent upgrades (#489)

* adding an upgraded event to conn

* set stopped flag asap

* trigger upgradded event on conn

* set concurrency limit for accepts

* backporting semaphore from tcp-limits2

* export unittests module

* make params explicit

* tone down debug logs

* adding semaphore tests

* use semaphore to throttle concurent upgrades

* add libp2p scope

* trigger upgraded event before any other events

* add event handler for connection upgrade

* cleanup upgraded event on conn close

* make upgrades slot release rebust

* dont forget to release slot on nil connection

* misc

* make sure semaphore is always released

* minor improvements and a nil check

* removing unneeded comment

* make upgradeMonitor a non-closure proc

* make sure the `upgraded` event is initialized

* handle exceptions in accepts when stopping

* don't leak exceptions when stopping accept loops
This commit is contained in:
Dmitriy Ryajov 2021-01-04 12:59:05 -06:00 committed by GitHub
parent e1bc9e7a44
commit b2ea5a3c77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 276 additions and 32 deletions

View File

@ -19,7 +19,8 @@ logScope:
declareGauge(libp2p_peers, "total connected peers") declareGauge(libp2p_peers, "total connected peers")
const MaxConnectionsPerPeer = 5 const
MaxConnectionsPerPeer = 5
type type
TooManyConnections* = object of CatchableError TooManyConnections* = object of CatchableError
@ -236,12 +237,17 @@ proc cleanupConn(c: ConnManager, conn: Connection) {.async.} =
trace "Connection cleaned up", conn trace "Connection cleaned up", conn
proc peerStartup(c: ConnManager, conn: Connection) {.async.} = proc onConnUpgraded(c: ConnManager, conn: Connection) {.async.} =
try: try:
trace "Triggering connect events", conn trace "Triggering connect events", conn
doAssert(not isNil(conn.upgraded),
"The `upgraded` event hasn't been properly initialized!")
conn.upgraded.complete()
let peerId = conn.peerInfo.peerId let peerId = conn.peerInfo.peerId
await c.triggerPeerEvents( await c.triggerPeerEvents(
peerId, PeerEvent(kind: PeerEventKind.Joined, initiator: conn.dir == Direction.Out)) peerId, PeerEvent(kind: PeerEventKind.Joined, initiator: conn.dir == Direction.Out))
await c.triggerConnEvent( await c.triggerConnEvent(
peerId, ConnEvent(kind: ConnEventKind.Connected, incoming: conn.dir == Direction.In)) peerId, ConnEvent(kind: ConnEventKind.Connected, incoming: conn.dir == Direction.In))
except CatchableError as exc: except CatchableError as exc:
@ -384,7 +390,7 @@ proc storeMuxer*(c: ConnManager,
trace "Stored muxer", trace "Stored muxer",
muxer, handle = not handle.isNil, connections = c.conns.len muxer, handle = not handle.isNil, connections = c.conns.len
asyncSpawn c.peerStartup(muxer.connection) asyncSpawn c.onConnUpgraded(muxer.connection)
proc getStream*(c: ConnManager, proc getStream*(c: ConnManager,
peerId: PeerID, peerId: PeerID,

View File

@ -46,6 +46,7 @@ proc init*(T: type SecureConn,
peerInfo: peerInfo, peerInfo: peerInfo,
observedAddr: observedAddr, observedAddr: observedAddr,
closeEvent: conn.closeEvent, closeEvent: conn.closeEvent,
upgraded: conn.upgraded,
timeout: timeout, timeout: timeout,
dir: conn.dir) dir: conn.dir)
result.initStream() result.initStream()

View File

@ -32,6 +32,7 @@ type
timeoutHandler*: TimeoutHandler # timeout handler timeoutHandler*: TimeoutHandler # timeout handler
peerInfo*: PeerInfo peerInfo*: PeerInfo
observedAddr*: Multiaddress observedAddr*: Multiaddress
upgraded*: Future[void]
proc timeoutMonitor(s: Connection) {.async, gcsafe.} proc timeoutMonitor(s: Connection) {.async, gcsafe.}
@ -49,6 +50,9 @@ method initStream*(s: Connection) =
doAssert(isNil(s.timerTaskFut)) doAssert(isNil(s.timerTaskFut))
if isNil(s.upgraded):
s.upgraded = newFuture[void]()
if s.timeout > 0.millis: if s.timeout > 0.millis:
trace "Monitoring for timeout", s, timeout = s.timeout trace "Monitoring for timeout", s, timeout = s.timeout
@ -61,10 +65,15 @@ method initStream*(s: Connection) =
method closeImpl*(s: Connection): Future[void] = method closeImpl*(s: Connection): Future[void] =
# Cleanup timeout timer # Cleanup timeout timer
trace "Closing connection", s trace "Closing connection", s
if not isNil(s.timerTaskFut) and not s.timerTaskFut.finished: if not isNil(s.timerTaskFut) and not s.timerTaskFut.finished:
s.timerTaskFut.cancel() s.timerTaskFut.cancel()
s.timerTaskFut = nil s.timerTaskFut = nil
if not isNil(s.upgraded) and not s.upgraded.finished:
s.upgraded.cancel()
s.upgraded = nil
trace "Closed connection", s trace "Closed connection", s
procCall LPStream(s).closeImpl() procCall LPStream(s).closeImpl()

View File

@ -125,7 +125,7 @@ method initStream*(s: LPStream) {.base.} =
libp2p_open_streams.inc(labelValues = [s.objName, $s.dir]) libp2p_open_streams.inc(labelValues = [s.objName, $s.dir])
inc getStreamTracker(s.objName).opened inc getStreamTracker(s.objName).opened
debug "Stream created", s, objName = s.objName, dir = $s.dir trace "Stream created", s, objName = s.objName, dir = $s.dir
proc join*(s: LPStream): Future[void] = proc join*(s: LPStream): Future[void] =
s.closeEvent.wait() s.closeEvent.wait()
@ -258,7 +258,7 @@ method closeImpl*(s: LPStream): Future[void] {.async, base.} =
s.closeEvent.fire() s.closeEvent.fire()
libp2p_open_streams.dec(labelValues = [s.objName, $s.dir]) libp2p_open_streams.dec(labelValues = [s.objName, $s.dir])
inc getStreamTracker(s.objName).closed inc getStreamTracker(s.objName).closed
debug "Closed stream", s, objName = s.objName, dir = $s.dir trace "Closed stream", s, objName = s.objName, dir = $s.dir
method close*(s: LPStream): Future[void] {.base, async.} = # {.raises [Defect].} method close*(s: LPStream): Future[void] {.base, async.} = # {.raises [Defect].}
## close the stream - this may block, but will not raise exceptions ## close the stream - this may block, but will not raise exceptions

View File

@ -26,6 +26,7 @@ import stream/connection,
peerinfo, peerinfo,
protocols/identify, protocols/identify,
muxers/muxer, muxers/muxer,
utils/semaphore,
connmanager, connmanager,
peerid, peerid,
errors errors
@ -45,6 +46,9 @@ declareCounter(libp2p_dialed_peers, "dialed peers")
declareCounter(libp2p_failed_dials, "failed dials") declareCounter(libp2p_failed_dials, "failed dials")
declareCounter(libp2p_failed_upgrade, "peers failed upgrade") declareCounter(libp2p_failed_upgrade, "peers failed upgrade")
const
ConcurrentUpgrades* = 4
type type
UpgradeFailedError* = object of CatchableError UpgradeFailedError* = object of CatchableError
DialFailedError* = object of CatchableError DialFailedError* = object of CatchableError
@ -223,23 +227,26 @@ proc upgradeIncoming(s: Switch, incomingConn: Connection) {.async, gcsafe.} = #
trace "Starting secure handler", conn trace "Starting secure handler", conn
let secure = s.secureManagers.filterIt(it.codec == proto)[0] let secure = s.secureManagers.filterIt(it.codec == proto)[0]
var sconn: Connection var cconn = conn
try: try:
sconn = await secure.secure(conn, false) var sconn = await secure.secure(cconn, false)
if isNil(sconn): if isNil(sconn):
return return
cconn = sconn
# add the muxer # add the muxer
for muxer in s.muxers.values: for muxer in s.muxers.values:
ms.addHandler(muxer.codecs, muxer) ms.addHandler(muxer.codecs, muxer)
# handle subsequent secure requests # handle subsequent secure requests
await ms.handle(sconn) await ms.handle(cconn)
except CatchableError as exc: except CatchableError as exc:
debug "Exception in secure handler during incoming upgrade", msg = exc.msg, conn debug "Exception in secure handler during incoming upgrade", msg = exc.msg, conn
if not cconn.upgraded.finished:
cconn.upgraded.fail(exc)
finally: finally:
if not isNil(sconn): if not isNil(cconn):
await sconn.close() await cconn.close()
trace "Stopped secure handler", conn trace "Stopped secure handler", conn
@ -254,6 +261,8 @@ proc upgradeIncoming(s: Switch, incomingConn: Connection) {.async, gcsafe.} = #
await ms.handle(incomingConn, active = true) await ms.handle(incomingConn, active = true)
except CatchableError as exc: except CatchableError as exc:
debug "Exception upgrading incoming", exc = exc.msg debug "Exception upgrading incoming", exc = exc.msg
if not incomingConn.upgraded.finished:
incomingConn.upgraded.fail(exc)
finally: finally:
await incomingConn.close() await incomingConn.close()
@ -416,31 +425,61 @@ proc mount*[T: LPProtocol](s: Switch, proto: T, matcher: Matcher = nil) {.gcsafe
s.ms.addHandler(proto.codecs, proto, matcher) s.ms.addHandler(proto.codecs, proto, matcher)
proc upgradeMonitor(conn: Connection, upgrades: AsyncSemaphore) {.async.} =
## monitor connection for upgrades
##
try:
# Since we don't control the flow of the
# upgrade, this timeout guarantees that a
# "hanged" remote doesn't hold the upgrade
# forever
await conn.upgraded.wait(30.seconds) # wait for connection to be upgraded
trace "Connection upgrade succeeded"
except CatchableError as exc:
# if not isNil(conn): # for some reason, this can be nil
await conn.close()
trace "Exception awaiting connection upgrade", exc = exc.msg, conn
finally:
upgrades.release() # don't forget to release the slot!
proc accept(s: Switch, transport: Transport) {.async.} = # noraises proc accept(s: Switch, transport: Transport) {.async.} = # noraises
## transport's accept loop ## switch accept loop, ran for every transport
## ##
let upgrades = AsyncSemaphore.init(ConcurrentUpgrades)
while transport.running: while transport.running:
var conn: Connection var conn: Connection
try: try:
debug "About to accept incoming connection" debug "About to accept incoming connection"
conn = await transport.accept() # remember to always release the slot when
if not isNil(conn): # the upgrade succeeds or fails, this is
debug "Accepted an incoming connection", conn # currently done by the `upgradeMonitor`
asyncSpawn s.upgradeIncoming(conn) # perform upgrade on incoming connection await upgrades.acquire() # first wait for an upgrade slot to become available
else: conn = await transport.accept() # next attempt to get a connection
if isNil(conn):
# A nil connection means that we might have hit a # A nil connection means that we might have hit a
# file-handle limit (or another non-fatal error), # file-handle limit (or another non-fatal error),
# we can get one on the next try, but we should # we can get one on the next try, but we should
# be careful to not end up in a thigh loop that # be careful to not end up in a thigh loop that
# will starve the main event loop, thus we sleep # will starve the main event loop, thus we sleep
# here before retrying. # here before retrying.
trace "Unable to get a connection, sleeping"
await sleepAsync(100.millis) # TODO: should be configurable? await sleepAsync(100.millis) # TODO: should be configurable?
upgrades.release()
continue
debug "Accepted an incoming connection", conn
asyncSpawn upgradeMonitor(conn, upgrades)
asyncSpawn s.upgradeIncoming(conn)
except CancelledError as exc:
trace "releasing semaphore on cancellation"
upgrades.release() # always release the slot
except CatchableError as exc: except CatchableError as exc:
debug "Exception in accept loop, exiting", exc = exc.msg debug "Exception in accept loop, exiting", exc = exc.msg
upgrades.release() # always release the slot
if not isNil(conn): if not isNil(conn):
await conn.close() await conn.close()
return return
proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
@ -460,13 +499,6 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
proc stop*(s: Switch) {.async.} = proc stop*(s: Switch) {.async.} =
trace "Stopping switch" trace "Stopping switch"
for a in s.acceptFuts:
if not a.finished:
a.cancel()
checkFutures(
await allFinished(s.acceptFuts))
# close and cleanup all connections # close and cleanup all connections
await s.connManager.close() await s.connManager.close()
@ -478,6 +510,18 @@ proc stop*(s: Switch) {.async.} =
except CatchableError as exc: except CatchableError as exc:
warn "error cleaning up transports", msg = exc.msg warn "error cleaning up transports", msg = exc.msg
try:
await allFutures(s.acceptFuts)
.wait(1.seconds)
except CatchableError as exc:
trace "Exception while stopping accept loops", exc = exc.msg
# check that all futures were properly
# stopped and otherwise cancel them
for a in s.acceptFuts:
if not a.finished:
a.cancel()
trace "Switch stopped" trace "Switch stopped"
proc muxerHandler(s: Switch, muxer: Muxer) {.async, gcsafe.} = proc muxerHandler(s: Switch, muxer: Muxer) {.async, gcsafe.} =

View File

@ -71,7 +71,7 @@ proc connHandler*(t: TcpTransport,
await client.closeWait() await client.closeWait()
raise exc raise exc
debug "Handling tcp connection", address = $observedAddr, trace "Handling tcp connection", address = $observedAddr,
dir = $dir, dir = $dir,
clients = t.clients[Direction.In].len + clients = t.clients[Direction.In].len +
t.clients[Direction.Out].len t.clients[Direction.Out].len
@ -130,7 +130,10 @@ method start*(t: TcpTransport, ma: MultiAddress) {.async.} =
await procCall Transport(t).start(ma) await procCall Transport(t).start(ma)
trace "Starting TCP transport" trace "Starting TCP transport"
t.server = createStreamServer(t.ma, t.flags, t) t.server = createStreamServer(
ma = t.ma,
flags = t.flags,
udata = t)
# always get the resolved address in case we're bound to 0.0.0.0:0 # always get the resolved address in case we're bound to 0.0.0.0:0
t.ma = MultiAddress.init(t.server.sock.getLocalAddress()).tryGet() t.ma = MultiAddress.init(t.server.sock.getLocalAddress()).tryGet()
@ -142,6 +145,8 @@ method stop*(t: TcpTransport) {.async, gcsafe.} =
## stop the transport ## stop the transport
## ##
t.running = false # mark stopped as soon as possible
try: try:
trace "Stopping TCP transport" trace "Stopping TCP transport"
await procCall Transport(t).stop() # call base await procCall Transport(t).stop() # call base
@ -160,8 +165,6 @@ method stop*(t: TcpTransport) {.async, gcsafe.} =
inc getTcpTransportTracker().closed inc getTcpTransportTracker().closed
except CatchableError as exc: except CatchableError as exc:
trace "Error shutting down tcp transport", exc = exc.msg trace "Error shutting down tcp transport", exc = exc.msg
finally:
t.running = false
method accept*(t: TcpTransport): Future[Connection] {.async, gcsafe.} = method accept*(t: TcpTransport): Future[Connection] {.async, gcsafe.} =
## accept a new TCP connection ## accept a new TCP connection
@ -179,12 +182,12 @@ method accept*(t: TcpTransport): Future[Connection] {.async, gcsafe.} =
# that can't. # that can't.
debug "OS Error", exc = exc.msg debug "OS Error", exc = exc.msg
except TransportTooManyError as exc: except TransportTooManyError as exc:
warn "Too many files opened", exc = exc.msg debug "Too many files opened", exc = exc.msg
except TransportUseClosedError as exc: except TransportUseClosedError as exc:
info "Server was closed", exc = exc.msg debug "Server was closed", exc = exc.msg
raise newTransportClosedError(exc) raise newTransportClosedError(exc)
except CatchableError as exc: except CatchableError as exc:
trace "Unexpected error creating connection", exc = exc.msg warn "Unexpected error creating connection", exc = exc.msg
raise exc raise exc
method dial*(t: TcpTransport, method dial*(t: TcpTransport,

View File

@ -0,0 +1,75 @@
## Nim-Libp2p
## Copyright (c) 2020 Status Research & Development GmbH
## Licensed under either of
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
## at your option.
## This file may not be copied, modified, or distributed except according to
## those terms.
import sequtils
import chronos, chronicles
# TODO: this should probably go in chronos
logScope:
topics = "libp2p semaphore"
type
AsyncSemaphore* = ref object of RootObj
size*: int
count*: int
queue*: seq[Future[void]]
proc init*(T: type AsyncSemaphore, size: int): T =
T(size: size, count: size)
proc tryAcquire*(s: AsyncSemaphore): bool =
## Attempts to acquire a resource, if successful
## returns true, otherwise false
##
if s.count > 0 and s.queue.len == 0:
s.count.dec
trace "Acquired slot", available = s.count, queue = s.queue.len
return true
proc acquire*(s: AsyncSemaphore): Future[void] =
## Acquire a resource and decrement the resource
## counter. If no more resources are available,
## the returned future will not complete until
## the resource count goes above 0 again.
##
let fut = newFuture[void]("AsyncSemaphore.acquire")
if s.tryAcquire():
fut.complete()
return fut
s.queue.add(fut)
s.count.dec
trace "Queued slot", available = s.count, queue = s.queue.len
return fut
proc release*(s: AsyncSemaphore) =
## Release a resource from the semaphore,
## by picking the first future from the queue
## and completing it and incrementing the
## internal resource count
##
doAssert(s.count <= s.size)
if s.count < s.size:
trace "Releasing slot", available = s.count,
queue = s.queue.len
if s.queue.len > 0:
var fut = s.queue.pop()
if not fut.finished():
fut.complete()
s.count.inc # increment the resource count
trace "Released slot", available = s.count,
queue = s.queue.len
return

View File

@ -9,6 +9,8 @@ import ../libp2p/stream/lpstream
import ../libp2p/muxers/mplex/lpchannel import ../libp2p/muxers/mplex/lpchannel
import ../libp2p/protocols/secure/secure import ../libp2p/protocols/secure/secure
export unittest
const const
StreamTransportTrackerName = "stream.transport" StreamTransportTrackerName = "stream.transport"
StreamServerTrackerName = "stream.server" StreamServerTrackerName = "stream.server"

View File

@ -1,6 +1,7 @@
import testvarint, import testvarint,
testminprotobuf, testminprotobuf,
teststreamseq teststreamseq,
testsemaphore
import testminasn1, import testminasn1,
testrsa, testrsa,

103
tests/testsemaphore.nim Normal file
View File

@ -0,0 +1,103 @@
import random
import chronos
import ../libp2p/utils/semaphore
import ./helpers
randomize()
suite "AsyncSemaphore":
asyncTest "should acquire":
let sema = AsyncSemaphore.init(3)
await sema.acquire()
await sema.acquire()
await sema.acquire()
check sema.count == 0
asyncTest "should release":
let sema = AsyncSemaphore.init(3)
await sema.acquire()
await sema.acquire()
await sema.acquire()
check sema.count == 0
sema.release()
sema.release()
sema.release()
check sema.count == 3
asyncTest "should queue acquire":
let sema = AsyncSemaphore.init(1)
await sema.acquire()
let fut = sema.acquire()
check sema.count == -1
check sema.queue.len == 1
sema.release()
sema.release()
check sema.count == 1
await sleepAsync(10.millis)
check fut.finished()
asyncTest "should keep count == size":
let sema = AsyncSemaphore.init(1)
sema.release()
sema.release()
sema.release()
check sema.count == 1
asyncTest "should tryAcquire":
let sema = AsyncSemaphore.init(1)
await sema.acquire()
check sema.tryAcquire() == false
asyncTest "should tryAcquire and acquire":
let sema = AsyncSemaphore.init(4)
check sema.tryAcquire() == true
check sema.tryAcquire() == true
check sema.tryAcquire() == true
check sema.tryAcquire() == true
check sema.count == 0
let fut = sema.acquire()
check fut.finished == false
check sema.count == -1
# queue is only used when count is < 0
check sema.queue.len == 1
sema.release()
sema.release()
sema.release()
sema.release()
sema.release()
check fut.finished == true
check sema.count == 4
check sema.queue.len == 0
asyncTest "should restrict resource access":
let sema = AsyncSemaphore.init(3)
var resource = 0
proc task() {.async.} =
try:
await sema.acquire()
resource.inc()
check resource > 0 and resource <= 3
let sleep = rand(0..10).millis
# echo sleep
await sleepAsync(sleep)
finally:
resource.dec()
sema.release()
var tasks: seq[Future[void]]
for i in 0..<10:
tasks.add(task())
await allFutures(tasks)