Hide the complexity of dealing with the BitSeq marker bit inside an efficient machine words iterator

This commit is contained in:
Zahary Karadjov 2019-08-06 21:02:03 +03:00
parent a3df04d701
commit 19fd0cc489
No known key found for this signature in database
GPG Key ID: C8936F8A3073D609
4 changed files with 180 additions and 12 deletions

View File

@ -1,9 +1,25 @@
import
bitops2
bitops2, endians2, ranges/ptr_arith
type
Bytes = seq[byte]
BitSeq* = distinct Bytes
## TODO
##
## The current design of BitSeq tries to follow precisely
## the bitwise representation of the SSZ bitlists.
## This is a relatively compact representation, but as
## evident from the code below, many of the operations
## are not trivial.
##
## An alternative simpler approach would be to maintain
## the BitSeq as a sequence of words with an external uint8
## counter denoting the used bits in the last word.
##
## This will reduce the complexity of the code here, but
## we'll have to define serialization routines for all the
## formats where such values appear (SSZ, JSON, YAML, etc).
BitArray*[bits: static int] = object
bytes*: array[(bits + 7) div 8, byte]
@ -19,9 +35,6 @@ func len*(s: BitSeq): int =
template len*(a: BitArray): int =
a.bits
template bytes*(s: BitSeq): untyped =
Bytes(s)
func add*(s: var BitSeq, value: bool) =
let
lastBytePos = s.Bytes.len - 1
@ -37,6 +50,122 @@ func add*(s: var BitSeq, value: bool) =
s.Bytes[lastBytePos].setBit 7, value
s.Bytes.add byte(1)
func loadLEBytes(WordType: type, bytes: openarray[byte]): WordType =
# TODO: this is a temporary proc until the endians API is improved
var shift = 0
for b in bytes:
result = result or (WordType(b) shl shift)
shift += 8
func storeLEBytes(value: SomeUnsignedInt, dst: var openarray[byte]) =
when system.cpuEndian == bigEndian:
var shift = 0
for i in 0 ..< dst.len:
result[i] = byte((v shr shift) and 0xff)
shift += 8
else:
copyMem(addr dst[0], unsafeAddr value, dst.len)
template loopOverWords(lhs, rhs: BitSeq,
lhsIsVar, rhsIsVar: static bool,
WordType: type,
lhsBits, rhsBits, body: untyped) =
const hasRhs = astToStr(lhs) != astToStr(rhs)
let bytesCount = len Bytes(lhs)
when hasRhs: doAssert len(Bytes(rhs)) == bytesCount
var fullWordsCount = bytesCount div sizeof(WordType)
let lastWordSize = bytesCount mod sizeof(WordType)
block:
var lhsWord: WordType
when hasRhs:
var rhsWord: WordType
var firstByteOfLastWord, lastByteOfLastWord, markerPos: int
# TODO: Returing a `var` value from an iterator is always safe due to
# the way inlining works, but currently the compiler reports an error
# when a local variable escapes. We have to cheat it with this location
# obfuscation through pointers:
template lhsBits: auto = (addr(lhsWord))[]
when hasRhs:
template rhsBits: auto = (addr(rhsWord))[]
template lastWordBytes(bitseq): auto =
Bytes(bitseq).toOpenArray(firstByteOfLastWord, lastByteOfLastWord)
template initBitsVars =
lhsWord = loadLEBytes(WordType, lastWordBytes(lhs))
when hasRhs: rhsWord = loadLEBytes(WordType, lastWordBytes(rhs))
if lastWordSize == 0:
firstByteOfLastWord = bytesCount - sizeof(WordType)
lastByteOfLastWord = bytesCount - 1
initBitsVars()
markerPos = sizeof(WordType) * 8 - 1
dec fullWordsCount
else:
firstByteOfLastWord = bytesCount - lastWordSize
lastByteOfLastWord = bytesCount - 1
initBitsVars()
markerPos = log2trunc(lhsWord)
when hasRhs: doAssert log2trunc(rhsWord) == markerPos
lhsWord.lowerBit markerPos
when hasRhs: rhsWord.lowerBit markerPos
body
when lhsIsVar or rhsIsVar:
let
markerBit = uint(1 shl markerPos)
mask = markerBit - 1'u
when lhsIsVar:
let lhsEndResult = (lhsWord and mask) or markerBit
storeLEBytes(lhsEndResult, lastWordBytes(lhs))
when rhsIsVar:
let rhsEndResult = (rhsWord and mask) or markerBit
storeLEBytes(rhsEndResult, lastWordBytes(rhs))
var lhsCurrAddr = cast[ptr WordType](unsafeAddr Bytes(lhs)[0])
let lhsEndAddr = shift(lhsCurrAddr, fullWordsCount)
when hasRhs:
var rhsCurrAddr = cast[ptr WordType](unsafeAddr Bytes(rhs)[0])
while lhsCurrAddr < lhsEndAddr:
template lhsBits: auto = lhsCurrAddr[]
when hasRhs:
template rhsBits: auto = rhsCurrAddr[]
body
lhsCurrAddr = shift(lhsCurrAddr, 1)
when hasRhs: rhsCurrAddr = shift(rhsCurrAddr, 1)
iterator words*(x: var BitSeq): var uint =
loopOverWords(x, x, true, false, uint, word, wordB):
yield word
iterator words*(x: BitSeq): uint =
loopOverWords(x, x, false, false, uint, word, word):
yield word
iterator words*(a, b: BitSeq): (uint, uint) =
loopOverWords(a, b, false, false, uint, wordA, wordB):
yield (wordA, wordB)
iterator words*(a: var BitSeq, b: BitSeq): (var uint, uint) =
loopOverWords(a, b, true, false, uint, wordA, wordB):
yield (wordA, wordB)
iterator words*(a, b: var BitSeq): (var uint, var uint) =
loopOverWords(a, b, true, true, uint, wordA, wordB):
yield (wordA, wordB)
func `[]`*(s: BitSeq, pos: Natural): bool {.inline.} =
doAssert pos < s.len
s.Bytes.getBit pos
@ -95,18 +224,17 @@ func `$`*(a: BitSeq): string =
let length = a.len
result = newStringOfCap(2 + length)
result.add "0b"
for i in 0 ..< length:
for i in countdown(length - 1, 0):
result.add if a[i]: '1' else: '0'
func combine*(tgt: var BitSeq, src: BitSeq) =
doAssert tgt.len == src.len
for i in 0 ..< tgt.bytes.len:
tgt.bytes[i] = tgt.bytes[i] or src.bytes[i]
for tgtWord, srcWord in words(tgt, src):
tgtWord = tgtWord or srcWord
func overlaps*(a, b: BitSeq): bool =
doAssert a.len == b.len
for i in 0..< a.bytes.len:
if (a.bytes[i] and b.bytes[i]) > 0'u8:
for wa, wb in words(a, b):
if (wa and wb) != 0:
return true
func isSubsetOf*(a, b: BitSeq): bool =

View File

@ -126,6 +126,12 @@ func fromBytes*(
## Read bytes and convert to an integer according to the given endianess. At
## runtime, v must contain at least sizeof(T) bytes. By default, native
## endianess is used which is not portable!
##
## REVIEW COMMENT (zah)
## This API is very strange. Why can't I pass an open array of 3 bytes
## to be interpreted as a LE number? Also, why is `endian` left as a
## run-time parameter (with such short functions, it could easily be static).
const ts = sizeof(T) # Nim bug: can't use sizeof directly
var tmp: array[ts, byte]
for i in 0..<tmp.len: # Loop since vm can't copymem

View File

@ -14,6 +14,9 @@ template distance*(a, b: pointer): int =
template shift*[T](p: ptr T, delta: int): ptr T =
cast[ptr T](shift(cast[pointer](p), delta * sizeof(T)))
proc `<`*(a, b: pointer): bool =
cast[uint](a) < cast[uint](b)
when (NimMajor,NimMinor,NimPatch) >= (0,19,9):
template makeOpenArray*[T](p: ptr T, len: int): auto =
toOpenArray(cast[ptr UncheckedArray[T]](p), 0, len - 1)

View File

@ -1,6 +1,6 @@
import
unittest,
../stew/bitseqs
unittest, strformat,
../stew/[bitseqs, bitops2]
suite "Bit fields":
test "roundtrips":
@ -25,3 +25,34 @@ suite "Bit fields":
not a[0]
a[1]
a[2]
test "iterating words":
for bitCount in [8, 3, 7, 8, 14, 15, 16, 19, 260]:
checkpoint &"trying bit count {bitCount}"
var
a = BitSeq.init(bitCount)
b = BitSeq.init(bitCount)
bitsInWord = sizeof(uint) * 8
expectedWordCount = (bitCount div bitsInWord) + 1
for i in 0 ..< expectedWordCount:
let every3rdBit = i * sizeof(uint) * 8 + 2
a[every3rdBit] = true
b[every3rdBit] = true
for word in words(a):
check word == 4
word = 2
for wa, wb in words(a, b):
check wa == 2 and wb == 4
wa = 1
wb = 2
for i in 0 ..< expectedWordCount:
for j in 0 ..< bitsInWord:
let bitPos = i * bitsInWord + j
if bitPos < bitCount:
check a[j] == (j == 0)
check b[j] == (j == 1)