diff --git a/libp2p/muxers/mplex/channel.nim b/libp2p/muxers/mplex/channel.nim index f501e97fc..638076bb9 100644 --- a/libp2p/muxers/mplex/channel.nim +++ b/libp2p/muxers/mplex/channel.nim @@ -18,7 +18,7 @@ const DefaultChannelSize* = DefaultBufferSize * 64 # 64kb type Channel* = ref object of BufferStream - id*: int + id*: uint name*: string conn*: Connection initiator*: bool @@ -30,7 +30,7 @@ type closeCode*: MessageType resetCode*: MessageType -proc newChannel*(id: int, +proc newChannel*(id: uint, conn: Connection, initiator: bool, name: string = "", @@ -46,13 +46,12 @@ proc newChannel*(id: int, let chan = result proc writeHandler(data: seq[byte]): Future[void] {.async, gcsafe.} = - await conn.writeHeader(chan.id, chan.msgCode, data.len) # write header - await conn.write(data) + await conn.writeMsg(chan.id, chan.msgCode, data) # write header result.initBufferStream(writeHandler, size) proc closeMessage(s: Channel) {.async, gcsafe.} = - await s.conn.writeHeader(s.id, s.closeCode) # write header + await s.conn.writeMsg(s.id, s.closeCode) # write header proc closed*(s: Channel): bool = s.closedLocal @@ -65,7 +64,7 @@ method close*(s: Channel) {.async, gcsafe.} = await s.closeMessage() proc resetMessage(s: Channel) {.async, gcsafe.} = - await s.conn.writeHeader(s.id, s.resetCode) + await s.conn.writeMsg(s.id, s.resetCode) proc resetByRemote*(s: Channel) {.async, gcsafe.} = await allFutures(s.close(), s.closedByRemote()) diff --git a/libp2p/muxers/mplex/coder.nim b/libp2p/muxers/mplex/coder.nim index 93648ff4f..a4cd4b021 100644 --- a/libp2p/muxers/mplex/coder.nim +++ b/libp2p/muxers/mplex/coder.nim @@ -18,33 +18,39 @@ import types, type Phase = enum Header, Size -proc readHeader*(conn: Connection): Future[(uint, MessageType)] {.async, gcsafe.} = +proc readMplexVarint(conn: Connection): Future[uint] {.async, gcsafe.} = var - header: uint + varint: uint length: int res: VarintStatus var buffer = newSeq[byte](10) try: for i in 0.. 0.uint: + data = await conn.read(dataLen.int) + result = (header shr 3, MessageType(header and 0x7), data) + +proc writeMsg*(conn: Connection, + id: uint, msgType: MessageType, - size: int = 0) {.async, gcsafe.} = + data: seq[byte] = @[]) {.async, gcsafe.} = ## write lenght prefixed var buf = initVBuffer() - buf.writeVarint((id.uint shl 3) or msgType.uint) - buf.writeVarint(size.uint) # size should be always sent + buf.writeVarint((id shl 3) or ord(msgType).uint) + buf.writeVarint(data.len().uint) # size should be always sent buf.finish() - await conn.write(buf.buffer) + await conn.write(buf.buffer & data) diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index 8a975e134..b971c9017 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -29,18 +29,18 @@ import coder, types, channel, type Mplex* = ref object of Muxer - remote*: Table[int, Channel] - local*: Table[int, Channel] - currentId*: int + remote*: Table[uint, Channel] + local*: Table[uint, Channel] + currentId*: uint maxChannels*: uint -proc newMplexNoSuchChannel(id: int, msgType: MessageType): ref MplexNoSuchChannel = +proc newMplexNoSuchChannel(id: uint, msgType: MessageType): ref MplexNoSuchChannel = result = newException(MplexNoSuchChannel, &"No such channel id {$id} and message {$msgType}") proc newMplexUnknownMsgError(): ref MplexUnknownMsgError = result = newException(MplexUnknownMsgError, "Unknown mplex message type") -proc getChannelList(m: Mplex, initiator: bool): var Table[int, Channel] = +proc getChannelList(m: Mplex, initiator: bool): var Table[uint, Channel] = if initiator: result = m.remote else: @@ -48,7 +48,7 @@ proc getChannelList(m: Mplex, initiator: bool): var Table[int, Channel] = proc newStreamInternal*(m: Mplex, initiator: bool = true, - chanId: int, + chanId: uint = 0, name: string = ""): Future[Channel] {.async, gcsafe.} = ## create new channel/stream @@ -56,60 +56,45 @@ proc newStreamInternal*(m: Mplex, result = newChannel(id, m.connection, initiator, name) m.getChannelList(initiator)[id] = result -proc newStreamInternal*(m: Mplex): Future[Channel] {.gcsafe.} = - result = m.newStreamInternal(true, 0) - method handle*(m: Mplex): Future[void] {.async, gcsafe.} = - try: while not m.connection.closed: - let (id, msgType) = await m.connection.readHeader() - let initiator = bool(ord(msgType) and 1) - var channel: Channel - if MessageType(msgType) != MessageType.New: - let channels = m.getChannelList(initiator) - if not channels.contains(id.int): - raise newMplexNoSuchChannel(id.int, msgType) - channel = channels[id.int] + try: + let (id, msgType, data) = await m.connection.readMsg() + let initiator = bool(ord(msgType) and 1) + var channel: Channel + if MessageType(msgType) != MessageType.New: + let channels = m.getChannelList(initiator) + if not channels.contains(id): + raise newMplexNoSuchChannel(id, msgType) + channel = channels[id] - case msgType: - of MessageType.New: - var name: seq[byte] - try: - name = await m.connection.readLp() - except LPStreamIncompleteError as exc: - echo exc.msg - except Exception as exc: - echo exc.msg - raise - - let channel = await m.newStreamInternal(false, id.int, cast[string](name)) - if not isNil(m.streamHandler): - channel.handlerFuture = m.streamHandler(newConnection(channel)) - of MessageType.MsgIn, MessageType.MsgOut: - let msg = await m.connection.readLp() - await channel.pushTo(msg) - of MessageType.CloseIn, MessageType.CloseOut: - await channel.closedByRemote() - m.getChannelList(initiator).del(id.int) - of MessageType.ResetIn, MessageType.ResetOut: - await channel.resetByRemote() - else: raise newMplexUnknownMsgError() - finally: - await m.connection.close() + case msgType: + of MessageType.New: + channel = await m.newStreamInternal(false, id, cast[string](data)) + if not isNil(m.streamHandler): + await m.streamHandler(newConnection(channel)) + of MessageType.MsgIn, MessageType.MsgOut: + await channel.pushTo(data) + of MessageType.CloseIn, MessageType.CloseOut: + await channel.closedByRemote() + m.getChannelList(initiator).del(id) + of MessageType.ResetIn, MessageType.ResetOut: + await channel.resetByRemote() + else: raise newMplexUnknownMsgError() + finally: + await m.connection.close() proc newMplex*(conn: Connection, maxChanns: uint = MaxChannels): Mplex = new result result.connection = conn result.maxChannels = maxChanns - result.remote = initTable[int, Channel]() - result.local = initTable[int, Channel]() + result.remote = initTable[uint, Channel]() + result.local = initTable[uint, Channel]() method newStream*(m: Mplex, name: string = ""): Future[Connection] {.async, gcsafe.} = let channel = await m.newStreamInternal() - await m.connection.writeHeader(channel.id, MessageType.New, len(name)) - if name.len > 0: - await m.connection.write(name) + await m.connection.writeMsg(channel.id, MessageType.New, cast[seq[byte]](toSeq(name.items))) result = newConnection(channel) method close*(m: Mplex) {.async, gcsafe.} = diff --git a/tests/testmplex.nim b/tests/testmplex.nim index 52dbd3fb2..1a6bb6be5 100644 --- a/tests/testmplex.nim +++ b/tests/testmplex.nim @@ -12,65 +12,109 @@ import ../libp2p/connection, ../libp2p/muxers/mplex/types, ../libp2p/muxers/mplex/channel -type - TestEncodeStream = ref object of LPStream - handler*: proc(data: seq[byte]) - -method write*(s: TestEncodeStream, - msg: seq[byte], - msglen = -1): - Future[void] {.gcsafe.} = - s.handler(msg) - -proc newTestEncodeStream(handler: proc(data: seq[byte])): TestEncodeStream = - new result - result.handler = handler - -type - TestDecodeStream = ref object of LPStream - handler*: proc(data: seq[byte]) - step*: int - msg*: seq[byte] - -method readExactly*(s: TestDecodeStream, - pbytes: pointer, - nbytes: int): Future[void] {.async, gcsafe.} = - let buff: seq[byte] = s.msg - copyMem(pbytes, unsafeAddr buff[s.step], nbytes) - s.step += nbytes - -proc newTestDecodeStream(): TestDecodeStream = - new result - result.step = 0 - result.msg = fromHex("8801023137") - suite "Mplex": - test "encode header": + test "encode header with channel id 0": proc testEncodeHeader(): Future[bool] {.async.} = - proc encHandler(msg: seq[byte]) = - check msg == fromHex("886f04") + proc encHandler(msg: seq[byte]) {.async.} = + check msg == fromHex("000873747265616d2031") - let conn = newConnection(newTestEncodeStream(encHandler)) - await conn.writeHeader(1777, MessageType.New, 4) + let stream = newBufferStream(encHandler) + let conn = newConnection(stream) + await conn.writeMsg(0, MessageType.New, cast[seq[byte]](toSeq("stream 1".items))) result = true check: waitFor(testEncodeHeader()) == true - test "decode header": - proc testDecodeHeader(): Future[bool] {.async.} = - let conn = newConnection(newTestDecodeStream()) - let (id, msgType) = await conn.readHeader() + test "encode header with channel id other than 0": + proc testEncodeHeader(): Future[bool] {.async.} = + proc encHandler(msg: seq[byte]) {.async.} = + check msg == fromHex("88010873747265616d2031") + + let stream = newBufferStream(encHandler) + let conn = newConnection(stream) + await conn.writeMsg(17, MessageType.New, cast[seq[byte]](toSeq("stream 1".items))) + result = true - check id == 17 + check: + waitFor(testEncodeHeader()) == true + + test "encode header and body with channel id 0": + proc testEncodeHeaderBody(): Future[bool] {.async.} = + var step = 0 + proc encHandler(msg: seq[byte]) {.async.} = + check msg == fromHex("020873747265616d2031") + + let stream = newBufferStream(encHandler) + let conn = newConnection(stream) + await conn.writeMsg(0, MessageType.MsgOut, cast[seq[byte]](toSeq("stream 1".items))) + result = true + + check: + waitFor(testEncodeHeaderBody()) == true + + test "encode header and body with channel id other than 0": + proc testEncodeHeaderBody(): Future[bool] {.async.} = + var step = 0 + proc encHandler(msg: seq[byte]) {.async.} = + check msg == fromHex("8a010873747265616d2031") + + let stream = newBufferStream(encHandler) + let conn = newConnection(stream) + await conn.writeMsg(17, MessageType.MsgOut, cast[seq[byte]](toSeq("stream 1".items))) + await conn.close() + result = true + + check: + waitFor(testEncodeHeaderBody()) == true + + test "decode header with channel id 0": + proc testDecodeHeader(): Future[bool] {.async.} = + proc encHandler(msg: seq[byte]) {.async.} = discard + let stream = newBufferStream(encHandler) + let conn = newConnection(stream) + await stream.pushTo(fromHex("000873747265616d2031")) + let (id, msgType, data) = await conn.readMsg() + + check id == 0 check msgType == MessageType.New - let data = await conn.readLp() - check cast[string](data) == "17" result = true check: waitFor(testDecodeHeader()) == true - + + test "decode header and body with channel id 0": + proc testDecodeHeader(): Future[bool] {.async.} = + proc encHandler(msg: seq[byte]) {.async.} = discard + let stream = newBufferStream(encHandler) + let conn = newConnection(stream) + await stream.pushTo(fromHex("021668656C6C6F2066726F6D206368616E6E656C20302121")) + let (id, msgType, data) = await conn.readMsg() + + check id == 0 + check msgType == MessageType.MsgOut + check cast[string](data) == "hello from channel 0!!" + result = true + + check: + waitFor(testDecodeHeader()) == true + + test "decode header and body with channel id other than 0": + proc testDecodeHeader(): Future[bool] {.async.} = + proc encHandler(msg: seq[byte]) {.async.} = discard + let stream = newBufferStream(encHandler) + let conn = newConnection(stream) + await stream.pushTo(fromHex("8a011668656C6C6F2066726F6D206368616E6E656C20302121")) + let (id, msgType, data) = await conn.readMsg() + + check id == 17 + check msgType == MessageType.MsgOut + check cast[string](data) == "hello from channel 0!!" + result = true + + check: + waitFor(testDecodeHeader()) == true + test "e2e - read/write initiator": proc testNewStream(): Future[bool] {.async.} = let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53380") @@ -92,10 +136,11 @@ suite "Mplex": let mplexDial = newMplex(conn) let dialFut = mplexDial.handle() - let stream = await mplexDial.newStream() - check cast[string](await stream.readLp()) == "Hello from stream!" + let stream = await mplexDial.newStream("DIALER") + let msg = cast[string](await stream.readLp()) + check msg == "Hello from stream!" await conn.close() - await dialFut + # await dialFut result = true check: @@ -122,11 +167,9 @@ suite "Mplex": let conn = await transport2.dial(ma) let mplexDial = newMplex(conn) - let dialFut = mplexDial.handle() let stream = await mplexDial.newStream() await stream.writeLp("Hello from stream!") await conn.close() - await dialFut result = true check: @@ -136,16 +179,13 @@ suite "Mplex": proc testNewStream(): Future[bool] {.async.} = let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53382") - var count = 0 - var completionFut: Future[void] = newFuture[void]() + var count = 1 proc connHandler(conn: Connection) {.async, gcsafe.} = proc handleMplexListen(stream: Connection) {.async, gcsafe.} = let msg = await stream.readLp() - check cast[string](msg) == &"Hello from stream {count}!" + check cast[string](msg) == &"stream {count}!" count.inc await stream.close() - if count == 11: - completionFut.complete() let mplexListen = newMplex(conn) mplexListen.streamHandler = handleMplexListen @@ -158,18 +198,12 @@ suite "Mplex": let conn = await transport2.dial(ma) let mplexDial = newMplex(conn) - asyncCheck mplexDial.handle() - - for i in 0..10: + for i in 1..<10: let stream = await mplexDial.newStream() - await stream.writeLp(&"Hello from stream {i}!") - - await completionFut - # closing the connection doesn't transfer all the data - # this seems to be a bug in chronos - # await conn.close() - check count == 11 + await stream.writeLp(&"stream {i}!") + await stream.close() + await conn.close() result = true check: @@ -177,7 +211,8 @@ suite "Mplex": test "half closed - channel should close for write": proc testClosedForWrite(): Future[void] {.async.} = - let chann = newChannel(1, newConnection(new LPStream), true) + proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard + let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true) await chann.close() await chann.write("Hello") @@ -186,7 +221,8 @@ suite "Mplex": test "half closed - channel should close for read": proc testClosedForRead(): Future[void] {.async.} = - let chann = newChannel(1, newConnection(new LPStream), true) + proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard + let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true) await chann.closedByRemote() asyncDiscard chann.read() @@ -195,7 +231,8 @@ suite "Mplex": test "half closed - channel should close for read after eof": proc testClosedForRead(): Future[void] {.async.} = - let chann = newChannel(1, newConnection(new LPStream), true) + proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard + let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true) await chann.pushTo(cast[seq[byte]](toSeq("Hello!".items))) await chann.close() @@ -207,7 +244,8 @@ suite "Mplex": test "reset - channel should fail reading": proc testResetRead(): Future[void] {.async.} = - let chann = newChannel(1, newConnection(new LPStream), true) + proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard + let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true) await chann.reset() asyncDiscard chann.read() @@ -216,7 +254,8 @@ suite "Mplex": test "reset - channel should fail writing": proc testResetWrite(): Future[void] {.async.} = - let chann = newChannel(1, newConnection(new LPStream), true) + proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard + let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true) await chann.reset() asyncDiscard chann.read() @@ -225,7 +264,8 @@ suite "Mplex": test "should not allow pushing data to channel when remote end closed": proc testResetWrite(): Future[void] {.async.} = - let chann = newChannel(1, newConnection(new LPStream), true) + proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard + let chann = newChannel(1, newConnection(newBufferStream(writeHandler)), true) await chann.closedByRemote() await chann.pushTo(@[byte(1)])