diff --git a/libp2p/multistream.nim b/libp2p/multistream.nim index 57499da9d..914e4ea3e 100644 --- a/libp2p/multistream.nim +++ b/libp2p/multistream.nim @@ -40,6 +40,12 @@ proc newMultistream*(): MultistreamSelect = new result result.codec = MSCodec +template validateSuffix(str: string): untyped = + if str.endsWith("\n"): + str.removeSuffix("\n") + else: + raise newException(CatchableError, "MultistreamSelect failed, malformed message") + proc select*(m: MultistreamSelect, conn: Connection, proto: seq[string]): @@ -52,17 +58,20 @@ proc select*(m: MultistreamSelect, await conn.writeLp((proto[0] & "\n")) # select proto var s = string.fromBytes((await conn.readLp(1024))) # read ms header - s.removeSuffix("\n") + validateSuffix(s) + if s != Codec: - notice "handshake failed", codec = s.toHex() - return "" + notice "handshake failed", codec = s + raise newException(CatchableError, "MultistreamSelect handshake failed") + else: + trace "multistream handshake success" if proto.len() == 0: # no protocols, must be a handshake call return Codec else: s = string.fromBytes(await conn.readLp(1024)) # read the first proto + validateSuffix(s) trace "reading first requested proto" - s.removeSuffix("\n") if s == proto[0]: trace "successfully selected ", proto = proto[0] return proto[0] @@ -74,7 +83,7 @@ proc select*(m: MultistreamSelect, trace "selecting proto", proto = p await conn.writeLp((p & "\n")) # select proto s = string.fromBytes(await conn.readLp(1024)) # read the first proto - s.removeSuffix("\n") + validateSuffix(s) if s == p: trace "selected protocol", protocol = s return s @@ -110,12 +119,18 @@ proc list*(m: MultistreamSelect, result = list -proc handle*(m: MultistreamSelect, conn: Connection) {.async, gcsafe.} = +proc handle*(m: MultistreamSelect, conn: Connection, active: bool = false) {.async, gcsafe.} = trace "handle: starting multistream handling" + var handshaked = active try: while not conn.closed: var ms = string.fromBytes(await conn.readLp(1024)) - ms.removeSuffix("\n") + validateSuffix(ms) + + if not handshaked and ms != Codec: + error "expected handshake message", instead=ms + raise newException(CatchableError, + "MultistreamSelect handling failed, invalid first message") trace "handle: got request for ", ms if ms.len() <= 0: @@ -128,23 +143,27 @@ proc handle*(m: MultistreamSelect, conn: Connection) {.async, gcsafe.} = continue case ms: - of "ls": - trace "handle: listing protos" - var protos = "" - for h in m.handlers: - protos &= (h.proto & "\n") - await conn.writeLp(protos) - of Codec: + of "ls": + trace "handle: listing protos" + var protos = "" + for h in m.handlers: + protos &= (h.proto & "\n") + await conn.writeLp(protos) + of Codec: + if not handshaked: await conn.write(m.codec) + handshaked = true else: - for h in m.handlers: - if (not isNil(h.match) and h.match(ms)) or ms == h.proto: - trace "found handler for", protocol = ms - await conn.writeLp((h.proto & "\n")) - await h.protocol.handler(conn, ms) - return - debug "no handlers for ", protocol = ms - await conn.write(Na) + await conn.write(Na) + else: + for h in m.handlers: + if (not isNil(h.match) and h.match(ms)) or ms == h.proto: + trace "found handler for", protocol = ms + await conn.writeLp((h.proto & "\n")) + await h.protocol.handler(conn, ms) + return + debug "no handlers for ", protocol = ms + await conn.write(Na) except CancelledError as exc: await conn.close() raise exc diff --git a/libp2p/switch.nim b/libp2p/switch.nim index ccd671d98..4e5c27c15 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -260,7 +260,7 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = for muxer in s.muxers.values: ms.addHandler(muxer.codec, muxer) - # handle subsequent requests + # handle subsequent secure requests await ms.handle(sconn) except CancelledError as exc: @@ -273,8 +273,9 @@ proc upgradeIncoming(s: Switch, conn: Connection) {.async, gcsafe.} = for k in s.secureManagers: ms.addHandler(k.codec, securedHandler) - # handle secured connections - await ms.handle(conn) + # handle un-secured connections + # we handshaked above, set this ms handler as active + await ms.handle(conn, active = true) proc internalConnect(s: Switch, peer: PeerInfo): Future[Connection] {.async.} =