* only check for payload size

* only subscribe if connection succeeded

* fix failing test

* check that the strem is active before openning

* msg type should not be > than 0x7

* fix tests

* check max against enum val
This commit is contained in:
Dmitriy Ryajov 2020-03-29 08:28:48 -06:00 committed by GitHub
parent 6bb4e91a39
commit 5285f0d091
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 24 additions and 10 deletions

View File

@ -26,6 +26,11 @@ type
msgType: MessageType msgType: MessageType
data: seq[byte] data: seq[byte]
InvalidMplexMsgType = object of CatchableError
proc newInvalidMplexMsgType*(): ref InvalidMplexMsgType =
newException(InvalidMplexMsgType, "invalid message type")
proc readMplexVarint(conn: Connection): Future[uint64] {.async, gcsafe.} = proc readMplexVarint(conn: Connection): Future[uint64] {.async, gcsafe.} =
var var
varint: uint varint: uint
@ -41,27 +46,31 @@ proc readMplexVarint(conn: Connection): Future[uint64] {.async, gcsafe.} =
break break
if res != VarintStatus.Success: if res != VarintStatus.Success:
raise newInvalidVarintException() raise newInvalidVarintException()
if varint.int > DefaultReadSize:
raise newInvalidVarintSizeException()
return varint return varint
except LPStreamIncompleteError as exc: except LPStreamIncompleteError as exc:
trace "unable to read varint", exc = exc.msg trace "unable to read varint", exc = exc.msg
raise exc raise exc
proc readMsg*(conn: Connection): Future[Msg] {.async, gcsafe.} = proc readMsg*(conn: Connection): Future[Msg] {.async, gcsafe.} =
let headerVarint = await conn.readMplexVarint() let header = await conn.readMplexVarint()
trace "read header varint", varint = headerVarint trace "read header varint", varint = header
let dataLenVarint = await conn.readMplexVarint() let dataLenVarint = await conn.readMplexVarint()
trace "read data len varint", varint = dataLenVarint trace "read data len varint", varint = dataLenVarint
if dataLenVarint.int > DefaultReadSize:
raise newInvalidVarintSizeException()
var data: seq[byte] = newSeq[byte](dataLenVarint.int) var data: seq[byte] = newSeq[byte](dataLenVarint.int)
if dataLenVarint.int > 0: if dataLenVarint.int > 0:
await conn.readExactly(addr data[0], dataLenVarint.int) await conn.readExactly(addr data[0], dataLenVarint.int)
trace "read data", data = data.len trace "read data", data = data.len
let header = headerVarint let msgType = header and 0x7
result = (uint64(header shr 3), MessageType(header and 0x7), data) if msgType.int > ord(MessageType.ResetOut):
raise newInvalidMplexMsgType()
result = (uint64(header shr 3), MessageType(msgType), data)
proc writeMsg*(conn: Connection, proc writeMsg*(conn: Connection,
id: uint64, id: uint64,

View File

@ -158,11 +158,12 @@ method readUntil*(s: LPChannel,
await s.tryCleanup() await s.tryCleanup()
template writePrefix: untyped = template writePrefix: untyped =
if s.isLazy and not s.isOpen:
await s.open()
if s.closedLocal or s.isReset: if s.closedLocal or s.isReset:
raise newLPStreamEOFError() raise newLPStreamEOFError()
if s.isLazy and not s.isOpen:
await s.open()
method write*(s: LPChannel, pbytes: pointer, nbytes: int) {.async.} = method write*(s: LPChannel, pbytes: pointer, nbytes: int) {.async.} =
writePrefix() writePrefix()
await procCall write(BufferStream(s), pbytes, nbytes) await procCall write(BufferStream(s), pbytes, nbytes)

View File

@ -11,7 +11,7 @@
## Timeouts and message limits are still missing ## Timeouts and message limits are still missing
## they need to be added ASAP ## they need to be added ASAP
import tables, sequtils, options import tables, sequtils
import chronos, chronicles import chronos, chronicles
import ../muxer, import ../muxer,
../../connection, ../../connection,

View File

@ -237,7 +237,9 @@ proc internalConnect(s: Switch,
else: else:
trace "Reusing existing connection" trace "Reusing existing connection"
await s.subscribeToPeer(peer) if not isNil(conn):
await s.subscribeToPeer(peer)
result = conn result = conn
proc connect*(s: Switch, peer: PeerInfo) {.async.} = proc connect*(s: Switch, peer: PeerInfo) {.async.} =
@ -323,6 +325,7 @@ proc subscribeToPeer(s: Switch, peerInfo: PeerInfo) {.async, gcsafe.} =
await s.pubSub.get().subscribeToPeer(conn) await s.pubSub.get().subscribeToPeer(conn)
except CatchableError as exc: except CatchableError as exc:
warn "unable to initiate pubsub", exc = exc.msg warn "unable to initiate pubsub", exc = exc.msg
finally:
s.dialedPubSubPeers.excl(peerInfo.id) s.dialedPubSubPeers.excl(peerInfo.id)
proc subscribe*(s: Switch, topic: string, proc subscribe*(s: Switch, topic: string,

View File

@ -390,6 +390,7 @@ suite "Interop":
inc(count2) inc(count2)
result = 10 == (await wait(testFuture, 10.secs)) result = 10 == (await wait(testFuture, 10.secs))
await stream.close()
await nativeNode.stop() await nativeNode.stop()
await allFutures(awaiters) await allFutures(awaiters)
await daemonNode.close() await daemonNode.close()