diff --git a/eth/rlp/writer.nim b/eth/rlp/writer.nim index 8c4d432..e726ddc 100644 --- a/eth/rlp/writer.nim +++ b/eth/rlp/writer.nim @@ -1,39 +1,39 @@ import std/options, results, - stew/[assign2, shims/macros], + stew/[arraybuf, assign2, bitops2, shims/macros], ./priv/defs +export arraybuf + type RlpWriter* = object pendingLists: seq[tuple[remainingItems, outBytes: int]] output: seq[byte] + RlpIntBuf* = ArrayBuf[9, byte] + ## Small buffer for holding a single RLP-encoded integer + const wrapObjsInList* = true -proc bytesNeeded(num: SomeUnsignedInt): int = - type IntType = type(num) - var n = num - while n != IntType(0): - inc result - n = n shr 8 +func bytesNeeded(num: SomeUnsignedInt): int = + # Number of non-zero bytes in the big endian encoding + sizeof(num) - (num.leadingZeros() shr 3) -proc writeBigEndian(outStream: var seq[byte], number: SomeUnsignedInt, +func writeBigEndian(outStream: var auto, number: SomeUnsignedInt, lastByteIdx: int, numberOfBytes: int) = - mixin `and`, `shr` - var n = number - for i in countdown(lastByteIdx, lastByteIdx - int(numberOfBytes) + 1): + for i in countdown(lastByteIdx, lastByteIdx - numberOfBytes + 1): outStream[i] = byte(n and 0xff) n = n shr 8 -proc writeBigEndian(outStream: var seq[byte], number: SomeUnsignedInt, +func writeBigEndian(outStream: var auto, number: SomeUnsignedInt, numberOfBytes: int) {.inline.} = outStream.setLen(outStream.len + numberOfBytes) outStream.writeBigEndian(number, outStream.len - 1, numberOfBytes) -proc writeCount(bytes: var seq[byte], count: int, baseMarker: byte) = +func writeCount(bytes: var auto, count: int, baseMarker: byte) = if count < THRESHOLD_LIST_LEN: bytes.add(baseMarker + byte(count)) else: @@ -45,6 +45,16 @@ proc writeCount(bytes: var seq[byte], count: int, baseMarker: byte) = bytes[origLen] = baseMarker + (THRESHOLD_LIST_LEN - 1) + byte(lenPrefixBytes) bytes.writeBigEndian(uint64(count), bytes.len - 1, lenPrefixBytes) +func writeInt(outStream: var auto, i: SomeUnsignedInt) = + if i == typeof(i)(0): + outStream.add BLOB_START_MARKER + elif i < typeof(i)(BLOB_START_MARKER): + outStream.add byte(i) + else: + let bytesNeeded = i.bytesNeeded + outStream.writeCount(bytesNeeded, BLOB_START_MARKER) + outStream.writeBigEndian(i, bytesNeeded) + proc initRlpWriter*: RlpWriter = # Avoid allocations during initial write of small items - since the writer is # expected to be short-lived, it doesn't hurt to allocate this buffer @@ -124,16 +134,7 @@ proc appendBlob(self: var RlpWriter, data: openArray[char]) = proc appendInt(self: var RlpWriter, i: SomeUnsignedInt) = # this is created as a separate proc as an extra precaution against # any overloading resolution problems when matching the IntLike concept. - type IntType = type(i) - - if i == IntType(0): - self.output.add BLOB_START_MARKER - elif i < BLOB_START_MARKER.SomeUnsignedInt: - self.output.add byte(i) - else: - let bytesNeeded = i.bytesNeeded - self.output.writeCount(bytesNeeded, BLOB_START_MARKER) - self.output.writeBigEndian(i, bytesNeeded) + self.output.writeInt(i) self.maybeClosePendingLists() @@ -319,16 +320,22 @@ template finish*(self: RlpWriter): seq[byte] = doAssert self.pendingLists.len == 0, "Insufficient number of elements written to a started list" self.output +func clear*(w: var RlpWriter) = + # Prepare writer for reuse + w.pendingLists.setLen(0) + w.output.setLen(0) + proc encode*[T](v: T): seq[byte] = mixin append + var writer = initRlpWriter() writer.append(v) move(writer.finish) -proc encodeInt*(i: SomeUnsignedInt): seq[byte] = - var writer = initRlpWriter() - writer.appendInt(i) - move(writer.finish) +func encodeInt*(i: SomeUnsignedInt): RlpIntBuf = + var buf: RlpIntBuf + buf.writeInt(i) + buf macro encodeList*(args: varargs[untyped]): seq[byte] = var @@ -345,12 +352,3 @@ macro encodeList*(args: varargs[untyped]): seq[byte] = var `writer` = initRlpList(`listLen`) `body` move(finish(`writer`)) - -when false: - # XXX: Currently fails with a malformed AST error on the args.len expression - template encodeList*(args: varargs[untyped]): seq[byte] = - mixin append - var writer = initRlpList(args.len) - for arg in args: - writer.append(arg) - writer.finish diff --git a/eth/trie/ordered_trie.nim b/eth/trie/ordered_trie.nim new file mode 100644 index 0000000..0f332db --- /dev/null +++ b/eth/trie/ordered_trie.nim @@ -0,0 +1,261 @@ +import ../common/hashes, ../rlp, stew/arraybuf + +export hashes + +type + ShortHash = ArrayBuf[32, byte] + + OrderedTrieRootBuilder* = object + ## A special case of hexary trie root building for the case where keys are + ## sorted integers and number of entries is known ahead of time. + ## + ## The builder must be initialized with the value count by calling `init`. + ## + ## In the ethereum MPT, leaf leaves are computed by prefixing the value with + ## its trie path slice. When the keys are ordere, we can pre-compute the + ## trie path slice thus avoiding unnecessary storage of leaf values. + ## + ## Similar implementations with various tradeoffs exist that cover the + ## general case: + ## + ## * https://github.com/alloy-rs/trie + ## * https://github.com/rust-ethereum/ethereum/blob/b160820620aa9fd30050d5fcb306be4e12d58c8c/src/util.rs#L152 + ## * https://github.com/ethereum/go-ethereum/blob/master/trie/stacktrie.go + ## + ## TODO We don't need to store all leaves - instead, we could for each + ## level of the trie store only a hashing state that collects the trie + ## built "so far", similar to the StackTrie implementation - this works + ## for items 0x80 and up where the rlp-encoded order matches insertion + ## order. + leaves: seq[ShortHash] + + items: int + ## Number of items added so far (and therefore also the key of the next item) + +func init*(T: type OrderedTrieRootBuilder, expected: int): T = + T(leaves: newSeq[ShortHash](expected)) + +func toShortHash(v: openArray[byte]): ShortHash = + if v.len < 32: + ShortHash.initCopyFrom(v) + else: + ShortHash.initCopyFrom(keccak256(v).data) + +func append(w: var RlpWriter, key: ShortHash) = + if 1 < key.len and key.len < 32: + w.appendRawBytes key.data + else: + w.append key.data + +func keyAtIndex(b: var OrderedTrieRootBuilder, i: int): RlpIntBuf = + # Given a leaf index, compute the rlp-encoded key + let key = + if i <= 0x7f: + if i == min(0x7f, b.leaves.len - 1): + 0'u64 + else: + uint64 i + 1 + else: + uint64 i + rlp.encodeInt(key) + +func nibble(v: RlpIntBuf, i: int): byte = + let data = v.data[i shr 1] + if (i and 1) != 0: + data and 0xf + else: + data shr 4 + +func nibbles(v: RlpIntBuf): int = + v.len * 2 + +func sharedPrefixLen(a, b: RlpIntBuf): int = + # Number of nibbles the two buffers have in common + for i in 0 ..< min(a.len, b.len): + if a[i] != b[i]: + return + if a.nibble(i * 2) == b.nibble(i * 2): + i * 2 + 1 + else: + i * 2 + min(a.len, b.len) + +func hexPrefixEncode( + r: RlpIntBuf, ibegin, iend: int, isLeaf = false +): ArrayBuf[10, byte] = + let nibbleCount = iend - ibegin + var oddnessFlag = (nibbleCount and 1) != 0 + result.setLen((nibbleCount div 2) + 1) + result[0] = byte((int(isLeaf) * 2 + int(oddnessFlag)) shl 4) + var writeHead = 0 + + for i in ibegin ..< iend: + let nextNibble = r.nibble(i) + if oddnessFlag: + result[writeHead] = result[writeHead] or nextNibble + else: + inc writeHead + result[writeHead] = nextNibble shl 4 + oddnessFlag = not oddnessFlag + +proc keyToIndex(b: var OrderedTrieRootBuilder, key: uint64): int = + ## Given a key, compute its position according to the rlp-encoded integer + ## ordering, ie the order that would result from encoding the key + ## with RLP, "shortest big endian encoding" and sorting lexicographically - + ## this lexicographical order determines the location of the key in the trie + if key == 0: + # Key 0 goes into position 0x7f or last, depending on how many there are + min(0x7f, b.leaves.len - 1) + elif key <= uint64 min(0x7f, b.leaves.len - 1): + int key - 1 + else: + int key + +proc updateHash(b: var OrderedTrieRootBuilder, key: uint64, v: auto, w: var RlpWriter) = + let + pos = b.keyToIndex(key) + cur = rlp.encodeInt(key) + b.leaves[pos] = + try: + w.clear() + w.startList(2) + + # compute the longest shared nibble prefix between a key and its sorted + # neighbours which determines how much of the key is left in the leaf + # itself during encoding + let spl = + if b.leaves.len == 1: + -1 # If there's only one leaf, the whole key is used as leaf path + else: + if pos + 1 < b.leaves.len: + let next = b.keyAtIndex(pos + 1) + if pos > 0: + let prev = b.keyAtIndex(pos - 1) + max(prev.sharedPrefixLen(cur), next.sharedPrefixLen(cur)) + else: + next.sharedPrefixLen(cur) + else: + let prev = b.keyAtIndex(pos - 1) + prev.sharedPrefixLen(cur) + + w.append(cur.hexPrefixEncode(spl + 1, cur.nibbles, isLeaf = true).data()) + w.append(rlp.encode(v)) + + toShortHash(w.finish) + except RlpError: + raiseAssert "RLP failures not expected" + +proc add*[T](b: var OrderedTrieRootBuilder, v: openArray[T]) = + ## Add items to the trie root builder, calling `rlp.encode(item)` to compute + ## the value of the item. The total number of items added before calling + ## `rootHash` must equal what was given in `init`. + ## + ## TODO instead of RLP-encoding the items to bytes, we should be hashing them + ## directly: + ## * https://github.com/status-im/nim-eth/issues/724 + ## * https://github.com/status-im/nim-eth/issues/698 + var w = initRlpWriter() + for item in v: + b.updateHash(uint64 b.items, item, w) + b.items += 1 + +proc computeKey(b: var OrderedTrieRootBuilder, rng: Slice[int], depth: int): ShortHash = + if rng.len == 0: + ShortHash.initCopyFrom([byte 128]) # RLP of empty list + elif rng.len == 1: # Leaf + b.leaves[rng.a] + else: # Branch (or extension) + var p = int.high + let ka = b.keyAtIndex(rng.a) + + # Find the shortest shared prefix among the given keys - if this is not 0, + # it means an extension node must be introduced among the nodes in the given + # range. The top level always has a 0 shared length prefix because the + # encodings for 0 and 1 start with different nibbles. + if depth == 0: + p = 0 + else: + for i in 1 ..< rng.len: + # TODO We can get rid of this loop by observing what the nibbles in the + # RLP integer encoding have in common and adjust accordingly + p = min(p, sharedPrefixLen(ka, b.keyAtIndex(rng.a + i))) + if p == depth: + break + + var w = initRlpWriter() + + if p == depth: # No shared prefix - this is a branch + w.startList(17) + # Sub-divide the keys by nibble and recurse + var pos = rng.a + for n in 0'u8 .. 15'u8: + var x: int + # Pick out the keys that have the asked-for nibble at the given depth + while pos + x <= rng.b and b.keyAtIndex(pos + x).nibble(depth) == n: + x += 1 + + if x > 0: + w.append b.computeKey(pos .. pos + x - 1, depth + 1) + else: + w.append(openArray[byte]([])) + pos += x + + w.append(openArray[byte]([])) # No data in branch nodes + else: + w.startList(2) + w.append(ka.hexPrefixEncode(depth, p, isLeaf = false).data()) + w.append(b.computeKey(rng, p)) + + toShortHash(w.finish()) + +proc rootHash*(b: var OrderedTrieRootBuilder): Root = + doAssert b.items == b.leaves.len, "Items added does not match initial length" + let h = b.computeKey(0 ..< b.leaves.len, 0) + if h.len == 32: + Root(h.buf) + else: + keccak256(h.data) + +proc orderedTrieRoot*[T](items: openArray[T]): Root = + ## Compute the MPT root of a list of items using their rlp-encoded index as + ## key. + ## + ## Typical examples include the transaction and withdrawal roots that appear + ## in blocks. + ## + ## The given values will be rlp-encoded using `rlp.encode`. + var b = OrderedTrieRootBuilder.init(items.len) + b.add(items) + b.rootHash + +when isMainModule: # A small benchmark + import std/[monotimes, times], eth/trie/[hexary, db] + + let n = 1000000 + echo "Testing ", n + let values = block: + var tmp: seq[uint64] + for i in 0 .. n: + tmp.add i.uint64 + tmp + + let x0 = getMonoTime() + let b1 = block: + var db = OrderedTrieRootBuilder.init(values.len) + + db.add(values) + db.rootHash() + echo b1 + let x1 = getMonoTime() + let b2 = block: + var db2 = initHexaryTrie(newMemoryDB()) + for v in values: + db2.put(rlp.encode(v), rlp.encode(v)) + + db2.rootHash() + let x2 = getMonoTime() + assert b1 == b2 + + echo ( + (x1 - x0), (x2 - x1), (x1 - x0).inNanoseconds.float / (x2 - x1).inNanoseconds.float + ) diff --git a/tests/rlp/test_api_usage.nim b/tests/rlp/test_api_usage.nim index 41ec9c2..80fb998 100644 --- a/tests/rlp/test_api_usage.nim +++ b/tests/rlp/test_api_usage.nim @@ -257,3 +257,8 @@ suite "test api usage": expect RlpTypeMismatch: discard rlp.read(MyEnum) rlp.skipElem() + + test "encodeInt basics": + for i in [uint64 0, 1, 10, 100, 1000, uint64.high]: + check: + encode(i) == encodeInt(i).data() diff --git a/tests/trie/all_tests.nim b/tests/trie/all_tests.nim index 5976203..b1e4f28 100644 --- a/tests/trie/all_tests.nim +++ b/tests/trie/all_tests.nim @@ -1,5 +1,6 @@ import ./test_hexary_trie, ./test_json_suite, + ./test_ordered_trie, ./test_transaction_db, ./test_hexary_proof diff --git a/tests/trie/test_ordered_trie.nim b/tests/trie/test_ordered_trie.nim new file mode 100644 index 0000000..b79e1a1 --- /dev/null +++ b/tests/trie/test_ordered_trie.nim @@ -0,0 +1,23 @@ +import ../../eth/trie/[db, hexary, ordered_trie], ../../eth/rlp, unittest2 + +{.used.} + +suite "OrderedTrie": + for n in [0, 1, 2, 3, 126, 127, 128, 129, 130, 1000]: + test "Ordered vs normal trie " & $n: + let values = block: + var tmp: seq[uint64] + for i in 0 .. n: + tmp.add i.uint64 + tmp + + let b1 = orderedTrieRoot(values) + + let b2 = block: + var db2 = initHexaryTrie(newMemoryDB()) + for v in values: + db2.put(rlp.encode(v), rlp.encode(v)) + + db2.rootHash() + check: + b1 == b2