mirror of
https://github.com/waku-org/nwaku.git
synced 2025-01-14 17:04:53 +00:00
feat(protobuf): added error wrappers for invalid length validation
This commit is contained in:
parent
8cb418a9e4
commit
f7584dfc49
@ -4,4 +4,5 @@
|
||||
import
|
||||
./common/test_envvar_serialization,
|
||||
./common/test_confutils_envvar,
|
||||
./common/test_protobuf_validation,
|
||||
./common/test_sqlite_migrations
|
||||
|
103
tests/common/test_protobuf_validation.nim
Normal file
103
tests/common/test_protobuf_validation.nim
Normal file
@ -0,0 +1,103 @@
|
||||
|
||||
{.used.}
|
||||
|
||||
import
|
||||
testutils/unittests
|
||||
import
|
||||
../../waku/common/protobuf
|
||||
|
||||
|
||||
## Fixtures
|
||||
|
||||
const MaxTestRpcFieldLen = 5
|
||||
|
||||
type TestRpc = object
|
||||
testField*: string
|
||||
|
||||
proc init(T: type TestRpc, field: string): T =
|
||||
T(testField: field)
|
||||
|
||||
proc encode(rpc: TestRpc): ProtoBuffer =
|
||||
var pb = initProtoBuffer()
|
||||
pb.write3(1, rpc.testField)
|
||||
pb.finish3()
|
||||
pb
|
||||
|
||||
proc encodeWithBadFieldId(rpc: TestRpc): ProtoBuffer =
|
||||
var pb = initProtoBuffer()
|
||||
pb.write3(666, rpc.testField)
|
||||
pb.finish3()
|
||||
pb
|
||||
|
||||
proc decode(T: type TestRpc, buf: seq[byte]): ProtobufResult[T] =
|
||||
let pb = initProtoBuffer(buf)
|
||||
|
||||
var field: string
|
||||
if not ?pb.getField(1, field):
|
||||
return err(ProtobufError.missingRequiredField("test_field"))
|
||||
if field.len > MaxTestRpcFieldLen:
|
||||
return err(ProtobufError.invalidLengthField("test_field"))
|
||||
|
||||
ok(TestRpc.init(field))
|
||||
|
||||
|
||||
## Tests
|
||||
|
||||
suite "Waku Common - libp2p minprotobuf wrapper":
|
||||
|
||||
test "serialize and deserialize - valid length field":
|
||||
## Given
|
||||
let field = "12345"
|
||||
|
||||
let rpc = TestRpc.init(field)
|
||||
|
||||
## When
|
||||
let encodedRpc = rpc.encode()
|
||||
let decodedRpcRes = TestRpc.decode(encodedRpc.buffer)
|
||||
|
||||
## Then
|
||||
check:
|
||||
decodedRpcRes.isOk()
|
||||
|
||||
let decodedRpc = decodedRpcRes.tryGet()
|
||||
check:
|
||||
decodedRpc.testField == field
|
||||
|
||||
test "serialize and deserialize - missing required field":
|
||||
## Given
|
||||
let field = "12345"
|
||||
|
||||
let rpc = TestRpc.init(field)
|
||||
|
||||
## When
|
||||
let encodedRpc = rpc.encodeWithBadFieldId()
|
||||
let decodedRpcRes = TestRpc.decode(encodedRpc.buffer)
|
||||
|
||||
## Then
|
||||
check:
|
||||
decodedRpcRes.isErr()
|
||||
|
||||
let error = decodedRpcRes.tryError()
|
||||
check:
|
||||
error.kind == ProtobufErrorKind.MissingRequiredField
|
||||
error.field == "test_field"
|
||||
|
||||
|
||||
test "serialize and deserialize - invalid length field":
|
||||
## Given
|
||||
let field = "123456" # field.len = MaxTestRpcFieldLen + 1
|
||||
|
||||
let rpc = TestRpc.init(field)
|
||||
|
||||
## When
|
||||
let encodedRpc = rpc.encode()
|
||||
let decodedRpcRes = TestRpc.decode(encodedRpc.buffer)
|
||||
|
||||
## Then
|
||||
check:
|
||||
decodedRpcRes.isErr()
|
||||
|
||||
let error = decodedRpcRes.tryError()
|
||||
check:
|
||||
error.kind == ProtobufErrorKind.InvalidLengthField
|
||||
error.field == "test_field"
|
@ -3,9 +3,9 @@
|
||||
import
|
||||
std/tables,
|
||||
stew/[results, byteutils],
|
||||
testutils/unittests,
|
||||
libp2p/protobuf/minprotobuf
|
||||
testutils/unittests
|
||||
import
|
||||
../../waku/common/protobuf,
|
||||
../../waku/v2/utils/noise as waku_message_utils,
|
||||
../../waku/v2/protocol/waku_noise/noise_types,
|
||||
../../waku/v2/protocol/waku_noise/noise_utils,
|
||||
@ -82,7 +82,7 @@ procSuite "Waku Noise Sessions":
|
||||
var
|
||||
sentTransportMessage: seq[byte]
|
||||
aliceStep, bobStep: HandshakeStepResult
|
||||
msgFromPb: ProtoResult[WakuMessage]
|
||||
msgFromPb: ProtobufResult[WakuMessage]
|
||||
wakuMsg: Result[WakuMessage, cstring]
|
||||
pb: ProtoBuffer
|
||||
readPayloadV2: PayloadV2
|
||||
|
@ -15,6 +15,41 @@ export
|
||||
varint
|
||||
|
||||
|
||||
## Custom errors
|
||||
|
||||
type
|
||||
ProtobufErrorKind* {.pure.} = enum
|
||||
DecodeFailure
|
||||
MissingRequiredField
|
||||
InvalidLengthField
|
||||
|
||||
ProtobufError* = object
|
||||
case kind*: ProtobufErrorKind
|
||||
of DecodeFailure:
|
||||
error*: minprotobuf.ProtoError
|
||||
of MissingRequiredField, InvalidLengthField:
|
||||
field*: string
|
||||
|
||||
ProtobufResult*[T] = Result[T, ProtobufError]
|
||||
|
||||
|
||||
converter toProtobufError*(err: minprotobuf.ProtoError): ProtobufError =
|
||||
case err:
|
||||
of minprotobuf.ProtoError.RequiredFieldMissing:
|
||||
ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: "unknown")
|
||||
else:
|
||||
ProtobufError(kind: ProtobufErrorKind.DecodeFailure, error: err)
|
||||
|
||||
|
||||
proc missingRequiredField*(T: type ProtobufError, field: string): T =
|
||||
ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: field)
|
||||
|
||||
proc invalidLengthField*(T: type ProtobufError, field: string): T =
|
||||
ProtobufError(kind: ProtobufErrorKind.InvalidLengthField, field: field)
|
||||
|
||||
|
||||
## Extension methods
|
||||
|
||||
proc write3*(proto: var ProtoBuffer, field: int, value: auto) =
|
||||
when value is Option:
|
||||
if value.isSome():
|
||||
|
@ -24,15 +24,15 @@ proc encode*(filter: ContentFilter): ProtoBuffer =
|
||||
|
||||
pb
|
||||
|
||||
proc decode*(T: type ContentFilter, buffer: seq[byte]): ProtoResult[T] =
|
||||
proc decode*(T: type ContentFilter, buffer: seq[byte]): ProtobufResult[T] =
|
||||
let pb = initProtoBuffer(buffer)
|
||||
var rpc = ContentFilter()
|
||||
|
||||
var contentTopic: string
|
||||
if not ?pb.getField(1, contentTopic):
|
||||
return err(ProtoError.RequiredFieldMissing)
|
||||
var topic: string
|
||||
if not ?pb.getField(1, topic):
|
||||
return err(ProtobufError.missingRequiredField("content_topic"))
|
||||
else:
|
||||
rpc.contentTopic = contentTopic
|
||||
rpc.contentTopic = topic
|
||||
|
||||
ok(rpc)
|
||||
|
||||
@ -50,25 +50,25 @@ proc encode*(rpc: FilterRequest): ProtoBuffer =
|
||||
|
||||
pb
|
||||
|
||||
proc decode*(T: type FilterRequest, buffer: seq[byte]): ProtoResult[T] =
|
||||
proc decode*(T: type FilterRequest, buffer: seq[byte]): ProtobufResult[T] =
|
||||
let pb = initProtoBuffer(buffer)
|
||||
var rpc = FilterRequest()
|
||||
|
||||
var subflag: uint64
|
||||
if not ?pb.getField(1, subflag):
|
||||
return err(ProtoError.RequiredFieldMissing)
|
||||
return err(ProtobufError.missingRequiredField("subscribe"))
|
||||
else:
|
||||
rpc.subscribe = bool(subflag)
|
||||
|
||||
var pubsubTopic: string
|
||||
if not ?pb.getField(2, pubsubTopic):
|
||||
return err(ProtoError.RequiredFieldMissing)
|
||||
var topic: string
|
||||
if not ?pb.getField(2, topic):
|
||||
return err(ProtobufError.missingRequiredField("topic"))
|
||||
else:
|
||||
rpc.pubsubTopic = pubsubTopic
|
||||
rpc.pubsubTopic = topic
|
||||
|
||||
var buffs: seq[seq[byte]]
|
||||
if not ?pb.getRepeatedField(3, buffs):
|
||||
return err(ProtoError.RequiredFieldMissing)
|
||||
return err(ProtobufError.missingRequiredField("content_filters"))
|
||||
else:
|
||||
for buf in buffs:
|
||||
let filter = ?ContentFilter.decode(buf)
|
||||
@ -87,13 +87,13 @@ proc encode*(push: MessagePush): ProtoBuffer =
|
||||
|
||||
pb
|
||||
|
||||
proc decode*(T: type MessagePush, buffer: seq[byte]): ProtoResult[T] =
|
||||
proc decode*(T: type MessagePush, buffer: seq[byte]): ProtobufResult[T] =
|
||||
let pb = initProtoBuffer(buffer)
|
||||
var rpc = MessagePush()
|
||||
|
||||
var messages: seq[seq[byte]]
|
||||
if not ?pb.getRepeatedField(1, messages):
|
||||
return err(ProtoError.RequiredFieldMissing)
|
||||
return err(ProtobufError.missingRequiredField("messages"))
|
||||
else:
|
||||
for buf in messages:
|
||||
let msg = ?WakuMessage.decode(buf)
|
||||
@ -112,13 +112,13 @@ proc encode*(rpc: FilterRPC): ProtoBuffer =
|
||||
|
||||
pb
|
||||
|
||||
proc decode*(T: type FilterRPC, buffer: seq[byte]): ProtoResult[T] =
|
||||
proc decode*(T: type FilterRPC, buffer: seq[byte]): ProtobufResult[T] =
|
||||
let pb = initProtoBuffer(buffer)
|
||||
var rpc = FilterRPC()
|
||||
|
||||
var requestId: string
|
||||
if not ?pb.getField(1, requestId):
|
||||
return err(ProtoError.RequiredFieldMissing)
|
||||
return err(ProtobufError.missingRequiredField("request_id"))
|
||||
else:
|
||||
rpc.requestId = requestId
|
||||
|
||||
|
@ -24,19 +24,19 @@ proc encode*(rpc: PushRequest): ProtoBuffer =
|
||||
|
||||
pb
|
||||
|
||||
proc decode*(T: type PushRequest, buffer: seq[byte]): ProtoResult[T] =
|
||||
proc decode*(T: type PushRequest, buffer: seq[byte]): ProtobufResult[T] =
|
||||
let pb = initProtoBuffer(buffer)
|
||||
var rpc = PushRequest()
|
||||
|
||||
var pubSubTopic: PubsubTopic
|
||||
if not ?pb.getField(1, pubSubTopic):
|
||||
return err(ProtoError.RequiredFieldMissing)
|
||||
return err(ProtobufError.missingRequiredField("pubsub_topic"))
|
||||
else:
|
||||
rpc.pubSubTopic = pubSubTopic
|
||||
|
||||
var messageBuf: seq[byte]
|
||||
if not ?pb.getField(2, messageBuf):
|
||||
return err(ProtoError.RequiredFieldMissing)
|
||||
return err(ProtobufError.missingRequiredField("message"))
|
||||
else:
|
||||
rpc.message = ?WakuMessage.decode(messageBuf)
|
||||
|
||||
@ -52,13 +52,13 @@ proc encode*(rpc: PushResponse): ProtoBuffer =
|
||||
|
||||
pb
|
||||
|
||||
proc decode*(T: type PushResponse, buffer: seq[byte]): ProtoResult[T] =
|
||||
proc decode*(T: type PushResponse, buffer: seq[byte]): ProtobufResult[T] =
|
||||
let pb = initProtoBuffer(buffer)
|
||||
var rpc = PushResponse()
|
||||
|
||||
var isSuccess: uint64
|
||||
if not ?pb.getField(1, isSuccess):
|
||||
return err(ProtoError.RequiredFieldMissing)
|
||||
return err(ProtobufError.missingRequiredField("is_success"))
|
||||
else:
|
||||
rpc.isSuccess = bool(isSuccess)
|
||||
|
||||
@ -81,13 +81,13 @@ proc encode*(rpc: PushRPC): ProtoBuffer =
|
||||
|
||||
pb
|
||||
|
||||
proc decode*(T: type PushRPC, buffer: seq[byte]): ProtoResult[T] =
|
||||
proc decode*(T: type PushRPC, buffer: seq[byte]): ProtobufResult[T] =
|
||||
let pb = initProtoBuffer(buffer)
|
||||
var rpc = PushRPC()
|
||||
|
||||
var requestId: string
|
||||
if not ?pb.getField(1, requestId):
|
||||
return err(ProtoError.RequiredFieldMissing)
|
||||
return err(ProtobufError.missingRequiredField("request_id"))
|
||||
else:
|
||||
rpc.requestId = requestId
|
||||
|
||||
|
@ -28,34 +28,40 @@ proc encode*(message: WakuMessage): ProtoBuffer =
|
||||
|
||||
buf
|
||||
|
||||
proc decode*(T: type WakuMessage, buffer: seq[byte]): ProtoResult[T] =
|
||||
var msg = WakuMessage(ephemeral: false)
|
||||
|
||||
proc decode*(T: type WakuMessage, buffer: seq[byte]): ProtobufResult[T] =
|
||||
var msg = WakuMessage()
|
||||
let pb = initProtoBuffer(buffer)
|
||||
|
||||
|
||||
var payload: seq[byte]
|
||||
if not ?pb.getField(1, payload):
|
||||
return err(ProtoError.RequiredFieldMissing)
|
||||
return err(ProtobufError.missingRequiredField("payload"))
|
||||
else:
|
||||
msg.payload = payload
|
||||
|
||||
|
||||
var topic: ContentTopic
|
||||
if not ?pb.getField(2, topic):
|
||||
return err(ProtoError.RequiredFieldMissing)
|
||||
return err(ProtobufError.missingRequiredField("content_topic"))
|
||||
else:
|
||||
msg.contentTopic = topic
|
||||
|
||||
|
||||
var version: uint32
|
||||
if not ?pb.getField(3, version):
|
||||
msg.version = 0
|
||||
else:
|
||||
msg.version = version
|
||||
|
||||
|
||||
var timestamp: zint64
|
||||
if not ?pb.getField(10, timestamp):
|
||||
msg.timestamp = Timestamp(0)
|
||||
else:
|
||||
msg.timestamp = Timestamp(timestamp)
|
||||
|
||||
|
||||
# Experimental: this is part of https://rfc.vac.dev/spec/17/ spec
|
||||
when defined(rln):
|
||||
var proof: seq[byte]
|
||||
@ -64,6 +70,7 @@ proc decode*(T: type WakuMessage, buffer: seq[byte]): ProtoResult[T] =
|
||||
else:
|
||||
msg.proof = proof
|
||||
|
||||
|
||||
var ephemeral: uint
|
||||
if not ?pb.getField(31, ephemeral):
|
||||
msg.ephemeral = false
|
||||
|
@ -12,7 +12,8 @@ else:
|
||||
import
|
||||
../../utils/time
|
||||
|
||||
const MaxWakuMessageSize* = 1024 * 1024 # In bytes. Corresponds to PubSub default
|
||||
const
|
||||
MaxWakuMessageSize* = 1024 * 1024 # 1 Mibytes. Corresponds to PubSub default
|
||||
|
||||
|
||||
type
|
||||
|
@ -31,14 +31,14 @@ proc encode*(index: PagingIndexRPC): ProtoBuffer =
|
||||
|
||||
pb
|
||||
|
||||
proc decode*(T: type PagingIndexRPC, buffer: seq[byte]): ProtoResult[T] =
|
||||
proc decode*(T: type PagingIndexRPC, buffer: seq[byte]): ProtobufResult[T] =
|
||||
## creates and returns an Index object out of buffer
|
||||
var rpc = PagingIndexRPC()
|
||||
let pb = initProtoBuffer(buffer)
|
||||
|
||||
var data: seq[byte]
|
||||
if not ?pb.getField(1, data):
|
||||
return err(ProtoError.RequiredFieldMissing)
|
||||
return err(ProtobufError.missingRequiredField("digest"))
|
||||
else:
|
||||
var digest = MessageDigest()
|
||||
for count, b in data:
|
||||
@ -48,19 +48,19 @@ proc decode*(T: type PagingIndexRPC, buffer: seq[byte]): ProtoResult[T] =
|
||||
|
||||
var receiverTime: zint64
|
||||
if not ?pb.getField(2, receiverTime):
|
||||
return err(ProtoError.RequiredFieldMissing)
|
||||
return err(ProtobufError.missingRequiredField("receiver_time"))
|
||||
else:
|
||||
rpc.receiverTime = int64(receiverTime)
|
||||
|
||||
var senderTime: zint64
|
||||
if not ?pb.getField(3, senderTime):
|
||||
return err(ProtoError.RequiredFieldMissing)
|
||||
return err(ProtobufError.missingRequiredField("sender_time"))
|
||||
else:
|
||||
rpc.senderTime = int64(senderTime)
|
||||
|
||||
var pubsubTopic: string
|
||||
if not ?pb.getField(4, pubsubTopic):
|
||||
return err(ProtoError.RequiredFieldMissing)
|
||||
return err(ProtobufError.missingRequiredField("pubsub_topic"))
|
||||
else:
|
||||
rpc.pubsubTopic = pubsubTopic
|
||||
|
||||
@ -79,7 +79,7 @@ proc encode*(rpc: PagingInfoRPC): ProtoBuffer =
|
||||
|
||||
pb
|
||||
|
||||
proc decode*(T: type PagingInfoRPC, buffer: seq[byte]): ProtoResult[T] =
|
||||
proc decode*(T: type PagingInfoRPC, buffer: seq[byte]): ProtobufResult[T] =
|
||||
## creates and returns a PagingInfo object out of buffer
|
||||
var rpc = PagingInfoRPC()
|
||||
let pb = initProtoBuffer(buffer)
|
||||
@ -116,13 +116,12 @@ proc encode*(rpc: HistoryContentFilterRPC): ProtoBuffer =
|
||||
|
||||
pb
|
||||
|
||||
proc decode*(T: type HistoryContentFilterRPC, buffer: seq[byte]): ProtoResult[T] =
|
||||
proc decode*(T: type HistoryContentFilterRPC, buffer: seq[byte]): ProtobufResult[T] =
|
||||
let pb = initProtoBuffer(buffer)
|
||||
|
||||
var contentTopic: ContentTopic
|
||||
if not ?pb.getField(1, contentTopic):
|
||||
return err(ProtoError.RequiredFieldMissing)
|
||||
|
||||
return err(ProtobufError.missingRequiredField("content_topic"))
|
||||
ok(HistoryContentFilterRPC(contentTopic: contentTopic))
|
||||
|
||||
|
||||
@ -140,7 +139,7 @@ proc encode*(rpc: HistoryQueryRPC): ProtoBuffer =
|
||||
|
||||
pb
|
||||
|
||||
proc decode*(T: type HistoryQueryRPC, buffer: seq[byte]): ProtoResult[T] =
|
||||
proc decode*(T: type HistoryQueryRPC, buffer: seq[byte]): ProtobufResult[T] =
|
||||
var rpc = HistoryQueryRPC()
|
||||
let pb = initProtoBuffer(buffer)
|
||||
|
||||
@ -192,7 +191,7 @@ proc encode*(response: HistoryResponseRPC): ProtoBuffer =
|
||||
|
||||
pb
|
||||
|
||||
proc decode*(T: type HistoryResponseRPC, buffer: seq[byte]): ProtoResult[T] =
|
||||
proc decode*(T: type HistoryResponseRPC, buffer: seq[byte]): ProtobufResult[T] =
|
||||
var rpc = HistoryResponseRPC()
|
||||
let pb = initProtoBuffer(buffer)
|
||||
|
||||
@ -213,7 +212,7 @@ proc decode*(T: type HistoryResponseRPC, buffer: seq[byte]): ProtoResult[T] =
|
||||
|
||||
var error: uint32
|
||||
if not ?pb.getField(4, error):
|
||||
return err(ProtoError.RequiredFieldMissing)
|
||||
return err(ProtobufError.missingRequiredField("error"))
|
||||
else:
|
||||
rpc.error = HistoryResponseErrorRPC.parse(error)
|
||||
|
||||
@ -230,12 +229,12 @@ proc encode*(rpc: HistoryRPC): ProtoBuffer =
|
||||
|
||||
pb
|
||||
|
||||
proc decode*(T: type HistoryRPC, buffer: seq[byte]): ProtoResult[T] =
|
||||
proc decode*(T: type HistoryRPC, buffer: seq[byte]): ProtobufResult[T] =
|
||||
var rpc = HistoryRPC()
|
||||
let pb = initProtoBuffer(buffer)
|
||||
|
||||
if not ?pb.getField(1, rpc.requestId):
|
||||
return err(ProtoError.RequiredFieldMissing)
|
||||
return err(ProtobufError.missingRequiredField("request_id"))
|
||||
|
||||
var queryBuffer: seq[byte]
|
||||
if not ?pb.getField(2, queryBuffer):
|
||||
|
Loading…
x
Reference in New Issue
Block a user