mirror of
https://github.com/logos-messaging/nim-sds.git
synced 2026-07-02 13:59:41 +00:00
perf(protobuf): parse the buffer once instead of rescanning per field
Addresses review feedback on PR #77: `getField` rescanned the whole buffer on every call, giving O(fields × len) decode. Parse the message once at `ProtoBuffer.init(data)` in a single forward pass with the library's reader, indexing values by field number; the accessors are then plain lookups (last-wins / repeated semantics fall out of wire order). Public API and wire behaviour are unchanged. The full type-driven `Protobuf.encode/decode` API is intentionally not used here: it cannot express the backward-compatible causal-history decode (field 3 is repeated HistoryEntry now, repeated bare-string IDs in legacy messages), and its default-value omission would drop the always-written empty fields the strict decoder requires (e.g. empty `content` on sync messages), breaking compatibility with peers on the current wire format. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
parent
b6c19b415d
commit
b7a65e3a02
@ -3,16 +3,22 @@
|
||||
#
|
||||
# `sds/protobuf.nim` and `sds/snapshot_codec.nim` build messages by hand at the
|
||||
# field level — including a backward-compatible decode path the type-driven
|
||||
# `Protobuf.encode/decode` API cannot express — so this exposes a small
|
||||
# `Protobuf.encode/decode` API cannot express, and required-field / always-write
|
||||
# semantics its default-value omission would break — so this exposes a small
|
||||
# accumulating `ProtoBuffer` with `write`/`getField`/`getRepeatedField`/`finish`:
|
||||
# * unsigned integers encode as plain varints (last value wins on decode);
|
||||
# * strings and byte seqs encode length-delimited, with no UTF-8 validation
|
||||
# (strings are treated as opaque bytes — message ids may be binary);
|
||||
# * a field whose stored wire type differs from the requested one is skipped,
|
||||
# as `protoc` does; only a malformed buffer yields an error.
|
||||
#
|
||||
# On construction from bytes the buffer is parsed once, in a single forward pass
|
||||
# with the library's reader, into per-field value lists; the `getField` accessors
|
||||
# are then plain lookups rather than re-scanning the buffer for every field.
|
||||
|
||||
{.push raises: [].}
|
||||
|
||||
import std/tables
|
||||
import results
|
||||
import faststreams/inputs
|
||||
from protobuf_serialization/codec import
|
||||
@ -24,6 +30,12 @@ export results, protobuf_error
|
||||
|
||||
type ProtoBuffer* = object ## Accumulating protobuf field buffer.
|
||||
buffer*: seq[byte]
|
||||
## Reads are served from these parse-once indexes (populated by `init(data)`),
|
||||
## keyed by field number; values are kept in wire order so last-wins / repeated
|
||||
## semantics fall out directly.
|
||||
varints: Table[int, seq[uint64]]
|
||||
lengthDelims: Table[int, seq[seq[byte]]]
|
||||
parseOk: bool
|
||||
|
||||
converter toProtobufError*(err: ProtoError): ProtobufError =
|
||||
case err
|
||||
@ -40,10 +52,29 @@ proc missingRequiredField*(T: type ProtobufError, field: string): T =
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
proc init*(T: type ProtoBuffer): T =
|
||||
return T(buffer: @[])
|
||||
return T(buffer: @[], parseOk: true)
|
||||
|
||||
proc init*(T: type ProtoBuffer, data: seq[byte]): T =
|
||||
return T(buffer: data)
|
||||
## Parse `data` once into per-field value lists. A malformed buffer leaves
|
||||
## `parseOk = false`, which every accessor reports as a decode error.
|
||||
var pb = T(buffer: data, parseOk: true)
|
||||
var sh = memoryInput(data)
|
||||
try:
|
||||
let stream = sh.s
|
||||
while stream.readable:
|
||||
let hdr = readHeader(stream)
|
||||
case hdr.kind
|
||||
of WireKind.Varint:
|
||||
pb.varints.mgetOrPut(hdr.number, @[]).add(uint64(readValue(stream, puint64)))
|
||||
of WireKind.LengthDelim:
|
||||
pb.lengthDelims.mgetOrPut(hdr.number, @[]).add(seq[byte](readValue(stream, pbytes)))
|
||||
of WireKind.Fixed64:
|
||||
discard readValue(stream, fixed64)
|
||||
of WireKind.Fixed32:
|
||||
discard readValue(stream, fixed32)
|
||||
except CatchableError:
|
||||
pb.parseOk = false
|
||||
return pb
|
||||
|
||||
proc finish*(pb: var ProtoBuffer) =
|
||||
## No length prefix is used, so finishing only asserts the invariant that a
|
||||
@ -88,70 +119,37 @@ proc bytesToString(b: seq[byte]): string =
|
||||
copyMem(addr s[0], unsafeAddr b[0], b.len)
|
||||
return s
|
||||
|
||||
proc collectVarints(buffer: seq[byte], field: int): ProtoResult[seq[uint64]] =
|
||||
## All varint values stored at `field`, in order. Mismatched wire types at
|
||||
## the same field number are skipped, as protoc does.
|
||||
var values: seq[uint64]
|
||||
var sh = memoryInput(buffer)
|
||||
try:
|
||||
let stream = sh.s
|
||||
while stream.readable:
|
||||
let hdr = readHeader(stream)
|
||||
if hdr.number == field and hdr.kind == WireKind.Varint:
|
||||
values.add(uint64(readValue(stream, puint64)))
|
||||
else:
|
||||
case hdr.kind
|
||||
of WireKind.Varint: discard readValue(stream, puint64)
|
||||
of WireKind.Fixed64: discard readValue(stream, fixed64)
|
||||
of WireKind.Fixed32: discard readValue(stream, fixed32)
|
||||
of WireKind.LengthDelim: discard readValue(stream, pbytes)
|
||||
except CatchableError:
|
||||
return err(ProtoError.VarintDecode)
|
||||
return ok(values)
|
||||
|
||||
proc collectLengthDelims(buffer: seq[byte], field: int): ProtoResult[seq[seq[byte]]] =
|
||||
## All length-delimited values stored at `field`, in order.
|
||||
var values: seq[seq[byte]]
|
||||
var sh = memoryInput(buffer)
|
||||
try:
|
||||
let stream = sh.s
|
||||
while stream.readable:
|
||||
let hdr = readHeader(stream)
|
||||
if hdr.number == field and hdr.kind == WireKind.LengthDelim:
|
||||
values.add(seq[byte](readValue(stream, pbytes)))
|
||||
else:
|
||||
case hdr.kind
|
||||
of WireKind.Varint: discard readValue(stream, puint64)
|
||||
of WireKind.Fixed64: discard readValue(stream, fixed64)
|
||||
of WireKind.Fixed32: discard readValue(stream, fixed32)
|
||||
of WireKind.LengthDelim: discard readValue(stream, pbytes)
|
||||
except CatchableError:
|
||||
return err(ProtoError.VarintDecode)
|
||||
return ok(values)
|
||||
|
||||
proc getField*(pb: ProtoBuffer, field: int, output: var uint64): ProtoResult[bool] =
|
||||
let values = ?collectVarints(pb.buffer, field)
|
||||
if not pb.parseOk:
|
||||
return err(ProtoError.VarintDecode)
|
||||
let values = pb.varints.getOrDefault(field)
|
||||
if values.len > 0:
|
||||
output = values[^1]
|
||||
return ok(true)
|
||||
return ok(false)
|
||||
|
||||
proc getField*(pb: ProtoBuffer, field: int, output: var uint32): ProtoResult[bool] =
|
||||
let values = ?collectVarints(pb.buffer, field)
|
||||
if not pb.parseOk:
|
||||
return err(ProtoError.VarintDecode)
|
||||
let values = pb.varints.getOrDefault(field)
|
||||
if values.len > 0:
|
||||
output = uint32(values[^1])
|
||||
return ok(true)
|
||||
return ok(false)
|
||||
|
||||
proc getField*(pb: ProtoBuffer, field: int, output: var seq[byte]): ProtoResult[bool] =
|
||||
let values = ?collectLengthDelims(pb.buffer, field)
|
||||
if not pb.parseOk:
|
||||
return err(ProtoError.VarintDecode)
|
||||
let values = pb.lengthDelims.getOrDefault(field)
|
||||
if values.len > 0:
|
||||
output = values[^1]
|
||||
return ok(true)
|
||||
return ok(false)
|
||||
|
||||
proc getField*(pb: ProtoBuffer, field: int, output: var string): ProtoResult[bool] =
|
||||
let values = ?collectLengthDelims(pb.buffer, field)
|
||||
if not pb.parseOk:
|
||||
return err(ProtoError.VarintDecode)
|
||||
let values = pb.lengthDelims.getOrDefault(field)
|
||||
if values.len > 0:
|
||||
output = bytesToString(values[^1])
|
||||
return ok(true)
|
||||
@ -160,13 +158,17 @@ proc getField*(pb: ProtoBuffer, field: int, output: var string): ProtoResult[boo
|
||||
proc getRepeatedField*(
|
||||
pb: ProtoBuffer, field: int, output: var seq[seq[byte]]
|
||||
): ProtoResult[bool] =
|
||||
output = ?collectLengthDelims(pb.buffer, field)
|
||||
if not pb.parseOk:
|
||||
return err(ProtoError.VarintDecode)
|
||||
output = pb.lengthDelims.getOrDefault(field)
|
||||
return ok(output.len > 0)
|
||||
|
||||
proc getRepeatedField*(
|
||||
pb: ProtoBuffer, field: int, output: var seq[string]
|
||||
): ProtoResult[bool] =
|
||||
let values = ?collectLengthDelims(pb.buffer, field)
|
||||
if not pb.parseOk:
|
||||
return err(ProtoError.VarintDecode)
|
||||
let values = pb.lengthDelims.getOrDefault(field)
|
||||
output.setLen(0)
|
||||
for v in values:
|
||||
output.add(bytesToString(v))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user