added (ugly!) locking capabilities

This commit is contained in:
Dmitriy Ryajov 2023-09-15 16:40:46 -06:00
parent f6acaa6f32
commit bee79ffe72
No known key found for this signature in database
GPG Key ID: DA8C680CE7C657A4
2 changed files with 102 additions and 36 deletions

View File

@ -188,7 +188,7 @@ method query*(
var
iter = QueryIter.new()
let lock = newAsyncLock()
var lock = newAsyncLock() # serialize querying under threads
proc next(): Future[?!QueryResponse] {.async.} =
defer:
if lock.locked:

View File

@ -9,6 +9,7 @@ push: {.upraises: [].}
import std/atomics
import std/strutils
import std/tables
import std/sequtils
import pkg/chronos
import pkg/chronos/threadsync
@ -45,35 +46,66 @@ type
ds: Datastore
semaphore: AsyncSemaphore # semaphore is used for backpressure \
# to avoid exhausting file descriptors
tasks: seq[Future[void]]
locks: Table[Key, AsyncLock]
case withLocks: bool
of true:
tasks: Table[Key, Future[void]]
queryLock: AsyncLock # global query lock, this is only really \
# needed for the fsds, but it is expensive!
else:
futs: seq[Future[void]] # keep a list of the futures to the signals around
template withLocks(
self: ThreadDatastore,
ctx: TaskCtx,
key: ?Key = Key.none,
fut: Future[void],
body: untyped) =
try:
case self.withLocks:
of true:
if key.isSome and
key.get in self.tasks:
await self.tasks[key.get]
await self.queryLock.acquire() # lock query or wait to finish
self.tasks[key.get] = fut
else:
self.futs.add(fut)
body
finally:
case self.withLocks:
of true:
if key.isSome:
self.tasks.del(key.get)
if self.queryLock.locked:
self.queryLock.release()
else:
self.futs.keepItIf(it != fut)
template dispatchTask(
self: ThreadDatastore,
ctx: TaskCtx,
key: ?Key = Key.none,
runTask: proc): untyped =
let
fut = wait(ctx.signal)
try:
self.tasks.add(fut)
runTask()
await fut
withLocks(self, ctx, key, fut):
try:
runTask()
await fut
if ctx.res[].isErr:
result = failure(ctx.res[].error()) # TODO: fix this, result shouldn't be accessed
except CancelledError as exc:
trace "Cancelling future!", exc = exc.msg
ctx.cancelled = true
await ctx.signal.fire()
raise exc
finally:
discard ctx.signal.close()
if (
let idx = self.tasks.find(fut);
idx != -1):
self.tasks.del(idx)
if ctx.res[].isErr:
result = failure(ctx.res[].error()) # TODO: fix this, result shouldn't be accessed
except CancelledError as exc:
trace "Cancelling thread future!", exc = exc.msg
ctx.cancelled = true
await ctx.signal.fire()
raise exc
finally:
discard ctx.signal.close()
proc signalMonitor[T](ctx: ptr TaskCtx, fut: Future[T]) {.async.} =
## Monitor the signal and cancel the future if
@ -115,6 +147,7 @@ proc hasTask(ctx: ptr TaskCtx, key: ptr Key) =
try:
waitFor asyncHasTask(ctx, key)
except CatchableError as exc:
trace "Unexpected exception thrown in asyncHasTask", error = error.msg
raiseAssert exc.msg
method has*(self: ThreadDatastore, key: Key): Future[?!bool] {.async.} =
@ -136,7 +169,7 @@ method has*(self: ThreadDatastore, key: Key): Future[?!bool] {.async.} =
proc runTask() =
self.tp.spawn hasTask(addr ctx, unsafeAddr key)
self.dispatchTask(ctx, runTask)
self.dispatchTask(ctx, key.some, runTask)
return success(res.get())
proc asyncDelTask(ctx: ptr TaskCtx[void], key: ptr Key) {.async.} =
@ -148,6 +181,7 @@ proc asyncDelTask(ctx: ptr TaskCtx[void], key: ptr Key) {.async.} =
asyncSpawn signalMonitor(ctx, fut)
without res =? (await fut).catch, error:
trace "Error in asyncDelTask", error = error.msg
ctx[].res[].err(error)
return
@ -158,6 +192,7 @@ proc delTask(ctx: ptr TaskCtx, key: ptr Key) =
try:
waitFor asyncDelTask(ctx, key)
except CatchableError as exc:
trace "Unexpected exception thrown in asyncDelTask", error = error.msg
raiseAssert exc.msg
method delete*(
@ -181,7 +216,7 @@ method delete*(
proc runTask() =
self.tp.spawn delTask(addr ctx, unsafeAddr key)
self.dispatchTask(ctx, runTask)
self.dispatchTask(ctx, key.some, runTask)
return success()
method delete*(
@ -207,6 +242,7 @@ proc asyncPutTask(
asyncSpawn signalMonitor(ctx, fut)
without res =? (await fut).catch, error:
trace "Error in asyncPutTask", error = error.msg
ctx[].res[].err(error)
return
@ -221,8 +257,9 @@ proc putTask(
##
try:
waitFor asyncPutTask(ctx, key, data, len)
waitFor asyncPutTask(ctx, key, data, len)
except CatchableError as exc:
trace "Unexpected exception thrown in asyncPutTask", error = error.msg
raiseAssert exc.msg
method put*(
@ -251,7 +288,7 @@ method put*(
makeUncheckedArray(baseAddr data),
data.len)
self.dispatchTask(ctx, runTask)
self.dispatchTask(ctx, key.some, runTask)
return success()
method put*(
@ -274,8 +311,8 @@ proc asyncGetTask(
fut = ctx[].ds[].get(key[])
asyncSpawn signalMonitor(ctx, fut)
without res =?
(waitFor fut).catch and data =? res, error:
without res =? (await fut).catch and data =? res, error:
trace "Error in asyncGetTask", error = error.msg
ctx[].res[].err(error)
return
@ -290,6 +327,7 @@ proc getTask(
try:
waitFor asyncGetTask(ctx, key)
except CatchableError as exc:
trace "Unexpected exception thrown in asyncGetTask", error = error.msg
raiseAssert exc.msg
method get*(
@ -314,15 +352,20 @@ method get*(
proc runTask() =
self.tp.spawn getTask(addr ctx, unsafeAddr key)
self.dispatchTask(ctx, runTask)
self.dispatchTask(ctx, key.some, runTask)
if err =? res.errorOption:
return failure err
return success(@(res.get()))
method close*(self: ThreadDatastore): Future[?!void] {.async.} =
for task in self.tasks:
await task.cancelAndWait()
var futs = if self.withLocks:
self.tasks.values.toSeq # toSeq(...) doesn't work here???
else:
self.futs
for fut in futs:
await fut.cancelAndWait()
await self.ds.close()
@ -336,7 +379,8 @@ proc asyncQueryTask(
fut = iter[].next()
asyncSpawn signalMonitor(ctx, fut)
without ret =? (waitFor fut).catch and res =? ret, error:
without ret =? (await fut).catch and res =? ret, error:
trace "Error in asyncQueryTask", error = error.msg
ctx[].res[].err(error)
return
@ -357,6 +401,7 @@ proc queryTask(
try:
waitFor asyncQueryTask(ctx, iter)
except CatchableError as exc:
trace "Unexpected exception thrown in asyncQueryTask", error = error.msg
raiseAssert exc.msg
method query*(
@ -373,8 +418,17 @@ method query*(
defer:
locked = false
self.semaphore.release()
case self.withLocks:
of true:
if self.queryLock.locked:
self.queryLock.release()
else:
discard
trace "About to query"
await self.semaphore.acquire()
if self.withLocks:
await self.queryLock.acquire()
if locked:
return failure (ref DatastoreError)(msg: "Should always await query features")
@ -400,8 +454,9 @@ method query*(
proc runTask() =
self.tp.spawn queryTask(addr ctx, addr childIter)
self.dispatchTask(ctx, runTask)
self.dispatchTask(ctx, Key.none, runTask)
if err =? res.errorOption:
trace "Query failed", err = err
return failure err
let (ok, key, data) = res.get()
@ -414,13 +469,24 @@ method query*(
iter.next = next
return success iter
func new*(
proc new*(
self: type ThreadDatastore,
ds: Datastore,
withLocks = static false,
tp: Taskpool): ?!ThreadDatastore =
doAssert tp.numThreads > 1, "ThreadDatastore requires at least 2 threads"
success ThreadDatastore(
tp: tp,
ds: ds,
semaphore: AsyncSemaphore.new(tp.numThreads - 1))
case withLocks:
of true:
success ThreadDatastore(
tp: tp,
ds: ds,
withLocks: true,
queryLock: newAsyncLock(),
semaphore: AsyncSemaphore.new(tp.numThreads - 1))
else:
success ThreadDatastore(
tp: tp,
ds: ds,
withLocks: false,
semaphore: AsyncSemaphore.new(tp.numThreads - 1))