diff --git a/eth/db/kvstore_sqlite3.nim b/eth/db/kvstore_sqlite3.nim index 1afffd9..9cea00d 100644 --- a/eth/db/kvstore_sqlite3.nim +++ b/eth/db/kvstore_sqlite3.nim @@ -1,3 +1,10 @@ +# nim-eth +# Copyright (c) 2019-2023 Status Research & Development GmbH +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at https://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at https://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + ## Implementation of KvStore based on sqlite3 {.push raises: [].} @@ -45,12 +52,6 @@ type SqKeyspaceRef* = ref SqKeyspace - CustomFunction* = - proc ( - a: openArray[byte], - b: openArray[byte] - ): Result[seq[byte], cstring] {.noSideEffect, cdecl, callback.} - template dispose(db: Sqlite) = discard sqlite3_close(db) @@ -660,51 +661,32 @@ proc openKvStore*( tmp = SqKeyspace() # make close harmless ok res -proc customScalarBlobFunction(ctx: ptr sqlite3_context, n: cint, v: ptr ptr sqlite3_value) {.cdecl, callback.} = - let ptrs = cast[ptr UncheckedArray[ptr sqlite3_value]](v) - let blob1 = cast[ptr UncheckedArray[byte]](sqlite3_value_blob(ptrs[][0])) - let blob2 = cast[ptr UncheckedArray[byte]](sqlite3_value_blob(ptrs[][1])) - let blob1Len = sqlite3_value_bytes(ptrs[][0]) - let blob2Len = sqlite3_value_bytes(ptrs[][1]) - # sqlite3_user_data retrieves data which was pointed by 5th param to - # sqlite3_create_function functions, which in our case is custom function - # provided by user - let usrFun = cast[CustomFunction](sqlite3_user_data(ctx)) - let s = usrFun( - toOpenArray(blob1, 0, blob1Len - 1), - toOpenArray(blob2, 0, blob2Len - 1) - ) +type + SqliteContext* = ptr sqlite3_context + SqliteValue* = ptr ptr sqlite3_value + SqliteCustomFunction* = + proc (a1: SqliteContext; a2: cint; a3: SqliteValue) {.cdecl, callback.} - if s.isOk(): - let bytes = s.unsafeGet() - # try is necessary as otherwise nim marks SQLITE_TRANSIENT as throwing - # unlisted exception. - # Using SQLITE_TRANSIENT destructor type, as it inform sqlite that data - # under provided pointer may be deleted at any moment, which is the case - # for seq[byte] as it is managed by nim gc. With this flag sqlite copy bytes - # under pointer and then releases them itself. - sqlite3_result_blob(ctx, unsafeAddr bytes[0], bytes.len.cint, SQLITE_TRANSIENT) - else: - let errMsg = s.error - sqlite3_result_error(ctx, errMsg, -1) - -proc registerCustomScalarFunction*(db: SqStoreRef, name: string, fun: CustomFunction): KvResult[void] = - ## Register custom function inside sqlite engine. Registered function can - ## be used in further queries by its name. Function should be side-effect - ## free and depends only on provided arguments. - ## Name of the function should be valid utf8 string. +proc createCustomFunction*( + db: SqStoreRef, name: string, argc: int, + customFunction: SqliteCustomFunction): + KvResult[void] = + ## Create custom function inside sqlite engine. Function can be used in + ## queries by the provided name. Function should be side-effect free and + ## depend only on provided arguments. + ## Name of the function must be a valid utf8 string. # Using SQLITE_DETERMINISTIC flag to inform sqlite that provided function - # won't have any side effect this may enable additional optimisations. + # will be deterministic, this may enable additional optimisations. let deterministicUtf8Func = cint(SQLITE_UTF8 or SQLITE_DETERMINISTIC) checkErr db.env, sqlite3_create_function( db.env, name, - cint(2), + cint(argc), deterministicUtf8Func, - cast[pointer](fun), - customScalarBlobFunction, + nil, + customFunction, nil, nil ) diff --git a/tests/db/all_tests.nim b/tests/db/all_tests.nim index 552aea1..2b676f3 100644 --- a/tests/db/all_tests.nim +++ b/tests/db/all_tests.nim @@ -1,3 +1,4 @@ import ./test_kvstore_sqlite3, + ./test_kvstore_sqlite3_custom_func, ./test_kvstore diff --git a/tests/db/test_kvstore_sqlite3.nim b/tests/db/test_kvstore_sqlite3.nim index 7feb9b3..faba91e 100644 --- a/tests/db/test_kvstore_sqlite3.nim +++ b/tests/db/test_kvstore_sqlite3.nim @@ -1,9 +1,8 @@ {.used.} import - std/[os, options, sequtils], + std/[os, options], testutils/unittests, - stew/endians2, ../../eth/db/[kvstore, kvstore_sqlite3], ./test_kvstore @@ -246,51 +245,3 @@ procSuite "SqStoreRef": check abc == row found = true check found - - proc customSumFun( - a: openArray[byte], - b: openArray[byte]): Result[seq[byte], cstring] {.cdecl.} = - let num1 = uint32.fromBytesBE(a) - let num2 = uint32.fromBytesBE(b) - let sum = num1 + num2 - let asBytes = sum.toBytesBE().toSeq() - return ok(asBytes) - - test "Register custom scalar function": - let db = SqStoreRef.init("", "test", inMemory = true)[] - - let registerResult = db.registerCustomScalarFunction("sum32", customSumFun) - - check: - registerResult.isOk() - - defer: db.close() - - let kv = db.openKvStore().get() - defer: kv.close() - - var sums: seq[seq[byte]] = @[] - - # Use custom function, which interprets blobs as uint32 numbers and sums - # them together - let sumKeyVal = db.prepareStmt( - "SELECT sum32(key, value) FROM kvstore;", - NoParams, seq[byte]).get - - let testUint = uint32(38) - - let putRes = kv.put(testUint.toBytesBE(), testUint.toBytesBE()) - - check: - putRes.isOk() - - discard sumKeyVal.exec do (res: seq[byte]): - sums.add(res) - - check: - len(sums) == 1 - - let sum = uint32.fromBytesBE(sums[0]) - - check: - sum == testUint + testUint diff --git a/tests/db/test_kvstore_sqlite3_custom_func.nim b/tests/db/test_kvstore_sqlite3_custom_func.nim new file mode 100644 index 0000000..0ae986b --- /dev/null +++ b/tests/db/test_kvstore_sqlite3_custom_func.nim @@ -0,0 +1,65 @@ +{.used.} + +import + std/sequtils, + testutils/unittests, + stew/endians2, + stew/ptrops, + sqlite3_abi, + ../../eth/db/kvstore_sqlite3 + +procSuite "SqStoreRef custom function": + + proc customSum( + ctx: SqliteContext, n: cint, v: SqliteValue) + {.cdecl, gcsafe, raises: [].} = + doAssert(n == 2) + + let + ptrs = makeUncheckedArray(v) + blob1Len = sqlite3_value_bytes(ptrs[][0]) + blob2Len = sqlite3_value_bytes(ptrs[][1]) + + num1 = uint32.fromBytesBE(makeOpenArray( + sqlite3_value_blob(ptrs[][0]), byte, blob1Len)) + num2 = uint32.fromBytesBE(makeOpenArray( + sqlite3_value_blob(ptrs[][1]), byte, blob2Len)) + sum = num1 + num2 + + bytes = sum.toBytesBE().toSeq() + + sqlite3_result_blob(ctx, baseAddr bytes, cint bytes.len, SQLITE_TRANSIENT) + + test "Create custom function": + let db = SqStoreRef.init("", "test", inMemory = true)[] + defer: db.close() + + db.createCustomFunction("sum32", 2, customSum).expect( + "Custom function creation OK") + + let kv = db.openKvStore().expect("Working database") + defer: kv.close() + + # Use the custom function, which interprets blobs as uint32 numbers and + # sums them together + let sumStmt = db.prepareStmt( + "SELECT sum32(key, value) FROM kvstore;", + NoParams, seq[byte]).get() + + let + key = uint32(39) + val = uint32(38) + + kv.put(key.toBytesBE(), val.toBytesBE()).expect("Working database") + + var sums: seq[seq[byte]] = @[] + discard sumStmt.exec do (res: seq[byte]): + sums.add(res) + + check: + len(sums) == 1 + + let sum = uint32.fromBytesBE(sums[0]) + + check: + sum == key + val