From 52bbe9d429a355e8e6f88a2650e750ad20c71c37 Mon Sep 17 00:00:00 2001 From: Tomasz Bekas Date: Wed, 15 Nov 2023 13:42:37 +0100 Subject: [PATCH] Concurrent datastore interface & sqlite implementation --- datastore.nim | 3 +- datastore/concurrentds.nim | 50 ++++++++++++ datastore/sql/sqliteds.nim | 103 +++++++++++++++++++++++-- datastore/sql/sqlitedsdb.nim | 102 +++++++++++++++++++++++- datastore/types.nim | 1 + tests/datastore/concurrentdstests.nim | 96 +++++++++++++++++++++++ tests/datastore/sql/testsqliteds.nim | 2 + tests/datastore/sql/testsqlitedsdb.nim | 8 +- 8 files changed, 353 insertions(+), 12 deletions(-) create mode 100644 datastore/concurrentds.nim create mode 100644 tests/datastore/concurrentdstests.nim diff --git a/datastore.nim b/datastore.nim index 6d43a20..deac2a6 100644 --- a/datastore.nim +++ b/datastore.nim @@ -1,7 +1,8 @@ import ./datastore/datastore +import ./datastore/concurrentds import ./datastore/fsds import ./datastore/sql import ./datastore/mountedds import ./datastore/tieredds -export datastore, fsds, mountedds, tieredds, sql +export datastore, concurrentds, fsds, mountedds, tieredds, sql diff --git a/datastore/concurrentds.nim b/datastore/concurrentds.nim new file mode 100644 index 0000000..da4170c --- /dev/null +++ b/datastore/concurrentds.nim @@ -0,0 +1,50 @@ +import pkg/chronos +import pkg/questionable +import pkg/questionable/results +import pkg/upraises + +import ./key +import ./query +import ./types +import ./datastore + +export key, query, types, datastore + +push: {.upraises: [].} + +type + Function*[T, U] = proc(value: T): U {.upraises: [], gcsafe, closure.} + Modify* = Function[?seq[byte], ?seq[byte]] + ModifyAsync* = Function[?seq[byte], Future[?seq[byte]]] + +method modify*(self: ConcurrentDatastore, key: Key, fn: Modify): Future[?!void] {.base, locks: "unknown".} = + ## Concurrently safe way of modifying the value associated with the `key`. + ## + ## This method first reads a value stored under the `key`, if such value exists it's wrapped into `some` + ## and passed as the only arg to the `fn`, otherwise `none` is passed. + ## + ## When `fn` returns `some`, returned value is put into the store, but only if it's different than + ## the existing value, otherwise nothing happens. + ## When `fn` returns `none` existing value is deleted from the store, if no value existed before + ## nothing happens. + ## + ## Note that `fn` can be called multiple times (when concurrent modify of the value was detected). + ## + + raiseAssert("Not implemented!") + +method modify*(self: ConcurrentDatastore, key: Key, fn: ModifyAsync): Future[?!void] {.base, locks: "unknown".} = + ## Concurrently safe way of modifying the value associated with the `key`. + ## + ## This method first reads a value stored under the `key`, if such value exists it's wrapped into `some` + ## and passed as the only arg to the `fn`, otherwise `none` is passed. + ## + ## When `fn` returns `some`, returned value is put into the store, but only if it's different than + ## the existing value, otherwise nothing happens. + ## When `fn` returns `none` existing value is deleted from the store, if no value existed before + ## nothing happens. + ## + ## Note that `fn` can be called multiple times (when concurrent modify of the value was detected). + ## + + raiseAssert("Not implemented!") diff --git a/datastore/sql/sqliteds.nim b/datastore/sql/sqliteds.nim index aa63274..a5b0b61 100644 --- a/datastore/sql/sqliteds.nim +++ b/datastore/sql/sqliteds.nim @@ -8,15 +8,15 @@ import pkg/sqlite3_abi from pkg/stew/results as stewResults import isErr import pkg/upraises -import ../datastore +import ../concurrentds import ./sqlitedsdb -export datastore, sqlitedsdb +export concurrentds, sqlitedsdb push: {.upraises: [].} type - SQLiteDatastore* = ref object of Datastore + SQLiteDatastore* = ref object of ConcurrentDatastore readOnly: bool db: SQLiteDsDb @@ -29,6 +29,99 @@ proc `readOnly=`*(self: SQLiteDatastore): bool proc timestamp*(t = epochTime()): int64 = (t * 1_000_000).int64 +const initVersion* = 0.int64 + +method modify*(self: SQLiteDatastore, key: Key, fn: ModifyAsync): Future[?!void] {.async.} = + var + retriesLeft = 100 # allows reasonable concurrency, avoids infinite loop + + while retriesLeft > 0: + var + currentData: seq[byte] + currentVersion: int64 + + proc onData(s: RawStmtPtr) = + currentData = dataCol(s, GetVersionedStmtDataCol)() + currentVersion = versionCol(s, GetVersionedStmtVersionCol)() + + if err =? self.db.getVersionedStmt.query((key.id), onData).errorOption: + return failure(err) + + let maybeCurrentData = if currentData.len > 0: some(currentData) else: seq[byte].none + var maybeNewData: ?seq[byte] + + try: + maybeNewData = await fn(maybeCurrentData) + except CatchableError as err: + return failure("Error running modify function: " & err.msg) + + if maybeCurrentData == maybeNewData: + # no need to change any stored value + break; + + if err =? self.db.beginStmt.exec().errorOption: + return failure(err) + if currentData =? maybeCurrentData and newData =? maybeNewData: + let updateParams = ( + newData, + currentVersion + 1, + timestamp(), + key.id, + currentVersion + ) + if err =? (self.db.updateVersionedStmt.exec(updateParams)).errorOption: + return failure(err) + elif currentData =? maybeCurrentData: + let deleteParams = ( + key.id, + currentVersion + ) + if err =? (self.db.deleteVersionedStmt.exec(deleteParams)).errorOption: + return failure(err) + elif newData =? maybeNewData: + let insertParams = ( + key.id, + newData, + initVersion, + timestamp() + ) + if err =? (self.db.insertVersionedStmt.exec(insertParams)).errorOption: + return failure(err) + + var changes = 0.int64 + proc onChangesResult(s: RawStmtPtr) = + changes = changesCol(s, 0)() + + if err =? self.db.getChangesStmt.query((), onChangesResult).errorOption: + if err =? self.db.rollbackStmt.exec().errorOption: + return failure(err) + return failure(err) + + if changes == 1: + if err =? self.db.endStmt.exec().errorOption: + return failure(err) + break + elif changes == 0: + # race condition detected + if err =? self.db.rollbackStmt.exec().errorOption: + return failure(err) + retriesLeft.dec + else: + if err =? self.db.rollbackStmt.exec().errorOption: + return failure(err) + return failure("Unexpected number of changes, expected either 0 or 1, was " & $changes) + + if retriesLeft == 0: + return failure("Retry limit exceeded") + + return success() + +method modify*(self: SQLiteDatastore, key: Key, fn: Modify): Future[?!void] {.async.} = + proc wrappedFn(maybeValue: ?seq[byte]): Future[(?seq[byte])] {.async.} = + return fn(maybeValue) + + return await self.modify(key, wrappedFn) + method has*(self: SQLiteDatastore, key: Key): Future[?!bool] {.async.} = var exists = false @@ -81,14 +174,14 @@ method get*(self: SQLiteDatastore, key: Key): Future[?!seq[byte]] {.async.} = return success bytes method put*(self: SQLiteDatastore, key: Key, data: seq[byte]): Future[?!void] {.async.} = - return self.db.putStmt.exec((key.id, data, timestamp())) + return self.db.putStmt.exec((key.id, data, initVersion, timestamp())) method put*(self: SQLiteDatastore, batch: seq[BatchEntry]): Future[?!void] {.async.} = if err =? self.db.beginStmt.exec().errorOption: return failure err for entry in batch: - if err =? self.db.putStmt.exec((entry.key.id, entry.data, timestamp())).errorOption: + if err =? self.db.putStmt.exec((entry.key.id, entry.data, initVersion, timestamp())).errorOption: if err =? self.db.rollbackStmt.exec().errorOption: return failure err diff --git a/datastore/sql/sqlitedsdb.nim b/datastore/sql/sqlitedsdb.nim index 503dea4..f35e087 100644 --- a/datastore/sql/sqlitedsdb.nim +++ b/datastore/sql/sqlitedsdb.nim @@ -1,4 +1,5 @@ import std/os +import std/strformat import pkg/questionable import pkg/questionable/results @@ -10,6 +11,7 @@ export sqliteutils type BoundIdCol* = proc (): string {.closure, gcsafe, upraises: [].} + BoundVersionCol* = proc (): int64 {.closure, gcsafe, upraises: [].} BoundDataCol* = proc (): seq[byte] {.closure, gcsafe, upraises: [].} BoundTimestampCol* = proc (): int64 {.closure, gcsafe, upraises: [].} @@ -19,8 +21,13 @@ type ContainsStmt* = SQLiteStmt[(string), void] DeleteStmt* = SQLiteStmt[(string), void] GetStmt* = SQLiteStmt[(string), void] - PutStmt* = SQLiteStmt[(string, seq[byte], int64), void] + PutStmt* = SQLiteStmt[(string, seq[byte], int64, int64), void] QueryStmt* = SQLiteStmt[(string), void] + GetVersionedStmt* = SQLiteStmt[(string), void] + InsertVersionedStmt* = SQLiteStmt[(string, seq[byte], int64, int64), void] + UpdateVersionedStmt* = SQLiteStmt[(seq[byte], int64, int64, string, int64), void] + DeleteVersionedStmt* = SQLiteStmt[(string, int64), void] + GetChangesStmt* = NoParamsStmt BeginStmt* = NoParamsStmt EndStmt* = NoParamsStmt RollbackStmt* = NoParamsStmt @@ -34,6 +41,11 @@ type getDataCol*: BoundDataCol getStmt*: GetStmt putStmt*: PutStmt + getVersionedStmt*: GetVersionedStmt + updateVersionedStmt*: UpdateVersionedStmt + insertVersionedStmt*: InsertVersionedStmt + deleteVersionedStmt*: DeleteVersionedStmt + getChangesStmt*: GetChangesStmt beginStmt*: BeginStmt endStmt*: EndStmt rollbackStmt*: RollbackStmt @@ -44,10 +56,12 @@ const IdColName* = "id" DataColName* = "data" + VersionColName* = "version" TimestampColName* = "timestamp" IdColType = "TEXT" DataColType = "BLOB" + VersionColType = "INTEGER" TimestampColType = "INTEGER" Memory* = ":memory:" @@ -69,6 +83,7 @@ const CREATE TABLE IF NOT EXISTS """ & TableName & """ ( """ & IdColName & """ """ & IdColType & """ NOT NULL PRIMARY KEY, """ & DataColName & """ """ & DataColType & """, + """ & VersionColName & """ """ & VersionColType & """ NOT NULL, """ & TimestampColName & """ """ & TimestampColType & """ NOT NULL ) WITHOUT ROWID; """ @@ -89,8 +104,9 @@ const REPLACE INTO """ & TableName & """ ( """ & IdColName & """, """ & DataColName & """, + """ & VersionColName & """, """ & TimestampColName & """ - ) VALUES (?, ?, ?) + ) VALUES (?, ?, ?, ?) """ QueryStmtIdStr* = """ @@ -119,6 +135,43 @@ const ORDER BY """ & IdColName & """ DESC """ + GetVersionedStmtStr* = fmt""" + SELECT {DataColName}, {VersionColName} FROM {TableName} + WHERE {IdColName} = ? + """ + + GetVersionedStmtDataCol* = 0 + GetVersionedStmtVersionCol* = 1 + + InsertVersionedStmtStr* = fmt""" + INSERT INTO {TableName} + ( + {IdColName}, + {DataColName}, + {VersionColName}, + {TimestampColName} + ) + VALUES (?, ?, ?, ?) + """ + + UpdateVersionedStmtStr* = fmt""" + UPDATE {TableName} + SET + {DataColName} = ?, + {VersionColName} = ?, + {TimestampColName} = ? + WHERE {IdColName} = ? AND {VersionColName} = ? + """ + + DeleteVersionedStmtStr* = fmt""" + DELETE FROM {TableName} + WHERE {IdColName} = ? AND {VersionColName} = ? + """ + + GetChangesStmtStr* = fmt""" + SELECT changes() + """ + BeginTransactionStr* = """ BEGIN; """ @@ -197,6 +250,21 @@ proc timestampCol*( return proc (): int64 = sqlite3_column_int64(s, index.cint) +proc versionCol*( + s: RawStmtPtr, + index: int): BoundVersionCol = + + checkColMetadata(s, index, VersionColName) + + return proc (): int64 = + sqlite3_column_int64(s, index.cint) + +proc changesCol*( + s: RawStmtPtr, + index: int): BoundVersionCol = + return proc (): int64 = + sqlite3_column_int64(s, index.cint) + proc getDBFilePath*(path: string): ?!string = try: let @@ -217,6 +285,11 @@ proc close*(self: SQLiteDsDb) = self.beginStmt.dispose self.endStmt.dispose self.rollbackStmt.dispose + self.getVersionedStmt.dispose + self.updateVersionedStmt.dispose + self.insertVersionedStmt.dispose + self.deleteVersionedStmt.dispose + self.getChangesStmt.dispose if not RawStmtPtr(self.deleteStmt).isNil: self.deleteStmt.dispose @@ -266,6 +339,11 @@ proc open*( deleteStmt: DeleteStmt getStmt: GetStmt putStmt: PutStmt + getVersionedStmt: GetVersionedStmt + updateVersionedStmt: UpdateVersionedStmt + insertVersionedStmt: InsertVersionedStmt + deleteVersionedStmt: DeleteVersionedStmt + getChangesStmt: GetChangesStmt beginStmt: BeginStmt endStmt: EndStmt rollbackStmt: RollbackStmt @@ -279,6 +357,18 @@ proc open*( putStmt = ? PutStmt.prepare( env.val, PutStmtStr, SQLITE_PREPARE_PERSISTENT) + insertVersionedStmt = ? InsertVersionedStmt.prepare( + env.val, InsertVersionedStmtStr, SQLITE_PREPARE_PERSISTENT) + + updateVersionedStmt = ? UpdateVersionedStmt.prepare( + env.val, UpdateVersionedStmtStr, SQLITE_PREPARE_PERSISTENT) + + deleteVersionedStmt = ? DeleteVersionedStmt.prepare( + env.val, DeleteVersionedStmtStr, SQLITE_PREPARE_PERSISTENT) + + getChangesStmt = ? GetChangesStmt.prepare( + env.val, GetChangesStmtStr, SQLITE_PREPARE_PERSISTENT) + beginStmt = ? BeginStmt.prepare( env.val, BeginTransactionStr, SQLITE_PREPARE_PERSISTENT) @@ -294,6 +384,9 @@ proc open*( getStmt = ? GetStmt.prepare( env.val, GetStmtStr, SQLITE_PREPARE_PERSISTENT) + getVersionedStmt = ? GetVersionedStmt.prepare( + env.val, GetVersionedStmtStr, SQLITE_PREPARE_PERSISTENT) + # if a readOnly/existing database does not satisfy the expected schema # `pepare()` will fail and `new` will return an error with message # "SQL logic error" @@ -310,6 +403,11 @@ proc open*( getStmt: getStmt, getDataCol: getDataCol, putStmt: putStmt, + getVersionedStmt: getVersionedStmt, + updateVersionedStmt: updateVersionedStmt, + insertVersionedStmt: insertVersionedStmt, + deleteVersionedStmt: deleteVersionedStmt, + getChangesStmt: getChangesStmt, beginStmt: beginStmt, endStmt: endStmt, rollbackStmt: rollbackStmt) diff --git a/datastore/types.nim b/datastore/types.nim index b019cdb..9f1385b 100644 --- a/datastore/types.nim +++ b/datastore/types.nim @@ -8,3 +8,4 @@ type DatastoreKeyNotFound* = object of DatastoreError Datastore* = ref object of RootObj + ConcurrentDatastore* = ref object of Datastore diff --git a/tests/datastore/concurrentdstests.nim b/tests/datastore/concurrentdstests.nim new file mode 100644 index 0000000..9900072 --- /dev/null +++ b/tests/datastore/concurrentdstests.nim @@ -0,0 +1,96 @@ +import std/options +import std/sugar +import std/random +import std/sequtils + +import pkg/asynctest +import pkg/chronos +import pkg/stew/endians2 +import pkg/questionable +import pkg/questionable/results + +import pkg/datastore/concurrentds + +proc concurrentStoreTests*( + ds: ConcurrentDatastore, + key: Key) = + + randomize() + + let processCount = 100 + + proc withRandDelay(op: Future[?!void]): Future[void] {.async.} = + await sleepAsync(rand(processCount).millis) + + let errMsg = (await op).errorOption.map((err) => err.msg) + + require none(string) == errMsg + + proc incAsyncFn(maybeBytes: ?seq[byte]): Future[?seq[byte]] {.async.} = + await sleepAsync(2.millis) # allows interleaving + if bytes =? maybeBytes: + let value = uint64.fromBytes(bytes) + return some(@((value + 1).toBytes)) + else: + return seq[byte].none + + test "unsafe increment - demo": + # this test demonstrates non synchronized read-modify-write sequence + (await ds.put(key, @(0.uint64.toBytes))).tryGet + + proc getIncAndPut(): Future[?!void] {.async.} = + without bytes =? (await ds.get(key)), err: + return failure(err) + + let value = uint64.fromBytes(bytes) + await sleepAsync(2.millis) # allows interleaving + + if err =? (await ds.put(key, @((value + 1).toBytes))).errorOption: + return failure(err) + else: + return success() + + let futs = newSeqWith(processCount, withRandDelay(getIncAndPut())) + await allFutures(futs).wait(10.seconds) + + let finalValue = uint64.fromBytes((await ds.get(key)).tryGet) + + check finalValue.int < processCount + + test "safe increment": + (await ds.put(key, @(0.uint64.toBytes))).tryGet + + let futs = newSeqWith(processCount, withRandDelay(ds.modify(key, incAsyncFn))) + await allFutures(futs).wait(10.seconds) + + let finalValue = uint64.fromBytes((await ds.get(key)).tryGet) + + check finalValue.int == processCount + + test "should update value": + (await ds.put(key, @((0.uint64).toBytes))).tryGet + + (await ds.modify(key, incAsyncFn)).tryGet + + let finalValue = uint64.fromBytes((await ds.get(key)).tryGet) + + check finalValue.int == 1 + + test "should put value": + (await ds.delete(key)).tryGet() + + (await ds.modify(key, (_: ?seq[byte]) => @(123.uint64.toBytes).some)).tryGet + + let finalValue = uint64.fromBytes((await ds.get(key)).tryGet) + + check finalValue.int == 123 + + test "should delete value": + let key = Key.init(Key.random).tryGet + (await ds.put(key, @(0.uint64.toBytes))).tryGet + + (await ds.modify(key, (_: ?seq[byte]) => seq[byte].none)).tryGet + + let hasKey = (await ds.has(key)).tryGet + + check not hasKey diff --git a/tests/datastore/sql/testsqliteds.nim b/tests/datastore/sql/testsqliteds.nim index c629eb0..c4cf5be 100644 --- a/tests/datastore/sql/testsqliteds.nim +++ b/tests/datastore/sql/testsqliteds.nim @@ -11,6 +11,7 @@ import pkg/stew/byteutils import pkg/datastore/sql/sqliteds import ../dscommontests +import ../concurrentdstests import ../querycommontests suite "Test Basic SQLiteDatastore": @@ -24,6 +25,7 @@ suite "Test Basic SQLiteDatastore": (await ds.close()).tryGet() basicStoreTests(ds, key, bytes, otherBytes) + concurrentStoreTests(ds, key) suite "Test Read Only SQLiteDatastore": let diff --git a/tests/datastore/sql/testsqlitedsdb.nim b/tests/datastore/sql/testsqlitedsdb.nim index d104933..b6ff105 100644 --- a/tests/datastore/sql/testsqlitedsdb.nim +++ b/tests/datastore/sql/testsqlitedsdb.nim @@ -106,9 +106,9 @@ suite "Test SQLite Datastore DB operations": test "Should insert key": check: - readOnlyDb.putStmt.exec((key.id, data, timestamp())).isErr() + readOnlyDb.putStmt.exec((key.id, data, initVersion, timestamp())).isErr() - dsDb.putStmt.exec((key.id, data, timestamp())).tryGet() + dsDb.putStmt.exec((key.id, data, initVersion, timestamp())).tryGet() test "Should select key": let @@ -124,9 +124,9 @@ suite "Test SQLite Datastore DB operations": test "Should update key": check: - readOnlyDb.putStmt.exec((key.id, otherData, timestamp())).isErr() + readOnlyDb.putStmt.exec((key.id, otherData, initVersion, timestamp())).isErr() - dsDb.putStmt.exec((key.id, otherData, timestamp())).tryGet() + dsDb.putStmt.exec((key.id, otherData, initVersion, timestamp())).tryGet() test "Should select updated key": let