diff --git a/presto/client.nim b/presto/client.nim index eacf3fc..860f0ea 100644 --- a/presto/client.nim +++ b/presto/client.nim @@ -54,6 +54,7 @@ const RestContentTypeArg = "restContentType" RestAcceptTypeArg = "restAcceptType" RestClientArg = "restClient" + ExtraHeadersArg = "extraHeaders" NotAllowedArgumentNames = [RestClientArg, RestContentTypeArg, RestAcceptTypeArg] @@ -119,30 +120,35 @@ proc closeWait*(client: RestClientRef) {.async.} = proc createPostRequest*(client: RestClientRef, path: string, query: string, contentType: string, acceptType: string, + extraHeaders: openArray[HttpHeaderTuple], meth: HttpMethod, contentLength: uint64): HttpClientRequestRef = var address = client.address address.path = path address.query = query - let headers = - [ - ("content-type", contentType), - ("content-length", Base10.toString(contentLength)), - ("accept", acceptType), - ("user-agent", client.agent) - ] + + var headers = newSeqOfCap[HttpHeaderTuple](4 + extraHeaders.len) + headers.add(("content-type", contentType)) + headers.add(("content-length", Base10.toString(contentLength))) + headers.add(("accept", acceptType)) + headers.add(("user-agent", client.agent)) + headers.add extraHeaders + HttpClientRequestRef.new(client.session, address, meth, headers = headers) proc createGetRequest*(client: RestClientRef, path: string, query: string, contentType: string, acceptType: string, + extraHeaders: openArray[HttpHeaderTuple], meth: HttpMethod): HttpClientRequestRef = var address = client.address address.path = path address.query = query - let headers = [ - ("accept", acceptType), - ("user-agent", client.agent) - ] + + var headers = newSeqOfCap[HttpHeaderTuple](2 + extraHeaders.len) + headers.add(("accept", acceptType)) + headers.add(("user-agent", client.agent)) + headers.add extraHeaders + HttpClientRequestRef.new(client.session, address, meth, headers = headers) proc getEndpointOrDefault(prc: NimNode, @@ -269,6 +275,7 @@ proc isPostMethod(node: NimNode): bool {.compileTime.} = proc transformProcDefinition(prc: NimNode, clientIdent: NimNode, contentIdent: NimNode, acceptIdent: NimNode, + extraHeadersIdent: NimNode, acceptValue: NimNode, stmtList: NimNode): NimNode {.compileTime.} = var procdef = copyNimTree(prc) @@ -280,6 +287,10 @@ proc transformProcDefinition(prc: NimNode, clientIdent: NimNode, let contentTypeArg = newTree(nnkIdentDefs, contentIdent, newIdentNode("string"), newStrLitNode("application/json")) + let extraHeadersArg = + newTree(nnkIdentDefs, extraHeadersIdent, + newTree(nnkBracketExpr, ident"seq", ident"HttpHeaderTuple"), + newTree(nnkPrefix, ident"@", newTree(nnkBracket))) let acceptTypeArg = newTree(nnkIdentDefs, acceptIdent, newIdentNode("string"), acceptValue) @@ -308,6 +319,7 @@ proc transformProcDefinition(prc: NimNode, clientIdent: NimNode, res.insert(clientArg, 1) res.add(contentTypeArg) res.add(acceptTypeArg) + res.add(extraHeadersArg) res var newPragmas = @@ -597,6 +609,7 @@ proc restSingleProc(prc: NimNode): NimNode {.compileTime.} = clientIdent = newIdentNode(RestClientArg) contentTypeIdent = newIdentNode(RestContentTypeArg) acceptTypeIdent = newIdentNode(RestAcceptTypeArg) + extraHeadersIdent = newIdentNode(ExtraHeadersArg) var statements = newStmtList() @@ -614,7 +627,7 @@ proc restSingleProc(prc: NimNode): NimNode {.compileTime.} = block: parameters.expectMinLen(1) if parameters[0].kind == nnkEmpty: - error("REST procedure should no\\\ave empty return value", parameters) + error("REST procedure should non have empty return value", parameters) let node = copyNimTree(parameters[0]) case node.kind of nnkIdent: @@ -848,7 +861,8 @@ proc restSingleProc(prc: NimNode): NimNode {.compileTime.} = let chunkSize = `clientIdent`.session.connectionBufferSize let `requestIdent` = createPostRequest( `clientIdent`, `requestPath`, `requestQuery`, - `contentTypeIdent`, `acceptTypeIdent`, `meth`, + `contentTypeIdent`, `acceptTypeIdent`, + `extraHeadersIdent`, `meth`, uint64(len(`bodyIdent`)) ) await requestWithBody(`requestIdent`, @@ -861,7 +875,8 @@ proc restSingleProc(prc: NimNode): NimNode {.compileTime.} = block: let `requestIdent` = createGetRequest( `clientIdent`, `requestPath`, `requestQuery`, - `contentTypeIdent`, `acceptTypeIdent`, `meth` + `contentTypeIdent`, `acceptTypeIdent`, + `extraHeadersIdent`, `meth` ) await requestWithoutBody(`requestIdent`, `requestFlagsIdent`) @@ -906,8 +921,8 @@ proc restSingleProc(prc: NimNode): NimNode {.compileTime.} = return `responseResultIdent` let res = transformProcDefinition(prc, clientIdent, contentTypeIdent, - acceptTypeIdent, newStrLitNode(accept), - statements) + acceptTypeIdent, extraHeadersIdent, + newStrLitNode(accept), statements) res macro rest*(prc: untyped): untyped = diff --git a/tests/testclient.nim b/tests/testclient.nim index 3908e2f..b48c801 100644 --- a/tests/testclient.nim +++ b/tests/testclient.nim @@ -57,6 +57,12 @@ suite "REST API client test suite": "nobody" return RestApiResponse.response(obody) + router.api(MethodGet, "/test/echo-authorization") do () -> RestApiResponse: + return RestApiResponse.response(request.headers.getString("Authorization")) + + router.api(MethodPost, "/test/echo-authorization") do () -> RestApiResponse: + return RestApiResponse.response(request.headers.getString("Authorization")) + let serverFlags = {HttpServerFlags.NotifyDisconnect, HttpServerFlags.QueryCommaSeparatedArray} var sres = RestServerRef.new(router, serverAddress, @@ -75,6 +81,12 @@ suite "REST API client test suite": proc testSimple6(body: string): string {.rest, endpoint: "/test/simple/6", meth: HttpMethod.MethodPost.} + proc testEchoAuthorizationPost(body: string): string + {.rest, endpoint: "/test/echo-authorization", meth: HttpMethod.MethodPost.} + + proc testEchoAuthorizationGet(): string + {.rest, endpoint: "/test/echo-authorization", meth: HttpMethod.MethodGet.} + var client = RestClientRef.new(serverAddress, HttpClientScheme.NonSecure) let res1 = await client.testSimple1() let res2 = await client.testSimple2("ok-2", restContentType = "text/text") @@ -117,6 +129,15 @@ suite "REST API client test suite": message == "Different error" contentType == "application/error" + block: + let postRes = await client.testEchoAuthorizationPost( + body = "{}", + extraHeaders = @[("Authorization", "Bearer XXX")]) + check postRes == "Bearer XXX" + + let getRes = await client.testEchoAuthorizationGet(extraHeaders = @[("Authorization", "Bearer XYZ")]) + check getRes == "Bearer XYZ" + await client.closeWait() await server.stop() await server.closeWait()