diff --git a/libp2p/multistream.nim b/libp2p/multistream.nim index 3d312b4d3..92fd8838b 100644 --- a/libp2p/multistream.nim +++ b/libp2p/multistream.nim @@ -40,7 +40,8 @@ proc newMultistream*(): MultisteamSelect = proc select*(m: MultisteamSelect, conn: Connection, - proto: seq[string]): Future[bool] {.async.} = + proto: seq[string]): + Future[string] {.async.} = ## select a remote protocol await conn.write(m.codec) # write handshake if proto.len() > 0: @@ -49,36 +50,41 @@ proc select*(m: MultisteamSelect, var ms = cast[string](await conn.readLp()) # read ms header ms.removeSuffix("\n") if ms != Codec: - return false + return "" if proto.len() == 0: # no protocols, must be a handshake call - return true + return "" ms = cast[string](await conn.readLp()) # read the first proto ms.removeSuffix("\n") - result = ms == proto[0] + if ms == proto[0]: + result = ms - if not result: + if not result.len > 0: for p in proto[1.. 0: m.select(conn, @[proto]) else: m.select(conn, @[]) + proto: string): Future[string] = + result = if proto.len > 0: m.select(conn, @[proto]) else: m.select(conn, @[]) + +proc select*(m: MultisteamSelect, + conn: Connection): Future[string] = + result = m.select(conn, @[]) proc list*(m: MultisteamSelect, conn: Connection): Future[seq[string]] {.async.} = ## list remote protos requests on connection - if not (await m.select(conn)): + if not (await m.select(conn)).len > 0: return - await conn.write(m.ls) # send ls + await conn.write(m.ls) # send ls var list = newSeq[string]() let ms = cast[string](await conn.readLp()) @@ -90,7 +96,7 @@ proc list*(m: MultisteamSelect, proc handle*(m: MultisteamSelect, conn: Connection) {.async, gcsafe.} = ## handle requests on connection - if not (await m.select(conn)): + if not (await m.select(conn)).len > 0: return while not conn.closed: diff --git a/tests/testmultistream.nim b/tests/testmultistream.nim index 5215a7f59..fecb16acb 100644 --- a/tests/testmultistream.nim +++ b/tests/testmultistream.nim @@ -143,7 +143,7 @@ suite "Multistream select": proc testSelect(): Future[bool] {.async.} = let ms = newMultistream() let conn = newConnection(newTestSelectStream()) - result = await ms.select(conn, @["/test/proto/1.0.0"]) + result = (await ms.select(conn, @["/test/proto/1.0.0"])) == "/test/proto/1.0.0" check: waitFor(testSelect()) == true @@ -255,8 +255,7 @@ suite "Multistream select": let transport2: TcpTransport = newTransport(TcpTransport) let conn = await transport2.dial(ma) - let res = await msDial.select(conn, @["/test/proto/1.0.0"]) - check res == true + check (await msDial.select(conn, @["/test/proto/1.0.0"])) == "/test/proto/1.0.0" let hello = cast[string](await conn.readLp()) result = hello == "Hello!" @@ -328,8 +327,8 @@ suite "Multistream select": let transport2: TcpTransport = newTransport(TcpTransport) let conn = await transport2.dial(ma) - let res = await msDial.select(conn, @["/test/proto/1.0.0", "/test/no/proto/1.0.0"]) - check res == true + check (await msDial.select(conn, + @["/test/proto/1.0.0", "/test/no/proto/1.0.0"])) == "/test/proto/1.0.0" let hello = cast[string](await conn.readLp()) result = hello == "Hello!" @@ -367,11 +366,9 @@ suite "Multistream select": let transport2: TcpTransport = newTransport(TcpTransport) let conn = await transport2.dial(ma) - let res = await msDial.select(conn, @["/test/proto2/1.0.0", "/test/proto1/1.0.0"]) - check res == true + check (await msDial.select(conn, @["/test/proto2/1.0.0", "/test/proto1/1.0.0"])) == "/test/proto2/1.0.0" - let hello = cast[string](await conn.readLp()) - result = hello == "Hello from /test/proto2/1.0.0!" + result = cast[string](await conn.readLp()) == "Hello from /test/proto2/1.0.0!" await conn.close() check: