staticfor: compile-time loop unrolling (#232)

* staticfor: compile-time loop unrolling

* better code, preserve line info

* one more line info

* license
This commit is contained in:
Jacek Sieka 2024-09-24 10:36:50 +02:00 committed by GitHub
parent 68e8ae6413
commit 41f48efee9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 61 additions and 0 deletions

43
stew/staticfor.nim Normal file
View File

@ -0,0 +1,43 @@
# stew
# Copyright (c) 2024 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or
# http://www.apache.org/licenses/LICENSE-2.0)
# * MIT license ([LICENSE-MIT](LICENSE-MIT) or
# http://opensource.org/licenses/MIT)
# at your option. This file may not be copied, modified, or distributed except
# according to those terms.
import std/macros
proc replaceNodes(ast: NimNode, what: NimNode, by: NimNode): NimNode =
# Replace "what" ident node by "by"
proc inspect(node: NimNode): NimNode =
case node.kind:
of {nnkIdent, nnkSym}:
if node.eqIdent(what):
by
else:
node
of nnkEmpty, nnkLiterals:
node
else:
let rTree = newNimNode(node.kind, lineInfoFrom = node)
for child in node:
rTree.add inspect(child)
rTree
inspect(ast)
macro staticFor*(idx: untyped{nkIdent}, slice: static Slice[int], body: untyped): untyped =
## Unrolled `for` loop over the given range:
##
## ```nim
## staticFor(i, 0..<2):
## echo default(array[i, byte])
## ```
result = newNimNode(nnkStmtList, lineInfoFrom = body)
for i in slice:
result.add nnkBlockStmt.newTree(
ident(":staticFor" & $idx & $i),
body.replaceNodes(idx, newLit i)
)

View File

@ -32,6 +32,7 @@ import
test_ptrops, test_ptrops,
test_sequtils2, test_sequtils2,
test_sets, test_sets,
test_staticfor,
test_strformat, test_strformat,
test_templateutils, test_templateutils,
test_winacl test_winacl

17
tests/test_staticfor.nim Normal file
View File

@ -0,0 +1,17 @@
{.used.}
import unittest2, ../stew/staticfor
suite "staticfor":
test "basics":
var
a = 0
b = 0
for i in 0..10:
a += i
staticFor i, 0..10:
b += default(array[i, byte]).len
check: a == b