diff --git a/presto/macrocommon.nim b/presto/macrocommon.nim index 62e2f21..437d2b0 100644 --- a/presto/macrocommon.nim +++ b/presto/macrocommon.nim @@ -35,7 +35,7 @@ proc makeProcName*(m, s: string): string = proc getRestReturnType*(params: NimNode): NimNode = if not(isNil(params)) and (len(params) > 0) and not(isNil(params[0])) and - (params[0].kind == nnkIdent): + (params[0].kind in {nnkIdent, nnkSym}): params[0] else: nil @@ -47,40 +47,53 @@ iterator paramsIter*(params: NimNode): tuple[name, ntype: NimNode] = for j in 0 ..< arg.len-2: yield (arg[j], argType) -proc isSimpleType*(typeNode: NimNode): bool = - typeNode.kind == nnkIdent +proc isKnownType*(typeNode: NimNode, typeNames: varargs[string]): bool = + typeNode.kind in {nnkIdent, nnkSym} and + $typeNode in typeNames + +proc isBracketExpr(n: NimNode, nodes: varargs[string]): bool = + let leadingIdx = if n.kind == nnkBracketExpr: + 0 + elif n.kind == nnkCall and + n[0].kind in {nnkOpenSymChoice, nnkClosedSymChoice} and + n[0].len > 0 and + $n[0][0] == "[]": + 1 + else: + return false + + for idx, types in nodes: + let actualIdx = leadingIdx + idx + if actualIdx > n.len: + return false + if not isKnownType(n[actualIdx], types.split("|")): + return false + + return true proc isOptionalArg*(typeNode: NimNode): bool = - (typeNode.kind == nnkBracketExpr) and (typeNode[0].kind == nnkIdent) and - (typeNode[0].strVal == "Option") + typeNode.isBracketExpr "Option" proc isBytesArg*(typeNode: NimNode): bool = - (typeNode.kind == nnkBracketExpr) and (typeNode[0].kind == nnkIdent) and - (typeNode[0].strVal == "seq") and (typeNode[1].kind == nnkIdent) and - ((typeNode[1].strVal == "byte") or (typeNode[1].strVal == "uint8")) + typeNode.isBracketExpr("seq", "byte|uint8") proc isSequenceArg*(typeNode: NimNode): bool = - (typeNode.kind == nnkBracketExpr) and (typeNode[0].kind == nnkIdent) and - (typeNode[0].strVal == "seq") + typeNode.isBracketExpr("seq") proc isContentBodyArg*(typeNode: NimNode): bool = - (typeNode.kind == nnkBracketExpr) and (typeNode[0].kind == nnkIdent) and - (typeNode[0].strVal == "Option") and (typeNode[1].kind == nnkIdent) and - (typeNode[1].strVal == "ContentBody") + typeNode.isBracketExpr("Option", "ContentBody") proc isResponseArg*(typeNode: NimNode): bool = - (typeNode.kind == nnkIdent) and (typeNode.strVal == "HttpResponseRef") + typeNode.isKnownType "HttpResponseRef" proc getSequenceType*(typeNode: NimNode): NimNode = - if (typeNode.kind == nnkBracketExpr) and (typeNode[0].kind == nnkIdent) and - (typeNode[0].strVal == "seq"): + if typeNode.isBracketExpr("seq"): typeNode[1] else: nil proc getOptionType*(typeNode: NimNode): NimNode = - if (typeNode.kind == nnkBracketExpr) and (typeNode[0].kind == nnkIdent) and - (typeNode[0].strVal == "Option"): + if typeNode.isBracketExpr("Option"): typeNode[1] else: nil diff --git a/presto/route.nim b/presto/route.nim index 465b715..afdd3a3 100644 --- a/presto/route.nim +++ b/presto/route.nim @@ -210,8 +210,7 @@ proc processApiCall(router: NimNode, meth: HttpMethod, for paramName, paramType in parameters.paramsIter(): let index = patterns.find($paramName) if isPathArg(paramType): - if isSimpleType(paramType) and - (paramType.strVal == "HttpResponseRef"): + if paramType.isKnownType("HttpResponseRef"): if isNil(respRes): respRes = paramName else: @@ -257,7 +256,7 @@ proc processApiCall(router: NimNode, meth: HttpMethod, error("Return value must not be empty and equal to [RestApiResponse]", parameters) else: - if returnType.strVal != "RestApiResponse": + if not returnType.isKnownType("RestApiResponse"): error("Return value must be equal to [RestApiResponse]", returnType) # "path" (required) arguments unmarshalling code. diff --git a/tests/testroute.nim b/tests/testroute.nim index 63ce181..bfb7724 100644 --- a/tests/testroute.nim +++ b/tests/testroute.nim @@ -1,4 +1,4 @@ -import std/[unittest, strutils, parseutils] +import std/[unittest, strutils, parseutils, typetraits] import helpers import chronos, chronos/apps import stew/byteutils @@ -151,6 +151,25 @@ suite "REST API router & macro tests": r2.kind == RestApiResponseKind.Content bytesToString(r2.content.data) == "ok-2" + test "Routes installation from generic proc": + proc addGenericRoute(router: var RestRouter, T: type) = + const typeName = typetraits.name(T) + router.api(MethodGet, "/test/" & typeName) do () -> RestApiResponse: + return RestApiResponse.response(typeName, contentType = "text/plain") + + var router = RestRouter.init(testValidate) + router.addGenericRoute(string) + router.addGenericRoute(int) + + let r1 = router.sendMockRequest(MethodGet, "http://l.to/test/string") + let r2 = router.sendMockRequest(MethodGet, "http://l.to/test/int") + + check: + r1.kind == RestApiResponseKind.Content + r2.kind == RestApiResponseKind.Content + bytesToString(r1.content.data) == "string" + bytesToString(r2.content.data) == "int" + test "Custom types as parameters test": var router = RestRouter.init(testValidate) router.api(MethodPost,