From bee79ffe72eb92a6331af6060b543d1ca89abc03 Mon Sep 17 00:00:00 2001 From: Dmitriy Ryajov Date: Fri, 15 Sep 2023 16:40:46 -0600 Subject: [PATCH] added (ugly!) locking capabilities --- datastore/fsds.nim | 2 +- datastore/threads/threadproxyds.nim | 136 +++++++++++++++++++++------- 2 files changed, 102 insertions(+), 36 deletions(-) diff --git a/datastore/fsds.nim b/datastore/fsds.nim index 5526d5c..6c695bb 100644 --- a/datastore/fsds.nim +++ b/datastore/fsds.nim @@ -188,7 +188,7 @@ method query*( var iter = QueryIter.new() - let lock = newAsyncLock() + var lock = newAsyncLock() # serialize querying under threads proc next(): Future[?!QueryResponse] {.async.} = defer: if lock.locked: diff --git a/datastore/threads/threadproxyds.nim b/datastore/threads/threadproxyds.nim index 56c818a..1cf4d46 100644 --- a/datastore/threads/threadproxyds.nim +++ b/datastore/threads/threadproxyds.nim @@ -9,6 +9,7 @@ push: {.upraises: [].} import std/atomics import std/strutils import std/tables +import std/sequtils import pkg/chronos import pkg/chronos/threadsync @@ -45,35 +46,66 @@ type ds: Datastore semaphore: AsyncSemaphore # semaphore is used for backpressure \ # to avoid exhausting file descriptors - tasks: seq[Future[void]] - locks: Table[Key, AsyncLock] + case withLocks: bool + of true: + tasks: Table[Key, Future[void]] + queryLock: AsyncLock # global query lock, this is only really \ + # needed for the fsds, but it is expensive! + else: + futs: seq[Future[void]] # keep a list of the futures to the signals around + +template withLocks( + self: ThreadDatastore, + ctx: TaskCtx, + key: ?Key = Key.none, + fut: Future[void], + body: untyped) = + try: + case self.withLocks: + of true: + if key.isSome and + key.get in self.tasks: + await self.tasks[key.get] + await self.queryLock.acquire() # lock query or wait to finish + + self.tasks[key.get] = fut + else: + self.futs.add(fut) + + body + finally: + case self.withLocks: + of true: + if key.isSome: + self.tasks.del(key.get) + if self.queryLock.locked: + self.queryLock.release() + else: + self.futs.keepItIf(it != fut) template dispatchTask( self: ThreadDatastore, ctx: TaskCtx, + key: ?Key = Key.none, runTask: proc): untyped = let fut = wait(ctx.signal) - try: - self.tasks.add(fut) - runTask() - await fut + withLocks(self, ctx, key, fut): + try: + runTask() + await fut - if ctx.res[].isErr: - result = failure(ctx.res[].error()) # TODO: fix this, result shouldn't be accessed - except CancelledError as exc: - trace "Cancelling future!", exc = exc.msg - ctx.cancelled = true - await ctx.signal.fire() - raise exc - finally: - discard ctx.signal.close() - if ( - let idx = self.tasks.find(fut); - idx != -1): - self.tasks.del(idx) + if ctx.res[].isErr: + result = failure(ctx.res[].error()) # TODO: fix this, result shouldn't be accessed + except CancelledError as exc: + trace "Cancelling thread future!", exc = exc.msg + ctx.cancelled = true + await ctx.signal.fire() + raise exc + finally: + discard ctx.signal.close() proc signalMonitor[T](ctx: ptr TaskCtx, fut: Future[T]) {.async.} = ## Monitor the signal and cancel the future if @@ -115,6 +147,7 @@ proc hasTask(ctx: ptr TaskCtx, key: ptr Key) = try: waitFor asyncHasTask(ctx, key) except CatchableError as exc: + trace "Unexpected exception thrown in asyncHasTask", error = error.msg raiseAssert exc.msg method has*(self: ThreadDatastore, key: Key): Future[?!bool] {.async.} = @@ -136,7 +169,7 @@ method has*(self: ThreadDatastore, key: Key): Future[?!bool] {.async.} = proc runTask() = self.tp.spawn hasTask(addr ctx, unsafeAddr key) - self.dispatchTask(ctx, runTask) + self.dispatchTask(ctx, key.some, runTask) return success(res.get()) proc asyncDelTask(ctx: ptr TaskCtx[void], key: ptr Key) {.async.} = @@ -148,6 +181,7 @@ proc asyncDelTask(ctx: ptr TaskCtx[void], key: ptr Key) {.async.} = asyncSpawn signalMonitor(ctx, fut) without res =? (await fut).catch, error: + trace "Error in asyncDelTask", error = error.msg ctx[].res[].err(error) return @@ -158,6 +192,7 @@ proc delTask(ctx: ptr TaskCtx, key: ptr Key) = try: waitFor asyncDelTask(ctx, key) except CatchableError as exc: + trace "Unexpected exception thrown in asyncDelTask", error = error.msg raiseAssert exc.msg method delete*( @@ -181,7 +216,7 @@ method delete*( proc runTask() = self.tp.spawn delTask(addr ctx, unsafeAddr key) - self.dispatchTask(ctx, runTask) + self.dispatchTask(ctx, key.some, runTask) return success() method delete*( @@ -207,6 +242,7 @@ proc asyncPutTask( asyncSpawn signalMonitor(ctx, fut) without res =? (await fut).catch, error: + trace "Error in asyncPutTask", error = error.msg ctx[].res[].err(error) return @@ -221,8 +257,9 @@ proc putTask( ## try: - waitFor asyncPutTask(ctx, key, data, len) + waitFor asyncPutTask(ctx, key, data, len) except CatchableError as exc: + trace "Unexpected exception thrown in asyncPutTask", error = error.msg raiseAssert exc.msg method put*( @@ -251,7 +288,7 @@ method put*( makeUncheckedArray(baseAddr data), data.len) - self.dispatchTask(ctx, runTask) + self.dispatchTask(ctx, key.some, runTask) return success() method put*( @@ -274,8 +311,8 @@ proc asyncGetTask( fut = ctx[].ds[].get(key[]) asyncSpawn signalMonitor(ctx, fut) - without res =? - (waitFor fut).catch and data =? res, error: + without res =? (await fut).catch and data =? res, error: + trace "Error in asyncGetTask", error = error.msg ctx[].res[].err(error) return @@ -290,6 +327,7 @@ proc getTask( try: waitFor asyncGetTask(ctx, key) except CatchableError as exc: + trace "Unexpected exception thrown in asyncGetTask", error = error.msg raiseAssert exc.msg method get*( @@ -314,15 +352,20 @@ method get*( proc runTask() = self.tp.spawn getTask(addr ctx, unsafeAddr key) - self.dispatchTask(ctx, runTask) + self.dispatchTask(ctx, key.some, runTask) if err =? res.errorOption: return failure err return success(@(res.get())) method close*(self: ThreadDatastore): Future[?!void] {.async.} = - for task in self.tasks: - await task.cancelAndWait() + var futs = if self.withLocks: + self.tasks.values.toSeq # toSeq(...) doesn't work here??? + else: + self.futs + + for fut in futs: + await fut.cancelAndWait() await self.ds.close() @@ -336,7 +379,8 @@ proc asyncQueryTask( fut = iter[].next() asyncSpawn signalMonitor(ctx, fut) - without ret =? (waitFor fut).catch and res =? ret, error: + without ret =? (await fut).catch and res =? ret, error: + trace "Error in asyncQueryTask", error = error.msg ctx[].res[].err(error) return @@ -357,6 +401,7 @@ proc queryTask( try: waitFor asyncQueryTask(ctx, iter) except CatchableError as exc: + trace "Unexpected exception thrown in asyncQueryTask", error = error.msg raiseAssert exc.msg method query*( @@ -373,8 +418,17 @@ method query*( defer: locked = false self.semaphore.release() + case self.withLocks: + of true: + if self.queryLock.locked: + self.queryLock.release() + else: + discard + trace "About to query" await self.semaphore.acquire() + if self.withLocks: + await self.queryLock.acquire() if locked: return failure (ref DatastoreError)(msg: "Should always await query features") @@ -400,8 +454,9 @@ method query*( proc runTask() = self.tp.spawn queryTask(addr ctx, addr childIter) - self.dispatchTask(ctx, runTask) + self.dispatchTask(ctx, Key.none, runTask) if err =? res.errorOption: + trace "Query failed", err = err return failure err let (ok, key, data) = res.get() @@ -414,13 +469,24 @@ method query*( iter.next = next return success iter -func new*( +proc new*( self: type ThreadDatastore, ds: Datastore, + withLocks = static false, tp: Taskpool): ?!ThreadDatastore = doAssert tp.numThreads > 1, "ThreadDatastore requires at least 2 threads" - success ThreadDatastore( - tp: tp, - ds: ds, - semaphore: AsyncSemaphore.new(tp.numThreads - 1)) + case withLocks: + of true: + success ThreadDatastore( + tp: tp, + ds: ds, + withLocks: true, + queryLock: newAsyncLock(), + semaphore: AsyncSemaphore.new(tp.numThreads - 1)) + else: + success ThreadDatastore( + tp: tp, + ds: ds, + withLocks: false, + semaphore: AsyncSemaphore.new(tp.numThreads - 1))