From cf4eb38d4f930a7f8f22fe32a9306f9a5ee665ca Mon Sep 17 00:00:00 2001 From: Jaremy Creechley Date: Mon, 18 Sep 2023 13:04:13 -0700 Subject: [PATCH] protect from possible memory corruption on cancellation --- datastore/threads/threadproxyds.nim | 49 ++++++++++++----------------- 1 file changed, 20 insertions(+), 29 deletions(-) diff --git a/datastore/threads/threadproxyds.nim b/datastore/threads/threadproxyds.nim index c9b0408..9f145de 100644 --- a/datastore/threads/threadproxyds.nim +++ b/datastore/threads/threadproxyds.nim @@ -34,9 +34,9 @@ type ThreadTypes = void | bool | SomeInteger | DataBuffer | tuple | Atomic ThreadResult[T: ThreadTypes] = Result[T, DataBuffer] - TaskCtx[T: ThreadTypes] = object + TaskCtx[T: ThreadTypes] = ref object ds: ptr Datastore - res: ptr ThreadResult[T] + res: ThreadResult[T] cancelled: bool semaphore: AsyncSemaphore signal: ThreadSignalPtr @@ -85,11 +85,12 @@ template dispatchTask( withLocks(self, ctx, key, fut): try: + GC_ref(ctx) runTask() await fut - if ctx.res[].isErr: - result = failure(ctx.res[].error()) # TODO: fix this, result shouldn't be accessed + if ctx.res.isErr: + result = failure(ctx.res.error()) # TODO: fix this, result shouldn't be accessed except CancelledError as exc: trace "Cancelling thread future!", exc = exc.msg ctx.cancelled = true @@ -115,7 +116,7 @@ proc signalMonitor[T](ctx: ptr TaskCtx, fut: Future[T]) {.async.} = discard ctx[].signal.fireSync() except CatchableError as exc: trace "Exception in thread signal monitor", exc = exc.msg - ctx[].res[].err(exc) + ctx.res.err(exc) discard ctx[].signal.fireSync() proc asyncHasTask( @@ -129,10 +130,10 @@ proc asyncHasTask( asyncSpawn signalMonitor(ctx, fut) without ret =? (await fut).catch and res =? ret, error: - ctx[].res[].err(error) + ctx.res.err(error) return - ctx[].res[].ok(res) + ctx.res.ok(res) proc hasTask(ctx: ptr TaskCtx, key: ptr Key) = try: @@ -151,17 +152,15 @@ method has*(self: ThreadDatastore, key: Key): Future[?!bool] {.async.} = signal = ThreadSignalPtr.new().valueOr: return failure(error()) - res = ThreadResult[bool]() ctx = TaskCtx[bool]( ds: addr self.ds, - res: addr res, signal: signal) proc runTask() = self.tp.spawn hasTask(addr ctx, unsafeAddr key) self.dispatchTask(ctx, key.some, runTask) - return success(res.get()) + return success(ctx.res.get()) proc asyncDelTask(ctx: ptr TaskCtx[void], key: ptr Key) {.async.} = defer: @@ -173,10 +172,10 @@ proc asyncDelTask(ctx: ptr TaskCtx[void], key: ptr Key) {.async.} = asyncSpawn signalMonitor(ctx, fut) without res =? (await fut).catch, error: trace "Error in asyncDelTask", error = error.msg - ctx[].res[].err(error) + ctx.res.err(error) return - ctx[].res[].ok() + ctx.res.ok() return proc delTask(ctx: ptr TaskCtx, key: ptr Key) = @@ -198,10 +197,8 @@ method delete*( signal = ThreadSignalPtr.new().valueOr: return failure(error()) - res = ThreadResult[void]() ctx = TaskCtx[void]( ds: addr self.ds, - res: addr res, signal: signal) proc runTask() = @@ -234,10 +231,10 @@ proc asyncPutTask( asyncSpawn signalMonitor(ctx, fut) without res =? (await fut).catch, error: trace "Error in asyncPutTask", error = error.msg - ctx[].res[].err(error) + ctx.res.err(error) return - ctx[].res[].ok() + ctx.res.ok() proc putTask( ctx: ptr TaskCtx, @@ -266,10 +263,8 @@ method put*( signal = ThreadSignalPtr.new().valueOr: return failure(error()) - res = ThreadResult[void]() ctx = TaskCtx[void]( ds: addr self.ds, - res: addr res, signal: signal) proc runTask() = @@ -304,10 +299,10 @@ proc asyncGetTask( asyncSpawn signalMonitor(ctx, fut) without res =? (await fut).catch and data =? res, error: trace "Error in asyncGetTask", error = error.msg - ctx[].res[].err(error) + ctx.res.err(error) return - ctx[].res[].ok(DataBuffer.new(data)) + ctx.res.ok(DataBuffer.new(data)) proc getTask( ctx: ptr TaskCtx, @@ -334,10 +329,8 @@ method get*( return failure(error()) var - res = ThreadResult[DataBuffer]() ctx = TaskCtx[DataBuffer]( ds: addr self.ds, - res: addr res, signal: signal) proc runTask() = @@ -367,18 +360,18 @@ proc asyncQueryTask( asyncSpawn signalMonitor(ctx, fut) without ret =? (await fut).catch and res =? ret, error: trace "Error in asyncQueryTask", error = error.msg - ctx[].res[].err(error) + ctx.res.err(error) return if res.key.isNone: - ctx[].res[].ok((false, default(DataBuffer), default(DataBuffer))) + ctx.res.ok((false, default(DataBuffer), default(DataBuffer))) return var keyBuf = DataBuffer.new($(res.key.get())) dataBuf = DataBuffer.new(res.data) - ctx[].res[].ok((true, keyBuf, dataBuf)) + ctx.res.ok((true, keyBuf, dataBuf)) proc queryTask( ctx: ptr TaskCtx, @@ -422,21 +415,19 @@ method query*( signal = ThreadSignalPtr.new().valueOr: return failure("Failed to create signal") - res = ThreadResult[(bool, DataBuffer, DataBuffer)]() ctx = TaskCtx[(bool, DataBuffer, DataBuffer)]( ds: addr self.ds, - res: addr res, signal: signal) proc runTask() = self.tp.spawn queryTask(addr ctx, addr childIter) self.dispatchTask(ctx, Key.none, runTask) - if err =? res.errorOption: + if err =? ctx.res.errorOption: trace "Query failed", err = err return failure err - let (ok, key, data) = res.get() + let (ok, key, data) = ctx.res.get() if not ok: iter.finished = true return success (Key.none, EmptyBytes)