diff --git a/stew/bitseqs.nim b/stew/bitseqs.nim index 6453fd6..fc1b086 100644 --- a/stew/bitseqs.nim +++ b/stew/bitseqs.nim @@ -1,9 +1,25 @@ import - bitops2 + bitops2, endians2, ranges/ptr_arith type Bytes = seq[byte] + BitSeq* = distinct Bytes + ## TODO + ## + ## The current design of BitSeq tries to follow precisely + ## the bitwise representation of the SSZ bitlists. + ## This is a relatively compact representation, but as + ## evident from the code below, many of the operations + ## are not trivial. + ## + ## An alternative simpler approach would be to maintain + ## the BitSeq as a sequence of words with an external uint8 + ## counter denoting the used bits in the last word. + ## + ## This will reduce the complexity of the code here, but + ## we'll have to define serialization routines for all the + ## formats where such values appear (SSZ, JSON, YAML, etc). BitArray*[bits: static int] = object bytes*: array[(bits + 7) div 8, byte] @@ -19,9 +35,6 @@ func len*(s: BitSeq): int = template len*(a: BitArray): int = a.bits -template bytes*(s: BitSeq): untyped = - Bytes(s) - func add*(s: var BitSeq, value: bool) = let lastBytePos = s.Bytes.len - 1 @@ -37,6 +50,122 @@ func add*(s: var BitSeq, value: bool) = s.Bytes[lastBytePos].setBit 7, value s.Bytes.add byte(1) +func loadLEBytes(WordType: type, bytes: openarray[byte]): WordType = + # TODO: this is a temporary proc until the endians API is improved + var shift = 0 + for b in bytes: + result = result or (WordType(b) shl shift) + shift += 8 + +func storeLEBytes(value: SomeUnsignedInt, dst: var openarray[byte]) = + when system.cpuEndian == bigEndian: + var shift = 0 + for i in 0 ..< dst.len: + result[i] = byte((v shr shift) and 0xff) + shift += 8 + else: + copyMem(addr dst[0], unsafeAddr value, dst.len) + +template loopOverWords(lhs, rhs: BitSeq, + lhsIsVar, rhsIsVar: static bool, + WordType: type, + lhsBits, rhsBits, body: untyped) = + const hasRhs = astToStr(lhs) != astToStr(rhs) + + let bytesCount = len Bytes(lhs) + when hasRhs: doAssert len(Bytes(rhs)) == bytesCount + + var fullWordsCount = bytesCount div sizeof(WordType) + let lastWordSize = bytesCount mod sizeof(WordType) + + block: + var lhsWord: WordType + when hasRhs: + var rhsWord: WordType + var firstByteOfLastWord, lastByteOfLastWord, markerPos: int + + # TODO: Returing a `var` value from an iterator is always safe due to + # the way inlining works, but currently the compiler reports an error + # when a local variable escapes. We have to cheat it with this location + # obfuscation through pointers: + template lhsBits: auto = (addr(lhsWord))[] + + when hasRhs: + template rhsBits: auto = (addr(rhsWord))[] + + template lastWordBytes(bitseq): auto = + Bytes(bitseq).toOpenArray(firstByteOfLastWord, lastByteOfLastWord) + + template initBitsVars = + lhsWord = loadLEBytes(WordType, lastWordBytes(lhs)) + when hasRhs: rhsWord = loadLEBytes(WordType, lastWordBytes(rhs)) + + if lastWordSize == 0: + firstByteOfLastWord = bytesCount - sizeof(WordType) + lastByteOfLastWord = bytesCount - 1 + initBitsVars() + markerPos = sizeof(WordType) * 8 - 1 + dec fullWordsCount + else: + firstByteOfLastWord = bytesCount - lastWordSize + lastByteOfLastWord = bytesCount - 1 + initBitsVars() + markerPos = log2trunc(lhsWord) + when hasRhs: doAssert log2trunc(rhsWord) == markerPos + + lhsWord.lowerBit markerPos + when hasRhs: rhsWord.lowerBit markerPos + + body + + when lhsIsVar or rhsIsVar: + let + markerBit = uint(1 shl markerPos) + mask = markerBit - 1'u + + when lhsIsVar: + let lhsEndResult = (lhsWord and mask) or markerBit + storeLEBytes(lhsEndResult, lastWordBytes(lhs)) + + when rhsIsVar: + let rhsEndResult = (rhsWord and mask) or markerBit + storeLEBytes(rhsEndResult, lastWordBytes(rhs)) + + var lhsCurrAddr = cast[ptr WordType](unsafeAddr Bytes(lhs)[0]) + let lhsEndAddr = shift(lhsCurrAddr, fullWordsCount) + when hasRhs: + var rhsCurrAddr = cast[ptr WordType](unsafeAddr Bytes(rhs)[0]) + + while lhsCurrAddr < lhsEndAddr: + template lhsBits: auto = lhsCurrAddr[] + when hasRhs: + template rhsBits: auto = rhsCurrAddr[] + + body + + lhsCurrAddr = shift(lhsCurrAddr, 1) + when hasRhs: rhsCurrAddr = shift(rhsCurrAddr, 1) + +iterator words*(x: var BitSeq): var uint = + loopOverWords(x, x, true, false, uint, word, wordB): + yield word + +iterator words*(x: BitSeq): uint = + loopOverWords(x, x, false, false, uint, word, word): + yield word + +iterator words*(a, b: BitSeq): (uint, uint) = + loopOverWords(a, b, false, false, uint, wordA, wordB): + yield (wordA, wordB) + +iterator words*(a: var BitSeq, b: BitSeq): (var uint, uint) = + loopOverWords(a, b, true, false, uint, wordA, wordB): + yield (wordA, wordB) + +iterator words*(a, b: var BitSeq): (var uint, var uint) = + loopOverWords(a, b, true, true, uint, wordA, wordB): + yield (wordA, wordB) + func `[]`*(s: BitSeq, pos: Natural): bool {.inline.} = doAssert pos < s.len s.Bytes.getBit pos @@ -95,18 +224,17 @@ func `$`*(a: BitSeq): string = let length = a.len result = newStringOfCap(2 + length) result.add "0b" - for i in 0 ..< length: + for i in countdown(length - 1, 0): result.add if a[i]: '1' else: '0' func combine*(tgt: var BitSeq, src: BitSeq) = doAssert tgt.len == src.len - for i in 0 ..< tgt.bytes.len: - tgt.bytes[i] = tgt.bytes[i] or src.bytes[i] + for tgtWord, srcWord in words(tgt, src): + tgtWord = tgtWord or srcWord func overlaps*(a, b: BitSeq): bool = - doAssert a.len == b.len - for i in 0..< a.bytes.len: - if (a.bytes[i] and b.bytes[i]) > 0'u8: + for wa, wb in words(a, b): + if (wa and wb) != 0: return true func isSubsetOf*(a, b: BitSeq): bool = diff --git a/stew/endians2.nim b/stew/endians2.nim index a12f702..a18d61b 100644 --- a/stew/endians2.nim +++ b/stew/endians2.nim @@ -126,6 +126,12 @@ func fromBytes*( ## Read bytes and convert to an integer according to the given endianess. At ## runtime, v must contain at least sizeof(T) bytes. By default, native ## endianess is used which is not portable! + ## + ## REVIEW COMMENT (zah) + ## This API is very strange. Why can't I pass an open array of 3 bytes + ## to be interpreted as a LE number? Also, why is `endian` left as a + ## run-time parameter (with such short functions, it could easily be static). + const ts = sizeof(T) # Nim bug: can't use sizeof directly var tmp: array[ts, byte] for i in 0..= (0,19,9): template makeOpenArray*[T](p: ptr T, len: int): auto = toOpenArray(cast[ptr UncheckedArray[T]](p), 0, len - 1) diff --git a/tests/test_bitseqs.nim b/tests/test_bitseqs.nim index 3f18d25..10f86f6 100644 --- a/tests/test_bitseqs.nim +++ b/tests/test_bitseqs.nim @@ -1,6 +1,6 @@ import - unittest, - ../stew/bitseqs + unittest, strformat, + ../stew/[bitseqs, bitops2] suite "Bit fields": test "roundtrips": @@ -25,3 +25,34 @@ suite "Bit fields": not a[0] a[1] a[2] + + test "iterating words": + for bitCount in [8, 3, 7, 8, 14, 15, 16, 19, 260]: + checkpoint &"trying bit count {bitCount}" + var + a = BitSeq.init(bitCount) + b = BitSeq.init(bitCount) + bitsInWord = sizeof(uint) * 8 + expectedWordCount = (bitCount div bitsInWord) + 1 + + for i in 0 ..< expectedWordCount: + let every3rdBit = i * sizeof(uint) * 8 + 2 + a[every3rdBit] = true + b[every3rdBit] = true + + for word in words(a): + check word == 4 + word = 2 + + for wa, wb in words(a, b): + check wa == 2 and wb == 4 + wa = 1 + wb = 2 + + for i in 0 ..< expectedWordCount: + for j in 0 ..< bitsInWord: + let bitPos = i * bitsInWord + j + if bitPos < bitCount: + check a[j] == (j == 0) + check b[j] == (j == 1) +