diff --git a/libp2p/muxers/mplex.nim b/libp2p/muxers/mplex.nim index 8b2ae5d..31e2067 100644 --- a/libp2p/muxers/mplex.nim +++ b/libp2p/muxers/mplex.nim @@ -7,11 +7,11 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -import sequtils +import tables, sequtils import chronos import ../varint, ../connection, ../vbuffer, ../protocol, - ../stream/bufferstream, + ../stream/bufferstream, ../stream/lpstream, muxer const MaxMsgSize* = 1 shl 20 # 1mb @@ -31,8 +31,8 @@ type StreamHandler = proc(conn: Connection): Future[void] {.gcsafe.} Mplex* = ref object of Muxer - remote*: seq[Channel] - local*: seq[Channel] + remote*: Table[int, Channel] + local*: Table[int, Channel] currentId*: int maxChannels*: uint streamHandler*: StreamHandler @@ -67,7 +67,7 @@ proc readHeader*(conn: Connection): Future[(uint, MessageType)] {.async, gcsafe. if res != VarintStatus.Success: buffer.setLen(0) return - except TransportIncompleteError: + except LPStreamIncompleteError: buffer.setLen(0) proc writeHeader*(conn: Connection, @@ -77,7 +77,8 @@ proc writeHeader*(conn: Connection, ## write lenght prefixed var buf = initVBuffer() buf.writeVarint(LPSomeUVarint(id.uint shl 3 or msgType.uint)) - buf.writeVarint(LPSomeUVarint(size.uint)) + if size > 0: + buf.writeVarint(LPSomeUVarint(size.uint)) buf.finish() result = conn.write(buf.buffer) @@ -94,55 +95,64 @@ proc newChannel*(mplex: Mplex, result.id = id result.mplex = mplex result.initiator = initiator - result.writeHandler = handler - result.maxSize = size + result.initBufferStream(handler, size) proc closed*(s: Channel): bool = s.closedLocal and s.closedRemote proc close*(s: Channel) {.async.} = discard proc reset*(s: Channel) {.async.} = discard ########################################## -## ## Mplex -## ########################################## -proc getChannelList(m: Mplex, initiator: bool): var seq[Channel] = - if initiator: +proc getChannelList(m: Mplex, initiator: bool): var Table[int, Channel] = + if initiator: result = m.remote else: result = m.local -proc newStream*(m: Mplex, - chanId: int = -1, - initiator: bool = true): - Future[Connection] {.async, gcsafe.} = +proc newStreamInternal*(m: Mplex, + initiator: bool = true, + chanId: int): + Future[Channel] {.async, gcsafe.} = ## create new channel/stream - defer: inc(m.currentId) - let id = if chanId > -1: chanId else: m.currentId + let id = if initiator: m.currentId.inc(); m.currentId else: chanId proc writeHandler(data: seq[byte]): Future[void] {.async, gcsafe.} = - let msgType = if initiator: MessageType.MsgIn else: MessageType.MsgOut + let msgType = if initiator: MessageType.MsgOut else: MessageType.MsgIn await m.connection.writeHeader(id, msgType, data.len) # write header await m.connection.write(data) # write data - let channel = newChannel(m, id, initiator, writeHandler) - m.getChannelList(initiator)[id] = channel - result = newConnection(channel) + result = newChannel(m, id, initiator, writeHandler) + m.getChannelList(initiator)[id] = result -proc handle*(m: Mplex) {.async, gcsafe.} = - while not m.connection.closed: - let (id, msgType) = await m.connection.readHeader() - let initiator = bool(ord(msgType) and 1) - case msgType: - of MessageType.New: - await m.streamHandler(await m.newStream(id.int, false)) - of MessageType.MsgIn, MessageType.MsgOut: - await m.getChannelList(initiator)[id.int].pushTo(await m.connection.readLp()) - of MessageType.CloseIn, MessageType.CloseOut: - await m.getChannelList(initiator)[id.int].close() - of MessageType.ResetIn, MessageType.ResetOut: - await m.getChannelList(initiator)[id.int].reset() - else: raise newMplexUnknownMsgError() +proc newStreamInternal*(m: Mplex): Future[Channel] {.gcsafe.} = + result = m.newStreamInternal(true, 0) + +proc handle*(m: Mplex): Future[void] {.async, gcsafe.} = + try: + while not m.connection.closed: + let (id, msgType) = await m.connection.readHeader() + let initiator = bool(ord(msgType) and 1) + case msgType: + of MessageType.New: + let channel = await m.newStreamInternal(false, id.int) + await m.streamHandler(newConnection(channel)) + of MessageType.MsgIn, MessageType.MsgOut: + let channel = m.getChannelList(initiator)[id.int] + let msg = await m.connection.readLp() + await channel.pushTo(msg) + of MessageType.CloseIn, MessageType.CloseOut: + let channel = m.getChannelList(initiator)[id.int] + await channel.close() + of MessageType.ResetIn, MessageType.ResetOut: + let channel = m.getChannelList(initiator)[id.int] + await channel.reset() + else: raise newMplexUnknownMsgError() + except Exception as exc: + #TODO: add proper loging + discard + finally: + await m.connection.close() proc newMplex*(conn: Connection, streamHandler: StreamHandler, @@ -151,12 +161,15 @@ proc newMplex*(conn: Connection, result.connection = conn result.maxChannels = maxChanns result.streamHandler = streamHandler + result.remote = initTable[int, Channel]() + result.local = initTable[int, Channel]() -method newStream*(m: Mplex): Future[Connection] {.gcsafe.} = - result = m.newStream(true) +method newStream*(m: Mplex): Future[Connection] {.async, gcsafe.} = + let channel = await m.newStreamInternal() + await m.connection.writeHeader(channel.id, MessageType.New, 0) + result = newConnection(channel) -method close(m: Mplex) {.async, gcsafe.} = - let futs = @[allFutures(m.remote.mapIt(it.close())), - allFutures(m.local.mapIt(it.close()))] - await allFutures(futs) +method close*(m: Mplex) {.async, gcsafe.} = + await allFutures(@[allFutures(toSeq(m.remote.values).mapIt(it.close())), + allFutures(toSeq(m.local.values).mapIt(it.close()))]) await m.connection.close() diff --git a/tests/testmplex.nim b/tests/testmplex.nim index ee83867..a4d4f53 100644 --- a/tests/testmplex.nim +++ b/tests/testmplex.nim @@ -43,7 +43,7 @@ suite "Mplex": # check msg == fromHex("880102") # let conn = newConnection(newTestEncodeStream(encHandler)) - # await conn.writeHeader(uint(17), MessageType.New, 2) + # await conn.writeHeader(17, MessageType.New, 2) # result = true # check: @@ -63,15 +63,16 @@ suite "Mplex": # check: # waitFor(testDecodeHeader()) == true - test "e2e - new stream": + test "e2e - read/write initiator": proc testNewStream(): Future[bool] {.async.} = - let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53351") + let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53380") proc connHandler(conn: Connection) {.async, gcsafe.} = - proc handleListen(stream: Connection) {.async, gcsafe.} = + proc handleMplexListen(stream: Connection) {.async, gcsafe.} = await stream.writeLp("Hello from stream!") + await stream.close() - let mplexListen = newMplex(conn, handleListen) + let mplexListen = newMplex(conn, handleMplexListen) await mplexListen.handle() let transport1: TcpTransport = newTransport(TcpTransport) @@ -79,11 +80,45 @@ suite "Mplex": let transport2: TcpTransport = newTransport(TcpTransport) let conn = await transport2.dial(ma) - proc handleDial(stream: Connection) {.async, gcsafe.} = - let msg = await stream.readLp() + proc handleDial(stream: Connection) {.async, gcsafe.} = discard let mplexDial = newMplex(conn, handleDial) - let handleFut = mplexDial.handle() + let dialFut = mplexDial.handle() + let stream = await mplexDial.newStream() + check cast[string](await stream.readLp()) == "Hello from stream!" + + await conn.close() + await dialFut + result = true + + check: + waitFor(testNewStream()) == true + + test "e2e - read/write receiver": + proc testNewStream(): Future[bool] {.async.} = + let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53381") + + proc connHandler(conn: Connection) {.async, gcsafe.} = + proc handleMplexListen(stream: Connection) {.async, gcsafe.} = + let msg = await stream.readLp() + check cast[string](msg) == "Hello from stream!" + await stream.close() + + let mplexListen = newMplex(conn, handleMplexListen) + await mplexListen.handle() + + let transport1: TcpTransport = newTransport(TcpTransport) + await transport1.listen(ma, connHandler) + + let transport2: TcpTransport = newTransport(TcpTransport) + let conn = await transport2.dial(ma) + + proc handleDial(stream: Connection) {.async, gcsafe.} = discard + let mplexDial = newMplex(conn, handleDial) + let dialFut = mplexDial.handle() + let stream = await mplexDial.newStream() + await stream.writeLp("Hello from stream!") + await dialFut result = true check: