multistream select make sure to not report NA (#235)

* multistream select make sure to not report NA but rather empty string if all fails

Also re-enable tests

* avoid using bad constructs, make multistream.select flow crystal clear
This commit is contained in:
Giovanni Petrantoni 2020-06-23 06:38:48 +09:00 committed by GitHub
parent 6331b04cb4
commit ee6e545878
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 34 additions and 29 deletions

View File

@ -58,40 +58,45 @@ proc select*(m: MultistreamSelect,
trace "selecting proto", proto = proto[0] trace "selecting proto", proto = proto[0]
await conn.writeLp((proto[0] & "\n")) # select proto await conn.writeLp((proto[0] & "\n")) # select proto
result = string.fromBytes((await conn.readLp(1024))) # read ms header var s = string.fromBytes((await conn.readLp(1024))) # read ms header
result.removeSuffix("\n") s.removeSuffix("\n")
if result != Codec: if s != Codec:
notice "handshake failed", codec = result.toHex() notice "handshake failed", codec = s.toHex()
raise newMultistreamHandshakeException() raise newMultistreamHandshakeException()
if proto.len() == 0: # no protocols, must be a handshake call if proto.len() == 0: # no protocols, must be a handshake call
return return Codec
else:
result = string.fromBytes(await conn.readLp(1024)) # read the first proto s = string.fromBytes(await conn.readLp(1024)) # read the first proto
trace "reading first requested proto" trace "reading first requested proto"
result.removeSuffix("\n") s.removeSuffix("\n")
if result == proto[0]: if s == proto[0]:
trace "successfully selected ", proto = proto[0] trace "successfully selected ", proto = proto[0]
return return proto[0]
elif proto.len > 1:
# Try to negotiate alternatives
let protos = proto[1..<proto.len()] let protos = proto[1..<proto.len()]
trace "selecting one of several protos", protos = protos trace "selecting one of several protos", protos = protos
for p in protos: for p in protos:
trace "selecting proto", proto = p trace "selecting proto", proto = p
await conn.writeLp((p & "\n")) # select proto await conn.writeLp((p & "\n")) # select proto
result = string.fromBytes(await conn.readLp(1024)) # read the first proto s = string.fromBytes(await conn.readLp(1024)) # read the first proto
result.removeSuffix("\n") s.removeSuffix("\n")
if result == p: if s == p:
trace "selected protocol", protocol = result trace "selected protocol", protocol = s
break return s
return ""
else:
# No alternatives, fail
return ""
proc select*(m: MultistreamSelect, proc select*(m: MultistreamSelect,
conn: Connection, conn: Connection,
proto: string): Future[bool] {.async.} = proto: string): Future[bool] {.async.} =
if proto.len > 0: if proto.len > 0:
result = (await m.select(conn, @[proto])) == proto return (await m.select(conn, @[proto])) == proto
else: else:
result = (await m.select(conn, @[])) == Codec return (await m.select(conn, @[])) == Codec
proc select*(m: MultistreamSelect, conn: Connection): Future[bool] = proc select*(m: MultistreamSelect, conn: Connection): Future[bool] =
m.select(conn, "") m.select(conn, "")

View File

@ -63,7 +63,7 @@ proc secure(s: Switch, conn: Connection): Future[Connection] {.async, gcsafe.} =
raise newException(CatchableError, "No secure managers registered!") raise newException(CatchableError, "No secure managers registered!")
let manager = await s.ms.select(conn, s.secureManagers.mapIt(it.codec)) let manager = await s.ms.select(conn, s.secureManagers.mapIt(it.codec))
if manager.len == 0 or manager == "na": if manager.len == 0:
raise newException(CatchableError, "Unable to negotiate a secure channel!") raise newException(CatchableError, "Unable to negotiate a secure channel!")
trace "securing connection", codec=manager trace "securing connection", codec=manager

View File

@ -57,7 +57,7 @@ proc newTestSelectStream(): TestSelectStream =
type type
LsHandler = proc(procs: seq[byte]): Future[void] {.gcsafe.} LsHandler = proc(procs: seq[byte]): Future[void] {.gcsafe.}
TestLsStream = ref object of LPStream TestLsStream = ref object of Connection
step*: int step*: int
ls*: LsHandler ls*: LsHandler
@ -103,7 +103,7 @@ proc newTestLsStream(ls: LsHandler): TestLsStream {.gcsafe.} =
type type
NaHandler = proc(procs: string): Future[void] {.gcsafe.} NaHandler = proc(procs: string): Future[void] {.gcsafe.}
TestNaStream = ref object of LPStream TestNaStream = ref object of Connection
step*: int step*: int
na*: NaHandler na*: NaHandler

View File

@ -14,7 +14,7 @@ import testmultibase,
testpeer testpeer
import testtransport, import testtransport,
# testmultistream, testmultistream,
testbufferstream, testbufferstream,
testidentify, testidentify,
testswitch, testswitch,