Audit multistream fixes (#291)

* Don't ignore missing \n in multistream requests

Also make sure to except and quit an existing connection if multistream handshake fails

* solve handshake tracking in ms handler
This commit is contained in:
Giovanni Petrantoni 2020-07-28 23:03:22 +09:00 committed by GitHub
parent f7fdf31365
commit 0f06ae5a1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 25 deletions

View File

@ -40,6 +40,12 @@ proc newMultistream*(): MultistreamSelect =
new result new result
result.codec = MSCodec result.codec = MSCodec
template validateSuffix(str: string): untyped =
if str.endsWith("\n"):
str.removeSuffix("\n")
else:
raise newException(CatchableError, "MultistreamSelect failed, malformed message")
proc select*(m: MultistreamSelect, proc select*(m: MultistreamSelect,
conn: Connection, conn: Connection,
proto: seq[string]): proto: seq[string]):
@ -52,17 +58,20 @@ proc select*(m: MultistreamSelect,
await conn.writeLp((proto[0] & "\n")) # select proto await conn.writeLp((proto[0] & "\n")) # select proto
var s = string.fromBytes((await conn.readLp(1024))) # read ms header var s = string.fromBytes((await conn.readLp(1024))) # read ms header
s.removeSuffix("\n") validateSuffix(s)
if s != Codec: if s != Codec:
notice "handshake failed", codec = s.toHex() notice "handshake failed", codec = s
return "" raise newException(CatchableError, "MultistreamSelect handshake failed")
else:
trace "multistream handshake success"
if proto.len() == 0: # no protocols, must be a handshake call if proto.len() == 0: # no protocols, must be a handshake call
return Codec return Codec
else: else:
s = string.fromBytes(await conn.readLp(1024)) # read the first proto s = string.fromBytes(await conn.readLp(1024)) # read the first proto
validateSuffix(s)
trace "reading first requested proto" trace "reading first requested proto"
s.removeSuffix("\n")
if s == proto[0]: if s == proto[0]:
trace "successfully selected ", proto = proto[0] trace "successfully selected ", proto = proto[0]
return proto[0] return proto[0]
@ -74,7 +83,7 @@ proc select*(m: MultistreamSelect,
trace "selecting proto", proto = p trace "selecting proto", proto = p
await conn.writeLp((p & "\n")) # select proto await conn.writeLp((p & "\n")) # select proto
s = string.fromBytes(await conn.readLp(1024)) # read the first proto s = string.fromBytes(await conn.readLp(1024)) # read the first proto
s.removeSuffix("\n") validateSuffix(s)
if s == p: if s == p:
trace "selected protocol", protocol = s trace "selected protocol", protocol = s
return s return s
@ -110,12 +119,18 @@ proc list*(m: MultistreamSelect,
result = list result = list
proc handle*(m: MultistreamSelect, conn: Connection) {.async, gcsafe.} = proc handle*(m: MultistreamSelect, conn: Connection, active: bool = false) {.async, gcsafe.} =
trace "handle: starting multistream handling" trace "handle: starting multistream handling"
var handshaked = active
try: try:
while not conn.closed: while not conn.closed:
var ms = string.fromBytes(await conn.readLp(1024)) var ms = string.fromBytes(await conn.readLp(1024))
ms.removeSuffix("\n") validateSuffix(ms)
if not handshaked and ms != Codec:
error "expected handshake message", instead=ms
raise newException(CatchableError,
"MultistreamSelect handling failed, invalid first message")
trace "handle: got request for ", ms trace "handle: got request for ", ms
if ms.len() <= 0: if ms.len() <= 0:
@ -128,23 +143,27 @@ proc handle*(m: MultistreamSelect, conn: Connection) {.async, gcsafe.} =
continue continue
case ms: case ms:
of "ls": of "ls":
trace "handle: listing protos" trace "handle: listing protos"
var protos = "" var protos = ""
for h in m.handlers: for h in m.handlers:
protos &= (h.proto & "\n") protos &= (h.proto & "\n")
await conn.writeLp(protos) await conn.writeLp(protos)
of Codec: of Codec:
if not handshaked:
await conn.write(m.codec) await conn.write(m.codec)
handshaked = true
else: else:
for h in m.handlers: await conn.write(Na)
if (not isNil(h.match) and h.match(ms)) or ms == h.proto: else:
trace "found handler for", protocol = ms for h in m.handlers:
await conn.writeLp((h.proto & "\n")) if (not isNil(h.match) and h.match(ms)) or ms == h.proto:
await h.protocol.handler(conn, ms) trace "found handler for", protocol = ms
return await conn.writeLp((h.proto & "\n"))
debug "no handlers for ", protocol = ms await h.protocol.handler(conn, ms)
await conn.write(Na) return
debug "no handlers for ", protocol = ms
await conn.write(Na)
except CancelledError as exc: except CancelledError as exc:
await conn.close() await conn.close()
raise exc raise exc

View File

@ -260,7 +260,7 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
for muxer in s.muxers.values: for muxer in s.muxers.values:
ms.addHandler(muxer.codec, muxer) ms.addHandler(muxer.codec, muxer)
# handle subsequent requests # handle subsequent secure requests
await ms.handle(sconn) await ms.handle(sconn)
except CancelledError as exc: except CancelledError as exc:
@ -273,8 +273,9 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
for k in s.secureManagers: for k in s.secureManagers:
ms.addHandler(k.codec, securedHandler) ms.addHandler(k.codec, securedHandler)
# handle secured connections # handle un-secured connections
await ms.handle(conn) # we handshaked above, set this ms handler as active
await ms.handle(conn, active = true)
proc internalConnect(s: Switch, proc internalConnect(s: Switch,
peer: PeerInfo): Future[Connection] {.async.} = peer: PeerInfo): Future[Connection] {.async.} =