fix(transport): various tcp transport races (#1095)

Co-authored-by: diegomrsantos <diego@status.im>
This commit is contained in:
Jacek Sieka 2024-05-14 07:10:34 +02:00 committed by GitHub
parent 1b91b97499
commit 3ca49a2f40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 273 additions and 219 deletions

View File

@ -81,16 +81,18 @@ proc dialAndUpgrade(
if dialed.dir != dir:
dialed.dir = dir
await transport.upgrade(dialed, peerId)
except CancelledError as exc:
await dialed.close()
raise exc
except CatchableError as exc:
# If we failed to establish the connection through one transport,
# we won't succeeded through another - no use in trying again
await dialed.close()
debug "Connection upgrade failed", err = exc.msg, peerId = peerId.get(default(PeerId))
if exc isnot CancelledError:
if dialed.dir == Direction.Out:
libp2p_failed_upgrades_outgoing.inc()
else:
libp2p_failed_upgrades_incoming.inc()
if dialed.dir == Direction.Out:
libp2p_failed_upgrades_outgoing.inc()
else:
libp2p_failed_upgrades_incoming.inc()
# Try other address
return nil

View File

@ -44,12 +44,3 @@ macro checkFutures*[F](futs: seq[F], exclude: untyped = []): untyped =
# We still don't abort but warn
debug "A future has failed, enable trace logging for details", error=exc.name
trace "Exception details", msg=exc.msg
template tryAndWarn*(message: static[string]; body: untyped): untyped =
try:
body
except CancelledError as exc:
raise exc
except CatchableError as exc:
debug "An exception has ocurred, enable trace logging for details", name = exc.name, msg = message
trace "Exception details", exc = exc.msg

View File

@ -273,6 +273,7 @@ proc accept(s: Switch, transport: Transport) {.async.} = # noraises
except CancelledError as exc:
trace "releasing semaphore on cancellation"
upgrades.release() # always release the slot
return
except CatchableError as exc:
error "Exception in accept loop, exiting", exc = exc.msg
upgrades.release() # always release the slot
@ -288,6 +289,12 @@ proc stop*(s: Switch) {.async, public.} =
s.started = false
try:
# Stop accepting incoming connections
await allFutures(s.acceptFuts.mapIt(it.cancelAndWait())).wait(1.seconds)
except CatchableError as exc:
debug "Cannot cancel accepts", error = exc.msg
for service in s.services:
discard await service.stop(s)
@ -302,18 +309,6 @@ proc stop*(s: Switch) {.async, public.} =
except CatchableError as exc:
warn "error cleaning up transports", msg = exc.msg
try:
await allFutures(s.acceptFuts)
.wait(1.seconds)
except CatchableError as exc:
trace "Exception while stopping accept loops", exc = exc.msg
# check that all futures were properly
# stopped and otherwise cancel them
for a in s.acceptFuts:
if not a.finished:
a.cancel()
for service in s.services:
discard await service.stop(s)

View File

@ -12,262 +12,327 @@
{.push raises: [].}
import std/[sequtils]
import stew/results
import chronos, chronicles
import transport,
../errors,
../wire,
../multicodec,
../connmanager,
../multiaddress,
../stream/connection,
../stream/chronosstream,
../upgrademngrs/upgrade,
../utility
import
./transport,
../wire,
../multiaddress,
../stream/connection,
../stream/chronosstream,
../upgrademngrs/upgrade,
../utility
logScope:
topics = "libp2p tcptransport"
export transport, results
export transport, connection, upgrade
const
TcpTransportTrackerName* = "libp2p.tcptransport"
const TcpTransportTrackerName* = "libp2p.tcptransport"
type
AcceptFuture = typeof(default(StreamServer).accept())
TcpTransport* = ref object of Transport
servers*: seq[StreamServer]
clients: array[Direction, seq[StreamTransport]]
flags: set[ServerFlags]
clientFlags: set[SocketFlags]
acceptFuts: seq[Future[StreamTransport]]
acceptFuts: seq[AcceptFuture]
connectionsTimeout: Duration
stopping: bool
TcpTransportError* = object of transport.TransportError
proc connHandler*(self: TcpTransport,
client: StreamTransport,
observedAddr: Opt[MultiAddress],
dir: Direction): Future[Connection] {.async.} =
trace "Handling tcp connection", address = $observedAddr,
dir = $dir,
clients = self.clients[Direction.In].len +
self.clients[Direction.Out].len
proc connHandler*(
self: TcpTransport,
client: StreamTransport,
observedAddr: Opt[MultiAddress],
dir: Direction,
): Connection =
trace "Handling tcp connection",
address = $observedAddr,
dir = $dir,
clients = self.clients[Direction.In].len + self.clients[Direction.Out].len
let conn = Connection(
ChronosStream.init(
client = client,
dir = dir,
observedAddr = observedAddr,
timeout = self.connectionsTimeout
))
timeout = self.connectionsTimeout,
)
)
proc onClose() {.async: (raises: []).} =
try:
block:
let
fut1 = client.join()
fut2 = conn.join()
try: # https://github.com/status-im/nim-chronos/issues/516
discard await race(fut1, fut2)
except ValueError: raiseAssert("Futures list is not empty")
# at least one join() completed, cancel pending one, if any
if not fut1.finished: await fut1.cancelAndWait()
if not fut2.finished: await fut2.cancelAndWait()
await noCancel client.join()
trace "Cleaning up client", addrs = $client.remoteAddress,
conn
trace "Cleaning up client", addrs = $client.remoteAddress, conn
self.clients[dir].keepItIf( it != client )
self.clients[dir].keepItIf(it != client)
block:
let
fut1 = conn.close()
fut2 = client.closeWait()
await allFutures(fut1, fut2)
if fut1.failed:
let err = fut1.error()
debug "Error cleaning up client", errMsg = err.msg, conn
static: doAssert typeof(fut2).E is void # Cannot fail
# Propagate the chronos client being closed to the connection
# TODO This is somewhat dubious since it's the connection that owns the
# client, but it allows the transport to close all connections when
# shutting down (also dubious! it would make more sense that the owner
# of all connections closes them, or the next read detects the closed
# socket and does the right thing..)
trace "Cleaned up client", addrs = $client.remoteAddress,
conn
await conn.close()
except CancelledError as exc:
let useExc {.used.} = exc
debug "Error cleaning up client", errMsg = exc.msg, conn
trace "Cleaned up client", addrs = $client.remoteAddress, conn
self.clients[dir].add(client)
asyncSpawn onClose()
return conn
proc new*(
T: typedesc[TcpTransport],
flags: set[ServerFlags] = {},
upgrade: Upgrade,
connectionsTimeout = 10.minutes): T {.public.} =
T: typedesc[TcpTransport],
flags: set[ServerFlags] = {},
upgrade: Upgrade,
connectionsTimeout = 10.minutes,
): T {.public.} =
T(
flags: flags,
clientFlags:
if ServerFlags.TcpNoDelay in flags:
{SocketFlags.TcpNoDelay}
else:
default(set[SocketFlags])
,
upgrader: upgrade,
networkReachability: NetworkReachability.Unknown,
connectionsTimeout: connectionsTimeout,
)
let
transport = T(
flags: flags,
clientFlags:
if ServerFlags.TcpNoDelay in flags:
compilesOr:
{SocketFlags.TcpNoDelay}
do:
doAssert(false)
default(set[SocketFlags])
else:
default(set[SocketFlags]),
upgrader: upgrade,
networkReachability: NetworkReachability.Unknown,
connectionsTimeout: connectionsTimeout)
method start*(self: TcpTransport, addrs: seq[MultiAddress]): Future[void] =
## Start transport listening to the given addresses - for dial-only transports,
## start with an empty list
return transport
# TODO remove `impl` indirection throughout when `raises` is added to base
method start*(
self: TcpTransport,
addrs: seq[MultiAddress]) {.async.} =
## listen on the transport
##
proc impl(
self: TcpTransport, addrs: seq[MultiAddress]
): Future[void] {.async: (raises: [transport.TransportError, CancelledError]).} =
if self.running:
warn "TCP transport already running"
return
if self.running:
warn "TCP transport already running"
return
await procCall Transport(self).start(addrs)
trace "Starting TCP transport"
trackCounter(TcpTransportTrackerName)
for i, ma in addrs:
if not self.handles(ma):
trace "Invalid address detected, skipping!", address = ma
continue
trace "Starting TCP transport"
self.flags.incl(ServerFlags.ReusePort)
let server = createStreamServer(
ma = ma,
flags = self.flags,
udata = self)
# always get the resolved address in case we're bound to 0.0.0.0:0
self.addrs[i] = MultiAddress.init(
server.sock.getLocalAddress()
).tryGet()
var supported: seq[MultiAddress]
var initialized = false
try:
for i, ma in addrs:
if not self.handles(ma):
trace "Invalid address detected, skipping!", address = ma
continue
self.servers &= server
let
ta = initTAddress(ma).expect("valid address per handles check above")
server =
try:
createStreamServer(ta, flags = self.flags)
except common.TransportError as exc:
raise (ref TcpTransportError)(msg: exc.msg, parent: exc)
trace "Listening on", address = ma
self.servers &= server
method stop*(self: TcpTransport) {.async.} =
## stop the transport
##
try:
trace "Listening on", address = ma
supported.add(
MultiAddress.init(server.sock.getLocalAddress()).expect(
"Can init from local address"
)
)
initialized = true
finally:
if not initialized:
# Clean up partial success on exception
await noCancel allFutures(self.servers.mapIt(it.closeWait()))
reset(self.servers)
try:
await procCall Transport(self).start(supported)
except CatchableError:
raiseAssert "Base method does not raise"
trackCounter(TcpTransportTrackerName)
impl(self, addrs)
method stop*(self: TcpTransport): Future[void] =
## Stop the transport and close all connections it created
proc impl(self: TcpTransport) {.async: (raises: []).} =
trace "Stopping TCP transport"
self.stopping = true
defer:
self.stopping = false
checkFutures(
await allFinished(
self.clients[Direction.In].mapIt(it.closeWait()) &
self.clients[Direction.Out].mapIt(it.closeWait())))
if self.running:
# Reset the running flag
try:
await noCancel procCall Transport(self).stop()
except CatchableError: # TODO remove when `accept` is annotated with raises
raiseAssert "doesn't actually raise"
if not self.running:
# Stop each server by closing the socket - this will cause all accept loops
# to fail - since the running flag has been reset, it's also safe to close
# all known clients since no more of them will be added
await noCancel allFutures(
self.servers.mapIt(it.closeWait()) &
self.clients[Direction.In].mapIt(it.closeWait()) &
self.clients[Direction.Out].mapIt(it.closeWait())
)
self.servers = @[]
for acceptFut in self.acceptFuts:
if acceptFut.completed():
await acceptFut.value().closeWait()
self.acceptFuts = @[]
if self.clients[Direction.In].len != 0 or self.clients[Direction.Out].len != 0:
# Future updates could consider turning this warn into an assert since
# it should never happen if the shutdown code is correct
warn "Couldn't clean up clients",
len = self.clients[Direction.In].len + self.clients[Direction.Out].len
trace "Transport stopped"
untrackCounter(TcpTransportTrackerName)
else:
# For legacy reasons, `stop` on a transpart that wasn't started is
# expected to close outgoing connections created by the transport
warn "TCP transport already stopped"
return
await procCall Transport(self).stop() # call base
var toWait: seq[Future[void]]
for fut in self.acceptFuts:
if not fut.finished:
toWait.add(fut.cancelAndWait())
elif fut.done:
toWait.add(fut.read().closeWait())
doAssert self.clients[Direction.In].len == 0,
"No incoming connections possible without start"
await noCancel allFutures(self.clients[Direction.Out].mapIt(it.closeWait()))
for server in self.servers:
server.stop()
toWait.add(server.closeWait())
impl(self)
await allFutures(toWait)
self.servers = @[]
self.acceptFuts = @[]
trace "Transport stopped"
untrackCounter(TcpTransportTrackerName)
except CatchableError as exc:
trace "Error shutting down tcp transport", exc = exc.msg
method accept*(self: TcpTransport): Future[Connection] {.async.} =
## accept a new TCP connection
method accept*(self: TcpTransport): Future[Connection] =
## accept a new TCP connection, returning nil on non-fatal errors
##
if not self.running:
raise newTransportClosedError()
try:
if self.acceptFuts.len <= 0:
self.acceptFuts = self.servers.mapIt(Future[StreamTransport](it.accept()))
## Raises an exception when the transport is broken and cannot be used for
## accepting further connections
# TODO returning nil for non-fatal errors is problematic in that error
# information is lost and must be logged here instead of being
# available to the caller - further refactoring should propagate errors
# to the caller instead
proc impl(
self: TcpTransport
): Future[Connection] {.async: (raises: [transport.TransportError, CancelledError]).} =
if not self.running:
raise newTransportClosedError()
if self.acceptFuts.len <= 0:
return
self.acceptFuts = self.servers.mapIt(it.accept())
let
finished = await one(self.acceptFuts)
finished =
try:
await one(self.acceptFuts)
except ValueError:
raise (ref TcpTransportError)(msg: "No listeners configured")
index = self.acceptFuts.find(finished)
transp =
try:
await finished
except TransportTooManyError as exc:
debug "Too many files opened", exc = exc.msg
return nil
except TransportAbortedError as exc:
debug "Connection aborted", exc = exc.msg
return nil
except TransportUseClosedError as exc:
raise newTransportClosedError(exc)
except TransportOsError as exc:
raise (ref TcpTransportError)(msg: exc.msg, parent: exc)
except common.TransportError as exc: # Needed for chronos 4.0.0 support
raise (ref TcpTransportError)(msg: exc.msg, parent: exc)
except CancelledError as exc:
raise exc
if not self.running: # Stopped while waiting
await transp.closeWait()
raise newTransportClosedError()
self.acceptFuts[index] = self.servers[index].accept()
let transp = await finished
try:
let observedAddr = MultiAddress.init(transp.remoteAddress).tryGet()
return await self.connHandler(transp, Opt.some(observedAddr), Direction.In)
except CancelledError as exc:
debug "CancelledError", exc = exc.msg
transp.close()
raise exc
except CatchableError as exc:
debug "Failed to handle connection", exc = exc.msg
transp.close()
except TransportTooManyError as exc:
debug "Too many files opened", exc = exc.msg
except TransportAbortedError as exc:
debug "Connection aborted", exc = exc.msg
except TransportUseClosedError as exc:
debug "Server was closed", exc = exc.msg
raise newTransportClosedError(exc)
except CancelledError as exc:
raise exc
except TransportOsError as exc:
info "OS Error", exc = exc.msg
raise exc
except CatchableError as exc:
info "Unexpected error accepting connection", exc = exc.msg
raise exc
let remote =
try:
transp.remoteAddress
except TransportOsError as exc:
# The connection had errors / was closed before `await` returned control
await transp.closeWait()
debug "Cannot read remote address", exc = exc.msg
return nil
let observedAddr =
MultiAddress.init(remote).expect("Can initialize from remote address")
self.connHandler(transp, Opt.some(observedAddr), Direction.In)
impl(self)
method dial*(
self: TcpTransport,
hostname: string,
address: MultiAddress,
peerId: Opt[PeerId] = Opt.none(PeerId)): Future[Connection] {.async.} =
self: TcpTransport,
hostname: string,
address: MultiAddress,
peerId: Opt[PeerId] = Opt.none(PeerId),
): Future[Connection] =
## dial a peer
##
proc impl(
self: TcpTransport, hostname: string, address: MultiAddress, peerId: Opt[PeerId]
): Future[Connection] {.async: (raises: [transport.TransportError, CancelledError]).} =
if self.stopping:
raise newTransportClosedError()
trace "Dialing remote peer", address = $address
let transp =
if self.networkReachability == NetworkReachability.NotReachable and self.addrs.len > 0:
self.clientFlags.incl(SocketFlags.ReusePort)
await connect(address, flags = self.clientFlags, localAddress = Opt.some(self.addrs[0]))
else:
await connect(address, flags = self.clientFlags)
let ta = initTAddress(address).valueOr:
raise (ref TcpTransportError)(msg: "Unsupported address: " & $address)
try:
let observedAddr = MultiAddress.init(transp.remoteAddress).tryGet()
return await self.connHandler(transp, Opt.some(observedAddr), Direction.Out)
except CatchableError as err:
await transp.closeWait()
raise err
trace "Dialing remote peer", address = $address
let transp =
try:
await(
if self.networkReachability == NetworkReachability.NotReachable and
self.addrs.len > 0:
let local = initTAddress(self.addrs[0]).expect("self address is valid")
self.clientFlags.incl(SocketFlags.ReusePort)
connect(ta, flags = self.clientFlags, localAddress = local)
else:
connect(ta, flags = self.clientFlags)
)
except CancelledError as exc:
raise exc
except CatchableError as exc:
raise (ref TcpTransportError)(msg: exc.msg, parent: exc)
method handles*(t: TcpTransport, address: MultiAddress): bool {.gcsafe.} =
# If `stop` is called after `connect` but before `await` returns, we might
# end up with a race condition where `stop` returns but not all connections
# have been closed - we drop connections in this case in order not to leak
# them
if self.stopping:
# Stopped while waiting for new connection
await transp.closeWait()
raise newTransportClosedError()
let observedAddr =
try:
MultiAddress.init(transp.remoteAddress).expect("remote address is valid")
except TransportOsError as exc:
await transp.closeWait()
raise (ref TcpTransportError)(msg: exc.msg)
self.connHandler(transp, Opt.some(observedAddr), Direction.Out)
impl(self, hostname, address, peerId)
method handles*(t: TcpTransport, address: MultiAddress): bool =
if procCall Transport(t).handles(address):
if address.protocols.isOk:
return TCP.match(address)

View File

@ -200,7 +200,7 @@ method dial*(
try:
await dialPeer(transp, address)
return await self.tcpTransport.connHandler(transp, Opt.none(MultiAddress), Direction.Out)
return self.tcpTransport.connHandler(transp, Opt.none(MultiAddress), Direction.Out)
except CatchableError as err:
await transp.closeWait()
raise err

View File

@ -35,7 +35,7 @@ type
upgrader*: Upgrade
networkReachability*: NetworkReachability
proc newTransportClosedError*(parent: ref Exception = nil): ref LPError =
proc newTransportClosedError*(parent: ref Exception = nil): ref TransportError =
newException(TransportClosedError,
"Transport closed, no more connections!", parent)

View File

@ -13,6 +13,8 @@
import chronos, stew/endians2
import multiaddress, multicodec, errors, utility
export multiaddress, chronos
when defined(windows):
import winlean
else:
@ -30,7 +32,6 @@ const
UDP,
)
proc initTAddress*(ma: MultiAddress): MaResult[TransportAddress] =
## Initialize ``TransportAddress`` with MultiAddress ``ma``.
##
@ -76,7 +77,7 @@ proc connect*(
child: StreamTransport = nil,
flags = default(set[SocketFlags]),
localAddress: Opt[MultiAddress] = Opt.none(MultiAddress)): Future[StreamTransport]
{.raises: [LPError, MaInvalidAddress].} =
{.async.} =
## Open new connection to remote peer with address ``ma`` and create
## new transport object ``StreamTransport`` for established connection.
## ``bufferSize`` is size of internal buffer for transport.
@ -88,12 +89,12 @@ proc connect*(
let transportAddress = initTAddress(ma).tryGet()
compilesOr:
return connect(transportAddress, bufferSize, child,
return await connect(transportAddress, bufferSize, child,
if localAddress.isSome(): initTAddress(localAddress.expect("just checked")).tryGet() else: TransportAddress(),
flags)
do:
# support for older chronos versions
return connect(transportAddress, bufferSize, child)
return await connect(transportAddress, bufferSize, child)
proc createStreamServer*[T](ma: MultiAddress,
cbproc: StreamCallback,