diff --git a/libp2p/muxers/mplex/lpchannel.nim b/libp2p/muxers/mplex/lpchannel.nim index 0d64f05..b031cdc 100644 --- a/libp2p/muxers/mplex/lpchannel.nim +++ b/libp2p/muxers/mplex/lpchannel.nim @@ -172,7 +172,7 @@ method readOnce*(s: LPChannel, method write*(s: LPChannel, msg: seq[byte]): Future[void] {.async.} = ## Write to mplex channel - there may be up to MaxWrite concurrent writes - ## pending after which the peer is disconencted + ## pending after which the peer is disconnected if s.closedLocal or s.conn.closed: raise newLPStreamClosedError() @@ -182,6 +182,7 @@ method write*(s: LPChannel, msg: seq[byte]): Future[void] {.async.} = if s.writes >= MaxWrites: debug "Closing connection, too many in-flight writes on channel", s, conn = s.conn, writes = s.writes + await s.reset() await s.conn.close() return @@ -197,6 +198,7 @@ method write*(s: LPChannel, msg: seq[byte]): Future[void] {.async.} = s.activity = true except CatchableError as exc: trace "exception in lpchannel write handler", s, msg = exc.msg + await s.reset() await s.conn.close() raise exc finally: diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index a856420..79756b1 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -27,7 +27,6 @@ const MplexCodec* = "/mplex/6.7.0" const MaxChannelCount = 200 - when defined(libp2p_expensive_metrics): declareGauge(libp2p_mplex_channels, "mplex channels", labels = ["initiator", "peer"]) diff --git a/libp2p/stream/lpstream.nim b/libp2p/stream/lpstream.nim index 655a891..c9cb7f1 100644 --- a/libp2p/stream/lpstream.nim +++ b/libp2p/stream/lpstream.nim @@ -60,8 +60,8 @@ proc setupStreamTracker(name: string): StreamTracker = let tracker = new StreamTracker proc dumpTracking(): string {.gcsafe.} = - return "Opened " & tracker.id & " :" & $tracker.opened & "\n" & - "Closed " & tracker.id & " :" & $tracker.closed + return "Opened " & tracker.id & ": " & $tracker.opened & "\n" & + "Closed " & tracker.id & ": " & $tracker.closed proc leakTransport(): bool {.gcsafe.} = return (tracker.opened != tracker.closed) diff --git a/tests/helpers.nim b/tests/helpers.nim index 561a943..9bf6d2c 100644 --- a/tests/helpers.nim +++ b/tests/helpers.nim @@ -32,6 +32,12 @@ iterator testTrackers*(extras: openArray[string] = []): TrackerBase = let t = getTracker(name) if not isNil(t): yield t +template checkTracker*(name: string) = + var tracker = getTracker(LPChannelTrackerName) + if tracker.isLeaked(): + checkpoint tracker.dump() + fail() + template checkTrackers*() = for tracker in testTrackers(): if tracker.isLeaked(): diff --git a/tests/testmplex.nim b/tests/testmplex.nim index 20abc52..eac138e 100644 --- a/tests/testmplex.nim +++ b/tests/testmplex.nim @@ -1,4 +1,4 @@ -import unittest, strformat, strformat, random, oids +import unittest, strformat, strformat, random, oids, sequtils import chronos, nimcrypto/utils, chronicles, stew/byteutils import ../libp2p/[errors, stream/connection, @@ -594,6 +594,347 @@ suite "Mplex": waitFor(testNewStream()) + test "e2e - channel closes listener with EOF": + proc testNewStream() {.async.} = + let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + + var listenStreams: seq[Connection] + proc connHandler(conn: Connection) {.async, gcsafe.} = + let mplexListen = Mplex.init(conn) + mplexListen.streamHandler = proc(stream: Connection) + {.async, gcsafe.} = + listenStreams.add(stream) + try: + discard await stream.readLp(1024) + except LPStreamEOFError: + await stream.close() + return + + check false + + await mplexListen.handle() + await mplexListen.close() + + let transport1 = TcpTransport.init() + let listenFut = await transport1.listen(ma, connHandler) + + let transport2: TcpTransport = TcpTransport.init() + let conn = await transport2.dial(transport1.ma) + + let mplexDial = Mplex.init(conn) + let mplexDialFut = mplexDial.handle() + var dialStreams: seq[Connection] + for i in 0..9: + dialStreams.add((await mplexDial.newStream())) + + for i, s in dialStreams: + await s.closeWithEOF() + check listenStreams[i].closed + check s.closed + + checkTracker(LPChannelTrackerName) + + await conn.close() + await mplexDialFut + await allFuturesThrowing( + transport1.close(), + transport2.close()) + await listenFut + + waitFor(testNewStream()) + + test "e2e - channel closes dialer with EOF": + proc testNewStream() {.async.} = + let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + + var listenStreams: seq[Connection] + var count = 0 + var done = newFuture[void]() + proc connHandler(conn: Connection) {.async, gcsafe.} = + let mplexListen = Mplex.init(conn) + mplexListen.streamHandler = proc(stream: Connection) + {.async, gcsafe.} = + listenStreams.add(stream) + count.inc() + if count == 10: + done.complete() + + await stream.join() + + await mplexListen.handle() + await mplexListen.close() + + let transport1 = TcpTransport.init() + let listenFut = await transport1.listen(ma, connHandler) + + let transport2: TcpTransport = TcpTransport.init() + let conn = await transport2.dial(transport1.ma) + + let mplexDial = Mplex.init(conn) + let mplexDialFut = mplexDial.handle() + var dialStreams: seq[Connection] + for i in 0..9: + dialStreams.add((await mplexDial.newStream())) + + proc dialReadLoop() {.async.} = + for s in dialStreams: + try: + discard await s.readLp(1024) + check false + except LPStreamEOFError: + await s.close() + continue + + check false + + await done + let readLoop = dialReadLoop() + for s in listenStreams: + await s.closeWithEOF() + check s.closed + + await readLoop + await allFuturesThrowing( + allFinished( + (dialStreams & listenStreams) + .mapIt( it.join() ))) + + checkTracker(LPChannelTrackerName) + + await conn.close() + await mplexDialFut + await allFuturesThrowing( + transport1.close(), + transport2.close()) + await listenFut + + waitFor(testNewStream()) + + test "e2e - dialing mplex closes both ends": + proc testNewStream() {.async.} = + let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + + var listenStreams: seq[Connection] + proc connHandler(conn: Connection) {.async, gcsafe.} = + let mplexListen = Mplex.init(conn) + mplexListen.streamHandler = proc(stream: Connection) + {.async, gcsafe.} = + listenStreams.add(stream) + await stream.join() + + await mplexListen.handle() + await mplexListen.close() + + let transport1 = TcpTransport.init() + let listenFut = await transport1.listen(ma, connHandler) + + let transport2: TcpTransport = TcpTransport.init() + let conn = await transport2.dial(transport1.ma) + + let mplexDial = Mplex.init(conn) + let mplexDialFut = mplexDial.handle() + var dialStreams: seq[Connection] + for i in 0..9: + dialStreams.add((await mplexDial.newStream())) + + await mplexDial.close() + await allFuturesThrowing( + allFinished( + (dialStreams & listenStreams) + .mapIt( it.join() ))) + + checkTracker(LPChannelTrackerName) + + await conn.close() + await mplexDialFut + await allFuturesThrowing( + transport1.close(), + transport2.close()) + await listenFut + + waitFor(testNewStream()) + + test "e2e - listening mplex closes both ends": + proc testNewStream() {.async.} = + let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + + var mplexListen: Mplex + var listenStreams: seq[Connection] + proc connHandler(conn: Connection) {.async, gcsafe.} = + mplexListen = Mplex.init(conn) + mplexListen.streamHandler = proc(stream: Connection) + {.async, gcsafe.} = + listenStreams.add(stream) + await stream.join() + + await mplexListen.handle() + await mplexListen.close() + + let transport1 = TcpTransport.init() + let listenFut = await transport1.listen(ma, connHandler) + + let transport2: TcpTransport = TcpTransport.init() + let conn = await transport2.dial(transport1.ma) + + let mplexDial = Mplex.init(conn) + let mplexDialFut = mplexDial.handle() + var dialStreams: seq[Connection] + for i in 0..9: + dialStreams.add((await mplexDial.newStream())) + + await mplexListen.close() + await allFuturesThrowing( + allFinished( + (dialStreams & listenStreams) + .mapIt( it.join() ))) + + checkTracker(LPChannelTrackerName) + + await conn.close() + await mplexDialFut + await allFuturesThrowing( + transport1.close(), + transport2.close()) + await listenFut + + waitFor(testNewStream()) + + test "e2e - canceling mplex handler closes both ends": + proc testNewStream() {.async.} = + let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + + var mplexHandle: Future[void] + var listenStreams: seq[Connection] + proc connHandler(conn: Connection) {.async, gcsafe.} = + let mplexListen = Mplex.init(conn) + mplexListen.streamHandler = proc(stream: Connection) + {.async, gcsafe.} = + listenStreams.add(stream) + await stream.join() + + mplexHandle = mplexListen.handle() + await mplexHandle + await mplexListen.close() + + let transport1 = TcpTransport.init() + let listenFut = await transport1.listen(ma, connHandler) + + let transport2: TcpTransport = TcpTransport.init() + let conn = await transport2.dial(transport1.ma) + + let mplexDial = Mplex.init(conn) + let mplexDialFut = mplexDial.handle() + var dialStreams: seq[Connection] + for i in 0..9: + dialStreams.add((await mplexDial.newStream())) + + mplexHandle.cancel() + await allFuturesThrowing( + allFinished( + (dialStreams & listenStreams) + .mapIt( it.join() ))) + + checkTracker(LPChannelTrackerName) + + await conn.close() + await mplexDialFut + await allFuturesThrowing( + transport1.close(), + transport2.close()) + await listenFut + + waitFor(testNewStream()) + + test "e2e - closing dialing connection should close both ends": + proc testNewStream() {.async.} = + let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + + var listenStreams: seq[Connection] + proc connHandler(conn: Connection) {.async, gcsafe.} = + let mplexListen = Mplex.init(conn) + mplexListen.streamHandler = proc(stream: Connection) + {.async, gcsafe.} = + listenStreams.add(stream) + await stream.join() + + await mplexListen.handle() + await mplexListen.close() + + let transport1 = TcpTransport.init() + let listenFut = await transport1.listen(ma, connHandler) + + let transport2: TcpTransport = TcpTransport.init() + let conn = await transport2.dial(transport1.ma) + + let mplexDial = Mplex.init(conn) + let mplexDialFut = mplexDial.handle() + var dialStreams: seq[Connection] + for i in 0..9: + dialStreams.add((await mplexDial.newStream())) + + await conn.close() + await allFuturesThrowing( + allFinished( + (dialStreams & listenStreams) + .mapIt( it.join() ))) + + checkTracker(LPChannelTrackerName) + + await conn.close() + await mplexDialFut + await allFuturesThrowing( + transport1.close(), + transport2.close()) + await listenFut + + waitFor(testNewStream()) + + test "e2e - canceling listening connection should close both ends": + proc testNewStream() {.async.} = + let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet() + + var listenConn: Connection + var listenStreams: seq[Connection] + proc connHandler(conn: Connection) {.async, gcsafe.} = + listenConn = conn + let mplexListen = Mplex.init(conn) + mplexListen.streamHandler = proc(stream: Connection) + {.async, gcsafe.} = + listenStreams.add(stream) + await stream.join() + + await mplexListen.handle() + await mplexListen.close() + + let transport1 = TcpTransport.init() + let listenFut = await transport1.listen(ma, connHandler) + + let transport2: TcpTransport = TcpTransport.init() + let conn = await transport2.dial(transport1.ma) + + let mplexDial = Mplex.init(conn) + let mplexDialFut = mplexDial.handle() + var dialStreams: seq[Connection] + for i in 0..9: + dialStreams.add((await mplexDial.newStream())) + + await listenConn.close() + await allFuturesThrowing( + allFinished( + (dialStreams & listenStreams) + .mapIt( it.join() ))) + + checkTracker(LPChannelTrackerName) + + await conn.close() + await mplexDialFut + await allFuturesThrowing( + transport1.close(), + transport2.close()) + await listenFut + + waitFor(testNewStream()) + test "jitter - channel should be able to handle erratic read/writes": proc test() {.async.} = let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0").tryGet() @@ -628,7 +969,7 @@ suite "Mplex": for _ in 0..