diff --git a/libp2p/multistream.nim b/libp2p/multistream.nim index eced193..3d312b4 100644 --- a/libp2p/multistream.nim +++ b/libp2p/multistream.nim @@ -40,25 +40,37 @@ proc newMultistream*(): MultisteamSelect = proc select*(m: MultisteamSelect, conn: Connection, - proto: string = ""): Future[bool] {.async.} = + proto: seq[string]): Future[bool] {.async.} = ## select a remote protocol - ## TODO: select should support a list of protos to be selected - await conn.write(m.codec) # write handshake if proto.len() > 0: - await conn.writeLp(proto) # select proto + await conn.writeLp(proto[0]) # select proto - var ms = cast[string](await conn.readLp()) + var ms = cast[string](await conn.readLp()) # read ms header ms.removeSuffix("\n") if ms != Codec: return false - if proto.len() <= 0: + if proto.len() == 0: # no protocols, must be a handshake call return true - ms = cast[string](await conn.readLp()) + ms = cast[string](await conn.readLp()) # read the first proto ms.removeSuffix("\n") - result = ms == proto + result = ms == proto[0] + + if not result: + for p in proto[1.. 0: m.select(conn, @[proto]) else: m.select(conn, @[]) proc list*(m: MultisteamSelect, conn: Connection): Future[seq[string]] {.async.} = diff --git a/tests/testmultistream.nim b/tests/testmultistream.nim index 826ef72..5215a7f 100644 --- a/tests/testmultistream.nim +++ b/tests/testmultistream.nim @@ -1,4 +1,4 @@ -import unittest, strutils, sequtils, sugar +import unittest, strutils, sequtils, sugar, strformat import chronos import ../libp2p/connection, ../libp2p/multistream, ../libp2p/stream/lpstream, ../libp2p/connection, @@ -143,7 +143,7 @@ suite "Multistream select": proc testSelect(): Future[bool] {.async.} = let ms = newMultistream() let conn = newConnection(newTestSelectStream()) - result = await ms.select(conn, "/test/proto/1.0.0") + result = await ms.select(conn, @["/test/proto/1.0.0"]) check: waitFor(testSelect()) == true @@ -255,7 +255,7 @@ suite "Multistream select": let transport2: TcpTransport = newTransport(TcpTransport) let conn = await transport2.dial(ma) - let res = await msDial.select(conn, "/test/proto/1.0.0") + let res = await msDial.select(conn, @["/test/proto/1.0.0"]) check res == true let hello = cast[string](await conn.readLp()) @@ -298,3 +298,81 @@ suite "Multistream select": check: waitFor(endToEnd()) == true + + test "e2e - select one of one invalid": + proc endToEnd(): Future[bool] {.async.} = + let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53352") + + let seckey = PrivateKey.random(RSA) + var peerInfo: PeerInfo + peerInfo.peerId = PeerID.init(seckey) + var protocol: LPProtocol = new LPProtocol + proc testHandler(conn: Connection, + proto: string): + Future[void] {.async, gcsafe.} = + check proto == "/test/proto/1.0.0" + await conn.writeLp("Hello!") + await conn.close() + + protocol.handler = testHandler + let msListen = newMultistream() + msListen.addHandler("/test/proto/1.0.0", protocol) + + proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = + await msListen.handle(conn) + + let transport1: TcpTransport = newTransport(TcpTransport) + await transport1.listen(ma, connHandler) + + let msDial = newMultistream() + let transport2: TcpTransport = newTransport(TcpTransport) + let conn = await transport2.dial(ma) + + let res = await msDial.select(conn, @["/test/proto/1.0.0", "/test/no/proto/1.0.0"]) + check res == true + + let hello = cast[string](await conn.readLp()) + result = hello == "Hello!" + await conn.close() + + check: + waitFor(endToEnd()) == true + + test "e2e - select one with both valid": + proc endToEnd(): Future[bool] {.async.} = + let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53353") + + let seckey = PrivateKey.random(RSA) + var peerInfo: PeerInfo + peerInfo.peerId = PeerID.init(seckey) + var protocol: LPProtocol = new LPProtocol + proc testHandler(conn: Connection, + proto: string): + Future[void] {.async, gcsafe.} = + await conn.writeLp(&"Hello from {proto}!") + await conn.close() + + protocol.handler = testHandler + let msListen = newMultistream() + msListen.addHandler("/test/proto1/1.0.0", protocol) + msListen.addHandler("/test/proto2/1.0.0", protocol) + + proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = + await msListen.handle(conn) + + let transport1: TcpTransport = newTransport(TcpTransport) + await transport1.listen(ma, connHandler) + + let msDial = newMultistream() + let transport2: TcpTransport = newTransport(TcpTransport) + let conn = await transport2.dial(ma) + + let res = await msDial.select(conn, @["/test/proto2/1.0.0", "/test/proto1/1.0.0"]) + check res == true + + let hello = cast[string](await conn.readLp()) + result = hello == "Hello from /test/proto2/1.0.0!" + await conn.close() + + check: + waitFor(endToEnd()) == true \ No newline at end of file