diff --git a/libp2p/multistreamselect.nim b/libp2p/multistreamselect.nim index 0d2ebf7..46127c5 100644 --- a/libp2p/multistreamselect.nim +++ b/libp2p/multistreamselect.nim @@ -45,12 +45,13 @@ proc select*(m: MultisteamSelect, conn: Connection, proto: string = ""): Future[ ## TODO: select should support a list of protos to be selected await conn.write(m.codec) # write handshake - await conn.writeLp(proto) # select proto + if proto.len() > 0: + await conn.writeLp(proto) # select proto + var ms = cast[string](await conn.readLp()) ms.removeSuffix("\n") if ms != Codec: - raise newException(MultisteamSelectException, - "Error: invalid multistream codec " & "\"" & ms & "\"") + return false if proto.len() <= 0: return true @@ -59,16 +60,32 @@ proc select*(m: MultisteamSelect, conn: Connection, proto: string = ""): Future[ ms.removeSuffix("\n") result = ms == proto +proc list*(m: MultisteamSelect, conn: Connection): Future[seq[string]] {.async.} = + ## list remote protos requests on connection + if not (await m.select(conn)): + return + + var list = newSeq[string]() + let ms = cast[string](await conn.readLp()) + for s in ms.split("\n"): + list.add(s) + + result = list + proc handle*(m: MultisteamSelect, conn: Connection) {.async.} = ## handle requests on connection if not (await m.select(conn)): return - + while not conn.closed: var ms = cast[string](await conn.readLp()) ms.removeSuffix("\n") if ms.len() <= 0: - await conn.writeLp(Na) + await conn.write(m.na) + + if m.handlers.len() == 0: + await conn.write(m.na) + continue case ms: of "ls": @@ -80,7 +97,7 @@ proc handle*(m: MultisteamSelect, conn: Connection) {.async.} = await h.handler(conn, ms) return else: - await conn.write(Na) + await conn.write(m.na) proc addHandler*(m: MultisteamSelect, proto: string, diff --git a/tests/testmultistreamselect.nim b/tests/testmultistreamselect.nim index f8f991f..c389ee1 100644 --- a/tests/testmultistreamselect.nim +++ b/tests/testmultistreamselect.nim @@ -65,13 +65,13 @@ proc newTestHandlesStream(): TestHandlesStream = new result result.step = 1 -## Mock stream for handles test +## Mock stream for handles `ls` test type LsHandler = proc(procs: seq[byte]): Future[void] TestLsStream = ref object of ReadWrite step*: int - ls*: proc(procs: seq[byte]): Future[void] + ls*: LsHandler method readExactly*(s: TestLsStream, pbytes: pointer, nbytes: int): Future[void] {.async.} = case s.step: @@ -104,6 +104,45 @@ proc newTestLsStream(ls: LsHandler): TestLsStream = result.ls = ls result.step = 1 +## Mock stream for handles `na` test +type + NaHandler = proc(procs: string): Future[void] + + TestNaStream = ref object of ReadWrite + step*: int + na*: NaHandler + +method readExactly*(s: TestNaStream, pbytes: pointer, nbytes: int): Future[void] {.async.} = + case s.step: + of 1: + var buf = newSeq[byte](1) + buf[0] = 19 + copyMem(cast[pointer](cast[uint](pbytes)), addr buf[0], buf.len()) + s.step = 2 + of 2: + var buf = "/multistream/1.0.0\n" + copyMem(cast[pointer](cast[uint](pbytes)), addr buf[0], buf.len()) + s.step = 3 + of 3: + var buf = newSeq[byte](1) + buf[0] = 18 + copyMem(cast[pointer](cast[uint](pbytes)), addr buf[0], buf.len()) + s.step = 4 + of 4: + var buf = "/test/proto/1.0.0\n" + copyMem(cast[pointer](cast[uint](pbytes)), addr buf[0], buf.len()) + else: + copyMem(cast[pointer](cast[uint](pbytes)), cstring("\0x3na\n"), "\0x3na\n".len()) + +method write*(s: TestNaStream, msg: string, msglen = -1) {.async.} = + if s.step == 4: + await s.na(msg) + +proc newTestNaStream(na: NaHandler): TestNaStream = + new result + result.na = na + result.step = 1 + suite "Multistream select": test "test select custom proto": proc testSelect(): Future[bool] {.async.} = @@ -150,3 +189,23 @@ suite "Multistream select": check: waitFor(testLs()) == true + + test "test handle `na`": + proc testNa(): Future[bool] {.async.} = + let ms = newMultistream() + + proc testNa(msg: string): Future[void] {.async.} + let conn = newConnection(newTestNaStream(testNa)) + + proc testNa(msg: string): Future[void] {.async.} = + check cast[string](msg) == "\x3na\n" + await conn.close() + + proc testHandler(conn: Connection, proto: string): Future[void] {.async.} = discard + ms.addHandler("/unabvailable/proto/1.0.0", testHandler) + + await ms.handle(conn) + result = true + + check: + waitFor(testNa()) == true