From d151c01cd810a4066127f216c3ab324b99e620ca Mon Sep 17 00:00:00 2001 From: Dmitriy Ryajov Date: Fri, 15 Sep 2023 13:08:38 -0600 Subject: [PATCH] enable cancellations --- datastore/threads/threadproxyds.nim | 200 ++++++++++++++++++-------- tests/datastore/testthreadproxyds.nim | 84 +++++++++++ 2 files changed, 228 insertions(+), 56 deletions(-) diff --git a/datastore/threads/threadproxyds.nim b/datastore/threads/threadproxyds.nim index 673e674..56c818a 100644 --- a/datastore/threads/threadproxyds.nim +++ b/datastore/threads/threadproxyds.nim @@ -8,6 +8,7 @@ push: {.upraises: [].} import std/atomics import std/strutils +import std/tables import pkg/chronos import pkg/chronos/threadsync @@ -16,30 +17,36 @@ import pkg/questionable/results import pkg/stew/ptrops import pkg/taskpools import pkg/stew/byteutils +import pkg/chronicles import ../key import ../query import ../datastore -import ./semaphore +import ./asyncsemaphore import ./databuffer type + ErrorEnum {.pure.} = enum + DatastoreErr, DatastoreKeyNotFoundErr, CatchableErr + ThreadTypes = void | bool | SomeInteger | DataBuffer | tuple | Atomic ThreadResult[T: ThreadTypes] = Result[T, DataBuffer] TaskCtx[T: ThreadTypes] = object ds: ptr Datastore res: ptr ThreadResult[T] - semaphore: ptr Semaphore + cancelled: bool + semaphore: AsyncSemaphore signal: ThreadSignalPtr ThreadDatastore* = ref object of Datastore tp: Taskpool ds: Datastore - semaphore: Semaphore # semaphore is used for backpressure \ - # to avoid exhausting file descriptors + semaphore: AsyncSemaphore # semaphore is used for backpressure \ + # to avoid exhausting file descriptors tasks: seq[Future[void]] + locks: Table[Key, AsyncLock] template dispatchTask( self: ThreadDatastore, @@ -57,7 +64,9 @@ template dispatchTask( if ctx.res[].isErr: result = failure(ctx.res[].error()) # TODO: fix this, result shouldn't be accessed except CancelledError as exc: - echo "Cancelling future!" + trace "Cancelling future!", exc = exc.msg + ctx.cancelled = true + await ctx.signal.fire() raise exc finally: discard ctx.signal.close() @@ -66,23 +75,54 @@ template dispatchTask( idx != -1): self.tasks.del(idx) -proc hasTask( - ctx: ptr TaskCtx, - key: ptr Key) = +proc signalMonitor[T](ctx: ptr TaskCtx, fut: Future[T]) {.async.} = + ## Monitor the signal and cancel the future if + ## the cancellation flag is set + ## + try: + await ctx[].signal.wait() + trace "Received signal" + + if ctx[].cancelled: # there could eventually be other flags + trace "Cancelling future" + if not fut.finished: + await fut.cancelAndWait() # cancel the `has` future + + discard ctx[].signal.fireSync() + except CatchableError as exc: + trace "Exception in thread signal monitor", exc = exc.msg + ctx[].res[].err(exc) + discard ctx[].signal.fireSync() + +proc asyncHasTask( + ctx: ptr TaskCtx[bool], + key: ptr Key) {.async.} = defer: discard ctx[].signal.fireSync() - ctx[].semaphore[].release() - ctx[].semaphore[].acquire() - without ret =? - (waitFor ctx[].ds[].has(key[])).catch and res =? ret, error: + let + fut = ctx[].ds[].has(key[]) + + asyncSpawn signalMonitor(ctx, fut) + without ret =? (await fut).catch and res =? ret, error: ctx[].res[].err(error) return ctx[].res[].ok(res) +proc hasTask(ctx: ptr TaskCtx, key: ptr Key) = + try: + waitFor asyncHasTask(ctx, key) + except CatchableError as exc: + raiseAssert exc.msg + method has*(self: ThreadDatastore, key: Key): Future[?!bool] {.async.} = + defer: + self.semaphore.release() + + await self.semaphore.acquire() + var signal = ThreadSignalPtr.new().valueOr: return failure(error()) @@ -91,7 +131,6 @@ method has*(self: ThreadDatastore, key: Key): Future[?!bool] {.async.} = ctx = TaskCtx[bool]( ds: addr self.ds, res: addr res, - semaphore: addr self.semaphore, signal: signal) proc runTask() = @@ -100,21 +139,34 @@ method has*(self: ThreadDatastore, key: Key): Future[?!bool] {.async.} = self.dispatchTask(ctx, runTask) return success(res.get()) -proc delTask(ctx: ptr TaskCtx, key: ptr Key) = +proc asyncDelTask(ctx: ptr TaskCtx[void], key: ptr Key) {.async.} = defer: discard ctx[].signal.fireSync() - ctx[].semaphore[].release() - ctx[].semaphore[].acquire() - without res =? (waitFor ctx[].ds[].delete(key[])).catch, error: + let + fut = ctx[].ds[].delete(key[]) + + asyncSpawn signalMonitor(ctx, fut) + without res =? (await fut).catch, error: ctx[].res[].err(error) return ctx[].res[].ok() + return + +proc delTask(ctx: ptr TaskCtx, key: ptr Key) = + try: + waitFor asyncDelTask(ctx, key) + except CatchableError as exc: + raiseAssert exc.msg method delete*( self: ThreadDatastore, key: Key): Future[?!void] {.async.} = + defer: + self.semaphore.release() + + await self.semaphore.acquire() var signal = ThreadSignalPtr.new().valueOr: @@ -124,7 +176,6 @@ method delete*( ctx = TaskCtx[void]( ds: addr self.ds, res: addr res, - semaphore: addr self.semaphore, signal: signal) proc runTask() = @@ -143,30 +194,45 @@ method delete*( return success() -proc putTask( - ctx: ptr TaskCtx, +proc asyncPutTask( + ctx: ptr TaskCtx[void], key: ptr Key, - # data: DataBuffer, data: ptr UncheckedArray[byte], - len: int) = - ## run put in a thread task - ## - + len: int) {.async.} = defer: discard ctx[].signal.fireSync() - ctx[].semaphore[].release() - ctx[].semaphore[].acquire() - without res =? (waitFor ctx[].ds[].put(key[], @(data.toOpenArray(0, len - 1)))).catch, error: + let + fut = ctx[].ds[].put(key[], @(data.toOpenArray(0, len - 1))) + + asyncSpawn signalMonitor(ctx, fut) + without res =? (await fut).catch, error: ctx[].res[].err(error) return ctx[].res[].ok() +proc putTask( + ctx: ptr TaskCtx, + key: ptr Key, + data: ptr UncheckedArray[byte], + len: int) = + ## run put in a thread task + ## + + try: + waitFor asyncPutTask(ctx, key, data, len) + except CatchableError as exc: + raiseAssert exc.msg + method put*( self: ThreadDatastore, key: Key, data: seq[byte]): Future[?!void] {.async.} = + defer: + self.semaphore.release() + + await self.semaphore.acquire() var signal = ThreadSignalPtr.new().valueOr: @@ -176,7 +242,6 @@ method put*( ctx = TaskCtx[void]( ds: addr self.ds, res: addr res, - semaphore: addr self.semaphore, signal: signal) proc runTask() = @@ -199,27 +264,41 @@ method put*( return success() +proc asyncGetTask( + ctx: ptr TaskCtx[DataBuffer], + key: ptr Key) {.async.} = + defer: + discard ctx[].signal.fireSync() + + let + fut = ctx[].ds[].get(key[]) + + asyncSpawn signalMonitor(ctx, fut) + without res =? + (waitFor fut).catch and data =? res, error: + ctx[].res[].err(error) + return + + ctx[].res[].ok(DataBuffer.new(data)) + proc getTask( ctx: ptr TaskCtx, key: ptr Key) = ## Run get in a thread task ## - defer: - discard ctx[].signal.fireSync() - ctx[].semaphore[].release() - - ctx[].semaphore[].acquire() - without res =? - (waitFor ctx[].ds[].get(key[])).catch and data =? res, error: - ctx[].res[].err(error) - return - - ctx[].res[].ok(DataBuffer.new(data)) + try: + waitFor asyncGetTask(ctx, key) + except CatchableError as exc: + raiseAssert exc.msg method get*( self: ThreadDatastore, key: Key): Future[?!seq[byte]] {.async.} = + defer: + self.semaphore.release() + + await self.semaphore.acquire() var signal = ThreadSignalPtr.new().valueOr: @@ -230,7 +309,6 @@ method get*( ctx = TaskCtx[DataBuffer]( ds: addr self.ds, res: addr res, - semaphore: addr self.semaphore, signal: signal) proc runTask() = @@ -248,16 +326,17 @@ method close*(self: ThreadDatastore): Future[?!void] {.async.} = await self.ds.close() -proc queryTask( +proc asyncQueryTask( ctx: ptr TaskCtx, - iter: ptr QueryIter) = - + iter: ptr QueryIter) {.async.} = defer: discard ctx[].signal.fireSync() - ctx[].semaphore[].release() - ctx[].semaphore[].acquire() - without ret =? (waitFor iter[].next()).catch and res =? ret, error: + let + fut = iter[].next() + + asyncSpawn signalMonitor(ctx, fut) + without ret =? (waitFor fut).catch and res =? ret, error: ctx[].res[].err(error) return @@ -271,30 +350,40 @@ proc queryTask( ctx[].res[].ok((true, keyBuf, dataBuf)) +proc queryTask( + ctx: ptr TaskCtx, + iter: ptr QueryIter) = + + try: + waitFor asyncQueryTask(ctx, iter) + except CatchableError as exc: + raiseAssert exc.msg + method query*( self: ThreadDatastore, query: Query): Future[?!QueryIter] {.async.} = - without var childIter =? await self.ds.query(query), error: return failure error var iter = QueryIter.new() + locked = false - let lock = newAsyncLock() proc next(): Future[?!QueryResponse] {.async.} = defer: - if lock.locked: - lock.release() + locked = false + self.semaphore.release() - if lock.locked: + await self.semaphore.acquire() + + if locked: return failure (ref DatastoreError)(msg: "Should always await query features") + locked = true + if iter.finished == true: return failure (ref QueryEndedError)(msg: "Calling next on a finished query!") - await lock.acquire() - if iter.finished == true: return success (Key.none, EmptyBytes) @@ -306,7 +395,6 @@ method query*( ctx = TaskCtx[(bool, DataBuffer, DataBuffer)]( ds: addr self.ds, res: addr res, - semaphore: addr self.semaphore, signal: signal) proc runTask() = @@ -335,4 +423,4 @@ func new*( success ThreadDatastore( tp: tp, ds: ds, - semaphore: Semaphore.init((tp.numThreads - 1).uint)) + semaphore: AsyncSemaphore.new(tp.numThreads - 1)) diff --git a/tests/datastore/testthreadproxyds.nim b/tests/datastore/testthreadproxyds.nim index a0f7ddf..dac08fa 100644 --- a/tests/datastore/testthreadproxyds.nim +++ b/tests/datastore/testthreadproxyds.nim @@ -7,10 +7,12 @@ import std/importutils import pkg/asynctest import pkg/chronos +import pkg/chronos/threadsync import pkg/stew/results import pkg/stew/byteutils import pkg/taskpools import pkg/questionable/results +import pkg/chronicles import pkg/datastore/sql import pkg/datastore/fsds @@ -78,3 +80,85 @@ suite "Test Query ThreadDatastore with SQLite": # ds: ThreadDatastore # taskPool: Taskpool +suite "Test ThreadDatastore": + var + sqlStore: Datastore + ds: ThreadDatastore + taskPool: Taskpool + key = Key.init("/a/b").tryGet() + bytes = "some bytes".toBytes + otherBytes = "some other bytes".toBytes + + privateAccess(ThreadDatastore) # expose private fields + privateAccess(TaskCtx) # expose private fields + + setupAll: + sqlStore = SQLiteDatastore.new(Memory).tryGet() + taskPool = Taskpool.new(countProcessors() * 2) + ds = ThreadDatastore.new(sqlStore, taskPool).tryGet() + + test "should monitor signal for cancellations and cancel": + var + signal = ThreadSignalPtr.new().tryGet() + res = ThreadResult[void]() + ctx = TaskCtx[void]( + ds: addr sqlStore, + res: addr res, + signal: signal) + fut = newFuture[void]("signalMonitor") + threadArgs = (addr ctx, addr fut) + + var + thread: Thread[type threadArgs] + + proc threadTask(args: type threadArgs) = + var (ctx, fut) = args + proc asyncTask() {.async.} = + let + monitor = signalMonitor(ctx, fut[]) + + await monitor + + waitFor asyncTask() + + createThread(thread, threadTask, threadArgs) + ctx.cancelled = true + check: ctx.signal.fireSync.tryGet + + joinThreads(thread) + + check: fut.cancelled + check: ctx.signal.close().isOk + + test "should monitor signal for cancellations and not cancel": + var + signal = ThreadSignalPtr.new().tryGet() + res = ThreadResult[void]() + ctx = TaskCtx[void]( + ds: addr sqlStore, + res: addr res, + signal: signal) + fut = newFuture[void]("signalMonitor") + threadArgs = (addr ctx, addr fut) + + var + thread: Thread[type threadArgs] + + proc threadTask(args: type threadArgs) = + var (ctx, fut) = args + proc asyncTask() {.async.} = + let + monitor = signalMonitor(ctx, fut[]) + + await monitor + + waitFor asyncTask() + + createThread(thread, threadTask, threadArgs) + ctx.cancelled = false + check: ctx.signal.fireSync.tryGet + + joinThreads(thread) + + check: not fut.cancelled + check: ctx.signal.close().isOk