513 lines
17 KiB
Nim
513 lines
17 KiB
Nim
# Nim-Taskpools
|
|
# Copyright (c) 2021 Status Research & Development GmbH
|
|
# Licensed and distributed under either of
|
|
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
|
|
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
|
|
# at your option. This file may not be copied, modified, or distributed except according to those terms.
|
|
|
|
# Taskpools
|
|
#
|
|
# This file implements a taskpool
|
|
#
|
|
# Implementation:
|
|
#
|
|
# It is a simple shared memory based work-stealing threadpool.
|
|
# The primary focus is:
|
|
# - Delegate compute intensive tasks to the threadpool.
|
|
# - Simple to audit by staying close to foundational papers
|
|
# and using simple datastructures otherwise.
|
|
# - Low energy consumption:
|
|
# threads should be put to sleep ASAP
|
|
# instead of polling/spinning (energy vs latency tradeoff)
|
|
# - Decent performance:
|
|
# Work-stealing has optimal asymptotic parallel speedup.
|
|
# Work-stealing has significantly reduced contention
|
|
# when many tasks are created,
|
|
# for example by divide-and-conquer algorithms, compared to a global task queue
|
|
#
|
|
# Not a priority:
|
|
# - Handling trillions of very short tasks (less than 100µs).
|
|
# - Advanced task dependencies or events API.
|
|
# - Unbalanced parallel-for loops.
|
|
# - Handling services that should run for the lifetime of the program.
|
|
#
|
|
# Doing IO on a compute threadpool should be avoided
|
|
# In case a thread is blocked for IO, other threads can steal pending tasks in that thread.
|
|
# If all threads are pending for IO, the threadpool will not make any progress and be soft-locked.
|
|
|
|
{.push raises: [AssertionDefect].} # Ensure no exceptions can happen
|
|
|
|
import
|
|
system/ansi_c,
|
|
std/[random, cpuinfo, atomics, macros],
|
|
./channels_spsc_single,
|
|
./chase_lev_deques,
|
|
./event_notifiers,
|
|
./primitives/[barriers, allocs],
|
|
./instrumentation/[contracts, loggers],
|
|
./sparsesets,
|
|
./flowvars,
|
|
./ast_utils
|
|
|
|
export
|
|
# flowvars
|
|
Flowvar, isSpawned, isReady, sync
|
|
|
|
import std/[isolation, tasks]
|
|
export isolation
|
|
|
|
type
|
|
WorkerID = int32
|
|
|
|
TaskNode = ptr object
|
|
# Linked list of tasks
|
|
parent: TaskNode
|
|
task: Task
|
|
|
|
Signal = object
|
|
terminate {.align: 64.}: Atomic[bool]
|
|
|
|
WorkerContext = object
|
|
## Thread-local worker context
|
|
|
|
# Params
|
|
id: WorkerID
|
|
taskpool: Taskpool
|
|
|
|
# Tasks
|
|
taskDeque: ptr ChaseLevDeque[TaskNode] # owned task deque
|
|
currentTask: TaskNode
|
|
|
|
# Synchronization
|
|
eventNotifier: ptr EventNotifier # shared event notifier
|
|
signal: ptr Signal # owned signal
|
|
|
|
# Thefts
|
|
rng: Rand # RNG state to select victims
|
|
otherDeques: ptr UncheckedArray[ChaseLevDeque[TaskNode]]
|
|
victims: SparseSet
|
|
|
|
Taskpool* = ptr object
|
|
## A taskpool schedules procedures to be executed in parallel
|
|
barrier: SyncBarrier
|
|
## Barrier for initialization and teardown
|
|
# --- Align: 64
|
|
eventNotifier: EventNotifier
|
|
## Puts thread to sleep
|
|
|
|
numThreads*{.align: 64.}: int
|
|
workerDeques: ptr UncheckedArray[ChaseLevDeque[TaskNode]]
|
|
## Direct access for task stealing
|
|
workers: ptr UncheckedArray[Thread[(Taskpool, WorkerID)]]
|
|
workerSignals: ptr UncheckedArray[Signal]
|
|
## Access signaledTerminate
|
|
|
|
# Thread-local config
|
|
# ---------------------------------------------
|
|
|
|
var workerContext {.threadvar.}: WorkerContext
|
|
## Thread-local Worker context
|
|
|
|
proc setupWorker() =
|
|
## Initialize the thread-local context of a worker
|
|
## Requires the ID and taskpool fields to be initialized
|
|
template ctx: untyped = workerContext
|
|
|
|
preCondition: not ctx.taskpool.isNil()
|
|
preCondition: 0 <= ctx.id and ctx.id < ctx.taskpool.numThreads
|
|
preCondition: not ctx.taskpool.workerDeques.isNil()
|
|
preCondition: not ctx.taskpool.workerSignals.isNil()
|
|
|
|
# Thefts
|
|
ctx.rng = initRand(0xEFFACED + ctx.id)
|
|
ctx.otherDeques = ctx.taskpool.workerDeques
|
|
ctx.victims.allocate(ctx.taskpool.numThreads)
|
|
|
|
# Synchronization
|
|
ctx.eventNotifier = addr ctx.taskpool.eventNotifier
|
|
ctx.signal = addr ctx.taskpool.workerSignals[ctx.id]
|
|
ctx.signal.terminate.store(false, moRelaxed)
|
|
|
|
# Tasks
|
|
ctx.taskDeque = addr ctx.taskpool.workerDeques[ctx.id]
|
|
ctx.currentTask = nil
|
|
|
|
# Init
|
|
ctx.taskDeque[].init(initialCapacity = 32)
|
|
|
|
proc teardownWorker() =
|
|
## Cleanup the thread-local context of a worker
|
|
template ctx: untyped = workerContext
|
|
ctx.taskDeque[].teardown()
|
|
ctx.victims.delete()
|
|
|
|
proc eventLoop(ctx: var WorkerContext) {.raises:[Exception].}
|
|
|
|
proc workerEntryFn(params: tuple[taskpool: Taskpool, id: WorkerID])
|
|
{.raises: [Exception].} =
|
|
## On the start of the threadpool workers will execute this
|
|
## until they receive a termination signal
|
|
# We assume that thread_local variables start all at their binary zero value
|
|
preCondition: workerContext == default(WorkerContext)
|
|
|
|
template ctx: untyped = workerContext
|
|
|
|
# If the following crashes, you need --tlsEmulation:off
|
|
ctx.id = params.id
|
|
ctx.taskpool = params.taskpool
|
|
|
|
setupWorker()
|
|
|
|
# 1 matching barrier in Taskpool.new() for root thread
|
|
discard params.taskpool.barrier.wait()
|
|
|
|
{.gcsafe.}: # Not GC-safe when multi-threaded due to thread-local variables
|
|
ctx.eventLoop()
|
|
|
|
debugTermination:
|
|
log(">>> Worker %2d shutting down <<<\n", ctx.id)
|
|
|
|
# 1 matching barrier in taskpool.shutdown() for root thread
|
|
discard params.taskpool.barrier.wait()
|
|
|
|
teardownWorker()
|
|
|
|
# Tasks
|
|
# ---------------------------------------------
|
|
|
|
proc new(T: type TaskNode, parent: TaskNode, task: sink Task): T =
|
|
var tn = tp_allocPtr(TaskNode)
|
|
tn.parent = parent
|
|
wasMoved(tn.task) # tn.task is uninitialized, prevent Nim from running the Task destructor
|
|
tn.task = task
|
|
return tn
|
|
|
|
proc runTask(tn: var TaskNode) {.raises:[Exception], inline.} =
|
|
## Run a task and consumes the taskNode
|
|
tn.task.invoke()
|
|
{.gcsafe.}: # Upstream missing tagging `=destroy` as gcsafe
|
|
tn.task.`=destroy`()
|
|
tn.c_free()
|
|
|
|
proc schedule(ctx: WorkerContext, tn: sink TaskNode) {.inline.} =
|
|
## Schedule a task in the taskpool
|
|
debug: log("Worker %2d: schedule task 0x%.08x (parent 0x%.08x, current 0x%.08x)\n", ctx.id, tn, tn.parent, ctx.currentTask)
|
|
ctx.taskDeque[].push(tn)
|
|
ctx.taskpool.eventNotifier.notify()
|
|
|
|
# Scheduler
|
|
# ---------------------------------------------
|
|
|
|
proc trySteal(ctx: var WorkerContext): TaskNode =
|
|
## Try to steal a task.
|
|
|
|
ctx.victims.refill()
|
|
ctx.victims.excl(ctx.id)
|
|
|
|
while not ctx.victims.isEmpty():
|
|
let target = ctx.victims.randomPick(ctx.rng)
|
|
|
|
let stolenTask = ctx.otherDeques[target].steal()
|
|
if not stolenTask.isNil:
|
|
return stolenTask
|
|
|
|
ctx.victims.excl(target)
|
|
|
|
return nil
|
|
|
|
proc eventLoop(ctx: var WorkerContext) {.raises:[Exception].} =
|
|
## Each worker thread executes this loop over and over.
|
|
while not ctx.signal.terminate.load(moRelaxed):
|
|
# 1. Pick from local deque
|
|
debug: log("Worker %2d: eventLoop 1 - searching task from local deque\n", ctx.id)
|
|
while (var taskNode = ctx.taskDeque[].pop(); not taskNode.isNil):
|
|
debug: log("Worker %2d: eventLoop 1 - running task 0x%.08x (parent 0x%.08x, current 0x%.08x)\n", ctx.id, taskNode, taskNode.parent, ctx.currentTask)
|
|
taskNode.runTask()
|
|
|
|
# 2. Run out of tasks, become a thief
|
|
debug: log("Worker %2d: eventLoop 2 - becoming a thief\n", ctx.id)
|
|
var stolenTask = ctx.trySteal()
|
|
if not stolenTask.isNil:
|
|
# 2.a Run task
|
|
debug: log("Worker %2d: eventLoop 2.a - stole task 0x%.08x (parent 0x%.08x, current 0x%.08x)\n", ctx.id, stolenTask, stolenTask.parent, ctx.currentTask)
|
|
stolenTask.runTask()
|
|
else:
|
|
# 2.b Park the thread until a new task enters the taskpool
|
|
debug: log("Worker %2d: eventLoop 2.b - sleeping\n", ctx.id)
|
|
ctx.eventNotifier[].park()
|
|
debug: log("Worker %2d: eventLoop 2.b - waking\n", ctx.id)
|
|
|
|
# Tasking
|
|
# ---------------------------------------------
|
|
|
|
const RootTask = default(Task) # TODO: sentinel value different from null task
|
|
|
|
template isRootTask(task: Task): bool =
|
|
task == RootTask
|
|
|
|
proc forceFuture*[T](fv: Flowvar[T], parentResult: var T) {.raises:[Exception].} =
|
|
## Eagerly complete an awaited FlowVar
|
|
|
|
template ctx: untyped = workerContext
|
|
|
|
template isFutReady(): untyped =
|
|
fv.chan[].tryRecv(parentResult)
|
|
|
|
if isFutReady():
|
|
return
|
|
|
|
## 1. Process all the children of the current tasks.
|
|
## This ensures that we can give control back ASAP.
|
|
debug: log("Worker %2d: sync 1 - searching task from local deque\n", ctx.id)
|
|
while (var taskNode = ctx.taskDeque[].pop(); not taskNode.isNil):
|
|
if taskNode.parent != ctx.currentTask:
|
|
debug: log("Worker %2d: sync 1 - skipping non-direct descendant task 0x%.08x (parent 0x%.08x, current 0x%.08x)\n", ctx.id, taskNode, taskNode.parent, ctx.currentTask)
|
|
ctx.schedule(taskNode)
|
|
break
|
|
debug: log("Worker %2d: sync 1 - running task 0x%.08x (parent 0x%.08x, current 0x%.08x)\n", ctx.id, taskNode, taskNode.parent, ctx.currentTask)
|
|
taskNode.runTask()
|
|
if isFutReady():
|
|
debug: log("Worker %2d: sync 1 - future ready, exiting\n", ctx.id)
|
|
return
|
|
|
|
## 2. We run out-of-tasks or out-of-direct-child of our current awaited task
|
|
## So the task is bottlenecked by dependencies in other threads,
|
|
## hence we abandon our enqueued work and steal in the others' queues
|
|
## in hope it advances our awaited task. This prioritizes latency over throughput.
|
|
debug: log("Worker %2d: sync 2 - future not ready, becoming a thief (currentTask 0x%.08x)\n", ctx.id, ctx.currentTask)
|
|
while not isFutReady():
|
|
var taskNode = ctx.trySteal()
|
|
|
|
if not taskNode.isNil:
|
|
# We stole some task, we hope we advance our awaited task
|
|
debug: log("Worker %2d: sync 2.1 - stole task 0x%.08x (parent 0x%.08x, current 0x%.08x)\n", ctx.id, taskNode, taskNode.parent, ctx.currentTask)
|
|
taskNode.runTask()
|
|
# elif (taskNode = ctx.taskDeque[].pop(); not taskNode.isNil):
|
|
# # We advance our own queue, this increases throughput but may impact latency on the awaited task
|
|
# debug: log("Worker %2d: sync 2.2 - couldn't steal, running own task\n", ctx.id)
|
|
# taskNode.runTask()
|
|
else:
|
|
# We don't park as there is no notif for task completion
|
|
cpuRelax()
|
|
|
|
proc syncAll*(tp: Taskpool) {.raises: [Exception].} =
|
|
## Blocks until all pending tasks are completed
|
|
## This MUST only be called from
|
|
## the root scope that created the taskpool
|
|
template ctx: untyped = workerContext
|
|
|
|
debugTermination:
|
|
log(">>> Worker %2d enters barrier <<<\n", ctx.id)
|
|
|
|
preCondition: ctx.id == 0
|
|
preCondition: ctx.currentTask.task.isRootTask()
|
|
|
|
# Empty all tasks
|
|
var foreignThreadsParked = false
|
|
while not foreignThreadsParked:
|
|
# 1. Empty local tasks
|
|
debug: log("Worker %2d: syncAll 1 - searching task from local deque\n", ctx.id)
|
|
while (var taskNode = ctx.taskDeque[].pop(); not taskNode.isNil):
|
|
debug: log("Worker %2d: syncAll 1 - running task 0x%.08x (parent 0x%.08x, current 0x%.08x)\n", ctx.id, taskNode, taskNode.parent, ctx.currentTask)
|
|
taskNode.runTask()
|
|
|
|
if tp.numThreads == 1 or foreignThreadsParked:
|
|
break
|
|
|
|
# 2. Help other threads
|
|
debug: log("Worker %2d: syncAll 2 - becoming a thief\n", ctx.id)
|
|
var taskNode = ctx.trySteal()
|
|
|
|
if not taskNode.isNil:
|
|
# 2.1 We stole some task
|
|
debug: log("Worker %2d: syncAll 2.1 - stole task 0x%.08x (parent 0x%.08x, current 0x%.08x)\n", ctx.id, taskNode, taskNode.parent, ctx.currentTask)
|
|
taskNode.runTask()
|
|
else:
|
|
# 2.2 No task to steal
|
|
if tp.eventNotifier.getParked() == tp.numThreads - 1:
|
|
# 2.2.1 all threads besides the current are parked
|
|
debugTermination:
|
|
log("Worker %2d: syncAll 2.2.1 - termination, all other threads sleeping\n", ctx.id)
|
|
foreignThreadsParked = true
|
|
else:
|
|
# 2.2.2 We don't park as there is no notif for task completion
|
|
cpuRelax()
|
|
|
|
debugTermination:
|
|
log(">>> Worker %2d leaves barrier <<<\n", ctx.id)
|
|
|
|
# Runtime
|
|
# ---------------------------------------------
|
|
|
|
proc new*(T: type Taskpool, numThreads = countProcessors()): T {.raises: [Exception].} =
|
|
## Initialize a threadpool that manages `numThreads` threads.
|
|
## Default to the number of logical processors available.
|
|
|
|
type TpObj = typeof(default(Taskpool)[])
|
|
# Event notifier requires an extra 64 bytes for alignment
|
|
var tp = tp_allocAligned(TpObj, sizeof(TpObj) + 64, 64)
|
|
|
|
tp.barrier.init(numThreads.int32)
|
|
tp.eventNotifier.initialize()
|
|
tp.numThreads = numThreads
|
|
tp.workerDeques = tp_allocArrayAligned(ChaseLevDeque[TaskNode], numThreads, alignment = 64)
|
|
tp.workers = tp_allocArrayAligned(Thread[(Taskpool, WorkerID)], numThreads, alignment = 64)
|
|
tp.workerSignals = tp_allocArrayAligned(Signal, numThreads, alignment = 64)
|
|
|
|
# Setup master thread
|
|
workerContext.id = 0
|
|
workerContext.taskpool = tp
|
|
|
|
# Start worker threads
|
|
for i in 1 ..< numThreads:
|
|
createThread(tp.workers[i], workerEntryFn, (tp, WorkerID(i)))
|
|
|
|
# Root worker
|
|
setupWorker()
|
|
|
|
# Root task, this is a sentinel task that is never called.
|
|
workerContext.currentTask = TaskNode.new(
|
|
parent = nil,
|
|
task = default(Task) # TODO RootTask, somehow this uses `=copy`
|
|
)
|
|
|
|
# Wait for the child threads
|
|
discard tp.barrier.wait()
|
|
return tp
|
|
|
|
proc cleanup(tp: var Taskpool) {.raises: [AssertionDefect, OSError].} =
|
|
## Cleanup all resources allocated by the taskpool
|
|
preCondition: workerContext.currentTask.task.isRootTask()
|
|
|
|
for i in 1 ..< tp.numThreads:
|
|
joinThread(tp.workers[i])
|
|
|
|
tp.workerSignals.tp_freeAligned()
|
|
tp.workers.tp_freeAligned()
|
|
tp.workerDeques.tp_freeAligned()
|
|
`=destroy`(tp.eventNotifier)
|
|
tp.barrier.delete()
|
|
|
|
tp.tp_freeAligned()
|
|
|
|
proc shutdown*(tp: var Taskpool) {.raises:[Exception].} =
|
|
## Wait until all tasks are processed and then shutdown the taskpool
|
|
preCondition: workerContext.currentTask.task.isRootTask()
|
|
tp.syncAll()
|
|
|
|
# Signal termination to all threads
|
|
for i in 0 ..< tp.numThreads:
|
|
tp.workerSignals[i].terminate.store(true, moRelaxed)
|
|
|
|
let parked = tp.eventNotifier.getParked()
|
|
for i in 0 ..< parked:
|
|
tp.eventNotifier.notify()
|
|
|
|
# 1 matching barrier in worker_entry_fn
|
|
discard tp.barrier.wait()
|
|
|
|
teardownWorker()
|
|
tp.cleanup()
|
|
|
|
# Dealloc dummy task
|
|
workerContext.currentTask.c_free()
|
|
|
|
# Task parallelism
|
|
# ---------------------------------------------
|
|
{.pop.} # raises:[]
|
|
|
|
macro spawn*(tp: Taskpool, fnCall: typed): untyped =
|
|
## Spawns the input function call asynchronously, potentially on another thread of execution.
|
|
##
|
|
## If the function calls returns a result, spawn will wrap it in a Flowvar.
|
|
## You can use `sync` to block the current thread and extract the asynchronous result from the flowvar.
|
|
## You can use `isReady` to check if result is available and if subsequent
|
|
## `spawn` returns immediately.
|
|
##
|
|
## Tasks are processed approximately in Last-In-First-Out (LIFO) order
|
|
result = newStmtList()
|
|
|
|
let fn = fnCall[0]
|
|
let fnName = $fn
|
|
|
|
# Get the return type if any
|
|
let retType = fnCall[0].getImpl[3][0]
|
|
let needFuture = retType.kind != nnkEmpty
|
|
|
|
# Package in a task
|
|
let taskNode = ident("taskNode")
|
|
if not needFuture:
|
|
result.add quote do:
|
|
let `taskNode` = TaskNode.new(workerContext.currentTask, toTask(`fnCall`))
|
|
schedule(workerContext, `taskNode`)
|
|
|
|
else:
|
|
# tasks have no return value.
|
|
# 1. We create a channel/flowvar to transmit the return value to awaiter/sync
|
|
# 2. We create a wrapper async_fn without return value that send the return value in the channel
|
|
# 3. We package that wrapper function in a task
|
|
|
|
# 1. Create the channel
|
|
let fut = ident("fut")
|
|
let futTy = nnkBracketExpr.newTree(
|
|
bindSym"FlowVar",
|
|
retType
|
|
)
|
|
result.add quote do:
|
|
let `fut` = newFlowVar(type `retType`)
|
|
|
|
# 2. Create a wrapper function that sends result to the channel
|
|
# TODO, upstream "getImpl" doesn't return the generic params
|
|
let genericParams = fn.getImpl()[2].replaceSymsByIdents()
|
|
let formalParams = fn.getImpl()[3].replaceSymsByIdents()
|
|
|
|
var asyncParams = nnkFormalParams.newTree(
|
|
newEmptyNode()
|
|
)
|
|
var fnCallIdents = nnkCall.newTree(
|
|
fnCall[0]
|
|
)
|
|
for i in 1 ..< formalParams.len:
|
|
let ident = formalParams[i].replaceSymsByIdents()
|
|
asyncParams.add ident
|
|
for j in 0 ..< ident.len - 2:
|
|
# Handle "a, b: int"
|
|
fnCallIdents.add ident[j]
|
|
|
|
let futFnParam = ident("fut")
|
|
asyncParams.add newIdentDefs(futFnParam, futTy)
|
|
|
|
let asyncBody = quote do:
|
|
# XXX: can't test that when the RootTask is default(Task) instead of a sentinel value
|
|
# preCondition: not isRootTask(workerContext.currentTask.task)
|
|
|
|
let res = `fnCallIdents`
|
|
readyWith(`futFnParam`, res)
|
|
|
|
let asyncFn = ident("taskpool_" & fnName)
|
|
result.add nnkProcDef.newTree(
|
|
asyncFn,
|
|
newEmptyNode(),
|
|
genericParams,
|
|
asyncParams,
|
|
nnkPragma.newTree(ident("nimcall")),
|
|
newEmptyNode(),
|
|
asyncBody
|
|
)
|
|
|
|
var asyncCall = newCall(asyncFn)
|
|
for i in 1 ..< fnCall.len:
|
|
asyncCall.add fnCall[i].replaceSymsByIdents()
|
|
asyncCall.add fut
|
|
|
|
result.add quote do:
|
|
let `taskNode` = TaskNode.new(workerContext.currentTask, toTask(`asyncCall`))
|
|
schedule(workerContext, `taskNode`)
|
|
|
|
# Return the future / flowvar
|
|
`fut`
|
|
|
|
# Wrap in a block for namespacing
|
|
result = nnkBlockStmt.newTree(newEmptyNode(), result)
|
|
# echo result.toStrLit()
|