diff --git a/tests/testmplex.nim b/tests/testmplex.nim index 1a6bb6b..9c5bfd3 100644 --- a/tests/testmplex.nim +++ b/tests/testmplex.nim @@ -1,4 +1,4 @@ -import unittest, sequtils, sugar, strformat +import unittest, sequtils, sugar, strformat, options import chronos, nimcrypto/utils import ../libp2p/connection, ../libp2p/stream/lpstream, @@ -10,7 +10,8 @@ import ../libp2p/connection, ../libp2p/muxers/mplex/mplex, ../libp2p/muxers/mplex/coder, ../libp2p/muxers/mplex/types, - ../libp2p/muxers/mplex/channel + ../libp2p/muxers/mplex/channel, + ../libp2p/helpers/debug suite "Mplex": test "encode header with channel id 0": @@ -74,11 +75,12 @@ suite "Mplex": let stream = newBufferStream(encHandler) let conn = newConnection(stream) await stream.pushTo(fromHex("000873747265616d2031")) - let (id, msgType, data) = await conn.readMsg() + let msg = await conn.readMsg() - check id == 0 - check msgType == MessageType.New - result = true + if msg.isSome: + check msg.get().id == 0 + check msg.get().msgType == MessageType.New + result = true check: waitFor(testDecodeHeader()) == true @@ -89,12 +91,13 @@ suite "Mplex": let stream = newBufferStream(encHandler) let conn = newConnection(stream) await stream.pushTo(fromHex("021668656C6C6F2066726F6D206368616E6E656C20302121")) - let (id, msgType, data) = await conn.readMsg() + let msg = await conn.readMsg() - check id == 0 - check msgType == MessageType.MsgOut - check cast[string](data) == "hello from channel 0!!" - result = true + if msg.isSome: + check msg.get().id == 0 + check msg.get().msgType == MessageType.MsgOut + check cast[string](msg.get().data) == "hello from channel 0!!" + result = true check: waitFor(testDecodeHeader()) == true @@ -105,12 +108,13 @@ suite "Mplex": let stream = newBufferStream(encHandler) let conn = newConnection(stream) await stream.pushTo(fromHex("8a011668656C6C6F2066726F6D206368616E6E656C20302121")) - let (id, msgType, data) = await conn.readMsg() + let msg = await conn.readMsg() - check id == 17 - check msgType == MessageType.MsgOut - check cast[string](data) == "hello from channel 0!!" - result = true + if msg.isSome: + check msg.get().id == 17 + check msg.get().msgType == MessageType.MsgOut + check cast[string](msg.get().data) == "hello from channel 0!!" + result = true check: waitFor(testDecodeHeader()) == true @@ -209,6 +213,51 @@ suite "Mplex": check: waitFor(testNewStream()) == true + test "e2e - multiple read/write streams": + proc testNewStream(): Future[bool] {.async.} = + let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53383") + + var count = 1 + var listenFut: Future[void] + proc connHandler(conn: Connection) {.async, gcsafe.} = + proc handleMplexListen(stream: Connection) {.async, gcsafe.} = + let msg = await stream.readLp() + check cast[string](msg) == &"stream {count} from dialer!" + await stream.writeLp(&"stream {count} from listener!") + count.inc + await stream.close() + + let mplexListen = newMplex(conn) + mplexListen.streamHandler = handleMplexListen + listenFut = mplexListen.handle() + listenFut.addCallback(proc(udata: pointer) {.gcsafe.} + = debug "completed listener") + + let transport1: TcpTransport = newTransport(TcpTransport) + await transport1.listen(ma, connHandler) + + let transport2: TcpTransport = newTransport(TcpTransport) + let conn = await transport2.dial(ma) + + let mplexDial = newMplex(conn) + let dialFut = mplexDial.handle() + dialFut.addCallback(proc(udata: pointer = nil) {.gcsafe.} + = debug "completed dialer") + for i in 1..10: + let stream = await mplexDial.newStream("dialer stream") + await stream.writeLp(&"stream {i} from dialer!") + let msg = await stream.readLp() + check cast[string](msg) == &"stream {i} from listener!" + await stream.close() + + await conn.close() + listenFut.complete() + dialFut.complete() + result = true + + check: + waitFor(testNewStream()) == true + test "half closed - channel should close for write": proc testClosedForWrite(): Future[void] {.async.} = proc writeHandler(data: seq[byte]) {.async, gcsafe.} = discard