mirror of
https://github.com/logos-storage/nim-datastore.git
synced 2026-01-04 06:33:11 +00:00
protect from possible memory corruption on cancellation
This commit is contained in:
parent
ed09b9c936
commit
cf4eb38d4f
@ -34,9 +34,9 @@ type
|
|||||||
ThreadTypes = void | bool | SomeInteger | DataBuffer | tuple | Atomic
|
ThreadTypes = void | bool | SomeInteger | DataBuffer | tuple | Atomic
|
||||||
ThreadResult[T: ThreadTypes] = Result[T, DataBuffer]
|
ThreadResult[T: ThreadTypes] = Result[T, DataBuffer]
|
||||||
|
|
||||||
TaskCtx[T: ThreadTypes] = object
|
TaskCtx[T: ThreadTypes] = ref object
|
||||||
ds: ptr Datastore
|
ds: ptr Datastore
|
||||||
res: ptr ThreadResult[T]
|
res: ThreadResult[T]
|
||||||
cancelled: bool
|
cancelled: bool
|
||||||
semaphore: AsyncSemaphore
|
semaphore: AsyncSemaphore
|
||||||
signal: ThreadSignalPtr
|
signal: ThreadSignalPtr
|
||||||
@ -85,11 +85,12 @@ template dispatchTask(
|
|||||||
|
|
||||||
withLocks(self, ctx, key, fut):
|
withLocks(self, ctx, key, fut):
|
||||||
try:
|
try:
|
||||||
|
GC_ref(ctx)
|
||||||
runTask()
|
runTask()
|
||||||
await fut
|
await fut
|
||||||
|
|
||||||
if ctx.res[].isErr:
|
if ctx.res.isErr:
|
||||||
result = failure(ctx.res[].error()) # TODO: fix this, result shouldn't be accessed
|
result = failure(ctx.res.error()) # TODO: fix this, result shouldn't be accessed
|
||||||
except CancelledError as exc:
|
except CancelledError as exc:
|
||||||
trace "Cancelling thread future!", exc = exc.msg
|
trace "Cancelling thread future!", exc = exc.msg
|
||||||
ctx.cancelled = true
|
ctx.cancelled = true
|
||||||
@ -115,7 +116,7 @@ proc signalMonitor[T](ctx: ptr TaskCtx, fut: Future[T]) {.async.} =
|
|||||||
discard ctx[].signal.fireSync()
|
discard ctx[].signal.fireSync()
|
||||||
except CatchableError as exc:
|
except CatchableError as exc:
|
||||||
trace "Exception in thread signal monitor", exc = exc.msg
|
trace "Exception in thread signal monitor", exc = exc.msg
|
||||||
ctx[].res[].err(exc)
|
ctx.res.err(exc)
|
||||||
discard ctx[].signal.fireSync()
|
discard ctx[].signal.fireSync()
|
||||||
|
|
||||||
proc asyncHasTask(
|
proc asyncHasTask(
|
||||||
@ -129,10 +130,10 @@ proc asyncHasTask(
|
|||||||
|
|
||||||
asyncSpawn signalMonitor(ctx, fut)
|
asyncSpawn signalMonitor(ctx, fut)
|
||||||
without ret =? (await fut).catch and res =? ret, error:
|
without ret =? (await fut).catch and res =? ret, error:
|
||||||
ctx[].res[].err(error)
|
ctx.res.err(error)
|
||||||
return
|
return
|
||||||
|
|
||||||
ctx[].res[].ok(res)
|
ctx.res.ok(res)
|
||||||
|
|
||||||
proc hasTask(ctx: ptr TaskCtx, key: ptr Key) =
|
proc hasTask(ctx: ptr TaskCtx, key: ptr Key) =
|
||||||
try:
|
try:
|
||||||
@ -151,17 +152,15 @@ method has*(self: ThreadDatastore, key: Key): Future[?!bool] {.async.} =
|
|||||||
signal = ThreadSignalPtr.new().valueOr:
|
signal = ThreadSignalPtr.new().valueOr:
|
||||||
return failure(error())
|
return failure(error())
|
||||||
|
|
||||||
res = ThreadResult[bool]()
|
|
||||||
ctx = TaskCtx[bool](
|
ctx = TaskCtx[bool](
|
||||||
ds: addr self.ds,
|
ds: addr self.ds,
|
||||||
res: addr res,
|
|
||||||
signal: signal)
|
signal: signal)
|
||||||
|
|
||||||
proc runTask() =
|
proc runTask() =
|
||||||
self.tp.spawn hasTask(addr ctx, unsafeAddr key)
|
self.tp.spawn hasTask(addr ctx, unsafeAddr key)
|
||||||
|
|
||||||
self.dispatchTask(ctx, key.some, runTask)
|
self.dispatchTask(ctx, key.some, runTask)
|
||||||
return success(res.get())
|
return success(ctx.res.get())
|
||||||
|
|
||||||
proc asyncDelTask(ctx: ptr TaskCtx[void], key: ptr Key) {.async.} =
|
proc asyncDelTask(ctx: ptr TaskCtx[void], key: ptr Key) {.async.} =
|
||||||
defer:
|
defer:
|
||||||
@ -173,10 +172,10 @@ proc asyncDelTask(ctx: ptr TaskCtx[void], key: ptr Key) {.async.} =
|
|||||||
asyncSpawn signalMonitor(ctx, fut)
|
asyncSpawn signalMonitor(ctx, fut)
|
||||||
without res =? (await fut).catch, error:
|
without res =? (await fut).catch, error:
|
||||||
trace "Error in asyncDelTask", error = error.msg
|
trace "Error in asyncDelTask", error = error.msg
|
||||||
ctx[].res[].err(error)
|
ctx.res.err(error)
|
||||||
return
|
return
|
||||||
|
|
||||||
ctx[].res[].ok()
|
ctx.res.ok()
|
||||||
return
|
return
|
||||||
|
|
||||||
proc delTask(ctx: ptr TaskCtx, key: ptr Key) =
|
proc delTask(ctx: ptr TaskCtx, key: ptr Key) =
|
||||||
@ -198,10 +197,8 @@ method delete*(
|
|||||||
signal = ThreadSignalPtr.new().valueOr:
|
signal = ThreadSignalPtr.new().valueOr:
|
||||||
return failure(error())
|
return failure(error())
|
||||||
|
|
||||||
res = ThreadResult[void]()
|
|
||||||
ctx = TaskCtx[void](
|
ctx = TaskCtx[void](
|
||||||
ds: addr self.ds,
|
ds: addr self.ds,
|
||||||
res: addr res,
|
|
||||||
signal: signal)
|
signal: signal)
|
||||||
|
|
||||||
proc runTask() =
|
proc runTask() =
|
||||||
@ -234,10 +231,10 @@ proc asyncPutTask(
|
|||||||
asyncSpawn signalMonitor(ctx, fut)
|
asyncSpawn signalMonitor(ctx, fut)
|
||||||
without res =? (await fut).catch, error:
|
without res =? (await fut).catch, error:
|
||||||
trace "Error in asyncPutTask", error = error.msg
|
trace "Error in asyncPutTask", error = error.msg
|
||||||
ctx[].res[].err(error)
|
ctx.res.err(error)
|
||||||
return
|
return
|
||||||
|
|
||||||
ctx[].res[].ok()
|
ctx.res.ok()
|
||||||
|
|
||||||
proc putTask(
|
proc putTask(
|
||||||
ctx: ptr TaskCtx,
|
ctx: ptr TaskCtx,
|
||||||
@ -266,10 +263,8 @@ method put*(
|
|||||||
signal = ThreadSignalPtr.new().valueOr:
|
signal = ThreadSignalPtr.new().valueOr:
|
||||||
return failure(error())
|
return failure(error())
|
||||||
|
|
||||||
res = ThreadResult[void]()
|
|
||||||
ctx = TaskCtx[void](
|
ctx = TaskCtx[void](
|
||||||
ds: addr self.ds,
|
ds: addr self.ds,
|
||||||
res: addr res,
|
|
||||||
signal: signal)
|
signal: signal)
|
||||||
|
|
||||||
proc runTask() =
|
proc runTask() =
|
||||||
@ -304,10 +299,10 @@ proc asyncGetTask(
|
|||||||
asyncSpawn signalMonitor(ctx, fut)
|
asyncSpawn signalMonitor(ctx, fut)
|
||||||
without res =? (await fut).catch and data =? res, error:
|
without res =? (await fut).catch and data =? res, error:
|
||||||
trace "Error in asyncGetTask", error = error.msg
|
trace "Error in asyncGetTask", error = error.msg
|
||||||
ctx[].res[].err(error)
|
ctx.res.err(error)
|
||||||
return
|
return
|
||||||
|
|
||||||
ctx[].res[].ok(DataBuffer.new(data))
|
ctx.res.ok(DataBuffer.new(data))
|
||||||
|
|
||||||
proc getTask(
|
proc getTask(
|
||||||
ctx: ptr TaskCtx,
|
ctx: ptr TaskCtx,
|
||||||
@ -334,10 +329,8 @@ method get*(
|
|||||||
return failure(error())
|
return failure(error())
|
||||||
|
|
||||||
var
|
var
|
||||||
res = ThreadResult[DataBuffer]()
|
|
||||||
ctx = TaskCtx[DataBuffer](
|
ctx = TaskCtx[DataBuffer](
|
||||||
ds: addr self.ds,
|
ds: addr self.ds,
|
||||||
res: addr res,
|
|
||||||
signal: signal)
|
signal: signal)
|
||||||
|
|
||||||
proc runTask() =
|
proc runTask() =
|
||||||
@ -367,18 +360,18 @@ proc asyncQueryTask(
|
|||||||
asyncSpawn signalMonitor(ctx, fut)
|
asyncSpawn signalMonitor(ctx, fut)
|
||||||
without ret =? (await fut).catch and res =? ret, error:
|
without ret =? (await fut).catch and res =? ret, error:
|
||||||
trace "Error in asyncQueryTask", error = error.msg
|
trace "Error in asyncQueryTask", error = error.msg
|
||||||
ctx[].res[].err(error)
|
ctx.res.err(error)
|
||||||
return
|
return
|
||||||
|
|
||||||
if res.key.isNone:
|
if res.key.isNone:
|
||||||
ctx[].res[].ok((false, default(DataBuffer), default(DataBuffer)))
|
ctx.res.ok((false, default(DataBuffer), default(DataBuffer)))
|
||||||
return
|
return
|
||||||
|
|
||||||
var
|
var
|
||||||
keyBuf = DataBuffer.new($(res.key.get()))
|
keyBuf = DataBuffer.new($(res.key.get()))
|
||||||
dataBuf = DataBuffer.new(res.data)
|
dataBuf = DataBuffer.new(res.data)
|
||||||
|
|
||||||
ctx[].res[].ok((true, keyBuf, dataBuf))
|
ctx.res.ok((true, keyBuf, dataBuf))
|
||||||
|
|
||||||
proc queryTask(
|
proc queryTask(
|
||||||
ctx: ptr TaskCtx,
|
ctx: ptr TaskCtx,
|
||||||
@ -422,21 +415,19 @@ method query*(
|
|||||||
signal = ThreadSignalPtr.new().valueOr:
|
signal = ThreadSignalPtr.new().valueOr:
|
||||||
return failure("Failed to create signal")
|
return failure("Failed to create signal")
|
||||||
|
|
||||||
res = ThreadResult[(bool, DataBuffer, DataBuffer)]()
|
|
||||||
ctx = TaskCtx[(bool, DataBuffer, DataBuffer)](
|
ctx = TaskCtx[(bool, DataBuffer, DataBuffer)](
|
||||||
ds: addr self.ds,
|
ds: addr self.ds,
|
||||||
res: addr res,
|
|
||||||
signal: signal)
|
signal: signal)
|
||||||
|
|
||||||
proc runTask() =
|
proc runTask() =
|
||||||
self.tp.spawn queryTask(addr ctx, addr childIter)
|
self.tp.spawn queryTask(addr ctx, addr childIter)
|
||||||
|
|
||||||
self.dispatchTask(ctx, Key.none, runTask)
|
self.dispatchTask(ctx, Key.none, runTask)
|
||||||
if err =? res.errorOption:
|
if err =? ctx.res.errorOption:
|
||||||
trace "Query failed", err = err
|
trace "Query failed", err = err
|
||||||
return failure err
|
return failure err
|
||||||
|
|
||||||
let (ok, key, data) = res.get()
|
let (ok, key, data) = ctx.res.get()
|
||||||
if not ok:
|
if not ok:
|
||||||
iter.finished = true
|
iter.finished = true
|
||||||
return success (Key.none, EmptyBytes)
|
return success (Key.none, EmptyBytes)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user