mirror of
https://github.com/status-im/nim-chronos.git
synced 2025-02-03 15:04:38 +00:00
HTTP client: Allow request connection management. (#323)
* Allow per-request connection management. Fix NewConnectionAlways leak issue. * Address review comment.
This commit is contained in:
parent
9df76c39df
commit
266e2c0ed2
@ -67,7 +67,8 @@ type
|
||||
Custom ## None of the above
|
||||
|
||||
HttpClientRequestFlag* {.pure.} = enum
|
||||
CloseConnection, ## Send `Connection: close` in request
|
||||
DedicatedConnection, ## Create new HTTP connection for request
|
||||
CloseConnection ## Send `Connection: close` in request
|
||||
|
||||
HttpClientConnectionFlag* {.pure.} = enum
|
||||
Request, ## Connection has pending request
|
||||
@ -110,6 +111,7 @@ type
|
||||
headersTimeout*: Duration
|
||||
connectionBufferSize*: int
|
||||
maxConnections*: int
|
||||
connectionsCount*: int
|
||||
flags*: HttpClientFlags
|
||||
|
||||
HttpAddress* = object
|
||||
@ -581,11 +583,14 @@ proc connect(session: HttpSessionRef,
|
||||
# If all attempts to connect to the remote host have failed.
|
||||
raiseHttpConnectionError("Could not connect to remote host")
|
||||
|
||||
proc acquireConnection(session: HttpSessionRef,
|
||||
ha: HttpAddress): Future[HttpClientConnectionRef] {.
|
||||
async.} =
|
||||
proc acquireConnection(
|
||||
session: HttpSessionRef,
|
||||
ha: HttpAddress,
|
||||
flags: set[HttpClientRequestFlag]
|
||||
): Future[HttpClientConnectionRef] {.async.} =
|
||||
## Obtain connection from ``session`` or establish a new one.
|
||||
if HttpClientFlag.NewConnectionAlways in session.flags:
|
||||
if (HttpClientFlag.NewConnectionAlways in session.flags) or
|
||||
(HttpClientRequestFlag.DedicatedConnection in flags):
|
||||
var default: seq[HttpClientConnectionRef]
|
||||
let res =
|
||||
try:
|
||||
@ -594,6 +599,7 @@ proc acquireConnection(session: HttpSessionRef,
|
||||
raiseHttpConnectionError("Connection timed out")
|
||||
res[].state = HttpClientConnectionState.Acquired
|
||||
session.connections.mgetOrPut(ha.id, default).add(res)
|
||||
inc(session.connectionsCount)
|
||||
return res
|
||||
else:
|
||||
let conn =
|
||||
@ -620,12 +626,22 @@ proc acquireConnection(session: HttpSessionRef,
|
||||
raiseHttpConnectionError("Connection timed out")
|
||||
res[].state = HttpClientConnectionState.Acquired
|
||||
session.connections.mgetOrPut(ha.id, default).add(res)
|
||||
inc(session.connectionsCount)
|
||||
return res
|
||||
|
||||
proc removeConnection(session: HttpSessionRef,
|
||||
conn: HttpClientConnectionRef) {.async.} =
|
||||
session.connections.withValue(conn.remoteHostname, connections):
|
||||
connections[].keepItIf(it != conn)
|
||||
let removeHost =
|
||||
block:
|
||||
var res = false
|
||||
session.connections.withValue(conn.remoteHostname, connections):
|
||||
connections[].keepItIf(it != conn)
|
||||
if len(connections[]) == 0:
|
||||
res = true
|
||||
res
|
||||
if removeHost:
|
||||
session.connections.del(conn.remoteHostname)
|
||||
dec(session.connectionsCount)
|
||||
await conn.closeWait()
|
||||
|
||||
proc releaseConnection(session: HttpSessionRef,
|
||||
@ -807,7 +823,11 @@ proc prepareResponse(request: HttpClientRequestRef, data: openArray[byte]
|
||||
res.connection.state = HttpClientConnectionState.ResponseHeadersReceived
|
||||
if nobodyFlag:
|
||||
res.connection.flags.incl(HttpClientConnectionFlag.NoBody)
|
||||
if connectionFlag:
|
||||
let newConnectionAlways =
|
||||
HttpClientFlag.NewConnectionAlways in request.session.flags
|
||||
let closeConnection =
|
||||
HttpClientRequestFlag.CloseConnection in request.flags
|
||||
if connectionFlag and not(newConnectionAlways) and not(closeConnection):
|
||||
res.connection.flags.incl(HttpClientConnectionFlag.KeepAlive)
|
||||
res.connection.flags.incl(HttpClientConnectionFlag.Response)
|
||||
trackHttpClientResponse(res)
|
||||
@ -994,7 +1014,7 @@ proc send*(request: HttpClientRequestRef): Future[HttpClientResponseRef] {.
|
||||
"Request's state is " & $request.state)
|
||||
let connection =
|
||||
try:
|
||||
await request.session.acquireConnection(request.address)
|
||||
await request.session.acquireConnection(request.address, request.flags)
|
||||
except CancelledError as exc:
|
||||
request.setError(newHttpInterruptError())
|
||||
raise exc
|
||||
@ -1045,7 +1065,7 @@ proc open*(request: HttpClientRequestRef): Future[HttpBodyWriter] {.
|
||||
"Request should not have static body content (len(buffer) == 0)")
|
||||
let connection =
|
||||
try:
|
||||
await request.session.acquireConnection(request.address)
|
||||
await request.session.acquireConnection(request.address, request.flags)
|
||||
except CancelledError as exc:
|
||||
request.setError(newHttpInterruptError())
|
||||
raise exc
|
||||
|
@ -74,6 +74,9 @@ N8r5CwGcIX/XPC3lKazzbZ8baA==
|
||||
|
||||
suite "HTTP client testing suite":
|
||||
|
||||
type
|
||||
TestResponseTuple = tuple[status: int, data: string, count: int]
|
||||
|
||||
proc createBigMessage(message: string, size: int): seq[byte] =
|
||||
var res = newSeq[byte](size)
|
||||
for i in 0 ..< len(res):
|
||||
@ -700,6 +703,164 @@ suite "HTTP client testing suite":
|
||||
else:
|
||||
return false
|
||||
|
||||
proc testConnectionManagement(address: TransportAddress): Future[bool] {.
|
||||
async.} =
|
||||
let
|
||||
keepHa = getAddress(address, HttpClientScheme.NonSecure, "/keep")
|
||||
dropHa = getAddress(address, HttpClientScheme.NonSecure, "/drop")
|
||||
|
||||
proc test1(
|
||||
a1: HttpAddress,
|
||||
version: HttpVersion,
|
||||
sessionFlags: set[HttpClientFlag],
|
||||
requestFlags: set[HttpClientRequestFlag]
|
||||
): Future[TestResponseTuple] {.async.} =
|
||||
let session = HttpSessionRef.new(flags = sessionFlags)
|
||||
var
|
||||
data: HttpResponseTuple
|
||||
count = -1
|
||||
request = HttpClientRequestRef.new(session, a1, version = version,
|
||||
flags = requestFlags)
|
||||
try:
|
||||
data = await request.fetch()
|
||||
await request.closeWait()
|
||||
count = session.connectionsCount
|
||||
finally:
|
||||
await session.closeWait()
|
||||
return (data.status, data.data.bytesToString(), count)
|
||||
|
||||
proc test2(
|
||||
a1, a2: HttpAddress,
|
||||
version: HttpVersion,
|
||||
sessionFlags: set[HttpClientFlag],
|
||||
requestFlags: set[HttpClientRequestFlag]
|
||||
): Future[seq[TestResponseTuple]] {.async.} =
|
||||
let session = HttpSessionRef.new(flags = sessionFlags)
|
||||
var
|
||||
data1: HttpResponseTuple
|
||||
data2: HttpResponseTuple
|
||||
count: int = -1
|
||||
request1 = HttpClientRequestRef.new(session, a1, version = version,
|
||||
flags = requestFlags)
|
||||
request2 = HttpClientRequestRef.new(session, a2, version = version,
|
||||
flags = requestFlags)
|
||||
try:
|
||||
data1 = await request1.fetch()
|
||||
data2 = await request2.fetch()
|
||||
await request1.closeWait()
|
||||
await request2.closeWait()
|
||||
count = session.connectionsCount
|
||||
finally:
|
||||
await session.closeWait()
|
||||
return @[(data1.status, data1.data.bytesToString(), count),
|
||||
(data2.status, data2.data.bytesToString(), count)]
|
||||
|
||||
proc process(r: RequestFence): Future[HttpResponseRef] {.async.} =
|
||||
if r.isOk():
|
||||
let request = r.get()
|
||||
case request.uri.path
|
||||
of "/keep":
|
||||
let headers = HttpTable.init([("connection", "keep-alive")])
|
||||
return await request.respond(Http200, "ok", headers = headers)
|
||||
of "/drop":
|
||||
let headers = HttpTable.init([("connection", "close")])
|
||||
return await request.respond(Http200, "ok", headers = headers)
|
||||
else:
|
||||
return await request.respond(Http404, "Page not found")
|
||||
else:
|
||||
return dumbResponse()
|
||||
|
||||
var server = createServer(address, process, false)
|
||||
server.start()
|
||||
|
||||
try:
|
||||
let
|
||||
r1 = await test1(keepHa, HttpVersion10, {}, {})
|
||||
r2 = await test1(keepHa, HttpVersion10,
|
||||
{HttpClientFlag.NewConnectionAlways}, {})
|
||||
r3 = await test1(keepHa, HttpVersion10, {},
|
||||
{HttpClientRequestFlag.DedicatedConnection})
|
||||
r4 = await test1(keepHa, HttpVersion10, {},
|
||||
{HttpClientRequestFlag.DedicatedConnection,
|
||||
HttpClientRequestFlag.CloseConnection})
|
||||
r5 = await test1(dropHa, HttpVersion10, {}, {})
|
||||
r6 = await test1(dropHa, HttpVersion10,
|
||||
{HttpClientFlag.NewConnectionAlways}, {})
|
||||
r7 = await test1(dropHa, HttpVersion10, {},
|
||||
{HttpClientRequestFlag.DedicatedConnection})
|
||||
r8 = await test1(dropHa, HttpVersion10, {},
|
||||
{HttpClientRequestFlag.DedicatedConnection,
|
||||
HttpClientRequestFlag.CloseConnection})
|
||||
check:
|
||||
r1 == (200, "ok", 0)
|
||||
r2 == (200, "ok", 0)
|
||||
r3 == (200, "ok", 0)
|
||||
r4 == (200, "ok", 0)
|
||||
r5 == (200, "ok", 0)
|
||||
r6 == (200, "ok", 0)
|
||||
r7 == (200, "ok", 0)
|
||||
r8 == (200, "ok", 0)
|
||||
|
||||
let
|
||||
d1 = await test2(keepHa, dropHa, HttpVersion10, {}, {})
|
||||
d2 = await test2(keepHa, dropHa, HttpVersion10,
|
||||
{HttpClientFlag.NewConnectionAlways}, {})
|
||||
d3 = await test2(keepHa, dropHa, HttpVersion10, {},
|
||||
{HttpClientRequestFlag.DedicatedConnection})
|
||||
d4 = await test2(keepHa, dropHa, HttpVersion10, {},
|
||||
{HttpClientRequestFlag.DedicatedConnection,
|
||||
HttpClientRequestFlag.CloseConnection})
|
||||
d5 = await test2(dropHa, keepHa, HttpVersion10, {}, {})
|
||||
d6 = await test2(dropHa, keepHa, HttpVersion10,
|
||||
{HttpClientFlag.NewConnectionAlways}, {})
|
||||
d7 = await test2(dropHa, keepHa, HttpVersion10, {},
|
||||
{HttpClientRequestFlag.DedicatedConnection})
|
||||
d8 = await test2(dropHa, keepHa, HttpVersion10, {},
|
||||
{HttpClientRequestFlag.DedicatedConnection,
|
||||
HttpClientRequestFlag.CloseConnection})
|
||||
check:
|
||||
d1 == @[(200, "ok", 0), (200, "ok", 0)]
|
||||
d2 == @[(200, "ok", 0), (200, "ok", 0)]
|
||||
d3 == @[(200, "ok", 0), (200, "ok", 0)]
|
||||
d4 == @[(200, "ok", 0), (200, "ok", 0)]
|
||||
d5 == @[(200, "ok", 0), (200, "ok", 0)]
|
||||
d6 == @[(200, "ok", 0), (200, "ok", 0)]
|
||||
d7 == @[(200, "ok", 0), (200, "ok", 0)]
|
||||
d8 == @[(200, "ok", 0), (200, "ok", 0)]
|
||||
|
||||
let
|
||||
n1 = await test1(keepHa, HttpVersion11, {}, {})
|
||||
n2 = await test2(keepHa, keepHa, HttpVersion11, {}, {})
|
||||
n3 = await test1(dropHa, HttpVersion11, {}, {})
|
||||
n4 = await test2(dropHa, dropHa, HttpVersion11, {}, {})
|
||||
n5 = await test1(keepHa, HttpVersion11,
|
||||
{HttpClientFlag.NewConnectionAlways}, {})
|
||||
n6 = await test1(keepHa, HttpVersion11, {},
|
||||
{HttpClientRequestFlag.DedicatedConnection})
|
||||
n7 = await test1(keepHa, HttpVersion11, {},
|
||||
{HttpClientRequestFlag.DedicatedConnection,
|
||||
HttpClientRequestFlag.CloseConnection})
|
||||
n8 = await test1(keepHa, HttpVersion11, {},
|
||||
{HttpClientRequestFlag.CloseConnection})
|
||||
n9 = await test1(keepHa, HttpVersion11,
|
||||
{HttpClientFlag.NewConnectionAlways},
|
||||
{HttpClientRequestFlag.CloseConnection})
|
||||
check:
|
||||
n1 == (200, "ok", 1)
|
||||
n2 == @[(200, "ok", 2), (200, "ok", 2)]
|
||||
n3 == (200, "ok", 0)
|
||||
n4 == @[(200, "ok", 0), (200, "ok", 0)]
|
||||
n5 == (200, "ok", 0)
|
||||
n6 == (200, "ok", 1)
|
||||
n7 == (200, "ok", 0)
|
||||
n8 == (200, "ok", 0)
|
||||
n9 == (200, "ok", 0)
|
||||
finally:
|
||||
await server.stop()
|
||||
await server.closeWait()
|
||||
|
||||
return true
|
||||
|
||||
test "HTTP all request methods test":
|
||||
let address = initTAddress("127.0.0.1:30080")
|
||||
check waitFor(testMethods(address, false)) == 18
|
||||
@ -767,6 +928,10 @@ suite "HTTP client testing suite":
|
||||
test "HTTPS basic authorization test":
|
||||
check waitFor(testBasicAuthorization()) == true
|
||||
|
||||
test "HTTP client connection management test":
|
||||
let address = initTAddress("127.0.0.1:30080")
|
||||
check waitFor(testConnectionManagement(address)) == true
|
||||
|
||||
test "Leaks test":
|
||||
proc getTrackerLeaks(tracker: string): bool =
|
||||
let tracker = getTracker(tracker)
|
||||
|
Loading…
x
Reference in New Issue
Block a user