don't pass stream handler through contructor

This commit is contained in:
Dmitriy Ryajov 2019-09-04 15:22:23 -06:00
parent 3cd19ddc47
commit 0b784c5b58
4 changed files with 21 additions and 17 deletions

View File

@ -7,12 +7,14 @@
## This file may not be copied, modified, or distributed except according to ## This file may not be copied, modified, or distributed except according to
## those terms. ## those terms.
## TODO: I have to be carefull to clean up channels correctly ## TODO: I have to be carefull to clean up channels correctly,
## both by removing them from the internal tables as well as ## both by removing them from the internal tables as well as
## cleaning up when the channel is completelly finished, this ## cleaning up when the channel is completelly finished. This
## is complicated because half closed makes it non-deterministic. ## is complicated because half-closed streams makes closing
## This still needs to be implemented properly - I'm leaving it here ## channels non non-deterministic.
## to not forget that this needs to be fixed ASAP. ##
## This still needs to be implemented properly - I'm leaving it
## here to not forget that this needs to be fixed ASAP.
import tables, sequtils import tables, sequtils
import chronos import chronos
@ -28,7 +30,6 @@ type
local*: Table[int, Channel] local*: Table[int, Channel]
currentId*: int currentId*: int
maxChannels*: uint maxChannels*: uint
streamHandler*: StreamHandler
proc newMplexUnknownMsgError(): ref MplexUnknownMsgError = proc newMplexUnknownMsgError(): ref MplexUnknownMsgError =
result = newException(MplexUnknownMsgError, "Unknown mplex message type") result = newException(MplexUnknownMsgError, "Unknown mplex message type")
@ -59,6 +60,7 @@ method handle*(m: Mplex): Future[void] {.async, gcsafe.} =
case msgType: case msgType:
of MessageType.New: of MessageType.New:
let channel = await m.newStreamInternal(false, id.int) let channel = await m.newStreamInternal(false, id.int)
if not isNil(m.streamHandler):
channel.handlerFuture = m.streamHandler(newConnection(channel)) channel.handlerFuture = m.streamHandler(newConnection(channel))
of MessageType.MsgIn, MessageType.MsgOut: of MessageType.MsgIn, MessageType.MsgOut:
let channel = m.getChannelList(initiator)[id.int] let channel = m.getChannelList(initiator)[id.int]
@ -79,12 +81,10 @@ method handle*(m: Mplex): Future[void] {.async, gcsafe.} =
await m.connection.close() await m.connection.close()
proc newMplex*(conn: Connection, proc newMplex*(conn: Connection,
streamHandler: StreamHandler,
maxChanns: uint = MaxChannels): Mplex = maxChanns: uint = MaxChannels): Mplex =
new result new result
result.connection = conn result.connection = conn
result.maxChannels = maxChanns result.maxChannels = maxChanns
result.streamHandler = streamHandler
result.remote = initTable[int, Channel]() result.remote = initTable[int, Channel]()
result.local = initTable[int, Channel]() result.local = initTable[int, Channel]()

View File

@ -25,5 +25,3 @@ type
CloseOut, CloseOut,
ResetIn, ResetIn,
ResetOut ResetOut
StreamHandler* = proc(conn: Connection): Future[void] {.gcsafe.}

View File

@ -11,7 +11,10 @@ import chronos
import ../protocol, ../connection import ../protocol, ../connection
type type
StreamHandler* = proc(conn: Connection): Future[void] {.gcsafe.}
Muxer* = ref object of RootObj Muxer* = ref object of RootObj
streamHandler*: StreamHandler
connection*: Connection connection*: Connection
MuxerCreator* = proc(conn: Connection): Muxer {.gcsafe, closure.} MuxerCreator* = proc(conn: Connection): Muxer {.gcsafe, closure.}
@ -33,3 +36,6 @@ method init(c: MuxerProvider) =
method newStream*(m: Muxer): Future[Connection] {.base, async, gcsafe.} = discard method newStream*(m: Muxer): Future[Connection] {.base, async, gcsafe.} = discard
method close*(m: Muxer) {.base, async, gcsafe.} = discard method close*(m: Muxer) {.base, async, gcsafe.} = discard
method handle*(m: Muxer): Future[void] {.base, async, gcsafe.} = discard method handle*(m: Muxer): Future[void] {.base, async, gcsafe.} = discard
method `=streamHandler`*(m: Muxer, handler: StreamHandler) {.base, gcsafe.} =
m.streamHandler = handler

View File

@ -79,7 +79,8 @@ suite "Mplex":
await stream.writeLp("Hello from stream!") await stream.writeLp("Hello from stream!")
await stream.close() await stream.close()
let mplexListen = newMplex(conn, handleMplexListen) let mplexListen = newMplex(conn)
mplexListen.streamHandler = handleMplexListen
await mplexListen.handle() await mplexListen.handle()
let transport1: TcpTransport = newTransport(TcpTransport) let transport1: TcpTransport = newTransport(TcpTransport)
@ -88,8 +89,7 @@ suite "Mplex":
let transport2: TcpTransport = newTransport(TcpTransport) let transport2: TcpTransport = newTransport(TcpTransport)
let conn = await transport2.dial(ma) let conn = await transport2.dial(ma)
proc handleDial(stream: Connection) {.async, gcsafe.} = discard let mplexDial = newMplex(conn)
let mplexDial = newMplex(conn, handleDial)
let dialFut = mplexDial.handle() let dialFut = mplexDial.handle()
let stream = await mplexDial.newStream() let stream = await mplexDial.newStream()
check cast[string](await stream.readLp()) == "Hello from stream!" check cast[string](await stream.readLp()) == "Hello from stream!"
@ -110,7 +110,8 @@ suite "Mplex":
check cast[string](msg) == "Hello from stream!" check cast[string](msg) == "Hello from stream!"
await stream.close() await stream.close()
let mplexListen = newMplex(conn, handleMplexListen) let mplexListen = newMplex(conn)
mplexListen.streamHandler = handleMplexListen
await mplexListen.handle() await mplexListen.handle()
let transport1: TcpTransport = newTransport(TcpTransport) let transport1: TcpTransport = newTransport(TcpTransport)
@ -119,8 +120,7 @@ suite "Mplex":
let transport2: TcpTransport = newTransport(TcpTransport) let transport2: TcpTransport = newTransport(TcpTransport)
let conn = await transport2.dial(ma) let conn = await transport2.dial(ma)
proc handleDial(stream: Connection) {.async, gcsafe.} = discard let mplexDial = newMplex(conn)
let mplexDial = newMplex(conn, handleDial)
let dialFut = mplexDial.handle() let dialFut = mplexDial.handle()
let stream = await mplexDial.newStream() let stream = await mplexDial.newStream()
await stream.writeLp("Hello from stream!") await stream.writeLp("Hello from stream!")