From b7a65e3a022f853a2db3d6024409f85c5ce6e979 Mon Sep 17 00:00:00 2001 From: Ivan FB Date: Wed, 10 Jun 2026 17:13:05 +0200 Subject: [PATCH] perf(protobuf): parse the buffer once instead of rescanning per field MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses review feedback on PR #77: `getField` rescanned the whole buffer on every call, giving O(fields × len) decode. Parse the message once at `ProtoBuffer.init(data)` in a single forward pass with the library's reader, indexing values by field number; the accessors are then plain lookups (last-wins / repeated semantics fall out of wire order). Public API and wire behaviour are unchanged. The full type-driven `Protobuf.encode/decode` API is intentionally not used here: it cannot express the backward-compatible causal-history decode (field 3 is repeated HistoryEntry now, repeated bare-string IDs in legacy messages), and its default-value omission would drop the always-written empty fields the strict decoder requires (e.g. empty `content` on sync messages), breaking compatibility with peers on the current wire format. Co-Authored-By: Claude Opus 4.8 --- sds/protobufutil.nim | 102 ++++++++++++++++++++++--------------------- 1 file changed, 52 insertions(+), 50 deletions(-) diff --git a/sds/protobufutil.nim b/sds/protobufutil.nim index 2203bf7..6566a3e 100644 --- a/sds/protobufutil.nim +++ b/sds/protobufutil.nim @@ -3,16 +3,22 @@ # # `sds/protobuf.nim` and `sds/snapshot_codec.nim` build messages by hand at the # field level — including a backward-compatible decode path the type-driven -# `Protobuf.encode/decode` API cannot express — so this exposes a small +# `Protobuf.encode/decode` API cannot express, and required-field / always-write +# semantics its default-value omission would break — so this exposes a small # accumulating `ProtoBuffer` with `write`/`getField`/`getRepeatedField`/`finish`: # * unsigned integers encode as plain varints (last value wins on decode); # * strings and byte seqs encode length-delimited, with no UTF-8 validation # (strings are treated as opaque bytes — message ids may be binary); # * a field whose stored wire type differs from the requested one is skipped, # as `protoc` does; only a malformed buffer yields an error. +# +# On construction from bytes the buffer is parsed once, in a single forward pass +# with the library's reader, into per-field value lists; the `getField` accessors +# are then plain lookups rather than re-scanning the buffer for every field. {.push raises: [].} +import std/tables import results import faststreams/inputs from protobuf_serialization/codec import @@ -24,6 +30,12 @@ export results, protobuf_error type ProtoBuffer* = object ## Accumulating protobuf field buffer. buffer*: seq[byte] + ## Reads are served from these parse-once indexes (populated by `init(data)`), + ## keyed by field number; values are kept in wire order so last-wins / repeated + ## semantics fall out directly. + varints: Table[int, seq[uint64]] + lengthDelims: Table[int, seq[seq[byte]]] + parseOk: bool converter toProtobufError*(err: ProtoError): ProtobufError = case err @@ -40,10 +52,29 @@ proc missingRequiredField*(T: type ProtobufError, field: string): T = # --------------------------------------------------------------------------- proc init*(T: type ProtoBuffer): T = - return T(buffer: @[]) + return T(buffer: @[], parseOk: true) proc init*(T: type ProtoBuffer, data: seq[byte]): T = - return T(buffer: data) + ## Parse `data` once into per-field value lists. A malformed buffer leaves + ## `parseOk = false`, which every accessor reports as a decode error. + var pb = T(buffer: data, parseOk: true) + var sh = memoryInput(data) + try: + let stream = sh.s + while stream.readable: + let hdr = readHeader(stream) + case hdr.kind + of WireKind.Varint: + pb.varints.mgetOrPut(hdr.number, @[]).add(uint64(readValue(stream, puint64))) + of WireKind.LengthDelim: + pb.lengthDelims.mgetOrPut(hdr.number, @[]).add(seq[byte](readValue(stream, pbytes))) + of WireKind.Fixed64: + discard readValue(stream, fixed64) + of WireKind.Fixed32: + discard readValue(stream, fixed32) + except CatchableError: + pb.parseOk = false + return pb proc finish*(pb: var ProtoBuffer) = ## No length prefix is used, so finishing only asserts the invariant that a @@ -88,70 +119,37 @@ proc bytesToString(b: seq[byte]): string = copyMem(addr s[0], unsafeAddr b[0], b.len) return s -proc collectVarints(buffer: seq[byte], field: int): ProtoResult[seq[uint64]] = - ## All varint values stored at `field`, in order. Mismatched wire types at - ## the same field number are skipped, as protoc does. - var values: seq[uint64] - var sh = memoryInput(buffer) - try: - let stream = sh.s - while stream.readable: - let hdr = readHeader(stream) - if hdr.number == field and hdr.kind == WireKind.Varint: - values.add(uint64(readValue(stream, puint64))) - else: - case hdr.kind - of WireKind.Varint: discard readValue(stream, puint64) - of WireKind.Fixed64: discard readValue(stream, fixed64) - of WireKind.Fixed32: discard readValue(stream, fixed32) - of WireKind.LengthDelim: discard readValue(stream, pbytes) - except CatchableError: - return err(ProtoError.VarintDecode) - return ok(values) - -proc collectLengthDelims(buffer: seq[byte], field: int): ProtoResult[seq[seq[byte]]] = - ## All length-delimited values stored at `field`, in order. - var values: seq[seq[byte]] - var sh = memoryInput(buffer) - try: - let stream = sh.s - while stream.readable: - let hdr = readHeader(stream) - if hdr.number == field and hdr.kind == WireKind.LengthDelim: - values.add(seq[byte](readValue(stream, pbytes))) - else: - case hdr.kind - of WireKind.Varint: discard readValue(stream, puint64) - of WireKind.Fixed64: discard readValue(stream, fixed64) - of WireKind.Fixed32: discard readValue(stream, fixed32) - of WireKind.LengthDelim: discard readValue(stream, pbytes) - except CatchableError: - return err(ProtoError.VarintDecode) - return ok(values) - proc getField*(pb: ProtoBuffer, field: int, output: var uint64): ProtoResult[bool] = - let values = ?collectVarints(pb.buffer, field) + if not pb.parseOk: + return err(ProtoError.VarintDecode) + let values = pb.varints.getOrDefault(field) if values.len > 0: output = values[^1] return ok(true) return ok(false) proc getField*(pb: ProtoBuffer, field: int, output: var uint32): ProtoResult[bool] = - let values = ?collectVarints(pb.buffer, field) + if not pb.parseOk: + return err(ProtoError.VarintDecode) + let values = pb.varints.getOrDefault(field) if values.len > 0: output = uint32(values[^1]) return ok(true) return ok(false) proc getField*(pb: ProtoBuffer, field: int, output: var seq[byte]): ProtoResult[bool] = - let values = ?collectLengthDelims(pb.buffer, field) + if not pb.parseOk: + return err(ProtoError.VarintDecode) + let values = pb.lengthDelims.getOrDefault(field) if values.len > 0: output = values[^1] return ok(true) return ok(false) proc getField*(pb: ProtoBuffer, field: int, output: var string): ProtoResult[bool] = - let values = ?collectLengthDelims(pb.buffer, field) + if not pb.parseOk: + return err(ProtoError.VarintDecode) + let values = pb.lengthDelims.getOrDefault(field) if values.len > 0: output = bytesToString(values[^1]) return ok(true) @@ -160,13 +158,17 @@ proc getField*(pb: ProtoBuffer, field: int, output: var string): ProtoResult[boo proc getRepeatedField*( pb: ProtoBuffer, field: int, output: var seq[seq[byte]] ): ProtoResult[bool] = - output = ?collectLengthDelims(pb.buffer, field) + if not pb.parseOk: + return err(ProtoError.VarintDecode) + output = pb.lengthDelims.getOrDefault(field) return ok(output.len > 0) proc getRepeatedField*( pb: ProtoBuffer, field: int, output: var seq[string] ): ProtoResult[bool] = - let values = ?collectLengthDelims(pb.buffer, field) + if not pb.parseOk: + return err(ProtoError.VarintDecode) + let values = pb.lengthDelims.getOrDefault(field) output.setLen(0) for v in values: output.add(bytesToString(v))