Fix WS observed address (#631)

* Fix WS observed address

* Unify tcptransport & wstransport
This commit is contained in:
Tanguy 2021-10-14 13:16:34 +02:00 committed by GitHub
parent 75bfc1b5f7
commit 3669b90ceb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 56 additions and 9 deletions

View File

@ -71,10 +71,10 @@ proc connHandler*(self: TcpTransport,
try: try:
observedAddr = MultiAddress.init(client.remoteAddress).tryGet() observedAddr = MultiAddress.init(client.remoteAddress).tryGet()
except CatchableError as exc: except CatchableError as exc:
trace "Connection setup failed", exc = exc.msg trace "Failed to create observedAddr", exc = exc.msg
if not(isNil(client) and client.closed): if not(isNil(client) and client.closed):
await client.closeWait() await client.closeWait()
raise exc raise exc
trace "Handling tcp connection", address = $observedAddr, trace "Handling tcp connection", address = $observedAddr,
dir = $dir, dir = $dir,

View File

@ -150,13 +150,36 @@ method stop*(self: WsTransport) {.async, gcsafe.} =
except CatchableError as exc: except CatchableError as exc:
trace "Error shutting down ws transport", exc = exc.msg trace "Error shutting down ws transport", exc = exc.msg
proc trackConnection(self: WsTransport, conn: WsStream, dir: Direction) = proc connHandler(self: WsTransport,
stream: WsSession,
dir: Direction): Future[Connection] {.async.} =
let observedAddr =
try:
let
codec =
if self.secure:
MultiAddress.init("/wss")
else:
MultiAddress.init("/ws")
remoteAddr = stream.stream.reader.tsource.remoteAddress
MultiAddress.init(remoteAddr).tryGet() & codec.tryGet()
except CatchableError as exc:
trace "Failed to create observedAddr", exc = exc.msg
if not(isNil(stream) and stream.stream.reader.closed):
await stream.close()
raise exc
let conn = WsStream.init(stream, dir)
conn.observedAddr = observedAddr
self.connections[dir].add(conn) self.connections[dir].add(conn)
proc onClose() {.async.} = proc onClose() {.async.} =
await conn.session.stream.reader.join() await conn.session.stream.reader.join()
self.connections[dir].keepItIf(it != conn) self.connections[dir].keepItIf(it != conn)
trace "Cleaned up client" trace "Cleaned up client"
asyncSpawn onClose() asyncSpawn onClose()
return conn
method accept*(self: WsTransport): Future[Connection] {.async, gcsafe.} = method accept*(self: WsTransport): Future[Connection] {.async, gcsafe.} =
## accept a new WS connection ## accept a new WS connection
@ -169,10 +192,8 @@ method accept*(self: WsTransport): Future[Connection] {.async, gcsafe.} =
let let
req = await self.httpserver.accept() req = await self.httpserver.accept()
wstransp = await self.wsserver.handleRequest(req) wstransp = await self.wsserver.handleRequest(req)
stream = WsStream.init(wstransp, Direction.In)
self.trackConnection(stream, Direction.In) return await self.connHandler(wstransp, Direction.In)
return stream
except TransportOsError as exc: except TransportOsError as exc:
debug "OS Error", exc = exc.msg debug "OS Error", exc = exc.msg
except TransportTooManyError as exc: except TransportTooManyError as exc:
@ -199,10 +220,8 @@ method dial*(
"", "",
secure = secure, secure = secure,
flags = self.tlsFlags) flags = self.tlsFlags)
stream = WsStream.init(transp, Direction.Out)
self.trackConnection(stream, Direction.Out) return await self.connHandler(transp, Direction.Out)
return stream
method handles*(t: WsTransport, address: MultiAddress): bool {.gcsafe.} = method handles*(t: WsTransport, address: MultiAddress): bool {.gcsafe.} =
if procCall Transport(t).handles(address): if procCall Transport(t).handles(address):

View File

@ -25,6 +25,34 @@ proc commonTransportTest*(name: string, prov: TransportProvider, ma: string) =
check transport1.handles(transport1.ma) check transport1.handles(transport1.ma)
await transport1.stop() await transport1.stop()
asyncTest "e2e: handle observedAddr":
let ma: MultiAddress = Multiaddress.init(ma).tryGet()
let transport1 = prov()
await transport1.start(ma)
let transport2 = prov()
proc acceptHandler() {.async, gcsafe.} =
let conn = await transport1.accept()
check transport1.handles(conn.observedAddr)
await conn.close()
let handlerWait = acceptHandler()
let conn = await transport2.dial(transport1.ma)
check transport2.handles(conn.observedAddr)
await conn.close() #for some protocols, closing requires actively reading, so we must close here
await allFuturesThrowing(
allFinished(
transport1.stop(),
transport2.stop()))
await handlerWait.wait(1.seconds) # when no issues will not wait that long!
asyncTest "e2e: handle write": asyncTest "e2e: handle write":
let ma: MultiAddress = Multiaddress.init(ma).tryGet() let ma: MultiAddress = Multiaddress.init(ma).tryGet()