From b5d1364d47372e14e4f5e0e4bc3ebfead96879b9 Mon Sep 17 00:00:00 2001 From: Ivan FB <128452529+Ivansete-status@users.noreply.github.com> Date: Fri, 12 Jun 2026 15:08:23 +0200 Subject: [PATCH] feat: replace nim-libp2p protobuf with nim-protobuf-serialization (#77) Co-authored-by: Esteban C Borsani --- .github/workflows/ci-nix.yml | 2 +- .github/workflows/ci.yml | 4 +- nimble.lock | 158 +++-------------- nix/deps.nix | 81 ++------- sds.nimble | 2 +- sds/protobuf.nim | 321 +++++++++++++++++++---------------- sds/protobufutil.nim | 168 +++++++++++++++++- sds/snapshot_codec.nim | 71 ++++---- sds/types/protobuf_error.nim | 17 +- 9 files changed, 431 insertions(+), 393 deletions(-) diff --git a/.github/workflows/ci-nix.yml b/.github/workflows/ci-nix.yml index f324b4a..60a3fdf 100644 --- a/.github/workflows/ci-nix.yml +++ b/.github/workflows/ci-nix.yml @@ -5,7 +5,7 @@ permissions: checks: write on: pull_request: - branches: [master] + branches: [master, release/*] jobs: build: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 02c2e52..ba331d7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,9 +5,9 @@ permissions: checks: write on: pull_request: - branches: [master] + branches: [master, release/*] push: - branches: [master] + branches: [master, release/*] workflow_dispatch: jobs: diff --git a/nimble.lock b/nimble.lock index 2ccd8ab..3520d20 100644 --- a/nimble.lock +++ b/nimble.lock @@ -23,18 +23,6 @@ "sha1": "7e068f119664cf47ad0cfb74ef4c56fb6b616523" } }, - "bearssl_pkey_decoder": { - "version": "0.1.0", - "vcsRevision": "21dd3710df9345ed2ad8bf8f882761e07863b8e0", - "url": "https://github.com/vacp2p/bearssl_pkey_decoder", - "downloadMethod": "git", - "dependencies": [ - "bearssl" - ], - "checksums": { - "sha1": "21b42e2e6ddca6c875d3fc50f36a5115abf51714" - } - }, "results": { "version": "0.5.1", "vcsRevision": "df8113dda4c2d74d460a8fa98252b0b771bf1f27", @@ -85,6 +73,32 @@ "sha1": "fa35c1bb76a0a02a2379fe86eaae0957c7527cb8" } }, + "npeg": { + "version": "1.3.0", + "vcsRevision": "409f6796d0e880b3f0222c964d1da7de6e450811", + "url": "https://github.com/zevv/npeg", + "downloadMethod": "git", + "dependencies": [], + "checksums": { + "sha1": "64f15c85a059c889cb11c5fe72372677c50da621" + } + }, + "protobuf_serialization": { + "version": "0.5.1", + "vcsRevision": "d9aa950b9d9e8bfc8a201740042b5e8ea5880875", + "url": "https://github.com/status-im/nim-protobuf-serialization", + "downloadMethod": "git", + "dependencies": [ + "stew", + "faststreams", + "serialization", + "npeg", + "unittest2" + ], + "checksums": { + "sha1": "c02d1124931612041a25c6c8ed59a6120849c85d" + } + }, "json_serialization": { "version": "0.4.4", "vcsRevision": "c343b0e243d9e17e2c40f3a8a24340f7c4a71d44", @@ -157,39 +171,6 @@ "sha1": "455802a90204d8ad6b31d53f2efff8ebfe4c834a" } }, - "dnsclient": { - "version": "0.3.4", - "vcsRevision": "23214235d4784d24aceed99bbfe153379ea557c8", - "url": "https://github.com/ba0f3/dnsclient.nim", - "downloadMethod": "git", - "dependencies": [], - "checksums": { - "sha1": "65262c7e533ff49d6aca5539da4bc6c6ce132f40" - } - }, - "jwt": { - "version": "0.2", - "vcsRevision": "18f8378de52b241f321c1f9ea905456e89b95c6f", - "url": "https://github.com/vacp2p/nim-jwt.git", - "downloadMethod": "git", - "dependencies": [ - "bearssl", - "bearssl_pkey_decoder" - ], - "checksums": { - "sha1": "bcfd6fc9c5e10a52b87117219b7ab5c98136bc8e" - } - }, - "nimcrypto": { - "version": "0.7.3", - "vcsRevision": "b3dbc9c4d08e58c5b7bfad6dc7ef2ee52f2f4c08", - "url": "https://github.com/cheatfate/nimcrypto", - "downloadMethod": "git", - "dependencies": [], - "checksums": { - "sha1": "f72b90fe3f4da09efa482de4f8729e7ee4abea2f" - } - }, "metrics": { "version": "0.1.2", "vcsRevision": "11d0cddfb0e711aa2a8c75d1892ae24a64c299fc", @@ -204,93 +185,6 @@ "sha1": "5cdac99d85d3c146d170e85064c88fb28f377842" } }, - "secp256k1": { - "version": "0.6.0.3.2", - "vcsRevision": "d8f1288b7c72f00be5fc2c5ea72bf5cae1eafb15", - "url": "https://github.com/status-im/nim-secp256k1", - "downloadMethod": "git", - "dependencies": [ - "stew", - "results", - "nimcrypto" - ], - "checksums": { - "sha1": "6618ef9de17121846a8c1d0317026b0ce8584e10" - } - }, - "zlib": { - "version": "0.1.0", - "vcsRevision": "e680f269fb01af2c34a2ba879ff281795a5258fe", - "url": "https://github.com/status-im/nim-zlib", - "downloadMethod": "git", - "dependencies": [ - "stew", - "results" - ], - "checksums": { - "sha1": "bbde4f5a97a84b450fef7d107461e5f35cf2b47f" - } - }, - "websock": { - "version": "0.2.1", - "vcsRevision": "35ae76f1559e835c80f9c1a3943bf995d3dd9eb5", - "url": "https://github.com/status-im/nim-websock", - "downloadMethod": "git", - "dependencies": [ - "chronos", - "httputils", - "chronicles", - "stew", - "nimcrypto", - "bearssl", - "results", - "zlib" - ], - "checksums": { - "sha1": "1cb5efa10cd389bc01d0707c242ae010c76a03cd" - } - }, - "lsquic": { - "version": "0.0.1", - "vcsRevision": "4fb03ee7bfb39aecb3316889fdcb60bec3d0936f", - "url": "https://github.com/vacp2p/nim-lsquic", - "downloadMethod": "git", - "dependencies": [ - "zlib", - "stew", - "chronos", - "nimcrypto", - "unittest2", - "chronicles" - ], - "checksums": { - "sha1": "f465fa994346490d0924d162f53d9b5aec62f948" - } - }, - "libp2p": { - "version": "1.15.2", - "vcsRevision": "ca48c3718246bb411ff0e354a70cb82d9a28de0d", - "url": "https://github.com/vacp2p/nim-libp2p", - "downloadMethod": "git", - "dependencies": [ - "nimcrypto", - "dnsclient", - "bearssl", - "chronicles", - "chronos", - "metrics", - "secp256k1", - "stew", - "websock", - "unittest2", - "results", - "lsquic", - "jwt" - ], - "checksums": { - "sha1": "3b2cdc7e00261eb4210ca3d44ec3bd64c2b7bbba" - } - }, "stint": { "version": "0.8.2", "vcsRevision": "470b7892561b5179ab20bd389a69217d6213fe58", diff --git a/nix/deps.nix b/nix/deps.nix index a73c829..f728089 100644 --- a/nix/deps.nix +++ b/nix/deps.nix @@ -17,13 +17,6 @@ fetchSubmodules = true; }; - bearssl_pkey_decoder = pkgs.fetchgit { - url = "https://github.com/vacp2p/bearssl_pkey_decoder"; - rev = "21dd3710df9345ed2ad8bf8f882761e07863b8e0"; - sha256 = "0bl3f147zmkazbhdkr4cj1nipf9rqiw3g4hh1j424k9hpl55zdpg"; - fetchSubmodules = true; - }; - results = pkgs.fetchgit { url = "https://github.com/arnetheduck/nim-results"; rev = "df8113dda4c2d74d460a8fa98252b0b771bf1f27"; @@ -52,6 +45,20 @@ fetchSubmodules = true; }; + npeg = pkgs.fetchgit { + url = "https://github.com/zevv/npeg"; + rev = "409f6796d0e880b3f0222c964d1da7de6e450811"; + sha256 = "1h2f5znbpa3svk7wsw2axn8f7f59d23xq85z148kiv6fqh0ffwbm"; + fetchSubmodules = true; + }; + + protobuf_serialization = pkgs.fetchgit { + url = "https://github.com/status-im/nim-protobuf-serialization"; + rev = "d9aa950b9d9e8bfc8a201740042b5e8ea5880875"; + sha256 = "11hrqpq7dpdqfn71izmq7ysrdnh8gry0qvrgqdspcz2k2lifzz0c"; + fetchSubmodules = true; + }; + json_serialization = pkgs.fetchgit { url = "https://github.com/status-im/nim-json-serialization"; rev = "c343b0e243d9e17e2c40f3a8a24340f7c4a71d44"; @@ -87,27 +94,6 @@ fetchSubmodules = true; }; - dnsclient = pkgs.fetchgit { - url = "https://github.com/ba0f3/dnsclient.nim"; - rev = "23214235d4784d24aceed99bbfe153379ea557c8"; - sha256 = "03mf3lw5c0m5nq9ppa49nylrl8ibkv2zzlc0wyhqg7w09kz6hks6"; - fetchSubmodules = true; - }; - - jwt = pkgs.fetchgit { - url = "https://github.com/vacp2p/nim-jwt.git"; - rev = "18f8378de52b241f321c1f9ea905456e89b95c6f"; - sha256 = "1986czmszdxj6g9yr7xn1fx8y2y9mwpb3f1bn9nc6973qawsdm0p"; - fetchSubmodules = true; - }; - - nimcrypto = pkgs.fetchgit { - url = "https://github.com/cheatfate/nimcrypto"; - rev = "b3dbc9c4d08e58c5b7bfad6dc7ef2ee52f2f4c08"; - sha256 = "1v4rz42lwcazs6isi3kmjylkisr84mh0kgmlqycx4i885dn3g0l4"; - fetchSubmodules = true; - }; - metrics = pkgs.fetchgit { url = "https://github.com/status-im/nim-metrics"; rev = "11d0cddfb0e711aa2a8c75d1892ae24a64c299fc"; @@ -115,41 +101,6 @@ fetchSubmodules = true; }; - secp256k1 = pkgs.fetchgit { - url = "https://github.com/status-im/nim-secp256k1"; - rev = "d8f1288b7c72f00be5fc2c5ea72bf5cae1eafb15"; - sha256 = "1qjrmwbngb73f6r1fznvig53nyal7wj41d1cmqfksrmivk2sgrn2"; - fetchSubmodules = true; - }; - - zlib = pkgs.fetchgit { - url = "https://github.com/status-im/nim-zlib"; - rev = "e680f269fb01af2c34a2ba879ff281795a5258fe"; - sha256 = "1xw9f1gjsgqihdg7kdkbaq1wankgnx2vn9l3ihc6nqk2jzv5bvk5"; - fetchSubmodules = true; - }; - - websock = pkgs.fetchgit { - url = "https://github.com/status-im/nim-websock"; - rev = "35ae76f1559e835c80f9c1a3943bf995d3dd9eb5"; - sha256 = "1j6dklzb6b6bv2aiglbiyflja2vdpmyxfirv98f49y62mykq0yrw"; - fetchSubmodules = true; - }; - - lsquic = pkgs.fetchgit { - url = "https://github.com/vacp2p/nim-lsquic"; - rev = "4fb03ee7bfb39aecb3316889fdcb60bec3d0936f"; - sha256 = "0qdhcd4hyp185szc9sv3jvwdwc9zp3j0syy7glxv13k9bchfmkfg"; - fetchSubmodules = true; - }; - - libp2p = pkgs.fetchgit { - url = "https://github.com/vacp2p/nim-libp2p"; - rev = "ca48c3718246bb411ff0e354a70cb82d9a28de0d"; - sha256 = "07qfjjrq6w7bj9dbchvcrpla47jidngbrgmigbhl7fh3cfkdabc9"; - fetchSubmodules = true; - }; - stint = pkgs.fetchgit { url = "https://github.com/status-im/nim-stint"; rev = "470b7892561b5179ab20bd389a69217d6213fe58"; @@ -166,8 +117,8 @@ ffi = pkgs.fetchgit { url = "https://github.com/logos-messaging/nim-ffi"; - rev = "fb25f069d2dfae2b543d79d2c1a81f197de22a2b"; - sha256 = "0zkjnrm2yjlw27q99kv2x8ll61mbz4nr0cvmyq0csydh43c08k0p"; + rev = "d4c87c1f94c4678eea7d32a8f5f41c72420fadb6"; + sha256 = "14dm92l3wl8sc5a108612r1cgjvxksy2chzmn1asph6frl4lm641"; fetchSubmodules = true; }; diff --git a/sds.nimble b/sds.nimble index d844717..78bddea 100644 --- a/sds.nimble +++ b/sds.nimble @@ -10,7 +10,7 @@ srcDir = "sds" # Dependencies requires "nim >= 2.2.4" requires "chronos >= 4.0.4" -requires "libp2p >= 1.15.2" +requires "protobuf_serialization >= 0.5.0" requires "chronicles" requires "stew" requires "stint" diff --git a/sds/protobuf.nim b/sds/protobuf.nim index 916bf18..48eeb0b 100644 --- a/sds/protobuf.nim +++ b/sds/protobuf.nim @@ -1,189 +1,222 @@ -import libp2p/protobuf/minprotobuf +## SDS network wire codec. +## +## Messages are described as annotated protobuf types and (de)serialised with +## `nim-protobuf-serialization`'s type-driven `Protobuf.encode/decode`. The +## domain types (`SdsMessage`, `HistoryEntry`) keep their distinct/`requiresInit` +## shape; small `*PB` mirrors carry the field annotations and a trivial +## conversion bridges the two. The mirror string-ish fields are `seq[byte]` +## (not `pstring`) so message/channel/sender ids stay opaque bytes — no UTF-8 +## validation — and the distinct `SdsParticipantID` needs no special support. +## +## Singular fields use the proto3 `optional` label (`Opt[T]`), which is the +## recommended form for forward-compatibility; presence is exposed but the +## actual validity of mandatory identifiers is checked at the application layer +## after decoding (proto3 has no `required`). + +{.push raises: [].} + import endians +import protobuf_serialization +import protobuf_serialization/pkg/results import ./types/[sds_message_id, history_entry, sds_message, reliability_error] -import ./protobufutil import ./bloom -import ./sds_utils -proc encodeHistoryEntry*(entry: HistoryEntry): ProtoBuffer = - var entryPb = initProtoBuffer() - entryPb.write(1, entry.messageId) - if entry.retrievalHint.len > 0: - entryPb.write(2, entry.retrievalHint) - if entry.senderId.len > 0: - entryPb.write(3, entry.senderId.string) - entryPb.finish() - entryPb +# --------------------------------------------------------------------------- +# Wire types +# --------------------------------------------------------------------------- -proc decodeHistoryEntry*(entryPb: ProtoBuffer): ProtobufResult[HistoryEntry] = - var entry = HistoryEntry.init("") - if not ?entryPb.getField(1, entry.messageId): - return err(ProtobufError.missingRequiredField("HistoryEntry.messageId")) - discard entryPb.getField(2, entry.retrievalHint) - var senderIdStr: string - if entryPb.getField(3, senderIdStr).valueOr(false): - entry.senderId = senderIdStr.SdsParticipantID - ok(entry) +type + HistoryEntryPB* {.proto3.} = object + messageId* {.fieldNumber: 1.}: Opt[seq[byte]] + retrievalHint* {.fieldNumber: 2.}: Opt[seq[byte]] + senderId* {.fieldNumber: 3.}: Opt[seq[byte]] -proc encode*(msg: SdsMessage): ProtoBuffer = - var pb = initProtoBuffer() + SdsMessagePB* {.proto3.} = object + messageId* {.fieldNumber: 1.}: Opt[seq[byte]] + lamportTimestamp* {.fieldNumber: 2, pint.}: Opt[int64] + causalHistory* {.fieldNumber: 3.}: seq[HistoryEntryPB] + channelId* {.fieldNumber: 4.}: Opt[seq[byte]] + content* {.fieldNumber: 5.}: Opt[seq[byte]] + bloomFilter* {.fieldNumber: 6.}: Opt[seq[byte]] + senderId* {.fieldNumber: 7.}: Opt[seq[byte]] + repairRequest* {.fieldNumber: 13.}: seq[HistoryEntryPB] - pb.write(1, msg.messageId) - pb.write(2, uint64(msg.lamportTimestamp)) + BloomFilterPB {.proto3.} = object + data {.fieldNumber: 1.}: Opt[seq[byte]] + capacity {.fieldNumber: 2, pint.}: Opt[uint64] + errorRate {.fieldNumber: 3, pint.}: Opt[uint64] + kHashes {.fieldNumber: 4, pint.}: Opt[uint64] + mBits {.fieldNumber: 5, pint.}: Opt[uint64] - for entry in msg.causalHistory: - let entryPb = encodeHistoryEntry(entry) - pb.write(3, entryPb.buffer) +# --------------------------------------------------------------------------- +# string <-> bytes (opaque, no UTF-8 validation) and optional-bytes helper +# --------------------------------------------------------------------------- - pb.write(4, msg.channelId) - pb.write(5, msg.content) - pb.write(6, msg.bloomFilter) +func toBytes(s: string): seq[byte] = + var b = newSeq[byte](s.len) + if s.len > 0: + copyMem(addr b[0], unsafeAddr s[0], s.len) + return b - if msg.senderId.len > 0: - pb.write(7, msg.senderId.string) +func toStr(b: seq[byte]): string = + var s = newString(b.len) + if b.len > 0: + copyMem(addr s[0], unsafeAddr b[0], b.len) + return s - for entry in msg.repairRequest: - let entryPb = encodeHistoryEntry(entry) - pb.write(13, entryPb.buffer) +func optBytes(b: seq[byte]): Opt[seq[byte]] = + ## Present only when non-empty, so empty optionals stay off the wire. + if b.len > 0: + return Opt.some(b) + return Opt.none(seq[byte]) - pb.finish() +# --------------------------------------------------------------------------- +# Domain <-> wire conversion +# --------------------------------------------------------------------------- +func toPB*(e: HistoryEntry): HistoryEntryPB = + return HistoryEntryPB( + messageId: optBytes(e.messageId.toBytes), + retrievalHint: optBytes(e.retrievalHint), + senderId: optBytes(e.senderId.string.toBytes), + ) + +func fromPB*(e: HistoryEntryPB): HistoryEntry = + return HistoryEntry( + messageId: e.messageId.valueOr(@[]).toStr, + retrievalHint: e.retrievalHint.valueOr(@[]), + senderId: e.senderId.valueOr(@[]).toStr.SdsParticipantID, + ) + +func toPB*(m: SdsMessage): SdsMessagePB = + var pb = SdsMessagePB( + messageId: optBytes(m.messageId.toBytes), + lamportTimestamp: Opt.some(m.lamportTimestamp), + channelId: optBytes(m.channelId.toBytes), + content: optBytes(m.content), + bloomFilter: optBytes(m.bloomFilter), + senderId: optBytes(m.senderId.string.toBytes), + ) + for e in m.causalHistory: + pb.causalHistory.add(e.toPB) + for e in m.repairRequest: + pb.repairRequest.add(e.toPB) return pb -proc decode*(T: type SdsMessage, buffer: seq[byte]): ProtobufResult[T] = - let pb = initProtoBuffer(buffer) - var msg = SdsMessage.init("", 0, @[], "", @[], @[]) +func fromPB*(pb: SdsMessagePB): SdsMessage = + var causal: seq[HistoryEntry] + for e in pb.causalHistory: + causal.add(e.fromPB) + var repair: seq[HistoryEntry] + for e in pb.repairRequest: + repair.add(e.fromPB) + return SdsMessage.init( + messageId = pb.messageId.valueOr(@[]).toStr, + lamportTimestamp = pb.lamportTimestamp.valueOr(0'i64), + causalHistory = causal, + channelId = pb.channelId.valueOr(@[]).toStr, + content = pb.content.valueOr(@[]), + bloomFilter = pb.bloomFilter.valueOr(@[]), + senderId = pb.senderId.valueOr(@[]).toStr.SdsParticipantID, + repairRequest = repair, + ) - if not ?pb.getField(1, msg.messageId): - return err(ProtobufError.missingRequiredField("messageId")) - - var timestamp: uint64 - if not ?pb.getField(2, timestamp): - return err(ProtobufError.missingRequiredField("lamportTimestamp")) - msg.lamportTimestamp = int64(timestamp) - - # Handle both old and new causal history formats - var historyBuffers: seq[seq[byte]] - if pb.getRepeatedField(3, historyBuffers).isOk(): - # New format: repeated HistoryEntry - for histBuffer in historyBuffers: - let entryPb = initProtoBuffer(histBuffer) - let entry = ?decodeHistoryEntry(entryPb) - msg.causalHistory.add(entry) - else: - # Try old format: repeated string - var causalHistory: seq[SdsMessageID] - let histResult = pb.getRepeatedField(3, causalHistory) - if histResult.isOk(): - msg.causalHistory = toCausalHistory(causalHistory) - - if not ?pb.getField(4, msg.channelId): - return err(ProtobufError.missingRequiredField("channelId")) - - if not ?pb.getField(5, msg.content): - return err(ProtobufError.missingRequiredField("content")) - - if not ?pb.getField(6, msg.bloomFilter): - msg.bloomFilter = @[] # Empty if not present - - # SDS-R: decode senderId (field 7, optional) - var msgSenderIdStr: string - if pb.getField(7, msgSenderIdStr).valueOr(false): - msg.senderId = msgSenderIdStr.SdsParticipantID - - # SDS-R: decode repair request (field 13, optional) - var repairBuffers: seq[seq[byte]] - if pb.getRepeatedField(13, repairBuffers).isOk(): - for repairBuffer in repairBuffers: - let entryPb = initProtoBuffer(repairBuffer) - let entry = ?decodeHistoryEntry(entryPb) - msg.repairRequest.add(entry) - - return ok(msg) - -proc extractChannelId*(data: seq[byte]): Result[SdsChannelID, ReliabilityError] = - ## For extraction of channel ID without full message deserialization - try: - let pb = initProtoBuffer(data) - var channelId: SdsChannelID - let fieldOk = pb.getField(4, channelId).valueOr: - return err(ReliabilityError.reDeserializationError) - if not fieldOk: - return err(ReliabilityError.reDeserializationError) - return ok(channelId) - except: - return err(ReliabilityError.reDeserializationError) +# --------------------------------------------------------------------------- +# Message (de)serialisation +# --------------------------------------------------------------------------- proc serializeMessage*(msg: SdsMessage): Result[seq[byte], ReliabilityError] = - let pb = encode(msg) - return ok(pb.buffer) + try: + return ok(Protobuf.encode(msg.toPB)) + except CatchableError: + return err(ReliabilityError.reSerializationError) proc deserializeMessage*(data: seq[byte]): Result[SdsMessage, ReliabilityError] = - let msg = SdsMessage.decode(data).valueOr: + ## proto3 has no required fields, so the mandatory identifiers are validated + ## by hand after decoding. `content`/`bloomFilter`/`lamportTimestamp` may + ## legitimately be empty/zero (e.g. periodic sync messages). + try: + let msg = Protobuf.decode(data, SdsMessagePB).fromPB + if msg.messageId.len == 0 or msg.channelId.len == 0: + return err(ReliabilityError.reDeserializationError) + for e in msg.causalHistory: + if e.messageId.len == 0: + return err(ReliabilityError.reDeserializationError) + for e in msg.repairRequest: + if e.messageId.len == 0: + return err(ReliabilityError.reDeserializationError) + return ok(msg) + except CatchableError: return err(ReliabilityError.reDeserializationError) - return ok(msg) + +proc extractChannelId*(data: seq[byte]): Result[SdsChannelID, ReliabilityError] = + ## Channel ID without retaining the rest of the decoded message. + try: + return ok(Protobuf.decode(data, SdsMessagePB).channelId.valueOr(@[]).toStr) + except CatchableError: + return err(ReliabilityError.reDeserializationError) + +# Single `HistoryEntry` (de)serialisation, used by the snapshot codec for the +# repair-buffer entries it embeds. Kept here so all `Protobuf.decode` calls live +# in this module. + +proc serializeHistoryEntry*(e: HistoryEntry): Result[seq[byte], ReliabilityError] = + try: + return ok(Protobuf.encode(e.toPB)) + except CatchableError: + return err(ReliabilityError.reSerializationError) + +proc deserializeHistoryEntry*(data: seq[byte]): Result[HistoryEntry, ReliabilityError] = + try: + return ok(Protobuf.decode(data, HistoryEntryPB).fromPB) + except CatchableError: + return err(ReliabilityError.reDeserializationError) + +# --------------------------------------------------------------------------- +# Bloom filter (de)serialisation +# --------------------------------------------------------------------------- proc serializeBloomFilter*(filter: BloomFilter): Result[seq[byte], ReliabilityError] = - var pb = initProtoBuffer() - try: var bytes = newSeq[byte](filter.intArray.len * sizeof(int)) for i, val in filter.intArray: var leVal: int littleEndian64(addr leVal, unsafeAddr val) - let start = i * sizeof(int) - copyMem(addr bytes[start], addr leVal, sizeof(int)) + copyMem(addr bytes[i * sizeof(int)], addr leVal, sizeof(int)) - pb.write(1, bytes) - pb.write(2, uint64(filter.capacity)) - pb.write(3, uint64(filter.errorRate * 1_000_000)) - pb.write(4, uint64(filter.kHashes)) - pb.write(5, uint64(filter.mBits)) - except: + let pb = BloomFilterPB( + data: optBytes(bytes), + capacity: Opt.some(uint64(filter.capacity)), + errorRate: Opt.some(uint64(filter.errorRate * 1_000_000)), + kHashes: Opt.some(uint64(filter.kHashes)), + mBits: Opt.some(uint64(filter.mBits)), + ) + return ok(Protobuf.encode(pb)) + except CatchableError: return err(ReliabilityError.reSerializationError) - pb.finish() - return ok(pb.buffer) - proc deserializeBloomFilter*(data: seq[byte]): Result[BloomFilter, ReliabilityError] = if data.len == 0: return err(ReliabilityError.reDeserializationError) - - let pb = initProtoBuffer(data) - var bytes: seq[byte] - var cap, errRate, kHashes, mBits: uint64 - try: - let - field1_Ok = pb.getField(1, bytes).valueOr: - return err(ReliabilityError.reDeserializationError) - field2_Ok = pb.getField(2, cap).valueOr: - return err(ReliabilityError.reDeserializationError) - field3_Ok = pb.getField(3, errRate).valueOr: - return err(ReliabilityError.reDeserializationError) - field4_Ok = pb.getField(4, kHashes).valueOr: - return err(ReliabilityError.reDeserializationError) - field5_Ok = pb.getField(5, mBits).valueOr: - return err(ReliabilityError.reDeserializationError) - - if not field1_Ok or not field2_Ok or not field3_Ok or not field4_Ok or not field5_Ok: - return err(ReliabilityError.reDeserializationError) - - var intArray = newSeq[int](bytes.len div sizeof(int)) + let pb = Protobuf.decode(data, BloomFilterPB) + let rawData = pb.data.valueOr(@[]) + var intArray = newSeq[int](rawData.len div sizeof(int)) for i in 0 ..< intArray.len: var leVal: int - let start = i * sizeof(int) - copyMem(addr leVal, unsafeAddr bytes[start], sizeof(int)) + copyMem(addr leVal, unsafeAddr rawData[i * sizeof(int)], sizeof(int)) littleEndian64(addr intArray[i], addr leVal) return ok( BloomFilter.init( - capacity = int(cap), - errorRate = float(errRate) / 1_000_000, - kHashes = int(kHashes), - mBits = int(mBits), + capacity = int(pb.capacity.valueOr(0'u64)), + errorRate = float(pb.errorRate.valueOr(0'u64)) / 1_000_000, + kHashes = int(pb.kHashes.valueOr(0'u64)), + mBits = int(pb.mBits.valueOr(0'u64)), intArray = intArray, ) ) - except: + except CatchableError: return err(ReliabilityError.reDeserializationError) + +{.pop.} diff --git a/sds/protobufutil.nim b/sds/protobufutil.nim index 3153017..4e272b2 100644 --- a/sds/protobufutil.nim +++ b/sds/protobufutil.nim @@ -1,19 +1,175 @@ -# adapted from https://github.com/waku-org/nwaku/blob/master/waku/common/protobuf.nim +# Minimal hand-rolled protobuf field codec, a thin shim over +# `nim-protobuf-serialization`'s low-level wire `codec` module. +# +# `sds/protobuf.nim` and `sds/snapshot_codec.nim` build messages by hand at the +# field level — including a backward-compatible decode path the type-driven +# `Protobuf.encode/decode` API cannot express, and required-field / always-write +# semantics its default-value omission would break — so this exposes a small +# accumulating `ProtoBuffer` with `write`/`getField`/`getRepeatedField`/`finish`: +# * unsigned integers encode as plain varints (last value wins on decode); +# * strings and byte seqs encode length-delimited, with no UTF-8 validation +# (strings are treated as opaque bytes — message ids may be binary); +# * a field whose stored wire type differs from the requested one is skipped, +# as `protoc` does; only a malformed buffer yields an error. +# +# On construction from bytes the buffer is parsed once, in a single forward pass +# with the library's reader, into per-field value lists; the `getField` accessors +# are then plain lookups rather than re-scanning the buffer for every field. {.push raises: [].} -import libp2p/protobuf/minprotobuf -import libp2p/varint +import std/tables +import results +import faststreams/inputs +import protobuf_serialization/codec except ProtobufError import ./types/protobuf_error -export minprotobuf, varint, protobuf_error +export results, protobuf_error -converter toProtobufError*(err: minprotobuf.ProtoError): ProtobufError = +type ProtoBuffer* = object ## Accumulating protobuf field buffer. + buffer*: seq[byte] + ## Reads are served from these parse-once indexes (populated by `init(data)`), + ## keyed by field number; values are kept in wire order so last-wins / repeated + ## semantics fall out directly. + varints: Table[int, seq[uint64]] + lengthDelims: Table[int, seq[seq[byte]]] + parseOk: bool + +converter toProtobufError*(err: ProtoError): ProtobufError = case err - of minprotobuf.ProtoError.RequiredFieldMissing: + of ProtoError.RequiredFieldMissing: return ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: "unknown") else: return ProtobufError(kind: ProtobufErrorKind.DecodeFailure, error: err) proc missingRequiredField*(T: type ProtobufError, field: string): T = return ProtobufError.init(field) + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + +proc init*(T: type ProtoBuffer): T = + return T(buffer: @[], parseOk: true) + +proc init*(T: type ProtoBuffer, data: seq[byte]): T = + ## Parse `data` once into per-field value lists. A malformed buffer leaves + ## `parseOk = false`, which every accessor reports as a decode error. + var pb = T(buffer: data, parseOk: true) + var sh = memoryInput(data) + try: + let stream = sh.s + while stream.readable: + let hdr = readHeader(stream) + case hdr.kind + of WireKind.Varint: + pb.varints.mgetOrPut(hdr.number, @[]).add(uint64(readValue(stream, puint64))) + of WireKind.LengthDelim: + pb.lengthDelims.mgetOrPut(hdr.number, @[]).add(seq[byte](readValue(stream, pbytes))) + of WireKind.Fixed64: + skipValue(stream, fixed64) + of WireKind.Fixed32: + skipValue(stream, fixed32) + except CatchableError: + pb.parseOk = false + return pb + +proc finish*(pb: var ProtoBuffer) = + ## No length prefix is used, so finishing only asserts the invariant that a + ## top-level buffer is never empty. + doAssert(pb.buffer.len > 0) + +# --------------------------------------------------------------------------- +# Writing +# --------------------------------------------------------------------------- + +proc writeVarint(pb: var ProtoBuffer, field: int, value: uint64) = + pb.buffer.add(toBytes(FieldHeader.init(field, WireKind.Varint))) + pb.buffer.add(toBytes(puint64(value))) + +proc write*(pb: var ProtoBuffer, field: int, value: uint64) = + pb.writeVarint(field, value) + +proc write*(pb: var ProtoBuffer, field: int, value: uint32) = + pb.writeVarint(field, uint64(value)) + +proc writeLengthDelim(pb: var ProtoBuffer, field: int, data: openArray[byte]) = + pb.buffer.add(toBytes(FieldHeader.init(field, WireKind.LengthDelim))) + pb.buffer.add(toBytes(puint64(uint64(data.len)))) + if data.len > 0: + pb.buffer.add(data) + +proc write*(pb: var ProtoBuffer, field: int, value: openArray[byte]) = + pb.writeLengthDelim(field, value) + +proc write*(pb: var ProtoBuffer, field: int, value: string) = + pb.writeLengthDelim(field, value.toOpenArrayByte(0, value.high)) + +# --------------------------------------------------------------------------- +# Reading +# --------------------------------------------------------------------------- + +proc bytesToString(b: seq[byte]): string = + ## Copy raw bytes into a string without UTF-8 validation — protobuf strings + ## are opaque bytes here, and message ids may not be valid UTF-8. + var s = newString(b.len) + if b.len > 0: + copyMem(addr s[0], unsafeAddr b[0], b.len) + return s + +proc getField*(pb: ProtoBuffer, field: int, output: var uint64): ProtoResult[bool] = + if not pb.parseOk: + return err(ProtoError.VarintDecode) + let values = pb.varints.getOrDefault(field) + if values.len > 0: + output = values[^1] + return ok(true) + return ok(false) + +proc getField*(pb: ProtoBuffer, field: int, output: var uint32): ProtoResult[bool] = + if not pb.parseOk: + return err(ProtoError.VarintDecode) + let values = pb.varints.getOrDefault(field) + if values.len > 0: + output = uint32(values[^1]) + return ok(true) + return ok(false) + +proc getField*(pb: ProtoBuffer, field: int, output: var seq[byte]): ProtoResult[bool] = + if not pb.parseOk: + return err(ProtoError.VarintDecode) + let values = pb.lengthDelims.getOrDefault(field) + if values.len > 0: + output = values[^1] + return ok(true) + return ok(false) + +proc getField*(pb: ProtoBuffer, field: int, output: var string): ProtoResult[bool] = + if not pb.parseOk: + return err(ProtoError.VarintDecode) + let values = pb.lengthDelims.getOrDefault(field) + if values.len > 0: + output = bytesToString(values[^1]) + return ok(true) + return ok(false) + +proc getRepeatedField*( + pb: ProtoBuffer, field: int, output: var seq[seq[byte]] +): ProtoResult[bool] = + if not pb.parseOk: + return err(ProtoError.VarintDecode) + output = pb.lengthDelims.getOrDefault(field) + return ok(output.len > 0) + +proc getRepeatedField*( + pb: ProtoBuffer, field: int, output: var seq[string] +): ProtoResult[bool] = + if not pb.parseOk: + return err(ProtoError.VarintDecode) + let values = pb.lengthDelims.getOrDefault(field) + output.setLen(0) + for v in values: + output.add(bytesToString(v)) + return ok(output.len > 0) + +{.pop.} diff --git a/sds/snapshot_codec.nim b/sds/snapshot_codec.nim index 7b626c6..45648c0 100644 --- a/sds/snapshot_codec.nim +++ b/sds/snapshot_codec.nim @@ -13,7 +13,6 @@ {.push raises: [].} import std/[sets, times] -import libp2p/protobuf/minprotobuf import ./types/[ channel_meta, history_update, sds_message, sds_message_id, history_entry, unacknowledged_message, incoming_message, repair_entry, reliability_error, @@ -44,20 +43,19 @@ proc fromUnixMs(ms: int64): Time = # --------------------------------------------------------------------------- proc encodeUnacked(u: UnacknowledgedMessage): ProtoBuffer = - var pb = initProtoBuffer() - let msgPb = wire.encode(u.message) - pb.write(1, msgPb.buffer) + var pb = ProtoBuffer.init() + pb.write(1, wire.serializeMessage(u.message).get()) pb.write(2, uint64(u.sendTime.toUnixMs)) pb.write(3, uint32(u.resendAttempts)) pb.finish() pb proc decodeUnacked(buf: seq[byte]): ProtobufResult[UnacknowledgedMessage] = - let pb = initProtoBuffer(buf) + let pb = ProtoBuffer.init(buf) var msgBytes: seq[byte] if not ?pb.getField(1, msgBytes): return err(ProtobufError.missingRequiredField("UnacknowledgedMessage.message")) - let msg = SdsMessage.decode(msgBytes).valueOr: + let msg = wire.deserializeMessage(msgBytes).valueOr: return err(ProtobufError.missingRequiredField("UnacknowledgedMessage.message")) var sendMs: uint64 if not ?pb.getField(2, sendMs): @@ -77,20 +75,19 @@ proc decodeUnacked(buf: seq[byte]): ProtobufResult[UnacknowledgedMessage] = # --------------------------------------------------------------------------- proc encodeIncoming(m: IncomingMessage): ProtoBuffer = - var pb = initProtoBuffer() - let msgPb = wire.encode(m.message) - pb.write(1, msgPb.buffer) + var pb = ProtoBuffer.init() + pb.write(1, wire.serializeMessage(m.message).get()) for dep in m.missingDeps: pb.write(2, dep) # SdsMessageID is string pb.finish() pb proc decodeIncoming(buf: seq[byte]): ProtobufResult[IncomingMessage] = - let pb = initProtoBuffer(buf) + let pb = ProtoBuffer.init(buf) var msgBytes: seq[byte] if not ?pb.getField(1, msgBytes): return err(ProtobufError.missingRequiredField("IncomingMessage.message")) - let msg = SdsMessage.decode(msgBytes).valueOr: + let msg = wire.deserializeMessage(msgBytes).valueOr: return err(ProtobufError.missingRequiredField("IncomingMessage.message")) var deps: seq[SdsMessageID] discard pb.getRepeatedField(2, deps) @@ -104,20 +101,19 @@ proc decodeIncoming(buf: seq[byte]): ProtobufResult[IncomingMessage] = # --------------------------------------------------------------------------- proc encodeOutRepairEntry(e: OutgoingRepairEntry): ProtoBuffer = - var pb = initProtoBuffer() - let histPb = wire.encodeHistoryEntry(e.outHistEntry) - pb.write(1, histPb.buffer) + var pb = ProtoBuffer.init() + pb.write(1, wire.serializeHistoryEntry(e.outHistEntry).get()) pb.write(2, uint64(e.minTimeRepairReq.toUnixMs)) pb.finish() pb proc decodeOutRepairEntry(buf: seq[byte]): ProtobufResult[OutgoingRepairEntry] = - let pb = initProtoBuffer(buf) + let pb = ProtoBuffer.init(buf) var histBytes: seq[byte] if not ?pb.getField(1, histBytes): return err(ProtobufError.missingRequiredField("OutgoingRepairEntry.outHistEntry")) - let histPb = initProtoBuffer(histBytes) - let entry = ?wire.decodeHistoryEntry(histPb) + let entry = wire.deserializeHistoryEntry(histBytes).valueOr: + return err(ProtobufError.missingRequiredField("HistoryEntry")) var ms: uint64 if not ?pb.getField(2, ms): return err(ProtobufError.missingRequiredField("OutgoingRepairEntry.minTimeRepairReq")) @@ -128,7 +124,7 @@ proc decodeOutRepairEntry(buf: seq[byte]): ProtobufResult[OutgoingRepairEntry] = ) proc encodeOutRepairKV(kv: OutgoingRepairKV): ProtoBuffer = - var pb = initProtoBuffer() + var pb = ProtoBuffer.init() pb.write(1, kv.messageId) let entryPb = encodeOutRepairEntry(kv.entry) pb.write(2, entryPb.buffer) @@ -136,7 +132,7 @@ proc encodeOutRepairKV(kv: OutgoingRepairKV): ProtoBuffer = pb proc decodeOutRepairKV(buf: seq[byte]): ProtobufResult[OutgoingRepairKV] = - let pb = initProtoBuffer(buf) + let pb = ProtoBuffer.init(buf) var msgId: SdsMessageID if not ?pb.getField(1, msgId): return err(ProtobufError.missingRequiredField("OutgoingRepairKV.messageId")) @@ -151,21 +147,20 @@ proc decodeOutRepairKV(buf: seq[byte]): ProtobufResult[OutgoingRepairKV] = # --------------------------------------------------------------------------- proc encodeInRepairEntry(e: IncomingRepairEntry): ProtoBuffer = - var pb = initProtoBuffer() - let histPb = wire.encodeHistoryEntry(e.inHistEntry) - pb.write(1, histPb.buffer) + var pb = ProtoBuffer.init() + pb.write(1, wire.serializeHistoryEntry(e.inHistEntry).get()) pb.write(2, e.cachedMessage) pb.write(3, uint64(e.minTimeRepairResp.toUnixMs)) pb.finish() pb proc decodeInRepairEntry(buf: seq[byte]): ProtobufResult[IncomingRepairEntry] = - let pb = initProtoBuffer(buf) + let pb = ProtoBuffer.init(buf) var histBytes: seq[byte] if not ?pb.getField(1, histBytes): return err(ProtobufError.missingRequiredField("IncomingRepairEntry.inHistEntry")) - let histPb = initProtoBuffer(histBytes) - let entry = ?wire.decodeHistoryEntry(histPb) + let entry = wire.deserializeHistoryEntry(histBytes).valueOr: + return err(ProtobufError.missingRequiredField("HistoryEntry")) var cached: seq[byte] if not ?pb.getField(2, cached): return err(ProtobufError.missingRequiredField("IncomingRepairEntry.cachedMessage")) @@ -181,7 +176,7 @@ proc decodeInRepairEntry(buf: seq[byte]): ProtobufResult[IncomingRepairEntry] = ) proc encodeInRepairKV(kv: IncomingRepairKV): ProtoBuffer = - var pb = initProtoBuffer() + var pb = ProtoBuffer.init() pb.write(1, kv.messageId) let entryPb = encodeInRepairEntry(kv.entry) pb.write(2, entryPb.buffer) @@ -189,7 +184,7 @@ proc encodeInRepairKV(kv: IncomingRepairKV): ProtoBuffer = pb proc decodeInRepairKV(buf: seq[byte]): ProtobufResult[IncomingRepairKV] = - let pb = initProtoBuffer(buf) + let pb = ProtoBuffer.init(buf) var msgId: SdsMessageID if not ?pb.getField(1, msgId): return err(ProtobufError.missingRequiredField("IncomingRepairKV.messageId")) @@ -204,7 +199,7 @@ proc decodeInRepairKV(buf: seq[byte]): ProtobufResult[IncomingRepairKV] = # --------------------------------------------------------------------------- proc encode*(meta: ChannelMeta): ProtoBuffer = - var pb = initProtoBuffer() + var pb = ProtoBuffer.init() pb.write(1, meta.schemaVersion) pb.write(2, uint64(meta.lamportTimestamp)) for u in meta.outgoingBuffer: @@ -223,7 +218,7 @@ proc encode*(meta: ChannelMeta): ProtoBuffer = pb proc decode*(T: type ChannelMeta, buf: seq[byte]): ProtobufResult[T] = - let pb = initProtoBuffer(buf) + let pb = ProtoBuffer.init(buf) var meta = ChannelMeta.init() var ver: uint32 @@ -271,17 +266,16 @@ proc deserializeChannelMeta*( # --------------------------------------------------------------------------- proc encode*(d: ChannelData): ProtoBuffer = - var pb = initProtoBuffer() + var pb = ProtoBuffer.init() let metaPb = encode(d.meta) pb.write(1, metaPb.buffer) for m in d.messageHistory: - let msgPb = wire.encode(m) - pb.write(2, msgPb.buffer) + pb.write(2, wire.serializeMessage(m).get()) pb.finish() pb proc decode*(T: type ChannelData, buf: seq[byte]): ProtobufResult[T] = - let pb = initProtoBuffer(buf) + let pb = ProtoBuffer.init(buf) var d = ChannelData.init() var metaBytes: seq[byte] if not ?pb.getField(1, metaBytes): @@ -290,7 +284,7 @@ proc decode*(T: type ChannelData, buf: seq[byte]): ProtobufResult[T] = var histBufs: seq[seq[byte]] discard pb.getRepeatedField(2, histBufs) for b in histBufs: - let m = SdsMessage.decode(b).valueOr: + let m = wire.deserializeMessage(b).valueOr: return err(ProtobufError.missingRequiredField("ChannelData.messageHistory[i]")) d.messageHistory.add(m) ok(d) @@ -300,22 +294,21 @@ proc decode*(T: type ChannelData, buf: seq[byte]): ProtobufResult[T] = # --------------------------------------------------------------------------- proc encode*(u: HistoryUpdate): ProtoBuffer = - var pb = initProtoBuffer() + var pb = ProtoBuffer.init() for m in u.append: - let msgPb = wire.encode(m) - pb.write(1, msgPb.buffer) + pb.write(1, wire.serializeMessage(m).get()) for id in u.evict: pb.write(2, id) pb.finish() pb proc decode*(T: type HistoryUpdate, buf: seq[byte]): ProtobufResult[T] = - let pb = initProtoBuffer(buf) + let pb = ProtoBuffer.init(buf) var u = HistoryUpdate.init() var appBufs: seq[seq[byte]] discard pb.getRepeatedField(1, appBufs) for b in appBufs: - let m = SdsMessage.decode(b).valueOr: + let m = wire.deserializeMessage(b).valueOr: return err(ProtobufError.missingRequiredField("HistoryUpdate.append[i]")) u.append.add(m) var ev: seq[SdsMessageID] diff --git a/sds/types/protobuf_error.nim b/sds/types/protobuf_error.nim index aff41df..cb338de 100644 --- a/sds/types/protobuf_error.nim +++ b/sds/types/protobuf_error.nim @@ -1,7 +1,18 @@ import results -import libp2p/protobuf/minprotobuf type + ProtoError* {.pure.} = enum + ## Low-level protobuf wire decode errors surfaced by the field codec in + ## `sds/protobufutil.nim`. + VarintDecode + MessageIncomplete + BufferOverflow + BadWireType + IncorrectBlob + RequiredFieldMissing + + ProtoResult*[T] = Result[T, ProtoError] + ProtobufErrorKind* {.pure.} = enum DecodeFailure MissingRequiredField @@ -9,13 +20,13 @@ type ProtobufError* = object case kind*: ProtobufErrorKind of DecodeFailure: - error*: minprotobuf.ProtoError + error*: ProtoError of MissingRequiredField: field*: string ProtobufResult*[T] = Result[T, ProtobufError] -proc init*(T: type ProtobufError, error: minprotobuf.ProtoError): T = +proc init*(T: type ProtobufError, error: ProtoError): T = return T(kind: ProtobufErrorKind.DecodeFailure, error: error) proc init*(T: type ProtobufError, field: string): T =