return proto string from select, instead of bool

This commit is contained in:
Dmitriy Ryajov 2019-09-04 14:15:55 -06:00
parent cc595f7947
commit 9889bd9cbf
2 changed files with 24 additions and 21 deletions

View File

@ -40,7 +40,8 @@ proc newMultistream*(): MultisteamSelect =
proc select*(m: MultisteamSelect, proc select*(m: MultisteamSelect,
conn: Connection, conn: Connection,
proto: seq[string]): Future[bool] {.async.} = proto: seq[string]):
Future[string] {.async.} =
## select a remote protocol ## select a remote protocol
await conn.write(m.codec) # write handshake await conn.write(m.codec) # write handshake
if proto.len() > 0: if proto.len() > 0:
@ -49,36 +50,41 @@ proc select*(m: MultisteamSelect,
var ms = cast[string](await conn.readLp()) # read ms header var ms = cast[string](await conn.readLp()) # read ms header
ms.removeSuffix("\n") ms.removeSuffix("\n")
if ms != Codec: if ms != Codec:
return false return ""
if proto.len() == 0: # no protocols, must be a handshake call if proto.len() == 0: # no protocols, must be a handshake call
return true return ""
ms = cast[string](await conn.readLp()) # read the first proto ms = cast[string](await conn.readLp()) # read the first proto
ms.removeSuffix("\n") ms.removeSuffix("\n")
result = ms == proto[0] if ms == proto[0]:
result = ms
if not result: if not result.len > 0:
for p in proto[1..<proto.len()]: for p in proto[1..<proto.len()]:
await conn.writeLp(p) # select proto await conn.writeLp(p) # select proto
ms = cast[string](await conn.readLp()) # read the first proto ms = cast[string](await conn.readLp()) # read the first proto
ms.removeSuffix("\n") ms.removeSuffix("\n")
result = ms == p if ms == p:
if result: result = p
break break
proc select*(m: MultisteamSelect, proc select*(m: MultisteamSelect,
conn: Connection, conn: Connection,
proto: string = ""): Future[bool] = proto: string): Future[string] =
result = if proto.len > 0: m.select(conn, @[proto]) else: m.select(conn, @[]) result = if proto.len > 0: m.select(conn, @[proto]) else: m.select(conn, @[])
proc select*(m: MultisteamSelect,
conn: Connection): Future[string] =
result = m.select(conn, @[])
proc list*(m: MultisteamSelect, proc list*(m: MultisteamSelect,
conn: Connection): Future[seq[string]] {.async.} = conn: Connection): Future[seq[string]] {.async.} =
## list remote protos requests on connection ## list remote protos requests on connection
if not (await m.select(conn)): if not (await m.select(conn)).len > 0:
return return
await conn.write(m.ls) # send ls await conn.write(m.ls) # send ls
var list = newSeq[string]() var list = newSeq[string]()
let ms = cast[string](await conn.readLp()) let ms = cast[string](await conn.readLp())
@ -90,7 +96,7 @@ proc list*(m: MultisteamSelect,
proc handle*(m: MultisteamSelect, conn: Connection) {.async, gcsafe.} = proc handle*(m: MultisteamSelect, conn: Connection) {.async, gcsafe.} =
## handle requests on connection ## handle requests on connection
if not (await m.select(conn)): if not (await m.select(conn)).len > 0:
return return
while not conn.closed: while not conn.closed:

View File

@ -143,7 +143,7 @@ suite "Multistream select":
proc testSelect(): Future[bool] {.async.} = proc testSelect(): Future[bool] {.async.} =
let ms = newMultistream() let ms = newMultistream()
let conn = newConnection(newTestSelectStream()) let conn = newConnection(newTestSelectStream())
result = await ms.select(conn, @["/test/proto/1.0.0"]) result = (await ms.select(conn, @["/test/proto/1.0.0"])) == "/test/proto/1.0.0"
check: check:
waitFor(testSelect()) == true waitFor(testSelect()) == true
@ -255,8 +255,7 @@ suite "Multistream select":
let transport2: TcpTransport = newTransport(TcpTransport) let transport2: TcpTransport = newTransport(TcpTransport)
let conn = await transport2.dial(ma) let conn = await transport2.dial(ma)
let res = await msDial.select(conn, @["/test/proto/1.0.0"]) check (await msDial.select(conn, @["/test/proto/1.0.0"])) == "/test/proto/1.0.0"
check res == true
let hello = cast[string](await conn.readLp()) let hello = cast[string](await conn.readLp())
result = hello == "Hello!" result = hello == "Hello!"
@ -328,8 +327,8 @@ suite "Multistream select":
let transport2: TcpTransport = newTransport(TcpTransport) let transport2: TcpTransport = newTransport(TcpTransport)
let conn = await transport2.dial(ma) let conn = await transport2.dial(ma)
let res = await msDial.select(conn, @["/test/proto/1.0.0", "/test/no/proto/1.0.0"]) check (await msDial.select(conn,
check res == true @["/test/proto/1.0.0", "/test/no/proto/1.0.0"])) == "/test/proto/1.0.0"
let hello = cast[string](await conn.readLp()) let hello = cast[string](await conn.readLp())
result = hello == "Hello!" result = hello == "Hello!"
@ -367,11 +366,9 @@ suite "Multistream select":
let transport2: TcpTransport = newTransport(TcpTransport) let transport2: TcpTransport = newTransport(TcpTransport)
let conn = await transport2.dial(ma) let conn = await transport2.dial(ma)
let res = await msDial.select(conn, @["/test/proto2/1.0.0", "/test/proto1/1.0.0"]) check (await msDial.select(conn, @["/test/proto2/1.0.0", "/test/proto1/1.0.0"])) == "/test/proto2/1.0.0"
check res == true
let hello = cast[string](await conn.readLp()) result = cast[string](await conn.readLp()) == "Hello from /test/proto2/1.0.0!"
result = hello == "Hello from /test/proto2/1.0.0!"
await conn.close() await conn.close()
check: check: