Add custom scalar functions to sqlite (#509)

This commit is contained in:
KonradStaniec 2022-06-02 14:14:15 +02:00 committed by GitHub
parent dffaa78cbe
commit dacf827a86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 115 additions and 2 deletions

View File

@ -42,6 +42,12 @@ type
SqKeyspaceRef* = ref SqKeyspace
CustomFunction* =
proc (
a: openArray[byte],
b: openArray[byte]
): Result[seq[byte], cstring] {.noSideEffect, gcsafe, cdecl, raises: [Defect].}
template dispose(db: Sqlite) =
discard sqlite3_close(db)
@ -593,6 +599,64 @@ proc openKvStore*(db: SqStoreRef, name = "kvstore", withoutRowid = false): KvRes
tmp = SqKeyspace() # make close harmless
ok res
proc customScalarBlobFunction(ctx: ptr sqlite3_context, n: cint, v: ptr ptr sqlite3_value) {.cdecl.} =
let ptrs = cast[ptr UncheckedArray[ptr sqlite3_value]](v)
let blob1 = cast[ptr UncheckedArray[byte]](sqlite3_value_blob(ptrs[][0]))
let blob2 = cast[ptr UncheckedArray[byte]](sqlite3_value_blob(ptrs[][1]))
let blob1Len = sqlite3_value_bytes(ptrs[][0])
let blob2Len = sqlite3_value_bytes(ptrs[][1])
# sqlite3_user_data retrieves data which was pointed by 5th param to
# sqlite3_create_function functions, which in our case is custom function
# provided by user
let usrFun = cast[CustomFunction](sqlite3_user_data(ctx))
let s = usrFun(
toOpenArray(blob1, 0, blob1Len - 1),
toOpenArray(blob2, 0, blob2Len - 1)
)
try:
if s.isOk():
let bytes = s.unsafeGet()
# try is necessessary as otherwise nim marks SQLITE_TRANSIENT as throwning
# unlisted exception.
# Using SQLITE_TRANSIENT destructor type, as it inform sqlite that data
# under provided pointer may be deleted at any moment, which is the case
# for seq[byte] as it is managed by nim gc. With this flag sqlite copy bytes
# under pointer and then realeases them itself.
sqlite3_result_blob(ctx, unsafeAddr bytes[0], bytes.len.cint, SQLITE_TRANSIENT)
else:
let errMsg = s.error
sqlite3_result_error(ctx, errMsg, -1)
except Exception as e:
raiseAssert(e.msg)
proc registerCustomScalarFunction*(db: SqStoreRef, name: string, fun: CustomFunction): KvResult[void] =
## Register custom function inside sqlite engine. Registered function can
## be used in further queries by its name. Function should be side-effect
## free and depends only on provided arguments.
## Name of the function should be valid utf8 string.
# Using SQLITE_DETERMINISTIC flag to inform sqlite that provided function
# won't have any side effect this may enable additional optimisations.
let deterministicUtf8Func = cint(SQLITE_UTF8 or SQLITE_DETERMINISTIC)
let res = sqlite3_create_function(
db.env,
name,
cint(2),
deterministicUtf8Func,
cast[pointer](fun),
customScalarBlobFunction,
nil,
nil
)
if res != SQLITE_OK:
return err($sqlite3_errstr(res))
else:
return ok()
when defined(metrics):
import locks, tables, times,
chronicles, metrics

View File

@ -1,8 +1,9 @@
{.used.}
import
std/[os, options],
std/[os, options, sequtils],
testutils/unittests,
stew/endians2,
../../eth/db/[kvstore, kvstore_sqlite3],
./test_kvstore
@ -232,4 +233,52 @@ procSuite "SqStoreRef":
rowRes.expect("working db")
check abc == row
found = true
check found
check found
proc customSumFun(
a: openArray[byte],
b: openArray[byte]): Result[seq[byte], cstring] {.noSideEffect, gcsafe, cdecl, raises: [Defect].} =
let num1 = uint32.fromBytesBE(a)
let num2 = uint32.fromBytesBE(b)
let sum = num1 + num2
let asBytes = sum.toBytesBE().toSeq()
return ok(asBytes)
test "Register custom scalar function":
let db = SqStoreRef.init("", "test", inMemory = true)[]
let registerResult = db.registerCustomScalarFunction("sum32", customSumFun)
check:
registerResult.isOk()
defer: db.close()
let kv = db.openKvStore().get()
defer: kv.close()
var sums: seq[seq[byte]] = @[]
# Use custom function, which interprest blobs as uint32 numbers and sums
# them together
let sumKeyVal = db.prepareStmt(
"SELECT sum32(key, value) FROM kvstore;",
NoParams, seq[byte]).get
let testUint = uint32(38)
let putRes = kv.put(testUint.toBytesBE(), testUint.toBytesBE())
check:
putRes.isOk()
discard sumKeyVal.exec do (res: seq[byte]):
sums.add(res)
check:
len(sums) == 1
let sum = uint32.fromBytesBE(sums[0])
check:
sum == testUint + testUint