From b970e49570e2800c911d150dd42f1bcacfd63497 Mon Sep 17 00:00:00 2001 From: Ivan FB Date: Thu, 11 Jun 2026 16:19:21 +0200 Subject: [PATCH] refactor(protobuf): proto3 optional (Opt) singular fields + app-layer validation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Satisfies the review feedback in one coherent design: - proto3 (no proto2 `required`, which is discouraged and only checks presence) - singular fields use the proto3 `optional` label (Opt[T]) — recommended over implicit fields for forward-compatibility - opaque seq[byte] ids (no UTF-8 validation) - mandatory identifiers (messageId/channelId and each entry's messageId) are validated at the application layer after decoding Opt-in-proto3 requires protobuf_serialization >= 0.5.0 (already pinned to 0.5.1). Co-Authored-By: Claude Opus 4.8 --- sds/protobuf.nim | 118 ++++++++++++++++++++++++++--------------------- 1 file changed, 65 insertions(+), 53 deletions(-) diff --git a/sds/protobuf.nim b/sds/protobuf.nim index d9b2897..86c6465 100644 --- a/sds/protobuf.nim +++ b/sds/protobuf.nim @@ -7,12 +7,17 @@ ## conversion bridges the two. The mirror string-ish fields are `seq[byte]` ## (not `pstring`) so message/channel/sender ids stay opaque bytes — no UTF-8 ## validation — and the distinct `SdsParticipantID` needs no special support. +## +## Singular fields use the proto3 `optional` label (`Opt[T]`), which is the +## recommended form for forward-compatibility; presence is exposed but the +## actual validity of mandatory identifiers is checked at the application layer +## after decoding (proto3 has no `required`). {.push raises: [].} import endians -import results import protobuf_serialization +import protobuf_serialization/pkg/results import ./types/[sds_message_id, history_entry, sds_message, reliability_error] import ./bloom @@ -22,29 +27,29 @@ import ./bloom type HistoryEntryPB* {.proto3.} = object - messageId* {.fieldNumber: 1.}: seq[byte] - retrievalHint* {.fieldNumber: 2.}: seq[byte] - senderId* {.fieldNumber: 3.}: seq[byte] + messageId* {.fieldNumber: 1.}: Opt[seq[byte]] + retrievalHint* {.fieldNumber: 2.}: Opt[seq[byte]] + senderId* {.fieldNumber: 3.}: Opt[seq[byte]] SdsMessagePB* {.proto3.} = object - messageId* {.fieldNumber: 1.}: seq[byte] - lamportTimestamp* {.fieldNumber: 2, pint.}: int64 + messageId* {.fieldNumber: 1.}: Opt[seq[byte]] + lamportTimestamp* {.fieldNumber: 2, pint.}: Opt[int64] causalHistory* {.fieldNumber: 3.}: seq[HistoryEntryPB] - channelId* {.fieldNumber: 4.}: seq[byte] - content* {.fieldNumber: 5.}: seq[byte] - bloomFilter* {.fieldNumber: 6.}: seq[byte] - senderId* {.fieldNumber: 7.}: seq[byte] + channelId* {.fieldNumber: 4.}: Opt[seq[byte]] + content* {.fieldNumber: 5.}: Opt[seq[byte]] + bloomFilter* {.fieldNumber: 6.}: Opt[seq[byte]] + senderId* {.fieldNumber: 7.}: Opt[seq[byte]] repairRequest* {.fieldNumber: 13.}: seq[HistoryEntryPB] BloomFilterPB {.proto3.} = object - data {.fieldNumber: 1.}: seq[byte] - capacity {.fieldNumber: 2, pint.}: uint64 - errorRate {.fieldNumber: 3, pint.}: uint64 - kHashes {.fieldNumber: 4, pint.}: uint64 - mBits {.fieldNumber: 5, pint.}: uint64 + data {.fieldNumber: 1.}: Opt[seq[byte]] + capacity {.fieldNumber: 2, pint.}: Opt[uint64] + errorRate {.fieldNumber: 3, pint.}: Opt[uint64] + kHashes {.fieldNumber: 4, pint.}: Opt[uint64] + mBits {.fieldNumber: 5, pint.}: Opt[uint64] # --------------------------------------------------------------------------- -# string <-> bytes (opaque, no UTF-8 validation) +# string <-> bytes (opaque, no UTF-8 validation) and optional-bytes helper # --------------------------------------------------------------------------- func toBytes(s: string): seq[byte] = @@ -59,32 +64,38 @@ func toStr(b: seq[byte]): string = copyMem(addr s[0], unsafeAddr b[0], b.len) return s +func optBytes(b: seq[byte]): Opt[seq[byte]] = + ## Present only when non-empty, so empty optionals stay off the wire. + if b.len > 0: + return Opt.some(b) + return Opt.none(seq[byte]) + # --------------------------------------------------------------------------- # Domain <-> wire conversion # --------------------------------------------------------------------------- func toPB*(e: HistoryEntry): HistoryEntryPB = return HistoryEntryPB( - messageId: e.messageId.toBytes, - retrievalHint: e.retrievalHint, - senderId: e.senderId.string.toBytes, + messageId: optBytes(e.messageId.toBytes), + retrievalHint: optBytes(e.retrievalHint), + senderId: optBytes(e.senderId.string.toBytes), ) func fromPB*(e: HistoryEntryPB): HistoryEntry = return HistoryEntry( - messageId: e.messageId.toStr, - retrievalHint: e.retrievalHint, - senderId: e.senderId.toStr.SdsParticipantID, + messageId: e.messageId.valueOr(@[]).toStr, + retrievalHint: e.retrievalHint.valueOr(@[]), + senderId: e.senderId.valueOr(@[]).toStr.SdsParticipantID, ) func toPB*(m: SdsMessage): SdsMessagePB = var pb = SdsMessagePB( - messageId: m.messageId.toBytes, - lamportTimestamp: m.lamportTimestamp, - channelId: m.channelId.toBytes, - content: m.content, - bloomFilter: m.bloomFilter, - senderId: m.senderId.string.toBytes, + messageId: optBytes(m.messageId.toBytes), + lamportTimestamp: Opt.some(m.lamportTimestamp), + channelId: optBytes(m.channelId.toBytes), + content: optBytes(m.content), + bloomFilter: optBytes(m.bloomFilter), + senderId: optBytes(m.senderId.string.toBytes), ) for e in m.causalHistory: pb.causalHistory.add(e.toPB) @@ -100,13 +111,13 @@ func fromPB*(pb: SdsMessagePB): SdsMessage = for e in pb.repairRequest: repair.add(e.fromPB) return SdsMessage.init( - messageId = pb.messageId.toStr, - lamportTimestamp = pb.lamportTimestamp, + messageId = pb.messageId.valueOr(@[]).toStr, + lamportTimestamp = pb.lamportTimestamp.valueOr(0'i64), causalHistory = causal, - channelId = pb.channelId.toStr, - content = pb.content, - bloomFilter = pb.bloomFilter, - senderId = pb.senderId.toStr.SdsParticipantID, + channelId = pb.channelId.valueOr(@[]).toStr, + content = pb.content.valueOr(@[]), + bloomFilter = pb.bloomFilter.valueOr(@[]), + senderId = pb.senderId.valueOr(@[]).toStr.SdsParticipantID, repairRequest = repair, ) @@ -121,24 +132,24 @@ proc serializeMessage*(msg: SdsMessage): Result[seq[byte], ReliabilityError] = return err(ReliabilityError.reSerializationError) proc deserializeMessage*(data: seq[byte]): Result[SdsMessage, ReliabilityError] = - ## proto3 has no required fields, so presence is validated by hand. Only the - ## identifiers are mandatory: `content`, `bloomFilter` and a zero - ## `lamportTimestamp` may legitimately be empty (e.g. periodic sync messages). + ## proto3 has no required fields, so the mandatory identifiers are validated + ## by hand after decoding. `content`/`bloomFilter`/`lamportTimestamp` may + ## legitimately be empty/zero (e.g. periodic sync messages). try: - let pb = Protobuf.decode(data, SdsMessagePB) - if pb.messageId.len == 0 or pb.channelId.len == 0: + let msg = Protobuf.decode(data, SdsMessagePB).fromPB + if msg.messageId.len == 0 or msg.channelId.len == 0: return err(ReliabilityError.reDeserializationError) - for e in pb.causalHistory & pb.repairRequest: + for e in msg.causalHistory & msg.repairRequest: if e.messageId.len == 0: return err(ReliabilityError.reDeserializationError) - return ok(pb.fromPB) + return ok(msg) except CatchableError: return err(ReliabilityError.reDeserializationError) proc extractChannelId*(data: seq[byte]): Result[SdsChannelID, ReliabilityError] = ## Channel ID without retaining the rest of the decoded message. try: - return ok(Protobuf.decode(data, SdsMessagePB).channelId.toStr) + return ok(Protobuf.decode(data, SdsMessagePB).channelId.valueOr(@[]).toStr) except CatchableError: return err(ReliabilityError.reDeserializationError) @@ -171,11 +182,11 @@ proc serializeBloomFilter*(filter: BloomFilter): Result[seq[byte], ReliabilityEr copyMem(addr bytes[i * sizeof(int)], addr leVal, sizeof(int)) let pb = BloomFilterPB( - data: bytes, - capacity: uint64(filter.capacity), - errorRate: uint64(filter.errorRate * 1_000_000), - kHashes: uint64(filter.kHashes), - mBits: uint64(filter.mBits), + data: optBytes(bytes), + capacity: Opt.some(uint64(filter.capacity)), + errorRate: Opt.some(uint64(filter.errorRate * 1_000_000)), + kHashes: Opt.some(uint64(filter.kHashes)), + mBits: Opt.some(uint64(filter.mBits)), ) return ok(Protobuf.encode(pb)) except CatchableError: @@ -186,18 +197,19 @@ proc deserializeBloomFilter*(data: seq[byte]): Result[BloomFilter, ReliabilityEr return err(ReliabilityError.reDeserializationError) try: let pb = Protobuf.decode(data, BloomFilterPB) - var intArray = newSeq[int](pb.data.len div sizeof(int)) + let rawData = pb.data.valueOr(@[]) + var intArray = newSeq[int](rawData.len div sizeof(int)) for i in 0 ..< intArray.len: var leVal: int - copyMem(addr leVal, unsafeAddr pb.data[i * sizeof(int)], sizeof(int)) + copyMem(addr leVal, unsafeAddr rawData[i * sizeof(int)], sizeof(int)) littleEndian64(addr intArray[i], addr leVal) return ok( BloomFilter.init( - capacity = int(pb.capacity), - errorRate = float(pb.errorRate) / 1_000_000, - kHashes = int(pb.kHashes), - mBits = int(pb.mBits), + capacity = int(pb.capacity.valueOr(0'u64)), + errorRate = float(pb.errorRate.valueOr(0'u64)) / 1_000_000, + kHashes = int(pb.kHashes.valueOr(0'u64)), + mBits = int(pb.mBits.valueOr(0'u64)), intArray = intArray, ) )