diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index b971c9017..c3aaa9dc7 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -16,7 +16,7 @@ ## This still needs to be implemented properly - I'm leaving it ## here to not forget that this needs to be fixed ASAP. -import tables, sequtils, strformat, options +import tables, sequtils, options, strformat import chronos import coder, types, channel, ../../varint, @@ -56,33 +56,50 @@ proc newStreamInternal*(m: Mplex, result = newChannel(id, m.connection, initiator, name) m.getChannelList(initiator)[id] = result -method handle*(m: Mplex): Future[void] {.async, gcsafe.} = +method handle*(m: Mplex) {.async, gcsafe.} = + try: while not m.connection.closed: - try: - let (id, msgType, data) = await m.connection.readMsg() - let initiator = bool(ord(msgType) and 1) - var channel: Channel - if MessageType(msgType) != MessageType.New: - let channels = m.getChannelList(initiator) - if not channels.contains(id): - raise newMplexNoSuchChannel(id, msgType) - channel = channels[id] + let msgRes = await m.connection.readMsg() + if msgRes.isNone: + await sleepAsync(100.millis) + continue - case msgType: - of MessageType.New: - channel = await m.newStreamInternal(false, id, cast[string](data)) - if not isNil(m.streamHandler): - await m.streamHandler(newConnection(channel)) - of MessageType.MsgIn, MessageType.MsgOut: - await channel.pushTo(data) - of MessageType.CloseIn, MessageType.CloseOut: - await channel.closedByRemote() - m.getChannelList(initiator).del(id) - of MessageType.ResetIn, MessageType.ResetOut: - await channel.resetByRemote() - else: raise newMplexUnknownMsgError() - finally: - await m.connection.close() + let (id, msgType, data) = msgRes.get() + let initiator = bool(ord(msgType) and 1) + var channel: Channel + if MessageType(msgType) != MessageType.New: + let channels = m.getChannelList(initiator) + if not channels.contains(id): + raise newMplexNoSuchChannel(id, msgType) + channel = channels[id] + + case msgType: + of MessageType.New: + channel = await m.newStreamInternal(false, id, cast[string](data)) + if not isNil(m.streamHandler): + let handlerFut = m.streamHandler(newConnection(channel)) + proc cleanUpChan(udata: pointer) {.gcsafe.} = + if handlerFut.finished: + channel.close().addCallback( + proc(udata: pointer) = + # TODO: is waitFor() OK here? + channel.cleanUp() + .addCallback(proc(udata: pointer) = + echo &"cleaned up channel {$id}") + ) + handlerFut.addCallback(cleanUpChan) + continue + of MessageType.MsgIn, MessageType.MsgOut: + await channel.pushTo(data) + of MessageType.CloseIn, MessageType.CloseOut: + await channel.closedByRemote() + m.getChannelList(initiator).del(id) + of MessageType.ResetIn, MessageType.ResetOut: + await channel.resetByRemote() + break + else: raise newMplexUnknownMsgError() + finally: + await m.connection.close() proc newMplex*(conn: Connection, maxChanns: uint = MaxChannels): Mplex = @@ -94,7 +111,7 @@ proc newMplex*(conn: Connection, method newStream*(m: Mplex, name: string = ""): Future[Connection] {.async, gcsafe.} = let channel = await m.newStreamInternal() - await m.connection.writeMsg(channel.id, MessageType.New, cast[seq[byte]](toSeq(name.items))) + await m.connection.writeMsg(channel.id, MessageType.New, name) result = newConnection(channel) method close*(m: Mplex) {.async, gcsafe.} =