Hide the complexity of dealing with the BitSeq marker bit inside an efficient machine words iterator
This commit is contained in:
parent
a3df04d701
commit
19fd0cc489
148
stew/bitseqs.nim
148
stew/bitseqs.nim
|
@ -1,9 +1,25 @@
|
|||
import
|
||||
bitops2
|
||||
bitops2, endians2, ranges/ptr_arith
|
||||
|
||||
type
|
||||
Bytes = seq[byte]
|
||||
|
||||
BitSeq* = distinct Bytes
|
||||
## TODO
|
||||
##
|
||||
## The current design of BitSeq tries to follow precisely
|
||||
## the bitwise representation of the SSZ bitlists.
|
||||
## This is a relatively compact representation, but as
|
||||
## evident from the code below, many of the operations
|
||||
## are not trivial.
|
||||
##
|
||||
## An alternative simpler approach would be to maintain
|
||||
## the BitSeq as a sequence of words with an external uint8
|
||||
## counter denoting the used bits in the last word.
|
||||
##
|
||||
## This will reduce the complexity of the code here, but
|
||||
## we'll have to define serialization routines for all the
|
||||
## formats where such values appear (SSZ, JSON, YAML, etc).
|
||||
|
||||
BitArray*[bits: static int] = object
|
||||
bytes*: array[(bits + 7) div 8, byte]
|
||||
|
@ -19,9 +35,6 @@ func len*(s: BitSeq): int =
|
|||
template len*(a: BitArray): int =
|
||||
a.bits
|
||||
|
||||
template bytes*(s: BitSeq): untyped =
|
||||
Bytes(s)
|
||||
|
||||
func add*(s: var BitSeq, value: bool) =
|
||||
let
|
||||
lastBytePos = s.Bytes.len - 1
|
||||
|
@ -37,6 +50,122 @@ func add*(s: var BitSeq, value: bool) =
|
|||
s.Bytes[lastBytePos].setBit 7, value
|
||||
s.Bytes.add byte(1)
|
||||
|
||||
func loadLEBytes(WordType: type, bytes: openarray[byte]): WordType =
|
||||
# TODO: this is a temporary proc until the endians API is improved
|
||||
var shift = 0
|
||||
for b in bytes:
|
||||
result = result or (WordType(b) shl shift)
|
||||
shift += 8
|
||||
|
||||
func storeLEBytes(value: SomeUnsignedInt, dst: var openarray[byte]) =
|
||||
when system.cpuEndian == bigEndian:
|
||||
var shift = 0
|
||||
for i in 0 ..< dst.len:
|
||||
result[i] = byte((v shr shift) and 0xff)
|
||||
shift += 8
|
||||
else:
|
||||
copyMem(addr dst[0], unsafeAddr value, dst.len)
|
||||
|
||||
template loopOverWords(lhs, rhs: BitSeq,
|
||||
lhsIsVar, rhsIsVar: static bool,
|
||||
WordType: type,
|
||||
lhsBits, rhsBits, body: untyped) =
|
||||
const hasRhs = astToStr(lhs) != astToStr(rhs)
|
||||
|
||||
let bytesCount = len Bytes(lhs)
|
||||
when hasRhs: doAssert len(Bytes(rhs)) == bytesCount
|
||||
|
||||
var fullWordsCount = bytesCount div sizeof(WordType)
|
||||
let lastWordSize = bytesCount mod sizeof(WordType)
|
||||
|
||||
block:
|
||||
var lhsWord: WordType
|
||||
when hasRhs:
|
||||
var rhsWord: WordType
|
||||
var firstByteOfLastWord, lastByteOfLastWord, markerPos: int
|
||||
|
||||
# TODO: Returing a `var` value from an iterator is always safe due to
|
||||
# the way inlining works, but currently the compiler reports an error
|
||||
# when a local variable escapes. We have to cheat it with this location
|
||||
# obfuscation through pointers:
|
||||
template lhsBits: auto = (addr(lhsWord))[]
|
||||
|
||||
when hasRhs:
|
||||
template rhsBits: auto = (addr(rhsWord))[]
|
||||
|
||||
template lastWordBytes(bitseq): auto =
|
||||
Bytes(bitseq).toOpenArray(firstByteOfLastWord, lastByteOfLastWord)
|
||||
|
||||
template initBitsVars =
|
||||
lhsWord = loadLEBytes(WordType, lastWordBytes(lhs))
|
||||
when hasRhs: rhsWord = loadLEBytes(WordType, lastWordBytes(rhs))
|
||||
|
||||
if lastWordSize == 0:
|
||||
firstByteOfLastWord = bytesCount - sizeof(WordType)
|
||||
lastByteOfLastWord = bytesCount - 1
|
||||
initBitsVars()
|
||||
markerPos = sizeof(WordType) * 8 - 1
|
||||
dec fullWordsCount
|
||||
else:
|
||||
firstByteOfLastWord = bytesCount - lastWordSize
|
||||
lastByteOfLastWord = bytesCount - 1
|
||||
initBitsVars()
|
||||
markerPos = log2trunc(lhsWord)
|
||||
when hasRhs: doAssert log2trunc(rhsWord) == markerPos
|
||||
|
||||
lhsWord.lowerBit markerPos
|
||||
when hasRhs: rhsWord.lowerBit markerPos
|
||||
|
||||
body
|
||||
|
||||
when lhsIsVar or rhsIsVar:
|
||||
let
|
||||
markerBit = uint(1 shl markerPos)
|
||||
mask = markerBit - 1'u
|
||||
|
||||
when lhsIsVar:
|
||||
let lhsEndResult = (lhsWord and mask) or markerBit
|
||||
storeLEBytes(lhsEndResult, lastWordBytes(lhs))
|
||||
|
||||
when rhsIsVar:
|
||||
let rhsEndResult = (rhsWord and mask) or markerBit
|
||||
storeLEBytes(rhsEndResult, lastWordBytes(rhs))
|
||||
|
||||
var lhsCurrAddr = cast[ptr WordType](unsafeAddr Bytes(lhs)[0])
|
||||
let lhsEndAddr = shift(lhsCurrAddr, fullWordsCount)
|
||||
when hasRhs:
|
||||
var rhsCurrAddr = cast[ptr WordType](unsafeAddr Bytes(rhs)[0])
|
||||
|
||||
while lhsCurrAddr < lhsEndAddr:
|
||||
template lhsBits: auto = lhsCurrAddr[]
|
||||
when hasRhs:
|
||||
template rhsBits: auto = rhsCurrAddr[]
|
||||
|
||||
body
|
||||
|
||||
lhsCurrAddr = shift(lhsCurrAddr, 1)
|
||||
when hasRhs: rhsCurrAddr = shift(rhsCurrAddr, 1)
|
||||
|
||||
iterator words*(x: var BitSeq): var uint =
|
||||
loopOverWords(x, x, true, false, uint, word, wordB):
|
||||
yield word
|
||||
|
||||
iterator words*(x: BitSeq): uint =
|
||||
loopOverWords(x, x, false, false, uint, word, word):
|
||||
yield word
|
||||
|
||||
iterator words*(a, b: BitSeq): (uint, uint) =
|
||||
loopOverWords(a, b, false, false, uint, wordA, wordB):
|
||||
yield (wordA, wordB)
|
||||
|
||||
iterator words*(a: var BitSeq, b: BitSeq): (var uint, uint) =
|
||||
loopOverWords(a, b, true, false, uint, wordA, wordB):
|
||||
yield (wordA, wordB)
|
||||
|
||||
iterator words*(a, b: var BitSeq): (var uint, var uint) =
|
||||
loopOverWords(a, b, true, true, uint, wordA, wordB):
|
||||
yield (wordA, wordB)
|
||||
|
||||
func `[]`*(s: BitSeq, pos: Natural): bool {.inline.} =
|
||||
doAssert pos < s.len
|
||||
s.Bytes.getBit pos
|
||||
|
@ -95,18 +224,17 @@ func `$`*(a: BitSeq): string =
|
|||
let length = a.len
|
||||
result = newStringOfCap(2 + length)
|
||||
result.add "0b"
|
||||
for i in 0 ..< length:
|
||||
for i in countdown(length - 1, 0):
|
||||
result.add if a[i]: '1' else: '0'
|
||||
|
||||
func combine*(tgt: var BitSeq, src: BitSeq) =
|
||||
doAssert tgt.len == src.len
|
||||
for i in 0 ..< tgt.bytes.len:
|
||||
tgt.bytes[i] = tgt.bytes[i] or src.bytes[i]
|
||||
for tgtWord, srcWord in words(tgt, src):
|
||||
tgtWord = tgtWord or srcWord
|
||||
|
||||
func overlaps*(a, b: BitSeq): bool =
|
||||
doAssert a.len == b.len
|
||||
for i in 0..< a.bytes.len:
|
||||
if (a.bytes[i] and b.bytes[i]) > 0'u8:
|
||||
for wa, wb in words(a, b):
|
||||
if (wa and wb) != 0:
|
||||
return true
|
||||
|
||||
func isSubsetOf*(a, b: BitSeq): bool =
|
||||
|
|
|
@ -126,6 +126,12 @@ func fromBytes*(
|
|||
## Read bytes and convert to an integer according to the given endianess. At
|
||||
## runtime, v must contain at least sizeof(T) bytes. By default, native
|
||||
## endianess is used which is not portable!
|
||||
##
|
||||
## REVIEW COMMENT (zah)
|
||||
## This API is very strange. Why can't I pass an open array of 3 bytes
|
||||
## to be interpreted as a LE number? Also, why is `endian` left as a
|
||||
## run-time parameter (with such short functions, it could easily be static).
|
||||
|
||||
const ts = sizeof(T) # Nim bug: can't use sizeof directly
|
||||
var tmp: array[ts, byte]
|
||||
for i in 0..<tmp.len: # Loop since vm can't copymem
|
||||
|
|
|
@ -14,6 +14,9 @@ template distance*(a, b: pointer): int =
|
|||
template shift*[T](p: ptr T, delta: int): ptr T =
|
||||
cast[ptr T](shift(cast[pointer](p), delta * sizeof(T)))
|
||||
|
||||
proc `<`*(a, b: pointer): bool =
|
||||
cast[uint](a) < cast[uint](b)
|
||||
|
||||
when (NimMajor,NimMinor,NimPatch) >= (0,19,9):
|
||||
template makeOpenArray*[T](p: ptr T, len: int): auto =
|
||||
toOpenArray(cast[ptr UncheckedArray[T]](p), 0, len - 1)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import
|
||||
unittest,
|
||||
../stew/bitseqs
|
||||
unittest, strformat,
|
||||
../stew/[bitseqs, bitops2]
|
||||
|
||||
suite "Bit fields":
|
||||
test "roundtrips":
|
||||
|
@ -25,3 +25,34 @@ suite "Bit fields":
|
|||
not a[0]
|
||||
a[1]
|
||||
a[2]
|
||||
|
||||
test "iterating words":
|
||||
for bitCount in [8, 3, 7, 8, 14, 15, 16, 19, 260]:
|
||||
checkpoint &"trying bit count {bitCount}"
|
||||
var
|
||||
a = BitSeq.init(bitCount)
|
||||
b = BitSeq.init(bitCount)
|
||||
bitsInWord = sizeof(uint) * 8
|
||||
expectedWordCount = (bitCount div bitsInWord) + 1
|
||||
|
||||
for i in 0 ..< expectedWordCount:
|
||||
let every3rdBit = i * sizeof(uint) * 8 + 2
|
||||
a[every3rdBit] = true
|
||||
b[every3rdBit] = true
|
||||
|
||||
for word in words(a):
|
||||
check word == 4
|
||||
word = 2
|
||||
|
||||
for wa, wb in words(a, b):
|
||||
check wa == 2 and wb == 4
|
||||
wa = 1
|
||||
wb = 2
|
||||
|
||||
for i in 0 ..< expectedWordCount:
|
||||
for j in 0 ..< bitsInWord:
|
||||
let bitPos = i * bitsInWord + j
|
||||
if bitPos < bitCount:
|
||||
check a[j] == (j == 0)
|
||||
check b[j] == (j == 1)
|
||||
|
||||
|
|
Loading…
Reference in New Issue