mirror of https://github.com/vacp2p/nim-libp2p.git
Concurrent upgrades (#489)
* adding an upgraded event to conn * set stopped flag asap * trigger upgradded event on conn * set concurrency limit for accepts * backporting semaphore from tcp-limits2 * export unittests module * make params explicit * tone down debug logs * adding semaphore tests * use semaphore to throttle concurent upgrades * add libp2p scope * trigger upgraded event before any other events * add event handler for connection upgrade * cleanup upgraded event on conn close * make upgrades slot release rebust * dont forget to release slot on nil connection * misc * make sure semaphore is always released * minor improvements and a nil check * removing unneeded comment * make upgradeMonitor a non-closure proc * make sure the `upgraded` event is initialized * handle exceptions in accepts when stopping * don't leak exceptions when stopping accept loops
This commit is contained in:
parent
e1bc9e7a44
commit
b2ea5a3c77
|
@ -19,7 +19,8 @@ logScope:
|
||||||
|
|
||||||
declareGauge(libp2p_peers, "total connected peers")
|
declareGauge(libp2p_peers, "total connected peers")
|
||||||
|
|
||||||
const MaxConnectionsPerPeer = 5
|
const
|
||||||
|
MaxConnectionsPerPeer = 5
|
||||||
|
|
||||||
type
|
type
|
||||||
TooManyConnections* = object of CatchableError
|
TooManyConnections* = object of CatchableError
|
||||||
|
@ -236,12 +237,17 @@ proc cleanupConn(c: ConnManager, conn: Connection) {.async.} =
|
||||||
|
|
||||||
trace "Connection cleaned up", conn
|
trace "Connection cleaned up", conn
|
||||||
|
|
||||||
proc peerStartup(c: ConnManager, conn: Connection) {.async.} =
|
proc onConnUpgraded(c: ConnManager, conn: Connection) {.async.} =
|
||||||
try:
|
try:
|
||||||
trace "Triggering connect events", conn
|
trace "Triggering connect events", conn
|
||||||
|
doAssert(not isNil(conn.upgraded),
|
||||||
|
"The `upgraded` event hasn't been properly initialized!")
|
||||||
|
conn.upgraded.complete()
|
||||||
|
|
||||||
let peerId = conn.peerInfo.peerId
|
let peerId = conn.peerInfo.peerId
|
||||||
await c.triggerPeerEvents(
|
await c.triggerPeerEvents(
|
||||||
peerId, PeerEvent(kind: PeerEventKind.Joined, initiator: conn.dir == Direction.Out))
|
peerId, PeerEvent(kind: PeerEventKind.Joined, initiator: conn.dir == Direction.Out))
|
||||||
|
|
||||||
await c.triggerConnEvent(
|
await c.triggerConnEvent(
|
||||||
peerId, ConnEvent(kind: ConnEventKind.Connected, incoming: conn.dir == Direction.In))
|
peerId, ConnEvent(kind: ConnEventKind.Connected, incoming: conn.dir == Direction.In))
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
|
@ -384,7 +390,7 @@ proc storeMuxer*(c: ConnManager,
|
||||||
trace "Stored muxer",
|
trace "Stored muxer",
|
||||||
muxer, handle = not handle.isNil, connections = c.conns.len
|
muxer, handle = not handle.isNil, connections = c.conns.len
|
||||||
|
|
||||||
asyncSpawn c.peerStartup(muxer.connection)
|
asyncSpawn c.onConnUpgraded(muxer.connection)
|
||||||
|
|
||||||
proc getStream*(c: ConnManager,
|
proc getStream*(c: ConnManager,
|
||||||
peerId: PeerID,
|
peerId: PeerID,
|
||||||
|
|
|
@ -46,6 +46,7 @@ proc init*(T: type SecureConn,
|
||||||
peerInfo: peerInfo,
|
peerInfo: peerInfo,
|
||||||
observedAddr: observedAddr,
|
observedAddr: observedAddr,
|
||||||
closeEvent: conn.closeEvent,
|
closeEvent: conn.closeEvent,
|
||||||
|
upgraded: conn.upgraded,
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
dir: conn.dir)
|
dir: conn.dir)
|
||||||
result.initStream()
|
result.initStream()
|
||||||
|
|
|
@ -32,6 +32,7 @@ type
|
||||||
timeoutHandler*: TimeoutHandler # timeout handler
|
timeoutHandler*: TimeoutHandler # timeout handler
|
||||||
peerInfo*: PeerInfo
|
peerInfo*: PeerInfo
|
||||||
observedAddr*: Multiaddress
|
observedAddr*: Multiaddress
|
||||||
|
upgraded*: Future[void]
|
||||||
|
|
||||||
proc timeoutMonitor(s: Connection) {.async, gcsafe.}
|
proc timeoutMonitor(s: Connection) {.async, gcsafe.}
|
||||||
|
|
||||||
|
@ -49,6 +50,9 @@ method initStream*(s: Connection) =
|
||||||
|
|
||||||
doAssert(isNil(s.timerTaskFut))
|
doAssert(isNil(s.timerTaskFut))
|
||||||
|
|
||||||
|
if isNil(s.upgraded):
|
||||||
|
s.upgraded = newFuture[void]()
|
||||||
|
|
||||||
if s.timeout > 0.millis:
|
if s.timeout > 0.millis:
|
||||||
trace "Monitoring for timeout", s, timeout = s.timeout
|
trace "Monitoring for timeout", s, timeout = s.timeout
|
||||||
|
|
||||||
|
@ -61,10 +65,15 @@ method initStream*(s: Connection) =
|
||||||
method closeImpl*(s: Connection): Future[void] =
|
method closeImpl*(s: Connection): Future[void] =
|
||||||
# Cleanup timeout timer
|
# Cleanup timeout timer
|
||||||
trace "Closing connection", s
|
trace "Closing connection", s
|
||||||
|
|
||||||
if not isNil(s.timerTaskFut) and not s.timerTaskFut.finished:
|
if not isNil(s.timerTaskFut) and not s.timerTaskFut.finished:
|
||||||
s.timerTaskFut.cancel()
|
s.timerTaskFut.cancel()
|
||||||
s.timerTaskFut = nil
|
s.timerTaskFut = nil
|
||||||
|
|
||||||
|
if not isNil(s.upgraded) and not s.upgraded.finished:
|
||||||
|
s.upgraded.cancel()
|
||||||
|
s.upgraded = nil
|
||||||
|
|
||||||
trace "Closed connection", s
|
trace "Closed connection", s
|
||||||
|
|
||||||
procCall LPStream(s).closeImpl()
|
procCall LPStream(s).closeImpl()
|
||||||
|
|
|
@ -125,7 +125,7 @@ method initStream*(s: LPStream) {.base.} =
|
||||||
|
|
||||||
libp2p_open_streams.inc(labelValues = [s.objName, $s.dir])
|
libp2p_open_streams.inc(labelValues = [s.objName, $s.dir])
|
||||||
inc getStreamTracker(s.objName).opened
|
inc getStreamTracker(s.objName).opened
|
||||||
debug "Stream created", s, objName = s.objName, dir = $s.dir
|
trace "Stream created", s, objName = s.objName, dir = $s.dir
|
||||||
|
|
||||||
proc join*(s: LPStream): Future[void] =
|
proc join*(s: LPStream): Future[void] =
|
||||||
s.closeEvent.wait()
|
s.closeEvent.wait()
|
||||||
|
@ -258,7 +258,7 @@ method closeImpl*(s: LPStream): Future[void] {.async, base.} =
|
||||||
s.closeEvent.fire()
|
s.closeEvent.fire()
|
||||||
libp2p_open_streams.dec(labelValues = [s.objName, $s.dir])
|
libp2p_open_streams.dec(labelValues = [s.objName, $s.dir])
|
||||||
inc getStreamTracker(s.objName).closed
|
inc getStreamTracker(s.objName).closed
|
||||||
debug "Closed stream", s, objName = s.objName, dir = $s.dir
|
trace "Closed stream", s, objName = s.objName, dir = $s.dir
|
||||||
|
|
||||||
method close*(s: LPStream): Future[void] {.base, async.} = # {.raises [Defect].}
|
method close*(s: LPStream): Future[void] {.base, async.} = # {.raises [Defect].}
|
||||||
## close the stream - this may block, but will not raise exceptions
|
## close the stream - this may block, but will not raise exceptions
|
||||||
|
|
|
@ -26,6 +26,7 @@ import stream/connection,
|
||||||
peerinfo,
|
peerinfo,
|
||||||
protocols/identify,
|
protocols/identify,
|
||||||
muxers/muxer,
|
muxers/muxer,
|
||||||
|
utils/semaphore,
|
||||||
connmanager,
|
connmanager,
|
||||||
peerid,
|
peerid,
|
||||||
errors
|
errors
|
||||||
|
@ -45,6 +46,9 @@ declareCounter(libp2p_dialed_peers, "dialed peers")
|
||||||
declareCounter(libp2p_failed_dials, "failed dials")
|
declareCounter(libp2p_failed_dials, "failed dials")
|
||||||
declareCounter(libp2p_failed_upgrade, "peers failed upgrade")
|
declareCounter(libp2p_failed_upgrade, "peers failed upgrade")
|
||||||
|
|
||||||
|
const
|
||||||
|
ConcurrentUpgrades* = 4
|
||||||
|
|
||||||
type
|
type
|
||||||
UpgradeFailedError* = object of CatchableError
|
UpgradeFailedError* = object of CatchableError
|
||||||
DialFailedError* = object of CatchableError
|
DialFailedError* = object of CatchableError
|
||||||
|
@ -223,23 +227,26 @@ proc upgradeIncoming(s: Switch, incomingConn: Connection) {.async, gcsafe.} = #
|
||||||
trace "Starting secure handler", conn
|
trace "Starting secure handler", conn
|
||||||
let secure = s.secureManagers.filterIt(it.codec == proto)[0]
|
let secure = s.secureManagers.filterIt(it.codec == proto)[0]
|
||||||
|
|
||||||
var sconn: Connection
|
var cconn = conn
|
||||||
try:
|
try:
|
||||||
sconn = await secure.secure(conn, false)
|
var sconn = await secure.secure(cconn, false)
|
||||||
if isNil(sconn):
|
if isNil(sconn):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
cconn = sconn
|
||||||
# add the muxer
|
# add the muxer
|
||||||
for muxer in s.muxers.values:
|
for muxer in s.muxers.values:
|
||||||
ms.addHandler(muxer.codecs, muxer)
|
ms.addHandler(muxer.codecs, muxer)
|
||||||
|
|
||||||
# handle subsequent secure requests
|
# handle subsequent secure requests
|
||||||
await ms.handle(sconn)
|
await ms.handle(cconn)
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
debug "Exception in secure handler during incoming upgrade", msg = exc.msg, conn
|
debug "Exception in secure handler during incoming upgrade", msg = exc.msg, conn
|
||||||
|
if not cconn.upgraded.finished:
|
||||||
|
cconn.upgraded.fail(exc)
|
||||||
finally:
|
finally:
|
||||||
if not isNil(sconn):
|
if not isNil(cconn):
|
||||||
await sconn.close()
|
await cconn.close()
|
||||||
|
|
||||||
trace "Stopped secure handler", conn
|
trace "Stopped secure handler", conn
|
||||||
|
|
||||||
|
@ -254,6 +261,8 @@ proc upgradeIncoming(s: Switch, incomingConn: Connection) {.async, gcsafe.} = #
|
||||||
await ms.handle(incomingConn, active = true)
|
await ms.handle(incomingConn, active = true)
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
debug "Exception upgrading incoming", exc = exc.msg
|
debug "Exception upgrading incoming", exc = exc.msg
|
||||||
|
if not incomingConn.upgraded.finished:
|
||||||
|
incomingConn.upgraded.fail(exc)
|
||||||
finally:
|
finally:
|
||||||
await incomingConn.close()
|
await incomingConn.close()
|
||||||
|
|
||||||
|
@ -416,31 +425,61 @@ proc mount*[T: LPProtocol](s: Switch, proto: T, matcher: Matcher = nil) {.gcsafe
|
||||||
|
|
||||||
s.ms.addHandler(proto.codecs, proto, matcher)
|
s.ms.addHandler(proto.codecs, proto, matcher)
|
||||||
|
|
||||||
|
proc upgradeMonitor(conn: Connection, upgrades: AsyncSemaphore) {.async.} =
|
||||||
|
## monitor connection for upgrades
|
||||||
|
##
|
||||||
|
try:
|
||||||
|
# Since we don't control the flow of the
|
||||||
|
# upgrade, this timeout guarantees that a
|
||||||
|
# "hanged" remote doesn't hold the upgrade
|
||||||
|
# forever
|
||||||
|
await conn.upgraded.wait(30.seconds) # wait for connection to be upgraded
|
||||||
|
trace "Connection upgrade succeeded"
|
||||||
|
except CatchableError as exc:
|
||||||
|
# if not isNil(conn): # for some reason, this can be nil
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
trace "Exception awaiting connection upgrade", exc = exc.msg, conn
|
||||||
|
finally:
|
||||||
|
upgrades.release() # don't forget to release the slot!
|
||||||
|
|
||||||
proc accept(s: Switch, transport: Transport) {.async.} = # noraises
|
proc accept(s: Switch, transport: Transport) {.async.} = # noraises
|
||||||
## transport's accept loop
|
## switch accept loop, ran for every transport
|
||||||
##
|
##
|
||||||
|
|
||||||
|
let upgrades = AsyncSemaphore.init(ConcurrentUpgrades)
|
||||||
while transport.running:
|
while transport.running:
|
||||||
var conn: Connection
|
var conn: Connection
|
||||||
try:
|
try:
|
||||||
debug "About to accept incoming connection"
|
debug "About to accept incoming connection"
|
||||||
conn = await transport.accept()
|
# remember to always release the slot when
|
||||||
if not isNil(conn):
|
# the upgrade succeeds or fails, this is
|
||||||
debug "Accepted an incoming connection", conn
|
# currently done by the `upgradeMonitor`
|
||||||
asyncSpawn s.upgradeIncoming(conn) # perform upgrade on incoming connection
|
await upgrades.acquire() # first wait for an upgrade slot to become available
|
||||||
else:
|
conn = await transport.accept() # next attempt to get a connection
|
||||||
|
if isNil(conn):
|
||||||
# A nil connection means that we might have hit a
|
# A nil connection means that we might have hit a
|
||||||
# file-handle limit (or another non-fatal error),
|
# file-handle limit (or another non-fatal error),
|
||||||
# we can get one on the next try, but we should
|
# we can get one on the next try, but we should
|
||||||
# be careful to not end up in a thigh loop that
|
# be careful to not end up in a thigh loop that
|
||||||
# will starve the main event loop, thus we sleep
|
# will starve the main event loop, thus we sleep
|
||||||
# here before retrying.
|
# here before retrying.
|
||||||
|
trace "Unable to get a connection, sleeping"
|
||||||
await sleepAsync(100.millis) # TODO: should be configurable?
|
await sleepAsync(100.millis) # TODO: should be configurable?
|
||||||
|
upgrades.release()
|
||||||
|
continue
|
||||||
|
|
||||||
|
debug "Accepted an incoming connection", conn
|
||||||
|
asyncSpawn upgradeMonitor(conn, upgrades)
|
||||||
|
asyncSpawn s.upgradeIncoming(conn)
|
||||||
|
except CancelledError as exc:
|
||||||
|
trace "releasing semaphore on cancellation"
|
||||||
|
upgrades.release() # always release the slot
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
debug "Exception in accept loop, exiting", exc = exc.msg
|
debug "Exception in accept loop, exiting", exc = exc.msg
|
||||||
|
upgrades.release() # always release the slot
|
||||||
if not isNil(conn):
|
if not isNil(conn):
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
|
proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
|
||||||
|
@ -460,13 +499,6 @@ proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} =
|
||||||
proc stop*(s: Switch) {.async.} =
|
proc stop*(s: Switch) {.async.} =
|
||||||
trace "Stopping switch"
|
trace "Stopping switch"
|
||||||
|
|
||||||
for a in s.acceptFuts:
|
|
||||||
if not a.finished:
|
|
||||||
a.cancel()
|
|
||||||
|
|
||||||
checkFutures(
|
|
||||||
await allFinished(s.acceptFuts))
|
|
||||||
|
|
||||||
# close and cleanup all connections
|
# close and cleanup all connections
|
||||||
await s.connManager.close()
|
await s.connManager.close()
|
||||||
|
|
||||||
|
@ -478,6 +510,18 @@ proc stop*(s: Switch) {.async.} =
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
warn "error cleaning up transports", msg = exc.msg
|
warn "error cleaning up transports", msg = exc.msg
|
||||||
|
|
||||||
|
try:
|
||||||
|
await allFutures(s.acceptFuts)
|
||||||
|
.wait(1.seconds)
|
||||||
|
except CatchableError as exc:
|
||||||
|
trace "Exception while stopping accept loops", exc = exc.msg
|
||||||
|
|
||||||
|
# check that all futures were properly
|
||||||
|
# stopped and otherwise cancel them
|
||||||
|
for a in s.acceptFuts:
|
||||||
|
if not a.finished:
|
||||||
|
a.cancel()
|
||||||
|
|
||||||
trace "Switch stopped"
|
trace "Switch stopped"
|
||||||
|
|
||||||
proc muxerHandler(s: Switch, muxer: Muxer) {.async, gcsafe.} =
|
proc muxerHandler(s: Switch, muxer: Muxer) {.async, gcsafe.} =
|
||||||
|
|
|
@ -71,7 +71,7 @@ proc connHandler*(t: TcpTransport,
|
||||||
await client.closeWait()
|
await client.closeWait()
|
||||||
raise exc
|
raise exc
|
||||||
|
|
||||||
debug "Handling tcp connection", address = $observedAddr,
|
trace "Handling tcp connection", address = $observedAddr,
|
||||||
dir = $dir,
|
dir = $dir,
|
||||||
clients = t.clients[Direction.In].len +
|
clients = t.clients[Direction.In].len +
|
||||||
t.clients[Direction.Out].len
|
t.clients[Direction.Out].len
|
||||||
|
@ -130,7 +130,10 @@ method start*(t: TcpTransport, ma: MultiAddress) {.async.} =
|
||||||
await procCall Transport(t).start(ma)
|
await procCall Transport(t).start(ma)
|
||||||
trace "Starting TCP transport"
|
trace "Starting TCP transport"
|
||||||
|
|
||||||
t.server = createStreamServer(t.ma, t.flags, t)
|
t.server = createStreamServer(
|
||||||
|
ma = t.ma,
|
||||||
|
flags = t.flags,
|
||||||
|
udata = t)
|
||||||
|
|
||||||
# always get the resolved address in case we're bound to 0.0.0.0:0
|
# always get the resolved address in case we're bound to 0.0.0.0:0
|
||||||
t.ma = MultiAddress.init(t.server.sock.getLocalAddress()).tryGet()
|
t.ma = MultiAddress.init(t.server.sock.getLocalAddress()).tryGet()
|
||||||
|
@ -142,6 +145,8 @@ method stop*(t: TcpTransport) {.async, gcsafe.} =
|
||||||
## stop the transport
|
## stop the transport
|
||||||
##
|
##
|
||||||
|
|
||||||
|
t.running = false # mark stopped as soon as possible
|
||||||
|
|
||||||
try:
|
try:
|
||||||
trace "Stopping TCP transport"
|
trace "Stopping TCP transport"
|
||||||
await procCall Transport(t).stop() # call base
|
await procCall Transport(t).stop() # call base
|
||||||
|
@ -160,8 +165,6 @@ method stop*(t: TcpTransport) {.async, gcsafe.} =
|
||||||
inc getTcpTransportTracker().closed
|
inc getTcpTransportTracker().closed
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
trace "Error shutting down tcp transport", exc = exc.msg
|
trace "Error shutting down tcp transport", exc = exc.msg
|
||||||
finally:
|
|
||||||
t.running = false
|
|
||||||
|
|
||||||
method accept*(t: TcpTransport): Future[Connection] {.async, gcsafe.} =
|
method accept*(t: TcpTransport): Future[Connection] {.async, gcsafe.} =
|
||||||
## accept a new TCP connection
|
## accept a new TCP connection
|
||||||
|
@ -179,12 +182,12 @@ method accept*(t: TcpTransport): Future[Connection] {.async, gcsafe.} =
|
||||||
# that can't.
|
# that can't.
|
||||||
debug "OS Error", exc = exc.msg
|
debug "OS Error", exc = exc.msg
|
||||||
except TransportTooManyError as exc:
|
except TransportTooManyError as exc:
|
||||||
warn "Too many files opened", exc = exc.msg
|
debug "Too many files opened", exc = exc.msg
|
||||||
except TransportUseClosedError as exc:
|
except TransportUseClosedError as exc:
|
||||||
info "Server was closed", exc = exc.msg
|
debug "Server was closed", exc = exc.msg
|
||||||
raise newTransportClosedError(exc)
|
raise newTransportClosedError(exc)
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
trace "Unexpected error creating connection", exc = exc.msg
|
warn "Unexpected error creating connection", exc = exc.msg
|
||||||
raise exc
|
raise exc
|
||||||
|
|
||||||
method dial*(t: TcpTransport,
|
method dial*(t: TcpTransport,
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
## Nim-Libp2p
|
||||||
|
## Copyright (c) 2020 Status Research & Development GmbH
|
||||||
|
## Licensed under either of
|
||||||
|
## * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
|
||||||
|
## * MIT license ([LICENSE-MIT](LICENSE-MIT))
|
||||||
|
## at your option.
|
||||||
|
## This file may not be copied, modified, or distributed except according to
|
||||||
|
## those terms.
|
||||||
|
|
||||||
|
import sequtils
|
||||||
|
import chronos, chronicles
|
||||||
|
|
||||||
|
# TODO: this should probably go in chronos
|
||||||
|
|
||||||
|
logScope:
|
||||||
|
topics = "libp2p semaphore"
|
||||||
|
|
||||||
|
type
|
||||||
|
AsyncSemaphore* = ref object of RootObj
|
||||||
|
size*: int
|
||||||
|
count*: int
|
||||||
|
queue*: seq[Future[void]]
|
||||||
|
|
||||||
|
proc init*(T: type AsyncSemaphore, size: int): T =
|
||||||
|
T(size: size, count: size)
|
||||||
|
|
||||||
|
proc tryAcquire*(s: AsyncSemaphore): bool =
|
||||||
|
## Attempts to acquire a resource, if successful
|
||||||
|
## returns true, otherwise false
|
||||||
|
##
|
||||||
|
|
||||||
|
if s.count > 0 and s.queue.len == 0:
|
||||||
|
s.count.dec
|
||||||
|
trace "Acquired slot", available = s.count, queue = s.queue.len
|
||||||
|
return true
|
||||||
|
|
||||||
|
proc acquire*(s: AsyncSemaphore): Future[void] =
|
||||||
|
## Acquire a resource and decrement the resource
|
||||||
|
## counter. If no more resources are available,
|
||||||
|
## the returned future will not complete until
|
||||||
|
## the resource count goes above 0 again.
|
||||||
|
##
|
||||||
|
|
||||||
|
let fut = newFuture[void]("AsyncSemaphore.acquire")
|
||||||
|
if s.tryAcquire():
|
||||||
|
fut.complete()
|
||||||
|
return fut
|
||||||
|
|
||||||
|
s.queue.add(fut)
|
||||||
|
s.count.dec
|
||||||
|
trace "Queued slot", available = s.count, queue = s.queue.len
|
||||||
|
return fut
|
||||||
|
|
||||||
|
proc release*(s: AsyncSemaphore) =
|
||||||
|
## Release a resource from the semaphore,
|
||||||
|
## by picking the first future from the queue
|
||||||
|
## and completing it and incrementing the
|
||||||
|
## internal resource count
|
||||||
|
##
|
||||||
|
|
||||||
|
doAssert(s.count <= s.size)
|
||||||
|
|
||||||
|
if s.count < s.size:
|
||||||
|
trace "Releasing slot", available = s.count,
|
||||||
|
queue = s.queue.len
|
||||||
|
|
||||||
|
if s.queue.len > 0:
|
||||||
|
var fut = s.queue.pop()
|
||||||
|
if not fut.finished():
|
||||||
|
fut.complete()
|
||||||
|
|
||||||
|
s.count.inc # increment the resource count
|
||||||
|
trace "Released slot", available = s.count,
|
||||||
|
queue = s.queue.len
|
||||||
|
return
|
|
@ -9,6 +9,8 @@ import ../libp2p/stream/lpstream
|
||||||
import ../libp2p/muxers/mplex/lpchannel
|
import ../libp2p/muxers/mplex/lpchannel
|
||||||
import ../libp2p/protocols/secure/secure
|
import ../libp2p/protocols/secure/secure
|
||||||
|
|
||||||
|
export unittest
|
||||||
|
|
||||||
const
|
const
|
||||||
StreamTransportTrackerName = "stream.transport"
|
StreamTransportTrackerName = "stream.transport"
|
||||||
StreamServerTrackerName = "stream.server"
|
StreamServerTrackerName = "stream.server"
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import testvarint,
|
import testvarint,
|
||||||
testminprotobuf,
|
testminprotobuf,
|
||||||
teststreamseq
|
teststreamseq,
|
||||||
|
testsemaphore
|
||||||
|
|
||||||
import testminasn1,
|
import testminasn1,
|
||||||
testrsa,
|
testrsa,
|
||||||
|
|
|
@ -0,0 +1,103 @@
|
||||||
|
import random
|
||||||
|
import chronos
|
||||||
|
import ../libp2p/utils/semaphore
|
||||||
|
|
||||||
|
import ./helpers
|
||||||
|
|
||||||
|
randomize()
|
||||||
|
|
||||||
|
suite "AsyncSemaphore":
|
||||||
|
asyncTest "should acquire":
|
||||||
|
let sema = AsyncSemaphore.init(3)
|
||||||
|
|
||||||
|
await sema.acquire()
|
||||||
|
await sema.acquire()
|
||||||
|
await sema.acquire()
|
||||||
|
|
||||||
|
check sema.count == 0
|
||||||
|
|
||||||
|
asyncTest "should release":
|
||||||
|
let sema = AsyncSemaphore.init(3)
|
||||||
|
|
||||||
|
await sema.acquire()
|
||||||
|
await sema.acquire()
|
||||||
|
await sema.acquire()
|
||||||
|
|
||||||
|
check sema.count == 0
|
||||||
|
sema.release()
|
||||||
|
sema.release()
|
||||||
|
sema.release()
|
||||||
|
check sema.count == 3
|
||||||
|
|
||||||
|
asyncTest "should queue acquire":
|
||||||
|
let sema = AsyncSemaphore.init(1)
|
||||||
|
|
||||||
|
await sema.acquire()
|
||||||
|
let fut = sema.acquire()
|
||||||
|
|
||||||
|
check sema.count == -1
|
||||||
|
check sema.queue.len == 1
|
||||||
|
sema.release()
|
||||||
|
sema.release()
|
||||||
|
check sema.count == 1
|
||||||
|
|
||||||
|
await sleepAsync(10.millis)
|
||||||
|
check fut.finished()
|
||||||
|
|
||||||
|
asyncTest "should keep count == size":
|
||||||
|
let sema = AsyncSemaphore.init(1)
|
||||||
|
sema.release()
|
||||||
|
sema.release()
|
||||||
|
sema.release()
|
||||||
|
check sema.count == 1
|
||||||
|
|
||||||
|
asyncTest "should tryAcquire":
|
||||||
|
let sema = AsyncSemaphore.init(1)
|
||||||
|
await sema.acquire()
|
||||||
|
check sema.tryAcquire() == false
|
||||||
|
|
||||||
|
asyncTest "should tryAcquire and acquire":
|
||||||
|
let sema = AsyncSemaphore.init(4)
|
||||||
|
check sema.tryAcquire() == true
|
||||||
|
check sema.tryAcquire() == true
|
||||||
|
check sema.tryAcquire() == true
|
||||||
|
check sema.tryAcquire() == true
|
||||||
|
check sema.count == 0
|
||||||
|
|
||||||
|
let fut = sema.acquire()
|
||||||
|
check fut.finished == false
|
||||||
|
check sema.count == -1
|
||||||
|
# queue is only used when count is < 0
|
||||||
|
check sema.queue.len == 1
|
||||||
|
|
||||||
|
sema.release()
|
||||||
|
sema.release()
|
||||||
|
sema.release()
|
||||||
|
sema.release()
|
||||||
|
sema.release()
|
||||||
|
|
||||||
|
check fut.finished == true
|
||||||
|
check sema.count == 4
|
||||||
|
check sema.queue.len == 0
|
||||||
|
|
||||||
|
asyncTest "should restrict resource access":
|
||||||
|
let sema = AsyncSemaphore.init(3)
|
||||||
|
var resource = 0
|
||||||
|
|
||||||
|
proc task() {.async.} =
|
||||||
|
try:
|
||||||
|
await sema.acquire()
|
||||||
|
resource.inc()
|
||||||
|
check resource > 0 and resource <= 3
|
||||||
|
let sleep = rand(0..10).millis
|
||||||
|
# echo sleep
|
||||||
|
await sleepAsync(sleep)
|
||||||
|
finally:
|
||||||
|
resource.dec()
|
||||||
|
sema.release()
|
||||||
|
|
||||||
|
var tasks: seq[Future[void]]
|
||||||
|
for i in 0..<10:
|
||||||
|
tasks.add(task())
|
||||||
|
|
||||||
|
await allFutures(tasks)
|
Loading…
Reference in New Issue