From 66ad5497d9b6f341fce0804abaec0cf03b75f497 Mon Sep 17 00:00:00 2001 From: Jacek Sieka Date: Mon, 9 Dec 2024 08:15:04 +0100 Subject: [PATCH] Unroll nibble ops (#2894) A bit unexpectedly, nibble handling shows up in the profiler mainly because the current impl is tuned towards slicing while the most common operation is prefix comparison - since the code is simple, might has well get rid of some of the excess fat by always aliging the nibbles to the byte buffer. --- nimbus/db/aristo/aristo_desc/desc_nibbles.nim | 331 ++++++++++++------ nimbus/db/aristo/aristo_fetch.nim | 11 +- nimbus/db/aristo/aristo_merge.nim | 8 +- tests/test_aristo/test_nibbles.nim | 98 +++++- 4 files changed, 322 insertions(+), 126 deletions(-) diff --git a/nimbus/db/aristo/aristo_desc/desc_nibbles.nim b/nimbus/db/aristo/aristo_desc/desc_nibbles.nim index beb10324d..623df22d7 100644 --- a/nimbus/db/aristo/aristo_desc/desc_nibbles.nim +++ b/nimbus/db/aristo/aristo_desc/desc_nibbles.nim @@ -8,7 +8,9 @@ # at your option. This file may not be copied, modified, or distributed # except according to those terms. -import stew/[arraybuf, arrayops] +{.push raises: [], gcsafe, inline.} + +import stew/[arraybuf, arrayops, bitops2, endians2, staticfor] export arraybuf @@ -16,8 +18,13 @@ type NibblesBuf* = object ## Allocation-free type for storing up to 64 4-bit nibbles, as seen in the ## Ethereum MPT - bytes: array[32, byte] - ibegin, iend: int8 + limbs: array[4, uint64] + # Each limb holds 16 nibbles in big endian order - for buffers shorter + # 64 nibbles we make sure the last limb holding any data is zero-padded + # (so as to avoid UB on uninitialized reads) - for example a buffer + # holding one nibble will have one fully initialized limb and 3 + # uninitialized limbs. + iend: uint8 # Where valid nibbles can be found - we use indices here to avoid copies # wen slicing - iend not inclusive @@ -26,39 +33,60 @@ type func high*(T: type NibblesBuf): int = 63 -func fromBytes*(T: type NibblesBuf, bytes: openArray[byte]): T = - result.iend = 2 * (int8 result.bytes.copyFrom(bytes)) - -func nibble*(T: type NibblesBuf, nibble: byte): T = - result.bytes[0] = nibble shl 4 +func nibble*(T: type NibblesBuf, nibble: byte): T {.noinit.} = + result.limbs[0] = uint64(nibble) shl (64 - 4) result.iend = 1 -template `[]`*(r: NibblesBuf, i: int): byte = - let pos = r.ibegin + i - if (pos and 1) != 0: - (r.bytes[pos shr 1] and 0xf) - else: - (r.bytes[pos shr 1] shr 4) +template limb(i: int | uint8): uint8 = + # In which limb can nibble i be found? + uint8(i) shr 4 # shr 4 = div 16 = 16 nibbles per limb -template `[]=`*(r: NibblesBuf, i: int, v: byte) = - let pos = r.ibegin + i - r.bytes[pos shr 1] = - if (pos and 1) != 0: - (v and 0x0f) or (r.bytes[pos shr 1] and 0xf0) - else: - (v shl 4) or (r.bytes[pos shr 1] and 0x0f) +template shift(i: int | uint8): int = + # How many bits to shift to find nibble i within its limb? + 60 - ((i mod 16) shl 2) # shl 2 = 4 bits per nibble + +func `[]`*(r: NibblesBuf, i: int): byte = + let + ilimb = i.limb + ishift = i.shift + byte((r.limbs[ilimb] shr ishift) and 0x0f) + +func `[]=`*(r: var NibblesBuf, i: int, v: byte) = + let + ilimb = i.limb + ishift = i.shift + + r.limbs[ilimb] = + (uint64(v and 0x0f) shl ishift) or ((r.limbs[ilimb] and not (0x0f'u64 shl ishift))) + +func fromBytes*(T: type NibblesBuf, bytes: openArray[byte]): T {.noinit.} = + if bytes.len >= 32: + result.iend = 64 + staticFor i, 0 ..< result.limbs.len: + const pos = i * 8 # 16 nibbles per limb, 2 nibbles per byte + result.limbs[i] = uint64.fromBytesBE(bytes.toOpenArray(pos, pos + 7)) + else: + let blen = uint8(bytes.len) + result.iend = blen * 2 + + block done: + staticFor i, 0 ..< result.limbs.len: + const pos = i * 8 + if pos + 7 < blen: + result.limbs[i] = uint64.fromBytesBE(bytes.toOpenArray(pos, pos + 7)) + else: + if pos < blen: + var tmp = 0'u64 + var shift = 56'u8 + for j in uint8(pos) ..< blen: + tmp = tmp or uint64(bytes[j]) shl shift + shift -= 8 + + result.limbs[i] = tmp + break done func len*(r: NibblesBuf): int = - r.iend - r.ibegin - -func `==`*(lhs, rhs: NibblesBuf): bool = - if lhs.len == rhs.len: - for i in 0 ..< lhs.len: - if lhs[i] != rhs[i]: - return false - return true - else: - return false + int(r.iend) func `$`*(r: NibblesBuf): string = result = newStringOfCap(64) @@ -66,96 +94,179 @@ func `$`*(r: NibblesBuf): string = const chars = "0123456789abcdef" result.add chars[r[i]] -func slice*(r: NibblesBuf, ibegin: int, iend = -1): NibblesBuf {.noinit.} = - result.bytes = r.bytes - result.ibegin = r.ibegin + ibegin.int8 - let e = - if iend < 0: - min(64, r.iend + iend + 1) - else: - min(64, r.ibegin + iend) - doAssert ibegin >= 0 and e <= result.bytes.len * 2 - result.iend = e.int8 +func `==`*(lhs, rhs: NibblesBuf): bool = + if lhs.iend != rhs.iend: + return false -func replaceSuffix*(r: NibblesBuf, suffix: NibblesBuf): NibblesBuf = - for i in 0 ..< r.len - suffix.len: - result[i] = r[i] - for i in 0 ..< suffix.len: - result[i + r.len - suffix.len] = suffix[i] - result.iend = min(64, r.len + suffix.len).int8 - -template writeFirstByte(nibbleCountExpr) {.dirty.} = - let nibbleCount = nibbleCountExpr - var oddnessFlag = (nibbleCount and 1) != 0 - result.setLen((nibbleCount div 2) + 1) - result[0] = byte((int(isLeaf) * 2 + int(oddnessFlag)) shl 4) - var writeHead = 0 - -template writeNibbles(r) {.dirty.} = - for i in 0 ..< r.len: - let nextNibble = r[i] - if oddnessFlag: - result[writeHead] = result[writeHead] or nextNibble - else: - inc writeHead - result[writeHead] = nextNibble shl 4 - oddnessFlag = not oddnessFlag - -func toHexPrefix*(r: NibblesBuf, isLeaf = false): HexPrefixBuf = - writeFirstByte(r.len) - writeNibbles(r) - -func toHexPrefix*(r1, r2: NibblesBuf, isLeaf = false): HexPrefixBuf = - writeFirstByte(r1.len + r2.len) - writeNibbles(r1) - writeNibbles(r2) + staticFor i, 0 ..< lhs.limbs.len: + if uint8(i * 16) >= lhs.iend: + return true + if lhs.limbs[i] != rhs.limbs[i]: + return false + true func sharedPrefixLen*(lhs, rhs: NibblesBuf): int = - result = 0 - while result < lhs.len and result < rhs.len: - if lhs[result] != rhs[result]: - break - inc result + let len = min(lhs.iend, rhs.iend) + staticFor i, 0 ..< lhs.limbs.len: + const pos = i * 16 + + if (pos + 16) >= len or lhs.limbs[i] != rhs.limbs[i]: + return + if pos < len: + let mask = + if len - pos >= 16: + 0'u64 + else: + (not 0'u64) shr ((len - pos) * 4) + pos + leadingZeros((lhs.limbs[i] xor rhs.limbs[i]) or mask) shr 2 + else: + pos + + 64 func startsWith*(lhs, rhs: NibblesBuf): bool = sharedPrefixLen(lhs, rhs) == rhs.len +func slice*(r: NibblesBuf, ibegin: int, iend = -1): NibblesBuf {.noinit.} = + let e = + if iend < 0: + min(64, r.len + iend + 1) + else: + min(64, iend) + + # With noinit, we have to be careful not to read result.bytes + result.iend = uint8(e - ibegin) + + var ilimb = ibegin.limb + block done: + let shift = (ibegin mod 16) shl 2 + if shift == 0: # Must be careful not to shift by 64 which is UB! + staticFor i, 0 ..< result.limbs.len: + if uint8(i * 16) >= result.iend: + break done + result.limbs[i] = r.limbs[ilimb] + ilimb += 1 + else: + staticFor i, 0 ..< result.limbs.len: + if uint8(i * 16) >= result.iend: + break done + + let cur = r.limbs[ilimb] shl shift + ilimb += 1 + + result.limbs[i] = + if (ilimb * 16) < uint8 r.iend: + let next = r.limbs[ilimb] shr (64 - shift) + cur or next + else: + cur + +template copyshr(aend: uint8) = + block adone: # copy aend nibbles of a + staticFor i, 0 ..< result.limbs.len: + if uint8(i * 16) >= aend: + break adone + + result.limbs[i] = a.limbs[i] + + block bdone: + let shift = (aend mod 16) shl 2 + + var alimb = aend.limb + + if shift == 0: + staticFor i, 0 ..< result.limbs.len: + if uint8(i * 16) >= b.iend: + break bdone + + result.limbs[alimb] = b.limbs[i] + alimb += 1 + else: + # remove the part of a that should be b from the last a limb + result.limbs[alimb] = result.limbs[alimb] and ((not 0'u64) shl (64 - shift)) + + staticFor i, 0 ..< result.limbs.len: + if uint8(i * 16) >= b.iend: + break bdone + + # reading result.limbs here is safe because because the previous loop + # iteration will have initialized it (or the a copy on initial iteration) + result.limbs[alimb] = result.limbs[alimb] or b.limbs[i] shr shift + + alimb += 1 + if (alimb * 16) < result.iend: + result.limbs[alimb] = b.limbs[i] shl (64 - shift) + +func `&`*(a, b: NibblesBuf): NibblesBuf {.noinit.} = + result.iend = min(64'u8, a.iend + b.iend) + + let aend = a.iend + copyshr(aend) + +func replaceSuffix*(a, b: NibblesBuf): NibblesBuf {.noinit.} = + if b.iend >= a.iend: + result = b + elif b.iend == 0: + result = a + else: + result.iend = a.iend + + let aend = a.iend - b.iend + copyshr(aend) + +func toHexPrefix*(r: NibblesBuf, isLeaf = false): HexPrefixBuf {.noinit.} = + # We'll adjust to the actual length below, but this hack allows us to write + # full limbs + + result.n = 33 # careful with noinit, to not call setlen + let + limbs = (r.iend + 15).limb + isOdd = (r.iend and 1) > 0 + + result[0] = (byte(isLeaf) * 2 + byte(isOdd)) shl 4 + + if isOdd: + result[0] = result[0] or byte(r.limbs[0] shr 60) + + staticFor i, 0 ..< r.limbs.len: + if i < limbs: + let next = + when i == r.limbs.high: + 0'u64 + else: + r.limbs[i + 1] + let limb = r.limbs[i] shl 4 or next shr 60 + + const pos = i * 8 + 1 + assign(result.data.toOpenArray(pos, pos + 7), limb.toBytesBE()) + else: + staticFor i, 0 ..< r.limbs.len: + if i < limbs: + let limb = r.limbs[i] + const pos = i * 8 + 1 + assign(result.data.toOpenArray(pos, pos + 7), limb.toBytesBE()) + + result.setLen(int((r.iend shr 1) + 1)) + func fromHexPrefix*( - T: type NibblesBuf, r: openArray[byte] + T: type NibblesBuf, bytes: openArray[byte] ): tuple[isLeaf: bool, nibbles: NibblesBuf] {.noinit.} = - result.nibbles.ibegin = 0 + if bytes.len > 0: + result.isLeaf = (bytes[0] and 0x20) != 0 + let hasOddLen = (bytes[0] and 0x10) != 0 - if r.len > 0: - result.isLeaf = (r[0] and 0x20) != 0 - let hasOddLen = (r[0] and 0x10) != 0 - - result.nibbles.iend = - if hasOddLen: - result.nibbles.bytes[0] = r[0] shl 4 - - let bytes = min(31, r.len - 1) - for j in 0 ..< bytes: - result.nibbles.bytes[j] = result.nibbles.bytes[j] or r[j + 1] shr 4 - result.nibbles.bytes[j + 1] = r[j + 1] shl 4 - - int8(bytes) * 2 + 1 - else: - let bytes = min(32, r.len - 1) - assign(result.nibbles.bytes.toOpenArray(0, bytes - 1), r.toOpenArray(1, bytes)) - int8(bytes) * 2 + if hasOddLen: + let high = uint8(min(31, bytes.len - 1)) + result.nibbles = + NibblesBuf.nibble(bytes[0] and 0x0f) & + NibblesBuf.fromBytes(bytes.toOpenArray(1, int high)) + else: + result.nibbles = NibblesBuf.fromBytes(bytes.toOpenArray(1, bytes.high())) else: result.isLeaf = false result.nibbles.iend = 0 -func `&`*(a, b: NibblesBuf): NibblesBuf {.noinit.} = - result.ibegin = 0 - for i in 0 ..< a.len: - result[i] = a[i] - - for i in 0 ..< b.len: - result[i + a.len] = b[i] - - result.iend = int8(min(64, a.len + b.len)) - -template getBytes*(a: NibblesBuf): array[32, byte] = - a.bytes +func getBytes*(a: NibblesBuf): array[32, byte] = + staticFor i, 0 ..< a.limbs.len: + const pos = i * 8 + assign(result.toOpenArray(pos, pos + 7), a.limbs[i].toBytesBE) diff --git a/nimbus/db/aristo/aristo_fetch.nim b/nimbus/db/aristo/aristo_fetch.nim index 3e9128230..b364b5091 100644 --- a/nimbus/db/aristo/aristo_fetch.nim +++ b/nimbus/db/aristo/aristo_fetch.nim @@ -26,12 +26,9 @@ import proc retrieveLeaf( db: AristoDbRef; root: VertexID; - path: openArray[byte]; + path: Hash32; ): Result[VertexRef,AristoError] = - if path.len == 0: - return err(FetchPathInvalid) - - for step in stepUp(NibblesBuf.fromBytes(path), root, db): + for step in stepUp(NibblesBuf.fromBytes(path.data), root, db): let vtx = step.valueOr: if error in HikeAcceptableStopsNotFound: return err(FetchPathNotFound) @@ -68,7 +65,7 @@ proc retrieveAccountLeaf( # Updated payloads are stored in the layers so if we didn't find them there, # it must have been in the database let - leafVtx = db.retrieveLeaf(VertexID(1), accPath.data).valueOr: + leafVtx = db.retrieveLeaf(VertexID(1), accPath).valueOr: if error == FetchPathNotFound: db.accLeaves.put(accPath, nil) return err(error) @@ -168,7 +165,7 @@ proc retrieveStoragePayload( # Updated payloads are stored in the layers so if we didn't find them there, # it must have been in the database - let leafVtx = db.retrieveLeaf(? db.fetchStorageIdImpl(accPath), stoPath.data).valueOr: + let leafVtx = db.retrieveLeaf(? db.fetchStorageIdImpl(accPath), stoPath).valueOr: if error == FetchPathNotFound: db.stoLeaves.put(mixPath, nil) return err(error) diff --git a/nimbus/db/aristo/aristo_merge.nim b/nimbus/db/aristo/aristo_merge.nim index 89f2c3b0f..00275066a 100644 --- a/nimbus/db/aristo/aristo_merge.nim +++ b/nimbus/db/aristo/aristo_merge.nim @@ -41,7 +41,7 @@ proc layersPutLeaf( proc mergePayloadImpl( db: AristoDbRef, # Database, top layer root: VertexID, # MPT state root - path: openArray[byte], # Leaf item to add to the database + path: Hash32, # Leaf item to add to the database leaf: Opt[VertexRef], payload: LeafPayload, # Payload value ): Result[(VertexRef, VertexRef, VertexRef), AristoError] = @@ -51,7 +51,7 @@ proc mergePayloadImpl( ## accordingly. ## var - path = NibblesBuf.fromBytes(path) + path = NibblesBuf.fromBytes(path.data) cur = root (vtx, _) = db.getVtxRc((root, cur)).valueOr: if error != GetVtxNotFound: @@ -185,7 +185,7 @@ proc mergeAccountRecord*( let pyl = LeafPayload(pType: AccountData, account: accRec) updated = db.mergePayloadImpl( - VertexID(1), accPath.data, db.cachedAccLeaf(accPath), pyl).valueOr: + VertexID(1), accPath, db.cachedAccLeaf(accPath), pyl).valueOr: if error == MergeNoAction: return ok false return err(error) @@ -226,7 +226,7 @@ proc mergeStorageData*( # Call merge pyl = LeafPayload(pType: StoData, stoData: stoData) updated = db.mergePayloadImpl( - useID.vid, stoPath.data, db.cachedStoLeaf(mixPath), pyl).valueOr: + useID.vid, stoPath, db.cachedStoLeaf(mixPath), pyl).valueOr: if error == MergeNoAction: assert stoID.isValid # debugging only return ok() diff --git a/tests/test_aristo/test_nibbles.nim b/tests/test_aristo/test_nibbles.nim index 8fcaffe8a..f91ed2e16 100644 --- a/tests/test_aristo/test_nibbles.nim +++ b/tests/test_aristo/test_nibbles.nim @@ -11,7 +11,7 @@ {.used.} import - std/sequtils, + std/[sequtils, strutils], stew/byteutils, unittest2, ../../nimbus/db/aristo/aristo_desc/desc_nibbles @@ -30,7 +30,13 @@ suite "Nibbles": n[1] == 0 $n.slice(1) == "0" $n.slice(2) == "" + $n.slice(0, 0) == "" + $n.slice(1, 1) == "" + $n.slice(0, 1) == "1" + $(n & n) == "1010" + + NibblesBuf.nibble(0) != NibblesBuf.nibble(1) block: let n = NibblesBuf.fromBytes(repeat(byte 0x12, 32)) check: @@ -52,7 +58,6 @@ suite "Nibbles": let he = n.toHexPrefix(true) ho = n.slice(1).toHexPrefix(true) - check: NibblesBuf.fromHexPrefix(he.data()) == (true, n) NibblesBuf.fromHexPrefix(ho.data()) == (true, n.slice(1)) @@ -72,11 +77,33 @@ suite "Nibbles": test "long": let n = NibblesBuf.fromBytes( - hexToSeqByte("0100000000000000000000000000000000000000000000000000000000000000") + hexToSeqByte("0100000000000000000000000000000000000000000000000000000000000012") ) + check: + n.getBytes() == + hexToSeqByte("0100000000000000000000000000000000000000000000000000000000000012") - check $n == "0100000000000000000000000000000000000000000000000000000000000000" - check $n.slice(1) == "100000000000000000000000000000000000000000000000000000000000000" + $n == "0100000000000000000000000000000000000000000000000000000000000012" + $(n & default(NibblesBuf)) == + "0100000000000000000000000000000000000000000000000000000000000012" + + $n.slice(1) == "100000000000000000000000000000000000000000000000000000000000012" + $n.slice(1, 2) == "1" + $(n.slice(1) & NibblesBuf.nibble(1)) == + "1000000000000000000000000000000000000000000000000000000000000121" + + $n.replaceSuffix(n.slice(1, 2)) == + "0100000000000000000000000000000000000000000000000000000000000011" + + for i in 0 ..< 64: + check: + $n.slice(0, i) == + "0100000000000000000000000000000000000000000000000000000000000012"[0 ..< i] + + for i in 0 ..< 64: + check: + $n.slice(i) == + "0100000000000000000000000000000000000000000000000000000000000012"[i ..< 64] let he = n.toHexPrefix(true) @@ -84,3 +111,64 @@ suite "Nibbles": check: NibblesBuf.fromHexPrefix(he.data()) == (true, n) NibblesBuf.fromHexPrefix(ho.data()) == (true, n.slice(1)) + + test "sharedPrefixLen": + let n0 = NibblesBuf.fromBytes(hexToSeqByte("01000000")) + let n2 = NibblesBuf.fromBytes(hexToSeqByte("10")) + let n = NibblesBuf.fromBytes( + hexToSeqByte("0100000000000000000000000000000000000000000000000000000000000012") + ) + let nn = NibblesBuf.fromBytes( + hexToSeqByte("0100000000000000000000000000000000000000000000000000000000000013") + ) + + check: + n0.sharedPrefixLen(n0) == 8 + n0.sharedPrefixLen(n2) == 0 + n0.slice(1).sharedPrefixLen(n0.slice(1)) == 7 + n0.slice(7).sharedPrefixLen(n0.slice(7)) == 1 + n0.slice(1).sharedPrefixLen(n2) == 2 + + for i in 0 .. 64: + check: + n.slice(0, i).sharedPrefixLen(n) == i + + for i in 0 .. 63: + check: + n.slice(0, i).sharedPrefixLen(nn.slice(0, i)) == i + check: + n.sharedPrefixLen(nn) == 63 + + test "join": + block: + let + n0 = NibblesBuf.fromBytes(repeat(0x00'u8, 32)) + n1 = NibblesBuf.fromBytes(repeat(0x11'u8, 32)) + + for i in 0 ..< 64: + check: + $(n0.slice(0, 64 - i) & n1.slice(0, i)) == + (strutils.repeat('0', 64 - i) & strutils.repeat('1', i)) + + $(n0.slice(0, 1) & n1.slice(0, i)) == + (strutils.repeat('0', 1) & strutils.repeat('1', i)) + + for i in 0 ..< 63: + check: + $(n0.slice(1, 64 - i) & n1.slice(0, i)) == + (strutils.repeat('0', 63 - i) & strutils.repeat('1', i)) + + test "replaceSuffix": + let + n0 = NibblesBuf.fromBytes(repeat(0x00'u8, 32)) + n1 = NibblesBuf.fromBytes(repeat(0x11'u8, 32)) + + check: + n0.replaceSuffix(default(NibblesBuf)) == n0 + n0.replaceSuffix(n0) == n0 + n0.replaceSuffix(n1) == n1 + + for i in 0 ..< 64: + check: + $n0.replaceSuffix(n1.slice(0, i)) == + (strutils.repeat('0', 64 - i) & strutils.repeat('1', i))