From 8a9816ac02e6026e0fbeffaea474565cba46fc6a Mon Sep 17 00:00:00 2001 From: Jordan Hrycaj Date: Mon, 18 Jul 2022 18:56:31 +0100 Subject: [PATCH] Fix iterator edge cases for [high(P),high(P)] (#129) also: cascaded `if` in rbtree (for unrelated troubleshooting) --- stew/interval_set.nim | 32 ++++++++++++++++++------------- stew/sorted_set/rbtree_delete.nim | 5 +++-- stew/sorted_set/rbtree_find.nim | 5 +++-- tests/test_interval_set.nim | 18 +++++++++++++++++ 4 files changed, 43 insertions(+), 17 deletions(-) diff --git a/stew/interval_set.nim b/stew/interval_set.nim index 1b7325e..61eb3be 100644 --- a/stew/interval_set.nim +++ b/stew/interval_set.nim @@ -849,13 +849,17 @@ iterator increasing*[P,S]( ## any interval already visited. Intervals not visited yet must not be ## deleted as the loop would become unpredictable. var rc = ds.leftPos.ge(minPt) - while rc.isOk: - let key = rc.value.key - if high(P) <= rc.value.right and ds.lastHigh: - yield Interval[P,S].new(rc.value.left,high(P)) - else: - yield Interval[P,S].new(rc.value) - rc = ds.leftPos.gt(key) + if rc.isErr: + if ds.lastHigh: + yield Interval[P,S].new(high(P),high(P)) + else: + while rc.isOk: + let key = rc.value.key + if high(P) <= rc.value.right and ds.lastHigh: + yield Interval[P,S].new(rc.value.left,high(P)) + else: + yield Interval[P,S].new(rc.value) + rc = ds.leftPos.gt(key) iterator decreasing*[P,S]( ds: IntervalSetRef[P,S]; @@ -866,8 +870,10 @@ iterator decreasing*[P,S]( ## ## See the note at the `increasing()` function comment about deleting items. var rc = ds.leftPos.le(maxPt) - - if rc.isOk: + if rc.isErr: + if ds.lastHigh: + yield Interval[P,S].new(high(P),high(P)) + else: let key = rc.value.key # last entry: check for additional point if high(P) <= rc.value.right and ds.lastHigh: @@ -877,10 +883,10 @@ iterator decreasing*[P,S]( # find the next smaller one rc = ds.leftPos.lt(key) - while rc.isOk: - let key = rc.value.key - yield Interval[P,S].new(rc.value) - rc = ds.leftPos.lt(key) + while rc.isOk: + let key = rc.value.key + yield Interval[P,S].new(rc.value) + rc = ds.leftPos.lt(key) # ------------------------------------------------------------------------------ # Public interval operators diff --git a/stew/sorted_set/rbtree_delete.nim b/stew/sorted_set/rbtree_delete.nim index c602bc9..083cabe 100644 --- a/stew/sorted_set/rbtree_delete.nim +++ b/stew/sorted_set/rbtree_delete.nim @@ -119,8 +119,9 @@ proc rbTreeDelete*[C,K](rbt: RbTreeRef[C,K]; key: K): RbResult[C] = dirY = q.linkLeft.isNil.toDir parent.link[dirX] = q.link[dirY]; # clear node cache if this was the one to be deleted - if not rbt.cache.isNil and rbt.cmp(rbt.cache.casket,key) == 0: - rbt.cache = nil + if not rbt.cache.isNil: + if rbt.cmp(rbt.cache.casket,key) == 0: + rbt.cache = nil q = nil # some hint for the GC to recycle that node rbt.size.dec diff --git a/stew/sorted_set/rbtree_find.nim b/stew/sorted_set/rbtree_find.nim index f87a393..e2c5e3a 100644 --- a/stew/sorted_set/rbtree_find.nim +++ b/stew/sorted_set/rbtree_find.nim @@ -31,8 +31,9 @@ proc rbTreeFindEq*[C,K](rbt: RbTreeRef[C,K]; key: K): RbResult[C] = if rbt.root.isNil: return err(rbEmptyTree) - if not rbt.cache.isNil and rbt.cmp(rbt.cache.casket,key) == 0: - return ok(rbt.cache.casket) + if not rbt.cache.isNil: + if rbt.cmp(rbt.cache.casket,key) == 0: + return ok(rbt.cache.casket) var q = rbt.root diff --git a/tests/test_interval_set.nim b/tests/test_interval_set.nim index af292ee..5c34478 100644 --- a/tests/test_interval_set.nim +++ b/tests/test_interval_set.nim @@ -177,6 +177,24 @@ suite "IntervalSet: Intervals of FancyPoint entries over FancyScalar": check br.total.truncate(uint64) == (uHigh - 10000000) + 1 check br.verify.isOk + br.clear() + check br.total == 0 and br.chunks == 0 + check br.merge(uHigh,uHigh) == 1 + + block: + var (ivVal, ivSet) = (iv(0,0), false) + for iv in br.increasing: + check ivSet == false + (ivVal, ivSet) = (iv, true) + check ivVal == iv(uHigh,uHigh) + + block: + var (ivVal, ivSet) = (iv(0,0), false) + for iv in br.decreasing: + check ivSet == false + (ivVal, ivSet) = (iv, true) + check ivVal == iv(uHigh,uHigh) + test "Merge disjunct intervals on 1st set": br.clear() check br.merge( 0, 99) == 100