mirror of https://github.com/vacp2p/nim-libp2p.git
Audit multistream fixes (#291)
* Don't ignore missing \n in multistream requests Also make sure to except and quit an existing connection if multistream handshake fails * solve handshake tracking in ms handler
This commit is contained in:
parent
f7fdf31365
commit
0f06ae5a1d
|
@ -40,6 +40,12 @@ proc newMultistream*(): MultistreamSelect =
|
||||||
new result
|
new result
|
||||||
result.codec = MSCodec
|
result.codec = MSCodec
|
||||||
|
|
||||||
|
template validateSuffix(str: string): untyped =
|
||||||
|
if str.endsWith("\n"):
|
||||||
|
str.removeSuffix("\n")
|
||||||
|
else:
|
||||||
|
raise newException(CatchableError, "MultistreamSelect failed, malformed message")
|
||||||
|
|
||||||
proc select*(m: MultistreamSelect,
|
proc select*(m: MultistreamSelect,
|
||||||
conn: Connection,
|
conn: Connection,
|
||||||
proto: seq[string]):
|
proto: seq[string]):
|
||||||
|
@ -52,17 +58,20 @@ proc select*(m: MultistreamSelect,
|
||||||
await conn.writeLp((proto[0] & "\n")) # select proto
|
await conn.writeLp((proto[0] & "\n")) # select proto
|
||||||
|
|
||||||
var s = string.fromBytes((await conn.readLp(1024))) # read ms header
|
var s = string.fromBytes((await conn.readLp(1024))) # read ms header
|
||||||
s.removeSuffix("\n")
|
validateSuffix(s)
|
||||||
|
|
||||||
if s != Codec:
|
if s != Codec:
|
||||||
notice "handshake failed", codec = s.toHex()
|
notice "handshake failed", codec = s
|
||||||
return ""
|
raise newException(CatchableError, "MultistreamSelect handshake failed")
|
||||||
|
else:
|
||||||
|
trace "multistream handshake success"
|
||||||
|
|
||||||
if proto.len() == 0: # no protocols, must be a handshake call
|
if proto.len() == 0: # no protocols, must be a handshake call
|
||||||
return Codec
|
return Codec
|
||||||
else:
|
else:
|
||||||
s = string.fromBytes(await conn.readLp(1024)) # read the first proto
|
s = string.fromBytes(await conn.readLp(1024)) # read the first proto
|
||||||
|
validateSuffix(s)
|
||||||
trace "reading first requested proto"
|
trace "reading first requested proto"
|
||||||
s.removeSuffix("\n")
|
|
||||||
if s == proto[0]:
|
if s == proto[0]:
|
||||||
trace "successfully selected ", proto = proto[0]
|
trace "successfully selected ", proto = proto[0]
|
||||||
return proto[0]
|
return proto[0]
|
||||||
|
@ -74,7 +83,7 @@ proc select*(m: MultistreamSelect,
|
||||||
trace "selecting proto", proto = p
|
trace "selecting proto", proto = p
|
||||||
await conn.writeLp((p & "\n")) # select proto
|
await conn.writeLp((p & "\n")) # select proto
|
||||||
s = string.fromBytes(await conn.readLp(1024)) # read the first proto
|
s = string.fromBytes(await conn.readLp(1024)) # read the first proto
|
||||||
s.removeSuffix("\n")
|
validateSuffix(s)
|
||||||
if s == p:
|
if s == p:
|
||||||
trace "selected protocol", protocol = s
|
trace "selected protocol", protocol = s
|
||||||
return s
|
return s
|
||||||
|
@ -110,12 +119,18 @@ proc list*(m: MultistreamSelect,
|
||||||
|
|
||||||
result = list
|
result = list
|
||||||
|
|
||||||
proc handle*(m: MultistreamSelect, conn: Connection) {.async, gcsafe.} =
|
proc handle*(m: MultistreamSelect, conn: Connection, active: bool = false) {.async, gcsafe.} =
|
||||||
trace "handle: starting multistream handling"
|
trace "handle: starting multistream handling"
|
||||||
|
var handshaked = active
|
||||||
try:
|
try:
|
||||||
while not conn.closed:
|
while not conn.closed:
|
||||||
var ms = string.fromBytes(await conn.readLp(1024))
|
var ms = string.fromBytes(await conn.readLp(1024))
|
||||||
ms.removeSuffix("\n")
|
validateSuffix(ms)
|
||||||
|
|
||||||
|
if not handshaked and ms != Codec:
|
||||||
|
error "expected handshake message", instead=ms
|
||||||
|
raise newException(CatchableError,
|
||||||
|
"MultistreamSelect handling failed, invalid first message")
|
||||||
|
|
||||||
trace "handle: got request for ", ms
|
trace "handle: got request for ", ms
|
||||||
if ms.len() <= 0:
|
if ms.len() <= 0:
|
||||||
|
@ -128,23 +143,27 @@ proc handle*(m: MultistreamSelect, conn: Connection) {.async, gcsafe.} =
|
||||||
continue
|
continue
|
||||||
|
|
||||||
case ms:
|
case ms:
|
||||||
of "ls":
|
of "ls":
|
||||||
trace "handle: listing protos"
|
trace "handle: listing protos"
|
||||||
var protos = ""
|
var protos = ""
|
||||||
for h in m.handlers:
|
for h in m.handlers:
|
||||||
protos &= (h.proto & "\n")
|
protos &= (h.proto & "\n")
|
||||||
await conn.writeLp(protos)
|
await conn.writeLp(protos)
|
||||||
of Codec:
|
of Codec:
|
||||||
|
if not handshaked:
|
||||||
await conn.write(m.codec)
|
await conn.write(m.codec)
|
||||||
|
handshaked = true
|
||||||
else:
|
else:
|
||||||
for h in m.handlers:
|
|
||||||
if (not isNil(h.match) and h.match(ms)) or ms == h.proto:
|
|
||||||
trace "found handler for", protocol = ms
|
|
||||||
await conn.writeLp((h.proto & "\n"))
|
|
||||||
await h.protocol.handler(conn, ms)
|
|
||||||
return
|
|
||||||
debug "no handlers for ", protocol = ms
|
|
||||||
await conn.write(Na)
|
await conn.write(Na)
|
||||||
|
else:
|
||||||
|
for h in m.handlers:
|
||||||
|
if (not isNil(h.match) and h.match(ms)) or ms == h.proto:
|
||||||
|
trace "found handler for", protocol = ms
|
||||||
|
await conn.writeLp((h.proto & "\n"))
|
||||||
|
await h.protocol.handler(conn, ms)
|
||||||
|
return
|
||||||
|
debug "no handlers for ", protocol = ms
|
||||||
|
await conn.write(Na)
|
||||||
except CancelledError as exc:
|
except CancelledError as exc:
|
||||||
await conn.close()
|
await conn.close()
|
||||||
raise exc
|
raise exc
|
||||||
|
|
|
@ -260,7 +260,7 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
|
||||||
for muxer in s.muxers.values:
|
for muxer in s.muxers.values:
|
||||||
ms.addHandler(muxer.codec, muxer)
|
ms.addHandler(muxer.codec, muxer)
|
||||||
|
|
||||||
# handle subsequent requests
|
# handle subsequent secure requests
|
||||||
await ms.handle(sconn)
|
await ms.handle(sconn)
|
||||||
|
|
||||||
except CancelledError as exc:
|
except CancelledError as exc:
|
||||||
|
@ -273,8 +273,9 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} =
|
||||||
for k in s.secureManagers:
|
for k in s.secureManagers:
|
||||||
ms.addHandler(k.codec, securedHandler)
|
ms.addHandler(k.codec, securedHandler)
|
||||||
|
|
||||||
# handle secured connections
|
# handle un-secured connections
|
||||||
await ms.handle(conn)
|
# we handshaked above, set this ms handler as active
|
||||||
|
await ms.handle(conn, active = true)
|
||||||
|
|
||||||
proc internalConnect(s: Switch,
|
proc internalConnect(s: Switch,
|
||||||
peer: PeerInfo): Future[Connection] {.async.} =
|
peer: PeerInfo): Future[Connection] {.async.} =
|
||||||
|
|
Loading…
Reference in New Issue