From 8e8d071ac73d16fafbf30daa3c2ccb1b2bc05e8f Mon Sep 17 00:00:00 2001 From: Tomasz Bekas Date: Tue, 27 Feb 2024 18:32:58 +0100 Subject: [PATCH] Fix missing rollbacks --- datastore/sql/sqliteds.nim | 42 +++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/datastore/sql/sqliteds.nim b/datastore/sql/sqliteds.nim index d959028..c764496 100644 --- a/datastore/sql/sqliteds.nim +++ b/datastore/sql/sqliteds.nim @@ -31,6 +31,16 @@ proc timestamp*(t = epochTime()): int64 = const initVersion* = 0.int64 +type RollbackError* = object of CatchableError + +proc newRollbackError(rbErr: ref CatchableError, opErrMsg: string): ref RollbackError = + let + msg = "Rollback initiated because of: " & opErrMsg & ". Rollback failed because of: " & rbErr.msg + return newException(RollbackError, msg, parentException = rbErr) + +proc newRollbackError(rbErr: ref CatchableError, opErr: ref CatchableError): ref RollbackError = + return newRollbackError(rbErr, opErr) + method modifyGet*(self: SQLiteDatastore, key: Key, fn: ModifyGet): Future[?!seq[byte]] {.async.} = var retriesLeft = 100 # allows reasonable concurrency, avoids infinite loop @@ -71,6 +81,8 @@ method modifyGet*(self: SQLiteDatastore, key: Key, fn: ModifyGet): Future[?!seq[ currentVersion ) if err =? (self.db.updateVersionedStmt.exec(updateParams)).errorOption: + if rbErr =? self.db.rollbackStmt.exec().errorOption: + return failure(newRollbackError(rbErr, err)) return failure(err) elif currentData =? maybeCurrentData: let deleteParams = ( @@ -78,6 +90,8 @@ method modifyGet*(self: SQLiteDatastore, key: Key, fn: ModifyGet): Future[?!seq[ currentVersion ) if err =? (self.db.deleteVersionedStmt.exec(deleteParams)).errorOption: + if rbErr =? self.db.rollbackStmt.exec().errorOption: + return failure(newRollbackError(rbErr, err)) return failure(err) elif newData =? maybeNewData: let insertParams = ( @@ -87,6 +101,8 @@ method modifyGet*(self: SQLiteDatastore, key: Key, fn: ModifyGet): Future[?!seq[ timestamp() ) if err =? (self.db.insertVersionedStmt.exec(insertParams)).errorOption: + if rbErr =? self.db.rollbackStmt.exec().errorOption: + return failure(newRollbackError(rbErr, err)) return failure(err) var changes = 0.int64 @@ -94,23 +110,25 @@ method modifyGet*(self: SQLiteDatastore, key: Key, fn: ModifyGet): Future[?!seq[ changes = changesCol(s, 0)() if err =? self.db.getChangesStmt.query((), onChangesResult).errorOption: - if err =? self.db.rollbackStmt.exec().errorOption: - return failure(err) + if rbErr =? self.db.rollbackStmt.exec().errorOption: + return failure(newRollbackError(rbErr, err)) return failure(err) if changes == 1: if err =? self.db.endStmt.exec().errorOption: + if rbErr =? self.db.rollbackStmt.exec().errorOption: + return failure(newRollbackError(rbErr, err)) return failure(err) break elif changes == 0: - # race condition detected - if err =? self.db.rollbackStmt.exec().errorOption: - return failure(err) + if rbErr =? self.db.rollbackStmt.exec().errorOption: + return failure(newRollbackError(rbErr, "Unable to retry after race condition was detected")) retriesLeft.dec else: - if err =? self.db.rollbackStmt.exec().errorOption: - return failure(err) - return failure("Unexpected number of changes, expected either 0 or 1, was " & $changes) + let msg = "Unexpected number of changes, expected either 0 or 1, was " & $changes + if rbErr =? self.db.rollbackStmt.exec().errorOption: + return failure(newRollbackError(rbErr, msg)) + return failure(msg) if retriesLeft == 0: return failure("Retry limit exceeded") @@ -150,8 +168,8 @@ method delete*(self: SQLiteDatastore, keys: seq[Key]): Future[?!void] {.async.} for key in keys: if err =? self.db.deleteStmt.exec((key.id)).errorOption: - if err =? self.db.rollbackStmt.exec().errorOption: - return failure err.msg + if rbErr =? self.db.rollbackStmt.exec().errorOption: + return failure(newRollbackError(rbErr, err)) return failure err.msg @@ -189,8 +207,8 @@ method put*(self: SQLiteDatastore, batch: seq[BatchEntry]): Future[?!void] {.asy for entry in batch: if err =? self.db.putStmt.exec((entry.key.id, entry.data, initVersion, timestamp())).errorOption: - if err =? self.db.rollbackStmt.exec().errorOption: - return failure err + if rbErr =? self.db.rollbackStmt.exec().errorOption: + return failure(newRollbackError(rbErr, err)) return failure err