fix(transport): various tcp transport races (#1095)

Co-authored-by: diegomrsantos <diego@status.im>
This commit is contained in:
Jacek Sieka 2024-05-14 07:10:34 +02:00 committed by GitHub
parent 1b91b97499
commit 3ca49a2f40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 273 additions and 219 deletions

View File

@ -81,12 +81,14 @@ proc dialAndUpgrade(
if dialed.dir != dir: if dialed.dir != dir:
dialed.dir = dir dialed.dir = dir
await transport.upgrade(dialed, peerId) await transport.upgrade(dialed, peerId)
except CancelledError as exc:
await dialed.close()
raise exc
except CatchableError as exc: except CatchableError as exc:
# If we failed to establish the connection through one transport, # If we failed to establish the connection through one transport,
# we won't succeeded through another - no use in trying again # we won't succeeded through another - no use in trying again
await dialed.close() await dialed.close()
debug "Connection upgrade failed", err = exc.msg, peerId = peerId.get(default(PeerId)) debug "Connection upgrade failed", err = exc.msg, peerId = peerId.get(default(PeerId))
if exc isnot CancelledError:
if dialed.dir == Direction.Out: if dialed.dir == Direction.Out:
libp2p_failed_upgrades_outgoing.inc() libp2p_failed_upgrades_outgoing.inc()
else: else:

View File

@ -44,12 +44,3 @@ macro checkFutures*[F](futs: seq[F], exclude: untyped = []): untyped =
# We still don't abort but warn # We still don't abort but warn
debug "A future has failed, enable trace logging for details", error=exc.name debug "A future has failed, enable trace logging for details", error=exc.name
trace "Exception details", msg=exc.msg trace "Exception details", msg=exc.msg
template tryAndWarn*(message: static[string]; body: untyped): untyped =
try:
body
except CancelledError as exc:
raise exc
except CatchableError as exc:
debug "An exception has ocurred, enable trace logging for details", name = exc.name, msg = message
trace "Exception details", exc = exc.msg

View File

@ -273,6 +273,7 @@ proc accept(s: Switch, transport: Transport) {.async.} = # noraises
except CancelledError as exc: except CancelledError as exc:
trace "releasing semaphore on cancellation" trace "releasing semaphore on cancellation"
upgrades.release() # always release the slot upgrades.release() # always release the slot
return
except CatchableError as exc: except CatchableError as exc:
error "Exception in accept loop, exiting", exc = exc.msg error "Exception in accept loop, exiting", exc = exc.msg
upgrades.release() # always release the slot upgrades.release() # always release the slot
@ -288,6 +289,12 @@ proc stop*(s: Switch) {.async, public.} =
s.started = false s.started = false
try:
# Stop accepting incoming connections
await allFutures(s.acceptFuts.mapIt(it.cancelAndWait())).wait(1.seconds)
except CatchableError as exc:
debug "Cannot cancel accepts", error = exc.msg
for service in s.services: for service in s.services:
discard await service.stop(s) discard await service.stop(s)
@ -302,18 +309,6 @@ proc stop*(s: Switch) {.async, public.} =
except CatchableError as exc: except CatchableError as exc:
warn "error cleaning up transports", msg = exc.msg warn "error cleaning up transports", msg = exc.msg
try:
await allFutures(s.acceptFuts)
.wait(1.seconds)
except CatchableError as exc:
trace "Exception while stopping accept loops", exc = exc.msg
# check that all futures were properly
# stopped and otherwise cancel them
for a in s.acceptFuts:
if not a.finished:
a.cancel()
for service in s.services: for service in s.services:
discard await service.stop(s) discard await service.stop(s)

View File

@ -12,13 +12,10 @@
{.push raises: [].} {.push raises: [].}
import std/[sequtils] import std/[sequtils]
import stew/results
import chronos, chronicles import chronos, chronicles
import transport, import
../errors, ./transport,
../wire, ../wire,
../multicodec,
../connmanager,
../multiaddress, ../multiaddress,
../stream/connection, ../stream/connection,
../stream/chronosstream, ../stream/chronosstream,
@ -28,76 +25,64 @@ import transport,
logScope: logScope:
topics = "libp2p tcptransport" topics = "libp2p tcptransport"
export transport, results export transport, connection, upgrade
const const TcpTransportTrackerName* = "libp2p.tcptransport"
TcpTransportTrackerName* = "libp2p.tcptransport"
type type
AcceptFuture = typeof(default(StreamServer).accept())
TcpTransport* = ref object of Transport TcpTransport* = ref object of Transport
servers*: seq[StreamServer] servers*: seq[StreamServer]
clients: array[Direction, seq[StreamTransport]] clients: array[Direction, seq[StreamTransport]]
flags: set[ServerFlags] flags: set[ServerFlags]
clientFlags: set[SocketFlags] clientFlags: set[SocketFlags]
acceptFuts: seq[Future[StreamTransport]] acceptFuts: seq[AcceptFuture]
connectionsTimeout: Duration connectionsTimeout: Duration
stopping: bool
TcpTransportError* = object of transport.TransportError TcpTransportError* = object of transport.TransportError
proc connHandler*(self: TcpTransport, proc connHandler*(
self: TcpTransport,
client: StreamTransport, client: StreamTransport,
observedAddr: Opt[MultiAddress], observedAddr: Opt[MultiAddress],
dir: Direction): Future[Connection] {.async.} = dir: Direction,
): Connection =
trace "Handling tcp connection", address = $observedAddr, trace "Handling tcp connection",
address = $observedAddr,
dir = $dir, dir = $dir,
clients = self.clients[Direction.In].len + clients = self.clients[Direction.In].len + self.clients[Direction.Out].len
self.clients[Direction.Out].len
let conn = Connection( let conn = Connection(
ChronosStream.init( ChronosStream.init(
client = client, client = client,
dir = dir, dir = dir,
observedAddr = observedAddr, observedAddr = observedAddr,
timeout = self.connectionsTimeout timeout = self.connectionsTimeout,
)) )
)
proc onClose() {.async: (raises: []).} = proc onClose() {.async: (raises: []).} =
try: await noCancel client.join()
block:
let
fut1 = client.join()
fut2 = conn.join()
try: # https://github.com/status-im/nim-chronos/issues/516
discard await race(fut1, fut2)
except ValueError: raiseAssert("Futures list is not empty")
# at least one join() completed, cancel pending one, if any
if not fut1.finished: await fut1.cancelAndWait()
if not fut2.finished: await fut2.cancelAndWait()
trace "Cleaning up client", addrs = $client.remoteAddress, trace "Cleaning up client", addrs = $client.remoteAddress, conn
conn
self.clients[dir].keepItIf(it != client) self.clients[dir].keepItIf(it != client)
block: # Propagate the chronos client being closed to the connection
let # TODO This is somewhat dubious since it's the connection that owns the
fut1 = conn.close() # client, but it allows the transport to close all connections when
fut2 = client.closeWait() # shutting down (also dubious! it would make more sense that the owner
await allFutures(fut1, fut2) # of all connections closes them, or the next read detects the closed
if fut1.failed: # socket and does the right thing..)
let err = fut1.error()
debug "Error cleaning up client", errMsg = err.msg, conn
static: doAssert typeof(fut2).E is void # Cannot fail
trace "Cleaned up client", addrs = $client.remoteAddress, await conn.close()
conn
except CancelledError as exc: trace "Cleaned up client", addrs = $client.remoteAddress, conn
let useExc {.used.} = exc
debug "Error cleaning up client", errMsg = exc.msg, conn
self.clients[dir].add(client) self.clients[dir].add(client)
asyncSpawn onClose() asyncSpawn onClose()
return conn return conn
@ -106,168 +91,248 @@ proc new*(
T: typedesc[TcpTransport], T: typedesc[TcpTransport],
flags: set[ServerFlags] = {}, flags: set[ServerFlags] = {},
upgrade: Upgrade, upgrade: Upgrade,
connectionsTimeout = 10.minutes): T {.public.} = connectionsTimeout = 10.minutes,
): T {.public.} =
let T(
transport = T(
flags: flags, flags: flags,
clientFlags: clientFlags:
if ServerFlags.TcpNoDelay in flags: if ServerFlags.TcpNoDelay in flags:
compilesOr:
{SocketFlags.TcpNoDelay} {SocketFlags.TcpNoDelay}
do:
doAssert(false)
default(set[SocketFlags])
else: else:
default(set[SocketFlags]), default(set[SocketFlags])
,
upgrader: upgrade, upgrader: upgrade,
networkReachability: NetworkReachability.Unknown, networkReachability: NetworkReachability.Unknown,
connectionsTimeout: connectionsTimeout) connectionsTimeout: connectionsTimeout,
)
return transport method start*(self: TcpTransport, addrs: seq[MultiAddress]): Future[void] =
## Start transport listening to the given addresses - for dial-only transports,
## start with an empty list
method start*( # TODO remove `impl` indirection throughout when `raises` is added to base
self: TcpTransport,
addrs: seq[MultiAddress]) {.async.} =
## listen on the transport
##
proc impl(
self: TcpTransport, addrs: seq[MultiAddress]
): Future[void] {.async: (raises: [transport.TransportError, CancelledError]).} =
if self.running: if self.running:
warn "TCP transport already running" warn "TCP transport already running"
return return
await procCall Transport(self).start(addrs)
trace "Starting TCP transport" trace "Starting TCP transport"
trackCounter(TcpTransportTrackerName)
self.flags.incl(ServerFlags.ReusePort)
var supported: seq[MultiAddress]
var initialized = false
try:
for i, ma in addrs: for i, ma in addrs:
if not self.handles(ma): if not self.handles(ma):
trace "Invalid address detected, skipping!", address = ma trace "Invalid address detected, skipping!", address = ma
continue continue
self.flags.incl(ServerFlags.ReusePort) let
let server = createStreamServer( ta = initTAddress(ma).expect("valid address per handles check above")
ma = ma, server =
flags = self.flags, try:
udata = self) createStreamServer(ta, flags = self.flags)
except common.TransportError as exc:
# always get the resolved address in case we're bound to 0.0.0.0:0 raise (ref TcpTransportError)(msg: exc.msg, parent: exc)
self.addrs[i] = MultiAddress.init(
server.sock.getLocalAddress()
).tryGet()
self.servers &= server self.servers &= server
trace "Listening on", address = ma trace "Listening on", address = ma
supported.add(
MultiAddress.init(server.sock.getLocalAddress()).expect(
"Can init from local address"
)
)
initialized = true
finally:
if not initialized:
# Clean up partial success on exception
await noCancel allFutures(self.servers.mapIt(it.closeWait()))
reset(self.servers)
method stop*(self: TcpTransport) {.async.} =
## stop the transport
##
try: try:
await procCall Transport(self).start(supported)
except CatchableError:
raiseAssert "Base method does not raise"
trackCounter(TcpTransportTrackerName)
impl(self, addrs)
method stop*(self: TcpTransport): Future[void] =
## Stop the transport and close all connections it created
proc impl(self: TcpTransport) {.async: (raises: []).} =
trace "Stopping TCP transport" trace "Stopping TCP transport"
self.stopping = true
defer:
self.stopping = false
checkFutures( if self.running:
await allFinished( # Reset the running flag
try:
await noCancel procCall Transport(self).stop()
except CatchableError: # TODO remove when `accept` is annotated with raises
raiseAssert "doesn't actually raise"
# Stop each server by closing the socket - this will cause all accept loops
# to fail - since the running flag has been reset, it's also safe to close
# all known clients since no more of them will be added
await noCancel allFutures(
self.servers.mapIt(it.closeWait()) &
self.clients[Direction.In].mapIt(it.closeWait()) & self.clients[Direction.In].mapIt(it.closeWait()) &
self.clients[Direction.Out].mapIt(it.closeWait()))) self.clients[Direction.Out].mapIt(it.closeWait())
)
if not self.running:
warn "TCP transport already stopped"
return
await procCall Transport(self).stop() # call base
var toWait: seq[Future[void]]
for fut in self.acceptFuts:
if not fut.finished:
toWait.add(fut.cancelAndWait())
elif fut.done:
toWait.add(fut.read().closeWait())
for server in self.servers:
server.stop()
toWait.add(server.closeWait())
await allFutures(toWait)
self.servers = @[] self.servers = @[]
for acceptFut in self.acceptFuts:
if acceptFut.completed():
await acceptFut.value().closeWait()
self.acceptFuts = @[] self.acceptFuts = @[]
if self.clients[Direction.In].len != 0 or self.clients[Direction.Out].len != 0:
# Future updates could consider turning this warn into an assert since
# it should never happen if the shutdown code is correct
warn "Couldn't clean up clients",
len = self.clients[Direction.In].len + self.clients[Direction.Out].len
trace "Transport stopped" trace "Transport stopped"
untrackCounter(TcpTransportTrackerName) untrackCounter(TcpTransportTrackerName)
except CatchableError as exc: else:
trace "Error shutting down tcp transport", exc = exc.msg # For legacy reasons, `stop` on a transpart that wasn't started is
# expected to close outgoing connections created by the transport
warn "TCP transport already stopped"
method accept*(self: TcpTransport): Future[Connection] {.async.} = doAssert self.clients[Direction.In].len == 0,
## accept a new TCP connection "No incoming connections possible without start"
await noCancel allFutures(self.clients[Direction.Out].mapIt(it.closeWait()))
impl(self)
method accept*(self: TcpTransport): Future[Connection] =
## accept a new TCP connection, returning nil on non-fatal errors
## ##
## Raises an exception when the transport is broken and cannot be used for
## accepting further connections
# TODO returning nil for non-fatal errors is problematic in that error
# information is lost and must be logged here instead of being
# available to the caller - further refactoring should propagate errors
# to the caller instead
proc impl(
self: TcpTransport
): Future[Connection] {.async: (raises: [transport.TransportError, CancelledError]).} =
if not self.running: if not self.running:
raise newTransportClosedError() raise newTransportClosedError()
try:
if self.acceptFuts.len <= 0: if self.acceptFuts.len <= 0:
self.acceptFuts = self.servers.mapIt(Future[StreamTransport](it.accept())) self.acceptFuts = self.servers.mapIt(it.accept())
if self.acceptFuts.len <= 0:
return
let let
finished = await one(self.acceptFuts) finished =
try:
await one(self.acceptFuts)
except ValueError:
raise (ref TcpTransportError)(msg: "No listeners configured")
index = self.acceptFuts.find(finished) index = self.acceptFuts.find(finished)
transp =
try:
await finished
except TransportTooManyError as exc:
debug "Too many files opened", exc = exc.msg
return nil
except TransportAbortedError as exc:
debug "Connection aborted", exc = exc.msg
return nil
except TransportUseClosedError as exc:
raise newTransportClosedError(exc)
except TransportOsError as exc:
raise (ref TcpTransportError)(msg: exc.msg, parent: exc)
except common.TransportError as exc: # Needed for chronos 4.0.0 support
raise (ref TcpTransportError)(msg: exc.msg, parent: exc)
except CancelledError as exc:
raise exc
if not self.running: # Stopped while waiting
await transp.closeWait()
raise newTransportClosedError()
self.acceptFuts[index] = self.servers[index].accept() self.acceptFuts[index] = self.servers[index].accept()
let transp = await finished let remote =
try: try:
let observedAddr = MultiAddress.init(transp.remoteAddress).tryGet() transp.remoteAddress
return await self.connHandler(transp, Opt.some(observedAddr), Direction.In)
except CancelledError as exc:
debug "CancelledError", exc = exc.msg
transp.close()
raise exc
except CatchableError as exc:
debug "Failed to handle connection", exc = exc.msg
transp.close()
except TransportTooManyError as exc:
debug "Too many files opened", exc = exc.msg
except TransportAbortedError as exc:
debug "Connection aborted", exc = exc.msg
except TransportUseClosedError as exc:
debug "Server was closed", exc = exc.msg
raise newTransportClosedError(exc)
except CancelledError as exc:
raise exc
except TransportOsError as exc: except TransportOsError as exc:
info "OS Error", exc = exc.msg # The connection had errors / was closed before `await` returned control
raise exc await transp.closeWait()
except CatchableError as exc: debug "Cannot read remote address", exc = exc.msg
info "Unexpected error accepting connection", exc = exc.msg return nil
raise exc
let observedAddr =
MultiAddress.init(remote).expect("Can initialize from remote address")
self.connHandler(transp, Opt.some(observedAddr), Direction.In)
impl(self)
method dial*( method dial*(
self: TcpTransport, self: TcpTransport,
hostname: string, hostname: string,
address: MultiAddress, address: MultiAddress,
peerId: Opt[PeerId] = Opt.none(PeerId)): Future[Connection] {.async.} = peerId: Opt[PeerId] = Opt.none(PeerId),
): Future[Connection] =
## dial a peer ## dial a peer
## proc impl(
self: TcpTransport, hostname: string, address: MultiAddress, peerId: Opt[PeerId]
): Future[Connection] {.async: (raises: [transport.TransportError, CancelledError]).} =
if self.stopping:
raise newTransportClosedError()
let ta = initTAddress(address).valueOr:
raise (ref TcpTransportError)(msg: "Unsupported address: " & $address)
trace "Dialing remote peer", address = $address trace "Dialing remote peer", address = $address
let transp = let transp =
if self.networkReachability == NetworkReachability.NotReachable and self.addrs.len > 0:
self.clientFlags.incl(SocketFlags.ReusePort)
await connect(address, flags = self.clientFlags, localAddress = Opt.some(self.addrs[0]))
else:
await connect(address, flags = self.clientFlags)
try: try:
let observedAddr = MultiAddress.init(transp.remoteAddress).tryGet() await(
return await self.connHandler(transp, Opt.some(observedAddr), Direction.Out) if self.networkReachability == NetworkReachability.NotReachable and
except CatchableError as err: self.addrs.len > 0:
await transp.closeWait() let local = initTAddress(self.addrs[0]).expect("self address is valid")
raise err self.clientFlags.incl(SocketFlags.ReusePort)
connect(ta, flags = self.clientFlags, localAddress = local)
else:
connect(ta, flags = self.clientFlags)
)
except CancelledError as exc:
raise exc
except CatchableError as exc:
raise (ref TcpTransportError)(msg: exc.msg, parent: exc)
method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} = # If `stop` is called after `connect` but before `await` returns, we might
# end up with a race condition where `stop` returns but not all connections
# have been closed - we drop connections in this case in order not to leak
# them
if self.stopping:
# Stopped while waiting for new connection
await transp.closeWait()
raise newTransportClosedError()
let observedAddr =
try:
MultiAddress.init(transp.remoteAddress).expect("remote address is valid")
except TransportOsError as exc:
await transp.closeWait()
raise (ref TcpTransportError)(msg: exc.msg)
self.connHandler(transp, Opt.some(observedAddr), Direction.Out)
impl(self, hostname, address, peerId)
method handles*(t: TcpTransport, address: MultiAddress): bool =
if procCall Transport(t).handles(address): if procCall Transport(t).handles(address):
if address.protocols.isOk: if address.protocols.isOk:
return TCP.match(address) return TCP.match(address)

View File

@ -200,7 +200,7 @@ method dial*(
try: try:
await dialPeer(transp, address) await dialPeer(transp, address)
return await self.tcpTransport.connHandler(transp, Opt.none(MultiAddress), Direction.Out) return self.tcpTransport.connHandler(transp, Opt.none(MultiAddress), Direction.Out)
except CatchableError as err: except CatchableError as err:
await transp.closeWait() await transp.closeWait()
raise err raise err

View File

@ -35,7 +35,7 @@ type
upgrader*: Upgrade upgrader*: Upgrade
networkReachability*: NetworkReachability networkReachability*: NetworkReachability
proc newTransportClosedError*(parent: ref Exception = nil): ref LPError = proc newTransportClosedError*(parent: ref Exception = nil): ref TransportError =
newException(TransportClosedError, newException(TransportClosedError,
"Transport closed, no more connections!", parent) "Transport closed, no more connections!", parent)

View File

@ -13,6 +13,8 @@
import chronos, stew/endians2 import chronos, stew/endians2
import multiaddress, multicodec, errors, utility import multiaddress, multicodec, errors, utility
export multiaddress, chronos
when defined(windows): when defined(windows):
import winlean import winlean
else: else:
@ -30,7 +32,6 @@ const
UDP, UDP,
) )
proc initTAddress*(ma: MultiAddress): MaResult[TransportAddress] = proc initTAddress*(ma: MultiAddress): MaResult[TransportAddress] =
## Initialize ``TransportAddress`` with MultiAddress ``ma``. ## Initialize ``TransportAddress`` with MultiAddress ``ma``.
## ##
@ -76,7 +77,7 @@ proc connect*(
child: StreamTransport = nil, child: StreamTransport = nil,
flags = default(set[SocketFlags]), flags = default(set[SocketFlags]),
localAddress: Opt[MultiAddress] = Opt.none(MultiAddress)): Future[StreamTransport] localAddress: Opt[MultiAddress] = Opt.none(MultiAddress)): Future[StreamTransport]
{.raises: [LPError, MaInvalidAddress].} = {.async.} =
## Open new connection to remote peer with address ``ma`` and create ## Open new connection to remote peer with address ``ma`` and create
## new transport object ``StreamTransport`` for established connection. ## new transport object ``StreamTransport`` for established connection.
## ``bufferSize`` is size of internal buffer for transport. ## ``bufferSize`` is size of internal buffer for transport.
@ -88,12 +89,12 @@ proc connect*(
let transportAddress = initTAddress(ma).tryGet() let transportAddress = initTAddress(ma).tryGet()
compilesOr: compilesOr:
return connect(transportAddress, bufferSize, child, return await connect(transportAddress, bufferSize, child,
if localAddress.isSome(): initTAddress(localAddress.expect("just checked")).tryGet() else: TransportAddress(), if localAddress.isSome(): initTAddress(localAddress.expect("just checked")).tryGet() else: TransportAddress(),
flags) flags)
do: do:
# support for older chronos versions # support for older chronos versions
return connect(transportAddress, bufferSize, child) return await connect(transportAddress, bufferSize, child)
proc createStreamServer*[T](ma: MultiAddress, proc createStreamServer*[T](ma: MultiAddress,
cbproc: StreamCallback, cbproc: StreamCallback,