diff --git a/libp2p/muxers/mplex/channel.nim b/libp2p/muxers/mplex/channel.nim index ecf794036..b789d4be2 100644 --- a/libp2p/muxers/mplex/channel.nim +++ b/libp2p/muxers/mplex/channel.nim @@ -8,27 +8,123 @@ ## those terms. import chronos -import ../../stream/bufferstream -import types +import ../../stream/bufferstream, + ../../stream/lpstream, + types, coder, ../../connection type Channel* = ref object of BufferStream id*: int + conn*: Connection initiator*: bool isReset*: bool closedLocal*: bool closedRemote*: bool handlerFuture*: Future[void] + msgCode*: MessageType + closeCode*: MessageType + resetCode*: MessageType proc newChannel*(id: int, + conn: Connection, initiator: bool, - handler: WriteHandler, size: int = MaxMsgSize): Channel = new result result.id = id + result.conn = conn result.initiator = initiator - result.initBufferStream(handler, size) + result.msgCode = if initiator: MessageType.MsgOut else: MessageType.MsgIn + result.closeCode = if initiator: MessageType.CloseOut else: MessageType.CloseIn + result.resetCode = if initiator: MessageType.ResetOut else: MessageType.ResetIn -proc closed*(s: Channel): bool = s.closedLocal and s.closedRemote -proc close*(s: Channel) {.async.} = discard -proc reset*(s: Channel) {.async.} = discard + let chan = result + proc writeHandler(data: seq[byte]): Future[void] {.async, gcsafe.} = + await conn.writeHeader(id, chan.msgCode, data.len) # write header + await conn.write(data) + + result.initBufferStream(writeHandler, size) + +proc closeMessage(s: Channel) {.async, gcsafe.} = + await s.conn.writeHeader(s.id, s.closeCode, 0) # write header + +proc closed*(s: Channel): bool = + s.closedLocal + +proc closeRemote*(s: Channel) {.async.} = + s.closedRemote = true + +method close*(s: Channel) {.async, gcsafe.} = + s.closedLocal = true + await s.closeMessage() + +proc resetMessage(s: Channel) {.async, gcsafe.} = + await s.conn.writeHeader(s.id, s.resetCode, 0) # write header + +proc remoteReset*(s: Channel) {.async, gcsafe.} = + await allFutures(s.close(), s.closeRemote()) + s.isReset = true + +proc reset*(s: Channel) {.async.} = + await allFutures(s.resetMessage(), s.remoteReset()) + +proc isReadEof(s: Channel): bool = + bool((s.closedRemote or s.closedLocal) and s.len() <= 0) + +method pushTo*(s: Channel, data: seq[byte]): Future[void] {.gcsafe.} = + if s.closedRemote: + raise newLPStreamClosedError() + result = procCall pushTo(BufferStream(s), data) + +method read*(s: Channel, n = -1): Future[seq[byte]] {.gcsafe.} = + if s.isReadEof(): + raise newLPStreamClosedError() + result = procCall read(BufferStream(s), n) + +method readExactly*(s: Channel, + pbytes: pointer, + nbytes: int): + Future[void] {.gcsafe.} = + if s.isReadEof(): + raise newLPStreamClosedError() + result = procCall readExactly(BufferStream(s), pbytes, nbytes) + +method readLine*(s: Channel, + limit = 0, + sep = "\r\n"): + Future[string] {.gcsafe.} = + if s.isReadEof(): + raise newLPStreamClosedError() + result = procCall readLine(BufferStream(s), limit, sep) + +method readOnce*(s: Channel, + pbytes: pointer, + nbytes: int): + Future[int] {.gcsafe.} = + if s.isReadEof(): + raise newLPStreamClosedError() + result = procCall readOnce(BufferStream(s), pbytes, nbytes) + +method readUntil*(s: Channel, + pbytes: pointer, nbytes: int, + sep: seq[byte]): + Future[int] {.gcsafe.} = + if s.isReadEof(): + raise newLPStreamClosedError() + result = procCall readOnce(BufferStream(s), pbytes, nbytes) + +method write*(s: Channel, + pbytes: pointer, + nbytes: int): Future[void] {.gcsafe.} = + if s.closedLocal: + raise newLPStreamClosedError() + result = procCall write(BufferStream(s), pbytes, nbytes) + +method write*(s: Channel, msg: string, msglen = -1) {.async, gcsafe.} = + if s.closedLocal: + raise newLPStreamClosedError() + result = procCall write(BufferStream(s), msg, msglen) + +method write*(s: Channel, msg: seq[byte], msglen = -1) {.async, gcsafe.} = + if s.closedLocal: + raise newLPStreamClosedError() + result = procCall write(BufferStream(s), msg, msglen) diff --git a/libp2p/muxers/mplex/coder.nim b/libp2p/muxers/mplex/coder.nim index 141e2a2fc..7ba8f2c53 100644 --- a/libp2p/muxers/mplex/coder.nim +++ b/libp2p/muxers/mplex/coder.nim @@ -9,7 +9,7 @@ import chronos import ../../connection, ../../varint, - ../../vbuffer, mplex, types, + ../../vbuffer, types, ../../stream/lpstream proc readHeader*(conn: Connection): Future[(uint, MessageType)] {.async, gcsafe.} = diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index 62b4bf42d..8b956b357 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -26,10 +26,6 @@ type proc newMplexUnknownMsgError(): ref MplexUnknownMsgError = result = newException(MplexUnknownMsgError, "Unknown mplex message type") -########################################## -## Mplex -########################################## - proc getChannelList(m: Mplex, initiator: bool): var Table[int, Channel] = if initiator: result = m.remote @@ -42,12 +38,7 @@ proc newStreamInternal*(m: Mplex, Future[Channel] {.async, gcsafe.} = ## create new channel/stream let id = if initiator: m.currentId.inc(); m.currentId else: chanId - proc writeHandler(data: seq[byte]): Future[void] {.async, gcsafe.} = - let msgType = if initiator: MessageType.MsgOut else: MessageType.MsgIn - await m.connection.writeHeader(id, msgType, data.len) # write header - await m.connection.write(data) # write data - - result = newChannel(id, initiator, writeHandler) + result = newChannel(id, m.connection, initiator) m.getChannelList(initiator)[id] = result proc newStreamInternal*(m: Mplex): Future[Channel] {.gcsafe.} = @@ -68,7 +59,7 @@ proc handle*(m: Mplex): Future[void] {.async, gcsafe.} = await channel.pushTo(msg) of MessageType.CloseIn, MessageType.CloseOut: let channel = m.getChannelList(initiator)[id.int] - await channel.close() + await channel.closeRemote() of MessageType.ResetIn, MessageType.ResetOut: let channel = m.getChannelList(initiator)[id.int] await channel.reset() diff --git a/libp2p/muxers/mplex/types.nim b/libp2p/muxers/mplex/types.nim index 8a7067ffe..fe29a82ae 100644 --- a/libp2p/muxers/mplex/types.nim +++ b/libp2p/muxers/mplex/types.nim @@ -13,6 +13,7 @@ import ../../connection const MaxMsgSize* = 1 shl 20 # 1mb const MaxChannels* = 1000 const MplexCodec* = "/mplex/6.7.0" +const MaxReadWriteTime* = 5.seconds type MplexUnknownMsgError* = object of CatchableError diff --git a/libp2p/stream/lpstream.nim b/libp2p/stream/lpstream.nim index 9c947dc8b..f763928a5 100644 --- a/libp2p/stream/lpstream.nim +++ b/libp2p/stream/lpstream.nim @@ -21,6 +21,7 @@ type par*: ref Exception LPStreamWriteError* = object of LPStreamError par*: ref Exception + LPStreamClosedError* = object of LPStreamError proc newLPStreamReadError*(p: ref Exception): ref Exception {.inline.} = var w = newException(LPStreamReadError, "Read stream failed") @@ -43,6 +44,9 @@ proc newLPStreamLimitError*(): ref Exception {.inline.} = proc newLPStreamIncorrectError*(m: string): ref Exception {.inline.} = result = newException(LPStreamIncorrectError, m) +proc newLPStreamClosedError*(): ref Exception {.inline.} = + result = newException(LPStreamClosedError, "Stream closed!") + method read*(s: LPStream, n = -1): Future[seq[byte]] {.base, async, gcsafe.} = assert(false, "not implemented!") diff --git a/tests/testmplex.nim b/tests/testmplex.nim index 2899c25b4..7cbdbbcdc 100644 --- a/tests/testmplex.nim +++ b/tests/testmplex.nim @@ -2,6 +2,7 @@ import unittest, sequtils, sugar import chronos, nimcrypto/utils import ../libp2p/connection, ../libp2p/stream/lpstream, + ../libp2p/stream/bufferstream, ../libp2p/tcptransport, ../libp2p/transport, ../libp2p/multiaddress, @@ -69,64 +70,111 @@ suite "Mplex": check: waitFor(testDecodeHeader()) == true - test "e2e - read/write initiator": - proc testNewStream(): Future[bool] {.async.} = - let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53380") + test "e2e - read/write initiator": + proc testNewStream(): Future[bool] {.async.} = + let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53380") - proc connHandler(conn: Connection) {.async, gcsafe.} = - proc handleMplexListen(stream: Connection) {.async, gcsafe.} = - await stream.writeLp("Hello from stream!") - await stream.close() + proc connHandler(conn: Connection) {.async, gcsafe.} = + proc handleMplexListen(stream: Connection) {.async, gcsafe.} = + await stream.writeLp("Hello from stream!") + await stream.close() - let mplexListen = newMplex(conn, handleMplexListen) - await mplexListen.handle() + let mplexListen = newMplex(conn, handleMplexListen) + await mplexListen.handle() - let transport1: TcpTransport = newTransport(TcpTransport) - await transport1.listen(ma, connHandler) + let transport1: TcpTransport = newTransport(TcpTransport) + await transport1.listen(ma, connHandler) - let transport2: TcpTransport = newTransport(TcpTransport) - let conn = await transport2.dial(ma) + let transport2: TcpTransport = newTransport(TcpTransport) + let conn = await transport2.dial(ma) - proc handleDial(stream: Connection) {.async, gcsafe.} = discard - let mplexDial = newMplex(conn, handleDial) - let dialFut = mplexDial.handle() - let stream = await mplexDial.newStream() - check cast[string](await stream.readLp()) == "Hello from stream!" + proc handleDial(stream: Connection) {.async, gcsafe.} = discard + let mplexDial = newMplex(conn, handleDial) + let dialFut = mplexDial.handle() + let stream = await mplexDial.newStream() + check cast[string](await stream.readLp()) == "Hello from stream!" + await conn.close() + await dialFut + result = true - await conn.close() - await dialFut - result = true + check: + waitFor(testNewStream()) == true - check: - waitFor(testNewStream()) == true + test "e2e - read/write receiver": + proc testNewStream(): Future[bool] {.async.} = + let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53381") - test "e2e - read/write receiver": - proc testNewStream(): Future[bool] {.async.} = - let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53381") + proc connHandler(conn: Connection) {.async, gcsafe.} = + proc handleMplexListen(stream: Connection) {.async, gcsafe.} = + let msg = await stream.readLp() + check cast[string](msg) == "Hello from stream!" + await stream.close() - proc connHandler(conn: Connection) {.async, gcsafe.} = - proc handleMplexListen(stream: Connection) {.async, gcsafe.} = - let msg = await stream.readLp() - check cast[string](msg) == "Hello from stream!" - await stream.close() + let mplexListen = newMplex(conn, handleMplexListen) + await mplexListen.handle() - let mplexListen = newMplex(conn, handleMplexListen) - await mplexListen.handle() + let transport1: TcpTransport = newTransport(TcpTransport) + await transport1.listen(ma, connHandler) - let transport1: TcpTransport = newTransport(TcpTransport) - await transport1.listen(ma, connHandler) + let transport2: TcpTransport = newTransport(TcpTransport) + let conn = await transport2.dial(ma) - let transport2: TcpTransport = newTransport(TcpTransport) - let conn = await transport2.dial(ma) + proc handleDial(stream: Connection) {.async, gcsafe.} = discard + let mplexDial = newMplex(conn, handleDial) + let dialFut = mplexDial.handle() + let stream = await mplexDial.newStream() + await stream.writeLp("Hello from stream!") + await conn.close() + await dialFut + result = true - proc handleDial(stream: Connection) {.async, gcsafe.} = discard - let mplexDial = newMplex(conn, handleDial) - let dialFut = mplexDial.handle() - let stream = await mplexDial.newStream() - await stream.writeLp("Hello from stream!") - await conn.close() - await dialFut - result = true + check: + waitFor(testNewStream()) == true - check: - waitFor(testNewStream()) == true + test "half closed - channel should close for write": + proc testClosedForWrite(): Future[void] {.async.} = + let chann = newChannel(1, newConnection(new LPStream), true) + await chann.close() + await chann.write("Hello") + + expect LPStreamClosedError: + waitFor(testClosedForWrite()) + + test "half closed - channel should close for read": + proc testClosedForRead(): Future[void] {.async.} = + let chann = newChannel(1, newConnection(new LPStream), true) + await chann.closeRemote() + asyncDiscard chann.read() + + expect LPStreamClosedError: + waitFor(testClosedForRead()) + + test "half closed - channel should close for read after eof": + proc testClosedForRead(): Future[void] {.async.} = + let chann = newChannel(1, newConnection(new LPStream), true) + + await chann.pushTo(cast[seq[byte]](toSeq("Hello!".items))) + await chann.close() + let msg = await chann.read() + asyncDiscard chann.read() + + expect LPStreamClosedError: + waitFor(testClosedForRead()) + + test "reset - channel should fail reading": + proc testResetRead(): Future[void] {.async.} = + let chann = newChannel(1, newConnection(new LPStream), true) + await chann.reset() + asyncDiscard chann.read() + + expect LPStreamClosedError: + waitFor(testResetRead()) + + test "reset - channel should fail writing": + proc testResetWrite(): Future[void] {.async.} = + let chann = newChannel(1, newConnection(new LPStream), true) + await chann.reset() + asyncDiscard chann.read() + + expect LPStreamClosedError: + waitFor(testResetWrite())