From 3ef6a552570e2a7ce7fc33d113b8a21d6030abb6 Mon Sep 17 00:00:00 2001 From: Ivan FB Date: Wed, 10 Jun 2026 22:45:23 +0200 Subject: [PATCH] refactor(protobuf): use protobuf-serialization encode/decode for wire codec Addresses review: replace the hand-rolled ProtoBuffer codec in the SDS wire format with type-driven `Protobuf.encode/decode` over annotated *PB mirrors (seq[byte] id fields keep ids opaque). The snapshot codec routes its embedded message/entry (de)serialisation through these wire helpers. --- sds/protobuf.nim | 296 +++++++++++++++++++++-------------------- sds/snapshot_codec.nim | 34 ++--- 2 files changed, 167 insertions(+), 163 deletions(-) diff --git a/sds/protobuf.nim b/sds/protobuf.nim index 4f7155b..a875aa2 100644 --- a/sds/protobuf.nim +++ b/sds/protobuf.nim @@ -1,188 +1,198 @@ +## SDS network wire codec. +## +## Messages are described as annotated protobuf types and (de)serialised with +## `nim-protobuf-serialization`'s type-driven `Protobuf.encode/decode`. The +## domain types (`SdsMessage`, `HistoryEntry`) keep their distinct/`requiresInit` +## shape; small `*PB` mirrors carry the field annotations and a trivial +## conversion bridges the two. The mirror string-ish fields are `seq[byte]` +## (not `pstring`) so message/channel/sender ids stay opaque bytes — no UTF-8 +## validation — and the distinct `SdsParticipantID` needs no special support. + +{.push raises: [].} + import endians +import results +import protobuf_serialization import ./types/[sds_message_id, history_entry, sds_message, reliability_error] -import ./protobufutil import ./bloom -import ./sds_utils -proc encodeHistoryEntry*(entry: HistoryEntry): ProtoBuffer = - var entryPb = ProtoBuffer.init() - entryPb.write(1, entry.messageId) - if entry.retrievalHint.len > 0: - entryPb.write(2, entry.retrievalHint) - if entry.senderId.len > 0: - entryPb.write(3, entry.senderId.string) - entryPb.finish() - entryPb +# --------------------------------------------------------------------------- +# Wire types +# --------------------------------------------------------------------------- -proc decodeHistoryEntry*(entryPb: ProtoBuffer): ProtobufResult[HistoryEntry] = - var entry = HistoryEntry.init("") - if not ?entryPb.getField(1, entry.messageId): - return err(ProtobufError.missingRequiredField("HistoryEntry.messageId")) - discard entryPb.getField(2, entry.retrievalHint) - var senderIdStr: string - if entryPb.getField(3, senderIdStr).valueOr(false): - entry.senderId = senderIdStr.SdsParticipantID - ok(entry) +type + HistoryEntryPB* {.proto3.} = object + messageId* {.fieldNumber: 1.}: seq[byte] + retrievalHint* {.fieldNumber: 2.}: seq[byte] + senderId* {.fieldNumber: 3.}: seq[byte] -proc encode*(msg: SdsMessage): ProtoBuffer = - var pb = ProtoBuffer.init() + SdsMessagePB* {.proto3.} = object + messageId* {.fieldNumber: 1.}: seq[byte] + lamportTimestamp* {.fieldNumber: 2, pint.}: int64 + causalHistory* {.fieldNumber: 3.}: seq[HistoryEntryPB] + channelId* {.fieldNumber: 4.}: seq[byte] + content* {.fieldNumber: 5.}: seq[byte] + bloomFilter* {.fieldNumber: 6.}: seq[byte] + senderId* {.fieldNumber: 7.}: seq[byte] + repairRequest* {.fieldNumber: 13.}: seq[HistoryEntryPB] - pb.write(1, msg.messageId) - pb.write(2, uint64(msg.lamportTimestamp)) + BloomFilterPB {.proto3.} = object + data {.fieldNumber: 1.}: seq[byte] + capacity {.fieldNumber: 2, pint.}: uint64 + errorRate {.fieldNumber: 3, pint.}: uint64 + kHashes {.fieldNumber: 4, pint.}: uint64 + mBits {.fieldNumber: 5, pint.}: uint64 - for entry in msg.causalHistory: - let entryPb = encodeHistoryEntry(entry) - pb.write(3, entryPb.buffer) +# --------------------------------------------------------------------------- +# string <-> bytes (opaque, no UTF-8 validation) +# --------------------------------------------------------------------------- - pb.write(4, msg.channelId) - pb.write(5, msg.content) - pb.write(6, msg.bloomFilter) +func toBytes(s: string): seq[byte] = + var b = newSeq[byte](s.len) + if s.len > 0: + copyMem(addr b[0], unsafeAddr s[0], s.len) + return b - if msg.senderId.len > 0: - pb.write(7, msg.senderId.string) +func toStr(b: seq[byte]): string = + var s = newString(b.len) + if b.len > 0: + copyMem(addr s[0], unsafeAddr b[0], b.len) + return s - for entry in msg.repairRequest: - let entryPb = encodeHistoryEntry(entry) - pb.write(13, entryPb.buffer) +# --------------------------------------------------------------------------- +# Domain <-> wire conversion +# --------------------------------------------------------------------------- - pb.finish() +func toPB*(e: HistoryEntry): HistoryEntryPB = + return HistoryEntryPB( + messageId: e.messageId.toBytes, + retrievalHint: e.retrievalHint, + senderId: e.senderId.string.toBytes, + ) +func fromPB*(e: HistoryEntryPB): HistoryEntry = + return HistoryEntry( + messageId: e.messageId.toStr, + retrievalHint: e.retrievalHint, + senderId: e.senderId.toStr.SdsParticipantID, + ) + +func toPB*(m: SdsMessage): SdsMessagePB = + var pb = SdsMessagePB( + messageId: m.messageId.toBytes, + lamportTimestamp: m.lamportTimestamp, + channelId: m.channelId.toBytes, + content: m.content, + bloomFilter: m.bloomFilter, + senderId: m.senderId.string.toBytes, + ) + for e in m.causalHistory: + pb.causalHistory.add(e.toPB) + for e in m.repairRequest: + pb.repairRequest.add(e.toPB) return pb -proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] = - let pb = ProtoBuffer.init(buffer) - var msg = SdsMessage.init("", 0, @[], "", @[], @[]) +func fromPB*(pb: SdsMessagePB): SdsMessage = + var causal: seq[HistoryEntry] + for e in pb.causalHistory: + causal.add(e.fromPB) + var repair: seq[HistoryEntry] + for e in pb.repairRequest: + repair.add(e.fromPB) + return SdsMessage.init( + messageId = pb.messageId.toStr, + lamportTimestamp = pb.lamportTimestamp, + causalHistory = causal, + channelId = pb.channelId.toStr, + content = pb.content, + bloomFilter = pb.bloomFilter, + senderId = pb.senderId.toStr.SdsParticipantID, + repairRequest = repair, + ) - if not ?pb.getField(1, msg.messageId): - return err(ProtobufError.missingRequiredField("messageId")) - - var timestamp: uint64 - if not ?pb.getField(2, timestamp): - return err(ProtobufError.missingRequiredField("lamportTimestamp")) - msg.lamportTimestamp = int64(timestamp) - - # Handle both old and new causal history formats - var historyBuffers: seq[seq[byte]] - if pb.getRepeatedField(3, historyBuffers).isOk(): - # New format: repeated HistoryEntry - for histBuffer in historyBuffers: - let entryPb = ProtoBuffer.init(histBuffer) - let entry = ?decodeHistoryEntry(entryPb) - msg.causalHistory.add(entry) - else: - # Try old format: repeated string - var causalHistory: seq[SdsMessageID] - let histResult = pb.getRepeatedField(3, causalHistory) - if histResult.isOk(): - msg.causalHistory = toCausalHistory(causalHistory) - - if not ?pb.getField(4, msg.channelId): - return err(ProtobufError.missingRequiredField("channelId")) - - if not ?pb.getField(5, msg.content): - return err(ProtobufError.missingRequiredField("content")) - - if not ?pb.getField(6, msg.bloomFilter): - msg.bloomFilter = @[] # Empty if not present - - # SDS-R: decode senderId (field 7, optional) - var msgSenderIdStr: string - if pb.getField(7, msgSenderIdStr).valueOr(false): - msg.senderId = msgSenderIdStr.SdsParticipantID - - # SDS-R: decode repair request (field 13, optional) - var repairBuffers: seq[seq[byte]] - if pb.getRepeatedField(13, repairBuffers).isOk(): - for repairBuffer in repairBuffers: - let entryPb = ProtoBuffer.init(repairBuffer) - let entry = ?decodeHistoryEntry(entryPb) - msg.repairRequest.add(entry) - - return ok(msg) - -proc extractChannelId*(data: seq[byte]): Result[SdsChannelID, ReliabilityError] = - ## For extraction of channel ID without full message deserialization - try: - let pb = ProtoBuffer.init(data) - var channelId: SdsChannelID - let fieldOk = pb.getField(4, channelId).valueOr: - return err(ReliabilityError.reDeserializationError) - if not fieldOk: - return err(ReliabilityError.reDeserializationError) - return ok(channelId) - except: - return err(ReliabilityError.reDeserializationError) +# --------------------------------------------------------------------------- +# Message (de)serialisation +# --------------------------------------------------------------------------- proc serializeMessage*(msg: SdsMessage): Result[seq[byte], ReliabilityError] = - let pb = encode(msg) - return ok(pb.buffer) + try: + return ok(Protobuf.encode(msg.toPB)) + except CatchableError: + return err(ReliabilityError.reSerializationError) proc deserializeMessage*(data: seq[byte]): Result[SdsMessage, ReliabilityError] = - let msg = SdsMessage.decode(data).valueOr: + try: + return ok(Protobuf.decode(data, SdsMessagePB).fromPB) + except CatchableError: return err(ReliabilityError.reDeserializationError) - return ok(msg) + +proc extractChannelId*(data: seq[byte]): Result[SdsChannelID, ReliabilityError] = + ## Channel ID without retaining the rest of the decoded message. + try: + return ok(Protobuf.decode(data, SdsMessagePB).channelId.toStr) + except CatchableError: + return err(ReliabilityError.reDeserializationError) + +# Single `HistoryEntry` (de)serialisation, used by the snapshot codec for the +# repair-buffer entries it embeds. Kept here so all `Protobuf.decode` calls live +# in this module. + +proc serializeHistoryEntry*(e: HistoryEntry): Result[seq[byte], ReliabilityError] = + try: + return ok(Protobuf.encode(e.toPB)) + except CatchableError: + return err(ReliabilityError.reSerializationError) + +proc deserializeHistoryEntry*(data: seq[byte]): Result[HistoryEntry, ReliabilityError] = + try: + return ok(Protobuf.decode(data, HistoryEntryPB).fromPB) + except CatchableError: + return err(ReliabilityError.reDeserializationError) + +# --------------------------------------------------------------------------- +# Bloom filter (de)serialisation +# --------------------------------------------------------------------------- proc serializeBloomFilter*(filter: BloomFilter): Result[seq[byte], ReliabilityError] = - var pb = ProtoBuffer.init() - try: var bytes = newSeq[byte](filter.intArray.len * sizeof(int)) for i, val in filter.intArray: var leVal: int littleEndian64(addr leVal, unsafeAddr val) - let start = i * sizeof(int) - copyMem(addr bytes[start], addr leVal, sizeof(int)) + copyMem(addr bytes[i * sizeof(int)], addr leVal, sizeof(int)) - pb.write(1, bytes) - pb.write(2, uint64(filter.capacity)) - pb.write(3, uint64(filter.errorRate * 1_000_000)) - pb.write(4, uint64(filter.kHashes)) - pb.write(5, uint64(filter.mBits)) - except: + let pb = BloomFilterPB( + data: bytes, + capacity: uint64(filter.capacity), + errorRate: uint64(filter.errorRate * 1_000_000), + kHashes: uint64(filter.kHashes), + mBits: uint64(filter.mBits), + ) + return ok(Protobuf.encode(pb)) + except CatchableError: return err(ReliabilityError.reSerializationError) - pb.finish() - return ok(pb.buffer) - proc deserializeBloomFilter*(data: seq[byte]): Result[BloomFilter, ReliabilityError] = if data.len == 0: return err(ReliabilityError.reDeserializationError) - - let pb = ProtoBuffer.init(data) - var bytes: seq[byte] - var cap, errRate, kHashes, mBits: uint64 - try: - let - field1_Ok = pb.getField(1, bytes).valueOr: - return err(ReliabilityError.reDeserializationError) - field2_Ok = pb.getField(2, cap).valueOr: - return err(ReliabilityError.reDeserializationError) - field3_Ok = pb.getField(3, errRate).valueOr: - return err(ReliabilityError.reDeserializationError) - field4_Ok = pb.getField(4, kHashes).valueOr: - return err(ReliabilityError.reDeserializationError) - field5_Ok = pb.getField(5, mBits).valueOr: - return err(ReliabilityError.reDeserializationError) - - if not field1_Ok or not field2_Ok or not field3_Ok or not field4_Ok or not field5_Ok: - return err(ReliabilityError.reDeserializationError) - - var intArray = newSeq[int](bytes.len div sizeof(int)) + let pb = Protobuf.decode(data, BloomFilterPB) + var intArray = newSeq[int](pb.data.len div sizeof(int)) for i in 0 ..< intArray.len: var leVal: int - let start = i * sizeof(int) - copyMem(addr leVal, unsafeAddr bytes[start], sizeof(int)) + copyMem(addr leVal, unsafeAddr pb.data[i * sizeof(int)], sizeof(int)) littleEndian64(addr intArray[i], addr leVal) return ok( BloomFilter.init( - capacity = int(cap), - errorRate = float(errRate) / 1_000_000, - kHashes = int(kHashes), - mBits = int(mBits), + capacity = int(pb.capacity), + errorRate = float(pb.errorRate) / 1_000_000, + kHashes = int(pb.kHashes), + mBits = int(pb.mBits), intArray = intArray, ) ) - except: + except CatchableError: return err(ReliabilityError.reDeserializationError) + +{.pop.} diff --git a/sds/snapshot_codec.nim b/sds/snapshot_codec.nim index 6fcb95c..45648c0 100644 --- a/sds/snapshot_codec.nim +++ b/sds/snapshot_codec.nim @@ -44,8 +44,7 @@ proc fromUnixMs(ms: int64): Time = proc encodeUnacked(u: UnacknowledgedMessage): ProtoBuffer = var pb = ProtoBuffer.init() - let msgPb = wire.encode(u.message) - pb.write(1, msgPb.buffer) + pb.write(1, wire.serializeMessage(u.message).get()) pb.write(2, uint64(u.sendTime.toUnixMs)) pb.write(3, uint32(u.resendAttempts)) pb.finish() @@ -56,7 +55,7 @@ proc decodeUnacked(buf: seq[byte]): ProtobufResult[UnacknowledgedMessage] = var msgBytes: seq[byte] if not ?pb.getField(1, msgBytes): return err(ProtobufError.missingRequiredField("UnacknowledgedMessage.message")) - let msg = SdsMessage.decode(msgBytes).valueOr: + let msg = wire.deserializeMessage(msgBytes).valueOr: return err(ProtobufError.missingRequiredField("UnacknowledgedMessage.message")) var sendMs: uint64 if not ?pb.getField(2, sendMs): @@ -77,8 +76,7 @@ proc decodeUnacked(buf: seq[byte]): ProtobufResult[UnacknowledgedMessage] = proc encodeIncoming(m: IncomingMessage): ProtoBuffer = var pb = ProtoBuffer.init() - let msgPb = wire.encode(m.message) - pb.write(1, msgPb.buffer) + pb.write(1, wire.serializeMessage(m.message).get()) for dep in m.missingDeps: pb.write(2, dep) # SdsMessageID is string pb.finish() @@ -89,7 +87,7 @@ proc decodeIncoming(buf: seq[byte]): ProtobufResult[IncomingMessage] = var msgBytes: seq[byte] if not ?pb.getField(1, msgBytes): return err(ProtobufError.missingRequiredField("IncomingMessage.message")) - let msg = SdsMessage.decode(msgBytes).valueOr: + let msg = wire.deserializeMessage(msgBytes).valueOr: return err(ProtobufError.missingRequiredField("IncomingMessage.message")) var deps: seq[SdsMessageID] discard pb.getRepeatedField(2, deps) @@ -104,8 +102,7 @@ proc decodeIncoming(buf: seq[byte]): ProtobufResult[IncomingMessage] = proc encodeOutRepairEntry(e: OutgoingRepairEntry): ProtoBuffer = var pb = ProtoBuffer.init() - let histPb = wire.encodeHistoryEntry(e.outHistEntry) - pb.write(1, histPb.buffer) + pb.write(1, wire.serializeHistoryEntry(e.outHistEntry).get()) pb.write(2, uint64(e.minTimeRepairReq.toUnixMs)) pb.finish() pb @@ -115,8 +112,8 @@ proc decodeOutRepairEntry(buf: seq[byte]): ProtobufResult[OutgoingRepairEntry] = var histBytes: seq[byte] if not ?pb.getField(1, histBytes): return err(ProtobufError.missingRequiredField("OutgoingRepairEntry.outHistEntry")) - let histPb = ProtoBuffer.init(histBytes) - let entry = ?wire.decodeHistoryEntry(histPb) + let entry = wire.deserializeHistoryEntry(histBytes).valueOr: + return err(ProtobufError.missingRequiredField("HistoryEntry")) var ms: uint64 if not ?pb.getField(2, ms): return err(ProtobufError.missingRequiredField("OutgoingRepairEntry.minTimeRepairReq")) @@ -151,8 +148,7 @@ proc decodeOutRepairKV(buf: seq[byte]): ProtobufResult[OutgoingRepairKV] = proc encodeInRepairEntry(e: IncomingRepairEntry): ProtoBuffer = var pb = ProtoBuffer.init() - let histPb = wire.encodeHistoryEntry(e.inHistEntry) - pb.write(1, histPb.buffer) + pb.write(1, wire.serializeHistoryEntry(e.inHistEntry).get()) pb.write(2, e.cachedMessage) pb.write(3, uint64(e.minTimeRepairResp.toUnixMs)) pb.finish() @@ -163,8 +159,8 @@ proc decodeInRepairEntry(buf: seq[byte]): ProtobufResult[IncomingRepairEntry] = var histBytes: seq[byte] if not ?pb.getField(1, histBytes): return err(ProtobufError.missingRequiredField("IncomingRepairEntry.inHistEntry")) - let histPb = ProtoBuffer.init(histBytes) - let entry = ?wire.decodeHistoryEntry(histPb) + let entry = wire.deserializeHistoryEntry(histBytes).valueOr: + return err(ProtobufError.missingRequiredField("HistoryEntry")) var cached: seq[byte] if not ?pb.getField(2, cached): return err(ProtobufError.missingRequiredField("IncomingRepairEntry.cachedMessage")) @@ -274,8 +270,7 @@ proc encode*(d: ChannelData): ProtoBuffer = let metaPb = encode(d.meta) pb.write(1, metaPb.buffer) for m in d.messageHistory: - let msgPb = wire.encode(m) - pb.write(2, msgPb.buffer) + pb.write(2, wire.serializeMessage(m).get()) pb.finish() pb @@ -289,7 +284,7 @@ proc decode*(T: type ChannelData, buf: seq[byte]): ProtobufResult[T] = var histBufs: seq[seq[byte]] discard pb.getRepeatedField(2, histBufs) for b in histBufs: - let m = SdsMessage.decode(b).valueOr: + let m = wire.deserializeMessage(b).valueOr: return err(ProtobufError.missingRequiredField("ChannelData.messageHistory[i]")) d.messageHistory.add(m) ok(d) @@ -301,8 +296,7 @@ proc decode*(T: type ChannelData, buf: seq[byte]): ProtobufResult[T] = proc encode*(u: HistoryUpdate): ProtoBuffer = var pb = ProtoBuffer.init() for m in u.append: - let msgPb = wire.encode(m) - pb.write(1, msgPb.buffer) + pb.write(1, wire.serializeMessage(m).get()) for id in u.evict: pb.write(2, id) pb.finish() @@ -314,7 +308,7 @@ proc decode*(T: type HistoryUpdate, buf: seq[byte]): ProtobufResult[T] = var appBufs: seq[seq[byte]] discard pb.getRepeatedField(1, appBufs) for b in appBufs: - let m = SdsMessage.decode(b).valueOr: + let m = wire.deserializeMessage(b).valueOr: return err(ProtobufError.missingRequiredField("HistoryUpdate.append[i]")) u.append.add(m) var ev: seq[SdsMessageID]