feat(protobuf): added error wrappers for invalid length validation

This commit is contained in:
Lorenzo Delgado 2023-02-20 15:03:32 +01:00 committed by GitHub
parent 8cb418a9e4
commit f7584dfc49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 196 additions and 50 deletions

View File

@ -4,4 +4,5 @@
import
./common/test_envvar_serialization,
./common/test_confutils_envvar,
./common/test_protobuf_validation,
./common/test_sqlite_migrations

View File

@ -0,0 +1,103 @@
{.used.}
import
testutils/unittests
import
../../waku/common/protobuf
## Fixtures
const MaxTestRpcFieldLen = 5
type TestRpc = object
testField*: string
proc init(T: type TestRpc, field: string): T =
T(testField: field)
proc encode(rpc: TestRpc): ProtoBuffer =
var pb = initProtoBuffer()
pb.write3(1, rpc.testField)
pb.finish3()
pb
proc encodeWithBadFieldId(rpc: TestRpc): ProtoBuffer =
var pb = initProtoBuffer()
pb.write3(666, rpc.testField)
pb.finish3()
pb
proc decode(T: type TestRpc, buf: seq[byte]): ProtobufResult[T] =
let pb = initProtoBuffer(buf)
var field: string
if not ?pb.getField(1, field):
return err(ProtobufError.missingRequiredField("test_field"))
if field.len > MaxTestRpcFieldLen:
return err(ProtobufError.invalidLengthField("test_field"))
ok(TestRpc.init(field))
## Tests
suite "Waku Common - libp2p minprotobuf wrapper":
test "serialize and deserialize - valid length field":
## Given
let field = "12345"
let rpc = TestRpc.init(field)
## When
let encodedRpc = rpc.encode()
let decodedRpcRes = TestRpc.decode(encodedRpc.buffer)
## Then
check:
decodedRpcRes.isOk()
let decodedRpc = decodedRpcRes.tryGet()
check:
decodedRpc.testField == field
test "serialize and deserialize - missing required field":
## Given
let field = "12345"
let rpc = TestRpc.init(field)
## When
let encodedRpc = rpc.encodeWithBadFieldId()
let decodedRpcRes = TestRpc.decode(encodedRpc.buffer)
## Then
check:
decodedRpcRes.isErr()
let error = decodedRpcRes.tryError()
check:
error.kind == ProtobufErrorKind.MissingRequiredField
error.field == "test_field"
test "serialize and deserialize - invalid length field":
## Given
let field = "123456" # field.len = MaxTestRpcFieldLen + 1
let rpc = TestRpc.init(field)
## When
let encodedRpc = rpc.encode()
let decodedRpcRes = TestRpc.decode(encodedRpc.buffer)
## Then
check:
decodedRpcRes.isErr()
let error = decodedRpcRes.tryError()
check:
error.kind == ProtobufErrorKind.InvalidLengthField
error.field == "test_field"

View File

@ -3,9 +3,9 @@
import
std/tables,
stew/[results, byteutils],
testutils/unittests,
libp2p/protobuf/minprotobuf
testutils/unittests
import
../../waku/common/protobuf,
../../waku/v2/utils/noise as waku_message_utils,
../../waku/v2/protocol/waku_noise/noise_types,
../../waku/v2/protocol/waku_noise/noise_utils,
@ -82,7 +82,7 @@ procSuite "Waku Noise Sessions":
var
sentTransportMessage: seq[byte]
aliceStep, bobStep: HandshakeStepResult
msgFromPb: ProtoResult[WakuMessage]
msgFromPb: ProtobufResult[WakuMessage]
wakuMsg: Result[WakuMessage, cstring]
pb: ProtoBuffer
readPayloadV2: PayloadV2

View File

@ -15,6 +15,41 @@ export
varint
## Custom errors
type
ProtobufErrorKind* {.pure.} = enum
DecodeFailure
MissingRequiredField
InvalidLengthField
ProtobufError* = object
case kind*: ProtobufErrorKind
of DecodeFailure:
error*: minprotobuf.ProtoError
of MissingRequiredField, InvalidLengthField:
field*: string
ProtobufResult*[T] = Result[T, ProtobufError]
converter toProtobufError*(err: minprotobuf.ProtoError): ProtobufError =
case err:
of minprotobuf.ProtoError.RequiredFieldMissing:
ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: "unknown")
else:
ProtobufError(kind: ProtobufErrorKind.DecodeFailure, error: err)
proc missingRequiredField*(T: type ProtobufError, field: string): T =
ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: field)
proc invalidLengthField*(T: type ProtobufError, field: string): T =
ProtobufError(kind: ProtobufErrorKind.InvalidLengthField, field: field)
## Extension methods
proc write3*(proto: var ProtoBuffer, field: int, value: auto) =
when value is Option:
if value.isSome():

View File

@ -24,15 +24,15 @@ proc encode*(filter: ContentFilter): ProtoBuffer =
pb
proc decode*(T: type ContentFilter, buffer: seq[byte]): ProtoResult[T] =
proc decode*(T: type ContentFilter, buffer: seq[byte]): ProtobufResult[T] =
let pb = initProtoBuffer(buffer)
var rpc = ContentFilter()
var contentTopic: string
if not ?pb.getField(1, contentTopic):
return err(ProtoError.RequiredFieldMissing)
var topic: string
if not ?pb.getField(1, topic):
return err(ProtobufError.missingRequiredField("content_topic"))
else:
rpc.contentTopic = contentTopic
rpc.contentTopic = topic
ok(rpc)
@ -50,25 +50,25 @@ proc encode*(rpc: FilterRequest): ProtoBuffer =
pb
proc decode*(T: type FilterRequest, buffer: seq[byte]): ProtoResult[T] =
proc decode*(T: type FilterRequest, buffer: seq[byte]): ProtobufResult[T] =
let pb = initProtoBuffer(buffer)
var rpc = FilterRequest()
var subflag: uint64
if not ?pb.getField(1, subflag):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("subscribe"))
else:
rpc.subscribe = bool(subflag)
var pubsubTopic: string
if not ?pb.getField(2, pubsubTopic):
return err(ProtoError.RequiredFieldMissing)
var topic: string
if not ?pb.getField(2, topic):
return err(ProtobufError.missingRequiredField("topic"))
else:
rpc.pubsubTopic = pubsubTopic
rpc.pubsubTopic = topic
var buffs: seq[seq[byte]]
if not ?pb.getRepeatedField(3, buffs):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("content_filters"))
else:
for buf in buffs:
let filter = ?ContentFilter.decode(buf)
@ -87,13 +87,13 @@ proc encode*(push: MessagePush): ProtoBuffer =
pb
proc decode*(T: type MessagePush, buffer: seq[byte]): ProtoResult[T] =
proc decode*(T: type MessagePush, buffer: seq[byte]): ProtobufResult[T] =
let pb = initProtoBuffer(buffer)
var rpc = MessagePush()
var messages: seq[seq[byte]]
if not ?pb.getRepeatedField(1, messages):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("messages"))
else:
for buf in messages:
let msg = ?WakuMessage.decode(buf)
@ -112,13 +112,13 @@ proc encode*(rpc: FilterRPC): ProtoBuffer =
pb
proc decode*(T: type FilterRPC, buffer: seq[byte]): ProtoResult[T] =
proc decode*(T: type FilterRPC, buffer: seq[byte]): ProtobufResult[T] =
let pb = initProtoBuffer(buffer)
var rpc = FilterRPC()
var requestId: string
if not ?pb.getField(1, requestId):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("request_id"))
else:
rpc.requestId = requestId

View File

@ -24,19 +24,19 @@ proc encode*(rpc: PushRequest): ProtoBuffer =
pb
proc decode*(T: type PushRequest, buffer: seq[byte]): ProtoResult[T] =
proc decode*(T: type PushRequest, buffer: seq[byte]): ProtobufResult[T] =
let pb = initProtoBuffer(buffer)
var rpc = PushRequest()
var pubSubTopic: PubsubTopic
if not ?pb.getField(1, pubSubTopic):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("pubsub_topic"))
else:
rpc.pubSubTopic = pubSubTopic
var messageBuf: seq[byte]
if not ?pb.getField(2, messageBuf):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("message"))
else:
rpc.message = ?WakuMessage.decode(messageBuf)
@ -52,13 +52,13 @@ proc encode*(rpc: PushResponse): ProtoBuffer =
pb
proc decode*(T: type PushResponse, buffer: seq[byte]): ProtoResult[T] =
proc decode*(T: type PushResponse, buffer: seq[byte]): ProtobufResult[T] =
let pb = initProtoBuffer(buffer)
var rpc = PushResponse()
var isSuccess: uint64
if not ?pb.getField(1, isSuccess):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("is_success"))
else:
rpc.isSuccess = bool(isSuccess)
@ -81,13 +81,13 @@ proc encode*(rpc: PushRPC): ProtoBuffer =
pb
proc decode*(T: type PushRPC, buffer: seq[byte]): ProtoResult[T] =
proc decode*(T: type PushRPC, buffer: seq[byte]): ProtobufResult[T] =
let pb = initProtoBuffer(buffer)
var rpc = PushRPC()
var requestId: string
if not ?pb.getField(1, requestId):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("request_id"))
else:
rpc.requestId = requestId

View File

@ -28,34 +28,40 @@ proc encode*(message: WakuMessage): ProtoBuffer =
buf
proc decode*(T: type WakuMessage, buffer: seq[byte]): ProtoResult[T] =
var msg = WakuMessage(ephemeral: false)
proc decode*(T: type WakuMessage, buffer: seq[byte]): ProtobufResult[T] =
var msg = WakuMessage()
let pb = initProtoBuffer(buffer)
var payload: seq[byte]
if not ?pb.getField(1, payload):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("payload"))
else:
msg.payload = payload
var topic: ContentTopic
if not ?pb.getField(2, topic):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("content_topic"))
else:
msg.contentTopic = topic
var version: uint32
if not ?pb.getField(3, version):
msg.version = 0
else:
msg.version = version
var timestamp: zint64
if not ?pb.getField(10, timestamp):
msg.timestamp = Timestamp(0)
else:
msg.timestamp = Timestamp(timestamp)
# Experimental: this is part of https://rfc.vac.dev/spec/17/ spec
when defined(rln):
var proof: seq[byte]
@ -64,6 +70,7 @@ proc decode*(T: type WakuMessage, buffer: seq[byte]): ProtoResult[T] =
else:
msg.proof = proof
var ephemeral: uint
if not ?pb.getField(31, ephemeral):
msg.ephemeral = false

View File

@ -12,7 +12,8 @@ else:
import
../../utils/time
const MaxWakuMessageSize* = 1024 * 1024 # In bytes. Corresponds to PubSub default
const
MaxWakuMessageSize* = 1024 * 1024 # 1 Mibytes. Corresponds to PubSub default
type

View File

@ -31,14 +31,14 @@ proc encode*(index: PagingIndexRPC): ProtoBuffer =
pb
proc decode*(T: type PagingIndexRPC, buffer: seq[byte]): ProtoResult[T] =
proc decode*(T: type PagingIndexRPC, buffer: seq[byte]): ProtobufResult[T] =
## creates and returns an Index object out of buffer
var rpc = PagingIndexRPC()
let pb = initProtoBuffer(buffer)
var data: seq[byte]
if not ?pb.getField(1, data):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("digest"))
else:
var digest = MessageDigest()
for count, b in data:
@ -48,19 +48,19 @@ proc decode*(T: type PagingIndexRPC, buffer: seq[byte]): ProtoResult[T] =
var receiverTime: zint64
if not ?pb.getField(2, receiverTime):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("receiver_time"))
else:
rpc.receiverTime = int64(receiverTime)
var senderTime: zint64
if not ?pb.getField(3, senderTime):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("sender_time"))
else:
rpc.senderTime = int64(senderTime)
var pubsubTopic: string
if not ?pb.getField(4, pubsubTopic):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("pubsub_topic"))
else:
rpc.pubsubTopic = pubsubTopic
@ -79,7 +79,7 @@ proc encode*(rpc: PagingInfoRPC): ProtoBuffer =
pb
proc decode*(T: type PagingInfoRPC, buffer: seq[byte]): ProtoResult[T] =
proc decode*(T: type PagingInfoRPC, buffer: seq[byte]): ProtobufResult[T] =
## creates and returns a PagingInfo object out of buffer
var rpc = PagingInfoRPC()
let pb = initProtoBuffer(buffer)
@ -116,13 +116,12 @@ proc encode*(rpc: HistoryContentFilterRPC): ProtoBuffer =
pb
proc decode*(T: type HistoryContentFilterRPC, buffer: seq[byte]): ProtoResult[T] =
proc decode*(T: type HistoryContentFilterRPC, buffer: seq[byte]): ProtobufResult[T] =
let pb = initProtoBuffer(buffer)
var contentTopic: ContentTopic
if not ?pb.getField(1, contentTopic):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("content_topic"))
ok(HistoryContentFilterRPC(contentTopic: contentTopic))
@ -140,7 +139,7 @@ proc encode*(rpc: HistoryQueryRPC): ProtoBuffer =
pb
proc decode*(T: type HistoryQueryRPC, buffer: seq[byte]): ProtoResult[T] =
proc decode*(T: type HistoryQueryRPC, buffer: seq[byte]): ProtobufResult[T] =
var rpc = HistoryQueryRPC()
let pb = initProtoBuffer(buffer)
@ -192,7 +191,7 @@ proc encode*(response: HistoryResponseRPC): ProtoBuffer =
pb
proc decode*(T: type HistoryResponseRPC, buffer: seq[byte]): ProtoResult[T] =
proc decode*(T: type HistoryResponseRPC, buffer: seq[byte]): ProtobufResult[T] =
var rpc = HistoryResponseRPC()
let pb = initProtoBuffer(buffer)
@ -213,7 +212,7 @@ proc decode*(T: type HistoryResponseRPC, buffer: seq[byte]): ProtoResult[T] =
var error: uint32
if not ?pb.getField(4, error):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("error"))
else:
rpc.error = HistoryResponseErrorRPC.parse(error)
@ -230,12 +229,12 @@ proc encode*(rpc: HistoryRPC): ProtoBuffer =
pb
proc decode*(T: type HistoryRPC, buffer: seq[byte]): ProtoResult[T] =
proc decode*(T: type HistoryRPC, buffer: seq[byte]): ProtobufResult[T] =
var rpc = HistoryRPC()
let pb = initProtoBuffer(buffer)
if not ?pb.getField(1, rpc.requestId):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("request_id"))
var queryBuffer: seq[byte]
if not ?pb.getField(2, queryBuffer):