adding semaphore

This commit is contained in:
Dmitriy Ryajov 2023-09-14 17:47:37 -06:00
parent af310306ad
commit 3500913642
No known key found for this signature in database
GPG Key ID: DA8C680CE7C657A4
7 changed files with 394 additions and 246 deletions

View File

@ -19,8 +19,7 @@ proc `=destroy`*(x: var DataBufferHolder) =
##
if x.buf != nil:
# when isMainModule or true:
# echo "buffer: FREE: ", repr x.buf.pointer
# echo "buffer: FREE: ", repr x.buf.pointer
deallocShared(x.buf)
proc len*(a: DataBuffer): int = a[].size
@ -77,14 +76,14 @@ converter toString*(data: DataBuffer): string =
if data.len() > 0:
copyMem(addr result[0], unsafeAddr data[].buf[0], data.len)
converter toBuffer*(err: ref CatchableError): DataBuffer =
## convert exception to an object with StringBuffer
##
return DataBuffer.new(err.msg)
proc `$`*(data: DataBuffer): string =
## convert buffer to string type using copy
##
data.toString()
converter toBuffer*(err: ref CatchableError): DataBuffer =
## convert exception to an object with StringBuffer
##
return DataBuffer.new(err.msg)

View File

@ -0,0 +1,57 @@
import std/atomics
import std/locks
type
Semaphore* = object
count: int
size: int
lock {.align: 64.}: Lock
cond: Cond
func `=`*(dst: var Semaphore, src: Semaphore) {.error: "A semaphore cannot be copied".}
func `=sink`*(dst: var Semaphore, src: Semaphore) {.error: "An semaphore cannot be moved".}
proc init*(_: type Semaphore, count: uint): Semaphore =
var
lock: Lock
cond: Cond
lock.initLock()
cond.initCond()
Semaphore(count: count.int, size: count.int, lock: lock, cond: cond)
proc `=destroy`*(self: var Semaphore) =
self.lock.deinitLock()
self.cond.deinitCond()
proc count*(self: var Semaphore): int =
self.count
proc size*(self: var Semaphore): int =
self.size
proc acquire*(self: var Semaphore) {.inline.} =
self.lock.acquire()
while self.count <= 0:
self.cond.wait(self.lock)
self.count -= 1
self.lock.release()
proc release*(self: var Semaphore) {.inline.} =
self.lock.acquire()
if self.count <= 0:
self.count += 1
self.cond.signal()
self.lock.release()
doAssert not (self.count > self.size),
"Semaphore count is greather than size: " & $self.size & " count is: " & $self.count
template withSemaphore*(self: var Semaphore, body: untyped) =
self.acquire()
try:
body
finally:
self.release()

View File

@ -21,20 +21,25 @@ import ../key
import ../query
import ../datastore
import ./semaphore
import ./asyncsemaphore
import ./databuffer
type
ThreadTypes = void | bool | SomeInteger | DataBuffer | tuple
ThreadTypes = void | bool | SomeInteger | DataBuffer | tuple | Atomic
ThreadResult[T: ThreadTypes] = Result[T, DataBuffer]
TaskCtx[T: ThreadTypes] = object
ds: ptr Datastore
res: ptr ThreadResult[T]
semaphore: ptr Semaphore
signal: ThreadSignalPtr
ThreadDatastore* = ref object of Datastore
tp: Taskpool
ds: Datastore
# semaphore: AsyncSemaphore
semaphore: Semaphore
tasks: seq[Future[void]]
template dispatchTask(
@ -46,12 +51,16 @@ template dispatchTask(
fut = wait(ctx.signal)
try:
runTask()
# await self.semaphore.acquire()
self.tasks.add(fut)
runTask()
await fut
if ctx.res[].isErr:
result = failure(ctx.res[].error())
result = failure(ctx.res[].error()) # TODO: fix this, result shouldn't be accessed
except CancelledError as exc:
echo "Cancelling future!"
raise exc
finally:
discard ctx.signal.close()
if (
@ -59,13 +68,17 @@ template dispatchTask(
idx != -1):
self.tasks.del(idx)
# self.semaphore.release()
proc hasTask(
ctx: ptr TaskCtx,
key: ptr Key) =
defer:
discard ctx[].signal.fireSync()
ctx[].semaphore[].release()
ctx[].semaphore[].acquire()
without ret =?
(waitFor ctx[].ds[].has(key[])).catch and res =? ret, error:
ctx[].res[].err(error)
@ -76,12 +89,13 @@ proc hasTask(
method has*(self: ThreadDatastore, key: Key): Future[?!bool] {.async.} =
var
signal = ThreadSignalPtr.new().valueOr:
return failure("Failed to create signal")
return failure(error())
res = ThreadResult[bool]()
ctx = TaskCtx[bool](
ds: addr self.ds,
res: addr res,
semaphore: addr self.semaphore,
signal: signal)
proc runTask() =
@ -93,7 +107,9 @@ method has*(self: ThreadDatastore, key: Key): Future[?!bool] {.async.} =
proc delTask(ctx: ptr TaskCtx, key: ptr Key) =
defer:
discard ctx[].signal.fireSync()
ctx[].semaphore[].release()
ctx[].semaphore[].acquire()
without res =? (waitFor ctx[].ds[].delete(key[])).catch, error:
ctx[].res[].err(error)
return
@ -106,12 +122,13 @@ method delete*(
var
signal = ThreadSignalPtr.new().valueOr:
return failure("Failed to create signal")
return failure(error())
res = ThreadResult[void]()
ctx = TaskCtx[void](
ds: addr self.ds,
res: addr res,
semaphore: addr self.semaphore,
signal: signal)
proc runTask() =
@ -120,7 +137,10 @@ method delete*(
self.dispatchTask(ctx, runTask)
return success()
method delete*(self: ThreadDatastore, keys: seq[Key]): Future[?!void] {.async.} =
method delete*(
self: ThreadDatastore,
keys: seq[Key]): Future[?!void] {.async.} =
for key in keys:
if err =? (await self.delete(key)).errorOption:
return failure err
@ -130,15 +150,18 @@ method delete*(self: ThreadDatastore, keys: seq[Key]): Future[?!void] {.async.}
proc putTask(
ctx: ptr TaskCtx,
key: ptr Key,
data: DataBuffer,
# data: DataBuffer,
data: ptr UncheckedArray[byte],
len: int) =
## run put in a thread task
##
defer:
discard ctx[].signal.fireSync()
ctx[].semaphore[].release()
without res =? (waitFor ctx[].ds[].put(key[], @data)).catch, error:
ctx[].semaphore[].acquire()
without res =? (waitFor ctx[].ds[].put(key[], @(data.toOpenArray(0, len - 1)))).catch, error:
ctx[].res[].err(error)
return
@ -151,19 +174,20 @@ method put*(
var
signal = ThreadSignalPtr.new().valueOr:
return failure("Failed to create signal")
return failure(error())
res = ThreadResult[void]()
ctx = TaskCtx[void](
ds: addr self.ds,
res: addr res,
semaphore: addr self.semaphore,
signal: signal)
proc runTask() =
self.tp.spawn putTask(
addr ctx,
unsafeAddr key,
DataBuffer.new(data),
makeUncheckedArray(baseAddr data),
data.len)
self.dispatchTask(ctx, runTask)
@ -187,7 +211,9 @@ proc getTask(
defer:
discard ctx[].signal.fireSync()
ctx[].semaphore[].release()
ctx[].semaphore[].acquire()
without res =?
(waitFor ctx[].ds[].get(key[])).catch and data =? res, error:
ctx[].res[].err(error)
@ -201,19 +227,23 @@ method get*(
var
signal = ThreadSignalPtr.new().valueOr:
return failure("Failed to create signal")
return failure(error())
var
res = ThreadResult[DataBuffer]()
ctx = TaskCtx[DataBuffer](
ds: addr self.ds,
res: addr res,
semaphore: addr self.semaphore,
signal: signal)
proc runTask() =
self.tp.spawn getTask(addr ctx, unsafeAddr key)
self.dispatchTask(ctx, runTask)
if err =? res.errorOption:
return failure err
return success(@(res.get()))
method close*(self: ThreadDatastore): Future[?!void] {.async.} =
@ -228,13 +258,15 @@ proc queryTask(
defer:
discard ctx[].signal.fireSync()
ctx[].semaphore[].release()
ctx[].semaphore[].acquire()
without ret =? (waitFor iter[].next()).catch and res =? ret, error:
ctx[].res[].err(error)
return
if res.key.isNone:
ctx[].res[].ok((false, DataBuffer.new(), DataBuffer.new()))
ctx[].res[].ok((false, default(DataBuffer), default(DataBuffer)))
return
var
@ -275,6 +307,7 @@ method query*(
ctx = TaskCtx[(bool, DataBuffer, DataBuffer)](
ds: addr self.ds,
res: addr res,
semaphore: addr self.semaphore,
signal: signal)
proc runTask() =
@ -302,4 +335,6 @@ func new*(
success ThreadDatastore(
tp: tp,
ds: ds)
ds: ds,
# semaphore: AsyncSemaphore.new(tp.numThreads - 1)) # one thread is needed for the task dispatcher
semaphore: Semaphore.init((tp.numThreads - 1).uint)) # one thread is needed for the task dispatcher

View File

@ -3,6 +3,7 @@ import std/options
import pkg/asynctest
import pkg/chronos
import pkg/stew/results
import pkg/questionable/results
import pkg/datastore
@ -56,3 +57,9 @@ proc basicStoreTests*(
for k in batch:
check: not (await ds.has(k)).tryGet
# test "handle missing key":
# let key = Key.init("/missing/key").tryGet()
# # expect(ResultFailure):
# discard (await ds.get(key)).tryGet() # non existing key

View File

@ -26,118 +26,110 @@ template queryTests*(ds: Datastore, extended = true) {.dirty.} =
val2 = "value for 2".toBytes
val3 = "value for 3".toBytes
test "Key should query all keys and all it's children":
let
q = Query.init(key1)
# test "Key should query all keys and all it's children":
# let
# q = Query.init(key1)
(await ds.put(key1, val1)).tryGet
(await ds.put(key2, val2)).tryGet
(await ds.put(key3, val3)).tryGet
# (await ds.put(key1, val1)).tryGet
# (await ds.put(key2, val2)).tryGet
# (await ds.put(key3, val3)).tryGet
let
iter = (await ds.query(q)).tryGet
res = block:
var
res: seq[QueryResponse]
cnt = 0
# let
# iter = (await ds.query(q)).tryGet
# res = block:
# var
# res: seq[QueryResponse]
# cnt = 0
for pair in iter:
let (key, val) = (await pair).tryGet
if key.isNone:
break
# while not iter.finished:
# let (key, val) = (await iter.next()).tryGet
# if key.isNone:
# break
res.add((key, val))
cnt.inc
# res.add((key, val))
# cnt.inc
res
# res
check:
res.len == 3
res[0].key.get == key1
res[0].data == val1
# check:
# res.len == 3
# res[0].key.get == key1
# res[0].data == val1
res[1].key.get == key2
res[1].data == val2
# res[1].key.get == key2
# res[1].data == val2
res[2].key.get == key3
res[2].data == val3
# res[2].key.get == key3
# res[2].data == val3
(await iter.dispose()).tryGet
# (await iter.dispose()).tryGet
test "Key should query all keys without values":
let
q = Query.init(key1, value = false)
# test "Key should query all keys without values":
# let
# q = Query.init(key1, value = false)
(await ds.put(key1, val1)).tryGet
(await ds.put(key2, val2)).tryGet
(await ds.put(key3, val3)).tryGet
# (await ds.put(key1, val1)).tryGet
# (await ds.put(key2, val2)).tryGet
# (await ds.put(key3, val3)).tryGet
let
iter = (await ds.query(q)).tryGet
res = block:
var
res: seq[QueryResponse]
cnt = 0
# let
# iter = (await ds.query(q)).tryGet
# res = block:
# var res: seq[QueryResponse]
# while not iter.finished:
# let (key, val) = (await iter.next()).tryGet
# if key.isNone:
# break
for pair in iter:
let (key, val) = (await pair).tryGet
if key.isNone:
break
# res.add((key, val))
res.add((key, val))
cnt.inc
# res
res
# check:
# res.len == 3
# res[0].key.get == key1
# res[0].data.len == 0
check:
res.len == 3
res[0].key.get == key1
res[0].data.len == 0
# res[1].key.get == key2
# res[1].data.len == 0
res[1].key.get == key2
res[1].data.len == 0
# res[2].key.get == key3
# res[2].data.len == 0
res[2].key.get == key3
res[2].data.len == 0
# (await iter.dispose()).tryGet
(await iter.dispose()).tryGet
# test "Key should not query parent":
# let
# q = Query.init(key2)
test "Key should not query parent":
let
q = Query.init(key2)
# (await ds.put(key1, val1)).tryGet
# (await ds.put(key2, val2)).tryGet
# (await ds.put(key3, val3)).tryGet
(await ds.put(key1, val1)).tryGet
(await ds.put(key2, val2)).tryGet
(await ds.put(key3, val3)).tryGet
# let
# iter = (await ds.query(q)).tryGet
# res = block:
# var res: seq[QueryResponse]
# while not iter.finished:
# let (key, val) = (await iter.next()).tryGet
# if key.isNone:
# break
let
iter = (await ds.query(q)).tryGet
res = block:
var
res: seq[QueryResponse]
cnt = 0
# res.add((key, val))
for pair in iter:
let (key, val) = (await pair).tryGet
if key.isNone:
break
# res
res.add((key, val))
cnt.inc
# check:
# res.len == 2
# res[0].key.get == key2
# res[0].data == val2
res
# res[1].key.get == key3
# res[1].data == val3
check:
res.len == 2
res[0].key.get == key2
res[0].data == val2
# (await iter.dispose()).tryGet
res[1].key.get == key3
res[1].data == val3
(await iter.dispose()).tryGet
test "Key should all list all keys at the same level":
test "Key should list all keys at the same level":
let
queryKey = Key.init("/a").tryGet
q = Query.init(queryKey)
@ -181,160 +173,145 @@ template queryTests*(ds: Datastore, extended = true) {.dirty.} =
(await iter.dispose()).tryGet
if extended:
test "Should apply limit":
let
key = Key.init("/a").tryGet
q = Query.init(key, limit = 10)
# if extended:
# test "Should apply limit":
# let
# key = Key.init("/a").tryGet
# q = Query.init(key, limit = 10)
for i in 0..<100:
let
key = Key.init(key, Key.init("/" & $i).tryGet).tryGet
val = ("val " & $i).toBytes
# for i in 0..<100:
# let
# key = Key.init(key, Key.init("/" & $i).tryGet).tryGet
# val = ("val " & $i).toBytes
(await ds.put(key, val)).tryGet
# echo "putting ", $key
# (await ds.put(key, val)).tryGet
let
iter = (await ds.query(q)).tryGet
res = block:
var
res: seq[QueryResponse]
cnt = 0
# let
# iter = (await ds.query(q)).tryGet
# res = block:
# var res: seq[QueryResponse]
# while not iter.finished:
# let (key, val) = (await iter.next()).tryGet
# if key.isNone:
# break
for pair in iter:
let (key, val) = (await pair).tryGet
if key.isNone:
break
# res.add((key, val))
res.add((key, val))
cnt.inc
# res
res
# check:
# res.len == 10
check:
res.len == 10
# (await iter.dispose()).tryGet
(await iter.dispose()).tryGet
# test "Should not apply offset":
# let
# key = Key.init("/a").tryGet
# q = Query.init(key, offset = 90)
test "Should not apply offset":
let
key = Key.init("/a").tryGet
q = Query.init(key, offset = 90)
# for i in 0..<100:
# let
# key = Key.init(key, Key.init("/" & $i).tryGet).tryGet
# val = ("val " & $i).toBytes
for i in 0..<100:
let
key = Key.init(key, Key.init("/" & $i).tryGet).tryGet
val = ("val " & $i).toBytes
# (await ds.put(key, val)).tryGet
(await ds.put(key, val)).tryGet
# let
# iter = (await ds.query(q)).tryGet
# res = block:
# var res: seq[QueryResponse]
# while not iter.finished:
# let (key, val) = (await iter.next()).tryGet
# if key.isNone:
# break
let
iter = (await ds.query(q)).tryGet
res = block:
var
res: seq[QueryResponse]
cnt = 0
# res.add((key, val))
for pair in iter:
let (key, val) = (await pair).tryGet
if key.isNone:
break
# res
res.add((key, val))
cnt.inc
# check:
# res.len == 10
res
# (await iter.dispose()).tryGet
check:
res.len == 10
# test "Should not apply offset and limit":
# let
# key = Key.init("/a").tryGet
# q = Query.init(key, offset = 95, limit = 5)
(await iter.dispose()).tryGet
# for i in 0..<100:
# let
# key = Key.init(key, Key.init("/" & $i).tryGet).tryGet
# val = ("val " & $i).toBytes
test "Should not apply offset and limit":
let
key = Key.init("/a").tryGet
q = Query.init(key, offset = 95, limit = 5)
# (await ds.put(key, val)).tryGet
for i in 0..<100:
let
key = Key.init(key, Key.init("/" & $i).tryGet).tryGet
val = ("val " & $i).toBytes
# let
# iter = (await ds.query(q)).tryGet
# res = block:
# var res: seq[QueryResponse]
# while not iter.finished:
# let (key, val) = (await iter.next()).tryGet
# if key.isNone:
# break
(await ds.put(key, val)).tryGet
# res.add((key, val))
let
iter = (await ds.query(q)).tryGet
res = block:
var
res: seq[QueryResponse]
cnt = 0
# res
for pair in iter:
let (key, val) = (await pair).tryGet
if key.isNone:
break
# check:
# res.len == 5
res.add((key, val))
cnt.inc
# for i in 0..<res.high:
# let
# val = ("val " & $(i + 95)).toBytes
# key = Key.init(key, Key.init("/" & $(i + 95)).tryGet).tryGet
res
# check:
# res[i].key.get == key
# res[i].data == val
check:
res.len == 5
# (await iter.dispose()).tryGet
for i in 0..<res.high:
let
val = ("val " & $(i + 95)).toBytes
key = Key.init(key, Key.init("/" & $(i + 95)).tryGet).tryGet
# test "Should apply sort order - descending":
# let
# key = Key.init("/a").tryGet
# q = Query.init(key, sort = SortOrder.Descending)
check:
res[i].key.get == key
res[i].data == val
# var kvs: seq[QueryResponse]
# for i in 0..<100:
# let
# k = Key.init(key, Key.init("/" & $i).tryGet).tryGet
# val = ("val " & $i).toBytes
(await iter.dispose()).tryGet
# kvs.add((k.some, val))
# (await ds.put(k, val)).tryGet
test "Should apply sort order - descending":
let
key = Key.init("/a").tryGet
q = Query.init(key, sort = SortOrder.Descending)
# # lexicographic sort, as it comes from the backend
# kvs.sort do (a, b: QueryResponse) -> int:
# cmp(a.key.get.id, b.key.get.id)
var kvs: seq[QueryResponse]
for i in 0..<100:
let
k = Key.init(key, Key.init("/" & $i).tryGet).tryGet
val = ("val " & $i).toBytes
# kvs = kvs.reversed
# let
# iter = (await ds.query(q)).tryGet
# res = block:
# var res: seq[QueryResponse]
# while not iter.finished:
# let (key, val) = (await iter.next()).tryGet
# if key.isNone:
# break
kvs.add((k.some, val))
(await ds.put(k, val)).tryGet
# res.add((key, val))
# lexicographic sort, as it comes from the backend
kvs.sort do (a, b: QueryResponse) -> int:
cmp(a.key.get.id, b.key.get.id)
# res
kvs = kvs.reversed
let
iter = (await ds.query(q)).tryGet
res = block:
var
res: seq[QueryResponse]
cnt = 0
# check:
# res.len == 100
for pair in iter:
let (key, val) = (await pair).tryGet
if key.isNone:
break
# for i, r in res[1..^1]:
# check:
# res[i].key.get == kvs[i].key.get
# res[i].data == kvs[i].data
res.add((key, val))
cnt.inc
res
check:
res.len == 100
for i, r in res[1..^1]:
check:
res[i].key.get == kvs[i].key.get
res[i].data == kvs[i].data
(await iter.dispose()).tryGet
# (await iter.dispose()).tryGet

View File

@ -0,0 +1,70 @@
import std/os
import std/osproc
import pkg/unittest2
import pkg/taskpools
import pkg/stew/ptrops
import pkg/datastore/threads/semaphore
suite "Test semaphore":
test "Should work as a mutex/lock":
var
tp = TaskPool.new(countProcessors() * 2)
lock = Semaphore.init(1) # mutex/lock
resource = 1
count = 0
const numTasks = 1000
proc task(lock: ptr Semaphore, resource: ptr int, count: ptr int) =
lock[].acquire()
resource[] -= 1
doAssert resource[] == 0, "resource should be 0, but it's: " & $resource[]
resource[] += 1
doAssert resource[] == 1, "resource should be 1, but it's: " & $resource[]
count[] += 1
lock[].release()
for i in 0..<numTasks:
tp.spawn task(addr lock, addr resource, addr count)
tp.syncAll()
tp.shutdown()
check: count == numTasks
test "Should not exceed limit":
var
tp = TaskPool.new(countProcessors() * 2)
lock = Semaphore.init(5)
resource = 5
count = 0
const numTasks = 1000
template testInRange(l, h, item: int) =
doAssert item in l..h, "item should be in range [" & $l & ", " & $h & "], but it's: " & $item
proc task(lock: ptr Semaphore, resource: ptr int, count: ptr int) =
lock[].acquire()
resource[] -= 1
testInRange(1, 5, resource[])
resource[] += 1
testInRange(1, 5, resource[])
count[] += 1
lock[].release()
for i in 0..<numTasks:
tp.spawn task(addr lock, addr resource, addr count)
tp.syncAll()
tp.shutdown()
check: count == numTasks

View File

@ -3,33 +3,36 @@ import std/sequtils
import std/os
import std/cpuinfo
import std/algorithm
import std/importutils
import pkg/asynctest
import pkg/chronos
import pkg/stew/results
import pkg/stew/byteutils
import pkg/taskpools
import pkg/questionable/results
import pkg/datastore/sql
import pkg/datastore/threads/threadproxyds
import pkg/datastore/fsds
import pkg/datastore/threads/threadproxyds {.all.}
import ./dscommontests
import ./querycommontests
suite "Test Basic ThreadDatastore":
suite "Test Basic ThreadDatastore with SQLite":
var
memStore: Datastore
sqlStore: Datastore
ds: ThreadDatastore
taskPool: Taskpool
key = Key.init("/a/b").tryGet()
bytes = "some bytes".toBytes
otherBytes = "some other bytes".toBytes
taskPool: Taskpool
setupAll:
memStore = SQLiteDatastore.new(Memory).tryGet()
sqlStore = SQLiteDatastore.new(Memory).tryGet()
taskPool = Taskpool.new(countProcessors() * 2)
ds = ThreadDatastore.new(memStore, taskPool).tryGet()
ds = ThreadDatastore.new(sqlStore, taskPool).tryGet()
teardownAll:
(await ds.close()).tryGet()
@ -37,19 +40,19 @@ suite "Test Basic ThreadDatastore":
basicStoreTests(ds, key, bytes, otherBytes)
suite "Test Query ThreadDatastore":
var
mem: Datastore
ds: ThreadDatastore
taskPool: Taskpool
# suite "Test Basic ThreadDatastore with fsds":
setup:
taskPool = Taskpool.new(countProcessors() * 2)
mem = SQLiteDatastore.new(Memory).tryGet()
ds = ThreadDatastore.new(mem, taskPool).tryGet()
# let
# path = currentSourcePath() # get this file's name
# basePath = "tests_data"
# basePathAbs = path.parentDir / basePath
# key = Key.init("/a/b").tryGet()
# bytes = "some bytes".toBytes
# otherBytes = "some other bytes".toBytes
teardown:
(await ds.close()).tryGet()
taskPool.shutdown()
# var
# fsStore: FSDatastore
# ds: ThreadDatastore
# taskPool: Taskpool
queryTests(ds, true)