From f7584dfc4970fdf90cc701225aca5290e0524be7 Mon Sep 17 00:00:00 2001 From: Lorenzo Delgado Date: Mon, 20 Feb 2023 15:03:32 +0100 Subject: [PATCH] feat(protobuf): added error wrappers for invalid length validation --- tests/all_tests_common.nim | 1 + tests/common/test_protobuf_validation.nim | 103 ++++++++++++++++++ tests/v2/test_waku_noise_sessions.nim | 6 +- waku/common/protobuf.nim | 37 ++++++- waku/v2/protocol/waku_filter/rpc_codec.nim | 34 +++--- waku/v2/protocol/waku_lightpush/rpc_codec.nim | 20 ++-- waku/v2/protocol/waku_message/codec.nim | 15 ++- waku/v2/protocol/waku_message/message.nim | 3 +- waku/v2/protocol/waku_store/rpc_codec.nim | 27 +++-- 9 files changed, 196 insertions(+), 50 deletions(-) create mode 100644 tests/common/test_protobuf_validation.nim diff --git a/tests/all_tests_common.nim b/tests/all_tests_common.nim index 38392726d..56ad5ac01 100644 --- a/tests/all_tests_common.nim +++ b/tests/all_tests_common.nim @@ -4,4 +4,5 @@ import ./common/test_envvar_serialization, ./common/test_confutils_envvar, + ./common/test_protobuf_validation, ./common/test_sqlite_migrations diff --git a/tests/common/test_protobuf_validation.nim b/tests/common/test_protobuf_validation.nim new file mode 100644 index 000000000..150787e99 --- /dev/null +++ b/tests/common/test_protobuf_validation.nim @@ -0,0 +1,103 @@ + +{.used.} + +import + testutils/unittests +import + ../../waku/common/protobuf + + +## Fixtures + +const MaxTestRpcFieldLen = 5 + +type TestRpc = object + testField*: string + +proc init(T: type TestRpc, field: string): T = + T(testField: field) + +proc encode(rpc: TestRpc): ProtoBuffer = + var pb = initProtoBuffer() + pb.write3(1, rpc.testField) + pb.finish3() + pb + +proc encodeWithBadFieldId(rpc: TestRpc): ProtoBuffer = + var pb = initProtoBuffer() + pb.write3(666, rpc.testField) + pb.finish3() + pb + +proc decode(T: type TestRpc, buf: seq[byte]): ProtobufResult[T] = + let pb = initProtoBuffer(buf) + + var field: string + if not ?pb.getField(1, field): + return err(ProtobufError.missingRequiredField("test_field")) + if field.len > MaxTestRpcFieldLen: + return err(ProtobufError.invalidLengthField("test_field")) + + ok(TestRpc.init(field)) + + +## Tests + +suite "Waku Common - libp2p minprotobuf wrapper": + + test "serialize and deserialize - valid length field": + ## Given + let field = "12345" + + let rpc = TestRpc.init(field) + + ## When + let encodedRpc = rpc.encode() + let decodedRpcRes = TestRpc.decode(encodedRpc.buffer) + + ## Then + check: + decodedRpcRes.isOk() + + let decodedRpc = decodedRpcRes.tryGet() + check: + decodedRpc.testField == field + + test "serialize and deserialize - missing required field": + ## Given + let field = "12345" + + let rpc = TestRpc.init(field) + + ## When + let encodedRpc = rpc.encodeWithBadFieldId() + let decodedRpcRes = TestRpc.decode(encodedRpc.buffer) + + ## Then + check: + decodedRpcRes.isErr() + + let error = decodedRpcRes.tryError() + check: + error.kind == ProtobufErrorKind.MissingRequiredField + error.field == "test_field" + + + test "serialize and deserialize - invalid length field": + ## Given + let field = "123456" # field.len = MaxTestRpcFieldLen + 1 + + let rpc = TestRpc.init(field) + + ## When + let encodedRpc = rpc.encode() + let decodedRpcRes = TestRpc.decode(encodedRpc.buffer) + + ## Then + check: + decodedRpcRes.isErr() + + let error = decodedRpcRes.tryError() + check: + error.kind == ProtobufErrorKind.InvalidLengthField + error.field == "test_field" diff --git a/tests/v2/test_waku_noise_sessions.nim b/tests/v2/test_waku_noise_sessions.nim index 9775cc188..5cfbdc4cb 100644 --- a/tests/v2/test_waku_noise_sessions.nim +++ b/tests/v2/test_waku_noise_sessions.nim @@ -3,9 +3,9 @@ import std/tables, stew/[results, byteutils], - testutils/unittests, - libp2p/protobuf/minprotobuf + testutils/unittests import + ../../waku/common/protobuf, ../../waku/v2/utils/noise as waku_message_utils, ../../waku/v2/protocol/waku_noise/noise_types, ../../waku/v2/protocol/waku_noise/noise_utils, @@ -82,7 +82,7 @@ procSuite "Waku Noise Sessions": var sentTransportMessage: seq[byte] aliceStep, bobStep: HandshakeStepResult - msgFromPb: ProtoResult[WakuMessage] + msgFromPb: ProtobufResult[WakuMessage] wakuMsg: Result[WakuMessage, cstring] pb: ProtoBuffer readPayloadV2: PayloadV2 diff --git a/waku/common/protobuf.nim b/waku/common/protobuf.nim index 54bd7f8f6..db66198be 100644 --- a/waku/common/protobuf.nim +++ b/waku/common/protobuf.nim @@ -9,12 +9,47 @@ import std/options, libp2p/protobuf/minprotobuf, libp2p/varint - + export minprotobuf, varint +## Custom errors + +type + ProtobufErrorKind* {.pure.} = enum + DecodeFailure + MissingRequiredField + InvalidLengthField + + ProtobufError* = object + case kind*: ProtobufErrorKind + of DecodeFailure: + error*: minprotobuf.ProtoError + of MissingRequiredField, InvalidLengthField: + field*: string + + ProtobufResult*[T] = Result[T, ProtobufError] + + +converter toProtobufError*(err: minprotobuf.ProtoError): ProtobufError = + case err: + of minprotobuf.ProtoError.RequiredFieldMissing: + ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: "unknown") + else: + ProtobufError(kind: ProtobufErrorKind.DecodeFailure, error: err) + + +proc missingRequiredField*(T: type ProtobufError, field: string): T = + ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: field) + +proc invalidLengthField*(T: type ProtobufError, field: string): T = + ProtobufError(kind: ProtobufErrorKind.InvalidLengthField, field: field) + + +## Extension methods + proc write3*(proto: var ProtoBuffer, field: int, value: auto) = when value is Option: if value.isSome(): diff --git a/waku/v2/protocol/waku_filter/rpc_codec.nim b/waku/v2/protocol/waku_filter/rpc_codec.nim index db128229e..69ef1ab62 100644 --- a/waku/v2/protocol/waku_filter/rpc_codec.nim +++ b/waku/v2/protocol/waku_filter/rpc_codec.nim @@ -24,15 +24,15 @@ proc encode*(filter: ContentFilter): ProtoBuffer = pb -proc decode*(T: type ContentFilter, buffer: seq[byte]): ProtoResult[T] = +proc decode*(T: type ContentFilter, buffer: seq[byte]): ProtobufResult[T] = let pb = initProtoBuffer(buffer) var rpc = ContentFilter() - var contentTopic: string - if not ?pb.getField(1, contentTopic): - return err(ProtoError.RequiredFieldMissing) + var topic: string + if not ?pb.getField(1, topic): + return err(ProtobufError.missingRequiredField("content_topic")) else: - rpc.contentTopic = contentTopic + rpc.contentTopic = topic ok(rpc) @@ -50,25 +50,25 @@ proc encode*(rpc: FilterRequest): ProtoBuffer = pb -proc decode*(T: type FilterRequest, buffer: seq[byte]): ProtoResult[T] = +proc decode*(T: type FilterRequest, buffer: seq[byte]): ProtobufResult[T] = let pb = initProtoBuffer(buffer) var rpc = FilterRequest() var subflag: uint64 if not ?pb.getField(1, subflag): - return err(ProtoError.RequiredFieldMissing) + return err(ProtobufError.missingRequiredField("subscribe")) else: rpc.subscribe = bool(subflag) - var pubsubTopic: string - if not ?pb.getField(2, pubsubTopic): - return err(ProtoError.RequiredFieldMissing) + var topic: string + if not ?pb.getField(2, topic): + return err(ProtobufError.missingRequiredField("topic")) else: - rpc.pubsubTopic = pubsubTopic + rpc.pubsubTopic = topic var buffs: seq[seq[byte]] if not ?pb.getRepeatedField(3, buffs): - return err(ProtoError.RequiredFieldMissing) + return err(ProtobufError.missingRequiredField("content_filters")) else: for buf in buffs: let filter = ?ContentFilter.decode(buf) @@ -87,13 +87,13 @@ proc encode*(push: MessagePush): ProtoBuffer = pb -proc decode*(T: type MessagePush, buffer: seq[byte]): ProtoResult[T] = +proc decode*(T: type MessagePush, buffer: seq[byte]): ProtobufResult[T] = let pb = initProtoBuffer(buffer) var rpc = MessagePush() var messages: seq[seq[byte]] if not ?pb.getRepeatedField(1, messages): - return err(ProtoError.RequiredFieldMissing) + return err(ProtobufError.missingRequiredField("messages")) else: for buf in messages: let msg = ?WakuMessage.decode(buf) @@ -112,13 +112,13 @@ proc encode*(rpc: FilterRPC): ProtoBuffer = pb -proc decode*(T: type FilterRPC, buffer: seq[byte]): ProtoResult[T] = - let pb = initProtoBuffer(buffer) +proc decode*(T: type FilterRPC, buffer: seq[byte]): ProtobufResult[T] = + let pb = initProtoBuffer(buffer) var rpc = FilterRPC() var requestId: string if not ?pb.getField(1, requestId): - return err(ProtoError.RequiredFieldMissing) + return err(ProtobufError.missingRequiredField("request_id")) else: rpc.requestId = requestId diff --git a/waku/v2/protocol/waku_lightpush/rpc_codec.nim b/waku/v2/protocol/waku_lightpush/rpc_codec.nim index b71a9c06e..9db2c5c5c 100644 --- a/waku/v2/protocol/waku_lightpush/rpc_codec.nim +++ b/waku/v2/protocol/waku_lightpush/rpc_codec.nim @@ -24,19 +24,19 @@ proc encode*(rpc: PushRequest): ProtoBuffer = pb -proc decode*(T: type PushRequest, buffer: seq[byte]): ProtoResult[T] = +proc decode*(T: type PushRequest, buffer: seq[byte]): ProtobufResult[T] = let pb = initProtoBuffer(buffer) var rpc = PushRequest() var pubSubTopic: PubsubTopic if not ?pb.getField(1, pubSubTopic): - return err(ProtoError.RequiredFieldMissing) + return err(ProtobufError.missingRequiredField("pubsub_topic")) else: rpc.pubSubTopic = pubSubTopic var messageBuf: seq[byte] if not ?pb.getField(2, messageBuf): - return err(ProtoError.RequiredFieldMissing) + return err(ProtobufError.missingRequiredField("message")) else: rpc.message = ?WakuMessage.decode(messageBuf) @@ -52,20 +52,20 @@ proc encode*(rpc: PushResponse): ProtoBuffer = pb -proc decode*(T: type PushResponse, buffer: seq[byte]): ProtoResult[T] = +proc decode*(T: type PushResponse, buffer: seq[byte]): ProtobufResult[T] = let pb = initProtoBuffer(buffer) var rpc = PushResponse() var isSuccess: uint64 if not ?pb.getField(1, isSuccess): - return err(ProtoError.RequiredFieldMissing) + return err(ProtobufError.missingRequiredField("is_success")) else: rpc.isSuccess = bool(isSuccess) var info: string if not ?pb.getField(2, info): rpc.info = none(string) - else: + else: rpc.info = some(info) ok(rpc) @@ -73,7 +73,7 @@ proc decode*(T: type PushResponse, buffer: seq[byte]): ProtoResult[T] = proc encode*(rpc: PushRPC): ProtoBuffer = var pb = initProtoBuffer() - + pb.write3(1, rpc.requestId) pb.write3(2, rpc.request.map(encode)) pb.write3(3, rpc.response.map(encode)) @@ -81,13 +81,13 @@ proc encode*(rpc: PushRPC): ProtoBuffer = pb -proc decode*(T: type PushRPC, buffer: seq[byte]): ProtoResult[T] = +proc decode*(T: type PushRPC, buffer: seq[byte]): ProtobufResult[T] = let pb = initProtoBuffer(buffer) var rpc = PushRPC() var requestId: string if not ?pb.getField(1, requestId): - return err(ProtoError.RequiredFieldMissing) + return err(ProtobufError.missingRequiredField("request_id")) else: rpc.requestId = requestId @@ -105,4 +105,4 @@ proc decode*(T: type PushRPC, buffer: seq[byte]): ProtoResult[T] = let response = ?PushResponse.decode(responseBuffer) rpc.response = some(response) - ok(rpc) \ No newline at end of file + ok(rpc) diff --git a/waku/v2/protocol/waku_message/codec.nim b/waku/v2/protocol/waku_message/codec.nim index 7cde86b52..a11718414 100644 --- a/waku/v2/protocol/waku_message/codec.nim +++ b/waku/v2/protocol/waku_message/codec.nim @@ -28,34 +28,40 @@ proc encode*(message: WakuMessage): ProtoBuffer = buf -proc decode*(T: type WakuMessage, buffer: seq[byte]): ProtoResult[T] = - var msg = WakuMessage(ephemeral: false) + +proc decode*(T: type WakuMessage, buffer: seq[byte]): ProtobufResult[T] = + var msg = WakuMessage() let pb = initProtoBuffer(buffer) + var payload: seq[byte] if not ?pb.getField(1, payload): - return err(ProtoError.RequiredFieldMissing) + return err(ProtobufError.missingRequiredField("payload")) else: msg.payload = payload + var topic: ContentTopic if not ?pb.getField(2, topic): - return err(ProtoError.RequiredFieldMissing) + return err(ProtobufError.missingRequiredField("content_topic")) else: msg.contentTopic = topic + var version: uint32 if not ?pb.getField(3, version): msg.version = 0 else: msg.version = version + var timestamp: zint64 if not ?pb.getField(10, timestamp): msg.timestamp = Timestamp(0) else: msg.timestamp = Timestamp(timestamp) + # Experimental: this is part of https://rfc.vac.dev/spec/17/ spec when defined(rln): var proof: seq[byte] @@ -64,6 +70,7 @@ proc decode*(T: type WakuMessage, buffer: seq[byte]): ProtoResult[T] = else: msg.proof = proof + var ephemeral: uint if not ?pb.getField(31, ephemeral): msg.ephemeral = false diff --git a/waku/v2/protocol/waku_message/message.nim b/waku/v2/protocol/waku_message/message.nim index 3f9bd73bd..dbb11bc8b 100644 --- a/waku/v2/protocol/waku_message/message.nim +++ b/waku/v2/protocol/waku_message/message.nim @@ -12,7 +12,8 @@ else: import ../../utils/time -const MaxWakuMessageSize* = 1024 * 1024 # In bytes. Corresponds to PubSub default +const + MaxWakuMessageSize* = 1024 * 1024 # 1 Mibytes. Corresponds to PubSub default type diff --git a/waku/v2/protocol/waku_store/rpc_codec.nim b/waku/v2/protocol/waku_store/rpc_codec.nim index 853f2ee64..4a7276d1e 100644 --- a/waku/v2/protocol/waku_store/rpc_codec.nim +++ b/waku/v2/protocol/waku_store/rpc_codec.nim @@ -31,14 +31,14 @@ proc encode*(index: PagingIndexRPC): ProtoBuffer = pb -proc decode*(T: type PagingIndexRPC, buffer: seq[byte]): ProtoResult[T] = +proc decode*(T: type PagingIndexRPC, buffer: seq[byte]): ProtobufResult[T] = ## creates and returns an Index object out of buffer var rpc = PagingIndexRPC() let pb = initProtoBuffer(buffer) var data: seq[byte] if not ?pb.getField(1, data): - return err(ProtoError.RequiredFieldMissing) + return err(ProtobufError.missingRequiredField("digest")) else: var digest = MessageDigest() for count, b in data: @@ -48,19 +48,19 @@ proc decode*(T: type PagingIndexRPC, buffer: seq[byte]): ProtoResult[T] = var receiverTime: zint64 if not ?pb.getField(2, receiverTime): - return err(ProtoError.RequiredFieldMissing) + return err(ProtobufError.missingRequiredField("receiver_time")) else: rpc.receiverTime = int64(receiverTime) var senderTime: zint64 if not ?pb.getField(3, senderTime): - return err(ProtoError.RequiredFieldMissing) + return err(ProtobufError.missingRequiredField("sender_time")) else: rpc.senderTime = int64(senderTime) var pubsubTopic: string if not ?pb.getField(4, pubsubTopic): - return err(ProtoError.RequiredFieldMissing) + return err(ProtobufError.missingRequiredField("pubsub_topic")) else: rpc.pubsubTopic = pubsubTopic @@ -79,7 +79,7 @@ proc encode*(rpc: PagingInfoRPC): ProtoBuffer = pb -proc decode*(T: type PagingInfoRPC, buffer: seq[byte]): ProtoResult[T] = +proc decode*(T: type PagingInfoRPC, buffer: seq[byte]): ProtobufResult[T] = ## creates and returns a PagingInfo object out of buffer var rpc = PagingInfoRPC() let pb = initProtoBuffer(buffer) @@ -116,13 +116,12 @@ proc encode*(rpc: HistoryContentFilterRPC): ProtoBuffer = pb -proc decode*(T: type HistoryContentFilterRPC, buffer: seq[byte]): ProtoResult[T] = +proc decode*(T: type HistoryContentFilterRPC, buffer: seq[byte]): ProtobufResult[T] = let pb = initProtoBuffer(buffer) var contentTopic: ContentTopic if not ?pb.getField(1, contentTopic): - return err(ProtoError.RequiredFieldMissing) - + return err(ProtobufError.missingRequiredField("content_topic")) ok(HistoryContentFilterRPC(contentTopic: contentTopic)) @@ -140,7 +139,7 @@ proc encode*(rpc: HistoryQueryRPC): ProtoBuffer = pb -proc decode*(T: type HistoryQueryRPC, buffer: seq[byte]): ProtoResult[T] = +proc decode*(T: type HistoryQueryRPC, buffer: seq[byte]): ProtobufResult[T] = var rpc = HistoryQueryRPC() let pb = initProtoBuffer(buffer) @@ -192,7 +191,7 @@ proc encode*(response: HistoryResponseRPC): ProtoBuffer = pb -proc decode*(T: type HistoryResponseRPC, buffer: seq[byte]): ProtoResult[T] = +proc decode*(T: type HistoryResponseRPC, buffer: seq[byte]): ProtobufResult[T] = var rpc = HistoryResponseRPC() let pb = initProtoBuffer(buffer) @@ -213,7 +212,7 @@ proc decode*(T: type HistoryResponseRPC, buffer: seq[byte]): ProtoResult[T] = var error: uint32 if not ?pb.getField(4, error): - return err(ProtoError.RequiredFieldMissing) + return err(ProtobufError.missingRequiredField("error")) else: rpc.error = HistoryResponseErrorRPC.parse(error) @@ -230,12 +229,12 @@ proc encode*(rpc: HistoryRPC): ProtoBuffer = pb -proc decode*(T: type HistoryRPC, buffer: seq[byte]): ProtoResult[T] = +proc decode*(T: type HistoryRPC, buffer: seq[byte]): ProtobufResult[T] = var rpc = HistoryRPC() let pb = initProtoBuffer(buffer) if not ?pb.getField(1, rpc.requestId): - return err(ProtoError.RequiredFieldMissing) + return err(ProtobufError.missingRequiredField("request_id")) var queryBuffer: seq[byte] if not ?pb.getField(2, queryBuffer):