diff --git a/libp2p/connection.nim b/libp2p/connection.nim index 4361bb082..80e95a518 100644 --- a/libp2p/connection.nim +++ b/libp2p/connection.nim @@ -7,7 +7,6 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import oids import chronos, chronicles, metrics import peerinfo, errors, @@ -15,6 +14,9 @@ import peerinfo, stream/lpstream, peerinfo +when chronicles.enabledLogLevel == LogLevel.TRACE: + import oids + export lpstream logScope: @@ -66,51 +68,57 @@ proc `$`*(conn: Connection): string = if not isNil(conn.peerInfo): result = $(conn.peerInfo) -proc bindStreamClose(conn: Connection) {.async.} = - # bind stream's close event to connection's close - # to ensure correct close propagation - if not isNil(conn.stream.closeEvent): - await conn.stream.closeEvent.wait() - trace "wrapped stream closed, about to close conn", - closed = conn.isClosed, conn = $conn - if not conn.isClosed: - trace "wrapped stream closed, closing conn", - closed = conn.isClosed, conn = $conn - await conn.close() - proc init[T: Connection](self: var T, stream: LPStream): T = ## create a new Connection for the specified async reader/writer new self self.stream = stream - self.closeEvent = newAsyncEvent() - when chronicles.enabledLogLevel == LogLevel.TRACE: - self.oid = genOid() - asyncCheck self.bindStreamClose() - inc getConnectionTracker().opened - libp2p_open_connection.inc() - + self.initStream() return self proc newConnection*(stream: LPStream): Connection = ## create a new Connection for the specified async reader/writer result.init(stream) +method initStream*(s: Connection) = + procCall LPStream(s).initStream() + trace "created connection", oid = s.oid + inc getConnectionTracker().opened + libp2p_open_connection.inc() + method readExactly*(s: Connection, pbytes: pointer, nbytes: int): - Future[void] {.gcsafe.} = - s.stream.readExactly(pbytes, nbytes) + Future[void] {.async, gcsafe.} = + try: + await s.stream.readExactly(pbytes, nbytes) + except CatchableError as exc: + await s.close() + raise exc method readOnce*(s: Connection, pbytes: pointer, nbytes: int): - Future[int] {.gcsafe.} = - s.stream.readOnce(pbytes, nbytes) + Future[int] {.async, gcsafe.} = + try: + result = await s.stream.readOnce(pbytes, nbytes) + except CatchableError as exc: + await s.close() + raise exc method write*(s: Connection, msg: seq[byte]): - Future[void] {.gcsafe.} = - s.stream.write(msg) + Future[void] {.async, gcsafe.} = + try: + await s.stream.write(msg) + except CatchableError as exc: + await s.close() + raise exc + +method atEof*(s: Connection): bool {.inline.} = + if isNil(s.stream): + return true + + s.stream.atEof method closed*(s: Connection): bool = if isNil(s.stream): @@ -119,30 +127,37 @@ method closed*(s: Connection): bool = result = s.stream.closed method close*(s: Connection) {.async, gcsafe.} = - trace "about to close connection", closed = s.closed, conn = $s + try: + if not s.isClosed: + s.isClosed = true - if not s.isClosed: - s.isClosed = true - inc getConnectionTracker().closed + trace "about to close connection", closed = s.closed, + conn = $s, + oid = s.oid - if not isNil(s.stream) and not s.stream.closed: - trace "closing child stream", closed = s.closed, conn = $s - try: + + if not isNil(s.stream) and not s.stream.closed: + trace "closing child stream", closed = s.closed, + conn = $s, + oid = s.stream.oid await s.stream.close() - except CancelledError as exc: - raise exc - except CatchableError as exc: - debug "Error while closing child stream", err = exc.msg + # s.stream = nil - s.closeEvent.fire() + s.closeEvent.fire() + trace "waiting readloops", count=s.readLoops.len, + conn = $s, + oid = s.oid + await all(s.readLoops) + s.readLoops = @[] - trace "waiting readloops", count=s.readLoops.len, conn = $s - let loopFuts = await allFinished(s.readLoops) - checkFutures(loopFuts) - s.readLoops = @[] + trace "connection closed", closed = s.closed, + conn = $s, + oid = s.oid - trace "connection closed", closed = s.closed, conn = $s - libp2p_open_connection.dec() + inc getConnectionTracker().closed + libp2p_open_connection.dec() + except CatchableError as exc: + trace "exception closing connections", exc = exc.msg method getObservedAddrs*(c: Connection): Future[MultiAddress] {.base, async, gcsafe.} = ## get resolved multiaddresses for the connection