From 7ab114e1596f553dc09973927807f5580d5fa6bf Mon Sep 17 00:00:00 2001 From: Zahary Karadjov Date: Thu, 2 Dec 2021 17:59:44 +0200 Subject: [PATCH] Add templateutils.evalTemplateParamOnce --- stew/templateutils.nim | 53 +++++++++++++++++++++++++++++ tests/all_tests.nim | 1 + tests/test_templateutils.nim | 64 ++++++++++++++++++++++++++++++++++++ 3 files changed, 118 insertions(+) create mode 100644 stew/templateutils.nim create mode 100644 tests/test_templateutils.nim diff --git a/stew/templateutils.nim b/stew/templateutils.nim new file mode 100644 index 0000000..f3af705 --- /dev/null +++ b/stew/templateutils.nim @@ -0,0 +1,53 @@ +type CppVar[T] = distinct ptr T + +iterator evalTemplateParamOnceImpl[T](x: T): lent T = + yield x + +when defined(cpp): + # TODO `nim cpp` miscompiles iterators returning `var`, + # so we need to emulate them in terms of pointers: + iterator evalTemplateParamOnceImpl[T](x: var T): CppVar[T] = + yield CppVar[T](addr(x)) + + template stripCppVar[T](p: CppVar[T]): var T = + ((ptr T)(p))[] +else: + iterator evalTemplateParamOnceImpl[T](x: var T): var T = + yield x + +template evalTemplateParamOnce*(templateParam, newName, blk: untyped) = + ## This can be used in templates to avoid the problem of multiple + ## evaluation of template parameters. Compared to the naive approach + ## of introducing an additional local variable, it has two benefits: + ## + ## * It avoids copying whenever possible. + ## * It works for var parameters. + ## + ## Usage example: + ## + ## template foo(xParam: SomeType) = + ## evalTemplateParamOnce(xParam, x): + ## echo x + ## echo x + ## + ## A currently existing limitation is that the `evalTemplateParamOnce` + ## block is considered a `void` expression, so templates returning + ## expressions may find it difficult to benefit fully from the construct. + ## + ## Please also note that using conrol-flow statements such as `return`, + ## `continue` and `break` within the template code is possible, but + ## extra care must be taken to ensure that they are not referring to the + ## inserted `for` loop (you may need to introduce enclosing named blocks + ## for correct implementation of both `break` and `continue`). + ## + ## Both limitations will be lifted in a future implementation based on + ## view types. + block: + for paramAddr in evalTemplateParamOnceImpl(templateParam): + template newName: auto = + when paramAddr is CppVar: + stripCppVar(paramAddr) + else: + paramAddr + + blk diff --git a/tests/all_tests.nim b/tests/all_tests.nim index 541959c..1be57e3 100644 --- a/tests/all_tests.nim +++ b/tests/all_tests.nim @@ -27,6 +27,7 @@ import test_ptrops, test_sequtils2, test_sorted_set, + test_templateutils, test_results, test_varints, test_winacl diff --git a/tests/test_templateutils.nim b/tests/test_templateutils.nim new file mode 100644 index 0000000..b597739 --- /dev/null +++ b/tests/test_templateutils.nim @@ -0,0 +1,64 @@ +import + unittest, + ../stew/templateutils + +var computations = newSeq[string]() +var templateParamAddresses = newSeq[pointer]() + +type + ObjectHoldingSeq = object + data: seq[int] + +proc accessSeq(x: var ObjectHoldingSeq): var seq[int] = + computations.add("accessor") + x.data + +proc expensiveComputation(evaluationLabel: string): seq[int] = + computations.add evaluationLabel + return @[1, 2, 3] + +template reject(code: untyped) = + static: assert(not compiles(code)) + +template evalManyTimes(xParam: untyped, shouldBeMutable: bool): string = + var res: string + evalTemplateParamOnce(xParam, x): + res = $x + + when shouldBeMutable: + x.add 10 + else: + reject: + x.add(10) + + res.add " => " + res.add $x + + templateParamAddresses.add(unsafeAddr x) + res + +test "Template utils": + # Pass function call + check "@[1, 2, 3] => @[1, 2, 3]" == evalManyTimes( + expensiveComputation("call"), shouldBeMutable = false) + + # Pass var symbol + var s1 = expensiveComputation("var") + check "@[1, 2, 3] => @[1, 2, 3, 10]" == evalManyTimes( + s1, shouldBeMutable = true) + + # Pass let symbol: + let s2 = expensiveComputation("let") + check "@[1, 2, 3] => @[1, 2, 3]" == evalManyTimes( + s2, shouldBeMutable = false) + + var o = ObjectHoldingSeq(data: @[1, 2, 3]) + check "@[1, 2, 3] => @[1, 2, 3, 10]" == evalManyTimes( + o.accessSeq, shouldBeMutable = true) + + check computations == ["call", "var", "let", "accessor"] + + check: + templateParamAddresses[1] == addr s1 + templateParamAddresses[2] == unsafeAddr s2 + templateParamAddresses[3] == addr o.accessSeq