diff --git a/datastore/leveldb/leveldbds.nim b/datastore/leveldb/leveldbds.nim index 6195710..d4f4d46 100644 --- a/datastore/leveldb/leveldbds.nim +++ b/datastore/leveldb/leveldbds.nim @@ -5,6 +5,7 @@ import std/tables import std/os import std/strformat import std/strutils +import std/sets import pkg/leveldbstatic import pkg/chronos @@ -19,6 +20,10 @@ type LevelDbDatastore* = ref object of Datastore db: LevelDb locks: TableRef[Key, AsyncLock] + openIterators: HashSet[QueryIter] + +proc hash(iter: QueryIter): Hash = + hash(addr iter) method has*(self: LevelDbDatastore, key: Key): Future[?!bool] {.async: (raises: [CancelledError]).} = try: @@ -70,6 +75,10 @@ method put*(self: LevelDbDatastore, batch: seq[BatchEntry]): Future[?!void] {.as method close*(self: LevelDbDatastore): Future[?!void] {.async: (raises: [CancelledError]).} = try: + for iter in self.openIterators: + if err =? (await iter.dispose()).errorOption: + return failure(err.msg) + self.openIterators.clear() self.db.close() return success() except LevelDbException as e: @@ -101,6 +110,8 @@ method query*( proc dispose(): Future[?!void] {.async: (raises: [CancelledError]).} = dbIter.dispose() iter.disposed = true + self.openIterators.excl(iter) + return success() proc next(): Future[?!QueryResponse] {.async: (raises: [CancelledError]).} = @@ -124,6 +135,8 @@ method query*( iter.next = next iter.dispose = dispose + self.openIterators.incl(iter) + return success iter method modifyGet*( diff --git a/tests/datastore/leveldb/testleveldbds.nim b/tests/datastore/leveldb/testleveldbds.nim index b07ad27..7434c7d 100644 --- a/tests/datastore/leveldb/testleveldbds.nim +++ b/tests/datastore/leveldb/testleveldbds.nim @@ -97,64 +97,64 @@ suite "LevelDB Query": (await ds.close()).tryGet removeDir(tempDir) - # test "should query by prefix": - # let - # q = Query.init(Key.init("/a/*").tryGet) - # iter = (await ds.query(q)).tryGet - # res = (await allFinished(toSeq(iter))) - # .mapIt( it.read.tryGet ) - # .filterIt( it.key.isSome ) + test "should query by prefix": + let + q = Query.init(Key.init("/a/*").tryGet) + iter = (await ds.query(q)).tryGet + res = (await allFinished(toSeq(iter))) + .mapIt( it.read.tryGet ) + .filterIt( it.key.isSome ) - # check: - # res.len == 3 - # res[0].key.get == key1 - # res[0].data == val1 + check: + res.len == 3 + res[0].key.get == key1 + res[0].data == val1 - # res[1].key.get == key2 - # res[1].data == val2 + res[1].key.get == key2 + res[1].data == val2 - # res[2].key.get == key3 - # res[2].data == val3 + res[2].key.get == key3 + res[2].data == val3 - # (await iter.dispose()).tryGet + (await iter.dispose()).tryGet - # test "should disregard forward trailing wildcards in keys": - # let - # q = Query.init(Key.init("/a/*").tryGet) - # iter = (await ds.query(q)).tryGet - # res = (await allFinished(toSeq(iter))) - # .mapIt( it.read.tryGet ) - # .filterIt( it.key.isSome ) + test "should disregard forward trailing wildcards in keys": + let + q = Query.init(Key.init("/a/*").tryGet) + iter = (await ds.query(q)).tryGet + res = (await allFinished(toSeq(iter))) + .mapIt( it.read.tryGet ) + .filterIt( it.key.isSome ) - # check: - # res.len == 3 - # res[0].key.get == key1 - # res[0].data == val1 + check: + res.len == 3 + res[0].key.get == key1 + res[0].data == val1 - # res[1].key.get == key2 - # res[1].data == val2 + res[1].key.get == key2 + res[1].data == val2 - # res[2].key.get == key3 - # res[2].data == val3 + res[2].key.get == key3 + res[2].data == val3 - # test "should disregard backward trailing wildcards in key": - # let - # q = Query.init(Key.init("/a\\*").tryGet) - # iter = (await ds.query(q)).tryGet - # res = (await allFinished(toSeq(iter))) - # .mapIt( it.read.tryGet ) - # .filterIt( it.key.isSome ) + test "should disregard backward trailing wildcards in key": + let + q = Query.init(Key.init("/a\\*").tryGet) + iter = (await ds.query(q)).tryGet + res = (await allFinished(toSeq(iter))) + .mapIt( it.read.tryGet ) + .filterIt( it.key.isSome ) - # check: - # res.len == 3 - # res[0].key.get == key1 - # res[0].data == val1 + check: + res.len == 3 + res[0].key.get == key1 + res[0].data == val1 - # res[1].key.get == key2 - # res[1].data == val2 + res[1].key.get == key2 + res[1].data == val2 - # res[2].key.get == key3 - # res[2].data == val3 + res[2].key.get == key3 + res[2].data == val3 test "should dispose automatically when iterator is finished": let @@ -174,3 +174,24 @@ suite "LevelDB Query": check iter.finished == true check iter.disposed == true + + test "should dispose automatically of iterators when datastore is closed": + let + q1 = Query.init(Key.init("/a/b/c").tryGet) + q2 = Query.init(Key.init("/a/b").tryGet) + i1 = (await ds.query(q1)).tryGet + i2 = (await ds.query(q2)).tryGet + + check i1.disposed == false + check i2.disposed == false + + (await ds.close()).tryGet + + check i1.disposed == true + check i2.disposed == true + + test "should have idempotent QueryIterator.dispose": + let q = Query.init(Key.init("/a/b/c").tryGet) + let iter = (await ds.query(q)).tryGet + (await iter.dispose()).tryGet + (await iter.dispose()).tryGet