diff --git a/eth/db/kvstore_sqlite3.nim b/eth/db/kvstore_sqlite3.nim index 99469c8..a1b11b4 100644 --- a/eth/db/kvstore_sqlite3.nim +++ b/eth/db/kvstore_sqlite3.nim @@ -3,11 +3,11 @@ {.push raises: [Defect].} import - std/[os, options, strformat], + std/[os, options, strformat, typetraits], sqlite3_abi, ./kvstore -export kvstore +export kvstore, typetraits type RawStmtPtr = ptr sqlite3_stmt @@ -37,6 +37,7 @@ type SqKeyspace* = object of RootObj # A Keyspace is a single key-value table - it is generally efficient to # create separate keyspaces for each type of data stored + open: bool getStmt, putStmt, delStmt, containsStmt, findStmt0, findStmt1, findStmt2: RawStmtPtr @@ -115,10 +116,11 @@ proc bindParam(s: RawStmtPtr, n: int, val: auto): cint = template bindParams(s: RawStmtPtr, params: auto) = when params is tuple: - var i = 1 - for param in fields(params): - checkErr bindParam(s, i, param) - inc i + when params.type.arity > 0: + var i = 1 + for param in fields(params): + checkErr bindParam(s, i, param) + inc i else: checkErr bindParam(s, 1, params) @@ -208,7 +210,8 @@ proc exec*[Params, Res](s: SqliteStmt[Params, Res], let v = sqlite3_step(s) case v of SQLITE_ROW: - onData(readResult(s, Res)) + if onData != nil: + onData(readResult(s, Res)) gotResults = true of SQLITE_DONE: break @@ -294,8 +297,9 @@ template exec*(db: SqStoreRef, stmt: string): KvResult[void] = proc get*(db: SqKeyspaceRef, key: openArray[byte], onData: DataProc): KvResult[bool] = - if db.getStmt == nil: return err("sqlite: database closed") + if not db.open: return err("sqlite: database closed") let getStmt = db.getStmt + if getStmt == nil: return ok(false) # no such table checkErr bindParam(getStmt, 1, key) let @@ -339,23 +343,26 @@ proc find*( db: SqKeyspaceRef, prefix: openArray[byte], onFind: KeyValueProc): KvResult[int] = + if not db.open: return err("sqlite: database closed") var next: seq[byte] # extended lifetime of bound param let findStmt = if prefix.len == 0: + if db.findStmt0 == nil: return ok(0) # no such table db.findStmt0 # all rows else: if not nextPrefix(prefix, next): # For example when looking for the prefix [byte 255], there are no # prefixes that lexicographically are greater, thus we use the # query that only does the >= comparison + if db.findStmt1 == nil: return ok(0) # no such table checkErr bindParam(db.findStmt1, 1, prefix) db.findStmt1 else: + if db.findStmt2 == nil: return ok(0) # no such table checkErr bindParam(db.findStmt2, 1, prefix) checkErr bindParam(db.findStmt2, 2, next) db.findStmt2 - if findStmt == nil: return err("sqlite: database closed") var total = 0 @@ -369,7 +376,8 @@ proc find*( kl = sqlite3_column_bytes(findStmt, 0) vp = cast[ptr UncheckedArray[byte]](sqlite3_column_blob(findStmt, 1)) vl = sqlite3_column_bytes(findStmt, 1) - onFind(kp.toOpenArray(0, kl - 1), vp.toOpenArray(0, vl - 1)) + if onFind != nil: + onFind(kp.toOpenArray(0, kl - 1), vp.toOpenArray(0, vl - 1)) total += 1 of SQLITE_DONE: break @@ -387,8 +395,9 @@ proc find*( ok(total) proc put*(db: SqKeyspaceRef, key, value: openArray[byte]): KvResult[void] = + if not db.open: return err("sqlite: database closed") let putStmt = db.putStmt - if putStmt == nil: return err("sqlite: database closed") + if putStmt == nil: return err("sqlite: cannot write to read-only database") checkErr bindParam(putStmt, 1, key) checkErr bindParam(putStmt, 2, value) @@ -405,8 +414,10 @@ proc put*(db: SqKeyspaceRef, key, value: openArray[byte]): KvResult[void] = res proc contains*(db: SqKeyspaceRef, key: openArray[byte]): KvResult[bool] = + if not db.open: return err("sqlite: database closed") let containsStmt = db.containsStmt - if containsStmt == nil: return err("sqlite: database closed") + if containsStmt == nil: return ok(false) # no such table + checkErr bindParam(containsStmt, 1, key) let @@ -423,8 +434,9 @@ proc contains*(db: SqKeyspaceRef, key: openArray[byte]): KvResult[bool] = res proc del*(db: SqKeyspaceRef, key: openArray[byte]): KvResult[void] = + if not db.open: return err("sqlite: database closed") let delStmt = db.delStmt - if delStmt == nil: return err("sqlite: database closed") + if delStmt == nil: return ok() # no such table checkErr bindParam(delStmt, 1, key) let res = @@ -473,14 +485,14 @@ proc checkpoint*(db: SqStoreRef, kind = SqStoreCheckpointKind.passive) = template prepare(env: ptr sqlite3, q: string): ptr sqlite3_stmt = block: var s: ptr sqlite3_stmt - checkErr sqlite3_prepare_v2(env, q, q.len.cint, addr s, nil): + checkErr sqlite3_prepare_v2(env, cstring(q), q.len.cint, addr s, nil): discard s template prepare(env: ptr sqlite3, q: string, cleanup: untyped): ptr sqlite3_stmt = block: var s: ptr sqlite3_stmt - checkErr sqlite3_prepare_v2(env, q, q.len.cint, addr s, nil) + checkErr sqlite3_prepare_v2(env, cstring(q), q.len.cint, addr s, nil) s template checkExec(s: ptr sqlite3_stmt) = @@ -564,7 +576,15 @@ proc init*( readOnly: readOnly )) -proc openKvStore*(db: SqStoreRef, name = "kvstore", withoutRowid = false): KvResult[SqKeyspaceRef] = +proc hasTable*(db: SqStoreRef, name: string): KvResult[bool] = + let + sql = "SELECT name FROM sqlite_master WHERE type='table' AND name='" & + name & "';" + db.exec(sql, (), proc(_: openArray[byte]) = discard) + +proc openKvStore*( + db: SqStoreRef, name = "kvstore", withoutRowid = false, + readOnly = false): KvResult[SqKeyspaceRef] = ## Open a new Key-Value store in the SQLite database ## ## withoutRowid: Create the table without rowid - this is more efficient when @@ -572,30 +592,33 @@ proc openKvStore*(db: SqStoreRef, name = "kvstore", withoutRowid = false): KvRes ## rows (the row being the sum of key and value) - see ## https://www.sqlite.org/withoutrowid.html ## - - if not db.readOnly: + let hasTable = if db.readOnly or readOnly: + ? db.hasTable(name) + else: let createSql = """ - CREATE TABLE IF NOT EXISTS """ & name & """ ( + CREATE TABLE IF NOT EXISTS '""" & name & """' ( key BLOB PRIMARY KEY, value BLOB )""" checkExec db.env, if withoutRowid: createSql & " WITHOUT ROWID;" else: createSql & ";" - + true var tmp: SqKeyspace defer: # We'll "move" ownership to the return value, effectively disabling "close" close(tmp) - - tmp.getStmt = prepare(db.env, "SELECT value FROM " & name & " WHERE key = ?;") - tmp.putStmt = - prepare(db.env, "INSERT OR REPLACE INTO " & name & "(key, value) VALUES (?, ?);") - tmp.delStmt = prepare(db.env, "DELETE FROM " & name & " WHERE key = ?;") - tmp.containsStmt = prepare(db.env, "SELECT 1 FROM " & name & " WHERE key = ?;") - tmp.findStmt0 = prepare(db.env, "SELECT key, value FROM " & name & ";") - tmp.findStmt1 = prepare(db.env, "SELECT key, value FROM " & name & " WHERE key >= ?;") - tmp.findStmt2 = prepare(db.env, "SELECT key, value FROM " & name & " WHERE key >= ? and key < ?;") + tmp.open = true + if hasTable: + tmp.getStmt = + prepare(db.env, "SELECT value FROM '" & name & "' WHERE key = ?;") + tmp.putStmt = + prepare(db.env, "INSERT OR REPLACE INTO '" & name & "'(key, value) VALUES (?, ?);") + tmp.delStmt = prepare(db.env, "DELETE FROM '" & name & "' WHERE key = ?;") + tmp.containsStmt = prepare(db.env, "SELECT 1 FROM '" & name & "' WHERE key = ?;") + tmp.findStmt0 = prepare(db.env, "SELECT key, value FROM '" & name & "';") + tmp.findStmt1 = prepare(db.env, "SELECT key, value FROM '" & name & "' WHERE key >= ?;") + tmp.findStmt2 = prepare(db.env, "SELECT key, value FROM '" & name & "' WHERE key >= ? and key < ?;") var res = SqKeyspaceRef() res[] = tmp diff --git a/tests/db/test_kvstore_sqlite3.nim b/tests/db/test_kvstore_sqlite3.nim index e7c181f..252c27e 100644 --- a/tests/db/test_kvstore_sqlite3.nim +++ b/tests/db/test_kvstore_sqlite3.nim @@ -16,6 +16,18 @@ procSuite "SqStoreRef": testKvStore(kvStore kv.get(), true) + test "Readonly kvstore with no table": + let db = SqStoreRef.init("", "test", inMemory = true, readOnly = true)[] + defer: db.close() + let kv = db.openKvStore().expect("working db") + + check: + not kv.get([byte 0, 1, 2], nil).expect("ok to query data") + kv.find([byte 0, 1, 2], nil).expect("ok") == 0 + kv.put([byte 0, 1, 2], []).isErr + kv.del([byte 0, 1, 2]).isOk + defer: kv[].close() + test "Prepare and execute statements": let db = SqStoreRef.init("", "test", inMemory = true)[] defer: db.close()