Fix WS observed address (#631)
* Fix WS observed address * Unify tcptransport & wstransport
This commit is contained in:
parent
75bfc1b5f7
commit
3669b90ceb
|
@ -71,10 +71,10 @@ proc connHandler*(self: TcpTransport,
|
|||
try:
|
||||
observedAddr = MultiAddress.init(client.remoteAddress).tryGet()
|
||||
except CatchableError as exc:
|
||||
trace "Connection setup failed", exc = exc.msg
|
||||
trace "Failed to create observedAddr", exc = exc.msg
|
||||
if not(isNil(client) and client.closed):
|
||||
await client.closeWait()
|
||||
raise exc
|
||||
raise exc
|
||||
|
||||
trace "Handling tcp connection", address = $observedAddr,
|
||||
dir = $dir,
|
||||
|
|
|
@ -150,13 +150,36 @@ method stop*(self: WsTransport) {.async, gcsafe.} =
|
|||
except CatchableError as exc:
|
||||
trace "Error shutting down ws transport", exc = exc.msg
|
||||
|
||||
proc trackConnection(self: WsTransport, conn: WsStream, dir: Direction) =
|
||||
proc connHandler(self: WsTransport,
|
||||
stream: WsSession,
|
||||
dir: Direction): Future[Connection] {.async.} =
|
||||
let observedAddr =
|
||||
try:
|
||||
let
|
||||
codec =
|
||||
if self.secure:
|
||||
MultiAddress.init("/wss")
|
||||
else:
|
||||
MultiAddress.init("/ws")
|
||||
remoteAddr = stream.stream.reader.tsource.remoteAddress
|
||||
|
||||
MultiAddress.init(remoteAddr).tryGet() & codec.tryGet()
|
||||
except CatchableError as exc:
|
||||
trace "Failed to create observedAddr", exc = exc.msg
|
||||
if not(isNil(stream) and stream.stream.reader.closed):
|
||||
await stream.close()
|
||||
raise exc
|
||||
|
||||
let conn = WsStream.init(stream, dir)
|
||||
conn.observedAddr = observedAddr
|
||||
|
||||
self.connections[dir].add(conn)
|
||||
proc onClose() {.async.} =
|
||||
await conn.session.stream.reader.join()
|
||||
self.connections[dir].keepItIf(it != conn)
|
||||
trace "Cleaned up client"
|
||||
asyncSpawn onClose()
|
||||
return conn
|
||||
|
||||
method accept*(self: WsTransport): Future[Connection] {.async, gcsafe.} =
|
||||
## accept a new WS connection
|
||||
|
@ -169,10 +192,8 @@ method accept*(self: WsTransport): Future[Connection] {.async, gcsafe.} =
|
|||
let
|
||||
req = await self.httpserver.accept()
|
||||
wstransp = await self.wsserver.handleRequest(req)
|
||||
stream = WsStream.init(wstransp, Direction.In)
|
||||
|
||||
self.trackConnection(stream, Direction.In)
|
||||
return stream
|
||||
return await self.connHandler(wstransp, Direction.In)
|
||||
except TransportOsError as exc:
|
||||
debug "OS Error", exc = exc.msg
|
||||
except TransportTooManyError as exc:
|
||||
|
@ -199,10 +220,8 @@ method dial*(
|
|||
"",
|
||||
secure = secure,
|
||||
flags = self.tlsFlags)
|
||||
stream = WsStream.init(transp, Direction.Out)
|
||||
|
||||
self.trackConnection(stream, Direction.Out)
|
||||
return stream
|
||||
return await self.connHandler(transp, Direction.Out)
|
||||
|
||||
method handles*(t: WsTransport, address: MultiAddress): bool {.gcsafe.} =
|
||||
if procCall Transport(t).handles(address):
|
||||
|
|
|
@ -25,6 +25,34 @@ proc commonTransportTest*(name: string, prov: TransportProvider, ma: string) =
|
|||
check transport1.handles(transport1.ma)
|
||||
await transport1.stop()
|
||||
|
||||
asyncTest "e2e: handle observedAddr":
|
||||
let ma: MultiAddress = Multiaddress.init(ma).tryGet()
|
||||
|
||||
let transport1 = prov()
|
||||
await transport1.start(ma)
|
||||
|
||||
let transport2 = prov()
|
||||
|
||||
proc acceptHandler() {.async, gcsafe.} =
|
||||
let conn = await transport1.accept()
|
||||
check transport1.handles(conn.observedAddr)
|
||||
await conn.close()
|
||||
|
||||
let handlerWait = acceptHandler()
|
||||
|
||||
let conn = await transport2.dial(transport1.ma)
|
||||
|
||||
check transport2.handles(conn.observedAddr)
|
||||
|
||||
await conn.close() #for some protocols, closing requires actively reading, so we must close here
|
||||
|
||||
await allFuturesThrowing(
|
||||
allFinished(
|
||||
transport1.stop(),
|
||||
transport2.stop()))
|
||||
|
||||
await handlerWait.wait(1.seconds) # when no issues will not wait that long!
|
||||
|
||||
asyncTest "e2e: handle write":
|
||||
let ma: MultiAddress = Multiaddress.init(ma).tryGet()
|
||||
|
||||
|
|
Loading…
Reference in New Issue