diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index 5f27acc80..e16913d54 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -7,12 +7,14 @@ ## This file may not be copied, modified, or distributed except according to ## those terms. -## TODO: I have to be carefull to clean up channels correctly +## TODO: I have to be carefull to clean up channels correctly, ## both by removing them from the internal tables as well as -## cleaning up when the channel is completelly finished, this -## is complicated because half closed makes it non-deterministic. -## This still needs to be implemented properly - I'm leaving it here -## to not forget that this needs to be fixed ASAP. +## cleaning up when the channel is completelly finished. This +## is complicated because half-closed streams makes closing +## channels non non-deterministic. +## +## This still needs to be implemented properly - I'm leaving it +## here to not forget that this needs to be fixed ASAP. import tables, sequtils import chronos @@ -28,7 +30,6 @@ type local*: Table[int, Channel] currentId*: int maxChannels*: uint - streamHandler*: StreamHandler proc newMplexUnknownMsgError(): ref MplexUnknownMsgError = result = newException(MplexUnknownMsgError, "Unknown mplex message type") @@ -59,7 +60,8 @@ method handle*(m: Mplex): Future[void] {.async, gcsafe.} = case msgType: of MessageType.New: let channel = await m.newStreamInternal(false, id.int) - channel.handlerFuture = m.streamHandler(newConnection(channel)) + if not isNil(m.streamHandler): + channel.handlerFuture = m.streamHandler(newConnection(channel)) of MessageType.MsgIn, MessageType.MsgOut: let channel = m.getChannelList(initiator)[id.int] let msg = await m.connection.readLp() @@ -79,12 +81,10 @@ method handle*(m: Mplex): Future[void] {.async, gcsafe.} = await m.connection.close() proc newMplex*(conn: Connection, - streamHandler: StreamHandler, maxChanns: uint = MaxChannels): Mplex = new result result.connection = conn result.maxChannels = maxChanns - result.streamHandler = streamHandler result.remote = initTable[int, Channel]() result.local = initTable[int, Channel]() diff --git a/libp2p/muxers/mplex/types.nim b/libp2p/muxers/mplex/types.nim index fe29a82ae..f83bb37a7 100644 --- a/libp2p/muxers/mplex/types.nim +++ b/libp2p/muxers/mplex/types.nim @@ -25,5 +25,3 @@ type CloseOut, ResetIn, ResetOut - - StreamHandler* = proc(conn: Connection): Future[void] {.gcsafe.} diff --git a/libp2p/muxers/muxer.nim b/libp2p/muxers/muxer.nim index 5623169a2..03d5dcfce 100644 --- a/libp2p/muxers/muxer.nim +++ b/libp2p/muxers/muxer.nim @@ -11,7 +11,10 @@ import chronos import ../protocol, ../connection type + StreamHandler* = proc(conn: Connection): Future[void] {.gcsafe.} + Muxer* = ref object of RootObj + streamHandler*: StreamHandler connection*: Connection MuxerCreator* = proc(conn: Connection): Muxer {.gcsafe, closure.} @@ -33,3 +36,6 @@ method init(c: MuxerProvider) = method newStream*(m: Muxer): Future[Connection] {.base, async, gcsafe.} = discard method close*(m: Muxer) {.base, async, gcsafe.} = discard method handle*(m: Muxer): Future[void] {.base, async, gcsafe.} = discard + +method `=streamHandler`*(m: Muxer, handler: StreamHandler) {.base, gcsafe.} = + m.streamHandler = handler diff --git a/tests/testmplex.nim b/tests/testmplex.nim index 6a2fbe922..56de312fe 100644 --- a/tests/testmplex.nim +++ b/tests/testmplex.nim @@ -79,7 +79,8 @@ suite "Mplex": await stream.writeLp("Hello from stream!") await stream.close() - let mplexListen = newMplex(conn, handleMplexListen) + let mplexListen = newMplex(conn) + mplexListen.streamHandler = handleMplexListen await mplexListen.handle() let transport1: TcpTransport = newTransport(TcpTransport) @@ -88,8 +89,7 @@ suite "Mplex": let transport2: TcpTransport = newTransport(TcpTransport) let conn = await transport2.dial(ma) - proc handleDial(stream: Connection) {.async, gcsafe.} = discard - let mplexDial = newMplex(conn, handleDial) + let mplexDial = newMplex(conn) let dialFut = mplexDial.handle() let stream = await mplexDial.newStream() check cast[string](await stream.readLp()) == "Hello from stream!" @@ -110,7 +110,8 @@ suite "Mplex": check cast[string](msg) == "Hello from stream!" await stream.close() - let mplexListen = newMplex(conn, handleMplexListen) + let mplexListen = newMplex(conn) + mplexListen.streamHandler = handleMplexListen await mplexListen.handle() let transport1: TcpTransport = newTransport(TcpTransport) @@ -119,8 +120,7 @@ suite "Mplex": let transport2: TcpTransport = newTransport(TcpTransport) let conn = await transport2.dial(ma) - proc handleDial(stream: Connection) {.async, gcsafe.} = discard - let mplexDial = newMplex(conn, handleDial) + let mplexDial = newMplex(conn) let dialFut = mplexDial.handle() let stream = await mplexDial.newStream() await stream.writeLp("Hello from stream!")