diff --git a/chronos/apps/http/httpclient.nim b/chronos/apps/http/httpclient.nim index 6e261e8..83d5623 100644 --- a/chronos/apps/http/httpclient.nim +++ b/chronos/apps/http/httpclient.nim @@ -67,7 +67,8 @@ type Custom ## None of the above HttpClientRequestFlag* {.pure.} = enum - CloseConnection, ## Send `Connection: close` in request + DedicatedConnection, ## Create new HTTP connection for request + CloseConnection ## Send `Connection: close` in request HttpClientConnectionFlag* {.pure.} = enum Request, ## Connection has pending request @@ -110,6 +111,7 @@ type headersTimeout*: Duration connectionBufferSize*: int maxConnections*: int + connectionsCount*: int flags*: HttpClientFlags HttpAddress* = object @@ -581,11 +583,14 @@ proc connect(session: HttpSessionRef, # If all attempts to connect to the remote host have failed. raiseHttpConnectionError("Could not connect to remote host") -proc acquireConnection(session: HttpSessionRef, - ha: HttpAddress): Future[HttpClientConnectionRef] {. - async.} = +proc acquireConnection( + session: HttpSessionRef, + ha: HttpAddress, + flags: set[HttpClientRequestFlag] + ): Future[HttpClientConnectionRef] {.async.} = ## Obtain connection from ``session`` or establish a new one. - if HttpClientFlag.NewConnectionAlways in session.flags: + if (HttpClientFlag.NewConnectionAlways in session.flags) or + (HttpClientRequestFlag.DedicatedConnection in flags): var default: seq[HttpClientConnectionRef] let res = try: @@ -594,6 +599,7 @@ proc acquireConnection(session: HttpSessionRef, raiseHttpConnectionError("Connection timed out") res[].state = HttpClientConnectionState.Acquired session.connections.mgetOrPut(ha.id, default).add(res) + inc(session.connectionsCount) return res else: let conn = @@ -620,12 +626,22 @@ proc acquireConnection(session: HttpSessionRef, raiseHttpConnectionError("Connection timed out") res[].state = HttpClientConnectionState.Acquired session.connections.mgetOrPut(ha.id, default).add(res) + inc(session.connectionsCount) return res proc removeConnection(session: HttpSessionRef, conn: HttpClientConnectionRef) {.async.} = - session.connections.withValue(conn.remoteHostname, connections): - connections[].keepItIf(it != conn) + let removeHost = + block: + var res = false + session.connections.withValue(conn.remoteHostname, connections): + connections[].keepItIf(it != conn) + if len(connections[]) == 0: + res = true + res + if removeHost: + session.connections.del(conn.remoteHostname) + dec(session.connectionsCount) await conn.closeWait() proc releaseConnection(session: HttpSessionRef, @@ -807,7 +823,11 @@ proc prepareResponse(request: HttpClientRequestRef, data: openArray[byte] res.connection.state = HttpClientConnectionState.ResponseHeadersReceived if nobodyFlag: res.connection.flags.incl(HttpClientConnectionFlag.NoBody) - if connectionFlag: + let newConnectionAlways = + HttpClientFlag.NewConnectionAlways in request.session.flags + let closeConnection = + HttpClientRequestFlag.CloseConnection in request.flags + if connectionFlag and not(newConnectionAlways) and not(closeConnection): res.connection.flags.incl(HttpClientConnectionFlag.KeepAlive) res.connection.flags.incl(HttpClientConnectionFlag.Response) trackHttpClientResponse(res) @@ -994,7 +1014,7 @@ proc send*(request: HttpClientRequestRef): Future[HttpClientResponseRef] {. "Request's state is " & $request.state) let connection = try: - await request.session.acquireConnection(request.address) + await request.session.acquireConnection(request.address, request.flags) except CancelledError as exc: request.setError(newHttpInterruptError()) raise exc @@ -1045,7 +1065,7 @@ proc open*(request: HttpClientRequestRef): Future[HttpBodyWriter] {. "Request should not have static body content (len(buffer) == 0)") let connection = try: - await request.session.acquireConnection(request.address) + await request.session.acquireConnection(request.address, request.flags) except CancelledError as exc: request.setError(newHttpInterruptError()) raise exc diff --git a/tests/testhttpclient.nim b/tests/testhttpclient.nim index 412546d..1017ccf 100644 --- a/tests/testhttpclient.nim +++ b/tests/testhttpclient.nim @@ -74,6 +74,9 @@ N8r5CwGcIX/XPC3lKazzbZ8baA== suite "HTTP client testing suite": + type + TestResponseTuple = tuple[status: int, data: string, count: int] + proc createBigMessage(message: string, size: int): seq[byte] = var res = newSeq[byte](size) for i in 0 ..< len(res): @@ -700,6 +703,164 @@ suite "HTTP client testing suite": else: return false + proc testConnectionManagement(address: TransportAddress): Future[bool] {. + async.} = + let + keepHa = getAddress(address, HttpClientScheme.NonSecure, "/keep") + dropHa = getAddress(address, HttpClientScheme.NonSecure, "/drop") + + proc test1( + a1: HttpAddress, + version: HttpVersion, + sessionFlags: set[HttpClientFlag], + requestFlags: set[HttpClientRequestFlag] + ): Future[TestResponseTuple] {.async.} = + let session = HttpSessionRef.new(flags = sessionFlags) + var + data: HttpResponseTuple + count = -1 + request = HttpClientRequestRef.new(session, a1, version = version, + flags = requestFlags) + try: + data = await request.fetch() + await request.closeWait() + count = session.connectionsCount + finally: + await session.closeWait() + return (data.status, data.data.bytesToString(), count) + + proc test2( + a1, a2: HttpAddress, + version: HttpVersion, + sessionFlags: set[HttpClientFlag], + requestFlags: set[HttpClientRequestFlag] + ): Future[seq[TestResponseTuple]] {.async.} = + let session = HttpSessionRef.new(flags = sessionFlags) + var + data1: HttpResponseTuple + data2: HttpResponseTuple + count: int = -1 + request1 = HttpClientRequestRef.new(session, a1, version = version, + flags = requestFlags) + request2 = HttpClientRequestRef.new(session, a2, version = version, + flags = requestFlags) + try: + data1 = await request1.fetch() + data2 = await request2.fetch() + await request1.closeWait() + await request2.closeWait() + count = session.connectionsCount + finally: + await session.closeWait() + return @[(data1.status, data1.data.bytesToString(), count), + (data2.status, data2.data.bytesToString(), count)] + + proc process(r: RequestFence): Future[HttpResponseRef] {.async.} = + if r.isOk(): + let request = r.get() + case request.uri.path + of "/keep": + let headers = HttpTable.init([("connection", "keep-alive")]) + return await request.respond(Http200, "ok", headers = headers) + of "/drop": + let headers = HttpTable.init([("connection", "close")]) + return await request.respond(Http200, "ok", headers = headers) + else: + return await request.respond(Http404, "Page not found") + else: + return dumbResponse() + + var server = createServer(address, process, false) + server.start() + + try: + let + r1 = await test1(keepHa, HttpVersion10, {}, {}) + r2 = await test1(keepHa, HttpVersion10, + {HttpClientFlag.NewConnectionAlways}, {}) + r3 = await test1(keepHa, HttpVersion10, {}, + {HttpClientRequestFlag.DedicatedConnection}) + r4 = await test1(keepHa, HttpVersion10, {}, + {HttpClientRequestFlag.DedicatedConnection, + HttpClientRequestFlag.CloseConnection}) + r5 = await test1(dropHa, HttpVersion10, {}, {}) + r6 = await test1(dropHa, HttpVersion10, + {HttpClientFlag.NewConnectionAlways}, {}) + r7 = await test1(dropHa, HttpVersion10, {}, + {HttpClientRequestFlag.DedicatedConnection}) + r8 = await test1(dropHa, HttpVersion10, {}, + {HttpClientRequestFlag.DedicatedConnection, + HttpClientRequestFlag.CloseConnection}) + check: + r1 == (200, "ok", 0) + r2 == (200, "ok", 0) + r3 == (200, "ok", 0) + r4 == (200, "ok", 0) + r5 == (200, "ok", 0) + r6 == (200, "ok", 0) + r7 == (200, "ok", 0) + r8 == (200, "ok", 0) + + let + d1 = await test2(keepHa, dropHa, HttpVersion10, {}, {}) + d2 = await test2(keepHa, dropHa, HttpVersion10, + {HttpClientFlag.NewConnectionAlways}, {}) + d3 = await test2(keepHa, dropHa, HttpVersion10, {}, + {HttpClientRequestFlag.DedicatedConnection}) + d4 = await test2(keepHa, dropHa, HttpVersion10, {}, + {HttpClientRequestFlag.DedicatedConnection, + HttpClientRequestFlag.CloseConnection}) + d5 = await test2(dropHa, keepHa, HttpVersion10, {}, {}) + d6 = await test2(dropHa, keepHa, HttpVersion10, + {HttpClientFlag.NewConnectionAlways}, {}) + d7 = await test2(dropHa, keepHa, HttpVersion10, {}, + {HttpClientRequestFlag.DedicatedConnection}) + d8 = await test2(dropHa, keepHa, HttpVersion10, {}, + {HttpClientRequestFlag.DedicatedConnection, + HttpClientRequestFlag.CloseConnection}) + check: + d1 == @[(200, "ok", 0), (200, "ok", 0)] + d2 == @[(200, "ok", 0), (200, "ok", 0)] + d3 == @[(200, "ok", 0), (200, "ok", 0)] + d4 == @[(200, "ok", 0), (200, "ok", 0)] + d5 == @[(200, "ok", 0), (200, "ok", 0)] + d6 == @[(200, "ok", 0), (200, "ok", 0)] + d7 == @[(200, "ok", 0), (200, "ok", 0)] + d8 == @[(200, "ok", 0), (200, "ok", 0)] + + let + n1 = await test1(keepHa, HttpVersion11, {}, {}) + n2 = await test2(keepHa, keepHa, HttpVersion11, {}, {}) + n3 = await test1(dropHa, HttpVersion11, {}, {}) + n4 = await test2(dropHa, dropHa, HttpVersion11, {}, {}) + n5 = await test1(keepHa, HttpVersion11, + {HttpClientFlag.NewConnectionAlways}, {}) + n6 = await test1(keepHa, HttpVersion11, {}, + {HttpClientRequestFlag.DedicatedConnection}) + n7 = await test1(keepHa, HttpVersion11, {}, + {HttpClientRequestFlag.DedicatedConnection, + HttpClientRequestFlag.CloseConnection}) + n8 = await test1(keepHa, HttpVersion11, {}, + {HttpClientRequestFlag.CloseConnection}) + n9 = await test1(keepHa, HttpVersion11, + {HttpClientFlag.NewConnectionAlways}, + {HttpClientRequestFlag.CloseConnection}) + check: + n1 == (200, "ok", 1) + n2 == @[(200, "ok", 2), (200, "ok", 2)] + n3 == (200, "ok", 0) + n4 == @[(200, "ok", 0), (200, "ok", 0)] + n5 == (200, "ok", 0) + n6 == (200, "ok", 1) + n7 == (200, "ok", 0) + n8 == (200, "ok", 0) + n9 == (200, "ok", 0) + finally: + await server.stop() + await server.closeWait() + + return true + test "HTTP all request methods test": let address = initTAddress("127.0.0.1:30080") check waitFor(testMethods(address, false)) == 18 @@ -767,6 +928,10 @@ suite "HTTP client testing suite": test "HTTPS basic authorization test": check waitFor(testBasicAuthorization()) == true + test "HTTP client connection management test": + let address = initTAddress("127.0.0.1:30080") + check waitFor(testConnectionManagement(address)) == true + test "Leaks test": proc getTrackerLeaks(tracker: string): bool = let tracker = getTracker(tracker)