fix(NimThreadpool): Implement lightweight threadpool
Motivation: reduce memory usage. The previous implementation was using 400+ mb of memory for a threadpool with 16 threads.
This commit is contained in:
parent
93ef76c3e4
commit
4d8757a128
|
@ -119,3 +119,6 @@
|
|||
[submodule "vendor/SortFilterProxyModel"]
|
||||
path = ui/StatusQ/vendor/SortFilterProxyModel
|
||||
url = https://github.com/status-im/SortFilterProxyModel.git
|
||||
[submodule "vendor/nim-taskpools"]
|
||||
path = vendor/nim-taskpools
|
||||
url = https://github.com/status-im/nim-taskpools.git
|
||||
|
|
|
@ -1,260 +1,61 @@
|
|||
import # std libs
|
||||
atomics, json, sequtils, tables
|
||||
std/cpuinfo
|
||||
|
||||
import # vendor libs
|
||||
chronicles, chronos, json_serialization, task_runner
|
||||
json_serialization, json, chronicles, taskpools
|
||||
|
||||
import # status-desktop libs
|
||||
./common
|
||||
|
||||
export
|
||||
chronos, common, json_serialization
|
||||
export common, json_serialization
|
||||
|
||||
logScope:
|
||||
topics = "task-threadpool"
|
||||
|
||||
type
|
||||
ThreadPool* = ref object
|
||||
chanRecvFromPool: AsyncChannel[ThreadSafeString]
|
||||
chanSendToPool: AsyncChannel[ThreadSafeString]
|
||||
thread: Thread[PoolThreadArg]
|
||||
size: int
|
||||
running*: Atomic[bool]
|
||||
PoolThreadArg = object
|
||||
chanSendToMain: AsyncChannel[ThreadSafeString]
|
||||
chanRecvFromMain: AsyncChannel[ThreadSafeString]
|
||||
size: int
|
||||
TaskThreadArg = object
|
||||
id: int
|
||||
chanRecvFromPool: AsyncChannel[ThreadSafeString]
|
||||
chanSendToPool: AsyncChannel[ThreadSafeString]
|
||||
ThreadNotification = object
|
||||
id: int
|
||||
notice: string
|
||||
pool: Taskpool
|
||||
ThreadSafeTaskArg* = distinct cstring
|
||||
|
||||
# forward declarations
|
||||
proc poolThread(arg: PoolThreadArg) {.thread.}
|
||||
proc safe*[T: TaskArg](taskArg: T): ThreadSafeTaskArg =
|
||||
var
|
||||
strArgs = taskArg.encode()
|
||||
res = cast[cstring](allocShared(strArgs.len + 1))
|
||||
|
||||
const MaxThreadPoolSize = 16
|
||||
copyMem(res, strArgs.cstring, strArgs.len)
|
||||
res[strArgs.len] = '\0'
|
||||
res.ThreadSafeTaskArg
|
||||
|
||||
proc init(self: ThreadPool) =
|
||||
self.chanRecvFromPool.open()
|
||||
self.chanSendToPool.open()
|
||||
let arg = PoolThreadArg(
|
||||
chanSendToMain: self.chanRecvFromPool,
|
||||
chanRecvFromMain: self.chanSendToPool,
|
||||
size: self.size
|
||||
)
|
||||
createThread(self.thread, poolThread, arg)
|
||||
# block until we receive "ready"
|
||||
discard $(self.chanRecvFromPool.recvSync())
|
||||
|
||||
proc newThreadPool*(size: int = MaxThreadPoolSize): ThreadPool =
|
||||
new(result)
|
||||
result.chanRecvFromPool = newAsyncChannel[ThreadSafeString](-1)
|
||||
result.chanSendToPool = newAsyncChannel[ThreadSafeString](-1)
|
||||
result.thread = Thread[PoolThreadArg]()
|
||||
result.size = size
|
||||
result.running.store(false)
|
||||
result.init()
|
||||
proc toString*(input: ThreadSafeTaskArg): string =
|
||||
result = $(input.cstring)
|
||||
deallocShared input.cstring
|
||||
|
||||
proc teardown*(self: ThreadPool) =
|
||||
self.running.store(false)
|
||||
self.chanSendToPool.sendSync("shutdown".safe)
|
||||
self.chanRecvFromPool.close()
|
||||
self.chanSendToPool.close()
|
||||
trace "[threadpool] waiting for the control thread to stop"
|
||||
joinThread(self.thread)
|
||||
self.pool.syncAll()
|
||||
self.pool.shutdown()
|
||||
|
||||
proc start*[T: TaskArg](self: Threadpool, arg: T) =
|
||||
self.chanSendToPool.sendSync(arg.encode.safe)
|
||||
self.running.store(true)
|
||||
|
||||
proc runner(arg: TaskThreadArg) {.async.} =
|
||||
arg.chanRecvFromPool.open()
|
||||
arg.chanSendToPool.open()
|
||||
|
||||
let noticeToPool = ThreadNotification(id: arg.id, notice: "ready")
|
||||
trace "[threadpool task thread] sending 'ready'", threadid=arg.id
|
||||
await arg.chanSendToPool.send(noticeToPool.encode.safe)
|
||||
|
||||
while true:
|
||||
trace "[threadpool task thread] waiting for message"
|
||||
let received = $(await arg.chanRecvFromPool.recv())
|
||||
|
||||
if received == "shutdown":
|
||||
trace "[threadpool task thread] received 'shutdown'"
|
||||
break
|
||||
proc newThreadPool*(): ThreadPool =
|
||||
new(result)
|
||||
var nthreads = countProcessors()
|
||||
result.pool = Taskpool.new(num_threads = nthreads)
|
||||
|
||||
proc runTask(safeTaskArg: ThreadSafeTaskArg) {.gcsafe, nimcall.} =
|
||||
let
|
||||
parsed = parseJson(received)
|
||||
taskArg = safeTaskArg.toString()
|
||||
parsed = parseJson(taskArg)
|
||||
messageType = parsed{"$type"}.getStr
|
||||
|
||||
debug "[threadpool task thread] initiating task", messageType=messageType,
|
||||
threadid=arg.id, task=received
|
||||
threadid=getThreadId(), task=taskArg
|
||||
|
||||
try:
|
||||
let task = cast[Task](parsed{"tptr"}.getInt)
|
||||
try:
|
||||
task(received)
|
||||
except Exception as e:
|
||||
task(taskArg)
|
||||
except CatchableError as e:
|
||||
error "[threadpool task thread] exception", error=e.msg
|
||||
except Exception as e:
|
||||
error "[threadpool task thread] unknown message", message=received
|
||||
except CatchableError as e:
|
||||
error "[threadpool task thread] unknown message", message=taskArg
|
||||
|
||||
let noticeToPool = ThreadNotification(id: arg.id, notice: "done")
|
||||
debug "[threadpool task thread] sending 'done' notice to pool",
|
||||
threadid=arg.id, task=received
|
||||
await arg.chanSendToPool.send(noticeToPool.encode.safe)
|
||||
|
||||
arg.chanRecvFromPool.close()
|
||||
arg.chanSendToPool.close()
|
||||
|
||||
proc taskThread(arg: TaskThreadArg) {.thread.} =
|
||||
waitFor runner(arg)
|
||||
|
||||
proc pool(arg: PoolThreadArg) {.async.} =
|
||||
let
|
||||
chanSendToMain = arg.chanSendToMain
|
||||
chanRecvFromMainOrTask = arg.chanRecvFromMain
|
||||
var threadsBusy = newTable[int, tuple[thr: Thread[TaskThreadArg],
|
||||
chanSendToTask: AsyncChannel[ThreadSafeString]]]()
|
||||
var threadsIdle = newSeq[tuple[id: int, thr: Thread[TaskThreadArg],
|
||||
chanSendToTask: AsyncChannel[ThreadSafeString]]](arg.size)
|
||||
var taskQueue: seq[string] = @[] # FIFO queue
|
||||
var allReady = 0
|
||||
chanSendToMain.open()
|
||||
chanRecvFromMainOrTask.open()
|
||||
|
||||
trace "[threadpool] sending 'ready' to main thread"
|
||||
await chanSendToMain.send("ready".safe)
|
||||
|
||||
for i in 0..<arg.size:
|
||||
let id = i + 1
|
||||
let chanSendToTask = newAsyncChannel[ThreadSafeString](-1)
|
||||
chanSendToTask.open()
|
||||
trace "[threadpool] adding to threadsIdle", threadid=id
|
||||
threadsIdle[i].id = id
|
||||
createThread(
|
||||
threadsIdle[i].thr,
|
||||
taskThread,
|
||||
TaskThreadArg(id: id, chanRecvFromPool: chanSendToTask,
|
||||
chanSendToPool: chanRecvFromMainOrTask
|
||||
)
|
||||
)
|
||||
threadsIdle[i].chanSendToTask = chanSendToTask
|
||||
|
||||
# when task received and number of busy threads == MaxThreadPoolSize,
|
||||
# then put the task in a queue
|
||||
|
||||
# when task received and number of busy threads < MaxThreadPoolSize, pop
|
||||
# a thread from threadsIdle, track that thread in threadsBusy, and run
|
||||
# task in that thread
|
||||
|
||||
# if "done" received from a thread, remove thread from threadsBusy, and
|
||||
# push thread into threadsIdle
|
||||
|
||||
while true:
|
||||
trace "[threadpool] waiting for message"
|
||||
var task = $(await chanRecvFromMainOrTask.recv())
|
||||
|
||||
if task == "shutdown":
|
||||
trace "[threadpool] sending 'shutdown' to all task threads"
|
||||
for tpl in threadsIdle:
|
||||
await tpl.chanSendToTask.send("shutdown".safe)
|
||||
for tpl in threadsBusy.values:
|
||||
await tpl.chanSendToTask.send("shutdown".safe)
|
||||
break
|
||||
|
||||
let
|
||||
jsonNode = parseJson(task)
|
||||
messageType = jsonNode{"$type"}.getStr
|
||||
trace "[threadpool] determined message type", messageType=messageType
|
||||
|
||||
case messageType
|
||||
of "ThreadNotification":
|
||||
try:
|
||||
let notification = decode[ThreadNotification](task)
|
||||
trace "[threadpool] received notification",
|
||||
notice=notification.notice, threadid=notification.id
|
||||
|
||||
if notification.notice == "ready":
|
||||
trace "[threadpool] received 'ready' from a task thread"
|
||||
allReady = allReady + 1
|
||||
|
||||
elif notification.notice == "done":
|
||||
let tpl = threadsBusy[notification.id]
|
||||
trace "[threadpool] adding to threadsIdle",
|
||||
newlength=(threadsIdle.len + 1)
|
||||
threadsIdle.add (notification.id, tpl.thr, tpl.chanSendToTask)
|
||||
trace "[threadpool] removing from threadsBusy",
|
||||
newlength=(threadsBusy.len - 1), threadid=notification.id
|
||||
threadsBusy.del notification.id
|
||||
|
||||
if taskQueue.len > 0:
|
||||
trace "[threadpool] removing from taskQueue",
|
||||
newlength=(taskQueue.len - 1)
|
||||
task = taskQueue[0]
|
||||
taskQueue.delete 0, 0
|
||||
|
||||
trace "[threadpool] removing from threadsIdle",
|
||||
newlength=(threadsIdle.len - 1)
|
||||
let tpl = threadsIdle[0]
|
||||
threadsIdle.delete 0, 0
|
||||
trace "[threadpool] adding to threadsBusy",
|
||||
newlength=(threadsBusy.len + 1), threadid=tpl.id
|
||||
threadsBusy.add tpl.id, (tpl.thr, tpl.chanSendToTask)
|
||||
await tpl.chanSendToTask.send(task.safe)
|
||||
|
||||
else:
|
||||
error "[threadpool] unknown notification", notice=notification.notice
|
||||
except Exception as e:
|
||||
warn "[threadpool] unknown error in thread notification", message=task, error=e.msg
|
||||
|
||||
else: # must be a request to do task work
|
||||
if allReady < arg.size or threadsBusy.len == arg.size:
|
||||
# add to queue
|
||||
trace "[threadpool] adding to taskQueue",
|
||||
newlength=(taskQueue.len + 1)
|
||||
taskQueue.add task
|
||||
|
||||
# do we have available threads in the threadpool?
|
||||
elif threadsBusy.len < arg.size:
|
||||
# check if we have tasks waiting on queue
|
||||
if taskQueue.len > 0:
|
||||
# remove first element from the task queue
|
||||
trace "[threadpool] adding to taskQueue",
|
||||
newlength=(taskQueue.len + 1)
|
||||
taskQueue.add task
|
||||
trace "[threadpool] removing from taskQueue",
|
||||
newlength=(taskQueue.len - 1)
|
||||
task = taskQueue[0]
|
||||
taskQueue.delete 0, 0
|
||||
|
||||
trace "[threadpool] removing from threadsIdle",
|
||||
newlength=(threadsIdle.len - 1)
|
||||
let tpl = threadsIdle[0]
|
||||
threadsIdle.delete 0, 0
|
||||
trace "[threadpool] adding to threadsBusy",
|
||||
newlength=(threadsBusy.len + 1), threadid=tpl.id
|
||||
threadsBusy.add tpl.id, (tpl.thr, tpl.chanSendToTask)
|
||||
await tpl.chanSendToTask.send(task.safe)
|
||||
|
||||
var allTaskThreads = newSeq[tuple[id: int, thr: Thread[TaskThreadArg]]]()
|
||||
|
||||
for tpl in threadsIdle:
|
||||
tpl.chanSendToTask.close()
|
||||
allTaskThreads.add (tpl.id, tpl.thr)
|
||||
for id, tpl in threadsBusy.pairs:
|
||||
tpl.chanSendToTask.close()
|
||||
allTaskThreads.add (id, tpl.thr)
|
||||
|
||||
chanSendToMain.close()
|
||||
chanRecvFromMainOrTask.close()
|
||||
|
||||
trace "[threadpool] waiting for all task threads to stop"
|
||||
for tpl in allTaskThreads:
|
||||
debug "[threadpool] join thread", threadid=tpl.id
|
||||
joinThread(tpl.thr)
|
||||
|
||||
proc poolThread(arg: PoolThreadArg) {.thread.} =
|
||||
waitFor pool(arg)
|
||||
proc start*[T: TaskArg](self: ThreadPool, arg: T) =
|
||||
self.pool.spawn runTask(arg.safe())
|
|
@ -9,13 +9,12 @@ type
|
|||
uuid*: string
|
||||
ObtainMarketStickerPacksTaskArg = ref object of QObjectTaskArg
|
||||
chainId*: int
|
||||
running*: ByteAddress # pointer to threadpool's `.running` Atomic[bool]
|
||||
InstallStickerPackTaskArg = ref object of QObjectTaskArg
|
||||
packId*: string
|
||||
chainId*: int
|
||||
hasKey*: bool
|
||||
|
||||
proc getMarketStickerPacks*(running: var Atomic[bool], chainId: int):
|
||||
proc getMarketStickerPacks*(chainId: int):
|
||||
tuple[stickers: Table[string, StickerPackDto], error: string] =
|
||||
result = (initTable[string, StickerPackDto](), "")
|
||||
try:
|
||||
|
@ -48,8 +47,7 @@ const estimateTask: Task = proc(argEncoded: string) {.gcsafe, nimcall.} =
|
|||
|
||||
const obtainMarketStickerPacksTask: Task = proc(argEncoded: string) {.gcsafe, nimcall.} =
|
||||
let arg = decode[ObtainMarketStickerPacksTaskArg](argEncoded)
|
||||
var running = cast[ptr Atomic[bool]](arg.running)
|
||||
let (marketStickerPacks, error) = getMarketStickerPacks(running[], arg.chainId)
|
||||
let (marketStickerPacks, error) = getMarketStickerPacks(arg.chainId)
|
||||
var packs: seq[StickerPackDto] = @[]
|
||||
for packId, stickerPack in marketStickerPacks.pairs:
|
||||
packs.add(stickerPack)
|
||||
|
|
|
@ -286,7 +286,6 @@ QtObject:
|
|||
vptr: cast[ByteAddress](self.vptr),
|
||||
slot: "setMarketStickerPacks",
|
||||
chainId: chainId,
|
||||
running: cast[ByteAddress](addr self.threadpool.running)
|
||||
)
|
||||
self.threadpool.start(arg)
|
||||
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 4bc0b592e8f71403c19b43ae6f4920c9a2380205
|
Loading…
Reference in New Issue