Fix WS observed address (#631)
* Fix WS observed address * Unify tcptransport & wstransport
This commit is contained in:
parent
75bfc1b5f7
commit
3669b90ceb
|
@ -71,7 +71,7 @@ proc connHandler*(self: TcpTransport,
|
||||||
try:
|
try:
|
||||||
observedAddr = MultiAddress.init(client.remoteAddress).tryGet()
|
observedAddr = MultiAddress.init(client.remoteAddress).tryGet()
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
trace "Connection setup failed", exc = exc.msg
|
trace "Failed to create observedAddr", exc = exc.msg
|
||||||
if not(isNil(client) and client.closed):
|
if not(isNil(client) and client.closed):
|
||||||
await client.closeWait()
|
await client.closeWait()
|
||||||
raise exc
|
raise exc
|
||||||
|
|
|
@ -150,13 +150,36 @@ method stop*(self: WsTransport) {.async, gcsafe.} =
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
trace "Error shutting down ws transport", exc = exc.msg
|
trace "Error shutting down ws transport", exc = exc.msg
|
||||||
|
|
||||||
proc trackConnection(self: WsTransport, conn: WsStream, dir: Direction) =
|
proc connHandler(self: WsTransport,
|
||||||
|
stream: WsSession,
|
||||||
|
dir: Direction): Future[Connection] {.async.} =
|
||||||
|
let observedAddr =
|
||||||
|
try:
|
||||||
|
let
|
||||||
|
codec =
|
||||||
|
if self.secure:
|
||||||
|
MultiAddress.init("/wss")
|
||||||
|
else:
|
||||||
|
MultiAddress.init("/ws")
|
||||||
|
remoteAddr = stream.stream.reader.tsource.remoteAddress
|
||||||
|
|
||||||
|
MultiAddress.init(remoteAddr).tryGet() & codec.tryGet()
|
||||||
|
except CatchableError as exc:
|
||||||
|
trace "Failed to create observedAddr", exc = exc.msg
|
||||||
|
if not(isNil(stream) and stream.stream.reader.closed):
|
||||||
|
await stream.close()
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
let conn = WsStream.init(stream, dir)
|
||||||
|
conn.observedAddr = observedAddr
|
||||||
|
|
||||||
self.connections[dir].add(conn)
|
self.connections[dir].add(conn)
|
||||||
proc onClose() {.async.} =
|
proc onClose() {.async.} =
|
||||||
await conn.session.stream.reader.join()
|
await conn.session.stream.reader.join()
|
||||||
self.connections[dir].keepItIf(it != conn)
|
self.connections[dir].keepItIf(it != conn)
|
||||||
trace "Cleaned up client"
|
trace "Cleaned up client"
|
||||||
asyncSpawn onClose()
|
asyncSpawn onClose()
|
||||||
|
return conn
|
||||||
|
|
||||||
method accept*(self: WsTransport): Future[Connection] {.async, gcsafe.} =
|
method accept*(self: WsTransport): Future[Connection] {.async, gcsafe.} =
|
||||||
## accept a new WS connection
|
## accept a new WS connection
|
||||||
|
@ -169,10 +192,8 @@ method accept*(self: WsTransport): Future[Connection] {.async, gcsafe.} =
|
||||||
let
|
let
|
||||||
req = await self.httpserver.accept()
|
req = await self.httpserver.accept()
|
||||||
wstransp = await self.wsserver.handleRequest(req)
|
wstransp = await self.wsserver.handleRequest(req)
|
||||||
stream = WsStream.init(wstransp, Direction.In)
|
|
||||||
|
|
||||||
self.trackConnection(stream, Direction.In)
|
return await self.connHandler(wstransp, Direction.In)
|
||||||
return stream
|
|
||||||
except TransportOsError as exc:
|
except TransportOsError as exc:
|
||||||
debug "OS Error", exc = exc.msg
|
debug "OS Error", exc = exc.msg
|
||||||
except TransportTooManyError as exc:
|
except TransportTooManyError as exc:
|
||||||
|
@ -199,10 +220,8 @@ method dial*(
|
||||||
"",
|
"",
|
||||||
secure = secure,
|
secure = secure,
|
||||||
flags = self.tlsFlags)
|
flags = self.tlsFlags)
|
||||||
stream = WsStream.init(transp, Direction.Out)
|
|
||||||
|
|
||||||
self.trackConnection(stream, Direction.Out)
|
return await self.connHandler(transp, Direction.Out)
|
||||||
return stream
|
|
||||||
|
|
||||||
method handles*(t: WsTransport, address: MultiAddress): bool {.gcsafe.} =
|
method handles*(t: WsTransport, address: MultiAddress): bool {.gcsafe.} =
|
||||||
if procCall Transport(t).handles(address):
|
if procCall Transport(t).handles(address):
|
||||||
|
|
|
@ -25,6 +25,34 @@ proc commonTransportTest*(name: string, prov: TransportProvider, ma: string) =
|
||||||
check transport1.handles(transport1.ma)
|
check transport1.handles(transport1.ma)
|
||||||
await transport1.stop()
|
await transport1.stop()
|
||||||
|
|
||||||
|
asyncTest "e2e: handle observedAddr":
|
||||||
|
let ma: MultiAddress = Multiaddress.init(ma).tryGet()
|
||||||
|
|
||||||
|
let transport1 = prov()
|
||||||
|
await transport1.start(ma)
|
||||||
|
|
||||||
|
let transport2 = prov()
|
||||||
|
|
||||||
|
proc acceptHandler() {.async, gcsafe.} =
|
||||||
|
let conn = await transport1.accept()
|
||||||
|
check transport1.handles(conn.observedAddr)
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
let handlerWait = acceptHandler()
|
||||||
|
|
||||||
|
let conn = await transport2.dial(transport1.ma)
|
||||||
|
|
||||||
|
check transport2.handles(conn.observedAddr)
|
||||||
|
|
||||||
|
await conn.close() #for some protocols, closing requires actively reading, so we must close here
|
||||||
|
|
||||||
|
await allFuturesThrowing(
|
||||||
|
allFinished(
|
||||||
|
transport1.stop(),
|
||||||
|
transport2.stop()))
|
||||||
|
|
||||||
|
await handlerWait.wait(1.seconds) # when no issues will not wait that long!
|
||||||
|
|
||||||
asyncTest "e2e: handle write":
|
asyncTest "e2e: handle write":
|
||||||
let ma: MultiAddress = Multiaddress.init(ma).tryGet()
|
let ma: MultiAddress = Multiaddress.init(ma).tryGet()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue