From 853299e399746eff4096870067cbc61861ecd534 Mon Sep 17 00:00:00 2001 From: Tanguy Date: Wed, 9 Mar 2022 14:38:45 +0100 Subject: [PATCH] Accept timeout (#102) --- websock/http/common.nim | 5 +++ websock/http/server.nim | 90 +++++++++++++++++++++++------------------ 2 files changed, 55 insertions(+), 40 deletions(-) diff --git a/websock/http/common.nim b/websock/http/common.nim index 04c42427..fbfcc992 100644 --- a/websock/http/common.nim +++ b/websock/http/common.nim @@ -67,6 +67,11 @@ proc closeWait*(stream: AsyncStream) {.async.} = stream.writer.closeStream(), stream.reader.tsource.closeTransp()) +proc close*(stream: AsyncStream) = + stream.reader.close() + stream.writer.close() + stream.reader.tsource.close() + proc sendResponse*( request: HttpRequest, code: HttpCode, diff --git a/websock/http/server.nim b/websock/http/server.nim index e0cb0c49..a3230ea2 100644 --- a/websock/http/server.nim +++ b/websock/http/server.nim @@ -29,6 +29,8 @@ type HttpServer* = ref object of StreamServer handler*: HttpAsyncCallback + handshakeTimeout*: Duration + headersTimeout*: Duration case secure*: bool: of true: tlsFlags*: set[TLSFlags] @@ -72,7 +74,7 @@ proc parseRequest( try: let hlenfut = stream.reader.readUntil( addr buffer[0], MaxHttpHeadersSize, sep = HeaderSep) - let ores = await withTimeout(hlenfut, HttpHeadersTimeout) + let ores = await withTimeout(hlenfut, server.headersTimeout) if not ores: # Timeout trace "Timeout expired while receiving headers", address = $remoteAddr @@ -191,25 +193,48 @@ proc accept*(server: HttpServer): Future[HttpRequest] trace "Got new request", isTls = server.secure try: - return await server.parseRequest(stream) + let + parseFut = server.parseRequest(stream) + if await withTimeout(parseFut, server.handshakeTimeout): + return parseFut.read() + raise newException(HttpError, "Timed out parsing request") except CatchableError as exc: - await stream.closeWait() + # Can't hold up the accept loop + stream.close() raise exc proc create*( _: typedesc[HttpServer], - address: TransportAddress, + address: TransportAddress | string, handler: HttpAsyncCallback = nil, - flags: set[ServerFlags] = {}): HttpServer + flags: set[ServerFlags] = {}, + headersTimeout = HttpHeadersTimeout, + handshakeTimeout = 0.seconds + ): HttpServer {.raises: [Defect, CatchableError].} = # TODO: remove CatchableError ## Make a new HTTP Server ## - var server = HttpServer(handler: handler) + var server = HttpServer( + handler: handler, + headersTimeout: headersTimeout, + handshakeTimeout: + if handshakeTimeout == 0.seconds: + # default to headersTimeout * 1.05 + headersTimeout + (headersTimeout div 20) + else: handshakeTimeout, + ) + + let localAddress = + when address is string: + initTAddress(address) + else: + address + server = HttpServer( createStreamServer( - address, + localAddress, handleConnCb, flags, child = StreamServer(server))) @@ -218,30 +243,28 @@ proc create*( return server -proc create*( - _: typedesc[HttpServer], - host: string, - handler: HttpAsyncCallback = nil, - flags: set[ServerFlags] = {}): HttpServer - {.raises: [Defect, CatchableError].} = # TODO: remove CatchableError - ## Make a new HTTP Server - ## - - return HttpServer.create(initTAddress(host), handler, flags) - proc create*( _: typedesc[TlsHttpServer], - address: TransportAddress, + address: TransportAddress | string, tlsPrivateKey: TLSPrivateKey, tlsCertificate: TLSCertificate, handler: HttpAsyncCallback = nil, flags: set[ServerFlags] = {}, tlsFlags: set[TLSFlags] = {}, tlsMinVersion = TLSVersion.TLS12, - tlsMaxVersion = TLSVersion.TLS12): TlsHttpServer + tlsMaxVersion = TLSVersion.TLS12, + headersTimeout = HttpHeadersTimeout, + handshakeTimeout = 0.seconds + ): TlsHttpServer {.raises: [Defect, CatchableError].} = # TODO: remove CatchableError var server = TlsHttpServer( + headersTimeout: headersTimeout, + handshakeTimeout: + if handshakeTimeout == 0.seconds: + # default to headersTimeout * 1.05 + headersTimeout + (headersTimeout div 20) + else: handshakeTimeout, secure: true, handler: handler, tlsPrivateKey: tlsPrivateKey, @@ -249,9 +272,15 @@ proc create*( minVersion: tlsMinVersion, maxVersion: tlsMaxVersion) + let localAddress = + when address is string: + initTAddress(address) + else: + address + server = TlsHttpServer( createStreamServer( - address, + localAddress, handleTlsConnCb, flags, child = StreamServer(server))) @@ -259,22 +288,3 @@ proc create*( trace "Created TLS HTTP Server", host = $server.localAddress() return server - -proc create*( - _: typedesc[TlsHttpServer], - host: string, - tlsPrivateKey: TLSPrivateKey, - tlsCertificate: TLSCertificate, - handler: HttpAsyncCallback = nil, - flags: set[ServerFlags] = {}, - tlsFlags: set[TLSFlags] = {}, - tlsMinVersion = TLSVersion.TLS12, - tlsMaxVersion = TLSVersion.TLS12): TlsHttpServer - {.raises: [Defect, CatchableError].} = # TODO: remove CatchableError - TlsHttpServer.create( - address = initTAddress(host), - handler = handler, - tlsPrivateKey = tlsPrivateKey, - tlsCertificate = tlsCertificate, - flags = flags, - tlsFlags = tlsFlags)