diff --git a/libp2p/multistream.nim b/libp2p/multistream.nim index 991e56bbf..8c2c308d6 100644 --- a/libp2p/multistream.nim +++ b/libp2p/multistream.nim @@ -15,7 +15,8 @@ import vbuffer, stream, pushable, asynciters, - lenprefixed] + lenprefixed, + utils] logScope: topic = "Multistream" @@ -48,25 +49,6 @@ proc newMultistreamHandshakeException*(): ref Exception {.inline.} = result = newException(MultistreamHandshakeException, "could not perform multistream handshake") -var appendNl: Through[seq[byte]] = proc (i: Source[seq[byte]]): Source[seq[byte]] {.gcsafe.} = - proc append(item: Future[seq[byte]]): Future[seq[byte]] {.async.} = - result = await item - result.add(byte('\n')) - - return iterator(): Future[seq[byte]] {.closure.} = - for item in i: - yield append(item) - -var stripNl: Through[seq[byte]] = proc (i: Source[seq[byte]]): Source[seq[byte]] {.gcsafe.} = - proc strip(item: Future[seq[byte]]): Future[seq[byte]] {.async.} = - result = await item - if result.len > 0 and result[^1] == byte('\n'): - result.setLen(result.high) - - return iterator(): Future[seq[byte]] {.closure.} = - for item in i: - yield strip(item) - proc init*(M: type[MultistreamSelect]): MultistreamSelect = M(codec: toSeq(Codec).mapIt( it.byte ), ls: Ls.toBytes(), @@ -77,15 +59,15 @@ proc select*(m: MultistreamSelect, conn: Connection, protos: seq[string]): Future[string] {.async.} = - trace "initiating handshake", codec = m.codec + trace "initiating handshake", codec = Codec, + proto = protos var pushable = Pushable[seq[byte]].init() # pushable source - var source = pipe(pushable, - appendNl, - m.lp.encoder, - conn.toThrough, - m.lp.decoder, - stripNl) + appendNl(), + m.lp.encoder, + conn.toThrough, + m.lp.decoder, + stripNl()) # handshake first await pushable.push(m.codec) @@ -138,11 +120,11 @@ proc list*(m: MultistreamSelect, var pushable = Pushable[seq[byte]].init() var source = pipe(pushable, - appendNl, + appendNl(), m.lp.encoder, conn.toThrough, m.lp.decoder, - stripNl) + stripNl()) await pushable.push(m.ls) # send ls @@ -156,15 +138,15 @@ proc list*(m: MultistreamSelect, result = list proc handle*(m: MultistreamSelect, conn: Connection) {.async, gcsafe.} = - trace "handle: starting multistream handling" + trace "starting multistream handling" try: var pushable = Pushable[seq[byte]].init() var source = pipe(pushable, - appendNl, + appendNl(), m.lp.encoder, conn.toThrough, m.lp.decoder, - stripNl) + stripNl()) for chunk in source: var msg = string.fromBytes((await chunk)) @@ -181,8 +163,13 @@ proc handle*(m: MultistreamSelect, conn: Connection) {.async, gcsafe.} = case msg: of Ls: trace "listing protos" - for h in m.handlers: - await pushable.push(h.proto.toBytes()) + var protos: string + for i in 0..m.handlers.high: + protos &= m.handlers[i].proto + if i < m.handlers.high: + protos &= "\n" + + await pushable.push(protos.toBytes()) of Codec: trace "handling handshake" await pushable.push(m.codec) diff --git a/tests/testmultistream.nim b/tests/testmultistream.nim index 91a32e857..6a23db30e 100644 --- a/tests/testmultistream.nim +++ b/tests/testmultistream.nim @@ -1,184 +1,84 @@ -import unittest, strutils, sequtils, strformat, options -import chronos -import ../libp2p/connection, - ../libp2p/multistream, - ../libp2p/stream/lpstream, - ../libp2p/stream/bufferstream, - ../libp2p/connection, - ../libp2p/multiaddress, - ../libp2p/transports/transport, - ../libp2p/transports/tcptransport, - ../libp2p/protocols/protocol, - ../libp2p/crypto/crypto, - ../libp2p/peerinfo, - ../libp2p/peer +import unittest, strutils, sequtils +import chronos, stew/byteutils +import crypto/crypto, + streams/[stream, pushable, connection, utils, lenprefixed], + transports/[transport, tcptransport], + protocols/protocol, + multistream, + multiaddress, + peerinfo, + peer when defined(nimHasUsed): {.used.} -## Mock stream for select test -type - TestSelectStream = ref object of LPStream - step*: int +const + CodecString = "/multistream/1.0.0" + TestProtoString = "/test/proto/1.0.0" + TestString = "HELLO" -method readExactly*(s: TestSelectStream, - pbytes: pointer, - nbytes: int): Future[void] {.async, gcsafe.} = - case s.step: - of 1: - var buf = newSeq[byte](1) - buf[0] = 19 - copyMem(pbytes, addr buf[0], buf.len()) - s.step = 2 - of 2: - var buf = "/multistream/1.0.0\n" - copyMem(pbytes, addr buf[0], buf.len()) - s.step = 3 - of 3: - var buf = newSeq[byte](1) - buf[0] = 18 - copyMem(pbytes, addr buf[0], buf.len()) - s.step = 4 - of 4: - var buf = "/test/proto/1.0.0\n" - copyMem(pbytes, addr buf[0], buf.len()) - else: - copyMem(pbytes, - cstring("\0x3na\n"), - "\0x3na\n".len()) + CodecBytes = @[19.byte, 47.byte, 109.byte, + 117.byte, 108.byte, 116.byte, + 105.byte, 115.byte, 116.byte, + 114.byte, 101.byte, 97.byte, + 109.byte, 47.byte, 49.byte, + 46.byte, 48.byte, 46.byte, + 48.byte, 10.byte] -method write*(s: TestSelectStream, msg: seq[byte], msglen = -1) - {.async, gcsafe.} = discard - -method write*(s: TestSelectStream, msg: string, msglen = -1) - {.async, gcsafe.} = discard - -method close(s: TestSelectStream) {.async, gcsafe.} = - s.isClosed = true - -proc newTestSelectStream(): TestSelectStream = - new result - result.step = 1 - -## Mock stream for handles `ls` test -type - LsHandler = proc(procs: seq[byte]): Future[void] {.gcsafe.} - - TestLsStream = ref object of LPStream - step*: int - ls*: LsHandler - -method readExactly*(s: TestLsStream, - pbytes: pointer, - nbytes: int): - Future[void] {.async.} = - case s.step: - of 1: - var buf = newSeq[byte](1) - buf[0] = 19 - copyMem(pbytes, addr buf[0], buf.len()) - s.step = 2 - of 2: - var buf = "/multistream/1.0.0\n" - copyMem(pbytes, addr buf[0], buf.len()) - s.step = 3 - of 3: - var buf = newSeq[byte](1) - buf[0] = 3 - copyMem(pbytes, addr buf[0], buf.len()) - s.step = 4 - of 4: - var buf = "ls\n" - copyMem(pbytes, addr buf[0], buf.len()) - else: - copyMem(pbytes, cstring(Na), Na.len()) - -method write*(s: TestLsStream, msg: seq[byte], msglen = -1) {.async, gcsafe.} = - if s.step == 4: - await s.ls(msg) - -method write*(s: TestLsStream, msg: string, msglen = -1) - {.async, gcsafe.} = discard - -method close(s: TestLsStream) {.async, gcsafe.} = - s.isClosed = true - -proc newTestLsStream(ls: LsHandler): TestLsStream {.gcsafe.} = - new result - result.ls = ls - result.step = 1 - -## Mock stream for handles `na` test -type - NaHandler = proc(procs: string): Future[void] {.gcsafe.} - - TestNaStream = ref object of LPStream - step*: int - na*: NaHandler - -method readExactly*(s: TestNaStream, - pbytes: pointer, - nbytes: int): - Future[void] {.async, gcsafe.} = - case s.step: - of 1: - var buf = newSeq[byte](1) - buf[0] = 19 - copyMem(pbytes, addr buf[0], buf.len()) - s.step = 2 - of 2: - var buf = "/multistream/1.0.0\n" - copyMem(pbytes, addr buf[0], buf.len()) - s.step = 3 - of 3: - var buf = newSeq[byte](1) - buf[0] = 18 - copyMem(pbytes, addr buf[0], buf.len()) - s.step = 4 - of 4: - var buf = "/test/proto/1.0.0\n" - copyMem(pbytes, addr buf[0], buf.len()) - else: - copyMem(pbytes, - cstring("\0x3na\n"), - "\0x3na\n".len()) - -method write*(s: TestNaStream, msg: string, msglen = -1) {.async, gcsafe.} = - if s.step == 4: - await s.na(msg) - -method close(s: TestNaStream) {.async, gcsafe.} = - s.isClosed = true - -proc newTestNaStream(na: NaHandler): TestNaStream = - new result - result.na = na - result.step = 1 + TestProtoBytes = @[18.byte, 47.byte, 116.byte, + 101.byte, 115.byte, 116.byte, + 47.byte, 112.byte, 114.byte, + 111.byte, 116.byte, 111.byte, + 47.byte, 49.byte, 46.byte, 48.byte, + 46.byte, 48.byte, 10.byte] suite "Multistream select": test "test select custom proto": - proc testSelect(): Future[bool] {.async.} = - let ms = newMultistream() - let conn = newConnection(newTestSelectStream()) - result = (await ms.select(conn, @["/test/proto/1.0.0"])) == "/test/proto/1.0.0" + proc test(): Future[bool] {.async.} = + let pushable = Pushable[seq[byte]].init() + pushable.sinkImpl = proc(s: Stream[seq[byte]]): Sink[seq[byte]] {.gcsafe.} = + return proc(i: Source[seq[byte]]) {.async, gcsafe.} = + check: (await i()) == CodecBytes + check: (await i()) == TestProtoBytes + + await pushable.push(CodecBytes) + await pushable.push(TestProtoBytes) + await pushable.close() + + let conn = Connection.init(pushable) + var ms = MultistreamSelect.init() + + result = (await ms.select(conn, @[TestProtoString])) == TestProtoString check: - waitFor(testSelect()) == true + waitFor(test()) == true test "test handle custom proto": proc testHandle(): Future[bool] {.async.} = - let ms = newMultistream() - let conn = newConnection(newTestSelectStream()) + var ms = MultistreamSelect.init() + let pushable = Pushable[seq[byte]].init() + pushable.sinkImpl = proc(s: Stream[seq[byte]]): Sink[seq[byte]] {.gcsafe.} = + return proc(i: Source[seq[byte]]) {.async.} = + check: (await i()) == CodecBytes + check: (await i()) == TestProtoBytes + await pushable.close() + + let conn = Connection.init(pushable) var protocol: LPProtocol = new LPProtocol - proc testHandler(conn: Connection, - proto: string): + proc testHandler(conn: Connection, proto: string): Future[void] {.async, gcsafe.} = - check proto == "/test/proto/1.0.0" + check: proto == TestProtoString await conn.close() protocol.handler = testHandler - ms.addHandler("/test/proto/1.0.0", protocol) - await ms.handle(conn) + ms.addHandler(TestProtoString, protocol) + var handlerFut = ms.handle(conn) + + await pushable.push(CodecBytes) + await pushable.push(TestProtoBytes) + await pushable.close() + + await handlerFut result = true check: @@ -186,22 +86,31 @@ suite "Multistream select": test "test handle `ls`": proc testLs(): Future[bool] {.async.} = - let ms = newMultistream() + var ms = MultistreamSelect.init() + let pushable = Pushable[seq[byte]].init() + pushable.sinkImpl = proc(s: Stream[seq[byte]]): Sink[seq[byte]] {.gcsafe.} = + return proc(i: Source[seq[byte]]) {.async.} = + check: (await i()) == CodecBytes + check: (await i()) == ("\x26/test/proto1/1.0.0\n" & + "/test/proto2/1.0.0\n").toBytes() - proc testLsHandler(proto: seq[byte]) {.async, gcsafe.} # forward declaration - let conn = newConnection(newTestLsStream(testLsHandler)) - proc testLsHandler(proto: seq[byte]) {.async, gcsafe.} = - var strProto: string = cast[string](proto) - check strProto == "\x26/test/proto1/1.0.0\n/test/proto2/1.0.0\n" - await conn.close() - proc testHandler(conn: Connection, proto: string): Future[void] - {.async, gcsafe.} = discard + let conn = Connection.init(pushable) var protocol: LPProtocol = new LPProtocol - protocol.handler = testHandler + + protocol.handler = proc(conn: Connection, proto: string): + Future[void] {.async, gcsafe.} = discard + ms.addHandler("/test/proto1/1.0.0", protocol) ms.addHandler("/test/proto2/1.0.0", protocol) - await ms.handle(conn) + + var handlerFut = ms.handle(conn) + + await pushable.push(CodecBytes) # handshake + await pushable.push("\3ls\n".toBytes()) + await pushable.close() + + await handlerFut result = true check: @@ -209,23 +118,26 @@ suite "Multistream select": test "test handle `na`": proc testNa(): Future[bool] {.async.} = - let ms = newMultistream() - proc testNaHandler(msg: string): Future[void] {.async, gcsafe.} - let conn = newConnection(newTestNaStream(testNaHandler)) - - proc testNaHandler(msg: string): Future[void] {.async, gcsafe.} = - check cast[string](msg) == Na - await conn.close() + var ms = MultistreamSelect.init() + let pushable = Pushable[seq[byte]].init() + pushable.sinkImpl = proc(s: Stream[seq[byte]]): Sink[seq[byte]] {.gcsafe.} = + return proc(i: Source[seq[byte]]) {.async.} = + check: (await i()) == "\3na\n".toBytes() var protocol: LPProtocol = new LPProtocol - proc testHandler(conn: Connection, - proto: string): - Future[void] {.async, gcsafe.} = discard + proc testHandler(conn: Connection, proto: string): + Future[void] {.async, gcsafe.} = discard protocol.handler = testHandler - ms.addHandler("/unabvailable/proto/1.0.0", protocol) + ms.addHandler(TestProtoString, protocol) - await ms.handle(conn) + let conn = Connection.init(pushable) + var handlerFut = ms.handle(conn) + + await pushable.push("/invalid/proto".toBytes()) + await conn.close() + + await handlerFut result = true check: @@ -236,134 +148,146 @@ suite "Multistream select": let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") var protocol: LPProtocol = new LPProtocol - proc testHandler(conn: Connection, - proto: string): - Future[void] {.async, gcsafe.} = - check proto == "/test/proto/1.0.0" - await conn.writeLp("Hello!") + proc testHandler(conn: Connection, proto: string): + Future[void] {.async, gcsafe.} = + var pushable = Pushable[seq[byte]].init() + var lp = LenPrefixed.init() + var sink = pipe(pushable, lp.encoder, conn) + + check: proto == TestProtoString + await pushable.push(CodecString.toBytes()) + await pushable.push(TestProtoString.toBytes()) + await pushable.push(TestString.toBytes()) + await conn.close() + await sink protocol.handler = testHandler - let msListen = newMultistream() - msListen.addHandler("/test/proto/1.0.0", protocol) + var msListen = MultistreamSelect.init() + msListen.addHandler(TestProtoString, protocol) proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = await msListen.handle(conn) let transport1: TcpTransport = newTransport(TcpTransport) - asyncCheck transport1.listen(ma, connHandler) + let transportFut = await transport1.listen(ma, connHandler) - let msDial = newMultistream() + let msDial = MultistreamSelect.init() let transport2: TcpTransport = newTransport(TcpTransport) let conn = await transport2.dial(transport1.ma) - check (await msDial.select(conn, "/test/proto/1.0.0")) == true + check: (await msDial.select(conn, TestProtoString)) == true + var lp = LenPrefixed.init() + var source = pipe(conn, lp.decoder) + + let hello = string.fromBytes(await source()) + result = hello == TestString - let hello = cast[string](await conn.readLp()) - result = hello == "Hello!" await conn.close() + await transport1.close() + await transportFut check: waitFor(endToEnd()) == true - test "e2e - ls": - proc endToEnd(): Future[bool] {.async.} = - let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") + # test "e2e - ls": + # proc endToEnd(): Future[bool] {.async.} = + # let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") - let msListen = newMultistream() - var protocol: LPProtocol = new LPProtocol - protocol.handler = proc(conn: Connection, proto: string) {.async, gcsafe.} = - await conn.close() - proc testHandler(conn: Connection, - proto: string): - Future[void] {.async.} = discard - protocol.handler = testHandler - msListen.addHandler("/test/proto1/1.0.0", protocol) - msListen.addHandler("/test/proto2/1.0.0", protocol) + # let msListen = newMultistream() + # var protocol: LPProtocol = new LPProtocol + # protocol.handler = proc(conn: Connection, proto: string) {.async, gcsafe.} = + # await conn.close() + # proc testHandler(conn: Connection, + # proto: string): + # Future[void] {.async.} = discard + # protocol.handler = testHandler + # msListen.addHandler("/test/proto1/1.0.0", protocol) + # msListen.addHandler("/test/proto2/1.0.0", protocol) - let transport1: TcpTransport = newTransport(TcpTransport) - proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = - await msListen.handle(conn) - asyncCheck transport1.listen(ma, connHandler) + # let transport1: TcpTransport = newTransport(TcpTransport) + # proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = + # await msListen.handle(conn) + # asyncCheck transport1.listen(ma, connHandler) - let msDial = newMultistream() - let transport2: TcpTransport = newTransport(TcpTransport) - let conn = await transport2.dial(transport1.ma) + # let msDial = newMultistream() + # let transport2: TcpTransport = newTransport(TcpTransport) + # let conn = await transport2.dial(transport1.ma) - let ls = await msDial.list(conn) - let protos: seq[string] = @["/test/proto1/1.0.0", "/test/proto2/1.0.0"] - await conn.close() - result = ls == protos + # let ls = await msDial.list(conn) + # let protos: seq[string] = @["/test/proto1/1.0.0", "/test/proto2/1.0.0"] + # await conn.close() + # result = ls == protos - check: - waitFor(endToEnd()) == true + # check: + # waitFor(endToEnd()) == true - test "e2e - select one from a list with unsupported protos": - proc endToEnd(): Future[bool] {.async.} = - let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") + # test "e2e - select one from a list with unsupported protos": + # proc endToEnd(): Future[bool] {.async.} = + # let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") - var protocol: LPProtocol = new LPProtocol - proc testHandler(conn: Connection, - proto: string): - Future[void] {.async, gcsafe.} = - check proto == "/test/proto/1.0.0" - await conn.writeLp("Hello!") - await conn.close() + # var protocol: LPProtocol = new LPProtocol + # proc testHandler(conn: Connection, + # proto: string): + # Future[void] {.async, gcsafe.} = + # check proto == "/test/proto/1.0.0" + # await conn.writeLp("Hello!") + # await conn.close() - protocol.handler = testHandler - let msListen = newMultistream() - msListen.addHandler("/test/proto/1.0.0", protocol) + # protocol.handler = testHandler + # let msListen = newMultistream() + # msListen.addHandler("/test/proto/1.0.0", protocol) - proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = - await msListen.handle(conn) + # proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = + # await msListen.handle(conn) - let transport1: TcpTransport = newTransport(TcpTransport) - asyncCheck transport1.listen(ma, connHandler) + # let transport1: TcpTransport = newTransport(TcpTransport) + # asyncCheck transport1.listen(ma, connHandler) - let msDial = newMultistream() - let transport2: TcpTransport = newTransport(TcpTransport) - let conn = await transport2.dial(transport1.ma) + # let msDial = newMultistream() + # let transport2: TcpTransport = newTransport(TcpTransport) + # let conn = await transport2.dial(transport1.ma) - check (await msDial.select(conn, - @["/test/proto/1.0.0", "/test/no/proto/1.0.0"])) == "/test/proto/1.0.0" + # check (await msDial.select(conn, + # @["/test/proto/1.0.0", "/test/no/proto/1.0.0"])) == "/test/proto/1.0.0" - let hello = cast[string](await conn.readLp()) - result = hello == "Hello!" - await conn.close() + # let hello = cast[string](await conn.readLp()) + # result = hello == "Hello!" + # await conn.close() - check: - waitFor(endToEnd()) == true + # check: + # waitFor(endToEnd()) == true - test "e2e - select one with both valid": - proc endToEnd(): Future[bool] {.async.} = - let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") + # test "e2e - select one with both valid": + # proc endToEnd(): Future[bool] {.async.} = + # let ma: MultiAddress = Multiaddress.init("/ip4/0.0.0.0/tcp/0") - var protocol: LPProtocol = new LPProtocol - proc testHandler(conn: Connection, - proto: string): - Future[void] {.async, gcsafe.} = - await conn.writeLp(&"Hello from {proto}!") - await conn.close() + # var protocol: LPProtocol = new LPProtocol + # proc testHandler(conn: Connection, + # proto: string): + # Future[void] {.async, gcsafe.} = + # await conn.writeLp(&"Hello from {proto}!") + # await conn.close() - protocol.handler = testHandler - let msListen = newMultistream() - msListen.addHandler("/test/proto1/1.0.0", protocol) - msListen.addHandler("/test/proto2/1.0.0", protocol) + # protocol.handler = testHandler + # let msListen = newMultistream() + # msListen.addHandler("/test/proto1/1.0.0", protocol) + # msListen.addHandler("/test/proto2/1.0.0", protocol) - proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = - await msListen.handle(conn) + # proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = + # await msListen.handle(conn) - let transport1: TcpTransport = newTransport(TcpTransport) - asyncCheck transport1.listen(ma, connHandler) + # let transport1: TcpTransport = newTransport(TcpTransport) + # asyncCheck transport1.listen(ma, connHandler) - let msDial = newMultistream() - let transport2: TcpTransport = newTransport(TcpTransport) - let conn = await transport2.dial(transport1.ma) + # let msDial = newMultistream() + # let transport2: TcpTransport = newTransport(TcpTransport) + # let conn = await transport2.dial(transport1.ma) - check (await msDial.select(conn, @["/test/proto2/1.0.0", "/test/proto1/1.0.0"])) == "/test/proto2/1.0.0" + # check (await msDial.select(conn, @["/test/proto2/1.0.0", "/test/proto1/1.0.0"])) == "/test/proto2/1.0.0" - result = cast[string](await conn.readLp()) == "Hello from /test/proto2/1.0.0!" - await conn.close() + # result = cast[string](await conn.readLp()) == "Hello from /test/proto2/1.0.0!" + # await conn.close() - check: - waitFor(endToEnd()) == true + # check: + # waitFor(endToEnd()) == true