diff --git a/libp2p/multistreamselect.nim b/libp2p/multistreamselect.nim index ac59e3d9e..b606050b7 100644 --- a/libp2p/multistreamselect.nim +++ b/libp2p/multistreamselect.nim @@ -9,7 +9,7 @@ import sequtils, strutils import chronos -import connection, varint, vbuffer +import connection, varint, vbuffer, protocol const MsgSize* = 64*1024 const Codec* = "/multistream/1.0.0" @@ -19,13 +19,12 @@ const Ls = "\x03ls\n" type MultisteamSelectException = object of CatchableError - - Handler* = proc (conn: Connection, proto: string): Future[void] {.gcsafe.} Matcher* = proc (proto: string): bool HandlerHolder* = object proto: string - handler: Handler + protocol: LPProtocol + handler: LPProtoHandler match: Matcher MultisteamSelect* = ref object of RootObj @@ -78,7 +77,7 @@ proc list*(m: MultisteamSelect, result = list -proc handle*(m: MultisteamSelect, conn: Connection) {.async.} = +proc handle*(m: MultisteamSelect, conn: Connection) {.async, gcsafe.} = ## handle requests on connection if not (await m.select(conn)): return @@ -103,15 +102,17 @@ proc handle*(m: MultisteamSelect, conn: Connection) {.async.} = for h in m.handlers: if (not isNil(h.match) and h.match(ms)) or ms == h.proto: await conn.writeLp(h.proto & "\n") - await h.handler(conn, ms) + await h.handler(h.protocol, conn, ms) return await conn.write(m.na) -proc addHandler*(m: MultisteamSelect, - proto: string, - handler: Handler, - matcher: Matcher = nil) = +proc addHandler*[T: LPProtocol](m: MultisteamSelect, + proto: string, + protocol: T, + handler: LPProtoHandler, + matcher: Matcher = nil) = ## register a handler for the protocol m.handlers.add(HandlerHolder(proto: proto, handler: handler, + protocol: protocol, match: matcher)) diff --git a/tests/testmultistreamselect.nim b/tests/testmultistreamselect.nim index 79ac6c8f0..3a758a1c1 100644 --- a/tests/testmultistreamselect.nim +++ b/tests/testmultistreamselect.nim @@ -2,7 +2,7 @@ import unittest, strutils, sequtils, sugar import chronos import ../libp2p/connection, ../libp2p/multistreamselect, ../libp2p/stream, ../libp2p/connection, ../libp2p/multiaddress, - ../libp2p/transport, ../libp2p/tcptransport + ../libp2p/transport, ../libp2p/tcptransport, ../libp2p/protocol ## Mock stream for select test type @@ -140,11 +140,13 @@ suite "Multistream select": let ms = newMultistream() let conn = newConnection(newTestSelectStream()) - proc testHandler(conn: Connection, + var protocol: LPProtocol + proc testHandler(protocol: LPProtocol, + conn: Connection, proto: string): Future[void] {.async, gcsafe.} = check proto == "/test/proto/1.0.0" - ms.addHandler("/test/proto/1.0.0", testHandler) + ms.addHandler("/test/proto/1.0.0", protocol, testHandler) await ms.handle(conn) result = true @@ -163,10 +165,12 @@ suite "Multistream select": check strProto == "\x26/test/proto1/1.0.0\n/test/proto2/1.0.0\n" await conn.close() - proc testHandler(conn: Connection, + var protocol: LPProtocol + proc testHandler(protocol: LPProtocol, + conn: Connection, proto: string): Future[void] {.async, gcsafe.} = discard - ms.addHandler("/test/proto1/1.0.0", testHandler) - ms.addHandler("/test/proto2/1.0.0", testHandler) + ms.addHandler("/test/proto1/1.0.0", protocol, testHandler) + ms.addHandler("/test/proto2/1.0.0", protocol, testHandler) await ms.handle(conn) result = true @@ -184,9 +188,11 @@ suite "Multistream select": check cast[string](msg) == "\x3na\n" await conn.close() - proc testHandler(conn: Connection, + var protocol: LPProtocol + proc testHandler(protocol: LPProtocol, + conn: Connection, proto: string): Future[void] {.async, gcsafe.} = discard - ms.addHandler("/unabvailable/proto/1.0.0", testHandler) + ms.addHandler("/unabvailable/proto/1.0.0", protocol, testHandler) await ms.handle(conn) result = true @@ -197,14 +203,16 @@ suite "Multistream select": test "e2e - handle": proc endToEnd(): Future[bool] {.async.} = let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53350") - proc testHandler(conn: Connection, + var protocol: LPProtocol + proc testHandler(protocol: LPProtocol, + conn: Connection, proto: string): Future[void] {.async, gcsafe.} = check proto == "/test/proto/1.0.0" await conn.writeLp("Hello!") await conn.close() let msListen = newMultistream() - msListen.addHandler("/test/proto/1.0.0", testHandler) + msListen.addHandler("/test/proto/1.0.0", protocol, testHandler) proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} = await msListen.handle(conn) @@ -231,10 +239,12 @@ suite "Multistream select": let ma: MultiAddress = Multiaddress.init("/ip4/127.0.0.1/tcp/53351") let msListen = newMultistream() - proc testHandler(conn: Connection, + var protocol: LPProtocol + proc testHandler(protocol: LPProtocol, + conn: Connection, proto: string): Future[void] {.async.} = discard - msListen.addHandler("/test/proto1/1.0.0", testHandler) - msListen.addHandler("/test/proto2/1.0.0", testHandler) + msListen.addHandler("/test/proto1/1.0.0", protocol, testHandler) + msListen.addHandler("/test/proto2/1.0.0", protocol, testHandler) let transport1: TcpTransport = newTransport(TcpTransport) proc connHandler(conn: Connection): Future[void] {.async, gcsafe.} =