diff --git a/libp2p/switch.nim b/libp2p/switch.nim index 3acc507..70e998f 100644 --- a/libp2p/switch.nim +++ b/libp2p/switch.nim @@ -417,7 +417,7 @@ proc dial*(s: Switch, proto: string): Future[Connection] = dial(s, peerId, addrs, @[proto]) -proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} = +proc mount*[T: LPProtocol](s: Switch, proto: T, matcher: Matcher = nil) {.gcsafe.} = if isNil(proto.handler): raise newException(CatchableError, "Protocol has to define a handle method or proc") @@ -426,7 +426,7 @@ proc mount*[T: LPProtocol](s: Switch, proto: T) {.gcsafe.} = raise newException(CatchableError, "Protocol has to define a codec string") - s.ms.addHandler(proto.codecs, proto) + s.ms.addHandler(proto.codecs, proto, matcher) proc start*(s: Switch): Future[seq[Future[void]]] {.async, gcsafe.} = trace "starting switch for peer", peerInfo = s.peerInfo diff --git a/tests/testswitch.nim b/tests/testswitch.nim index c75175e..2d5a1d1 100644 --- a/tests/testswitch.nim +++ b/tests/testswitch.nim @@ -75,6 +75,58 @@ suite "Switch": waitFor(testSwitch()) + test "e2e use switch dial proto string with custom matcher": + proc testSwitch() {.async, gcsafe.} = + let done = newFuture[void]() + proc handle(conn: Connection, proto: string) {.async, gcsafe.} = + try: + let msg = string.fromBytes(await conn.readLp(1024)) + check "Hello!" == msg + await conn.writeLp("Hello!") + finally: + await conn.close() + done.complete() + + let testProto = new TestProto + testProto.codec = TestCodec + testProto.handler = handle + + let callProto = TestCodec & "/pew" + + proc match(proto: string): bool {.gcsafe.} = + return proto == callProto + + let switch1 = newStandardSwitch(secureManagers = [SecureProtocol.Noise]) + switch1.mount(testProto, match) + + let switch2 = newStandardSwitch(secureManagers = [SecureProtocol.Noise]) + var awaiters: seq[Future[void]] + awaiters.add(await switch1.start()) + awaiters.add(await switch2.start()) + + let conn = await switch2.dial(switch1.peerInfo, callProto) + + check switch1.isConnected(switch2.peerInfo) + check switch2.isConnected(switch1.peerInfo) + + await conn.writeLp("Hello!") + let msg = string.fromBytes(await conn.readLp(1024)) + check "Hello!" == msg + await conn.close() + + await allFuturesThrowing( + done.wait(5.seconds), + switch1.stop(), + switch2.stop()) + + # this needs to go at end + await allFuturesThrowing(awaiters) + + check not switch1.isConnected(switch2.peerInfo) + check not switch2.isConnected(switch1.peerInfo) + + waitFor(testSwitch()) + test "e2e should not leak bufferstreams and connections on channel close": proc testSwitch() {.async, gcsafe.} = let done = newFuture[void]()