enable cancellations

This commit is contained in:
Dmitriy Ryajov 2023-09-15 13:08:38 -06:00
parent 1713c7674c
commit d151c01cd8
No known key found for this signature in database
GPG Key ID: DA8C680CE7C657A4
2 changed files with 228 additions and 56 deletions

View File

@ -8,6 +8,7 @@ push: {.upraises: [].}
import std/atomics
import std/strutils
import std/tables
import pkg/chronos
import pkg/chronos/threadsync
@ -16,30 +17,36 @@ import pkg/questionable/results
import pkg/stew/ptrops
import pkg/taskpools
import pkg/stew/byteutils
import pkg/chronicles
import ../key
import ../query
import ../datastore
import ./semaphore
import ./asyncsemaphore
import ./databuffer
type
ErrorEnum {.pure.} = enum
DatastoreErr, DatastoreKeyNotFoundErr, CatchableErr
ThreadTypes = void | bool | SomeInteger | DataBuffer | tuple | Atomic
ThreadResult[T: ThreadTypes] = Result[T, DataBuffer]
TaskCtx[T: ThreadTypes] = object
ds: ptr Datastore
res: ptr ThreadResult[T]
semaphore: ptr Semaphore
cancelled: bool
semaphore: AsyncSemaphore
signal: ThreadSignalPtr
ThreadDatastore* = ref object of Datastore
tp: Taskpool
ds: Datastore
semaphore: Semaphore # semaphore is used for backpressure \
# to avoid exhausting file descriptors
semaphore: AsyncSemaphore # semaphore is used for backpressure \
# to avoid exhausting file descriptors
tasks: seq[Future[void]]
locks: Table[Key, AsyncLock]
template dispatchTask(
self: ThreadDatastore,
@ -57,7 +64,9 @@ template dispatchTask(
if ctx.res[].isErr:
result = failure(ctx.res[].error()) # TODO: fix this, result shouldn't be accessed
except CancelledError as exc:
echo "Cancelling future!"
trace "Cancelling future!", exc = exc.msg
ctx.cancelled = true
await ctx.signal.fire()
raise exc
finally:
discard ctx.signal.close()
@ -66,23 +75,54 @@ template dispatchTask(
idx != -1):
self.tasks.del(idx)
proc hasTask(
ctx: ptr TaskCtx,
key: ptr Key) =
proc signalMonitor[T](ctx: ptr TaskCtx, fut: Future[T]) {.async.} =
## Monitor the signal and cancel the future if
## the cancellation flag is set
##
try:
await ctx[].signal.wait()
trace "Received signal"
if ctx[].cancelled: # there could eventually be other flags
trace "Cancelling future"
if not fut.finished:
await fut.cancelAndWait() # cancel the `has` future
discard ctx[].signal.fireSync()
except CatchableError as exc:
trace "Exception in thread signal monitor", exc = exc.msg
ctx[].res[].err(exc)
discard ctx[].signal.fireSync()
proc asyncHasTask(
ctx: ptr TaskCtx[bool],
key: ptr Key) {.async.} =
defer:
discard ctx[].signal.fireSync()
ctx[].semaphore[].release()
ctx[].semaphore[].acquire()
without ret =?
(waitFor ctx[].ds[].has(key[])).catch and res =? ret, error:
let
fut = ctx[].ds[].has(key[])
asyncSpawn signalMonitor(ctx, fut)
without ret =? (await fut).catch and res =? ret, error:
ctx[].res[].err(error)
return
ctx[].res[].ok(res)
proc hasTask(ctx: ptr TaskCtx, key: ptr Key) =
try:
waitFor asyncHasTask(ctx, key)
except CatchableError as exc:
raiseAssert exc.msg
method has*(self: ThreadDatastore, key: Key): Future[?!bool] {.async.} =
defer:
self.semaphore.release()
await self.semaphore.acquire()
var
signal = ThreadSignalPtr.new().valueOr:
return failure(error())
@ -91,7 +131,6 @@ method has*(self: ThreadDatastore, key: Key): Future[?!bool] {.async.} =
ctx = TaskCtx[bool](
ds: addr self.ds,
res: addr res,
semaphore: addr self.semaphore,
signal: signal)
proc runTask() =
@ -100,21 +139,34 @@ method has*(self: ThreadDatastore, key: Key): Future[?!bool] {.async.} =
self.dispatchTask(ctx, runTask)
return success(res.get())
proc delTask(ctx: ptr TaskCtx, key: ptr Key) =
proc asyncDelTask(ctx: ptr TaskCtx[void], key: ptr Key) {.async.} =
defer:
discard ctx[].signal.fireSync()
ctx[].semaphore[].release()
ctx[].semaphore[].acquire()
without res =? (waitFor ctx[].ds[].delete(key[])).catch, error:
let
fut = ctx[].ds[].delete(key[])
asyncSpawn signalMonitor(ctx, fut)
without res =? (await fut).catch, error:
ctx[].res[].err(error)
return
ctx[].res[].ok()
return
proc delTask(ctx: ptr TaskCtx, key: ptr Key) =
try:
waitFor asyncDelTask(ctx, key)
except CatchableError as exc:
raiseAssert exc.msg
method delete*(
self: ThreadDatastore,
key: Key): Future[?!void] {.async.} =
defer:
self.semaphore.release()
await self.semaphore.acquire()
var
signal = ThreadSignalPtr.new().valueOr:
@ -124,7 +176,6 @@ method delete*(
ctx = TaskCtx[void](
ds: addr self.ds,
res: addr res,
semaphore: addr self.semaphore,
signal: signal)
proc runTask() =
@ -143,30 +194,45 @@ method delete*(
return success()
proc putTask(
ctx: ptr TaskCtx,
proc asyncPutTask(
ctx: ptr TaskCtx[void],
key: ptr Key,
# data: DataBuffer,
data: ptr UncheckedArray[byte],
len: int) =
## run put in a thread task
##
len: int) {.async.} =
defer:
discard ctx[].signal.fireSync()
ctx[].semaphore[].release()
ctx[].semaphore[].acquire()
without res =? (waitFor ctx[].ds[].put(key[], @(data.toOpenArray(0, len - 1)))).catch, error:
let
fut = ctx[].ds[].put(key[], @(data.toOpenArray(0, len - 1)))
asyncSpawn signalMonitor(ctx, fut)
without res =? (await fut).catch, error:
ctx[].res[].err(error)
return
ctx[].res[].ok()
proc putTask(
ctx: ptr TaskCtx,
key: ptr Key,
data: ptr UncheckedArray[byte],
len: int) =
## run put in a thread task
##
try:
waitFor asyncPutTask(ctx, key, data, len)
except CatchableError as exc:
raiseAssert exc.msg
method put*(
self: ThreadDatastore,
key: Key,
data: seq[byte]): Future[?!void] {.async.} =
defer:
self.semaphore.release()
await self.semaphore.acquire()
var
signal = ThreadSignalPtr.new().valueOr:
@ -176,7 +242,6 @@ method put*(
ctx = TaskCtx[void](
ds: addr self.ds,
res: addr res,
semaphore: addr self.semaphore,
signal: signal)
proc runTask() =
@ -199,27 +264,41 @@ method put*(
return success()
proc asyncGetTask(
ctx: ptr TaskCtx[DataBuffer],
key: ptr Key) {.async.} =
defer:
discard ctx[].signal.fireSync()
let
fut = ctx[].ds[].get(key[])
asyncSpawn signalMonitor(ctx, fut)
without res =?
(waitFor fut).catch and data =? res, error:
ctx[].res[].err(error)
return
ctx[].res[].ok(DataBuffer.new(data))
proc getTask(
ctx: ptr TaskCtx,
key: ptr Key) =
## Run get in a thread task
##
defer:
discard ctx[].signal.fireSync()
ctx[].semaphore[].release()
ctx[].semaphore[].acquire()
without res =?
(waitFor ctx[].ds[].get(key[])).catch and data =? res, error:
ctx[].res[].err(error)
return
ctx[].res[].ok(DataBuffer.new(data))
try:
waitFor asyncGetTask(ctx, key)
except CatchableError as exc:
raiseAssert exc.msg
method get*(
self: ThreadDatastore,
key: Key): Future[?!seq[byte]] {.async.} =
defer:
self.semaphore.release()
await self.semaphore.acquire()
var
signal = ThreadSignalPtr.new().valueOr:
@ -230,7 +309,6 @@ method get*(
ctx = TaskCtx[DataBuffer](
ds: addr self.ds,
res: addr res,
semaphore: addr self.semaphore,
signal: signal)
proc runTask() =
@ -248,16 +326,17 @@ method close*(self: ThreadDatastore): Future[?!void] {.async.} =
await self.ds.close()
proc queryTask(
proc asyncQueryTask(
ctx: ptr TaskCtx,
iter: ptr QueryIter) =
iter: ptr QueryIter) {.async.} =
defer:
discard ctx[].signal.fireSync()
ctx[].semaphore[].release()
ctx[].semaphore[].acquire()
without ret =? (waitFor iter[].next()).catch and res =? ret, error:
let
fut = iter[].next()
asyncSpawn signalMonitor(ctx, fut)
without ret =? (waitFor fut).catch and res =? ret, error:
ctx[].res[].err(error)
return
@ -271,30 +350,40 @@ proc queryTask(
ctx[].res[].ok((true, keyBuf, dataBuf))
proc queryTask(
ctx: ptr TaskCtx,
iter: ptr QueryIter) =
try:
waitFor asyncQueryTask(ctx, iter)
except CatchableError as exc:
raiseAssert exc.msg
method query*(
self: ThreadDatastore,
query: Query): Future[?!QueryIter] {.async.} =
without var childIter =? await self.ds.query(query), error:
return failure error
var
iter = QueryIter.new()
locked = false
let lock = newAsyncLock()
proc next(): Future[?!QueryResponse] {.async.} =
defer:
if lock.locked:
lock.release()
locked = false
self.semaphore.release()
if lock.locked:
await self.semaphore.acquire()
if locked:
return failure (ref DatastoreError)(msg: "Should always await query features")
locked = true
if iter.finished == true:
return failure (ref QueryEndedError)(msg: "Calling next on a finished query!")
await lock.acquire()
if iter.finished == true:
return success (Key.none, EmptyBytes)
@ -306,7 +395,6 @@ method query*(
ctx = TaskCtx[(bool, DataBuffer, DataBuffer)](
ds: addr self.ds,
res: addr res,
semaphore: addr self.semaphore,
signal: signal)
proc runTask() =
@ -335,4 +423,4 @@ func new*(
success ThreadDatastore(
tp: tp,
ds: ds,
semaphore: Semaphore.init((tp.numThreads - 1).uint))
semaphore: AsyncSemaphore.new(tp.numThreads - 1))

View File

@ -7,10 +7,12 @@ import std/importutils
import pkg/asynctest
import pkg/chronos
import pkg/chronos/threadsync
import pkg/stew/results
import pkg/stew/byteutils
import pkg/taskpools
import pkg/questionable/results
import pkg/chronicles
import pkg/datastore/sql
import pkg/datastore/fsds
@ -78,3 +80,85 @@ suite "Test Query ThreadDatastore with SQLite":
# ds: ThreadDatastore
# taskPool: Taskpool
suite "Test ThreadDatastore":
var
sqlStore: Datastore
ds: ThreadDatastore
taskPool: Taskpool
key = Key.init("/a/b").tryGet()
bytes = "some bytes".toBytes
otherBytes = "some other bytes".toBytes
privateAccess(ThreadDatastore) # expose private fields
privateAccess(TaskCtx) # expose private fields
setupAll:
sqlStore = SQLiteDatastore.new(Memory).tryGet()
taskPool = Taskpool.new(countProcessors() * 2)
ds = ThreadDatastore.new(sqlStore, taskPool).tryGet()
test "should monitor signal for cancellations and cancel":
var
signal = ThreadSignalPtr.new().tryGet()
res = ThreadResult[void]()
ctx = TaskCtx[void](
ds: addr sqlStore,
res: addr res,
signal: signal)
fut = newFuture[void]("signalMonitor")
threadArgs = (addr ctx, addr fut)
var
thread: Thread[type threadArgs]
proc threadTask(args: type threadArgs) =
var (ctx, fut) = args
proc asyncTask() {.async.} =
let
monitor = signalMonitor(ctx, fut[])
await monitor
waitFor asyncTask()
createThread(thread, threadTask, threadArgs)
ctx.cancelled = true
check: ctx.signal.fireSync.tryGet
joinThreads(thread)
check: fut.cancelled
check: ctx.signal.close().isOk
test "should monitor signal for cancellations and not cancel":
var
signal = ThreadSignalPtr.new().tryGet()
res = ThreadResult[void]()
ctx = TaskCtx[void](
ds: addr sqlStore,
res: addr res,
signal: signal)
fut = newFuture[void]("signalMonitor")
threadArgs = (addr ctx, addr fut)
var
thread: Thread[type threadArgs]
proc threadTask(args: type threadArgs) =
var (ctx, fut) = args
proc asyncTask() {.async.} =
let
monitor = signalMonitor(ctx, fut[])
await monitor
waitFor asyncTask()
createThread(thread, threadTask, threadArgs)
ctx.cancelled = false
check: ctx.signal.fireSync.tryGet
joinThreads(thread)
check: not fut.cancelled
check: ctx.signal.close().isOk