diff --git a/libp2p/transports/tcptransport.nim b/libp2p/transports/tcptransport.nim index 7ed322f48..b0527dfbd 100644 --- a/libp2p/transports/tcptransport.nim +++ b/libp2p/transports/tcptransport.nim @@ -71,10 +71,10 @@ proc connHandler*(self: TcpTransport, try: observedAddr = MultiAddress.init(client.remoteAddress).tryGet() except CatchableError as exc: - trace "Connection setup failed", exc = exc.msg + trace "Failed to create observedAddr", exc = exc.msg if not(isNil(client) and client.closed): await client.closeWait() - raise exc + raise exc trace "Handling tcp connection", address = $observedAddr, dir = $dir, diff --git a/libp2p/transports/wstransport.nim b/libp2p/transports/wstransport.nim index ac922f27e..a75cafc21 100644 --- a/libp2p/transports/wstransport.nim +++ b/libp2p/transports/wstransport.nim @@ -150,13 +150,36 @@ method stop*(self: WsTransport) {.async, gcsafe.} = except CatchableError as exc: trace "Error shutting down ws transport", exc = exc.msg -proc trackConnection(self: WsTransport, conn: WsStream, dir: Direction) = +proc connHandler(self: WsTransport, + stream: WsSession, + dir: Direction): Future[Connection] {.async.} = + let observedAddr = + try: + let + codec = + if self.secure: + MultiAddress.init("/wss") + else: + MultiAddress.init("/ws") + remoteAddr = stream.stream.reader.tsource.remoteAddress + + MultiAddress.init(remoteAddr).tryGet() & codec.tryGet() + except CatchableError as exc: + trace "Failed to create observedAddr", exc = exc.msg + if not(isNil(stream) and stream.stream.reader.closed): + await stream.close() + raise exc + + let conn = WsStream.init(stream, dir) + conn.observedAddr = observedAddr + self.connections[dir].add(conn) proc onClose() {.async.} = await conn.session.stream.reader.join() self.connections[dir].keepItIf(it != conn) trace "Cleaned up client" asyncSpawn onClose() + return conn method accept*(self: WsTransport): Future[Connection] {.async, gcsafe.} = ## accept a new WS connection @@ -169,10 +192,8 @@ method accept*(self: WsTransport): Future[Connection] {.async, gcsafe.} = let req = await self.httpserver.accept() wstransp = await self.wsserver.handleRequest(req) - stream = WsStream.init(wstransp, Direction.In) - self.trackConnection(stream, Direction.In) - return stream + return await self.connHandler(wstransp, Direction.In) except TransportOsError as exc: debug "OS Error", exc = exc.msg except TransportTooManyError as exc: @@ -199,10 +220,8 @@ method dial*( "", secure = secure, flags = self.tlsFlags) - stream = WsStream.init(transp, Direction.Out) - self.trackConnection(stream, Direction.Out) - return stream + return await self.connHandler(transp, Direction.Out) method handles*(t: WsTransport, address: MultiAddress): bool {.gcsafe.} = if procCall Transport(t).handles(address): diff --git a/tests/commontransport.nim b/tests/commontransport.nim index e8f583976..70db2a211 100644 --- a/tests/commontransport.nim +++ b/tests/commontransport.nim @@ -25,6 +25,34 @@ proc commonTransportTest*(name: string, prov: TransportProvider, ma: string) = check transport1.handles(transport1.ma) await transport1.stop() + asyncTest "e2e: handle observedAddr": + let ma: MultiAddress = Multiaddress.init(ma).tryGet() + + let transport1 = prov() + await transport1.start(ma) + + let transport2 = prov() + + proc acceptHandler() {.async, gcsafe.} = + let conn = await transport1.accept() + check transport1.handles(conn.observedAddr) + await conn.close() + + let handlerWait = acceptHandler() + + let conn = await transport2.dial(transport1.ma) + + check transport2.handles(conn.observedAddr) + + await conn.close() #for some protocols, closing requires actively reading, so we must close here + + await allFuturesThrowing( + allFinished( + transport1.stop(), + transport2.stop())) + + await handlerWait.wait(1.seconds) # when no issues will not wait that long! + asyncTest "e2e: handle write": let ma: MultiAddress = Multiaddress.init(ma).tryGet()