diff --git a/libp2p/muxers/mplex/coder.nim b/libp2p/muxers/mplex/coder.nim index a0a240131..bd15ef796 100644 --- a/libp2p/muxers/mplex/coder.nim +++ b/libp2p/muxers/mplex/coder.nim @@ -26,6 +26,11 @@ type msgType: MessageType data: seq[byte] + InvalidMplexMsgType = object of CatchableError + +proc newInvalidMplexMsgType*(): ref InvalidMplexMsgType = + newException(InvalidMplexMsgType, "invalid message type") + proc readMplexVarint(conn: Connection): Future[uint64] {.async, gcsafe.} = var varint: uint @@ -41,27 +46,31 @@ proc readMplexVarint(conn: Connection): Future[uint64] {.async, gcsafe.} = break if res != VarintStatus.Success: raise newInvalidVarintException() - if varint.int > DefaultReadSize: - raise newInvalidVarintSizeException() return varint except LPStreamIncompleteError as exc: trace "unable to read varint", exc = exc.msg raise exc proc readMsg*(conn: Connection): Future[Msg] {.async, gcsafe.} = - let headerVarint = await conn.readMplexVarint() - trace "read header varint", varint = headerVarint + let header = await conn.readMplexVarint() + trace "read header varint", varint = header let dataLenVarint = await conn.readMplexVarint() trace "read data len varint", varint = dataLenVarint + if dataLenVarint.int > DefaultReadSize: + raise newInvalidVarintSizeException() + var data: seq[byte] = newSeq[byte](dataLenVarint.int) if dataLenVarint.int > 0: await conn.readExactly(addr data[0], dataLenVarint.int) trace "read data", data = data.len - let header = headerVarint - result = (uint64(header shr 3), MessageType(header and 0x7), data) + let msgType = header and 0x7 + if msgType.int > ord(MessageType.ResetOut): + raise newInvalidMplexMsgType() + + result = (uint64(header shr 3), MessageType(msgType), data) proc writeMsg*(conn: Connection, id: uint64, diff --git a/libp2p/muxers/mplex/lpchannel.nim b/libp2p/muxers/mplex/lpchannel.nim index 6ffed67d8..6e1b2a19f 100644 --- a/libp2p/muxers/mplex/lpchannel.nim +++ b/libp2p/muxers/mplex/lpchannel.nim @@ -158,11 +158,12 @@ method readUntil*(s: LPChannel, await s.tryCleanup() template writePrefix: untyped = - if s.isLazy and not s.isOpen: - await s.open() if s.closedLocal or s.isReset: raise newLPStreamEOFError() + if s.isLazy and not s.isOpen: + await s.open() + method write*(s: LPChannel, pbytes: pointer, nbytes: int) {.async.} = writePrefix() await procCall write(BufferStream(s), pbytes, nbytes) diff --git a/libp2p/muxers/mplex/mplex.nim b/libp2p/muxers/mplex/mplex.nim index 5e758ce8a..54ba181d7 100644 --- a/libp2p/muxers/mplex/mplex.nim +++ b/libp2p/muxers/mplex/mplex.nim @@ -11,7 +11,7 @@ ## Timeouts and message limits are still missing ## they need to be added ASAP -import tables, sequtils, options +import tables, sequtils import chronos, chronicles import ../muxer, ../../connection, diff --git a/libp2p/switch.nim b/libp2p/switch.nim index d6715d0fe..f4a7bbdce 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -237,7 +237,9 @@ proc internalConnect(s: Switch, else: trace "Reusing existing connection" - await s.subscribeToPeer(peer) + if not isNil(conn): + await s.subscribeToPeer(peer) + result = conn proc connect*(s: Switch, peer: PeerInfo) {.async.} = @@ -323,6 +325,7 @@ proc subscribeToPeer(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} = await s.pubSub.get().subscribeToPeer(conn) except CatchableError as exc: warn "unable to initiate pubsub", exc = exc.msg + finally: s.dialedPubSubPeers.excl(peerInfo.id) proc subscribe*(s: Switch, topic: string, diff --git a/tests/testinterop.nim b/tests/testinterop.nim index 1464c4bd7..f519f6397 100644 --- a/tests/testinterop.nim +++ b/tests/testinterop.nim @@ -390,6 +390,7 @@ suite "Interop": inc(count2) result = 10 == (await wait(testFuture, 10.secs)) + await stream.close() await nativeNode.stop() await allFutures(awaiters) await daemonNode.close()